forked from catalyst-team/catalyst
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_resnet.py
146 lines (123 loc) · 4.95 KB
/
train_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python
# flake8: noqa
from argparse import ArgumentParser, RawTextHelpFormatter
import os
from common import E2E
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from catalyst import dl, utils
from catalyst.contrib.datasets import CIFAR10
from catalyst.contrib.nn import ResidualBlock
from catalyst.data import transforms
def conv_block(in_channels, out_channels, pool=False):
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
]
if pool:
layers.append(nn.MaxPool2d(2))
return nn.Sequential(*layers)
def resnet9(in_channels: int, num_classes: int, size: int = 16):
sz, sz2, sz4, sz8 = size, size * 2, size * 4, size * 8
return nn.Sequential(
conv_block(in_channels, sz),
conv_block(sz, sz2, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz2, sz2), conv_block(sz2, sz2))),
conv_block(sz2, sz4, pool=True),
conv_block(sz4, sz8, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz8, sz8), conv_block(sz8, sz8))),
nn.Sequential(nn.MaxPool2d(4), nn.Flatten(), nn.Dropout(0.2), nn.Linear(sz8, num_classes)),
)
class CustomRunner(dl.IRunner):
def __init__(self, logdir: str, engine: str, sync_bn: bool = False):
super().__init__()
self._logdir = logdir
self._engine = engine
self._sync_bn = sync_bn
def get_engine(self):
return E2E[self._engine](sync_bn=True) if self._sync_bn else E2E[self._engine]()
def get_loggers(self):
return {
"console": dl.ConsoleLogger(),
"csv": dl.CSVLogger(logdir=self._logdir),
"tensorboard": dl.TensorboardLogger(logdir=self._logdir),
}
@property
def stages(self):
return ["train"]
def get_stage_len(self, stage: str) -> int:
return 10
def get_loaders(self, stage: str):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
train_data = CIFAR10(os.getcwd(), train=True, download=True, transform=transform)
valid_data = CIFAR10(os.getcwd(), train=False, download=True, transform=transform)
if self.engine.is_ddp:
train_sampler = DistributedSampler(
train_data,
num_replicas=self.engine.world_size,
rank=self.engine.rank,
shuffle=True,
)
valid_sampler = DistributedSampler(
valid_data,
num_replicas=self.engine.world_size,
rank=self.engine.rank,
shuffle=False,
)
else:
train_sampler = valid_sampler = None
return {
"train": DataLoader(train_data, batch_size=32, sampler=train_sampler, num_workers=4),
"valid": DataLoader(valid_data, batch_size=32, sampler=valid_sampler, num_workers=4),
}
def get_model(self, stage: str):
model = self.model if self.model is not None else resnet9(in_channels=3, num_classes=10)
return model
def get_criterion(self, stage: str):
return nn.CrossEntropyLoss()
def get_optimizer(self, stage: str, model):
return optim.Adam(model.parameters(), lr=1e-3)
def get_scheduler(self, stage: str, optimizer):
return optim.lr_scheduler.MultiStepLR(optimizer, [5, 8], gamma=0.3)
def get_callbacks(self, stage: str):
return {
"criterion": dl.CriterionCallback(
metric_key="loss", input_key="logits", target_key="targets"
),
"optimizer": dl.OptimizerCallback(metric_key="loss"),
"scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
"accuracy": dl.AccuracyCallback(
input_key="logits", target_key="targets", topk_args=(1, 3, 5)
),
"checkpoint": dl.CheckpointCallback(
self._logdir,
loader_key="valid",
metric_key="accuracy",
minimize=False,
save_n_best=1,
),
# "tqdm": dl.TqdmCallback(),
}
def handle_batch(self, batch):
x, y = batch
logits = self.model(x)
self.batch = {
"features": x,
"targets": y,
"logits": logits,
}
if __name__ == "__main__":
parser = ArgumentParser(formatter_class=RawTextHelpFormatter)
parser.add_argument("--logdir", type=str, default=None)
parser.add_argument("--engine", type=str, choices=list(E2E.keys()))
utils.boolean_flag(parser, "sync-bn", default=False)
args, _ = parser.parse_known_args()
args.logdir = args.logdir or f"logs_resnet_{args.engine}_sbn{int(args.sync_bn)}".replace(
"-", "_"
)
runner = CustomRunner(args.logdir, args.engine, args.sync_bn)
runner.run()