Skip to content

Commit

Permalink
remove torch_sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
SherylHYX committed Feb 9, 2025
1 parent 4fed7ff commit 1a6cc70
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions torch_geometric_signed_directed/nn/directed/DGCNConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from torch_geometric.typing import Adj, OptTensor

from torch import Tensor
from torch_geometric.typing import SparseTensor, torch_sparse
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import spmm
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm

Expand Down Expand Up @@ -99,4 +100,4 @@ def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return torch_sparse.mul(adj_t, x, reduce=self.aggr)
return spmm(adj_t, x, reduce=self.aggr)
5 changes: 3 additions & 2 deletions torch_geometric_signed_directed/nn/signed/SGCNConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import SparseTensor, torch_sparse
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import spmm
from torch_geometric.nn.conv import MessagePassing


Expand Down Expand Up @@ -130,7 +131,7 @@ def message(self, x_j: Tensor) -> Tensor:
def message_and_aggregate(self, adj_t: SparseTensor,
x: PairTensor) -> Tensor:
adj_t = adj_t.set_value(None, layout=None)
return torch_sparse.mul(adj_t, x[0], reduce=self.aggr)
return spmm(adj_t, x[0], reduce=self.aggr)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_dim}, '
Expand Down

0 comments on commit 1a6cc70

Please sign in to comment.