Skip to content

Commit

Permalink
Support default scalars for TrajectorySource
Browse files Browse the repository at this point in the history
  • Loading branch information
RussTedrake committed Aug 9, 2023
1 parent 7b2bdee commit e7e35c3
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 17 deletions.
18 changes: 9 additions & 9 deletions bindings/pydrake/systems/primitives_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,15 @@ PYBIND11_MODULE(primitives, m) {
.doc_3args_period_sec_abstract_model_value_offset_sec)
.def("period", &ZeroOrderHold<T>::period, doc.ZeroOrderHold.period.doc)
.def("offset", &ZeroOrderHold<T>::offset, doc.ZeroOrderHold.offset.doc);

DefineTemplateClassWithDefault<TrajectorySource<T>, LeafSystem<T>>(
m, "TrajectorySource", GetPyParam<T>(), doc.TrajectorySource.doc)
.def(py::init<const trajectories::Trajectory<T>&, int, bool>(),
py::arg("trajectory"), py::arg("output_derivative_order") = 0,
py::arg("zero_derivatives_beyond_limits") = true,
doc.TrajectorySource.ctor.doc)
.def("UpdateTrajectory", &TrajectorySource<T>::UpdateTrajectory,
py::arg("trajectory"), doc.TrajectorySource.UpdateTrajectory.doc);
};
type_visit(bind_common_scalar_types, CommonScalarPack{});

Expand Down Expand Up @@ -743,15 +752,6 @@ PYBIND11_MODULE(primitives, m) {
py::arg("num_outputs"), py::arg("sampling_interval_sec"),
doc.RandomSource.ctor.doc);

py::class_<TrajectorySource<double>, LeafSystem<double>>(
m, "TrajectorySource", doc.TrajectorySource.doc)
.def(py::init<const trajectories::Trajectory<double>&, int, bool>(),
py::arg("trajectory"), py::arg("output_derivative_order") = 0,
py::arg("zero_derivatives_beyond_limits") = true,
doc.TrajectorySource.ctor.doc)
.def("UpdateTrajectory", &TrajectorySource<double>::UpdateTrajectory,
py::arg("trajectory"), doc.TrajectorySource.UpdateTrajectory.doc);

m.def("AddRandomInputs", &AddRandomInputs<double>,
py::arg("sampling_interval_sec"), py::arg("builder"),
doc.AddRandomInputs.doc)
Expand Down
3 changes: 2 additions & 1 deletion bindings/pydrake/systems/test/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
SymbolicVectorSystem, SymbolicVectorSystem_,
TrajectoryAffineSystem, TrajectoryAffineSystem_,
TrajectoryLinearSystem, TrajectoryLinearSystem_,
TrajectorySource,
TrajectorySource, TrajectorySource_,
VectorLog, VectorLogSink, VectorLogSink_,
WrapToSystem, WrapToSystem_,
ZeroOrderHold, ZeroOrderHold_,
Expand Down Expand Up @@ -110,6 +110,7 @@ def test_instantiations(self):
supports_symbolic=False)
self._check_instantiations(TrajectoryLinearSystem_,
supports_symbolic=False)
self._check_instantiations(TrajectorySource_)
self._check_instantiations(VectorLogSink_)
self._check_instantiations(WrapToSystem_)
self._check_instantiations(ZeroOrderHold_)
Expand Down
16 changes: 16 additions & 0 deletions systems/primitives/test/trajectory_source_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ TEST_F(TrajectorySourceTest, ConstantVectorSourceIsStateless) {
EXPECT_EQ(0, context_->num_continuous_states());
}

template <typename T>
void TestScalar() {
auto pp = PiecewisePolynomial<T>::ZeroOrderHold(Vector3<T>{0, 1, 2},
RowVector3<T>{1.2, 3, 1.5});
TrajectorySource<T> source(pp);
auto context = source.CreateDefaultContext();
context->SetTime(0.5);
EXPECT_NEAR(ExtractDoubleOrThrow(source.get_output_port().Eval(*context))[0],
1.2, 1e-14);
}

GTEST_TEST(AdditionalTrajectorySourceTests, Scalars) {
TestScalar<AutoDiffXd>();
TestScalar<symbolic::Expression>();
}

} // namespace
} // namespace systems
} // namespace drake
21 changes: 16 additions & 5 deletions systems/primitives/trajectory_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,19 @@ void TrajectorySource<T>::DoCalcVectorOutput(
int len = trajectory_->rows();
output->head(len) = trajectory_->value(context.get_time());

double time = context.get_time();
bool set_zero = clamp_derivatives_ && (time > trajectory_->end_time() ||
time < trajectory_->start_time());
T time = context.get_time();
bool set_zero = false;
if (clamp_derivatives_ && !scalar_predicate<T>::is_bool) {
// zero_derivatives_beyond_limits is true by default, but presumably most
// users will not want the clamped derivatives for symbolic types.
log()->warn(
"TrajectorySource: Derivatives are not clamped for symbolic types. "
"Pass zero_derivatives_beyond_limits=false in the constructor to avoid "
"this warning");
} else {
set_zero = clamp_derivatives_ && (time > trajectory_->end_time() ||
time < trajectory_->start_time());
}

for (size_t i = 0; i < derivatives_.size(); ++i) {
if (set_zero) {
Expand All @@ -64,7 +74,8 @@ void TrajectorySource<T>::DoCalcVectorOutput(
}
}

template class TrajectorySource<double>;

} // namespace systems
} // namespace drake

DRAKE_DEFINE_CLASS_TEMPLATE_INSTANTIATIONS_ON_DEFAULT_SCALARS(
class ::drake::systems::TrajectorySource)
8 changes: 6 additions & 2 deletions systems/primitives/trajectory_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace systems {
/// - y0
/// @endsystem
///
/// @tparam_double_only
/// @tparam_default_scalar
/// @ingroup primitive_systems
template <typename T>
class TrajectorySource final : public SingleOutputVectorSource<T> {
Expand All @@ -37,7 +37,8 @@ class TrajectorySource final : public SingleOutputVectorSource<T> {
/// @param output_derivative_order The number of times to take the derivative.
/// Must be greater than or equal to zero.
/// @param zero_derivatives_beyond_limits All derivatives will be zero before
/// the start time or after the end time of @p trajectory.
/// the start time or after the end time of @p trajectory. However, this
/// clamping is ignored for T=Expression.
/// @pre The value of `trajectory` is a column vector. More precisely,
/// trajectory.cols() == 1.
explicit TrajectorySource(const trajectories::Trajectory<T>& trajectory,
Expand Down Expand Up @@ -66,3 +67,6 @@ class TrajectorySource final : public SingleOutputVectorSource<T> {

} // namespace systems
} // namespace drake

DRAKE_DECLARE_CLASS_TEMPLATE_INSTANTIATIONS_ON_DEFAULT_SCALARS(
class ::drake::systems::TrajectorySource)

0 comments on commit e7e35c3

Please sign in to comment.