easydel.utils.parameters_transformation#
- class easydel.utils.parameters_transformation.DtypeHandler[source]#
Bases:
objectHandles dtype conversions and operations.
- static float_tensor_to_dtype(tensor: Any, dtype: str | numpy.dtype | None) Any[source]#
Convert float tensor to specified dtype.
- static get_dtype(dtype: str | numpy.dtype) dtype[source]#
Convert string dtype representation to JAX dtype.
- class easydel.utils.parameters_transformation.ModelConverter[source]#
Bases:
objectHandles model conversions between EasyDeL and HuggingFace formats.
- static easydel_to_huggingface(module: EasyDeLBaseModule, config: EasyDeLBaseConfig, base_huggingface_module: PreTrainedModel, base_huggingface_module_kwarguments: dict | None = None, dtype: jnp.dtype = <class 'jax.numpy.float16'>, use_meta_torch: bool = True, reform_param: dict | None = None, **kw) tp.Any[source]#
Convert EasyDeL module to HuggingFace model.
- class easydel.utils.parameters_transformation.StateDictConverter[source]#
Bases:
objectHandles conversion between PyTorch and EasyDeL state dictionaries.
- static apply_moe_transformations(state_dict: dict[str, Any], moe_block_names: list[str] | None = None, moe_names: list[str] | None = None, moe_block_path: list[str] | None = None, moe_path: list[str] | None = None, tensor_transform: Optional[Callable] = None) tuple[dict[str, Any], set[str]][source]#
Transform MoE weights from HuggingFace format (separate experts) to EasyDel format (stacked experts). Converts from:
model.layers.3.block_sparse_moe.experts.0.w3.weight -> shape (128, 256) model.layers.3.block_sparse_moe.experts.1.w3.weight -> shape (128, 256) …
- To:
model.layers.3.block_sparse_moe.experts.w3.weight -> shape (num_experts, 128, 256)
- static apply_moe_transformations_reverse(state_dict: dict[str, Any], moe_block_names: list[str] | None = None, moe_names: list[str] | None = None, moe_block_path: list[str] | None = None, moe_path: list[str] | None = None, tensor_transform: Optional[Callable] = None) dict[str, Any][source]#
Transform MoE weights from EasyDel format (stacked experts) to HuggingFace format (separate experts).
- Converts from:
model.layers.3.block_sparse_moe.experts.w3.weight -> shape (num_experts, 128, 256)
- To:
model.layers.3.block_sparse_moe.experts.0.w3.weight -> shape (128, 256) model.layers.3.block_sparse_moe.experts.1.w3.weight -> shape (128, 256) …
- static easydel_to_torch(module: EasyDeLBaseModule, dtype: jnp.dtype = <class 'jax.numpy.float16'>, **kwargs) dict[str, tp.Any][source]#
Convert EasyDeL module to PyTorch state dict.
- static huggingface_to_easydel(state_dict: dict[str, typing.Any], *, device: jaxlib._jax.Device | None = None, embedding_layer_names: list[str] | None = None, layernorm_names: list[str] | None = None, moe_block_names: list[str] | None = None, moe_names: list[str] | None = None, moe_block_path: list[str] | None = None, moe_path: list[str] | None = None, shard_fns: ~typing.Optional[~typing.Mapping[tuple, ~typing.Callable]] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, verbose: bool = True, callback: ~typing.Optional[~typing.Callable[[~jax.Array, tuple], ~jax.Array]] = None, remove_state_dict: bool = False, lm_head_name: str | None = None, uses_tie_word_embedding: bool = False, reform_param: dict | None = None, **kwargs) dict[str, Any][source]#
Convert PyTorch state dict to EasyDeL format with MoE transformations.