easydel.utils.parameters_transformation#
- easydel.utils.parameters_transformation.match_keywords(string, ts, ns)[source]#
The match_keywords function takes a string, and two lists of strings. The first list is the "must-have" keywords, and the second list is the "not-allowed" keywords. It returns True if all the must-have keywords are in string, but none of not allowed are in it.
- Parameters
string – Pass in the text that is being searched
ts – Specify the required keywords and ns is used to specify the non-required keywords
ns – Specify a list of negative keywords
- Returns
True if all the keywords in ts are present and none of the
- easydel.utils.parameters_transformation.module_to_huggingface_model(module: ~typing.Any, config: ~typing.Any, base_huggingface_module: ~typing.Any, base_huggingface_module_kwarguments: ~typing.Optional[~typing.Dict] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, use_meta_torch: bool = True, **kw)[source]#
- easydel.utils.parameters_transformation.module_to_torch(module: ~typing.Any, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>)[source]#
- easydel.utils.parameters_transformation.process_tensor(key: str, tensor: Any, config: Dict[str, Any]) Optional[Tuple[tuple, Array]][source]#
Process a single tensor and return its processed key and value.
- Parameters
key – The parameter key
tensor – The tensor to process
config – Dictionary containing processing configuration
- Returns
tp.Tuple of processed key tuple and JAX array, or None if tensor should be skipped
- easydel.utils.parameters_transformation.torch_dict_to_easydel_params(state_dict: ~typing.Dict[str, ~typing.Any], *, device: ~typing.Optional[~jaxlib.xla_extension.Device] = None, embedding_layer_names: ~typing.Optional[~typing.List[str]] = None, layernorm_names: ~typing.Optional[~typing.List[str]] = None, shard_fns: ~typing.Optional[~typing.Mapping[tuple, ~typing.Callable]] = None, dtype: ~numpy.dtype = <class 'jax.numpy.float16'>, verbose: bool = True, callback: ~typing.Optional[~typing.Callable[[~jax.Array, tuple], ~jax.Array]] = None, remove_state_dict: bool = False, lm_head_name: ~typing.Optional[str] = None, uses_tie_word_embedding: bool = False, **kwargs) Dict[str, Any][source]#
Convert PyTorch state dict to EasyDel parameter format.
- Parameters
state_dict – PyTorch state dictionary
device – JAX device to use
embedding_layer_names – Names of embedding layers
layernorm_names – Names of layer normalization layers
shard_fns – tp.Mapping of parameter names to sharding functions
block_size – Size of processing blocks
params_pattern_selection – Regex pattern for parameter selection
dtype – Target dtype for parameters
verbose – Whether to show progress bar callback: callback for tensors after they are converted to a jax array.
remove_state_dict – Whether to delete state_dict after conversion
lm_head_name – Name of language model head
uses_tie_word_embedding – Whether model uses tied embeddings
**kwargs – Additional arguments
- Returns
Dictionary of converted parameters in EasyDel format