diff --git a/mambular/base_models/embedding_regressor.py b/mambular/base_models/embedding_regressor.py index 2b9904e..d7c3b93 100644 --- a/mambular/base_models/embedding_regressor.py +++ b/mambular/base_models/embedding_regressor.py @@ -4,6 +4,7 @@ from ..utils.config import MambularConfig from ..utils.mamba_arch import Mamba +from ..utils.mlp_utils import MLP class BaseEmbeddingMambularRegressor(pl.LightningModule): @@ -57,6 +58,12 @@ def __init__( lr_factor=0.75, seq_size: int = 20, raw_embeddings=False, + head_layer_sizes=[64, 32, 32], + head_dropout: float = 0.3, + head_skip_layers: bool = False, + head_activation="leakyrelu", + head_use_batch_norm: bool = False, + attn_dropout: float = 0.3, ): super().__init__() @@ -97,8 +104,7 @@ def __init__( self.num_embeddings = nn.ModuleList( [ nn.Sequential( - nn.Linear(self.seq_size, - self.config.d_model, bias=False), + nn.Linear(self.seq_size, self.config.d_model, bias=False), # Example using ReLU as the activation function, change as needed self.embedding_activation, ) @@ -128,26 +134,17 @@ def __init__( self.mamba = Mamba(self.config) self.norm_f = self.config.norm(self.config.d_model) - mlp_activation_fn = activations.get( - self.config.tabular_head_activation.lower(), nn.Identity() - ) - - # Dynamically create MLP layers based on config.tabular_units - mlp_layers = [] - input_dim = self.config.d_model # Initial input dimension - - # Iterate over the specified units for each layer in the MLP - for units in self.config.tabular_head_units: - mlp_layers.append(nn.Linear(input_dim, units)) - mlp_layers.append(mlp_activation_fn) - mlp_layers.append(nn.Dropout(self.config.tabular_head_dropout)) - input_dim = units - - # Add the final linear layer to map to a single output value - mlp_layers.append(nn.Linear(input_dim, 1)) + head_activation = activations.get(head_activation.lower(), nn.Identity()) # Combine all layers into a Sequential module - self.tabular_head = nn.Sequential(*mlp_layers) + self.tabular_head = MLP( + self.config.d_model, + hidden_units_list=head_layer_sizes, + dropout_rate=head_dropout, + use_skip_layers=head_skip_layers, + activation_fn=head_activation, + use_batch_norm=head_use_batch_norm, + ) self.pooling_method = self.config.pooling_method self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.d_model)) @@ -176,8 +173,7 @@ def forward(self, cat_features, num_features): The output predictions of the model for regression tasks. """ batch_size = ( - cat_features[0].size(0) if cat_features != [ - ] else num_features[0].size(0) + cat_features[0].size(0) if cat_features != [] else num_features[0].size(0) ) cls_tokens = self.cls_token.expand(batch_size, -1, -1) # Process categorical features if present diff --git a/mambular/base_models/regressor.py b/mambular/base_models/regressor.py index 3b0f721..2ff671c 100644 --- a/mambular/base_models/regressor.py +++ b/mambular/base_models/regressor.py @@ -1,68 +1,38 @@ import lightning as pl import torch import torch.nn as nn - -from ..utils.config import MambularConfig from ..utils.mamba_arch import Mamba +from ..utils.mlp_utils import MLP +from ..utils.normalization_layers import ( + RMSNorm, + LayerNorm, + LearnableLayerScaling, + BatchNorm, + InstanceNorm, + GroupNorm, +) +from ..utils.default_mamba_params import DefaultConfig class BaseMambularRegressor(pl.LightningModule): - """ - A base regression module for tabular data built on PyTorch Lightning. It incorporates embeddings - for categorical and numerical features with a configurable architecture provided by MambularConfig. - This module is designed for regression tasks. - - Parameters - ---------- - config : MambularConfig - An instance of MambularConfig containing configuration parameters for the model architecture. - cat_feature_info : dict, optional - A dictionary mapping the names of categorical features to their number of unique categories. Defaults to None. - num_feature_info : dict, optional - A dictionary mapping the names of numerical features to their number of dimensions after embedding. Defaults to None. - lr : float, optional - The initial learning rate for the optimizer. Defaults to 1e-03. - lr_patience : int, optional - The number of epochs with no improvement after which learning rate will be reduced. Defaults to 10. - weight_decay : float, optional - Weight decay (L2 penalty) coefficient. Defaults to 0.025. - lr_factor : float, optional - Factor by which the learning rate will be reduced. Defaults to 0.75. - - - Attributes - ---------- - mamba : Mamba - The core neural network module implementing the Mamba architecture. - norm_f : nn.Module - Normalization layer applied after the Mamba block. - tabular_head : nn.Linear - Final linear layer mapping the features to a single output for regression tasks. - train_mse : torchmetrics.MeanSquaredError - Metric computation module for training Mean Squared Error. - val_mse : torchmetrics.MeanSquaredError - Metric computation module for validation Mean Squared Error. - loss_fct : torch.nn.MSELoss - The loss function for regression tasks. - """ - def __init__( self, - config: MambularConfig, - cat_feature_info: dict = None, - num_feature_info: dict = None, - lr=1e-03, - lr_patience=10, - weight_decay=0.025, - lr_factor=0.75, + cat_feature_info, + num_feature_info, + config: DefaultConfig = DefaultConfig(), + **kwargs, ): super().__init__() - self.config = config - self.lr = lr - self.lr_patience = lr_patience - self.weight_decay = weight_decay - self.lr_factor = lr_factor + # Save all hyperparameters + self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) + + # Assigning values from hyperparameters + self.lr = self.hparams.get("lr", config.lr) + self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) + self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) + self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) + self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info @@ -75,67 +45,108 @@ def __init__( "selu": nn.SELU(), "gelu": nn.GELU(), "softplus": nn.Softplus(), - "leakyrelu": nn.LeakyReLU(), "linear": nn.Identity(), + "silu": nn.functional.silu, } - self.embedding_activation = activations.get( - self.config.num_embedding_activation.lower() + self.embedding_activation = self.hparams.get( + "num_embedding_activation", config.num_embedding_activation + ) + + # Additional layers and components initialization based on hyperparameters + self.mamba = Mamba( + d_model=self.hparams.get("d_model", config.d_model), + n_layers=self.hparams.get("n_layers", config.n_layers), + expand_factor=self.hparams.get("expand_factor", config.expand_factor), + bias=self.hparams.get("bias", config.bias), + d_conv=self.hparams.get("d_conv", config.d_conv), + conv_bias=self.hparams.get("conv_bias", config.conv_bias), + dropout=self.hparams.get("dropout", config.dropout), + dt_rank=self.hparams.get("dt_rank", config.dt_rank), + d_state=self.hparams.get("d_state", config.d_state), + dt_scale=self.hparams.get("dt_scale", config.dt_scale), + dt_init=self.hparams.get("dt_init", config.dt_init), + dt_max=self.hparams.get("dt_max", config.dt_max), + dt_min=self.hparams.get("dt_min", config.dt_min), + dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor), + norm=globals()[self.hparams.get("norm", config.norm)], + activation=self.hparams.get("activation", config.activation), ) + + # Set the normalization layer dynamically + norm_layer = self.hparams.get("norm", config.norm) + if norm_layer == "RMSNorm": + self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model)) + elif norm_layer == "LayerNorm": + self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model)) + elif norm_layer == "BatchNorm": + self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model)) + elif norm_layer == "InstanceNorm": + self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model)) + elif norm_layer == "GroupNorm": + self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model)) + elif norm_layer == "LearnableLayerScaling": + self.norm_f = LearnableLayerScaling( + self.hparams.get("d_model", config.d_model) + ) + else: + raise ValueError(f"Unsupported normalization layer: {norm_layer}") + if self.embedding_activation is None: raise ValueError( - f"Unsupported activation function: {self.config.num_embedding_activation}" + f"Unsupported activation function: {self.hparams.get('num_embedding_activation')}" ) self.num_embeddings = nn.ModuleList( [ nn.Sequential( - nn.Linear(input_shape, self.config.d_model, bias=False), - # Example using ReLU as the activation function, change as needed + nn.Linear( + input_shape, + self.hparams.get("d_model", config.d_model), + bias=False, + ), self.embedding_activation, ) for feature_name, input_shape in num_feature_info.items() ] ) - # Create embedding layers for categorical features based on cat_feature_info self.cat_embeddings = nn.ModuleList( [ - nn.Embedding(num_categories + 1, self.config.d_model) + nn.Embedding( + num_categories + 1, self.hparams.get("d_model", config.d_model) + ) for feature_name, num_categories in cat_feature_info.items() ] ) - self.mamba = Mamba(self.config) - self.norm_f = self.config.norm(self.config.d_model) - mlp_activation_fn = activations.get( - self.config.tabular_head_activation.lower(), nn.Identity() - ) - - # Dynamically create MLP layers based on config.tabular_units - mlp_layers = [] - input_dim = self.config.d_model # Initial input dimension - - # Iterate over the specified units for each layer in the MLP - for units in self.config.tabular_head_units: - mlp_layers.append(nn.Linear(input_dim, units)) - mlp_layers.append(mlp_activation_fn) - mlp_layers.append(nn.Dropout(self.config.tabular_head_dropout)) - input_dim = units + head_activation = self.hparams.get("head_activation", config.head_activation) - # Add the final linear layer to map to a single output value - mlp_layers.append(nn.Linear(input_dim, 1)) - - # Combine all layers into a Sequential module - self.tabular_head = nn.Sequential(*mlp_layers) + self.tabular_head = MLP( + self.hparams.get("d_model", config.d_model), + hidden_units_list=self.hparams.get( + "head_layer_sizes", config.head_layer_sizes + ), + dropout_rate=self.hparams.get("head_dropout", config.head_dropout), + use_skip_layers=self.hparams.get( + "head_skip_layers", config.head_skip_layers + ), + activation_fn=head_activation, + use_batch_norm=self.hparams.get( + "head_use_batch_norm", config.head_use_batch_norm + ), + ) - self.pooling_method = self.config.pooling_method - self.cls_token = nn.Parameter(torch.zeros(1, 1, self.config.d_model)) + self.cls_token = nn.Parameter( + torch.zeros(1, 1, self.hparams.get("d_model", config.d_model)) + ) self.loss_fct = nn.MSELoss() - if self.config.layer_norm_after_embedding: - self.embedding_norm = nn.LayerNorm(self.config.d_model) + if self.hparams.get("layer_norm_after_embedding"): + self.embedding_norm = nn.LayerNorm( + self.hparams.get("d_model", config.d_model) + ) def forward(self, cat_features, num_features): """ @@ -156,8 +167,7 @@ def forward(self, cat_features, num_features): """ batch_size = ( - cat_features[0].size(0) if cat_features != [ - ] else num_features[0].size(0) + cat_features[0].size(0) if cat_features != [] else num_features[0].size(0) ) cls_tokens = self.cls_token.expand(batch_size, -1, -1) @@ -168,7 +178,7 @@ def forward(self, cat_features, num_features): ] cat_embeddings = torch.stack(cat_embeddings, dim=1) cat_embeddings = torch.squeeze(cat_embeddings, dim=2) - if self.config.layer_norm_after_embedding: + if self.hparams.get("layer_norm_after_embedding"): cat_embeddings = self.embedding_norm(cat_embeddings) else: cat_embeddings = None @@ -179,7 +189,7 @@ def forward(self, cat_features, num_features): emb(num_features[i]) for i, emb in enumerate(self.num_embeddings) ] num_embeddings = torch.stack(num_embeddings, dim=1) - if self.config.layer_norm_after_embedding: + if self.hparams.get("layer_norm_after_embedding"): num_embeddings = self.embedding_norm(num_embeddings) else: num_embeddings = None @@ -209,7 +219,7 @@ def forward(self, cat_features, num_features): else: raise ValueError(f"Invalid pooling method: {self.pooling_method}") - x = self.norm_f(x) + x = self.norm_f.forward(x) preds = self.tabular_head(x) return preds @@ -281,7 +291,7 @@ def configure_optimizers(self): A dictionary containing the optimizer and lr_scheduler configurations. """ optimizer = torch.optim.Adam( - self.parameters(), lr=self.lr, weight_decay=self.config.weight_decay + self.parameters(), lr=self.lr, weight_decay=self.weight_decay ) scheduler = { "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( diff --git a/mambular/models/sklearn_regressor.py b/mambular/models/sklearn_regressor.py index ad23fd6..c75f5d8 100644 --- a/mambular/models/sklearn_regressor.py +++ b/mambular/models/sklearn_regressor.py @@ -6,30 +6,94 @@ from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader +import warnings from ..base_models.regressor import BaseMambularRegressor -from ..utils.config import MambularConfig from ..utils.dataset import MambularDataModule, MambularDataset from ..utils.preprocessor import Preprocessor +from ..utils.default_mamba_params import DefaultConfig class MambularRegressor(BaseEstimator): """ - A regressor implemented using PyTorch Lightning that follows the scikit-learn API conventions. This class is designed - to work with tabular data, offering a straightforward way to specify model configurations and preprocessing steps. It - integrates seamlessly with scikit-learn's tools such as cross-validation and grid search. + A regressor implemented using PyTorch Lightning that follows the scikit-learn API conventions. + This class is designed to work with tabular data, offering a straightforward way to specify + model configurations and preprocessing steps. It integrates seamlessly with scikit-learn's tools + such as cross-validation and grid search. Parameters ---------- - **kwargs : Various - Accepts any number of keyword arguments. Arguments recognized as model configuration options are passed to the - MambularConfig constructor. Remaining arguments are assumed to be preprocessor options and passed to the - Preprocessor constructor. + # configuration parameters + lr : float, optional + Learning rate for the optimizer. Default is 1e-4. + lr_patience : int, optional + Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. Default is 10. + weight_decay : float, optional + Weight decay (L2 penalty) coefficient. Default is 1e-6. + lr_factor : float, optional + Factor by which the learning rate will be reduced. Default is 0.1. + d_model : int, optional + Dimension of the model. Default is 64. + n_layers : int, optional + Number of layers. Default is 8. + expand_factor : int, optional + Expansion factor. Default is 2. + bias : bool, optional + Whether to use bias. Default is False. + d_conv : int, optional + Dimension of the convolution. Default is 16. + conv_bias : bool, optional + Whether to use bias in the convolution. Default is True. + dropout : float, optional + Dropout rate in the mamba blocks. Default is 0.05. + dt_rank : str, optional + Rank of the time dimension. Default is "auto". + d_state : int, optional + State dimension. Default is 16. + dt_scale : float, optional + Scale of the time dimension. Default is 1.0. + dt_init : str, optional + Initialization method for the time dimension. Default is "random". + dt_max : float, optional + Maximum value for the time dimension. Default is 0.1. + dt_min : float, optional + Minimum value for the time dimension. Default is 1e-3. + dt_init_floor : float, optional + Floor value for the time dimension initialization. Default is 1e-4. + norm : str, optional + Normalization method. Default is 'RMSNorm'. + activation : callable, optional + Activation function. Default is nn.SELU(). + num_embedding_activation : callable, optional + Activation function for numerical embeddings. Default is nn.Identity(). + head_layer_sizes : list, optional + Sizes of the layers in the head. Default is [64, 64, 32]. + head_dropout : float, optional + Dropout rate for the head. Default is 0.5. + head_skip_layers : bool, optional + Whether to use skip layers in the head. Default is False. + head_activation : callable, optional + Activation function for the head. Default is nn.SELU(). + head_use_batch_norm : bool, optional + Whether to use batch normalization in the head. Default is False. + + # Preprocessor Parameters + n_bins : int, optional + The number of bins to use for numerical feature binning. Default is 50. + numerical_preprocessing : str, optional + The preprocessing strategy for numerical features. Default is 'ple'. + use_decision_tree_bins : bool, optional + If True, uses decision tree regression/classification to determine optimal bin edges for numerical feature binning. Default is False. + binning_strategy : str, optional + Defines the strategy for binning numerical features. Default is 'uniform'. + task : str, optional + Indicates the type of machine learning task ('regression' or 'classification'). Default is 'regression'. + Attributes ---------- - config : MambularConfig + config : DefaultConfig An object storing the configuration settings for the model. preprocessor : Preprocessor An object responsible for preprocessing the input data, such as encoding categorical variables and scaling numerical features. @@ -39,44 +103,60 @@ class MambularRegressor(BaseEstimator): def __init__(self, **kwargs): # Known config arguments - print("Received kwargs:", kwargs) config_arg_names = [ + "lr", + "lr_patience", + "weight_decay", + "lr_factor", "d_model", "n_layers", - "dt_rank", - "output_dimension", - "pooling_method", - "norm", - "cls", - "dt_min", - "dt_max", - "dropout", + "expand_factor", "bias", - "weight_decay", + "d_conv", "conv_bias", + "dropout", + "dt_rank", "d_state", - "expand_factor", - "d_conv", - "dt_init", "dt_scale", + "dt_init", + "dt_max", + "dt_min", "dt_init_floor", - "tabular_head_units", - "tabular_head_activation", - "tabular_head_dropout", - "num_emebedding_activation", - "layer_norm_after_embedding", + "norm", + "activation", + "num_embedding_activation", + "head_layer_sizes", + "head_dropout", + "head_skip_layers", + "head_activation", + "head_use_batch_norm", + ] + + preprocessor_arg_names = [ + "n_bins", + "numerical_preprocessing", + "use_decision_tree_bins", + "binning_strategy", + "task", ] - self.config_kwargs = {k: v for k, - v in kwargs.items() if k in config_arg_names} - self.config = MambularConfig(**self.config_kwargs) - # The rest are assumed to be preprocessor arguments + self.config_kwargs = {k: v for k, v in kwargs.items() if k in config_arg_names} + self.config = DefaultConfig(**self.config_kwargs) + preprocessor_kwargs = { - k: v for k, v in kwargs.items() if k not in config_arg_names + k: v for k, v in kwargs.items() if k in preprocessor_arg_names } + self.preprocessor = Preprocessor(**preprocessor_kwargs) self.model = None + # Raise a warning if task is set to 'classification' + if preprocessor_kwargs.get("task") == "classification": + warnings.warn( + "The task is set to 'classification'. MambularRegressor is designed for regression tasks.", + UserWarning, + ) + def get_params(self, deep=True): """ Get parameters for this estimator. Overrides the BaseEstimator method. @@ -86,13 +166,12 @@ def get_params(self, deep=True): deep : bool, default=True If True, returns the parameters for this estimator and contained sub-objects that are estimators. - Returns ------- params : dict Parameter names mapped to their values. """ - params = self.config_kwargs # Parameters used to initialize MambularConfig + params = self.config_kwargs # Parameters used to initialize DefaultConfig # If deep=True, include parameters from nested components like preprocessor if deep: @@ -114,7 +193,6 @@ def set_params(self, **parameters): **parameters : dict Estimator parameters to be set. - Returns ------- self : object @@ -122,8 +200,7 @@ def set_params(self, **parameters): """ # Update config_kwargs with provided parameters valid_config_keys = self.config_kwargs.keys() - config_updates = {k: v for k, - v in parameters.items() if k in valid_config_keys} + config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys} self.config_kwargs.update(config_updates) # Update the config object @@ -194,8 +271,7 @@ def preprocess_data(self, X_train, y_train, X_val, y_val, batch_size, shuffle): data_module : MambularDataModule An instance of MambularDataModule containing the training and validation DataLoaders. """ - train_preprocessed_data = self.preprocessor.fit_transform( - X_train, y_train) + train_preprocessed_data = self.preprocessor.fit_transform(X_train, y_train) val_preprocessed_data = self.preprocessor.transform(X_val) # Update feature info based on the actual processed data @@ -215,26 +291,22 @@ def preprocess_data(self, X_train, y_train, X_val, y_val, batch_size, shuffle): cat_key = "cat_" + key # Assuming categorical keys are prefixed with 'cat_' if cat_key in train_preprocessed_data: train_cat_tensors.append( - torch.tensor( - train_preprocessed_data[cat_key], dtype=torch.long) + torch.tensor(train_preprocessed_data[cat_key], dtype=torch.long) ) if cat_key in val_preprocessed_data: val_cat_tensors.append( - torch.tensor( - val_preprocessed_data[cat_key], dtype=torch.long) + torch.tensor(val_preprocessed_data[cat_key], dtype=torch.long) ) binned_key = "num_" + key # for binned features if binned_key in train_preprocessed_data: train_cat_tensors.append( - torch.tensor( - train_preprocessed_data[binned_key], dtype=torch.long) + torch.tensor(train_preprocessed_data[binned_key], dtype=torch.long) ) if binned_key in val_preprocessed_data: val_cat_tensors.append( - torch.tensor( - val_preprocessed_data[binned_key], dtype=torch.long) + torch.tensor(val_preprocessed_data[binned_key], dtype=torch.long) ) # Populate tensors for numerical features, if present in processed data @@ -242,13 +314,11 @@ def preprocess_data(self, X_train, y_train, X_val, y_val, batch_size, shuffle): num_key = "num_" + key # Assuming numerical keys are prefixed with 'num_' if num_key in train_preprocessed_data: train_num_tensors.append( - torch.tensor( - train_preprocessed_data[num_key], dtype=torch.float) + torch.tensor(train_preprocessed_data[num_key], dtype=torch.float) ) if num_key in val_preprocessed_data: val_num_tensors.append( - torch.tensor( - val_preprocessed_data[num_key], dtype=torch.float) + torch.tensor(val_preprocessed_data[num_key], dtype=torch.float) ) train_labels = torch.tensor(y_train, dtype=torch.float) @@ -258,8 +328,7 @@ def preprocess_data(self, X_train, y_train, X_val, y_val, batch_size, shuffle): train_dataset = MambularDataset( train_cat_tensors, train_num_tensors, train_labels ) - val_dataset = MambularDataset( - val_cat_tensors, val_num_tensors, val_labels) + val_dataset = MambularDataset(val_cat_tensors, val_num_tensors, val_labels) # Create dataloaders train_dataloader = DataLoader( @@ -320,20 +389,20 @@ def fit( self, X, y, - val_size=0.2, + val_size: float = 0.2, X_val=None, y_val=None, - max_epochs=100, - random_state=101, - batch_size=128, - shuffle=True, - patience=10, - monitor="val_loss", - mode="min", - lr=1e-3, - lr_patience=5, - factor=0.75, - weight_decay=1e-06, + max_epochs: int = 100, + random_state: int = 101, + batch_size: int = 128, + shuffle: bool = True, + patience: int = 15, + monitor: str = "val_loss", + mode: str = "min", + lr: float = 1e-4, + lr_patience: int = 10, + factor: float = 0.1, + weight_decay: float = 1e-06, **trainer_kwargs ): """ @@ -369,7 +438,7 @@ def fit( Learning rate for the optimizer. lr_patience : int, default=10 Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. - factor : float, default=0.75 + factor : float, default=0.1 Factor by which the learning rate will be reduced. weight_decay : float, default=0.025 Weight decay (L2 penalty) coefficient. diff --git a/mambular/utils/default_mamba_params.py b/mambular/utils/default_mamba_params.py new file mode 100644 index 0000000..f46e4c4 --- /dev/null +++ b/mambular/utils/default_mamba_params.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +import torch.nn as nn + + +@dataclass +class DefaultConfig: + lr: float = 1e-04 + lr_patience: int = 10 + weight_decay: float = 1e-06 + lr_factor: float = 0.1 + d_model: int = 64 + n_layers: int = 8 + expand_factor: int = 2 + bias: bool = False + d_conv: int = 16 + conv_bias: bool = True + dropout: float = 0.05 + dt_rank: str = "auto" + d_state: int = 32 + dt_scale: float = 1.0 + dt_init: str = "random" + dt_max: float = 0.1 + dt_min: float = 1e-04 + dt_init_floor: float = 1e-04 + norm: str = "RMSNorm" + activation: callable = nn.SELU() + num_embedding_activation: callable = nn.Identity() + head_layer_sizes: list = (128, 64, 32) + head_dropout: float = 0.5 + head_skip_layers: bool = False + head_activation: callable = nn.SELU() + head_use_batch_norm: bool = (False,) + layer_norm_after_embedding: bool = (False,) + pooling_method: str = "avg" diff --git a/mambular/utils/mamba_arch.py b/mambular/utils/mamba_arch.py index 2e7ca7c..a1eb830 100644 --- a/mambular/utils/mamba_arch.py +++ b/mambular/utils/mamba_arch.py @@ -2,7 +2,14 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .config import MambularConfig +from .normalization_layers import ( + RMSNorm, + LayerNorm, + LearnableLayerScaling, + BatchNorm, + InstanceNorm, + GroupNorm, +) ### Heavily inspired and mostly taken from https://github.com/alxndrTL/mamba.py @@ -16,13 +23,48 @@ class Mamba(nn.Module): layers (nn.ModuleList): List of MambaBlocks constituting the model. """ - def __init__(self, config: MambularConfig): + def __init__( + self, + d_model=32, + n_layers=8, + expand_factor=2, + bias=False, + d_conv=8, + conv_bias=True, + dropout=0.01, + dt_rank="auto", + d_state=16, + dt_scale=1.0, + dt_init="random", + dt_max=0.1, + dt_min=1e-03, + dt_init_floor=1e-04, + norm=RMSNorm, + activation=F.silu, + ): super().__init__() - self.config = config - self.layers = nn.ModuleList( - [ResidualBlock(config) for _ in range(config.n_layers)] + [ + ResidualBlock( + d_model, + expand_factor, + bias, + d_conv, + conv_bias, + dropout, + dt_rank, + d_state, + dt_scale, + dt_init, + dt_max, + dt_min, + dt_init_floor, + norm, + activation, + ) + for _ in range(n_layers) + ] ) def forward(self, x): @@ -40,11 +82,67 @@ class ResidualBlock(nn.Module): norm (RMSNorm): Normalization layer. """ - def __init__(self, config: MambularConfig): + def __init__( + self, + d_model=32, + expand_factor=2, + bias=False, + d_conv=16, + conv_bias=True, + dropout=0.01, + dt_rank="auto", + d_state=32, + dt_scale=1.0, + dt_init="random", + dt_max=0.1, + dt_min=1e-03, + dt_init_floor=1e-04, + norm=RMSNorm, + activation=F.silu, + ): super().__init__() - self.layers = MambaBlock(config) - self.norm = config.norm(config.d_model) + VALID_NORMALIZATION_LAYERS = { + "RMSNorm": RMSNorm, + "LayerNorm": LayerNorm, + "LearnableLayerScaling": LearnableLayerScaling, + "BatchNorm": BatchNorm, + "InstanceNorm": InstanceNorm, + "GroupNorm": GroupNorm, + } + + # Check if the provided normalization layer is valid + if isinstance(norm, type) and norm.__name__ not in VALID_NORMALIZATION_LAYERS: + raise ValueError( + f"Invalid normalization layer: {norm.__name__}. " + f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}" + ) + elif isinstance(norm, str) and norm not in self.VALID_NORMALIZATION_LAYERS: + raise ValueError( + f"Invalid normalization layer: {norm}. " + f"Valid options are: {', '.join(VALID_NORMALIZATION_LAYERS.keys())}" + ) + + if dt_rank == "auto": + dt_rank = math.ceil(d_model / 16) + + self.layers = MambaBlock( + d_model=d_model, + expand_factor=expand_factor, + bias=bias, + d_conv=d_conv, + conv_bias=conv_bias, + dropout=dropout, + dt_rank=dt_rank, + d_state=d_state, + dt_scale=dt_scale, + dt_init=dt_init, + dt_max=dt_max, + dt_min=dt_min, + dt_init_floor=dt_init_floor, + activation=activation, + ) + self.norm = norm(d_model) def forward(self, x): output = self.layers(self.norm(x)) + x @@ -65,53 +163,66 @@ class MambaBlock(nn.Module): out_proj (nn.Linear): Linear projection for output. """ - def __init__(self, config: MambularConfig): + def __init__( + self, + d_model=32, + expand_factor=2, + bias=False, + d_conv=16, + conv_bias=True, + dropout=0.01, + dt_rank="auto", + d_state=32, + dt_scale=1.0, + dt_init="random", + dt_max=0.1, + dt_min=1e-03, + dt_init_floor=1e-04, + activation=F.silu, + ): super().__init__() + self.d_inner = d_model * expand_factor - self.config = config - - self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias) + self.in_proj = nn.Linear(d_model, 2 * self.d_inner, bias=bias) self.conv1d = nn.Conv1d( - in_channels=config.d_inner, - out_channels=config.d_inner, - kernel_size=config.d_conv, - bias=config.conv_bias, - groups=config.d_inner, - padding=config.d_conv - 1, + in_channels=self.d_inner, + out_channels=self.d_inner, + kernel_size=d_conv, + bias=conv_bias, + groups=self.d_inner, + padding=d_conv - 1, ) - self.dropout = nn.Dropout(config.dropout) + self.dropout = nn.Dropout(dropout) + self.activation = activation - self.x_proj = nn.Linear( - config.d_inner, config.dt_rank + 2 * config.d_state, bias=False - ) + self.x_proj = nn.Linear(self.d_inner, dt_rank + 2 * d_state, bias=False) - self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) + self.dt_proj = nn.Linear(dt_rank, self.d_inner, bias=True) - dt_init_std = config.dt_rank**-0.5 * config.dt_scale - if config.dt_init == "constant": + dt_init_std = dt_rank**-0.5 * dt_scale + if dt_init == "constant": nn.init.constant_(self.dt_proj.weight, dt_init_std) - elif config.dt_init == "random": + elif dt_init == "random": nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError dt = torch.exp( - torch.rand(config.d_inner) - * (math.log(config.dt_max) - math.log(config.dt_min)) - + math.log(config.dt_min) - ).clamp(min=config.dt_init_floor) + torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt) - A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat( - config.d_inner, 1 - ) + A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1) self.A_log = nn.Parameter(torch.log(A)) - self.D = nn.Parameter(torch.ones(config.d_inner)) - self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) + self.D = nn.Parameter(torch.ones(self.d_inner)) + self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias) + self.dt_rank = dt_rank + self.d_state = d_state def forward(self, x): _, L, _ = x.shape @@ -123,11 +234,11 @@ def forward(self, x): x = self.conv1d(x)[:, :, :L] x = x.transpose(1, 2) - x = F.silu(x) + x = self.activation(x) x = self.dropout(x) y = self.ssm(x) - z = F.silu(z) + z = self.activation(z) z = self.dropout(z) output = y * z @@ -143,7 +254,7 @@ def ssm(self, x): delta, B, C = torch.split( deltaBC, - [self.config.dt_rank, self.config.d_state, self.config.d_state], + [self.dt_rank, self.d_state, self.d_state], dim=-1, ) delta = F.softplus(self.dt_proj(delta)) @@ -160,9 +271,7 @@ def selective_scan_seq(self, x, delta, A, B, C, D): BX = deltaB * (x.unsqueeze(-1)) - h = torch.zeros( - x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device - ) + h = torch.zeros(x.size(0), self.d_inner, self.d_state, device=deltaA.device) hs = [] for t in range(0, L): diff --git a/mambular/utils/normalization_layers.py b/mambular/utils/normalization_layers.py index 817a2cd..5237177 100644 --- a/mambular/utils/normalization_layers.py +++ b/mambular/utils/normalization_layers.py @@ -15,7 +15,6 @@ class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-5): super().__init__() - self.eps = eps self.weight = nn.Parameter(torch.ones(d_model)) diff --git a/mambular/utils/preprocessor.py b/mambular/utils/preprocessor.py index 767e4f1..c443ed1 100644 --- a/mambular/utils/preprocessor.py +++ b/mambular/utils/preprocessor.py @@ -53,8 +53,8 @@ class Preprocessor: def __init__( self, - n_bins=200, - numerical_preprocessing="binning", + n_bins=50, + numerical_preprocessing="ple", use_decision_tree_bins=False, binning_strategy="uniform", task="regression",