Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Bump fairscale from 0.4.0 to 0.4.2 (#5461)
Browse files Browse the repository at this point in the history
* Bump fairscale from 0.4.0 to 0.4.2

Bumps [fairscale]() from 0.4.0 to 0.4.2.

---
updated-dependencies:
- dependency-name: fairscale
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <[email protected]>

* fix FairScale checkpoint wrapper wrapper

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Dirk Groeneveld <[email protected]>
Co-authored-by: epwalsh <[email protected]>
  • Loading branch information
3 people authored Nov 15, 2021
1 parent 923dbde commit 05fc7f6
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 19 deletions.
19 changes: 1 addition & 18 deletions allennlp/nn/checkpoint/fairscale_checkpoint_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,10 @@ class FairScaleCheckpointWrapper(CheckpointWrapper):
:class:`allennlp.nn.parallel.fairscale_fsdp_accelerator.FairScaleFsdpAccelerator`.
See the [T5 implementation](/api/modules/transformer/t5/) for an example
of how to use the two together.
!!! Note
If using the `FairScaleFsdpAccelerator`, you need to set `maintain_forward_counter` to `True`.
For convenience, if `maintain_forward_counter` is not set, internally it will be
set to `True` if training in a distributed setup, or `False` otherwise.
"""

def __init__(
self, offload_to_cpu: Optional[bool] = True, maintain_forward_counter: Optional[bool] = None
) -> None:
def __init__(self, offload_to_cpu: Optional[bool] = True) -> None:
self._offload_to_cpu = offload_to_cpu
if maintain_forward_counter is None:
from allennlp.common.util import is_distributed

# Better to assume we need this in the distributed case, since we definitely
# need this when the model is wrapped with FairScale's FSDP.
self._maintain_forward_counter = is_distributed()
else:
self._maintain_forward_counter = maintain_forward_counter

@overrides
def wrap_module(
Expand All @@ -50,6 +35,4 @@ def wrap_module(
) -> nn.Module:
if "offload_to_cpu" not in kwargs and self._offload_to_cpu is not None:
kwargs["offload_to_cpu"] = self._offload_to_cpu
if "maintain_forward_counter" not in kwargs and self._maintain_forward_counter is not None:
kwargs["maintain_forward_counter"] = self._maintain_forward_counter
return checkpoint_wrapper(module, **kwargs)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"torch>=1.6.0,<1.11.0",
"torchvision>=0.8.1,<0.12.0",
"cached-path>=0.3.1,<0.4.0",
"fairscale==0.4.0",
"fairscale==0.4.2",
"jsonnet>=0.10.0 ; sys.platform != 'win32'",
"overrides==3.1.0",
"nltk",
Expand Down

0 comments on commit 05fc7f6

Please sign in to comment.