@@ -46,7 +46,8 @@ namespace deepx::op
4646 deepx::tensorfunc::add (*b_grad, *c_grad, *b_grad); // b_grad += c_grad
4747 }
4848 };
49-
49+
50+ // Add_scalar
5051 template <typename T>
5152 class Add_scalar : public OpT <T>
5253 {
@@ -63,10 +64,10 @@ namespace deepx::op
6364 // 已验证,2025-02-19,lipeng
6465 void forward (mem::Mem &mem) override
6566 {
66- auto a = mem.gettensor <T>(this ->args [0 ]);
67- auto b = mem. getarg <T>( this ->args [ 1 ] );
68- auto c = mem.gettensor <T>(this ->returns [0 ]);
69- deepx::tensorfunc::add (*a , b, *c );
67+ auto A= mem.gettensor <T>(this ->args [0 ]). get ( );
68+ auto b = this ->getarg ( 1 ,mem );
69+ auto C = mem.gettensor <T>(this ->returns [0 ]). get ( );
70+ deepx::tensorfunc::add (*A , b, *C );
7071 }
7172 // 已验证,2025-02-19,lipeng
7273 void backward (mem::Mem &mem) override
@@ -169,16 +170,16 @@ namespace deepx::op
169170 // 已验证,2025-02-19,lipeng
170171 void forward (mem::Mem &mem) override
171172 {
172- auto a = mem.gettensor <T>(this ->args [0 ]).get ();
173- auto b = mem. getarg <T>( this ->args [ 1 ] );
174- auto c = mem.gettensor <T>(this ->returns [0 ]).get ();
175- deepx::tensorfunc::mul (*a , b, *c );
173+ auto A= mem.gettensor <T>(this ->args [0 ]).get ();
174+ auto b = this ->getarg ( 1 ,mem );
175+ auto C = mem.gettensor <T>(this ->returns [0 ]).get ();
176+ deepx::tensorfunc::mul (*A , b, *C );
176177 }
177178 // 已验证,2025-02-19,lipeng
178179 void backward (mem::Mem &mem) override
179180 {
180181 // 需要用到前向传播的标量输入b
181- auto b = mem. getarg <T>( this ->args [ 1 ]); // 获取标量b
182+ auto b = this ->getarg ( 1 ,mem);
182183 auto a_grad = mem.gettensor <T>(this ->args_grad [0 ]).get ();
183184 auto c_grad = mem.gettensor <T>(this ->returns_grad [0 ]).get ();
184185
@@ -235,6 +236,7 @@ namespace deepx::op
235236 };
236237
237238 // Div_scalar之所以不复用Mul_scalar,是防止b接近0时,Mul_scalar(1/b)不稳定
239+ // A/b=C
238240 template <typename T>
239241 class Div_scalar : public OpT <T>
240242 {
@@ -251,25 +253,16 @@ namespace deepx::op
251253 // 已验证,2025-02-19,lipeng
252254 void forward (mem::Mem &mem) override
253255 {
254- if (mem.existstensor (this ->args [0 ])){
255- // C= A/b
256- auto A = mem.gettensor <T>(this ->args [0 ]).get ();
257- auto b = mem.getarg <T>(this ->args [1 ]);
258- auto C = mem.gettensor <T>(this ->returns [0 ]).get ();
259- tensorfunc::div_scalar (*A, b, *C); // 直接使用除法
260- }else {
261- // C=a/B
262- auto a = mem.getarg <T>(this ->args [0 ]);
263- auto B = mem.gettensor <T>(this ->args [1 ]).get ();
264- auto C = mem.gettensor <T>(this ->returns [0 ]).get ();
265- tensorfunc::div_scalar (a, *B, *C); // 直接使用除法
266- }
256+ auto A = mem.gettensor <T>(this ->args [0 ]).get ();
257+ auto b = this ->getarg (1 ,mem);
258+ auto C = mem.gettensor <T>(this ->returns [0 ]).get ();
259+ tensorfunc::div_scalar (*A, b, *C); // 直接使用除法
267260 }
268261
269262 // 已验证,2025-02-19,lipeng
270263 void backward (mem::Mem &mem) override
271264 {
272- auto b = mem. getarg <T>( this ->args [ 1 ]); // 获取标量b
265+ auto b = this ->getarg ( 1 ,mem);
273266 auto a_grad = mem.gettensor <T>(this ->args_grad [0 ]).get ();
274267 auto c_grad = mem.gettensor <T>(this ->returns_grad [0 ]).get ();
275268
@@ -280,6 +273,53 @@ namespace deepx::op
280273 // 标量b不需要计算梯度
281274 }
282275 };
276+
277+
278+ template <typename T>
279+ class RDiv_scalar : public OpT <T>
280+ {
281+ public:
282+ RDiv_scalar (){
283+ this ->init (" rdiv_scalar" ,dtype<T>::name (), {}, {}, false , {}, {});
284+ }
285+ RDiv_scalar (vector< string> args, vector< string> returns, bool require_grad = false , vector< string> args_grad = {}, vector< string> returns_grad = {}){
286+ this ->init (" rdiv_scalar" ,dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
287+ }
288+ RDiv_scalar (initializer_list< string> args, initializer_list< string> returns, bool require_grad = false , initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
289+ this ->init (" rdiv_scalar" ,dtype<T>::name (), args, returns, require_grad, args_grad, returns_grad);
290+ }
291+
292+ void forward (mem::Mem &mem) override
293+ {
294+ // C=a/B
295+ auto a = this ->getarg (0 ,mem);
296+ auto B = mem.gettensor <T>(this ->args [1 ]).get ();
297+ auto C = mem.gettensor <T>(this ->returns [0 ]).get ();
298+ tensorfunc::div_scalar (a, *B, *C); // 直接使用除法
299+
300+ }
301+
302+ // TODO: 未验证
303+ void backward (mem::Mem &mem) override
304+ {
305+ // 需要用到前向传播的输入
306+ auto a = this ->getarg (0 ,mem);
307+ auto B = mem.gettensor <T>(this ->args [1 ]).get ();
308+ auto C = mem.gettensor <T>(this ->returns [0 ]).get (); // C = a/B
309+ auto B_grad = mem.gettensor <T>(this ->args_grad [1 ]).get ();
310+ auto C_grad = mem.gettensor <T>(this ->returns_grad [0 ]).get ();
311+
312+ // 标量除法的反向传播:
313+ // 对于 C = a/B
314+ // ∂L/∂B = ∂L/∂C * ∂C/∂B = ∂L/∂C * (-a/B²)
315+ // = -C_grad * (a/B²) = -C_grad * (C/B)
316+ auto temp = mem.temptensor <T>(B->shape .shape ).get ();
317+ deepx::tensorfunc::div (*C, *B, *temp); // temp = C/B
318+ deepx::tensorfunc::muladd (*C_grad, *temp, T (-1 ), *B_grad, T (1 ), *B_grad); // B_grad -= C_grad * temp
319+
320+ // 标量a不需要计算梯度
321+ }
322+ };
283323
284324 template <typename T>
285325 class Sqrt : public OpT <T>
0 commit comments