Skip to content

Commit

Permalink
reduced iteration length, added safety measure when swap is filling up
Browse files Browse the repository at this point in the history
  • Loading branch information
Flunzmas committed Nov 13, 2020
1 parent b33875e commit c43870b
Showing 1 changed file with 25 additions and 11 deletions.
36 changes: 25 additions & 11 deletions agents/goal_rgcn_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import psutil

import numpy as np
from operator import itemgetter
Expand All @@ -25,14 +26,16 @@
critic_discount = 0.5
entropy_beta = 0.001

ppo_steps = 3232
mini_batch_size = 202
ppo_steps = 808
mini_batch_size = 101
ppo_epochs = 10

min_po_attempts = 1000
target_po_ratio = 0.95
render_each_step = False

max_swap_percent = 80

# ------------------------------------------------------------------------------

def compute_gae(next_value, rewards, masks, values, gamma=gamma, lam=gae_lambda):
Expand Down Expand Up @@ -114,7 +117,9 @@ def train():

# ------ TRAINING LOOP ------

while True:
can_stop_training = False

while not can_stop_training:

iter_start = time.time()
log_probs = []
Expand All @@ -124,7 +129,8 @@ def train():
rewards = []
masks = []

for i in range(ppo_steps):
i = 0
while i < ppo_steps:

dist, value = model(state)
action = dist.sample()
Expand All @@ -142,6 +148,14 @@ def train():

state = next_state
frame_idx += 1
i += 1

# Since the graphs are memory-heavy, pre-detect swap overflow after every mini_batch_size steps.
if i % mini_batch_size == 0:
swap_stats = psutil.swap_memory()
if swap_stats.percent > max_swap_percent:
print("memory capacity is exhausted, stopping training...")
exit(1)

_, next_value = model(next_state)
returns = compute_gae(next_value, rewards, masks, values)
Expand All @@ -159,19 +173,19 @@ def train():
print("\nPPO iteration {0} done in {1} secs. Env stats:".format(iteration, round(iter_stop - iter_start, 2)))
env.render(mode='cli_basic')

if can_stop_training(env):
print("Training stop condition met, finishing training.")
break
else:
iteration += 1
iteration += 1
can_stop_training = check_training_progress(env)

# ------ cleanup ------

env.close()


def can_stop_training(env):
return len(env.po_success_history) > min_po_attempts and env.po_percent > target_po_ratio
def check_training_progress(env):
is_good_progress = len(env.po_success_history) > min_po_attempts and env.po_percent > target_po_ratio
if is_good_progress:
print("training successful, will stop training soon...")
return is_good_progress


if __name__ == "__main__":
Expand Down

0 comments on commit c43870b

Please sign in to comment.