1+ #ifndef DEEPX_OP_CHANGESHAPE_HPP
2+ #define DEEPX_OP_CHANGESHAPE_HPP
3+
4+ #include " deepx/op/op.hpp"
5+ #include " deepx/tensorfunc/changeshape.hpp"
6+ #include " deepx/dtype.hpp"
7+
8+ namespace deepx ::op
9+ {
10+ template <typename T>
11+ class Concat : public Op {
12+ public:
13+ Concat (){
14+ this ->init (" concat" ,deepx::dtype<T>::name (), {}, {}, false , {}, {});
15+ }
16+ Concat (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
17+ this ->init (" concat" ,deepx::dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
18+ }
19+ Concat (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
20+ this ->init (" concat" ,deepx::dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
21+ }
22+ void setexample () override {
23+ this ->init (" concat" , " float32" , {" T1" , " T2" , " 3" }, {" T3" }, false , {}, {});
24+ }
25+ string math_formula () const override {
26+ return " T3 = concat([T1, T2], axis=3)" ;
27+ }
28+ void forward (mem::Mem &mem) override
29+ {
30+ std::vector<Tensor<T>*> input;
31+ for (int i=0 ;i<this ->args .size ()-1 ;i++){
32+ input.push_back (mem.gettensor <T>(this ->args [i]).get ());
33+ }
34+ auto output = mem.gettensor <T>(this ->returns [0 ]).get ();
35+
36+ int axis = mem.getarg <int >(this ->args .back ());
37+ tensorfunc::concat (input,axis,*output);
38+ };
39+ void backward (mem::Mem &mem) override
40+ {
41+ std::vector<Tensor<T>*> input;
42+ for (int i=0 ;i<this ->args .size ()-1 ;i++){
43+ input.push_back (mem.gettensor <T>(this ->args [i]).get ());
44+ }
45+ int axis = mem.getarg <int >(this ->args .back ());
46+ auto output = mem.gettensor <T>(this ->returns [0 ]).get ();
47+ tensorfunc::split (*output,axis,input);
48+ };
49+ };
50+
51+ template <typename T>
52+ class Reshape : public Op
53+ {
54+ public:
55+ Reshape ()
56+ {
57+ this ->init (" reshape" , " any" , {}, {}, false , {}, {});
58+ }
59+ void forward (mem::Mem &mem) override
60+ {
61+ auto input = mem.gettensor <T>(this ->args [0 ]).get ();
62+ auto output = mem.gettensor <T>(this ->returns [0 ]).get ();
63+ vector<int > shape;
64+ if (this ->args .size () == 2 && !is_integer (this ->args [1 ]))
65+ {
66+ shape = mem.getvector <int32_t >(this ->args [1 ]);
67+ }
68+ else
69+ {
70+ for (int i = 1 ; i < this ->args .size (); i++)
71+ {
72+ shape.push_back (atoi (this ->args [i].c_str ()));
73+ }
74+ }
75+ tensorfunc::reshape (*input, *output, shape);
76+ }
77+ void backward (mem::Mem &mem) override
78+ {
79+ auto return_grad = mem.gettensor <T>(this ->returns_grad [0 ]).get ();
80+ auto input_grad = mem.gettensor <T>(this ->args_grad [0 ]).get ();
81+ auto input = mem.gettensor <T>(this ->args [0 ]).get ();
82+ vector<int > shape = input->shape .shape ;
83+ tensorfunc::reshape (*return_grad, *input_grad, shape);
84+ }
85+ void setexample () override {
86+ this ->init (" reshape" , " float32" , {" T1" , " 2" ," 3" ," 4" }, {" T2" }, false , {}, {});
87+ }
88+ string math_formula () const override {
89+ return " T2 = reshape(T1, [2,3,4])" ;
90+ }
91+ };
92+
93+ template <typename T>
94+ class Transpose : public Op {
95+ public:
96+ Transpose () {
97+ this ->init (" transpose" , " any" , {}, {}, false , {}, {});
98+ }
99+ Transpose (vector<string> args, vector<string> returns, bool require_grad = false , vector<string> args_grad = {}, vector<string> returns_grad = {}) {
100+ this ->init (" transpose" , " any" , args, returns, require_grad, args_grad, returns_grad);
101+ }
102+ Transpose (initializer_list<string> args, initializer_list<string> returns, bool require_grad = false , initializer_list<string> args_grad = {}, initializer_list<string> returns_grad = {}) {
103+ this ->init (" transpose" , " any" , args, returns, require_grad, args_grad, returns_grad);
104+ }
105+ void forward (mem::Mem &mem) override {
106+ auto input = mem.gettensor <T>(this ->args [0 ]).get ();
107+ vector<int > dimOrder;
108+ if (this ->args .size ()==2 &&!is_integer (this ->args [1 ])){
109+ dimOrder=mem.getvector <int32_t >(this ->args [1 ]);
110+ }else if (this ->args .size ()>2 ){
111+ for (int i = 1 ; i < this ->args .size (); i++) {
112+ dimOrder.push_back (atoi (this ->args [i].c_str ()));
113+ }
114+ }
115+ auto output = mem.gettensor <T>(this ->returns [0 ]).get ();
116+ tensorfunc::transpose (*input, *output, dimOrder);
117+ }
118+ void backward (mem::Mem &mem) override {
119+ auto input_grad = mem.gettensor <T>(this ->args_grad [0 ]).get ();
120+ vector<int > dimOrder;
121+ if (this ->args .size ()==2 &&!is_integer (this ->args [1 ])){
122+ dimOrder=mem.getvector <int32_t >(this ->args [1 ]);
123+ }else if (this ->args .size ()>2 ){
124+ for (int i = 1 ; i < this ->args .size (); i++) {
125+ dimOrder.push_back (atoi (this ->args [i].c_str ()));
126+ }
127+ }
128+ auto output_grad = mem.gettensor <T>(this ->returns_grad [0 ]).get ();
129+ tensorfunc::transpose (*output_grad, *input_grad, dimOrder);
130+ }
131+ void setexample () override {
132+ this ->init (" transpose" , " float32" , {" T1" , " 1" ," 0" }, {" T2" }, false , {}, {});
133+ }
134+ string math_formula () const override {
135+ return " T2 = transpose(T1, dimorder=[1,0])" ;
136+ }
137+ };
138+ }
139+ #endif // DEEPX_OP_CONCAT_HPP
0 commit comments