diff --git a/examples/lightrag_zhipu_demo.py b/examples/lightrag_zhipu_demo.py index bcade616..0924656d 100644 --- a/examples/lightrag_zhipu_demo.py +++ b/examples/lightrag_zhipu_demo.py @@ -1,9 +1,6 @@ -import asyncio import os -import inspect import logging -from dotenv import load_dotenv from lightrag import LightRAG, QueryParam from lightrag.llm import zhipu_complete, zhipu_embedding @@ -21,7 +18,6 @@ raise Exception("Please set ZHIPU_API_KEY in your environment") - rag = LightRAG( working_dir=WORKING_DIR, llm_model_func=zhipu_complete, @@ -31,9 +27,7 @@ embedding_func=EmbeddingFunc( embedding_dim=2048, # Zhipu embedding-3 dimension max_token_size=8192, - func=lambda texts: zhipu_embedding( - texts - ), + func=lambda texts: zhipu_embedding(texts), ), ) @@ -58,4 +52,4 @@ # Perform hybrid search print( rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")) -) \ No newline at end of file +) diff --git a/lightrag/__init__.py b/lightrag/__init__.py index 1c5cd617..85ceb3d2 100644 --- a/lightrag/__init__.py +++ b/lightrag/__init__.py @@ -1,5 +1,5 @@ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam -__version__ = "1.0.5" +__version__ = "1.0.6" __author__ = "Zirui Guo" __url__ = "https://github.com/HKUDS/LightRAG" diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index fe046eb4..bf20ffd7 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -63,7 +63,9 @@ async def wrapped_task(batch): return result embedding_tasks = [wrapped_task(batch) for batch in batches] - pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch") + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) embeddings_list = await asyncio.gather(*embedding_tasks) embeddings = np.concatenate(embeddings_list) diff --git a/lightrag/llm.py b/lightrag/llm.py index 591b5dc9..e89af0d8 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -604,11 +604,11 @@ async def ollama_model_complete( ) async def zhipu_complete_if_cache( prompt: Union[str, List[Dict[str, str]]], - model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series + model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series api_key: Optional[str] = None, system_prompt: Optional[str] = None, history_messages: List[Dict[str, str]] = [], - **kwargs + **kwargs, ) -> str: # dynamically load ZhipuAI try: @@ -640,13 +640,11 @@ async def zhipu_complete_if_cache( logger.debug(f"System prompt: {system_prompt}") # Remove unsupported kwargs - kwargs = {k: v for k, v in kwargs.items() if k not in ['hashing_kv', 'keyword_extraction']} + kwargs = { + k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"] + } - response = client.chat.completions.create( - model=model, - messages=messages, - **kwargs - ) + response = client.chat.completions.create(model=model, messages=messages, **kwargs) return response.choices[0].message.content @@ -663,13 +661,13 @@ async def zhipu_complete( Please analyze the content and extract two types of keywords: 1. High-level keywords: Important concepts and main themes 2. Low-level keywords: Specific details and supporting elements - + Return your response in this exact JSON format: { "high_level_keywords": ["keyword1", "keyword2"], "low_level_keywords": ["keyword1", "keyword2", "keyword3"] } - + Only return the JSON, no other text.""" # Combine with existing system prompt if any @@ -683,15 +681,15 @@ async def zhipu_complete( prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, - **kwargs + **kwargs, ) - + # Try to parse as JSON try: data = json.loads(response) return GPTKeywordExtractionFormat( high_level_keywords=data.get("high_level_keywords", []), - low_level_keywords=data.get("low_level_keywords", []) + low_level_keywords=data.get("low_level_keywords", []), ) except json.JSONDecodeError: # If direct JSON parsing fails, try to extract JSON from text @@ -701,13 +699,15 @@ async def zhipu_complete( data = json.loads(match.group()) return GPTKeywordExtractionFormat( high_level_keywords=data.get("high_level_keywords", []), - low_level_keywords=data.get("low_level_keywords", []) + low_level_keywords=data.get("low_level_keywords", []), ) except json.JSONDecodeError: pass - + # If all parsing fails, log warning and return empty format - logger.warning(f"Failed to parse keyword extraction response: {response}") + logger.warning( + f"Failed to parse keyword extraction response: {response}" + ) return GPTKeywordExtractionFormat( high_level_keywords=[], low_level_keywords=[] ) @@ -722,7 +722,7 @@ async def zhipu_complete( prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, - **kwargs + **kwargs, ) @@ -733,13 +733,9 @@ async def zhipu_complete( retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) async def zhipu_embedding( - texts: list[str], - model: str = "embedding-3", - api_key: str = None, - **kwargs + texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs ) -> np.ndarray: - -# dynamically load ZhipuAI + # dynamically load ZhipuAI try: from zhipuai import ZhipuAI except ImportError: @@ -758,11 +754,7 @@ async def zhipu_embedding( embeddings = [] for text in texts: try: - response = client.embeddings.create( - model=model, - input=[text], - **kwargs - ) + response = client.embeddings.create(model=model, input=[text], **kwargs) embeddings.append(response.data[0].embedding) except Exception as e: raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}") diff --git a/lightrag/storage.py b/lightrag/storage.py index 037a9c2f..0c880bb7 100644 --- a/lightrag/storage.py +++ b/lightrag/storage.py @@ -103,7 +103,9 @@ async def wrapped_task(batch): return result embedding_tasks = [wrapped_task(batch) for batch in batches] - pbar = tqdm_async(total=len(embedding_tasks), desc="Generating embeddings", unit="batch") + pbar = tqdm_async( + total=len(embedding_tasks), desc="Generating embeddings", unit="batch" + ) embeddings_list = await asyncio.gather(*embedding_tasks) embeddings = np.concatenate(embeddings_list)