Skip to content

Commit

Permalink
improved training rendering
Browse files Browse the repository at this point in the history
  • Loading branch information
Flunzmas committed Nov 11, 2020
1 parent 3c94ef2 commit cc861b5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 42 deletions.
29 changes: 19 additions & 10 deletions agents/goal_rgcn_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

import numpy as np
from operator import itemgetter

Expand Down Expand Up @@ -29,6 +31,7 @@

min_po_attempts = 1000
target_po_ratio = 0.95
render_each_step = False

# ------------------------------------------------------------------------------

Expand Down Expand Up @@ -107,14 +110,13 @@ def train():
# ------ vars etc. ------

frame_idx = 0
train_epoch = 0
iteration = 0

# ------ TRAINING LOOP ------

can_stop_training = False

while not can_stop_training:
while True:

iter_start = time.time()
log_probs = []
values = []
states = []
Expand All @@ -127,7 +129,8 @@ def train():
dist, value = model(state)
action = dist.sample()
next_state, reward, done, _ = env.step(action)
# env.render()
if render_each_step:
env.render()
log_prob = dist.log_prob(action)

log_probs.append(log_prob)
Expand All @@ -150,18 +153,24 @@ def train():
advantage = returns - values
advantage = normalize(advantage)

env.render()
ppo_update(model, optimizer, frame_idx, states, actions, log_probs, returns, advantage)
train_epoch += 1
can_stop_training = check_stop_cond(env)

iter_stop = time.time()
print("\nPPO iteration {0} done in {1} secs. Env stats:".format(iteration, round(iter_stop - iter_start, 2)))
env.render(mode='cli_basic')

if can_stop_training(env):
print("Training stop condition met, finishing training.")
break
else:
iteration += 1

# ------ cleanup ------

env.close()


def check_stop_cond(env):
print("CAN_STOP: {0} | {1}".format(len(env.po_success_history), env.po_percent))
def can_stop_training(env):
return len(env.po_success_history) > min_po_attempts and env.po_percent > target_po_ratio


Expand Down
69 changes: 37 additions & 32 deletions gym_autokey/envs/autokey_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AutokeyEnv(gym.Env):
the amount of closed goals and POs.
"""

metadata = {'render.modes': ['human']}
metadata = {'render.modes': ['cli_full', 'cli_basic']}

def __init__(self, self_render: bool = True):
"""
Expand Down Expand Up @@ -88,7 +88,7 @@ def _del(self):
if self.connector:
self.connector.quit_key()

# -------------------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------

def step(self, action):
"""
Expand All @@ -97,13 +97,9 @@ def step(self, action):
proving process.
"""

# time.sleep(0.3)
if self.env_steps == 3:
print(self._observe())

# render the past step at the beginning if self_render
if self.self_render:
self.render()
self.render()
if self.env_steps % 1000 == 0 and len(self.po_success_history) > 0:
with open(self.po_percent_logfile, 'a') as po_p_file:
po_p_file.write(str(round(100 * sum(self.po_success_history) / len(self.po_success_history), 1)) + "\n")
Expand Down Expand Up @@ -131,7 +127,7 @@ def step(self, action):

# tactic preparation
cur_tactic = self.connector.available_tactics[action]
cur_tactic_app_str = ' {0}'.format(cf.TACTIC_ABBR[cur_tactic])\
cur_tactic_app_str = ' {0}'.format(cf.TACTIC_ABBR[cur_tactic]) \
+ ' (#{0})'.format(self.cur_subepis.cur_goal.id).rjust(9)
self.tactic_history.append(cur_tactic_app_str)

Expand All @@ -157,7 +153,7 @@ def step(self, action):
self.last_action = "crash (an id is -1):" + cur_tactic_app_str
return self._episode_exit("crash")
new_goals.append(new_goal_node)

# Multiple new goals -> subepisode becomes parent and env continues with a subepisode
if len(new_goals) > 1:
child_episodes = [
Expand All @@ -180,7 +176,7 @@ def step(self, action):
self.last_action = "open:" + cur_tactic_app_str
return self._observe(), 0, False, {} # return obs, rew, done, infos

# -------------------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------

def _episode_exit(self, status):
"""
Expand All @@ -196,7 +192,7 @@ def _episode_exit(self, status):
self.finalize_subepisode()
if status == "fail" or status == 'crash':
self.po_closable = False # current po cannot be closed anymore.

# end whole root episode on crash or if pre_kill is set to True.
if status == 'crash' or self.pre_kill:
while self.open_subepisodes: # status is 'open' for children, 'parent' for parents
Expand Down Expand Up @@ -252,7 +248,7 @@ def _update_episode_counter(self):
if self.cur_subepis.status == "success":
self.successful_topgoal_count += 1

# -------------------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------

def _observe(self):
"""
Expand All @@ -262,15 +258,15 @@ def _observe(self):
# self.obs_space.render(obs)
return obs

# -------------------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------

def reset(self, exit_status="none"):
"""
Resets the environment;
If there are still open subepisodes: takes the next subepisode.
Otherwise returns initial observation of ast and features of a random PO.
"""

# print("reset():: {}".format(exit_status))
# self.print_open_goals("reset")

Expand All @@ -290,7 +286,7 @@ def reset(self, exit_status="none"):
if self.env_steps > 0:

self.total_po_count += 1

# all goals went through nicely -> po successful
if self.po_closable:
# assert that KeY doesn't have open goals left either
Expand Down Expand Up @@ -355,15 +351,15 @@ def _sample_file(self):
# print('sample_file()::new goals: {0}'.format(sorted([se.cur_goal.id for se in goal_subepisodes])))
self.open_subepisodes.extend(goal_subepisodes)

# -------------------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------

def episode_to_line(self):
"""
Returns a compactified information string for the current episode.
"""
return "EP.: {epis_steps} steps".format(epis_steps=str(self.po_steps).rjust(5))
def env_to_line(self):

def env_to_line(self):
"""
Returns a compactified information string for the environment.
"""
Expand All @@ -375,25 +371,34 @@ def env_to_line(self):
po_done_str = "T, +"
po_p_str = str(self.po_percent).rjust(5) if self.po_percent is not None else "-----"

return "{ac} \u2588 step: {n_env} ENV, {n_po} PO \u2588"\
" \u2713: {suc_po}/{tot_po} PO | {suc_tg}/{tot_tg} TG | {suc_se}/{tot_se} subep"\
" \u2588 PO%: {po_p} \u2588 TG_done: {po_dn}"\
return "{ac} \u2588 step: {n_env} ENV, {n_po} PO \u2588" \
" PO%: {po_p} \u2588 TG_done: {po_dn}" \
" \u2588 \u2713: {suc_po}/{tot_po} PO | {suc_tg}/{tot_tg} TG | {suc_se}/{tot_se} subep" \
.format(
ac=self.last_action.rjust(34), n_env=str(self.env_steps).rjust(6),
n_po=str(self.po_steps).rjust(4), tot_po=self.total_po_count, suc_po=self.successful_po_count,
tot_tg=self.total_topgoal_count, suc_tg=self.successful_topgoal_count, tot_se=self.total_subep_count,
suc_se=self.successful_subep_count, po_p=po_p_str, po_dn=po_done_str.rjust(4))
ac=self.last_action.rjust(34), n_env=str(self.env_steps).rjust(6),
n_po=str(self.po_steps).rjust(4), tot_po=self.total_po_count, suc_po=self.successful_po_count,
tot_tg=self.total_topgoal_count, suc_tg=self.successful_topgoal_count, tot_se=self.total_subep_count,
suc_se=self.successful_subep_count, po_p=po_p_str, po_dn=po_done_str.rjust(4))

def render(self, mode='human'):
def render(self, mode='cli_full'):
"""
Prints output to trace the learning process.
"""
open_goals_print = [se.cur_goal.id for se in self.open_subepisodes]
if len(open_goals_print) > 3:
open_goals_print = '[..., ' + str(open_goals_print[-3:])[1:]
print(
self.env_to_line() + ' \u2588 now open: {1} | active: #{0}]'
.format(self.cur_subepis.cur_goal.id, str(open_goals_print)[:-1]))
if mode == 'cli_full':
open_goals_print = [se.cur_goal.id for se in self.open_subepisodes]
if len(open_goals_print) > 3:
open_goals_print = '[..., ' + str(open_goals_print[-3:])[1:]
print(
self.env_to_line() + ' \u2588 now open: {1} | active: #{0}]'
.format(self.cur_subepis.cur_goal.id, str(open_goals_print)[:-1]))
elif mode == 'cli_basic':
print("env steps: {st} \u2588 PO%: {po_p} \u2588 \u2713: " \
"{suc_po}/{tot_po} PO | {suc_tg}/{tot_tg} TG | {suc_se}/{tot_se} subep"
.format(st=str(self.env_steps).rjust(8), tot_po=self.total_po_count, suc_po=self.successful_po_count,
tot_tg=self.total_topgoal_count, suc_tg=self.successful_topgoal_count,
tot_se=self.total_subep_count,
suc_se=self.successful_subep_count,
po_p=str(self.po_percent).rjust(5) if self.po_percent is not None else "-----"))

def print_open_goals(self, origin_func: str):
"""
Expand Down

0 comments on commit cc861b5

Please sign in to comment.