easydel.modules.pixtral.modeling_pixtral_flax#
- class easydel.modules.pixtral.modeling_pixtral_flax.PixtralAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModulePixtral Attention module.
This module implements the multi-head self-attention mechanism used in the Pixtral vision model. It utilizes Rotary Position Embeddings (RoPE).
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
Dimensionality of the hidden states.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_groups#
Number of query head groups for each key/value head (typically 1 for MHA).
- Type
int
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- o_proj#
Linear layer for the output projection.
- Type
- attention_performer#
Module to perform the core attention computation.
- class easydel.modules.pixtral.modeling_pixtral_flax.PixtralBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePixtral Transformer Block.
This module represents a single transformer block in the Pixtral vision model, containing self-attention and MLP sub-layers with residual connections and RMS normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- attention#
The self-attention module.
- Type
- feed_forward#
The feed-forward (MLP) module.
- Type
- class easydel.modules.pixtral.modeling_pixtral_flax.PixtralMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePixtral MLP module.
This module implements the feed-forward network (MLP) used in the Pixtral vision model. It uses a Gated Linear Unit (GLU) structure with SiLU activation.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- gate_proj#
Linear layer for the GLU gate.
- Type
- down_proj#
Linear layer for the down projection.
- Type
- up_proj#
Linear layer for the GLU value.
- Type
- act_fn#
Activation function (GELU in the original config, but SiLU is commonly used in similar models).
- Type
callable
- class easydel.modules.pixtral.modeling_pixtral_flax.PixtralTransformer(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePixtral Transformer stack.
This module represents the main stack of transformer blocks in the Pixtral vision model. It takes patch embeddings as input and processes them through multiple PixtralBlock layers, applying a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- layers#
List of transformer blocks.
- Type
tp.List[PixtralBlock]
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.modules.pixtral.modeling_pixtral_flax.PixtralVisionModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe Pixtral Vision Model transformer.
This class implements the complete Pixtral vision model, including patch embedding via convolution and the main transformer stack.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- patch_conv#
Convolutional layer for patch embedding.
- Type
nn.Conv
- transformer#
The main transformer stack.
- Type
- property frequencies#
Cached property to compute and retrieve RoPE frequencies.
- easydel.modules.pixtral.modeling_pixtral_flax.apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=0)[source]#
Applies Rotary Position Embedding to the query and key tensors.
- Parameters
q (jnp.ndarray) – The query tensor.
k (jnp.ndarray) – The key tensor.
cos (jnp.ndarray) – The cosine part of the rotary embedding.
sin (jnp.ndarray) – The sine part of the rotary embedding.
position_ids (jnp.ndarray, optional) – Deprecated and unused.
unsqueeze_dim (int, optional) – The ‘unsqueeze_dim’ argument specifies the dimension along which to unsqueeze cos and sin so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns
tuple(jnp.ndarray) comprising of the query and key tensors rotated using the Rotary Position Embedding.
- easydel.modules.pixtral.modeling_pixtral_flax.compute_frequencies(dim: int, max_patches_per_side: int, theta: float = 10000.0)[source]#
Computes frequencies with a fixed max length for RoPE.
- Parameters
dim – Embedding dimension.
max_patches_per_side – Maximum number of patches per side of the image.
theta – Scaling factor for frequencies.
- Returns
Computed frequencies of shape (max_patches_per_side**2, dim).
- Return type
inv_freq
- easydel.modules.pixtral.modeling_pixtral_flax.generate_block_attention_mask(patch_embeds_list, tensor)[source]#
Generates a block-diagonal attention mask for multi-image processing.
This mask ensures that attention is only computed within each image’s patches, preventing cross-image attention.
- Parameters
patch_embeds_list (list[int]) – A list containing the number of patches for each image.
tensor (chex.Array) – The input tensor (e.g., hidden states) with shape (batch_size, sequence_length, …).
- Returns
- A block-diagonal attention mask of shape
(batch_size, 1, sequence_length, sequence_length). The mask contains 0.0 for allowed attention positions and a large negative number (minimum float value) for masked positions.
- Return type
chex.Array
- easydel.modules.pixtral.modeling_pixtral_flax.position_ids_in_meshgrid(patch_embeds_list, max_width)[source]#
Generates position IDs based on a meshgrid for a list of patch embeddings.
- Parameters
patch_embeds_list (list[chex.Array]) – A list of patch embeddings, where each element has shape (…, height, width).
max_width (int) – The maximum width across all patches, used for calculating the linear index.
- Returns
A 1D array of position IDs corresponding to the flattened patches.
- Return type
chex.Array