Skip to content

Latest commit

 

History

History
223 lines (176 loc) · 24.9 KB

README.md

File metadata and controls

223 lines (176 loc) · 24.9 KB

Benchmarking model performance across diverse DataFrameBenchmark datasets.

First install additional dependencies:

pip install optuna
pip install torchmetrics
pip install xgboost
pip install catboost
pip install lightgbm

Then run

# Specify the model from [TabNet, FTTransformer, ResNet, MLP, TabTransformer,
# Trompt, ExcelFormer, FTTransformerBucket, XGBoost, CatBoost, LightGBM]
model_type=TabNet

# Specify the task type from [binary_classification, regression,
# multiclass_classification]
task_type=binary_classification

# Specify the dataset scale from [small, medium, large]
scale=small

# Specify the dataset idx from [0, 1, ...]
idx=0

# Specify the number of AutoML search trials
num_trials=20

# Specify the path to save the results
result_path=results.pt

# Run hyper-parameter tuning and training of the specified model on a specified
# dataset.
python data_frame_benchmark.py --model_type $model_type\
                               --task_type $task_type\
                               --scale $scale\
                               --idx $idx\
                               --num_trials $num_trials\
                               --result_path $result_path

Leaderboard

We show the current model performance across different datasets. The row denotes the model names and the column denotes the dataset idx. In each cell, we include the mean and standard deviation of the model performance, as well as the total time spent, including Optuna-based hyper-parameter search and final model training.

For the mapping from dataset idx into the actual dataset object, please see the documentation.

task_type: binary_classification

Metric: ROC-AUC, higher the better.

scale: small

Experimental setting: 20 Optuna search trials. 50 epochs of training.

dataset_0 dataset_1 dataset_2 dataset_3 dataset_4 dataset_5 dataset_6 dataset_7 dataset_8 dataset_9 dataset_10 dataset_11 dataset_12 dataset_13
XGBoost 0.931±0.000 (41s) 1.000±0.000 (3s) 0.940±0.000 (389s) 0.947±0.000 (42s) 0.885±0.000 (109s) 0.966±0.000 (14s) 0.862±0.000 (10s) 0.779±0.000 (79s) 0.984±0.000 (376s) 0.714±0.000 (10s) 0.787±0.000 (9s) 0.951±0.000 (103s) 0.999±0.000 (434s) 0.925±0.000 (848s)
CatBoost 0.930±0.000 (152s) 1.000±0.000 (9s) 0.938±0.000 (164s) 0.924±0.000 (29s) 0.881±0.000 (27s) 0.963±0.000 (48s) 0.861±0.000 (12s) 0.772±0.000 (10s) 0.930±0.000 (91s) 0.628±0.000 (10s) 0.796±0.000 (15s) 0.948±0.000 (46s) 0.998±0.000 (38s) 0.926±0.000 (115s)
LightGBM 0.931±0.000 (15s) 0.999±0.000 (1s) 0.943±0.000 (23s) 0.943±0.000 (14s) 0.887±0.000 (5s) 0.972±0.000 (11s) 0.862±0.000 (6s) 0.774±0.000 (3s) 0.979±0.000 (41s) 0.732±0.000 (13s) 0.787±0.000 (3s) 0.951±0.000 (13s) 0.999±0.000 (10s) 0.927±0.000 (24s)
Trompt 0.919±0.000 (9627s) 1.000±0.000 (5341s) 0.945±0.000 (14679s) 0.942±0.001 (2752s) 0.881±0.001 (2640s) 0.964±0.001 (5173s) 0.855±0.002 (4249s) 0.778±0.002 (8789s) 0.933±0.001 (9353s) 0.686±0.008 (3105s) 0.793±0.002 (8255s) 0.952±0.001 (4876s) 1.000±0.000 (3558s) 0.916±0.001 (30002s)
ResNet 0.917±0.000 (615s) 1.000±0.000 (71s) 0.937±0.001 (787s) 0.938±0.002 (230s) 0.865±0.001 (183s) 0.960±0.001 (349s) 0.828±0.001 (248s) 0.768±0.002 (205s) 0.925±0.002 (958s) 0.665±0.006 (140s) 0.794±0.002 (76s) 0.946±0.002 (145s) 1.000±0.000 (93s) 0.911±0.001 (880s)
MLP 0.913±0.001 (112s) 1.000±0.000 (45s) 0.934±0.001 (274s) 0.938±0.001 (66s) 0.863±0.002 (61s) 0.953±0.000 (92s) 0.830±0.001 (68s) 0.769±0.002 (56s) 0.903±0.002 (159s) 0.666±0.015 (58s) 0.789±0.001 (48s) 0.940±0.002 (107s) 1.000±0.000 (48s) 0.910±0.001 (149s)
FTTransformerBucket 0.915±0.001 (690s) 0.999±0.001 (354s) 0.936±0.002 (1705s) 0.939±0.002 (484s) 0.876±0.002 (321s) 0.960±0.001 (746s) 0.857±0.000 (549s) 0.771±0.003 (654s) 0.909±0.002 (1177s) 0.636±0.012 (244s) 0.788±0.002 (710s) 0.950±0.001 (510s) 0.999±0.000 (634s) 0.913±0.001 (1164s)
ExcelFormer 0.918±0.001 (1587s) 1.000±0.000 (634s) 0.939±0.001 (1827s) 0.939±0.002 (378s) 0.883±0.001 (289s) 0.969±0.000 (678s) 0.833±0.011 (435s) 0.780±0.002 (938s) 0.940±0.003 (919s) 0.670±0.017 (464s) 0.794±0.003 (683s) 0.950±0.001 (405s) 0.999±0.000 (1169s) 0.919±0.001 (1798s)
FTTransformer 0.918±0.001 (871s) 1.000±0.000 (571s) 0.940±0.001 (1371s) 0.936±0.001 (458s) 0.874±0.002 (200s) 0.959±0.001 (622s) 0.828±0.001 (339s) 0.773±0.002 (521s) 0.909±0.002 (1488s) 0.635±0.011 (392s) 0.790±0.001 (556s) 0.949±0.002 (374s) 1.000±0.000 (713s) 0.912±0.000 (1855s)
TabNet 0.911±0.001 (150s) 1.000±0.000 (35s) 0.931±0.005 (254s) 0.937±0.003 (125s) 0.864±0.002 (52s) 0.944±0.001 (116s) 0.828±0.001 (79s) 0.771±0.005 (93s) 0.913±0.005 (177s) 0.606±0.014 (65s) 0.790±0.003 (41s) 0.936±0.003 (104s) 1.000±0.000 (64s) 0.910±0.001 (294s)
TabTransformer 0.910±0.001 (2044s) 1.000±0.000 (1321s) 0.928±0.001 (2519s) 0.918±0.003 (134s) 0.829±0.002 (64s) 0.928±0.001 (105s) 0.816±0.002 (99s) 0.757±0.003 (645s) 0.885±0.001 (1167s) 0.652±0.006 (282s) 0.780±0.002 (112s) 0.937±0.001 (117s) 0.996±0.000 (76s) 0.905±0.001 (2283s)

scale: medium

Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 5 Optuna search trials and 25 epochs training for deep learning models.

dataset_0 dataset_1 dataset_2 dataset_3 dataset_4 dataset_5 dataset_6 dataset_7 dataset_8
XGBoost 0.594±0.000 (466s) 0.955±0.000 (6340s) 0.653±0.000 (19s) 0.986±0.000 (195s) 0.721±0.000 (62s) 0.998±0.000 (70626s) 0.868±0.000 (159s) 0.888±0.000 (2945s) 0.803±0.000 (371s)
CatBoost 0.631±0.000 (1201s) 0.956±0.000 (2963s) 0.649±0.000 (26s) 0.986±0.000 (352s) 0.719±0.000 (244s) 0.987±0.000 (2561s) 0.863±0.000 (212s) 0.896±0.000 (740s) 0.803±0.000 (140s)
LightGBM 0.639±0.000 (49s) 0.955±0.000 (126s) 0.652±0.000 (7s) 0.986±0.000 (99s) 0.723±0.000 (16s) 0.997±0.000 (172s) 0.881±0.000 (83s) 0.914±0.000 (86s) 0.809±0.000 (76s)
Trompt OOM 0.950±0.000 (28212s) 0.652±0.000 (5962s) 0.982±0.000 (19936s) 0.716±0.000 (7110s) 0.966±0.000 (106916s) 0.882±0.000 (13644s) 0.883±0.000 (17863s) 0.705±0.006 (11563s)
ResNet 0.637±0.000 (810s) 0.948±0.000 (1051s) 0.649±0.000 (185s) 0.983±0.000 (239s) 0.705±0.001 (226s) 0.989±0.000 (1967s) 0.871±0.001 (173s) 0.890±0.001 (315s) 0.719±0.001 (245s)
MLP 0.634±0.002 (392s) 0.946±0.001 (2306s) 0.650±0.000 (263s) 0.978±0.000 (468s) 0.699±0.001 (357s) 0.991±0.000 (2491s) 0.869±0.001 (449s) 0.883±0.001 (695s) 0.727±0.002 (368s)
FTTransformerBucket 0.637±0.000 (8032s) 0.947±0.000 (6571s) 0.649±0.001 (714s) 0.986±0.000 (2138s) 0.651±0.060 (1473s) 0.832±0.153 (8248s) 0.866±0.001 (1531s) 0.877±0.000 (2960s) 0.688±0.001 (1983s)
ExcelFormer OOM 0.948±0.000 (6278s) 0.651±0.001 (602s) 0.982±0.000 (2691s) 0.716±0.001 (1263s) 0.995±0.002 (20487s) 0.879±0.001 (2541s) 0.883±0.002 (2983s) 0.814±0.001 (3040s)
FTTransformer 0.632±0.001 (7669s) 0.946±0.001 (4613s) 0.652±0.000 (587s) 0.981±0.000 (3048s) 0.704±0.001 (980s) 0.984±0.001 (15615s) 0.871±0.002 (1424s) 0.878±0.002 (2933s) 0.713±0.001 (1656s)
TabNet 0.628±0.002 (282s) 0.945±0.001 (338s) 0.650±0.001 (91s) 0.977±0.000 (95s) 0.706±0.001 (62s) 0.993±0.000 (671s) 0.862±0.001 (80s) 0.889±0.000 (120s) 0.797±0.001 (93s)
TabTransformer 0.634±0.001 (10791s) 0.942±0.001 (10599s) 0.642±0.000 (128s) 0.980±0.000 (130s) 0.698±0.002 (742s) 0.968±0.002 (24409s) 0.867±0.000 (94s) 0.873±0.000 (1132s) 0.788±0.000 (147s)

scale: large

Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 3 Optuna search trials and 10 epochs training for deep learning models.

dataset_0
XGBoost 0.792±0.000 (28889s)
CatBoost 0.788±0.000 (240s)
LightGBM 0.831±0.000 (167s)
Trompt 0.822±0.002 (15418s)
ResNet 0.831±0.001 (764s)
MLP 0.824±0.001 (220s)
FTTransformerBucket 0.825±0.001 (5387s)
ExcelFormer 0.842±0.002 (5264s)
FTTransformer 0.835±0.001 (5072s)
TabNet 0.837±0.001 (404s)
TabTransformer 0.790±0.002 (457s)

task_type: regression

Metric: RMSE, lower the better.

scale: small

Experimental setting: 20 Optuna search trials. 50 epochs of training.

dataset_0 dataset_1 dataset_2 dataset_3 dataset_4 dataset_5 dataset_6 dataset_7 dataset_8 dataset_9 dataset_10 dataset_11 dataset_12
XGBoost 0.250±0.000 (22s) 0.038±0.000 (1011s) 0.187±0.000 (19s) 0.475±0.000 (439s) 0.328±0.000 (32s) 0.401±0.000 (375s) 0.249±0.000 (340s) 0.363±0.000 (378s) 0.904±0.000 (2400s) 0.056±0.000 (250s) 0.820±0.000 (721s) 0.857±0.000 (487s) 0.418±0.000 (46s)
CatBoost 0.265±0.000 (116s) 0.062±0.000 (129s) 0.128±0.000 (97s) 0.336±0.000 (103s) 0.346±0.000 (110s) 0.443±0.000 (97s) 0.375±0.000 (46s) 0.273±0.000 (693s) 0.881±0.000 (660s) 0.040±0.000 (80s) 0.756±0.000 (44s) 0.876±0.000 (110s) 0.439±0.000 (101s)
LightGBM 0.253±0.000 (38s) 0.054±0.000 (24s) 0.112±0.000 (10s) 0.302±0.000 (30s) 0.325±0.000 (30s) 0.384±0.000 (23s) 0.295±0.000 (15s) 0.272±0.000 (26s) 0.877±0.000 (16s) 0.011±0.000 (12s) 0.702±0.000 (13s) 0.863±0.000 (5s) 0.395±0.000 (40s)
Trompt 0.261±0.003 (8390s) 0.015±0.005 (3792s) 0.118±0.001 (3836s) 0.262±0.001 (10037s) 0.323±0.001 (9255s) 0.418±0.003 (9071s) 0.329±0.009 (2977s) 0.312±0.002 (21967s) OOM 0.008±0.001 (1889s) 0.779±0.006 (775s) 0.874±0.004 (3723s) 0.424±0.005 (3185s)
ResNet 0.288±0.006 (220s) 0.018±0.003 (187s) 0.124±0.001 (135s) 0.268±0.001 (330s) 0.335±0.001 (471s) 0.434±0.004 (345s) 0.325±0.012 (178s) 0.324±0.004 (365s) 0.895±0.005 (142s) 0.036±0.002 (172s) 0.794±0.006 (120s) 0.875±0.004 (122s) 0.468±0.004 (303s)
MLP 0.300±0.002 (108s) 0.141±0.015 (76s) 0.125±0.001 (44s) 0.272±0.002 (69s) 0.348±0.001 (103s) 0.435±0.002 (33s) 0.331±0.008 (43s) 0.380±0.004 (125s) 0.893±0.002 (69s) 0.017±0.001 (48s) 0.784±0.007 (29s) 0.881±0.005 (30s) 0.467±0.003 (92s)
FTTransformerBucket 0.325±0.008 (619s) 0.096±0.005 (290s) 0.360±0.354 (332s) 0.284±0.005 (768s) 0.342±0.004 (757s) 0.441±0.003 (835s) 0.345±0.007 (191s) 0.339±0.003 (3321s) OOM 0.105±0.011 (199s) 0.807±0.010 (156s) 0.885±0.008 (820s) 0.468±0.006 (706s)
ExcelFormer 0.262±0.004 (770s) 0.099±0.003 (490s) 0.128±0.000 (362s) 0.264±0.003 (796s) 0.331±0.003 (1121s) 0.411±0.005 (469s) 0.298±0.012 (222s) 0.308±0.007 (5522s) OOM 0.011±0.001 (227) 0.785±0.011 (314s) 0.890±0.003 (1186s) 0.431±0.006 (682s)
FTTransformer 0.335±0.010 (338s) 0.161±0.022 (370s) 0.140±0.002 (244s) 0.277±0.004 (516s) 0.335±0.003 (973s) 0.445±0.003 (599s) 0.361±0.018 (286s) 0.345±0.005 (2443s) OOM 0.106±0.012 (150s) 0.826±0.005 (121s) 0.896±0.007 (832s) 0.461±0.003 (647s)
TabNet 0.279±0.003 (68s) 0.224±0.016 (53s) 0.141±0.010 (34s) 0.275±0.002 (61s) 0.348±0.003 (110s) 0.451±0.007 (82s) 0.355±0.030 (49s) 0.332±0.004 (168s) 0.992±0.182 (53s) 0.015±0.002 (57s) 0.805±0.014 (27s) 0.885±0.013 (46s) 0.544±0.011 (112s)
TabTransformer 0.624±0.003 (1225s) 0.229±0.003 (1200s) 0.369±0.005 (52s) 0.340±0.004 (163s) 0.388±0.002 (1137s) 0.539±0.003 (100s) 0.619±0.005 (73s) 0.351±0.001 (125s) 0.893±0.005 (389s) 0.431±0.001 (489s) 0.819±0.002 (52s) 0.886±0.005 (46s) 0.545±0.004 (95s)

scale: medium

Experimental setting: 20 Optuna search trials for XGBoost and CatBoost. 5 Optuna search trials and 25 epochs training for deep learning models.

dataset_0 dataset_1 dataset_2 dataset_3 dataset_4 dataset_5
XGBoost 0.663±0.000 (18528s) 0.014±0.000 (380s) 0.089±0.000 (2441s) 0.140±0.000 (1632s) 0.539±0.000 (22047s) 0.900±0.000 (1420s)
CatBoost 0.669±0.000 (2037s) 0.018±0.000 (649s) 0.092±0.000 (391s) 0.145±0.000 (271s) 0.549±0.000 (1347s) 0.898±0.000 (122s)
LightGBM 0.660±0.000 (199s) 0.015±0.000 (86s) 0.085±0.000 (39s) 0.141±0.000 (35s) 0.524±0.000 (148s) 0.895±0.000 (7s)
Trompt OOM 0.014±0.000 (19976s) 0.092±0.001 (4060s) 0.140±0.000 (3487s) 0.537±0.000 (26520s) 0.901±0.000 (2333s)
ResNet 0.676±0.000 (894s) 0.016±0.000 (548s) 0.101±0.001 (176s) 0.147±0.000 (503s) 0.555±0.003 (1121s) 0.903±0.000 (116s)
MLP 0.680±0.001 (907s) 0.016±0.000 (1015s) 0.105±0.000 (254s) 0.140±0.000 (313s) 0.558±0.001 (1756s) 0.905±0.001 (240s)
FTTransformerBucket 0.738±0.029 (17223s) 0.023±0.000 (2573s) 0.113±0.002 (645s) 0.147±0.000 (970s) 0.545±0.000 (3009s) 0.908±0.000 (360s)
ExcelFormer 0.667±0.000 (35946s) 0.015±0.001 (2677s) 0.090±0.001 (603s) 0.142±0.000 (1162s) 0.526±0.001 (2403s) 0.901±0.003 (330s)
FTTransformer 0.673±0.000 (18524s) 0.056±0.003 (3348s) 0.119±0.003 (396s) 0.141±0.000 (1049s) 0.561±0.001 (2403s) 0.907±0.002 (302s)
TabNet 0.683±0.001 (521s) 0.024±0.001 (437s) 0.115±0.003 (72s) 0.140±0.000 (319s) 0.549±0.001 (760s) 0.899±0.001 (37s)
TabTransformer OOM 0.799±0.000 (2829s) 0.148±0.000 (720s) 0.708±0.000 (182s) 0.755±0.000 (4008s) 0.964±0.000 (599s)

scale: large

Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 3 Optuna search trials and 10 epochs training for deep learning models.

dataset_0 dataset_1 dataset_2 dataset_3 dataset_4 dataset_5
XGBoost 0.966±0.000 (19327s)
CatBoost 0.971±0.000 (223s)
LightGBM 0.965±0.000 (67s)
Trompt 0.970±0.000 (12358s) Too slow* Too slow* Too slow* 0.796±0.001 (157380s) 0.799±0.001 (55577s)
ResNet 0.969±0.000 (5800s) 0.976±0.000 (57298s) 0.655±0.000 (4612s) 0.880±0.000 (20101s) 0.757±0.000 (53493s) 0.772±0.001 (3906s)
MLP 0.973±0.000 (223s) 0.980±0.000 (1104s) 0.681±0.002 (91s) 0.908±0.001 (529s) 0.792±0.001 (680s) 0.810±0.002 (117s)
FTTransformerBucket 0.970±0.000 (2071s) 0.977±0.000 (10953s) 0.776±0.004 (1608s) 0.896±0.002 (33883s) 0.785±0.002 (145035s) 0.808±0.003 (10407s)
ExcelFormer 0.969±0.000 (1785s) 0.978±0.001 (47745s) 0.649±0.000 (15888s) Too slow* Too slow* 0.782±0.001 (14729s)
FTTransformer 0.969±0.000 (25604s) 0.975±0.000 (219716s) 0.670±0.004 (17119s) 0.888±0.002 (323939s) Too slow* 0.788±0.001 (103434s)
TabNet 0.968±0.001 (3177s) 0.974±0.000 (44035s) 0.655±0.001 (2679s) 0.891±0.001 (6800s) 0.767±0.001 (6242s) 0.784±0.001 (2307s)
TabTransformer 0.984±0.000 (2813s) 0.986±0.000 (13457s) 0.713±0.002 (3767s) 0.895±0.001 (5679s) 0.764±0.001 (3763s) 0.781±0.000 (3540s)

task_type: multiclass_classification

Metric: Accuracy, the higher the better.

scale: medium

Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 5 Optuna search trials and 25 epochs training for deep learning models.

*: Too slow which takes more than a day for a single trial.

dataset_0 dataset_1 dataset_2
XGBoost Too slow* Too slow* Too slow*
CatBoost Too slow* Too slow* Too slow*
LightGBM Too slow* Too slow* Too slow*
Trompt OOM 0.373±0.004 (9114s) OOM
ResNet 0.951±0.000 (419s) 0.378±0.001 (171s) 0.723±0.001 (257s)
MLP 0.947±0.001 (1133s) 0.371±0.002 (462s) 0.723±0.002 (495s)
FTTransformerBucket 0.879±0.006 (9104s) 0.365±0.002 (1067s) 0.722±0.001 (2366s)
ExcelFormer OOM 0.378±0.001 (1790s) 0.734±0.002 (3024s)
FTTransformer 0.923±0.003 (14517s) 0.357±0.001 (754s) 0.724±0.004 (2621s)
TabNet 0.934±0.005 (218s) 0.349±0.001 (64s) 0.716±0.001 (153s)
TabTransformer 0.950±0.000 (160s) 0.352±0.001 (98s) 0.705±0.003 (103s)

scale: large

Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 3 Optuna search trials and 10 epochs training for deep learning models.

dataset_0 dataset_1 dataset_2
XGBoost Too slow* Too slow* Too slow*
CatBoost Too slow* Too slow* Too slow*
LightGBM Too slow* Too slow* Too slow*
Trompt OOM 0.889±0.063 (55428s) 0.804±0.013 (23304s)
ResNet 0.892±0.002 (417s) 0.999±0.001 (396s) 0.915±0.001 (405s)
MLP 0.770±0.001 (170s) 0.549±0.000 (223s) 0.895±0.001 (192s)
FTTransformerBucket 0.897±0.004 (4436s) 0.502±0.000 (1892s) 0.888±0.009 (4414s)
ExcelFormer OOM 0.999±0.001 (13952s) 0.951±0.002 (10236s)
FTTransformer 0.872±0.005 (7004s) 0.540±0.068 (3355s) 0.908±0.004 (7514s)
TabNet 0.912±0.004 (219s) 0.995±0.001 (301s) 0.919±0.003 (187s)
TabTransformer 0.843±0.003 (2810s) 0.657±0.187 (2843s) 0.854±0.001 (284s)

Benchmarking pytorch-frame and pytorch-tabular

pytorch_tabular_benchmark compares the performance of pytorch-frame to pytorch-tabular. pytorch-tabular excels in providing an accessible approach for standard tabular tasks, allowing users to quickly implement and experiment with existing tabular learning models. It also excels with its training loop modifications and explainability feature. On the other hand, ptroch-frame offers enhanced flexibility for exploring and building novel tabular learning approaches while still providing access to established models. It distinguishes itself through support for a wider array of data types, more sophisticated encoding schemas, and streamlined integration with LLMs. The following table shows the speed comparison of pytorch-frame to pytorch-tabular on implementations of TabNet and FTTransformer.

Package Model Num iters/sec
PyTorch Tabular TabNet 41.7
PyTorch Frame TabNet 45.0
PyTorch Tabular FTTransformer 40.1
PyTorch Frame FTTransformer 43.7