Skip to content

Commit

Permalink
db queries are guarded by user_id of current user
Browse files Browse the repository at this point in the history
  • Loading branch information
fullerzz committed Aug 10, 2024
1 parent da92cb4 commit d8ed717
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
11 changes: 6 additions & 5 deletions src/smolvault/clients/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def add_metadata(self, file_upload: FileUploadDTO, key: str) -> None:
session.add(FileTag(tag_name=tag, file_id=file_metadata.id))
session.commit()

def get_all_metadata(self) -> Sequence[FileMetadataRecord]:
def get_all_metadata(self, user_id: int) -> Sequence[FileMetadataRecord]:
with Session(self.engine) as session:
statement = select(FileMetadataRecord)
statement = select(FileMetadataRecord).where(FileMetadataRecord.user_id == user_id)
results = session.exec(statement)
return results.fetchall()

Expand All @@ -74,12 +74,13 @@ def get_metadata(self, filename: str, user_id: int) -> FileMetadataRecord | None
)
return session.exec(statement).first()

def select_metadata_by_tag(self, tag: str) -> Sequence[FileMetadataRecord]:
def select_metadata_by_tag(self, tag: str, user_id: int) -> Sequence[FileMetadataRecord]:
with Session(self.engine) as session:
statement = (
select(FileMetadataRecord)
.where(FileTag.file_id == FileMetadataRecord.id)
.where(FileTag.tag_name == tag)
.where(FileMetadataRecord.user_id == user_id)
)
results = session.exec(statement)
return results.fetchall()
Expand All @@ -90,11 +91,11 @@ def update_metadata(self, record: FileMetadataRecord) -> None:
session.commit()
session.refresh(record)

def delete_metadata(self, record: FileMetadataRecord) -> None:
def delete_metadata(self, record: FileMetadataRecord, user_id: int) -> None:
with Session(self.engine) as session:
session.delete(record)
session.commit()
statement = select(FileTag).where(FileTag.file_id == record.id)
statement = select(FileTag).where(FileTag.file_id == record.id).where(FileMetadataRecord.user_id == user_id)
tags = session.exec(statement)
for tag in tags:
session.delete(tag)
Expand Down
6 changes: 3 additions & 3 deletions src/smolvault/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def get_files(
current_user: Annotated[User, Depends(get_current_user)],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
) -> list[FileMetadata]:
raw_metadata = db_client.get_all_metadata()
raw_metadata = db_client.get_all_metadata(current_user.id)
logger.info("Retrieved %d records from database", len(raw_metadata))
results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata]
return results
Expand All @@ -148,7 +148,7 @@ async def search_files(
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
tag: str,
) -> list[FileMetadata]:
raw_metadata = db_client.select_metadata_by_tag(tag)
raw_metadata = db_client.select_metadata_by_tag(tag, current_user.id)
logger.info("Retrieved %d records from database with tag %s", len(raw_metadata), tag)
results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata]
return results
Expand Down Expand Up @@ -186,7 +186,7 @@ async def delete_file(
if record is None:
return Response(content=json.dumps({"error": "File not found"}), status_code=404, media_type="application/json")
s3_client.delete(record.object_key)
db_client.delete_metadata(record)
db_client.delete_metadata(record, current_user.id)
if record.local_path:
background_tasks.add_task(cache.delete_file, record.local_path)
return Response(
Expand Down

0 comments on commit d8ed717

Please sign in to comment.