diff --git a/allennlp/nn/checkpoint/fairscale_checkpoint_wrapper.py b/allennlp/nn/checkpoint/fairscale_checkpoint_wrapper.py index e95325555a4..e136a81a936 100644 --- a/allennlp/nn/checkpoint/fairscale_checkpoint_wrapper.py +++ b/allennlp/nn/checkpoint/fairscale_checkpoint_wrapper.py @@ -22,25 +22,10 @@ class FairScaleCheckpointWrapper(CheckpointWrapper): :class:`allennlp.nn.parallel.fairscale_fsdp_accelerator.FairScaleFsdpAccelerator`. See the [T5 implementation](/api/modules/transformer/t5/) for an example of how to use the two together. - - !!! Note - If using the `FairScaleFsdpAccelerator`, you need to set `maintain_forward_counter` to `True`. - For convenience, if `maintain_forward_counter` is not set, internally it will be - set to `True` if training in a distributed setup, or `False` otherwise. """ - def __init__( - self, offload_to_cpu: Optional[bool] = True, maintain_forward_counter: Optional[bool] = None - ) -> None: + def __init__(self, offload_to_cpu: Optional[bool] = True) -> None: self._offload_to_cpu = offload_to_cpu - if maintain_forward_counter is None: - from allennlp.common.util import is_distributed - - # Better to assume we need this in the distributed case, since we definitely - # need this when the model is wrapped with FairScale's FSDP. - self._maintain_forward_counter = is_distributed() - else: - self._maintain_forward_counter = maintain_forward_counter @overrides def wrap_module( @@ -50,6 +35,4 @@ def wrap_module( ) -> nn.Module: if "offload_to_cpu" not in kwargs and self._offload_to_cpu is not None: kwargs["offload_to_cpu"] = self._offload_to_cpu - if "maintain_forward_counter" not in kwargs and self._maintain_forward_counter is not None: - kwargs["maintain_forward_counter"] = self._maintain_forward_counter return checkpoint_wrapper(module, **kwargs) diff --git a/setup.py b/setup.py index e6a8033bd59..ab42822651e 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ "torch>=1.6.0,<1.11.0", "torchvision>=0.8.1,<0.12.0", "cached-path>=0.3.1,<0.4.0", - "fairscale==0.4.0", + "fairscale==0.4.2", "jsonnet>=0.10.0 ; sys.platform != 'win32'", "overrides==3.1.0", "nltk",