easydel.infra.base_config#
- class easydel.infra.base_config.EasyDeLBaseConfig(axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn'] = 'vanilla', decode_attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn'] = None, blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = PartitionAxis(data_parallel_axis='dp', fully_sharded_data_parallel_axis='fsdp', tensor_parallel_axis='tp', sequence_parallel_axis='sp', expert_parallel_axis='ep', batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', kv_head_axis=None, key_sequence_axis='sp', hidden_state_axis='tp', mlp_intermediate_axis='tp', vocab_axis='tp', expert_axis='ep', expert_gate_axis=None, attention_dim_axis=None, attention_kv_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, decode_batch_axis=('fsdp', 'dp'), decode_query_sequence_axis=None, decode_head_axis='tp', decode_kv_head_axis=None, decode_key_sequence_axis='sp', decode_attention_dim_axis=None, decode_attention_kv_dim_axis=None), shard_attention_computation: bool = True, use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, easy_method: ~typing.Literal['train', 'serve', 'convert'] = 'train', bits: ~typing.Optional[int] = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: ~easydel.infra.etils.EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, precompute_masks: bool = True, kv_cache_quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, kv_cache_quantization_blocksize: int = 64, quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, quantization_pattern: str = '.*', quantization_blocksize: int = 64, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = 'sp', flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = 'triton', attn_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, attn_softmax_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, **kwargs)[source]#
Bases:
PretrainedConfigInitialize the configuration for EasyDeL. :param axis_dims: Dimensions of the axes. Default is (1, -1, 1, 1). :type axis_dims: tp.Sequence[int] :param axis_names: Names of the axes. Default is (โdpโ, โfsdpโ, โtpโ, โspโ). :type axis_names: tp.Sequence[str] :param attn_mechanism: Attention mechanism to use. Default is DEFAULT_ATTENTION_MECHANISM.
decode_attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS): Attention mechanism to use for decode phase. Default is None.
- Parameters
blocksize_k (int) โ Block size for key. Default is 128.
blocksize_q (int) โ Block size for query. Default is 128.
blocksize_b (int) โ Block size for batch. Default is 1.
partition_axis (PartitionAxis) โ Partition axis configuration. Default is PartitionAxis().
shard_attention_computation (bool) โ Whether to shard attention computation. Default is True.
use_sharded_kv_caching (bool) โ Whether to use sharded key-value caching. Default is False.
use_sharding_constraint (bool) โ Whether to use sharding constraint. Default is False.
backend (tp.Optional[EasyDeLBackends]) โ Backend to use. Default is None.
platform (tp.Optional[EasyDeLPlatforms]) โ Platform to use. Default is None.
easy_method (tp.Literal["train", "serve", "convert"]) โ Method to use. Default is EasyMethod.TRAIN.
bits (tp.Optional[int]) โ Number of bits for quantization. Default is None.
scan_ring_attention (bool) โ Whether to scan ring attention. Default is True.
scan_attention_layers (bool) โ Whether to scan attention layers. Default is False.
use_scan_mlp (bool) โ Whether to use scan MLP. Default is False.
scan_mlp_chunk_size (int) โ Chunk size for scan MLP. Default is 1024.
sequence_axis_name (str) โ Name of the attention axis. Default is โspโ.
gradient_checkpointing (EasyDeLGradientCheckPointers) โ Gradient checkpointing method. Default is EasyDeLGradientCheckPointers.NONE.
kv_cache_quantization_method (EasyDeLQuantizationMethods) โ Key-value cache quantization method. Default is EasyDeLQuantizationMethods.NONE.
kv_cache_quantization_blocksize (int) โ Block size for key-value cache quantization. Default is 64.
quantization_method (EasyDeLQuantizationMethods) โ Quantization method. Default is EasyDeLQuantizationMethods.NONE.
quantization_pattern (str) โ Pattern for quantization. Default is โ.*โ.
quantization_blocksize (int) โ Block size for quantization. Default is 64.
kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]]) โ Name of the key-value cache sharding sequence axis. Default is โspโ.
flash_attention_backward_pass_impl (tp.Literal["triton", "xla"]) โ Implementation for flash attention backward pass. Default is โtritonโ.
attn_dtype (jnp.dtype) โ Data type for attention. Default is device half.
attn_softmax_dtype (jnp.dtype) โ Data type for softmax ops in attention. Default is jnp.float32.
fcm_max_ratio (float) โ Maximum ratio for FCM. Default is 0.0.
fcm_min_ratio (float) โ Minimum ratio for FCM. Default is 0.0.
hardware_abstraction (bool) โ Whether to use hardware abstraction. Default is DEFAULT_HARDWARE_ABSTRACTION.
pallas_m_block_size (int) โ Block size for Pallas M. Default is DEFAULT_PALLAS_M_BLOCK_SIZE.
pallas_k_block_size (int) โ Block size for Pallas K. Default is DEFAULT_PALLAS_K_BLOCK_SIZE.
pallas_n_block_size (int) โ Block size for Pallas N. Default is DEFAULT_PALLAS_N_BLOCK_SIZE.
**kwargs โ Additional keyword arguments.
- Raises
Warning โ If kv_cache_quantization_method is not NONE and use_sharded_kv_caching is True.
- add_basic_configurations(axis_dims: ~typing.Sequence[int] = <eformer.common_types._Empty object>, dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = <eformer.common_types._Empty object>, axis_names: ~typing.Sequence[str] = <eformer.common_types._Empty object>, attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn'] = <eformer.common_types._Empty object>, decode_attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa', 'autoregressive_decodeattn'] = <eformer.common_types._Empty object>, blocksize_k: int = <eformer.common_types._Empty object>, blocksize_q: int = <eformer.common_types._Empty object>, blocksize_b: int = <eformer.common_types._Empty object>, partition_axis: ~eformer.escale.partition.manager.PartitionAxis = <eformer.common_types._Empty object>, shard_attention_computation: bool = <eformer.common_types._Empty object>, use_sharded_kv_caching: bool = <eformer.common_types._Empty object>, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = <eformer.common_types._Empty object>, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = <eformer.common_types._Empty object>, easy_method: ~typing.Literal['train', 'serve', 'convert'] = <eformer.common_types._Empty object>, bits: ~typing.Optional[int] = <eformer.common_types._Empty object>, scan_ring_attention: bool = <eformer.common_types._Empty object>, scan_attention_layers: bool = <eformer.common_types._Empty object>, use_sharding_constraint: bool = <eformer.common_types._Empty object>, use_scan_mlp: bool = <eformer.common_types._Empty object>, scan_mlp_chunk_size: int = <eformer.common_types._Empty object>, sequence_axis_name: str = <eformer.common_types._Empty object>, gradient_checkpointing: ~easydel.infra.etils.EasyDeLGradientCheckPointers = <eformer.common_types._Empty object>, precompute_masks: bool = <eformer.common_types._Empty object>, kv_cache_quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = <eformer.common_types._Empty object>, kv_cache_quantization_blocksize: int = <eformer.common_types._Empty object>, quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = <eformer.common_types._Empty object>, quantization_blocksize: int = <eformer.common_types._Empty object>, quantization_pattern: str = <eformer.common_types._Empty object>, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = <eformer.common_types._Empty object>, flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = <eformer.common_types._Empty object>, attn_dtype: ~numpy.dtype = <eformer.common_types._Empty object>, attn_softmax_dtype: ~numpy.dtype = <eformer.common_types._Empty object>, hardware_abstraction: bool = <eformer.common_types._Empty object>, pallas_m_block_size: int = <eformer.common_types._Empty object>, pallas_k_block_size: int = <eformer.common_types._Empty object>, pallas_n_block_size: int = <eformer.common_types._Empty object>, **kwargs)[source]#
It initializes all the attributes of an object, and itโs called when you create a new instance of that class.
- Parameters
axis_dims (tp.Sequence[int], optional) โ Specify the number of dimensions for each axis. Defaults to (1, -1, 1, 1).
axis_names (tp.Sequence[str], optional) โ Set the names of the axes. Defaults to (โdpโ, โfsdpโ, โtpโ, โspโ).
attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS, optional) โ attention mechanism to use. Defaults to DEFAULT_ATTENTION_MECHANISM. decode_attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS): Attention mechanism to use for decode phase. Default is None.
blocksize_k (int, optional) โ block size of key_states. Defaults to 128.
blocksize_q (int, optional) โ block size of query_states. Defaults to 128.
blocksize_b (int, optional) โ block size of bias. Defaults to 1.
partition_axis (PartitionAxis, optional) โ PartitionAxis is new module used for partitioning arrays in easydel. Defaults to PartitionAxis().
shard_attention_computation (bool, optional) โ whenever to use shard_map for attention. Defaults to True.
use_sharded_kv_caching (bool, optional) โ whenever to use shard_map and sharding for key and value. Defaults to True.
backend (tp.Optional[EasyDeLBackends], optional) โ Specify the backend to use. Defaults to None.
platform (tp.Optional[EasyDeLPlatforms], optional) โ Specify the platform to used to use. Defaults to None.
easy_method (tp.Literal["train", "serve", "convert"], optional) โ easydel Quantization Method to be applied for. Defaults to EasyMethod.TRAIN.
bits (tp.Optional[int], optional) โ Model bits for quantization. Defaults to None.
scan_ring_attention (bool, optional) โ Whether to use can for ring attention. Defaults to True.
scan_attention_layers (bool, optional) โ Whether to use can for attention layers. Defaults to False.
use_sharding_constraint (bool, optional) โ whether to use sharding constraint for the arrays. Defaults to False.
use_scan_mlp (bool, optional) โ Determine whether to use scan_mlp or not. Defaults to False.
scan_mlp_chunk_size (int, optional) โ Size of chunks in scan MLP. Defaults to 1024.
sequence_axis_name (str, optional) โ Name of the attention axis name. Defaults to โspโ.
gradient_checkpointing (EasyDeLQuantizationMethods, optional) โ Gradient Checkpointing method for created or loaded module (applied on mlp and attn layers most of the times).
kv_cache_quantization_method (EasyDeLQuantizationMethods, optional) โ key and value quantization type. Defaults to EasyDeLQuantizationMethods.NONE.
kv_cache_quantization_blocksize (int, optional) โ size of kv cache quantization. Defaults to 64.
quantization_method (EasyDeLQuantizationMethods, optional) โ linear modules quantization type. Defaults to EasyDeLQuantizationMethods.NONE.
quantization_blocksize (int, optional) โ size of linear quantization. Defaults to 64.
quantization_pattern (str) โ re pattern to be used for quantizing layers.
kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]], optional) โ axis name to target for sharding sequences. Defaults to โspโ.
flash_attention_backward_pass_impl (tp.Literal["triton", "xla"], optional) โ Specify the backward pass kernel for flash attention. Defaults to โtritonโ.
attn_dtype (jnp.dtype, optional) โ Data type for attention computations. Defaults to device half.
attn_softmax_dtype (jnp.dtype, optional) โ Data type for softmax in attention op computations. Defaults to jnp.float32.
fcm_max_ratio (float, optional) โ Maximum ratio for flash cross attention. Defaults to 0.0.
fcm_min_ratio (float, optional) โ Minimum ratio for flash cross attention. Defaults to 0.0.
hardware_abstraction (bool, optional) โ whenever to switch to custom pallas kernels instead of JAX. Defaults to DEFAULT_HARDWARE_ABSTRACTION.
pallas_m_block_size (int, optional) โ block size m dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_M_BLOCK_SIZE.
pallas_k_block_size (int, optional) โ block size k dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_K_BLOCK_SIZE.
pallas_n_block_size (int, optional) โ block size n dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_N_BLOCK_SIZE.
- static create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: Optional[str] = None)[source]#
The create_mesh function creates a mesh object that can be used to shard arrays.
- Returns
A mesh object
- classmethod from_pretrained(pretrained_model_name_or_path: Union[str, PathLike], cache_dir: Optional[Union[str, PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#
Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.
- Parameters
pretrained_model_name_or_path (str or os.PathLike) โ
This can be either:
a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.
a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.
a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.
cache_dir (str or os.PathLike, optional) โ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
force_download (bool, optional, defaults to False) โ Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.
resume_download โ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
proxies (Dict[str, str], optional) โ A dictionary of proxy servers to use by protocol or endpoint, e.g., {โhttpโ: โfoo.bar:3128โ, โhttp://hostnameโ: โfoo.bar:4012โ}. The proxies are used on each request.
token (str or bool, optional) โ The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
revision (str, optional, defaults to โmainโ) โ
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass `revision=โrefs/pr/<pr_number>โ.
</Tip>
return_unused_kwargs (bool, optional, defaults to False) โ
If False, then this function returns just the final configuration object.
If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.
subfolder (str, optional, defaults to โโ) โ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
kwargs (Dict[str, tp.Any], optional) โ The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.
- Returns
The configuration object instantiated from this pretrained model.
- Return type
[PretrainedConfig]
Examples:
>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a >>> # derived class: BertConfig >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased" >>> ) # Download configuration from huggingface.co and cache. >>> config = BertConfig.from_pretrained( ... "./test/saved_model/" >>> ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* >>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", output_attentions=True, foo=False >>> ) >>> assert config.output_attentions == True >>> config, unused_kwargs = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", ... output_attentions=True, ... foo=False, ... return_unused_kwargs=True, >>> ) >>> assert config.output_attentions == True >>> assert unused_kwargs == {"foo": False}
- get_axis_dims() Sequence[int][source]#
The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.
- Parameters
self โ Represent the instance of the class
- Returns
The dimensions of the axes
- get_axis_names() Sequence[str][source]#
The get_axis_names function returns a list of the names of the axes.
- Parameters
self โ Represent the instance of the class
- Returns
A list of the names of all axes
- get_backend() str[source]#
The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend.
- Parameters
self โ Bind the method to an object
- Returns
The backend platform
- get_basic_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None) Any[source]#
Get basic frequencies for rotary embeddings.
- Parameters
head_size โ Size of attention heads (defaults to self.head_dim)
rotary_dim โ Dimension for rotary embeddings (defaults to head_size)
base โ Base value for frequency computation (defaults to self.rope_theta)
- Returns
ModuleCaches instance containing computed frequencies
- get_basic_inv_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None, partial_rotary_factor: float = 1.0) Any[source]#
Get basic inv frequencies for rotary embeddings.
- Parameters
head_size โ Size of attention heads (defaults to self.head_dim)
rotary_dim โ Dimension for rotary embeddings (defaults to head_size)
base โ Base value for frequency computation (defaults to self.rope_theta)
- Returns
ModuleCaches instance containing computed frequencies
- get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: Optional[int] = None, is_neox_style: bool = True, base: Optional[float] = None)[source]#
Get basic rotary position embeddings.
- Parameters
dtype โ Data type for the embeddings
head_size โ Size of attention heads
rotary_dim โ Dimension for rotary embeddings (defaults to head_size)
is_neox_style โ Whether to use NeoX style embeddings
base โ Base value for frequency computation (defaults to self.rope_theta)
- Returns
Rotary position embeddings func
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
- property granted_mask_max_position_embedding: int#
- property mesh#
The mesh property is a helper property that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The platform attribute is also used if it exists.
- Parameters
self โ Refer to the object itself
- Returns
A jaxMesh
- property partition_manager: PartitionManager#
- read_basics_from_config(config: EasyDeLBaseConfig)[source]#
- class easydel.infra.base_config.EasyMethod(TRAIN: 'str' = 'train', SERVE: 'str' = 'serve', EVAL: 'str' = 'serve', CONVERT: 'str' = 'convert')[source]#
Bases:
object- CONVERT: str = 'convert'#
- EVAL: str = 'serve'#
- SERVE: str = 'serve'#
- TRAIN: str = 'train'#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.