-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlmtest.cpp
More file actions
155 lines (124 loc) · 4.81 KB
/
Copy pathlmtest.cpp
File metadata and controls
155 lines (124 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include <boost/fusion/include/for_each.hpp>
#include <boost/mpl/for_each.hpp>
#include <boost/mpl/range_c.hpp>
#include "logbilinear_lm.h"
#include "nnopt.h"
#undef _GNU_SOURCE
#define _GNU_SOURCE
#include <fenv.h>
template<int Order,class FF>
typename nnet::lblm<Order,FF>::dataset lblm_load_data(const char *file, const typename nnet::lblm<Order,FF>::dataset::vocmap_type &vocmap);
template<int Order,class FF>
typename nnet::lblm<Order,FF>::dataset lblm_load_data(const char *file, const typename nnet::lblm<Order,FF>::dataset::vocmap_type *vocmap = NULL);
template<int Order,class FF>
typename nnet::lblm<Order,FF>::dataset lblm_load_data(const char *file, const typename nnet::lblm<Order,FF>::dataset::vocmap_type &vocmap) {
return lblm_load_data<Order,FF>(file, &vocmap);
}
namespace {
template<int Order,class OutputType,class Idx>
struct process_ngram {
typedef typename std::remove_reference<OutputType>::type output_type;
output_type &out_;
std::size_t row_;
Idx *p_;
process_ngram(output_type &out, std::size_t row, Idx *p) : out_(out), row_(row), p_(p) {}
template<class Derived>
void set_row(Eigen::SparseMatrixBase<Derived> &sparse, std::size_t row, Idx p) const {
sparse.insert(row, p) = 1;
}
void set_row(nnet::vocidx_vector &idxmat, std::size_t row, Idx p) const {
idxmat(row) = p;
}
template<class T>
void operator()(T i) {
typedef typename output_type::float_type FF;
set_row(out_.template at<Order - T::value - 1>(), row_, *p_);
if(*p_ != 0)
p_--;
}
};
}
template<int Order,class FF>
typename nnet::lblm<Order,FF>::dataset lblm_load_data(const char *file, const typename nnet::lblm<Order,FF>::dataset::vocmap_type *vocmap) {
typedef typename nnet::vocidx_type idx;
std::vector<idx> corpus;
const idx SENTENCE_BOUNDARY = 0;
const idx UNKNOWN_WORD = 1;
idx vocsize = 2;
typename nnet::lblm<Order,FF>::dataset out;
typedef typename nnet::lblm<Order,FF>::dataset::vocmap_type vocmap_type;
if(vocmap != NULL) {
out.vocmap() = *vocmap;
vocsize = out.vocmap().size();
} else {
out.vocmap().insert(std::make_pair("</s>", SENTENCE_BOUNDARY));
out.vocmap().insert(std::make_pair("<unk>", UNKNOWN_WORD));
}
std::size_t nwords = 0;
std::ifstream is(file);
if(!is) {
std::cerr << "Problem reading file: " << file << std::endl;
std::exit(1);
}
corpus.push_back(SENTENCE_BOUNDARY);
for(std::string line; getline(is, line);) {
std::istringstream ts(line);
for(std::string token; getline(ts, token, ' ');) {
typename vocmap_type::iterator it = out.vocmap().find(token);
idx tokidx;
if(it == out.vocmap().end()) {
if(vocmap != NULL)
tokidx = UNKNOWN_WORD;
else {
tokidx = vocsize++;
out.vocmap().insert(std::make_pair(token, tokidx));
}
} else
tokidx = it->second;
corpus.push_back(tokidx);
nwords++;
}
corpus.push_back(SENTENCE_BOUNDARY);
}
auto setup_matrix = [&] (nnet::vocidx_vector &mat) {
mat.resize(corpus.size() - 1);
};
boost::fusion::for_each(out.inputs().sequence(), setup_matrix);
out.targets().matrix().resize(corpus.size() - 1, vocsize);
out.targets().matrix().reserve(Eigen::VectorXi::Constant(corpus.size() - 1, 1));
for(std::size_t i = 1; i < corpus.size(); i++) { // the first element is just a boundary
out.targets().matrix().insert(i - 1, corpus[i]) = 1;
boost::mpl::for_each<boost::mpl::range_c<int,0,Order> >
(process_ngram<Order,decltype(out.inputs()),idx>(out.inputs(), i - 1, &corpus[i-1]));
}
return out;
}
int main(int argc, char **argv) {
feenableexcept(FE_INVALID | FE_DIVBYZERO);
std::string suffix = "cl";
if(argc == 2)
suffix = argv[1];
const int ngram_order = 3;
typedef nnet::lblm<ngram_order,float> net_type;
net_type::dataset trainset = lblm_load_data<ngram_order,float>((std::string("train.") + suffix).c_str());
net_type::dataset valset = lblm_load_data<ngram_order,float>((std::string("val.") + suffix).c_str(), trainset.vocmap());
net_type::dataset testset = lblm_load_data<ngram_order,float>((std::string("test.") + suffix).c_str(), trainset.vocmap());
net_type net(trainset.vocmap().size(), 80);
nnet::crossentropy_loss loss;
nnet::nnopt<net_type> opt(net);
nnet::nnopt_results<net_type> res = opt.train(net, loss, trainset, valset);
std::cout << "Training energy: ";
std::copy(res.trainerr.begin(), res.trainerr.end(), std::ostream_iterator<net_type::float_type>(std::cout, " "));
std::cout << "\nValidation energy: ";
std::copy(res.valerr.begin(), res.valerr.end(), std::ostream_iterator<net_type::float_type>(std::cout, " "));
std::cout << std::endl;
const auto &testout = net(res.best_weights, testset.inputs());
std::cout << "Test energy: " << evaluate_loss(loss, testout, testset.targets()) << '\n';
std::cout << "BEST WEIGHTS:\n" << res.best_weights << '\n';
return 0;
}