-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic eval table logging for WandbCallback #31050
base: main
Are you sure you want to change the base?
Changes from all commits
27d7393
4a68300
bf1f801
e4dec9e
c84d0c4
a80db7a
b8d5c6e
b176f29
2d9b503
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -692,10 +692,19 @@ def print_to_file(s): | |
|
||
class WandbCallback(TrainerCallback): | ||
""" | ||
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). | ||
A [`TrainerCallback`] that logs metrics, media, evals, and model checkpoints to [Weight and Biases](https://www.wandb.com/). | ||
""" | ||
|
||
def __init__(self): | ||
def __init__( | ||
self, | ||
*, | ||
trainer=None, | ||
tokenizer=None, | ||
dataset=None, | ||
num_samples: int = 10, | ||
freq: int = 1, | ||
ignore_tokens: Optional[list] = None, | ||
): | ||
has_wandb = is_wandb_available() | ||
if not has_wandb: | ||
raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.") | ||
|
@@ -704,6 +713,48 @@ def __init__(self): | |
|
||
self._wandb = wandb | ||
self._initialized = False | ||
|
||
# Setup for evals if user requests it | ||
if os.getenv("WANDB_LOG_EVALS"): | ||
if trainer is not None: | ||
self.trainer = trainer | ||
|
||
if tokenizer is None: | ||
tokenizer = self.trainer.tokenizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes |
||
self.tokenizer = tokenizer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to set even if |
||
|
||
if dataset is None: | ||
dataset = self.trainer.eval_dataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here - assumes self.trainer and self.trainer.dataset is not None |
||
|
||
try: | ||
sampled_dataset = dataset.select(range(num_samples)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes |
||
except IndexError as e: | ||
print(f"WARNING: Could not get those indices: {e=}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should log rather than print There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we make this a bit clearer - the user never specifies indices. so it's a bit weird to refer to them as "those indices". Maybe something along the lines of "Could not select {num_sample=} rows from the dataset" |
||
sampled_dataset = dataset | ||
|
||
self.sample_dataset = sampled_dataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to store both the full and sampled dataset? |
||
self.freq = freq | ||
|
||
if ignore_tokens is None: | ||
ignore_tokens = [-100] | ||
|
||
padding_token_id = self.tokenizer.pad_token_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assumes tokenizer is not None. Confusingly, the |
||
|
||
def replace_ignored_tokens(a): | ||
if isinstance(a, np.ndarray): | ||
mask = np.isin(a, ignore_tokens) | ||
elif isinstance(a, torch.Tensor): | ||
mask = torch.isin(a, torch.tensor(ignore_tokens, dtype=a.dtype)) | ||
else: | ||
raise TypeError(f"Unsupported type replace token type {type(a)}") | ||
|
||
a[mask] = padding_token_id | ||
return a | ||
|
||
self._replace_ignored_tokens_func = replace_ignored_tokens | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this functionality in the callback? |
||
|
||
self._collected_eval_rows = [] | ||
|
||
# log model | ||
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}): | ||
DeprecationWarning( | ||
|
@@ -933,6 +984,46 @@ def on_predict(self, args, state, control, metrics, **kwargs): | |
metrics = rewrite_logs(metrics) | ||
self._wandb.log(metrics) | ||
|
||
def on_evaluate(self, args, state, control, **kwargs): | ||
if os.getenv("WANDB_LOG_EVALS"): | ||
eval_loop_output = self.trainer.eval_loop_output | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching, I missed this commit. It's added here: |
||
|
||
inputs = eval_loop_output.inputs | ||
decoded_inputs = None | ||
if inputs is not None: | ||
decoded_inputs = self.tokenizer.batch_decode(inputs, skip_special_tokens=True) | ||
|
||
preds = eval_loop_output.predictions | ||
outputs = preds.argmax(axis=-1) | ||
decoded_outputs = None | ||
if outputs is not None: | ||
decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
|
||
expected = eval_loop_output.label_ids | ||
decoded_expected = None | ||
if expected is not None: | ||
expected = self._replace_ignored_tokens_func(expected) | ||
decoded_expected = self.tokenizer.batch_decode(expected, skip_special_tokens=True) | ||
|
||
# Determine which fields are available | ||
available_fields = [ | ||
("decoded_inputs", decoded_inputs), | ||
("decoded_outputs", decoded_outputs), | ||
("decoded_expected", decoded_expected), | ||
] | ||
available_fields = [(name, value) for name, value in available_fields if value is not None] | ||
|
||
# Create rows using only available fields | ||
for items in zip(*(value for _, value in available_fields)): | ||
row = {name: item for (name, _), item in zip(available_fields, items)} | ||
self._collected_eval_rows.append(row) | ||
|
||
if self._collected_eval_rows: | ||
table = self._wandb.Table(columns=list(row.keys())) | ||
for row in self._collected_eval_rows: | ||
table.add_data(*row.values()) | ||
self._wandb.log({"evaluation_table": table}) | ||
|
||
|
||
class CometCallback(TrainerCallback): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -570,6 +570,14 @@ def __init__( | |
) | ||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) | ||
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks | ||
|
||
# Add a reference to the trainer in case callbacks need it | ||
def init_callback(cb): | ||
cb.trainer = self | ||
return cb | ||
|
||
callbacks = [init_callback(cb) for cb in callbacks] | ||
Comment on lines
+575
to
+579
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a better way we can do this rather than adding the trainer to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't love this either, but it's a tradeoff.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd much prefer the non-idiomatic way please |
||
|
||
self.callback_handler = CallbackHandler( | ||
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler | ||
) | ||
|
@@ -3660,6 +3668,7 @@ def evaluate( | |
) | ||
) | ||
|
||
self.eval_loop_output = output | ||
self.log(output.metrics) | ||
|
||
if DebugOption.TPU_METRICS_DEBUG in self.args.debug: | ||
|
@@ -3939,7 +3948,9 @@ def evaluation_loop( | |
if not key.startswith(f"{metric_key_prefix}_"): | ||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) | ||
|
||
return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) | ||
return EvalLoopOutput( | ||
predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples, inputs=all_inputs | ||
) | ||
Comment on lines
+3951
to
+3953
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some things to be very careful with, I'd appreciate checking the memory usage before/after this change. To make sure we don't have a memory leak, and we don't increase the VRAM used by the user by utilizing this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any profiling tools you recommend / would want to see? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the update here? |
||
|
||
def _nested_gather(self, tensors, name=None): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a docstring for all these args. In particular
num_samples
andfreq
which aren't obvious