easydel.modules.gemma2.modeling_gemma2#
- class easydel.modules.gemma2.modeling_gemma2.Gemma2Attention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionMulti-head attention layer with RoPE embeddings for Gemma2 models.
Inherits from UnifiedAttention with Gemma2-specific customizations: - Sliding window attention (layer-specific) - Custom query pre-attention scalar
- class easydel.modules.gemma2.modeling_gemma2.Gemma2DecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSingle decoder layer for Gemma2 models.
Combines multi-head attention and feedforward networks with residual connections and layer normalization to form a complete transformer decoder layer.
- class easydel.modules.gemma2.modeling_gemma2.Gemma2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[Gemma2Model,Gemma2Config]Gemma2 model with a language modeling head for causal language modeling tasks.
- class easydel.modules.gemma2.modeling_gemma2.Gemma2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGemma2 text encoder with a classification head for sequence-level tasks.
- class easydel.modules.gemma2.modeling_gemma2.Gemma2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMulti-Layer Perceptron module for Gemma2 models.
Implements the feedforward network component of the transformer architecture with gated linear units and optional activation functions.
- class easydel.modules.gemma2.modeling_gemma2.Gemma2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleDecoder-only Gemma2 transformer composed of embedding, decoder stack, and final norm.
- class easydel.modules.gemma2.modeling_gemma2.Gemma2RMSNorm(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleRoot Mean Square Layer Normalization for Gemma2 models.
This normalization technique normalizes the inputs by the root mean square, providing stability during training while being computationally efficient.
- static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#
An initializer that returns a constant array full of ones.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)