Skip to content

Commit

Permalink
fix: fix empty outputs in langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanbohan committed Dec 25, 2023
1 parent b740f0a commit a0fc340
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/greptimeai/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def get_span_context(

context_list = self._spans.get(key, [])
if len(context_list) == 0:
logger.warning(f"get_span_context: { key } not found for { span_name= }")
logger.debug(f"get_span_context: { key } not found for { span_name= }")
return None

if not span_name:
Expand Down Expand Up @@ -232,7 +232,7 @@ def pop_span_context(

context_list = self._spans.get(key, [])
if len(context_list) == 0:
logger.warning(f"pop_span_context: { key } not found for { span_name= }")
logger.debug(f"pop_span_context: { key } not found for { span_name= }")
return None

target_context = None
Expand Down
67 changes: 47 additions & 20 deletions src/greptimeai/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from langchain.schema.document import Document
from langchain.schema.messages import BaseMessage

from greptimeai import logger

Check failure on line 7 in src/greptimeai/langchain/__init__.py

View workflow job for this annotation

GitHub Actions / ci (3.8)

Ruff (F401)

src/greptimeai/langchain/__init__.py:7:24: F401 `greptimeai.logger` imported but unused

_SPAN_NAME_CHAIN = "langchain_chain"
_SPAN_NAME_AGENT = "langchain_agent"
_SPAN_NAME_LLM = "langchain_llm"
Expand Down Expand Up @@ -74,27 +76,26 @@ def _parse_output(raw_output: dict) -> Any:
)


def _parse_generation(gen: Generation) -> Optional[Dict[str, Any]]:
"""
Generation, or ChatGeneration (which contains message field)
"""
if not gen:
return None

info = gen.generation_info or {}
attrs = {
"text": gen.text,
# the following is OpenAI only?
"finish_reason": info.get("finish_reason"),
"log_probability": info.get("logprobs"),
}
def _str_generations(gens: Sequence[Generation]) -> str:
def _str_generation(gen: Generation) -> Optional[str]:
"""
Generation, or ChatGeneration (which contains message field)
"""
if not gen:
return None

if isinstance(gen, ChatGeneration):
message: BaseMessage = gen.message
attrs["additional_kwargs"] = message.additional_kwargs
attrs["type"] = message.type
info = gen.generation_info or {}
reason = info.get("finish_reason")
if reason in ["function_call", "tool_calls"] and isinstance(
gen, ChatGeneration
):
kwargs = gen.message.additional_kwargs
return f"{reason}: kwargs={kwargs}"
else:
return gen.text

return attrs
texts = list(filter(None, [_str_generation(gen) for gen in gens]))
return "\n".join(texts)


def _parse_generations(
Expand All @@ -103,8 +104,33 @@ def _parse_generations(
"""
parse LLMResult.generations[0] to structured fields
"""

def _parse_generation(gen: Generation) -> Optional[Dict[str, Any]]:
"""
Generation, or ChatGeneration (which contains message field)
"""
if not gen:
return None

gen.to_json()

info = gen.generation_info or {}
attrs = {
"text": gen.text,
# the following is OpenAI only?
"finish_reason": info.get("finish_reason"),
"log_probability": info.get("logprobs"),
}

if isinstance(gen, ChatGeneration):
message: BaseMessage = gen.message
attrs["additional_kwargs"] = message.additional_kwargs
attrs["type"] = message.type

return attrs

if gens and len(gens) > 0:
return list(filter(None, [_parse_generation(gen) for gen in gens if gen]))
return list(filter(None, [_parse_generation(gen) for gen in gens]))

return None

Expand All @@ -124,3 +150,4 @@ def _parse_doc(doc: Document) -> Dict[str, Any]:
return [_parse_doc(doc) for doc in docs]

return None
return None
4 changes: 2 additions & 2 deletions src/greptimeai/langchain/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
_parse_generations,
_parse_input,
_parse_output,
_str_generations,
)


Expand Down Expand Up @@ -271,8 +272,7 @@ def on_llm_end(
if response and len(response.generations) > 0
else []
)
texts = [generation.text for generation in generations]
outputs = " ".join(texts)
outputs = _str_generations(generations)

output = response.llm_output or {}
model_name: Optional[str] = output.get("model_name")
Expand Down

0 comments on commit a0fc340

Please sign in to comment.