Skip to content

Commit

Permalink
Merge pull request #134 from rllm-team/develop
Browse files Browse the repository at this point in the history
modify BRIDGE examples and rect model
  • Loading branch information
JianwuZheng413 authored Nov 26, 2024
2 parents a3777fc + 41bcb58 commit 60ca8a8
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 82 deletions.
51 changes: 24 additions & 27 deletions examples/bridge/bridge_tacm12k.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,20 @@
_,
) = dataset.data_list

# Making graph
paper_embeddings = paper_embeddings.to(device)
adj = build_homo_adj(
relation_df=citations_table.df,
n_all=len(papers_table),
).to(device)
target_table = papers_table.to(device)
y = papers_table.y.long().to(device)
paper_embeddings = paper_embeddings.to(device)


train_mask, val_mask, test_mask = (
papers_table.train_mask,
papers_table.val_mask,
papers_table.test_mask,
)
out_dim = papers_table.num_classes


class Bridge(torch.nn.Module):
Expand All @@ -79,6 +77,29 @@ def forward(self, table, non_table, adj):
return node_feats[: len(table), :]


t_encoder = TableEncoder(
in_dim=paper_embeddings.size(1),
out_dim=paper_embeddings.size(1),
table_transorm=FTTransformerTransform(col_stats_dict=papers_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=paper_embeddings.size(1),
out_dim=papers_table.num_classes,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd,
)


def accuracy_score(preds, truth):
return (preds == truth).sum(dim=0) / len(truth)

Expand Down Expand Up @@ -112,32 +133,8 @@ def test_epoch():
return train_acc.item(), val_acc.item(), test_acc.item()


t_encoder = TableEncoder(
in_dim=paper_embeddings.size(1),
out_dim=paper_embeddings.size(1),
table_transorm=FTTransformerTransform(col_stats_dict=papers_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=paper_embeddings.size(1),
out_dim=out_dim,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)

start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(
[
dict(params=model.table_encoder.parameters(), lr=0.001),
dict(params=model.graph_encoder.parameters(), lr=0.01, weight_decay=1e-4),
]
)

for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand Down
43 changes: 23 additions & 20 deletions examples/bridge/bridge_tlf2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
artist_table.val_mask,
artist_table.test_mask,
)
out_dim = artist_table.num_classes


class Bridge(torch.nn.Module):
Expand All @@ -79,6 +78,29 @@ def forward(self, table, non_table, adj):
return node_feats[: len(table), :]


t_encoder = TableEncoder(
in_dim=emb_size,
out_dim=emb_size,
table_transorm=FTTransformerTransform(col_stats_dict=artist_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=emb_size,
out_dim=artist_table.num_classes,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd,
)


def accuracy_score(preds, truth):
return (preds == truth).sum(dim=0) / len(truth)

Expand Down Expand Up @@ -112,27 +134,8 @@ def test_epoch():
return train_acc.item(), val_acc.item(), test_acc.item()


t_encoder = TableEncoder(
in_dim=emb_size,
out_dim=emb_size,
table_transorm=FTTransformerTransform(col_stats_dict=artist_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=emb_size,
out_dim=out_dim,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)


start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand Down
43 changes: 23 additions & 20 deletions examples/bridge/bridge_tml1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
user_table.val_mask,
user_table.test_mask,
)
out_dim = user_table.num_classes


class Bridge(torch.nn.Module):
Expand All @@ -83,6 +82,29 @@ def forward(self, table, non_table, adj):
return node_feats[: len(table), :]


t_encoder = TableEncoder(
in_dim=emb_size,
out_dim=emb_size,
table_transorm=FTTransformerTransform(col_stats_dict=user_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=emb_size,
out_dim=user_table.num_classes,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd,
)


def accuracy_score(preds, truth):
return (preds == truth).sum(dim=0) / len(truth)

Expand Down Expand Up @@ -116,27 +138,8 @@ def test_epoch():
return train_acc.item(), val_acc.item(), test_acc.item()


t_encoder = TableEncoder(
in_dim=emb_size,
out_dim=emb_size,
table_transorm=FTTransformerTransform(col_stats_dict=user_table.stats_dict),
table_conv=TabTransformerConv,
)
g_encoder = GraphEncoder(
in_dim=emb_size,
out_dim=out_dim,
graph_transform=GCNNorm(),
graph_conv=GCNConv,
)
model = Bridge(
table_encoder=t_encoder,
graph_encoder=g_encoder,
).to(device)


start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand Down
2 changes: 1 addition & 1 deletion examples/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
sys.path.append("./")
sys.path.append("../")
from rllm.types import ColType
from rllm.datasets.titanic import Titanic
from rllm.datasets import Titanic
from rllm.transforms.table_transforms import FTTransformerTransform
from rllm.nn.conv.table_conv import FTTransformerConv

Expand Down
4 changes: 2 additions & 2 deletions examples/rect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Datasets Citeseer | Cora | Pubmed
# Unseen Classes [1, 2, 5] [3, 4] | [1, 2, 3] [3, 4, 6] | [2]
# RECT-L 66.50 68.40 | 74.80 72.20 | 75.30
# RECT-L 61.10 66.10 | 71.20 70.40 | 69.90

import argparse
import copy
Expand Down Expand Up @@ -41,7 +41,7 @@
[
UT.NormalizeFeatures("l2"),
UT.SVDFeatureReduction(200),
GT.GDC(),
GT.GCNNorm(),
]
)

Expand Down
2 changes: 1 addition & 1 deletion examples/tab_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
sys.path.append("./")
sys.path.append("../")
from rllm.types import ColType
from rllm.datasets.titanic import Titanic
from rllm.datasets import Titanic
from rllm.transforms.table_transforms import TabTransformerTransform
from rllm.nn.conv.table_conv import TabTransformerConv

Expand Down
2 changes: 1 addition & 1 deletion examples/tabnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
sys.path.append("./")
sys.path.append("../")
from rllm.types import ColType
from rllm.datasets.adult import Adult
from rllm.datasets import Adult
from rllm.transforms.table_transforms import TabNetTransform
from rllm.nn.models import TabNet

Expand Down
4 changes: 2 additions & 2 deletions examples/trompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
sys.path.append("./")
sys.path.append("../")
from rllm.types import ColType
from rllm.datasets.adult import Adult
from rllm.datasets import Adult
from rllm.transforms.table_transforms import TromptTransform
from rllm.nn.conv.table_conv import TromptConv

Expand All @@ -43,7 +43,7 @@
dataset = Adult(cached_dir=path)[0]
dataset.to(device)

# Split dataset, here the ratio of train-val-test is 80%-10%-10%
# Split dataset, here the ratio of train-val-test is 26048-6513-16281
train_loader, val_loader, test_loader = dataset.get_dataloader(
26048, 6513, 16281, batch_size=args.batch_size
)
Expand Down
2 changes: 2 additions & 0 deletions rllm/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .adult import Adult
from .churn_modelling import ChurnModelling
from .dblp import DBLP
from .imdb import IMDB
from .planetoid import PlanetoidDataset
Expand All @@ -11,6 +12,7 @@

__all__ = [
"Adult",
"ChurnModelling",
"DBLP",
"IMDB",
"PlanetoidDataset",
Expand Down
14 changes: 7 additions & 7 deletions rllm/datasets/adult.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@ class Adult(Dataset):
fnlwgt: The number of people the census believes have this job.
Education: The highest level of education achieved (Bachelors,
Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th,
7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.).
7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool).
Education-Num: A numeric version of Education.
Marital-Status: Marital status of the individual (Married-civ-spouse,
Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.).
Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse).
Occupation: The kind of work individuals perform (Tech-support,
Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty,
Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing,
Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.).
Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces).
Relationship: Relationship to head-of-household (Wife, Own-child, Husband,
Not-in-family, Other-relative, Unmarried.).
Not-in-family, Other-relative, Unmarried).
Race: Race of the individual (White, Asian-Pac-Islander,
Amer-Indian-Eskimo, Other, Black.).
Amer-Indian-Eskimo, Other, Black).
Sex: Gender of the individual.
Capital-Gain: Total capital gains.
Capital-Loss: Total capital losses.
Expand All @@ -53,8 +53,8 @@ class Adult(Dataset):
.. parsed-literal::
Statics:
Name Passengers Features
Size 48842 14
Name Individuals Features
Size 48842 14
"""

Expand Down
Loading

0 comments on commit 60ca8a8

Please sign in to comment.