Skip to content

Commit

Permalink
Merge pull request #30 from AlexWendland/master
Browse files Browse the repository at this point in the history
  • Loading branch information
hiive authored Mar 3, 2024
2 parents da5d4dc + 0f1008d commit c23243b
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 7 deletions.
1 change: 0 additions & 1 deletion mlrose_hiive/fitness/max_k_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Author: Genevieve Hayes (Modified by Andrew Rollings)
# License: BSD 3 clause

import numpy as np


class MaxKColor:
Expand Down
13 changes: 12 additions & 1 deletion mlrose_hiive/fitness/queens.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ class Queens:
optimization problems *only*.
"""

def __init__(self):
def __init__(self, maximize=False):

self.prob_type = 'discrete'
self.maximize = maximize

@staticmethod
def shift(a, num, fill_value=np.nan):
Expand Down Expand Up @@ -84,6 +85,8 @@ def evaluate(self, state):

f_d = np.sum(state_shifts == state) // 2 # each diagonal piece is counted twice
fitness = f_h + f_d
if self.maximize:
fitness = self.get_max_size(ls) - fitness
return fitness

def get_prob_type(self):
Expand All @@ -96,3 +99,11 @@ def get_prob_type(self):
or 'either'.
"""
return self.prob_type

@staticmethod
def get_max_size(problem_size):
if problem_size <= 1:
return 0
if problem_size == 2:
return 1
return 3*(problem_size-2)
5 changes: 3 additions & 2 deletions mlrose_hiive/generators/queens_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

class QueensGenerator:
@staticmethod
def generate(seed, size=20):
def generate(seed, size=20, maximize=False):
np.random.seed(seed)
problem = QueensOpt(length=size)
problem = QueensOpt(length=size, maximize=maximize)
return problem

4 changes: 3 additions & 1 deletion mlrose_hiive/opt_probs/max_k_color_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(self, edges=None, length=None, fitness_fn=None, maximize=False,
else:
self.source_graph = source_graph

self.stop_fitness = self.source_graph.number_of_edges() if maximize else 0

fitness_fn.set_graph(self.source_graph)
# if none is provided, make a reasonable starting guess.
# the max val is going to be the one plus the maximum number of neighbors of any one node.
Expand All @@ -58,4 +60,4 @@ def __init__(self, edges=None, length=None, fitness_fn=None, maximize=False,
self.set_state(state)

def can_stop(self):
return int(self.get_fitness()) == 0
return int(self.get_fitness()) == self.stop_fitness
6 changes: 4 additions & 2 deletions mlrose_hiive/opt_probs/queens_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, length=None, fitness_fn=None, maximize=False,
self.length = length

if fitness_fn is None:
fitness_fn = Queens()
fitness_fn = Queens(maximize=maximize)

self.stop_fitness = Queens.get_max_size(length) if maximize else 0

self.max_val = length
crossover = UniformCrossOver(self) if crossover is None else crossover
Expand All @@ -36,4 +38,4 @@ def __init__(self, length=None, fitness_fn=None, maximize=False,
self.set_state(state)

def can_stop(self):
return int(self.get_fitness()) == 0
return int(self.get_fitness()) == self.stop_fitness

0 comments on commit c23243b

Please sign in to comment.