From 584258078f542fc2c4c4ff1ade8cb5342d80989e Mon Sep 17 00:00:00 2001 From: yuanxiaobin Date: Fri, 6 Dec 2024 14:29:16 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E9=87=8D=E6=9E=84=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 提取通用缓存处理逻辑到新函数 handle_cache 和 save_to_cache - 使用 CacheData 类统一缓存数据结构 - 优化嵌入式缓存和常规缓存的处理流程 - 添加模式参数以支持不同查询模式的缓存策略 - 重构 get_best_cached_response 函数,提高缓存查询效率 --- lightrag/llm.py | 496 ++++++++++++++++++++------------------------ lightrag/operate.py | 6 +- lightrag/utils.py | 84 ++++---- 3 files changed, 277 insertions(+), 309 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index fef8c9a3..89d74a5b 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -4,7 +4,8 @@ import os import struct from functools import lru_cache -from typing import List, Dict, Callable, Any +from typing import List, Dict, Callable, Any, Optional +from dataclasses import dataclass import aioboto3 import aiohttp @@ -59,39 +60,21 @@ async def openai_complete_if_cache( openai_async_client = ( AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url) ) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + kwargs.get("hashing_kv"), args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response if "response_format" in kwargs: response = await openai_async_client.beta.chat.completions.parse( @@ -105,24 +88,21 @@ async def openai_complete_if_cache( if r"\u" in content: content = content.encode("utf-8").decode("unicode_escape") - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": content, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) + # Save to cache + await save_to_cache( + kwargs.get("hashing_kv"), + CacheData( + args_hash=args_hash, + content=content, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + return content @@ -155,6 +135,8 @@ async def azure_openai_complete_if_cache( ) hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + mode = kwargs.pop("mode", "default") + messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) @@ -162,56 +144,35 @@ async def azure_openai_complete_if_cache( if prompt is not None: messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) + content = response.choices[0].message.content + + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=content, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": response.choices[0].message.content, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) - return response.choices[0].message.content + return content class BedrockError(Exception): @@ -253,6 +214,15 @@ async def bedrock_complete_if_cache( # Add user prompt messages.append({"role": "user", "content": [{"text": prompt}]}) + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + kwargs.get("hashing_kv"), args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response + # Initialize Converse API arguments args = {"modelId": model, "messages": messages} @@ -275,33 +245,14 @@ async def bedrock_complete_if_cache( kwargs.pop(param) ) - hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + kwargs.get("hashing_kv"), args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response # Call model via Converse API session = aioboto3.Session() @@ -311,30 +262,22 @@ async def bedrock_complete_if_cache( except Exception as e: raise BedrockError(e) - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": response["output"]["message"]["content"][0]["text"], - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val - if is_embedding_cache_enabled - else None, - "embedding_max": max_val - if is_embedding_cache_enabled - else None, - "original_prompt": prompt, - } - } - ) + # Save to cache + await save_to_cache( + kwargs.get("hashing_kv"), + CacheData( + args_hash=args_hash, + content=response["output"]["message"]["content"][0]["text"], + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) - return response["output"]["message"]["content"][0]["text"] + return response["output"]["message"]["content"][0]["text"] @lru_cache(maxsize=1) @@ -372,32 +315,14 @@ async def hf_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response input_prompt = "" try: @@ -442,24 +367,22 @@ async def hf_model_if_cache( response_text = hf_tokenizer.decode( output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True ) - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": response_text, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) + + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response_text, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + return response_text @@ -489,55 +412,34 @@ async def ollama_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response response = await ollama_client.chat(model=model, messages=messages, **kwargs) result = response["message"]["content"] - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": result, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=result, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + return result @@ -649,32 +551,14 @@ async def lmdeploy_model_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) - if hashing_kv is not None: - # Calculate args_hash only when using cache - args_hash = compute_args_hash(model, messages) - - # Get embedding cache configuration - embedding_cache_config = hashing_kv.global_config.get( - "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} - ) - is_embedding_cache_enabled = embedding_cache_config["enabled"] - if is_embedding_cache_enabled: - # Use embedding cache - embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] - current_embedding = await embedding_model_func([prompt]) - quantized, min_val, max_val = quantize_embedding(current_embedding[0]) - best_cached_response = await get_best_cached_response( - hashing_kv, - current_embedding[0], - similarity_threshold=embedding_cache_config["similarity_threshold"], - ) - if best_cached_response is not None: - return best_cached_response - else: - # Use regular cache - if_cache_return = await hashing_kv.get_by_id(args_hash) - if if_cache_return is not None: - return if_cache_return["return"] + # Handle cache + mode = kwargs.pop("mode", "default") + args_hash = compute_args_hash(model, messages) + cached_response, quantized, min_val, max_val = await handle_cache( + hashing_kv, args_hash, prompt, mode + ) + if cached_response is not None: + return cached_response gen_config = GenerationConfig( skip_special_tokens=skip_special_tokens, @@ -692,24 +576,21 @@ async def lmdeploy_model_if_cache( ): response += res.response - if hashing_kv is not None: - await hashing_kv.upsert( - { - args_hash: { - "return": response, - "model": model, - "embedding": quantized.tobytes().hex() - if is_embedding_cache_enabled - else None, - "embedding_shape": quantized.shape - if is_embedding_cache_enabled - else None, - "embedding_min": min_val if is_embedding_cache_enabled else None, - "embedding_max": max_val if is_embedding_cache_enabled else None, - "original_prompt": prompt, - } - } - ) + # Save to cache + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + model=model, + prompt=prompt, + quantized=quantized, + min_val=min_val, + max_val=max_val, + mode=mode, + ), + ) + return response @@ -1139,6 +1020,75 @@ async def llm_model_func( return await next_model.gen_func(**args) +async def handle_cache(hashing_kv, args_hash, prompt, mode="default"): + """Generic cache handling function""" + if hashing_kv is None: + return None, None, None, None + + # Get embedding cache configuration + embedding_cache_config = hashing_kv.global_config.get( + "embedding_cache_config", {"enabled": False, "similarity_threshold": 0.95} + ) + is_embedding_cache_enabled = embedding_cache_config["enabled"] + + quantized = min_val = max_val = None + if is_embedding_cache_enabled: + # Use embedding cache + embedding_model_func = hashing_kv.global_config["embedding_func"]["func"] + current_embedding = await embedding_model_func([prompt]) + quantized, min_val, max_val = quantize_embedding(current_embedding[0]) + best_cached_response = await get_best_cached_response( + hashing_kv, + current_embedding[0], + similarity_threshold=embedding_cache_config["similarity_threshold"], + mode=mode, + ) + if best_cached_response is not None: + return best_cached_response, None, None, None + else: + # Use regular cache + mode_cache = await hashing_kv.get_by_id(mode) or {} + if args_hash in mode_cache: + return mode_cache[args_hash]["return"], None, None, None + + return None, quantized, min_val, max_val + + +@dataclass +class CacheData: + args_hash: str + content: str + model: str + prompt: str + quantized: Optional[np.ndarray] = None + min_val: Optional[float] = None + max_val: Optional[float] = None + mode: str = "default" + + +async def save_to_cache(hashing_kv, cache_data: CacheData): + if hashing_kv is None: + return + + mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {} + + mode_cache[cache_data.args_hash] = { + "return": cache_data.content, + "model": cache_data.model, + "embedding": cache_data.quantized.tobytes().hex() + if cache_data.quantized is not None + else None, + "embedding_shape": cache_data.quantized.shape + if cache_data.quantized is not None + else None, + "embedding_min": cache_data.min_val, + "embedding_max": cache_data.max_val, + "original_prompt": cache_data.prompt, + } + + await hashing_kv.upsert({cache_data.mode: mode_cache}) + + if __name__ == "__main__": import asyncio diff --git a/lightrag/operate.py b/lightrag/operate.py index a846cfc5..5b911d34 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -474,7 +474,9 @@ async def kg_query( use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) - result = await use_model_func(kw_prompt, keyword_extraction=True) + result = await use_model_func( + kw_prompt, keyword_extraction=True, mode=query_param.mode + ) logger.info("kw_prompt result:") print(result) try: @@ -534,6 +536,7 @@ async def kg_query( response = await use_model_func( query, system_prompt=sys_prompt, + mode=query_param.mode, ) if len(response) > len(sys_prompt): response = ( @@ -1035,6 +1038,7 @@ async def naive_query( response = await use_model_func( query, system_prompt=sys_prompt, + mode=query_param.mode, ) if len(response) > len(sys_prompt): diff --git a/lightrag/utils.py b/lightrag/utils.py index d080ee03..70ec4341 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -310,43 +310,57 @@ def process_combine_contexts(hl, ll): async def get_best_cached_response( - hashing_kv, current_embedding, similarity_threshold=0.95 -): - """Get the cached response with the highest similarity""" - try: - # Get all keys - all_keys = await hashing_kv.all_keys() - max_similarity = 0 - best_cached_response = None - - # Get cached data one by one - for key in all_keys: - cache_data = await hashing_kv.get_by_id(key) - if cache_data is None or "embedding" not in cache_data: - continue - - # Convert cached embedding list to ndarray - cached_quantized = np.frombuffer( - bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 - ).reshape(cache_data["embedding_shape"]) - cached_embedding = dequantize_embedding( - cached_quantized, - cache_data["embedding_min"], - cache_data["embedding_max"], - ) - - similarity = cosine_similarity(current_embedding, cached_embedding) - if similarity > max_similarity: - max_similarity = similarity - best_cached_response = cache_data["return"] - - if max_similarity > similarity_threshold: - return best_cached_response + hashing_kv, + current_embedding, + similarity_threshold=0.95, + mode="default", +) -> Union[str, None]: + # Get mode-specific cache + mode_cache = await hashing_kv.get_by_id(mode) + if not mode_cache: return None - except Exception as e: - logger.warning(f"Error in get_best_cached_response: {e}") - return None + best_similarity = -1 + best_response = None + best_prompt = None + best_cache_id = None + + # Only iterate through cache entries for this mode + for cache_id, cache_data in mode_cache.items(): + if cache_data["embedding"] is None: + continue + + # Convert cached embedding list to ndarray + cached_quantized = np.frombuffer( + bytes.fromhex(cache_data["embedding"]), dtype=np.uint8 + ).reshape(cache_data["embedding_shape"]) + cached_embedding = dequantize_embedding( + cached_quantized, + cache_data["embedding_min"], + cache_data["embedding_max"], + ) + + similarity = cosine_similarity(current_embedding, cached_embedding) + if similarity > best_similarity: + best_similarity = similarity + best_response = cache_data["return"] + best_prompt = cache_data["original_prompt"] + best_cache_id = cache_id + + if best_similarity > similarity_threshold: + prompt_display = ( + best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt + ) + log_data = { + "event": "cache_hit", + "mode": mode, + "similarity": round(best_similarity, 4), + "cache_id": best_cache_id, + "original_prompt": prompt_display, + } + logger.info(json.dumps(log_data)) + return best_response + return None def cosine_similarity(v1, v2): From 558068f61171c4e691e1bb12808c7d1231ddc628 Mon Sep 17 00:00:00 2001 From: yuanxiaobin Date: Fri, 6 Dec 2024 14:32:41 +0800 Subject: [PATCH 2/5] =?UTF-8?q?fix(utils):=20=E4=BF=AE=E5=A4=8D=20JSON=20?= =?UTF-8?q?=E6=97=A5=E5=BF=97=E7=BC=96=E7=A0=81=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 json.dumps 中添加 ensure_ascii=False 参数,以支持非 ASCII 字符编码 -这个修改确保了包含中文等非 ASCII 字符的日志信息能够正确处理和显示 --- lightrag/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 70ec4341..4c8d7996 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -358,7 +358,7 @@ async def get_best_cached_response( "cache_id": best_cache_id, "original_prompt": prompt_display, } - logger.info(json.dumps(log_data)) + logger.info(json.dumps(log_data, ensure_ascii=False)) return best_response return None From 633fb55b5b888aaf38fee8d756622a3f1d00a370 Mon Sep 17 00:00:00 2001 From: yuanxiaobin Date: Fri, 6 Dec 2024 15:09:50 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E5=86=B2=E7=AA=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index dda65630..d147e416 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -4,8 +4,8 @@ import os import struct from functools import lru_cache -from typing import List, Dict, Callable, Any, Union - +from typing import List, Dict, Callable, Any, Union, Optional +from dataclasses import dataclass import aioboto3 import aiohttp import numpy as np From a1c4a036fd2187c5b463ba2c24815c39a918a835 Mon Sep 17 00:00:00 2001 From: yuanxiaobin Date: Fri, 6 Dec 2024 15:23:18 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E7=A7=BB=E9=99=A4kwargs=E4=B8=AD=E7=9A=84h?= =?UTF-8?q?ashing=5Fkv=E5=8F=82=E6=95=B0=E5=8F=96=E4=B8=BA=E5=8F=98?= =?UTF-8?q?=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/llm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index d147e416..09e9fd74 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -73,11 +73,12 @@ async def openai_complete_if_cache( messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) # Handle cache mode = kwargs.pop("mode", "default") args_hash = compute_args_hash(model, messages) cached_response, quantized, min_val, max_val = await handle_cache( - kwargs.get("hashing_kv"), args_hash, prompt, mode + hashing_kv, args_hash, prompt, mode ) if cached_response is not None: return cached_response @@ -219,12 +220,12 @@ async def bedrock_complete_if_cache( # Add user prompt messages.append({"role": "user", "content": [{"text": prompt}]}) - + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) # Handle cache mode = kwargs.pop("mode", "default") args_hash = compute_args_hash(model, messages) cached_response, quantized, min_val, max_val = await handle_cache( - kwargs.get("hashing_kv"), args_hash, prompt, mode + hashing_kv, args_hash, prompt, mode ) if cached_response is not None: return cached_response @@ -250,12 +251,12 @@ async def bedrock_complete_if_cache( args["inferenceConfig"][inference_params_map.get(param, param)] = ( kwargs.pop(param) ) - + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) # Handle cache mode = kwargs.pop("mode", "default") args_hash = compute_args_hash(model, messages) cached_response, quantized, min_val, max_val = await handle_cache( - kwargs.get("hashing_kv"), args_hash, prompt, mode + hashing_kv, args_hash, prompt, mode ) if cached_response is not None: return cached_response From 6a010abb625d82af21cba2374a717f86f56b5c09 Mon Sep 17 00:00:00 2001 From: yuanxiaobin Date: Fri, 6 Dec 2024 15:35:09 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E7=A7=BB=E9=99=A4kwargs=E4=B8=AD=E7=9A=84h?= =?UTF-8?q?ashing=5Fkv=E5=8F=82=E6=95=B0=E5=8F=96=E4=B8=BA=E5=8F=98?= =?UTF-8?q?=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lightrag/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index 09e9fd74..63913c90 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -97,7 +97,7 @@ async def openai_complete_if_cache( # Save to cache await save_to_cache( - kwargs.get("hashing_kv"), + hashing_kv, CacheData( args_hash=args_hash, content=content, @@ -271,7 +271,7 @@ async def bedrock_complete_if_cache( # Save to cache await save_to_cache( - kwargs.get("hashing_kv"), + hashing_kv, CacheData( args_hash=args_hash, content=response["output"]["message"]["content"][0]["text"],