diff --git a/pyproject.toml b/pyproject.toml index 92cde0a..32189a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dev-dependencies = [ "types-pyjwt>=1.7.1", "httpx>=0.27.0", "pytest-sugar>=1.0.0", + "anyio>=4.4.0", ] [tool.pytest.ini_options] diff --git a/scripts/test.sh b/scripts/test.sh index e764f6c..2d344b6 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -5,6 +5,7 @@ export SMOLVAULT_DB="test.db" export SMOLVAULT_CACHE="./uploads/" export AUTH_SECRET_KEY="09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" # key from FastAPI docs to use in tests export DAILY_UPLOAD_LIMIT_BYTES="50000" +export USERS_LIMIT="3" # remove test db if it exists if [ -f $SMOLVAULT_DB ]; then diff --git a/src/smolvault/auth/models.py b/src/smolvault/auth/models.py index 7e26e82..42d53e4 100644 --- a/src/smolvault/auth/models.py +++ b/src/smolvault/auth/models.py @@ -17,7 +17,7 @@ class NewUserDTO(BaseModel): full_name: str password: SecretStr - @computed_field # type: ignore + @computed_field # type: ignore[misc] @cached_property def hashed_password(self) -> str: return bcrypt.hashpw(self.password.get_secret_value().encode(), bcrypt.gensalt()).decode() diff --git a/src/smolvault/main.py b/src/smolvault/main.py index f200e4c..85d3e50 100644 --- a/src/smolvault/main.py +++ b/src/smolvault/main.py @@ -19,7 +19,7 @@ from smolvault.clients.database import DatabaseClient, FileMetadataRecord from smolvault.config import Settings, get_settings from smolvault.models import FileMetadata, FileTagsDTO, FileUploadDTO -from smolvault.validators.operation_validator import UploadValidator +from smolvault.validators.operation_validator import UploadValidator, UserCreationValidator logging.basicConfig( handlers=[ @@ -55,10 +55,21 @@ async def read_root(current_user: Annotated[User, Depends(get_current_user)]) -> @app.post("/users/new") async def create_user( - user: NewUserDTO, db_client: Annotated[DatabaseClient, Depends(DatabaseClient)] + user: NewUserDTO, + db_client: Annotated[DatabaseClient, Depends(DatabaseClient)], + op_validator: Annotated[UserCreationValidator, Depends(UserCreationValidator)], ) -> dict[str, str]: - db_client.add_user(user) - return {"username": user.username} + logger.info("Received new user creation request for %s", user.username) + if op_validator.user_creation_allowed(db_client): + logger.info("Creating new user", extra=user.model_dump(exclude={"password"})) + db_client.add_user(user) + return {"username": user.username} + else: + logger.error("User creation failed. User limit exceeded") + raise HTTPException( + status_code=400, + detail="User limit exceeded", + ) @app.post("/token") diff --git a/tests/conftest.py b/tests/conftest.py index 813e275..d035d36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,8 @@ import pathlib from collections.abc import Generator from datetime import datetime -from typing import Any +from typing import Any, Literal +from uuid import uuid4 from zoneinfo import ZoneInfo import boto3 @@ -22,24 +23,38 @@ class TestDatabaseClient(DatabaseClient): - def __init__(self) -> None: - self.engine = create_engine("sqlite:///test.db", echo=False, connect_args={"check_same_thread": False}) + def __init__(self, filename: str) -> None: + self.engine = create_engine(f"sqlite:///{filename}", echo=False, connect_args={"check_same_thread": False}) SQLModel.metadata.create_all(self.engine) -@pytest.fixture(scope="session") -def _user() -> None: - client = TestDatabaseClient() +@pytest.fixture +def anyio_backend() -> Literal["asyncio"]: + return "asyncio" + + +@pytest.fixture +def temp_db(monkeypatch: pytest.MonkeyPatch) -> Generator[TestDatabaseClient, Any, Any]: + db_filename = f"test-{uuid4().hex}.db" + os.environ["SMOLVAULT_DB"] = db_filename + monkeypatch.setenv("SMOLVAULT_DB", db_filename) + client = TestDatabaseClient(db_filename) + yield client + pathlib.Path(db_filename).unlink() + + +@pytest.fixture +def _user(temp_db: TestDatabaseClient) -> None: user = NewUserDTO( username="testuser", password="testpassword", # type: ignore # noqa: S106 email="test@email.com", full_name="John Smith", ) - client.add_user(user) + temp_db.add_user(user) -@pytest.fixture(scope="module") +@pytest.fixture def client(_user: None) -> AsyncClient: app.dependency_overrides[DatabaseClient] = TestDatabaseClient return AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") # type: ignore diff --git a/tests/test_delete_file.py b/tests/test_delete_file.py index 09a3346..c838fcc 100644 --- a/tests/test_delete_file.py +++ b/tests/test_delete_file.py @@ -7,7 +7,7 @@ from smolvault.models import FileUploadDTO -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_test_bucket") async def test_delete_file(client: AsyncClient, camera_img: bytes, access_token: str) -> None: # first upload the file diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 0434451..35072a9 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -9,7 +9,7 @@ from smolvault.models import FileMetadata -@pytest.mark.asyncio +@pytest.mark.anyio async def test_read_root(client: AsyncClient, access_token: str) -> None: response = await client.get("/", headers={"Authorization": f"Bearer {access_token}"}) assert response.status_code == 200 @@ -21,7 +21,7 @@ async def test_read_root(client: AsyncClient, access_token: str) -> None: } -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_test_bucket") async def test_list_files( client: AsyncClient, @@ -39,7 +39,7 @@ def mock_get_all_files(*args: Any, **kwargs: Any) -> Sequence[FileMetadataRecord assert response.json() == [file_metadata.model_dump(by_alias=True)] -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_bucket_w_camera_img") async def test_get_file( client: AsyncClient, @@ -64,7 +64,7 @@ async def test_get_file( assert response.content == camera_img -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_test_bucket") async def test_get_file_not_found(client: AsyncClient, access_token: str) -> None: response = await client.get( diff --git a/tests/test_file_uploads.py b/tests/test_file_uploads.py index 2e92859..59028a0 100644 --- a/tests/test_file_uploads.py +++ b/tests/test_file_uploads.py @@ -7,7 +7,7 @@ from smolvault.models import FileUploadDTO -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_test_bucket") async def test_upload_file(client: AsyncClient, camera_img: bytes, access_token: str) -> None: filename = f"{uuid4().hex[:6]}-camera.png" @@ -31,7 +31,7 @@ async def test_upload_file(client: AsyncClient, camera_img: bytes, access_token: assert actual == expected -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_test_bucket") async def test_upload_file_no_tags(client: AsyncClient, camera_img: bytes, access_token: str) -> None: filename = f"{uuid4().hex[:6]}-camera.png" diff --git a/tests/test_search_by_tag.py b/tests/test_search_by_tag.py index ad57b56..2592ebb 100644 --- a/tests/test_search_by_tag.py +++ b/tests/test_search_by_tag.py @@ -4,7 +4,7 @@ from smolvault.models import FileMetadata -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_bucket_w_camera_img") async def test_search_tag_exists( client: AsyncClient, @@ -37,7 +37,7 @@ async def test_search_tag_exists( assert actual == expected -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_test_bucket") async def test_search_tag_not_found(client: AsyncClient, access_token: str) -> None: response = await client.get( diff --git a/tests/test_security.py b/tests/test_security.py index 083a51a..2e9480d 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -4,7 +4,7 @@ from httpx import AsyncClient -@pytest.fixture(scope="module") +@pytest.fixture async def user_john(client: AsyncClient) -> str: """ Creates a new user 'John' and returns the access token for John. @@ -29,7 +29,7 @@ async def user_john(client: AsyncClient) -> str: return response.json()["access_token"] -@pytest.fixture(scope="module") +@pytest.fixture async def user_jane(client: AsyncClient) -> str: """ Creates a new user 'Jane' and returns the access token for Jane. @@ -51,7 +51,7 @@ async def user_jane(client: AsyncClient) -> str: return response.json()["access_token"] -@pytest.fixture(scope="module") +@pytest.fixture async def user_jack(client: AsyncClient) -> str: """ Creates a new user 'Jack' and returns the access token for Jack. @@ -76,7 +76,7 @@ async def user_jack(client: AsyncClient) -> str: return response.json()["access_token"] -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_test_bucket") async def test_get_file(client: AsyncClient, camera_img: bytes, user_john: str, user_jane: str) -> None: """ @@ -128,7 +128,7 @@ async def _fully_populated_user_bucket( bytes_uploaded += img_size -@pytest.mark.asyncio +@pytest.mark.anyio @pytest.mark.usefixtures("_fully_populated_user_bucket") async def test_user_over_daily_upload_limit(client: AsyncClient, camera_img: bytes, user_jack: str) -> None: """ @@ -144,3 +144,23 @@ async def test_user_over_daily_upload_limit(client: AsyncClient, camera_img: byt ) assert response.status_code == 400 assert response.json() == {"error": "Upload limit exceeded"} + + +@pytest.mark.anyio +@pytest.mark.usefixtures("_test_bucket") +async def test_user_creation_limit(client: AsyncClient, user_john: str, user_jane: str, user_jack: str) -> None: + """ + Test that the system blocks new user creation if the user limit has been reached. + """ + + user_data = { + "username": "kate", + "password": "testpassword", + "email": "email4@email.com", + "full_name": "Kate Smith", + } + response = await client.post( + "/users/new", + json=user_data, + ) + assert response.status_code == 400 diff --git a/uv.lock b/uv.lock index f4bd0e5..6f7696e 100644 --- a/uv.lock +++ b/uv.lock @@ -1542,6 +1542,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "anyio" }, { name = "boto3-stubs" }, { name = "boto3-stubs", extra = ["essential"] }, { name = "httpx" }, @@ -1574,6 +1575,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "anyio", specifier = ">=4.4.0" }, { name = "boto3-stubs", extras = ["essential"], specifier = ">=1.35.2" }, { name = "httpx", specifier = ">=0.27.0" }, { name = "invoke", specifier = ">=2.2.0" }, @@ -1583,7 +1585,7 @@ dev = [ { name = "pytest", specifier = ">=8.3.2" }, { name = "pytest-asyncio", specifier = ">=0.23.8" }, { name = "pytest-cov", specifier = ">=5.0.0" }, - { name = "pytest-sugar" }, + { name = "pytest-sugar", specifier = ">=1.0.0" }, { name = "rich", specifier = ">=13.7.1" }, { name = "ruff", specifier = ">=0.6.1" }, { name = "types-pyjwt", specifier = ">=1.7.1" },