easydel.infra.__init__

Contents

easydel.infra.__init__#

class easydel.infra.__init__.EasyDeLBaseConfig(axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = 'vanilla', blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, partition_axis: ~eformer.escale.partition.constraints.PartitionAxis = PartitionAxis(batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', attention_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, generation_query_sequence_axis=None, generation_head_axis='tp', generation_key_sequence_axis='sp', generation_attention_dim_axis=None), shard_attention_computation: bool = True, use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, easy_method: ~typing.Literal['train', 'serve', 'convert'] = 'train', bits: ~typing.Optional[int] = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: ~easydel.infra.etils.EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, kv_cache_quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, kv_cache_quantization_blocksize: int = 64, quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, quantization_pattern: str = '.*', quantization_blocksize: int = 64, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = 'sp', flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = 'triton', attn_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, attn_softmax_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, **kwargs)[source]#

Bases: PretrainedConfig

Initialize the configuration for EasyDeL. :param axis_dims: Dimensions of the axes. Default is (1, -1, 1, 1). :type axis_dims: tp.Sequence[int] :param axis_names: Names of the axes. Default is (“dp”, “fsdp”, “tp”, “sp”). :type axis_names: tp.Sequence[str] :param attn_mechanism: Attention mechanism to use. Default is DEFAULT_ATTENTION_MECHANISM. :type attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS :param blocksize_k: Block size for key. Default is 128. :type blocksize_k: int :param blocksize_q: Block size for query. Default is 128. :type blocksize_q: int :param blocksize_b: Block size for batch. Default is 1. :type blocksize_b: int :param partition_axis: Partition axis configuration. Default is PartitionAxis(). :type partition_axis: PartitionAxis :param shard_attention_computation: Whether to shard attention computation. Default is True. :type shard_attention_computation: bool :param use_sharded_kv_caching: Whether to use sharded key-value caching. Default is False. :type use_sharded_kv_caching: bool :param use_sharding_constraint: Whether to use sharding constraint. Default is False. :type use_sharding_constraint: bool :param backend: Backend to use. Default is None. :type backend: tp.Optional[EasyDeLBackends] :param platform: Platform to use. Default is None. :type platform: tp.Optional[EasyDeLPlatforms] :param easy_method: Method to use. Default is EasyMethod.TRAIN. :type easy_method: tp.Literal[“train”, “serve”, “convert”] :param bits: Number of bits for quantization. Default is None. :type bits: tp.Optional[int] :param scan_ring_attention: Whether to scan ring attention. Default is True. :type scan_ring_attention: bool :param scan_attention_layers: Whether to scan attention layers. Default is False. :type scan_attention_layers: bool :param use_scan_mlp: Whether to use scan MLP. Default is False. :type use_scan_mlp: bool :param scan_mlp_chunk_size: Chunk size for scan MLP. Default is 1024. :type scan_mlp_chunk_size: int :param sequence_axis_name: Name of the attention axis. Default is “sp”. :type sequence_axis_name: str :param gradient_checkpointing: Gradient checkpointing method. Default is EasyDeLGradientCheckPointers.NONE. :type gradient_checkpointing: EasyDeLGradientCheckPointers :param kv_cache_quantization_method: Key-value cache quantization method. Default is EasyDeLQuantizationMethods.NONE. :type kv_cache_quantization_method: EasyDeLQuantizationMethods :param kv_cache_quantization_blocksize: Block size for key-value cache quantization. Default is 64. :type kv_cache_quantization_blocksize: int :param quantization_method: Quantization method. Default is EasyDeLQuantizationMethods.NONE. :type quantization_method: EasyDeLQuantizationMethods :param quantization_pattern: Pattern for quantization. Default is “.*”. :type quantization_pattern: str :param quantization_blocksize: Block size for quantization. Default is 64. :type quantization_blocksize: int :param kv_cache_sharding_sequence_axis_name: Name of the key-value cache sharding sequence axis. Default is “sp”. :type kv_cache_sharding_sequence_axis_name: tp.Union[str, tp.Tuple[str, …]] :param flash_attention_backward_pass_impl: Implementation for flash attention backward pass. Default is “triton”. :type flash_attention_backward_pass_impl: tp.Literal[“triton”, “xla”] :param attn_dtype: Data type for attention. Default is device half. :type attn_dtype: jnp.dtype :param attn_softmax_dtype: Data type for softmax ops in attention. Default is jnp.float32. :type attn_softmax_dtype: jnp.dtype :param fcm_max_ratio: Maximum ratio for FCM. Default is 0.0. :type fcm_max_ratio: float :param fcm_min_ratio: Minimum ratio for FCM. Default is 0.0. :type fcm_min_ratio: float :param hardware_abstraction: Whether to use hardware abstraction. Default is DEFAULT_HARDWARE_ABSTRACTION. :type hardware_abstraction: bool :param pallas_m_block_size: Block size for Pallas M. Default is DEFAULT_PALLAS_M_BLOCK_SIZE. :type pallas_m_block_size: int :param pallas_k_block_size: Block size for Pallas K. Default is DEFAULT_PALLAS_K_BLOCK_SIZE. :type pallas_k_block_size: int :param pallas_n_block_size: Block size for Pallas N. Default is DEFAULT_PALLAS_N_BLOCK_SIZE. :type pallas_n_block_size: int :param **kwargs: Additional keyword arguments.

Raises

Warning – If kv_cache_quantization_method is not NONE and use_sharded_kv_caching is True.

add_basic_configurations(axis_dims: Sequence[int] = Ellipsis, dcn_axis_dims: Optional[Sequence[int]] = Ellipsis, axis_names: Sequence[str] = Ellipsis, attn_mechanism: Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = Ellipsis, blocksize_k: int = Ellipsis, blocksize_q: int = Ellipsis, blocksize_b: int = Ellipsis, partition_axis: PartitionAxis = Ellipsis, shard_attention_computation: bool = Ellipsis, use_sharded_kv_caching: bool = Ellipsis, backend: Optional[EasyDeLBackends] = Ellipsis, platform: Optional[EasyDeLPlatforms] = Ellipsis, easy_method: Literal['train', 'serve', 'convert'] = Ellipsis, bits: Optional[int] = Ellipsis, scan_ring_attention: bool = Ellipsis, scan_attention_layers: bool = Ellipsis, use_sharding_constraint: bool = Ellipsis, use_scan_mlp: bool = Ellipsis, scan_mlp_chunk_size: int = Ellipsis, sequence_axis_name: str = Ellipsis, gradient_checkpointing: EasyDeLGradientCheckPointers = Ellipsis, kv_cache_quantization_method: EasyDeLQuantizationMethods = Ellipsis, kv_cache_quantization_blocksize: int = Ellipsis, quantization_method: EasyDeLQuantizationMethods = Ellipsis, quantization_blocksize: int = Ellipsis, quantization_pattern: str = Ellipsis, kv_cache_sharding_sequence_axis_name: Union[str, Tuple[str, ...]] = Ellipsis, flash_attention_backward_pass_impl: Literal['triton', 'xla'] = Ellipsis, attn_dtype: dtype = Ellipsis, attn_softmax_dtype: dtype = Ellipsis, hardware_abstraction: bool = Ellipsis, pallas_m_block_size: int = Ellipsis, pallas_k_block_size: int = Ellipsis, pallas_n_block_size: int = Ellipsis, **kwargs)[source]#

It initializes all the attributes of an object, and it’s called when you create a new instance of that class.

Parameters
  • axis_dims (tp.Sequence[int], optional) – Specify the number of dimensions for each axis. Defaults to (1, -1, 1, 1).

  • axis_names (tp.Sequence[str], optional) – Set the names of the axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).

  • attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS, optional) – attention mechanism to use. Defaults to DEFAULT_ATTENTION_MECHANISM.

  • blocksize_k (int, optional) – block size of key_states. Defaults to 128.

  • blocksize_q (int, optional) – block size of query_states. Defaults to 128.

  • blocksize_b (int, optional) – block size of bias. Defaults to 1.

  • partition_axis (PartitionAxis, optional) – PartitionAxis is new module used for partitioning arrays in easydel. Defaults to PartitionAxis().

  • shard_attention_computation (bool, optional) – whenever to use shard_map for attention. Defaults to True.

  • use_sharded_kv_caching (bool, optional) – whenever to use shard_map and sharding for key and value. Defaults to True.

  • backend (tp.Optional[EasyDeLBackends], optional) – Specify the backend to use. Defaults to None.

  • platform (tp.Optional[EasyDeLPlatforms], optional) – Specify the platform to used to use. Defaults to None.

  • easy_method (tp.Literal["train", "serve", "convert"], optional) – easydel Quantization Method to be applied for. Defaults to EasyMethod.TRAIN.

  • bits (tp.Optional[int], optional) – Model bits for quantization. Defaults to None.

  • scan_ring_attention (bool, optional) – Whether to use can for ring attention. Defaults to True.

  • scan_attention_layers (bool, optional) – Whether to use can for attention layers. Defaults to False.

  • use_sharding_constraint (bool, optional) – whether to use sharding constraint for the arrays. Defaults to False.

  • use_scan_mlp (bool, optional) – Determine whether to use scan_mlp or not. Defaults to False.

  • scan_mlp_chunk_size (int, optional) – Size of chunks in scan MLP. Defaults to 1024.

  • sequence_axis_name (str, optional) – Name of the attention axis name. Defaults to “sp”.

  • gradient_checkpointing (EasyDeLQuantizationMethods, optional) – Gradient Checkpointing method for created or loaded module (applied on mlp and attn layers most of the times).

  • kv_cache_quantization_method (EasyDeLQuantizationMethods, optional) – key and value quantization type. Defaults to EasyDeLQuantizationMethods.NONE.

  • kv_cache_quantization_blocksize (int, optional) – size of kv cache quantization. Defaults to 64.

  • quantization_method (EasyDeLQuantizationMethods, optional) – linear modules quantization type. Defaults to EasyDeLQuantizationMethods.NONE.

  • quantization_blocksize (int, optional) – size of linear quantization. Defaults to 64.

  • quantization_pattern (str) – re pattern to be used for quantizing layers.

  • kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]], optional) – axis name to target for sharding sequences. Defaults to “sp”.

  • flash_attention_backward_pass_impl (tp.Literal["triton", "xla"], optional) – Specify the backward pass kernel for flash attention. Defaults to “triton”.

  • attn_dtype (jnp.dtype, optional) – Data type for attention computations. Defaults to device half.

  • attn_softmax_dtype (jnp.dtype, optional) – Data type for softmax in attention op computations. Defaults to jnp.float32.

  • fcm_max_ratio (float, optional) – Maximum ratio for flash cross attention. Defaults to 0.0.

  • fcm_min_ratio (float, optional) – Minimum ratio for flash cross attention. Defaults to 0.0.

  • hardware_abstraction (bool, optional) – whenever to switch to custom pallas kernels instead of JAX. Defaults to DEFAULT_HARDWARE_ABSTRACTION.

  • pallas_m_block_size (int, optional) – block size m dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_M_BLOCK_SIZE.

  • pallas_k_block_size (int, optional) – block size k dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_K_BLOCK_SIZE.

  • pallas_n_block_size (int, optional) – block size n dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_N_BLOCK_SIZE.

attach_custom_arguments(**kwargs)[source]#
static create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: Optional[str] = None)[source]#

The create_mesh function creates a mesh object that can be used to shard arrays.

Returns

A mesh object

classmethod from_pretrained(pretrained_model_name_or_path: Union[str, PathLike], cache_dir: Optional[Union[str, PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[bool, str]] = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#

Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.

Parameters
  • pretrained_model_name_or_path (str or os.PathLike) –

    This can be either:

    • a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.

    • a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.

    • a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.

  • cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

  • force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.

  • resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

  • proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

  • token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).

  • revision (str, optional, defaults to “main”) –

    The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

    <Tip>

    To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.

    </Tip>

  • return_unused_kwargs (bool, optional, defaults to False) –

    If False, then this function returns just the final configuration object.

    If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.

  • subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

  • kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.

Returns

The configuration object instantiated from this pretrained model.

Return type

[PretrainedConfig]

Examples:

>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
>>> # derived class: BertConfig
>>> config = BertConfig.from_pretrained(
...   "google-bert/bert-base-uncased"
>>> )  # Download configuration from huggingface.co and cache.
>>> config = BertConfig.from_pretrained(
...   "./test/saved_model/"
>>> )  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
>>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
>>> config = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased", output_attentions=True, foo=False
>>> )
>>> assert config.output_attentions == True
>>> config, unused_kwargs = BertConfig.from_pretrained(
...  "google-bert/bert-base-uncased",
...  output_attentions=True,
...  foo=False,
...  return_unused_kwargs=True,
>>> )
>>> assert config.output_attentions == True
>>> assert unused_kwargs == {"foo": False}

```

get_axis_dims() Sequence[int][source]#

The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.

Parameters

self – Represent the instance of the class

Returns

The dimensions of the axes

get_axis_names() Sequence[str][source]#

The get_axis_names function returns a list of the names of the axes.

Parameters

self – Represent the instance of the class

Returns

A list of the names of all axes

get_backend() str[source]#

The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend.

Parameters

self – Bind the method to an object

Returns

The backend platform

get_basic_causal_mask(*args, **kwargs)[source]#
get_basic_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None) Any[source]#

Get basic frequencies for rotary embeddings.

Parameters
  • head_size – Size of attention heads (defaults to self.head_dim)

  • rotary_dim – Dimension for rotary embeddings (defaults to head_size)

  • base – Base value for frequency computation (defaults to self.rope_theta)

Returns

ModuleCaches instance containing computed frequencies

get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: Optional[int] = None, is_neox_style: bool = True, base: Optional[float] = None)[source]#

Get basic rotary position embeddings.

Parameters
  • dtype – Data type for the embeddings

  • head_size – Size of attention heads

  • rotary_dim – Dimension for rotary embeddings (defaults to head_size)

  • is_neox_style – Whether to use NeoX style embeddings

  • base – Base value for frequency computation (defaults to self.rope_theta)

Returns

Rotary position embeddings func

get_fcm_mask(batch_size, seq_length, deterministic: bool)[source]#
get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]

property granted_freq_max_position_embedding: int#
property granted_mask_max_position_embedding: int#
jax_mesh()[source]#
property mesh#

The mesh property is a helper property that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The platform attribute is also used if it exists.

Parameters

self – Refer to the object itself

Returns

A jaxMesh

read_basics_from_config(config: EasyDeLBaseConfig)[source]#
to_dict() Dict[str, Any][source]#

Serializes this instance to a Python dictionary.

Returns

Dictionary of all the attributes that make up this configuration instance.

Return type

Dict[str, Any]

class easydel.infra.__init__.EasyDeLBaseConfigDict[source]#

Bases: TypedDict

class easydel.infra.__init__.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#

Bases: Module, BaseModuleProtocol, EasyBridgeMixin, EasyGenerationMixin

Base class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.

apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#

Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.

property causal_mask: Array#

Returns a causal mask from the config.

compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#

basic compute_loss call

float(change_runtime_dtype: bool = True) SELF[source]#

Converts Model paramters to float32.

property frequencies: Array#

Returns frequency values from the config.

fully_gather() SELF[source]#
fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#
gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Gathers the model’s parameters based on the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for gathering.

  • mesh (jax.sharding.Mesh, optional) – The mesh to gather from. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.

Returns

The gathered model.

Return type

EasyDeLBaseModule

get_static_arguments() Tuple[source]#

return static arguments kwargs for jax.jit

property graphdef: Union[NodeDef[Node], NodeRef[Node]]#
property graphother: State[Key, VariableState[Any]]#
property graphstate: State[Key, VariableState[Any]]#
property graphtree_params_shape: Dict#

Evaluates the shape of the model’s parameters and returns a dictionary.

property graphtree_shape: Dict#

Evaluates the shape of the modeland returns a dictionary.

half(change_runtime_dtype: bool = True) SELF[source]#

Converts Model paramters to float16.

classmethod lazy_init(*args, **kwargs) SELF[source]#

initialize the base class with nnx.eval_shape carefully

property loss_function#
merge_lora_params(pytree: Dict) SELF[source]#

Merge Given Pytree (LoRA Params) with current LoRA Module.

merge_params(tree)[source]#

merge state to the current model

merge_params_dict(params_dict: Dict) SELF[source]#

Merges the model parameters from a dictionary into the current model.

Parameters

params_dict (tp.Dict) – A dictionary containing the parameters to merge.

Returns

The model with merged parameters.

Return type

EasyDeLBaseModule

property mesh: Mesh#

Returns the mesh from the config.

property model_task: Optional[str]#

Returns the model task.

property model_type: Optional[str]#

Returns the model type.

property module_dtype: dtype#
property parameters: Dict#
property params: Dict#
property params_sharding: Dict#

return the sharding of the model parameters

prepare_inputs_for_call(**kwargs)[source]#

update inputs for calling model

property pure_transform_fn#

generates a pure transform function for converting torch to easydel module.

quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#

Quantizes the model’s linear layers.

Parameters
  • method (EasyDeLQuantizationMethods, optional) – The quantization method to use.

  • block_size (int, optional) – The block size for quantization.

  • quantization_pattern (str, optional) – The quantization pattern to use. quantize_tensors (bool): whenever to quantize tensors or quantize Linear Layers.` verbose (bool, optional): Verbose quantizing process

Returns

The quantized model.

Return type

EasyDeLBaseModule

shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#

Shards the model’s parameters using the specified partitioning rules and mesh.

Parameters
  • partition_rules (PartitionLike, optional) – Partitioning rules for sharding.

  • mesh (jax.sharding.Mesh, optional) – The mesh to shard across. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.

Returns

The sharded model.

Return type

EasyDeLBaseModule

split_lora_params() Dict[source]#

split Given Module (LoRA Module) and return LoRA Params.

split_params()[source]#

split the model parameters

split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#

Splits the model parameters and returns them as a dictionary, removing VariableState from the tree.

Parameters
  • extract_fn (tp.Optional[tp.Callable], optional) – Function to extract values from the parameters.

  • remove_none (bool, optional) – Whether to remove None values from the dictionary.

Returns

The dictionary of split parameters.

Return type

tp.Dict

property static_arguments: Tuple#
to_dtype(dtype: dtype) SELF[source]#

Applies sharding functions to the model’s state.

to_state() Any[source]#

converts current model to a EasyDeLState

to_torch(**kwargs)[source]#

converts current model to a huggingface torch model

property transform_fn#

generate transform function for converting torch to easydel module.

unwrap_lora_to_layers(verbose: bool = False) SELF[source]#

UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.

class easydel.infra.__init__.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#

Bases: PyTreeNode

EasyDeLState A Snapshot of Your EasyDeL Model

The EasyDeLState class acts like a comprehensive container that holds all the essential information about your EasyDeL model at a given point in time. Think of it as a snapshot of your model. It includes

apply_fn: tp.Optional[tp.Callable] = None#
apply_gradients(*, grads)[source]#

Applies gradients to the model parameters and updates the optimizer state. This function is typically called during training to update the model based on the computed gradients.

Parameters

grads – A dictionary of gradients, where keys correspond to model parameters.

Returns

An updated EasyDeLState object with modified parameters and optimizer state.

Return type

EasyDeLState

classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#

Create an instance with flexible initialization options.

Parameters
  • step – Optional number of training steps.

  • graphdef – Optional graph definition.

  • graphstate – Optional graph state.

  • graphother – Optional graph *others.

  • model – Optional neural network module.

  • tx – Optional gradient transformation.

  • opt_state – Optional optimizer state.

Raises

ValueError – If initialization parameters are inconsistent.

gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Gathers the model according to the provided partition rules.

Returns

An updated EasyDeLState object with the gathered model.

Return type

EasyDeLState

gather_optimizer_state(partition_rules=None)[source]#
gather_state()[source]#

Gathers the entire state.

Returns

An updated EasyDeLState object with the gathered state.

Return type

EasyDeLState

graphdef: nn.GraphDef#
graphother: nn.GraphState#
graphstate: nn.GraphState#
init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#

Initialize the optimizer state with the given gradient transformation.

Parameters
  • tx (optax.GradientTransformation) – A gradient transformation to initialize the optimizer state.

  • partition_rules (Optional[Any], optional) – Rules for partitioning the optimizer state. Defaults to None.

Returns

An updated EasyDeLState object with the new gradient transformation and sharded optimizer state.

Return type

EasyDeLState

load_optimizer(load_directory: Union[str, PathLike])[source]#
load_state(load_directory: Union[str, PathLike], verbose: bool = True)[source]#
merge(tree) Any[source]#
merge_to_state(tree) EasyDeLState[source]#
property model: Any#
opt_state: tp.Optional[optax.OptState]#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#
shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#

Shards the model according to the provided partition rules.

Parameters
  • partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.

  • mesh (Optional[Mesh]) – The mesh to be used for sharding. If None, the method will use the mesh from self.model.

Returns

An updated EasyDeLState object with the sharded model.

Return type

EasyDeLState

shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#

Shards the optimizer state according to the provided partition rules.

Parameters
  • opt_state (Optional[Any]) – The optimizer state to be sharded. If None, the method will use self.opt_state. Raises a ValueError if both opt_state and self.opt_state are None.

  • partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.

Returns

The sharded optimizer state.

Return type

Any

Raises

ValueError – If both opt_state and self.opt_state are None.

shard_state(partition_rules: Any = None) EasyDeLState[source]#

Shards the entire state, according to the provided partition rules.

Parameters

partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.

Returns

An updated EasyDeLState object with the sharded state.

Return type

EasyDeLState

shard_with_shape(shape) EasyDeLState[source]#

shard current state with a given shape

property shardings#

Returns the sharding information for the state.

Returns

The sharding information.

Return type

Any

property size: int#

Calculates the total size of the optimizer state and model graph state.

Returns

The total size in bytes.

Return type

int

step: int | jax.Array#
tx: optax.GradientTransformation#
class easydel.infra.__init__.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Union[float, int, str, easydel.infra.loss_utils.SpecialLossNormalizingFactor, NoneType] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#

Bases: Mapping

break_on_nan: bool = True#
classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
divide_weight_sum: bool = False#
from_tuple()#
ignore_index: int = -100#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
label_smoothing: float = 0.0#
loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
num_classification_labels: Optional[int] = None#
num_labels: Optional[str] = None#
problem_type: Optional[str] = None#
reduction: Optional[Literal['none', 'mean', 'sum']] = None#
replace(**kwargs)#
shift_tokens: bool = True#
to_tuple()#
values() an object providing a view on D's values#
z_loss: float = 0.0#
class easydel.infra.__init__.PartitionAxis(batch_axis: Optional[Union[Tuple[str, ...], str]] = ('fsdp', 'dp'), sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', query_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', head_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', hidden_state_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', attention_dim_axis: Optional[Union[Tuple[str, ...], str]] = None, bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, generation_head_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str]] = None)[source]#

Bases: NamedTuple

A NamedTuple representing different axes of partitioning in a model.

Each field represents an axis and its corresponding partitioning strategy. The value of each field can be:

  • None: The axis is not partitioned.

  • str: The name of the single mesh dimension across which the axis is partitioned.

  • Tuple[str, …]: A tuple of mesh dimension names, indicating a sharding strategy

    where the axis is split across multiple mesh dimensions.

batch_axis#

Partitioning strategy for the batch dimension. Defaults to (“fsdp”, “dp”).

Type

Optional[Union[Tuple[str, …], str]]

sequence_axis#

Partitioning strategy for the sequence dimension. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

query_sequence_axis#

Partitioning strategy for the query sequence dimension. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

head_axis#

Partitioning strategy for the attention head dimension. Defaults to “tp”.

Type

Optional[Union[Tuple[str, …], str]]

key_sequence_axis#

Partitioning strategy for the key sequence dimension. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

hidden_state_axis#

Partitioning strategy for the hidden state dimension. Defaults to “tp”.

Type

Optional[Union[Tuple[str, …], str]]

attention_dim_axis#

Partitioning strategy for the attention dimension. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

bias_head_sequence_axis#

Partitioning strategy for the bias head sequence dimension. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

bias_key_sequence_axis#

Partitioning strategy for the bias key sequence dimension. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

generation_query_sequence_axis#

Partitioning strategy for the query sequence dimension during generation. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

generation_head_axis#

Partitioning strategy for the attention head dimension during generation. Defaults to “tp”.

Type

Optional[Union[Tuple[str, …], str]]

generation_key_sequence_axis#

Partitioning strategy for the key sequence dimension during generation. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

generation_attention_dim_axis#

Partitioning strategy for the attention dimension during generation. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

attention_dim_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 6

batch_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 0

bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 7

bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 8

generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 12

generation_head_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 10

generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 11

generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 9

head_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 3

hidden_state_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 5

key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 4

query_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 2

sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 1