Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a utility for writing a barebones config file #371

Merged
merged 10 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/internal.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ The main work on your PyTorch `DataLoader` is done by the following function:
[[autodoc]] utils.synchronize_rng_states

[[autodoc]] utils.wait_for_everyone

[[autodoc]] utils.write_basic_config
9 changes: 8 additions & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,12 @@

from .launch import PrepareForLaunch
from .memory import find_executable_batch_size
from .other import extract_model_from_parallel, get_pretty_name, patch_environment, save, wait_for_everyone
from .other import (
extract_model_from_parallel,
get_pretty_name,
patch_environment,
save,
wait_for_everyone,
write_basic_config,
)
from .random import set_seed, synchronize_rng_state, synchronize_rng_states
43 changes: 43 additions & 0 deletions src/accelerate/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

import os
from contextlib import contextmanager
from pathlib import Path

import torch

from ..commands.config.cluster import ClusterConfig
from ..commands.config.config_args import default_json_config_file
from ..state import AcceleratorState
from .dataclasses import DistributedType
from .imports import is_deepspeed_available, is_tpu_available
Expand Down Expand Up @@ -109,3 +112,43 @@ def get_pretty_name(obj):
if hasattr(obj, "__name__"):
return obj.__name__
return str(obj)


def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file):
"""
Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also
set CPU if it is a CPU-only machine.

Args:
mixed_precision (`str`, *optional*, defaults to "no"):
Mixed Precision to use. Should be one of "no", "fp16", or "bf16"
save_location (`str`, *optional*, defaults to "~/.cache/huggingface/accelerate/default_config.yaml"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be the default location for everyone, so I'd just say defaults to the default config path and expand a bit in the docstring that it's inside the default cache of hugging face, and may add which env variable control it.

Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`.
"""
path = Path(save_location)
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists():
print(
f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`."
)
return
mixed_precision = mixed_precision.lower()
if mixed_precision not in ["no", "fp16", "bf16"]:
raise ValueError(f"`mixed_precision` should be one of 'no', 'fp16', or 'bf16'. Received {mixed_precision}")
config = {"compute_environment": "LOCAL_MACHINE", "mixed_precision": mixed_precision}
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
config["num_processes"] = num_gpus
config["use_cpu"] = False
if num_gpus > 1:
config["distributed_type"] = "MULTI_GPU"
else:
config["distributed_type"] = "NO"
else:
num_gpus = 0
config["use_cpu"] = True
config["num_processes"] = 1
config["distributed_type"] = "NO"
if not path.exists():
config = ClusterConfig(**config)
config.to_json_file(path)