@@ -16,66 +16,68 @@ class TestTrainer(unittest.TestCase):
16
16
def setUp (cls ):
17
17
super_gradients .init_trainer ()
18
18
# NAMES FOR THE EXPERIMENTS TO LATER DELETE
19
- cls .folder_names = ['test_train' , 'test_save_load' , 'test_load_w' , 'test_load_w2' ,
20
- 'test_load_w3' , 'test_checkpoint_content' , 'analyze' ]
21
- cls .training_params = {"max_epochs" : 1 ,
22
- "silent_mode" : True ,
23
- "lr_decay_factor" : 0.1 ,
24
- "initial_lr" : 0.1 ,
25
- "lr_updates" : [4 ],
26
- "lr_mode" : "step" ,
27
- "loss" : "cross_entropy" , "train_metrics_list" : [Accuracy (), Top5 ()],
28
- "valid_metrics_list" : [Accuracy (), Top5 ()],
29
- "metric_to_watch" : "Accuracy" ,
30
- "greater_metric_to_watch_is_better" : True }
19
+ cls .folder_names = ["test_train" , "test_save_load" , "test_load_w" , "test_load_w2" , "test_load_w3" , "test_checkpoint_content" , "analyze" ]
20
+ cls .training_params = {
21
+ "max_epochs" : 1 ,
22
+ "silent_mode" : True ,
23
+ "lr_decay_factor" : 0.1 ,
24
+ "initial_lr" : 0.1 ,
25
+ "lr_updates" : [4 ],
26
+ "lr_mode" : "step" ,
27
+ "loss" : "cross_entropy" ,
28
+ "train_metrics_list" : [Accuracy (), Top5 ()],
29
+ "valid_metrics_list" : [Accuracy (), Top5 ()],
30
+ "metric_to_watch" : "Accuracy" ,
31
+ "greater_metric_to_watch_is_better" : True ,
32
+ }
31
33
32
34
@classmethod
33
35
def tearDownClass (cls ) -> None :
34
36
# ERASE ALL THE FOLDERS THAT WERE CREATED DURING THIS TEST
35
37
for folder in cls .folder_names :
36
- if os .path .isdir (os .path .join (' checkpoints' , folder )):
37
- shutil .rmtree (os .path .join (' checkpoints' , folder ))
38
+ if os .path .isdir (os .path .join (" checkpoints" , folder )):
39
+ shutil .rmtree (os .path .join (" checkpoints" , folder ))
38
40
39
41
@staticmethod
40
- def get_classification_trainer (name = '' ):
42
+ def get_classification_trainer (name = "" ):
41
43
trainer = Trainer (name )
42
44
model = models .get ("resnet18" , num_classes = 5 )
43
45
return trainer , model
44
46
45
47
def test_train (self ):
46
48
trainer , model = self .get_classification_trainer (self .folder_names [0 ])
47
- trainer .train (model = model , training_params = self .training_params , train_loader = classification_test_dataloader (),
48
- valid_loader = classification_test_dataloader ())
49
+ trainer .train (
50
+ model = model , training_params = self .training_params , train_loader = classification_test_dataloader (), valid_loader = classification_test_dataloader ()
51
+ )
49
52
50
53
def test_save_load (self ):
51
54
trainer , model = self .get_classification_trainer (self .folder_names [1 ])
52
- trainer .train (model = model , training_params = self .training_params , train_loader = classification_test_dataloader (),
53
- valid_loader = classification_test_dataloader ())
55
+ trainer .train (
56
+ model = model , training_params = self .training_params , train_loader = classification_test_dataloader (), valid_loader = classification_test_dataloader ()
57
+ )
54
58
resume_training_params = self .training_params .copy ()
55
59
resume_training_params ["resume" ] = True
56
60
resume_training_params ["max_epochs" ] = 2
57
61
trainer , model = self .get_classification_trainer (self .folder_names [1 ])
58
- trainer .train (model = model , training_params = resume_training_params ,
59
- train_loader = classification_test_dataloader (),
60
- valid_loader = classification_test_dataloader () )
62
+ trainer .train (
63
+ model = model , training_params = resume_training_params , train_loader = classification_test_dataloader (), valid_loader = classification_test_dataloader ()
64
+ )
61
65
62
66
def test_checkpoint_content (self ):
63
67
"""VERIFY THAT ALL CHECKPOINTS ARE SAVED AND CONTAIN ALL THE EXPECTED KEYS"""
64
68
trainer , model = self .get_classification_trainer (self .folder_names [5 ])
65
69
params = self .training_params .copy ()
66
70
params ["save_ckpt_epoch_list" ] = [1 ]
67
- trainer .train (model = model , training_params = params , train_loader = classification_test_dataloader (),
68
- valid_loader = classification_test_dataloader ())
69
- ckpt_filename = ['ckpt_best.pth' , 'ckpt_latest.pth' , 'ckpt_epoch_1.pth' ]
71
+ trainer .train (model = model , training_params = params , train_loader = classification_test_dataloader (), valid_loader = classification_test_dataloader ())
72
+ ckpt_filename = ["ckpt_best.pth" , "ckpt_latest.pth" , "ckpt_epoch_1.pth" ]
70
73
ckpt_paths = [os .path .join (trainer .checkpoints_dir_path , suf ) for suf in ckpt_filename ]
71
74
for ckpt_path in ckpt_paths :
72
75
ckpt = torch .load (ckpt_path )
73
- self .assertListEqual (['net' , 'acc' , 'epoch' , 'optimizer_state_dict' , 'scaler_state_dict' ],
74
- list (ckpt .keys ()))
76
+ self .assertListEqual (["net" , "acc" , "epoch" , "optimizer_state_dict" , "scaler_state_dict" ], list (ckpt .keys ()))
75
77
trainer ._save_checkpoint ()
76
- weights_only = torch .load (os .path .join (trainer .checkpoints_dir_path , ' ckpt_latest_weights_only.pth' ))
77
- self .assertListEqual ([' net' ], list (weights_only .keys ()))
78
+ weights_only = torch .load (os .path .join (trainer .checkpoints_dir_path , " ckpt_latest_weights_only.pth" ))
79
+ self .assertListEqual ([" net" ], list (weights_only .keys ()))
78
80
79
81
80
- if __name__ == ' __main__' :
82
+ if __name__ == " __main__" :
81
83
unittest .main ()
0 commit comments