-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmalevich_classifier.cpp
More file actions
44 lines (40 loc) · 1.3 KB
/
malevich_classifier.cpp
File metadata and controls
44 lines (40 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include "malevich_classifier.h"
MalevichClassifier::MalevichClassifier(int clusters) : clusters_(clusters) {
}
void MalevichClassifier::Learn(const Dataset& dataset) {
prior_.clear();
//Calculating prior
for (size_t i = 0; i < dataset.size(); ++i) {
if (dataset[i].class_label >= (int)prior_.size()) {
prior_.resize(dataset[i].class_label + 1);
}
++prior_[dataset[i].class_label];
}
for (size_t i = 0; i < prior_.size(); ++i) {
prior_[i] /= dataset.size();
}
//Preparing data for EM.
vector< vector< vector<Feature> > > data(prior_.size());
for (size_t i = 0; i < dataset.size(); ++i) {
data[dataset[i].class_label].push_back(dataset[i].features);
}
prob_.resize(prior_.size());
for (size_t i = 0; i < prior_.size(); ++i) {
prob_[i] = ExpectationMaximization(data[i], 1e-12, clusters_);
}
}
void MalevichClassifier::Classify(Dataset* dataset) {
for (size_t i = 0; i < dataset->size(); ++i) {
long double maxValue = -1;
int bestClass = 0;
for (size_t j = 0; j < prob_.size(); ++j) {
// Assumed equal costs lambda
long double value = prior_[j] * prob_[j].GetProbability(dataset->at(i).features);
if (maxValue < value) {
maxValue = value;
bestClass = (int)j;
}
}
dataset->at(i).class_label = bestClass;
}
}