Skip to content

Commit

Permalink
[Feature] Add support for mask-free keypoints conversion (#142)
Browse files Browse the repository at this point in the history
*  Add generate_mask_from_confidence in HumanData

 * Add return_mask switch in convert_kps that allow mask to be not returned
  • Loading branch information
caizhongang authored Apr 1, 2022
1 parent b683420 commit 5808b11
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 2 deletions.
16 changes: 14 additions & 2 deletions mmhuman3d/core/conventions/keypoints_mapping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def convert_kps(
approximate: bool = False,
mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
keypoints_factory: dict = KEYPOINTS_FACTORY,
return_mask: bool = True
) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]:
"""Convert keypoints following the mapping correspondence between src and
dst keypoints definition. Supported conventions by now: agora, coco, smplx,
Expand All @@ -84,14 +85,22 @@ def convert_kps(
Defaults to None.
keypoints_factory (dict, optional): A class to store the attributes.
Defaults to keypoints_factory.
return_mask (bool, optional): whether to return a mask as part of the
output. It is unnecessary to return a mask if the keypoints consist
of confidence. Any invalid keypoints will have zero confidence.
Defaults to True.
Returns:
Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]
: tuple of (out_keypoints, mask). out_keypoints and mask will be of
the same type.
"""
assert keypoints.ndim in {3, 4}
if src == dst:
return keypoints, np.ones((keypoints.shape[-2]))
if return_mask:
return keypoints, np.ones((keypoints.shape[-2]))
else:
return keypoints

src_names = keypoints_factory[src.lower()]
dst_names = keypoints_factory[dst.lower()]
extra_dims = keypoints.shape[:-2]
Expand Down Expand Up @@ -130,7 +139,10 @@ def convert_kps(
mask[dst_idxs] = original_mask[src_idxs] \
if original_mask is not None else 1.0

return out_keypoints, mask
if return_mask:
return out_keypoints, mask
else:
return out_keypoints


def compress_converted_kps(
Expand Down
50 changes: 50 additions & 0 deletions mmhuman3d/data/data_structures/human_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,56 @@ def __check_value_len__(self, key: Any, val: Any) -> bool:
msg=err_msg, logger=self.__class__.logger, level=logging.ERROR)
return ret_bool

def generate_mask_from_confidence(self, keys=None) -> None:
"""Generate mask from keypoints' confidence. Keypoints that have zero
confidence in all occurrences will have a zero mask. Note that the last
value of the keypoint is assumed to be confidence.
Args:
keys: None, str, or list of str.
None: all keys with `keypoint` in it will have mask
generated from their confidence.
str: key of the keypoint, the mask has name f'{key}_name'
list of str: a list of keys of the keypoints.
Generate mask for multiple keypoints.
Defaults to None.
Returns:
None
Raises:
KeyError:
A key is not not found
"""
if keys is None:
keys = []
for key in self.keys():
val = self.get_raw_value(key)
if isinstance(val, np.ndarray) and \
'keypoints' in key and \
'_mask' not in key:
keys.append(key)
elif isinstance(keys, str):
keys = [keys]
elif isinstance(keys, list):
for key in keys:
assert isinstance(key, str)
else:
raise TypeError(f'`Keys` must be None, str, or list of str, '
f'got {type(keys)}.')

update_dict = {}
for kpt_key in keys:
kpt_array = self.get_raw_value(kpt_key)
num_joints = kpt_array.shape[-2]
# if all conf of a joint are zero, this joint is masked
joint_conf = kpt_array[..., -1].reshape(-1, num_joints)
mask_array = (joint_conf > 0).astype(np.uint8).max(axis=0)
assert len(mask_array) == num_joints
# generate mask
update_dict[f'{kpt_key}_mask'] = mask_array
self.update(update_dict)

def compress_keypoints_by_mask(self) -> None:
"""If a key contains 'keypoints', and f'{key}_mask' is in self.keys(),
invalid zeros will be removed and f'{key}_mask' will be locked.
Expand Down
8 changes: 8 additions & 0 deletions tests/test_convention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,16 @@ def test_conventions():
np.zeros((f, n_person, J, 3)),
np.zeros((f, n_person, J, 2))
]:
# with mask
keypoints_dst, mask = convert_kps(keypoints, src_name,
dst_name)

# without mask
keypoints_dst_wo_mask = convert_kps(
keypoints, src_name, dst_name, return_mask=False)

assert np.all(keypoints_dst == keypoints_dst_wo_mask)

exp_shape = list(keypoints.shape)
exp_shape[-2] = J_dst
assert keypoints_dst.shape == tuple(exp_shape)
Expand Down
37 changes: 37 additions & 0 deletions tests/test_human_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,43 @@ def test_compression():
human_data.decompress_keypoints()


def test_generate_mask_from_keypoints():

human_data = HumanData.new(key_strict=False)
keypoints2d = np.random.rand(3, 144, 3)
keypoints3d = np.random.rand(3, 144, 4)

# set confidence
keypoints2d[:, :72, -1] = 0
keypoints3d[:, 72:, -1] = 0
human_data['keypoints2d'] = keypoints2d
human_data['keypoints3d'] = keypoints3d

# test all keys
with pytest.raises(KeyError):
human_data['keypoints2d_mask']
with pytest.raises(KeyError):
human_data['keypoints3d_mask']
human_data.generate_mask_from_confidence()
assert 'keypoints2d_mask' in human_data
assert (human_data['keypoints2d_mask'][:72] == 0).all()
assert (human_data['keypoints2d_mask'][72:] == 1).all()
assert 'keypoints3d_mask' in human_data
assert (human_data['keypoints3d_mask'][72:] == 0).all()
assert (human_data['keypoints3d_mask'][:72] == 1).all()

# test str keys
human_data.generate_mask_from_confidence(keys='keypoints2d')

# test list of str keys
human_data.generate_mask_from_confidence(
keys=['keypoints2d', 'keypoints3d'])

# test compression with generated mask
human_data.compress_keypoints_by_mask()
assert human_data.check_keypoints_compressed() is True


def test_pop_unsupported_items():
# directly pop them
human_data = HumanData.new(key_strict=False)
Expand Down

0 comments on commit 5808b11

Please sign in to comment.