Skip to content

Commit

Permalink
fix(detections): use correct bucket for URL fetching (#366)
Browse files Browse the repository at this point in the history
* fix error when using admin creds

* add pagination parameters

* 15 event by default

* refactor(detections): merged SQL query into a single one

* style(detections): silences mypy warnings

* fix(detections): fix route logic

* fix(detections): fix syntax typo

---------

Co-authored-by: Ronan <[email protected]>
Co-authored-by: F-G Fernandez <[email protected]>
  • Loading branch information
3 people authored Nov 1, 2024
1 parent f737b0d commit 40cda4a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 20 deletions.
54 changes: 35 additions & 19 deletions src/app/api/api_v1/endpoints/detections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from datetime import datetime
from typing import List, cast
from typing import List, Optional, cast

from fastapi import (
APIRouter,
Expand All @@ -19,6 +19,8 @@
UploadFile,
status,
)
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.api.dependencies import (
dispatch_webhook,
Expand All @@ -30,6 +32,7 @@
)
from app.core.config import settings
from app.crud import CameraCRUD, DetectionCRUD, OrganizationCRUD, WebhookCRUD
from app.db import get_session
from app.models import Camera, Detection, Organization, Role, UserRole
from app.schemas.detections import (
BOXES_PATTERN,
Expand Down Expand Up @@ -154,32 +157,45 @@ async def fetch_detections(
@router.get("/unlabeled/fromdate", status_code=status.HTTP_200_OK, summary="Fetch all the unlabeled detections")
async def fetch_unlabeled_detections(
from_date: datetime = Query(),
detections: DetectionCRUD = Depends(get_detection_crud),
cameras: CameraCRUD = Depends(get_camera_crud),
limit: Optional[int] = Query(15, description="Maximum number of detections to fetch"),
offset: Optional[int] = Query(0, description="Number of detections to skip before starting to fetch"),
session: AsyncSession = Depends(get_session),
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
) -> List[DetectionWithUrl]:
telemetry_client.capture(token_payload.sub, event="unacknowledged-fetch")

bucket = s3_service.get_bucket(s3_service.resolve_bucket_name(token_payload.organization_id))

def get_url(detection: Detection) -> str:
return bucket.get_public_url(detection.bucket_key)
telemetry_client.capture(token_payload.sub, event="detections-fetch-unlabeled")

if UserRole.ADMIN in token_payload.scopes:
all_unck_detections = await detections.fetch_all(
filter_pair=("is_wildfire", None), inequality_pair=("created_at", ">=", from_date)
# Custom SQL query to fetch detections along with corresponding organization_id
query = await session.exec(
select(Detection, Camera.organization_id) # type: ignore[attr-defined]
.join(Camera, Detection.camera_id == Camera.id) # type: ignore[arg-type]
.where(Detection.is_wildfire.is_(None)) # type: ignore[union-attr]
.where(Detection.created_at >= from_date)
.limit(limit)
.offset(offset)
)
results = query.all()
unlabeled_detections = [Detection(**detection.__dict__) for detection, _ in results]
urls = [
s3_service.get_bucket(s3_service.resolve_bucket_name(org_id)).get_public_url(det.bucket_key)
for det, org_id in results
]
else:
org_cams = await cameras.fetch_all(filter_pair=("organization_id", token_payload.organization_id))
all_unck_detections = await detections.fetch_all(
filter_pair=("is_wildfire", None),
in_pair=("camera_id", [camera.id for camera in org_cams]),
inequality_pair=("created_at", ">=", from_date),
query = await session.exec(
select(Detection) # type: ignore[attr-defined]
.join(Camera, Detection.camera_id == Camera.id) # type: ignore[arg-type]
.where(Detection.is_wildfire.is_(None)) # type: ignore[union-attr]
.where(Detection.created_at >= from_date)
.where(Camera.organization_id == token_payload.organization_id)
.limit(limit)
.offset(offset)
)
results = query.all()
unlabeled_detections = [Detection(**detection.__dict__) for detection in results]
bucket = s3_service.get_bucket(s3_service.resolve_bucket_name(token_payload.organization_id))
urls = [bucket.get_public_url(detection.bucket_key) for detection in unlabeled_detections]

urls = (get_url(detection) for detection in all_unck_detections)

return [DetectionWithUrl(**detection.model_dump(), url=url) for detection, url in zip(all_unck_detections, urls)]
return [DetectionWithUrl(**detection.model_dump(), url=url) for detection, url in zip(unlabeled_detections, urls)]


@router.patch("/{detection_id}/label", status_code=status.HTTP_200_OK, summary="Label the nature of the detection")
Expand Down
2 changes: 1 addition & 1 deletion src/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def get_current_user(


async def dispatch_webhook(url: str, payload: BaseModel) -> None:
async with AsyncClient() as client:
async with AsyncClient(timeout=5) as client:
try:
response = await client.post(url, json=payload.model_dump_json())
response.raise_for_status()
Expand Down
8 changes: 8 additions & 0 deletions src/app/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ async def fetch_all(
filter_pair: Union[Tuple[str, Any], None] = None,
in_pair: Union[Tuple[str, List], None] = None,
inequality_pair: Optional[Tuple[str, str, Any]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[ModelType]:
statement = select(self.model) # type: ignore[var-annotated]
if isinstance(filter_pair, tuple):
Expand All @@ -84,6 +86,12 @@ async def fetch_all(
else:
raise ValueError(f"Unsupported inequality operator: {op}")

if offset is not None:
statement = statement.offset(offset)

if limit is not None:
statement = statement.limit(limit)

result = await self.session.exec(statement=statement)
return [r for r in result]

Expand Down

0 comments on commit 40cda4a

Please sign in to comment.