diff --git a/.DS_Store b/.DS_Store index 651e36ed..7489d923 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/examples/lightrag_zhipu_demo.py b/examples/lightrag_zhipu_demo.py new file mode 100644 index 00000000..bcade616 --- /dev/null +++ b/examples/lightrag_zhipu_demo.py @@ -0,0 +1,61 @@ +import asyncio +import os +import inspect +import logging + +from dotenv import load_dotenv + +from lightrag import LightRAG, QueryParam +from lightrag.llm import zhipu_complete, zhipu_embedding +from lightrag.utils import EmbeddingFunc + +WORKING_DIR = "./dickens" + +logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO) + +if not os.path.exists(WORKING_DIR): + os.mkdir(WORKING_DIR) + +api_key = os.environ.get("ZHIPUAI_API_KEY") +if api_key is None: + raise Exception("Please set ZHIPU_API_KEY in your environment") + + + +rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=zhipu_complete, + llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here. + llm_model_max_async=4, + llm_model_max_token_size=32768, + embedding_func=EmbeddingFunc( + embedding_dim=2048, # Zhipu embedding-3 dimension + max_token_size=8192, + func=lambda texts: zhipu_embedding( + texts + ), + ), +) + +with open("./book.txt", "r", encoding="utf-8") as f: + rag.insert(f.read()) + +# Perform naive search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")) +) + +# Perform local search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="local")) +) + +# Perform global search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="global")) +) + +# Perform hybrid search +print( + rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) +) \ No newline at end of file diff --git a/lightrag/llm.py b/lightrag/llm.py index 636f03cb..591b5dc9 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -2,9 +2,10 @@ import copy import json import os +import re import struct from functools import lru_cache -from typing import List, Dict, Callable, Any, Union +from typing import List, Dict, Callable, Any, Union, Optional import aioboto3 import aiohttp import numpy as np @@ -596,6 +597,179 @@ async def ollama_model_complete( ) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def zhipu_complete_if_cache( + prompt: Union[str, List[Dict[str, str]]], + model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series + api_key: Optional[str] = None, + system_prompt: Optional[str] = None, + history_messages: List[Dict[str, str]] = [], + **kwargs +) -> str: + # dynamically load ZhipuAI + try: + from zhipuai import ZhipuAI + except ImportError: + raise ImportError("Please install zhipuai before initialize zhipuai backend.") + + if api_key: + client = ZhipuAI(api_key=api_key) + else: + # please set ZHIPUAI_API_KEY in your environment + # os.environ["ZHIPUAI_API_KEY"] + client = ZhipuAI() + + messages = [] + + if not system_prompt: + system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。" + + # Add system prompt if provided + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + # Add debug logging + logger.debug("===== Query Input to LLM =====") + logger.debug(f"Query: {prompt}") + logger.debug(f"System prompt: {system_prompt}") + + # Remove unsupported kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']} + + response = client.chat.completions.create( + model=model, + messages=messages, + **kwargs + ) + + return response.choices[0].message.content + + +async def zhipu_complete( + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs +): + # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache + keyword_extraction = kwargs.pop("keyword_extraction", None) + + if keyword_extraction: + # Add a system prompt to guide the model to return JSON format + extraction_prompt = """You are a helpful assistant that extracts keywords from text. + Please analyze the content and extract two types of keywords: + 1. High-level keywords: Important concepts and main themes + 2. Low-level keywords: Specific details and supporting elements + + Return your response in this exact JSON format: + { + "high_level_keywords": ["keyword1", "keyword2"], + "low_level_keywords": ["keyword1", "keyword2", "keyword3"] + } + + Only return the JSON, no other text.""" + + # Combine with existing system prompt if any + if system_prompt: + system_prompt = f"{system_prompt}\n\n{extraction_prompt}" + else: + system_prompt = extraction_prompt + + try: + response = await zhipu_complete_if_cache( + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs + ) + + # Try to parse as JSON + try: + data = json.loads(response) + return GPTKeywordExtractionFormat( + high_level_keywords=data.get("high_level_keywords", []), + low_level_keywords=data.get("low_level_keywords", []) + ) + except json.JSONDecodeError: + # If direct JSON parsing fails, try to extract JSON from text + match = re.search(r"\{[\s\S]*\}", response) + if match: + try: + data = json.loads(match.group()) + return GPTKeywordExtractionFormat( + high_level_keywords=data.get("high_level_keywords", []), + low_level_keywords=data.get("low_level_keywords", []) + ) + except json.JSONDecodeError: + pass + + # If all parsing fails, log warning and return empty format + logger.warning(f"Failed to parse keyword extraction response: {response}") + return GPTKeywordExtractionFormat( + high_level_keywords=[], low_level_keywords=[] + ) + except Exception as e: + logger.error(f"Error during keyword extraction: {str(e)}") + return GPTKeywordExtractionFormat( + high_level_keywords=[], low_level_keywords=[] + ) + else: + # For non-keyword-extraction, just return the raw response string + return await zhipu_complete_if_cache( + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs + ) + + +@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def zhipu_embedding( + texts: list[str], + model: str = "embedding-3", + api_key: str = None, + **kwargs +) -> np.ndarray: + +# dynamically load ZhipuAI + try: + from zhipuai import ZhipuAI + except ImportError: + raise ImportError("Please install zhipuai before initialize zhipuai backend.") + if api_key: + client = ZhipuAI(api_key=api_key) + else: + # please set ZHIPUAI_API_KEY in your environment + # os.environ["ZHIPUAI_API_KEY"] + client = ZhipuAI() + + # Convert single text to list if needed + if isinstance(texts, str): + texts = [texts] + + embeddings = [] + for text in texts: + try: + response = client.embeddings.create( + model=model, + input=[text], + **kwargs + ) + embeddings.append(response.data[0].embedding) + except Exception as e: + raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}") + + return np.array(embeddings) + + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3),