From e8988e82d00b57e37ec77b278d9516287ce35bf6 Mon Sep 17 00:00:00 2001 From: Siddhant Rai Date: Mon, 30 Sep 2024 00:41:34 +0530 Subject: [PATCH] refactor: answer routes to comply with OpenAPI spec using flask-restx --- application/api/answer/routes.py | 620 ++++++++++++++++++------------- application/api/user/routes.py | 85 ++--- application/app.py | 25 +- application/extensions.py | 7 + application/utils.py | 21 +- 5 files changed, 429 insertions(+), 329 deletions(-) create mode 100644 application/extensions.py diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 6f0315152..35b95174d 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -1,20 +1,24 @@ import asyncio -import os -import sys -from flask import Blueprint, request, Response, current_app -import json import datetime +import json import logging +import os +import sys import traceback -from pymongo import MongoClient -from bson.objectid import ObjectId from bson.dbref import DBRef +from bson.objectid import ObjectId +from flask import Blueprint, current_app, make_response, request, Response +from flask_restx import fields, Namespace, Resource + +from pymongo import MongoClient from application.core.settings import settings +from application.error import bad_request +from application.extensions import api from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator -from application.error import bad_request +from application.utils import check_required_fields logger = logging.getLogger(__name__) @@ -25,7 +29,10 @@ prompts_collection = db["prompts"] api_key_collection = db["api_keys"] user_logs_collection = db["user_logs"] + answer = Blueprint("answer", __name__) +answer_ns = Namespace("answer", description="Answer related operations", path="/") +api.add_namespace(answer_ns) gpt_model = "" # to have some kind of default behaviour @@ -186,10 +193,10 @@ def complete_stream( answer = retriever.gen() sources = retriever.search() for source in sources: - if("text" in source): - source["text"] = source["text"][:100].strip()+"..." - if(len(sources) > 0): - data = json.dumps({"type":"source","source":sources}) + if "text" in source: + source["text"] = source["text"][:100].strip() + "..." + if len(sources) > 0: + data = json.dumps({"type": "source", "source": sources}) yield f"data: {data}\n\n" for line in answer: if "answer" in line: @@ -243,109 +250,133 @@ def complete_stream( return -@answer.route("/stream", methods=["POST"]) -def stream(): - try: +@answer_ns.route("/stream") +class Stream(Resource): + stream_model = api.model( + "StreamModel", + { + "question": fields.String( + required=True, description="Question to be asked" + ), + "history": fields.List( + fields.String, required=False, description="Chat history" + ), + "conversation_id": fields.String( + required=False, description="Conversation ID" + ), + "prompt_id": fields.String( + required=False, default="default", description="Prompt ID" + ), + "selectedDocs": fields.String( + required=False, description="Selected documents" + ), + "chunks": fields.Integer( + required=False, default=2, description="Number of chunks" + ), + "token_limit": fields.Integer(required=False, description="Token limit"), + "retriever": fields.String(required=False, description="Retriever type"), + "api_key": fields.String(required=False, description="API key"), + "active_docs": fields.String( + required=False, description="Active documents" + ), + "isNoneDoc": fields.Boolean( + required=False, description="Flag indicating if no document is used" + ), + }, + ) + + @api.expect(stream_model) + @api.doc(description="Stream a response based on the question and retriever") + def post(self): data = request.get_json() - question = data["question"] - if "history" not in data: - history = [] - else: - history = data["history"] + required_fields = ["question"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + question = data["question"] + history = data.get("history", []) history = json.loads(history) - if "conversation_id" not in data: - conversation_id = None - else: - conversation_id = data["conversation_id"] - if "prompt_id" in data: - prompt_id = data["prompt_id"] - else: - prompt_id = "default" - if "selectedDocs" in data and data["selectedDocs"] is None: - chunks = 0 - elif "chunks" in data: - chunks = int(data["chunks"]) - else: - chunks = 2 - if "token_limit" in data: - token_limit = data["token_limit"] - else: - token_limit = settings.DEFAULT_MAX_HISTORY - - ## retriever can be "brave_search, duckduck_search or classic" - retriever_name = data["retriever"] if "retriever" in data else "classic" - - # check if active_docs or api_key is set - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key["chunks"]) - prompt_id = data_key["prompt_id"] - source = {"active_docs": data_key["source"]} - retriever_name = data_key["retriever"] or retriever_name - user_api_key = data["api_key"] - - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - retriever_name = get_retriever(data["active_docs"]) or retriever_name - user_api_key = None - - else: - source = {} - user_api_key = None - - current_app.logger.info( - f"/stream - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})}, - ) + conversation_id = data.get("conversation_id") + prompt_id = data.get("prompt_id", "default") + if "selectedDocs" in data and data["selectedDocs"] is None: + chunks = 0 + else: + chunks = int(data.get("chunks", 2)) + token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) + retriever_name = data.get("retriever", "classic") + + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + chunks = int(data_key.get("chunks", 2)) + prompt_id = data_key.get("prompt_id", "default") + source = {"active_docs": data_key.get("source")} + retriever_name = data_key.get("retriever", retriever_name) + user_api_key = data["api_key"] + + elif "active_docs" in data: + source = {"active_docs": data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) or retriever_name + user_api_key = None + + else: + source = {} + user_api_key = None + + current_app.logger.info( + f"/stream - request_data: {data}, source: {source}", + extra={"data": json.dumps({"request_data": data, "source": source})}, + ) - prompt = get_prompt(prompt_id) - - retriever = RetrieverCreator.create_retriever( - retriever_name, - question=question, - source=source, - chat_history=history, - prompt=prompt, - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, - ) + prompt = get_prompt(prompt_id) - return Response( - complete_stream( + retriever = RetrieverCreator.create_retriever( + retriever_name, question=question, - retriever=retriever, - conversation_id=conversation_id, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + token_limit=token_limit, + gpt_model=gpt_model, user_api_key=user_api_key, - isNoneDoc=data.get("isNoneDoc"), - ), - mimetype="text/event-stream", - ) + ) - except ValueError: - message = "Malformed request body" - print("\033[91merr", str(message), file=sys.stderr) - return Response( - error_stream_generate(message), - status=400, - mimetype="text/event-stream", - ) - except Exception as e: - current_app.logger.error( - f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}", - extra={"error": str(e), "traceback": traceback.format_exc()}, - ) - message = e.args[0] - status_code = 400 - # # Custom exceptions with two arguments, index 1 as status code - if len(e.args) >= 2: - status_code = e.args[1] - return Response( - error_stream_generate(message), - status=status_code, - mimetype="text/event-stream", - ) + return Response( + complete_stream( + question=question, + retriever=retriever, + conversation_id=conversation_id, + user_api_key=user_api_key, + isNoneDoc=data.get("isNoneDoc"), + ), + mimetype="text/event-stream", + ) + + except ValueError: + message = "Malformed request body" + print("\033[91merr", str(message), file=sys.stderr) + return Response( + error_stream_generate(message), + status=400, + mimetype="text/event-stream", + ) + except Exception as e: + current_app.logger.error( + f"/stream - error: {str(e)} - traceback: {traceback.format_exc()}", + extra={"error": str(e), "traceback": traceback.format_exc()}, + ) + message = e.args[0] + status_code = 400 + # Custom exceptions with two arguments, index 1 as status code + if len(e.args) >= 2: + status_code = e.args[1] + return Response( + error_stream_generate(message), + status=status_code, + mimetype="text/event-stream", + ) def error_stream_generate(err_response): @@ -353,180 +384,235 @@ def error_stream_generate(err_response): yield f"data: {data}\n\n" -@answer.route("/api/answer", methods=["POST"]) -def api_answer(): - data = request.get_json() - question = data["question"] - if "history" not in data: - history = [] - else: - history = data["history"] - if "conversation_id" not in data: - conversation_id = None - else: - conversation_id = data["conversation_id"] - print("-" * 5) - if "prompt_id" in data: - prompt_id = data["prompt_id"] - else: - prompt_id = "default" - if "chunks" in data: - chunks = int(data["chunks"]) - else: - chunks = 2 - if "token_limit" in data: - token_limit = data["token_limit"] - else: - token_limit = settings.DEFAULT_MAX_HISTORY - - ## retriever can be brave_search, duckduck_search or classic - retriever_name = data["retriever"] if "retriever" in data else "classic" +@answer_ns.route("/api/answer") +class Answer(Resource): + answer_model = api.model( + "AnswerModel", + { + "question": fields.String( + required=True, description="The question to answer" + ), + "history": fields.List( + fields.String, required=False, description="Conversation history" + ), + "conversation_id": fields.String( + required=False, description="Conversation ID" + ), + "prompt_id": fields.String( + required=False, default="default", description="Prompt ID" + ), + "chunks": fields.Integer( + required=False, default=2, description="Number of chunks" + ), + "token_limit": fields.Integer(required=False, description="Token limit"), + "retriever": fields.String(required=False, description="Retriever type"), + "api_key": fields.String(required=False, description="API key"), + "active_docs": fields.String( + required=False, description="Active documents" + ), + "isNoneDoc": fields.Boolean( + required=False, description="Flag indicating if no document is used" + ), + }, + ) - # use try and except to check for exception - try: - # check if the vectorstore is set - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key["chunks"]) - prompt_id = data_key["prompt_id"] - source = {"active_docs": data_key["source"]} - retriever_name = data_key["retriever"] or retriever_name - user_api_key = data["api_key"] - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - retriever_name = get_retriever(data["active_docs"]) or retriever_name - user_api_key = None - else: - source = {} - user_api_key = None - - prompt = get_prompt(prompt_id) - - current_app.logger.info( - f"/api/answer - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})}, - ) + @api.expect(answer_model) + @api.doc(description="Provide an answer based on the question and retriever") + def post(self): + data = request.get_json() + required_fields = ["question"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + question = data["question"] + history = data.get("history", []) + conversation_id = data.get("conversation_id") + prompt_id = data.get("prompt_id", "default") + chunks = int(data.get("chunks", 2)) + token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) + retriever_name = data.get("retriever", "classic") + + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + chunks = int(data_key.get("chunks", 2)) + prompt_id = data_key.get("prompt_id", "default") + source = {"active_docs": data_key.get("source")} + retriever_name = data_key.get("retriever", retriever_name) + user_api_key = data["api_key"] + elif "active_docs" in data: + source = {"active_docs": data["active_docs"]} + retriever_name = get_retriever(data["active_docs"]) or retriever_name + user_api_key = None + else: + source = {} + user_api_key = None + + prompt = get_prompt(prompt_id) + + current_app.logger.info( + f"/api/answer - request_data: {data}, source: {source}", + extra={"data": json.dumps({"request_data": data, "source": source})}, + ) - retriever = RetrieverCreator.create_retriever( - retriever_name, - question=question, - source=source, - chat_history=history, - prompt=prompt, - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, - ) - source_log_docs = [] - response_full = "" - for line in retriever.gen(): - if "source" in line: - source_log_docs.append(line["source"]) - elif "answer" in line: - response_full += line["answer"] + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=history, + prompt=prompt, + chunks=chunks, + token_limit=token_limit, + gpt_model=gpt_model, + user_api_key=user_api_key, + ) - if data.get("isNoneDoc"): - for doc in source_log_docs: - doc["source"] = "None" + source_log_docs = [] + response_full = "" + for line in retriever.gen(): + if "source" in line: + source_log_docs.append(line["source"]) + elif "answer" in line: + response_full += line["answer"] - llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key - ) + if data.get("isNoneDoc"): + for doc in source_log_docs: + doc["source"] = "None" - result = {"answer": response_full, "sources": source_log_docs} - result["conversation_id"] = str( - save_conversation( - conversation_id, question, response_full, source_log_docs, llm + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key ) - ) - retriever_params = retriever.get_params() - user_logs_collection.insert_one( - { - "action": "api_answer", - "level": "info", - "user": "local", - "api_key": user_api_key, - "question": question, - "response": response_full, - "sources": source_log_docs, - "retriever_params": retriever_params, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - ) - return result - except Exception as e: - current_app.logger.error( - f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", - extra={"error": str(e), "traceback": traceback.format_exc()}, - ) - return bad_request(500, str(e)) + result = {"answer": response_full, "sources": source_log_docs} + result["conversation_id"] = str( + save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) + ) + retriever_params = retriever.get_params() + user_logs_collection.insert_one( + { + "action": "api_answer", + "level": "info", + "user": "local", + "api_key": user_api_key, + "question": question, + "response": response_full, + "sources": source_log_docs, + "retriever_params": retriever_params, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + ) + except Exception as e: + current_app.logger.error( + f"/api/answer - error: {str(e)} - traceback: {traceback.format_exc()}", + extra={"error": str(e), "traceback": traceback.format_exc()}, + ) + return bad_request(500, str(e)) -@answer.route("/api/search", methods=["POST"]) -def api_search(): - data = request.get_json() - question = data["question"] - if "chunks" in data: - chunks = int(data["chunks"]) - else: - chunks = 2 - if "api_key" in data: - data_key = get_data_from_api_key(data["api_key"]) - chunks = int(data_key["chunks"]) - source = {"active_docs":data_key["source"]} - user_api_key = data["api_key"] - elif "active_docs" in data: - source = {"active_docs": data["active_docs"]} - user_api_key = None - else: - source = {} - user_api_key = None + return make_response(result, 200) - if "retriever" in data: - retriever_name = data["retriever"] - else: - retriever_name = "classic" - if "token_limit" in data: - token_limit = data["token_limit"] - else: - token_limit = settings.DEFAULT_MAX_HISTORY - current_app.logger.info( - f"/api/answer - request_data: {data}, source: {source}", - extra={"data": json.dumps({"request_data": data, "source": source})}, +@answer_ns.route("/api/search") +class Search(Resource): + search_model = api.model( + "SearchModel", + { + "question": fields.String( + required=True, description="The question to search" + ), + "chunks": fields.Integer( + required=False, default=2, description="Number of chunks" + ), + "api_key": fields.String( + required=False, description="API key for authentication" + ), + "active_docs": fields.String( + required=False, description="Active documents for retrieval" + ), + "retriever": fields.String(required=False, description="Retriever type"), + "token_limit": fields.Integer( + required=False, description="Limit for tokens" + ), + "isNoneDoc": fields.Boolean( + required=False, description="Flag indicating if no document is used" + ), + }, ) - retriever = RetrieverCreator.create_retriever( - retriever_name, - question=question, - source=source, - chat_history=[], - prompt="default", - chunks=chunks, - token_limit=token_limit, - gpt_model=gpt_model, - user_api_key=user_api_key, + @api.expect(search_model) + @api.doc( + description="Search for relevant documents based on the question and retriever" ) - docs = retriever.search() + def post(self): + data = request.get_json() + required_fields = ["question"] + missing_fields = check_required_fields(data, required_fields) + if missing_fields: + return missing_fields + + try: + question = data["question"] + chunks = int(data.get("chunks", 2)) + token_limit = data.get("token_limit", settings.DEFAULT_MAX_HISTORY) + retriever_name = data.get("retriever", "classic") + + if "api_key" in data: + data_key = get_data_from_api_key(data["api_key"]) + chunks = int(data_key.get("chunks", 2)) + source = {"active_docs": data_key.get("source")} + user_api_key = data["api_key"] + elif "active_docs" in data: + source = {"active_docs": data["active_docs"]} + user_api_key = None + else: + source = {} + user_api_key = None + + current_app.logger.info( + f"/api/answer - request_data: {data}, source: {source}", + extra={"data": json.dumps({"request_data": data, "source": source})}, + ) - retriever_params = retriever.get_params() - user_logs_collection.insert_one( - { - "action": "api_search", - "level": "info", - "user": "local", - "api_key": user_api_key, - "question": question, - "sources": docs, - "retriever_params": retriever_params, - "timestamp": datetime.datetime.now(datetime.timezone.utc), - } - ) + retriever = RetrieverCreator.create_retriever( + retriever_name, + question=question, + source=source, + chat_history=[], + prompt="default", + chunks=chunks, + token_limit=token_limit, + gpt_model=gpt_model, + user_api_key=user_api_key, + ) - if data.get("isNoneDoc"): - for doc in docs: - doc["source"] = "None" + docs = retriever.search() + retriever_params = retriever.get_params() + + user_logs_collection.insert_one( + { + "action": "api_search", + "level": "info", + "user": "local", + "api_key": user_api_key, + "question": question, + "sources": docs, + "retriever_params": retriever_params, + "timestamp": datetime.datetime.now(datetime.timezone.utc), + } + ) + + if data.get("isNoneDoc"): + for doc in docs: + doc["source"] = "None" + + except Exception as e: + current_app.logger.error( + f"/api/search - error: {str(e)} - traceback: {traceback.format_exc()}", + extra={"error": str(e), "traceback": traceback.format_exc()}, + ) + return bad_request(500, str(e)) - return docs + return make_response(docs, 200) diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 0bc11a1be..657b3673e 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -7,13 +7,15 @@ from bson.dbref import DBRef from bson.objectid import ObjectId from flask import Blueprint, jsonify, make_response, request -from flask_restx import Api, fields, Namespace, Resource +from flask_restx import fields, Namespace, Resource from pymongo import MongoClient from werkzeug.utils import secure_filename from application.api.user.tasks import ingest, ingest_remote from application.core.settings import settings +from application.extensions import api +from application.utils import check_required_fields from application.vectorstore.vector_creator import VectorCreator mongo = MongoClient(settings.MONGO_URI) @@ -28,14 +30,8 @@ user_logs_collection = db["user_logs"] user = Blueprint("user", __name__) -api = Api( - user, - version="1.0", - title="DocsGPT API", - description="API for DocsGPT", - default="user", - default_label="User operations", -) +user_ns = Namespace("user", description="User related operations", path="/") +api.add_namespace(user_ns) current_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -63,22 +59,7 @@ def generate_date_range(start_date, end_date): } -def check_required_fields(data, required_fields): - missing_fields = [field for field in required_fields if field not in data] - if missing_fields: - return make_response( - jsonify( - { - "success": False, - "message": f"Missing fields: {', '.join(missing_fields)}", - } - ), - 400, - ) - return None - - -@api.route("/api/delete_conversation") +@user_ns.route("/api/delete_conversation") class DeleteConversation(Resource): @api.doc( description="Deletes a conversation by ID", @@ -98,7 +79,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/delete_all_conversations") +@user_ns.route("/api/delete_all_conversations") class DeleteAllConversations(Resource): @api.doc( description="Deletes all conversations for a specific user", @@ -112,7 +93,7 @@ def get(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/get_conversations") +@user_ns.route("/api/get_conversations") class GetConversations(Resource): @api.doc( description="Retrieve a list of the latest 30 conversations", @@ -129,7 +110,7 @@ def get(self): return make_response(jsonify(list_conversations), 200) -@api.route("/api/get_single_conversation") +@user_ns.route("/api/get_single_conversation") class GetSingleConversation(Resource): @api.doc( description="Retrieve a single conversation by ID", @@ -153,7 +134,7 @@ def get(self): return make_response(jsonify(conversation["queries"]), 200) -@api.route("/api/update_conversation_name") +@user_ns.route("/api/update_conversation_name") class UpdateConversationName(Resource): @api.expect( api.model( @@ -186,7 +167,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/feedback") +@user_ns.route("/api/feedback") class SubmitFeedback(Resource): @api.expect( api.model( @@ -229,7 +210,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/delete_by_ids") +@user_ns.route("/api/delete_by_ids") class DeleteByIds(Resource): @api.doc( description="Deletes documents from the vector store by IDs", @@ -252,7 +233,7 @@ def get(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) -@api.route("/api/delete_old") +@user_ns.route("/api/delete_old") class DeleteOldIndexes(Resource): @api.doc( description="Deletes old indexes", @@ -289,7 +270,7 @@ def get(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/upload") +@user_ns.route("/api/upload") class UploadFile(Resource): @api.expect( api.model( @@ -370,7 +351,7 @@ def post(self): return make_response(jsonify({"success": True, "task_id": task.id}), 200) -@api.route("/api/remote") +@user_ns.route("/api/remote") class UploadRemote(Resource): @api.expect( api.model( @@ -408,7 +389,7 @@ def post(self): return make_response(jsonify({"success": True, "task_id": task.id}), 200) -@api.route("/api/task_status") +@user_ns.route("/api/task_status") class TaskStatus(Resource): task_status_model = api.model( "TaskStatusModel", @@ -435,7 +416,7 @@ def get(self): return make_response(jsonify({"status": task.status, "result": task_meta}), 200) -@api.route("/api/combine") +@user_ns.route("/api/combine") class CombinedJson(Resource): @api.doc(description="Provide JSON file with combined available indexes") def get(self): @@ -496,7 +477,7 @@ def get(self): return make_response(jsonify(data), 200) -@api.route("/api/docs_check") +@user_ns.route("/api/docs_check") class CheckDocs(Resource): check_docs_model = api.model( "CheckDocsModel", @@ -522,7 +503,7 @@ def post(self): return make_response(jsonify({"status": "not found"}), 404) -@api.route("/api/create_prompt") +@user_ns.route("/api/create_prompt") class CreatePrompt(Resource): create_prompt_model = api.model( "CreatePromptModel", @@ -560,7 +541,7 @@ def post(self): return make_response(jsonify({"id": new_id}), 200) -@api.route("/api/get_prompts") +@user_ns.route("/api/get_prompts") class GetPrompts(Resource): @api.doc(description="Get all prompts for the user") def get(self): @@ -587,7 +568,7 @@ def get(self): return make_response(jsonify(list_prompts), 200) -@api.route("/api/get_single_prompt") +@user_ns.route("/api/get_single_prompt") class GetSinglePrompt(Resource): @api.doc(params={"id": "ID of the prompt"}, description="Get a single prompt by ID") def get(self): @@ -628,7 +609,7 @@ def get(self): return make_response(jsonify({"content": prompt["content"]}), 200) -@api.route("/api/delete_prompt") +@user_ns.route("/api/delete_prompt") class DeletePrompt(Resource): delete_prompt_model = api.model( "DeletePromptModel", @@ -652,7 +633,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/update_prompt") +@user_ns.route("/api/update_prompt") class UpdatePrompt(Resource): update_prompt_model = api.model( "UpdatePromptModel", @@ -685,7 +666,7 @@ def post(self): return make_response(jsonify({"success": True}), 200) -@api.route("/api/get_api_keys") +@user_ns.route("/api/get_api_keys") class GetApiKeys(Resource): @api.doc(description="Retrieve API keys for the user") def get(self): @@ -719,7 +700,7 @@ def get(self): return make_response(jsonify(list_keys), 200) -@api.route("/api/create_api_key") +@user_ns.route("/api/create_api_key") class CreateApiKey(Resource): create_api_key_model = api.model( "CreateApiKeyModel", @@ -764,7 +745,7 @@ def post(self): return make_response(jsonify({"id": new_id, "key": key}), 201) -@api.route("/api/delete_api_key") +@user_ns.route("/api/delete_api_key") class DeleteApiKey(Resource): delete_api_key_model = api.model( "DeleteApiKeyModel", @@ -790,7 +771,7 @@ def post(self): return {"success": True}, 200 -@api.route("/api/share") +@user_ns.route("/api/share") class ShareConversation(Resource): share_conversation_model = api.model( "ShareConversationModel", @@ -988,7 +969,7 @@ def post(self): return make_response(jsonify({"success": False, "error": str(err)}), 400) -@api.route("/api/shared_conversation/") +@user_ns.route("/api/shared_conversation/") class GetPubliclySharedConversations(Resource): @api.doc(description="Get publicly shared conversations by identifier") def get(self, identifier: str): @@ -1043,7 +1024,7 @@ def get(self, identifier: str): return make_response(jsonify({"success": False, "error": str(err)}), 400) -@api.route("/api/get_message_analytics") +@user_ns.route("/api/get_message_analytics") class GetMessageAnalytics(Resource): get_message_analytics_model = api.model( "GetMessageAnalyticsModel", @@ -1181,7 +1162,7 @@ def post(self): ) -@api.route("/api/get_token_analytics") +@user_ns.route("/api/get_token_analytics") class GetTokenAnalytics(Resource): get_token_analytics_model = api.model( "GetTokenAnalyticsModel", @@ -1332,7 +1313,7 @@ def post(self): ) -@api.route("/api/get_feedback_analytics") +@user_ns.route("/api/get_feedback_analytics") class GetFeedbackAnalytics(Resource): get_feedback_analytics_model = api.model( "GetFeedbackAnalyticsModel", @@ -1550,7 +1531,7 @@ def post(self): ) -@api.route("/api/get_user_logs") +@user_ns.route("/api/get_user_logs") class GetUserLogs(Resource): get_user_logs_model = api.model( "GetUserLogsModel", @@ -1629,7 +1610,7 @@ def post(self): ) -@api.route("/api/manage_sync") +@user_ns.route("/api/manage_sync") class ManageSync(Resource): manage_sync_model = api.model( "ManageSyncModel", diff --git a/application/app.py b/application/app.py index 87d9d42fa..d7727001b 100644 --- a/application/app.py +++ b/application/app.py @@ -1,15 +1,19 @@ import platform + import dotenv -from application.celery_init import celery -from flask import Flask, request, redirect -from application.core.settings import settings -from application.api.user.routes import user +from flask import Flask, redirect, request + from application.api.answer.routes import answer from application.api.internal.routes import internal +from application.api.user.routes import user +from application.celery_init import celery from application.core.logging_config import setup_logging +from application.core.settings import settings +from application.extensions import api if platform.system() == "Windows": import pathlib + pathlib.PosixPath = pathlib.WindowsPath dotenv.load_dotenv() @@ -23,16 +27,19 @@ UPLOAD_FOLDER="inputs", CELERY_BROKER_URL=settings.CELERY_BROKER_URL, CELERY_RESULT_BACKEND=settings.CELERY_RESULT_BACKEND, - MONGO_URI=settings.MONGO_URI + MONGO_URI=settings.MONGO_URI, ) celery.config_from_object("application.celeryconfig") +api.init_app(app) + @app.route("/") def home(): - if request.remote_addr in ('0.0.0.0', '127.0.0.1', 'localhost', '172.18.0.1'): - return redirect('http://localhost:5173') + if request.remote_addr in ("0.0.0.0", "127.0.0.1", "localhost", "172.18.0.1"): + return redirect("http://localhost:5173") else: - return 'Welcome to DocsGPT Backend!' + return "Welcome to DocsGPT Backend!" + @app.after_request def after_request(response): @@ -41,6 +48,6 @@ def after_request(response): response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS") return response + if __name__ == "__main__": app.run(debug=settings.FLASK_DEBUG_MODE, port=7091) - diff --git a/application/extensions.py b/application/extensions.py new file mode 100644 index 000000000..b6f528931 --- /dev/null +++ b/application/extensions.py @@ -0,0 +1,7 @@ +from flask_restx import Api + +api = Api( + version="1.0", + title="DocsGPT API", + description="API for DocsGPT", +) diff --git a/application/utils.py b/application/utils.py index 70a00ce05..f0802c390 100644 --- a/application/utils.py +++ b/application/utils.py @@ -1,22 +1,41 @@ import tiktoken +from flask import jsonify, make_response _encoding = None + def get_encoding(): global _encoding if _encoding is None: _encoding = tiktoken.get_encoding("cl100k_base") return _encoding + def num_tokens_from_string(string: str) -> int: encoding = get_encoding() num_tokens = len(encoding.encode(string)) return num_tokens + def count_tokens_docs(docs): docs_content = "" for doc in docs: docs_content += doc.page_content tokens = num_tokens_from_string(docs_content) - return tokens \ No newline at end of file + return tokens + + +def check_required_fields(data, required_fields): + missing_fields = [field for field in required_fields if field not in data] + if missing_fields: + return make_response( + jsonify( + { + "success": False, + "message": f"Missing fields: {', '.join(missing_fields)}", + } + ), + 400, + ) + return None