diff --git a/skills/src/lasr_skills/__init__.py b/skills/src/lasr_skills/__init__.py index 233a436cc..abb498c39 100755 --- a/skills/src/lasr_skills/__init__.py +++ b/skills/src/lasr_skills/__init__.py @@ -14,3 +14,4 @@ from .receive_object import ReceiveObject from .handover_object import HandoverObject from .ask_and_listen import AskAndListen +from .clip_vqa import QueryImage diff --git a/skills/src/lasr_skills/clip_vqa.py b/skills/src/lasr_skills/clip_vqa.py index 4cb992046..7126d10c0 100755 --- a/skills/src/lasr_skills/clip_vqa.py +++ b/skills/src/lasr_skills/clip_vqa.py @@ -1,15 +1,24 @@ import smach_ros -from lasr_vision_clip.srv import Vqa +from lasr_vision_msgs.srv import Vqa, VqaRequest + +from typing import List, Union class QueryImage(smach_ros.ServiceState): - def __init__( - self, - ): - super(smach_ros.ServiceState, self).__init__( - "/clip_vqa/query_service", - Vqa, - request_slots=["answers"], - response_slots=["answer", "similarity_score"], - ) + def __init__(self, possible_answers: Union[None, List[str]] = None): + + if possible_answers is not None: + super(QueryImage, self).__init__( + "/clip_vqa/query_service", + Vqa, + request=VqaRequest(possible_answers=possible_answers), + response_slots=["answer", "similarity"], + ) + else: + super(QueryImage, self).__init__( + "/clip_vqa/query_service", + Vqa, + request_slots=["possible_answers"], + response_slots=["answer", "similarity"], + )