diff --git a/src/smolvault/clients/database.py b/src/smolvault/clients/database.py index 4ad9778..774351d 100644 --- a/src/smolvault/clients/database.py +++ b/src/smolvault/clients/database.py @@ -70,8 +70,8 @@ def get_all_metadata( *, start_time: datetime | None = None, end_time: datetime | None = None, - offset: int = 0, - limit: Annotated[int, PydanticField(default=10, lt=100)] = 10, + offset: int | None = 0, + limit: Annotated[int | None, PydanticField(default=10, lt=100)] = 10, ) -> Sequence[FileMetadataRecord]: with Session(self.engine) as session: statement = select(FileMetadataRecord).where(FileMetadataRecord.user_id == user_id) @@ -91,8 +91,13 @@ def get_metadata(self, filename: str, user_id: int) -> FileMetadataRecord | None ) return session.exec(statement).first() - # TODO: Add offset and limit parameters - def select_metadata_by_tag(self, tag: str, user_id: int) -> Sequence[FileMetadataRecord]: + def select_metadata_by_tag( + self, + tag: str, + user_id: int, + offset: int | None = 0, + limit: Annotated[int | None, PydanticField(default=10, lt=100)] = 10, + ) -> Sequence[FileMetadataRecord]: with Session(self.engine) as session: statement = ( select(FileMetadataRecord) @@ -100,7 +105,7 @@ def select_metadata_by_tag(self, tag: str, user_id: int) -> Sequence[FileMetadat .where(FileTag.tag_name == tag) .where(FileMetadataRecord.user_id == user_id) ) - results = session.exec(statement) + results = session.exec(statement.offset(offset).limit(limit)) return results.fetchall() def update_metadata(self, record: FileMetadataRecord) -> None: diff --git a/src/smolvault/main.py b/src/smolvault/main.py index 631bd95..8a0c1e3 100644 --- a/src/smolvault/main.py +++ b/src/smolvault/main.py @@ -175,28 +175,30 @@ async def get_file_metadata( return None -# TODO: Add offset and limit query parameters @app.get("/files") async def get_files( current_user: Annotated[User, Depends(get_current_user)], db_client: Annotated[DatabaseClient, Depends(DatabaseClient)], + offset: int | None = None, + limit: int | None = None, ) -> list[FileMetadata]: logger.info("Retrieving all files for user %s", current_user.username) - raw_metadata = db_client.get_all_metadata(current_user.id) + raw_metadata = db_client.get_all_metadata(user_id=current_user.id, offset=offset, limit=limit) logger.info("Retrieved %d records from database", len(raw_metadata)) results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata] return results -# TODO: Add offset and limit query parameters @app.get("/files/search") async def search_files( current_user: Annotated[User, Depends(get_current_user)], db_client: Annotated[DatabaseClient, Depends(DatabaseClient)], tag: str, + offset: int | None = None, + limit: int | None = None, ) -> list[FileMetadata]: logger.info("Retrieving files with tag %s for user %s", tag, current_user.username) - raw_metadata = db_client.select_metadata_by_tag(tag, current_user.id) + raw_metadata = db_client.select_metadata_by_tag(tag, current_user.id, offset=offset, limit=limit) logger.info("Retrieved %d records from database with tag %s", len(raw_metadata), tag) results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata] return results diff --git a/tests/conftest.py b/tests/conftest.py index 6400295..e9272c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -48,7 +48,7 @@ def user(user_factory: UserFactory, db_client: TestDatabaseClient) -> tuple[str, @pytest.fixture(scope="module") def client() -> AsyncClient: app.dependency_overrides[DatabaseClient] = TestDatabaseClient - return AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") # type: ignore + return AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") @pytest.fixture