Skip to content

Commit 22c724c

Browse files
committed
changeshape:补充expand等shape算子
1 parent bcaf203 commit 22c724c

24 files changed

Lines changed: 495 additions & 438 deletions
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#ifndef DEEPX_OP_CHANGESHAPE_HPP
2+
#define DEEPX_OP_CHANGESHAPE_HPP
3+
4+
#include "deepx/op/op.hpp"
5+
#include "deepx/tensorfunc/changeshape.hpp"
6+
#include "deepx/dtype.hpp"
7+
8+
namespace deepx::op
9+
{
10+
template <typename T>
11+
class Concat : public Op{
12+
public:
13+
Concat(){
14+
this->init("concat",deepx::dtype<T>::name(), {}, {}, false, {}, {});
15+
}
16+
Concat(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
17+
this->init("concat",deepx::dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
18+
}
19+
Concat(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
20+
this->init("concat",deepx::dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
21+
}
22+
void setexample() override {
23+
this->init("concat", "float32", {"T1", "T2", "3"}, {"T3"}, false, {}, {});
24+
}
25+
string math_formula() const override {
26+
return "T3 = concat([T1, T2], axis=3)";
27+
}
28+
void forward(mem::Mem &mem) override
29+
{
30+
std::vector<Tensor<T>*> input;
31+
for (int i=0;i<this->args.size()-1;i++){
32+
input.push_back(mem.gettensor<T>(this->args[i]).get());
33+
}
34+
auto output = mem.gettensor<T>(this->returns[0]).get();
35+
36+
int axis = mem.getarg<int>(this->args.back());
37+
tensorfunc::concat(input,axis,*output);
38+
};
39+
void backward(mem::Mem &mem) override
40+
{
41+
std::vector<Tensor<T>*> input;
42+
for (int i=0;i<this->args.size()-1;i++){
43+
input.push_back(mem.gettensor<T>(this->args[i]).get());
44+
}
45+
int axis = mem.getarg<int>(this->args.back());
46+
auto output = mem.gettensor<T>(this->returns[0]).get();
47+
tensorfunc::split(*output,axis,input);
48+
};
49+
};
50+
51+
template <typename T>
52+
class Reshape : public Op
53+
{
54+
public:
55+
Reshape()
56+
{
57+
this->init("reshape", "any", {}, {}, false, {}, {});
58+
}
59+
void forward(mem::Mem &mem) override
60+
{
61+
auto input = mem.gettensor<T>(this->args[0]).get();
62+
auto output = mem.gettensor<T>(this->returns[0]).get();
63+
vector<int> shape;
64+
if (this->args.size() == 2 && !is_integer(this->args[1]))
65+
{
66+
shape = mem.getvector<int32_t>(this->args[1]);
67+
}
68+
else
69+
{
70+
for (int i = 1; i < this->args.size(); i++)
71+
{
72+
shape.push_back(atoi(this->args[i].c_str()));
73+
}
74+
}
75+
tensorfunc::reshape(*input, *output, shape);
76+
}
77+
void backward(mem::Mem &mem) override
78+
{
79+
auto return_grad = mem.gettensor<T>(this->returns_grad[0]).get();
80+
auto input_grad = mem.gettensor<T>(this->args_grad[0]).get();
81+
auto input = mem.gettensor<T>(this->args[0]).get();
82+
vector<int> shape = input->shape.shape;
83+
tensorfunc::reshape(*return_grad, *input_grad, shape);
84+
}
85+
void setexample() override {
86+
this->init("reshape", "float32", {"T1", "2","3","4"}, {"T2"}, false, {}, {});
87+
}
88+
string math_formula() const override {
89+
return "T2 = reshape(T1, [2,3,4])";
90+
}
91+
};
92+
93+
template <typename T>
94+
class Transpose : public Op {
95+
public:
96+
Transpose() {
97+
this->init("transpose", "any", {}, {}, false, {}, {});
98+
}
99+
Transpose(vector<string> args, vector<string> returns, bool require_grad = false, vector<string> args_grad = {}, vector<string> returns_grad = {}) {
100+
this->init("transpose", "any", args, returns, require_grad, args_grad, returns_grad);
101+
}
102+
Transpose(initializer_list<string> args, initializer_list<string> returns, bool require_grad = false, initializer_list<string> args_grad = {}, initializer_list<string> returns_grad = {}) {
103+
this->init("transpose", "any", args, returns, require_grad, args_grad, returns_grad);
104+
}
105+
void forward(mem::Mem &mem) override {
106+
auto input = mem.gettensor<T>(this->args[0]).get();
107+
vector<int> dimOrder;
108+
if (this->args.size()==2&&!is_integer(this->args[1])){
109+
dimOrder=mem.getvector<int32_t>(this->args[1]);
110+
}else if (this->args.size()>2){
111+
for (int i = 1; i < this->args.size(); i++) {
112+
dimOrder.push_back(atoi(this->args[i].c_str()));
113+
}
114+
}
115+
auto output = mem.gettensor<T>(this->returns[0]).get();
116+
tensorfunc::transpose(*input, *output, dimOrder);
117+
}
118+
void backward(mem::Mem &mem) override {
119+
auto input_grad = mem.gettensor<T>(this->args_grad[0]).get();
120+
vector<int> dimOrder;
121+
if (this->args.size()==2&&!is_integer(this->args[1])){
122+
dimOrder=mem.getvector<int32_t>(this->args[1]);
123+
}else if (this->args.size()>2){
124+
for (int i = 1; i < this->args.size(); i++) {
125+
dimOrder.push_back(atoi(this->args[i].c_str()));
126+
}
127+
}
128+
auto output_grad = mem.gettensor<T>(this->returns_grad[0]).get();
129+
tensorfunc::transpose(*output_grad, *input_grad, dimOrder);
130+
}
131+
void setexample() override {
132+
this->init("transpose", "float32", {"T1", "1","0"}, {"T2"}, false, {}, {});
133+
}
134+
string math_formula() const override {
135+
return "T2 = transpose(T1, dimorder=[1,0])";
136+
}
137+
};
138+
}
139+
#endif // DEEPX_OP_CONCAT_HPP

excuter/op-mem-ompsimd/src/deepx/op/concat.hpp

Lines changed: 0 additions & 52 deletions
This file was deleted.

excuter/op-mem-ompsimd/src/deepx/op/matmul.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
#ifndef DEEPX_OP_MATMUL_HPP
22
#define DEEPX_OP_MATMUL_HPP
33

4-
#include <iostream>
5-
64
#include "deepx/shape_transpose.hpp"
75
#include "deepx/op/op.hpp"
86
#include "deepx/mem/mem.hpp"
9-
#include "deepx/tensorfunc/new.hpp"
107
#include "deepx/tensorfunc/matmul.hpp"
11-
#include "deepx/tensorfunc/transpose.hpp"
8+
#include "deepx/tensorfunc/changeshape.hpp"
129
namespace deepx::op
1310
{
1411
using namespace std;

excuter/op-mem-ompsimd/src/deepx/op/opfactory.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
#include "deepx/op/new.hpp"
77
#include "deepx/op/arg.hpp"
88
#include "deepx/op/print.hpp"
9-
#include "deepx/op/transpose.hpp"
10-
#include "deepx/op/reshape.hpp"
9+
#include "deepx/op/changeshape.hpp"
1110
namespace deepx::op
1211
{
1312
//new

excuter/op-mem-ompsimd/src/deepx/op/opfactory.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <algorithm>
1010

1111
#include "deepx/op/op.hpp"
12-
#include "deepx/op/concat.hpp"
1312
namespace deepx::op
1413
{
1514
using Op_dtype = std::unordered_map<std::string, std::shared_ptr<Op>>;

excuter/op-mem-ompsimd/src/deepx/op/reduce.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include "deepx/tensor.hpp"
55
#include "deepx/tensorfunc/reduce.hpp"
6-
#include "deepx/tensorfunc/broadcast.hpp"
6+
#include "deepx/tensorfunc/changeshape.hpp"
77
#include "deepx/tensorfunc/compare.hpp"
88
#include "stdutil/num.hpp"
99

excuter/op-mem-ompsimd/src/deepx/op/reshape.hpp

Lines changed: 0 additions & 54 deletions
This file was deleted.

excuter/op-mem-ompsimd/src/deepx/op/transpose.hpp

Lines changed: 0 additions & 60 deletions
This file was deleted.

0 commit comments

Comments
 (0)