forked from RobotLocomotion/drake
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add bindings for barycentric mesh, barycentric mesh system, and value…
… iteration
- Loading branch information
1 parent
e684a6b
commit f075f88
Showing
11 changed files
with
221 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#include <pybind11/eigen.h> | ||
#include <pybind11/functional.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "drake/bindings/pydrake/pydrake_pybind.h" | ||
#include "drake/math/barycentric.h" | ||
|
||
namespace drake { | ||
namespace pydrake { | ||
|
||
PYBIND11_MODULE(math, m) { | ||
// NOLINTNEXTLINE(build/namespaces): Emulate placement in namespace. | ||
using namespace drake::math; | ||
|
||
m.doc() = "Bindings for //math."; | ||
|
||
// TODO(eric.cousineau): At present, we only bind doubles. | ||
// In the future, we will bind more scalar types, and enable scalar | ||
// conversion. | ||
using T = double; | ||
|
||
py::class_<BarycentricMesh<T>>(m, "BarycentricMesh") | ||
.def(py::init<BarycentricMesh<T>::MeshGrid>()) | ||
.def("get_input_grid", &BarycentricMesh<T>::get_input_grid) | ||
.def("get_input_size", &BarycentricMesh<T>::get_input_size) | ||
.def("get_num_mesh_points", &BarycentricMesh<T>::get_num_mesh_points) | ||
.def("get_num_interpolants", &BarycentricMesh<T>::get_num_interpolants) | ||
.def("get_mesh_point", overload_cast_explicit<VectorX<T>,int> | ||
(&BarycentricMesh<T>::get_mesh_point)) | ||
.def("Eval", overload_cast_explicit<VectorX<T>,const Eigen::Ref<const MatrixX<T>>&, | ||
const Eigen::Ref<const VectorX<T>>&>( | ||
&BarycentricMesh<T>::Eval)) | ||
.def("MeshValuesFrom", &BarycentricMesh<T>::MeshValuesFrom); | ||
} | ||
|
||
} // namespace pydrake | ||
} // namespace drake |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#include <pybind11/eigen.h> | ||
#include <pybind11/functional.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "drake/bindings/pydrake/pydrake_pybind.h" | ||
#include "drake/systems/controllers/dynamic_programming.h" | ||
|
||
namespace drake { | ||
namespace pydrake { | ||
|
||
PYBIND11_MODULE(controllers, m) { | ||
// NOLINTNEXTLINE(build/namespaces): Emulate placement in namespace. | ||
using namespace drake::systems::controllers; | ||
|
||
py::module::import("pydrake.math"); | ||
py::module::import("pydrake.systems.primitives"); | ||
|
||
py::class_<DynamicProgrammingOptions>(m, "DynamicProgrammingOptions") | ||
.def(py::init<>()) | ||
.def_readwrite("discount_factor", | ||
&DynamicProgrammingOptions::discount_factor); | ||
|
||
m.def("FittedValueIteration", &FittedValueIteration); | ||
} | ||
|
||
} // namespace pydrake | ||
} // namespace drake |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env python | ||
|
||
from __future__ import print_function | ||
|
||
import unittest | ||
import math | ||
import numpy as np | ||
|
||
from pydrake.examples.pendulum import PendulumPlant | ||
from pydrake.systems.analysis import Simulator | ||
from pydrake.math import BarycentricMesh | ||
from pydrake.systems.controllers import ( | ||
DynamicProgrammingOptions, FittedValueIteration) | ||
|
||
|
||
class TestControllers(unittest.TestCase): | ||
def test_fitted_value_iteration_pendulum(self): | ||
plant = PendulumPlant() | ||
simulator = Simulator(plant) | ||
|
||
def quadratic_regulator_cost(context): | ||
print("got here") | ||
# x = context.get_continuous_state_vector().CopyToVector() | ||
# u = plant.EvalVectorInput(context,0) | ||
# print(type(x.dot(x) + u.dot(u))) | ||
return 0 | ||
|
||
state_mesh = [set(np.linspace(0,2*math.pi,51)), | ||
set(np.linspace(-10,10,51))] | ||
input_limit = 2 | ||
input_mesh = [set(np.linspace(-input_limit, input_limit, 9))] | ||
timestep = 0.01 | ||
|
||
options = DynamicProgrammingOptions() | ||
|
||
policy, value_function = FittedValueIteration(simulator, | ||
quadratic_regulator_cost, | ||
state_mesh, input_mesh, timestep, options) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import unittest | ||
from pydrake.math import (BarycentricMesh) | ||
import numpy as np | ||
|
||
class TestBarycentricMesh(unittest.TestCase): | ||
def testSpelling(self): | ||
mesh = BarycentricMesh([{0, 1}, {0, 1}]) | ||
values = np.array([[0, 1, 2, 3]]) | ||
|
||
mesh.get_input_grid() | ||
self.assertTrue(mesh.get_input_size() == 2) | ||
self.assertTrue(mesh.get_num_mesh_points() == 4) | ||
self.assertTrue(mesh.get_num_interpolants() == 3) | ||
self.assertTrue((mesh.get_mesh_point(0) == [0., 0.]).all()) | ||
self.assertTrue(mesh.Eval(values, (0, 1))[0] == 2) | ||
|
||
def testMeshValuesFrom(self): | ||
mesh = BarycentricMesh([{0, 1}, {0, 1}]) | ||
|
||
def mynorm(x): | ||
return [x.dot(x)] | ||
|
||
values = mesh.MeshValuesFrom(mynorm) | ||
self.assertTrue(values.size == 4) | ||
|
||
def testSurf(self): | ||
mesh = BarycentricMesh([{0, 1}, {0, 1}]) | ||
values = np.array([[0, 1, 2, 3]]) | ||
|
||
import matplotlib.pyplot as plt | ||
from mpl_toolkits.mplot3d import Axes3D | ||
fig = plt.figure() | ||
ax = fig.add_subplot(111, projection='3d') | ||
|
||
X, Y = np.meshgrid(list(mesh.get_input_grid()[0]), list(mesh.get_input_grid()[1])) | ||
Z = X | ||
for i in range(0, X.size): | ||
Z.itemset(i, mesh.Eval(values,(X.item(i),Y.item(i)))[0]) | ||
|
||
ax.plot_surface(X,Y,Z) | ||
# plt.show() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters