Skip to content

Commit

Permalink
learn
Browse files Browse the repository at this point in the history
  • Loading branch information
Freakwill committed Dec 8, 2023
1 parent 877d294 commit 9709b9f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
58 changes: 58 additions & 0 deletions pyrimidine/learn/linear_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3


from pyrimidine import *
from digit_converter import IntervalConverter

c = IntervalConverter(lb=-60, ub=60)
class _BinaryChromosome(BinaryChromosome):
def decode(self):
return c(self)


import numpy as np
import numpy.linalg as LA
from sklearn.linear_model import LinearRegression
from ..learn import BaseEstimator


class GALinearRegression(BaseEstimator, LinearRegression):
'''Linear Regression by GA
'''

@classmethod
def create_model(cls, *args, **kwargs):
return LinearRegression(*args, **kwargs)

def _fit(self, X, Y):
self.pop.ezolve(n_iter=self.max_iter)
model_ = self.pop.solution
self.coef_ = model_.coef_
self.intercept_ = model_.intercept_

@classmethod
def config(cls, X, Y, n_individuals=10, *args, **kwargs):

input_dim = X.shape[1]
# output_dim = Y.shape[1]

class MyIndividual(MixedIndividual):
params={'sigma':0.02}
element_class = FloatChromosome, _BinaryChromosome

def decode(self):
model = cls.create_model(*args, **kwargs)
model.coef_ = np.asarray(self[0])
model.intercept_ = self[1].decode()
return model

def _fitness(self):
model = self.decode()
return model.score(X, Y)

class MyPopulation(HOFPopulation):
element_class = MyIndividual

pop = MyPopulation.random(n_individuals=n_individuals, size=(input_dim, 8))

return pop
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[pytest]
python_files = tests/test_*.py
python_files = tests/test_*.py
1 change: 1 addition & 0 deletions tests/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_ann():
s1 = model.score(X, Y)
assert s0 <= s1


def test_lr():
X = np.array([[0,0], [0,1], [1,0], [1,1]])
Y = np.array([2,1,0,0])
Expand Down

0 comments on commit 9709b9f

Please sign in to comment.