Source code for easydel.kernels.gpu_ops._utils
import typing as tp
import chex
import numpy
from jax import numpy as jnp
F = tp.TypeVar("F", bound=tp.Callable[..., tp.Any])
[docs]def safe_autotune(
configs,
key,
prune_configs_by=None,
reset_to_zero=None,
restore_value=None,
pre_hook=None,
post_hook=None,
warmup=None,
rep=None,
use_cuda_graph=False,
do_bench=None,
) -> tp.Callable[[F], F]:
"""
Applies `triton.autotune` safely. Falls back to the original function if autotuning fails.
"""
try:
from triton.runtime.autotuner import Autotuner
def decorator(fn):
try:
return Autotuner(
fn,
fn.arg_names,
configs,
key,
reset_to_zero,
restore_value,
pre_hook=pre_hook,
post_hook=post_hook,
prune_configs_by=prune_configs_by,
warmup=warmup,
rep=rep,
use_cuda_graph=use_cuda_graph,
)
except Exception:
return fn
return decorator
except (Exception, RuntimeError) as err:
print(f"Couldn't autotune given function due to {err}")
def decorator(fn):
return fn
return decorator
[docs]def dtype_index(x: jnp.array) -> int:
if x.dtype == jnp.float16:
return 1
if x.dtype == jnp.bfloat16:
return 2
if x.dtype == jnp.float32:
return 3
raise ValueError(x.dtype)
[docs]def get_sharding(arr: chex.Array):
"""Gets the sharding of an array.
Args:
arr: Array to get sharding from.
Returns:
Sharding of the array.
"""
return getattr(arr, "sharding", None)
[docs]def get_strides(shape: tuple[int, ...]) -> tuple[int, ...]:
"""Calculates strides for a given shape.
Args:
shape: Shape of the array.
Returns:
Tuple of strides.
"""
if hasattr(shape, "shape"):
shape = shape.shape
size = numpy.prod(shape)
strides = []
for s in shape:
size = int(size // s)
strides.append(size)
return tuple(strides)