Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply weight compression after model save to reduce peak RAM during export #878

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import gc
import logging
import operator
import warnings
from functools import reduce
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

Expand All @@ -23,18 +25,20 @@
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
from transformers.utils import is_torch_available

from openvino.runtime import Core, Type, save_model
from optimum.exporters import TasksManager
from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
from optimum.exporters.openvino.convert import export_from_model
from optimum.intel.utils.import_utils import (
is_nncf_available,
is_openvino_tokenizers_available,
is_openvino_version,
is_transformers_version,
)
from optimum.utils.save_utils import maybe_load_preprocessors

from .utils import clear_class_registry
from .utils import _MAX_UNCOMPRESSED_SIZE, clear_class_registry


if TYPE_CHECKING:
Expand Down Expand Up @@ -402,7 +406,7 @@ class StoreAttr(object):
model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
)

export_from_model(
submodel_paths = export_from_model(
model=model,
output=output,
task=task,
Expand All @@ -425,6 +429,52 @@ class StoreAttr(object):
del model
gc.collect()

core = Core()
compressed_submodel_paths = []
for submodel_path in submodel_paths:
submodel_path = Path(output) / submodel_path
submodel = core.read_model(submodel_path)

quantization_config = None
if ov_config is None:
num_parameters = 0
for op in submodel.get_ops():
if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]:
num_parameters += reduce(operator.mul, op.shape, 1)
if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
quantization_config = {"bits": 8, "sym": False}
logger.info("The model weights will be quantized to int8_asym.")
else:
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 "
"requires nncf. Please install it with `pip install nncf`"
)
break
else:
quantization_config = ov_config.quantization_config
if quantization_config is None:
continue

if not is_nncf_available():
raise ImportError("Quantization of the weights requires nncf, please install it with `pip install nncf`")

from optimum.intel.openvino.quantization import _weight_only_quantization

_weight_only_quantization(submodel, quantization_config)

compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml"
save_model(submodel, compressed_submodel_path, compress_to_fp16=ov_config and ov_config.dtype == "fp16")
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
compressed_submodel_paths.append((submodel_path, compressed_submodel_path))

del submodel

for submodel_path, compressed_submodel_path in compressed_submodel_paths:
submodel_path.unlink()
submodel_path.with_suffix(".bin").unlink()
compressed_submodel_path.rename(submodel_path)
compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin"))

# Unpatch modules after GPTQ export
if do_gptq_patching:
torch.cuda.is_available = orig_cuda_check
Expand Down
38 changes: 3 additions & 35 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from .model_patcher import patch_model_with_bettertransformer
from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful
from .utils import (
_MAX_UNCOMPRESSED_SIZE,
OV_XML_FILE_NAME,
clear_class_registry,
flattenize_inputs,
Expand All @@ -76,21 +75,7 @@


def _save_model(model, path: str, ov_config: Optional["OVConfig"] = None, library_name: Optional[str] = None):
compress_to_fp16 = False

if ov_config is not None:
if ov_config.quantization_config:
if not is_nncf_available():
raise ImportError(
"Quantization of the weights to int8 requires nncf, please install it with `pip install nncf`"
)

from optimum.intel.openvino.quantization import _weight_only_quantization

_weight_only_quantization(model, ov_config.quantization_config)

compress_to_fp16 = ov_config.dtype == "fp16"

compress_to_fp16 = ov_config is not None and ov_config.dtype == "fp16"
model = _add_version_info_to_model(model, library_name)
save_model(model, path, compress_to_fp16)

Expand Down Expand Up @@ -643,25 +628,6 @@ def export_from_model(
)
logging.disable(logging.NOTSET)

if ov_config is None:
if library_name == "diffusers":
num_parameters = model.unet.num_parameters()
else:
num_parameters = sum(param.numel() for param in list(model.parameters()) if param.requires_grad)

if num_parameters >= _MAX_UNCOMPRESSED_SIZE:
if is_nncf_available():
from ...intel.openvino.configuration import OVConfig

ov_config = OVConfig(quantization_config={"bits": 8, "sym": False})

logger.info("The model weights will be quantized to int8_asym.")
else:
logger.warning(
"The model will be converted with no weights quantization. Quantization of the weights to int8 requires nncf."
"please install it with `pip install nncf`"
)

if library_name != "diffusers":
# Saving the model config and preprocessor as this is needed sometimes.
model.config.save_pretrained(output)
Expand Down Expand Up @@ -720,6 +686,8 @@ def export_from_model(
patch_16bit_model=patch_16bit_model,
)

return files_subpaths


def export_tokenizer(
tokenizer,
Expand Down
138 changes: 72 additions & 66 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect

# ruff: noqa

Expand All @@ -22,6 +23,7 @@
from enum import Enum
from functools import partial
from typing import Union

import pytest
import evaluate
import numpy as np
Expand Down Expand Up @@ -538,76 +540,80 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type):
self.assertEqual(0, num_int8)

def test_ovmodel_load_large_model_with_default_compressed_weights(self):
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
mock_tensor = unittest.mock.Mock()
mock_tensor.numel = lambda: 2000000000
mock_tensor.requires_grad = True
model_parameters.return_value = [mock_tensor]
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, compile=False, use_cache=False
)
save_model_patch.assert_called_with(
unittest.mock.ANY,
unittest.mock.ANY,
ov_config=OVConfig(quantization_config={"bits": 8}),
library_name="transformers",
)
def main_export_in_stacktrace(*args, **kwargs):
# Compression was called from `main_export`
self.assertTrue(inspect.stack()[5].function == "main_export")

with unittest.mock.patch(
"openvino.runtime.op.Constant.shape", new_callable=unittest.mock.PropertyMock
) as ov_constant_shape:
ov_constant_shape.return_value = (2000000000,)
with unittest.mock.patch(
"nncf.compress_weights", side_effect=main_export_in_stacktrace
) as compress_weights_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, compile=False, use_cache=False
)
compression_params = {
"mode": nncf.CompressWeightsMode.INT8_ASYM,
"ratio": 1.0,
"group_size": -1,
"all_layers": None,
"sensitivity_metric": None,
"dataset": None,
"ignored_scope": nncf.IgnoredScope(),
"awq": None,
"subset_size": 128,
"scale_estimation": None,
}
compress_weights_patch.assert_called_with(
unittest.mock.ANY,
**compression_params,
)

def test_ovmodel_load_large_model_with_uncompressed_weights(self):
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
mock_tensor = unittest.mock.Mock()
mock_tensor.numel = lambda: 2000000000
mock_tensor.requires_grad = True
model_parameters.return_value = [mock_tensor]
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, load_in_8bit=False, compile=False, use_cache=False
)
save_model_patch.assert_called_with(
unittest.mock.ANY,
unittest.mock.ANY,
ov_config=OVConfig(dtype="auto"),
library_name="transformers",
)
with unittest.mock.patch(
"openvino.runtime.op.Constant.shape", new_callable=unittest.mock.PropertyMock
) as ov_constant_shape:
ov_constant_shape.return_value = (2000000000,)
with unittest.mock.patch("nncf.compress_weights") as compress_weights_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, load_in_8bit=False, compile=False, use_cache=False
)
compress_weights_patch.assert_not_called()

def test_ovmodel_load_large_model_with_additional_quantization_config(self):
with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
mock_tensor = unittest.mock.Mock()
mock_tensor.numel = lambda: 2000000000
mock_tensor.requires_grad = True
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
with unittest.mock.patch("nncf.compress_weights") as compress_weights_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"],
export=True,
compile=False,
use_cache=False,
quantization_config=OVWeightQuantizationConfig(bits=4, sym=True, group_size=-1, ratio=0.8),
)
# quantization will be performed later, using load_model
save_model_patch.assert_called_with(
unittest.mock.ANY,
unittest.mock.ANY,
ov_config=OVConfig(dtype="auto"),
library_name="transformers",
)
compression_params = {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"ratio": 0.8,
"group_size": -1,
"all_layers": None,
"sensitivity_metric": None,
"dataset": None,
"ignored_scope": nncf.IgnoredScope(),
"awq": None,
"subset_size": 128,
"scale_estimation": None,
}
compress_weights_patch.assert_called_with(unittest.mock.ANY, **compression_params)
def main_export_not_in_stacktrace(*args, **kwargs):
# Compression was not called from `main_export`
self.assertTrue(all(frame_info.function != "main_export" for frame_info in inspect.stack()))

with unittest.mock.patch(
"openvino.runtime.op.Constant.shape", new_callable=unittest.mock.PropertyMock
) as ov_constant_shape:
ov_constant_shape.return_value = (2000000000,)
with unittest.mock.patch(
"nncf.compress_weights", side_effect=main_export_not_in_stacktrace
) as compress_weights_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"],
export=True,
compile=False,
use_cache=False,
quantization_config=OVWeightQuantizationConfig(bits=4, sym=True, group_size=-1, ratio=0.8),
)
compression_params = {
"mode": nncf.CompressWeightsMode.INT4_SYM,
"ratio": 0.8,
"group_size": -1,
"all_layers": None,
"sensitivity_metric": None,
"dataset": None,
"ignored_scope": nncf.IgnoredScope(),
"awq": None,
"subset_size": 128,
"scale_estimation": None,
}
compress_weights_patch.assert_called_with(unittest.mock.ANY, **compression_params)

@parameterized.expand(LOAD_IN_4_BITS_SCOPE)
def test_ovmodel_4bit_dynamic_with_config(self, model_cls, model_name, quantization_config, expected_ov_int4):
Expand Down