Skip to content

Commit 0e71199

Browse files
committed
op:注册op
1 parent 24f3bf1 commit 0e71199

File tree

12 files changed

+296
-674
lines changed

12 files changed

+296
-674
lines changed

excuter/op-mem-ompsimd/src/client/yml.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace client
55
{
66
using namespace deepx::op;
77
using namespace deepx::mem;
8-
shared_ptr<OpBase> parse(const char *yml)
8+
Op parse(const char *yml)
99
{
1010
YAML::Node config = YAML::Load(yml);
1111

@@ -27,10 +27,8 @@ namespace client
2727
returns_grad = config["returns_grad"].as<std::vector<std::string>>();
2828
}
2929

30-
// 通过工厂创建OP
31-
auto op = OpFactory::Create(opname, dtype, args, returns,
32-
require_grad, args_grad, returns_grad);
33-
30+
Op op;
31+
op.init(opname+dtype, args, returns, require_grad, args_grad, returns_grad);
3432
return op;
3533
}
3634
}

excuter/op-mem-ompsimd/src/client/yml.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ namespace client
1010
{
1111
using namespace deepx::op;
1212
using namespace std;
13-
shared_ptr<OpBase> parse(const char *yml);
13+
Op parse(const char *yml);
1414
}
1515
#endif

excuter/op-mem-ompsimd/src/deepx/op/concat.hpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,24 @@
77
namespace deepx::op
88
{
99
template <typename T>
10-
class Concat : public Op<T>
10+
class Concat : public OpT<T>
1111
{
1212
public:
13-
Concat(std::vector<string> input,string output,int axis)
14-
{
15-
this->name = std::string("concat") + "_" + dtype<T>::name();
16-
this->args = input;
17-
this->returns.push_back(output);
18-
std::string axisstr=std::to_string(axis);
19-
this->args.push_back(axisstr);
20-
};
13+
Concat(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
14+
this->init("concat", args, returns, require_grad, args_grad, returns_grad);
15+
}
16+
Concat(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
17+
this->init("concat", args, returns, require_grad, args_grad, returns_grad);
18+
}
2119
void forward(mem::Mem<T> &mem) override
2220
{
2321
std::vector<Tensor<T>*> input;
2422
for (int i=0;i<this->args.size()-1;i++){
2523
input.push_back(mem.get(this->args[i]).get());
2624
}
2725
auto output = mem.get(this->returns[0]).get();
28-
int axis = std::stoi(this->args.back());
26+
27+
int axis = mem.get<int>(this->args.back());
2928
cpu::concat(input,axis,*output);
3029
};
3130
void backward(mem::Mem<T> &mem) override
@@ -34,7 +33,7 @@ namespace deepx::op
3433
for (int i=0;i<this->args.size()-1;i++){
3534
input.push_back(mem.get(this->args[i]).get());
3635
}
37-
int axis = std::stoi(this->args.back());
36+
int axis = mem.get<int>(this->args.back());
3837
auto output = mem.get(this->returns[0]).get();
3938
cpu::split(output,axis,input);
4039
};

0 commit comments

Comments
 (0)