From 4c2afdf50cdeb3409a3e36edc9085940e23097a0 Mon Sep 17 00:00:00 2001 From: Anish Palakurthi Date: Thu, 27 Jun 2024 16:40:03 -0700 Subject: [PATCH] swapped asyncio for thread --- .../python_src/baml_py/stream.py | 40 +++++++++++++------ tools/build | 4 +- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/engine/language_client_python/python_src/baml_py/stream.py b/engine/language_client_python/python_src/baml_py/stream.py index f4a2af94f..6efbd2ecf 100644 --- a/engine/language_client_python/python_src/baml_py/stream.py +++ b/engine/language_client_python/python_src/baml_py/stream.py @@ -6,8 +6,11 @@ TypeBuilder, ) from typing import Callable, Generic, Optional, TypeVar - +import threading import asyncio +import concurrent.futures + +import queue PartialOutputType = TypeVar("PartialOutputType") FinalOutputType = TypeVar("FinalOutputType") @@ -18,9 +21,10 @@ class BamlStream(Generic[PartialOutputType, FinalOutputType]): __partial_coerce: Callable[[FunctionResult], PartialOutputType] __final_coerce: Callable[[FunctionResult], FinalOutputType] __ctx_manager: RuntimeContextManager - __task: Optional[asyncio.Task[FunctionResult]] - __event_queue: asyncio.Queue[Optional[FunctionResult]] + __task: Optional[threading.Thread] + __event_queue: queue.Queue[Optional[FunctionResult]] __tb: Optional[TypeBuilder] + __future: concurrent.futures.Future[FunctionResult] def __init__( self, @@ -29,41 +33,53 @@ def __init__( final_coerce: Callable[[FunctionResult], FinalOutputType], ctx_manager: RuntimeContextManager, tb: Optional[TypeBuilder], + ): self.__ffi_stream = ffi_stream.on_event(self.__enqueue) self.__partial_coerce = partial_coerce self.__final_coerce = final_coerce self.__ctx_manager = ctx_manager self.__task = None - self.__event_queue = asyncio.Queue() + self.__event_queue = queue.Queue() self.__tb = tb + self.__future = concurrent.futures.Future() # Initialize the future here def __enqueue(self, data: FunctionResult) -> None: + self.__event_queue.put_nowait(data) async def __drive_to_completion(self) -> FunctionResult: + print("Drive to completion") try: retval = await self.__ffi_stream.done(self.__ctx_manager) + self.__future.set_result(retval) return retval + except Exception as e: + self.__future.set_exception(e) + raise finally: self.__event_queue.put_nowait(None) - def __drive_to_completion_in_bg(self) -> asyncio.Task[FunctionResult]: - # Doing this without using a compare-and-swap or lock is safe, - # because we don't cross an await point during it + def __drive_to_completion_in_bg(self) -> concurrent.futures.Future[FunctionResult]: if self.__task is None: - self.__task = asyncio.create_task(self.__drive_to_completion()) + self.__task = threading.Thread(target = self.threading_target, daemon=True) + self.__task.start() + return self.__future + + def threading_target(self): + asyncio.run(self.__drive_to_completion(), debug=True) - return self.__task async def __aiter__(self): self.__drive_to_completion_in_bg() while True: - event = await self.__event_queue.get() + print("Loop iteration") + event = self.__event_queue.get() + print("Event") if event is None: break yield self.__partial_coerce(event.parsed()) async def get_final_response(self): - final = await self.__drive_to_completion_in_bg() - return self.__final_coerce(final.parsed()) + final = self.__drive_to_completion_in_bg() + return self.__final_coerce((await asyncio.wrap_future(final)).parsed()) \ No newline at end of file diff --git a/tools/build b/tools/build index b66d41103..a03cb9063 100755 --- a/tools/build +++ b/tools/build @@ -232,7 +232,9 @@ case "$_path" in command="env -u CONDA_PREFIX poetry run maturin develop --manifest-path ${_repo_root}/engine/language_client_python/Cargo.toml" command="${command} && poetry run baml-cli generate --from ${_repo_root}/integ-tests/baml_src" if [ "$_test_mode" -eq 1 ]; then - command="${command} && BAML_LOG=debug infisical run --env=test -- poetry run pytest app/test_functions.py::test_streaming_claude" + # command="${command} && BAML_LOG=debug infisical run --env=test -- poetry run pytest app/test_functions.py::test_streaming_claude" + command="${command} && BAML_LOG=info poetry run pytest -s test_functions.py::test_streaming" + fi if [ "$_watch_mode" -eq 1 ]; then npx nodemon \