Skip to content

Commit

Permalink
feat: rename and started rne post constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
lvjonok committed Jul 4, 2024
1 parent 723deee commit c4a977d
Show file tree
Hide file tree
Showing 6 changed files with 474 additions and 96 deletions.
388 changes: 388 additions & 0 deletions examples/match_rne.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,388 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-07-04 10:15:46.022376: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.typing as jpt\n",
"import mujoco as mj\n",
"import mujoco.mjx as mjx\n",
"from mujoco.mjx._src import scan\n",
"from mujoco.mjx._src.math import transform_motion\n",
"from mujoco.mjx._src.types import DisableBit\n",
"from robot_descriptions.z1_mj_description import MJCF_PATH\n",
"\n",
"jnp.set_printoptions(precision=5, suppress=True, linewidth=500)\n",
"\n",
"key = jax.random.PRNGKey(0)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/lvjonok/github.com/lvjonok/mujoco-sysid/venv/lib/python3.10/site-packages/jax/_src/interpreters/xla.py:155: RuntimeWarning: overflow encountered in cast\n",
" return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))\n"
]
}
],
"source": [
"mjmodel = mj.MjModel.from_xml_path(MJCF_PATH)\n",
"mjdata = mj.MjData(mjmodel)\n",
"\n",
"# alter the model so it becomes mjx compatible\n",
"mjmodel.dof_frictionloss = 0\n",
"mjmodel.opt.integrator = 1\n",
"\n",
"mjxmodel = mjx.put_model(mjmodel)\n",
"mjxdata = mjx.put_data(mjmodel, mjdata)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"q, v, dv = jax.random.normal(key, (3, mjmodel.nq))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"mj_inverse + mj_rnePostConstraint"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mjdata.qpos = q\n",
"mjdata.qvel = v\n",
"mjdata.qacc = dv\n",
"\n",
"mj.mj_inverse(mjmodel, mjdata)\n",
"mjdata.cacc"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0. , 0. , 0. , -0. , -0. , 9.81 ],\n",
" [ 0. , 0. , 0. , 0. , 0. , 9.81 ],\n",
" [ 0. , 0. , -0.24564, 0.04451, 0.03984, 9.81 ],\n",
" [ 0.26177, 0.17445, -0.24564, 0.05323, 0.02676, 9.88573],\n",
" [ 2.22556, -0.4742 , -0.24564, -0.02487, -0.14671, 9.75662],\n",
" [ 3.04146, -1.35822, -0.24564, -0.14381, 0.16344, 10.17372],\n",
" [ 3.97377, -0.57563, 3.77315, 0.23265, 0.7747 , 10.4696 ],\n",
" [ 4.84202, 0.20971, 3.80627, 0.14318, 0.83105, 10.06915]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mj.mj_rnePostConstraint(mjmodel, mjdata)\n",
"\n",
"mjdata.cacc"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[ 0. , 0. , 1. , -0.1812 , -0.1622 , 0. ],\n",
" [ 0.7655 , 0.64344, 0. , 0.03215, -0.03825, 0.24308],\n",
" [ 0.7655 , 0.64344, 0. , 0.06266, -0.07455, -0.1037 ],\n",
" [ 0.7655 , 0.64344, 0. , -0.0581 , 0.06912, -0.22839],\n",
" [-0.63989, 0.76128, -0.10492, -0.14118, -0.11763, 0.00753],\n",
" [-0.75908, -0.60487, 0.24068, 0.14133, -0.08628, 0.22892]]),\n",
" array([[ 0. , 0. , 0. , 0. , 0. , 0. ],\n",
" [-0.61642, 0.73335, 0. , 0.03664, 0.0308 , 0.00726],\n",
" [-0.61642, 0.73335, 0. , 0.07954, 0.05036, 0.00898],\n",
" [-0.61642, 0.73335, 0. , 0.09224, -0.24418, -0.34267],\n",
" [-0.51814, -0.86424, -3.1108 , -0.24798, -0.43955, -0.23407],\n",
" [-0.05795, -0.44941, -1.31224, -0.30063, 0.1761 , 0.61546]]))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mjdata.cdof, mjdata.cdof_dot"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"mjx_rne + mjx_rnePostConstraint"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"mjxdata = mjxdata.replace(qpos=q, qvel=v, qacc=dv)\n",
"\n",
"\n",
"def mjx_invPosition(m: mjx.Model, d: mjx.Data) -> mjx.Data:\n",
" d = mjx.kinematics(m, d)\n",
" d = mjx.com_pos(m, d)\n",
" d = mjx.camlight(m, d)\n",
" # flex is missing\n",
" # tendon is missing\n",
"\n",
" d = mjx.crb(m, d)\n",
" d = mjx.factor_m(m, d)\n",
" d = mjx.collision(m, d)\n",
" d = mjx.make_constraint(m, d)\n",
" d = mjx.transmission(m, d)\n",
"\n",
" return d\n",
"\n",
"\n",
"def mjx_invVelocity(m: mjx.Model, d: mjx.Data) -> mjx.Data:\n",
" d = mjx.fwd_velocity(m, d)\n",
"\n",
" return d\n",
"\n",
"\n",
"def mjx_invConstraint(m: mjx.Model, d: mjx.Data) -> mjx.Data:\n",
" # return data if there are no constraints\n",
" if d.nefc == 0:\n",
" return d\n",
"\n",
" # jar = Jac*qacc - aref\n",
" jar = d.efc_J @ d.qacc - d.efc_aref"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mjx_inverse(m: mjx.Model, d: mjx.Data) -> mjx.Data:\n",
" d = mjx_invPosition(m, d)\n",
" d = mjx_invVelocity(m, d)\n",
"\n",
" # acceleration dependent\n",
" mjx\n",
"\n",
" d = mjx.rne(m, d)\n",
"\n",
" return d\n",
"\n",
"\n",
"mjxdata = mjx.rne(mjxmodel, mjxdata)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([[0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [0. , 0. , 0. , 0. , 0. , 9.81]], dtype=float32)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def com_acc(m: mjx.Model, d: mjx.Data) -> jpt.ArrayLike:\n",
" # forward scan over tree: accumulate link center of mass acceleration\n",
" def cacc_fn(cacc, cdof_dot, qvel):\n",
" if cacc is None:\n",
" if m.opt.disableflags & DisableBit.GRAVITY:\n",
" cacc = jnp.zeros((6,))\n",
" else:\n",
" cacc = jnp.concatenate((jnp.zeros((3,)), -m.opt.gravity))\n",
"\n",
" cacc += jnp.sum(jax.vmap(jnp.multiply)(cdof_dot, qvel), axis=0)\n",
"\n",
" return cacc\n",
"\n",
" return scan.body_tree(m, cacc_fn, \"vv\", \"b\", d.cdof_dot, d.qvel)\n",
"\n",
"\n",
"com_acc(mjxmodel, mjxdata)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Array([[0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.]], dtype=float32),\n",
" Array([[0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0.]], dtype=float32))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mjxdata.cdof, mjxdata.cdof_dot"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([[ 0. , 0. , 0. , -0. , -0. , 9.81],\n",
" [ 0. , 0. , 0. , -0. , -0. , 9.81],\n",
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n",
" [ 0. , 0. , 0. , 0. , 0. , 9.81]], dtype=float32)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def mjx_rnePostConstraint(m: mjx.Model, d: mjx.Data):\n",
" nbody = m.nbody\n",
"\n",
" all_cacc = jnp.zeros((nbody, 6))\n",
"\n",
" # clear cacc, set world acceleration to -gravity\n",
" if not m.opt.disableflags & DisableBit.GRAVITY:\n",
" all_cacc = all_cacc.at[0, 3:].set(-m.opt.gravity)\n",
"\n",
" # FIXME: assumption that xfrc_applied is zero\n",
" # FIXME: assumption that contacts are zero\n",
" # FIXME: assumption that connect and weld constraints are zero\n",
"\n",
" # forward pass over bodies: compute acc\n",
" for j in range(nbody):\n",
" bda = m.body_dofadr[j]\n",
"\n",
" # cacc = cacc_parent + cdofdot * qvel + cdof * qacc\n",
" cacc_j = all_cacc[m.body_parentid[j]] + d.cdof_dot[bda] * d.qvel[bda] + d.cdof[bda] * d.qacc[bda]\n",
" all_cacc = all_cacc.at[j].set(cacc_j)\n",
"\n",
" return all_cacc\n",
"\n",
"\n",
"mjx_rnePostConstraint(mjxmodel, mjxdata)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit c4a977d

Please sign in to comment.