Skip to content

Commit

Permalink
Refactor metrics_definitions.py (mlflow#10059)
Browse files Browse the repository at this point in the history
Signed-off-by: Bryan Qiu <[email protected]>
Signed-off-by: swathi <[email protected]>
  • Loading branch information
bbqiu authored and KonakanchiSwathi committed Nov 29, 2023
1 parent 0e0945c commit 81405bd
Showing 1 changed file with 88 additions and 95 deletions.
183 changes: 88 additions & 95 deletions mlflow/metrics/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

_logger = logging.getLogger(__name__)

targets_col_specifier = "the column specified by the `targets` parameter"
predictions_col_specifier = (
"the column specified by the `predictions` parameter or the model output column"
)


def standard_aggregations(scores):
return {
Expand All @@ -17,16 +22,16 @@ def standard_aggregations(scores):
}


def _validate_text_data(data, metric_name, column_name):
"""Validates that the data is text and is non-empty"""
if len(data) == 0:
def _validate_text_data(data, metric_name, col_specifier):
"""Validates that the data is a list of strs and is non-empty"""
if data is None or len(data) == 0:
return False

for row, line in enumerate(data):
if not isinstance(line, str):
_logger.warning(
f"Cannot calculate {metric_name} for non-string inputs. "
+ f"Non-string found for {column_name} on row {row}. skipping metric logging."
f"Non-string found for {col_specifier} on row {row}. Skipping metric logging."
)
return False

Expand Down Expand Up @@ -82,7 +87,7 @@ def _cached_evaluate_load(path, module_type=None):


def _toxicity_eval_fn(predictions, targets=None, metrics=None):
if not _validate_text_data(predictions, "toxicity", "predictions"):
if not _validate_text_data(predictions, "toxicity", predictions_col_specifier):
return
try:
toxicity = _cached_evaluate_load("toxicity", module_type="measurement")
Expand All @@ -106,7 +111,7 @@ def _toxicity_eval_fn(predictions, targets=None, metrics=None):


def _flesch_kincaid_eval_fn(predictions, targets=None, metrics=None):
if not _validate_text_data(predictions, "flesch_kincaid", "predictions"):
if not _validate_text_data(predictions, "flesch_kincaid", predictions_col_specifier):
return

try:
Expand All @@ -123,7 +128,7 @@ def _flesch_kincaid_eval_fn(predictions, targets=None, metrics=None):


def _ari_eval_fn(predictions, targets=None, metrics=None):
if not _validate_text_data(predictions, "ari", "predictions"):
if not _validate_text_data(predictions, "ari", predictions_col_specifier):
return

try:
Expand All @@ -150,111 +155,99 @@ def _accuracy_eval_fn(predictions, targets=None, metrics=None, sample_weight=Non


def _rouge1_eval_fn(predictions, targets=None, metrics=None):
if targets is not None and len(targets) != 0:
if not _validate_text_data(targets, "rouge1", "targets") or not _validate_text_data(
predictions, "rouge1", "predictions"
):
return
if not _validate_text_data(targets, "rouge1", targets_col_specifier) or not _validate_text_data(
predictions, "rouge1", predictions_col_specifier
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(
f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging."
)
return
try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rouge1"],
use_aggregator=False,
)["rouge1"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)
scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rouge1"],
use_aggregator=False,
)["rouge1"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)


def _rouge2_eval_fn(predictions, targets=None, metrics=None):
if targets is not None and len(targets) != 0:
if not _validate_text_data(targets, "rouge2", "targets") or not _validate_text_data(
predictions, "rouge2", "predictions"
):
return
if not _validate_text_data(targets, "rouge2", targets_col_specifier) or not _validate_text_data(
predictions, "rouge2", predictions_col_specifier
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(
f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging."
)
return
try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rouge2"],
use_aggregator=False,
)["rouge2"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)
scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rouge2"],
use_aggregator=False,
)["rouge2"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)


def _rougeL_eval_fn(predictions, targets=None, metrics=None):
if targets is not None and len(targets) != 0:
if not _validate_text_data(targets, "rougeL", "targets") or not _validate_text_data(
predictions, "rougeL", "predictions"
):
return
if not _validate_text_data(targets, "rougeL", targets_col_specifier) or not _validate_text_data(
predictions, "rougeL", predictions_col_specifier
):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(
f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging."
)
return
try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rougeL"],
use_aggregator=False,
)["rougeL"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)
scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rougeL"],
use_aggregator=False,
)["rougeL"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)


def _rougeLsum_eval_fn(predictions, targets=None, metrics=None):
if targets is not None and len(targets) != 0:
if not _validate_text_data(targets, "rougeLsum", "targets") or not _validate_text_data(
predictions, "rougeLsum", "predictions"
):
return
if not _validate_text_data(
targets, "rougeLsum", targets_col_specifier
) or not _validate_text_data(predictions, "rougeLsum", predictions_col_specifier):
return

try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(
f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging."
)
return
try:
rouge = _cached_evaluate_load("rouge")
except Exception as e:
_logger.warning(f"Failed to load 'rouge' metric (error: {e!r}), skipping metric logging.")
return

scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rougeLsum"],
use_aggregator=False,
)["rougeLsum"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)
scores = rouge.compute(
predictions=predictions,
references=targets,
rouge_types=["rougeLsum"],
use_aggregator=False,
)["rougeLsum"]
return MetricValue(
scores=scores,
aggregate_results=standard_aggregations(scores),
)


def _mae_eval_fn(predictions, targets=None, metrics=None, sample_weight=None):
Expand Down

0 comments on commit 81405bd

Please sign in to comment.