Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replayer fix #6

Merged
merged 6 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion fogros2-rt-x/fogros2_rt_x/dataset_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from fogros2_rt_x_msgs.msg import Step, Observation, Action
from cv_bridge import CvBridge
from std_msgs.msg import MultiArrayLayout, MultiArrayDimension
from sensor_msgs.msg import Image
Expand Down
71 changes: 71 additions & 0 deletions fogros2-rt-x/fogros2_rt_x/recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import socket

import rclpy
from rosbag2_py import Recorder, RecordOptions, StorageOptions
from rclpy.node import Node
import time
from std_srvs.srv import Empty
from threading import Thread


class DatasetRecorder(Node):
"""
A class for replaying datasets in ROS2.

Args:
dataset_name (str): The name of the dataset to be replayed.

Attributes:
publisher (Publisher): The publisher for sending step information.
dataset (Dataset): The loaded RLDS dataset.
logger (Logger): The logger for logging information.
feature_spec (DatasetFeatureSpec): The feature specification for the dataset.
episode (Episode): The current episode being replayed.
"""
def __init__(self):
super().__init__("fogros2_rt_x_recorder")

self.new_episode_notification_service = self.create_service(Empty, 'new_episode_notification_service', self.new_episode_notification_service_callback)
self.episode_recorder = Recorder()

self.logger = self.get_logger()
self.episode_counter = 1
self.init_recorder()
self.logger.info("Recording started")

def new_episode_notification_service_callback(self, request, response):
self.logger.info("Received request to start new episode")
self.stop_recorder()
self.start_new_recorder()
return response

def init_recorder(self):
self.start_new_recorder()

def start_new_recorder(self):
self.logger.info(f"starting episode #: {self.episode_counter}")
storage_options = StorageOptions(
uri=f"rosbags/episode_{self.episode_counter}",
storage_id="sqlite3"
)
record_options = RecordOptions()
record_options.all = True
self.thread = Thread(target=self.episode_recorder.record, args=(storage_options, record_options)).start()
self.episode_counter += 1

def stop_recorder(self):
self.episode_recorder.cancel()

def main(args=None):

rclpy.init(args=args)
node = DatasetRecorder()

rclpy.spin(node)

node.destroy_node()
rclpy.shutdown()


if __name__ == "__main__":
main()
164 changes: 75 additions & 89 deletions fogros2-rt-x/fogros2_rt_x/replayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@
import socket

import rclpy
import rosbag2_py as rosbag
from rosbag2_py import Recorder
from rclpy.node import Node
from .dataset_utils import *
from fogros2_rt_x_msgs.msg import Step, Observation, Action
from .dataset_spec import DatasetFeatureSpec
from .plugins.conf_base import *
from .dataset_spec import DatasetFeatureSpec, FeatureSpec
from .dataset_spec import tf_feature_definition_to_ros_msg_class_str
from .conf_base import *
import time

import tensorflow_datasets as tfds
from std_srvs.srv import Empty

class DatasetReplayer(Node):
"""
Expand All @@ -62,8 +65,6 @@ def __init__(self):

self.declare_parameter("dataset_name", "berkeley_fanuc_manipulation")
dataset_name = self.get_parameter("dataset_name").value
self.config = get_dataset_plugin_config_from_str(dataset_name)
self.feature_spec = self.config.get_dataset_feature_spec()

self.declare_parameter("per_episode_interval", 5) # second
self.per_episode_interval = self.get_parameter("per_episode_interval").value
Expand All @@ -77,115 +78,103 @@ def __init__(self):

self.dataset = load_rlds_dataset(dataset_name)
self.logger = self.get_logger()
self.logger.info("Loading Dataset " + str(get_dataset_info([dataset_name])))
self.dataset_info = get_dataset_info([dataset_name])
self.logger.info("Loading Dataset " + str(self.dataset_info))

self.episode = next(iter(self.dataset))
self.dataset_features = self.dataset_info[0][1].features
self.step_features = self.dataset_features["steps"]
self.topics = list()
self.episode_counter = 1
self.init_topics_from_features(self.step_features)

# create an empty ros2 servcice to start and stop recording
self.new_episode_notification_client = self.create_client(Empty, 'new_episode_notification_service')
self.new_episode_notification_req = Empty.Request()

if replay_type == "as_separate_topics":
self.topic_name_to_publisher_dict = dict()
self.topic_name_to_recorder_dict = dict()
self.init_publisher_separate_topics()
elif replay_type == "as_single_topic":
self.init_publisher_single_topic()
elif replay_type == "both":
self.topic_name_to_publisher_dict = dict()
self.init_publisher_separate_topics()
self.init_publisher_single_topic()
else:
raise ValueError(
"Invalid replay_type: "
+ str(replay_type)
+ ". Must be one of: as_separate_topics, as_single_topic."
)

def init_publisher_separate_topics(self):
for observation in self.feature_spec.observation_spec:
publisher = self.create_publisher(
observation.ros_type, observation.ros_topic_name, 10
)
self.topic_name_to_publisher_dict[observation.ros_topic_name] = publisher
def init_topics_from_features(self, features):
for name, tf_feature in features.items():
if isinstance(tf_feature, tfds.features.FeaturesDict):
self.init_topics_from_features(tf_feature)
else:
if tf_feature.shape == () and tf_feature.dtype.is_bool:
self.topics.append(FeatureSpec(name, Scalar(dtype=tf.bool)))
else:
self.topics.append(FeatureSpec(name, tf_feature))





for action in self.feature_spec.action_spec:
def init_publisher_separate_topics(self):
for topic in self.topics:
publisher = self.create_publisher(
action.ros_type, action.ros_topic_name, 10
topic.ros_type, topic.ros_topic_name, 10
)
self.topic_name_to_publisher_dict[action.ros_topic_name] = publisher

for step in self.feature_spec.step_spec:
publisher = self.create_publisher(step.ros_type, step.ros_topic_name, 10)
self.topic_name_to_publisher_dict[step.ros_topic_name] = publisher
self.topic_name_to_publisher_dict[topic.ros_topic_name] = publisher

self.create_timer(
self.per_episode_interval, self.timer_callback_separate_topics
)

def init_publisher_single_topic(self):
self.publisher = self.create_publisher(Step, "step_info", 10)
callback = self.timer_callback_single_topic
self.create_timer(self.per_episode_interval, callback)

def timer_callback_separate_topics(self):
for step in self.episode["steps"]:
for observation in self.feature_spec.observation_spec:
if observation.tf_name not in step["observation"]:
self.logger.warn(
f"Observation {observation.tf_name} not found in step data"
for topic in self.topics:
if topic.tf_name in step:
# Fetch from step data
msg = topic.convert_tf_tensor_data_to_ros2_msg(
step[topic.tf_name]
)
continue
msg = observation.convert_tf_tensor_data_to_ros2_msg(
step["observation"][observation.tf_name]
)
self.logger.info(
f"Publishing observation {observation.tf_name} on topic {observation.ros_topic_name}"
)
self.topic_name_to_publisher_dict[observation.ros_topic_name].publish(
msg
)

for action in self.feature_spec.action_spec:
if type(step["action"]) is not dict:
# action is only one tensor/datatype, not a dictionary
msg = action.convert_tf_tensor_data_to_ros2_msg(step["action"])
else:
if action.tf_name not in step["action"]:
self.logger.warn(
f"Action {action.tf_name} not found in step data"
)
continue
msg = action.convert_tf_tensor_data_to_ros2_msg(
step["action"][action.tf_name]
self.logger.info(
f"Publishing step {topic.tf_name} on topic {topic.ros_topic_name}"
)
self.logger.info(
f"Publishing action {action.tf_name} on topic {action.ros_topic_name}"
)
self.topic_name_to_publisher_dict[action.ros_topic_name].publish(msg)

for step_feature in self.feature_spec.step_spec:
if step_feature.tf_name not in step:
self.logger.warn(
f"Step {step_feature.tf_name} not found in step data"
if type(step["observation"]) is dict and topic.tf_name in step["observation"]:
# Fetch from observation data
msg = topic.convert_tf_tensor_data_to_ros2_msg(
step["observation"][topic.tf_name]
)
continue
msg = step_feature.convert_tf_tensor_data_to_ros2_msg(
step[step_feature.tf_name]
)
self.logger.info(
f"Publishing step {step_feature.tf_name} on topic {step_feature.ros_topic_name}"
)
self.topic_name_to_publisher_dict[step_feature.ros_topic_name].publish(
msg
)

self.logger.info(
f"Publishing observation {topic.tf_name} on topic {topic.ros_topic_name}"
)
if type(step["action"]) is dict and topic.tf_name in step["action"]:
# Fetch from action data
msg = topic.convert_tf_tensor_data_to_ros2_msg(
step["action"][topic.tf_name]
)
self.logger.info(
f"Publishing action {topic.tf_name} on topic {topic.ros_topic_name}"
)

self.topic_name_to_publisher_dict[topic.ros_topic_name].publish(msg)

self.check_last_step_update_recorder(step)
time.sleep(self.per_step_interval)

self.episode = next(iter(self.dataset))

def timer_callback_single_topic(self):
for step in self.episode["steps"]:
msg = self.feature_spec.convert_tf_step_to_ros2_msg(
step, step["action"], step["observation"]
)
self.publisher.publish(msg)
self.episode = next(iter(self.dataset))
self.create_timer(
self.per_episode_interval, self.timer_callback_separate_topics
)

def check_last_step_update_recorder(self, step):
if step["is_last"]:
self.logger.info(f"End of the current episode")
self.episode_counter += 1
# restart_recorder(self.episode_counter)
# restart_recorder()
while not self.new_episode_notification_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info('Service not available, waiting again...')
self.future = self.new_episode_notification_client.call_async(self.new_episode_notification_req)


def main(args=None):
Expand All @@ -196,9 +185,6 @@ def main(args=None):

rclpy.spin(node)

# Destroy the timer attached to the node explicitly
# (optional - otherwise it will be done automatically
# when the garbage collector destroys the node object)
node.destroy_node()
rclpy.shutdown()

Expand Down
10 changes: 10 additions & 0 deletions fogros2-rt-x/launch/replayer.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,14 @@ def generate_launch_description():
)

ld.add_action(replayer_node)

recorder_node = Node(
package="fogros2_rt_x",
executable="recorder",
output="screen",
parameters = [
{"dataset_name": "berkeley_fanuc_manipulation"},
]
)
ld.add_action(recorder_node)
return ld
1 change: 1 addition & 0 deletions fogros2-rt-x/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
entry_points={
"console_scripts": [
"replayer = fogros2_rt_x.replayer:main",
"recorder = fogros2_rt_x.recorder:main",
],
"ros2cli.command": [
"fgr = fogros2_rt_x.cli:FogCommand",
Expand Down
Loading