Skip to content

Commit

Permalink
refactor(maybe revert)t: minor changes+refactor of object comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
jws-1 committed Jul 14, 2024
1 parent 9f6f7ee commit d89f64d
Showing 1 changed file with 104 additions and 238 deletions.
342 changes: 104 additions & 238 deletions tasks/gpsr/src/gpsr/states/object_comparison.py
Original file line number Diff line number Diff line change
@@ -1,280 +1,146 @@
#!/usr/bin/env python3
import smach
from lasr_skills import Detect3DInArea, Say
from lasr_skills import Detect3D
from lasr_skills import Detect
import smach_ros
from shapely.geometry.polygon import Polygon
from typing import List, Union, Dict
import rospy
from typing import List, Literal, Optional
from lasr_vision_msgs.srv import CroppedDetection, CroppedDetectionRequest
from lasr_vision_msgs.msg import CDRequest


class ObjectComparison(smach.StateMachine):
class CountObjectTypes(smach.State):
def __init__(self, area_polygon: Polygon):
smach.State.__init__(
self,
outcomes=["succeeded", "failed"],
input_keys=["detections_3d"],
output_keys=["detections_types", "object_dict"],
)

def count_types(self, detections):
object_counts = {}
for detection in detections:
object_type = detection.name
if object_type in object_counts:
object_counts[object_type] += 1
else:
object_counts[object_type] = 1
return object_counts
_object_comp_list: List[str] = [
"biggest",
"largets",
"smallest",
"heaviest",
"lightest",
"thinnest",
]

def execute(self, userdata):
filtered_detections = userdata.detections_3d
rospy.loginfo(filtered_detections)
object_counts = self.count_types(filtered_detections.detected_objects)
userdata.object_dict = object_counts
userdata.detections_types = list(object_counts.keys())
return "succeeded"
_biggest_list: List[str] = []

class CountCategory(smach.State):
def __init__(self, object_weight: Union[List[dict], None] = None):
smach.State.__init__(
self,
outcomes=["succeeded", "failed"],
input_keys=["object_dict"],
output_keys=["category_dict", "detections_categories"],
)
self.object_weight = object_weight
_largest_list: List[str] = []

def count_category(self, object_dictionary, count_object):
category_count = {category: 0 for category in object_dictionary.keys()}
for category, items in object_dictionary.items():
for obj in count_object.keys():
if obj in items:
category_count[category] += count_object[obj]
return category_count
_smallest_list: List[str] = []

def execute(self, userdata):
detected_objects = userdata.object_dict
counts = self.count_category(self.object_weight, detected_objects)
category_counts = {
key: value for key, value in counts.items() if value != 0
}
userdata.category_dict = category_counts
userdata.detections_categories = list(category_counts.keys())
return "succeeded"
_heaviest_list: List[str] = []

class ObjectWeight(smach.State):
def __init__(self, object_weight: Union[List[dict], None] = None):
smach.State.__init__(
self,
outcomes=["succeeded", "failed"],
input_keys=["detections_types"],
output_keys=["sorted_weights"],
)
self.object_weight = object_weight
_lightest_list: List[str] = []

def get_weight(self, detections, average_weights):
weight_dict = {}
for category, items in average_weights.items():
for i in detections:
if i in items:
weight = items[i]
weight_dict[i] = weight
return weight_dict
_thinnest_list: List[str] = []

def execute(self, userdata):
weights_dict = self.get_weight(
userdata.detections_types, self.object_weight
)
sorted_weights = sorted(
weights_dict.items(), key=lambda item: item[1], reverse=True
_query: Literal[
"biggest", "largest", "smallest", "heaviest", "lightest", "thinnest"
]

def _compare_objects(self, userdata):
detections = userdata.responses[0].detections_3d
if self._query == "biggest":
biggest_object = next(
(obj for obj in self._biggest_list if obj in detections), None
)
userdata.sorted_weights = sorted_weights
if not biggest_object:
return "failed"
userdata.query_result = biggest_object
return "succeeded"

class ObjectSize(smach.State):
def __init__(self):
smach.State.__init__(
self,
outcomes=["succeeded", "failed"],
input_keys=["object_dict", "detections_3d"],
output_keys=["sorted_size"],
elif self._query == "largest":
largest_object = next(
(obj for obj in self._largest_list if obj in detections), None
)

def property_size_calculation(self, detections, result):
area = dict()
for i in detections:
for object in result.detected_objects:
if i == object.name:
area[i] = object.xywh[2] * object.xywh[3]
return area

def execute(self, userdata):
detections_types = list(userdata.object_dict.keys())
area_dict = self.property_size_calculation(
detections_types, userdata.detections_3d
if not largest_object:
return "failed"
userdata.query_result = largest_object
return "succeeded"
elif self._query == "smallest":
smallest_object = next(
(obj for obj in self._smallest_list if obj in detections), None
)
sorted_size = sorted(
area_dict.items(), key=lambda item: item[1], reverse=True
if not smallest_object:
return "failed"
userdata.query_result = smallest_object
return "succeeded"
elif self._query == "heaviest":
heaviest_object = next(
(obj for obj in self._heaviest_list if obj in detections), None
)
userdata.sorted_size = sorted_size
if not heaviest_object:
return "failed"
userdata.query_result = heaviest_object
return "succeeded"

class DecideOperation(smach.State):
def __init__(self):
smach.State.__init__(
self,
outcomes=[
"do_object_count",
"do_category_count" "do_weight",
"do_size",
"failed",
],
input_keys=["operation_label"],
elif self._query == "lightest":
lightest_object = next(
(obj for obj in self._lightest_list if obj in detections), None
)

def execute(self, userdata):
if userdata.operation_label == "object":
return "do_object_count"
elif userdata.operation_label == "categoty":
return "do_category_count"
elif userdata.operation_label == "weight":
return "do_weight"
elif userdata.operation_label == "size":
return "do_size"
else:
if not lightest_object:
return "failed"

class SayResult(smach.State):
def __init__(self):
smach.State.__init__(
self,
outcomes=["succeeded", "failed"],
input_keys=[
"operation_label",
"detections_types",
"detections_categories",
"sorted_size",
"sorted_weights",
],
output_keys=["say_text"],
userdata.query_result = lightest_object
return "succeeded"
elif self._query == "thinnest":
thinnest_object = next(
(obj for obj in self._thinnest_list if obj in detections), None
)

def execute(self, userdata):
try:
if userdata.operation_label == "count":
object_count = len(userdata.detections_types)
userdata.say_text = f"There are {object_count} objects"
elif userdata.operation_label == "category":
category_count = len(userdata.detections_categories)
userdata.say_text = f"There are {category_count} objects"
elif userdata.operation_label == "weight":
heaviest_object = userdata.sorted_weights[0][0]
userdata.say_text = f"The heaviest object is {heaviest_object}"
elif userdata.operation_label == "size":
biggest_object = userdata.sorted_size[0][0]
userdata.say_text = f"The biggest object is {biggest_object}"
else:
return "failed"
return "succeeded"
except Exception as e:
rospy.logerr(str(e))
if not thinnest_object:
return "failed"
userdata.query_result = thinnest_object
return "succeeded"
else:
return "failed"

def __init__(
self,
area_polygon: Polygon,
operation_label: str,
depth_topic: str = "/xtion/depth_registered/points",
model: str = "yolov8n-seg.pt",
filter: Union[List[str], None] = None,
query: Literal[
"biggest", "largest", "smallest", "heaviest", "lightest", "thinnest"
],
area_polygon: Polygon, # input key
model: str = "yolov8x-seg.pt",
objects: Optional[List[str]] = None,
confidence: float = 0.5,
nms: float = 0.3,
object_weight: Union[List[dict], None] = None,
):

self._query = query

smach.StateMachine.__init__(
self,
outcomes=["succeeded", "failed"],
output_keys=["detections_3d", "object_dict", "say_text"],
self, outcomes=["succeeded", "failed"], output_keys=["query_result"]
)

# Set the operation label in the userdata to decide which task to perform
self.userdata.operation_label = operation_label

with self:

smach.StateMachine.add(
"DETECT_OBJECTS_3D",
Detect3DInArea(
depth_topic=depth_topic,
model=model,
filter=filter,
confidence=confidence,
nms=nms,
smach_ros.ServiceState(
"/vision/cropped_detection",
CroppedDetection,
request=CroppedDetectionRequest(
requests=[
CDRequest(
method="closest",
use_mask=True,
yolo_model=model,
yolo_model_confidence=confidence,
yolo_nms_threshold=nms,
return_sensor_reading=False,
object_names=objects,
polygons=[area_polygon],
)
]
),
output_keys=["responses"],
response_slots=["responses"],
),
transitions={"succeeded": "COUNTOBJECTS", "failed": "failed"},
)

smach.StateMachine.add(
"COUNTOBJECTS",
self.CountObjectTypes(area_polygon),
transitions={"succeeded": "DECIDE_OPERATION", "failed": "failed"},
)

smach.StateMachine.add(
"DECIDE_OPERATION",
self.DecideOperation(),
transitions={
"do_object_count": "SAY_RESULT",
"do_category_count": "COUNTCATEGORY",
"do_weight": "GETWEIGHT",
"do_size": "GETSIZE",
"failed": "failed",
"succeeded": "COMPARE_OBJECTS",
"aborted": "failed",
"preempted": "failed",
},
)

smach.StateMachine.add(
"COUNTCATEGORY",
self.CountCategory(object_weight=object_weight),
transitions={"succeeded": "SAY_RESULT", "failed": "failed"},
)

smach.StateMachine.add(
"GETWEIGHT",
self.ObjectWeight(object_weight=object_weight),
transitions={"succeeded": "SAY_RESULT", "failed": "failed"},
)

smach.StateMachine.add(
"GETSIZE",
self.ObjectSize(),
transitions={"succeeded": "SAY_RESULT", "failed": "failed"},
)

smach.StateMachine.add(
"SAY_RESULT",
self.SayResult(),
transitions={"succeeded": "SAY", "failed": "failed"},
)

smach.StateMachine.add(
"SAY",
Say(),
transitions={"succeeded": "succeeded", "failed": "failed"},
remapping={"text": "say_text"},
"COMPARE_OBJECTS",
smach.CBState(
self._compare_objects,
outcomes=["succeeded", "failed"],
input_keys=["responses"],
),
transitions={"succeeded": "succeeded"},
)


# if __name__ == "__main__":
# import rospy
# from sensor_msgs.msg import PointCloud2

# rospy.init_node("test_object_comparison")
# weight = rospy.get_param("/Object_list/Object")

# polygon = Polygon([[-1, 0], [1, 0], [0, 1], [1, 1]])
# sm = ObjectComparison(Polygon(), filter=["bottle", "cup", "cola"], object=weight)
# sm.userdata.pcl_msg = rospy.wait_for_message(
# "/xtion/depth_registered/points", PointCloud2
# )
# sm.execute()

0 comments on commit d89f64d

Please sign in to comment.