diff --git a/src/greptimeai/patcher/openai_patcher/retry.py b/src/greptimeai/patcher/openai_patcher/retry.py index e173d00..02f1db2 100644 --- a/src/greptimeai/patcher/openai_patcher/retry.py +++ b/src/greptimeai/patcher/openai_patcher/retry.py @@ -2,7 +2,6 @@ from typing import Callable, Optional, Union from openai import AsyncOpenAI, OpenAI -from pydantic import BaseModel from typing_extensions import override from greptimeai import logger @@ -26,7 +25,7 @@ def __init__( super().__init__(collector=collector, patchees=patchees, client=client) def _add_retry_event(self, *args): - if len(args) > 0 and isinstance(args[0], BaseModel): + if len(args) > 0 and hasattr(args[0], "model_dump"): dict = args[0].model_dump(exclude_unset=True) span_id = dict.get("headers", {}).get(_EXTRA_HEADERS_X_SPAN_ID_KEY) if span_id: diff --git a/src/greptimeai/patcher/openai_patcher/stream.py b/src/greptimeai/patcher/openai_patcher/stream.py index 959299b..517191b 100644 --- a/src/greptimeai/patcher/openai_patcher/stream.py +++ b/src/greptimeai/patcher/openai_patcher/stream.py @@ -19,18 +19,6 @@ ) -def _extract_resp(resp: Any) -> Dict[str, Any]: - try: - if hasattr(resp, "model_dump"): - return resp.model_dump() - else: - logger.warning(f"Unknown response stream type: {type(resp)}") - return {} - except Exception as e: - logger.error(f"Failed to extract response: {e}") - return {} - - def _extract_tokens(chunk: BaseModel) -> str: if not chunk: return "" @@ -40,25 +28,28 @@ def _extract_tokens(chunk: BaseModel) -> str: dict_ = chunk.model_dump() choices = dict_.get("choices", []) for choice in choices: - content = choice.model_dump().get("delta", {}).get("content") - if content: + content = choice.get("delta", {}).get("content") + if content: # chat completion tokens += content - else: - text = choice.model_dump().get("text") - if text: - tokens += text + elif choice.get("text"): # completion + tokens += choice.get("text") except Exception as e: logger.error(f"Failed to extract chunk tokens: {e}") return tokens -def _collect_resp( - resp: Any, collector: Collector, span_id: str, event_name: str +def _collect_stream_item( + item: Any, collector: Collector, span_id: str, event_name: str ) -> Tuple[str, str]: - event_attrs = _extract_resp(resp) + event_attrs = {} + if hasattr(item, "model_dump"): + event_attrs = item.model_dump() + else: + logger.warning(f"Unknown response stream type: {type(item)}") + model_name = event_attrs.get("model", "") - tokens = _extract_tokens(resp) + tokens = _extract_tokens(item) collector._collector.add_span_event( span_id=span_id, event_name=event_name, @@ -128,7 +119,7 @@ def __iter__(self) -> Iterator[Any]: for item in self.stream: yield item - tokens, model_name = _collect_resp( + tokens, model_name = _collect_stream_item( item, self.collector, self.span_id, "stream" ) completion_tokens += tokens @@ -168,7 +159,7 @@ async def __aiter__(self) -> AsyncIterator[Any]: async for item in self.astream: yield item - tokens, model_name = _collect_resp( + tokens, model_name = _collect_stream_item( item, self.collector, self.span_id, "stream" ) completion_tokens += tokens