diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 29528feb515c..72f667385c5d 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -692,10 +692,19 @@ def print_to_file(s): class WandbCallback(TrainerCallback): """ - A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). + A [`TrainerCallback`] that logs metrics, media, evals, and model checkpoints to [Weight and Biases](https://www.wandb.com/). """ - def __init__(self): + def __init__( + self, + *, + trainer=None, + tokenizer=None, + dataset=None, + num_samples: int = 10, + freq: int = 1, + ignore_tokens: Optional[list] = None, + ): has_wandb = is_wandb_available() if not has_wandb: raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.") @@ -704,6 +713,48 @@ def __init__(self): self._wandb = wandb self._initialized = False + + # Setup for evals if user requests it + if os.getenv("WANDB_LOG_EVALS"): + if trainer is not None: + self.trainer = trainer + + if tokenizer is None: + tokenizer = self.trainer.tokenizer + self.tokenizer = tokenizer + + if dataset is None: + dataset = self.trainer.eval_dataset + + try: + sampled_dataset = dataset.select(range(num_samples)) + except IndexError as e: + print(f"WARNING: Could not get those indices: {e=}") + sampled_dataset = dataset + + self.sample_dataset = sampled_dataset + self.freq = freq + + if ignore_tokens is None: + ignore_tokens = [-100] + + padding_token_id = self.tokenizer.pad_token_id + + def replace_ignored_tokens(a): + if isinstance(a, np.ndarray): + mask = np.isin(a, ignore_tokens) + elif isinstance(a, torch.Tensor): + mask = torch.isin(a, torch.tensor(ignore_tokens, dtype=a.dtype)) + else: + raise TypeError(f"Unsupported type replace token type {type(a)}") + + a[mask] = padding_token_id + return a + + self._replace_ignored_tokens_func = replace_ignored_tokens + + self._collected_eval_rows = [] + # log model if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}): DeprecationWarning( @@ -933,6 +984,46 @@ def on_predict(self, args, state, control, metrics, **kwargs): metrics = rewrite_logs(metrics) self._wandb.log(metrics) + def on_evaluate(self, args, state, control, **kwargs): + if os.getenv("WANDB_LOG_EVALS"): + eval_loop_output = self.trainer.eval_loop_output + + inputs = eval_loop_output.inputs + decoded_inputs = None + if inputs is not None: + decoded_inputs = self.tokenizer.batch_decode(inputs, skip_special_tokens=True) + + preds = eval_loop_output.predictions + outputs = preds.argmax(axis=-1) + decoded_outputs = None + if outputs is not None: + decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + expected = eval_loop_output.label_ids + decoded_expected = None + if expected is not None: + expected = self._replace_ignored_tokens_func(expected) + decoded_expected = self.tokenizer.batch_decode(expected, skip_special_tokens=True) + + # Determine which fields are available + available_fields = [ + ("decoded_inputs", decoded_inputs), + ("decoded_outputs", decoded_outputs), + ("decoded_expected", decoded_expected), + ] + available_fields = [(name, value) for name, value in available_fields if value is not None] + + # Create rows using only available fields + for items in zip(*(value for _, value in available_fields)): + row = {name: item for (name, _), item in zip(available_fields, items)} + self._collected_eval_rows.append(row) + + if self._collected_eval_rows: + table = self._wandb.Table(columns=list(row.keys())) + for row in self._collected_eval_rows: + table.add_data(*row.values()) + self._wandb.log({"evaluation_table": table}) + class CometCallback(TrainerCallback): """ diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 71c3ee43af2c..88b2ce32a6c7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -570,6 +570,14 @@ def __init__( ) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + + # Add a reference to the trainer in case callbacks need it + def init_callback(cb): + cb.trainer = self + return cb + + callbacks = [init_callback(cb) for cb in callbacks] + self.callback_handler = CallbackHandler( callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler ) @@ -3660,6 +3668,7 @@ def evaluate( ) ) + self.eval_loop_output = output self.log(output.metrics) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: @@ -3939,7 +3948,9 @@ def evaluation_loop( if not key.startswith(f"{metric_key_prefix}_"): metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) - return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) + return EvalLoopOutput( + predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples, inputs=all_inputs + ) def _nested_gather(self, tensors, name=None): """ diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 5f6900658840..1eaa79d9609a 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -188,6 +188,7 @@ class EvalLoopOutput(NamedTuple): label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]] metrics: Optional[Dict[str, float]] num_samples: Optional[int] + inputs: Union[np.ndarray, Tuple[np.ndarray]] class PredictionOutput(NamedTuple):