Skip to content

Commit

Permalink
Merge pull request #462 from JasonGuoo/main
Browse files Browse the repository at this point in the history
Supporting Zhipu AI API
  • Loading branch information
LarFii authored Dec 13, 2024
2 parents ae0c43b + e64cf50 commit 9cac3b0
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 1 deletion.
Binary file modified .DS_Store
Binary file not shown.
61 changes: 61 additions & 0 deletions examples/lightrag_zhipu_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import asyncio
import os
import inspect
import logging

from dotenv import load_dotenv

from lightrag import LightRAG, QueryParam
from lightrag.llm import zhipu_complete, zhipu_embedding
from lightrag.utils import EmbeddingFunc

WORKING_DIR = "./dickens"

logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)

api_key = os.environ.get("ZHIPUAI_API_KEY")
if api_key is None:
raise Exception("Please set ZHIPU_API_KEY in your environment")



rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=zhipu_complete,
llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here.
llm_model_max_async=4,
llm_model_max_token_size=32768,
embedding_func=EmbeddingFunc(
embedding_dim=2048, # Zhipu embedding-3 dimension
max_token_size=8192,
func=lambda texts: zhipu_embedding(
texts
),
),
)

with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())

# Perform naive search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
)

# Perform local search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
)

# Perform global search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
)

# Perform hybrid search
print(
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
)
176 changes: 175 additions & 1 deletion lightrag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import copy
import json
import os
import re
import struct
from functools import lru_cache
from typing import List, Dict, Callable, Any, Union
from typing import List, Dict, Callable, Any, Union, Optional
import aioboto3
import aiohttp
import numpy as np
Expand Down Expand Up @@ -596,6 +597,179 @@ async def ollama_model_complete(
)


@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def zhipu_complete_if_cache(
prompt: Union[str, List[Dict[str, str]]],
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
**kwargs
) -> str:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")

if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()

messages = []

if not system_prompt:
system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。"

# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})

# Add debug logging
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
logger.debug(f"System prompt: {system_prompt}")

# Remove unsupported kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']}

response = client.chat.completions.create(
model=model,
messages=messages,
**kwargs
)

return response.choices[0].message.content


async def zhipu_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
):
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)

if keyword_extraction:
# Add a system prompt to guide the model to return JSON format
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
Please analyze the content and extract two types of keywords:
1. High-level keywords: Important concepts and main themes
2. Low-level keywords: Specific details and supporting elements
Return your response in this exact JSON format:
{
"high_level_keywords": ["keyword1", "keyword2"],
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
}
Only return the JSON, no other text."""

# Combine with existing system prompt if any
if system_prompt:
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
else:
system_prompt = extraction_prompt

try:
response = await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
)

# Try to parse as JSON
try:
data = json.loads(response)
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", [])
)
except json.JSONDecodeError:
# If direct JSON parsing fails, try to extract JSON from text
match = re.search(r"\{[\s\S]*\}", response)
if match:
try:
data = json.loads(match.group())
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", [])
)
except json.JSONDecodeError:
pass

# If all parsing fails, log warning and return empty format
logger.warning(f"Failed to parse keyword extraction response: {response}")
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
except Exception as e:
logger.error(f"Error during keyword extraction: {str(e)}")
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
else:
# For non-keyword-extraction, just return the raw response string
return await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs
)


@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
)
async def zhipu_embedding(
texts: list[str],
model: str = "embedding-3",
api_key: str = None,
**kwargs
) -> np.ndarray:

# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()

# Convert single text to list if needed
if isinstance(texts, str):
texts = [texts]

embeddings = []
for text in texts:
try:
response = client.embeddings.create(
model=model,
input=[text],
**kwargs
)
embeddings.append(response.data[0].embedding)
except Exception as e:
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")

return np.array(embeddings)


@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
Expand Down

0 comments on commit 9cac3b0

Please sign in to comment.