-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] update database_talker #9
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,11 @@ | |
import json | ||
import os | ||
import random | ||
import re | ||
import rospkg | ||
import shutil | ||
import sys | ||
import yaml | ||
import tempfile | ||
import time | ||
import traceback | ||
|
@@ -50,6 +52,9 @@ | |
class MessageListener(object): | ||
|
||
def __init__(self): | ||
self.robot_name = rospy.get_param('robot/name') | ||
rospy.loginfo("using '{}' database".format(self.robot_name)) | ||
|
||
rospy.loginfo("wait for '/google_chat_ros/send'") | ||
self.chat_ros_ac = actionlib.SimpleActionClient('/google_chat_ros/send', SendMessageAction) | ||
self.chat_ros_ac.wait_for_server() | ||
|
@@ -58,7 +63,7 @@ def __init__(self): | |
rospy.loginfo("wait for '/message_store/query_messages'") | ||
rospy.wait_for_service('/message_store/query_messages') | ||
self.query = rospy.ServiceProxy('/message_store/query_messages', MongoQueryMsg) | ||
|
||
rospy.loginfo("wait for '/classification/inference_server'") | ||
self.classification_ac = actionlib.SimpleActionClient('/classification/inference_server' , ClassificationTaskAction) | ||
self.classification_ac.wait_for_server() | ||
|
@@ -82,59 +87,65 @@ def __init__(self): | |
self.analyze_text_ac = actionlib.SimpleActionClient('/analyze_text/text' , AnalyzeTextAction) | ||
self.analyze_text_ac.wait_for_server() | ||
|
||
# rospy.loginfo("subscribe '/google_chat_ros/message_activity'") | ||
# self.sub = rospy.Subscriber('/google_chat_ros/message_activity', MessageEvent, self.cb) | ||
rospy.loginfo("subscribe '/dialogflow_client/text_action/result'") | ||
self.sub = rospy.Subscriber('/dialogflow_client/text_action/result', DialogTextActionResult, self.cb) | ||
rospy.loginfo("subscribe '/google_chat_ros/message_activity'") | ||
self.sub = rospy.Subscriber('/google_chat_ros/message_activity', MessageEvent, self.cb) | ||
|
||
rospy.loginfo("all done, ready") | ||
|
||
|
||
def make_reply(self, message, lang="en"): | ||
rospy.logwarn("Run make_reply({})".format(message)) | ||
def make_reply(self, message, lang="en", startdate=datetime.datetime.now(JST)-datetime.timedelta(hours=24), duration=datetime.timedelta(hours=24) ): | ||
enddate = startdate+duration | ||
rospy.logwarn("Run make_reply({} from {} to {})".format(message, startdate, enddate)) | ||
query = self.text_to_salience(message) | ||
rospy.logwarn("query using salience word '{}'".format(query)) | ||
# look for images | ||
try: | ||
# get chat message | ||
timestamp = datetime.datetime.now(JST) | ||
results, chat_msgs = self.query_dialogflow(query, timestamp, threshold=0.25) | ||
retry = 0 | ||
while retry < -1 and len(results) == 0 and len(chat_msgs.metas) > 0: | ||
meta = json.loads(chat_msgs.metas[-1].pairs[0].second) | ||
results, chat_msgs = self.query_dialogflow(query, datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST)) | ||
retry = retry + 1 | ||
results, chat_msgs = self.query_dialogflow(query, startdate, enddate, threshold=0.25) | ||
# retry = 0 | ||
# while retry < 3 and len(results) == 0 and len(chat_msgs.metas) > 0: | ||
# meta = json.loads(chat_msgs.metas[-1].pairs[0].second) | ||
# results, chat_msgs = self.query_dialogflow(query, datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST)) | ||
# retry = retry + 1 | ||
# sort based on similarity with 'query' | ||
chat_msgs_sorted = sorted(results, key=lambda x: x['similarity'], reverse=True) | ||
|
||
if len(chat_msgs_sorted) == 0: | ||
rospy.logwarn("no chat message was found") | ||
else: | ||
# query images that was taken when chat_msgs are stored | ||
msg = chat_msgs_sorted[0]['msg'] | ||
meta = chat_msgs_sorted[0]['meta'] | ||
text = chat_msgs_sorted[0]['message'] | ||
timestamp = chat_msgs_sorted[0]['timestamp'] | ||
startdate = chat_msgs_sorted[0]['timestamp'] | ||
action = chat_msgs_sorted[0]['action'] | ||
similarity = chat_msgs_sorted[0]['similarity'] | ||
# query chat to get response | ||
#meta = json.loads(chat_msgs_sorted[0]['meta'].pairs[0].second) | ||
# text = msg.message.argument_text or msg.message.text | ||
# timestamp = datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST) | ||
rospy.loginfo("Found message '{}'({}) at {}, corresponds to query '{}' with {:2f}%".format(text, action, timestamp.strftime('%Y-%m-%d %H:%M:%S'), query, similarity)) | ||
# startdate = datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST) | ||
rospy.loginfo("Found message '{}'({}) at {}, corresponds to query '{}' with {:2f}%".format(text, action, startdate.strftime('%Y-%m-%d %H:%M:%S'), query, similarity)) | ||
|
||
start_time = timestamp-datetime.timedelta(minutes=300) | ||
end_time = timestamp+datetime.timedelta(minutes=30) | ||
# query images when chat was received | ||
start_time = startdate # startdate is updated with found chat space | ||
end_time = enddate # enddate is not modified within this function, it is given from chat | ||
results = self.query_images_and_classify(query=query, start_time=start_time, end_time=end_time) | ||
|
||
# no images found | ||
if len(results) == 0: | ||
return {'text': '記憶がありません🤯'} | ||
|
||
end_time = results[-1]['timestamp'] | ||
|
||
# sort | ||
results = sorted(results, key=lambda x: x['similarities'], reverse=True) | ||
rospy.loginfo("Probabilities of all images {}".format(list(map(lambda x: (x['label'], x['similarities']), results)))) | ||
rospy.loginfo("Probabilities of all images {}".format(list(map(lambda x: (x['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), x['similarities']), results)))) | ||
best_result = results[0] | ||
|
||
|
||
''' | ||
# if probability is too low, try again | ||
while len(results) > 0 and results[0]['similarities'] < 0.25: | ||
|
||
start_time = end_time-datetime.timedelta(hours=24) | ||
timestamp = datetime.datetime.now(JST) | ||
results = self.query_images_and_classify(query=query, start_time=start_time, end_time=end_time, limit=300) | ||
|
@@ -147,39 +158,35 @@ def make_reply(self, message, lang="en"): | |
best_result = results[0] | ||
|
||
rospy.loginfo("Found '{}' image with {:0.2f} % simiarity at {}".format(best_result['label'], best_result['similarities'], best_result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'))) | ||
''' | ||
|
||
## make prompt | ||
goal = VQATaskGoal() | ||
goal.compressed_image = best_result['image'] | ||
|
||
# unusual objects | ||
if random.randint(0,1) == 1: | ||
goal.questions = ['what unusual things can be seen?'] | ||
reaction = 'and you saw ' | ||
else: | ||
goal.questions = ['what is the atmosphere of this place?'] | ||
reaction = 'and the atmosphere of the scene was ' | ||
|
||
# get vqa result | ||
self.vqa_ac.send_goal(goal) | ||
self.vqa_ac.wait_for_result() | ||
result = self.vqa_ac.get_result() | ||
reaction += result.result.result[0].answer | ||
reaction = self.describe_image_scene(best_result['image']) | ||
if len(chat_msgs_sorted) > 0 and chat_msgs_sorted[0]['action'] and 'action' in chat_msgs_sorted[0]: | ||
reaction += " and you felt " + chat_msgs_sorted[0]['action'] | ||
rospy.loginfo("reaction = {}".format(reaction)) | ||
|
||
# make prompt | ||
prompt = 'if you are a pet and someone tells you \"' + message + '\" when we went together, ' + \ | ||
reaction + ' in your memory of that moment, what would you reply? '+ \ | ||
'and ' + reaction + ' in your memory of that moment, what would you reply? '+ \ | ||
'Show only the reply in {lang}'.format(lang={'en': 'English', 'ja':'Japanese'}[lang]) | ||
result = self.completion(prompt=prompt,temperature=0) | ||
loop = 0 | ||
result = None | ||
while loop < 3 and result is None: | ||
try: | ||
result = self.completion(prompt=prompt,temperature=0) | ||
except rospy.ServiceException as e: | ||
rospy.logerr("Service call failed: %s"%e) | ||
result = None | ||
loop += 1 | ||
result.text = result.text.lstrip() | ||
rospy.loginfo("prompt = {}".format(prompt)) | ||
rospy.loginfo("result = {}".format(result)) | ||
# pubish as card | ||
filename = tempfile.mktemp(suffix=".jpg", dir=rospkg.get_ros_home()) | ||
self.write_image_with_annotation(filename, best_result, prompt) | ||
self.publish_google_chat_card(result.text, filename) | ||
return {'text': result.text, 'filename': filename} | ||
|
||
except Exception as e: | ||
raise ValueError("Query failed {} {}".format(e, traceback.format_exc())) | ||
|
||
|
@@ -199,19 +206,27 @@ def write_image_with_annotation(self, filename, best_result, prompt): | |
rospy.logwarn("save images to {}".format(filename)) | ||
|
||
|
||
def query_dialogflow(self, query, end_time, limit=30, threshold=0.0): | ||
rospy.logwarn("Query dialogflow until {}".format(end_time)) | ||
meta_query= {'inserted_at': {"$lt": end_time}} | ||
def query_dialogflow(self, query, start_time, end_time, limit=30, threshold=0.0): | ||
rospy.logwarn("Query dialogflow from {} until {}".format(start_time, end_time)) | ||
meta_query= {'inserted_at': {"$lt": end_time, "$gt": start_time}} | ||
meta_tuple = (StringPair(MongoQueryMsgRequest.JSON_QUERY, json.dumps(meta_query, default=json_util.default)),) | ||
chat_msgs = self.query(database = 'jsk_robot_lifelog', | ||
collection = 'fetch1075', | ||
collection = self.robot_name, | ||
# type = 'google_chat_ros/MessageEvent', | ||
type = 'dialogflow_task_executive/DialogTextActionResult', | ||
single = False, | ||
limit = limit, | ||
meta_query = StringPairList(meta_tuple), | ||
# limit = limit, | ||
meta_query = StringPairList(meta_tuple), | ||
sort_query = StringPairList([StringPair('_meta.inserted_at', '-1')])) | ||
|
||
# optimization... send translate once | ||
messages = '' | ||
for msg, meta in zip(chat_msgs.messages, chat_msgs.metas): | ||
msg = deserialise_message(msg) | ||
message = msg.result.response.query.replace('\n','') | ||
messages += message + '\n' | ||
messages = self.translate(messages, dest="en").text.split('\n') | ||
|
||
# show chats | ||
results = [] | ||
for msg, meta in zip(chat_msgs.messages, chat_msgs.metas): | ||
|
@@ -220,7 +235,8 @@ def query_dialogflow(self, query, end_time, limit=30, threshold=0.0): | |
timestamp = datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST) | ||
# message = msg.message.argument_text or msg.message.text | ||
message = msg.result.response.query | ||
message_translate = self.translate(message, dest="en").text | ||
#message_translate = self.translate(message, dest="en").text | ||
message_translate = messages.pop(0).strip() | ||
result = {'message': message, | ||
'message_translate': message_translate, | ||
'timestamp': timestamp, | ||
|
@@ -229,34 +245,32 @@ def query_dialogflow(self, query, end_time, limit=30, threshold=0.0): | |
'msg': msg, | ||
'meta': meta} | ||
if msg.result.response.action in ['make_reply', 'input.unknown']: | ||
rospy.logwarn("Found dialogflow messages {} at {} but skipping (action:{})".format(result['message'], result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), msg.result.response.action)) | ||
rospy.logwarn("Found dialogflow messages {}({}) at {} but skipping (action:{})".format(result['message'], result['message_translate'], result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), msg.result.response.action)) | ||
else: | ||
rospy.logwarn("Found dialogflow messages {}({}) ({}) at {} ({}:{:.2f})".format(result['message'], result['message_translate'], msg.result.response.action, result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), query, result['similarity'])) | ||
rospy.loginfo("Found dialogflow messages {}({}) ({}) at {} ({}:{:.2f})".format(result['message'], result['message_translate'], msg.result.response.action, result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), query, result['similarity'])) | ||
if ( result['similarity'] > threshold): | ||
results.append(result) | ||
else: | ||
rospy.logwarn(" ... skipping (threshold: {:.2f})".format(threshold)) | ||
|
||
|
||
return results, chat_msgs | ||
|
||
|
||
def query_images_and_classify(self, query, start_time, end_time, limit=30): | ||
def query_images_and_classify(self, query, start_time, end_time, limit=10): | ||
rospy.logwarn("Query images from {} to {}".format(start_time, end_time)) | ||
# meta_query= {'input_topic': '/spot/camera/hand_color/image/compressed/throttled', | ||
# 'inserted_at': {"$gt": start_time, "$lt": end_time}} | ||
meta_query= {'input_topic': '/head_camera/rgb/image_rect_color/compressed/throttled', | ||
'inserted_at': {"$gt": start_time, "$lt": end_time}} | ||
meta_query= {#'input_topic': '/spot/camera/hand_color/image/compressed/throttled', | ||
'inserted_at': {"$gt": start_time, "$lt": end_time}} | ||
meta_tuple = (StringPair(MongoQueryMsgRequest.JSON_QUERY, json.dumps(meta_query, default=json_util.default)),) | ||
msgs = self.query(database = 'jsk_robot_lifelog', | ||
collection = 'fetch1075', | ||
collection = self.robot_name, | ||
type = 'sensor_msgs/CompressedImage', | ||
single = False, | ||
limit = limit, | ||
meta_query = StringPairList(meta_tuple), | ||
sort_query = StringPairList([StringPair('_meta.inserted_at', '-1')])) | ||
|
||
rospy.loginfo("Found {} images".format(len(msgs.messages))) | ||
rospy.loginfo("Found {} images".format(len(msgs.messages))) | ||
if len(msgs.messages) == 0: | ||
rospy.logwarn("no images was found") | ||
|
||
|
@@ -283,13 +297,31 @@ def query_images_and_classify(self, query, start_time, end_time, limit=30): | |
# we do not sorty by probabilites, becasue we also need oldest timestamp | ||
return results | ||
|
||
def describe_image_scene(self, image): | ||
goal = VQATaskGoal() | ||
goal.compressed_image = image | ||
|
||
# unusual objects | ||
if random.randint(0,1) == 1: | ||
goal.questions = ['what unusual things can be seen?'] | ||
reaction = 'you saw ' | ||
else: | ||
goal.questions = ['what is the atmosphere of this place?'] | ||
reaction = 'the atmosphere of the scene was ' | ||
|
||
# get vqa result | ||
self.vqa_ac.send_goal(goal) | ||
self.vqa_ac.wait_for_result() | ||
result = self.vqa_ac.get_result() | ||
reaction += result.result.result[0].answer | ||
return reaction | ||
|
||
def publish_google_chat_card(self, text, filename=None): | ||
def publish_google_chat_card(self, text, space, filename=None): | ||
goal = SendMessageGoal() | ||
goal.text = text | ||
if filename: | ||
goal.cards = [Card(sections=[Section(widgets=[WidgetMarkup(image=Image(localpath=filename))])])] | ||
goal.space = 'spaces/AAAAoTwLBL0' | ||
goal.space = space | ||
rospy.logwarn("send {} to {}".format(goal.text, goal.space)) | ||
self.chat_ros_ac.send_goal_and_wait(goal, execute_timeout=rospy.Duration(0.10)) | ||
|
||
|
@@ -305,7 +337,6 @@ def text_to_salience(self, text): | |
return text | ||
|
||
def translate(self, text, dest): | ||
return Translated(text=text, dest=dest, src="en", origin="unknown", pronunciation="unknown") | ||
global translator | ||
loop = 3 | ||
while loop > 0: | ||
|
@@ -318,11 +349,13 @@ def translate(self, text, dest): | |
translator = Translator() | ||
loop = loop - 1 | ||
return Translated(text=text, dest=dest) | ||
|
||
|
||
def cb(self, msg): | ||
if msg._type == 'google_chat_ros.msg/MessageEvent': | ||
text = message.message.argument_text.lstrip() or message.message.text.lstrip() | ||
space = 'spaces/AAAAoTwLBL0' ## default space JskRobotBot | ||
if msg._type == 'google_chat_ros/MessageEvent': | ||
text = msg.message.argument_text.lstrip() or msg.message.text.lstrip() | ||
space = msg.space.name | ||
rospy.logwarn("Received chat message '{}'".format(text)) | ||
|
||
# ask dialogflow for intent | ||
|
@@ -342,20 +375,37 @@ def cb(self, msg): | |
rospy.logwarn("received dialogflow action '{}'".format(result.response.action)) | ||
print(result.response) | ||
if result.response.action == 'input.unknown': | ||
self.publish_google_chat_card("🤖") | ||
self.publish_google_chat_card("🤖", space) | ||
elif result.response.action == 'make_reply': | ||
self.publish_google_chat_card("・・・", space) | ||
|
||
parameters = yaml.safe_load(result.response.parameters) | ||
startdate=datetime.datetime.now(JST)-datetime.timedelta(hours=24) | ||
duration=datetime.timedelta(hours=24) | ||
if parameters['date']: | ||
startdate = datetime.datetime.strptime(re.sub('\+(\d+):(\d+)$', '+\\1\\2',parameters['date']), "%Y-%m-%dT%H:%M:%S%z") | ||
duration = datetime.timedelta(hours=24) | ||
if parameters['date-period']: | ||
startdate = datetime.datetime.strptime(re.sub('\+(\d+):(\d+)$', '+\\1\\2',parameters['date-period']['startDate']), "%Y-%m-%dT%H:%M:%S%z") | ||
duration = datetime.datetime.strptime(re.sub('\+(\d+):(\d+)$', '+\\1\\2',parameters['date-period']['endDate']), "%Y-%m-%dT%H:%M:%S%z") - startdate | ||
Comment on lines
+385
to
+390
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect that this section cause something wrong with searching database???? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I misunderstood the problem. This section has no problem. This is because a long search, such as a 24-hour search, will not return results until the search is completed, unless the number of data matching the condition is equal to the number of limit items. |
||
print(startdate) | ||
print(duration) | ||
translated = self.translate(result.response.query, dest="en") | ||
self.make_reply(translated.text, translated.src) | ||
ret = self.make_reply(translated.text, translated.src, startdate=startdate, duration=duration) | ||
if 'filename' in ret: | ||
# upload text first, then upload images | ||
self.publish_google_chat_card(ret['text'], space) | ||
self.publish_google_chat_card('', space, ret['filename']) | ||
else: | ||
self.publish_google_chat_card(ret['text'], space) | ||
else: | ||
self.publish_google_chat_card(result.response.response) | ||
self.publish_google_chat_card(result.response.response, space) | ||
|
||
except Exception as e: | ||
rospy.logerr("Callback failed {} {}".format(e, traceback.format_exc())) | ||
self.publish_google_chat_card("💀 {}".format(e)) | ||
self.publish_google_chat_card("💀 {}".format(e), space) | ||
|
||
if __name__ == '__main__': | ||
rospy.init_node('test', anonymous=True) | ||
ml = MessageListener() | ||
#ml.cb2(0) | ||
#ml.cb2('chair') | ||
rospy.spin() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-ichikura もうちょっとマシな返答テキストがあれば,教えてくれると助かります.