easydel.modules.gemma2.modeling_gemma2#

class easydel.modules.gemma2.modeling_gemma2.Gemma2Attention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Multi-head attention layer with RoPE embeddings for Gemma2 models.

Inherits from UnifiedAttention with Gemma2-specific customizations: - Sliding window attention (layer-specific) - Custom query pre-attention scalar

class easydel.modules.gemma2.modeling_gemma2.Gemma2DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single decoder layer for Gemma2 models.

Combines multi-head attention and feedforward networks with residual connections and layer normalization to form a complete transformer decoder layer.

class easydel.modules.gemma2.modeling_gemma2.Gemma2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Gemma2Model, Gemma2Config]

Gemma2 model with a language modeling head for causal language modeling tasks.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module.

class easydel.modules.gemma2.modeling_gemma2.Gemma2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Gemma2 text encoder with a classification head for sequence-level tasks.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. This model has a sequence classification head, not an LM Head.

class easydel.modules.gemma2.modeling_gemma2.Gemma2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-Layer Perceptron module for Gemma2 models.

Implements the feedforward network component of the transformer architecture with gated linear units and optional activation functions.

class easydel.modules.gemma2.modeling_gemma2.Gemma2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Decoder-only Gemma2 transformer composed of embedding, decoder stack, and final norm.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.gemma2.modeling_gemma2.Gemma2RMSNorm(*args: Any, **kwargs: Any)[source]#

Bases: Module

Root Mean Square Layer Normalization for Gemma2 models.

This normalization technique normalizes the inputs by the root mean square, providing stability during training while being computationally efficient.

static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)