easydel.layers.norms#

Normalization layers for neural networks.

Provides efficient normalization layers optimized for JAX/Flax, with support for mixed precision and float8 data types.

Classes:

RMSNorm: Root Mean Square normalization layer

Constants:

float8s: List of supported float8 data types

Key Features:
  • Efficient RMS normalization

  • Support for float8 quantization

  • Mixed precision computation

  • Automatic dtype promotion

Example

>>> from easydel.layers.norms import RMSNorm
>>> norm = RMSNorm(
...     dim=768,
...     eps=1e-6,
...     dtype=jnp.bfloat16
... )
>>> normalized = norm(inputs)

Note

RMSNorm is particularly efficient for large language models as it requires fewer parameters than LayerNorm while providing similar normalization benefits.

class easydel.layers.norms.RMSNorm(*args: Any, **kwargs: Any)[source]#

Bases: Module

Root Mean Square normalization layer.

RMSNorm normalizes inputs by their root mean square value, providing a simpler and more efficient alternative to LayerNorm.

dim#

Dimension of the input features.

eps#

Small constant for numerical stability.

dtype#

Data type for computations.

param_dtype#

Data type for parameters.

kernel#

Learnable scale parameters.

__call__()[source]#

Apply RMS normalization to input.

static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)