# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions and helpers for EasyDeL infrastructure.
Provides common utilities used throughout the EasyDeL framework, including
activation functions, dtype handling, module manipulation, and various
helper functions for model operations.
Constants:
ACT2FN: Dictionary mapping activation names to functions
ROPE_TYPES: Supported RoPE (Rotary Position Embedding) types
Functions:
quick_gelu: Quick GELU activation function
canonicalize_dtype: Canonicalize dtype for JAX arrays
get_activation: Get activation function by name
quantize_linear: Apply quantization to linear layers
replace_dot: Replace JAX dot operations
Key Features:
- Activation function registry
- Data type canonicalization
- Module quantization utilities
- Sharding constraint helpers
- Memory optimization tools
Example:
>>> from easydel.infra.utils import ACT2FN, canonicalize_dtype
>>> # Get activation function
>>> activation = ACT2FN["gelu"]
>>> # Canonicalize dtype
>>> dtype = canonicalize_dtype(array, dtype=jnp.float32)
"""
from __future__ import annotations
import functools
import inspect
import re
import types
import typing as tp
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache, partial
import jax
import jax.extend
import jax.tree_util
import numpy as np
from eformer.escale import with_sharding_constraint
from eformer.loggings import get_logger
from eformer.pytree import auto_pytree
from einops import rearrange
from flax import nnx as nn
from jaxtyping import Array, DTypeLike, PRNGKeyArray
from tqdm.auto import tqdm
from easydel.layers.linear import ParallelLinear
from easydel.layers.quantization import EasyDeLQuantizationConfig, EasyQuantizer
from easydel.utils.compiling_utils import hash_fn
from easydel.utils.traversals import flatten_dict, unflatten_dict
from .base_config import EasyMethod
from .errors import EasyDeLBlockWiseFFNError
from .etils import AVAILABLE_SPARSE_MODULE_TYPES, EasyDeLGradientCheckPointers
warnings.filterwarnings(
"ignore",
message="Primitive dynamic_update_slice was not handled by class",
)
logger = get_logger(__name__)
[docs]def quick_gelu(x):
"""Quick GELU activation function.
A faster approximation of GELU using sigmoid.
Args:
x: Input array.
Returns:
Activated array.
"""
return x * jax.nn.sigmoid(1.702 * x)
ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
"tanh": nn.tanh,
"sigmoid": nn.sigmoid,
"leaky_relu": partial(nn.leaky_relu, negative_slope=0.01),
"glu": nn.glu,
"elu": nn.elu,
"softmax": nn.softmax,
"quick_gelu": quick_gelu,
}
"""Registry of activation functions by name.
Maps activation function names to their implementations.
Supports common activations used in neural networks.
"""
ROPE_TYPES = tp.Optional[tp.Literal["none", "linear", "dynamic", "yarn", "su", "llama3", "longrope"]] # noqa
with_sharding_constraint = with_sharding_constraint
[docs]def canonicalize_dtype(
*args,
dtype: jax.numpy.dtype | None = None,
inexact: bool = True,
) -> jax.numpy.dtype:
"""Canonicalize an optional dtype to the definitive dtype.
Infers or validates the dtype for JAX operations. If dtype is None,
infers from input arguments. Otherwise validates and returns the
specified dtype.
Args:
*args: JAX array compatible values (None values ignored).
dtype: Optional dtype override. If specified, arguments are
cast to this dtype and inference is disabled.
inexact: When True, the output dtype must be a subdtype
of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
is useful when you want to apply operations that don'position_ids work directly on
integers like taking a mean for example.
Returns:
The dtype that *args should be cast to.
"""
if dtype is None:
args_filtered = [jax.numpy.asarray(x) for x in args if x is not None]
dtype = jax.numpy.result_type(*args_filtered)
if inexact and not jax.numpy.issubdtype(dtype, jax.numpy.inexact):
dtype = jax.numpy.promote_types(jax.numpy.float32, dtype)
if inexact and not jax.numpy.issubdtype(dtype, jax.numpy.inexact):
raise ValueError(f"Dtype must be inexact: {dtype}")
return dtype
[docs]def get_gradient_checkpoint_policy(
name: str | EasyDeLGradientCheckPointers,
save_names: list[str] | None = None,
exclude_names: list[str] | None = None,
) -> tp.Callable:
"""Get a gradient checkpointing policy by name or create a custom one.
Retrieves a JAX gradient checkpointing policy function that determines
which intermediate values to save during forward pass for use in backward pass.
This is used to trade compute for memory in gradient calculations.
Args:
name: Name of the checkpointing policy or EasyDeLGradientCheckPointers enum.
Supported values:
- 'everything_saveable': Save all intermediate values
- 'nothing_saveable': Save no intermediate values (maximum recomputation)
- 'dots_saveable': Save dot product results
- 'checkpoint_dots': Checkpoint dot operations
- 'dots_with_no_batch_dims_saveable': Save dots without batch dimensions
- 'checkpoint_dots_with_no_batch_dims': Checkpoint dots without batch dims
- 'save_anything_except_these_names': Save all except specified names
- 'save_any_names_but_these': Save any names except specified
- 'save_only_these_names': Save only specified names
- 'save_from_both_policies': Combine two policies
save_names: List of checkpoint names to save (used with 'save_only_these_names')
exclude_names: List of checkpoint names to exclude (used with 'save_anything_except_these_names')
Returns:
The corresponding JAX checkpoint policy function.
Raises:
KeyError: If the policy name is not recognized.
ValueError: If save_names or exclude_names are not provided when required.
Example:
>>> # Basic policy
>>> policy = get_gradient_checkpoint_policy('dots_saveable')
>>>
>>> # Custom policy saving only specific checkpoints
>>> policy = get_gradient_checkpoint_policy(
... 'save_only_these_names',
... save_names=['attn_output', 'mlp_output']
... )
"""
if isinstance(name, EasyDeLGradientCheckPointers):
name = name.value
if name == "save_only_these_names":
if save_names is None:
raise ValueError("save_names must be provided when using 'save_only_these_names' policy")
return jax.checkpoint_policies.save_only_these_names(*save_names)
elif name in ["save_anything_except_these_names", "save_any_names_but_these"]:
if exclude_names is None:
raise ValueError("exclude_names must be provided when using exclude-based policies")
return jax.checkpoint_policies.save_any_names_but_these(*exclude_names)
gradients = dict(
everything_saveable=jax.checkpoint_policies.everything_saveable,
nothing_saveable=jax.checkpoint_policies.nothing_saveable,
dots_saveable=jax.checkpoint_policies.dots_saveable,
checkpoint_dots=jax.checkpoint_policies.checkpoint_dots,
dots_with_no_batch_dims_saveable=jax.checkpoint_policies.dots_with_no_batch_dims_saveable,
checkpoint_dots_with_no_batch_dims=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
save_from_both_policies=jax.checkpoint_policies.save_from_both_policies,
)
if name not in gradients:
raise KeyError(f"Unknown checkpoint policy: {name}")
return gradients[name]
[docs]def add_start_docstrings(*docstr):
"""The add_start_docstrings function is a decorator that adds the docstrings to the beginning of a function.
The add_start_docstrings function takes in an arbitrary number of strings and returns a decorator.
The returned decorator takes in one argument, fn, which is assumed to be a function. The docstring
for fn is set equal to the concatenation of all the strings passed into add_start_docstrings
plus (if it exists) the original docstring for fn.
Args:
*docstr: Pass in a variable number of arguments to the function
Returns:
A decorator that adds the docstrings to the function
"""
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn
return docstring_decorator
[docs]def get_dot_general_by_bits(
bits: int | None = None,
mode: tp.Literal["train", "serve", "convert"] = EasyMethod.TRAIN,
) -> dict:
"""The get_general_dot function is a helper function that returns a q_flax.QDotGeneral object
with the specified number of bits for forward and backward passes. If no bits are specified,
the function returns None.
Args:
bits: tp.Optional[int]: Specify the number of bits for quantization
mode: EasyMethod: Specify the use of model to init the QDot
Method for (e.q TRAIN,SERVE,...)
Returns:
A dict that contain dot_general_cls
"""
if bits is not None:
try:
from aqt.jax.v2 import config as q_config # type: ignore
from aqt.jax.v2.flax import aqt_flax as q_flax # type: ignore
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"No module named `aqt` has been found, please install aqt before using bits option in EasyDeL"
) from e
if mode == EasyMethod.TRAIN:
rhs_quant_mode = q_flax.QuantMode.TRAIN
elif mode == EasyMethod.EVAL or mode == EasyMethod.SERVE:
rhs_quant_mode = q_flax.QuantMode.SERVE
elif mode == EasyMethod.CONVERT:
rhs_quant_mode = q_flax.QuantMode.CONVERT
else:
raise ValueError("Unknown Quant Method for EasyMethod")
return {
"dot_general_cls": functools.partial(
q_flax.AqtDotGeneral,
cfg=q_config.fully_quantized(fwd_bits=bits, bwd_bits=bits),
rhs_quant_mode=rhs_quant_mode,
)
}
return {} # empty just in case of not getting any error
[docs]def block_wise_ffn(remat_ffn: tp.Callable, inputs: jax.Array, chunk_size: int) -> jax.Array:
"""Apply a feed-forward network block-wise to reduce memory usage.
Implements the block-wise feed-forward approach from the near-infinite
context length paper. This technique processes the FFN in chunks along
the sequence dimension to reduce peak memory usage during training.
Args:
remat_ffn: The feed-forward network function to apply. Should be
rematerialized (checkpointed) for memory efficiency.
inputs: Input tensor with shape (batch_size, sequence_length, hidden_dim).
chunk_size: Size of chunks to process. Sequence length must be
divisible by chunk_size.
Returns:
Output tensor with same shape as inputs.
Raises:
EasyDeLBlockWiseFFNError: If inputs have wrong shape or chunk_size
doesn't divide sequence length evenly.
Note:
- For generation (sequence_length=1), applies FFN directly without chunking
- For training, processes sequence in chunks to reduce memory
- Requires sequence_length to be divisible by chunk_size
Example:
>>> ffn = lambda x: mlp(x) # Your FFN function
>>> chunked_output = block_wise_ffn(ffn, inputs, chunk_size=256)
"""
generating = inputs.shape[1] == 1
try:
if generating:
return remat_ffn(inputs)
else:
return rearrange(
jax.lax.scan(
f=lambda carry, idx: (carry.at[:, idx].set(remat_ffn(carry[:, idx])), None),
init=rearrange(inputs, "b (c n) d -> b c n d", c=chunk_size),
xs=jax.numpy.arange(chunk_size),
length=chunk_size,
unroll=True,
)[0],
"b c n d -> b (c n) d",
)
except Exception as e:
raise EasyDeLBlockWiseFFNError(
"You Are using BlockWise FFN from near-infinite-context length paper and you might be passing "
"input arguments in wrong way in case that you don'position_ids want to use this just pass "
"`use_scan_mlp=False` in "
"model config or in config_kwargs in AutoEasyDeLModelFor... or change `scan_mlp_chunk_size` "
f"in configs for more information read Docs.\nOriginal Error\n{e}"
) from e
[docs]def is_flatten(pytree: dict):
"""The is_flatten function checks if the pytree is flattened.
If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id).
Otherwise, it will be an integer representing mpl_id.
Args:
pytree: dict: Pass the pytree to the function
Returns:
True if the pytree is a flattened tree, and false otherwise
"""
mpl = next(iter(pytree.keys()))
return True if isinstance(mpl, tuple) else False
[docs]def quantize_linear_layers(
model: nn.Module,
/,
*,
quantization_config: EasyDeLQuantizationConfig | None = None,
verbose: bool = True,
) -> nn.Module:
"""
Quantize parameters to requested precision, excluding specified layers.
Args:
model: The model to quantize.
quantization_config: Quantization config specifying dtype, block_size, and pattern.
verbose: Whether to use tqdm for logging.
Returns:
Quantized parameters in the same structure as the input.
"""
if quantization_config is None:
return model
quantizer = EasyQuantizer(quantization_config=quantization_config)
return quantizer.quantize_linears(model, verbose=verbose)
[docs]def apply_lora_to_layers(
model: nn.Module,
/,
*,
lora_rank: int,
lora_pattern: str | None = None,
verbose: bool = True,
rngs: nn.Rngs | None = None,
) -> nn.Module:
"""
Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.
Args:
model: The EasyDeL model to modify.
lora_rank: The rank of the LoRA adapters.
lora_pattern: A regular expression pattern to match the names of
modules to which LoRA should be applied. Defaults to ".*" (all linear layers).
verbose: Whether to display a progress bar.
rngs: A `flax.nnx.Rngs` instance for random number generation. If None, initializes with a seed of 0.
Returns:
The modified model with LoRA applied to the specified layers.
"""
from easydel.utils.traversals import get_module_from_path, iter_module_search, set_module_from_path
if not (lora_rank > 0):
raise ValueError("lora_rank should be a positive value and higher than `0`.")
if lora_pattern is None:
lora_pattern = ".*"
if rngs is None:
rngs = nn.Rngs(0)
pattern = re.compile(lora_pattern)
with tqdm(
total=len([p[0] for p in iter_module_search(model, ParallelLinear)]),
desc="Applying LoRA",
disable=not verbose,
) as pbar:
for path, _ in iter_module_search(model, ParallelLinear):
if pattern.search(".".join([str(p) for p in path])):
base_module: ParallelLinear = get_module_from_path(model=model, path=path)
set_module_from_path(
model=model,
path=path,
new_value=nn.LoRA(
base_module=base_module,
rngs=rngs,
dtype=base_module.dtype,
param_dtype=base_module.param_dtype,
in_features=base_module.in_features,
lora_rank=lora_rank,
out_features=base_module.out_features,
),
)
pbar.update(1)
return model
[docs]def split_lora_params(model: nn.Module) -> nn.Module:
"""
get LoRA (Low-Rank Adaptation) from layers within a model.
Args:
model: The EasyDeL model.
Returns:
LoRA Layer Weights.
"""
from easydel.utils.traversals import get_module_from_path, iter_module_search
od = {}
with tqdm(
total=len([p[0] for p in iter_module_search(model, nn.LoRA)]),
desc="Split LoRA Params",
) as pbar:
for path, _ in iter_module_search(model, nn.LoRA):
base_module: nn.LoRA = get_module_from_path(model=model, path=path)
od.update({path: {"lora_a": base_module.lora_a, "lora_b": base_module.lora_b}})
pbar.update(1)
return unflatten_dict(od)
[docs]def merge_lora_params(model: nn.Module, lora_tree: dict) -> nn.Module:
"""
get LoRA (Low-Rank Adaptation) from layers within a model.
Args:
model: The EasyDeL model.
Returns:
LoRA Layer Weights.
"""
from easydel.utils.traversals import get_module_from_path, iter_module_search
if not is_flatten(lora_tree):
lora_tree = flatten_dict(lora_tree)
with tqdm(
total=len([p[0] for p in iter_module_search(model, nn.LoRA)]),
desc="Merge LoRA Params",
) as pbar:
for path, _ in iter_module_search(model, nn.LoRA):
base_module: nn.LoRA = get_module_from_path(model=model, path=path)
base_module.lora_b = lora_tree[(*path, "lora_b")]
base_module.lora_a = lora_tree[(*path, "lora_a")]
pbar.update(1)
return model
[docs]def unwrap_lora_to_layers(
model: nn.Module,
/,
*,
verbose: bool = True,
) -> nn.Module:
"""
UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.
"""
from easydel.utils.traversals import get_module_from_path, iter_module_search, set_module_from_path
with tqdm(
total=len([p[0] for p in iter_module_search(model, ParallelLinear)]),
desc="Unwarping LoRA Layers",
disable=not verbose,
) as pbar:
for path, _ in iter_module_search(model, nn.LoRA):
base_module: nn.LoRA = get_module_from_path(model=model, path=path)
with jax.default_matmul_precision("float32"):
base_module.base_module.kernel.value = (
base_module.base_module.kernel.value + base_module.lora_a.value @ base_module.lora_b.value
)
del base_module.lora_a, base_module.lora_b
set_module_from_path(
model=model,
path=path,
new_value=base_module.base_module,
)
pbar.update(1)
return model
[docs]def apply_sparsity_to_params(
params: dict[str, tp.Any] | tp.Any,
sparsify_module: AVAILABLE_SPARSE_MODULE_TYPES = "bcoo",
verbose: bool = True,
) -> dict[str, tp.Any] | tp.Any:
flatten = is_flatten(params)
if not flatten:
params = flatten_dict(params)
from jax.experimental import sparse
sparser = {
"bcoo": sparse.BCOO,
"bcsr": sparse.BCSR,
"coo": sparse.COO,
"csr": sparse.CSR,
}.get(sparsify_module, None)
assert sparser is not None, f"unkown type of sparser {sparsify_module}"
def filter_params(path, array):
layer_name = ".".join(path[0].key)
if layer_name.endswith("kernel") and 4 > array.ndim > 1:
array = sparser.fromdense(array)
return array
total_params = len(jax.tree_util.tree_leaves(params))
with tqdm(
total=total_params,
desc=f"{sparsify_module.capitalize()}",
disable=not verbose,
) as pbar:
def _with_progress(path, array):
pbar.set_postfix_str(".".join(path[0].key))
result = filter_params(path, array)
pbar.update(1)
return result
params = jax.tree_util.tree_map_with_path(_with_progress, params)
if not flatten:
params = unflatten_dict(params)
return params
M = tp.TypeVar("M", bound=nn.Module)
@tp.overload
def auto_remat(
module: type[M],
/,
*,
policy: EasyDeLGradientCheckPointers | str | tp.Callable = EasyDeLGradientCheckPointers.NONE,
prevent_cse: bool = True,
save_names: list[str] | None = None,
exclude_names: list[str] | None = None,
) -> type[M]: ...
@tp.overload
def auto_remat(
module1: type[M],
module2: type[M],
/,
*,
policy: EasyDeLGradientCheckPointers | str | tp.Callable = EasyDeLGradientCheckPointers.NONE,
prevent_cse: bool = True,
save_names: list[str] | None = None,
exclude_names: list[str] | None = None,
) -> tuple[type[M], type[M]]: ...
@tp.overload
def auto_remat(
*modules: type[M],
policy: EasyDeLGradientCheckPointers | str | tp.Callable = EasyDeLGradientCheckPointers.NONE,
prevent_cse: bool = True,
save_names: list[str] | None = None,
exclude_names: list[str] | None = None,
) -> tuple[type[M], ...]: ...
[docs]def auto_remat(
*modules: type[M],
policy: EasyDeLGradientCheckPointers | str | tp.Callable = EasyDeLGradientCheckPointers.NONE,
prevent_cse: bool = True,
save_names: list[str] | None = None,
exclude_names: list[str] | None = None,
) -> type[M] | tuple[type[M], ...]:
"""Apply gradient checkpointing (rematerialization) to module(s).
Wraps module __call__ methods with JAX's remat (rematerialization) to trade
compute for memory during training. Supports fine-grained control via
checkpoint_name annotations added to models.
Args:
*modules: One or more module classes to wrap with remat.
policy: Checkpointing policy. Can be:
- EasyDeLGradientCheckPointers enum value
- String policy name (e.g., 'dots_saveable', 'nothing_saveable')
- Custom callable policy (e.g., from create_transformer_checkpoint_policy)
- 'save_only_these_names': Use with save_names param
- 'save_anything_except_these_names': Use with exclude_names param
prevent_cse: If True, prevents common subexpression elimination.
save_names: List of checkpoint names to save (for 'save_only_these_names').
Works with checkpoint_name calls in models.
exclude_names: List of checkpoint names to exclude from saving.
Returns:
Single module or tuple of modules with remat applied.
Examples:
>>> # Basic usage with predefined policy
>>> AttentionModule = auto_remat(AttentionModule, policy='dots_saveable')
>>>
>>> # Multiple modules
>>> AttentionModule, MLPModule = auto_remat(
... AttentionModule, MLPModule,
... policy='nothing_saveable'
... )
>>>
>>> # Custom policy saving only specific checkpoints
>>> model = auto_remat(
... model,
... policy='save_only_these_names',
... save_names=['attn_output', 'mlp_output', 'residual']
... )
>>>
>>> # Using transformer-optimized policy
>>> policy = create_transformer_checkpoint_policy(
... save_attention=True,
... save_mlp=False # Recompute MLP to save memory
... )
>>> model = auto_remat(model, policy=policy)
"""
if policy == EasyDeLGradientCheckPointers.NONE or policy == "":
if len(modules) == 1:
return modules[0]
return modules
if isinstance(policy, str | EasyDeLGradientCheckPointers):
policy = get_gradient_checkpoint_policy(policy, save_names, exclude_names)
elif not callable(policy):
raise ValueError(f"Invalid policy type: {type(policy)}")
outs = ()
for module in modules:
assert issubclass(module, nn.Module)
static_argnums = extract_static_parameters(module=module)
if static_argnums is None:
static_argnums = ()
module.__call__ = nn.remat(
f=module.__call__,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
)
outs += (module,)
if len(outs) == 1:
return outs[0]
return outs
# Main FLOP counting function
[docs]def count_flop_jaxpr(jaxpr) -> int:
"""Count flops in a Jaxpr."""
def get_shape_size(shape) -> int:
"""Calculate total size of an array shape."""
return int(np.prod(shape)) if shape else 1
def compute_binary_op_flops(eqn) -> int:
"""Generic FLOP counter for binary operations with broadcasting."""
shape0 = eqn.invars[0].aval.shape
shape1 = eqn.invars[1].aval.shape
output_shape = np.broadcast_shapes(shape0, shape1)
return get_shape_size(output_shape)
def compute_unary_op_flops(eqn) -> int:
"""FLOP counter for unary operations."""
shape = eqn.invars[0].aval.shape
return get_shape_size(shape)
def compute_dot_general_flops(eqn) -> int:
"""Compute FLOPs for dot_general operation."""
shapes = [var.aval.shape for var in eqn.invars]
if len(shapes) != 2:
return 0
dimension_numbers = eqn.params.get("dimension_numbers", None)
if not dimension_numbers:
return 0
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
# Calculate sizes for contracting dimensions
contracting_size = np.prod([shapes[0][d] for d in lhs_contract])
# Calculate output shape size
batch_size = np.prod([shapes[0][d] for d in lhs_batch])
lhs_remaining = [d for i, d in enumerate(shapes[0]) if i not in lhs_contract and i not in lhs_batch]
rhs_remaining = [d for i, d in enumerate(shapes[1]) if i not in rhs_contract and i not in rhs_batch]
out_size = batch_size * np.prod(lhs_remaining) * np.prod(rhs_remaining)
# Each output element requires 2*contracting_size - 1 operations
return out_size * (2 * contracting_size - 1)
def compute_conv_flops(eqn) -> int:
"""Compute FLOPs for convolution operation."""
lhs_shape = eqn.invars[0].aval.shape
rhs_shape = eqn.invars[1].aval.shape
dimension_numbers = eqn.params.get("dimension_numbers", None)
if not dimension_numbers:
return 0
lhs_spec, rhs_spec, _out_spec = dimension_numbers
batch_size = lhs_shape[lhs_spec.index("N")]
in_channels = lhs_shape[lhs_spec.index("C")]
out_channels = rhs_shape[rhs_spec.index("O")]
spatial_size = 1
kernel_size = 1
for d in range(len(lhs_spec) - 2):
spatial_size *= lhs_shape[lhs_spec.index(str(d))]
kernel_size *= rhs_shape[rhs_spec.index(str(d))]
ops_per_point = 2 * kernel_size * in_channels - 1
total_points = batch_size * spatial_size * out_channels
return ops_per_point * total_points
def compute_reduce_flops(eqn) -> int:
"""Compute FLOPs for reduction operations."""
shape = eqn.invars[0].aval.shape
reduced_axes = eqn.params.get("axes", ())
if not reduced_axes:
return 0
reduced_size = np.prod([shape[ax] for ax in reduced_axes])
remaining_shape = [s for i, s in enumerate(shape) if i not in reduced_axes]
remaining_size = np.prod(remaining_shape) if remaining_shape else 1
return remaining_size * (reduced_size - 1)
def compute_attention_flops(eqn) -> int:
"""Compute FLOPs for attention operation."""
q_shape = eqn.invars[0].aval.shape
k_shape = eqn.invars[1].aval.shape
batch, q_len, num_heads, head_dim = q_shape
_, kv_len, _, _ = k_shape
qk_flops = batch * num_heads * q_len * kv_len * (2 * head_dim - 1)
softmax_flops = batch * num_heads * q_len * (kv_len + (kv_len - 1) + 1)
av_flops = batch * num_heads * q_len * head_dim * (2 * kv_len - 1)
return qk_flops + softmax_flops + av_flops
def count_scan_flops(eqn) -> int:
"""Count FLOPs in a scan operation."""
scan_jaxpr = eqn.params.get("jaxpr", None)
if scan_jaxpr:
body_flops = count_flop_jaxpr(scan_jaxpr)
length = eqn.invars[0].aval.shape[0]
return body_flops * length
return 0
def count_cond_flops(eqn) -> int:
"""Count FLOPs in a conditional operation."""
true_jaxpr = eqn.params.get("true_jaxpr", None)
false_jaxpr = eqn.params.get("false_jaxpr", None)
total_flops = 0
if true_jaxpr:
total_flops += count_flop_jaxpr(true_jaxpr)
if false_jaxpr:
total_flops += count_flop_jaxpr(false_jaxpr)
return total_flops // 2
def get_scatter_flops(eqn) -> int:
"""Count FLOPs in a scatter operation."""
updates_shape = eqn.invars[2].aval.shape
return get_shape_size(updates_shape)
def compute_select_n_flops(eqn) -> int:
"""Compute FLOPs for select_n operation."""
pred_shape = eqn.invars[0].aval.shape
return get_shape_size(pred_shape)
def compute_cumsum_flops(eqn) -> int:
"""Compute FLOPs for cumulative sum."""
shape = eqn.invars[0].aval.shape
axis = eqn.params.get("axis", 0)
# Each element adds to the previous sum
return get_shape_size(shape) - shape[axis]
def compute_max_flops(eqn) -> int:
"""Compute FLOPs for max operation."""
if len(eqn.invars) == 2:
# Binary max
return compute_binary_op_flops(eqn)
# Unary max
return compute_unary_op_flops(eqn)
def compute_pow_flops(eqn) -> int:
"""Compute FLOPs for power operation."""
if len(eqn.invars) == 2:
shape0 = eqn.invars[0].aval.shape
shape1 = eqn.invars[1].aval.shape
output_shape = np.broadcast_shapes(shape0, shape1)
return 8 * get_shape_size(output_shape) # Power is expensive
return 8 * get_shape_size(eqn.invars[0].aval.shape)
def compute_integer_pow_flops(eqn) -> int:
"""Compute FLOPs for integer power."""
shape = eqn.invars[0].aval.shape
power = eqn.params.get("y", 2)
return (power - 1) * get_shape_size(shape)
def compute_and_flops(eqn) -> int:
"""Compute FLOPs for logical and operation."""
return compute_binary_op_flops(eqn)
def count_custom_vjp_flops(eqn) -> int:
"""Count FLOPs in custom VJP operation."""
fwd_jaxpr = eqn.params.get("fun_jaxpr", None)
if fwd_jaxpr:
return count_flop_jaxpr(fwd_jaxpr)
return 0
def compute_sqrt_flops(eqn) -> int:
"""Compute FLOPs for square root operation."""
# Square root is typically more expensive than basic operations
return 4 * compute_unary_op_flops(eqn)
def compute_argmax_flops(eqn) -> int:
"""Compute FLOPs for argmax operation."""
shape = eqn.invars[0].aval.shape
axis = eqn.params.get("axes", (0,))[0]
# For each output element, we need to compare n-1 elements where n is the size of the reduction axis
remaining_size = get_shape_size(shape) // shape[axis]
return remaining_size * (shape[axis] - 1)
def compute_min_flops(eqn) -> int:
"""Compute FLOPs for min operation."""
if len(eqn.invars) == 2:
# Binary min
return compute_binary_op_flops(eqn)
# Unary min
return compute_unary_op_flops(eqn)
def compute_rem_flops(eqn) -> int:
"""Compute FLOPs for remainder operation."""
# Remainder typically involves division and multiplication
return 2 * compute_binary_op_flops(eqn)
def compute_square_flops(eqn) -> int:
"""Compute FLOPs for square operation (x * x)."""
# Square is a single multiplication of a number by itself
shape = eqn.invars[0].aval.shape
return get_shape_size(shape)
def compute_triangular_solve_flops(eqn) -> int:
"""Compute FLOPs for triangular solve operation."""
# For a triangular solve with a matrix of size n x n,
# each row/column requires n^2/2 multiply-adds
matrix_shape = eqn.invars[0].aval.shape
n = matrix_shape[-1] # Size of the last dimension
batch_dims = matrix_shape[:-2]
batch_size = np.prod(batch_dims) if batch_dims else 1
return batch_size * n * (n + 1) * (2 * n + 1) // 6
def compute_erf_inv_flops(eqn) -> int:
"""Compute FLOPs for inverse error function."""
# erf_inv is computationally expensive, typically implemented
# as a series expansion or numerical approximation
return 15 * compute_unary_op_flops(eqn)
def compute_or_flops(eqn) -> int:
"""Compute FLOPs for logical or operation."""
return compute_binary_op_flops(eqn)
def compute_shift_right_logical_flops(eqn) -> int:
"""Compute FLOPs for logical right shift."""
return compute_binary_op_flops(eqn)
# Dictionary mapping primitives to their FLOP counting functions
primitive_flops: dict[str, tp.Callable] = {
# Binary operations
"mul": compute_binary_op_flops,
"add": compute_binary_op_flops,
"sub": compute_binary_op_flops,
"div": compute_binary_op_flops,
"gt": compute_binary_op_flops,
"lt": compute_binary_op_flops,
"ge": compute_binary_op_flops,
"le": compute_binary_op_flops,
"ne": compute_binary_op_flops,
"eq": compute_binary_op_flops,
# Unary operations
"neg": compute_unary_op_flops,
"sin": lambda eqn: 5 * compute_unary_op_flops(eqn),
"cos": lambda eqn: 5 * compute_unary_op_flops(eqn),
"exp": lambda eqn: 4 * compute_unary_op_flops(eqn),
"log": lambda eqn: 6 * compute_unary_op_flops(eqn),
"log1p": lambda eqn: 6 * compute_unary_op_flops(eqn),
"tanh": lambda eqn: 7 * compute_unary_op_flops(eqn),
"rsqrt": lambda eqn: 6 * compute_unary_op_flops(eqn),
# Linear algebra
"dot_general": compute_dot_general_flops,
"conv_general_dilated": compute_conv_flops,
# Reduction operations
"reduce_sum": compute_reduce_flops,
"reduce_max": compute_reduce_flops,
"reduce_min": compute_reduce_flops,
# Special operations
"scatter-add": get_scatter_flops,
"scan": count_scan_flops,
"cond": count_cond_flops,
# Memory operations (0 FLOPs)
"broadcast_in_dim": lambda eqn: 0,
"reshape": lambda eqn: 0,
"transpose": lambda eqn: 0,
"slice": lambda eqn: 0,
"gather": lambda eqn: 0,
"concatenate": lambda eqn: 0,
"convert_element_type": lambda eqn: 0,
"dynamic_slice": lambda eqn: 0,
"pad": lambda eqn: 0,
# Parallel/Sharding operations (0 FLOPs)
"pjit": lambda eqn: 0,
"shard_map": lambda eqn: 0,
"sharding_constraint": lambda eqn: 0,
# Other operations
"dot_product_attention_fwd_wrapper": compute_attention_flops,
"select_n": compute_select_n_flops,
"cumsum": compute_cumsum_flops,
"max": compute_max_flops,
"iota": lambda eqn: 0, # Memory operation, no FLOPs
"pow": compute_pow_flops,
"integer_pow": compute_integer_pow_flops,
"and": compute_and_flops,
"random_fold_in": lambda eqn: 0, # Random number generation, no FLOPs
"custom_vjp_call_jaxpr": count_custom_vjp_flops,
"logistic": lambda eqn: 4 * compute_unary_op_flops(eqn), # sigmoid function
# No-op operations (0 FLOPs)
"stop_gradient": lambda eqn: 0, # Just passes through the value
"squeeze": lambda eqn: 0, # Reshapes data, no computation
"copy": lambda eqn: 0, # Memory operation only
"split": lambda eqn: 0,
"remat2": lambda eqn: 0,
"random_seed": lambda eqn: 0,
"random_unwrap": lambda eqn: 0,
"random_wrap": lambda eqn: 0,
"random_split": lambda eqn: 0,
"random_bits": lambda eqn: 0,
# Bitwise and type conversion operations
"shift_right_logical": compute_shift_right_logical_flops,
"or": compute_or_flops,
"bitcast_convert_type": lambda eqn: 0, # Type conversion, no computation
# Mathematical operations
"abs": compute_unary_op_flops, # Single comparison/selection per element
"erf_inv": compute_erf_inv_flops, # Inverse error function
"triangular_solve": compute_triangular_solve_flops,
# Computation operations
"square": compute_square_flops,
"sqrt": compute_sqrt_flops,
"argmax": compute_argmax_flops,
"add_any": compute_binary_op_flops, # Similar to regular add
"min": compute_min_flops,
"rem": compute_rem_flops,
}
flops = 0
def visit_jaxpr(jaxpr):
nonlocal flops
for eqn in jaxpr.eqns:
primitive_name = eqn.primitive.name
if primitive_name in primitive_flops:
flops += primitive_flops[primitive_name](eqn)
else:
warnings.warn(f"Unhandled primitive {primitive_name}", stacklevel=1)
# Recursively visit subjaxprs
for subjaxpr in jax.core.jaxprs_in_params(eqn.params):
visit_jaxpr(subjaxpr)
visit_jaxpr(jaxpr)
return flops
[docs]class TraceResult:
"""Container for XLA executable trace results with cost analysis.
Wraps an XLA executable and provides lazy access to its cost analysis,
including FLOP counts and other performance metrics.
Attributes:
_executable: The underlying XLA executable.
_cached_cost: Cached cost analysis result.
Properties:
cost_analysis: Returns the cost analysis dict (cached after first access).
flops: Returns the FLOP count from cost analysis.
"""
def __init__(self, executable):
self._executable = executable
self._cached_cost = None
@property
@lru_cache(maxsize=1) # noqa
def cost_analysis(self):
return self._executable.cost_analysis()
@property
def flops(self):
return self.cost_analysis["flops"]
[docs]class FunctionTracer:
"""Tracer for capturing new XLA executables during compilation.
Used to track which functions are compiled during a trace operation.
Captures the difference between executables before and after tracing.
Attributes:
new_executables: List of TraceResult objects for newly compiled functions.
_before: Set of executables that existed before tracing started.
Example:
>>> with trace_functions() as tracer:
... result = jitted_function(x)
>>> print(f"Compiled {len(tracer.new_executables)} functions")
>>> print(f"Total FLOPs: {sum(t.flops for t in tracer.new_executables)}")
"""
def __init__(self):
self.new_executables: list[TraceResult] = []
self._before: set = set()
def __getitem__(self, idx):
return self.new_executables[idx]
[docs]class CompilationTracker:
"""Tracks XLA compilation and FLOP counts across function calls.
Monitors the compilation of XLA executables and accumulates their
FLOP counts. Useful for profiling and understanding computational
costs of JAX programs.
Attributes:
first_time: Whether this is the first compilation trace.
cached_flops: Total accumulated FLOPs from all compiled functions.
functions: List of compiled XLA executables.
Properties:
online_flops: Current total FLOPs from all tracked functions.
Methods:
trace_compilation: Context manager for tracing compilation.
Example:
>>> tracker = CompilationTracker()
>>> with tracker.trace_compilation():
... result = model(inputs)
>>> print(f"Total FLOPs: {tracker.cached_flops}")
"""
def __init__(self):
self.first_time = True
self.cached_flops = 0
self.functions = None
@property
def online_flops(self):
if self.functions is None:
return 0
cached_flops = 0
for cm in self.functions:
try:
cached_flops += cm.cost_analysis()["flops"]
except Exception:
...
return cached_flops
[docs] @contextmanager
def trace_compilation(self):
if self.first_time:
before = set(jax.extend.backend.get_backend().live_executables())
yield
after = set(jax.extend.backend.get_backend().live_executables())
new = after - before
if new:
cmpf = list(new)
self.functions = cmpf
for cm in cmpf:
try:
self.cached_flops += cm.cost_analysis()["flops"]
except Exception:
...
self.first_time = False
else:
yield
[docs]class ActivationType(str, Enum):
GELU = "gelu"
RELU = "relu"
SILU = "silu"
SWISH = "swish"
GELU_NEW = "gelu_new"
GELU_PYTORCH_TANH = "gelu_pytorch_tanh"
TANH = "tanh"
SIGMOID = "sigmoid"
LEAKY_RELU = "leaky_relu"
GLU = "glu"
ELU = "elu"
SOFTMAX = "softmax"
QUICK_GELU = "quick_gelu"
[docs]def flop_activation(activation_type: ActivationType, dim: int) -> float:
"""Calculate FLOPs for different activation functions."""
# FLOPs per element for different activation functions
flops_per_element = {
ActivationType.GELU: 8, # Approximation with several operations
ActivationType.GELU_NEW: 8, # Approximation with tanh
ActivationType.GELU_PYTORCH_TANH: 8, # Similar to GELU_NEW
ActivationType.RELU: 1, # Just a max operation
ActivationType.SILU: 4, # x * sigmoid(x) - sigmoid + multiplication
ActivationType.SWISH: 4, # Same as SILU
ActivationType.TANH: 5, # Approximation of tanh
ActivationType.SIGMOID: 4, # Approximation of sigmoid
ActivationType.LEAKY_RELU: 2, # Comparison + multiplication for negative slope
ActivationType.GLU: 5, # Gated operation - sigmoid + multiplication
ActivationType.ELU: 2, # Comparison + exp for negative values
ActivationType.SOFTMAX: 5, # Similar cost as sigmoid + normalization
ActivationType.QUICK_GELU: 2, # Simple approximation x * sigmoid(1.702 * x)
}
return flops_per_element.get(activation_type, 1) * dim
[docs]class AttnMaskType(str, Enum):
FULL = "ATTN_MASK_FULL"
SLIDING = "ATTN_MASK_SLIDING"
CHUNK = "ATTN_MASK_CHUNK"
[docs] @classmethod
def from_hf(cls, hf_type: tp.Literal["sliding_attention", "full_attention", "chunk_attention"]):
if hf_type == "sliding_attention":
return AttnMaskType.SLIDING
elif hf_type == "full_attention":
return AttnMaskType.FULL
elif hf_type == "chunk_attention":
return AttnMaskType.CHUNK
else:
raise ValueError(f"`hf_type` {hf_type} is not available")
[docs]@auto_pytree
class AttnMaskDetail:
"""Details for attention mask configuration.
Specifies the type and parameters of attention masking to use.
Registered as a JAX pytree for use in JAX transformations.
Attributes:
mask_type: Type of attention mask (FULL, SLIDING, or CHUNK).
size: Size parameter for the mask (e.g., window size for sliding).
offset: Optional offset for mask positioning.
chunks: Optional number of chunks for chunk attention.
bricks: Optional number of bricks for hierarchical attention.
Example:
>>> mask_detail = AttnMaskDetail(
... mask_type=AttnMaskType.SLIDING,
... size=512,
... offset=0
... )
"""
mask_type: AttnMaskType
size: int
offset: int | None = None
chunks: int | None = None
bricks: int | None = None
[docs]class TaskType(str, Enum):
CAUSAL_LM = "causal-language-model"
VISION_LM = "vision-language-model"
DIFFUSION_LM = "diffusion-language-model"
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
BASE_MODULE = "base-module"
BASE_VISION = "vision-module"
SEQUENCE_TO_SEQUENCE = "sequence-to-sequence"
SPEECH_SEQUENCE_TO_SEQUENCE = "speech-sequence-to-sequence"
ZERO_SHOT_IMAGE_CLASSIFICATION = "zero-shot-image-classification"
SEQUENCE_CLASSIFICATION = "sequence-classification"
AUDIO_CLASSIFICATION = "audio-classification"
IMAGE_CLASSIFICATION = "image-classification"
AUTO_BIND = "auto-bind"
[docs]@dataclass
class FlopCalcConfig:
"""Configuration for calculating FLOPs in transformer models.
Comprehensive configuration that captures all parameters needed to
calculate the theoretical FLOP count for various transformer architectures
including encoder-decoder, MoE, and vision transformers.
Attributes:
hidden_dim: Hidden dimension of the model.
intermediate_dim: Dimension of FFN intermediate layer.
num_layers: Number of decoder (or encoder-only) layers.
num_heads: Number of attention heads.
kv_heads: Number of key-value heads (for GQA/MQA).
head_dim: Dimension of each attention head.
seq_len: Sequence length for decoder or encoder-only models.
enc_num_layers: Number of encoder layers (for seq2seq).
enc_seq_len: Encoder sequence length (for seq2seq).
glu: Whether using GLU activation in FFN.
num_experts: Number of MoE experts.
num_shared_experts: Number of shared experts in MoE.
num_experts_per_tok: Experts activated per token.
activation_type: Type of activation function.
task: Model task type (affects head computation).
vocab_size: Vocabulary size for LM head.
num_labels: Number of labels for classification.
vision_hidden_dim: Hidden dim for vision transformer.
vision_intermediate_dim: FFN dim for vision transformer.
vision_num_layers: Number of vision transformer layers.
vision_num_heads: Number of vision attention heads.
vision_seq_len: Vision sequence length (patches).
include_loss: Whether to include loss computation in FLOPs.
Example:
>>> config = FlopCalcConfig(
... hidden_dim=768,
... intermediate_dim=3072,
... num_layers=12,
... num_heads=12,
... kv_heads=12,
... head_dim=64,
... seq_len=1024,
... task=TaskType.CAUSAL_LM,
... vocab_size=50000
... )
>>> flops = flops_per_token(config)
"""
# Core transformer body: for decoder-only and encoder-only models
hidden_dim: int
intermediate_dim: int
num_layers: int # number of decoder (or encoder-only) layers
num_heads: int
kv_heads: int
head_dim: int
seq_len: int # decoder (or encoder-only) sequence length
# Optional encoder for seq2seq / encoder-decoder
enc_num_layers: int = 0
enc_seq_len: int = 0
# MoE / GLU
glu: bool = False
num_experts: int = 1
num_shared_experts: int = 0
num_experts_per_tok: int = 1
# Task specifics
activation_type: ActivationType = ActivationType.GELU
task: TaskType = TaskType.AUTO_BIND
vocab_size: int = 0
num_labels: int = 0
# Vision tower (patch transformer)
vision_hidden_dim: int = 0
vision_intermediate_dim: int = 0
vision_num_layers: int = 0
vision_num_heads: int = 0
vision_seq_len: int = 0
include_loss: bool = False
[docs]def flop_layernorm(hidden_dim: int) -> float:
return 8 * hidden_dim
[docs]def flop_attention(
hidden_dim: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
seq_len: int,
) -> float:
if head_dim is None:
head_dim = hidden_dim // num_heads
qkv_proj = 2 * hidden_dim * (num_heads * head_dim + 2 * num_kv_heads * head_dim)
dense_proj = 2 * hidden_dim * hidden_dim
key_query_logits = 2 * seq_len**2 * num_heads * head_dim
mask = 3 * seq_len * seq_len * num_heads
mask_value = 2 * seq_len * seq_len * head_dim * num_heads
seq_flops = key_query_logits + mask + mask_value
attn = seq_flops / seq_len
return qkv_proj + dense_proj + attn
[docs]def flop_cross_attention(
hidden_dim: int,
num_heads: int,
enc_seq_len: int,
dec_seq_len: int,
) -> float:
head_dim = hidden_dim // num_heads
proj = 2 * hidden_dim * hidden_dim
scores = 2 * head_dim * enc_seq_len * dec_seq_len * num_heads
softmax = 5 * enc_seq_len * dec_seq_len * num_heads
wsum = 2 * head_dim * enc_seq_len * dec_seq_len * num_heads
out_proj = 2 * hidden_dim * hidden_dim
return proj + scores + softmax + wsum + out_proj
[docs]def flop_mlp(
cfg: FlopCalcConfig,
hidden_dim: int,
intermediate_dim: int,
) -> float:
factor = 3 if cfg.glu else 2
base = factor * hidden_dim * intermediate_dim
total_ffn = base * (cfg.num_experts_per_tok + cfg.num_shared_experts)
activation_flops = flop_activation(
cfg.activation_type,
intermediate_dim * (cfg.num_experts_per_tok + cfg.num_shared_experts),
)
router = 2 * hidden_dim * cfg.num_experts if cfg.num_experts > 1 else 0
return 2 * total_ffn + activation_flops + router
[docs]def flop_lm_head(hidden_dim: int, vocab_size: int) -> float:
return 2 * hidden_dim * vocab_size + 5 * vocab_size
[docs]def flop_cls_head(hidden_dim: int, num_labels: int) -> float:
return 2 * hidden_dim * num_labels + 5 * num_labels
[docs]def flop_loss(num_classes: int) -> float:
return 3 * num_classes + 2
[docs]def flop_transformer_body(
layers: int,
seq_len: int,
hidden_dim: int,
intermediate_dim: int,
cfg: FlopCalcConfig,
) -> float:
ln = 2 * flop_layernorm(hidden_dim)
att = flop_attention(
hidden_dim,
cfg.num_heads,
cfg.kv_heads,
cfg.head_dim,
seq_len,
)
mlp = flop_mlp(cfg, hidden_dim, intermediate_dim)
return layers * (ln + att + mlp)
[docs]def flop_seq2seq(cfg: FlopCalcConfig) -> float:
enc = flop_transformer_body(
cfg.enc_num_layers,
cfg.enc_seq_len,
cfg.hidden_dim,
cfg.intermediate_dim,
cfg,
)
ln = 3 * flop_layernorm(cfg.hidden_dim)
self_att = flop_attention(
cfg.hidden_dim,
cfg.num_heads,
cfg.kv_heads,
cfg.head_dim,
cfg.seq_len,
)
cross_att = flop_cross_attention(
cfg.hidden_dim,
cfg.num_heads,
cfg.enc_seq_len,
cfg.seq_len,
)
mlp = flop_mlp(cfg, cfg.hidden_dim, cfg.intermediate_dim)
dec = cfg.num_layers * (ln + self_att + cross_att + mlp)
return enc + dec
[docs]def flop_vision_tower(cfg: FlopCalcConfig) -> float:
return flop_transformer_body(
cfg.vision_num_layers,
cfg.vision_seq_len,
cfg.vision_hidden_dim,
cfg.vision_intermediate_dim,
cfg,
)
[docs]def flops_per_token(cfg: FlopCalcConfig) -> float:
body_cost = 0
head_cost = 0
loss_cost = 0
if cfg.task in {
TaskType.CAUSAL_LM,
TaskType.DIFFUSION_LM,
}:
body_cost = flop_transformer_body(
cfg.num_layers,
cfg.seq_len,
cfg.hidden_dim,
cfg.intermediate_dim,
cfg,
)
head_cost = flop_lm_head(cfg.hidden_dim, cfg.vocab_size)
loss_cost = flop_loss(cfg.vocab_size) if cfg.include_loss else 0
elif cfg.task in {
TaskType.SEQUENCE_CLASSIFICATION,
TaskType.IMAGE_CLASSIFICATION,
TaskType.AUDIO_CLASSIFICATION,
}:
body_cost = flop_transformer_body(
cfg.num_layers,
cfg.seq_len,
cfg.hidden_dim,
cfg.intermediate_dim,
cfg,
)
head_cost = flop_cls_head(cfg.hidden_dim, cfg.num_labels)
loss_cost = flop_loss(cfg.num_labels) if cfg.include_loss else 0
elif cfg.task in {
TaskType.SEQUENCE_TO_SEQUENCE,
TaskType.SPEECH_SEQUENCE_TO_SEQUENCE,
}:
body_cost = flop_seq2seq(cfg)
head_cost = flop_lm_head(cfg.hidden_dim, cfg.vocab_size)
loss_cost = flop_loss(cfg.vocab_size) if cfg.include_loss else 0
elif cfg.task == TaskType.VISION_LM:
body_cost = flop_vision_tower(cfg)
elif cfg.task == TaskType.IMAGE_TEXT_TO_TEXT:
try:
vision = flop_vision_tower(cfg)
text = flop_seq2seq(cfg)
except ZeroDivisionError:
vision = 0
text = 0
clm_head = flop_transformer_body(
cfg.num_layers,
cfg.seq_len,
cfg.hidden_dim,
cfg.intermediate_dim,
cfg,
)
body_cost = vision + text + clm_head
head_cost = flop_lm_head(cfg.hidden_dim, cfg.vocab_size)
loss_cost = flop_loss(cfg.vocab_size) if cfg.include_loss else 0
elif cfg.task == TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION:
body_cost = flop_vision_tower(cfg)
head_cost = flop_cls_head(cfg.hidden_dim, cfg.num_labels)
elif cfg.task in {TaskType.BASE_MODULE, TaskType.BASE_VISION, TaskType.AUTO_BIND}:
body_cost = flop_transformer_body(
cfg.num_layers,
cfg.seq_len,
cfg.hidden_dim,
cfg.intermediate_dim,
cfg,
)
else:
raise NotImplementedError(f"Unsupported task: {cfg.task}")
return body_cost + head_cost + loss_cost
[docs]@contextmanager
def trace_functions():
tracer = FunctionTracer()
tracer._before = set(jax.extend.backend.get_backend().live_executables())
try:
yield tracer
finally:
after = set(jax.extend.backend.get_backend().live_executables())
new = after - tracer._before
tracer.new_executables = [TraceResult(exe) for exe in new]
[docs]class ModuleCaches(nn.Cache):
"""Cache container for module-level cached values.
Extends flax.nnx.Cache to provide caching functionality for
EasyDeL modules, particularly for caching computed values like
frequencies, masks, and other reusable tensors.
"""
[docs]class OverWriteWithGradient(nn.Param):
"""Parameter type that allows gradient overwrites.
Special parameter container that permits gradients to directly
overwrite the parameter values during optimization, useful for
certain advanced optimization techniques.
"""
[docs]class hashable_dict(dict):
__hash__ = hash_fn
[docs]class ArrayParam(nn.Param):
"""Parameterized array with serializable initialization.
A parameter container that stores initialization metadata (method name
and kwargs) as strings/dicts instead of functions, making it pickleable
and serializable. This is particularly useful for checkpointing and
distributed training.
Attributes:
shape: The shape of the parameter array.
dtype: The data type of the parameter array.
init_method: Name of the JAX initializer (e.g., "normal", "zeros", "ones").
init_kwargs: Optional kwargs passed to the initializer.
"""
shape: Sequence[int]
dtype: DTypeLike
init_method: str = "normal"
init_kwargs: hashable_dict | None = None
[docs] @classmethod
def bound(
cls,
shape: Sequence[int],
dtype: DTypeLike,
init_method: str,
init_kwargs: hashable_dict | None = None,
*,
key: PRNGKeyArray | None = None,
value: Array | None = None,
use_ref: bool | None = None,
**metadata,
):
"""Create an ArrayParam with initialized value.
Args:
shape: Shape of the parameter array.
dtype: Data type for the parameter.
init_method: Name of JAX initializer (e.g., "normal", "zeros", "kaiming_uniform").
init_kwargs: Optional keyword arguments for the initializer.
key: PRNG key for random initialization. Required if value is None.
value: Pre-computed value. If provided, skips initialization.
use_ref: Whether to use reference semantics.
**metadata: Additional metadata to store with the parameter.
Returns:
ArrayParam: An initialized ArrayParam instance.
"""
if init_kwargs is None:
init_kwargs = {}
init_kwargs = hashable_dict(init_kwargs)
init_fn = getattr(jax.nn.initializers, init_method, jax.nn.initializers.normal)(**init_kwargs)
if value is None:
value = init_fn(key, shape, dtype)
return cls(
shape=shape,
dtype=dtype,
init_method=init_method,
init_kwargs=init_kwargs,
value=value,
use_ref=use_ref,
**metadata,
)
[docs] def resure(self, key: PRNGKeyArray, shard_fn: tp.Callable[[Array], Array] | None = None) -> None:
"""Reinitialize the parameter value with a new random key.
Regenerates the parameter value using the stored initialization method
and optional sharding function. Useful for resetting parameters or
applying sharding after initialization.
Args:
key: PRNG key for random initialization.
shard_fn: Optional function to apply sharding to the reinitialized value.
"""
init_fn = getattr(jax.nn.initializers, self.init_method, jax.nn.initializers.normal)(**self.init_kwargs)
val = init_fn(key, self.shape, self.dtype)
if shard_fn is not None:
val = shard_fn(val)
self.value = val
self.raw_value = val
if tp.TYPE_CHECKING:
from transformers import BaseImageProcessor, FeatureExtractionMixin, PreTrainedTokenizerBase, ProcessorMixin
ProcessingClassType = tp.Optional[ # noqa
tp.Union[ # noqa
PreTrainedTokenizerBase,
BaseImageProcessor,
FeatureExtractionMixin,
ProcessorMixin,
]
]
else:
ProcessingClassType = tp.Any