Skip to content

Commit

Permalink
Add label upload endpoint and enhance label creation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Dec 19, 2024
1 parent 42498ba commit 4f3d5cd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
2 changes: 2 additions & 0 deletions backend/core/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
FeedbackViewset,
GenerateFeedbackAOIGpxView,
GenerateGpxView,
LabelUploadView,
LabelViewSet,
ModelCentroidView,
ModelViewSet,
Expand Down Expand Up @@ -52,6 +53,7 @@
urlpatterns = [
path("", include(router.urls)),
path("label/osm/fetch/<int:aoi_id>/", RawdataApiAOIView.as_view()),
path("labels/upload/<int:aoi_id>/", LabelUploadView.as_view(), name="label-upload"),
path(
"label/feedback/osm/fetch/<int:feedbackaoi_id>/",
RawdataApiFeedbackView.as_view(),
Expand Down
52 changes: 32 additions & 20 deletions backend/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,29 @@ class LabelViewSet(viewsets.ModelViewSet):
)
filterset_fields = ["aoi", "aoi__dataset"]

def create(self, request, *args, **kwargs):
aoi_id = request.data.get("aoi")
geom = request.data.get("geom")

existing_label = Label.objects.filter(aoi=aoi_id, geom=geom).first()

if existing_label:
serializer = LabelSerializer(existing_label, data=request.data)
else:
serializer = LabelSerializer(data=request.data)

if serializer.is_valid():
serializer.save()
return Response(serializer.data, status=status.HTTP_200_OK)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)


class LabelUploadView(APIView):
authentication_classes = [OsmAuthentication]
permission_classes = [IsOsmAuthenticated]
parser_classes = (MultiPartParser, FormParser)

def create(self, request, *args, **kwargs):
def post(self, request, aoi_id, *args, **kwargs):
geojson_file = request.FILES.get("geojson_file")
if geojson_file:
try:
Expand All @@ -383,29 +403,17 @@ def create(self, request, *args, **kwargs):
async_task(
"core.views.process_labels_geojson",
geojson_data,
request.data.get("aoi"),
aoi_id,
)
return Response(
{"status": "GeoJSON file is being processed"},
status=status.HTTP_202_ACCEPTED,
)
except (json.JSONDecodeError, ValidationError) as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)

aoi_id = request.data.get("aoi")
geom = request.data.get("geom")

existing_label = Label.objects.filter(aoi=aoi_id, geom=geom).first()

if existing_label:
serializer = LabelSerializer(existing_label, data=request.data)
else:
serializer = LabelSerializer(data=request.data)

if serializer.is_valid():
serializer.save()
return Response(serializer.data, status=status.HTTP_200_OK)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
return Response(
{"error": "No GeoJSON file provided"}, status=status.HTTP_400_BAD_REQUEST
)

def validate_geojson(self, geojson_data):
if geojson_data.get("type") != "FeatureCollection":
Expand All @@ -427,9 +435,8 @@ def validate_geojson(self, geojson_data):
)

# Validate the first feature with the serializer
aoi_id = self.request.data.get("aoi")
label_data = {
"aoi": aoi_id,
"aoi": self.kwargs.get("aoi_id"),
"geom": first_feature["geometry"],
**first_feature["properties"],
}
Expand All @@ -447,7 +454,12 @@ def process_labels_geojson(geojson_data, aoi_id):
geom = feature["geometry"]
properties = feature["properties"]
label_data = {"aoi": aoi_id, "geom": geom, **properties}
serializer = LabelSerializer(data=label_data)

existing_label = Label.objects.filter(aoi=aoi_id, geom=geom).first()
if existing_label:
serializer = LabelSerializer(existing_label, data=label_data)
else:
serializer = LabelSerializer(data=label_data)
if serializer.is_valid():
serializer.save()

Expand Down

0 comments on commit 4f3d5cd

Please sign in to comment.