diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 6c5e3e9c8..cafc8c664 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -9,13 +9,11 @@ from pymongo import MongoClient from bson.objectid import ObjectId - from application.core.settings import settings from application.llm.llm_creator import LLMCreator from application.retriever.retriever_creator import RetrieverCreator from application.error import bad_request - logger = logging.getLogger(__name__) mongo = MongoClient(settings.MONGO_URI) @@ -75,8 +73,10 @@ def run_async_chain(chain, question, chat_history): def get_data_from_api_key(api_key): data = api_key_collection.find_one({"key": api_key}) + + # # Raise custom exception if the API key is not found if data is None: - return bad_request(401, "Invalid API key") + raise Exception("Invalid API Key, please generate new key", 401) return data @@ -128,10 +128,10 @@ def save_conversation(conversation_id, question, response, source_log_docs, llm) "content": "Summarise following conversation in no more than 3 " "words, respond ONLY with the summary, use the same " "language as the system \n\nUser: " - + question - + "\n\n" - + "AI: " - + response, + +question + +"\n\n" + +"AI: " + +response, }, { "role": "user", @@ -173,33 +173,39 @@ def get_prompt(prompt_id): def complete_stream(question, retriever, conversation_id, user_api_key): - response_full = "" - source_log_docs = [] - answer = retriever.gen() - for line in answer: - if "answer" in line: - response_full += str(line["answer"]) - data = json.dumps(line) - yield f"data: {data}\n\n" - elif "source" in line: - source_log_docs.append(line["source"]) - - llm = LLMCreator.create_llm( - settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key - ) - conversation_id = save_conversation( - conversation_id, question, response_full, source_log_docs, llm - ) - - # send data.type = "end" to indicate that the stream has ended as json - data = json.dumps({"type": "id", "id": str(conversation_id)}) - yield f"data: {data}\n\n" - data = json.dumps({"type": "end"}) - yield f"data: {data}\n\n" + try: + response_full = "" + source_log_docs = [] + answer = retriever.gen() + for line in answer: + if "answer" in line: + response_full += str(line["answer"]) + data = json.dumps(line) + yield f"data: {data}\n\n" + elif "source" in line: + source_log_docs.append(line["source"]) + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=user_api_key + ) + conversation_id = save_conversation( + conversation_id, question, response_full, source_log_docs, llm + ) + + # send data.type = "end" to indicate that the stream has ended as json + data = json.dumps({"type": "id", "id": str(conversation_id)}) + yield f"data: {data}\n\n" + data = json.dumps({"type": "end"}) + yield f"data: {data}\n\n" + except Exception as e: + data = json.dumps({"type": "error","error":"Please try again later. We apologize for any inconvenience.", + "error_exception": str(e)}) + yield f"data: {data}\n\n" + return @answer.route("/stream", methods=["POST"]) def stream(): + try: data = request.get_json() # get parameter from url question question = data["question"] @@ -273,7 +279,29 @@ def stream(): ), mimetype="text/event-stream", ) - + + except ValueError: + message = "Malformed request body" + return Response( + error_stream_generate(message), + status=400, + mimetype="text/event-stream", + ) + except Exception as e: + print("err",str(e)) + message = e.args[0] + status_code = 400 + # # Custom exceptions with two arguments, index 1 as status code + if(len(e.args) >= 2): + status_code = e.args[1] + return Response( + error_stream_generate(message), + status=status_code, + mimetype="text/event-stream", + ) +def error_stream_generate(err_response): + data = json.dumps({"type": "error", "error":err_response}) + yield f"data: {data}\n\n" @answer.route("/api/answer", methods=["POST"]) def api_answer(): diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 443faddc5..62addc260 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -257,8 +257,8 @@ def combined_json(): } ] # structure: name, language, version, description, fullName, date, docLink - # append data from vectors_collection - for index in vectors_collection.find({"user": user}): + # append data from vectors_collection in sorted order in descending order of date + for index in vectors_collection.find({"user": user}).sort("date", -1): data.append( { "name": index["name"], diff --git a/frontend/src/conversation/conversationSlice.ts b/frontend/src/conversation/conversationSlice.ts index 738806988..7ab9f8fe3 100644 --- a/frontend/src/conversation/conversationSlice.ts +++ b/frontend/src/conversation/conversationSlice.ts @@ -68,6 +68,15 @@ export const fetchAnswer = createAsyncThunk( query: { conversationId: data.id }, }), ); + } else if (data.type === 'error') { + // set status to 'failed' + dispatch(conversationSlice.actions.setStatus('failed')); + dispatch( + conversationSlice.actions.raiseError({ + index: state.conversation.queries.length - 1, + message: data.error, + }), + ); } else { const result = data.answer; dispatch( @@ -191,6 +200,13 @@ export const conversationSlice = createSlice({ setStatus(state, action: PayloadAction) { state.status = action.payload; }, + raiseError( + state, + action: PayloadAction<{ index: number; message: string }>, + ) { + const { index, message } = action.payload; + state.queries[index].error = message; + }, }, extraReducers(builder) { builder @@ -204,7 +220,7 @@ export const conversationSlice = createSlice({ } state.status = 'failed'; state.queries[state.queries.length - 1].error = - 'Something went wrong. Please try again later.'; + 'Something went wrong. Please check your internet connection.'; }); }, });