Skip to content

Commit

Permalink
Fix scoring for generated target
Browse files Browse the repository at this point in the history
  • Loading branch information
fraimondo committed Sep 3, 2024
1 parent 6add24f commit 9411da8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
11 changes: 11 additions & 0 deletions julearn/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def run_cross_validation( # noqa: C901

wrap_score = False
if isinstance(model, (PipelineCreator, list)):
logger.debug(
"Generating pipeline from PipelineCreator or list of them"
)
if preprocess is not None:
raise_error(
"If model is a PipelineCreator (or list of), "
Expand Down Expand Up @@ -266,6 +269,7 @@ def run_cross_validation( # noqa: C901
expanded_models.extend(m.split())

has_target_transformer = expanded_models[-1]._added_target_transformer
has_target_generator = expanded_models[-1]._added_target_generator
all_pipelines = [
model.to_pipeline(X_types=X_types, search_params=search_params)
for model in expanded_models
Expand All @@ -279,12 +283,16 @@ def run_cross_validation( # noqa: C901
pipeline = all_pipelines[0]

if has_target_transformer:
logger.debug("Pipeline has target transformer")
if isinstance(pipeline, BaseSearchCV):
last_step = pipeline.estimator[-1] # type: ignore
else:
last_step = pipeline[-1]
if not last_step.can_inverse_transform():
wrap_score = True
if has_target_generator:
logger.debug("Pipeline has target generator")
wrap_score = True
problem_type = model[0].problem_type

elif not isinstance(model, (str, BaseEstimator)):
Expand Down Expand Up @@ -343,12 +351,15 @@ def run_cross_validation( # noqa: C901
"The following model_params are incorrect: " f"{unused_params}"
)
has_target_transformer = pipeline_creator._added_target_transformer
has_target_generator = pipeline_creator._added_target_generator
pipeline = pipeline_creator.to_pipeline(
X_types=X_types, search_params=search_params
)

if has_target_transformer and not pipeline[-1].can_inverse_transform():
wrap_score = True
if has_target_generator:
wrap_score = True

# Log some information
logger.info("= Data Information =")
Expand Down
26 changes: 22 additions & 4 deletions julearn/scoring/available_scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from ..transformers.target.ju_transformed_target_model import (
TransformedTargetWarning,
)
from ..transformers.target.ju_generated_target_model import (
GeneratedTargetWarning,
)
from ..utils import logger, raise_error, warn_with_log
from ..utils.typing import EstimatorLike, ScorerLike
from .metrics import r2_corr, r_corr
Expand Down Expand Up @@ -145,11 +148,14 @@ def check_scoring(
scoring to check
wrap_score : bool
Does the score needs to be wrapped
to handle non_inverse transformable target pipelines.
to handle non_inverse transformable/generated target pipelines.
"""
if scoring is None:
return scoring
logger.debug(
f"Checking scoring{' (wrapping)' if wrap_score else ''}: [{scoring}]"
)
if isinstance(scoring, str):
scoring = _extend_scorer(get_scorer(scoring), wrap_score)
if callable(scoring):
Expand Down Expand Up @@ -188,12 +194,24 @@ def __call__(self, estimator, X, y): # noqa: N803
X_trans = X
for _, transform in estimator.steps[:-1]:
X_trans = transform.transform(X_trans)
y_true = estimator.steps[-1][-1].transform_target( # last est
X_trans, y
)
if hasattr(estimator.steps[-1][-1], "transform_target"):
y_true = estimator.steps[-1][-1].transform_target( # last est
X_trans, y
)
elif hasattr(estimator.steps[-1][-1], "generate_target"):
y_true = estimator.steps[-1][-1].generate_target(X_trans, y)
else:
raise_error(
"A scorer was wrapped to handle non-invertible or generated "
"but the model does not generate or transform the target."
)
with warnings.catch_warnings():
warnings.filterwarnings(
action="ignore", category=TransformedTargetWarning
)
warnings.filterwarnings(
action="ignore",
category=GeneratedTargetWarning,
)
scores = self.scorer(estimator, X, y_true)
return scores

0 comments on commit 9411da8

Please sign in to comment.