# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import typing as tp
import chex
import jax.lax
from flax import nnx as nn
from jax import numpy as jnp
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import FlaxBaseModelOutput
from easydel.infra.utils import (
ACT2FN,
auto_remat,
block_wise_ffn,
control_mlp_sharding,
get_dot_general_by_bits,
)
from easydel.layers.attention import FlaxAttentionModule, FlexibleAttentionModule
from easydel.layers.norms import RMSNorm
from .pixtral_configuration import PixtralVisionConfig
[docs]def position_ids_in_meshgrid(patch_embeds_list, max_width):
positions = []
for patch in patch_embeds_list:
height, width = patch.shape[-2:]
mesh = jnp.meshgrid(jnp.arange(height), jnp.arange(width), indexing="ij")
h_grid, v_grid = jnp.stack(mesh, axis=-1).reshape(-1, 2).T
ids = h_grid * max_width + v_grid
positions.append(ids)
return jnp.concatenate(positions)
# TODO:Make this jitable
[docs]def generate_block_attention_mask(patch_embeds_list, tensor):
dtype = tensor.dtype
seq_len = tensor.shape[1]
d_min = jnp.finfo(dtype).min
causal_mask = jnp.full((seq_len, seq_len), fill_value=d_min, dtype=dtype)
block_end_idx = jnp.cumsum(jnp.array(patch_embeds_list))
block_start_idx = jnp.cumsum(jnp.array([0] + patch_embeds_list[:-1]))
def update_mask(mask, start_end):
start, end = start_end
return mask.at[start:end, start:end].set(0)
causal_mask = jax.lax.fori_loop(
0,
len(block_start_idx),
lambda i, mask: update_mask(mask, (block_start_idx[i], block_end_idx[i])),
causal_mask,
)
causal_mask = jnp.expand_dims(causal_mask, axis=(0, 1))
causal_mask = jnp.broadcast_to(causal_mask, (tensor.shape[0], 1, seq_len, seq_len))
return causal_mask
[docs]def compute_frequencies(dim: int, max_patches_per_side: int, theta: float = 10000.0):
"""
Computes frequencies with a fixed max length for RoPE.
Args:
dim: Embedding dimension.
max_patches_per_side: Maximum number of patches per side of the image.
theta: Scaling factor for frequencies.
Returns:
inv_freq: Computed frequencies of shape (max_patches_per_side**2, dim).
"""
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2).astype(jnp.float32) / dim))
h = jnp.arange(max_patches_per_side)
w = jnp.arange(max_patches_per_side)
freqs_h = jnp.outer(h, freqs[::2])
freqs_w = jnp.outer(w, freqs[1::2])
inv_freq = jnp.concatenate(
[
jnp.tile(freqs_h[:, None, :], (1, max_patches_per_side, 1)),
jnp.tile(freqs_w[None, :, :], (max_patches_per_side, 1, 1)),
],
axis=-1,
).reshape(-1, dim // 2)
# we reshape to only index on the position indexes, not tuple of indexes
inv_freq = jnp.concatenate((inv_freq, inv_freq), axis=-1)
return inv_freq
# Adapted from transformers.models.llama.modeling_llama.rotate_half
[docs]def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return jnp.concatenate((-x2, x1), axis=-1)
# Adapted from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
[docs]def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=0):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`jnp.ndarray`): The query tensor.
k (`jnp.ndarray`): The key tensor.
cos (`jnp.ndarray`): The cosine part of the rotary embedding.
sin (`jnp.ndarray`): The sine part of the rotary embedding.
position_ids (`jnp.ndarray`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos and sin
so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos and sin have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos and sin broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(jnp.ndarray)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = jnp.expand_dims(cos, axis=unsqueeze_dim)
sin = jnp.expand_dims(sin, axis=unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
[docs]class PixtralMLP(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
linear_class = functools.partial(
nn.Linear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.gate_proj = linear_class(
config.hidden_size,
config.intermediate_size,
rngs=rngs,
)
self.down_proj = linear_class(
config.intermediate_size,
config.hidden_size,
rngs=rngs,
)
self.up_proj = linear_class(
config.hidden_size,
config.intermediate_size,
rngs=rngs,
)
self.act_fn = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states: jnp.ndarray) -> jnp.ndarray:
hidden_states = control_mlp_sharding(hidden_states, self.config.partition_axis)
return self.down_proj(
self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
)
[docs]class PixtralAttention(FlaxAttentionModule):
def __init__(
self,
config: PixtralVisionConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(config=config)
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
self.hidden_size = config.hidden_size
self.head_dim = self.config.head_dim
self.num_key_value_groups = (
self.config.num_attention_heads // self.config.num_attention_heads
)
if self.num_key_value_groups == 1:
assert self.config.num_attention_heads == self.config.num_attention_heads
linear_class = functools.partial(
nn.Linear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.q_proj = linear_class(
config.hidden_size,
config.num_attention_heads * self.head_dim,
rngs=rngs,
)
self.k_proj = linear_class(
config.hidden_size,
config.num_attention_heads * self.head_dim,
rngs=rngs,
)
self.v_proj = linear_class(
config.hidden_size,
config.num_attention_heads * self.head_dim,
rngs=rngs,
)
self.o_proj = linear_class(
config.hidden_size,
config.num_attention_heads * self.head_dim,
rngs=rngs,
)
self.attention_performer = FlexibleAttentionModule(
base_config=config,
softmax_scale=self.head_dim**-0.5,
dropout_prob=config.attention_dropout,
)
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
position_ids: chex.Array,
output_attentions: bool = False,
frequencies: tp.Optional[chex.Array] = None,
):
batch_size, sequence_length = hidden_states.shape[:2]
query_states, key_states, value_states = (
self.q_proj(hidden_states),
self.k_proj(hidden_states),
self.v_proj(hidden_states),
)
query_states = query_states.reshape(
batch_size,
sequence_length,
self.config.num_attention_heads,
self.head_dim,
)
key_states = key_states.reshape(
batch_size,
sequence_length,
self.config.num_attention_heads,
self.head_dim,
)
value_states = value_states.reshape(
batch_size,
sequence_length,
self.config.num_attention_heads,
self.head_dim,
)
query_states, key_states = apply_rotary_pos_emb(
q=query_states,
k=key_states,
cos=jnp.cos(frequencies),
sin=jnp.sin(frequencies),
position_ids=position_ids,
unsqueeze_dim=0,
)
(
key_states,
value_states,
attention_mask,
init_attention_bias,
) = self.concatenate(
query=query_states,
key=key_states,
cache_view=None,
value=value_states,
attention_mask=attention_mask,
causal_mask=None,
fcm_mask=None,
)
attentions = self.attention_performer.forward(
query_states=query_states,
key_states=key_states,
value_states=value_states,
bias=None,
init_bias=init_attention_bias,
attention_mask=attention_mask,
segment_ids=None,
causal=True,
dropout_rng=self.rngs.params(),
)
attn_output = self.shard_attention_prod(
self._merge_heads(attentions.attention_outputs)
)
attn_output = self.o_proj(attn_output)
outputs = (
(attn_output, attentions.attention_weights)
if output_attentions
else (attn_output,)
)
return outputs
[docs]class PixtralBlock(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
attn_block = PixtralAttention
mlp_block = PixtralMLP
attn_block, mlp_block = auto_remat(
attn_block,
mlp_block,
policy=config.gradient_checkpointing,
)
self.attention = attn_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.feed_forward = mlp_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.attention_norm = RMSNorm(
dim=config.hidden_size,
eps=1e-5,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.ffn_norm = RMSNorm(
dim=config.hidden_size,
eps=1e-5,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
position_ids: chex.Array,
output_attentions: bool = False,
frequencies: tp.Optional[chex.Array] = None,
):
residual = hidden_states
attention_output = self.attention(
self.attention_norm(hidden_states),
attention_mask,
position_ids,
output_attentions,
frequencies,
)
hidden_states = attention_output[0] + residual
ffd_inp = self.ffn_norm(hidden_states)
if self.config.use_scan_mlp:
feed_forward_hidden_states = block_wise_ffn(
self.feed_forward, ffd_inp, self.config.scan_mlp_chunk_size
)
else:
feed_forward_hidden_states = self.feed_forward(ffd_inp)
hidden_states = hidden_states + feed_forward_hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attention_output[1],)
return outputs
[docs]@register_module(
TaskType.BASE_VISION,
config=PixtralVisionConfig,
model_type="pixtral",
)
class PixtralVisionModel(EasyDeLBaseModule):
def __init__(
self,
config: PixtralVisionConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.patch_conv = nn.Conv(
in_features=config.num_channels,
out_features=config.hidden_size,
kernel_size=(config.patch_size,) * 2,
strides=(config.patch_size,) * 2,
use_bias=False,
precision=precision,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.ln_pre = RMSNorm(
config.hidden_size,
eps=1e-5,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.transformer = PixtralTransformer(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
@functools.cached_property
def frequencies(self):
return compute_frequencies(
dim=self.config.head_dim,
theta=self.config.rope_theta,
max_patches_per_side=self.config.image_size // self.config.patch_size,
)
def __call__(
self,
pixel_values: tp.List[chex.Array],
output_hidden_states: tp.Optional[bool] = False,
output_attentions: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
*args,
**kwargs,
) -> tp.Union[tp.Tuple, FlaxBaseModelOutput]:
patch_embeds_list = [
self.patch_conv(jnp.expand_dims(img, 0).astype(self.dtype).transpose(0, 2, 3, 1))
for img in pixel_values
]
patch_embeds_list = [p.transpose(0, 3, 1, 2) for p in patch_embeds_list]
# flatten to a single sequence
patch_embeds = jnp.concatenate(
[
jnp.transpose(jnp.reshape(p, (p.shape[0], p.shape[1], -1)), (0, 2, 1))
for p in patch_embeds_list
],
axis=1,
)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
position_ids = position_ids_in_meshgrid(
patch_embeds_list,
max_width=self.config.image_size // self.config.patch_size,
)
attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
)
return self.transformer(
inputs_embeds=patch_embeds,
attention_mask=attention_mask,
position_embeddings=self.frequencies[position_ids],
)