From d5d59a7b67aa454bf2fa274c4eb240a1328f3b1b Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Wed, 8 Jan 2025 11:34:11 +0100 Subject: [PATCH] feat: temp env var decorator --- src/modalities/utils/env_variables.py | 38 +++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/modalities/utils/env_variables.py b/src/modalities/utils/env_variables.py index 15c4630d..89d7150d 100644 --- a/src/modalities/utils/env_variables.py +++ b/src/modalities/utils/env_variables.py @@ -1,11 +1,14 @@ import os from contextlib import contextmanager +from functools import wraps +from typing import Any + @contextmanager def temporary_env_var(key, value): """ Temporarily set an environment variable. - + Args: key (str): The environment variable name. value (str): The temporary value to set. @@ -19,4 +22,35 @@ def temporary_env_var(key, value): if original_value is None: del os.environ[key] else: - os.environ[key] = original_value \ No newline at end of file + os.environ[key] = original_value + + +def temporary_env_vars_decorator(env_vars: dict[str, Any]): + """ + Decorator to temporarily set multiple environment variables for the duration of a function call. + + Args: + env_vars (dict): A dictionary of environment variable names and their temporary values. + """ + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + original_values = {} # Store original values of environment variables + try: + # Set the temporary environment variables + for key, value in env_vars.items(): + original_values[key] = os.environ.get(key) # Save original value + os.environ[key] = value # Set temporary value + return func(*args, **kwargs) # Execute the decorated function + finally: + # Restore original values or delete keys if not originally set + for key, original_value in original_values.items(): + if original_value is None: + del os.environ[key] + else: + os.environ[key] = original_value + + return wrapper + + return decorator