Skip to content

Commit

Permalink
learn
Browse files Browse the repository at this point in the history
  • Loading branch information
Freakwill committed Dec 15, 2023
1 parent 35a8bfc commit 0c9df50
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 28 deletions.
12 changes: 11 additions & 1 deletion pyrimidine/learn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@

class BaseEstimator(BE):

estimated_params = ()

@classmethod
def config(cls, X, Y, *args, **kwargs):
# configure a population for GA
raise NotImplementedError

def fit(self, X, Y, pop=None, warm_start=False):
if warm_start:
self.pop = pop or self.pop or self.config(X, Y)
else:
self.pop = pop or self.config(X, Y)
self._fit(X, Y)
self._fit(X, Y)
return self

def _fit(self, X, Y):
self.pop.ezolve(n_iter=self.max_iter)
model_ = self.pop.solution
for k in self.estimated_params:
setattr(self, k, getattr(model_, k))
9 changes: 0 additions & 9 deletions pyrimidine/learn/base.py

This file was deleted.

16 changes: 6 additions & 10 deletions pyrimidine/learn/linear_regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3


from pyrimidine import *
from pyrimidine import StandardPopulation
from digit_converter import IntervalConverter

c = IntervalConverter(lb=-60, ub=60)
Expand All @@ -20,24 +20,21 @@ class GALinearRegression(BaseEstimator, LinearRegression):
'''Linear Regression by GA
'''

estimated_params = ('coefs_', 'intercepts_')

@classmethod
def create_model(cls, *args, **kwargs):
return LinearRegression(*args, **kwargs)

def _fit(self, X, Y):
self.pop.ezolve(n_iter=self.max_iter)
model_ = self.pop.solution
self.coef_ = model_.coef_
self.intercept_ = model_.intercept_

@classmethod
def config(cls, X, Y, n_individuals=10, *args, **kwargs):

input_dim = X.shape[1]
assert np.ndim(Y) == 1, 'only support 1D array for `Y`'
# output_dim = Y.shape[1]

class MyIndividual(MixedIndividual):
params={'sigma':0.02}

element_class = FloatChromosome, _BinaryChromosome

def decode(self):
Expand All @@ -50,8 +47,7 @@ def _fitness(self):
model = self.decode()
return model.score(X, Y)

class MyPopulation(HOFPopulation):
element_class = MyIndividual
MyPopulation = StandardPopulation[MyIndividual]

pop = MyPopulation.random(n_individuals=n_individuals, size=(input_dim, 8))

Expand Down
12 changes: 4 additions & 8 deletions pyrimidine/learn/neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.neural_network import MLPRegressor

from .. import MixedIndividual, FloatChromosome, FloatMatrixChromosome
from ..population import HOFPopulation
from ..population import StandardPopulation
from ..learn import BaseEstimator


Expand All @@ -22,6 +22,8 @@ class GAANN(BaseEstimator, MLPRegressor):
max_iter = 100
n_layers = 3

estimated_params = ('coefs_', 'intercepts_')

@classmethod
def create_model(cls, *args, **kwargs):
# create MLPRegressor object
Expand Down Expand Up @@ -58,13 +60,7 @@ def decode(self):
model.n_layers_ = 3
return model

MyPopulation = HOFPopulation[MyIndividual]
MyPopulation = StandardPopulation[MyIndividual]

return MyPopulation.random(n_individuals=n_individuals, size=((input_dim, cls.hidden_dim), cls.hidden_dim, (cls.hidden_dim, output_dim), output_dim))

def _fit(self, X, Y):
self.pop.ezolve(n_iter=self.max_iter)
model_ = self.pop.solution
self.coefs_ = model_.coefs_
self.intercepts_ = model_.intercepts_

0 comments on commit 0c9df50

Please sign in to comment.