Skip to content

Commit

Permalink
Implements System<T>::ExecuteInitializationEvents() (RobotLocomotion#…
Browse files Browse the repository at this point in the history
…19577)

and uses it in MeshcatPoseSliders() and JointSliders()
  • Loading branch information
RussTedrake committed Jul 3, 2023
1 parent 2ee647e commit f13f312
Show file tree
Hide file tree
Showing 16 changed files with 216 additions and 147 deletions.
3 changes: 3 additions & 0 deletions bindings/pydrake/systems/framework_py_systems.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ struct Impl {
.def("CalcForcedUnrestrictedUpdate",
&System<T>::CalcForcedUnrestrictedUpdate, py::arg("context"),
py::arg("state"), doc.System.CalcForcedUnrestrictedUpdate.doc)
.def("ExecuteInitializationEvents",
&System<T>::ExecuteInitializationEvents, py::arg("context"),
doc.System.ExecuteInitializationEvents.doc)
.def("GetUniquePeriodicDiscreteUpdateAttribute",
&System<T>::GetUniquePeriodicDiscreteUpdateAttribute,
doc.System.GetUniquePeriodicDiscreteUpdateAttribute.doc)
Expand Down
21 changes: 21 additions & 0 deletions bindings/pydrake/systems/test/custom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,27 @@ def _system_reset(self, system, context, event, state):
self.assertTrue(system.called_reset)
self.assertTrue(system.called_system_reset)

# Test ExecuteInitializationEvents.
system = TrivialSystem()
context = system.CreateDefaultContext()
system.ExecuteInitializationEvents(context=context)
self.assertFalse(system.called_per_step)
self.assertFalse(system.called_periodic)
self.assertTrue(system.called_initialize_publish)
self.assertTrue(system.called_initialize_discrete)
self.assertTrue(system.called_initialize_unrestricted)
self.assertFalse(system.called_periodic_publish)
self.assertFalse(system.called_periodic_discrete)
self.assertFalse(system.called_periodic_unrestricted)
self.assertFalse(system.called_per_step_publish)
self.assertFalse(system.called_per_step_discrete)
self.assertFalse(system.called_per_step_unrestricted)
self.assertFalse(system.called_getwitness)
self.assertFalse(system.called_witness)
self.assertFalse(system.called_guard)
self.assertFalse(system.called_reset)
self.assertFalse(system.called_system_reset)

def test_event_handler_returns_none(self):
"""Checks that a Python event handler callback function is allowed to
(implicitly) return None, instead of an EventStatus. Because of all the
Expand Down
1 change: 1 addition & 0 deletions multibody/meshcat/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ drake_cc_googletest(
"//geometry:meshcat_visualizer",
"//geometry/test_utilities:meshcat_environment",
"//multibody/parsing",
"//systems/framework/test_utilities",
],
)

Expand Down
2 changes: 2 additions & 0 deletions multibody/meshcat/joint_sliders.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ Eigen::VectorXd JointSliders<T>::Run(const Diagram<T>& diagram,
using Duration = std::chrono::duration<double>;
const auto start_time = Clock::now();

diagram.ExecuteInitializationEvents(root_context.get());

// Set the context to the initial slider values.
plant_->SetPositions(&plant_context,
this->get_output_port().Eval(sliders_context));
Expand Down
5 changes: 4 additions & 1 deletion multibody/meshcat/joint_sliders.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ class JointSliders final : public systems::LeafSystem<T> {
value of the plant Context.
@pre `diagram` must be a top-level (i.e., "root") diagram.
@pre `diagram` must contain this JointSliders system.
@pre `diagram` must contain the `plant` that was passed into this
JointSliders system's constructor.
@pre `diagram` must contain this JointSliders system, however the output of
these sliders need not be connected (even indirectly) to any `plant` input
port. The positions of the `plant` will be updated directly using a call to
`plant.SetPositions(...)` when the slider values change.
*/
Eigen::VectorXd Run(const systems::Diagram<T>& diagram,
std::optional<double> timeout = std::nullopt,
Expand Down
9 changes: 9 additions & 0 deletions multibody/meshcat/test/joint_sliders_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "drake/multibody/parsing/parser.h"
#include "drake/multibody/plant/multibody_plant.h"
#include "drake/systems/framework/diagram_builder.h"
#include "drake/systems/framework/test_utilities/initialization_test_system.h"

namespace drake {
namespace multibody {
Expand Down Expand Up @@ -282,13 +283,21 @@ TEST_F(JointSlidersTest, Run) {
MeshcatVisualizer<double>::AddToBuilder(&builder_, scene_graph_, meshcat_);
auto* dut = builder_.AddSystem<JointSliders<double>>(meshcat_, &plant_,
initial_value);

auto init_system = builder_.AddSystem<systems::InitializationTestSystem>();

auto diagram = builder_.Build();

// Run for a while.
const double timeout = 1.0;
Eigen::VectorXd q = dut->Run(*diagram, timeout);
EXPECT_TRUE(CompareMatrices(q, initial_value));

// Confirm that initialization events were triggered.
EXPECT_TRUE(init_system->get_pub_init());
EXPECT_TRUE(init_system->get_dis_update_init());
EXPECT_TRUE(init_system->get_unres_update_init());

// Note: the stop button is deleted on timeout, so we cannot easily check
// that it was created correctly here.

Expand Down
39 changes: 23 additions & 16 deletions systems/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,8 @@ drake_cc_googletest(
"//common:essential",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
"//systems/framework/test_utilities:pack_value",
],
)

Expand All @@ -746,7 +747,8 @@ drake_cc_googletest(
"//common:essential",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
"//systems/framework/test_utilities:pack_value",
],
)

Expand Down Expand Up @@ -803,7 +805,9 @@ drake_cc_googletest(
"//common/test_utilities:is_dynamic_castable",
"//examples/pendulum:pendulum_plant",
"//systems/analysis/test_utilities:stateless_system",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:initialization_test_system",
"//systems/framework/test_utilities:pack_value",
"//systems/framework/test_utilities:scalar_conversion",
"//systems/primitives:adder",
"//systems/primitives:constant_value_source",
"//systems/primitives:constant_vector_source",
Expand All @@ -822,7 +826,7 @@ drake_cc_googletest(
"//common/test_utilities:eigen_matrix_compare",
"//common/test_utilities:expect_throws_message",
"//common/test_utilities:is_dynamic_castable",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand Down Expand Up @@ -863,7 +867,7 @@ drake_cc_googletest(
":leaf_context",
"//common:essential",
"//common/test_utilities",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:pack_value",
],
)

Expand All @@ -874,7 +878,9 @@ drake_cc_googletest(
"//common:essential",
"//common/test_utilities",
"//common/test_utilities:limit_malloc",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:initialization_test_system",
"//systems/framework/test_utilities:my_vector",
"//systems/framework/test_utilities:pack_value",
],
)

Expand All @@ -889,7 +895,7 @@ drake_cc_googletest(
name = "model_values_test",
deps = [
":model_values",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand All @@ -898,7 +904,7 @@ drake_cc_googletest(
deps = [
":abstract_values",
"//common:essential",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:pack_value",
],
)

Expand All @@ -907,7 +913,7 @@ drake_cc_googletest(
deps = [
":parameters",
"//common:essential",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:pack_value",
],
)

Expand Down Expand Up @@ -941,7 +947,8 @@ drake_cc_googletest(
deps = [
":system_output",
"//common:essential",
"//systems/framework/test_utilities",
"//common/test_utilities:is_dynamic_castable",
"//systems/framework/test_utilities:my_vector",
],
)

Expand Down Expand Up @@ -972,7 +979,7 @@ drake_cc_googletest(
":input_port",
":leaf_system",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand All @@ -984,7 +991,7 @@ drake_cc_googletest(
"//common:essential",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
"//systems/primitives:constant_vector_source",
],
)
Expand All @@ -1011,7 +1018,7 @@ drake_cc_googletest(
"//common:unused",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand All @@ -1030,7 +1037,7 @@ drake_cc_googletest(
"//common:essential",
"//common:unused",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand All @@ -1048,7 +1055,7 @@ drake_cc_googletest(
name = "single_output_vector_source_test",
deps = [
":single_output_vector_source",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:scalar_conversion",
],
)

Expand All @@ -1058,7 +1065,7 @@ drake_cc_googletest(
":vector_system",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:scalar_conversion",
"//systems/primitives:integrator",
],
)
Expand Down
31 changes: 31 additions & 0 deletions systems/framework/system.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,37 @@ void System<T>::GetInitializationEvents(
DoGetInitializationEvents(context, events);
}

template <typename T>
void System<T>::ExecuteInitializationEvents(Context<T>* context) const {
auto discrete_updates = AllocateDiscreteVariables();
auto state = context->CloneState();
auto init_events = AllocateCompositeEventCollection();

// NOTE: The execution order here must match the code in
// Simulator::Initialize().
GetInitializationEvents(*context, init_events.get());
// Do unrestricted updates first.
if (init_events->get_unrestricted_update_events().HasEvents()) {
CalcUnrestrictedUpdate(*context,
init_events->get_unrestricted_update_events(),
state.get());
ApplyUnrestrictedUpdate(init_events->get_unrestricted_update_events(),
state.get(), context);
}
// Do restricted (discrete variable) updates next.
if (init_events->get_discrete_update_events().HasEvents()) {
CalcDiscreteVariableUpdate(*context,
init_events->get_discrete_update_events(),
discrete_updates.get());
ApplyDiscreteVariableUpdate(init_events->get_discrete_update_events(),
discrete_updates.get(), context);
}
// Do any publishes last.
if (init_events->get_publish_events().HasEvents()) {
Publish(*context, init_events->get_publish_events());
}
}

template <typename T>
std::optional<PeriodicEventData>
System<T>::GetUniquePeriodicDiscreteUpdateAttribute() const {
Expand Down
10 changes: 10 additions & 0 deletions systems/framework/system.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,16 @@ class System : public SystemBase {
void GetInitializationEvents(const Context<T>& context,
CompositeEventCollection<T>* events) const;

/** This method triggers all of the initialization events returned by
GetInitializationEvents(). The method allocates temporary storage to perform
the updates, and is intended only as a convenience method for callers who do
not want to use the full Simulator workflow.
Note that this is not fully equivalent to Simulator::Initialize() because
_only_ initialization events are handled here, while Simulator::Initialize()
also processes other events associated with time zero. */
void ExecuteInitializationEvents(Context<T>* context) const;

/** Determines whether there exists a unique periodic timing (offset and
period) that triggers one or more discrete update events (and, if so, returns
that unique periodic timing). Thus, this method can be used (1) as a test to
Expand Down
66 changes: 2 additions & 64 deletions systems/framework/test/diagram_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "drake/systems/framework/fixed_input_port_value.h"
#include "drake/systems/framework/leaf_system.h"
#include "drake/systems/framework/output_port.h"
#include "drake/systems/framework/test_utilities/initialization_test_system.h"
#include "drake/systems/framework/test_utilities/pack_value.h"
#include "drake/systems/framework/test_utilities/scalar_conversion.h"
#include "drake/systems/primitives/adder.h"
Expand Down Expand Up @@ -3691,59 +3692,6 @@ GTEST_TEST(RandomContextTest, SetRandomTest) {

// Tests initialization works properly for all subsystems.
GTEST_TEST(InitializationTest, InitializationTest) {
// Note: this class is duplicated in leaf_system_test.
class InitializationTestSystem : public LeafSystem<double> {
public:
InitializationTestSystem() {
PublishEvent<double> pub_event(
TriggerType::kInitialization,
std::bind(&InitializationTestSystem::InitPublish, this,
std::placeholders::_1, std::placeholders::_2));
DeclareInitializationEvent(pub_event);

DeclareInitializationEvent(DiscreteUpdateEvent<double>(
TriggerType::kInitialization));
DeclareInitializationEvent(UnrestrictedUpdateEvent<double>(
TriggerType::kInitialization));
}

bool get_pub_init() const { return pub_init_; }
bool get_dis_update_init() const { return dis_update_init_; }
bool get_unres_update_init() const { return unres_update_init_; }

private:
void InitPublish(const Context<double>&,
const PublishEvent<double>& event) const {
EXPECT_EQ(event.get_trigger_type(),
TriggerType::kInitialization);
pub_init_ = true;
}

void DoCalcDiscreteVariableUpdates(
const Context<double>&,
const std::vector<const DiscreteUpdateEvent<double>*>& events,
DiscreteValues<double>*) const final {
EXPECT_EQ(events.size(), 1);
EXPECT_EQ(events.front()->get_trigger_type(),
TriggerType::kInitialization);
dis_update_init_ = true;
}

void DoCalcUnrestrictedUpdate(
const Context<double>&,
const std::vector<const UnrestrictedUpdateEvent<double>*>& events,
State<double>*) const final {
EXPECT_EQ(events.size(), 1);
EXPECT_EQ(events.front()->get_trigger_type(),
TriggerType::kInitialization);
unres_update_init_ = true;
}

mutable bool pub_init_{false};
mutable bool dis_update_init_{false};
mutable bool unres_update_init_{false};
};

DiagramBuilder<double> builder;

auto sys0 = builder.AddSystem<InitializationTestSystem>();
Expand All @@ -3752,17 +3700,7 @@ GTEST_TEST(InitializationTest, InitializationTest) {
auto dut = builder.Build();

auto context = dut->CreateDefaultContext();
auto discrete_updates = dut->AllocateDiscreteVariables();
auto state = context->CloneState();
auto init_events = dut->AllocateCompositeEventCollection();
dut->GetInitializationEvents(*context, init_events.get());

dut->Publish(*context, init_events->get_publish_events());
dut->CalcDiscreteVariableUpdate(*context,
init_events->get_discrete_update_events(),
discrete_updates.get());
dut->CalcUnrestrictedUpdate(
*context, init_events->get_unrestricted_update_events(), state.get());
dut->ExecuteInitializationEvents(context.get());

EXPECT_TRUE(sys0->get_pub_init());
EXPECT_TRUE(sys0->get_dis_update_init());
Expand Down
Loading

0 comments on commit f13f312

Please sign in to comment.