Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a utility for writing a barebones config file #371

Merged
merged 10 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/internal.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ The main work on your PyTorch `DataLoader` is done by the following function:
[[autodoc]] utils.synchronize_rng_states

[[autodoc]] utils.wait_for_everyone

[[autodoc]] utils.write_basic_config
9 changes: 8 additions & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,12 @@

from .launch import PrepareForLaunch
from .memory import find_executable_batch_size
from .other import extract_model_from_parallel, get_pretty_name, patch_environment, save, wait_for_everyone
from .other import (
extract_model_from_parallel,
get_pretty_name,
patch_environment,
save,
wait_for_everyone,
write_basic_config,
)
from .random import set_seed, synchronize_rng_state, synchronize_rng_states
41 changes: 41 additions & 0 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import os
from contextlib import contextmanager
from pathlib import Path

import torch

from ..commands.config.cluster import ClusterConfig
from ..commands.config.config_args import default_json_config_file
from ..state import AcceleratorState
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_tpu_available
Expand Down Expand Up @@ -109,3 +112,41 @@ def get_pretty_name(obj):
if hasattr(obj, "__name__"):
return obj.__name__
return str(obj)


def write_basic_config(mixed_precision="no"):
"""
Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also
set CPU if it is a CPU-only machine.

Args:
mixed_precision (`str`, *optional*, defaults to "no"):
Mixed Precision to use. Should be one of "no", "fp16", or "bf16"
"""
path = Path(default_json_config_file)
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists():
print(
"User configuration already setup, will not override existing configuration. Run `accelerate config` manually."
)
return
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
mixed_precision = mixed_precision.lower()
if mixed_precision not in ["no", "fp16", "bf16"]:
raise ValueError(f"`mixed_precision` should be one of 'no', 'fp16', or 'bf16'. Received {mixed_precision}")
config = {"compute_environment": "LOCAL_MACHINE", "mixed_precision": mixed_precision}
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
config["num_processes"] = num_gpus
config["use_cpu"] = False
if num_gpus > 1:
config["distributed_type"] = "MULTI_GPU"
else:
config["distributed_type"] = "NO"
else:
num_gpus = 0
config["use_cpu"] = True
config["num_processes"] = 1
config["distributed_type"] = "NO"
if not path.exists():
config = ClusterConfig(**config)
config.to_json_file(path)