easydel.modules.gidd.modeling_gidd#

This module provides the core components of the GIDD model, including: - GiddMLP: A feed-forward network with squared ReLU activation - GiddAttention: An attention mechanism with optional query-key normalization - GiddRMSNorm: Root Mean Square normalization layer - GiddLayer: A transformer layer combining attention and MLP components - GiddModel: The base transformer model - GiddForDiffusionLM: A version of the model adapted for diffusion language modeling

The implementation leverages JAX for efficient computation and supports various optimizations including gradient checkpointing, mixed precision, and model parallelism.

class easydel.modules.gidd.modeling_gidd.GiddAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

GIDD-specific attention mechanism with optional query-key normalization.

This attention module implements a multi-head attention mechanism with support for query-key normalization, rotary position embeddings, and flexible attention patterns.

config#

Configuration object containing model parameters.

Type

GiddConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

hidden_size#

Dimensionality of the hidden states.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

use_qk_norm#

Whether to apply normalization to query and key vectors.

Type

bool

qk_norm_eps#

Epsilon value for numerical stability in QK normalization.

Type

float

qk_scale#

Scaling factor for attention scores.

Type

float or ArrayParam

q_proj#

Linear projection for queries.

Type

ParallelLinear

k_proj#

Linear projection for keys.

Type

ParallelLinear

v_proj#

Linear projection for values.

Type

ParallelLinear

o_proj#

Linear projection for outputs.

Type

ParallelLinear

rotary#

Rotary position embedding module.

attention_performer#

Module that performs the actual attention computation.

Type

FlexibleAttentionModule

concatenate(*, query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], mask_info: MaskInfo, noise_mask: Union[Array, ndarray, bool, number], cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None = None) tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], Callable[[], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]]][source]#

Prepare and concatenate key, value, and attention mask for attention computation.

This method handles the preprocessing of attention inputs, including: - Validating and reshaping attention masks - Combining attention masks with noise masks - Creating a function to initialize attention bias

Parameters
  • query – Query tensor of shape [batch_size, seq_len, num_heads, head_dim].

  • key – Key tensor of shape [batch_size, seq_len, num_heads, head_dim].

  • value – Value tensor of shape [batch_size, seq_len, num_heads, head_dim].

  • attention_mask – Binary mask of shape [batch_size, seq_len] or [batch_size, 1, seq_len, seq_len].

  • noise_mask – Binary mask for noise tokens of shape [batch_size, seq_len].

  • cache_view – View into the key/value cache for incremental decoding.

  • cache_metadata – Metadata for cache operations.

Returns

  • key: Processed key tensor.

  • value: Processed value tensor.

  • attention_mask: Processed attention mask.

  • init_attention_bias: Function to initialize attention bias.

  • cache_view: Updated cache view.

Return type

A tuple containing

class easydel.modules.gidd.modeling_gidd.GiddForDiffusionLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

GIDD model with a language modeling head for diffusion language modeling tasks.

This model extends the base GiddModel with a language modeling head, making it suitable for autoregressive language generation tasks, particularly in the context of diffusion models.

config#

Configuration for the model.

Type

GiddConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

model#

The base transformer model.

Type

GiddModel

lm_head#

Language modeling head.

Type

ParallelLinear

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

Returns

The base model, which serves as the decoder.

get_embedding()[source]#

Returns the embedding layer of the module.

Returns

The token embedding layer from the base model.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition.

Note

This is a decoder-only model and does not have an encoder.

Raises

NotImplementedError – Always raised as this is a decoder-only model.

get_lm_head()[source]#

Returns the language model head of the module.

Returns

The language modeling head.

class easydel.modules.gidd.modeling_gidd.GiddLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

A single transformer layer for the GIDD model.

This layer combines a self-attention mechanism with a feed-forward network (MLP), using residual connections and layer normalization. It’s designed to be stacked to form the complete transformer model.

config#

Configuration object containing model parameters.

Type

GiddConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

resid_scale#

Scaling factor for residual connections.

Type

float

self_attn#

Self-attention module.

Type

GiddAttention

mlp#

Feed-forward network module.

Type

GiddMLP

input_layernorm#

Layer normalization before attention.

Type

GiddRMSNorm

post_attention_layernorm#

Layer normalization before MLP.

Type

GiddRMSNorm

class easydel.modules.gidd.modeling_gidd.GiddMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

GIDD-specific MLP (Multi-Layer Perceptron) implementation.

This MLP uses a two-layer structure with a squared ReLU activation function between the layers. It’s designed to be used within the GIDD transformer layers.

config#

Configuration object containing model parameters.

Type

GiddConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

up_proj#

First linear layer projecting from hidden_size to intermediate_size.

Type

ParallelLinear

down_proj#

Second linear layer projecting from intermediate_size back to hidden_size.

Type

ParallelLinear

class easydel.modules.gidd.modeling_gidd.GiddModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Base GIDD model implementation.

This model implements the core transformer architecture of the GIDD model, consisting of an embedding layer, multiple transformer layers, and a final normalization layer. It serves as the foundation for more specialized models like GiddForDiffusionLM.

config#

Configuration for the model.

Type

GiddConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

resid_scale#

Scaling factor for residual connections.

Type

float

embed_tokens#

Token embedding layer.

Type

nn.Embed

layers#

List of transformer layers.

Type

list[GiddLayer]

norm#

Final normalization layer.

Type

GiddRMSNorm

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

Returns

The model itself, as it is a decoder-only model.

get_embedding()[source]#

Returns the embedding layer of the module.

Returns

The token embedding layer.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition.

Note

This is a decoder-only model and does not have an encoder.

Raises

NotImplementedError – Always raised as this is a decoder-only model.

get_lm_head()[source]#

Returns the language model head of the module.

Note

The base model does not have a language model head.

Raises

NotImplementedError – Always raised as the base model does not have an LM head.

class easydel.modules.gidd.modeling_gidd.GiddRMSNorm(*args: Any, **kwargs: Any)[source]#

Bases: Module

Root Mean Square Layer Normalization (RMSNorm) for the GIDD model.

RMSNorm is a simplified variant of LayerNorm that normalizes the input by its root mean square value and applies a learnable scale parameter.

config#

Configuration object containing model parameters.

Type

GiddConfig

epsilon#

Small constant for numerical stability.

Type

float

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

kernel#

Learnable scale parameter.

Type

ArrayParam

static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)