From f8ddbf6ac540040ef8fde7068c39881252f23cb6 Mon Sep 17 00:00:00 2001 From: Arjo Chakravarty Date: Wed, 5 Mar 2025 03:14:42 +0000 Subject: [PATCH] Enable GUI rollout Signed-off-by: Arjo Chakravarty --- .../simple_cart_pole/README.md | 4 ++-- .../simple_cart_pole/cart_pole_env.py | 12 ++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/scripts/reinforcement_learning/simple_cart_pole/README.md b/examples/scripts/reinforcement_learning/simple_cart_pole/README.md index e7ea450c57..a629368fa8 100644 --- a/examples/scripts/reinforcement_learning/simple_cart_pole/README.md +++ b/examples/scripts/reinforcement_learning/simple_cart_pole/README.md @@ -1,6 +1,6 @@ # Example for Reinforcement Learning (RL) With Gazebo -This demo world shows you an example of how you can use SDFormat, Ray-RLLIB and Gazebo to perform RL with python. +This demo world shows you an example of how you can use SDFormat, Stable Baselines 3 and Gazebo to perform RL with python. We start with a very simple cart-pole world. This world is defined in our sdf file `cart_pole.sdf`. It is analogous to the @@ -10,7 +10,7 @@ First create a virtual environment using python, ``` python3 -m venv venv ``` -Lets activate it and install rayrllib and pytorch. +Lets activate it and install stablebaselines3 and pytorch. ``` . venv/bin/activate ``` diff --git a/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py b/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py index 31d336b817..60c0b8a8ea 100644 --- a/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py +++ b/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py @@ -4,7 +4,7 @@ import numpy as np from gz.common6 import set_verbosity -from gz.sim9 import TestFixture, World, world_entity, Model, Link, run_gui +from gz.sim10 import TestFixture, World, world_entity, Model, Link, get_install_prefix from gz.math8 import Vector3d from gz.transport14 import Node from gz.msgs11.world_control_pb2 import WorldControl @@ -13,9 +13,17 @@ from stable_baselines3 import PPO import time +import subprocess file_path = os.path.dirname(os.path.realpath(__file__)) +def run_gui(): + if os.name == 'nt': + base = os.path.join(get_install_prefix(), "libexec", "runGui.exe") + else: + base = os.path.join(get_install_prefix(), "libexec", "runGui") + subprocess.Popen(base) + class GzRewardScorer: def __init__(self): self.fixture = TestFixture(os.path.join(file_path, 'cart_pole.sdf')) @@ -99,13 +107,13 @@ def step(self, action): obs, reward, done, truncated, info = self.env.step(action) return obs, reward, done, truncated, info - env = CustomCartPole({}) model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=25_000) vec_env = model.get_env() obs = vec_env.reset() + run_gui() time.sleep(10) for i in range(50000):