Skip to content

Commit

Permalink
Implements System<T>::ExecuteInitializationEvents()
Browse files Browse the repository at this point in the history
and uses it in MeshcatPoseSliders() and JointSliders()
  • Loading branch information
RussTedrake committed Jun 13, 2023
1 parent af98335 commit 0fcaf9b
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 28 deletions.
3 changes: 3 additions & 0 deletions bindings/pydrake/systems/framework_py_systems.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,9 @@ struct Impl {
.def("CalcForcedUnrestrictedUpdate",
&System<T>::CalcForcedUnrestrictedUpdate, py::arg("context"),
py::arg("state"), doc.System.CalcForcedUnrestrictedUpdate.doc)
.def("ExecuteInitializationEvents",
&System<T>::ExecuteInitializationEvents, py::arg("context"),
doc.System.ExecuteInitializationEvents.doc)
.def("GetUniquePeriodicDiscreteUpdateAttribute",
&System<T>::GetUniquePeriodicDiscreteUpdateAttribute,
doc.System.GetUniquePeriodicDiscreteUpdateAttribute.doc)
Expand Down
21 changes: 21 additions & 0 deletions bindings/pydrake/systems/test/custom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,27 @@ def _system_reset(self, system, context, event, state):
self.assertTrue(system.called_reset)
self.assertTrue(system.called_system_reset)

# Test ExecuteInitializationEvents.
system = TrivialSystem()
context = system.CreateDefaultContext()
system.ExecuteInitializationEvents(context=context)
self.assertFalse(system.called_per_step)
self.assertFalse(system.called_periodic)
self.assertTrue(system.called_initialize_publish)
self.assertTrue(system.called_initialize_discrete)
self.assertTrue(system.called_initialize_unrestricted)
self.assertFalse(system.called_periodic_publish)
self.assertFalse(system.called_periodic_discrete)
self.assertFalse(system.called_periodic_unrestricted)
self.assertFalse(system.called_per_step_publish)
self.assertFalse(system.called_per_step_discrete)
self.assertFalse(system.called_per_step_unrestricted)
self.assertFalse(system.called_getwitness)
self.assertFalse(system.called_witness)
self.assertFalse(system.called_guard)
self.assertFalse(system.called_reset)
self.assertFalse(system.called_system_reset)

def test_event_handler_returns_none(self):
"""Checks that a Python event handler callback function is allowed to
(implicitly) return None, instead of an EventStatus. Because of all the
Expand Down
2 changes: 2 additions & 0 deletions multibody/meshcat/joint_sliders.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ Eigen::VectorXd JointSliders<T>::Run(const Diagram<T>& diagram,
using Duration = std::chrono::duration<double>;
const auto start_time = Clock::now();

diagram.ExecuteInitializationEvents(root_context.get());

// Set the context to the initial slider values.
plant_->SetPositions(&plant_context,
this->get_output_port().Eval(sliders_context));
Expand Down
5 changes: 4 additions & 1 deletion multibody/meshcat/joint_sliders.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ class JointSliders final : public systems::LeafSystem<T> {
value of the plant Context.
@pre `diagram` must be a top-level (i.e., "root") diagram.
@pre `diagram` must contain this JointSliders system.
@pre `diagram` must contain the `plant` that was passed into this
JointSliders system's constructor.
@pre `diagram` must contain this JointSliders system, however the output of
these sliders need not be connected (even indirectly) to any `plant` input
port. The positions of the `plant` will be updated directly using a call to
`plant.SetPositions(...)` when the slider values change.
*/
Eigen::VectorXd Run(const systems::Diagram<T>& diagram,
std::optional<double> timeout = std::nullopt,
Expand Down
27 changes: 27 additions & 0 deletions systems/framework/system.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,33 @@ void System<T>::GetInitializationEvents(
DoGetInitializationEvents(context, events);
}

template <typename T>
void System<T>::ExecuteInitializationEvents(Context<T>* context) const {
auto discrete_updates = AllocateDiscreteVariables();
auto state = context->CloneState();
auto init_events = AllocateCompositeEventCollection();

// NOTE: The execution order here must match the code in
// Simulator::Initialize().
GetInitializationEvents(*context, init_events.get());
// Do unrestricted updates first.
if (init_events->get_unrestricted_update_events().HasEvents()) {
CalcUnrestrictedUpdate(*context,
init_events->get_unrestricted_update_events(),
state.get());
}
// Do restricted (discrete variable) updates next.
if (init_events->get_discrete_update_events().HasEvents()) {
CalcDiscreteVariableUpdate(*context,
init_events->get_discrete_update_events(),
discrete_updates.get());
}
// Do any publishes last.
if (init_events->get_publish_events().HasEvents()) {
Publish(*context, init_events->get_publish_events());
}
}

template <typename T>
std::optional<PeriodicEventData>
System<T>::GetUniquePeriodicDiscreteUpdateAttribute() const {
Expand Down
6 changes: 6 additions & 0 deletions systems/framework/system.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,12 @@ class System : public SystemBase {
void GetInitializationEvents(const Context<T>& context,
CompositeEventCollection<T>* events) const;

/** This method triggers all of the initialization events returned by
GetInitializationEvents(). The method allocates temporary storage to perform
the updates, and is intended only as a convenience method for callers who do
not want to use the full Simulator workflow. */
void ExecuteInitializationEvents(Context<T>* context) const;

/** Determines whether there exists a unique periodic timing (offset and
period) that triggers one or more discrete update events (and, if so, returns
that unique periodic timing). Thus, this method can be used (1) as a test to
Expand Down
12 changes: 1 addition & 11 deletions systems/framework/test/diagram_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3752,17 +3752,7 @@ GTEST_TEST(InitializationTest, InitializationTest) {
auto dut = builder.Build();

auto context = dut->CreateDefaultContext();
auto discrete_updates = dut->AllocateDiscreteVariables();
auto state = context->CloneState();
auto init_events = dut->AllocateCompositeEventCollection();
dut->GetInitializationEvents(*context, init_events.get());

dut->Publish(*context, init_events->get_publish_events());
dut->CalcDiscreteVariableUpdate(*context,
init_events->get_discrete_update_events(),
discrete_updates.get());
dut->CalcUnrestrictedUpdate(
*context, init_events->get_unrestricted_update_events(), state.get());
dut->ExecuteInitializationEvents(context.get());

EXPECT_TRUE(sys0->get_pub_init());
EXPECT_TRUE(sys0->get_dis_update_init());
Expand Down
10 changes: 10 additions & 0 deletions systems/framework/test/leaf_system_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2880,6 +2880,16 @@ GTEST_TEST(InitializationTest, InitializationTest) {
EXPECT_TRUE(dut.get_pub_init());
EXPECT_TRUE(dut.get_dis_update_init());
EXPECT_TRUE(dut.get_unres_update_init());

// Now again with the ExecuteInitializationEvents method.
InitializationTestSystem dut2;
auto context2 = dut2.CreateDefaultContext();

dut2.ExecuteInitializationEvents(context2.get());

EXPECT_TRUE(dut2.get_pub_init());
EXPECT_TRUE(dut2.get_dis_update_init());
EXPECT_TRUE(dut2.get_unres_update_init());
}

// Although many of the tests above validate behavior of events when the
Expand Down
17 changes: 1 addition & 16 deletions visualization/meshcat_pose_sliders.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,22 +210,7 @@ RigidTransformd MeshcatPoseSliders<T>::Run(
using Duration = std::chrono::duration<double>;
const auto start_time = Clock::now();

// Handle initialization events.
auto init_events = system.AllocateCompositeEventCollection();
system.GetInitializationEvents(system_context, init_events.get());
if (init_events->get_publish_events().HasEvents()) {
system.Publish(system_context, init_events->get_publish_events());
}
if (init_events->get_discrete_update_events().HasEvents()) {
system.CalcDiscreteVariableUpdate(
system_context, init_events->get_discrete_update_events(),
&root_context->get_mutable_discrete_state());
}
if (init_events->get_unrestricted_update_events().HasEvents()) {
system.CalcUnrestrictedUpdate(system_context,
init_events->get_unrestricted_update_events(),
&root_context->get_mutable_state());
}
system.ExecuteInitializationEvents(root_context.get());

RigidTransformd pose = nominal_pose_;

Expand Down

0 comments on commit 0fcaf9b

Please sign in to comment.