Skip to content

Commit

Permalink
enhance fsdp for 3rd devices
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Oct 19, 2024
1 parent 31e3812 commit f863645
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -463,11 +463,16 @@ def _check_strategy_and_fallback(self) -> None:

if (
strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy
) and self._accelerator_flag not in ("cuda", "gpu"):
) and self._accelerator_flag not in ("cuda", "gpu") and isinstance(self._accelerator_flag, str):
raise ValueError(
f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:"
f" {self._accelerator_flag}"
)
elif isinstance(self._accelerator_flag, Accelerator):
Warning(
f"Using a custom accelerator `{self._accelerator_flag.__class__.__name__}`."
f" Please ensure it is compatible with the selected strategy `{strategy_flag}`."
)
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
raise ValueError(
f"You selected `Trainer(strategy='{strategy_flag}')` but process forking is not supported on this"
Expand Down Expand Up @@ -501,7 +506,7 @@ def _check_and_init_precision(self) -> Precision:
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_flag) # type: ignore[arg-type]
if isinstance(self.strategy, FSDPStrategy):
return FSDPPrecision(precision=self._precision_input, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None)
return FSDPPrecision(precision=self._precision_flag, device=self._accelerator_flag.get_device() if isinstance(self._accelerator_flag, Accelerator) else None)
if self._precision_flag in ("16-true", "bf16-true"):
return HalfPrecision(self._precision_flag) # type: ignore
if self._precision_flag == "32-true":
Expand Down

0 comments on commit f863645

Please sign in to comment.