Skip to content

Commit

Permalink
refactor: wip describe people changes
Browse files Browse the repository at this point in the history
  • Loading branch information
insertish committed Dec 2, 2023
1 parent fbbaf49 commit adbe01c
Showing 1 changed file with 12 additions and 89 deletions.
101 changes: 12 additions & 89 deletions skills/src/lasr_skills/test_describe_people.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lasr_vision_msgs.msg import BodyPixMaskRequest, ColourPrediction, FeatureWithColour
from lasr_vision_msgs.srv import YoloDetection, BodyPixDetection, TorchFaceFeatureDetection

from .vision import GetImage
from .vision import GetImage, ImageMsgToCv2


class TestDescribePeople(smach.StateMachine):
Expand All @@ -28,16 +28,16 @@ def __init__(self):

with self:
smach.StateMachine.add('GET_IMAGE', GetImage(), transitions={
'succeeded': 'CONVERT_IMAGE'})
smach.StateMachine.add('CONVERT_IMAGE', ImageMsgToCv2(), transitions={
'succeeded': 'SEGMENT'})
# smach.StateMachine.add('RESIZE_TEST', self.ResizeTest(), transitions={'succeeded': 'SEGMENT_FACE'})
# smach.StateMachine.add('SEGMENT_FACE', self.SegmentFace())

sm_con = smach.Concurrence(outcomes=['succeeded', 'failed'],
default_outcome='failed',
outcome_map={'succeeded': {
'SEGMENT_YOLO': 'succeeded', 'SEGMENT_BODYPIX': 'succeeded'}},
input_keys=['img_msg'],
output_keys=['people_detections', 'masks'])
input_keys=['img', 'img_msg'],
output_keys=['people_detections', 'bodypix_masks'])

with sm_con:
smach.Concurrence.add('SEGMENT_YOLO', self.SegmentYolo())
Expand Down Expand Up @@ -73,7 +73,7 @@ def execute(self, userdata):
class SegmentBodypix(smach.State):
def __init__(self):
smach.State.__init__(self, outcomes=['succeeded', 'failed'], input_keys=[
'img_msg'], output_keys=['masks'])
'img_msg'], output_keys=['bodypix_masks'])
self.bodypix = rospy.ServiceProxy(
'/bodypix/detect', BodyPixDetection)

Expand All @@ -88,93 +88,16 @@ def execute(self, userdata):
masks = [torso, head]

result = self.bodypix(userdata.img_msg, "resnet50", 0.7, masks)
userdata.masks = result.masks
userdata.bodypix_masks = result.masks
return 'succeeded'
except rospy.ServiceException as e:
rospy.logwarn(f"Unable to perform inference. ({str(e)})")
return 'failed'

class SegmentFace(smach.State):
def __init__(self):
smach.State.__init__(self, outcomes=['succeeded', 'failed'], input_keys=[
'img_msg', 'people_detections', 'masks'], output_keys=[])
self.torch_face_features = rospy.ServiceProxy(
'/torch/detect/face_features', TorchFaceFeatureDetection)
self.test = rospy.Publisher('/test', Image)

def execute(self, userdata):
try:

# uncomment to make everything work as previously:
# self.torch(userdata.img_msg, "resnet50", 0.7, [])

# TODO: remove
rospy.loginfo('Decoding')
size = (userdata.img_msg.width, userdata.img_msg.height)
if userdata.img_msg.encoding in ['bgr8', '8UC3']:
img = PillowImage.frombytes(
'RGB', size, userdata.img_msg.data, 'raw')

# BGR => RGB
img = PillowImage.fromarray(np.array(img)[:, :, ::-1])
elif userdata.img_msg.encoding == 'rgb8':
img = PillowImage.frombytes(
'RGB', size, userdata.img_msg.data, 'raw')
else:
raise Exception("Unsupported format.")

frame = np.array(img)
frame = frame[:, :, ::-1].copy()

for person in userdata.people_detections:
# mask
mask_image = np.zeros((size[1], size[0]), np.uint8)
contours = np.array(person.xyseg).reshape(-1, 2)
cv2.fillPoly(mask_image, pts=np.int32(
[contours]), color=(255, 255, 255))
# mask_bin = mask_image > 128

# crop out face
face_mask = np.array(userdata.masks[1].mask).reshape(
userdata.masks[1].shape[0], userdata.masks[1].shape[1])
mask_image[face_mask == 0] = 0

a = cv2_img.extract_mask_region(frame, mask_image)
height, width, _ = a.shape

msg = Image()
msg.header.stamp = rospy.Time.now()
msg.width = width
msg.height = height
msg.encoding = 'bgr8'
msg.is_bigendian = 1
msg.step = 3 * width
msg.data = a.tobytes()
self.test.publish(msg)
print(self.torch_face_features(msg))

return 'succeeded'
except rospy.ServiceException as e:
rospy.logwarn(f"Unable to perform inference. ({str(e)})")
return 'failed'

class ResizeTest(smach.State):
def __init__(self):
smach.State.__init__(self, outcomes=['succeeded', 'failed'], input_keys=[
'img_msg'], output_keys=['img_msg'])

def execute(self, userdata):
# temp
rospy.loginfo('Decoding')
frame = cv2_img.msg_to_cv2_img(userdata.img_msg)
frame = cv2.resize(frame, (128, 128))
userdata.img_msg = cv2_img.cv2_img_to_msg(frame)
return 'succeeded'

class Filter(smach.State):
def __init__(self):
smach.State.__init__(self, outcomes=['succeeded', 'failed'], input_keys=[
'img_msg', 'people_detections', 'masks'], output_keys=['people'])
'img', 'people_detections', 'bodypix_masks'], output_keys=['people'])
self.torch_face_features = rospy.ServiceProxy(
'/torch/detect/face_features', TorchFaceFeatureDetection)

Expand All @@ -189,7 +112,7 @@ def execute(self, userdata):

# decode the image
rospy.loginfo('Decoding')
img = cv2_img.msg_to_cv2_img(userdata.img_msg)
img = userdata.img
height, width, _ = img.shape

people = []
Expand All @@ -208,7 +131,7 @@ def execute(self, userdata):
features = []

# process part masks
for (bodypix_mask, part) in zip(userdata.masks, ['torso', 'head']):
for (bodypix_mask, part) in zip(userdata.bodypix_masks, ['torso', 'head']):
part_mask = np.array(bodypix_mask.mask).reshape(
bodypix_mask.shape[0], bodypix_mask.shape[1])

Expand Down Expand Up @@ -252,8 +175,8 @@ def execute(self, userdata):
# mask_bin = mask_image > 128

# crop out face
face_mask = np.array(userdata.masks[1].mask).reshape(
userdata.masks[1].shape[0], userdata.masks[1].shape[1])
face_mask = np.array(userdata.bodypix_masks[1].mask).reshape(
userdata.bodypix_masks[1].shape[0], userdata.bodypix_masks[1].shape[1])
mask_image[face_mask == 0] = 0

a = cv2_img.extract_mask_region(img, mask_image)
Expand Down

0 comments on commit adbe01c

Please sign in to comment.