Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use override=True for config1 in adjust_config_by_multi_process_divid… #1714

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions alf/config_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int,
tag,
math.ceil(num_parallel_environments / multi_process_divider),
raise_if_used=False,
override_sole_init=True)
override_all=True)

# Adjust the mini_batch_size. If the original configured value is 64 and
# there are 4 processes, it should mean that "jointly the 4 processes have
Expand All @@ -167,7 +167,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int,
tag,
math.ceil(mini_batch_size / multi_process_divider),
raise_if_used=False,
override_sole_init=True)
override_all=True)

# If the termination condition is num_env_steps instead of num_iterations,
# we need to adjust it as well since each process only sees env steps taking
Expand All @@ -179,15 +179,15 @@ def adjust_config_by_multi_process_divider(ddp_rank: int,
tag,
math.ceil(num_env_steps / multi_process_divider),
raise_if_used=False,
override_sole_init=True)
override_all=True)

tag = 'TrainerConfig.initial_collect_steps'
init_collect_steps = get_config_value(tag)
config1(
tag,
math.ceil(init_collect_steps / multi_process_divider),
raise_if_used=False,
override_sole_init=True)
override_all=True)

# Only allow process with rank 0 to have evaluate. Enabling evaluation for
# other parallel processes is a waste as such evaluation does not offer more
Expand All @@ -197,7 +197,7 @@ def adjust_config_by_multi_process_divider(ddp_rank: int,
'TrainerConfig.evaluate',
False,
raise_if_used=False,
override_sole_init=True)
override_all=True)


def parse_config(conf_file, conf_params, create_env=True):
Expand Down
24 changes: 22 additions & 2 deletions alf/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ def config1(config_name,
mutable=True,
raise_if_used=True,
sole_init=False,
override_sole_init=False):
override_sole_init=False,
override_all=False):
"""Set one configurable value.

Args:
Expand All @@ -391,6 +392,10 @@ def config1(config_name,
where the student must override certain configs inherited from the
teacher). If the config is immutable, a warning will be declared with
no changes made.
override_all (bool): If True, the value of the config will be set regardless
of any pre-existing ``mutable`` or ``sole_init`` settings. This should
be used only when absolutely necessary (e.g., adjusting certain configs
such as mini_batch_size for DDP workers.).
"""
config_node = _get_config_node(config_name)

Expand All @@ -399,7 +404,22 @@ def config1(config_name,
"Config '%s' has already been used. You should config "
"its value before using it." % config_name)

if override_sole_init:
if override_all:
if config_node.get_sole_init():
logging.warning(
"The value of config '%s' (%s) is protected by sole_init. "
"It is now being overridden by the overide_all flag to a new value %s. "
"Use at your own risk." % (config_name,
config_node.get_value(), value))
if not config_node.is_mutable():
logging.warning(
"The value of config '%s' (%s) is immutable. "
"It is now being overridden by the overide_all flag to a new value %s. "
"Use at your own risk." % (config_name,
config_node.get_value(), value))
config_node.set_value(value)
return
elif override_sole_init:
if config_node.is_configured():
if not config_node.is_mutable():
logging.warning(
Expand Down