easydel.layers.norms#
Normalization layers for neural networks.
Provides efficient normalization layers optimized for JAX/Flax, with support for mixed precision and float8 data types.
- Classes:
RMSNorm: Root Mean Square normalization layer
- Constants:
float8s: List of supported float8 data types
- Key Features:
Efficient RMS normalization
Support for float8 quantization
Mixed precision computation
Automatic dtype promotion
Example
>>> from easydel.layers.norms import RMSNorm
>>> norm = RMSNorm(
... dim=768,
... eps=1e-6,
... dtype=jnp.bfloat16
... )
>>> normalized = norm(inputs)
Note
RMSNorm is particularly efficient for large language models as it requires fewer parameters than LayerNorm while providing similar normalization benefits.
- class easydel.layers.norms.RMSNorm(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleRoot Mean Square normalization layer.
RMSNorm normalizes inputs by their root mean square value, providing a simpler and more efficient alternative to LayerNorm.
- dim#
Dimension of the input features.
- eps#
Small constant for numerical stability.
- dtype#
Data type for computations.
- param_dtype#
Data type for parameters.
- kernel#
Learnable scale parameters.
- static kernel_init(key: Array, shape: Sequence[Union[int, Any]], dtype: Any | None = None, out_sharding: jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None = None) Array#
An initializer that returns a constant array full of ones.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)