Skip to content

Commit

Permalink
add deploy examples
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoruXue committed Sep 25, 2024
1 parent 87d2ec2 commit 28bfa48
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 34 deletions.
42 changes: 38 additions & 4 deletions dial_mpc/deploy/dial_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from tqdm import tqdm
import art
import emoji

import functools
from functools import partial
Expand All @@ -26,7 +27,12 @@
import dial_mpc.envs as dial_envs
from dial_mpc.core.dial_core import DialConfig, MBDPI
from dial_mpc.envs.base_env import BaseEnv, BaseEnvConfig
from dial_mpc.utils.io_utils import load_dataclass_from_dict
from dial_mpc.utils.io_utils import (
load_dataclass_from_dict,
get_model_path,
get_example_path,
)
from dial_mpc.examples import deploy_examples

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get("XLA_FLAGS", "")
Expand Down Expand Up @@ -224,11 +230,39 @@ def reverse_scan(rng_Y0_state, factor):
def main(args=None):
art.tprint("LeCAR @ CMU\nDIAL-MPC\nPLANNER", font="big", chr_ignore=True)
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="config.yaml")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--config", type=str, default="config.yaml", help="Path to config file"
)
group.add_argument(
"--example",
type=str,
default=None,
help="Example to run",
)
group.add_argument(
"--list-examples",
action="store_true",
help="List available examples",
)
args = parser.parse_args(args)

print("Creating environment")
config_dict = yaml.safe_load(open(args.config))
if args.list_examples:
print("Available examples:")
for example in deploy_examples:
print(f" - {example}")
return
if args.example is not None:
if args.example not in deploy_examples:
print(f"Example {args.example} not found.")
return
config_dict = yaml.safe_load(
open(get_example_path(args.example + ".yaml"), "r")
)
else:
config_dict = yaml.safe_load(open(args.config, "r"))

print(emoji.emojize(":rocket:") + "Creating environment")
dial_config = load_dataclass_from_dict(DialConfig, config_dict)
env_config_type = dial_envs.get_config(dial_config.env_name)
env_config = load_dataclass_from_dict(
Expand Down
44 changes: 39 additions & 5 deletions dial_mpc/deploy/dial_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@

from dial_mpc.envs.base_env import BaseEnvConfig
from dial_mpc.core.dial_core import DialConfig
from dial_mpc.utils.io_utils import load_dataclass_from_dict
from dial_mpc.utils.io_utils import (
load_dataclass_from_dict,
get_model_path,
get_example_path,
)
from dial_mpc.examples import deploy_examples

plt.style.use(["science"])


@dataclass
class DialSimConfig:
sim_model_path: str
robot_name: str
scene_name: str
sim_leg_control: str
plot: bool
record: bool
Expand Down Expand Up @@ -52,7 +58,9 @@ def __init__(
self.kp = env_config.kp
self.kd = env_config.kd
self.leg_control = sim_config.sim_leg_control
self.mj_model = mujoco.MjModel.from_xml_path(sim_config.sim_model_path)
self.mj_model = mujoco.MjModel.from_xml_path(
get_model_path(sim_config.robot_name, sim_config.scene_name).as_posix()
)
self.mj_model.opt.timestep = self.sim_dt
self.mj_data = mujoco.MjData(self.mj_model)
self.q_history = np.zeros((self.n_acts, self.mj_model.nu))
Expand Down Expand Up @@ -289,11 +297,37 @@ def close(self):
def main(args=None):
art.tprint("LeCAR @ CMU\nDIAL-MPC\nSIMULATOR", font="big", chr_ignore=True)
parser = argparse.ArgumentParser()
parser.add_argument(
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--config", type=str, default="config.yaml", help="Path to config file"
)
group.add_argument(
"--example",
type=str,
default=None,
help="Example to run",
)
group.add_argument(
"--list-examples",
action="store_true",
help="List available examples",
)
args = parser.parse_args(args)
config_dict = yaml.safe_load(open(args.config, "r"))

if args.list_examples:
print("Available examples:")
for example in deploy_examples:
print(f" - {example}")
return
if args.example is not None:
if args.example not in deploy_examples:
print(f"Example {args.example} not found.")
return
config_dict = yaml.safe_load(
open(get_example_path(args.example + ".yaml"), "r")
)
else:
config_dict = yaml.safe_load(open(args.config, "r"))
sim_config = load_dataclass_from_dict(DialSimConfig, config_dict)
env_config = load_dataclass_from_dict(
BaseEnvConfig, config_dict, convert_list_to_array=True
Expand Down
5 changes: 5 additions & 0 deletions dial_mpc/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@
"unitree_go2_seq_jump",
"unitree_go2_crate_climb",
]

deploy_examples = [
"unitree_go2_trot_deploy",
"unitree_go2_seq_jump_deploy",
]
24 changes: 8 additions & 16 deletions dial_mpc/examples/unitree_go2_seq_jump.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,12 @@ action_scale: 1.0

# Go2
jump_dt: 1.0
pose_target_sequence: [
[0.0, 0.0, 0.27],
[0.4, 0.0, 0.27],
[0.8, 0.0, 0.27],
[1.2, 0.0, 0.27],
[1.6, 0.0, 0.27],
]
pose_target_sequence:
[
[0.0, 0.0, 0.27],
[0.4, 0.0, 0.27],
[0.8, 0.0, 0.27],
[1.2, 0.0, 0.27],
[1.6, 0.0, 0.27],
]
yaw_target_sequence: [0.0, 0.0, 0.0, 0.0, 0.0]

# Sim
sim_model_path: /home/haoru/research/dial/dial-mpc/dial_mpc/models/unitree_go2/scene.xml
sim_leg_control: torque
plot: false
record: false
real_time_factor: 1.0
sim_dt: 0.005
sync_mode: false
43 changes: 43 additions & 0 deletions dial_mpc/examples/unitree_go2_seq_jump_deploy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# DIAL-MPC
seed: 0
output_dir: unitree_go2_seq_jump
n_steps: 400

env_name: unitree_go2_seq_jump
Nsample: 2048
Hsample: 20
Hnode: 5
Ndiffuse: 1
Ndiffuse_init: 10
temp_sample: 0.1
horizon_diffuse_factor: 0.9
traj_diffuse_factor: 0.5
update_method: mppi

# Base environment
dt: 0.02
timestep: 0.02
leg_control: torque
action_scale: 1.0

# Go2
jump_dt: 1.0
pose_target_sequence:
[
[0.0, 0.0, 0.27],
[0.4, 0.0, 0.27],
[0.8, 0.0, 0.27],
[1.2, 0.0, 0.27],
[1.6, 0.0, 0.27],
]
yaw_target_sequence: [0.0, 0.0, 0.0, 0.0, 0.0]

# Sim
robot_name: "unitree_go2"
scene_name: "scene.xml"
sim_leg_control: torque
plot: false
record: false
real_time_factor: 1.0
sim_dt: 0.005
sync_mode: false
9 changes: 0 additions & 9 deletions dial_mpc/examples/unitree_go2_trot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,3 @@ default_vy: 0.0
default_vyaw: 0.0
ramp_up_time: 1.0
gait: trot

# Sim
sim_model_path: /home/haoru/research/dial/dial-mpc/dial_mpc/models/unitree_go2/scene.xml
sim_leg_control: torque
plot: false
record: false
real_time_factor: 1.0
sim_dt: 0.005
sync_mode: false
38 changes: 38 additions & 0 deletions dial_mpc/examples/unitree_go2_trot_deploy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# DIAL-MPC
seed: 0
output_dir: unitree_go2_trot
n_steps: 400

env_name: unitree_go2_walk
Nsample: 2048
Hsample: 16
Hnode: 4
Ndiffuse: 1
Ndiffuse_init: 10
temp_sample: 0.05
horizon_diffuse_factor: 0.9
traj_diffuse_factor: 0.5
update_method: mppi

# Base environment
dt: 0.02
timestep: 0.02
leg_control: torque
action_scale: 1.0

# Go2
default_vx: 1.0
default_vy: 0.0
default_vyaw: 0.0
ramp_up_time: 1.0
gait: trot

# Sim
robot_name: "unitree_go2"
scene_name: "scene.xml"
sim_leg_control: torque
plot: false
record: false
real_time_factor: 1.0
sim_dt: 0.005
sync_mode: false

0 comments on commit 28bfa48

Please sign in to comment.