From 5c2ee6ac67ee6babf57f4081a0c78a4bf6a29a84 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Sun, 24 Nov 2024 08:29:56 -0800 Subject: [PATCH] Convert dspy.settings to a ContextVar, improve ParallelExecutor (isolate even if 1 thread), and permit user-launched threads (#1852) * Convert dspy.settings to a ContextVar, improve ParallelExecutor (isolate even if 1 thread), and permit user-launched threads * Fixes --- dsp/utils/settings.py | 123 ++++++++++++++++++++----------------- dspy/utils/asyncify.py | 21 +------ dspy/utils/parallelizer.py | 98 ++++++++++++++++------------- 3 files changed, 124 insertions(+), 118 deletions(-) diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index 00f01eeaf..118a61fdc 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -1,7 +1,8 @@ +import copy import threading -from contextlib import contextmanager -from copy import deepcopy +from contextlib import contextmanager +from contextvars import ContextVar from dsp.utils.utils import dotdict DEFAULT_CONFIG = dotdict( @@ -27,85 +28,95 @@ async_max_workers=8, ) +# Global base configuration +main_thread_config = copy.deepcopy(DEFAULT_CONFIG) + +# Initialize the context variable with an empty dict as default +dspy_ctx_overrides = ContextVar('dspy_ctx_overrides', default=dotdict()) + class Settings: - """DSP configuration settings.""" + """ + A singleton class for DSPy configuration settings. + + This is thread-safe. User threads are supported both through ParallelExecutor and native threading. + - If native threading is used, the thread inherits the initial config from the main thread. + - If ParallelExecutor is used, the thread inherits the initial config from its parent thread. + """ _instance = None def __new__(cls): - """ - Singleton Pattern. See https://python-patterns.guide/gang-of-four/singleton/ - """ - if cls._instance is None: cls._instance = super().__new__(cls) - cls._instance.lock = threading.Lock() - cls._instance.main_tid = threading.get_ident() - cls._instance.main_stack = [] - cls._instance.stack_by_thread = {} - cls._instance.stack_by_thread[threading.get_ident()] = cls._instance.main_stack + cls._instance.lock = threading.Lock() # maintained here for assertions + return cls._instance - # TODO: remove first-class support for re-ranker and potentially combine with RM to form a pipeline of sorts - # eg: RetrieveThenRerankPipeline(RetrievalModel, Reranker) - # downstream operations like dsp.retrieve would use configs from the defined pipeline. + def __getattr__(self, name): + overrides = dspy_ctx_overrides.get() + if name in overrides: + return overrides[name] + elif name in main_thread_config: + return main_thread_config[name] + else: + raise AttributeError(f"'Settings' object has no attribute '{name}'") - # make a deepcopy of the default config to avoid modifying the default config - cls._instance.__append(deepcopy(DEFAULT_CONFIG)) + def __setattr__(self, name, value): + if name in ('_instance',): + super().__setattr__(name, value) + else: + self.configure(**{name: value}) - return cls._instance + # Dictionary-like access - @property - def config(self): - thread_id = threading.get_ident() - if thread_id not in self.stack_by_thread: - self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()] - return self.stack_by_thread[thread_id][-1] + def __getitem__(self, key): + return self.__getattr__(key) - def __getattr__(self, name): - if hasattr(self.config, name): - return getattr(self.config, name) + def __setitem__(self, key, value): + self.__setattr__(key, value) - if name in self.config: - return self.config[name] + def __contains__(self, key): + overrides = dspy_ctx_overrides.get() + return key in overrides or key in main_thread_config - super().__getattr__(name) + def get(self, key, default=None): + try: + return self[key] + except AttributeError: + return default - def __append(self, config): - thread_id = threading.get_ident() - if thread_id not in self.stack_by_thread: - self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()] - self.stack_by_thread[thread_id].append(config) + def copy(self): + overrides = dspy_ctx_overrides.get() + return dotdict({**main_thread_config, **overrides}) - def __pop(self): - thread_id = threading.get_ident() - if thread_id in self.stack_by_thread: - self.stack_by_thread[thread_id].pop() + # Configuration methods - def configure(self, inherit_config: bool = True, **kwargs): - """Set configuration settings. + def configure(self, return_token=False, **kwargs): + global main_thread_config + overrides = dspy_ctx_overrides.get() + new_overrides = dotdict({**main_thread_config, **overrides, **kwargs}) + token = dspy_ctx_overrides.set(new_overrides) - Args: - inherit_config (bool, optional): Set configurations for the given, and use existing configurations for the rest. Defaults to True. - """ - if inherit_config: - config = {**self.config, **kwargs} - else: - config = {**kwargs} + # Update main_thread_config, in the main thread only + if threading.current_thread() is threading.main_thread(): + main_thread_config = new_overrides - self.__append(config) + if return_token: + return token @contextmanager - def context(self, inherit_config=True, **kwargs): - self.configure(inherit_config=inherit_config, **kwargs) - + def context(self, **kwargs): + """Context manager for temporary configuration changes.""" + token = self.configure(return_token=True, **kwargs) try: yield finally: - self.__pop() + dspy_ctx_overrides.reset(token) - def __repr__(self) -> str: - return repr(self.config) + def __repr__(self): + overrides = dspy_ctx_overrides.get() + combined_config = {**main_thread_config, **overrides} + return repr(combined_config) -settings = Settings() \ No newline at end of file +settings = Settings() diff --git a/dspy/utils/asyncify.py b/dspy/utils/asyncify.py index ca801e12a..03bd9a7e9 100644 --- a/dspy/utils/asyncify.py +++ b/dspy/utils/asyncify.py @@ -24,22 +24,7 @@ def get_limiter(): def asyncify(program): - import dspy import threading - - assert threading.get_ident() == dspy.settings.main_tid, "asyncify can only be called from the main thread" - - def wrapped(*args, **kwargs): - thread_stacks = dspy.settings.stack_by_thread - current_thread_id = threading.get_ident() - creating_new_thread = current_thread_id not in thread_stacks - - assert creating_new_thread - thread_stacks[current_thread_id] = list(dspy.settings.main_stack) - - try: - return program(*args, **kwargs) - finally: - del thread_stacks[threading.get_ident()] - - return asyncer.asyncify(wrapped, abandon_on_cancel=True, limiter=get_limiter()) + assert threading.current_thread() is threading.main_thread(), "asyncify can only be called from the main thread" + # NOTE: To allow this to be nested, we'd need behavior with contextvars like parallelizer.py + return asyncer.asyncify(program, abandon_on_cancel=True, limiter=get_limiter()) diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index 27983632b..c6b5f3d5f 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -1,16 +1,15 @@ -import logging import sys import tqdm -import dspy import signal +import logging import threading import traceback import contextlib +from contextvars import copy_context from tqdm.contrib.logging import logging_redirect_tqdm from concurrent.futures import ThreadPoolExecutor, as_completed - logger = logging.getLogger(__name__) @@ -23,6 +22,8 @@ def __init__( provide_traceback=False, compare_results=False, ): + """Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1.""" + self.num_threads = num_threads self.disable_progress_bar = disable_progress_bar self.max_errors = max_errors @@ -33,34 +34,18 @@ def __init__( self.error_lock = threading.Lock() self.cancel_jobs = threading.Event() - def execute(self, function, data): wrapped_function = self._wrap_function(function) if self.num_threads == 1: - return self._execute_single_thread(wrapped_function, data) + return self._execute_isolated_single_thread(wrapped_function, data) else: return self._execute_multi_thread(wrapped_function, data) - def _wrap_function(self, function): - # Wrap the function with threading context and error handling - def wrapped(item, parent_id=None): - thread_stacks = dspy.settings.stack_by_thread - current_thread_id = threading.get_ident() - creating_new_thread = current_thread_id not in thread_stacks - - assert creating_new_thread or threading.get_ident() == dspy.settings.main_tid - - if creating_new_thread: - # If we have a parent thread ID, copy its stack. TODO: Should the caller just pass a copy of the stack? - if parent_id and parent_id in thread_stacks: - thread_stacks[current_thread_id] = list(thread_stacks[parent_id]) - else: - thread_stacks[current_thread_id] = list(dspy.settings.main_stack) - - # TODO: Consider the behavior below. - # import copy; thread_stacks[current_thread_id].append(copy.deepcopy(thread_stacks[current_thread_id][-1])) - + # Wrap the function with error handling + def wrapped(item): + if self.cancel_jobs.is_set(): + return None try: return function(item) except Exception as e: @@ -79,45 +64,53 @@ def wrapped(item, parent_id=None): f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace." ) return None - finally: - if creating_new_thread: - del thread_stacks[threading.get_ident()] return wrapped - - def _execute_single_thread(self, function, data): + def _execute_isolated_single_thread(self, function, data): results = [] pbar = tqdm.tqdm( total=len(data), dynamic_ncols=True, disable=self.disable_progress_bar, - file=sys.stdout, + file=sys.stdout ) + for item in data: with logging_redirect_tqdm(): if self.cancel_jobs.is_set(): break - result = function(item) + + # Create an isolated context for each task + task_ctx = copy_context() + result = task_ctx.run(function, item) results.append(result) + if self.compare_results: # Assumes score is the last element of the result tuple - self._update_progress(pbar, sum([r[-1] for r in results if r is not None]), len([r for r in data if r is not None])) + self._update_progress( + pbar, + sum([r[-1] for r in results if r is not None]), + len([r for r in data if r is not None]), + ) else: self._update_progress(pbar, len(results), len(data)) + pbar.close() + if self.cancel_jobs.is_set(): logger.warning("Execution was cancelled due to errors.") raise Exception("Execution was cancelled due to errors.") - return results + return results def _update_progress(self, pbar, nresults, ntotal): if self.compare_results: - pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({round(100 * nresults / ntotal, 1) if ntotal > 0 else 0}%)") + percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0 + pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)") else: pbar.set_description(f"Processed {nresults} / {ntotal} examples") - pbar.update() + pbar.update() def _execute_multi_thread(self, function, data): results = [None] * len(data) # Pre-allocate results list to maintain order @@ -132,6 +125,7 @@ def interrupt_handler_manager(): def interrupt_handler(sig, frame): self.cancel_jobs.set() logger.warning("Received SIGINT. Cancelling execution.") + # Re-raise the signal to allow default behavior default_handler(sig, frame) signal.signal(signal.SIGINT, interrupt_handler) @@ -143,37 +137,53 @@ def interrupt_handler(sig, frame): # If not in the main thread, skip setting signal handlers yield - def cancellable_function(index_item, parent_id=None): + def cancellable_function(index_item): index, item = index_item if self.cancel_jobs.is_set(): return index, job_cancelled - return index, function(item, parent_id) - - parent_id = threading.get_ident() if threading.current_thread() is not threading.main_thread() else None + return index, function(item) with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager(): - futures = {executor.submit(cancellable_function, pair, parent_id): pair for pair in enumerate(data)} + futures = {} + for pair in enumerate(data): + # Capture the context for each task + task_ctx = copy_context() + future = executor.submit(task_ctx.run, cancellable_function, pair) + futures[future] = pair + pbar = tqdm.tqdm( total=len(data), dynamic_ncols=True, disable=self.disable_progress_bar, - file=sys.stdout, + file=sys.stdout ) for future in as_completed(futures): index, result = future.result() - + if result is job_cancelled: continue + results[index] = result if self.compare_results: # Assumes score is the last element of the result tuple - self._update_progress(pbar, sum([r[-1] for r in results if r is not None]), len([r for r in results if r is not None])) + self._update_progress( + pbar, + sum([r[-1] for r in results if r is not None]), + len([r for r in results if r is not None]), + ) else: - self._update_progress(pbar, len([r for r in results if r is not None]), len(data)) + self._update_progress( + pbar, + len([r for r in results if r is not None]), + len(data), + ) + pbar.close() + if self.cancel_jobs.is_set(): logger.warning("Execution was cancelled due to errors.") raise Exception("Execution was cancelled due to errors.") + return results