|
| 1 | +import sys |
| 2 | + |
| 3 | +sys.path.append("..") |
| 4 | + |
| 5 | + |
| 6 | +from gconv.gnn import GLiftingConvSE3 |
| 7 | +from gconv.geometry.groups import so3 as R |
| 8 | +from gconv.gnn import functional as gF |
| 9 | +import torch |
| 10 | + |
| 11 | +from matplotlib import pyplot as plt |
| 12 | + |
| 13 | +from torch.nn.functional import grid_sample |
| 14 | + |
| 15 | + |
| 16 | +def plot_activations(activations): |
| 17 | + B, _, H, *_ = activations.shape |
| 18 | + fig = plt.figure() |
| 19 | + for i in range(B): |
| 20 | + for j in range(H): |
| 21 | + ax = fig.add_subplot(B, H, 1 + j + i * H) |
| 22 | + ax.imshow(activations[i, 1, j, 2].detach().numpy()) |
| 23 | + ax.axis(False) |
| 24 | + plt.show() |
| 25 | + |
| 26 | + |
| 27 | +def test_se3_lifting_conv(): |
| 28 | + torch.manual_seed(0) |
| 29 | + |
| 30 | + batch_size = 1 |
| 31 | + in_channels = 2 |
| 32 | + out_channels = 3 |
| 33 | + kernel_size = 5 |
| 34 | + group_kernel_size = 4 |
| 35 | + groups = 1 |
| 36 | + bias = False |
| 37 | + |
| 38 | + input = torch.zeros(batch_size, in_channels, 5, 5, 5) |
| 39 | + input[:, :, 2, 2, :] = 1 |
| 40 | + |
| 41 | + grid_H = torch.Tensor( |
| 42 | + [ |
| 43 | + [ |
| 44 | + [1, 0, 0], |
| 45 | + [0, 1, 0], |
| 46 | + [0, 0, 1], |
| 47 | + ], |
| 48 | + [ |
| 49 | + [-1, 0, 0], |
| 50 | + [0, -1, 0], |
| 51 | + [0, 0, 1], |
| 52 | + ], |
| 53 | + ] |
| 54 | + ) |
| 55 | + from math import pi |
| 56 | + |
| 57 | + grid_H = R.matrix_z(torch.linspace(0, 2 * pi, 5)[:-1]) |
| 58 | + |
| 59 | + grid_R3 = gF.create_grid_R3(5) |
| 60 | + |
| 61 | + grid_R3_rotated = R.left_apply_to_R3(grid_H, grid_R3) |
| 62 | + input_rotated = grid_sample( |
| 63 | + input.repeat(grid_H.shape[0], 1, 1, 1, 1), |
| 64 | + grid_R3_rotated, |
| 65 | + mode="nearest", |
| 66 | + padding_mode="zeros", |
| 67 | + ) |
| 68 | + |
| 69 | + model = GLiftingConvSE3( |
| 70 | + in_channels, |
| 71 | + out_channels, |
| 72 | + kernel_size, |
| 73 | + group_kernel_size=group_kernel_size, |
| 74 | + padding="same", |
| 75 | + groups=groups, |
| 76 | + bias=bias, |
| 77 | + sampling_mode="nearest", |
| 78 | + sampling_padding_mode="zeros", |
| 79 | + mask=True, |
| 80 | + permute_output_grid=False, |
| 81 | + ) |
| 82 | + |
| 83 | + output, H = model(input_rotated, grid_H) |
| 84 | + |
| 85 | + # plot_activations(input[:, :, None]) |
| 86 | + # plot_activations(input_rotated[:, :, None]) |
| 87 | + print(output.shape) |
| 88 | + plot_activations(output) |
| 89 | + |
| 90 | + |
| 91 | +def main(): |
| 92 | + test_se3_lifting_conv() |
| 93 | + |
| 94 | + |
| 95 | +if __name__ == "__main__": |
| 96 | + main() |
0 commit comments