Skip to content

Commit

Permalink
Initial work towards removing local_reward & some clarity refactors.
Browse files Browse the repository at this point in the history
  • Loading branch information
anordin95 committed Dec 4, 2024
1 parent 369848b commit 3557a2e
Showing 1 changed file with 31 additions and 49 deletions.
80 changes: 31 additions & 49 deletions pettingzoo/butterfly/pistonball/pistonball.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@
from pettingzoo.utils import AgentSelector, wrappers
from pettingzoo.utils.conversions import parallel_wrapper_fn

_image_library = {}

FPS = 20

__all__ = ["ManualPolicy", "env", "parallel_env", "raw_env"]
Expand Down Expand Up @@ -239,8 +237,7 @@ def __init__(
)
self.recentPistons = set() # Set of pistons that have touched the ball recently
self.time_penalty = time_penalty
# TODO: this was a bad idea and the logic this uses should be removed at some point
self.local_ratio = 0

self.ball_mass = ball_mass
self.ball_friction = ball_friction
self.ball_elasticity = ball_elasticity
Expand Down Expand Up @@ -466,8 +463,8 @@ def reset(self, seed=None, options=None):
-6 * math.pi, 6 * math.pi
)

self.lastX = int(self.ball.position[0] - self.ball_radius)
self.distance = self.lastX - self.wall_width
self.ball_prev_pos = self._get_ball_position()
self.distance_to_wall_at_game_start = self.ball_prev_pos - self.wall_width

self.draw_background()
self.draw()
Expand Down Expand Up @@ -566,30 +563,6 @@ def draw(self):
)
self.draw_pistons()

def get_nearby_pistons(self):
# first piston = leftmost
nearby_pistons = []
ball_pos = int(self.ball.position[0] - self.ball_radius)
closest = abs(self.pistonList[0].position.x - ball_pos)
closest_piston_index = 0
for i in range(self.n_pistons):
next_distance = abs(self.pistonList[i].position.x - ball_pos)
if next_distance < closest:
closest = next_distance
closest_piston_index = i

if closest_piston_index > 0:
nearby_pistons.append(closest_piston_index - 1)
nearby_pistons.append(closest_piston_index)
if closest_piston_index < self.n_pistons - 1:
nearby_pistons.append(closest_piston_index + 1)

return nearby_pistons

def get_local_reward(self, prev_position, curr_position):
local_reward = 0.5 * (prev_position - curr_position)
return local_reward

def render(self):
if self.render_mode is None:
gymnasium.logger.warn(
Expand All @@ -612,6 +585,15 @@ def render(self):
if self.render_mode == "rgb_array"
else None
)

def _get_ball_position(self) -> int:
"""Return the leftmost x-position of the ball. If the ball
extends beyond the leftmost wall, return the position of that
wall-edge."""
ball_position = int(self.ball.position[0] - self.ball_radius)
# check if the ball is touching/within the left-most wall.
clipped_ball_position = max(self.wall_width, ball_position)
return clipped_ball_position

def step(self, action):
if (
Expand All @@ -633,30 +615,30 @@ def step(self, action):

self.space.step(self.dt)
if self._agent_selector.is_last():
ball_min_x = int(self.ball.position[0] - self.ball_radius)
ball_next_x = (
self.ball.position[0]
- self.ball_radius
+ self.ball.velocity[0] * self.dt
ball_curr_pos = self._get_ball_position()

# A rough, first-order prediction (i.e. velocity-only) of the balls next position.
# The physics environment may bounce the ball off the wall in the next time-step
# without us first registering that win-condition.
ball_predicted_next_pos = (
ball_curr_pos +
self.ball.velocity[0] * self.dt
)
if ball_next_x <= self.wall_width + 1:
# Include a single-pixel fudge-factor for the approximation.
if ball_predicted_next_pos <= self.wall_width + 1:
self.terminate = True
# ensures that the ball can't pass through the wall
ball_min_x = max(self.wall_width, ball_min_x)

self.draw()
local_reward = self.get_local_reward(self.lastX, ball_min_x)
# Opposite order due to moving right to left
global_reward = (100 / self.distance) * (self.lastX - ball_min_x)

# The negative one is included since the x-axis increases from left-to-right. And, if the x
# position decreases we want the reward to be positive, since the ball would have gotten closer
# to the left-wall.
global_reward = -1 * (ball_curr_pos - self.ball_prev_pos) * (100 / self.distance_to_wall_at_game_start)
if not self.terminate:
global_reward += self.time_penalty
total_reward = [
global_reward * (1 - self.local_ratio)
] * self.n_pistons # start with global reward
local_pistons_to_reward = self.get_nearby_pistons()
for index in local_pistons_to_reward:
total_reward[index] += local_reward * self.local_ratio
self.rewards = dict(zip(self.agents, total_reward))
self.lastX = ball_min_x

self.rewards = {agent: global_reward for agent in self.agents}
self.ball_prev_pos = ball_curr_pos
self.frames += 1
else:
self._clear_rewards()
Expand Down

0 comments on commit 3557a2e

Please sign in to comment.