diff --git a/src/flash/core/integrations/pytorch_tabular/adapter.py b/src/flash/core/integrations/pytorch_tabular/adapter.py index b9aca1f243..e68b68d35b 100644 --- a/src/flash/core/integrations/pytorch_tabular/adapter.py +++ b/src/flash/core/integrations/pytorch_tabular/adapter.py @@ -50,6 +50,7 @@ def from_task( "categorical_dim": len(categorical_fields), "continuous_dim": num_features - len(categorical_fields), "output_dim": output_dim, + "embedded_cat_dim": sum([embd_dim for _, embd_dim in embedding_sizes]), } return cls( task_type, diff --git a/src/flash/core/integrations/pytorch_tabular/backbones.py b/src/flash/core/integrations/pytorch_tabular/backbones.py index 72084ae0d8..531d987709 100644 --- a/src/flash/core/integrations/pytorch_tabular/backbones.py +++ b/src/flash/core/integrations/pytorch_tabular/backbones.py @@ -30,6 +30,7 @@ AutoIntConfig, CategoryEmbeddingModelConfig, FTTransformerConfig, + GatedAdditiveTreeEnsembleConfig, NodeConfig, TabNetModelConfig, TabTransformerConfig, @@ -88,8 +89,9 @@ def load_pytorch_tabular( AutoIntConfig, NodeConfig, CategoryEmbeddingModelConfig, + GatedAdditiveTreeEnsembleConfig, ], - ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"], + ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding", "gate"], ): PYTORCH_TABULAR_BACKBONES( functools.partial(load_pytorch_tabular, model_config_class), diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py index f52f045830..074f9564dd 100644 --- a/tests/tabular/classification/test_data_model_integration.py +++ b/tests/tabular/classification/test_data_model_integration.py @@ -39,8 +39,7 @@ ("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), - # ("category_embedding", # todo: seems to be bug in tabular - # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), # No categorical / numerical fields ("tabnet", {"categorical_fields": ["category"]}), ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index 008a797a99..a42d07b422 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -55,7 +55,7 @@ class TestTabularClassifier(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular + {"backbone": "category_embedding"}, ], ) ], @@ -68,7 +68,7 @@ class TestTabularClassifier(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular + {"backbone": "category_embedding"}, ], ) ], @@ -81,7 +81,7 @@ class TestTabularClassifier(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular + {"backbone": "category_embedding"}, ], ) ], diff --git a/tests/tabular/regression/test_data_model_integration.py b/tests/tabular/regression/test_data_model_integration.py index 0a01bac532..9aff8a806e 100644 --- a/tests/tabular/regression/test_data_model_integration.py +++ b/tests/tabular/regression/test_data_model_integration.py @@ -48,8 +48,7 @@ ("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), - # ("category_embedding", # todo: seems to be bug in tabular - # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), # No categorical / numerical fields ("tabnet", {"categorical_fields": ["category"]}), ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), @@ -82,8 +81,7 @@ def test_regression_data_frame(backbone, fields, tmpdir): ("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), - # ("category_embedding", # todo: seems to be bug in tabular - # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), # No categorical / numerical fields ("tabnet", {"categorical_fields": ["category"]}), ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), @@ -113,8 +111,7 @@ def test_regression_dicts(backbone, fields, tmpdir): ("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), ("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), - # ("category_embedding", # todo: seems to be bug in tabular - # {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), + ("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}), # No categorical / numerical fields ("tabnet", {"categorical_fields": ["category"]}), ("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}), diff --git a/tests/tabular/regression/test_model.py b/tests/tabular/regression/test_model.py index 52dd36ed9b..31674b2847 100644 --- a/tests/tabular/regression/test_model.py +++ b/tests/tabular/regression/test_model.py @@ -53,7 +53,7 @@ class TestTabularRegressor(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular + {"backbone": "category_embedding"}, ], ) ], @@ -66,7 +66,7 @@ class TestTabularRegressor(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular + {"backbone": "category_embedding"}, ], ) ], @@ -79,7 +79,7 @@ class TestTabularRegressor(TaskTester): {"backbone": "fttransformer"}, {"backbone": "autoint"}, {"backbone": "node"}, - # {"backbone": "category_embedding"}, # todo: seems to be bug in tabular + {"backbone": "category_embedding"}, ], ) ],