easydel.modules.gemma3.modeling_gemma3_flax#
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3Attention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModule
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3CausalLMOutputWithPast(loss: Optional[Union[Array, ndarray, bool, number]] = None, logits: Union[Array, ndarray, bool, number] = None, past_key_values: Optional[TransformerCache] = None, hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, attentions: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, image_hidden_states: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Bases:
ModelOutputBase class for Gemma3 causal language model (or autoregressive) outputs.
- Parameters
loss (chex.Array of shape (1,), optional, returned when labels is provided) – Language modeling loss (for next-token prediction).
logits (chex.Array of shape (batch_size, sequence_length, config.text_config.vocab_size)) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (tuple(tuple(chex.Array)), optional, returned when use_cache=True is passed or when config.use_cache=True) –
Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.
hidden_states (tuple(chex.Array), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) –
Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (tuple(chex.Array), optional, returned when output_attentions=True is passed or when config.output_attentions=True) –
Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
image_hidden_states (chex.Array, optional) – A chex.Array of size (batch_size, sequence_length, hidden_size). image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- past_key_values: Optional[TransformerCache] = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3DecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGemma3 model with a language modeling head for causal language modeling tasks.
This model extends the base Gemma3TextModel by incorporating a linear language modeling head on top of the base model, designed for generative tasks and text generation. The model can optionally apply softcapping to logits based on configuration settings.
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- config: tp.Union[EasyDeLBaseConfig, _CP]#
- dtype: jnp.dtype#
- get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
- init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#
Initializes and returns a standard (non-paged) Key-Value cache.
This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.
- Parameters
batch_size (int) – The batch size for the cache.
max_length (int) – The maximum sequence length the cache needs to support.
starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).
shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).
pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.
- Returns
An initialized standard TransformerCache object.
- Return type
- loss_type = 'ForCausalLM'#
- param_dtype: jnp.dtype#
- precision: lax.PrecisionLike#
- prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pad_token_id: int, starts: int | None = None, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Sets up the initial inputs required for starting autoregressive generation.
This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.
- Parameters
input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).
max_length (int) – The maximum sequence length that the KV cache should support.
pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.
starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.
shardings (dict | None) – Optional sharding configuration passed to init_cache.
attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).
token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.
- Returns
- A dictionary containing the prepared inputs, typically including:
”past_key_values”: The initialized KV cache.
”attention_mask”: The extended attention mask for generation.
”position_ids”: The calculated initial position IDs.
”token_type_ids”: (Optional) Prepared token type IDs.
This dictionary is then passed through prepare_inputs_for_call.
- Return type
dict
- rngs: nn.Rngs#
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates the keyword arguments for the next generation step.
Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.
- Parameters
model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).
model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.
- Returns
The updated model_kwargs dictionary ready for the next generation step.
- Return type
dict
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3MLP(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3MultiModalProjector(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3RMSNorm(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.gemma3.modeling_gemma3_flax.Gemma3TextModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- property default_frequencies#