This repository was archived by the owner on Jan 26, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 222
Expand file tree
/
Copy pathc_api.cpp
More file actions
115 lines (93 loc) · 3.61 KB
/
c_api.cpp
File metadata and controls
115 lines (93 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include "multiverso/c_api.h"
#include "multiverso/multiverso.h"
#include "multiverso/table/array_table.h"
#include "multiverso/table/matrix_table.h"
#include "multiverso/util/log.h"
#include "multiverso/updater/updater.h"
extern "C" {
void MV_Init(int* argc, char* argv[]) {
multiverso::MV_Init(argc, argv);
}
void MV_ShutDown(){
multiverso::MV_ShutDown();
}
void MV_Barrier(){
multiverso::MV_Barrier();
}
int MV_NumWorkers(){
return multiverso::MV_NumWorkers();
}
int MV_WorkerId(){
return multiverso::MV_WorkerId();
}
int MV_ServerId(){
return multiverso::MV_ServerId();
}
// Array Table
void MV_NewArrayTable(int size, TableHandler* out) {
*out = multiverso::MV_CreateTable(multiverso::ArrayTableOption<float>(size));
}
void MV_GetArrayTable(TableHandler handler, float* data, int size) {
auto worker = reinterpret_cast<multiverso::ArrayWorker<float>*>(handler);
worker->Get(data, size);
}
void MV_AddArrayTable(TableHandler handler, float* data, int size) {
auto worker = reinterpret_cast<multiverso::ArrayWorker<float>*>(handler);
worker->Add(data, size);
}
void MV_AddArrayTableOption(TableHandler handler, float* data, int size, float lr, float mom, float rho, float lambda) {
auto worker = reinterpret_cast<multiverso::ArrayWorker<float>*>(handler);
multiverso::AddOption option;
option.set_worker_id(multiverso::MV_WorkerId());
option.set_learning_rate(lr);
option.set_momentum(mom);
option.set_rho(rho);
option.set_lambda(lambda);
worker->Add(data, size, &option);
}
void MV_AddAsyncArrayTable(TableHandler handler, float* data, int size) {
auto worker = reinterpret_cast<multiverso::ArrayWorker<float>*>(handler);
worker->AddAsync(data, size);
}
void MV_AddAsyncArrayTableOption(TableHandler handler, float* data, int size, float lr, float mom, float rho, float lambda) {
auto worker = reinterpret_cast<multiverso::ArrayWorker<float>*>(handler);
multiverso::AddOption option;
option.set_worker_id(multiverso::MV_WorkerId());
option.set_learning_rate(lr);
option.set_momentum(mom);
option.set_rho(rho);
option.set_lambda(lambda);
worker->AddAsync(data, size, &option);
}
// MatrixTable
void MV_NewMatrixTable(int num_row, int num_col, TableHandler* out) {
*out = multiverso::MV_CreateTable(multiverso::MatrixTableOption<float>(num_row, num_col));
}
void MV_GetMatrixTableAll(TableHandler handler, float* data, int size) {
auto worker = reinterpret_cast<multiverso::MatrixWorkerTable<float>*>(handler);
worker->Get(data, size);
}
void MV_AddMatrixTableAll(TableHandler handler, float* data, int size) {
auto worker = reinterpret_cast<multiverso::MatrixWorkerTable<float>*>(handler);
worker->Add(data, size);
}
void MV_AddAsyncMatrixTableAll(TableHandler handler, float* data, int size) {
auto worker = reinterpret_cast<multiverso::MatrixWorkerTable<float>*>(handler);
worker->AddAsync(data, size);
}
void MV_GetMatrixTableByRows(TableHandler handler, float* data, int size,
int row_ids[], int row_ids_n) {
auto worker = reinterpret_cast<multiverso::MatrixWorkerTable<float>*>(handler);
worker->Get(data, size, row_ids, row_ids_n);
}
void MV_AddMatrixTableByRows(TableHandler handler, float* data, int size,
int row_ids[], int row_ids_n) {
auto worker = reinterpret_cast<multiverso::MatrixWorkerTable<float>*>(handler);
worker->Add(data, size, row_ids, row_ids_n);
}
void MV_AddAsyncMatrixTableByRows(TableHandler handler, float* data, int size,
int row_ids[], int row_ids_n) {
auto worker = reinterpret_cast<multiverso::MatrixWorkerTable<float>*>(handler);
worker->AddAsync(data, size, row_ids, row_ids_n);
}
}