-
Notifications
You must be signed in to change notification settings - Fork 284
/
Copy pathfully_sharded_data_parallel.py
2942 lines (2568 loc) · 140 KB
/
fully_sharded_data_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
import collections
from dataclasses import dataclass
from enum import Enum, auto
import functools
import itertools
import logging
from math import inf
import os
import tempfile
import time
import traceback
import typing
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
Deque,
)
import torch
from torch.autograd import Variable
import torch.distributed as dist
from torch.distributed import ProcessGroup
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
ProcessGroupName,
enable_pytorch_sync_bn,
get_process_group_cached,
validate_process_group,
)
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
from . import fsdp_optim_utils as ou
if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
# TODO: Remove the toggle here when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
try:
import fairscale.experimental.nn.ssd_offload as ssd_offload
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
import_ssd_offload = True
except ImportError:
# The latest nightly PyTorch version required
import_ssd_offload = False
pass
from logging import getLogger
logger = getLogger()
class _FreeEventQueue:
"""
This tracks all pending frees corresponding to inflight all-gathers. The
queueing pattern is iterative enqueues with a single dequeue per iteration
once the limit ``_max_num_inflight_all_gathers`` is reached.
"""
def __init__(self) -> None:
self._queue: Deque[torch.cuda.Event] = collections.deque()
self._max_num_inflight_all_gathers = 0 # empirically chosen
def enqueue(self, free_event: torch.cuda.Event) -> None:
"""Enqueues a free event."""
self._queue.append(free_event)
def dequeue_if_needed(self) -> Optional[torch.cuda.Event]:
"""Dequeues a single event if the limit is reached."""
if len(self._queue) >= self._max_num_inflight_all_gathers:
return self._dequeue()
return None
def _dequeue(self) -> Optional[torch.cuda.Event]:
"""Dequeues a free event if possible."""
if self._queue:
event = self._queue.popleft()
return event
return None
class TrainingState(Enum):
"""
Simple enum to indicate what state FSDP is in. Used for asserting
to make sure APIs are called in the correct state.
..note::
BACKWARD_PRE and BACKWARD_POST states are used to ensure we
receives backward hooks in the correct order. It is used to catch
unexpected order of hooks being called (likely due to our
hook registration logic or autograd engine logic changes).
TODO (Min): It would be nice to capture the stepping state as well.
Maybe we can use the model.zero_grad() call, but not sure if it
is called if optim.zero_grad() is used instead.
It would be nice to have clear state transition be explicit like:
zero_grad -> fwd -> bwd -> optionally accum grad by repeating
fwd/bwd -> stepping -> loop back to zero_grad
"""
IDLE = auto()
FORWARD = auto()
BACKWARD_PRE = auto()
BACKWARD_POST = auto()
SUMMON_FULL_PARAMS = auto()
# Data classes containing FSDP parameter constructs
# Offload config for specifying SSD options (initially at least)
@dataclass
class OffloadConfig:
"""Class for specifying all arguments related to offloading parameters."""
# Offload type: currently only supports: "ssd_offload"
offload_type: Optional[str] = None
# Path to the directory for storing parameters offloaded to disk.
dir: Optional[str] = None
class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
FullyShardedDataParallel is commonly shorten to FSDP.
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _DeepSpeed: https://www.deepspeed.ai/
Pseudo-code usage::
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
torch.cuda.set_device(device_id)
sharded_module = FSDP(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
x = sharded_module(x, y=3, z=torch.Tensor([1]))
loss = x.sum()
loss.backward()
optim.step()
It is also possible to shard individual layers separately and have an outer
wrapper handle any leftover parameters. This can be helpful to further
reduce GPU memory usage, reduce system memory usage when initializing large
models and to improve training speed by overlapping the all-gather step
across the forward pass. For example::
import torch
from fairscale.nn.wrap import wrap, enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, teardown, rmf
result = dist_init(0, 1, "/tmp/t1", "/tmp/t2")
assert result
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
with enable_wrap(**fsdp_params):
l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(l1, FSDP)
# Wraps layer in FSDP by default if within context
# Separately Wraps children modules with more than 1e8 params
large_tfmr = torch.nn.Transformer(d_model=2048, num_encoder_layers=12,
num_decoder_layers=12)
l2 = auto_wrap(large_tfmr)
assert isinstance(l2.encoder, FSDP)
assert isinstance(l2.decoder, FSDP)
print(l2) # You can print the model to examine FSDP wrapping.
teardown()
rmf("/tmp/t1")
rmf("/tmp/t2")
.. warning::
The optimizer must be initialized *after* the module has been wrapped,
since FSDP will shard parameters in-place and this will break any
previously initialized optimizers.
.. warning::
If you wrap every parameter inside a nested FSDP and leaving the outer
FSDP empty without any parameter, checkpointing activation may trigger
an assert on the backward pass. The solution is to leave some parameters
to the outer FSDP.
.. warning::
If activation checkpointing is used with FSDP, it is strongly encouraged
to use ``checkpoint_wrapper`` function from FairScale instead of the
``checkpoint`` function from PyTorch.
Args:
module (nn.Module):
module to be wrapped with FSDP.
process_group (Optional):
process group for sharding
process_group_reduce_scatter (Optional):
process group for reduce scatter
it defaults to ProcessGroupName.reduce_scatter. A seperate process group is initialized and assigned to the reduce_scatter operation. And the
reduce_scatter operation overlaps with other operations in the backward propagation
If it is a specific ProcessGroup, the reduce_scatter operates on this ProcessGroup, and the overlap still happens.
To disable the overlap feature, set the process group to ProcessGroupName.default. In this case, the reduce_scatter
operation uses the same process group with the default group.
If reduce scatter process group size is differnt with the default process group size, the reduce_scatter
operation rolls back to use the same process group with the default process group.
reshard_after_forward (bool, Optional):
if ``True``, reshard parameters after the forward pass. This saves
memory but slows training. This is only relevant when resharding
individual layers.
disable_reshard_on_root (bool, Optional):
If ``True``, ``reshard_after_forward`` will be set to ``False`` if the module is a
FSDP root module to improve performance. For some cases, we do not reshard the full
parameters of an FSDP root module since those parameters are needed immediately for the
backward pass.
If ``False``, the performance will be lower, but it is needed because it helps to
save memory. Consider a case that an FSDP root module is a submodule of a model.
Backward pass may not start immediate after the FSDP root module finishes its forward.
So, reshard the parameters for the FSDP root modules can help to save memory in this case.
Default: True.
mixed_precision (bool, Optional):
if ``True``, inputs, activations and gradients will be kept in FP16;
computation and communication will occur in FP16; and a (sharded)
master copy of the model weights will be maintained in FP32.
fp32_reduce_scatter (bool, Optional):
if ``True``, then reduce-scatter gradients in FP32. This is only
relevant when *``mixed_precision``* is ``True``.
flatten_parameters (bool, Optional):
if ``True``, flatten parameters into a single contiguous tensor,
which improves training speed.
move_params_to_cpu (bool, Optional):
if ``True``, offload params to CPU.
compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case
it defaults to ``torch.float16``.
buffer_dtype (torch.dtype, Optional):
dtype for buffers for computation. This defaults to ``compute_dtype``.
move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after reduction. This is useful when
combined with CPU-based optimizers. It defaults to the value of
*``move_params_to_cpu``*.
bucket_cap_mb (int, Optional):
FSDP will bucket parameters so that gradient reduction can
be more efficient for small parameters.
``bucket_cap_mb`` controls the bucket size in MegaBytes (MB). Buckets
are sub-divided based on world_size, so the max shard size is roughly
``bucket_cap_mb / world_size``. There is one bucketer (with potentially
multiple ``bucket_cap_mb`` sized buffers shared by all FSDP instances.
Large gradient tensors are directly reduced without using the buffers.
The buffers are there to reduce communication overhead for small tensors.
Overlapping with computation happens due to use of a different CUDA stream
than the computation CUDA stream. The total memory overhead per buffer is around
``bucket_cap_mb / world_size * (world_size + 1)``.
The buffers are allocated during the backward pass and freed at the end
of the backward pass to save more memory for other phases of the
training process.
Note, the memory vs. speed tradeoff of bucket size is very different
from that of the DDP engine. In DDP, the buffer size ``1MB + n*cap_mb``,
until n is big enough to cover the entire model size. The order
of which buffer is ready there is more rigid and DDP requires all
gradients to be computed in the backward. In FSDP, the buffer size
does not change with model size (it changes based on number of
<dtype, device, process_group> tuples) and gradient ready order matters
little since FSDP has a final flush call that ensures everything is reduced
and not all gradients need to be upfront known. Overlapping with compute is
done differently too.
Values <= 0 disable bucketing.
Default: 25.
compute_device (torch.device, Optional):
device for computation. If not given and module params are on a CUDA
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
no_broadcast_optim_state: (bool, Optional)
do not broadcast this modules optimizer state when ``gather_full_optim_state_dict`` is called.
If you set this true, you are expected to overwrite the relevant state entries of the returned optimizer state dict
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node.
Default: False
state_dict_device (torch.device, Optional):
device for parameters returned by :func:`state_dict`. If not given,
this will default to ``compute_dtype``. Note that only the device
type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
clear_autocast_cache (bool):
When using mixed precision training with `torch.amp.autocast`, if the model weights
are in FP32, autocast maintains a cache for downcasted weights. The cache can cause
GPU OOM during the forward pass. Setting this flag to true will help clearing this
cache as inner FSDP instances finish part of the forward pass to save GPU memory.
Default: False
force_input_to_fp32 (bool):
Set to ``True`` to force input floating point tensors to be FP32 (if they are FP16)
when the FSDP instance is in full precision mode. This helps avoid issues of running
SyncBatchNorm with AMP and checkpoint_wrapper.
Default: False
verbose (bool):
Set this to ``True`` to turn on verbose output for model's string representation.
Default: False
cpu_offload (bool, Optional):
if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of
*``move_params_to_cpu``* in an upcoming release.
offload_config (OffloadConfig):
The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and
other required knobs when offloading parameters from GPU. Currently the OffloadConfig
only supports specifying SSD offload as an option. Note: This is an experimental feature.
state_dict_on_rank_0_only (bool):
When set to ``True``, ``model.state_dict()`` will only returns full state dict on
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
skip the GPU -> CPU copy on non-rank 0 altogether and prevent OOM.
Default: False
optimize_backward_concat (bool):
If True, only let backward pass propagate to self.params, which will
invoke the _post_backward_hook() and concat() op, when self._require_backward_grad_sync
is True (e.g. last microbatch)
Default: False
NOTE: this likely will incur more GPU memory usage
"""
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
# The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
process_group_reduce_scatter: Any = ProcessGroupName.default,
reshard_after_forward: bool = True,
disable_reshard_on_root: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
flatten_parameters: bool = True,
move_params_to_cpu: bool = False,
compute_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
no_broadcast_optim_state: Optional[bool] = False,
state_dict_device: Optional[torch.device] = None,
clear_autocast_cache: bool = False,
force_input_to_fp32: bool = False,
verbose: bool = False,
cpu_offload: bool = False,
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None,
limit_all_gather_events: bool = False,
limit_reduce_scatter_events: bool = False,
cast_input: bool = True,
optimize_backward_concat: bool = False,
):
try:
import torch._C
torch._C._log_api_usage_once("fairscale.fsdp")
except ImportError:
pass
init_start = time.time()
super().__init__()
self.process_group = process_group or get_process_group_cached()
# If ProcessGroupName.default is passed in, the reduce_scatter will use the same process group with
# the rest of operations. The overlap feature in the backward propagation is disabled.
if process_group_reduce_scatter == ProcessGroupName.default:
self.process_group_reduce_scatter = self.process_group
# If ProcessGroupName.reduce_scatter is passed in, the reduce_scatter use a seperate process group
# so that the overlap feature in the backward propagagion is enabled.
elif process_group_reduce_scatter == ProcessGroupName.reduce_scatter:
self.process_group_reduce_scatter = get_process_group_cached(ProcessGroupName.reduce_scatter)
else:
# If a specific process group is passed in, the reduce_scatter will use the passed in process group.
if isinstance(process_group_reduce_scatter, ProcessGroup):
self.process_group_reduce_scatter = process_group_reduce_scatter
else:
if not hasattr(process_group_reduce_scatter, "allgather") and hasattr(
process_group_reduce_scatter, "rank"
):
# Likely a dummy pg for unit test
self.process_group_reduce_scatter = process_group_reduce_scatter
else:
raise TypeError("unsupported type for reduce_scatter process group")
self.rank = self.process_group.rank()
self.world_size = self.process_group.size()
# In a unit test dummy enviromnent, the process_group_reduce_scatter can be None.
if self.process_group_reduce_scatter is not None:
reduce_scatter_group_size = self.process_group_reduce_scatter.size()
# Roll back to use the default process group for reduce scatter operation when
# the world size and reduce scatter process group size are differnt.
if self.world_size != reduce_scatter_group_size:
self.process_group_reduce_scatter = self.process_group
logging.warn(
"Rolled back to use the default process group for the reduce scatter "
"operation because the reduce_scatter process group "
f"size is {reduce_scatter_group_size}, which is different with the "
f"world size {self.world_size}. Please make sure the process_group "
"parameter uses all the available ranks for the optimal performance."
)
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root
self.mixed_precision = mixed_precision
self.cast_input = cast_input
self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters
self.move_params_to_cpu = move_params_to_cpu or cpu_offload
self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = self.move_params_to_cpu if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.compute_device = compute_device or _get_default_cuda_device(module)
self.uncollected_opt_state: Dict[int, Dict] = {}
self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device
self.clear_autocast_cache = clear_autocast_cache
self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size
)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.numel_padded_per_param: List[int] = []
self._tstart = time.time()
# if self.fp32_reduce_scatter and not self.mixed_precision:
# raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.ssd_offload and not self.flatten_parameters:
raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")
# skip validation if the process group was created above
if process_group:
validate_process_group(self.compute_device, self.process_group)
# enable pytorch sync_bn just in case model contains sync_bn layers.
enable_pytorch_sync_bn(module)
# Only handle params which are not already sharded. This enables
# sharding individual layers of a Module, with an outer wrapper to
# shard any leftover parameters.
param_names = []
params = []
for param_name, param in module.named_parameters():
if not hasattr(param, "_is_sharded"):
param_names.append(param_name)
params.append(param)
self._has_params = len(params) > 0
self._has_shared_params = False
# TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding.
self.buffer_size = sum(p.numel() for p in params)
self.ssd_directory = tempfile.gettempdir()
if self.ssd_offload:
assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature."
if offload_config and offload_config.dir:
self.ssd_directory = offload_config.dir
self.move_grads_to_cpu = True
self.move_params_to_cpu = True
# For now, it is either all flatten or none flatten. This will be extended to
# multiple flatten groups in my next PR.
to_be_flatten_params: List[List[Parameter]] = [[]]
non_flatten_params = params
param_name_groups = [[n] for n in param_names]
if self.flatten_parameters:
to_be_flatten_params = [params]
non_flatten_params = []
param_name_groups = [param_names]
del param_names
self.optimize_backward_concat = optimize_backward_concat
if self.optimize_backward_concat:
assert self.fp32_reduce_scatter, f"{optimize_backward_concat=} requires self.fp32_reduce_scatter=True"
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory, optimize_backward_concat=self.optimize_backward_concat,
)
del module # free original module in case it helps garbage collection
# Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
# params for doing sharding, gradient hooks, etc. Note, the ordering of the
# list matters: flatten params are always in the front.
#
# The self._num_flatten_params and self._param_name_groups are computed
# and kept here to support summon_full_params and shard-to-full weight
# consolidation.
self.params = cast(List[Parameter], self._fsdp_wrapped_module.flat_params) + non_flatten_params
self._num_flatten_params = len(self._fsdp_wrapped_module.flat_params)
self._param_name_groups = param_name_groups
# Shard module parameters in place
self._shard_parameters_()
# Make sure all parameters are sharded.
for n, p in self.named_parameters():
assert hasattr(p, "_is_sharded"), f"found unsharded parameter: {n} ; {p.size()}"
self._reset_lazy_init()
# Flag to indicate if we require gradient reduction in the backward
# pass. This will be False when inside the no_sync context manager.
self._require_backward_grad_sync: bool = True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE
# Flag to indicate if the full params are gathered.
self.has_full_params: bool = False
# Register hook after state_dict() to remove the "_fsdp_wrapped_module."
# prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(functools.partial(_post_state_dict_hook, self.state_dict_on_rank_0_only))
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
# TODO(anj): This should by default be set to False for ssd_offload=True
# unless we are in the summon_full_params context.
self._return_full_state_dict = True
init_end = time.time()
logging.debug(
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
)
# Flag to guard against preparing gradients multiple times per iteration.
# This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False
# Free all params at the end of initialization.
if self.ssd_offload:
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
m._free_ssd_offload()
self.dont_wait_current_stream_for_post_all_gather = False
self._all_gather_free_event_queue = _FreeEventQueue() if limit_all_gather_events else None
self._reduce_scatter_free_event_queue = _FreeEventQueue() if limit_reduce_scatter_events else None
def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
factor *= 2
return float(factor)
def set_gradient_divide_factors(self, pre: float, post: float, recursive: bool) -> None:
"""Allowing user to override the pre and post divide factors.
Args:
pre (float): divide factor before the reduction.
post (float): divide factor after the reduction.
recursive (bool): recursively set it for all child FSDP instances or not.
"""
self.assert_state(TrainingState.IDLE)
if recursive:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel) and module != self:
module.set_gradient_divide_factors(pre, post, False)
self.gradient_predivide_factor = pre
self.gradient_postdivide_factor = post
@property
def module(self) -> FlattenParamsWrapper:
"""make model.module accessible, just like DDP."""
assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper)
return self._fsdp_wrapped_module
def append_shared_param(self, p: Parameter) -> None:
"""Add a param that's already owned by another FSDP wrapper.
.. warning:: This is experimental!
This only works with all sharing FSDP modules are un-flattened.
p must to be already sharded by the owning module.
Check the corresponding unit test to see how is it used and tested.
In particular, the sharing FSDP wrappers are "siblings" not "parent"
and "child" of each other in the nested module structure.
Args:
p (Parameter):
The shared parameter.
"""
assert self._is_root is None
assert not self.flatten_parameters
assert isinstance(p, Parameter)
assert p._is_sharded
p._is_shared = True
assert (
len(list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))) > 0
), "Must have at least 1 non-shared param."
self.params.append(p)
self._has_shared_params = True
def non_shared_params(self) -> List[nn.Parameter]:
"""Return the list of non-shared parameters."""
if self._has_shared_params:
return list(filter(lambda p: not (hasattr(p, "_is_shared") and p._is_shared), self.params))
else:
return self.params
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""
Applies ``fn`` recursively to every submodule (as returned by
``.children()``) as well as self. Typical use includes initializing the
parameters of a model.
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from
within another ``summon_full_params`` context.
Args:
fn (nn.Module): function to be applied to each submodule
Returns:
Module: self
"""
is_uninitialized = self._is_root is None
self.assert_state(TrainingState.IDLE)
with self.summon_full_params(recurse=False):
return_value = super().apply(fn)
# summon_full_params will call _lazy_init, which sets _is_root. However,
# apply() may be called directly on children instances to do weight
# init, so we should reset the _is_root flag in this case.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return return_value
def _cast_buffers(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
) -> None:
"""Move all buffers to the given *device* and *dtype*.
If *device* or *dtype* are not given, then they will default to
``self.compute_device`` and ``self.buffer_dtype``, respectively. In the
case of nested FSDP instances, we will respect the child instance's
``compute_device`` and ``buffer_dtype`` configuration.
Args:
device (torch.device, Optional):
device to cast buffers to (defaults to compute_device)
dtype (torch.dtype, Optional):
dtype to cast buffers to (defaults to buffer_dtype)
memo (Set, Optional):
set of modules that have already been processed
"""
if memo is None:
memo = set()
for module in self.modules():
if module is not self and isinstance(module, FullyShardedDataParallel):
# Allow any child FSDP instances to handle their own buffers.
module._cast_buffers(device=device, dtype=dtype, memo=memo)
elif module not in memo:
memo.add(module)
for name, buf in module.named_buffers(recurse=False):
if buf is None:
continue
buf = buf.to(device=device or self.compute_device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype or self.buffer_dtype)
setattr(module, name, buf)
@property
def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None]"""
return [p for p in self.parameters() if (p.requires_grad and (p.grad is not None or p.main_grad is not None))]
@torch.no_grad()
def clip_grad_norm_(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
# filter_params_fn: Callable[[Any], Any] = None,
) -> torch.Tensor:
"""
Clip all gradients at this point in time. The norm is computed over all
gradients together, as if they were concatenated into a single vector.
Gradients are modified in-place.
Args:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'``
for infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
.. note:: This is analogous to `torch.nn.utils.clip_grad_norm_` but
handles the partitioning and multiple devices per rank under the
hood. The default torch util is not applicable here, because each
rank only has a partial view of all the grads in the model, so
calling it in the OSS context would lead to different scaling being
applied per subset of model parameters.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
# We don't call torch.cuda.synchronize() here, since clipping can be
# inside the train loop and we probably don't want to force a GPU-CPU sync.
# _lazy_init should be sufficient, since it will force the other streams
# to sync with the default stream (via _wait_for_previous_optim_step).
self._lazy_init()
assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance"
self.assert_state(TrainingState.IDLE)
max_norm = float(max_norm)
norm_type = float(norm_type)
params_with_grad = self.params_with_grad
if not self.children_share_process_group:
raise NotImplementedError(
"clip_grad_norm requires that all params share one process group. clip_grad_by_value_ should work"
)
# Computes the max norm for this shard's gradients and sync's across workers
local_norm = calc_grad_norm(params_with_grad, norm_type).cuda()
if norm_type == inf:
total_norm = local_norm
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
else:
total_norm = local_norm**norm_type
dist.all_reduce(total_norm, group=self.process_group)
total_norm = total_norm ** (1.0 / norm_type)
if self.move_grads_to_cpu:
total_norm = total_norm.cpu()
# Now multiply each grad by (max_norm/total_norm), same as torch 1.7 https://tinyurl.com/3wtxhhqq)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
# multiply by clip_coef
for p in params_with_grad:
assert p.grad is not None
p.grad.detach().mul_(clip_coef.to(p.grad.device))
return total_norm
@torch.no_grad()
def _shard_parameters_(self) -> None:
"""
At initialization we wrap a module with full parameters and shard the
parameters in-place. Sharding is implemented by viewing each parameter
as a 1D Tensor and retaining only a single slice, where the slice size
is determined by the number of data parallel workers.
Wrapping modules with many small parameters (or with a very large data
parallel world size) will result in many small parameter shards and slow
performance. In this case it's better to set *``flatten_parameters``* to
``True``, so that all of the small parameters in the module are combined
into a single contiguous Tensor and sharded once.
After this initial sharding is complete, the user can initialize a
``torch.optim.Optimizer`` in the usual way, i.e.::
.. code-block:: python
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
The optimizer will see only a single slice of parameters and will thus
allocate less memory for optimizer state, avoiding redundancy across
data parallel workers.
"""
self.numel_padded_per_param = []
for p in self.params:
assert not hasattr(p, "_is_sharded")
assert p.is_floating_point()
if self.mixed_precision:
assert p.dtype == torch.float32
# If world_size is 1, then we all-reduce grads instead of sharding.
p._is_sharded = self.world_size > 1
p._orig_size = p.data.size()
if not p._is_sharded:
if not self.ssd_offload:
p._is_sharded = False
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True
# Replace p.data with the relevant shard.
if self.ssd_offload:
assert isinstance(p, SsdFlatParameter)
sharded_tensor, num_padded = self._get_shard(p.data)
p.point_to_resized_tensor(sharded_tensor)
self.numel_padded_per_param.append(num_padded)
p.to_file()
else:
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params)
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(self.world_size))
while len(chunks) < self.world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[self.rank].numel()
assert num_to_pad >= 0, num_to_pad
shard = chunks[self.rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard, num_to_pad
def extra_repr(self) -> str:
repr = (
f"world_size={self.world_size}, "
f"flatten_parameters={self.flatten_parameters}, "
f"mixed_precision={self.mixed_precision}, "
)
if self.verbose:
repr = (
f"self={id(self)} is_root={self._is_root}, "
f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"compute_device={self.compute_device}"
f"move_params_to_cpu={self.move_params_to_cpu}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
f"optimize_backward_concat={self.optimize_backward_concat}"
)
return repr
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.module, name)
def __getstate__(self) -> Dict[str, str]:
"""Serialize the state of the current FSDP instance.
Some properties are not serializable (e.g., process groups, streams), so
we remove them and try to reconstruct them in :func:`__setstate__`.
"""
state = copy.copy(self.__dict__)
state["is_sharded"] = [p._is_sharded for p in self.params]
state["orig_sizes"] = [p._orig_size for p in self.params]
if state["process_group"] is not None:
state["process_group"] = "MISSING" # process_group isn't pickleable
if state["process_group_reduce_scatter"] is not None:
state["process_group_reduce_scatter"] = "MISSING" # process_group_reduce_scatter isn't pickleable
self._reset_lazy_init()
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Intercept state setting and perform needed changes on params."""
super().__setstate__(state)
def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter:
assert isinstance(p, Parameter)
p.data = p.data.clone() # move tensors out of shared memory
p._is_sharded = is_sharded
p._orig_size = size
return p
self.params = [
fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes)
]
del self.is_sharded
del self.orig_sizes
self._reset_lazy_init()
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
"""Returns an iterator over the module parameters, yielding all the parameters
part of the model.
"""
return super().parameters(recurse=recurse)
def named_parameters(self, *args: Any, **kwargs: Any) -> Iterator[Tuple[str, Parameter]]:
"""Returns an iterator over the module parameters, yielding both the name of the
parameter as well as the parameter.
With FSDP, the `named_parameters` function implemented in `nn.Module` will not
be able to return the name and param when we use flattened parameters unless
we call this function under a `summon_full_params` context.
If you want the full param to be returned, you should call this function
under a `summon_full_params` context when using flattened or original params.
"""
named_param = super().named_parameters(*args, **kwargs)
for name, param in named_param:
if (
hasattr(self, "flatten_parameters")
and self.flatten_parameters
and hasattr(self, "training_state")
and self.training_state != TrainingState.SUMMON_FULL_PARAMS
):
yield name, param
else:
yield _clean_path(name), param
def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self.module.__getitem__(key)
@typing.overload
def state_dict(
self, destination: Mapping[str, torch.Tensor], prefix: str = ..., keep_vars: bool = ...
) -> Mapping[str, torch.Tensor]:
...
@typing.overload
def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> "OrderedDict[str, torch.Tensor]":
...
# Since we have overloads above, we can use Any here.
def state_dict(self, *args: Any, **kwargs: Any) -> Any:
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors
will be full precision (e.g., FP32).
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
is_uninitialized = self._is_root is None # See comment below on why we use this.
self._lazy_init()
def maybe_cast_buffers(dtype: Optional[torch.dtype] = None) -> None:
if self.mixed_precision:
self._cast_buffers(dtype=dtype)
if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
with self.summon_full_params(recurse=False, volatile=True):
maybe_cast_buffers(torch.float32)
state_dict = super().state_dict(*args, **kwargs)
else:
maybe_cast_buffers(torch.float32)
state_dict = super().state_dict(*args, **kwargs)
else:
maybe_cast_buffers(torch.float32)
state_dict = self.module.flat_state_dict(*args, **kwargs)
if self.move_params_to_cpu:
for k in state_dict.keys():
state_dict[k] = state_dict[k].cpu()
# In case we are in mixed precision, restore buffers back to buffer_dtype.
maybe_cast_buffers()
# We shouldn't change the init state in case this was an inner module and
# users simply wanted to get state_dict before training.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()