Skip to content

Commit a6dca37

Browse files
committed
opfamily:支持op的多实现
1 parent e241495 commit a6dca37

27 files changed

Lines changed: 2932 additions & 2387 deletions

doc/excuter/op-mem-ompsimd/list.md

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,38 @@
22

33
本页面由 `excuter/op-mem-ompsimd/src/deepx/op/opfactory.hpp` 生成,请勿手动修改
44

5-
| Operation | Data Types | Math Formula | IR Instruction |
6-
|-----------|------------|--------------|----------------|
7-
| add_scalar | float32, float64 | T2 = T1 + 1.0 | add_scalar@float32 T1 1.0 -> T2 |
8-
| rdiv_scalar | float32, float64 | T3 =1 / T2 | rdiv_scalar@float32 1 T2 -> T3 |
9-
| constant | float32, float64 | T1 = full(shape, 0.0) | constant@float32 0.0 -> T1 |
10-
| uniform | float32, float64 | uniform(-1.0, 1.0,T1) | uniform@float32 -1.0 1.0 -> T1 |
11-
| mul_scalar | float32, float64 | T2 = T1 * 2.0 | mul_scalar@float32 T1 2.0 -> T2 |
12-
| deltensor | any | del T1 | deltensor@any T1 -> |
13-
| sub | float32, float64 | T3 = T1 - T2 | sub@int32 T1 T2 -> T3 |
14-
| sum | float32, float64 | T2 = sum(T1, dims=[1,2]) | sum@float32 T1 1 2 -> T2 |
15-
| argset | float32, float64, int32 | shape = [3, 4, 5] | argset@int32 3 4 5 -> shape |
16-
| arange | float32, float64 | arange(start=0.0, step=1.0,T1) | arange@float32 0.0 1.0 -> T1 |
17-
| add | float32, float64 | T3 = T1 + T2 | add@int32 T1 T2 -> T3 |
18-
| copytensor | float32, float64, int16, int32, int64, int8 | T2 = T1.copy() | copytensor@float32 T1 -> T2 |
19-
| min | float32, float64 | C = min(A,B) | min@float32 A B -> C |
20-
| print | any | | print@any -> |
21-
| newtensor | float32, float64, int16, int32, int64, int8 | T1 = zeros(shape) | newtensor@float32 shape -> T1 |
22-
| div | float32, float64 | T3 = T1 / T2 | div@float32 T1 T2 -> T3 |
23-
| div_scalar | float32, float64 | T2 = T1 / 2.0 | div_scalar@float32 T1 2.0 -> T2 |
24-
| reshape | any | T2 = reshape(T1, [2,3,4]) | reshape@float32 T1 2 3 4 -> T2 |
25-
| min_scalar | float32, float64 | B= min(A, 1.0) | min_scalar@float32 A 1.0 -> B |
26-
| sqrt | float32, float64 | T2 = sqrt(T1) | sqrt@float32 T1 -> T2 |
27-
| mul | float32, float64 | T3 = T1 * T2 | mul@float32 T1 T2 -> T3 |
28-
| exp | float32, float64 | T2 = exp(T1) | exp@float32 T1 -> T2 |
29-
| max_scalar | float32, float64 | T2 = max(T1, 0.0) | max_scalar@float32 T1 0.0 -> T2 |
30-
| max | float32, float64 | T3 = max(T1,T2) | max@float32 T1 -> T2 |
31-
| pow | float32, float64 | T3 = T1 ^ T2 | pow@float32 T1 T2 -> T3 |
32-
| pow_scalar | float32, float64 | T2 = T1 ^ 2.0 | pow_scalar@float32 T1 2.0 -> T2 |
33-
| matmul | float32, float64 | T3 = T1 @ T2 | matmul@float32 T1 T2 -> T3 |
34-
| clonetensor | float32, float64, int16, int32, int64, int8 | T2 = T1.clone() | clonetensor@float32 T1 -> T2 |
35-
| transpose | any | T2 = transpose(T1, dimorder=[1,0]) | transpose@float32 T1 1 0 -> T2 |
36-
| expand | any | T2 = expand(T1, axis=[4,6,12]) | expand@float32 T1 4 6 12 -> T2 |
37-
| concat | float32 | T3 = concat([T1, T2], axis=3) | concat@float32 T1 T2 3 -> T3 |
5+
| Operation | Author | Data Types | Math Formula | IR Instruction |
6+
|-----------|--------|------------|--------------|----------------|
7+
| divscalar | miaobyte | float32, float64 | T2 = T1 / 2.0 | divscalar@float32 T1 2.0 -> T2 |
8+
| addscalar | miaobyte | float32, float64 | T2 = T1 + 1.0 | addscalar@float32 T1 1.0 -> T2 |
9+
| uniform | | float32, float64 | uniform(-1.0, 1.0,T1) | uniform@float32 -1.0 1.0 -> T1 |
10+
| deltensor | | any | del T1 | deltensor@any T1 -> |
11+
| minscalar | | float32, float64 | B= min(A, 1.0) | minscalar@float32 A 1.0 -> B |
12+
| rdivscalar | miaobyte | float32, float64 | T3 =1 / T2 | rdivscalar@float32 1 T2 -> T3 |
13+
| constant | | float32, float64 | T1 = full(shape, 0.0) | constant@float32 0.0 -> T1 |
14+
| powscalar | miaobyte | float32, float64 | T2 = T1 ^ 2.0 | powscalar@float32 T1 2.0 -> T2 |
15+
| sub | cblas | float32, float64 | T3 = T1 - T2 | sub@int32 T1 T2 -> T3 |
16+
| sub | miaobyte | float32, float64 | T3 = T1 - T2 | sub@int32 T1 T2 -> T3 |
17+
| sum | | float32, float64 | T2 = sum(T1, dims=[1,2]) | sum@float32 T1 1 2 -> T2 |
18+
| argset | | float32, float64, int32 | shape = [3, 4, 5] | argset@int32 3 4 5 -> shape |
19+
| arange | | float32, float64 | arange(start=0.0, step=1.0,T1) | arange@float32 0.0 1.0 -> T1 |
20+
| transpose | | any | T2 = transpose(T1, dimorder=[1,0]) | transpose@float32 T1 1 0 -> T2 |
21+
| clonetensor | | float32, float64, int16, int32, int64, int8 | T2 = T1.clone() | clonetensor@float32 T1 -> T2 |
22+
| add | cblas | float32, float64 | T3 = T1 + T2 | add@int32 T1 T2 -> T3 |
23+
| add | miaobyte | float32, float64, int16, int32, int64, int8 | T3 = T1 + T2 | add@int32 T1 T2 -> T3 |
24+
| copytensor | | float32, float64, int16, int32, int64, int8 | T2 = T1.copy() | copytensor@float32 T1 -> T2 |
25+
| min | | float32, float64 | C = min(A,B) | min@float32 A B -> C |
26+
| print | | any | | print@any -> |
27+
| newtensor | | float32, float64, int16, int32, int64, int8 | T1 = zeros(shape) | newtensor@float32 shape -> T1 |
28+
| mulscalar | miaobyte | float32, float64 | T2 = T1 * 2.0 | mulscalar@float32 T1 2.0 -> T2 |
29+
| div | miaobyte | float32, float64 | T3 = T1 / T2 | div_miaobyte@float32 T1 T2 -> T3 |
30+
| sqrt | miaobyte | float32, float64 | T2 = sqrt(T1) | sqrt@float32 T1 -> T2 |
31+
| mul | miaobyte | float32, float64 | T3 = T1 * T2 | mul@float32 T1 T2 -> T3 |
32+
| exp | miaobyte | float32, float64 | T2 = exp(T1) | exp@float32 T1 -> T2 |
33+
| max | | float32, float64 | T3 = max(T1,T2) | max@float32 T1 -> T2 |
34+
| pow | miaobyte | float32, float64 | T3 = T1 ^ T2 | pow@float32 T1 T2 -> T3 |
35+
| maxscalar | | float32, float64 | T2 = max(T1, 0.0) | maxscalar@float32 T1 0.0 -> T2 |
36+
| matmul | | float32, float64 | T3 = T1 @ T2 | matmul@float32 T1 T2 -> T3 |
37+
| reshape | | any | T2 = reshape(T1, [2,3,4]) | reshape@float32 T1 2 3 4 -> T2 |
38+
| expand | | any | T2 = expand(T1, axis=[4,6,12]) | expand@float32 T1 4 6 12 -> T2 |
39+
| concat | | float32 | T3 = concat([T1, T2], axis=3) | concat@float32 T1 T2 3 -> T3 |

excuter/common/src/deepx/op/op.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ namespace deepx::op
99
{
1010
// 与deepx/front/py/deepx/nn/deepxir.py对应
1111

12-
// 新格式示例:mul@float32 a(a_grad) b(b_grad) -> a(a_grad) //id=1 create_time=1714512000 send_time=1714512000
12+
// 前向 示例:miaobyte@mul[float32] a b -> a //id=1 create_time=1714512000 send_time=1714512000
13+
// 反向,混合精度计算 示例:miaobyte@matmul a[float16](a_grad[float16]) b[float16](b_grad[float16]) <- c[float32](a_grad[float32]) //id=1 create_time=1714512000 send_time=1714512000
1314
void Op::load(const string &input)
1415
{
1516
// 分割元数据部分
@@ -144,6 +145,7 @@ namespace deepx::op
144145
const vector<string> &returns_grad)
145146
{
146147
this->name = opname;
148+
this->author = "";
147149
this->dtype = dtype;
148150
this->args = args;
149151
this->returns = returns;

excuter/common/src/deepx/op/op.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
#include <iostream>
99
#include <sstream>
1010
#include <chrono>
11+
1112
#include "deepx/tensor.hpp"
1213
#include "deepx/mem/mem.hpp"
1314
#include "deepx/dtype.hpp"
1415

16+
#include "stdutil/error.hpp"
1517
namespace deepx::op
1618
{
1719
using deepx::mem::Mem;
@@ -21,6 +23,7 @@ namespace deepx::op
2123
{
2224
public:
2325
string name;
26+
string author;
2427
string dtype;
2528
vector<string> args;
2629
vector<string> args_grad;
@@ -46,14 +49,14 @@ namespace deepx::op
4649
// 改为普通虚函数,提供默认实现
4750
virtual void forward(mem::Mem &mem)
4851
{
49-
throw std::runtime_error("forward not implemented");
52+
throw NotImplementError(name);
5053
}
5154

5255
virtual void backward(mem::Mem &mem)
5356
{
54-
throw std::runtime_error("backward not implemented");
57+
throw NotImplementError(name);
5558
}
56-
59+
5760
virtual string math_formula() const {
5861
return "";
5962
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef STDUTIL_ERROR_HPP
2+
#define STDUTIL_ERROR_HPP
3+
4+
#include <stdexcept>
5+
#include <string>
6+
7+
8+
class NotImplementError : public std::logic_error {
9+
public:
10+
explicit NotImplementError(const std::string& method_name)
11+
: std::logic_error("Not implement: " + method_name) {}
12+
};
13+
class UnsupportedOperationException : public std::logic_error {
14+
public:
15+
explicit UnsupportedOperationException(const std::string& method_name)
16+
: std::logic_error("Unsupported method: " + method_name) {}
17+
};
18+
19+
#endif // STDUTIL_ERROR_HPP

excuter/op-mem-ompsimd/src/client/main.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,36 @@ int main()
5959
opresp.id = op.id;
6060
opresp.recv_at = op.recv_at;
6161

62-
if (opfactory.ops.find(op.name)==opfactory.ops.end()){
62+
63+
if ( opfactory.op_families.find(op.name)==opfactory.op_families.end()){
6364
cerr<<"<op> "<<op.name<<" not found"<<endl;
6465
opresp.error("op"+op.name+" not found");
6566
continue;
67+
}
68+
auto op_family = *(opfactory.op_families.find(op.name)->second);
69+
string op_author_name= op.author;
70+
if (op.author==""){
71+
op_author_name= op_family._default;
72+
if (op_author_name=="" && op_family.op_authors.size()>0){
73+
op_author_name=op_family.op_authors.begin()->first;
74+
}else{
75+
cerr<<"<op> "<<op.name<<" no author implement"<<endl;
76+
opresp.error("op"+op.name+" no author implement");
77+
continue;
78+
}
6679
}
67-
auto &type_map = opfactory.ops.find(op.name)->second;
68-
if (type_map.find(op.dtype)==type_map.end()){
80+
if (op_family.op_authors.find(op_author_name)==op_family.op_authors.end()){
81+
cerr<<"<op> "<<op.name<<" "<<op_author_name<<" not found"<<endl;
82+
opresp.error("op"+op.name+" "+op_author_name+" not found");
83+
continue;
84+
}
85+
auto &type_map =*(op_family.op_authors.find(op_author_name)->second);
86+
if (type_map.ops.find(op.dtype)==type_map.ops.end()){
6987
cerr<<"<op>"<<op.name<<" "<<op.dtype<<" not found"<<endl;
7088
opresp.error("op"+op.dtype+" not found");
7189
continue;
7290
}
73-
auto src = type_map.find(op.dtype)->second;
91+
auto src = type_map.ops.find(op.dtype)->second;
7492

7593
(*src).init(op.name, op.dtype, op.args, op.returns, op.grad, op.args_grad, op.returns_grad);
7694
memmutex.lock();

excuter/op-mem-ompsimd/src/deepx/op/arg.hpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,11 @@ namespace deepx::op{
99
class ArgSet : public Op{
1010
public:
1111
ArgSet(){
12-
this->init("argset",deepx::dtype<T>::name(), {}, {}, false, {}, {});
12+
this->init("argset", deepx::dtype<T>::name(), {}, {}, false, {}, {});
1313
}
1414

15-
ArgSet(string name,T value){
16-
this->init("argset",deepx::dtype<T>::name(), {name,value}, {}, false, {}, {});
17-
}
18-
19-
ArgSet(string name,vector<T> value){
20-
this->init("argset",deepx::dtype<T>::name(), {name,value}, {}, false, {}, {});
21-
}
22-
ArgSet(initializer_list<string> args){
23-
this->init("argset",deepx::dtype<T>::name(), args, {}, false, {}, {});
24-
}
2515
void setexample() override {
26-
this->init("argset", "int32", {"3", "4", "5"}, {"shape"}, false, {}, {});
16+
this->init("argset", "int32", {"3", "4", "5"}, {"shape"}, false, {}, {});
2717
}
2818
string math_formula() const override {
2919
return "shape = [3, 4, 5]";

0 commit comments

Comments
 (0)