Source code for easydel.modules.siglip.modeling_siglip_flax

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
from functools import partial

import chex
import jax
import jax.numpy as jnp
from flax import nnx as nn
from jax import image as jimg

from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import (
	FlaxBaseModelOutput,
	FlaxBaseModelOutputWithPooling,
	FlaxImageClassifierOutput,
	ModelOutput,
)
from easydel.infra.utils import ACT2FN, control_mlp_sharding
from easydel.layers.attention import FlaxAttentionModule, FlexibleAttentionModule
from easydel.utils import traversals as etr

from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig


[docs]@etr.auto_pytree class SiglipVisionModelOutput(ModelOutput): image_embeds: tp.Optional[chex.Array] = None last_hidden_state: chex.Array = None hidden_states: tp.Optional[tp.Tuple[chex.Array, ...]] = None attentions: tp.Optional[tp.Tuple[chex.Array, ...]] = None
[docs]@etr.auto_pytree class SiglipTextModelOutput(ModelOutput): text_embeds: tp.Optional[chex.Array] = None last_hidden_state: chex.Array = None hidden_states: tp.Optional[tp.Tuple[chex.Array, ...]] = None attentions: tp.Optional[tp.Tuple[chex.Array, ...]] = None
[docs]@etr.auto_pytree class SiglipOutput(ModelOutput): loss: tp.Optional[chex.Array] = None logits_per_image: chex.Array = None logits_per_text: chex.Array = None text_embeds: chex.Array = None image_embeds: chex.Array = None text_model_output: FlaxBaseModelOutputWithPooling = None vision_model_output: FlaxBaseModelOutputWithPooling = None
[docs] def to_tuple(self) -> tp.Tuple[tp.Any]: return tuple( self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() for k in self.keys() )
[docs]class SiglipVisionEmbeddings(nn.Module): def __init__( self, config: SiglipVisionConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches self.position_embedding = nn.Embed( self.num_positions, self.embed_dim, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.patch_embedding = nn.Conv( in_features=config.num_channels, out_features=self.embed_dim, kernel_size=(self.patch_size, self.patch_size), strides=(self.patch_size, self.patch_size), padding="VALID", dtype=dtype, param_dtype=param_dtype, rngs=rngs, precision=precision, )
[docs] def interpolate(self, embeddings: chex.Array, height: int, width: int): num_patches = embeddings.shape[1] num_positions = self.position_embedding.weight.shape[0] if num_patches == num_positions and height == width: return self.position_embedding( jnp.arange( self.num_positions, dtype="i4", ).reshape(1, -1) ) patch_pos_embed = self.position_embedding.embedding.unsqueeze(0) dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = int(num_positions**0.5) patch_pos_embed = jnp.reshape( patch_pos_embed, (1, sqrt_num_positions, sqrt_num_positions, dim) ) patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 3, 1, 2)) patch_pos_embed = jimg.resize( patch_pos_embed, (1, dim, new_height, new_width), method="cubic", ) return jnp.reshape(jnp.transpose(patch_pos_embed, (0, 2, 3, 1)), (1, -1, dim))
def __call__(self, pixel_values: chex.Array, interpolate_pos_encoding=False): _, _, height, width = pixel_values.shape target_dtype = self.patch_embedding.kernel.dtype pixel_values = pixel_values.transpose(0, 2, 3, 1).astype(dtype=target_dtype) patch_embeds = self.patch_embedding(pixel_values).transpose(0, 3, 1, 2) embeddings = jnp.reshape(patch_embeds, patch_embeds.shape[:2] + (-1,)) embeddings = jnp.transpose(embeddings, (0, 2, 1)) if interpolate_pos_encoding: embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embedding( jnp.arange(self.num_positions, dtype="i4").reshape(1, -1) ) return embeddings
[docs]class SiglipTextEmbeddings(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): embed_dim = config.hidden_size self.token_embedding = nn.Embed( config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(), dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.position_embedding = nn.Embed( config.max_position_embeddings, embed_dim, embedding_init=jax.nn.initializers.normal(), dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, ) -> chex.Array: seq_length = ( input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] ) max_position_embedding = self.position_embedding.embedding.shape[0] if seq_length > max_position_embedding: raise ValueError( f"Sequence length must be less than max_position_embeddings (got `sequence length`: " f"{seq_length} and max_position_embeddings: {max_position_embedding}" ) if position_ids is None: position_ids = jnp.arange(seq_length, dtype="i4").reshape(1, -1) if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings
[docs]class SiglipAttention(FlaxAttentionModule): def __init__( self, config, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__(config=config) self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) self.dropout = config.attention_dropout linear_class = partial( nn.Linear, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, kernel_init=jax.nn.initializers.normal(0.01), ) self.k_proj = linear_class(self.embed_dim, self.embed_dim) self.v_proj = linear_class(self.embed_dim, self.embed_dim) self.q_proj = linear_class(self.embed_dim, self.embed_dim) self.out_proj = linear_class(self.embed_dim, self.embed_dim) self.causal = False self.attention_performer = FlexibleAttentionModule( base_config=config, softmax_scale=self.head_dim**-0.5, dropout_prob=config.attention_dropout, ) def _split_heads(self, hidden_states): return hidden_states.reshape( hidden_states.shape[:2] + (self.num_heads, self.head_dim) ) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) def __call__( self, hidden_states: chex.Array, attention_mask: tp.Optional[chex.Array] = None, output_attentions: bool = False, ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query) key = self._split_heads(key) value = self._split_heads(value) causal_attention_mask = None if self.causal: raise NotImplementedError() if attention_mask is not None and causal_attention_mask is not None: if attention_mask.ndim == 2: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = nn.combine_masks( attention_mask, causal_attention_mask, dtype="i4", ) elif causal_attention_mask is not None: attention_mask = causal_attention_mask elif attention_mask is not None: if attention_mask.ndim == 2: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_bias = None if attention_mask is not None: attention_bias = jax.lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) attention_mask = None attentions = self.attention_performer.forward( query_states=query, key_states=key, value_states=value, bias=None, init_bias=lambda: attention_bias, attention_mask=attention_mask, segment_ids=None, causal=self.causal, dropout_rng=self.rngs.params(), ) attn_output = self._merge_heads(attentions.attention_outputs) attn_output = self.out_proj(attn_output) outputs = ( (attn_output, attentions.attention_weights) if output_attentions else (attn_output, None) ) return outputs
[docs]class SiglipMLP(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.activation_fn = ACT2FN[config.hidden_act] linear_class = partial( nn.Linear, use_bias=True, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, kernel_init=jax.nn.initializers.normal(0.01), ) self.fc1 = linear_class(config.hidden_size, config.intermediate_size) self.fc2 = linear_class(config.intermediate_size, config.hidden_size) def __call__(self, hidden_states: chex.Array) -> chex.Array: hidden_states = control_mlp_sharding(hidden_states, self.config.partition_axis) hidden_states = self.fc2(self.activation_fn(self.fc1(hidden_states))) return hidden_states
[docs]class SiglipEncoderLayer(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.self_attn = SiglipAttention( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.layer_norm1 = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.mlp = SiglipMLP( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.layer_norm2 = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) def __call__( self, hidden_states: chex.Array, attention_mask: tp.Optional[chex.Array] = None, output_attentions: bool = False, ): residual = hidden_states hidden_states = self.layer_norm1(hidden_states) attn_outputs = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = attn_outputs[0] hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) + attn_outputs[1:] return outputs
[docs]class SiglipEncoder(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.layers = [ SiglipEncoderLayer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) for _ in range(config.num_hidden_layers) ] def __call__( self, inputs_embeds: chex.Array, attention_mask: tp.Optional[chex.Array] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): hidden_states = inputs_embeds all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = layer( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions += (layer_outputs[1],) if output_hidden_states: all_hidden_states += (hidden_states,) outputs = (hidden_states,) if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, )
[docs]class SiglipTextTransformer(EasyDeLBaseModule): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) embed_dim = config.hidden_size self.embeddings = SiglipTextEmbeddings( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.encoder = SiglipEncoder( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.final_layer_norm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.head = nn.Linear( embed_dim, config.projection_size, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, input_ids: chex.Array, attention_mask: chex.Array, position_ids: chex.Array, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) input_shape = input_ids.shape input_ids = input_ids.reshape(-1, input_shape[-1]) hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, ) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) pooled_output = last_hidden_state[:, -1, :] pooled_output = self.head(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return FlaxBaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )
[docs]@register_module( TaskType.BASE_MODULE, config=SiglipTextConfig, model_type="siglip_text_model", ) class SiglipTextModel(EasyDeLBaseModule): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.text_model = SiglipTextTransformer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, ) -> tp.Union[tp.Tuple, FlaxBaseModelOutputWithPooling]: return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict if return_dict is not None else self.config.use_return_dict, )
[docs]class SiglipVisionTransformer(EasyDeLBaseModule): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.encoder = SiglipEncoder( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.post_layernorm = nn.LayerNorm( embed_dim, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.use_head = ( True if not hasattr(config, "vision_use_head") else config.vision_use_head ) if self.use_head: self.head = SiglipMultiheadAttentionPoolingHead( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, pixel_values, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: tp.Optional[bool] = False, ) -> tp.Union[tp.Tuple, FlaxBaseModelOutputWithPooling]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.post_layernorm(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None if not return_dict: return (last_hidden_state, pooler_output) + encoder_outputs[1:] return FlaxBaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )
[docs]class MultiheadAttention(nn.Module): def __init__( self, embed_dim, num_heads, bias=True, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): if embed_dim <= 0 or num_heads <= 0: raise ValueError( f"embed_dim and num_heads must be greater than 0," f" got embed_dim={embed_dim} and num_heads={num_heads} instead" ) self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, ( "embed_dim must be divisible by num_heads" ) def normal_init(*shape): return nn.initializers.xavier_uniform()(rngs.param(), shape, param_dtype) def ze_init(*shape): return jnp.zeros(shape, param_dtype) self.in_proj_weight = nn.Param(normal_init(embed_dim * 3, embed_dim)) self.in_proj_bias = nn.Param(ze_init(3 * embed_dim)) self.out_proj = nn.Linear( embed_dim, embed_dim, use_bias=bias, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, query: chex.Array, key: chex.Array, value: chex.Array, ): qbs, qss, qds = query.shape b, s, d = value.shape qb, kb, vb = jnp.split(self.in_proj_bias, 3, -1) qw, kw, vw = jnp.split(self.in_proj_weight, 3, -1) qout = ((query @ qw) + qb).reshape(qbs, qss, self.num_heads, -1) kout = ((key @ kw) + kb).reshape(b, s, self.num_heads, -1) vout = ((value @ vw) + vb).reshape(b, s, self.num_heads, -1) attn = jnp.einsum( "bhqk,bkhd->bqhd", jax.nn.softmax( jnp.einsum( "bqhd,bkhd->bhqk", qout * (qout.shape[-1] ** -0.5), kout, ) ), vout, ) return self.out_proj(attn.reshape(qbs, qss, qds))
[docs]class SiglipMultiheadAttentionPoolingHead(nn.Module): def __init__( self, config: SiglipTextConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.probe = nn.Param( jax.random.normal( rngs.param(), (1, 1, config.hidden_size), param_dtype, ) ) self.attention = MultiheadAttention( config.hidden_size, config.num_attention_heads, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.layernorm = nn.LayerNorm( config.hidden_size, epsilon=config.layer_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.mlp = SiglipMLP( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__(self, hidden_state): batch_size = hidden_state.shape[0] probe = self.probe.value.repeat(batch_size, 0) hidden_state = self.attention(probe, hidden_state, hidden_state) residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0]
[docs]@register_module( TaskType.BASE_VISION, config=SiglipVisionConfig, model_type="siglip_vision_model", ) class SiglipVisionModel(nn.Module): def __init__( self, config: SiglipVisionConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): self.vision_model = SiglipVisionTransformer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, pixel_values, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> tp.Union[tp.Tuple, FlaxBaseModelOutputWithPooling]: return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, )
[docs]@register_module( TaskType.BASE_MODULE, config=SiglipConfig, model_type="siglip", ) class SiglipModel(EasyDeLBaseModule): def __init__( self, config: SiglipConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) if not isinstance(config.text_config, SiglipTextConfig): raise TypeError( "config.text_config is expected to be of type SiglipTextConfig but is of type" f" {type(config.text_config)}." ) if not isinstance(config.vision_config, SiglipVisionConfig): raise TypeError( "config.vision_config is expected to be of type SiglipVisionConfig but is of type" f" {type(config.vision_config)}." ) text_config = config.text_config vision_config = config.vision_config text_model = SiglipTextModel( text_config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) vision_model = SiglipVisionModel( vision_config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.text_model = text_model.text_model self.vision_model = vision_model.vision_model self.logit_scale = nn.Param(jax.random.normal(rngs.param(), (1,), param_dtype)) self.logit_bias = nn.Param(jax.random.normal(rngs.param(), (1,), param_dtype))
[docs] def get_text_features( self, input_ids: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, ) -> chex.Array: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = text_outputs[1] return pooled_output
[docs] def get_image_features( self, pixel_values: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> chex.Array: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) pooled_output = vision_outputs[1] return pooled_output
def __call__( self, input_ids: tp.Optional[chex.Array] = None, pixel_values: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, return_loss: tp.Optional[bool] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> tp.Union[tp.Tuple, SiglipOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[1] text_embeds = text_outputs[1] # normalized features image_embeds = image_embeds / jnp.linalg.norm( image_embeds, ord=2, axis=-1, keepdims=True, ) text_embeds = text_embeds / jnp.linalg.norm( text_embeds, ord=2, axis=-1, keepdims=True, ) # cosine similarity as logits logits_per_text = jnp.matmul(text_embeds, image_embeds.T) logit_scale, logit_bias = (self.logit_scale, self.logit_bias) logits_per_text = logits_per_text * jnp.exp(logit_scale) + logit_bias logits_per_image = logits_per_text.T loss = None if return_loss: m1_diag1 = -jnp.ones_like(logits_per_text) + 2 * jnp.eye(logits_per_text.shape[0]) loglik = jax.nn.log_sigmoid(m1_diag1 * logits_per_text) nll = -jnp.sum(loglik, axis=-1) loss = nll.mean() if not return_dict: output = ( logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs, ) return ((loss,) + output) if loss is not None else output return SiglipOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, )
[docs]@register_module( TaskType.IMAGE_CLASSIFICATION, config=SiglipConfig, model_type="siglip", ) class SiglipForImageClassification(EasyDeLBaseModule): def __init__( self, config: SiglipConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.num_labels = config.num_labels vision_model = SiglipVisionModel( config.vision_config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.vision_model = vision_model.vision_model self.use_classif = config.num_labels > 0 # Classifier head if self.use_classif: self.classifier = nn.Linear( config.vision_config.hidden_size, config.num_labels, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) def __call__( self, pixel_values: tp.Optional[chex.Array] = None, labels: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, return_dict: tp.Optional[bool] = None, interpolate_pos_encoding: bool = False, ) -> tp.Union[tuple, FlaxImageClassifierOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.vision_model( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, interpolate_pos_encoding=interpolate_pos_encoding, ) sequence_output = outputs[0] logits = jnp.mean(sequence_output, axis=1) if self.use_classif: logits = self.classifier(logits) if not return_dict: output = (logits,) + outputs[2:] return output return FlaxImageClassifierOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )