Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: 🐛 FIxed obsk=None: obsk only influence obs construction, but… #228

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 10 additions & 33 deletions gymnasium_robotics/envs/multiagent_mujoco/mujoco_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
This project is covered by the Apache 2.0 License.
"""


from __future__ import annotations

import os
Expand All @@ -31,7 +30,6 @@
CoupledHalfCheetahEnv,
)
from gymnasium_robotics.envs.multiagent_mujoco.obsk import (
Node,
build_obs,
get_joints_at_kdist,
get_parts_and_edges,
Expand Down Expand Up @@ -94,7 +92,7 @@ def __init__(
If set to 0 it only observes local state,
If set to 1 it observes local state + 1 joint over,
If set to 2 it observes local state + 2 joints over,
If it set to None the task becomes single agent (the agent observes the entire environment, and performs all the actions)
If it set to None the task becomes single agent (the agents observe the entire environment)
The Default value is: 1
agent_factorization: A custom factorization of the MuJoCo environment (overwrites agent_conf),
see DOC [how to create new agent factorizations](https://robotics.farama.org/envs/MaMuJoCo/index.html#how-to-create-new-agent-factorizations).
Expand Down Expand Up @@ -127,25 +125,16 @@ def __init__(
self.agent_obsk = agent_obsk # if None, fully observable else k>=0 implies observe nearest k agents or joints

# load the agent factorization
if self.agent_obsk is not None:
if agent_factorization is None:
(
self.agent_action_partitions,
mujoco_edges,
self.mujoco_globals,
) = get_parts_and_edges(scenario, agent_conf)
else:
self.agent_action_partitions = agent_factorization["partition"]
mujoco_edges = agent_factorization["edges"]
self.mujoco_globals = agent_factorization["globals"]
if agent_factorization is None:
(
self.agent_action_partitions,
mujoco_edges,
self.mujoco_globals,
) = get_parts_and_edges(scenario, agent_conf)
else:
self.agent_action_partitions = [
tuple(
Node("dummy_node", None, None, i)
for i in range(self.single_agent_env.action_space.shape[0])
)
]
mujoco_edges = []
self.agent_action_partitions = agent_factorization["partition"]
mujoco_edges = agent_factorization["edges"]
self.mujoco_globals = agent_factorization["globals"]

# Create agent lists
self.possible_agents = [
Expand Down Expand Up @@ -293,9 +282,6 @@ def map_local_actions_to_global_action(
AssertionError:
If the Agent action factorization is badly defined (if an action is double defined or not defined at all)
"""
if self.agent_obsk is None:
return actions[self.possible_agents[0]]

assert self.single_agent_env.action_space.shape is not None
global_action = (
np.zeros((self.single_agent_env.action_space.shape[0],)) + np.nan
Expand Down Expand Up @@ -329,9 +315,6 @@ def map_global_action_to_local_actions(
AssertionError:
If the Agent action factorization sizes are badly defined
"""
if self.agent_obsk is None:
return {self.possible_agents[0]: action}

local_actions = {}
for agent_id, partition in enumerate(self.agent_action_partitions):
local_actions[self.possible_agents[agent_id]] = np.array(
Expand Down Expand Up @@ -417,12 +400,6 @@ def create_observation_mapping(self) -> dict[str, np.ndarray[np.float64]]:
Returns:
A cache that indexes global osbervations to local.
"""
if self.agent_obsk is None:
return {
self.possible_agents[0]: np.arange(
self.single_agent_env.observation_space.shape[0]
)
}
if not hasattr(self.single_agent_env.unwrapped, "observation_structure"):
return None

Expand Down
Loading