easydel.kernels.rms_norm
-
easydel.kernels.rms_norm.basic_layer_norm(x: Array, weight: Array, eps: float) → Array[source]
-
easydel.kernels.rms_norm.rms_norm(x: ~jax.Array, w: ~jax.Array, blocksize_x: int = 8, eps: float = 1e-05, prod_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]
-
easydel.kernels.rms_norm.test_bwd_call()[source]
-
easydel.kernels.rms_norm.test_fwd_call()[source]