Skip to content

Commit

Permalink
Call agent's pre-interaction during evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 10, 2024
1 parent bacb0f5 commit ab0956d
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 351 deletions.
104 changes: 56 additions & 48 deletions skrl/trainers/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self,
self.headless = self.cfg.get("headless", False)
self.disable_progressbar = self.cfg.get("disable_progressbar", False)
self.close_environment_at_exit = self.cfg.get("close_environment_at_exit", True)
self.environment_info = self.cfg.get("environment_info", "episode")

self.initial_timestep = 0

Expand Down Expand Up @@ -172,19 +173,18 @@ def single_agent_train(self) -> None:
# pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)

# compute actions
with contextlib.nullcontext():
# compute actions
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]

# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)

# render scene
if not self.headless:
self.env.render()
# render scene
if not self.headless:
self.env.render()

# record the environments' transitions
with contextlib.nullcontext():
# record the environments' transitions
self.agents.record_transition(states=states,
actions=actions,
rewards=rewards,
Expand Down Expand Up @@ -226,18 +226,20 @@ def single_agent_eval(self) -> None:

for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):

# compute actions
# pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)

with contextlib.nullcontext():
# compute actions
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]

# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)

# render scene
if not self.headless:
self.env.render()
# render scene
if not self.headless:
self.env.render()

with contextlib.nullcontext():
# write data to TensorBoard
self.agents.record_transition(states=states,
actions=actions,
Expand All @@ -248,7 +250,9 @@ def single_agent_eval(self) -> None:
infos=infos,
timestep=timestep,
timesteps=self.timesteps)
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps)

# post-interaction
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps)

# reset environments
if self.env.num_envs > 1:
Expand Down Expand Up @@ -285,22 +289,21 @@ def multi_agent_train(self) -> None:
# pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)

# compute actions
with contextlib.nullcontext():
# compute actions
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]

# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
shared_next_states = self.env.state()
infos["shared_states"] = shared_states
infos["shared_next_states"] = shared_next_states
# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
shared_next_states = self.env.state()
infos["shared_states"] = shared_states
infos["shared_next_states"] = shared_next_states

# render scene
if not self.headless:
self.env.render()
# render scene
if not self.headless:
self.env.render()

# record the environments' transitions
with contextlib.nullcontext():
# record the environments' transitions
self.agents.record_transition(states=states,
actions=actions,
rewards=rewards,
Expand All @@ -315,13 +318,13 @@ def multi_agent_train(self) -> None:
self.agents.post_interaction(timestep=timestep, timesteps=self.timesteps)

# reset environments
with contextlib.nullcontext():
if not self.env.agents:
if not self.env.agents:
with contextlib.nullcontext():
states, infos = self.env.reset()
shared_states = self.env.state()
else:
states = next_states
shared_states = shared_next_states
else:
states = next_states
shared_states = shared_next_states

def multi_agent_eval(self) -> None:
"""Evaluate multi-agents
Expand All @@ -342,21 +345,23 @@ def multi_agent_eval(self) -> None:

for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):

# compute actions
# pre-interaction
self.agents.pre_interaction(timestep=timestep, timesteps=self.timesteps)

with contextlib.nullcontext():
# compute actions
actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0]

# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
shared_next_states = self.env.state()
infos["shared_states"] = shared_states
infos["shared_next_states"] = shared_next_states
# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
shared_next_states = self.env.state()
infos["shared_states"] = shared_states
infos["shared_next_states"] = shared_next_states

# render scene
if not self.headless:
self.env.render()
# render scene
if not self.headless:
self.env.render()

with contextlib.nullcontext():
# write data to TensorBoard
self.agents.record_transition(states=states,
actions=actions,
Expand All @@ -367,12 +372,15 @@ def multi_agent_eval(self) -> None:
infos=infos,
timestep=timestep,
timesteps=self.timesteps)
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps)

# reset environments
if not self.env.agents:
# post-interaction
super(type(self.agents), self.agents).post_interaction(timestep=timestep, timesteps=self.timesteps)

# reset environments
if not self.env.agents:
with contextlib.nullcontext():
states, infos = self.env.reset()
shared_states = self.env.state()
else:
states = next_states
shared_states = shared_next_states
else:
states = next_states
shared_states = shared_next_states
55 changes: 31 additions & 24 deletions skrl/trainers/jax/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"headless": False, # whether to use headless mode (no rendering)
"disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY
"close_environment_at_exit": True, # whether to close the environment on normal program termination
"environment_info": "episode", # key used to get and log environment info
}
# [end-config-dict-jax]

Expand Down Expand Up @@ -93,20 +94,19 @@ def train(self) -> None:
for agent in self.agents:
agent.pre_interaction(timestep=timestep, timesteps=self.timesteps)

# compute actions
with contextlib.nullcontext():
# compute actions
actions = jnp.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \
for agent, scope in zip(self.agents, self.agents_scope)])

# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)

# render scene
if not self.headless:
self.env.render()
# render scene
if not self.headless:
self.env.render()

# record the environments' transitions
with contextlib.nullcontext():
# record the environments' transitions
for agent, scope in zip(self.agents, self.agents_scope):
agent.record_transition(states=states[scope[0]:scope[1]],
actions=actions[scope[0]:scope[1]],
Expand All @@ -123,11 +123,11 @@ def train(self) -> None:
agent.post_interaction(timestep=timestep, timesteps=self.timesteps)

# reset environments
with contextlib.nullcontext():
if terminated.any() or truncated.any():
if terminated.any() or truncated.any():
with contextlib.nullcontext():
states, infos = self.env.reset()
else:
states = next_states
else:
states = next_states

def eval(self) -> None:
"""Evaluate the agents sequentially
Expand Down Expand Up @@ -161,19 +161,22 @@ def eval(self) -> None:

for timestep in tqdm.tqdm(range(self.initial_timestep, self.timesteps), disable=self.disable_progressbar, file=sys.stdout):

# compute actions
# pre-interaction
for agent in self.agents:
agent.pre_interaction(timestep=timestep, timesteps=self.timesteps)

with contextlib.nullcontext():
# compute actions
actions = jnp.vstack([agent.act(states[scope[0]:scope[1]], timestep=timestep, timesteps=self.timesteps)[0] \
for agent, scope in zip(self.agents, self.agents_scope)])

# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
# step the environments
next_states, rewards, terminated, truncated, infos = self.env.step(actions)

# render scene
if not self.headless:
self.env.render()
# render scene
if not self.headless:
self.env.render()

with contextlib.nullcontext():
# write data to TensorBoard
for agent, scope in zip(self.agents, self.agents_scope):
agent.record_transition(states=states[scope[0]:scope[1]],
Expand All @@ -185,10 +188,14 @@ def eval(self) -> None:
infos=infos,
timestep=timestep,
timesteps=self.timesteps)
super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps)

# reset environments
if terminated.any() or truncated.any():
# post-interaction
for agent in self.agents:
super(type(agent), agent).post_interaction(timestep=timestep, timesteps=self.timesteps)

# reset environments
if terminated.any() or truncated.any():
with contextlib.nullcontext():
states, infos = self.env.reset()
else:
states = next_states
else:
states = next_states
Loading

0 comments on commit ab0956d

Please sign in to comment.