easydel.modules.qwen2_vl.modeling_qwen2_vl_flax

Contents

easydel.modules.qwen2_vl.modeling_qwen2_vl_flax#

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.PatchEmbed(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.PatchMerger(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLCausalLMOutputWithPast(loss: Optional[Union[Array, ndarray, bool, number]] = None, logits: Union[Array, ndarray, bool, number] = None, past_key_values: Optional[List[Union[Array, ndarray, bool, number]]] = None, hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, attentions: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, rope_deltas: Optional[Union[Array, ndarray, bool, number]] = None)[source]#

Bases: ModelOutput

Base class for Qwen2VL causal language model (or autoregressive) outputs.

attentions: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None#
logits: Union[Array, ndarray, bool, number] = None#
loss: Optional[Union[Array, ndarray, bool, number]] = None#
past_key_values: Optional[List[Union[Array, ndarray, bool, number]]] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

rope_deltas: Optional[Union[Array, ndarray, bool, number]] = None#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

config: tp.Union[EasyDeLBaseConfig, _CP]#
dtype: jnp.dtype#
get_decoder()[source]#
get_input_embeddings()[source]#
get_output_embeddings()[source]#
get_static_arguments()[source]#

Returns a tuple of static arguments required by the module’s __call__ method.

Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.

Returns

A tuple containing static arguments.

Return type

tp.Tuple

loss_type = 'ForCausalLM'#
param_dtype: jnp.dtype#
precision: lax.PrecisionLike#
prepare_inputs_for_call(image_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, video_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, image_max_grid_size: int = None, video_max_grid_size: int = None, drop_ids: bool = True, **others)[source]#

Prepares keyword arguments before passing them to the module’s __call__ method.

This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).

Parameters

**kwargs – The keyword arguments intended for __call__.

Returns

The prepared keyword arguments.

Return type

dict

prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, past_key_values=None, attention_mask=None, inputs_embeds=None, position_ids=None, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, **kwargs)[source]#

Sets up the initial inputs required for starting autoregressive generation.

This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.

Parameters
  • input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).

  • max_length (int) – The maximum sequence length that the KV cache should support.

  • pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.

  • starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.

  • shardings (dict | None) – Optional sharding configuration passed to init_cache.

  • attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).

  • token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.

Returns

A dictionary containing the prepared inputs, typically including:
  • ”past_key_values”: The initialized KV cache.

  • ”attention_mask”: The extended attention mask for generation.

  • ”position_ids”: The calculated initial position IDs.

  • ”token_type_ids”: (Optional) Prepared token type IDs.

This dictionary is then passed through prepare_inputs_for_call.

Return type

dict

rngs: nn.Rngs#
update_inputs_for_generation(model_outputs, model_kwargs)[source]#

Updates the keyword arguments for the next generation step.

Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.

Parameters
  • model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).

  • model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.

Returns

The updated model_kwargs dictionary ready for the next generation step.

Return type

dict

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLVisionBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VisionTransformerPretrainedModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

config_class#

alias of Qwen2VLVisionConfig

get_dtype() dtype[source]#
rot_pos_emb(grid_thw, max_grid_size)[source]#
class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.VisionAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.VisionMlp(*args: Any, **kwargs: Any)[source]#

Bases: Module

easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.apply_rotary_pos_emb_vision(array: Union[Array, ndarray, bool, number], freqs: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.create_attention_mask(cu_seqlens, seq_length, dtype)[source]#

Creates an attention mask matrix.

Parameters
  • cu_seqlens – Cumulative sequence lengths.

  • seq_length – Length of each sequence.

  • dtype – Data type of the mask.

Returns

Attention mask matrix.

easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.get_rope_index(input_ids: ndarray, image_grid_thw: Optional[ndarray] = None, video_grid_thw: Optional[ndarray] = None, attention_mask: Optional[ndarray] = None, spatial_merge_size: int = 1, image_token_id: int = -1, video_token_id: int = -1, vision_start_token_id: int = -1) Tuple[ndarray, ndarray][source]#

Calculate the 3D rope index based on image and video’s temporal, height, and width in LLM.

Parameters
  • input_ids (np.ndarray of shape (batch_size, sequence_length)) – Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.

  • image_grid_thw (np.ndarray of shape (num_images, 3), optional) – The temporal, height, and width of feature shape of each image in LLM.

  • video_grid_thw (np.ndarray of shape (num_videos, 3), optional) – The temporal, height, and width of feature shape of each video in LLM.

  • attention_mask (np.ndarray of shape (batch_size, sequence_length), optional) – Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]: - 1 for tokens that are not masked, - 0 for tokens that are masked.

  • spatial_merge_size (int) – The spatial merge size for vision embeddings.

  • image_token_id (int) – The token ID representing an image.

  • video_token_id (int) – The token ID representing a video.

  • vision_start_token_id (int) – The token ID representing the start of a vision sequence.

Returns

position_ids (np.ndarray of shape (3, batch_size, sequence_length)) mrope_position_deltas (np.ndarray of shape (batch_size))

easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.jax_scatter(sec_embeds, ids, fir_embeds, TKN_ID)[source]#
easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.precompute_vl_rotary(dim, theta, max_position)[source]#
easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.rotate_half(x)[source]#

Rotates half the hidden dims of the input.