diff --git a/backend/core/admin.py b/backend/core/admin.py index ccd5af55..20f4947b 100644 --- a/backend/core/admin.py +++ b/backend/core/admin.py @@ -8,12 +8,12 @@ @admin.register(Dataset) class DatasetAdmin(geoadmin.OSMGeoAdmin): - list_display = ["name", "created_by"] + list_display = ["name", "user"] @admin.register(Model) class ModelAdmin(geoadmin.OSMGeoAdmin): - list_display = ["get_dataset_id", "name", "status", "created_at", "created_by"] + list_display = ["get_dataset_id", "name", "status", "created_at", "user"] def get_dataset_id(self, obj): return obj.dataset.id @@ -28,7 +28,7 @@ class TrainingAdmin(geoadmin.OSMGeoAdmin): "description", "status", "zoom_level", - "created_by", + "user", "accuracy", ] list_filter = ["status"] @@ -47,3 +47,17 @@ class FeedbackAOIAdmin(geoadmin.OSMGeoAdmin): @admin.register(Feedback) class FeedbackAdmin(geoadmin.OSMGeoAdmin): list_display = ["feedback_type", "training", "user", "created_at"] + + +@admin.register(Banner) +class BannerAdmin(admin.ModelAdmin): + list_display = ("message", "start_date", "end_date", "is_displayable") + list_filter = ("start_date", "end_date") + search_fields = ("message",) + readonly_fields = ("is_displayable",) + + def is_displayable(self, obj): + return obj.is_displayable() + + is_displayable.boolean = True + is_displayable.short_description = "Currently Displayable" diff --git a/backend/core/models.py b/backend/core/models.py index 60247ff9..3964587d 100644 --- a/backend/core/models.py +++ b/backend/core/models.py @@ -2,7 +2,7 @@ from django.contrib.postgres.fields import ArrayField from django.core.validators import MaxValueValidator, MinValueValidator from django.db import models - +from django.utils import timezone from login.models import OsmUser # Create your models here. @@ -15,7 +15,7 @@ class DatasetStatus(models.IntegerChoices): DRAFT = -1 name = models.CharField(max_length=255) - created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) + user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) last_modified = models.DateTimeField(auto_now=True) created_at = models.DateTimeField(auto_now_add=True) source_imagery = models.URLField(blank=True, null=True) @@ -47,6 +47,11 @@ class Label(models.Model): class Model(models.Model): + BASE_MODEL_CHOICES = ( + ("RAMP", "RAMP"), + ("YOLO", "YOLO"), + ) + class ModelStatus(models.IntegerChoices): ARCHIVED = 1 PUBLISHED = 0 @@ -57,9 +62,12 @@ class ModelStatus(models.IntegerChoices): created_at = models.DateTimeField(auto_now_add=True) last_modified = models.DateTimeField(auto_now=True) description = models.TextField(max_length=500, null=True, blank=True) - created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) + user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) published_training = models.PositiveIntegerField(null=True, blank=True) - status = models.IntegerField(default=-1, choices=ModelStatus.choices) # + status = models.IntegerField(default=-1, choices=ModelStatus.choices) + base_model = models.CharField( + choices=BASE_MODEL_CHOICES, default="RAMP", max_length=10 + ) class Training(models.Model): @@ -81,7 +89,7 @@ class Training(models.Model): models.PositiveIntegerField(), size=4, ) - created_by = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) + user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) started_at = models.DateTimeField(null=True, blank=True) finished_at = models.DateTimeField(null=True, blank=True) accuracy = models.FloatField(null=True, blank=True) @@ -89,6 +97,7 @@ class Training(models.Model): chips_length = models.PositiveIntegerField(default=0) batch_size = models.PositiveIntegerField() freeze_layers = models.BooleanField(default=False) + centroid = geomodels.PointField(srid=4326, null=True, blank=True) class Feedback(models.Model): @@ -146,6 +155,19 @@ class ApprovedPredictions(models.Model): srid=4326 ) ## Making this geometry field to support point/line prediction later on approved_at = models.DateTimeField(auto_now_add=True) - approved_by = models.ForeignKey( - OsmUser, to_field="osm_id", on_delete=models.CASCADE - ) + user = models.ForeignKey(OsmUser, to_field="osm_id", on_delete=models.CASCADE) + + +class Banner(models.Model): + message = models.TextField() + start_date = models.DateTimeField(default=timezone.now) + end_date = models.DateTimeField(null=True, blank=True) + + def is_displayable(self): + now = timezone.now() + return (self.start_date <= now) and ( + self.end_date is None or self.end_date >= now + ) + + def __str__(self): + return self.message diff --git a/backend/core/serializers.py b/backend/core/serializers.py index fbcd1efc..6896dbe0 100644 --- a/backend/core/serializers.py +++ b/backend/core/serializers.py @@ -18,14 +18,14 @@ class Meta: model = Dataset fields = "__all__" # defining all the fields to be included in curd for now , we can restrict few if we want read_only_fields = ( - "created_by", + "user", "created_at", "last_modified", ) def create(self, validated_data): user = self.context["request"].user - validated_data["created_by"] = user + validated_data["user"] = user return super().create(validated_data) @@ -46,7 +46,7 @@ class Meta: class ModelSerializer(serializers.ModelSerializer): - created_by = UserSerializer(read_only=True) + user = UserSerializer(read_only=True) accuracy = serializers.SerializerMethodField() thumbnail_url = serializers.SerializerMethodField() @@ -56,13 +56,13 @@ class Meta: read_only_fields = ( "created_at", "last_modified", - "created_by", + "user", "published_training", ) def create(self, validated_data): user = self.context["request"].user - validated_data["created_by"] = user + validated_data["user"] = user return super().create(validated_data) # def get_training(self, obj): @@ -393,3 +393,14 @@ def validate(self, data): data["area_threshold"] ) return data + + +class BannerSerializer(serializers.ModelSerializer): + class Meta: + model = Banner + fields = [ + "id", + "message", + "start_date", + "end_date", + ] diff --git a/backend/core/tasks.py b/backend/core/tasks.py index 05b674a1..fc9bc303 100644 --- a/backend/core/tasks.py +++ b/backend/core/tasks.py @@ -2,18 +2,13 @@ import logging import os import shutil +import subprocess import sys import tarfile import traceback from shutil import rmtree from celery import shared_task -from django.conf import settings -from django.contrib.gis.db.models.aggregates import Extent -from django.contrib.gis.geos import GEOSGeometry -from django.shortcuts import get_object_or_404 -from django.utils import timezone - from core.models import AOI, Feedback, FeedbackAOI, FeedbackLabel, Label, Training from core.serializers import ( AOISerializer, @@ -23,6 +18,11 @@ LabelFileSerializer, ) from core.utils import bbox, is_dir_empty +from django.conf import settings +from django.contrib.gis.db.models.aggregates import Extent +from django.contrib.gis.geos import GEOSGeometry +from django.shortcuts import get_object_or_404 +from django.utils import timezone logger = logging.getLogger(__name__) @@ -135,6 +135,10 @@ def train_model( raise ValueError( f"No AOI is attached with supplied dataset id:{dataset_id}, Create AOI first", ) + first_aoi_centroid = aois[0].geom.centroid + training_instance.centroid = first_aoi_centroid + training_instance.save() + for obj in aois: bbox_coords = bbox(obj.geom.coords[0]) for z in zoom_level: @@ -309,6 +313,18 @@ def train_model( ) as f: f.write(json.dumps(aoi_serializer.data)) + tippecanoe_command = f"""tippecanoe -o {os.path.join(output_path,"meta.pmtiles")} -Z7 -z18 -L aois:{ os.path.join(output_path, "aois.geojson")} -L labels:{os.path.join(output_path, "labels.geojson")} --force --read-parallel -rg --drop-densest-as-needed""" + logging.info("Starting to generate vector tiles for aois and labels") + try: + result = subprocess.run( + tippecanoe_command, shell=True, check=True, capture_output=True + ) + logging.info(result.stdout.decode("utf-8")) + except subprocess.CalledProcessError as ex: + logger.error(ex.output) + raise ex + logging.info("Vector tile generation done !") + # copy aois and labels to preprocess output before compressing it to tar shutil.copyfile( os.path.join(output_path, "aois.geojson"), @@ -332,7 +348,7 @@ def train_model( training_instance.save() response = {} response["accuracy"] = float(final_accuracy) - # response["model_path"] = os.path.join(output_path, "checkpoint.tf") + response["tiles_path"] = os.path.join(output_path, "meta.pmtiles") response["model_path"] = os.path.join(output_path, "checkpoint.h5") response["graph_path"] = os.path.join(output_path, "graphs") sys.stdout = sys.__stdout__ diff --git a/backend/core/urls.py b/backend/core/urls.py index b54bc169..9a236206 100644 --- a/backend/core/urls.py +++ b/backend/core/urls.py @@ -7,6 +7,7 @@ from .views import ( # APIStatus, AOIViewSet, ApprovedPredictionsViewSet, + BannerViewSet, ConflateGeojson, DatasetViewSet, FeedbackAOIViewset, @@ -26,6 +27,7 @@ UsersView, download_training_data, geojson2osmconverter, + get_kpi_stats, publish_training, run_task_status, ) @@ -44,6 +46,7 @@ router.register(r"feedback", FeedbackViewset) router.register(r"feedback-aoi", FeedbackAOIViewset) router.register(r"feedback-label", FeedbackLabelViewset) +router.register(r"banner", BannerViewSet) urlpatterns = [ @@ -71,6 +74,7 @@ "workspace/download//", TrainingWorkspaceDownloadView.as_view() ), path("workspace//", TrainingWorkspaceView.as_view()), + path("kpi/stats/", get_kpi_stats, name="get_kpi_stats"), ] if settings.ENABLE_PREDICTION_API: urlpatterns.append(path("prediction/", PredictionView.as_view())) diff --git a/backend/core/views.py b/backend/core/views.py index c89e5974..8c02394a 100644 --- a/backend/core/views.py +++ b/backend/core/views.py @@ -23,9 +23,15 @@ StreamingHttpResponse, ) from django.shortcuts import get_object_or_404, redirect +from django.utils import timezone +from django.utils.decorators import method_decorator +from django.views.decorators.cache import cache_page +from django.views.decorators.vary import vary_on_cookie, vary_on_headers from django_filters.rest_framework import DjangoFilterBackend from drf_yasg.utils import swagger_auto_schema from geojson2osm import geojson2osm +from login.authentication import OsmAuthentication +from login.permissions import IsAdminUser, IsOsmAuthenticated, IsStaffUser from orthogonalizer import othogonalize_poly from osmconflator import conflate_geojson from rest_framework import decorators, filters, serializers, status, viewsets @@ -36,12 +42,10 @@ from rest_framework.views import APIView from rest_framework_gis.filters import InBBoxFilter, TMSTileFilter -from login.authentication import OsmAuthentication -from login.permissions import IsOsmAuthenticated - from .models import ( AOI, ApprovedPredictions, + Banner, Dataset, Feedback, FeedbackAOI, @@ -54,6 +58,7 @@ from .serializers import ( AOISerializer, ApprovedPredictionsSerializer, + BannerSerializer, DatasetSerializer, FeedbackAOISerializer, FeedbackFileSerializer, @@ -82,7 +87,7 @@ class DatasetViewSet( ): # This is datasetviewset , will be tightly coupled with the models authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] - permission_allowed_methods = ["GET"] + public_methods = ["GET"] queryset = Dataset.objects.all() serializer_class = DatasetSerializer # connecting serializer @@ -105,7 +110,7 @@ class Meta: read_only_fields = ( "created_at", "status", - "created_by", + "user", "started_at", "finished_at", "accuracy", @@ -142,7 +147,7 @@ def create(self, validated_data): ) user = self.context["request"].user - validated_data["created_by"] = user + validated_data["user"] = user # create the model instance multimasks = validated_data.get("multimasks", False) input_contact_spacing = validated_data.get("input_contact_spacing", 0.75) @@ -187,17 +192,27 @@ class TrainingViewSet( ): # This is TrainingViewSet , will be tightly coupled with the models authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] - permission_allowed_methods = ["GET"] + public_methods = ["GET"] queryset = Training.objects.all() http_method_names = ["get", "post", "delete"] serializer_class = TrainingSerializer # connecting serializer filterset_fields = ["model", "status"] + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + serializer = self.get_serializer(instance) + feedback_count = Feedback.objects.filter( + training=instance.id + ).count() # cal feedback count + data = serializer.data + data["feedback_count"] = feedback_count + return Response(data, status=status.HTTP_200_OK) + class FeedbackViewset(viewsets.ModelViewSet): authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] - permission_allowed_methods = ["GET"] + public_methods = ["GET"] queryset = Feedback.objects.all() http_method_names = ["get", "post", "patch", "delete"] serializer_class = FeedbackSerializer # connecting serializer @@ -207,7 +222,7 @@ class FeedbackViewset(viewsets.ModelViewSet): class FeedbackAOIViewset(viewsets.ModelViewSet): authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] - permission_allowed_methods = ["GET"] + public_methods = ["GET"] queryset = FeedbackAOI.objects.all() http_method_names = ["get", "post", "patch", "delete"] serializer_class = FeedbackAOISerializer @@ -220,7 +235,7 @@ class FeedbackAOIViewset(viewsets.ModelViewSet): class FeedbackLabelViewset(viewsets.ModelViewSet): authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] - permission_allowed_methods = ["GET"] + public_methods = ["GET"] queryset = FeedbackLabel.objects.all() http_method_names = ["get", "post", "patch", "delete"] serializer_class = FeedbackLabelSerializer @@ -238,7 +253,7 @@ class ModelViewSet( ): # This is ModelViewSet , will be tightly coupled with the models authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] - permission_allowed_methods = ["GET"] + public_methods = ["GET"] queryset = Model.objects.all() filter_backends = ( InBBoxFilter, # it will take bbox like this api/v1/model/?in_bbox=-90,29,-89,35 , @@ -251,11 +266,11 @@ class ModelViewSet( "status": ["exact"], "created_at": ["exact", "gt", "gte", "lt", "lte"], "last_modified": ["exact", "gt", "gte", "lt", "lte"], - "created_by": ["exact"], + "user": ["exact"], "id": ["exact"], } ordering_fields = ["created_at", "last_modified", "id", "status"] - search_fields = ["name"] + search_fields = ["name", "id"] class ModelCentroidView(ListAPIView): @@ -288,7 +303,7 @@ class UsersView(ListAPIView): class AOIViewSet(viewsets.ModelViewSet): authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] - permission_allowed_methods = ["GET"] + public_methods = ["GET"] queryset = AOI.objects.all() serializer_class = AOISerializer # connecting serializer filter_backends = [DjangoFilterBackend] @@ -298,7 +313,7 @@ class AOIViewSet(viewsets.ModelViewSet): class LabelViewSet(viewsets.ModelViewSet): authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] - permission_allowed_methods = ["GET"] + public_methods = ["GET"] queryset = Label.objects.all() serializer_class = LabelSerializer # connecting serializer bbox_filter_field = "geom" @@ -337,7 +352,7 @@ def create(self, request, *args, **kwargs): class ApprovedPredictionsViewSet(viewsets.ModelViewSet): authentication_classes = [OsmAuthentication] permission_classes = [IsOsmAuthenticated] - permission_allowed_methods = ["GET"] + public_methods = ["GET"] queryset = ApprovedPredictions.objects.all() serializer_class = ApprovedPredictionsSerializer bbox_filter_field = "geom" @@ -577,7 +592,7 @@ def post(self, request, *args, **kwargs): model=training_instance.model, status="SUBMITTED", description=f"Feedback of Training {training_id}", - created_by=self.request.user, + user=self.request.user, zoom_level=zoom_level, epochs=epochs, batch_size=batch_size, @@ -712,16 +727,24 @@ def post(self, request, *args, **kwargs): def publish_training(request, training_id: int): """Publishes training for model""" training_instance = get_object_or_404(Training, id=training_id) + if training_instance.status != "FINISHED": return Response("Training is not FINISHED", status=404) if training_instance.accuracy < 70: return Response( - "Can't publish the training since it's accuracy is below 70 %", status=404 + "Can't publish the training since its accuracy is below 70%", status=404 ) + model_instance = get_object_or_404(Model, id=training_instance.model.id) + + # Check if the current user is the owner of the model + if model_instance.user != request.user: + return Response("You are not allowed to publish this training", status=403) + model_instance.published_training = training_instance.id model_instance.status = 0 model_instance.save() + return Response("Training Published", status=status.HTTP_201_CREATED) @@ -788,8 +811,8 @@ def get(self, request, lookup_dir=None): class TrainingWorkspaceDownloadView(APIView): - # authentication_classes = [OsmAuthentication] - # permission_classes = [IsOsmAuthenticated] + authentication_classes = [OsmAuthentication] + permission_classes = [IsOsmAuthenticated] def get(self, request, lookup_dir): base_dir = os.path.join(settings.TRAINING_WORKSPACE, lookup_dir) @@ -829,3 +852,36 @@ def get(self, request, lookup_dir): os.path.basename(base_dir) ) return response + + +class BannerViewSet(viewsets.ModelViewSet): + queryset = Banner.objects.all() + serializer_class = BannerSerializer + authentication_classes = [OsmAuthentication] + permission_classes = [IsAdminUser, IsStaffUser] + public_methods = ["GET"] + + def get_queryset(self): + now = timezone.now() + return Banner.objects.filter(start_date__lte=now).filter( + end_date__gte=now + ) | Banner.objects.filter(end_date__isnull=True) + + +@cache_page(60 * 15) ## Cache for 15 mins +# @vary_on_cookie , if you wanna do user specific cache +@api_view(["GET"]) +def get_kpi_stats(request): + total_models_with_status_published = Model.objects.filter(status=0).count() + total_registered_users = OsmUser.objects.count() + total_approved_predictions = ApprovedPredictions.objects.count() + total_feedback_labels = FeedbackLabel.objects.count() + + data = { + "total_models_published": total_models_with_status_published, + "total_registered_users": total_registered_users, + "total_accepted_predictions": total_approved_predictions, + "total_feedback_labels": total_feedback_labels, + } + + return Response(data) diff --git a/backend/login/admin.py b/backend/login/admin.py index ef46bed6..566991b1 100644 --- a/backend/login/admin.py +++ b/backend/login/admin.py @@ -1,10 +1,93 @@ +from django import forms from django.contrib import admin +from django.contrib.auth.forms import UserChangeForm, UserCreationForm +from django.db import models from .models import OsmUser -# Register your models here. + +class OsmUserCreationForm(UserCreationForm): + class Meta: + model = OsmUser + fields = ( + "username", + "email", + "osm_id", + "img_url", + "is_staff", + "is_superuser", + "is_active", + ) + + +class OsmUserChangeForm(UserChangeForm): + class Meta: + model = OsmUser + fields = ( + "username", + "email", + "osm_id", + "img_url", + "is_staff", + "is_superuser", + "is_active", + ) @admin.register(OsmUser) -class DatasetAdmin(admin.ModelAdmin): - list_display = ["osm_id", "username"] +class OsmUserAdmin(admin.ModelAdmin): + add_form = OsmUserCreationForm + form = OsmUserChangeForm + model = OsmUser + + list_display = [ + "osm_id", + "username", + "email", + "is_staff", + "is_superuser", + "last_login", + ] + list_filter = ["is_staff", "is_superuser", "is_active"] + search_fields = ["username", "email", "osm_id"] + readonly_fields = ["last_login", "date_joined"] + + fieldsets = ( + (None, {"fields": ("username", "osm_id", "email", "img_url")}), + ( + "Permissions", + { + "fields": ( + "is_active", + "is_staff", + "is_superuser", + "groups", + "user_permissions", + ) + }, + ), + ("Important dates", {"fields": ("last_login", "date_joined")}), + ) + + add_fieldsets = ( + ( + None, + { + "classes": ("wide",), + "fields": ( + "username", + "email", + "osm_id", + "img_url", + "is_staff", + "is_superuser", + "is_active", + ), + }, + ), + ) + + def formfield_for_dbfield(self, db_field, request, **kwargs): + if db_field.name == "username": + kwargs["validators"] = [] ## override the validation for sername + return super().formfield_for_dbfield(db_field, request, **kwargs) diff --git a/backend/login/authentication.py b/backend/login/authentication.py index 83e3c106..2fce1872 100644 --- a/backend/login/authentication.py +++ b/backend/login/authentication.py @@ -47,6 +47,7 @@ def authenticate(self, request): except Exception as ex: print(ex) + # raise ex raise exceptions.AuthenticationFailed( f"Osm Authentication Failed" ) # raise exception if user does not exist diff --git a/backend/login/permissions.py b/backend/login/permissions.py index fd8b482c..b40c1390 100644 --- a/backend/login/permissions.py +++ b/backend/login/permissions.py @@ -7,10 +7,52 @@ class IsOsmAuthenticated(permissions.BasePermission): def has_permission(self, request, view): - permission_allowed_methods = getattr(view, "permission_allowed_methods", []) - if request.method in permission_allowed_methods: # if request method is set to allowed give them permission + public_methods = getattr(view, "public_methods", []) + if request.method in public_methods: return True - if request.user: + + if request.user and request.user.is_authenticated: + # Global access + if request.user.is_staff or request.user.is_superuser: + return True + + return True + + return False + + def has_object_permission(self, request, view, obj): + + if request.method in permissions.SAFE_METHODS: + return True + + # Allow modification (PUT, DELETE) if the user is staff or admin + if request.user.is_staff or request.user.is_superuser: return True + ## if the object it is trying to access has user info + if hasattr(obj, "user"): + # in order to change it it needs to be in his/her name + if obj.user == request.user: + return True + else: + if request.method == "POST": + # if object doesn't have user in it then he has permission to access the object , considered as common object + return True + return False - return False \ No newline at end of file + +class IsAdminUser(permissions.BasePermission): + def has_permission(self, request, view): + public_methods = getattr(view, "public_methods", []) + if request.method in public_methods: + return True + return ( + request.user and request.user.is_authenticated and request.user.is_superuser + ) + + +class IsStaffUser(permissions.BasePermission): + def has_permission(self, request, view): + public_methods = getattr(view, "public_methods", []) + if request.method in public_methods: + return True + return request.user and request.user.is_authenticated and request.user.is_staff diff --git a/backend/requirements.txt b/backend/requirements.txt index e6edc6ac..d7b4e8cc 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,3 +1,4 @@ -r api-requirements.txt hot-fair-utilities==1.3.0 -tflite-runtime==2.14.0 \ No newline at end of file +tflite-runtime==2.14.0 +tippecanoe==2.45.0 \ No newline at end of file diff --git a/backend/tests/factories.py b/backend/tests/factories.py index 829c1b4e..0a407535 100644 --- a/backend/tests/factories.py +++ b/backend/tests/factories.py @@ -1,16 +1,16 @@ import factory -from login.models import OsmUser -from django.contrib.gis.geos import Polygon from core.models import ( - Dataset, AOI, - Label, - Model, - Training, + Dataset, Feedback, FeedbackAOI, FeedbackLabel, + Label, + Model, + Training, ) +from django.contrib.gis.geos import Polygon +from login.models import OsmUser class OsmUserFactory(factory.django.DjangoModelFactory): @@ -26,7 +26,7 @@ class Meta: name = "My test dataset" source_imagery = "https://tiles.openaerialmap.org/5ac4fc6f26964b0010033112/0/5ac4fc6f26964b0010033113/{z}/{x}/{y}" - created_by = factory.SubFactory(OsmUserFactory) + user = factory.SubFactory(OsmUserFactory) class AoiFactory(factory.django.DjangoModelFactory): @@ -67,7 +67,7 @@ class Meta: dataset = factory.SubFactory(DatasetFactory) name = "My test model" - created_by = factory.SubFactory(OsmUserFactory) + user = factory.SubFactory(OsmUserFactory) class TrainingFactory(factory.django.DjangoModelFactory): @@ -76,7 +76,7 @@ class Meta: model = factory.SubFactory(ModelFactory) description = "My very first training" - created_by = factory.SubFactory(OsmUserFactory) + user = factory.SubFactory(OsmUserFactory) epochs = 1 zoom_level = [20, 21] batch_size = 1 diff --git a/backend/tests/test_endpoints.py b/backend/tests/test_endpoints.py index 246993c5..e6461768 100644 --- a/backend/tests/test_endpoints.py +++ b/backend/tests/test_endpoints.py @@ -2,18 +2,19 @@ import os import shutil -from django.conf import settings import validators +from django.conf import settings from rest_framework import status from rest_framework.test import APILiveServerTestCase, RequestsClient + from .factories import ( - OsmUserFactory, - TrainingFactory, - DatasetFactory, AoiFactory, + DatasetFactory, + FeedbackAoiFactory, LabelFactory, ModelFactory, - FeedbackAoiFactory, + OsmUserFactory, + TrainingFactory, ) API_BASE = "http://testserver/api/v1" @@ -30,9 +31,9 @@ def setUp(self): # Create a request factory instance self.client = RequestsClient() self.user = OsmUserFactory(osm_id=123) - self.dataset = DatasetFactory(created_by=self.user) + self.dataset = DatasetFactory(user=self.user) self.aoi = AoiFactory(dataset=self.dataset) - self.model = ModelFactory(dataset=self.dataset, created_by=self.user) + self.model = ModelFactory(dataset=self.dataset, user=self.user) self.json_type_header = headersList.copy() self.json_type_header["content-type"] = "application/json" @@ -187,11 +188,11 @@ def test_create_training(self): ) self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) - self.training = TrainingFactory(model=self.model, created_by=self.user) + self.training = TrainingFactory(model=self.model, user=self.user) def test_create_label(self): self.label = LabelFactory(aoi=self.aoi) - self.training = TrainingFactory(model=self.model, created_by=self.user) + self.training = TrainingFactory(model=self.model, user=self.user) # create label @@ -236,7 +237,7 @@ def test_create_label(self): def test_fetch_feedbackAoi_osm_label(self): # create feedback aoi - training = TrainingFactory(model=self.model, created_by=self.user) + training = TrainingFactory(model=self.model, user=self.user) feedbackAoi = FeedbackAoiFactory(training=training, user=self.user) # download available osm data as labels for the feedback aoi @@ -249,7 +250,7 @@ def test_fetch_feedbackAoi_osm_label(self): self.assertEqual(res.status_code, status.HTTP_201_CREATED) def test_get_runStatus(self): - training = TrainingFactory(model=self.model, created_by=self.user) + training = TrainingFactory(model=self.model, user=self.user) # get running training status @@ -259,7 +260,7 @@ def test_get_runStatus(self): self.assertEqual(res.status_code, status.HTTP_200_OK) def test_submit_training_feedback(self): - training = TrainingFactory(model=self.model, created_by=self.user) + training = TrainingFactory(model=self.model, user=self.user) # apply feedback to training published checkpoints @@ -278,7 +279,7 @@ def test_submit_training_feedback(self): self.assertEqual(res.status_code, status.HTTP_400_BAD_REQUEST) def test_publish_training(self): - training = TrainingFactory(model=self.model, created_by=self.user) + training = TrainingFactory(model=self.model, user=self.user) # publish an unfinished training should not pass @@ -288,7 +289,7 @@ def test_publish_training(self): self.assertEqual(res.status_code, status.HTTP_404_NOT_FOUND) def test_get_GpxView(self): - training = TrainingFactory(model=self.model, created_by=self.user) + training = TrainingFactory(model=self.model, user=self.user) feedbackAoi = FeedbackAoiFactory(training=training, user=self.user) # generate aoi GPX view - aoi_id diff --git a/docker-compose-cpu.yml b/docker-compose-cpu.yml index 20777bcc..a6e6eae9 100644 --- a/docker-compose-cpu.yml +++ b/docker-compose-cpu.yml @@ -69,7 +69,7 @@ services: context: ./frontend dockerfile: Dockerfile.frontend container_name: frontend - command: npm start -- --host 0.0.0.0 --port 3000 + command: npm run dev -- --host 0.0.0.0 --port 3000 ports: - 3000:3000 depends_on: diff --git a/docker-compose.yml b/docker-compose.yml index ac37dbee..cb0c6460 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -74,7 +74,7 @@ services: context: ./frontend dockerfile: Dockerfile.frontend container_name: frontend - command: npm start -- --host 0.0.0.0 --port 3000 + command: npm run dev -- --host 0.0.0.0 --port 3000 ports: - 3000:3000 depends_on: diff --git a/frontend/Dockerfile.frontend b/frontend/Dockerfile.frontend new file mode 100644 index 00000000..0738b662 --- /dev/null +++ b/frontend/Dockerfile.frontend @@ -0,0 +1,15 @@ + +FROM node:20.18 + +WORKDIR /app + + +COPY . /app + + +RUN npm install --force + + +# RUN npm run build + +# EXPOSE 3000 \ No newline at end of file diff --git a/setup-ramp.sh b/setup-ramp.sh index 817844ee..84abdc31 100644 --- a/setup-ramp.sh +++ b/setup-ramp.sh @@ -2,29 +2,27 @@ ## To run this activate your venv and hit bash setup-ramp.sh -# Step 1: Create a new folder called 'ramp' outside fAIr -mkdir -p ramp -# Step 2: Install gdown for downloading files from Google Drive -pip install gdown +if [ ! -d "ramp_base" ]; then -# Step 3: Download BaseModel Checkpoint from Google Drive -gdown --fuzzy https://drive.google.com/uc?id=1YQsY61S_rGfJ_f6kLQq4ouYE2l3iRe1k + mkdir -p ramp_base -# Step 4: Clone the Ramp code repository -git clone https://github.com/kshitijrajsharma/ramp-code-fAIr.git ramp-code -# Step 5: Unzip the downloaded BaseModel checkpoint into the 'ramp' directory inside the cloned repository -unzip checkpoint.tf.zip -d ramp-code/ramp + pip install gdown -# Step 6: Define the current location for environment variables -RAMP_HOME="$(pwd)/ramp" + + gdown --fuzzy https://drive.google.com/uc?id=1YQsY61S_rGfJ_f6kLQq4ouYE2l3iRe1k + + git clone https://github.com/kshitijrajsharma/ramp-code-fAIr.git "$(pwd)/ramp_base/ramp-code" + + + unzip checkpoint.tf.zip -d "$(pwd)/ramp_base/ramp-code/ramp" + + echo "Setup complete. Please run 'source .env' to apply the environment variables." +fi + +RAMP_HOME="$(pwd)/ramp_base" TRAINING_WORKSPACE="$(pwd)/trainings" -# Step 7: Create a '.env' file with the exported variables echo "export RAMP_HOME=$RAMP_HOME" > .env echo "export TRAINING_WORKSPACE=$TRAINING_WORKSPACE" >> .env - -# Print success message -echo "Setup complete. Please run 'source .env' to apply the environment variables." -