Skip to content

Commit

Permalink
system now only allows user file upload if user is whitelisted and us…
Browse files Browse the repository at this point in the history
…er has uploaded less than daily amount of bytes
  • Loading branch information
fullerzz committed Aug 31, 2024
1 parent c4c942d commit 7ab3606
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 43 deletions.
9 changes: 8 additions & 1 deletion src/smolvault/clients/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from datetime import datetime

from sqlmodel import Field, Session, SQLModel, create_engine, select

Expand Down Expand Up @@ -59,9 +60,15 @@ def add_metadata(self, file_upload: FileUploadDTO, key: str) -> None:
session.add(FileTag(tag_name=tag, file_id=file_metadata.id))
session.commit()

def get_all_metadata(self, user_id: int) -> Sequence[FileMetadataRecord]:
def get_all_metadata(
self, user_id: int, start_time: datetime | None = None, end_time: datetime | None = None
) -> Sequence[FileMetadataRecord]:
with Session(self.engine) as session:
statement = select(FileMetadataRecord).where(FileMetadataRecord.user_id == user_id)
if start_time:
statement = statement.where(FileMetadataRecord.upload_timestamp >= start_time.isoformat())
if end_time:
statement = statement.where(FileMetadataRecord.upload_timestamp <= end_time.isoformat())
results = session.exec(statement)
return results.fetchall()

Expand Down
1 change: 1 addition & 0 deletions src/smolvault/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Settings(BaseSettings):
smolvault_db: str
smolvault_cache: str
auth_secret_key: str
user_whitelist: str

model_config = SettingsConfigDict(env_file=".env")

Expand Down
92 changes: 52 additions & 40 deletions src/smolvault/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
BackgroundTasks,
Depends,
FastAPI,
File,
Form,
HTTPException,
UploadFile,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
Expand All @@ -22,12 +25,13 @@
create_access_token,
get_current_user,
)
from smolvault.auth.models import Token, User
from smolvault.auth.models import NewUserDTO, Token, User
from smolvault.cache.cache_manager import CacheManager
from smolvault.clients.aws import S3Client
from smolvault.clients.database import DatabaseClient, FileMetadataRecord
from smolvault.config import Settings, get_settings
from smolvault.models import FileMetadata, FileTagsDTO
from smolvault.models import FileMetadata, FileTagsDTO, FileUploadDTO
from smolvault.validators.operation_validator import UploadValidator

logging.basicConfig(
handlers=[
Expand Down Expand Up @@ -61,12 +65,12 @@ async def read_root(current_user: Annotated[User, Depends(get_current_user)]) ->
return current_user


# @app.post("/users/new")
# async def create_user(
# user: NewUserDTO, db_client: Annotated[DatabaseClient, Depends(DatabaseClient)]
# ) -> dict[str, str]:
# db_client.add_user(user)
# return {"username": user.username}
@app.post("/users/new")
async def create_user(
user: NewUserDTO, db_client: Annotated[DatabaseClient, Depends(DatabaseClient)]
) -> dict[str, str]:
db_client.add_user(user)
return {"username": user.username}


@app.post("/token")
Expand All @@ -88,38 +92,46 @@ async def login_for_access_token(
return access_token


# @app.post("/file/upload")
# async def upload_file(
# current_user: Annotated[User, Depends(get_current_user)],
# db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
# file: Annotated[UploadFile, File()],
# tags: str | None = Form(default=None),
# ) -> Response:
# logger.info("Received file upload request from %s", current_user.username)
# contents = await file.read()
# if file.filename is None:
# logger.error("Filename not received in request")
# raise ValueError("Filename is required")
# file_upload = FileUploadDTO(
# name=file.filename,
# size=len(contents),
# content=contents,
# tags=tags,
# user_id=current_user.id,
# )
# logger.info(
# "Uploading file to S3 with name %s uploaded by %s",
# file_upload.name,
# current_user.username,
# )
# object_key = s3_client.upload(data=file_upload)
# db_client.add_metadata(file_upload, object_key)
# logger.info("File %s uploaded successfully", file_upload.name)
# return Response(
# content=json.dumps(file_upload.model_dump(exclude={"content", "tags"})),
# status_code=201,
# media_type="application/json",
# )
@app.post("/file/upload")
async def upload_file(
current_user: Annotated[User, Depends(get_current_user)],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
op_validator: Annotated[UploadValidator, Depends(UploadValidator)],
file: Annotated[UploadFile, File()],
tags: str | None = Form(default=None),
) -> Response:
logger.info("Received file upload request from %s", current_user.username)
if not op_validator.upload_allowed(current_user.id, db_client):
logger.error("Upload limit exceeded for user %s", current_user.username)
return Response(
content=json.dumps({"error": "Upload limit exceeded"}),
status_code=400,
media_type="application/json",
)
contents = await file.read()
if file.filename is None:
logger.error("Filename not received in request")
raise ValueError("Filename is required")
file_upload = FileUploadDTO(
name=file.filename,
size=len(contents),
content=contents,
tags=tags,
user_id=current_user.id,
)
logger.info(
"Uploading file to S3 with name %s uploaded by %s",
file_upload.name,
current_user.username,
)
object_key = s3_client.upload(data=file_upload)
db_client.add_metadata(file_upload, object_key)
logger.info("File %s uploaded successfully", file_upload.name)
return Response(
content=json.dumps(file_upload.model_dump(exclude={"content", "tags"})),
status_code=201,
media_type="application/json",
)


@app.get("/file/original")
Expand Down
39 changes: 39 additions & 0 deletions src/smolvault/validators/operation_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging
from datetime import datetime, timedelta

from smolvault.clients.database import DatabaseClient
from smolvault.config import get_settings

DAILY_UPLOAD_LIMIT_BYTES = 1_000_000_000
logger = logging.getLogger(__name__)


class UploadValidator:
def __init__(self) -> None:
self.settings = get_settings()
self.whitelist = self.settings.user_whitelist.split(",")

def upload_allowed(self, user_id: int, db_client: DatabaseClient) -> bool:
valid = self._uploads_under_limit_prev_24h(user_id, db_client) and self._user_on_whitelist(user_id)
logger.info("Upload allowed result for user %s: %s", user_id, valid)
return valid

def _uploads_under_limit_prev_24h(self, user_id: int, db_client: DatabaseClient) -> bool:
logger.info("Checking upload limit for user %s", user_id)
start_time = datetime.now() - timedelta(days=1)
metadata = db_client.get_all_metadata(user_id, start_time=start_time)
bytes_uploaded = sum([record.size for record in metadata])
logger.info("User %s has uploaded %d bytes in the last 24 hours", user_id, bytes_uploaded)
return bytes_uploaded < DAILY_UPLOAD_LIMIT_BYTES

def _user_on_whitelist(self, user_id: int) -> bool:
logger.info("Checking whitelist for user %s", user_id)
return str(user_id) in self.whitelist


class UserCreationValidator:
def __init__(self, database_client: DatabaseClient) -> None:
self.database_client = database_client

def user_creation_allowed(self, email: str) -> bool:
raise NotImplementedError
3 changes: 1 addition & 2 deletions tests/test_file_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ async def test_upload_file_no_tags(client: AsyncClient, camera_img: bytes, acces
files={"file": (filename, camera_img, "image/png")},
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == 201
actual: dict[str, Any] = response.json()
actual.pop("upload_timestamp")

assert response.status_code == 201
assert actual == expected
1 change: 1 addition & 0 deletions tests/testing.env
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ ENVIRONMENT="dev"
SMOLVAULT_BUCKET="test-bucket"
SMOLVAULT_DB="test.db"
SMOLVAULT_CACHE="./uploads/"
USER_WHITELIST="1,2"

0 comments on commit 7ab3606

Please sign in to comment.