Skip to content

Commit

Permalink
update Step_3.py and openai compatible script
Browse files Browse the repository at this point in the history
  • Loading branch information
russellkim committed Oct 17, 2024
1 parent a2f1654 commit 70dbca1
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 2 deletions.
66 changes: 66 additions & 0 deletions reproduce/Step_1_openai_compatible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import json
import time
import numpy as np

from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.llm import openai_complete_if_cache, openai_embedding

## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs
)

async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar"
)
## /For Upstage API

def insert_text(rag, file_path):
with open(file_path, mode='r') as f:
unique_contexts = json.load(f)

retries = 0
max_retries = 3
while retries < max_retries:
try:
rag.insert(unique_contexts)
break
except Exception as e:
retries += 1
print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
time.sleep(10)
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")

cls = "mix"
WORKING_DIR = f"../{cls}"

if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)

rag = LightRAG(working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096,
max_token_size=8192,
func=embedding_func
)
)

insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
4 changes: 2 additions & 2 deletions reproduce/Step_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
if __name__ == "__main__":
cls = "agriculture"
mode = "hybrid"
WORKING_DIR = "../{cls}"
WORKING_DIR = f"../{cls}"

rag = LightRAG(working_dir=WORKING_DIR)
query_param = QueryParam(mode=mode)

queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
run_queries_and_save_to_json(queries, rag, query_param, "result.json", "errors.json")
run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json")
99 changes: 99 additions & 0 deletions reproduce/Step_3_openai_compatible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import re
import json
import asyncio
from lightrag import LightRAG, QueryParam
from tqdm import tqdm
from lightrag.llm import openai_complete_if_cache, openai_embedding
from lightrag.utils import EmbeddingFunc
import numpy as np

## For Upstage API
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs
)

async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar"
)
## /For Upstage API

def extract_queries(file_path):
with open(file_path, 'r') as f:
data = f.read()

data = data.replace('**', '')

This comment has been minimized.

Copy link
@Liontpe

queries = re.findall(r'- Question \d+: (.+)', data)

return queries

async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
return {"query": query_text, "result": result, "context": context}, None
except Exception as e:
return None, {"query": query_text, "error": str(e)}

def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop

def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
loop = always_get_an_event_loop()

with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
result_file.write("[\n")
first_entry = True

for query_text in tqdm(queries, desc="Processing queries", unit="query"):
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))

if result:
if not first_entry:
result_file.write(",\n")
json.dump(result, result_file, ensure_ascii=False, indent=4)
first_entry = False
elif error:
json.dump(error, err_file, ensure_ascii=False, indent=4)
err_file.write("\n")

result_file.write("\n]")

if __name__ == "__main__":
cls = "mix"
mode = "hybrid"
WORKING_DIR = f"../{cls}"

rag = LightRAG(working_dir=WORKING_DIR)
rag = LightRAG(working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096,
max_token_size=8192,
func=embedding_func
)
)
query_param = QueryParam(mode=mode)

base_dir='../datasets/questions'
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json")

0 comments on commit 70dbca1

Please sign in to comment.