Skip to content

Commit

Permalink
Adding multiple configs (#685)
Browse files Browse the repository at this point in the history
* Adding multiple configs

* Add type hinting
  • Loading branch information
abhinavg4 authored Aug 26, 2024
1 parent c823c75 commit 7ec7bb5
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions src/levanter/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,51 @@ def _maybe_get_config_path_and_cmdline_args(args: List[str]):
If URL, we need to download it and save it to a temp file. We then want to remove --config_path
from the cmdline args so that draccus doesn't try to load it as a config path and return it separately here
along with the modified cmdline args.
We also accept ... --configs <config1> <config2> ... and concatenate them into a single config.
"""
if "--config_path" not in args and "--config" not in args:
if "--config_path" not in args and "--config" not in args and "--configs" not in args:
return None, args
else:
try:
config_path_index = args.index("--config_path")
except ValueError:
config_path_index = args.index("--config")

config_path = args[config_path_index + 1]

if urllib.parse.urlparse(config_path).scheme:
fs: AbstractFileSystem
fs, fs_path = fsspec.core.url_to_fs(config_path)
temp_file = tempfile.NamedTemporaryFile(prefix="config", suffix=".yaml", delete=False)
atexit.register(lambda: os.unlink(temp_file.name))
fs.get(fs_path, temp_file.name)
config_path = temp_file.name
config_args = ["--config_path", "--config", "--configs"]
found_indices = [args.index(arg) for arg in config_args if arg in args]
if len(found_indices) > 1:
raise ValueError(f"Multiple config args found in {args}")
config_path_index = found_indices[0] + 1

config_paths: List[str] = []
args = args.copy()
del args[config_path_index]
del args[config_path_index]
return config_path, args
del args[config_path_index - 1]
config_path_index -= 1

while config_path_index < len(args) and not args[config_path_index].startswith("-"):

config_path = args[config_path_index]

if urllib.parse.urlparse(config_path).scheme:
fs: AbstractFileSystem
fs, fs_path = fsspec.core.url_to_fs(config_path)
temp_file = tempfile.NamedTemporaryFile(prefix="config", suffix=".yaml", delete=False)
atexit.register(lambda: os.unlink(temp_file.name))
fs.get(fs_path, temp_file.name)
config_path = temp_file.name

config_paths.append(config_path)
del args[config_path_index]

merged_config_path = None

if len(config_paths) == 1:
merged_config_path = config_paths[0]
elif len(config_paths) > 1:
# merge the configs by concatenating them
temp_merged_config_path = tempfile.NamedTemporaryFile(prefix="config_merged", suffix=".yaml", delete=False)
atexit.register(lambda: os.unlink(temp_merged_config_path.name))
with open(temp_merged_config_path.name, "w") as f:
for config_path in config_paths:
with open(config_path) as config_file:
f.write(config_file.read())
merged_config_path = temp_merged_config_path.name
else:
raise ValueError("No config path found in args")

return merged_config_path, args

0 comments on commit 7ec7bb5

Please sign in to comment.