diff --git a/skills/src/lasr_skills/clip_vqa.py b/skills/src/lasr_skills/clip_vqa.py index 456c0f0d3..4cb992046 100755 --- a/skills/src/lasr_skills/clip_vqa.py +++ b/skills/src/lasr_skills/clip_vqa.py @@ -1,29 +1,15 @@ -import smach -from smach import UserData -from typing import Union -from lasr_skills import Say -from lasr_vision_clip.srv import VqaRequest, VqaResponse, Vqa -import rospy +import smach_ros +from lasr_vision_clip.srv import Vqa -class QueryImage(smach.State): +class QueryImage(smach_ros.ServiceState): + def __init__( self, - model_device: str = "cuda", ): - smach.State.__init__( - self, - outcomes=["succeeded", "failed"], - input_keys=["question", "answers"], - output_keys=["answer", "similarity_score"], + super(smach_ros.ServiceState, self).__init__( + "/clip_vqa/query_service", + Vqa, + request_slots=["answers"], + response_slots=["answer", "similarity_score"], ) - self._service_proxy = rospy.ServiceProxy("/clip_vqa/query_service", Vqa) - - def execute(self, userdata: UserData): - answers = userdata.answers - request = VqaRequest() - request.possible_answers = answers - response = self._service_proxy(request) - userdata.answer = response.answer - userdata.similarity_score = response.similarity - return "succeeded"