# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import typing as tp
from functools import partial
import chex
import jax
import jax.numpy as jnp
import numpy as np
from eformer.pytree import auto_pytree
from flax import nnx as nn
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import (
BaseModelOutput,
ModelOutput,
)
from easydel.infra.utils import (
ACT2FN,
auto_remat,
block_wise_ffn,
control_mlp_sharding,
get_dot_general_by_bits,
)
from easydel.layers.attention import AttentionModule, FlexibleAttentionModule
from easydel.layers.caching import (
PagedAttentionCache,
PagedAttentionCacheView,
PagedAttentionMetadata,
TransformerCache,
TransformerCacheView,
TransformerMetadata,
)
from easydel.layers.linear import ParallelLinear
from easydel.layers.norms import RMSNorm
from .qwen2_vl_configuration import Qwen2VLConfig, Qwen2VLVisionConfig
# TODO: Convert this to a jitable jax fn and use that inside model instead of precall
[docs]def get_rope_index(
input_ids: np.ndarray,
image_grid_thw: tp.Optional[np.ndarray] = None,
video_grid_thw: tp.Optional[np.ndarray] = None,
attention_mask: tp.Optional[np.ndarray] = None,
spatial_merge_size: int = 1,
image_token_id: int = -1,
video_token_id: int = -1,
vision_start_token_id: int = -1,
) -> tp.Tuple[np.ndarray, np.ndarray]:
"""
Calculate the 3D rope index based on image and video's temporal, height, and width in LLM.
Args:
input_ids (`np.ndarray` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`np.ndarray` of shape `(num_images, 3)`, *optional*):
The temporal, height, and width of feature shape of each image in LLM.
video_grid_thw (`np.ndarray` of shape `(num_videos, 3)`, *optional*):
The temporal, height, and width of feature shape of each video in LLM.
attention_mask (`np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
spatial_merge_size (int):
The spatial merge size for vision embeddings.
image_token_id (int):
The token ID representing an image.
video_token_id (int):
The token ID representing a video.
vision_start_token_id (int):
The token ID representing the start of a vision sequence.
Returns:
position_ids (`np.ndarray` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`np.ndarray` of shape `(batch_size)`)
"""
if input_ids.shape[-1] != 1:
attention_mask = attention_mask[:, : input_ids.shape[-1]]
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = np.ones_like(total_input_ids)
position_ids = np.ones(
(3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
)
image_index, video_index = 0, 0
mrope_position_deltas = []
for i in range(input_ids.shape[0]):
input_ids_masked = input_ids[i][attention_mask[i] == 1]
vision_start_indices = np.where(input_ids_masked == vision_start_token_id)[0]
vision_tokens = input_ids_masked[vision_start_indices + 1]
image_nums = np.sum(vision_tokens == image_token_id)
video_nums = np.sum(vision_tokens == video_token_id)
input_tokens = input_ids_masked.tolist()
llm_pos_ids_list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
int(t),
int(h) // spatial_merge_size,
int(w) // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.arange(text_len).reshape(1, -1).repeat(3, axis=0) + st_idx
)
t_index = (
np.arange(llm_grid_t)
.reshape(-1, 1)
.repeat(llm_grid_h * llm_grid_w, axis=1)
.flatten()
)
h_index = (
np.arange(llm_grid_h)
.reshape(1, -1, 1)
.repeat(llm_grid_t, axis=0)
.repeat(llm_grid_w, axis=2)
.flatten()
)
w_index = (
np.arange(llm_grid_w)
.reshape(1, 1, -1)
.repeat(llm_grid_t, axis=0)
.repeat(llm_grid_h, axis=1)
.flatten()
)
llm_pos_ids_list.append(
np.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
np.arange(text_len).reshape(1, -1).repeat(3, axis=0) + st_idx
)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
position_ids[:, i, attention_mask[i] == 1] = llm_positions
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = np.array(mrope_position_deltas).reshape(-1, 1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = jnp.cumsum(attention_mask, axis=-1) - 1
position_ids = jnp.where(attention_mask == 0, 1, position_ids)
position_ids = jnp.expand_dims(position_ids, axis=0).repeat(3, axis=0)
max_position_ids = jnp.max(position_ids, axis=(0, 2), keepdims=True)
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
np.arange(input_ids.shape[1])
.reshape(1, 1, -1)
.repeat(3, axis=0)
.repeat(input_ids.shape[0], axis=1)
)
mrope_position_deltas = np.zeros((input_ids.shape[0], 1), dtype=input_ids.dtype)
return position_ids, mrope_position_deltas
[docs]@auto_pytree
class Qwen2VLCausalLMOutputWithPast(ModelOutput):
"""
Base class for Qwen2VL causal language model (or autoregressive) outputs.
"""
loss: tp.Optional[chex.Array] = None
logits: chex.Array = None
past_key_values: tp.Optional[tp.List[chex.Array]] = None
hidden_states: tp.Optional[tp.Tuple[chex.Array]] = None
attentions: tp.Optional[tp.Tuple[chex.Array]] = None
rope_deltas: tp.Optional[chex.Array] = None
[docs]def create_attention_mask(cu_seqlens, seq_length, dtype):
"""
Creates an attention mask matrix.
Args:
cu_seqlens: Cumulative sequence lengths.
seq_length: Length of each sequence.
dtype: Data type of the mask.
Returns:
Attention mask matrix.
"""
attention_mask = jnp.full(
(1, seq_length, seq_length),
jnp.finfo(dtype).min,
dtype=dtype,
)
mask_updates = jnp.zeros((1, seq_length, seq_length), dtype=dtype)
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
mask_updates = mask_updates.at[
...,
start_idx:end_idx,
start_idx:end_idx,
].set(0)
attention_mask = jax.lax.dynamic_update_slice(attention_mask, mask_updates, (0, 0, 0))
return attention_mask
# some of my garbage ideas but they always endup workin
# TODO: Fix this structure somehow
[docs]@partial(jax.jit, static_argnames=["TKN_ID"])
def jax_scatter(sec_embeds, ids, fir_embeds, TKN_ID):
image_embeds = sec_embeds.astype(fir_embeds.dtype)
image_indices = (
jnp.where(
jnp.broadcast_to(
jnp.expand_dims(ids == TKN_ID, axis=-1), fir_embeds.shape
).reshape(-1),
size=fir_embeds.size,
fill_value=-1,
)[0]
+ 1
)
flatten_emb = fir_embeds.reshape(-1)
flatten_img_emb = image_embeds.reshape(-1)[: len(image_indices)]
flatten_emb = jnp.pad(flatten_emb, (1, 0))
flatten_img_emb = jnp.pad(
flatten_img_emb,
(0, flatten_emb.size - flatten_img_emb.size),
# this will default be known as 0 so it wont be used anyway
)
image_indices = jnp.pad(
image_indices,
(0, flatten_emb.size - image_indices.size),
# this will default be known as 0 so it wont be used anyway
)
scattered_embeds = flatten_emb.at[image_indices].set(flatten_img_emb)[1:]
fir_embeds = scattered_embeds.reshape(fir_embeds.shape)
return fir_embeds
[docs]def precompute_vl_rotary(dim, theta, max_position):
inv = 1.0 / (theta ** (jnp.arange(0, dim, 2, dtype="f4") / dim))
seq = jnp.arange(0, max_position, "f4")
return jnp.outer(seq, inv)
[docs]def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return jnp.concatenate([-x2, x1], axis=-1)
[docs]def apply_rotary_pos_emb_vision(array: chex.Array, freqs: chex.Array) -> chex.Array:
orig_dtype = array.dtype
array = array.astype("f4")
cos = jnp.cos(freqs)
sin = jnp.sin(freqs)
cos = jnp.expand_dims(jnp.repeat(jnp.expand_dims(cos, 1), 2, -1), 0).astype("f4")
sin = jnp.expand_dims(jnp.repeat(jnp.expand_dims(sin, 1), 2, -1), 0).astype("f4")
output = (array * cos) + (rotate_half(array) * sin)
output = output.astype(orig_dtype)
return output.squeeze(0)
[docs]class PatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
embed_dim: int = 1152,
precision: jax.lax.PrecisionLike = None,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
*,
rngs: nn.Rngs,
) -> None:
self.dtype = dtype
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv(
in_features=in_channels,
out_features=embed_dim,
kernel_size=kernel_size,
strides=kernel_size,
use_bias=False,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(self, hidden_states: chex.Array) -> chex.Array:
hidden_states = jnp.transpose(
hidden_states.reshape(
-1,
self.in_channels,
self.temporal_patch_size,
self.patch_size,
self.patch_size,
),
(0, 2, 3, 4, 1),
)
hidden_states = self.proj(hidden_states.astype(self.dtype))
hidden_states = hidden_states.reshape(-1, self.embed_dim)
return hidden_states
[docs]class PatchMerger(nn.Module):
def __init__(
self,
dim: int,
context_dim: int,
spatial_merge_size: int = 2,
precision: jax.lax.PrecisionLike = None,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
*,
rngs: nn.Rngs,
) -> None:
super().__init__()
self.dtype = dtype
self.hidden_size = context_dim * (spatial_merge_size**2)
self.ln_q = nn.LayerNorm(
context_dim,
epsilon=1e-6,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.mlp = [
ParallelLinear(
self.hidden_size,
self.hidden_size,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
),
partial(nn.gelu, approximate=False),
ParallelLinear(
self.hidden_size,
dim,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
),
]
def __call__(self, x: chex.Array) -> chex.Array:
x = self.ln_q(x).reshape(-1, self.hidden_size)
for mlp in self.mlp: # make easy attach work with no effort
x = mlp(x)
return x
[docs]class VisionMlp(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
hidden_act: str,
precision: jax.lax.PrecisionLike = None,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
*,
rngs: nn.Rngs,
) -> None:
super().__init__()
self.fc1 = ParallelLinear(
dim,
hidden_dim,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.act = ACT2FN[hidden_act]
self.fc2 = ParallelLinear(
hidden_dim,
dim,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(self, x: chex.Array) -> chex.Array:
return self.fc2(self.act(self.fc1(x)))
[docs]class VisionAttention(AttentionModule):
def __init__(
self,
config,
dim: int,
num_heads: int = 16,
precision: jax.lax.PrecisionLike = None,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
*,
rngs: nn.Rngs,
):
super().__init__(config)
self.rngs = rngs
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = ParallelLinear(
dim,
dim * 3,
use_bias=True,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.proj = ParallelLinear(
dim,
dim,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.attention_performer = FlexibleAttentionModule(
base_config=config,
softmax_scale=self.head_dim**-0.5,
dropout_prob=0.0,
)
def __call__(
self,
hidden_states: chex.Array,
cu_seqlens: chex.Array,
rotary_pos_emb: chex.Array = None,
) -> chex.Array:
seq_length = hidden_states.shape[0]
q, k, v = map(
lambda x: x.squeeze(0),
jnp.split(
self.qkv(hidden_states)
.reshape(seq_length, 3, self.num_heads, -1)
.transpose(1, 0, 2, 3), # seq spl nhd fea -> spl seq nhd fea
3,
0,
),
)
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
# q = jnp.expand_dims(q, 0)
# k = jnp.expand_dims(k, 0)
# v = jnp.expand_dims(v, 0)
# attention_mask = jnp.full(
# (1, seq_length, seq_length),
# jnp.finfo(q.dtype).min,
# dtype=q.dtype,
# )
# for i in range(1, len(cu_seqlens)):
# mask = attention_mask.at[
# ...,
# cu_seqlens[i - 1] : cu_seqlens[i],
# cu_seqlens[i - 1] : cu_seqlens[i],
# ].set(0)
# attention_mask = mask
row_ids = jnp.arange(seq_length)[None, None, :]
col_ids = jnp.arange(seq_length)[None, :, None]
starts = cu_seqlens[:-1][:, None, None]
ends = cu_seqlens[1:][:, None, None]
is_valid = (
(row_ids >= starts) & (row_ids < ends) & (col_ids >= starts) & (col_ids < ends)
)
combined_mask = jnp.any(is_valid, axis=0)
attention_mask = jnp.where(combined_mask, 0.0, jnp.finfo(q.dtype).min)
q = q.swapaxes(0, 1)
k = k.swapaxes(0, 1)
v = v.swapaxes(0, 1)
attn_weights = jnp.matmul(q, k.swapaxes(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = jax.nn.softmax(
attn_weights.astype(jnp.float32),
axis=-1,
).astype(q.dtype)
attn_output = jnp.matmul(attn_weights, v)
attn_output = attn_output.swapaxes(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
[docs]class Qwen2VLVisionBlock(nn.Module):
def __init__(
self,
config: Qwen2VLVisionConfig,
precision: jax.lax.PrecisionLike = None,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
*,
rngs: nn.Rngs,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(
config.embed_dim,
epsilon=1e-6,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.norm2 = nn.LayerNorm(
config.embed_dim,
epsilon=1e-6,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
self.attn = VisionAttention(
config=config,
dim=config.embed_dim,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.mlp = VisionMlp(
dim=config.embed_dim,
hidden_dim=mlp_hidden_dim,
hidden_act=config.hidden_act,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> chex.Array:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
[docs]class Qwen2VLMLP(nn.Module):
def __init__(
self,
config: Qwen2VLConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
linear_class = partial(
ParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.gate_proj = linear_class(
config.hidden_size,
config.intermediate_size,
rngs=rngs,
)
self.down_proj = linear_class(
config.intermediate_size,
config.hidden_size,
rngs=rngs,
)
self.up_proj = linear_class(
config.hidden_size,
config.intermediate_size,
rngs=rngs,
)
self.act_fn = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states: jnp.ndarray) -> jnp.ndarray:
hidden_states = control_mlp_sharding(hidden_states, self.config.partition_axis)
hidden_states = self.down_proj(
self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)
)
return hidden_states
[docs]class Qwen2VLAttention(AttentionModule):
def __init__(
self,
config: Qwen2VLConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(config=config)
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.rngs = rngs
self.hidden_size = config.hidden_size
head_dim = config.hidden_size // config.num_attention_heads
self.head_dim = getattr(config, "head_dim", head_dim)
self.num_key_value_groups = (
self.config.num_attention_heads // self.config.num_key_value_heads
)
if self.num_key_value_groups == 1:
assert self.config.num_attention_heads == self.config.num_key_value_heads
linear_class = partial(
ParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.q_proj = linear_class(
config.hidden_size,
config.num_attention_heads * self.head_dim,
rngs=rngs,
use_bias=True,
)
self.k_proj = linear_class(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
rngs=rngs,
use_bias=True,
)
self.v_proj = linear_class(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
rngs=rngs,
use_bias=True,
)
self.o_proj = linear_class(
config.num_attention_heads * self.head_dim,
config.hidden_size,
rngs=rngs,
use_bias=False,
)
self.rotary = self.config.get_basic_rope(
self.dtype,
self.head_dim,
self.head_dim,
True,
)
self.attention_performer = FlexibleAttentionModule(
base_config=config,
softmax_scale=self.head_dim**-0.5,
dropout_prob=config.attention_dropout,
)
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
position_ids: chex.Array,
causal_mask: tp.Optional[chex.Array | bool],
cache_view: tp.Optional[TransformerCacheView | PagedAttentionCacheView] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
segment_ids: tp.Optional[chex.Array] = None,
output_attentions: bool = False,
fcm_mask: tp.Optional[chex.Array] = None,
frequencies: tp.Optional[chex.Array] = None,
) -> tp.Tuple[chex.Array, chex.Array]:
batch_size, sequence_length = hidden_states.shape[:2]
query_states, key_states, value_states = (
self.q_proj(hidden_states),
self.k_proj(hidden_states),
self.v_proj(hidden_states),
)
qshape = (
batch_size,
sequence_length,
self.config.num_attention_heads,
self.head_dim,
)
kv_shape = (
batch_size,
sequence_length,
self.config.num_key_value_heads,
self.head_dim,
)
query_states = query_states.reshape(qshape)
key_states = key_states.reshape(kv_shape)
value_states = value_states.reshape(kv_shape)
if position_ids.ndim == 3:
position_ids = position_ids[0]
# cond vision gen issue will be fixed with no mem issue.
query_states, key_states = self.rotary(
positions=position_ids,
query=query_states,
key=key_states,
frequencies=frequencies,
)
(
key_states,
value_states,
attention_mask,
init_attention_bias,
) = self.concatenate(
query=query_states,
key=key_states,
cache_view=cache_view,
value=value_states,
attention_mask=attention_mask,
causal_mask=causal_mask,
fcm_mask=fcm_mask,
)
attentions = self.attention_performer.forward(
query_states=query_states,
key_states=key_states,
value_states=value_states,
bias=None,
cache_metadata=cache_metadata,
cache_view=cache_view,
init_bias=init_attention_bias,
attention_mask=attention_mask,
segment_ids=segment_ids,
causal=True,
dropout_rng=self.rngs.params(),
)
attn_output = self.o_proj(
self.shard_attention_prod(
attn_output=self._merge_heads(attentions.attention_outputs)
)
)
return attn_output, attentions.attention_weights
[docs]class Qwen2VLDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen2VLConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
attn_block = Qwen2VLAttention
mlp_block = Qwen2VLMLP
attn_block, mlp_block = auto_remat(
attn_block,
mlp_block,
policy=config.gradient_checkpointing,
)
self.self_attn = attn_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.mlp = mlp_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.input_layernorm = RMSNorm(
dim=config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.post_attention_layernorm = RMSNorm(
dim=config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
hidden_states: chex.Array,
attention_mask: chex.Array,
position_ids: chex.Array,
causal_mask: tp.Optional[chex.Array | bool],
cache_view: tp.Optional[TransformerCacheView | PagedAttentionCacheView] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
segment_ids: tp.Optional[chex.Array] = None,
output_attentions: bool = False,
fcm_mask: tp.Optional[chex.Array] = None,
frequencies: tp.Optional[chex.Array] = None,
):
attn_outputs = self.self_attn(
self.input_layernorm(hidden_states),
attention_mask,
position_ids,
causal_mask,
cache_view,
cache_metadata,
segment_ids,
output_attentions,
fcm_mask,
frequencies,
)
attn_output = attn_outputs[0]
hidden_states = hidden_states + attn_output
feed_forward_input = self.post_attention_layernorm(hidden_states)
if self.config.use_scan_mlp:
feed_forward_hidden_states = block_wise_ffn(
self.mlp,
feed_forward_input,
self.config.scan_mlp_chunk_size,
)
else:
feed_forward_hidden_states = self.mlp(feed_forward_input)
hidden_states = hidden_states + feed_forward_hidden_states
return (hidden_states,) + attn_outputs[1:]
[docs]@register_module(
TaskType.BASE_MODULE,
config=Qwen2VLConfig,
model_type="qwen2_vl",
)
class Qwen2VLModel(EasyDeLBaseModule):
def __init__(
self,
config: Qwen2VLConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.embed_tokens = nn.Embed(
num_embeddings=self.config.vocab_size,
features=self.config.hidden_size,
dtype=dtype,
param_dtype=param_dtype,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
rngs=rngs,
)
self.layers = [
Qwen2VLDecoderLayer(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for _ in range(self.config.num_hidden_layers)
]
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
segment_ids: tp.Optional[chex.Array] = None,
past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: bool = True,
) -> tp.Union[BaseModelOutput, tp.Tuple]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids.astype("i4"))
batch_size, sequence_length, _ = inputs_embeds.shape
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
assert sequence_length <= self.config.max_position_embeddings, (
f"Maximum Position Embedding Reached ! (Excepted <= {self.config.max_position_embeddings} got {sequence_length})"
)
if attention_mask is None:
attention_mask = jnp.ones((batch_size, sequence_length), "b1")
else:
if attention_mask.dtype != jnp.bool:
attention_mask = jnp.astype(attention_mask == 1, "b1")
if position_ids is None:
position_ids = jnp.broadcast_to(
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
(batch_size, sequence_length),
).astype(jnp.int32)
hidden_states = inputs_embeds
if past_key_values is None:
past_key_values = TransformerCache.init_empty(len(self.layers))
for idx, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
cache_view=past_key_values.views[idx],
cache_metadata=cache_metadata,
causal_mask=self.causal_mask,
output_attentions=output_attentions,
segment_ids=segment_ids,
frequencies=self.frequencies,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions, past_key_values)
else:
outputs = (hidden_states, all_attentions, past_key_values)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=past_key_values,
)
[docs]@register_module(
TaskType.IMAGE_TEXT_TO_TEXT,
config=Qwen2VLConfig,
model_type="qwen2_vl",
)
class Qwen2VLForConditionalGeneration(EasyDeLBaseModule):
loss_type = "ForCausalLM"
def __init__(
self,
config: Qwen2VLConfig,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.visual = Qwen2VisionTransformerPretrainedModel(
config.vision_config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.model = Qwen2VLModel(
config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.vocab_size = config.vocab_size
self.lm_head = ParallelLinear(
config.hidden_size,
config.vocab_size,
use_bias=False,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
[docs] def get_output_embeddings(self):
return self.lm_head
[docs] def get_decoder(self):
return self.model
def __call__(
self,
input_ids: chex.Array = None,
attention_mask: tp.Optional[chex.Array] = None,
position_ids: tp.Optional[chex.Array] = None,
past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None,
cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
output_attentions: tp.Optional[bool] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
pixel_values: tp.Optional[chex.Array] = None,
pixel_values_videos: tp.Optional[chex.Array] = None,
image_grid_thw: tp.Optional[tuple] = None,
video_grid_thw: tp.Optional[tuple] = None,
rope_deltas: tp.Optional[chex.Array] = None,
image_max_grid_size: int = None,
video_max_grid_size: int = None,
) -> tp.Union[tp.Tuple, Qwen2VLCausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.astype(self.visual.get_dtype())
image_embeds = self.visual(
pixel_values,
grid_thw=np.array(image_grid_thw),
max_grid_size=image_max_grid_size,
)
inputs_embeds = jax_scatter(
image_embeds,
input_ids,
inputs_embeds,
self.config.image_token_id,
)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.astype(self.visual.get_dtype())
video_embeds = self.visual(
pixel_values_videos,
grid_thw=np.array(video_grid_thw),
max_grid_size=video_max_grid_size,
)
inputs_embeds = jax_scatter(
video_embeds,
input_ids,
inputs_embeds,
self.config.video_token_id,
)
if (
position_ids is None
and input_ids is not None
and (attention_mask is None or attention_mask.ndim == 2)
):
if past_key_values is not None or rope_deltas is None:
position_ids, rope_deltas = get_rope_index(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
spatial_merge_size=self.visual.spatial_merge_size,
image_token_id=self.config.image_token_id,
video_token_id=self.config.video_token_id,
vision_start_token_id=self.config.vision_start_token_id,
)
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
position_ids = jnp.arange(sequence_length).reshape(1, -1).repeat(batch_size, 0)
position_ids = jnp.expand_dims(position_ids, 0).repeat(3, 0)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
cache_metadata=cache_metadata,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if not return_dict:
output = (logits,) + outputs[1:]
return output
return Qwen2VLCausalLMOutputWithPast(
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=rope_deltas,
)
[docs] def get_static_arguments(self):
return (
"video_max_grid_size",
"image_max_grid_size",
"image_grid_thw",
"video_grid_thw",
)
def _get_compile_model_kwargs(
self,
batch_size: int,
input_tokens_length: int,
input_sharding: jax.sharding.PartitionSpec,
rngs: jax.random.PRNGKey,
vision_included: bool = False,
vision_batch_size: int = 1,
vision_channels: int = 3,
vision_height: tp.Optional[int] = None,
vision_width: tp.Optional[int] = None,
required_props: tp.Optional[tp.Mapping[str, tp.Dict[str, tp.Any]]] = None,
**kwargs,
):
basics = super()._get_compile_model_kwargs(
batch_size=batch_size,
input_tokens_length=input_tokens_length,
input_sharding=input_sharding,
rngs=rngs,
vision_included=vision_included,
vision_batch_size=vision_batch_size,
vision_channels=vision_channels,
vision_height=vision_height,
vision_width=vision_width,
required_props=required_props,
**kwargs,
)
if vision_included:
assert required_props is not None
assert "image_grid_thw" in required_props.keys()
pixel_values = jnp.ones((vision_height, vision_width), dtype="f4")
basics.update(
{
"pixel_values": pixel_values,
"image_grid_thw": jnp.array(required_props["image_grid_thw"]["value"]),
}
)
return basics
def _create_required_props_from_kwargs(
self,
model_kwargs: tp.Dict[str, chex.Array],
) -> tp.Optional[tp.Mapping[str, tp.Dict[str, tp.Any]]]:
basics = {}
if "image_grid_thw" in model_kwargs.keys():
basics.update(
{"image_grid_thw": {"value": jnp.array(model_kwargs["image_grid_thw"])}}
)
if "video_grid_thw" in model_kwargs.keys():
basics.update(
{"video_grid_thw": {"value": jnp.array(model_kwargs["video_grid_thw"])}}
)
return basics