From ae048f7cdc4de7e5c8881940369658a0d287f411 Mon Sep 17 00:00:00 2001 From: Russ Tedrake Date: Sat, 3 Jun 2023 08:42:24 -0400 Subject: [PATCH] Implements System::ExecuteInitializationEvents() and uses it in MeshcatPoseSliders() and JointSliders() --- .../pydrake/systems/framework_py_systems.cc | 3 +++ bindings/pydrake/systems/test/custom_test.py | 21 ++++++++++++++++ multibody/meshcat/joint_sliders.cc | 2 ++ multibody/meshcat/joint_sliders.h | 5 +++- systems/framework/system.cc | 25 +++++++++++++++++++ systems/framework/system.h | 6 +++++ systems/framework/test/diagram_test.cc | 12 +-------- systems/framework/test/leaf_system_test.cc | 10 ++++++++ visualization/meshcat_pose_sliders.cc | 17 +------------ 9 files changed, 73 insertions(+), 28 deletions(-) diff --git a/bindings/pydrake/systems/framework_py_systems.cc b/bindings/pydrake/systems/framework_py_systems.cc index 5a6cc84ade14..0f689ab09d6c 100644 --- a/bindings/pydrake/systems/framework_py_systems.cc +++ b/bindings/pydrake/systems/framework_py_systems.cc @@ -414,6 +414,9 @@ struct Impl { .def("CalcForcedUnrestrictedUpdate", &System::CalcForcedUnrestrictedUpdate, py::arg("context"), py::arg("state"), doc.System.CalcForcedUnrestrictedUpdate.doc) + .def("ExecuteInitializationEvents", + &System::ExecuteInitializationEvents, py::arg("context"), + doc.System.ExecuteInitializationEvents.doc) .def("GetUniquePeriodicDiscreteUpdateAttribute", &System::GetUniquePeriodicDiscreteUpdateAttribute, doc.System.GetUniquePeriodicDiscreteUpdateAttribute.doc) diff --git a/bindings/pydrake/systems/test/custom_test.py b/bindings/pydrake/systems/test/custom_test.py index 78ea220aea36..717b9b8258d8 100644 --- a/bindings/pydrake/systems/test/custom_test.py +++ b/bindings/pydrake/systems/test/custom_test.py @@ -638,6 +638,27 @@ def _system_reset(self, system, context, event, state): self.assertTrue(system.called_reset) self.assertTrue(system.called_system_reset) + # Test ExecuteInitializationEvents. + system = TrivialSystem() + context = system.CreateDefaultContext() + system.ExecuteInitializationEvents(context=context) + self.assertFalse(system.called_per_step) + self.assertFalse(system.called_periodic) + self.assertTrue(system.called_initialize_publish) + self.assertTrue(system.called_initialize_discrete) + self.assertTrue(system.called_initialize_unrestricted) + self.assertFalse(system.called_periodic_publish) + self.assertFalse(system.called_periodic_discrete) + self.assertFalse(system.called_periodic_unrestricted) + self.assertFalse(system.called_per_step_publish) + self.assertFalse(system.called_per_step_discrete) + self.assertFalse(system.called_per_step_unrestricted) + self.assertFalse(system.called_getwitness) + self.assertFalse(system.called_witness) + self.assertFalse(system.called_guard) + self.assertFalse(system.called_reset) + self.assertFalse(system.called_system_reset) + def test_event_handler_returns_none(self): """Checks that a Python event handler callback function is allowed to (implicitly) return None, instead of an EventStatus. Because of all the diff --git a/multibody/meshcat/joint_sliders.cc b/multibody/meshcat/joint_sliders.cc index d988c16b8460..ca79f9f21c4c 100644 --- a/multibody/meshcat/joint_sliders.cc +++ b/multibody/meshcat/joint_sliders.cc @@ -279,6 +279,8 @@ Eigen::VectorXd JointSliders::Run(const Diagram& diagram, using Duration = std::chrono::duration; const auto start_time = Clock::now(); + diagram.ExecuteInitializationEvents(root_context.get()); + // Set the context to the initial slider values. plant_->SetPositions(&plant_context, this->get_output_port().Eval(sliders_context)); diff --git a/multibody/meshcat/joint_sliders.h b/multibody/meshcat/joint_sliders.h index 34eba3867ece..50fc50eda650 100644 --- a/multibody/meshcat/joint_sliders.h +++ b/multibody/meshcat/joint_sliders.h @@ -120,9 +120,12 @@ class JointSliders final : public systems::LeafSystem { value of the plant Context. @pre `diagram` must be a top-level (i.e., "root") diagram. - @pre `diagram` must contain this JointSliders system. @pre `diagram` must contain the `plant` that was passed into this JointSliders system's constructor. + @pre `diagram` must contain this JointSliders system, however the output of + these sliders need not be connected (even indirectly) to any `plant` input + port. The positions of the `plant` will be updated directly using a call to + `plant.SetPositions(...)` when the slider values change. */ Eigen::VectorXd Run(const systems::Diagram& diagram, std::optional timeout = std::nullopt, diff --git a/systems/framework/system.cc b/systems/framework/system.cc index 2fb17abfef1f..5801552122a4 100644 --- a/systems/framework/system.cc +++ b/systems/framework/system.cc @@ -478,6 +478,31 @@ void System::GetInitializationEvents( DoGetInitializationEvents(context, events); } +template +void System::ExecuteInitializationEvents(Context* context) const { + auto discrete_updates = AllocateDiscreteVariables(); + auto state = context->CloneState(); + auto init_events = AllocateCompositeEventCollection(); + + GetInitializationEvents(*context, init_events.get()); + // Do unrestricted updates first. + if (init_events->get_unrestricted_update_events().HasEvents()) { + CalcUnrestrictedUpdate(*context, + init_events->get_unrestricted_update_events(), + state.get()); + } + // Do restricted (discrete variable) updates next. + if (init_events->get_discrete_update_events().HasEvents()) { + CalcDiscreteVariableUpdate(*context, + init_events->get_discrete_update_events(), + discrete_updates.get()); + } + // Do any publishes last. + if (init_events->get_publish_events().HasEvents()) { + Publish(*context, init_events->get_publish_events()); + } +} + template std::optional System::GetUniquePeriodicDiscreteUpdateAttribute() const { diff --git a/systems/framework/system.h b/systems/framework/system.h index 53a0dcafd801..36ca15bff17f 100644 --- a/systems/framework/system.h +++ b/systems/framework/system.h @@ -727,6 +727,12 @@ class System : public SystemBase { void GetInitializationEvents(const Context& context, CompositeEventCollection* events) const; + /** This method triggers all of the initialization events returned by + GetInitializationEvents(). The method allocates temporary storage to perform + the updates, and is intended only as a convenience method for callers who do + not want to use the full Simulator workflow. */ + void ExecuteInitializationEvents(Context* context) const; + /** Determines whether there exists a unique periodic timing (offset and period) that triggers one or more discrete update events (and, if so, returns that unique periodic timing). Thus, this method can be used (1) as a test to diff --git a/systems/framework/test/diagram_test.cc b/systems/framework/test/diagram_test.cc index 510e6ae231be..779684a133ae 100644 --- a/systems/framework/test/diagram_test.cc +++ b/systems/framework/test/diagram_test.cc @@ -3752,17 +3752,7 @@ GTEST_TEST(InitializationTest, InitializationTest) { auto dut = builder.Build(); auto context = dut->CreateDefaultContext(); - auto discrete_updates = dut->AllocateDiscreteVariables(); - auto state = context->CloneState(); - auto init_events = dut->AllocateCompositeEventCollection(); - dut->GetInitializationEvents(*context, init_events.get()); - - dut->Publish(*context, init_events->get_publish_events()); - dut->CalcDiscreteVariableUpdate(*context, - init_events->get_discrete_update_events(), - discrete_updates.get()); - dut->CalcUnrestrictedUpdate( - *context, init_events->get_unrestricted_update_events(), state.get()); + dut->ExecuteInitializationEvents(context.get()); EXPECT_TRUE(sys0->get_pub_init()); EXPECT_TRUE(sys0->get_dis_update_init()); diff --git a/systems/framework/test/leaf_system_test.cc b/systems/framework/test/leaf_system_test.cc index 65baa19f3ee6..0e7f5bc46cfd 100644 --- a/systems/framework/test/leaf_system_test.cc +++ b/systems/framework/test/leaf_system_test.cc @@ -2880,6 +2880,16 @@ GTEST_TEST(InitializationTest, InitializationTest) { EXPECT_TRUE(dut.get_pub_init()); EXPECT_TRUE(dut.get_dis_update_init()); EXPECT_TRUE(dut.get_unres_update_init()); + + // Now again with the ExecuteInitializationEvents method. + InitializationTestSystem dut2; + auto context2 = dut2.CreateDefaultContext(); + + dut2.ExecuteInitializationEvents(context2.get()); + + EXPECT_TRUE(dut2.get_pub_init()); + EXPECT_TRUE(dut2.get_dis_update_init()); + EXPECT_TRUE(dut2.get_unres_update_init()); } // Although many of the tests above validate behavior of events when the diff --git a/visualization/meshcat_pose_sliders.cc b/visualization/meshcat_pose_sliders.cc index af239e5c272e..efadce7a9976 100644 --- a/visualization/meshcat_pose_sliders.cc +++ b/visualization/meshcat_pose_sliders.cc @@ -210,22 +210,7 @@ RigidTransformd MeshcatPoseSliders::Run( using Duration = std::chrono::duration; const auto start_time = Clock::now(); - // Handle initialization events. - auto init_events = system.AllocateCompositeEventCollection(); - system.GetInitializationEvents(system_context, init_events.get()); - if (init_events->get_publish_events().HasEvents()) { - system.Publish(system_context, init_events->get_publish_events()); - } - if (init_events->get_discrete_update_events().HasEvents()) { - system.CalcDiscreteVariableUpdate( - system_context, init_events->get_discrete_update_events(), - &root_context->get_mutable_discrete_state()); - } - if (init_events->get_unrestricted_update_events().HasEvents()) { - system.CalcUnrestrictedUpdate(system_context, - init_events->get_unrestricted_update_events(), - &root_context->get_mutable_state()); - } + system.ExecuteInitializationEvents(root_context.get()); RigidTransformd pose = nominal_pose_;