Skip to content

Commit

Permalink
Fix issue 1058 (#1060)
Browse files Browse the repository at this point in the history
Fix an issue that imports onnx in the exporter during init if torch was found:
* In fakely_quant_onnx_pytorch_exporter.py, import onnx-related modules only if onnx was found. Add "dummy" FakelyQuantONNXPyTorchExporter to raise an error if it was used without Onnx installed.
* Move DEFAULT_ONNX_OPSET_VERSION to pytorch_export_facade.py to use it only there since it's in the pytorch_export_facade api.

---------

Co-authored-by: reuvenp <[email protected]>
  • Loading branch information
reuvenperetz and reuvenp authored May 9, 2024
1 parent 98f771f commit c022f09
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,117 +16,123 @@
from io import BytesIO

import torch.nn
import onnx

from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
from model_compression_toolkit.constants import FOUND_ONNX
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
from mct_quantizers import pytorch_quantizers
from mct_quantizers.pytorch.metadata import add_onnx_metadata

DEFAULT_ONNX_OPSET_VERSION=15


class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
"""
Exporter for fakely-quant PyTorch models.
The exporter expects to receive an exportable model (where each layer's full quantization parameters
can be retrieved), and convert it into a fakely-quant model (namely, weights that are in fake-quant
format) and fake-quant layers for the activations.
"""

def __init__(self,
model: torch.nn.Module,
is_layer_exportable_fn: Callable,
save_model_path: str,
repr_dataset: Callable,
use_onnx_custom_quantizer_ops: bool = False,
onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION):
"""
Args:
model: Model to export.
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
save_model_path: Path to save the exported model.
repr_dataset: Representative dataset (needed for creating torch script).
use_onnx_custom_quantizer_ops: Whether to export quantizers custom ops in ONNX or not.
onnx_opset_version: ONNX opset version to use for exported ONNX model.
"""

super().__init__(model,
is_layer_exportable_fn,
save_model_path,
repr_dataset)

self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
self._onnx_opset_version = onnx_opset_version
if FOUND_ONNX:
import onnx
from mct_quantizers.pytorch.metadata import add_onnx_metadata

def export(self) -> None:
class FakelyQuantONNXPyTorchExporter(BasePyTorchExporter):
"""
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.
Returns:
Fake-quant PyTorch model.
Exporter for fakely-quant PyTorch models.
The exporter expects to receive an exportable model (where each layer's full quantization parameters
can be retrieved), and convert it into a fakely-quant model (namely, weights that are in fake-quant
format) and fake-quant layers for the activations.
"""
for layer in self.model.children():
self.is_layer_exportable_fn(layer)

# Set forward that is used during onnx export.
# If _use_onnx_custom_quantizer_ops is set to True, the quantizer forward function will use
# the custom implementation when exporting the operator into onnx model. If not, it removes the
# wraps and quantizes the ops in place (for weights, for activation torch quantization function is
# exported since it's used during forward).
if self._use_onnx_custom_quantizer_ops:
self._enable_onnx_custom_ops_export()
else:
self._substitute_fully_quantized_model()

if self._use_onnx_custom_quantizer_ops:
Logger.info(f"Exporting onnx model with MCTQ quantizers: {self.save_model_path}")
else:
Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}")

model_input = to_torch_tensor(next(self.repr_dataset())[0])

if hasattr(self.model, 'metadata'):
onnx_bytes = BytesIO()
torch.onnx.export(self.model,
model_input,
onnx_bytes,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
onnx_model = add_onnx_metadata(onnx_model, self.model.metadata)
onnx.save_model(onnx_model, self.save_model_path)
else:
torch.onnx.export(self.model,
model_input,
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})

def _enable_onnx_custom_ops_export(self):
"""
Enable the custom implementation forward in quantizers, so it is exported
with custom quantizers.
"""

for n, m in self.model.named_modules():
if isinstance(m, PytorchActivationQuantizationHolder):
assert isinstance(m.activation_holder_quantizer, pytorch_quantizers.BasePyTorchInferableQuantizer)
m.activation_holder_quantizer.enable_custom_impl()

if isinstance(m, PytorchQuantizationWrapper):
for wq in m.weights_quantizers.values():
assert isinstance(wq, pytorch_quantizers.BasePyTorchInferableQuantizer)
wq.enable_custom_impl()
def __init__(self,
model: torch.nn.Module,
is_layer_exportable_fn: Callable,
save_model_path: str,
repr_dataset: Callable,
onnx_opset_version: int,
use_onnx_custom_quantizer_ops: bool = False):
"""
Args:
model: Model to export.
is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
save_model_path: Path to save the exported model.
repr_dataset: Representative dataset (needed for creating torch script).
onnx_opset_version: ONNX opset version to use for exported ONNX model.
use_onnx_custom_quantizer_ops: Whether to export quantizers custom ops in ONNX or not.
"""

super().__init__(model,
is_layer_exportable_fn,
save_model_path,
repr_dataset)

self._use_onnx_custom_quantizer_ops = use_onnx_custom_quantizer_ops
self._onnx_opset_version = onnx_opset_version

def export(self) -> None:
"""
Convert an exportable (fully-quantized) PyTorch model to a fakely-quant model
(namely, weights that are in fake-quant format) and fake-quant layers for the activations.
Returns:
Fake-quant PyTorch model.
"""
for layer in self.model.children():
self.is_layer_exportable_fn(layer)

# Set forward that is used during onnx export.
# If _use_onnx_custom_quantizer_ops is set to True, the quantizer forward function will use
# the custom implementation when exporting the operator into onnx model. If not, it removes the
# wraps and quantizes the ops in place (for weights, for activation torch quantization function is
# exported since it's used during forward).
if self._use_onnx_custom_quantizer_ops:
self._enable_onnx_custom_ops_export()
else:
self._substitute_fully_quantized_model()

if self._use_onnx_custom_quantizer_ops:
Logger.info(f"Exporting onnx model with MCTQ quantizers: {self.save_model_path}")
else:
Logger.info(f"Exporting fake-quant onnx model: {self.save_model_path}")

model_input = to_torch_tensor(next(self.repr_dataset())[0])

if hasattr(self.model, 'metadata'):
onnx_bytes = BytesIO()
torch.onnx.export(self.model,
model_input,
onnx_bytes,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})
onnx_model = onnx.load_from_string(onnx_bytes.getvalue())
onnx_model = add_onnx_metadata(onnx_model, self.model.metadata)
onnx.save_model(onnx_model, self.save_model_path)
else:
torch.onnx.export(self.model,
model_input,
self.save_model_path,
opset_version=self._onnx_opset_version,
verbose=False,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})

def _enable_onnx_custom_ops_export(self):
"""
Enable the custom implementation forward in quantizers, so it is exported
with custom quantizers.
"""

for n, m in self.model.named_modules():
if isinstance(m, PytorchActivationQuantizationHolder):
assert isinstance(m.activation_holder_quantizer, pytorch_quantizers.BasePyTorchInferableQuantizer)
m.activation_holder_quantizer.enable_custom_impl()

if isinstance(m, PytorchQuantizationWrapper):
for wq in m.weights_quantizers.values():
assert isinstance(wq, pytorch_quantizers.BasePyTorchInferableQuantizer)
wq.enable_custom_impl()

else:
def FakelyQuantONNXPyTorchExporter(*args, **kwargs):
Logger.critical("ONNX must be installed to use 'FakelyQuantONNXPyTorchExporter'. "
"The 'onnx' package is missing.") # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from model_compression_toolkit.exporter.model_exporter.pytorch.export_serialization_format import \
PytorchExportSerializationFormat
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities


DEFAULT_ONNX_OPSET_VERSION = 15


if FOUND_TORCH:
import torch.nn
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
FakelyQuantONNXPyTorchExporter, DEFAULT_ONNX_OPSET_VERSION
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import \
FakelyQuantTorchScriptPyTorchExporter
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import FakelyQuantONNXPyTorchExporter
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter
from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable

supported_serialization_quantization_export_dict = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
DEFAULT_ONNX_OPSET_VERSION
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION

from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
generate_pytorch_tpc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from onnx import numpy_helper

import model_compression_toolkit as mct
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
DEFAULT_ONNX_OPSET_VERSION
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION

from tests.pytorch_tests.exporter_tests.base_pytorch_export_test import BasePytorchExportTest


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import mct_quantizers
import model_compression_toolkit as mct
from model_compression_toolkit.constants import FOUND_ONNX, FOUND_ONNXRUNTIME
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
DEFAULT_ONNX_OPSET_VERSION
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION

from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
generate_pytorch_tpc
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import model_compression_toolkit as mct
from mct_quantizers import QuantizationMethod
from model_compression_toolkit.constants import FOUND_ONNX, FOUND_ONNXRUNTIME
from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import \
DEFAULT_ONNX_OPSET_VERSION
from model_compression_toolkit.exporter.model_exporter.pytorch.pytorch_export_facade import DEFAULT_ONNX_OPSET_VERSION

from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
generate_pytorch_tpc
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model
Expand Down

0 comments on commit c022f09

Please sign in to comment.