Source code for easydel.inference.vwhisper.utils
# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import typing as tp
import numpy as np
import requests
[docs]def chunk_iter_with_batch(
audio_array: np.ndarray,
chunk_length: int,
stride_left: int,
stride_right: int,
batch_size: int,
feature_extractor,
):
"""
Process an audio array into chunks with overlapping strides.
Args:
audio_array: Input audio array
chunk_length: Length of each chunk in samples
stride_left: Left stride in samples
stride_right: Right stride in samples
batch_size: Number of chunks to process at once
feature_extractor: Feature extractor to process audio
Yields:
Batches of processed audio chunks
"""
inputs_len = audio_array.shape[0]
step = chunk_length - stride_left - stride_right
all_chunk_start_idx = np.arange(0, inputs_len, step)
num_samples = len(all_chunk_start_idx)
num_batches = math.ceil(num_samples / batch_size)
batch_idx = np.array_split(np.arange(num_samples), num_batches)
for idx in batch_idx:
chunk_start_idx = all_chunk_start_idx[idx]
chunk_end_idx = chunk_start_idx + chunk_length
chunks = [
audio_array[chunk_start:chunk_end]
for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)
]
processed = feature_extractor(
chunks,
sampling_rate=feature_extractor.sampling_rate,
return_tensors="np",
)
yield {
"stride": [
(chunk_l, _stride_l, _stride_r)
for chunk_l, _stride_l, _stride_r in zip(
[chunk.shape[0] for chunk in chunks],
np.where(chunk_start_idx == 0, 0, stride_left),
np.where(
np.where(
stride_right > 0,
chunk_end_idx > inputs_len,
chunk_end_idx >= inputs_len,
),
0,
stride_right,
),
)
],
**processed,
}
[docs]def process_audio_input(
audio_input: tp.Union[
str, bytes, np.ndarray, tp.Dict[str, tp.Union[np.ndarray, int]]
],
feature_extractor,
):
"""
Process audio input into a numpy array with correct sampling rate.
Args:
audio_input: Input audio in various formats
feature_extractor: Feature extractor with sampling rate info
Returns:
Tuple of (audio_array, stride)
"""
stride = None
if isinstance(audio_input, str):
if audio_input.startswith("http://") or audio_input.startswith("https://"):
audio_input = requests.get(audio_input).content
else:
with open(audio_input, "rb") as f:
audio_input = f.read()
if isinstance(audio_input, bytes):
from transformers.pipelines.audio_utils import ffmpeg_read
audio_input = ffmpeg_read(audio_input, feature_extractor.sampling_rate)
if isinstance(audio_input, dict):
stride = audio_input.get("stride", None)
if not ("sampling_rate" in audio_input and "array" in audio_input):
raise ValueError(
"When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key "
"containing the numpy array representing the audio, and a 'sampling_rate' key "
"containing the sampling rate associated with the audio array."
)
in_sampling_rate = audio_input.get("sampling_rate")
audio_input = audio_input.get("array", None)
if in_sampling_rate != feature_extractor.sampling_rate:
try:
import librosa # type:ignore
except ImportError as err:
raise ImportError(
"To support resampling audio files, please install 'librosa' and 'soundfile'."
) from err
audio_input = librosa.resample(
audio_input,
orig_sr=in_sampling_rate,
target_sr=feature_extractor.sampling_rate,
)
ratio = feature_extractor.sampling_rate / in_sampling_rate
else:
ratio = 1
if not isinstance(audio_input, np.ndarray):
raise ValueError(f"We expect a numpy ndarray as input, got `{type(audio_input)}`")
if len(audio_input.shape) != 1:
raise ValueError(
"We expect a single channel audio input for AutomaticSpeechRecognitionPipeline"
)
if stride is not None:
if stride[0] + stride[1] > audio_input.shape[0]:
raise ValueError("Stride is too large for input")
stride = (
audio_input.shape[0],
int(round(stride[0] * ratio)),
int(round(stride[1] * ratio)),
)
return audio_input, stride