diff --git a/main.py b/main.py index c089cdd..cbc829d 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,11 @@ from datetime import timedelta from typing import List +from uuid import uuid4 from fastapi import FastAPI, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import jwt, JWTError +from database.database import Database import document_store from user_management import (authenticate_user, create_access_token, create_user, @@ -103,6 +105,10 @@ async def chat_new(current_user: User = Depends(get_current_user)): async def chat_prompt(current_user: User = Depends(get_current_user)): pass +@app.post("/chat/regenerate") +async def chat_regenerate(current_user: User = Depends(get_current_user)): + pass + @app.get("/chat/list") async def chat_list(current_user: User = Depends(get_current_user)): pass @@ -116,17 +122,15 @@ async def chat(chat_id: str, current_user: User = Depends(get_current_user)): ### Feedback ### ############################################ -@app.post("/feedback/thumbs_up") -async def feedback_thumbs_up(current_user: User = Depends(get_current_user)): - pass - -@app.post("/feedback/thumbs_down") -async def feedback_thumbs_down(current_user: User = Depends(get_current_user)): - pass +@app.post("/feedback/{message_id}/thumbs_up") +async def feedback_thumbs_up(message_id, current_user: User = Depends(get_current_user)): + with Database() as connection: + connection.query("INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)", (str(uuid4()), message_id, "thumbs_up")) -@app.post("/feedback/regenerate") -async def feedback_regenerate(current_user: User = Depends(get_current_user)): - pass +@app.post("/feedback/{message_id}/thumbs_down") +async def feedback_thumbs_down(message_id, current_user: User = Depends(get_current_user)): + with Database() as connection: + connection.query("INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)", (str(uuid4()), message_id, "thumbs_down")) ############################################ diff --git a/tests/test_feedback.py b/tests/test_feedback.py new file mode 100644 index 0000000..a0b8031 --- /dev/null +++ b/tests/test_feedback.py @@ -0,0 +1,55 @@ +import os +os.environ["TESTING"] = "True" + +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +from database.database import Database +from main import app + + +client = TestClient(app) + +@pytest.fixture(scope="module") +def context(): + db = Database() + with db: + db.query_from_file(Path(__file__).parents[1] / "database" / "database_init.sql") + + user_data = { + "email": "test@example.com", + "password": "testpassword" + } + + response = client.post("/user/signup", json=user_data) + assert response.status_code == 200 + response = client.post("/user/login", data={"username": user_data["email"], "password": user_data["password"]}) + assert response.status_code == 200 + token = response.json()["access_token"] + client.headers = { + **client.headers, + "Authorization": f"Bearer {token}" + } + + yield client.headers, db + db.delete_db() + +def test_feedback_thumbs_up(context): + headers, db = context[0], context[1] + message_id = "test_message_id_1" + response = client.post(f"/feedback/{message_id}/thumbs_up", headers=headers) + assert response.status_code == 200 + with db: + result = db.query("SELECT 1 FROM feedback WHERE message_id = ?", (message_id, ))[0] + assert len(result) == 1 + +def test_feedback_thumbs_down(context): + headers, db = context[0], context[1] + message_id = "test_message_id_2" + response = client.post(f"/feedback/{message_id}/thumbs_down", headers=headers) + assert response.status_code == 200 + with db: + result = db.query("SELECT 1 FROM feedback WHERE message_id = ?", (message_id, ))[0] + assert len(result) == 1 diff --git a/tests/test_api.py b/tests/test_users.py similarity index 100% rename from tests/test_api.py rename to tests/test_users.py