easydel.modules.whisper.modeling_whisper_flax#
- class easydel.modules.whisper.modeling_whisper_flax.WhisperAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleWhisper Attention mechanism.
This module implements the standard multi-head attention mechanism used in both the encoder and decoder of the Whisper model.
- config#
Configuration object for the model.
- Type
- embed_dim#
Dimensionality of the embedding layer.
- Type
int
- num_heads#
Number of attention heads.
- Type
int
- dropout#
Dropout probability.
- Type
float
- causal#
Whether this attention is causal (used in decoder self-attention).
- Type
bool
- bias#
Whether to include bias in linear projections.
- Type
bool
- head_dim#
Dimensionality of each attention head.
- Type
int
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- out_proj#
Linear layer for output projection.
- Type
- attention_performer#
Module for performing attention computation.
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.whisper.modeling_whisper_flax.WhisperDecoder(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe Whisper Decoder transformer stack.
This module processes the target token IDs, incorporates positional embeddings, and attends to both the input sequence (self-attention) and the encoder outputs (cross-attention) through a stack of WhisperDecoderLayer modules.
- config#
Configuration object for the model.
- Type
- embed_tokens#
Embedding layer for target tokens.
- Type
nn.Embed
- embed_positions#
Positional embedding layer.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
nn.List[WhisperDecoderLayer]
- layer_norm#
Final layer normalization (applied to pre-final outputs).
- Type
nn.LayerNorm
- dropout#
Dropout layer.
- Type
nn.Dropout
- padding_idx#
Index of the padding token.
- Type
int
- max_target_positions#
Maximum sequence length for the decoder.
- Type
int
- embed_scale#
Scaling factor for embeddings.
- Type
float | None
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.whisper.modeling_whisper_flax.WhisperDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleA single layer for the Whisper decoder.
This layer consists of self-attention, cross-attention (attending to encoder outputs), and a feed-forward network (FFN), each followed by residual connections and layer normalization.
- config#
Configuration object for the model.
- Type
- embed_dim#
Dimensionality of the input and output features.
- Type
int
- self_attn#
Self-attention module (causal).
- Type
- encoder_attn#
Cross-attention module (attends to encoder outputs).
- Type
- self_attn_layer_norm#
Layer normalization before self-attention.
- Type
nn.LayerNorm
- encoder_attn_layer_norm#
Layer normalization before cross-attention.
- Type
nn.LayerNorm
- fc1#
First linear layer of the FFN.
- Type
- fc2#
Second linear layer of the FFN.
- Type
- final_layer_norm#
Layer normalization after the FFN.
- Type
nn.LayerNorm
- activation_fn#
Activation function for the FFN.
- Type
callable
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.whisper.modeling_whisper_flax.WhisperEncoder(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe Whisper Encoder transformer stack.
This module processes the input audio features (log-Mel spectrogram) through convolutional layers followed by a stack of WhisperEncoderLayer modules.
- config#
Configuration object for the model.
- Type
- conv1#
First convolutional layer.
- Type
nn.Conv
- conv2#
Second convolutional layer.
- Type
nn.Conv
- embed_positions#
Positional embedding layer.
- Type
nn.Embed
- layers#
List of encoder layers.
- Type
nn.List[WhisperEncoderLayer]
- layer_norm#
Final layer normalization.
- Type
nn.LayerNorm
- embed_dim#
Dimensionality of the model.
- Type
int
- num_mel_bins#
Number of Mel frequency bins in the input features.
- Type
int
- padding_idx#
Index of the padding token.
- Type
int
- max_source_positions#
Maximum sequence length for the encoder.
- Type
int
- scale_embedding#
Scaling factor for embeddings.
- Type
float | None
- embed_scale#
Alias for scale_embedding.
- Type
float | None
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.whisper.modeling_whisper_flax.WhisperEncoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleA single layer for the Whisper encoder.
This layer consists of a self-attention mechanism followed by a feed-forward network (FFN), with residual connections and layer normalization.
- config#
Configuration object for the model.
- Type
- embed_dim#
Dimensionality of the input and output features.
- Type
int
- self_attn#
Self-attention module.
- Type
- self_attn_layer_norm#
Layer normalization before self-attention.
- Type
nn.LayerNorm
- fc1#
First linear layer of the FFN.
- Type
- fc2#
Second linear layer of the FFN.
- Type
- final_layer_norm#
Layer normalization after the FFN.
- Type
nn.LayerNorm
- activation_fn#
Activation function for the FFN.
- Type
callable
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.whisper.modeling_whisper_flax.WhisperForAudioClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.whisper.modeling_whisper_flax.WhisperForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#
Computes the loss for the model given a batch of inputs and labels.
This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.
- Parameters
labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.
loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.
loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.
**batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).
- Returns
- A tuple containing:
The model’s output ( Pytree typically including logits, hidden states etc.)
A LossMetrics object containing the calculated loss and potentially other metrics.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- Raises
AssertionError – If labels are required for the loss function but are not provided or inferred.
AssertionError – If sequence classification loss is used without num_labels in the config.
- config: tp.Union[EasyDeLBaseConfig, _CP]#
- decode(decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, mode: Optional[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']] = None, past_key_values: Optional[TransformerCache] = None, cache_metadata: Optional[TransformerMetadata] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None)[source]#
- dtype: jnp.dtype#
- encode(input_features: Array, attention_mask: Optional[Array] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs)[source]#
- generate(input_features, generation_config=None, logits_processor=None, return_timestamps=None, task=None, language=None, is_multilingual=None, **kwargs)[source]#
Generates sequences of token ids for models with a language modeling head.
- Parameters
input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.
generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.
trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.
logits_processor (`LogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.
kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.
- Returns
[~utils.ModelOutput].
- loss_type = 'ForCausalLM'#
- param_dtype: jnp.dtype#
- precision: lax.PrecisionLike#
- prepare_inputs_for_generation(decoder_input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings=None, attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, encoder_outputs=None, **kwargs)[source]#
Sets up the initial inputs required for starting autoregressive generation.
This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.
- Parameters
input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).
max_length (int) – The maximum sequence length that the KV cache should support.
pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.
starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.
shardings (dict | None) – Optional sharding configuration passed to init_cache.
attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).
token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.
- Returns
- A dictionary containing the prepared inputs, typically including:
”past_key_values”: The initialized KV cache.
”attention_mask”: The extended attention mask for generation.
”position_ids”: The calculated initial position IDs.
”token_type_ids”: (Optional) Prepared token type IDs.
This dictionary is then passed through prepare_inputs_for_call.
- Return type
dict
- rngs: nn.Rngs#
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates the keyword arguments for the next generation step.
Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.
- Parameters
model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).
model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.
- Returns
The updated model_kwargs dictionary ready for the next generation step.
- Return type
dict
- class easydel.modules.whisper.modeling_whisper_flax.WhisperModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Whisper Model transformer implementing the encoder-decoder architecture.
- config#
Configuration object for the model.
- Type
- encoder#
The encoder stack.
- Type
- decoder#
The decoder stack.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- decode(encoder_hidden_states: Array, decoder_input_ids: Array, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, mode: Optional[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']] = None, past_key_values: Optional[TransformerCache] = None, cache_metadata: Optional[TransformerMetadata] = None, output_attentions: bool = False, output_hidden_states: bool = False)[source]#
Performs decoding using the decoder module.
- Parameters
encoder_hidden_states (jnp.ndarray) – Hidden states from the encoder.
decoder_input_ids (jnp.ndarray) – Decoder input token IDs.
decoder_attention_mask (tp.Optional[jnp.ndarray]) – Mask for decoder self-attention.
decoder_position_ids (tp.Optional[jnp.ndarray]) – Position IDs for decoder inputs.
past_key_values (tp.Optional[TransformerCache]) – Cached key/value states.
cache_metadata (tp.Optional[TransformerMetadata]) – Metadata for paged attention.
output_attentions (bool) – Whether to return attention weights.
output_hidden_states (bool) – Whether to return hidden states for all layers.
- Returns
Decoder output.
- Return type
- encode(input_features: Array, output_attentions: bool = False, output_hidden_states: bool = False)[source]#
Performs encoding using the encoder module.
- Parameters
input_features (jnp.ndarray) – Input audio features.
output_attentions (bool) – Whether to return attention weights.
output_hidden_states (bool) – Whether to return hidden states for all layers.
- Returns
Encoder output.
- Return type
BaseModelOutput | tuple
- easydel.modules.whisper.modeling_whisper_flax.shift_tokens_right(input_ids: Array, pad_token_id: int, decoder_start_token_id: int)[source]#
Shift input ids one token to the right using JAX.
- easydel.modules.whisper.modeling_whisper_flax.sinusoidal_embedding_init(key, shape, dtype=<class 'jax.numpy.float64'>) Array[source]#
Initializes sinusoidal positional embeddings.
- Parameters
key – JAX PRNG key (unused, but part of standard initializer signature).
shape (tuple) – Shape of the embedding matrix (length, channels).
dtype – Data type of the embeddings (default: jnp.float_).
- Returns
Sinusoidal positional embedding matrix.
- Return type
- Raises
ValueError – If the number of channels is not even.