-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmain.cpp
More file actions
85 lines (72 loc) · 2.9 KB
/
main.cpp
File metadata and controls
85 lines (72 loc) · 2.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include "Espnet_TRT_Transformer_Encoder.h"
#include "Espnet_TRT_Transformer_Decoder.h"
#ifndef TEST
#include "plugin_registery.h"
#endif
int main(int argc,char**argv) {
#ifndef TEST
registryFinalSlicePlugin();
registryLayerNormalizaitonPlugin();
registryPositionWisePlugin();
registrySelfAttentionPlugin();
registrySrcAttentionPlugin();
#endif
if (argc < 3) {
std::cout
<< "--path [the transformer model weight path, Required!]" << std::endl
<< "--idim [input feature dimension, default 83]" << std::endl
<< "--n_Head [the number of head in attention, default 4]" << std::endl
<< "--odim [feature dimension, default 256]" << std::endl
<< "--feed_forward [feed forward dimension, default 2048]" << std::endl
<< "--nvocab [the size of vocabulary, Required!]" << std::endl
<< "--dtype [the data type used for computation {float/half}, default float]" << std::endl
<< "--concat_after [concat is used, default false]" << std::endl
<< "--normalize_before [default true]" << std::endl
<< "--encoder_layers [the number of attention in encoder, default 12]" << std::endl
<< "--decoder_layers [the number of attention in decoder, default 6]" << std::endl
<< "--batchsize [the max batchsize of decoder, default 16]" << std::endl
<< "--topk [the topk in each decoder step, default 16]" << std::endl
<< "--maxseql [the max sequence length of encoder, default 500]" << std::endl
<< "--model_name [the output trt model name, default asr]" << std::endl;
}
std::map<std::string, std::string> configure{
{"--path","asr"},
{"--idim","83"},
{"--n_Head","4"},
{"--odim","256"},
{"--feed_forward","2048"},
{"--nvocab","7244"},
{"--dtype","float"},
{"--concat_after","false"},
{"--normalize_before","true"},
{"--encoder_layers","12"},
{"--decoder_layers","6"},
{"--batchsize","16"},
{"--topk","16"},
{"--maxseql","5000"},
{"--model_name","asr"}
};
for (int i = 1; i < argc; i += 2) {
if (configure.count(argv[i]) > 0) {
configure[argv[i]] = argv[i + 1];
}
else {
std::cerr << "Option is not supported!" << std::endl;
exit(0);
}
}
if (configure["--path"] == "") {
std::cerr << "The path to model weight is not specified!" << std::endl;
std::cout << "To see more information by running program without option input!" << std::endl;
exit(0);
}
if (configure["--nvocab"] == "") {
std::cerr << "The size of vocabulary is not specified!" << std::endl;
std::cout << "To see more information by running program without option input!" << std::endl;
}
std::cout << "building encoder model ..." << std::endl;
Espnet_TRT_Transformer_Encoder(configure);
std::cout << "building decoder model ..." << std::endl;
Espnet_TRT_Transformer_Decoder(configure);
return 0;
}