Skip to content

Commit 8f959d3

Browse files
authored
Merge branch 'master' into master
2 parents baf5765 + 15c0fe1 commit 8f959d3

File tree

4 files changed

+110
-54
lines changed

4 files changed

+110
-54
lines changed

documentation/source/qat_ptq_yolo_nas.md

+20-20
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,27 @@ Now, let's get to it.
2020

2121
## Step 0: Installations and Dataset Setup
2222

23-
Follow the setup instructions for RF100:
23+
Follow the [official instructions](https://github.com/roboflow/roboflow-100-benchmark?ref=roboflow-blog) to download Roboflow100:
24+
25+
To use this dataset, you **must** download the "coco" format, **NOT** the yolov5.
26+
2427
```
25-
- Follow the official instructions to download Roboflow100: https://github.com/roboflow/roboflow-100-benchmark?ref=roboflow-blog
26-
//!\\ To use this dataset, you must download the "coco" format, NOT the yolov5.
27-
28-
- Your dataset should look like this:
29-
rf100
30-
├── 4-fold-defect
31-
│ ├─ train
32-
│ │ ├─ 000000000001.jpg
33-
│ │ ├─ ...
34-
│ │ └─ _annotations.coco.json
35-
│ ├─ valid
36-
│ │ └─ ...
37-
│ └─ test
38-
│ └─ ...
39-
├── abdomen-mri
40-
│ └─ ...
41-
└── ...
42-
43-
- Install CoCo API: https://github.com/pdollar/coco/tree/master/PythonAPI
28+
- Your dataset should look like this:
29+
rf100
30+
├── 4-fold-defect
31+
│ ├─ train
32+
│ │ ├─ 000000000001.jpg
33+
│ │ ├─ ...
34+
│ │ └─ _annotations.coco.json
35+
│ ├─ valid
36+
│ │ └─ ...
37+
│ └─ test
38+
│ └─ ...
39+
├── abdomen-mri
40+
│ └─ ...
41+
└── ...
42+
43+
- Install CoCo API: https://github.com/pdollar/coco/tree/master/PythonAPI
4444
```
4545

4646
Install the latest version of SG:

src/super_gradients/training/models/model_factory.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ def instantiate_model(
153153
net = architecture_cls(arch_params=arch_params)
154154

155155
if pretrained_weights:
156+
# The logic is follows - first we initialize the preprocessing params using default hard-coded params
157+
# If pretrained checkpoint contains preprocessing params, new params will be loaded and override the ones from
158+
# this step in load_pretrained_weights_local/load_pretrained_weights
159+
if isinstance(net, HasPredict):
160+
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
161+
net.set_dataset_processing_params(**processing_params)
162+
156163
if is_remote and pretrained_weights_path:
157164
load_pretrained_weights_local(net, model_name, pretrained_weights_path)
158165
else:
@@ -162,11 +169,6 @@ def instantiate_model(
162169
net.replace_head(new_num_classes=num_classes_new_head)
163170
arch_params.num_classes = num_classes_new_head
164171

165-
# STILL NEED TO GET PREPROCESSING PARAMS IN CASE CHECKPOINT HAS NO RECIPE
166-
if isinstance(net, HasPredict):
167-
processing_params = get_pretrained_processing_params(model_name, pretrained_weights)
168-
net.set_dataset_processing_params(**processing_params)
169-
170172
_add_model_name_attribute(net, model_name)
171173

172174
return net

src/super_gradients/training/utils/checkpoint_utils.py

+57-26
Original file line numberDiff line numberDiff line change
@@ -1517,16 +1517,7 @@ def load_checkpoint_to_model(
15171517
message_model = "model" if not load_backbone else "model's backbone"
15181518
logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)
15191519

1520-
if (isinstance(net, HasPredict)) and load_processing_params:
1521-
if "processing_params" not in checkpoint.keys():
1522-
raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
1523-
try:
1524-
net.set_dataset_processing_params(**checkpoint["processing_params"])
1525-
except Exception as e:
1526-
logger.warning(
1527-
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
1528-
"predict make sure to call set_dataset_processing_params."
1529-
)
1520+
_maybe_load_preprocessing_params(net, checkpoint)
15301521

15311522
if load_weights_only or load_backbone:
15321523
# DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
@@ -1549,10 +1540,12 @@ def __init__(self, desc):
15491540
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
15501541
"""
15511542
Loads pretrained weights from the MODEL_URLS dictionary to model
1552-
:param architecture: name of the model's architecture
1553-
:param model: model to load pretrinaed weights for
1554-
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
1555-
:return: None
1543+
1544+
:param architecture: name of the model's architecture
1545+
:param model: model to load pretrinaed weights for
1546+
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
1547+
1548+
:return: None
15561549
"""
15571550
from super_gradients.common.object_names import Models
15581551

@@ -1569,19 +1562,19 @@ def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretraine
15691562
"By downloading the pre-trained weight files you agree to comply with these terms."
15701563
)
15711564

1572-
unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
1573-
map_location = torch.device("cpu")
1574-
with wait_for_the_master(get_local_rank()):
1575-
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
1576-
_load_weights(architecture, model, pretrained_state_dict)
1577-
1565+
# Basically this check allows settings pretrained weights from local path using file:///path/to/weights scheme
1566+
# which is a valid URI scheme for local files
1567+
# Supporting local files and file URI allows us modification of pretrained weights dics in unit tests
1568+
if url.startswith("file://") or os.path.exists(url):
1569+
pretrained_state_dict = torch.load(url.replace("file://", ""), map_location="cpu")
1570+
else:
1571+
unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
1572+
map_location = torch.device("cpu")
1573+
with wait_for_the_master(get_local_rank()):
1574+
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
15781575

1579-
def _load_weights(architecture, model, pretrained_state_dict):
1580-
if "ema_net" in pretrained_state_dict.keys():
1581-
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
1582-
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
1583-
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
1584-
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")
1576+
_load_weights(architecture, model, pretrained_state_dict)
1577+
_maybe_load_preprocessing_params(model, pretrained_state_dict)
15851578

15861579

15871580
def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
@@ -1598,3 +1591,41 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre
15981591

15991592
pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
16001593
_load_weights(architecture, model, pretrained_state_dict)
1594+
_maybe_load_preprocessing_params(model, pretrained_state_dict)
1595+
1596+
1597+
def _load_weights(architecture, model, pretrained_state_dict):
1598+
if "ema_net" in pretrained_state_dict.keys():
1599+
pretrained_state_dict["net"] = pretrained_state_dict["ema_net"]
1600+
solver = YoloXCheckpointSolver() if "yolox" in architecture else DefaultCheckpointSolver()
1601+
adaptive_load_state_dict(net=model, state_dict=pretrained_state_dict, strict=StrictLoad.NO_KEY_MATCHING, solver=solver)
1602+
logger.info(f"Successfully loaded pretrained weights for architecture {architecture}")
1603+
1604+
1605+
def _maybe_load_preprocessing_params(model: Union[nn.Module, HasPredict], checkpoint: Mapping[str, Tensor]) -> bool:
1606+
"""
1607+
Tries to load preprocessing params from the checkpoint to the model.
1608+
The function does not crash, and raises a warning if the loading fails.
1609+
:param model: Instance of nn.Module
1610+
:param checkpoint: Entire checkpoint dict (not state_dict with model weights)
1611+
:return: True if the loading was successful, False otherwise.
1612+
"""
1613+
model = unwrap_model(model)
1614+
checkpoint_has_preprocessing_params = "processing_params" in checkpoint.keys()
1615+
model_has_predict = isinstance(model, HasPredict)
1616+
logger.debug(
1617+
f"Trying to load preprocessing params from checkpoint. Preprocessing params in checkpoint: {checkpoint_has_preprocessing_params}. "
1618+
f"Model {model.__class__.__name__} inherit HasPredict: {model_has_predict}"
1619+
)
1620+
1621+
if model_has_predict and checkpoint_has_preprocessing_params:
1622+
try:
1623+
model.set_dataset_processing_params(**checkpoint["processing_params"])
1624+
logger.debug(f"Successfully loaded preprocessing params from checkpoint {checkpoint['processing_params']}")
1625+
return True
1626+
except Exception as e:
1627+
logger.warning(
1628+
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
1629+
"predict make sure to call set_dataset_processing_params."
1630+
)
1631+
return False

tests/unit_tests/pretrained_models_unit_test.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
1+
import os
2+
import shutil
3+
import tempfile
14
import unittest
5+
6+
import numpy as np
7+
import torch
8+
29
import super_gradients
310
from super_gradients.common.object_names import Models
4-
from super_gradients.training import models
511
from super_gradients.training import Trainer
12+
from super_gradients.training import models
613
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
714
from super_gradients.training.metrics import Accuracy
8-
import os
9-
import shutil
15+
from super_gradients.training.pretrained_models import MODEL_URLS, PRETRAINED_NUM_CLASSES
16+
from super_gradients.training.processing.processing import default_yolo_nas_coco_processing_params
1017

1118

1219
class PretrainedModelsUnitTest(unittest.TestCase):
@@ -29,6 +36,22 @@ def test_pretrained_repvgg_a0_imagenet(self):
2936
model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
3037
trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)
3138

39+
def test_pretrained_models_load_preprocessing_params(self):
40+
"""
41+
Test that checks whether preprocessing params from pretrained model load correctly.
42+
"""
43+
state = {"net": models.get(Models.YOLO_NAS_S, num_classes=80).state_dict(), "processing_params": default_yolo_nas_coco_processing_params()}
44+
with tempfile.TemporaryDirectory() as td:
45+
checkpoint_path = os.path.join(td, "yolo_nas_s_coco.pth")
46+
torch.save(state, checkpoint_path)
47+
48+
MODEL_URLS[Models.YOLO_NAS_S + "_test"] = checkpoint_path
49+
PRETRAINED_NUM_CLASSES["test"] = 80
50+
51+
model = models.get(Models.YOLO_NAS_S, pretrained_weights="test")
52+
# .predict() would fail it model has no preprocessing params
53+
self.assertIsNotNone(model.predict(np.zeros(shape=(512, 512, 3), dtype=np.uint8)))
54+
3255
def tearDown(self) -> None:
3356
if os.path.exists("~/.cache/torch/hub/"):
3457
shutil.rmtree("~/.cache/torch/hub/")

0 commit comments

Comments
 (0)