Skip to content

Commit

Permalink
updated get and search files endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
fullerzz committed Sep 22, 2024
1 parent b5ad684 commit d62bb1e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
15 changes: 10 additions & 5 deletions src/smolvault/clients/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def get_all_metadata(
*,
start_time: datetime | None = None,
end_time: datetime | None = None,
offset: int = 0,
limit: Annotated[int, PydanticField(default=10, lt=100)] = 10,
offset: int | None = 0,
limit: Annotated[int | None, PydanticField(default=10, lt=100)] = 10,
) -> Sequence[FileMetadataRecord]:
with Session(self.engine) as session:
statement = select(FileMetadataRecord).where(FileMetadataRecord.user_id == user_id)
Expand All @@ -91,16 +91,21 @@ def get_metadata(self, filename: str, user_id: int) -> FileMetadataRecord | None
)
return session.exec(statement).first()

# TODO: Add offset and limit parameters
def select_metadata_by_tag(self, tag: str, user_id: int) -> Sequence[FileMetadataRecord]:
def select_metadata_by_tag(
self,
tag: str,
user_id: int,
offset: int | None = 0,
limit: Annotated[int | None, PydanticField(default=10, lt=100)] = 10,
) -> Sequence[FileMetadataRecord]:
with Session(self.engine) as session:
statement = (
select(FileMetadataRecord)
.where(FileTag.file_id == FileMetadataRecord.id)
.where(FileTag.tag_name == tag)
.where(FileMetadataRecord.user_id == user_id)
)
results = session.exec(statement)
results = session.exec(statement.offset(offset).limit(limit))
return results.fetchall()

def update_metadata(self, record: FileMetadataRecord) -> None:
Expand Down
10 changes: 6 additions & 4 deletions src/smolvault/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,28 +175,30 @@ async def get_file_metadata(
return None


# TODO: Add offset and limit query parameters
@app.get("/files")
async def get_files(
current_user: Annotated[User, Depends(get_current_user)],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
offset: int | None = None,
limit: int | None = None,
) -> list[FileMetadata]:
logger.info("Retrieving all files for user %s", current_user.username)
raw_metadata = db_client.get_all_metadata(current_user.id)
raw_metadata = db_client.get_all_metadata(user_id=current_user.id, offset=offset, limit=limit)
logger.info("Retrieved %d records from database", len(raw_metadata))
results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata]
return results


# TODO: Add offset and limit query parameters
@app.get("/files/search")
async def search_files(
current_user: Annotated[User, Depends(get_current_user)],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
tag: str,
offset: int | None = None,
limit: int | None = None,
) -> list[FileMetadata]:
logger.info("Retrieving files with tag %s for user %s", tag, current_user.username)
raw_metadata = db_client.select_metadata_by_tag(tag, current_user.id)
raw_metadata = db_client.select_metadata_by_tag(tag, current_user.id, offset=offset, limit=limit)
logger.info("Retrieved %d records from database with tag %s", len(raw_metadata), tag)
results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata]
return results
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def user(user_factory: UserFactory, db_client: TestDatabaseClient) -> tuple[str,
@pytest.fixture(scope="module")
def client() -> AsyncClient:
app.dependency_overrides[DatabaseClient] = TestDatabaseClient
return AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") # type: ignore
return AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver")


@pytest.fixture
Expand Down

0 comments on commit d62bb1e

Please sign in to comment.