-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(maybe revert)t: minor changes+refactor of object comparison
- Loading branch information
Showing
1 changed file
with
104 additions
and
238 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,280 +1,146 @@ | ||
#!/usr/bin/env python3 | ||
import smach | ||
from lasr_skills import Detect3DInArea, Say | ||
from lasr_skills import Detect3D | ||
from lasr_skills import Detect | ||
import smach_ros | ||
from shapely.geometry.polygon import Polygon | ||
from typing import List, Union, Dict | ||
import rospy | ||
from typing import List, Literal, Optional | ||
from lasr_vision_msgs.srv import CroppedDetection, CroppedDetectionRequest | ||
from lasr_vision_msgs.msg import CDRequest | ||
|
||
|
||
class ObjectComparison(smach.StateMachine): | ||
class CountObjectTypes(smach.State): | ||
def __init__(self, area_polygon: Polygon): | ||
smach.State.__init__( | ||
self, | ||
outcomes=["succeeded", "failed"], | ||
input_keys=["detections_3d"], | ||
output_keys=["detections_types", "object_dict"], | ||
) | ||
|
||
def count_types(self, detections): | ||
object_counts = {} | ||
for detection in detections: | ||
object_type = detection.name | ||
if object_type in object_counts: | ||
object_counts[object_type] += 1 | ||
else: | ||
object_counts[object_type] = 1 | ||
return object_counts | ||
_object_comp_list: List[str] = [ | ||
"biggest", | ||
"largets", | ||
"smallest", | ||
"heaviest", | ||
"lightest", | ||
"thinnest", | ||
] | ||
|
||
def execute(self, userdata): | ||
filtered_detections = userdata.detections_3d | ||
rospy.loginfo(filtered_detections) | ||
object_counts = self.count_types(filtered_detections.detected_objects) | ||
userdata.object_dict = object_counts | ||
userdata.detections_types = list(object_counts.keys()) | ||
return "succeeded" | ||
_biggest_list: List[str] = [] | ||
|
||
class CountCategory(smach.State): | ||
def __init__(self, object_weight: Union[List[dict], None] = None): | ||
smach.State.__init__( | ||
self, | ||
outcomes=["succeeded", "failed"], | ||
input_keys=["object_dict"], | ||
output_keys=["category_dict", "detections_categories"], | ||
) | ||
self.object_weight = object_weight | ||
_largest_list: List[str] = [] | ||
|
||
def count_category(self, object_dictionary, count_object): | ||
category_count = {category: 0 for category in object_dictionary.keys()} | ||
for category, items in object_dictionary.items(): | ||
for obj in count_object.keys(): | ||
if obj in items: | ||
category_count[category] += count_object[obj] | ||
return category_count | ||
_smallest_list: List[str] = [] | ||
|
||
def execute(self, userdata): | ||
detected_objects = userdata.object_dict | ||
counts = self.count_category(self.object_weight, detected_objects) | ||
category_counts = { | ||
key: value for key, value in counts.items() if value != 0 | ||
} | ||
userdata.category_dict = category_counts | ||
userdata.detections_categories = list(category_counts.keys()) | ||
return "succeeded" | ||
_heaviest_list: List[str] = [] | ||
|
||
class ObjectWeight(smach.State): | ||
def __init__(self, object_weight: Union[List[dict], None] = None): | ||
smach.State.__init__( | ||
self, | ||
outcomes=["succeeded", "failed"], | ||
input_keys=["detections_types"], | ||
output_keys=["sorted_weights"], | ||
) | ||
self.object_weight = object_weight | ||
_lightest_list: List[str] = [] | ||
|
||
def get_weight(self, detections, average_weights): | ||
weight_dict = {} | ||
for category, items in average_weights.items(): | ||
for i in detections: | ||
if i in items: | ||
weight = items[i] | ||
weight_dict[i] = weight | ||
return weight_dict | ||
_thinnest_list: List[str] = [] | ||
|
||
def execute(self, userdata): | ||
weights_dict = self.get_weight( | ||
userdata.detections_types, self.object_weight | ||
) | ||
sorted_weights = sorted( | ||
weights_dict.items(), key=lambda item: item[1], reverse=True | ||
_query: Literal[ | ||
"biggest", "largest", "smallest", "heaviest", "lightest", "thinnest" | ||
] | ||
|
||
def _compare_objects(self, userdata): | ||
detections = userdata.responses[0].detections_3d | ||
if self._query == "biggest": | ||
biggest_object = next( | ||
(obj for obj in self._biggest_list if obj in detections), None | ||
) | ||
userdata.sorted_weights = sorted_weights | ||
if not biggest_object: | ||
return "failed" | ||
userdata.query_result = biggest_object | ||
return "succeeded" | ||
|
||
class ObjectSize(smach.State): | ||
def __init__(self): | ||
smach.State.__init__( | ||
self, | ||
outcomes=["succeeded", "failed"], | ||
input_keys=["object_dict", "detections_3d"], | ||
output_keys=["sorted_size"], | ||
elif self._query == "largest": | ||
largest_object = next( | ||
(obj for obj in self._largest_list if obj in detections), None | ||
) | ||
|
||
def property_size_calculation(self, detections, result): | ||
area = dict() | ||
for i in detections: | ||
for object in result.detected_objects: | ||
if i == object.name: | ||
area[i] = object.xywh[2] * object.xywh[3] | ||
return area | ||
|
||
def execute(self, userdata): | ||
detections_types = list(userdata.object_dict.keys()) | ||
area_dict = self.property_size_calculation( | ||
detections_types, userdata.detections_3d | ||
if not largest_object: | ||
return "failed" | ||
userdata.query_result = largest_object | ||
return "succeeded" | ||
elif self._query == "smallest": | ||
smallest_object = next( | ||
(obj for obj in self._smallest_list if obj in detections), None | ||
) | ||
sorted_size = sorted( | ||
area_dict.items(), key=lambda item: item[1], reverse=True | ||
if not smallest_object: | ||
return "failed" | ||
userdata.query_result = smallest_object | ||
return "succeeded" | ||
elif self._query == "heaviest": | ||
heaviest_object = next( | ||
(obj for obj in self._heaviest_list if obj in detections), None | ||
) | ||
userdata.sorted_size = sorted_size | ||
if not heaviest_object: | ||
return "failed" | ||
userdata.query_result = heaviest_object | ||
return "succeeded" | ||
|
||
class DecideOperation(smach.State): | ||
def __init__(self): | ||
smach.State.__init__( | ||
self, | ||
outcomes=[ | ||
"do_object_count", | ||
"do_category_count" "do_weight", | ||
"do_size", | ||
"failed", | ||
], | ||
input_keys=["operation_label"], | ||
elif self._query == "lightest": | ||
lightest_object = next( | ||
(obj for obj in self._lightest_list if obj in detections), None | ||
) | ||
|
||
def execute(self, userdata): | ||
if userdata.operation_label == "object": | ||
return "do_object_count" | ||
elif userdata.operation_label == "categoty": | ||
return "do_category_count" | ||
elif userdata.operation_label == "weight": | ||
return "do_weight" | ||
elif userdata.operation_label == "size": | ||
return "do_size" | ||
else: | ||
if not lightest_object: | ||
return "failed" | ||
|
||
class SayResult(smach.State): | ||
def __init__(self): | ||
smach.State.__init__( | ||
self, | ||
outcomes=["succeeded", "failed"], | ||
input_keys=[ | ||
"operation_label", | ||
"detections_types", | ||
"detections_categories", | ||
"sorted_size", | ||
"sorted_weights", | ||
], | ||
output_keys=["say_text"], | ||
userdata.query_result = lightest_object | ||
return "succeeded" | ||
elif self._query == "thinnest": | ||
thinnest_object = next( | ||
(obj for obj in self._thinnest_list if obj in detections), None | ||
) | ||
|
||
def execute(self, userdata): | ||
try: | ||
if userdata.operation_label == "count": | ||
object_count = len(userdata.detections_types) | ||
userdata.say_text = f"There are {object_count} objects" | ||
elif userdata.operation_label == "category": | ||
category_count = len(userdata.detections_categories) | ||
userdata.say_text = f"There are {category_count} objects" | ||
elif userdata.operation_label == "weight": | ||
heaviest_object = userdata.sorted_weights[0][0] | ||
userdata.say_text = f"The heaviest object is {heaviest_object}" | ||
elif userdata.operation_label == "size": | ||
biggest_object = userdata.sorted_size[0][0] | ||
userdata.say_text = f"The biggest object is {biggest_object}" | ||
else: | ||
return "failed" | ||
return "succeeded" | ||
except Exception as e: | ||
rospy.logerr(str(e)) | ||
if not thinnest_object: | ||
return "failed" | ||
userdata.query_result = thinnest_object | ||
return "succeeded" | ||
else: | ||
return "failed" | ||
|
||
def __init__( | ||
self, | ||
area_polygon: Polygon, | ||
operation_label: str, | ||
depth_topic: str = "/xtion/depth_registered/points", | ||
model: str = "yolov8n-seg.pt", | ||
filter: Union[List[str], None] = None, | ||
query: Literal[ | ||
"biggest", "largest", "smallest", "heaviest", "lightest", "thinnest" | ||
], | ||
area_polygon: Polygon, # input key | ||
model: str = "yolov8x-seg.pt", | ||
objects: Optional[List[str]] = None, | ||
confidence: float = 0.5, | ||
nms: float = 0.3, | ||
object_weight: Union[List[dict], None] = None, | ||
): | ||
|
||
self._query = query | ||
|
||
smach.StateMachine.__init__( | ||
self, | ||
outcomes=["succeeded", "failed"], | ||
output_keys=["detections_3d", "object_dict", "say_text"], | ||
self, outcomes=["succeeded", "failed"], output_keys=["query_result"] | ||
) | ||
|
||
# Set the operation label in the userdata to decide which task to perform | ||
self.userdata.operation_label = operation_label | ||
|
||
with self: | ||
|
||
smach.StateMachine.add( | ||
"DETECT_OBJECTS_3D", | ||
Detect3DInArea( | ||
depth_topic=depth_topic, | ||
model=model, | ||
filter=filter, | ||
confidence=confidence, | ||
nms=nms, | ||
smach_ros.ServiceState( | ||
"/vision/cropped_detection", | ||
CroppedDetection, | ||
request=CroppedDetectionRequest( | ||
requests=[ | ||
CDRequest( | ||
method="closest", | ||
use_mask=True, | ||
yolo_model=model, | ||
yolo_model_confidence=confidence, | ||
yolo_nms_threshold=nms, | ||
return_sensor_reading=False, | ||
object_names=objects, | ||
polygons=[area_polygon], | ||
) | ||
] | ||
), | ||
output_keys=["responses"], | ||
response_slots=["responses"], | ||
), | ||
transitions={"succeeded": "COUNTOBJECTS", "failed": "failed"}, | ||
) | ||
|
||
smach.StateMachine.add( | ||
"COUNTOBJECTS", | ||
self.CountObjectTypes(area_polygon), | ||
transitions={"succeeded": "DECIDE_OPERATION", "failed": "failed"}, | ||
) | ||
|
||
smach.StateMachine.add( | ||
"DECIDE_OPERATION", | ||
self.DecideOperation(), | ||
transitions={ | ||
"do_object_count": "SAY_RESULT", | ||
"do_category_count": "COUNTCATEGORY", | ||
"do_weight": "GETWEIGHT", | ||
"do_size": "GETSIZE", | ||
"failed": "failed", | ||
"succeeded": "COMPARE_OBJECTS", | ||
"aborted": "failed", | ||
"preempted": "failed", | ||
}, | ||
) | ||
|
||
smach.StateMachine.add( | ||
"COUNTCATEGORY", | ||
self.CountCategory(object_weight=object_weight), | ||
transitions={"succeeded": "SAY_RESULT", "failed": "failed"}, | ||
) | ||
|
||
smach.StateMachine.add( | ||
"GETWEIGHT", | ||
self.ObjectWeight(object_weight=object_weight), | ||
transitions={"succeeded": "SAY_RESULT", "failed": "failed"}, | ||
) | ||
|
||
smach.StateMachine.add( | ||
"GETSIZE", | ||
self.ObjectSize(), | ||
transitions={"succeeded": "SAY_RESULT", "failed": "failed"}, | ||
) | ||
|
||
smach.StateMachine.add( | ||
"SAY_RESULT", | ||
self.SayResult(), | ||
transitions={"succeeded": "SAY", "failed": "failed"}, | ||
) | ||
|
||
smach.StateMachine.add( | ||
"SAY", | ||
Say(), | ||
transitions={"succeeded": "succeeded", "failed": "failed"}, | ||
remapping={"text": "say_text"}, | ||
"COMPARE_OBJECTS", | ||
smach.CBState( | ||
self._compare_objects, | ||
outcomes=["succeeded", "failed"], | ||
input_keys=["responses"], | ||
), | ||
transitions={"succeeded": "succeeded"}, | ||
) | ||
|
||
|
||
# if __name__ == "__main__": | ||
# import rospy | ||
# from sensor_msgs.msg import PointCloud2 | ||
|
||
# rospy.init_node("test_object_comparison") | ||
# weight = rospy.get_param("/Object_list/Object") | ||
|
||
# polygon = Polygon([[-1, 0], [1, 0], [0, 1], [1, 1]]) | ||
# sm = ObjectComparison(Polygon(), filter=["bottle", "cup", "cola"], object=weight) | ||
# sm.userdata.pcl_msg = rospy.wait_for_message( | ||
# "/xtion/depth_registered/points", PointCloud2 | ||
# ) | ||
# sm.execute() |