easydel.modules.gemma3.modeling_gemma3

Contents

easydel.modules.gemma3.modeling_gemma3#

class easydel.modules.gemma3.modeling_gemma3.Gemma3Attention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Gemma3 Attention with Q/K normalization.

Inherits Q/K normalization from QKNormAttention. Features: - Custom Gemma3RMSNorm for Q/K normalization - Layer-specific sliding window - Custom softmax scaling (query_pre_attn_scalar)

class easydel.modules.gemma3.modeling_gemma3.Gemma3CausalLMOutputWithPast(loss: Optional[Union[Array, ndarray, bool, number]] = None, logits: Optional[Union[Array, ndarray, bool, number]] = None, last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, past_key_values: easydel.layers.caching.transformer.cache.TransformerCache | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, image_hidden_states: jaxtyping.Float[Array, 'batch seq_len hidden_dim'] | None = None)[source]#

Bases: ModelOutput

Base class for Gemma3 causal language model (or autoregressive) outputs.

Parameters
  • loss (chex.Array of shape (1,), optional, returned when labels is provided) – Language modeling loss (for next-token prediction).

  • logits (chex.Array of shape (batch_size, sequence_length, config.get_text_config().vocab_size)) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • past_key_values (tuple(tuple(chex.Array))) –

    Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))

    Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • hidden_states (`tuple(chex.Array) –

    when config.output_hidden_states=True): Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (`tuple(chex.Array) –

    or when config.output_attentions=True): Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • image_hidden_states (chex.Array, optional) – A chex.Array of size (batch_size, sequence_length, hidden_size). image_hidden_states of the model produced by the vision encoder after projecting last hidden state.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
image_hidden_states: jaxtyping.Float[Array, 'batch seq_len hidden_dim'] | None = None#
last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None#
logits: Optional[Union[Array, ndarray, bool, number]] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: easydel.layers.caching.transformer.cache.TransformerCache | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.modules.gemma3.modeling_gemma3.Gemma3DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single decoder layer for Gemma3 models.

Combines self-attention, optional cross-attention, and feedforward networks with residual connections and layer normalization.

class easydel.modules.gemma3.modeling_gemma3.Gemma3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Gemma3TextModel, Gemma3TextConfig]

Gemma3 model with a language modeling head for causal language modeling tasks.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: BaseVisionLanguageModule[Gemma3Model, Gemma3Config]

Gemma3 multimodal language model for conditional generation.

Combines a vision tower and a language model with a multi-modal projector. Inherits from BaseVisionLanguageModule to leverage common VLM infrastructure.

Features a custom apply_lm_head with final_logit_softcapping for improved training stability.

Class Attributes:

_task_type: IMAGE_TEXT_TO_TEXT task type _model_type: “gemma3” model identifier _supports_video: False (Gemma3 is image-only) _uses_mrope: False (uses standard RoPE)

apply_lm_head(hidden_states: Float[Array, 'batch seq_len hidden_dim']) Union[Array, ndarray, bool, number][source]#

Apply the language modeling head with optional logit softcapping.

Gemma3 uses final_logit_softcapping to prevent extreme logit values, which improves training stability.

Parameters

hidden_states – Hidden states from the model

Returns

LM logits with optional softcapping applied

get_image_features(pixel_values: Float[Array, 'batch channels height width'], **kwargs) Float[Array, 'batch num_patches hidden'][source]#

Extract and project image features from pixel values.

Delegates to the base model’s get_image_features implementation.

Parameters
  • pixel_values – Input image pixel values

  • **kwargs – Additional arguments (unused for Gemma3)

Returns

Projected image features ready for merging with text embeddings

get_language_model() Module[source]#

Returns the language model component.

get_projector() Module[source]#

Returns the multimodal projector component.

get_vision_tower() Module[source]#

Returns the vision tower component.

init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#

Initialize KV cache for generation.

loss_type = 'ForCausalLM'#
class easydel.modules.gemma3.modeling_gemma3.Gemma3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Text-only Gemma3 backbone with a classification head for sequence tasks.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module.

class easydel.modules.gemma3.modeling_gemma3.Gemma3MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-Layer Perceptron module for Gemma3 models.

Implements the feedforward network with gated activation functions and optional Float8 scaling for improved performance.

class easydel.modules.gemma3.modeling_gemma3.Gemma3Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Multimodal Gemma3 stack combining a vision tower, projector, and language model.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Gemma3 is a multi-modal model with a vision tower, but for typical LLM usage, it’s considered a decoder-only architecture.

get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) – The batch size for the cache.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

prepare_inputs_for_generation(input_ids: Int[Array, 'batch seq_len'], max_length: int, pad_token_id: int, starts: int | None = None, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: jaxtyping.Bool[Array, 'batch seq_len'] | None = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Sets up the initial inputs required for starting autoregressive generation.

This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.

Parameters
  • input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).

  • max_length (int) – The maximum sequence length that the KV cache should support.

  • pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.

  • starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.

  • shardings (dict | None) – Optional sharding configuration passed to init_cache.

  • attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).

  • token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.

Returns

A dictionary containing the prepared inputs, typically including:
  • ”past_key_values”: The initialized KV cache.

  • ”attention_mask”: The extended attention mask for generation.

  • ”position_ids”: The calculated initial position IDs.

  • ”token_type_ids”: (Optional) Prepared token type IDs.

This dictionary is then passed through prepare_inputs_for_call.

Return type

dict

update_inputs_for_generation(model_outputs, model_kwargs)[source]#

Updates the keyword arguments for the next generation step.

Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.

Parameters
  • model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).

  • model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.

Returns

The updated model_kwargs dictionary ready for the next generation step.

Return type

dict

class easydel.modules.gemma3.modeling_gemma3.Gemma3ModelOutputWithPast(last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, image_hidden_states: jaxtyping.Float[Array, 'batch seq_len hidden_dim'] | None = None, past_key_values: easydel.layers.caching.transformer.cache.TransformerCache | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None)[source]#

Bases: ModelOutput

past_key_values (tuple(tuple(chex.Array))):

Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))

Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

image_hidden_states (chex.Array, optional):

A chex.Array of size (batch_size, num_images, sequence_length, hidden_size). image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
image_hidden_states: jaxtyping.Float[Array, 'batch seq_len hidden_dim'] | None = None#
last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: easydel.layers.caching.transformer.cache.TransformerCache | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.modules.gemma3.modeling_gemma3.Gemma3MultiModalProjector(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-modal projector for Gemma3 vision-language models.

Projects vision features into the text embedding space, enabling cross-modal understanding and generation in Gemma3.

class easydel.modules.gemma3.modeling_gemma3.Gemma3RMSNorm(*args: Any, **kwargs: Any)[source]#

Bases: Module

Root Mean Square Layer Normalization for Gemma3 models.

Implements RMS normalization with Float8 support for efficient computation and memory usage in Gemma3 architecture.

static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
class easydel.modules.gemma3.modeling_gemma3.Gemma3TextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Decoder-only Gemma3 text transformer with embeddings and stacked decoder layers.

property default_frequencies#
get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.