diff --git a/src/smolvault/clients/database.py b/src/smolvault/clients/database.py index 3ee69ee..e7138bf 100644 --- a/src/smolvault/clients/database.py +++ b/src/smolvault/clients/database.py @@ -59,9 +59,9 @@ def add_metadata(self, file_upload: FileUploadDTO, key: str) -> None: session.add(FileTag(tag_name=tag, file_id=file_metadata.id)) session.commit() - def get_all_metadata(self) -> Sequence[FileMetadataRecord]: + def get_all_metadata(self, user_id: int) -> Sequence[FileMetadataRecord]: with Session(self.engine) as session: - statement = select(FileMetadataRecord) + statement = select(FileMetadataRecord).where(FileMetadataRecord.user_id == user_id) results = session.exec(statement) return results.fetchall() @@ -74,12 +74,13 @@ def get_metadata(self, filename: str, user_id: int) -> FileMetadataRecord | None ) return session.exec(statement).first() - def select_metadata_by_tag(self, tag: str) -> Sequence[FileMetadataRecord]: + def select_metadata_by_tag(self, tag: str, user_id: int) -> Sequence[FileMetadataRecord]: with Session(self.engine) as session: statement = ( select(FileMetadataRecord) .where(FileTag.file_id == FileMetadataRecord.id) .where(FileTag.tag_name == tag) + .where(FileMetadataRecord.user_id == user_id) ) results = session.exec(statement) return results.fetchall() @@ -90,11 +91,11 @@ def update_metadata(self, record: FileMetadataRecord) -> None: session.commit() session.refresh(record) - def delete_metadata(self, record: FileMetadataRecord) -> None: + def delete_metadata(self, record: FileMetadataRecord, user_id: int) -> None: with Session(self.engine) as session: session.delete(record) session.commit() - statement = select(FileTag).where(FileTag.file_id == record.id) + statement = select(FileTag).where(FileTag.file_id == record.id).where(FileMetadataRecord.user_id == user_id) tags = session.exec(statement) for tag in tags: session.delete(tag) diff --git a/src/smolvault/main.py b/src/smolvault/main.py index c839e02..bdcf4c7 100644 --- a/src/smolvault/main.py +++ b/src/smolvault/main.py @@ -136,7 +136,7 @@ async def get_files( current_user: Annotated[User, Depends(get_current_user)], db_client: Annotated[DatabaseClient, Depends(DatabaseClient)], ) -> list[FileMetadata]: - raw_metadata = db_client.get_all_metadata() + raw_metadata = db_client.get_all_metadata(current_user.id) logger.info("Retrieved %d records from database", len(raw_metadata)) results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata] return results @@ -148,7 +148,7 @@ async def search_files( db_client: Annotated[DatabaseClient, Depends(DatabaseClient)], tag: str, ) -> list[FileMetadata]: - raw_metadata = db_client.select_metadata_by_tag(tag) + raw_metadata = db_client.select_metadata_by_tag(tag, current_user.id) logger.info("Retrieved %d records from database with tag %s", len(raw_metadata), tag) results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata] return results @@ -186,7 +186,7 @@ async def delete_file( if record is None: return Response(content=json.dumps({"error": "File not found"}), status_code=404, media_type="application/json") s3_client.delete(record.object_key) - db_client.delete_metadata(record) + db_client.delete_metadata(record, current_user.id) if record.local_path: background_tasks.add_task(cache.delete_file, record.local_path) return Response(