Skip to content

Commit 386f30b

Browse files
added correct kwargs to lifting layer
1 parent a3998cf commit 386f30b

File tree

6 files changed

+102
-4
lines changed

6 files changed

+102
-4
lines changed

gconv/gnn/kernels/kernel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def forward(self, H) -> Tensor:
181181
weight = self.sample_Rn(
182182
self.weight.repeat_interleave(H.shape[0], dim=0),
183183
H_product.repeat(self.out_channels, *product_dims),
184-
**self.sample_H_kwargs,
184+
**self.sample_Rn_kwargs,
185185
).view(
186186
self.out_channels, num_H, self.in_channels // self.groups, *self.kernel_size
187187
)

gconv/gnn/modules/gconv.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
"""
66
from __future__ import annotations
77

8+
from matplotlib import pyplot as plt
9+
810
from gconv.gnn.kernels import (
911
GroupKernel,
1012
GLiftingKernel,

gconv/tests/test_equivariance.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import sys
2+
3+
sys.path.append("..")
4+
5+
6+
from gconv.gnn import GLiftingConvSE3
7+
from gconv.geometry.groups import so3 as R
8+
from gconv.gnn import functional as gF
9+
import torch
10+
11+
from matplotlib import pyplot as plt
12+
13+
from torch.nn.functional import grid_sample
14+
15+
16+
def plot_activations(activations):
17+
B, _, H, *_ = activations.shape
18+
fig = plt.figure()
19+
for i in range(B):
20+
for j in range(H):
21+
ax = fig.add_subplot(B, H, 1 + j + i * H)
22+
ax.imshow(activations[i, 1, j, 2].detach().numpy())
23+
ax.axis(False)
24+
plt.show()
25+
26+
27+
def test_se3_lifting_conv():
28+
torch.manual_seed(0)
29+
30+
batch_size = 1
31+
in_channels = 2
32+
out_channels = 3
33+
kernel_size = 5
34+
group_kernel_size = 4
35+
groups = 1
36+
bias = False
37+
38+
input = torch.zeros(batch_size, in_channels, 5, 5, 5)
39+
input[:, :, 2, 2, :] = 1
40+
41+
grid_H = torch.Tensor(
42+
[
43+
[
44+
[1, 0, 0],
45+
[0, 1, 0],
46+
[0, 0, 1],
47+
],
48+
[
49+
[-1, 0, 0],
50+
[0, -1, 0],
51+
[0, 0, 1],
52+
],
53+
]
54+
)
55+
from math import pi
56+
57+
grid_H = R.matrix_z(torch.linspace(0, 2 * pi, 5)[:-1])
58+
59+
grid_R3 = gF.create_grid_R3(5)
60+
61+
grid_R3_rotated = R.left_apply_to_R3(grid_H, grid_R3)
62+
input_rotated = grid_sample(
63+
input.repeat(grid_H.shape[0], 1, 1, 1, 1),
64+
grid_R3_rotated,
65+
mode="nearest",
66+
padding_mode="zeros",
67+
)
68+
69+
model = GLiftingConvSE3(
70+
in_channels,
71+
out_channels,
72+
kernel_size,
73+
group_kernel_size=group_kernel_size,
74+
padding="same",
75+
groups=groups,
76+
bias=bias,
77+
sampling_mode="nearest",
78+
sampling_padding_mode="zeros",
79+
mask=True,
80+
permute_output_grid=False,
81+
)
82+
83+
output, H = model(input_rotated, grid_H)
84+
85+
# plot_activations(input[:, :, None])
86+
# plot_activations(input_rotated[:, :, None])
87+
print(output.shape)
88+
plot_activations(output)
89+
90+
91+
def main():
92+
test_se3_lifting_conv()
93+
94+
95+
if __name__ == "__main__":
96+
main()

gconv/tests/test_gconv_e2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from gconv.nn import GLiftingConvE2, GSeparableConvE2, GConvE2
7+
from gconv.gnn import GLiftingConvE2, GSeparableConvE2, GConvE2
88
from gconv.geometry.groups import o2
99

1010

gconv/tests/test_gconv_e3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from gconv.nn import GLiftingConvE3, GSeparableConvE3, GConvE3
8+
from gconv.gnn import GLiftingConvE3, GSeparableConvE3, GConvE3
99
from gconv.geometry.groups import o3
1010

1111

gconv/tests/test_kernel_se3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
sys.path.append("..")
44

5-
from gconv.nn import kernels
5+
from gconv.gnn import kernels
66
from gconv.geometry.groups import so3
77

88

0 commit comments

Comments
 (0)