diff --git a/docs/source/tasks/index.md b/docs/source/tasks/index.md
index 87ce7b707..ca52cbb1b 100644
--- a/docs/source/tasks/index.md
+++ b/docs/source/tasks/index.md
@@ -280,28 +280,42 @@ Using the TriFingerPro robot, rotate a cube
## Control Tasks
-### MS-CartPole-v1
-
+### MS-CartpoleBalance-v1
:::{dropdown} Task Card
:icon: note
:color: primary
**Task Description:**
-Keep the CartPole stable and up right by sliding it left and right
+Use the Cartpole robot to balance a pole on a cart.
-**Supported Robots: None**
-**Randomizations:**
-- TODO
+**Supported Robots: Cartpole**
-**Success Conditions:**
-- the cart is within 0.25m of the center of the rail (which is at 0)
-- the cosine of the hinge angle attaching the pole is between 0.995 and 1
+**Randomizations:**
+- Pole direction is randomized around the vertical axis. the range is [-0.05, 0.05] radians.
-**Goal Specification:**
-- None
+**Fail Conditions:**
+- Pole is lower than the horizontal plane
-
-
\ No newline at end of file
+
+
+
+### MS-CartpoleSwingup-v1
+
+:::{dropdown} Task Card
+:icon: note
+:color: primary
+
+**Task Description:**
+Use the Cartpole robot to swing up a pole on a cart.
+
+
+**Supported Robots: Cartpole**
+
+**Randomizations:**
+- Pole direction is randomized around the whole circle. the range is [-pi, pi] radians.
+
+**Success Conditions:**
+- No specific success conditions. The task is considered successful if the pole is upright for the whole episode. We can threshold the episode accumulated reward to determine success.
diff --git a/examples/baselines/ppo/README.md b/examples/baselines/ppo/README.md
index fbbd69482..438b87941 100644
--- a/examples/baselines/ppo/README.md
+++ b/examples/baselines/ppo/README.md
@@ -61,9 +61,15 @@ python ppo.py --env_id="RotateCubeLevel4-v1" \
--num_envs=1024 --update_epochs=8 --num_minibatches=32 \
--total_timesteps=500_000_000 --num-steps=250 --num-eval-steps=250
-python ppo.py --env_id="MS-CartPole-v1" \
+python ppo.py --env_id="MS-CartpoleBalance-v1" \
--num_envs=1024 --update_epochs=8 --num_minibatches=32 \
- --total_timesteps=10_000_000 --num-steps=500 --num-eval-steps=500 \
+ --total_timesteps=4_000_000 --num-steps=250 --num-eval-steps=1000 \
+ --gamma=0.99 --gae_lambda=0.95 \
+ --eval_freq=5
+
+python ppo.py --env_id="MS-CartpoleSwingUp-v1" \
+ --num_envs=1024 --update_epochs=8 --num_minibatches=32 \
+ --total_timesteps=10_000_000 --num-steps=250 --num-eval-steps=1000 \
--gamma=0.99 --gae_lambda=0.95 \
--eval_freq=5
diff --git a/figures/environment_demos/MS-CartPole-v1.mp4 b/figures/environment_demos/MS-CartPole-v1.mp4
deleted file mode 100644
index fe7260a51..000000000
Binary files a/figures/environment_demos/MS-CartPole-v1.mp4 and /dev/null differ
diff --git a/figures/environment_demos/MS-CartpoleBalance-v1_rt.mp4 b/figures/environment_demos/MS-CartpoleBalance-v1_rt.mp4
new file mode 100644
index 000000000..8ae31e61f
Binary files /dev/null and b/figures/environment_demos/MS-CartpoleBalance-v1_rt.mp4 differ
diff --git a/mani_skill/envs/scene.py b/mani_skill/envs/scene.py
index 22748c0df..ed0d835d2 100644
--- a/mani_skill/envs/scene.py
+++ b/mani_skill/envs/scene.py
@@ -485,17 +485,23 @@ def get_sim_state(self) -> torch.Tensor:
state_dict["articulations"][
articulation.name
] = articulation.get_state().clone()
+ if len(state_dict["actors"]) == 0:
+ del state_dict["actors"]
+ if len(state_dict["articulations"]) == 0:
+ del state_dict["articulations"]
return state_dict
def set_sim_state(self, state: Dict):
- for actor_id, actor_state in state["actors"].items():
- if len(actor_state.shape) == 1:
- actor_state = actor_state[None, :]
- self.actors[actor_id].set_state(actor_state)
- for art_id, art_state in state["articulations"].items():
- if len(art_state.shape) == 1:
- art_state = art_state[None, :]
- self.articulations[art_id].set_state(art_state)
+ if "actors" in state:
+ for actor_id, actor_state in state["actors"].items():
+ if len(actor_state.shape) == 1:
+ actor_state = actor_state[None, :]
+ self.actors[actor_id].set_state(actor_state)
+ if "articulations" in state:
+ for art_id, art_state in state["articulations"].items():
+ if len(art_state.shape) == 1:
+ art_state = art_state[None, :]
+ self.articulations[art_id].set_state(art_state)
# ---------------------------------------------------------------------------- #
# GPU Simulation Management
diff --git a/mani_skill/envs/tasks/control/__init__.py b/mani_skill/envs/tasks/control/__init__.py
index 6c9c00c7b..920f2d3f6 100644
--- a/mani_skill/envs/tasks/control/__init__.py
+++ b/mani_skill/envs/tasks/control/__init__.py
@@ -1 +1 @@
-from .cartpole import CartPoleEnv
+from .cartpole import CartpoleBalanceEnv, CartpoleSwingUpEnv
diff --git a/mani_skill/envs/tasks/control/cartpole.py b/mani_skill/envs/tasks/control/cartpole.py
index a5caef2fa..f6812c0b3 100644
--- a/mani_skill/envs/tasks/control/cartpole.py
+++ b/mani_skill/envs/tasks/control/cartpole.py
@@ -8,17 +8,23 @@
from mani_skill.agents.base_agent import BaseAgent
from mani_skill.agents.controllers import *
from mani_skill.envs.sapien_env import BaseEnv
-from mani_skill.envs.utils import randomization
+from mani_skill.envs.utils import randomization, rewards
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import common, sapien_utils
from mani_skill.utils.registration import register_env
-from mani_skill.utils.structs.types import SceneConfig, SimConfig
+from mani_skill.utils.structs.pose import Pose
+from mani_skill.utils.structs.types import (
+ Array,
+ GPUMemoryConfig,
+ SceneConfig,
+ SimConfig,
+)
MJCF_FILE = f"{os.path.join(os.path.dirname(__file__), 'assets/cartpole.xml')}"
class CartPoleRobot(BaseAgent):
- uid = "cartpole"
+ uid = "cart_pole"
mjcf_path = MJCF_FILE
@property
@@ -57,15 +63,23 @@ def _load_articulation(self):
self.robot_link_ids = [link.name for link in self.robot.get_links()]
-@register_env("MS-CartPole-v1", max_episode_steps=500)
-class CartPoleEnv(BaseEnv):
- SUPPORTED_REWARD_MODES = ["sparse", "none"]
+# @register_env("MS-CartPole-v1", max_episode_steps=500)
+# class CartPoleEnv(BaseEnv):
+# SUPPORTED_REWARD_MODES = ["sparse", "none"]
- SUPPORTED_ROBOTS = [CartPoleRobot]
- agent: Union[CartPoleRobot]
+# SUPPORTED_ROBOTS = [CartPoleRobot]
+# agent: Union[CartPoleRobot]
+
+# CART_RANGE = [-0.25, 0.25]
+# ANGLE_COSINE_RANGE = [0.995, 1]
+
+# def __init__(self, *args, robot_uids=CartPoleRobot, **kwargs):
+# super().__init__(*args, robot_uids=robot_uids, **kwargs)
- CART_RANGE = [-0.25, 0.25]
- ANGLE_COSINE_RANGE = [0.995, 1]
+
+class CartpoleEnv(BaseEnv):
+
+ agent: Union[CartPoleRobot]
def __init__(self, *args, robot_uids=CartPoleRobot, **kwargs):
super().__init__(*args, robot_uids=robot_uids, **kwargs)
@@ -94,6 +108,60 @@ def _load_scene(self, options: dict):
for a in actor_builders:
a.build(a.name)
+ def evaluate(self):
+ return dict()
+
+ def _get_obs_extra(self, info: Dict):
+ obs = dict(
+ velocity=self.agent.robot.links_map["pole_1"].linear_velocity,
+ angular_velocity=self.agent.robot.links_map["pole_1"].angular_velocity,
+ )
+ return obs
+
+ @property
+ def pole_angle_cosine(self):
+ return torch.cos(self.agent.robot.joints_map["hinge_1"].qpos)
+
+ def compute_dense_reward(self, obs: Any, action: Array, info: Dict):
+ cart_pos = self.agent.robot.links_map["cart"].pose.p[
+ :, 0
+ ] # (B, ), we only care about x position
+ centered = rewards.tolerance(cart_pos, margin=2)
+ centered = (1 + centered) / 2 # (B, )
+
+ small_control = rewards.tolerance(
+ action, margin=1, value_at_margin=0, sigmoid="quadratic"
+ )[:, 0]
+ small_control = (4 + small_control) / 5
+
+ angular_vel = self.agent.robot.get_qvel()[:, 1]
+ small_velocity = rewards.tolerance(angular_vel, margin=5)
+ small_velocity = (1 + small_velocity) / 2 # (B, )
+
+ upright = (self.pole_angle_cosine + 1) / 2 # (B, )
+
+ # upright is 1 when the pole is upright, 0 when the pole is upside down
+ # small_control is 1 when the action is small, 0.8 when the action is large
+ # small_velocity is 1 when the angular velocity is small, 0.5 when the angular velocity is large
+ # centered is 1 when the cart is centered, 0 when the cart is at the edge of the screen
+
+ reward = upright * centered * small_control * small_velocity
+ return reward
+
+ def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict):
+ # this should be equal to compute_dense_reward / max possible reward
+ max_reward = 1.0
+ return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward
+
+
+@register_env("MS-CartpoleBalance-v1", max_episode_steps=1000)
+class CartpoleBalanceEnv(CartpoleEnv):
+ def __init__(self, *args, **kwargs):
+ super().__init__(
+ *args,
+ **kwargs,
+ )
+
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
with torch.device(self.device):
b = len(env_idx)
@@ -104,38 +172,24 @@ def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
self.agent.robot.set_qpos(qpos)
self.agent.robot.set_qvel(qvel)
- @property
- def pole_angle_cosine(self):
- return torch.cos(self.agent.robot.joints_map["hinge_1"].qpos)
-
def evaluate(self):
- cart_pos = self.agent.robot.joints_map["slider"].qpos
- pole_angle_cosine = self.pole_angle_cosine
- cart_in_bounds = cart_pos < self.CART_RANGE[1]
- cart_in_bounds = cart_in_bounds & (cart_pos > self.CART_RANGE[0])
- angle_in_bounds = pole_angle_cosine < self.ANGLE_COSINE_RANGE[1]
- angle_in_bounds = angle_in_bounds & (
- pole_angle_cosine > self.ANGLE_COSINE_RANGE[0]
- )
- return {"cart_in_bounds": cart_in_bounds, "angle_in_bounds": angle_in_bounds}
+ return dict(fail=self.pole_angle_cosine < 0)
- def _get_obs_extra(self, info: Dict):
- return dict()
-
- def compute_sparse_reward(self, obs: Any, action: torch.Tensor, info: Dict):
- return info["cart_in_bounds"] * info["angle_in_bounds"]
+@register_env("MS-CartpoleSwingUp-v1", max_episode_steps=1000)
+class CartpoleSwingUpEnv(CartpoleEnv):
+ def __init__(self, *args, **kwargs):
+ super().__init__(
+ *args,
+ **kwargs,
+ )
-@register_env("CartPoleSwingUp-v1", max_episode_steps=500, override=True)
-class CartPoleSwingUpEnv(CartPoleEnv):
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
with torch.device(self.device):
b = len(env_idx)
qpos = torch.zeros((b, 2))
- qpos[:, 0] = 0.01 * torch.randn(size=(b,))
- qpos[:, 1] = torch.pi + 0.01 * torch.randn(size=(b,))
+ qpos[:, 0] = torch.randn((b,)) * 0.01
+ qpos[:, 1] = torch.randn((b,)) * 0.01 + torch.pi
qvel = torch.randn(size=(b, 2)) * 0.01
self.agent.robot.set_qpos(qpos)
self.agent.robot.set_qvel(qvel)
- # Note DM-Control sets some randomness to other qpos values but am not sure what they are
- # as cartpole.xml seems to only load two joints
diff --git a/mani_skill/envs/utils/rewards/common.py b/mani_skill/envs/utils/rewards/common.py
index 707258f26..23f9ee127 100644
--- a/mani_skill/envs/utils/rewards/common.py
+++ b/mani_skill/envs/utils/rewards/common.py
@@ -1 +1,58 @@
-"""Useful utilities for reward functions"""
+import torch
+
+
+def tolerance(
+ x, lower=0.0, upper=0.0, margin=0.0, sigmoid="gaussian", value_at_margin=0.1
+):
+ # modified from https://github.com/google-deepmind/dm_control/blob/554ad2753df914372597575505249f22c255979d/dm_control/utils/rewards.py#L93
+ """Returns 1 when `x` falls inside the bounds, between 0 and 1 otherwise.
+
+ Args:
+ x: A torch array. (B, 3)
+ lower, upper: specifying inclusive `(lower, upper)` bounds for
+ the target interval. These can be infinite if the interval is unbounded
+ at one or both ends, or they can be equal to one another if the target
+ value is exact.
+ margin: Float. Parameter that controls how steeply the output decreases as
+ `x` moves out-of-bounds.
+ * If `margin == 0` then the output will be 0 for all values of `x`
+ outside of `bounds`.
+ * If `margin > 0` then the output will decrease sigmoidally with
+ increasing distance from the nearest bound.
+ sigmoid: String, choice of sigmoid type. Valid values are: 'gaussian',
+ 'linear', 'hyperbolic', 'long_tail', 'cosine', 'tanh_squared'.
+ value_at_margin: A float between 0 and 1 specifying the output value when
+ the distance from `x` to the nearest bound is equal to `margin`. Ignored
+ if `margin == 0`. todo: not implemented yet
+
+ Returns:
+ A torch array with values between 0.0 and 1.0.
+
+ Raises:
+ ValueError: If `bounds[0] > bounds[1]`.
+ ValueError: If `margin` is negative.
+ """
+ if lower > upper:
+ raise ValueError("Lower bound must be <= upper bound.")
+
+ if margin < 0:
+ raise ValueError("`margin` must be non-negative.")
+
+ in_bounds = torch.logical_and(lower <= x, x <= upper)
+
+ if margin == 0:
+ value = torch.where(in_bounds, torch.tensor(1.0), torch.tensor(0.0))
+ else:
+ d = torch.where(x < lower, lower - x, x - upper) / margin
+ if sigmoid == "gaussian":
+ value = torch.where(
+ in_bounds, torch.tensor(1.0), torch.exp(-0.5 * (d**2))
+ )
+ elif sigmoid == "hyperbolic":
+ value = torch.where(in_bounds, torch.tensor(1.0), 1 / (1 + torch.exp(d)))
+ elif sigmoid == "quadratic":
+ value = torch.where(in_bounds, torch.tensor(1.0), 1 - d**2)
+ else:
+ raise ValueError(f"Unknown sigmoid type {sigmoid!r}.")
+
+ return value
diff --git a/mani_skill/trajectory/replay_trajectory.py b/mani_skill/trajectory/replay_trajectory.py
index 3f4a778d4..f04179e80 100644
--- a/mani_skill/trajectory/replay_trajectory.py
+++ b/mani_skill/trajectory/replay_trajectory.py
@@ -343,6 +343,9 @@ def parse_args(args=None):
type=str,
help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer",
)
+ parser.add_argument(
+ "--video-fps", default=30, type=int, help="The FPS of saved videos"
+ )
return parser.parse_args(args)
@@ -418,6 +421,7 @@ def _main(args, proc_id: int = 0, num_procs=1, pbar=None):
save_trajectory=args.save_traj,
trajectory_name=new_traj_name,
save_video=args.save_video,
+ video_fps=args.video_fps,
record_reward=args.record_rewards,
)
diff --git a/mani_skill/utils/wrappers/record.py b/mani_skill/utils/wrappers/record.py
index 11af5383d..faa0cbab1 100644
--- a/mani_skill/utils/wrappers/record.py
+++ b/mani_skill/utils/wrappers/record.py
@@ -216,7 +216,7 @@ def __init__(
max_steps_per_video=None,
clean_on_close=True,
record_reward=True,
- video_fps=20,
+ video_fps=30,
source_type=None,
source_desc=None,
):
@@ -597,7 +597,7 @@ def recursive_add_to_h5py(group: h5py.Group, data: dict, key):
dtype=bool,
)
episode_info.update(
- fail=self._trajectory_buffer.success[end_ptr - 1, env_idx]
+ fail=self._trajectory_buffer.fail[end_ptr - 1, env_idx]
)
recursive_add_to_h5py(group, self._trajectory_buffer.state, "env_states")
if self.record_reward: