Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: full gpu hybrid model #963

Merged
merged 16 commits into from
Jan 6, 2025
4 changes: 2 additions & 2 deletions .github/workflows/run_one_use_cases_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
needs: [start-runner-linux]
runs-on: ${{ needs.start-runner-linux.outputs.label-38 }}
container:
image: ubuntu:20.04
image: ubuntu:22.04
defaults:
run:
shell: bash
Expand All @@ -96,7 +96,7 @@ jobs:
$(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null
apt-get update
apt-get install -y docker-ce docker-ce-cli containerd.io docker-compose-plugin
apt-get install -y python3-venv make git git-lfs binutils
apt-get install -y python3-venv make git git-lfs binutils python3-pip

- name: Checkout Code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
Expand Down
2 changes: 1 addition & 1 deletion docs/deep-learning/lora_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SimpleMLP(nn.Module):
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)

def forward(self, x):
def forward(self, x, labels=None):
"""Forward pass of the MLP."""
out = self.fc1(x)
out = self.relu(out)
Expand Down
144 changes: 83 additions & 61 deletions src/concrete/ml/quantization/linear_op_glwe_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import json

import numpy
import torch

from ..common.utils import HybridFHEMode, to_tuple
from .quantized_module import QuantizedModule
from .quantizers import TorchUniformQuantizer


def has_glwe_backend():
Expand Down Expand Up @@ -54,9 +56,29 @@ def keygen(self):
self.glwe_crypto_params
)

def _forward_clear(
self,
x: torch.Tensor,
q_module: QuantizedModule,
transpose_inputs1: bool,
q_weight: numpy.ndarray,
) -> torch.Tensor:
quantizer = TorchUniformQuantizer(q_module.input_quantizers[0])
out_quantizer = TorchUniformQuantizer(q_module.output_quantizers[0])

q_x_torch = quantizer.quant(x, dtype=torch.float32)
q_x_torch = torch.transpose(q_x_torch, 1, 0) if transpose_inputs1 else q_x_torch

# There is no need to add the bias to the de-quantized values
# as the bias is already included in the output quantizer
# zero-point, in the analytical calibration
q_w = torch.from_numpy(q_weight).to(q_x_torch.device)
mm = torch.matmul(q_x_torch, q_w)
return out_quantizer.dequant(mm)

def forward(
self, x: numpy.ndarray, q_module: QuantizedModule, fhe: HybridFHEMode
) -> numpy.ndarray:
self, x: torch.Tensor, q_module: QuantizedModule, fhe: HybridFHEMode
) -> torch.Tensor:
"""Perform the inference of this linear layer.

Args:
Expand Down Expand Up @@ -91,78 +113,78 @@ def forward(
assert weight_bias[0].quantizer.quant_params.zero_point == 0

# Retrieve quantized weights
q_weight = weight_bias[0].qvalues
q_weight = weight_bias[0].values
assert isinstance(q_weight, numpy.ndarray)
assert q_weight.dtype == numpy.float32

q_weight = numpy.transpose(q_weight) if transpose_inputs2 else q_weight

q_x = q_module.quantize_input(x)
if fhe == HybridFHEMode.DISABLE:
return self._forward_clear(x, q_module, transpose_inputs1, q_weight)

if self.private_key is None:
self.keygen() # pragma: no cover

x_device = x.device
q_x = q_module.quantize_input(x.cpu().numpy())
assert q_x is not None
assert isinstance(q_x, numpy.ndarray)

q_x = numpy.transpose(q_x) if transpose_inputs1 else q_x

if fhe == HybridFHEMode.DISABLE:
# There is no need to add the bias to the de-quantized values
# as the bias is already included in the output quantizer
# zero-point, in the analytical calibration
q_x = q_x.astype(numpy.float32)
q_weight = q_weight.astype(numpy.float32)
y = q_module.dequantize_output(*to_tuple(numpy.matmul(q_x, q_weight)))
else:
# Need to slice the last GLWE (this will be improved in later cml-extensions)
num_valid_glwe_values_in_last_ciphertext = (
q_weight.shape[1] % self.poly_size or self.poly_size
)
# Need to slice the last GLWE (this will be improved in later cml-extensions)
num_valid_glwe_values_in_last_ciphertext = (
q_weight.shape[1] % self.poly_size or self.poly_size
)

# The GLWE backend needs uint64 encoding for both neg/pos values
q_weight = q_weight.astype(numpy.uint64)
# The GLWE backend needs uint64 encoding for both neg/pos values
q_weight = q_weight.astype(numpy.uint64)

# Some models have (B, C, H)-size activations,
# for example LLMs: B=batch size, C=context length, H=hidden dime
# while other models only have (B, H)-size activations.
# Add a B=1 dimension if needed
return_2d = False
if q_x.ndim == 2:
return_2d = True
q_x = numpy.expand_dims(q_x, 0)
# Some models have (B, C, H)-size activations,
# for example LLMs: B=batch size, C=context length, H=hidden dime
# while other models only have (B, H)-size activations.
# Add a B=1 dimension if needed
return_2d = False
if q_x.ndim == 2:
return_2d = True
q_x = numpy.expand_dims(q_x, 0)

# The GLWE backend needs contiguous memory uint64 encoding for both neg/pos values
q_x = numpy.ascontiguousarray(q_x.astype(numpy.uint64))
# The GLWE backend needs contiguous memory uint64 encoding for both neg/pos values
q_x = numpy.ascontiguousarray(q_x.astype(numpy.uint64))

assert q_weight.ndim == 2
result_buffer = numpy.zeros(
(q_x.shape[0], q_x.shape[1], q_weight.shape[1]), dtype=numpy.int64
assert q_weight.ndim == 2
result_buffer = numpy.zeros(
(q_x.shape[0], q_x.shape[1], q_weight.shape[1]), dtype=numpy.int64
)

for idx, q_x_sample in enumerate(q_x):

ciphertext = self.fhext.encrypt_matrix( # pylint: disable=no-member
pkey=self.private_key, crypto_params=self.glwe_crypto_params, data=q_x_sample
)
encrypted_result = self.fhext.matrix_multiplication( # pylint: disable=no-member
encrypted_matrix=ciphertext,
data=q_weight.astype(numpy.uint64),
compression_key=self.compression_key,
)
q_result = self.fhext.decrypt_matrix( # pylint: disable=no-member
encrypted_result,
self.private_key,
self.glwe_crypto_params,
num_valid_glwe_values_in_last_ciphertext,
)
q_result = q_result.astype(numpy.int64)

result_buffer[idx, :] = q_result

# There is no need to add the bias to the de-quantized values
# as the bias is already included in the output quantizer
# zero-point, in the analytical calibration
y = q_module.dequantize_output(*to_tuple(result_buffer))

for idx, q_x_sample in enumerate(q_x):

ciphertext = self.fhext.encrypt_matrix( # pylint: disable=no-member
pkey=self.private_key, crypto_params=self.glwe_crypto_params, data=q_x_sample
)
encrypted_result = self.fhext.matrix_multiplication( # pylint: disable=no-member
encrypted_matrix=ciphertext,
data=q_weight.astype(numpy.uint64),
compression_key=self.compression_key,
)
q_result = self.fhext.decrypt_matrix( # pylint: disable=no-member
encrypted_result,
self.private_key,
self.glwe_crypto_params,
num_valid_glwe_values_in_last_ciphertext,
)
q_result = q_result.astype(numpy.int64)

result_buffer[idx, :] = q_result

# There is no need to add the bias to the de-quantized values
# as the bias is already included in the output quantizer
# zero-point, in the analytical calibration
y = q_module.dequantize_output(*to_tuple(result_buffer))

if return_2d:
y = numpy.squeeze(y)

# Only single outputs are supported
assert isinstance(y, numpy.ndarray)

return y
if return_2d:
y = numpy.squeeze(y)

return torch.Tensor(y.astype(numpy.float32)).to(x_device)
8 changes: 6 additions & 2 deletions src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,14 +730,18 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray):
node_results[output_name] = node_output[0]
constants.add(output_name)

def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule:
def quantize_module(
self, *calibration_data: numpy.ndarray, keep_onnx: Optional[bool] = True
Copy link
Contributor

@kcelia kcelia Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean keep_onnx: Optional[bool] = None ? or keep_onnx: bool = True

) -> QuantizedModule:
"""Quantize numpy module.

Following https://arxiv.org/abs/1712.05877 guidelines.

Args:
calibration_data (numpy.ndarray): Data that will be used to compute the bounds,
scales and zero point values for every quantized object.
keep_onnx (bool): keep the onnx model inside the QuantizedModule. Set to False
to save memory. Keeping the onnx model is useful for debugging

Returns:
QuantizedModule: Quantized numpy module
Expand All @@ -760,7 +764,7 @@ def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule:
graph_output.name for graph_output in self.numpy_model.onnx_model.graph.output
),
quant_layers_dict=self.quant_ops_dict,
onnx_model=self.numpy_model.onnx_model,
onnx_model=self.numpy_model.onnx_model if keep_onnx else None,
onnx_preprocessing=self.numpy_model.onnx_preprocessing,
)

Expand Down
10 changes: 6 additions & 4 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,12 +702,14 @@ def _fhe_forward(
return q_results

def quantize_input(
self, *x: Optional[numpy.ndarray]
self, *x: Optional[numpy.ndarray], dtype: numpy.typing.DTypeLike = numpy.int64
) -> Union[numpy.ndarray, Tuple[Optional[numpy.ndarray], ...]]:
"""Take the inputs in fp32 and quantize it using the learned quantization parameters.

Args:
x (Optional[numpy.ndarray]): Floating point x or None.
dtype (numpy.typing.DTypeLike): optional user-specified datatype for the output


Returns:
Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]: Quantized (numpy.int64) x, or None if
Expand All @@ -729,7 +731,7 @@ def quantize_input(
# cannot be None
q_x = tuple(
(
self.input_quantizers[idx].quant(x[idx]) # type: ignore[arg-type]
self.input_quantizers[idx].quant(x[idx], dtype) # type: ignore[arg-type]
if x[idx] is not None
else None
)
Expand All @@ -738,7 +740,7 @@ def quantize_input(

# Make sure all inputs are quantized to int64
assert all_values_are_of_dtype(
*q_x, dtypes="int64", allow_none=True
*q_x, dtypes=numpy.dtype(dtype).name, allow_none=True
), "Inputs were not quantized to int64"

if len(q_x) == 1:
Expand All @@ -750,7 +752,7 @@ def quantize_input(

def dequantize_output(
self, *q_y_preds: numpy.ndarray
) -> Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]:
) -> Union[numpy.ndarray, Tuple[Union[numpy.ndarray], ...]]:
"""Take the last layer q_out and use its de-quant function.

Args:
Expand Down
70 changes: 68 additions & 2 deletions src/concrete/ml/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Any, Dict, Optional, TextIO, Union, get_type_hints

import numpy
import numpy.typing
import torch
from concrete.fhe.tracing.tracer import Tracer

from ..common.debugging import assert_true
Expand Down Expand Up @@ -671,11 +673,14 @@ def dump(self, file: TextIO) -> None:
"""
dump(self, file)

def quant(self, values: numpy.ndarray) -> numpy.ndarray:
def quant(
self, values: numpy.ndarray, dtype: numpy.typing.DTypeLike = numpy.int64
) -> numpy.ndarray:
"""Quantize values.

Args:
values (numpy.ndarray): float values to quantize
dtype (numpy.typing.DTypeLike): optional user-specified datatype for the output

Returns:
numpy.ndarray: Integer quantized values.
Expand All @@ -686,6 +691,8 @@ def quant(self, values: numpy.ndarray) -> numpy.ndarray:
assert self.offset is not None
assert self.scale is not None

assert dtype in (numpy.int64, numpy.int32, numpy.float32, numpy.float64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe:

valid_dtypes = (numpy.int64, numpy.int32, numpy.float32, numpy.float64)
assert dtype in valid_dtypes, f"Invalid dtype: `{dtype}`. Expected one of {valid_dtypes}."


if QUANT_ROUND_LIKE_ROUND_PBS:
qvalues = numpy.floor(values / self.scale + self.zero_point + 0.5) # pragma: no cover
else:
Expand All @@ -707,7 +714,9 @@ def quant(self, values: numpy.ndarray) -> numpy.ndarray:

qvalues = qvalues.clip(min_value, 2 ** (self.n_bits) - 1 - self.offset)

return qvalues.astype(numpy.int64)
qvalues = qvalues.astype(dtype)

return qvalues

def dequant(self, qvalues: numpy.ndarray) -> Union[float, numpy.ndarray, Tracer]:
"""De-quantize values.
Expand Down Expand Up @@ -737,6 +746,63 @@ def dequant(self, qvalues: numpy.ndarray) -> Union[float, numpy.ndarray, Tracer]
return values


class TorchUniformQuantizer:
"""Uniform quantizer with a PyTorch implementation.

Contains all information necessary for uniform quantization and provides
quantization/de-quantization functionality on torch tensors.

Args:
quantizer (UniformQuantizer): Underlying numpy quantizer containing all parameters
"""

_np_quant: UniformQuantizer

def __init__(self, quantizer: UniformQuantizer):
self._np_quant = quantizer

def quant(self, values: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Quantize values.

Args:
values (numpy.ndarray): float values to quantize
dtype (Optional[torch.dtype]): optional user-specified datatype for the output

Returns:
numpy.ndarray: Integer quantized values.
"""
qvalues = torch.round(values / self._np_quant.scale + self._np_quant.zero_point)

if not self._np_quant.no_clipping:
assert self._np_quant.offset is not None
min_value = -self._np_quant.offset
if self._np_quant.is_narrow:
min_value += 1

qvalues = torch.clip(
qvalues, min_value, 2 ** (self._np_quant.n_bits) - 1 - self._np_quant.offset
)

if dtype is not None:
qvalues = qvalues.type(dtype)

return qvalues

def dequant(self, qvalues: torch.Tensor) -> torch.Tensor:
"""De-quantize values.

Args:
qvalues (numpy.ndarray): integer values to de-quantize

Returns:
Union[numpy.ndarray, Tracer]: De-quantized float values.
"""
zp_tensor = torch.tensor(self._np_quant.zero_point).type(qvalues.dtype).to(qvalues.device)

values = self._np_quant.scale * (qvalues - zp_tensor)
return values


class QuantizedArray:
"""Abstraction of quantized array.

Expand Down
Loading
Loading