Skip to content

Commit

Permalink
Merge pull request #101 from rllm-team/develop
Browse files Browse the repository at this point in the history
Update file architecture
  • Loading branch information
JianwuZheng413 authored Oct 29, 2024
2 parents 175d539 + 8ffe2f8 commit a38b035
Show file tree
Hide file tree
Showing 52 changed files with 915 additions and 991 deletions.
46 changes: 18 additions & 28 deletions examples/bridge/bridge_tacm12k.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@
import os.path as osp
import pandas as pd
import sys
sys.path.append('../')
sys.path.append('../../')

sys.path.append("../")
sys.path.append("../../")

import torch
import torch.nn.functional as F

import rllm.transforms as T
import rllm.transforms.graph_transforms as T
from rllm.datasets import TACM12KDataset
from rllm.transforms import build_homo_graph
from rllm.nn.models import Bridge
from rllm.transforms.graph_transforms import build_homo_graph


parser = argparse.ArgumentParser()
parser.add_argument(
"--tab_dim", type=int, default=256,
help="Tab Transformer categorical embedding dim")
parser.add_argument("--gcn_dropout", type=float, default=0.5,
help="Dropout for GCN")
"--tab_dim", type=int, default=256, help="Tab Transformer categorical embedding dim"
)
parser.add_argument("--gcn_dropout", type=float, default=0.5, help="Dropout for GCN")
parser.add_argument("--epochs", type=int, default=100, help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--wd", type=float, default=5e-4, help="Weight decay")
Expand All @@ -45,21 +45,17 @@
author_embeddings,
) = dataset.data_list

cite = cite_table.df.assign(
Target=cite_table.df["paper_id_cited"]
)
cite = cite_table.df.assign(Target=cite_table.df["paper_id_cited"])
author2id = {
author_id: idx+paper_embeddings.size(0) for idx, author_id in enumerate(author_table.df.index.to_numpy())
author_id: idx + paper_embeddings.size(0)
for idx, author_id in enumerate(author_table.df.index.to_numpy())
}
writed = writing_table.df.assign(
Target=writing_table.df["author_id"].map(author2id)
)
writed = writing_table.df.assign(Target=writing_table.df["author_id"].map(author2id))

# Get relation with cite_table and writing_table
relation_df = pd.concat(
[cite.iloc[:, [0, 2]], writed.iloc[:, [0, 2]]],
axis=0,
ignore_index=True)
[cite.iloc[:, [0, 2]], writed.iloc[:, [0, 2]]], axis=0, ignore_index=True
)
x = torch.cat([paper_embeddings, author_embeddings], dim=0)

# Making graph
Expand Down Expand Up @@ -122,26 +118,20 @@ def test_epoch():


model = Bridge(
table_hidden_dim=args.tab_dim,
table_output_dim=emb_size,
table_hidden_dim=emb_size,
graph_output_dim=output_dim,
stats_dict=graph.paper_table.stats_dict,
graph_dropout=args.gcn_dropout,
graph_layers=2,
graph_hidden_dim=(128),
graph_hidden_dim=128,
).to(device)

start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(
[
dict(
params=model.table_encoder.parameters(),
lr=0.001),
dict(
params=model.graph_encoder.parameters(),
lr=0.01,
weight_decay=1e-4),
dict(params=model.table_encoder.parameters(), lr=0.001),
dict(params=model.graph_encoder.parameters(), lr=0.01, weight_decay=1e-4),
]
# model.parameters(),
# lr=args.lr,
Expand Down
8 changes: 3 additions & 5 deletions examples/bridge/bridge_tlf2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import torch
import torch.nn.functional as F

import rllm.transforms as T
import rllm.transforms.graph_transforms as T
from rllm.datasets import TLF2KDataset
from rllm.nn.models import Bridge
from rllm.transforms import build_homo_graph
from rllm.transforms.graph_transforms import build_homo_graph


parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -112,10 +112,8 @@ def test_epoch():


model = Bridge(
table_hidden_dim=args.tab_dim,
table_output_dim=emb_size,
table_hidden_dim=emb_size,
graph_layers=2,
graph_hidden_dim=emb_size,
graph_output_dim=output_dim,
stats_dict=graph.artist_table.stats_dict,
graph_dropout=args.gcn_dropout,
Expand Down
9 changes: 3 additions & 6 deletions examples/bridge/bridge_tml1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import torch
import torch.nn.functional as F

import rllm.transforms as T
import rllm.transforms.graph_transforms as T
from rllm.datasets import TML1MDataset
from rllm.nn.models import Bridge
from rllm.transforms import build_homo_graph
from rllm.transforms.graph_transforms import build_homo_graph


parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -70,7 +70,6 @@
)
output_dim = graph.user_table.num_classes


def accuracy_score(preds, truth):
return (preds == truth).sum(dim=0) / len(truth)

Expand Down Expand Up @@ -110,10 +109,8 @@ def test_epoch():


model = Bridge(
table_hidden_dim=args.tab_dim,
table_output_dim=emb_size,
table_hidden_dim=emb_size,
graph_layers=2,
graph_hidden_dim=emb_size,
graph_output_dim=output_dim,
stats_dict=graph.user_table.stats_dict,
graph_dropout=args.gcn_dropout,
Expand Down
112 changes: 74 additions & 38 deletions examples/ft_transformer.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,81 @@
import argparse
import os.path as osp
import sys
sys.path.append('../')

sys.path.append("../")

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import LayerNorm, Linear, ReLU, Sequential
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm

from rllm.types import ColType
from rllm.datasets.titanic import Titanic
from rllm.nn.models.ft_transformer import FTTransformer
from rllm.transforms.table_transforms import FTTransformerTransform
from rllm.nn.conv.table_conv import FTTransformerConvs

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='titanic',
choices=["titanic",])
parser.add_argument('--dim', help='embedding dim.', type=int, default=32)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--compile', action='store_true')
parser.add_argument(
"--dataset",
type=str,
default="titanic",
choices=[
"titanic",
],
)
parser.add_argument("--dim", help="embedding dim.", type=int, default=32)
parser.add_argument("--num_layers", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--wd", type=float, default=5e-4)
args = parser.parse_args()

torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare datasets
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data")
dataset = Titanic(cached_dir=path)[0]
dataset.to(device)
dataset.shuffle()

# Split dataset, here the ratio of train-val-test is 80%-10%-10%
train_dataset, val_dataset, test_dataset = dataset.get_dataset(0.8, 0.1, 0.1)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size)
train_loader, val_loader, test_loader = dataset.get_dataloader(
0.8, 0.1, 0.1, batch_size=args.batch_size
)


# Set up model and optimizer
cat_dims = tuple(dataset.count_categorical_features().values())
cont_nums = len(dataset.count_numerical_features())
class FTTransformer(torch.nn.Module):
def __init__(
self,
hidden_dim: int,
output_dim: int,
layers: int,
col_stats_dict: dict[ColType, list[dict[str,]]],
):
super().__init__()
self.transform = FTTransformerTransform(
out_dim=hidden_dim,
col_stats_dict=col_stats_dict,
)
self.convs = FTTransformerConvs(dim=hidden_dim, layers=layers)
self.fc = self.decoder = Sequential(
LayerNorm(hidden_dim),
ReLU(),
Linear(hidden_dim, output_dim),
)

def forward(self, x) -> Tensor:
x, _ = self.transform(x)
x, x_cls = self.convs(x)
out = self.fc(x_cls)
return out


model = FTTransformer(
hidden_dim=args.dim,
Expand All @@ -52,19 +84,19 @@
col_stats_dict=dataset.stats_dict,
).to(device)


model = torch.compile(model, dynamic=True) if args.compile else model
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
lr_scheduler = ExponentialLR(optimizer, gamma=0.95)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd,
)


def train(epoch: int) -> float:
model.train()
loss_accum = total_count = 0
for batch in tqdm(train_loader, desc=f'Epoch: {epoch}'):
feat_dict, y = batch
pred = model.forward(feat_dict)
for batch in tqdm(train_loader, desc=f"Epoch: {epoch}"):
x, y = batch
pred = model.forward(x)
loss = F.cross_entropy(pred, y.long())
optimizer.zero_grad()
loss.backward()
Expand All @@ -80,8 +112,8 @@ def test(loader: DataLoader) -> float:
all_preds = []
all_labels = []
for batch in loader:
feat_dict, y = batch
pred = model.forward(feat_dict)
x, y = batch
pred = model.forward(x)
all_labels.append(y.cpu())
all_preds.append(pred[:, 1].detach().cpu())
all_labels = torch.cat(all_labels).numpy()
Expand All @@ -92,7 +124,7 @@ def test(loader: DataLoader) -> float:
return overall_auc


metric = 'AUC'
metric = "AUC"
best_val_metric = 0
best_test_metric = 0
for epoch in range(1, args.epochs + 1):
Expand All @@ -105,9 +137,13 @@ def test(loader: DataLoader) -> float:
best_val_metric = val_metric
best_test_metric = test_metric

print(f'Train Loss: {train_loss:.4f}, Train {metric}: {train_metric:.4f}, '
f'Val {metric}: {val_metric:.4f}, Test {metric}: {test_metric:.4f}')
lr_scheduler.step()
print(
f"Train Loss: {train_loss:.4f}, Train {metric}: {train_metric:.4f}, "
f"Val {metric}: {val_metric:.4f}, Test {metric}: {test_metric:.4f}"
)
optimizer.step()

print(f'Best Val {metric}: {best_val_metric:.4f}, '
f'Best Test {metric}: {best_test_metric:.4f}')
print(
f"Best Val {metric}: {best_val_metric:.4f}, "
f"Best Test {metric}: {best_test_metric:.4f}"
)
Loading

0 comments on commit a38b035

Please sign in to comment.