Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JianwuZheng413 committed Oct 30, 2024
1 parent a38b035 commit 95b7216
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 89 deletions.
3 changes: 2 additions & 1 deletion examples/bridge/bridge_tacm12k.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ def test_epoch():
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

print(f"Total Time: {time.time() - start_time:.4f}s")
print(
"Bridge result: "
f"Best Val acc: {best_val_acc:.4f}, "
f"Best Test acc: {best_test_acc:.4f}"
)
print(f"Total Time: {time.time() - start_time:.4f}s")
57 changes: 21 additions & 36 deletions examples/bridge/bridge_tlf2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import argparse
import os.path as osp
import sys
sys.path.append('../')
sys.path.append('../../')

sys.path.append("../")
sys.path.append("../../")

import torch
import torch.nn.functional as F
Expand All @@ -21,16 +22,13 @@


parser = argparse.ArgumentParser()
parser.add_argument("--tab_dim", type=int, default=64,
help="Tab Transformer categorical embedding dim")
parser.add_argument("--gcn_dropout", type=float, default=0.5,
help="Dropout for GCN")
parser.add_argument("--epochs", type=int, default=200,
help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001,
help="Learning rate")
parser.add_argument("--wd", type=float, default=1e-4,
help="Weight decay")
parser.add_argument(
"--tab_dim", type=int, default=64, help="Tab Transformer categorical embedding dim"
)
parser.add_argument("--gcn_dropout", type=float, default=0.5, help="Dropout for GCN")
parser.add_argument("--epochs", type=int, default=200, help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--wd", type=float, default=1e-4, help="Weight decay")
args = parser.parse_args()

# Prepare datasets
Expand All @@ -43,16 +41,16 @@
# We assume it a homogeneous graph,
# so we need to reorder the user and artist id.
ordered_ua = ua_table.df.assign(
artistID=ua_table.df['artistID']-1,
userID=ua_table.df['userID']+len(artist_table)-1,
artistID=ua_table.df["artistID"] - 1,
userID=ua_table.df["userID"] + len(artist_table) - 1,
)

# Making graph
emb_size = 384 # Since user doesn't have an embedding, randomly select a dim.
len_artist = len(artist_table)
len_user = ua_table.df['userID'].max()
len_user = ua_table.df["userID"].max()
# Randomly initialize the embedding, artist embedding will be further trained
x = torch.randn(len_artist+len_user, emb_size)
x = torch.randn(len_artist + len_user, emb_size)
graph = build_homo_graph(
df=ordered_ua,
n_src=len_artist,
Expand All @@ -66,7 +64,7 @@
train_mask, val_mask, test_mask = (
graph.artist_table.train_mask,
graph.artist_table.val_mask,
graph.artist_table.test_mask
graph.artist_table.test_mask,
)
output_dim = graph.artist_table.num_classes

Expand All @@ -79,15 +77,9 @@ def train_epoch() -> float:
model.train()
optimizer.zero_grad()
logits = model(
graph.artist_table,
graph.x,
graph.adj,
len_artist,
len_artist+len_user
)
loss = F.cross_entropy(
logits[train_mask].squeeze(), graph.y[train_mask]
graph.artist_table, graph.x, graph.adj, len_artist, len_artist + len_user
)
loss = F.cross_entropy(logits[train_mask].squeeze(), graph.y[train_mask])
loss.backward()
optimizer.step()
return loss.item()
Expand All @@ -97,11 +89,7 @@ def train_epoch() -> float:
def test_epoch():
model.eval()
logits = model(
graph.artist_table,
graph.x,
graph.adj,
len_artist,
len_artist+len_user
graph.artist_table, graph.x, graph.adj, len_artist, len_artist + len_user
)
preds = logits.argmax(dim=1)
y = graph.y
Expand All @@ -122,11 +110,7 @@ def test_epoch():

start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd
)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand All @@ -138,9 +122,10 @@ def test_epoch():
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

print(f"Total Time: {time.time() - start_time:.4f}s")
print(
"Bridge result: "
f"Best Val acc: {best_val_acc:.4f}, "
f"Best Test acc: {best_test_acc:.4f}"
)
print(f"Total Time: {time.time() - start_time:.4f}s")
57 changes: 21 additions & 36 deletions examples/bridge/bridge_tml1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import argparse
import os.path as osp
import sys
sys.path.append('../')
sys.path.append('../../')

sys.path.append("../")
sys.path.append("../../")

import torch
import torch.nn.functional as F
Expand All @@ -21,16 +22,13 @@


parser = argparse.ArgumentParser()
parser.add_argument("--tab_dim", type=int, default=64,
help="Tab Transformer categorical embedding dim")
parser.add_argument("--gcn_dropout", type=float, default=0.5,
help="Dropout for GCN")
parser.add_argument("--epochs", type=int, default=200,
help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001,
help="Learning rate")
parser.add_argument("--wd", type=float, default=1e-4,
help="Weight decay")
parser.add_argument(
"--tab_dim", type=int, default=64, help="Tab Transformer categorical embedding dim"
)
parser.add_argument("--gcn_dropout", type=float, default=0.5, help="Dropout for GCN")
parser.add_argument("--epochs", type=int, default=200, help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--wd", type=float, default=1e-4, help="Weight decay")
args = parser.parse_args()

# Prepare datasets
Expand All @@ -42,8 +40,9 @@
# We assume it a homogeneous graph,
# so we need to reorder the user and movie id.
ordered_rating = rating_table.df.assign(
UserID=rating_table.df['UserID']-1,
MovieID=rating_table.df['MovieID']+len(user_table)-1)
UserID=rating_table.df["UserID"] - 1,
MovieID=rating_table.df["MovieID"] + len(user_table) - 1,
)

# Making graph
emb_size = movie_embeddings.size(1)
Expand All @@ -66,26 +65,20 @@
train_mask, val_mask, test_mask = (
graph.user_table.train_mask,
graph.user_table.val_mask,
graph.user_table.test_mask
graph.user_table.test_mask,
)
output_dim = graph.user_table.num_classes


def accuracy_score(preds, truth):
return (preds == truth).sum(dim=0) / len(truth)


def train_epoch() -> float:
model.train()
optimizer.zero_grad()
logits = model(
graph.user_table,
graph.x, graph.adj,
len_user,
len_user+len_movie
)
loss = F.cross_entropy(
logits[train_mask].squeeze(), graph.y[train_mask]
)
logits = model(graph.user_table, graph.x, graph.adj, len_user, len_user + len_movie)
loss = F.cross_entropy(logits[train_mask].squeeze(), graph.y[train_mask])
loss.backward()
optimizer.step()
return loss.item()
Expand All @@ -94,12 +87,7 @@ def train_epoch() -> float:
@torch.no_grad()
def test_epoch():
model.eval()
logits = model(
graph.user_table,
graph.x, graph.adj,
len_user,
len_user+len_movie
)
logits = model(graph.user_table, graph.x, graph.adj, len_user, len_user + len_movie)
preds = logits.argmax(dim=1)
y = graph.y
train_acc = accuracy_score(preds[train_mask], y[train_mask])
Expand All @@ -119,11 +107,7 @@ def test_epoch():

start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd
)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand All @@ -135,9 +119,10 @@ def test_epoch():
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

print(f"Total Time: {time.time() - start_time:.4f}s")
print(
"Bridge result: "
f"Best Val acc: {best_val_acc:.4f}, "
f"Best Test acc: {best_test_acc:.4f}"
)
print(f"Total Time: {time.time() - start_time:.4f}s")
4 changes: 2 additions & 2 deletions examples/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rllm.types import ColType
from rllm.datasets.titanic import Titanic
from rllm.transforms.table_transforms import FTTransformerTransform
from rllm.nn.conv.table_conv import FTTransformerConvs
from rllm.nn.conv.table_conv import FTTransformerConv

parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
out_dim=hidden_dim,
col_stats_dict=col_stats_dict,
)
self.convs = FTTransformerConvs(dim=hidden_dim, layers=layers)
self.convs = FTTransformerConv(dim=hidden_dim, layers=layers)
self.fc = self.decoder = Sequential(
LayerNorm(hidden_dim),
ReLU(),
Expand Down
8 changes: 2 additions & 6 deletions rllm/nn/conv/table_conv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from .ft_transformer_conv import FTTransformerConvs

from .ft_transformer_conv import FTTransformerConv
from .tab_transformer_conv import TabTransformerConv

__all__ = [
'FTTransformerConvs',
'TabTransformerConv'
]
__all__ = ["FTTransformerConv", "TabTransformerConv"]
11 changes: 6 additions & 5 deletions rllm/nn/conv/table_conv/ft_transformer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


class FTTransformerConvs(torch.nn.Module):
class FTTransformerConv(torch.nn.Module):
r"""The FT-Transformer backbone in the
`"Revisiting Deep Learning Models for Tabular Data"
<https://arxiv.org/abs/2106.11959>`_ paper.
Expand All @@ -33,6 +33,7 @@ class FTTransformerConvs(torch.nn.Module):
dropout (int): The dropout value (default: 0.1)
activation (str): The activation function (default: :obj:`relu`)
"""

def __init__(
self,
dim: int,
Expand All @@ -41,7 +42,7 @@ def __init__(
layers: int = 3,
heads: int = 8,
dropout: float = 0.2,
activation: str = 'relu',
activation: str = "relu",
):
super().__init__()

Expand All @@ -56,9 +57,9 @@ def __init__(
batch_first=True,
)
encoder_norm = LayerNorm(dim)
self.transformer = TransformerEncoder(encoder_layer=encoder_layer,
num_layers=layers,
norm=encoder_norm)
self.transformer = TransformerEncoder(
encoder_layer=encoder_layer, num_layers=layers, norm=encoder_norm
)
self.cls_embedding = Parameter(torch.empty(dim))
self.reset_parameters()

Expand Down
5 changes: 2 additions & 3 deletions rllm/transforms/table_transforms/tab_transformer_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rllm.transforms.table_transforms import ColTypeTransform, TableTypeTransform
from rllm.nn.pre_encoder import EmbeddingEncoder, StackEncoder


class TabTransformerTransform(TableTypeTransform):
def __init__(
self,
Expand All @@ -12,10 +13,8 @@ def __init__(
col_types_transform_dict: dict[ColType, ColTypeTransform] = None,
) -> None:
if col_types_transform_dict is None:
col_types_transform_dict={
col_types_transform_dict = {
ColType.CATEGORICAL: EmbeddingEncoder(),
ColType.NUMERICAL: StackEncoder(),
}
super().__init__(out_dim, col_stats_dict, col_types_transform_dict)


40 changes: 40 additions & 0 deletions test/examples/test_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import subprocess

EXAMPLE_ROOT = os.path.join(
os.path.dirname(os.path.relpath(__file__)),
"..",
"..",
"examples",
"bridge",
)


def test_bridge_tml1m():
script = os.path.join(EXAMPLE_ROOT, "bridge_tml1m.py")
out = subprocess.run(["python", str(script)], capture_output=True)
assert (
out.returncode == 0
), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
stdout = out.stdout.decode("utf-8")
assert float(stdout[-9:]) > 0.42


def test_gat():
script = os.path.join(EXAMPLE_ROOT, "bridge_tlf2k.py")
out = subprocess.run(["python", str(script)], capture_output=True)
assert (
out.returncode == 0
), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
stdout = out.stdout.decode("utf-8")
assert float(stdout[-9:]) > 0.49


def test_han():
script = os.path.join(EXAMPLE_ROOT, "bridge_tacm12k.py")
out = subprocess.run(["python", str(script)], capture_output=True)
assert (
out.returncode == 0
), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
stdout = out.stdout.decode("utf-8")
assert float(stdout[-9:]) > 0.32
Loading

0 comments on commit 95b7216

Please sign in to comment.