First install additional dependencies:
pip install optuna
pip install torchmetrics
pip install xgboost
pip install catboost
pip install lightgbm
Then run
# Specify the model from [TabNet, FTTransformer, ResNet, MLP, TabTransformer,
# Trompt, ExcelFormer, FTTransformerBucket, XGBoost, CatBoost, LightGBM]
model_type=TabNet
# Specify the task type from [binary_classification, regression,
# multiclass_classification]
task_type=binary_classification
# Specify the dataset scale from [small, medium, large]
scale=small
# Specify the dataset idx from [0, 1, ...]
idx=0
# Specify the number of AutoML search trials
num_trials=20
# Specify the path to save the results
result_path=results.pt
# Run hyper-parameter tuning and training of the specified model on a specified
# dataset.
python data_frame_benchmark.py --model_type $model_type\
--task_type $task_type\
--scale $scale\
--idx $idx\
--num_trials $num_trials\
--result_path $result_path
We show the current model performance across different datasets.
The row denotes the model names and the column denotes the dataset idx
.
In each cell, we include the mean and standard deviation of the model performance, as well as
the total time spent, including Optuna
-based hyper-parameter search and final model training.
For the mapping from dataset idx
into the actual dataset object, please see the documentation.
Metric: ROC-AUC, higher the better.
Experimental setting: 20 Optuna search trials. 50 epochs of training.
dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | dataset_13 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
XGBoost | 0.931±0.000 (41s) | 1.000±0.000 (3s) | 0.940±0.000 (389s) | 0.947±0.000 (42s) | 0.885±0.000 (109s) | 0.966±0.000 (14s) | 0.862±0.000 (10s) | 0.779±0.000 (79s) | 0.984±0.000 (376s) | 0.714±0.000 (10s) | 0.787±0.000 (9s) | 0.951±0.000 (103s) | 0.999±0.000 (434s) | 0.925±0.000 (848s) |
CatBoost | 0.930±0.000 (152s) | 1.000±0.000 (9s) | 0.938±0.000 (164s) | 0.924±0.000 (29s) | 0.881±0.000 (27s) | 0.963±0.000 (48s) | 0.861±0.000 (12s) | 0.772±0.000 (10s) | 0.930±0.000 (91s) | 0.628±0.000 (10s) | 0.796±0.000 (15s) | 0.948±0.000 (46s) | 0.998±0.000 (38s) | 0.926±0.000 (115s) |
LightGBM | 0.931±0.000 (15s) | 0.999±0.000 (1s) | 0.943±0.000 (23s) | 0.943±0.000 (14s) | 0.887±0.000 (5s) | 0.972±0.000 (11s) | 0.862±0.000 (6s) | 0.774±0.000 (3s) | 0.979±0.000 (41s) | 0.732±0.000 (13s) | 0.787±0.000 (3s) | 0.951±0.000 (13s) | 0.999±0.000 (10s) | 0.927±0.000 (24s) |
Trompt | 0.919±0.000 (9627s) | 1.000±0.000 (5341s) | 0.945±0.000 (14679s) | 0.942±0.001 (2752s) | 0.881±0.001 (2640s) | 0.964±0.001 (5173s) | 0.855±0.002 (4249s) | 0.778±0.002 (8789s) | 0.933±0.001 (9353s) | 0.686±0.008 (3105s) | 0.793±0.002 (8255s) | 0.952±0.001 (4876s) | 1.000±0.000 (3558s) | 0.916±0.001 (30002s) |
ResNet | 0.917±0.000 (615s) | 1.000±0.000 (71s) | 0.937±0.001 (787s) | 0.938±0.002 (230s) | 0.865±0.001 (183s) | 0.960±0.001 (349s) | 0.828±0.001 (248s) | 0.768±0.002 (205s) | 0.925±0.002 (958s) | 0.665±0.006 (140s) | 0.794±0.002 (76s) | 0.946±0.002 (145s) | 1.000±0.000 (93s) | 0.911±0.001 (880s) |
MLP | 0.913±0.001 (112s) | 1.000±0.000 (45s) | 0.934±0.001 (274s) | 0.938±0.001 (66s) | 0.863±0.002 (61s) | 0.953±0.000 (92s) | 0.830±0.001 (68s) | 0.769±0.002 (56s) | 0.903±0.002 (159s) | 0.666±0.015 (58s) | 0.789±0.001 (48s) | 0.940±0.002 (107s) | 1.000±0.000 (48s) | 0.910±0.001 (149s) |
FTTransformerBucket | 0.915±0.001 (690s) | 0.999±0.001 (354s) | 0.936±0.002 (1705s) | 0.939±0.002 (484s) | 0.876±0.002 (321s) | 0.960±0.001 (746s) | 0.857±0.000 (549s) | 0.771±0.003 (654s) | 0.909±0.002 (1177s) | 0.636±0.012 (244s) | 0.788±0.002 (710s) | 0.950±0.001 (510s) | 0.999±0.000 (634s) | 0.913±0.001 (1164s) |
ExcelFormer | 0.918±0.001 (1587s) | 1.000±0.000 (634s) | 0.939±0.001 (1827s) | 0.939±0.002 (378s) | 0.883±0.001 (289s) | 0.969±0.000 (678s) | 0.833±0.011 (435s) | 0.780±0.002 (938s) | 0.940±0.003 (919s) | 0.670±0.017 (464s) | 0.794±0.003 (683s) | 0.950±0.001 (405s) | 0.999±0.000 (1169s) | 0.919±0.001 (1798s) |
FTTransformer | 0.918±0.001 (871s) | 1.000±0.000 (571s) | 0.940±0.001 (1371s) | 0.936±0.001 (458s) | 0.874±0.002 (200s) | 0.959±0.001 (622s) | 0.828±0.001 (339s) | 0.773±0.002 (521s) | 0.909±0.002 (1488s) | 0.635±0.011 (392s) | 0.790±0.001 (556s) | 0.949±0.002 (374s) | 1.000±0.000 (713s) | 0.912±0.000 (1855s) |
TabNet | 0.911±0.001 (150s) | 1.000±0.000 (35s) | 0.931±0.005 (254s) | 0.937±0.003 (125s) | 0.864±0.002 (52s) | 0.944±0.001 (116s) | 0.828±0.001 (79s) | 0.771±0.005 (93s) | 0.913±0.005 (177s) | 0.606±0.014 (65s) | 0.790±0.003 (41s) | 0.936±0.003 (104s) | 1.000±0.000 (64s) | 0.910±0.001 (294s) |
TabTransformer | 0.910±0.001 (2044s) | 1.000±0.000 (1321s) | 0.928±0.001 (2519s) | 0.918±0.003 (134s) | 0.829±0.002 (64s) | 0.928±0.001 (105s) | 0.816±0.002 (99s) | 0.757±0.003 (645s) | 0.885±0.001 (1167s) | 0.652±0.006 (282s) | 0.780±0.002 (112s) | 0.937±0.001 (117s) | 0.996±0.000 (76s) | 0.905±0.001 (2283s) |
Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 5 Optuna search trials and 25 epochs training for deep learning models.
dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | |
---|---|---|---|---|---|---|---|---|---|
XGBoost | 0.594±0.000 (466s) | 0.955±0.000 (6340s) | 0.653±0.000 (19s) | 0.986±0.000 (195s) | 0.721±0.000 (62s) | 0.998±0.000 (70626s) | 0.868±0.000 (159s) | 0.888±0.000 (2945s) | 0.803±0.000 (371s) |
CatBoost | 0.631±0.000 (1201s) | 0.956±0.000 (2963s) | 0.649±0.000 (26s) | 0.986±0.000 (352s) | 0.719±0.000 (244s) | 0.987±0.000 (2561s) | 0.863±0.000 (212s) | 0.896±0.000 (740s) | 0.803±0.000 (140s) |
LightGBM | 0.639±0.000 (49s) | 0.955±0.000 (126s) | 0.652±0.000 (7s) | 0.986±0.000 (99s) | 0.723±0.000 (16s) | 0.997±0.000 (172s) | 0.881±0.000 (83s) | 0.914±0.000 (86s) | 0.809±0.000 (76s) |
Trompt | OOM | 0.950±0.000 (28212s) | 0.652±0.000 (5962s) | 0.982±0.000 (19936s) | 0.716±0.000 (7110s) | 0.966±0.000 (106916s) | 0.882±0.000 (13644s) | 0.883±0.000 (17863s) | 0.705±0.006 (11563s) |
ResNet | 0.637±0.000 (810s) | 0.948±0.000 (1051s) | 0.649±0.000 (185s) | 0.983±0.000 (239s) | 0.705±0.001 (226s) | 0.989±0.000 (1967s) | 0.871±0.001 (173s) | 0.890±0.001 (315s) | 0.719±0.001 (245s) |
MLP | 0.634±0.002 (392s) | 0.946±0.001 (2306s) | 0.650±0.000 (263s) | 0.978±0.000 (468s) | 0.699±0.001 (357s) | 0.991±0.000 (2491s) | 0.869±0.001 (449s) | 0.883±0.001 (695s) | 0.727±0.002 (368s) |
FTTransformerBucket | 0.637±0.000 (8032s) | 0.947±0.000 (6571s) | 0.649±0.001 (714s) | 0.986±0.000 (2138s) | 0.651±0.060 (1473s) | 0.832±0.153 (8248s) | 0.866±0.001 (1531s) | 0.877±0.000 (2960s) | 0.688±0.001 (1983s) |
ExcelFormer | OOM | 0.948±0.000 (6278s) | 0.651±0.001 (602s) | 0.982±0.000 (2691s) | 0.716±0.001 (1263s) | 0.995±0.002 (20487s) | 0.879±0.001 (2541s) | 0.883±0.002 (2983s) | 0.814±0.001 (3040s) |
FTTransformer | 0.632±0.001 (7669s) | 0.946±0.001 (4613s) | 0.652±0.000 (587s) | 0.981±0.000 (3048s) | 0.704±0.001 (980s) | 0.984±0.001 (15615s) | 0.871±0.002 (1424s) | 0.878±0.002 (2933s) | 0.713±0.001 (1656s) |
TabNet | 0.628±0.002 (282s) | 0.945±0.001 (338s) | 0.650±0.001 (91s) | 0.977±0.000 (95s) | 0.706±0.001 (62s) | 0.993±0.000 (671s) | 0.862±0.001 (80s) | 0.889±0.000 (120s) | 0.797±0.001 (93s) |
TabTransformer | 0.634±0.001 (10791s) | 0.942±0.001 (10599s) | 0.642±0.000 (128s) | 0.980±0.000 (130s) | 0.698±0.002 (742s) | 0.968±0.002 (24409s) | 0.867±0.000 (94s) | 0.873±0.000 (1132s) | 0.788±0.000 (147s) |
Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 3 Optuna search trials and 10 epochs training for deep learning models.
dataset_0 | |
---|---|
XGBoost | 0.792±0.000 (28889s) |
CatBoost | 0.788±0.000 (240s) |
LightGBM | 0.831±0.000 (167s) |
Trompt | 0.822±0.002 (15418s) |
ResNet | 0.831±0.001 (764s) |
MLP | 0.824±0.001 (220s) |
FTTransformerBucket | 0.825±0.001 (5387s) |
ExcelFormer | 0.842±0.002 (5264s) |
FTTransformer | 0.835±0.001 (5072s) |
TabNet | 0.837±0.001 (404s) |
TabTransformer | 0.790±0.002 (457s) |
Metric: RMSE, lower the better.
Experimental setting: 20 Optuna search trials. 50 epochs of training.
dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
XGBoost | 0.250±0.000 (22s) | 0.038±0.000 (1011s) | 0.187±0.000 (19s) | 0.475±0.000 (439s) | 0.328±0.000 (32s) | 0.401±0.000 (375s) | 0.249±0.000 (340s) | 0.363±0.000 (378s) | 0.904±0.000 (2400s) | 0.056±0.000 (250s) | 0.820±0.000 (721s) | 0.857±0.000 (487s) | 0.418±0.000 (46s) |
CatBoost | 0.265±0.000 (116s) | 0.062±0.000 (129s) | 0.128±0.000 (97s) | 0.336±0.000 (103s) | 0.346±0.000 (110s) | 0.443±0.000 (97s) | 0.375±0.000 (46s) | 0.273±0.000 (693s) | 0.881±0.000 (660s) | 0.040±0.000 (80s) | 0.756±0.000 (44s) | 0.876±0.000 (110s) | 0.439±0.000 (101s) |
LightGBM | 0.253±0.000 (38s) | 0.054±0.000 (24s) | 0.112±0.000 (10s) | 0.302±0.000 (30s) | 0.325±0.000 (30s) | 0.384±0.000 (23s) | 0.295±0.000 (15s) | 0.272±0.000 (26s) | 0.877±0.000 (16s) | 0.011±0.000 (12s) | 0.702±0.000 (13s) | 0.863±0.000 (5s) | 0.395±0.000 (40s) |
Trompt | 0.261±0.003 (8390s) | 0.015±0.005 (3792s) | 0.118±0.001 (3836s) | 0.262±0.001 (10037s) | 0.323±0.001 (9255s) | 0.418±0.003 (9071s) | 0.329±0.009 (2977s) | 0.312±0.002 (21967s) | OOM | 0.008±0.001 (1889s) | 0.779±0.006 (775s) | 0.874±0.004 (3723s) | 0.424±0.005 (3185s) |
ResNet | 0.288±0.006 (220s) | 0.018±0.003 (187s) | 0.124±0.001 (135s) | 0.268±0.001 (330s) | 0.335±0.001 (471s) | 0.434±0.004 (345s) | 0.325±0.012 (178s) | 0.324±0.004 (365s) | 0.895±0.005 (142s) | 0.036±0.002 (172s) | 0.794±0.006 (120s) | 0.875±0.004 (122s) | 0.468±0.004 (303s) |
MLP | 0.300±0.002 (108s) | 0.141±0.015 (76s) | 0.125±0.001 (44s) | 0.272±0.002 (69s) | 0.348±0.001 (103s) | 0.435±0.002 (33s) | 0.331±0.008 (43s) | 0.380±0.004 (125s) | 0.893±0.002 (69s) | 0.017±0.001 (48s) | 0.784±0.007 (29s) | 0.881±0.005 (30s) | 0.467±0.003 (92s) |
FTTransformerBucket | 0.325±0.008 (619s) | 0.096±0.005 (290s) | 0.360±0.354 (332s) | 0.284±0.005 (768s) | 0.342±0.004 (757s) | 0.441±0.003 (835s) | 0.345±0.007 (191s) | 0.339±0.003 (3321s) | OOM | 0.105±0.011 (199s) | 0.807±0.010 (156s) | 0.885±0.008 (820s) | 0.468±0.006 (706s) |
ExcelFormer | 0.262±0.004 (770s) | 0.099±0.003 (490s) | 0.128±0.000 (362s) | 0.264±0.003 (796s) | 0.331±0.003 (1121s) | 0.411±0.005 (469s) | 0.298±0.012 (222s) | 0.308±0.007 (5522s) | OOM | 0.011±0.001 (227) | 0.785±0.011 (314s) | 0.890±0.003 (1186s) | 0.431±0.006 (682s) |
FTTransformer | 0.335±0.010 (338s) | 0.161±0.022 (370s) | 0.140±0.002 (244s) | 0.277±0.004 (516s) | 0.335±0.003 (973s) | 0.445±0.003 (599s) | 0.361±0.018 (286s) | 0.345±0.005 (2443s) | OOM | 0.106±0.012 (150s) | 0.826±0.005 (121s) | 0.896±0.007 (832s) | 0.461±0.003 (647s) |
TabNet | 0.279±0.003 (68s) | 0.224±0.016 (53s) | 0.141±0.010 (34s) | 0.275±0.002 (61s) | 0.348±0.003 (110s) | 0.451±0.007 (82s) | 0.355±0.030 (49s) | 0.332±0.004 (168s) | 0.992±0.182 (53s) | 0.015±0.002 (57s) | 0.805±0.014 (27s) | 0.885±0.013 (46s) | 0.544±0.011 (112s) |
TabTransformer | 0.624±0.003 (1225s) | 0.229±0.003 (1200s) | 0.369±0.005 (52s) | 0.340±0.004 (163s) | 0.388±0.002 (1137s) | 0.539±0.003 (100s) | 0.619±0.005 (73s) | 0.351±0.001 (125s) | 0.893±0.005 (389s) | 0.431±0.001 (489s) | 0.819±0.002 (52s) | 0.886±0.005 (46s) | 0.545±0.004 (95s) |
Experimental setting: 20 Optuna search trials for XGBoost and CatBoost. 5 Optuna search trials and 25 epochs training for deep learning models.
dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | |
---|---|---|---|---|---|---|
XGBoost | 0.663±0.000 (18528s) | 0.014±0.000 (380s) | 0.089±0.000 (2441s) | 0.140±0.000 (1632s) | 0.539±0.000 (22047s) | 0.900±0.000 (1420s) |
CatBoost | 0.669±0.000 (2037s) | 0.018±0.000 (649s) | 0.092±0.000 (391s) | 0.145±0.000 (271s) | 0.549±0.000 (1347s) | 0.898±0.000 (122s) |
LightGBM | 0.660±0.000 (199s) | 0.015±0.000 (86s) | 0.085±0.000 (39s) | 0.141±0.000 (35s) | 0.524±0.000 (148s) | 0.895±0.000 (7s) |
Trompt | OOM | 0.014±0.000 (19976s) | 0.092±0.001 (4060s) | 0.140±0.000 (3487s) | 0.537±0.000 (26520s) | 0.901±0.000 (2333s) |
ResNet | 0.676±0.000 (894s) | 0.016±0.000 (548s) | 0.101±0.001 (176s) | 0.147±0.000 (503s) | 0.555±0.003 (1121s) | 0.903±0.000 (116s) |
MLP | 0.680±0.001 (907s) | 0.016±0.000 (1015s) | 0.105±0.000 (254s) | 0.140±0.000 (313s) | 0.558±0.001 (1756s) | 0.905±0.001 (240s) |
FTTransformerBucket | 0.738±0.029 (17223s) | 0.023±0.000 (2573s) | 0.113±0.002 (645s) | 0.147±0.000 (970s) | 0.545±0.000 (3009s) | 0.908±0.000 (360s) |
ExcelFormer | 0.667±0.000 (35946s) | 0.015±0.001 (2677s) | 0.090±0.001 (603s) | 0.142±0.000 (1162s) | 0.526±0.001 (2403s) | 0.901±0.003 (330s) |
FTTransformer | 0.673±0.000 (18524s) | 0.056±0.003 (3348s) | 0.119±0.003 (396s) | 0.141±0.000 (1049s) | 0.561±0.001 (2403s) | 0.907±0.002 (302s) |
TabNet | 0.683±0.001 (521s) | 0.024±0.001 (437s) | 0.115±0.003 (72s) | 0.140±0.000 (319s) | 0.549±0.001 (760s) | 0.899±0.001 (37s) |
TabTransformer | OOM | 0.799±0.000 (2829s) | 0.148±0.000 (720s) | 0.708±0.000 (182s) | 0.755±0.000 (4008s) | 0.964±0.000 (599s) |
Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 3 Optuna search trials and 10 epochs training for deep learning models.
dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | |
---|---|---|---|---|---|---|
XGBoost | 0.966±0.000 (19327s) | |||||
CatBoost | 0.971±0.000 (223s) | |||||
LightGBM | 0.965±0.000 (67s) | |||||
Trompt | 0.970±0.000 (12358s) | Too slow* | Too slow* | Too slow* | 0.796±0.001 (157380s) | 0.799±0.001 (55577s) |
ResNet | 0.969±0.000 (5800s) | 0.976±0.000 (57298s) | 0.655±0.000 (4612s) | 0.880±0.000 (20101s) | 0.757±0.000 (53493s) | 0.772±0.001 (3906s) |
MLP | 0.973±0.000 (223s) | 0.980±0.000 (1104s) | 0.681±0.002 (91s) | 0.908±0.001 (529s) | 0.792±0.001 (680s) | 0.810±0.002 (117s) |
FTTransformerBucket | 0.970±0.000 (2071s) | 0.977±0.000 (10953s) | 0.776±0.004 (1608s) | 0.896±0.002 (33883s) | 0.785±0.002 (145035s) | 0.808±0.003 (10407s) |
ExcelFormer | 0.969±0.000 (1785s) | 0.978±0.001 (47745s) | 0.649±0.000 (15888s) | Too slow* | Too slow* | 0.782±0.001 (14729s) |
FTTransformer | 0.969±0.000 (25604s) | 0.975±0.000 (219716s) | 0.670±0.004 (17119s) | 0.888±0.002 (323939s) | Too slow* | 0.788±0.001 (103434s) |
TabNet | 0.968±0.001 (3177s) | 0.974±0.000 (44035s) | 0.655±0.001 (2679s) | 0.891±0.001 (6800s) | 0.767±0.001 (6242s) | 0.784±0.001 (2307s) |
TabTransformer | 0.984±0.000 (2813s) | 0.986±0.000 (13457s) | 0.713±0.002 (3767s) | 0.895±0.001 (5679s) | 0.764±0.001 (3763s) | 0.781±0.000 (3540s) |
Metric: Accuracy, the higher the better.
Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 5 Optuna search trials and 25 epochs training for deep learning models.
*: Too slow which takes more than a day for a single trial.
dataset_0 | dataset_1 | dataset_2 | |
---|---|---|---|
XGBoost | Too slow* | Too slow* | Too slow* |
CatBoost | Too slow* | Too slow* | Too slow* |
LightGBM | Too slow* | Too slow* | Too slow* |
Trompt | OOM | 0.373±0.004 (9114s) | OOM |
ResNet | 0.951±0.000 (419s) | 0.378±0.001 (171s) | 0.723±0.001 (257s) |
MLP | 0.947±0.001 (1133s) | 0.371±0.002 (462s) | 0.723±0.002 (495s) |
FTTransformerBucket | 0.879±0.006 (9104s) | 0.365±0.002 (1067s) | 0.722±0.001 (2366s) |
ExcelFormer | OOM | 0.378±0.001 (1790s) | 0.734±0.002 (3024s) |
FTTransformer | 0.923±0.003 (14517s) | 0.357±0.001 (754s) | 0.724±0.004 (2621s) |
TabNet | 0.934±0.005 (218s) | 0.349±0.001 (64s) | 0.716±0.001 (153s) |
TabTransformer | 0.950±0.000 (160s) | 0.352±0.001 (98s) | 0.705±0.003 (103s) |
Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM. 3 Optuna search trials and 10 epochs training for deep learning models.
dataset_0 | dataset_1 | dataset_2 | |
---|---|---|---|
XGBoost | Too slow* | Too slow* | Too slow* |
CatBoost | Too slow* | Too slow* | Too slow* |
LightGBM | Too slow* | Too slow* | Too slow* |
Trompt | OOM | 0.889±0.063 (55428s) | 0.804±0.013 (23304s) |
ResNet | 0.892±0.002 (417s) | 0.999±0.001 (396s) | 0.915±0.001 (405s) |
MLP | 0.770±0.001 (170s) | 0.549±0.000 (223s) | 0.895±0.001 (192s) |
FTTransformerBucket | 0.897±0.004 (4436s) | 0.502±0.000 (1892s) | 0.888±0.009 (4414s) |
ExcelFormer | OOM | 0.999±0.001 (13952s) | 0.951±0.002 (10236s) |
FTTransformer | 0.872±0.005 (7004s) | 0.540±0.068 (3355s) | 0.908±0.004 (7514s) |
TabNet | 0.912±0.004 (219s) | 0.995±0.001 (301s) | 0.919±0.003 (187s) |
TabTransformer | 0.843±0.003 (2810s) | 0.657±0.187 (2843s) | 0.854±0.001 (284s) |
pytorch_tabular_benchmark
compares the performance of pytorch-frame
to pytorch-tabular
. pytorch-tabular
excels in providing an accessible approach for standard tabular tasks, allowing users to quickly implement and experiment with existing tabular learning models. It also excels with its training loop modifications and explainability feature. On the other hand, ptroch-frame
offers enhanced flexibility for exploring and building novel tabular learning approaches while still providing access to established models. It distinguishes itself through support for a wider array of data types, more sophisticated encoding schemas, and streamlined integration with LLMs.
The following table shows the speed comparison of pytorch-frame
to pytorch-tabular
on implementations of TabNet
and FTTransformer
.
Package | Model | Num iters/sec |
---|---|---|
PyTorch Tabular | TabNet | 41.7 |
PyTorch Frame | TabNet | 45.0 |
PyTorch Tabular | FTTransformer | 40.1 |
PyTorch Frame | FTTransformer | 43.7 |