easydel.modules.gemma.modeling_gemma#

class easydel.modules.gemma.modeling_gemma.GemmaAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Multi-head attention layer with RoPE embeddings for Gemma models.

Inherits from UnifiedAttention.

class easydel.modules.gemma.modeling_gemma.GemmaDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single decoder layer for Gemma models.

Combines multi-head attention and feedforward networks with residual connections and layer normalization to form a complete transformer decoder layer.

class easydel.modules.gemma.modeling_gemma.GemmaForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[GemmaModel, GemmaConfig]

Gemma model with a Causal Language Modeling head.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module.

class easydel.modules.gemma.modeling_gemma.GemmaForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Gemma encoder stack with a linear classification head for sequence labels.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. This model has a sequence classification head, not an LM Head.

class easydel.modules.gemma.modeling_gemma.GemmaMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-Layer Perceptron module for Gemma models.

Implements the feedforward network component of the transformer architecture with gated linear units and optional activation functions.

class easydel.modules.gemma.modeling_gemma.GemmaModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Decoder-only Gemma transformer wiring embeddings, decoder blocks, and output norm.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.gemma.modeling_gemma.GemmaRMSNorm(*args: Any, **kwargs: Any)[source]#

Bases: Module

Root Mean Square Layer Normalization for Gemma models.

This normalization technique normalizes the inputs by the root mean square, providing stability during training while being computationally efficient.

static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)