diff --git a/backend/core/tasks.py b/backend/core/tasks.py index bfbf7c3b..b037ab01 100644 --- a/backend/core/tasks.py +++ b/backend/core/tasks.py @@ -7,9 +7,7 @@ import traceback from shutil import rmtree -import hot_fair_utilities -import ramp.utils -import tensorflow as tf + from celery import shared_task from core.models import AOI, Feedback, FeedbackAOI, FeedbackLabel, Label, Training from core.serializers import ( @@ -25,8 +23,6 @@ from django.contrib.gis.geos import GEOSGeometry from django.shortcuts import get_object_or_404 from django.utils import timezone -from hot_fair_utilities import preprocess, train -from hot_fair_utilities.training import run_feedback from predictor import download_imagery, get_start_end_download_coords logger = logging.getLogger(__name__) @@ -82,6 +78,13 @@ def train_model( feedback=None, freeze_layers=False, ): + #importing them here so that it won't be necessary when sending tasks ( api only) + import hot_fair_utilities + import ramp.utils + import tensorflow as tf + from hot_fair_utilities import preprocess, train + from hot_fair_utilities.training import run_feedback + training_instance = get_object_or_404(Training, id=training_id) training_instance.status = "RUNNING" training_instance.started_at = timezone.now()