Skip to content

Commit

Permalink
upd: better db class, better testing
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Dec 14, 2023
1 parent 5936f34 commit 2a7daad
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 32 deletions.
6 changes: 3 additions & 3 deletions authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jose import jwt


from database.database import DatabaseConnection
from database.database import Database

SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key")
ALGORITHM = "HS256"
Expand All @@ -14,11 +14,11 @@ class User(BaseModel):
password: str = None

def create_user(user: User):
with DatabaseConnection() as connection:
with Database() as connection:
connection.query("INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password))

def get_user(email: str):
with DatabaseConnection() as connection:
with Database() as connection:
user_row = connection.query("SELECT * FROM user WHERE email = ?", (email,))[0]
for row in user_row:
return User(**row)
Expand Down
19 changes: 16 additions & 3 deletions database/database.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import os
from pathlib import Path
import sqlite3
from typing import List

class DatabaseConnection:
class Database:
def __init__(self):
db_name = "test.sqlite" if os.getenv("TESTING", "false").lower() == "true" else "database.sqlite"
self.db_path = Path(__file__).parent / db_name

def __enter__(self):
self.conn = sqlite3.connect(Path(__file__).parent / "database.sqlite")
self.conn = sqlite3.connect(self.db_path)
self.conn.row_factory = sqlite3.Row
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
self.conn.rollback()
self.conn.commit()
self.conn.close()

Expand All @@ -26,5 +33,11 @@ def query_from_file(self, file_path):
query = file.read()
self.query(query)

with DatabaseConnection() as connection:
def delete_db(self):
if self.conn:
self.conn.close()
if self.db_path.exists():
self.db_path.unlink(missing_ok=True)

with Database() as connection:
connection.query_from_file(Path(__file__).parent / "database_init.sql")
31 changes: 25 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from jose import jwt, JWTError

import document_store
from authentication import (authenticate_user, create_access_token, create_user,
get_user, User, SECRET_KEY, ALGORITHM)
from user_management import (authenticate_user, create_access_token, create_user,
get_user, User, SECRET_KEY, ALGORITHM, user_exists)
from document_store import StorageBackend
from model import Doc

Expand Down Expand Up @@ -41,16 +41,35 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:

@app.post("/user/signup")
async def signup(user: User):
try:
user = get_user(user.email)
if user_exists(user.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"User {user.email} already registered"
)
except Exception as e:
create_user(user)

create_user(user)
return {"email": user.email}


@app.delete("/user/")
async def delete_user(current_user: User = Depends(get_current_user)):
email = current_user.email
try:
user = get_user(email)
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"User {email} not found"
)
delete_user(email)
return {"detail": f"User {email} deleted"}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal Server Error"
)


@app.post("/user/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(form_data.username, form_data.password)
Expand Down
15 changes: 0 additions & 15 deletions sandbox_alexis/storage_backend.py

This file was deleted.

40 changes: 35 additions & 5 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,51 @@
import os
os.environ["TESTING"] = "True"

from pathlib import Path
from fastapi.testclient import TestClient
import pytest

from database.database import Database
from main import app

client = TestClient(app)

def test_signup():
@pytest.fixture()
def initialize_database():
db = Database()
with db:
db.query_from_file(Path(__file__).parents[1] / "database" / "database_init.sql")
yield db
db.delete_db()

def test_signup(initialize_database):
response = client.post("/user/signup", json={"email": "[email protected]", "password": "testpassword"})
assert response.status_code == 200
assert response.json()["email"] == "[email protected]"

response = client.post("/user/signup", json={"email": "[email protected]", "password": "testpassword"})
assert response.status_code == 400
assert "detail" in response.json()
assert response.json()["detail"] == "User [email protected] already registered"

def test_login(initialize_database):
response = client.post("/user/signup", json={"email": "[email protected]", "password": "testpassword"})
assert response.status_code == 200
assert response.json()["email"] == "[email protected]"
response = client.post("/user/login", data={"username": "[email protected]", "password": "testpassword"})
assert response.status_code == 200
assert "access_token" in response.json()

def test_login():
def test_user_me(initialize_database):
response = client.post("/user/signup", json={"email": "[email protected]", "password": "testpassword"})
assert response.status_code == 200
assert response.json()["email"] == "[email protected]"

response = client.post("/user/login", data={"username": "[email protected]", "password": "testpassword"})
assert response.status_code == 200
assert "access_token" in response.json()

def test_user_me():
login_response = client.post("/user/login", data={"username": "[email protected]", "password": "testpassword"})
token = login_response.json()["access_token"]
token = response.json()["access_token"]
response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"})
assert response.status_code == 200
assert response.json()["email"] == "[email protected]"
50 changes: 50 additions & 0 deletions user_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from datetime import timedelta, datetime
import os
from pydantic import BaseModel
from jose import jwt


from database.database import Database

SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key")
ALGORITHM = "HS256"

class User(BaseModel):
email: str = None
password: str = None

def create_user(user: User):
with Database() as connection:
connection.query("INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password))

def user_exists(email: str) -> bool:
with Database() as connection:
result = connection.query("SELECT 1 FROM user WHERE email = ?", (email,))[0]
return bool(result)

def get_user(email: str):
with Database() as connection:
user_row = connection.query("SELECT * FROM user WHERE email = ?", (email,))[0]
for row in user_row:
return User(**row)
raise Exception("User not found")

def delete_user(email: str):
with Database() as connection:
connection.query("DELETE FROM user WHERE email = ?", (email,))

def authenticate_user(username: str, password: str):
user = get_user(username)
if not user or not password == user.password:
return False
return user

def create_access_token(*, data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

0 comments on commit 2a7daad

Please sign in to comment.