Skip to content

Commit

Permalink
Merge branch 'state-space'
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Aug 6, 2024
2 parents 4ea996a + 93bb9bb commit 9e106fb
Show file tree
Hide file tree
Showing 13 changed files with 483 additions and 193 deletions.
59 changes: 46 additions & 13 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,24 @@ debug:
env:
headless: False
stream_wrapper: False
init_state: victory_road
max_steps: 16
init_state: "victory_road"
state_dir: pyboy_states
max_steps: 20480
log_frequency: 1
disable_wild_encounters: True
disable_ai_actions: True
use_global_map: False
reduce_res: False
reduce_res: True
animate_scripts: True
save_state: False
train:
device: cpu
compile: False
compile_mode: default
num_envs: 1
envs_per_worker: 1
num_workers: 1
env_batch_size: 4
env_batch_size: 128
env_pool: True
zero_copy: False
batch_size: 128
Expand All @@ -36,8 +39,8 @@ debug:
verbose: False
env_pool: False
load_optimizer_state: False
# swarm_frequency: 10
# swarm_keep_pct: .1
async_wrapper: False
archive_states: False

env:
headless: True
Expand All @@ -54,7 +57,7 @@ env:
perfect_ivs: True
reduce_res: True
two_bit: True
log_frequency: 1000
log_frequency: 2000
auto_flash: True
disable_wild_encounters: True
disable_ai_actions: False
Expand All @@ -68,8 +71,10 @@ env:
auto_pokeflute: True
infinite_money: True
use_global_map: False
save_state: False
save_state: True
animate_scripts: False
exploration_inc: 1.0
exploration_max: 1.0



Expand Down Expand Up @@ -102,7 +107,7 @@ train:

num_envs: 288
num_workers: 24
env_batch_size: 72
env_batch_size: 36
env_pool: True
zero_copy: False

Expand All @@ -116,10 +121,9 @@ train:
pool_kernel: [0]
load_optimizer_state: False
use_rnn: True
async_wrapper: False

# swarm_frequency: 500
# swarm_keep_pct: .8
async_wrapper: True
archive_states: True
swarm: True

wrappers:
empty:
Expand Down Expand Up @@ -257,6 +261,35 @@ rewards:
required_item: 5.0
useful_item: 1.0
pokecenter_heal: 1.0

baseline.ObjectRewardRequiredEventsEnvTilesetExploration:
reward:
event: 1.0
seen_pokemon: 4.0
caught_pokemon: 4.0
moves_obtained: 4.0
hm_count: 10.0
level: 1.0
badges: 5.0
cut_coords: 0.0
cut_tiles: 0.0
start_menu: 0.0
pokemon_menu: 0.0
stats_menu: 0.0
bag_menu: 0.0
explore_hidden_objs: 0.02
seen_action_bag_menu: 0.0
required_event: 5.0
required_item: 5.0
useful_item: 1.0
pokecenter_heal: 0.2
exploration: 0.02
exploration_gym: 0.025
exploration_facility: 0.10
exploration_plateau: 0.025
exploration_lobby: 0.035 # for game corner
a_press: 0.00001
explore_warps: 0.05



Expand Down
167 changes: 109 additions & 58 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import argparse
import heapq
import math
import ast
from datetime import datetime
from functools import partial
import os
import pathlib
import random
import time
from collections import defaultdict, deque
Expand Down Expand Up @@ -138,6 +140,9 @@ class CleanPuffeRL:
stats: dict = field(default_factory=lambda: {})
msg: str = ""
infos: dict = field(default_factory=lambda: defaultdict(list))
states: dict = field(default_factory=lambda: defaultdict(partial(deque, maxlen=5)))
event_tracker: dict = field(default_factory=lambda: {})
max_event_count: int = 0

def __post_init__(self):
seed_everything(self.config.seed, self.config.torch_deterministic)
Expand Down Expand Up @@ -197,53 +202,15 @@ def __post_init__(self):
self.taught_cut = False
self.log = False

if self.config.archive_states:
self.archive_path = pathlib.Path(datetime.now().strftime("%Y%m%d-%H%M%S"))
self.archive_path.mkdir(exist_ok=False)

@pufferlib.utils.profile
def evaluate(self):
# Clear all self.infos except for the state
# states are managed separately so dont worry about deleting them
for k in list(self.infos.keys()):
if k != "state":
del self.infos[k]

# now for a tricky bit:
# if we have swarm_frequency, we will take the top swarm_keep_pct envs and evenly distribute
# their states to the bottom 90%.
# we do this here so the environment can remain "pure"
if (
self.config.async_wrapper
and hasattr(self.config, "swarm_frequency")
and hasattr(self.config, "swarm_keep_pct")
and self.epoch % self.config.swarm_frequency == 0
and "reward/event" in self.infos
and "state" in self.infos
):
# collect the top swarm_keep_pct % of envs
largest = [
x[0]
for x in heapq.nlargest(
math.ceil(self.config.num_envs * self.config.swarm_keep_pct),
enumerate(self.infos["reward/event"]),
key=lambda x: x[1],
)
]
print("Migrating states:")
waiting_for = []
# Need a way not to reset the env id counter for the driver env
# Until then env ids are 1-indexed
for i in range(self.config.num_envs):
if i not in largest:
new_state = random.choice(largest)
print(
f'\t {i+1} -> {new_state+1}, event scores: {self.infos["reward/event"][i]} -> {self.infos["reward/event"][new_state]}'
)
self.env_recv_queues[i + 1].put(self.infos["state"][new_state])
waiting_for.append(i + 1)
# Now copy the hidden state over
# This may be a little slow, but so is this whole process
self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :]
self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :]
for i in waiting_for:
self.env_send_queues[i].get()
print("State migration complete")
del self.infos[k]

with self.profile.eval_misc:
policy = self.policy
Expand Down Expand Up @@ -289,15 +256,101 @@ def evaluate(self):

for i in info:
for k, v in pufferlib.utils.unroll_nested_dict(i):
if k == "state":
self.infos[k] = [v]
if "state/" in k:
_, key = k.split("/")
key: tuple[str] = ast.literal_eval(key)
self.states[key].append(v)
if self.config.archive_states:
state_dir = self.archive_path / str(hash(key))
if not state_dir.exists():
state_dir.mkdir(exist_ok=True)
with open(state_dir / "desc.txt", "w") as f:
f.write(str(key))
with open(state_dir / f"{hash(v)}.state", "wb") as f:
f.write(v)
elif "required_count" == k:
for count, eid in zip(
self.infos["required_count"], self.infos["env_id"]
):
self.event_tracker[eid] = count
self.infos[k].append(v)
else:
self.infos[k].append(v)

with self.profile.env:
self.vecenv.send(actions)

with self.profile.eval_misc:
# now for a tricky bit:
# if we have swarm_frequency, we will migrate the bottom
# % of envs in the batch (by required events count)
# and migrate them to a new state at random.
# Now this has a lot of gotchas and is really unstable
# E.g. Some envs could just constantly be on the bottom since they're never
# progressing
# env id in async queues is the index within self.infos - self.config.num_envs + 1
if (
self.config.async_wrapper
and hasattr(self.config, "swarm")
and self.config.swarm
# and self.epoch % self.config.swarm_frequency == 0
and "required_count" in self.infos
and self.states
):
"""
# V1 implementation -
# collect the top swarm_keep_pct % of the envs in the batch
# migrate the envs not in the largest keep pct to one of the top states
largest = [
x[1][0]
for x in heapq.nlargest(
math.ceil(len(self.event_tracker) * self.config.swarm_keep_pct),
enumerate(self.event_tracker.items()),
key=lambda x: x[1][0],
)
]
to_migrate_keys = set(self.event_tracker.keys()) - set(largest)
print(f"Migrating {len(to_migrate_keys)} states:")
for key in to_migrate_keys:
# we store states in a weird format
# pull a list of states corresponding to a required event completion state
new_state_key = random.choice(list(self.states.keys()))
# pull a state within that list
new_state = random.choice(self.states[new_state_key])
"""

# V2 implementation
# check if we have a new highest required_count with N save states available
# If we do, migrate 100% of states to one of the states
max_event_count = 0
new_state_key = ""
for key in self.states.keys():
if len(key) > max_event_count:
max_event_count = len(key)
new_state_key = key
max_state: deque = self.states[key]
if max_event_count > self.max_event_count and len(max_state) == max_state.maxlen:
self.max_event_count = max_event_count

# Need a way not to reset the env id counter for the driver env
# Until then env ids are 1-indexed
for key in self.event_tracker.keys():
new_state = random.choice(self.states[new_state_key])

print(f"Environment ID: {key}")
print(f"\tEvents count: {self.event_tracker[key]} -> {len(new_state_key)}")
print(f"\tNew events: {new_state_key}")
self.env_recv_queues[key].put(new_state)
# Now copy the hidden state over
# This may be a little slow, but so is this whole process
# self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :]
# self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :]
for key in self.event_tracker.keys():
# print(f"\tWaiting for message from env-id {key}")
self.env_send_queues[key].get()
print("State migration complete")

self.stats = {}

for k, v in self.infos.items():
Expand All @@ -307,9 +360,7 @@ def evaluate(self):
if self.epoch % self.config.overlay_interval == 0:
overlay = make_pokemon_red_overlay(np.stack(self.infos[k], axis=0))
if self.wandb_client is not None:
self.stats["Media/aggregate_exploration_map"] = wandb.Image(
overlay, file_type="jpg"
)
self.stats["Media/aggregate_exploration_map"] = wandb.Image(overlay)
elif "state" in k:
continue

Expand Down Expand Up @@ -484,7 +535,7 @@ def train(self):
)

if self.epoch % self.config.checkpoint_interval == 0 or done_training:
self.save_checkpoint()
# self.save_checkpoint()
self.msg = f"Checkpoint saved at update {self.epoch}"

def close(self):
Expand All @@ -493,11 +544,11 @@ def close(self):
self.utilization.stop()

if self.wandb_client is not None:
artifact_name = f"{self.exp_name}_model"
artifact = wandb.Artifact(artifact_name, type="model")
model_path = self.save_checkpoint()
artifact.add_file(model_path)
self.wandb_client.log_artifact(artifact)
# artifact_name = f"{self.exp_name}_model"
# artifact = wandb.Artifact(artifact_name, type="model")
# model_path = self.save_checkpoint()
# artifact.add_file(model_path)
# self.wandb_client.log_artifact(artifact)
self.wandb_client.finish()

def save_checkpoint(self):
Expand Down Expand Up @@ -541,7 +592,7 @@ def __enter__(self):

def __exit__(self, *args):
print("Done training.")
self.save_checkpoint()
# self.save_checkpoint()
self.close()
print("Run complete")

Expand Down
6 changes: 6 additions & 0 deletions pokemonred_puffer/data/events.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ctypes import c_uint8, LittleEndianStructure, Union
import re

from pyboy import PyBoy

Expand Down Expand Up @@ -2585,6 +2586,11 @@ def get_event(self, event_name: str) -> bool:
return bool(getattr(self.b, event_name))


EVENTS = {
event for event, _, _ in EventFlagsBits._fields_ if not re.search("EVENT_[0-9,A-F]{3}$", event)
}


REQUIRED_EVENTS = {
"EVENT_FOLLOWED_OAK_INTO_LAB",
"EVENT_PALLET_AFTER_GETTING_POKEBALLS",
Expand Down
Loading

0 comments on commit 9e106fb

Please sign in to comment.