Skip to content

Commit

Permalink
Fix KerasMetricCallback prediction with generate() and inference of c…
Browse files Browse the repository at this point in the history
…olumn names (#15351)

* Fix prediction with generate() and the inference of column names
Should now have very few differences with the PyTorch implementation

* Minor edit to parent class

* Update src/transformers/keras_callbacks.py

Co-authored-by: Sylvain Gugger <[email protected]>

* Explaining the dict conversion

* Putting main_input_name back

* Fixes to main_input_name

Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
Rocketknight1 and sgugger authored Jan 27, 2022
1 parent da5ef25 commit 6beae76
Showing 1 changed file with 60 additions and 21 deletions.
81 changes: 60 additions & 21 deletions src/transformers/keras_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ def rouge_fn(predictions, labels):
metric names to numerical values.
eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
Validation data to be used to generate predictions for the `metric_fn`.
metric_fn_kwargs (`dict`, *optional*):
Additional keyword arguments to be passed to the metric_fn.
output_cols (`List[str], *optional*):
A list of columns to be retained from the model output as the predictions. Defaults to all.
label_cols ('`List[str]`, *optional*'):
Expand All @@ -74,7 +72,6 @@ def __init__(
self,
metric_fn: Callable,
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
metric_fn_kwargs: Optional[dict] = None,
output_cols: Optional[List[str]] = None,
label_cols: Optional[List[str]] = None,
batch_size: Optional[int] = None,
Expand All @@ -94,12 +91,6 @@ def __init__(
self.eval_dataset = eval_dataset
self.predict_with_generate = predict_with_generate
self.output_cols = output_cols
self.metric_fn_kwargs = metric_fn_kwargs or dict()

if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
self.main_input_name = self.model.encoder.main_input_name
else:
self.main_input_name = self.model.main_input_name

# This next block attempts to parse out which elements of the dataset should be appended to the labels list
# that is passed to the metric_fn
Expand All @@ -123,32 +114,75 @@ def __init__(
self.label_cols = ["labels"]
self.use_keras_label = False
logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
elif "start_positions" in input_spec and "end_positions" in input_spec:
self.label_cols = ["start_positions", "end_positions"]
self.use_keras_label = False
logging.warning(
"No label_cols specified for KerasMetricCallback, assuming you want the "
"start_positions and end_positions keys."
)
else:
raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
if parse(tf.__version__).minor < parse("2.7"):
if parse(tf.__version__) < parse("2.7"):
logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")

@staticmethod
def _concatenate_batches(batches):
# Flattens Numpy array batches into a list of single samples, where each sample is still np.ndarray
return [sample for batch in batches for sample in batch]
def _concatenate_batches(batches, padding_index=-100):
# If all batches are unidimensional or same length, do a simple concatenation
if batches[0].ndim == 1 or all([batch.shape[1] == batches[0].shape[1] for batch in batches]):
return np.concatenate(batches, axis=0)

# Welp, they're not the same length. Let's do some padding
max_len = max([batch.shape[1] for batch in batches])
num_samples = sum([batch.shape[0] for batch in batches])
output = np.full_like(
batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
)
# i keeps track of which part of the concatenated array we're writing the next batch to
i = 0
for batch in batches:
output[i : i + len(batch), : batch.shape[1]] = batch
i += len(batch)
return output

def _postprocess_predictions_or_labels(self, inputs):
if isinstance(inputs[0], dict):
outputs = dict()
for key in inputs[0].keys():
outputs[key] = self._concatenate_batches(batch[key] for batch in inputs)
outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
# If it's a dict with only one key, just return the array
if len(outputs) == 1:
outputs = list(outputs.values())[0]
elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
outputs = []
for input_list in zip(*inputs):
outputs.append(self._concatenate_batches(input_list))
if len(outputs) == 1:
outputs = outputs[0] # If it's a list with only one element, just return the array
elif isinstance(inputs[0], np.ndarray):
outputs = self._concatenate_batches(inputs)
elif isinstance(inputs[0], tf.Tensor):
outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
else:
raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
return outputs

def on_epoch_end(self, epoch, logs=None):
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []

main_input_name = None
if self.predict_with_generate:
# This dense conditional recognizes the case where we have an encoder-decoder model, but
# avoids getting tangled up when we just have a model with a layer called 'encoder'
if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
if self.model.encoder.main_input_name != self.model.main_input_name:
main_input_name = self.model.encoder.main_input_name
else:
main_input_name = getattr(self.model, "main_input_name", "input_ids")

prediction_list = []
label_list = []

Expand All @@ -160,7 +194,7 @@ def on_epoch_end(self, epoch, logs=None):
labels = None
if self.predict_with_generate:
if isinstance(batch, dict):
generation_inputs = batch[self.main_input_name]
generation_inputs = batch[main_input_name]
attention_mask = batch.get("attention_mask", None)
else:
generation_inputs = batch
Expand All @@ -169,9 +203,14 @@ def on_epoch_end(self, epoch, logs=None):
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
else:
predictions = self.model.predict(batch)
predictions = dict(predictions)
if self.output_cols is not None:
predictions = {key: predictions[key] for key in self.output_cols}
if isinstance(predictions, dict):
# This converts any dict-subclass to a regular dict
# Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
predictions = dict(predictions)
if self.output_cols is not None:
predictions = {key: predictions[key] for key in self.output_cols}
else:
predictions = {key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]}
prediction_list.append(predictions)
if not self.use_keras_label:
labels = {key: batch[key].numpy() for key in self.label_cols}
Expand All @@ -185,10 +224,10 @@ def on_epoch_end(self, epoch, logs=None):
raise TypeError(f"Confused by labels of type {type(labels)}")
label_list.append(labels)

prediction_list = self._postprocess_predictions_or_labels(prediction_list)
label_list = self._postprocess_predictions_or_labels(label_list)
all_preds = self._postprocess_predictions_or_labels(prediction_list)
all_labels = self._postprocess_predictions_or_labels(label_list)

metric_output = self.metric_fn(prediction_list, label_list, **self.metric_fn_kwargs)
metric_output = self.metric_fn((all_preds, all_labels))
if not isinstance(metric_output, dict):
raise TypeError(
f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
Expand Down

0 comments on commit 6beae76

Please sign in to comment.