Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] update database_talker #9

Merged
merged 2 commits into from
Jun 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 123 additions & 73 deletions database_talker/scripts/hoge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import json
import os
import random
import re
import rospkg
import shutil
import sys
import yaml
import tempfile
import time
import traceback
Expand Down Expand Up @@ -50,6 +52,9 @@
class MessageListener(object):

def __init__(self):
self.robot_name = rospy.get_param('robot/name')
rospy.loginfo("using '{}' database".format(self.robot_name))

rospy.loginfo("wait for '/google_chat_ros/send'")
self.chat_ros_ac = actionlib.SimpleActionClient('/google_chat_ros/send', SendMessageAction)
self.chat_ros_ac.wait_for_server()
Expand All @@ -58,7 +63,7 @@ def __init__(self):
rospy.loginfo("wait for '/message_store/query_messages'")
rospy.wait_for_service('/message_store/query_messages')
self.query = rospy.ServiceProxy('/message_store/query_messages', MongoQueryMsg)

rospy.loginfo("wait for '/classification/inference_server'")
self.classification_ac = actionlib.SimpleActionClient('/classification/inference_server' , ClassificationTaskAction)
self.classification_ac.wait_for_server()
Expand All @@ -82,59 +87,65 @@ def __init__(self):
self.analyze_text_ac = actionlib.SimpleActionClient('/analyze_text/text' , AnalyzeTextAction)
self.analyze_text_ac.wait_for_server()

# rospy.loginfo("subscribe '/google_chat_ros/message_activity'")
# self.sub = rospy.Subscriber('/google_chat_ros/message_activity', MessageEvent, self.cb)
rospy.loginfo("subscribe '/dialogflow_client/text_action/result'")
self.sub = rospy.Subscriber('/dialogflow_client/text_action/result', DialogTextActionResult, self.cb)
rospy.loginfo("subscribe '/google_chat_ros/message_activity'")
self.sub = rospy.Subscriber('/google_chat_ros/message_activity', MessageEvent, self.cb)

rospy.loginfo("all done, ready")


def make_reply(self, message, lang="en"):
rospy.logwarn("Run make_reply({})".format(message))
def make_reply(self, message, lang="en", startdate=datetime.datetime.now(JST)-datetime.timedelta(hours=24), duration=datetime.timedelta(hours=24) ):
enddate = startdate+duration
rospy.logwarn("Run make_reply({} from {} to {})".format(message, startdate, enddate))
query = self.text_to_salience(message)
rospy.logwarn("query using salience word '{}'".format(query))
# look for images
try:
# get chat message
timestamp = datetime.datetime.now(JST)
results, chat_msgs = self.query_dialogflow(query, timestamp, threshold=0.25)
retry = 0
while retry < -1 and len(results) == 0 and len(chat_msgs.metas) > 0:
meta = json.loads(chat_msgs.metas[-1].pairs[0].second)
results, chat_msgs = self.query_dialogflow(query, datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST))
retry = retry + 1
results, chat_msgs = self.query_dialogflow(query, startdate, enddate, threshold=0.25)
# retry = 0
# while retry < 3 and len(results) == 0 and len(chat_msgs.metas) > 0:
# meta = json.loads(chat_msgs.metas[-1].pairs[0].second)
# results, chat_msgs = self.query_dialogflow(query, datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST))
# retry = retry + 1
# sort based on similarity with 'query'
chat_msgs_sorted = sorted(results, key=lambda x: x['similarity'], reverse=True)

if len(chat_msgs_sorted) == 0:
rospy.logwarn("no chat message was found")
else:
# query images that was taken when chat_msgs are stored
msg = chat_msgs_sorted[0]['msg']
meta = chat_msgs_sorted[0]['meta']
text = chat_msgs_sorted[0]['message']
timestamp = chat_msgs_sorted[0]['timestamp']
startdate = chat_msgs_sorted[0]['timestamp']
action = chat_msgs_sorted[0]['action']
similarity = chat_msgs_sorted[0]['similarity']
# query chat to get response
#meta = json.loads(chat_msgs_sorted[0]['meta'].pairs[0].second)
# text = msg.message.argument_text or msg.message.text
# timestamp = datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST)
rospy.loginfo("Found message '{}'({}) at {}, corresponds to query '{}' with {:2f}%".format(text, action, timestamp.strftime('%Y-%m-%d %H:%M:%S'), query, similarity))
# startdate = datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST)
rospy.loginfo("Found message '{}'({}) at {}, corresponds to query '{}' with {:2f}%".format(text, action, startdate.strftime('%Y-%m-%d %H:%M:%S'), query, similarity))

start_time = timestamp-datetime.timedelta(minutes=300)
end_time = timestamp+datetime.timedelta(minutes=30)
# query images when chat was received
start_time = startdate # startdate is updated with found chat space
end_time = enddate # enddate is not modified within this function, it is given from chat
results = self.query_images_and_classify(query=query, start_time=start_time, end_time=end_time)

# no images found
if len(results) == 0:
return {'text': '記憶がありません🤯'}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-ichikura もうちょっとマシな返答テキストがあれば,教えてくれると助かります.


end_time = results[-1]['timestamp']

# sort
results = sorted(results, key=lambda x: x['similarities'], reverse=True)
rospy.loginfo("Probabilities of all images {}".format(list(map(lambda x: (x['label'], x['similarities']), results))))
rospy.loginfo("Probabilities of all images {}".format(list(map(lambda x: (x['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), x['similarities']), results))))
best_result = results[0]


'''
# if probability is too low, try again
while len(results) > 0 and results[0]['similarities'] < 0.25:

start_time = end_time-datetime.timedelta(hours=24)
timestamp = datetime.datetime.now(JST)
results = self.query_images_and_classify(query=query, start_time=start_time, end_time=end_time, limit=300)
Expand All @@ -147,39 +158,35 @@ def make_reply(self, message, lang="en"):
best_result = results[0]

rospy.loginfo("Found '{}' image with {:0.2f} % simiarity at {}".format(best_result['label'], best_result['similarities'], best_result['timestamp'].strftime('%Y-%m-%d %H:%M:%S')))
'''

## make prompt
goal = VQATaskGoal()
goal.compressed_image = best_result['image']

# unusual objects
if random.randint(0,1) == 1:
goal.questions = ['what unusual things can be seen?']
reaction = 'and you saw '
else:
goal.questions = ['what is the atmosphere of this place?']
reaction = 'and the atmosphere of the scene was '

# get vqa result
self.vqa_ac.send_goal(goal)
self.vqa_ac.wait_for_result()
result = self.vqa_ac.get_result()
reaction += result.result.result[0].answer
reaction = self.describe_image_scene(best_result['image'])
if len(chat_msgs_sorted) > 0 and chat_msgs_sorted[0]['action'] and 'action' in chat_msgs_sorted[0]:
reaction += " and you felt " + chat_msgs_sorted[0]['action']
rospy.loginfo("reaction = {}".format(reaction))

# make prompt
prompt = 'if you are a pet and someone tells you \"' + message + '\" when we went together, ' + \
reaction + ' in your memory of that moment, what would you reply? '+ \
'and ' + reaction + ' in your memory of that moment, what would you reply? '+ \
'Show only the reply in {lang}'.format(lang={'en': 'English', 'ja':'Japanese'}[lang])
result = self.completion(prompt=prompt,temperature=0)
loop = 0
result = None
while loop < 3 and result is None:
try:
result = self.completion(prompt=prompt,temperature=0)
except rospy.ServiceException as e:
rospy.logerr("Service call failed: %s"%e)
result = None
loop += 1
result.text = result.text.lstrip()
rospy.loginfo("prompt = {}".format(prompt))
rospy.loginfo("result = {}".format(result))
# pubish as card
filename = tempfile.mktemp(suffix=".jpg", dir=rospkg.get_ros_home())
self.write_image_with_annotation(filename, best_result, prompt)
self.publish_google_chat_card(result.text, filename)
return {'text': result.text, 'filename': filename}

except Exception as e:
raise ValueError("Query failed {} {}".format(e, traceback.format_exc()))

Expand All @@ -199,19 +206,27 @@ def write_image_with_annotation(self, filename, best_result, prompt):
rospy.logwarn("save images to {}".format(filename))


def query_dialogflow(self, query, end_time, limit=30, threshold=0.0):
rospy.logwarn("Query dialogflow until {}".format(end_time))
meta_query= {'inserted_at': {"$lt": end_time}}
def query_dialogflow(self, query, start_time, end_time, limit=30, threshold=0.0):
rospy.logwarn("Query dialogflow from {} until {}".format(start_time, end_time))
meta_query= {'inserted_at': {"$lt": end_time, "$gt": start_time}}
meta_tuple = (StringPair(MongoQueryMsgRequest.JSON_QUERY, json.dumps(meta_query, default=json_util.default)),)
chat_msgs = self.query(database = 'jsk_robot_lifelog',
collection = 'fetch1075',
collection = self.robot_name,
# type = 'google_chat_ros/MessageEvent',
type = 'dialogflow_task_executive/DialogTextActionResult',
single = False,
limit = limit,
meta_query = StringPairList(meta_tuple),
# limit = limit,
meta_query = StringPairList(meta_tuple),
sort_query = StringPairList([StringPair('_meta.inserted_at', '-1')]))

# optimization... send translate once
messages = ''
for msg, meta in zip(chat_msgs.messages, chat_msgs.metas):
msg = deserialise_message(msg)
message = msg.result.response.query.replace('\n','')
messages += message + '\n'
messages = self.translate(messages, dest="en").text.split('\n')

# show chats
results = []
for msg, meta in zip(chat_msgs.messages, chat_msgs.metas):
Expand All @@ -220,7 +235,8 @@ def query_dialogflow(self, query, end_time, limit=30, threshold=0.0):
timestamp = datetime.datetime.fromtimestamp(meta['timestamp']//1000000000, JST)
# message = msg.message.argument_text or msg.message.text
message = msg.result.response.query
message_translate = self.translate(message, dest="en").text
#message_translate = self.translate(message, dest="en").text
message_translate = messages.pop(0).strip()
result = {'message': message,
'message_translate': message_translate,
'timestamp': timestamp,
Expand All @@ -229,34 +245,32 @@ def query_dialogflow(self, query, end_time, limit=30, threshold=0.0):
'msg': msg,
'meta': meta}
if msg.result.response.action in ['make_reply', 'input.unknown']:
rospy.logwarn("Found dialogflow messages {} at {} but skipping (action:{})".format(result['message'], result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), msg.result.response.action))
rospy.logwarn("Found dialogflow messages {}({}) at {} but skipping (action:{})".format(result['message'], result['message_translate'], result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), msg.result.response.action))
else:
rospy.logwarn("Found dialogflow messages {}({}) ({}) at {} ({}:{:.2f})".format(result['message'], result['message_translate'], msg.result.response.action, result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), query, result['similarity']))
rospy.loginfo("Found dialogflow messages {}({}) ({}) at {} ({}:{:.2f})".format(result['message'], result['message_translate'], msg.result.response.action, result['timestamp'].strftime('%Y-%m-%d %H:%M:%S'), query, result['similarity']))
if ( result['similarity'] > threshold):
results.append(result)
else:
rospy.logwarn(" ... skipping (threshold: {:.2f})".format(threshold))


return results, chat_msgs


def query_images_and_classify(self, query, start_time, end_time, limit=30):
def query_images_and_classify(self, query, start_time, end_time, limit=10):
rospy.logwarn("Query images from {} to {}".format(start_time, end_time))
# meta_query= {'input_topic': '/spot/camera/hand_color/image/compressed/throttled',
# 'inserted_at': {"$gt": start_time, "$lt": end_time}}
meta_query= {'input_topic': '/head_camera/rgb/image_rect_color/compressed/throttled',
'inserted_at': {"$gt": start_time, "$lt": end_time}}
meta_query= {#'input_topic': '/spot/camera/hand_color/image/compressed/throttled',
'inserted_at': {"$gt": start_time, "$lt": end_time}}
meta_tuple = (StringPair(MongoQueryMsgRequest.JSON_QUERY, json.dumps(meta_query, default=json_util.default)),)
msgs = self.query(database = 'jsk_robot_lifelog',
collection = 'fetch1075',
collection = self.robot_name,
type = 'sensor_msgs/CompressedImage',
single = False,
limit = limit,
meta_query = StringPairList(meta_tuple),
sort_query = StringPairList([StringPair('_meta.inserted_at', '-1')]))

rospy.loginfo("Found {} images".format(len(msgs.messages)))
rospy.loginfo("Found {} images".format(len(msgs.messages)))
if len(msgs.messages) == 0:
rospy.logwarn("no images was found")

Expand All @@ -283,13 +297,31 @@ def query_images_and_classify(self, query, start_time, end_time, limit=30):
# we do not sorty by probabilites, becasue we also need oldest timestamp
return results

def describe_image_scene(self, image):
goal = VQATaskGoal()
goal.compressed_image = image

# unusual objects
if random.randint(0,1) == 1:
goal.questions = ['what unusual things can be seen?']
reaction = 'you saw '
else:
goal.questions = ['what is the atmosphere of this place?']
reaction = 'the atmosphere of the scene was '

# get vqa result
self.vqa_ac.send_goal(goal)
self.vqa_ac.wait_for_result()
result = self.vqa_ac.get_result()
reaction += result.result.result[0].answer
return reaction

def publish_google_chat_card(self, text, filename=None):
def publish_google_chat_card(self, text, space, filename=None):
goal = SendMessageGoal()
goal.text = text
if filename:
goal.cards = [Card(sections=[Section(widgets=[WidgetMarkup(image=Image(localpath=filename))])])]
goal.space = 'spaces/AAAAoTwLBL0'
goal.space = space
rospy.logwarn("send {} to {}".format(goal.text, goal.space))
self.chat_ros_ac.send_goal_and_wait(goal, execute_timeout=rospy.Duration(0.10))

Expand All @@ -305,7 +337,6 @@ def text_to_salience(self, text):
return text

def translate(self, text, dest):
return Translated(text=text, dest=dest, src="en", origin="unknown", pronunciation="unknown")
global translator
loop = 3
while loop > 0:
Expand All @@ -318,11 +349,13 @@ def translate(self, text, dest):
translator = Translator()
loop = loop - 1
return Translated(text=text, dest=dest)


def cb(self, msg):
if msg._type == 'google_chat_ros.msg/MessageEvent':
text = message.message.argument_text.lstrip() or message.message.text.lstrip()
space = 'spaces/AAAAoTwLBL0' ## default space JskRobotBot
if msg._type == 'google_chat_ros/MessageEvent':
text = msg.message.argument_text.lstrip() or msg.message.text.lstrip()
space = msg.space.name
rospy.logwarn("Received chat message '{}'".format(text))

# ask dialogflow for intent
Expand All @@ -342,20 +375,37 @@ def cb(self, msg):
rospy.logwarn("received dialogflow action '{}'".format(result.response.action))
print(result.response)
if result.response.action == 'input.unknown':
self.publish_google_chat_card("🤖")
self.publish_google_chat_card("🤖", space)
elif result.response.action == 'make_reply':
self.publish_google_chat_card("・・・", space)

parameters = yaml.safe_load(result.response.parameters)
startdate=datetime.datetime.now(JST)-datetime.timedelta(hours=24)
duration=datetime.timedelta(hours=24)
if parameters['date']:
startdate = datetime.datetime.strptime(re.sub('\+(\d+):(\d+)$', '+\\1\\2',parameters['date']), "%Y-%m-%dT%H:%M:%S%z")
duration = datetime.timedelta(hours=24)
if parameters['date-period']:
startdate = datetime.datetime.strptime(re.sub('\+(\d+):(\d+)$', '+\\1\\2',parameters['date-period']['startDate']), "%Y-%m-%dT%H:%M:%S%z")
duration = datetime.datetime.strptime(re.sub('\+(\d+):(\d+)$', '+\\1\\2',parameters['date-period']['endDate']), "%Y-%m-%dT%H:%M:%S%z") - startdate
Comment on lines +385 to +390
Copy link

@tkmtnt7000 tkmtnt7000 Jun 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that this section cause something wrong with searching database????

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misunderstood the problem. This section has no problem.

This is because a long search, such as a 24-hour search, will not return results until the search is completed, unless the number of data matching the condition is equal to the number of limit items.

print(startdate)
print(duration)
translated = self.translate(result.response.query, dest="en")
self.make_reply(translated.text, translated.src)
ret = self.make_reply(translated.text, translated.src, startdate=startdate, duration=duration)
if 'filename' in ret:
# upload text first, then upload images
self.publish_google_chat_card(ret['text'], space)
self.publish_google_chat_card('', space, ret['filename'])
else:
self.publish_google_chat_card(ret['text'], space)
else:
self.publish_google_chat_card(result.response.response)
self.publish_google_chat_card(result.response.response, space)

except Exception as e:
rospy.logerr("Callback failed {} {}".format(e, traceback.format_exc()))
self.publish_google_chat_card("💀 {}".format(e))
self.publish_google_chat_card("💀 {}".format(e), space)

if __name__ == '__main__':
rospy.init_node('test', anonymous=True)
ml = MessageListener()
#ml.cb2(0)
#ml.cb2('chair')
rospy.spin()