easydel.infra.__init__#
- class easydel.infra.__init__.EasyDeLBaseConfig(axis_dims: ~typing.Sequence[int] = (1, -1, 1, 1), dcn_axis_dims: ~typing.Optional[~typing.Sequence[int]] = None, axis_names: ~typing.Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), attn_mechanism: ~typing.Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = 'vanilla', blocksize_k: int = 128, blocksize_q: int = 128, blocksize_b: int = 1, partition_axis: ~eformer.escale.partition.constraints.PartitionAxis = PartitionAxis(batch_axis=('fsdp', 'dp'), sequence_axis='sp', query_sequence_axis='sp', head_axis='tp', key_sequence_axis='sp', hidden_state_axis='tp', attention_dim_axis=None, bias_head_sequence_axis=None, bias_key_sequence_axis=None, generation_query_sequence_axis=None, generation_head_axis='tp', generation_key_sequence_axis='sp', generation_attention_dim_axis=None), shard_attention_computation: bool = True, use_sharded_kv_caching: bool = False, use_sharding_constraint: bool = False, backend: ~typing.Optional[~easydel.infra.etils.EasyDeLBackends] = None, platform: ~typing.Optional[~easydel.infra.etils.EasyDeLPlatforms] = None, easy_method: ~typing.Literal['train', 'serve', 'convert'] = 'train', bits: ~typing.Optional[int] = None, scan_ring_attention: bool = True, scan_attention_layers: bool = False, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, sequence_axis_name: str = 'sp', gradient_checkpointing: ~easydel.infra.etils.EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, kv_cache_quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, kv_cache_quantization_blocksize: int = 64, quantization_method: ~easydel.infra.etils.EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NONE, quantization_pattern: str = '.*', quantization_blocksize: int = 64, kv_cache_sharding_sequence_axis_name: ~typing.Union[str, ~typing.Tuple[str, ...]] = 'sp', flash_attention_backward_pass_impl: ~typing.Literal['triton', 'xla'] = 'triton', attn_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, attn_softmax_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, fcm_max_ratio: float = 0.0, fcm_min_ratio: float = 0.0, hardware_abstraction: bool = False, pallas_m_block_size: int = 128, pallas_k_block_size: int = 128, pallas_n_block_size: int = 128, **kwargs)[source]#
Bases:
PretrainedConfigInitialize the configuration for EasyDeL. :param axis_dims: Dimensions of the axes. Default is (1, -1, 1, 1). :type axis_dims: tp.Sequence[int] :param axis_names: Names of the axes. Default is (“dp”, “fsdp”, “tp”, “sp”). :type axis_names: tp.Sequence[str] :param attn_mechanism: Attention mechanism to use. Default is DEFAULT_ATTENTION_MECHANISM. :type attn_mechanism: AVAILABLE_ATTENTION_MECHANISMS :param blocksize_k: Block size for key. Default is 128. :type blocksize_k: int :param blocksize_q: Block size for query. Default is 128. :type blocksize_q: int :param blocksize_b: Block size for batch. Default is 1. :type blocksize_b: int :param partition_axis: Partition axis configuration. Default is PartitionAxis(). :type partition_axis: PartitionAxis :param shard_attention_computation: Whether to shard attention computation. Default is True. :type shard_attention_computation: bool :param use_sharded_kv_caching: Whether to use sharded key-value caching. Default is False. :type use_sharded_kv_caching: bool :param use_sharding_constraint: Whether to use sharding constraint. Default is False. :type use_sharding_constraint: bool :param backend: Backend to use. Default is None. :type backend: tp.Optional[EasyDeLBackends] :param platform: Platform to use. Default is None. :type platform: tp.Optional[EasyDeLPlatforms] :param easy_method: Method to use. Default is EasyMethod.TRAIN. :type easy_method: tp.Literal[“train”, “serve”, “convert”] :param bits: Number of bits for quantization. Default is None. :type bits: tp.Optional[int] :param scan_ring_attention: Whether to scan ring attention. Default is True. :type scan_ring_attention: bool :param scan_attention_layers: Whether to scan attention layers. Default is False. :type scan_attention_layers: bool :param use_scan_mlp: Whether to use scan MLP. Default is False. :type use_scan_mlp: bool :param scan_mlp_chunk_size: Chunk size for scan MLP. Default is 1024. :type scan_mlp_chunk_size: int :param sequence_axis_name: Name of the attention axis. Default is “sp”. :type sequence_axis_name: str :param gradient_checkpointing: Gradient checkpointing method. Default is EasyDeLGradientCheckPointers.NONE. :type gradient_checkpointing: EasyDeLGradientCheckPointers :param kv_cache_quantization_method: Key-value cache quantization method. Default is EasyDeLQuantizationMethods.NONE. :type kv_cache_quantization_method: EasyDeLQuantizationMethods :param kv_cache_quantization_blocksize: Block size for key-value cache quantization. Default is 64. :type kv_cache_quantization_blocksize: int :param quantization_method: Quantization method. Default is EasyDeLQuantizationMethods.NONE. :type quantization_method: EasyDeLQuantizationMethods :param quantization_pattern: Pattern for quantization. Default is “.*”. :type quantization_pattern: str :param quantization_blocksize: Block size for quantization. Default is 64. :type quantization_blocksize: int :param kv_cache_sharding_sequence_axis_name: Name of the key-value cache sharding sequence axis. Default is “sp”. :type kv_cache_sharding_sequence_axis_name: tp.Union[str, tp.Tuple[str, …]] :param flash_attention_backward_pass_impl: Implementation for flash attention backward pass. Default is “triton”. :type flash_attention_backward_pass_impl: tp.Literal[“triton”, “xla”] :param attn_dtype: Data type for attention. Default is device half. :type attn_dtype: jnp.dtype :param attn_softmax_dtype: Data type for softmax ops in attention. Default is jnp.float32. :type attn_softmax_dtype: jnp.dtype :param fcm_max_ratio: Maximum ratio for FCM. Default is 0.0. :type fcm_max_ratio: float :param fcm_min_ratio: Minimum ratio for FCM. Default is 0.0. :type fcm_min_ratio: float :param hardware_abstraction: Whether to use hardware abstraction. Default is DEFAULT_HARDWARE_ABSTRACTION. :type hardware_abstraction: bool :param pallas_m_block_size: Block size for Pallas M. Default is DEFAULT_PALLAS_M_BLOCK_SIZE. :type pallas_m_block_size: int :param pallas_k_block_size: Block size for Pallas K. Default is DEFAULT_PALLAS_K_BLOCK_SIZE. :type pallas_k_block_size: int :param pallas_n_block_size: Block size for Pallas N. Default is DEFAULT_PALLAS_N_BLOCK_SIZE. :type pallas_n_block_size: int :param **kwargs: Additional keyword arguments.
- Raises
Warning – If kv_cache_quantization_method is not NONE and use_sharded_kv_caching is True.
- add_basic_configurations(axis_dims: Sequence[int] = Ellipsis, dcn_axis_dims: Optional[Sequence[int]] = Ellipsis, axis_names: Sequence[str] = Ellipsis, attn_mechanism: Literal['vanilla', 'flash_attn2', 'splash', 'ring', 'cudnn', 'blockwise', 'sdpa'] = Ellipsis, blocksize_k: int = Ellipsis, blocksize_q: int = Ellipsis, blocksize_b: int = Ellipsis, partition_axis: PartitionAxis = Ellipsis, shard_attention_computation: bool = Ellipsis, use_sharded_kv_caching: bool = Ellipsis, backend: Optional[EasyDeLBackends] = Ellipsis, platform: Optional[EasyDeLPlatforms] = Ellipsis, easy_method: Literal['train', 'serve', 'convert'] = Ellipsis, bits: Optional[int] = Ellipsis, scan_ring_attention: bool = Ellipsis, scan_attention_layers: bool = Ellipsis, use_sharding_constraint: bool = Ellipsis, use_scan_mlp: bool = Ellipsis, scan_mlp_chunk_size: int = Ellipsis, sequence_axis_name: str = Ellipsis, gradient_checkpointing: EasyDeLGradientCheckPointers = Ellipsis, kv_cache_quantization_method: EasyDeLQuantizationMethods = Ellipsis, kv_cache_quantization_blocksize: int = Ellipsis, quantization_method: EasyDeLQuantizationMethods = Ellipsis, quantization_blocksize: int = Ellipsis, quantization_pattern: str = Ellipsis, kv_cache_sharding_sequence_axis_name: Union[str, Tuple[str, ...]] = Ellipsis, flash_attention_backward_pass_impl: Literal['triton', 'xla'] = Ellipsis, attn_dtype: dtype = Ellipsis, attn_softmax_dtype: dtype = Ellipsis, hardware_abstraction: bool = Ellipsis, pallas_m_block_size: int = Ellipsis, pallas_k_block_size: int = Ellipsis, pallas_n_block_size: int = Ellipsis)[source]#
It initializes all the attributes of an object, and it’s called when you create a new instance of that class.
- Parameters
axis_dims (tp.Sequence[int], optional) – Specify the number of dimensions for each axis. Defaults to (1, -1, 1, 1).
axis_names (tp.Sequence[str], optional) – Set the names of the axes. Defaults to (“dp”, “fsdp”, “tp”, “sp”).
attn_mechanism (AVAILABLE_ATTENTION_MECHANISMS, optional) – attention mechanism to use. Defaults to DEFAULT_ATTENTION_MECHANISM.
blocksize_k (int, optional) – block size of key_states. Defaults to 128.
blocksize_q (int, optional) – block size of query_states. Defaults to 128.
blocksize_b (int, optional) – block size of bias. Defaults to 1.
partition_axis (PartitionAxis, optional) – PartitionAxis is new module used for partitioning arrays in easydel. Defaults to PartitionAxis().
shard_attention_computation (bool, optional) – whenever to use shard_map for attention. Defaults to True.
use_sharded_kv_caching (bool, optional) – whenever to use shard_map and sharding for key and value. Defaults to True.
backend (tp.Optional[EasyDeLBackends], optional) – Specify the backend to use. Defaults to None.
platform (tp.Optional[EasyDeLPlatforms], optional) – Specify the platform to used to use. Defaults to None.
easy_method (tp.Literal["train", "serve", "convert"], optional) – easydel Quantization Method to be applied for. Defaults to EasyMethod.TRAIN.
bits (tp.Optional[int], optional) – Model bits for quantization. Defaults to None.
scan_ring_attention (bool, optional) – Whether to use can for ring attention. Defaults to True.
scan_attention_layers (bool, optional) – Whether to use can for attention layers. Defaults to False.
use_sharding_constraint (bool, optional) – whether to use sharding constraint for the arrays. Defaults to False.
use_scan_mlp (bool, optional) – Determine whether to use scan_mlp or not. Defaults to False.
scan_mlp_chunk_size (int, optional) – Size of chunks in scan MLP. Defaults to 1024.
sequence_axis_name (str, optional) – Name of the attention axis name. Defaults to “sp”.
gradient_checkpointing (EasyDeLQuantizationMethods, optional) – Gradient Checkpointing method for created or loaded module (applied on mlp and attn layers most of the times).
kv_cache_quantization_method (EasyDeLQuantizationMethods, optional) – key and value quantization type. Defaults to EasyDeLQuantizationMethods.NONE.
kv_cache_quantization_blocksize (int, optional) – size of kv cache quantization. Defaults to 64.
quantization_method (EasyDeLQuantizationMethods, optional) – linear modules quantization type. Defaults to EasyDeLQuantizationMethods.NONE.
quantization_blocksize (int, optional) – size of linear quantization. Defaults to 64.
quantization_pattern (str) – re pattern to be used for quantizing layers.
kv_cache_sharding_sequence_axis_name (tp.Union[str, tp.Tuple[str, ...]], optional) – axis name to target for sharding sequences. Defaults to “sp”.
flash_attention_backward_pass_impl (tp.Literal["triton", "xla"], optional) – Specify the backward pass kernel for flash attention. Defaults to “triton”.
attn_dtype (jnp.dtype, optional) – Data type for attention computations. Defaults to device half.
attn_softmax_dtype (jnp.dtype, optional) – Data type for softmax in attention op computations. Defaults to jnp.float32.
fcm_max_ratio (float, optional) – Maximum ratio for flash cross attention. Defaults to 0.0.
fcm_min_ratio (float, optional) – Minimum ratio for flash cross attention. Defaults to 0.0.
hardware_abstraction (bool, optional) – whenever to switch to custom pallas kernels instead of JAX. Defaults to DEFAULT_HARDWARE_ABSTRACTION.
pallas_m_block_size (int, optional) – block size m dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_M_BLOCK_SIZE.
pallas_k_block_size (int, optional) – block size k dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_K_BLOCK_SIZE.
pallas_n_block_size (int, optional) – block size n dim in matmul for pallas kernel A(mk)@B(kn)=B(mn). Defaults to DEFAULT_PALLAS_N_BLOCK_SIZE.
- static create_mesh(axis_dims: Sequence[int] = (1, -1, 1, 1), axis_names: Sequence[str] = ('dp', 'fsdp', 'tp', 'sp'), dcn_axis_dims: Optional[Sequence[int]] = None, process_is_granule: bool = False, should_sort_granules_by_key: bool = True, allow_split_physical_axes: bool = True, backend: Optional[str] = None)[source]#
The create_mesh function creates a mesh object that can be used to shard arrays.
- Returns
A mesh object
- classmethod from_pretrained(pretrained_model_name_or_path: Union[str, PathLike], cache_dir: Optional[Union[str, PathLike]] = None, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = 'main', **kwargs) PretrainedConfig[source]#
Instantiate a [PretrainedConfig] (or a derived class) from a pretrained model configuration.
- Parameters
pretrained_model_name_or_path (str or os.PathLike) –
This can be either:
a string, the model id of a pretrained model configuration hosted inside a model repo on huggingface.co.
a path to a directory containing a configuration file saved using the [~PretrainedConfig.save_pretrained] method, e.g., ./my_model_directory/.
a path or url to a saved configuration JSON file, e.g., ./my_model_directory/configuration.json.
cache_dir (str or os.PathLike, optional) – Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
force_download (bool, optional, defaults to False) – Whether or not to force to (re-)download the configuration files and override the cached versions if they exist.
resume_download – Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
proxies (Dict[str, str], optional) – A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.
token (str or bool, optional) – The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
revision (str, optional, defaults to “main”) –
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass `revision=”refs/pr/<pr_number>”.
</Tip>
return_unused_kwargs (bool, optional, defaults to False) –
If False, then this function returns just the final configuration object.
If True, then this functions returns a tp.Tuple(config, unused_kwargs) where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part of kwargs which has not been used to update config and is otherwise ignored.
subfolder (str, optional, defaults to “”) – In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
kwargs (Dict[str, tp.Any], optional) – The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled by the return_unused_kwargs keyword parameter.
- Returns
The configuration object instantiated from this pretrained model.
- Return type
[PretrainedConfig]
Examples:
>>> # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a >>> # derived class: BertConfig >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased" >>> ) # Download configuration from huggingface.co and cache. >>> config = BertConfig.from_pretrained( ... "./test/saved_model/" >>> ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* >>> config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") >>> config = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", output_attentions=True, foo=False >>> ) >>> assert config.output_attentions == True >>> config, unused_kwargs = BertConfig.from_pretrained( ... "google-bert/bert-base-uncased", ... output_attentions=True, ... foo=False, ... return_unused_kwargs=True, >>> ) >>> assert config.output_attentions == True >>> assert unused_kwargs == {"foo": False}
- get_axis_dims() Sequence[int][source]#
The get_axis_dims function returns a sequence of integers representing the dimensions of each axis.
- Parameters
self – Represent the instance of the class
- Returns
The dimensions of the axes
- get_axis_names() Sequence[str][source]#
The get_axis_names function returns a list of the names of the axes.
- Parameters
self – Represent the instance of the class
- Returns
A list of the names of all axes
- get_backend() str[source]#
The get_backend function returns the backend that is currently being used. If no backend has been set, it will return the default JAX backend.
- Parameters
self – Bind the method to an object
- Returns
The backend platform
- get_basic_frequencies(head_size: Optional[int] = None, rotary_dim: Optional[int] = None, base: Optional[float] = None) Any[source]#
Get basic frequencies for rotary embeddings.
- Parameters
head_size – Size of attention heads (defaults to self.head_dim)
rotary_dim – Dimension for rotary embeddings (defaults to head_size)
base – Base value for frequency computation (defaults to self.rope_theta)
- Returns
ModuleCaches instance containing computed frequencies
- get_basic_rope(dtype: Union[Array, ndarray, bool, number], head_size: int, rotary_dim: Optional[int] = None, is_neox_style: bool = True, base: Optional[float] = None)[source]#
Get basic rotary position embeddings.
- Parameters
dtype – Data type for the embeddings
head_size – Size of attention heads
rotary_dim – Dimension for rotary embeddings (defaults to head_size)
is_neox_style – Whether to use NeoX style embeddings
base – Base value for frequency computation (defaults to self.rope_theta)
- Returns
Rotary position embeddings func
- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model. :returns: The partition rules. :rtype: tp.Tuple[tp.Tuple[str, PartitionSpec]]
- property granted_freq_max_position_embedding: int#
- property granted_mask_max_position_embedding: int#
- property mesh#
The mesh property is a helper property that creates a Mesh object from the axis_dims and axis_names attributes of an object, which are assumed to be lists of integers and strings, respectively. The platform attribute is also used if it exists.
- Parameters
self – Refer to the object itself
- Returns
A jaxMesh
- class easydel.infra.__init__.EasyDeLBaseModule(*args: Any, **kwargs: Any)[source]#
Bases:
Module,BaseModuleProtocol,EasyBridgeMixin,EasyGenerationMixinBase class for EasyDeL modules, providing common functionalities for model initialization, parameter handling, and integration with the EasyDeL ecosystem.
- apply_lora_to_layers(lora_rank: int, lora_pattern: Optional[str] = None, verbose: bool = False, rngs: Optional[Rngs] = None) SELF[source]#
Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.
- compute_loss(*, labels: Optional[Union[Array, ndarray, bool, number]] = None, loss_config: Optional[LossConfig] = None, loss_kwargs: Optional[Dict] = None, **batch) Tuple[Any, LossMetrics][source]#
basic compute_loss call
- fully_shard(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None) SELF[source]#
- gather_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Gathers the model’s parameters based on the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for gathering.
mesh (jax.sharding.Mesh, optional) – The mesh to gather from. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.
- Returns
The gathered model.
- Return type
- property graphtree_params_shape: Dict#
Evaluates the shape of the model’s parameters and returns a dictionary.
- property graphtree_shape: Dict#
Evaluates the shape of the modeland returns a dictionary.
- classmethod lazy_init(*args, **kwargs) SELF[source]#
initialize the base class with nnx.eval_shape carefully
- property loss_function#
- merge_lora_params(pytree: Dict) SELF[source]#
Merge Given Pytree (LoRA Params) with current LoRA Module.
- merge_params_dict(params_dict: Dict) SELF[source]#
Merges the model parameters from a dictionary into the current model.
- Parameters
params_dict (tp.Dict) – A dictionary containing the parameters to merge.
- Returns
The model with merged parameters.
- Return type
- property model_task: Optional[str]#
Returns the model task.
- property model_type: Optional[str]#
Returns the model type.
- property parameters: Dict#
- property params: Dict#
- property params_sharding: Dict#
return the sharding of the model parameters
- property pure_transform_fn#
generates a pure transform function for converting torch to easydel module.
- quantize(method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.A8BIT, block_size: int = 128, quantization_pattern: Optional[str] = None, quantize_tensors: bool = True, verbose: Optional[bool] = None) SELF[source]#
Quantizes the model’s linear layers.
- Parameters
method (EasyDeLQuantizationMethods, optional) – The quantization method to use.
block_size (int, optional) – The block size for quantization.
quantization_pattern (str, optional) – The quantization pattern to use. quantize_tensors (bool): whenever to quantize tensors or quantize Linear Layers.` verbose (bool, optional): Verbose quantizing process
- Returns
The quantized model.
- Return type
- shard_model(partition_rules: Optional[Union[Mapping[str, Callable], Mapping[tuple, Callable]]] = None, mesh: Optional[Mesh] = None, overlay_fns: Optional[Mapping[str, Callable]] = None) SELF[source]#
Shards the model’s parameters using the specified partitioning rules and mesh.
- Parameters
partition_rules (PartitionLike, optional) – Partitioning rules for sharding.
mesh (jax.sharding.Mesh, optional) – The mesh to shard across. overlay_fns (tp.Optional[tp.Mapping[str, tp.Callable]]): Overlay functions to apply to the model.
- Returns
The sharded model.
- Return type
- split_params_dict(extract_fn: Optional[Callable] = None, remove_none: bool = True) Dict[source]#
Splits the model parameters and returns them as a dictionary, removing VariableState from the tree.
- Parameters
extract_fn (tp.Optional[tp.Callable], optional) – Function to extract values from the parameters.
remove_none (bool, optional) – Whether to remove None values from the dictionary.
- Returns
The dictionary of split parameters.
- Return type
tp.Dict
- property static_arguments: Tuple#
- property transform_fn#
generate transform function for converting torch to easydel module.
- class easydel.infra.__init__.EasyDeLState(step: int | jax.Array, graphdef: nn.GraphDef, graphstate: nn.GraphState, graphother: nn.GraphState, tx: optax.GradientTransformation, opt_state: tp.Optional[optax.OptState], apply_fn: tp.Optional[tp.Callable] = None)[source]#
Bases:
PyTreeNodeEasyDeLState A Snapshot of Your EasyDeL Model
The EasyDeLState class acts like a comprehensive container that holds all the essential information about your EasyDeL model at a given point in time. Think of it as a snapshot of your model. It includes
- apply_fn: tp.Optional[tp.Callable] = None#
- apply_gradients(*, grads)[source]#
Applies gradients to the model parameters and updates the optimizer state. This function is typically called during training to update the model based on the computed gradients.
- Parameters
grads – A dictionary of gradients, where keys correspond to model parameters.
- Returns
An updated EasyDeLState object with modified parameters and optimizer state.
- Return type
- classmethod create(*, step: tp.Optional[int] = None, graphdef: tp.Optional[nn.GraphDef] = None, graphstate: tp.Optional[nn.GraphState] = None, graphother: tp.Optional[nn.GraphState] = None, model: tp.Optional[nn.Module] = None, tx: tp.Optional[optax.GradientTransformation] = None, opt_state: tp.Optional[optax.OptState] = None, init_opt_state: bool = False) EasyDeLState[source]#
Create an instance with flexible initialization options.
- Parameters
step – Optional number of training steps.
graphdef – Optional graph definition.
graphstate – Optional graph state.
graphother – Optional graph *others.
model – Optional neural network module.
tx – Optional gradient transformation.
opt_state – Optional optimizer state.
- Raises
ValueError – If initialization parameters are inconsistent.
- gather_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Gathers the model according to the provided partition rules.
- Returns
An updated EasyDeLState object with the gathered model.
- Return type
- gather_state()[source]#
Gathers the entire state.
- Returns
An updated EasyDeLState object with the gathered state.
- Return type
- graphdef: nn.GraphDef#
- graphother: nn.GraphState#
- graphstate: nn.GraphState#
- init_tx(tx: GradientTransformation, partition_rules: Any = None) EasyDeLState[source]#
Initialize the optimizer state with the given gradient transformation.
- Parameters
tx (optax.GradientTransformation) – A gradient transformation to initialize the optimizer state.
partition_rules (Optional[Any], optional) – Rules for partitioning the optimizer state. Defaults to None.
- Returns
An updated EasyDeLState object with the new gradient transformation and sharded optimizer state.
- Return type
- merge_to_state(tree) EasyDeLState[source]#
- property model: Any#
- opt_state: tp.Optional[optax.OptState]#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.
- save_state(save_directory: Union[str, PathLike], float_dtype: Optional[dtype] = None, verbose: bool = True, mismatch_allowed: bool = True, save_optimizer: bool = True, enable: Optional[bool] = None)[source]#
- shard_model(partition_rules: Any = None, mesh: Optional[Any] = None) EasyDeLState[source]#
Shards the model according to the provided partition rules.
- Parameters
partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.
mesh (Optional[Mesh]) – The mesh to be used for sharding. If None, the method will use the mesh from self.model.
- Returns
An updated EasyDeLState object with the sharded model.
- Return type
- shard_optimizer_state(opt_state: Optional[Any] = None, partition_rules: Any = None) Any[source]#
Shards the optimizer state according to the provided partition rules.
- Parameters
opt_state (Optional[Any]) – The optimizer state to be sharded. If None, the method will use self.opt_state. Raises a ValueError if both opt_state and self.opt_state are None.
partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.
- Returns
The sharded optimizer state.
- Return type
Any
- Raises
ValueError – If both opt_state and self.opt_state are None.
- shard_state(partition_rules: Any = None) EasyDeLState[source]#
Shards the entire state, according to the provided partition rules.
- Parameters
partition_rules (Optional[Any]) – The partition rules to be used for sharding. If None, the method will use the partition rules from self.model.config.
- Returns
An updated EasyDeLState object with the sharded state.
- Return type
- shard_with_shape(shape) EasyDeLState[source]#
shard current state with a given shape
- property shardings#
Returns the sharding information for the state.
- Returns
The sharding information.
- Return type
Any
- property size: int#
Calculates the total size of the optimizer state and model graph state.
- Returns
The total size in bytes.
- Return type
int
- tx: optax.GradientTransformation#
- class easydel.infra.__init__.LossConfig(ignore_index: int = -100, label_smoothing: float = 0.0, z_loss: float = 0.0, loss_normalizing_factor: Union[float, int, str, easydel.infra.loss_utils.SpecialLossNormalizingFactor, NoneType] = 'NUM_REAL_TARGET_TOKENS', num_labels: Optional[str] = None, problem_type: Optional[str] = None, divide_weight_sum: bool = False, shift_tokens: bool = True, break_on_nan: bool = True, reduction: Optional[Literal['none', 'mean', 'sum']] = None, num_classification_labels: Optional[int] = None, classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None)[source]#
Bases:
Mapping- break_on_nan: bool = True#
- classification_problem_type: Optional[Literal['regression', 'single_label_classification', 'multi_label_classification']] = None#
- divide_weight_sum: bool = False#
- from_tuple()#
- ignore_index: int = -100#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- label_smoothing: float = 0.0#
- loss_normalizing_factor: Optional[Union[float, int, str, SpecialLossNormalizingFactor]] = 'NUM_REAL_TARGET_TOKENS'#
- num_classification_labels: Optional[int] = None#
- num_labels: Optional[str] = None#
- problem_type: Optional[str] = None#
- reduction: Optional[Literal['none', 'mean', 'sum']] = None#
- replace(**kwargs)#
- shift_tokens: bool = True#
- to_tuple()#
- values() an object providing a view on D's values#
- z_loss: float = 0.0#