Skip to content

Commit

Permalink
Fix test code
Browse files Browse the repository at this point in the history
  • Loading branch information
KKIEEK authored and KKIEEK committed Nov 30, 2022
1 parent 5400e41 commit f9135f4
Showing 1 changed file with 12 additions and 18 deletions.
30 changes: 12 additions & 18 deletions tests/test_mm/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,42 +161,36 @@ def test_discrete_test_function(mock_report):
assert isinstance(get_session().get('result'), float)


@patch.object(MMSegmentation, 'train_model', return_value=None)
@patch.object(MMSegmentation, 'train_model')
@patch.object(MMSegmentation, 'build_model')
@patch.object(MMSegmentation, 'build_dataset')
def test_mmseg(mock_build_dataset, mock_train_model):
mock_build_dataset.return_value.CLASSES = ['a', 'b', 'c']
def test_mmseg(*not_used):
os.environ['LOCAL_RANK'] = '0'

config_path = 'configs/mmseg/pspnet/pspnet_r18-d8_4x4_512x512_80k_potsdam.py' # noqa

task = MMSegmentation()
task.set_args([config_path])
task.set_args(['tests/data/test_config.py'])
task.run(args=task.args)


@patch.object(MMDetection, 'train_model', return_value=None)
@patch.object(MMDetection, 'train_model')
@patch.object(MMDetection, 'build_model')
@patch.object(MMDetection, 'build_dataset')
def test_mmdet(mock_build_dataset, mock_train_model):
mock_build_dataset.return_value.CLASSES = ['a', 'b', 'c']
def test_mmdet(*not_used):
os.environ['LOCAL_RANK'] = '0'

config_path = 'configs/mmdet/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'

task = MMDetection()
task.set_args([config_path])
task.set_args(['tests/data/test_config.py'])
task.run(args=task.args)


@patch.object(MMClassification, 'train_model', return_value=None)
@patch.object(MMClassification, 'train_model')
@patch.object(MMClassification, 'build_model')
@patch.object(MMClassification, 'build_dataset')
def test_mmcls(mock_build_dataset, mock_train_model):
mock_build_dataset.return_value.CLASSES = ['a', 'b', 'c']
def test_mmcls(*not_used):
os.environ['LOCAL_RANK'] = '0'

config_path = 'configs/mmcls/resnet/resnet18_8xb16_cifar10.py'

task = MMClassification()
task.set_args([config_path])
task.set_args(['tests/data/test_config.py'])
task.run(args=task.args)


Expand Down

0 comments on commit f9135f4

Please sign in to comment.