easydel.modules.llama4.modeling_llama4_flax#
- class easydel.modules.llama4.modeling_llama4_flax.Llama4CausalLMOutputWithPast(loss: Optional[Union[Array, ndarray, bool, number]] = None, logits: Union[Array, ndarray, bool, number] = None, past_key_values: Optional[TransformerCache] = None, hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, attentions: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, image_hidden_states: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Bases:
ModelOutputBase class for Llama4Vision causal language model (or autoregressive) outputs.
- Parameters
loss (chex.Array of shape (1,), optional, returned when labels is provided) โ Language modeling loss (for next-token prediction).
logits (chex.Array of shape (batch_size, sequence_length, config.vocab_size)) โ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (tuple(tuple(chex.Array)), optional, returned when use_cache=True is passed or when config.use_cache=True) โ
Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.
hidden_states (tuple(chex.Array), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) โ
Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (tuple(chex.Array), optional, returned when output_attentions=True is passed or when config.output_attentions=True) โ
Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
image_hidden_states (chex.Array, optional) โ A chex.Array of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- past_key_values: Optional[TransformerCache] = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.modules.llama4.modeling_llama4_flax.Llama4ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.llama4.modeling_llama4_flax.Llama4ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama4Vision model for conditional text generation based on image inputs. Combines a vision tower and a language model with a multi-modal projector.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- config: tp.Union[EasyDeLBaseConfig, _CP]#
- dtype: jnp.dtype#
- get_image_features(pixel_values: Union[Array, ndarray, bool, number], **kwargs) Union[Array, ndarray, bool, number][source]#
Extracts and projects image features from the vision tower.
- Parameters
pixel_values (chex.Array) โ Input pixel values for the images.
- Returns
Processed image features ready for the language model.
- Return type
chex.Array
- init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#
Initializes and returns a standard (non-paged) Key-Value cache.
This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the modelโs configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.
- Parameters
batch_size (int) โ The batch size for the cache.
max_length (int) โ The maximum sequence length the cache needs to support.
starts (int | None) โ Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).
shardings (dict | None) โ Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).
pad_token_id (int | None) โ The ID of the padding token. If None, itโs inferred.
- Returns
An initialized standard TransformerCache object.
- Return type
- loss_type = 'ForCausalLM'#
- param_dtype: jnp.dtype#
- precision: lax.PrecisionLike#
- prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pad_token_id: int, starts: int | None = None, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Prepares inputs for text generation, including pixel values if provided.
- Parameters
input_ids (chex.Array) โ Initial input token IDs.
max_length (int) โ Maximum generation length.
pixel_values (Optional[chex.Array]) โ Pixel values for image input.
attention_mask (Optional[chex.Array]) โ Attention mask.
- Returns
Model inputs ready for generation.
- Return type
dict
- rngs: nn.Rngs#
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates model inputs for the next step of generation, removing pixel values after the first step.
- Parameters
model_outputs โ Outputs from the previous generation step.
model_kwargs โ Current keyword arguments for the model.
- Returns
Updated model keyword arguments.
- Return type
dict
- class easydel.modules.llama4.modeling_llama4_flax.Llama4ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama model for sequence classification tasks.
This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- class easydel.modules.llama4.modeling_llama4_flax.Llama4MultiModalProjector(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4TextAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModule
- class easydel.modules.llama4.modeling_llama4_flax.Llama4TextDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4TextExperts(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4TextL2Norm(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4TextMLP(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4TextModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.llama4.modeling_llama4_flax.Llama4TextMoe(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4UnfoldConvolution(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModule
- class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionEncoder(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionEncoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionMLP(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionMLP2(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.llama4.modeling_llama4_flax.Llama4VisionPixelShuffleMLP(*args: Any, **kwargs: Any)[source]#
Bases:
Module