easydel.utils.parameters_transformation#

easydel.utils.parameters_transformation.convert_pytorch_tensor_to_jax(tensor, dtype)[source]#
easydel.utils.parameters_transformation.float_tensor_to_dtype(tensor, dtype)[source]#
easydel.utils.parameters_transformation.get_dtype(dtype)[source]#
easydel.utils.parameters_transformation.get_torch()[source]#
easydel.utils.parameters_transformation.jax2pt(x: Array)[source]#
easydel.utils.parameters_transformation.match_keywords(string, ts, ns)[source]#

The match_keywords function takes a string, and two lists of strings. The first list is the "must-have" keywords, and the second list is the "not-allowed" keywords. It returns True if all the must-have keywords are in string, but none of not allowed are in it.

Parameters
  • string – Pass in the text that is being searched

  • ts – Specify the required keywords and ns is used to specify the non-required keywords

  • ns – Specify a list of negative keywords

Returns

True if all the keywords in ts are present and none of the

easydel.utils.parameters_transformation.module_to_huggingface_model(module: ~typing.Any, config: ~typing.Any, base_huggingface_module: ~typing.Any, base_huggingface_module_kwarguments: ~typing.Optional[~typing.Dict] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, use_meta_torch: bool = True, **kw)[source]#
easydel.utils.parameters_transformation.module_to_torch(module: ~typing.Any, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>)[source]#
easydel.utils.parameters_transformation.process_tensor(key: str, tensor: Any, config: Dict[str, Any]) Optional[Tuple[tuple, Array]][source]#

Process a single tensor and return its processed key and value.

Parameters
  • key – The parameter key

  • tensor – The tensor to process

  • config – Dictionary containing processing configuration

Returns

tp.Tuple of processed key tuple and JAX array, or None if tensor should be skipped

easydel.utils.parameters_transformation.pt2jax(x)[source]#
easydel.utils.parameters_transformation.torch_dict_to_easydel_params(state_dict: ~typing.Dict[str, ~typing.Any], *, device: ~typing.Optional[~jaxlib.xla_extension.Device] = None, embedding_layer_names: ~typing.Optional[~typing.List[str]] = None, layernorm_names: ~typing.Optional[~typing.List[str]] = None, shard_fns: ~typing.Optional[~typing.Mapping[tuple, ~typing.Callable]] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, verbose: bool = True, callback: ~typing.Optional[~typing.Callable[[~jax.Array, tuple], ~jax.Array]] = None, remove_state_dict: bool = False, lm_head_name: ~typing.Optional[str] = None, uses_tie_word_embedding: bool = False, **kwargs) Dict[str, Any][source]#

Convert PyTorch state dict to EasyDel parameter format.

Parameters
  • state_dict – PyTorch state dictionary

  • device – JAX device to use

  • embedding_layer_names – Names of embedding layers

  • layernorm_names – Names of layer normalization layers

  • shard_fns – tp.Mapping of parameter names to sharding functions

  • block_size – Size of processing blocks

  • params_pattern_selection – Regex pattern for parameter selection

  • dtype – Target dtype for parameters

  • verbose – Whether to show progress bar callback: callback for tensors after they are converted to a jax array.

  • remove_state_dict – Whether to delete state_dict after conversion

  • lm_head_name – Name of language model head

  • uses_tie_word_embedding – Whether model uses tied embeddings

  • **kwargs – Additional arguments

Returns

Dictionary of converted parameters in EasyDel format