-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLogisticRegression.java
39 lines (31 loc) · 1.02 KB
/
LogisticRegression.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package Datascience;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
public class LogisticRegression {
public DoubleMatrix2D X;
public DoubleMatrix1D y;
public DoubleMatrix1D theta;
public LogisticRegression(DoubleMatrix2D x, DoubleMatrix1D y, DoubleMatrix1D theta) {
super();
X = x;
this.y = y;
this.theta = theta;
}
public void train(int epoch, double learningRate){
Algebra a = new Algebra();
}
public double sigmoid(double z){
return 1./(1. + Math.exp(-z));
}
public static void main(String[] args) {
double x [][] = {{1,2,3},{1,2,3},{1,2,3}};
double y [] = {1,2,3};
double theta [] = {1,1,1};
DoubleMatrix2D m = new DenseDoubleMatrix2D(x);
LogisticRegression lg = new LogisticRegression(new DenseDoubleMatrix2D(x), new DenseDoubleMatrix1D(y), new DenseDoubleMatrix1D(theta));
System.out.println(m.cardinality());
}
}