Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into fb-dia-1716
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-bernstein committed Jan 29, 2025
2 parents 11368f4 + 4132c6f commit 5fc0a31
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
17 changes: 14 additions & 3 deletions adala/runtimes/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _get_usage_dict(usage: Usage, model: str) -> Dict:
data["_prompt_cost_usd"] = prompt_cost
data["_completion_cost_usd"] = completion_cost
data["_total_cost_usd"] = prompt_cost + completion_cost
except NotFoundError:
except:
logger.error(f"Failed to get cost for model {model}")
data["_prompt_cost_usd"] = None
data["_completion_cost_usd"] = None
Expand Down Expand Up @@ -215,6 +215,8 @@ def handle_llm_exception(
# usage = e.total_usage
# not available here, so have to approximate by hand, assuming the same error occurred each time
n_attempts = retries.stop.max_attempt_number
# Note that the default model used in token_counter is gpt-3.5-turbo as of now - if model passed in
# does not match a mapped model, falls back to default
prompt_tokens = n_attempts * litellm.token_counter(
model=model, messages=messages[:-1]
) # response is appended as the last message
Expand Down Expand Up @@ -368,11 +370,16 @@ def record_to_record(
)
usage = completion.usage
dct = to_jsonable_python(response)
# With successful completions we can get canonical model name
usage_model = completion.model

except Exception as e:
dct, usage = handle_llm_exception(e, messages, self.model, retries)
# With exceptions we dont have access to completion.model
usage_model = self.model

# Add usage data to the response (e.g. token counts, cost)
dct.update(_get_usage_dict(usage, model=self.model))
dct.update(_get_usage_dict(usage, model=usage_model))

return dct

Expand Down Expand Up @@ -499,13 +506,17 @@ async def batch_to_batch(
dct, usage = handle_llm_exception(
response, messages, self.model, retries
)
# With exceptions we dont have access to completion.model
usage_model = self.model
else:
resp, completion = response
usage = completion.usage
dct = to_jsonable_python(resp)
# With successful completions we can get canonical model name
usage_model = completion.model

# Add usage data to the response (e.g. token counts, cost)
dct.update(_get_usage_dict(usage, model=self.model))
dct.update(_get_usage_dict(usage, model=usage_model))

df_data.append(dct)

Expand Down
1 change: 1 addition & 0 deletions server/tasks/stream_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ async def async_process_streaming_input(input_task_done: asyncio.Event, agent: A
logger.error(
f"Error in async_process_streaming_input: {e}. Traceback: {traceback.format_exc()}"
)
input_task_done.set()
# cleans up after any exceptions raised here as well as asyncio.CancelledError resulting from failure in async_process_streaming_output
finally:
await agent.environment.finalize()
Expand Down

0 comments on commit 5fc0a31

Please sign in to comment.