easydel.modules.gemma.modeling_gemma#
- class easydel.modules.gemma.modeling_gemma.GemmaAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionMulti-head attention layer with RoPE embeddings for Gemma models.
Inherits from UnifiedAttention.
- class easydel.modules.gemma.modeling_gemma.GemmaDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSingle decoder layer for Gemma models.
Combines multi-head attention and feedforward networks with residual connections and layer normalization to form a complete transformer decoder layer.
- class easydel.modules.gemma.modeling_gemma.GemmaForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[GemmaModel,GemmaConfig]Gemma model with a Causal Language Modeling head.
- class easydel.modules.gemma.modeling_gemma.GemmaForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGemma encoder stack with a linear classification head for sequence labels.
- class easydel.modules.gemma.modeling_gemma.GemmaMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMulti-Layer Perceptron module for Gemma models.
Implements the feedforward network component of the transformer architecture with gated linear units and optional activation functions.
- class easydel.modules.gemma.modeling_gemma.GemmaModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleDecoder-only Gemma transformer wiring embeddings, decoder blocks, and output norm.
- class easydel.modules.gemma.modeling_gemma.GemmaRMSNorm(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleRoot Mean Square Layer Normalization for Gemma models.
This normalization technique normalizes the inputs by the root mean square, providing stability during training while being computationally efficient.
- static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#
An initializer that returns a constant array full of ones.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)