easydel.modules.gidd.modeling_gidd#
This module provides the core components of the GIDD model, including: - GiddMLP: A feed-forward network with squared ReLU activation - GiddAttention: An attention mechanism with optional query-key normalization - GiddRMSNorm: Root Mean Square normalization layer - GiddLayer: A transformer layer combining attention and MLP components - GiddModel: The base transformer model - GiddForDiffusionLM: A version of the model adapted for diffusion language modeling
The implementation leverages JAX for efficient computation and supports various optimizations including gradient checkpointing, mixed precision, and model parallelism.
- class easydel.modules.gidd.modeling_gidd.GiddAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleGIDD-specific attention mechanism with optional query-key normalization.
This attention module implements a multi-head attention mechanism with support for query-key normalization, rotary position embeddings, and flexible attention patterns.
- config#
Configuration object containing model parameters.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
Dimensionality of the hidden states.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- use_qk_norm#
Whether to apply normalization to query and key vectors.
- Type
bool
- qk_norm_eps#
Epsilon value for numerical stability in QK normalization.
- Type
float
- qk_scale#
Scaling factor for attention scores.
- Type
float or ArrayParam
- q_proj#
Linear projection for queries.
- Type
- k_proj#
Linear projection for keys.
- Type
- v_proj#
Linear projection for values.
- Type
- o_proj#
Linear projection for outputs.
- Type
- rotary#
Rotary position embedding module.
- attention_performer#
Module that performs the actual attention computation.
- concatenate(*, query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], mask_info: MaskInfo, noise_mask: Union[Array, ndarray, bool, number], cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None = None) tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], Callable[[], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]][source]#
Prepare and concatenate key, value, and attention mask for attention computation.
This method handles the preprocessing of attention inputs, including: - Validating and reshaping attention masks - Combining attention masks with noise masks - Creating a function to initialize attention bias
- Parameters
query – Query tensor of shape [batch_size, seq_len, num_heads, head_dim].
key – Key tensor of shape [batch_size, seq_len, num_heads, head_dim].
value – Value tensor of shape [batch_size, seq_len, num_heads, head_dim].
attention_mask – Binary mask of shape [batch_size, seq_len] or [batch_size, 1, seq_len, seq_len].
noise_mask – Binary mask for noise tokens of shape [batch_size, seq_len].
cache_view – View into the key/value cache for incremental decoding.
cache_metadata – Metadata for cache operations.
- Returns
key: Processed key tensor.
value: Processed value tensor.
attention_mask: Processed attention mask.
init_attention_bias: Function to initialize attention bias.
cache_view: Updated cache view.
- Return type
A tuple containing
- class easydel.modules.gidd.modeling_gidd.GiddForDiffusionLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGIDD model with a language modeling head for diffusion language modeling tasks.
This model extends the base GiddModel with a language modeling head, making it suitable for autoregressive language generation tasks, particularly in the context of diffusion models.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- lm_head#
Language modeling head.
- Type
- get_decoder()[source]#
Returns the decoder part of the model’s graph definition.
- Returns
The base model, which serves as the decoder.
- get_embedding()[source]#
Returns the embedding layer of the module.
- Returns
The token embedding layer from the base model.
- class easydel.modules.gidd.modeling_gidd.GiddLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleA single transformer layer for the GIDD model.
This layer combines a self-attention mechanism with a feed-forward network (MLP), using residual connections and layer normalization. It’s designed to be stacked to form the complete transformer model.
- config#
Configuration object containing model parameters.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- resid_scale#
Scaling factor for residual connections.
- Type
float
- self_attn#
Self-attention module.
- Type
- input_layernorm#
Layer normalization before attention.
- Type
- post_attention_layernorm#
Layer normalization before MLP.
- Type
- class easydel.modules.gidd.modeling_gidd.GiddMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGIDD-specific MLP (Multi-Layer Perceptron) implementation.
This MLP uses a two-layer structure with a squared ReLU activation function between the layers. It’s designed to be used within the GIDD transformer layers.
- config#
Configuration object containing model parameters.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- up_proj#
First linear layer projecting from hidden_size to intermediate_size.
- Type
- down_proj#
Second linear layer projecting from intermediate_size back to hidden_size.
- Type
- class easydel.modules.gidd.modeling_gidd.GiddModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleBase GIDD model implementation.
This model implements the core transformer architecture of the GIDD model, consisting of an embedding layer, multiple transformer layers, and a final normalization layer. It serves as the foundation for more specialized models like GiddForDiffusionLM.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- resid_scale#
Scaling factor for residual connections.
- Type
float
- embed_tokens#
Token embedding layer.
- Type
nn.Embed
- norm#
Final normalization layer.
- Type
- get_decoder()[source]#
Returns the decoder part of the model’s graph definition.
- Returns
The model itself, as it is a decoder-only model.
- get_embedding()[source]#
Returns the embedding layer of the module.
- Returns
The token embedding layer.
- class easydel.modules.gidd.modeling_gidd.GiddRMSNorm(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleRoot Mean Square Layer Normalization (RMSNorm) for the GIDD model.
RMSNorm is a simplified variant of LayerNorm that normalizes the input by its root mean square value and applies a learnable scale parameter.
- config#
Configuration object containing model parameters.
- Type
- epsilon#
Small constant for numerical stability.
- Type
float
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- kernel#
Learnable scale parameter.
- Type
- static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#
An initializer that returns a constant array full of ones.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)