# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import re
import types
import typing as tp
import warnings
from contextlib import contextmanager
from functools import lru_cache, partial
from typing import List, Set
import flax
import flax.core
import jax
import jax.experimental
import jax.tree_util
import numpy as np
from eformer.escale import PartitionAxis, with_sharding_constraint
from einops import rearrange
from flax import nnx as nn
from jax.sharding import PartitionSpec
from tqdm.auto import tqdm
from easydel.layers.linear import ParallelLinear
from easydel.utils.helpers import get_logger
from easydel.utils.traversals import flatten_dict, unflatten_dict
from .base_config import EasyMethod
from .errors import EasyDeLBlockWiseFFNError
from .etils import (
AVAILABLE_SPARSE_MODULE_TYPES,
EasyDeLGradientCheckPointers,
EasyDeLQuantizationMethods,
)
warnings.filterwarnings(
"ignore",
message="Primitive dynamic_update_slice was not handled by class",
)
logger = get_logger(__name__)
[docs]def quick_gelu(x):
return x * jax.nn.sigmoid(1.702 * x)
ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
"tanh": nn.tanh,
"sigmoid": nn.sigmoid,
"leaky_relu": partial(nn.leaky_relu, negative_slope=0.01),
"glu": nn.glu,
"elu": nn.elu,
"softmax": nn.softmax,
"quick_gelu": quick_gelu,
}
ROPE_TYPES = tp.Optional[
tp.Literal[
"none",
"linear",
"dynamic",
"yarn",
"su",
"llama3",
"longrope",
]
]
with_sharding_constraint = with_sharding_constraint
[docs]def canonicalize_dtype(
*args,
dtype: tp.Optional[jax.numpy.dtype] = None,
inexact: bool = True,
) -> jax.numpy.dtype:
"""Canonicalize an optional dtype to the definitive dtype.
If the ``dtype`` is None this function will infer the dtype. If it is not
None it will be returned unmodified or an exceptions is raised if the dtype
is invalid.
from the input arguments using ``jnp.result_type``.
Args:
*args: JAX array compatible values. None values
are ignored.
dtype: tp.Optional dtype override. If specified the arguments are cast to
the specified dtype instead and dtype inference is disabled.
inexact: When True, the output dtype must be a subdtype
of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
is useful when you want to apply operations that don'position_ids work directly on
integers like taking a mean for example.
Returns:
The dtype that *args should be cast to.
"""
if dtype is None:
args_filtered = [jax.numpy.asarray(x) for x in args if x is not None]
dtype = jax.numpy.result_type(*args_filtered)
if inexact and not jax.numpy.issubdtype(dtype, jax.numpy.inexact):
dtype = jax.numpy.promote_types(jax.numpy.float32, dtype)
if inexact and not jax.numpy.issubdtype(dtype, jax.numpy.inexact):
raise ValueError(f"Dtype must be inexact: {dtype}")
return dtype
[docs]def get_gradient_checkpoint_policy(name):
"""
The get_gradient_checkpoint_policy function is a helper function that returns the gradient checkpoint policy
specified by the name parameter.
"""
gradients = dict(
everything_saveable=jax.checkpoint_policies.everything_saveable,
nothing_saveable=jax.checkpoint_policies.nothing_saveable,
dots_saveable=jax.checkpoint_policies.dots_saveable,
checkpoint_dots=jax.checkpoint_policies.checkpoint_dots,
dots_with_no_batch_dims_saveable=jax.checkpoint_policies.dots_with_no_batch_dims_saveable,
checkpoint_dots_with_no_batch_dims=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
save_anything_except_these_names=jax.checkpoint_policies.save_anything_except_these_names,
save_any_names_but_these=jax.checkpoint_policies.save_any_names_but_these,
save_only_these_names=jax.checkpoint_policies.save_only_these_names,
save_from_both_policies=jax.checkpoint_policies.save_from_both_policies,
)
return gradients[name]
[docs]def add_start_docstrings(*docstr):
"""The add_start_docstrings function is a decorator that adds the docstrings to the beginning of a function.
The add_start_docstrings function takes in an arbitrary number of strings and returns a decorator.
The returned decorator takes in one argument, fn, which is assumed to be a function. The docstring for fn is set equal to
the concatenation of all the strings passed into add_start_docstrings plus (if it exists) the original docstring for fn.
Args:
*docstr: Pass in a variable number of arguments to the function
Returns:
A decorator that adds the docstrings to the function
"""
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn
return docstring_decorator
[docs]def get_dot_general_by_bits(
bits: tp.Optional[int] = None,
mode: tp.Literal["train", "serve", "convert"] = EasyMethod.TRAIN,
) -> dict:
"""The get_general_dot function is a helper function that returns a q_flax.QDotGeneral object
with the specified number of bits for forward and backward passes. If no bits are specified,
the function returns None.
Args:
bits: tp.Optional[int]: Specify the number of bits for quantization
mode: EasyMethod: Specify the use of model to init the QDot
Method for (e.q TRAIN,SERVE,...)
Returns:
A dict that contain dot_general_cls
"""
if bits is not None:
try:
from aqt.jax.v2 import config as q_config # type: ignore
from aqt.jax.v2.flax import aqt_flax as q_flax # type: ignore
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"No module named `aqt` has been found, please "
"install aqt before using bits option in EasyDeL"
) from e
if mode == EasyMethod.TRAIN:
rhs_quant_mode = q_flax.QuantMode.TRAIN
elif mode == EasyMethod.EVAL or mode == EasyMethod.SERVE:
rhs_quant_mode = q_flax.QuantMode.SERVE
elif mode == EasyMethod.CONVERT:
rhs_quant_mode = q_flax.QuantMode.CONVERT
else:
raise ValueError("Unknown Quant Method for EasyMethod")
return {
"dot_general_cls": functools.partial(
q_flax.AqtDotGeneral,
cfg=q_config.fully_quantized(fwd_bits=bits, bwd_bits=bits),
rhs_quant_mode=rhs_quant_mode,
)
}
return {} # empty just in case of not getting any error
[docs]def block_wise_ffn(remat_ffn, inputs, chunk_size: int):
generating = inputs.shape[1] == 1
try:
if generating:
return remat_ffn(inputs)
else:
return rearrange(
jax.lax.scan(
f=lambda carry, idx: (carry.at[:, idx].set(remat_ffn(carry[:, idx])), None),
init=rearrange(inputs, "b (c n) d -> b c n d", c=chunk_size),
xs=jax.numpy.arange(chunk_size),
length=chunk_size,
unroll=True,
)[0],
"b c n d -> b (c n) d",
)
except Exception as e:
raise EasyDeLBlockWiseFFNError(
"You Are using BlockWise FFN from near-infinite-context length paper and you might be passing "
"input arguments in wrong way in case that you don'position_ids want to use this just pass `use_scan_mlp=False` in "
"model config or in config_kwargs in AutoEasyDeLModelFor... or change `scan_mlp_chunk_size` "
f"in configs for more information read Docs.\nOriginal Error\n{e}"
) from e
[docs]def control_mlp_sharding(
x: jax.Array,
partition_axis: PartitionAxis,
force_spec: PartitionSpec = None,
):
"""
handles MLP Shardings
"""
sqax = (
partition_axis.sequence_axis
if x.shape[1] != 1
else partition_axis.generation_query_sequence_axis
)
force_spec = force_spec or PartitionSpec(
partition_axis.batch_axis,
sqax,
partition_axis.hidden_state_axis,
)
x = with_sharding_constraint(x, sharding=force_spec)
return x
[docs]def is_flatten(pytree: dict):
"""The is_flatten function checks if the pytree is flattened.
If it is, then the first key in the dictionary will be a tuple of (mpl, mpl_id).
Otherwise, it will be an integer representing mpl_id.
Args:
pytree: dict: Pass the pytree to the function
Returns:
True if the pytree is a flattened tree, and false otherwise
"""
mpl = [k for k in pytree.keys()][0]
return True if isinstance(mpl, tuple) else False
[docs]def quantize_linear_layers(
model: nn.Module,
/,
*,
method: tp.Optional[EasyDeLQuantizationMethods] = None,
block_size: int = 256,
quantization_pattern: tp.Optional[str] = None,
verbose: bool = True,
) -> nn.Module:
"""
Quantize parameters to requested precision, excluding specified layers.
Args:
model: The model to quantize.
method (EasyDeLQuantizationMethods): quantization method for params.
quantization_pattern (str): re pattern for layers to be quantized.
verbose (bool): whenever to use tqdm for logging stuff.
Returns:
Quantized parameters in the same structure as the input.
"""
if method == EasyDeLQuantizationMethods.NONE or method is None:
return model
from easydel.layers.quantization import Linear8bit, LinearNF4
from easydel.utils.graph_utils import (
get_module_from_path,
iter_module_search,
set_module_from_path,
)
quantizer: Linear8bit = {
EasyDeLQuantizationMethods.NF4: LinearNF4,
EasyDeLQuantizationMethods.A8BIT: Linear8bit,
EasyDeLQuantizationMethods.A4Q: LinearNF4,
EasyDeLQuantizationMethods.A8Q: Linear8bit,
}.get(method, None)
if quantizer is None:
raise NotImplementedError("Requested Quantizer is not Supported")
if quantization_pattern is None:
quantization_pattern = ".*"
if hasattr(model, "config"):
model.config.quantization_method = method
model.config.quantization_block_size = block_size
model.config.quantization_pattern = quantization_pattern
pattern = re.compile(quantization_pattern)
with tqdm(
total=len([p[0] for p in iter_module_search(model, ParallelLinear)]),
desc=f"Quantizing to {method}",
disable=not verbose,
) as pbar:
for path, _ in iter_module_search(model, ParallelLinear):
if pattern.search(".".join([str(p) for p in path])):
set_module_from_path(
model=model,
path=path,
new_value=quantizer.from_linear(
linear=get_module_from_path(model=model, path=path),
rngs=None,
block_size=block_size,
),
)
pbar.update(1)
return model
[docs]def apply_lora_to_layers(
model: nn.Module,
/,
*,
lora_rank: int,
lora_pattern: tp.Optional[str] = None,
verbose: bool = True,
rngs: tp.Optional[nn.Rngs] = None,
) -> nn.Module:
"""
Applies LoRA (Low-Rank Adaptation) to specified linear layers within a model.
Args:
model: The EasyDeL model to modify.
lora_rank: The rank of the LoRA adapters.
lora_pattern: A regular expression pattern to match the names of
modules to which LoRA should be applied. Defaults to ".*" (all linear layers).
verbose: Whether to display a progress bar.
rngs: A `flax.nnx.Rngs` instance for random number generation. If None, initializes with a seed of 0.
Returns:
The modified model with LoRA applied to the specified layers.
"""
from easydel.utils.graph_utils import (
get_module_from_path,
iter_module_search,
set_module_from_path,
)
if not (lora_rank > 0):
raise ValueError("lora_rank should be a positive value and higher than `0`.")
if lora_pattern is None:
lora_pattern = ".*"
if rngs is None:
rngs = nn.Rngs(0)
pattern = re.compile(lora_pattern)
with tqdm(
total=len([p[0] for p in iter_module_search(model, ParallelLinear)]),
desc="Applying LoRA",
disable=not verbose,
) as pbar:
for path, _ in iter_module_search(model, ParallelLinear):
if pattern.search(".".join([str(p) for p in path])):
base_module: ParallelLinear = get_module_from_path(model=model, path=path)
set_module_from_path(
model=model,
path=path,
new_value=nn.LoRA(
base_module=base_module,
rngs=rngs,
dtype=base_module.dtype,
param_dtype=base_module.param_dtype,
in_features=base_module.in_features,
lora_rank=lora_rank,
out_features=base_module.out_features,
),
)
pbar.update(1)
return model
[docs]def split_lora_params(model: nn.Module) -> nn.Module:
"""
get LoRA (Low-Rank Adaptation) from layers within a model.
Args:
model: The EasyDeL model.
Returns:
LoRA Layer Weights.
"""
from easydel.utils.graph_utils import get_module_from_path, iter_module_search
od = {}
with tqdm(
total=len([p[0] for p in iter_module_search(model, nn.LoRA)]),
desc="Split LoRA Params",
) as pbar:
for path, _ in iter_module_search(model, nn.LoRA):
base_module: nn.LoRA = get_module_from_path(model=model, path=path)
od.update({path: {"lora_a": base_module.lora_a, "lora_b": base_module.lora_b}})
pbar.update(1)
return unflatten_dict(od)
[docs]def merge_lora_params(model: nn.Module, lora_tree: tp.Dict) -> nn.Module:
"""
get LoRA (Low-Rank Adaptation) from layers within a model.
Args:
model: The EasyDeL model.
Returns:
LoRA Layer Weights.
"""
from easydel.utils.graph_utils import get_module_from_path, iter_module_search
if not is_flatten(lora_tree):
lora_tree = flatten_dict(lora_tree)
with tqdm(
total=len([p[0] for p in iter_module_search(model, nn.LoRA)]),
desc="Merge LoRA Params",
) as pbar:
for path, _ in iter_module_search(model, nn.LoRA):
base_module: nn.LoRA = get_module_from_path(model=model, path=path)
base_module.lora_b = lora_tree[path + ("lora_b",)]
base_module.lora_a = lora_tree[path + ("lora_a",)]
pbar.update(1)
return model
[docs]def unwrap_lora_to_layers(
model: nn.Module,
/,
*,
verbose: bool = True,
) -> nn.Module:
"""
UnWrap LoRA (Low-Rank Adaptation) from specified linear layers within a model.
"""
from easydel.utils.graph_utils import (
get_module_from_path,
iter_module_search,
set_module_from_path,
)
with tqdm(
total=len([p[0] for p in iter_module_search(model, ParallelLinear)]),
desc="Unwarping LoRA Layers",
disable=not verbose,
) as pbar:
for path, _ in iter_module_search(model, nn.LoRA):
base_module: nn.LoRA = get_module_from_path(model=model, path=path)
with jax.default_matmul_precision("float32"):
base_module.base_module.kernel.value = (
base_module.base_module.kernel.value
+ base_module.lora_a.value @ base_module.lora_b.value
)
del base_module.lora_a, base_module.lora_b
set_module_from_path(
model=model,
path=path,
new_value=base_module.base_module,
)
pbar.update(1)
return model
[docs]def apply_sparsity_to_params(
params: tp.Union[tp.Dict[str, tp.Any], tp.Any],
sparsify_module: AVAILABLE_SPARSE_MODULE_TYPES = "bcoo",
verbose: bool = True,
) -> tp.Union[tp.Dict[str, tp.Any], tp.Any]:
its_frozen = isinstance(params, flax.core.FrozenDict)
flatten = is_flatten(params)
if not flatten:
params = flatten_dict(params)
from jax.experimental import sparse
sparser = {
"bcoo": sparse.BCOO,
"bcsr": sparse.BCSR,
"coo": sparse.COO,
"csr": sparse.CSR,
}.get(sparsify_module, None)
assert sparser is not None, f"unkown type of sparser {sparsify_module}"
def filter_params(path, array):
layer_name = ".".join(path[0].key)
if layer_name.endswith("kernel") and 4 > array.ndim > 1:
array = sparser.fromdense(array)
return array
total_params = len(jax.tree_util.tree_leaves(params))
with tqdm(
total=total_params,
desc=f"{sparsify_module.capitalize()}",
disable=not verbose,
) as pbar:
def _with_progress(path, array):
pbar.set_postfix_str(".".join(path[0].key))
result = filter_params(path, array)
pbar.update(1)
return result
params = jax.tree_util.tree_map_with_path(_with_progress, params)
if not flatten:
params = unflatten_dict(params)
if its_frozen:
return flax.core.FrozenDict(params)
return params
M = tp.TypeVar("M")
[docs]def auto_remat(
*modules: tp.Type[M],
policy: tp.Union[
EasyDeLGradientCheckPointers, str
] = EasyDeLGradientCheckPointers.NONE,
prevent_cse: bool = True,
) -> tp.Tuple[tp.Type[M], ...]:
if policy == EasyDeLGradientCheckPointers.NONE:
return modules
if isinstance(policy, str):
policy = get_gradient_checkpoint_policy(policy)
outs = ()
for module in modules:
assert issubclass(module, nn.Module)
static_argnums = extract_static_parameters(module=module)
if static_argnums is None:
static_argnums = ()
module.__call__ = nn.remat(
f=module.__call__,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
)
outs += (module,)
return outs
# Main FLOP counting function
[docs]def count_flop_jaxpr(jaxpr) -> int:
"""Count flops in a Jaxpr."""
def get_shape_size(shape) -> int:
"""Calculate total size of an array shape."""
return int(np.prod(shape)) if shape else 1
def compute_binary_op_flops(eqn) -> int:
"""Generic FLOP counter for binary operations with broadcasting."""
shape0 = eqn.invars[0].aval.shape
shape1 = eqn.invars[1].aval.shape
output_shape = np.broadcast_shapes(shape0, shape1)
return get_shape_size(output_shape)
def compute_unary_op_flops(eqn) -> int:
"""FLOP counter for unary operations."""
shape = eqn.invars[0].aval.shape
return get_shape_size(shape)
def compute_dot_general_flops(eqn) -> int:
"""Compute FLOPs for dot_general operation."""
shapes = [var.aval.shape for var in eqn.invars]
if len(shapes) != 2:
return 0
dimension_numbers = eqn.params.get("dimension_numbers", None)
if not dimension_numbers:
return 0
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
# Calculate sizes for contracting dimensions
contracting_size = np.prod([shapes[0][d] for d in lhs_contract])
# Calculate output shape size
batch_size = np.prod([shapes[0][d] for d in lhs_batch])
lhs_remaining = [
d for i, d in enumerate(shapes[0]) if i not in lhs_contract and i not in lhs_batch
]
rhs_remaining = [
d for i, d in enumerate(shapes[1]) if i not in rhs_contract and i not in rhs_batch
]
out_size = batch_size * np.prod(lhs_remaining) * np.prod(rhs_remaining)
# Each output element requires 2*contracting_size - 1 operations
return out_size * (2 * contracting_size - 1)
def compute_conv_flops(eqn) -> int:
"""Compute FLOPs for convolution operation."""
lhs_shape = eqn.invars[0].aval.shape
rhs_shape = eqn.invars[1].aval.shape
dimension_numbers = eqn.params.get("dimension_numbers", None)
if not dimension_numbers:
return 0
lhs_spec, rhs_spec, out_spec = dimension_numbers
batch_size = lhs_shape[lhs_spec.index("N")]
in_channels = lhs_shape[lhs_spec.index("C")]
out_channels = rhs_shape[rhs_spec.index("O")]
spatial_size = 1
kernel_size = 1
for d in range(len(lhs_spec) - 2):
spatial_size *= lhs_shape[lhs_spec.index(str(d))]
kernel_size *= rhs_shape[rhs_spec.index(str(d))]
ops_per_point = 2 * kernel_size * in_channels - 1
total_points = batch_size * spatial_size * out_channels
return ops_per_point * total_points
def compute_reduce_flops(eqn) -> int:
"""Compute FLOPs for reduction operations."""
shape = eqn.invars[0].aval.shape
reduced_axes = eqn.params.get("axes", ())
if not reduced_axes:
return 0
reduced_size = np.prod([shape[ax] for ax in reduced_axes])
remaining_shape = [s for i, s in enumerate(shape) if i not in reduced_axes]
remaining_size = np.prod(remaining_shape) if remaining_shape else 1
return remaining_size * (reduced_size - 1)
def compute_attention_flops(eqn) -> int:
"""Compute FLOPs for attention operation."""
q_shape = eqn.invars[0].aval.shape
k_shape = eqn.invars[1].aval.shape
batch, q_len, num_heads, head_dim = q_shape
_, kv_len, _, _ = k_shape
qk_flops = batch * num_heads * q_len * kv_len * (2 * head_dim - 1)
softmax_flops = batch * num_heads * q_len * (kv_len + (kv_len - 1) + 1)
av_flops = batch * num_heads * q_len * head_dim * (2 * kv_len - 1)
return qk_flops + softmax_flops + av_flops
def count_scan_flops(eqn) -> int:
"""Count FLOPs in a scan operation."""
scan_jaxpr = eqn.params.get("jaxpr", None)
if scan_jaxpr:
body_flops = count_flop_jaxpr(scan_jaxpr)
length = eqn.invars[0].aval.shape[0]
return body_flops * length
return 0
def count_cond_flops(eqn) -> int:
"""Count FLOPs in a conditional operation."""
true_jaxpr = eqn.params.get("true_jaxpr", None)
false_jaxpr = eqn.params.get("false_jaxpr", None)
total_flops = 0
if true_jaxpr:
total_flops += count_flop_jaxpr(true_jaxpr)
if false_jaxpr:
total_flops += count_flop_jaxpr(false_jaxpr)
return total_flops // 2
def get_scatter_flops(eqn) -> int:
"""Count FLOPs in a scatter operation."""
updates_shape = eqn.invars[2].aval.shape
return get_shape_size(updates_shape)
def compute_select_n_flops(eqn) -> int:
"""Compute FLOPs for select_n operation."""
pred_shape = eqn.invars[0].aval.shape
return get_shape_size(pred_shape)
def compute_cumsum_flops(eqn) -> int:
"""Compute FLOPs for cumulative sum."""
shape = eqn.invars[0].aval.shape
axis = eqn.params.get("axis", 0)
# Each element adds to the previous sum
return get_shape_size(shape) - shape[axis]
def compute_max_flops(eqn) -> int:
"""Compute FLOPs for max operation."""
if len(eqn.invars) == 2:
# Binary max
return compute_binary_op_flops(eqn)
# Unary max
return compute_unary_op_flops(eqn)
def compute_pow_flops(eqn) -> int:
"""Compute FLOPs for power operation."""
if len(eqn.invars) == 2:
shape0 = eqn.invars[0].aval.shape
shape1 = eqn.invars[1].aval.shape
output_shape = np.broadcast_shapes(shape0, shape1)
return 8 * get_shape_size(output_shape) # Power is expensive
return 8 * get_shape_size(eqn.invars[0].aval.shape)
def compute_integer_pow_flops(eqn) -> int:
"""Compute FLOPs for integer power."""
shape = eqn.invars[0].aval.shape
power = eqn.params.get("y", 2)
return (power - 1) * get_shape_size(shape)
def compute_and_flops(eqn) -> int:
"""Compute FLOPs for logical and operation."""
return compute_binary_op_flops(eqn)
def count_custom_vjp_flops(eqn) -> int:
"""Count FLOPs in custom VJP operation."""
fwd_jaxpr = eqn.params.get("fun_jaxpr", None)
if fwd_jaxpr:
return count_flop_jaxpr(fwd_jaxpr)
return 0
def compute_sqrt_flops(eqn) -> int:
"""Compute FLOPs for square root operation."""
# Square root is typically more expensive than basic operations
return 4 * compute_unary_op_flops(eqn)
def compute_argmax_flops(eqn) -> int:
"""Compute FLOPs for argmax operation."""
shape = eqn.invars[0].aval.shape
axis = eqn.params.get("axes", (0,))[0]
# For each output element, we need to compare n-1 elements where n is the size of the reduction axis
remaining_size = get_shape_size(shape) // shape[axis]
return remaining_size * (shape[axis] - 1)
def compute_min_flops(eqn) -> int:
"""Compute FLOPs for min operation."""
if len(eqn.invars) == 2:
# Binary min
return compute_binary_op_flops(eqn)
# Unary min
return compute_unary_op_flops(eqn)
def compute_rem_flops(eqn) -> int:
"""Compute FLOPs for remainder operation."""
# Remainder typically involves division and multiplication
return 2 * compute_binary_op_flops(eqn)
def compute_square_flops(eqn) -> int:
"""Compute FLOPs for square operation (x * x)."""
# Square is a single multiplication of a number by itself
shape = eqn.invars[0].aval.shape
return get_shape_size(shape)
def compute_triangular_solve_flops(eqn) -> int:
"""Compute FLOPs for triangular solve operation."""
# For a triangular solve with a matrix of size n x n,
# each row/column requires n^2/2 multiply-adds
matrix_shape = eqn.invars[0].aval.shape
n = matrix_shape[-1] # Size of the last dimension
batch_dims = matrix_shape[:-2]
batch_size = np.prod(batch_dims) if batch_dims else 1
return batch_size * n * (n + 1) * (2 * n + 1) // 6
def compute_erf_inv_flops(eqn) -> int:
"""Compute FLOPs for inverse error function."""
# erf_inv is computationally expensive, typically implemented
# as a series expansion or numerical approximation
return 15 * compute_unary_op_flops(eqn)
def compute_or_flops(eqn) -> int:
"""Compute FLOPs for logical or operation."""
return compute_binary_op_flops(eqn)
def compute_shift_right_logical_flops(eqn) -> int:
"""Compute FLOPs for logical right shift."""
return compute_binary_op_flops(eqn)
# Dictionary mapping primitives to their FLOP counting functions
primitive_flops: tp.Dict[str, tp.Callable] = {
# Binary operations
"mul": compute_binary_op_flops,
"add": compute_binary_op_flops,
"sub": compute_binary_op_flops,
"div": compute_binary_op_flops,
"gt": compute_binary_op_flops,
"lt": compute_binary_op_flops,
"ge": compute_binary_op_flops,
"le": compute_binary_op_flops,
"ne": compute_binary_op_flops,
"eq": compute_binary_op_flops,
# Unary operations
"neg": compute_unary_op_flops,
"sin": lambda eqn: 5 * compute_unary_op_flops(eqn),
"cos": lambda eqn: 5 * compute_unary_op_flops(eqn),
"exp": lambda eqn: 4 * compute_unary_op_flops(eqn),
"log": lambda eqn: 6 * compute_unary_op_flops(eqn),
"log1p": lambda eqn: 6 * compute_unary_op_flops(eqn),
"tanh": lambda eqn: 7 * compute_unary_op_flops(eqn),
"rsqrt": lambda eqn: 6 * compute_unary_op_flops(eqn),
# Linear algebra
"dot_general": compute_dot_general_flops,
"conv_general_dilated": compute_conv_flops,
# Reduction operations
"reduce_sum": compute_reduce_flops,
"reduce_max": compute_reduce_flops,
"reduce_min": compute_reduce_flops,
# Special operations
"scatter-add": get_scatter_flops,
"scan": count_scan_flops,
"cond": count_cond_flops,
# Memory operations (0 FLOPs)
"broadcast_in_dim": lambda eqn: 0,
"reshape": lambda eqn: 0,
"transpose": lambda eqn: 0,
"slice": lambda eqn: 0,
"gather": lambda eqn: 0,
"concatenate": lambda eqn: 0,
"convert_element_type": lambda eqn: 0,
"dynamic_slice": lambda eqn: 0,
"pad": lambda eqn: 0,
# Parallel/Sharding operations (0 FLOPs)
"pjit": lambda eqn: 0,
"shard_map": lambda eqn: 0,
"sharding_constraint": lambda eqn: 0,
# Other operations
"dot_product_attention_fwd_wrapper": compute_attention_flops,
"select_n": compute_select_n_flops,
"cumsum": compute_cumsum_flops,
"max": compute_max_flops,
"iota": lambda eqn: 0, # Memory operation, no FLOPs
"pow": compute_pow_flops,
"integer_pow": compute_integer_pow_flops,
"and": compute_and_flops,
"random_fold_in": lambda eqn: 0, # Random number generation, no FLOPs
"custom_vjp_call_jaxpr": count_custom_vjp_flops,
"logistic": lambda eqn: 4 * compute_unary_op_flops(eqn), # sigmoid function
# No-op operations (0 FLOPs)
"stop_gradient": lambda eqn: 0, # Just passes through the value
"squeeze": lambda eqn: 0, # Reshapes data, no computation
"copy": lambda eqn: 0, # Memory operation only
"split": lambda eqn: 0,
"remat2": lambda eqn: 0,
"random_seed": lambda eqn: 0,
"random_unwrap": lambda eqn: 0,
"random_wrap": lambda eqn: 0,
"random_split": lambda eqn: 0,
"random_bits": lambda eqn: 0,
# Bitwise and type conversion operations
"shift_right_logical": compute_shift_right_logical_flops,
"or": compute_or_flops,
"bitcast_convert_type": lambda eqn: 0, # Type conversion, no computation
# Mathematical operations
"abs": compute_unary_op_flops, # Single comparison/selection per element
"erf_inv": compute_erf_inv_flops, # Inverse error function
"triangular_solve": compute_triangular_solve_flops,
# Computation operations
"square": compute_square_flops,
"sqrt": compute_sqrt_flops,
"argmax": compute_argmax_flops,
"add_any": compute_binary_op_flops, # Similar to regular add
"min": compute_min_flops,
"rem": compute_rem_flops,
}
flops = 0
def visit_jaxpr(jaxpr):
nonlocal flops
for eqn in jaxpr.eqns:
primitive_name = eqn.primitive.name
if primitive_name in primitive_flops:
flops += primitive_flops[primitive_name](eqn)
else:
warnings.warn(f"Unhandled primitive {primitive_name}", stacklevel=1)
# Recursively visit subjaxprs
for subjaxpr in jax.core.jaxprs_in_params(eqn.params):
visit_jaxpr(subjaxpr)
visit_jaxpr(jaxpr)
return flops
[docs]class TraceResult:
def __init__(self, executable):
self._executable = executable
self._cached_cost = None
@property
@lru_cache(maxsize=1) # noqa
def cost_analysis(self):
return self._executable.cost_analysis()
@property
def flops(self):
return self.cost_analysis["flops"]
[docs]class FunctionTracer:
def __init__(self):
self.new_executables: List[TraceResult] = []
self._before: Set = set()
def __getitem__(self, idx):
return self.new_executables[idx]
[docs]class CompilationTracker:
def __init__(self):
self.first_time = True
self.cached_flops = 0
self.functions = None
@property
def online_flops(self):
if self.functions is None:
return 0
cached_flops = 0
for cm in self.functions:
try:
cached_flops += cm.cost_analysis()["flops"]
except Exception:
...
return cached_flops
[docs] @contextmanager
def trace_compilation(self):
if self.first_time:
before = set(jax.lib.xla_bridge.get_backend().live_executables())
yield
after = set(jax.lib.xla_bridge.get_backend().live_executables())
new = after - before
if new:
cmpf = list(new)
self.functions = cmpf
for cm in cmpf:
try:
self.cached_flops += cm.cost_analysis()["flops"]
except Exception:
...
self.first_time = False
else:
yield
[docs]@contextmanager
def trace_functions():
tracer = FunctionTracer()
tracer._before = set(jax.lib.xla_bridge.get_backend().live_executables())
try:
yield tracer
finally:
after = set(jax.lib.xla_bridge.get_backend().live_executables())
new = after - tracer._before
tracer.new_executables = [TraceResult(exe) for exe in new]
[docs]class ModuleCaches(nn.Cache): ...
[docs]class OverWriteWithGradient(nn.Param): ...
if tp.TYPE_CHECKING:
from transformers import (
BaseImageProcessor,
FeatureExtractionMixin,
PreTrainedTokenizerBase,
ProcessorMixin,
)
ProcessingClassType = tp.Optional[
tp.Union[
PreTrainedTokenizerBase,
BaseImageProcessor,
FeatureExtractionMixin,
ProcessorMixin,
]
]
else:
ProcessingClassType = tp.Any