Skip to content

Commit

Permalink
controllers: Use WrapPtr to ensure Context<T> is not copied by pybind.
Browse files Browse the repository at this point in the history
  • Loading branch information
EricCousineau-TRI committed Feb 13, 2018
1 parent 7f15c8a commit 709abd2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
1 change: 1 addition & 0 deletions bindings/pydrake/systems/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ drake_pybind_library(
drake_pybind_library(
name = "controllers_py",
cc_deps = [
"//bindings/pydrake/util:wrap_function",
"//systems/controllers:dynamic_programming",
],
cc_so_name = "controllers",
Expand Down
25 changes: 24 additions & 1 deletion bindings/pydrake/systems/controllers_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,34 @@
#include <pybind11/stl.h>

#include "drake/bindings/pydrake/pydrake_pybind.h"
#include "drake/bindings/pydrake/util/wrap_function.h"
#include "drake/systems/controllers/dynamic_programming.h"

namespace drake {
namespace pydrake {

namespace {

template <typename T, typename = void>
struct wrap_ptr : public wrap_arg_default<T> {};

template <typename T>
struct wrap_ptr<const systems::Context<T>&> {
using Type = systems::Context<T>;
static auto wrap(const Type& arg) { return &arg; }
static auto unwrap(const Type* arg_wrapped) { return *arg_wrapped; }
};

// Ensures that `const Context<T>&` is wrapped with `const Context<T>*`.
// TODO(eric.cousineau): Replace this with general wrappper, place in
// `pydrake_pybind` or somewhere related.
template <typename Func>
auto WrapPtr(Func&& func) {
return WrapFunction<wrap_ptr>(std::forward<Func>(func));
}

} // namespace

PYBIND11_MODULE(controllers, m) {
// NOLINTNEXTLINE(build/namespaces): Emulate placement in namespace.
using namespace drake::systems::controllers;
Expand All @@ -21,7 +44,7 @@ PYBIND11_MODULE(controllers, m) {
.def_readwrite("discount_factor",
&DynamicProgrammingOptions::discount_factor);

m.def("FittedValueIteration", &FittedValueIteration);
m.def("FittedValueIteration", WrapPtr(&FittedValueIteration));
}

} // namespace pydrake
Expand Down

0 comments on commit 709abd2

Please sign in to comment.