@@ -17,11 +17,12 @@ namespace deepx::op
1717 class Add : public OpT <T>
1818 {
1919 public:
20+ Add ()=default ;
2021 Add (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
21- this ->init (" add" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
22+ this ->init (" add" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
2223 }
2324 Add (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
24- this ->init (" add" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
25+ this ->init (" add" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
2526 }
2627 void forward (mem::Mem &mem) override
2728 {
@@ -43,15 +44,17 @@ namespace deepx::op
4344 deepx::tensorfunc::add (*b_grad, *c_grad, *b_grad); // b_grad += c_grad
4445 }
4546 };
47+ // 注册add算子
48+
4649 template <typename T>
4750 class Add_scalar : public OpT <T>
4851 {
4952 public:
5053 Add_scalar (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
51- this ->init (" add_scalar" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
54+ this ->init (" add_scalar" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
5255 }
5356 Add_scalar (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
54- this ->init (" add_scalar" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
57+ this ->init (" add_scalar" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
5558 }
5659 // 已验证,2025-02-19,lipeng
5760 void forward (mem::Mem &mem) override
@@ -78,10 +81,10 @@ namespace deepx::op
7881 {
7982 public:
8083 Sub (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
81- this ->init (" sub" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
84+ this ->init (" sub" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
8285 }
8386 Sub (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
84- this ->init (" sub" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
87+ this ->init (" sub" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
8588 }
8689 void forward (mem::Mem &mem) override
8790 {
@@ -109,10 +112,10 @@ namespace deepx::op
109112 {
110113 public:
111114 Mul (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
112- this ->init (" mul" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
115+ this ->init (" mul" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
113116 }
114117 Mul (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
115- this ->init (" mul" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
118+ this ->init (" mul" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
116119 }
117120 void forward (mem::Mem &mem) override
118121 {
@@ -145,10 +148,10 @@ namespace deepx::op
145148 {
146149 public:
147150 Mul_scalar (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
148- this ->init (" mul_scalar" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
151+ this ->init (" mul_scalar" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
149152 }
150153 Mul_scalar (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
151- this ->init (" mul_scalar" + dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
154+ this ->init (" mul_scalar" , dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
152155 }
153156 // 已验证,2025-02-19,lipeng
154157 void forward (mem::Mem &mem) override
@@ -179,10 +182,10 @@ namespace deepx::op
179182 {
180183 public:
181184 Div (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
182- this ->init (" div" , args, returns, require_grad, args_grad, returns_grad);
185+ this ->init (" div" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
183186 }
184187 Div (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
185- this ->init (" div" , args, returns, require_grad, args_grad, returns_grad);
188+ this ->init (" div" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
186189 }
187190 void forward (mem::Mem &mem) override
188191 {
@@ -221,10 +224,10 @@ namespace deepx::op
221224 {
222225 public:
223226 Div_scalar (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
224- this ->init (" div_scalar" , args, returns, require_grad, args_grad, returns_grad);
227+ this ->init (" div_scalar" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
225228 }
226229 Div_scalar (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
227- this ->init (" div_scalar" , args, returns, require_grad, args_grad, returns_grad);
230+ this ->init (" div_scalar" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
228231 }
229232 // 已验证,2025-02-19,lipeng
230233 void forward (mem::Mem &mem) override
@@ -254,10 +257,10 @@ namespace deepx::op
254257 class Sqrt : public OpT <T>{
255258 public:
256259 Sqrt (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
257- this ->init (" sqrt" , args, returns, require_grad, args_grad, returns_grad);
260+ this ->init (" sqrt" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
258261 }
259262 Sqrt (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
260- this ->init (" sqrt" , args, returns, require_grad, args_grad, returns_grad);
263+ this ->init (" sqrt" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
261264 }
262265 void forward (mem::Mem &mem) override
263266 {
@@ -284,10 +287,10 @@ namespace deepx::op
284287 {
285288 public:
286289 Exp (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
287- this ->init (" exp" , args, returns, require_grad, args_grad, returns_grad);
290+ this ->init (" exp" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
288291 }
289292 Exp (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
290- this ->init (" exp" , args, returns, require_grad, args_grad, returns_grad);
293+ this ->init (" exp" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
291294 }
292295 void forward (mem::Mem &mem) override
293296 {
@@ -315,10 +318,10 @@ namespace deepx::op
315318 {
316319 public:
317320 Pow (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
318- this ->init (" pow" , args, returns, require_grad, args_grad, returns_grad);
321+ this ->init (" pow" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
319322 }
320323 Pow (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
321- this ->init (" pow" , args, returns, require_grad, args_grad, returns_grad);
324+ this ->init (" pow" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
322325 }
323326 void forward (mem::Mem &mem) override
324327 {
@@ -361,10 +364,10 @@ namespace deepx::op
361364 {
362365 public:
363366 Pow_scalar (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
364- this ->init (" pow_scalar" , args, returns, require_grad, args_grad, returns_grad);
367+ this ->init (" pow_scalar" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
365368 }
366369 Pow_scalar (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
367- this ->init (" pow_scalar" , args, returns, require_grad, args_grad, returns_grad);
370+ this ->init (" pow_scalar" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
368371 }
369372 void forward (mem::Mem &mem) override
370373 {
@@ -399,10 +402,10 @@ namespace deepx::op
399402 {
400403 public:
401404 Log (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
402- this ->init (" log" , args, returns, require_grad, args_grad, returns_grad);
405+ this ->init (" log" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
403406 }
404407 Log (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
405- this ->init (" log" , args, returns, require_grad, args_grad, returns_grad);
408+ this ->init (" log" ,dtype<T>:: name (), args, returns, require_grad, args_grad, returns_grad);
406409 }
407410 void forward (mem::Mem &mem) override
408411 {
0 commit comments