Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor metrics_definitions.py #10059

Merged
merged 11 commits into from
Oct 25, 2023
Merged
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 95 additions & 102 deletions mlflow/metrics/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

_logger = logging.getLogger(__name__)

targets_err_msg = "the column specified by the `targets` parameter"
predictions_err_msg = (
"the column specified by the `predictions` parameter or the model output column"
)


def standard_aggregations(scores):
return {
Expand All @@ -18,15 +23,15 @@ def standard_aggregations(scores):


def _validate_text_data(data, metric_name, column_name):
"""Validates that the data is text and is non-empty"""
if len(data) == 0:
"""Validates that the data is a list of strs and is non-empty"""
if data is None or len(data) == 0:
return False

for row, line in enumerate(data):
if not isinstance(line, str):
_logger.warning(
f"Cannot calculate {metric_name} for non-string inputs. "
+ f"Non-string found for {column_name} on row {row}. skipping metric logging."
f"Non-string found for {column_name} on row {row}. Skipping metric logging."
)
return False

Expand Down Expand Up @@ -60,7 +65,7 @@ def _cached_evaluate_load(path, module_type=None):


def _toxicity_eval_fn(predictions, targets=None, metrics=None):
if not _validate_text_data(predictions, "toxicity", "predictions"):
if not _validate_text_data(predictions, "toxicity", predictions_err_msg):
return
try:
toxicity = _cached_evaluate_load("toxicity", module_type="measurement")
Expand All @@ -84,7 +89,7 @@ def _toxicity_eval_fn(predictions, targets=None, metrics=None):


def _flesch_kincaid_eval_fn(predictions, targets=None, metrics=None):
if not _validate_text_data(predictions, "flesch_kincaid", "predictions"):
if not _validate_text_data(predictions, "flesch_kincaid", predictions_err_msg):
return

try:
Expand All @@ -101,7 +106,7 @@ def _flesch_kincaid_eval_fn(predictions, targets=None, metrics=None):


def _ari_eval_fn(predictions, targets=None, metrics=None):
if not _validate_text_data(predictions, "ari", "predictions"):
if not _validate_text_data(predictions, "ari", predictions_err_msg):
return

try:
Expand All @@ -128,111 +133,99 @@ def _accuracy_eval_fn(predictions, targets=None, metrics=None, sample_weight=Non


def _rouge1_eval_fn(predictions, targets=None, metrics=None):
if targets is not None and len(targets) != 0:
if not _validate_text_data(targets, "rouge1", "targets") or not _validate_text_data(
predictions, "rouge1", "predictions"
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(
f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging."
)
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rouge1"],
use_aggregator=False,
)["rouge1"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)
if not _validate_text_data(targets, "rouge1", targets_err_msg) or not _validate_text_data(
predictions, "rouge1", predictions_err_msg
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rouge1"],
use_aggregator=False,
)["rouge1"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)


def _rouge2_eval_fn(predictions, targets=None, metrics=None):
if targets is not None and len(targets) != 0:
if not _validate_text_data(targets, "rouge2", "targets") or not _validate_text_data(
predictions, "rouge2", "predictions"
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(
f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging."
)
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rouge2"],
use_aggregator=False,
)["rouge2"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)
if not _validate_text_data(targets, "rouge2", targets_err_msg) or not _validate_text_data(
predictions, "rouge2", predictions_err_msg
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rouge2"],
use_aggregator=False,
)["rouge2"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)


def _rougeL_eval_fn(predictions, targets=None, metrics=None):
if targets is not None and len(targets) != 0:
if not _validate_text_data(targets, "rougeL", "targets") or not _validate_text_data(
predictions, "rougeL", "predictions"
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(
f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging."
)
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rougeL"],
use_aggregator=False,
)["rougeL"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)
if not _validate_text_data(targets, "rougeL", targets_err_msg) or not _validate_text_data(
predictions, "rougeL", predictions_err_msg
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rougeL"],
use_aggregator=False,
)["rougeL"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)


def _rougeLsum_eval_fn(predictions, targets=None, metrics=None):
if targets is not None and len(targets) != 0:
if not _validate_text_data(targets, "rougeLsum", "targets") or not _validate_text_data(
predictions, "rougeLsum", "predictions"
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(
f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging."
)
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rougeLsum"],
use_aggregator=False,
)["rougeLsum"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)
if not _validate_text_data(targets, "rougeLsum", targets_err_msg) or not _validate_text_data(
predictions, "rougeLsum", predictions_err_msg
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rougeLsum"],
use_aggregator=False,
)["rougeLsum"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)


def _mae_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
Expand Down