easydel.modules.pixtral.modeling_pixtral_flax#

class easydel.modules.pixtral.modeling_pixtral_flax.PixtralAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Pixtral Attention module.

This module implements the multi-head self-attention mechanism used in the Pixtral vision model. It utilizes Rotary Position Embeddings (RoPE).

config#

Configuration object for the model.

Type

PixtralVisionConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

hidden_size#

Dimensionality of the hidden states.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

num_key_value_groups#

Number of query head groups for each key/value head (typically 1 for MHA).

Type

int

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

o_proj#

Linear layer for the output projection.

Type

ParallelLinear

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

class easydel.modules.pixtral.modeling_pixtral_flax.PixtralBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Pixtral Transformer Block.

This module represents a single transformer block in the Pixtral vision model, containing self-attention and MLP sub-layers with residual connections and RMS normalization.

config#

Configuration object for the model.

Type

PixtralVisionConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

ln_1#

RMS normalization applied before the attention layer.

Type

RMSNorm

ln_2#

RMS normalization applied before the MLP layer.

Type

RMSNorm

attention#

The self-attention module.

Type

PixtralAttention

feed_forward#

The feed-forward (MLP) module.

Type

PixtralMLP

class easydel.modules.pixtral.modeling_pixtral_flax.PixtralMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Pixtral MLP module.

This module implements the feed-forward network (MLP) used in the Pixtral vision model. It uses a Gated Linear Unit (GLU) structure with SiLU activation.

config#

Configuration object for the model.

Type

PixtralVisionConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

gate_proj#

Linear layer for the GLU gate.

Type

ParallelLinear

down_proj#

Linear layer for the down projection.

Type

ParallelLinear

up_proj#

Linear layer for the GLU value.

Type

ParallelLinear

act_fn#

Activation function (GELU in the original config, but SiLU is commonly used in similar models).

Type

callable

class easydel.modules.pixtral.modeling_pixtral_flax.PixtralTransformer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Pixtral Transformer stack.

This module represents the main stack of transformer blocks in the Pixtral vision model. It takes patch embeddings as input and processes them through multiple PixtralBlock layers, applying a final layer normalization.

config#

Configuration object for the model.

Type

PixtralVisionConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

layers#

List of transformer blocks.

Type

tp.List[PixtralBlock]

ln_post#

Final layer normalization applied after the transformer blocks.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.modules.pixtral.modeling_pixtral_flax.PixtralVisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The Pixtral Vision Model transformer.

This class implements the complete Pixtral vision model, including patch embedding via convolution and the main transformer stack.

config#

Configuration object for the model.

Type

PixtralVisionConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

patch_conv#

Convolutional layer for patch embedding.

Type

nn.Conv

transformer#

The main transformer stack.

Type

PixtralTransformer

ln_pre#

Layer normalization applied before the transformer blocks.

Type

RMSNorm

property frequencies#

Cached property to compute and retrieve RoPE frequencies.

easydel.modules.pixtral.modeling_pixtral_flax.apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=0)[source]#

Applies Rotary Position Embedding to the query and key tensors.

Parameters
  • q (jnp.ndarray) – The query tensor.

  • k (jnp.ndarray) – The key tensor.

  • cos (jnp.ndarray) – The cosine part of the rotary embedding.

  • sin (jnp.ndarray) – The sine part of the rotary embedding.

  • position_ids (jnp.ndarray, optional) – Deprecated and unused.

  • unsqueeze_dim (int, optional) – The ‘unsqueeze_dim’ argument specifies the dimension along which to unsqueeze cos and sin so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.

Returns

tuple(jnp.ndarray) comprising of the query and key tensors rotated using the Rotary Position Embedding.

easydel.modules.pixtral.modeling_pixtral_flax.compute_frequencies(dim: int, max_patches_per_side: int, theta: float = 10000.0)[source]#

Computes frequencies with a fixed max length for RoPE.

Parameters
  • dim – Embedding dimension.

  • max_patches_per_side – Maximum number of patches per side of the image.

  • theta – Scaling factor for frequencies.

Returns

Computed frequencies of shape (max_patches_per_side**2, dim).

Return type

inv_freq

easydel.modules.pixtral.modeling_pixtral_flax.generate_block_attention_mask(patch_embeds_list, tensor)[source]#

Generates a block-diagonal attention mask for multi-image processing.

This mask ensures that attention is only computed within each image’s patches, preventing cross-image attention.

Parameters
  • patch_embeds_list (list[int]) – A list containing the number of patches for each image.

  • tensor (chex.Array) – The input tensor (e.g., hidden states) with shape (batch_size, sequence_length, …).

Returns

A block-diagonal attention mask of shape

(batch_size, 1, sequence_length, sequence_length). The mask contains 0.0 for allowed attention positions and a large negative number (minimum float value) for masked positions.

Return type

chex.Array

easydel.modules.pixtral.modeling_pixtral_flax.position_ids_in_meshgrid(patch_embeds_list, max_width)[source]#

Generates position IDs based on a meshgrid for a list of patch embeddings.

Parameters
  • patch_embeds_list (list[chex.Array]) – A list of patch embeddings, where each element has shape (…, height, width).

  • max_width (int) – The maximum width across all patches, used for calculating the linear index.

Returns

A 1D array of position IDs corresponding to the flattened patches.

Return type

chex.Array

easydel.modules.pixtral.modeling_pixtral_flax.rotate_half(x)[source]#

Rotates half the hidden dims of the input.