Source code for easydel.layers.quantization.linear_8bit

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import jax
import jax.numpy as jnp
from eformer.ops.quantization import Array8B
from flax import nnx
from flax.nnx import rnglib
from flax.nnx.nn import initializers
from flax.typing import DotGeneralT, Dtype, Initializer, PrecisionLike
from jax import lax

from .base_quant import QauntModule

Array = jax.Array
Axis = int
Size = int

default_kernel_init = initializers.lecun_normal()
default_bias_init = initializers.zeros_init()


[docs]class Linear8bit(QauntModule): """An 8-bit quantized version of the linear transformation applied over the last dimension of the input. Uses eformer's Array8B implicit array for efficient 8-bit quantization. """ def __init__( self, in_features: int, out_features: int, *, use_bias: bool = True, dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, do_init: bool = False, kernel_init: Initializer = default_kernel_init, bias_init: Initializer = default_bias_init, dot_general: DotGeneralT = lax.dot_general, rngs: rnglib.Rngs, ): super().__init__( dtype=dtype, param_dtype=param_dtype, precision=precision, ) if do_init: kernel_key = rngs.params() kernel = kernel_init(kernel_key, (in_features, out_features), param_dtype) quant_kernel, quant_scales = self._quantize_kernel(kernel) else: quant_kernel, quant_scales = None, None self.quant_kernel = nnx.Param(quant_kernel) self.quant_scales = nnx.Param(quant_scales) if use_bias and do_init: bias_key = rngs.params() self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype)) else: self.bias = nnx.Param(None) self.in_features = in_features self.out_features = out_features self.use_bias = use_bias self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.kernel_init = kernel_init self.bias_init = bias_init self.dot_general = dot_general
[docs] @classmethod def from_linear( cls, linear: nnx.Linear, rngs: rnglib.Rngs | None = None, **kwargs, ) -> Linear8bit: """Create a Linear8bit module from a regular Linear module.""" if rngs is None: rngs = nnx.Rngs(0) instance = nnx.eval_shape( lambda: cls( in_features=linear.in_features, out_features=linear.out_features, use_bias=linear.use_bias, dtype=linear.dtype, param_dtype=linear.param_dtype, precision=linear.precision, kernel_init=linear.kernel_init, bias_init=linear.bias_init, dot_general=linear.dot_general, rngs=rngs, ) ) quant_kernel, quant_scales = cls._quantize_kernel(linear.kernel.value) instance.quant_kernel = nnx.Param(quant_kernel) instance.quant_scales = nnx.Param(quant_scales) if linear.use_bias: instance.bias = nnx.Param(linear.bias.value) return instance
[docs] def to_linear(self, rngs: rnglib.Rngs | None = None) -> nnx.Linear: """Convert this Linear8bit module back to a regular Linear module.""" if rngs is None: rngs = nnx.Rngs(0) linear = nnx.eval_shape( lambda: nnx.Linear( in_features=self.in_features, out_features=self.out_features, use_bias=self.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init, dot_general=self.dot_general, rngs=rngs, ) ) dequantized_kernel = self._dequantize_kernel() linear.kernel = nnx.Param(dequantized_kernel) if self.use_bias: linear.bias = nnx.Param(self.bias.value) return linear
@staticmethod def _quantize_kernel(kernel): """Quantize the kernel weights using eformer's Array8B.""" if kernel is None or isinstance(kernel, jax.ShapeDtypeStruct): return None, None quantized = Array8B.quantize(kernel) return quantized.weight, quantized.scale def _dequantize_kernel(self): """Dequantize the kernel weights.""" if self.quant_kernel.value is None and self.quant_scales.value is None: return None elif self.quant_scales.value is None: return self.quant_kernel.value quantized = Array8B( weight=self.quant_kernel.value, scale=self.quant_scales.value, shape=(self.in_features, self.out_features), dtype=self.param_dtype, ) return quantized.materialize() @jax.named_scope("easydel-linear-int8-call") def __call__(self, inputs: Array) -> Array: """Forward pass using 8-bit quantized weights.""" kernel = self._dequantize_kernel() assert kernel is not None, ( "loaded and dequantized kernel is None, which means it has been loaded from another None Kernel Linear" ) if self.dtype is not None: inputs = inputs.astype(self.dtype) kernel = kernel.astype(self.dtype) out = self.dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) if self.use_bias and self.bias.value is not None: bias = self.bias.value if self.dtype is not None: bias = bias.astype(self.dtype) out = out + jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,)) return out
[docs] def get_kernel(self): """Get the dequantized kernel weights.""" return self._dequantize_kernel()
[docs] def get_quantized_kernel(self): """Get the quantized kernel weights and scales.""" return self.quant_kernel.value, self.quant_scales.value
[docs] @staticmethod def metadata(): return {"quant_mode": "int8"}
[docs] @staticmethod def quantization_mapping(): return {"kernel": ["quant_kernel", "quant_scales"]}