-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel_def.py
114 lines (90 loc) · 4.29 KB
/
model_def.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
This example shows how to interact with the Determined PyTorch interface to
build a basic MNIST network.
In the `__init__` method, the model and optimizer are wrapped with `wrap_model`
and `wrap_optimizer`. This model is single-input and single-output.
The methods `train_batch` and `evaluate_batch` define the forward pass
for training and evaluation respectively.
"""
from typing import Any, Dict, Sequence, Tuple, Union, cast
import data
import torch
from layers import Flatten
from torch import nn
import determined as det
from determined.pytorch import DataLoader, PyTorchTrial, PyTorchTrialContext
TorchData = Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]
class MNistTrial(PyTorchTrial):
def __init__(self, context: PyTorchTrialContext) -> None:
self.context = context
# Create a unique download directory for each rank so they don't overwrite each
# other when doing distributed training.
self.download_directory = f"/tmp/data-rank{self.context.distributed.get_rank()}"
self.data_downloaded = False
self.model = self.context.wrap_model(
nn.Sequential(
nn.Conv2d(1, self.context.get_hparam("n_filters1"), 3, 1),
nn.ReLU(),
nn.Conv2d(
self.context.get_hparam("n_filters1"),
self.context.get_hparam("n_filters2"),
3,
),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout2d(self.context.get_hparam("dropout1")),
Flatten(),
nn.Linear(144 * self.context.get_hparam("n_filters2"), 128),
nn.ReLU(),
nn.Dropout2d(self.context.get_hparam("dropout2")),
nn.Linear(128, 10),
nn.LogSoftmax(),
)
)
self.optimizer = self.context.wrap_optimizer(
torch.optim.Adadelta(
self.model.parameters(), lr=self.context.get_hparam("learning_rate")
)
)
if abs(self.context.get_hparam("n_filters2") - self.context.get_hparam("n_filters1")) < 10:
raise det.InvalidHP("example for using HP search constraints API")
def train_batch(
self, batch: TorchData, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
# A det.InvalidHP can be raised here too based on the value of loss
# or other training metrics
batch = cast(Tuple[torch.Tensor, torch.Tensor], batch)
data, labels = batch
output = self.model(data)
loss = torch.nn.functional.nll_loss(output, labels)
self.context.backward(loss)
self.context.step_optimizer(self.optimizer)
return {"loss": loss}
def evaluate_batch(self, batch: TorchData) -> Dict[str, Any]:
# A det.InvalidHP can be raised here too based on the value of validation_loss,
# accuracy, or other validation metrics
batch = cast(Tuple[torch.Tensor, torch.Tensor], batch)
data, labels = batch
output = self.model(data)
validation_loss = torch.nn.functional.nll_loss(output, labels).item()
pred = output.argmax(dim=1, keepdim=True)
accuracy = pred.eq(labels.view_as(pred)).sum().item() / len(data)
return {"validation_loss": validation_loss, "accuracy": accuracy}
def build_training_data_loader(self) -> DataLoader:
if not self.data_downloaded:
self.download_directory = data.download_dataset(
download_directory=self.download_directory,
data_config=self.context.get_data_config(),
)
self.data_downloaded = True
train_data = data.get_dataset(self.download_directory, train=True)
return DataLoader(train_data, batch_size=self.context.get_per_slot_batch_size())
def build_validation_data_loader(self) -> DataLoader:
if not self.data_downloaded:
self.download_directory = data.download_dataset(
download_directory=self.download_directory,
data_config=self.context.get_data_config(),
)
self.data_downloaded = True
validation_data = data.get_dataset(self.download_directory, train=False)
return DataLoader(validation_data, batch_size=self.context.get_per_slot_batch_size())