From 1a6cc70cbe0313f164c1084cfeefbe0ba80ad73b Mon Sep 17 00:00:00 2001 From: Yixuan He Date: Sun, 9 Feb 2025 00:34:21 +0000 Subject: [PATCH] remove torch_sparse --- torch_geometric_signed_directed/nn/directed/DGCNConv.py | 5 +++-- torch_geometric_signed_directed/nn/signed/SGCNConv.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/torch_geometric_signed_directed/nn/directed/DGCNConv.py b/torch_geometric_signed_directed/nn/directed/DGCNConv.py index 8e5b4b6..900026e 100644 --- a/torch_geometric_signed_directed/nn/directed/DGCNConv.py +++ b/torch_geometric_signed_directed/nn/directed/DGCNConv.py @@ -2,7 +2,8 @@ from torch_geometric.typing import Adj, OptTensor from torch import Tensor -from torch_geometric.typing import SparseTensor, torch_sparse +from torch_geometric.typing import SparseTensor +from torch_geometric.utils import spmm from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm @@ -99,4 +100,4 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: - return torch_sparse.mul(adj_t, x, reduce=self.aggr) + return spmm(adj_t, x, reduce=self.aggr) diff --git a/torch_geometric_signed_directed/nn/signed/SGCNConv.py b/torch_geometric_signed_directed/nn/signed/SGCNConv.py index fe90af3..bbb3717 100644 --- a/torch_geometric_signed_directed/nn/signed/SGCNConv.py +++ b/torch_geometric_signed_directed/nn/signed/SGCNConv.py @@ -5,7 +5,8 @@ from torch import Tensor import torch.nn.functional as F from torch_geometric.nn.dense.linear import Linear -from torch_geometric.typing import SparseTensor, torch_sparse +from torch_geometric.typing import SparseTensor +from torch_geometric.utils import spmm from torch_geometric.nn.conv import MessagePassing @@ -130,7 +131,7 @@ def message(self, x_j: Tensor) -> Tensor: def message_and_aggregate(self, adj_t: SparseTensor, x: PairTensor) -> Tensor: adj_t = adj_t.set_value(None, layout=None) - return torch_sparse.mul(adj_t, x[0], reduce=self.aggr) + return spmm(adj_t, x[0], reduce=self.aggr) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_dim}, '