-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
upd: better db class, better testing
- Loading branch information
1 parent
5936f34
commit 2a7daad
Showing
6 changed files
with
129 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,51 @@ | ||
import os | ||
os.environ["TESTING"] = "True" | ||
|
||
from pathlib import Path | ||
from fastapi.testclient import TestClient | ||
import pytest | ||
|
||
from database.database import Database | ||
from main import app | ||
|
||
client = TestClient(app) | ||
|
||
def test_signup(): | ||
@pytest.fixture() | ||
def initialize_database(): | ||
db = Database() | ||
with db: | ||
db.query_from_file(Path(__file__).parents[1] / "database" / "database_init.sql") | ||
yield db | ||
db.delete_db() | ||
|
||
def test_signup(initialize_database): | ||
response = client.post("/user/signup", json={"email": "[email protected]", "password": "testpassword"}) | ||
assert response.status_code == 200 | ||
assert response.json()["email"] == "[email protected]" | ||
|
||
response = client.post("/user/signup", json={"email": "[email protected]", "password": "testpassword"}) | ||
assert response.status_code == 400 | ||
assert "detail" in response.json() | ||
assert response.json()["detail"] == "User [email protected] already registered" | ||
|
||
def test_login(initialize_database): | ||
response = client.post("/user/signup", json={"email": "[email protected]", "password": "testpassword"}) | ||
assert response.status_code == 200 | ||
assert response.json()["email"] == "[email protected]" | ||
response = client.post("/user/login", data={"username": "[email protected]", "password": "testpassword"}) | ||
assert response.status_code == 200 | ||
assert "access_token" in response.json() | ||
|
||
def test_login(): | ||
def test_user_me(initialize_database): | ||
response = client.post("/user/signup", json={"email": "[email protected]", "password": "testpassword"}) | ||
assert response.status_code == 200 | ||
assert response.json()["email"] == "[email protected]" | ||
|
||
response = client.post("/user/login", data={"username": "[email protected]", "password": "testpassword"}) | ||
assert response.status_code == 200 | ||
assert "access_token" in response.json() | ||
|
||
def test_user_me(): | ||
login_response = client.post("/user/login", data={"username": "[email protected]", "password": "testpassword"}) | ||
token = login_response.json()["access_token"] | ||
token = response.json()["access_token"] | ||
response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"}) | ||
assert response.status_code == 200 | ||
assert response.json()["email"] == "[email protected]" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from datetime import timedelta, datetime | ||
import os | ||
from pydantic import BaseModel | ||
from jose import jwt | ||
|
||
|
||
from database.database import Database | ||
|
||
SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key") | ||
ALGORITHM = "HS256" | ||
|
||
class User(BaseModel): | ||
email: str = None | ||
password: str = None | ||
|
||
def create_user(user: User): | ||
with Database() as connection: | ||
connection.query("INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password)) | ||
|
||
def user_exists(email: str) -> bool: | ||
with Database() as connection: | ||
result = connection.query("SELECT 1 FROM user WHERE email = ?", (email,))[0] | ||
return bool(result) | ||
|
||
def get_user(email: str): | ||
with Database() as connection: | ||
user_row = connection.query("SELECT * FROM user WHERE email = ?", (email,))[0] | ||
for row in user_row: | ||
return User(**row) | ||
raise Exception("User not found") | ||
|
||
def delete_user(email: str): | ||
with Database() as connection: | ||
connection.query("DELETE FROM user WHERE email = ?", (email,)) | ||
|
||
def authenticate_user(username: str, password: str): | ||
user = get_user(username) | ||
if not user or not password == user.password: | ||
return False | ||
return user | ||
|
||
def create_access_token(*, data: dict, expires_delta: timedelta = None): | ||
to_encode = data.copy() | ||
if expires_delta: | ||
expire = datetime.utcnow() + expires_delta | ||
else: | ||
expire = datetime.utcnow() + timedelta(minutes=15) | ||
to_encode.update({"exp": expire}) | ||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | ||
return encoded_jwt |