Skip to content

Commit

Permalink
add gcn_conv_test
Browse files Browse the repository at this point in the history
  • Loading branch information
JianwuZheng413 committed Aug 13, 2024
1 parent 8403cb9 commit 117f2fb
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions test/nn/conv/test_gcn_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from rllm.nn.conv import GCNConv


def test_gcn_conv():
node_size = 4
in_dim = 16
out_dim = 8

# Feature-based embeddings and adj
x = torch.randn(size=(node_size, in_dim))
adj = torch.tensor([[1., 1., 1., 1.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],])

conv = GCNConv(in_dim, out_dim)
assert str(conv) == 'GCNConv(16, 8)'

x_out = conv(x, adj)
assert x_out.shape == (node_size, out_dim)

0 comments on commit 117f2fb

Please sign in to comment.