diff --git a/docs/source/tutorial/gnns.rst b/docs/source/tutorial/gnns.rst index f92fa939..87508265 100644 --- a/docs/source/tutorial/gnns.rst +++ b/docs/source/tutorial/gnns.rst @@ -74,18 +74,15 @@ Finally, we need to implement a :obj:`train()` function and a :obj:`test()` func .. code-block:: python - def train(): + for epoch in range(200): model.train() optimizer.zero_grad() out = model(data.x, data.adj) loss = loss_fn(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() - return loss.item() - - @torch.no_grad() - def test(): + with torch.no_grad(): model.eval() out = model(data.x, data.adj) pred = out.argmax(dim=1) @@ -94,29 +91,6 @@ Finally, we need to implement a :obj:`train()` function and a :obj:`test()` func for mask in [data.train_mask, data.val_mask, data.test_mask]: correct = float(pred[mask].eq(data.y[mask]).sum().item()) accs.append(correct / int(mask.sum())) - return accs - - - metric = "Acc" - best_val_acc = best_test_acc = 0 - times = [] - for epoch in range(1, args.epochs + 1): - start = time.time() - - train_loss = train() - train_acc, val_acc, test_acc = test() - - if val_acc > best_val_acc: - best_val_acc = val_acc - best_test_acc = test_acc - - times.append(time.time() - start) - print( - f"Epoch: [{epoch}/{args.epochs}] " - f"Train Loss: {train_loss:.4f} Train {metric}: {train_acc:.4f} " - f"Val {metric}: {val_acc:.4f}, Test {metric}: {test_acc:.4f} " - ) - print(f"Mean time per epoch: {torch.tensor(times).mean():.4f}s") - print(f"Total time: {sum(times):.4f}s") - print(f"Best test acc: {best_test_acc:.4f}") \ No newline at end of file + print(f"Accuracy: {acc:.4f}") + >>> 0.8150 \ No newline at end of file diff --git a/docs/source/tutorial/rtls.rst b/docs/source/tutorial/rtls.rst index 35514b6f..53d498dc 100644 --- a/docs/source/tutorial/rtls.rst +++ b/docs/source/tutorial/rtls.rst @@ -1,9 +1,9 @@ Design of RTLs ============== -What is a RTL? +What is RTL? ---------------- -In machine learning, **Relational Table Learnings (RTLs)** typically refers to the learning of relational table data, which consists of multiple interconnected tables with significant heterogeneity. In an RTL, the input comprises multiple table signals that are interrelated. A typical RTL architecture consists of one or more Transforms followed by multiple Convolution layers, as detailed in *Understanding Transform* and *Understanding Convolution*. +In machine learning, **Relational Table Learnings (RTLs)** typically refers to the learning of relational table data, which consists of multiple interconnected tables with significant heterogeneity. In an RTL, the input comprises multiple table signals that are interrelated. A typical RTL architecture consists of one or more Transforms followed by multiple Convolution layers, as detailed in **Understanding Transforms** and **Understanding Convolutions**. Construct a BRIDGE @@ -41,7 +41,7 @@ For convenience, we will construct a basic homogeneous graph here, even though m .. code-block:: python - from utils import build_homo_graph, reorder_ids + from examples.bridge.utils import build_homo_graph, reorder_ids # Original movie id in datasets is unordered, so we reorder them. ordered_rating = reorder_ids( @@ -106,31 +106,24 @@ After initializing the data, we instantiate the model. Since the task of the TML table_encoder=t_encoder, graph_encoder=g_encoder, ).to(device) - optimizer = torch.optim.Adam( - model.parameters(), - lr=args.lr, - weight_decay=args.wd, - ) + optimizer = torch.optim.Adam(model.parameters()) -Finally, we need to implement a :obj:`train()` function and a :obj:`test()` function, the latter of which does not require gradient tracking. The model can then be trained on the training and validation sets, and the classification results can be obtained from the test set. +Finally, we jointly train the model and evaluate the results on the test set. .. code-block:: python - def train() -> float: - model.train() - optimizer.zero_grad() - logits = model( - table=user_table, - non_table=movie_embeddings, - adj=adj, - ) - loss = F.cross_entropy(logits[train_mask].squeeze(), y[train_mask]) - loss.backward() - optimizer.step() - return loss.item() + for epoch in range(50): + optimizer.zero_grad() + logits = model( + table=user_table, + non_table=movie_embeddings, + adj=adj, + ) + loss = F.cross_entropy(logits[train_mask].squeeze(), y[train_mask]) + loss.backward() + optimizer.step() - @torch.no_grad() - def test(): + with torch.no_grad(): model.eval() logits = model( table=user_table, @@ -138,30 +131,6 @@ Finally, we need to implement a :obj:`train()` function and a :obj:`test()` func adj=adj, ) preds = logits.argmax(dim=1) - - accs = [] - for mask in [train_mask, val_mask, test_mask]: - correct = float(preds[mask].eq(y[mask]).sum().item()) - accs.append(correct / int(mask.sum())) - return accs - - start_time = time.time() - best_val_acc = best_test_acc = 0 - for epoch in range(1, args.epochs + 1): - train_loss = train() - train_acc, val_acc, test_acc = test() - print( - f"Epoch: [{epoch}/{args.epochs}]" - f"Loss: {train_loss:.4f} train_acc: {train_acc:.4f} " - f"val_acc: {val_acc:.4f} test_acc: {test_acc:.4f} " - ) - if val_acc > best_val_acc: - best_val_acc = val_acc - best_test_acc = test_acc - - print(f"Total Time: {time.time() - start_time:.4f}s") - print( - "BRIDGE result: " - f"Best Val acc: {best_val_acc:.4f}, " - f"Best Test acc: {best_test_acc:.4f}" - ) \ No newline at end of file + acc = (preds[test_mask] == y[test_mask]).sum(dim=0) / test_mask.sum() + print(f'Accuracy: {acc:.4f}') + >>> 0.3860 \ No newline at end of file diff --git a/docs/source/tutorial/tnns.rst b/docs/source/tutorial/tnns.rst index 583805d7..f68ef88c 100644 --- a/docs/source/tutorial/tnns.rst +++ b/docs/source/tutorial/tnns.rst @@ -1,8 +1,8 @@ Design of TNNs =============== -What is a TNN? +What is TNN? ---------------- -In machine learning, **Table/Tabular Neural Networks (TNNs)** are recently emerging neural networks specifically designed to process tabular data. In a TNN, the input is structured tabular data, usually organized in rows and columns. A typical TNN architecture consists of an initial Transform followed by multiple Convolution layers, as detailed in *Understanding Transform* and *Understanding Convolution*. +In machine learning, **Table/Tabular Neural Networks (TNNs)** are recently emerging neural networks specifically designed to process tabular data. In a TNN, the input is structured tabular data, usually organized in rows and columns. A typical TNN architecture consists of an initial Transform followed by multiple Convolution layers, as detailed in *Understanding Transforms* and *Understanding Convolutions*. Construct a TabTransformer @@ -11,34 +11,28 @@ In this tutorial, we will learn the basic workflow of using `[TabTransformer] float: - model.train() - loss_accum = total_count = 0.0 - for batch in tqdm(train_loader, desc=f"Epoch: {epoch}"): + for epoch in range(50): + for batch in train_loader: x, y = batch - pred = model.forward(x) - loss = F.cross_entropy(pred, y.long()) + pred = model(x) + loss = F.cross_entropy(pred, y) optimizer.zero_grad() loss.backward() - loss_accum += float(loss) * y.size(0) - total_count += y.size(0) optimizer.step() - return loss_accum / total_count - - - @torch.no_grad() - def test(loader: DataLoader) -> float: + + with torch.no_grad(): model.eval() - correct = total = 0 - for batch in loader: - feat_dict, y = batch - pred = model.forward(feat_dict) - _, predicted = torch.max(pred, 1) - total += y.size(0) - correct += (predicted == y).sum().item() - accuracy = correct / total - return accuracy - - metric = "Acc" - best_val_metric = best_test_metric = 0 - times = [] - for epoch in range(1, args.epochs + 1): - start = time.time() - - train_loss = train(epoch) - train_metric = test(train_loader) - val_metric = test(val_loader) - test_metric = test(test_loader) - - if val_metric > best_val_metric: - best_val_metric = val_metric - best_test_metric = test_metric - - times.append(time.time() - start) - print( - f"Train Loss: {train_loss:.4f}, Train {metric}: {train_metric:.4f}, " - f"Val {metric}: {val_metric:.4f}, Test {metric}: {test_metric:.4f}" - ) - - print(f"Mean time per epoch: {torch.tensor(times).mean():.4f}s") - print(f"Total time: {sum(times):.4f}s") - print( - f"Best Val {metric}: {best_val_metric:.4f}, " - f"Best Test {metric}: {best_test_metric:.4f}" - ) + correct = 0 + for tf in test_loader: + x, y = batch + pred = model(x) + pred_class = pred.argmax(dim=-1) + correct += (y == pred_class).sum() + acc = int(correct) / len(test_dataset) + print(f'Accuracy: {acc:.4f}') + >>> 0.8082 diff --git a/examples/bridge/bridge_tml1m.py b/examples/bridge/bridge_tml1m.py index 178ecc4a..de97b7c0 100644 --- a/examples/bridge/bridge_tml1m.py +++ b/examples/bridge/bridge_tml1m.py @@ -14,6 +14,7 @@ sys.path.append("./") sys.path.append("../") +sys.path.append("../../") from rllm.datasets import TML1MDataset from rllm.transforms.graph_transforms import GCNTransform from rllm.transforms.table_transforms import TabTransformerTransform @@ -128,6 +129,8 @@ def test(): for mask in [train_mask, val_mask, test_mask]: correct = float(preds[mask].eq(y[mask]).sum().item()) accs.append(correct / int(mask.sum())) + print(mask.sum()) + exit(0) return accs diff --git a/rllm/data/__init__.py b/rllm/data/__init__.py index 26f01797..5c085c97 100644 --- a/rllm/data/__init__.py +++ b/rllm/data/__init__.py @@ -1,4 +1,4 @@ -from .dataset import Dataset # noqa +from ..datasets.dataset import Dataset # noqa from .graph_data import BaseGraph, GraphData, HeteroGraphData # noqa from .table_data import BaseTable, TableData, TableDataset # noqa from .storage import BaseStorage, NodeStorage, EdgeStorage, recursive_apply # noqa diff --git a/rllm/datasets/adult.py b/rllm/datasets/adult.py index 1b69252b..570818b3 100644 --- a/rllm/datasets/adult.py +++ b/rllm/datasets/adult.py @@ -6,7 +6,7 @@ from rllm.types import ColType from rllm.data.table_data import TableData -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.utils.download import download_url diff --git a/rllm/datasets/bank_marketing.py b/rllm/datasets/bank_marketing.py index 1b473875..0b5aa6d7 100644 --- a/rllm/datasets/bank_marketing.py +++ b/rllm/datasets/bank_marketing.py @@ -6,7 +6,7 @@ from rllm.types import ColType from rllm.data.table_data import TableData -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.utils.download import download_url diff --git a/rllm/datasets/churn_modelling.py b/rllm/datasets/churn_modelling.py index 6ec85ad4..e64f1d10 100644 --- a/rllm/datasets/churn_modelling.py +++ b/rllm/datasets/churn_modelling.py @@ -6,7 +6,7 @@ from rllm.types import ColType from rllm.data.table_data import TableData -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.utils.download import download_url diff --git a/rllm/data/dataset.py b/rllm/datasets/dataset.py similarity index 100% rename from rllm/data/dataset.py rename to rllm/datasets/dataset.py diff --git a/rllm/datasets/dblp.py b/rllm/datasets/dblp.py index d0c902d2..e4909805 100644 --- a/rllm/datasets/dblp.py +++ b/rllm/datasets/dblp.py @@ -7,7 +7,7 @@ import scipy.sparse as sp import torch -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.data.graph_data import HeteroGraphData from rllm.utils.graph_utils import sparse_mx_to_torch_sparse_tensor from rllm.utils.download import download_url diff --git a/rllm/datasets/imdb.py b/rllm/datasets/imdb.py index 56219c8c..e4996b55 100644 --- a/rllm/datasets/imdb.py +++ b/rllm/datasets/imdb.py @@ -9,7 +9,7 @@ # import sys # sys.path.append('../') -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.data.graph_data import HeteroGraphData from rllm.utils.sparse import sparse_mx_to_torch_sparse_tensor from rllm.utils.extract import extract_zip diff --git a/rllm/datasets/planetoid.py b/rllm/datasets/planetoid.py index 0f7da5f2..f0c9dfcb 100644 --- a/rllm/datasets/planetoid.py +++ b/rllm/datasets/planetoid.py @@ -13,7 +13,7 @@ # import sys # sys.path.append('../') -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.data.graph_data import GraphData from rllm.utils.sparse import sparse_mx_to_torch_sparse_tensor from rllm.datasets.utils import index2mask diff --git a/rllm/datasets/sjtutables/tacm12k.py b/rllm/datasets/sjtutables/tacm12k.py index d1fa5747..476d3d4e 100644 --- a/rllm/datasets/sjtutables/tacm12k.py +++ b/rllm/datasets/sjtutables/tacm12k.py @@ -8,7 +8,7 @@ from rllm.types import ColType from rllm.data.table_data import TableData -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.utils.download import download_url from rllm.utils.extract import extract_zip diff --git a/rllm/datasets/sjtutables/tlf2k.py b/rllm/datasets/sjtutables/tlf2k.py index d940ee47..845740ac 100644 --- a/rllm/datasets/sjtutables/tlf2k.py +++ b/rllm/datasets/sjtutables/tlf2k.py @@ -7,7 +7,7 @@ from rllm.types import ColType from rllm.data.table_data import TableData -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.utils.download import download_url from rllm.utils.extract import extract_zip diff --git a/rllm/datasets/sjtutables/tml1m.py b/rllm/datasets/sjtutables/tml1m.py index 5cf9ca2f..90a6874d 100644 --- a/rllm/datasets/sjtutables/tml1m.py +++ b/rllm/datasets/sjtutables/tml1m.py @@ -8,7 +8,7 @@ from rllm.types import ColType from rllm.data.table_data import TableData -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.utils.download import download_url from rllm.utils.extract import extract_zip diff --git a/rllm/datasets/tagdataset.py b/rllm/datasets/tagdataset.py index 786b7547..34614fa1 100644 --- a/rllm/datasets/tagdataset.py +++ b/rllm/datasets/tagdataset.py @@ -7,7 +7,7 @@ import torch -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.data.graph_data import GraphData from rllm.utils.download import download_url from rllm.data.storage import BaseStorage diff --git a/rllm/datasets/tape.py b/rllm/datasets/tape.py index 4f6c7487..56289cdd 100644 --- a/rllm/datasets/tape.py +++ b/rllm/datasets/tape.py @@ -14,7 +14,7 @@ import torch -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.data.graph_data import GraphData from rllm.utils.sparse import sparse_mx_to_torch_sparse_tensor from rllm.utils.extract import extract_zip diff --git a/rllm/datasets/titanic.py b/rllm/datasets/titanic.py index f5961d3a..3db7a63b 100644 --- a/rllm/datasets/titanic.py +++ b/rllm/datasets/titanic.py @@ -6,7 +6,7 @@ from rllm.types import ColType from rllm.data.table_data import TableData -from rllm.data.dataset import Dataset +from rllm.datasets.dataset import Dataset from rllm.utils.download import download_url