Skip to content

Commit

Permalink
Fix tests and post-training analysis behavior (#221)
Browse files Browse the repository at this point in the history
* Fix: post-training analysis should use result of best-model checkpointing

* Fix: create split dataset eagerly for test_train multiview test

* Fix: predict using the config that includes hydra overrides
  • Loading branch information
ksikka authored Nov 19, 2024
1 parent 14ea627 commit d5f01c1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
16 changes: 10 additions & 6 deletions lightning_pose/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Example model training function."""

import copy
import os
import random
import shutil
Expand All @@ -19,7 +19,7 @@
return_absolute_data_paths,
return_absolute_path,
)
from lightning_pose.utils.predictions import predict_dataset
from lightning_pose.utils.predictions import load_model_from_checkpoint, predict_dataset
from lightning_pose.utils.scripts import (
calculate_train_batches,
compute_metrics,
Expand Down Expand Up @@ -176,7 +176,7 @@ def train(cfg: DictConfig) -> None:

# make unaugmented data_loader if necessary
if cfg.training.imgaug != "default":
cfg_pred = cfg.copy()
cfg_pred = copy.deepcopy(cfg)
cfg_pred.training.imgaug = "default"
imgaug_transform_pred = get_imgaug_transform(cfg=cfg_pred)
dataset_pred = get_dataset(
Expand All @@ -187,6 +187,13 @@ def train(cfg: DictConfig) -> None:
else:
data_module_pred = data_module

model = load_model_from_checkpoint(
cfg=cfg,
ckpt_file=best_ckpt,
eval=True,
data_module=data_module_pred,
)

# ----------------------------------------------------------------------------------
# predict on all labeled frames (train/val/test)
# ----------------------------------------------------------------------------------
Expand All @@ -200,7 +207,6 @@ def train(cfg: DictConfig) -> None:
trainer=trainer,
model=model,
data_module=data_module_pred,
ckpt_file=best_ckpt,
preds_file=preds_file,
)
# compute and save various metrics
Expand Down Expand Up @@ -244,7 +250,6 @@ def train(cfg: DictConfig) -> None:
export_predictions_and_labeled_video(
video_file=video_file,
cfg=cfg,
ckpt_file=best_ckpt,
prediction_csv_file=prediction_csv_file,
labeled_mp4_file=labeled_mp4_file,
trainer=trainer,
Expand Down Expand Up @@ -299,7 +304,6 @@ def train(cfg: DictConfig) -> None:
trainer=trainer,
model=model,
data_module=data_module_ood,
ckpt_file=best_ckpt,
preds_file=preds_file_ood,
)
# compute and save various metrics
Expand Down
2 changes: 1 addition & 1 deletion scripts/predict_new_vids.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def predict_videos_in_dir(cfg: DictConfig):
absolute_cfg_path = return_absolute_path(hydra_relative_path, n_dirs_back=2)

# load model
model_cfg = OmegaConf.load(os.path.join(absolute_cfg_path, ".hydra/config.yaml"))
model_cfg = OmegaConf.load(os.path.join(absolute_cfg_path, "config.yaml"))
ckpt_file = ckpt_path_from_base_path(
base_path=absolute_cfg_path, model_name=model_cfg.model.model_name
)
Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def cfg_multiview() -> dict:
return cfg


def make_multiview_dataset() -> None:

def create_multiview_dataset_if_not_exists() -> None:
# create multiview dataset
repo_dir = os.path.dirname(os.path.dirname(os.path.join(__file__)))
base_dir = os.path.join(repo_dir, TOY_DATA_ROOT_DIR)
Expand Down Expand Up @@ -238,7 +237,6 @@ def heatmap_dataset(cfg, imgaug_transform) -> HeatmapDataset:
@pytest.fixture
def multiview_heatmap_dataset(cfg_multiview, imgaug_transform) -> MultiviewHeatmapDataset:
"""Create a dataset for heatmap models from toy data."""
make_multiview_dataset()
# setup
cfg_tmp = copy.deepcopy(cfg_multiview)
cfg_tmp.model.model_type = "heatmap"
Expand Down Expand Up @@ -604,3 +602,6 @@ def _run_model_test(cfg, data_module, video_dataloader, trainer, remove_logs_fn)
remove_logs_fn()

return _run_model_test


create_multiview_dataset_if_not_exists()

0 comments on commit d5f01c1

Please sign in to comment.