From 8582f51a3447b1b17b27b08912eb4c797aa99503 Mon Sep 17 00:00:00 2001 From: Keya Loding Date: Mon, 8 Jul 2024 10:23:17 -0700 Subject: [PATCH] new functions --- io_test.ipynb | 78 ++++++++++++++++++++++++++-------------------- sleap_io/io/nwb.py | 34 +++++++++++--------- 2 files changed, 65 insertions(+), 47 deletions(-) diff --git a/io_test.ipynb b/io_test.ipynb index 5d8adb36..2599dbb1 100644 --- a/io_test.ipynb +++ b/io_test.ipynb @@ -2,13 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import sleap_io as sio\n", "import pynwb\n", "from sleap_io import save_file\n", + "from sleap_io.model.skeleton import Node, Edge, Symmetry, Skeleton\n", "from ndx_pose import (\n", " PoseEstimation,\n", " PoseEstimationSeries,\n", @@ -17,31 +18,58 @@ " PoseTraining,\n", " SourceVideos,\n", ")\n", - "from sleap_io.io.nwb import (\n", - " convert_labels_to_pose_training,\n", - " convert_pose_training_to_labels,\n", - ")" + "from sleap_io.io.nwb import convert_slp_skeleton_to_nwb" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "metadata": {}, "outputs": [ { - "ename": "TypeError", - "evalue": "'NoneType' object is not subscriptable", + "name": "stdout", + "output_type": "stream", + "text": [ + "[]\n" + ] + }, + { + "ename": "ValueError", + "evalue": "CustomClassGenerator.set_init..__init__: incorrect shape for 'edges' (got '(0,)', expected '[None, 2]')", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[9], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m labels_original \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_slp(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtests/data/slp/minimal_instance.pkg.slp\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mlabels_original\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mminimal_instance.pkg.nwb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnwb_training\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m labels_loaded \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_nwb(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mminimal_instance.pkg.nwb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m labels_original\u001b[38;5;241m.\u001b[39mlabeled_frames \u001b[38;5;241m==\u001b[39m labels_loaded\u001b[38;5;241m.\u001b[39mlabeled_frames\n", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m skel \u001b[38;5;241m=\u001b[39m Skeleton([Node(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mA\u001b[39m\u001b[38;5;124m\"\u001b[39m), Node(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mB\u001b[39m\u001b[38;5;124m\"\u001b[39m)], name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m new_skel \u001b[38;5;241m=\u001b[39m \u001b[43mconvert_slp_skeleton_to_nwb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mskel\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:149\u001b[0m, in \u001b[0;36mconvert_slp_skeleton_to_nwb\u001b[0;34m(skeleton)\u001b[0m\n\u001b[1;32m 147\u001b[0m nwb_edges\u001b[38;5;241m.\u001b[39mappend([i, i \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28mprint\u001b[39m(nwb_edges)\n\u001b[0;32m--> 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mNWBSkeleton\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 150\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskeleton\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43mnodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskeleton\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnode_names\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43medges\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnwb_edges\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43muint8\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/hdmf/utils.py:667\u001b[0m, in \u001b[0;36mdocval..dec..func_call\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfunc_call\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 667\u001b[0m pargs \u001b[38;5;241m=\u001b[39m \u001b[43m_check_args\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 668\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(args[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpargs)\n", + "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/hdmf/utils.py:660\u001b[0m, in \u001b[0;36mdocval..dec.._check_args\u001b[0;34m(args, kwargs)\u001b[0m\n\u001b[1;32m 658\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m parse_err:\n\u001b[1;32m 659\u001b[0m msg \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m%\u001b[39m (func\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(parse_err))\n\u001b[0;32m--> 660\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ExceptionType(msg)\n\u001b[1;32m 662\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m parsed[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m'\u001b[39m]\n", + "\u001b[0;31mValueError\u001b[0m: CustomClassGenerator.set_init..__init__: incorrect shape for 'edges' (got '(0,)', expected '[None, 2]')" + ] + } + ], + "source": [ + "skel = Skeleton([Node(\"A\"), Node(\"B\")], name=\"test\")\n", + "new_skel = convert_slp_skeleton_to_nwb(skel)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'NoneType' object has no attribute 'name'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m labels_original \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_slp(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtests/data/slp/minimal_instance.pkg.slp\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mlabels_original\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mminimal_instance.pkg.nwb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnwb_training\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m labels_loaded \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_nwb(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mminimal_instance.pkg.nwb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m labels_original\u001b[38;5;241m.\u001b[39mlabeled_frames \u001b[38;5;241m==\u001b[39m labels_loaded\u001b[38;5;241m.\u001b[39mlabeled_frames\n", "File \u001b[0;32m~/salk/io_fork/sleap_io/model/labels.py:372\u001b[0m, in \u001b[0;36mLabels.save\u001b[0;34m(self, filename, format, embed, **kwargs)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Save labels to file in specified format.\u001b[39;00m\n\u001b[1;32m 349\u001b[0m \n\u001b[1;32m 350\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;124;03m This argument is only valid for the SLP backend.\u001b[39;00m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 370\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msleap_io\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m save_file\n\u001b[0;32m--> 372\u001b[0m \u001b[43msave_file\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:225\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(labels, filename, format, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m save_slp(labels, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 224\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnwb\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnwb_training\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 225\u001b[0m \u001b[43msave_nwb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabelstudio\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 227\u001b[0m save_labelstudio(labels, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", - "File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:76\u001b[0m, in \u001b[0;36msave_nwb\u001b[0;34m(labels, filename, append, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m nwb\u001b[38;5;241m.\u001b[39mappend_nwb(labels, filename)\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 76\u001b[0m \u001b[43mnwb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_nwb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:300\u001b[0m, in \u001b[0;36mwrite_nwb\u001b[0;34m(labels, nwbfile_path, nwb_file_kwargs, pose_estimation_metadata)\u001b[0m\n\u001b[1;32m 293\u001b[0m nwb_file_kwargs\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[1;32m 294\u001b[0m session_description\u001b[38;5;241m=\u001b[39msession_description,\n\u001b[1;32m 295\u001b[0m session_start_time\u001b[38;5;241m=\u001b[39msession_start_time,\n\u001b[1;32m 296\u001b[0m identifier\u001b[38;5;241m=\u001b[39midentifier,\n\u001b[1;32m 297\u001b[0m )\n\u001b[1;32m 299\u001b[0m nwbfile \u001b[38;5;241m=\u001b[39m NWBFile(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mnwb_file_kwargs)\n\u001b[0;32m--> 300\u001b[0m nwbfile \u001b[38;5;241m=\u001b[39m \u001b[43mappend_nwb_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnwbfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpose_estimation_metadata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m NWBHDF5IO(\u001b[38;5;28mstr\u001b[39m(nwbfile_path), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m io:\n\u001b[1;32m 303\u001b[0m io\u001b[38;5;241m.\u001b[39mwrite(nwbfile)\n", - "File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:359\u001b[0m, in \u001b[0;36mappend_nwb_data\u001b[0;34m(labels, nwbfile, pose_estimation_metadata)\u001b[0m\n\u001b[1;32m 355\u001b[0m default_metadata\u001b[38;5;241m.\u001b[39mupdate(pose_estimation_metadata)\n\u001b[1;32m 357\u001b[0m \u001b[38;5;66;03m# For every track in that video create a PoseEstimation container\u001b[39;00m\n\u001b[1;32m 358\u001b[0m name_of_tracks_in_video \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 359\u001b[0m \u001b[43mlabels_data_df\u001b[49m\u001b[43m[\u001b[49m\u001b[43mvideo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilename\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 360\u001b[0m \u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mget_level_values(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrack_name\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 361\u001b[0m \u001b[38;5;241m.\u001b[39munique()\n\u001b[1;32m 362\u001b[0m )\n\u001b[1;32m 364\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m track_name \u001b[38;5;129;01min\u001b[39;00m name_of_tracks_in_video:\n\u001b[1;32m 365\u001b[0m pose_estimation_container \u001b[38;5;241m=\u001b[39m build_pose_estimation_container_for_track(\n\u001b[1;32m 366\u001b[0m labels_data_df,\n\u001b[1;32m 367\u001b[0m labels,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 370\u001b[0m default_metadata,\n\u001b[1;32m 371\u001b[0m )\n", - "\u001b[0;31mTypeError\u001b[0m: 'NoneType' object is not subscriptable" + "File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:232\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(labels, filename, format, **kwargs)\u001b[0m\n\u001b[1;32m 230\u001b[0m save_jabs(labels, pose_version, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 231\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnwb_training\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 232\u001b[0m \u001b[43mnwb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconvert_labels_to_pose_training\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnknown format \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mformat\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m for filename: \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfilename\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:115\u001b[0m, in \u001b[0;36mconvert_labels_to_pose_training\u001b[0;34m(labels, filename, **kwargs)\u001b[0m\n\u001b[1;32m 110\u001b[0m training_frame_name \u001b[38;5;241m=\u001b[39m name_generator(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_frame\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 111\u001b[0m training_frame_annotator \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtraining_frame_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 112\u001b[0m training_frame_skeleton_instances \u001b[38;5;241m=\u001b[39m SkeletonInstances(\n\u001b[1;32m 113\u001b[0m [\n\u001b[1;32m 114\u001b[0m SkeletonInstance(\n\u001b[0;32m--> 115\u001b[0m name\u001b[38;5;241m=\u001b[39m\u001b[43minstance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m,\n\u001b[1;32m 116\u001b[0m skeleton\u001b[38;5;241m=\u001b[39minstance\u001b[38;5;241m.\u001b[39mskeleton,\n\u001b[1;32m 117\u001b[0m points\u001b[38;5;241m=\u001b[39minstance\u001b[38;5;241m.\u001b[39mpoints,\n\u001b[1;32m 118\u001b[0m confidence\u001b[38;5;241m=\u001b[39minstance\u001b[38;5;241m.\u001b[39mpoint_scores,\n\u001b[1;32m 119\u001b[0m )\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m instance \u001b[38;5;129;01min\u001b[39;00m labeled_frame\u001b[38;5;241m.\u001b[39minstances\n\u001b[1;32m 121\u001b[0m ]\n\u001b[1;32m 122\u001b[0m )\n\u001b[1;32m 123\u001b[0m training_frame_video \u001b[38;5;241m=\u001b[39m labeled_frame\u001b[38;5;241m.\u001b[39mvideo\n\u001b[1;32m 124\u001b[0m training_frame_video_index \u001b[38;5;241m=\u001b[39m labeled_frame\u001b[38;5;241m.\u001b[39mframe_idx\n", + "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'name'" ] } ], @@ -59,25 +87,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "OSError", - "evalue": "Cannot understand given URI: array([[[0],\n [0],\n [0],\n ...,\n ....", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[8], line 7\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, lf \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(labels_original):\n\u001b[1;32m 6\u001b[0m img_path \u001b[38;5;241m=\u001b[39m save_path \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlabels_original\u001b[38;5;241m.\u001b[39mlabeled_frames[i]\u001b[38;5;241m.\u001b[39mvideo\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlf\u001b[38;5;241m.\u001b[39mframe_idx\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.png\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 7\u001b[0m \u001b[43miio\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimg_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 8\u001b[0m img_paths\u001b[38;5;241m.\u001b[39mappend(img_path)\n", - "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/v3.py:139\u001b[0m, in \u001b[0;36mimwrite\u001b[0;34m(uri, image, plugin, extension, format_hint, **kwargs)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mimwrite\u001b[39m(uri, image, \u001b[38;5;241m*\u001b[39m, plugin\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, extension\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, format_hint\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 105\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Write an ndimage to the given URI.\u001b[39;00m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m The exact behavior depends on the file type and plugin used. To learn about\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 136\u001b[0m \n\u001b[1;32m 137\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 139\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mimopen\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 140\u001b[0m \u001b[43m \u001b[49m\u001b[43muri\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mw\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 142\u001b[0m \u001b[43m \u001b[49m\u001b[43mlegacy_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 143\u001b[0m \u001b[43m \u001b[49m\u001b[43mplugin\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mplugin\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 144\u001b[0m \u001b[43m \u001b[49m\u001b[43mformat_hint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mformat_hint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[43m \u001b[49m\u001b[43mextension\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextension\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 146\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m img_file:\n\u001b[1;32m 147\u001b[0m encoded \u001b[38;5;241m=\u001b[39m img_file\u001b[38;5;241m.\u001b[39mwrite(image, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m encoded\n", - "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/core/imopen.py:113\u001b[0m, in \u001b[0;36mimopen\u001b[0;34m(uri, io_mode, plugin, extension, format_hint, legacy_mode, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m request\u001b[38;5;241m.\u001b[39mformat_hint \u001b[38;5;241m=\u001b[39m format_hint\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m request \u001b[38;5;241m=\u001b[39m \u001b[43mRequest\u001b[49m\u001b[43m(\u001b[49m\u001b[43muri\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mio_mode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mformat_hint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mformat_hint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mextension\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextension\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 115\u001b[0m source \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(uri, \u001b[38;5;28mbytes\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m uri\n\u001b[1;32m 117\u001b[0m \u001b[38;5;66;03m# fast-path based on plugin\u001b[39;00m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;66;03m# (except in legacy mode)\u001b[39;00m\n", - "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/core/request.py:247\u001b[0m, in \u001b[0;36mRequest.__init__\u001b[0;34m(self, uri, mode, extension, format_hint, **kwargs)\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid Request.Mode: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmode\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 246\u001b[0m \u001b[38;5;66;03m# Parse what was given\u001b[39;00m\n\u001b[0;32m--> 247\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_parse_uri\u001b[49m\u001b[43m(\u001b[49m\u001b[43muri\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# Set extension\u001b[39;00m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m extension \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/imageio/core/request.py:369\u001b[0m, in \u001b[0;36mRequest._parse_uri\u001b[0;34m(self, uri)\u001b[0m\n\u001b[1;32m 367\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(uri_r) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m60\u001b[39m:\n\u001b[1;32m 368\u001b[0m uri_r \u001b[38;5;241m=\u001b[39m uri_r[:\u001b[38;5;241m57\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m...\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 369\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIOError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot understand given URI: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m uri_r)\n\u001b[1;32m 371\u001b[0m \u001b[38;5;66;03m# Check if this is supported\u001b[39;00m\n\u001b[1;32m 372\u001b[0m noWriting \u001b[38;5;241m=\u001b[39m [URI_HTTP, URI_FTP]\n", - "\u001b[0;31mOSError\u001b[0m: Cannot understand given URI: array([[[0],\n [0],\n [0],\n ...,\n ...." - ] - } - ], + "outputs": [], "source": [ "import imageio.v3 as iio\n", "from pathlib import Path\n", diff --git a/sleap_io/io/nwb.py b/sleap_io/io/nwb.py index 7ac7123c..81abd23c 100644 --- a/sleap_io/io/nwb.py +++ b/sleap_io/io/nwb.py @@ -14,7 +14,11 @@ from numpy.typing import ArrayLike except ImportError: ArrayLike = np.ndarray -from pynwb import NWBFile, NWBHDF5IO, ProcessingModule # type: ignore[import] + +from pynwb import NWBFile, NWBHDF5IO, ProcessingModule # type: ignore[import] +from pynwb.image import ImageSeries +from pynwb.testing.mock.utils import name_generator + from ndx_pose import ( PoseEstimationSeries, PoseEstimation, @@ -25,7 +29,7 @@ TrainingFrames, PoseTraining, SourceVideos, -) # type: ignore[import] +) from sleap_io import ( Labels, @@ -37,13 +41,12 @@ PredictedInstance, ) from sleap_io.io.utils import convert_predictions_to_dataframe -from pynwb.testing.mock.utils import name_generator -# def convert_nwb(nwb_data_structure): +# def convert_nwb_to_slp(nwb_data_structure): # """Converts an NWB object to its object SLEAP instance.""" -# def convert_frame(frame: TrainingFrame) -> LabeledFrame: +# def convert_frame(frame: TrainingFrame) -> LabeledFrame: # type: ignore[return] # """ # Converts an NWB TrainingFrame instance to a LabeledFrame instance. # """ @@ -111,12 +114,7 @@ def convert_labels_to_pose_training(labels: Labels, filename: str, **kwargs) -> training_frame_annotator = f"{training_frame_name}{i}" training_frame_skeleton_instances = SkeletonInstances( [ - SkeletonInstance( - name=instance.track.name, - skeleton=instance.skeleton, - points=instance.points, - confidence=instance.point_scores, - ) + convert_instance_to_skeleton_instance(instance) for instance in labeled_frame.instances ] ) @@ -145,15 +143,23 @@ def convert_slp_skeleton_to_nwb(skeleton: SLEAPSkeleton) -> NWBSkeleton: # type: if i == len(skeleton.edges): break nwb_edges.append([i, i + 1]) - print(nwb_edges) return NWBSkeleton( name=skeleton.name, nodes=skeleton.node_names, edges=np.array(nwb_edges, dtype=np.uint8), ) -def convert_nwb(): - raise +def convert_instance_to_skeleton_instance(instance: Instance) -> SkeletonInstance: # type: ignore[return] + id = np.uint(10) + skeleton = convert_slp_skeleton_to_nwb(instance.skeleton) + node_locations = skeleton.edges + node_visibility = [True, False] + return SkeletonInstance( + id=id, + node_locations=node_locations, + node_visibility=node_visibility, + skeleton=skeleton, + ) def get_timestamps(series: PoseEstimationSeries) -> np.ndarray: