Skip to content

Commit

Permalink
Convert dspy.settings to a ContextVar, improve ParallelExecutor (isol…
Browse files Browse the repository at this point in the history
…ate even if 1 thread), and permit user-launched threads (#1852)

* Convert dspy.settings to a ContextVar, improve ParallelExecutor (isolate even if 1 thread), and permit user-launched threads

* Fixes
  • Loading branch information
okhat authored Nov 24, 2024
1 parent 0eb1e04 commit 5c2ee6a
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 118 deletions.
123 changes: 67 additions & 56 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import threading
from contextlib import contextmanager
from copy import deepcopy

from contextlib import contextmanager
from contextvars import ContextVar
from dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
Expand All @@ -27,85 +28,95 @@
async_max_workers=8,
)

# Global base configuration
main_thread_config = copy.deepcopy(DEFAULT_CONFIG)

# Initialize the context variable with an empty dict as default
dspy_ctx_overrides = ContextVar('dspy_ctx_overrides', default=dotdict())


class Settings:
"""DSP configuration settings."""
"""
A singleton class for DSPy configuration settings.
This is thread-safe. User threads are supported both through ParallelExecutor and native threading.
- If native threading is used, the thread inherits the initial config from the main thread.
- If ParallelExecutor is used, the thread inherits the initial config from its parent thread.
"""

_instance = None

def __new__(cls):
"""
Singleton Pattern. See https://python-patterns.guide/gang-of-four/singleton/
"""

if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.lock = threading.Lock()
cls._instance.main_tid = threading.get_ident()
cls._instance.main_stack = []
cls._instance.stack_by_thread = {}
cls._instance.stack_by_thread[threading.get_ident()] = cls._instance.main_stack
cls._instance.lock = threading.Lock() # maintained here for assertions
return cls._instance

# TODO: remove first-class support for re-ranker and potentially combine with RM to form a pipeline of sorts
# eg: RetrieveThenRerankPipeline(RetrievalModel, Reranker)
# downstream operations like dsp.retrieve would use configs from the defined pipeline.
def __getattr__(self, name):
overrides = dspy_ctx_overrides.get()
if name in overrides:
return overrides[name]
elif name in main_thread_config:
return main_thread_config[name]
else:
raise AttributeError(f"'Settings' object has no attribute '{name}'")

# make a deepcopy of the default config to avoid modifying the default config
cls._instance.__append(deepcopy(DEFAULT_CONFIG))
def __setattr__(self, name, value):
if name in ('_instance',):
super().__setattr__(name, value)
else:
self.configure(**{name: value})

return cls._instance
# Dictionary-like access

@property
def config(self):
thread_id = threading.get_ident()
if thread_id not in self.stack_by_thread:
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
return self.stack_by_thread[thread_id][-1]
def __getitem__(self, key):
return self.__getattr__(key)

def __getattr__(self, name):
if hasattr(self.config, name):
return getattr(self.config, name)
def __setitem__(self, key, value):
self.__setattr__(key, value)

if name in self.config:
return self.config[name]
def __contains__(self, key):
overrides = dspy_ctx_overrides.get()
return key in overrides or key in main_thread_config

super().__getattr__(name)
def get(self, key, default=None):
try:
return self[key]
except AttributeError:
return default

def __append(self, config):
thread_id = threading.get_ident()
if thread_id not in self.stack_by_thread:
self.stack_by_thread[thread_id] = [self.main_stack[-1].copy()]
self.stack_by_thread[thread_id].append(config)
def copy(self):
overrides = dspy_ctx_overrides.get()
return dotdict({**main_thread_config, **overrides})

def __pop(self):
thread_id = threading.get_ident()
if thread_id in self.stack_by_thread:
self.stack_by_thread[thread_id].pop()
# Configuration methods

def configure(self, inherit_config: bool = True, **kwargs):
"""Set configuration settings.
def configure(self, return_token=False, **kwargs):
global main_thread_config
overrides = dspy_ctx_overrides.get()
new_overrides = dotdict({**main_thread_config, **overrides, **kwargs})
token = dspy_ctx_overrides.set(new_overrides)

Args:
inherit_config (bool, optional): Set configurations for the given, and use existing configurations for the rest. Defaults to True.
"""
if inherit_config:
config = {**self.config, **kwargs}
else:
config = {**kwargs}
# Update main_thread_config, in the main thread only
if threading.current_thread() is threading.main_thread():
main_thread_config = new_overrides

self.__append(config)
if return_token:
return token

@contextmanager
def context(self, inherit_config=True, **kwargs):
self.configure(inherit_config=inherit_config, **kwargs)

def context(self, **kwargs):
"""Context manager for temporary configuration changes."""
token = self.configure(return_token=True, **kwargs)
try:
yield
finally:
self.__pop()
dspy_ctx_overrides.reset(token)

def __repr__(self) -> str:
return repr(self.config)
def __repr__(self):
overrides = dspy_ctx_overrides.get()
combined_config = {**main_thread_config, **overrides}
return repr(combined_config)


settings = Settings()
settings = Settings()
21 changes: 3 additions & 18 deletions dspy/utils/asyncify.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,7 @@ def get_limiter():


def asyncify(program):
import dspy
import threading

assert threading.get_ident() == dspy.settings.main_tid, "asyncify can only be called from the main thread"

def wrapped(*args, **kwargs):
thread_stacks = dspy.settings.stack_by_thread
current_thread_id = threading.get_ident()
creating_new_thread = current_thread_id not in thread_stacks

assert creating_new_thread
thread_stacks[current_thread_id] = list(dspy.settings.main_stack)

try:
return program(*args, **kwargs)
finally:
del thread_stacks[threading.get_ident()]

return asyncer.asyncify(wrapped, abandon_on_cancel=True, limiter=get_limiter())
assert threading.current_thread() is threading.main_thread(), "asyncify can only be called from the main thread"
# NOTE: To allow this to be nested, we'd need behavior with contextvars like parallelizer.py
return asyncer.asyncify(program, abandon_on_cancel=True, limiter=get_limiter())
98 changes: 54 additions & 44 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import logging
import sys
import tqdm
import dspy
import signal
import logging
import threading
import traceback
import contextlib

from contextvars import copy_context
from tqdm.contrib.logging import logging_redirect_tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed


logger = logging.getLogger(__name__)


Expand All @@ -23,6 +22,8 @@ def __init__(
provide_traceback=False,
compare_results=False,
):
"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""

self.num_threads = num_threads
self.disable_progress_bar = disable_progress_bar
self.max_errors = max_errors
Expand All @@ -33,34 +34,18 @@ def __init__(
self.error_lock = threading.Lock()
self.cancel_jobs = threading.Event()


def execute(self, function, data):
wrapped_function = self._wrap_function(function)
if self.num_threads == 1:
return self._execute_single_thread(wrapped_function, data)
return self._execute_isolated_single_thread(wrapped_function, data)
else:
return self._execute_multi_thread(wrapped_function, data)


def _wrap_function(self, function):
# Wrap the function with threading context and error handling
def wrapped(item, parent_id=None):
thread_stacks = dspy.settings.stack_by_thread
current_thread_id = threading.get_ident()
creating_new_thread = current_thread_id not in thread_stacks

assert creating_new_thread or threading.get_ident() == dspy.settings.main_tid

if creating_new_thread:
# If we have a parent thread ID, copy its stack. TODO: Should the caller just pass a copy of the stack?
if parent_id and parent_id in thread_stacks:
thread_stacks[current_thread_id] = list(thread_stacks[parent_id])
else:
thread_stacks[current_thread_id] = list(dspy.settings.main_stack)

# TODO: Consider the behavior below.
# import copy; thread_stacks[current_thread_id].append(copy.deepcopy(thread_stacks[current_thread_id][-1]))

# Wrap the function with error handling
def wrapped(item):
if self.cancel_jobs.is_set():
return None
try:
return function(item)
except Exception as e:
Expand All @@ -79,45 +64,53 @@ def wrapped(item, parent_id=None):
f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace."
)
return None
finally:
if creating_new_thread:
del thread_stacks[threading.get_ident()]
return wrapped


def _execute_single_thread(self, function, data):
def _execute_isolated_single_thread(self, function, data):
results = []
pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout,
file=sys.stdout
)

for item in data:
with logging_redirect_tqdm():
if self.cancel_jobs.is_set():
break
result = function(item)

# Create an isolated context for each task
task_ctx = copy_context()
result = task_ctx.run(function, item)
results.append(result)

if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(pbar, sum([r[-1] for r in results if r is not None]), len([r for r in data if r is not None]))
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in data if r is not None]),
)
else:
self._update_progress(pbar, len(results), len(data))

pbar.close()

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")
return results

return results

def _update_progress(self, pbar, nresults, ntotal):
if self.compare_results:
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({round(100 * nresults / ntotal, 1) if ntotal > 0 else 0}%)")
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)")
else:
pbar.set_description(f"Processed {nresults} / {ntotal} examples")
pbar.update()

pbar.update()

def _execute_multi_thread(self, function, data):
results = [None] * len(data) # Pre-allocate results list to maintain order
Expand All @@ -132,6 +125,7 @@ def interrupt_handler_manager():
def interrupt_handler(sig, frame):
self.cancel_jobs.set()
logger.warning("Received SIGINT. Cancelling execution.")
# Re-raise the signal to allow default behavior
default_handler(sig, frame)

signal.signal(signal.SIGINT, interrupt_handler)
Expand All @@ -143,37 +137,53 @@ def interrupt_handler(sig, frame):
# If not in the main thread, skip setting signal handlers
yield

def cancellable_function(index_item, parent_id=None):
def cancellable_function(index_item):
index, item = index_item
if self.cancel_jobs.is_set():
return index, job_cancelled
return index, function(item, parent_id)

parent_id = threading.get_ident() if threading.current_thread() is not threading.main_thread() else None
return index, function(item)

with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
futures = {executor.submit(cancellable_function, pair, parent_id): pair for pair in enumerate(data)}
futures = {}
for pair in enumerate(data):
# Capture the context for each task
task_ctx = copy_context()
future = executor.submit(task_ctx.run, cancellable_function, pair)
futures[future] = pair

pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout,
file=sys.stdout
)

for future in as_completed(futures):
index, result = future.result()

if result is job_cancelled:
continue

results[index] = result

if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(pbar, sum([r[-1] for r in results if r is not None]), len([r for r in results if r is not None]))
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in results if r is not None]),
)
else:
self._update_progress(pbar, len([r for r in results if r is not None]), len(data))
self._update_progress(
pbar,
len([r for r in results if r is not None]),
len(data),
)

pbar.close()

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")

return results

0 comments on commit 5c2ee6a

Please sign in to comment.