easydel.modules.whisper.__init__#
- class easydel.modules.whisper.__init__.WhisperConfig(vocab_size=51865, num_mel_bins=80, encoder_layers=4, encoder_attention_heads=6, decoder_layers=4, decoder_attention_heads=6, decoder_ffn_dim=1536, encoder_ffn_dim=1536, encoder_layerdrop=0.0, decoder_layerdrop=0.0, decoder_start_token_id=50257, use_cache=True, is_encoder_decoder=True, activation_function='gelu', d_model=384, dropout=0.0, attention_dropout=0.0, activation_dropout=0.0, init_std=0.02, scale_embedding=False, max_source_positions=1500, max_target_positions=448, pad_token_id=50256, bos_token_id=50256, eos_token_id=50256, suppress_tokens=None, begin_suppress_tokens=[220, 50256], use_weighted_layer_sum=False, classifier_proj_size=256, apply_spec_augment=False, mask_time_prob=0.05, mask_time_length=10, mask_time_min_masks=2, mask_feature_prob=0.0, mask_feature_length=10, mask_feature_min_masks=0, median_filter_width=7, bits: Optional[int] = None, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, **kwargs)[source]#
Bases:
EasyDeLBaseConfigConfiguration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.
- Parameters
vocab_size (int, optional, defaults to 51865) – Vocabulary size of the Whisper model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [~easydel.modules.WhisperModel].
num_mel_bins (int, optional, defaults to 80) – Number of mel bins used by the feature extractor.
encoder_layers (int, optional, defaults to 6) – Number of encoder layers.
encoder_attention_heads (int, optional, defaults to 4) – Number of attention heads for each attention layer in the Transformer encoder.
decoder_layers (int, optional, defaults to 6) – Number of decoder layers.
decoder_attention_heads (int, optional, defaults to 4) – Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (int, optional, defaults to 1536) – Dimensionality of the decoder feed-forward network (FFN) layer.
encoder_ffn_dim (int, optional, defaults to 1536) – Dimensionality of the encoder feed-forward network (FFN) layer.
encoder_layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the encoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for more details.
decoder_layerdrop (float, optional, defaults to 0.0) – The LayerDrop probability for the decoder. See the [LayerDrop paper](https://arxiv.org/abs/1909.11556) for more details.
d_model (int, optional, defaults to 256) – Dimensionality of the layers and the pooler layer.
activation_function (str, optional, defaults to “gelu”) – The non-linear activation function (function or string) in the encoder and pooler. If string, “gelu”, “relu”, “silu” and “gelu_new” are supported.
dropout (float, optional, defaults to 0.1) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.
activation_dropout (float, optional, defaults to 0.0) – The dropout ratio for activations inside the fully connected layer.
init_std (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_embedding (bool, optional, defaults to False) – Scale embeddings by dividing by sqrt(d_model).
max_source_positions (int, optional, defaults to 1500) – The maximum sequence length allowed for the source text input to the model. tp.Any longer inputs will be truncated.
max_target_positions (int, optional, defaults to 448) – The maximum sequence length allowed for the target text input to the model. tp.Any longer inputs will be truncated.
use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models).
apply_spec_augment (bool, optional, defaults to False) – Whether to apply SpecAugment data augmentation.
mask_time_prob (float, optional, defaults to 0.05) – Propability of each feature vector along the time axis to be chosen as the start of the vector span to be masked. Approximately mask_time_prob * sequence_length // mask_time_length feature vectors will be masked along the time axis. This is only relevant if apply_spec_augment is set to True.
mask_time_length (int, optional, defaults to 10) – Length of vector span along the time axis.
mask_time_min_masks (int, optional, defaults to 2) – The minimum number of masks of length mask_feature_length generated along the time axis, each time mask, the mask will be filled with floats sampled in (random_lower_bound, random_upper_bound).
mask_feature_prob (float, optional, defaults to 0.0) – Propability of each feature vector along the feature axis to be chosen as the start of the vector span to be masked. Approximately mask_time_prob * hidden_size // mask_feature_length feature vectors will be masked along the time axis. This is only relevant if apply_spec_augment is set to True.
mask_feature_length (int, optional, defaults to 10) – Length of vector span along the feature axis.
mask_feature_min_masks (int, optional, defaults to 0) – The minimum number of masks of length mask_feature_length generated along the feature axis, each time mask, the mask will be filled with floats sampled in (random_lower_bound, random_upper_bound).
median_filter_width (int, optional, defaults to 7) – The width of the median filter applied to the mask.
bits (int, optional) – The number of bits to quantize the model to. If None, the model is not quantized.
gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – What to save during gradient checkpointing. Choose one of “nothing_saveable”, “first_half_saveable”, “full_saveable”.
- attribute_map: dict[str, str] = {'hidden_size': 'd_model', 'num_attention_heads': 'encoder_attention_heads'}#
- get_partition_rules(*args, **kwargs)[source]#
Returns the partition rules for the Whisper model. Arguments are ignored.
- Returns
Partition rules.
- Return type
tuple
- model_type: str = 'whisper'#
- class easydel.modules.whisper.__init__.WhisperForAudioClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule
- class easydel.modules.whisper.__init__.WhisperForConditionalGeneration(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModule- compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#
Computes the loss for the model given a batch of inputs and labels.
This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.
- Parameters
labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.
loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.
loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.
**batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).
- Returns
- A tuple containing:
The model’s output ( Pytree typically including logits, hidden states etc.)
A LossMetrics object containing the calculated loss and potentially other metrics.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- Raises
AssertionError – If labels are required for the loss function but are not provided or inferred.
AssertionError – If sequence classification loss is used without num_labels in the config.
- decode(decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, mode: Optional[Literal['__autoregressive__', '__prefill__', '__train__', '__insert__']] = None, past_key_values: Optional[TransformerCache] = None, cache_metadata: Optional[TransformerMetadata] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None)[source]#
- encode(input_features: Array, attention_mask: Optional[Array] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs)[source]#
- generate(input_features, generation_config=None, logits_processor=None, return_timestamps=None, task=None, language=None, is_multilingual=None, **kwargs)[source]#
Generates sequences of token ids for models with a language modeling head.
- Parameters
input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.
generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.
trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.
logits_processor (`LogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.
kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.
- Returns
[~utils.ModelOutput].
- loss_type = 'ForCausalLM'#
- prepare_inputs_for_generation(decoder_input_ids, max_length: int, pad_token_id: int, starts: int | None = None, shardings=None, attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, encoder_outputs=None, **kwargs)[source]#
Sets up the initial inputs required for starting autoregressive generation.
This function initializes the Key-Value cache (past_key_values) using init_cache, calculates the initial position_ids based on the input attention_mask (or assumes a contiguous range if no mask is provided), and prepares an extended attention_mask suitable for caching. It ensures inputs are placed on the correct devices/shards.
- Parameters
input_ids (chex.Array) – The initial sequence of token IDs. Shape (batch_size, seq_length).
max_length (int) – The maximum sequence length that the KV cache should support.
pad_token_id (int) – The ID used for padding tokens. Used to calculate starts if not provided.
starts (int | None) – Optional pre-calculated starting positions (number of leading pads). If None, calculated using compute_prefill_length.
shardings (dict | None) – Optional sharding configuration passed to init_cache.
attention_mask (tp.Optional[chex.Array]) – An optional mask indicating which tokens should be attended to. Shape (batch_size, seq_length).
token_type_ids (tp.Optional[chex.Array]) – Optional segment IDs for models that use them.
- Returns
- A dictionary containing the prepared inputs, typically including:
”past_key_values”: The initialized KV cache.
”attention_mask”: The extended attention mask for generation.
”position_ids”: The calculated initial position IDs.
”token_type_ids”: (Optional) Prepared token type IDs.
This dictionary is then passed through prepare_inputs_for_call.
- Return type
dict
- update_inputs_for_generation(model_outputs, model_kwargs)[source]#
Updates the keyword arguments for the next generation step.
Specifically, it takes the past_key_values from the model_outputs of the current step and updates the model_kwargs with them. It also increments the position_ids by one for the next token prediction.
- Parameters
model_outputs – The output object from the model’s forward pass in the previous step (should contain a past_key_values attribute).
model_kwargs (dict) – The dictionary of keyword arguments used for the model call. This dictionary will be modified in-place or a new one returned.
- Returns
The updated model_kwargs dictionary ready for the next generation step.
- Return type
dict
- class easydel.modules.whisper.__init__.WhisperTimeStampLogitsProcessor(generate_config, model_config, decoder_input_length)[source]#
Bases:
LogitsProcessorA specialized [LogitsProcessor] tailored for handling timestamp tokens during generation with Whisper-style models used for Automatic Speech Recognition (ASR).
It enforces several constraints specific to timestamp prediction: 1. Suppresses `<|notimestamps|>`: Prevents the model from predicting the token
that indicates the absence of timestamps.
Alternating Tokens: Enforces that text tokens and timestamp tokens generally alternate. If the last generated token was a timestamp, it biases against predicting another timestamp immediately after (unless it’s the very beginning or certain edge cases).
Initial Timestamp Limit: Restricts the maximum value of the first timestamp token predicted using max_initial_timestamp_index.
Timestamp Probability Check: If the total probability mass assigned to all valid timestamp tokens is higher than the probability of the single most likely non-timestamp token, it forces the model to sample a timestamp token by suppressing all non-timestamp tokens.
Note
This processor assumes the existence of specific token IDs related to timestamps (e.g., eos_token_id, no_timestamps_token_id, timestamp_begin) which are typically defined in the model’s generation configuration.
- Parameters
generate_config – Configuration object containing Whisper-specific generation parameters like eos_token_id, no_timestamps_token_id, is_multilingual, max_initial_timestamp_index.
model_config – The model’s configuration (used for vocab_size as a fallback).
decoder_input_length – The length of the initial input sequence provided to the decoder (e.g., the prompt length).