-
Notifications
You must be signed in to change notification settings - Fork 101
/
Copy pathcapsule_layer.py
121 lines (94 loc) · 4.28 KB
/
capsule_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#
# Dynamic Routing Between Capsules
# https://arxiv.org/pdf/1710.09829.pdf
#
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import torch.nn.functional as F
class ConvUnit(nn.Module):
def __init__(self, in_channels):
super(ConvUnit, self).__init__()
self.conv0 = nn.Conv2d(in_channels=in_channels,
out_channels=32, # fixme constant
kernel_size=9, # fixme constant
stride=2, # fixme constant
bias=True)
def forward(self, x):
return self.conv0(x)
class CapsuleLayer(nn.Module):
def __init__(self, in_units, in_channels, num_units, unit_size, use_routing):
super(CapsuleLayer, self).__init__()
self.in_units = in_units
self.in_channels = in_channels
self.num_units = num_units
self.use_routing = use_routing
if self.use_routing:
# In the paper, the deeper capsule layer(s) with capsule inputs (DigitCaps) use a special routing algorithm
# that uses this weight matrix.
self.W = nn.Parameter(torch.randn(1, in_channels, num_units, unit_size, in_units))
else:
# The first convolutional capsule layer (PrimaryCapsules in the paper) does not perform routing.
# Instead, it is composed of several convolutional units, each of which sees the full input.
# It is implemented as a normal convolutional layer with a special nonlinearity (squash()).
def create_conv_unit(unit_idx):
unit = ConvUnit(in_channels=in_channels)
self.add_module("unit_" + str(unit_idx), unit)
return unit
self.units = [create_conv_unit(i) for i in range(self.num_units)]
@staticmethod
def squash(s):
# This is equation 1 from the paper.
mag_sq = torch.sum(s**2, dim=2, keepdim=True)
mag = torch.sqrt(mag_sq)
s = (mag_sq / (1.0 + mag_sq)) * (s / mag)
return s
def forward(self, x):
if self.use_routing:
return self.routing(x)
else:
return self.no_routing(x)
def no_routing(self, x):
# Get output for each unit.
# Each will be (batch, channels, height, width).
u = [self.units[i](x) for i in range(self.num_units)]
# Stack all unit outputs (batch, unit, channels, height, width).
u = torch.stack(u, dim=1)
# Flatten to (batch, unit, output).
u = u.view(x.size(0), self.num_units, -1)
# Return squashed outputs.
return CapsuleLayer.squash(u)
def routing(self, x):
batch_size = x.size(0)
# (batch, in_units, features) -> (batch, features, in_units)
x = x.transpose(1, 2)
# (batch, features, in_units) -> (batch, features, num_units, in_units, 1)
x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)
# (batch, features, in_units, unit_size, num_units)
W = torch.cat([self.W] * batch_size, dim=0)
# Transform inputs by weight matrix.
# (batch_size, features, num_units, unit_size, 1)
u_hat = torch.matmul(W, x)
# Initialize routing logits to zero.
b_ij = Variable(torch.zeros(1, self.in_channels, self.num_units, 1)).cuda()
# Iterative routing.
num_iterations = 3
for iteration in range(num_iterations):
# Convert routing logits to softmax.
# (batch, features, num_units, 1, 1)
c_ij = F.softmax(b_ij)
c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
# Apply routing (c_ij) to weighted inputs (u_hat).
# (batch_size, 1, num_units, unit_size, 1)
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
# (batch_size, 1, num_units, unit_size, 1)
v_j = CapsuleLayer.squash(s_j)
# (batch_size, features, num_units, unit_size, 1)
v_j1 = torch.cat([v_j] * self.in_channels, dim=1)
# (1, features, num_units, 1)
u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze(4).mean(dim=0, keepdim=True)
# Update b_ij (routing)
b_ij = b_ij + u_vj1
return v_j.squeeze(1)