-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
87 lines (68 loc) · 2.36 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import asyncio
import time
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from rnn_data import NamesDataset, transform_batch
from model import SimpleRNN, SimpleLinear
from config import config, Batch
from train import Trainer
@contextmanager
def timer():
"""Helper for measuring runtime"""
time0 = time.perf_counter()
yield
print('[elapsed time: %.2f s]' % (time.perf_counter() - time0))
def configure_dataloaders(data_dir: Path) -> Tuple[DataLoader, DataLoader]:
if config.dataset == 'mnist':
def create_loader(is_train_loader):
return DataLoader(
MNIST(
data_dir,
train=is_train_loader,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
),
# yield batches for every client
batch_size=config.n_parties * config.batch_size,
)
else:
def create_loader(is_train_loader):
return DataLoader(
NamesDataset(),
# yield batches for every client
batch_size=config.n_parties * config.batch_size,
collate_fn=transform_batch
)
return (create_loader(True), create_loader(False))
def configure_model() -> torch.nn.Module:
if config.dataset == 'mnist':
model = SimpleLinear(in_size=28 * 28, out_size=10)
else:
num_langs = len(loaders[0].dataset.langs)
vocab_size = len(loaders[0].dataset.char2index)
model = SimpleRNN(in_size=vocab_size, hidden_size=config.hidden_size, out_size=num_langs)
return model
if __name__ == '__main__':
data_dir = Path(__file__).parent / 'data/'
data_dir.mkdir(parents=True, exist_ok=True)
loaders = configure_dataloaders(data_dir)
model = configure_model()
trainer = Trainer(
model=model,
train_loader=loaders[0],
valid_loader=loaders[1],
)
try:
trainer.fit()
except KeyboardInterrupt:
exit(0)