Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633254707
Change-Id: If5931c0a90944c754305a300046e2e10c53e0ff3
  • Loading branch information
Brax Team authored and btaba committed May 13, 2024
1 parent 3fdbd82 commit 0d513cd
Show file tree
Hide file tree
Showing 34 changed files with 109 additions and 71 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ If you would like to reference Brax in a publication, please use:
author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem},
title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation},
url = {http://github.com/google/brax},
version = {0.10.3},
version = {0.10.4},
year = {2021},
}
```
Expand Down
2 changes: 1 addition & 1 deletion brax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Import top-level classes and functions here for encapsulation/clarity."""

__version__ = '0.10.3'
__version__ = '0.10.4'

from brax.base import Motion
from brax.base import State
Expand Down
14 changes: 10 additions & 4 deletions brax/actuator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


def _actuator_step(pipeline, sys, q, qd, act, dt, n):
sys = sys.replace(dt=dt)
sys = sys.tree_replace({'opt.timestep': dt})
def f(state, _):
return jax.jit(pipeline.step)(sys, state, act), None

Expand Down Expand Up @@ -163,7 +163,9 @@ def test_spherical_pendulum_mj_generalized(self, config):
mujoco.mj_step(mj_model, mj_data)
mq, mqd = jp.asarray(mj_data.qpos), jp.asarray(mj_data.qvel)

gq, gqd = _actuator_step(g_pipeline, sys, q, qd, act=act, dt=sys.dt, n=1000)
gq, gqd = _actuator_step(
g_pipeline, sys, q, qd, act=act, dt=sys.opt.timestep, n=1000
)
np.testing.assert_array_almost_equal(gq, mq, 3)
np.testing.assert_array_almost_equal(gqd, mqd, 3)

Expand All @@ -178,8 +180,12 @@ def test_single_pendulum_spring_positional(self, config):

q, qd = sys.init_q, jp.zeros(sys.qd_size())

sq, sqd = _actuator_step(s_pipeline, sys, q, qd, act=act, dt=sys.dt, n=500)
pq, pqd = _actuator_step(p_pipeline, sys, q, qd, act=act, dt=sys.dt, n=500)
sq, sqd = _actuator_step(
s_pipeline, sys, q, qd, act=act, dt=sys.opt.timestep, n=500
)
pq, pqd = _actuator_step(
p_pipeline, sys, q, qd, act=act, dt=sys.opt.timestep, n=500
)
np.testing.assert_array_almost_equal(sq, pq, 2)
np.testing.assert_array_almost_equal(sqd, pqd, 2)

Expand Down
2 changes: 0 additions & 2 deletions brax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ class System(mjx.Model):
r"""Describes a physical environment: its links, joints and geometries.
Attributes:
dt: timestep used for the simulation
gravity: (3,) linear universal force applied during forward dynamics
viscosity: (1,) viscosity of the medium applied to all links
density: (1,) density of the medium applied to all links
Expand Down Expand Up @@ -451,7 +450,6 @@ class System(mjx.Model):
mj_model: mujoco.MjModel that was used to build this brax System
"""

dt: jax.Array
gravity: jax.Array
viscosity: Union[float, jax.Array]
density: Union[float, jax.Array]
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(
n_frames = 5

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.005)
sys = sys.tree_replace({'opt.timestep': 0.005})
n_frames = 10

if backend == 'mjx':
Expand Down
10 changes: 5 additions & 5 deletions brax/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from brax.spring import pipeline as s_pipeline
from flax import struct
import jax
import mujoco
from mujoco import mjx
import numpy as np


Expand Down Expand Up @@ -114,9 +112,11 @@ def __init__(
self._n_frames = n_frames
self._debug = debug

def pipeline_init(self, q: jax.Array, qd: jax.Array) -> base.State:
def pipeline_init(
self, q: jax.Array, qd: jax.Array, act: Optional[jax.Array] = None
) -> base.State:
"""Initializes the pipeline state."""
return self._pipeline.init(self.sys, q, qd, self._debug)
return self._pipeline.init(self.sys, q, qd, act, self._debug)

def pipeline_step(self, pipeline_state: Any, action: jax.Array) -> base.State:
"""Takes a physics step using the physics pipeline."""
Expand All @@ -132,7 +132,7 @@ def f(state, _):
@property
def dt(self) -> jax.Array:
"""The timestep used for each env step."""
return self.sys.dt * self._n_frames
return self.sys.opt.timestep * self._n_frames # pytype: disable=attribute-error

@property
def observation_size(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
n_frames = 5

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.003125)
sys = sys.tree_replace({'opt.timestep': 0.003125})
n_frames = 16
gear = jp.array([120, 90, 60, 120, 100, 100])
sys = sys.replace(actuator=sys.actuator.replace(gear=gear))
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
n_frames = 5

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.0015)
sys = sys.tree_replace({'opt.timestep': 0.0015})
n_frames = 10
gear = jp.array([
350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0,
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/humanoidstandup.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(self, backend='generalized', **kwargs):
n_frames = 5

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.0015)
sys = sys.tree_replace({'opt.timestep': 0.0015})
n_frames = 10
sys = sys.replace(
actuator=sys.actuator.replace(
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/inverted_double_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, backend='generalized', **kwargs):
n_frames = 2

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.005)
sys = sys.tree_replace({'opt.timestep': 0.005})
n_frames = 4

kwargs['n_frames'] = kwargs.get('n_frames', n_frames)
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/inverted_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self, backend='generalized', **kwargs):
n_frames = 2

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.005)
sys = sys.tree_replace({'opt.timestep': 0.005})
n_frames = 4

kwargs['n_frames'] = kwargs.get('n_frames', n_frames)
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(self, backend='generalized', **kwargs):
n_frames = 5

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.001)
sys = sys.tree_replace({'opt.timestep': 0.001})
sys = sys.replace(
actuator=sys.actuator.replace(gear=jp.array([20.0] * sys.act_size()))
)
Expand Down
2 changes: 1 addition & 1 deletion brax/envs/reacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self, backend='generalized', **kwargs):
n_frames = 2

if backend in ['spring', 'positional']:
sys = sys.replace(dt=0.005)
sys = sys.tree_replace({'opt.timestep': 0.005})
sys = sys.replace(
actuator=sys.actuator.replace(gear=jp.array([25.0, 25.0]))
)
Expand Down
2 changes: 1 addition & 1 deletion brax/experimental/barkour/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@
" path = epath.Path('mujoco_menagerie/google_barkour_v0/scene_mjx.xml')\n",
" sys = mjcf.load(path.as_posix())\n",
" self._dt = 0.02 # this environment is 50 fps\n",
" sys = sys.tree_replace({'opt.timestep': 0.004, 'dt': 0.004})\n",
" sys = sys.tree_replace({'opt.timestep': 0.004})\n",
"\n",
" # override menagerie params for smoother policy\n",
" sys = sys.replace(\n",
Expand Down
10 changes: 5 additions & 5 deletions brax/generalized/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@

def _integrate_q_axis(sys: System, q: jax.Array, qd: jax.Array) -> jax.Array:
"""Integrates next q for revolute/prismatic joints."""
return q + qd * sys.dt
return q + qd * sys.opt.timestep


def _integrate_q_free(sys: System, q: jax.Array, qd: jax.Array) -> jax.Array:
"""Integrates next q for free joints."""
rot, ang = q[3:7], qd[3:6]
ang_norm = jp.linalg.norm(ang) + 1e-8
axis = ang / ang_norm
angle = sys.dt * ang_norm
angle = sys.opt.timestep * ang_norm
qrot = math.quat_rot_axis(axis, angle)
rot = math.quat_mul(rot, qrot)
rot = rot / jp.linalg.norm(rot)
pos, vel = q[0:3], qd[0:3]
pos += vel * sys.dt
pos += vel * sys.opt.timestep

return jp.concatenate([pos, rot])

Expand All @@ -56,12 +56,12 @@ def integrate(sys: System, state: State) -> State:
# integrate joint damping implicitly to increase stability when we are not
# using approximate inverse
if sys.matrix_inv_iterations == 0:
mx = state.mass_mx + jp.diag(sys.dof.damping) * sys.dt
mx = state.mass_mx + jp.diag(sys.dof.damping) * sys.opt.timestep
mx_inv = jax.scipy.linalg.solve(mx, jp.eye(sys.qd_size()), assume_a='pos')
else:
mx_inv = state.mass_mx_inv
qdd = mx_inv @ (state.qf_smooth + state.qf_constraint)
qd = state.qd + qdd * sys.dt
qd = state.qd + qdd * sys.opt.timestep

def q_fn(typ, link, q, qd):
q = q.reshape(link.transform.pos.shape[0], -1)
Expand Down
7 changes: 6 additions & 1 deletion brax/generalized/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# pylint:disable=g-multiple-import
"""Physics pipeline for generalized coordinates engine."""

from typing import Optional
from brax import actuator
from brax import contact
from brax import kinematics
Expand All @@ -29,7 +30,11 @@


def init(
sys: System, q: jax.Array, qd: jax.Array, debug: bool = False
sys: System,
q: jax.Array,
qd: jax.Array,
unused_act: Optional[jax.Array] = None,
debug: bool = False,
) -> State:
"""Initializes physics state.
Expand Down
5 changes: 3 additions & 2 deletions brax/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def render(
format=fmt,
append_images=frames[1:],
save_all=True,
duration=sys.dt * 1000,
loop=0)
duration=sys.opt.timestep * 1000,
loop=0,
)

return f.getvalue()
3 changes: 2 additions & 1 deletion brax/io/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
# Fields that get json encoded.
_ENCODE_FIELDS = [
'contact',
'dt',
'opt',
'timestep',
'face',
'size',
'link_idx',
Expand Down
1 change: 0 additions & 1 deletion brax/io/mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,6 @@ def load_model(mj: mujoco.MjModel) -> System:
mjx_model = mjx.put_model(mj)

sys = System( # pytype: disable=wrong-arg-types # jax-ndarray
dt=mj.opt.timestep,
gravity=mj.opt.gravity,
viscosity=mj.opt.viscosity,
density=mj.opt.density,
Expand Down
11 changes: 10 additions & 1 deletion brax/mjx/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Physics pipeline for fully articulated dynamics and collisiion."""
# pylint:disable=g-multiple-import
# pylint:disable=g-importing-member
from typing import Optional
from brax.base import Contact, Motion, System, Transform
from brax.mjx.base import State
import jax
Expand All @@ -40,14 +41,19 @@ def _reformat_contact(sys: System, data: State) -> State:


def init(
sys: System, q: jax.Array, qd: jax.Array, unused_debug: bool = False
sys: System,
q: jax.Array,
qd: jax.Array,
act: Optional[jax.Array] = None,
unused_debug: bool = False,
) -> State:
"""Initializes physics data.
Args:
sys: a brax System
q: (q_size,) joint angle vector
qd: (qd_size,) joint velocity vector
act: actuator activations
unused_debug: ignored
Returns:
Expand All @@ -56,6 +62,9 @@ def init(

data = mjx.make_data(sys)
data = data.replace(qpos=q, qvel=qd)
if act is not None:
data = data.replace(act=act)

data = mjx.forward(sys, data)

q, qd = data.qpos, data.qvel
Expand Down
2 changes: 1 addition & 1 deletion brax/positional/collisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def impulse(contact, dlambda):
v_t = rel_vel - n * v_n
v_t_dir, v_t_norm = math.normalize(v_t)
dvel = -v_t_dir * jp.minimum(
contact.friction[0] * jp.abs(dlambda) / sys.dt, v_t_norm
contact.friction[0] * jp.abs(dlambda) / sys.opt.timestep, v_t_norm
)

angw_1 = jp.cross((contact.pos - x.pos[0]), v_t_dir)
Expand Down
17 changes: 10 additions & 7 deletions brax/positional/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def integrate_xdv(sys: System, xd: Motion, xdv: Motion) -> Motion:
xd: updated velocity
"""
damp = Motion(vel=sys.vel_damping, ang=sys.ang_damping)
xd = jax.tree_map(lambda d, x: jp.exp(d * sys.dt) * x, damp, xd) + xdv
xd = (
jax.tree_map(lambda d, x: jp.exp(d * sys.opt.timestep) * x, damp, xd)
+ xdv
)

return xd

Expand All @@ -58,14 +61,14 @@ def integrate_xdd(
xd: updated velocity
"""

xd = xd + xdd * sys.dt
xd = xd + xdd * sys.opt.timestep
damp = Motion(vel=sys.vel_damping, ang=sys.ang_damping)
xd = jax.tree_map(lambda d, x: jp.exp(d * sys.dt) * x, damp, xd)
xd = jax.tree_map(lambda d, x: jp.exp(d * sys.opt.timestep) * x, damp, xd)

@jax.vmap
def op(x, xd):
pos = x.pos + xd.vel * sys.dt
rot_at_ang_quat = math.ang_to_quat(xd.ang) * 0.5 * sys.dt
pos = x.pos + xd.vel * sys.opt.timestep
rot_at_ang_quat = math.ang_to_quat(xd.ang) * 0.5 * sys.opt.timestep
rot, _ = math.normalize(x.rot + math.quat_mul(rot_at_ang_quat, x.rot))
return Transform(pos=pos, rot=rot)

Expand All @@ -91,9 +94,9 @@ def project_xd(sys: System, x: Transform, x_prev: Transform) -> Motion:

@jax.vmap
def op(x, x_prev):
vel = (x.pos - x_prev.pos) / sys.dt
vel = (x.pos - x_prev.pos) / sys.opt.timestep
dq = math.relative_quat(x_prev.rot, x.rot)
ang = 2.0 * dq[1:] / sys.dt
ang = 2.0 * dq[1:] / sys.opt.timestep
scale = jp.where(dq[0] >= 0.0, 1.0, -1.0)
ang = scale * ang
return Motion(vel=vel, ang=ang)
Expand Down
7 changes: 6 additions & 1 deletion brax/positional/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Physics pipeline for fully articulated dynamics and collisiion."""
# pylint:disable=g-multiple-import
from typing import Optional
from brax import actuator
from brax import com
from brax import contact
Expand All @@ -29,7 +30,11 @@


def init(
sys: System, q: jax.Array, qd: jax.Array, debug: bool = False
sys: System,
q: jax.Array,
qd: jax.Array,
unused_act: Optional[jax.Array] = None,
debug: bool = False,
) -> State:
"""Initializes physics state.
Expand Down
Loading

0 comments on commit 0d513cd

Please sign in to comment.