Source code for easydel.infra.base_config

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import typing as tp
import warnings
from dataclasses import dataclass

import chex
import jax
import jax.extend
import jax.tree_util
from eformer.escale import PartitionAxis
from flax import nnx as nn
from jax import numpy as jnp
from transformers.configuration_utils import PretrainedConfig

from easydel.utils.compiling_utils import hash_fn
from easydel.utils.helpers import get_logger

from .etils import (
	AVAILABLE_ATTENTION_MECHANISMS,
	DEFAULT_ATTENTION_MECHANISM,
	EasyDeLBackends,
	EasyDeLGradientCheckPointers,
	EasyDeLPlatforms,
	EasyDeLQuantizationMethods,
)

if tp.TYPE_CHECKING:
	from easydel.layers.rotary_embedding import RopeConfig

	from .utils import ModuleCaches
else:
	RopeConfig = tp.Any
	ModuleCaches = tp.Any
logger = get_logger(__name__)

FLAX_WEIGHTS_NAME = "easydel-model.parameters"
DEFAULT_PALLAS_M_BLOCK_SIZE = 128
DEFAULT_PALLAS_K_BLOCK_SIZE = 128
DEFAULT_PALLAS_N_BLOCK_SIZE = 128
DEFAULT_HARDWARE_ABSTRACTION = False
ED_DEFAULT_HARDWARE_ABSTRACTION = os.environ.get(
	"ED_DEFAULT_HARDWARE_ABSTRACTION",
	default="false",
).lower() in ["true", "1", "yes"]

EKERNEL_OPS = os.environ.get(
	"EKERNEL_OPS",
	default="false",
).lower() in ["true", "1", "yes"]

if ED_DEFAULT_HARDWARE_ABSTRACTION:
	DEFAULT_HARDWARE_ABSTRACTION = True


if DEFAULT_HARDWARE_ABSTRACTION:
	logger.info("HARDWARE_ABSTRACTION is ON by default")

if EKERNEL_OPS:
	logger.info(
		"`EKERNEL_OPS` is ON and some operations will automatically be replaced by EasyDeL."
	)
	from easydel.kernels.matmul import replace_dot_general_with_matmul

	replace_dot_general_with_matmul()


[docs]def set_attrs_smartly(self, attr_name: str, default: tp.Any, new_attr: tp.Any): if not hasattr(self, attr_name): setattr(self, attr_name, default) if not new_attr == Ellipsis: setattr(self, attr_name, new_attr)
[docs]@dataclass class EasyMethod: TRAIN: str = "train" SERVE: str = "serve" EVAL: str = "serve" CONVERT: str = "convert"
warnings.filterwarnings( "ignore", message="Passing `gradient_checkpointing` to a config initialization is deprecated", # EasyDeL will handle this ) warnings.filterwarnings("ignore", message="You are using a model of type")
[docs]class EasyDeLBaseConfigDict(tp.TypedDict, total=False): axis_dims: tp.Sequence[int] dcn_axis_dims: tp.Optional[tp.Sequence[int]] axis_names: tp.Sequence[str] attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS blocksize_k: int blocksize_q: int blocksize_b: int partition_axis: PartitionAxis shard_attention_computation: bool use_sharded_kv_caching: bool use_sharding_constraint: bool backend: tp.Optional[EasyDeLBackends] platform: tp.Optional[EasyDeLPlatforms] easy_method: tp.Literal["train", "serve", "convert"] bits: tp.Optional[int] scan_ring_attention: bool scan_attention_layers: bool use_scan_mlp: bool scan_mlp_chunk_size: int sequence_axis_name: str gradient_checkpointing: EasyDeLGradientCheckPointers kv_cache_quantization_method: EasyDeLQuantizationMethods kv_cache_quantization_blocksize: int kv_cache_sharding_sequence_axis_name: tp.Union[str, tp.Tuple[str, ...]] flash_attention_backward_pass_impl: tp.Literal["triton", "xla"] attn_dtype: jnp.dtype attn_softmax_dtype: jnp.dtype fcm_max_ratio: float fcm_min_ratio: float hardware_abstraction: bool pallas_m_block_size: int pallas_k_block_size: int pallas_n_block_size: int mask_max_position_embeddings: int freq_max_position_embeddings: int
[docs]class EasyDeLBaseConfig(PretrainedConfig): """ Initialize the configuration for EasyDeL. Args: axis_dims (tp.Sequence[int]): Dimensions of the axes. Default is (1, -1, 1, 1). axis_names (tp.Sequence[str]): Names of the axes. Default is ("dp", "fsdp", "tp", "sp"). attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS): Attention mechanism to use. Default is DEFAULT_ATTENTION_MECHANISM. blocksize_k (int): Block size for key. Default is 128. blocksize_q (int): Block size for query. Default is 128. blocksize_b (int): Block size for batch. Default is 1. partition_axis (PartitionAxis): Partition axis configuration. Default is PartitionAxis(). shard_attention_computation (bool): Whether to shard attention computation. Default is True. use_sharded_kv_caching (bool): Whether to use sharded key-value caching. Default is False. use_sharding_constraint (bool): Whether to use sharding constraint. Default is False. backend (tp.Optional[EasyDeLBackends]): Backend to use. Default is None. platform (tp.Optional[EasyDeLPlatforms]): Platform to use. Default is None. easy_method (tp.Literal["train", "serve", "convert"]): Method to use. Default is EasyMethod.TRAIN. bits (tp.Optional[int]): Number of bits for quantization. Default is None. scan_ring_attention (bool): Whether to scan ring attention. Default is True. scan_attention_layers (bool): Whether to scan attention layers. Default is False. use_scan_mlp (bool): Whether to use scan MLP. Default is False. scan_mlp_chunk_size (int): Chunk size for scan MLP. Default is 1024. sequence_axis_name (str): Name of the attention axis. Default is "sp". gradient_checkpointing (EasyDeLGradientCheckPointers): Gradient checkpointing method. Default is EasyDeLGradientCheckPointers.NONE. kv_cache_quantization_method (EasyDeLQuantizationMethods): Key-value cache quantization method. Default is EasyDeLQuantizationMethods.NONE. kv_cache_quantization_blocksize (int): Block size for key-value cache quantization. Default is 64. quantization_method (EasyDeLQuantizationMethods): Quantization method. Default is EasyDeLQuantizationMethods.NONE. quantization_pattern (str): Pattern for quantization. Default is ".*". quantization_blocksize (int): Block size for quantization. Default is 64. kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]]): Name of the key-value cache sharding sequence axis. Default is "sp". flash_attention_backward_pass_impl (tp.Literal["triton", "xla"]): Implementation for flash attention backward pass. Default is "triton". attn_dtype (jnp.dtype): Data type for attention. Default is device half. attn_softmax_dtype (jnp.dtype): Data type for softmax ops in attention. Default is jnp.float32. fcm_max_ratio (float): Maximum ratio for FCM. Default is 0.0. fcm_min_ratio (float): Minimum ratio for FCM. Default is 0.0. hardware_abstraction (bool): Whether to use hardware abstraction. Default is DEFAULT_HARDWARE_ABSTRACTION. pallas_m_block_size (int): Block size for Pallas M. Default is DEFAULT_PALLAS_M_BLOCK_SIZE. pallas_k_block_size (int): Block size for Pallas K. Default is DEFAULT_PALLAS_K_BLOCK_SIZE. pallas_n_block_size (int): Block size for Pallas N. Default is DEFAULT_PALLAS_N_BLOCK_SIZE. **kwargs: Additional keyword arguments. Raises: Warning: If `kv_cache_quantization_method` is not NONE and `use_sharded_kv_caching` is True. """ _show_private_attrs: bool = False def __init__( self, axis_dims: tp.Sequence[int] = (1, -1, 1, 1), dcn_axis_dims: tp.Optional[tp.Sequence[int]] = None, axis_names: tp.Sequence[str] = ("dp", "fsdp", "tp", "sp"), attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS = DEFAULT_ATTENTION_MECHANISM, blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, partition_axis: PartitionAxis = PartitionAxis(), # noqa shard_attention_computation: bool = True, use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: tp.Optional[EasyDeLBackends] = None, platform: tp.Optional[EasyDeLPlatforms] = None, easy_method: tp.Literal["train", "serve", "convert"] = EasyMethod.TRAIN, bits: tp.Optional[int] = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = "sp", gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, kv_cache_quantization_method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, kv_cache_quantization_blocksize: int = 64, quantization_method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, quantization_pattern: str = ".*", quantization_blocksize: int = 64, kv_cache_sharding_sequence_axis_name: tp.Union[str, tp.Tuple[str, ...]] = "sp", flash_attention_backward_pass_impl: tp.Literal["triton", "xla"] = "triton", attn_dtype: jnp.dtype = jnp.float32, attn_softmax_dtype: jnp.dtype = jnp.float32, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = DEFAULT_HARDWARE_ABSTRACTION, pallas_m_block_size: int = DEFAULT_PALLAS_M_BLOCK_SIZE, pallas_k_block_size: int = DEFAULT_PALLAS_K_BLOCK_SIZE, pallas_n_block_size: int = DEFAULT_PALLAS_N_BLOCK_SIZE, **kwargs, ): self.axis_dims = getattr(self, "axis_dims", axis_dims) self.dcn_axis_dims = getattr(self, "dcn_axis_dims", dcn_axis_dims) self.axis_names = getattr(self, "axis_names", axis_names) self.backend = getattr( self, "backend", backend if backend is not None else jax.default_backend(), ) self.platform = getattr( self, "platform", platform if platform is not None else ("triton" if jax.default_backend() == "gpu" else "jax"), ) # fmt:off self.easy_method = getattr(self, "easy_method", easy_method) self.attn_mechanism = getattr(self, "attn_mechanism", attn_mechanism) self.blocksize_b = getattr(self, "blocksize_b", blocksize_b) self.blocksize_k = getattr(self, "blocksize_k", blocksize_k) self.blocksize_q = getattr(self, "blocksize_q", blocksize_q) self.partition_axis = getattr(self, "partition_axis", partition_axis) self.shard_attention_computation = getattr(self,"shard_attention_computation", shard_attention_computation) self.bits = getattr(self, "bits", bits) self.scan_attention_layers = getattr(self,"scan_attention_layers", scan_attention_layers) self.scan_ring_attention = getattr(self, "scan_ring_attention", scan_ring_attention) self.use_sharded_kv_caching = getattr(self,"use_sharded_kv_caching", use_sharded_kv_caching) self.use_scan_mlp = getattr(self, "use_scan_mlp", use_scan_mlp) self.scan_mlp_chunk_size = getattr(self, "scan_mlp_chunk_size", scan_mlp_chunk_size) self.use_sharding_constraint = getattr(self,"use_sharding_constraint", use_sharding_constraint) self.sequence_axis_name = getattr(self, "sequence_axis_name", sequence_axis_name) self.kv_cache_sharding_sequence_axis_name = getattr(self,"kv_cache_sharding_sequence_axis_name", kv_cache_sharding_sequence_axis_name) self.gradient_checkpointing = getattr(self,"gradient_checkpointing", gradient_checkpointing) self.kv_cache_quantization_method = getattr(self,"kv_cache_quantization_method", kv_cache_quantization_method) self.kv_cache_quantization_blocksize = getattr(self,"kv_cache_quantization_blocksize", kv_cache_quantization_blocksize) self.quantization_method = getattr(self, "quantization_method", quantization_method) self.quantization_blocksize = getattr(self, "quantization_blocksize", quantization_blocksize) self.quantization_pattern = getattr(self, "quantization_pattern", quantization_pattern) self.flash_attention_backward_pass_impl = getattr(self, "flash_attention_backward_pass_impl", flash_attention_backward_pass_impl) self.attn_dtype = getattr(self, "attn_dtype", attn_dtype) self.attn_softmax_dtype = getattr(self, "attn_softmax_dtype", attn_softmax_dtype) self.fcm_max_ratio = getattr(self, "fcm_max_ratio", fcm_max_ratio) self.fcm_min_ratio = getattr(self, "fcm_min_ratio", fcm_min_ratio) self.hardware_abstraction = getattr(self, "hardware_abstraction", hardware_abstraction) self.pallas_m_block_size = getattr(self, "pallas_m_block_size", pallas_m_block_size) self.pallas_k_block_size = getattr(self, "pallas_k_block_size", pallas_k_block_size) self.pallas_n_block_size = getattr(self, "pallas_n_block_size", pallas_n_block_size) # fmt:on self.pretraining_tp = 1 # it's for pytorch models. if ( self.kv_cache_quantization_method != EasyDeLQuantizationMethods.NONE and self.use_sharded_kv_caching ): use_sharded_kv_caching = self.use_sharded_kv_caching warnings.warn( f"`{self.kv_cache_quantization_method=}` and `{use_sharded_kv_caching=}`" " can't be used together at the moment.", stacklevel=1, ) super().__init__(**kwargs)
[docs] @staticmethod def create_mesh( axis_dims: tp.Sequence[int] = (1, -1, 1, 1), axis_names: tp.Sequence[str] = ("dp", "fsdp", "tp", "sp"), dcn_axis_dims: tp.Optional[tp.Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: tp.Optional[str] = None, ): """ The create_mesh function creates a mesh object that can be used to shard arrays. Returns: A mesh object """ from eformer.escale import create_mesh if backend == "": backend = None return create_mesh( axis_dims=axis_dims, axis_names=axis_names, dcn_mesh_dims=dcn_axis_dims, process_is_granule=process_is_granule, should_sort_granules_by_key=should_sort_granules_by_key, allow_split_physical_axes=allow_split_physical_axes, backend=backend, )
@property def mesh(self): """The mesh property is a helper property that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The platform attribute is also used if it exists. Args: self: Refer to the object itself Returns: A jaxMesh """ return self.create_mesh( axis_dims=( [v for k, v in self.axis_dims.items()] if isinstance(self.axis_dims, dict) else self.axis_dims ), axis_names=( [v for k, v in self.axis_names.items()] if isinstance(self.axis_names, dict) else self.axis_names ), dcn_axis_dims=( [v for k, v in self.dcn_axis_dims.items()] if isinstance(self.dcn_axis_dims, dict) else self.dcn_axis_dims ), should_sort_granules_by_key=( ( self.should_sort_granules_by_key if self.should_sort_granules_by_key is not None else True ) if hasattr(self, "should_sort_granules_by_key") else True ), allow_split_physical_axes=( ( self.allow_split_physical_axes if self.allow_split_physical_axes is not None else True ) if hasattr(self, "allow_split_physical_axes") else True ), backend=( (self.backend if self.backend is not None else "") if hasattr(self, "backend") else "" ), )
[docs] def jax_mesh(self): warnings.warn("`jax_mesh` is deprecated use `get_mesh` or `mesh`", stacklevel=1) return self.get_mesh()
[docs] def get_partition_rules(self, *args, **kwargs): """ Get the partition rules for the model. Returns: `tp.Tuple[tp.Tuple[str, PartitionSpec]]`: The partition rules. """ raise NotImplementedError("`get_partition_rules` is not implemented.")
[docs] def get_axis_dims(self) -> tp.Sequence[int]: """The get_axis_dims function returns a sequence of integers representing the dimensions of each axis. Args: self: Represent the instance of the class Returns: The dimensions of the axes """ return self.axis_dims
[docs] def get_axis_names(self) -> tp.Sequence[str]: """The get_axis_names function returns a list of the names of the axes. Args: self: Represent the instance of the class Returns: A list of the names of all axes """ return self.axis_names
[docs] def get_backend(self) -> str: """The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend. Args: self: Bind the method to an object Returns: The backend platform """ return ( self.backend if not self.backend == "" else jax.extend.backend.get_backend().platform )
[docs] def add_basic_configurations( self, axis_dims: tp.Sequence[int] = ..., dcn_axis_dims: tp.Optional[tp.Sequence[int]] = ..., axis_names: tp.Sequence[str] = ..., attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS = ..., blocksize_k: int = ..., blocksize_q: int = ..., blocksize_b: int = ..., partition_axis: PartitionAxis = ..., shard_attention_computation: bool = ..., use_sharded_kv_caching: bool = ..., backend: tp.Optional[EasyDeLBackends] = ..., platform: tp.Optional[EasyDeLPlatforms] = ..., easy_method: tp.Literal["train", "serve", "convert"] = ..., bits: tp.Optional[int] = ..., scan_ring_attention: bool = ..., scan_attention_layers: bool = ..., use_sharding_constraint: bool = ..., use_scan_mlp: bool = ..., scan_mlp_chunk_size: int = ..., sequence_axis_name: str = ..., gradient_checkpointing: EasyDeLGradientCheckPointers = ..., kv_cache_quantization_method: EasyDeLQuantizationMethods = ..., kv_cache_quantization_blocksize: int = ..., quantization_method: EasyDeLQuantizationMethods = ..., quantization_blocksize: int = ..., quantization_pattern: str = ..., kv_cache_sharding_sequence_axis_name: tp.Union[str, tp.Tuple[str, ...]] = ..., flash_attention_backward_pass_impl: tp.Literal["triton", "xla"] = ..., attn_dtype: jnp.dtype = ..., attn_softmax_dtype: jnp.dtype = ..., hardware_abstraction: bool = ..., pallas_m_block_size: int = ..., pallas_k_block_size: int = ..., pallas_n_block_size: int = ..., ): """ It initializes all the attributes of an object, and it's called when you create a new instance of that class. Args: axis_dims (tp.Sequence[int], optional): Specify the number of dimensions for each axis. Defaults to (1, -1, 1, 1). axis_names (tp.Sequence[str], optional): Set the names of the axes. Defaults to ("dp", "fsdp", "tp", "sp"). attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS, optional): attention mechanism to use. Defaults to DEFAULT_ATTENTION_MECHANISM. blocksize_k (int, optional): block size of key_states. Defaults to 128. blocksize_q (int, optional): block size of query_states. Defaults to 128. blocksize_b (int, optional): block size of bias. Defaults to 1. partition_axis (PartitionAxis, optional): PartitionAxis is new module used for partitioning arrays in easydel. Defaults to PartitionAxis(). shard_attention_computation (bool, optional): whenever to use shard_map for attention. Defaults to True. use_sharded_kv_caching (bool, optional): whenever to use shard_map and sharding for key and value. Defaults to True. backend (tp.Optional[EasyDeLBackends], optional): Specify the backend to use. Defaults to None. platform (tp.Optional[EasyDeLPlatforms], optional): Specify the platform to used to use. Defaults to None. easy_method (tp.Literal["train", "serve", "convert"], optional): easydel Quantization Method to be applied for. Defaults to EasyMethod.TRAIN. bits (tp.Optional[int], optional): Model bits for quantization. Defaults to None. scan_ring_attention (bool, optional): Whether to use can for ring attention. Defaults to True. scan_attention_layers (bool, optional): Whether to use can for attention layers. Defaults to False. use_sharding_constraint (bool, optional): whether to use sharding constraint for the arrays. Defaults to False. use_scan_mlp (bool, optional): Determine whether to use scan_mlp or not. Defaults to False. scan_mlp_chunk_size (int, optional): Size of chunks in scan MLP. Defaults to 1024. sequence_axis_name (str, optional): Name of the attention axis name. Defaults to "sp". gradient_checkpointing (EasyDeLQuantizationMethods, optional): Gradient Checkpointing method for created or loaded module (applied on mlp and attn layers most of the times). kv_cache_quantization_method (EasyDeLQuantizationMethods, optional): key and value quantization type. Defaults to EasyDeLQuantizationMethods.NONE. kv_cache_quantization_blocksize (int, optional): size of kv cache quantization. Defaults to 64. quantization_method (EasyDeLQuantizationMethods, optional): linear modules quantization type. Defaults to EasyDeLQuantizationMethods.NONE. quantization_blocksize (int, optional): size of linear quantization. Defaults to 64. quantization_pattern (str): re pattern to be used for quantizing layers. kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]], optional): axis name to target for sharding sequences. Defaults to "sp". flash_attention_backward_pass_impl (tp.Literal["triton", "xla"], optional): Specify the backward pass kernel for flash attention. Defaults to "triton". attn_dtype (jnp.dtype, optional): Data type for attention computations. Defaults to device half. attn_softmax_dtype (jnp.dtype, optional): Data type for softmax in attention op computations. Defaults to jnp.float32. fcm_max_ratio (float, optional): Maximum ratio for flash cross attention. Defaults to 0.0. fcm_min_ratio (float, optional): Minimum ratio for flash cross attention. Defaults to 0.0. hardware_abstraction (bool, optional): whenever to switch to custom pallas kernels instead of JAX. Defaults to DEFAULT_HARDWARE_ABSTRACTION. pallas_m_block_size (int, optional): block size m dim in matmul for pallas kernel `A(mk)@B(kn)=B(mn)`. Defaults to DEFAULT_PALLAS_M_BLOCK_SIZE. pallas_k_block_size (int, optional): block size k dim in matmul for pallas kernel `A(mk)@B(kn)=B(mn)`. Defaults to DEFAULT_PALLAS_K_BLOCK_SIZE. pallas_n_block_size (int, optional): block size n dim in matmul for pallas kernel `A(mk)@B(kn)=B(mn)`. Defaults to DEFAULT_PALLAS_N_BLOCK_SIZE. """ # fmt: off set_attrs_smartly(self, "axis_dims", (1, -1, 1, 1), axis_dims) set_attrs_smartly(self, "dcn_axis_dims", None, dcn_axis_dims) set_attrs_smartly(self, "axis_names", ("dp", "fsdp", "tp", "sp"), axis_names) set_attrs_smartly(self, "blocksize_q", 512, blocksize_q) set_attrs_smartly(self, "blocksize_k", 512, blocksize_k) set_attrs_smartly(self, "blocksize_b", 1, blocksize_b) set_attrs_smartly(self, "partition_axis", PartitionAxis(), partition_axis) set_attrs_smartly(self, "use_sharding_constraint", False, use_sharding_constraint) set_attrs_smartly(self, "backend", None, backend) set_attrs_smartly(self, "platform", "jax", platform) set_attrs_smartly(self, "shard_attention_computation", True, shard_attention_computation) set_attrs_smartly(self, "use_sharded_kv_caching", False, use_sharded_kv_caching) set_attrs_smartly(self, "attn_mechanism", "jax_flash_attn2", attn_mechanism) set_attrs_smartly(self, "easy_method", EasyMethod.TRAIN, easy_method) set_attrs_smartly(self, "bits", None, bits) set_attrs_smartly(self, "scan_attention_layers", True, scan_attention_layers) set_attrs_smartly(self, "scan_ring_attention", True, scan_ring_attention) set_attrs_smartly(self, "use_scan_mlp", False, use_scan_mlp) set_attrs_smartly(self, "scan_mlp_chunk_size", 1024, scan_mlp_chunk_size) set_attrs_smartly(self, "sequence_axis_name", "sp", sequence_axis_name) set_attrs_smartly(self, "kv_cache_quantization_blocksize", 128, kv_cache_quantization_blocksize) set_attrs_smartly(self, "kv_cache_sharding_sequence_axis_name", "sp", kv_cache_sharding_sequence_axis_name) set_attrs_smartly(self, "gradient_checkpointing", EasyDeLGradientCheckPointers.NONE, gradient_checkpointing) set_attrs_smartly(self, "kv_cache_quantization_method", EasyDeLQuantizationMethods.NONE, kv_cache_quantization_method) set_attrs_smartly(self, "quantization_method", EasyDeLQuantizationMethods.NONE, quantization_method) set_attrs_smartly(self, "quantization_blocksize", EasyDeLQuantizationMethods.NONE, quantization_blocksize) set_attrs_smartly(self, "quantization_pattern", ".*", quantization_pattern) set_attrs_smartly(self, "flash_attention_backward_pass_impl", "triton", flash_attention_backward_pass_impl) set_attrs_smartly(self, "attn_dtype", jnp.float32, attn_dtype) set_attrs_smartly(self, "attn_softmax_dtype", jnp.float32, attn_softmax_dtype) set_attrs_smartly(self, "hardware_abstraction", DEFAULT_HARDWARE_ABSTRACTION, hardware_abstraction) set_attrs_smartly(self, "pallas_m_block_size", DEFAULT_PALLAS_M_BLOCK_SIZE, pallas_m_block_size) set_attrs_smartly(self, "pallas_k_block_size", DEFAULT_PALLAS_K_BLOCK_SIZE, pallas_k_block_size) set_attrs_smartly(self, "pallas_n_block_size", DEFAULT_PALLAS_N_BLOCK_SIZE, pallas_n_block_size)
# fmt: on def __repr__(self): """The __repr__ function is used to generate a string representation of an object. This function should return a string that can be parsed by the Python interpreter to recreate the object. The __repr__ function is called when you use print() on an object, or when you type its name in the REPL. Args: self: Refer to the instance of the class Returns: A string representation of the object """ string = f"{self.__class__.__name__}(\n" for k, v in self.__dict__.items(): if not k.startswith("_"): try: repr_src = f" {k} : " + v.__str__().replace("\n", "\n ") + "\n" string += ( repr_src if len(repr_src) < 500 else f" {k} : " + f"{v.__class__.__name__}(...)" + "\n" ) except TypeError: pass return string + ")"
[docs] def to_dict(self) -> tp.Dict[str, tp.Any]: sd = self.__dict__ forbidden_types = ["_ScalarMeta"] extracted_values = { k: sd.pop(k) for k in list(sd.keys()) if sd.get(k).__class__.__name__ in forbidden_types } result = super().to_dict() for k, v in extracted_values.items(): sd[k] = v return result
[docs] def add_jax_args(self, **kwargs): for k, v in kwargs.items(): set_attrs_smartly(self, k, v, v)
def __str__(self): """The __str__ function is called when you use the print function or when str() is used. It should return a string representation of the object. Args: self: Refer to the instance of the class Returns: The object's string representation """ return self.__repr__()
[docs] @classmethod # From HF. def from_pretrained( cls, pretrained_model_name_or_path: tp.Union[str, os.PathLike], cache_dir: tp.Optional[tp.Union[str, os.PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: tp.Optional[tp.Union[str, bool]] = None, revision: str = "main", **kwargs, ) -> "PretrainedConfig": r""" Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model configuration hosted inside a model repo on huggingface.co. - a path to a *directory* containing a configuration file saved using the [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`. - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`. cache_dir (`str` or `os.PathLike`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force to (re-)download the configuration files and override the cached versions if they exist. resume_download: Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. <Tip> To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>". </Tip> return_unused_kwargs (`bool`, *optional*, defaults to `False`): If `False`, then this function returns just the final configuration object. If `True`, then this functions returns a `tp.Tuple(config, unused_kwargs)` where *unused_kwargs* is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of `kwargs` which has not been used to update `config` and is otherwise ignored. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. kwargs (`Dict[str, tp.Any]`, *optional*): The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. Returns: [`PretrainedConfig`]: The configuration object instantiated from this pretrained model. Examples: >>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a >>> # derived class: BertConfig >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased" >>> ) # Download configuration from huggingface.co and cache. >>> config = BertConfig.from_pretrained( ... "./test/saved_model/" >>> ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* >>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", output_attentions=True, foo=False >>> ) >>> assert config.output_attentions == True >>> config, unused_kwargs = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", ... output_attentions=True, ... foo=False, ... return_unused_kwargs=True, >>> ) >>> assert config.output_attentions == True >>> assert unused_kwargs == {"foo": False} ```""" kwargs["cache_dir"] = cache_dir kwargs["force_download"] = force_download kwargs["local_files_only"] = local_files_only kwargs["revision"] = revision cls._set_token_in_kwargs(kwargs, token) config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) return cls.from_dict(config_dict, **kwargs)
@property def granted_freq_max_position_embedding(self) -> int: return getattr( self, "freq_max_position_embeddings", self.max_position_embeddings, ) @property def granted_mask_max_position_embedding(self) -> int: return getattr( self, "mask_max_position_embeddings", self.max_position_embeddings, ) def _get_rope_config(self) -> RopeConfig: """Get RoPE configuration from the instance attributes.""" from easydel.layers.rotary_embedding import RopeConfig if not hasattr(self, "rope_scaling") or self.rope_scaling is None: config = RopeConfig() else: config = RopeConfig.from_dict(self.rope_scaling) if config.original_max_position_embeddings is None: config.original_max_position_embeddings = getattr( self, "original_max_position_embeddings", None, ) return config
[docs] def get_basic_rope( self, dtype: chex.Array, head_size: int, rotary_dim: tp.Optional[int] = None, is_neox_style: bool = True, base: tp.Optional[float] = None, ): """ Get basic rotary position embeddings. Args: dtype: Data type for the embeddings head_size: Size of attention heads rotary_dim: Dimension for rotary embeddings (defaults to head_size) is_neox_style: Whether to use NeoX style embeddings base: Base value for frequency computation (defaults to self.rope_theta) Returns: Rotary position embeddings func """ from easydel.layers.rotary_embedding import get_rope rotary_dim = rotary_dim or head_size rope_config = self._get_rope_config() return get_rope( head_size=head_size, rotary_dim=rotary_dim, max_position=self.granted_freq_max_position_embedding, base=base or self.rope_theta, dtype=dtype, is_neox_style=is_neox_style, rope_scaling=rope_config.to_dict(), )
[docs] def get_basic_frequencies( self, head_size: tp.Optional[int] = None, rotary_dim: tp.Optional[int] = None, base: tp.Optional[float] = None, ) -> ModuleCaches: """ Get basic frequencies for rotary embeddings. Args: head_size: Size of attention heads (defaults to self.head_dim) rotary_dim: Dimension for rotary embeddings (defaults to head_size) base: Base value for frequency computation (defaults to self.rope_theta) Returns: ModuleCaches instance containing computed frequencies """ from easydel.layers.rotary_embedding import get_frequencies from .utils import ModuleCaches head_size = head_size or self.head_dim rotary_dim = rotary_dim or head_size rope_config = self._get_rope_config() frequencies = get_frequencies( head_size=head_size, rotary_dim=rotary_dim, max_position=self.granted_freq_max_position_embedding, base=base or self.rope_theta, rope_scaling=rope_config.to_dict(), ) return ModuleCaches(frequencies)
[docs] def get_basic_causal_mask(self, dtype="bool"): from .utils import ModuleCaches return ModuleCaches( nn.make_causal_mask( jnp.ones( shape=(1, self.granted_mask_max_position_embedding), dtype=dtype, ), dtype=dtype, ) )
[docs] def get_fcm_mask(self, batch_size, seq_length, deterministic: bool): if not deterministic and self.fcm_max_ratio > 0: # Apply forgetful causal mask fcm_ratio = jax.random.uniform( self.make_rng("fcm"), shape=(batch_size, 1, 1, 1), minval=self.fcm_min_ratio, maxval=self.fcm_max_ratio, ) fcm_mask = ( jax.random.uniform( self.make_rng("fcm"), shape=(batch_size, 1, seq_length, seq_length) ) > fcm_ratio ) fcm_mask = fcm_mask.at[:, :, :, 0].set(True) fcm_mask = fcm_mask.astype("bool") else: fcm_mask = None return fcm_mask
__hash__ = hash_fn
EasyDeLBaseConfigDict.__doc__ = EasyDeLBaseConfig.__init__.__doc__ EasyDeLBaseConfigDict.__annotations__ = EasyDeLBaseConfig.__annotations__