-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: rename and started rne post constraint
- Loading branch information
Showing
6 changed files
with
474 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,388 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"2024-07-04 10:15:46.022376: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"import jax.typing as jpt\n", | ||
"import mujoco as mj\n", | ||
"import mujoco.mjx as mjx\n", | ||
"from mujoco.mjx._src import scan\n", | ||
"from mujoco.mjx._src.math import transform_motion\n", | ||
"from mujoco.mjx._src.types import DisableBit\n", | ||
"from robot_descriptions.z1_mj_description import MJCF_PATH\n", | ||
"\n", | ||
"jnp.set_printoptions(precision=5, suppress=True, linewidth=500)\n", | ||
"\n", | ||
"key = jax.random.PRNGKey(0)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/lvjonok/github.com/lvjonok/mujoco-sysid/venv/lib/python3.10/site-packages/jax/_src/interpreters/xla.py:155: RuntimeWarning: overflow encountered in cast\n", | ||
" return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"mjmodel = mj.MjModel.from_xml_path(MJCF_PATH)\n", | ||
"mjdata = mj.MjData(mjmodel)\n", | ||
"\n", | ||
"# alter the model so it becomes mjx compatible\n", | ||
"mjmodel.dof_frictionloss = 0\n", | ||
"mjmodel.opt.integrator = 1\n", | ||
"\n", | ||
"mjxmodel = mjx.put_model(mjmodel)\n", | ||
"mjxdata = mjx.put_data(mjmodel, mjdata)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"q, v, dv = jax.random.normal(key, (3, mjmodel.nq))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"mj_inverse + mj_rnePostConstraint" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([[0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.]])" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"mjdata.qpos = q\n", | ||
"mjdata.qvel = v\n", | ||
"mjdata.qacc = dv\n", | ||
"\n", | ||
"mj.mj_inverse(mjmodel, mjdata)\n", | ||
"mjdata.cacc" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([[ 0. , 0. , 0. , -0. , -0. , 9.81 ],\n", | ||
" [ 0. , 0. , 0. , 0. , 0. , 9.81 ],\n", | ||
" [ 0. , 0. , -0.24564, 0.04451, 0.03984, 9.81 ],\n", | ||
" [ 0.26177, 0.17445, -0.24564, 0.05323, 0.02676, 9.88573],\n", | ||
" [ 2.22556, -0.4742 , -0.24564, -0.02487, -0.14671, 9.75662],\n", | ||
" [ 3.04146, -1.35822, -0.24564, -0.14381, 0.16344, 10.17372],\n", | ||
" [ 3.97377, -0.57563, 3.77315, 0.23265, 0.7747 , 10.4696 ],\n", | ||
" [ 4.84202, 0.20971, 3.80627, 0.14318, 0.83105, 10.06915]])" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"mj.mj_rnePostConstraint(mjmodel, mjdata)\n", | ||
"\n", | ||
"mjdata.cacc" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(array([[ 0. , 0. , 1. , -0.1812 , -0.1622 , 0. ],\n", | ||
" [ 0.7655 , 0.64344, 0. , 0.03215, -0.03825, 0.24308],\n", | ||
" [ 0.7655 , 0.64344, 0. , 0.06266, -0.07455, -0.1037 ],\n", | ||
" [ 0.7655 , 0.64344, 0. , -0.0581 , 0.06912, -0.22839],\n", | ||
" [-0.63989, 0.76128, -0.10492, -0.14118, -0.11763, 0.00753],\n", | ||
" [-0.75908, -0.60487, 0.24068, 0.14133, -0.08628, 0.22892]]),\n", | ||
" array([[ 0. , 0. , 0. , 0. , 0. , 0. ],\n", | ||
" [-0.61642, 0.73335, 0. , 0.03664, 0.0308 , 0.00726],\n", | ||
" [-0.61642, 0.73335, 0. , 0.07954, 0.05036, 0.00898],\n", | ||
" [-0.61642, 0.73335, 0. , 0.09224, -0.24418, -0.34267],\n", | ||
" [-0.51814, -0.86424, -3.1108 , -0.24798, -0.43955, -0.23407],\n", | ||
" [-0.05795, -0.44941, -1.31224, -0.30063, 0.1761 , 0.61546]]))" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"mjdata.cdof, mjdata.cdof_dot" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"mjx_rne + mjx_rnePostConstraint" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"mjxdata = mjxdata.replace(qpos=q, qvel=v, qacc=dv)\n", | ||
"\n", | ||
"\n", | ||
"def mjx_invPosition(m: mjx.Model, d: mjx.Data) -> mjx.Data:\n", | ||
" d = mjx.kinematics(m, d)\n", | ||
" d = mjx.com_pos(m, d)\n", | ||
" d = mjx.camlight(m, d)\n", | ||
" # flex is missing\n", | ||
" # tendon is missing\n", | ||
"\n", | ||
" d = mjx.crb(m, d)\n", | ||
" d = mjx.factor_m(m, d)\n", | ||
" d = mjx.collision(m, d)\n", | ||
" d = mjx.make_constraint(m, d)\n", | ||
" d = mjx.transmission(m, d)\n", | ||
"\n", | ||
" return d\n", | ||
"\n", | ||
"\n", | ||
"def mjx_invVelocity(m: mjx.Model, d: mjx.Data) -> mjx.Data:\n", | ||
" d = mjx.fwd_velocity(m, d)\n", | ||
"\n", | ||
" return d\n", | ||
"\n", | ||
"\n", | ||
"def mjx_invConstraint(m: mjx.Model, d: mjx.Data) -> mjx.Data:\n", | ||
" # return data if there are no constraints\n", | ||
" if d.nefc == 0:\n", | ||
" return d\n", | ||
"\n", | ||
" # jar = Jac*qacc - aref\n", | ||
" jar = d.efc_J @ d.qacc - d.efc_aref" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def mjx_inverse(m: mjx.Model, d: mjx.Data) -> mjx.Data:\n", | ||
" d = mjx_invPosition(m, d)\n", | ||
" d = mjx_invVelocity(m, d)\n", | ||
"\n", | ||
" # acceleration dependent\n", | ||
" mjx\n", | ||
"\n", | ||
" d = mjx.rne(m, d)\n", | ||
"\n", | ||
" return d\n", | ||
"\n", | ||
"\n", | ||
"mjxdata = mjx.rne(mjxmodel, mjxdata)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Array([[0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [0. , 0. , 0. , 0. , 0. , 9.81]], dtype=float32)" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"def com_acc(m: mjx.Model, d: mjx.Data) -> jpt.ArrayLike:\n", | ||
" # forward scan over tree: accumulate link center of mass acceleration\n", | ||
" def cacc_fn(cacc, cdof_dot, qvel):\n", | ||
" if cacc is None:\n", | ||
" if m.opt.disableflags & DisableBit.GRAVITY:\n", | ||
" cacc = jnp.zeros((6,))\n", | ||
" else:\n", | ||
" cacc = jnp.concatenate((jnp.zeros((3,)), -m.opt.gravity))\n", | ||
"\n", | ||
" cacc += jnp.sum(jax.vmap(jnp.multiply)(cdof_dot, qvel), axis=0)\n", | ||
"\n", | ||
" return cacc\n", | ||
"\n", | ||
" return scan.body_tree(m, cacc_fn, \"vv\", \"b\", d.cdof_dot, d.qvel)\n", | ||
"\n", | ||
"\n", | ||
"com_acc(mjxmodel, mjxdata)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"(Array([[0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.]], dtype=float32),\n", | ||
" Array([[0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.],\n", | ||
" [0., 0., 0., 0., 0., 0.]], dtype=float32))" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"mjxdata.cdof, mjxdata.cdof_dot" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Array([[ 0. , 0. , 0. , -0. , -0. , 9.81],\n", | ||
" [ 0. , 0. , 0. , -0. , -0. , 9.81],\n", | ||
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [ 0. , 0. , 0. , 0. , 0. , 9.81],\n", | ||
" [ 0. , 0. , 0. , 0. , 0. , 9.81]], dtype=float32)" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"def mjx_rnePostConstraint(m: mjx.Model, d: mjx.Data):\n", | ||
" nbody = m.nbody\n", | ||
"\n", | ||
" all_cacc = jnp.zeros((nbody, 6))\n", | ||
"\n", | ||
" # clear cacc, set world acceleration to -gravity\n", | ||
" if not m.opt.disableflags & DisableBit.GRAVITY:\n", | ||
" all_cacc = all_cacc.at[0, 3:].set(-m.opt.gravity)\n", | ||
"\n", | ||
" # FIXME: assumption that xfrc_applied is zero\n", | ||
" # FIXME: assumption that contacts are zero\n", | ||
" # FIXME: assumption that connect and weld constraints are zero\n", | ||
"\n", | ||
" # forward pass over bodies: compute acc\n", | ||
" for j in range(nbody):\n", | ||
" bda = m.body_dofadr[j]\n", | ||
"\n", | ||
" # cacc = cacc_parent + cdofdot * qvel + cdof * qacc\n", | ||
" cacc_j = all_cacc[m.body_parentid[j]] + d.cdof_dot[bda] * d.qvel[bda] + d.cdof[bda] * d.qacc[bda]\n", | ||
" all_cacc = all_cacc.at[j].set(cacc_j)\n", | ||
"\n", | ||
" return all_cacc\n", | ||
"\n", | ||
"\n", | ||
"mjx_rnePostConstraint(mjxmodel, mjxdata)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.