From 9bae8b3fb3d4a5d9de520ccdfa15e6e86d441913 Mon Sep 17 00:00:00 2001 From: Keya Loding Date: Mon, 8 Jul 2024 11:27:25 -0700 Subject: [PATCH] new function --- io_test.ipynb | 80 ++++++++++++++++++++++++--------------------- sleap_io/io/main.py | 2 +- sleap_io/io/nwb.py | 77 ++++++++++++++++--------------------------- 3 files changed, 72 insertions(+), 87 deletions(-) diff --git a/io_test.ipynb b/io_test.ipynb index 2599dbb1..3d1fef34 100644 --- a/io_test.ipynb +++ b/io_test.ipynb @@ -2,14 +2,27 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", "import sleap_io as sio\n", "import pynwb\n", + "import numpy as np\n", "from sleap_io import save_file\n", "from sleap_io.model.skeleton import Node, Edge, Symmetry, Skeleton\n", + "from pynwb.image import ImageSeries\n", "from ndx_pose import (\n", " PoseEstimation,\n", " PoseEstimationSeries,\n", @@ -18,58 +31,51 @@ " PoseTraining,\n", " SourceVideos,\n", ")\n", - "from sleap_io.io.nwb import convert_slp_skeleton_to_nwb" + "from sleap_io.io.nwb import slp_skeleton_to_nwb" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[]\n" - ] - }, - { - "ename": "ValueError", - "evalue": "CustomClassGenerator.set_init..__init__: incorrect shape for 'edges' (got '(0,)', expected '[None, 2]')", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[5], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m skel \u001b[38;5;241m=\u001b[39m Skeleton([Node(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mA\u001b[39m\u001b[38;5;124m\"\u001b[39m), Node(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m)], name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m new_skel \u001b[38;5;241m=\u001b[39m \u001b[43mconvert_slp_skeleton_to_nwb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mskel\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:149\u001b[0m, in \u001b[0;36mconvert_slp_skeleton_to_nwb\u001b[0;34m(skeleton)\u001b[0m\n\u001b[1;32m 147\u001b[0m nwb_edges\u001b[38;5;241m.\u001b[39mappend([i, i \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28mprint\u001b[39m(nwb_edges)\n\u001b[0;32m--> 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mNWBSkeleton\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 150\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskeleton\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43mnodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskeleton\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode_names\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43medges\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnwb_edges\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muint8\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/hdmf/utils.py:667\u001b[0m, in \u001b[0;36mdocval..dec..func_call\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfunc_call\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 667\u001b[0m pargs \u001b[38;5;241m=\u001b[39m \u001b[43m_check_args\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 668\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(args[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpargs)\n", - "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/hdmf/utils.py:660\u001b[0m, in \u001b[0;36mdocval..dec.._check_args\u001b[0;34m(args, kwargs)\u001b[0m\n\u001b[1;32m 658\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m parse_err:\n\u001b[1;32m 659\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m%\u001b[39m (func\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(parse_err))\n\u001b[0;32m--> 660\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ExceptionType(msg)\n\u001b[1;32m 662\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parsed[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m'\u001b[39m]\n", - "\u001b[0;31mValueError\u001b[0m: CustomClassGenerator.set_init..__init__: incorrect shape for 'edges' (got '(0,)', expected '[None, 2]')" - ] - } - ], + "outputs": [], "source": [ - "skel = Skeleton([Node(\"A\"), Node(\"B\")], name=\"test\")\n", - "new_skel = convert_slp_skeleton_to_nwb(skel)" + "node_a = Node(\"A\")\n", + "node_b = Node(\"B\")\n", + "edge = Edge(node_a, node_b)\n", + "skel = Skeleton([Node(\"A\"), Node(\"B\")], [edge], name=\"test\")\n", + "new_skel = slp_skeleton_to_nwb(skel)\n", + "assert new_skel.nodes[0] == \"A\"\n", + "assert new_skel.nodes[1] == \"B\"" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, "outputs": [ { - "ename": "AttributeError", - "evalue": "'NoneType' object has no attribute 'name'", + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/keya/salk/io_fork/sleap_io/io/nwb.py:115: FutureWarning: CustomClassGenerator.set_init..__init__: Using positional arguments for this method is discouraged and will be deprecated in a future major release. Please use keyword arguments to ensure future compatibility.\n", + " training_frame_skeleton_instances = SkeletonInstances(\n" + ] + }, + { + "ename": "TypeError", + "evalue": "CustomClassGenerator.set_init..__init__: incorrect type for 'name' (got 'list', expected 'str')", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m labels_original \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_slp(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtests/data/slp/minimal_instance.pkg.slp\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mlabels_original\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mminimal_instance.pkg.nwb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnwb_training\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m labels_loaded \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_nwb(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mminimal_instance.pkg.nwb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m labels_original\u001b[38;5;241m.\u001b[39mlabeled_frames \u001b[38;5;241m==\u001b[39m labels_loaded\u001b[38;5;241m.\u001b[39mlabeled_frames\n", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[20], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m labels_original \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_slp(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtests/data/slp/minimal_instance.pkg.slp\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mlabels_original\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mminimal_instance.pkg.nwb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnwb_training\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m labels_loaded \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_nwb(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mminimal_instance.pkg.nwb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m labels_original\u001b[38;5;241m.\u001b[39mlabeled_frames \u001b[38;5;241m==\u001b[39m labels_loaded\u001b[38;5;241m.\u001b[39mlabeled_frames\n", "File \u001b[0;32m~/salk/io_fork/sleap_io/model/labels.py:372\u001b[0m, in \u001b[0;36mLabels.save\u001b[0;34m(self, filename, format, embed, **kwargs)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Save labels to file in specified format.\u001b[39;00m\n\u001b[1;32m 349\u001b[0m \n\u001b[1;32m 350\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;124;03m This argument is only valid for the SLP backend.\u001b[39;00m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 370\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msleap_io\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m save_file\n\u001b[0;32m--> 372\u001b[0m \u001b[43msave_file\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:232\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(labels, filename, format, **kwargs)\u001b[0m\n\u001b[1;32m 230\u001b[0m save_jabs(labels, pose_version, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 231\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnwb_training\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 232\u001b[0m \u001b[43mnwb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconvert_labels_to_pose_training\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnknown format \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mformat\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m for filename: \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfilename\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:115\u001b[0m, in \u001b[0;36mconvert_labels_to_pose_training\u001b[0;34m(labels, filename, **kwargs)\u001b[0m\n\u001b[1;32m 110\u001b[0m training_frame_name \u001b[38;5;241m=\u001b[39m name_generator(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_frame\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 111\u001b[0m training_frame_annotator \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtraining_frame_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 112\u001b[0m training_frame_skeleton_instances \u001b[38;5;241m=\u001b[39m SkeletonInstances(\n\u001b[1;32m 113\u001b[0m [\n\u001b[1;32m 114\u001b[0m SkeletonInstance(\n\u001b[0;32m--> 115\u001b[0m name\u001b[38;5;241m=\u001b[39m\u001b[43minstance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m,\n\u001b[1;32m 116\u001b[0m skeleton\u001b[38;5;241m=\u001b[39minstance\u001b[38;5;241m.\u001b[39mskeleton,\n\u001b[1;32m 117\u001b[0m points\u001b[38;5;241m=\u001b[39minstance\u001b[38;5;241m.\u001b[39mpoints,\n\u001b[1;32m 118\u001b[0m confidence\u001b[38;5;241m=\u001b[39minstance\u001b[38;5;241m.\u001b[39mpoint_scores,\n\u001b[1;32m 119\u001b[0m )\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m instance \u001b[38;5;129;01min\u001b[39;00m labeled_frame\u001b[38;5;241m.\u001b[39minstances\n\u001b[1;32m 121\u001b[0m ]\n\u001b[1;32m 122\u001b[0m )\n\u001b[1;32m 123\u001b[0m training_frame_video \u001b[38;5;241m=\u001b[39m labeled_frame\u001b[38;5;241m.\u001b[39mvideo\n\u001b[1;32m 124\u001b[0m training_frame_video_index \u001b[38;5;241m=\u001b[39m labeled_frame\u001b[38;5;241m.\u001b[39mframe_idx\n", - "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'name'" + "File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:232\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(labels, filename, format, **kwargs)\u001b[0m\n\u001b[1;32m 230\u001b[0m save_jabs(labels, pose_version, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 231\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnwb_training\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 232\u001b[0m \u001b[43mnwb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlabels_to_pose_training\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnknown format \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mformat\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m for filename: \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfilename\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:115\u001b[0m, in \u001b[0;36mlabels_to_pose_training\u001b[0;34m(labels, filename, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m training_frame_name \u001b[38;5;241m=\u001b[39m name_generator(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_frame\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 114\u001b[0m training_frame_annotator \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtraining_frame_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 115\u001b[0m training_frame_skeleton_instances \u001b[38;5;241m=\u001b[39m \u001b[43mSkeletonInstances\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 116\u001b[0m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 117\u001b[0m \u001b[43m \u001b[49m\u001b[43minstance_to_skeleton_instance\u001b[49m\u001b[43m(\u001b[49m\u001b[43minstance\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43minstance\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mlabeled_frame\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minstances\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 120\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 121\u001b[0m training_frame_video \u001b[38;5;241m=\u001b[39m labeled_frame\u001b[38;5;241m.\u001b[39mvideo\n\u001b[1;32m 122\u001b[0m training_frame_video_index \u001b[38;5;241m=\u001b[39m labeled_frame\u001b[38;5;241m.\u001b[39mframe_idx\n", + "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/hdmf/utils.py:667\u001b[0m, in \u001b[0;36mdocval..dec..func_call\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfunc_call\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 667\u001b[0m pargs \u001b[38;5;241m=\u001b[39m \u001b[43m_check_args\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 668\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(args[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpargs)\n", + "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/hdmf/utils.py:660\u001b[0m, in \u001b[0;36mdocval..dec.._check_args\u001b[0;34m(args, kwargs)\u001b[0m\n\u001b[1;32m 658\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m parse_err:\n\u001b[1;32m 659\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m%\u001b[39m (func\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(parse_err))\n\u001b[0;32m--> 660\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ExceptionType(msg)\n\u001b[1;32m 662\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parsed[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m'\u001b[39m]\n", + "\u001b[0;31mTypeError\u001b[0m: CustomClassGenerator.set_init..__init__: incorrect type for 'name' (got 'list', expected 'str')" ] } ], diff --git a/sleap_io/io/main.py b/sleap_io/io/main.py index 1a4570ab..ccf5e220 100644 --- a/sleap_io/io/main.py +++ b/sleap_io/io/main.py @@ -229,6 +229,6 @@ def save_file( pose_version = kwargs.pop("pose_version", 5) save_jabs(labels, pose_version, filename, **kwargs) elif format == "nwb_training": - nwb.convert_labels_to_pose_training(labels, filename, **kwargs) + nwb.labels_to_pose_training(labels, filename, **kwargs) else: raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.") \ No newline at end of file diff --git a/sleap_io/io/nwb.py b/sleap_io/io/nwb.py index 81abd23c..665a065f 100644 --- a/sleap_io/io/nwb.py +++ b/sleap_io/io/nwb.py @@ -43,46 +43,7 @@ from sleap_io.io.utils import convert_predictions_to_dataframe -# def convert_nwb_to_slp(nwb_data_structure): -# """Converts an NWB object to its object SLEAP instance.""" - -# def convert_frame(frame: TrainingFrame) -> LabeledFrame: # type: ignore[return] -# """ -# Converts an NWB TrainingFrame instance to a LabeledFrame instance. -# """ -# return LabeledFrame( -# video=Video(filename=frame.source_video.data), -# frame_idx=frame.frame_number.data, -# instances=[ -# PredictedInstance.from_numpy( -# points=frame.points.data, -# point_scores=frame.confidence.data, -# instance_score=frame.confidence.data.mean(), -# skeleton=Skeleton( -# nodes=frame.skeleton.nodes.data, -# edges=frame.skeleton.edges.data, -# ), -# ) -# ], -# ) - -# if isinstance(nwb_data_structure, TrainingFrame): -# return convert_frame(nwb_data_structure) -# elif isinstance(nwb_data_structure, TrainingFrames): -# return [convert_frame(frame) for frame in nwb_data_structure.training_frames] -# elif isinstance(nwb_data_structure, PoseTraining): -# return Labels( -# [convert_frame(frame) for frame in nwb_data_structure.training_frames] -# ) -# elif isinstance(nwb_data_structure, SourceVideos): -# return Video(filename=nwb_data_structure.data) -# else: -# raise ValueError( -# f"Cannot convert {type(nwb_data_structure)} to SLEAP instance." -# ) - - -def convert_pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return] +def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return] """Creates a Labels object from an NWB PoseTraining object.""" labeled_frames = [] for training_frame in pose_training.training_frames: @@ -106,7 +67,7 @@ def convert_pose_training_to_labels(pose_training: PoseTraining) -> Labels: # t return Labels(labeled_frames) -def convert_labels_to_pose_training(labels: Labels, filename: str, **kwargs) -> PoseTraining: # type: ignore[return] +def labels_to_pose_training(labels: Labels, filename: str, **kwargs) -> PoseTraining: # type: ignore[return] """Creates an NWB PoseTraining object from a Labels object.""" training_frame_list = [] for i, labeled_frame in enumerate(labels.labeled_frames): @@ -114,7 +75,7 @@ def convert_labels_to_pose_training(labels: Labels, filename: str, **kwargs) -> training_frame_annotator = f"{training_frame_name}{i}" training_frame_skeleton_instances = SkeletonInstances( [ - convert_instance_to_skeleton_instance(instance) + instance_to_skeleton_instance(instance) for instance in labeled_frame.instances ] ) @@ -136,7 +97,7 @@ def convert_labels_to_pose_training(labels: Labels, filename: str, **kwargs) -> ) return pose_training -def convert_slp_skeleton_to_nwb(skeleton: SLEAPSkeleton) -> NWBSkeleton: # type: ignore[return] +def slp_skeleton_to_nwb(skeleton: SLEAPSkeleton) -> NWBSkeleton: # type: ignore[return] """Converts SLEAP skeleton to NWB skeleton.""" nwb_edges: list[list[int, int]] = [] for i, _ in enumerate(skeleton.edges): @@ -149,19 +110,37 @@ def convert_slp_skeleton_to_nwb(skeleton: SLEAPSkeleton) -> NWBSkeleton: # type: edges=np.array(nwb_edges, dtype=np.uint8), ) -def convert_instance_to_skeleton_instance(instance: Instance) -> SkeletonInstance: # type: ignore[return] - id = np.uint(10) - skeleton = convert_slp_skeleton_to_nwb(instance.skeleton) + +def instance_to_skeleton_instance(instance: Instance) -> SkeletonInstance: # type: ignore[return] + skeleton = slp_skeleton_to_nwb(instance.skeleton) node_locations = skeleton.edges node_visibility = [True, False] return SkeletonInstance( - id=id, + id=np.uint(10), node_locations=node_locations, node_visibility=node_visibility, skeleton=skeleton, ) +def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return] + """Converts a list of SLEAP Videos to NWB SourceVideos.""" + source_videos = [] + for video in videos: + image_series = ImageSeries( + name=video.filename, + description="Video file", + unit="NA", + format="external", + external_file=[video.filename], + dimension=[video.backend.height, video.backend.width], + starting_frame=[0], + rate=30.0, + ) + source_videos.append(image_series) + return SourceVideos(data=source_videos) + + def get_timestamps(series: PoseEstimationSeries) -> np.ndarray: """Return a vector of timestamps for a `PoseEstimationSeries`.""" if series.timestamps is not None: @@ -307,7 +286,7 @@ def write_nwb( or the sampling rate with key`video_sample_rate`. e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps) - or pose_estimation_metadata["video_sample_rate] = 15 # In Hz + or pose_estimation_metadata["video_sample_rate"] = 15 # In Hz 2) The other use of this dictionary is to ovewrite sleap-io default arguments for the PoseEstimation container. @@ -533,7 +512,7 @@ def build_pose_estimation_container_for_track( def build_track_pose_estimation_list( - track_data_df: pd.DataFrame, timestamps: ArrayLike + track_data_df: pd.DataFrame, timestamps: ArrayLike # type: ignore[return] ) -> List[PoseEstimationSeries]: """Build a list of PoseEstimationSeries from tracks.