Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1060 yolo nas pose release pr to add datasets and metric #1506

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, ModelHasNoPreprocessingParamsException
from .pose_estimation_post_prediction_callback import AbstractPoseEstimationPostPredictionCallback, PoseEstimationPredictions

__all__ = [
"HasPredict",
Expand All @@ -8,4 +9,6 @@
"ExportableObjectDetectionModel",
"AbstractObjectDetectionDecodingModule",
"ModelHasNoPreprocessingParamsException",
"AbstractPoseEstimationPostPredictionCallback",
"PoseEstimationPredictions",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import abc
import dataclasses
import numpy as np

from typing import Any, List
from typing import Union, Optional
from torch import Tensor

__all__ = ["PoseEstimationPredictions", "AbstractPoseEstimationPostPredictionCallback"]


@dataclasses.dataclass
class PoseEstimationPredictions:
"""
A data class that encapsulates pose estimation predictions for a single image.

:param poses: Array of shape [N, K, 3] where N is number of poses and K is number of joints.
Last dimension is [x, y, score] where score the confidence score for the specific joint
with [0..1] range.
:param scores: Array of shape [N] with scores for each pose with [0..1] range.
:param bboxes_xyxy: Array of shape [N, 4] with bounding boxes for each pose in XYXY format.
Can be None if bounding boxes are not available (for instance, DEKR model does not output boxes).
"""

poses: Union[Tensor, np.ndarray]
scores: Union[Tensor, np.ndarray]
bboxes_xyxy: Optional[Union[Tensor, np.ndarray]]


class AbstractPoseEstimationPostPredictionCallback(abc.ABC):
"""
A protocol interface of a post-prediction callback for pose estimation models.
"""

@abc.abstractmethod
def __call__(self, predictions: Any) -> List[PoseEstimationPredictions]:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# This file is not "true" dataset_params file, but rather a collection of settings that describe
# skeleton configuration specific to COCO dataset. It is used by other dataset_params files to
# avoid code duplication.

num_joints: 17

# OKs sigma values take from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py#L523
oks_sigmas: [0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089]

flip_indexes: [ 0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15,]

edge_links:
- [0, 1]
- [0, 2]
- [1, 2]
- [1, 3]
- [2, 4]
- [3, 5]
- [4, 6]
- [5, 6]
- [5, 7]
- [5, 11]
- [6, 8]
- [6, 12]
- [7, 9]
- [8, 10]
- [11, 12]
- [11, 13]
- [12, 14]
- [13, 15]
- [14, 16]

edge_colors:
- [214, 39, 40] # Nose -> LeftEye
- [148, 103, 189] # Nose -> RightEye
- [44, 160, 44] # LeftEye -> RightEye
- [140, 86, 75] # LeftEye -> LeftEar
- [227, 119, 194] # RightEye -> RightEar
- [127, 127, 127] # LeftEar -> LeftShoulder
- [188, 189, 34] # RightEar -> RightShoulder
- [127, 127, 127] # Shoulders
- [188, 189, 34] # LeftShoulder -> LeftElbow
- [140, 86, 75] # LeftTorso
- [23, 190, 207] # RightShoulder -> RightElbow
- [227, 119, 194] # RightTorso
- [31, 119, 180] # LeftElbow -> LeftArm
- [255, 127, 14] # RightElbow -> RightArm
- [148, 103, 189] # Waist
- [255, 127, 14] # Left Hip -> Left Knee
- [214, 39, 40] # Right Hip -> Right Knee
- [31, 119, 180] # Left Knee -> Left Ankle
- [44, 160, 44] # Right Knee -> Right Ankle


keypoint_colors:
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
- [31, 119, 180]
- [148, 103, 189]
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,12 @@ train_dataloader_params:
batch_size: 8
num_workers: 8
drop_last: True
collate_fn:
_target_: super_gradients.training.datasets.pose_estimation_datasets.KeypointsCollate
collate_fn: KeypointsCollate

val_dataloader_params:
batch_size: 24
num_workers: 8
drop_last: False
collate_fn:
_target_: super_gradients.training.datasets.pose_estimation_datasets.KeypointsCollate
collate_fn: KeypointsCollate

_convert_: all
10 changes: 7 additions & 3 deletions src/super_gradients/training/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset, CityscapesConcatDataset
from super_gradients.training.datasets.segmentation_datasets.coco_segmentation import CoCoSegmentationDataSet
from super_gradients.training.datasets.segmentation_datasets.supervisely_persons_segmentation import SuperviselyPersonsDataset
from super_gradients.training.datasets.pose_estimation_datasets import COCOKeypointsDataset

cv2.setNumThreads(0)
from super_gradients.training.datasets.pose_estimation_datasets import (
COCOKeypointsDataset,
BaseKeypointsDataset,
)


__all__ = [
"BaseKeypointsDataset",
"DataAugmentation",
"ListDataset",
"DirectoryDataSet",
Expand All @@ -45,3 +47,5 @@
"SuperviselyPersonsDataset",
"COCOKeypointsDataset",
]

cv2.setNumThreads(0)
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,13 @@
from super_gradients.training.datasets.pose_estimation_datasets.base_keypoints import BaseKeypointsDataset, KeypointsCollate
from super_gradients.training.datasets.pose_estimation_datasets.target_generators import KeypointsTargetsGenerator, DEKRTargetsGenerator

__all__ = ["COCOKeypointsDataset", "BaseKeypointsDataset", "KeypointsCollate", "KeypointsTargetsGenerator", "DEKRTargetsGenerator"]
from .abstract_pose_estimation_dataset import AbstractPoseEstimationDataset

__all__ = [
"AbstractPoseEstimationDataset",
"COCOKeypointsDataset",
"BaseKeypointsDataset",
"KeypointsCollate",
"KeypointsTargetsGenerator",
"DEKRTargetsGenerator",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import abc
import random
from typing import Tuple, List, Union

import numpy as np
from torch.utils.data.dataloader import Dataset

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.object_names import Processings
from super_gradients.module_interfaces import HasPreprocessingParams
from super_gradients.training.samples import PoseEstimationSample
from super_gradients.training.transforms.keypoint_transforms import KeypointsCompose, AbstractKeypointTransform
from super_gradients.training.utils.visualization.utils import generate_color_mapping

logger = get_logger(__name__)


class AbstractPoseEstimationDataset(Dataset, HasPreprocessingParams):
"""
Abstract class for strongly typed dataset classes for pose estimation task.
This new concept introduced in SG 3.3 and will be used in the future to replace the old BaseKeypointsDataset.
The reasoning begin strongly typed dataset includes:
1. Introduction of a new concept of "data sample" that has clear definition (via @dataclass) thus reducing change of bugs/confusion.
2. Data sample becomes a central concept in data augmentation transforms and metrics.
3. Dataset implementation decoupled from the model & loss - now the dataset returns the data sample objects
and model/loss specific conversion happens only in collate function.

Descendants should implement the load_sample method to read a sample from the disk and return PoseEstimationSample object.
"""

def __init__(
self,
transforms: List[AbstractKeypointTransform],
num_joints: int,
edge_links: Union[List[Tuple[int, int]], np.ndarray],
edge_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
keypoint_colors: Union[List[Tuple[int, int, int]], np.ndarray, None],
):
"""

:param transforms: Transforms to be applied to the image & keypoints
:param num_joints: Number of joints to be predicted
:param edge_links: Edge links between joints
:param edge_colors: Color of the edge links. If None, the color will be generated randomly.
:param keypoint_colors: Color of the keypoints. If None, the color will be generated randomly.
"""
super().__init__()
self.transforms = KeypointsCompose(
transforms,
load_sample_fn=self.load_random_sample,
)
self.num_joints = num_joints
self.edge_links = edge_links
self.edge_colors = edge_colors or generate_color_mapping(len(edge_links))
self.keypoint_colors = keypoint_colors or generate_color_mapping(num_joints)

@abc.abstractmethod
def __len__(self) -> int:
raise NotImplementedError()

@abc.abstractmethod
def load_sample(self, index: int) -> PoseEstimationSample:
"""
Read a sample from the disk and return a PoseEstimationSample
:param index: Sample index
:return: Returns an instance of PoseEstimationSample that holds complete sample (image and annotations)
"""
raise NotImplementedError()

def load_random_sample(self) -> PoseEstimationSample:
"""
Return a random sample from the dataset

:return: Instance of PoseEstimationSample
"""
num_samples = len(self)
random_index = random.randrange(0, num_samples)
return self.load_sample(random_index)

def __getitem__(self, index: int) -> PoseEstimationSample:
sample = self.load_sample(index)
sample = self.transforms.apply_to_sample(sample)
return sample

def get_dataset_preprocessing_params(self):
"""

:return:
"""
image_to_tensor = {Processings.ImagePermute: {"permutation": (2, 0, 1)}}
pipeline = self.transforms.get_equivalent_preprocessing() + [image_to_tensor]
params = dict(
conf=0.05,
image_processor={Processings.ComposeProcessing: {"processings": pipeline}},
edge_links=self.edge_links,
edge_colors=self.edge_colors,
keypoint_colors=self.keypoint_colors,
)
return params
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ def __init__(
"""
super().__init__()
self.target_generator = target_generator
self.transforms = KeypointsCompose(
transforms,
load_sample_fn=None,
)
self.transforms = KeypointsCompose(transforms)
self.min_instance_area = min_instance_area
self.num_joints = num_joints
self.edge_links = edge_links
Expand Down
Loading