Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic eval table logging for WandbCallback #31050

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 93 additions & 2 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,19 @@ def print_to_file(s):

class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
A [`TrainerCallback`] that logs metrics, media, evals, and model checkpoints to [Weight and Biases](https://www.wandb.com/).
"""

def __init__(self):
def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a docstring for all these args. In particular num_samples and freq which aren't obvious

self,
*,
trainer=None,
tokenizer=None,
dataset=None,
num_samples: int = 10,
freq: int = 1,
ignore_tokens: Optional[list] = None,
):
has_wandb = is_wandb_available()
if not has_wandb:
raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
Expand All @@ -704,6 +713,48 @@ def __init__(self):

self._wandb = wandb
self._initialized = False

# Setup for evals if user requests it
if os.getenv("WANDB_LOG_EVALS"):
if trainer is not None:
self.trainer = trainer

if tokenizer is None:
tokenizer = self.trainer.tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes trainer is not None

self.tokenizer = tokenizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to set even if tokenizer is None


if dataset is None:
dataset = self.trainer.eval_dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - assumes self.trainer and self.trainer.dataset is not None


try:
sampled_dataset = dataset.select(range(num_samples))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes dataset is not None

except IndexError as e:
print(f"WARNING: Could not get those indices: {e=}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should log rather than print

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this a bit clearer - the user never specifies indices. so it's a bit weird to refer to them as "those indices". Maybe something along the lines of "Could not select {num_sample=} rows from the dataset"

sampled_dataset = dataset

self.sample_dataset = sampled_dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to store both the full and sampled dataset?

self.freq = freq

if ignore_tokens is None:
ignore_tokens = [-100]

padding_token_id = self.tokenizer.pad_token_id
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assumes tokenizer is not None. Confusingly, the tokenizer for Trainer may not be a tokenizer at all c.f. #32385


def replace_ignored_tokens(a):
if isinstance(a, np.ndarray):
mask = np.isin(a, ignore_tokens)
elif isinstance(a, torch.Tensor):
mask = torch.isin(a, torch.tensor(ignore_tokens, dtype=a.dtype))
else:
raise TypeError(f"Unsupported type replace token type {type(a)}")

a[mask] = padding_token_id
return a

self._replace_ignored_tokens_func = replace_ignored_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this functionality in the callback?


self._collected_eval_rows = []

# log model
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
Expand Down Expand Up @@ -933,6 +984,46 @@ def on_predict(self, args, state, control, metrics, **kwargs):
metrics = rewrite_logs(metrics)
self._wandb.log(metrics)

def on_evaluate(self, args, state, control, **kwargs):
if os.getenv("WANDB_LOG_EVALS"):
eval_loop_output = self.trainer.eval_loop_output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is eval_loop_output coming from exactly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching, I missed this commit. It's added here:
b8d5c6e


inputs = eval_loop_output.inputs
decoded_inputs = None
if inputs is not None:
decoded_inputs = self.tokenizer.batch_decode(inputs, skip_special_tokens=True)

preds = eval_loop_output.predictions
outputs = preds.argmax(axis=-1)
decoded_outputs = None
if outputs is not None:
decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

expected = eval_loop_output.label_ids
decoded_expected = None
if expected is not None:
expected = self._replace_ignored_tokens_func(expected)
decoded_expected = self.tokenizer.batch_decode(expected, skip_special_tokens=True)

# Determine which fields are available
available_fields = [
("decoded_inputs", decoded_inputs),
("decoded_outputs", decoded_outputs),
("decoded_expected", decoded_expected),
]
available_fields = [(name, value) for name, value in available_fields if value is not None]

# Create rows using only available fields
for items in zip(*(value for _, value in available_fields)):
row = {name: item for (name, _), item in zip(available_fields, items)}
self._collected_eval_rows.append(row)

if self._collected_eval_rows:
table = self._wandb.Table(columns=list(row.keys()))
for row in self._collected_eval_rows:
table.add_data(*row.values())
self._wandb.log({"evaluation_table": table})


class CometCallback(TrainerCallback):
"""
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,14 @@ def __init__(
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks

# Add a reference to the trainer in case callbacks need it
def init_callback(cb):
cb.trainer = self
return cb

callbacks = [init_callback(cb) for cb in callbacks]
Comment on lines +575 to +579
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way we can do this rather than adding the trainer to self here? I'm not a fan of this because otherwise if users are saving states, they have an entire reference to the trainer in there, not good.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't love this either, but it's a tradeoff.

  1. The current Trainer reporting interface looks like report_to="wandb" which is fine if your callback doesn't need any args/kwargs, but in this case we do (the tokenizer, dataset, etc.) and the one object that has all of these is the Trainer.
  2. The alternative is also implemented in this PR, but you can't pass report_to="wandb". I think counter-intuitively you have to NOT report to wandb. Instead, you need to instantiate a callback and manually pass it in -- not the end of the world, but it didn't seem idiomatic.
trainer = Trainer(...)
wandb_callback = WandbCallback(..., tokenizer=..., dataset=...)
trainer.add_callback(wandb_callback)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd much prefer the non-idiomatic way please


self.callback_handler = CallbackHandler(
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
)
Expand Down Expand Up @@ -3660,6 +3668,7 @@ def evaluate(
)
)

self.eval_loop_output = output
self.log(output.metrics)

if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
Expand Down Expand Up @@ -3939,7 +3948,9 @@ def evaluation_loop(
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
return EvalLoopOutput(
predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples, inputs=all_inputs
)
Comment on lines +3951 to +3953
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some things to be very careful with, I'd appreciate checking the memory usage before/after this change. To make sure we don't have a memory leak, and we don't increase the VRAM used by the user by utilizing this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any profiling tools you recommend / would want to see?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wandb logs work just fine :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the update here?


def _nested_gather(self, tensors, name=None):
"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class EvalLoopOutput(NamedTuple):
label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
metrics: Optional[Dict[str, float]]
num_samples: Optional[int]
inputs: Union[np.ndarray, Tuple[np.ndarray]]


class PredictionOutput(NamedTuple):
Expand Down