Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tests #102

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/bridge/bridge_tacm12k.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ def test_epoch():
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

print(f"Total Time: {time.time() - start_time:.4f}s")
print(
"Bridge result: "
f"Best Val acc: {best_val_acc:.4f}, "
f"Best Test acc: {best_test_acc:.4f}"
)
print(f"Total Time: {time.time() - start_time:.4f}s")
57 changes: 21 additions & 36 deletions examples/bridge/bridge_tlf2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import argparse
import os.path as osp
import sys
sys.path.append('../')
sys.path.append('../../')

sys.path.append("../")
sys.path.append("../../")

import torch
import torch.nn.functional as F
Expand All @@ -21,16 +22,13 @@


parser = argparse.ArgumentParser()
parser.add_argument("--tab_dim", type=int, default=64,
help="Tab Transformer categorical embedding dim")
parser.add_argument("--gcn_dropout", type=float, default=0.5,
help="Dropout for GCN")
parser.add_argument("--epochs", type=int, default=200,
help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001,
help="Learning rate")
parser.add_argument("--wd", type=float, default=1e-4,
help="Weight decay")
parser.add_argument(
"--tab_dim", type=int, default=64, help="Tab Transformer categorical embedding dim"
)
parser.add_argument("--gcn_dropout", type=float, default=0.5, help="Dropout for GCN")
parser.add_argument("--epochs", type=int, default=200, help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--wd", type=float, default=1e-4, help="Weight decay")
args = parser.parse_args()

# Prepare datasets
Expand All @@ -43,16 +41,16 @@
# We assume it a homogeneous graph,
# so we need to reorder the user and artist id.
ordered_ua = ua_table.df.assign(
artistID=ua_table.df['artistID']-1,
userID=ua_table.df['userID']+len(artist_table)-1,
artistID=ua_table.df["artistID"] - 1,
userID=ua_table.df["userID"] + len(artist_table) - 1,
)

# Making graph
emb_size = 384 # Since user doesn't have an embedding, randomly select a dim.
len_artist = len(artist_table)
len_user = ua_table.df['userID'].max()
len_user = ua_table.df["userID"].max()
# Randomly initialize the embedding, artist embedding will be further trained
x = torch.randn(len_artist+len_user, emb_size)
x = torch.randn(len_artist + len_user, emb_size)
graph = build_homo_graph(
df=ordered_ua,
n_src=len_artist,
Expand All @@ -66,7 +64,7 @@
train_mask, val_mask, test_mask = (
graph.artist_table.train_mask,
graph.artist_table.val_mask,
graph.artist_table.test_mask
graph.artist_table.test_mask,
)
output_dim = graph.artist_table.num_classes

Expand All @@ -79,15 +77,9 @@ def train_epoch() -> float:
model.train()
optimizer.zero_grad()
logits = model(
graph.artist_table,
graph.x,
graph.adj,
len_artist,
len_artist+len_user
)
loss = F.cross_entropy(
logits[train_mask].squeeze(), graph.y[train_mask]
graph.artist_table, graph.x, graph.adj, len_artist, len_artist + len_user
)
loss = F.cross_entropy(logits[train_mask].squeeze(), graph.y[train_mask])
loss.backward()
optimizer.step()
return loss.item()
Expand All @@ -97,11 +89,7 @@ def train_epoch() -> float:
def test_epoch():
model.eval()
logits = model(
graph.artist_table,
graph.x,
graph.adj,
len_artist,
len_artist+len_user
graph.artist_table, graph.x, graph.adj, len_artist, len_artist + len_user
)
preds = logits.argmax(dim=1)
y = graph.y
Expand All @@ -122,11 +110,7 @@ def test_epoch():

start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd
)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand All @@ -138,9 +122,10 @@ def test_epoch():
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

print(f"Total Time: {time.time() - start_time:.4f}s")
print(
"Bridge result: "
f"Best Val acc: {best_val_acc:.4f}, "
f"Best Test acc: {best_test_acc:.4f}"
)
print(f"Total Time: {time.time() - start_time:.4f}s")
57 changes: 21 additions & 36 deletions examples/bridge/bridge_tml1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import argparse
import os.path as osp
import sys
sys.path.append('../')
sys.path.append('../../')

sys.path.append("../")
sys.path.append("../../")

import torch
import torch.nn.functional as F
Expand All @@ -21,16 +22,13 @@


parser = argparse.ArgumentParser()
parser.add_argument("--tab_dim", type=int, default=64,
help="Tab Transformer categorical embedding dim")
parser.add_argument("--gcn_dropout", type=float, default=0.5,
help="Dropout for GCN")
parser.add_argument("--epochs", type=int, default=200,
help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001,
help="Learning rate")
parser.add_argument("--wd", type=float, default=1e-4,
help="Weight decay")
parser.add_argument(
"--tab_dim", type=int, default=64, help="Tab Transformer categorical embedding dim"
)
parser.add_argument("--gcn_dropout", type=float, default=0.5, help="Dropout for GCN")
parser.add_argument("--epochs", type=int, default=200, help="Training epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--wd", type=float, default=1e-4, help="Weight decay")
args = parser.parse_args()

# Prepare datasets
Expand All @@ -42,8 +40,9 @@
# We assume it a homogeneous graph,
# so we need to reorder the user and movie id.
ordered_rating = rating_table.df.assign(
UserID=rating_table.df['UserID']-1,
MovieID=rating_table.df['MovieID']+len(user_table)-1)
UserID=rating_table.df["UserID"] - 1,
MovieID=rating_table.df["MovieID"] + len(user_table) - 1,
)

# Making graph
emb_size = movie_embeddings.size(1)
Expand All @@ -66,26 +65,20 @@
train_mask, val_mask, test_mask = (
graph.user_table.train_mask,
graph.user_table.val_mask,
graph.user_table.test_mask
graph.user_table.test_mask,
)
output_dim = graph.user_table.num_classes


def accuracy_score(preds, truth):
return (preds == truth).sum(dim=0) / len(truth)


def train_epoch() -> float:
model.train()
optimizer.zero_grad()
logits = model(
graph.user_table,
graph.x, graph.adj,
len_user,
len_user+len_movie
)
loss = F.cross_entropy(
logits[train_mask].squeeze(), graph.y[train_mask]
)
logits = model(graph.user_table, graph.x, graph.adj, len_user, len_user + len_movie)
loss = F.cross_entropy(logits[train_mask].squeeze(), graph.y[train_mask])
loss.backward()
optimizer.step()
return loss.item()
Expand All @@ -94,12 +87,7 @@ def train_epoch() -> float:
@torch.no_grad()
def test_epoch():
model.eval()
logits = model(
graph.user_table,
graph.x, graph.adj,
len_user,
len_user+len_movie
)
logits = model(graph.user_table, graph.x, graph.adj, len_user, len_user + len_movie)
preds = logits.argmax(dim=1)
y = graph.y
train_acc = accuracy_score(preds[train_mask], y[train_mask])
Expand All @@ -119,11 +107,7 @@ def test_epoch():

start_time = time.time()
best_val_acc = best_test_acc = 0
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.wd
)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
for epoch in range(1, args.epochs + 1):
train_loss = train_epoch()
train_acc, val_acc, test_acc = test_epoch()
Expand All @@ -135,9 +119,10 @@ def test_epoch():
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

print(f"Total Time: {time.time() - start_time:.4f}s")
print(
"Bridge result: "
f"Best Val acc: {best_val_acc:.4f}, "
f"Best Test acc: {best_test_acc:.4f}"
)
print(f"Total Time: {time.time() - start_time:.4f}s")
4 changes: 2 additions & 2 deletions examples/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rllm.types import ColType
from rllm.datasets.titanic import Titanic
from rllm.transforms.table_transforms import FTTransformerTransform
from rllm.nn.conv.table_conv import FTTransformerConvs
from rllm.nn.conv.table_conv import FTTransformerConv

parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
out_dim=hidden_dim,
col_stats_dict=col_stats_dict,
)
self.convs = FTTransformerConvs(dim=hidden_dim, layers=layers)
self.convs = FTTransformerConv(dim=hidden_dim, layers=layers)
self.fc = self.decoder = Sequential(
LayerNorm(hidden_dim),
ReLU(),
Expand Down
8 changes: 2 additions & 6 deletions rllm/nn/conv/table_conv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from .ft_transformer_conv import FTTransformerConvs

from .ft_transformer_conv import FTTransformerConv
from .tab_transformer_conv import TabTransformerConv

__all__ = [
'FTTransformerConvs',
'TabTransformerConv'
]
__all__ = ["FTTransformerConv", "TabTransformerConv"]
11 changes: 6 additions & 5 deletions rllm/nn/conv/table_conv/ft_transformer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


class FTTransformerConvs(torch.nn.Module):
class FTTransformerConv(torch.nn.Module):
r"""The FT-Transformer backbone in the
`"Revisiting Deep Learning Models for Tabular Data"
<https://arxiv.org/abs/2106.11959>`_ paper.
Expand All @@ -33,6 +33,7 @@ class FTTransformerConvs(torch.nn.Module):
dropout (int): The dropout value (default: 0.1)
activation (str): The activation function (default: :obj:`relu`)
"""

def __init__(
self,
dim: int,
Expand All @@ -41,7 +42,7 @@ def __init__(
layers: int = 3,
heads: int = 8,
dropout: float = 0.2,
activation: str = 'relu',
activation: str = "relu",
):
super().__init__()

Expand All @@ -56,9 +57,9 @@ def __init__(
batch_first=True,
)
encoder_norm = LayerNorm(dim)
self.transformer = TransformerEncoder(encoder_layer=encoder_layer,
num_layers=layers,
norm=encoder_norm)
self.transformer = TransformerEncoder(
encoder_layer=encoder_layer, num_layers=layers, norm=encoder_norm
)
self.cls_embedding = Parameter(torch.empty(dim))
self.reset_parameters()

Expand Down
5 changes: 2 additions & 3 deletions rllm/transforms/table_transforms/tab_transformer_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rllm.transforms.table_transforms import ColTypeTransform, TableTypeTransform
from rllm.nn.pre_encoder import EmbeddingEncoder, StackEncoder


class TabTransformerTransform(TableTypeTransform):
def __init__(
self,
Expand All @@ -12,10 +13,8 @@ def __init__(
col_types_transform_dict: dict[ColType, ColTypeTransform] = None,
) -> None:
if col_types_transform_dict is None:
col_types_transform_dict={
col_types_transform_dict = {
ColType.CATEGORICAL: EmbeddingEncoder(),
ColType.NUMERICAL: StackEncoder(),
}
super().__init__(out_dim, col_stats_dict, col_types_transform_dict)


40 changes: 40 additions & 0 deletions test/examples/test_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import subprocess

EXAMPLE_ROOT = os.path.join(
os.path.dirname(os.path.relpath(__file__)),
"..",
"..",
"examples",
"bridge",
)


def test_bridge_tml1m():
script = os.path.join(EXAMPLE_ROOT, "bridge_tml1m.py")
out = subprocess.run(["python", str(script)], capture_output=True)
assert (
out.returncode == 0
), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
stdout = out.stdout.decode("utf-8")
assert float(stdout[-9:]) > 0.42


def test_gat():
script = os.path.join(EXAMPLE_ROOT, "bridge_tlf2k.py")
out = subprocess.run(["python", str(script)], capture_output=True)
assert (
out.returncode == 0
), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
stdout = out.stdout.decode("utf-8")
assert float(stdout[-9:]) > 0.49


def test_han():
script = os.path.join(EXAMPLE_ROOT, "bridge_tacm12k.py")
out = subprocess.run(["python", str(script)], capture_output=True)
assert (
out.returncode == 0
), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}"
stdout = out.stdout.decode("utf-8")
assert float(stdout[-9:]) > 0.32
Loading
Loading