From 0b2af9ffaa89cec3e1b27ced0839dc0e07902bd0 Mon Sep 17 00:00:00 2001 From: Russ Tedrake Date: Wed, 23 Aug 2023 21:50:44 -0400 Subject: [PATCH] Adds an initialization event to LcmSubscriberSystem Also supports a new opt-in ability to wait for at least one message to be received during initialization. --- bindings/pydrake/systems/_lcm_extra.py | 21 ++++- bindings/pydrake/systems/lcm_py.cc | 8 +- bindings/pydrake/systems/test/lcm_test.py | 6 +- systems/lcm/BUILD.bazel | 1 + systems/lcm/lcm_subscriber_system.cc | 87 ++++++++++++++++--- systems/lcm/lcm_subscriber_system.h | 52 +++++++++-- .../lcm/test/lcm_subscriber_system_test.cc | 87 +++++++++++++++++++ 7 files changed, 231 insertions(+), 31 deletions(-) diff --git a/bindings/pydrake/systems/_lcm_extra.py b/bindings/pydrake/systems/_lcm_extra.py index ab86afb92bbe..6ef10a187c8a 100644 --- a/bindings/pydrake/systems/_lcm_extra.py +++ b/bindings/pydrake/systems/_lcm_extra.py @@ -39,7 +39,12 @@ def Serialize(self, abstract_value): @staticmethod -def _make_lcm_subscriber(channel, lcm_type, lcm, use_cpp_serializer=False): +def _make_lcm_subscriber(channel, + lcm_type, + lcm, + use_cpp_serializer=False, + *, + wait_for_message_on_initialization_timeout=0.0): """Convenience to create an LCM subscriber system with a concrete type. Args: @@ -48,7 +53,16 @@ def _make_lcm_subscriber(channel, lcm_type, lcm, use_cpp_serializer=False): lcm: LCM service instance. use_cpp_serializer: Use C++ serializer to interface with LCM converter systems that are implemented in C++. LCM types must be registered - in C++ via `BindCppSerializer`. + in C++ via ``BindCppSerializer``. + wait_for_message_on_initialization_timeout: Configures the behavior of + initialization events (see ``System.ExecuteInitializationEvents`` + and ``Simulator.Initialize``) by specifying the number of seconds + (wall-clock elapsed time) to wait for a new message. If this + timeout is <= 0, initialization will copy any already-received + messages into the Context but will not process any new messages. + If this timeout is > 0, initialization will call + ``lcm.HandleSubscriptions()`` until at least one message is + received or until the timeout. Pass ∞ to wait indefinitely. """ # TODO(eric.cousineau): Make `use_cpp_serializer` be kwarg-only. # N.B. This documentation is actually public, as it is assigned to classes @@ -57,7 +71,8 @@ def _make_lcm_subscriber(channel, lcm_type, lcm, use_cpp_serializer=False): serializer = PySerializer(lcm_type) else: serializer = _Serializer_[lcm_type]() - return LcmSubscriberSystem(channel, serializer, lcm) + return LcmSubscriberSystem(channel, serializer, lcm, + wait_for_message_on_initialization_timeout) @staticmethod diff --git a/bindings/pydrake/systems/lcm_py.cc b/bindings/pydrake/systems/lcm_py.cc index 545ad039b193..88b61ad10adc 100644 --- a/bindings/pydrake/systems/lcm_py.cc +++ b/bindings/pydrake/systems/lcm_py.cc @@ -230,14 +230,16 @@ PYBIND11_MODULE(lcm, m) { py::class_>(m, "LcmSubscriberSystem") .def(py::init, - LcmInterfaceSystem*>(), + LcmInterfaceSystem*, double>(), py::arg("channel"), py::arg("serializer"), py::arg("lcm"), + py::arg("wait_for_message_on_initialization_timeout") = 0.0, // Keep alive, reference: `self` keeps `lcm` alive. py::keep_alive<1, 4>(), doc.LcmSubscriberSystem.ctor.doc) .def(py::init, - DrakeLcmInterface*>(), + std::shared_ptr, DrakeLcmInterface*, + double>(), py::arg("channel"), py::arg("serializer"), py::arg("lcm"), + py::arg("wait_for_message_on_initialization_timeout") = 0.0, // Keep alive, reference: `self` keeps `lcm` alive. py::keep_alive<1, 4>(), doc.LcmSubscriberSystem.ctor.doc) .def("WaitForMessage", &Class::WaitForMessage, diff --git a/bindings/pydrake/systems/test/lcm_test.py b/bindings/pydrake/systems/test/lcm_test.py index 882d371c6b9d..c13fd6ba00aa 100644 --- a/bindings/pydrake/systems/test/lcm_test.py +++ b/bindings/pydrake/systems/test/lcm_test.py @@ -151,7 +151,8 @@ def _process_event(self, dut): def test_subscriber(self): lcm = DrakeLcm() dut = mut.LcmSubscriberSystem.Make( - channel="TEST_CHANNEL", lcm_type=lcmt_quaternion, lcm=lcm) + channel="TEST_CHANNEL", lcm_type=lcmt_quaternion, lcm=lcm, + wait_for_message_on_initialization_timeout=0.0) model_message = self._model_message() lcm.Publish(channel="TEST_CHANNEL", buffer=model_message.encode()) lcm.HandleSubscriptions(0) @@ -172,7 +173,8 @@ def test_subscriber_cpp(self): lcm = DrakeLcm() dut = mut.LcmSubscriberSystem.Make( channel="TEST_CHANNEL", lcm_type=lcmt_quaternion, lcm=lcm, - use_cpp_serializer=True) + use_cpp_serializer=True, + wait_for_message_on_initialization_timeout=0.0) model_message = self._model_message() lcm.Publish(channel="TEST_CHANNEL", buffer=model_message.encode()) lcm.HandleSubscriptions(0) diff --git a/systems/lcm/BUILD.bazel b/systems/lcm/BUILD.bazel index 80e4134eb5f6..783e3ef2905b 100644 --- a/systems/lcm/BUILD.bazel +++ b/systems/lcm/BUILD.bazel @@ -203,6 +203,7 @@ drake_cc_googletest( name = "lcm_subscriber_system_test", deps = [ ":lcm_subscriber_system", + "//common/test_utilities:expect_throws_message", "//lcm:drake_lcm", "//lcm:lcmt_drake_signal_utils", ], diff --git a/systems/lcm/lcm_subscriber_system.cc b/systems/lcm/lcm_subscriber_system.cc index 736b98d3a4b2..17e01e253e69 100644 --- a/systems/lcm/lcm_subscriber_system.cc +++ b/systems/lcm/lcm_subscriber_system.cc @@ -23,12 +23,19 @@ constexpr int kMagic = 6832; // An arbitrary value. LcmSubscriberSystem::LcmSubscriberSystem( const std::string& channel, std::shared_ptr serializer, - drake::lcm::DrakeLcmInterface* lcm) + drake::lcm::DrakeLcmInterface* lcm, + double wait_for_message_on_initialization_timeout) : channel_(channel), serializer_(std::move(serializer)), - magic_number_{kMagic} { + magic_number_{kMagic}, + // Only capture the lcm pointer if it is required. + lcm_{wait_for_message_on_initialization_timeout > 0 ? lcm : nullptr}, + wait_for_message_on_initialization_timeout_{ + wait_for_message_on_initialization_timeout} { DRAKE_THROW_UNLESS(serializer_ != nullptr); - DRAKE_THROW_UNLESS(lcm != nullptr); + DRAKE_THROW_UNLESS(wait_for_message_on_initialization_timeout_ <= 0 || + lcm_ != nullptr); + DRAKE_THROW_UNLESS(!std::isnan(wait_for_message_on_initialization_timeout)); subscription_ = lcm->Subscribe( channel_, [this](const void* buffer, int size) { @@ -58,6 +65,11 @@ LcmSubscriberSystem::LcmSubscriberSystem( this->DeclareForcedUnrestrictedUpdateEvent( &LcmSubscriberSystem::ProcessMessageAndStoreToAbstractState); + // On initialization, we process any existing received messages and maybe + // wait for new messages. + this->DeclareInitializationUnrestrictedUpdateEvent( + &LcmSubscriberSystem::Initialize); + set_name(make_name(channel_)); } @@ -68,19 +80,21 @@ LcmSubscriberSystem::~LcmSubscriberSystem() { // This function processes the internal received message and store the results // to the abstract states, which include both the message and message counts. -systems::EventStatus LcmSubscriberSystem::ProcessMessageAndStoreToAbstractState( - const Context&, State* state) const { - AbstractValues& abstract_state = state->get_mutable_abstract_state(); +EventStatus LcmSubscriberSystem::ProcessMessageAndStoreToAbstractState( + const Context& context, State* state) const { std::lock_guard lock(received_message_mutex_); - if (!received_message_.empty()) { - serializer_->Deserialize( - received_message_.data(), received_message_.size(), - &abstract_state.get_mutable_value(kStateIndexMessage)); + const int context_message_count = GetMessageCount(context); + if (context_message_count == received_message_count_) { + state->SetFrom(context.get_state()); + return EventStatus::DidNothing(); } - abstract_state.get_mutable_value(kStateIndexMessageCount) - .get_mutable_value() = received_message_count_; - - return systems::EventStatus::Succeeded(); + serializer_->Deserialize( + received_message_.data(), received_message_.size(), + &state->get_mutable_abstract_state().get_mutable_value( + kStateIndexMessage)); + state->get_mutable_abstract_state(kStateIndexMessageCount) = + received_message_count_; + return EventStatus::Succeeded(); } int LcmSubscriberSystem::GetMessageCount(const Context& context) const { @@ -204,6 +218,51 @@ int LcmSubscriberSystem::GetInternalMessageCount() const { return received_message_count_; } +EventStatus LcmSubscriberSystem::Initialize(const Context& context, + State* state) const { + // In the default case when waiting is disabled, we'll opportunistically try + // to update our state, but we might return EventStatus::DidNothing(). + if (wait_for_message_on_initialization_timeout_ <= 0.0) { + return ProcessMessageAndStoreToAbstractState(context, state); + } + + // The user has requested to pause initialization until context changes. + // Start by peeking to see if there's already a message waiting. + DRAKE_DEMAND(lcm_ != nullptr); + lcm_->HandleSubscriptions(0 /* timeout_millis */); + EventStatus result = ProcessMessageAndStoreToAbstractState(context, state); + if (result.severity() != EventStatus::kDidNothing) { + return result; + } + + // No message was pending. We'll spin until we get one (or run out of time). + log()->info("Waiting for messages on {}", channel_); + using Clock = std::chrono::steady_clock; + using Duration = std::chrono::duration; + const auto start_time = Clock::now(); + while (Duration(Clock::now() - start_time).count() < + wait_for_message_on_initialization_timeout_) { + // Since the DrakeLcmInterface will not be handling subscriptions during + // this initialization, we must handle them directly here. + lcm_->HandleSubscriptions(1 /* timeout_millis*/); + result = ProcessMessageAndStoreToAbstractState(context, state); + if (result.severity() != EventStatus::kDidNothing) { + log()->info("Received message on {}", channel_); + return result; + } + } + + // We ran out of time. + result = EventStatus::Failed( + this, + fmt::format( + "Timed out without receiving any message on channel {} at url {}", + channel_, lcm_->get_lcm_url())); + // TODO(russt): Once EventStatus are actually propagated, return the status + // instead of throwing it. + throw std::runtime_error(result.message()); +} + } // namespace lcm } // namespace systems } // namespace drake diff --git a/systems/lcm/lcm_subscriber_system.h b/systems/lcm/lcm_subscriber_system.h index e0e66d25137a..2c6714baa5d0 100644 --- a/systems/lcm/lcm_subscriber_system.h +++ b/systems/lcm/lcm_subscriber_system.h @@ -57,30 +57,55 @@ class LcmSubscriberSystem : public LeafSystem { * * @param[in] channel The LCM channel on which to subscribe. * - * @param lcm A non-null pointer to the LCM subsystem to subscribe on. + * @param lcm A non-null pointer to the LCM subsystem to subscribe on. If + * `wait_for_message_on_initialization_timeout > 0`, then the pointer must + * remain valid for the lifetime of the returned system. + * + * @param wait_for_message_on_initialization_timeout Configures the behavior + * of initialization events (see System::ExecuteInitializationEvents() and + * Simulator::Initialize()) by specifying the number of seconds (wall-clock + * elapsed time) to wait for a new message. If this timeout is <= 0, + * initialization will copy any already-received messages into the Context but + * will not process any new messages. If this timeout is > 0, initialization + * will call lcm->HandleSubscriptions() until at least one message is received + * or until the timeout. Pass ∞ to wait indefinitely. */ template static std::unique_ptr Make( - const std::string& channel, drake::lcm::DrakeLcmInterface* lcm) { + const std::string& channel, drake::lcm::DrakeLcmInterface* lcm, + double wait_for_message_on_initialization_timeout = 0.0) { return std::make_unique( - channel, std::make_unique>(), lcm); + channel, std::make_unique>(), lcm, + wait_for_message_on_initialization_timeout); } /** * Constructor that returns a subscriber System that provides message objects - * on its sole abstract-valued output port. The type of the message object is - * determined by the @p serializer. + * on its sole abstract-valued output port. The type of the message object + * is determined by the @p serializer. * * @param[in] channel The LCM channel on which to subscribe. * * @param[in] serializer The serializer that converts between byte vectors * and LCM message objects. Cannot be null. * - * @param lcm A non-null pointer to the LCM subsystem to subscribe on. + * @param lcm A non-null pointer to the LCM subsystem to subscribe on. If + * `wait_for_message_on_initialization_timeout > 0`, then the pointer must + * remain valid for the lifetime of the returned system. + * + * @param wait_for_message_on_initialization_timeout Configures the behavior + * of initialization events (see System::ExecuteInitializationEvents() and + * Simulator::Initialize()) by specifying the number of seconds (wall-clock + * elapsed time) to wait for a new message. If this timeout is <= 0, + * initialization will copy any already-received messages into the Context but + * will not process any new messages. If this timeout is > 0, initialization + * will call lcm->HandleSubscriptions() until at least one message is received + * or until the timeout. Pass ∞ to wait indefinitely. */ LcmSubscriberSystem(const std::string& channel, std::shared_ptr serializer, - drake::lcm::DrakeLcmInterface* lcm); + drake::lcm::DrakeLcmInterface* lcm, + double wait_for_message_on_initialization_timeout = 0.0); ~LcmSubscriberSystem() override; @@ -126,8 +151,10 @@ class LcmSubscriberSystem : public LeafSystem { systems::CompositeEventCollection* events, double* time) const final; - systems::EventStatus ProcessMessageAndStoreToAbstractState( - const Context&, State* state) const; + EventStatus ProcessMessageAndStoreToAbstractState(const Context&, + State* state) const; + + EventStatus Initialize(const Context&, State* state) const; // The channel on which to receive LCM messages. const std::string channel_; @@ -154,6 +181,13 @@ class LcmSubscriberSystem : public LeafSystem { // A little hint to help catch use-after-free. int magic_number_{}; + + // The lcm interface is (maybe) used to handle subscriptions during + // Initialization. + drake::lcm::DrakeLcmInterface* const lcm_; + + // A timeout in seconds. + const double wait_for_message_on_initialization_timeout_; }; } // namespace lcm diff --git a/systems/lcm/test/lcm_subscriber_system_test.cc b/systems/lcm/test/lcm_subscriber_system_test.cc index e192c1f0c11e..97a59b8f495b 100644 --- a/systems/lcm/test/lcm_subscriber_system_test.cc +++ b/systems/lcm/test/lcm_subscriber_system_test.cc @@ -5,6 +5,7 @@ #include +#include "drake/common/test_utilities/expect_throws_message.h" #include "drake/lcm/drake_lcm.h" #include "drake/lcm/lcmt_drake_signal_utils.h" #include "drake/lcmt_drake_signal.hpp" @@ -102,6 +103,92 @@ GTEST_TEST(LcmSubscriberSystemTest, ReceiveTest) { EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value)); } +// Tests LcmSubscriberSystem using a Serializer. +GTEST_TEST(LcmSubscriberSystemTest, InitializationNoWaitTest) { + drake::lcm::DrakeLcm lcm; + const std::string channel_name = "channel_name"; + + // The "device under test". + auto dut = LcmSubscriberSystem::Make(channel_name, &lcm); + + // Establish the context and output for the dut. + std::unique_ptr> context = dut->CreateDefaultContext(); + std::unique_ptr> output = dut->AllocateOutput(); + + // Produce a sample message. + SampleData sample_data; + // Publish, but do not call handle. + Publish(&lcm, channel_name, sample_data.value); + + // Fire the initialization event. It should NOT process the message. + dut->ExecuteInitializationEvents(context.get()); + dut->CalcOutput(*context, output.get()); + const AbstractValue* abstract_value = output->get_data(0); + ASSERT_NE(abstract_value, nullptr); + auto value = abstract_value->get_value(); + EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, lcmt_drake_signal{})); + + // Receive the message. + lcm.HandleSubscriptions(0); + + // Now the initialization event should process the message. + dut->ExecuteInitializationEvents(context.get()); + dut->CalcOutput(*context, output.get()); + abstract_value = output->get_data(0); + ASSERT_NE(abstract_value, nullptr); + value = abstract_value->get_value(); + EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value)); +} + +GTEST_TEST(LcmSubscriberSystemTest, InitializationWithWaitTest) { + drake::lcm::DrakeLcm lcm; + const std::string channel_name = "channel_name"; + const double wait_for_message_on_initialization_timeout{0.01}; + + // The "device under test". + auto dut = LcmSubscriberSystem::Make( + channel_name, &lcm, wait_for_message_on_initialization_timeout); + + // Establish the context and output for the dut. + std::unique_ptr> context = dut->CreateDefaultContext(); + std::unique_ptr> output = dut->AllocateOutput(); + + // The initialization event will fail (timeout) if no message is received. + DRAKE_EXPECT_THROWS_MESSAGE(dut->ExecuteInitializationEvents(context.get()), + "Timed out without receiving any message on " + "channel channel_name at url.*"); + + // Produce a sample message. + SampleData sample_data; + // Publish, but do not call handle. + Publish(&lcm, channel_name, sample_data.value); + + // Now the initialization event calls handle and obtains the message. + dut->ExecuteInitializationEvents(context.get()); + dut->CalcOutput(*context, output.get()); + const AbstractValue* abstract_value = output->get_data(0); + ASSERT_NE(abstract_value, nullptr); + auto value = abstract_value->get_value(); + EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value)); + + // A second initialization event will fail (timeout) with no *new* message. + DRAKE_EXPECT_THROWS_MESSAGE(dut->ExecuteInitializationEvents(context.get()), + "Timed out without receiving any message on " + "channel channel_name at url.*"); + + // Publish, but do not call handle, with a new message. + sample_data.value.timestamp += 1; + Publish(&lcm, channel_name, sample_data.value); + + // Now the initialization event calls handle and obtains the message. + dut->ExecuteInitializationEvents(context.get()); + dut->CalcOutput(*context, output.get()); + abstract_value = output->get_data(0); + ASSERT_NE(abstract_value, nullptr); + value = abstract_value->get_value(); + EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value)); +} + GTEST_TEST(LcmSubscriberSystemTest, WaitTest) { // Ensure that `WaitForMessage` works as expected. drake::lcm::DrakeLcm lcm;