From 95b7216b6931f5d53a6e2eaa9b2692d85b3bdc63 Mon Sep 17 00:00:00 2001 From: z3u5 <1592050303@qq.com> Date: Wed, 30 Oct 2024 13:03:59 +0800 Subject: [PATCH] add tests --- examples/bridge/bridge_tacm12k.py | 3 +- examples/bridge/bridge_tlf2k.py | 57 +++++++------------ examples/bridge/bridge_tml1m.py | 57 +++++++------------ examples/ft_transformer.py | 4 +- rllm/nn/conv/table_conv/__init__.py | 8 +-- .../nn/conv/table_conv/ft_transformer_conv.py | 11 ++-- .../tab_transformer_transform.py | 5 +- test/examples/test_bridge.py | 40 +++++++++++++ test/examples/test_tnn.py | 39 +++++++++++++ 9 files changed, 135 insertions(+), 89 deletions(-) create mode 100644 test/examples/test_bridge.py create mode 100644 test/examples/test_tnn.py diff --git a/examples/bridge/bridge_tacm12k.py b/examples/bridge/bridge_tacm12k.py index 6b68478d..0a61d965 100644 --- a/examples/bridge/bridge_tacm12k.py +++ b/examples/bridge/bridge_tacm12k.py @@ -149,9 +149,10 @@ def test_epoch(): if val_acc > best_val_acc: best_val_acc = val_acc best_test_acc = test_acc + +print(f"Total Time: {time.time() - start_time:.4f}s") print( "Bridge result: " f"Best Val acc: {best_val_acc:.4f}, " f"Best Test acc: {best_test_acc:.4f}" ) -print(f"Total Time: {time.time() - start_time:.4f}s") diff --git a/examples/bridge/bridge_tlf2k.py b/examples/bridge/bridge_tlf2k.py index 745e3b3f..ad17b772 100644 --- a/examples/bridge/bridge_tlf2k.py +++ b/examples/bridge/bridge_tlf2k.py @@ -8,8 +8,9 @@ import argparse import os.path as osp import sys -sys.path.append('../') -sys.path.append('../../') + +sys.path.append("../") +sys.path.append("../../") import torch import torch.nn.functional as F @@ -21,16 +22,13 @@ parser = argparse.ArgumentParser() -parser.add_argument("--tab_dim", type=int, default=64, - help="Tab Transformer categorical embedding dim") -parser.add_argument("--gcn_dropout", type=float, default=0.5, - help="Dropout for GCN") -parser.add_argument("--epochs", type=int, default=200, - help="Training epochs") -parser.add_argument("--lr", type=float, default=0.001, - help="Learning rate") -parser.add_argument("--wd", type=float, default=1e-4, - help="Weight decay") +parser.add_argument( + "--tab_dim", type=int, default=64, help="Tab Transformer categorical embedding dim" +) +parser.add_argument("--gcn_dropout", type=float, default=0.5, help="Dropout for GCN") +parser.add_argument("--epochs", type=int, default=200, help="Training epochs") +parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") +parser.add_argument("--wd", type=float, default=1e-4, help="Weight decay") args = parser.parse_args() # Prepare datasets @@ -43,16 +41,16 @@ # We assume it a homogeneous graph, # so we need to reorder the user and artist id. ordered_ua = ua_table.df.assign( - artistID=ua_table.df['artistID']-1, - userID=ua_table.df['userID']+len(artist_table)-1, + artistID=ua_table.df["artistID"] - 1, + userID=ua_table.df["userID"] + len(artist_table) - 1, ) # Making graph emb_size = 384 # Since user doesn't have an embedding, randomly select a dim. len_artist = len(artist_table) -len_user = ua_table.df['userID'].max() +len_user = ua_table.df["userID"].max() # Randomly initialize the embedding, artist embedding will be further trained -x = torch.randn(len_artist+len_user, emb_size) +x = torch.randn(len_artist + len_user, emb_size) graph = build_homo_graph( df=ordered_ua, n_src=len_artist, @@ -66,7 +64,7 @@ train_mask, val_mask, test_mask = ( graph.artist_table.train_mask, graph.artist_table.val_mask, - graph.artist_table.test_mask + graph.artist_table.test_mask, ) output_dim = graph.artist_table.num_classes @@ -79,15 +77,9 @@ def train_epoch() -> float: model.train() optimizer.zero_grad() logits = model( - graph.artist_table, - graph.x, - graph.adj, - len_artist, - len_artist+len_user - ) - loss = F.cross_entropy( - logits[train_mask].squeeze(), graph.y[train_mask] + graph.artist_table, graph.x, graph.adj, len_artist, len_artist + len_user ) + loss = F.cross_entropy(logits[train_mask].squeeze(), graph.y[train_mask]) loss.backward() optimizer.step() return loss.item() @@ -97,11 +89,7 @@ def train_epoch() -> float: def test_epoch(): model.eval() logits = model( - graph.artist_table, - graph.x, - graph.adj, - len_artist, - len_artist+len_user + graph.artist_table, graph.x, graph.adj, len_artist, len_artist + len_user ) preds = logits.argmax(dim=1) y = graph.y @@ -122,11 +110,7 @@ def test_epoch(): start_time = time.time() best_val_acc = best_test_acc = 0 -optimizer = torch.optim.Adam( - model.parameters(), - lr=args.lr, - weight_decay=args.wd -) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) for epoch in range(1, args.epochs + 1): train_loss = train_epoch() train_acc, val_acc, test_acc = test_epoch() @@ -138,9 +122,10 @@ def test_epoch(): if val_acc > best_val_acc: best_val_acc = val_acc best_test_acc = test_acc + +print(f"Total Time: {time.time() - start_time:.4f}s") print( "Bridge result: " f"Best Val acc: {best_val_acc:.4f}, " f"Best Test acc: {best_test_acc:.4f}" ) -print(f"Total Time: {time.time() - start_time:.4f}s") diff --git a/examples/bridge/bridge_tml1m.py b/examples/bridge/bridge_tml1m.py index 0330b22c..c563b000 100644 --- a/examples/bridge/bridge_tml1m.py +++ b/examples/bridge/bridge_tml1m.py @@ -8,8 +8,9 @@ import argparse import os.path as osp import sys -sys.path.append('../') -sys.path.append('../../') + +sys.path.append("../") +sys.path.append("../../") import torch import torch.nn.functional as F @@ -21,16 +22,13 @@ parser = argparse.ArgumentParser() -parser.add_argument("--tab_dim", type=int, default=64, - help="Tab Transformer categorical embedding dim") -parser.add_argument("--gcn_dropout", type=float, default=0.5, - help="Dropout for GCN") -parser.add_argument("--epochs", type=int, default=200, - help="Training epochs") -parser.add_argument("--lr", type=float, default=0.001, - help="Learning rate") -parser.add_argument("--wd", type=float, default=1e-4, - help="Weight decay") +parser.add_argument( + "--tab_dim", type=int, default=64, help="Tab Transformer categorical embedding dim" +) +parser.add_argument("--gcn_dropout", type=float, default=0.5, help="Dropout for GCN") +parser.add_argument("--epochs", type=int, default=200, help="Training epochs") +parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") +parser.add_argument("--wd", type=float, default=1e-4, help="Weight decay") args = parser.parse_args() # Prepare datasets @@ -42,8 +40,9 @@ # We assume it a homogeneous graph, # so we need to reorder the user and movie id. ordered_rating = rating_table.df.assign( - UserID=rating_table.df['UserID']-1, - MovieID=rating_table.df['MovieID']+len(user_table)-1) + UserID=rating_table.df["UserID"] - 1, + MovieID=rating_table.df["MovieID"] + len(user_table) - 1, +) # Making graph emb_size = movie_embeddings.size(1) @@ -66,10 +65,11 @@ train_mask, val_mask, test_mask = ( graph.user_table.train_mask, graph.user_table.val_mask, - graph.user_table.test_mask + graph.user_table.test_mask, ) output_dim = graph.user_table.num_classes + def accuracy_score(preds, truth): return (preds == truth).sum(dim=0) / len(truth) @@ -77,15 +77,8 @@ def accuracy_score(preds, truth): def train_epoch() -> float: model.train() optimizer.zero_grad() - logits = model( - graph.user_table, - graph.x, graph.adj, - len_user, - len_user+len_movie - ) - loss = F.cross_entropy( - logits[train_mask].squeeze(), graph.y[train_mask] - ) + logits = model(graph.user_table, graph.x, graph.adj, len_user, len_user + len_movie) + loss = F.cross_entropy(logits[train_mask].squeeze(), graph.y[train_mask]) loss.backward() optimizer.step() return loss.item() @@ -94,12 +87,7 @@ def train_epoch() -> float: @torch.no_grad() def test_epoch(): model.eval() - logits = model( - graph.user_table, - graph.x, graph.adj, - len_user, - len_user+len_movie - ) + logits = model(graph.user_table, graph.x, graph.adj, len_user, len_user + len_movie) preds = logits.argmax(dim=1) y = graph.y train_acc = accuracy_score(preds[train_mask], y[train_mask]) @@ -119,11 +107,7 @@ def test_epoch(): start_time = time.time() best_val_acc = best_test_acc = 0 -optimizer = torch.optim.Adam( - model.parameters(), - lr=args.lr, - weight_decay=args.wd -) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) for epoch in range(1, args.epochs + 1): train_loss = train_epoch() train_acc, val_acc, test_acc = test_epoch() @@ -135,9 +119,10 @@ def test_epoch(): if val_acc > best_val_acc: best_val_acc = val_acc best_test_acc = test_acc + +print(f"Total Time: {time.time() - start_time:.4f}s") print( "Bridge result: " f"Best Val acc: {best_val_acc:.4f}, " f"Best Test acc: {best_test_acc:.4f}" ) -print(f"Total Time: {time.time() - start_time:.4f}s") diff --git a/examples/ft_transformer.py b/examples/ft_transformer.py index c9e579ca..eb415777 100644 --- a/examples/ft_transformer.py +++ b/examples/ft_transformer.py @@ -15,7 +15,7 @@ from rllm.types import ColType from rllm.datasets.titanic import Titanic from rllm.transforms.table_transforms import FTTransformerTransform -from rllm.nn.conv.table_conv import FTTransformerConvs +from rllm.nn.conv.table_conv import FTTransformerConv parser = argparse.ArgumentParser() parser.add_argument( @@ -63,7 +63,7 @@ def __init__( out_dim=hidden_dim, col_stats_dict=col_stats_dict, ) - self.convs = FTTransformerConvs(dim=hidden_dim, layers=layers) + self.convs = FTTransformerConv(dim=hidden_dim, layers=layers) self.fc = self.decoder = Sequential( LayerNorm(hidden_dim), ReLU(), diff --git a/rllm/nn/conv/table_conv/__init__.py b/rllm/nn/conv/table_conv/__init__.py index ffdc54b4..f1bff2e6 100644 --- a/rllm/nn/conv/table_conv/__init__.py +++ b/rllm/nn/conv/table_conv/__init__.py @@ -1,8 +1,4 @@ -from .ft_transformer_conv import FTTransformerConvs - +from .ft_transformer_conv import FTTransformerConv from .tab_transformer_conv import TabTransformerConv -__all__ = [ - 'FTTransformerConvs', - 'TabTransformerConv' -] +__all__ = ["FTTransformerConv", "TabTransformerConv"] diff --git a/rllm/nn/conv/table_conv/ft_transformer_conv.py b/rllm/nn/conv/table_conv/ft_transformer_conv.py index 70a8d5b0..3710c886 100644 --- a/rllm/nn/conv/table_conv/ft_transformer_conv.py +++ b/rllm/nn/conv/table_conv/ft_transformer_conv.py @@ -12,7 +12,7 @@ ) -class FTTransformerConvs(torch.nn.Module): +class FTTransformerConv(torch.nn.Module): r"""The FT-Transformer backbone in the `"Revisiting Deep Learning Models for Tabular Data" `_ paper. @@ -33,6 +33,7 @@ class FTTransformerConvs(torch.nn.Module): dropout (int): The dropout value (default: 0.1) activation (str): The activation function (default: :obj:`relu`) """ + def __init__( self, dim: int, @@ -41,7 +42,7 @@ def __init__( layers: int = 3, heads: int = 8, dropout: float = 0.2, - activation: str = 'relu', + activation: str = "relu", ): super().__init__() @@ -56,9 +57,9 @@ def __init__( batch_first=True, ) encoder_norm = LayerNorm(dim) - self.transformer = TransformerEncoder(encoder_layer=encoder_layer, - num_layers=layers, - norm=encoder_norm) + self.transformer = TransformerEncoder( + encoder_layer=encoder_layer, num_layers=layers, norm=encoder_norm + ) self.cls_embedding = Parameter(torch.empty(dim)) self.reset_parameters() diff --git a/rllm/transforms/table_transforms/tab_transformer_transform.py b/rllm/transforms/table_transforms/tab_transformer_transform.py index e37ef995..3a3b691e 100644 --- a/rllm/transforms/table_transforms/tab_transformer_transform.py +++ b/rllm/transforms/table_transforms/tab_transformer_transform.py @@ -4,6 +4,7 @@ from rllm.transforms.table_transforms import ColTypeTransform, TableTypeTransform from rllm.nn.pre_encoder import EmbeddingEncoder, StackEncoder + class TabTransformerTransform(TableTypeTransform): def __init__( self, @@ -12,10 +13,8 @@ def __init__( col_types_transform_dict: dict[ColType, ColTypeTransform] = None, ) -> None: if col_types_transform_dict is None: - col_types_transform_dict={ + col_types_transform_dict = { ColType.CATEGORICAL: EmbeddingEncoder(), ColType.NUMERICAL: StackEncoder(), } super().__init__(out_dim, col_stats_dict, col_types_transform_dict) - - diff --git a/test/examples/test_bridge.py b/test/examples/test_bridge.py new file mode 100644 index 00000000..23bb13a4 --- /dev/null +++ b/test/examples/test_bridge.py @@ -0,0 +1,40 @@ +import os +import subprocess + +EXAMPLE_ROOT = os.path.join( + os.path.dirname(os.path.relpath(__file__)), + "..", + "..", + "examples", + "bridge", +) + + +def test_bridge_tml1m(): + script = os.path.join(EXAMPLE_ROOT, "bridge_tml1m.py") + out = subprocess.run(["python", str(script)], capture_output=True) + assert ( + out.returncode == 0 + ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}" + stdout = out.stdout.decode("utf-8") + assert float(stdout[-9:]) > 0.42 + + +def test_gat(): + script = os.path.join(EXAMPLE_ROOT, "bridge_tlf2k.py") + out = subprocess.run(["python", str(script)], capture_output=True) + assert ( + out.returncode == 0 + ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}" + stdout = out.stdout.decode("utf-8") + assert float(stdout[-9:]) > 0.49 + + +def test_han(): + script = os.path.join(EXAMPLE_ROOT, "bridge_tacm12k.py") + out = subprocess.run(["python", str(script)], capture_output=True) + assert ( + out.returncode == 0 + ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}" + stdout = out.stdout.decode("utf-8") + assert float(stdout[-9:]) > 0.32 diff --git a/test/examples/test_tnn.py b/test/examples/test_tnn.py new file mode 100644 index 00000000..1a16f12e --- /dev/null +++ b/test/examples/test_tnn.py @@ -0,0 +1,39 @@ +import os +import subprocess + +EXAMPLE_ROOT = os.path.join( + os.path.dirname(os.path.relpath(__file__)), + "..", + "..", + "examples", +) + + +def test_ft_transformer(): + script = os.path.join(EXAMPLE_ROOT, "ft_transformer.py") + out = subprocess.run(["python", str(script)], capture_output=True) + assert ( + out.returncode == 0 + ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}" + stdout = out.stdout.decode("utf-8") + assert float(stdout[-9:]) > 0.80 + + +def test_gat(): + script = os.path.join(EXAMPLE_ROOT, "tab_transformer.py") + out = subprocess.run(["python", str(script)], capture_output=True) + assert ( + out.returncode == 0 + ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}" + stdout = out.stdout.decode("utf-8") + assert float(stdout[-9:]) > 0.80 + + +def test_han(): + script = os.path.join(EXAMPLE_ROOT, "tabnet.py") + out = subprocess.run(["python", str(script)], capture_output=True) + assert ( + out.returncode == 0 + ), f"stdout: {out.stdout.decode('utf-8')}\nstderr: {out.stderr.decode('utf-8')}" + stdout = out.stdout.decode("utf-8") + assert float(stdout[-9:]) > 0.80