Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.6.0 - Harden Security #55

Merged
merged 13 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smolvault"
version = "0.5.0"
version = "0.6.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
Expand All @@ -27,13 +27,15 @@ dev-dependencies = [
"boto3-stubs[essential]>=1.35.2",
"pre-commit>=3.8.0",
"pytest>=8.3.2",
"pytest-asyncio>=0.23.8",
"pytest-cov>=5.0.0",
"moto[all]>=5.0.13",
"invoke>=2.2.0",
"rich>=13.7.1",
"types-pyjwt>=1.7.1",
"httpx>=0.27.0",
"pytest-sugar>=1.0.0",
"anyio>=4.4.0",
"polyfactory>=2.16.2",
]

[tool.pytest.ini_options]
Expand Down
292 changes: 224 additions & 68 deletions requirements.txt

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion scripts/start_app.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#!/bin/bash

source .venv/bin/activate
hypercorn src.smolvault.main:app -b 0.0.0.0 --debug --log-config=logging.conf --log-level=DEBUG --access-logfile=hypercorn.access.log --error-logfile=hypercorn.error.log --keep-alive=120 --workers=2
hypercorn src.smolvault.main:app -b 0.0.0.0 --debug \
--log-config=logging.conf --log-level=DEBUG \
--access-logfile=hypercorn.access.log \
--error-logfile=hypercorn.error.log \
--keep-alive=120 --workers=2
5 changes: 4 additions & 1 deletion scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ export SMOLVAULT_BUCKET="test-bucket"
export SMOLVAULT_DB="test.db"
export SMOLVAULT_CACHE="./uploads/"
export AUTH_SECRET_KEY="09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" # key from FastAPI docs to use in tests
export DAILY_UPLOAD_LIMIT_BYTES="500000"
export USERS_LIMIT="20"
export USER_WHITELIST="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20"

# remove test db if it exists
if [ -f $SMOLVAULT_DB ]; then
Expand All @@ -18,4 +21,4 @@ fi
# create local cache dir
mkdir uploads

pytest -vvv tests/
pytest -vvv tests
2 changes: 1 addition & 1 deletion src/smolvault/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class NewUserDTO(BaseModel):
full_name: str
password: SecretStr

@computed_field # type: ignore
@computed_field # type: ignore[prop-decorator]
@cached_property
def hashed_password(self) -> str:
return bcrypt.hashpw(self.password.get_secret_value().encode(), bcrypt.gensalt()).decode()
Expand Down
15 changes: 14 additions & 1 deletion src/smolvault/clients/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from datetime import datetime

from sqlmodel import Field, Session, SQLModel, create_engine, select

Expand Down Expand Up @@ -59,9 +60,15 @@ def add_metadata(self, file_upload: FileUploadDTO, key: str) -> None:
session.add(FileTag(tag_name=tag, file_id=file_metadata.id))
session.commit()

def get_all_metadata(self, user_id: int) -> Sequence[FileMetadataRecord]:
def get_all_metadata(
self, user_id: int, start_time: datetime | None = None, end_time: datetime | None = None
) -> Sequence[FileMetadataRecord]:
with Session(self.engine) as session:
statement = select(FileMetadataRecord).where(FileMetadataRecord.user_id == user_id)
if start_time:
statement = statement.where(FileMetadataRecord.upload_timestamp >= start_time.isoformat())
if end_time:
statement = statement.where(FileMetadataRecord.upload_timestamp <= end_time.isoformat())
results = session.exec(statement)
return results.fetchall()
fullerzz marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -106,6 +113,12 @@ def get_user(self, username: str) -> UserInfo | None:
statement = select(UserInfo).where(UserInfo.username == username)
return session.exec(statement).first()

def get_user_count(self) -> int:
with Session(self.engine) as session:
statement = select(UserInfo)
results = session.exec(statement)
return len(results.fetchall())

def add_user(self, user: NewUserDTO) -> None:
user_info = UserInfo(
username=user.username,
Expand Down
3 changes: 3 additions & 0 deletions src/smolvault/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class Settings(BaseSettings):
smolvault_db: str
smolvault_cache: str
auth_secret_key: str
user_whitelist: str
users_limit: int
daily_upload_limit_bytes: int

model_config = SettingsConfigDict(env_file=".env")

Expand Down
59 changes: 42 additions & 17 deletions src/smolvault/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,20 @@
from logging.handlers import RotatingFileHandler
from typing import Annotated

from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
File,
Form,
HTTPException,
UploadFile,
)
from fastapi import BackgroundTasks, Depends, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import FileResponse, Response
from fastapi.security import OAuth2PasswordRequestForm

from smolvault.auth.decoder import (
authenticate_user,
create_access_token,
get_current_user,
)
from smolvault.auth.decoder import authenticate_user, create_access_token, get_current_user
from smolvault.auth.models import NewUserDTO, Token, User
from smolvault.cache.cache_manager import CacheManager
from smolvault.clients.aws import S3Client
from smolvault.clients.database import DatabaseClient, FileMetadataRecord
from smolvault.config import Settings, get_settings
from smolvault.models import FileMetadata, FileTagsDTO, FileUploadDTO
from smolvault.validators.operation_validator import UploadValidator, UserCreationValidator

logging.basicConfig(
handlers=[
Expand Down Expand Up @@ -66,35 +55,58 @@ async def read_root(current_user: Annotated[User, Depends(get_current_user)]) ->

@app.post("/users/new")
async def create_user(
user: NewUserDTO, db_client: Annotated[DatabaseClient, Depends(DatabaseClient)]
user: NewUserDTO,
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
op_validator: Annotated[UserCreationValidator, Depends(UserCreationValidator)],
) -> dict[str, str]:
db_client.add_user(user)
return {"username": user.username}
logger.info("Received new user creation request for %s", user.username)
if op_validator.user_creation_allowed(db_client):
logger.info("Creating new user", extra=user.model_dump(exclude={"password"}))
db_client.add_user(user)
return {"username": user.username}
else:
logger.error("User creation failed. User limit exceeded")
raise HTTPException(
status_code=400,
detail="User limit exceeded",
)


@app.post("/token")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
) -> Token:
logger.info("Authenticating user %s", form_data.username)
user = authenticate_user(db_client, form_data.username, form_data.password)
if not user:
logger.info("Incorrect username or password for %s", form_data.username)
raise HTTPException(
status_code=400,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(data={"sub": user.username})
logger.info("User %s authenticated successfully", user.username)
return access_token


@app.post("/file/upload")
async def upload_file(
current_user: Annotated[User, Depends(get_current_user)],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
op_validator: Annotated[UploadValidator, Depends(UploadValidator)],
file: Annotated[UploadFile, File()],
tags: str | None = Form(default=None),
) -> Response:
logger.info("Received file upload request from %s", current_user.username)
if not op_validator.upload_allowed(current_user.id, db_client):
logger.error("Upload limit exceeded for user %s", current_user.username)
fullerzz marked this conversation as resolved.
Show resolved Hide resolved
return Response(
content=json.dumps({"error": "Upload limit exceeded"}),
status_code=400,
media_type="application/json",
)
contents = await file.read()
if file.filename is None:
logger.error("Filename not received in request")
Expand All @@ -113,6 +125,7 @@ async def upload_file(
)
object_key = s3_client.upload(data=file_upload)
db_client.add_metadata(file_upload, object_key)
logger.info("File %s uploaded successfully", file_upload.name)
return Response(
content=json.dumps(file_upload.model_dump(exclude={"content", "tags"})),
status_code=201,
Expand All @@ -127,6 +140,7 @@ async def get_file(
filename: str,
background_tasks: BackgroundTasks,
) -> Response:
logger.info("Received file download request for %s from %s", filename, current_user.username)
record = db_client.get_metadata(filename, current_user.id)
if record is None:
logger.info("File not found: %s", filename)
Expand All @@ -152,9 +166,12 @@ async def get_file_metadata(
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
name: str,
) -> FileMetadata | None:
logger.info("Retrieving metadata for file %s requested by %s", name, current_user.username)
record: FileMetadataRecord | None = db_client.get_metadata(urllib.parse.unquote(name), current_user.id)
if record:
logger.info("Retrieved metadata for file %s", name)
return FileMetadata.model_validate(record.model_dump())
logger.info("File metadata for %s not found", name)
return None


Expand All @@ -163,6 +180,7 @@ async def get_files(
current_user: Annotated[User, Depends(get_current_user)],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
) -> list[FileMetadata]:
logger.info("Retrieving all files for user %s", current_user.username)
raw_metadata = db_client.get_all_metadata(current_user.id)
logger.info("Retrieved %d records from database", len(raw_metadata))
results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata]
Expand All @@ -175,6 +193,7 @@ async def search_files(
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
tag: str,
) -> list[FileMetadata]:
logger.info("Retrieving files with tag %s for user %s", tag, current_user.username)
raw_metadata = db_client.select_metadata_by_tag(tag, current_user.id)
logger.info("Retrieved %d records from database with tag %s", len(raw_metadata), tag)
results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata]
Expand All @@ -188,8 +207,10 @@ async def update_file_tags(
name: str,
tags: FileTagsDTO,
) -> Response:
logger.info("Updating tags for file %s requested by %s", name, current_user.username)
record: FileMetadataRecord | None = db_client.get_metadata(name, current_user.id)
if record is None:
logger.info("Tag update failed. File %s not found", name)
return Response(
content=json.dumps({"error": "File not found"}),
status_code=404,
Expand All @@ -199,6 +220,7 @@ async def update_file_tags(
record.tags = tags.tags_str
db_client.update_metadata(record)
file_metadata = FileMetadata.model_validate(record.model_dump())
logger.info("Tags updated for file %s", name)
return Response(
content=json.dumps(
{
Expand All @@ -218,8 +240,10 @@ async def delete_file(
name: str,
background_tasks: BackgroundTasks,
) -> Response:
logger.info("Recieved delete request for file %s from %s", name, current_user.username)
record: FileMetadataRecord | None = db_client.get_metadata(name, current_user.id)
if record is None:
logger.info("File %s not found", name)
return Response(
content=json.dumps({"error": "File not found"}),
status_code=404,
Expand All @@ -229,6 +253,7 @@ async def delete_file(
db_client.delete_metadata(record, current_user.id)
if record.local_path:
background_tasks.add_task(cache.delete_file, record.local_path)
logger.info("File %s deleted successfully", name)
return Response(
content=json.dumps({"message": "File deleted successfully", "record": record.model_dump()}),
status_code=200,
Expand Down
47 changes: 47 additions & 0 deletions src/smolvault/validators/operation_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from datetime import datetime, timedelta

from smolvault.clients.database import DatabaseClient
from smolvault.config import get_settings

logger = logging.getLogger(__name__)


class UploadValidator:
def __init__(self) -> None:
self.settings = get_settings()
self.daily_upload_limit_bytes = self.settings.daily_upload_limit_bytes
self.whitelist = self.settings.user_whitelist.split(",")

def upload_allowed(self, user_id: int, db_client: DatabaseClient) -> bool:
valid = self._uploads_under_limit_prev_24h(user_id, db_client) and self._user_on_whitelist(user_id)
logger.info("Upload allowed result for user %s: %s", user_id, valid)
return valid

def _uploads_under_limit_prev_24h(self, user_id: int, db_client: DatabaseClient) -> bool:
logger.info("Checking upload limit for user %s", user_id)
start_time = datetime.now() - timedelta(days=1)
metadata = db_client.get_all_metadata(user_id, start_time=start_time)
bytes_uploaded = sum([record.size for record in metadata])
logger.info(
"User %s has uploaded %d bytes in the last 24 hours. DAILY_LIMIT: %d",
user_id,
bytes_uploaded,
self.daily_upload_limit_bytes,
)
return bytes_uploaded < self.daily_upload_limit_bytes

def _user_on_whitelist(self, user_id: int) -> bool:
logger.info("Checking whitelist for user %s", user_id)
return str(user_id) in self.whitelist


class UserCreationValidator:
def __init__(self) -> None:
self.settings = get_settings()
self.users_limit = self.settings.users_limit

def user_creation_allowed(self, db_client: DatabaseClient) -> bool:
users: int = db_client.get_user_count()
logger.info("%d users currently in the system", users)
return users < self.users_limit
33 changes: 33 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import sqlite3
from datetime import datetime
from typing import Any
from zoneinfo import ZoneInfo

from invoke.context import Context
from invoke.tasks import task
from rich import print
from rich.table import Table


@task
Expand Down Expand Up @@ -36,7 +38,38 @@ def show_table(c: Context) -> None:
conn.close()


def output_table(title: str, column_names: list[str], rows: list[Any]) -> None:
table = Table(title=title)
for column_name in column_names:
table.add_column(column_name)
for row in rows:
table.add_row(*row)
print(table)


@task
def show_users_table(c: Context) -> None:
conn = sqlite3.connect("file_metadata.db")
cursor = conn.cursor()
cursor.execute("SELECT * FROM userinfo")
results = cursor.fetchall()
conn.close()
rows: list[tuple[str, str, str, str]] = []
column_names = ["id", "username", "hashed_password", "email", "full_name"]
print(
f"[bold cyan]Unformatted results:[/bold cyan]\n[blue]column_names=[/blue][bold purple]{column_names}[/bold purple]\n {results}"
)
for result in results:
rows.append((str(result[0]), result[1], result[2], result[4])) # noqa: PERF401
output_table("[bold cyan]Users Table[/bold cyan]", ["id", "username", "hashed_pwd", "name"], rows)


@task
def bak_db(c: Context) -> None:
timestamp = datetime.now(ZoneInfo("UTC")).strftime("%Y-%m-%d_%H:%M:%S")
c.run(f"cp file_metadata.db file_metadata_{timestamp}.bak.db", echo=True)


@task
def export_reqs(c: Context) -> None:
c.run("uv export --no-emit-project --no-dev --output-file=requirements.txt", echo=True, pty=True)
Loading