Skip to content

Commit

Permalink
Merge pull request #1118 from arc53/1059-migrating-database-to-new-model
Browse files Browse the repository at this point in the history
  • Loading branch information
dartpain authored Sep 9, 2024
2 parents a1d3592 + a1b32ff commit 72842ec
Show file tree
Hide file tree
Showing 28 changed files with 374 additions and 440 deletions.
95 changes: 49 additions & 46 deletions application/api/answer/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pymongo import MongoClient
from bson.objectid import ObjectId
from bson.dbref import DBRef

from application.core.settings import settings
from application.llm.llm_creator import LLMCreator
Expand All @@ -20,7 +21,7 @@
mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
conversations_collection = db["conversations"]
vectors_collection = db["vectors"]
sources_collection = db["sources"]
prompts_collection = db["prompts"]
api_key_collection = db["api_keys"]
answer = Blueprint("answer", __name__)
Expand All @@ -36,9 +37,7 @@
gpt_model = settings.MODEL_NAME

# load the prompts
current_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
with open(os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r") as f:
chat_combine_template = f.read()

Expand Down Expand Up @@ -74,35 +73,34 @@ def run_async_chain(chain, question, chat_history):

def get_data_from_api_key(api_key):
data = api_key_collection.find_one({"key": api_key})

# # Raise custom exception if the API key is not found
if data is None:
raise Exception("Invalid API Key, please generate new key", 401)
return data

if "retriever" not in data:
data["retriever"] = None

def get_vectorstore(data):
if "active_docs" in data:
if data["active_docs"].split("/")[0] == "default":
vectorstore = ""
elif data["active_docs"].split("/")[0] == "local":
vectorstore = "indexes/" + data["active_docs"]
else:
vectorstore = "vectors/" + data["active_docs"]
if data["active_docs"] == "default":
vectorstore = ""
if "source" in data and isinstance(data["source"], DBRef):
source_doc = db.dereference(data["source"])
data["source"] = str(source_doc["_id"])
if "retriever" in source_doc:
data["retriever"] = source_doc["retriever"]
else:
vectorstore = ""
vectorstore = os.path.join("application", vectorstore)
return vectorstore
data["source"] = {}
return data


def get_retriever(source_id: str):
doc = sources_collection.find_one({"_id": ObjectId(source_id)})
if doc is None:
raise Exception("Source document does not exist", 404)
retriever_name = None if "retriever" not in doc else doc["retriever"]
return retriever_name



def is_azure_configured():
return (
settings.OPENAI_API_BASE
and settings.OPENAI_API_VERSION
and settings.AZURE_DEPLOYMENT_NAME
)
return settings.OPENAI_API_BASE and settings.OPENAI_API_VERSION and settings.AZURE_DEPLOYMENT_NAME


def save_conversation(conversation_id, question, response, source_log_docs, llm):
Expand Down Expand Up @@ -247,32 +245,33 @@ def stream():
else:
token_limit = settings.DEFAULT_MAX_HISTORY

# check if active_docs or api_key is set
## retriever can be "brave_search, duckduck_search or classic"
retriever_name = data["retriever"] if "retriever" in data else "classic"

# check if active_docs or api_key is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = {"active_docs": data_key["source"]}
retriever_name = data_key["retriever"] or retriever_name
user_api_key = data["api_key"]

elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
source = {"active_docs" : data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None

else:
source = {}
user_api_key = None

if source["active_docs"].split("/")[0] in ["default", "local"]:
retriever_name = "classic"
else:
retriever_name = source["active_docs"]

current_app.logger.info(f"/stream - request_data: {data}, source: {source}",
extra={"data": json.dumps({"request_data": data, "source": source})}
)

prompt = get_prompt(prompt_id)

retriever = RetrieverCreator.create_retriever(
retriever_name,
question=question,
Expand Down Expand Up @@ -351,22 +350,26 @@ def api_answer():
else:
token_limit = settings.DEFAULT_MAX_HISTORY

## retriever can be brave_search, duckduck_search or classic
retriever_name = data["retriever"] if "retriever" in data else "classic"

# use try and except to check for exception
try:
# check if the vectorstore is set
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
prompt_id = data_key["prompt_id"]
source = {"active_docs": data_key["source"]}
retriever_name = data_key["retriever"] or retriever_name
user_api_key = data["api_key"]
else:
source = data
elif "active_docs" in data:
source = {"active_docs":data["active_docs"]}
retriever_name = get_retriever(data["active_docs"]) or retriever_name
user_api_key = None

if source["active_docs"].split("/")[0] in ["default", "local"]:
retriever_name = "classic"
else:
retriever_name = source["active_docs"]
source = {}
user_api_key = None

prompt = get_prompt(prompt_id)

Expand Down Expand Up @@ -402,8 +405,8 @@ def api_answer():
)

result = {"answer": response_full, "sources": source_log_docs}
result["conversation_id"] = save_conversation(
conversation_id, question, response_full, source_log_docs, llm
result["conversation_id"] = str(
save_conversation(conversation_id, question, response_full, source_log_docs, llm)
)

return result
Expand All @@ -425,19 +428,19 @@ def api_search():
if "api_key" in data:
data_key = get_data_from_api_key(data["api_key"])
chunks = int(data_key["chunks"])
source = {"active_docs": data_key["source"]}
user_api_key = data["api_key"]
source = {"active_docs":data_key["source"]}
user_api_key = data_key["api_key"]
elif "active_docs" in data:
source = {"active_docs": data["active_docs"]}
source = {"active_docs":data["active_docs"]}
user_api_key = None
else:
source = {}
user_api_key = None

if source["active_docs"].split("/")[0] in ["default", "local"]:
retriever_name = "classic"
if "retriever" in data:
retriever_name = data["retriever"]
else:
retriever_name = source["active_docs"]
retriever_name = "classic"
if "token_limit" in data:
token_limit = data["token_limit"]
else:
Expand Down
23 changes: 15 additions & 8 deletions application/api/internal/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from flask import Blueprint, request, send_from_directory
from pymongo import MongoClient
from werkzeug.utils import secure_filename

from bson.objectid import ObjectId

from application.core.settings import settings
mongo = MongoClient(settings.MONGO_URI)
db = mongo["docsgpt"]
conversations_collection = db["conversations"]
vectors_collection = db["vectors"]
sources_collection = db["sources"]

current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

Expand All @@ -35,7 +35,12 @@ def upload_index_files():
return {"status": "no name"}
job_name = secure_filename(request.form["name"])
tokens = secure_filename(request.form["tokens"])
save_dir = os.path.join(current_dir, "indexes", user, job_name)
retriever = secure_filename(request.form["retriever"])
id = secure_filename(request.form["id"])
type = secure_filename(request.form["type"])
remote_data = secure_filename(request.form["remote_data"]) if "remote_data" in request.form else None

save_dir = os.path.join(current_dir, "indexes", str(id))
if settings.VECTOR_STORE == "faiss":
if "file_faiss" not in request.files:
print("No file part")
Expand All @@ -55,17 +60,19 @@ def upload_index_files():
os.makedirs(save_dir)
file_faiss.save(os.path.join(save_dir, "index.faiss"))
file_pkl.save(os.path.join(save_dir, "index.pkl"))
# create entry in vectors_collection
vectors_collection.insert_one(
# create entry in sources_collection
sources_collection.insert_one(
{
"_id": ObjectId(id),
"user": user,
"name": job_name,
"language": job_name,
"location": save_dir,
"date": datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
"model": settings.EMBEDDINGS_NAME,
"type": "local",
"tokens": tokens
"type": type,
"tokens": tokens,
"retriever": retriever,
"remote_data": remote_data
}
)
return {"status": "ok"}
Loading

0 comments on commit 72842ec

Please sign in to comment.