Skip to content

Commit

Permalink
load variables when getting config
Browse files Browse the repository at this point in the history
  • Loading branch information
胡霁 authored and ice-black-tea committed Apr 16, 2023
1 parent a0121ff commit 8b4476e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
41 changes: 28 additions & 13 deletions src/linktools/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import abc
import errno
import os
import threading
from types import ModuleType
from typing import Optional, Union, Callable, IO, Any, Mapping, Dict, List, Type

Expand All @@ -42,13 +43,24 @@


class _Loader(abc.ABC):

def load(self, env: BaseEnviron, key: any):
try:
return self._load(env, key)
except Exception as e:
env.logger.error(f"Load config \"{key}\" error", exc_info=e)
raise e
__missing__ = object
__lock__ = threading.RLock()

def __init__(self):
self._data: Union[str, object] = self.__missing__

def load(self, env: BaseEnviron, key: any) -> Optional[str]:
if self._data is not self.__missing__:
return self._data
with self.__lock__:
if self._data is not self.__missing__:
return self._data
try:
self._data = self._load(env, key)
except Exception as e:
env.logger.error(f"Load config \"{key}\" error", exc_info=e)
raise e
return self._data

@abc.abstractmethod
def _load(self, env: BaseEnviron, key: any):
Expand All @@ -64,9 +76,10 @@ def __init__(
choices: Optional[List[str]] = None,
default: Any = None,
type: Type = str,
cached: bool = None,
cached: bool = False,
trim: bool = True,
):
super().__init__()
self.prompt = prompt
self.password = password
self.choices = choices
Expand Down Expand Up @@ -138,6 +151,7 @@ def process_result(data):
class _Lazy(_Loader):

def __init__(self, func: Callable[[BaseEnviron], Any]):
super().__init__()
self.func = func

def _load(self, env: BaseEnviron, key: any):
Expand All @@ -151,6 +165,12 @@ def __init__(self, env: BaseEnviron, defaults: Optional[dict] = None):
super().__init__(defaults or {})
self.environ = env

def load_value(self, key) -> Any:
value = self[key]
if isinstance(value, _Loader):
value = value.load(self.environ, key)
return value

def from_envvar(self, variable_name: str, silent: bool = False) -> bool:
rv = os.environ.get(variable_name)
if not rv:
Expand Down Expand Up @@ -207,8 +227,3 @@ def from_mapping(self, mapping: Optional[Mapping[str, Any]] = None, **kwargs: An
if key[0].isupper():
self[key] = value
return True

def __setitem__(self, key, value):
if isinstance(value, _Loader):
value = utils.lazy_load(value.load, self.environ, key)
return super().__setitem__(key, value)
16 changes: 11 additions & 5 deletions src/linktools/_environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_configs(self, namespace: str, lowercase: bool = True, trim_namespace: bo
根据命名空间获取配置列表
"""
rv = {}
for k, v in self._config.items():
for k in self._config:
if not k.startswith(namespace):
continue
if trim_namespace:
Expand All @@ -195,7 +195,7 @@ def get_configs(self, namespace: str, lowercase: bool = True, trim_namespace: bo
key = k
if lowercase:
key = key.lower()
rv[key] = v
rv[key] = self._config.load_value(k)
return rv

def get_config(self, key: str, type: Type[T] = None, empty: bool = False, default: T = None) -> Optional[T]:
Expand All @@ -213,7 +213,7 @@ def get_config(self, key: str, type: Type[T] = None, empty: bool = False, defaul

try:
if key in self._config:
value = self._config.get(key)
value = self._config.load_value(key)
if empty or value:
return value if type is None else type(value)
except Exception as e:
Expand All @@ -234,16 +234,22 @@ def set_config(self, key: str, value: Any) -> None:
self._config[key] = value

def load_config_file(self, path: str) -> bool:
"""
加载配置文件,按照扩展名来匹配相应的加载规则
"""
if path.endswith(".py"):
return self._config.from_pyfile(path)
elif path.endswith(".json"):
return self._config.from_file(path, load=json.load)
elif path.endswith(".yml"):
return self._config.from_file(path, load=yaml.safe_load)
self.logger.debug(f"Unsupport config file: {path}")
self.logger.debug(f"Unsupported config file: {path}")
return False

def load_config_dir(self, path: str, recursion: bool = False) -> bool:
"""
加载配置文件目录,按照扩展名来匹配相应的加载规则
"""
# 路径不存在
if not os.path.exists(path):
return False
Expand All @@ -270,7 +276,7 @@ def walk_configs(self, include_internal: bool = False) -> Generator[Tuple[str, A
internal_config = self._internal_config
for key in self._config:
if include_internal or key not in internal_config:
yield key, self.get_config(key, type=None)
yield key, self.get_config(key)

@cached_property
def tools(self):
Expand Down

0 comments on commit 8b4476e

Please sign in to comment.