# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as tp
from functools import partial
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx import rnglib
from flax.nnx.nn import dtypes, initializers
from flax.typing import (
DotGeneralT,
Dtype,
Initializer,
PrecisionLike,
)
from jax import lax
from .base_quant import QauntModule
Array = jax.Array
Axis = int
Size = int
default_kernel_init = initializers.lecun_normal()
default_bias_init = initializers.zeros_init()
[docs]@partial(jax.jit, static_argnames=["block_size"])
def single_quantize_and_pack_nf4(blocks, block_size=64):
"""
Combined quantization and packing for better performance.
Handles normalization, quantization, and packing in a single operation.
"""
# Pad and reshape into blocks
blocks = blocks.reshape(-1, block_size)
# Compute absolute maximum for each block
absmax = jnp.max(jnp.abs(blocks), axis=1)
# Normalize blocks
normalized = blocks / absmax[:, None]
# Quantize using vectorized operations
quantized = (
jnp.searchsorted(
jnp.array(
[
-float("inf"),
-0.8480964004993439,
-0.6106329262256622,
-0.4599952697753906,
-0.33967943489551544,
-0.23460740596055984,
-0.13791173323988914,
-0.045525018125772476,
0.03979014977812767,
0.1202552504837513,
0.2035212516784668,
0.2920137718319893,
0.3893125355243683,
0.5016634166240692,
0.6427869200706482,
0.8614784181118011,
],
dtype=jnp.float32,
),
normalized.reshape(-1),
)
- 1
)
# Pack pairs efficiently using bit operations
quantized = quantized.reshape(-1, 2)
packed = (quantized[:, 0] << 4) | quantized[:, 1]
return packed.astype(jnp.uint8), absmax
[docs]@partial(jax.jit, static_argnames=["block_size"])
def single_dequantize_nf4(packed_values, absmax, block_size):
"""
Optimized dequantization combining unpacking and scaling in fewer operations.
"""
high = (packed_values >> 4) & 0xF
low = packed_values & 0xF
unpacked = jnp.stack([high, low], axis=1).reshape(-1)
dequantized = jnp.array(
[
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
],
dtype=jnp.float32,
)[unpacked]
num_blocks = len(absmax)
dequantized = dequantized.reshape(num_blocks, block_size)
scaled = dequantized * absmax[:, None]
return scaled
[docs]@partial(jax.jit, static_argnames=["block_size"])
def quantize_and_pack_nf4(blocks, block_size=64):
if blocks.ndim > 2:
return jax.vmap(quantize_and_pack_nf4, in_axes=(0, None), out_axes=(0, 0))(
blocks, block_size
)
return single_quantize_and_pack_nf4(blocks, block_size)
[docs]@partial(jax.jit, static_argnames=["block_size"])
def dequantize_nf4(packed_values, absmax, block_size):
if packed_values.ndim > 2:
return jax.vmap(dequantize_nf4, in_axes=(0, 0, None), out_axes=(0,))(
packed_values, absmax, block_size
)
return single_dequantize_nf4(packed_values, absmax, block_size)
[docs]class LinearNF4(QauntModule):
"""A 4-bit quantized version of the linear transformation using NF4 quantization."""
def __init__(
self,
in_features: int,
out_features: int,
*,
use_bias: bool = True,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
do_init: bool = False,
kernel_init: Initializer = default_kernel_init,
bias_init: Initializer = default_bias_init,
dot_general: DotGeneralT = lax.dot_general,
rngs: rnglib.Rngs,
block_size: int = 64,
):
super().__init__(
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
)
# Initialize the quant_kernel
if do_init:
kernel_key = rngs.params()
quant_kernel = kernel_init(kernel_key, (in_features, out_features), param_dtype)
quant_kernel, quant_scales = self._quantize_kernel(quant_kernel)
else:
quant_kernel, quant_scales = None, None
self.quant_kernel = nnx.Param(quant_kernel)
self.quant_scales = nnx.Param(quant_scales)
if use_bias and do_init:
bias_key = rngs.params()
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
else:
self.bias = nnx.Param(None)
self.in_features = in_features
self.out_features = out_features
self.use_bias = use_bias
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.kernel_init = kernel_init
self.bias_init = bias_init
self.dot_general = dot_general
self.block_size = block_size
[docs] @classmethod
def from_linear(
cls,
linear: nnx.Linear,
rngs: tp.Optional[rnglib.Rngs] = None,
block_size: int = 128,
**kwargs,
) -> "LinearNF4":
if rngs is None:
rngs = nnx.Rngs(0)
instance = nnx.eval_shape(
lambda: cls(
in_features=linear.in_features,
out_features=linear.out_features,
use_bias=linear.use_bias,
dtype=linear.dtype,
param_dtype=linear.param_dtype,
precision=linear.precision,
kernel_init=linear.kernel_init,
bias_init=linear.bias_init,
dot_general=linear.dot_general,
block_size=block_size,
rngs=rngs,
)
)
quant_kernel, quant_scales = cls._quantize_kernel(linear.kernel.value, block_size)
instance.quant_kernel = nnx.Param(quant_kernel)
instance.quant_scales = nnx.Param(quant_scales)
if linear.use_bias:
instance.bias = nnx.Param(linear.bias.value)
return instance
[docs] def to_linear(self, rngs: tp.Optional[rnglib.Rngs] = None) -> nnx.Linear:
if rngs is None:
rngs = nnx.Rngs(0)
linear = nnx.eval_shape(
lambda: nnx.Linear(
in_features=self.in_features,
out_features=self.out_features,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dot_general=self.dot_general,
rngs=rngs,
)
)
dequantized_kernel = self._dequantize_kernel()
linear.quant_kernel = nnx.Param(dequantized_kernel)
if self.use_bias:
linear.bias = nnx.Param(self.bias.value)
return linear
@staticmethod
def _quantize_kernel(quant_kernel, block_size):
"""Quantize the quant_kernel weights using NF4."""
if quant_kernel is None or isinstance(quant_kernel, jax.ShapeDtypeStruct):
return None, None
return quantize_and_pack_nf4(quant_kernel, block_size)
def _dequantize_kernel(self): # in case someone's using tie word embedding.
"""Dequantize the quant_kernel weights from NF4."""
if (
self.quant_kernel.value is None
and self.quant_scales.value is None
and self.block_size is None
):
return None
elif self.quant_scales.value is None and self.block_size is None:
return self.quant_kernel
return dequantize_nf4(
self.quant_kernel.value,
self.quant_scales.value,
self.block_size,
).reshape(self.in_features, self.out_features)
@jax.named_scope("easydel-linear-nf4-call")
def __call__(self, inputs: Array) -> Array:
"""Applies a quantized linear transformation to the inputs along the last dimension."""
quant_kernel = self._dequantize_kernel()
assert quant_kernel is not None, (
"loaded and dequantized quant_kernel is None, which means it have been loaded from another None Kernel Linear"
)
bias = self.bias.value
inputs, quant_kernel, bias = dtypes.promote_dtype(
(inputs, quant_kernel, bias), dtype=self.dtype
)
y = self.dot_general(
inputs,
quant_kernel,
(((inputs.ndim - 1,), (0,)), ((), ())),
precision=self.precision,
)
assert self.use_bias == (bias is not None)
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
return y
[docs] def get_kernel(self):
"""Get the dequantized quant_kernel weights."""
return self._dequantize_kernel()
[docs] def get_quantized_kernel(self):
"""Get the quantized quant_kernel weights and quant_scales."""
return self.quant_kernel.value, self.quant_scales.value
[docs] @staticmethod
def quantization_mapping():
return {"kernel": ["quant_kernel", "quant_scales", "block_size"]}