Source code for easydel.layers.quantization.quantizers

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import re
import typing as tp

import chex
import jax
import tqdm
from eformer.ops.quantization import Array8B, ArrayNF4
from flax import nnx as nn

from easydel.infra.etils import EasyDeLPlatforms, EasyDeLQuantizationMethods

from .linear_8bit import Linear8bit
from .linear_nf4 import LinearNF4

DEFAULT_QUANTIZATION_PATTERN = (
	"(wo|wq|wk|wv|q_proj|k_proj|v_proj|o_proj|w1|w2|w3|"
	"gate_proj|up_proj|down_proj|dense_4h_to_h|dense_h_to_4h|query_key_value|wqkv|Wqkv|"
	"dense|proj_1|proj_2|out_proj|qkv_proj)"
)

METHOD_TO_LINEAR_MAPPING = {
	EasyDeLQuantizationMethods.NF4: LinearNF4,
	EasyDeLQuantizationMethods.A8BIT: Linear8bit,
}


[docs]class EasyQuantizer: def __init__( self, quantization_method: EasyDeLQuantizationMethods = EasyDeLQuantizationMethods.NF4, quantization_platform: tp.Optional[EasyDeLPlatforms] = EasyDeLPlatforms.JAX, quantization_pattern: tp.Optional[str] = None, block_size: int = 256, **kwargs, ) -> None: self.block_size = block_size self.quantization_method = quantization_method self.quantization_platform = quantization_platform if quantization_pattern is None: quantization_pattern = r".*(?:embedding|layernorm|norm)$" self.quantization_pattern = quantization_pattern @jax.named_scope("easydel-easyquantize-call") def __call__( self, array, path: tp.Optional[tp.Union[str, tp.Tuple[str]]] = None, ) -> chex.Array: should_be_quantized = True if path is not None: if isinstance(path, list): path = tuple(path) if isinstance(path, tuple): path = map(str, path) path = ".".join(path) if self.quantization_pattern is not None: should_be_quantized = not bool(re.match(self.quantization_pattern, path)) if not should_be_quantized: return array match self.quantization_method: case EasyDeLQuantizationMethods.A8BIT: return Array8B.quantize(array=array) case EasyDeLQuantizationMethods.NF4: should_be_quantized = True if array.size % self.block_size != 0: should_be_quantized = False if should_be_quantized: return ArrayNF4.quantize(array=array, block_size=self.block_size) return array case EasyDeLQuantizationMethods.NONE: return array case None: return array case _: raise ValueError(f"unknown quantization_method {self.quantization_method}.")
[docs] def quantize_linears( self, model: nn.Module, /, *, quantization_pattern: tp.Optional[str] = None, verbose: bool = True, ) -> nn.Module: """ Quantize parameters to requested precision, excluding specified layers. Args: model: The model to quantize. quantization_pattern (str): re pattern for layers to be quantized. verbose (bool): whenever to use tqdm for logging stuff. Returns: Quantized parameters in the same structure as the input. """ if ( self.quantization_method == EasyDeLQuantizationMethods.NONE or self.quantization_method is None ): return model from easydel.utils.graph_utils import ( get_module_from_path, iter_module_search, set_module_from_path, ) quantizer = METHOD_TO_LINEAR_MAPPING.get(self.quantization_method, None) if quantizer is None: raise NotImplementedError("Requested Quantizer is not Supported") if quantization_pattern is None: quantization_pattern = self.quantization_pattern if hasattr(model, "config"): model.config.quantization_method = self.quantization_method model.config.quantization_block_size = self.block_size model.config.quantization_pattern = quantization_pattern pattern = re.compile(quantization_pattern) with tqdm.tqdm( total=len([p[0] for p in iter_module_search(model, nn.Linear)]), desc=f"Quantizing to {self.quantization_method}", disable=not verbose, ) as pbar: for path, _ in iter_module_search(model, nn.Linear): if pattern.search(".".join([str(p) for p in path])): set_module_from_path( model=model, path=path, new_value=quantizer.from_linear( linear=get_module_from_path(model=model, path=path), rngs=None, block_size=self.block_size, ), ) pbar.update(1) return model