easydel.modules.llama4.__init__#
- class easydel.modules.llama4.__init__.Llama4Config(vision_config=None, text_config=None, boi_token_index=200080, eoi_token_index=200081, image_token_index=200092, tie_word_embeddings=False, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the Llama4 model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'llama4'#
- sub_configs: dict[str, 'PretrainedConfig'] = {'text_config': <class 'easydel.modules.llama4.llama4_configuration.Llama4TextConfig'>, 'vision_config': <class 'easydel.modules.llama4.llama4_configuration.Llama4VisionConfig'>}#
- class easydel.modules.llama4.__init__.Llama4ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.llama4.__init__.Llama4ForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama4Vision model for conditional text generation based on image inputs. Combines a vision tower and a language model with a multi-modal projector.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- get_image_features(pixel_values: Union[Array, ndarray, bool, number], **kwargs) Union[Array, ndarray, bool, number][source]#
Extracts and projects image features from the vision tower.
- Parameters
pixel_values (chex.Array) โ Input pixel values for the images.
- Returns
Processed image features ready for the language model.
- Return type
chex.Array
- init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#
Initializes and returns a standard (non-paged) Key-Value cache.
This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the modelโs configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.
- Parameters
batch_size (int) โ The batch size for the cache.
max_length (int) โ The maximum sequence length the cache needs to support.
starts (int | None) โ Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).
shardings (dict | None) โ Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).
pad_token_id (int | None) โ The ID of the padding token. If None, itโs inferred.
- Returns
An initialized standard TransformerCache object.
- Return type
- loss_type = 'ForCausalLM'#
- prepare_inputs_for_generation(input_ids: Union[Array, ndarray, bool, number], max_length: int, pad_token_id: int, starts: int | None = None, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Prepares inputs for text generation, including pixel values if provided.
- Parameters
input_ids (chex.Array) โ Initial input token IDs.
max_length (int) โ Maximum generation length.
pixel_values (Optional[chex.Array]) โ Pixel values for image input.
attention_mask (Optional[chex.Array]) โ Attention mask.
- Returns
Model inputs ready for generation.
- Return type
dict
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates model inputs for the next step of generation, removing pixel values after the first step.
- Parameters
model_outputs โ Outputs from the previous generation step.
model_kwargs โ Current keyword arguments for the model.
- Returns
Updated model keyword arguments.
- Return type
dict
- class easydel.modules.llama4.__init__.Llama4ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleLlama model for sequence classification tasks.
This class extends the base Llama model by adding a linear classification head to perform sequence classification tasks such as sentiment analysis or text classification.
- config#
Configuration for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- class easydel.modules.llama4.__init__.Llama4TextConfig(vocab_size=202048, hidden_size=5120, intermediate_size=8192, intermediate_size_mlp=16384, num_hidden_layers=48, num_attention_heads=40, num_key_value_heads=8, head_dim=128, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=500000, attention_dropout=0.0, num_experts_per_tok=1, num_local_experts=16, moe_layers=None, interleave_moe_layer_step=1, use_qk_norm=True, output_router_logits=False, router_aux_loss_coef=0.001, router_jitter_noise=0.0, rope_scaling=None, no_rope_layers=None, no_rope_layer_interval=4, attention_chunk_size=8192, attn_temperature_tuning=4, floor_scale=8192, attn_scale=0.1, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the Llama4Text model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'llama4_text'#
- class easydel.modules.llama4.__init__.Llama4TextModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.llama4.__init__.Llama4VisionConfig(hidden_size: int = 768, hidden_act: str = 'gelu', num_hidden_layers: int = 34, num_attention_heads: int = 16, num_channels: int = 3, intermediate_size: int = 5632, vision_output_dim: int = 7680, image_size: int = 448, patch_size: int = 14, norm_eps: float = 1e-05, vision_feature_layer=-1, vision_feature_select_strategy='default', initializer_range: float = 0.02, pixel_shuffle_ratio=0.5, projector_input_dim=4096, projector_output_dim=4096, multi_modal_projector_bias=False, projector_dropout=0.0, attention_dropout=0.0, rope_theta=10000, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- base_config_key: str = 'vision_config'#
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the Llama4Vision model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'llama4_vision_model'#
- class easydel.modules.llama4.__init__.Llama4VisionModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule