Skip to content

Commit

Permalink
collect prompt token and cost for stream
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanbohan committed Dec 6, 2023
1 parent 6826b12 commit 927502a
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
30 changes: 30 additions & 0 deletions src/greptimeai/extractor/openai_extractor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any, Dict, Optional, Tuple, Union

from openai._response import APIResponse
Expand Down Expand Up @@ -27,6 +28,35 @@


class OpenaiExtractor(BaseExtractor):
@staticmethod
def is_stream(**kwargs) -> bool:
return bool(kwargs.get("stream"))

@staticmethod
def extract_req_tokens(**kwargs) -> Optional[str]:
"""
NOTE: only for completion and chat completion so far.
TODO(ynanbohan): better way to extract req tokens
"""
if kwargs.get("prompt"):
prompt = kwargs["prompt"]
if isinstance(prompt, str):
return prompt
elif isinstance(prompt, list) and all(isinstance(p, str) for p in prompt):
return " ".join(prompt)
else:
logger.warning(f"Failed to extract req tokens from {prompt=}")
return None
elif kwargs.get("messages"):
try:
return json.dumps(kwargs["messages"])
except Exception as e:
logger.warning(f"Failed to extract req tokens from {kwargs=}: {e}")
return None
else:
logger.warning(f"Failed to extract req tokens from {kwargs=}")
return None

@staticmethod
def update_trace_info(kwargs: Dict[str, Any], trace_id: str, span_id: str):
attrs = {_X_TRACE_ID_KEY: trace_id, _X_SPAN_ID_KEY: span_id}
Expand Down
41 changes: 40 additions & 1 deletion src/greptimeai/patcher/openai_patcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from greptimeai.collector import Collector
from greptimeai.extractor import Extraction
from greptimeai.extractor.openai_extractor import OpenaiExtractor
from greptimeai.labels import _MODEL_LABEL, _SPAN_NAME_LABEL
from greptimeai.labels import (
_MODEL_LABEL,
_SPAN_NAME_LABEL,
_PROMPT_COST_LABEl,
_PROMPT_TOKENS_LABEl,
)
from greptimeai.patchee import Patchee
from greptimeai.patchee.openai_patchee import OpenaiPatchees
from greptimeai.patchee.openai_patchee.audio import AudioPatchees
Expand All @@ -21,6 +26,10 @@
from greptimeai.patchee.openai_patchee.moderation import ModerationPatchees
from greptimeai.patcher import Patcher
from greptimeai.patcher.openai_patcher.stream import AsyncStream_, Stream_
from greptimeai.utils.openai.token import (
get_openai_token_cost_for_model,
num_tokens_from_messages,
)


class _OpenaiPatcher(Patcher):
Expand All @@ -40,6 +49,25 @@ def __init__(

self.patchees = patchees

def _collect_req_metrics_for_stream(
self, model_name: Optional[str], span_name: str, tokens: Optional[str]
):
model_name = model_name or ""
attrs = {
_SPAN_NAME_LABEL: span_name,
_MODEL_LABEL: model_name,
}

num = num_tokens_from_messages(tokens or "")
cost = get_openai_token_cost_for_model(model_name, num)

span_attrs = {
_PROMPT_TOKENS_LABEl: num,
_PROMPT_COST_LABEl: cost,
}

self.collector.collect_metrics(span_attrs=span_attrs, attrs=attrs)

def _pre_patch(
self,
span_name: str,
Expand All @@ -53,6 +81,17 @@ def _pre_patch(
event_attrs=extraction.event_attributes,
)
OpenaiExtractor.update_trace_info(kwargs, trace_id, span_id)

# if stream, the usage won't be included in the resp,
# so we need to extract and collect it from req for best.
if OpenaiExtractor.is_stream(**kwargs):
tokens = OpenaiExtractor.extract_req_tokens(**kwargs)
self._collect_req_metrics_for_stream(
model_name=extraction.get_model_name(),
span_name=span_name,
tokens=tokens,
)

start = time.time()
return (extraction, span_id, start, kwargs)

Expand Down

0 comments on commit 927502a

Please sign in to comment.