easydel.utils.parameters_transformation#

class easydel.utils.parameters_transformation.DtypeHandler[source]#

Bases: object

Handles dtype conversions and operations.

static float_tensor_to_dtype(tensor: Any, dtype: str | numpy.dtype | None) Any[source]#

Convert float tensor to specified dtype.

static get_dtype(dtype: str | numpy.dtype) dtype[source]#

Convert string dtype representation to JAX dtype.

class easydel.utils.parameters_transformation.ModelConverter[source]#

Bases: object

Handles model conversions between EasyDeL and HuggingFace formats.

static easydel_to_huggingface(module: EasyDeLBaseModule, config: EasyDeLBaseConfig, base_huggingface_module: PreTrainedModel, base_huggingface_module_kwarguments: dict | None = None, dtype: jnp.dtype = <class 'jax.numpy.float16'>, use_meta_torch: bool = True, reform_param: dict | None = None, **kw) tp.Any[source]#

Convert EasyDeL module to HuggingFace model.

class easydel.utils.parameters_transformation.StateDictConverter[source]#

Bases: object

Handles conversion between PyTorch and EasyDeL state dictionaries.

static apply_moe_transformations(state_dict: dict[str, Any], moe_block_names: list[str] | None = None, moe_names: list[str] | None = None, moe_block_path: list[str] | None = None, moe_path: list[str] | None = None, tensor_transform: Optional[Callable] = None) tuple[dict[str, Any], set[str]][source]#

Transform MoE weights from HuggingFace format (separate experts) to EasyDel format (stacked experts). Converts from:

model.layers.3.block_sparse_moe.experts.0.w3.weight -> shape (128, 256) model.layers.3.block_sparse_moe.experts.1.w3.weight -> shape (128, 256) …

To:

model.layers.3.block_sparse_moe.experts.w3.weight -> shape (num_experts, 128, 256)

static apply_moe_transformations_reverse(state_dict: dict[str, Any], moe_block_names: list[str] | None = None, moe_names: list[str] | None = None, moe_block_path: list[str] | None = None, moe_path: list[str] | None = None, tensor_transform: Optional[Callable] = None) dict[str, Any][source]#

Transform MoE weights from EasyDel format (stacked experts) to HuggingFace format (separate experts).

Converts from:

model.layers.3.block_sparse_moe.experts.w3.weight -> shape (num_experts, 128, 256)

To:

model.layers.3.block_sparse_moe.experts.0.w3.weight -> shape (128, 256) model.layers.3.block_sparse_moe.experts.1.w3.weight -> shape (128, 256) …

static easydel_to_torch(module: EasyDeLBaseModule, dtype: jnp.dtype = <class 'jax.numpy.float16'>, **kwargs) dict[str, tp.Any][source]#

Convert EasyDeL module to PyTorch state dict.

static huggingface_to_easydel(state_dict: dict[str, typing.Any], *, device: jaxlib._jax.Device | None = None, embedding_layer_names: list[str] | None = None, layernorm_names: list[str] | None = None, moe_block_names: list[str] | None = None, moe_names: list[str] | None = None, moe_block_path: list[str] | None = None, moe_path: list[str] | None = None, shard_fns: ~typing.Optional[~typing.Mapping[tuple, ~typing.Callable]] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, verbose: bool = True, callback: ~typing.Optional[~typing.Callable[[~jax.Array, tuple], ~jax.Array]] = None, remove_state_dict: bool = False, lm_head_name: str | None = None, uses_tie_word_embedding: bool = False, reform_param: dict | None = None, **kwargs) dict[str, Any][source]#

Convert PyTorch state dict to EasyDeL format with MoE transformations.

static match_keywords(string: str, required: list[str], forbidden: list[str]) bool[source]#

Check if string contains all required keywords and none of the forbidden ones.

static process_tensor(key: str, tensor: Any, config: dict[str, Any]) list[tuple[tuple, jax.Array]] | None[source]#

Process a single tensor and return its processed key and value.

class easydel.utils.parameters_transformation.TensorConverter[source]#

Bases: object

Handles tensor conversions between PyTorch and JAX.

static convert_pytorch_to_jnp(tensor: Any, dtype: dtype) Array[source]#

Convert PyTorch tensor to JAX array.

static get_torch()[source]#

Import and return torch module (cached).

static jax_to_pytorch(x: Array) Any[source]#

Convert JAX array to PyTorch tensor.

static pytorch_to_jax(x: Any) Array[source]#

Convert PyTorch tensor to JAX array.