Source code for easydel.__init__.modules.siglip.modeling_siglip_flax

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
from functools import partial

import chex
import jax
import jax.numpy as jnp
from flax import nnx as nn
from jax import image as jimg

from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import (
	FlaxBaseModelOutput,
	FlaxBaseModelOutputWithPooling,
	FlaxImageClassifierOutput,
	ModelOutput,
)
from easydel.infra.utils import ACT2FN, control_mlp_sharding
from easydel.layers.attention import FlaxAttentionModule, FlexibleAttentionModule
from easydel.utils import traversals as etr

from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig


@etr.auto_pytree
class SiglipVisionModelOutput(ModelOutput):
	image_embeds: tp.Optional[chex.Array] = None
	last_hidden_state: chex.Array = None
	hidden_states: tp.Optional[tp.Tuple[chex.Array, ...]] = None
	attentions: tp.Optional[tp.Tuple[chex.Array, ...]] = None


@etr.auto_pytree
class SiglipTextModelOutput(ModelOutput):
	text_embeds: tp.Optional[chex.Array] = None
	last_hidden_state: chex.Array = None
	hidden_states: tp.Optional[tp.Tuple[chex.Array, ...]] = None
	attentions: tp.Optional[tp.Tuple[chex.Array, ...]] = None


@etr.auto_pytree
class SiglipOutput(ModelOutput):
	loss: tp.Optional[chex.Array] = None
	logits_per_image: chex.Array = None
	logits_per_text: chex.Array = None
	text_embeds: chex.Array = None
	image_embeds: chex.Array = None
	text_model_output: FlaxBaseModelOutputWithPooling = None
	vision_model_output: FlaxBaseModelOutputWithPooling = None

	def to_tuple(self) -> tp.Tuple[tp.Any]:
		return tuple(
			self[k]
			if k not in ["text_model_output", "vision_model_output"]
			else getattr(self, k).to_tuple()
			for k in self.keys()
		)


class SiglipVisionEmbeddings(nn.Module):
	def __init__(
		self,
		config: SiglipVisionConfig,
		dtype: jnp.dtype = jnp.float32,
		param_dtype: jnp.dtype = jnp.float32,
		precision: jax.lax.PrecisionLike = None,
		*,
		rngs: nn.Rngs,
	):
		self.config = config
		self.embed_dim = config.hidden_size
		self.image_size = config.image_size
		self.patch_size = config.patch_size
		self.num_patches = (self.image_size // self.patch_size) ** 2
		self.num_positions = self.num_patches
		self.position_embedding = nn.Embed(
			self.num_positions,
			self.embed_dim,
			dtype=dtype,
			param_dtype=param_dtype,
			rngs=rngs,
		)
		self.patch_embedding = nn.Conv(
			in_features=config.num_channels,
			out_features=self.embed_dim,
			kernel_size=(self.patch_size, self.patch_size),
			strides=(self.patch_size, self.patch_size),
			padding="VALID",
			dtype=dtype,
			param_dtype=param_dtype,
			rngs=rngs,
			precision=precision,
		)

	def interpolate(self, embeddings: chex.Array, height: int, width: int):
		num_patches = embeddings.shape[1]
		num_positions = self.position_embedding.weight.shape[0]
		if num_patches == num_positions and height == width:
			return self.position_embedding(
				jnp.arange(
					self.num_positions,
					dtype="i4",
				).reshape(1, -1)
			)
		patch_pos_embed = self.position_embedding.embedding.unsqueeze(0)

		dim = embeddings.shape[-1]
		new_height = height // self.patch_size
		new_width = width // self.patch_size

		sqrt_num_positions = int(num_positions**0.5)

		patch_pos_embed = jnp.reshape(
			patch_pos_embed, (1, sqrt_num_positions, sqrt_num_positions, dim)
		)
		patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 3, 1, 2))

		patch_pos_embed = jimg.resize(
			patch_pos_embed,
			(1, dim, new_height, new_width),
			method="cubic",
		)

		return jnp.reshape(jnp.transpose(patch_pos_embed, (0, 2, 3, 1)), (1, -1, dim))

	def __call__(self, pixel_values: chex.Array, interpolate_pos_encoding=False):
		_, _, height, width = pixel_values.shape
		target_dtype = self.patch_embedding.kernel.dtype

		pixel_values = pixel_values.transpose(0, 2, 3, 1).astype(dtype=target_dtype)
		patch_embeds = self.patch_embedding(pixel_values).transpose(0, 3, 1, 2)

		embeddings = jnp.reshape(patch_embeds, patch_embeds.shape[:2] + (-1,))
		embeddings = jnp.transpose(embeddings, (0, 2, 1))
		if interpolate_pos_encoding:
			embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
		else:
			embeddings = embeddings + self.position_embedding(
				jnp.arange(self.num_positions, dtype="i4").reshape(1, -1)
			)
		return embeddings


class SiglipTextEmbeddings(nn.Module):
	def __init__(
		self,
		config: SiglipTextConfig,
		dtype: jnp.dtype = jnp.float32,
		param_dtype: jnp.dtype = jnp.float32,
		precision: jax.lax.PrecisionLike = None,
		*,
		rngs: nn.Rngs,
	):
		embed_dim = config.hidden_size

		self.token_embedding = nn.Embed(
			config.vocab_size,
			embed_dim,
			embedding_init=jax.nn.initializers.normal(),
			dtype=dtype,
			param_dtype=param_dtype,
			rngs=rngs,
		)
		self.position_embedding = nn.Embed(
			config.max_position_embeddings,
			embed_dim,
			embedding_init=jax.nn.initializers.normal(),
			dtype=dtype,
			param_dtype=param_dtype,
			rngs=rngs,
		)

	def __call__(
		self,
		input_ids: tp.Optional[chex.Array] = None,
		position_ids: tp.Optional[chex.Array] = None,
		inputs_embeds: tp.Optional[chex.Array] = None,
	) -> chex.Array:
		seq_length = (
			input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
		)
		max_position_embedding = self.position_embedding.embedding.shape[0]

		if seq_length > max_position_embedding:
			raise ValueError(
				f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
				f"{seq_length} and max_position_embeddings: {max_position_embedding}"
			)

		if position_ids is None:
			position_ids = jnp.arange(seq_length, dtype="i4").reshape(1, -1)

		if inputs_embeds is None:
			inputs_embeds = self.token_embedding(input_ids)

		position_embeddings = self.position_embedding(position_ids)
		embeddings = inputs_embeds + position_embeddings

		return embeddings


class SiglipAttention(FlaxAttentionModule):
	def __init__(
		self,
		config,
		dtype: jnp.dtype = jnp.float32,
		param_dtype: jnp.dtype = jnp.float32,
		precision: jax.lax.PrecisionLike = None,
		*,
		rngs: nn.Rngs,
	):
		super().__init__(config=config)
		self.dtype = dtype
		self.param_dtype = param_dtype
		self.precision = precision
		self.rngs = rngs
		self.embed_dim = config.hidden_size
		self.num_heads = config.num_attention_heads
		self.head_dim = self.embed_dim // self.num_heads
		if self.head_dim * self.num_heads != self.embed_dim:
			raise ValueError(
				f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
				f" {self.num_heads})."
			)

		self.dropout = config.attention_dropout
		linear_class = partial(
			nn.Linear,
			dtype=dtype,
			param_dtype=param_dtype,
			precision=precision,
			rngs=rngs,
			kernel_init=jax.nn.initializers.normal(0.01),
		)
		self.k_proj = linear_class(self.embed_dim, self.embed_dim)
		self.v_proj = linear_class(self.embed_dim, self.embed_dim)
		self.q_proj = linear_class(self.embed_dim, self.embed_dim)
		self.out_proj = linear_class(self.embed_dim, self.embed_dim)

		self.causal = False
		self.attention_performer = FlexibleAttentionModule(
			base_config=config,
			softmax_scale=self.head_dim**-0.5,
			dropout_prob=config.attention_dropout,
		)

	def _split_heads(self, hidden_states):
		return hidden_states.reshape(
			hidden_states.shape[:2] + (self.num_heads, self.head_dim)
		)

	def _merge_heads(self, hidden_states):
		return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))

	def __call__(
		self,
		hidden_states: chex.Array,
		attention_mask: tp.Optional[chex.Array] = None,
		output_attentions: bool = False,
	):
		query = self.q_proj(hidden_states)
		key = self.k_proj(hidden_states)
		value = self.v_proj(hidden_states)

		query = self._split_heads(query)
		key = self._split_heads(key)
		value = self._split_heads(value)

		causal_attention_mask = None
		if self.causal:
			raise NotImplementedError()
		if attention_mask is not None and causal_attention_mask is not None:
			if attention_mask.ndim == 2:
				attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
			attention_mask = nn.combine_masks(
				attention_mask,
				causal_attention_mask,
				dtype="i4",
			)
		elif causal_attention_mask is not None:
			attention_mask = causal_attention_mask
		elif attention_mask is not None:
			if attention_mask.ndim == 2:
				attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
		attention_bias = None
		if attention_mask is not None:
			attention_bias = jax.lax.select(
				attention_mask > 0,
				jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
				jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
			)
			attention_mask = None

		attentions = self.attention_performer.forward(
			query_states=query,
			key_states=key,
			value_states=value,
			bias=None,
			init_bias=lambda: attention_bias,
			attention_mask=attention_mask,
			segment_ids=None,
			causal=self.causal,
			dropout_rng=self.rngs.params(),
		)

		attn_output = self._merge_heads(attentions.attention_outputs)
		attn_output = self.out_proj(attn_output)

		outputs = (
			(attn_output, attentions.attention_weights)
			if output_attentions
			else (attn_output, None)
		)
		return outputs


class SiglipMLP(nn.Module):
	def __init__(
		self,
		config: SiglipTextConfig,
		dtype: jnp.dtype = jnp.float32,
		param_dtype: jnp.dtype = jnp.float32,
		precision: jax.lax.PrecisionLike = None,
		*,
		rngs: nn.Rngs,
	):
		self.config = config
		self.dtype = dtype
		self.param_dtype = param_dtype
		self.precision = precision
		self.rngs = rngs
		self.activation_fn = ACT2FN[config.hidden_act]
		linear_class = partial(
			nn.Linear,
			use_bias=True,
			dtype=dtype,
			param_dtype=param_dtype,
			precision=precision,
			rngs=rngs,
			kernel_init=jax.nn.initializers.normal(0.01),
		)
		self.fc1 = linear_class(config.hidden_size, config.intermediate_size)
		self.fc2 = linear_class(config.intermediate_size, config.hidden_size)

	def __call__(self, hidden_states: chex.Array) -> chex.Array:
		hidden_states = control_mlp_sharding(hidden_states, self.config.partition_axis)
		hidden_states = self.fc2(self.activation_fn(self.fc1(hidden_states)))
		return hidden_states


class SiglipEncoderLayer(nn.Module):
	def __init__(
		self,
		config: SiglipTextConfig,
		dtype: jnp.dtype = jnp.float32,
		param_dtype: jnp.dtype = jnp.float32,
		precision: jax.lax.PrecisionLike = None,
		*,
		rngs: nn.Rngs,
	):
		self.config = config
		self.dtype = dtype
		self.param_dtype = param_dtype
		self.precision = precision
		self.rngs = rngs
		self.self_attn = SiglipAttention(
			config=config,
			dtype=dtype,
			param_dtype=param_dtype,
			precision=precision,
			rngs=rngs,
		)
		self.layer_norm1 = nn.LayerNorm(
			config.hidden_size,
			epsilon=config.layer_norm_eps,
			dtype=dtype,
			param_dtype=param_dtype,
			rngs=rngs,
		)
		self.mlp = SiglipMLP(
			config=config,
			dtype=dtype,
			param_dtype=param_dtype,
			precision=precision,
			rngs=rngs,
		)
		self.layer_norm2 = nn.LayerNorm(
			config.hidden_size,
			epsilon=config.layer_norm_eps,
			dtype=dtype,
			param_dtype=param_dtype,
			rngs=rngs,
		)

	def __call__(
		self,
		hidden_states: chex.Array,
		attention_mask: tp.Optional[chex.Array] = None,
		output_attentions: bool = False,
	):
		residual = hidden_states
		hidden_states = self.layer_norm1(hidden_states)
		attn_outputs = self.self_attn(
			hidden_states=hidden_states,
			attention_mask=attention_mask,
			output_attentions=output_attentions,
		)
		hidden_states = attn_outputs[0]
		hidden_states = residual + hidden_states

		residual = hidden_states
		hidden_states = self.layer_norm2(hidden_states)
		hidden_states = self.mlp(hidden_states)
		hidden_states = residual + hidden_states

		outputs = (hidden_states,) + attn_outputs[1:]

		return outputs


class SiglipEncoder(nn.Module):
	def __init__(
		self,
		config: SiglipTextConfig,
		dtype: jnp.dtype = jnp.float32,
		param_dtype: jnp.dtype = jnp.float32,
		precision: jax.lax.PrecisionLike = None,
		*,
		rngs: nn.Rngs,
	):
		self.config = config
		self.dtype = dtype
		self.param_dtype = param_dtype
		self.precision = precision
		self.rngs = rngs
		self.layers = [
			SiglipEncoderLayer(
				config=config,
				dtype=dtype,
				param_dtype=param_dtype,
				precision=precision,
				rngs=rngs,
			)
			for _ in range(config.num_hidden_layers)
		]

	def __call__(
		self,
		inputs_embeds: chex.Array,
		attention_mask: tp.Optional[chex.Array] = None,
		output_attentions: bool = False,
		output_hidden_states: bool = False,
		return_dict: bool = True,
	):
		hidden_states = inputs_embeds
		all_attentions = () if output_attentions else None
		all_hidden_states = () if output_hidden_states else None

		for layer in self.layers:
			if output_hidden_states:
				all_hidden_states += (hidden_states,)

			layer_outputs = layer(
				hidden_states=hidden_states,
				attention_mask=attention_mask,
				output_attentions=output_attentions,
			)
			hidden_states = layer_outputs[0]

			if output_attentions:
				all_attentions += (layer_outputs[1],)

		if output_hidden_states:
			all_hidden_states += (hidden_states,)

		outputs = (hidden_states,)

		if not return_dict:
			return tuple(v for v in outputs if v is not None)

		return FlaxBaseModelOutput(
			last_hidden_state=hidden_states,
			hidden_states=all_hidden_states,
			attentions=all_attentions,
		)


class SiglipTextTransformer(EasyDeLBaseModule):
	def __init__(
		self,
		config: SiglipTextConfig,
		dtype: jnp.dtype = jnp.float32,
		param_dtype: jnp.dtype = jnp.float32,
		precision: jax.lax.PrecisionLike = None,
		*,
		rngs: nn.Rngs,
	):
		super().__init__(
			config=config,
			dtype=dtype,
			param_dtype=param_dtype,
			precision=precision,
			rngs=rngs,
		)
		embed_dim = config.hidden_size
		self.embeddings = SiglipTextEmbeddings(
			config=config,
			dtype=dtype,
			param_dtype=param_dtype,
			precision=precision,
			rngs=rngs,
		)
		self.encoder = SiglipEncoder(
			config=config,
			dtype=dtype,
			param_dtype=param_dtype,
			precision=precision,
			rngs=rngs,
		)

		self.final_layer_norm = nn.LayerNorm(
			config.hidden_size,
			epsilon=config.layer_norm_eps,
			dtype=dtype,
			param_dtype=param_dtype,
			rngs=rngs,
		)
		self.head = nn.Linear(
			embed_dim,
			config.projection_size,
			dtype=dtype,
			param_dtype=param_dtype,
			precision=precision,
			rngs=rngs,
		)

	def __call__(
		self,
		input_ids: chex.Array,
		attention_mask: chex.Array,
		position_ids: chex.Array,
		output_attentions: bool = False,
		output_hidden_states: bool = False,
		return_dict: bool = True,
	):
		output_attentions = (
			output_attentions
			if output_attentions is not None
			else self.config.output_attentions
		)
		output_hidden_states = (
			output_hidden_states
			if output_hidden_states is not None
			else self.config.output_hidden_states
		)
		return_dict = (
			return_dict if return_dict is not None else self.config.use_return_dict
		)
		input_shape = input_ids.shape
		input_ids = input_ids.reshape(-1, input_shape[-1])
		hidden_states = self.embeddings(
			input_ids=input_ids,
			position_ids=position_ids,
		)

		encoder_outputs = self.encoder(
			inputs_embeds=hidden_states,
			attention_mask=attention_mask,
			output_attentions=output_attentions,
			output_hidden_states=output_hidden_states,
			return_dict=return_dict,
		)

		last_hidden_state = encoder_outputs[0]
		last_hidden_state = self.final_layer_norm(last_hidden_state)
		pooled_output = last_hidden_state[:, -1, :]
		pooled_output = self.head(pooled_output)

		if not return_dict:
			return (last_hidden_state, pooled_output) + encoder_outputs[1:]

		return FlaxBaseModelOutputWithPooling(
			last_hidden_state=last_hidden_state,
			pooler_output=pooled_output,
			hidden_states=encoder_outputs.hidden_states,
			attentions=encoder_outputs.attentions,
		)


[docs]@register_module( TaskType.BASE_MODULE, config=SiglipTextConfig, model_type="siglip_text_model", ) class SiglipTextModel(EasyDeLBaseModule): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.text_model = SiglipTextTransformer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, ) -> tp.Union[tp.Tuple, FlaxBaseModelOutputWithPooling]: return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict if return_dict is not None else self.config.use_return_dict, )
class SiglipVisionTransformer(EasyDeLBaseModule): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.encoder = SiglipEncoder( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.post_layernorm = nn.LayerNorm( embed_dim, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.use_head = ( True if not hasattr(config, "vision_use_head") else config.vision_use_head ) if self.use_head: self.head = SiglipMultiheadAttentionPoolingHead( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, pixel_values, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: tp.Optional[bool] = False, ) -> tp.Union[tp.Tuple, FlaxBaseModelOutputWithPooling]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None if not return_dict: return (last_hidden_state, pooler_output) + encoder_outputs[1:] return FlaxBaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class MultiheadAttention(nn.Module): def __init__( self, embed_dim, num_heads, bias=True, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): if embed_dim <= 0 or num_heads <= 0: raise ValueError( f"embed_dim and num_heads must be greater than 0," f" got embed_dim={embed_dim} and num_heads={num_heads} instead" ) self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, ( "embed_dim must be divisible by num_heads" ) def normal_init(*shape): return nn.initializers.xavier_uniform()(rngs.param(), shape, param_dtype) def ze_init(*shape): return jnp.zeros(shape, param_dtype) self.in_proj_weight = nn.Param(normal_init(embed_dim * 3, embed_dim)) self.in_proj_bias = nn.Param(ze_init(3 * embed_dim)) self.out_proj = nn.Linear( embed_dim, embed_dim, use_bias=bias, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, query: chex.Array, key: chex.Array, value: chex.Array, ): qbs, qss, qds = query.shape b, s, d = value.shape qb, kb, vb = jnp.split(self.in_proj_bias, 3, -1) qw, kw, vw = jnp.split(self.in_proj_weight, 3, -1) qout = ((query @ qw) + qb).reshape(qbs, qss, self.num_heads, -1) kout = ((key @ kw) + kb).reshape(b, s, self.num_heads, -1) vout = ((value @ vw) + vb).reshape(b, s, self.num_heads, -1) attn = jnp.einsum( "bhqk,bkhd->bqhd", jax.nn.softmax( jnp.einsum( "bqhd,bkhd->bhqk", qout * (qout.shape[-1] ** -0.5), kout, ) ), vout, ) return self.out_proj(attn.reshape(qbs, qss, qds)) class SiglipMultiheadAttentionPoolingHead(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.probe = nn.Param( jax.random.normal( rngs.param(), (1, 1, config.hidden_size), param_dtype, ) ) self.attention = MultiheadAttention( config.hidden_size, config.num_attention_heads, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.layernorm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.mlp = SiglipMLP( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__(self, hidden_state): batch_size = hidden_state.shape[0] probe = self.probe.value.repeat(batch_size, 0) hidden_state = self.attention(probe, hidden_state, hidden_state) residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0]
[docs]@register_module( TaskType.BASE_VISION, config=SiglipVisionConfig, model_type="siglip_vision_model", ) class SiglipVisionModel(nn.Module): def __init__( self, config: SiglipVisionConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.vision_model = SiglipVisionTransformer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, pixel_values, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> tp.Union[tp.Tuple, FlaxBaseModelOutputWithPooling]: return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, )
[docs]@register_module( TaskType.BASE_MODULE, config=SiglipConfig, model_type="siglip", ) class SiglipModel(EasyDeLBaseModule): def __init__( self, config: SiglipConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) if not isinstance(config.text_config, SiglipTextConfig): raise TypeError( "config.text_config is expected to be of type SiglipTextConfig but is of type" f" {type(config.text_config)}." ) if not isinstance(config.vision_config, SiglipVisionConfig): raise TypeError( "config.vision_config is expected to be of type SiglipVisionConfig but is of type" f" {type(config.vision_config)}." ) text_config = config.text_config vision_config = config.vision_config text_model = SiglipTextModel( text_config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) vision_model = SiglipVisionModel( vision_config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.text_model = text_model.text_model self.vision_model = vision_model.vision_model self.logit_scale = nn.Param(jax.random.normal(rngs.param(), (1,), param_dtype)) self.logit_bias = nn.Param(jax.random.normal(rngs.param(), (1,), param_dtype))
[docs] def get_text_features( self, input_ids: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, ) -> chex.Array: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = text_outputs[1] return pooled_output
[docs] def get_image_features( self, pixel_values: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> chex.Array: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) pooled_output = vision_outputs[1] return pooled_output
def __call__( self, input_ids: tp.Optional[chex.Array] = None, pixel_values: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, return_loss: tp.Optional[bool] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> tp.Union[tp.Tuple, SiglipOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[1] text_embeds = text_outputs[1] # normalized features image_embeds = image_embeds / jnp.linalg.norm( image_embeds, ord=2, axis=-1, keepdims=True, ) text_embeds = text_embeds / jnp.linalg.norm( text_embeds, ord=2, axis=-1, keepdims=True, ) # cosine similarity as logits logits_per_text = jnp.matmul(text_embeds, image_embeds.T) logit_scale, logit_bias = (self.logit_scale, self.logit_bias) logits_per_text = logits_per_text * jnp.exp(logit_scale) + logit_bias logits_per_image = logits_per_text.T loss = None if return_loss: m1_diag1 = -jnp.ones_like(logits_per_text) + 2 * jnp.eye(logits_per_text.shape[0]) loglik = jax.nn.log_sigmoid(m1_diag1 * logits_per_text) nll = -jnp.sum(loglik, axis=-1) loss = nll.mean() if not return_dict: output = ( logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs, ) return ((loss,) + output) if loss is not None else output return SiglipOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, )
[docs]@register_module( TaskType.IMAGE_CLASSIFICATION, config=SiglipConfig, model_type="siglip", ) class SiglipForImageClassification(EasyDeLBaseModule): def __init__( self, config: SiglipConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.num_labels = config.num_labels vision_model = SiglipVisionModel( config.vision_config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.vision_model = vision_model.vision_model self.use_classif = config.num_labels > 0 # Classifier head if self.use_classif: self.classifier = nn.Linear( config.vision_config.hidden_size, config.num_labels, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, pixel_values: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> tp.Union[tuple, FlaxImageClassifierOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.vision_model( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = outputs[0] logits = jnp.mean(sequence_output, axis=1) if self.use_classif: logits = self.classifier(logits) if not return_dict: output = (logits,) + outputs[2:] return output return FlaxImageClassifierOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )