Skip to content

Commit

Permalink
Merge pull request #73 from rllm-team/dev_gnn
Browse files Browse the repository at this point in the history
update test and annotation
  • Loading branch information
JianwuZheng413 authored Aug 30, 2024
2 parents 54d1f37 + f9ddabe commit 25f765b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
14 changes: 8 additions & 6 deletions rllm/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Union, Tuple
from sympy import primefactors
from torch import Tensor
import torch
import torch.nn as nn
Expand Down Expand Up @@ -182,7 +181,7 @@ def forward(
attentions_edge, p=0.6, training=self.training)

# Element-wise product. Operator * does the same thing as torch.mul
# shape = (E, NH, F_OUT) * (E, NH, 1) -> (E, NH, F_OUT)
# shape = (E, H, F_OUT) * (E, H, 1) -> (E, H, F_OUT)
# 1 gets broadcast into F_OUT
nodes_features_weighted = nodes_features_selected * attentions_edge

Expand All @@ -199,7 +198,7 @@ def forward(

def aggregate_neighborhoods(self, nodes_features, idx_target, num_nodes):
size = list(nodes_features.shape)
size[self.nodes_dim] = num_nodes # shape = (N, NH, FOUT)
size[self.nodes_dim] = num_nodes # shape = (N, H, FOUT)
nodes_features_aggregated = torch.zeros(
size, dtype=nodes_features.dtype, device=nodes_features.device
)
Expand Down Expand Up @@ -240,9 +239,12 @@ def score_edge_wiht_neighborhood(
idx_target,
num_nodes
):
# The shape must be the same as in scores_edge_exp (required by scatter_add_) i.e. from E -> (E, NH)
# The shape must be the same as in scores_edge_exp (required by scatter_add_)
# i.e. from E -> (E, H)
idx_target_broadcasted = self.expand_dim(idx_target, scores_edge_exp)
# shape = (N, NH), where N is the number of nodes and NH the number of attention heads

# shape = (N, H), where N is the number of nodes
# H the number of attention heads
size = list(
scores_edge_exp.shape
) # convert to list otherwise assignment is not possible
Expand All @@ -254,7 +256,7 @@ def score_edge_wiht_neighborhood(
self.nodes_dim, idx_target_broadcasted, scores_edge_exp
)

# shape = (N, NH) -> (E, NH)
# shape = (N, H) -> (E, H)
return neighborhood_sums.index_select(self.nodes_dim, idx_target)

def expand_dim(self, src, trg):
Expand Down
4 changes: 2 additions & 2 deletions test/examples/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_gat():
out.returncode == 0
), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
stdout = out.stdout.decode("utf-8")
assert float(stdout[-9:]) > 0.75
assert float(stdout[-9:]) > 0.82


def test_han():
Expand All @@ -36,7 +36,7 @@ def test_han():
out.returncode == 0
), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
stdout = out.stdout.decode("utf-8")
assert float(stdout[-9:]) > 0.54
assert float(stdout[-9:]) > 0.56


def test_rect():
Expand Down

0 comments on commit 25f765b

Please sign in to comment.