-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathaccelerator.py
2617 lines (2216 loc) · 110 KB
/
accelerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import math
import os
import re
import shutil
import sys
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, List, Optional, Union
import torch
import torch.utils.hooks as hooks
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .logging import get_logger
from .optimizer import AcceleratedOptimizer
from .scheduler import AcceleratedScheduler
from .state import AcceleratorState, GradientState, PartialState, parse_flag_from_env
from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers
from .utils import (
MODEL_NAME,
DeepSpeedPlugin,
DistributedDataParallelKwargs,
DistributedType,
DynamoBackend,
FP8RecipeKwargs,
FullyShardedDataParallelPlugin,
GradientAccumulationPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
KwargsHandler,
LoggerType,
MegatronLMPlugin,
PrecisionType,
ProjectConfiguration,
RNGType,
TorchDynamoPlugin,
compare_versions,
convert_model,
convert_outputs_to_fp32,
extract_model_from_parallel,
gather,
get_pretty_name,
has_transformer_engine_layers,
is_bf16_available,
is_deepspeed_available,
is_fp8_available,
is_megatron_lm_available,
is_torch_version,
is_tpu_available,
pad_across_processes,
parse_choice_from_env,
recursively_apply,
reduce,
release_memory,
save,
wait_for_everyone,
)
if is_deepspeed_available():
import deepspeed
from .utils import (
DeepSpeedEngineWrapper,
DeepSpeedOptimizerWrapper,
DeepSpeedSchedulerWrapper,
DummyOptim,
DummyScheduler,
)
if is_fp8_available():
import transformer_engine.common.recipe as te_recipe
from transformer_engine.pytorch import fp8_autocast
if is_megatron_lm_available():
from .utils import (
MegatronEngine,
MegatronLMDummyDataLoader,
MegatronLMDummyScheduler,
MegatronLMOptimizerWrapper,
MegatronLMSchedulerWrapper,
megatron_lm_initialize,
megatron_lm_prepare_data_loader,
megatron_lm_prepare_model,
megatron_lm_prepare_optimizer,
megatron_lm_prepare_scheduler,
)
if is_torch_version(">", "1.10.0"):
from torch.distributed.algorithms.join import Join
if is_tpu_available(check_device=False):
import torch_xla.distributed.xla_multiprocessing as xmp
try:
from torch.optim.lr_scheduler import LRScheduler
except ImportError:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
logger = get_logger(__name__)
class Accelerator:
"""
Creates an instance of an accelerator for distributed training (on multi-GPU, TPU) or mixed precision training.
Args:
device_placement (`bool`, *optional*, defaults to `True`):
Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model,
etc...).
split_batches (`bool`, *optional*, defaults to `False`):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
`True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
in your script multiplied by the number of processes.
mixed_precision (`str`, *optional*):
Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the
value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the
accelerate config of the current system or the flag passed with the `accelerate.launch` command. 'fp16'
requires pytorch 1.6 or higher. 'bf16' requires pytorch 1.10 or higher. 'fp8' requires the installation of
transformers-engine.
gradient_accumulation_steps (`int`, *optional*, default to 1):
The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with
`Accelerator.accumulate`. If not passed, will default to the value in the environment variable
`ACCELERATE_GRADIENT_ACCUMULATION_STEPS`. Can also be configured through a `GradientAccumulationPlugin`.
cpu (`bool`, *optional*):
Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force
the execution on one process only.
deepspeed_plugin (`DeepSpeedPlugin`, *optional*):
Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
fsdp_plugin (`FullyShardedDataParallelPlugin`, *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
megatron_lm_plugin (`MegatronLMPlugin`, *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
rng_types (list of `str` or [`~utils.RNGType`]):
The list of random number generators to synchronize at the beginning of each iteration in your prepared
dataloaders. Should be one or several of:
- `"torch"`: the base torch random number generator
- `"cuda"`: the CUDA random number generator (GPU only)
- `"xla"`: the XLA random number generator (TPU only)
- `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
Will default to `["torch"]` for PyTorch versions <=1.5.1 and `["generator"]` for PyTorch versions >= 1.6.
log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
A list of loggers to be setup for experiment tracking. Should be one or several of:
- `"all"`
- `"tensorboard"`
- `"wandb"`
- `"comet_ml"`
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
project_config (`ProjectConfiguration`, *optional*):
A configuration for how saving the state can be handled.
project_dir (`str`, `os.PathLike`, *optional*):
A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved
checkpoints.
dispatch_batches (`bool`, *optional*):
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
underlying dataset is an `IterableDataset`, `False` otherwise.
even_batches (`bool`, *optional*, defaults to `True`):
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
all workers.
step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`):
Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only
done under certain circumstances (at the end of each epoch, for instance).
kwargs_handlers (`List[KwargHandler]`, *optional*)
A list of `KwargHandler` to customize how the objects related to distributed training or mixed precision
are created. See [kwargs](kwargs) for more information.
dynamo_backend (`str` or `DynamoBackend`, *optional*, defaults to `"no"`):
Set to one of the possible dynamo backends to optimize your training with torch dynamo.
gradient_accumulation_plugin (`GradientAccumulationPlugin`, *optional*):
A configuration for how gradient accumulation should be handled, if more tweaking than just the
`gradient_accumulation_steps` is needed.
**Available attributes:**
- **device** (`torch.device`) -- The device to use.
- **distributed_type** ([`~utils.DistributedType`]) -- The distributed training configuration.
- **local_process_index** (`int`) -- The process index on the current machine.
- **mixed_precision** (`str`) -- The configured mixed precision mode.
- **num_processes** (`int`) -- The total number of processes used for training.
- **optimizer_step_was_skipped** (`bool`) -- Whether or not the optimizer update was skipped (because of
gradient overflow in mixed precision), in which
case the learning rate should not be changed.
- **process_index** (`int`) -- The overall index of the current process among all processes.
- **state** ([`~state.AcceleratorState`]) -- The distributed setup state.
- **sync_gradients** (`bool`) -- Whether the gradients are currently being synced across all processes.
- **use_distributed** (`bool`) -- Whether the current configuration is for distributed training.
"""
def __init__(
self,
device_placement: bool = True,
split_batches: bool = False,
mixed_precision: Union[PrecisionType, str] = None,
gradient_accumulation_steps: int = 1,
cpu: bool = False,
deepspeed_plugin: DeepSpeedPlugin = None,
fsdp_plugin: FullyShardedDataParallelPlugin = None,
megatron_lm_plugin: MegatronLMPlugin = None,
rng_types: Optional[List[Union[str, RNGType]]] = None,
log_with: Optional[List[Union[str, LoggerType, GeneralTracker]]] = None,
project_dir: Optional[Union[str, os.PathLike]] = None,
project_config: Optional[ProjectConfiguration] = None,
logging_dir: Optional[Union[str, os.PathLike]] = None,
gradient_accumulation_plugin: Optional[GradientAccumulationPlugin] = None,
dispatch_batches: Optional[bool] = None,
even_batches: bool = True,
step_scheduler_with_optimizer: bool = True,
kwargs_handlers: Optional[List[KwargsHandler]] = None,
dynamo_backend: Union[DynamoBackend, str] = None,
):
if project_config is not None:
self.project_configuration = project_config
else:
self.project_configuration = ProjectConfiguration(project_dir=project_dir)
if logging_dir is not None:
warnings.warn(
"`logging_dir` is deprecated and will be removed in version 0.18.0 of 🤗 Accelerate. Use `project_dir` instead.",
FutureWarning,
)
self.project_configuration.logging_dir = logging_dir
if project_dir is not None and self.project_dir is None:
self.project_configuration.project_dir = project_dir
if mixed_precision is not None:
mixed_precision = str(mixed_precision)
if mixed_precision not in PrecisionType:
raise ValueError(
f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}"
)
dynamo_plugin = TorchDynamoPlugin() if dynamo_backend is None else TorchDynamoPlugin(backend=dynamo_backend)
if deepspeed_plugin is None: # init from env variables
deepspeed_plugin = (
DeepSpeedPlugin() if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" else None
)
else:
assert isinstance(
deepspeed_plugin, DeepSpeedPlugin
), "`deepspeed_plugin` must be an `accelerate.utils.DeepSpeedPlugin` object."
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" # use DeepSpeed if plugin is provided
if deepspeed_plugin:
if not is_deepspeed_available():
raise ImportError("DeepSpeed is not installed => run `pip install deepspeed` or build it from source.")
if compare_versions("deepspeed", "<", "0.6.5"):
raise ImportError("DeepSpeed version must be >= 0.6.5. Please update DeepSpeed.")
mixed_precision = (
os.environ.get("ACCELERATE_MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
)
deepspeed_plugin.set_mixed_precision(mixed_precision)
deepspeed_plugin.set_deepspeed_weakref()
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance(
fsdp_plugin, FullyShardedDataParallelPlugin
):
if is_torch_version("<", "1.12.0"):
raise ValueError("FSDP requires PyTorch >= 1.12.0")
if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
)
else:
if not isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):
raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.")
os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided
if megatron_lm_plugin is None: # init from env variables
megatron_lm_plugin = (
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
)
else:
if not isinstance(megatron_lm_plugin, MegatronLMPlugin):
raise TypeError("`megatron_lm_plugin` must be a MegatronLMPlugin object.")
os.environ["ACCELERATE_USE_MEGATRON_LM"] = "true" # use MegatronLM if plugin is provided
if megatron_lm_plugin:
if not is_megatron_lm_available():
raise ImportError("Megatron is not installed. please build it from source.")
# Kwargs handlers
self.ddp_handler = None
self.scaler_handler = None
self.init_handler = None
self.fp8_recipe_handler = None
if kwargs_handlers is not None:
for handler in kwargs_handlers:
assert isinstance(
handler, KwargsHandler
), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
if isinstance(handler, DistributedDataParallelKwargs):
if self.ddp_handler is not None:
raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.")
else:
self.ddp_handler = handler
elif isinstance(handler, GradScalerKwargs):
if self.scaler_handler is not None:
raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.")
else:
self.scaler_handler = handler
elif isinstance(handler, InitProcessGroupKwargs):
if self.init_handler is not None:
raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.")
else:
self.init_handler = handler
elif isinstance(handler, FP8RecipeKwargs):
if self.fp8_recipe_handler is not None:
raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.")
else:
self.fp8_recipe_handler = handler
kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
self.state = AcceleratorState(
mixed_precision=mixed_precision,
cpu=cpu,
dynamo_plugin=dynamo_plugin,
deepspeed_plugin=deepspeed_plugin,
fsdp_plugin=fsdp_plugin,
megatron_lm_plugin=megatron_lm_plugin,
_from_accelerator=True,
**kwargs,
)
if self.state.distributed_type == DistributedType.TPU:
if gradient_accumulation_plugin.num_steps != 1:
raise ValueError(
"Gradient accumulation is not supported on TPU. Please set `gradient_accumulation_steps` to 1 and don't pass in a `GradientAccumulationPlugin` object."
)
trackers = filter_trackers(log_with, self.logging_dir)
if len(trackers) < 1 and log_with is not None:
warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.")
self.log_with = trackers
if (
(mixed_precision != "bf16")
and getattr(self.state, "downcast_bfloat", False)
and (self.state.distributedType != DistributedType.TPU)
):
raise ValueError("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU")
if gradient_accumulation_plugin is not None:
if gradient_accumulation_steps != 1:
raise ValueError(
"You can only pass one of `gradient_accumulation_steps` and `gradient_accumulation_plugin`. Please only pass in the created `GradientAccumulationPlugin` object."
)
else:
gradient_accumulation_steps = int(
parse_choice_from_env("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", gradient_accumulation_steps)
)
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=gradient_accumulation_steps)
self.gradient_state = GradientState(
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
self.device_placement = device_placement
self.split_batches = split_batches
self.dispatch_batches = dispatch_batches
if dispatch_batches is True and is_torch_version("<", "1.8.0"):
raise ImportError(
"Using `DataLoaderDispatcher` requires PyTorch 1.8.0 minimum. You have {torch.__version__}."
)
self.even_batches = even_batches
self.step_scheduler_with_optimizer = step_scheduler_with_optimizer
# Mixed precision attributes
self.scaler = None
self.native_amp = False
err = "{mode} mixed precision requires {requirement}"
if (
self.state.mixed_precision == "fp16"
and self.device.type != "cpu"
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
):
self.native_amp = True
if not torch.cuda.is_available() and not parse_flag_from_env("ACCELERATE_USE_MPS_DEVICE"):
raise ValueError(err.format(mode="fp16", requirement="a GPU"))
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
if self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
self.scaler = ShardedGradScaler(**kwargs)
else:
self.scaler = torch.cuda.amp.GradScaler(**kwargs)
elif self.state.mixed_precision == "bf16" and self.distributed_type not in (
DistributedType.DEEPSPEED,
DistributedType.MEGATRON_LM,
):
if self.device.type == "cpu":
self.native_amp = is_torch_version(">=", "1.10")
else:
self.native_amp = is_bf16_available(True)
if mixed_precision == "bf16" and not self.native_amp and not is_tpu_available():
raise ValueError(err.format(mode="bf16", requirement="PyTorch >= 1.10 and a supported device."))
# Start of internal step tracking
self.step = 0
# Internal references to the training objects
self._optimizers = []
self._models = []
self._schedulers = []
self._dataloaders = []
self._custom_objects = []
# Hooks
self._load_model_state_pre_hook = OrderedDict()
self._save_model_state_pre_hook = OrderedDict()
# RNG Types
self.rng_types = rng_types
if self.rng_types is None:
self.rng_types = ["generator"]
@property
def use_distributed(self):
"""
Whether the Accelerator is configured for distributed training
"""
return self.state.use_distributed
@property
def distributed_type(self):
return self.state.distributed_type
@property
def num_processes(self):
return self.state.num_processes
@property
def process_index(self):
return self.state.process_index
@property
def local_process_index(self):
return self.state.local_process_index
@property
def device(self):
return self.state.device
@property
def project_dir(self):
return self.project_configuration.project_dir
@property
def logging_dir(self):
return self.project_configuration.logging_dir
@property
def save_iteration(self):
return self.project_configuration.iteration
@property
def is_main_process(self):
"""True for one process only."""
return self.state.is_main_process
@property
def is_local_main_process(self):
"""True for one process per server."""
return self.state.is_local_main_process
@property
def use_fp16(self):
warnings.warn(
"The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use "
"`Accelerator.mixed_precision == 'fp16'` instead.",
FutureWarning,
)
return self.mixed_precision != "no"
@property
def is_last_process(self):
return self.process_index == self.num_processes - 1
@property
def mixed_precision(self):
return self.state.mixed_precision
def on_main_process(self, function: Callable[..., Any] = None):
"""
A decorator that will run the decorated function on the main process only. Can also be called using the
`PartialState` class.
Args:
function (`Callable`): The function to decorate.
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> @accelerator.on_main_process
... def print_something():
... print("This will be printed by process 0 only.")
>>> print_something()
"This will be printed by process 0 only"
```
"""
# For times when the `Accelerator` object itself utilizes this decorator.
if function is None:
if "Accelerator." in self.__qualname__:
function = self
else:
raise ValueError(
"The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
)
def _inner(*args, **kwargs):
return PartialState().on_main_process(function)(*args, **kwargs)
return _inner
def on_local_main_process(self, function: Callable[..., Any] = None):
"""
A decorator that will run the decorated function on the local main process only. Can also be called using the
`PartialState` class.
Args:
function (`Callable`): The function to decorate.
Example:
```python
# Assume we have 2 servers with 4 processes each.
from accelerate import Accelerator
accelerator = Accelerator()
@accelerator.on_local_main_process
def print_something():
print("This will be printed by process 0 only on each server.")
print_something()
# On server 1:
"This will be printed by process 0 only"
# On server 2:
"This will be printed by process 0 only"
```
"""
# For times when the `Accelerator` object itself utilizes this decorator.
if function is None:
if "Accelerator." in self.__qualname__:
function = self
else:
raise ValueError(
"The `on_local_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
)
def _inner(*args, **kwargs):
return PartialState().on_local_main_process(function)(*args, **kwargs)
return _inner
def on_last_process(self, function: Callable[..., Any]):
"""
A decorator that will run the decorated function on the last process only. Can also be called using the
`PartialState` class.
Args:
function (`Callable`): The function to decorate.
Example:
```python
# Assume we have 4 processes.
from accelerate import Accelerator
accelerator = Accelerator()
@accelerator.on_last_process
def print_something():
print(f"Printed on process {accelerator.process_index}")
print_something()
"Printed on process 3"
```
"""
# For times when the `Accelerator` object itself utilizes this decorator.
if function is None:
if "Accelerator." in self.__qualname__:
function = self
else:
raise ValueError(
"The `on_last_process` decorator must be called with a function on an instantiated `Accelerator` object."
)
def _inner(*args, **kwargs):
return PartialState().on_last_process(function)(*args, **kwargs)
return _inner
def on_process(self, function: Callable[..., Any] = None, process_index: int = None):
"""
A decorator that will run the decorated function on a given process index only. Can also be called using the
`PartialState` class.
Args:
function (`Callable`, `optional`):
The function to decorate.
process_index (`int`, `optional`):
The index of the process on which to run the function.
Example:
```python
# Assume we have 4 processes.
from accelerate import Accelerator
accelerator = Accelerator()
@accelerator.on_process(process_index=2)
def print_something():
print(f"Printed on process {accelerator.process_index}")
print_something()
"Printed on process 2"
```
"""
# Initial construction of the decorator.
if (self is not None) and (process_index is not None) and (function is None):
return partial(self.on_process, process_index=process_index)
# For times when the `Accelerator` object itself utilizes this decorator.
if function is None:
if "Accelerator." in self.__qualname__:
function = self
else:
raise ValueError(
"The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
)
def _inner(*args, **kwargs):
return PartialState().on_process(function, process_index)(*args, **kwargs)
return _inner
def on_local_process(self, function: Callable[..., Any] = None, local_process_index: int = None):
"""
A decorator that will run the decorated function on a given local process index only. Can also be called using
the `PartialState` class.
Args:
function (`Callable`, *optional*):
The function to decorate.
local_process_index (`int`, *optional*):
The index of the local process on which to run the function.
Example:
```python
# Assume we have 2 servers with 4 processes each.
from accelerate import Accelerator
accelerator = Accelerator()
@accelerator.on_local_process(local_process_index=2)
def print_something():
print(f"Printed on process {accelerator.local_process_index}")
print_something()
# On server 1:
"Printed on process 2"
# On server 2:
"Printed on process 2"
```
"""
# Initial construction of the decorator.
if (self is not None) and (local_process_index is not None) and (function is None):
return partial(self.on_local_process, local_process_index=local_process_index)
# For times when the `Accelerator` object itself utilizes this decorator.
if function is None:
if "Accelerator." in self.__qualname__:
function = self
else:
raise ValueError(
"The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
)
def _inner(*args, **kwargs):
return PartialState().on_local_process(function, local_process_index)(*args, **kwargs)
return _inner
@contextmanager
def main_process_first(self):
"""
Lets the main process go first inside a with block.
The other processes will enter the with block after the main process exits.
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> with accelerator.main_process_first():
... # This will be printed first by process 0 then in a seemingly
... # random order by the other processes.
... print(f"This will be printed by process {accelerator.process_index}")
```
"""
yield self.state.main_process_first()
@contextmanager
def local_main_process_first(self):
"""
Lets the local main process go inside a with block.
The other processes will enter the with block after the main process exits.
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> with accelerator.local_main_process_first():
... # This will be printed first by local process 0 then in a seemingly
... # random order by the other processes.
... print(f"This will be printed by process {accelerator.local_process_index}")
```
"""
yield self.state.local_main_process_first()
@contextmanager
def no_sync(self, model):
"""
A context manager to disable gradient synchronizations across DDP processes by calling
`torch.nn.parallel.DistributedDataParallel.no_sync`.
If `model` is not in DDP, this context manager does nothing
Args:
model (`torch.nn.Module`):
PyTorch Module that was prepared with `Accelerator.prepare`
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer)
>>> input_a = next(iter(dataloader))
>>> input_b = next(iter(dataloader))
>>> with accelerator.no_sync():
... outputs = model(input_a)
... loss = loss_func(outputs)
... accelerator.backward(loss)
... # No synchronization across processes, only accumulate gradients
>>> outputs = model(input_b)
>>> accelerator.backward(loss)
>>> # Synchronization across all processes
>>> optimizer.step()
>>> optimizer.zero_grad()
```
"""
context = contextlib.nullcontext
if self.use_distributed:
context = getattr(model, "no_sync", context)
with context():
yield
def _do_sync(self):
"Sets the right `sync_gradients` context and either resets or increases `self.step`"
if self.gradient_state.end_of_dataloader:
self.step = 0
self.gradient_state._set_sync_gradients(True)
else:
self.step += 1
self.gradient_state._set_sync_gradients((self.step % self.gradient_state.num_steps) == 0)
@property
def sync_gradients(self):
return self.gradient_state.sync_gradients
@property
def gradient_accumulation_steps(self):
return self.gradient_state.num_steps
@contextmanager
def accumulate(self, model):
"""
A context manager that will lightly wrap around and perform gradient accumulation automatically
Args:
model (`torch.nn.Module`):
PyTorch Module that was prepared with `Accelerator.prepare`
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator(gradient_accumulation_steps=1)
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
>>> with accelerator.accumulate(model):
... for input, output in dataloader:
... outputs = model(input)
... loss = loss_func(outputs)
... loss.backward()
... optimizer.step()
... scheduler.step()
... optimizer.zero_grad()
```
"""
self._do_sync()
if self.sync_gradients:
context = contextlib.nullcontext
else:
context = self.no_sync
with context(model):
yield
@contextmanager
def join_uneven_inputs(self, joinables, even_batches=None):
"""
A context manager that facilitates distributed training or evaluation on uneven inputs, which acts as a wrapper
around `torch.distributed.algorithms.join`. This is useful when the total batch size does not evenly divide the
length of the dataset.
Args:
joinables (`List[torch.distributed.algorithms.Joinable]`):
A list of models or optimizers that subclass `torch.distributed.algorithms.Joinable`. Most commonly, a
PyTorch Module that was prepared with `Accelerator.prepare` for DistributedDataParallel training.
even_batches (`bool`, *optional*)
If set, this will override the value of `even_batches` set in the `Accelerator`. If it is not provided,
the default `Accelerator` value wil be used.
<Tip warning={true}>
`join_uneven_inputs` is only supported for Distributed Data Parallel training on multiple GPUs. For any other
configuration, this method will have no effect.
</Tip>
<Tip warning={true}>
Overidding `even_batches` will not affect iterable-style data loaders.
</Tip>
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator(even_batches=True)
>>> ddp_model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
>>> with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
... for input, output in dataloader:
... outputs = model(input)
... loss = loss_func(outputs)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
```
"""
if is_torch_version("<", "1.10.0"):
raise ValueError(f"Joining uneven inputs requires PyTorch >= 1.10.0, You have {torch.__version__}.")
if self.distributed_type == DistributedType.MULTI_GPU:
dl_even_batches_values = []
if even_batches is not None:
iterable_dl_seen = False
# override value in batch sampler for map-style datasets
for dl_idx, dl in enumerate(self._dataloaders):
if isinstance(dl, DataLoaderDispatcher):
iterable_dl_seen = True
continue
dl_even_batches_values.append((dl_idx, dl.batch_sampler.even_batches))
dl.batch_sampler.even_batches = even_batches
if iterable_dl_seen:
warnings.warn(
"Overridding even_batches is only supported for map-style datasets, yet some dataloaders given were iterable"
)
else:
even_batches = self.even_batches
enable_join = False if even_batches else True
try:
with Join(joinables, enable=enable_join, throw_on_early_termination=False):
yield
finally:
# reset any batch samplers that have been modified
for dl_idx, even_batches_value in dl_even_batches_values:
self._dataloaders[dl_idx].batch_sampler.even_batches = even_batches_value
else:
# Even when disabled, Join expects models to subclass Joinable, so skip entirely for single process runs
if self.distributed_type != DistributedType.NO:
warnings.warn(
"Joining uneven inputs is only supported for multi-GPU training, as a result `join_uneven_inputs` will have no effect."
)
with contextlib.nullcontext(joinables):
yield
def print(self, *args, **kwargs):
"""
Drop in replacement of `print()` to only print once per server.
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> accelerator.print("Hello world!")
```
"""
self.state.print(*args, **kwargs)
def _prepare_one(self, obj, first_pass=False, device_placement=None):
# First pass of preparation: DataLoader, model, optimizer
if first_pass:
if isinstance(obj, torch.utils.data.DataLoader):
return self.prepare_data_loader(obj, device_placement=device_placement)
elif isinstance(obj, torch.nn.Module):
return self.prepare_model(obj, device_placement=device_placement)
elif isinstance(obj, torch.optim.Optimizer):
optimizer = self.prepare_optimizer(obj, device_placement=device_placement)
return optimizer
# Second pass of preparation: LR scheduler (which need the full list of optimizers)
elif isinstance(obj, LRScheduler):
scheduler = self.prepare_scheduler(obj)
return scheduler
# Return the unprocessed object if previous criteria was not met
return obj
def _prepare_fsdp(self, *args):
result = []
for obj in args:
if isinstance(obj, torch.nn.Module):
model = obj
break
optimizers = []
self._schedulers = []
self._models = []
intermediate_result = []
for obj in args:
if isinstance(obj, torch.optim.Optimizer):
if len(obj.param_groups) > 1:
logger.warning(
"FSDP Warning: When using FSDP, several parameter groups will be conflated into "
"a single one due to nested module wrapping and parameter flattening."
)
try:
optimizer = obj.optimizer.__class__(model.parameters(), **obj.optimizer.defaults)