diff --git a/examples/bridge/bridge_tacm12k.py b/examples/bridge/bridge_tacm12k.py index 90d6ed98..e91c12ff 100644 --- a/examples/bridge/bridge_tacm12k.py +++ b/examples/bridge/bridge_tacm12k.py @@ -44,14 +44,13 @@ _, ) = dataset.data_list -# Making graph -paper_embeddings = paper_embeddings.to(device) adj = build_homo_adj( relation_df=citations_table.df, n_all=len(papers_table), ).to(device) target_table = papers_table.to(device) y = papers_table.y.long().to(device) +paper_embeddings = paper_embeddings.to(device) train_mask, val_mask, test_mask = ( @@ -59,7 +58,6 @@ papers_table.val_mask, papers_table.test_mask, ) -out_dim = papers_table.num_classes class Bridge(torch.nn.Module): @@ -79,6 +77,29 @@ def forward(self, table, non_table, adj): return node_feats[: len(table), :] +t_encoder = TableEncoder( + in_dim=paper_embeddings.size(1), + out_dim=paper_embeddings.size(1), + table_transorm=FTTransformerTransform(col_stats_dict=papers_table.stats_dict), + table_conv=TabTransformerConv, +) +g_encoder = GraphEncoder( + in_dim=paper_embeddings.size(1), + out_dim=papers_table.num_classes, + graph_transform=GCNNorm(), + graph_conv=GCNConv, +) +model = Bridge( + table_encoder=t_encoder, + graph_encoder=g_encoder, +).to(device) +optimizer = torch.optim.Adam( + model.parameters(), + lr=args.lr, + weight_decay=args.wd, +) + + def accuracy_score(preds, truth): return (preds == truth).sum(dim=0) / len(truth) @@ -112,32 +133,8 @@ def test_epoch(): return train_acc.item(), val_acc.item(), test_acc.item() -t_encoder = TableEncoder( - in_dim=paper_embeddings.size(1), - out_dim=paper_embeddings.size(1), - table_transorm=FTTransformerTransform(col_stats_dict=papers_table.stats_dict), - table_conv=TabTransformerConv, -) -g_encoder = GraphEncoder( - in_dim=paper_embeddings.size(1), - out_dim=out_dim, - graph_transform=GCNNorm(), - graph_conv=GCNConv, -) -model = Bridge( - table_encoder=t_encoder, - graph_encoder=g_encoder, -).to(device) - start_time = time.time() best_val_acc = best_test_acc = 0 -optimizer = torch.optim.Adam( - [ - dict(params=model.table_encoder.parameters(), lr=0.001), - dict(params=model.graph_encoder.parameters(), lr=0.01, weight_decay=1e-4), - ] -) - for epoch in range(1, args.epochs + 1): train_loss = train_epoch() train_acc, val_acc, test_acc = test_epoch() diff --git a/examples/bridge/bridge_tlf2k.py b/examples/bridge/bridge_tlf2k.py index 90bab716..5bc8d672 100644 --- a/examples/bridge/bridge_tlf2k.py +++ b/examples/bridge/bridge_tlf2k.py @@ -59,7 +59,6 @@ artist_table.val_mask, artist_table.test_mask, ) -out_dim = artist_table.num_classes class Bridge(torch.nn.Module): @@ -79,6 +78,29 @@ def forward(self, table, non_table, adj): return node_feats[: len(table), :] +t_encoder = TableEncoder( + in_dim=emb_size, + out_dim=emb_size, + table_transorm=FTTransformerTransform(col_stats_dict=artist_table.stats_dict), + table_conv=TabTransformerConv, +) +g_encoder = GraphEncoder( + in_dim=emb_size, + out_dim=artist_table.num_classes, + graph_transform=GCNNorm(), + graph_conv=GCNConv, +) +model = Bridge( + table_encoder=t_encoder, + graph_encoder=g_encoder, +).to(device) +optimizer = torch.optim.Adam( + model.parameters(), + lr=args.lr, + weight_decay=args.wd, +) + + def accuracy_score(preds, truth): return (preds == truth).sum(dim=0) / len(truth) @@ -112,27 +134,8 @@ def test_epoch(): return train_acc.item(), val_acc.item(), test_acc.item() -t_encoder = TableEncoder( - in_dim=emb_size, - out_dim=emb_size, - table_transorm=FTTransformerTransform(col_stats_dict=artist_table.stats_dict), - table_conv=TabTransformerConv, -) -g_encoder = GraphEncoder( - in_dim=emb_size, - out_dim=out_dim, - graph_transform=GCNNorm(), - graph_conv=GCNConv, -) -model = Bridge( - table_encoder=t_encoder, - graph_encoder=g_encoder, -).to(device) - - start_time = time.time() best_val_acc = best_test_acc = 0 -optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) for epoch in range(1, args.epochs + 1): train_loss = train_epoch() train_acc, val_acc, test_acc = test_epoch() diff --git a/examples/bridge/bridge_tml1m.py b/examples/bridge/bridge_tml1m.py index 2534eac2..fc85967b 100644 --- a/examples/bridge/bridge_tml1m.py +++ b/examples/bridge/bridge_tml1m.py @@ -63,7 +63,6 @@ user_table.val_mask, user_table.test_mask, ) -out_dim = user_table.num_classes class Bridge(torch.nn.Module): @@ -83,6 +82,29 @@ def forward(self, table, non_table, adj): return node_feats[: len(table), :] +t_encoder = TableEncoder( + in_dim=emb_size, + out_dim=emb_size, + table_transorm=FTTransformerTransform(col_stats_dict=user_table.stats_dict), + table_conv=TabTransformerConv, +) +g_encoder = GraphEncoder( + in_dim=emb_size, + out_dim=user_table.num_classes, + graph_transform=GCNNorm(), + graph_conv=GCNConv, +) +model = Bridge( + table_encoder=t_encoder, + graph_encoder=g_encoder, +).to(device) +optimizer = torch.optim.Adam( + model.parameters(), + lr=args.lr, + weight_decay=args.wd, +) + + def accuracy_score(preds, truth): return (preds == truth).sum(dim=0) / len(truth) @@ -116,27 +138,8 @@ def test_epoch(): return train_acc.item(), val_acc.item(), test_acc.item() -t_encoder = TableEncoder( - in_dim=emb_size, - out_dim=emb_size, - table_transorm=FTTransformerTransform(col_stats_dict=user_table.stats_dict), - table_conv=TabTransformerConv, -) -g_encoder = GraphEncoder( - in_dim=emb_size, - out_dim=out_dim, - graph_transform=GCNNorm(), - graph_conv=GCNConv, -) -model = Bridge( - table_encoder=t_encoder, - graph_encoder=g_encoder, -).to(device) - - start_time = time.time() best_val_acc = best_test_acc = 0 -optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) for epoch in range(1, args.epochs + 1): train_loss = train_epoch() train_acc, val_acc, test_acc = test_epoch() diff --git a/examples/ft_transformer.py b/examples/ft_transformer.py index bc77ae39..dd9a0dc7 100644 --- a/examples/ft_transformer.py +++ b/examples/ft_transformer.py @@ -21,7 +21,7 @@ sys.path.append("./") sys.path.append("../") from rllm.types import ColType -from rllm.datasets.titanic import Titanic +from rllm.datasets import Titanic from rllm.transforms.table_transforms import FTTransformerTransform from rllm.nn.conv.table_conv import FTTransformerConv diff --git a/examples/rect.py b/examples/rect.py index cc428821..a5a49ef9 100644 --- a/examples/rect.py +++ b/examples/rect.py @@ -11,7 +11,7 @@ # Datasets Citeseer | Cora | Pubmed # Unseen Classes [1, 2, 5] [3, 4] | [1, 2, 3] [3, 4, 6] | [2] -# RECT-L 66.50 68.40 | 74.80 72.20 | 75.30 +# RECT-L 61.10 66.10 | 71.20 70.40 | 69.90 import argparse import copy @@ -41,7 +41,7 @@ [ UT.NormalizeFeatures("l2"), UT.SVDFeatureReduction(200), - GT.GDC(), + GT.GCNNorm(), ] ) diff --git a/examples/tab_transformer.py b/examples/tab_transformer.py index 13a11f86..82626792 100644 --- a/examples/tab_transformer.py +++ b/examples/tab_transformer.py @@ -20,7 +20,7 @@ sys.path.append("./") sys.path.append("../") from rllm.types import ColType -from rllm.datasets.titanic import Titanic +from rllm.datasets import Titanic from rllm.transforms.table_transforms import TabTransformerTransform from rllm.nn.conv.table_conv import TabTransformerConv diff --git a/examples/tabnet.py b/examples/tabnet.py index b8a192ce..43b06389 100644 --- a/examples/tabnet.py +++ b/examples/tabnet.py @@ -20,7 +20,7 @@ sys.path.append("./") sys.path.append("../") from rllm.types import ColType -from rllm.datasets.adult import Adult +from rllm.datasets import Adult from rllm.transforms.table_transforms import TabNetTransform from rllm.nn.models import TabNet diff --git a/examples/trompt.py b/examples/trompt.py index 2e0b2303..35f172e0 100644 --- a/examples/trompt.py +++ b/examples/trompt.py @@ -21,7 +21,7 @@ sys.path.append("./") sys.path.append("../") from rllm.types import ColType -from rllm.datasets.adult import Adult +from rllm.datasets import Adult from rllm.transforms.table_transforms import TromptTransform from rllm.nn.conv.table_conv import TromptConv @@ -43,7 +43,7 @@ dataset = Adult(cached_dir=path)[0] dataset.to(device) -# Split dataset, here the ratio of train-val-test is 80%-10%-10% +# Split dataset, here the ratio of train-val-test is 26048-6513-16281 train_loader, val_loader, test_loader = dataset.get_dataloader( 26048, 6513, 16281, batch_size=args.batch_size ) diff --git a/rllm/datasets/__init__.py b/rllm/datasets/__init__.py index 9cb320bf..cd24eb56 100644 --- a/rllm/datasets/__init__.py +++ b/rllm/datasets/__init__.py @@ -1,4 +1,5 @@ from .adult import Adult +from .churn_modelling import ChurnModelling from .dblp import DBLP from .imdb import IMDB from .planetoid import PlanetoidDataset @@ -11,6 +12,7 @@ __all__ = [ "Adult", + "ChurnModelling", "DBLP", "IMDB", "PlanetoidDataset", diff --git a/rllm/datasets/adult.py b/rllm/datasets/adult.py index 1533b2be..e611702c 100644 --- a/rllm/datasets/adult.py +++ b/rllm/datasets/adult.py @@ -26,18 +26,18 @@ class Adult(Dataset): fnlwgt: The number of people the census believes have this job. Education: The highest level of education achieved (Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, - 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.). + 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool). Education-Num: A numeric version of Education. Marital-Status: Marital status of the individual (Married-civ-spouse, - Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.). + Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse). Occupation: The kind of work individuals perform (Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, - Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.). + Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces). Relationship: Relationship to head-of-household (Wife, Own-child, Husband, - Not-in-family, Other-relative, Unmarried.). + Not-in-family, Other-relative, Unmarried). Race: Race of the individual (White, Asian-Pac-Islander, - Amer-Indian-Eskimo, Other, Black.). + Amer-Indian-Eskimo, Other, Black). Sex: Gender of the individual. Capital-Gain: Total capital gains. Capital-Loss: Total capital losses. @@ -53,8 +53,8 @@ class Adult(Dataset): .. parsed-literal:: Statics: - Name Passengers Features - Size 48842 14 + Name Individuals Features + Size 48842 14 """ diff --git a/rllm/datasets/churn_modelling.py b/rllm/datasets/churn_modelling.py new file mode 100644 index 00000000..7add9ef8 --- /dev/null +++ b/rllm/datasets/churn_modelling.py @@ -0,0 +1,119 @@ +import os +import os.path as osp +from typing import Optional + +import pandas as pd + +from rllm.types import ColType +from rllm.data.table_data import TableData +from rllm.data.dataset import Dataset +from rllm.utils.download import download_url + + +class ChurnModelling(Dataset): + r"""The Churn Modelling dataset is used to predict which customers are + likely to churn from the organization by analyzing various attributes and + applying machine learning and deep learning techniques. + + Customer churn refers to when a customer (player, subscriber, user, etc.) + ceases their relationship with a company. Online businesses typically treat + a customer as churned once a particular amount of time has elapsed since + the customer's last interaction with the site or service. + + Customer churn occurs when customers or subscribers stop doing business + with a company or service, also known as customer attrition. It is also + referred to as loss of clients or customers. Similar to predicting + employee turnover, we are going to predict customer churn using this + dataset. + + The dataset encompasses a variety of features pertaining to customers and + their interactions with the company. The primary objective is to predict + whether a customer will churn: + + RowNumber: Row number. + CustomerId: Unique identifier for the customer. + Surname: Surname of the customer. + CreditScore: Credit score of the customer. + Geography: Country of the customer (France, Spain, Germany). + Gender: Gender of the customer (Male, Female). + Age: Age of the customer. + Tenure: Number of years the customer has been with the company. + Balance: Account balance of the customer. + NumOfProducts: Number of products the customer has with the company. + HasCrCard: Does the customer have a credit card? (0 = No, 1 = Yes). + IsActiveMember: Is the customer an active member? (0 = No, 1 = Yes). + EstimatedSalary: Estimated salary of the customer. + Exited: Did the customer churn? (0 = No, 1 = Yes). + + Args: + cached_dir (str): Root directory where dataset should be saved. + forced_reload (bool): If set to `True`, this dataset will be + re-processed again. + + .. parsed-literal:: + + Statics: + Name Customers Features + Size 10000 14 + + """ + + url = "https://raw.githubusercontent.com/sharmaroshan/Churn-Modelling-Dataset/master/Churn_Modelling.csv" + + def __init__(self, cached_dir: str, forced_reload: Optional[bool] = False) -> None: + self.name = "churn" + root = os.path.join(cached_dir, self.name) + super().__init__(root, force_reload=forced_reload) + self.data_list = [TableData.load(self.processed_paths[0])] + + @property + def raw_filenames(self): + return ["churn.csv"] + + @property + def processed_filenames(self): + return ["data.pt"] + + def process(self): + r""" + process data and save to './cached_dir/{dataset}/processed/'. + """ + os.makedirs(self.processed_dir, exist_ok=True) + path = osp.join(self.raw_dir, self.raw_filenames[0]) + df = pd.read_csv(path, index_col=["RowNumber"]) + + # Note: the order of column in col_types must + # correspond to the order of column in files, + # except target column. + col_types = { + "CreditScore": ColType.NUMERICAL, + "Geography": ColType.CATEGORICAL, + "Gender": ColType.CATEGORICAL, + "Age": ColType.NUMERICAL, + "Tenure": ColType.NUMERICAL, + "Balance": ColType.NUMERICAL, + "NumOfProducts": ColType.NUMERICAL, + "HasCrCard": ColType.NUMERICAL, + "IsActiveMember": ColType.CATEGORICAL, + "EstimatedSalary": ColType.NUMERICAL, + "Exited": ColType.CATEGORICAL, + } + data = TableData( + df=df, + col_types=col_types, + target_col="Exited", + ) + + data.save(self.processed_paths[0]) + + def download(self): + os.makedirs(self.raw_dir, exist_ok=True) + download_url(self.url, self.raw_dir, self.raw_filenames[0]) + + def __len__(self): + return 1 + + def __getitem__(self, index: int): + if index != 0: + raise IndexError + return self.data_list[index] diff --git a/rllm/nn/models/rect.py b/rllm/nn/models/rect.py index b24b8780..632eac62 100644 --- a/rllm/nn/models/rect.py +++ b/rllm/nn/models/rect.py @@ -30,6 +30,7 @@ def __init__( self.in_dim = in_dim self.hidden_dim = hidden_dim self.dropout = dropout + self.prelu = torch.nn.PReLU() self.conv = GCNConv(in_dim, hidden_dim) self.lin = Linear(hidden_dim, in_dim) self.reset_parameters() @@ -41,7 +42,7 @@ def reset_parameters(self): torch.nn.init.xavier_uniform_(self.lin.weight.data) def forward(self, x: Tensor, adj: Tensor): - x = self.conv(x, adj) + x = self.prelu(self.conv(x, adj)) x = F.dropout(x, p=self.dropout, training=self.training) return self.lin(x)