From 7ab360680fbe36758d3155b611c2fbcb58e15966 Mon Sep 17 00:00:00 2001 From: Zach Fuller Date: Sat, 31 Aug 2024 15:41:15 -0700 Subject: [PATCH] system now only allows user file upload if user is whitelisted and user has uploaded less than daily amount of bytes --- src/smolvault/clients/database.py | 9 +- src/smolvault/config.py | 1 + src/smolvault/main.py | 92 +++++++++++-------- .../validators/operation_validator.py | 39 ++++++++ tests/test_file_uploads.py | 3 +- tests/testing.env | 1 + 6 files changed, 102 insertions(+), 43 deletions(-) create mode 100644 src/smolvault/validators/operation_validator.py diff --git a/src/smolvault/clients/database.py b/src/smolvault/clients/database.py index 71abded..98ce33f 100644 --- a/src/smolvault/clients/database.py +++ b/src/smolvault/clients/database.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from datetime import datetime from sqlmodel import Field, Session, SQLModel, create_engine, select @@ -59,9 +60,15 @@ def add_metadata(self, file_upload: FileUploadDTO, key: str) -> None: session.add(FileTag(tag_name=tag, file_id=file_metadata.id)) session.commit() - def get_all_metadata(self, user_id: int) -> Sequence[FileMetadataRecord]: + def get_all_metadata( + self, user_id: int, start_time: datetime | None = None, end_time: datetime | None = None + ) -> Sequence[FileMetadataRecord]: with Session(self.engine) as session: statement = select(FileMetadataRecord).where(FileMetadataRecord.user_id == user_id) + if start_time: + statement = statement.where(FileMetadataRecord.upload_timestamp >= start_time.isoformat()) + if end_time: + statement = statement.where(FileMetadataRecord.upload_timestamp <= end_time.isoformat()) results = session.exec(statement) return results.fetchall() diff --git a/src/smolvault/config.py b/src/smolvault/config.py index adc9016..3076a09 100644 --- a/src/smolvault/config.py +++ b/src/smolvault/config.py @@ -10,6 +10,7 @@ class Settings(BaseSettings): smolvault_db: str smolvault_cache: str auth_secret_key: str + user_whitelist: str model_config = SettingsConfigDict(env_file=".env") diff --git a/src/smolvault/main.py b/src/smolvault/main.py index e509962..502eca8 100644 --- a/src/smolvault/main.py +++ b/src/smolvault/main.py @@ -10,7 +10,10 @@ BackgroundTasks, Depends, FastAPI, + File, + Form, HTTPException, + UploadFile, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -22,12 +25,13 @@ create_access_token, get_current_user, ) -from smolvault.auth.models import Token, User +from smolvault.auth.models import NewUserDTO, Token, User from smolvault.cache.cache_manager import CacheManager from smolvault.clients.aws import S3Client from smolvault.clients.database import DatabaseClient, FileMetadataRecord from smolvault.config import Settings, get_settings -from smolvault.models import FileMetadata, FileTagsDTO +from smolvault.models import FileMetadata, FileTagsDTO, FileUploadDTO +from smolvault.validators.operation_validator import UploadValidator logging.basicConfig( handlers=[ @@ -61,12 +65,12 @@ async def read_root(current_user: Annotated[User, Depends(get_current_user)]) -> return current_user -# @app.post("/users/new") -# async def create_user( -# user: NewUserDTO, db_client: Annotated[DatabaseClient, Depends(DatabaseClient)] -# ) -> dict[str, str]: -# db_client.add_user(user) -# return {"username": user.username} +@app.post("/users/new") +async def create_user( + user: NewUserDTO, db_client: Annotated[DatabaseClient, Depends(DatabaseClient)] +) -> dict[str, str]: + db_client.add_user(user) + return {"username": user.username} @app.post("/token") @@ -88,38 +92,46 @@ async def login_for_access_token( return access_token -# @app.post("/file/upload") -# async def upload_file( -# current_user: Annotated[User, Depends(get_current_user)], -# db_client: Annotated[DatabaseClient, Depends(DatabaseClient)], -# file: Annotated[UploadFile, File()], -# tags: str | None = Form(default=None), -# ) -> Response: -# logger.info("Received file upload request from %s", current_user.username) -# contents = await file.read() -# if file.filename is None: -# logger.error("Filename not received in request") -# raise ValueError("Filename is required") -# file_upload = FileUploadDTO( -# name=file.filename, -# size=len(contents), -# content=contents, -# tags=tags, -# user_id=current_user.id, -# ) -# logger.info( -# "Uploading file to S3 with name %s uploaded by %s", -# file_upload.name, -# current_user.username, -# ) -# object_key = s3_client.upload(data=file_upload) -# db_client.add_metadata(file_upload, object_key) -# logger.info("File %s uploaded successfully", file_upload.name) -# return Response( -# content=json.dumps(file_upload.model_dump(exclude={"content", "tags"})), -# status_code=201, -# media_type="application/json", -# ) +@app.post("/file/upload") +async def upload_file( + current_user: Annotated[User, Depends(get_current_user)], + db_client: Annotated[DatabaseClient, Depends(DatabaseClient)], + op_validator: Annotated[UploadValidator, Depends(UploadValidator)], + file: Annotated[UploadFile, File()], + tags: str | None = Form(default=None), +) -> Response: + logger.info("Received file upload request from %s", current_user.username) + if not op_validator.upload_allowed(current_user.id, db_client): + logger.error("Upload limit exceeded for user %s", current_user.username) + return Response( + content=json.dumps({"error": "Upload limit exceeded"}), + status_code=400, + media_type="application/json", + ) + contents = await file.read() + if file.filename is None: + logger.error("Filename not received in request") + raise ValueError("Filename is required") + file_upload = FileUploadDTO( + name=file.filename, + size=len(contents), + content=contents, + tags=tags, + user_id=current_user.id, + ) + logger.info( + "Uploading file to S3 with name %s uploaded by %s", + file_upload.name, + current_user.username, + ) + object_key = s3_client.upload(data=file_upload) + db_client.add_metadata(file_upload, object_key) + logger.info("File %s uploaded successfully", file_upload.name) + return Response( + content=json.dumps(file_upload.model_dump(exclude={"content", "tags"})), + status_code=201, + media_type="application/json", + ) @app.get("/file/original") diff --git a/src/smolvault/validators/operation_validator.py b/src/smolvault/validators/operation_validator.py new file mode 100644 index 0000000..fcec0ad --- /dev/null +++ b/src/smolvault/validators/operation_validator.py @@ -0,0 +1,39 @@ +import logging +from datetime import datetime, timedelta + +from smolvault.clients.database import DatabaseClient +from smolvault.config import get_settings + +DAILY_UPLOAD_LIMIT_BYTES = 1_000_000_000 +logger = logging.getLogger(__name__) + + +class UploadValidator: + def __init__(self) -> None: + self.settings = get_settings() + self.whitelist = self.settings.user_whitelist.split(",") + + def upload_allowed(self, user_id: int, db_client: DatabaseClient) -> bool: + valid = self._uploads_under_limit_prev_24h(user_id, db_client) and self._user_on_whitelist(user_id) + logger.info("Upload allowed result for user %s: %s", user_id, valid) + return valid + + def _uploads_under_limit_prev_24h(self, user_id: int, db_client: DatabaseClient) -> bool: + logger.info("Checking upload limit for user %s", user_id) + start_time = datetime.now() - timedelta(days=1) + metadata = db_client.get_all_metadata(user_id, start_time=start_time) + bytes_uploaded = sum([record.size for record in metadata]) + logger.info("User %s has uploaded %d bytes in the last 24 hours", user_id, bytes_uploaded) + return bytes_uploaded < DAILY_UPLOAD_LIMIT_BYTES + + def _user_on_whitelist(self, user_id: int) -> bool: + logger.info("Checking whitelist for user %s", user_id) + return str(user_id) in self.whitelist + + +class UserCreationValidator: + def __init__(self, database_client: DatabaseClient) -> None: + self.database_client = database_client + + def user_creation_allowed(self, email: str) -> bool: + raise NotImplementedError diff --git a/tests/test_file_uploads.py b/tests/test_file_uploads.py index 02fc442..2e92859 100644 --- a/tests/test_file_uploads.py +++ b/tests/test_file_uploads.py @@ -43,8 +43,7 @@ async def test_upload_file_no_tags(client: AsyncClient, camera_img: bytes, acces files={"file": (filename, camera_img, "image/png")}, headers={"Authorization": f"Bearer {access_token}"}, ) + assert response.status_code == 201 actual: dict[str, Any] = response.json() actual.pop("upload_timestamp") - - assert response.status_code == 201 assert actual == expected diff --git a/tests/testing.env b/tests/testing.env index a1456d4..7f05d94 100644 --- a/tests/testing.env +++ b/tests/testing.env @@ -2,3 +2,4 @@ ENVIRONMENT="dev" SMOLVAULT_BUCKET="test-bucket" SMOLVAULT_DB="test.db" SMOLVAULT_CACHE="./uploads/" +USER_WHITELIST="1,2"