-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Add] Add EvalHook inherited from MMCV EvalHook (#90)
* upgrade eval * fix lint * fix lint * fix lint * add a unit test: test eval hook * add unit test * fix unit test * fix unit test: remove the requirement for cuda * use kwargs to receive EvalHook args * remove useless comments * create the folder if it does not exist * add new metric * fix some bugs * fix unit test * remove joint_error metric * fix unit test * fix pck thresholds * fix import error * fix import error * remove unused paramter * add more unit test * add unit test * rename p-mpjpe to pa-mpjpe * fix unit test * remove `mpjpe` in `__all__` * fix comments * add more unit tests * fix * rename * fix docsting * fix typo * update `getting_started.md` * fix docstring * add evaluation config * fix unit test * use mmhuman3d greater/less key
- Loading branch information
Showing
22 changed files
with
1,369 additions
and
365 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,16 @@ | ||
from mmhuman3d.core.evaluation import mesh_eval, mpjpe | ||
from mmhuman3d.core.evaluation import mesh_eval | ||
from mmhuman3d.core.evaluation.eval_hooks import DistEvalHook, EvalHook | ||
from mmhuman3d.core.evaluation.eval_utils import ( | ||
keypoint_3d_auc, | ||
keypoint_3d_pck, | ||
keypoint_accel_error, | ||
keypoint_mpjpe, | ||
vertice_pve, | ||
) | ||
from mmhuman3d.core.evaluation.mesh_eval import compute_similarity_transform | ||
from mmhuman3d.core.evaluation.mpjpe import keypoint_mpjpe | ||
|
||
__all__ = [ | ||
'compute_similarity_transform', 'keypoint_mpjpe', 'mesh_eval', 'mpjpe' | ||
'compute_similarity_transform', 'keypoint_mpjpe', 'mesh_eval', | ||
'DistEvalHook', 'EvalHook', 'vertice_pve', 'keypoint_3d_pck', | ||
'keypoint_3d_auc', 'keypoint_accel_error' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import tempfile | ||
import warnings | ||
|
||
from mmcv.runner import DistEvalHook as BaseDistEvalHook | ||
from mmcv.runner import EvalHook as BaseEvalHook | ||
|
||
MMHUMAN3D_GREATER_KEYS = ['3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc'] | ||
MMHUMAN3D_LESS_KEYS = ['mpjpe', 'pa-mpjpe', 'pve'] | ||
|
||
|
||
class EvalHook(BaseEvalHook): | ||
|
||
def __init__(self, | ||
dataloader, | ||
start=None, | ||
interval=1, | ||
by_epoch=True, | ||
save_best=None, | ||
rule=None, | ||
test_fn=None, | ||
greater_keys=MMHUMAN3D_GREATER_KEYS, | ||
less_keys=MMHUMAN3D_LESS_KEYS, | ||
**eval_kwargs): | ||
if test_fn is None: | ||
from mmhuman3d.apis import single_gpu_test | ||
test_fn = single_gpu_test | ||
|
||
# remove "gpu_collect" from eval_kwargs | ||
if 'gpu_collect' in eval_kwargs: | ||
warnings.warn( | ||
'"gpu_collect" will be deprecated in EvalHook.' | ||
'Please remove it from the config.', DeprecationWarning) | ||
_ = eval_kwargs.pop('gpu_collect') | ||
|
||
# update "save_best" according to "key_indicator" and remove the | ||
# latter from eval_kwargs | ||
if 'key_indicator' in eval_kwargs or isinstance(save_best, bool): | ||
warnings.warn( | ||
'"key_indicator" will be deprecated in EvalHook.' | ||
'Please use "save_best" to specify the metric key,' | ||
'e.g., save_best="pa-mpjpe".', DeprecationWarning) | ||
|
||
key_indicator = eval_kwargs.pop('key_indicator', None) | ||
if save_best is True and key_indicator is None: | ||
raise ValueError('key_indicator should not be None, when ' | ||
'save_best is set to True.') | ||
save_best = key_indicator | ||
|
||
super().__init__(dataloader, start, interval, by_epoch, save_best, | ||
rule, test_fn, greater_keys, less_keys, **eval_kwargs) | ||
|
||
def evaluate(self, runner, results): | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
eval_res = self.dataloader.dataset.evaluate( | ||
results, | ||
res_folder=tmp_dir, | ||
logger=runner.logger, | ||
**self.eval_kwargs) | ||
|
||
for name, val in eval_res.items(): | ||
runner.log_buffer.output[name] = val | ||
runner.log_buffer.ready = True | ||
|
||
if self.save_best is not None: | ||
if self.key_indicator == 'auto': | ||
self._init_rule(self.rule, list(eval_res.keys())[0]) | ||
|
||
return eval_res[self.key_indicator] | ||
|
||
return None | ||
|
||
|
||
class DistEvalHook(BaseDistEvalHook): | ||
|
||
def __init__(self, | ||
dataloader, | ||
start=None, | ||
interval=1, | ||
by_epoch=True, | ||
save_best=None, | ||
rule=None, | ||
test_fn=None, | ||
greater_keys=MMHUMAN3D_GREATER_KEYS, | ||
less_keys=MMHUMAN3D_LESS_KEYS, | ||
broadcast_bn_buffer=True, | ||
tmpdir=None, | ||
gpu_collect=False, | ||
**eval_kwargs): | ||
|
||
if test_fn is None: | ||
from mmhuman3d.apis import multi_gpu_test | ||
test_fn = multi_gpu_test | ||
|
||
# update "save_best" according to "key_indicator" and remove the | ||
# latter from eval_kwargs | ||
if 'key_indicator' in eval_kwargs or isinstance(save_best, bool): | ||
warnings.warn( | ||
'"key_indicator" will be deprecated in EvalHook.' | ||
'Please use "save_best" to specify the metric key,' | ||
'e.g., save_best="pa-mpjpe".', DeprecationWarning) | ||
|
||
key_indicator = eval_kwargs.pop('key_indicator', None) | ||
if save_best is True and key_indicator is None: | ||
raise ValueError('key_indicator should not be None, when ' | ||
'save_best is set to True.') | ||
save_best = key_indicator | ||
|
||
super().__init__(dataloader, start, interval, by_epoch, save_best, | ||
rule, test_fn, greater_keys, less_keys, | ||
broadcast_bn_buffer, tmpdir, gpu_collect, | ||
**eval_kwargs) | ||
|
||
def evaluate(self, runner, results): | ||
"""Evaluate the results. | ||
Args: | ||
runner (:obj:`mmcv.Runner`): The underlined training runner. | ||
results (list): Output results. | ||
""" | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
eval_res = self.dataloader.dataset.evaluate( | ||
results, | ||
res_folder=tmp_dir, | ||
logger=runner.logger, | ||
**self.eval_kwargs) | ||
|
||
for name, val in eval_res.items(): | ||
runner.log_buffer.output[name] = val | ||
runner.log_buffer.ready = True | ||
|
||
if self.save_best is not None: | ||
if self.key_indicator == 'auto': | ||
# infer from eval_results | ||
self._init_rule(self.rule, list(eval_res.keys())[0]) | ||
return eval_res[self.key_indicator] | ||
|
||
return None |
Oops, something went wrong.