diff --git a/dace/config.py b/dace/config.py index da53023171..978cf82fda 100644 --- a/dace/config.py +++ b/dace/config.py @@ -3,6 +3,7 @@ import os import platform import tempfile +import io from typing import Any, Dict import yaml import warnings @@ -39,10 +40,11 @@ def temporary_config(): Config.set("optimizer", "autooptimize", value=True) foo() """ - with tempfile.NamedTemporaryFile() as fp: - Config.save(fp.name) + with tempfile.NamedTemporaryFile(mode='w+t') as fp: + Config.save(file=fp) yield - Config.load(fp.name) + fp.seek(0) # rewind to the beginning of the file. + Config.load(file=fp) def _env2bool(envval): @@ -157,19 +159,21 @@ def initialize(): Config.save(all=False) @staticmethod - def load(filename=None): + def load(filename=None, file=None): """ Loads a configuration from an existing file. :param filename: The file to load. If unspecified, uses default configuration file. + :param file: Load the configuration from the file object. """ - if filename is None: - filename = Config._cfg_filename - # Read configuration file - with open(filename, 'r') as f: - Config._config = yaml.load(f.read(), Loader=yaml.SafeLoader) + if file is not None: + assert filename is None + Config._config = yaml.load(file.read(), Loader=yaml.SafeLoader) + else: + with open(filename if filename else Config._cfg_filename, 'r') as f: + Config._config = yaml.load(f.read(), Loader=yaml.SafeLoader) if Config._config is None: Config._config = {} @@ -191,7 +195,7 @@ def load_schema(filename=None): Config._config_metadata = yaml.load(f.read(), Loader=yaml.SafeLoader) @staticmethod - def save(path=None, all: bool = False): + def save(path=None, all: bool = False, file=None): """ Saves the current configuration to a file. @@ -199,8 +203,9 @@ def save(path=None, all: bool = False): uses default configuration file. :param all: If False, only saves non-default configuration entries. Otherwise saves all entries. + :param file: A file object to use directly. """ - if path is None: + if path is None and file is None: path = Config._cfg_filename if path is None: # Try to create a new config file in reversed priority order, and if all else fails keep config in memory @@ -217,8 +222,11 @@ def save(path=None, all: bool = False): return # Write configuration file - with open(path, 'w') as f: - yaml.dump(Config._config if all else Config.nondefaults(), f, default_flow_style=False) + if file is not None: + yaml.dump(Config._config if all else Config.nondefaults(), file, default_flow_style=False) + else: + with open(path, 'w') as f: + yaml.dump(Config._config if all else Config.nondefaults(), f, default_flow_style=False) @staticmethod def get_metadata(*key_hierarchy): diff --git a/tests/config_test.py b/tests/config_test.py index be765b262c..e1a7ef5cc6 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -1,5 +1,5 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace.config import set_temporary, Config +from dace.config import Config, set_temporary, temporary_config def test_set_temporary(): @@ -10,5 +10,15 @@ def test_set_temporary(): assert Config.get(*path) == current_value +def test_temporary_config(): + path = ["compiler", "build_type"] + current_value = Config.get(*path) + with temporary_config(): + Config.set(*path, value="I'm not a build type") + assert Config.get(*path) == "I'm not a build type" + assert Config.get(*path) == current_value + + if __name__ == '__main__': test_set_temporary() + test_temporary_config()