Skip to content

Commit

Permalink
Fix split calculation and allow for not embedding (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo authored Sep 29, 2024
1 parent 391df6e commit bbee932
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 7 deletions.
2 changes: 1 addition & 1 deletion sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def write_labels(
"""
if Path(labels_path).exists():
Path(labels_path).unlink()
if embed is not None:
if embed:
embed_videos(labels_path, labels, embed)
write_videos(labels_path, labels.videos, restore_source=(embed == "source"))
write_tracks(labels_path, labels.tracks)
Expand Down
33 changes: 28 additions & 5 deletions sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ def make_training_splits(
n_test: int | float | None = None,
save_dir: str | Path | None = None,
seed: int | None = None,
embed: bool = True,
) -> tuple[Labels, Labels] | tuple[Labels, Labels, Labels]:
"""Make splits for training with embedded images.
Expand All @@ -635,6 +636,10 @@ def make_training_splits(
test split will not be saved.
save_dir: If specified, save splits to SLP files with embedded images.
seed: Optional integer seed to use for reproducibility.
embed: If `True` (the default), embed user labeled frame images in the saved
files, which is useful for portability but can be slow for large
projects. If `False`, labels are saved with references to the source
videos files.
Returns:
A tuple of `labels_train, labels_val` or
Expand All @@ -652,6 +657,12 @@ def make_training_splits(
- `{save_dir}/val.pkg.slp`
- `{save_dir}/test.pkg.slp` (if `n_test` is specified)
If `embed` is `False`, the files will be saved without embedded images to:
- `{save_dir}/train.slp`
- `{save_dir}/val.slp`
- `{save_dir}/test.slp` (if `n_test` is specified)
See also: `Labels.split`
"""
# Clean up labels.
Expand All @@ -660,16 +671,23 @@ def make_training_splits(
labels.suggestions = []
labels.clean()

# Make splits.
# Make train split.
labels_train, labels_rest = labels.split(n_train, seed=seed)

# Make test split.
if n_test is not None:
if n_test < 1:
n_test = (n_test * len(labels)) / len(labels_rest)
labels_test, labels_rest = labels_rest.split(n=n_test, seed=seed)

# Make val split.
if n_val is not None:
if n_val < 1:
n_val = (n_val * len(labels)) / len(labels_rest)
labels_val, _ = labels_rest.split(n=n_val, seed=seed)
if isinstance(n_val, float) and n_val == 1.0:
labels_val = labels_rest
else:
labels_val, _ = labels_rest.split(n=n_val, seed=seed)
else:
labels_val = labels_rest

Expand All @@ -678,9 +696,14 @@ def make_training_splits(
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True, parents=True)

labels_train.save(save_dir / "train.pkg.slp", embed="user")
labels_val.save(save_dir / "val.pkg.slp", embed="user")
labels_test.save(save_dir / "test.pkg.slp", embed="user")
if embed:
labels_train.save(save_dir / "train.pkg.slp", embed="user")
labels_val.save(save_dir / "val.pkg.slp", embed="user")
labels_test.save(save_dir / "test.pkg.slp", embed="user")
else:
labels_train.save(save_dir / "train.slp", embed=False)
labels_val.save(save_dir / "val.slp", embed=False)
labels_test.save(save_dir / "test.slp", embed=False)

if n_test is None:
return labels_train, labels_val
Expand Down
42 changes: 41 additions & 1 deletion tests/model/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def test_split(slp_real_data, tmp_path):
)


def test_make_training_splits(slp_real_data, tmp_path):
def test_make_training_splits(slp_real_data):
labels = load_slp(slp_real_data)
assert len(labels.user_labeled_frames) == 5

Expand Down Expand Up @@ -568,6 +568,11 @@ def test_make_training_splits(slp_real_data, tmp_path):
assert len(val) == 1
assert len(test) == 1

train, val, test = labels.make_training_splits(n_train=0.4, n_val=0.4, n_test=0.2)
assert len(train) == 2
assert len(val) == 2
assert len(test) == 1


def test_make_training_splits_save(slp_real_data, tmp_path):
labels = load_slp(slp_real_data)
Expand All @@ -587,3 +592,38 @@ def test_make_training_splits_save(slp_real_data, tmp_path):
assert train_.provenance["source_labels"] == slp_real_data
assert val_.provenance["source_labels"] == slp_real_data
assert test_.provenance["source_labels"] == slp_real_data


@pytest.mark.parametrize("embed", [True, False])
def test_make_training_splits_save(slp_real_data, tmp_path, embed):
labels = load_slp(slp_real_data)

train, val, test = labels.make_training_splits(
0.6, 0.2, 0.2, save_dir=tmp_path, embed=embed
)

if embed:
train_, val_, test_ = (
load_slp(tmp_path / "train.pkg.slp"),
load_slp(tmp_path / "val.pkg.slp"),
load_slp(tmp_path / "test.pkg.slp"),
)
else:
train_, val_, test_ = (
load_slp(tmp_path / "train.slp"),
load_slp(tmp_path / "val.slp"),
load_slp(tmp_path / "test.slp"),
)

assert len(train_) == len(train)
assert len(val_) == len(val)
assert len(test_) == len(test)

if embed:
assert train_.provenance["source_labels"] == slp_real_data
assert val_.provenance["source_labels"] == slp_real_data
assert test_.provenance["source_labels"] == slp_real_data
else:
assert train_.video.filename == labels.video.filename
assert val_.video.filename == labels.video.filename
assert test_.video.filename == labels.video.filename

0 comments on commit bbee932

Please sign in to comment.