Skip to content

Commit

Permalink
fix: model_dump
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanbohan committed Dec 21, 2023
1 parent 54e0095 commit d0aa8c2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 26 deletions.
3 changes: 1 addition & 2 deletions src/greptimeai/patcher/openai_patcher/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Callable, Optional, Union

from openai import AsyncOpenAI, OpenAI
from pydantic import BaseModel
from typing_extensions import override

from greptimeai import logger
Expand All @@ -26,7 +25,7 @@ def __init__(
super().__init__(collector=collector, patchees=patchees, client=client)

def _add_retry_event(self, *args):
if len(args) > 0 and isinstance(args[0], BaseModel):
if len(args) > 0 and hasattr(args[0], "model_dump"):
dict = args[0].model_dump(exclude_unset=True)
span_id = dict.get("headers", {}).get(_EXTRA_HEADERS_X_SPAN_ID_KEY)
if span_id:
Expand Down
39 changes: 15 additions & 24 deletions src/greptimeai/patcher/openai_patcher/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,6 @@
)


def _extract_resp(resp: Any) -> Dict[str, Any]:
try:
if hasattr(resp, "model_dump"):
return resp.model_dump()
else:
logger.warning(f"Unknown response stream type: {type(resp)}")
return {}
except Exception as e:
logger.error(f"Failed to extract response: {e}")
return {}


def _extract_tokens(chunk: BaseModel) -> str:
if not chunk:
return ""
Expand All @@ -40,25 +28,28 @@ def _extract_tokens(chunk: BaseModel) -> str:
dict_ = chunk.model_dump()
choices = dict_.get("choices", [])
for choice in choices:
content = choice.model_dump().get("delta", {}).get("content")
if content:
content = choice.get("delta", {}).get("content")
if content: # chat completion
tokens += content
else:
text = choice.model_dump().get("text")
if text:
tokens += text
elif choice.get("text"): # completion
tokens += choice.get("text")
except Exception as e:
logger.error(f"Failed to extract chunk tokens: {e}")

return tokens


def _collect_resp(
resp: Any, collector: Collector, span_id: str, event_name: str
def _collect_stream_item(
item: Any, collector: Collector, span_id: str, event_name: str
) -> Tuple[str, str]:
event_attrs = _extract_resp(resp)
event_attrs = {}
if hasattr(item, "model_dump"):
event_attrs = item.model_dump()
else:
logger.warning(f"Unknown response stream type: {type(item)}")

model_name = event_attrs.get("model", "")
tokens = _extract_tokens(resp)
tokens = _extract_tokens(item)
collector._collector.add_span_event(
span_id=span_id,
event_name=event_name,
Expand Down Expand Up @@ -128,7 +119,7 @@ def __iter__(self) -> Iterator[Any]:
for item in self.stream:
yield item

tokens, model_name = _collect_resp(
tokens, model_name = _collect_stream_item(
item, self.collector, self.span_id, "stream"
)
completion_tokens += tokens
Expand Down Expand Up @@ -168,7 +159,7 @@ async def __aiter__(self) -> AsyncIterator[Any]:
async for item in self.astream:
yield item

tokens, model_name = _collect_resp(
tokens, model_name = _collect_stream_item(
item, self.collector, self.span_id, "stream"
)
completion_tokens += tokens
Expand Down

0 comments on commit d0aa8c2

Please sign in to comment.