From 9709b9ff033daf3db51a5626fa73a977613ae82e Mon Sep 17 00:00:00 2001 From: William Song <30965609+Freakwill@users.noreply.github.com> Date: Fri, 8 Dec 2023 18:11:39 +0800 Subject: [PATCH] learn --- pyrimidine/learn/linear_regression.py | 58 +++++++++++++++++++++++++++ pytest.ini | 2 +- tests/test_learn.py | 1 + 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100755 pyrimidine/learn/linear_regression.py diff --git a/pyrimidine/learn/linear_regression.py b/pyrimidine/learn/linear_regression.py new file mode 100755 index 0000000..f921711 --- /dev/null +++ b/pyrimidine/learn/linear_regression.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + + +from pyrimidine import * +from digit_converter import IntervalConverter + +c = IntervalConverter(lb=-60, ub=60) +class _BinaryChromosome(BinaryChromosome): + def decode(self): + return c(self) + + +import numpy as np +import numpy.linalg as LA +from sklearn.linear_model import LinearRegression +from ..learn import BaseEstimator + + +class GALinearRegression(BaseEstimator, LinearRegression): + '''Linear Regression by GA + ''' + + @classmethod + def create_model(cls, *args, **kwargs): + return LinearRegression(*args, **kwargs) + + def _fit(self, X, Y): + self.pop.ezolve(n_iter=self.max_iter) + model_ = self.pop.solution + self.coef_ = model_.coef_ + self.intercept_ = model_.intercept_ + + @classmethod + def config(cls, X, Y, n_individuals=10, *args, **kwargs): + + input_dim = X.shape[1] + # output_dim = Y.shape[1] + + class MyIndividual(MixedIndividual): + params={'sigma':0.02} + element_class = FloatChromosome, _BinaryChromosome + + def decode(self): + model = cls.create_model(*args, **kwargs) + model.coef_ = np.asarray(self[0]) + model.intercept_ = self[1].decode() + return model + + def _fitness(self): + model = self.decode() + return model.score(X, Y) + + class MyPopulation(HOFPopulation): + element_class = MyIndividual + + pop = MyPopulation.random(n_individuals=n_individuals, size=(input_dim, 8)) + + return pop diff --git a/pytest.ini b/pytest.ini index fcdc3ba..b907caa 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,2 @@ [pytest] -python_files = tests/test_*.py +python_files = tests/test_*.py \ No newline at end of file diff --git a/tests/test_learn.py b/tests/test_learn.py index db1ff14..4b42c15 100644 --- a/tests/test_learn.py +++ b/tests/test_learn.py @@ -19,6 +19,7 @@ def test_ann(): s1 = model.score(X, Y) assert s0 <= s1 + def test_lr(): X = np.array([[0,0], [0,1], [1,0], [1,1]]) Y = np.array([2,1,0,0])