easydel.kernels.rms_norm

easydel.kernels.rms_norm#

easydel.kernels.rms_norm.basic_layer_norm(x: Array, weight: Array, eps: float) Array[source]#
easydel.kernels.rms_norm.rms_norm(x: ~jax.Array, w: ~jax.Array, blocksize_x: int = 8, eps: float = 1e-05, prod_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#
easydel.kernels.rms_norm.test_bwd_call()[source]#
easydel.kernels.rms_norm.test_fwd_call()[source]#