-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathbridge_tml1m.py
156 lines (134 loc) · 3.88 KB
/
bridge_tml1m.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# The BRIDGE method from the "rLLM: Relational Table Learning with LLMs" paper.
# ArXiv: https://arxiv.org/abs/2407.20157
# Datasets TML1M
# Acc 0.397
import time
import argparse
import sys
import os.path as osp
import torch
import torch.nn.functional as F
sys.path.append("./")
sys.path.append("../")
sys.path.append("../../")
from rllm.datasets import TML1MDataset
from rllm.transforms.graph_transforms import GCNTransform
from rllm.transforms.table_transforms import TabTransformerTransform
from rllm.nn.conv.graph_conv import GCNConv
from rllm.nn.conv.table_conv import TabTransformerConv
from rllm.nn.models import BRIDGE, TableEncoder, GraphEncoder
from utils import build_homo_graph, reorder_ids
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100, help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--wd", type=float, default=1e-4, help="Weight decay")
args = parser.parse_args()
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load data
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data")
dataset = TML1MDataset(cached_dir=path, force_reload=True)
# Get the required data
(
user_table,
_,
rating_table,
movie_embeddings,
) = dataset.data_list
emb_size = movie_embeddings.size(1)
user_size = len(user_table)
ordered_rating = reorder_ids(
relation_df=rating_table.df,
src_col_name="UserID",
tgt_col_name="MovieID",
n_src=user_size,
)
target_table = user_table.to(device)
y = user_table.y.long().to(device)
movie_embeddings = movie_embeddings.to(device)
# Build graph
graph = build_homo_graph(
relation_df=ordered_rating,
n_all=user_size + movie_embeddings.size(0),
).to(device)
# Transform data
table_transform = TabTransformerTransform(
out_dim=emb_size, metadata=target_table.metadata
)
target_table = table_transform(target_table)
graph_transform = GCNTransform()
adj = graph_transform(graph).adj
# Split data
train_mask, val_mask, test_mask = (
user_table.train_mask,
user_table.val_mask,
user_table.test_mask,
)
# Set up model and optimizer
t_encoder = TableEncoder(
in_dim=emb_size,
out_dim=emb_size,
table_conv=TabTransformerConv,
metadata=target_table.metadata,
)
g_encoder = GraphEncoder(
in_dim=emb_size,
out_dim=target_table.num_classes,
graph_conv=GCNConv,
)
model = BRIDGE(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd,
)
def train() -> float:
model.train()
optimizer.zero_grad()
logits = model(
table=user_table,
non_table=movie_embeddings,
adj=adj,
)
loss = F.cross_entropy(logits[train_mask].squeeze(), y[train_mask])
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def test():
model.eval()
logits = model(
table=user_table,
non_table=movie_embeddings,
adj=adj,
)
preds = logits.argmax(dim=1)
accs = []
for mask in [train_mask, val_mask, test_mask]:
correct = float(preds[mask].eq(y[mask]).sum().item())
accs.append(correct / int(mask.sum()))
print(mask.sum())
exit(0)
return accs
start_time = time.time()
best_val_acc = best_test_acc = 0
for epoch in range(1, args.epochs + 1):
train_loss = train()
train_acc, val_acc, test_acc = test()
print(
f"Epoch: [{epoch}/{args.epochs}]"
f"Loss: {train_loss:.4f} train_acc: {train_acc:.4f} "
f"val_acc: {val_acc:.4f} test_acc: {test_acc:.4f} "
)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
print(f"Total Time: {time.time() - start_time:.4f}s")
print(
"BRIDGE result: "
f"Best Val acc: {best_val_acc:.4f}, "
f"Best Test acc: {best_test_acc:.4f}"
)