Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add support for mask-free keypoints conversion #142

Merged
merged 6 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions mmhuman3d/core/conventions/keypoints_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def convert_kps(
approximate: bool = False,
mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
keypoints_factory: dict = KEYPOINTS_FACTORY,
return_mask: bool = True
) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]:
"""Convert keypoints following the mapping correspondence between src and
dst keypoints definition. Supported conventions by now: agora, coco, smplx,
Expand All @@ -83,14 +84,22 @@ def convert_kps(
Defaults to None.
keypoints_factory (dict, optional): A class to store the attributes.
Defaults to keypoints_factory.
return_mask (bool, optional): whether to return a mask as part of the
output. It is unnecessary to return a mask if the keypoints consist
of confidence. Any invalid keypoints will have zero confidence.
Defaults to True.
Returns:
Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]
: tuple of (out_keypoints, mask). out_keypoints and mask will be of
the same type.
"""
assert keypoints.ndim in {3, 4}
if src == dst:
return keypoints, np.ones((keypoints.shape[-2]))
if return_mask:
return keypoints, np.ones((keypoints.shape[-2]))
else:
return keypoints

src_names = keypoints_factory[src.lower()]
dst_names = keypoints_factory[dst.lower()]
extra_dims = keypoints.shape[:-2]
Expand Down Expand Up @@ -129,7 +138,10 @@ def convert_kps(
mask[dst_idxs] = original_mask[src_idxs] \
if original_mask is not None else 1.0

return out_keypoints, mask
if return_mask:
return out_keypoints, mask
else:
return out_keypoints


def compress_converted_kps(
Expand Down
50 changes: 50 additions & 0 deletions mmhuman3d/data/data_structures/human_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,56 @@ def __check_value_len__(self, key: Any, val: Any) -> bool:
msg=err_msg, logger=self.__class__.logger, level=logging.ERROR)
return ret_bool

def generate_mask_from_confidence(self, keys=None) -> None:
"""Generate mask from keypoints' confidence. Keypoints that have zero
confidence in all occurrences will have a zero mask. Note that the last
value of the keypoint is assumed to be confidence.

Args:
keys: None, str, or list of str.
None: all keys with `keypoint` in it will have mask
generated from their confidence.
str: key of the keypoint, the mask has name f'{key}_name'
list of str: a list of keys of the keypoints.
Generate mask for multiple keypoints.
Defaults to None.

Returns:
None

Raises:
KeyError:
A key is not not found
"""
if keys is None:
keys = []
for key in self.keys():
val = self.get_raw_value(key)
if isinstance(val, np.ndarray) and \
'keypoints' in key and \
'_mask' not in key:
keys.append(key)
elif isinstance(keys, str):
keys = [keys]
elif isinstance(keys, list):
for key in keys:
assert isinstance(key, str)
else:
raise TypeError(f'`Keys` must be None, str, or list of str, '
f'got {type(keys)}.')

update_dict = {}
for kpt_key in keys:
kpt_array = self.get_raw_value(kpt_key)
num_joints = kpt_array.shape[-2]
# if all conf of a joint are zero, this joint is masked
joint_conf = kpt_array[..., -1].reshape(-1, num_joints)
mask_array = (joint_conf > 0).astype(np.uint8).max(axis=0)
assert len(mask_array) == num_joints
# generate mask
update_dict[f'{kpt_key}_mask'] = mask_array
self.update(update_dict)

def compress_keypoints_by_mask(self) -> None:
"""If a key contains 'keypoints', and f'{key}_mask' is in self.keys(),
invalid zeros will be removed and f'{key}_mask' will be locked.
Expand Down
8 changes: 8 additions & 0 deletions tests/test_convention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,16 @@ def test_conventions():
np.zeros((f, n_person, J, 3)),
np.zeros((f, n_person, J, 2))
]:
# with mask
keypoints_dst, mask = convert_kps(keypoints, src_name,
dst_name)

# without mask
keypoints_dst_wo_mask = convert_kps(
keypoints, src_name, dst_name, return_mask=False)

assert np.all(keypoints_dst == keypoints_dst_wo_mask)

exp_shape = list(keypoints.shape)
exp_shape[-2] = J_dst
assert keypoints_dst.shape == tuple(exp_shape)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_human_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,43 @@ def test_compression():
human_data.decompress_keypoints()


def test_generate_mask_from_keypoints():

human_data = HumanData.new(key_strict=False)
keypoints2d = np.random.rand(3, 144, 3)
keypoints3d = np.random.rand(3, 144, 4)

# set confidence
keypoints2d[:, :72, -1] = 0
keypoints3d[:, 72:, -1] = 0
human_data['keypoints2d'] = keypoints2d
human_data['keypoints3d'] = keypoints3d

# test all keys
with pytest.raises(KeyError):
human_data['keypoints2d_mask']
with pytest.raises(KeyError):
human_data['keypoints3d_mask']
human_data.generate_mask_from_confidence()
assert 'keypoints2d_mask' in human_data
assert (human_data['keypoints2d_mask'][:72] == 0).all()
assert (human_data['keypoints2d_mask'][72:] == 1).all()
assert 'keypoints3d_mask' in human_data
assert (human_data['keypoints3d_mask'][72:] == 0).all()
assert (human_data['keypoints3d_mask'][:72] == 1).all()

# test str keys
human_data.generate_mask_from_confidence(keys='keypoints2d')

# test list of str keys
human_data.generate_mask_from_confidence(
keys=['keypoints2d', 'keypoints3d'])

# test compression with generated mask
human_data.compress_keypoints_by_mask()
assert human_data.check_keypoints_compressed() is True


def test_pop_unsupported_items():
# directly pop them
human_data = HumanData.new(key_strict=False)
Expand Down