Skip to content

Commit

Permalink
Adds an initialization event to LcmSubscriberSystem (#20072)
Browse files Browse the repository at this point in the history
Also supports a new opt-in ability to wait for at least one message to
be received during initialization.
  • Loading branch information
RussTedrake authored Aug 28, 2023
1 parent 5c1284b commit 086e93a
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 30 deletions.
21 changes: 18 additions & 3 deletions bindings/pydrake/systems/_lcm_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def Serialize(self, abstract_value):


@staticmethod
def _make_lcm_subscriber(channel, lcm_type, lcm, use_cpp_serializer=False):
def _make_lcm_subscriber(channel,
lcm_type,
lcm,
use_cpp_serializer=False,
*,
wait_for_message_on_initialization_timeout=0.0):
"""Convenience to create an LCM subscriber system with a concrete type.
Args:
Expand All @@ -48,7 +53,16 @@ def _make_lcm_subscriber(channel, lcm_type, lcm, use_cpp_serializer=False):
lcm: LCM service instance.
use_cpp_serializer: Use C++ serializer to interface with LCM converter
systems that are implemented in C++. LCM types must be registered
in C++ via `BindCppSerializer`.
in C++ via ``BindCppSerializer``.
wait_for_message_on_initialization_timeout: Configures the behavior of
initialization events (see ``System.ExecuteInitializationEvents``
and ``Simulator.Initialize``) by specifying the number of seconds
(wall-clock elapsed time) to wait for a new message. If this
timeout is <= 0, initialization will copy any already-received
messages into the Context but will not process any new messages.
If this timeout is > 0, initialization will call
``lcm.HandleSubscriptions()`` until at least one message is
received or until the timeout. Pass ∞ to wait indefinitely.
"""
# TODO(eric.cousineau): Make `use_cpp_serializer` be kwarg-only.
# N.B. This documentation is actually public, as it is assigned to classes
Expand All @@ -57,7 +71,8 @@ def _make_lcm_subscriber(channel, lcm_type, lcm, use_cpp_serializer=False):
serializer = PySerializer(lcm_type)
else:
serializer = _Serializer_[lcm_type]()
return LcmSubscriberSystem(channel, serializer, lcm)
return LcmSubscriberSystem(channel, serializer, lcm,
wait_for_message_on_initialization_timeout)


@staticmethod
Expand Down
8 changes: 5 additions & 3 deletions bindings/pydrake/systems/lcm_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,16 @@ PYBIND11_MODULE(lcm, m) {
py::class_<Class, LeafSystem<double>>(m, "LcmSubscriberSystem")
.def(py::init<const std::string&,
std::shared_ptr<const SerializerInterface>,
LcmInterfaceSystem*>(),
LcmInterfaceSystem*, double>(),
py::arg("channel"), py::arg("serializer"), py::arg("lcm"),
py::arg("wait_for_message_on_initialization_timeout") = 0.0,
// Keep alive, reference: `self` keeps `lcm` alive.
py::keep_alive<1, 4>(), doc.LcmSubscriberSystem.ctor.doc)
.def(py::init<const std::string&,
std::shared_ptr<const SerializerInterface>,
DrakeLcmInterface*>(),
std::shared_ptr<const SerializerInterface>, DrakeLcmInterface*,
double>(),
py::arg("channel"), py::arg("serializer"), py::arg("lcm"),
py::arg("wait_for_message_on_initialization_timeout") = 0.0,
// Keep alive, reference: `self` keeps `lcm` alive.
py::keep_alive<1, 4>(), doc.LcmSubscriberSystem.ctor.doc)
.def("WaitForMessage", &Class::WaitForMessage,
Expand Down
6 changes: 4 additions & 2 deletions bindings/pydrake/systems/test/lcm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def _process_event(self, dut):
def test_subscriber(self):
lcm = DrakeLcm()
dut = mut.LcmSubscriberSystem.Make(
channel="TEST_CHANNEL", lcm_type=lcmt_quaternion, lcm=lcm)
channel="TEST_CHANNEL", lcm_type=lcmt_quaternion, lcm=lcm,
wait_for_message_on_initialization_timeout=0.0)
model_message = self._model_message()
lcm.Publish(channel="TEST_CHANNEL", buffer=model_message.encode())
lcm.HandleSubscriptions(0)
Expand All @@ -172,7 +173,8 @@ def test_subscriber_cpp(self):
lcm = DrakeLcm()
dut = mut.LcmSubscriberSystem.Make(
channel="TEST_CHANNEL", lcm_type=lcmt_quaternion, lcm=lcm,
use_cpp_serializer=True)
use_cpp_serializer=True,
wait_for_message_on_initialization_timeout=0.0)
model_message = self._model_message()
lcm.Publish(channel="TEST_CHANNEL", buffer=model_message.encode())
lcm.HandleSubscriptions(0)
Expand Down
1 change: 1 addition & 0 deletions systems/lcm/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ drake_cc_googletest(
name = "lcm_subscriber_system_test",
deps = [
":lcm_subscriber_system",
"//common/test_utilities:expect_throws_message",
"//lcm:drake_lcm",
"//lcm:lcmt_drake_signal_utils",
],
Expand Down
84 changes: 71 additions & 13 deletions systems/lcm/lcm_subscriber_system.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ constexpr int kMagic = 6832; // An arbitrary value.
LcmSubscriberSystem::LcmSubscriberSystem(
const std::string& channel,
std::shared_ptr<const SerializerInterface> serializer,
drake::lcm::DrakeLcmInterface* lcm)
drake::lcm::DrakeLcmInterface* lcm,
double wait_for_message_on_initialization_timeout)
: channel_(channel),
serializer_(std::move(serializer)),
magic_number_{kMagic} {
magic_number_{kMagic},
// Only capture the lcm pointer if it is required.
lcm_{wait_for_message_on_initialization_timeout > 0 ? lcm : nullptr},
wait_for_message_on_initialization_timeout_{
wait_for_message_on_initialization_timeout} {
DRAKE_THROW_UNLESS(serializer_ != nullptr);
DRAKE_THROW_UNLESS(lcm != nullptr);
DRAKE_THROW_UNLESS(!std::isnan(wait_for_message_on_initialization_timeout));

subscription_ = lcm->Subscribe(
channel_, [this](const void* buffer, int size) {
Expand Down Expand Up @@ -58,6 +64,11 @@ LcmSubscriberSystem::LcmSubscriberSystem(
this->DeclareForcedUnrestrictedUpdateEvent(
&LcmSubscriberSystem::ProcessMessageAndStoreToAbstractState);

// On initialization, we process any existing received messages and maybe
// wait for new messages.
this->DeclareInitializationUnrestrictedUpdateEvent(
&LcmSubscriberSystem::Initialize);

set_name(make_name(channel_));
}

Expand All @@ -68,19 +79,21 @@ LcmSubscriberSystem::~LcmSubscriberSystem() {

// This function processes the internal received message and store the results
// to the abstract states, which include both the message and message counts.
systems::EventStatus LcmSubscriberSystem::ProcessMessageAndStoreToAbstractState(
const Context<double>&, State<double>* state) const {
AbstractValues& abstract_state = state->get_mutable_abstract_state();
EventStatus LcmSubscriberSystem::ProcessMessageAndStoreToAbstractState(
const Context<double>& context, State<double>* state) const {
std::lock_guard<std::mutex> lock(received_message_mutex_);
if (!received_message_.empty()) {
serializer_->Deserialize(
received_message_.data(), received_message_.size(),
&abstract_state.get_mutable_value(kStateIndexMessage));
const int context_message_count = GetMessageCount(context);
if (context_message_count == received_message_count_) {
state->SetFrom(context.get_state());
return EventStatus::DidNothing();
}
abstract_state.get_mutable_value(kStateIndexMessageCount)
.get_mutable_value<int>() = received_message_count_;

return systems::EventStatus::Succeeded();
serializer_->Deserialize(
received_message_.data(), received_message_.size(),
&state->get_mutable_abstract_state().get_mutable_value(
kStateIndexMessage));
state->get_mutable_abstract_state<int>(kStateIndexMessageCount) =
received_message_count_;
return EventStatus::Succeeded();
}

int LcmSubscriberSystem::GetMessageCount(const Context<double>& context) const {
Expand Down Expand Up @@ -204,6 +217,51 @@ int LcmSubscriberSystem::GetInternalMessageCount() const {
return received_message_count_;
}

EventStatus LcmSubscriberSystem::Initialize(const Context<double>& context,
State<double>* state) const {
// In the default case when waiting is disabled, we'll opportunistically try
// to update our state, but we might return EventStatus::DidNothing().
if (wait_for_message_on_initialization_timeout_ <= 0.0) {
return ProcessMessageAndStoreToAbstractState(context, state);
}

// The user has requested to pause initialization until context changes.
// Start by peeking to see if there's already a message waiting.
DRAKE_DEMAND(lcm_ != nullptr);
lcm_->HandleSubscriptions(0 /* timeout_millis */);
EventStatus result = ProcessMessageAndStoreToAbstractState(context, state);
if (result.severity() != EventStatus::kDidNothing) {
return result;
}

// No message was pending. We'll spin until we get one (or run out of time).
log()->info("Waiting for messages on {}", channel_);
using Clock = std::chrono::steady_clock;
using Duration = std::chrono::duration<double>;
const auto start_time = Clock::now();
while (Duration(Clock::now() - start_time).count() <
wait_for_message_on_initialization_timeout_) {
// Since the DrakeLcmInterface will not be handling subscriptions during
// this initialization, we must handle them directly here.
lcm_->HandleSubscriptions(1 /* timeout_millis*/);
result = ProcessMessageAndStoreToAbstractState(context, state);
if (result.severity() != EventStatus::kDidNothing) {
log()->info("Received message on {}", channel_);
return result;
}
}

// We ran out of time.
result = EventStatus::Failed(
this,
fmt::format(
"Timed out without receiving any message on channel {} at url {}",
channel_, lcm_->get_lcm_url()));
// TODO(russt): Once EventStatus are actually propagated, return the status
// instead of throwing it.
throw std::runtime_error(result.message());
}

} // namespace lcm
} // namespace systems
} // namespace drake
Expand Down
52 changes: 43 additions & 9 deletions systems/lcm/lcm_subscriber_system.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,55 @@ class LcmSubscriberSystem : public LeafSystem<double> {
*
* @param[in] channel The LCM channel on which to subscribe.
*
* @param lcm A non-null pointer to the LCM subsystem to subscribe on.
* @param lcm A non-null pointer to the LCM subsystem to subscribe on. If
* `wait_for_message_on_initialization_timeout > 0`, then the pointer must
* remain valid for the lifetime of the returned system.
*
* @param wait_for_message_on_initialization_timeout Configures the behavior
* of initialization events (see System::ExecuteInitializationEvents() and
* Simulator::Initialize()) by specifying the number of seconds (wall-clock
* elapsed time) to wait for a new message. If this timeout is <= 0,
* initialization will copy any already-received messages into the Context but
* will not process any new messages. If this timeout is > 0, initialization
* will call lcm->HandleSubscriptions() until at least one message is received
* or until the timeout. Pass ∞ to wait indefinitely.
*/
template <typename LcmMessage>
static std::unique_ptr<LcmSubscriberSystem> Make(
const std::string& channel, drake::lcm::DrakeLcmInterface* lcm) {
const std::string& channel, drake::lcm::DrakeLcmInterface* lcm,
double wait_for_message_on_initialization_timeout = 0.0) {
return std::make_unique<LcmSubscriberSystem>(
channel, std::make_unique<Serializer<LcmMessage>>(), lcm);
channel, std::make_unique<Serializer<LcmMessage>>(), lcm,
wait_for_message_on_initialization_timeout);
}

/**
* Constructor that returns a subscriber System that provides message objects
* on its sole abstract-valued output port. The type of the message object is
* determined by the @p serializer.
* on its sole abstract-valued output port. The type of the message object
* is determined by the @p serializer.
*
* @param[in] channel The LCM channel on which to subscribe.
*
* @param[in] serializer The serializer that converts between byte vectors
* and LCM message objects. Cannot be null.
*
* @param lcm A non-null pointer to the LCM subsystem to subscribe on.
* @param lcm A non-null pointer to the LCM subsystem to subscribe on. If
* `wait_for_message_on_initialization_timeout > 0`, then the pointer must
* remain valid for the lifetime of the returned system.
*
* @param wait_for_message_on_initialization_timeout Configures the behavior
* of initialization events (see System::ExecuteInitializationEvents() and
* Simulator::Initialize()) by specifying the number of seconds (wall-clock
* elapsed time) to wait for a new message. If this timeout is <= 0,
* initialization will copy any already-received messages into the Context but
* will not process any new messages. If this timeout is > 0, initialization
* will call lcm->HandleSubscriptions() until at least one message is received
* or until the timeout. Pass ∞ to wait indefinitely.
*/
LcmSubscriberSystem(const std::string& channel,
std::shared_ptr<const SerializerInterface> serializer,
drake::lcm::DrakeLcmInterface* lcm);
drake::lcm::DrakeLcmInterface* lcm,
double wait_for_message_on_initialization_timeout = 0.0);

~LcmSubscriberSystem() override;

Expand Down Expand Up @@ -126,8 +151,10 @@ class LcmSubscriberSystem : public LeafSystem<double> {
systems::CompositeEventCollection<double>* events,
double* time) const final;

systems::EventStatus ProcessMessageAndStoreToAbstractState(
const Context<double>&, State<double>* state) const;
EventStatus ProcessMessageAndStoreToAbstractState(const Context<double>&,
State<double>* state) const;

EventStatus Initialize(const Context<double>&, State<double>* state) const;

// The channel on which to receive LCM messages.
const std::string channel_;
Expand All @@ -154,6 +181,13 @@ class LcmSubscriberSystem : public LeafSystem<double> {

// A little hint to help catch use-after-free.
int magic_number_{};

// The lcm interface is (maybe) used to handle subscriptions during
// Initialization.
drake::lcm::DrakeLcmInterface* const lcm_;

// A timeout in seconds.
const double wait_for_message_on_initialization_timeout_;
};

} // namespace lcm
Expand Down
87 changes: 87 additions & 0 deletions systems/lcm/test/lcm_subscriber_system_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <gtest/gtest.h>

#include "drake/common/test_utilities/expect_throws_message.h"
#include "drake/lcm/drake_lcm.h"
#include "drake/lcm/lcmt_drake_signal_utils.h"
#include "drake/lcmt_drake_signal.hpp"
Expand Down Expand Up @@ -102,6 +103,92 @@ GTEST_TEST(LcmSubscriberSystemTest, ReceiveTest) {
EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value));
}

// Tests LcmSubscriberSystem using a Serializer.
GTEST_TEST(LcmSubscriberSystemTest, InitializationNoWaitTest) {
drake::lcm::DrakeLcm lcm;
const std::string channel_name = "channel_name";

// The "device under test".
auto dut = LcmSubscriberSystem::Make<lcmt_drake_signal>(channel_name, &lcm);

// Establish the context and output for the dut.
std::unique_ptr<Context<double>> context = dut->CreateDefaultContext();
std::unique_ptr<SystemOutput<double>> output = dut->AllocateOutput();

// Produce a sample message.
SampleData sample_data;
// Publish, but do not call handle.
Publish(&lcm, channel_name, sample_data.value);

// Fire the initialization event. It should NOT process the message.
dut->ExecuteInitializationEvents(context.get());
dut->CalcOutput(*context, output.get());
const AbstractValue* abstract_value = output->get_data(0);
ASSERT_NE(abstract_value, nullptr);
auto value = abstract_value->get_value<lcmt_drake_signal>();
EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, lcmt_drake_signal{}));

// Receive the message.
lcm.HandleSubscriptions(0);

// Now the initialization event should process the message.
dut->ExecuteInitializationEvents(context.get());
dut->CalcOutput(*context, output.get());
abstract_value = output->get_data(0);
ASSERT_NE(abstract_value, nullptr);
value = abstract_value->get_value<lcmt_drake_signal>();
EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value));
}

GTEST_TEST(LcmSubscriberSystemTest, InitializationWithWaitTest) {
drake::lcm::DrakeLcm lcm;
const std::string channel_name = "channel_name";
const double wait_for_message_on_initialization_timeout{0.01};

// The "device under test".
auto dut = LcmSubscriberSystem::Make<lcmt_drake_signal>(
channel_name, &lcm, wait_for_message_on_initialization_timeout);

// Establish the context and output for the dut.
std::unique_ptr<Context<double>> context = dut->CreateDefaultContext();
std::unique_ptr<SystemOutput<double>> output = dut->AllocateOutput();

// The initialization event will fail (timeout) if no message is received.
DRAKE_EXPECT_THROWS_MESSAGE(dut->ExecuteInitializationEvents(context.get()),
"Timed out without receiving any message on "
"channel channel_name at url.*");

// Produce a sample message.
SampleData sample_data;
// Publish, but do not call handle.
Publish(&lcm, channel_name, sample_data.value);

// Now the initialization event calls handle and obtains the message.
dut->ExecuteInitializationEvents(context.get());
dut->CalcOutput(*context, output.get());
const AbstractValue* abstract_value = output->get_data(0);
ASSERT_NE(abstract_value, nullptr);
auto value = abstract_value->get_value<lcmt_drake_signal>();
EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value));

// A second initialization event will fail (timeout) with no *new* message.
DRAKE_EXPECT_THROWS_MESSAGE(dut->ExecuteInitializationEvents(context.get()),
"Timed out without receiving any message on "
"channel channel_name at url.*");

// Publish, but do not call handle, with a new message.
sample_data.value.timestamp += 1;
Publish(&lcm, channel_name, sample_data.value);

// Now the initialization event calls handle and obtains the message.
dut->ExecuteInitializationEvents(context.get());
dut->CalcOutput(*context, output.get());
abstract_value = output->get_data(0);
ASSERT_NE(abstract_value, nullptr);
value = abstract_value->get_value<lcmt_drake_signal>();
EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value));
}

GTEST_TEST(LcmSubscriberSystemTest, WaitTest) {
// Ensure that `WaitForMessage` works as expected.
drake::lcm::DrakeLcm lcm;
Expand Down

0 comments on commit 086e93a

Please sign in to comment.