easydel.modules.llama4.modeling_llama4

Contents

easydel.modules.llama4.modeling_llama4#

class easydel.modules.llama4.modeling_llama4.Llama4CausalLMOutputWithPast(loss: Optional[Union[Array, ndarray, bool, number]] = None, logits: Union[Array, ndarray, bool, number] = None, past_key_values: easydel.layers.caching.transformer.cache.TransformerCache | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, image_hidden_states: jaxtyping.Float[Array, 'batch seq_len hidden_dim'] | None = None)[source]#

Bases: ModelOutput

Base class for Llama4Vision causal language model (or autoregressive) outputs.

Parameters
  • loss (chex.Array of shape (1,), optional, returned when labels is provided) – Language modeling loss (for next-token prediction).

  • logits (chex.Array of shape (batch_size, sequence_length, config.vocab_size)) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • past_key_values (`tuple(tuple(chex.Array)) –

    when config.use_cache=True): Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))

    Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • hidden_states (`tuple(chex.Array) –

    when config.output_hidden_states=True): Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (`tuple(chex.Array) –

    config.output_attentions=True): Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • image_hidden_states (chex.Array, optional) – A chex.Array of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.

attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
image_hidden_states: jaxtyping.Float[Array, 'batch seq_len hidden_dim'] | None = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: easydel.layers.caching.transformer.cache.TransformerCache | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.modules.llama4.modeling_llama4.Llama4ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Llama4TextModel, Llama4TextConfig]

Llama4 model with a Causal Language Modeling head.

class easydel.modules.llama4.modeling_llama4.Llama4ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama4 Vision model for conditional text generation based on image inputs.

Combines a vision tower and a language model with a multi-modal projector.

Note: Llama4 has a unique architecture where the language_model is already a complete Llama4ForCausalLM (with its own lm_head), unlike other VLMs where the base model doesn’t include the lm_head.

config#

Configuration object.

Type

Llama4Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

Class Attributes:

_task_type: IMAGE_TEXT_TO_TEXT task type _model_type: “llama4” model identifier _supports_video: True (Llama4 supports video input) _uses_mrope: False (uses standard RoPE)

get_decoder()[source]#

Returns the decoder part of the model.

get_embedding()[source]#

Returns the embedding layer.

get_encoder()[source]#

Returns the encoder part of the model (vision tower).

get_image_features(pixel_values: Union[Array, ndarray, bool, number], **kwargs) Union[Array, ndarray, bool, number][source]#

Extracts and projects image features from the vision tower.

Parameters

pixel_values (chex.Array) – Input pixel values for the images.

Returns

Processed image features ready for the language model.

Return type

chex.Array

get_language_model() Module[source]#

Returns the language model component.

get_lm_head()[source]#

Returns the language model head.

get_projector() Module[source]#

Returns the multimodal projector component.

get_vision_tower() Module[source]#

Returns the vision tower component.

init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) – The batch size for the cache.

  • max_length (int) – The maximum sequence length the cache needs to support.

  • starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(input_ids: Int[Array, 'batch seq_len'], max_length: int, pad_token_id: int, starts: int | None = None, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: jaxtyping.Bool[Array, 'batch seq_len'] | None = None)[source]#

Prepares inputs for text generation, including pixel values if provided.

Parameters
  • input_ids (chex.Array) – Initial input token IDs.

  • max_length (int) – Maximum generation length.

  • pixel_values (Optional[chex.Array]) – Pixel values for image input.

  • attention_mask (Optional[chex.Array]) – Attention mask.

Returns

Model inputs ready for generation.

Return type

dict

update_inputs_for_generation(model_outputs, model_kwargs)[source]#

Updates model inputs for the next step of generation, removing pixel values after the first step.

Parameters
  • model_outputs – Outputs from the previous generation step.

  • model_kwargs – Current keyword arguments for the model.

Returns

Updated model keyword arguments.

Return type

dict

class easydel.modules.llama4.modeling_llama4.Llama4ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[Llama4TextModel, Llama4TextConfig]

Llama4 model for sequence classification tasks.

class easydel.modules.llama4.modeling_llama4.Llama4MultiModalProjector(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-modal projector for Llama4 vision-language models.

Projects vision features into the text embedding space using MLP layers, enabling cross-modal understanding and generation.

class easydel.modules.llama4.modeling_llama4.Llama4TextAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Attention module for the Llama4 text decoder with optional sliding windows.

forward(hidden_states: Float[Array, 'batch seq_len hidden_dim'], mask_info: ejkernel.types.mask.MaskInfo | None, position_ids: Int[Array, 'batch seq_len'], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], cache_view: easydel.layers.caching.transformer.cache.TransformerCacheView | easydel.layers.caching.ragged_page.cache.RaggedPagesCacheView | None = None, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | easydel.layers.caching.ragged_page.cache.RaggedPagesMetadata | None = None, output_attentions: bool = False, frequencies: jaxtyping.Float[Array, 'seq_len head_dim'] | None = None, alibi: jaxtyping.Float[Array, 'batch_or_1 heads qseq_len_or_1 kvseq_len_or_1'] | None = None) AttentionLayerOutput[source]#

Standard RoPE-based attention (default path).

Used by most models: Llama, Mistral, Gemma, Qwen, etc.

Flow:
  1. Project Q/K/V

  2. Reshape to multi-head format

  3. POST-PROCESS: Apply Q/K norm via _postprocess_qkv()

  4. Apply sharding

  5. Apply RoPE

  6. KV cache concatenation

  7. Compute attention

  8. Merge heads and output projection

  9. Optional residual dropout

Parameters
  • hidden_states – Input tensor [batch, seq_len, hidden_dim]

  • mask_info – Mask information for attention

  • position_ids – Position indices for RoPE

  • mode – Runtime mode (train/eval/infer)

  • cache_view – Optional cache view for KV caching

  • cache_metadata – Optional cache metadata

  • output_attentions – Whether to return attention weights

  • frequencies – Optional precomputed RoPE frequencies

  • alibi – Optional external ALiBi positional bias (unused in standard attention)

Returns

AttentionLayerOutput with attention output and optional weights

class easydel.modules.llama4.modeling_llama4.Llama4TextDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single Llama4 text decoder block combining attention and MLP.

class easydel.modules.llama4.modeling_llama4.Llama4TextExperts(*args: Any, **kwargs: Any)[source]#

Bases: Module

Mixture of Experts module for Llama4 text models.

Implements a sparse mixture of experts with top-k routing, enabling efficient scaling and specialization of model capacity.

class easydel.modules.llama4.modeling_llama4.Llama4TextL2Norm(*args: Any, **kwargs: Any)[source]#

Bases: Module

L2 normalization layer for Llama4 text models.

Normalizes inputs using L2 norm with learned scaling parameters, providing stable gradients during training.

static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
class easydel.modules.llama4.modeling_llama4.Llama4TextMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-Layer Perceptron for Llama4 text models.

Implements feedforward network with SwiGLU activation function for improved representation learning.

class easydel.modules.llama4.modeling_llama4.Llama4TextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Decoder-only Llama4 text model built from embeddings and decoder blocks.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.llama4.modeling_llama4.Llama4TextMoe(*args: Any, **kwargs: Any)[source]#

Bases: Module

Mixture of Experts layer for Llama4 text models.

Routes inputs to specialized expert networks based on learned routing, allowing for conditional computation and increased model capacity.

class easydel.modules.llama4.modeling_llama4.Llama4UnfoldConvolution(*args: Any, **kwargs: Any)[source]#

Bases: Module

Unfold convolution module for Llama4 vision models.

Implements patch extraction with optional convolution, converting images into sequences of patch embeddings.

class easydel.modules.llama4.modeling_llama4.Llama4VisionAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Attention module for the Llama4 vision transformer.

class easydel.modules.llama4.modeling_llama4.Llama4VisionEncoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

Vision encoder stack for Llama4 models.

Stacks multiple vision encoder layers to progressively encode visual features for downstream processing.

class easydel.modules.llama4.modeling_llama4.Llama4VisionEncoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single encoder layer for Llama4 vision models.

Combines self-attention and feedforward networks with layer normalization and residual connections for vision feature encoding.

class easydel.modules.llama4.modeling_llama4.Llama4VisionMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

MLP module for Llama4 vision transformer.

Standard feedforward network with GELU activation for vision feature transformation within transformer blocks.

class easydel.modules.llama4.modeling_llama4.Llama4VisionMLP2(*args: Any, **kwargs: Any)[source]#

Bases: Module

Two-layer MLP module for Llama4 vision models.

Implements a simple two-layer feedforward network with GELU activation for vision feature transformation.

class easydel.modules.llama4.modeling_llama4.Llama4VisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Vision transformer for Llama4 including patchify stem, transformer blocks, and final norm.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition. This is an encoder-only model and does not have a decoder.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. This vision model acts as the encoder.

get_lm_head()[source]#

Returns the language model head of the module. This vision model does not have a language model head.

class easydel.modules.llama4.modeling_llama4.Llama4VisionPixelShuffleMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Pixel shuffle MLP for Llama4 vision models.

Performs spatial downsampling of vision features through pixel shuffling and MLP transformations for efficient processing.

easydel.modules.llama4.modeling_llama4.bmm(inputs, kernel, precision)[source]#

Batch matrix multiplication helper that works for 2D or higher-rank inputs.

easydel.modules.llama4.modeling_llama4.pixel_shuffle(input_tensor, shuffle_ratio)[source]#

Rearrange flattened vision tokens to a denser spatial grid.

easydel.modules.llama4.modeling_llama4.reshape_for_broadcast(frequencies: Array, query: Array) Array[source]#

Reshape rotary frequencies so they broadcast over the complex query tensor.

easydel.modules.llama4.modeling_llama4.vision_apply_rotary_emb(query: Array, key: Array, frequencies: Array) tuple[jax.Array, jax.Array][source]#

Apply rotary position embeddings to complex-valued vision queries and keys.