From c9947b680b0eba65e0ed145233285b44b57184b5 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Tue, 2 Jul 2024 21:12:55 -0400 Subject: [PATCH] Basic COCO writing implementation --- sleap_io/io/coco.py | 136 ++++++++++++++++++++++++++++++++++++++++---- sleap_io/io/main.py | 2 +- 2 files changed, 126 insertions(+), 12 deletions(-) diff --git a/sleap_io/io/coco.py b/sleap_io/io/coco.py index 349289b8..c9d8c596 100644 --- a/sleap_io/io/coco.py +++ b/sleap_io/io/coco.py @@ -8,19 +8,20 @@ from sleap_io import ( Video, Skeleton, - Edge, - Symmetry, - Node, Track, - SuggestionFrame, - Point, - PredictedPoint, Instance, - PredictedInstance, LabeledFrame, Labels, ) +import sys +import imageio.v3 as iio + +try: + import cv2 +except ImportError: + pass + def read_ann(ann_json_path: str | Path): """Read annotations JSON file. @@ -51,9 +52,9 @@ def make_skeleton(ann: dict) -> Skeleton: This assumes that `skeleton` (edge indices) are 1-based. """ return Skeleton( - nodes=ann["categories"]["keypoints"], - edges=(np.array(ann["categories"]["skeleton"]) - 1).tolist(), - name=ann["categories"].get("name", None), + nodes=ann["categories"][0]["keypoints"], + edges=(np.array(ann["categories"][0]["skeleton"]) - 1).tolist(), + name=ann["categories"][0].get("name", None), ) @@ -84,7 +85,7 @@ def make_videos( video_id_map = {} for img in ann["images"]: shape = img["height"], img["width"] - img_filename = img["filename"] + img_filename = img["file_name"] if imgs_prefix is not None: img_filename = (imgs_prefix / img_filename).as_posix() imgs_by_shape[shape].append(img_filename) @@ -175,3 +176,116 @@ def read_labels( skeleton = make_skeleton(ann) labels = make_labels(ann, videos, video_id_map, skeleton) return labels + + +def write_labels( + labels: Labels, + dataset_folder: str | Path, + split: str | None = None, + img_format: str = "png", +): + """Save a `Labels` to COCO format. + + Args: + labels: A `Labels` object. + dataset_folder: Path to a folder to save data to. + split: Optional string specifying the split name. + img_format: Format to save images to. Formats: "png" (default) or "jpg". + + Notes: + If `split` was not provided, the annotations will be saved to + `{dataset_folder}/annotations/ann.json` and images will be saved to + `{dataset_folder}/images`. + + If `split` was provided, the annotations will be saved to + `{dataset_folder}/annotations/ann_{split}.json` and images will be saved to + `{dataset_folder}/images/{split}`. + + Calling this multiple times with the same dataset folder may overwrite previous + data if `split` is not provided. + """ + if split is None: + ann_path = dataset_folder / "annotations" / "ann.json" + imgs_folder = dataset_folder / "images" + else: + ann_path = dataset_folder / "annotations" / f"ann_{split}.json" + imgs_folder = dataset_folder / "images" / split + + ann_path.parent.mkdir(parents=True, exist_ok=True) + imgs_folder.mkdir(parents=True, exist_ok=True) + + lfs = labels.user_labeled_frames + + imgs = [] + img_filename_map = {} + for img_id, lf in enumerate(lfs): + img_filename = f"{img_id}.{img_format}" + img_shape = video.shape[[1, 2]] + imgs.append( + { + "id": img_id, + "file_name": img_filename.as_posix(), + "height": img_shape[0], + "width": img_shape[1], + } + ) + img_filename_map[(lf.video, lf.frame_idx)] = img_filename + + for (video, frame_idx), img_filename in img_filename_map.items(): + img = video[frame_idx] + img_path = (imgs_folder / img_filename).as_posix() + if "cv2" in sys.modules: + cv2.imwrite(img_path, img) + else: + iio.imwrite(img_path, img) + + inst_id = 0 + annotations = [] + for img_id, lf in enumerate(lfs): + for inst in lf: + ann = {} + + pts = inst.numpy() + vis = np.isnan(pts).any(axis=1, keepdims=True).astype(int) + vis[vis == 0] = 2 # labeled and visible + # 1: labeled but not visible + vis[vis == 1] = 0 # not labeled + pts[np.isnan(pts)] = -1 + kps = np.concatenate([pts, vis], axis=1).reshape(-1).tolist() + ann["keypoints"] = kps + ann["id"] = inst_id + ann["image_id"] = img_id + ann["num_keypoints"] = len(pts) + + x, y = np.nanmin(pts, axis=0) + w, h = np.nanmax(pts, axis=0) - np.nanmin(pts, axis=0) + ann["bbox"] = [x, y, w, h] + ann["iscrowd"] = 0 + ann["area"] = w * h + ann["category_id"] = labels.skeletons.index(inst.skeleton) + + if inst.track is not None: + ann["track_id"] = labels.tracks.index(inst.track) + + annotations.append(ann) + inst_id += 1 + + categories = [] + for skel_ind, skel in enumerate(labels.skeletons): + category = {} + category["supercategory"] = "animal" + category["id"] = skel_ind + category["name"] = skel.name + category["keypoints"] = skel.node_names + category["skeleton"] = (np.array(skel.edge_inds) + 1).tolist() + categories.append(category) + + ann = { + "info": labels.provenance.get("info", {}), + "images": imgs, + "annotations": annotations, + "categories": categories, + } + + with open(ann_path, "w") as f: + json.dump(ann, f) diff --git a/sleap_io/io/main.py b/sleap_io/io/main.py index 0e558f82..38044ce7 100644 --- a/sleap_io/io/main.py +++ b/sleap_io/io/main.py @@ -192,7 +192,7 @@ def load_file( filename: Path to a file. format: Optional format to load as. If not provided, will be inferred from the file extension. Available formats are: "slp", "nwb", "labelstudio", "jabs" - and "video". + "coco" and "video". Returns: A `Labels` or `Video` object.