diff --git a/agents-api/agents_api/clients/litellm.py b/agents-api/agents_api/clients/litellm.py index ea4721b54..3b1860073 100644 --- a/agents-api/agents_api/clients/litellm.py +++ b/agents-api/agents_api/clients/litellm.py @@ -27,6 +27,25 @@ litellm.drop_params = True +def patch_litellm_response( + model_response: ModelResponse | CustomStreamWrapper, +) -> ModelResponse | CustomStreamWrapper: + """ + Patches the response we get from litellm to handle unexpected response formats. + """ + + if isinstance(model_response, ModelResponse): + for choice in model_response.choices: + if choice.finish_reason == "eos": + choice.finish_reason = "stop" + + elif isinstance(model_response, CustomStreamWrapper): + if model_response.received_finish_reason == "eos": + model_response.received_finish_reason = "stop" + + return model_response + + @wraps(_acompletion) @beartype async def acompletion( @@ -45,7 +64,7 @@ async def acompletion( for message in messages ] - return await _acompletion( + model_response = await _acompletion( model=model, messages=messages, **settings, @@ -53,6 +72,10 @@ async def acompletion( api_key=custom_api_key or litellm_master_key, ) + model_response = patch_litellm_response(model_response) + + return model_response + @wraps(_aembedding) @beartype diff --git a/agents-api/agents_api/worker/codec.py b/agents-api/agents_api/worker/codec.py index a9f773900..d562b4653 100644 --- a/agents-api/agents_api/worker/codec.py +++ b/agents-api/agents_api/worker/codec.py @@ -31,7 +31,9 @@ def serialize(x: Any) -> bytes: duration = time.time() - start_time if duration > 1: - print(f"||| [SERIALIZE] Time taken: {duration}s // Object size: {sys.getsizeof(x) / 1000}kb") + print( + f"||| [SERIALIZE] Time taken: {duration}s // Object size: {sys.getsizeof(x) / 1000}kb" + ) return compressed @@ -43,7 +45,9 @@ def deserialize(b: bytes) -> Any: duration = time.time() - start_time if duration > 1: - print(f"||| [DESERIALIZE] Time taken: {duration}s // Object size: {sys.getsizeof(b) / 1000}kb") + print( + f"||| [DESERIALIZE] Time taken: {duration}s // Object size: {sys.getsizeof(b) / 1000}kb" + ) return object