Skip to content

Commit

Permalink
Control parallel column transformers for wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
fraimondo committed Oct 8, 2024
1 parent 88ee8fe commit 731abbd
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@ Here you can find the comprehensive list of flags that can be set:
- | Disable printing the list of expanded column names in ``X_types``.
| If set to ``True``, the list of types of X will not be printed.
- The user will not see the expanded ``X_types`` column names.
* - ``enable_parallel_column_transformers``
- | This flag enables parallel execution of column transformers by
| reverting to the default behaviour of scikit-learn
| (instead of using ``n_jobs=1``)
| If set to ``True``, the parameter will be set back to None.
- | Column transformers will be applied in parallel, using more resources.
| than expected.
1 change: 1 addition & 0 deletions julearn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_global_config["disable_xtypes_check"] = False
_global_config["disable_x_verbose"] = False
_global_config["disable_xtypes_verbose"] = False
_global_config["enable_parallel_column_transformers"] = False


def set_config(key: str, value: Any) -> None:
Expand Down
4 changes: 4 additions & 0 deletions julearn/transformers/ju_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.utils.validation import check_is_fitted

from ..base import ColumnTypesLike, JuTransformer, ensure_column_types
from ..config import get_config
from ..utils.logging import raise_error
from ..utils.typing import DataLike, EstimatorLike

Expand Down Expand Up @@ -93,6 +94,9 @@ def _fit(
[(self.name, self.transformer, self.apply_to.to_type_selector())],
verbose_feature_names_out=verbose_feature_names_out,
remainder="passthrough",
n_jobs=None
if get_config("enable_parallel_column_transformers")
else 1,
)
self.column_transformer_.fit(X, y, **fit_params)

Expand Down

0 comments on commit 731abbd

Please sign in to comment.