Skip to content

Commit

Permalink
swapped asyncio for thread (#726)
Browse files Browse the repository at this point in the history
Fixed streaming buffer delay by separating stream handler to another
thread.
  • Loading branch information
anish-palakurthi authored Jun 28, 2024
1 parent edd6d0e commit e4f2daa
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 69 deletions.
42 changes: 30 additions & 12 deletions engine/language_client_python/python_src/baml_py/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
TypeBuilder,
)
from typing import Callable, Generic, Optional, TypeVar

import threading
import asyncio
import concurrent.futures

import queue

PartialOutputType = TypeVar("PartialOutputType")
FinalOutputType = TypeVar("FinalOutputType")
Expand All @@ -18,9 +21,10 @@ class BamlStream(Generic[PartialOutputType, FinalOutputType]):
__partial_coerce: Callable[[FunctionResult], PartialOutputType]
__final_coerce: Callable[[FunctionResult], FinalOutputType]
__ctx_manager: RuntimeContextManager
__task: Optional[asyncio.Task[FunctionResult]]
__event_queue: asyncio.Queue[Optional[FunctionResult]]
__task: Optional[threading.Thread]
__event_queue: queue.Queue[Optional[FunctionResult]]
__tb: Optional[TypeBuilder]
__future: concurrent.futures.Future[FunctionResult]

def __init__(
self,
Expand All @@ -29,41 +33,55 @@ def __init__(
final_coerce: Callable[[FunctionResult], FinalOutputType],
ctx_manager: RuntimeContextManager,
tb: Optional[TypeBuilder],

):
self.__ffi_stream = ffi_stream.on_event(self.__enqueue)
self.__partial_coerce = partial_coerce
self.__final_coerce = final_coerce
self.__ctx_manager = ctx_manager
self.__task = None
self.__event_queue = asyncio.Queue()
self.__event_queue = queue.Queue()
self.__tb = tb
self.__future = concurrent.futures.Future() # Initialize the future here

def __enqueue(self, data: FunctionResult) -> None:

self.__event_queue.put_nowait(data)

async def __drive_to_completion(self) -> FunctionResult:

try:
retval = await self.__ffi_stream.done(self.__ctx_manager)

self.__future.set_result(retval)
return retval
except Exception as e:
self.__future.set_exception(e)
raise
finally:
self.__event_queue.put_nowait(None)

def __drive_to_completion_in_bg(self) -> asyncio.Task[FunctionResult]:
# Doing this without using a compare-and-swap or lock is safe,
# because we don't cross an await point during it
def __drive_to_completion_in_bg(self) -> concurrent.futures.Future[FunctionResult]:
if self.__task is None:
self.__task = asyncio.create_task(self.__drive_to_completion())
self.__task = threading.Thread(target = self.threading_target, daemon=True)
self.__task.start()
return self.__future

def threading_target(self):
asyncio.run(self.__drive_to_completion(), debug=True)

return self.__task

async def __aiter__(self):
self.__drive_to_completion_in_bg()
while True:
event = await self.__event_queue.get()

event = self.__event_queue.get()

if event is None:

break
yield self.__partial_coerce(event.parsed())

async def get_final_response(self):
final = await self.__drive_to_completion_in_bg()
return self.__final_coerce(final.parsed())
final = self.__drive_to_completion_in_bg()
return self.__final_coerce((await asyncio.wrap_future(final)).parsed())
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,19 @@ function PromptTestClaude(input: string) -> string {
"#
}

function PromptTestOpenAI(input: string) -> string {

function PromptTestStreaming(input: string) -> string {
client GPT35
prompt #"
Tell me a haiku about {{ input }}
Tell me a short story about {{ input }}
"#
}
}

test TestName {
functions [PromptTestStreaming]
args {
input #"
hello world
"#
}
}
42 changes: 21 additions & 21 deletions integ-tests/python/baml_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ async def PromptTestClaudeChatNoSystem(
mdl = create_model("PromptTestClaudeChatNoSystemReturnType", inner=(str, ...))
return coerce(mdl, raw.parsed())

async def PromptTestOpenAI(
async def PromptTestOpenAIChat(
self,
input: str,
baml_options: BamlCallOptions = {},
Expand All @@ -881,17 +881,17 @@ async def PromptTestOpenAI(
tb = None

raw = await self.__runtime.call_function(
"PromptTestOpenAI",
"PromptTestOpenAIChat",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
)
mdl = create_model("PromptTestOpenAIReturnType", inner=(str, ...))
mdl = create_model("PromptTestOpenAIChatReturnType", inner=(str, ...))
return coerce(mdl, raw.parsed())

async def PromptTestOpenAIChat(
async def PromptTestOpenAIChatNoSystem(
self,
input: str,
baml_options: BamlCallOptions = {},
Expand All @@ -903,17 +903,17 @@ async def PromptTestOpenAIChat(
tb = None

raw = await self.__runtime.call_function(
"PromptTestOpenAIChat",
"PromptTestOpenAIChatNoSystem",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
)
mdl = create_model("PromptTestOpenAIChatReturnType", inner=(str, ...))
mdl = create_model("PromptTestOpenAIChatNoSystemReturnType", inner=(str, ...))
return coerce(mdl, raw.parsed())

async def PromptTestOpenAIChatNoSystem(
async def PromptTestStreaming(
self,
input: str,
baml_options: BamlCallOptions = {},
Expand All @@ -925,14 +925,14 @@ async def PromptTestOpenAIChatNoSystem(
tb = None

raw = await self.__runtime.call_function(
"PromptTestOpenAIChatNoSystem",
"PromptTestStreaming",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
)
mdl = create_model("PromptTestOpenAIChatNoSystemReturnType", inner=(str, ...))
mdl = create_model("PromptTestStreamingReturnType", inner=(str, ...))
return coerce(mdl, raw.parsed())

async def TestAnthropic(
Expand Down Expand Up @@ -2551,7 +2551,7 @@ def PromptTestClaudeChatNoSystem(
tb,
)

def PromptTestOpenAI(
def PromptTestOpenAIChat(
self,
input: str,
baml_options: BamlCallOptions = {},
Expand All @@ -2563,7 +2563,7 @@ def PromptTestOpenAI(
tb = None

raw = self.__runtime.stream_function(
"PromptTestOpenAI",
"PromptTestOpenAIChat",
{
"input": input,
},
Expand All @@ -2572,8 +2572,8 @@ def PromptTestOpenAI(
tb,
)

mdl = create_model("PromptTestOpenAIReturnType", inner=(str, ...))
partial_mdl = create_model("PromptTestOpenAIPartialReturnType", inner=(Optional[str], ...))
mdl = create_model("PromptTestOpenAIChatReturnType", inner=(str, ...))
partial_mdl = create_model("PromptTestOpenAIChatPartialReturnType", inner=(Optional[str], ...))

return baml_py.BamlStream[Optional[str], str](
raw,
Expand All @@ -2583,7 +2583,7 @@ def PromptTestOpenAI(
tb,
)

def PromptTestOpenAIChat(
def PromptTestOpenAIChatNoSystem(
self,
input: str,
baml_options: BamlCallOptions = {},
Expand All @@ -2595,7 +2595,7 @@ def PromptTestOpenAIChat(
tb = None

raw = self.__runtime.stream_function(
"PromptTestOpenAIChat",
"PromptTestOpenAIChatNoSystem",
{
"input": input,
},
Expand All @@ -2604,8 +2604,8 @@ def PromptTestOpenAIChat(
tb,
)

mdl = create_model("PromptTestOpenAIChatReturnType", inner=(str, ...))
partial_mdl = create_model("PromptTestOpenAIChatPartialReturnType", inner=(Optional[str], ...))
mdl = create_model("PromptTestOpenAIChatNoSystemReturnType", inner=(str, ...))
partial_mdl = create_model("PromptTestOpenAIChatNoSystemPartialReturnType", inner=(Optional[str], ...))

return baml_py.BamlStream[Optional[str], str](
raw,
Expand All @@ -2615,7 +2615,7 @@ def PromptTestOpenAIChat(
tb,
)

def PromptTestOpenAIChatNoSystem(
def PromptTestStreaming(
self,
input: str,
baml_options: BamlCallOptions = {},
Expand All @@ -2627,7 +2627,7 @@ def PromptTestOpenAIChatNoSystem(
tb = None

raw = self.__runtime.stream_function(
"PromptTestOpenAIChatNoSystem",
"PromptTestStreaming",
{
"input": input,
},
Expand All @@ -2636,8 +2636,8 @@ def PromptTestOpenAIChatNoSystem(
tb,
)

mdl = create_model("PromptTestOpenAIChatNoSystemReturnType", inner=(str, ...))
partial_mdl = create_model("PromptTestOpenAIChatNoSystemPartialReturnType", inner=(Optional[str], ...))
mdl = create_model("PromptTestStreamingReturnType", inner=(str, ...))
partial_mdl = create_model("PromptTestStreamingPartialReturnType", inner=(Optional[str], ...))

return baml_py.BamlStream[Optional[str], str](
raw,
Expand Down
2 changes: 1 addition & 1 deletion integ-tests/python/baml_client/inlinedbaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"test-files/functions/output/optional.baml": "class OptionalTest_Prop1 {\n omega_a string\n omega_b int\n}\n\nenum OptionalTest_CategoryType {\n Aleph\n Beta\n Gamma\n}\n \nclass OptionalTest_ReturnType {\n omega_1 OptionalTest_Prop1?\n omega_2 string?\n omega_3 (OptionalTest_CategoryType?)[]\n} \n \nfunction OptionalTest_Function(input: string) -> (OptionalTest_ReturnType?)[]\n{\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest OptionalTest_Function {\n functions [OptionalTest_Function]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/string-list.baml": "function FnOutputStringList(input: string) -> string[] {\n client GPT35\n prompt #\"\n Return a list of strings in json format like [\"string1\", \"string2\", \"string3\"].\n\n JSON:\n \"#\n}\n\ntest FnOutputStringList {\n functions [FnOutputStringList]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/unions.baml": "class UnionTest_ReturnType {\n prop1 string | bool\n prop2 (float | bool)[]\n prop3 (bool[] | int[])\n}\n\nfunction UnionTest_Function(input: string | bool) -> UnionTest_ReturnType {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest UnionTest_Function {\n functions [UnionTest_Function]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Claude\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAI(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}",
"test-files/functions/prompts/no-chat-messages.baml": "\n\nfunction PromptTestClaude(input: string) -> string {\n client Claude\n prompt #\"\n Tell me a haiku about {{ input }}\n \"#\n}\n\n\nfunction PromptTestStreaming(input: string) -> string {\n client GPT35\n prompt #\"\n Tell me a short story about {{ input }}\n \"#\n}\n\ntest TestName {\n functions [PromptTestStreaming]\n args {\n input #\"\n hello world\n \"#\n }\n}\n",
"test-files/functions/prompts/with-chat-messages.baml": "\nfunction PromptTestOpenAIChat(input: string) -> string {\n client GPT35\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestOpenAIChatNoSystem(input: string) -> string {\n client GPT35\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChat(input: string) -> string {\n client Claude\n prompt #\"\n {{ _.role(\"system\") }}\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\nfunction PromptTestClaudeChatNoSystem(input: string) -> string {\n client Claude\n prompt #\"\n You are an assistant that always responds in a very excited way with emojis and also outputs this word 4 times after giving a response: {{ input }}\n \n {{ _.role(\"user\") }}\n Tell me a haiku about {{ input }}\n \"#\n}\n\ntest TestSystemAndNonSystemChat1 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"cats\"\n }\n}\n\ntest TestSystemAndNonSystemChat2 {\n functions [PromptTestClaude, PromptTestOpenAI, PromptTestOpenAIChat, PromptTestOpenAIChatNoSystem, PromptTestClaudeChat, PromptTestClaudeChatNoSystem]\n args {\n input \"lion\"\n }\n}",
"test-files/functions/v2/basic.baml": "\n\nfunction ExtractResume2(resume: string) -> Resume {\n client GPT4\n prompt #\"\n {{ _.role('system') }}\n\n Extract the following information from the resume:\n\n Resume:\n <<<<\n {{ resume }}\n <<<<\n\n Output JSON schema:\n {{ ctx.output_format }}\n\n JSON:\n \"#\n}\n\n\nclass WithReasoning {\n value string\n reasoning string @description(#\"\n Why the value is a good fit.\n \"#)\n}\n\n\nclass SearchParams {\n dateRange int? @description(#\"\n In ISO duration format, e.g. P1Y2M10D.\n \"#)\n location string[]\n jobTitle WithReasoning? @description(#\"\n An exact job title, not a general category.\n \"#)\n company WithReasoning? @description(#\"\n The exact name of the company, not a product or service.\n \"#)\n description WithReasoning[] @description(#\"\n Any specific projects or features the user is looking for.\n \"#)\n tags (Tag | string)[]\n}\n\nenum Tag {\n Security\n AI\n Blockchain\n}\n\nfunction GetQuery(query: string) -> SearchParams {\n client GPT4\n prompt #\"\n Extract the following information from the query:\n\n Query:\n <<<<\n {{ query }}\n <<<<\n\n OUTPUT_JSON_SCHEMA:\n {{ ctx.output_format }}\n\n Before OUTPUT_JSON_SCHEMA, list 5 intentions the user may have.\n --- EXAMPLES ---\n 1. <intent>\n 2. <intent>\n 3. <intent>\n 4. <intent>\n 5. <intent>\n\n {\n ... // OUTPUT_JSON_SCHEMA\n }\n \"#\n}\n\nclass RaysData {\n dataType DataType\n value Resume | Event\n}\n\nenum DataType {\n Resume\n Event\n}\n\nclass Event {\n title string\n date string\n location string\n description string\n}\n\nfunction GetDataType(text: string) -> RaysData {\n client GPT4\n prompt #\"\n Extract the relevant info.\n\n Text:\n <<<<\n {{ text }}\n <<<<\n\n Output JSON schema:\n {{ ctx.output_format }}\n\n JSON:\n \"#\n}",
"test-files/providers/providers.baml": "function TestAnthropic(input: string) -> string {\n client Claude\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestOpenAI(input: string) -> string {\n client GPT35\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestAzure(input: string) -> string {\n client GPT35Azure\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestOllama(input: string) -> string {\n client Ollama\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestGemini(input: string) -> string {\n client Gemini\n prompt #\"\n Write a nice short story about {{ input }}\n \"#\n}\n\n\ntest TestProvider {\n functions [TestAnthropic, TestOpenAI, TestAzure, TestOllama, TestGemini]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\n\n",
Expand Down
14 changes: 12 additions & 2 deletions integ-tests/python/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,23 @@ async def test_gemini():

@pytest.mark.asyncio
async def test_streaming():
stream = b.stream.PromptTestOpenAI(input="Programming languages are fun to create")
stream = b.stream.PromptTestStreaming(input="Programming languages are fun to create")
msgs = []

start_time = asyncio.get_event_loop().time()
last_msg_time = start_time
async for msg in stream:
msgs.append(msg)
if len(msgs) == 1:
first_msg_time = asyncio.get_event_loop().time()

last_msg_time = asyncio.get_event_loop().time()


final = await stream.get_final_response()

assert first_msg_time - start_time <= 1.5, "Expected first message within 1 second but it took longer."
assert last_msg_time - start_time >= 1, "Expected last message after 1.5 seconds but it was earlier."
assert len(final) > 0, "Expected non-empty final but got empty."
assert len(msgs) > 0, "Expected at least one streamed response but got none."
for prev_msg, msg in zip(msgs, msgs[1:]):
Expand All @@ -177,7 +188,6 @@ async def test_streaming():
)
assert msgs[-1] == final, "Expected last stream message to match final response."


@pytest.mark.asyncio
async def test_streaming_uniterated():
final = await b.stream.PromptTestOpenAI(
Expand Down
24 changes: 12 additions & 12 deletions integ-tests/ruby/baml_client/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,11 @@ def PromptTestClaudeChatNoSystem(
).returns(String)

}
def PromptTestOpenAI(
def PromptTestOpenAIChat(
input:
)
raw = @runtime.call_function(
"PromptTestOpenAI",
"PromptTestOpenAIChat",
{
"input" => input,
},
Expand All @@ -816,11 +816,11 @@ def PromptTestOpenAI(
).returns(String)

}
def PromptTestOpenAIChat(
def PromptTestOpenAIChatNoSystem(
input:
)
raw = @runtime.call_function(
"PromptTestOpenAIChat",
"PromptTestOpenAIChatNoSystem",
{
"input" => input,
},
Expand All @@ -836,11 +836,11 @@ def PromptTestOpenAIChat(
).returns(String)

}
def PromptTestOpenAIChatNoSystem(
def PromptTestStreaming(
input:
)
raw = @runtime.call_function(
"PromptTestOpenAIChatNoSystem",
"PromptTestStreaming",
{
"input" => input,
},
Expand Down Expand Up @@ -2018,11 +2018,11 @@ def PromptTestClaudeChatNoSystem(
input: String,
).returns(Baml::BamlStream[String])
}
def PromptTestOpenAI(
def PromptTestOpenAIChat(
input:
)
raw = @runtime.stream_function(
"PromptTestOpenAI",
"PromptTestOpenAIChat",
{
"input" => input,
},
Expand All @@ -2039,11 +2039,11 @@ def PromptTestOpenAI(
input: String,
).returns(Baml::BamlStream[String])
}
def PromptTestOpenAIChat(
def PromptTestOpenAIChatNoSystem(
input:
)
raw = @runtime.stream_function(
"PromptTestOpenAIChat",
"PromptTestOpenAIChatNoSystem",
{
"input" => input,
},
Expand All @@ -2060,11 +2060,11 @@ def PromptTestOpenAIChat(
input: String,
).returns(Baml::BamlStream[String])
}
def PromptTestOpenAIChatNoSystem(
def PromptTestStreaming(
input:
)
raw = @runtime.stream_function(
"PromptTestOpenAIChatNoSystem",
"PromptTestStreaming",
{
"input" => input,
},
Expand Down
Loading

0 comments on commit e4f2daa

Please sign in to comment.