Skip to content

Commit

Permalink
Merge pull request #585 from ARISE-Initiative/domain-rand-wrapper
Browse files Browse the repository at this point in the history
Update domain randomization wrapper to work with  mujoco!=3.1.1 for some settings
  • Loading branch information
kevin-thankyou-lin authored Dec 25, 2024
2 parents 154491f + 795c4ba commit 06550e0
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 3 deletions.
13 changes: 10 additions & 3 deletions robosuite/demos/demo_domain_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Script to showcase domain randomization functionality.
"""

import mujoco
import time

import robosuite.macros as macros
from robosuite.utils.input_utils import *
Expand All @@ -12,7 +12,6 @@
macros.USING_INSTANCE_RANDOMIZATION = True

if __name__ == "__main__":
assert mujoco.__version__ == "3.1.1", "Script requires mujoco-py version 3.1.1 to run"
# Create dict to hold options that will be passed to env creation call
options = {}

Expand Down Expand Up @@ -57,10 +56,17 @@
control_freq=20,
hard_reset=False, # TODO: Not setting this flag to False brings up a segfault on macos or glfw error on linux
)
env = DomainRandomizationWrapper(env)
env = DomainRandomizationWrapper(
env,
randomize_color=False, # randomize_color currently only works for mujoco==3.1.1
randomize_camera=False, # less jarring when visualizing
randomize_dynamics=False,
)
env.reset()
env.viewer.set_camera(camera_id=0)

max_frame_rate = 20 # Set the desired maximum frame rate

# Get action limits
low, high = env.action_spec

Expand All @@ -69,3 +75,4 @@
action = np.random.uniform(low, high)
obs, reward, done, _ = env.step(action)
env.render()
time.sleep(1 / max_frame_rate)
47 changes: 47 additions & 0 deletions robosuite/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This file contains utility classes and functions for logging to stdout and stderr
Adapted from robomimic: https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/log_utils.py
"""
import inspect
import logging
import os
import time
Expand Down Expand Up @@ -97,6 +98,52 @@ def get_logger(self):
return logger


def format_message(level: str, message: str) -> str:
"""
Format a message with colors based on the level and include file and line number.
Args:
level (str): The logging level (e.g., "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL").
message (str): The message to format.
Returns:
str: The formatted message with file and line number.
"""
# Get the caller's file name and line number
frame = inspect.currentframe().f_back
filename = frame.f_code.co_filename
lineno = frame.f_lineno

# Level-based coloring
level_colors = {
"DEBUG": "blue",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red",
}
attrs = ["bold"]
if level == "CRITICAL":
attrs.append("reverse")

color = level_colors.get(level, "white")
formatted_message = colored(f"[{level}] {filename}:{lineno} - {message}", color, attrs=attrs)
return formatted_message


def rs_assert(condition: bool, message: str):
"""
Assert a condition and raise an error with a formatted message if the condition fails.
Args:
condition (bool): The condition to check.
message (str): The error message to display if the assertion fails.
"""
if not condition:
formatted_message = format_message("ERROR", message)
raise AssertionError(formatted_message)


ROBOSUITE_DEFAULT_LOGGER = DefaultLogger(
console_logging_level=macros.CONSOLE_LOGGING_LEVEL,
file_logging_level=macros.FILE_LOGGING_LEVEL,
Expand Down
9 changes: 9 additions & 0 deletions robosuite/wrappers/domain_randomization_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
This file implements a wrapper for facilitating domain randomization over
robosuite environments.
"""
import mujoco
import numpy as np

from robosuite.utils.log_utils import rs_assert
from robosuite.utils.mjmod import CameraModder, DynamicsModder, LightingModder, TextureModder
from robosuite.wrappers import Wrapper

Expand Down Expand Up @@ -154,6 +156,13 @@ def __init__(
self.modders = []

if self.randomize_color:
rs_assert(
mujoco.__version__ == "3.1.1",
(
"TextureModder requires mujoco version 3.1.1 to run. "
"Pending support for later versions. Alternatively, you can set randomize_color=False."
),
)
self.tex_modder = TextureModder(
sim=self.env.sim, random_state=self.random_state, **self.color_randomization_args
)
Expand Down

0 comments on commit 06550e0

Please sign in to comment.