Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update examples to show how to deal with extra validation copies #319

Merged
merged 8 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions examples/by_feature/multi_process_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse

import torch
from torch.utils.data import DataLoader

from accelerate import Accelerator, DistributedType
from datasets import load_dataset, load_metric
from transformers import (
AdamW,
AutoModelForSequenceClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
set_seed,
)


########################################################################
# This is a fully working simple example to use Accelerate,
# specifically showcasing how to properly calculate the metrics on the
# validation dataset when in a distributed system, and builds off the
# `nlp_example.py` script.
#
# This example trains a Bert base model on GLUE MRPC
# in any of the following settings (with the same script):
# - single CPU or single GPU
# - multi GPUS (using PyTorch distributed mode)
# - (multi) TPUs
# - fp16 (mixed-precision) or fp32 (normal precision)
#
# To help focus on the differences in the code, building `DataLoaders`
# was refactored into its own function.
# New additions from the base script can be found quickly by
# looking for the # New Code # tags
#
# To run it in each of these various modes, follow the instructions
# in the readme for examples:
# https://github.com/huggingface/accelerate/tree/main/examples
#
########################################################################


MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32


def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
"""
Creates a set of `DataLoader`s for the `glue` dataset,
using "bert-base-cased" as the tokenizer.

Args:
accelerator (`Accelerator`):
An `Accelerator` object
batch_size (`int`, *optional*):
The batch size for the train and validation DataLoaders.
"""
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
datasets = load_dataset("glue", "mrpc")

def tokenize_function(examples):
# max_length=None => use the model max length (it's actually the default)
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs

# Apply the method we just defined to all the examples in all the splits of the dataset
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)

# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
# transformers library
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

def collate_fn(examples):
# On TPU it's best to pad everything to the same length or training will be very slow.
if accelerator.distributed_type == DistributedType.TPU:
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
return tokenizer.pad(examples, padding="longest", return_tensors="pt")

# Instantiate dataloaders.
train_dataloader = DataLoader(
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
)
eval_dataloader = DataLoader(
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
)

return train_dataloader, eval_dataloader


def training_function(config, args):
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
num_epochs = int(config["num_epochs"])
correct_bias = config["correct_bias"]
seed = int(config["seed"])
batch_size = int(config["batch_size"])

metric = load_metric("glue", "mrpc")

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

set_seed(seed)
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)

# We could avoid this line since the accelerator is set with `device_placement=True` (default value).
# Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
# creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
model = model.to(accelerator.device)

# Instantiate optimizer
optimizer = AdamW(params=model.parameters(), lr=lr, correct_bias=correct_bias)

# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)

# Prepare everything
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
# prepare method.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

# Now we train the model
for epoch in range(num_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
if step % gradient_accumulation_steps == 0:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# New Code #
# First we check if it's a distributed system
if accelerator.num_processes > 1:
# Then see if we're on the last batch of our eval dataloader
if step == len(eval_dataloader):
# Last batch needs to be truncated on distributed systems as it contains additional samples
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
# Otherwise we add the number of samples seen
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
)

eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)


def main():
parser = argparse.ArgumentParser(description="Simple example of training script.")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
)
parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": 3, "correct_bias": True, "seed": 42, "batch_size": 16}
training_function(config, args)


if __name__ == "__main__":
main()
33 changes: 20 additions & 13 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
if SRC_DIRS is not None:
import checkpointing
import cross_validation
import multi_process_metrics
import tracking

# DataLoaders built from `test_samples/MRPC` for quick testing
# Should mock `{script_name}.get_dataloaders` via:
# @mock.patch("{script_name}.get_dataloaders", mocked_dataloaders)

EXCLUDE_EXAMPLES = ["cross_validation.py"]
EXCLUDE_EXAMPLES = ["cross_validation.py", "multi_process_metrics.py"]


def mocked_dataloaders(accelerator, batch_size: int = 16):
Expand Down Expand Up @@ -182,18 +183,6 @@ def test_checkpointing_by_steps(self):
checkpointing.main()
self.assertTrue(os.path.exists(os.path.join(tmpdir, "step_2")))

@mock.patch("tracking.get_dataloaders", mocked_dataloaders)
def test_tracking(self):
with tempfile.TemporaryDirectory() as tmpdir:
testargs = f"""
tracking.py
--with_tracking
--logging_dir {tmpdir}
""".split()
with mock.patch.object(sys, "argv", testargs):
tracking.main()
self.assertTrue(os.path.exists(os.path.join(tmpdir, "tracking")))

@slow
def test_cross_validation(self):
testargs = """
Expand All @@ -205,3 +194,21 @@ def test_cross_validation(self):
cross_validation.main()
call = mocked_print.mock_calls[-1]
self.assertGreaterEqual(call.args[1]["accuracy"], 0.75)

@mock.patch("multi_process_metrics.get_dataloaders", mocked_dataloaders)
def test_multi_process_metrics(self):
testargs = ["multi_process_metrics.py"]
with mock.patch.object(sys, "argv", testargs):
multi_process_metrics.main()

@mock.patch("tracking.get_dataloaders", mocked_dataloaders)
def test_tracking(self):
with tempfile.TemporaryDirectory() as tmpdir:
testargs = f"""
tracking.py
--with_tracking
--logging_dir {tmpdir}
""".split()
with mock.patch.object(sys, "argv", testargs):
tracking.main()
self.assertTrue(os.path.exists(os.path.join(tmpdir, "tracking")))