Skip to content

Commit

Permalink
Basic COCO writing implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Jul 3, 2024
1 parent 7be4993 commit c9947b6
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 12 deletions.
136 changes: 125 additions & 11 deletions sleap_io/io/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
from sleap_io import (
Video,
Skeleton,
Edge,
Symmetry,
Node,
Track,
SuggestionFrame,
Point,
PredictedPoint,
Instance,
PredictedInstance,
LabeledFrame,
Labels,
)

import sys
import imageio.v3 as iio

try:
import cv2
except ImportError:
pass


def read_ann(ann_json_path: str | Path):
"""Read annotations JSON file.
Expand Down Expand Up @@ -51,9 +52,9 @@ def make_skeleton(ann: dict) -> Skeleton:
This assumes that `skeleton` (edge indices) are 1-based.
"""
return Skeleton(
nodes=ann["categories"]["keypoints"],
edges=(np.array(ann["categories"]["skeleton"]) - 1).tolist(),
name=ann["categories"].get("name", None),
nodes=ann["categories"][0]["keypoints"],
edges=(np.array(ann["categories"][0]["skeleton"]) - 1).tolist(),
name=ann["categories"][0].get("name", None),
)


Expand Down Expand Up @@ -84,7 +85,7 @@ def make_videos(
video_id_map = {}
for img in ann["images"]:
shape = img["height"], img["width"]
img_filename = img["filename"]
img_filename = img["file_name"]
if imgs_prefix is not None:
img_filename = (imgs_prefix / img_filename).as_posix()
imgs_by_shape[shape].append(img_filename)
Expand Down Expand Up @@ -175,3 +176,116 @@ def read_labels(
skeleton = make_skeleton(ann)
labels = make_labels(ann, videos, video_id_map, skeleton)
return labels


def write_labels(
labels: Labels,
dataset_folder: str | Path,
split: str | None = None,
img_format: str = "png",
):
"""Save a `Labels` to COCO format.
Args:
labels: A `Labels` object.
dataset_folder: Path to a folder to save data to.
split: Optional string specifying the split name.
img_format: Format to save images to. Formats: "png" (default) or "jpg".
Notes:
If `split` was not provided, the annotations will be saved to
`{dataset_folder}/annotations/ann.json` and images will be saved to
`{dataset_folder}/images`.
If `split` was provided, the annotations will be saved to
`{dataset_folder}/annotations/ann_{split}.json` and images will be saved to
`{dataset_folder}/images/{split}`.
Calling this multiple times with the same dataset folder may overwrite previous
data if `split` is not provided.
"""
if split is None:
ann_path = dataset_folder / "annotations" / "ann.json"
imgs_folder = dataset_folder / "images"
else:
ann_path = dataset_folder / "annotations" / f"ann_{split}.json"
imgs_folder = dataset_folder / "images" / split

ann_path.parent.mkdir(parents=True, exist_ok=True)
imgs_folder.mkdir(parents=True, exist_ok=True)

lfs = labels.user_labeled_frames

imgs = []
img_filename_map = {}
for img_id, lf in enumerate(lfs):
img_filename = f"{img_id}.{img_format}"
img_shape = video.shape[[1, 2]]
imgs.append(
{
"id": img_id,
"file_name": img_filename.as_posix(),
"height": img_shape[0],
"width": img_shape[1],
}
)
img_filename_map[(lf.video, lf.frame_idx)] = img_filename

for (video, frame_idx), img_filename in img_filename_map.items():
img = video[frame_idx]
img_path = (imgs_folder / img_filename).as_posix()
if "cv2" in sys.modules:
cv2.imwrite(img_path, img)
else:
iio.imwrite(img_path, img)

inst_id = 0
annotations = []
for img_id, lf in enumerate(lfs):
for inst in lf:
ann = {}

pts = inst.numpy()
vis = np.isnan(pts).any(axis=1, keepdims=True).astype(int)
vis[vis == 0] = 2 # labeled and visible
# 1: labeled but not visible
vis[vis == 1] = 0 # not labeled
pts[np.isnan(pts)] = -1
kps = np.concatenate([pts, vis], axis=1).reshape(-1).tolist()
ann["keypoints"] = kps
ann["id"] = inst_id
ann["image_id"] = img_id
ann["num_keypoints"] = len(pts)

x, y = np.nanmin(pts, axis=0)
w, h = np.nanmax(pts, axis=0) - np.nanmin(pts, axis=0)
ann["bbox"] = [x, y, w, h]
ann["iscrowd"] = 0
ann["area"] = w * h
ann["category_id"] = labels.skeletons.index(inst.skeleton)

if inst.track is not None:
ann["track_id"] = labels.tracks.index(inst.track)

annotations.append(ann)
inst_id += 1

categories = []
for skel_ind, skel in enumerate(labels.skeletons):
category = {}
category["supercategory"] = "animal"
category["id"] = skel_ind
category["name"] = skel.name
category["keypoints"] = skel.node_names
category["skeleton"] = (np.array(skel.edge_inds) + 1).tolist()
categories.append(category)

ann = {
"info": labels.provenance.get("info", {}),
"images": imgs,
"annotations": annotations,
"categories": categories,
}

with open(ann_path, "w") as f:
json.dump(ann, f)
2 changes: 1 addition & 1 deletion sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def load_file(
filename: Path to a file.
format: Optional format to load as. If not provided, will be inferred from the
file extension. Available formats are: "slp", "nwb", "labelstudio", "jabs"
and "video".
"coco" and "video".
Returns:
A `Labels` or `Video` object.
Expand Down

0 comments on commit c9947b6

Please sign in to comment.