Skip to content

Commit

Permalink
fix: refactor quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Dec 23, 2024
1 parent 08a9038 commit 2a326f3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 25 deletions.
7 changes: 5 additions & 2 deletions src/concrete/ml/quantization/linear_op_glwe_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from ..common.utils import HybridFHEMode, to_tuple
from .quantizers import TorchUniformQuantizer
from .quantized_module import QuantizedModule


Expand Down Expand Up @@ -98,8 +99,10 @@ def forward(

q_weight = numpy.transpose(q_weight) if transpose_inputs2 else q_weight

q_x = q_module.quantize_input(
x, dtype=numpy.float32 if fhe == HybridFHEMode.DISABLE else None
quantizer = TorchUniformQuantizer(q_module.input_quantizers[0])

q_x = quantizer.quant(
x, dtype=torch.float32 if fhe == HybridFHEMode.DISABLE else None
)
q_x = torch.transpose(q_x) if transpose_inputs1 else q_x

Expand Down
10 changes: 4 additions & 6 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def _fhe_forward(
return q_results

def quantize_input(
self, *x: Optional[Union[numpy.ndarray]], dtype=numpy.int64
self, *x: Optional[numpy.ndarray], dtype=numpy.int64
) -> Union[numpy.ndarray, Tuple[Optional[numpy.ndarray], ...]]:
"""Take the inputs in fp32 and quantize it using the learned quantization parameters.
Expand Down Expand Up @@ -750,8 +750,8 @@ def quantize_input(
return q_x

def dequantize_output(
self, *q_y_preds: Union[numpy.ndarray]
) -> Union[Union[numpy.ndarray], Tuple[Union[numpy.ndarray], ...]]:
self, *q_y_preds: numpy.ndarray
) -> Union[numpy.ndarray, Tuple[Union[numpy.ndarray], ...]]:
"""Take the last layer q_out and use its de-quant function.
Args:
Expand All @@ -768,12 +768,10 @@ def dequantize_output(
)

y_preds = tuple(
output_quantizer.dequant(q_y_pred)
numpy.array(output_quantizer.dequant(q_y_pred))
for q_y_pred, output_quantizer in zip(q_y_preds, self.output_quantizers)
)

y_preds = tuple(map(numpy.array, y_preds))

if len(y_preds) == 1:
return y_preds[0]

Expand Down
49 changes: 32 additions & 17 deletions src/concrete/ml/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,12 +689,10 @@ def quant(self, values: numpy.ndarray, dtype=numpy.int64) -> numpy.ndarray:

assert dtype in (numpy.int64, numpy.int32, numpy.float32, numpy.float64)

delta = 0.5 if QUANT_ROUND_LIKE_ROUND_PBS else 0

round_func = numpy.floor if QUANT_ROUND_LIKE_ROUND_PBS else numpy.rint
clip_func = numpy.clip

qvalues = round_func(values / self.scale + self.zero_point + delta)
if QUANT_ROUND_LIKE_ROUND_PBS:
qvalues = numpy.floor(values / self.scale + self.zero_point + 0.5) # pragma: no cover
else:
qvalues = numpy.rint(values / self.scale + self.zero_point)

# Clipping must be performed for PTQ and for precomputed (for now only Brevitas) QAT
# (where quantizer parameters are available in ONNX layers).
Expand All @@ -710,7 +708,7 @@ def quant(self, values: numpy.ndarray, dtype=numpy.int64) -> numpy.ndarray:
if self.is_narrow:
min_value += 1

qvalues = clip_func(qvalues, min_value, 2 ** (self.n_bits) - 1 - self.offset)
qvalues = qvalues.clip(min_value, 2 ** (self.n_bits) - 1 - self.offset)

qvalues = qvalues.astype(dtype)

Expand Down Expand Up @@ -738,17 +736,18 @@ def dequant(self, qvalues: numpy.ndarray) -> Union[float, numpy.ndarray, Tracer]
+ ((" " + str(self.scale.dtype)) if isinstance(self.scale, numpy.ndarray) else ""),
)

prepared_zp = numpy.asarray(self.zero_point, dtype=numpy.float64)
if isinstance(qvalues, torch.Tensor):
prepared_zp = torch.from_numpy(prepared_zp).float().to(qvalues.device)

values = self.scale * (qvalues - prepared_zp)
values = self.scale * (qvalues - numpy.asarray(self.zero_point, dtype=numpy.float64))

assert isinstance(values, (float, numpy.ndarray, torch.Tensor, Tracer)), f"{values=}, {type(values)=}"
assert isinstance(values, (float, numpy.ndarray, Tracer)), f"{values=}, {type(values)=}"
return values

class TorchUniformQuantizer(UniformQuantizer):
def quant(self, values: Union[numpy.ndarray], dtype=numpy.int64) -> numpy.ndarray:
class TorchUniformQuantizer():
_numpy_quantizer: UniformQuantizer

def __init__(self, quantizer: UniformQuantizer):
self._numpy_quantizer = quantizer

def quant(self, values: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Quantize values.
Args:
Expand All @@ -757,8 +756,22 @@ def quant(self, values: Union[numpy.ndarray], dtype=numpy.int64) -> numpy.ndarra
Returns:
numpy.ndarray: Integer quantized values.
"""
qvalues = torch.round(values / self._numpy_quantizer.scale + self._numpy_quantizer.zero_point)

if not self._numpy_quantizer.no_clipping:
min_value = -self._numpy_quantizer.offset
if self._numpy_quantizer.is_narrow:
min_value += 1

qvalues = torch.clip(qvalues, min_value, 2 ** (self.n_bits) - 1 - self._numpy_quantizer.offset)

def dequant(self, qvalues: Union[numpy.ndarray]) -> Union[float, numpy.ndarray, Tracer]:
if dtype is not None:
qvalues = qvalues.type(dtype)

return qvalues


def dequant(self, qvalues: torch.Tensor) -> torch.Tensor:
"""De-quantize values.
Args:
Expand All @@ -767,7 +780,9 @@ def dequant(self, qvalues: Union[numpy.ndarray]) -> Union[float, numpy.ndarray,
Returns:
Union[numpy.ndarray, Tracer]: De-quantized float values.
"""

values = self._numpy_quantizer.scale * (qvalues - self._numpy_quantizer.zero_point)
return values

class QuantizedArray:
"""Abstraction of quantized array.
Expand Down

0 comments on commit 2a326f3

Please sign in to comment.