diff --git a/mlrose_hiive/fitness/max_k_color.py b/mlrose_hiive/fitness/max_k_color.py index fb816918..8bbda047 100644 --- a/mlrose_hiive/fitness/max_k_color.py +++ b/mlrose_hiive/fitness/max_k_color.py @@ -3,7 +3,6 @@ # Author: Genevieve Hayes (Modified by Andrew Rollings) # License: BSD 3 clause -import numpy as np class MaxKColor: diff --git a/mlrose_hiive/fitness/queens.py b/mlrose_hiive/fitness/queens.py index 48e61163..9f0b2504 100644 --- a/mlrose_hiive/fitness/queens.py +++ b/mlrose_hiive/fitness/queens.py @@ -36,9 +36,10 @@ class Queens: optimization problems *only*. """ - def __init__(self): + def __init__(self, maximize=False): self.prob_type = 'discrete' + self.maximize = maximize @staticmethod def shift(a, num, fill_value=np.nan): @@ -84,6 +85,8 @@ def evaluate(self, state): f_d = np.sum(state_shifts == state) // 2 # each diagonal piece is counted twice fitness = f_h + f_d + if self.maximize: + fitness = self.get_max_size(ls) - fitness return fitness def get_prob_type(self): @@ -96,3 +99,11 @@ def get_prob_type(self): or 'either'. """ return self.prob_type + + @staticmethod + def get_max_size(problem_size): + if problem_size <= 1: + return 0 + if problem_size == 2: + return 1 + return 3*(problem_size-2) diff --git a/mlrose_hiive/generators/queens_generator.py b/mlrose_hiive/generators/queens_generator.py index 68aa98bf..b0131d57 100644 --- a/mlrose_hiive/generators/queens_generator.py +++ b/mlrose_hiive/generators/queens_generator.py @@ -5,7 +5,8 @@ class QueensGenerator: @staticmethod - def generate(seed, size=20): + def generate(seed, size=20, maximize=False): np.random.seed(seed) - problem = QueensOpt(length=size) + problem = QueensOpt(length=size, maximize=maximize) return problem + diff --git a/mlrose_hiive/opt_probs/max_k_color_opt.py b/mlrose_hiive/opt_probs/max_k_color_opt.py index 1a9da521..a370a97b 100644 --- a/mlrose_hiive/opt_probs/max_k_color_opt.py +++ b/mlrose_hiive/opt_probs/max_k_color_opt.py @@ -39,6 +39,8 @@ def __init__(self, edges=None, length=None, fitness_fn=None, maximize=False, else: self.source_graph = source_graph + self.stop_fitness = self.source_graph.number_of_edges() if maximize else 0 + fitness_fn.set_graph(self.source_graph) # if none is provided, make a reasonable starting guess. # the max val is going to be the one plus the maximum number of neighbors of any one node. @@ -58,4 +60,4 @@ def __init__(self, edges=None, length=None, fitness_fn=None, maximize=False, self.set_state(state) def can_stop(self): - return int(self.get_fitness()) == 0 + return int(self.get_fitness()) == self.stop_fitness diff --git a/mlrose_hiive/opt_probs/queens_opt.py b/mlrose_hiive/opt_probs/queens_opt.py index 3e2b9f03..7f00a0d4 100644 --- a/mlrose_hiive/opt_probs/queens_opt.py +++ b/mlrose_hiive/opt_probs/queens_opt.py @@ -24,7 +24,9 @@ def __init__(self, length=None, fitness_fn=None, maximize=False, self.length = length if fitness_fn is None: - fitness_fn = Queens() + fitness_fn = Queens(maximize=maximize) + + self.stop_fitness = Queens.get_max_size(length) if maximize else 0 self.max_val = length crossover = UniformCrossOver(self) if crossover is None else crossover @@ -36,4 +38,4 @@ def __init__(self, length=None, fitness_fn=None, maximize=False, self.set_state(state) def can_stop(self): - return int(self.get_fitness()) == 0 + return int(self.get_fitness()) == self.stop_fitness