Source code for easydel.modules.siglip.modeling_siglip_flax

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
from functools import partial

import chex
import jax
import jax.numpy as jnp
from eformer.pytree import auto_pytree
from flax import nnx as nn
from jax import image as jimg

from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import (
	BaseModelOutput,
	BaseModelOutputWithPooling,
	ImageClassifierOutput,
	ModelOutput,
)
from easydel.infra.utils import ACT2FN, control_mlp_sharding
from easydel.layers.attention import AttentionModule, FlexibleAttentionModule
from easydel.layers.linear import ParallelLinear

from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig


[docs]@auto_pytree class SiglipVisionModelOutput(ModelOutput): image_embeds: tp.Optional[chex.Array] = None last_hidden_state: chex.Array = None hidden_states: tp.Optional[tp.Tuple[chex.Array, ...]] = None attentions: tp.Optional[tp.Tuple[chex.Array, ...]] = None
[docs]@auto_pytree class SiglipTextModelOutput(ModelOutput): text_embeds: tp.Optional[chex.Array] = None last_hidden_state: chex.Array = None hidden_states: tp.Optional[tp.Tuple[chex.Array, ...]] = None attentions: tp.Optional[tp.Tuple[chex.Array, ...]] = None
[docs]@auto_pytree class SiglipOutput(ModelOutput): loss: tp.Optional[chex.Array] = None logits_per_image: chex.Array = None logits_per_text: chex.Array = None text_embeds: chex.Array = None image_embeds: chex.Array = None text_model_output: BaseModelOutputWithPooling = None vision_model_output: BaseModelOutputWithPooling = None
[docs] def to_tuple(self) -> tp.Tuple[tp.Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() )
[docs]class SiglipVisionEmbeddings(nn.Module): def __init__( self, config: SiglipVisionConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embed( self.num_positions, self.embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.patch_embedding = nn.Conv( in_features=config.num_channels, out_features=self.embed_dim, kernel_size=(self.patch_size, self.patch_size), strides=(self.patch_size, self.patch_size), padding="VALID", dtype=dtype, param_dtype=param_dtype, rngs=rngs, precision=precision, )
[docs] def interpolate(self, embeddings: chex.Array, height: int, width: int): num_patches = embeddings.shape[1] num_positions = self.position_embedding.weight.shape[0] if num_patches == num_positions and height == width: return self.position_embedding( jnp.arange( self.num_positions, dtype="i4", ).reshape(1, -1) ) patch_pos_embed = self.position_embedding.embedding.unsqueeze(0) dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = int(num_positions**0.5) patch_pos_embed = jnp.reshape( patch_pos_embed, (1, sqrt_num_positions, sqrt_num_positions, dim) ) patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 3, 1, 2)) patch_pos_embed = jimg.resize( patch_pos_embed, (1, dim, new_height, new_width), method="cubic", ) return jnp.reshape(jnp.transpose(patch_pos_embed, (0, 2, 3, 1)), (1, -1, dim))
def __call__(self, pixel_values: chex.Array, interpolate_pos_encoding=False): _, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.kernel.dtype pixel_values = pixel_values.transpose(0, 2, 3, 1).astype(dtype=target_dtype) patch_embeds = self.patch_embedding(pixel_values).transpose(0, 3, 1, 2) embeddings = jnp.reshape(patch_embeds, patch_embeds.shape[:2] + (-1,)) embeddings = jnp.transpose(embeddings, (0, 2, 1)) if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embedding( jnp.arange(self.num_positions, dtype="i4").reshape(1, -1) ) return embeddings
[docs]class SiglipTextEmbeddings(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): embed_dim = config.hidden_size self.token_embedding = nn.Embed( config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(), dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.position_embedding = nn.Embed( config.max_position_embeddings, embed_dim, embedding_init=jax.nn.initializers.normal(), dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, ) -> chex.Array: seq_length = ( input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] ) max_position_embedding = self.position_embedding.embedding.shape[0] if seq_length > max_position_embedding: raise ValueError( f"Sequence length must be less than max_position_embeddings (got `sequence length`: " f"{seq_length} and max_position_embeddings: {max_position_embedding}" ) if position_ids is None: position_ids = jnp.arange(seq_length, dtype="i4").reshape(1, -1) if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings
[docs]class SiglipAttention(AttentionModule): def __init__( self, config, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__(config=config) self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.dropout = config.attention_dropout linear_class = partial( ParallelLinear, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, kernel_init=jax.nn.initializers.normal(0.01), ) self.k_proj = linear_class(self.embed_dim, self.embed_dim) self.v_proj = linear_class(self.embed_dim, self.embed_dim) self.q_proj = linear_class(self.embed_dim, self.embed_dim) self.out_proj = linear_class(self.embed_dim, self.embed_dim) self.causal = False self.attention_performer = FlexibleAttentionModule( base_config=config, softmax_scale=self.head_dim**-0.5, dropout_prob=config.attention_dropout, ) def _split_heads(self, hidden_states): return hidden_states.reshape( hidden_states.shape[:2] + (self.num_heads, self.head_dim) ) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) def __call__( self, hidden_states: chex.Array, attention_mask: tp.Optional[chex.Array] = None, output_attentions: bool = False, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) causal_attention_mask = None if self.causal: raise NotImplementedError() if attention_mask is not None and causal_attention_mask is not None: if attention_mask.ndim == 2: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = nn.combine_masks( attention_mask, causal_attention_mask, dtype="i4", ) elif causal_attention_mask is not None: attention_mask = causal_attention_mask elif attention_mask is not None: if attention_mask.ndim == 2: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = None if attention_mask is not None: attention_bias = jax.lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) attention_mask = None attentions = self.attention_performer.forward( query_states=query, key_states=key, value_states=value, bias=None, init_bias=lambda: attention_bias, attention_mask=attention_mask, segment_ids=None, causal=self.causal, dropout_rng=self.rngs.params(), ) attn_output = self._merge_heads(attentions.attention_outputs) attn_output = self.out_proj(attn_output) return attn_output, attentions.attention_weights
[docs]class SiglipMLP(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.activation_fn = ACT2FN[config.hidden_act] linear_class = partial( ParallelLinear, use_bias=True, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, kernel_init=jax.nn.initializers.normal(0.01), ) self.fc1 = linear_class(config.hidden_size, config.intermediate_size) self.fc2 = linear_class(config.intermediate_size, config.hidden_size) def __call__(self, hidden_states: chex.Array) -> chex.Array: hidden_states = control_mlp_sharding(hidden_states, self.config.partition_axis) hidden_states = self.fc2(self.activation_fn(self.fc1(hidden_states))) return hidden_states
[docs]class SiglipEncoderLayer(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.self_attn = SiglipAttention( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.layer_norm1 = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.mlp = SiglipMLP( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.layer_norm2 = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) def __call__( self, hidden_states: chex.Array, attention_mask: tp.Optional[chex.Array] = None, output_attentions: bool = False, ): residual = hidden_states hidden_states = self.layer_norm1(hidden_states) attn_outputs = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = attn_outputs[0] hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) + attn_outputs[1:] return outputs
[docs]class SiglipEncoder(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.layers = [ SiglipEncoderLayer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) for _ in range(config.num_hidden_layers) ] def __call__( self, inputs_embeds: chex.Array, attention_mask: tp.Optional[chex.Array] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): hidden_states = inputs_embeds all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions += (layer_outputs[1],) if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states,) if not return_dict: return tuple(v for v in outputs if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, )
[docs]class SiglipTextTransformer(EasyDeLBaseModule): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) embed_dim = config.hidden_size self.embeddings = SiglipTextEmbeddings( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.encoder = SiglipEncoder( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.final_layer_norm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.head = ParallelLinear( embed_dim, config.projection_size, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, input_ids: chex.Array, attention_mask: chex.Array, position_ids: chex.Array, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) input_shape = input_ids.shape input_ids = input_ids.reshape(-1, input_shape[-1]) hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, ) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) pooled_output = last_hidden_state[:, -1, :] pooled_output = self.head(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )
[docs]@register_module( TaskType.BASE_MODULE, config=SiglipTextConfig, model_type="siglip_text_model", ) class SiglipTextModel(EasyDeLBaseModule): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.text_model = SiglipTextTransformer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, ) -> tp.Union[tp.Tuple, BaseModelOutputWithPooling]: return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict if return_dict is not None else self.config.use_return_dict, )
[docs]class SiglipVisionTransformer(EasyDeLBaseModule): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.encoder = SiglipEncoder( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.post_layernorm = nn.LayerNorm( embed_dim, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.use_head = ( True if not hasattr(config, "vision_use_head") else config.vision_use_head ) if self.use_head: self.head = SiglipMultiheadAttentionPoolingHead( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, pixel_values, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: tp.Optional[bool] = False, ) -> tp.Union[tp.Tuple, BaseModelOutputWithPooling]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None if not return_dict: return (last_hidden_state, pooler_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )
[docs]class MultiheadAttention(nn.Module): def __init__( self, embed_dim, num_heads, bias=True, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): if embed_dim <= 0 or num_heads <= 0: raise ValueError( f"embed_dim and num_heads must be greater than 0," f" got embed_dim={embed_dim} and num_heads={num_heads} instead" ) self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, ( "embed_dim must be divisible by num_heads" ) def normal_init(*shape): return nn.initializers.xavier_uniform()(rngs.param(), shape, param_dtype) def ze_init(*shape): return jnp.zeros(shape, param_dtype) self.in_proj_weight = nn.Param(normal_init(embed_dim * 3, embed_dim)) self.in_proj_bias = nn.Param(ze_init(3 * embed_dim)) self.out_proj = ParallelLinear( embed_dim, embed_dim, use_bias=bias, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, query: chex.Array, key: chex.Array, value: chex.Array, ): qbs, qss, qds = query.shape b, s, d = value.shape qb, kb, vb = jnp.split(self.in_proj_bias, 3, -1) qw, kw, vw = jnp.split(self.in_proj_weight, 3, -1) qout = ((query @ qw) + qb).reshape(qbs, qss, self.num_heads, -1) kout = ((key @ kw) + kb).reshape(b, s, self.num_heads, -1) vout = ((value @ vw) + vb).reshape(b, s, self.num_heads, -1) attn = jnp.einsum( "bhqk,bkhd->bqhd", jax.nn.softmax( jnp.einsum( "bqhd,bkhd->bhqk", qout * (qout.shape[-1] ** -0.5), kout, ) ), vout, ) return self.out_proj(attn.reshape(qbs, qss, qds))
[docs]class SiglipMultiheadAttentionPoolingHead(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.probe = nn.Param( jax.random.normal( rngs.param(), (1, 1, config.hidden_size), param_dtype, ) ) self.attention = MultiheadAttention( config.hidden_size, config.num_attention_heads, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.layernorm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.mlp = SiglipMLP( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__(self, hidden_state): batch_size = hidden_state.shape[0] probe = self.probe.value.repeat(batch_size, 0) hidden_state = self.attention(probe, hidden_state, hidden_state) residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0]
[docs]@register_module( TaskType.BASE_VISION, config=SiglipVisionConfig, model_type="siglip_vision_model", ) class SiglipVisionModel(nn.Module): def __init__( self, config: SiglipVisionConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.vision_model = SiglipVisionTransformer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, pixel_values, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> tp.Union[tp.Tuple, BaseModelOutputWithPooling]: return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, )
[docs]@register_module( TaskType.BASE_MODULE, config=SiglipConfig, model_type="siglip", ) class SiglipModel(EasyDeLBaseModule): def __init__( self, config: SiglipConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) if not isinstance(config.text_config, SiglipTextConfig): raise TypeError( "config.text_config is expected to be of type SiglipTextConfig but is of type" f" {type(config.text_config)}." ) if not isinstance(config.vision_config, SiglipVisionConfig): raise TypeError( "config.vision_config is expected to be of type SiglipVisionConfig but is of type" f" {type(config.vision_config)}." ) text_config = config.text_config vision_config = config.vision_config text_model = SiglipTextModel( text_config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) vision_model = SiglipVisionModel( vision_config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.text_model = text_model.text_model self.vision_model = vision_model.vision_model self.logit_scale = nn.Param(jax.random.normal(rngs.param(), (1,), param_dtype)) self.logit_bias = nn.Param(jax.random.normal(rngs.param(), (1,), param_dtype))
[docs] def get_text_features( self, input_ids: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, ) -> chex.Array: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = text_outputs[1] return pooled_output
[docs] def get_image_features( self, pixel_values: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> chex.Array: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) pooled_output = vision_outputs[1] return pooled_output
def __call__( self, input_ids: tp.Optional[chex.Array] = None, pixel_values: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, return_loss: tp.Optional[bool] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> tp.Union[tp.Tuple, SiglipOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[1] text_embeds = text_outputs[1] # normalized features image_embeds = image_embeds / jnp.linalg.norm( image_embeds, ord=2, axis=-1, keepdims=True, ) text_embeds = text_embeds / jnp.linalg.norm( text_embeds, ord=2, axis=-1, keepdims=True, ) # cosine similarity as logits logits_per_text = jnp.matmul(text_embeds, image_embeds.T) logit_scale, logit_bias = (self.logit_scale, self.logit_bias) logits_per_text = logits_per_text * jnp.exp(logit_scale) + logit_bias logits_per_image = logits_per_text.T loss = None if return_loss: m1_diag1 = -jnp.ones_like(logits_per_text) + 2 * jnp.eye(logits_per_text.shape[0]) loglik = jax.nn.log_sigmoid(m1_diag1 * logits_per_text) nll = -jnp.sum(loglik, axis=-1) loss = nll.mean() if not return_dict: output = ( logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs, ) return ((loss,) + output) if loss is not None else output return SiglipOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, )
[docs]@register_module( TaskType.IMAGE_CLASSIFICATION, config=SiglipConfig, model_type="siglip", ) class SiglipForImageClassification(EasyDeLBaseModule): def __init__( self, config: SiglipConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.num_labels = config.num_labels vision_model = SiglipVisionModel( config.vision_config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.vision_model = vision_model.vision_model self.use_classif = config.num_labels > 0 # Classifier head if self.use_classif: self.classifier = ParallelLinear( config.vision_config.hidden_size, config.num_labels, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, pixel_values: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> tp.Union[tuple, ImageClassifierOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.vision_model( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = outputs[0] logits = jnp.mean(sequence_output, axis=1) if self.use_classif: logits = self.classifier(logits) if not return_dict: output = (logits,) + outputs[2:] return output return ImageClassifierOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )