Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Results on TPU worse than on GPU (using colab) #526

Closed
2 of 4 tasks
koba35 opened this issue Jul 17, 2022 · 8 comments
Closed
2 of 4 tasks

Results on TPU worse than on GPU (using colab) #526

koba35 opened this issue Jul 17, 2022 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@koba35
Copy link

koba35 commented Jul 17, 2022

System Info

- `Accelerate` version: 0.11.0.dev0
- Platform: Linux-5.4.188+-x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.13
- Numpy version: 1.21.6
- PyTorch version (GPU?): 1.11.0+cu102 (False)
- `Accelerate` default config:
	- compute_environment: LOCAL_MACHINE
	- distributed_type: TPU
	- mixed_precision: no
	- use_cpu: False
	- num_processes: 8
	- machine_rank: 0
	- num_machines: 1
	- main_process_ip: None
	- main_process_port: None
	- main_training_function: main
	- deepspeed_config: {}
	- fsdp_config: {}

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

I created notebook for reproducing, but the steps are very easy:

  1. Install all libs (I tried different versions, all share same behaivor)
  2. Run accelerate config and choose TPU (or GPU with fp16)
  3. Run accelerate launch accelerate/examples/nlp_example.py

When I choose training with TPU, I could get 0.848 f1 score. When I train with GPU, I got more than 0.9. I also tried different scripts, and always get much worse results with TPU. Maybe it something colab-specific, because as I can see in another TPU related issues (for example) people getting results similar to my GPU results

Expected behavior

When I run example scripts in colab, I should get similar results with TPU and GPU.
@koba35 koba35 added the bug Something isn't working label Jul 17, 2022
@muellerzr muellerzr self-assigned this Jul 17, 2022
@muellerzr
Copy link
Collaborator

muellerzr commented Jul 18, 2022

I'm working on a very large benchmark for this, as this issue/confusion has come up multiple times. But please give a look at this issue first because if you do not adjust the batch size at all you are not actually accurately benchmarking your results: #450

@koba35
Copy link
Author

koba35 commented Jul 18, 2022

Yep, I tried adjust batch size and learning rates (multiply/divide by number) - but still no results, TPU in multi-processing perform much worse. I cannot achieve same performance (0.9+ f1), but I could achieve with single TPU now.
BTW, this bug exist not only in colab - in TPU VM too

@muellerzr
Copy link
Collaborator

muellerzr commented Jul 18, 2022

For us to reproduce your results, please include the exact configurations you used for testing each one, including each batch size as you adjusted it as well as learning rates. These are critically important

Since you are in colab, could you try launching via the notebook_launcher to see if you get similar or worse results as well?

@koba35
Copy link
Author

koba35 commented Jul 18, 2022

Well, I tried divide or multiply lr/bs by degrees of 2. And when I launching via notebook launcher, I get similar results (well, exactly same results)
But seems that I found root of problem - when I init model outside of train_fn and multiply lr by 8 - I could get 0.9 in f1. I will do another checks tomorrow and I'll be more specific if this fixes the problem

@muellerzr
Copy link
Collaborator

@koba35 great, thanks! Very curious to see about your result there.

@koba35
Copy link
Author

koba35 commented Jul 19, 2022

As it turned out, everything is much more complicated. When removing the model from the training function and increasing lr, I was able to achieve normal results, but apparently the point is also that in this case the model is initialized before we set_seed. If we set_seed before we start multiprocessing, then the results fall again. It seems to me that in this example, several different details just converged - a relatively small dataset, a large initial LR, a fixed number of steps in the scheduler, model initialization inside the function, a fixed seed inside each process (if I understand correctly, in every fork we should setting different seeds).

As for me, it is worth adding the following tweaks:

  1. Initialize the model outside the training function
  2. Add the process number to the seed (like seed = int(config["seed"]) + accelerator.process_index)
  3. Add warm-up steps to the scheduler as a percentage of the length of the dataset - num_warmup_steps=len(train_dataloader) * num_epochs * 0.1,
  4. Multiply lr by num_processes.

Unfortunately, I can’t say more precisely - maybe it's actually something else, but I couldn't find it.

@huggingface huggingface deleted a comment from github-actions bot Aug 16, 2022
@muellerzr
Copy link
Collaborator

muellerzr commented Aug 23, 2022

Hey @koba35, I believe I've finally solved this regression issue. If you could, can you try doing the following in your code:

        model, optimizer, scheduler, train_dataloader, eval_dataloader = accelerator.prepare(
            model, optimizer, scheduler, train_dataloader, eval_dataloader
        )

        scheduler.split_batches = True # <- THIS PART RIGHT HERE

and tell me what your results are? This improved my initial findings as I was looking, and practically perfectly aligned with results when Accelerate wasn't used at all

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants