Skip to content

Commit

Permalink
change model_config to get_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
JianwuZheng413 committed Dec 3, 2024
1 parent 224b600 commit 468669e
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 48 deletions.
5 changes: 3 additions & 2 deletions examples/excelformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
sys.path.append("./")
sys.path.append("../")
from rllm.types import ColType
from rllm.nn.models import MODEL_CONFIG
from rllm.nn.models import get_transform
from rllm.datasets.titanic import Titanic
from rllm.nn.conv.table_conv import ExcelFormerConv

Expand Down Expand Up @@ -52,14 +52,15 @@ def __init__(
metadata: Dict[ColType, List[Dict[str, Any]]],
):
super().__init__()
self.transform = MODEL_CONFIG[ExcelFormerConv](
self.transform = get_transform(ExcelFormerConv)(
out_dim=hidden_dim,
metadata=metadata,
)

self.convs = torch.nn.ModuleList(
[ExcelFormerConv(dim=hidden_dim) for _ in range(num_layers)]
)

self.fc = torch.nn.Sequential(
torch.nn.LayerNorm(hidden_dim),
torch.nn.ReLU(),
Expand Down
4 changes: 2 additions & 2 deletions examples/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
sys.path.append("../")
from rllm.types import ColType
from rllm.datasets import Titanic
from rllm.nn.models import MODEL_CONFIG
from rllm.nn.models import get_transform
from rllm.nn.conv.table_conv import FTTransformerConv

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
metadata: Dict[ColType, List[Dict[str, Any]]],
):
super().__init__()
self.transform = MODEL_CONFIG[FTTransformerConv](
self.transform = get_transform(FTTransformerConv)(
out_dim=hidden_dim,
metadata=metadata,
)
Expand Down
25 changes: 6 additions & 19 deletions examples/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# ArXiv: https://arxiv.org/abs/1710.10903

# Datasets CiteSeer Cora PubMed
# Acc 0.717 0.823 0.778
# Acc 0.717 0.830 0.778
# Time 16.6s 8.4s 15.6s

import argparse
Expand All @@ -16,7 +16,7 @@

sys.path.append("./")
sys.path.append("../")
from rllm.nn.models import MODEL_CONFIG
from rllm.nn.models import get_transform
from rllm.datasets.planetoid import PlanetoidDataset
from rllm.nn.conv.graph_conv import GATConv

Expand All @@ -30,21 +30,11 @@
parser.add_argument("--wd", type=float, default=5e-4, help="Weight decay")
parser.add_argument("--epochs", type=int, default=100, help="Training epochs")
parser.add_argument("--dropout", type=float, default=0.5, help="Graph Dropout")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

torch.manual_seed(args.seed)
path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data")
dataset = PlanetoidDataset(
path,
args.dataset,
)
dataset = PlanetoidDataset(path, args.dataset, transform=get_transform(GATConv)())
data = dataset[0]
data.adj = torch.eye(data.adj.size(0)) + data.adj
indices = torch.nonzero(data.adj, as_tuple=False)
values = data.adj[indices[:, 0], indices[:, 1]]
sparse_tensor = torch.sparse_coo_tensor(indices.t(), values, data.adj.size())
data.adj = sparse_tensor


class GAT(torch.nn.Module):
Expand All @@ -58,13 +48,10 @@ def __init__(
):
super().__init__()
self.dropout = dropout
self.transform = MODEL_CONFIG[GATConv]()
self.conv1 = GATConv(in_dim, hidden_dim, heads, concat=True)
self.conv2 = GATConv(hidden_dim * heads, out_dim, heads=1)

def forward(self, data):
data = self.transform(data)
x, adj = data.x, data.adj
def forward(self, x, adj):
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.elu(self.conv1(x, adj))
x = F.dropout(x, p=self.dropout, training=self.training)
Expand All @@ -89,7 +76,7 @@ def forward(self, data):
def train():
model.train()
optimizer.zero_grad()
out = model(data)
out = model(data.x, data.adj)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
Expand All @@ -99,7 +86,7 @@ def train():
@torch.no_grad()
def test():
model.eval()
out = model(data)
out = model(data.x, data.adj)
pred = out.argmax(dim=1)

accs = []
Expand Down
14 changes: 5 additions & 9 deletions examples/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

sys.path.append("./")
sys.path.append("../")
from rllm.nn.models import MODEL_CONFIG
from rllm.nn.models import get_transform
from rllm.datasets.planetoid import PlanetoidDataset
from rllm.nn.conv.graph_conv import GCNConv

Expand All @@ -33,23 +33,19 @@
args = parser.parse_args()

torch.manual_seed(args.seed)

path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data")
dataset = PlanetoidDataset(path, args.dataset)
dataset = PlanetoidDataset(path, args.dataset, transform=get_transform(GCNConv)())
data = dataset[0]


class GCN(torch.nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, dropout):
super().__init__()
self.dropout = dropout
self.transform = MODEL_CONFIG[GCNConv]()
self.conv1 = GCNConv(in_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, out_dim)

def forward(self, data):
data = self.transform(data)
x, adj = data.x, data.adj
def forward(self, x, adj):
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.relu(self.conv1(x, adj))
x = F.dropout(x, p=self.dropout, training=self.training)
Expand All @@ -73,7 +69,7 @@ def forward(self, data):
def train():
model.train()
optimizer.zero_grad()
out = model(data)
out = model(data.x, data.adj)
loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
Expand All @@ -83,7 +79,7 @@ def train():
@torch.no_grad()
def test():
model.eval()
out = model(data)
out = model(data.x, data.adj)
pred = out.argmax(dim=1)

accs = []
Expand Down
4 changes: 2 additions & 2 deletions examples/rect.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

sys.path.append("./")
sys.path.append("../")
from rllm.nn.models import RECT_L, MODEL_CONFIG
from rllm.nn.models import RECT_L, get_transform
from rllm.datasets.planetoid import PlanetoidDataset
from rllm.transforms.utils import RemoveTrainingClasses

Expand All @@ -37,7 +37,7 @@
parser.add_argument("--epochs", type=int, default=200, help="Training epochs")
args = parser.parse_args()

transform = MODEL_CONFIG[RECT_L]()
transform = get_transform(RECT_L)()

path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data")
dataset = PlanetoidDataset(path, args.dataset, transform=transform, force_reload=True)
Expand Down
4 changes: 2 additions & 2 deletions examples/tab_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
sys.path.append("../")
from rllm.types import ColType
from rllm.datasets import Titanic
from rllm.nn.models import MODEL_CONFIG
from rllm.nn.models import get_transform
from rllm.nn.conv.table_conv import TabTransformerConv

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
metadata: Dict[ColType, List[Dict[str, Any]]],
):
super().__init__()
self.transform = MODEL_CONFIG[TabTransformerConv](
self.transform = get_transform(TabTransformerConv)(
out_dim=hidden_dim,
metadata=metadata,
)
Expand Down
4 changes: 2 additions & 2 deletions examples/tabnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
sys.path.append("../")
from rllm.types import ColType
from rllm.datasets import Adult
from rllm.nn.models import TabNet, MODEL_CONFIG
from rllm.nn.models import TabNet, get_transform

parser = argparse.ArgumentParser()
parser.add_argument("--dim", help="embedding dim", type=int, default=32)
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(
metadata: Dict[ColType, List[Dict[str, Any]]],
):
super().__init__()
self.transform = MODEL_CONFIG[TabNet](
self.transform = get_transform(TabNet)(
out_dim=hidden_dim,
metadata=metadata,
)
Expand Down
4 changes: 2 additions & 2 deletions examples/trompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
sys.path.append("../")
from rllm.types import ColType
from rllm.datasets import Adult
from rllm.nn.models import MODEL_CONFIG
from rllm.nn.models import get_transform
from rllm.nn.conv.table_conv import TromptConv

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(

self.transforms = torch.nn.ModuleList(
[
MODEL_CONFIG[TromptConv](
get_transform(TromptConv)(
out_dim=hidden_dim,
metadata=metadata,
)
Expand Down
4 changes: 2 additions & 2 deletions rllm/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# GLU_Block,
# GLU_Layer
)
from .model_config import MODEL_CONFIG
from .get_transform import get_transform

__all__ = [
"RECT_L",
"TabNet",
"MODEL_CONFIG",
"get_transform",
]
34 changes: 28 additions & 6 deletions rllm/nn/models/model_config.py → rllm/nn/models/get_transform.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any, Dict, Optional, Type

from rllm.nn.conv.graph_conv import GATConv
from rllm.nn.conv.graph_conv import GCNConv
from rllm.nn.conv.table_conv import ExcelFormerConv
from rllm.nn.conv.table_conv import TromptConv
from rllm.nn.conv.table_conv import TabTransformerConv
from rllm.nn.conv.table_conv import FTTransformerConv

from rllm.nn.models.tabnet import TabNet
from rllm.nn.models.rect import RECT_L
from rllm.nn.models import RECT_L
from rllm.nn.models import TabNet

from rllm.transforms.graph_transforms import GCNTransform
from rllm.transforms.graph_transforms import RECTTransform
Expand All @@ -15,15 +17,35 @@
from rllm.transforms.table_transforms import TabNetTransform


MODEL_CONFIG = {
# GNN models
# Define GNN configuration dictionary
GNN_CONV_TO_TRANSFORM: Dict[Type[Any], Type[Any]] = {
GCNConv: GCNTransform,
GATConv: GCNTransform,
RECT_L: RECTTransform,
# TNN models
}

# Define TNN configuration dictionary
TNN_CONV_TO_TRANSFORM: Dict[Type[Any], Type[Any]] = {
TabTransformerConv: TabTransformerTransform,
FTTransformerConv: FTTransformerTransform,
TabNet: TabNetTransform,
ExcelFormerConv: FTTransformerTransform,
TromptConv: FTTransformerTransform,
TabNet: TabNetTransform,
}


def get_transform(conv: Type[Any]) -> Optional[Type[Any]]:
"""Get the default transform for a given conv class.
Args:
conv (Type[Any]): The conv class.
Returns:
Optional[Type[Any]]: The default transform class, or None if not found.
"""
if conv in GNN_CONV_TO_TRANSFORM:
return GNN_CONV_TO_TRANSFORM[conv]
elif conv in TNN_CONV_TO_TRANSFORM:
return TNN_CONV_TO_TRANSFORM[conv]
else:
return None

0 comments on commit 468669e

Please sign in to comment.