Source code for easydel.inference.vwhisper.generation

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import jax
from flax import nnx as nn
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE


@partial(
	jax.jit,
	static_argnames=[
		"graphdef",
		"inference_config",
		"return_timestamps",
	],
)
def _compiled_generate(
	graphdef,
	graphstate,
	inference_config,
	input_features,
	decoder_input_ids,
	return_timestamps,
):
	model = nn.merge(graphdef, graphstate)
	with model.mesh:
		return model._force_generate(
			input_features=input_features,
			forced_decoder_ids=decoder_input_ids,
			return_timestamps=return_timestamps,
			generation_config=inference_config.generation_config,
		)


[docs]def get_decoder_input_ids( model_config, generation_config=None, task=None, language=None, return_timestamps=False, ): """Helper function to get decoder input IDs for Whisper.""" generation_config = generation_config or model_config is_multilingual = getattr(generation_config, "is_multilingual", None) decoder_input_ids = [] if is_multilingual: if language is not None: language = language.lower() if language in generation_config.lang_to_id: language_token = language elif language in TO_LANGUAGE_CODE.values(): language_token = f"<|{language}|>" elif language in TO_LANGUAGE_CODE: language_token = f"<|{TO_LANGUAGE_CODE[language]}|>" else: acceptable_languages = ( list(TO_LANGUAGE_CODE.values()) if len(language) == 2 else list(generation_config.lang_to_id) if "<" in language or "|" in language or ">" in language else list(TO_LANGUAGE_CODE) ) raise ValueError( f"Unsupported language: {language}. Language should be one of: {acceptable_languages}." ) decoder_input_ids.append((1, generation_config.lang_to_id[language_token])) if task is not None: decoder_input_ids.append((2, generation_config.task_to_id[task])) else: decoder_input_ids.append((2, generation_config.task_to_id["transcribe"])) if ( not return_timestamps and decoder_input_ids and decoder_input_ids[-1][0] != generation_config.no_timestamps_token_id ): next_idx = (decoder_input_ids[-1][0] + 1) if decoder_input_ids else 1 decoder_input_ids.append((next_idx, generation_config.no_timestamps_token_id)) return decoder_input_ids