Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove deep_transformer module #92

Merged
merged 1 commit into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ repos:
args: ["--profile=black"]

- repo: https://github.com/PyCQA/pylint
rev: v2.16.2
rev: v2.17.0
hooks:
- id: pylint
args: ["--rcfile=pyproject.toml"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
rev: v1.1.1
hooks:
- id: mypy
args:
Expand All @@ -139,7 +139,7 @@ repos:
- id: nbstripout

- repo: https://github.com/python-poetry/poetry
rev: 1.3.0
rev: 1.4.0
hooks:
- id: poetry-check
- id: poetry-lock
Expand Down
1 change: 0 additions & 1 deletion docs/API-reference/transformer/deep_transformer.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ poetry install
| ------ | ----------- | ----------- |
|[`Datetime transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/datetime_transformer/)|[`DateColumnsTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/datetime_transformer/#sk_transformers.datetime_transformer.DateColumnsTransformer)|Splits a date column into multiple columns.|
|[`Datetime transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/datetime_transformer/)|[`DurationCalculatorTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/datetime_transformer/#sk_transformers.datetime_transformer.DurationCalculatorTransformer)|Calculates the duration between to given dates.|
|[`Deep transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/)|[`ToVecTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/#sk_transformers.deep_transformer.ToVecTransformer)|This transformer trains an [FT-Transformer](https://paperswithcode.com/method/ft-transformer) using the [pytorch-widedeep package](https://github.com/jrzaurin/pytorch-widedeep) and extracts the embeddings from its embedding layer.|
|[`Encoder transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/encoder_transformer/)|[`MeanEncoderTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/encoder_transformer/#sk_transformers.encoder_transformer.MeanEncoderTransformer)|Scikit-learn API for the [feature-engine MeanEncoder](https://feature-engine.readthedocs.io/en/latest/api_doc/encoding/MeanEncoder.html).|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`AggregateTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.AggregateTransformer)|This transformer uses Pandas groupby method and aggregate to apply function on a column grouped by another column.|
|[`Generic transformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/)|[`AllowedValuesTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/generic_transformer/#sk_transformers.generic_transformer.AllowedValuesTransformer)|This transformer replaces values that are *not* in a list with another value.|
Expand Down
53 changes: 0 additions & 53 deletions examples/playground.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -116,59 +116,6 @@
"transformer.fit_transform(X).to_numpy()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Deep transformer](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### [`ToVecTransformer`](https://chrislemke.github.io/sk-transformers/API-reference/transformer/deep_transformer/#sk_transformers.deep_transformer.ToVecTransformer)\n",
"\n",
"This transformer trains an [FT-Transformer](https://paperswithcode.com/method/ft-transformer)\n",
"using the [pytorch-widedeep package](https://github.com/jrzaurin/pytorch-widedeep) and extracts the embeddings\n",
"from its embedding layer. The output shape of the transformer is (number of rows,(`input_dim` * number of columns)).\n",
"Please refer to [this example](https://pytorch-widedeep.readthedocs.io/en/latest/examples/09_extracting_embeddings.html)\n",
"for pytorch_widedeep example on how to extract embeddings."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from pytorch_widedeep.datasets import load_adult\n",
"from sk_transformers import ToVecTransformer\n",
"\n",
"df = load_adult(as_frame=True)\n",
"df[\"target\"] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
"df = df.drop([\"income\", \"educational-num\"], axis=1)\n",
"\n",
"cat_cols, cont_cols = [], []\n",
"for col in df.columns:\n",
" if df[col].dtype == \"O\" or df[col].nunique() < 50 and col != \"target\":\n",
" cat_cols.append(col)\n",
" elif col != \"target\":\n",
" cont_cols.append(col)\n",
"\n",
"target_col = \"target\"\n",
"target = df[target_col].to_numpy()\n",
"\n",
"transformer = ToVecTransformer(\n",
" cat_cols, cont_cols, verbose=0, training_objective=\"binary\"\n",
")\n",
"transformer.fit_transform(df, target).shape"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
Loading