easydel.infra.utils#
Utility functions and helpers for EasyDeL infrastructure.
Provides common utilities used throughout the EasyDeL framework, including activation functions, dtype handling, module manipulation, and various helper functions for model operations.
- Constants:
ACT2FN: Dictionary mapping activation names to functions ROPE_TYPES: Supported RoPE (Rotary Position Embedding) types
- Functions:
quick_gelu: Quick GELU activation function canonicalize_dtype: Canonicalize dtype for JAX arrays get_activation: Get activation function by name quantize_linear: Apply quantization to linear layers replace_dot: Replace JAX dot operations
- Key Features:
Activation function registry
Data type canonicalization
Module quantization utilities
Sharding constraint helpers
Memory optimization tools
Example
>>> from easydel.infra.utils import ACT2FN, canonicalize_dtype
>>> # Get activation function
>>> activation = ACT2FN["gelu"]
>>> # Canonicalize dtype
>>> dtype = canonicalize_dtype(array, dtype=jnp.float32)
- easydel.infra.utils.ACT2FN = {'elu': <PjitFunction of <function elu>>, 'gelu': functools.partial(<function gelu>, approximate=False), 'gelu_new': functools.partial(<function gelu>, approximate=True), 'gelu_pytorch_tanh': functools.partial(<function gelu>, approximate=True), 'glu': <PjitFunction of <function glu>>, 'leaky_relu': functools.partial(<PjitFunction of <function leaky_relu>>, negative_slope=0.01), 'quick_gelu': <function quick_gelu>, 'relu': <jax._src.custom_derivatives.custom_jvp object>, 'sigmoid': <PjitFunction of <function sigmoid>>, 'silu': <PjitFunction of <function silu>>, 'softmax': <function softmax>, 'swish': <PjitFunction of <function silu>>, 'tanh': <PjitFunction of <function tanh>>}#
Registry of activation functions by name.
Maps activation function names to their implementations. Supports common activations used in neural networks.
- class easydel.infra.utils.ActivationType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,Enum- ELU = 'elu'#
- GELU = 'gelu'#
- GELU_NEW = 'gelu_new'#
- GELU_PYTORCH_TANH = 'gelu_pytorch_tanh'#
- GLU = 'glu'#
- LEAKY_RELU = 'leaky_relu'#
- QUICK_GELU = 'quick_gelu'#
- RELU = 'relu'#
- SIGMOID = 'sigmoid'#
- SILU = 'silu'#
- SOFTMAX = 'softmax'#
- SWISH = 'swish'#
- TANH = 'tanh'#
- class easydel.infra.utils.ArrayParam(value: Union[A, VariableMetadata[A]], *, use_ref: bool | None = None, **metadata: Any)[source]#
Bases:
ParamParameterized array with serializable initialization.
A parameter container that stores initialization metadata (method name and kwargs) as strings/dicts instead of functions, making it pickleable and serializable. This is particularly useful for checkpointing and distributed training.
- shape#
The shape of the parameter array.
- Type
collections.abc.Sequence[int]
- dtype#
The data type of the parameter array.
- Type
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType]
- init_method#
Name of the JAX initializer (e.g., “normal”, “zeros”, “ones”).
- Type
str
- init_kwargs#
Optional kwargs passed to the initializer.
- Type
- classmethod bound(shape: Sequence[int], dtype: Union[str, type[Any], dtype, SupportsDType], init_method: str, init_kwargs: easydel.infra.utils.hashable_dict | None = None, *, key: Optional[Union[Key[Array, ''], UInt32[Array, '2']]] = None, value: jax.Array | None = None, use_ref: bool | None = None, **metadata)[source]#
Create an ArrayParam with initialized value.
- Parameters
shape – Shape of the parameter array.
dtype – Data type for the parameter.
init_method – Name of JAX initializer (e.g., “normal”, “zeros”, “kaiming_uniform”).
init_kwargs – Optional keyword arguments for the initializer.
key – PRNG key for random initialization. Required if value is None.
value – Pre-computed value. If provided, skips initialization.
use_ref – Whether to use reference semantics.
**metadata – Additional metadata to store with the parameter.
- Returns
An initialized ArrayParam instance.
- Return type
- dtype: DTypeLike#
- init_kwargs: hashable_dict | None = None#
- init_method: str = 'normal'#
- raw_value: A#
- resure(key: Union[Key[Array, ''], UInt32[Array, '2']], shard_fn: Optional[Callable[[Array], Array]] = None) None[source]#
Reinitialize the parameter value with a new random key.
Regenerates the parameter value using the stored initialization method and optional sharding function. Useful for resetting parameters or applying sharding after initialization.
- Parameters
key – PRNG key for random initialization.
shard_fn – Optional function to apply sharding to the reinitialized value.
- shape: Sequence[int]#
- class easydel.infra.utils.AttnMaskDetail(mask_type: AttnMaskType, size: int, offset: int | None = None, chunks: int | None = None, bricks: int | None = None)[source]#
Bases:
objectDetails for attention mask configuration.
Specifies the type and parameters of attention masking to use. Registered as a JAX pytree for use in JAX transformations.
- mask_type#
Type of attention mask (FULL, SLIDING, or CHUNK).
- size#
Size parameter for the mask (e.g., window size for sliding).
- Type
int
- offset#
Optional offset for mask positioning.
- Type
int | None
- chunks#
Optional number of chunks for chunk attention.
- Type
int | None
- bricks#
Optional number of bricks for hierarchical attention.
- Type
int | None
Example
>>> mask_detail = AttnMaskDetail( ... mask_type=AttnMaskType.SLIDING, ... size=512, ... offset=0 ... )
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- mask_type: AttnMaskType#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- size: int#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.infra.utils.AttnMaskType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,Enum- CHUNK = 'ATTN_MASK_CHUNK'#
- FULL = 'ATTN_MASK_FULL'#
- SLIDING = 'ATTN_MASK_SLIDING'#
- class easydel.infra.utils.CompilationTracker[source]#
Bases:
objectTracks XLA compilation and FLOP counts across function calls.
Monitors the compilation of XLA executables and accumulates their FLOP counts. Useful for profiling and understanding computational costs of JAX programs.
- first_time#
Whether this is the first compilation trace.
- cached_flops#
Total accumulated FLOPs from all compiled functions.
- functions#
List of compiled XLA executables.
- Properties:
online_flops: Current total FLOPs from all tracked functions.
Example
>>> tracker = CompilationTracker() >>> with tracker.trace_compilation(): ... result = model(inputs) >>> print(f"Total FLOPs: {tracker.cached_flops}")
- property online_flops#
- class easydel.infra.utils.FlopCalcConfig(hidden_dim: int, intermediate_dim: int, num_layers: int, num_heads: int, kv_heads: int, head_dim: int, seq_len: int, enc_num_layers: int = 0, enc_seq_len: int = 0, glu: bool = False, num_experts: int = 1, num_shared_experts: int = 0, num_experts_per_tok: int = 1, activation_type: ActivationType = ActivationType.GELU, task: TaskType = TaskType.AUTO_BIND, vocab_size: int = 0, num_labels: int = 0, vision_hidden_dim: int = 0, vision_intermediate_dim: int = 0, vision_num_layers: int = 0, vision_num_heads: int = 0, vision_seq_len: int = 0, include_loss: bool = False)[source]#
Bases:
objectConfiguration for calculating FLOPs in transformer models.
Comprehensive configuration that captures all parameters needed to calculate the theoretical FLOP count for various transformer architectures including encoder-decoder, MoE, and vision transformers.
Hidden dimension of the model.
- Type
int
- intermediate_dim#
Dimension of FFN intermediate layer.
- Type
int
- num_layers#
Number of decoder (or encoder-only) layers.
- Type
int
- num_heads#
Number of attention heads.
- Type
int
- kv_heads#
Number of key-value heads (for GQA/MQA).
- Type
int
- head_dim#
Dimension of each attention head.
- Type
int
- seq_len#
Sequence length for decoder or encoder-only models.
- Type
int
- enc_num_layers#
Number of encoder layers (for seq2seq).
- Type
int
- enc_seq_len#
Encoder sequence length (for seq2seq).
- Type
int
- glu#
Whether using GLU activation in FFN.
- Type
bool
- num_experts#
Number of MoE experts.
- Type
int
Number of shared experts in MoE.
- Type
int
- num_experts_per_tok#
Experts activated per token.
- Type
int
- activation_type#
Type of activation function.
- task#
Model task type (affects head computation).
- vocab_size#
Vocabulary size for LM head.
- Type
int
- num_labels#
Number of labels for classification.
- Type
int
Hidden dim for vision transformer.
- Type
int
- vision_intermediate_dim#
FFN dim for vision transformer.
- Type
int
- vision_num_layers#
Number of vision transformer layers.
- Type
int
- vision_num_heads#
Number of vision attention heads.
- Type
int
- vision_seq_len#
Vision sequence length (patches).
- Type
int
- include_loss#
Whether to include loss computation in FLOPs.
- Type
bool
Example
>>> config = FlopCalcConfig( ... hidden_dim=768, ... intermediate_dim=3072, ... num_layers=12, ... num_heads=12, ... kv_heads=12, ... head_dim=64, ... seq_len=1024, ... task=TaskType.CAUSAL_LM, ... vocab_size=50000 ... ) >>> flops = flops_per_token(config)
- activation_type: ActivationType = 'gelu'#
- enc_num_layers: int = 0#
- enc_seq_len: int = 0#
- glu: bool = False#
- head_dim: int#
- hidden_dim: int#
- include_loss: bool = False#
- intermediate_dim: int#
- kv_heads: int#
- num_experts: int = 1#
- num_experts_per_tok: int = 1#
- num_heads: int#
- num_labels: int = 0#
- num_layers: int#
- num_shared_experts: int = 0#
- seq_len: int#
- vision_hidden_dim: int = 0#
- vision_intermediate_dim: int = 0#
- vision_num_heads: int = 0#
- vision_num_layers: int = 0#
- vision_seq_len: int = 0#
- vocab_size: int = 0#
- class easydel.infra.utils.FunctionTracer[source]#
Bases:
objectTracer for capturing new XLA executables during compilation.
Used to track which functions are compiled during a trace operation. Captures the difference between executables before and after tracing.
- new_executables#
List of TraceResult objects for newly compiled functions.
- _before#
Set of executables that existed before tracing started.
Example
>>> with trace_functions() as tracer: ... result = jitted_function(x) >>> print(f"Compiled {len(tracer.new_executables)} functions") >>> print(f"Total FLOPs: {sum(t.flops for t in tracer.new_executables)}")
- class easydel.infra.utils.ModuleCaches(value: Union[A, VariableMetadata[A]], *, use_ref: bool | None = None, **metadata: Any)[source]#
Bases:
CacheCache container for module-level cached values.
Extends flax.nnx.Cache to provide caching functionality for EasyDeL modules, particularly for caching computed values like frequencies, masks, and other reusable tensors.
- raw_value: A#
- class easydel.infra.utils.OverWriteWithGradient(value: Union[A, VariableMetadata[A]], *, use_ref: bool | None = None, **metadata: Any)[source]#
Bases:
ParamParameter type that allows gradient overwrites.
Special parameter container that permits gradients to directly overwrite the parameter values during optimization, useful for certain advanced optimization techniques.
- raw_value: A#
- class easydel.infra.utils.TaskType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
str,Enum- AUDIO_CLASSIFICATION = 'audio-classification'#
- AUTO_BIND = 'auto-bind'#
- BASE_MODULE = 'base-module'#
- BASE_VISION = 'vision-module'#
- CAUSAL_LM = 'causal-language-model'#
- DIFFUSION_LM = 'diffusion-language-model'#
- IMAGE_CLASSIFICATION = 'image-classification'#
- IMAGE_TEXT_TO_TEXT = 'image-text-to-text'#
- SEQUENCE_CLASSIFICATION = 'sequence-classification'#
- SEQUENCE_TO_SEQUENCE = 'sequence-to-sequence'#
- SPEECH_SEQUENCE_TO_SEQUENCE = 'speech-sequence-to-sequence'#
- VISION_LM = 'vision-language-model'#
- ZERO_SHOT_IMAGE_CLASSIFICATION = 'zero-shot-image-classification'#
- class easydel.infra.utils.TraceResult(executable)[source]#
Bases:
objectContainer for XLA executable trace results with cost analysis.
Wraps an XLA executable and provides lazy access to its cost analysis, including FLOP counts and other performance metrics.
- _executable#
The underlying XLA executable.
- _cached_cost#
Cached cost analysis result.
- Properties:
cost_analysis: Returns the cost analysis dict (cached after first access). flops: Returns the FLOP count from cost analysis.
- property cost_analysis#
- property flops#
- easydel.infra.utils.add_start_docstrings(*docstr)[source]#
The add_start_docstrings function is a decorator that adds the docstrings to the beginning of a function. The add_start_docstrings function takes in an arbitrary number of strings and returns a decorator. The returned decorator takes in one argument, fn, which is assumed to be a function. The docstring for fn is set equal to the concatenation of all the strings passed into add_start_docstrings plus (if it exists) the original docstring for fn.
- Parameters
*docstr – Pass in a variable number of arguments to the function
- Returns
A decorator that adds the docstrings to the function
- easydel.infra.utils.apply_lora_to_layers(model: Module, /, *, lora_rank: int, lora_pattern: str | None = None, verbose: bool = True, rngs: flax.nnx.rnglib.Rngs | None = None) Module[source]#
Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.
- Parameters
model – The EasyDeL model to modify.
lora_rank – The rank of the LoRA adapters.
lora_pattern – A regular expression pattern to match the names of modules to which LoRA should be applied. Defaults to “.*” (all linear layers).
verbose – Whether to display a progress bar.
rngs – A flax.nnx.Rngs instance for random number generation. If None, initializes with a seed of 0.
- Returns
The modified model with LoRA applied to the specified layers.
- easydel.infra.utils.apply_sparsity_to_params(params: dict[str, Any] | Any, sparsify_module: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', verbose: bool = True) dict[str, Any] | Any[source]#
- easydel.infra.utils.auto_remat(module: type[M], /, *, policy: Union[EasyDeLGradientCheckPointers, str, Callable] = EasyDeLGradientCheckPointers.NONE, prevent_cse: bool = True, save_names: list[str] | None = None, exclude_names: list[str] | None = None) type[M][source]#
- easydel.infra.utils.auto_remat(module1: type[M], module2: type[M], /, *, policy: Union[EasyDeLGradientCheckPointers, str, Callable] = EasyDeLGradientCheckPointers.NONE, prevent_cse: bool = True, save_names: list[str] | None = None, exclude_names: list[str] | None = None) tuple[type[M], type[M]]
- easydel.infra.utils.auto_remat(*modules: type[M], policy: Union[EasyDeLGradientCheckPointers, str, Callable] = EasyDeLGradientCheckPointers.NONE, prevent_cse: bool = True, save_names: list[str] | None = None, exclude_names: list[str] | None = None) tuple[type[M], ...]
Apply gradient checkpointing (rematerialization) to module(s).
Wraps module __call__ methods with JAX’s remat (rematerialization) to trade compute for memory during training. Supports fine-grained control via checkpoint_name annotations added to models.
- Parameters
*modules – One or more module classes to wrap with remat.
policy – Checkpointing policy. Can be: - EasyDeLGradientCheckPointers enum value - String policy name (e.g., ‘dots_saveable’, ‘nothing_saveable’) - Custom callable policy (e.g., from create_transformer_checkpoint_policy) - ‘save_only_these_names’: Use with save_names param - ‘save_anything_except_these_names’: Use with exclude_names param
prevent_cse – If True, prevents common subexpression elimination.
save_names – List of checkpoint names to save (for ‘save_only_these_names’). Works with checkpoint_name calls in models.
exclude_names – List of checkpoint names to exclude from saving.
- Returns
Single module or tuple of modules with remat applied.
Examples
>>> # Basic usage with predefined policy >>> AttentionModule = auto_remat(AttentionModule, policy='dots_saveable') >>> >>> # Multiple modules >>> AttentionModule, MLPModule = auto_remat( ... AttentionModule, MLPModule, ... policy='nothing_saveable' ... ) >>> >>> # Custom policy saving only specific checkpoints >>> model = auto_remat( ... model, ... policy='save_only_these_names', ... save_names=['attn_output', 'mlp_output', 'residual'] ... ) >>> >>> # Using transformer-optimized policy >>> policy = create_transformer_checkpoint_policy( ... save_attention=True, ... save_mlp=False # Recompute MLP to save memory ... ) >>> model = auto_remat(model, policy=policy)
- easydel.infra.utils.block_wise_ffn(remat_ffn: Callable, inputs: Array, chunk_size: int) Array[source]#
Apply a feed-forward network block-wise to reduce memory usage.
Implements the block-wise feed-forward approach from the near-infinite context length paper. This technique processes the FFN in chunks along the sequence dimension to reduce peak memory usage during training.
- Parameters
remat_ffn – The feed-forward network function to apply. Should be rematerialized (checkpointed) for memory efficiency.
inputs – Input tensor with shape (batch_size, sequence_length, hidden_dim).
chunk_size – Size of chunks to process. Sequence length must be divisible by chunk_size.
- Returns
Output tensor with same shape as inputs.
- Raises
EasyDeLBlockWiseFFNError – If inputs have wrong shape or chunk_size doesn’t divide sequence length evenly.
Note
For generation (sequence_length=1), applies FFN directly without chunking
For training, processes sequence in chunks to reduce memory
Requires sequence_length to be divisible by chunk_size
Example
>>> ffn = lambda x: mlp(x) # Your FFN function >>> chunked_output = block_wise_ffn(ffn, inputs, chunk_size=256)
- easydel.infra.utils.canonicalize_dtype(*args, dtype: numpy.dtype | None = None, inexact: bool = True) dtype[source]#
Canonicalize an optional dtype to the definitive dtype.
Infers or validates the dtype for JAX operations. If dtype is None, infers from input arguments. Otherwise validates and returns the specified dtype.
- Parameters
*args – JAX array compatible values (None values ignored).
dtype –
- Optional dtype override. If specified, arguments are
cast to this dtype and inference is disabled.
inexact: When True, the output dtype must be a subdtype of jnp.inexact. Inexact dtypes are real or complex floating points. This is useful when you want to apply operations that don’position_ids work directly on integers like taking a mean for example.
- Returns
The dtype that *args should be cast to.
- easydel.infra.utils.create_transformer_checkpoint_policy(save_attention: bool = True, save_mlp: bool = True, save_residuals: bool = True, save_layer_outputs: bool = False, save_embeddings: bool = False, custom_names: list[str] | None = None) Callable[source]#
Create a checkpoint policy optimized for transformer models.
Creates a custom checkpoint policy that selectively saves transformer components based on the checkpoint_name calls we’ve added to all models.
- Parameters
save_attention – Whether to save attention outputs (attn_query, attn_key, attn_value, attn_output)
save_mlp – Whether to save MLP outputs (mlp_gate, mlp_up, mlp_down, mlp_output)
save_residuals – Whether to save residual connections
save_layer_outputs – Whether to save layer outputs
save_embeddings – Whether to save embeddings and model outputs
custom_names – Additional checkpoint names to save
- Returns
JAX checkpoint policy function
Example
>>> # Save only critical transformer components >>> policy = create_transformer_checkpoint_policy( ... save_attention=True, ... save_mlp=False, # Recompute MLP ... save_residuals=True ... ) >>> model = auto_remat(model, policy=policy)
- easydel.infra.utils.extract_static_parameters(module)[source]#
Extract static_argnums for specified parameters across functions in a module.
- Parameters
module (types.ModuleType) – The module to inspect
- Returns
A dictionary mapping function names to their static parameter indices
- Return type
dict
- easydel.infra.utils.flop_activation(activation_type: ActivationType, dim: int) float[source]#
Calculate FLOPs for different activation functions.
- easydel.infra.utils.flop_attention(hidden_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, seq_len: int) float[source]#
- easydel.infra.utils.flop_cross_attention(hidden_dim: int, num_heads: int, enc_seq_len: int, dec_seq_len: int) float[source]#
- easydel.infra.utils.flop_mlp(cfg: FlopCalcConfig, hidden_dim: int, intermediate_dim: int) float[source]#
- easydel.infra.utils.flop_seq2seq(cfg: FlopCalcConfig) float[source]#
- easydel.infra.utils.flop_transformer_body(layers: int, seq_len: int, hidden_dim: int, intermediate_dim: int, cfg: FlopCalcConfig) float[source]#
- easydel.infra.utils.flop_vision_tower(cfg: FlopCalcConfig) float[source]#
- easydel.infra.utils.flops_per_token(cfg: FlopCalcConfig) float[source]#
- easydel.infra.utils.get_dot_general_by_bits(bits: int | None = None, mode: Literal['train', 'serve', 'convert'] = 'train') dict[source]#
The get_general_dot function is a helper function that returns a q_flax.QDotGeneral object with the specified number of bits for forward and backward passes. If no bits are specified, the function returns None.
- Parameters
bits – tp.Optional[int]: Specify the number of bits for quantization
mode – EasyMethod: Specify the use of model to init the QDot Method for (e.q TRAIN,SERVE,…)
- Returns
A dict that contain dot_general_cls
- easydel.infra.utils.get_gradient_checkpoint_policy(name: str | easydel.infra.etils.EasyDeLGradientCheckPointers, save_names: list[str] | None = None, exclude_names: list[str] | None = None) Callable[source]#
Get a gradient checkpointing policy by name or create a custom one.
Retrieves a JAX gradient checkpointing policy function that determines which intermediate values to save during forward pass for use in backward pass. This is used to trade compute for memory in gradient calculations.
- Parameters
name – Name of the checkpointing policy or EasyDeLGradientCheckPointers enum. Supported values: - ‘everything_saveable’: Save all intermediate values - ‘nothing_saveable’: Save no intermediate values (maximum recomputation) - ‘dots_saveable’: Save dot product results - ‘checkpoint_dots’: Checkpoint dot operations - ‘dots_with_no_batch_dims_saveable’: Save dots without batch dimensions - ‘checkpoint_dots_with_no_batch_dims’: Checkpoint dots without batch dims - ‘save_anything_except_these_names’: Save all except specified names - ‘save_any_names_but_these’: Save any names except specified - ‘save_only_these_names’: Save only specified names - ‘save_from_both_policies’: Combine two policies
save_names – List of checkpoint names to save (used with ‘save_only_these_names’)
exclude_names – List of checkpoint names to exclude (used with ‘save_anything_except_these_names’)
- Returns
The corresponding JAX checkpoint policy function.
- Raises
KeyError – If the policy name is not recognized.
ValueError – If save_names or exclude_names are not provided when required.
Example
>>> # Basic policy >>> policy = get_gradient_checkpoint_policy('dots_saveable') >>> >>> # Custom policy saving only specific checkpoints >>> policy = get_gradient_checkpoint_policy( ... 'save_only_these_names', ... save_names=['attn_output', 'mlp_output'] ... )
- easydel.infra.utils.is_flatten(pytree: dict)[source]#
- The is_flatten function checks if the pytree is flattened.
If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id). Otherwise, it will be an integer representing mpl_id.
- Parameters
pytree – dict: Pass the pytree to the function
- Returns
True if the pytree is a flattened tree, and false otherwise
- easydel.infra.utils.merge_lora_params(model: Module, lora_tree: dict) Module[source]#
get LoRA (Low-Rank Adaptation) from layers within a model.
- Parameters
model – The EasyDeL model.
- Returns
LoRA Layer Weights.
- easydel.infra.utils.quantize_linear_layers(model: Module, /, *, quantization_config: easydel.layers.quantization.quantizers.EasyDeLQuantizationConfig | None = None, verbose: bool = True) Module[source]#
Quantize parameters to requested precision, excluding specified layers.
- Parameters
model – The model to quantize.
quantization_config – Quantization config specifying dtype, block_size, and pattern.
verbose – Whether to use tqdm for logging.
- Returns
Quantized parameters in the same structure as the input.
- easydel.infra.utils.quick_gelu(x)[source]#
Quick GELU activation function.
A faster approximation of GELU using sigmoid.
- Parameters
x – Input array.
- Returns
Activated array.