easydel.modules.qwen2_vl.modeling_qwen2_vl_flax#
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.PatchEmbed(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.PatchMerger(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModule
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLCausalLMOutputWithPast(loss: Optional[Union[Array, ndarray, bool, number]] = None, logits: Union[Array, ndarray, bool, number] = None, past_key_values: Optional[List[Union[Array, ndarray, bool, number]]] = None, hidden_states: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, attentions: Optional[Tuple[Union[Array, ndarray, bool, number]]] = None, rope_deltas: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
Bases:
ModelOutputBase class for Qwen2VL causal language model (or autoregressive) outputs.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- config: tp.Union[EasyDeLBaseConfig, _CP]#
- dtype: jnp.dtype#
- get_static_arguments()[source]#
Returns a tuple of static arguments required by the module’s __call__ method.
Static arguments are those that don’t change across calls and can be potentially cached or handled differently by JIT compilation. This base implementation returns an empty tuple. Subclasses should override this if they have static arguments.
- Returns
A tuple containing static arguments.
- Return type
tp.Tuple
- loss_type = 'ForCausalLM'#
- param_dtype: jnp.dtype#
- precision: lax.PrecisionLike#
- prepare_inputs_for_call(image_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, video_grid_thw: Optional[Union[Array, ndarray, bool, number]] = None, image_max_grid_size: int = None, video_max_grid_size: int = None, drop_ids: bool = True, **others)[source]#
Prepares keyword arguments before passing them to the module’s __call__ method.
This base implementation simply returns the kwargs as is. Subclasses can override this to modify or add arguments as needed (e.g., for generation).
- Parameters
**kwargs – The keyword arguments intended for __call__.
- Returns
The prepared keyword arguments.
- Return type
dict
- prepare_inputs_for_generation(input_ids, max_length: int, pad_token_id: int, starts: int | None = None, past_key_values=None, attention_mask=None, inputs_embeds=None, position_ids=None, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, **kwargs)[source]#
Sets up the initial inputs required for starting autoregressive generation.
This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.
- Parameters
input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).
max_length (int) – The maximum sequence length that the KV cache should support.
pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.
starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.
shardings (dict | None) – Optional sharding configuration passed to init_cache.
attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).
token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.
- Returns
- A dictionary containing the prepared inputs, typically including:
”past_key_values”: The initialized KV cache.
”attention_mask”: The extended attention mask for generation.
”position_ids”: The calculated initial position IDs.
”token_type_ids”: (Optional) Prepared token type IDs.
This dictionary is then passed through prepare_inputs_for_call.
- Return type
dict
- rngs: nn.Rngs#
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates the keyword arguments for the next generation step.
Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.
- Parameters
model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).
model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.
- Returns
The updated model_kwargs dictionary ready for the next generation step.
- Return type
dict
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLMLP(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VLVisionBlock(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.Qwen2VisionTransformerPretrainedModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- config_class#
alias of
Qwen2VLVisionConfig
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.VisionAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModule
- class easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.VisionMlp(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.apply_rotary_pos_emb_vision(array: Union[Array, ndarray, bool, number], freqs: Union[Array, ndarray, bool, number]) Union[Array, ndarray, bool, number][source]#
- easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.create_attention_mask(cu_seqlens, seq_length, dtype)[source]#
Creates an attention mask matrix.
- Parameters
cu_seqlens – Cumulative sequence lengths.
seq_length – Length of each sequence.
dtype – Data type of the mask.
- Returns
Attention mask matrix.
- easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.get_rope_index(input_ids: ndarray, image_grid_thw: Optional[ndarray] = None, video_grid_thw: Optional[ndarray] = None, attention_mask: Optional[ndarray] = None, spatial_merge_size: int = 1, image_token_id: int = -1, video_token_id: int = -1, vision_start_token_id: int = -1) Tuple[ndarray, ndarray][source]#
Calculate the 3D rope index based on image and video’s temporal, height, and width in LLM.
- Parameters
input_ids (np.ndarray of shape (batch_size, sequence_length)) – Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
image_grid_thw (np.ndarray of shape (num_images, 3), optional) – The temporal, height, and width of feature shape of each image in LLM.
video_grid_thw (np.ndarray of shape (num_videos, 3), optional) – The temporal, height, and width of feature shape of each video in LLM.
attention_mask (np.ndarray of shape (batch_size, sequence_length), optional) – Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]: - 1 for tokens that are not masked, - 0 for tokens that are masked.
spatial_merge_size (int) – The spatial merge size for vision embeddings.
image_token_id (int) – The token ID representing an image.
video_token_id (int) – The token ID representing a video.
vision_start_token_id (int) – The token ID representing the start of a vision sequence.
- Returns
position_ids (np.ndarray of shape (3, batch_size, sequence_length)) mrope_position_deltas (np.ndarray of shape (batch_size))
- easydel.modules.qwen2_vl.modeling_qwen2_vl_flax.jax_scatter(sec_embeds, ids, fir_embeds, TKN_ID)[source]#