Skip to content

Commit

Permalink
stream async support
Browse files Browse the repository at this point in the history
  • Loading branch information
M-Naveed-Ashraf committed Dec 31, 2024
1 parent 6b76200 commit 5f12676
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
17 changes: 9 additions & 8 deletions promptlayer/promptlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ async def _track_request(**body):
)
return await atrack_request(**track_request_kwargs)

return await _track_request
return _track_request

async def _track_request_log(
self,
Expand Down Expand Up @@ -547,15 +547,16 @@ async def _run_internal(
)

if stream:
track_request_callable = await self._create_track_request_callable(
request_params=llm_request_params,
tags=tags,
input_variables=input_variables,
group_id=group_id,
pl_run_span_id=pl_run_span_id,
)
return astream_response(
response,
self._create_track_request_callable(
request_params=llm_request_params,
tags=tags,
input_variables=input_variables,
group_id=group_id,
pl_run_span_id=pl_run_span_id,
),
track_request_callable,
llm_request_params["stream_function"],
)

Expand Down
21 changes: 11 additions & 10 deletions promptlayer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import contextvars
import datetime
import functools
import inspect
import json
import os
import sys
Expand Down Expand Up @@ -1449,15 +1448,17 @@ async def astream_response(
results.append(result)
data["raw_response"] = result
yield data
request_response = await map_results(results)
if inspect.iscoroutinefunction(after_stream):
# after_stream is an async function
response = await after_stream(request_response=request_response.model_dump())
else:
# after_stream is synchronous
response = after_stream(request_response=request_response.model_dump())
data["request_id"] = response.get("request_id")
data["prompt_blueprint"] = response.get("prompt_blueprint")

async def async_generator_from_list(lst):
for item in lst:
yield item

request_response = await map_results(async_generator_from_list(results))
after_stream_response = await after_stream(
request_response=request_response.model_dump()
)
data["request_id"] = after_stream_response.get("request_id")
data["prompt_blueprint"] = after_stream_response.get("prompt_blueprint")
yield data


Expand Down

0 comments on commit 5f12676

Please sign in to comment.