From 07738425d169fd844e2730bfca1c6bf9f4ede095 Mon Sep 17 00:00:00 2001 From: Siyao Date: Mon, 15 Jul 2024 14:25:22 +0100 Subject: [PATCH] Revise on find the object --- .../src/gpsr/states/go_find_the_object.py | 164 +++++++----------- 1 file changed, 65 insertions(+), 99 deletions(-) diff --git a/tasks/gpsr/src/gpsr/states/go_find_the_object.py b/tasks/gpsr/src/gpsr/states/go_find_the_object.py index 9256005db..0671e6ff3 100755 --- a/tasks/gpsr/src/gpsr/states/go_find_the_object.py +++ b/tasks/gpsr/src/gpsr/states/go_find_the_object.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 import smach -from lasr_skills import Detect3D, GoToLocation +import smach_ros +from lasr_skills import Detect3D, GoToLocation, LookToPoint from shapely.geometry.polygon import Polygon from typing import List, Union -from geometry_msgs.msg import Pose, Point, Quaternion -from lasr_skills import Say, PlayMotion +from geometry_msgs.msg import Pose, Point, Quaternion, CDRequest +from lasr_skills import Say, LookToPoint +from lasr_vision_msgs.srv import CroppedDetection, CroppedDetectionRequest import rospy """ @@ -26,6 +28,8 @@ def __init__(self): def execute(self, userdata): userdata.location = userdata.waypoints[userdata.location_index] + userdata.polygon = userdata.polygons[userdata.location_index] + userdata.look_point = userdata.look_points[userdata.location_index] return "succeeded" class check_objects(smach.State): @@ -59,59 +63,44 @@ def execute(self, userdata): return "succeeded" - class cumulate_result(smach.State): - def __init__(self): - smach.State.__init__( - self, - outcomes=["succeeded", "failed"], - input_keys=["detection_result", "cumulated_result"], - output_keys=["cumulated_result"], - ) - - def execute(self, userdata): - if "cumulated_result" not in userdata: - userdata.cumulated_result = list() - userdata.cumulated_result.append(userdata.detection_result) - else: - userdata.cumulated_result.append( - userdata.detection_result - ) # the outcome of the 3d detection - return "succeeded" - class detection_result(smach.State): def __init__(self): smach.State.__init__( self, outcomes=["object_found", "object_not_found", "failed"], - input_keys=["cumulated_result"], - output_keys=["result", "cumulated_result"], + input_keys=["detection_result"], + output_keys=["result"], ) def execute(self, userdata): - if any(userdata.cumulated_result): + if any(userdata.detection_result): userdata.result = True - userdata.cumulated_result = list() return "object_found" else: - userdata.cumulated_result = list() userdata.result = False return "object_not_found" def __init__( self, - depth_topic: str = "/xtion/depth_registered/points", - model: str = "yolov8n-seg.pt", - filter: Union[List[str], None] = None, - waypoints: Union[List[Pose], None] = None, + model: str = "yolov8x-seg.pt", + filter: Union[List[str], None] = None, # <- input + waypoints: Union[List[Pose], None] = None, # <- input locations: Union[str, None] = None, confidence: float = 0.5, nms: float = 0.3, - motions: Union[List[str], None] = None, + poly_points: List[Polygon] = None, # polygen list of polygen for every location <- input + look_points: List[Point] = None # point to look at for each location <- input ): smach.StateMachine.__init__(self, outcomes=["succeeded", "failed"]) if waypoints is None and locations is None: raise ValueError("Either waypoints or location_param must be provided") + if poly_points is None or len(poly_points) != len(waypoints_to_iterate): + raise ValueError("Poly points must be provided for each waypoint") + + if look_points is None or len(look_points) != len(waypoints_to_iterate): + raise ValueError("Look points must be provided for each waypoint") + if waypoints is None: waypoints_to_iterate: List[Pose] = [] @@ -162,81 +151,58 @@ def __init__( "GO_TO_LOCATION", GoToLocation(), transitions={ - "succeeded": "INNER_ITERATOR", + "succeeded": "LOOK_AT_POINT", "failed": "failed", }, ) - inner_iterator = smach.Iterator( - outcomes=["succeeded", "failed", "continue"], - it=lambda: range(len(motions)), - it_label="motion_index", - input_keys=["waypoints", "location_index"], - output_keys=["cumulated_result"], - exhausted_outcome="succeeded", + smach.StateMachine.add( + "LOOK_AT_POINT", + LookToPoint(), + transitions={ + "succeeded": "DETECT_OBJECTS_3D", + "failed": "failed", + }, + remapping={"look_point": "look_point"}, ) - with inner_iterator: - inner_container_sm = smach.StateMachine( - outcomes=["succeeded", "failed", "continue"], - input_keys=["motion_index", "location_index", "waypoints"], - output_keys=["cumulated_result"], - ) - - with inner_container_sm: - smach.StateMachine.add( - "LOOK_AROUND", - PlayMotion(), - transitions={ - "succeeded": "DETECT_OBJECTS_3D", - "aborted": "DETECT_OBJECTS_3D", - "preempted": "failed", - }, - ) - - smach.StateMachine.add( - "DETECT_OBJECTS_3D", - Detect3D( - depth_topic=depth_topic, - model=model, - filter=filter, - confidence=confidence, - nms=nms, - ), - transitions={ - "succeeded": "RESULT", - "failed": "failed", - }, - ) - - smach.StateMachine.add( - "RESULT", - self.check_objects(), - transitions={ - "succeeded": "SAVE_RESULT", - "failed": "failed", - }, - ) - - smach.StateMachine.add( - "SAVE_RESULT", - self.cumulate_result(), - transitions={ - "succeeded": "continue", - "failed": "failed", - }, - ) - - inner_iterator.set_contained_state( - "INNER_CONTAINER_STATE", - inner_container_sm, - loop_outcomes=["continue"], - ) + smach.StateMachine.add( + "DETECT_OBJECTS_3D", + smach_ros.ServiceState( + "/vision/cropped_detection", + CroppedDetection, + request=CroppedDetectionRequest( + requests=[ + CDRequest( + method="closest", + use_mask=True, + yolo_model=model, + yolo_model_confidence=confidence, + yolo_nms_threshold=nms, + return_sensor_reading=False, + object_names= filter, + polygons= [], + ) + ] + ), + output_keys=["responses"], + response_slots=["responses"], + ), + transitions={ + "succeeded": "RESULT", + "aborted": "failed", + "preempted": "failed", + }, + remapping={"polygon": "polygon"} + ) smach.StateMachine.add( - "INNER_ITERATOR", - inner_iterator, - {"succeeded": "CHECK_RESULT", "failed": "failed"}, + "RESULT", + self.check_objects(), + transitions={ + "succeeded": "CHECK_RESULT", + "failed": "failed", + }, ) smach.StateMachine.add(