Skip to content

Commit

Permalink
Merge pull request #29 from froody/flagfile
Browse files Browse the repository at this point in the history
Add support for --flagfile to DEFINE_config_*
marksandler2 authored Jul 29, 2024
2 parents 295c938 + e68b688 commit 2699256
Showing 2 changed files with 21 additions and 4 deletions.
12 changes: 8 additions & 4 deletions ml_collections/config_flags/config_flags.py
Original file line number Diff line number Diff line change
@@ -719,6 +719,12 @@ def __init__(
self._sys_argv = sys_argv
super(_ConfigFlag, self).__init__(**kwargs)

def _GetArgv(self):
"""Lazily fetches sys.argv and expands any potential --flagfile=... arguments."""
argv = sys.argv if self._sys_argv is None else self._sys_argv
argv = flags.FLAGS.read_flags_from_files(argv, force_gnu=False)
return argv

def _GetOverrides(self, argv):
"""Parses the command line arguments for the overrides."""
# We use a dict to keep the order of the overrides.
@@ -755,8 +761,7 @@ def _IsConfigSpecified(self, argv):
return self._FindConfigSpecified(argv) >= 0

def _set_default(self, default):
if self._IsConfigSpecified(
sys.argv if self._sys_argv is None else self._sys_argv):
if self._IsConfigSpecified(self._GetArgv()):
self.default = default
else:
super(_ConfigFlag, self)._set_default(default) # pytype: disable=attribute-error
@@ -786,8 +791,7 @@ def _parse(self, argument):
config = super(_ConfigFlag, self)._parse(argument)

# Get list or overrides
overrides = self._GetOverrides(
sys.argv if self._sys_argv is None else self._sys_argv)
overrides = self._GetOverrides(self._GetArgv())
# Iterate over overridden fields and create valid parsers
self._override_values = {}
self._initialize_missing_parent_fields(config, overrides)
13 changes: 13 additions & 0 deletions ml_collections/config_flags/tests/config_overriding_test.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
import copy
import shlex
import sys
import tempfile

from absl import flags
from absl.testing import absltest
@@ -832,6 +833,18 @@ def testOverridesSerialize(self):
serialize_parse('test_config.type_tuple',
values.test_config.type_tuple))

def testFlagfile(self):
config = config_dict.ConfigDict()
config.foo = 3
config.bar = 4

with tempfile.NamedTemporaryFile(mode='w', delete=False) as f:
f.write('--test_config.foo=7\n')
f.flush()
f.close()
values = _parse_flags(f'./program --flagfile={f.name}', config=config)
self.assertEqual(values.test_config.foo, 7)
self.assertEqual(values.test_config.bar, 4)

def main():
absltest.main()

0 comments on commit 2699256

Please sign in to comment.