77namespace deepx ::op
88{
99 template <typename T>
10- class Concat : public Op <T>
10+ class Concat : public OpT <T>
1111 {
1212 public:
13- Concat (std::vector<string> input,string output,int axis)
14- {
15- this ->name = std::string (" concat" ) + " _" + dtype<T>::name ();
16- this ->args = input;
17- this ->returns .push_back (output);
18- std::string axisstr=std::to_string (axis);
19- this ->args .push_back (axisstr);
20- };
13+ Concat (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
14+ this ->init (" concat" , args, returns, require_grad, args_grad, returns_grad);
15+ }
16+ Concat (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
17+ this ->init (" concat" , args, returns, require_grad, args_grad, returns_grad);
18+ }
2119 void forward (mem::Mem<T> &mem) override
2220 {
2321 std::vector<Tensor<T>*> input;
2422 for (int i=0 ;i<this ->args .size ()-1 ;i++){
2523 input.push_back (mem.get (this ->args [i]).get ());
2624 }
2725 auto output = mem.get (this ->returns [0 ]).get ();
28- int axis = std::stoi (this ->args .back ());
26+
27+ int axis = mem.get <int >(this ->args .back ());
2928 cpu::concat (input,axis,*output);
3029 };
3130 void backward (mem::Mem<T> &mem) override
@@ -34,7 +33,7 @@ namespace deepx::op
3433 for (int i=0 ;i<this ->args .size ()-1 ;i++){
3534 input.push_back (mem.get (this ->args [i]).get ());
3635 }
37- int axis = std::stoi (this ->args .back ());
36+ int axis = mem. get < int > (this ->args .back ());
3837 auto output = mem.get (this ->returns [0 ]).get ();
3938 cpu::split (output,axis,input);
4039 };
0 commit comments