diff --git a/node_script/node.py b/node_script/node.py index 81cb306..c89034d 100755 --- a/node_script/node.py +++ b/node_script/node.py @@ -2,7 +2,7 @@ from typing import Optional import rospy -from jsk_recognition_msgs.msg import LabelArray, VectorArray +from jsk_recognition_msgs.msg import LabelArray, RectArray, VectorArray from node_config import NodeConfig from rospy import Publisher, Subscriber from sensor_msgs.msg import Image @@ -24,6 +24,7 @@ class DeticRosNode: pub_segimg: Optional[Publisher] pub_labels: Optional[Publisher] pub_score: Optional[Publisher] + pub_rects: Optional[Publisher] # otherwise, the following publisher will be used pub_info: Optional[Publisher] @@ -45,6 +46,7 @@ def __init__(self, node_config: Optional[NodeConfig] = None): self.pub_segimg = rospy.Publisher('~segmentation', Image, queue_size=1) self.pub_labels = rospy.Publisher('~detected_classes', LabelArray, queue_size=1) self.pub_score = rospy.Publisher('~score', VectorArray, queue_size=1) + self.pub_rects = rospy.Publisher('~rects', RectArray, queue_size=1) else: self.pub_info = rospy.Publisher('~segmentation_info', SegmentationInfo, queue_size=1) @@ -77,9 +79,11 @@ def callback_image(self, msg: Image): seg_img = raw_result.get_ros_segmentaion_image() labels = raw_result.get_label_array() scores = raw_result.get_score_array() + rects = raw_result.get_rect_array() self.pub_segimg.publish(seg_img) self.pub_labels.publish(labels) self.pub_score.publish(scores) + self.pub_rects.publish(rects) else: assert self.pub_info is not None seg_info = raw_result.get_segmentation_info() diff --git a/node_script/wrapper.py b/node_script/wrapper.py index 1f0c3ff..bddb706 100644 --- a/node_script/wrapper.py +++ b/node_script/wrapper.py @@ -10,7 +10,7 @@ from cv_bridge import CvBridge from detectron2.utils.visualizer import VisImage from detic.predictor import VisualizationDemo -from jsk_recognition_msgs.msg import Label, LabelArray, VectorArray +from jsk_recognition_msgs.msg import Label, LabelArray, Rect, RectArray, VectorArray from node_config import NodeConfig from sensor_msgs.msg import Image from std_msgs.msg import Header @@ -28,6 +28,7 @@ class InferenceRawResult: visualization: Optional[VisImage] header: Header detected_class_names: List[str] + boxes: List[List[float]] def get_ros_segmentaion_image(self) -> Image: seg_img = _cv_bridge.cv2_to_imgmsg(self.segmentation_raw_image, encoding="32SC1") @@ -68,6 +69,14 @@ def get_segmentation_info(self) -> SegmentationInfo: header=self.header) return seg_info + def get_rect_array(self) -> RectArray: + rects = [Rect(x=int(box[0]), + y=int(box[1]), + width=int(box[2] - box[0]), + height=int(box[3] - box[1])) for box in self.boxes] + rec_arr = RectArray(header=self.header, rects=rects) + return rec_arr + class DeticWrapper: predictor: VisualizationDemo @@ -122,12 +131,14 @@ def infer(self, msg: Image) -> InferenceRawResult: pred_masks = list(instances.pred_masks) scores = instances.scores.tolist() class_indices = instances.pred_classes.tolist() + boxes = list(instances.pred_boxes) if len(scores) > 0 and self.node_config.output_highest: best_index = np.argmax(scores) pred_masks = [pred_masks[best_index]] scores = [scores[best_index]] class_indices = [class_indices[best_index]] + boxes = [boxes[best_index]] if self.node_config.verbose: rospy.loginfo("{} with highest score {}".format(self.class_names[class_indices[0]], scores[best_index])) @@ -150,5 +161,6 @@ def infer(self, msg: Image) -> InferenceRawResult: scores, visualized_output, msg.header, - detected_classes_names) + detected_classes_names, + boxes) return result