Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix - store snapshotindex for the latest snapshot instead of the filename #110

Merged
merged 18 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion element_deeplabcut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,14 @@ def make(self, key):
task_mode, output_dir = (PoseEstimationTask & key).fetch1(
"task_mode", "pose_estimation_output_dir"
)

if not output_dir:
output_dir = PoseEstimationTask.infer_output_dir(
key, relative=True, mkdir=True
)
# update pose_estimation_output_dir
PoseEstimationTask.update1(
{**key, "pose_estimation_output_dir": output_dir.as_posix()}
)
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)

# Triger PoseEstimation
Expand Down
53 changes: 42 additions & 11 deletions element_deeplabcut/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import datajoint as dj
import inspect
import importlib
import os
import re
from pathlib import Path
import yaml

from element_interface.utils import find_full_path, dict_to_uuid
from .readers import dlc_reader

Expand Down Expand Up @@ -241,7 +243,7 @@ class ModelTraining(dj.Computed):
# https://github.com/DeepLabCut/DeepLabCut/issues/70

def make(self, key):
from deeplabcut import train_network # isort:skip
import deeplabcut

try:
from deeplabcut.utils.auxiliaryfunctions import (
Expand Down Expand Up @@ -288,13 +290,39 @@ def make(self, key):
)
model_train_folder = project_path / model_folder / "train"

# update path of the init_weight
with open(model_train_folder / "pose_cfg.yaml", "r") as f:
pose_cfg = yaml.safe_load(f)
init_weights_path = Path(pose_cfg["init_weights"])

if (
"pose_estimation_tensorflow/models/pretrained"
in init_weights_path.as_posix()
):
# this is the res_net models, construct new path here
init_weights_path = (
Path(deeplabcut.__path__[0])
/ "pose_estimation_tensorflow/models/pretrained"
/ init_weights_path.name
)
else:
# this is existing snapshot weights, update path here
init_weights_path = model_train_folder / init_weights_path.name

edit_config(
model_train_folder / "pose_cfg.yaml",
{"project_path": project_path.as_posix()},
{
"project_path": project_path.as_posix(),
"init_weights": init_weights_path.as_posix(),
"dataset": Path(pose_cfg["dataset"]).as_posix(),
"metadataset": Path(pose_cfg["metadataset"]).as_posix(),
},
)

# ---- Trigger DLC model training job ----
train_network_input_args = list(inspect.signature(train_network).parameters)
train_network_input_args = list(
inspect.signature(deeplabcut.train_network).parameters
)
train_network_kwargs = {
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
for k, v in dlc_config.items()
Expand All @@ -304,25 +332,28 @@ def make(self, key):
train_network_kwargs[k] = int(train_network_kwargs[k])

try:
train_network(dlc_cfg_filepath, **train_network_kwargs)
deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs)
except KeyboardInterrupt: # Instructions indicate to train until interrupt
print("DLC training stopped via Keyboard Interrupt")

snapshots = list(model_train_folder.glob("*index*"))
max_modified_time = 0
# DLC goes by snapshot magnitude when judging 'latest' for evaluation
# Here, we mean most recently generated
snapshots = sorted(model_train_folder.glob("snapshot*.index"))
max_modified_time = 0
for snapshot in snapshots:
modified_time = os.path.getmtime(snapshot)
modified_time = snapshot.stat().st_mtime
if modified_time > max_modified_time:
latest_snapshot = int(snapshot.stem[9:])
latest_snapshot_file = snapshot
latest_snapshot = int(re.search(r"(\d+)\.index", latest_snapshot_file.name).group(1))
max_modified_time = modified_time

# update snapshotindex in the config
dlc_config["snapshotindex"] = latest_snapshot
snapshotindex = snapshots.index(latest_snapshot_file)

dlc_config["snapshotindex"] = snapshotindex
edit_config(
dlc_cfg_filepath,
{"snapshotindex": latest_snapshot},
{"snapshotindex": snapshotindex},
)

self.insert1(
Expand Down
Loading