diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 062cf0f76..ed9c377d2 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -35,8 +35,9 @@ class which inherits from `AppCommand` (or a more specialized class such as import traceback from enum import Enum from glob import glob +from itertools import permutations, product from pathlib import Path, PurePath -from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast import attr import cv2 @@ -53,7 +54,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from sleap.gui.state import GuiState from sleap.gui.suggestions import VideoFrameSuggestions from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track -from sleap.io.cameras import Camcorder, RecordingSession +from sleap.io.cameras import Camcorder, InstanceGroup, FrameGroup, RecordingSession from sleap.io.convert import default_analysis_filename from sleap.io.dataset import Labels from sleap.io.format.adaptor import Adaptor @@ -1947,7 +1948,6 @@ class AddSession(EditCommand): @staticmethod def do_action(context: CommandContext, params: dict): - camera_calibration = params["camera_calibration"] session = RecordingSession.load(filename=camera_calibration) @@ -3406,409 +3406,114 @@ def do_action(cls, context: CommandContext, params: dict): ask_again: If True, then ask for views/instances again. Default is False. """ - # Check if we already ran ask - ask_again = params.get("ask_again", False) - - # Add "instances" to params dict without GUI, otherwise taken care of in ask - if ask_again: - params["show_dialog"] = False - enough_instances = cls.verify_views_and_instances( - context=context, params=params - ) - if not enough_instances: - return - - # Get params + # Get `FrameGroup` for the current frame index video = params.get("video", None) or context.state["video"] session = params.get("session", None) or context.labels.get_session(video) - instances = params["instances"] - - # Update instances - TriangulateSession.update_instances(session=session, instances=instances) - - @classmethod - def ask(cls, context: CommandContext, params: dict): - """Add "instances" to params dict if enough views/instances, warning user otherwise. - - Args: - context: The command context. - params: The command parameters. - video: The `Video` object to use. Default is current video. - session: The `RecordingSession` object to use. Default is current - video's session. - frame_idx: The frame index to use. Default is current frame index. - instance: The `Instance` object to use. Default is current instance. - show_dialog: If True, then show a warning dialog. Default is True. - - Returns: - True if enough views/instances for triangulation, False otherwise. - """ - - return cls.verify_views_and_instances(context=context, params=params) - - @classmethod - def verify_views_and_instances(cls, context: CommandContext, params: dict): - """Verify that there are enough views and instances to triangulate. - - Also adds "instances" to params dict if there are enough views and instances. - - Args: - context: The command context. - params: The command parameters. - video: The `Video` object used to lookup a `session` (if not provided). - Default is current video. - session: The `RecordingSession` object to use. Default is current - video's session. - frame_idx: The frame index to use. Default is current frame index. - instance: The `Instance` object to use. Default is current instance. - show_dialog: If True, then show a warning dialog. Default is True. + frame_idx: int = params["frame_idx"] + frame_group: FrameGroup = ( + params.get("frame_group", None) or session.frame_groups[frame_idx] + ) - Returns: - True if enough views/instances for triangulation, False otherwise. - """ - video = params.get("video", None) or context.state["video"] - session = params.get("session", None) or context.labels.get_session(video) + # Get the `InstanceGroup` from `Instance` if any instance = params.get("instance", None) or context.state["instance"] - show_dialog = params.get("show_dialog", True) - - # This value could possibly be 0, so we can't use "or" - frame_idx = params.get("frame_idx", None) - frame_idx = frame_idx if frame_idx is not None else context.state["frame_idx"] + instance_group = frame_group.get_instance_group(instance) - # Return if we don't have a session for video or an instance selected. - if session is None or instance is None: - return - - track = instance.track - cams_to_include = params.get("cams_to_include", None) or session.linked_cameras - - # If not enough `Camcorder`s available/specified, then return - if not TriangulateSession.verify_enough_views( - context=context, - session=session, - cams_to_include=cams_to_include, - show_dialog=show_dialog, - ): - return False + # If instance_group is None, then we will try to triangulate entire frame_group + instance_groups = ( + [instance_group] + if instance_group is not None + else frame_group.instance_groups + ) - # Get all instances accross views at this frame index - instances = TriangulateSession.get_and_verify_enough_instances( - context=context, - session=session, + # Retain instance groups that have enough views/instances for triangulation + instance_groups = TriangulateSession.has_enough_instances( + frame_group=frame_group, + instance_groups=instance_groups, frame_idx=frame_idx, - cams_to_include=cams_to_include, - track=track, - show_dialog=show_dialog, + instance=instance, ) + if instance_groups is None or len(instance_groups) == 0: + return # Not enough instances for triangulation - # Return if not enough instances - if not instances: - return False + # Get the `FrameGroup` of shape M=include x T x N x 2 + fg_tensor = frame_group.numpy(instance_groups=instance_groups, pred_as_nan=True) - # Add instances to params dict - params["instances"] = instances + # Add extra dimension for number of frames + frame_group_tensor = np.expand_dims(fg_tensor, axis=1) # M=include x F=1 xTxNx2 - return True + # Triangulate to one 3D pose per instance + points_3d = triangulate( + p2d=frame_group_tensor, + calib=session.camera_cluster, + excluded_views=frame_group.excluded_views, + ) # F x T x N x 3 - @staticmethod - def get_and_verify_enough_instances( - session: RecordingSession, - frame_idx: int, - context: Optional[CommandContext] = None, - cams_to_include: Optional[List[Camcorder]] = None, - track: Optional[Track] = None, - show_dialog: bool = True, - ) -> Union[Dict[Camcorder, Instance], bool]: - """Get all instances accross views at this frame index. + # Reproject onto all views + pts_reprojected = reproject( + points_3d, + calib=session.camera_cluster, + excluded_views=frame_group.excluded_views, + ) # M=include x F=1 x T x N x 2 - If not enough `Instance`s are available at this frame index, then return False. + # Sqeeze back to the original shape + points_reprojected = np.squeeze(pts_reprojected, axis=1) # M=include x TxNx2 - Args: - session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get instances from (0-indexed). - context: The optional command context used to display a dialog. - cams_to_include: List of `Camcorder`s to include. Default is all. - track: `Track` object used to find instances accross views. Default is None. - show_dialog: If True, then show a warning dialog. Default is True. - - Returns: - Dict with `Camcorder` keys and `Instances` values (or False if not enough - instances at this frame index). - """ - try: - instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, - track=track, - require_multiple_views=True, - ) - return instances - except Exception as e: - # If not enough views, instances or some other error, then return - message = str(e) - message += "\n\tSkipping triangulation and reprojection." - logger.warning(message) - return False + # Update or create/insert ("upsert") instance points + frame_group.upsert_points( + points=points_reprojected, + instance_groups=instance_groups, + exclude_complete=True, + ) - @staticmethod - def verify_enough_views( - session: RecordingSession, - context: Optional[CommandContext] = None, - cams_to_include: Optional[List[Camcorder]] = None, - show_dialog=True, - ): - """If not enough `Camcorder`s available/specified, then return False. + @classmethod + def has_enough_instances( + cls, + frame_group: FrameGroup, + instance_groups: Optional[List[InstanceGroup]], + frame_idx: Optional[int] = None, + instance: Optional[Instance] = None, + ) -> Optional[List[InstanceGroup]]: + """Filters out instance groups without enough instances for triangulation. Args: - session: The `RecordingSession` containing the `Camcorder`s. - context: The optional command context, used to display a dialog. - cams_to_include: List of `Camcorder`s to include. Default is all. - show_dialog: If True, then show a warning dialog. Default is True. + frame_group: The `FrameGroup` object to use. + instance_groups: A list of `InstanceGroup` objects to use. + frame_idx: The frame index to use. + instance: The `Instance` object to use (only used in logging). Returns: - True if enough views are available, False otherwise. + A list of `InstanceGroup` objects with enough instances for triangulation. """ - if (cams_to_include is not None and len(cams_to_include) <= 1) or ( - len(session.videos) <= 1 - ): - message = ( - "One or less cameras available. " - "Multiple cameras needed to triangulate. " - "Skipping triangulation and reprojection." - ) - if show_dialog and context is not None: - QtWidgets.QMessageBox.warning(context.app, "Triangulation", message) - else: - logger.warning(message) - - return False - - return True - - @staticmethod - def get_instances_across_views( - session: RecordingSession, - frame_idx: int, - cams_to_include: Optional[List[Camcorder]] = None, - track: Optional["Track"] = None, - require_multiple_views: bool = False, - ) -> Dict[Camcorder, "Instance"]: - """Get all `Instances` accross all views at a given frame index. + if instance is None: + instance = "" # Just used for logging - Args: - session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get instances from (0-indexed). - cams_to_include: List of `Camcorder`s to include. Default is all. - track: `Track` object used to find instances accross views. Default is None. - require_multiple_views: If True, then raise and error if one or less views - or instances are found. - - Returns: - Dict with `Camcorder` keys and `Instances` values. - - Raises: - ValueError if require_multiple_view is true and one or less views or - instances are found. - """ + if frame_idx is None: + frame_idx = "" # Just used for logging - def _message(views: bool): - views_or_instances = "views" if views else "instances" - return ( - f"One or less {views_or_instances} found for frame " - f"{frame_idx} in {session.camera_cluster}. " - "Multiple instances accross multiple views needed to triangulate." + if len(instance_groups) < 1: + logger.warning( + f"Require at least 1 instance group, but found " + f"{len(frame_group.instance_groups)} for frame group {frame_group} at " + f"frame {frame_idx}." + f"\nSkipping triangulation." ) - - # Get all views at this frame index - views: Dict[ - Camcorder, "LabeledFrame" - ] = TriangulateSession.get_all_views_at_frame( - session=session, - frame_idx=frame_idx, - cams_to_include=cams_to_include, - ) - - # If not enough views, then raise error - if len(views) <= 1 and require_multiple_views: - raise ValueError(_message(views=True)) - - # Find all instance accross all views - instances: Dict[Camcorder, "Instance"] = {} - for cam, lf in views.items(): - insts = lf.find(track=track) - if len(insts) > 0: - instances[cam] = insts[0] - - # If not enough instances for multiple views, then raise error - if len(instances) <= 1 and require_multiple_views: - raise ValueError(_message(views=False)) - - return instances - - @staticmethod - def get_all_views_at_frame( - session: RecordingSession, - frame_idx, - cams_to_include: Optional[List[Camcorder]] = None, - ) -> Dict[Camcorder, "LabeledFrame"]: - """Get all views at a given frame index. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - frame_idx: Frame index to get views from (0-indexed). - cams_to_include: List of `Camcorder`s to include. Default is all. - - Returns: - Dict with `Camcorder` keys and `LabeledFrame` values. - """ - - views: Dict[Camcorder, "LabeledFrame"] = {} - videos: Dict[Camcorder, Video] = session.get_videos_from_selected_cameras( - cams_to_include=cams_to_include - ) - for cam, video in videos.items(): - lfs: List["LabeledFrame"] = session.labels.get((video, [frame_idx])) - if len(lfs) == 0: - logger.debug( - f"No LabeledFrames found for video {video} at {frame_idx}." - ) - continue - - lf = lfs[0] - if len(lf.instances) == 0: + return None # No instance groups found + + # Assert that there are enough views and instances + instance_groups_to_tri = [] + for instance_group in instance_groups: + instances = instance_group.get_instances(frame_group.cams_to_include) + if len(instances) < 2: + # Not enough instances logger.warning( - f"No Instances found for {lf}." - " There should not be empty LabeledFrames." + f"Not enough instances in {instance_group} for triangulation." + f"\nSkipping instance group." ) continue + instance_groups_to_tri.append(instance_group) - views[cam] = lf - - return views - - @staticmethod - def get_instances_matrices(instances_ordered: List[Instance]) -> np.ndarray: - """Gather instances from views into M x F x T x N x 2 an array. - - M: # views, F: # frames = 1, T: # tracks = 1, N: # nodes, 2: x, y - - Args: - instances_ordered: List of instances from view (following the order of the - `RecordingSession.cameras` if using for triangulation). - - Returns: - M x F x T x N x 2 array of instances coordinates. - """ - - # Gather instances into M x F x T x N x 2 arrays (require specific order) - # (M = # views, F = # frames = 1, T = # tracks = 1, N = # nodes, 2 = x, y) - inst_coords = np.stack( - [inst.numpy() for inst in instances_ordered], axis=0 - ) # M x N x 2 - inst_coords = np.expand_dims(inst_coords, axis=1) # M x T=1 x N x 2 - inst_coords = np.expand_dims(inst_coords, axis=1) # M x F=1 x T=1 x N x 2 - - return inst_coords - - @staticmethod - def calculate_excluded_views( - session: RecordingSession, - instances: Dict[Camcorder, "Instance"], - ) -> Tuple[str]: - """Get excluded views from dictionary of `Camcorder` to `Instance`. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with `Camcorder` key and `Instance` values. - - Returns: - Tuple of excluded view names. - """ - - # Calculate excluded views from included cameras - cams_excluded = set(session.cameras) - set(instances.keys()) - excluded_views = tuple(cam.name for cam in cams_excluded) - - return excluded_views - - @staticmethod - def calculate_reprojected_points( - session: RecordingSession, instances: Dict[Camcorder, "Instance"] - ) -> Iterator[Tuple["Instance", np.ndarray]]: - """Triangulate and reproject instance coordinates. - - Note that the order of the instances in the list must match the order of the - cameras in the `CameraCluster`, that is why we require instances be passed in as - a dictionary mapping back to its `Camcorder`. - https://github.com/lambdaloop/aniposelib/blob/d03b485c4e178d7cff076e9fe1ac36837db49158/aniposelib/cameras.py#L491 - - Args: - instances: Dict with `Camcorder` keys and `Instance` values. - - Returns: - A zip of the ordered instances and the related reprojected coordinates. Each - element in the coordinates is a numpy array of shape (1, N, 2) where N is - the number of nodes. - """ - - # TODO (LM): Support multiple tracks and optimize - - excluded_views = TriangulateSession.calculate_excluded_views( - session=session, instances=instances - ) - instances_ordered = [ - instances[cam] for cam in session.cameras if cam in instances - ] - - # Gather instances into M x F x T x N x 2 arrays (require specific order) - # (M = # views, F = # frames = 1, T = # tracks = 1, N = # nodes, 2 = x, y) - inst_coords = TriangulateSession.get_instances_matrices( - instances_ordered=instances_ordered - ) # M x F=1 x T=1 x N x 2 - points_3d = triangulate( - p2d=inst_coords, - calib=session.camera_cluster, - excluded_views=excluded_views, - ) # F=1, T=1, N, 3 - - # Update the views with the new 3D points - inst_coords_reprojected = reproject( - points_3d, calib=session.camera_cluster, excluded_views=excluded_views - ) # M x F=1 x T=1 x N x 2 - insts_coords_list: List[np.ndarray] = np.split( - inst_coords_reprojected.squeeze(), inst_coords_reprojected.shape[0], axis=0 - ) # len(M) of T=1 x N x 2 - - return zip(instances_ordered, insts_coords_list) - - @staticmethod - def update_instances(session, instances: Dict[Camcorder, Instance]): - """Triangulate, reproject, and update coordinates of `Instances`. - - Args: - session: The `RecordingSession` containing the `Camcorder`s. - instances: Dict with `Camcorder` keys and `Instance` values. - - Returns: - None - """ - - # Triangulate and reproject instance coordinates. - instances_and_coords: Iterator[ - Tuple["Instance", np.ndarray] - ] = TriangulateSession.calculate_reprojected_points( - session=session, instances=instances - ) - - # Update the instance coordinates. - for inst, inst_coord in instances_and_coords: - inst.update_points( - points=inst_coord[0], exclude_complete=True - ) # inst_coord is (1, N, 2) + return instance_groups_to_tri # `InstanceGroup`s with enough instances def open_website(url: str): diff --git a/sleap/gui/dialogs/delete.py b/sleap/gui/dialogs/delete.py index 7e8d39e6b..a0a281e74 100644 --- a/sleap/gui/dialogs/delete.py +++ b/sleap/gui/dialogs/delete.py @@ -216,7 +216,7 @@ def _delete(self, lf_inst_list: List[Tuple[LabeledFrame, Instance]]): for lf, inst in lf_inst_list: self.context.labels.remove_instance(lf, inst, in_transaction=True) if not lf.instances: - self.context.labels.remove(lf) + self.context.labels.remove_frame(lf=lf, update_cache=False) # Update caches since we skipped doing this after each deletion self.context.labels.update_cache() diff --git a/sleap/instance.py b/sleap/instance.py index 1da784416..ed1fa3d07 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -500,7 +500,6 @@ def _points_dict_to_array( ) try: parray[skeleton.node_to_index(node)] = point - # parray[skeleton.node_to_index(node.name)] = point except: logger.debug( f"Could not set point for node {node} in {skeleton} " @@ -729,9 +728,31 @@ def update_points(self, points: np.ndarray, exclude_complete: bool = False): for point_new, points_old, node_name in zip( points, self._points, self.skeleton.node_names ): + + # Skip if new point is nan or old point is complete if np.isnan(point_new).any() or (exclude_complete and points_old.complete): continue - points_dict[node_name] = Point(x=point_new[0], y=point_new[1]) + + # Grab the x, y from the new point and visible, complete from the old point + x, y = point_new + visible = points_old.visible + complete = points_old.complete + + # Create a new point and add to the dict + if type(self._points) == PredictedPointArray: + # TODO(LM): The point score is meant to rate the confidence of the + # prediction, but this method updates from triangulation. + score = points_old.score + point_obj = PredictedPoint( + x=x, y=y, visible=visible, complete=complete, score=score + ) + else: + point_obj = Point(x=x, y=y, visible=visible, complete=complete) + + # Update the points dict + points_dict[node_name] = point_obj + + # Update the points if len(points_dict) > 0: Instance._points_dict_to_array(points_dict, self._points, self.skeleton) diff --git a/sleap/io/cameras.py b/sleap/io/cameras.py index 0cf830feb..d8bac5807 100644 --- a/sleap/io/cameras.py +++ b/sleap/io/cameras.py @@ -1,8 +1,9 @@ """Module for storing information for camera groups.""" + import logging import tempfile from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast, Set import cattr import numpy as np @@ -10,9 +11,9 @@ from aniposelib.cameras import Camera, CameraGroup, FisheyeCamera from attrs import define, field from attrs.validators import deep_iterable, instance_of -from sleap_anipose import reproject, triangulate # from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer +from sleap.instance import LabeledFrame, Instance, PredictedInstance from sleap.io.video import Video from sleap.util import deep_iterable_converter @@ -394,6 +395,507 @@ def to_calibration_dict(self) -> Dict[str, str]: return calibration_dict +@define +class InstanceGroup: + """Defines a group of instances across the same frame index. + + Args: + camera_cluster: `CameraCluster` object. + instances: List of `Instance` objects. + + """ + + _name: str = field() + frame_idx: int = field(validator=instance_of(int)) + _instance_by_camcorder: Dict[Camcorder, Instance] = field(factory=dict) + _camcorder_by_instance: Dict[Instance, Camcorder] = field(factory=dict) + _dummy_instance: Optional[Instance] = field(default=None) + + # Class attributes + camera_cluster: Optional[CameraCluster] = None + + def __attrs_post_init__(self): + """Initialize `InstanceGroup` object.""" + + instance = None + for cam, instance in self._instance_by_camcorder.items(): + self._camcorder_by_instance[instance] = cam + + # Create a dummy instance to fill in for missing instances + if self._dummy_instance is None: + self._create_dummy_instance(instance=instance) + + def _create_dummy_instance(self, instance: Optional[Instance] = None): + """Create a dummy instance to fill in for missing instances. + + Args: + instance: Optional `Instance` object to use as an example instance. If None, + then the first instance in the `InstanceGroup` is used. + + Raises: + ValueError: If no instances are available to create a dummy instance. + """ + + if self._dummy_instance is None: + # Get an example instance + if instance is None: + if len(self.instances) < 1: + raise ValueError( + "Cannot create a dummy instance without any instances." + ) + instance = self.instances[0] + + # Use the example instance to create a dummy instance + skeleton: "Skeleton" = instance.skeleton + self._dummy_instance = PredictedInstance.from_numpy( + points=np.full( + shape=(len(skeleton.nodes), 2), + fill_value=np.nan, + ), + point_confidences=np.full( + shape=(len(skeleton.nodes),), + fill_value=np.nan, + ), + instance_score=np.nan, + skeleton=skeleton, + ) + + @property + def dummy_instance(self) -> PredictedInstance: + """Dummy `PredictedInstance` object to fill in for missing instances. + + Also used to create instances that are not found in the `InstanceGroup`. + + Returns: + `PredictedInstance` object or None if unable to create the dummy instance. + """ + + if self._dummy_instance is None: + self._create_dummy_instance() + return self._dummy_instance + + @property + def name(self) -> str: + """Name of the `InstanceGroup`.""" + + return self._name + + @name.setter + def name(self, name: str): + """Set the name of the `InstanceGroup`.""" + + raise ValueError( + "Cannot set name directly. Use `set_name` method instead (preferably " + "through FrameGroup.set_instance_group_name)." + ) + + def set_name(self, name: str, name_registry: Set[str]): + """Set the name of the `InstanceGroup`. + + This function mutates the name_registry input (see side-effect). + + Args: + name: Name to set for the `InstanceGroup`. + name_registry: Set of names to check for uniqueness. + + Raises: + ValueError: If the name is already in use (in the name_registry). + """ + + # Check if the name is already in use + if name in name_registry: + raise ValueError( + f"Name {name} already in use. Please use a unique name not currently " + f"in the registry: {name_registry}" + ) + + # Remove the old name from the registry + if self._name in name_registry: + name_registry.remove(self._name) + + self._name = name + name_registry.add(name) + + @classmethod + def return_unique_name(cls, name_registry: Set[str]) -> str: + """Return a unique name for the `InstanceGroup`. + + Args: + name_registry: Set of names to check for uniqueness. + + Returns: + Unique name for the `InstanceGroup`. + """ + + base_name = "instance_group_" + count = len(name_registry) + new_name = f"{base_name}{count}" + + while new_name in name_registry: + count += 1 + new_name = f"{base_name}{count}" + + return new_name + + @property + def instances(self) -> List[Instance]: + """List of `Instance` objects.""" + return list(self._instance_by_camcorder.values()) + + @property + def cameras(self) -> List[Camcorder]: + """List of `Camcorder` objects.""" + return list(self._instance_by_camcorder.keys()) + + @property + def instance_by_camcorder(self) -> Dict[Camcorder, Instance]: + """Dictionary of `Instance` objects by `Camcorder`.""" + return self._instance_by_camcorder + + def numpy(self, pred_as_nan: bool = False) -> np.ndarray: + """Return instances as a numpy array of shape (n_views, n_nodes, 2). + + The ordering of views is based on the ordering of `Camcorder`s in the + `self.camera_cluster: CameraCluster`. + + If an instance is missing for a `Camcorder`, then the instance is filled in with + the dummy instance (all NaNs). + + Args: + pred_as_nan: If True, then replaces `PredictedInstance`s with all nan + self.dummy_instance. Default is False. + + Returns: + Numpy array of shape (n_views, n_nodes, 2). + """ + + instance_numpys: List[np.ndarray] = [] # len(M) x N x 2 + for cam in self.camera_cluster.cameras: + instance = self.get_instance(cam) + + # Determine whether to use a dummy (all nan) instance + instance_is_missing = instance is None + instance_as_nan = pred_as_nan and isinstance(instance, PredictedInstance) + use_dummy_instance = instance_is_missing or instance_as_nan + + # Add the dummy instance if the instance is missing + if use_dummy_instance: + instance = self.dummy_instance # This is an all nan PredictedInstance + + instance_numpy: np.ndarray = instance.numpy() # N x 2 + instance_numpys.append(instance_numpy) + + return np.stack(instance_numpys, axis=0) # M x N x 2 + + def create_and_add_instance(self, cam: Camcorder, labeled_frame: LabeledFrame): + """Create an `Instance` at a labeled_frame and add it to the `InstanceGroup`. + + Args: + cam: `Camcorder` object that the `Instance` is for. + labeled_frame: `LabeledFrame` object that the `Instance` is contained in. + + Returns: + All nan `PredictedInstance` created and added to the `InstanceGroup`. + """ + + # Get the `Skeleton` + skeleton: "Skeleton" = self.dummy_instance.skeleton + + # Create an all nan `Instance` + instance: PredictedInstance = PredictedInstance.from_numpy( + points=self.dummy_instance.points_array, + point_confidences=self.dummy_instance.scores, + instance_score=self.dummy_instance.score, + skeleton=skeleton, + ) + instance.frame = labeled_frame + + # Add the instance to the `InstanceGroup` + self.add_instance(cam, instance) + + return instance + + def add_instance(self, cam: Camcorder, instance: Instance): + """Add an `Instance` to the `InstanceGroup`. + + Args: + cam: `Camcorder` object that the `Instance` is for. + instance: `Instance` object to add. + + Raises: + ValueError: If the `Camcorder` is not in the `CameraCluster`. + ValueError: If the `Instance` is already in the `InstanceGroup` at another + camera. + """ + + # Ensure the `Camcorder` is in the `CameraCluster` + self._raise_if_cam_not_in_cluster(cam=cam) + + # Ensure the `Instance` is not already in the `InstanceGroup` at another camera + if ( + instance in self._camcorder_by_instance + and self._camcorder_by_instance[instance] != cam + ): + raise ValueError( + f"Instance {instance} is already in this InstanceGroup at camera " + f"{self.get_instance(instance)}." + ) + + # Add the instance to the `InstanceGroup` + self.replace_instance(cam, instance) + + def replace_instance(self, cam: Camcorder, instance: Instance): + """Replace an `Instance` in the `InstanceGroup`. + + If the `Instance` is already in the `InstanceGroup`, then it is removed and + replaced. If the `Instance` is not already in the `InstanceGroup`, then it is + added. + + Args: + cam: `Camcorder` object that the `Instance` is for. + instance: `Instance` object to replace. + + Raises: + ValueError: If the `Camcorder` is not in the `CameraCluster`. + """ + + # Ensure the `Camcorder` is in the `CameraCluster` + self._raise_if_cam_not_in_cluster(cam=cam) + + # Remove the instance if it already exists + self.remove_instance(instance_or_cam=instance) + + # Replace the instance in the `InstanceGroup` + self._instance_by_camcorder[cam] = instance + self._camcorder_by_instance[instance] = cam + + def remove_instance(self, instance_or_cam: Union[Instance, Camcorder]): + """Remove an `Instance` from the `InstanceGroup`. + + Args: + instance_or_cam: `Instance` or `Camcorder` object to remove from + `InstanceGroup`. + + Raises: + ValueError: If the `Camcorder` is not in the `CameraCluster`. + """ + + if isinstance(instance_or_cam, Camcorder): + cam = instance_or_cam + + # Ensure the `Camcorder` is in the `CameraCluster` + self._raise_if_cam_not_in_cluster(cam=cam) + + # Remove the instance from the `InstanceGroup` + if cam in self._instance_by_camcorder: + instance = self._instance_by_camcorder.pop(cam) + self._camcorder_by_instance.pop(instance) + + else: + # The input is an `Instance` + instance = instance_or_cam + + # Remove the instance from the `InstanceGroup` + if instance in self._camcorder_by_instance: + cam = self._camcorder_by_instance.pop(instance) + self._instance_by_camcorder.pop(cam) + else: + logger.debug( + f"Instance {instance} not found in this InstanceGroup {self}." + ) + + def _raise_if_cam_not_in_cluster(self, cam: Camcorder): + """Raise a ValueError if the `Camcorder` is not in the `CameraCluster`.""" + + if cam not in self.camera_cluster: + raise ValueError( + f"Camcorder {cam} is not in this InstanceGroup's " + f"{self.camera_cluster}." + ) + + def get_instance(self, cam: Camcorder) -> Optional[Instance]: + """Retrieve `Instance` linked to `Camcorder`. + + Args: + camcorder: `Camcorder` object. + + Returns: + If `Camcorder` in `self.camera_cluster`, then `Instance` object if found, else + `None` if `Camcorder` has no linked `Instance`. + """ + + if cam not in self._instance_by_camcorder: + logger.debug( + f"Camcorder {cam} has no linked `Instance` in this `InstanceGroup` " + f"{self}." + ) + return None + + return self._instance_by_camcorder[cam] + + def get_instances(self, cams: List[Camcorder]) -> List[Instance]: + instances = [] + for cam in cams: + instance = self.get_instance(cam) + instances.append(instance) + return instance + + def get_cam(self, instance: Instance) -> Optional[Camcorder]: + """Retrieve `Camcorder` linked to `Instance`. + + Args: + instance: `Instance` object. + + Returns: + `Camcorder` object if found, else `None`. + """ + + if instance not in self._camcorder_by_instance: + logger.debug( + f"{instance} is not in this InstanceGroup.instances: " + f"\n\t{self.instances}." + ) + return None + + return self._camcorder_by_instance[instance] + + def update_points( + self, + points: np.ndarray, + cams_to_include: Optional[List[Camcorder]] = None, + exclude_complete: bool = True, + ): + """Update the points in the `Instance` for the specified `Camcorder`s. + + Args: + points: Numpy array of shape (M, N, 2) where M is the number of views, N is + the number of Nodes, and 2 is for x, y. + cams_to_include: List of `Camcorder`s to include in the update. The order of + the `Camcorder`s in the list should match the order of the views in the + `points` array. If None, then all `Camcorder`s in the `CameraCluster` + are included. Default is None. + exclude_complete: If True, then do not update points that are marked as + complete. Default is True. + """ + + # If no `Camcorder`s specified, then update `Instance`s for all `CameraCluster` + if cams_to_include is None: + cams_to_include = self.camera_cluster.cameras + + # Check that correct shape was passed in + n_views, n_nodes, _ = points.shape + assert n_views == len(cams_to_include), ( + f"Number of views in `points` ({n_views}) does not match the number of " + f"Camcorders in `cams_to_include` ({len(cams_to_include)})." + ) + + for cam_idx, cam in enumerate(cams_to_include): + # Get the instance for the cam + instance: Optional[Instance] = self.get_instance(cam) + if instance is None: + logger.warning( + f"Camcorder {cam.name} not found in this InstanceGroup's instances." + ) + continue + + # Update the points (and scores) for the (predicted) instance + instance.update_points( + points=points[cam_idx, :, :], exclude_complete=exclude_complete + ) + + def __getitem__( + self, idx_or_key: Union[int, Camcorder, Instance] + ) -> Union[Camcorder, Instance]: + """Grab a `Camcorder` of `Instance` from the `InstanceGroup`.""" + + def _raise_key_error(): + raise KeyError(f"Key {idx_or_key} not found in {self.__class__.__name__}.") + + # Try to find in `self.camera_cluster.cameras` + if isinstance(idx_or_key, int): + try: + return self.instances[idx_or_key] + except IndexError: + _raise_key_error() + + # Return a `Instance` if `idx_or_key` is a `Camcorder`` + if isinstance(idx_or_key, Camcorder): + return self.get_instance(idx_or_key) + + else: + # isinstance(idx_or_key, Instance): + try: + return self.get_cam(idx_or_key) + except: + pass + + _raise_key_error() + + def __len__(self): + return len(self.instances) + + def __repr__(self): + return f"{self.__class__.__name__}(frame_idx={self.frame_idx}, instances={len(self)}, camera_cluster={self.camera_cluster})" + + def __hash__(self) -> int: + return hash(self._name) + + @classmethod + def from_dict( + cls, d: dict, name: str, name_registry: Set[str] + ) -> Optional["InstanceGroup"]: + """Creates an `InstanceGroup` object from a dictionary. + + Args: + d: Dictionary with `Camcorder` keys and `Instance` values. + name: Name to use for the `InstanceGroup`. + name_registry: Set of names to check for uniqueness. + + Raises: + ValueError: If the `InstanceGroup` name is already in use. + + Returns: + `InstanceGroup` object or None if no "real" (determined by `frame_idx` other + than None) instances found. + """ + + # Ensure not to mutate the original dictionary + d_copy = d.copy() + + frame_idx = None + for cam, instance in d_copy.copy().items(): + camera_cluster = cam.camera_cluster + + # Remove dummy instances (determined by not having a frame index) + if instance.frame_idx is None: + d_copy.pop(cam) + # Grab the frame index from non-dummy instances + elif frame_idx is None: + frame_idx = instance.frame_idx + # Ensure all instances have the same frame index + elif frame_idx != instance.frame_idx: + raise ValueError( + f"Cannot create `InstanceGroup`: Frame index {frame_idx} does " + f"not match instance frame index {instance.frame_idx}." + ) + + if len(d_copy) == 0: + raise ValueError("Cannot create `InstanceGroup`: No real instances found.") + + if name in name_registry: + raise ValueError( + f"Cannot create `InstanceGroup`: Name {name} already in use. Please " + f"use a unique name that is not in the registry: {name_registry}." + ) + + return cls( + name=name, + frame_idx=frame_idx, + camera_cluster=camera_cluster, + instance_by_camcorder=d_copy, + ) + + @define(eq=False) class RecordingSession: """Class for storing information for a recording session. @@ -412,8 +914,9 @@ class RecordingSession: # TODO(LM): Consider implementing Observer pattern for `camera_cluster` and `labels` camera_cluster: CameraCluster = field(factory=CameraCluster) metadata: dict = field(factory=dict) + labels: Optional["Labels"] = field(default=None) _video_by_camcorder: Dict[Camcorder, Video] = field(factory=dict) - labels: Optional["Labels"] = None + _frame_group_by_frame_idx: Dict[int, "FrameGroup"] = field(factory=dict) @property def videos(self) -> List[Video]: @@ -423,15 +926,38 @@ def videos(self) -> List[Video]: @property def linked_cameras(self) -> List[Camcorder]: - """List of `Camcorder`s in `self.camera_cluster` that are linked to a video.""" + """List of `Camcorder`s in `self.camera_cluster` that are linked to a video. - return list(self._video_by_camcorder.keys()) + The list is ordered based on the order of the `Camcorder`s in the `CameraCluster`. + """ + + return sorted( + self._video_by_camcorder.keys(), key=self.camera_cluster.cameras.index + ) @property def unlinked_cameras(self) -> List[Camcorder]: - """List of `Camcorder`s in `self.camera_cluster` that are not linked to a video.""" + """List of `Camcorder`s in `self.camera_cluster` that are not linked to a video. + + The list is ordered based on the order of the `Camcorder`s in the `CameraCluster`. + """ + + return sorted( + set(self.camera_cluster.cameras) - set(self.linked_cameras), + key=self.camera_cluster.cameras.index, + ) + + @property + def frame_groups(self) -> Dict[int, "FrameGroup"]: + """Dict of `FrameGroup`s by frame index.""" - return list(set(self.camera_cluster.cameras) - set(self.linked_cameras)) + return self._frame_group_by_frame_idx + + @property + def frame_inds(self) -> List[int]: + """List of frame indices.""" + + return list(self.frame_groups.keys()) def get_video(self, camcorder: Camcorder) -> Optional[Video]: """Retrieve `Video` linked to `Camcorder`. @@ -490,9 +1016,7 @@ def add_video(self, video: Video, camcorder: Camcorder): """ # Ensure the `Camcorder` is in this `RecordingSession`'s `CameraCluster` - try: - assert camcorder in self.camera_cluster - except AssertionError: + if camcorder not in self.camera_cluster: raise ValueError( f"Camcorder {camcorder.name} is not in this RecordingSession's " f"{self.camera_cluster}." @@ -519,6 +1043,11 @@ def add_video(self, video: Video, camcorder: Camcorder): # Add camcorder-to-video (1-to-1) map to `RecordingSession` self._video_by_camcorder[camcorder] = video + # Sort `_videos_by_session` by order of linked `Camcorder` in `CameraCluster.cameras` + self.camera_cluster._videos_by_session[self].sort( + key=lambda video: self.camera_cluster.cameras.index(self.get_camera(video)) + ) + # Update labels cache if self.labels is not None: self.labels.update_session(self, video) @@ -545,7 +1074,22 @@ def remove_video(self, video: Video): # Update labels cache if self.labels is not None and self.labels.get_session(video) is not None: - self.labels.remove_session_video(self, video) + self.labels.remove_session_video(video=video) + + def new_frame_group(self, frame_idx: int): + """Creates and adds an empty `FrameGroup` to the `RecordingSession`. + + Args: + frame_idx: Frame index for the `FrameGroup`. + + Returns: + `FrameGroup` object. + """ + + # `FrameGroup.__attrs_post_init` will manage `_frame_group_by_frame_idx` + frame_group = FrameGroup(frame_idx=frame_idx, session=self) + + return frame_group def get_videos_from_selected_cameras( self, cams_to_include: Optional[List[Camcorder]] = None @@ -755,3 +1299,718 @@ def make_cattr(videos_list: List[Video]): RecordingSession, lambda x: x.to_session_dict(video_to_idx) ) return sessions_cattr + + +@define +class FrameGroup: + """Defines a group of `InstanceGroups` across views at the same frame index.""" + + # Instance attributes + frame_idx: int = field(validator=instance_of(int)) + session: RecordingSession = field(validator=instance_of(RecordingSession)) + _instance_groups: List[InstanceGroup] = field( + factory=list, + validator=deep_iterable( + member_validator=instance_of(InstanceGroup), + iterable_validator=instance_of(list), + ), + ) # Akin to `LabeledFrame.instances` + _instance_group_name_registry: Set[str] = field(factory=set) + + # "Hidden" class attribute + _cams_to_include: Optional[List[Camcorder]] = None + _excluded_views: Optional[Tuple[str]] = () + + # "Hidden" instance attributes + + # TODO(LM): This dict should be updated each time a LabeledFrame is added/removed + # from the Labels object. Or if a video is added/removed from the RecordingSession. + _labeled_frame_by_cam: Dict[Camcorder, LabeledFrame] = field(factory=dict) + _cam_by_labeled_frame: Dict[LabeledFrame, Camcorder] = field(factory=dict) + _instances_by_cam: Dict[Camcorder, Set[Instance]] = field(factory=dict) + + def __attrs_post_init__(self): + """Initialize `FrameGroup` object.""" + + # Check that `InstanceGroup` names unique (later added via add_instance_group) + instance_group_name_registry_copy = set(self._instance_group_name_registry) + for instance_group in self.instance_groups: + if instance_group.name in instance_group_name_registry_copy: + raise ValueError( + f"InstanceGroup name {instance_group.name} already in use. " + f"Please use a unique name not currently in the registry: " + f"{self._instance_group_name_registry}" + ) + instance_group_name_registry_copy.add(instance_group.name) + + # Remove existing `FrameGroup` object from the `RecordingSession._frame_group_by_frame_idx` + self.enforce_frame_idx_unique(self.session, self.frame_idx) + + # Reorder `cams_to_include` to match `CameraCluster` order (via setter method) + if self._cams_to_include is not None: + self.cams_to_include = self._cams_to_include + + # Add `FrameGroup` to `RecordingSession` + self.session._frame_group_by_frame_idx[self.frame_idx] = self + + # Build `_labeled_frame_by_cam` and `_instances_by_cam` dictionary + for camera in self.session.camera_cluster.cameras: + self._instances_by_cam[camera] = set() + for instance_group in self.instance_groups: + self.add_instance_group(instance_group) + + @property + def instance_groups(self) -> List[InstanceGroup]: + """List of `InstanceGroup`s.""" + + return self._instance_groups + + @instance_groups.setter + def instance_groups(self, instance_groups: List[InstanceGroup]): + """Setter for `instance_groups` that updates `LabeledFrame`s and `Instance`s.""" + + instance_groups_to_remove = set(self.instance_groups) - set(instance_groups) + instance_groups_to_add = set(instance_groups) - set(self.instance_groups) + + # Update the `_labeled_frame_by_cam` and `_instances_by_cam` dictionary + for instance_group in instance_groups_to_remove: + self.remove_instance_group(instance_group=instance_group) + + for instance_group in instance_groups_to_add: + self.add_instance_group(instance_group=instance_group) + + @property + def cams_to_include(self) -> Optional[List[Camcorder]]: + """List of `Camcorder`s to include in this `FrameGroup`.""" + + if self._cams_to_include is None: + self._cams_to_include = self.session.camera_cluster.cameras.copy() + + # TODO(LM): Should we store this in another attribute? + # Filter cams to include based on videos linked to the session + cams_to_include = [ + cam for cam in self._cams_to_include if cam in self.session.linked_cameras + ] + + return cams_to_include + + @property + def excluded_views(self) -> Optional[Tuple[str]]: + """List of excluded views (names of Camcorders).""" + + return self._excluded_views + + @cams_to_include.setter + def cams_to_include(self, cams_to_include: List[Camcorder]): + """Setter for `cams_to_include` that sorts by `CameraCluster` order.""" + + # Sort the `Camcorder`s to include based on the order of `CameraCluster` cameras + self._cams_to_include = cams_to_include.sort( + key=self.session.camera_cluster.cameras.index + ) + + # Update the `excluded_views` attribute + excluded_cams = list( + set(self.session.camera_cluster.cameras) - set(cams_to_include) + ) + excluded_cams.sort(key=self.session.camera_cluster.cameras.index) + self._excluded_views = (cam.name for cam in excluded_cams) + + @property + def labeled_frames(self) -> List[LabeledFrame]: + """List of `LabeledFrame`s.""" + + # TODO(LM): Revisit whether we need to return a list instead of a view object + return list(self._labeled_frame_by_cam.values()) + + @property + def cameras(self) -> List[Camcorder]: + """List of `Camcorder`s.""" + + # TODO(LM): Revisit whether we need to return a list instead of a view object + return list(self._labeled_frame_by_cam.keys()) + + def numpy( + self, + instance_groups: Optional[List[InstanceGroup]] = None, + pred_as_nan: bool = False, + ) -> np.ndarray: + """Numpy array of all `InstanceGroup`s in `FrameGroup.cams_to_include`. + + Args: + instance_groups: `InstanceGroup`s to include. Default is None and uses all + self.instance_groups. + pred_as_nan: If True, then replaces `PredictedInstance`s with all nan + self.dummy_instance. Default is False. + + Returns: + Numpy array of shape (M, T, N, 2) where M is the number of views (determined + by self.cames_to_include), T is the number of `InstanceGroup`s, N is the + number of Nodes, and 2 is for x, y. + """ + + # Use all `InstanceGroup`s if not specified + if instance_groups is None: + instance_groups = self.instance_groups + else: + # Ensure that `InstanceGroup`s is in this `FrameGroup` + for instance_group in instance_groups: + if instance_group not in self.instance_groups: + raise ValueError( + f"InstanceGroup {instance_group} is not in this FrameGroup: " + f"{self.instance_groups}" + ) + + instance_group_numpys: List[np.ndarray] = [] # len(T) M=all x N x 2 + for instance_group in instance_groups: + instance_group_numpy = instance_group.numpy( + pred_as_nan=pred_as_nan + ) # M=all x N x 2 + instance_group_numpys.append(instance_group_numpy) + + frame_group_numpy = np.stack(instance_group_numpys, axis=1) # M=all x T x N x 2 + cams_to_include_mask = np.array( + [1 if cam in self.cams_to_include else 0 for cam in self.cameras] + ) # M=include x 1 + + return frame_group_numpy[cams_to_include_mask] # M=include x T x N x 2 + + def add_instance( + self, + instance: Instance, + camera: Camcorder, + instance_group: Optional[InstanceGroup] = None, + ): + """Add an (existing) `Instance` to the `FrameGroup`. + + If no `InstanceGroup` is provided, then check the `Instance` is already in an + `InstanceGroup` contained in the `FrameGroup`. Otherwise, add the `Instance` to + the `InstanceGroup` and `FrameGroup`. + + Args: + instance: `Instance` to add to the `FrameGroup`. + camera: `Camcorder` to link the `Instance` to. + instance_group: `InstanceGroup` to add the `Instance` to. If None, then + check the `Instance` is already in an `InstanceGroup`. + + Raises: + ValueError: If the `InstanceGroup` is not in the `FrameGroup`. + ValueError: If the `Instance` is not linked to a `LabeledFrame`. + ValueError: If the frame index of the `Instance` does not match the frame index + of the `FrameGroup`. + ValueError: If the `LabeledFrame` of the `Instance` does not match the existing + `LabeledFrame` for the `Camcorder` in the `FrameGroup`. + ValueError: If the `Instance` is not in an `InstanceGroup` in the + `FrameGroup`. + """ + + # Ensure the `InstanceGroup` is in this `FrameGroup` + if instance_group is not None: + self._raise_if_instance_group_not_in_frame_group( + instance_group=instance_group + ) + + # Ensure `Instance` is compatible with `FrameGroup` + self._raise_if_instance_incompatibile(instance=instance, camera=camera) + + # Add the `Instance` to the `InstanceGroup` + if instance_group is not None: + instance_group.add_instance(cam=camera, instance=instance) + else: + self._raise_if_instance_not_in_instance_group(instance=instance) + + # Add the `Instance` to the `FrameGroup` + self._instances_by_cam[camera].add(instance) + + # Update the labeled frames if necessary + labeled_frame = self.get_labeled_frame(camera=camera) + if labeled_frame is None: + labeled_frame = instance.frame + self.add_labeled_frame(labeled_frame=labeled_frame, camera=camera) + + def remove_instance(self, instance: Instance): + """Removes an `Instance` from the `FrameGroup`. + + Args: + instance: `Instance` to remove from the `FrameGroup`. + """ + + instance_group = self.get_instance_group(instance=instance) + + if instance_group is None: + logger.warning( + f"Instance {instance} not found in this FrameGroup.instance_groups: " + f"{self.instance_groups}." + ) + return + + # Remove the `Instance` from the `InstanceGroup` + camera = instance_group.get_cam(instance=instance) + instance_group.remove_instance(instance=instance) + + # Remove the `Instance` from the `FrameGroup` + self._instances_by_cam[camera].remove(instance) + + # Remove "empty" `LabeledFrame`s from the `FrameGroup` + if len(self._instances_by_cam[camera]) < 1: + self.remove_labeled_frame(labeled_frame_or_camera=camera) + + def add_instance_group(self, instance_group: Optional[InstanceGroup] = None): + """Add an `InstanceGroup` to the `FrameGroup`. + + This method updates the underlying dictionaries in calling add_instance: + - `_instances_by_cam` + - `_labeled_frame_by_cam` + - `_cam_by_labeled_frame` + + Args: + instance_group: `InstanceGroup` to add to the `FrameGroup`. If None, then + create a new `InstanceGroup` and add it to the `FrameGroup`. + + Raises: + ValueError: If the `InstanceGroup` is already in the `FrameGroup`. + """ + + if instance_group is None: + + # Find a unique name for the `InstanceGroup` + instance_group_name = InstanceGroup.return_unique_name( + name_registry=self._instance_group_name_registry + ) + + # Create an empty `InstanceGroup` with the frame index of the `FrameGroup` + instance_group = InstanceGroup( + name=instance_group_name, + frame_idx=self.frame_idx, + camera_cluster=self.session.camera_cluster, + ) + else: + # Ensure the `InstanceGroup` is compatible with the `FrameGroup` + self._raise_if_instance_group_incompatible(instance_group=instance_group) + + # Add the `InstanceGroup` to the `FrameGroup` + # We only expect this to be false on initialization + if instance_group not in self.instance_groups: + self.instance_groups.append(instance_group) + + # Add instance group name to the registry + self._instance_group_name_registry.add(instance_group.name) + + # Add `Instance`s and `LabeledFrame`s to the `FrameGroup` + for camera, instance in instance_group.instance_by_camcorder.items(): + self.add_instance(instance=instance, camera=camera) + + def remove_instance_group(self, instance_group: InstanceGroup): + """Remove an `InstanceGroup` from the `FrameGroup`.""" + + if instance_group not in self.instance_groups: + logger.warning( + f"InstanceGroup {instance_group} not found in this FrameGroup: " + f"{self.instance_groups}." + ) + return + + # Remove the `InstanceGroup` from the `FrameGroup` + self.instance_groups.remove(instance_group) + self._instance_group_name_registry.remove(instance_group.name) + + # Remove the `Instance`s from the `FrameGroup` + for camera, instance in instance_group.instance_by_camcorder.items(): + self._instances_by_cam[camera].remove(instance) + + # Remove the `LabeledFrame` from the `FrameGroup` + labeled_frame = self.get_labeled_frame(camera=camera) + if labeled_frame is not None: + self.remove_labeled_frame(camera=camera) + + def get_instance_group(self, instance: Instance) -> Optional[InstanceGroup]: + """Get `InstanceGroup` that contains `Instance` if exists. Otherwise, None. + + Args: + instance: `Instance` + + Returns: + `InstanceGroup` + """ + + instance_group: Optional[InstanceGroup] = next( + ( + instance_group + for instance_group in self.instance_groups + if instance in instance_group.instances + ), + None, + ) + + return instance_group + + def set_instance_group_name(self, instance_group: InstanceGroup, name: str): + """Set the name of an `InstanceGroup` in the `FrameGroup`.""" + + self._raise_if_instance_group_not_in_frame_group(instance_group=instance_group) + + instance_group.set_name( + name=name, name_registry=self._instance_group_name_registry + ) + + def add_labeled_frame(self, labeled_frame: LabeledFrame, camera: Camcorder): + """Add a `LabeledFrame` to the `FrameGroup`. + + Args: + labeled_frame: `LabeledFrame` to add to the `FrameGroup`. + camera: `Camcorder` to link the `LabeledFrame` to. + + Raises: + ValueError: If the `LabeledFrame` is not compatible with the `FrameGroup`. + """ + + # Some checks to ensure the `LabeledFrame` is compatible with the `FrameGroup` + if not isinstance(labeled_frame, LabeledFrame): + raise ValueError( + f"Cannot add LabeledFrame: {labeled_frame} is not a LabeledFrame." + ) + elif labeled_frame.frame_idx != self.frame_idx: + raise ValueError( + f"Cannot add LabeledFrame: Frame index {labeled_frame.frame_idx} does " + f"not match FrameGroup frame index {self.frame_idx}." + ) + elif not isinstance(camera, Camcorder): + raise ValueError(f"Cannot add LabeledFrame: {camera} is not a Camcorder.") + + # Add the `LabeledFrame` to the `FrameGroup` + self._labeled_frame_by_cam[camera] = labeled_frame + self._cam_by_labeled_frame[labeled_frame] = camera + + # Add the `LabeledFrame` to the `RecordingSession`'s `Labels` object + if (self.session.labels is not None) and ( + labeled_frame not in self.session.labels + ): + self.session.labels.append(labeled_frame) + + def remove_labeled_frame( + self, labeled_frame_or_camera: Union[LabeledFrame, Camcorder] + ): + """Remove a `LabeledFrame` from the `FrameGroup`. + + Args: + labeled_frame_or_camera: `LabeledFrame` or `Camcorder` to remove the + `LabeledFrame` for. + """ + + if isinstance(labeled_frame_or_camera, LabeledFrame): + labeled_frame: LabeledFrame = labeled_frame_or_camera + camera = self.get_camera(labeled_frame=labeled_frame) + + elif isinstance(labeled_frame_or_camera, Camcorder): + camera: Camcorder = labeled_frame_or_camera + labeled_frame = self.get_labeled_frame(camera=camera) + + else: + logger.warning( + f"Cannot remove LabeledFrame: {labeled_frame_or_camera} is not a " + "LabeledFrame or Camcorder." + ) + + # Remove the `LabeledFrame` from the `FrameGroup` + self._labeled_frame_by_cam.pop(camera, None) + self._cam_by_labeled_frame.pop(labeled_frame, None) + + def get_labeled_frame(self, camera: Camcorder) -> Optional[LabeledFrame]: + """Get `LabeledFrame` for `Camcorder` if exists. Otherwise, None. + + Args: + camera: `Camcorder` + + Returns: + `LabeledFrame` + """ + + return self._labeled_frame_by_cam.get(camera, None) + + def get_camera(self, labeled_frame: LabeledFrame) -> Optional[Camcorder]: + """Get `Camcorder` for `LabeledFrame` if exists. Otherwise, None. + + Args: + labeled_frame: `LabeledFrame` + + Returns: + `Camcorder` + """ + + return self._cam_by_labeled_frame.get(labeled_frame, None) + + def _create_and_add_labeled_frame(self, camera: Camcorder) -> LabeledFrame: + """Create and add a `LabeledFrame` to the `FrameGroup`. + + This also adds the `LabeledFrame` to the `RecordingSession`'s `Labels` object. + + Args: + camera: `Camcorder` + + Returns: + `LabeledFrame` that was created and added to the `FrameGroup`. + """ + + video = self.session.get_video(camera) + if video is None: + # There should be a `Video` linked to all cams_to_include + raise ValueError( + f"Camcorder {camera} is not linked to a video in this " + f"RecordingSession {self.session}." + ) + + labeled_frame = LabeledFrame(video=video, frame_idx=self.frame_idx) + self.add_labeled_frame(labeled_frame=labeled_frame) + + return labeled_frame + + def _create_and_add_instance( + self, + instance_group: InstanceGroup, + camera: Camcorder, + labeled_frame: LabeledFrame, + ): + """Add an `Instance` to the `InstanceGroup` (and `FrameGroup`). + + Args: + instance_group: `InstanceGroup` to add the `Instance` to. + camera: `Camcorder` to link the `Instance` to. + labeled_frame: `LabeledFrame` that the `Instance` is in. + """ + + # Add the `Instance` to the `InstanceGroup` + instance = instance_group.create_and_add_instance( + cam=camera, labeled_frame=labeled_frame + ) + + # Add the `Instance` to the `FrameGroup` + self._instances_by_cam[camera].add(instance=instance) + + def create_and_add_missing_instances(self, instance_group: InstanceGroup): + """Add missing instances to `FrameGroup` from `InstanceGroup`s. + + If an `InstanceGroup` does not have an `Instance` for a `Camcorder` in + `FrameGroup.cams_to_include`, then create an `Instance` and add it to the + `InstanceGroup`. + + Args: + instance_group: `InstanceGroup` objects to add missing `Instance`s for. + + Raises: + ValueError: If a `Camcorder` in `FrameGroup.cams_to_include` is not in the + `InstanceGroup`. + """ + + # Check that the `InstanceGroup` has `LabeledFrame`s for all included views + for cam in self.cams_to_include: + + # If the `Camcorder` is in the `InstanceGroup`, then `Instance` exists + if cam in instance_group.cameras: + continue # Skip to next cam + + # Get the `LabeledFrame` for the view + labeled_frame = self.get_labeled_frame(camera=cam) + if labeled_frame is None: + # There is no `LabeledFrame` for this view, so lets make one + labeled_frame = self._create_and_add_labeled_frame(camera=cam) + + # Create an instance + self._create_and_add_instance( + instance_group=instance_group, cam=cam, labeled_frame=labeled_frame + ) + + def upsert_points( + self, + points: np.ndarray, + instance_groups: List[InstanceGroup], + exclude_complete: bool = True, + ): + """Upsert points for `Instance`s at included cams in specified `InstanceGroup`. + + This will update the points for existing `Instance`s in the `InstanceGroup`s and + also add new `Instance`s if they do not exist. + + + Included cams are specified by `FrameGroup.cams_to_include`. + + The ordering of the `InstanceGroup`s in `instance_groups` should match the + ordering of the second dimension (T) in `points`. + + Args: + points: Numpy array of shape (M, T, N, 2) where M is the number of views, T + is the number of Tracks, N is the number of Nodes, and 2 is for x, y. + instance_groups: List of `InstanceGroup` objects to update points for. + exclude_complete: If True, then only update points that are not marked as + complete. Default is True. + """ + + # Check that the correct shape was passed in + n_views, n_instances, n_nodes, n_coords = points.shape + assert n_views == len( + self.cams_to_include + ), f"Expected {len(self.cams_to_include)} views, got {n_views}." + assert n_instances == len( + instance_groups + ), f"Expected {len(instance_groups)} instances, got {n_instances}." + assert n_coords == 2, f"Expected 2 coordinates, got {n_coords}." + + # Update points for each `InstanceGroup` + for ig_idx, instance_group in enumerate(instance_groups): + # Ensure that `InstanceGroup`s is in this `FrameGroup` + self._raise_if_instance_group_not_in_frame_group( + instance_group=instance_group + ) + + # Check that the `InstanceGroup` has `Instance`s for all cams_to_include + self.create_and_add_missing_instances(instance_group=instance_group) + + # Update points for each `Instance` in `InstanceGroup` + instance_points = points[:, ig_idx, :, :] # M x N x 2 + instance_group.update_points( + points=instance_points, + cams_to_include=self.cams_to_include, + exclude_complete=exclude_complete, + ) + + def _raise_if_instance_not_in_instance_group(self, instance: Instance): + """Raise a ValueError if the `Instance` is not in an `InstanceGroup`. + + Args: + instance: `Instance` to check if in an `InstanceGroup`. + + Raises: + ValueError: If the `Instance` is not in an `InstanceGroup`. + """ + + instance_group = self.get_instance_group(instance=instance) + if instance_group is None: + raise ValueError( + f"Instance {instance} is not in an InstanceGroup within the FrameGroup." + ) + + def _raise_if_instance_incompatibile(self, instance: Instance, camera: Camcorder): + """Raise a ValueError if the `Instance` is incompatible with the `FrameGroup`. + + The `Instance` is incompatible if: + 1. the `Instance` is not linked to a `LabeledFrame`. + 2. the frame index of the `Instance` does not match the frame index of the + `FrameGroup`. + 3. the `LabeledFrame` of the `Instance` does not match the existing + `LabeledFrame` for the `Camcorder` in the `FrameGroup`. + + Args: + instance: `Instance` to check compatibility of. + camera: `Camcorder` to link the `Instance` to. + """ + + labeled_frame = instance.frame + if labeled_frame is None: + raise ValueError( + f"Instance {instance} is not linked to a LabeledFrame. " + "Cannot add to FrameGroup." + ) + + frame_idx = labeled_frame.frame_idx + if frame_idx != self.frame_idx: + raise ValueError( + f"Instance {instance} frame index {frame_idx} does not match " + f"FrameGroup frame index {self.frame_idx}." + ) + + labeled_frame_fg = self.get_labeled_frame(camera=camera) + if labeled_frame_fg is None: + pass + elif labeled_frame != labeled_frame_fg: + raise ValueError( + f"Instance's LabeledFrame {labeled_frame} is not the same as " + f"FrameGroup's LabeledFrame {labeled_frame_fg} for Camcorder {camera}." + ) + + def _raise_if_instance_group_incompatible(self, instance_group: InstanceGroup): + """Raise a ValueError if `InstanceGroup` is incompatible with `FrameGroup`. + + An `InstanceGroup` is incompatible if + - the `frame_idx` does not match the `FrameGroup`'s `frame_idx`. + - the `InstanceGroup.name` is already used in the `FrameGroup`. + + Args: + instance_group: `InstanceGroup` to check compatibility of. + + Raises: + ValueError: If the `InstanceGroup` is incompatible with the `FrameGroup`. + """ + + if instance_group.frame_idx != self.frame_idx: + raise ValueError( + f"InstanceGroup {instance_group} frame index {instance_group.frame_idx} " + f"does not match FrameGroup frame index {self.frame_idx}." + ) + + if instance_group.name in self._instance_group_name_registry: + raise ValueError( + f"InstanceGroup name {instance_group.name} is already registered in " + "this FrameGroup's list of names: " + f"{self._instance_group_name_registry}\n" + "Please use a unique name for the new InstanceGroup." + ) + + def _raise_if_instance_group_not_in_frame_group( + self, instance_group: InstanceGroup + ): + """Raise a ValueError if `InstanceGroup` is not in this `FrameGroup`.""" + + if instance_group not in self.instance_groups: + raise ValueError( + f"InstanceGroup {instance_group} is not in this FrameGroup: " + f"{self.instance_groups}." + ) + + @classmethod + def from_instance_groups( + cls, + session: RecordingSession, + instance_groups: List["InstanceGroup"], + ) -> Optional["FrameGroup"]: + """Creates a `FrameGroup` object from an `InstanceGroup` object. + + Args: + session: `RecordingSession` object. + instance_groups: A list of `InstanceGroup` objects. + + Returns: + `FrameGroup` object or None if no "real" (determined by `frame_idx` other + than None) frames found. + """ + + if len(instance_groups) == 0: + raise ValueError("instance_groups must contain at least one InstanceGroup") + + # Get frame index from first instance group + frame_idx = instance_groups[0].frame_idx + + # Create and return `FrameGroup` object + return cls( + frame_idx=frame_idx, instance_groups=instance_groups, session=session + ) + + def enforce_frame_idx_unique( + self, session: RecordingSession, frame_idx: int + ) -> bool: + """Enforces that all frame indices are unique in `RecordingSession`. + + Removes existing `FrameGroup` object from the + `RecordingSession._frame_group_by_frame_idx`. + + Args: + session: `RecordingSession` object. + frame_idx: Frame index. + """ + + if session.frame_groups.get(frame_idx, None) is not None: + # Remove existing `FrameGroup` object from the + # `RecordingSession._frame_group_by_frame_idx` + logger.warning( + f"Frame index {frame_idx} for FrameGroup already exists in this " + "RecordingSession. Overwriting." + ) + session.frame_groups.pop(frame_idx) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 1221cadae..185b235a8 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -274,6 +274,12 @@ def remove_video(self, video: Video): del self._lf_by_video[video] if video in self._frame_idx_map: del self._frame_idx_map[video] + self.remove_session_video(video=video) + + def remove_session_video(self, video: Video): + """Remove video from session in cache.""" + + # TODO(LM): Also remove LabeledFrames from frame_group if video in self._session_by_video: del self._session_by_video[video] @@ -442,8 +448,7 @@ def _del_count_cache(self, video, video_idx, frame_idx, type_key: str): @attr.s(auto_attribs=True, repr=False, str=False) class Labels(MutableSequence): - """ - The :class:`Labels` class collects the data for a SLEAP project. + """The :class:`Labels` class collects the data for a SLEAP project. This class is front-end for all interactions with loading, writing, and modifying these labels. The actual storage backend for the data @@ -967,6 +972,9 @@ def remove_frame(self, lf: LabeledFrame, update_cache: bool = True): update_cache: If True, update the internal frame cache. If False, cache update can be postponed (useful when removing many frames). """ + + # TODO(LM): Remove LabeledFrame from any frame groups it's in. + self.labeled_frames.remove(lf) if update_cache: self._cache.remove_frame(lf) @@ -977,6 +985,8 @@ def remove_frames(self, lfs: List[LabeledFrame]): Args: lfs: A sequence of labeled frames to remove. """ + + # TODO(LM): Remove LabeledFrame from any frame groups it's in. to_remove = set(lfs) self.labeled_frames = [lf for lf in self.labeled_frames if lf not in to_remove] self.update_cache() @@ -1000,6 +1010,8 @@ def remove_empty_instances(self, keep_empty_frames: bool = True): def remove_empty_frames(self): """Remove frames with no instances.""" + + # TODO(LM): Remove LabeledFrame from any frame groups it's in. self.labeled_frames = [ lf for lf in self.labeled_frames if len(lf.instances) > 0 ] @@ -1657,7 +1669,8 @@ def remove_video(self, video: Video): # Delete video self.videos.remove(video) - self._cache.remove_video(video) + self.remove_session_video(video=video) + self._cache.remove_video(video=video) def add_session(self, session: RecordingSession): """Add a recording session to the labels. @@ -1702,16 +1715,21 @@ def get_session(self, video: Video) -> Optional[RecordingSession]: """ return self._cache._session_by_video.get(video, None) - def remove_session_video(self, session: RecordingSession, video: Video): - """Remove a video from a recording session. + def remove_session_video(self, video: Video): + """Remove a video from its linked recording session (if any). Args: - session: `RecordingSession` instance video: `Video` instance """ - self._cache._session_by_video.pop(video, None) - if video in session.videos: + session = self.get_session(video) + + if session is None: + return + + # Need to remove from cache first to avoid circular reference + self._cache.remove_session_video(video=video) + if session.get_camera(video) is not None: session.remove_video(video) @classmethod @@ -1845,6 +1863,8 @@ def remove_user_instances(self, new_labels: Optional["Labels"] = None): # Keep only labeled frames with no conflicting predictions. self.labeled_frames = keep_lfs + # TODO(LM): Remove LabeledFrame from any frame groups it's in. + def remove_predictions(self, new_labels: Optional["Labels"] = None): """Clear predicted instances from the labels. @@ -1881,6 +1901,8 @@ def remove_predictions(self, new_labels: Optional["Labels"] = None): # Keep only labeled frames with no conflicting predictions. self.labeled_frames = keep_lfs + # TODO(LM): Remove LabeledFrame from any frame groups it's in. + def remove_untracked_instances(self, remove_empty_frames: bool = True): """Remove instances that do not have a track assignment. @@ -1998,6 +2020,7 @@ def merge_matching_frames(self, video: Optional[Video] = None): for vid in {lf.video for lf in self.labeled_frames}: self.merge_matching_frames(video=vid) else: + # TODO(LM): Remove LabeledFrame from any frame groups it's in. self.labeled_frames = LabeledFrame.merge_frames( self.labeled_frames, video=video ) diff --git a/tests/data/cameras/minimal_session/min_session_user_labeled.slp b/tests/data/cameras/minimal_session/min_session_user_labeled.slp new file mode 100644 index 000000000..c7d8fb2dd Binary files /dev/null and b/tests/data/cameras/minimal_session/min_session_user_labeled.slp differ diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index 790f29946..529087088 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -278,3 +278,11 @@ def multiview_min_session_labels(): "tests/data/cameras/minimal_session/min_session.slp", video_search=["tests/data/videos/"], ) + + +@pytest.fixture +def multiview_min_session_user_labels(): + return Labels.load_file( + "tests/data/cameras/minimal_session/min_session_user_labeled.slp", + video_search=["tests/data/videos/"], + ) diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index c20af8614..78219c52c 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -220,7 +220,6 @@ def assert_videos_written(num_videos: int, labels_path: str = None): context.state["filename"] = None if csv: - context.state["filename"] = centered_pair_predictions_hdf5_path params = {"all_videos": True, "csv": csv} @@ -955,344 +954,3 @@ def test_AddSession( assert len(labels.sessions) == 2 assert context.state["session"] is session assert labels.sessions[1] is not session - - -def test_triangulate_session_get_all_views_at_frame( - multiview_min_session_labels: Labels, -): - labels = multiview_min_session_labels - session = labels.sessions[0] - lf = labels.labeled_frames[0] - frame_idx = lf.frame_idx - - # Test with no cams_to_include, expect views from all linked cameras - views = TriangulateSession.get_all_views_at_frame(session, frame_idx) - assert len(views) == len(session.linked_cameras) - for cam in session.linked_cameras: - assert views[cam].frame_idx == frame_idx - assert views[cam].video == session[cam] - - # Test with cams_to_include, expect views from only those cameras - cams_to_include = session.linked_cameras[0:2] - views = TriangulateSession.get_all_views_at_frame( - session, frame_idx, cams_to_include=cams_to_include - ) - assert len(views) == len(cams_to_include) - for cam in cams_to_include: - assert views[cam].frame_idx == frame_idx - assert views[cam].video == session[cam] - - -def test_triangulate_session_get_instances_across_views( - multiview_min_session_labels: Labels, -): - - labels = multiview_min_session_labels - session = labels.sessions[0] - - # Test get_instances_across_views - lf: LabeledFrame = labels[0] - track = labels.tracks[0] - instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, frame_idx=lf.frame_idx, track=track - ) - assert len(instances) == len(session.videos) - for vid in session.videos: - cam = session[vid] - inst = instances[cam] - assert inst.frame_idx == lf.frame_idx - assert inst.track == track - assert inst.video == vid - - # Try with excluding cam views - lf: LabeledFrame = labels[2] - track = labels.tracks[1] - cams_to_include = session.linked_cameras[:4] - videos_to_include: Dict[ - Camcorder, Video - ] = session.get_videos_from_selected_cameras(cams_to_include=cams_to_include) - assert len(cams_to_include) == 4 - assert len(videos_to_include) == len(cams_to_include) - instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=lf.frame_idx, - track=track, - cams_to_include=cams_to_include, - ) - assert len(instances) == len( - videos_to_include - ) # May not be true if no instances at that frame - for cam, vid in videos_to_include.items(): - inst = instances[cam] - assert inst.frame_idx == lf.frame_idx - assert inst.track == track - assert inst.video == vid - - # Try with only a single view - cams_to_include = [session.linked_cameras[0]] - with pytest.raises(ValueError): - instances = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=lf.frame_idx, - cams_to_include=cams_to_include, - track=track, - require_multiple_views=True, - ) - - # Try with multiple views, but not enough instances - track = labels.tracks[1] - cams_to_include = session.linked_cameras[4:6] - with pytest.raises(ValueError): - instances = TriangulateSession.get_instances_across_views( - session=session, - frame_idx=lf.frame_idx, - cams_to_include=cams_to_include, - track=track, - require_multiple_views=True, - ) - - -def test_triangulate_session_get_and_verify_enough_instances( - multiview_min_session_labels: Labels, - caplog, -): - labels = multiview_min_session_labels - session = labels.sessions[0] - lf = labels.labeled_frames[0] - track = labels.tracks[1] - - # Test with no cams_to_include, expect views from all linked cameras - instances = TriangulateSession.get_and_verify_enough_instances( - session=session, frame_idx=lf.frame_idx, track=track - ) - assert len(instances) == 6 # Some views don't have an instance at this track - for cam in session.linked_cameras: - if cam.name in ["side", "sideL"]: # The views that don't have an instance - continue - assert instances[cam].frame_idx == lf.frame_idx - assert instances[cam].track == track - assert instances[cam].video == session[cam] - - # Test with cams_to_include, expect views from only those cameras - cams_to_include = session.linked_cameras[-2:] - instances = TriangulateSession.get_and_verify_enough_instances( - session=session, - frame_idx=lf.frame_idx, - cams_to_include=cams_to_include, - track=track, - ) - assert len(instances) == len(cams_to_include) - for cam in cams_to_include: - assert instances[cam].frame_idx == lf.frame_idx - assert instances[cam].track == track - assert instances[cam].video == session[cam] - - # Test with not enough instances, expect views from only those cameras - cams_to_include = session.linked_cameras[0:2] - instances = TriangulateSession.get_and_verify_enough_instances( - session=session, frame_idx=lf.frame_idx, cams_to_include=cams_to_include - ) - assert isinstance(instances, bool) - assert not instances - messages = "".join([rec.message for rec in caplog.records]) - assert "One or less instances found for frame" in messages - - -def test_triangulate_session_verify_enough_views( - multiview_min_session_labels: Labels, caplog -): - labels = multiview_min_session_labels - session = labels.sessions[0] - - # Test with enough views - enough_views = TriangulateSession.verify_enough_views( - session=session, show_dialog=False - ) - assert enough_views - messages = "".join([rec.message for rec in caplog.records]) - assert len(messages) == 0 - caplog.clear() - - # Test with not enough views - cams_to_include = [session.linked_cameras[0]] - enough_views = TriangulateSession.verify_enough_views( - session=session, cams_to_include=cams_to_include, show_dialog=False - ) - assert not enough_views - messages = "".join([rec.message for rec in caplog.records]) - assert "One or less cameras available." in messages - - -def test_triangulate_session_verify_views_and_instances( - multiview_min_session_labels: Labels, -): - labels = multiview_min_session_labels - session = labels.sessions[0] - - # Test with enough views and instances - lf = labels.labeled_frames[0] - instance = lf.instances[0] - - context = CommandContext.from_labels(labels) - params = { - "video": session.videos[0], - "session": session, - "frame_idx": lf.frame_idx, - "instance": instance, - "show_dialog": False, - } - enough_views = TriangulateSession.verify_views_and_instances(context, params) - assert enough_views - assert "instances" in params - - # Test with not enough views - cams_to_include = [session.linked_cameras[0]] - params = { - "video": session.videos[0], - "session": session, - "frame_idx": lf.frame_idx, - "instance": instance, - "cams_to_include": cams_to_include, - "show_dialog": False, - } - enough_views = TriangulateSession.verify_views_and_instances(context, params) - assert not enough_views - assert "instances" not in params - - -def test_triangulate_session_calculate_reprojected_points( - multiview_min_session_labels: Labels, -): - """Test `TriangulateSession.calculate_reprojected_points`.""" - - session = multiview_min_session_labels.sessions[0] - lf: LabeledFrame = multiview_min_session_labels[0] - track = multiview_min_session_labels.tracks[0] - instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, frame_idx=lf.frame_idx, track=track - ) - instances_and_coords = TriangulateSession.calculate_reprojected_points( - session=session, instances=instances - ) - - # Check that we get the same number of instances as input - assert len(instances) == len(list(instances_and_coords)) - - # Check that each instance has the same number of points - for inst, inst_coords in instances_and_coords: - assert inst_coords.shape[1] == len(inst.skeleton) # (1, 15, 2) - - -def test_triangulate_session_get_instances_matrices( - multiview_min_session_labels: Labels, -): - """Test `TriangulateSession.get_instance_matrices`.""" - labels = multiview_min_session_labels - session = labels.sessions[0] - lf: LabeledFrame = labels[0] - track = labels.tracks[0] - instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, frame_idx=lf.frame_idx, track=track - ) - instances_matrices = TriangulateSession.get_instances_matrices( - instances_ordered=instances.values() - ) - - # Verify shape - n_views = len(instances) - n_frames = 1 - n_tracks = 1 - n_nodes = len(labels.skeleton) - assert instances_matrices.shape == (n_views, n_frames, n_tracks, n_nodes, 2) - - -def test_triangulate_session_update_instances(multiview_min_session_labels: Labels): - """Test `RecordingSession.update_instances`.""" - - # Test update_instances - session = multiview_min_session_labels.sessions[0] - lf: LabeledFrame = multiview_min_session_labels[0] - track = multiview_min_session_labels.tracks[0] - instances: Dict[ - Camcorder, Instance - ] = TriangulateSession.get_instances_across_views( - session=session, frame_idx=lf.frame_idx, track=track - ) - instances_and_coordinates = TriangulateSession.calculate_reprojected_points( - session=session, instances=instances - ) - for inst, inst_coords in instances_and_coordinates: - assert inst_coords.shape == (1, len(inst.skeleton), 2) # Tracks, Nodes, 2 - # Assert coord are different from original - assert not np.array_equal(inst_coords, inst.points_array) - - # Just run for code coverage testing, do not test output here (race condition) - # (see "functional core, imperative shell" pattern) - TriangulateSession.update_instances(session=session, instances=instances) - - -def test_triangulate_session_do_action(multiview_min_session_labels: Labels): - """Test `TriangulateSession.do_action`.""" - - labels = multiview_min_session_labels - session = labels.sessions[0] - - # Test with enough views and instances - lf = labels.labeled_frames[0] - instance = lf.instances[0] - - context = CommandContext.from_labels(labels) - params = { - "video": session.videos[0], - "session": session, - "frame_idx": lf.frame_idx, - "instance": instance, - "ask_again": True, - } - TriangulateSession.do_action(context, params) - - # Test with not enough views - cams_to_include = [session.linked_cameras[0]] - params = { - "video": session.videos[0], - "session": session, - "frame_idx": lf.frame_idx, - "instance": instance, - "cams_to_include": cams_to_include, - "ask_again": True, - } - TriangulateSession.do_action(context, params) - - -def test_triangulate_session(multiview_min_session_labels: Labels): - """Test `TriangulateSession`.""" - - labels = multiview_min_session_labels - session = labels.sessions[0] - video = session.videos[0] - lf = labels.labeled_frames[0] - instance = lf.instances[0] - context = CommandContext.from_labels(labels) - - # Test with enough views and instances so we don't get any GUI pop-ups - context.triangulateSession( - frame_idx=lf.frame_idx, - video=video, - instance=instance, - session=session, - ) - - # Test with using state to gather params - context.state["session"] = session - context.state["video"] = video - context.state["instance"] = instance - context.state["frame_idx"] = lf.frame_idx - context.triangulateSession() diff --git a/tests/io/test_cameras.py b/tests/io/test_cameras.py index 35ecaa50e..16fbfd0a7 100644 --- a/tests/io/test_cameras.py +++ b/tests/io/test_cameras.py @@ -1,12 +1,18 @@ """Module to test functions in `sleap.io.cameras`.""" -from typing import Dict, List +from typing import Dict, List, Tuple, Union import numpy as np import pytest -from sleap.io.cameras import Camcorder, CameraCluster, RecordingSession -from sleap.io.dataset import Instance, LabeledFrame, Labels, LabelsDataCache +from sleap.io.cameras import ( + Camcorder, + CameraCluster, + InstanceGroup, + FrameGroup, + RecordingSession, +) +from sleap.io.dataset import Instance, Labels from sleap.io.video import Video @@ -164,6 +170,13 @@ def test_recording_session( # Test __repr__ assert f"{session.__class__.__name__}(" in repr(session) + # Test new_frame_group + frame_group = session.new_frame_group(frame_idx=0) + assert isinstance(frame_group, FrameGroup) + assert frame_group.session == session + assert frame_group.frame_idx == 0 + assert frame_group == session.frame_groups[0] + # Test add_video camcorder = session.camera_cluster.cameras[0] session.add_video(centered_pair_vid, camcorder) @@ -280,3 +293,177 @@ def test_recording_session_remove_video(multiview_min_session_labels: Labels): session.remove_video(video) assert labels_cache._session_by_video.get(video, None) is None assert video not in session.videos + + +# TODO(LM): Remove after adding method to (de)seralize `InstanceGroup` +def create_instance_group( + labels: Labels, + frame_idx: int, + add_dummy: bool = False, +) -> Union[ + InstanceGroup, Tuple[InstanceGroup, Dict[Camcorder, Instance], Instance, Camcorder] +]: + """Create an `InstanceGroup` from a `Labels` object. + + Args: + labels: The `Labels` object to use. + frame_idx: The frame index to use. + add_dummy: Whether to add a dummy instance to the `InstanceGroup`. + + Returns: + The `InstanceGroup` object. + """ + + session = labels.sessions[0] + + lf = labels.labeled_frames[0] + instance = lf.instances[0] + + instance_by_camera = {} + for cam in session.linked_cameras: + video = session.get_video(cam) + lfs_in_view = labels.find(video=video, frame_idx=frame_idx) + if len(lfs_in_view) > 0: + instance = lfs_in_view[0].instances[0] + instance_by_camera[cam] = instance + + # Add a dummy instance to make sure it gets ignored + if add_dummy: + dummy_instance = Instance.from_numpy( + np.full( + shape=(len(instance.skeleton.nodes), 2), + fill_value=np.nan, + ), + skeleton=instance.skeleton, + ) + instance_by_camera[cam] = dummy_instance + + instance_group = InstanceGroup.from_dict( + d=instance_by_camera, name="test_instance_group", name_registry={} + ) + return ( + (instance_group, instance_by_camera, dummy_instance, cam) + if add_dummy + else instance_group + ) + + +def test_instance_group(multiview_min_session_labels: Labels): + """Test `InstanceGroup` data structure.""" + + labels = multiview_min_session_labels + session = labels.sessions[0] + camera_cluster = session.camera_cluster + + lf = labels.labeled_frames[0] + frame_idx = lf.frame_idx + + # Test `from_dict` + instance_group, instance_by_camera, dummy_instance, cam = create_instance_group( + labels=labels, frame_idx=frame_idx, add_dummy=True + ) + assert isinstance(instance_group, InstanceGroup) + assert instance_group.frame_idx == frame_idx + assert instance_group.camera_cluster == camera_cluster + for camera in session.linked_cameras: + if camera == cam: + assert instance_by_camera[camera] == dummy_instance + assert camera not in instance_group.cameras + else: + instance = instance_group[camera] + assert isinstance(instance, Instance) + assert instance_group[camera] == instance_by_camera[camera] + assert instance_group[instance] == camera + + # Test `__repr__` + print(instance_group) + + # Test `__len__` + assert len(instance_group) == len(instance_by_camera) - 1 + + # Test `get_cam` + assert instance_group.get_cam(dummy_instance) is None + + # Test `get_instance` + assert instance_group.get_instance(cam) is None + + # Test `instances` property + assert len(instance_group.instances) == len(instance_by_camera) - 1 + + # Test `cameras` property + assert len(instance_group.cameras) == len(instance_by_camera) - 1 + + # Test `__getitem__` with `int` key + assert isinstance(instance_group[0], Instance) + with pytest.raises(KeyError): + instance_group[len(instance_group)] + + # Test `_dummy_instance` property + assert ( + instance_group.dummy_instance.skeleton == instance_group.instances[0].skeleton + ) + assert isinstance(instance_group.dummy_instance, Instance) + + # Test `numpy` method + instance_group_numpy = instance_group.numpy() + assert isinstance(instance_group_numpy, np.ndarray) + n_views, n_nodes, n_coords = instance_group_numpy.shape + assert n_views == len(instance_group.camera_cluster.cameras) + assert n_nodes == len(instance_group.dummy_instance.skeleton.nodes) + assert n_coords == 2 + + # Test `update_points` method + instance_group.update_points(np.full((n_views, n_nodes, n_coords), 0)) + instance_group_numpy = instance_group.numpy() + np.nan_to_num(instance_group_numpy, nan=0) + assert np.all(np.nan_to_num(instance_group_numpy, nan=0) == 0) + + # Populate with only dummy instance and test `from_dict` + instance_by_camera = {cam: dummy_instance} + with pytest.raises(ValueError): + instance_group = InstanceGroup.from_dict( + d=instance_by_camera, name="test_instance_group", name_registry={} + ) + + +def test_frame_group(multiview_min_session_labels: Labels): + """Test `FrameGroup` data structure.""" + + labels = multiview_min_session_labels + session = labels.sessions[0] + + # Test `from_instance_groups` from list of instance groups + frame_idx_1 = 0 + instance_group = create_instance_group(labels=labels, frame_idx=frame_idx_1) + instance_groups: List[InstanceGroup] = [instance_group] + frame_group_1 = FrameGroup.from_instance_groups( + session=session, instance_groups=instance_groups + ) + assert isinstance(frame_group_1, FrameGroup) + assert frame_idx_1 in session.frame_groups + assert len(session.frame_groups) == 1 + assert frame_group_1 == session.frame_groups[frame_idx_1] + assert len(frame_group_1.instance_groups) == 1 + + # Test `RecordingSession.frame_groups` property + frame_idx_2 = 1 + instance_group = create_instance_group(labels=labels, frame_idx=frame_idx_2) + instance_groups: List[InstanceGroup] = [instance_group] + frame_group_2 = FrameGroup.from_instance_groups( + session=session, instance_groups=instance_groups + ) + assert isinstance(frame_group_2, FrameGroup) + assert frame_idx_2 in session.frame_groups + assert len(session.frame_groups) == 2 + assert frame_group_2 == session.frame_groups[frame_idx_2] + assert len(frame_group_2.instance_groups) == 1 + + frame_idx_3 = 2 + frame_group_3 = FrameGroup(frame_idx=frame_idx_3, session=session) + assert isinstance(frame_group_3, FrameGroup) + assert frame_idx_3 in session.frame_groups + assert len(session.frame_groups) == 3 + assert frame_group_3 == session.frame_groups[frame_idx_3] + assert len(frame_group_3.instance_groups) == 0 + + # TODO(LM): Test underlying dictionaries more thoroughly diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 020dd64ed..a544b7703 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -1030,7 +1030,7 @@ def test_add_session_and_update_session( assert labels._cache._session_by_video == {video: session} assert labels.get_session(video) == session - labels.remove_session_video(session, video) + labels.remove_session_video(video=video) assert video not in session.videos assert video not in labels._cache._session_by_video