easydel.modules.llama4.modeling_llama4_flax

Contents

easydel.modules.llama4.modeling_llama4_flax#

class easydel.modules.llama4.modeling_llama4_flax.Llama4CausalLMOutputWithPast(loss: Optional[Union[Array, ndarray, bool, number]] = None, logits: Union[Array, ndarray, bool, number] = None, past_key_values: Optional[TransformerCache] = None, hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, attentions: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, image_hidden_states: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for Llama4Vision causal language model (or autoregressive) outputs.

Parameters
  • loss (chex.Array of shape (1,), optional, returned when labels is provided) โ€“ Language modeling loss (for next-token prediction).

  • logits (chex.Array of shape (batch_size, sequence_length, config.vocab_size)) โ€“ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).

  • past_key_values (tuple(tuple(chex.Array)), optional, returned when use_cache=True is passed or when config.use_cache=True) โ€“

    Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))

    Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.

  • hidden_states (tuple(chex.Array), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) โ€“

    Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

  • attentions (tuple(chex.Array), optional, returned when output_attentions=True is passed or when config.output_attentions=True) โ€“

    Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).

    Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

  • image_hidden_states (chex.Array, optional) โ€“ A chex.Array of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.

attentions: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
image_hidden_states: Optional[Union[Array, ndarray, bool, number]] = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Optional[TransformerCache] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.modules.llama4.modeling_llama4_flax.Llama4ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.llama4.modeling_llama4_flax.Llama4ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama4Vision model for conditional text generation based on image inputs. Combines a vision tower and a language model with a multi-modal projector.

config#

Configuration object.

Type

Llama4VisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

config: tp.Union[EasyDeLBaseConfig, _CP]#
dtype: jnp.dtype#
get_image_features(pixel_values: Union[Array, ndarray, bool, number], **kwargs) Union[Array, ndarray, bool, number][source]#

Extracts and projects image features from the vision tower.

Parameters

pixel_values (chex.Array) โ€“ Input pixel values for the images.

Returns

Processed image features ready for the language model.

Return type

chex.Array

init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the modelโ€™s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) โ€“ The batch size for the cache.

  • max_length (int) โ€“ The maximum sequence length the cache needs to support.

  • starts (int | None) โ€“ Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) โ€“ Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) โ€“ The ID of the padding token. If None, itโ€™s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

loss_type = 'ForCausalLM'#
param_dtype: jnp.dtype#
precision: lax.PrecisionLike#
prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pad_token_id: int, starts: int | None = None, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Prepares inputs for text generation, including pixel values if provided.

Parameters
  • input_ids (chex.Array) โ€“ Initial input token IDs.

  • max_length (int) โ€“ Maximum generation length.

  • pixel_values (Optional[chex.Array]) โ€“ Pixel values for image input.

  • attention_mask (Optional[chex.Array]) โ€“ Attention mask.

Returns

Model inputs ready for generation.

Return type

dict

rngs: nn.Rngs#
update_inputs_for_generation(model_outputs, model_kwargs)[source]#

Updates model inputs for the next step of generation, removing pixel values after the first step.

Parameters
  • model_outputs โ€“ Outputs from the previous generation step.

  • model_kwargs โ€“ Current keyword arguments for the model.

Returns

Updated model keyword arguments.

Return type

dict

class easydel.modules.llama4.modeling_llama4_flax.Llama4ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model for sequence classification tasks.

This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.modules.llama4.modeling_llama4_flax.Llama4MultiModalProjector(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4TextAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

class easydel.modules.llama4.modeling_llama4_flax.Llama4TextDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4TextExperts(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4TextL2Norm(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4TextMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4TextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.llama4.modeling_llama4_flax.Llama4TextMoe(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4UnfoldConvolution(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionEncoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionEncoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionMLP2(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionPixelShuffleMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

easydel.modules.llama4.modeling_llama4_flax.bmm(inputs, kernel, precision)[source]#
easydel.modules.llama4.modeling_llama4_flax.pixel_shuffle(input_tensor, shuffle_ratio)[source]#
easydel.modules.llama4.modeling_llama4_flax.reshape_for_broadcast(frequencies: Array, query: Array) Array[source]#
easydel.modules.llama4.modeling_llama4_flax.vision_apply_rotary_emb(query: Array, key: Array, frequencies: Array) Tuple[Array, Array][source]#