Skip to content

Commit

Permalink
Added log save dir
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Nov 12, 2024
1 parent 4a1f2a6 commit 4d46b4a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
7 changes: 4 additions & 3 deletions hot_fair_utilities/training/yolo_v8_v1/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight
from ...utils import compute_iou_chart_from_yolo_results, get_yolo_iou_metrics
# Get environment variables with fallbacks
ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute()))
DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training")))
LOGS_ROOT = str(Path(os.getenv("YOLO_LOGS_ROOT", ROOT / "checkpoints")))
# ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute()))
# DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training")))
# LOGS_ROOT = str(Path(os.getenv("YOLO_LOGS_ROOT", ROOT / "checkpoints")))

# Different hyperparameters from default in YOLOv8 release models
# https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
Expand Down Expand Up @@ -118,6 +118,7 @@ def train(
epochs=int(epochs),
resume=resume,
deterministic=False,
save_dir= os.path.join(output_path),
device=[int(i) for i in gpu.split(",")] if "," in gpu else gpu,
**kwargs,
)
Expand Down
5 changes: 3 additions & 2 deletions hot_fair_utilities/training/yolo_v8_v2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
# Reader imports
from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight

ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute()))
DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training")))
# ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute()))
# DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training")))


HYPERPARAM_CHANGES = {
Expand Down Expand Up @@ -80,6 +80,7 @@ def train(data, weights, epochs, batch_size, pc, output_path, dataset_yaml_path,
epochs=int(epochs),
resume=resume,
deterministic=False,
save_dir= os.path.join(output_path),
device=[int(i) for i in gpu.split(",")] if "," in gpu else gpu,
**kwargs,
)
Expand Down

0 comments on commit 4d46b4a

Please sign in to comment.