Skip to content

Commit ab2e792

Browse files
authored
Fix DEKR's replace_head & improve __repr__ for keypoints transforms (#1227)
* YOLO-NAS Pose Estimation Experiment * Added logging * Remove test * Tune recipe for L * Tune recipe for S * Tune optimizer params * Lower LR * Lower LR * Added __repr__ for keypoint transforms to improve their printing in notebooks * If OKS sigmas are not given explicitly, initialize with default 17 keypoints only if num_joints == 17, otherwise use default values and emit a warning * Added recipe to train rescoring for yolo_nas_pose_l * Added YOLO-NAS-POSE scores * Added YOLO-NAS-POSE-M recipe * Fine-tuning notebook for pose est * Added lings to pretrained models * Remove links to S & M models * Update notebook * Fix apply_sigmoid=True * Increased eps to prevent divizion by zero * Update notebook * Adding predict * Predict * Predict * Adding predict * Adding predict * Adding joint information to dataset configs * Added makefile target recipe_accuracy_tests * Remove temp files * Rename variables for better clarity * Move predict() related files to super_gradients.training.utils.predict * Move predict() related files to super_gradients.training.utils.predict * Rename file poses.py -> pose_estimation.py * Rename joint_colors/joint_links -> edge_colors/edge_links * Disable showing bounding box by default * Allow passing edge & keypoints as None, in this case colors will be generated randomly * Update docstrings * Fix test * Added unit tests to verify settings preprocesisng params from dataset works * Added predict * Added default prerprocessing settings for yolo-nas-pose * Added default prerprocessing settings for yolo-nas-pose * Added __repr__ to KeypointsImageToTensor * _pad_image cannot work with pad_value that is tuple (r,g,b). So we change the keypoint transforms defaults in config to use single scalar value * _pad_image cannot work with pad_value that is tuple (r,g,b). So we change the keypoint transforms defaults in config to use single scalar value * _pad_image cannot work with pad_value that is tuple (r,g,b). So we change the keypoint transforms defaults in config to use single scalar value * Fix pad_value in keypoints transforms to accept single scalar value to make compatible with _pad_image * Update signature of base YoloNasPose class (dropped arch_params) * Simplify recipes to train YOLO-NAS-POSE * Implement replace_head for dekr * Implement replace_head for dekr * Make more beautiful __repr__ implementation * Change .format to string interpolation
1 parent 41d455f commit ab2e792

File tree

4 files changed

+82
-10
lines changed

4 files changed

+82
-10
lines changed

src/super_gradients/training/metrics/pose_estimation_metrics.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,14 @@ def __init__(
103103
self.greater_component_is_better = dict((k, True) for k in self.stats_names)
104104

105105
if oks_sigmas is None:
106-
oks_sigmas = np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) / 10.0
106+
if num_joints == 17:
107+
oks_sigmas = np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) / 10.0
108+
else:
109+
oks_sigmas = np.array([0.1] * num_joints)
110+
logger.warning(
111+
f"Using default OKS sigmas of `0.1` for a custom dataset with {num_joints} joints. "
112+
f"To silence this warning, you may want to specify OKS sigmas explicitly as it has direct impact on the AP score."
113+
)
107114

108115
if len(oks_sigmas) != num_joints:
109116
raise ValueError(f"Length of oks_sigmas ({len(oks_sigmas)}) should be equal to num_joints {num_joints}")

src/super_gradients/training/models/pose_estimation_models/dekr_hrnet.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -328,22 +328,32 @@ def __init__(self, arch_params):
328328
setattr(self, "stage{}".format(i + 2), stage)
329329

330330
# build head net
331-
inp_channels = int(sum(self.stages_spec.NUM_CHANNELS[-1]))
332-
config_heatmap = self.spec.HEAD_HEATMAP
333-
config_offset = self.spec.HEAD_OFFSET
331+
self.head_inp_channels = int(sum(self.stages_spec.NUM_CHANNELS[-1]))
332+
self.config_heatmap = self.spec.HEAD_HEATMAP
333+
self.config_offset = self.spec.HEAD_OFFSET
334334
self.num_joints = arch_params.num_classes
335335
self.num_offset = self.num_joints * 2
336336
self.num_joints_with_center = self.num_joints + 1
337-
self.offset_prekpt = config_offset["NUM_CHANNELS_PERKPT"]
337+
self.offset_prekpt = self.config_offset["NUM_CHANNELS_PERKPT"]
338338

339339
offset_channels = self.num_joints * self.offset_prekpt
340-
self.transition_heatmap = self._make_transition_for_head(inp_channels, config_heatmap["NUM_CHANNELS"])
341-
self.transition_offset = self._make_transition_for_head(inp_channels, offset_channels)
342-
self.head_heatmap = self._make_heatmap_head(config_heatmap)
343-
self.offset_feature_layers, self.offset_final_layer = self._make_separete_regression_head(config_offset)
344-
self.heatmap_activation = nn.Sigmoid() if config_heatmap["HEATMAP_APPLY_SIGMOID"] else nn.Identity()
340+
self.transition_heatmap = self._make_transition_for_head(self.head_inp_channels, self.config_heatmap["NUM_CHANNELS"])
341+
self.transition_offset = self._make_transition_for_head(self.head_inp_channels, offset_channels)
342+
self.head_heatmap = self._make_heatmap_head(self.config_heatmap)
343+
self.offset_feature_layers, self.offset_final_layer = self._make_separete_regression_head(self.config_offset)
344+
self.heatmap_activation = nn.Sigmoid() if self.config_heatmap["HEATMAP_APPLY_SIGMOID"] else nn.Identity()
345345
self.init_weights()
346346

347+
def replace_head(self, new_num_classes: int):
348+
self.num_joints = new_num_classes
349+
self.num_offset = new_num_classes * 2
350+
self.num_joints_with_center = new_num_classes + 1
351+
352+
offset_channels = self.num_joints * self.offset_prekpt
353+
self.head_heatmap = self._make_heatmap_head(self.config_heatmap)
354+
self.transition_offset = self._make_transition_for_head(self.head_inp_channels, offset_channels)
355+
self.offset_feature_layers, self.offset_final_layer = self._make_separete_regression_head(self.config_offset)
356+
347357
def _make_transition_for_head(self, inplanes: int, outplanes: int) -> nn.Module:
348358
transition_layer = [nn.Conv2d(inplanes, outplanes, 1, 1, 0, bias=False), nn.BatchNorm2d(outplanes), nn.ReLU(True)]
349359
return nn.Sequential(*transition_layer)

src/super_gradients/training/transforms/keypoint_transforms.py

+48
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ def get_equivalent_preprocessing(self) -> List:
7171
preprocessing += t.get_equivalent_preprocessing()
7272
return preprocessing
7373

74+
def __repr__(self):
75+
format_string = self.__class__.__name__ + "("
76+
for t in self.transforms:
77+
format_string += f"\t{repr(t)}"
78+
format_string += "\n)"
79+
return format_string
80+
7481

7582
@register_transform(Transforms.KeypointsImageToTensor)
7683
class KeypointsImageToTensor(KeypointTransform):
@@ -87,6 +94,9 @@ def get_equivalent_preprocessing(self) -> List:
8794
{Processings.ImagePermute: {"permutation": (2, 0, 1)}},
8895
]
8996

97+
def __repr__(self):
98+
return self.__class__.__name__ + f"(permutation={self.permutation})"
99+
90100

91101
@register_transform(Transforms.KeypointsImageStandardize)
92102
class KeypointsImageStandardize(KeypointTransform):
@@ -107,6 +117,9 @@ def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, area
107117
def get_equivalent_preprocessing(self) -> List[Dict]:
108118
return [{Processings.StandardizeImage: {"max_value": self.max_value}}]
109119

120+
def __repr__(self):
121+
return self.__class__.__name__ + f"(max_value={self.max_value})"
122+
110123

111124
@register_transform(Transforms.KeypointsImageNormalize)
112125
class KeypointsImageNormalize(KeypointTransform):
@@ -122,6 +135,9 @@ def __call__(self, image: np.ndarray, mask: np.ndarray, joints: np.ndarray, area
122135
image = (image - self.mean) / self.std
123136
return image, mask, joints, areas, bboxes
124137

138+
def __repr__(self):
139+
return self.__class__.__name__ + f"(mean={self.mean}, std={self.std})"
140+
125141
def get_equivalent_preprocessing(self) -> List:
126142
return [{Processings.NormalizeImage: {"mean": self.mean, "std": self.std}}]
127143

@@ -143,6 +159,9 @@ def __init__(self, flip_index: List[int], prob: float = 0.5):
143159
self.flip_index = flip_index
144160
self.prob = prob
145161

162+
def __repr__(self):
163+
return self.__class__.__name__ + f"(flip_index={self.flip_index}, prob={self.prob})"
164+
146165
def __call__(self, image, mask, joints, areas: Optional[np.ndarray], bboxes: Optional[np.ndarray]):
147166
if image.shape[:2] != mask.shape[:2]:
148167
raise RuntimeError(f"Image shape ({image.shape[:2]}) does not match mask shape ({mask.shape[:2]}).")
@@ -218,6 +237,9 @@ def apply_to_bboxes(self, bboxes, rows):
218237
def get_equivalent_preprocessing(self) -> List:
219238
raise RuntimeError("KeypointsRandomHorizontalFlip does not have equivalent preprocessing.")
220239

240+
def __repr__(self):
241+
return self.__class__.__name__ + f"(prob={self.prob})"
242+
221243

222244
@register_transform(Transforms.KeypointsLongestMaxSize)
223245
class KeypointsLongestMaxSize(KeypointTransform):
@@ -278,6 +300,13 @@ def apply_to_keypoints(cls, keypoints, scale):
278300
def apply_to_bboxes(cls, bboxes, scale):
279301
return bboxes * scale
280302

303+
def __repr__(self):
304+
return (
305+
self.__class__.__name__ + f"(max_height={self.max_height}, "
306+
f"max_width={self.max_width}, "
307+
f"interpolation={self.interpolation}, prob={self.prob})"
308+
)
309+
281310
def get_equivalent_preprocessing(self) -> List:
282311
return [{Processings.KeypointsLongestMaxSizeRescale: {"output_shape": (self.max_height, self.max_width)}}]
283312

@@ -318,6 +347,14 @@ def __call__(self, image, mask, joints, areas: Optional[np.ndarray], bboxes: Opt
318347

319348
return image, mask, joints, areas, bboxes
320349

350+
def __repr__(self):
351+
return (
352+
self.__class__.__name__ + f"(min_height={self.min_height}, "
353+
f"min_width={self.min_width}, "
354+
f"image_pad_value={self.image_pad_value}, "
355+
f"mask_pad_value={self.mask_pad_value})"
356+
)
357+
321358
def get_equivalent_preprocessing(self) -> List:
322359
return [{Processings.KeypointsBottomRightPadding: {"output_shape": (self.min_height, self.min_width), "pad_value": self.image_pad_value}}]
323360

@@ -353,6 +390,17 @@ def __init__(
353390
self.mask_pad_value = mask_pad_value
354391
self.prob = prob
355392

393+
def __repr__(self):
394+
return (
395+
self.__class__.__name__ + f"(max_rotation={self.max_rotation}, "
396+
f"min_scale={self.min_scale}, "
397+
f"max_scale={self.max_scale}, "
398+
f"max_translate={self.max_translate}, "
399+
f"image_pad_value={self.image_pad_value}, "
400+
f"mask_pad_value={self.mask_pad_value}, "
401+
f"prob={self.prob})"
402+
)
403+
356404
def _get_affine_matrix(self, img, angle, scale, dx, dy):
357405
"""
358406

tests/unit_tests/replace_head_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ def test_yolo_nas_replace_head(self):
3030
(_, pred_scores), _ = model.forward(input)
3131
self.assertEqual(pred_scores.size(2), 100)
3232

33+
def test_dekr_replace_head(self):
34+
input = torch.randn(1, 3, 640, 640).to(self.device)
35+
model = models.get(Models.DEKR_W32_NO_DC, num_classes=20, pretrained_weights="coco_pose").to(self.device).eval()
36+
heatmap, offsets = model.forward(input)
37+
self.assertEqual(heatmap.size(1), 20 + 1)
38+
self.assertEqual(offsets.size(1), 20 * 2)
39+
3340
def tearDown(self) -> None:
3441
if os.path.exists("~/.cache/torch/hub/"):
3542
shutil.rmtree("~/.cache/torch/hub/")

0 commit comments

Comments
 (0)