Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683315255
Change-Id: I7e46153ca4e925d5708de0682609f97fe926a49a
  • Loading branch information
Brax Team authored and btaba committed Oct 7, 2024
1 parent c89a3ad commit 865d974
Show file tree
Hide file tree
Showing 19 changed files with 77 additions and 106 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ If you would like to reference Brax in a publication, please use:
author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem},
title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation},
url = {http://github.com/google/brax},
version = {0.10.5},
version = {0.11.0},
year = {2021},
}
```
Expand Down
2 changes: 1 addition & 1 deletion brax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Import top-level classes and functions here for encapsulation/clarity."""

__version__ = '0.10.5'
__version__ = '0.11.0'

from brax.base import Motion
from brax.base import State
Expand Down
8 changes: 6 additions & 2 deletions brax/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,14 @@ def __init__(
self._debug = debug

def pipeline_init(
self, q: jax.Array, qd: jax.Array, act: Optional[jax.Array] = None
self,
q: jax.Array,
qd: jax.Array,
act: Optional[jax.Array] = None,
ctrl: Optional[jax.Array] = None,
) -> base.State:
"""Initializes the pipeline state."""
return self._pipeline.init(self.sys, q, qd, act, self._debug)
return self._pipeline.init(self.sys, q, qd, act, ctrl, self._debug)

def pipeline_step(self, pipeline_state: Any, action: jax.Array) -> base.State:
"""Takes a physics step using the physics pipeline."""
Expand Down
1 change: 1 addition & 0 deletions brax/generalized/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def init(
q: jax.Array,
qd: jax.Array,
unused_act: Optional[jax.Array] = None,
unused_ctrl: Optional[jax.Array] = None,
debug: bool = False,
) -> State:
"""Initializes physics state.
Expand Down
38 changes: 3 additions & 35 deletions brax/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Saves a system config and trajectory as json."""

import json
from typing import List, Text, Tuple
from typing import Optional, List, Text, Tuple

from brax.base import State, System
from etils import epath
Expand All @@ -26,12 +26,8 @@
import numpy as np


# State attributes needed for the visualizer.
_STATE_ATTR = ['x', 'contact']

# Fields that get json encoded.
_ENCODE_FIELDS = [
'contact',
'opt',
'timestep',
'face',
Expand Down Expand Up @@ -84,30 +80,6 @@ def _to_dict(obj):
return obj


def _compress_contact(states: State) -> State:
"""Reduces the number of contacts based on penetration > 0."""
if states.contact is None or states.contact.pos.shape[0] == 0:
return states

contact_mask = states.contact.dist < 0
n_contact = contact_mask.sum(axis=1).max()

def pad(arr, n):
r = jp.zeros(n)
if len(arr.shape) > 1:
r = jp.zeros((n, *arr.shape[1:]))
r = r.at[: arr.shape[0], ...].set(arr)
return r

def compress(contact, i):
return jax.tree.map(
lambda x: pad(x[contact_mask[i]], n_contact), contact.take(i)
)

c = [compress(states.contact, i) for i in range(states.x.pos.shape[0])]
return states.replace(contact=jax.tree.map(lambda *x: jp.stack(x), *c))


def _get_mesh(mj: mujoco.MjModel, i: int) -> Tuple[np.ndarray, np.ndarray]:
"""Gets mesh from mj at index i."""
last = (i + 1) >= mj.nmesh
Expand Down Expand Up @@ -179,12 +151,8 @@ def dumps(sys: System, states: List[State]) -> Text:

d['geoms'] = link_geoms

# stack states for the viewer
states = jax.tree.map(lambda *x: jp.stack(x), *states)
states = _compress_contact(states)

states = _to_dict(states)
d['states'] = {k: states[k] for k in _STATE_ATTR}
# add states for the viewer, we only need 'x' (positions and orientations).
d['states'] = {'x': [_to_dict(s.x) for s in states]}

return json.dumps(d)

Expand Down
6 changes: 3 additions & 3 deletions brax/io/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _torch_dict_to_jax(


@functools.singledispatch
def jax_to_torch(value: Any, device: Device = None) -> Any:
def jax_to_torch(value: Any, device: Union[Device, None] = None) -> Any:
"""Convert JAX values to PyTorch Tensors.
Args:
Expand All @@ -82,7 +82,7 @@ def jax_to_torch(value: Any, device: Device = None) -> Any:

@jax_to_torch.register(jax.Array)
def _jaxarray_to_tensor(
value: jax.Array, device: Device = None
value: jax.Array, device: Union[Device, None] = None
) -> torch.Tensor:
"""Converts a jax.Array into PyTorch Tensor."""
dpack = jax_dlpack.to_dlpack(value.astype("float32"))
Expand All @@ -95,7 +95,7 @@ def _jaxarray_to_tensor(
@jax_to_torch.register(abc.Mapping)
def _jax_dict_to_torch(
value: Dict[str, Union[jax.Array, Any]],
device: Device = None) -> Dict[str, Union[torch.Tensor, Any]]:
device: Union[Device, None] = None) -> Dict[str, Union[torch.Tensor, Any]]:
"""Converts a dict of jax.Arrays into a dict of PyTorch tensors."""
return type(value)(
**{k: jax_to_torch(v, device=device) for k, v in value.items()}) # type: ignore
9 changes: 7 additions & 2 deletions brax/mjx/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def _reformat_contact(sys: System, data: State) -> State:
"""Reformats the mjx.Contact into a brax.base.Contact."""
if data.contact is None or data.ncon == 0:
if data.contact is None:
return data

elasticity = jp.zeros(data.contact.pos.shape[0])
Expand All @@ -45,6 +45,7 @@ def init(
q: jax.Array,
qd: jax.Array,
act: Optional[jax.Array] = None,
ctrl: Optional[jax.Array] = None,
unused_debug: bool = False,
) -> State:
"""Initializes physics data.
Expand All @@ -54,6 +55,7 @@ def init(
q: (q_size,) joint angle vector
qd: (qd_size,) joint velocity vector
act: actuator activations
ctrl: actuator controls
unused_debug: ignored
Returns:
Expand All @@ -64,6 +66,8 @@ def init(
data = data.replace(qpos=q, qvel=qd)
if act is not None:
data = data.replace(act=act)
if ctrl is not None:
data = data.replace(ctrl=ctrl)

data = mjx.forward(sys, data)

Expand Down Expand Up @@ -106,5 +110,6 @@ def step(
offset = Transform.create(pos=offset)
xd = offset.vmap().do(cvel)

data = _reformat_contact(sys, data)
if data.ncon > 0:
data = _reformat_contact(sys, data)
return data.replace(q=q, qd=qd, x=x, xd=xd)
16 changes: 16 additions & 0 deletions brax/mjx/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.

# pylint:disable=g-multiple-import
# pylint:disable=g-importing-member
"""Tests for spring physics pipeline."""

from absl.testing import absltest
from brax import test_utils
from brax.base import Contact
from brax.mjx import pipeline
import jax
from jax import numpy as jp
Expand All @@ -30,6 +32,9 @@ def test_pendulum(self):
model = test_utils.load_fixture('double_pendulum.xml')

state = pipeline.init(model, model.init_q, jp.zeros(model.qd_size()))

self.assertIsInstance(state.contact, Contact)

step_fn = jax.jit(pipeline.step)
for _ in range(20):
state = step_fn(model, state, jp.zeros(model.act_size()))
Expand All @@ -43,6 +48,17 @@ def test_pendulum(self):
np.testing.assert_almost_equal(data.qvel, state.qd, decimal=3)
np.testing.assert_almost_equal(data.xpos[1:], state.x.pos, decimal=4)

def test_pipeline_init_with_ctrl(self):
model = test_utils.load_fixture('single_spherical_pendulum_position.xml')
ctrl = jp.array([0.3, 0.5, 0.4])
state = pipeline.init(
model,
model.init_q,
jp.zeros(model.qd_size()),
ctrl=ctrl,
)
np.testing.assert_array_almost_equal(state.ctrl, ctrl)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions brax/positional/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def init(
q: jax.Array,
qd: jax.Array,
unused_act: Optional[jax.Array] = None,
unused_ctrl: Optional[jax.Array] = None,
debug: bool = False,
) -> State:
"""Initializes physics state.
Expand Down
1 change: 1 addition & 0 deletions brax/spring/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def init(
q: jax.Array,
qd: jax.Array,
unused_act: Optional[jax.Array] = None,
unused_ctrl: Optional[jax.Array] = None,
debug: bool = False,
) -> State:
"""Initializes physics state.
Expand Down
6 changes: 3 additions & 3 deletions brax/training/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
flags.DEFINE_bool('use_v2', True, 'Use Brax v2.')
flags.DEFINE_enum(
'backend',
'generalized',
['spring', 'generalized', 'positional'],
'mjx',
['mjx', 'spring', 'generalized', 'positional'],
'The physics backend to use.',
)
flags.DEFINE_bool('legacy_spring', False, 'Brax v1 backend.')
Expand Down Expand Up @@ -121,7 +121,7 @@
'A reward shift to get rid of "stay alive" bonus.')

# ARS hps.
flags.DEFINE_integer('policy updates', None,
flags.DEFINE_integer('policy_updates', None,
'Number of policy updates in APG.')


Expand Down
10 changes: 7 additions & 3 deletions brax/v1/jumpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ def _in_jit() -> bool:
"""Returns true if currently inside a jax.jit call or jit is disabled."""
if jax.config.jax_disable_jit:
return True
return 'DynamicJaxprTrace' in str(
jax.core.thread_local_state.trace_state.trace_stack
)

if jax.__version_info__ <= (0, 4, 33):
return 'DynamicJaxprTrace' in str(
jax.core.thread_local_state.trace_state.trace_stack
)

return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE()


def _which_np(*args):
Expand Down
2 changes: 1 addition & 1 deletion brax/v1/tests/jumpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ForiLoopTest(absltest.TestCase):

def testForiLoopTest(self):
a = jp.fori_loop(2, 4, lambda i, x: i + x, jp.array(1.))
self.assertIsInstance(a, np.float_)
self.assertIsInstance(a, np.float64)
self.assertEqual(a.shape, ())
self.assertAlmostEqual(a, 1.0 + 2.0 + 3.0)

Expand Down
3 changes: 2 additions & 1 deletion brax/v1/tests/physics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,8 @@ def test_pendulum_period(self, mass, radius, vel):
vel=jp.array([[0., 0., 0.], [0., vel, 0.]]),
ang=jp.array([[0., 0., 0.], [vel, 0., 0.]]))
qp, _ = jax.jit(sys.step)(qp, jp.array([]))
self.assertAlmostEqual(qp.pos[1, 1], 0., 3) # returned to the origin
# returned to the origin
self.assertAlmostEqual(qp.pos[1, 1], 0.0, delta=1e-3)

offsets = [-15, 15, -45, 45, -75, 75]
axes = [
Expand Down
42 changes: 5 additions & 37 deletions brax/visualizer/js/system.js
Original file line number Diff line number Diff line change
Expand Up @@ -217,30 +217,11 @@ function createScene(system) {
scene.add(parent);
});

if (system.states.contact) {
/* add contact point spheres */
for (let i = 0; i < system.states.contact.pos[0].length; i++) {
const parent = new THREE.Group();
parent.name = 'contact' + i;
let child;

const mat = new THREE.MeshPhongMaterial({color: 0xff0000});
const sphere_geom = new THREE.SphereGeometry(minAxisSize / 20.0, 6, 6);
child = new THREE.Mesh(sphere_geom, mat);
child.baseMaterial = child.material;
child.castShadow = false;
child.position.set(0, 0, 0);

parent.add(child);
scene.add(parent);
}
}

return scene;
}

function createTrajectory(system) {
const times = [...Array(system.states.x.pos.length).keys()].map(
const times = [...Array(system.states.x.length).keys()].map(
x => x * system.opt.timestep);
const tracks = [];

Expand All @@ -252,30 +233,17 @@ function createTrajectory(system) {
return;
}
const group = name.replaceAll('/', '_'); // sanitize node name
const pos = system.states.x.pos.map(p => [p[i][0], p[i][1], p[i][2]]);
const pos = system.states.x.map(
x => [x.pos[i][0], x.pos[i][1], x.pos[i][2]]);
const rot =
system.states.x.rot.map(r => [r[i][1], r[i][2], r[i][3], r[i][0]]);
system.states.x.map(
x => [x.rot[i][1], x.rot[i][2], x.rot[i][3], x.rot[i][0]]);
tracks.push(new THREE.VectorKeyframeTrack(
'scene/' + group + '.position', times, pos.flat()));
tracks.push(new THREE.QuaternionKeyframeTrack(
'scene/' + group + '.quaternion', times, rot.flat()));
});

if (system.states.contact) {
/* add contact debug point trajectory */
for (let i = 0; i < system.states.contact.pos[0].length; i++) {
const group = 'contact' + i;
const pos = system.states.contact.pos.map(p => [p[i][0], p[i][1], p[i][2]]);
const visible = system.states.contact.dist.map(p => p[i] < -1e-6);
tracks.push(new THREE.VectorKeyframeTrack(
'scene/' + group + '.position', times, pos.flat(),
THREE.InterpolateDiscrete));
tracks.push(new THREE.BooleanKeyframeTrack(
'scene/' + group + '.visible', times, visible,
THREE.InterpolateDiscrete));
}
}

return new THREE.AnimationClip('Action', -1, tracks);
}

Expand Down
Loading

0 comments on commit 865d974

Please sign in to comment.