Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] fix the placement on device with fp16_full_eval #11322

Merged
merged 2 commits into from
Apr 19, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 58 additions & 57 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def __init__(
self.place_model_on_device = args.place_model_on_device
if (
self.is_model_parallel
or (args.deepspeed and args.do_train)
or args.deepspeed
or (args.fp16_full_eval and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
):
Expand Down Expand Up @@ -954,8 +954,15 @@ def train(
# memory metrics - must set up as early as possible
self._memory_tracker.start()

args = self.args

self.is_in_train = True

# do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround:
if args.fp16_full_eval and not args.do_train:
self.model = self.model.to(args.device)
Comment on lines +961 to +964
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, that's exactly what I wanted! (but wasn't sure of the exact test)


if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path")
warnings.warn(
Expand All @@ -972,17 +979,17 @@ def train(
model_reloaded = False
if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init.
set_seed(self.args.seed)
set_seed(args.seed)
self.model = self.call_model_init(trial)
model_reloaded = True
# Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None

# Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
resume_from_checkpoint = get_last_checkpoint(args.output_dir)
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

if resume_from_checkpoint is not None:
if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
Expand All @@ -1003,7 +1010,7 @@ def train(
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
self.model = self.model.to(args.device)
self.model_wrapped = self.model

# Keeping track whether we can can len() on the dataset or not
Expand All @@ -1017,24 +1024,24 @@ def train(
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
if train_dataset_is_sized:
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
else:
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(self.args.num_train_epochs)
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
else:
# see __init__. max_steps is set when the dataset has no __len__
max_steps = self.args.max_steps
num_train_epochs = int(self.args.num_train_epochs)
max_steps = args.max_steps
num_train_epochs = int(args.num_train_epochs)
num_update_steps_per_epoch = max_steps

delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed:
if args.deepspeed:
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
)
Expand Down Expand Up @@ -1068,24 +1075,22 @@ def train(
# Train!
if is_torch_tpu_available():
world_size = xm.xrt_world_size()
elif self.args.local_rank != -1:
elif args.local_rank != -1:
world_size = dist.get_world_size()
else:
world_size = 1

total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps * world_size
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
num_examples = (
self.num_examples(train_dataloader)
if train_dataset_is_sized
else total_train_batch_size * self.args.max_steps
self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
)

logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")

self.state.epoch = 0
Expand All @@ -1099,16 +1104,16 @@ def train(
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json"))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not self.args.ignore_data_skip:
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0

logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not self.args.ignore_data_skip:
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
"batches in the first epoch."
Expand All @@ -1129,17 +1134,17 @@ def train(
self.state.is_world_process_zero = self.is_world_process_zero()

# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(self.args.device)
tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
self._total_flos = self.state.total_flos
model.zero_grad()

self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not self.args.ignore_data_skip:
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
Expand All @@ -1152,23 +1157,19 @@ def train(
train_dataloader.dataset.set_epoch(epoch)

if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
self.args.device
)
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = parallel_loader
else:
epoch_iterator = train_dataloader

# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
if args.past_index >= 0:
self._past = None

steps_in_epoch = (
len(epoch_iterator)
if train_dataset_is_sized
else self.args.max_steps * self.args.gradient_accumulation_steps
len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

for step, inputs in enumerate(epoch_iterator):

Expand All @@ -1177,13 +1178,13 @@ def train(
steps_trained_in_current_epoch -= 1
continue

if step % self.args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

if (
((step + 1) % self.args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1
and self.args._no_sync_in_gradient_accumulation
((step + 1) % args.gradient_accumulation_steps != 0)
and args.local_rank != -1
and args._no_sync_in_gradient_accumulation
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
Expand All @@ -1196,13 +1197,13 @@ def train(
if self.deepspeed:
self.deepspeed.step()

if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
if (step + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= self.args.gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# Gradient clipping
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed:
if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
# deepspeed does its own clipping

if self.use_amp:
Expand All @@ -1211,15 +1212,15 @@ def train(

if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
self.optimizer.clip_grad_norm(args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(self.args.max_grad_norm)
model.clip_grad_norm_(args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
self.args.max_grad_norm,
args.max_grad_norm,
)

# Optimizer step
Expand All @@ -1243,17 +1244,17 @@ def train(
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)

if self.control.should_epoch_stop or self.control.should_training_stop:
break

self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)

if self.args.tpu_metrics_debug or self.args.debug:
if args.tpu_metrics_debug or args.debug:
if is_torch_tpu_available():
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
Expand All @@ -1265,16 +1266,16 @@ def train(
if self.control.should_training_stop:
break

if self.args.past_index and hasattr(self, "_past"):
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")

logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if self.args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sur the model has been saved by process 0.
if is_torch_tpu_available():
xm.rendezvous("load_best_model_at_end")
elif self.args.local_rank != -1:
elif args.local_rank != -1:
dist.barrier()

logger.info(
Expand All @@ -1283,7 +1284,7 @@ def train(
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if self.place_model_on_device:
self.model = self.model.to(self.args.device)
self.model = self.model.to(args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
Expand All @@ -1299,7 +1300,7 @@ def train(
metrics["total_flos"] = self.state.total_flos
self.log(metrics)

self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()

Expand Down Expand Up @@ -1952,7 +1953,7 @@ def evaluation_loop(
model = self._wrap_model(self.model, training=False)

# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device
# ``train`` is running, halve it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)

Expand Down Expand Up @@ -2288,7 +2289,7 @@ def prediction_loop(
model = self._wrap_model(self.model, training=False)

# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
# ``train`` is running, half it first and then put on device
# ``train`` is running, halve it first and then put on device
if not self.is_in_train and self.args.fp16_full_eval:
model = model.half().to(self.args.device)

Expand Down