Skip to content

Commit

Permalink
debug off policy, prep v1.0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
steveyuwono committed Dec 4, 2024
1 parent eb751ee commit 1c32245
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
2 changes: 1 addition & 1 deletion doc/rtd/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Detlef Arend, Steve Yuwono, Laxmikant Shrikant Baheti et al'

# The full version, including alpha/beta/rc tags
release = '1.0.3'
release = '1.0.4'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = mlpro-int-sb3
version = 1.0.3
version = 1.0.4
author = MLPro Team
author_email = [email protected]
description = MLPro: Integration StableBaselines3
Expand Down
14 changes: 12 additions & 2 deletions src/mlpro_int_sb3/wrappers/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,20 +481,30 @@ def _add_buffer_off_policy(self, p_buffer_element: SARSElement):
data_next_obs['achieved_goal'] = np.array(datas["state_new"].get_values())
data_next_obs['desired_goal'] = np.array(self.desired_goals)
data_next_obs['observation'] = np.array(datas["state_new"].get_values())

try:
rewards = datas["reward"].get_overall_reward()
except:
rewards = datas["reward"].get_agent_reward(self._id)

self.sb3.replay_buffer.add(
obs=data_obs,
next_obs=data_next_obs,
action=datas["action"].get_sorted_values(),
reward=datas["reward"].get_overall_reward(),
reward=rewards,
done=datas["state_new"].get_terminal(),
infos=[info])
else:
try:
rewards = datas["reward"].get_overall_reward()
except:
rewards = datas["reward"].get_agent_reward(self._id)

self.sb3.replay_buffer.add(
datas["state"].get_values(),
datas["state_new"].get_values(),
datas["action"].get_sorted_values(),
datas["reward"].get_overall_reward(),
rewards,
datas["state_new"].get_terminal(),
[info])

Expand Down
2 changes: 1 addition & 1 deletion src/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


setup(name='mlpro-int-sb3',
version='1.0.3',
version='1.0.4',
description='MLPro: Integration StableBaselines3',
author='MLPro Team',
author_mail='[email protected]',
Expand Down

0 comments on commit 1c32245

Please sign in to comment.