Skip to content

Commit

Permalink
implement stagnation
Browse files Browse the repository at this point in the history
  • Loading branch information
bhark committed Jun 3, 2024
1 parent 51e212c commit 847ed8d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='sneat',
version='0.0.6',
version='0.0.7',
packages=find_packages(),
package_data={'sneat': ['default_config.ini']},
install_requires=[
Expand Down
4 changes: 3 additions & 1 deletion sneat/default_config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ remove_node=0.08

[Evolution]
max_generations = 100
max_fitness = 4
max_fitness = 4
max_stagnation = 15
min_species = 3
4 changes: 2 additions & 2 deletions sneat/evolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def load_checkpoint():

def print_stats(pop):
print(f'\n\n[i] Gen. {pop.generation}:')
headers = ['Species', 'Members', 'Best Fitness', 'Average Fitness']
headers = ['Species', 'Members', 'Best Fitness', 'Average Fitness', 'Stagnation', 'Best Complexity']
for s in pop.species:
s.members = sorted(s.members, key=lambda x: x.fitness, reverse=True)
species = sorted(pop.species, key=lambda x: x.members[0].fitness, reverse=True)
data = [[s.id, len(s.members), round(max(g.fitness for g in s.members), 2), round(np.mean([g.fitness for g in s.members]), 2)] for s in species]
data = [[s.id, len(s.members), round(max(g.fitness for g in s.members), 2), round(np.mean([g.fitness for g in s.members]), 2), s.stagnation, f'{len(s.members[0].network.nodes)}n + {len(s.members[0].network.connections)}'] for s in species]
print(tb(data, headers=headers))
print('-' * 55)

Expand Down
21 changes: 19 additions & 2 deletions sneat/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,30 @@ def reproduce(self):
for g in self.genomes:
g.normalized_fitness = (g.fitness - min_fitness) / (max_fitness - min_fitness)

# assign adjusted fitness scores
for s in self.species:
s_size = len(s.members)

# bump stagnation
if s.members[0].fitness > s.best_fitness:
s.best_fitness = s.members[0].fitness
s.stagnation = 0
else:
s.stagnation += 1

# assign adjusted fitness scores
s_size = len(s.members)
for g in s.members:
g.adjusted_fitness = max(g.normalized_fitness / s_size, 0.0001) # avoid division by zero

# remove stagnant species
while len(self.species) >= self.config.getint('Evolution', 'min_species'):
stagnant_species = [s for s in self.species if s.stagnation >= self.config.getint('Evolution', 'max_stagnation')]
stagnant_species = sorted(stagnant_species, key=lambda x: x.best_fitness, reverse=True)
if not stagnant_species:
break
extinct = stagnant_species.pop()
self.species.remove(extinct)
print(f'[i] Species {extinct.id} went extinct')

# perform reproduction inside of each species
offspring = []
for s in self.species:
Expand Down
1 change: 1 addition & 0 deletions sneat/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ def __init__(self, representative, callbacks):
self.members = [representative]
self.id = callbacks['get_next_species_id']()
self.stagnation = 0
self.best_fitness = float('-inf')

def add_member(self, genome):
self.members.append(genome)

0 comments on commit 847ed8d

Please sign in to comment.