Skip to content

Commit

Permalink
Save out detector artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
ksikka committed Dec 4, 2024
1 parent d5f01c1 commit 8b1a08d
Show file tree
Hide file tree
Showing 77 changed files with 893 additions and 333 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tests/**/*.mp4 filter=lfs diff=lfs merge=lfs -text
tests/**/*.png filter=lfs diff=lfs merge=lfs -text
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
ImageEncoderViT
===============

.. currentmodule:: lightning_pose.models.backbones.vit_img_encoder

.. autoclass:: ImageEncoderViT
:show-inheritance:

.. rubric:: Methods Summary

.. autosummary::

~ImageEncoderViT.forward

.. rubric:: Methods Documentation

.. automethod:: forward
2 changes: 0 additions & 2 deletions docs/api/lightning_pose.models.base.BaseSupervisedTracker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ BaseSupervisedTracker

~BaseSupervisedTracker.evaluate_labeled
~BaseSupervisedTracker.get_loss_inputs_labeled
~BaseSupervisedTracker.get_parameters
~BaseSupervisedTracker.test_step
~BaseSupervisedTracker.training_step
~BaseSupervisedTracker.validation_step
Expand All @@ -21,7 +20,6 @@ BaseSupervisedTracker

.. automethod:: evaluate_labeled
.. automethod:: get_loss_inputs_labeled
.. automethod:: get_parameters
.. automethod:: test_step
.. automethod:: training_step
.. automethod:: validation_step
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ SemiSupervisedTrackerMixin

~SemiSupervisedTrackerMixin.evaluate_unlabeled
~SemiSupervisedTrackerMixin.get_loss_inputs_unlabeled
~SemiSupervisedTrackerMixin.get_parameters
~SemiSupervisedTrackerMixin.training_step

.. rubric:: Methods Documentation

.. automethod:: evaluate_unlabeled
.. automethod:: get_loss_inputs_unlabeled
.. automethod:: get_parameters
.. automethod:: training_step
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ SemiSupervisedHeatmapTrackerMHCRNN
.. autosummary::

~SemiSupervisedHeatmapTrackerMHCRNN.get_loss_inputs_unlabeled
~SemiSupervisedHeatmapTrackerMHCRNN.get_parameters

.. rubric:: Methods Documentation

.. automethod:: get_loss_inputs_unlabeled
.. automethod:: get_parameters
6 changes: 0 additions & 6 deletions docs/api/lightning_pose.utils.io.load_label_csv_from_cfg.rst

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export_predictions_and_labeled_video
====================================

.. currentmodule:: lightning_pose.utils.scripts
.. currentmodule:: lightning_pose.utils.predictions

.. autofunction:: export_predictions_and_labeled_video
6 changes: 0 additions & 6 deletions docs/api/lightning_pose.utils.predictions.get_cfg_file.rst

This file was deleted.

6 changes: 0 additions & 6 deletions docs/api/lightning_pose.utils.predictions.make_cmap.rst

This file was deleted.

37 changes: 3 additions & 34 deletions lightning_pose/data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Dataset objects store images, labels, and functions for manipulation."""

from __future__ import annotations # python 3.8 compatibility for sphinx

import os
import re
Expand All @@ -21,7 +20,8 @@
MultiviewHeatmapLabeledExampleDict,
generate_heatmaps,
)
from lightning_pose.utils.io import get_keypoint_names
from lightning_pose.utils.io import get_context_img_paths, get_keypoint_names


# to ignore imports for sphix-autoapidoc
__all__ = [
Expand Down Expand Up @@ -133,7 +133,7 @@ def __getitem__(self, idx: int) -> BaseLabeledExampleDict:
transformed_images = self.pytorch_transform(transformed_images)

else:
context_img_paths = _get_context_img_paths(img_path)
context_img_paths = get_context_img_paths(img_path)
# read the images from image list to create dataset
images = []
for path in context_img_paths:
Expand Down Expand Up @@ -494,34 +494,3 @@ def __getitem__(self, idx: int) -> MultiviewHeatmapLabeledExampleDict:
view_names=self.view_names, # List[str]
)


def _get_context_img_paths(center_img_path: Path) -> list[Path]:
"""Given the path to a center image frame, return paths of 5 context frames
(n-2, n-1, n, n+1, n+2).
Negative indices are floored at 0.
"""
match = re.search(r"(\d+)", center_img_path.stem)
assert (
match is not None
), f"No frame index in filename, can't get context frames: {center_img_path.name}"

center_index_string = match.group()
center_index = int(center_index_string)

context_img_paths = []
for index in range(
center_index - 2, center_index + 3
): # End at n+3 exclusive, n+2 inclusive.
# Negative indices are floored at 0.
index = max(index, 0)

# Add leading zeros to match center_index_string length.
index_string = str(index).zfill(len(center_index_string))

stem = center_img_path.stem.replace(center_index_string, index_string)
path = center_img_path.with_stem(stem)

context_img_paths.append(path)

return context_img_paths
46 changes: 36 additions & 10 deletions lightning_pose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,25 @@
from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict
from typeguard import typechecked

from lightning_pose.utils.cropzoom import (
generate_cropped_labeled_frames,
generate_cropped_video,
)
from lightning_pose.utils import pretty_print_cfg, pretty_print_str
from lightning_pose.utils.io import (
check_video_paths,
ckpt_path_from_base_path,
return_absolute_data_paths,
return_absolute_path,
)
from lightning_pose.utils.predictions import load_model_from_checkpoint, predict_dataset
from lightning_pose.utils.predictions import (
export_predictions_and_labeled_video,
load_model_from_checkpoint,
predict_dataset,
)
from lightning_pose.utils.scripts import (
calculate_train_batches,
compute_metrics,
export_predictions_and_labeled_video,
get_callbacks,
get_data_module,
get_dataset,
Expand Down Expand Up @@ -168,7 +176,9 @@ def train(cfg: DictConfig) -> None:
# Post-training analysis
# ----------------------------------------------------------------------------------
# get best ckpt
best_ckpt = os.path.abspath(trainer.checkpoint_callback.best_model_path)
best_ckpt = ckpt_path_from_base_path(
base_path=hydra_output_directory, model_name=cfg.model.model_name
)
print(f"Best checkpoint: {os.path.basename(best_ckpt)}")
# check if best_ckpt is a file
if not os.path.isfile(best_ckpt):
Expand Down Expand Up @@ -219,10 +229,20 @@ def train(cfg: DictConfig) -> None:
preds_file = multiview_pred_files
compute_metrics(cfg=cfg, preds_file=preds_file, data_module=data_module_pred)

is_detector = (
cfg.get("detector") is not None and cfg.detector.get("crop_ratio") is not None
)
if is_detector:
generate_cropped_labeled_frames(
root_directory=Path(data_dir),
output_directory=Path(hydra_output_directory),
detector_cfg=cfg.detector,
)

# ----------------------------------------------------------------------------------
# predict folder of videos
# ----------------------------------------------------------------------------------
if cfg.eval.predict_vids_after_training:
if cfg.eval.predict_vids_after_training or is_detector:
pretty_print_str("Predicting videos...")
if cfg.eval.test_videos_directory is None:
filenames = []
Expand All @@ -235,6 +255,7 @@ def train(cfg: DictConfig) -> None:

for video_file in filenames:
assert os.path.isfile(video_file)

pretty_print_str(f"Predicting video: {video_file}...")
# get save name for prediction csv file
video_pred_dir = os.path.join(hydra_output_directory, "video_preds")
Expand All @@ -255,18 +276,23 @@ def train(cfg: DictConfig) -> None:
trainer=trainer,
model=model,
data_module=data_module_pred,
save_heatmaps=cfg.eval.get("predict_vids_after_training_save_heatmaps", False),
save_heatmaps=cfg.eval.get(
"predict_vids_after_training_save_heatmaps", False
),
)
# compute and save various metrics
# try:

compute_metrics(
cfg=cfg,
preds_file=prediction_csv_file,
data_module=data_module_pred,
)
# except Exception as e:
# print(f"Error predicting on video {video_file}:\n{e}")
# continue

if is_detector:
generate_cropped_video(
video_path=Path(video_file),
detector_model_dir=Path(hydra_output_directory),
detector_cfg=cfg.detector,
)

# ----------------------------------------------------------------------------------
# predict on OOD frames
Expand Down
57 changes: 38 additions & 19 deletions lightning_pose/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Path handling functions."""
from __future__ import annotations # python 3.8 compatibility for sphinx


import os
from typing import List, Optional, Tuple, Union

from pathlib import Path
import re
import pandas as pd
from omegaconf import DictConfig, ListConfig
from typeguard import typechecked
Expand All @@ -11,7 +15,6 @@
__all__ = [
"ckpt_path_from_base_path",
"check_if_semi_supervised",
"load_label_csv_from_cfg",
"get_keypoint_names",
"return_absolute_path",
"return_absolute_data_paths",
Expand All @@ -25,7 +28,6 @@ def ckpt_path_from_base_path(
base_path: str,
model_name: str,
logging_dir_name: str = "tb_logs/",
version: int = 0,
) -> str:
"""Given a path to a hydra output with trained model, extract the model .ckpt file.
Expand All @@ -48,7 +50,7 @@ def ckpt_path_from_base_path(
base_path,
logging_dir_name, # may change when we switch from Tensorboard
model_name, # get the name string of the model (determined pre-training)
"version_%i" % version, # always version_0 because ptl starts a version_0 dir
"version_0", # always version_0 because enable_version_counter=False
"checkpoints",
"*.ckpt",
)
Expand Down Expand Up @@ -84,22 +86,6 @@ def check_if_semi_supervised(losses_to_use: Union[ListConfig, list, None] = None
return semi_supervised


@typechecked
def load_label_csv_from_cfg(cfg: Union[DictConfig, dict]) -> pd.DataFrame:
"""Helper function for easy loading.
Args:
cfg: DictConfig
Returns:
pd.DataFrame
"""

csv_file = os.path.join(cfg["data"]["data_dir"], cfg["data"]["csv_file"])
labels_df = pd.read_csv(csv_file, header=[0, 1, 2], index_col=0)
return labels_df


@typechecked
def get_keypoint_names(
cfg: Optional[DictConfig] = None,
Expand Down Expand Up @@ -252,3 +238,36 @@ def check_video_paths(
assert f.endswith(".mp4"), "video files must be mp4 format!"

return filenames


@typechecked
def get_context_img_paths(center_img_path: Path) -> list[Path]:
"""Given the path to a center image frame, return paths of 5 context frames
(n-2, n-1, n, n+1, n+2).
Negative indices are floored at 0.
"""
match = re.search(r"(\d+)", center_img_path.stem)
assert (
match is not None
), f"No frame index in filename, can't get context frames: {center_img_path.name}"

center_index_string = match.group()
center_index = int(center_index_string)

context_img_paths = []
for index in range(
center_index - 2, center_index + 3
): # End at n+3 exclusive, n+2 inclusive.
# Negative indices are floored at 0.
index = max(index, 0)

# Add leading zeros to match center_index_string length.
index_string = str(index).zfill(len(center_index_string))

stem = center_img_path.stem.replace(center_index_string, index_string)
path = center_img_path.with_stem(stem)

context_img_paths.append(path)

return context_img_paths
Loading

0 comments on commit 8b1a08d

Please sign in to comment.