Skip to content

Commit

Permalink
fixup! reinitialization
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri committed Aug 26, 2023
1 parent 5ec8aca commit f8e1ea2
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 70 deletions.
16 changes: 9 additions & 7 deletions bindings/pydrake/systems/_lcm_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ def _make_lcm_subscriber(channel,
lcm: LCM service instance.
use_cpp_serializer: Use C++ serializer to interface with LCM converter
systems that are implemented in C++. LCM types must be registered
in C++ via `BindCppSerializer`.
wait_for_message_on_initialization_timeout: The number of seconds
(wall-clock elapsed time) to wait for GetInternalMessageCount() to
be > 0. If this timeout is <= 0, then the initialization event does
not handle any new messages, but only processes existing received
messages. If the timeout is > 0, then the initialization event will
call lcm.HandleSubscriptions() until at least one message is
in C++ via ``BindCppSerializer``.
wait_for_message_on_initialization_timeout: Configures the behavior of
initialization events (see ``System.ExecuteInitializationEvents``
and ``Simulator.Initialize``) by specifying the number of seconds
(wall-clock elapsed time) to wait for a new message. If this
timeout is <= 0, initialization will copy any already-received
messages into the Context but will not process any new messages.
If this timeout is > 0, initialization will call
``lcm.HandleSubscriptions()`` until at least one message is
received or until the timeout. Pass ∞ to wait indefinitely.
"""
# TODO(eric.cousineau): Make `use_cpp_serializer` be kwarg-only.
Expand Down
91 changes: 50 additions & 41 deletions systems/lcm/lcm_subscriber_system.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,21 @@ LcmSubscriberSystem::~LcmSubscriberSystem() {

// This function processes the internal received message and store the results
// to the abstract states, which include both the message and message counts.
systems::EventStatus LcmSubscriberSystem::ProcessMessageAndStoreToAbstractState(
const Context<double>&, State<double>* state) const {
AbstractValues& abstract_state = state->get_mutable_abstract_state();
EventStatus LcmSubscriberSystem::ProcessMessageAndStoreToAbstractState(
const Context<double>& context, State<double>* state) const {
std::lock_guard<std::mutex> lock(received_message_mutex_);
if (!received_message_.empty()) {
serializer_->Deserialize(
received_message_.data(), received_message_.size(),
&abstract_state.get_mutable_value(kStateIndexMessage));
const int context_message_count = GetMessageCount(context);
if (context_message_count == received_message_count_) {
state->SetFrom(context.get_state());
return EventStatus::DidNothing();
}
abstract_state.get_mutable_value(kStateIndexMessageCount)
.get_mutable_value<int>() = received_message_count_;

return systems::EventStatus::Succeeded();
serializer_->Deserialize(
received_message_.data(), received_message_.size(),
&state->get_mutable_abstract_state().get_mutable_value(
kStateIndexMessage));
state->get_mutable_abstract_state<int>(kStateIndexMessageCount) =
received_message_count_;
return EventStatus::Succeeded();
}

int LcmSubscriberSystem::GetMessageCount(const Context<double>& context) const {
Expand Down Expand Up @@ -218,40 +220,47 @@ int LcmSubscriberSystem::GetInternalMessageCount() const {

EventStatus LcmSubscriberSystem::Initialize(const Context<double>& context,
State<double>* state) const {
if (GetInternalMessageCount() < 1 &&
wait_for_message_on_initialization_timeout_ > 0.0) {
log()->info("Waiting for messages on {}", channel_);

using Clock = std::chrono::steady_clock;
using Duration = std::chrono::duration<double>;
const auto start_time = Clock::now();

while (GetInternalMessageCount() < 1 &&
Duration(Clock::now() - start_time).count() <
wait_for_message_on_initialization_timeout_) {
// Since the DrakeLcmInterface will not be handling subscriptions during
// this initialization, we must handle them directly here.
lcm_->HandleSubscriptions(1 /* timeout_millis*/);
}
if (GetInternalMessageCount() > 0) {
// In the default case when waiting is disabled, we'll opportunistically try
// to update our state, but we might return EventStatus::DidNothing().
if (wait_for_message_on_initialization_timeout_ <= 0.0) {
return ProcessMessageAndStoreToAbstractState(context, state);
}

// The user has requested to pause initialization until context changes.
// Start by peeking to see if there's already a message waiting.
DRAKE_DEMAND(lcm_ != nullptr);
lcm_->HandleSubscriptions(0 /* timeout_millis */);
EventStatus result = ProcessMessageAndStoreToAbstractState(context, state);
if (result.severity() != EventStatus::kDidNothing) {
return result;
}

// No message was pending. We'll spin until we get one (or run out of time).
log()->info("Waiting for messages on {}", channel_);
using Clock = std::chrono::steady_clock;
using Duration = std::chrono::duration<double>;
const auto start_time = Clock::now();
while (Duration(Clock::now() - start_time).count() <
wait_for_message_on_initialization_timeout_) {
// Since the DrakeLcmInterface will not be handling subscriptions during
// this initialization, we must handle them directly here.
lcm_->HandleSubscriptions(1 /* timeout_millis*/);
result = ProcessMessageAndStoreToAbstractState(context, state);
if (result.severity() != EventStatus::kDidNothing) {
log()->info("Received message on {}", channel_);
} else {
// TODO(russt): Remove this once EventStatus are actually propagated.
throw std::runtime_error(fmt::format(
"Timed out without receiving any message on channel {} at url {}",
channel_, lcm_->get_lcm_url()));
return result;
}
}

if (GetInternalMessageCount() > 0) {
return ProcessMessageAndStoreToAbstractState(context, state);
} else if (wait_for_message_on_initialization_timeout_ <= 0.0) {
return EventStatus::DidNothing();
} else {
return EventStatus::Failed(
this,
fmt::format("Timed out without receiving any message on {}", channel_));
}
// We ran out of time.
result = EventStatus::Failed(
this,
fmt::format(
"Timed out without receiving any message on channel {} at url {}",
channel_, lcm_->get_lcm_url()));
// TODO(russt): Once EventStatus are actually propagated, return the status
// instead of throwing it.
throw std::runtime_error(result.message());
}

} // namespace lcm
Expand Down
37 changes: 19 additions & 18 deletions systems/lcm/lcm_subscriber_system.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ class LcmSubscriberSystem : public LeafSystem<double> {
* `wait_for_message_on_initialization_timeout > 0`, then the pointer must
* remain valid for the lifetime of the returned system.
*
* @param wait_for_message_on_initialization_timeout The number of seconds
* (wall-clock elapsed time) to wait for GetInternalMessageCount() to be > 0.
* If this timeout is <= 0, then the initialization event does not handle any
* new messages, but only processes existing received messages. If the
* timeout is > 0, then the initialization event will call
* lcm->HandleSubscriptions() until at least one message is received or until
* the timeout. Pass ∞ to wait indefinitely.
* @param wait_for_message_on_initialization_timeout Configures the behavior
* of initialization events (see System::ExecuteInitializationEvents() and
* Simulator::Initialize()) by specifying the number of seconds (wall-clock
* elapsed time) to wait for a new message. If this timeout is <= 0,
* initialization will copy any already-received messages into the Context but
* will not process any new messages. If this timeout is > 0, initialization
* will call lcm->HandleSubscriptions() until at least one message is received
* or until the timeout. Pass ∞ to wait indefinitely.
*/
template <typename LcmMessage>
static std::unique_ptr<LcmSubscriberSystem> Make(
Expand All @@ -92,13 +93,14 @@ class LcmSubscriberSystem : public LeafSystem<double> {
* `wait_for_message_on_initialization_timeout > 0`, then the pointer must
* remain valid for the lifetime of the returned system.
*
* @param wait_for_message_on_initialization_timeout The number of seconds
* (wall-clock elapsed time) to wait for GetInternalMessageCount() to be > 0.
* If this timeout is <= 0, then the initialization event does not handle any
* new messages, but only processes existing received messages. If the
* timeout is > 0, then the initialization event will call
* lcm->HandleSubscriptions() until at least one message is received or until
* the timeout. Pass ∞ to wait indefinitely.
* @param wait_for_message_on_initialization_timeout Configures the behavior
* of initialization events (see System::ExecuteInitializationEvents() and
* Simulator::Initialize()) by specifying the number of seconds (wall-clock
* elapsed time) to wait for a new message. If this timeout is <= 0,
* initialization will copy any already-received messages into the Context but
* will not process any new messages. If this timeout is > 0, initialization
* will call lcm->HandleSubscriptions() until at least one message is received
* or until the timeout. Pass ∞ to wait indefinitely.
*/
LcmSubscriberSystem(const std::string& channel,
std::shared_ptr<const SerializerInterface> serializer,
Expand Down Expand Up @@ -149,11 +151,10 @@ class LcmSubscriberSystem : public LeafSystem<double> {
systems::CompositeEventCollection<double>* events,
double* time) const final;

systems::EventStatus ProcessMessageAndStoreToAbstractState(
const Context<double>&, State<double>* state) const;
EventStatus ProcessMessageAndStoreToAbstractState(const Context<double>&,
State<double>* state) const;

systems::EventStatus Initialize(const Context<double>&,
State<double>* state) const;
EventStatus Initialize(const Context<double>&, State<double>* state) const;

// The channel on which to receive LCM messages.
const std::string channel_;
Expand Down
35 changes: 31 additions & 4 deletions systems/lcm/test/lcm_subscriber_system_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,26 @@ GTEST_TEST(LcmSubscriberSystemTest, InitializationNoWaitTest) {

// Produce a sample message.
SampleData sample_data;
sample_data.PublishAndHandle(&lcm, channel_name);
// Publish, but do not call handle.
Publish(&lcm, channel_name, sample_data.value);

// Use the initialization event to process the message.
// Fire the initialization event. It should NOT process the message.
dut->ExecuteInitializationEvents(context.get());
dut->CalcOutput(*context, output.get());

const AbstractValue* abstract_value = output->get_data(0);
ASSERT_NE(abstract_value, nullptr);
auto value = abstract_value->get_value<lcmt_drake_signal>();
EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, lcmt_drake_signal{}));

// Receive the message.
lcm.HandleSubscriptions(0);

// Now the initialization event should process the message.
dut->ExecuteInitializationEvents(context.get());
dut->CalcOutput(*context, output.get());
abstract_value = output->get_data(0);
ASSERT_NE(abstract_value, nullptr);
value = abstract_value->get_value<lcmt_drake_signal>();
EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value));
}

Expand Down Expand Up @@ -155,11 +166,27 @@ GTEST_TEST(LcmSubscriberSystemTest, InitializationWithWaitTest) {
// Now the initialization event calls handle and obtains the message.
dut->ExecuteInitializationEvents(context.get());
dut->CalcOutput(*context, output.get());

const AbstractValue* abstract_value = output->get_data(0);
ASSERT_NE(abstract_value, nullptr);
auto value = abstract_value->get_value<lcmt_drake_signal>();
EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value));

// A second initialization event will fail (timeout) with no *new* message.
DRAKE_EXPECT_THROWS_MESSAGE(dut->ExecuteInitializationEvents(context.get()),
"Timed out without receiving any message on "
"channel channel_name at url.*");

// Publish, but do not call handle, with a new message.
sample_data.value.timestamp += 1;
Publish(&lcm, channel_name, sample_data.value);

// Now the initialization event calls handle and obtains the message.
dut->ExecuteInitializationEvents(context.get());
dut->CalcOutput(*context, output.get());
abstract_value = output->get_data(0);
ASSERT_NE(abstract_value, nullptr);
value = abstract_value->get_value<lcmt_drake_signal>();
EXPECT_TRUE(CompareLcmtDrakeSignalMessages(value, sample_data.value));
}

GTEST_TEST(LcmSubscriberSystemTest, WaitTest) {
Expand Down

0 comments on commit f8e1ea2

Please sign in to comment.