Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delete unused imports #130

Merged
merged 1 commit into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
col_stats_dict=col_stats_dict,
)

self.convs = FTTransformerConv(
self.conv = FTTransformerConv(
dim=hidden_dim,
layers=layers,
use_cls=True,
Expand All @@ -77,7 +77,7 @@ def __init__(

def forward(self, x) -> Tensor:
x = self.transform(x)
x_cls = self.convs(x)
x_cls = self.conv(x)
out = self.fc(x_cls)
return out

Expand Down
4 changes: 2 additions & 2 deletions examples/tab_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __init__(

def forward(self, x):
x = self.transform(x)
for tab_transformer_conv in self.convs:
x = tab_transformer_conv(x)
for conv in self.convs:
x = conv(x)
out = self.fc(x.mean(dim=1))
return out

Expand Down
136 changes: 67 additions & 69 deletions examples/tape/core/GNNs/RevGAT/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ def forward(self, x):

class GATConv(nn.Module):
def __init__(
self,
in_feats,
out_feats,
num_heads=1,
feat_drop=0.0,
attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2,
use_attn_dst=True,
residual=False,
activation=None,
allow_zero_in_degree=False,
use_symmetric_norm=False,
self,
in_feats,
out_feats,
num_heads=1,
feat_drop=0.0,
attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2,
use_attn_dst=True,
residual=False,
activation=None,
allow_zero_in_degree=False,
use_symmetric_norm=False,
):
super(GATConv, self).__init__()
self._num_heads = num_heads
Expand All @@ -73,17 +73,18 @@ def __init__(
self._use_symmetric_norm = use_symmetric_norm
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self._in_src_feats, out_feats * num_heads, bias=False
)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=False)
self._in_dst_feats, out_feats * num_heads, bias=False
)
else:
self.fc = nn.Linear(self._in_src_feats,
out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(
torch.FloatTensor(size=(1, num_heads, out_feats)))
self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
if use_attn_dst:
self.attn_r = nn.Parameter(
torch.FloatTensor(size=(1, num_heads, out_feats)))
torch.FloatTensor(size=(1, num_heads, out_feats))
)
else:
self.register_buffer("attn_r", None)
self.feat_drop = nn.Dropout(feat_drop)
Expand All @@ -94,7 +95,8 @@ def __init__(
self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False)
self._in_dst_feats, num_heads * out_feats, bias=False
)
else:
self.register_buffer("res_fc", None)
self.reset_parameters()
Expand Down Expand Up @@ -128,15 +130,12 @@ def forward(self, graph, feat, perm=None):
if not hasattr(self, "fc_src"):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src, feat_dst = h_src, h_dst
feat_src = self.fc_src(
h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(
h_dst).view(-1, self._num_heads, self._out_feats)
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
else:
h_src = self.feat_drop(feat)
feat_src = h_src
feat_src = self.fc(
h_src).view(-1, self._num_heads, self._out_feats)
feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
if graph.is_block:
h_dst = h_src[: graph.number_of_dst_nodes()]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
Expand Down Expand Up @@ -174,13 +173,13 @@ def forward(self, graph, feat, perm=None):

if self.training and self.edge_drop > 0:
if perm is None:
perm = torch.randperm(
graph.number_of_edges(), device=e.device)
perm = torch.randperm(graph.number_of_edges(), device=e.device)
bound = int(graph.number_of_edges() * self.edge_drop)
eids = perm[bound:]
graph.edata["a"] = torch.zeros_like(e)
graph.edata["a"][eids] = self.attn_drop(
edge_softmax(graph, e[eids], eids=eids))
edge_softmax(graph, e[eids], eids=eids)
)
else:
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))

Expand All @@ -197,8 +196,7 @@ def forward(self, graph, feat, perm=None):

# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats)
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval

# activation
Expand All @@ -209,20 +207,20 @@ def forward(self, graph, feat, perm=None):

class RevGATBlock(nn.Module):
def __init__(
self,
node_feats,
edge_feats,
edge_emb,
out_feats,
n_heads=1,
attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2,
residual=True,
activation=None,
use_attn_dst=True,
allow_zero_in_degree=True,
use_symmetric_norm=False,
self,
node_feats,
edge_feats,
edge_emb,
out_feats,
n_heads=1,
attn_drop=0.0,
edge_drop=0.0,
negative_slope=0.2,
residual=True,
activation=None,
use_attn_dst=True,
allow_zero_in_degree=True,
use_symmetric_norm=False,
):
super(RevGATBlock, self).__init__()

Expand Down Expand Up @@ -269,21 +267,22 @@ def forward(self, x, graph, dropout_mask=None, perm=None, efeat=None):

class RevGAT(nn.Module):
def __init__(
self,
in_feats,
n_classes,
n_hidden,
n_layers,
n_heads,
activation,
dropout=0.0,
input_drop=0.0,
attn_drop=0.0,
edge_drop=0.0,
use_attn_dst=True,
use_symmetric_norm=False,
group=2, input_norm=True,
use_pred=False,
self,
in_feats,
n_classes,
n_hidden,
n_layers,
n_heads,
activation,
dropout=0.0,
input_drop=0.0,
attn_drop=0.0,
edge_drop=0.0,
use_attn_dst=True,
use_symmetric_norm=False,
group=2,
input_norm=True,
use_pred=False,
):
super().__init__()
self.in_feats = in_feats
Expand All @@ -295,7 +294,7 @@ def __init__(
self.use_pred = use_pred

if self.use_pred:
self.encoder = torch.nn.Embedding(n_classes+1, n_hidden)
self.encoder = torch.nn.Embedding(n_classes + 1, n_hidden)
self.convs = nn.ModuleList()
self.norm = nn.BatchNorm1d(n_heads * n_hidden)
if input_norm:
Expand Down Expand Up @@ -339,11 +338,11 @@ def __init__(
else:
Fms.append(copy.deepcopy(fm))

invertible_module = memgcn.GroupAdditiveCoupling(Fms,
group=self.group)
invertible_module = memgcn.GroupAdditiveCoupling(Fms, group=self.group)

conv = memgcn.InvertibleModuleWrapper(fn=invertible_module,
keep_input=False)
conv = memgcn.InvertibleModuleWrapper(
fn=invertible_module, keep_input=False
)

self.convs.append(conv)

Expand All @@ -359,7 +358,7 @@ def forward(self, graph, feat, output_hidden_layer=None):
if self.use_pred:
h = self.encoder(h)
h = torch.flatten(h, start_dim=1)
if hasattr(self, 'input_norm'):
if hasattr(self, "input_norm"):
h = self.input_norm(h)
# h2 = self.input_norm(h2)
h = self.input_drop(h)
Expand All @@ -370,8 +369,7 @@ def forward(self, graph, feat, output_hidden_layer=None):

self.perms = []
for i in range(self.n_layers):
perm = torch.randperm(graph.number_of_edges(),
device=graph.device)
perm = torch.randperm(graph.number_of_edges(), device=graph.device)
self.perms.append(perm)

h = self.convs[0](graph, h, self.perms[0]).flatten(1, -1)
Expand Down
29 changes: 7 additions & 22 deletions rllm/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,25 @@
from .storage import BaseStorage, NodeStorage, EdgeStorage, recursive_apply # noqa
from .view import MappingView, KeysView, ValuesView, ItemsView # noqa

dataset_classes = [
__all__ = [
# dataset_classes
"Dataset",
]

graph_data_classes = [
# graph_data_classes
"BaseGraph",
"GraphData",
"HeteroGraphData",
]

table_data_classes = [
# table_data_classes
"BaseTable",
"TableDataset",
"TableData",
]

storage_classes = [
"TableDataset",
# storage_classes
"BaseStorage",
"NodeStorage",
"EdgeStorage",
"recursive_apply",
]

view_classes = [
# view_classes
"MappingView",
"KeysView",
"ValuesView",
"ItemsView",
]

__all__ = (
dataset_classes
+ graph_data_classes
+ table_data_classes
+ storage_classes
+ view_classes
)
5 changes: 1 addition & 4 deletions rllm/llm/llm_module/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from abc import ABC, abstractmethod
from typing import (
Any,
Sequence,
)
from typing import Sequence

from rllm.llm.types import (
ChatMessage,
Expand Down
2 changes: 1 addition & 1 deletion rllm/llm/llm_module/langchain_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Sequence
from typing import Callable, Optional, Sequence

from rllm.llm.types import (
LLMMetadata,
Expand Down
Loading