Skip to content

Commit 9189fe7

Browse files
committed
op:注册op
1 parent 0e71199 commit 9189fe7

8 files changed

Lines changed: 94 additions & 126 deletions

File tree

excuter/op-mem-ompsimd/src/client/main.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
#include "deepx/op/op.hpp"
1313
#include "deepx/op/elementwise.hpp"
1414
#include "deepx/op/reduce.hpp"
15-
1615
#include "deepx/mem/mem.hpp"
1716
#include "client/udpserver.hpp"
18-
#include "client/yml.hpp"
17+
1918
using namespace deepx::tensorfunc;
2019
using namespace deepx::mem;
2120

@@ -31,14 +30,18 @@ int main()
3130
mem.add("result", std::make_shared<deepx::Tensor<float>>(result));
3231

3332
client::udpserver server(8080);
34-
server.func = [&mem](char *buffer)
33+
deepx::op::OpFactory opfactory;
34+
server.func = [&mem, &opfactory](char *buffer)
3535
{
36-
// auto op = client::parse(buffer);
37-
38-
// op->forward(mem);
39-
40-
// print(*mem.gettensor<float>("result"));
41-
36+
deepx::op::Op op;
37+
op.load(buffer);
38+
op.forward(mem);
39+
40+
shared_ptr<deepx::op::Op> opsrc = opfactory.get_op(op);
41+
42+
(*opsrc).init(op.name, op.dtype, op.args, op.returns, op.require_grad, op.args_grad, op.returns_grad);
43+
(*opsrc).forward(mem);
44+
print(*mem.gettensor<float>("result"));
4245
};
4346
server.start();
4447
return 0;

excuter/op-mem-ompsimd/src/client/yml.cpp

Lines changed: 0 additions & 34 deletions
This file was deleted.

excuter/op-mem-ompsimd/src/client/yml.hpp

Lines changed: 0 additions & 15 deletions
This file was deleted.

excuter/op-mem-ompsimd/src/deepx/op/concat.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ namespace deepx::op
1111
{
1212
public:
1313
Concat(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
14-
this->init("concat", args, returns, require_grad, args_grad, returns_grad);
14+
this->init("concat",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
1515
}
1616
Concat(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
17-
this->init("concat", args, returns, require_grad, args_grad, returns_grad);
17+
this->init("concat",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
1818
}
1919
void forward(mem::Mem<T> &mem) override
2020
{
@@ -38,5 +38,8 @@ namespace deepx::op
3838
cpu::split(output,axis,input);
3939
};
4040
};
41+
// 注册concat算子
42+
auto concat_float = OpFactory::add_op<Concat<float>>("concat");
43+
auto concat_double = OpFactory::add_op<Concat<double>>("concat");
4144
}
4245
#endif // DEEPX_OP_CONCAT_HPP

excuter/op-mem-ompsimd/src/deepx/op/elementwise.hpp

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ namespace deepx::op
1717
class Add : public OpT<T>
1818
{
1919
public:
20+
Add()=default;
2021
Add(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
21-
this->init("add"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
22+
this->init("add",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
2223
}
2324
Add(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
24-
this->init("add"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
25+
this->init("add",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
2526
}
2627
void forward(mem::Mem &mem) override
2728
{
@@ -43,15 +44,17 @@ namespace deepx::op
4344
deepx::tensorfunc::add(*b_grad, *c_grad, *b_grad); // b_grad += c_grad
4445
}
4546
};
47+
// 注册add算子
48+
4649
template <typename T>
4750
class Add_scalar : public OpT<T>
4851
{
4952
public:
5053
Add_scalar(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
51-
this->init("add_scalar"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
54+
this->init("add_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
5255
}
5356
Add_scalar(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
54-
this->init("add_scalar"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
57+
this->init("add_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
5558
}
5659
//已验证,2025-02-19,lipeng
5760
void forward(mem::Mem &mem) override
@@ -78,10 +81,10 @@ namespace deepx::op
7881
{
7982
public:
8083
Sub(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
81-
this->init("sub"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
84+
this->init("sub",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
8285
}
8386
Sub(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
84-
this->init("sub"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
87+
this->init("sub",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
8588
}
8689
void forward(mem::Mem &mem) override
8790
{
@@ -109,10 +112,10 @@ namespace deepx::op
109112
{
110113
public:
111114
Mul(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
112-
this->init("mul"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
115+
this->init("mul",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
113116
}
114117
Mul(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
115-
this->init("mul"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
118+
this->init("mul",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
116119
}
117120
void forward(mem::Mem &mem) override
118121
{
@@ -145,10 +148,10 @@ namespace deepx::op
145148
{
146149
public:
147150
Mul_scalar(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
148-
this->init("mul_scalar"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
151+
this->init("mul_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
149152
}
150153
Mul_scalar(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
151-
this->init("mul_scalar"+dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
154+
this->init("mul_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
152155
}
153156
//已验证,2025-02-19,lipeng
154157
void forward(mem::Mem &mem) override
@@ -179,10 +182,10 @@ namespace deepx::op
179182
{
180183
public:
181184
Div(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
182-
this->init("div", args, returns, require_grad, args_grad, returns_grad);
185+
this->init("div",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
183186
}
184187
Div(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
185-
this->init("div", args, returns, require_grad, args_grad, returns_grad);
188+
this->init("div",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
186189
}
187190
void forward(mem::Mem &mem) override
188191
{
@@ -221,10 +224,10 @@ namespace deepx::op
221224
{
222225
public:
223226
Div_scalar(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
224-
this->init("div_scalar", args, returns, require_grad, args_grad, returns_grad);
227+
this->init("div_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
225228
}
226229
Div_scalar(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
227-
this->init("div_scalar", args, returns, require_grad, args_grad, returns_grad);
230+
this->init("div_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
228231
}
229232
//已验证,2025-02-19,lipeng
230233
void forward(mem::Mem &mem) override
@@ -254,10 +257,10 @@ namespace deepx::op
254257
class Sqrt : public OpT<T>{
255258
public:
256259
Sqrt(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
257-
this->init("sqrt", args, returns, require_grad, args_grad, returns_grad);
260+
this->init("sqrt",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
258261
}
259262
Sqrt(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
260-
this->init("sqrt", args, returns, require_grad, args_grad, returns_grad);
263+
this->init("sqrt",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
261264
}
262265
void forward(mem::Mem &mem) override
263266
{
@@ -284,10 +287,10 @@ namespace deepx::op
284287
{
285288
public:
286289
Exp(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
287-
this->init("exp", args, returns, require_grad, args_grad, returns_grad);
290+
this->init("exp",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
288291
}
289292
Exp(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
290-
this->init("exp", args, returns, require_grad, args_grad, returns_grad);
293+
this->init("exp",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
291294
}
292295
void forward(mem::Mem &mem) override
293296
{
@@ -315,10 +318,10 @@ namespace deepx::op
315318
{
316319
public:
317320
Pow(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
318-
this->init("pow", args, returns, require_grad, args_grad, returns_grad);
321+
this->init("pow",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
319322
}
320323
Pow(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
321-
this->init("pow", args, returns, require_grad, args_grad, returns_grad);
324+
this->init("pow",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
322325
}
323326
void forward(mem::Mem &mem) override
324327
{
@@ -361,10 +364,10 @@ namespace deepx::op
361364
{
362365
public:
363366
Pow_scalar(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
364-
this->init("pow_scalar", args, returns, require_grad, args_grad, returns_grad);
367+
this->init("pow_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
365368
}
366369
Pow_scalar(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
367-
this->init("pow_scalar", args, returns, require_grad, args_grad, returns_grad);
370+
this->init("pow_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
368371
}
369372
void forward(mem::Mem &mem) override
370373
{
@@ -399,10 +402,10 @@ namespace deepx::op
399402
{
400403
public:
401404
Log(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
402-
this->init("log", args, returns, require_grad, args_grad, returns_grad);
405+
this->init("log",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
403406
}
404407
Log(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
405-
this->init("log", args, returns, require_grad, args_grad, returns_grad);
408+
this->init("log",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
406409
}
407410
void forward(mem::Mem &mem) override
408411
{

excuter/op-mem-ompsimd/src/deepx/op/matmul.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ namespace deepx::op
1919
{
2020
public:
2121
MatMul(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
22-
this->init("matmul", args, returns, require_grad, args_grad, returns_grad);
22+
this->init("matmul",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
2323
}
2424
MatMul(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
25-
this->init("matmul", args, returns, require_grad, args_grad, returns_grad);
25+
this->init("matmul",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
2626
}
2727
void forward(mem::Mem &mem) override
2828
{

0 commit comments

Comments
 (0)