easydel.modules.aya_vision.modeling_aya_vision#
- class easydel.modules.aya_vision.modeling_aya_vision.AyaVisionCausalLMOutputWithPast(loss: Optional[Union[Array, ndarray, bool, number]] = None, logits: Union[Array, ndarray, bool, number] = None, past_key_values: easydel.layers.caching.transformer.cache.TransformerCache | None = None, hidden_states: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, last_hidden_state: Optional[Union[Array, ndarray, bool, number]] = None, attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None, image_hidden_states: jaxtyping.Float[Array, 'batch seq_len hidden_dim'] | None = None)[source]#
Bases:
ModelOutputBase class for AyaVision causal language model (or autoregressive) outputs.
- Parameters
loss (chex.Array of shape (1,), optional, returned when labels is provided) – Language modeling loss (for next-token prediction).
logits (chex.Array of shape (batch_size, sequence_length, config.vocab_size)) – Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(chex.Array)) –
when config.use_cache=True): Tuple of tuple(chex.Array) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head))
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.
hidden_states (`tuple(chex.Array) –
when config.output_hidden_states=True): Tuple of chex.Array (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(chex.Array) –
when config.output_attentions=True): Tuple of chex.Array (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
image_hidden_states (chex.Array, optional) – A chex.Array of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
- attentions: tuple[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number]] | None = None#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- past_key_values: easydel.layers.caching.transformer.cache.TransformerCache | None = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.modules.aya_vision.modeling_aya_vision.AyaVisionForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
BaseVisionLanguageModule[AyaVisionModel,AyaVisionConfig]AyaVision model for conditional text generation based on image inputs.
Combines a vision tower and a language model with a multi-modal projector. Inherits from BaseVisionLanguageModule to leverage common VLM infrastructure.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- Class Attributes:
_task_type: IMAGE_TEXT_TO_TEXT task type _model_type: “aya_vision” model identifier _supports_video: False (AyaVision is image-only) _uses_mrope: False (uses standard RoPE)
- get_image_features(pixel_values: Float[Array, 'batch channels height width'], **kwargs) Float[Array, 'batch num_patches hidden'][source]#
Extract and project image features from pixel values.
Delegates to the base model’s get_image_features implementation which: 1. Passes pixel_values through the vision tower 2. Applies pixel shuffling for downsampling 3. Applies the multimodal projector with gating
- Parameters
pixel_values – Input image pixel values
**kwargs – Additional arguments (unused for AyaVision)
- Returns
Projected image features ready for merging with text embeddings
- init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#
Initialize KV cache for generation.
- loss_type = 'ForCausalLM'#
- class easydel.modules.aya_vision.modeling_aya_vision.AyaVisionModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleAyaVision model for conditional text generation based on image inputs. Combines a vision tower and a language model with a multi-modal projector.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- get_encoder()[source]#
Returns the encoder part of the model’s graph definition. The vision tower acts as the encoder in this multi-modal setup.
- get_image_features(pixel_values: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
Extracts and projects image features from the vision tower.
- Parameters
pixel_values (chex.Array) – Input pixel values for the images.
- Returns
Processed image features ready for the language model.
- Return type
chex.Array
- get_lm_head()[source]#
Returns the language model head of the module. Base Models don’t have a Language Model Head.
- init_cache(batch_size, max_length, starts=None, shardings=None, pad_token_id=None)[source]#
Initializes and returns a standard (non-paged) Key-Value cache.
This method first creates the necessary metadata using create_cache_metadata and then calls TransformerCache.init_cache to allocate and initialize the cache tensors based on the model’s configuration, dtype, sharding, quantization settings, and provided batch size and maximum length.
- Parameters
batch_size (int) – The batch size for the cache.
max_length (int) – The maximum sequence length the cache needs to support.
starts (int | None) – Optional starting positions for the cache sequences. If provided, influences the initial state. Defaults to None (usually 0).
shardings (dict | None) – Optional dictionary specifying sharding configurations. (Note: This argument appears unused in the current implementation shown).
pad_token_id (int | None) – The ID of the padding token. If None, it’s inferred.
- Returns
An initialized standard TransformerCache object.
- Return type
- prepare_inputs_for_generation(input_ids: Int[Array, 'batch seq_len'], max_length: int, pad_token_id: int, starts: int | None = None, pixel_values: Optional[Union[Array, ndarray, bool, number]] = None, attention_mask: jaxtyping.Bool[Array, 'batch seq_len'] | None = None)[source]#
Prepares inputs for text generation, including pixel values if provided.
- Parameters
input_ids (chex.Array) – Initial input token IDs.
max_length (int) – Maximum generation length.
pixel_values (Optional[chex.Array]) – Pixel values for image input.
attention_mask (Optional[chex.Array]) – Attention mask.
- Returns
Model inputs ready for generation.
- Return type
dict
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates model inputs for the next step of generation, removing pixel values after the first step.
- Parameters
model_outputs – Outputs from the previous generation step.
model_kwargs – Current keyword arguments for the model.
- Returns
Updated model keyword arguments.
- Return type
dict
- class easydel.modules.aya_vision.modeling_aya_vision.AyaVisionMultiModalProjector(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleA multi-modal projector module for AyaVision that processes image features. It applies pixel shuffling, layer normalization, and linear transformations.
- config#
Configuration object.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
JAX precision level.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs