This repository contains an implementation of SAINT (Self-Attention and Intersample Attention Transformer) using Pytorch-Lightning as a framework and Hydra for the configuration. Find the paper on arxiv
Check the website for more information.
pip install lit-saint
- Create an yaml file that contains the configuration needed by the application or use default values
- Create an instance of SaintConfig using Hydra
- Create the Dataframe that will be used for the model. In order to split correctly the data, you need to add a new column where you assign the label "train" to the rows of the training set, "validation" for the ones of the validation set and "test" for the testing one
data_module = SaintDatamodule(df=df, target="TARGET", split_column="SPLIT")
- Create an instance of SaintDataModule and SAINT
model = Saint(categories=data_module.categorical_dims, continuous=data_module.numerical_columns,
config=cfg, dim_target=data_module.dim_target)
- Create the Trainers defined by Pytorch lightning to fit the model
pretrainer = Trainer(max_epochs=1)
trainer = Trainer(max_epochs=5)
- Create the SaintTrainer that will be used in order to fit the model and make predictions
saint_trainer = SaintTrainer(pretrainer=pretrainer, trainer=trainer)
saint_trainer.fit(model=model, datamodule=data_module, enable_pretraining=True)
- Then you can define the data for the prediction step
prediction = saint_trainer.predict(model=model, datamodule=data_module, df=df_to_predict)
df_test["prediction"] = np.argmax(prediction, axis=1)
- The numerical columns are filled with zeros in case of missing values
- The categorical columns are filled with a new category with the value SAINT_NAN in case of missing values
- The numerical columns are scaled using a StandardScaler unless you specify a different scaler inside the SaintDataModule
- The columns that are of type ["object", "category"] are considered categorical
- The columns that are of type ["int64", "float64", "int32", "float32"] are considered numerical
- All the columns that have a type different from the one specified before aren't used inside the model
- During the pretraining the rows that have the target column with nan value are used, instead are dropped before start the training
Some suggestions are:
- If you want to fill the columns in a different way from the default one you need to do it before to use in the SaintDataModule
- If you want to use columns that contains datetime, you need to extract some features (i.e day of week) or convert them in epoch
- If you have a classification problem the column that contain the target must be of type "object" or "category"
from lit_saint import SaintConfig
from omegaconf import OmegaConf
conf = OmegaConf.create(SaintConfig)
with open("<FILE_NAME>", "w+") as fp:
OmegaConf.save(config=conf, f=fp.name)
In order to make type validation at runtime, you need to add at the beginning of your file the following lines:
defaults:
- base_config
We would like to thank the repo with the official implementation of SAINT: https://github.com/somepago/saint