easydel.inference.whisper_inference#
- class easydel.inference.whisper_inference.vWhisperInference(model: ~typing.Any, tokenizer: ~typing.Any, processor: ~typing.Any, inference_config: ~typing.Optional[~easydel.inference.whisper_inference.vWhisperInferenceConfig] = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>)[source]#
Bases:
objectWhisper inference pipeline for performing speech-to-text transcription or translation.
- Parameters
model (WhisperForConditionalGeneration) – The fine-tuned Whisper model to use for inference.
tokenizer (WhisperTokenizer) – Tokenizer for Whisper.
processor (WhisperProcessor) – Processor for Whisper.
inference_config (vWhisperInferenceConfig, optional) – Inference configuration.
dtype (jax.typing.DTypeLike, optional, defaults to jnp.float32) – Data type for computations.
Example usage:
>>> import easydel as ed >>> from transformers import WhisperTokenizer, WhisperProcessor
>>> REPO_ID = "openai/whisper-small" # Replace with your desired model
>>> model = ed.AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained( ... REPO_ID, ... # ... (config_kwargs as needed) >>> ) >>> tokenizer = WhisperTokenizer.from_pretrained(REPO_ID) >>> processor = WhisperProcessor.from_pretrained(REPO_ID)
>>> inference = vWhisperInference( ... model=model, ... tokenizer=tokenizer, ... processor=processor, ... dtype=jnp.float16, # Or jnp.float32 >>> )
>>> result = inference("sample1.flac", return_timestamps=True) >>> print(result)
>>> # Example using a URL: >>> result_url = inference( ... "https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy/raw/main/common_voice_en_100038.mp3", ... return_timestamps=True >>> ) >>> print(result_url)
>>> # Example specifying language and task: >>> result_lang_task = inference( ... "sample1.flac", language="en", task="transcribe", return_timestamps=True >>> ) >>> print(result_lang_task)
- chunk_iter_with_batch(audio_array: Array, chunk_length: int, stride_left: int, stride_right: int, batch_size: int)[source]#
- generate(audio_input: Union[str, bytes, ndarray, Dict[str, Union[ndarray, int]]], chunk_length_s: float = 30.0, stride_length_s: Optional[Union[float, list[float]]] = None, batch_size: Optional[int] = None, language: Optional[str] = None, task: Optional[str] = None, return_timestamps: Optional[bool] = None)[source]#
Transcribe or translate audio input.
- Parameters
audio_input (tp.Union[str, bytes, np.ndarray, tp.Dict[str, tp.Union[np.ndarray, int]]]) – Input audio. Can be a local file path, URL, bytes, numpy array, or a dictionary containing the array and sampling rate.
chunk_length_s (float, optional, defaults to 30.0) – Length of audio chunks in seconds.
stride_length_s (float or list[float], optional) – Stride length for chunking audio, in seconds. Defaults to chunk_length_s / 6.
batch_size (int, optional) – Batch size for processing. Defaults to the batch_size in inference_config.
language (str, optional) – Language of the input audio. Defaults to the language in inference_config.
task (str, optional) – Task to perform (e.g., “transcribe”, “translate”). Defaults to the task in inference_config.
return_timestamps (bool, optional) – Whether to return timestamps with the transcription. Defaults to the return_timestamps in inference_config.
- Returns
A dictionary containing the transcribed text (“text”) and optionally other information like timestamps or detected language.
- Return type
dict
- class easydel.inference.whisper_inference.vWhisperInferenceConfig(batch_size: Optional[int] = 1, max_length: Optional[int] = None, generation_config: Optional[Any] = None)[source]#
Bases:
objectConfiguration class for Whisper inference.
- Parameters
batch_size (int, optional, defaults to 1) – Batch size used for inference.
max_length (int, optional) – Maximum sequence length for generation.
generation_config (transformers.GenerationConfig, optional) – Generation configuration object.
logits_processor (optional) – Not used.
return_timestamps (bool, optional) – Whether to return timestamps with the transcribed text.
task (str, optional) – Task for the model (e.g., “transcribe”, “translate”).
language (str, optional) – Language of the input audio.
is_multilingual (bool, optional) – Whether the model is multilingual.
- batch_size: Optional[int] = 1#
- generation_config: Optional[Any] = None#
- is_multilingual = None#
- language = None#
- logits_processor = None#
- max_length: Optional[int] = None#
- return_timestamps = None#
- task = None#