Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements System<T>::ExecuteInitializationEvents() #19577

Merged
merged 1 commit into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bindings/pydrake/systems/framework_py_systems.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ struct Impl {
.def("CalcForcedUnrestrictedUpdate",
&System<T>::CalcForcedUnrestrictedUpdate, py::arg("context"),
py::arg("state"), doc.System.CalcForcedUnrestrictedUpdate.doc)
.def("ExecuteInitializationEvents",
&System<T>::ExecuteInitializationEvents, py::arg("context"),
doc.System.ExecuteInitializationEvents.doc)
.def("GetUniquePeriodicDiscreteUpdateAttribute",
&System<T>::GetUniquePeriodicDiscreteUpdateAttribute,
doc.System.GetUniquePeriodicDiscreteUpdateAttribute.doc)
Expand Down
21 changes: 21 additions & 0 deletions bindings/pydrake/systems/test/custom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,27 @@ def _system_reset(self, system, context, event, state):
self.assertTrue(system.called_reset)
self.assertTrue(system.called_system_reset)

# Test ExecuteInitializationEvents.
system = TrivialSystem()
context = system.CreateDefaultContext()
system.ExecuteInitializationEvents(context=context)
self.assertFalse(system.called_per_step)
self.assertFalse(system.called_periodic)
self.assertTrue(system.called_initialize_publish)
self.assertTrue(system.called_initialize_discrete)
self.assertTrue(system.called_initialize_unrestricted)
self.assertFalse(system.called_periodic_publish)
self.assertFalse(system.called_periodic_discrete)
self.assertFalse(system.called_periodic_unrestricted)
self.assertFalse(system.called_per_step_publish)
self.assertFalse(system.called_per_step_discrete)
self.assertFalse(system.called_per_step_unrestricted)
self.assertFalse(system.called_getwitness)
self.assertFalse(system.called_witness)
self.assertFalse(system.called_guard)
self.assertFalse(system.called_reset)
self.assertFalse(system.called_system_reset)

def test_event_handler_returns_none(self):
"""Checks that a Python event handler callback function is allowed to
(implicitly) return None, instead of an EventStatus. Because of all the
Expand Down
1 change: 1 addition & 0 deletions multibody/meshcat/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ drake_cc_googletest(
"//geometry:meshcat_visualizer",
"//geometry/test_utilities:meshcat_environment",
"//multibody/parsing",
"//systems/framework/test_utilities",
],
)

Expand Down
2 changes: 2 additions & 0 deletions multibody/meshcat/joint_sliders.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ Eigen::VectorXd JointSliders<T>::Run(const Diagram<T>& diagram,
using Duration = std::chrono::duration<double>;
const auto start_time = Clock::now();

diagram.ExecuteInitializationEvents(root_context.get());

// Set the context to the initial slider values.
plant_->SetPositions(&plant_context,
this->get_output_port().Eval(sliders_context));
Expand Down
5 changes: 4 additions & 1 deletion multibody/meshcat/joint_sliders.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,12 @@ class JointSliders final : public systems::LeafSystem<T> {
value of the plant Context.

@pre `diagram` must be a top-level (i.e., "root") diagram.
@pre `diagram` must contain this JointSliders system.
@pre `diagram` must contain the `plant` that was passed into this
JointSliders system's constructor.
@pre `diagram` must contain this JointSliders system, however the output of
these sliders need not be connected (even indirectly) to any `plant` input
port. The positions of the `plant` will be updated directly using a call to
`plant.SetPositions(...)` when the slider values change.
*/
Eigen::VectorXd Run(const systems::Diagram<T>& diagram,
std::optional<double> timeout = std::nullopt,
Expand Down
9 changes: 9 additions & 0 deletions multibody/meshcat/test/joint_sliders_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "drake/multibody/parsing/parser.h"
#include "drake/multibody/plant/multibody_plant.h"
#include "drake/systems/framework/diagram_builder.h"
#include "drake/systems/framework/test_utilities/initialization_test_system.h"

namespace drake {
namespace multibody {
Expand Down Expand Up @@ -282,13 +283,21 @@ TEST_F(JointSlidersTest, Run) {
MeshcatVisualizer<double>::AddToBuilder(&builder_, scene_graph_, meshcat_);
auto* dut = builder_.AddSystem<JointSliders<double>>(meshcat_, &plant_,
initial_value);

auto init_system = builder_.AddSystem<systems::InitializationTestSystem>();

auto diagram = builder_.Build();

// Run for a while.
const double timeout = 1.0;
Eigen::VectorXd q = dut->Run(*diagram, timeout);
EXPECT_TRUE(CompareMatrices(q, initial_value));

// Confirm that initialization events were triggered.
EXPECT_TRUE(init_system->get_pub_init());
EXPECT_TRUE(init_system->get_dis_update_init());
EXPECT_TRUE(init_system->get_unres_update_init());

// Note: the stop button is deleted on timeout, so we cannot easily check
// that it was created correctly here.

Expand Down
39 changes: 23 additions & 16 deletions systems/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,8 @@ drake_cc_googletest(
"//common:essential",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
"//systems/framework/test_utilities:pack_value",
],
)

Expand All @@ -746,7 +747,8 @@ drake_cc_googletest(
"//common:essential",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
"//systems/framework/test_utilities:pack_value",
],
)

Expand Down Expand Up @@ -803,7 +805,9 @@ drake_cc_googletest(
"//common/test_utilities:is_dynamic_castable",
"//examples/pendulum:pendulum_plant",
"//systems/analysis/test_utilities:stateless_system",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:initialization_test_system",
"//systems/framework/test_utilities:pack_value",
"//systems/framework/test_utilities:scalar_conversion",
"//systems/primitives:adder",
"//systems/primitives:constant_value_source",
"//systems/primitives:constant_vector_source",
Expand All @@ -822,7 +826,7 @@ drake_cc_googletest(
"//common/test_utilities:eigen_matrix_compare",
"//common/test_utilities:expect_throws_message",
"//common/test_utilities:is_dynamic_castable",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand Down Expand Up @@ -863,7 +867,7 @@ drake_cc_googletest(
":leaf_context",
"//common:essential",
"//common/test_utilities",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:pack_value",
],
)

Expand All @@ -874,7 +878,9 @@ drake_cc_googletest(
"//common:essential",
"//common/test_utilities",
"//common/test_utilities:limit_malloc",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:initialization_test_system",
"//systems/framework/test_utilities:my_vector",
"//systems/framework/test_utilities:pack_value",
],
)

Expand All @@ -889,7 +895,7 @@ drake_cc_googletest(
name = "model_values_test",
deps = [
":model_values",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand All @@ -898,7 +904,7 @@ drake_cc_googletest(
deps = [
":abstract_values",
"//common:essential",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:pack_value",
],
)

Expand All @@ -907,7 +913,7 @@ drake_cc_googletest(
deps = [
":parameters",
"//common:essential",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:pack_value",
],
)

Expand Down Expand Up @@ -941,7 +947,8 @@ drake_cc_googletest(
deps = [
":system_output",
"//common:essential",
"//systems/framework/test_utilities",
"//common/test_utilities:is_dynamic_castable",
"//systems/framework/test_utilities:my_vector",
],
)

Expand Down Expand Up @@ -972,7 +979,7 @@ drake_cc_googletest(
":input_port",
":leaf_system",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand All @@ -984,7 +991,7 @@ drake_cc_googletest(
"//common:essential",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
"//systems/primitives:constant_vector_source",
],
)
Expand All @@ -1011,7 +1018,7 @@ drake_cc_googletest(
"//common:unused",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand All @@ -1030,7 +1037,7 @@ drake_cc_googletest(
"//common:essential",
"//common:unused",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:my_vector",
],
)

Expand All @@ -1048,7 +1055,7 @@ drake_cc_googletest(
name = "single_output_vector_source_test",
deps = [
":single_output_vector_source",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:scalar_conversion",
],
)

Expand All @@ -1058,7 +1065,7 @@ drake_cc_googletest(
":vector_system",
"//common/test_utilities:expect_no_throw",
"//common/test_utilities:expect_throws_message",
"//systems/framework/test_utilities",
"//systems/framework/test_utilities:scalar_conversion",
"//systems/primitives:integrator",
],
)
Expand Down
31 changes: 31 additions & 0 deletions systems/framework/system.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,37 @@ void System<T>::GetInitializationEvents(
DoGetInitializationEvents(context, events);
}

template <typename T>
void System<T>::ExecuteInitializationEvents(Context<T>* context) const {
auto discrete_updates = AllocateDiscreteVariables();
auto state = context->CloneState();
auto init_events = AllocateCompositeEventCollection();

// NOTE: The execution order here must match the code in
// Simulator::Initialize().
GetInitializationEvents(*context, init_events.get());
// Do unrestricted updates first.
if (init_events->get_unrestricted_update_events().HasEvents()) {
CalcUnrestrictedUpdate(*context,
init_events->get_unrestricted_update_events(),
state.get());
ApplyUnrestrictedUpdate(init_events->get_unrestricted_update_events(),
state.get(), context);
}
// Do restricted (discrete variable) updates next.
if (init_events->get_discrete_update_events().HasEvents()) {
CalcDiscreteVariableUpdate(*context,
init_events->get_discrete_update_events(),
discrete_updates.get());
ApplyDiscreteVariableUpdate(init_events->get_discrete_update_events(),
discrete_updates.get(), context);
}
// Do any publishes last.
if (init_events->get_publish_events().HasEvents()) {
Publish(*context, init_events->get_publish_events());
}
}

template <typename T>
std::optional<PeriodicEventData>
System<T>::GetUniquePeriodicDiscreteUpdateAttribute() const {
Expand Down
10 changes: 10 additions & 0 deletions systems/framework/system.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,16 @@ class System : public SystemBase {
void GetInitializationEvents(const Context<T>& context,
CompositeEventCollection<T>* events) const;

/** This method triggers all of the initialization events returned by
GetInitializationEvents(). The method allocates temporary storage to perform
the updates, and is intended only as a convenience method for callers who do
not want to use the full Simulator workflow.

Note that this is not fully equivalent to Simulator::Initialize() because
_only_ initialization events are handled here, while Simulator::Initialize()
also processes other events associated with time zero. */
void ExecuteInitializationEvents(Context<T>* context) const;

/** Determines whether there exists a unique periodic timing (offset and
period) that triggers one or more discrete update events (and, if so, returns
that unique periodic timing). Thus, this method can be used (1) as a test to
Expand Down
66 changes: 2 additions & 64 deletions systems/framework/test/diagram_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "drake/systems/framework/fixed_input_port_value.h"
#include "drake/systems/framework/leaf_system.h"
#include "drake/systems/framework/output_port.h"
#include "drake/systems/framework/test_utilities/initialization_test_system.h"
#include "drake/systems/framework/test_utilities/pack_value.h"
#include "drake/systems/framework/test_utilities/scalar_conversion.h"
#include "drake/systems/primitives/adder.h"
Expand Down Expand Up @@ -3691,59 +3692,6 @@ GTEST_TEST(RandomContextTest, SetRandomTest) {

// Tests initialization works properly for all subsystems.
GTEST_TEST(InitializationTest, InitializationTest) {
// Note: this class is duplicated in leaf_system_test.
class InitializationTestSystem : public LeafSystem<double> {
public:
InitializationTestSystem() {
PublishEvent<double> pub_event(
TriggerType::kInitialization,
std::bind(&InitializationTestSystem::InitPublish, this,
std::placeholders::_1, std::placeholders::_2));
DeclareInitializationEvent(pub_event);

DeclareInitializationEvent(DiscreteUpdateEvent<double>(
TriggerType::kInitialization));
DeclareInitializationEvent(UnrestrictedUpdateEvent<double>(
TriggerType::kInitialization));
}

bool get_pub_init() const { return pub_init_; }
bool get_dis_update_init() const { return dis_update_init_; }
bool get_unres_update_init() const { return unres_update_init_; }

private:
void InitPublish(const Context<double>&,
const PublishEvent<double>& event) const {
EXPECT_EQ(event.get_trigger_type(),
TriggerType::kInitialization);
pub_init_ = true;
}

void DoCalcDiscreteVariableUpdates(
const Context<double>&,
const std::vector<const DiscreteUpdateEvent<double>*>& events,
DiscreteValues<double>*) const final {
EXPECT_EQ(events.size(), 1);
EXPECT_EQ(events.front()->get_trigger_type(),
TriggerType::kInitialization);
dis_update_init_ = true;
}

void DoCalcUnrestrictedUpdate(
const Context<double>&,
const std::vector<const UnrestrictedUpdateEvent<double>*>& events,
State<double>*) const final {
EXPECT_EQ(events.size(), 1);
EXPECT_EQ(events.front()->get_trigger_type(),
TriggerType::kInitialization);
unres_update_init_ = true;
}

mutable bool pub_init_{false};
mutable bool dis_update_init_{false};
mutable bool unres_update_init_{false};
};

DiagramBuilder<double> builder;

auto sys0 = builder.AddSystem<InitializationTestSystem>();
Expand All @@ -3752,17 +3700,7 @@ GTEST_TEST(InitializationTest, InitializationTest) {
auto dut = builder.Build();

auto context = dut->CreateDefaultContext();
auto discrete_updates = dut->AllocateDiscreteVariables();
auto state = context->CloneState();
auto init_events = dut->AllocateCompositeEventCollection();
dut->GetInitializationEvents(*context, init_events.get());

dut->Publish(*context, init_events->get_publish_events());
dut->CalcDiscreteVariableUpdate(*context,
init_events->get_discrete_update_events(),
discrete_updates.get());
dut->CalcUnrestrictedUpdate(
*context, init_events->get_unrestricted_update_events(), state.get());
dut->ExecuteInitializationEvents(context.get());

EXPECT_TRUE(sys0->get_pub_init());
EXPECT_TRUE(sys0->get_dis_update_init());
Expand Down
Loading