|
1 |
| -import os |
2 | 1 | from typing import Union, Tuple
|
3 | 2 |
|
4 |
| -import copy |
5 |
| -import hydra |
6 |
| -import torch.cuda |
| 3 | +from deprecated import deprecated |
7 | 4 | from omegaconf import DictConfig
|
8 |
| -from omegaconf import OmegaConf |
9 | 5 | from torch import nn
|
10 | 6 |
|
11 | 7 | from super_gradients.common.abstractions.abstract_logger import get_logger
|
12 |
| -from super_gradients.common.environment.device_utils import device_config |
13 |
| -from super_gradients.training import utils as core_utils, models, dataloaders, pre_launch_callbacks |
14 | 8 | from super_gradients.training.sg_trainer import Trainer
|
15 |
| -from super_gradients.training.utils import get_param |
16 |
| -from super_gradients.training.utils.distributed_training_utils import setup_device |
17 |
| -from super_gradients.modules.repvgg_block import fuse_repvgg_blocks_residual_branches |
18 | 9 |
|
19 | 10 | logger = get_logger(__name__)
|
20 |
| -try: |
21 |
| - from super_gradients.training.utils.quantization.calibrator import QuantizationCalibrator |
22 |
| - from super_gradients.training.utils.quantization.export import export_quantized_module_to_onnx |
23 |
| - from super_gradients.training.utils.quantization.selective_quantization_utils import SelectiveQuantizer |
24 |
| - |
25 |
| - _imported_pytorch_quantization_failure = None |
26 |
| - |
27 |
| -except (ImportError, NameError, ModuleNotFoundError) as import_err: |
28 |
| - logger.debug("Failed to import pytorch_quantization:") |
29 |
| - logger.debug(import_err) |
30 |
| - _imported_pytorch_quantization_failure = import_err |
31 | 11 |
|
32 | 12 |
|
33 | 13 | class QATTrainer(Trainer):
|
34 | 14 | @classmethod
|
35 |
| - def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]: |
36 |
| - """ |
37 |
| - Perform quantization aware training (QAT) according to a recipe configuration. |
38 |
| -
|
39 |
| - This method will instantiate all the objects specified in the recipe, build and quantize the model, |
40 |
| - and calibrate the quantized model. The resulting quantized model and the output of the trainer.train() |
41 |
| - method will be returned. |
42 |
| -
|
43 |
| - The quantized model will be exported to ONNX along with other checkpoints. |
44 |
| -
|
45 |
| - :param cfg: The parsed DictConfig object from yaml recipe files or a dictionary. |
46 |
| - :return: A tuple containing the quantized model and the output of trainer.train() method. |
47 |
| - :rtype: Tuple[nn.Module, Tuple] |
48 |
| -
|
49 |
| - :raises ValueError: If the recipe does not have the required key `quantization_params` or |
50 |
| - `checkpoint_params.checkpoint_path` in it. |
51 |
| - :raises NotImplementedError: If the recipe requests multiple GPUs or num_gpus is not equal to 1. |
52 |
| - :raises ImportError: If pytorch-quantization import was unsuccessful |
53 |
| -
|
54 |
| - """ |
55 |
| - if _imported_pytorch_quantization_failure is not None: |
56 |
| - raise _imported_pytorch_quantization_failure |
57 |
| - |
58 |
| - # INSTANTIATE ALL OBJECTS IN CFG |
59 |
| - cfg = hydra.utils.instantiate(cfg) |
60 |
| - |
61 |
| - # TRIGGER CFG MODIFYING CALLBACKS |
62 |
| - cfg = cls._trigger_cfg_modifying_callbacks(cfg) |
63 |
| - |
64 |
| - if "quantization_params" not in cfg: |
65 |
| - raise ValueError("Your recipe does not have quantization_params. Add them to use QAT.") |
66 |
| - |
67 |
| - if "checkpoint_path" not in cfg.checkpoint_params: |
68 |
| - raise ValueError("Starting checkpoint is a must for QAT finetuning.") |
69 |
| - |
70 |
| - num_gpus = core_utils.get_param(cfg, "num_gpus") |
71 |
| - multi_gpu = core_utils.get_param(cfg, "multi_gpu") |
72 |
| - device = core_utils.get_param(cfg, "device") |
73 |
| - if num_gpus != 1: |
74 |
| - raise NotImplementedError( |
75 |
| - f"Recipe requests multi_gpu={cfg.multi_gpu} and num_gpus={cfg.num_gpus}. QAT is proven to work correctly only with multi_gpu=OFF and num_gpus=1" |
76 |
| - ) |
77 |
| - |
78 |
| - setup_device(device=device, multi_gpu=multi_gpu, num_gpus=num_gpus) |
79 |
| - |
80 |
| - # INSTANTIATE DATA LOADERS |
81 |
| - train_dataloader = dataloaders.get( |
82 |
| - name=get_param(cfg, "train_dataloader"), |
83 |
| - dataset_params=copy.deepcopy(cfg.dataset_params.train_dataset_params), |
84 |
| - dataloader_params=copy.deepcopy(cfg.dataset_params.train_dataloader_params), |
85 |
| - ) |
86 |
| - |
87 |
| - val_dataloader = dataloaders.get( |
88 |
| - name=get_param(cfg, "val_dataloader"), |
89 |
| - dataset_params=copy.deepcopy(cfg.dataset_params.val_dataset_params), |
90 |
| - dataloader_params=copy.deepcopy(cfg.dataset_params.val_dataloader_params), |
91 |
| - ) |
92 |
| - |
93 |
| - if "calib_dataloader" in cfg: |
94 |
| - calib_dataloader_name = get_param(cfg, "calib_dataloader") |
95 |
| - calib_dataloader_params = copy.deepcopy(cfg.dataset_params.calib_dataloader_params) |
96 |
| - calib_dataset_params = copy.deepcopy(cfg.dataset_params.calib_dataset_params) |
97 |
| - else: |
98 |
| - calib_dataloader_name = get_param(cfg, "train_dataloader") |
99 |
| - calib_dataloader_params = copy.deepcopy(cfg.dataset_params.train_dataloader_params) |
100 |
| - calib_dataset_params = copy.deepcopy(cfg.dataset_params.train_dataset_params) |
101 |
| - |
102 |
| - # if we use whole dataloader for calibration, don't shuffle it |
103 |
| - # HistogramCalibrator collection routine is sensitive to order of batches and produces slightly different results |
104 |
| - # if we use several batches, we don't want it to be from one class if it's sequential in dataloader |
105 |
| - # model is in eval mode, so BNs will not be affected |
106 |
| - calib_dataloader_params.shuffle = cfg.quantization_params.calib_params.num_calib_batches is not None |
107 |
| - # we don't need training transforms during calibration, distribution of activations will be skewed |
108 |
| - calib_dataset_params.transforms = cfg.dataset_params.val_dataset_params.transforms |
109 |
| - |
110 |
| - calib_dataloader = dataloaders.get( |
111 |
| - name=calib_dataloader_name, |
112 |
| - dataset_params=calib_dataset_params, |
113 |
| - dataloader_params=calib_dataloader_params, |
114 |
| - ) |
115 |
| - |
116 |
| - # BUILD MODEL |
117 |
| - model = models.get( |
118 |
| - model_name=cfg.arch_params.get("model_name", None) or cfg.architecture, |
119 |
| - num_classes=cfg.get("num_classes", None) or cfg.arch_params.num_classes, |
120 |
| - arch_params=cfg.arch_params, |
121 |
| - strict_load=cfg.checkpoint_params.strict_load, |
122 |
| - pretrained_weights=cfg.checkpoint_params.pretrained_weights, |
123 |
| - checkpoint_path=cfg.checkpoint_params.checkpoint_path, |
124 |
| - load_backbone=False, |
125 |
| - ) |
126 |
| - model.to(device_config.device) |
127 |
| - |
128 |
| - # QUANTIZE MODEL |
129 |
| - model.eval() |
130 |
| - fuse_repvgg_blocks_residual_branches(model) |
131 |
| - |
132 |
| - q_util = SelectiveQuantizer( |
133 |
| - default_quant_modules_calibrator_weights=cfg.quantization_params.selective_quantizer_params.calibrator_w, |
134 |
| - default_quant_modules_calibrator_inputs=cfg.quantization_params.selective_quantizer_params.calibrator_i, |
135 |
| - default_per_channel_quant_weights=cfg.quantization_params.selective_quantizer_params.per_channel, |
136 |
| - default_learn_amax=cfg.quantization_params.selective_quantizer_params.learn_amax, |
137 |
| - verbose=cfg.quantization_params.calib_params.verbose, |
138 |
| - ) |
139 |
| - q_util.register_skip_quantization(layer_names=cfg.quantization_params.selective_quantizer_params.skip_modules) |
140 |
| - q_util.quantize_module(model) |
141 |
| - |
142 |
| - # CALIBRATE MODEL |
143 |
| - logger.info("Calibrating model...") |
144 |
| - calibrator = QuantizationCalibrator( |
145 |
| - verbose=cfg.quantization_params.calib_params.verbose, |
146 |
| - torch_hist=True, |
147 |
| - ) |
148 |
| - calibrator.calibrate_model( |
149 |
| - model, |
150 |
| - method=cfg.quantization_params.calib_params.histogram_calib_method, |
151 |
| - calib_data_loader=calib_dataloader, |
152 |
| - num_calib_batches=cfg.quantization_params.calib_params.num_calib_batches or len(train_dataloader), |
153 |
| - percentile=get_param(cfg.quantization_params.calib_params, "percentile", 99.99), |
154 |
| - ) |
155 |
| - calibrator.reset_calibrators(model) # release memory taken by calibrators |
156 |
| - |
157 |
| - # VALIDATE PTQ MODEL AND PRINT SUMMARY |
158 |
| - logger.info("Validating PTQ model...") |
159 |
| - trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None)) |
160 |
| - valid_metrics_dict = trainer.test(model=model, test_loader=val_dataloader, test_metrics_list=cfg.training_hyperparams.valid_metrics_list) |
161 |
| - results = ["PTQ Model Validation Results"] |
162 |
| - results += [f" - {metric:10}: {value}" for metric, value in valid_metrics_dict.items()] |
163 |
| - logger.info("\n".join(results)) |
164 |
| - |
165 |
| - # TRAIN |
166 |
| - if cfg.quantization_params.ptq_only: |
167 |
| - logger.info("cfg.quantization_params.ptq_only=True. Performing PTQ only!") |
168 |
| - suffix = "ptq" |
169 |
| - res = None |
170 |
| - else: |
171 |
| - model.train() |
172 |
| - recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)} |
173 |
| - trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=get_param(cfg, "ckpt_root_dir", default_val=None)) |
174 |
| - torch.cuda.empty_cache() |
175 |
| - |
176 |
| - res = trainer.train( |
177 |
| - model=model, |
178 |
| - train_loader=train_dataloader, |
179 |
| - valid_loader=val_dataloader, |
180 |
| - training_params=cfg.training_hyperparams, |
181 |
| - additional_configs_to_log=recipe_logged_cfg, |
182 |
| - ) |
183 |
| - suffix = "qat" |
184 |
| - |
185 |
| - # EXPORT QUANTIZED MODEL TO ONNX |
186 |
| - input_shape = next(iter(val_dataloader))[0].shape |
187 |
| - os.makedirs(trainer.checkpoints_dir_path, exist_ok=True) |
188 |
| - |
189 |
| - qdq_onnx_path = os.path.join(trainer.checkpoints_dir_path, f"{cfg.experiment_name}_{'x'.join((str(x) for x in input_shape))}_{suffix}.onnx") |
190 |
| - # TODO: modify SG's convert_to_onnx for quantized models and use it instead |
191 |
| - export_quantized_module_to_onnx( |
192 |
| - model=model.cpu(), |
193 |
| - onnx_filename=qdq_onnx_path, |
194 |
| - input_shape=input_shape, |
195 |
| - input_size=input_shape, |
196 |
| - train=False, |
197 |
| - ) |
198 |
| - |
199 |
| - logger.info(f"Exported {suffix.upper()} ONNX to {qdq_onnx_path}") |
200 |
| - |
201 |
| - return model, res |
| 15 | + @deprecated(version="3.2.0", reason="QATTrainer is deprecated and will be removed in future release, use Trainer " "class instead.") |
| 16 | + def quantize_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tuple]: |
| 17 | + return Trainer.quantize_from_config(cfg) |
0 commit comments