Skip to content

Commit

Permalink
Merge pull request #7 from rradules/crazyrl
Browse files Browse the repository at this point in the history
crazyrl implemented
  • Loading branch information
ffelten authored Nov 2, 2023
2 parents 59c3ce7 + 8b84d4f commit 47b171b
Show file tree
Hide file tree
Showing 16 changed files with 916 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ Clone the repo and run `pre-commit install` to setup the pre-commit hooks.
3. Define the factory functions to create your class: `parallel_env` returns a parallel version of the env, `env` returns an AEC version, and `raw_env` that is the pure class constructor (it is not used in practice). (!) use the conversions that are defined inside our repository, e.g. `mo_parallel_to_aec` instead of `parallel_to_aec` from PZ.
4. (!) do not use `OrderEnforcingWrapper`, it prevents from accessing the `reward_space` of the env :-(;
5. Add a versioned constructor of your env in the directory which exports the factory functions (see `mobeach_v0.py` for an example).
6. Add your environment to the tests in `tests/all_modules.py`
6. Add your environment to the tests in `utils/all_modules.py`
7. Run `pytest` to check that everything works
1 change: 1 addition & 0 deletions momadm_benchmarks/envs/crazyrl/catch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Catch environment for multi-agent reinforcement learning."""
130 changes: 130 additions & 0 deletions momadm_benchmarks/envs/crazyrl/catch/catch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Catch environment for Crazyflie 2. Each agent is supposed to learn to surround a common target point trying to escape."""

from typing_extensions import override

import numpy as np
from pettingzoo.utils.wrappers import AssertOutOfBoundsWrapper

from momadm_benchmarks.envs.crazyrl.crazyRL_base import FPS, CrazyRLBaseParallelEnv
from momadm_benchmarks.utils.conversions import mo_parallel_to_aec


def env(*args, **kwargs):
"""Returns the wrapped environment in `AEC` format.
Args:
**kwargs: keyword args to forward to the raw_env function.
Returns:
A wrapped AEC env.
"""
env = raw_env(*args, **kwargs)
env = mo_parallel_to_aec(env)
env = AssertOutOfBoundsWrapper(env)
return env


def parallel_env(*args, **kwargs):
"""Returns the wrapped env in `parallel` format.
Args:
**kwargs: keyword args to forward to the raw_env function.
Returns:
A parallel env.
"""
env = raw_env(*args, **kwargs)
return env


def raw_env(*args, **kwargs):
"""Returns the environment in `Parallel` format.
Args:
**kwargs: keyword args to forward to create the `Catch` environment.
Returns:
A raw env.
"""
return Catch(*args, **kwargs)


class Catch(CrazyRLBaseParallelEnv):
"""A Parallel Environment where drone learn how to surround a moving target trying to escape."""

metadata = {"render_modes": ["human"], "name": "mocatch_v0", "is_parallelizable": True, "render_fps": FPS}

@override
def __init__(self, *args, target_speed=0.1, **kwargs):
"""Catch environment in CrazyRL.
Args:
render_mode (str, optional): The mode to display the rendering of the environment. Can be human or None.
size (int, optional): Size of the area sides
num_drones: amount of drones
init_flying_pos: 2d array containing the coordinates of the agents
is a (3)-shaped array containing the initial XYZ position of the drones.
init_target_location: Array of the initial position of the moving target
target_speed: Distance traveled by the target at each timestep
"""

super().__init__(*args, **kwargs)
self.target_speed = target_speed

def _move_target(self):
# mean of the agent's positions
mean = np.array([0, 0, 0])
for agent in self.agents:
mean = mean + self.agent_location[agent]

mean = mean / self.num_drones

dist = np.linalg.norm(mean - self.target_location)
self.target_location = self.target_location.copy()

# go to the opposite direction of the mean of the agents
if dist > 0.2:
self.target_location += (self.target_location - mean) / dist * self.target_speed

# if the mean of the agents is too close to the target, move the target in a random direction, slowly because
# it hesitates
else:
self.target_location += np.random.random_sample(3) * self.target_speed * 0.1

# if the target is out of the map, put it back in the map
np.clip(
self.target_location,
[-self.size, -self.size, 0.2],
[self.size, self.size, 3],
out=self.target_location,
)

@override
def _transition_state(self, actions):
target_point_action = dict()
state = self.agent_location

# new targets
self.previous_target = self.target_location.copy()
self._move_target()

for agent in self.agents:
# Actions are clipped to stay in the map and scaled to do max 20cm in one step
target_point_action[agent] = np.clip(
state[agent] + actions[agent] * 0.2, [-self.size, -self.size, 0], [self.size, self.size, 3]
)

return target_point_action


if __name__ == "__main__":
prll_env = Catch(render_mode="human")

observations, infos = prll_env.reset()

while prll_env.agents:
actions = {
agent: prll_env.action_space(agent).sample() for agent in prll_env.agents
} # this is where you would insert your policy
observations, rewards, terminations, truncations, infos = prll_env.step(actions)
prll_env.render()
5 changes: 5 additions & 0 deletions momadm_benchmarks/envs/crazyrl/catch/mocatch_v0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""CrazyRL/Catch environment for MOMARL."""
from momadm_benchmarks.envs.crazyrl.catch.catch import env, parallel_env, raw_env


__all__ = ["env", "parallel_env", "raw_env"]
Loading

0 comments on commit 47b171b

Please sign in to comment.