# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as tp
from functools import partial
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx import rnglib
from flax.nnx.nn import initializers
from flax.typing import (
DotGeneralT,
Dtype,
Initializer,
PrecisionLike,
)
from jax import lax
from .base_quant import QauntModule
Array = jax.Array
Axis = int
Size = int
default_kernel_init = initializers.lecun_normal()
default_bias_init = initializers.zeros_init()
[docs]def quantize_8bit(x):
"""
Quantize a row of float32 values to 8-bit integers with blockwise scaling.
"""
max_val = jnp.amax(jnp.abs(x.astype(jnp.float32)), axis=-1, keepdims=True)
max_val = jnp.clip(max_val, min=1e-5)
qscale = max_val / 127
qweight = jnp.clip(
jnp.round(x * (1.0 / qscale)),
min=-128,
max=127,
).astype(jnp.int8)
qscale = qscale.astype(x.dtype)
return qweight, qscale
[docs]def dequantize_8bit(quants, scales):
"""
Dequantize 8-bit integers back to values using blockwise scaling.
"""
dequantized = quants * scales
return dequantized
@partial(jax.custom_vjp, nondiff_argnums=(3, 4))
def quantized_matmul(x, qweight, qscale, transpose_weight=False, smt=True):
"""
Forward pass for 8-bit quantized matrix multiplication.
"""
# Dequantize weights
dequantized = dequantize_8bit(qweight, qscale)
if transpose_weight:
dequantized = dequantized.T
return jnp.matmul(x, dequantized)
[docs]def quantized_matmul_fwd(x, qweight, qscale, transpose_weight):
"""Forward pass that saves required values for backward pass."""
# Dequantize weights for forward computation
dequantized = dequantize_8bit(qweight, qscale)
if transpose_weight:
dequantized = dequantized.T
# Compute output
out = jnp.matmul(x, dequantized)
# Save values needed for backward pass
saved = (x, qweight, qscale, dequantized, transpose_weight)
return out, saved
[docs]def quantized_matmul_bwd(transpose_weight, res, grad_output):
"""
Backward pass computing gradients for quantized matrix multiplication.
Args:
transpose_weight: Whether weight matrix was transposed
res: Saved values from forward pass
grad_output: Gradient of loss with respect to output
"""
x, qweight, qscale, dequantized, _ = res
# Gradient with respect to input x
if transpose_weight:
grad_x = jnp.matmul(grad_output, dequantized.T)
else:
grad_x = jnp.matmul(grad_output, dequantized)
# Gradient with respect to quantized weights and scales
if transpose_weight:
grad_dequant = jnp.matmul(grad_output.T, x)
else:
grad_dequant = jnp.matmul(x.T, grad_output)
# Optimize gradient calculation for scales
abs_qweight = jnp.abs(qweight)
# Calculate scaling factors more efficiently
scale_grad_factor = jnp.where(
abs_qweight > 0, qweight.astype(grad_dequant.dtype) / (127.0 * abs_qweight), 0.0
)
# Compute gradients for scaling factors using einsum for better performance
grad_qscale = jnp.einsum("ij,ij->i", grad_dequant, scale_grad_factor)[:, jnp.newaxis]
# Compute gradients for quantized weights, avoiding clipping here
grad_qweight = grad_dequant * qscale
# Pack gradients for weight tuple
grad_weight = (grad_qweight, grad_qscale)
return grad_x, quantize_8bit(grad_weight)
# Register the custom VJP (Vector-Jacobian Product) rules
quantized_matmul.defvjp(quantized_matmul_fwd, quantized_matmul_bwd)
[docs]class Linear8bit(QauntModule):
"""An 8-bit quantized version of the linear transformation applied over the last dimension of the input."""
def __init__(
self,
in_features: int,
out_features: int,
*,
use_bias: bool = True,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
do_init: bool = False,
kernel_init: Initializer = default_kernel_init,
bias_init: Initializer = default_bias_init,
dot_general: DotGeneralT = lax.dot_general,
rngs: rnglib.Rngs,
):
super().__init__(
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
)
if do_init:
kernel_key = rngs.params()
quant_kernel = kernel_init(kernel_key, (in_features, out_features), param_dtype)
quantized_kernel, quant_scales = self._quantize_kernel(quant_kernel)
else:
quantized_kernel, quant_scales = None, None
# Quantize the quant_kernel
self.quant_kernel = nnx.Param(quantized_kernel)
self.quant_scales = nnx.Param(quant_scales)
if use_bias and do_init:
bias_key = rngs.params()
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
else:
self.bias = nnx.Param(None)
self.in_features = in_features
self.out_features = out_features
self.use_bias = use_bias
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.kernel_init = kernel_init
self.bias_init = bias_init
self.dot_general = dot_general
[docs] @classmethod
def from_linear(
cls,
linear: nnx.Linear,
rngs: tp.Optional[rnglib.Rngs] = None,
**kwargs,
) -> "Linear8bit":
"""
Create a Linear8bit module from a regular Linear module.
Args:
linear: The source Linear module
rngs: Random number generator state
Returns:
A new Linear8bit module with quantized weights
"""
if rngs is None:
rngs = nnx.Rngs(0)
# Create a new instance with minimal initialization
instance = nnx.eval_shape(
lambda: cls(
in_features=linear.in_features,
out_features=linear.out_features,
use_bias=linear.use_bias,
dtype=linear.dtype,
param_dtype=linear.param_dtype,
precision=linear.precision,
kernel_init=linear.kernel_init,
bias_init=linear.bias_init,
dot_general=linear.dot_general,
rngs=rngs,
)
)
# Quantize the quant_kernel from the original linear layer
quantized_kernel, quant_scales = cls._quantize_kernel(linear.kernel.value)
# Update the parameters
instance.quant_kernel = nnx.Param(quantized_kernel)
instance.quant_scales = nnx.Param(quant_scales)
# Copy the bias if it exists
if linear.use_bias:
instance.bias = nnx.Param(linear.bias.value)
return instance
[docs] def to_linear(self, rngs: tp.Optional[rnglib.Rngs] = None) -> nnx.Linear:
"""
Convert this Linear8bit module back to a regular Linear module.
Args:
rngs: Random number generator state
Returns:
A new Linear module with dequantized weights
"""
if rngs is None:
rngs = nnx.Rngs(0)
# Create a new Linear instance
linear = nnx.eval_shape(
lambda: nnx.Linear(
in_features=self.in_features,
out_features=self.out_features,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
dot_general=self.dot_general,
rngs=rngs,
)
)
# Dequantize the quant_kernel and update the linear layer
dequantized_kernel = self._dequantize_kernel()
linear.quant_kernel = nnx.Param(dequantized_kernel)
# Copy the bias if it exists
if self.use_bias:
linear.bias = nnx.Param(self.bias.value)
return linear
@staticmethod
def _quantize_kernel(quant_kernel):
"""Quantize the quant_kernel weights."""
if quant_kernel is None or isinstance(quant_kernel, jax.ShapeDtypeStruct):
return None, None
quantized, quant_scales = quantize_8bit(quant_kernel)
return quantized, quant_scales
def _dequantize_kernel(self): # in case somebody using tie word embedding.
"""Dequantize the quant_kernel weights."""
if self.quant_kernel.value is None and self.quant_scales.value is None:
return None
elif self.quant_scales.value is None:
return self.quant_kernel
return dequantize_8bit(
self.quant_kernel.value,
self.quant_scales.value,
).astype(self.param_dtype)
@jax.named_scope("easydel-linear-8bit-call")
def __call__(self, inputs: Array) -> Array:
"""Forward pass using custom gradient computation."""
out = quantized_matmul(
inputs,
self.quant_kernel.value,
self.quant_scales.value,
transpose_weight=False,
)
if self.use_bias:
out = out + self.bias.value
return out
[docs] def get_kernel(self):
"""Get the dequantized quant_kernel weights."""
return self._dequantize_kernel()
[docs] def get_quantized_kernel(self):
"""Get the quantized quant_kernel weights and quant_scales."""
return self.quant_kernel.value, self.quant_scales.value
[docs] @staticmethod
def quantization_mapping():
return {"kernel": ["quant_kernel", "quant_scales"]}