easydel.infra.base_config#

class easydel.infra.base_config.EasyDeLBaseConfig(axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = 'vanilla', blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, partition_axis: ~eformer.escale.partition.constraints.PartitionAxis = PartitionAxis(batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', attention_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, generation_query_sequence_axis=None, generation_head_axis='tp', generation_key_sequence_axis='sp', generation_attention_dim_axis=None), shard_attention_computation: bool = True, use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, easy_method: ~typing.Literal['train', 'serve', 'convert'] = 'train', bits: ~typing.Optional[int] = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: ~easydel.infra.etils.EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, kv_cache_quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, kv_cache_quantization_blocksize: int = 64, quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, quantization_pattern: str = '.*', quantization_blocksize: int = 64, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = 'sp', flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = 'triton', attn_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, attn_softmax_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, **kwargs)[source]#

Bases: PretrainedConfig

Initialize the configuration for EasyDeL. :param axis_dims: Dimensions of the axes. Default is (1, -1, 1, 1). :type axis_dims: tp.Sequence[int] :param axis_names: Names of the axes. Default is (โ€œdpโ€, โ€œfsdpโ€, โ€œtpโ€, โ€œspโ€). :type axis_names: tp.Sequence[str] :param attn_mechanism: Attention mechanism to use. Default is DEFAULT_ATTENTION_MECHANISM. :type attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS :param blocksize_k: Block size for key. Default is 128. :type blocksize_k: int :param blocksize_q: Block size for query. Default is 128. :type blocksize_q: int :param blocksize_b: Block size for batch. Default is 1. :type blocksize_b: int :param partition_axis: Partition axis configuration. Default is PartitionAxis(). :type partition_axis: PartitionAxis :param shard_attention_computation: Whether to shard attention computation. Default is True. :type shard_attention_computation: bool :param use_sharded_kv_caching: Whether to use sharded key-value caching. Default is False. :type use_sharded_kv_caching: bool :param use_sharding_constraint: Whether to use sharding constraint. Default is False. :type use_sharding_constraint: bool :param backend: Backend to use. Default is None. :type backend: tp.Optional[EasyDeLBackends] :param platform: Platform to use. Default is None. :type platform: tp.Optional[EasyDeLPlatforms] :param easy_method: Method to use. Default is EasyMethod.TRAIN. :type easy_method: tp.Literal[โ€œtrainโ€, โ€œserveโ€, โ€œconvertโ€] :param bits: Number of bits for quantization. Default is None. :type bits: tp.Optional[int] :param scan_ring_attention: Whether to scan ring attention. Default is True. :type scan_ring_attention: bool :param scan_attention_layers: Whether to scan attention layers. Default is False. :type scan_attention_layers: bool :param use_scan_mlp: Whether to use scan MLP. Default is False. :type use_scan_mlp: bool :param scan_mlp_chunk_size: Chunk size for scan MLP. Default is 1024. :type scan_mlp_chunk_size: int :param sequence_axis_name: Name of the attention axis. Default is โ€œspโ€. :type sequence_axis_name: str :param gradient_checkpointing: Gradient checkpointing method. Default is EasyDeLGradientCheckPointers.NONE. :type gradient_checkpointing: EasyDeLGradientCheckPointers :param kv_cache_quantization_method: Key-value cache quantization method. Default is EasyDeLQuantizationMethods.NONE. :type kv_cache_quantization_method: EasyDeLQuantizationMethods :param kv_cache_quantization_blocksize: Block size for key-value cache quantization. Default is 64. :type kv_cache_quantization_blocksize: int :param quantization_method: Quantization method. Default is EasyDeLQuantizationMethods.NONE. :type quantization_method: EasyDeLQuantizationMethods :param quantization_pattern: Pattern for quantization. Default is โ€œ.*โ€. :type quantization_pattern: str :param quantization_blocksize: Block size for quantization. Default is 64. :type quantization_blocksize: int :param kv_cache_sharding_sequence_axis_name: Name of the key-value cache sharding sequence axis. Default is โ€œspโ€. :type kv_cache_sharding_sequence_axis_name: tp.Union[str, tp.Tuple[str, โ€ฆ]] :param flash_attention_backward_pass_impl: Implementation for flash attention backward pass. Default is โ€œtritonโ€. :type flash_attention_backward_pass_impl: tp.Literal[โ€œtritonโ€, โ€œxlaโ€] :param attn_dtype: Data type for attention. Default is device half. :type attn_dtype: jnp.dtype :param attn_softmax_dtype: Data type for softmax ops in attention. Default is jnp.float32. :type attn_softmax_dtype: jnp.dtype :param fcm_max_ratio: Maximum ratio for FCM. Default is 0.0. :type fcm_max_ratio: float :param fcm_min_ratio: Minimum ratio for FCM. Default is 0.0. :type fcm_min_ratio: float :param hardware_abstraction: Whether to use hardware abstraction. Default is DEFAULT_HARDWARE_ABSTRACTION. :type hardware_abstraction: bool :param pallas_m_block_size: Block size for Pallas M. Default is DEFAULT_PALLAS_M_BLOCK_SIZE. :type pallas_m_block_size: int :param pallas_k_block_size: Block size for Pallas K. Default is DEFAULT_PALLAS_K_BLOCK_SIZE. :type pallas_k_block_size: int :param pallas_n_block_size: Block size for Pallas N. Default is DEFAULT_PALLAS_N_BLOCK_SIZE. :type pallas_n_block_size: int :param **kwargs: Additional keyword arguments.

Raises

Warning โ€“ If kv_cache_quantization_method is not NONE and use_sharded_kv_caching is True.

add_basic_configurations(axis_dims: Sequence[int] = Ellipsis, dcn_axis_dims: Optional[Sequence[int]] = Ellipsis, axis_names: Sequence[str] = Ellipsis, attn_mechanism: Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = Ellipsis, blocksize_k: int = Ellipsis, blocksize_q: int = Ellipsis, blocksize_b: int = Ellipsis, partition_axis: PartitionAxis = Ellipsis, shard_attention_computation: bool = Ellipsis, use_sharded_kv_caching: bool = Ellipsis, backend: Optional[EasyDeLBackends] = Ellipsis, platform: Optional[EasyDeLPlatforms] = Ellipsis, easy_method: Literal['train', 'serve', 'convert'] = Ellipsis, bits: Optional[int] = Ellipsis, scan_ring_attention: bool = Ellipsis, scan_attention_layers: bool = Ellipsis, use_sharding_constraint: bool = Ellipsis, use_scan_mlp: bool = Ellipsis, scan_mlp_chunk_size: int = Ellipsis, sequence_axis_name: str = Ellipsis, gradient_checkpointing: EasyDeLGradientCheckPointers = Ellipsis, kv_cache_quantization_method: EasyDeLQuantizationMethods = Ellipsis, kv_cache_quantization_blocksize: int = Ellipsis, quantization_method: EasyDeLQuantizationMethods = Ellipsis, quantization_blocksize: int = Ellipsis, quantization_pattern: str = Ellipsis, kv_cache_sharding_sequence_axis_name: Union[str, Tuple[str, ...]] = Ellipsis, flash_attention_backward_pass_impl: Literal['triton', 'xla'] = Ellipsis, attn_dtype: dtype = Ellipsis, attn_softmax_dtype: dtype = Ellipsis, hardware_abstraction: bool = Ellipsis, pallas_m_block_size: int = Ellipsis, pallas_k_block_size: int = Ellipsis, pallas_n_block_size: int = Ellipsis)[source]#

It initializes all the attributes of an object, and itโ€™s called when you create a new instance of that class.

Parameters
  • axis_dims (tp.Sequence[int], optional) โ€“ Specify the number of dimensions for each axis. Defaults to (1, -1, 1, 1).

  • axis_names (tp.Sequence[str], optional) โ€“ Set the names of the axes. Defaults to (โ€œdpโ€, โ€œfsdpโ€, โ€œtpโ€, โ€œspโ€).

  • attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS, optional) โ€“ attention mechanism to use. Defaults to DEFAULT_ATTENTION_MECHANISM.

  • blocksize_k (int, optional) โ€“ block size of key_states. Defaults to 128.

  • blocksize_q (int, optional) โ€“ block size of query_states. Defaults to 128.

  • blocksize_b (int, optional) โ€“ block size of bias. Defaults to 1.

  • partition_axis (PartitionAxis, optional) โ€“ PartitionAxis is new module used for partitioning arrays in easydel. Defaults to PartitionAxis().

  • shard_attention_computation (bool, optional) โ€“ whenever to use shard_map for attention. Defaults to True.

  • use_sharded_kv_caching (bool, optional) โ€“ whenever to use shard_map and sharding for key and value. Defaults to True.

  • backend (tp.Optional[EasyDeLBackends], optional) โ€“ Specify the backend to use. Defaults to None.

  • platform (tp.Optional[EasyDeLPlatforms], optional) โ€“ Specify the platform to used to use. Defaults to None.

  • easy_method (tp.Literal["train", "serve", "convert"], optional) โ€“ easydel Quantization Method to be applied for. Defaults to EasyMethod.TRAIN.

  • bits (tp.Optional[int], optional) โ€“ Model bits for quantization. Defaults to None.

  • scan_ring_attention (bool, optional) โ€“ Whether to use can for ring attention. Defaults to True.

  • scan_attention_layers (bool, optional) โ€“ Whether to use can for attention layers. Defaults to False.

  • use_sharding_constraint (bool, optional) โ€“ whether to use sharding constraint for the arrays. Defaults to False.

  • use_scan_mlp (bool, optional) โ€“ Determine whether to use scan_mlp or not. Defaults to False.

  • scan_mlp_chunk_size (int, optional) โ€“ Size of chunks in scan MLP. Defaults to 1024.

  • sequence_axis_name (str, optional) โ€“ Name of the attention axis name. Defaults to โ€œspโ€.

  • gradient_checkpointing (EasyDeLQuantizationMethods, optional) โ€“ Gradient Checkpointing method for created or loaded module (applied on mlp and attn layers most of the times).

  • kv_cache_quantization_method (EasyDeLQuantizationMethods, optional) โ€“ key and value quantization type. Defaults to EasyDeLQuantizationMethods.NONE.

  • kv_cache_quantization_blocksize (int, optional) โ€“ size of kv cache quantization. Defaults to 64.

  • quantization_method (EasyDeLQuantizationMethods, optional) โ€“ linear modules quantization type. Defaults to EasyDeLQuantizationMethods.NONE.

  • quantization_blocksize (int, optional) โ€“ size of linear quantization. Defaults to 64.

  • quantization_pattern (str) โ€“ re pattern to be used for quantizing layers.

  • kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]], optional) โ€“ axis name to target for sharding sequences. Defaults to โ€œspโ€.

  • flash_attention_backward_pass_impl (tp.Literal["triton", "xla"], optional) โ€“ Specify the backward pass kernel for flash attention. Defaults to โ€œtritonโ€.

  • attn_dtype (jnp.dtype, optional) โ€“ Data type for attention computations. Defaults to device half.

  • attn_softmax_dtype (jnp.dtype, optional) โ€“ Data type for softmax in attention op computations. Defaults to jnp.float32.

  • fcm_max_ratio (float, optional) โ€“ Maximum ratio for flash cross attention. Defaults to 0.0.

  • fcm_min_ratio (float, optional) โ€“ Minimum ratio for flash cross attention. Defaults to 0.0.

  • hardware_abstraction (bool, optional) โ€“ whenever to switch to custom pallas kernels instead of JAX. Defaults to DEFAULT_HARDWARE_ABSTRACTION.

  • pallas_m_block_size (int, optional) โ€“ block size m dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_M_BLOCK_SIZE.

  • pallas_k_block_size (int, optional) โ€“ block size k dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_K_BLOCK_SIZE.

  • pallas_n_block_size (int, optional) โ€“ block size n dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_N_BLOCK_SIZE.

add_jax_args(**kwargs)[source]#
static create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: Optional[str] = None)[source]#

The create_mesh function creates a mesh object that can be used to shard arrays.

Returns

A mesh object

classmethod from_pretrained(pretrained_model_name_or_path: Union[str, PathLike], cache_dir: Optional[Union[str, PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) โ€“

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) โ€“ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) โ€“ Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download โ€“ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) โ€“ A dictionary of proxy servers to use by protocol or endpoint, e.g., {โ€˜httpโ€™: โ€˜foo.bar:3128โ€™, โ€˜http://hostnameโ€™: โ€˜foo.bar:4012โ€™}. The proxies are used on each request.

  • token (str or bool, optional) โ€“ The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to โ€œmainโ€) โ€“

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=โ€refs/pr/<pr_number>โ€.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) โ€“

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to โ€œโ€) โ€“ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) โ€“ The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

get_axis_dims() Sequence[int][source]#

The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.

Parameters

self โ€“ Represent the instance of the class

Returns

The dimensions of the axes

get_axis_names() Sequence[str][source]#

The get_axis_names function returns a list of the names of the axes.

Parameters

self โ€“ Represent the instance of the class

Returns

A list of the names of all axes

get_backend() str[source]#

The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend.

Parameters

self โ€“ Bind the method to an object

Returns

The backend platform

get_basic_causal_mask(dtype='bool')[source]#
get_basic_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None) Any[source]#

Get basic frequencies for rotary embeddings.

Parameters
  • head_size โ€“ Size of attention heads (defaults to self.head_dim)

  • rotary_dim โ€“ Dimension for rotary embeddings (defaults to head_size)

  • base โ€“ Base value for frequency computation (defaults to self.rope_theta)

Returns

ModuleCaches instance containing computed frequencies

get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: Optional[int] = None, is_neox_style: bool = True, base: Optional[float] = None)[source]#

Get basic rotary position embeddings.

Parameters
  • dtype โ€“ Data type for the embeddings

  • head_size โ€“ Size of attention heads

  • rotary_dim โ€“ Dimension for rotary embeddings (defaults to head_size)

  • is_neox_style โ€“ Whether to use NeoX style embeddings

  • base โ€“ Base value for frequency computation (defaults to self.rope_theta)

Returns

Rotary position embeddings func

get_fcm_mask(batch_size, seq_length, deterministic: bool)[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
jax_mesh()[source]#
property mesh#

The mesh property is a helper property that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The platform attribute is also used if it exists.

Parameters

self โ€“ Refer to the object itself

Returns

A jaxMesh

to_dict() Dict[str, Any][source]#

Serializes this instance to a Python dictionary.

Returns

Dictionary of all the attributes that make up this configuration instance.

Return type

Dict[str, Any]

class easydel.infra.base_config.EasyDeLBaseConfigDict[source]#

Bases: TypedDict

class easydel.infra.base_config.EasyMethod(TRAIN: str = 'train', SERVE: str = 'serve', EVAL: str = 'serve', CONVERT: str = 'convert')[source]#

Bases: object

CONVERT: str = 'convert'#
EVAL: str = 'serve'#
SERVE: str = 'serve'#
TRAIN: str = 'train'#
easydel.infra.base_config.set_attrs_smartly(self, attr_name: str, default: Any, new_attr: Any)[source]#