|
15 | 15 | import torch
|
16 | 16 |
|
17 | 17 | from super_gradients.common import MultiGPUMode
|
| 18 | +from super_gradients.common.object_names import Models |
18 | 19 | from super_gradients.training.datasets.datasets_utils import RandomResizedCropAndInterpolation
|
19 | 20 | from torchvision.transforms import RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
|
20 | 21 | import super_gradients
|
21 | 22 | from super_gradients.training import Trainer, models, dataloaders
|
22 | 23 | import argparse
|
23 | 24 | from super_gradients.training.metrics import Accuracy, Top5
|
24 | 25 | from super_gradients.training.datasets.data_augmentation import RandomErase
|
| 26 | + |
25 | 27 | parser = argparse.ArgumentParser()
|
26 | 28 | super_gradients.init_trainer()
|
27 | 29 |
|
28 | 30 | parser.add_argument("--reload", action="store_true")
|
29 | 31 | parser.add_argument("--max_epochs", type=int, default=100)
|
30 | 32 | parser.add_argument("--batch", type=int, default=3)
|
31 | 33 | parser.add_argument("--experiment_name", type=str, default="ddrnet_23")
|
32 |
| -parser.add_argument("-s", "--slim", action="store_true", help='train the slim version of DDRNet23') |
| 34 | +parser.add_argument("-s", "--slim", action="store_true", help="train the slim version of DDRNet23") |
33 | 35 |
|
34 | 36 | args, _ = parser.parse_known_args()
|
35 | 37 | distributed = super_gradients.is_distributed()
|
36 | 38 | devices = torch.cuda.device_count() if not distributed else 1
|
37 | 39 |
|
38 |
| -train_params_ddr = {"max_epochs": args.max_epochs, |
39 |
| - "lr_mode": "step", |
40 |
| - "lr_updates": [30, 60, 90], |
41 |
| - "lr_decay_factor": 0.1, |
42 |
| - "initial_lr": 0.1 * devices, |
43 |
| - "optimizer": "SGD", |
44 |
| - "optimizer_params": {"weight_decay": 0.0001, "momentum": 0.9, "nesterov": True}, |
45 |
| - "loss": "cross_entropy", |
46 |
| - "train_metrics_list": [Accuracy(), Top5()], |
47 |
| - "valid_metrics_list": [Accuracy(), Top5()], |
48 |
| - |
49 |
| - "metric_to_watch": "Accuracy", |
50 |
| - "greater_metric_to_watch_is_better": True |
51 |
| - } |
52 |
| - |
53 |
| -dataset_params = {"batch_size": args.batch, |
54 |
| - "color_jitter": 0.4, |
55 |
| - "random_erase_prob": 0.2, |
56 |
| - "random_erase_value": 'random', |
57 |
| - "train_interpolation": 'random', |
58 |
| - } |
59 |
| - |
60 |
| - |
61 |
| -train_transforms = [RandomResizedCropAndInterpolation(size=224, interpolation="random"), |
62 |
| - RandomHorizontalFlip(), |
63 |
| - ColorJitter(0.4, 0.4, 0.4), |
64 |
| - ToTensor(), |
65 |
| - Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
66 |
| - RandomErase(0.2, "random") |
67 |
| - ] |
68 |
| - |
69 |
| -trainer = Trainer(experiment_name=args.experiment_name, |
70 |
| - multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL, |
71 |
| - device='cuda') |
72 |
| - |
73 |
| -train_loader = dataloaders.imagenet_train(dataset_params={"transforms": train_transforms}, |
74 |
| - dataloader_params={"batch_size": args.batch}) |
| 40 | +train_params_ddr = { |
| 41 | + "max_epochs": args.max_epochs, |
| 42 | + "lr_mode": "step", |
| 43 | + "lr_updates": [30, 60, 90], |
| 44 | + "lr_decay_factor": 0.1, |
| 45 | + "initial_lr": 0.1 * devices, |
| 46 | + "optimizer": "SGD", |
| 47 | + "optimizer_params": {"weight_decay": 0.0001, "momentum": 0.9, "nesterov": True}, |
| 48 | + "loss": "cross_entropy", |
| 49 | + "train_metrics_list": [Accuracy(), Top5()], |
| 50 | + "valid_metrics_list": [Accuracy(), Top5()], |
| 51 | + "metric_to_watch": "Accuracy", |
| 52 | + "greater_metric_to_watch_is_better": True, |
| 53 | +} |
| 54 | + |
| 55 | +dataset_params = { |
| 56 | + "batch_size": args.batch, |
| 57 | + "color_jitter": 0.4, |
| 58 | + "random_erase_prob": 0.2, |
| 59 | + "random_erase_value": "random", |
| 60 | + "train_interpolation": "random", |
| 61 | +} |
| 62 | + |
| 63 | + |
| 64 | +train_transforms = [ |
| 65 | + RandomResizedCropAndInterpolation(size=224, interpolation="random"), |
| 66 | + RandomHorizontalFlip(), |
| 67 | + ColorJitter(0.4, 0.4, 0.4), |
| 68 | + ToTensor(), |
| 69 | + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| 70 | + RandomErase(0.2, "random"), |
| 71 | +] |
| 72 | + |
| 73 | +trainer = Trainer( |
| 74 | + experiment_name=args.experiment_name, multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL if distributed else MultiGPUMode.DATA_PARALLEL, device="cuda" |
| 75 | +) |
| 76 | + |
| 77 | +train_loader = dataloaders.imagenet_train(dataset_params={"transforms": train_transforms}, dataloader_params={"batch_size": args.batch}) |
75 | 78 | valid_loader = dataloaders.imagenet_val()
|
76 | 79 |
|
77 |
| -model = models.get("ddrnet_23_slim" if args.slim else "ddrnet_23", |
78 |
| - arch_params={"aux_head": False, "classification_mode": True, 'dropout_prob': 0.3}, |
79 |
| - num_classes=1000) |
| 80 | +model = models.get( |
| 81 | + Models.DDRNET_23_SLIM if args.slim else Models.DDRNET_23, |
| 82 | + arch_params={"aux_head": False, "classification_mode": True, "dropout_prob": 0.3}, |
| 83 | + num_classes=1000, |
| 84 | +) |
80 | 85 |
|
81 | 86 | trainer.train(model=model, training_params=train_params_ddr, train_loader=train_loader, valid_loader=valid_loader)
|
0 commit comments