-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
34 lines (27 loc) · 1009 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import os
from models import lenet, resnet, alexnet
class Model(nn.Module):
def __init__(self, args):
super(Model, self).__init__()
# Model
if args.dataset == 'MNIST':
self.model = lenet.LeNet(input_dim=1)
elif args.dataset == 'CIFAR10':
self.model = resnet.ResNet18()
# self.model = resnet.ResNet34()
# self.model = resnet.ResNet50()
# self.model = alexnet.AlexNet()
elif args.dataset == 'ImageNet':
self.model = resnet.ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=1000, latent_feature=2048)
def forward(self, image):
output = self.model(image)
return output
def save(self, path):
checkpoint = {'model': self.state_dict()}
torch.save(checkpoint, path)
def load(self, path):
path = os.path.abspath(path)
checkpoint = torch.load(path)
self.load_state_dict(checkpoint['model'])