Skip to content

Commit

Permalink
Add mujoco rgbd rendering (#1229)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidPL1 authored Oct 28, 2024
1 parent 0e94c46 commit 988999c
Show file tree
Hide file tree
Showing 28 changed files with 98 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/environments/mujoco.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ env = gymnasium.make("Ant-v5", render_mode="rgb_array", width=1280, height=720)

| Parameter | Type | Default | Description |
|-------------------------|-------------------------------------|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `render_mode` | **str** | `None` | The modality of the render result. Must be one of `human`, `rgb_array`, `depth_array`, or `rgbd_tuple`. Note that `human` does not return a rendered image, but renders directly to the window |
| `width` | **int** | `480` | The width of the render window |
| `height` | **int** | `480` | The height of the render window |
| `camera_id` | **int \| None** | `None` | The camera ID used for the render window |
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/ant_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class AntEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 20,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/ant_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class AntEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -297,6 +298,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/half_cheetah_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 20,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/half_cheetah_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -198,6 +199,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/hopper_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 125,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/hopper_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -226,6 +227,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/humanoid_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 67,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/humanoid_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -381,6 +382,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/humanoidstandup_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 67,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/humanoidstandup_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -357,6 +358,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/inverted_double_pendulum_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 20,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/inverted_double_pendulum_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/inverted_pendulum_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 25,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/inverted_pendulum_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
], self.metadata["render_modes"]
if "render_fps" in self.metadata:
assert (
Expand Down
30 changes: 19 additions & 11 deletions gymnasium/envs/mujoco/mujoco_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,12 @@ def render(

mujoco.mjr_readPixels(rgb_arr, depth_arr, self.viewport, self.con)

if render_mode == "depth_array":
# Process rendered images according to render_mode
if render_mode in ["depth_array", "rgbd_tuple"]:
depth_img = depth_arr.reshape(self.viewport.height, self.viewport.width)
# original image is upside-down, so flip it
return depth_img[::-1, :]
else:
depth_img = depth_img[::-1, :]
if render_mode in ["rgb_array", "rgbd_tuple"]:
rgb_img = rgb_arr.reshape(self.viewport.height, self.viewport.width, 3)

if segmentation:
Expand All @@ -280,9 +281,16 @@ def render(
seg_ids[geom.segid + 1, 0] = geom.objtype
seg_ids[geom.segid + 1, 1] = geom.objid
rgb_img = seg_ids[seg_img]
# original image is upside-down, so flip it
rgb_img = rgb_img[::-1, :, :]

# original image is upside-down, so flip i
return rgb_img[::-1, :, :]
# Return processed images based on render_mode
if render_mode == "rgb_array":
return rgb_img
elif render_mode == "depth_array":
return depth_img
else: # "rgbd_tuple"
return rgb_img, depth_img

def close(self):
self.free()
Expand Down Expand Up @@ -697,10 +705,10 @@ def render(
"""Renders a frame of the simulation in a specific format and camera view.
Args:
render_mode: The format to render the frame, it can be: "human", "rgb_array", or "depth_array"
render_mode: The format to render the frame, it can be: "human", "rgb_array", "depth_array", or "rgbd_tuple"
Returns:
If render_mode is "rgb_array" or "depth_array" it returns a numpy array in the specified format. "human" render mode does not return anything.
If render_mode is "rgb_array" or "depth_array" it returns a numpy array in the specified format. "rgbd_tuple" returns a tuple of numpy arrays of the form (rgb, depth). "human" render mode does not return anything.
"""
if render_mode != "human":
assert (
Expand All @@ -709,15 +717,15 @@ def render(

viewer = self._get_viewer(render_mode=render_mode)

if render_mode in ["rgb_array", "depth_array"]:
if render_mode in ["rgb_array", "depth_array", "rgbd_tuple"]:
return viewer.render(render_mode=render_mode, camera_id=self.camera_id)
elif render_mode == "human":
return viewer.render()

def _get_viewer(self, render_mode: Optional[str]):
"""Initializes and returns a viewer class depending on the render_mode
- `WindowViewer` class for "human" render mode
- `OffScreenViewer` class for "rgb_array" or "depth_array" render mode
- `OffScreenViewer` class for "rgb_array", "depth_array", or "rgbd_tuple" render mode
"""
self.viewer = self._viewers.get(render_mode)
if self.viewer is None:
Expand All @@ -730,7 +738,7 @@ def _get_viewer(self, render_mode: Optional[str]):
self.max_geom,
self._vopt,
)
elif render_mode in {"rgb_array", "depth_array"}:
elif render_mode in {"rgb_array", "depth_array", "rgbd_tuple"}:
self.viewer = OffScreenViewer(
self.model,
self.data,
Expand All @@ -741,7 +749,7 @@ def _get_viewer(self, render_mode: Optional[str]):
)
else:
raise AttributeError(
f"Unexpected mode: {render_mode}, expected modes: human, rgb_array, or depth_array"
f"Unexpected mode: {render_mode}, expected modes: human, rgb_array, depth_array, or rgbd_tuple"
)
# Add default camera parameters
self._set_cam_config()
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/pusher_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 20,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/pusher_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -207,6 +208,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/reacher_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 50,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/reacher_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -182,6 +183,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/swimmer_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 25,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/swimmer_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -197,6 +198,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
1 change: 1 addition & 0 deletions gymnasium/envs/mujoco/walker2d_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": 125,
}
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/mujoco/walker2d_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand Down Expand Up @@ -231,6 +232,7 @@ def __init__(
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
2 changes: 2 additions & 0 deletions tests/envs/mujoco/test_mujoco_custom_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class PointEnv(MujocoEnv, utils.EzPickle):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
}

Expand All @@ -43,6 +44,7 @@ def __init__(self, xml_file="point.xml", frame_skip=1, **kwargs):
"human",
"rgb_array",
"depth_array",
"rgbd_tuple",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
Expand Down
Loading

0 comments on commit 988999c

Please sign in to comment.