easydel.modules.whisper.modeling_whisper_flax

Contents

easydel.modules.whisper.modeling_whisper_flax#

class easydel.modules.whisper.modeling_whisper_flax.WhisperAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Whisper Attention mechanism.

This module implements the standard multi-head attention mechanism used in both the encoder and decoder of the Whisper model.

config#

Configuration object for the model.

Type

WhisperConfig

embed_dim#

Dimensionality of the embedding layer.

Type

int

num_heads#

Number of attention heads.

Type

int

dropout#

Dropout probability.

Type

float

causal#

Whether this attention is causal (used in decoder self-attention).

Type

bool

bias#

Whether to include bias in linear projections.

Type

bool

head_dim#

Dimensionality of each attention head.

Type

int

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

out_proj#

Linear layer for output projection.

Type

ParallelLinear

attention_performer#

Module for performing attention computation.

Type

FlexibleAttentionModule

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.whisper.modeling_whisper_flax.WhisperDecoder(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The Whisper Decoder transformer stack.

This module processes the target token IDs, incorporates positional embeddings, and attends to both the input sequence (self-attention) and the encoder outputs (cross-attention) through a stack of WhisperDecoderLayer modules.

config#

Configuration object for the model.

Type

WhisperConfig

embed_tokens#

Embedding layer for target tokens.

Type

nn.Embed

embed_positions#

Positional embedding layer.

Type

nn.Embed

layers#

List of decoder layers.

Type

nn.List[WhisperDecoderLayer]

layer_norm#

Final layer normalization (applied to pre-final outputs).

Type

nn.LayerNorm

dropout#

Dropout layer.

Type

nn.Dropout

padding_idx#

Index of the padding token.

Type

int

max_target_positions#

Maximum sequence length for the decoder.

Type

int

embed_scale#

Scaling factor for embeddings.

Type

float | None

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.whisper.modeling_whisper_flax.WhisperDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

A single layer for the Whisper decoder.

This layer consists of self-attention, cross-attention (attending to encoder outputs), and a feed-forward network (FFN), each followed by residual connections and layer normalization.

config#

Configuration object for the model.

Type

WhisperConfig

embed_dim#

Dimensionality of the input and output features.

Type

int

self_attn#

Self-attention module (causal).

Type

WhisperAttention

encoder_attn#

Cross-attention module (attends to encoder outputs).

Type

WhisperAttention

self_attn_layer_norm#

Layer normalization before self-attention.

Type

nn.LayerNorm

encoder_attn_layer_norm#

Layer normalization before cross-attention.

Type

nn.LayerNorm

fc1#

First linear layer of the FFN.

Type

ParallelLinear

fc2#

Second linear layer of the FFN.

Type

ParallelLinear

final_layer_norm#

Layer normalization after the FFN.

Type

nn.LayerNorm

activation_fn#

Activation function for the FFN.

Type

callable

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.whisper.modeling_whisper_flax.WhisperEncoder(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The Whisper Encoder transformer stack.

This module processes the input audio features (log-Mel spectrogram) through convolutional layers followed by a stack of WhisperEncoderLayer modules.

config#

Configuration object for the model.

Type

WhisperConfig

conv1#

First convolutional layer.

Type

nn.Conv

conv2#

Second convolutional layer.

Type

nn.Conv

embed_positions#

Positional embedding layer.

Type

nn.Embed

layers#

List of encoder layers.

Type

nn.List[WhisperEncoderLayer]

layer_norm#

Final layer normalization.

Type

nn.LayerNorm

embed_dim#

Dimensionality of the model.

Type

int

num_mel_bins#

Number of Mel frequency bins in the input features.

Type

int

padding_idx#

Index of the padding token.

Type

int

max_source_positions#

Maximum sequence length for the encoder.

Type

int

scale_embedding#

Scaling factor for embeddings.

Type

float | None

embed_scale#

Alias for scale_embedding.

Type

float | None

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.whisper.modeling_whisper_flax.WhisperEncoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

A single layer for the Whisper encoder.

This layer consists of a self-attention mechanism followed by a feed-forward network (FFN), with residual connections and layer normalization.

config#

Configuration object for the model.

Type

WhisperConfig

embed_dim#

Dimensionality of the input and output features.

Type

int

self_attn#

Self-attention module.

Type

WhisperAttention

self_attn_layer_norm#

Layer normalization before self-attention.

Type

nn.LayerNorm

fc1#

First linear layer of the FFN.

Type

ParallelLinear

fc2#

Second linear layer of the FFN.

Type

ParallelLinear

final_layer_norm#

Layer normalization after the FFN.

Type

nn.LayerNorm

activation_fn#

Activation function for the FFN.

Type

callable

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.whisper.modeling_whisper_flax.WhisperForAudioClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.whisper.modeling_whisper_flax.WhisperForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#

Computes the loss for the model given a batch of inputs and labels.

This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.

Parameters
  • labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.

  • loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.

  • loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.

  • **batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).

Returns

A tuple containing:
  • The model’s output ( Pytree typically including logits, hidden states etc.)

  • A LossMetrics object containing the calculated loss and potentially other metrics.

Return type

tp.Tuple[tp.Any, LossMetrics]

Raises
  • AssertionError – If labels are required for the loss function but are not provided or inferred.

  • AssertionError – If sequence classification loss is used without num_labels in the config.

config: tp.Union[EasyDeLBaseConfig, _CP]#
decode(decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, mode: Optional[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']] = None, past_key_values: Optional[TransformerCache] = None, cache_metadata: Optional[TransformerMetadata] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None)[source]#
dtype: jnp.dtype#
encode(input_features: Array, attention_mask: Optional[Array] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs)[source]#
generate(input_features, generation_config=None, logits_processor=None, return_timestamps=None, task=None, language=None, is_multilingual=None, **kwargs)[source]#

Generates sequences of token ids for models with a language modeling head.

Parameters
  • input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.

  • generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.

  • trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.

  • logits_processor (`LogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.

  • kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.

Returns

[~utils.ModelOutput].

loss_type = 'ForCausalLM'#
param_dtype: jnp.dtype#
precision: lax.PrecisionLike#
prepare_inputs_for_generation(decoder_input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings=None, attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, encoder_outputs=None, **kwargs)[source]#

Sets up the initial inputs required for starting autoregressive generation.

This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.

Parameters
  • input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).

  • max_length (int) – The maximum sequence length that the KV cache should support.

  • pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.

  • starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.

  • shardings (dict | None) – Optional sharding configuration passed to init_cache.

  • attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).

  • token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.

Returns

A dictionary containing the prepared inputs, typically including:
  • ”past_key_values”: The initialized KV cache.

  • ”attention_mask”: The extended attention mask for generation.

  • ”position_ids”: The calculated initial position IDs.

  • ”token_type_ids”: (Optional) Prepared token type IDs.

This dictionary is then passed through prepare_inputs_for_call.

Return type

dict

rngs: nn.Rngs#
update_inputs_for_generation(model_outputs, model_kwargs)[source]#

Updates the keyword arguments for the next generation step.

Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.

Parameters
  • model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).

  • model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.

Returns

The updated model_kwargs dictionary ready for the next generation step.

Return type

dict

class easydel.modules.whisper.modeling_whisper_flax.WhisperModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Whisper Model transformer implementing the encoder-decoder architecture.

config#

Configuration object for the model.

Type

WhisperConfig

encoder#

The encoder stack.

Type

WhisperEncoder

decoder#

The decoder stack.

Type

WhisperDecoder

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

decode(encoder_hidden_states: Array, decoder_input_ids: Array, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, mode: Optional[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']] = None, past_key_values: Optional[TransformerCache] = None, cache_metadata: Optional[TransformerMetadata] = None, output_attentions: bool = False, output_hidden_states: bool = False)[source]#

Performs decoding using the decoder module.

Parameters
  • encoder_hidden_states (jnp.ndarray) – Hidden states from the encoder.

  • decoder_input_ids (jnp.ndarray) – Decoder input token IDs.

  • decoder_attention_mask (tp.Optional[jnp.ndarray]) – Mask for decoder self-attention.

  • decoder_position_ids (tp.Optional[jnp.ndarray]) – Position IDs for decoder inputs.

  • past_key_values (tp.Optional[TransformerCache]) – Cached key/value states.

  • cache_metadata (tp.Optional[TransformerMetadata]) – Metadata for paged attention.

  • output_attentions (bool) – Whether to return attention weights.

  • output_hidden_states (bool) – Whether to return hidden states for all layers.

Returns

Decoder output.

Return type

BaseModelOutputWithPastAndCrossAttentions | tuple

encode(input_features: Array, output_attentions: bool = False, output_hidden_states: bool = False)[source]#

Performs encoding using the encoder module.

Parameters
  • input_features (jnp.ndarray) – Input audio features.

  • output_attentions (bool) – Whether to return attention weights.

  • output_hidden_states (bool) – Whether to return hidden states for all layers.

Returns

Encoder output.

Return type

BaseModelOutput | tuple

easydel.modules.whisper.modeling_whisper_flax.shift_tokens_right(input_ids: Array, pad_token_id: int, decoder_start_token_id: int)[source]#

Shift input ids one token to the right using JAX.

easydel.modules.whisper.modeling_whisper_flax.sinusoidal_embedding_init(key, shape, dtype=<class 'jax.numpy.float64'>) Array[source]#

Initializes sinusoidal positional embeddings.

Parameters
  • key – JAX PRNG key (unused, but part of standard initializer signature).

  • shape (tuple) – Shape of the embedding matrix (length, channels).

  • dtype – Data type of the embeddings (default: jnp.float_).

Returns

Sinusoidal positional embedding matrix.

Return type

jax.Array

Raises

ValueError – If the number of channels is not even.