Skip to content

Commit

Permalink
feat: close tool controllers on run end (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Dec 18, 2024
1 parent e5e06e4 commit 27a1e82
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
4 changes: 2 additions & 2 deletions python/assistant-stream/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[project]
name = "assistant-stream"
version = "0.0.3"
version = "0.0.5"
authors = [
{ name="Simon Farshid", email="[email protected]" },
]
Expand All @@ -22,7 +22,7 @@ Issues = "https://github.com/Yonom/assistant-ui/issues"

[tool.poetry]
name = "assistant-stream"
version = "0.0.3"
version = "0.0.5"
description = ""
authors = ["Simon Farshid <[email protected]>"]

Expand Down
22 changes: 13 additions & 9 deletions python/assistant-stream/src/assistant_stream/create_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@

class RunController:
def __init__(self, queue):
self.queue = queue
self.loop = asyncio.get_event_loop()
self.stream_tasks = []
self._queue = queue
self._loop = asyncio.get_event_loop()
self._dispose_callbacks = []
self._stream_tasks = []

def append_text(self, text_delta: str):
"""Append a text delta to the stream."""
chunk = TextDeltaChunk(type="text-delta", text_delta=text_delta)
self.loop.call_soon_threadsafe(self.queue.put_nowait, chunk)
self._loop.call_soon_threadsafe(self._queue.put_nowait, chunk)

async def add_tool_call(
self, tool_name: str, tool_call_id: str = generate_openai_style_tool_call_id()
) -> ToolCallController:
"""Add a tool call to the stream."""

stream, controller = await create_tool_call(tool_name, tool_call_id)
self._dispose_callbacks.append(controller.close)
self.add_stream(stream)
return controller

Expand All @@ -33,10 +35,10 @@ def add_stream(self, stream: AsyncGenerator[AssistantStreamChunk, None]):

async def reader():
async for chunk in stream:
await self.queue.put(chunk)
await self._queue.put(chunk)

task = asyncio.create_task(reader())
self.stream_tasks.append(task)
self._stream_tasks.append(task)


async def create_run(
Expand All @@ -49,18 +51,20 @@ async def background_task():
try:
await callback(controller)

for task in controller.stream_tasks:
for dispose in controller._dispose_callbacks:
dispose()
for task in controller._stream_tasks:
await task
finally:
asyncio.get_event_loop().call_soon_threadsafe(queue.put_nowait, None)

task = asyncio.create_task(background_task())

while True:
chunk = await controller.queue.get()
chunk = await controller._queue.get()
if chunk is None:
break
yield chunk
controller.queue.task_done()
controller._queue.task_done()

await task
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def set_result(self, result: Any):
result=result,
)
self.loop.call_soon_threadsafe(self.queue.put_nowait, chunk)
self.close()

def close(self):
"""Close the stream."""
self.loop.call_soon_threadsafe(self.queue.put_nowait, None)


Expand Down

0 comments on commit 27a1e82

Please sign in to comment.