From 590605a652e7c897ab967313280466837758f258 Mon Sep 17 00:00:00 2001 From: Bhark Date: Mon, 3 Jun 2024 16:56:07 +0200 Subject: [PATCH] checkpoints --- sneat/evolve.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/sneat/evolve.py b/sneat/evolve.py index dc0bd53..9ec0cec 100644 --- a/sneat/evolve.py +++ b/sneat/evolve.py @@ -23,6 +23,19 @@ def evaluate_population(pop, ff): for g, fitness in zip(pop.genomes, fitness_scores): g.fitness = fitness +def save_checkpoint(pop): + with open('checkpoint.pkl', 'wb') as f: + pkl.dump(pop, f) + +def load_checkpoint(): + try: + with open('checkpoint.pkl', 'rb') as f: + pop = pkl.load(f) + print(f'[i] Restoring from checkpoint (gen. {pop.generation})...') + return pop + except FileNotFoundError: + return None + def print_stats(pop): print(f'\n\n[i] Gen. {pop.generation}:') headers = ['Species', 'Members', 'Best Fitness', 'Average Fitness'] @@ -35,7 +48,8 @@ def print_stats(pop): def evolve(fitness_function): config = get_config() - pop = Population() + + pop = load_checkpoint() or Population() max_generations = config.getint('Evolution', 'max_generations') or np.inf max_fitness = config.getfloat('Evolution', 'max_fitness') or np.inf @@ -52,6 +66,10 @@ def evolve(fitness_function): # reproduce pop.reproduce() + # save checkpoint + if pop.generation % 10 == 0: + save_checkpoint(pop) + best_fitness = max(g.fitness for g in pop.genomes) if best_fitness >= max_fitness: winner = max(pop.genomes, key=lambda x: x.fitness)