From 7ec7bb54085fd2d0d48f184546dd482fcf15729a Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Sun, 25 Aug 2024 20:39:54 -0700 Subject: [PATCH] Adding multiple configs (#685) * Adding multiple configs * Add type hinting --- src/levanter/config.py | 61 +++++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/src/levanter/config.py b/src/levanter/config.py index fefe1b2d3..8caafc256 100644 --- a/src/levanter/config.py +++ b/src/levanter/config.py @@ -93,26 +93,51 @@ def _maybe_get_config_path_and_cmdline_args(args: List[str]): If URL, we need to download it and save it to a temp file. We then want to remove --config_path from the cmdline args so that draccus doesn't try to load it as a config path and return it separately here along with the modified cmdline args. + We also accept ... --configs ... and concatenate them into a single config. """ - if "--config_path" not in args and "--config" not in args: + if "--config_path" not in args and "--config" not in args and "--configs" not in args: return None, args else: - try: - config_path_index = args.index("--config_path") - except ValueError: - config_path_index = args.index("--config") - - config_path = args[config_path_index + 1] - - if urllib.parse.urlparse(config_path).scheme: - fs: AbstractFileSystem - fs, fs_path = fsspec.core.url_to_fs(config_path) - temp_file = tempfile.NamedTemporaryFile(prefix="config", suffix=".yaml", delete=False) - atexit.register(lambda: os.unlink(temp_file.name)) - fs.get(fs_path, temp_file.name) - config_path = temp_file.name + config_args = ["--config_path", "--config", "--configs"] + found_indices = [args.index(arg) for arg in config_args if arg in args] + if len(found_indices) > 1: + raise ValueError(f"Multiple config args found in {args}") + config_path_index = found_indices[0] + 1 + config_paths: List[str] = [] args = args.copy() - del args[config_path_index] - del args[config_path_index] - return config_path, args + del args[config_path_index - 1] + config_path_index -= 1 + + while config_path_index < len(args) and not args[config_path_index].startswith("-"): + + config_path = args[config_path_index] + + if urllib.parse.urlparse(config_path).scheme: + fs: AbstractFileSystem + fs, fs_path = fsspec.core.url_to_fs(config_path) + temp_file = tempfile.NamedTemporaryFile(prefix="config", suffix=".yaml", delete=False) + atexit.register(lambda: os.unlink(temp_file.name)) + fs.get(fs_path, temp_file.name) + config_path = temp_file.name + + config_paths.append(config_path) + del args[config_path_index] + + merged_config_path = None + + if len(config_paths) == 1: + merged_config_path = config_paths[0] + elif len(config_paths) > 1: + # merge the configs by concatenating them + temp_merged_config_path = tempfile.NamedTemporaryFile(prefix="config_merged", suffix=".yaml", delete=False) + atexit.register(lambda: os.unlink(temp_merged_config_path.name)) + with open(temp_merged_config_path.name, "w") as f: + for config_path in config_paths: + with open(config_path) as config_file: + f.write(config_file.read()) + merged_config_path = temp_merged_config_path.name + else: + raise ValueError("No config path found in args") + + return merged_config_path, args