forked from irrmnv/pytorch-ieee-cmi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_models.py
35 lines (30 loc) · 1.1 KB
/
custom_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import models
class AvgPool(nn.Module):
def forward(self, x):
return torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)))
class ResNet(nn.Module):
def __init__(self, num_classes, net_cls=models.resnet50, pretrained=False):
super().__init__()
self.net = net_cls(pretrained=pretrained)
self.net.avgpool = AvgPool()
self.fc = nn.Sequential(
nn.Linear(self.net.fc.in_features+1, 512),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
self.net.fc = nn.Dropout(0.0)
def fresh_params(self):
return self.net.fc.parameters()
def forward(self, x, O): #0, 1, 2, 3 -> (0, 3, 1, 2)
out = torch.transpose(x, 1, 3) #0, 3, 2, 1
out = torch.transpose(out, 2, 3) #0, 3, 1, 2
out = self.net(out)
out = out.view(out.size(0), -1)
out = torch.cat([out, O], 1)
return self.fc(out)