From 865d974c0285d345425d76df30c38614c828c8c8 Mon Sep 17 00:00:00 2001 From: Brax Team Date: Mon, 7 Oct 2024 13:56:06 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 683315255 Change-Id: I7e46153ca4e925d5708de0682609f97fe926a49a --- README.md | 2 +- brax/__init__.py | 2 +- brax/envs/base.py | 8 ++++-- brax/generalized/pipeline.py | 1 + brax/io/json.py | 38 +++------------------------ brax/io/torch.py | 6 ++--- brax/mjx/pipeline.py | 9 +++++-- brax/mjx/pipeline_test.py | 16 ++++++++++++ brax/positional/pipeline.py | 1 + brax/spring/pipeline.py | 1 + brax/training/learner.py | 6 ++--- brax/v1/jumpy.py | 10 ++++--- brax/v1/tests/jumpy_test.py | 2 +- brax/v1/tests/physics_test.py | 3 ++- brax/visualizer/js/system.js | 42 ++++-------------------------- brax/visualizer/js/viewer.js | 28 +++++++++----------- docs/release-notes/next-release.md | 1 + docs/release-notes/v0.11.0.md | 5 ++++ setup.py | 2 +- 19 files changed, 77 insertions(+), 106 deletions(-) create mode 100644 docs/release-notes/next-release.md create mode 100644 docs/release-notes/v0.11.0.md diff --git a/README.md b/README.md index 6db868bf..e06b1ba9 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ If you would like to reference Brax in a publication, please use: author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem}, title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation}, url = {http://github.com/google/brax}, - version = {0.10.5}, + version = {0.11.0}, year = {2021}, } ``` diff --git a/brax/__init__.py b/brax/__init__.py index 018e4115..0138714a 100644 --- a/brax/__init__.py +++ b/brax/__init__.py @@ -14,7 +14,7 @@ """Import top-level classes and functions here for encapsulation/clarity.""" -__version__ = '0.10.5' +__version__ = '0.11.0' from brax.base import Motion from brax.base import State diff --git a/brax/envs/base.py b/brax/envs/base.py index 4d48ce1e..9097ac2f 100644 --- a/brax/envs/base.py +++ b/brax/envs/base.py @@ -113,10 +113,14 @@ def __init__( self._debug = debug def pipeline_init( - self, q: jax.Array, qd: jax.Array, act: Optional[jax.Array] = None + self, + q: jax.Array, + qd: jax.Array, + act: Optional[jax.Array] = None, + ctrl: Optional[jax.Array] = None, ) -> base.State: """Initializes the pipeline state.""" - return self._pipeline.init(self.sys, q, qd, act, self._debug) + return self._pipeline.init(self.sys, q, qd, act, ctrl, self._debug) def pipeline_step(self, pipeline_state: Any, action: jax.Array) -> base.State: """Takes a physics step using the physics pipeline.""" diff --git a/brax/generalized/pipeline.py b/brax/generalized/pipeline.py index 808fbfe2..982e749e 100644 --- a/brax/generalized/pipeline.py +++ b/brax/generalized/pipeline.py @@ -34,6 +34,7 @@ def init( q: jax.Array, qd: jax.Array, unused_act: Optional[jax.Array] = None, + unused_ctrl: Optional[jax.Array] = None, debug: bool = False, ) -> State: """Initializes physics state. diff --git a/brax/io/json.py b/brax/io/json.py index 4614f72a..94014f5b 100644 --- a/brax/io/json.py +++ b/brax/io/json.py @@ -16,7 +16,7 @@ """Saves a system config and trajectory as json.""" import json -from typing import List, Text, Tuple +from typing import Optional, List, Text, Tuple from brax.base import State, System from etils import epath @@ -26,12 +26,8 @@ import numpy as np -# State attributes needed for the visualizer. -_STATE_ATTR = ['x', 'contact'] - # Fields that get json encoded. _ENCODE_FIELDS = [ - 'contact', 'opt', 'timestep', 'face', @@ -84,30 +80,6 @@ def _to_dict(obj): return obj -def _compress_contact(states: State) -> State: - """Reduces the number of contacts based on penetration > 0.""" - if states.contact is None or states.contact.pos.shape[0] == 0: - return states - - contact_mask = states.contact.dist < 0 - n_contact = contact_mask.sum(axis=1).max() - - def pad(arr, n): - r = jp.zeros(n) - if len(arr.shape) > 1: - r = jp.zeros((n, *arr.shape[1:])) - r = r.at[: arr.shape[0], ...].set(arr) - return r - - def compress(contact, i): - return jax.tree.map( - lambda x: pad(x[contact_mask[i]], n_contact), contact.take(i) - ) - - c = [compress(states.contact, i) for i in range(states.x.pos.shape[0])] - return states.replace(contact=jax.tree.map(lambda *x: jp.stack(x), *c)) - - def _get_mesh(mj: mujoco.MjModel, i: int) -> Tuple[np.ndarray, np.ndarray]: """Gets mesh from mj at index i.""" last = (i + 1) >= mj.nmesh @@ -179,12 +151,8 @@ def dumps(sys: System, states: List[State]) -> Text: d['geoms'] = link_geoms - # stack states for the viewer - states = jax.tree.map(lambda *x: jp.stack(x), *states) - states = _compress_contact(states) - - states = _to_dict(states) - d['states'] = {k: states[k] for k in _STATE_ATTR} + # add states for the viewer, we only need 'x' (positions and orientations). + d['states'] = {'x': [_to_dict(s.x) for s in states]} return json.dumps(d) diff --git a/brax/io/torch.py b/brax/io/torch.py index c32b3bbd..ce555664 100644 --- a/brax/io/torch.py +++ b/brax/io/torch.py @@ -64,7 +64,7 @@ def _torch_dict_to_jax( @functools.singledispatch -def jax_to_torch(value: Any, device: Device = None) -> Any: +def jax_to_torch(value: Any, device: Union[Device, None] = None) -> Any: """Convert JAX values to PyTorch Tensors. Args: @@ -82,7 +82,7 @@ def jax_to_torch(value: Any, device: Device = None) -> Any: @jax_to_torch.register(jax.Array) def _jaxarray_to_tensor( - value: jax.Array, device: Device = None + value: jax.Array, device: Union[Device, None] = None ) -> torch.Tensor: """Converts a jax.Array into PyTorch Tensor.""" dpack = jax_dlpack.to_dlpack(value.astype("float32")) @@ -95,7 +95,7 @@ def _jaxarray_to_tensor( @jax_to_torch.register(abc.Mapping) def _jax_dict_to_torch( value: Dict[str, Union[jax.Array, Any]], - device: Device = None) -> Dict[str, Union[torch.Tensor, Any]]: + device: Union[Device, None] = None) -> Dict[str, Union[torch.Tensor, Any]]: """Converts a dict of jax.Arrays into a dict of PyTorch tensors.""" return type(value)( **{k: jax_to_torch(v, device=device) for k, v in value.items()}) # type: ignore diff --git a/brax/mjx/pipeline.py b/brax/mjx/pipeline.py index 3a017d17..1fb5c947 100644 --- a/brax/mjx/pipeline.py +++ b/brax/mjx/pipeline.py @@ -25,7 +25,7 @@ def _reformat_contact(sys: System, data: State) -> State: """Reformats the mjx.Contact into a brax.base.Contact.""" - if data.contact is None or data.ncon == 0: + if data.contact is None: return data elasticity = jp.zeros(data.contact.pos.shape[0]) @@ -45,6 +45,7 @@ def init( q: jax.Array, qd: jax.Array, act: Optional[jax.Array] = None, + ctrl: Optional[jax.Array] = None, unused_debug: bool = False, ) -> State: """Initializes physics data. @@ -54,6 +55,7 @@ def init( q: (q_size,) joint angle vector qd: (qd_size,) joint velocity vector act: actuator activations + ctrl: actuator controls unused_debug: ignored Returns: @@ -64,6 +66,8 @@ def init( data = data.replace(qpos=q, qvel=qd) if act is not None: data = data.replace(act=act) + if ctrl is not None: + data = data.replace(ctrl=ctrl) data = mjx.forward(sys, data) @@ -106,5 +110,6 @@ def step( offset = Transform.create(pos=offset) xd = offset.vmap().do(cvel) - data = _reformat_contact(sys, data) + if data.ncon > 0: + data = _reformat_contact(sys, data) return data.replace(q=q, qd=qd, x=x, xd=xd) diff --git a/brax/mjx/pipeline_test.py b/brax/mjx/pipeline_test.py index 12375317..254d175b 100644 --- a/brax/mjx/pipeline_test.py +++ b/brax/mjx/pipeline_test.py @@ -13,10 +13,12 @@ # limitations under the License. # pylint:disable=g-multiple-import +# pylint:disable=g-importing-member """Tests for spring physics pipeline.""" from absl.testing import absltest from brax import test_utils +from brax.base import Contact from brax.mjx import pipeline import jax from jax import numpy as jp @@ -30,6 +32,9 @@ def test_pendulum(self): model = test_utils.load_fixture('double_pendulum.xml') state = pipeline.init(model, model.init_q, jp.zeros(model.qd_size())) + + self.assertIsInstance(state.contact, Contact) + step_fn = jax.jit(pipeline.step) for _ in range(20): state = step_fn(model, state, jp.zeros(model.act_size())) @@ -43,6 +48,17 @@ def test_pendulum(self): np.testing.assert_almost_equal(data.qvel, state.qd, decimal=3) np.testing.assert_almost_equal(data.xpos[1:], state.x.pos, decimal=4) + def test_pipeline_init_with_ctrl(self): + model = test_utils.load_fixture('single_spherical_pendulum_position.xml') + ctrl = jp.array([0.3, 0.5, 0.4]) + state = pipeline.init( + model, + model.init_q, + jp.zeros(model.qd_size()), + ctrl=ctrl, + ) + np.testing.assert_array_almost_equal(state.ctrl, ctrl) + if __name__ == '__main__': absltest.main() diff --git a/brax/positional/pipeline.py b/brax/positional/pipeline.py index 7f666300..4969d5d2 100644 --- a/brax/positional/pipeline.py +++ b/brax/positional/pipeline.py @@ -34,6 +34,7 @@ def init( q: jax.Array, qd: jax.Array, unused_act: Optional[jax.Array] = None, + unused_ctrl: Optional[jax.Array] = None, debug: bool = False, ) -> State: """Initializes physics state. diff --git a/brax/spring/pipeline.py b/brax/spring/pipeline.py index 887e4816..58ce436d 100644 --- a/brax/spring/pipeline.py +++ b/brax/spring/pipeline.py @@ -35,6 +35,7 @@ def init( q: jax.Array, qd: jax.Array, unused_act: Optional[jax.Array] = None, + unused_ctrl: Optional[jax.Array] = None, debug: bool = False, ) -> State: """Initializes physics state. diff --git a/brax/training/learner.py b/brax/training/learner.py index 9dd3c3c5..d8523cc5 100644 --- a/brax/training/learner.py +++ b/brax/training/learner.py @@ -46,8 +46,8 @@ flags.DEFINE_bool('use_v2', True, 'Use Brax v2.') flags.DEFINE_enum( 'backend', - 'generalized', - ['spring', 'generalized', 'positional'], + 'mjx', + ['mjx', 'spring', 'generalized', 'positional'], 'The physics backend to use.', ) flags.DEFINE_bool('legacy_spring', False, 'Brax v1 backend.') @@ -121,7 +121,7 @@ 'A reward shift to get rid of "stay alive" bonus.') # ARS hps. -flags.DEFINE_integer('policy updates', None, +flags.DEFINE_integer('policy_updates', None, 'Number of policy updates in APG.') diff --git a/brax/v1/jumpy.py b/brax/v1/jumpy.py index a5e40ca3..168d9c10 100644 --- a/brax/v1/jumpy.py +++ b/brax/v1/jumpy.py @@ -40,9 +40,13 @@ def _in_jit() -> bool: """Returns true if currently inside a jax.jit call or jit is disabled.""" if jax.config.jax_disable_jit: return True - return 'DynamicJaxprTrace' in str( - jax.core.thread_local_state.trace_state.trace_stack - ) + + if jax.__version_info__ <= (0, 4, 33): + return 'DynamicJaxprTrace' in str( + jax.core.thread_local_state.trace_state.trace_stack + ) + + return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE() def _which_np(*args): diff --git a/brax/v1/tests/jumpy_test.py b/brax/v1/tests/jumpy_test.py index 6cd55b4d..49d62fa0 100644 --- a/brax/v1/tests/jumpy_test.py +++ b/brax/v1/tests/jumpy_test.py @@ -26,7 +26,7 @@ class ForiLoopTest(absltest.TestCase): def testForiLoopTest(self): a = jp.fori_loop(2, 4, lambda i, x: i + x, jp.array(1.)) - self.assertIsInstance(a, np.float_) + self.assertIsInstance(a, np.float64) self.assertEqual(a.shape, ()) self.assertAlmostEqual(a, 1.0 + 2.0 + 3.0) diff --git a/brax/v1/tests/physics_test.py b/brax/v1/tests/physics_test.py index 9bef1458..a906432c 100644 --- a/brax/v1/tests/physics_test.py +++ b/brax/v1/tests/physics_test.py @@ -522,7 +522,8 @@ def test_pendulum_period(self, mass, radius, vel): vel=jp.array([[0., 0., 0.], [0., vel, 0.]]), ang=jp.array([[0., 0., 0.], [vel, 0., 0.]])) qp, _ = jax.jit(sys.step)(qp, jp.array([])) - self.assertAlmostEqual(qp.pos[1, 1], 0., 3) # returned to the origin + # returned to the origin + self.assertAlmostEqual(qp.pos[1, 1], 0.0, delta=1e-3) offsets = [-15, 15, -45, 45, -75, 75] axes = [ diff --git a/brax/visualizer/js/system.js b/brax/visualizer/js/system.js index 671bbd71..93ad6ec1 100644 --- a/brax/visualizer/js/system.js +++ b/brax/visualizer/js/system.js @@ -217,30 +217,11 @@ function createScene(system) { scene.add(parent); }); - if (system.states.contact) { - /* add contact point spheres */ - for (let i = 0; i < system.states.contact.pos[0].length; i++) { - const parent = new THREE.Group(); - parent.name = 'contact' + i; - let child; - - const mat = new THREE.MeshPhongMaterial({color: 0xff0000}); - const sphere_geom = new THREE.SphereGeometry(minAxisSize / 20.0, 6, 6); - child = new THREE.Mesh(sphere_geom, mat); - child.baseMaterial = child.material; - child.castShadow = false; - child.position.set(0, 0, 0); - - parent.add(child); - scene.add(parent); - } - } - return scene; } function createTrajectory(system) { - const times = [...Array(system.states.x.pos.length).keys()].map( + const times = [...Array(system.states.x.length).keys()].map( x => x * system.opt.timestep); const tracks = []; @@ -252,30 +233,17 @@ function createTrajectory(system) { return; } const group = name.replaceAll('/', '_'); // sanitize node name - const pos = system.states.x.pos.map(p => [p[i][0], p[i][1], p[i][2]]); + const pos = system.states.x.map( + x => [x.pos[i][0], x.pos[i][1], x.pos[i][2]]); const rot = - system.states.x.rot.map(r => [r[i][1], r[i][2], r[i][3], r[i][0]]); + system.states.x.map( + x => [x.rot[i][1], x.rot[i][2], x.rot[i][3], x.rot[i][0]]); tracks.push(new THREE.VectorKeyframeTrack( 'scene/' + group + '.position', times, pos.flat())); tracks.push(new THREE.QuaternionKeyframeTrack( 'scene/' + group + '.quaternion', times, rot.flat())); }); - if (system.states.contact) { - /* add contact debug point trajectory */ - for (let i = 0; i < system.states.contact.pos[0].length; i++) { - const group = 'contact' + i; - const pos = system.states.contact.pos.map(p => [p[i][0], p[i][1], p[i][2]]); - const visible = system.states.contact.dist.map(p => p[i] < -1e-6); - tracks.push(new THREE.VectorKeyframeTrack( - 'scene/' + group + '.position', times, pos.flat(), - THREE.InterpolateDiscrete)); - tracks.push(new THREE.BooleanKeyframeTrack( - 'scene/' + group + '.visible', times, visible, - THREE.InterpolateDiscrete)); - } - } - return new THREE.AnimationClip('Action', -1, tracks); } diff --git a/brax/visualizer/js/viewer.js b/brax/visualizer/js/viewer.js index 8082f59c..58ffffde 100644 --- a/brax/visualizer/js/viewer.js +++ b/brax/visualizer/js/viewer.js @@ -38,25 +38,21 @@ function downloadFile(name, contents, mime) { } /** - * Toggles the contact point debug in the scene. + * Toggles the debug mode for the scene. * @param {!ObjType} obj A Scene or Mesh object. - * @param {boolean} debug Whether to add contact debugging. + * @param {boolean} debug Whether to add axes and transparency to the scene. */ -function toggleContactDebug(obj, debug) { +function toggleDebugMode(obj, debug) { for (let i = 0; i < obj.children.length; i++) { let c = obj.children[i]; if (c.type == 'AxesHelper') { /* toggle visibility on world axis */ c.visible = debug; } - if (c.type == 'Group' && c.name && c.name.startsWith('contact')) { - /* toggle visibility of all contact points */ - c.children[0].visible = debug; - } if (c.type == 'Group') { /* recurse over group's children */ for (let j = 0; j < c.children.length; j++) { - toggleContactDebug(c.children[j], debug); + toggleDebugMode(c.children[j], debug); } } } @@ -160,7 +156,7 @@ class Viewer { this.bodyFolders = {}; for (let c of this.scene.children) { - if (!c.name || c.name.startsWith('contact')) continue; + if (!c.name) continue; const folder = bodiesFolder.addFolder(c.name); this.bodyFolders[c.name] = folder; folder.close(); @@ -185,11 +181,11 @@ class Viewer { saveFolder.close(); /* debugger */ - this.contactDebug = system.states.contact !== null; + this.debugMode = false; let debugFolder = this.gui.addFolder('Debugger'); - debugFolder.add(this, 'contactDebug') - .name(system.states.contact ? 'contacts' : 'axis') - .onChange((value) => this.setContactDebug(value)); + debugFolder.add(this, 'debugMode') + .name('axis') + .onChange((value) => this.setDebugMode(value)); /* done setting up the gui */ this.gui.close(); @@ -253,7 +249,7 @@ class Viewer { } render() { - toggleContactDebug(this.scene, this.contactDebug); + toggleDebugMode(this.scene, this.debugMode); this.renderer.render(this.scene, this.camera); this.needsRender = false; } @@ -319,8 +315,8 @@ class Viewer { downloadFile('system.json', JSON.stringify(this.system)); } - setContactDebug(val) { - this.contactDebug = val; + setDebugMode(val) { + this.debugMode = val; } setHover(object, hovering) { diff --git a/docs/release-notes/next-release.md b/docs/release-notes/next-release.md new file mode 100644 index 00000000..cbefb69c --- /dev/null +++ b/docs/release-notes/next-release.md @@ -0,0 +1 @@ +# Brax Release Notes \ No newline at end of file diff --git a/docs/release-notes/v0.11.0.md b/docs/release-notes/v0.11.0.md new file mode 100644 index 00000000..cf568001 --- /dev/null +++ b/docs/release-notes/v0.11.0.md @@ -0,0 +1,5 @@ +# Brax v0.11.0 Release Notes + +* Remove contact debugging from the viewer. This is a breaking change compared to older versions of brax. Old saved files will not render using the new viewer. +* Added `ctrl` as input to the `MJX pipeline.init`. +* Fixes to #513, #512, and #504. diff --git a/setup.py b/setup.py index 3c88c621..76b7dd54 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ setup( name="brax", - version="0.10.5", + version="0.11.0", description="A differentiable physics engine written in JAX.", author="Brax Authors", author_email="no-reply@google.com",