easydel.layers.quantization.linear_8bit#

class easydel.layers.quantization.linear_8bit.Linear8bit(*args: Any, **kwargs: Any)[source]#

Bases: QauntModule

An 8-bit quantized version of the linear transformation applied over the last dimension of the input.

classmethod from_linear(linear: Linear, rngs: Optional[Rngs] = None, **kwargs) Linear8bit[source]#

Create a Linear8bit module from a regular Linear module.

Parameters
  • linear – The source Linear module

  • rngs – Random number generator state

Returns

A new Linear8bit module with quantized weights

get_kernel()[source]#

Get the dequantized quant_kernel weights.

get_quantized_kernel()[source]#

Get the quantized quant_kernel weights and quant_scales.

static metadata()[source]#
static quantization_mapping()[source]#
to_linear(rngs: Optional[Rngs] = None) Linear[source]#

Convert this Linear8bit module back to a regular Linear module.

Parameters

rngs – Random number generator state

Returns

A new Linear module with dequantized weights

easydel.layers.quantization.linear_8bit.dequantize_8bit(quants, scales)[source]#

Dequantize 8-bit integers back to values using blockwise scaling.

easydel.layers.quantization.linear_8bit.quantize_8bit(x)[source]#

Quantize a row of float32 values to 8-bit integers with blockwise scaling.

easydel.layers.quantization.linear_8bit.quantized_matmul_bwd(transpose_weight, res, grad_output)[source]#

Backward pass computing gradients for quantized matrix multiplication. :param transpose_weight: Whether weight matrix was transposed :param res: Saved values from forward pass :param grad_output: Gradient of loss with respect to output

easydel.layers.quantization.linear_8bit.quantized_matmul_fwd(x, qweight, qscale, transpose_weight)[source]#

Forward pass that saves required values for backward pass.