Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/match dynamics #6

Merged
merged 15 commits into from
Jul 4, 2024
3 changes: 2 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ jobs:
pip install ruff==0.1.11
# Update output format to enable automatic inline annotations.
- name: Run Ruff
run: ruff check --output-format=github .
# run: ruff check --output-format=github .
run: ruff format --diff .

tests:
runs-on: ubuntu-latest
Expand Down
1 change: 1 addition & 0 deletions data/ltv_lqr_traj.json

Large diffs are not rendered by default.

212 changes: 212 additions & 0 deletions examples/dynamics_a1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import mujoco\n",
"import numpy as np\n",
"import pinocchio as pin\n",
"from robot_descriptions.h1_description import URDF_PATH\n",
"from robot_descriptions.h1_mj_description import MJCF_PATH\n",
"\n",
"from mujoco_sysid import parameters\n",
"from mujoco_sysid.utils import muj2pin\n",
"\n",
"mjmodel = mujoco.MjModel.from_xml_path(MJCF_PATH)\n",
"mjdata = mujoco.MjData(mjmodel)\n",
"\n",
"pinmodel = pin.buildModelFromUrdf(URDF_PATH, pin.JointModelFreeFlyer())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup the model to have the same dynamic parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 5.39 -0.00108 0.00022 -0.24374 0.0556 0.00009 0.01927 -0.00007 0.00001 0.04902]\n",
"[ 5.39 -0.00108 0.00022 -0.24374 0.0556 0.00009 0.01927 -0.00007 0.00001 0.04902]\n",
"[ 2.244 -0.11047 0.00022 0.01616 0.00269 0.00002 0.0086 0.00028 0. 0.00773]\n",
"[ 2.244 -0.11047 0.00022 0.01616 0.00269 0.00002 0.0086 0.00028 0. 0.00773]\n",
"[ 2.232 -0.01295 -0.00712 -0.0002 0.00208 -0.00001 0.00232 0. -0.00001 0.00253]\n",
"[ 2.232 -0.01295 -0.00712 -0.0002 0.00208 -0.00001 0.00232 0. -0.00001 0.00253]\n",
"[ 4.152 0.03097 -0.09741 -0.34017 0.11277 0.00006 0.10968 0.00661 -0.00078 0.00852]\n",
"[ 4.152 0.03097 -0.09741 -0.34017 0.11277 0.00006 0.10968 0.00661 -0.00078 0.00852]\n",
"[ 1.721 -0.00234 -0.00881 -0.23819 0.04522 -0.00008 0.04548 0.00076 -0.00099 0.00211]\n",
"[ 1.721 -0.00234 -0.00881 -0.23819 0.04522 -0.00008 0.04548 0.00076 -0.00099 0.00211]\n",
"[ 0.446 0.02998 0.00007 -0.02006 0.00116 -0. 0.00513 0.00163 0. 0.00416]\n",
"[ 0.474 0.02018 -0. -0.02117 0.00111 0. 0.00471 0.00104 -0. 0.00366]\n",
"changed to [ 0.474 0.02018 -0. -0.02117 0.00111 0. 0.00471 0.00104 -0. 0.00366]\n",
"[ 2.244 -0.11047 -0.00022 0.01616 0.00269 -0.00002 0.0086 0.00028 -0. 0.00773]\n",
"[ 2.244 -0.11047 -0.00022 0.01616 0.00269 -0.00002 0.0086 0.00028 -0. 0.00773]\n",
"[ 2.232 -0.01295 0.00712 -0.0002 0.00208 0.00001 0.00232 0. 0.00001 0.00253]\n",
"[ 2.232 -0.01295 0.00712 -0.0002 0.00208 0.00001 0.00232 0. 0.00001 0.00253]\n",
"[ 4.152 0.03097 0.09741 -0.34017 0.11277 -0.00006 0.10968 0.00661 0.00078 0.00852]\n",
"[ 4.152 0.03097 0.09741 -0.34017 0.11277 -0.00006 0.10968 0.00661 0.00078 0.00852]\n",
"[ 1.721 -0.00234 0.00881 -0.23819 0.04522 0.00008 0.04548 0.00076 0.00099 0.00211]\n",
"[ 1.721 -0.00234 0.00881 -0.23819 0.04522 0.00008 0.04548 0.00076 0.00099 0.00211]\n",
"[ 0.446 0.02998 -0.00007 -0.02006 0.00116 0. 0.00513 0.00163 -0. 0.00416]\n",
"[ 0.474 0.02018 0. -0.02117 0.00111 -0. 0.00471 0.00104 0. 0.00366]\n",
"changed to [ 0.474 0.02018 0. -0.02117 0.00111 -0. 0.00471 0.00104 0. 0.00366]\n",
"[17.789 0.0087 0.04976 3.6439 1.23386 -0.00056 1.15605 0.00025 -0.01094 0.12799]\n",
"[17.789 0.0087 0.04976 3.6439 1.23386 -0.00056 1.15605 0.00025 -0.01094 0.12799]\n",
"[ 1.033 0.00521 0.05543 -0.01623 0.00453 -0.0003 0.00115 0.00009 0.00091 0.00397]\n",
"[ 1.033 0.00521 0.05543 -0.01623 0.00453 -0.0003 0.00115 0.00009 0.00091 0.00397]\n",
"[ 0.793 0.00054 0.00091 -0.0746 0.00859 0. 0.00872 -0.00002 0.00002 0.00102]\n",
"[ 0.793 0.00054 0.00091 -0.0746 0.00859 0. 0.00872 -0.00002 0.00002 0.00102]\n",
"[ 0.839 0.01145 0.00232 -0.13647 0.02587 -0.00004 0.02643 0.00221 0.00045 0.00083]\n",
"[ 0.839 0.01145 0.00232 -0.13647 0.02587 -0.00004 0.02643 0.00221 0.00045 0.00083]\n",
"[ 0.669 0.10642 -0.0001 -0.01055 0.00059 -0.00002 0.0231 0.00197 0. 0.02293]\n",
"[ 0.723 0.1192 0.00009 -0.01138 0.0006 -0.00005 0.02584 0.00217 0.00001 0.02565]\n",
"changed to [ 0.723 0.1192 0.00009 -0.01138 0.0006 -0.00005 0.02584 0.00217 0.00001 0.02565]\n",
"[ 1.033 0.00521 -0.05543 -0.01623 0.00453 0.0003 0.00115 0.00009 -0.00091 0.00397]\n",
"[ 1.033 0.00521 -0.05543 -0.01623 0.00453 0.0003 0.00115 0.00009 -0.00091 0.00397]\n",
"[ 0.793 0.00054 -0.00091 -0.0746 0.00859 -0. 0.00872 -0.00002 -0.00002 0.00102]\n",
"[ 0.793 0.00054 -0.00091 -0.0746 0.00859 -0. 0.00872 -0.00002 -0.00002 0.00102]\n",
"[ 0.839 0.01145 -0.00232 -0.13647 0.02587 0.00004 0.02643 0.00221 -0.00045 0.00083]\n",
"[ 0.839 0.01145 -0.00232 -0.13647 0.02587 0.00004 0.02643 0.00221 -0.00045 0.00083]\n",
"[ 0.669 0.10642 0.0001 -0.01055 0.00059 0.00002 0.0231 0.00197 -0. 0.02293]\n",
"[ 0.723 0.1192 -0.00009 -0.01138 0.0006 0.00005 0.02584 0.00217 -0.00001 0.02565]\n",
"changed to [ 0.723 0.1192 -0.00009 -0.01138 0.0006 0.00005 0.02584 0.00217 -0.00001 0.02565]\n"
]
}
],
"source": [
"with np.printoptions(precision=5, suppress=True, linewidth=400):\n",
" for body_id in mjmodel.jnt_bodyid:\n",
" mjparams = parameters.get_dynamic_parameters(mjmodel, body_id)\n",
" pinparams = pinmodel.inertias[int(body_id)].toDynamicParameters()\n",
"\n",
" print(mjparams)\n",
" print(pinparams)\n",
" if not np.allclose(mjparams, pinparams, atol=1e-6):\n",
" parameters.set_dynamic_parameters(mjmodel, body_id, pinparams)\n",
"\n",
" mjparams = parameters.get_dynamic_parameters(mjmodel, body_id)\n",
" print(\"changed to\", mjparams)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"q = np.random.randn(mjmodel.nq)\n",
"v = np.random.randn(mjmodel.nv)\n",
"dv = np.random.randn(mjmodel.nv)\n",
"tau = np.random.randn(mjmodel.nu)\n",
"\n",
"mjdata.qpos[:] = q\n",
"mjdata.qvel[:] = v\n",
"mjdata.qacc[:] = dv\n",
"mjdata.ctrl[:] = tau\n",
"\n",
"mujoco.mj_step(mjmodel, mjdata)\n",
"mujoco.mj_inverse(mjmodel, mjdata)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# setup same data in pinocchio\n",
"pinq, pinv = muj2pin(mjdata.qpos, mjdata.qvel)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ -26.04972 -2.26343 951.46826 -99.99949 0. 0. 0. 0. 0. 0. ]\n",
" [ 5.94784 -950.59184 -3.00354 68.78174 0. 0. 0. 0. 0. 0. ]\n",
" [ 28.77278 102.78788 -67.48682 -1.14711 0. 0. 0. 0. 0. 0. ]\n",
" [ 0. 0. 28.77278 -5.94784 -68.13428 -102.78788 -0.64746 -950.59184 -1.85643 0.64746]\n",
" [ 0. -28.77278 0. -26.04972 1.39419 -67.48682 -101.39368 1.11632 -951.46826 -1.39419]\n",
" [ 0. 5.94784 26.04972 0. -0.43821 0.74011 0.43821 -68.78174 -99.99949 -951.03005]]\n"
]
}
],
"source": [
"from mujoco_sysid.regressors import joint_body_regressor\n",
"\n",
"with np.printoptions(precision=5, suppress=True, linewidth=400):\n",
" print(joint_body_regressor(mjmodel, mjdata, 2))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ -27.22494 -2.26343 951.46826 -99.99949 0. 0. 0. 0. 0. 0. ]\n",
" [ 19.5201 -950.59184 -3.00354 68.78174 0. 0. 0. 0. 0. 0. ]\n",
" [ 28.80256 102.78788 -67.48682 -1.14711 0. 0. 0. 0. 0. 0. ]\n",
" [ 0. 0. 28.80256 -19.5201 -68.13428 -102.78788 -0.64746 -950.59184 -1.85643 0.64746]\n",
" [ 0. -28.80256 0. -27.22494 1.39419 -67.48682 -101.39368 1.11632 -951.46826 -1.39419]\n",
" [ 0. 19.5201 27.22494 0. -0.43821 0.74011 0.43821 -68.78174 -99.99949 -951.03005]]\n"
]
}
],
"source": [
"pindata = pin.Data(pinmodel)\n",
"tau = pin.rnea(pinmodel, pindata, pinq, pinv, mjdata.qacc.copy())\n",
"\n",
"with np.printoptions(precision=5, suppress=True, linewidth=400):\n",
" print(pin.jointBodyRegressor(pinmodel, pindata, 2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading
Loading