Skip to content

Commit

Permalink
Merge pull request #6 from BerkeleyAutomation/replayer-fix
Browse files Browse the repository at this point in the history
Replayer fix
  • Loading branch information
KeplerC authored Feb 11, 2024
2 parents 2ba7f04 + 2e33efb commit 2ad5326
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 90 deletions.
1 change: 0 additions & 1 deletion fogros2-rt-x/fogros2_rt_x/dataset_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from fogros2_rt_x_msgs.msg import Step, Observation, Action
from cv_bridge import CvBridge
from std_msgs.msg import MultiArrayLayout, MultiArrayDimension
from sensor_msgs.msg import Image
Expand Down
71 changes: 71 additions & 0 deletions fogros2-rt-x/fogros2_rt_x/recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import socket

import rclpy
from rosbag2_py import Recorder, RecordOptions, StorageOptions
from rclpy.node import Node
import time
from std_srvs.srv import Empty
from threading import Thread


class DatasetRecorder(Node):
"""
A class for replaying datasets in ROS2.
Args:
dataset_name (str): The name of the dataset to be replayed.
Attributes:
publisher (Publisher): The publisher for sending step information.
dataset (Dataset): The loaded RLDS dataset.
logger (Logger): The logger for logging information.
feature_spec (DatasetFeatureSpec): The feature specification for the dataset.
episode (Episode): The current episode being replayed.
"""
def __init__(self):
super().__init__("fogros2_rt_x_recorder")

self.new_episode_notification_service = self.create_service(Empty, 'new_episode_notification_service', self.new_episode_notification_service_callback)
self.episode_recorder = Recorder()

self.logger = self.get_logger()
self.episode_counter = 1
self.init_recorder()
self.logger.info("Recording started")

def new_episode_notification_service_callback(self, request, response):
self.logger.info("Received request to start new episode")
self.stop_recorder()
self.start_new_recorder()
return response

def init_recorder(self):
self.start_new_recorder()

def start_new_recorder(self):
self.logger.info(f"starting episode #: {self.episode_counter}")
storage_options = StorageOptions(
uri=f"rosbags/episode_{self.episode_counter}",
storage_id="sqlite3"
)
record_options = RecordOptions()
record_options.all = True
self.thread = Thread(target=self.episode_recorder.record, args=(storage_options, record_options)).start()
self.episode_counter += 1

def stop_recorder(self):
self.episode_recorder.cancel()

def main(args=None):

rclpy.init(args=args)
node = DatasetRecorder()

rclpy.spin(node)

node.destroy_node()
rclpy.shutdown()


if __name__ == "__main__":
main()
164 changes: 75 additions & 89 deletions fogros2-rt-x/fogros2_rt_x/replayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@
import socket

import rclpy
import rosbag2_py as rosbag
from rosbag2_py import Recorder
from rclpy.node import Node
from .dataset_utils import *
from fogros2_rt_x_msgs.msg import Step, Observation, Action
from .dataset_spec import DatasetFeatureSpec
from .plugins.conf_base import *
from .dataset_spec import DatasetFeatureSpec, FeatureSpec
from .dataset_spec import tf_feature_definition_to_ros_msg_class_str
from .conf_base import *
import time

import tensorflow_datasets as tfds
from std_srvs.srv import Empty

class DatasetReplayer(Node):
"""
Expand All @@ -62,8 +65,6 @@ def __init__(self):

self.declare_parameter("dataset_name", "berkeley_fanuc_manipulation")
dataset_name = self.get_parameter("dataset_name").value
self.config = get_dataset_plugin_config_from_str(dataset_name)
self.feature_spec = self.config.get_dataset_feature_spec()

self.declare_parameter("per_episode_interval", 5) # second
self.per_episode_interval = self.get_parameter("per_episode_interval").value
Expand All @@ -77,115 +78,103 @@ def __init__(self):

self.dataset = load_rlds_dataset(dataset_name)
self.logger = self.get_logger()
self.logger.info("Loading Dataset " + str(get_dataset_info([dataset_name])))
self.dataset_info = get_dataset_info([dataset_name])
self.logger.info("Loading Dataset " + str(self.dataset_info))

self.episode = next(iter(self.dataset))
self.dataset_features = self.dataset_info[0][1].features
self.step_features = self.dataset_features["steps"]
self.topics = list()
self.episode_counter = 1
self.init_topics_from_features(self.step_features)

# create an empty ros2 servcice to start and stop recording
self.new_episode_notification_client = self.create_client(Empty, 'new_episode_notification_service')
self.new_episode_notification_req = Empty.Request()

if replay_type == "as_separate_topics":
self.topic_name_to_publisher_dict = dict()
self.topic_name_to_recorder_dict = dict()
self.init_publisher_separate_topics()
elif replay_type == "as_single_topic":
self.init_publisher_single_topic()
elif replay_type == "both":
self.topic_name_to_publisher_dict = dict()
self.init_publisher_separate_topics()
self.init_publisher_single_topic()
else:
raise ValueError(
"Invalid replay_type: "
+ str(replay_type)
+ ". Must be one of: as_separate_topics, as_single_topic."
)

def init_publisher_separate_topics(self):
for observation in self.feature_spec.observation_spec:
publisher = self.create_publisher(
observation.ros_type, observation.ros_topic_name, 10
)
self.topic_name_to_publisher_dict[observation.ros_topic_name] = publisher
def init_topics_from_features(self, features):
for name, tf_feature in features.items():
if isinstance(tf_feature, tfds.features.FeaturesDict):
self.init_topics_from_features(tf_feature)
else:
if tf_feature.shape == () and tf_feature.dtype.is_bool:
self.topics.append(FeatureSpec(name, Scalar(dtype=tf.bool)))
else:
self.topics.append(FeatureSpec(name, tf_feature))





for action in self.feature_spec.action_spec:
def init_publisher_separate_topics(self):
for topic in self.topics:
publisher = self.create_publisher(
action.ros_type, action.ros_topic_name, 10
topic.ros_type, topic.ros_topic_name, 10
)
self.topic_name_to_publisher_dict[action.ros_topic_name] = publisher

for step in self.feature_spec.step_spec:
publisher = self.create_publisher(step.ros_type, step.ros_topic_name, 10)
self.topic_name_to_publisher_dict[step.ros_topic_name] = publisher
self.topic_name_to_publisher_dict[topic.ros_topic_name] = publisher

self.create_timer(
self.per_episode_interval, self.timer_callback_separate_topics
)

def init_publisher_single_topic(self):
self.publisher = self.create_publisher(Step, "step_info", 10)
callback = self.timer_callback_single_topic
self.create_timer(self.per_episode_interval, callback)

def timer_callback_separate_topics(self):
for step in self.episode["steps"]:
for observation in self.feature_spec.observation_spec:
if observation.tf_name not in step["observation"]:
self.logger.warn(
f"Observation {observation.tf_name} not found in step data"
for topic in self.topics:
if topic.tf_name in step:
# Fetch from step data
msg = topic.convert_tf_tensor_data_to_ros2_msg(
step[topic.tf_name]
)
continue
msg = observation.convert_tf_tensor_data_to_ros2_msg(
step["observation"][observation.tf_name]
)
self.logger.info(
f"Publishing observation {observation.tf_name} on topic {observation.ros_topic_name}"
)
self.topic_name_to_publisher_dict[observation.ros_topic_name].publish(
msg
)

for action in self.feature_spec.action_spec:
if type(step["action"]) is not dict:
# action is only one tensor/datatype, not a dictionary
msg = action.convert_tf_tensor_data_to_ros2_msg(step["action"])
else:
if action.tf_name not in step["action"]:
self.logger.warn(
f"Action {action.tf_name} not found in step data"
)
continue
msg = action.convert_tf_tensor_data_to_ros2_msg(
step["action"][action.tf_name]
self.logger.info(
f"Publishing step {topic.tf_name} on topic {topic.ros_topic_name}"
)
self.logger.info(
f"Publishing action {action.tf_name} on topic {action.ros_topic_name}"
)
self.topic_name_to_publisher_dict[action.ros_topic_name].publish(msg)

for step_feature in self.feature_spec.step_spec:
if step_feature.tf_name not in step:
self.logger.warn(
f"Step {step_feature.tf_name} not found in step data"
if type(step["observation"]) is dict and topic.tf_name in step["observation"]:
# Fetch from observation data
msg = topic.convert_tf_tensor_data_to_ros2_msg(
step["observation"][topic.tf_name]
)
continue
msg = step_feature.convert_tf_tensor_data_to_ros2_msg(
step[step_feature.tf_name]
)
self.logger.info(
f"Publishing step {step_feature.tf_name} on topic {step_feature.ros_topic_name}"
)
self.topic_name_to_publisher_dict[step_feature.ros_topic_name].publish(
msg
)

self.logger.info(
f"Publishing observation {topic.tf_name} on topic {topic.ros_topic_name}"
)
if type(step["action"]) is dict and topic.tf_name in step["action"]:
# Fetch from action data
msg = topic.convert_tf_tensor_data_to_ros2_msg(
step["action"][topic.tf_name]
)
self.logger.info(
f"Publishing action {topic.tf_name} on topic {topic.ros_topic_name}"
)

self.topic_name_to_publisher_dict[topic.ros_topic_name].publish(msg)

self.check_last_step_update_recorder(step)
time.sleep(self.per_step_interval)

self.episode = next(iter(self.dataset))

def timer_callback_single_topic(self):
for step in self.episode["steps"]:
msg = self.feature_spec.convert_tf_step_to_ros2_msg(
step, step["action"], step["observation"]
)
self.publisher.publish(msg)
self.episode = next(iter(self.dataset))
self.create_timer(
self.per_episode_interval, self.timer_callback_separate_topics
)

def check_last_step_update_recorder(self, step):
if step["is_last"]:
self.logger.info(f"End of the current episode")
self.episode_counter += 1
# restart_recorder(self.episode_counter)
# restart_recorder()
while not self.new_episode_notification_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info('Service not available, waiting again...')
self.future = self.new_episode_notification_client.call_async(self.new_episode_notification_req)


def main(args=None):
Expand All @@ -196,9 +185,6 @@ def main(args=None):

rclpy.spin(node)

# Destroy the timer attached to the node explicitly
# (optional - otherwise it will be done automatically
# when the garbage collector destroys the node object)
node.destroy_node()
rclpy.shutdown()

Expand Down
10 changes: 10 additions & 0 deletions fogros2-rt-x/launch/replayer.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,14 @@ def generate_launch_description():
)

ld.add_action(replayer_node)

recorder_node = Node(
package="fogros2_rt_x",
executable="recorder",
output="screen",
parameters = [
{"dataset_name": "berkeley_fanuc_manipulation"},
]
)
ld.add_action(recorder_node)
return ld
1 change: 1 addition & 0 deletions fogros2-rt-x/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
entry_points={
"console_scripts": [
"replayer = fogros2_rt_x.replayer:main",
"recorder = fogros2_rt_x.recorder:main",
],
"ros2cli.command": [
"fgr = fogros2_rt_x.cli:FogCommand",
Expand Down

0 comments on commit 2ad5326

Please sign in to comment.