diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bc24479..d2ac52ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,16 @@ Changes are grouped as follows - `Fixed` for any bug fixes. - `Security` in case of vulnerabilities. +## [5.5.0] + +### Added + + * Added `load_yaml_dict` to `configtools.loaders`. + +### Fixed + + * Fixed getting the config `type` when `!env` was used in the config file. + ## [5.4.3] ### Added diff --git a/cognite/extractorutils/__init__.py b/cognite/extractorutils/__init__.py index d1a528e7..b5cecba9 100644 --- a/cognite/extractorutils/__init__.py +++ b/cognite/extractorutils/__init__.py @@ -16,5 +16,5 @@ Cognite extractor utils is a Python package that simplifies the development of new extractors. """ -__version__ = "5.4.3" +__version__ = "5.5.0" from .base import Extractor diff --git a/cognite/extractorutils/configtools/loaders.py b/cognite/extractorutils/configtools/loaders.py index cfdd74b7..2b031a4f 100644 --- a/cognite/extractorutils/configtools/loaders.py +++ b/cognite/extractorutils/configtools/loaders.py @@ -36,30 +36,31 @@ CustomConfigClass = TypeVar("CustomConfigClass", bound=BaseConfig) -def _load_yaml( +class _EnvLoader(yaml.SafeLoader): + pass + + +def _env_constructor(_: yaml.SafeLoader, node: yaml.Node) -> bool: + bool_values = { + "true": True, + "false": False, + } + expanded_value = os.path.expandvars(node.value) + return bool_values.get(expanded_value.lower(), expanded_value) + + +_EnvLoader.add_implicit_resolver("!env", re.compile(r"\$\{([^}^{]+)\}"), None) +_EnvLoader.add_constructor("!env", _env_constructor) + + +def _load_yaml_dict( source: Union[TextIO, str], - config_type: Type[CustomConfigClass], case_style: str = "hyphen", expand_envvars: bool = True, dict_manipulator: Callable[[Dict[str, Any]], Dict[str, Any]] = lambda x: x, -) -> CustomConfigClass: - def env_constructor(_: yaml.SafeLoader, node: yaml.Node) -> bool: - bool_values = { - "true": True, - "false": False, - } - expanded_value = os.path.expandvars(node.value) - return bool_values.get(expanded_value.lower(), expanded_value) - - class EnvLoader(yaml.SafeLoader): - pass +) -> Dict[str, Any]: + loader = _EnvLoader if expand_envvars else yaml.SafeLoader - EnvLoader.add_implicit_resolver("!env", re.compile(r"\$\{([^}^{]+)\}"), None) - EnvLoader.add_constructor("!env", env_constructor) - - loader = EnvLoader if expand_envvars else yaml.SafeLoader - - # Safe to use load instead of safe_load since both loader classes are based on SafeLoader try: config_dict = yaml.load(source, Loader=loader) # noqa: S506 except ScannerError as e: @@ -71,6 +72,20 @@ class EnvLoader(yaml.SafeLoader): config_dict = dict_manipulator(config_dict) config_dict = _to_snake_case(config_dict, case_style) + return config_dict + + +def _load_yaml( + source: Union[TextIO, str], + config_type: Type[CustomConfigClass], + case_style: str = "hyphen", + expand_envvars: bool = True, + dict_manipulator: Callable[[Dict[str, Any]], Dict[str, Any]] = lambda x: x, +) -> CustomConfigClass: + config_dict = _load_yaml_dict( + source, case_style=case_style, expand_envvars=expand_envvars, dict_manipulator=dict_manipulator + ) + try: config = dacite.from_dict( data=config_dict, data_class=config_type, config=dacite.Config(strict=True, cast=[Enum, TimeIntervalConfig]) @@ -133,6 +148,29 @@ def load_yaml( return _load_yaml(source=source, config_type=config_type, case_style=case_style, expand_envvars=expand_envvars) +def load_yaml_dict( + source: Union[TextIO, str], + case_style: str = "hyphen", + expand_envvars: bool = True, +) -> Dict[str, Any]: + """ + Read a YAML file and return a dictionary from its contents + + Args: + source: Input stream (as returned by open(...)) or string containing YAML. + case_style: Casing convention of config file. Valid options are 'snake', 'hyphen' or 'camel'. Should be + 'hyphen'. + expand_envvars: Substitute values with the pattern ${VAR} with the content of the environment variable VAR + + Returns: + A raw dict with the contents of the config file. + + Raises: + InvalidConfigError: If any config field is given as an invalid type, is missing or is unknown + """ + return _load_yaml_dict(source=source, case_style=case_style, expand_envvars=expand_envvars) + + class ConfigResolver(Generic[CustomConfigClass]): def __init__(self, config_path: str, config_type: Type[CustomConfigClass]): self.config_path = config_path @@ -147,7 +185,7 @@ def _reload_file(self) -> None: @property def is_remote(self) -> bool: - raw_config_type = yaml.safe_load(self._config_text).get("type") + raw_config_type = load_yaml_dict(self._config_text).get("type") if raw_config_type is None: _logger.warning("No config type specified, default to local") raw_config_type = "local" diff --git a/pyproject.toml b/pyproject.toml index 0c3f21b0..fde5b898 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cognite-extractor-utils" -version = "5.4.3" +version = "5.5.0" description = "Utilities for easier development of extractors for CDF" authors = ["Mathias Lohne "] license = "Apache-2.0" diff --git a/tests/tests_unit/test_configtools.py b/tests/tests_unit/test_configtools.py index 90016085..3d725eae 100644 --- a/tests/tests_unit/test_configtools.py +++ b/tests/tests_unit/test_configtools.py @@ -32,6 +32,7 @@ ) from cognite.extractorutils.configtools._util import _to_snake_case from cognite.extractorutils.configtools.elements import AuthenticatorConfig +from cognite.extractorutils.configtools.loaders import ConfigResolver from cognite.extractorutils.exceptions import InvalidConfigError @@ -412,3 +413,22 @@ def test_env_substitution(self): config5 = load_yaml(config_file5, SimpleStringConfig) self.assertEqual(config5.string_field, "veryheyocrowded") + + def test_env_substitution_remote_check(self): + os.environ["STRING_VALUE"] = "test" + + resolver = ConfigResolver("some-path.yml", BaseConfig) + + resolver._config_text = """ + type: local + some_field: !env "wow${STRING_VALUE}wow" + """ + assert not resolver.is_remote + + resolver._config_text = """ + type: ${STRING_VALUE} + some_field: !env "wow${STRING_VALUE}wow" + """ + + os.environ["STRING_VALUE"] = "remote" + assert resolver.is_remote