Skip to content

Commit

Permalink
Skip and raise NotImplementedError for gather_for_metrics for now (#580)
Browse files Browse the repository at this point in the history
* Skip and raise NotImplementedError for now
  • Loading branch information
muellerzr authored Jul 28, 2022
1 parent c826b51 commit 5030571
Show file tree
Hide file tree
Showing 15 changed files with 119 additions and 14 deletions.
11 changes: 10 additions & 1 deletion examples/by_feature/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,23 @@ def training_function(config, args):
accelerator.save_state(output_dir)

model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True` (the default).
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
predictions, references = accelerator.gather((predictions, batch["labels"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
12 changes: 11 additions & 1 deletion examples/by_feature/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,23 @@ def training_function(config, args):
optimizer.zero_grad()

model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather((predictions, batch["labels"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
11 changes: 10 additions & 1 deletion examples/by_feature/fsdp_with_peak_mem_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,23 @@ def collate_fn(examples):
# context manager to track the peak memory usage during the evaluation
with TorchTracemalloc() as tracemalloc:
model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
predictions, references = accelerator.gather((predictions, batch["labels"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
12 changes: 11 additions & 1 deletion examples/by_feature/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,23 @@ def training_function(config, args):
optimizer.zero_grad()

model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather((predictions, batch["labels"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
12 changes: 11 additions & 1 deletion examples/by_feature/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,23 @@ def inner_training_loop(batch_size):
optimizer.zero_grad()

model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather((predictions, batch["labels"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
2 changes: 0 additions & 2 deletions examples/by_feature/multi_process_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ def training_function(config, args):
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
# All of this can be avoided if you use `Accelerator.gather_for_metrics` instead of `Accelerator.gather`:
# accelerator.gather_for_metrics((predictions, references))
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
11 changes: 10 additions & 1 deletion examples/by_feature/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,23 @@ def training_function(config, args):
optimizer.zero_grad()

model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True` (the default).
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
predictions, references = accelerator.gather((predictions, batch["labels"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
14 changes: 12 additions & 2 deletions examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,25 @@ def training_function(config, args):
accelerator.save_state(output_dir)
model.eval()
accurate = 0
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
inputs = (batch["image"] - mean) / std
with torch.no_grad():
outputs = model(inputs)
predictions = outputs.argmax(dim=-1)
predictions, labels = accelerator.gather_for_metrics((predictions, batch["label"]))
accurate_preds = predictions == labels
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather((predictions, batch["label"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
accurate_preds = predictions == references
accurate += accurate_preds.long().sum()

eval_metric = accurate.item() / accelerator.gradient_state.samples_seen
Expand Down
11 changes: 10 additions & 1 deletion examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,23 @@ def collate_fn(examples):
accelerator.save_state(output_dir)

model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
predictions, references = accelerator.gather((predictions, batch["labels"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
14 changes: 12 additions & 2 deletions examples/cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,25 @@ def training_function(config, args):
model.eval()
accurate = 0
num_elems = 0
samples_seen = 0
for _, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
inputs = (batch["image"] - mean) / std
with torch.no_grad():
outputs = model(inputs)
predictions = outputs.argmax(dim=-1)
predictions, labels = accelerator.gather_for_metrics((predictions, batch["label"]))
accurate_preds = predictions == labels
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather((predictions, batch["label"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
accurate_preds = predictions == references
num_elems += accurate_preds.shape[0]
accurate += accurate_preds.long().sum()

Expand Down
12 changes: 11 additions & 1 deletion examples/nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,23 @@ def training_function(config, args):
optimizer.zero_grad()

model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
# It is slightly faster to call this once, than multiple times
predictions, references = accelerator.gather((predictions, batch["labels"]))
if accelerator.use_distributed:
if step == len(eval_dataloader) - 1:
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,9 @@ def gather_for_metrics(self, tensor):
tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
The tensors for calculating metrics across all processes.
"""
raise NotImplementedError(
"Currently there are a number of bugs with this method. You should use `Accelerator.gather()` and drop the samples yourself for the time being."
)
tensor = self.gather(tensor)
if self.use_distributed:
if self.gradient_state.remainder == -1:
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
require_multi_gpu,
require_single_gpu,
require_tpu,
skip,
slow,
)
from .training import RegressionDataset, RegressionModel
Expand Down
5 changes: 5 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def parse_flag_from_env(key, default=False):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)


def skip(test_case):
"Decorator that skips a test unconditionally"
return unittest.skip("Test was skipped")(test_case)


def slow(test_case):
"""
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
Expand Down
2 changes: 2 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
require_cpu,
require_multi_gpu,
require_single_gpu,
skip,
test_metrics,
)
from accelerate.utils import get_launch_prefix, patch_environment


@skip
class MetricTester(unittest.TestCase):
def setUp(self):
mod_file = inspect.getfile(accelerate.test_utils)
Expand Down

0 comments on commit 5030571

Please sign in to comment.