From f9ddabe636af05cfc80e51df3859b9ae3faee630 Mon Sep 17 00:00:00 2001 From: z3u5 <1592050303@qq.com> Date: Fri, 30 Aug 2024 17:54:31 +0800 Subject: [PATCH] update test and annotation --- rllm/nn/conv/gat_conv.py | 14 ++++++++------ test/examples/test_gnn.py | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/rllm/nn/conv/gat_conv.py b/rllm/nn/conv/gat_conv.py index 9c6c0576..239d8209 100644 --- a/rllm/nn/conv/gat_conv.py +++ b/rllm/nn/conv/gat_conv.py @@ -1,5 +1,4 @@ from typing import Union, Tuple -from sympy import primefactors from torch import Tensor import torch import torch.nn as nn @@ -182,7 +181,7 @@ def forward( attentions_edge, p=0.6, training=self.training) # Element-wise product. Operator * does the same thing as torch.mul - # shape = (E, NH, F_OUT) * (E, NH, 1) -> (E, NH, F_OUT) + # shape = (E, H, F_OUT) * (E, H, 1) -> (E, H, F_OUT) # 1 gets broadcast into F_OUT nodes_features_weighted = nodes_features_selected * attentions_edge @@ -199,7 +198,7 @@ def forward( def aggregate_neighborhoods(self, nodes_features, idx_target, num_nodes): size = list(nodes_features.shape) - size[self.nodes_dim] = num_nodes # shape = (N, NH, FOUT) + size[self.nodes_dim] = num_nodes # shape = (N, H, FOUT) nodes_features_aggregated = torch.zeros( size, dtype=nodes_features.dtype, device=nodes_features.device ) @@ -240,9 +239,12 @@ def score_edge_wiht_neighborhood( idx_target, num_nodes ): - # The shape must be the same as in scores_edge_exp (required by scatter_add_) i.e. from E -> (E, NH) + # The shape must be the same as in scores_edge_exp (required by scatter_add_) + # i.e. from E -> (E, H) idx_target_broadcasted = self.expand_dim(idx_target, scores_edge_exp) - # shape = (N, NH), where N is the number of nodes and NH the number of attention heads + + # shape = (N, H), where N is the number of nodes + # H the number of attention heads size = list( scores_edge_exp.shape ) # convert to list otherwise assignment is not possible @@ -254,7 +256,7 @@ def score_edge_wiht_neighborhood( self.nodes_dim, idx_target_broadcasted, scores_edge_exp ) - # shape = (N, NH) -> (E, NH) + # shape = (N, H) -> (E, H) return neighborhood_sums.index_select(self.nodes_dim, idx_target) def expand_dim(self, src, trg): diff --git a/test/examples/test_gnn.py b/test/examples/test_gnn.py index 9782e797..4610adf4 100644 --- a/test/examples/test_gnn.py +++ b/test/examples/test_gnn.py @@ -26,7 +26,7 @@ def test_gat(): out.returncode == 0 ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}" stdout = out.stdout.decode("utf-8") - assert float(stdout[-9:]) > 0.75 + assert float(stdout[-9:]) > 0.82 def test_han(): @@ -36,7 +36,7 @@ def test_han(): out.returncode == 0 ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}" stdout = out.stdout.decode("utf-8") - assert float(stdout[-9:]) > 0.54 + assert float(stdout[-9:]) > 0.56 def test_rect():