Skip to content

Commit

Permalink
Merge pull request #135 from rllm-team/develop
Browse files Browse the repository at this point in the history
split pre_encoders and add annotations for HGTConv
  • Loading branch information
JianwuZheng413 authored Nov 26, 2024
2 parents 60ca8a8 + 633c13e commit 5835463
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 183 deletions.
8 changes: 4 additions & 4 deletions examples/bridge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def __init__(
def forward(self, table):
feat_dict = table.get_feat_dict() # A dict contains feature tensor.
x = self.table_transform(feat_dict)
for table_conv in self.convs:
x = table_conv(x)
for conv in self.convs:
x = conv(x)
x = x.mean(dim=1)
return x

Expand Down Expand Up @@ -170,9 +170,9 @@ def __init__(

def forward(self, x, adj):
adj = self.graph_transform(adj)
for graph_conv in self.convs[:-1]:
for conv in self.convs[:-1]:
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.relu(graph_conv(x, adj))
x = F.relu(conv(x, adj))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj)
return x
4 changes: 0 additions & 4 deletions rllm/nn/__init__.py

This file was deleted.

19 changes: 19 additions & 0 deletions rllm/nn/conv/graph_conv/hgt_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,25 @@ def __init__(
group: str = "sum",
dropout_rate: float = 0.0,
):
r"""The Heterogeneous Graph Transformer (HGT) layer,
as introduced in the `"Heterogeneous Graph Transformer"
<https://arxiv.org/abs/2003.01332>`__ paper.
Args:
in_dim (Union[int, Dict[str, int]]): Size of each input sample of every node type.
out_dim (int): Size of each output sample of every node type.
metadata (Tuple[List[str], List[Tuple[str, str]]]): The metadata of the heterogeneous
graph, *i.e.* its node and edge types given by a list of strings and a list of
string triplets, respectively.
heads (int, optional): Number of multi-head attentions. Defaults to 1.
group (str, optional): Aggregation method, either 'sum', 'mean', or 'max'. Defaults to
'sum'.
dropout_rate (float, optional): Dropout probability of the normalized attention
coefficients which exposes each node to a stochastically sampled neighborhood during
training. Defaults to 0.0.
"""

super().__init__()

if not isinstance(in_dim, dict):
Expand Down
8 changes: 3 additions & 5 deletions rllm/nn/pre_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from .coltype_encoders import (
EmbeddingEncoder,
LinearEncoder,
StackEncoder,
)
from .embedding_encoder import EmbeddingEncoder
from .linear_encoder import LinearEncoder
from .stack_encoder import StackEncoder

__all__ = [
"EmbeddingEncoder",
Expand Down
170 changes: 0 additions & 170 deletions rllm/nn/pre_encoder/coltype_encoders.py

This file was deleted.

71 changes: 71 additions & 0 deletions rllm/nn/pre_encoder/embedding_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, Module

from rllm.types import ColType, NAMode, StatType
from rllm.transforms.table_transforms import ColTypeTransform


class EmbeddingEncoder(ColTypeTransform):
r"""An simple embedding look-up based Transform for categorical features.
It applies :class:`torch.nn.Embedding` for each categorical feature and
concatenates the output embeddings.
"""

supported_types = {ColType.CATEGORICAL}

def __init__(
self,
out_dim: int | None = None,
stats_list: List[Dict[StatType, Any]] | None = None,
col_type: ColType | None = ColType.CATEGORICAL,
post_module: Module | None = None,
na_mode: NAMode | None = None,
) -> None:
super().__init__(out_dim, stats_list, col_type, post_module, na_mode)

def post_init(self):
r"""This is the actual initialization function."""
num_categories_list = [0]
for stats in self.stats_list:
num_categories = stats[StatType.COUNT]
num_categories_list.append(num_categories)
# Single embedding module that stores embeddings of all categories
# across all categorical columns.
# 0-th category is for NaN.
self.emb = Embedding(
sum(num_categories_list) + 1,
self.out_dim,
padding_idx=0,
)
# [num_cols, ]
self.register_buffer(
"offset",
torch.cumsum(
torch.tensor(num_categories_list[:-1], dtype=torch.long), dim=0
),
)
self.reset_parameters()

def reset_parameters(self):
super().reset_parameters()
self.emb.reset_parameters()

def encode_forward(
self,
feat: Tensor,
col_names: List[str] | None = None,
) -> Tensor:
# feat: [batch_size, num_cols]
# Get NaN mask
na_mask = feat < 0
# Increment the index by one not to conflict with the padding idx
# Also add offset for each column to avoid embedding conflict
feat = feat + self.offset + 1
# Use 0th index for NaN
feat[na_mask] = 0
# [batch_size, num_cols, dim]
return self.emb(feat)
65 changes: 65 additions & 0 deletions rllm/nn/pre_encoder/linear_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Module, Parameter

from rllm.types import ColType, NAMode, StatType
from rllm.transforms.table_transforms import ColTypeTransform


class LinearEncoder(ColTypeTransform):
r"""A linear function based Transform for numerical features. It applies
linear layer :obj:`torch.nn.Linear(1, out_dim)` on each raw numerical
feature and concatenates the output embeddings. Note that the
implementation does this for all numerical features in a batched manner.
"""

supported_types = {ColType.NUMERICAL}

def __init__(
self,
out_dim: int | None = None,
stats_list: List[Dict[StatType, Any]] | None = None,
col_type: ColType | None = ColType.NUMERICAL,
post_module: Module | None = None,
na_mode: NAMode | None = None,
activate: Module | None = None,
):
super().__init__(out_dim, stats_list, col_type, post_module, na_mode)
self.activate = activate

def post_init(self):
r"""This is the actual initialization function."""
mean = torch.tensor([stats[StatType.MEAN] for stats in self.stats_list])
self.register_buffer("mean", mean)
std = torch.tensor([stats[StatType.STD] for stats in self.stats_list]) + 1e-6
self.register_buffer("std", std)
num_cols = len(self.stats_list)
self.weight = Parameter(torch.empty(num_cols, self.out_dim))
self.bias = Parameter(torch.empty(num_cols, self.out_dim))
self.reset_parameters()

def reset_parameters(self) -> None:
super().reset_parameters()
torch.nn.init.normal_(self.weight, std=0.01)
torch.nn.init.zeros_(self.bias)

def encode_forward(
self,
feat: Tensor,
col_names: List[str] | None = None,
) -> Tensor:
# feat: [batch_size, num_cols]
feat = (feat - self.mean) / self.std
# [batch_size, num_cols], [dim, num_cols]
# -> [batch_size, num_cols, dim]
x_lin = torch.einsum("ij,jk->ijk", feat, self.weight)
# [batch_size, num_cols, dim] + [num_cols, dim]
# -> [batch_size, num_cols, dim]
x = x_lin + self.bias

if self.activate is not None:
x = self.activate(x)
return x
Loading

0 comments on commit 5835463

Please sign in to comment.