Source code for easydel.inference.esurge.utils

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions and classes for the eSurge engine.

Provides helper classes and functions for working with immutable lists,
array type checking, and other common operations.

Classes:
    ConstantList: Immutable list wrapper that prevents modifications

Functions:
    is_list_of_jax_arrays: Type guard for checking JAX array lists

Example:
    >>> from easydel.inference.esurge.utils import ConstantList
    >>>
    >>> # Create immutable list
    >>> const_list = ConstantList([1, 2, 3])
    >>> print(const_list[0])  # Works
    >>> const_list.append(4)  # Raises Exception
"""

from collections.abc import Sequence
from typing import Generic, Literal, TypeVar, overload

from jax import numpy as jnp
from typing_extensions import TypeIs

T = TypeVar("T")


[docs]class ConstantList(Generic[T], Sequence): """Immutable list wrapper that prevents modifications. Provides read-only access to a list while preventing any modification operations. Useful for protecting data structures that should not be changed after creation. Args: x: The list to wrap and make immutable. Example: >>> const_list = ConstantList([1, 2, 3]) >>> print(const_list[0]) # 1 >>> print(len(const_list)) # 3 >>> const_list.append(4) # Raises Exception """ def __init__(self, x: list[T]) -> None: """Initialize with a list to make immutable. Args: x: List to wrap. """ self._x = x
[docs] def append(self, item): raise Exception("Cannot append to a constant list")
[docs] def extend(self, item): raise Exception("Cannot extend a constant list")
[docs] def insert(self, item): raise Exception("Cannot insert into a constant list")
[docs] def pop(self, item): raise Exception("Cannot pop from a constant list")
[docs] def remove(self, item): raise Exception("Cannot remove from a constant list")
[docs] def clear(self): raise Exception("Cannot clear a constant list")
[docs] def index(self, item: T, start: int = 0, stop: int | None = None) -> int: return self._x.index(item, start, stop if stop is not None else len(self._x))
@overload def __getitem__(self, item: int) -> T: ... @overload def __getitem__(self, s: slice, /) -> list[T]: ... def __getitem__(self, item: int | slice) -> T | list[T]: return self._x[item] @overload def __setitem__(self, item: int, value: T): ... @overload def __setitem__(self, s: slice, value: T, /): ... def __setitem__(self, item: int | slice, value: T | list[T]): raise Exception("Cannot set item in a constant list") def __delitem__(self, item): raise Exception("Cannot delete item from a constant list") def __iter__(self): return iter(self._x) def __contains__(self, item): return item in self._x def __len__(self): return len(self._x) def __repr__(self): return f"ConstantList({self._x})"
[docs]def is_list_of( value: object, typ: type[T] | tuple[type[T], ...], *, check: Literal["first", "all"] = "first", ) -> TypeIs[list[T]]: if not isinstance(value, list): return False if check == "first": return len(value) == 0 or isinstance(value[0], typ) elif check == "all": return all(isinstance(v, typ) for v in value)
[docs]def chunk_list(lst: list[T], chunk_size: int): """Yield successive chunk_size chunks from lst.""" for i in range(0, len(lst), chunk_size): yield lst[i : i + chunk_size]
[docs]def cdiv(a: int, b: int) -> int: """Ceiling division.""" return -(a // -b)
[docs]def next_power_of_2(n) -> int: """The next power of 2 (inclusive)""" if n < 1: return 1 return 1 << (n - 1).bit_length()
[docs]def prev_power_of_2(n: int) -> int: """The previous power of 2 (inclusive)""" if n <= 0: return 0 return 1 << (n.bit_length() - 1)
[docs]def round_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y
[docs]def round_down(x: int, y: int) -> int: return (x // y) * y
[docs]def get_dtype_size(dtype: jnp.ndarray) -> int: """Get the size of the data type in bytes.""" return jnp.finfo(dtype).bits // 8 if jnp.issubdtype(dtype, jnp.floating) else jnp.iinfo(dtype).bits // 8
[docs]def truncate_tokens(tokens, target_len: int, mode: str = "left"): n = len(tokens) if n <= target_len: return tokens, 0 drop = n - target_len if mode == "left": return tokens[drop:], drop elif mode == "right": return tokens[:target_len], drop elif mode == "middle": keep_left = (target_len + 1) // 2 keep_right = target_len - keep_left return tokens[:keep_left] + tokens[n - keep_right :], drop else: raise ValueError(f"Unknown truncate_mode: {mode}")