Skip to content

Commit

Permalink
feat: command similarity matcher state
Browse files Browse the repository at this point in the history
  • Loading branch information
m-barker committed Apr 18, 2024
1 parent 859de79 commit b053626
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class TxtQueryService:
best_distances, best_matches = zip(
*sorted(zip(best_distances, best_matches))
)
best_distances.sort()

return TxtQueryResponse(
closest_sentences=best_matches,
Expand Down
41 changes: 41 additions & 0 deletions tasks/gpsr/states/command_similarity_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import smach
import rospy
import rospkg
import os
from lasr_vector_databases_msgs.srv import TxtQuery, TxtQueryRequest


class CommandSimilarityMatcher(smach.State):
def __init__(self):
smach.State.__init__(
self,
outcomes=["success"],
input_keys=["command"],
output_keys=["matched_command"],
)

self._query_service = rospy.ServiceProxy("lasr_faiss/txt_query", TxtQuery)
self._text_directory = os.path.join(
rospkg.RosPack().get_path("gpsr"), "data", "command_data"
)
self._index_directory = os.path.join(
rospkg.RosPack().get_path("gpsr"), "data", "index_data"
)
self._text_paths = [
os.path.join(self._text_directory, f"all_gpsr_commands_chunk_{i+1}.txt")
for i in range(10)
]
self._index_paths = [
os.path.join(self._index_directory, f"all_gpsr_commands_chunk_{i+1}.index")
for i in range(10)
]

def execute(self, userdata):
request = TxtQueryRequest()
request.txt_paths = self._text_paths
request.index_paths = self._index_paths
request.query_sentence = userdata.command
request.k = 1
response = self._query_service(request)
userdata.matched_command = response.closest_sentences[0]
return "success"

0 comments on commit b053626

Please sign in to comment.