@@ -28,43 +28,48 @@ def _A_B_elementwiseop_C(
2828 send (ir )
2929 return outtensor
3030def _A_b_elementwiseop_C (
31- a :Optional [ Union [ Tensor , float , int ]] = None ,
32- b : Optional [ Union [ Tensor , float , int ]] = None ,
31+ a :Tensor ,
32+ b : Union [ float , int ] ,
3333 op :str = None ,
3434 out :Union [Tensor ,str ]= "" )-> Tensor :
35- if isinstance (a ,Tensor ):
36- g = a .graph
37- else :
38- g = b .graph
39-
35+ g = a .graph
4036 opnode = g .add_op (op )
41- if isinstance (a ,Tensor ):
42- opnode .add_input (a .node )
43- else :
44- varnode = g .add_var ("" ,a )
45- opnode .add_input (varnode )
37+ opnode .add_input (a .node )
38+ opnode .add_input (g .add_var ("" ,b ))
4639
47- if isinstance (b ,Tensor ):
48- opnode .add_input (b .node )
40+ outtensor = None
41+ if isinstance (out ,str ):
42+ outtensor = Tensor (shape = a .shape , dtype = a .dtype , device = a .device )
43+ outtensor .addtograph (out )
4944 else :
50- varnode = g .add_var ("" ,b )
51- opnode .add_input (varnode )
45+ outtensor = out
46+ outtensor .node .add_input (opnode )
47+ if g .eager :
48+ ir = DeepxIR (op , a .dtype , [a .node .name ,b ], [outtensor .node .name ])
49+ send (ir )
50+ return outtensor
51+ def _a_B_elementwiseop_C (
52+ a : Union [ float , int ] ,
53+ b : Tensor ,
54+ op :str = None ,
55+ out :Union [Tensor ,str ]= "" )-> Tensor :
56+ g = b .graph
57+ opnode = g .add_op (op )
58+ opnode .add_input (g .add_var ("" ,a ))
59+ opnode .add_input (b .node )
5260
5361 outtensor = None
5462 if isinstance (out ,str ):
55- outtensor = Tensor (shape = a .shape , dtype = a .dtype , device = a .device )
63+ outtensor = Tensor (shape = b .shape , dtype = b .dtype , device = b .device )
5664 outtensor .addtograph (out )
5765 else :
5866 outtensor = out
5967 outtensor .node .add_input (opnode )
6068 if g .eager :
61- ir = None
62- if isinstance (a ,Tensor ):
63- ir = DeepxIR (op , a .dtype , [a .node .name ,b ], [outtensor .node .name ])
64- else :
65- ir = DeepxIR (op , b .dtype , [a ,b .node .name ], [outtensor .node .name ])
69+ ir = DeepxIR (op , b .dtype , [a ,b .node .name ], [outtensor .node .name ])
6670 send (ir )
6771 return outtensor
72+
6873#add
6974OpNode .register ("add" )
7075OpNode .register ("add_scalar" )
@@ -122,7 +127,7 @@ def div(
122127 return _A_b_elementwiseop_C (a ,b ,"div_scalar" ,out )
123128 else :
124129 #C=a/B
125- return _A_b_elementwiseop_C (a ,b ,"rdiv_scalar" ,out )
130+ return _a_B_elementwiseop_C (a ,b ,"rdiv_scalar" ,out )
126131#clamp
127132OpNode .register ("clamp" )
128133def clamp (
0 commit comments