Skip to content

Commit

Permalink
modify the BRIDGE examples and rect model
Browse files Browse the repository at this point in the history
  • Loading branch information
JianwuZheng413 committed Nov 26, 2024
1 parent 43bcab7 commit 41bcb58
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 70 deletions.
51 changes: 24 additions & 27 deletions examples/bridge/bridge_tacm12k.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,20 @@
_,
) = dataset.data_list

# Making graph
paper_embeddings = paper_embeddings.to(device)
adj = build_homo_adj(
relation_df=citations_table.df,
n_all=len(papers_table),
).to(device)
target_table = papers_table.to(device)
y = papers_table.y.long().to(device)
paper_embeddings = paper_embeddings.to(device)


train_mask, val_mask, test_mask = (
papers_table.train_mask,
papers_table.val_mask,
papers_table.test_mask,
)
out_dim = papers_table.num_classes


class Bridge(torch.nn.Module):
Expand All @@ -79,6 +77,29 @@ def forward(self, table, non_table, adj):
return node_feats[: len(table), :]


t_encoder = TableEncoder(
in_dim=paper_embeddings.size(1),
out_dim=paper_embeddings.size(1),
table_transorm=FTTransformerTransform(col_stats_dict=papers_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=paper_embeddings.size(1),
out_dim=papers_table.num_classes,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd,
)


def accuracy_score(preds, truth):
return (preds == truth).sum(dim=0) / len(truth)

Expand Down Expand Up @@ -112,32 +133,8 @@ def test_epoch():
return train_acc.item(), val_acc.item(), test_acc.item()


t_encoder = TableEncoder(
in_dim=paper_embeddings.size(1),
out_dim=paper_embeddings.size(1),
table_transorm=FTTransformerTransform(col_stats_dict=papers_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=paper_embeddings.size(1),
out_dim=out_dim,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)

start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(
[
dict(params=model.table_encoder.parameters(), lr=0.001),
dict(params=model.graph_encoder.parameters(), lr=0.01, weight_decay=1e-4),
]
)

for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand Down
43 changes: 23 additions & 20 deletions examples/bridge/bridge_tlf2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
artist_table.val_mask,
artist_table.test_mask,
)
out_dim = artist_table.num_classes


class Bridge(torch.nn.Module):
Expand All @@ -79,6 +78,29 @@ def forward(self, table, non_table, adj):
return node_feats[: len(table), :]


t_encoder = TableEncoder(
in_dim=emb_size,
out_dim=emb_size,
table_transorm=FTTransformerTransform(col_stats_dict=artist_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=emb_size,
out_dim=artist_table.num_classes,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd,
)


def accuracy_score(preds, truth):
return (preds == truth).sum(dim=0) / len(truth)

Expand Down Expand Up @@ -112,27 +134,8 @@ def test_epoch():
return train_acc.item(), val_acc.item(), test_acc.item()


t_encoder = TableEncoder(
in_dim=emb_size,
out_dim=emb_size,
table_transorm=FTTransformerTransform(col_stats_dict=artist_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=emb_size,
out_dim=out_dim,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)


start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand Down
43 changes: 23 additions & 20 deletions examples/bridge/bridge_tml1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
user_table.val_mask,
user_table.test_mask,
)
out_dim = user_table.num_classes


class Bridge(torch.nn.Module):
Expand All @@ -83,6 +82,29 @@ def forward(self, table, non_table, adj):
return node_feats[: len(table), :]


t_encoder = TableEncoder(
in_dim=emb_size,
out_dim=emb_size,
table_transorm=FTTransformerTransform(col_stats_dict=user_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=emb_size,
out_dim=user_table.num_classes,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd,
)


def accuracy_score(preds, truth):
return (preds == truth).sum(dim=0) / len(truth)

Expand Down Expand Up @@ -116,27 +138,8 @@ def test_epoch():
return train_acc.item(), val_acc.item(), test_acc.item()


t_encoder = TableEncoder(
in_dim=emb_size,
out_dim=emb_size,
table_transorm=FTTransformerTransform(col_stats_dict=user_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=emb_size,
out_dim=out_dim,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)


start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand Down
4 changes: 2 additions & 2 deletions examples/rect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Datasets Citeseer | Cora | Pubmed
# Unseen Classes [1, 2, 5] [3, 4] | [1, 2, 3] [3, 4, 6] | [2]
# RECT-L 66.50 68.40 | 74.80 72.20 | 75.30
# RECT-L 61.10 66.10 | 71.20 70.40 | 69.90

import argparse
import copy
Expand Down Expand Up @@ -41,7 +41,7 @@
[
UT.NormalizeFeatures("l2"),
UT.SVDFeatureReduction(200),
GT.GDC(),
GT.GCNNorm(),
]
)

Expand Down
3 changes: 2 additions & 1 deletion rllm/nn/models/rect.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.prelu = torch.nn.PReLU()
self.conv = GCNConv(in_dim, hidden_dim)
self.lin = Linear(hidden_dim, in_dim)
self.reset_parameters()
Expand All @@ -41,7 +42,7 @@ def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.lin.weight.data)

def forward(self, x: Tensor, adj: Tensor):
x = self.conv(x, adj)
x = self.prelu(self.conv(x, adj))
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lin(x)

Expand Down

0 comments on commit 41bcb58

Please sign in to comment.