easydel.infra.utils

Contents

easydel.infra.utils#

Utility functions and helpers for EasyDeL infrastructure.

Provides common utilities used throughout the EasyDeL framework, including activation functions, dtype handling, module manipulation, and various helper functions for model operations.

Constants:

ACT2FN: Dictionary mapping activation names to functions ROPE_TYPES: Supported RoPE (Rotary Position Embedding) types

Functions:

quick_gelu: Quick GELU activation function canonicalize_dtype: Canonicalize dtype for JAX arrays get_activation: Get activation function by name quantize_linear: Apply quantization to linear layers replace_dot: Replace JAX dot operations

Key Features:
  • Activation function registry

  • Data type canonicalization

  • Module quantization utilities

  • Sharding constraint helpers

  • Memory optimization tools

Example

>>> from easydel.infra.utils import ACT2FN, canonicalize_dtype
>>> # Get activation function
>>> activation = ACT2FN["gelu"]
>>> # Canonicalize dtype
>>> dtype = canonicalize_dtype(array, dtype=jnp.float32)
easydel.infra.utils.ACT2FN = {'elu': <PjitFunction of <function elu>>, 'gelu': functools.partial(<function gelu>, approximate=False), 'gelu_new': functools.partial(<function gelu>, approximate=True), 'gelu_pytorch_tanh': functools.partial(<function gelu>, approximate=True), 'glu': <PjitFunction of <function glu>>, 'leaky_relu': functools.partial(<PjitFunction of <function leaky_relu>>, negative_slope=0.01), 'quick_gelu': <function quick_gelu>, 'relu': <jax._src.custom_derivatives.custom_jvp object>, 'sigmoid': <PjitFunction of <function sigmoid>>, 'silu': <PjitFunction of <function silu>>, 'softmax': <function softmax>, 'swish': <PjitFunction of <function silu>>, 'tanh': <PjitFunction of <function tanh>>}#

Registry of activation functions by name.

Maps activation function names to their implementations. Supports common activations used in neural networks.

class easydel.infra.utils.ActivationType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

ELU = 'elu'#
GELU = 'gelu'#
GELU_NEW = 'gelu_new'#
GELU_PYTORCH_TANH = 'gelu_pytorch_tanh'#
GLU = 'glu'#
LEAKY_RELU = 'leaky_relu'#
QUICK_GELU = 'quick_gelu'#
RELU = 'relu'#
SIGMOID = 'sigmoid'#
SILU = 'silu'#
SOFTMAX = 'softmax'#
SWISH = 'swish'#
TANH = 'tanh'#
class easydel.infra.utils.ArrayParam(value: Union[A, VariableMetadata[A]], *, use_ref: bool | None = None, **metadata: Any)[source]#

Bases: Param

Parameterized array with serializable initialization.

A parameter container that stores initialization metadata (method name and kwargs) as strings/dicts instead of functions, making it pickleable and serializable. This is particularly useful for checkpointing and distributed training.

shape#

The shape of the parameter array.

Type

collections.abc.Sequence[int]

dtype#

The data type of the parameter array.

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]

init_method#

Name of the JAX initializer (e.g., “normal”, “zeros”, “ones”).

Type

str

init_kwargs#

Optional kwargs passed to the initializer.

Type

easydel.infra.utils.hashable_dict | None

classmethod bound(shape: Sequence[int], dtype: Union[str, type[Any], dtype, SupportsDType], init_method: str, init_kwargs: easydel.infra.utils.hashable_dict | None = None, *, key: Optional[Union[Key[Array, ''], UInt32[Array, '2']]] = None, value: jax.Array | None = None, use_ref: bool | None = None, **metadata)[source]#

Create an ArrayParam with initialized value.

Parameters
  • shape – Shape of the parameter array.

  • dtype – Data type for the parameter.

  • init_method – Name of JAX initializer (e.g., “normal”, “zeros”, “kaiming_uniform”).

  • init_kwargs – Optional keyword arguments for the initializer.

  • key – PRNG key for random initialization. Required if value is None.

  • value – Pre-computed value. If provided, skips initialization.

  • use_ref – Whether to use reference semantics.

  • **metadata – Additional metadata to store with the parameter.

Returns

An initialized ArrayParam instance.

Return type

ArrayParam

dtype: DTypeLike#
init_kwargs: hashable_dict | None = None#
init_method: str = 'normal'#
raw_value: A#
resure(key: Union[Key[Array, ''], UInt32[Array, '2']], shard_fn: Optional[Callable[[Array], Array]] = None) None[source]#

Reinitialize the parameter value with a new random key.

Regenerates the parameter value using the stored initialization method and optional sharding function. Useful for resetting parameters or applying sharding after initialization.

Parameters
  • key – PRNG key for random initialization.

  • shard_fn – Optional function to apply sharding to the reinitialized value.

shape: Sequence[int]#
class easydel.infra.utils.AttnMaskDetail(mask_type: AttnMaskType, size: int, offset: int | None = None, chunks: int | None = None, bricks: int | None = None)[source]#

Bases: object

Details for attention mask configuration.

Specifies the type and parameters of attention masking to use. Registered as a JAX pytree for use in JAX transformations.

mask_type#

Type of attention mask (FULL, SLIDING, or CHUNK).

Type

easydel.infra.utils.AttnMaskType

size#

Size parameter for the mask (e.g., window size for sliding).

Type

int

offset#

Optional offset for mask positioning.

Type

int | None

chunks#

Optional number of chunks for chunk attention.

Type

int | None

bricks#

Optional number of bricks for hierarchical attention.

Type

int | None

Example

>>> mask_detail = AttnMaskDetail(
...     mask_type=AttnMaskType.SLIDING,
...     size=512,
...     offset=0
... )
bricks: int | None = None#
chunks: int | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

mask_type: AttnMaskType#
offset: int | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

size: int#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.infra.utils.AttnMaskType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

CHUNK = 'ATTN_MASK_CHUNK'#
FULL = 'ATTN_MASK_FULL'#
SLIDING = 'ATTN_MASK_SLIDING'#
classmethod from_hf(hf_type: Literal['sliding_attention', 'full_attention', 'chunk_attention'])[source]#
class easydel.infra.utils.CompilationTracker[source]#

Bases: object

Tracks XLA compilation and FLOP counts across function calls.

Monitors the compilation of XLA executables and accumulates their FLOP counts. Useful for profiling and understanding computational costs of JAX programs.

first_time#

Whether this is the first compilation trace.

cached_flops#

Total accumulated FLOPs from all compiled functions.

functions#

List of compiled XLA executables.

Properties:

online_flops: Current total FLOPs from all tracked functions.

trace_compilation()[source]#

Context manager for tracing compilation.

Example

>>> tracker = CompilationTracker()
>>> with tracker.trace_compilation():
...     result = model(inputs)
>>> print(f"Total FLOPs: {tracker.cached_flops}")
property online_flops#
trace_compilation()[source]#
class easydel.infra.utils.FlopCalcConfig(hidden_dim: int, intermediate_dim: int, num_layers: int, num_heads: int, kv_heads: int, head_dim: int, seq_len: int, enc_num_layers: int = 0, enc_seq_len: int = 0, glu: bool = False, num_experts: int = 1, num_shared_experts: int = 0, num_experts_per_tok: int = 1, activation_type: ActivationType = ActivationType.GELU, task: TaskType = TaskType.AUTO_BIND, vocab_size: int = 0, num_labels: int = 0, vision_hidden_dim: int = 0, vision_intermediate_dim: int = 0, vision_num_layers: int = 0, vision_num_heads: int = 0, vision_seq_len: int = 0, include_loss: bool = False)[source]#

Bases: object

Configuration for calculating FLOPs in transformer models.

Comprehensive configuration that captures all parameters needed to calculate the theoretical FLOP count for various transformer architectures including encoder-decoder, MoE, and vision transformers.

hidden_dim#

Hidden dimension of the model.

Type

int

intermediate_dim#

Dimension of FFN intermediate layer.

Type

int

num_layers#

Number of decoder (or encoder-only) layers.

Type

int

num_heads#

Number of attention heads.

Type

int

kv_heads#

Number of key-value heads (for GQA/MQA).

Type

int

head_dim#

Dimension of each attention head.

Type

int

seq_len#

Sequence length for decoder or encoder-only models.

Type

int

enc_num_layers#

Number of encoder layers (for seq2seq).

Type

int

enc_seq_len#

Encoder sequence length (for seq2seq).

Type

int

glu#

Whether using GLU activation in FFN.

Type

bool

num_experts#

Number of MoE experts.

Type

int

num_shared_experts#

Number of shared experts in MoE.

Type

int

num_experts_per_tok#

Experts activated per token.

Type

int

activation_type#

Type of activation function.

Type

easydel.infra.utils.ActivationType

task#

Model task type (affects head computation).

Type

easydel.infra.utils.TaskType

vocab_size#

Vocabulary size for LM head.

Type

int

num_labels#

Number of labels for classification.

Type

int

vision_hidden_dim#

Hidden dim for vision transformer.

Type

int

vision_intermediate_dim#

FFN dim for vision transformer.

Type

int

vision_num_layers#

Number of vision transformer layers.

Type

int

vision_num_heads#

Number of vision attention heads.

Type

int

vision_seq_len#

Vision sequence length (patches).

Type

int

include_loss#

Whether to include loss computation in FLOPs.

Type

bool

Example

>>> config = FlopCalcConfig(
...     hidden_dim=768,
...     intermediate_dim=3072,
...     num_layers=12,
...     num_heads=12,
...     kv_heads=12,
...     head_dim=64,
...     seq_len=1024,
...     task=TaskType.CAUSAL_LM,
...     vocab_size=50000
... )
>>> flops = flops_per_token(config)
activation_type: ActivationType = 'gelu'#
enc_num_layers: int = 0#
enc_seq_len: int = 0#
glu: bool = False#
head_dim: int#
hidden_dim: int#
include_loss: bool = False#
intermediate_dim: int#
kv_heads: int#
num_experts: int = 1#
num_experts_per_tok: int = 1#
num_heads: int#
num_labels: int = 0#
num_layers: int#
num_shared_experts: int = 0#
seq_len: int#
task: TaskType = 'auto-bind'#
vision_hidden_dim: int = 0#
vision_intermediate_dim: int = 0#
vision_num_heads: int = 0#
vision_num_layers: int = 0#
vision_seq_len: int = 0#
vocab_size: int = 0#
class easydel.infra.utils.FunctionTracer[source]#

Bases: object

Tracer for capturing new XLA executables during compilation.

Used to track which functions are compiled during a trace operation. Captures the difference between executables before and after tracing.

new_executables#

List of TraceResult objects for newly compiled functions.

_before#

Set of executables that existed before tracing started.

Example

>>> with trace_functions() as tracer:
...     result = jitted_function(x)
>>> print(f"Compiled {len(tracer.new_executables)} functions")
>>> print(f"Total FLOPs: {sum(t.flops for t in tracer.new_executables)}")
class easydel.infra.utils.ModuleCaches(value: Union[A, VariableMetadata[A]], *, use_ref: bool | None = None, **metadata: Any)[source]#

Bases: Cache

Cache container for module-level cached values.

Extends flax.nnx.Cache to provide caching functionality for EasyDeL modules, particularly for caching computed values like frequencies, masks, and other reusable tensors.

raw_value: A#
class easydel.infra.utils.OverWriteWithGradient(value: Union[A, VariableMetadata[A]], *, use_ref: bool | None = None, **metadata: Any)[source]#

Bases: Param

Parameter type that allows gradient overwrites.

Special parameter container that permits gradients to directly overwrite the parameter values during optimization, useful for certain advanced optimization techniques.

raw_value: A#
class easydel.infra.utils.TaskType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

AUDIO_CLASSIFICATION = 'audio-classification'#
AUTO_BIND = 'auto-bind'#
BASE_MODULE = 'base-module'#
BASE_VISION = 'vision-module'#
CAUSAL_LM = 'causal-language-model'#
DIFFUSION_LM = 'diffusion-language-model'#
IMAGE_CLASSIFICATION = 'image-classification'#
IMAGE_TEXT_TO_TEXT = 'image-text-to-text'#
SEQUENCE_CLASSIFICATION = 'sequence-classification'#
SEQUENCE_TO_SEQUENCE = 'sequence-to-sequence'#
SPEECH_SEQUENCE_TO_SEQUENCE = 'speech-sequence-to-sequence'#
VISION_LM = 'vision-language-model'#
ZERO_SHOT_IMAGE_CLASSIFICATION = 'zero-shot-image-classification'#
class easydel.infra.utils.TraceResult(executable)[source]#

Bases: object

Container for XLA executable trace results with cost analysis.

Wraps an XLA executable and provides lazy access to its cost analysis, including FLOP counts and other performance metrics.

_executable#

The underlying XLA executable.

_cached_cost#

Cached cost analysis result.

Properties:

cost_analysis: Returns the cost analysis dict (cached after first access). flops: Returns the FLOP count from cost analysis.

property cost_analysis#
property flops#
easydel.infra.utils.add_start_docstrings(*docstr)[source]#

The add_start_docstrings function is a decorator that adds the docstrings to the beginning of a function. The add_start_docstrings function takes in an arbitrary number of strings and returns a decorator. The returned decorator takes in one argument, fn, which is assumed to be a function. The docstring for fn is set equal to the concatenation of all the strings passed into add_start_docstrings plus (if it exists) the original docstring for fn.

Parameters

*docstr – Pass in a variable number of arguments to the function

Returns

A decorator that adds the docstrings to the function

easydel.infra.utils.apply_lora_to_layers(model: Module, /, *, lora_rank: int, lora_pattern: str | None = None, verbose: bool = True, rngs: flax.nnx.rnglib.Rngs | None = None) Module[source]#

Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.

Parameters
  • model – The EasyDeL model to modify.

  • lora_rank – The rank of the LoRA adapters.

  • lora_pattern – A regular expression pattern to match the names of modules to which LoRA should be applied. Defaults to “.*” (all linear layers).

  • verbose – Whether to display a progress bar.

  • rngs – A flax.nnx.Rngs instance for random number generation. If None, initializes with a seed of 0.

Returns

The modified model with LoRA applied to the specified layers.

easydel.infra.utils.apply_sparsity_to_params(params: dict[str, Any] | Any, sparsify_module: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', verbose: bool = True) dict[str, Any] | Any[source]#
easydel.infra.utils.auto_remat(module: type[M], /, *, policy: Union[EasyDeLGradientCheckPointers, str, Callable] = EasyDeLGradientCheckPointers.NONE, prevent_cse: bool = True, save_names: list[str] | None = None, exclude_names: list[str] | None = None) type[M][source]#
easydel.infra.utils.auto_remat(module1: type[M], module2: type[M], /, *, policy: Union[EasyDeLGradientCheckPointers, str, Callable] = EasyDeLGradientCheckPointers.NONE, prevent_cse: bool = True, save_names: list[str] | None = None, exclude_names: list[str] | None = None) tuple[type[M], type[M]]
easydel.infra.utils.auto_remat(*modules: type[M], policy: Union[EasyDeLGradientCheckPointers, str, Callable] = EasyDeLGradientCheckPointers.NONE, prevent_cse: bool = True, save_names: list[str] | None = None, exclude_names: list[str] | None = None) tuple[type[M], ...]

Apply gradient checkpointing (rematerialization) to module(s).

Wraps module __call__ methods with JAX’s remat (rematerialization) to trade compute for memory during training. Supports fine-grained control via checkpoint_name annotations added to models.

Parameters
  • *modules – One or more module classes to wrap with remat.

  • policy – Checkpointing policy. Can be: - EasyDeLGradientCheckPointers enum value - String policy name (e.g., ‘dots_saveable’, ‘nothing_saveable’) - Custom callable policy (e.g., from create_transformer_checkpoint_policy) - ‘save_only_these_names’: Use with save_names param - ‘save_anything_except_these_names’: Use with exclude_names param

  • prevent_cse – If True, prevents common subexpression elimination.

  • save_names – List of checkpoint names to save (for ‘save_only_these_names’). Works with checkpoint_name calls in models.

  • exclude_names – List of checkpoint names to exclude from saving.

Returns

Single module or tuple of modules with remat applied.

Examples

>>> # Basic usage with predefined policy
>>> AttentionModule = auto_remat(AttentionModule, policy='dots_saveable')
>>>
>>> # Multiple modules
>>> AttentionModule, MLPModule = auto_remat(
...     AttentionModule, MLPModule,
...     policy='nothing_saveable'
... )
>>>
>>> # Custom policy saving only specific checkpoints
>>> model = auto_remat(
...     model,
...     policy='save_only_these_names',
...     save_names=['attn_output', 'mlp_output', 'residual']
... )
>>>
>>> # Using transformer-optimized policy
>>> policy = create_transformer_checkpoint_policy(
...     save_attention=True,
...     save_mlp=False  # Recompute MLP to save memory
... )
>>> model = auto_remat(model, policy=policy)
easydel.infra.utils.block_wise_ffn(remat_ffn: Callable, inputs: Array, chunk_size: int) Array[source]#

Apply a feed-forward network block-wise to reduce memory usage.

Implements the block-wise feed-forward approach from the near-infinite context length paper. This technique processes the FFN in chunks along the sequence dimension to reduce peak memory usage during training.

Parameters
  • remat_ffn – The feed-forward network function to apply. Should be rematerialized (checkpointed) for memory efficiency.

  • inputs – Input tensor with shape (batch_size, sequence_length, hidden_dim).

  • chunk_size – Size of chunks to process. Sequence length must be divisible by chunk_size.

Returns

Output tensor with same shape as inputs.

Raises

EasyDeLBlockWiseFFNError – If inputs have wrong shape or chunk_size doesn’t divide sequence length evenly.

Note

  • For generation (sequence_length=1), applies FFN directly without chunking

  • For training, processes sequence in chunks to reduce memory

  • Requires sequence_length to be divisible by chunk_size

Example

>>> ffn = lambda x: mlp(x)  # Your FFN function
>>> chunked_output = block_wise_ffn(ffn, inputs, chunk_size=256)
easydel.infra.utils.canonicalize_dtype(*args, dtype: numpy.dtype | None = None, inexact: bool = True) dtype[source]#

Canonicalize an optional dtype to the definitive dtype.

Infers or validates the dtype for JAX operations. If dtype is None, infers from input arguments. Otherwise validates and returns the specified dtype.

Parameters
  • *args – JAX array compatible values (None values ignored).

  • dtype

    Optional dtype override. If specified, arguments are

    cast to this dtype and inference is disabled.

    inexact: When True, the output dtype must be a subdtype of jnp.inexact. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don’position_ids work directly on integers like taking a mean for example.

Returns

The dtype that *args should be cast to.

easydel.infra.utils.count_flop_jaxpr(jaxpr) int[source]#

Count flops in a Jaxpr.

easydel.infra.utils.create_transformer_checkpoint_policy(save_attention: bool = True, save_mlp: bool = True, save_residuals: bool = True, save_layer_outputs: bool = False, save_embeddings: bool = False, custom_names: list[str] | None = None) Callable[source]#

Create a checkpoint policy optimized for transformer models.

Creates a custom checkpoint policy that selectively saves transformer components based on the checkpoint_name calls we’ve added to all models.

Parameters
  • save_attention – Whether to save attention outputs (attn_query, attn_key, attn_value, attn_output)

  • save_mlp – Whether to save MLP outputs (mlp_gate, mlp_up, mlp_down, mlp_output)

  • save_residuals – Whether to save residual connections

  • save_layer_outputs – Whether to save layer outputs

  • save_embeddings – Whether to save embeddings and model outputs

  • custom_names – Additional checkpoint names to save

Returns

JAX checkpoint policy function

Example

>>> # Save only critical transformer components
>>> policy = create_transformer_checkpoint_policy(
...     save_attention=True,
...     save_mlp=False,  # Recompute MLP
...     save_residuals=True
... )
>>> model = auto_remat(model, policy=policy)
easydel.infra.utils.extract_static_parameters(module)[source]#

Extract static_argnums for specified parameters across functions in a module.

Parameters

module (types.ModuleType) – The module to inspect

Returns

A dictionary mapping function names to their static parameter indices

Return type

dict

easydel.infra.utils.flop_activation(activation_type: ActivationType, dim: int) float[source]#

Calculate FLOPs for different activation functions.

easydel.infra.utils.flop_attention(hidden_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_len: int) float[source]#
easydel.infra.utils.flop_cls_head(hidden_dim: int, num_labels: int) float[source]#
easydel.infra.utils.flop_cross_attention(hidden_dim: int, num_heads: int, enc_seq_len: int, dec_seq_len: int) float[source]#
easydel.infra.utils.flop_layernorm(hidden_dim: int) float[source]#
easydel.infra.utils.flop_lm_head(hidden_dim: int, vocab_size: int) float[source]#
easydel.infra.utils.flop_loss(num_classes: int) float[source]#
easydel.infra.utils.flop_mlp(cfg: FlopCalcConfig, hidden_dim: int, intermediate_dim: int) float[source]#
easydel.infra.utils.flop_seq2seq(cfg: FlopCalcConfig) float[source]#
easydel.infra.utils.flop_transformer_body(layers: int, seq_len: int, hidden_dim: int, intermediate_dim: int, cfg: FlopCalcConfig) float[source]#
easydel.infra.utils.flop_vision_tower(cfg: FlopCalcConfig) float[source]#
easydel.infra.utils.flops_per_token(cfg: FlopCalcConfig) float[source]#
easydel.infra.utils.get_dot_general_by_bits(bits: int | None = None, mode: Literal['train', 'serve', 'convert'] = 'train') dict[source]#

The get_general_dot function is a helper function that returns a q_flax.QDotGeneral object with the specified number of bits for forward and backward passes. If no bits are specified, the function returns None.

Parameters
  • bits – tp.Optional[int]: Specify the number of bits for quantization

  • mode – EasyMethod: Specify the use of model to init the QDot Method for (e.q TRAIN,SERVE,…)

Returns

A dict that contain dot_general_cls

easydel.infra.utils.get_gradient_checkpoint_policy(name: str | easydel.infra.etils.EasyDeLGradientCheckPointers, save_names: list[str] | None = None, exclude_names: list[str] | None = None) Callable[source]#

Get a gradient checkpointing policy by name or create a custom one.

Retrieves a JAX gradient checkpointing policy function that determines which intermediate values to save during forward pass for use in backward pass. This is used to trade compute for memory in gradient calculations.

Parameters
  • name – Name of the checkpointing policy or EasyDeLGradientCheckPointers enum. Supported values: - ‘everything_saveable’: Save all intermediate values - ‘nothing_saveable’: Save no intermediate values (maximum recomputation) - ‘dots_saveable’: Save dot product results - ‘checkpoint_dots’: Checkpoint dot operations - ‘dots_with_no_batch_dims_saveable’: Save dots without batch dimensions - ‘checkpoint_dots_with_no_batch_dims’: Checkpoint dots without batch dims - ‘save_anything_except_these_names’: Save all except specified names - ‘save_any_names_but_these’: Save any names except specified - ‘save_only_these_names’: Save only specified names - ‘save_from_both_policies’: Combine two policies

  • save_names – List of checkpoint names to save (used with ‘save_only_these_names’)

  • exclude_names – List of checkpoint names to exclude (used with ‘save_anything_except_these_names’)

Returns

The corresponding JAX checkpoint policy function.

Raises
  • KeyError – If the policy name is not recognized.

  • ValueError – If save_names or exclude_names are not provided when required.

Example

>>> # Basic policy
>>> policy = get_gradient_checkpoint_policy('dots_saveable')
>>>
>>> # Custom policy saving only specific checkpoints
>>> policy = get_gradient_checkpoint_policy(
...     'save_only_these_names',
...     save_names=['attn_output', 'mlp_output']
... )
class easydel.infra.utils.hashable_dict[source]#

Bases: dict

easydel.infra.utils.is_flatten(pytree: dict)[source]#
The is_flatten function checks if the pytree is flattened.

If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id). Otherwise, it will be an integer representing mpl_id.

Parameters

pytree – dict: Pass the pytree to the function

Returns

True if the pytree is a flattened tree, and false otherwise

easydel.infra.utils.merge_lora_params(model: Module, lora_tree: dict) Module[source]#

get LoRA (Low-Rank Adaptation) from layers within a model.

Parameters

model – The EasyDeL model.

Returns

LoRA Layer Weights.

easydel.infra.utils.quantize_linear_layers(model: Module, /, *, quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = None, verbose: bool = True) Module[source]#

Quantize parameters to requested precision, excluding specified layers.

Parameters
  • model – The model to quantize.

  • quantization_config – Quantization config specifying dtype, block_size, and pattern.

  • verbose – Whether to use tqdm for logging.

Returns

Quantized parameters in the same structure as the input.

easydel.infra.utils.quick_gelu(x)[source]#

Quick GELU activation function.

A faster approximation of GELU using sigmoid.

Parameters

x – Input array.

Returns

Activated array.

easydel.infra.utils.split_lora_params(model: Module) Module[source]#

get LoRA (Low-Rank Adaptation) from layers within a model.

Parameters

model – The EasyDeL model.

Returns

LoRA Layer Weights.

easydel.infra.utils.trace_functions()[source]#
easydel.infra.utils.unwrap_lora_to_layers(model: Module, /, *, verbose: bool = True) Module[source]#

UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.