easydel.modules.whisper.modeling_whisper_flax#

class easydel.modules.whisper.modeling_whisper_flax.WhisperAttention(*args: Any, **kwargs: Any)[source]#

Bases: FlaxAttentionModule

class easydel.modules.whisper.modeling_whisper_flax.WhisperDecoder(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.whisper.modeling_whisper_flax.WhisperDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.whisper.modeling_whisper_flax.WhisperEncoder(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.whisper.modeling_whisper_flax.WhisperEncoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class easydel.modules.whisper.modeling_whisper_flax.WhisperForAudioClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

class easydel.modules.whisper.modeling_whisper_flax.WhisperForConditionalGeneration(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#

basic compute_loss call

config: tp.Union[EasyDeLBaseConfig, _CP]#
decode(decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, past_key_values: Optional[dict] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None)[source]#
dtype: jnp.dtype#
encode(input_features: Array, attention_mask: Optional[Array] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs)[source]#
generate(input_features, generation_config=None, logits_processor=None, return_timestamps=None, task=None, language=None, is_multilingual=None, **kwargs)[source]#

Generates sequences of token ids for models with a language modeling head.

Parameters
  • input_ids (chex.Array of shape (batch_size, sequence_length)) – The sequence used as a prompt for the generation.

  • generation_config (~generation.GenerationConfig, optional) – The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them. If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [~generation.GenerationConfig]’s default values, whose documentation should be checked to parameterize generation.

  • trace (bool, optional, defaults to True) – Whether to trace generation. Setting trace=False should only be used for debugging and will lead to a considerably slower runtime.

  • logits_processor (`FlaxLogitsProcessorList `, optional) – Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users.

  • kwargs (tp.Dict[str, Any], optional) – Ad hoc parametrization of generate_config and/or additional model-specific kwargs that will be forwarded to the forward function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.

Returns

[~utils.ModelOutput].

loss_type = 'ForCausalLM'#
param_dtype: jnp.dtype#
precision: lax.PrecisionLike#
prepare_inputs_for_generation(decoder_input_ids, max_length, attention_mask: Optional[Array] = None, decoder_attention_mask: Optional[Array] = None, encoder_outputs=None, **kwargs)[source]#

The prepare_inputs_for_generation function is used to prepare the inputs for a generation task.

Parameters
  • self – Access variables that belong to the class

  • input_ids – Pass in the input tokens

  • max_length – Set the length of the sequence to be generated

  • attention_mask – tp.Optional[chex.Array]: Mask the attention weights

Returns

A dictionary of the past_key_values, attention_mask and position ids

rngs: nn.Rngs#
update_inputs_for_generation(model_outputs, model_kwargs)[source]#
class easydel.modules.whisper.modeling_whisper_flax.WhisperModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

decode(encoder_hidden_states: Array, decoder_input_ids: Array, decoder_attention_mask: Optional[Array] = None, decoder_position_ids: Optional[Array] = None, past_key_values: Optional[TransformerCache] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True)[source]#
encode(input_features: Array, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True)[source]#
easydel.modules.whisper.modeling_whisper_flax.shift_tokens_right(input_ids: Array, pad_token_id: int, decoder_start_token_id: int)[source]#

Shift input ids one token to the right using JAX.

easydel.modules.whisper.modeling_whisper_flax.sinusoidal_embedding_init(key, shape, dtype=<class 'jax.numpy.float64'>) Array[source]#