# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
from functools import cached_property, partial
import chex
import jax
import jax.numpy as jnp
from eformer import common_types
from eformer.escale import apply_logical_sharding
from eformer.pytree import auto_pytree
from flax import nnx as nn
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import (
AttentionLayerOutput,
BaseModelOutput,
CausalLMOutput,
DecoderLayerOutput,
ModelOutput,
SequenceClassifierOutput,
)
from easydel.infra.utils import (
ACT2FN,
auto_remat,
block_wise_ffn,
get_dot_general_by_bits,
)
from easydel.layers.attention import AttentionModule, FlexibleAttentionModule
from easydel.layers.caching import (
PagedAttentionCache,
PagedAttentionCacheView,
PagedAttentionMetadata,
TransformerCache,
TransformerCacheView,
TransformerMetadata,
)
from easydel.layers.linear import ParallelLinear
from easydel.layers.norms import float8s
from easydel.modules.auto.auto_modeling import AutoEasyDeLVisionModel
from easydel.utils.helpers import get_logger
from .gemma3_configuration import Gemma3Config, Gemma3TextConfig
logger = get_logger(__name__)
[docs]@auto_pytree
class Gemma3CausalLMOutputWithPast(ModelOutput):
"""
Base class for Gemma3 causal language model (or autoregressive) outputs.
Args:
loss (`chex.Array` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`chex.Array` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(chex.Array))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(chex.Array)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(chex.Array)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `chex.Array` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(chex.Array)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `chex.Array` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`chex.Array`, *optional*):
A `chex.Array` of size `(batch_size, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
loss: tp.Optional[chex.Array] = None
logits: chex.Array = None
past_key_values: tp.Optional[TransformerCache] = None
hidden_states: tp.Optional[tp.Tuple[chex.Array]] = None
attentions: tp.Optional[tp.Tuple[chex.Array]] = None
image_hidden_states: tp.Optional[chex.Array] = None
[docs]class Gemma3RMSNorm(nn.Module):
def __init__(
self,
config: Gemma3TextConfig,
param_dtype: jnp.dtype = jnp.float32,
dim: tp.Optional[int] = None,
epsilon: tp.Optional[float] = None,
):
self.config = config
self.epsilon = self.config.rms_norm_eps if epsilon is None else epsilon
self.param_dtype = param_dtype
dim = self.config.hidden_size if dim is None else dim
self.kernel = nn.Param(jnp.ones(dim, dtype=param_dtype))
def _norm(self, x: jax.Array) -> jax.Array:
return x * (1 / jnp.sqrt(jnp.power(x, 2).mean(-1, keepdims=True) + self.epsilon))
def __call__(self, hidden_states: jax.Array) -> jax.Array:
variance = self._norm(hidden_states.astype(jnp.float32)).astype(self.param_dtype)
out = (1 + self.kernel.value.astype(self.param_dtype)) * variance
if out.dtype in float8s:
out = out.astype(jnp.bfloat16)
return out
[docs]class Gemma3Attention(AttentionModule):
def __init__(
self,
config: Gemma3TextConfig,
layer_idx: int,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: tp.Optional[tp.Union[str, jax.lax.Precision]] = None,
causal: bool = True,
is_cross_attention: bool = False,
*,
rngs: nn.Rngs,
):
super().__init__(config)
self.layer_idx = layer_idx
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.is_cross_attention = is_cross_attention
self.rngs = rngs
self.causal = causal
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(
config,
"head_dim",
config.hidden_size // config.num_attention_heads,
)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
kernel = jax.nn.initializers.normal(config.initializer_range)
linear = partial(
ParallelLinear,
use_bias=config.attention_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
kernel_init=kernel,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.q_proj = linear(self.embed_dim, self.num_heads * self.head_dim)
self.k_proj = linear(self.embed_dim, self.num_key_value_heads * self.head_dim)
self.v_proj = linear(self.embed_dim, self.num_key_value_heads * self.head_dim)
self.o_proj = linear(self.num_heads * self.head_dim, self.embed_dim)
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
self.sliding_window = config.sliding_window if self.is_sliding else None
self.q_norm = Gemma3RMSNorm(
self.config,
param_dtype=self.param_dtype,
dim=self.head_dim,
)
self.k_norm = Gemma3RMSNorm(
self.config,
param_dtype=self.param_dtype,
dim=self.head_dim,
)
self.attention_performer = FlexibleAttentionModule(
base_config=config,
softmax_scale=self.config.query_pre_attn_scalar**-0.5,
dropout_prob=config.attention_dropout,
)
self.rotary = self.config.get_basic_rope(
self.dtype,
self.head_dim,
self.head_dim,
True,
)
def _merge_heads(self, hidden_states):
"""
Merges the attention heads into a single hidden state tensor.
Args:
hidden_states (chex.Array): The hidden states with separate head dimensions.
Returns:
chex.Array: The hidden states with merged head dimensions.
"""
return hidden_states.reshape(
hidden_states.shape[:2] + (self.num_heads * self.head_dim,)
)
def _split_heads(self, hidden_states, num_heads):
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
position_ids: chex.Array,
causal_mask: tp.Optional[chex.Array | bool],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: tp.Optional[TransformerCacheView | PagedAttentionCacheView] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
segment_ids: tp.Optional[chex.Array] = None,
token_type_ids: tp.Optional[chex.Array] = None,
output_attentions: bool = False,
fcm_mask: tp.Optional[chex.Array] = None,
frequencies: tp.Optional[chex.Array] = None,
):
"""
Forward pass of the attention module.
Args:
hidden_states (chex.Array): Input hidden states.
attention_mask (chex.Array): Mask to apply on the attention scores.
position_ids (chex.Array): Position indices for the tokens.
causal_mask (chex.Array): Causal mask for ensuring autoregressive behavior.
segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional).
deterministic (bool): If True, disables dropout for deterministic behavior.
init_cache (bool): If True, initializes cache for caching keys and values.
output_attentions (bool): If True, outputs attention weights alongside the hidden states.
fcm_mask (tp.Optional[chex.Array]): fcm mask to be combined with attn mask and causal mask.
Returns:
tp.Tuple[chex.Array, chex.Array]: A tuple containing the attention output and the attention weights.
"""
batch_size, sequence_length = hidden_states.shape[:2]
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
(query_states, key_states, value_states) = (
self.q_proj(hidden_states),
self.k_proj(hidden_states),
self.v_proj(hidden_states),
)
query_states = query_states.reshape(*hidden_shape)
key_states = key_states.reshape(*hidden_shape)
value_states = value_states.reshape(*hidden_shape)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
(
query_states,
key_states,
value_states,
) = self.apply_qkv_shardings(query_states, key_states, value_states)
query_states, key_states = self.rotary(
positions=position_ids,
query=query_states,
key=key_states,
frequencies=frequencies,
)
(
key_states,
value_states,
attention_mask,
init_attention_bias,
cache_view,
) = self.concatenate(
query=query_states,
key=key_states,
value=value_states,
cache_view=cache_view,
cache_metadata=cache_metadata,
attention_mask=attention_mask,
causal_mask=causal_mask,
token_type_ids=token_type_ids,
fcm_mask=fcm_mask,
sliding_window=None,
)
if self.is_sliding:
cache_pos = self.build_cache_pos(attention_mask, cache_view)
if isinstance(cache_view, TransformerCacheView):
curr_index = cache_view.index
else:
curr_index = jnp.repeat(
jnp.array([0], "i4").reshape(-1),
query_states.shape[0],
0,
)
sliding_mask = self._create_sliding_mask(
cache_pos=cache_pos,
curr_index=curr_index,
cache_length=attention_mask.shape[-1],
sliding_window=self.sliding_window,
)
attention_mask = jnp.logical_and(sliding_mask, attention_mask)
def init_attention_bias():
return jax.lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
attentions = self.attention_performer.forward(
query_states=query_states,
key_states=key_states,
value_states=value_states,
mode=mode,
bias=None,
sliding_window=self.sliding_window,
cache_metadata=cache_metadata,
cache_view=cache_view,
init_bias=init_attention_bias,
attention_mask=attention_mask,
segment_ids=segment_ids,
causal=True,
dropout_rng=self.rngs.params(),
)
attn_output = self.shard_attention_prod(
self._merge_heads(attentions.attention_outputs)
)
attn_output = self.o_proj(attn_output)
return AttentionLayerOutput(
attention_output=attn_output,
attention_weight=attentions.attention_weights if output_attentions else None,
cache_view=cache_view,
)
[docs]class Gemma3MLP(nn.Module):
def __init__(
self,
config: Gemma3TextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: tp.Optional[tp.Union[str, jax.lax.Precision]] = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
embed_dim = self.config.hidden_size
inner_dim = (
self.config.intermediate_size
if self.config.intermediate_size is not None
else 4 * embed_dim
)
kernel_init = jax.nn.initializers.normal(config.initializer_range)
self.act = ACT2FN[self.config.hidden_activation]
linear_class = partial(
ParallelLinear,
use_bias=False,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
kernel_init=kernel_init,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.gate_proj = linear_class(embed_dim, inner_dim)
self.down_proj = linear_class(inner_dim, embed_dim)
self.up_proj = linear_class(embed_dim, inner_dim)
def __call__(self, hidden_states):
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
gate = self.act(self.gate_proj(hidden_states))
up = self.up_proj(hidden_states)
hidden_states = self.down_proj(gate * up)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return hidden_states
[docs]class Gemma3DecoderLayer(nn.Module):
def __init__(
self,
config: Gemma3TextConfig,
layer_idx: int,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: tp.Optional[tp.Union[str, jax.lax.Precision]] = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.layer_idx = layer_idx
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
mlp_block = Gemma3MLP
attn_block = Gemma3Attention
attn_block, mlp_block = auto_remat(
attn_block,
mlp_block,
policy=config.gradient_checkpointing,
)
self.self_attn = attn_block(
self.config,
layer_idx=self.layer_idx,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.mlp = mlp_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.input_layernorm = Gemma3RMSNorm(
self.config,
param_dtype=self.param_dtype,
)
self.post_attention_layernorm = Gemma3RMSNorm(
self.config,
param_dtype=self.param_dtype,
)
self.pre_feedforward_layernorm = Gemma3RMSNorm(
self.config,
param_dtype=self.param_dtype,
)
self.post_feedforward_layernorm = Gemma3RMSNorm(
self.config,
param_dtype=self.param_dtype,
)
self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
position_ids: chex.Array,
causal_mask: tp.Optional[chex.Array | bool],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: tp.Optional[TransformerCacheView | PagedAttentionCacheView] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
segment_ids: tp.Optional[chex.Array] = None,
token_type_ids: tp.Optional[chex.Array] = None,
output_attentions: bool = False,
fcm_mask: tp.Optional[chex.Array] = None,
frequencies: tp.Optional[chex.Array] = None,
default_frequencies: tp.Optional[chex.Array] = None,
):
"""
Forward pass of the module block.
Args:
hidden_states (chex.Array): Input hidden states.
attention_mask (chex.Array): Mask to apply on the attention scores.
position_ids (chex.Array): Position indices for the tokens.
causal_mask (chex.Array): Causal mask for ensuring autoregressive behavior.
segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional).
deterministic (bool): If True, disables dropout for deterministic behavior.
init_cache (bool): If True, initializes cache for caching keys and values.
output_attentions (bool): If True, outputs attention weights alongside the hidden states.
fcm_mask (tp.Optional[chex.Array]): fcm mask to be combined with attn mask and causal mask.
Returns:
tp.Tuple[chex.Array, chex.Array]: A tuple containing the attention output and the attention weights.
"""
residual = hidden_states
frequencies = default_frequencies if self.is_sliding else frequencies
hidden_states = self.input_layernorm(hidden_states)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
attn_outputs = self.self_attn(
hidden_states,
attention_mask,
position_ids,
causal_mask,
mode,
cache_view,
cache_metadata,
segment_ids,
token_type_ids,
output_attentions,
fcm_mask,
frequencies,
)
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
hidden_states = self.post_attention_layernorm(attn_outputs.attention_output)
hidden_states = residual + hidden_states
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
if self.config.use_scan_mlp:
hidden_states = block_wise_ffn(
self.mlp,
hidden_states,
self.config.scan_mlp_chunk_size,
)
else:
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return DecoderLayerOutput(
hidden_states=hidden_states,
attention_weight=attn_outputs.attention_weight,
cache_view=attn_outputs.cache_view,
)
[docs]@register_module(
TaskType.BASE_MODULE,
config=Gemma3TextConfig,
model_type="gemma3_text",
)
class Gemma3TextModel(EasyDeLBaseModule):
def __init__(
self,
config: Gemma3TextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.hidden_size = self.config.hidden_size
self.embed_tokens = nn.Embed(
self.config.vocab_size,
self.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.layers = [
Gemma3DecoderLayer(
self.config,
layer_idx=i,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for i in range(self.config.num_hidden_layers)
]
self.norm = Gemma3RMSNorm(self.config, param_dtype=self.dtype)
@cached_property
def default_frequencies(self):
from easydel.infra.utils import ModuleCaches
from easydel.layers.rotary_embedding import get_frequencies
frequencies = get_frequencies(
head_size=self.config.head_dim,
rotary_dim=self.config.head_dim,
max_position=self.config.granted_freq_max_position_embedding,
base=self.config.rope_local_base_freq,
rope_scaling=None,
)
return ModuleCaches(frequencies)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
segment_ids: tp.Optional[chex.Array] = None,
token_type_ids: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
mode: tp.Optional[common_types.RUNTIME_MODE_TYPES] = None, # type:ignore
past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
) -> BaseModelOutput:
"""
Forward pass through the Gemma2 module.
Args:
input_ids (chex.Array): Input tensor containing token IDs.
attention_mask (chex.Array): Mask for attention.
position_ids (chex.Array): Positional indices.
segment_ids (tp.Optional[chex.Array]): Segment IDs for different input parts.
inputs_embeds (tp.Optional[chex.Array]): Embedded input tensor.
output_attentions (tp.Optional[bool]): If True, output attention weights.
output_hidden_states (tp.Optional[bool]): If True, output hidden states.
init_cache (bool): If True, initialize cache for decoding.
deterministic (bool): If True, disable dropout.
Returns:
BaseModelOutput | tp.Tuple: Model output, either as a named tuple or a standard tuple.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids.astype("i4")) * (
self.config.hidden_size**0.5
)
batch_size, sequence_length, _ = inputs_embeds.shape
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length), "b1")
else:
if attention_mask.dtype != jnp.bool:
attention_mask = jnp.astype(attention_mask == 1, "b1")
if position_ids is None:
position_ids = jnp.broadcast_to(
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
(batch_size, sequence_length),
)
inputs_embeds = inputs_embeds
assert sequence_length <= self.config.max_position_embeddings, (
f"Maximum Position Embedding Reached ! (Excepted <= {self.config.max_position_embeddings} got {sequence_length})"
)
if attention_mask.ndim == 2:
attention_mask = jnp.expand_dims(attention_mask, (1, 2))
hidden_states = inputs_embeds
if mode is None:
mode = (
common_types.MODE_DECODE
if sequence_length == 1 and past_key_values is not None
else common_types.MODE_TRAIN
)
if past_key_values is None:
past_key_values = TransformerCache.init_empty(len(self.layers))
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
causal_mask = self.causal_mask
for idx, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
mode=mode,
cache_view=past_key_values.views[idx],
cache_metadata=cache_metadata,
causal_mask=causal_mask,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
segment_ids=segment_ids,
frequencies=self.frequencies,
default_frequencies=self.default_frequencies,
)
hidden_states = layer_outputs.hidden_states
if output_attentions:
all_attentions += (layer_outputs.attention_weight,)
past_key_values[idx] = layer_outputs.cache_view
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=past_key_values,
)
[docs]@register_module(
TaskType.CAUSAL_LM,
config=Gemma3TextConfig,
model_type="gemma3_text",
)
class Gemma3ForCausalLM(EasyDeLBaseModule):
"""Gemma3 model with a language modeling head for causal language modeling tasks.
This model extends the base Gemma3TextModel by incorporating a linear language modeling head on top
of the base model, designed for generative tasks and text generation. The model can optionally apply
softcapping to logits based on configuration settings.
"""
def __init__(
self,
config: Gemma3TextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
if param_dtype == jnp.float16 or param_dtype == "f2":
logger.error(
"Gemma-3's recommended dtype is bfloat16, but you are using float16. "
"This may result in junk responses or incorrect predictions."
)
self.model = Gemma3TextModel(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.lm_head = ParallelLinear(
config.hidden_size,
config.vocab_size,
use_bias=False,
rngs=rngs,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
kernel_init=jax.nn.initializers.normal(stddev=config.initializer_range),
**get_dot_general_by_bits(config.bits, config.easy_method),
)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
segment_ids: tp.Optional[chex.Array] = None,
token_type_ids: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
mode: tp.Optional[common_types.RUNTIME_MODE_TYPES] = None, # type:ignore
past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
) -> CausalLMOutput:
"""
Forward pass through the Gemma3 model.
Args:
input_ids (tp.Optional[chex.Array]): Input tensor containing token IDs.
attention_mask (tp.Optional[chex.Array]): Mask for attention.
position_ids (tp.Optional[chex.Array]): Positional indices.
segment_ids (tp.Optional[chex.Array]): Segment IDs for different input parts.
token_type_ids (tp.Optional[chex.Array]): Token type IDs for handling different types of tokens.
inputs_embeds (tp.Optional[chex.Array]): Embedded input tensor.
output_attentions (tp.Optional[bool]): If True, output attention weights.
output_hidden_states (tp.Optional[bool]): If True, output hidden states.
past_key_values (tp.Optional[TransformerCache | PagedAttentionCache]): Cached key values for faster inference.
cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for cache handling.
Returns:
CausalLMOutput | tp.Tuple: Model output, either as a named tuple or a standard tuple.
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
mode=mode,
past_key_values=past_key_values,
cache_metadata=cache_metadata,
inputs_embeds=inputs_embeds,
segment_ids=segment_ids,
token_type_ids=token_type_ids,
)
hidden_states = outputs.last_hidden_state
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
if self.config.tie_word_embeddings:
lm_logits = jax.lax.dot_general(
hidden_states,
self.model.embed_tokens.embedding.value.T,
(((hidden_states.ndim - 1), (0,)), ((), ())),
)
else:
lm_logits = self.lm_head(hidden_states)
if self.config.final_logit_softcapping is not None:
cap = jnp.array(self.config.final_logit_softcapping, dtype=lm_logits.dtype)
lm_logits = cap * jax.nn.tanh(lm_logits / cap)
return CausalLMOutput(
logits=lm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
past_key_values=outputs.past_key_values,
)
[docs]@register_module(
TaskType.SEQUENCE_CLASSIFICATION,
config=Gemma3TextConfig,
model_type="gemma3_text",
)
class Gemma3ForSequenceClassification(EasyDeLBaseModule):
def __init__(
self,
config: Gemma3TextConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.model = Gemma3TextModel(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
assert hasattr(config, "num_labels"), (
"in order to use `SequenceClassification` Models in `EasyDeL` you first need to attach `num_labels` to model `config`"
)
self.score = ParallelLinear(
self.config.hidden_size,
config.num_labels,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(stddev=config.initializer_range),
precision=self.precision,
rngs=rngs,
)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
segment_ids: tp.Optional[chex.Array] = None,
mode: tp.Optional[common_types.RUNTIME_MODE_TYPES] = None, # type:ignore
past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
) -> SequenceClassifierOutput:
transformer_outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
mode=mode,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
inputs_embeds=inputs_embeds,
segment_ids=segment_ids,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (
jnp.argmax(jnp.equal(input_ids, self.config.pad_token_id).astype("i4"), -1)
- 1
)
sequence_lengths = sequence_lengths % input_ids.shape[-1]
else:
sequence_lengths = -1
pooled_logits = logits[jnp.arange(batch_size), sequence_lengths]
return SequenceClassifierOutput(
logits=pooled_logits,
past_key_values=past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
[docs]class Gemma3MultiModalProjector(nn.Module):
def __init__(
self,
config: Gemma3Config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
self.mm_input_projection_weight = nn.Param(
jnp.zeros(
(
config.text_config.hidden_size,
config.vision_config.hidden_size,
),
dtype=param_dtype,
)
)
self.mm_soft_emb_norm = Gemma3RMSNorm(
config.vision_config,
param_dtype=param_dtype,
dim=config.vision_config.hidden_size,
epsilon=config.vision_config.layer_norm_eps,
)
self.patches_per_image = int(
config.vision_config.image_size // config.vision_config.patch_size
)
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
kernel_size = self.patches_per_image // self.tokens_per_side
self.kernel_size = kernel_size
self.avg_pool = lambda x: jax.lax.reduce_window(
x,
init_value=0.0,
computation=jax.lax.add,
window_dimensions=(1, 1, kernel_size, kernel_size),
window_strides=(1, 1, kernel_size, kernel_size),
padding="VALID",
) / (kernel_size * kernel_size)
def __call__(self, vision_outputs):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = jnp.transpose(vision_outputs, (0, 2, 1))
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size,
seq_length,
self.patches_per_image,
self.patches_per_image,
)
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.reshape(batch_size, seq_length, -1)
pooled_vision_outputs = jnp.transpose(pooled_vision_outputs, (0, 2, 1))
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = jax.lax.dot_general(
normed_vision_outputs,
self.mm_input_projection_weight.T,
(((normed_vision_outputs.ndim - 1), (0,)), ((), ())),
)
return projected_vision_outputs.astype(vision_outputs.dtype)
[docs]@register_module(
TaskType.IMAGE_TEXT_TO_TEXT,
config=Gemma3Config,
model_type="gemma3",
)
class Gemma3ForConditionalGeneration(EasyDeLBaseModule):
loss_type = "ForCausalLM"
def __init__(
self,
config: Gemma3Config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.vision_tower = AutoEasyDeLVisionModel.from_config(
config=config.vision_config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.multi_modal_projector = Gemma3MultiModalProjector(
config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.vocab_size = config.text_config.vocab_size
self.language_model = Gemma3ForCausalLM(
config=config.text_config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
)
[docs] def get_image_features(self, pixel_values: chex.Array) -> chex.Array:
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = self.multi_modal_projector(vision_outputs)
return image_features
def __call__(
self,
input_ids: chex.Array = None,
pixel_values: chex.Array = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
mode: tp.Optional[common_types.RUNTIME_MODE_TYPES] = None, # type:ignore
past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
token_type_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
**lm_kwargs,
):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids
llm_input_ids = jnp.where(special_image_mask, 0, llm_input_ids)
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.language_model.model.embed_tokens(llm_input_ids) * (
self.config.text_config.hidden_size**0.5
)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.language_model.model.embed_tokens(
jnp.array(self.config.image_token_index, dtype="i4")
)
else:
special_image_mask = jnp.expand_dims(
(input_ids == self.config.image_token_index),
-1,
)
special_image_mask = jnp.broadcast_to(special_image_mask, inputs_embeds.shape)
image_features = image_features.astype(inputs_embeds.dtype)
inputs_embeds = jnp.place(
inputs_embeds,
special_image_mask,
image_features,
inplace=False,
)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
mode=mode,
past_key_values=past_key_values,
cache_metadata=cache_metadata,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids,
segment_ids=None,
**lm_kwargs,
)
return Gemma3CausalLMOutputWithPast(
loss=None,
logits=outputs.logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
[docs] def init_cache(
self,
batch_size,
max_length,
starts=None,
shardings=None,
pad_token_id=None,
):
return self.language_model.init_cache(
batch_size,
max_length,
starts,
shardings,
pad_token_id,
)
def _get_compile_model_kwargs(
self,
batch_size: int,
input_tokens_length: int,
input_sharding: jax.sharding.PartitionSpec,
rngs: jax.random.PRNGKey,
vision_included: bool = False,
vision_batch_size: int = 1,
vision_channels: int = 3,
vision_height: tp.Optional[int] = None,
vision_width: tp.Optional[int] = None,
required_props: tp.Optional[tp.Mapping[str, tp.Dict[str, tp.Any]]] = None,
**kwargs,
):
basics = super()._get_compile_model_kwargs(
batch_size=batch_size,
input_tokens_length=input_tokens_length,
input_sharding=input_sharding,
rngs=rngs,
vision_included=vision_included,
vision_batch_size=vision_batch_size,
vision_channels=vision_channels,
vision_height=vision_height,
vision_width=vision_width,
required_props=required_props,
**kwargs,
)
token_type_ids = jnp.ones(
(batch_size, input_tokens_length),
dtype="i4",
device=input_sharding,
)
basics.update({"token_type_ids": token_type_ids})
if vision_included:
pixel_values = jnp.ones(
(
vision_batch_size or 1,
vision_channels or 3,
self.config.vision_config.image_size,
self.config.vision_config.image_size,
),
dtype="f4",
)
basics.update({"pixel_values": pixel_values})
return basics