Skip to content

Commit

Permalink
Merge pull request #18 from fhswf/sb3_contrib
Browse files Browse the repository at this point in the history
Maskable PPO from SB3-Contrib
  • Loading branch information
steveyuwono authored Oct 25, 2024
2 parents 99d2bdc + f7078fa commit 55de436
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 10 deletions.
2 changes: 1 addition & 1 deletion doc/rtd/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Detlef Arend, Steve Yuwono, Laxmikant Shrikant Baheti et al'

# The full version, including alpha/beta/rc tags
release = '1.0.2'
release = '1.0.3'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions doc/rtd/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mlpro_int_gymnasium>=1.0.2
mlpro_int_mujoco[full]>=1.0.1
mlpro>=1.9.0
gymnasium!=1.0.0
stable_baselines3>=2.3.0
sb3-contrib>=2.2.1
torch>=2.0.0,<=2.3.1
numpy>=1.0.0,<=1.26.4
setuptools >= 75.0.0
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mlpro_int_gymnasium>=1.0.2
mlpro_int_mujoco[full]>=1.0.1
mlpro>=1.9.0
gymnasium!=1.0.0
stable_baselines3>=2.3.0
sb3-contrib>=2.2.1
torch>=2.0.0,<=2.3.1
numpy>=1.0.0,<=1.26.4
setuptools >= 75.0.0
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = mlpro-int-sb3
version = 1.0.2
version = 1.0.3
author = MLPro Team
author_email = [email protected]
description = MLPro: Integration StableBaselines3
Expand Down Expand Up @@ -28,8 +28,10 @@ where = src
full =
mlpro_int_gymnasium>=1.0.2
mlpro_int_mujoco[full]>=1.0.1
mlpro>=1.9.0
gymnasium!=1.0.0
stable_baselines3>=2.3.0
sb3-contrib>=2.2.1
torch>=2.0.0,<=2.3.1
numpy>=1.0.0,<=1.26.4
setuptools >= 75.0.0
50 changes: 43 additions & 7 deletions src/mlpro_int_sb3/wrappers/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
## -- 2023-09-25 1.2.8 DA Set minimum version for sb3 to 2.1.0
## -- 2024-02-16 1.3.0 SY Wrapper Relocation from MLPro to MLPro-Int-PettingZoo
## -- 2024-04-19 1.4.0 DA Alignment with MLPro 1.4.0
## -- 2024-10-24 1.4.1 SY Update: _compute_action_on_policy() for Maskable PPO
## -------------------------------------------------------------------------------------------------

"""
Ver. 1.4.0 (2024-04-19)
Ver. 1.4.1 (2024-10-24)
This module provides wrapper classes for integrating stable baselines3 policy algorithms.
Expand All @@ -54,7 +55,7 @@
from collections import OrderedDict
from mlpro.rl import *
from typing import Any, Dict, Optional, Union

from sb3_contrib import MaskablePPO



Expand Down Expand Up @@ -140,7 +141,7 @@ class WrPolicySB32MLPro (Wrapper, Policy):
p_action_space : MSpace
Environment Action Space
p_ada : bool
Adaptability. Defaults to True.
Adaptability. Default = True.
p_visualize : bool
Boolean switch for visualisation. Default = False.
p_logging
Expand All @@ -156,9 +157,16 @@ class WrPolicySB32MLPro (Wrapper, Policy):
C_MINIMUM_VERSION = '2.3.0'

## -------------------------------------------------------------------------------------------------
def __init__(self, p_sb3_policy, p_cycle_limit, p_observation_space:MSpace,
p_action_space:MSpace, p_ada:bool=True, p_visualize:bool=False,
p_logging=Log.C_LOG_ALL, p_num_envs:int=1, p_desired_goals=None):
def __init__(self,
p_sb3_policy,
p_cycle_limit,
p_observation_space:MSpace,
p_action_space:MSpace,
p_ada:bool=True,
p_visualize:bool=False,
p_logging=Log.C_LOG_ALL,
p_num_envs:int=1,
p_desired_goals=None):
# Set Name
WrPolicySB32MLPro.C_NAME = "Policy " + type(p_sb3_policy).__name__

Expand All @@ -169,6 +177,12 @@ def __init__(self, p_sb3_policy, p_cycle_limit, p_observation_space:MSpace,
self.last_buffer_element = None
self.last_done = False

# Check masking
if isinstance(p_sb3_policy, MaskablePPO):
self._action_masking = True
else:
self._action_masking = False

# Variable preparation for SB3
action_space = None
observation_space = None
Expand Down Expand Up @@ -278,6 +292,16 @@ def __init__(self, p_sb3_policy, p_cycle_limit, p_observation_space:MSpace,
self.sb3._logger = utils.configure_logger(0, self.sb3.tensorboard_log, "MLPro")


## -------------------------------------------------------------------------------------------------
def _get_mask(self) -> np.array:
raise NotImplementedError


## -------------------------------------------------------------------------------------------------
def _add_to_mask(self, p_action:Action):
raise NotImplementedError


## -------------------------------------------------------------------------------------------------
def _compute_action_on_policy(self, p_obs: State) -> Action:
obs = p_obs.get_values()
Expand All @@ -289,8 +313,17 @@ def _compute_action_on_policy(self, p_obs: State) -> Action:
else:
obs = torch.Tensor(obs).reshape(1, obs.size).to(self.sb3.device)

if self._action_masking:
if p_obs.get_kwargs() is not None:
act_masks = p_obs.get_kwargs()
else:
act_masks = self._get_mask()

with torch.no_grad():
actions, values, log_probs = self.sb3.policy.forward(obs)
if not self._action_masking:
actions, values, log_probs = self.sb3.policy.forward(obs)
else:
actions, values, log_probs = self.sb3.policy.forward(obs, action_masks=act_masks)

actions = actions.cpu().numpy()

Expand All @@ -307,6 +340,9 @@ def _compute_action_on_policy(self, p_obs: State) -> Action:
action_buffer = actions.flatten()
action_buffer = Action(self._id, self._action_space, action_buffer)

if self._action_masking:
self._add_to_mask(action_buffer)

# Add to additional_buffer_element
self.additional_buffer_element = dict(action=action_buffer, value=values, action_log=log_probs)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


setup(name='mlpro-int-sb3',
version='1.0.2',
version='1.0.3',
description='MLPro: Integration StableBaselines3',
author='MLPro Team',
author_mail='[email protected]',
Expand All @@ -14,8 +14,10 @@
"full": [
"mlpro_int_gymnasium>=1.0.2",
"mlpro_int_mujoco[full]>=1.0.1",
"mlpro>=1.9.0",
"gymnasium!=1.0.0",
"stable_baselines3>=2.3.0",
"sb3-contrib>=2.2.1",
"torch>=2.0.0,<=2.3.1",
"numpy>=1.0.0,<=1.26.4",
"setuptools >= 75.0.0"
Expand Down

0 comments on commit 55de436

Please sign in to comment.