easydel.inference.__init__#
- class easydel.inference.__init__.vInference(model: None, processor_class: None, generation_config: Optional[vInferenceConfig] = None, seed: Optional[int] = None, input_partition_spec: Optional[PartitionSpec] = None, max_new_tokens: int = 512, inference_name: Optional[str] = None)[source]#
Bases:
objectClass for performing text generation using a pre-trained language graphdef in EasyDeL.
This class handles the generation process, including initialization, precompilation, and generating text in streaming chunks.
- property SEQUENCE_DIM_MAPPING#
- generate(input_ids: Array, attention_mask: Optional[Array] = None, *, graphstate: Optional[State[Key, VariableState[Any]]] = None, graphother: Optional[State[Key, VariableState[Any]]] = None, **model_kwargs) Generator[Union[SampleState, Any], SampleState, SampleState][source]#
Generates text in streaming chunks with comprehensive input adjustment.
- Parameters
input_ids – Input token IDs as a JAX array
attention_mask – Optional attention mask for the input
graphstate (nn.GraphState, optional) – in case that you want to update model state for generation.
graphother (nn.GraphState, optional) – in case that you want to update model ostate for generation.
**model_kwargs – Additional model-specific keyword arguments
- Returns
Generator yielding SampleState objects containing generation results and metrics
- property inference_name#
- property metrics#
- property model#
- property model_prefill_length: int#
Calculate the maximum length available for input prefill by subtracting the maximum new tokens from the model’s maximum sequence length.
- Returns
The maximum length available for input prefill
- Return type
int
- Raises
ValueError – If no maximum sequence length configuration is found
- precompile(config: vInferencePreCompileConfig)[source]#
Precompiles the generation functions for a given batch size and input length.
This function checks if the generation functions have already been compiled for the given configuration. If not, it compiles them asynchronously and stores them in a cache.
- Returns
True if precompilation was successful, False otherwise.
- Return type
bool
- property tokenizer#
- class easydel.inference.__init__.vInferenceApiServer(inference_map: Union[Dict[str, Any], Any] = None, inference_init_call: Optional[Callable[[], Any]] = None, max_workers: int = 10)[source]#
Bases:
object- async chat_completions(request: ChatCompletionRequest)[source]#
- count_tokens(request: CountTokenRequest)[source]#
- class easydel.inference.__init__.vInferenceConfig(max_new_tokens: int = 64, min_length: Optional[int] = None, streaming_chunks: int = 16, temperature: float = 0.0, top_p: float = 0.95, top_k: int = 50, do_sample: bool = True, no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Union[int, Dict[int, int], NoneType] = 1, suppress_tokens: Optional[list] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Union[int, List[int], NoneType] = None, partition_rules: Optional[Tuple[Tuple[str, Any]]] = None, partition_axis: Optional[eformer.escale.partition.constraints.PartitionAxis] = None, _loop_rows: Optional[int] = None)[source]#
Bases:
object- bos_token_id: Optional[int] = None#
- do_sample: bool = True#
- eos_token_id: Optional[Union[int, List[int]]] = None#
- forced_bos_token_id: Optional[int] = None#
- forced_eos_token_id: Optional[int] = None#
- get_partition_rules(runtime_config: Optional[vInferencePreCompileConfig] = None)[source]#
- max_new_tokens: int = 64#
- min_length: Optional[int] = None#
- no_repeat_ngram_size: Optional[int] = None#
- num_return_sequences: Optional[Union[int, Dict[int, int]]] = 1#
- pad_token_id: Optional[int] = None#
- partition_axis: Optional[PartitionAxis] = None#
- partition_rules: Optional[Tuple[Tuple[str, Any]]] = None#
- replace(**kwargs)#
- streaming_chunks: int = 16#
- suppress_tokens: Optional[list] = None#
- temperature: float = 0.0#
- top_k: int = 50#
- top_p: float = 0.95#
- class easydel.inference.__init__.vInferencePreCompileConfig(batch_size: Union[int, List[int]] = 1, prefill_length: Union[int, List[int], NoneType] = None, vision_included: Union[bool, List[bool]] = False, vision_batch_size: Union[int, List[int], NoneType] = None, vision_channels: Union[int, List[int], NoneType] = None, vision_height: Union[int, List[int], NoneType] = None, vision_width: Union[int, List[int], NoneType] = None, required_props: Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]], NoneType] = None)[source]#
Bases:
object- batch_size: Union[int, List[int]] = 1#
- get_standalones()[source]#
Creates standalone configurations when any field contains a list. Returns a list of standalone vInferencePreCompileConfig instances.
For example, if batch_size=[1, 2, 3, 4], it will create 4 standalone configs with batch_size values 1, 2, 3, and 4 respectively.
- prefill_length: Optional[Union[int, List[int]]] = None#
- replace(**kwargs)#
- required_props: Optional[Union[Mapping[str, Dict[str, Any]], List[Mapping[str, Dict[str, Any]]]]] = None#
- vision_batch_size: Optional[Union[int, List[int]]] = None#
- vision_channels: Optional[Union[int, List[int]]] = None#
- vision_height: Optional[Union[int, List[int]]] = None#
- vision_included: Union[bool, List[bool]] = False#
- vision_width: Optional[Union[int, List[int]]] = None#
- class easydel.inference.__init__.vWhisperInference(model: ~typing.Any, tokenizer: ~typing.Any, processor: ~typing.Any, inference_config: ~typing.Optional[~easydel.inference.whisper_inference.vWhisperInferenceConfig] = None, dtype: ~typing.Union[str, type[typing.Any], ~numpy.dtype, ~jax._src.typing.SupportsDType] = <class 'jax.numpy.float32'>)[source]#
Bases:
objectWhisper inference pipeline for performing speech-to-text transcription or translation.
- Parameters
model (WhisperForConditionalGeneration) – The fine-tuned Whisper model to use for inference.
tokenizer (WhisperTokenizer) – Tokenizer for Whisper.
processor (WhisperProcessor) – Processor for Whisper.
inference_config (vWhisperInferenceConfig, optional) – Inference configuration.
dtype (jax.typing.DTypeLike, optional, defaults to jnp.float32) – Data type for computations.
Example usage:
>>> import easydel as ed >>> from transformers import WhisperTokenizer, WhisperProcessor
>>> REPO_ID = "openai/whisper-small" # Replace with your desired model
>>> model = ed.AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained( ... REPO_ID, ... # ... (config_kwargs as needed) >>> ) >>> tokenizer = WhisperTokenizer.from_pretrained(REPO_ID) >>> processor = WhisperProcessor.from_pretrained(REPO_ID)
>>> inference = vWhisperInference( ... model=model, ... tokenizer=tokenizer, ... processor=processor, ... dtype=jnp.float16, # Or jnp.float32 >>> )
>>> result = inference("sample1.flac", return_timestamps=True) >>> print(result)
>>> # Example using a URL: >>> result_url = inference( ... "https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy/raw/main/common_voice_en_100038.mp3", ... return_timestamps=True >>> ) >>> print(result_url)
>>> # Example specifying language and task: >>> result_lang_task = inference( ... "sample1.flac", language="en", task="transcribe", return_timestamps=True >>> ) >>> print(result_lang_task)
- chunk_iter_with_batch(audio_array: Array, chunk_length: int, stride_left: int, stride_right: int, batch_size: int)[source]#
- generate(audio_input: Union[str, bytes, ndarray, Dict[str, Union[ndarray, int]]], chunk_length_s: float = 30.0, stride_length_s: Optional[Union[float, list[float]]] = None, batch_size: Optional[int] = None, language: Optional[str] = None, task: Optional[str] = None, return_timestamps: Optional[bool] = None)[source]#
Transcribe or translate audio input.
- Parameters
audio_input (tp.Union[str, bytes, np.ndarray, tp.Dict[str, tp.Union[np.ndarray, int]]]) – Input audio. Can be a local file path, URL, bytes, numpy array, or a dictionary containing the array and sampling rate.
chunk_length_s (float, optional, defaults to 30.0) – Length of audio chunks in seconds.
stride_length_s (float or list[float], optional) – Stride length for chunking audio, in seconds. Defaults to chunk_length_s / 6.
batch_size (int, optional) – Batch size for processing. Defaults to the batch_size in inference_config.
language (str, optional) – Language of the input audio. Defaults to the language in inference_config.
task (str, optional) – Task to perform (e.g., “transcribe”, “translate”). Defaults to the task in inference_config.
return_timestamps (bool, optional) – Whether to return timestamps with the transcription. Defaults to the return_timestamps in inference_config.
- Returns
A dictionary containing the transcribed text (“text”) and optionally other information like timestamps or detected language.
- Return type
dict
- class easydel.inference.__init__.vWhisperInferenceConfig(batch_size: Optional[int] = 1, max_length: Optional[int] = None, generation_config: Optional[Any] = None)[source]#
Bases:
objectConfiguration class for Whisper inference.
- Parameters
batch_size (int, optional, defaults to 1) – Batch size used for inference.
max_length (int, optional) – Maximum sequence length for generation.
generation_config (transformers.GenerationConfig, optional) – Generation configuration object.
logits_processor (optional) – Not used.
return_timestamps (bool, optional) – Whether to return timestamps with the transcribed text.
task (str, optional) – Task for the model (e.g., “transcribe”, “translate”).
language (str, optional) – Language of the input audio.
is_multilingual (bool, optional) – Whether the model is multilingual.
- batch_size: Optional[int] = 1#
- generation_config: Optional[Any] = None#
- is_multilingual = None#
- language = None#
- logits_processor = None#
- max_length: Optional[int] = None#
- replace(**kwargs)#
- return_timestamps = None#
- task = None#