Skip to content

Commit

Permalink
feat: added checkpoint fn
Browse files Browse the repository at this point in the history
  • Loading branch information
Algoboy-Kevin committed Jul 29, 2024
1 parent 1539d9c commit 3a11cba
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 60 deletions.
Binary file added cmunbtl.otf
Binary file not shown.
5 changes: 3 additions & 2 deletions game_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ def render_init(self, net, genome, config):
self.config = config

# Modify nets to accomodate display
modify_eval_functions(self.net, self.genome, self.config)
self.net.node_evals = modify_eval_functions(self.net, self.genome, self.config)
has_eval = set(eval[0] for eval in self.net.node_evals)
has_input = set(con[1] for con in self.genome.connections)

self.hidden_nodes = [node for node in self.genome.nodes if not 0 <= node <= 3 and node in has_input and node in has_eval]
self.node_centers = get_node_centers(self.net, self.genome, self.hidden_nodes)

def reset(self):
self.snake = [((int) (random() * N_COLS), (int) (random() * N_ROWS))]
self.apple = (int) (random() * N_COLS), (int) (random() * N_ROWS)
Expand Down
86 changes: 66 additions & 20 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import print_function
import multiprocessing
import pickle
from game_env import *
from render import spawn_window,render_basic
import pygame
import os
import neat
import sys
import random
import time
from utils import does_checkpoint_exist, get_last_episode
from game_env import *
from render import spawn_window,render_basic, render

WINNER_FILENAME = "winner/last-winner"

def test():
scores = []
Expand All @@ -34,12 +36,10 @@ def test():
if done:
scores.append(env.rewards)

time.sleep(0.1)
time.sleep(0.5)

# return the average scores
print(scores)
mean = np.mean(scores)
print(mean)
return mean

def simulate(net):
Expand Down Expand Up @@ -67,16 +67,36 @@ def simulate(net):
# return the average scores
return np.mean(scores)

def replay_genome(genome, config):
def replay_genome(genome, config, display=False):
net = neat.nn.FeedForwardNetwork.create(genome, config)
env = SnekEnv(net, genome, config)
env = SnekEnv()
obs = env.reset()
done = False
while not done:
activation = net.activate(obs)
action = np.argmax(activation)
obs, done = env.step(action)

if display:
env.render_init(net, genome, config)
spawn_window()
pygame.init()
while not done:
activation = env.net.activate(obs)
render(
env.snake,
env.apple,
env.net,
env.genome,
env.node_centers,
env.hidden_nodes
)
action = np.argmax(activation)
obs, done = env.step(action)
time.sleep(0.1)

pygame.quit()
else:
while not done:
activation = net.activate(obs)
action = np.argmax(activation)
obs, done = env.step(action)

# env close

def eval_genomes(genomes, config):
Expand All @@ -99,7 +119,21 @@ def eval_genome(genome, config):
fitness = simulate(net, genome, config)
return fitness

def run(config_file, arg):
def test_winner(config_file, genome):
with open(genome, "rb") as f:
winner = pickle.load(f, encoding="latin-1")

config = neat.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
config_file
)

replay_genome(winner, config, True)

def run(config_file, checkpoint_path, arg):
# Load configuration.
config = neat.Config(
neat.DefaultGenome,
Expand All @@ -110,7 +144,15 @@ def run(config_file, arg):
)

# Create the population, which is the top-level object for a NEAT run.
p = neat.Population(config)
# Check if a winner file already exists
if does_checkpoint_exist(checkpoint_path):
print("Found checkpoint on ", checkpoint_path)
last_episode = get_last_episode(checkpoint_path)
p = neat.Checkpointer.restore_checkpoint(last_episode)

else:
# Start training from scratch
p = neat.Population(config)

# Add a stdout reporter to show progress in the terminal.
p.add_reporter(neat.StdOutReporter(True))
Expand All @@ -120,26 +162,30 @@ def run(config_file, arg):

# Run for up to 500 generations.
if arg == 'serial':
winner = p.run(eval_genomes, 500)
winner = p.run(eval_genomes, 3)
elif arg == 'parallel':
pe = neat.ParallelEvaluator(multiprocessing.cpu_count(), eval_genome)
winner = p.run(pe.evaluate, n=500)

with open('winner/neat-winner', 'wb') as f:
# Save file once winner is found
with open(WINNER_FILENAME, 'wb') as f:
pickle.dump(winner, f)


if __name__ == '__main__':
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, 'config-feedforward')
checkpoint_path = os.path.join(local_dir, 'checkpoints')
winner_path = os.path.join(local_dir, WINNER_FILENAME)

if len(sys.argv) == 0:
run(config_path)
elif sys.argv[1] == 'test':
test()
elif sys.argv[1] == 'train':
run(config_path, 'serial')
run(config_path, checkpoint_path, 'serial')
elif sys.argv[1] == 'train_fast':
run(config_path, 'parallel')
run(config_path, checkpoint_path, 'parallel')
elif sys.argv[1] == 'test_winner':
test_winner(config_path, winner_path)


30 changes: 14 additions & 16 deletions render.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
GAME_TOP_LEFT = (WINDOW_BUFFER + 0, 400 + WINDOW_BUFFER)
NODE_SIZE = 10
BUFFER = 8
screen = None

# ---------------- color settings --------------

Expand Down Expand Up @@ -53,7 +52,7 @@ def render(snake, apple, net, genome, node_centers, hidden_nodes):
pygame.quit()

screen.fill(BLACK)
draw_square(screen)
draw_square()
draw_snake(snake)
draw_apple(apple)
draw_network(
Expand Down Expand Up @@ -81,26 +80,25 @@ def getLeftTop(x, y):

def draw_snake(snake):
global screen
print(snake)

for i, (x, y) in enumerate(snake):
rect = pygame.Rect(getLeftTop(x, y), (BLOCK_W - BUFFER * 2, BLOCK_H - BUFFER * 2))
pygame.draw.rect(screen, YELLOW if i == len(snake) - 1 else WHITE, rect)

def draw_connections(first_set, second_set, net, genome, node_centers):
global screen

for first in first_set:
for second in second_set:
if (first, second) in genome.connections:
start = node_centers[first]
stop = node_centers[second]
weight = genome.connections[(first, second)].weight
color = BLUE if weight >= 0 else ORANGE

surf = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT), pygame.SRCALPHA)
alpha = 255 * (0.3 + net.values[first] * 0.7)
pygame.draw.line(surf, color + (alpha,), start, stop, width=5)
screen.blit(surf, (0, 0))
for second in second_set:
if (first, second) in genome.connections:
start = node_centers[first]
stop = node_centers[second]
weight = genome.connections[(first, second)].weight
color = BLUE if weight >= 0 else ORANGE
surf = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT), pygame.SRCALPHA)
alpha = 255 * (0.3 + net.values[first] * 0.7)
pygame.draw.line(surf, color + (alpha,), start, stop, width=5)
screen.blit(surf, (0, 0))

def draw_network(net, genome, node_centers, hidden_nodes):
global screen
Expand Down Expand Up @@ -161,8 +159,8 @@ def draw_network(net, genome, node_centers, hidden_nodes):
# img = font.render(str(hidden), True, WHITE)
# screen.blit(img, center2)

pygame.draw.circle(color, center, NODE_SIZE)
pygame.draw.circle(WHITE, center, NODE_SIZE, width=5)
pygame.draw.circle(screen, color, center, NODE_SIZE)
pygame.draw.circle(screen, WHITE, center, NODE_SIZE, width=5)



Expand Down
44 changes: 22 additions & 22 deletions render_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,29 @@ def get_node_centers(net, genome, hidden_nodes):


def modify_eval_functions(net, genome, config):
"""
Modify neat-python's function to display more hidden nodes
"""
# Gather expressed connections.
connections = [cg.key for cg in genome.connections.values() if cg.enabled]

layers = feed_forward_layers(config.genome_config.input_keys, config.genome_config.output_keys, connections, genome)
node_evals = []
for layer in layers:
for node in layer:
inputs = []
for conn_key in connections:
inode, onode = conn_key
if onode == node:
cg = genome.connections[conn_key]
inputs.append((inode, cg.weight))

ng = genome.nodes[node]
aggregation_function = config.genome_config.aggregation_function_defs.get(ng.aggregation)
activation_function = config.genome_config.activation_defs.get(ng.activation)
node_evals.append((node, activation_function, aggregation_function, ng.bias, ng.response, inputs))
"""
Modify neat-python's function to display more hidden nodes
"""
# Gather expressed connections.
connections = [cg.key for cg in genome.connections.values() if cg.enabled]

layers = feed_forward_layers(config.genome_config.input_keys, config.genome_config.output_keys, connections, genome)
node_evals = []
for layer in layers:
for node in layer:
inputs = []
for conn_key in connections:
inode, onode = conn_key
if onode == node:
cg = genome.connections[conn_key]
inputs.append((inode, cg.weight))

ng = genome.nodes[node]
aggregation_function = config.genome_config.aggregation_function_defs.get(ng.aggregation)
activation_function = config.genome_config.activation_defs.get(ng.activation)
node_evals.append((node, activation_function, aggregation_function, ng.bias, ng.response, inputs))

net.node_evals = node_evals
return node_evals

def feed_forward_layers(inputs, outputs, connections, genome):
"""
Expand Down
24 changes: 24 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

def get_last_episode(folder_path):
episodes = []
for file in os.listdir(folder_path):
# if file.lower().startswith('winner'):
# print("WINNER FOUND")
# return os.path.join(folder_path, file)

if file.lower().startswith('neat-checkpoint-'):
episodes.append(file)

if episodes:
filename = episodes[-1]
file_path = os.path.join(folder_path, filename)
return file_path
else:
return "No checkpoint found in the folder."

def does_checkpoint_exist(folder_path):
for file in os.listdir(folder_path):
if 'neat-checkpoint-' in file:
return True
return False
Binary file added winner/default-winner
Binary file not shown.
Binary file added winner/last-winner
Binary file not shown.

0 comments on commit 3a11cba

Please sign in to comment.