Skip to content

Commit c4b4e37

Browse files
apply black on tests (#633)
Co-authored-by: Eugene Khvedchenya <[email protected]>
1 parent 4bbbf11 commit c4b4e37

39 files changed

+859
-634
lines changed

tests/deci_core_integration_test_suite_runner.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66

77
class CoreIntegrationTestSuiteRunner:
8-
98
def __init__(self):
109
self.test_loader = unittest.TestLoader()
1110
self.integration_tests_suite = unittest.TestSuite()
@@ -22,5 +21,5 @@ def _add_modules_to_integration_tests_suite(self):
2221
self.integration_tests_suite.addTest(self.test_loader.loadTestsFromModule(LRTest))
2322

2423

25-
if __name__ == '__main__':
24+
if __name__ == "__main__":
2625
unittest.main()

tests/end_to_end_tests/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44

55
from tests.end_to_end_tests.cifar_trainer_test import TestCifarTrainer
66

7-
__all__ = ['TestTrainer', 'TestCifarTrainer']
7+
__all__ = ["TestTrainer", "TestCifarTrainer"]

tests/end_to_end_tests/trainer_test.py

+32-30
Original file line numberDiff line numberDiff line change
@@ -16,66 +16,68 @@ class TestTrainer(unittest.TestCase):
1616
def setUp(cls):
1717
super_gradients.init_trainer()
1818
# NAMES FOR THE EXPERIMENTS TO LATER DELETE
19-
cls.folder_names = ['test_train', 'test_save_load', 'test_load_w', 'test_load_w2',
20-
'test_load_w3', 'test_checkpoint_content', 'analyze']
21-
cls.training_params = {"max_epochs": 1,
22-
"silent_mode": True,
23-
"lr_decay_factor": 0.1,
24-
"initial_lr": 0.1,
25-
"lr_updates": [4],
26-
"lr_mode": "step",
27-
"loss": "cross_entropy", "train_metrics_list": [Accuracy(), Top5()],
28-
"valid_metrics_list": [Accuracy(), Top5()],
29-
"metric_to_watch": "Accuracy",
30-
"greater_metric_to_watch_is_better": True}
19+
cls.folder_names = ["test_train", "test_save_load", "test_load_w", "test_load_w2", "test_load_w3", "test_checkpoint_content", "analyze"]
20+
cls.training_params = {
21+
"max_epochs": 1,
22+
"silent_mode": True,
23+
"lr_decay_factor": 0.1,
24+
"initial_lr": 0.1,
25+
"lr_updates": [4],
26+
"lr_mode": "step",
27+
"loss": "cross_entropy",
28+
"train_metrics_list": [Accuracy(), Top5()],
29+
"valid_metrics_list": [Accuracy(), Top5()],
30+
"metric_to_watch": "Accuracy",
31+
"greater_metric_to_watch_is_better": True,
32+
}
3133

3234
@classmethod
3335
def tearDownClass(cls) -> None:
3436
# ERASE ALL THE FOLDERS THAT WERE CREATED DURING THIS TEST
3537
for folder in cls.folder_names:
36-
if os.path.isdir(os.path.join('checkpoints', folder)):
37-
shutil.rmtree(os.path.join('checkpoints', folder))
38+
if os.path.isdir(os.path.join("checkpoints", folder)):
39+
shutil.rmtree(os.path.join("checkpoints", folder))
3840

3941
@staticmethod
40-
def get_classification_trainer(name=''):
42+
def get_classification_trainer(name=""):
4143
trainer = Trainer(name)
4244
model = models.get("resnet18", num_classes=5)
4345
return trainer, model
4446

4547
def test_train(self):
4648
trainer, model = self.get_classification_trainer(self.folder_names[0])
47-
trainer.train(model=model, training_params=self.training_params, train_loader=classification_test_dataloader(),
48-
valid_loader=classification_test_dataloader())
49+
trainer.train(
50+
model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
51+
)
4952

5053
def test_save_load(self):
5154
trainer, model = self.get_classification_trainer(self.folder_names[1])
52-
trainer.train(model=model, training_params=self.training_params, train_loader=classification_test_dataloader(),
53-
valid_loader=classification_test_dataloader())
55+
trainer.train(
56+
model=model, training_params=self.training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
57+
)
5458
resume_training_params = self.training_params.copy()
5559
resume_training_params["resume"] = True
5660
resume_training_params["max_epochs"] = 2
5761
trainer, model = self.get_classification_trainer(self.folder_names[1])
58-
trainer.train(model=model, training_params=resume_training_params,
59-
train_loader=classification_test_dataloader(),
60-
valid_loader=classification_test_dataloader())
62+
trainer.train(
63+
model=model, training_params=resume_training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
64+
)
6165

6266
def test_checkpoint_content(self):
6367
"""VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
6468
trainer, model = self.get_classification_trainer(self.folder_names[5])
6569
params = self.training_params.copy()
6670
params["save_ckpt_epoch_list"] = [1]
67-
trainer.train(model=model, training_params=params, train_loader=classification_test_dataloader(),
68-
valid_loader=classification_test_dataloader())
69-
ckpt_filename = ['ckpt_best.pth', 'ckpt_latest.pth', 'ckpt_epoch_1.pth']
71+
trainer.train(model=model, training_params=params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader())
72+
ckpt_filename = ["ckpt_best.pth", "ckpt_latest.pth", "ckpt_epoch_1.pth"]
7073
ckpt_paths = [os.path.join(trainer.checkpoints_dir_path, suf) for suf in ckpt_filename]
7174
for ckpt_path in ckpt_paths:
7275
ckpt = torch.load(ckpt_path)
73-
self.assertListEqual(['net', 'acc', 'epoch', 'optimizer_state_dict', 'scaler_state_dict'],
74-
list(ckpt.keys()))
76+
self.assertListEqual(["net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict"], list(ckpt.keys()))
7577
trainer._save_checkpoint()
76-
weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, 'ckpt_latest_weights_only.pth'))
77-
self.assertListEqual(['net'], list(weights_only.keys()))
78+
weights_only = torch.load(os.path.join(trainer.checkpoints_dir_path, "ckpt_latest_weights_only.pth"))
79+
self.assertListEqual(["net"], list(weights_only.keys()))
7880

7981

80-
if __name__ == '__main__':
82+
if __name__ == "__main__":
8183
unittest.main()

tests/integration_tests/conversion_callback_test.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from super_gradients.training import models
66

77
from super_gradients import Trainer
8-
from super_gradients.training.dataloaders.dataloaders import segmentation_test_dataloader, \
9-
classification_test_dataloader
8+
from super_gradients.training.dataloaders.dataloaders import segmentation_test_dataloader, classification_test_dataloader
109
from super_gradients.training.utils.callbacks import ModelConversionCheckCallback
1110
from super_gradients.training.metrics import Accuracy, Top5, IoU
1211
from super_gradients.training.losses.stdc_loss import STDCLoss
@@ -63,18 +62,17 @@ def test_classification_architectures(self):
6362
"criterion_params": {},
6463
"train_metrics_list": [Accuracy(), Top5()],
6564
"valid_metrics_list": [Accuracy(), Top5()],
66-
6765
"metric_to_watch": "Accuracy",
6866
"greater_metric_to_watch_is_better": True,
6967
"phase_callbacks": phase_callbacks,
7068
}
7169

72-
trainer = Trainer(f"{architecture}_example",
73-
ckpt_root_dir=checkpoint_dir)
70+
trainer = Trainer(f"{architecture}_example", ckpt_root_dir=checkpoint_dir)
7471
model = models.get(architecture=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
7572
try:
76-
trainer.train(model=model, training_params=train_params, train_loader=classification_test_dataloader(),
77-
valid_loader=classification_test_dataloader())
73+
trainer.train(
74+
model=model, training_params=train_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
75+
)
7876
except Exception as e:
7977
self.fail(f"Model training didn't succeed due to {e}")
8078
else:
@@ -84,26 +82,22 @@ def test_segmentation_architectures(self):
8482
def get_architecture_custom_config(architecture_name: str):
8583
if re.search(r"ddrnet", architecture_name):
8684
return {
87-
8885
"loss": DDRNetLoss(num_pixels_exclude_ignored=False),
8986
}
9087
elif re.search(r"stdc", architecture_name):
9188
return {
92-
9389
"loss": STDCLoss(num_classes=5),
9490
}
9591
elif re.search(r"regseg", architecture_name):
9692
return {
97-
9893
"loss": "cross_entropy",
9994
}
10095
else:
10196
raise Exception("You tried to run a conversion test on an unknown architecture")
10297

10398
for architecture in SEMANTIC_SEGMENTATION:
10499
model_meta_data = generate_model_metadata(architecture=architecture, task=Task.SEMANTIC_SEGMENTATION)
105-
trainer = Trainer(f"{architecture}_example",
106-
ckpt_root_dir=checkpoint_dir)
100+
trainer = Trainer(f"{architecture}_example", ckpt_root_dir=checkpoint_dir)
107101
model = models.get(model_name=architecture, arch_params={"use_aux_heads": True, "aux_head": True})
108102

109103
phase_callbacks = [
@@ -128,8 +122,12 @@ def get_architecture_custom_config(architecture_name: str):
128122
train_params.update(custom_config)
129123

130124
try:
131-
trainer.train(model=model, training_params=train_params, train_loader=segmentation_test_dataloader(image_size=512),
132-
valid_loader=segmentation_test_dataloader(image_size=512))
125+
trainer.train(
126+
model=model,
127+
training_params=train_params,
128+
train_loader=segmentation_test_dataloader(image_size=512),
129+
valid_loader=segmentation_test_dataloader(image_size=512),
130+
)
133131
except Exception as e:
134132
self.fail(f"Model training didn't succeed for {architecture} due to {e}")
135133
else:

tests/integration_tests/lr_test.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,26 @@ class LRTest(unittest.TestCase):
1313
@classmethod
1414
def setUp(cls):
1515
# NAMES FOR THE EXPERIMENTS TO LATER DELETE
16-
cls.folder_name = 'lr_test'
17-
cls.training_params = {"max_epochs": 1,
18-
"silent_mode": True,
19-
"initial_lr": 0.1,
20-
"loss": "cross_entropy", "train_metrics_list": [Accuracy(), Top5()],
21-
"valid_metrics_list": [Accuracy(), Top5()],
22-
"metric_to_watch": "Accuracy",
23-
"greater_metric_to_watch_is_better": True}
16+
cls.folder_name = "lr_test"
17+
cls.training_params = {
18+
"max_epochs": 1,
19+
"silent_mode": True,
20+
"initial_lr": 0.1,
21+
"loss": "cross_entropy",
22+
"train_metrics_list": [Accuracy(), Top5()],
23+
"valid_metrics_list": [Accuracy(), Top5()],
24+
"metric_to_watch": "Accuracy",
25+
"greater_metric_to_watch_is_better": True,
26+
}
2427

2528
@classmethod
2629
def tearDownClass(cls) -> None:
2730
# ERASE THE FOLDER THAT WAS CREATED DURING THIS TEST
28-
if os.path.isdir(os.path.join('checkpoints', cls.folder_name)):
29-
shutil.rmtree(os.path.join('checkpoints', cls.folder_name))
31+
if os.path.isdir(os.path.join("checkpoints", cls.folder_name)):
32+
shutil.rmtree(os.path.join("checkpoints", cls.folder_name))
3033

3134
@staticmethod
32-
def get_trainer(name=''):
35+
def get_trainer(name=""):
3336
trainer = Trainer(name)
3437
model = models.get("resnet18_cifar", num_classes=5)
3538
return trainer, model
@@ -42,26 +45,30 @@ def test_lr_function(initial_lr, epoch, iter, max_epoch, iters_per_epoch, **kwar
4245

4346
# test if we are able that lr_function supports functions with this structure
4447
training_params = {**self.training_params, "lr_mode": "function", "lr_schedule_function": test_lr_function}
45-
trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
46-
valid_loader=classification_test_dataloader())
48+
trainer.train(
49+
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
50+
)
4751
# test that we assert lr_function is callable
4852
training_params = {**self.training_params, "lr_mode": "function"}
4953
with self.assertRaises(AssertionError):
50-
trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
51-
valid_loader=classification_test_dataloader())
54+
trainer.train(
55+
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
56+
)
5257

5358
def test_cosine_lr(self):
5459
trainer, model = self.get_trainer(self.folder_name)
5560
training_params = {**self.training_params, "lr_mode": "cosine", "cosine_final_lr_ratio": 0.01}
56-
trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
57-
valid_loader=classification_test_dataloader())
61+
trainer.train(
62+
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
63+
)
5864

5965
def test_step_lr(self):
6066
trainer, model = self.get_trainer(self.folder_name)
6167
training_params = {**self.training_params, "lr_mode": "step", "lr_decay_factor": 0.1, "lr_updates": [4]}
62-
trainer.train(model=model, training_params=training_params, train_loader=classification_test_dataloader(),
63-
valid_loader=classification_test_dataloader())
68+
trainer.train(
69+
model=model, training_params=training_params, train_loader=classification_test_dataloader(), valid_loader=classification_test_dataloader()
70+
)
6471

6572

66-
if __name__ == '__main__':
73+
if __name__ == "__main__":
6774
unittest.main()

tests/unit_tests/all_architectures_test.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,30 @@
66

77

88
class AllArchitecturesTest(unittest.TestCase):
9-
109
def setUp(self):
1110
# contains all arch_params needed for initialization of all architectures
12-
self.all_arch_params = HpmStruct(**{'num_classes': 10,
13-
'width_mult': 1,
14-
'threshold': 1,
15-
'sml_net': torch.nn.Identity(),
16-
'big_net': torch.nn.Identity(),
17-
'dropout': 0,
18-
'build_residual_branches': True})
11+
self.all_arch_params = HpmStruct(
12+
**{
13+
"num_classes": 10,
14+
"width_mult": 1,
15+
"threshold": 1,
16+
"sml_net": torch.nn.Identity(),
17+
"big_net": torch.nn.Identity(),
18+
"dropout": 0,
19+
"build_residual_branches": True,
20+
}
21+
)
1922

2023
def test_architecture_is_sg_module(self):
2124
"""
2225
Validate all models from all_architectures.py are SgModule
2326
"""
2427
for arch_name in ARCHITECTURES:
2528
# skip custom constructors to keep all_arch_params as general as a possible
26-
if 'custom' in arch_name.lower() or 'nas' in arch_name.lower() or 'kd' in arch_name.lower():
29+
if "custom" in arch_name.lower() or "nas" in arch_name.lower() or "kd" in arch_name.lower():
2730
continue
2831
self.assertTrue(isinstance(ARCHITECTURES[arch_name](arch_params=self.all_arch_params), SgModule))
2932

3033

31-
if __name__ == '__main__':
34+
if __name__ == "__main__":
3235
unittest.main()

tests/unit_tests/average_meter_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def setUp(cls):
1414
cls.avg_tensor = AverageMeter()
1515
cls.left_empty = AverageMeter()
1616
cls.list_of_avg_meter = [cls.avg_float, cls.avg_tuple, cls.avg_list, cls.avg_tensor]
17-
cls.score_types = [1.2, (3., 4.), [5., 6., 7.], torch.FloatTensor([8., 9., 10.])]
17+
cls.score_types = [1.2, (3.0, 4.0), [5.0, 6.0, 7.0], torch.FloatTensor([8.0, 9.0, 10.0])]
1818
cls.batch_size = 3
1919

2020
def test_empty_return_0(self):
@@ -42,5 +42,5 @@ def test_correctness_and_typing(self):
4242
self.assertListEqual(list(avg_meter.average), list(score))
4343

4444

45-
if __name__ == '__main__':
45+
if __name__ == "__main__":
4646
unittest.main()

tests/unit_tests/cityscapes_dataset_test.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
import yaml
55
from torch.utils.data import DataLoader
66

7-
from super_gradients.training.dataloaders.dataloaders import cityscapes_train, cityscapes_val, \
8-
cityscapes_stdc_seg50_train, cityscapes_stdc_seg50_val, cityscapes_stdc_seg75_val, cityscapes_ddrnet_train, \
9-
cityscapes_regseg48_val, cityscapes_regseg48_train, cityscapes_ddrnet_val, cityscapes_stdc_seg75_train
7+
from super_gradients.training.dataloaders.dataloaders import (
8+
cityscapes_train,
9+
cityscapes_val,
10+
cityscapes_stdc_seg50_train,
11+
cityscapes_stdc_seg50_val,
12+
cityscapes_stdc_seg75_val,
13+
cityscapes_ddrnet_train,
14+
cityscapes_regseg48_val,
15+
cityscapes_regseg48_train,
16+
cityscapes_ddrnet_val,
17+
cityscapes_stdc_seg75_train,
18+
)
1019
from super_gradients.training.datasets.segmentation_datasets.cityscape_segmentation import CityscapesDataset
1120

1221

1322
class CityscapesDatasetTest(unittest.TestCase):
14-
1523
def setUp(self) -> None:
16-
default_config_path = pkg_resources.resource_filename("super_gradients.recipes",
17-
"dataset_params/cityscapes_dataset_params.yaml")
18-
with open(default_config_path, 'r') as file:
24+
default_config_path = pkg_resources.resource_filename("super_gradients.recipes", "dataset_params/cityscapes_dataset_params.yaml")
25+
with open(default_config_path, "r") as file:
1926
self.recipe = yaml.safe_load(file)
2027

2128
def dataloader_tester(self, dl: DataLoader):
@@ -26,12 +33,12 @@ def dataloader_tester(self, dl: DataLoader):
2633
next(it)
2734

2835
def test_train_dataset_creation(self):
29-
train_dataset = CityscapesDataset(**self.recipe['train_dataset_params'])
36+
train_dataset = CityscapesDataset(**self.recipe["train_dataset_params"])
3037
for i in range(10):
3138
image, mask = train_dataset[i]
3239

3340
def test_val_dataset_creation(self):
34-
val_dataset = CityscapesDataset(**self.recipe['val_dataset_params'])
41+
val_dataset = CityscapesDataset(**self.recipe["val_dataset_params"])
3542
for i in range(10):
3643
image, mask = val_dataset[i]
3744

@@ -76,5 +83,5 @@ def test_cityscapes_ddrnet_val_dataloader(self):
7683
self.dataloader_tester(dl_val)
7784

7885

79-
if __name__ == '__main__':
86+
if __name__ == "__main__":
8087
unittest.main()

0 commit comments

Comments
 (0)