# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
from functools import partial
import chex
import jax
import jax.numpy as jnp
from flax import nnx as nn
from jax import image as jimg
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPooling,
FlaxImageClassifierOutput,
ModelOutput,
)
from easydel.infra.utils import ACT2FN, control_mlp_sharding
from easydel.layers.attention import FlaxAttentionModule, FlexibleAttentionModule
from easydel.utils import traversals as etr
from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
[docs]@etr.auto_pytree
class SiglipVisionModelOutput(ModelOutput):
image_embeds: tp.Optional[chex.Array] = None
last_hidden_state: chex.Array = None
hidden_states: tp.Optional[tp.Tuple[chex.Array, ...]] = None
attentions: tp.Optional[tp.Tuple[chex.Array, ...]] = None
[docs]@etr.auto_pytree
class SiglipTextModelOutput(ModelOutput):
text_embeds: tp.Optional[chex.Array] = None
last_hidden_state: chex.Array = None
hidden_states: tp.Optional[tp.Tuple[chex.Array, ...]] = None
attentions: tp.Optional[tp.Tuple[chex.Array, ...]] = None
[docs]@etr.auto_pytree
class SiglipOutput(ModelOutput):
loss: tp.Optional[chex.Array] = None
logits_per_image: chex.Array = None
logits_per_text: chex.Array = None
text_embeds: chex.Array = None
image_embeds: chex.Array = None
text_model_output: FlaxBaseModelOutputWithPooling = None
vision_model_output: FlaxBaseModelOutputWithPooling = None
[docs] def to_tuple(self) -> tp.Tuple[tp.Any]:
return tuple(
self[k]
if k not in ["text_model_output", "vision_model_output"]
else getattr(self, k).to_tuple()
for k in self.keys()
)
[docs]class SiglipVisionEmbeddings(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embed(
self.num_positions,
self.embed_dim,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.patch_embedding = nn.Conv(
in_features=config.num_channels,
out_features=self.embed_dim,
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
padding="VALID",
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
precision=precision,
)
[docs] def interpolate(self, embeddings: chex.Array, height: int, width: int):
num_patches = embeddings.shape[1]
num_positions = self.position_embedding.weight.shape[0]
if num_patches == num_positions and height == width:
return self.position_embedding(
jnp.arange(
self.num_positions,
dtype="i4",
).reshape(1, -1)
)
patch_pos_embed = self.position_embedding.embedding.unsqueeze(0)
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = int(num_positions**0.5)
patch_pos_embed = jnp.reshape(
patch_pos_embed, (1, sqrt_num_positions, sqrt_num_positions, dim)
)
patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 3, 1, 2))
patch_pos_embed = jimg.resize(
patch_pos_embed,
(1, dim, new_height, new_width),
method="cubic",
)
return jnp.reshape(jnp.transpose(patch_pos_embed, (0, 2, 3, 1)), (1, -1, dim))
def __call__(self, pixel_values: chex.Array, interpolate_pos_encoding=False):
_, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.kernel.dtype
pixel_values = pixel_values.transpose(0, 2, 3, 1).astype(dtype=target_dtype)
patch_embeds = self.patch_embedding(pixel_values).transpose(0, 3, 1, 2)
embeddings = jnp.reshape(patch_embeds, patch_embeds.shape[:2] + (-1,))
embeddings = jnp.transpose(embeddings, (0, 2, 1))
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(
jnp.arange(self.num_positions, dtype="i4").reshape(1, -1)
)
return embeddings
[docs]class SiglipTextEmbeddings(nn.Module):
def __init__(
self,
config: SiglipTextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
embed_dim = config.hidden_size
self.token_embedding = nn.Embed(
config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.position_embedding = nn.Embed(
config.max_position_embeddings,
embed_dim,
embedding_init=jax.nn.initializers.normal(),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
) -> chex.Array:
seq_length = (
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
)
max_position_embedding = self.position_embedding.embedding.shape[0]
if seq_length > max_position_embedding:
raise ValueError(
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
)
if position_ids is None:
position_ids = jnp.arange(seq_length, dtype="i4").reshape(1, -1)
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
[docs]class SiglipAttention(FlaxAttentionModule):
def __init__(
self,
config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(config=config)
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.dropout = config.attention_dropout
linear_class = partial(
nn.Linear,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
kernel_init=jax.nn.initializers.normal(0.01),
)
self.k_proj = linear_class(self.embed_dim, self.embed_dim)
self.v_proj = linear_class(self.embed_dim, self.embed_dim)
self.q_proj = linear_class(self.embed_dim, self.embed_dim)
self.out_proj = linear_class(self.embed_dim, self.embed_dim)
self.causal = False
self.attention_performer = FlexibleAttentionModule(
base_config=config,
softmax_scale=self.head_dim**-0.5,
dropout_prob=config.attention_dropout,
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(
hidden_states.shape[:2] + (self.num_heads, self.head_dim)
)
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def __call__(
self,
hidden_states: chex.Array,
attention_mask: tp.Optional[chex.Array] = None,
output_attentions: bool = False,
):
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = self._split_heads(query)
key = self._split_heads(key)
value = self._split_heads(value)
causal_attention_mask = None
if self.causal:
raise NotImplementedError()
if attention_mask is not None and causal_attention_mask is not None:
if attention_mask.ndim == 2:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_mask = nn.combine_masks(
attention_mask,
causal_attention_mask,
dtype="i4",
)
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask
elif attention_mask is not None:
if attention_mask.ndim == 2:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = None
if attention_mask is not None:
attention_bias = jax.lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
attention_mask = None
attentions = self.attention_performer.forward(
query_states=query,
key_states=key,
value_states=value,
bias=None,
init_bias=lambda: attention_bias,
attention_mask=attention_mask,
segment_ids=None,
causal=self.causal,
dropout_rng=self.rngs.params(),
)
attn_output = self._merge_heads(attentions.attention_outputs)
attn_output = self.out_proj(attn_output)
outputs = (
(attn_output, attentions.attention_weights)
if output_attentions
else (attn_output, None)
)
return outputs
[docs]class SiglipMLP(nn.Module):
def __init__(
self,
config: SiglipTextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
self.activation_fn = ACT2FN[config.hidden_act]
linear_class = partial(
nn.Linear,
use_bias=True,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
kernel_init=jax.nn.initializers.normal(0.01),
)
self.fc1 = linear_class(config.hidden_size, config.intermediate_size)
self.fc2 = linear_class(config.intermediate_size, config.hidden_size)
def __call__(self, hidden_states: chex.Array) -> chex.Array:
hidden_states = control_mlp_sharding(hidden_states, self.config.partition_axis)
hidden_states = self.fc2(self.activation_fn(self.fc1(hidden_states)))
return hidden_states
[docs]class SiglipEncoderLayer(nn.Module):
def __init__(
self,
config: SiglipTextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
self.self_attn = SiglipAttention(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.layer_norm1 = nn.LayerNorm(
config.hidden_size,
epsilon=config.layer_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.mlp = SiglipMLP(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.layer_norm2 = nn.LayerNorm(
config.hidden_size,
epsilon=config.layer_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
hidden_states: chex.Array,
attention_mask: tp.Optional[chex.Array] = None,
output_attentions: bool = False,
):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
attn_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = attn_outputs[0]
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,) + attn_outputs[1:]
return outputs
[docs]class SiglipEncoder(nn.Module):
def __init__(
self,
config: SiglipTextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
self.layers = [
SiglipEncoderLayer(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for _ in range(config.num_hidden_layers)
]
def __call__(
self,
inputs_embeds: chex.Array,
attention_mask: tp.Optional[chex.Array] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
hidden_states = inputs_embeds
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
[docs]class SiglipTextTransformer(EasyDeLBaseModule):
def __init__(
self,
config: SiglipTextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
embed_dim = config.hidden_size
self.embeddings = SiglipTextEmbeddings(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.encoder = SiglipEncoder(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.final_layer_norm = nn.LayerNorm(
config.hidden_size,
epsilon=config.layer_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.head = nn.Linear(
embed_dim,
config.projection_size,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
input_ids: chex.Array,
attention_mask: chex.Array,
position_ids: chex.Array,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
pooled_output = last_hidden_state[:, -1, :]
pooled_output = self.head(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
[docs]@register_module(
TaskType.BASE_MODULE,
config=SiglipTextConfig,
model_type="siglip_text_model",
)
class SiglipTextModel(EasyDeLBaseModule):
def __init__(
self,
config: SiglipTextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.text_model = SiglipTextTransformer(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
) -> tp.Union[tp.Tuple, FlaxBaseModelOutputWithPooling]:
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
if return_dict is not None
else self.config.use_return_dict,
)
[docs]class MultiheadAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
bias=True,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
if embed_dim <= 0 or num_heads <= 0:
raise ValueError(
f"embed_dim and num_heads must be greater than 0,"
f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, (
"embed_dim must be divisible by num_heads"
)
def normal_init(*shape):
return nn.initializers.xavier_uniform()(rngs.param(), shape, param_dtype)
def ze_init(*shape):
return jnp.zeros(shape, param_dtype)
self.in_proj_weight = nn.Param(normal_init(embed_dim * 3, embed_dim))
self.in_proj_bias = nn.Param(ze_init(3 * embed_dim))
self.out_proj = nn.Linear(
embed_dim,
embed_dim,
use_bias=bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
query: chex.Array,
key: chex.Array,
value: chex.Array,
):
qbs, qss, qds = query.shape
b, s, d = value.shape
qb, kb, vb = jnp.split(self.in_proj_bias, 3, -1)
qw, kw, vw = jnp.split(self.in_proj_weight, 3, -1)
qout = ((query @ qw) + qb).reshape(qbs, qss, self.num_heads, -1)
kout = ((key @ kw) + kb).reshape(b, s, self.num_heads, -1)
vout = ((value @ vw) + vb).reshape(b, s, self.num_heads, -1)
attn = jnp.einsum(
"bhqk,bkhd->bqhd",
jax.nn.softmax(
jnp.einsum(
"bqhd,bkhd->bhqk",
qout * (qout.shape[-1] ** -0.5),
kout,
)
),
vout,
)
return self.out_proj(attn.reshape(qbs, qss, qds))
[docs]class SiglipMultiheadAttentionPoolingHead(nn.Module):
def __init__(
self,
config: SiglipTextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.probe = nn.Param(
jax.random.normal(
rngs.param(),
(1, 1, config.hidden_size),
param_dtype,
)
)
self.attention = MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.layernorm = nn.LayerNorm(
config.hidden_size,
epsilon=config.layer_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.mlp = SiglipMLP(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.value.repeat(batch_size, 0)
hidden_state = self.attention(probe, hidden_state, hidden_state)
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
[docs]@register_module(
TaskType.BASE_VISION,
config=SiglipVisionConfig,
model_type="siglip_vision_model",
)
class SiglipVisionModel(nn.Module):
def __init__(
self,
config: SiglipVisionConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.vision_model = SiglipVisionTransformer(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
pixel_values,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> tp.Union[tp.Tuple, FlaxBaseModelOutputWithPooling]:
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
[docs]@register_module(
TaskType.BASE_MODULE,
config=SiglipConfig,
model_type="siglip",
)
class SiglipModel(EasyDeLBaseModule):
def __init__(
self,
config: SiglipConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
if not isinstance(config.text_config, SiglipTextConfig):
raise TypeError(
"config.text_config is expected to be of type SiglipTextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, SiglipVisionConfig):
raise TypeError(
"config.vision_config is expected to be of type SiglipVisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
text_model = SiglipTextModel(
text_config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
vision_model = SiglipVisionModel(
vision_config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.text_model = text_model.text_model
self.vision_model = vision_model.vision_model
self.logit_scale = nn.Param(jax.random.normal(rngs.param(), (1,), param_dtype))
self.logit_bias = nn.Param(jax.random.normal(rngs.param(), (1,), param_dtype))
[docs] def get_text_features(
self,
input_ids: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
) -> chex.Array:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
return pooled_output
[docs] def get_image_features(
self,
pixel_values: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> chex.Array:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
pooled_output = vision_outputs[1]
return pooled_output
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
pixel_values: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
return_loss: tp.Optional[bool] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> tp.Union[tp.Tuple, SiglipOutput]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
text_embeds = text_outputs[1]
# normalized features
image_embeds = image_embeds / jnp.linalg.norm(
image_embeds,
ord=2,
axis=-1,
keepdims=True,
)
text_embeds = text_embeds / jnp.linalg.norm(
text_embeds,
ord=2,
axis=-1,
keepdims=True,
)
# cosine similarity as logits
logits_per_text = jnp.matmul(text_embeds, image_embeds.T)
logit_scale, logit_bias = (self.logit_scale, self.logit_bias)
logits_per_text = logits_per_text * jnp.exp(logit_scale) + logit_bias
logits_per_image = logits_per_text.T
loss = None
if return_loss:
m1_diag1 = -jnp.ones_like(logits_per_text) + 2 * jnp.eye(logits_per_text.shape[0])
loglik = jax.nn.log_sigmoid(m1_diag1 * logits_per_text)
nll = -jnp.sum(loglik, axis=-1)
loss = nll.mean()
if not return_dict:
output = (
logits_per_image,
logits_per_text,
text_embeds,
image_embeds,
text_outputs,
vision_outputs,
)
return ((loss,) + output) if loss is not None else output
return SiglipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
[docs]@register_module(
TaskType.IMAGE_CLASSIFICATION,
config=SiglipConfig,
model_type="siglip",
)
class SiglipForImageClassification(EasyDeLBaseModule):
def __init__(
self,
config: SiglipConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.num_labels = config.num_labels
vision_model = SiglipVisionModel(
config.vision_config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.vision_model = vision_model.vision_model
self.use_classif = config.num_labels > 0
# Classifier head
if self.use_classif:
self.classifier = nn.Linear(
config.vision_config.hidden_size,
config.num_labels,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
pixel_values: tp.Optional[chex.Array] = None,
labels: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> tp.Union[tuple, FlaxImageClassifierOutput]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.vision_model(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
sequence_output = outputs[0]
logits = jnp.mean(sequence_output, axis=1)
if self.use_classif:
logits = self.classifier(logits)
if not return_dict:
output = (logits,) + outputs[2:]
return output
return FlaxImageClassifierOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)