Skip to content

Commit

Permalink
restructure regression module
Browse files Browse the repository at this point in the history
  • Loading branch information
thielmaf committed May 29, 2024
1 parent 4e530a1 commit 1098f94
Show file tree
Hide file tree
Showing 7 changed files with 441 additions and 224 deletions.
40 changes: 18 additions & 22 deletions mambular/base_models/embedding_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ..utils.config import MambularConfig
from ..utils.mamba_arch import Mamba
from ..utils.mlp_utils import MLP


class BaseEmbeddingMambularRegressor(pl.LightningModule):
Expand Down Expand Up @@ -57,6 +58,12 @@ def __init__(
lr_factor=0.75,
seq_size: int = 20,
raw_embeddings=False,
head_layer_sizes=[64, 32, 32],
head_dropout: float = 0.3,
head_skip_layers: bool = False,
head_activation="leakyrelu",
head_use_batch_norm: bool = False,
attn_dropout: float = 0.3,
):
super().__init__()

Expand Down Expand Up @@ -97,8 +104,7 @@ def __init__(
self.num_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Linear(self.seq_size,
self.config.d_model, bias=False),
nn.Linear(self.seq_size, self.config.d_model, bias=False),
# Example using ReLU as the activation function, change as needed
self.embedding_activation,
)
Expand Down Expand Up @@ -128,26 +134,17 @@ def __init__(

self.mamba = Mamba(self.config)
self.norm_f = self.config.norm(self.config.d_model)
mlp_activation_fn = activations.get(
self.config.tabular_head_activation.lower(), nn.Identity()
)

# Dynamically create MLP layers based on config.tabular_units
mlp_layers = []
input_dim = self.config.d_model # Initial input dimension

# Iterate over the specified units for each layer in the MLP
for units in self.config.tabular_head_units:
mlp_layers.append(nn.Linear(input_dim, units))
mlp_layers.append(mlp_activation_fn)
mlp_layers.append(nn.Dropout(self.config.tabular_head_dropout))
input_dim = units

# Add the final linear layer to map to a single output value
mlp_layers.append(nn.Linear(input_dim, 1))
head_activation = activations.get(head_activation.lower(), nn.Identity())

# Combine all layers into a Sequential module
self.tabular_head = nn.Sequential(*mlp_layers)
self.tabular_head = MLP(
self.config.d_model,
hidden_units_list=head_layer_sizes,
dropout_rate=head_dropout,
use_skip_layers=head_skip_layers,
activation_fn=head_activation,
use_batch_norm=head_use_batch_norm,
)

self.pooling_method = self.config.pooling_method
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.d_model))
Expand Down Expand Up @@ -176,8 +173,7 @@ def forward(self, cat_features, num_features):
The output predictions of the model for regression tasks.
"""
batch_size = (
cat_features[0].size(0) if cat_features != [
] else num_features[0].size(0)
cat_features[0].size(0) if cat_features != [] else num_features[0].size(0)
)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# Process categorical features if present
Expand Down
194 changes: 102 additions & 92 deletions mambular/base_models/regressor.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,38 @@
import lightning as pl
import torch
import torch.nn as nn

from ..utils.config import MambularConfig
from ..utils.mamba_arch import Mamba
from ..utils.mlp_utils import MLP
from ..utils.normalization_layers import (
RMSNorm,
LayerNorm,
LearnableLayerScaling,
BatchNorm,
InstanceNorm,
GroupNorm,
)
from ..utils.default_mamba_params import DefaultConfig


class BaseMambularRegressor(pl.LightningModule):
"""
A base regression module for tabular data built on PyTorch Lightning. It incorporates embeddings
for categorical and numerical features with a configurable architecture provided by MambularConfig.
This module is designed for regression tasks.
Parameters
----------
config : MambularConfig
An instance of MambularConfig containing configuration parameters for the model architecture.
cat_feature_info : dict, optional
A dictionary mapping the names of categorical features to their number of unique categories. Defaults to None.
num_feature_info : dict, optional
A dictionary mapping the names of numerical features to their number of dimensions after embedding. Defaults to None.
lr : float, optional
The initial learning rate for the optimizer. Defaults to 1e-03.
lr_patience : int, optional
The number of epochs with no improvement after which learning rate will be reduced. Defaults to 10.
weight_decay : float, optional
Weight decay (L2 penalty) coefficient. Defaults to 0.025.
lr_factor : float, optional
Factor by which the learning rate will be reduced. Defaults to 0.75.
Attributes
----------
mamba : Mamba
The core neural network module implementing the Mamba architecture.
norm_f : nn.Module
Normalization layer applied after the Mamba block.
tabular_head : nn.Linear
Final linear layer mapping the features to a single output for regression tasks.
train_mse : torchmetrics.MeanSquaredError
Metric computation module for training Mean Squared Error.
val_mse : torchmetrics.MeanSquaredError
Metric computation module for validation Mean Squared Error.
loss_fct : torch.nn.MSELoss
The loss function for regression tasks.
"""

def __init__(
self,
config: MambularConfig,
cat_feature_info: dict = None,
num_feature_info: dict = None,
lr=1e-03,
lr_patience=10,
weight_decay=0.025,
lr_factor=0.75,
cat_feature_info,
num_feature_info,
config: DefaultConfig = DefaultConfig(),
**kwargs,
):
super().__init__()

self.config = config
self.lr = lr
self.lr_patience = lr_patience
self.weight_decay = weight_decay
self.lr_factor = lr_factor
# Save all hyperparameters
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])

# Assigning values from hyperparameters
self.lr = self.hparams.get("lr", config.lr)
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
self.pooling_method = self.hparams.get("pooling_method", config.pooling_method)
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info

Expand All @@ -75,67 +45,108 @@ def __init__(
"selu": nn.SELU(),
"gelu": nn.GELU(),
"softplus": nn.Softplus(),
"leakyrelu": nn.LeakyReLU(),
"linear": nn.Identity(),
"silu": nn.functional.silu,
}

self.embedding_activation = activations.get(
self.config.num_embedding_activation.lower()
self.embedding_activation = self.hparams.get(
"num_embedding_activation", config.num_embedding_activation
)

# Additional layers and components initialization based on hyperparameters
self.mamba = Mamba(
d_model=self.hparams.get("d_model", config.d_model),
n_layers=self.hparams.get("n_layers", config.n_layers),
expand_factor=self.hparams.get("expand_factor", config.expand_factor),
bias=self.hparams.get("bias", config.bias),
d_conv=self.hparams.get("d_conv", config.d_conv),
conv_bias=self.hparams.get("conv_bias", config.conv_bias),
dropout=self.hparams.get("dropout", config.dropout),
dt_rank=self.hparams.get("dt_rank", config.dt_rank),
d_state=self.hparams.get("d_state", config.d_state),
dt_scale=self.hparams.get("dt_scale", config.dt_scale),
dt_init=self.hparams.get("dt_init", config.dt_init),
dt_max=self.hparams.get("dt_max", config.dt_max),
dt_min=self.hparams.get("dt_min", config.dt_min),
dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor),
norm=globals()[self.hparams.get("norm", config.norm)],
activation=self.hparams.get("activation", config.activation),
)

# Set the normalization layer dynamically
norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model))
elif norm_layer == "LayerNorm":
self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model))
elif norm_layer == "BatchNorm":
self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model))
elif norm_layer == "InstanceNorm":
self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model))
elif norm_layer == "GroupNorm":
self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model))
elif norm_layer == "LearnableLayerScaling":
self.norm_f = LearnableLayerScaling(
self.hparams.get("d_model", config.d_model)
)
else:
raise ValueError(f"Unsupported normalization layer: {norm_layer}")

if self.embedding_activation is None:
raise ValueError(
f"Unsupported activation function: {self.config.num_embedding_activation}"
f"Unsupported activation function: {self.hparams.get('num_embedding_activation')}"
)

self.num_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Linear(input_shape, self.config.d_model, bias=False),
# Example using ReLU as the activation function, change as needed
nn.Linear(
input_shape,
self.hparams.get("d_model", config.d_model),
bias=False,
),
self.embedding_activation,
)
for feature_name, input_shape in num_feature_info.items()
]
)

# Create embedding layers for categorical features based on cat_feature_info
self.cat_embeddings = nn.ModuleList(
[
nn.Embedding(num_categories + 1, self.config.d_model)
nn.Embedding(
num_categories + 1, self.hparams.get("d_model", config.d_model)
)
for feature_name, num_categories in cat_feature_info.items()
]
)

self.mamba = Mamba(self.config)
self.norm_f = self.config.norm(self.config.d_model)
mlp_activation_fn = activations.get(
self.config.tabular_head_activation.lower(), nn.Identity()
)

# Dynamically create MLP layers based on config.tabular_units
mlp_layers = []
input_dim = self.config.d_model # Initial input dimension

# Iterate over the specified units for each layer in the MLP
for units in self.config.tabular_head_units:
mlp_layers.append(nn.Linear(input_dim, units))
mlp_layers.append(mlp_activation_fn)
mlp_layers.append(nn.Dropout(self.config.tabular_head_dropout))
input_dim = units
head_activation = self.hparams.get("head_activation", config.head_activation)

# Add the final linear layer to map to a single output value
mlp_layers.append(nn.Linear(input_dim, 1))

# Combine all layers into a Sequential module
self.tabular_head = nn.Sequential(*mlp_layers)
self.tabular_head = MLP(
self.hparams.get("d_model", config.d_model),
hidden_units_list=self.hparams.get(
"head_layer_sizes", config.head_layer_sizes
),
dropout_rate=self.hparams.get("head_dropout", config.head_dropout),
use_skip_layers=self.hparams.get(
"head_skip_layers", config.head_skip_layers
),
activation_fn=head_activation,
use_batch_norm=self.hparams.get(
"head_use_batch_norm", config.head_use_batch_norm
),
)

self.pooling_method = self.config.pooling_method
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.d_model))
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.hparams.get("d_model", config.d_model))
)

self.loss_fct = nn.MSELoss()

if self.config.layer_norm_after_embedding:
self.embedding_norm = nn.LayerNorm(self.config.d_model)
if self.hparams.get("layer_norm_after_embedding"):
self.embedding_norm = nn.LayerNorm(
self.hparams.get("d_model", config.d_model)
)

def forward(self, cat_features, num_features):
"""
Expand All @@ -156,8 +167,7 @@ def forward(self, cat_features, num_features):
"""

batch_size = (
cat_features[0].size(0) if cat_features != [
] else num_features[0].size(0)
cat_features[0].size(0) if cat_features != [] else num_features[0].size(0)
)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)

Expand All @@ -168,7 +178,7 @@ def forward(self, cat_features, num_features):
]
cat_embeddings = torch.stack(cat_embeddings, dim=1)
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
if self.config.layer_norm_after_embedding:
if self.hparams.get("layer_norm_after_embedding"):
cat_embeddings = self.embedding_norm(cat_embeddings)
else:
cat_embeddings = None
Expand All @@ -179,7 +189,7 @@ def forward(self, cat_features, num_features):
emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)
]
num_embeddings = torch.stack(num_embeddings, dim=1)
if self.config.layer_norm_after_embedding:
if self.hparams.get("layer_norm_after_embedding"):
num_embeddings = self.embedding_norm(num_embeddings)
else:
num_embeddings = None
Expand Down Expand Up @@ -209,7 +219,7 @@ def forward(self, cat_features, num_features):
else:
raise ValueError(f"Invalid pooling method: {self.pooling_method}")

x = self.norm_f(x)
x = self.norm_f.forward(x)
preds = self.tabular_head(x)

return preds
Expand Down Expand Up @@ -281,7 +291,7 @@ def configure_optimizers(self):
A dictionary containing the optimizer and lr_scheduler configurations.
"""
optimizer = torch.optim.Adam(
self.parameters(), lr=self.lr, weight_decay=self.config.weight_decay
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
scheduler = {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
Expand Down
Loading

0 comments on commit 1098f94

Please sign in to comment.