-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
298 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# | ||
# Copyright (c) 2021 Nathan Juraj Michlo | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in | ||
# all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ | ||
|
||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
|
||
from ruck import * | ||
from ruck.external.deap import select_nsga2 | ||
|
||
|
||
class MultiObjectiveMinimalModule(EaModule): | ||
""" | ||
Minimal onemax example | ||
- The goal is to flip all the bits of a boolean array to True | ||
- Offspring are generated as bit flipped versions of the previous population | ||
- Selection tournament is performed between the previous population and the offspring | ||
""" | ||
|
||
# evaluate unevaluated members | ||
def evaluate_values(self, values): | ||
return [(y - x**2, x - y**2) for (x, y) in values] | ||
|
||
# generate values in the range [-1, 1] | ||
def gen_starting_values(self): | ||
return [np.random.random(2) * 2 - 1 for _ in range(100)] | ||
|
||
# randomly offset the members by a small amount | ||
def generate_offspring(self, population): | ||
return [Member(np.clip(m.value + np.random.randn(2) * 0.05, -1, 1)) for m in population] | ||
|
||
# apply nsga2 to population, which tries to maintain a diverse set of solutions | ||
def select_population(self, population, offspring): | ||
return select_nsga2(population + offspring, len(population)) | ||
|
||
|
||
if __name__ == '__main__': | ||
# create and train the population | ||
module = MultiObjectiveMinimalModule() | ||
pop, logbook, halloffame = Trainer(generations=100, progress=True).fit(module) | ||
|
||
print('initial stats:', logbook[0]) | ||
print('final stats:', logbook[-1]) | ||
print('best member:', halloffame.members[0]) | ||
|
||
# plot path | ||
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 5)) | ||
# plot points | ||
ax0.set_title('Pareto Optimal Values') | ||
ax0.scatter(*zip(*(m.value for m in pop))) | ||
ax0.set_xlabel('X') | ||
ax0.set_ylabel('Y') | ||
# plot pareto optimal solution | ||
ax1.set_title('Pareto Optimal Scores') | ||
ax1.scatter(*zip(*(m.fitness for m in pop))) | ||
ax1.set_xlabel('Distances') | ||
ax1.set_ylabel('Smoothness') | ||
# display | ||
fig.tight_layout() | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ | ||
# MIT License | ||
# | ||
# Copyright (c) 2021 Nathan Juraj Michlo | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in | ||
# all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ | ||
|
||
import random | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
from ruck import * | ||
|
||
|
||
class TravelingSalesmanModule(EaModule): | ||
|
||
def __init__(self, points, num_individuals: int = 128, closed_path=False): | ||
self.num_individuals = int(num_individuals) | ||
self.points = np.array(points) | ||
self.num_points = len(self.points) | ||
self.closed_path = bool(closed_path) | ||
# checks | ||
assert self.points.ndim == 2 | ||
assert self.num_points > 0 | ||
assert self.num_individuals > 0 | ||
|
||
# OVERRIDE | ||
|
||
def gen_starting_values(self): | ||
values = [np.arange(self.num_points) for _ in range(self.num_individuals)] | ||
[np.random.shuffle(v) for v in values] | ||
return values | ||
|
||
def generate_offspring(self, population): | ||
# there are definitely much better ways to do this | ||
return [Member(self._two_opt_swap(random.choice(population).value)) for _ in range(self.num_individuals)] | ||
|
||
def evaluate_values(self, values): | ||
# we negate because we want to minimize dist | ||
return [-self._get_dist(v) for v in values] | ||
|
||
def select_population(self, population, offspring): | ||
return R.select_tournament(population + offspring, len(population), k=3) | ||
|
||
# HELPER | ||
|
||
def _two_opt_swap(self, idxs): | ||
i, j = np.random.randint(0, self.num_points, 2) | ||
i, j = min(i, j), max(i, j) | ||
nidxs = np.concatenate([idxs[:i], idxs[i:j][::-1], idxs[j:]]) | ||
return nidxs | ||
|
||
def _get_dist(self, value): | ||
if self.closed_path: | ||
idxs_from, idxs_to = value, np.roll(value, -1) | ||
else: | ||
idxs_from, idxs_to = value[:-1], value[1:] | ||
# compute dist | ||
return np.sum(np.linalg.norm(self.points[idxs_from] - self.points[idxs_to], ord=2, axis=-1)) | ||
|
||
def get_plot_points(self, value): | ||
idxs = value.value if isinstance(value, Member) else value | ||
# handle case | ||
if self.closed_path: | ||
idxs = np.concatenate([idxs, [idxs[0]]]) | ||
# get consecutive points | ||
xs, ys = self.points[idxs].T | ||
return xs, ys | ||
|
||
|
||
if __name__ == '__main__': | ||
# determinism | ||
random.seed(42) | ||
np.random.seed(42) | ||
# get points | ||
points = np.random.rand(72, 2) | ||
# train | ||
module = TravelingSalesmanModule(points=points, num_individuals=128, closed_path=False) | ||
population, logbook, halloffame = Trainer(generations=1024).fit(module) | ||
|
||
# plot path | ||
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) | ||
ax.plot(*module.get_plot_points(halloffame[0])) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
pip>=21.0 | ||
numpy>=1.19 | ||
tqdm>=4 | ||
|
||
# requirements needed for examples too | ||
ray>=1.6.0 | ||
deap>=1.3 | ||
matplotlib>=3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,3 @@ | ||
pip>=21.0 | ||
numpy>=1.19 | ||
tqdm>=4 | ||
# ray should be an optional requirement | ||
ray>=1.6.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ | ||
# MIT License | ||
# | ||
# Copyright (c) 2021 Nathan Juraj Michlo | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in | ||
# all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ | ||
# MIT License | ||
# | ||
# Copyright (c) 2021 Nathan Juraj Michlo | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in | ||
# all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
# SOFTWARE. | ||
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ | ||
|
||
from typing import Optional | ||
from typing import Tuple | ||
from ruck.functional import check_selection | ||
|
||
|
||
try: | ||
import deap | ||
except ImportError as e: | ||
import warnings | ||
warnings.warn('failed to import deap, please install it: $ pip install deap') | ||
raise e | ||
|
||
# ========================================================================= # | ||
# deap helper # | ||
# ========================================================================= # | ||
|
||
|
||
@check_selection | ||
def select_nsga2(population, num_offspring: int, weights: Optional[Tuple[float, ...]] = None): | ||
""" | ||
This is hacky... ruck doesn't yet have NSGA2 | ||
support, but we will add it in future! | ||
""" | ||
# get a fitness value to perform checks | ||
f = population[0].fitness | ||
# check fitness | ||
try: | ||
for _ in f: break | ||
except: | ||
raise ValueError('fitness values do not have multiple values!') | ||
# get weights | ||
if weights is None: | ||
weights = tuple(1.0 for _ in f) | ||
# get deap | ||
from deap import creator, tools, base | ||
# initialize creator | ||
creator.create('_SelIdxFitness', base.Fitness, weights=weights) | ||
creator.create('_SelIdxIndividual', int, fitness=creator._SelIdxFitness) | ||
# convert to deap population | ||
idx_individuals = [] | ||
for i, m in enumerate(population): | ||
ind = creator._SelIdxIndividual(i) | ||
ind.fitness.values = m.fitness | ||
idx_individuals.append(ind) | ||
# run nsga2 | ||
chosen_idx = tools.selNSGA2(individuals=idx_individuals, k=num_offspring) | ||
# return values | ||
return [population[i] for i in chosen_idx] | ||
|
||
|
||
# ========================================================================= # | ||
# END # | ||
# ========================================================================= # |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,7 @@ | |
author="Nathan Juraj Michlo", | ||
author_email="[email protected]", | ||
|
||
version="0.2.1", | ||
version="0.2.2", | ||
python_requires=">=3.6", | ||
packages=setuptools.find_packages(), | ||
|
||
|