easydel.inference.whisper_inference#

class easydel.inference.whisper_inference.vWhisperInference(model: ~typing.Any, tokenizer: ~typing.Any, processor: ~typing.Any, inference_config: ~typing.Optional[~easydel.inference.whisper_inference.vWhisperInferenceConfig] = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>)[source]#

Bases: object

Whisper inference pipeline for performing speech-to-text transcription or translation.

Parameters
  • model (WhisperForConditionalGeneration) – The fine-tuned Whisper model to use for inference.

  • tokenizer (WhisperTokenizer) – Tokenizer for Whisper.

  • processor (WhisperProcessor) – Processor for Whisper.

  • inference_config (vWhisperInferenceConfig, optional) – Inference configuration.

  • dtype (jax.typing.DTypeLike, optional, defaults to jnp.float32) – Data type for computations.

Example usage:

>>> import easydel as ed
>>> from transformers import WhisperTokenizer, WhisperProcessor
>>> REPO_ID = "openai/whisper-small"  # Replace with your desired model
>>> model = ed.AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained(
...             REPO_ID,
...             # ... (config_kwargs as needed)
>>> )
>>> tokenizer = WhisperTokenizer.from_pretrained(REPO_ID)
>>> processor = WhisperProcessor.from_pretrained(REPO_ID)
>>> inference = vWhisperInference(
...             model=model,
...             tokenizer=tokenizer,
...             processor=processor,
...             dtype=jnp.float16,  # Or jnp.float32
>>> )
>>> result = inference("sample1.flac", return_timestamps=True)
>>> print(result)
>>> # Example using a URL:
>>> result_url = inference(
...             "https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy/raw/main/common_voice_en_100038.mp3",
...             return_timestamps=True
>>> )
>>> print(result_url)
>>> # Example specifying language and task:
>>> result_lang_task = inference(
...             "sample1.flac", language="en", task="transcribe", return_timestamps=True
>>> )
>>> print(result_lang_task)
chunk_iter_with_batch(audio_array: Array, chunk_length: int, stride_left: int, stride_right: int, batch_size: int)[source]#
generate(audio_input: Union[str, bytes, ndarray, Dict[str, Union[ndarray, int]]], chunk_length_s: float = 30.0, stride_length_s: Optional[Union[float, list[float]]] = None, batch_size: Optional[int] = None, language: Optional[str] = None, task: Optional[str] = None, return_timestamps: Optional[bool] = None)[source]#

Transcribe or translate audio input.

Parameters
  • audio_input (tp.Union[str, bytes, np.ndarray, tp.Dict[str, tp.Union[np.ndarray, int]]]) – Input audio. Can be a local file path, URL, bytes, numpy array, or a dictionary containing the array and sampling rate.

  • chunk_length_s (float, optional, defaults to 30.0) – Length of audio chunks in seconds.

  • stride_length_s (float or list[float], optional) – Stride length for chunking audio, in seconds. Defaults to chunk_length_s / 6.

  • batch_size (int, optional) – Batch size for processing. Defaults to the batch_size in inference_config.

  • language (str, optional) – Language of the input audio. Defaults to the language in inference_config.

  • task (str, optional) – Task to perform (e.g., “transcribe”, “translate”). Defaults to the task in inference_config.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcription. Defaults to the return_timestamps in inference_config.

Returns

A dictionary containing the transcribed text (“text”) and optionally other information like timestamps or detected language.

Return type

dict

get_decoder_input_ids(generation_config: Optional[Any] = None, task: Optional[str] = None, language: Optional[str] = None, return_timestamps: bool = False) list[Tuple[int, int]][source]#
class easydel.inference.whisper_inference.vWhisperInferenceConfig(batch_size: Optional[int] = 1, max_length: Optional[int] = None, generation_config: Optional[Any] = None)[source]#

Bases: object

Configuration class for Whisper inference.

Parameters
  • batch_size (int, optional, defaults to 1) – Batch size used for inference.

  • max_length (int, optional) – Maximum sequence length for generation.

  • generation_config (transformers.GenerationConfig, optional) – Generation configuration object.

  • logits_processor (optional) – Not used.

  • return_timestamps (bool, optional) – Whether to return timestamps with the transcribed text.

  • task (str, optional) – Task for the model (e.g., “transcribe”, “translate”).

  • language (str, optional) – Language of the input audio.

  • is_multilingual (bool, optional) – Whether the model is multilingual.

batch_size: Optional[int] = 1#
generation_config: Optional[Any] = None#
is_multilingual = None#
language = None#
logits_processor = None#
max_length: Optional[int] = None#
return_timestamps = None#
task = None#
tree_flatten()[source]#
classmethod tree_unflatten(aux, children)[source]#