diff --git a/agents/goal_rgcn_agent.py b/agents/goal_rgcn_agent.py index 238e999c..fbcd2b9b 100644 --- a/agents/goal_rgcn_agent.py +++ b/agents/goal_rgcn_agent.py @@ -1,4 +1,5 @@ import time +import psutil import numpy as np from operator import itemgetter @@ -25,14 +26,16 @@ critic_discount = 0.5 entropy_beta = 0.001 -ppo_steps = 3232 -mini_batch_size = 202 +ppo_steps = 808 +mini_batch_size = 101 ppo_epochs = 10 min_po_attempts = 1000 target_po_ratio = 0.95 render_each_step = False +max_swap_percent = 80 + # ------------------------------------------------------------------------------ def compute_gae(next_value, rewards, masks, values, gamma=gamma, lam=gae_lambda): @@ -114,7 +117,9 @@ def train(): # ------ TRAINING LOOP ------ - while True: + can_stop_training = False + + while not can_stop_training: iter_start = time.time() log_probs = [] @@ -124,7 +129,8 @@ def train(): rewards = [] masks = [] - for i in range(ppo_steps): + i = 0 + while i < ppo_steps: dist, value = model(state) action = dist.sample() @@ -142,6 +148,14 @@ def train(): state = next_state frame_idx += 1 + i += 1 + + # Since the graphs are memory-heavy, pre-detect swap overflow after every mini_batch_size steps. + if i % mini_batch_size == 0: + swap_stats = psutil.swap_memory() + if swap_stats.percent > max_swap_percent: + print("memory capacity is exhausted, stopping training...") + exit(1) _, next_value = model(next_state) returns = compute_gae(next_value, rewards, masks, values) @@ -159,19 +173,19 @@ def train(): print("\nPPO iteration {0} done in {1} secs. Env stats:".format(iteration, round(iter_stop - iter_start, 2))) env.render(mode='cli_basic') - if can_stop_training(env): - print("Training stop condition met, finishing training.") - break - else: - iteration += 1 + iteration += 1 + can_stop_training = check_training_progress(env) # ------ cleanup ------ env.close() -def can_stop_training(env): - return len(env.po_success_history) > min_po_attempts and env.po_percent > target_po_ratio +def check_training_progress(env): + is_good_progress = len(env.po_success_history) > min_po_attempts and env.po_percent > target_po_ratio + if is_good_progress: + print("training successful, will stop training soon...") + return is_good_progress if __name__ == "__main__":