Skip to content

Commit

Permalink
Merge pull request #274 from juaml/fix/wrap_trans
Browse files Browse the repository at this point in the history
[ENH]: Optimise step wrapping
  • Loading branch information
fraimondo authored Oct 17, 2024
2 parents eac0462 + f42ed99 commit 36fb37a
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/changes/newsfragments/274.enh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimise wrapping of steps and models in the pipeline only when a subset of features is being used, by `Fede Raimondo`_
7 changes: 7 additions & 0 deletions docs/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@ Here you can find the comprehensive list of flags that can be set:
- | Disable printing the list of expanded column names in ``X_types``.
| If set to ``True``, the list of types of X will not be printed.
- The user will not see the expanded ``X_types`` column names.
* - ``enable_parallel_column_transformers``
- | This flag enables parallel execution of column transformers by
| reverting to the default behaviour of scikit-learn
| (instead of using ``n_jobs=1``)
| If set to ``True``, the parameter will be set back to None.
- | Column transformers will be applied in parallel, using more resources.
| than expected.
2 changes: 1 addition & 1 deletion examples/02_inspection/plot_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
sns.scatterplot(x=X[0], y=X[1], data=df, ax=axes[0])
axes[0].set_title("Raw features")
sns.scatterplot(x="pca__pca0", y="pca__pca1", data=X_after_pca, ax=axes[1])
sns.scatterplot(x="pca0", y="pca1", data=X_after_pca, ax=axes[1])
axes[1].set_title("PCA components")

###############################################################################
Expand Down
2 changes: 1 addition & 1 deletion examples/04_confounds/run_return_confounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@
# (including confounds and categorical variables).
# Here we can see that the model is using 10 features (9 deconfounded features
# and the confound).
print(len(model.steps[-1][1].model.coef_))
print(len(model.steps[-1][1].coef_))
2 changes: 1 addition & 1 deletion examples/99_docs/run_model_inspection_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
for fold_inspector in inspector.folds:
fold_model = fold_inspector.model
c_values.append(
fold_model.get_fitted_params()["svm__model_"].get_params()["C"]
fold_model.get_fitted_params()["svm__C"]
)

##############################################################################
Expand Down
1 change: 1 addition & 0 deletions julearn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_global_config["disable_xtypes_check"] = False
_global_config["disable_x_verbose"] = False
_global_config["disable_xtypes_verbose"] = False
_global_config["enable_parallel_column_transformers"] = False


def set_config(key: str, value: Any) -> None:
Expand Down
5 changes: 4 additions & 1 deletion julearn/inspect/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,14 @@ def get_fitted_params(self):
),
}

return {
private_params = {
param: val
for param, val in all_params.items()
if re.match(r"^[a-zA-Z].*[a-zA-Z0-9]*_$", param)
}
out = self.get_params()
out.update(private_params)
return out

@property
def estimator(self):
Expand Down
26 changes: 24 additions & 2 deletions julearn/inspect/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,25 @@ def test_steps(
@pytest.mark.parametrize(
"est,fitted_params",
[
[MockTestEst(), {"param_0_": 0, "param_1_": 1}],
[
MockTestEst(),
{"hype_0": 0, "hype_1": 1, "param_0_": 0, "param_1_": 1},
],
[
JuColumnTransformer(
"test",
MockTestEst(), # type: ignore
"continuous",
),
{"param_0_": 0, "param_1_": 1},
{
"hype_0": 0,
"hype_1": 1,
"param_0_": 0,
"param_1_": 1,
"needed_types": None,
"row_select_col_type": None,
"row_select_vals": None,
},
],
],
)
Expand All @@ -183,6 +194,9 @@ def test_inspect_estimator(
assert est.get_params() == inspector.get_params()
inspect_params = inspector.get_fitted_params()
inspect_params.pop("column_transformer_", None)
inspect_params.pop("apply_to", None)
inspect_params.pop("transformer", None)
inspect_params.pop("name", None)
assert fitted_params == inspect_params


Expand All @@ -196,8 +210,14 @@ def test_inspect_pipeline(df_iris: "pd.DataFrame") -> None:
"""
expected_fitted_params = {
"jucolumntransformer__hype_0": 0,
"jucolumntransformer__hype_1": 1,
"jucolumntransformer__param_0_": 0,
"jucolumntransformer__param_1_": 1,
"jucolumntransformer__needed_types": None,
"jucolumntransformer__row_select_col_type": None,
"jucolumntransformer__row_select_vals": None,
"jucolumntransformer__name": "test",
}

pipe = (
Expand All @@ -216,6 +236,8 @@ def test_inspect_pipeline(df_iris: "pd.DataFrame") -> None:
inspector = PipelineInspector(pipe)
inspect_params = inspector.get_fitted_params()
inspect_params.pop("jucolumntransformer__column_transformer_", None)
inspect_params.pop("jucolumntransformer__transformer", None)
inspect_params.pop("jucolumntransformer__apply_to", None)
inspect_params = {
key: val
for key, val in inspect_params.items()
Expand Down
44 changes: 39 additions & 5 deletions julearn/pipeline/pipeline_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,37 @@
from .target_pipeline_creator import TargetPipelineCreator


def _should_wrap_this_step(
X_types: Dict[str, List[str]], # noqa: N803
apply_to: ColumnTypesLike,
) -> bool:
"""Check if we should wrap the step.
Parameters
----------
X_types : Dict[str, List[str]]
The types of the columns in the data.
apply_to : ColumnTypesLike
The types to apply this step to.
Returns
-------
bool
Whether we should wrap the step.
"""

# If we have a wildcard, we will not wrap the step
if any(x in ["*", ".*"] for x in apply_to):
return False

# If any of the X_types is not in the apply_to, we will wrap the step
if any(x not in apply_to for x in X_types.keys()):
return True

return False


def _params_to_pipeline(
param: Any,
X_types: Dict[str, List], # noqa: N803
Expand Down Expand Up @@ -511,7 +542,9 @@ def to_pipeline(
logger.debug(f"\t Params to tune: {step_params_to_tune}")

# Wrap in a JuTransformer if needed
if self.wrap and not isinstance(estimator, JuTransformer):
if _should_wrap_this_step(
X_types, step_dict.apply_to
) and not isinstance(estimator, JuTransformer):
estimator = self._wrap_step(
name,
estimator,
Expand Down Expand Up @@ -539,7 +572,9 @@ def to_pipeline(
for k, v in model_params.items()
}
model_estimator.set_params(**model_params)
if self.wrap and not isinstance(model_estimator, JuModelLike):
if _should_wrap_this_step(
X_types, model_step.apply_to
) and not isinstance(model_estimator, JuModelLike):
logger.debug(f"Wrapping {model_name}")
model_estimator = WrapModel(model_estimator, model_step.apply_to)

Expand Down Expand Up @@ -789,12 +824,11 @@ def _check_X_types(
"this type."
)

self.wrap = needed_types != {"continuous"}
return X_types

@staticmethod
def _is_transformer_step(
step: Union[str, EstimatorLike, TargetPipelineCreator]
step: Union[str, EstimatorLike, TargetPipelineCreator],
) -> bool:
"""Check if a step is a transformer."""
if step in list_transformers():
Expand All @@ -805,7 +839,7 @@ def _is_transformer_step(

@staticmethod
def _is_model_step(
step: Union[EstimatorLike, str, TargetPipelineCreator]
step: Union[EstimatorLike, str, TargetPipelineCreator],
) -> bool:
"""Check if a step is a model."""
if step in list_models():
Expand Down
53 changes: 50 additions & 3 deletions julearn/pipeline/tests/test_pipeline_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from sklearn.pipeline import Pipeline


def test_construction_working(
def test_construction_working_wrapping(
model: str, preprocess: Union[str, List[str]], problem_type: str
) -> None:
"""Test that the pipeline constructions works as expected.
"""Test that the pipeline constructions works as expected (wrapping).
Parameters
----------
Expand All @@ -46,7 +46,7 @@ def test_construction_working(
for step in preprocess:
creator.add(step, apply_to="categorical")
creator.add(model)
X_types = {"categorical": ["A"]}
X_types = {"categorical": ["A"], "continuous": ["B"]}
pipeline = creator.to_pipeline(X_types=X_types)

# check preprocessing steps
Expand All @@ -72,6 +72,53 @@ def test_construction_working(
assert len(preprocess) + 2 == len(pipeline.steps)


def test_construction_working_nowrapping(
model: str, preprocess: Union[str, List[str]], problem_type: str
) -> None:
"""Test that the pipeline constructions works as expected (no wrapping).
Parameters
----------
model : str
The model to test.
preprocess : str or list of str
The preprocessing steps to test.
problem_type : str
The problem type to test.
"""
creator = PipelineCreator(problem_type=problem_type)
preprocess = preprocess if isinstance(preprocess, list) else [preprocess]
for step in preprocess:
creator.add(step, apply_to="*")
creator.add(model, apply_to=["categorical", "continuous"])
X_types = {"categorical": ["A"], "continuous": ["B"]}
pipeline = creator.to_pipeline(X_types=X_types)

# check preprocessing steps
# ignoring first step for types and last for model
for element in zip(preprocess, pipeline.steps[1:-1]):
_preprocess, (name, transformer) = element
assert name.startswith(f"{_preprocess}")
assert not isinstance(transformer, JuColumnTransformer)
assert isinstance(
transformer, get_transformer(_preprocess).__class__
)

# check model step
model_name, model = pipeline.steps[-1]
assert not isinstance(model, WrapModel)
assert isinstance(
model,
get_model(
model_name,
problem_type=problem_type,
).__class__,
)
assert len(preprocess) + 2 == len(pipeline.steps)



def test_fit_and_transform_no_error(
X_iris: pd.DataFrame, # noqa: N803
y_iris: pd.Series,
Expand Down
2 changes: 1 addition & 1 deletion julearn/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def test_api_stacking_models() -> None:
# The final model should be a stacking model im which the first estimator
# is a grid search
assert isinstance(
final.steps[1][1].model.estimators[0][1], # type: ignore
final.steps[1][1].estimators[0][1], # type: ignore
GridSearchCV,
)

Expand Down
4 changes: 4 additions & 0 deletions julearn/transformers/ju_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.utils.validation import check_is_fitted

from ..base import ColumnTypesLike, JuTransformer, ensure_column_types
from ..config import get_config
from ..utils.logging import raise_error
from ..utils.typing import DataLike, EstimatorLike

Expand Down Expand Up @@ -93,6 +94,9 @@ def _fit(
[(self.name, self.transformer, self.apply_to.to_type_selector())],
verbose_feature_names_out=verbose_feature_names_out,
remainder="passthrough",
n_jobs=None
if get_config("enable_parallel_column_transformers")
else 1,
)
self.column_transformer_.fit(X, y, **fit_params)

Expand Down

0 comments on commit 36fb37a

Please sign in to comment.