diff --git a/hot_fair_utilities/training/yolo_v8_v1/train.py b/hot_fair_utilities/training/yolo_v8_v1/train.py index e9ff763..05d2ebe 100644 --- a/hot_fair_utilities/training/yolo_v8_v1/train.py +++ b/hot_fair_utilities/training/yolo_v8_v1/train.py @@ -11,9 +11,9 @@ from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight from ...utils import compute_iou_chart_from_yolo_results, get_yolo_iou_metrics # Get environment variables with fallbacks -ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute())) -DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training"))) -LOGS_ROOT = str(Path(os.getenv("YOLO_LOGS_ROOT", ROOT / "checkpoints"))) +# ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute())) +# DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training"))) +# LOGS_ROOT = str(Path(os.getenv("YOLO_LOGS_ROOT", ROOT / "checkpoints"))) # Different hyperparameters from default in YOLOv8 release models # https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml @@ -118,6 +118,7 @@ def train( epochs=int(epochs), resume=resume, deterministic=False, + save_dir= os.path.join(output_path), device=[int(i) for i in gpu.split(",")] if "," in gpu else gpu, **kwargs, ) diff --git a/hot_fair_utilities/training/yolo_v8_v2/train.py b/hot_fair_utilities/training/yolo_v8_v2/train.py index 9afd1c7..37467ba 100644 --- a/hot_fair_utilities/training/yolo_v8_v2/train.py +++ b/hot_fair_utilities/training/yolo_v8_v2/train.py @@ -10,8 +10,8 @@ # Reader imports from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight -ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute())) -DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training"))) +# ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute())) +# DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training"))) HYPERPARAM_CHANGES = { @@ -80,6 +80,7 @@ def train(data, weights, epochs, batch_size, pc, output_path, dataset_yaml_path, epochs=int(epochs), resume=resume, deterministic=False, + save_dir= os.path.join(output_path), device=[int(i) for i in gpu.split(",")] if "," in gpu else gpu, **kwargs, )