diff --git a/dial_mpc/deploy/dial_plan.py b/dial_mpc/deploy/dial_plan.py index 3e0a0c8..8e9e3db 100644 --- a/dial_mpc/deploy/dial_plan.py +++ b/dial_mpc/deploy/dial_plan.py @@ -8,6 +8,7 @@ import numpy as np from tqdm import tqdm import art +import emoji import functools from functools import partial @@ -26,7 +27,12 @@ import dial_mpc.envs as dial_envs from dial_mpc.core.dial_core import DialConfig, MBDPI from dial_mpc.envs.base_env import BaseEnv, BaseEnvConfig -from dial_mpc.utils.io_utils import load_dataclass_from_dict +from dial_mpc.utils.io_utils import ( + load_dataclass_from_dict, + get_model_path, + get_example_path, +) +from dial_mpc.examples import deploy_examples # Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs xla_flags = os.environ.get("XLA_FLAGS", "") @@ -224,11 +230,39 @@ def reverse_scan(rng_Y0_state, factor): def main(args=None): art.tprint("LeCAR @ CMU\nDIAL-MPC\nPLANNER", font="big", chr_ignore=True) parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="config.yaml") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( + "--config", type=str, default="config.yaml", help="Path to config file" + ) + group.add_argument( + "--example", + type=str, + default=None, + help="Example to run", + ) + group.add_argument( + "--list-examples", + action="store_true", + help="List available examples", + ) args = parser.parse_args(args) - print("Creating environment") - config_dict = yaml.safe_load(open(args.config)) + if args.list_examples: + print("Available examples:") + for example in deploy_examples: + print(f" - {example}") + return + if args.example is not None: + if args.example not in deploy_examples: + print(f"Example {args.example} not found.") + return + config_dict = yaml.safe_load( + open(get_example_path(args.example + ".yaml"), "r") + ) + else: + config_dict = yaml.safe_load(open(args.config, "r")) + + print(emoji.emojize(":rocket:") + "Creating environment") dial_config = load_dataclass_from_dict(DialConfig, config_dict) env_config_type = dial_envs.get_config(dial_config.env_name) env_config = load_dataclass_from_dict( diff --git a/dial_mpc/deploy/dial_sim.py b/dial_mpc/deploy/dial_sim.py index e7eb7fa..bf405b1 100644 --- a/dial_mpc/deploy/dial_sim.py +++ b/dial_mpc/deploy/dial_sim.py @@ -15,14 +15,20 @@ from dial_mpc.envs.base_env import BaseEnvConfig from dial_mpc.core.dial_core import DialConfig -from dial_mpc.utils.io_utils import load_dataclass_from_dict +from dial_mpc.utils.io_utils import ( + load_dataclass_from_dict, + get_model_path, + get_example_path, +) +from dial_mpc.examples import deploy_examples plt.style.use(["science"]) @dataclass class DialSimConfig: - sim_model_path: str + robot_name: str + scene_name: str sim_leg_control: str plot: bool record: bool @@ -52,7 +58,9 @@ def __init__( self.kp = env_config.kp self.kd = env_config.kd self.leg_control = sim_config.sim_leg_control - self.mj_model = mujoco.MjModel.from_xml_path(sim_config.sim_model_path) + self.mj_model = mujoco.MjModel.from_xml_path( + get_model_path(sim_config.robot_name, sim_config.scene_name).as_posix() + ) self.mj_model.opt.timestep = self.sim_dt self.mj_data = mujoco.MjData(self.mj_model) self.q_history = np.zeros((self.n_acts, self.mj_model.nu)) @@ -289,11 +297,37 @@ def close(self): def main(args=None): art.tprint("LeCAR @ CMU\nDIAL-MPC\nSIMULATOR", font="big", chr_ignore=True) parser = argparse.ArgumentParser() - parser.add_argument( + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( "--config", type=str, default="config.yaml", help="Path to config file" ) + group.add_argument( + "--example", + type=str, + default=None, + help="Example to run", + ) + group.add_argument( + "--list-examples", + action="store_true", + help="List available examples", + ) args = parser.parse_args(args) - config_dict = yaml.safe_load(open(args.config, "r")) + + if args.list_examples: + print("Available examples:") + for example in deploy_examples: + print(f" - {example}") + return + if args.example is not None: + if args.example not in deploy_examples: + print(f"Example {args.example} not found.") + return + config_dict = yaml.safe_load( + open(get_example_path(args.example + ".yaml"), "r") + ) + else: + config_dict = yaml.safe_load(open(args.config, "r")) sim_config = load_dataclass_from_dict(DialSimConfig, config_dict) env_config = load_dataclass_from_dict( BaseEnvConfig, config_dict, convert_list_to_array=True diff --git a/dial_mpc/examples/__init__.py b/dial_mpc/examples/__init__.py index 96b6978..70c7dee 100644 --- a/dial_mpc/examples/__init__.py +++ b/dial_mpc/examples/__init__.py @@ -5,3 +5,8 @@ "unitree_go2_seq_jump", "unitree_go2_crate_climb", ] + +deploy_examples = [ + "unitree_go2_trot_deploy", + "unitree_go2_seq_jump_deploy", +] diff --git a/dial_mpc/examples/unitree_go2_seq_jump.yaml b/dial_mpc/examples/unitree_go2_seq_jump.yaml index 4c8c917..3a24020 100644 --- a/dial_mpc/examples/unitree_go2_seq_jump.yaml +++ b/dial_mpc/examples/unitree_go2_seq_jump.yaml @@ -22,20 +22,12 @@ action_scale: 1.0 # Go2 jump_dt: 1.0 -pose_target_sequence: [ - [0.0, 0.0, 0.27], - [0.4, 0.0, 0.27], - [0.8, 0.0, 0.27], - [1.2, 0.0, 0.27], - [1.6, 0.0, 0.27], -] +pose_target_sequence: + [ + [0.0, 0.0, 0.27], + [0.4, 0.0, 0.27], + [0.8, 0.0, 0.27], + [1.2, 0.0, 0.27], + [1.6, 0.0, 0.27], + ] yaw_target_sequence: [0.0, 0.0, 0.0, 0.0, 0.0] - -# Sim -sim_model_path: /home/haoru/research/dial/dial-mpc/dial_mpc/models/unitree_go2/scene.xml -sim_leg_control: torque -plot: false -record: false -real_time_factor: 1.0 -sim_dt: 0.005 -sync_mode: false \ No newline at end of file diff --git a/dial_mpc/examples/unitree_go2_seq_jump_deploy.yaml b/dial_mpc/examples/unitree_go2_seq_jump_deploy.yaml new file mode 100644 index 0000000..f6fda2f --- /dev/null +++ b/dial_mpc/examples/unitree_go2_seq_jump_deploy.yaml @@ -0,0 +1,43 @@ +# DIAL-MPC +seed: 0 +output_dir: unitree_go2_seq_jump +n_steps: 400 + +env_name: unitree_go2_seq_jump +Nsample: 2048 +Hsample: 20 +Hnode: 5 +Ndiffuse: 1 +Ndiffuse_init: 10 +temp_sample: 0.1 +horizon_diffuse_factor: 0.9 +traj_diffuse_factor: 0.5 +update_method: mppi + +# Base environment +dt: 0.02 +timestep: 0.02 +leg_control: torque +action_scale: 1.0 + +# Go2 +jump_dt: 1.0 +pose_target_sequence: + [ + [0.0, 0.0, 0.27], + [0.4, 0.0, 0.27], + [0.8, 0.0, 0.27], + [1.2, 0.0, 0.27], + [1.6, 0.0, 0.27], + ] +yaw_target_sequence: [0.0, 0.0, 0.0, 0.0, 0.0] + +# Sim +robot_name: "unitree_go2" +scene_name: "scene.xml" +sim_leg_control: torque +plot: false +record: false +real_time_factor: 1.0 +sim_dt: 0.005 +sync_mode: false diff --git a/dial_mpc/examples/unitree_go2_trot.yaml b/dial_mpc/examples/unitree_go2_trot.yaml index 24a522b..e4ff16f 100644 --- a/dial_mpc/examples/unitree_go2_trot.yaml +++ b/dial_mpc/examples/unitree_go2_trot.yaml @@ -26,12 +26,3 @@ default_vy: 0.0 default_vyaw: 0.0 ramp_up_time: 1.0 gait: trot - -# Sim -sim_model_path: /home/haoru/research/dial/dial-mpc/dial_mpc/models/unitree_go2/scene.xml -sim_leg_control: torque -plot: false -record: false -real_time_factor: 1.0 -sim_dt: 0.005 -sync_mode: false \ No newline at end of file diff --git a/dial_mpc/examples/unitree_go2_trot_deploy.yaml b/dial_mpc/examples/unitree_go2_trot_deploy.yaml new file mode 100644 index 0000000..0a11b4e --- /dev/null +++ b/dial_mpc/examples/unitree_go2_trot_deploy.yaml @@ -0,0 +1,38 @@ +# DIAL-MPC +seed: 0 +output_dir: unitree_go2_trot +n_steps: 400 + +env_name: unitree_go2_walk +Nsample: 2048 +Hsample: 16 +Hnode: 4 +Ndiffuse: 1 +Ndiffuse_init: 10 +temp_sample: 0.05 +horizon_diffuse_factor: 0.9 +traj_diffuse_factor: 0.5 +update_method: mppi + +# Base environment +dt: 0.02 +timestep: 0.02 +leg_control: torque +action_scale: 1.0 + +# Go2 +default_vx: 1.0 +default_vy: 0.0 +default_vyaw: 0.0 +ramp_up_time: 1.0 +gait: trot + +# Sim +robot_name: "unitree_go2" +scene_name: "scene.xml" +sim_leg_control: torque +plot: false +record: false +real_time_factor: 1.0 +sim_dt: 0.005 +sync_mode: false