Skip to content

Commit 4fa9c62

Browse files
shaydecigeoffrey-g-delhomme
authored andcommitted
Bug/sg 861 decouple qat from train from config (Deci-AI#1001)
* adde unit tests * changed local * switch to ema model before quantization if exists * midifying method complete * midifying method cal in pre launch callback * removed option to get the defaults from previous training * added unit tests passing * updated docs and test names * moved logger init * comments resolved
1 parent 11d2ecd commit 4fa9c62

File tree

11 files changed

+694
-253
lines changed

11 files changed

+694
-253
lines changed

.circleci/config.yml

+2
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,8 @@ jobs:
465465
python3.8 -m pip install -r requirements.txt
466466
python3.8 -m pip install .
467467
python3.8 -m pip install torch torchvision torchaudio
468+
python3.8 -m pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com
469+
468470
python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=coco2017_pose_dekr_w32_no_dc experiment_name=shortened_coco2017_pose_dekr_w32_ap_test batch_size=4 val_batch_size=8 epochs=1 training_hyperparams.lr_warmup_steps=0 training_hyperparams.average_best_models=False training_hyperparams.max_train_batches=1000 training_hyperparams.max_valid_batches=100 multi_gpu=DDP num_gpus=4
469471
python3.8 src/super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=cifar10_resnet experiment_name=shortened_cifar10_resnet_accuracy_test epochs=100 training_hyperparams.average_best_models=False multi_gpu=DDP num_gpus=4
470472
python3.8 src/super_gradients/examples/convert_recipe_example/convert_recipe_example.py --config-name=cifar10_conversion_params experiment_name=shortened_cifar10_resnet_accuracy_test

src/super_gradients/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from super_gradients.common.registry.registry import ARCHITECTURES
44
from super_gradients.sanity_check import env_sanity_check
55
from super_gradients.training.utils.distributed_training_utils import setup_device
6+
from super_gradients.training.pre_launch_callbacks import AutoTrainBatchSizeSelectionCallback, QATRecipeModificationCallback
67

78
__all__ = [
89
"ARCHITECTURES",
@@ -18,6 +19,8 @@
1819
"is_distributed",
1920
"env_sanity_check",
2021
"setup_device",
22+
"QATRecipeModificationCallback",
23+
"AutoTrainBatchSizeSelectionCallback",
2124
]
2225

2326
__version__ = "3.1.1"

src/super_gradients/qat_from_recipe.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
import hydra
99
from omegaconf import DictConfig
1010

11-
from super_gradients import init_trainer
12-
from super_gradients.training.qat_trainer.qat_trainer import QATTrainer
11+
from super_gradients import init_trainer, Trainer
1312

1413

1514
@hydra.main(config_path="recipes", version_base="1.2")
1615
def _main(cfg: DictConfig) -> None:
17-
QATTrainer.train_from_config(cfg)
16+
Trainer.quantize_from_config(cfg)
1817

1918

2019
def main():

src/super_gradients/training/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from super_gradients.training.kd_trainer import KDTrainer
66
from super_gradients.training.qat_trainer import QATTrainer
77
from super_gradients.common import MultiGPUMode, StrictLoad, EvaluationType
8+
from super_gradients.training.pre_launch_callbacks import modify_params_for_qat
89

910
__all__ = [
1011
"distributed_training_utils",
@@ -16,4 +17,5 @@
1617
"MultiGPUMode",
1718
"StrictLoad",
1819
"EvaluationType",
20+
"modify_params_for_qat",
1921
]

src/super_gradients/training/pre_launch_callbacks/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
PreLaunchCallback,
33
AutoTrainBatchSizeSelectionCallback,
44
QATRecipeModificationCallback,
5+
modify_params_for_qat,
56
)
67
from super_gradients.common.registry.registry import ALL_PRE_LAUNCH_CALLBACKS
78

8-
__all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS"]
9+
__all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "QATRecipeModificationCallback", "ALL_PRE_LAUNCH_CALLBACKS", "modify_params_for_qat"]

src/super_gradients/training/pre_launch_callbacks/pre_launch_callbacks.py

+180-55
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -1,201 +1,17 @@
1-
import os
21
from typing import Union, Tuple
32

4-
import copy
5-
import hydra
6-
import torch.cuda
3+
from deprecated import deprecated
74
from omegaconf import DictConfig
8-
from omegaconf import OmegaConf
95
from torch import nn
106

117
from super_gradients.common.abstractions.abstract_logger import get_logger
12-
from super_gradients.common.environment.device_utils import device_config
13-
from super_gradients.training import utils as core_utils, models, dataloaders, pre_launch_callbacks
148
from super_gradients.training.sg_trainer import Trainer
15-
from super_gradients.training.utils import get_param
16-
from super_gradients.training.utils.distributed_training_utils import setup_device
17-
from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches
189

1910
logger = get_logger(__name__)
20-
try:
21-
from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator
22-
from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx
23-
from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer
24-
25-
_imported_pytorch_quantization_failure = None
26-
27-
except (ImportError, NameError, ModuleNotFoundError) as import_err:
28-
logger.debug("Failed to import pytorch_quantization:")
29-
logger.debug(import_err)
30-
_imported_pytorch_quantization_failure = import_err
3111

3212

3313
class QATTrainer(Trainer):
3414
@classmethod
35-
def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
36-
"""
37-
Perform quantization aware training (QAT) according to a recipe configuration.
38-
39-
This method will instantiate all the objects specified in the recipe, build and quantize the model,
40-
and calibrate the quantized model. The resulting quantized model and the output of the trainer.train()
41-
method will be returned.
42-
43-
The quantized model will be exported to ONNX along with other checkpoints.
44-
45-
:param cfg: The parsed DictConfig object from yaml recipe files or a dictionary.
46-
:return: A tuple containing the quantized model and the output of trainer.train() method.
47-
:rtype: Tuple[nn.Module, Tuple]
48-
49-
:raises ValueError: If the recipe does not have the required key `quantization_params` or
50-
`checkpoint_params.checkpoint_path` in it.
51-
:raises NotImplementedError: If the recipe requests multiple GPUs or num_gpus is not equal to 1.
52-
:raises ImportError: If pytorch-quantization import was unsuccessful
53-
54-
"""
55-
if _imported_pytorch_quantization_failure is not None:
56-
raise _imported_pytorch_quantization_failure
57-
58-
# INSTANTIATE ALL OBJECTS IN CFG
59-
cfg = hydra.utils.instantiate(cfg)
60-
61-
# TRIGGER CFG MODIFYING CALLBACKS
62-
cfg = cls._trigger_cfg_modifying_callbacks(cfg)
63-
64-
if "quantization_params" not in cfg:
65-
raise ValueError("Your recipe does not have quantization_params. Add them to use QAT.")
66-
67-
if "checkpoint_path" not in cfg.checkpoint_params:
68-
raise ValueError("Starting checkpoint is a must for QAT finetuning.")
69-
70-
num_gpus = core_utils.get_param(cfg, "num_gpus")
71-
multi_gpu = core_utils.get_param(cfg, "multi_gpu")
72-
device = core_utils.get_param(cfg, "device")
73-
if num_gpus != 1:
74-
raise NotImplementedError(
75-
f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. QAT is proven to work correctly only with multi_gpu=OFF and num_gpus=1"
76-
)
77-
78-
setup_device(device=device, multi_gpu=multi_gpu, num_gpus=num_gpus)
79-
80-
# INSTANTIATE DATA LOADERS
81-
train_dataloader = dataloaders.get(
82-
name=get_param(cfg, "train_dataloader"),
83-
dataset_params=copy.deepcopy(cfg.dataset_params.train_dataset_params),
84-
dataloader_params=copy.deepcopy(cfg.dataset_params.train_dataloader_params),
85-
)
86-
87-
val_dataloader = dataloaders.get(
88-
name=get_param(cfg, "val_dataloader"),
89-
dataset_params=copy.deepcopy(cfg.dataset_params.val_dataset_params),
90-
dataloader_params=copy.deepcopy(cfg.dataset_params.val_dataloader_params),
91-
)
92-
93-
if "calib_dataloader" in cfg:
94-
calib_dataloader_name = get_param(cfg, "calib_dataloader")
95-
calib_dataloader_params = copy.deepcopy(cfg.dataset_params.calib_dataloader_params)
96-
calib_dataset_params = copy.deepcopy(cfg.dataset_params.calib_dataset_params)
97-
else:
98-
calib_dataloader_name = get_param(cfg, "train_dataloader")
99-
calib_dataloader_params = copy.deepcopy(cfg.dataset_params.train_dataloader_params)
100-
calib_dataset_params = copy.deepcopy(cfg.dataset_params.train_dataset_params)
101-
102-
# if we use whole dataloader for calibration, don't shuffle it
103-
# HistogramCalibrator collection routine is sensitive to order of batches and produces slightly different results
104-
# if we use several batches, we don't want it to be from one class if it's sequential in dataloader
105-
# model is in eval mode, so BNs will not be affected
106-
calib_dataloader_params.shuffle = cfg.quantization_params.calib_params.num_calib_batches is not None
107-
# we don't need training transforms during calibration, distribution of activations will be skewed
108-
calib_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms
109-
110-
calib_dataloader = dataloaders.get(
111-
name=calib_dataloader_name,
112-
dataset_params=calib_dataset_params,
113-
dataloader_params=calib_dataloader_params,
114-
)
115-
116-
# BUILD MODEL
117-
model = models.get(
118-
model_name=cfg.arch_params.get("model_name", None) or cfg.architecture,
119-
num_classes=cfg.get("num_classes", None) or cfg.arch_params.num_classes,
120-
arch_params=cfg.arch_params,
121-
strict_load=cfg.checkpoint_params.strict_load,
122-
pretrained_weights=cfg.checkpoint_params.pretrained_weights,
123-
checkpoint_path=cfg.checkpoint_params.checkpoint_path,
124-
load_backbone=False,
125-
)
126-
model.to(device_config.device)
127-
128-
# QUANTIZE MODEL
129-
model.eval()
130-
fuse_repvgg_blocks_residual_branches(model)
131-
132-
q_util = SelectiveQuantizer(
133-
default_quant_modules_calibrator_weights=cfg.quantization_params.selective_quantizer_params.calibrator_w,
134-
default_quant_modules_calibrator_inputs=cfg.quantization_params.selective_quantizer_params.calibrator_i,
135-
default_per_channel_quant_weights=cfg.quantization_params.selective_quantizer_params.per_channel,
136-
default_learn_amax=cfg.quantization_params.selective_quantizer_params.learn_amax,
137-
verbose=cfg.quantization_params.calib_params.verbose,
138-
)
139-
q_util.register_skip_quantization(layer_names=cfg.quantization_params.selective_quantizer_params.skip_modules)
140-
q_util.quantize_module(model)
141-
142-
# CALIBRATE MODEL
143-
logger.info("Calibrating model...")
144-
calibrator = QuantizationCalibrator(
145-
verbose=cfg.quantization_params.calib_params.verbose,
146-
torch_hist=True,
147-
)
148-
calibrator.calibrate_model(
149-
model,
150-
method=cfg.quantization_params.calib_params.histogram_calib_method,
151-
calib_data_loader=calib_dataloader,
152-
num_calib_batches=cfg.quantization_params.calib_params.num_calib_batches or len(train_dataloader),
153-
percentile=get_param(cfg.quantization_params.calib_params, "percentile", 99.99),
154-
)
155-
calibrator.reset_calibrators(model) # release memory taken by calibrators
156-
157-
# VALIDATE PTQ MODEL AND PRINT SUMMARY
158-
logger.info("Validating PTQ model...")
159-
trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
160-
valid_metrics_dict = trainer.test(model=model, test_loader=val_dataloader, test_metrics_list=cfg.training_hyperparams.valid_metrics_list)
161-
results = ["PTQ Model Validation Results"]
162-
results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()]
163-
logger.info("\n".join(results))
164-
165-
# TRAIN
166-
if cfg.quantization_params.ptq_only:
167-
logger.info("cfg.quantization_params.ptq_only=True. Performing PTQ only!")
168-
suffix = "ptq"
169-
res = None
170-
else:
171-
model.train()
172-
recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
173-
trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None))
174-
torch.cuda.empty_cache()
175-
176-
res = trainer.train(
177-
model=model,
178-
train_loader=train_dataloader,
179-
valid_loader=val_dataloader,
180-
training_params=cfg.training_hyperparams,
181-
additional_configs_to_log=recipe_logged_cfg,
182-
)
183-
suffix = "qat"
184-
185-
# EXPORT QUANTIZED MODEL TO ONNX
186-
input_shape = next(iter(val_dataloader))[0].shape
187-
os.makedirs(trainer.checkpoints_dir_path, exist_ok=True)
188-
189-
qdq_onnx_path = os.path.join(trainer.checkpoints_dir_path, f"{cfg.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx")
190-
# TODO: modify SG's convert_to_onnx for quantized models and use it instead
191-
export_quantized_module_to_onnx(
192-
model=model.cpu(),
193-
onnx_filename=qdq_onnx_path,
194-
input_shape=input_shape,
195-
input_size=input_shape,
196-
train=False,
197-
)
198-
199-
logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}")
200-
201-
return model, res
15+
@deprecated(version="3.2.0", reason="QATTrainer is deprecated and will be removed in future release, use Trainer " "class instead.")
16+
def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]:
17+
return Trainer.quantize_from_config(cfg)

0 commit comments

Comments
 (0)