You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Related to the discussion in #8771 ( #8771 (comment) ) that suggests MP can be done in 4.3.0 just by calling model.parallelize() after loading. I made another issue rather than hijack that one that's about MP improvements in general.
Information
Model I am using (Bert, XLNet ...): T5
The problem arises when using:
my own modified scripts: (give details below)
Added one line to funetune_trainer.py after model is loaded ( model.parallelize(), see below)
+++ b/examples/seq2seq/finetune_trainer.py
@@ -215,6 +215,9 @@ def main():
# use task specific params
use_task_specific_params(model, data_args.task)
+ # PJ: Parallelize model
+ model.parallelize()
+
# set num_beams for evaluation
if data_args.eval_beams is None:
data_args.eval_beams = model.config.num_beams
The tasks I am working on is:
an official GLUE/SQUaD task: Running the example on an official task/dataset (seq2seq)
To reproduce
Steps to reproduce the behavior:
On 4.3.0-dev (tonight):
Fresh pull of transformers. Add change above ( model.parallelize() ).
Run runscript (below). Error appears to reproduce for any sized model (e.g. I'm using t5-11b, but also happens under t5-large).
...
[INFO|modeling_utils.py:1152] 2021-01-21 00:52:03,923 >> All the weights of T5ForConditionalGeneration were initialized from the model checkpoint at t5-large.
If your task is similar to the task the model of the checkpoint was trained on, you can already use T5ForConditionalGeneration for predictions without further training.
01/21/2021 00:52:03 - INFO - utils - setting model.config to task specific params for summarization:
{'early_stopping': True, 'length_penalty': 2.0, 'max_length': 200, 'min_length': 30, 'no_repeat_ngram_size': 3, 'num_beams': 4, 'prefix': 'summarize: '}
01/21/2021 00:52:03 - INFO - utils - note: command line args may override some of these
[INFO|trainer.py:362] 2021-01-21 00:52:14,376 >> Using amp fp16 backend
01/21/2021 00:52:14 - INFO - __main__ - *** Train ***
[INFO|trainer.py:813] 2021-01-21 00:52:14,383 >> ***** Running training *****
[INFO|trainer.py:814] 2021-01-21 00:52:14,383 >> Num examples = 204016
[INFO|trainer.py:815] 2021-01-21 00:52:14,383 >> Num Epochs = 1
[INFO|trainer.py:816] 2021-01-21 00:52:14,383 >> Instantaneous batch size per device = 8
[INFO|trainer.py:817] 2021-01-21 00:52:14,383 >> Total train batch size (w. parallel, distributed & accumulation) = 8
[INFO|trainer.py:818] 2021-01-21 00:52:14,383 >> Gradient Accumulation steps = 1
[INFO|trainer.py:819] 2021-01-21 00:52:14,383 >> Total optimization steps = 25502
0%| | 0/25502 [00:00<?, ?it/s]Traceback (most recent call last):
File "finetune_trainer.py", line 370, in <module>
main()
File "finetune_trainer.py", line 301, in main
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
File "/home/pajansen/anaconda3/envs/transformers-4.1.1a/lib/python3.7/site-packages/transformers/trainer.py", line 910, in train
tr_loss += self.training_step(model, inputs)
File "/home/pajansen/anaconda3/envs/transformers-4.1.1a/lib/python3.7/site-packages/transformers/trainer.py", line 1272, in training_step
loss = self.compute_loss(model, inputs)
File "/home/pajansen/anaconda3/envs/transformers-4.1.1a/lib/python3.7/site-packages/transformers/trainer.py", line 1300, in compute_loss
outputs = model(**inputs)
File "/home/pajansen/anaconda3/envs/transformers-4.1.1a/lib/python3.7/site-packages/torch/nn/modules/module.py", line 873, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/pajansen/anaconda3/envs/transformers-4.1.1a/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 1500, in forward
return_dict=return_dict,
File "/home/pajansen/anaconda3/envs/transformers-4.1.1a/lib/python3.7/site-packages/torch/nn/modules/module.py", line 873, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/pajansen/anaconda3/envs/transformers-4.1.1a/lib/python3.7/site-packages/transformers/models/t5/modeling_t5.py", line 938, in forward
head_mask = head_mask.to(hidden_states.device)
AttributeError: 'list' object has no attribute 'to'
0%| | 0/25502 [00:00<?, ?it/s]
01/21/2021 01:25:01 - INFO - __main__ - *** Train ***
[INFO|trainer.py:703] 2021-01-21 01:25:01,016 >> ***** Running training *****
[INFO|trainer.py:704] 2021-01-21 01:25:01,016 >> Num examples = 999
[INFO|trainer.py:705] 2021-01-21 01:25:01,016 >> Num Epochs = 1
[INFO|trainer.py:706] 2021-01-21 01:25:01,016 >> Instantaneous batch size per device = 1
[INFO|trainer.py:707] 2021-01-21 01:25:01,016 >> Total train batch size (w. parallel, distributed & accumulation) = 1
[INFO|trainer.py:708] 2021-01-21 01:25:01,016 >> Gradient Accumulation steps = 1
[INFO|trainer.py:709] 2021-01-21 01:25:01,017 >> Total optimization steps = 999
0%| | 0/999 [00:00<?, ?it/s]/home/pajansen/anaconda3/envs/transformers-4.1.1/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:134: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
{'loss': nan, 'learning_rate': 1.4984984984984986e-05, 'epoch': 0.5005005005005005}
50%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 500/999 [02:25<02:20, 3.54it/s][INFO|trainer.py:1226] 2021-01-21 01:27:26,134 >> Saving model checkpoint to xsum-mini_results/checkpoint-500
[INFO|configuration_utils.py:289] 2021-01-21 01:27:26,138 >> Configuration saved in xsum-mini_results/checkpoint-500/config.json
[INFO|modeling_utils.py:814] 2021-01-21 01:27:29,444 >> Model weights saved in xsum-mini_results/checkpoint-500/pytorch_model.bin
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 999/999 [04:54<00:00, 3.30it/s][INFO|trainer.py:862] 2021-01-21 01:29:55,140 >>
Training completed. Do not forget to share your model on huggingface.co/models =)
{'epoch': 1.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 999/999 [04:54<00:00, 3.40it/s]
[INFO|trainer.py:1226] 2021-01-21 01:29:55,141 >> Saving model checkpoint to xsum-mini_results
[INFO|configuration_utils.py:289] 2021-01-21 01:29:55,146 >> Configuration saved in xsum-mini_results/config.json
[INFO|modeling_utils.py:814] 2021-01-21 01:29:58,207 >> Model weights saved in xsum-mini_results/pytorch_model.bin
01/21/2021 01:29:58 - INFO - __main__ - ***** train metrics *****
01/21/2021 01:29:58 - INFO - __main__ - train_samples_per_second = -0.003
01/21/2021 01:29:58 - INFO - __main__ - train_runtime = 294.1311
01/21/2021 01:29:58 - INFO - __main__ - train_n_ojbs = -1
(Note, I substituted the xsum dataset above for a shorter version I made with /head/ to just use the first 1000 lines of each file, to see if it would finish to completion (without taking 15 hours for the full example dataset). It looks okay. It's worth noting that if the validation arguments are added:
then 4.1.1 will die at the checkpoints (500 iterations) with "RuntimeError: Input, output and indices must be on the current device". (I don't fully appreciate that one -- I'm assuming it means train/eval has to be done separately with MP, which is entirely manageable. #9336 showed a similar error, but that person was using BART (which doesn't have MP in 4.1.1) instead of T5, so I don't think it's the same thing).
Expected behavior
Model parallelism -- spreading large models across multiple GPUs.
The text was updated successfully, but these errors were encountered:
If you run into other issues please try this PR: #9323
which has lots of improvements. It just hasn't been merged since we are waiting for me I think to sort the whole MP/PP out before moving forward.
Environment info
transformers
version: 4.3.0.dev0Who can help
@stas00 @alexorona @sgugger
Related
Related to the discussion in #8771 ( #8771 (comment) ) that suggests MP can be done in 4.3.0 just by calling model.parallelize() after loading. I made another issue rather than hijack that one that's about MP improvements in general.
Information
Model I am using (Bert, XLNet ...): T5
The problem arises when using:
Added one line to funetune_trainer.py after model is loaded ( model.parallelize(), see below)
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
On 4.3.0-dev (tonight):
transformers
version: 4.1.1Change:
Runscript:
Output (works fine):
(Note, I substituted the xsum dataset above for a shorter version I made with /head/ to just use the first 1000 lines of each file, to see if it would finish to completion (without taking 15 hours for the full example dataset). It looks okay. It's worth noting that if the validation arguments are added:
then 4.1.1 will die at the checkpoints (500 iterations) with "RuntimeError: Input, output and indices must be on the current device". (I don't fully appreciate that one -- I'm assuming it means train/eval has to be done separately with MP, which is entirely manageable. #9336 showed a similar error, but that person was using BART (which doesn't have MP in 4.1.1) instead of T5, so I don't think it's the same thing).
Expected behavior
Model parallelism -- spreading large models across multiple GPUs.
The text was updated successfully, but these errors were encountered: