From b084e3074dc2781e285abf9ab50d922900f52f1b Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Fri, 27 Sep 2024 20:08:46 +0530 Subject: [PATCH 1/3] refactor: user routes to comply with OpenAPI spec using flask-restx --- application/api/user/routes.py | 2372 +++++++++++++-------- application/requirements.txt | 1 + frontend/src/components/Dropdown.tsx | 4 +- frontend/src/modals/CreateAPIKeyModal.tsx | 3 +- frontend/src/settings/Analytics.tsx | 6 +- 5 files changed, 1439 insertions(+), 947 deletions(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 764195e32..0bc11a1be 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -6,7 +6,8 @@ from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef from bson.objectid import ObjectId -from flask import Blueprint, jsonify, request +from flask import Blueprint, jsonify, make_response, request +from flask_restx import Api, fields, Namespace, Resource from pymongo import MongoClient from werkzeug.utils import secure_filename @@ -27,6 +28,14 @@ user_logs_collection = db["user_logs"] user = Blueprint("user", __name__) +api = Api( + user, + version="1.0", + title="DocsGPT API", + description="API for DocsGPT", + default="user", + default_label="User operations", +) current_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -54,502 +63,865 @@ def generate_date_range(start_date, end_date): } -@user.route("/api/delete_conversation", methods=["POST"]) -def delete_conversation(): - # deletes a conversation from the database - conversation_id = request.args.get("id") - # write to mongodb - conversations_collection.delete_one( - { - "_id": ObjectId(conversation_id), - } +def check_required_fields(data, required_fields): + missing_fields = [field for field in required_fields if field not in data] + if missing_fields: + return make_response( + jsonify( + { + "success": False, + "message": f"Missing fields: {', '.join(missing_fields)}", + } + ), + 400, + ) + return None + + +@api.route("/api/delete_conversation") +class DeleteConversation(Resource): + @api.doc( + description="Deletes a conversation by ID", + params={"id": "The ID of the conversation to delete"}, ) + def post(self): + conversation_id = request.args.get("id") + if not conversation_id: + return make_response( + jsonify({"success": False, "message": "ID is required"}), 400 + ) - return {"status": "ok"} + try: + conversations_collection.delete_one({"_id": ObjectId(conversation_id)}) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response(jsonify({"success": True}), 200) -@user.route("/api/delete_all_conversations", methods=["GET"]) -def delete_all_conversations(): - user_id = "local" - conversations_collection.delete_many({"user": user_id}) - return {"status": "ok"} +@api.route("/api/delete_all_conversations") +class DeleteAllConversations(Resource): + @api.doc( + description="Deletes all conversations for a specific user", + ) + def get(self): + user_id = "local" + try: + conversations_collection.delete_many({"user": user_id}) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response(jsonify({"success": True}), 200) -@user.route("/api/get_conversations", methods=["get"]) -def get_conversations(): - # provides a list of conversations - conversations = conversations_collection.find().sort("date", -1).limit(30) - list_conversations = [] - for conversation in conversations: - list_conversations.append( - {"id": str(conversation["_id"]), "name": conversation["name"]} - ) +@api.route("/api/get_conversations") +class GetConversations(Resource): + @api.doc( + description="Retrieve a list of the latest 30 conversations", + ) + def get(self): + try: + conversations = conversations_collection.find().sort("date", -1).limit(30) + list_conversations = [ + {"id": str(conversation["_id"]), "name": conversation["name"]} + for conversation in conversations + ] + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response(jsonify(list_conversations), 200) - # list_conversations = [{"id": "default", "name": "default"}, {"id": "jeff", "name": "jeff"}] - return jsonify(list_conversations) +@api.route("/api/get_single_conversation") +class GetSingleConversation(Resource): + @api.doc( + description="Retrieve a single conversation by ID", + params={"id": "The conversation ID"}, + ) + def get(self): + conversation_id = request.args.get("id") + if not conversation_id: + return make_response( + jsonify({"success": False, "message": "ID is required"}), 400 + ) + + try: + conversation = conversations_collection.find_one( + {"_id": ObjectId(conversation_id)} + ) + if not conversation: + return make_response(jsonify({"status": "not found"}), 404) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response(jsonify(conversation["queries"]), 200) + + +@api.route("/api/update_conversation_name") +class UpdateConversationName(Resource): + @api.expect( + api.model( + "UpdateConversationModel", + { + "id": fields.String(required=True, description="Conversation ID"), + "name": fields.String( + required=True, description="New name of the conversation" + ), + }, + ) + ) + @api.doc( + description="Updates the name of a conversation", + ) + def post(self): + data = request.get_json() + required_fields = ["id", "name"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + try: + conversations_collection.update_one( + {"_id": ObjectId(data["id"])}, {"$set": {"name": data["name"]}} + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) -@user.route("/api/get_single_conversation", methods=["get"]) -def get_single_conversation(): - # provides data for a conversation - conversation_id = request.args.get("id") - conversation = conversations_collection.find_one({"_id": ObjectId(conversation_id)}) - return jsonify(conversation["queries"]) + return make_response(jsonify({"success": True}), 200) -@user.route("/api/update_conversation_name", methods=["POST"]) -def update_conversation_name(): - # update data for a conversation - data = request.get_json() - id = data["id"] - name = data["name"] - conversations_collection.update_one({"_id": ObjectId(id)}, {"$set": {"name": name}}) - return {"status": "ok"} +@api.route("/api/feedback") +class SubmitFeedback(Resource): + @api.expect( + api.model( + "FeedbackModel", + { + "question": fields.String( + required=True, description="The user question" + ), + "answer": fields.String(required=True, description="The AI answer"), + "feedback": fields.String(required=True, description="User feedback"), + "api_key": fields.String(description="Optional API key"), + }, + ) + ) + @api.doc( + description="Submit feedback for a conversation", + ) + def post(self): + data = request.get_json() + required_fields = ["question", "answer", "feedback"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + new_doc = { + "question": data["question"], + "answer": data["answer"], + "feedback": data["feedback"], + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + if "api_key" in data: + new_doc["api_key"] = data["api_key"] -@user.route("/api/feedback", methods=["POST"]) -def api_feedback(): - data = request.get_json() - question = data["question"] - answer = data["answer"] - feedback = data["feedback"] - new_doc = { - "question": question, - "answer": answer, - "feedback": feedback, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - if "api_key" in data: - new_doc["api_key"] = data["api_key"] - feedback_collection.insert_one(new_doc) - return {"status": "ok"} + try: + feedback_collection.insert_one(new_doc) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response(jsonify({"success": True}), 200) -@user.route("/api/delete_by_ids", methods=["get"]) -def delete_by_ids(): - """Delete by ID. These are the IDs in the vectorstore""" - ids = request.args.get("path") - if not ids: - return {"status": "error"} +@api.route("/api/delete_by_ids") +class DeleteByIds(Resource): + @api.doc( + description="Deletes documents from the vector store by IDs", + params={"path": "Comma-separated list of IDs"}, + ) + def get(self): + ids = request.args.get("path") + if not ids: + return make_response( + jsonify({"success": False, "message": "Missing required fields"}), 400 + ) - if settings.VECTOR_STORE == "faiss": - result = sources_collection.delete_index(ids=ids) - if result: - return {"status": "ok"} - return {"status": "error"} + try: + result = sources_collection.delete_index(ids=ids) + if result: + return make_response(jsonify({"success": True}), 200) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response(jsonify({"success": False, "error": str(err)}), 400) -@user.route("/api/delete_old", methods=["get"]) -def delete_old(): - """Delete old indexes.""" - import shutil - source_id = request.args.get("source_id") - doc = sources_collection.find_one( - { - "_id": ObjectId(source_id), - "user": "local", - } +@api.route("/api/delete_old") +class DeleteOldIndexes(Resource): + @api.doc( + description="Deletes old indexes", + params={"source_id": "The source ID to delete"}, ) - if doc is None: - return {"status": "not found"}, 404 - if settings.VECTOR_STORE == "faiss": + def get(self): + source_id = request.args.get("source_id") + if not source_id: + return make_response( + jsonify({"success": False, "message": "Missing required fields"}), 400 + ) + try: - shutil.rmtree(os.path.join(current_dir, str(doc["_id"]))) + doc = sources_collection.find_one( + {"_id": ObjectId(source_id), "user": "local"} + ) + if not doc: + return make_response(jsonify({"status": "not found"}), 404) + + if settings.VECTOR_STORE == "faiss": + shutil.rmtree(os.path.join(current_dir, "indexes", str(doc["_id"]))) + else: + vectorstore = VectorCreator.create_vectorstore( + settings.VECTOR_STORE, source_id=str(doc["_id"]) + ) + vectorstore.delete_index() + + sources_collection.delete_one({"_id": ObjectId(source_id)}) except FileNotFoundError: pass - else: - vetorstore = VectorCreator.create_vectorstore( - settings.VECTOR_STORE, source_id=str(doc["_id"]) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + + +@api.route("/api/upload") +class UploadFile(Resource): + @api.expect( + api.model( + "UploadModel", + { + "user": fields.String(required=True, description="User ID"), + "name": fields.String(required=True, description="Job name"), + "file": fields.Raw(required=True, description="File(s) to upload"), + }, ) - vetorstore.delete_index() - sources_collection.delete_one( - { - "_id": ObjectId(source_id), - } ) + @api.doc( + description="Uploads a file to be vectorized and indexed", + ) + def post(self): + data = request.form + files = request.files.getlist("file") + required_fields = ["user", "name"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields or not files or all(file.filename == "" for file in files): + return make_response( + jsonify( + { + "status": "error", + "message": "Missing required fields or files", + } + ), + 400, + ) - return {"status": "ok"} - + user = secure_filename(request.form["user"]) + job_name = secure_filename(request.form["name"]) + try: + save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name) + os.makedirs(save_dir, exist_ok=True) -@user.route("/api/upload", methods=["POST"]) -def upload_file(): - """Upload a file to get vectorized and indexed.""" - if "user" not in request.form: - return {"status": "no user"} - user = secure_filename(request.form["user"]) - if "name" not in request.form: - return {"status": "no name"} - job_name = secure_filename(request.form["name"]) - # check if the post request has the file part - files = request.files.getlist("file") + if len(files) > 1: + temp_dir = os.path.join(save_dir, "temp") + os.makedirs(temp_dir, exist_ok=True) - if not files or all(file.filename == "" for file in files): - return {"status": "no file name"} + for file in files: + filename = secure_filename(file.filename) + file.save(os.path.join(temp_dir, filename)) - # Directory where files will be saved - save_dir = os.path.join(current_dir, settings.UPLOAD_FOLDER, user, job_name) - os.makedirs(save_dir, exist_ok=True) + zip_path = shutil.make_archive( + base_name=os.path.join(save_dir, job_name), + format="zip", + root_dir=temp_dir, + ) + final_filename = os.path.basename(zip_path) + shutil.rmtree(temp_dir) + else: + file = files[0] + final_filename = secure_filename(file.filename) + file_path = os.path.join(save_dir, final_filename) + file.save(file_path) + + task = ingest.delay( + settings.UPLOAD_FOLDER, + [ + ".rst", + ".md", + ".pdf", + ".txt", + ".docx", + ".csv", + ".epub", + ".html", + ".mdx", + ], + job_name, + final_filename, + user, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) - if len(files) > 1: - # Multiple files; prepare them for zip - temp_dir = os.path.join(save_dir, "temp") - os.makedirs(temp_dir, exist_ok=True) + return make_response(jsonify({"success": True, "task_id": task.id}), 200) - for file in files: - filename = secure_filename(file.filename) - file.save(os.path.join(temp_dir, filename)) - # Use shutil.make_archive to zip the temp directory - zip_path = shutil.make_archive( - base_name=os.path.join(save_dir, job_name), format="zip", root_dir=temp_dir +@api.route("/api/remote") +class UploadRemote(Resource): + @api.expect( + api.model( + "RemoteUploadModel", + { + "user": fields.String(required=True, description="User ID"), + "source": fields.String( + required=True, description="Source of the data" + ), + "name": fields.String(required=True, description="Job name"), + "data": fields.String(required=True, description="Data to process"), + }, ) - final_filename = os.path.basename(zip_path) - - # Clean up the temporary directory after zipping - shutil.rmtree(temp_dir) - else: - # Single file - file = files[0] - final_filename = secure_filename(file.filename) - file_path = os.path.join(save_dir, final_filename) - file.save(file_path) - - # Call ingest with the single file or zipped file - task = ingest.delay( - settings.UPLOAD_FOLDER, - [".rst", ".md", ".pdf", ".txt", ".docx", ".csv", ".epub", ".html", ".mdx"], - job_name, - final_filename, - user, ) + @api.doc( + description="Uploads remote source for vectorization", + ) + def post(self): + data = request.form + required_fields = ["user", "source", "name", "data"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields - return {"status": "ok", "task_id": task.id} - - -@user.route("/api/remote", methods=["POST"]) -def upload_remote(): - """Upload a remote source to get vectorized and indexed.""" - if "user" not in request.form: - return {"status": "no user"} - user = secure_filename(request.form["user"]) - if "source" not in request.form: - return {"status": "no source"} - source = secure_filename(request.form["source"]) - if "name" not in request.form: - return {"status": "no name"} - job_name = secure_filename(request.form["name"]) - if "data" not in request.form: - print("No data") - return {"status": "no data"} - source_data = request.form["data"] - - if source_data: - task = ingest_remote.delay( - source_data=source_data, job_name=job_name, user=user, loader=source - ) - task_id = task.id - return {"status": "ok", "task_id": task_id} - else: - return {"status": "error"} + try: + task = ingest_remote.delay( + source_data=data["data"], + job_name=data["name"], + user=data["user"], + loader=data["source"], + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True, "task_id": task.id}), 200) + + +@api.route("/api/task_status") +class TaskStatus(Resource): + task_status_model = api.model( + "TaskStatusModel", + {"task_id": fields.String(required=True, description="Task ID")}, + ) + @api.expect(task_status_model) + @api.doc(description="Get celery job status") + def get(self): + task_id = request.args.get("task_id") + if not task_id: + return make_response( + jsonify({"success": False, "message": "Task ID is required"}), 400 + ) -@user.route("/api/task_status", methods=["GET"]) -def task_status(): - """Get celery job status.""" - task_id = request.args.get("task_id") - from application.celery_init import celery + try: + from application.celery_init import celery - task = celery.AsyncResult(task_id) - task_meta = task.info - return {"status": task.status, "result": task_meta} + task = celery.AsyncResult(task_id) + task_meta = task.info + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response(jsonify({"status": task.status, "result": task_meta}), 200) -@user.route("/api/combine", methods=["GET"]) -def combined_json(): - user = "local" - """Provide json file with combined available indexes.""" - # get json from https://d3dg1063dc54p9.cloudfront.net/combined.json - data = [ - { - "name": "default", - "date": "default", - "model": settings.EMBEDDINGS_NAME, - "location": "remote", - "tokens": "", - "retriever": "classic", - } - ] - # structure: name, language, version, description, fullName, date, docLink - # append data from sources_collection in sorted order in descending order of date - for index in sources_collection.find({"user": user}).sort("date", -1): - data.append( - { - "id": str(index["_id"]), - "name": index.get("name"), - "date": index.get("date"), - "model": settings.EMBEDDINGS_NAME, - "location": "local", - "tokens": index.get("tokens", ""), - "retriever": index.get("retriever", "classic"), - "syncFrequency": index.get("sync_frequency", ""), - } - ) - if "duckduck_search" in settings.RETRIEVERS_ENABLED: - data.append( - { - "name": "DuckDuckGo Search", - "date": "duckduck_search", - "model": settings.EMBEDDINGS_NAME, - "location": "custom", - "tokens": "", - "retriever": "duckduck_search", - } - ) - if "brave_search" in settings.RETRIEVERS_ENABLED: - data.append( +@api.route("/api/combine") +class CombinedJson(Resource): + @api.doc(description="Provide JSON file with combined available indexes") + def get(self): + user = "local" + data = [ { - "name": "Brave Search", - "language": "en", - "date": "brave_search", + "name": "default", + "date": "default", "model": settings.EMBEDDINGS_NAME, - "location": "custom", + "location": "remote", "tokens": "", - "retriever": "brave_search", + "retriever": "classic", } - ) + ] - return jsonify(data) + try: + for index in sources_collection.find({"user": user}).sort("date", -1): + data.append( + { + "id": str(index["_id"]), + "name": index.get("name"), + "date": index.get("date"), + "model": settings.EMBEDDINGS_NAME, + "location": "local", + "tokens": index.get("tokens", ""), + "retriever": index.get("retriever", "classic"), + "syncFrequency": index.get("sync_frequency", ""), + } + ) + if "duckduck_search" in settings.RETRIEVERS_ENABLED: + data.append( + { + "name": "DuckDuckGo Search", + "date": "duckduck_search", + "model": settings.EMBEDDINGS_NAME, + "location": "custom", + "tokens": "", + "retriever": "duckduck_search", + } + ) -@user.route("/api/docs_check", methods=["POST"]) -def check_docs(): - data = request.get_json() + if "brave_search" in settings.RETRIEVERS_ENABLED: + data.append( + { + "name": "Brave Search", + "language": "en", + "date": "brave_search", + "model": settings.EMBEDDINGS_NAME, + "location": "custom", + "tokens": "", + "retriever": "brave_search", + } + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) - vectorstore = "vectors/" + secure_filename(data["docs"]) - if os.path.exists(vectorstore) or data["docs"] == "default": - return {"status": "exists"} - else: - return {"status": "not found"} + return make_response(jsonify(data), 200) -@user.route("/api/create_prompt", methods=["POST"]) -def create_prompt(): - data = request.get_json() - content = data["content"] - name = data["name"] - if name == "": - return {"status": "error"} - user = "local" - resp = prompts_collection.insert_one( - { - "name": name, - "content": content, - "user": user, - } +@api.route("/api/docs_check") +class CheckDocs(Resource): + check_docs_model = api.model( + "CheckDocsModel", + {"docs": fields.String(required=True, description="Document name")}, ) - new_id = str(resp.inserted_id) - return {"id": new_id} - - -@user.route("/api/get_prompts", methods=["GET"]) -def get_prompts(): - user = "local" - prompts = prompts_collection.find({"user": user}) - list_prompts = [] - list_prompts.append({"id": "default", "name": "default", "type": "public"}) - list_prompts.append({"id": "creative", "name": "creative", "type": "public"}) - list_prompts.append({"id": "strict", "name": "strict", "type": "public"}) - for prompt in prompts: - list_prompts.append( - {"id": str(prompt["_id"]), "name": prompt["name"], "type": "private"} - ) - return jsonify(list_prompts) - - -@user.route("/api/get_single_prompt", methods=["GET"]) -def get_single_prompt(): - prompt_id = request.args.get("id") - if prompt_id == "default": - with open( - os.path.join(current_dir, "prompts", "chat_combine_default.txt"), "r" - ) as f: - chat_combine_template = f.read() - return jsonify({"content": chat_combine_template}) - elif prompt_id == "creative": - with open( - os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), "r" - ) as f: - chat_reduce_creative = f.read() - return jsonify({"content": chat_reduce_creative}) - elif prompt_id == "strict": - with open( - os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r" - ) as f: - chat_reduce_strict = f.read() - return jsonify({"content": chat_reduce_strict}) - - prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)}) - return jsonify({"content": prompt["content"]}) - - -@user.route("/api/delete_prompt", methods=["POST"]) -def delete_prompt(): - data = request.get_json() - id = data["id"] - prompts_collection.delete_one( + @api.expect(check_docs_model) + @api.doc(description="Check if document exists") + def post(self): + data = request.get_json() + required_fields = ["docs"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + vectorstore = "vectors/" + secure_filename(data["docs"]) + if os.path.exists(vectorstore) or data["docs"] == "default": + return {"status": "exists"}, 200 + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"status": "not found"}), 404) + + +@api.route("/api/create_prompt") +class CreatePrompt(Resource): + create_prompt_model = api.model( + "CreatePromptModel", { - "_id": ObjectId(id), - } + "content": fields.String( + required=True, description="Content of the prompt" + ), + "name": fields.String(required=True, description="Name of the prompt"), + }, ) - return {"status": "ok"} - - -@user.route("/api/update_prompt", methods=["POST"]) -def update_prompt_name(): - data = request.get_json() - id = data["id"] - name = data["name"] - content = data["content"] - # check if name is null - if name == "": - return {"status": "error"} - prompts_collection.update_one( - {"_id": ObjectId(id)}, {"$set": {"name": name, "content": content}} + + @api.expect(create_prompt_model) + @api.doc(description="Create a new prompt") + def post(self): + data = request.get_json() + required_fields = ["content", "name"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + user = "local" + try: + + resp = prompts_collection.insert_one( + { + "name": data["name"], + "content": data["content"], + "user": user, + } + ) + new_id = str(resp.inserted_id) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"id": new_id}), 200) + + +@api.route("/api/get_prompts") +class GetPrompts(Resource): + @api.doc(description="Get all prompts for the user") + def get(self): + user = "local" + try: + prompts = prompts_collection.find({"user": user}) + list_prompts = [ + {"id": "default", "name": "default", "type": "public"}, + {"id": "creative", "name": "creative", "type": "public"}, + {"id": "strict", "name": "strict", "type": "public"}, + ] + + for prompt in prompts: + list_prompts.append( + { + "id": str(prompt["_id"]), + "name": prompt["name"], + "type": "private", + } + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify(list_prompts), 200) + + +@api.route("/api/get_single_prompt") +class GetSinglePrompt(Resource): + @api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID") + def get(self): + prompt_id = request.args.get("id") + if not prompt_id: + return make_response( + jsonify({"success": False, "message": "ID is required"}), 400 + ) + + try: + if prompt_id == "default": + with open( + os.path.join(current_dir, "prompts", "chat_combine_default.txt"), + "r", + ) as f: + chat_combine_template = f.read() + return make_response(jsonify({"content": chat_combine_template}), 200) + + elif prompt_id == "creative": + with open( + os.path.join(current_dir, "prompts", "chat_combine_creative.txt"), + "r", + ) as f: + chat_reduce_creative = f.read() + return make_response(jsonify({"content": chat_reduce_creative}), 200) + + elif prompt_id == "strict": + with open( + os.path.join(current_dir, "prompts", "chat_combine_strict.txt"), "r" + ) as f: + chat_reduce_strict = f.read() + return make_response(jsonify({"content": chat_reduce_strict}), 200) + + prompt = prompts_collection.find_one({"_id": ObjectId(prompt_id)}) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"content": prompt["content"]}), 200) + + +@api.route("/api/delete_prompt") +class DeletePrompt(Resource): + delete_prompt_model = api.model( + "DeletePromptModel", + {"id": fields.String(required=True, description="Prompt ID to delete")}, ) - return {"status": "ok"} - - -@user.route("/api/get_api_keys", methods=["GET"]) -def get_api_keys(): - user = "local" - keys = api_key_collection.find({"user": user}) - list_keys = [] - for key in keys: - if "source" in key and isinstance(key["source"], DBRef): - source = db.dereference(key["source"]) - if source is None: - continue - else: - source_name = source["name"] - elif "retriever" in key: - source_name = key["retriever"] - else: - continue - list_keys.append( - { - "id": str(key["_id"]), - "name": key["name"], - "key": key["key"][:4] + "..." + key["key"][-4:], - "source": source_name, - "prompt_id": key["prompt_id"], - "chunks": key["chunks"], - } - ) - return jsonify(list_keys) - - -@user.route("/api/create_api_key", methods=["POST"]) -def create_api_key(): - data = request.get_json() - name = data["name"] - prompt_id = data["prompt_id"] - chunks = data["chunks"] - key = str(uuid.uuid4()) - user = "local" - new_api_key = { - "name": name, - "key": key, - "user": user, - "prompt_id": prompt_id, - "chunks": chunks, - } - if "source" in data and ObjectId.is_valid(data["source"]): - new_api_key["source"] = DBRef("sources", ObjectId(data["source"])) - if "retriever" in data: - new_api_key["retriever"] = data["retriever"] - resp = api_key_collection.insert_one(new_api_key) - new_id = str(resp.inserted_id) - return {"id": new_id, "key": key} - - -@user.route("/api/delete_api_key", methods=["POST"]) -def delete_api_key(): - data = request.get_json() - id = data["id"] - api_key_collection.delete_one( + @api.expect(delete_prompt_model) + @api.doc(description="Delete a prompt by ID") + def post(self): + data = request.get_json() + required_fields = ["id"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + prompts_collection.delete_one({"_id": ObjectId(data["id"])}) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + + +@api.route("/api/update_prompt") +class UpdatePrompt(Resource): + update_prompt_model = api.model( + "UpdatePromptModel", { - "_id": ObjectId(id), - } + "id": fields.String(required=True, description="Prompt ID to update"), + "name": fields.String(required=True, description="New name of the prompt"), + "content": fields.String( + required=True, description="New content of the prompt" + ), + }, ) - return {"status": "ok"} - -# route to share conversation -##isPromptable should be passed through queries -@user.route("/api/share", methods=["POST"]) -def share_conversation(): - try: + @api.expect(update_prompt_model) + @api.doc(description="Update an existing prompt") + def post(self): data = request.get_json() - user = "local" if "user" not in data else data["user"] - conversation_id = data["conversation_id"] - isPromptable = request.args.get("isPromptable").lower() == "true" + required_fields = ["id", "name", "content"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields - conversation = conversations_collection.find_one( - {"_id": ObjectId(conversation_id)} - ) - if conversation is None: - raise Exception("Conversation does not exist") - current_n_queries = len(conversation["queries"]) + try: + prompts_collection.update_one( + {"_id": ObjectId(data["id"])}, + {"$set": {"name": data["name"], "content": data["content"]}}, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) + + +@api.route("/api/get_api_keys") +class GetApiKeys(Resource): + @api.doc(description="Retrieve API keys for the user") + def get(self): + user = "local" + try: + keys = api_key_collection.find({"user": user}) + list_keys = [] + for key in keys: + if "source" in key and isinstance(key["source"], DBRef): + source = db.dereference(key["source"]) + if source is None: + continue + source_name = source["name"] + elif "retriever" in key: + source_name = key["retriever"] + else: + continue + + list_keys.append( + { + "id": str(key["_id"]), + "name": key["name"], + "key": key["key"][:4] + "..." + key["key"][-4:], + "source": source_name, + "prompt_id": key["prompt_id"], + "chunks": key["chunks"], + } + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response(jsonify(list_keys), 200) - ##generate binary representation of uuid - explicit_binary = Binary.from_uuid(uuid.uuid4(), UuidRepresentation.STANDARD) - if isPromptable: - prompt_id = "default" if "prompt_id" not in data else data["prompt_id"] - chunks = "2" if "chunks" not in data else data["chunks"] +@api.route("/api/create_api_key") +class CreateApiKey(Resource): + create_api_key_model = api.model( + "CreateApiKeyModel", + { + "name": fields.String(required=True, description="Name of the API key"), + "prompt_id": fields.String(required=True, description="Prompt ID"), + "chunks": fields.Integer(required=True, description="Chunks count"), + "source": fields.String(description="Source ID (optional)"), + "retriever": fields.String(description="Retriever (optional)"), + }, + ) + + @api.expect(create_api_key_model) + @api.doc(description="Create a new API key") + def post(self): + data = request.get_json() + required_fields = ["name", "prompt_id", "chunks"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields - name = conversation["name"] + "(shared)" - new_api_key_data = { - "prompt_id": prompt_id, - "chunks": chunks, + user = "local" + try: + key = str(uuid.uuid4()) + new_api_key = { + "name": data["name"], + "key": key, "user": user, + "prompt_id": data["prompt_id"], + "chunks": data["chunks"], } if "source" in data and ObjectId.is_valid(data["source"]): - new_api_key_data["source"] = DBRef("sources", ObjectId(data["source"])) - elif "retriever" in data: - new_api_key_data["retriever"] = data["retriever"] - - pre_existing_api_document = api_key_collection.find_one(new_api_key_data) - if pre_existing_api_document: - api_uuid = pre_existing_api_document["key"] - pre_existing = shared_conversations_collections.find_one( - { - "conversation_id": DBRef( - "conversations", ObjectId(conversation_id) - ), - "isPromptable": isPromptable, - "first_n_queries": current_n_queries, - "user": user, - "api_key": api_uuid, - } + new_api_key["source"] = DBRef("sources", ObjectId(data["source"])) + if "retriever" in data: + new_api_key["retriever"] = data["retriever"] + + resp = api_key_collection.insert_one(new_api_key) + new_id = str(resp.inserted_id) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"id": new_id, "key": key}), 201) + + +@api.route("/api/delete_api_key") +class DeleteApiKey(Resource): + delete_api_key_model = api.model( + "DeleteApiKeyModel", + {"id": fields.String(required=True, description="API Key ID to delete")}, + ) + + @api.expect(delete_api_key_model) + @api.doc(description="Delete an API key by ID") + def post(self): + data = request.get_json() + required_fields = ["id"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + result = api_key_collection.delete_one({"_id": ObjectId(data["id"])}) + if result.deleted_count == 0: + return {"success": False, "message": "API Key not found"}, 404 + except Exception as err: + return {"success": False, "error": str(err)}, 400 + + return {"success": True}, 200 + + +@api.route("/api/share") +class ShareConversation(Resource): + share_conversation_model = api.model( + "ShareConversationModel", + { + "conversation_id": fields.String( + required=True, description="Conversation ID" + ), + "user": fields.String(description="User ID (optional)"), + "prompt_id": fields.String(description="Prompt ID (optional)"), + "chunks": fields.Integer(description="Chunks count (optional)"), + }, + ) + + @api.expect(share_conversation_model) + @api.doc(description="Share a conversation") + def post(self): + data = request.get_json() + required_fields = ["conversation_id"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + is_promptable = request.args.get("isPromptable") + if is_promptable is None: + return make_response( + jsonify({"success": False, "message": "isPromptable is required"}), 400 + ) + + user = data.get("user", "local") + conversation_id = data["conversation_id"] + + try: + conversation = conversations_collection.find_one( + {"_id": ObjectId(conversation_id)} + ) + if conversation is None: + return make_response( + jsonify( + { + "status": "error", + "message": "Conversation does not exist", + } + ), + 404, ) - if pre_existing is not None: - return ( - jsonify( + + current_n_queries = len(conversation["queries"]) + explicit_binary = Binary.from_uuid( + uuid.uuid4(), UuidRepresentation.STANDARD + ) + + if is_promptable.lower() == "true": + prompt_id = data.get("prompt_id", "default") + chunks = data.get("chunks", "2") + + name = conversation["name"] + "(shared)" + new_api_key_data = { + "prompt_id": prompt_id, + "chunks": chunks, + "user": user, + } + + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key_data["source"] = DBRef( + "sources", ObjectId(data["source"]) + ) + if "retriever" in data: + new_api_key_data["retriever"] = data["retriever"] + + pre_existing_api_document = api_key_collection.find_one( + new_api_key_data + ) + if pre_existing_api_document: + api_uuid = pre_existing_api_document["key"] + pre_existing = shared_conversations_collections.find_one( + { + "conversation_id": DBRef( + "conversations", ObjectId(conversation_id) + ), + "isPromptable": is_promptable.lower() == "true", + "first_n_queries": current_n_queries, + "user": user, + "api_key": api_uuid, + } + ) + if pre_existing is not None: + return make_response( + jsonify( + { + "success": True, + "identifier": str(pre_existing["uuid"].as_uuid()), + } + ), + 200, + ) + else: + shared_conversations_collections.insert_one( { - "success": True, - "identifier": str(pre_existing["uuid"].as_uuid()), + "uuid": explicit_binary, + "conversation_id": { + "$ref": "conversations", + "$id": ObjectId(conversation_id), + }, + "isPromptable": is_promptable.lower() == "true", + "first_n_queries": current_n_queries, + "user": user, + "api_key": api_uuid, } - ), - 200, - ) + ) + return make_response( + jsonify( + { + "success": True, + "identifier": str(explicit_binary.as_uuid()), + } + ), + 201, + ) else: + api_uuid = str(uuid.uuid4()) + new_api_key_data["key"] = api_uuid + new_api_key_data["name"] = name + + if "source" in data and ObjectId.is_valid(data["source"]): + new_api_key_data["source"] = DBRef( + "sources", ObjectId(data["source"]) + ) + if "retriever" in data: + new_api_key_data["retriever"] = data["retriever"] + + api_key_collection.insert_one(new_api_key_data) shared_conversations_collections.insert_one( { "uuid": explicit_binary, @@ -557,27 +929,43 @@ def share_conversation(): "$ref": "conversations", "$id": ObjectId(conversation_id), }, - "isPromptable": isPromptable, + "isPromptable": is_promptable.lower() == "true", "first_n_queries": current_n_queries, "user": user, "api_key": api_uuid, } ) - return jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} + return make_response( + jsonify( + { + "success": True, + "identifier": str(explicit_binary.as_uuid()), + } + ), + 201, ) - else: - api_uuid = str(uuid.uuid4()) - new_api_key_data["key"] = api_uuid - new_api_key_data["name"] = name - if "source" in data and ObjectId.is_valid(data["source"]): - new_api_key_data["source"] = DBRef( - "sources", ObjectId(data["source"]) - ) - if "retriever" in data: - new_api_key_data["retriever"] = data["retriever"] - api_key_collection.insert_one(new_api_key_data) + pre_existing = shared_conversations_collections.find_one( + { + "conversation_id": DBRef( + "conversations", ObjectId(conversation_id) + ), + "isPromptable": is_promptable.lower() == "false", + "first_n_queries": current_n_queries, + "user": user, + } + ) + if pre_existing is not None: + return make_response( + jsonify( + { + "success": True, + "identifier": str(pre_existing["uuid"].as_uuid()), + } + ), + 200, + ) + else: shared_conversations_collections.insert_one( { "uuid": explicit_binary, @@ -585,80 +973,54 @@ def share_conversation(): "$ref": "conversations", "$id": ObjectId(conversation_id), }, - "isPromptable": isPromptable, + "isPromptable": is_promptable.lower() == "false", "first_n_queries": current_n_queries, "user": user, - "api_key": api_uuid, } ) - ## Identifier as route parameter in frontend - return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), - 201, - ) - - ##isPromptable = False - pre_existing = shared_conversations_collections.find_one( - { - "conversation_id": DBRef("conversations", ObjectId(conversation_id)), - "isPromptable": isPromptable, - "first_n_queries": current_n_queries, - "user": user, - } - ) - if pre_existing is not None: - return ( - jsonify( - {"success": True, "identifier": str(pre_existing["uuid"].as_uuid())} - ), - 200, - ) - else: - shared_conversations_collections.insert_one( - { - "uuid": explicit_binary, - "conversation_id": { - "$ref": "conversations", - "$id": ObjectId(conversation_id), - }, - "isPromptable": isPromptable, - "first_n_queries": current_n_queries, - "user": user, - } - ) - ## Identifier as route parameter in frontend - return ( - jsonify( - {"success": True, "identifier": str(explicit_binary.as_uuid())} - ), - 201, - ) - except Exception as err: - print(err) - return jsonify({"success": False, "error": str(err)}), 400 + return make_response( + jsonify( + {"success": True, "identifier": str(explicit_binary.as_uuid())} + ), + 201, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) -# route to get publicly shared conversations -@user.route("/api/shared_conversation/", methods=["GET"]) -def get_publicly_shared_conversations(identifier: str): - try: - query_uuid = Binary.from_uuid( - uuid.UUID(identifier), UuidRepresentation.STANDARD - ) - shared = shared_conversations_collections.find_one({"uuid": query_uuid}) - conversation_queries = [] - if ( - shared - and "conversation_id" in shared - and isinstance(shared["conversation_id"], DBRef) - ): - # Resolve the DBRef - conversation_ref = shared["conversation_id"] - conversation = db.dereference(conversation_ref) - if conversation is None: - return ( +@api.route("/api/shared_conversation/") +class GetPubliclySharedConversations(Resource): + @api.doc(description="Get publicly shared conversations by identifier") + def get(self, identifier: str): + try: + query_uuid = Binary.from_uuid( + uuid.UUID(identifier), UuidRepresentation.STANDARD + ) + shared = shared_conversations_collections.find_one({"uuid": query_uuid}) + conversation_queries = [] + + if ( + shared + and "conversation_id" in shared + and isinstance(shared["conversation_id"], DBRef) + ): + conversation_ref = shared["conversation_id"] + conversation = db.dereference(conversation_ref) + if conversation is None: + return make_response( + jsonify( + { + "sucess": False, + "error": "might have broken url or the conversation does not exist", + } + ), + 404, + ) + conversation_queries = conversation["queries"][ + : (shared["first_n_queries"]) + ] + else: + return make_response( jsonify( { "sucess": False, @@ -667,470 +1029,577 @@ def get_publicly_shared_conversations(identifier: str): ), 404, ) - conversation_queries = conversation["queries"][ - : (shared["first_n_queries"]) - ] - else: - return ( - jsonify( - { - "sucess": False, - "error": "might have broken url or the conversation does not exist", - } - ), - 404, - ) - date = conversation["_id"].generation_time.isoformat() - res = { - "success": True, - "queries": conversation_queries, - "title": conversation["name"], - "timestamp": date, - } - if shared["isPromptable"] and "api_key" in shared: - res["api_key"] = shared["api_key"] - return jsonify(res), 200 - except Exception as err: - print(err) - return jsonify({"success": False, "error": str(err)}), 400 - - -@user.route("/api/get_message_analytics", methods=["POST"]) -def get_message_analytics(): - data = request.get_json() - api_key_id = data.get("api_key_id") - filter_option = data.get("filter_option", "last_30_days") - - try: - api_key = ( - api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] - if api_key_id - else None - ) - except Exception as err: - print(err) - return jsonify({"success": False, "error": str(err)}), 400 - end_date = datetime.datetime.now(datetime.timezone.utc) - - if filter_option == "last_hour": - start_date = end_date - datetime.timedelta(hours=1) - group_format = "%Y-%m-%d %H:%M:00" - group_stage = { - "$group": { - "_id": { - "minute": { - "$dateToString": {"format": group_format, "date": "$date"} - } - }, - "total_messages": {"$sum": 1}, + date = conversation["_id"].generation_time.isoformat() + res = { + "success": True, + "queries": conversation_queries, + "title": conversation["name"], + "timestamp": date, } - } + if shared["isPromptable"] and "api_key" in shared: + res["api_key"] = shared["api_key"] + return make_response(jsonify(res), 200) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) - elif filter_option == "last_24_hour": - start_date = end_date - datetime.timedelta(hours=24) - group_format = "%Y-%m-%d %H:00" - group_stage = { - "$group": { - "_id": { - "hour": {"$dateToString": {"format": group_format, "date": "$date"}} - }, - "total_messages": {"$sum": 1}, - } - } - else: - if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: - filter_days = ( - 6 - if filter_option == "last_7_days" - else (14 if filter_option == "last_15_days" else 29) +@api.route("/api/get_message_analytics") +class GetMessageAnalytics(Resource): + get_message_analytics_model = api.model( + "GetMessageAnalyticsModel", + { + "api_key_id": fields.String( + required=False, + description="API Key ID", + ), + "filter_option": fields.String( + required=False, + description="Filter option for analytics", + default="last_30_days", + enum=[ + "last_hour", + "last_24_hour", + "last_7_days", + "last_15_days", + "last_30_days", + ], + ), + }, + ) + + @api.expect(get_message_analytics_model) + @api.doc(description="Get message analytics based on filter option") + def post(self): + data = request.get_json() + api_key_id = data.get("api_key_id") + filter_option = data.get("filter_option", "last_30_days") + + try: + api_key = ( + api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] + if api_key_id + else None ) - else: - return jsonify({"success": False, "error": "Invalid option"}), 400 - start_date = end_date - datetime.timedelta(days=filter_days) - start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) - end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999) - group_format = "%Y-%m-%d" - group_stage = { - "$group": { - "_id": { - "day": {"$dateToString": {"format": group_format, "date": "$date"}} - }, - "total_messages": {"$sum": 1}, - } - } + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + end_date = datetime.datetime.now(datetime.timezone.utc) - try: - match_stage = { - "$match": { - "date": {"$gte": start_date, "$lte": end_date}, + if filter_option == "last_hour": + start_date = end_date - datetime.timedelta(hours=1) + group_format = "%Y-%m-%d %H:%M:00" + group_stage = { + "$group": { + "_id": { + "minute": { + "$dateToString": {"format": group_format, "date": "$date"} + } + }, + "total_messages": {"$sum": 1}, + } } - } - if api_key: - match_stage["$match"]["api_key"] = api_key - message_data = conversations_collection.aggregate( - [ - match_stage, - group_stage, - {"$sort": {"_id": 1}}, - ] - ) - if filter_option == "last_hour": - intervals = generate_minute_range(start_date, end_date) elif filter_option == "last_24_hour": - intervals = generate_hourly_range(start_date, end_date) + start_date = end_date - datetime.timedelta(hours=24) + group_format = "%Y-%m-%d %H:00" + group_stage = { + "$group": { + "_id": { + "hour": { + "$dateToString": {"format": group_format, "date": "$date"} + } + }, + "total_messages": {"$sum": 1}, + } + } + else: - intervals = generate_date_range(start_date, end_date) + if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: + filter_days = ( + 6 + if filter_option == "last_7_days" + else (14 if filter_option == "last_15_days" else 29) + ) + else: + return make_response( + jsonify({"success": False, "message": "Invalid option"}), 400 + ) + start_date = end_date - datetime.timedelta(days=filter_days) + start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) + end_date = end_date.replace( + hour=23, minute=59, second=59, microsecond=999999 + ) + group_format = "%Y-%m-%d" + group_stage = { + "$group": { + "_id": { + "day": { + "$dateToString": {"format": group_format, "date": "$date"} + } + }, + "total_messages": {"$sum": 1}, + } + } - daily_messages = {interval: 0 for interval in intervals} + try: + match_stage = { + "$match": { + "date": {"$gte": start_date, "$lte": end_date}, + } + } + if api_key: + match_stage["$match"]["api_key"] = api_key + message_data = conversations_collection.aggregate( + [ + match_stage, + group_stage, + {"$sort": {"_id": 1}}, + ] + ) - for entry in message_data: if filter_option == "last_hour": - daily_messages[entry["_id"]["minute"]] = entry["total_messages"] + intervals = generate_minute_range(start_date, end_date) elif filter_option == "last_24_hour": - daily_messages[entry["_id"]["hour"]] = entry["total_messages"] + intervals = generate_hourly_range(start_date, end_date) else: - daily_messages[entry["_id"]["day"]] = entry["total_messages"] + intervals = generate_date_range(start_date, end_date) - except Exception as err: - print(err) - return jsonify({"success": False, "error": str(err)}), 400 - - return jsonify({"success": True, "messages": daily_messages}), 200 + daily_messages = {interval: 0 for interval in intervals} + for entry in message_data: + if filter_option == "last_hour": + daily_messages[entry["_id"]["minute"]] = entry["total_messages"] + elif filter_option == "last_24_hour": + daily_messages[entry["_id"]["hour"]] = entry["total_messages"] + else: + daily_messages[entry["_id"]["day"]] = entry["total_messages"] -@user.route("/api/get_token_analytics", methods=["POST"]) -def get_token_analytics(): - data = request.get_json() - api_key_id = data.get("api_key_id") - filter_option = data.get("filter_option", "last_30_days") + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) - try: - api_key = ( - api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] - if api_key_id - else None + return make_response( + jsonify({"success": True, "messages": daily_messages}), 200 ) - except Exception as err: - print(err) - return jsonify({"success": False, "error": str(err)}), 400 - end_date = datetime.datetime.now(datetime.timezone.utc) - - if filter_option == "last_hour": - start_date = end_date - datetime.timedelta(hours=1) - group_format = "%Y-%m-%d %H:%M:00" - group_stage = { - "$group": { - "_id": { - "minute": { - "$dateToString": {"format": group_format, "date": "$timestamp"} - } - }, - "total_tokens": { - "$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]} - }, - } - } - elif filter_option == "last_24_hour": - start_date = end_date - datetime.timedelta(hours=24) - group_format = "%Y-%m-%d %H:00" - group_stage = { - "$group": { - "_id": { - "hour": { - "$dateToString": {"format": group_format, "date": "$timestamp"} - } - }, - "total_tokens": { - "$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]} - }, - } - } - else: - if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: - filter_days = ( - 6 - if filter_option == "last_7_days" - else (14 if filter_option == "last_15_days" else 29) - ) - else: - return jsonify({"success": False, "error": "Invalid option"}), 400 - start_date = end_date - datetime.timedelta(days=filter_days) - start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) - end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999) - group_format = "%Y-%m-%d" - group_stage = { - "$group": { - "_id": { - "day": { - "$dateToString": {"format": group_format, "date": "$timestamp"} - } - }, - "total_tokens": { - "$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]} - }, - } - } +@api.route("/api/get_token_analytics") +class GetTokenAnalytics(Resource): + get_token_analytics_model = api.model( + "GetTokenAnalyticsModel", + { + "api_key_id": fields.String(required=False, description="API Key ID"), + "filter_option": fields.String( + required=False, + description="Filter option for analytics", + default="last_30_days", + enum=[ + "last_hour", + "last_24_hour", + "last_7_days", + "last_15_days", + "last_30_days", + ], + ), + }, + ) - try: - match_stage = { - "$match": { - "timestamp": {"$gte": start_date, "$lte": end_date}, - } - } - if api_key: - match_stage["$match"]["api_key"] = api_key + @api.expect(get_token_analytics_model) + @api.doc(description="Get token analytics data") + def post(self): + data = request.get_json() + api_key_id = data.get("api_key_id") + filter_option = data.get("filter_option", "last_30_days") - token_usage_data = token_usage_collection.aggregate( - [ - match_stage, - group_stage, - {"$sort": {"_id": 1}}, - ] - ) + try: + api_key = ( + api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] + if api_key_id + else None + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + end_date = datetime.datetime.now(datetime.timezone.utc) if filter_option == "last_hour": - intervals = generate_minute_range(start_date, end_date) + start_date = end_date - datetime.timedelta(hours=1) + group_format = "%Y-%m-%d %H:%M:00" + group_stage = { + "$group": { + "_id": { + "minute": { + "$dateToString": { + "format": group_format, + "date": "$timestamp", + } + } + }, + "total_tokens": { + "$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]} + }, + } + } + elif filter_option == "last_24_hour": - intervals = generate_hourly_range(start_date, end_date) + start_date = end_date - datetime.timedelta(hours=24) + group_format = "%Y-%m-%d %H:00" + group_stage = { + "$group": { + "_id": { + "hour": { + "$dateToString": { + "format": group_format, + "date": "$timestamp", + } + } + }, + "total_tokens": { + "$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]} + }, + } + } + else: - intervals = generate_date_range(start_date, end_date) + if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: + filter_days = ( + 6 + if filter_option == "last_7_days" + else (14 if filter_option == "last_15_days" else 29) + ) + else: + return make_response( + jsonify({"success": False, "message": "Invalid option"}), 400 + ) + start_date = end_date - datetime.timedelta(days=filter_days) + start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) + end_date = end_date.replace( + hour=23, minute=59, second=59, microsecond=999999 + ) + group_format = "%Y-%m-%d" + group_stage = { + "$group": { + "_id": { + "day": { + "$dateToString": { + "format": group_format, + "date": "$timestamp", + } + } + }, + "total_tokens": { + "$sum": {"$add": ["$prompt_tokens", "$generated_tokens"]} + }, + } + } - daily_token_usage = {interval: 0 for interval in intervals} + try: + match_stage = { + "$match": { + "timestamp": {"$gte": start_date, "$lte": end_date}, + } + } + if api_key: + match_stage["$match"]["api_key"] = api_key + + token_usage_data = token_usage_collection.aggregate( + [ + match_stage, + group_stage, + {"$sort": {"_id": 1}}, + ] + ) - for entry in token_usage_data: if filter_option == "last_hour": - daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"] + intervals = generate_minute_range(start_date, end_date) elif filter_option == "last_24_hour": - daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"] + intervals = generate_hourly_range(start_date, end_date) else: - daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"] + intervals = generate_date_range(start_date, end_date) - except Exception as err: - print(err) - return jsonify({"success": False, "error": str(err)}), 400 - - return jsonify({"success": True, "token_usage": daily_token_usage}), 200 + daily_token_usage = {interval: 0 for interval in intervals} + for entry in token_usage_data: + if filter_option == "last_hour": + daily_token_usage[entry["_id"]["minute"]] = entry["total_tokens"] + elif filter_option == "last_24_hour": + daily_token_usage[entry["_id"]["hour"]] = entry["total_tokens"] + else: + daily_token_usage[entry["_id"]["day"]] = entry["total_tokens"] -@user.route("/api/get_feedback_analytics", methods=["POST"]) -def get_feedback_analytics(): - data = request.get_json() - api_key_id = data.get("api_key_id") - filter_option = data.get("filter_option", "last_30_days") + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) - try: - api_key = ( - api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] - if api_key_id - else None + return make_response( + jsonify({"success": True, "token_usage": daily_token_usage}), 200 ) - except Exception as err: - print(err) - return jsonify({"success": False, "error": str(err)}), 400 - end_date = datetime.datetime.now(datetime.timezone.utc) - - if filter_option == "last_hour": - start_date = end_date - datetime.timedelta(hours=1) - group_format = "%Y-%m-%d %H:%M:00" - group_stage_1 = { - "$group": { - "_id": { - "minute": { - "$dateToString": {"format": group_format, "date": "$timestamp"} + + +@api.route("/api/get_feedback_analytics") +class GetFeedbackAnalytics(Resource): + get_feedback_analytics_model = api.model( + "GetFeedbackAnalyticsModel", + { + "api_key_id": fields.String(required=False, description="API Key ID"), + "filter_option": fields.String( + required=False, + description="Filter option for analytics", + default="last_30_days", + enum=[ + "last_hour", + "last_24_hour", + "last_7_days", + "last_15_days", + "last_30_days", + ], + ), + }, + ) + + @api.expect(get_feedback_analytics_model) + @api.doc(description="Get feedback analytics data") + def post(self): + data = request.get_json() + api_key_id = data.get("api_key_id") + filter_option = data.get("filter_option", "last_30_days") + + try: + api_key = ( + api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] + if api_key_id + else None + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + end_date = datetime.datetime.now(datetime.timezone.utc) + + if filter_option == "last_hour": + start_date = end_date - datetime.timedelta(hours=1) + group_format = "%Y-%m-%d %H:%M:00" + group_stage_1 = { + "$group": { + "_id": { + "minute": { + "$dateToString": { + "format": group_format, + "date": "$timestamp", + } + }, + "feedback": "$feedback", }, - "feedback": "$feedback", - }, - "count": {"$sum": 1}, + "count": {"$sum": 1}, + } } - } - group_stage_2 = { - "$group": { - "_id": "$_id.minute", - "likes": { - "$sum": { - "$cond": [ - {"$eq": ["$_id.feedback", "LIKE"]}, - "$count", - 0, - ] - } - }, - "dislikes": { - "$sum": { - "$cond": [ - {"$eq": ["$_id.feedback", "DISLIKE"]}, - "$count", - 0, - ] - } - }, + group_stage_2 = { + "$group": { + "_id": "$_id.minute", + "likes": { + "$sum": { + "$cond": [ + {"$eq": ["$_id.feedback", "LIKE"]}, + "$count", + 0, + ] + } + }, + "dislikes": { + "$sum": { + "$cond": [ + {"$eq": ["$_id.feedback", "DISLIKE"]}, + "$count", + 0, + ] + } + }, + } } - } - elif filter_option == "last_24_hour": - start_date = end_date - datetime.timedelta(hours=24) - group_format = "%Y-%m-%d %H:00" - group_stage_1 = { - "$group": { - "_id": { - "hour": { - "$dateToString": {"format": group_format, "date": "$timestamp"} + elif filter_option == "last_24_hour": + start_date = end_date - datetime.timedelta(hours=24) + group_format = "%Y-%m-%d %H:00" + group_stage_1 = { + "$group": { + "_id": { + "hour": { + "$dateToString": { + "format": group_format, + "date": "$timestamp", + } + }, + "feedback": "$feedback", }, - "feedback": "$feedback", - }, - "count": {"$sum": 1}, + "count": {"$sum": 1}, + } } - } - group_stage_2 = { - "$group": { - "_id": "$_id.hour", - "likes": { - "$sum": { - "$cond": [ - {"$eq": ["$_id.feedback", "LIKE"]}, - "$count", - 0, - ] - } - }, - "dislikes": { - "$sum": { - "$cond": [ - {"$eq": ["$_id.feedback", "DISLIKE"]}, - "$count", - 0, - ] - } - }, + group_stage_2 = { + "$group": { + "_id": "$_id.hour", + "likes": { + "$sum": { + "$cond": [ + {"$eq": ["$_id.feedback", "LIKE"]}, + "$count", + 0, + ] + } + }, + "dislikes": { + "$sum": { + "$cond": [ + {"$eq": ["$_id.feedback", "DISLIKE"]}, + "$count", + 0, + ] + } + }, + } } - } - else: - if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: - filter_days = ( - 6 - if filter_option == "last_7_days" - else (14 if filter_option == "last_15_days" else 29) - ) else: - return jsonify({"success": False, "error": "Invalid option"}), 400 - start_date = end_date - datetime.timedelta(days=filter_days) - start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) - end_date = end_date.replace(hour=23, minute=59, second=59, microsecond=999999) - group_format = "%Y-%m-%d" - group_stage_1 = { - "$group": { - "_id": { - "day": { - "$dateToString": {"format": group_format, "date": "$timestamp"} + if filter_option in ["last_7_days", "last_15_days", "last_30_days"]: + filter_days = ( + 6 + if filter_option == "last_7_days" + else (14 if filter_option == "last_15_days" else 29) + ) + else: + return make_response( + jsonify({"success": False, "message": "Invalid option"}), 400 + ) + start_date = end_date - datetime.timedelta(days=filter_days) + start_date = start_date.replace(hour=0, minute=0, second=0, microsecond=0) + end_date = end_date.replace( + hour=23, minute=59, second=59, microsecond=999999 + ) + group_format = "%Y-%m-%d" + group_stage_1 = { + "$group": { + "_id": { + "day": { + "$dateToString": { + "format": group_format, + "date": "$timestamp", + } + }, + "feedback": "$feedback", }, - "feedback": "$feedback", - }, - "count": {"$sum": 1}, + "count": {"$sum": 1}, + } } - } - group_stage_2 = { - "$group": { - "_id": "$_id.day", - "likes": { - "$sum": { - "$cond": [ - {"$eq": ["$_id.feedback", "LIKE"]}, - "$count", - 0, - ] - } - }, - "dislikes": { - "$sum": { - "$cond": [ - {"$eq": ["$_id.feedback", "DISLIKE"]}, - "$count", - 0, - ] - } - }, + group_stage_2 = { + "$group": { + "_id": "$_id.day", + "likes": { + "$sum": { + "$cond": [ + {"$eq": ["$_id.feedback", "LIKE"]}, + "$count", + 0, + ] + } + }, + "dislikes": { + "$sum": { + "$cond": [ + {"$eq": ["$_id.feedback", "DISLIKE"]}, + "$count", + 0, + ] + } + }, + } } - } - try: - match_stage = { - "$match": { - "timestamp": {"$gte": start_date, "$lte": end_date}, + try: + match_stage = { + "$match": { + "timestamp": {"$gte": start_date, "$lte": end_date}, + } } - } - if api_key: - match_stage["$match"]["api_key"] = api_key - - feedback_data = feedback_collection.aggregate( - [ - match_stage, - group_stage_1, - group_stage_2, - {"$sort": {"_id": 1}}, - ] - ) - - if filter_option == "last_hour": - intervals = generate_minute_range(start_date, end_date) - elif filter_option == "last_24_hour": - intervals = generate_hourly_range(start_date, end_date) - else: - intervals = generate_date_range(start_date, end_date) + if api_key: + match_stage["$match"]["api_key"] = api_key + + feedback_data = feedback_collection.aggregate( + [ + match_stage, + group_stage_1, + group_stage_2, + {"$sort": {"_id": 1}}, + ] + ) - daily_feedback = { - interval: {"positive": 0, "negative": 0} for interval in intervals - } + if filter_option == "last_hour": + intervals = generate_minute_range(start_date, end_date) + elif filter_option == "last_24_hour": + intervals = generate_hourly_range(start_date, end_date) + else: + intervals = generate_date_range(start_date, end_date) - for entry in feedback_data: - daily_feedback[entry["_id"]] = { - "positive": entry["likes"], - "negative": entry["dislikes"], + daily_feedback = { + interval: {"positive": 0, "negative": 0} for interval in intervals } - except Exception as err: - print(err) - return jsonify({"success": False, "error": str(err)}), 400 + for entry in feedback_data: + daily_feedback[entry["_id"]] = { + "positive": entry["likes"], + "negative": entry["dislikes"], + } - return jsonify({"success": True, "feedback": daily_feedback}), 200 + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response( + jsonify({"success": True, "feedback": daily_feedback}), 200 + ) -@user.route("/api/get_user_logs", methods=["POST"]) -def get_user_logs(): - data = request.get_json() - page = int(data.get("page", 1)) - api_key_id = data.get("api_key_id") - page_size = int(data.get("page_size", 10)) - skip = (page - 1) * page_size - try: - api_key = ( - api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] - if api_key_id - else None - ) - except Exception as err: - print(err) - return jsonify({"success": False, "error": str(err)}), 400 - - query = {} - if api_key: - query = {"api_key": api_key} - items_cursor = ( - user_logs_collection.find(query) - .sort("timestamp", -1) - .skip(skip) - .limit(page_size + 1) +@api.route("/api/get_user_logs") +class GetUserLogs(Resource): + get_user_logs_model = api.model( + "GetUserLogsModel", + { + "page": fields.Integer( + required=False, + description="Page number for pagination", + default=1, + ), + "api_key_id": fields.String(required=False, description="API Key ID"), + "page_size": fields.Integer( + required=False, + description="Number of logs per page", + default=10, + ), + }, ) - items = list(items_cursor) - results = [] - for item in items[:page_size]: - results.append( + @api.expect(get_user_logs_model) + @api.doc(description="Get user logs with pagination") + def post(self): + data = request.get_json() + page = int(data.get("page", 1)) + api_key_id = data.get("api_key_id") + page_size = int(data.get("page_size", 10)) + skip = (page - 1) * page_size + + try: + api_key = ( + api_key_collection.find_one({"_id": ObjectId(api_key_id)})["key"] + if api_key_id + else None + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + query = {} + if api_key: + query = {"api_key": api_key} + + items_cursor = ( + user_logs_collection.find(query) + .sort("timestamp", -1) + .skip(skip) + .limit(page_size + 1) + ) + items = list(items_cursor) + + results = [ { "id": str(item.get("_id")), "action": item.get("action"), @@ -1141,42 +1610,65 @@ def get_user_logs(): "retriever_params": item.get("retriever_params"), "timestamp": item.get("timestamp"), } + for item in items[:page_size] + ] + + has_more = len(items) > page_size + + return make_response( + jsonify( + { + "success": True, + "logs": results, + "page": page, + "page_size": page_size, + "has_more": has_more, + } + ), + 200, ) - has_more = len(items) > page_size - return ( - jsonify( - { - "success": True, - "logs": results, - "page": page, - "page_size": page_size, - "has_more": has_more, - } - ), - 200, + +@api.route("/api/manage_sync") +class ManageSync(Resource): + manage_sync_model = api.model( + "ManageSyncModel", + { + "source_id": fields.String(required=True, description="Source ID"), + "sync_frequency": fields.String( + required=True, + description="Sync frequency (never, daily, weekly, monthly)", + ), + }, ) + @api.expect(manage_sync_model) + @api.doc(description="Manage sync frequency for sources") + def post(self): + data = request.get_json() + required_fields = ["source_id", "sync_frequency"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields -@user.route("/api/manage_sync", methods=["POST"]) -def manage_sync(): - data = request.get_json() - source_id = data.get("source_id") - sync_frequency = data.get("sync_frequency") + source_id = data["source_id"] + sync_frequency = data["sync_frequency"] - if sync_frequency not in ["never", "daily", "weekly", "monthly"]: - return jsonify({"status": "invalid frequency"}), 400 + if sync_frequency not in ["never", "daily", "weekly", "monthly"]: + return make_response( + jsonify({"success": False, "message": "Invalid frequency"}), 400 + ) - update_data = {"$set": {"sync_frequency": sync_frequency}} - try: - sources_collection.update_one( - { - "_id": ObjectId(source_id), - "user": "local", - }, - update_data, - ) - except Exception as err: - print(err) - return jsonify({"status": "error"}), 400 - return jsonify({"status": "ok"}), 200 + update_data = {"$set": {"sync_frequency": sync_frequency}} + try: + sources_collection.update_one( + { + "_id": ObjectId(source_id), + "user": "local", + }, + update_data, + ) + except Exception as err: + return make_response(jsonify({"success": False, "error": str(err)}), 400) + + return make_response(jsonify({"success": True}), 200) diff --git a/application/requirements.txt b/application/requirements.txt index 023d2ef32..d7621cfdd 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -13,6 +13,7 @@ esprima==4.0.1 esutils==1.0.1 Flask==3.0.3 faiss-cpu==1.8.0.post1 +flask-restx==1.3.0 gunicorn==23.0.0 html2text==2024.2.26 javalang==0.13.0 diff --git a/frontend/src/components/Dropdown.tsx b/frontend/src/components/Dropdown.tsx index 3daa39117..07f33650a 100644 --- a/frontend/src/components/Dropdown.tsx +++ b/frontend/src/components/Dropdown.tsx @@ -82,12 +82,12 @@ function Dropdown({ }`} > {typeof selectedValue === 'string' ? ( - + {selectedValue} ) : ( diff --git a/frontend/src/modals/CreateAPIKeyModal.tsx b/frontend/src/modals/CreateAPIKeyModal.tsx index 71d86330c..eb085a28c 100644 --- a/frontend/src/modals/CreateAPIKeyModal.tsx +++ b/frontend/src/modals/CreateAPIKeyModal.tsx @@ -97,14 +97,13 @@ export default function CreateAPIKeyModal({
{ setSourcePath(selection); - console.log(selection); }} options={extractDocPaths()} size="w-full" diff --git a/frontend/src/settings/Analytics.tsx b/frontend/src/settings/Analytics.tsx index a385c4713..5ddab2cbf 100644 --- a/frontend/src/settings/Analytics.tsx +++ b/frontend/src/settings/Analytics.tsx @@ -181,8 +181,8 @@ export default function Analytics() { border="border" />
-
-
+
+

Messages @@ -227,7 +227,7 @@ export default function Analytics() { />

-
+

Token Usage From e8988e82d00b57e37ec77b278d9516287ce35bf6 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 30 Sep 2024 00:41:34 +0530 Subject: [PATCH 2/3] refactor: answer routes to comply with OpenAPI spec using flask-restx --- application/api/answer/routes.py | 620 ++++++++++++++++++------------- application/api/user/routes.py | 85 ++--- application/app.py | 25 +- application/extensions.py | 7 + application/utils.py | 21 +- 5 files changed, 429 insertions(+), 329 deletions(-) create mode 100644 application/extensions.py diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 6f0315152..35b95174d 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -1,20 +1,24 @@ import asyncio -import os -import sys -from flask import Blueprint, request, Response, current_app -import json import datetime +import json import logging +import os +import sys import traceback -from pymongo import MongoClient -from bson.objectid import ObjectId from bson.dbref import DBRef +from bson.objectid import ObjectId +from flask import Blueprint, current_app, make_response, request, Response +from flask_restx import fields, Namespace, Resource + +from pymongo import MongoClient from application.core.settings import settings +from application.error import bad_request +from application.extensions import api from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator -from application.error import bad_request +from application.utils import check_required_fields logger = logging.getLogger(__name__) @@ -25,7 +29,10 @@ prompts_collection = db["prompts"] api_key_collection = db["api_keys"] user_logs_collection = db["user_logs"] + answer = Blueprint("answer", __name__) +answer_ns = Namespace("answer", description="Answer related operations", path="/") +api.add_namespace(answer_ns) gpt_model = "" # to have some kind of default behaviour @@ -186,10 +193,10 @@ def complete_stream( answer = retriever.gen() sources = retriever.search() for source in sources: - if("text" in source): - source["text"] = source["text"][:100].strip()+"..." - if(len(sources) > 0): - data = json.dumps({"type":"source","source":sources}) + if "text" in source: + source["text"] = source["text"][:100].strip() + "..." + if len(sources) > 0: + data = json.dumps({"type": "source", "source": sources}) yield f"data: {data}\n\n" for line in answer: if "answer" in line: @@ -243,109 +250,133 @@ def complete_stream( return -@answer.route("/stream", methods=["POST"]) -def stream(): - try: +@answer_ns.route("/stream") +class Stream(Resource): + stream_model = api.model( + "StreamModel", + { + "question": fields.String( + required=True, description="Question to be asked" + ), + "history": fields.List( + fields.String, required=False, description="Chat history" + ), + "conversation_id": fields.String( + required=False, description="Conversation ID" + ), + "prompt_id": fields.String( + required=False, default="default", description="Prompt ID" + ), + "selectedDocs": fields.String( + required=False, description="Selected documents" + ), + "chunks": fields.Integer( + required=False, default=2, description="Number of chunks" + ), + "token_limit": fields.Integer(required=False, description="Token limit"), + "retriever": fields.String(required=False, description="Retriever type"), + "api_key": fields.String(required=False, description="API key"), + "active_docs": fields.String( + required=False, description="Active documents" + ), + "isNoneDoc": fields.Boolean( + required=False, description="Flag indicating if no document is used" + ), + }, + ) + + @api.expect(stream_model) + @api.doc(description="Stream a response based on the question and retriever") + def post(self): data = request.get_json() - question = data["question"] - if "history" not in data: - history = [] - else: - history = data["history"] + required_fields = ["question"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + question = data["question"] + history = data.get("history", []) history = json.loads(history) - if "conversation_id" not in data: - conversation_id = None - else: - conversation_id = data["conversation_id"] - if "prompt_id" in data: - prompt_id = data["prompt_id"] - else: - prompt_id = "default" - if "selectedDocs" in data and data["selectedDocs"] is None: - chunks = 0 - elif "chunks" in data: - chunks = int(data["chunks"]) - else: - chunks = 2 - if "token_limit" in data: - token_limit = data["token_limit"] - else: - token_limit = settings.DEFAULT_MAX_HISTORY - - ## retriever can be "brave_search, duckduck_search or classic" - retriever_name = data["retriever"] if "retriever" in data else "classic" - - # check if active_docs or api_key is set - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key["chunks"]) - prompt_id = data_key["prompt_id"] - source = {"active_docs": data_key["source"]} - retriever_name = data_key["retriever"] or retriever_name - user_api_key = data["api_key"] - - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - retriever_name = get_retriever(data["active_docs"]) or retriever_name - user_api_key = None - - else: - source = {} - user_api_key = None - - current_app.logger.info( - f"/stream - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})}, - ) + conversation_id = data.get("conversation_id") + prompt_id = data.get("prompt_id", "default") + if "selectedDocs" in data and data["selectedDocs"] is None: + chunks = 0 + else: + chunks = int(data.get("chunks", 2)) + token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) + retriever_name = data.get("retriever", "classic") + + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + chunks = int(data_key.get("chunks", 2)) + prompt_id = data_key.get("prompt_id", "default") + source = {"active_docs": data_key.get("source")} + retriever_name = data_key.get("retriever", retriever_name) + user_api_key = data["api_key"] + + elif "active_docs" in data: + source = {"active_docs": data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) or retriever_name + user_api_key = None + + else: + source = {} + user_api_key = None + + current_app.logger.info( + f"/stream - request_data: {data}, source: {source}", + extra={"data": json.dumps({"request_data": data, "source": source})}, + ) - prompt = get_prompt(prompt_id) - - retriever = RetrieverCreator.create_retriever( - retriever_name, - question=question, - source=source, - chat_history=history, - prompt=prompt, - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, - ) + prompt = get_prompt(prompt_id) - return Response( - complete_stream( + retriever = RetrieverCreator.create_retriever( + retriever_name, question=question, - retriever=retriever, - conversation_id=conversation_id, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + token_limit=token_limit, + gpt_model=gpt_model, user_api_key=user_api_key, - isNoneDoc=data.get("isNoneDoc"), - ), - mimetype="text/event-stream", - ) + ) - except ValueError: - message = "Malformed request body" - print("\033[91merr", str(message), file=sys.stderr) - return Response( - error_stream_generate(message), - status=400, - mimetype="text/event-stream", - ) - except Exception as e: - current_app.logger.error( - f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}", - extra={"error": str(e), "traceback": traceback.format_exc()}, - ) - message = e.args[0] - status_code = 400 - # # Custom exceptions with two arguments, index 1 as status code - if len(e.args) >= 2: - status_code = e.args[1] - return Response( - error_stream_generate(message), - status=status_code, - mimetype="text/event-stream", - ) + return Response( + complete_stream( + question=question, + retriever=retriever, + conversation_id=conversation_id, + user_api_key=user_api_key, + isNoneDoc=data.get("isNoneDoc"), + ), + mimetype="text/event-stream", + ) + + except ValueError: + message = "Malformed request body" + print("\033[91merr", str(message), file=sys.stderr) + return Response( + error_stream_generate(message), + status=400, + mimetype="text/event-stream", + ) + except Exception as e: + current_app.logger.error( + f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}", + extra={"error": str(e), "traceback": traceback.format_exc()}, + ) + message = e.args[0] + status_code = 400 + # Custom exceptions with two arguments, index 1 as status code + if len(e.args) >= 2: + status_code = e.args[1] + return Response( + error_stream_generate(message), + status=status_code, + mimetype="text/event-stream", + ) def error_stream_generate(err_response): @@ -353,180 +384,235 @@ def error_stream_generate(err_response): yield f"data: {data}\n\n" -@answer.route("/api/answer", methods=["POST"]) -def api_answer(): - data = request.get_json() - question = data["question"] - if "history" not in data: - history = [] - else: - history = data["history"] - if "conversation_id" not in data: - conversation_id = None - else: - conversation_id = data["conversation_id"] - print("-" * 5) - if "prompt_id" in data: - prompt_id = data["prompt_id"] - else: - prompt_id = "default" - if "chunks" in data: - chunks = int(data["chunks"]) - else: - chunks = 2 - if "token_limit" in data: - token_limit = data["token_limit"] - else: - token_limit = settings.DEFAULT_MAX_HISTORY - - ## retriever can be brave_search, duckduck_search or classic - retriever_name = data["retriever"] if "retriever" in data else "classic" +@answer_ns.route("/api/answer") +class Answer(Resource): + answer_model = api.model( + "AnswerModel", + { + "question": fields.String( + required=True, description="The question to answer" + ), + "history": fields.List( + fields.String, required=False, description="Conversation history" + ), + "conversation_id": fields.String( + required=False, description="Conversation ID" + ), + "prompt_id": fields.String( + required=False, default="default", description="Prompt ID" + ), + "chunks": fields.Integer( + required=False, default=2, description="Number of chunks" + ), + "token_limit": fields.Integer(required=False, description="Token limit"), + "retriever": fields.String(required=False, description="Retriever type"), + "api_key": fields.String(required=False, description="API key"), + "active_docs": fields.String( + required=False, description="Active documents" + ), + "isNoneDoc": fields.Boolean( + required=False, description="Flag indicating if no document is used" + ), + }, + ) - # use try and except to check for exception - try: - # check if the vectorstore is set - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key["chunks"]) - prompt_id = data_key["prompt_id"] - source = {"active_docs": data_key["source"]} - retriever_name = data_key["retriever"] or retriever_name - user_api_key = data["api_key"] - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - retriever_name = get_retriever(data["active_docs"]) or retriever_name - user_api_key = None - else: - source = {} - user_api_key = None - - prompt = get_prompt(prompt_id) - - current_app.logger.info( - f"/api/answer - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})}, - ) + @api.expect(answer_model) + @api.doc(description="Provide an answer based on the question and retriever") + def post(self): + data = request.get_json() + required_fields = ["question"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + question = data["question"] + history = data.get("history", []) + conversation_id = data.get("conversation_id") + prompt_id = data.get("prompt_id", "default") + chunks = int(data.get("chunks", 2)) + token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) + retriever_name = data.get("retriever", "classic") + + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + chunks = int(data_key.get("chunks", 2)) + prompt_id = data_key.get("prompt_id", "default") + source = {"active_docs": data_key.get("source")} + retriever_name = data_key.get("retriever", retriever_name) + user_api_key = data["api_key"] + elif "active_docs" in data: + source = {"active_docs": data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) or retriever_name + user_api_key = None + else: + source = {} + user_api_key = None + + prompt = get_prompt(prompt_id) + + current_app.logger.info( + f"/api/answer - request_data: {data}, source: {source}", + extra={"data": json.dumps({"request_data": data, "source": source})}, + ) - retriever = RetrieverCreator.create_retriever( - retriever_name, - question=question, - source=source, - chat_history=history, - prompt=prompt, - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, - ) - source_log_docs = [] - response_full = "" - for line in retriever.gen(): - if "source" in line: - source_log_docs.append(line["source"]) - elif "answer" in line: - response_full += line["answer"] + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + token_limit=token_limit, + gpt_model=gpt_model, + user_api_key=user_api_key, + ) - if data.get("isNoneDoc"): - for doc in source_log_docs: - doc["source"] = "None" + source_log_docs = [] + response_full = "" + for line in retriever.gen(): + if "source" in line: + source_log_docs.append(line["source"]) + elif "answer" in line: + response_full += line["answer"] - llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key - ) + if data.get("isNoneDoc"): + for doc in source_log_docs: + doc["source"] = "None" - result = {"answer": response_full, "sources": source_log_docs} - result["conversation_id"] = str( - save_conversation( - conversation_id, question, response_full, source_log_docs, llm + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key ) - ) - retriever_params = retriever.get_params() - user_logs_collection.insert_one( - { - "action": "api_answer", - "level": "info", - "user": "local", - "api_key": user_api_key, - "question": question, - "response": response_full, - "sources": source_log_docs, - "retriever_params": retriever_params, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - ) - return result - except Exception as e: - current_app.logger.error( - f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", - extra={"error": str(e), "traceback": traceback.format_exc()}, - ) - return bad_request(500, str(e)) + result = {"answer": response_full, "sources": source_log_docs} + result["conversation_id"] = str( + save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) + ) + retriever_params = retriever.get_params() + user_logs_collection.insert_one( + { + "action": "api_answer", + "level": "info", + "user": "local", + "api_key": user_api_key, + "question": question, + "response": response_full, + "sources": source_log_docs, + "retriever_params": retriever_params, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + ) + except Exception as e: + current_app.logger.error( + f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", + extra={"error": str(e), "traceback": traceback.format_exc()}, + ) + return bad_request(500, str(e)) -@answer.route("/api/search", methods=["POST"]) -def api_search(): - data = request.get_json() - question = data["question"] - if "chunks" in data: - chunks = int(data["chunks"]) - else: - chunks = 2 - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key["chunks"]) - source = {"active_docs":data_key["source"]} - user_api_key = data["api_key"] - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - user_api_key = None - else: - source = {} - user_api_key = None + return make_response(result, 200) - if "retriever" in data: - retriever_name = data["retriever"] - else: - retriever_name = "classic" - if "token_limit" in data: - token_limit = data["token_limit"] - else: - token_limit = settings.DEFAULT_MAX_HISTORY - current_app.logger.info( - f"/api/answer - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})}, +@answer_ns.route("/api/search") +class Search(Resource): + search_model = api.model( + "SearchModel", + { + "question": fields.String( + required=True, description="The question to search" + ), + "chunks": fields.Integer( + required=False, default=2, description="Number of chunks" + ), + "api_key": fields.String( + required=False, description="API key for authentication" + ), + "active_docs": fields.String( + required=False, description="Active documents for retrieval" + ), + "retriever": fields.String(required=False, description="Retriever type"), + "token_limit": fields.Integer( + required=False, description="Limit for tokens" + ), + "isNoneDoc": fields.Boolean( + required=False, description="Flag indicating if no document is used" + ), + }, ) - retriever = RetrieverCreator.create_retriever( - retriever_name, - question=question, - source=source, - chat_history=[], - prompt="default", - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, + @api.expect(search_model) + @api.doc( + description="Search for relevant documents based on the question and retriever" ) - docs = retriever.search() + def post(self): + data = request.get_json() + required_fields = ["question"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + question = data["question"] + chunks = int(data.get("chunks", 2)) + token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) + retriever_name = data.get("retriever", "classic") + + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + chunks = int(data_key.get("chunks", 2)) + source = {"active_docs": data_key.get("source")} + user_api_key = data["api_key"] + elif "active_docs" in data: + source = {"active_docs": data["active_docs"]} + user_api_key = None + else: + source = {} + user_api_key = None + + current_app.logger.info( + f"/api/answer - request_data: {data}, source: {source}", + extra={"data": json.dumps({"request_data": data, "source": source})}, + ) - retriever_params = retriever.get_params() - user_logs_collection.insert_one( - { - "action": "api_search", - "level": "info", - "user": "local", - "api_key": user_api_key, - "question": question, - "sources": docs, - "retriever_params": retriever_params, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - ) + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=[], + prompt="default", + chunks=chunks, + token_limit=token_limit, + gpt_model=gpt_model, + user_api_key=user_api_key, + ) - if data.get("isNoneDoc"): - for doc in docs: - doc["source"] = "None" + docs = retriever.search() + retriever_params = retriever.get_params() + + user_logs_collection.insert_one( + { + "action": "api_search", + "level": "info", + "user": "local", + "api_key": user_api_key, + "question": question, + "sources": docs, + "retriever_params": retriever_params, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + ) + + if data.get("isNoneDoc"): + for doc in docs: + doc["source"] = "None" + + except Exception as e: + current_app.logger.error( + f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}", + extra={"error": str(e), "traceback": traceback.format_exc()}, + ) + return bad_request(500, str(e)) - return docs + return make_response(docs, 200) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 0bc11a1be..657b3673e 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -7,13 +7,15 @@ from bson.dbref import DBRef from bson.objectid import ObjectId from flask import Blueprint, jsonify, make_response, request -from flask_restx import Api, fields, Namespace, Resource +from flask_restx import fields, Namespace, Resource from pymongo import MongoClient from werkzeug.utils import secure_filename from application.api.user.tasks import ingest, ingest_remote from application.core.settings import settings +from application.extensions import api +from application.utils import check_required_fields from application.vectorstore.vector_creator import VectorCreator mongo = MongoClient(settings.MONGO_URI) @@ -28,14 +30,8 @@ user_logs_collection = db["user_logs"] user = Blueprint("user", __name__) -api = Api( - user, - version="1.0", - title="DocsGPT API", - description="API for DocsGPT", - default="user", - default_label="User operations", -) +user_ns = Namespace("user", description="User related operations", path="/") +api.add_namespace(user_ns) current_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -63,22 +59,7 @@ def generate_date_range(start_date, end_date): } -def check_required_fields(data, required_fields): - missing_fields = [field for field in required_fields if field not in data] - if missing_fields: - return make_response( - jsonify( - { - "success": False, - "message": f"Missing fields: {', '.join(missing_fields)}", - } - ), - 400, - ) - return None - - -@api.route("/api/delete_conversation") +@user_ns.route("/api/delete_conversation") class DeleteConversation(Resource): @api.doc( description="Deletes a conversation by ID", @@ -98,7 +79,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/delete_all_conversations") +@user_ns.route("/api/delete_all_conversations") class DeleteAllConversations(Resource): @api.doc( description="Deletes all conversations for a specific user", @@ -112,7 +93,7 @@ def get(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/get_conversations") +@user_ns.route("/api/get_conversations") class GetConversations(Resource): @api.doc( description="Retrieve a list of the latest 30 conversations", @@ -129,7 +110,7 @@ def get(self): return make_response(jsonify(list_conversations), 200) -@api.route("/api/get_single_conversation") +@user_ns.route("/api/get_single_conversation") class GetSingleConversation(Resource): @api.doc( description="Retrieve a single conversation by ID", @@ -153,7 +134,7 @@ def get(self): return make_response(jsonify(conversation["queries"]), 200) -@api.route("/api/update_conversation_name") +@user_ns.route("/api/update_conversation_name") class UpdateConversationName(Resource): @api.expect( api.model( @@ -186,7 +167,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/feedback") +@user_ns.route("/api/feedback") class SubmitFeedback(Resource): @api.expect( api.model( @@ -229,7 +210,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/delete_by_ids") +@user_ns.route("/api/delete_by_ids") class DeleteByIds(Resource): @api.doc( description="Deletes documents from the vector store by IDs", @@ -252,7 +233,7 @@ def get(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) -@api.route("/api/delete_old") +@user_ns.route("/api/delete_old") class DeleteOldIndexes(Resource): @api.doc( description="Deletes old indexes", @@ -289,7 +270,7 @@ def get(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/upload") +@user_ns.route("/api/upload") class UploadFile(Resource): @api.expect( api.model( @@ -370,7 +351,7 @@ def post(self): return make_response(jsonify({"success": True, "task_id": task.id}), 200) -@api.route("/api/remote") +@user_ns.route("/api/remote") class UploadRemote(Resource): @api.expect( api.model( @@ -408,7 +389,7 @@ def post(self): return make_response(jsonify({"success": True, "task_id": task.id}), 200) -@api.route("/api/task_status") +@user_ns.route("/api/task_status") class TaskStatus(Resource): task_status_model = api.model( "TaskStatusModel", @@ -435,7 +416,7 @@ def get(self): return make_response(jsonify({"status": task.status, "result": task_meta}), 200) -@api.route("/api/combine") +@user_ns.route("/api/combine") class CombinedJson(Resource): @api.doc(description="Provide JSON file with combined available indexes") def get(self): @@ -496,7 +477,7 @@ def get(self): return make_response(jsonify(data), 200) -@api.route("/api/docs_check") +@user_ns.route("/api/docs_check") class CheckDocs(Resource): check_docs_model = api.model( "CheckDocsModel", @@ -522,7 +503,7 @@ def post(self): return make_response(jsonify({"status": "not found"}), 404) -@api.route("/api/create_prompt") +@user_ns.route("/api/create_prompt") class CreatePrompt(Resource): create_prompt_model = api.model( "CreatePromptModel", @@ -560,7 +541,7 @@ def post(self): return make_response(jsonify({"id": new_id}), 200) -@api.route("/api/get_prompts") +@user_ns.route("/api/get_prompts") class GetPrompts(Resource): @api.doc(description="Get all prompts for the user") def get(self): @@ -587,7 +568,7 @@ def get(self): return make_response(jsonify(list_prompts), 200) -@api.route("/api/get_single_prompt") +@user_ns.route("/api/get_single_prompt") class GetSinglePrompt(Resource): @api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID") def get(self): @@ -628,7 +609,7 @@ def get(self): return make_response(jsonify({"content": prompt["content"]}), 200) -@api.route("/api/delete_prompt") +@user_ns.route("/api/delete_prompt") class DeletePrompt(Resource): delete_prompt_model = api.model( "DeletePromptModel", @@ -652,7 +633,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/update_prompt") +@user_ns.route("/api/update_prompt") class UpdatePrompt(Resource): update_prompt_model = api.model( "UpdatePromptModel", @@ -685,7 +666,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/get_api_keys") +@user_ns.route("/api/get_api_keys") class GetApiKeys(Resource): @api.doc(description="Retrieve API keys for the user") def get(self): @@ -719,7 +700,7 @@ def get(self): return make_response(jsonify(list_keys), 200) -@api.route("/api/create_api_key") +@user_ns.route("/api/create_api_key") class CreateApiKey(Resource): create_api_key_model = api.model( "CreateApiKeyModel", @@ -764,7 +745,7 @@ def post(self): return make_response(jsonify({"id": new_id, "key": key}), 201) -@api.route("/api/delete_api_key") +@user_ns.route("/api/delete_api_key") class DeleteApiKey(Resource): delete_api_key_model = api.model( "DeleteApiKeyModel", @@ -790,7 +771,7 @@ def post(self): return {"success": True}, 200 -@api.route("/api/share") +@user_ns.route("/api/share") class ShareConversation(Resource): share_conversation_model = api.model( "ShareConversationModel", @@ -988,7 +969,7 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) -@api.route("/api/shared_conversation/") +@user_ns.route("/api/shared_conversation/") class GetPubliclySharedConversations(Resource): @api.doc(description="Get publicly shared conversations by identifier") def get(self, identifier: str): @@ -1043,7 +1024,7 @@ def get(self, identifier: str): return make_response(jsonify({"success": False, "error": str(err)}), 400) -@api.route("/api/get_message_analytics") +@user_ns.route("/api/get_message_analytics") class GetMessageAnalytics(Resource): get_message_analytics_model = api.model( "GetMessageAnalyticsModel", @@ -1181,7 +1162,7 @@ def post(self): ) -@api.route("/api/get_token_analytics") +@user_ns.route("/api/get_token_analytics") class GetTokenAnalytics(Resource): get_token_analytics_model = api.model( "GetTokenAnalyticsModel", @@ -1332,7 +1313,7 @@ def post(self): ) -@api.route("/api/get_feedback_analytics") +@user_ns.route("/api/get_feedback_analytics") class GetFeedbackAnalytics(Resource): get_feedback_analytics_model = api.model( "GetFeedbackAnalyticsModel", @@ -1550,7 +1531,7 @@ def post(self): ) -@api.route("/api/get_user_logs") +@user_ns.route("/api/get_user_logs") class GetUserLogs(Resource): get_user_logs_model = api.model( "GetUserLogsModel", @@ -1629,7 +1610,7 @@ def post(self): ) -@api.route("/api/manage_sync") +@user_ns.route("/api/manage_sync") class ManageSync(Resource): manage_sync_model = api.model( "ManageSyncModel", diff --git a/application/app.py b/application/app.py index 87d9d42fa..d7727001b 100644 --- a/application/app.py +++ b/application/app.py @@ -1,15 +1,19 @@ import platform + import dotenv -from application.celery_init import celery -from flask import Flask, request, redirect -from application.core.settings import settings -from application.api.user.routes import user +from flask import Flask, redirect, request + from application.api.answer.routes import answer from application.api.internal.routes import internal +from application.api.user.routes import user +from application.celery_init import celery from application.core.logging_config import setup_logging +from application.core.settings import settings +from application.extensions import api if platform.system() == "Windows": import pathlib + pathlib.PosixPath = pathlib.WindowsPath dotenv.load_dotenv() @@ -23,16 +27,19 @@ UPLOAD_FOLDER="inputs", CELERY_BROKER_URL=settings.CELERY_BROKER_URL, CELERY_RESULT_BACKEND=settings.CELERY_RESULT_BACKEND, - MONGO_URI=settings.MONGO_URI + MONGO_URI=settings.MONGO_URI, ) celery.config_from_object("application.celeryconfig") +api.init_app(app) + @app.route("/") def home(): - if request.remote_addr in ('0.0.0.0', '127.0.0.1', 'localhost', '172.18.0.1'): - return redirect('http://localhost:5173') + if request.remote_addr in ("0.0.0.0", "127.0.0.1", "localhost", "172.18.0.1"): + return redirect("http://localhost:5173") else: - return 'Welcome to DocsGPT Backend!' + return "Welcome to DocsGPT Backend!" + @app.after_request def after_request(response): @@ -41,6 +48,6 @@ def after_request(response): response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS") return response + if __name__ == "__main__": app.run(debug=settings.FLASK_DEBUG_MODE, port=7091) - diff --git a/application/extensions.py b/application/extensions.py new file mode 100644 index 000000000..b6f528931 --- /dev/null +++ b/application/extensions.py @@ -0,0 +1,7 @@ +from flask_restx import Api + +api = Api( + version="1.0", + title="DocsGPT API", + description="API for DocsGPT", +) diff --git a/application/utils.py b/application/utils.py index 70a00ce05..f0802c390 100644 --- a/application/utils.py +++ b/application/utils.py @@ -1,22 +1,41 @@ import tiktoken +from flask import jsonify, make_response _encoding = None + def get_encoding(): global _encoding if _encoding is None: _encoding = tiktoken.get_encoding("cl100k_base") return _encoding + def num_tokens_from_string(string: str) -> int: encoding = get_encoding() num_tokens = len(encoding.encode(string)) return num_tokens + def count_tokens_docs(docs): docs_content = "" for doc in docs: docs_content += doc.page_content tokens = num_tokens_from_string(docs_content) - return tokens \ No newline at end of file + return tokens + + +def check_required_fields(data, required_fields): + missing_fields = [field for field in required_fields if field not in data] + if missing_fields: + return make_response( + jsonify( + { + "success": False, + "message": f"Missing fields: {', '.join(missing_fields)}", + } + ), + 400, + ) + return None From c9976020dd6363280e1915a872ac283dfc0c5f13 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 30 Sep 2024 01:15:36 +0530 Subject: [PATCH 3/3] fix: lint error --- application/api/user/routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 657b3673e..340d020ae 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -230,7 +230,7 @@ def get(self): except Exception as err: return make_response(jsonify({"success": False, "error": str(err)}), 400) - return make_response(jsonify({"success": False, "error": str(err)}), 400) + return make_response(jsonify({"success": False}), 400) @user_ns.route("/api/delete_old")