Skip to content

Commit

Permalink
fix(agents-api): Patch unexpected response from litellm (finish_reaso…
Browse files Browse the repository at this point in the history
…n) (#894)

<!-- ELLIPSIS_HIDDEN -->


> [!IMPORTANT]
> Add function to handle 'eos' finish reasons in litellm responses and
reformat print statements in `codec.py`.
> 
>   - **Behavior**:
> - Add `patch_litellm_response()` in `litellm.py` to convert 'eos'
finish reasons to 'stop'.
> - Apply `patch_litellm_response()` to `model_response` in
`acompletion()`.
>   - **Misc**:
> - Reformat print statements in `serialize()` and `deserialize()` in
`codec.py` for readability.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=julep-ai%2Fjulep&utm_source=github&utm_medium=referral)<sup>
for 653b8f3. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
HamadaSalhab authored Nov 27, 2024
1 parent 58e8f3c commit 32bd599
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
25 changes: 24 additions & 1 deletion agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@
litellm.drop_params = True


def patch_litellm_response(
model_response: ModelResponse | CustomStreamWrapper,
) -> ModelResponse | CustomStreamWrapper:
"""
Patches the response we get from litellm to handle unexpected response formats.
"""

if isinstance(model_response, ModelResponse):
for choice in model_response.choices:
if choice.finish_reason == "eos":
choice.finish_reason = "stop"

elif isinstance(model_response, CustomStreamWrapper):
if model_response.received_finish_reason == "eos":
model_response.received_finish_reason = "stop"

return model_response


@wraps(_acompletion)
@beartype
async def acompletion(
Expand All @@ -45,14 +64,18 @@ async def acompletion(
for message in messages
]

return await _acompletion(
model_response = await _acompletion(
model=model,
messages=messages,
**settings,
base_url=None if custom_api_key else litellm_url,
api_key=custom_api_key or litellm_master_key,
)

model_response = patch_litellm_response(model_response)

return model_response


@wraps(_aembedding)
@beartype
Expand Down
8 changes: 6 additions & 2 deletions agents-api/agents_api/worker/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def serialize(x: Any) -> bytes:

duration = time.time() - start_time
if duration > 1:
print(f"||| [SERIALIZE] Time taken: {duration}s // Object size: {sys.getsizeof(x) / 1000}kb")
print(
f"||| [SERIALIZE] Time taken: {duration}s // Object size: {sys.getsizeof(x) / 1000}kb"
)

return compressed

Expand All @@ -43,7 +45,9 @@ def deserialize(b: bytes) -> Any:

duration = time.time() - start_time
if duration > 1:
print(f"||| [DESERIALIZE] Time taken: {duration}s // Object size: {sys.getsizeof(b) / 1000}kb")
print(
f"||| [DESERIALIZE] Time taken: {duration}s // Object size: {sys.getsizeof(b) / 1000}kb"
)

return object

Expand Down

0 comments on commit 32bd599

Please sign in to comment.