Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for downcasting bf16 on TPUs #523

Merged
merged 7 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ def __init__(
**kwargs,
)

if (
(mixed_precision != "bf16")
and getattr(self.state, "downcast_bfloat", False)
and (self.state.distributedType != DistributedType.TPU)
):
raise ValueError("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU")

if gradient_accumulation_steps > 1:
if self.state.distributed_type == DistributedType.TPU:
raise NotImplementedError(
Expand Down
7 changes: 7 additions & 0 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,18 @@ def get_cluster_input():
else:
mixed_precision = "no"

downcast_bf16 = "no"
if distributed_type == DistributedType.TPU and mixed_precision == "bf16":
downcast_bf16 = _ask_field(
"Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no"
)

return ClusterConfig(
compute_environment=ComputeEnvironment.LOCAL_MACHINE,
distributed_type=distributed_type,
num_processes=num_processes,
mixed_precision=mixed_precision,
downcast_bf16=downcast_bf16,
machine_rank=machine_rank,
num_machines=num_machines,
main_process_ip=main_process_ip,
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/commands/config/config_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class ClusterConfig(BaseConfig):
deepspeed_config: dict = None
# args for fsdp
fsdp_config: dict = None
# args for TPU
downcast_bf16: bool = False

def __post_init__(self):
if self.deepspeed_config is None:
Expand Down
19 changes: 18 additions & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
get_launch_prefix,
is_deepspeed_available,
is_sagemaker_available,
patch_environment,
)
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
from accelerate.utils.dataclasses import SageMakerDistributedType
Expand Down Expand Up @@ -198,6 +199,11 @@ def launch_command_parser(subparsers=None):
default=None,
help="The name of the main function to be executed in your script (only for TPU training).",
)
parser.add_argument(
"--downcast_bf16",
action="store_true",
help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32",
)
parser.add_argument(
"-m",
"--module",
Expand Down Expand Up @@ -425,9 +431,19 @@ def deepspeed_launcher(args):
def tpu_launcher(args):
import torch_xla.distributed.xla_multiprocessing as xmp

current_env = {}

if args.no_python:
raise ValueError("--no_python cannot be used with TPU launcher")

if args.mixed_precision == "bf16":
if args.downcast_bf16:
current_env["XLA_USE_BF16"] = "0"
current_env["XLA_DOWNCAST_BF16"] = "1"
else:
current_env["XLA_USE_BF16"] = "1"
current_env["XLA_DOWNCAST_BF16"] = "0"

if args.module:
mod_name = args.training_script
else:
Expand All @@ -447,7 +463,8 @@ def tpu_launcher(args):
sys.argv = [mod.__file__] + args.training_script_args

main_function = getattr(mod, args.main_training_function)
xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)
with patch_environment(**current_env):
xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)


def _convert_nargs_to_dict(nargs: List[str]) -> Dict[str, str]:
Expand Down
9 changes: 8 additions & 1 deletion src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,14 @@ def __init__(
self.local_process_index = xm.get_local_ordinal()
self.device = xm.xla_device()
if mixed_precision == "bf16":
os.environ["XLA_USE_BF16"] = str(1)
if os.environ.get("DOWNCAST_BF16"):
os.environ["XLA_USE_BF16"] = str(0)
os.environ["XLA_DOWNCAST_BF16"] = str(1)
self.downcast_bfloat = True
else:
os.environ["XLA_USE_BF16"] = str(1)
os.environ["XLA_DOWNCAST_BF16"] = str(0)
self.downcast_bfloat = False
self.mixed_precision = mixed_precision
elif os.environ.get("USE_DEEPSPEED", "false") == "true" and not cpu:
assert (
Expand Down