Skip to content

Commit

Permalink
Refactor remove BaseMujocoEnv class (#1075)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kallinteris-Andreas authored Jun 5, 2024
1 parent 04fb345 commit b984e6b
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 142 deletions.
200 changes: 83 additions & 117 deletions gymnasium/envs/mujoco/mujoco_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from os import path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -34,19 +34,22 @@ def expand_model_path(model_path: str) -> str:
return fullpath


class BaseMujocoEnv(gym.Env[NDArray[np.float64], NDArray[np.float32]]):
"""Superclass for all MuJoCo environments."""
class MujocoEnv(gym.Env):
"""Superclass for MuJoCo based environments."""

def __init__(
self,
model_path,
frame_skip,
model_path: str,
frame_skip: int,
observation_space: Optional[Space],
render_mode: Optional[str] = None,
width: int = DEFAULT_SIZE,
height: int = DEFAULT_SIZE,
camera_id: Optional[int] = None,
camera_name: Optional[str] = None,
default_camera_config: Optional[Dict[str, Union[float, int]]] = None,
max_geom: int = 1000,
visual_options: Dict[int, bool] = {},
):
"""Base abstract class for mujoco based environments.
Expand All @@ -59,6 +62,9 @@ def __init__(
height: The height of the render window.
camera_id: The camera ID used.
camera_name: The name of the camera used (can not be used in conjunction with `camera_id`).
default_camera_config: configuration for rendering camera.
max_geom: max number of rendered geometries.
visual_options: render flag options.
Raises:
OSError: when the `model_path` does not exist.
Expand Down Expand Up @@ -93,48 +99,78 @@ def __init__(
self.camera_name = camera_name
self.camera_id = camera_id

from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer

self.mujoco_renderer = MujocoRenderer(
self.model,
self.data,
default_camera_config,
self.width,
self.height,
max_geom,
camera_id,
camera_name,
visual_options,
)

def _set_action_space(self):
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
low, high = bounds.T
self.action_space = spaces.Box(low=low, high=high, dtype=np.float32)
return self.action_space

# methods to override:
# ----------------------------
def step(
self, action: NDArray[np.float32]
) -> Tuple[NDArray[np.float64], np.float64, bool, bool, Dict[str, np.float64]]:
raise NotImplementedError

def reset_model(self) -> NDArray[np.float64]:
def _initialize_simulation(
self,
) -> Tuple["mujoco.MjModel", "mujoco.MjData"]:
"""
Reset the robot degrees of freedom (qpos and qvel).
Implement this in each subclass.
Initialize MuJoCo simulation data structures `mjModel` and `mjData`.
"""
raise NotImplementedError
model = mujoco.MjModel.from_xml_path(self.fullpath)
# MjrContext will copy model.vis.global_.off* to con.off*
model.vis.global_.offwidth = self.width
model.vis.global_.offheight = self.height
data = mujoco.MjData(model)
return model, data

def _initialize_simulation(self) -> Tuple[Any, Any]:
"""
Initialize MuJoCo simulation data structures mjModel and mjData.
def set_state(self, qpos, qvel):
"""Set the joints position qpos and velocity qvel of the model.
Note: `qpos` and `qvel` is not the full physics state for all mujoco models/environments https://mujoco.readthedocs.io/en/stable/APIreference/APItypes.html#mjtstate
"""
raise NotImplementedError
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
self.data.qpos[:] = np.copy(qpos)
self.data.qvel[:] = np.copy(qvel)
if self.model.na == 0:
self.data.act[:] = None
mujoco.mj_forward(self.model, self.data)

def _step_mujoco_simulation(self, ctrl, n_frames) -> None:
def _step_mujoco_simulation(self, ctrl, n_frames):
"""
Step over the MuJoCo simulation.
"""
raise NotImplementedError
self.data.ctrl[:] = ctrl

def render(self) -> Union[NDArray[np.float64], None]:
mujoco.mj_step(self.model, self.data, nstep=n_frames)

# As of MuJoCo 2.0, force-related quantities like cacc are not computed
# unless there's a force sensor in the model.
# See https://github.com/openai/gym/issues/1541
mujoco.mj_rnePostConstraint(self.model, self.data)

def render(self):
"""
Render a frame from the MuJoCo simulation as specified by the render_mode.
"""
raise NotImplementedError
return self.mujoco_renderer.render(self.render_mode)

# -----------------------------
def _get_reset_info(self) -> Dict[str, float]:
"""Function that generates the `info` that is returned during a `reset()`."""
return {}
def close(self):
"""Close rendering contexts processes."""
if self.mujoco_renderer is not None:
self.mujoco_renderer.close()

def get_body_com(self, body_name):
"""Return the cartesian position of a body frame."""
return self.data.body(body_name).xpos

def reset(
self,
Expand Down Expand Up @@ -168,99 +204,29 @@ def do_simulation(self, ctrl, n_frames) -> None:
)
self._step_mujoco_simulation(ctrl, n_frames)

def close(self):
"""Close all processes like rendering contexts"""
raise NotImplementedError

def get_body_com(self, body_name) -> NDArray[np.float64]:
"""Return the cartesian position of a body frame"""
raise NotImplementedError

def state_vector(self) -> NDArray[np.float64]:
"""Return the position and velocity joint states of the model"""
return np.concatenate([self.data.qpos.flat, self.data.qvel.flat])


class MujocoEnv(BaseMujocoEnv):
"""Superclass for MuJoCo environments."""
"""Return the position and velocity joint states of the model.
def __init__(
self,
model_path,
frame_skip,
observation_space: Optional[Space],
render_mode: Optional[str] = None,
width: int = DEFAULT_SIZE,
height: int = DEFAULT_SIZE,
camera_id: Optional[int] = None,
camera_name: Optional[str] = None,
default_camera_config: Optional[Dict[str, Union[float, int]]] = None,
max_geom: int = 1000,
visual_options: Dict[int, bool] = {},
):
super().__init__(
model_path,
frame_skip,
observation_space,
render_mode,
width,
height,
camera_id,
camera_name,
)

from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer

self.mujoco_renderer = MujocoRenderer(
self.model,
self.data,
default_camera_config,
self.width,
self.height,
max_geom,
camera_id,
camera_name,
visual_options,
)

def _initialize_simulation(
self,
) -> Tuple["mujoco._structs.MjModel", "mujoco._structs.MjData"]:
model = mujoco.MjModel.from_xml_path(self.fullpath)
# MjrContext will copy model.vis.global_.off* to con.off*
model.vis.global_.offwidth = self.width
model.vis.global_.offheight = self.height
data = mujoco.MjData(model)
return model, data

def set_state(self, qpos, qvel):
"""Set the joints position qpos and velocity qvel of the model.
Note: `qpos` and `qvel` is not the full physics state for all mujoco models/environments https://mujoco.readthedocs.io/en/stable/APIreference/APItypes.html#mjtstate
Note: `qpos` and `qvel` does not constitute the full physics state for all `mujoco` environments see https://mujoco.readthedocs.io/en/stable/computation/index.html#the-state.
"""
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
self.data.qpos[:] = np.copy(qpos)
self.data.qvel[:] = np.copy(qvel)
if self.model.na == 0:
self.data.act[:] = None
mujoco.mj_forward(self.model, self.data)

def _step_mujoco_simulation(self, ctrl, n_frames):
self.data.ctrl[:] = ctrl

mujoco.mj_step(self.model, self.data, nstep=n_frames)
return np.concatenate([self.data.qpos.flat, self.data.qvel.flat])

# As of MuJoCo 2.0, force-related quantities like cacc are not computed
# unless there's a force sensor in the model.
# See https://github.com/openai/gym/issues/1541
mujoco.mj_rnePostConstraint(self.model, self.data)
# methods to override:
# ----------------------------
def step(
self, action: NDArray[np.float32]
) -> Tuple[NDArray[np.float64], np.float64, bool, bool, Dict[str, np.float64]]:
raise NotImplementedError

def render(self):
return self.mujoco_renderer.render(self.render_mode)
def reset_model(self) -> NDArray[np.float64]:
"""
Reset the robot degrees of freedom (qpos and qvel).
Implement this in each environment subclass.
"""
raise NotImplementedError

def close(self):
if self.mujoco_renderer is not None:
self.mujoco_renderer.close()
def _get_reset_info(self) -> Dict[str, float]:
"""Function that generates the `info` that is returned during a `reset()`."""
return {}

def get_body_com(self, body_name):
return self.data.body(body_name).xpos
# -----------------------------
Loading

0 comments on commit b984e6b

Please sign in to comment.