-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLearner.java
More file actions
32 lines (23 loc) · 1.06 KB
/
Copy pathLearner.java
File metadata and controls
32 lines (23 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
package eu.redzoo.ml;
import java.util.List;
public class Learner {
public static LinearRegressionFunction train(LinearRegressionFunction targetFunction, List<Double[]> dataset, List<Double> labels, double alpha) {
int m = dataset.size();
double[] thetaVector = targetFunction.getThetas();
double[] newThetaVector = new double[thetaVector.length];
// compute the new theta of each element of the theta array
for (int j = 0; j < thetaVector.length; j++) {
// summarize the error gap * feature
double sumErrors = 0;
for (int i = 0; i < m; i++) {
Double[] featureVector = dataset.get(i);
double error = targetFunction.apply(featureVector) - labels.get(i);
sumErrors += error * featureVector[j];
}
// compute the new theta value
double gradient = (1.0 / m) * sumErrors;
newThetaVector[j] = thetaVector[j] - alpha * gradient;
}
return new LinearRegressionFunction(newThetaVector);
}
}