Skip to content

Commit

Permalink
checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
bhark committed Jun 3, 2024
1 parent add52ab commit 590605a
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion sneat/evolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def evaluate_population(pop, ff):
for g, fitness in zip(pop.genomes, fitness_scores):
g.fitness = fitness

def save_checkpoint(pop):
with open('checkpoint.pkl', 'wb') as f:
pkl.dump(pop, f)

def load_checkpoint():
try:
with open('checkpoint.pkl', 'rb') as f:
pop = pkl.load(f)
print(f'[i] Restoring from checkpoint (gen. {pop.generation})...')
return pop
except FileNotFoundError:
return None

def print_stats(pop):
print(f'\n\n[i] Gen. {pop.generation}:')
headers = ['Species', 'Members', 'Best Fitness', 'Average Fitness']
Expand All @@ -35,7 +48,8 @@ def print_stats(pop):

def evolve(fitness_function):
config = get_config()
pop = Population()

pop = load_checkpoint() or Population()

max_generations = config.getint('Evolution', 'max_generations') or np.inf
max_fitness = config.getfloat('Evolution', 'max_fitness') or np.inf
Expand All @@ -52,6 +66,10 @@ def evolve(fitness_function):
# reproduce
pop.reproduce()

# save checkpoint
if pop.generation % 10 == 0:
save_checkpoint(pop)

best_fitness = max(g.fitness for g in pop.genomes)
if best_fitness >= max_fitness:
winner = max(pop.genomes, key=lambda x: x.fitness)
Expand Down

0 comments on commit 590605a

Please sign in to comment.