From eb0eb4f77ed80729231105edd3002aafd7e9832c Mon Sep 17 00:00:00 2001 From: kshitijrajsharma Date: Fri, 6 Dec 2024 13:04:14 +0100 Subject: [PATCH] fix(tasks): enhanced model training zipping by adding input image path and improving file copying --- backend/core/tasks.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/backend/core/tasks.py b/backend/core/tasks.py index 95308f72..451c3d60 100644 --- a/backend/core/tasks.py +++ b/backend/core/tasks.py @@ -225,6 +225,9 @@ def ramp_model_training( shutil.rmtree(output_path) shutil.copytree(final_model_path, os.path.join(output_path, "checkpoint.tf")) shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed")) + shutil.copytree( + model_input_image_path, os.path.join(output_path, "preprocessed", "input") + ) graph_output_path = f"{base_path}/train/graphs" shutil.copytree(graph_output_path, os.path.join(output_path, "graphs")) @@ -374,11 +377,30 @@ def yolo_model_training( os.path.join(os.path.dirname(output_model_path), "best.onnx"), os.path.join(output_path, "checkpoint.onnx"), ) + shutil.copyfile( + os.path.join(os.path.dirname(output_model_path), "best.onnx"), + os.path.join(output_path, "checkpoint.onnx"), + ) # shutil.copyfile(os.path.dirname(output_model_path,'checkpoint.tflite'), os.path.join(output_path, "checkpoint.tflite")) shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed")) + shutil.copytree( + model_input_image_path, os.path.join(output_path, "preprocessed", "input") + ) os.makedirs(os.path.join(output_path, model), exist_ok=True) + shutil.copytree( + os.path.join(yolo_data_dir, "images"), + os.path.join(output_path, model, "images"), + ) + shutil.copytree( + os.path.join(yolo_data_dir, "labels"), + os.path.join(output_path, model, "labels"), + ) + shutil.copyfile( + os.path.join(yolo_data_dir, "yolo_dataset.yaml"), + os.path.join(output_path, model, "yolo_dataset.yaml"), + ) shutil.copytree( os.path.join(yolo_data_dir, "images"), os.path.join(output_path, model, "images"), @@ -473,7 +495,7 @@ def train_model( if training_instance.task_id is None or training_instance.task_id.strip() == "": training_instance.task_id = train_model.request.id training_instance.save() - log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}.log") + log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}_log.txt") if model_instance.base_model == "YOLO_V8_V1" and settings.YOLO_HOME is None: raise ValueError("YOLO Home is not configured") @@ -481,10 +503,9 @@ def train_model( raise ValueError("Ramp Home is not configured") try: - with open(log_file, "a") as f: + with open(log_file, "w") as f: # redirect stdout to the log file sys.stdout = f - logging.info("Training Started") training_input_image_source, aoi_serializer, serialized_field = ( prepare_data( training_instance, dataset_id, feedback, zoom_level, source_imagery