Skip to content

Commit

Permalink
Restore --flagfile fix compatibility
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695220238
  • Loading branch information
Conchylicultor authored and The ml_collections Authors committed Nov 11, 2024
1 parent 06499ac commit b4368a2
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ The ML Collections documentation can be found here: https://ml-collections.readt
# How to build the docs
1. Install the requirements in ml_collections/docs/requirements.txt.
2. Ensure `pandoc` is installed.
3. Run `make html` to locally generate documentation.
3. Run `make html` to locally generate documentation.
12 changes: 8 additions & 4 deletions ml_collections/config_flags/config_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,12 @@ def __init__(
self._sys_argv = sys_argv
super(_ConfigFlag, self).__init__(**kwargs)

def _GetArgv(self):
"""Lazily fetches sys.argv and expands any potential --flagfile=..."""
argv = sys.argv if self._sys_argv is None else self._sys_argv
argv = flags.FLAGS.read_flags_from_files(argv, force_gnu=False)
return argv

def _GetOverrides(self, argv):
"""Parses the command line arguments for the overrides."""
# We use a dict to keep the order of the overrides.
Expand Down Expand Up @@ -756,8 +762,7 @@ def _IsConfigSpecified(self, argv):
return self._FindConfigSpecified(argv) >= 0

def _set_default(self, default):
if self._IsConfigSpecified(
sys.argv if self._sys_argv is None else self._sys_argv):
if self._IsConfigSpecified(self._GetArgv()):
self.default = default
else:
super(_ConfigFlag, self)._set_default(default) # pytype: disable=attribute-error
Expand Down Expand Up @@ -787,8 +792,7 @@ def _parse(self, argument):
config = super(_ConfigFlag, self)._parse(argument)

# Get list or overrides
overrides = self._GetOverrides(
sys.argv if self._sys_argv is None else self._sys_argv)
overrides = self._GetOverrides(self._GetArgv())
# Iterate over overridden fields and create valid parsers
self._override_values = {}
self._initialize_missing_parent_fields(config, overrides)
Expand Down
14 changes: 14 additions & 0 deletions ml_collections/config_flags/tests/config_overriding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import shlex
import sys
import tempfile

from absl import flags
from absl.testing import absltest
Expand Down Expand Up @@ -832,6 +833,19 @@ def testOverridesSerialize(self):
serialize_parse('test_config.type_tuple',
values.test_config.type_tuple))

def testFlagfile(self):
config = config_dict.ConfigDict()
config.foo = 3
config.bar = 4

with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
f.write('--test_config.foo=7\n')
f.flush()
f.close()
values = _parse_flags(f'./program --flagfile={f.name}', config=config)
self.assertEqual(values.test_config.foo, 7)
self.assertEqual(values.test_config.bar, 4)


def main():
absltest.main()
Expand Down

0 comments on commit b4368a2

Please sign in to comment.