Source code for easydel.inference.vwhisper.generation
# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flax import nnx as nn
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
from easydel.utils.compiling_utils import ejit
@ejit(static_argnames=["graphdef", "inference_config", "return_timestamps"])
def _compiled_generate(
graphdef,
graphstate,
inference_config,
input_features,
decoder_input_ids,
return_timestamps,
):
model = nn.merge(graphdef, graphstate)
with model.mesh:
return model._force_generate(
input_features=input_features,
forced_decoder_ids=decoder_input_ids,
return_timestamps=return_timestamps,
generation_config=inference_config.generation_config,
)
[docs]def get_decoder_input_ids(
model_config,
generation_config=None,
task=None,
language=None,
return_timestamps=False,
):
"""Helper function to get decoder input IDs for Whisper."""
generation_config = generation_config or model_config
is_multilingual = getattr(generation_config, "is_multilingual", None)
decoder_input_ids = []
if is_multilingual:
if language is not None:
language = language.lower()
if language in generation_config.lang_to_id:
language_token = language
elif language in TO_LANGUAGE_CODE.values():
language_token = f"<|{language}|>"
elif language in TO_LANGUAGE_CODE:
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
else:
acceptable_languages = (
list(TO_LANGUAGE_CODE.values())
if len(language) == 2
else list(generation_config.lang_to_id)
if "<" in language or "|" in language or ">" in language
else list(TO_LANGUAGE_CODE)
)
raise ValueError(f"Unsupported language: {language}. Language should be one of: {acceptable_languages}.")
decoder_input_ids.append((1, generation_config.lang_to_id[language_token]))
if task is not None:
decoder_input_ids.append((2, generation_config.task_to_id[task]))
else:
decoder_input_ids.append((2, generation_config.task_to_id["transcribe"]))
if (
not return_timestamps
and decoder_input_ids
and decoder_input_ids[-1][0] != generation_config.no_timestamps_token_id
):
next_idx = (decoder_input_ids[-1][0] + 1) if decoder_input_ids else 1
decoder_input_ids.append((next_idx, generation_config.no_timestamps_token_id))
return decoder_input_ids