easydel.modules.llama4.__init__#

class easydel.modules.llama4.__init__.Llama4Config(vision_config=None, text_config=None, boi_token_index=200080, eoi_token_index=200081, image_token_index=200092, tie_word_embeddings=False, **kwargs)[source]#

Bases: EasyDeLBaseConfig

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the Llama4 model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'llama4'#
sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.modules.llama4.llama4_configuration.Llama4TextConfig'>, 'vision_config': <class 'easydel.modules.llama4.llama4_configuration.Llama4VisionConfig'>}#
class easydel.modules.llama4.__init__.Llama4ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.llama4.__init__.Llama4ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama4Vision model for conditional text generation based on image inputs. Combines a vision tower and a language model with a multi-modal projector.

config#

Configuration object.

Type

Llama4VisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_image_features(pixel_values: Union[Array, ndarray, bool, number], **kwargs) Union[Array, ndarray, bool, number][source]#

Extracts and projects image features from the vision tower.

Parameters

pixel_values (chex.Array) โ€“ Input pixel values for the images.

Returns

Processed image features ready for the language model.

Return type

chex.Array

init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#

Initializes and returns a standard (non-paged) Key-Value cache.

This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the modelโ€™s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.

Parameters
  • batch_size (int) โ€“ The batch size for the cache.

  • max_length (int) โ€“ The maximum sequence length the cache needs to support.

  • starts (int | None) โ€“ Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).

  • shardings (dict | None) โ€“ Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).

  • pad_token_id (int | None) โ€“ The ID of the padding token. If None, itโ€™s inferred.

Returns

An initialized standard TransformerCache object.

Return type

TransformerCache

loss_type = 'ForCausalLM'#
prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pad_token_id: int, starts: int | None = None, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Prepares inputs for text generation, including pixel values if provided.

Parameters
  • input_ids (chex.Array) โ€“ Initial input token IDs.

  • max_length (int) โ€“ Maximum generation length.

  • pixel_values (Optional[chex.Array]) โ€“ Pixel values for image input.

  • attention_mask (Optional[chex.Array]) โ€“ Attention mask.

Returns

Model inputs ready for generation.

Return type

dict

update_inputs_for_generation(model_outputs, model_kwargs)[source]#

Updates model inputs for the next step of generation, removing pixel values after the first step.

Parameters
  • model_outputs โ€“ Outputs from the previous generation step.

  • model_kwargs โ€“ Current keyword arguments for the model.

Returns

Updated model keyword arguments.

Return type

dict

class easydel.modules.llama4.__init__.Llama4ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Llama model for sequence classification tasks.

This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.

config#

Configuration for the model.

Type

LlamaConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

class easydel.modules.llama4.__init__.Llama4TextConfig(vocab_size=202048, hidden_size=5120, intermediate_size=8192, intermediate_size_mlp=16384, num_hidden_layers=48, num_attention_heads=40, num_key_value_heads=8, head_dim=128, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=500000, attention_dropout=0.0, num_experts_per_tok=1, num_local_experts=16, moe_layers=None, interleave_moe_layer_step=1, use_qk_norm=True, output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.0, rope_scaling=None, no_rope_layers=None, no_rope_layer_interval=4, attention_chunk_size=8192, attn_temperature_tuning=4, floor_scale=8192, attn_scale=0.1, **kwargs)[source]#

Bases: EasyDeLBaseConfig

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the Llama4Text model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'llama4_text'#
class easydel.modules.llama4.__init__.Llama4TextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.llama4.__init__.Llama4VisionConfig(hidden_size: int = 768, hidden_act: str = 'gelu', num_hidden_layers: int = 34, num_attention_heads: int = 16, num_channels: int = 3, intermediate_size: int = 5632, vision_output_dim: int = 7680, image_size: int = 448, patch_size: int = 14, norm_eps: float = 1e-05, vision_feature_layer=-1, vision_feature_select_strategy='default', initializer_range: float = 0.02, pixel_shuffle_ratio=0.5, projector_input_dim=4096, projector_output_dim=4096, multi_modal_projector_bias=False, projector_dropout=0.0, attention_dropout=0.0, rope_theta=10000, **kwargs)[source]#

Bases: EasyDeLBaseConfig

base_config_key: str = 'vision_config'#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the Llama4Vision model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'llama4_vision_model'#
class easydel.modules.llama4.__init__.Llama4VisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule