From 6a490ecd4a9b32e9b73e89e7b83a865447fb07b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 21 Jun 2024 08:13:14 +0200 Subject: [PATCH] Use file object directly in `temporary_config()` (#1598) The context manager uses `NamedTemporaryFile` to store the current configuration, to later restore them. Instead of passing the file object directly to the save function, it just passes the file name, i.e. the save (and the load function) will open the file again, which is in itself not a problem. However, on the Github Windows image this leads to a permission error (using the created file object is fine). This commit solves this by adding the `file` argument to `Config.save()` that allows to pass a file object directly to the function. The same change is applied to the load function of the config object. --- dace/config.py | 34 +++++++++++++++++++++------------- tests/config_test.py | 12 +++++++++++- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/dace/config.py b/dace/config.py index da53023171..978cf82fda 100644 --- a/dace/config.py +++ b/dace/config.py @@ -3,6 +3,7 @@ import os import platform import tempfile +import io from typing import Any, Dict import yaml import warnings @@ -39,10 +40,11 @@ def temporary_config(): Config.set("optimizer", "autooptimize", value=True) foo() """ - with tempfile.NamedTemporaryFile() as fp: - Config.save(fp.name) + with tempfile.NamedTemporaryFile(mode='w+t') as fp: + Config.save(file=fp) yield - Config.load(fp.name) + fp.seek(0) # rewind to the beginning of the file. + Config.load(file=fp) def _env2bool(envval): @@ -157,19 +159,21 @@ def initialize(): Config.save(all=False) @staticmethod - def load(filename=None): + def load(filename=None, file=None): """ Loads a configuration from an existing file. :param filename: The file to load. If unspecified, uses default configuration file. + :param file: Load the configuration from the file object. """ - if filename is None: - filename = Config._cfg_filename - # Read configuration file - with open(filename, 'r') as f: - Config._config = yaml.load(f.read(), Loader=yaml.SafeLoader) + if file is not None: + assert filename is None + Config._config = yaml.load(file.read(), Loader=yaml.SafeLoader) + else: + with open(filename if filename else Config._cfg_filename, 'r') as f: + Config._config = yaml.load(f.read(), Loader=yaml.SafeLoader) if Config._config is None: Config._config = {} @@ -191,7 +195,7 @@ def load_schema(filename=None): Config._config_metadata = yaml.load(f.read(), Loader=yaml.SafeLoader) @staticmethod - def save(path=None, all: bool = False): + def save(path=None, all: bool = False, file=None): """ Saves the current configuration to a file. @@ -199,8 +203,9 @@ def save(path=None, all: bool = False): uses default configuration file. :param all: If False, only saves non-default configuration entries. Otherwise saves all entries. + :param file: A file object to use directly. """ - if path is None: + if path is None and file is None: path = Config._cfg_filename if path is None: # Try to create a new config file in reversed priority order, and if all else fails keep config in memory @@ -217,8 +222,11 @@ def save(path=None, all: bool = False): return # Write configuration file - with open(path, 'w') as f: - yaml.dump(Config._config if all else Config.nondefaults(), f, default_flow_style=False) + if file is not None: + yaml.dump(Config._config if all else Config.nondefaults(), file, default_flow_style=False) + else: + with open(path, 'w') as f: + yaml.dump(Config._config if all else Config.nondefaults(), f, default_flow_style=False) @staticmethod def get_metadata(*key_hierarchy): diff --git a/tests/config_test.py b/tests/config_test.py index be765b262c..e1a7ef5cc6 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -1,5 +1,5 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from dace.config import set_temporary, Config +from dace.config import Config, set_temporary, temporary_config def test_set_temporary(): @@ -10,5 +10,15 @@ def test_set_temporary(): assert Config.get(*path) == current_value +def test_temporary_config(): + path = ["compiler", "build_type"] + current_value = Config.get(*path) + with temporary_config(): + Config.set(*path, value="I'm not a build type") + assert Config.get(*path) == "I'm not a build type" + assert Config.get(*path) == current_value + + if __name__ == '__main__': test_set_temporary() + test_temporary_config()