44from .deepxir import DeepxIR
55from deepx .scheduler import send
66
7- def _A_B_op_C (
7+ def _A_B_elementwiseop_C (
88 a :Tensor ,
9- b : Optional [ Union [ Tensor , float , int ]] = None ,
9+ b : Tensor ,
1010 op :str = None ,
1111 out :Tensor = None ):
1212 opnode = a .graph .add_op (op )
1313 opnode .add_input (a .node )
14- varnode = None
15- if isinstance (b ,Tensor ):
16- opnode .add_input (b .node )
17- else :
18- varnode = a .graph .add_var ("" ,b )
19- opnode .add_input (varnode )
14+ opnode .add_input (b .node )
2015 out .node .add_input (opnode )
2116 if a .graph .eager :
22- if isinstance (b ,Tensor ):
23- ir = DeepxIR (op , a .dtype , [a .node .name , b .node .name ], [out .node .name ])
24- else :
25- varir = DeepxIR ("argset" , a .dtype , [b ], [varnode .name ])
26- send (str (varir ))
27- ir = DeepxIR (op + "_scalar" , a .dtype , [a .node .name ,varnode .name ], [out .node .name ])
17+ ir = DeepxIR (op , a .dtype , [a .node .name , b .node .name ], [out .node .name ])
18+ send (str (ir ))
19+ def _A_b_elementwiseop_C (
20+ a :Tensor ,
21+ b : Optional [Union [ float , int ]] = None ,
22+ op :str = None ,
23+ out :Tensor = None ):
24+ opnode = a .graph .add_op (op )
25+ opnode .add_input (a .node )
26+ varnode = a .graph .add_var ("" ,b )
27+ opnode .add_input (varnode )
28+ out .node .add_input (opnode )
29+ if a .graph .eager :
30+ varir = DeepxIR ("argset" , a .dtype , [b ], [varnode .name ])
31+ send (str (varir ))
32+ ir = DeepxIR (op + "_scalar" , a .dtype , [a .node .name ,varnode .name ], [out .node .name ])
2833 send (str (ir ))
29-
3034#add
3135OpNode .register ("add" )
36+ OpNode .register ("add_scalar" )
37+
3238def add (
3339 a :Tensor ,
3440 b : Optional [Union [Tensor , float , int ]] = None ,
3541 out :Tensor = None ):
36- _A_B_op_C (a ,b ,"add" ,out )
42+ if isinstance (b ,Tensor ):
43+ _A_B_elementwiseop_C (a ,b ,"add" ,out )
44+ else :
45+ _A_b_elementwiseop_C (a ,b ,"add" ,out )
3746
3847@tensor_method
3948def add_ (self , other ):
@@ -42,8 +51,16 @@ def add_(self, other):
4251 return result
4352#sub
4453OpNode .register ("sub" )
45- def sub (a :Tensor ,b :Tensor ,out :Tensor ):
46- _A_B_op_C (a ,b ,out )
54+ OpNode .register ("sub_scalar" )
55+
56+ def sub (
57+ a :Tensor ,
58+ b : Optional [Union [Tensor , float , int ]] = None ,
59+ out :Tensor = None ):
60+ if isinstance (b ,Tensor ):
61+ _A_B_elementwiseop_C (a ,b ,"sub" ,out )
62+ else :
63+ _A_b_elementwiseop_C (a ,b ,"sub" ,out )
4764@tensor_method
4865def sub_ (self , other ):
4966 result = Tensor (dtype = self .dtype ,shape = self .shape )
@@ -52,8 +69,16 @@ def sub_(self, other):
5269
5370#mul
5471OpNode .register ("mul" )
55- def mul (a :Tensor ,b :Tensor ,out :Tensor ):
56- _A_B_op_C (a ,b ,"mul" ,out )
72+ OpNode .register ("mul_scalar" )
73+
74+ def mul (
75+ a :Tensor ,
76+ b : Optional [Union [Tensor , float , int ]] = None ,
77+ out :Tensor = None ):
78+ if isinstance (b ,Tensor ):
79+ _A_B_elementwiseop_C (a ,b ,"mul" ,out )
80+ else :
81+ _A_b_elementwiseop_C (a ,b ,"mul" ,out )
5782@tensor_method
5883def mul_ (self , other ):
5984 result = Tensor (dtype = self .dtype ,shape = self .shape )
@@ -63,35 +88,23 @@ def mul_(self, other):
6388
6489#div
6590OpNode .register ("div" )
66- def div (a :Tensor ,b :Tensor ,out :Tensor ):
67- _A_B_op_C (a ,b ,"div" ,out )
91+ OpNode .register ("div_scalar" )
92+
93+ def div (
94+ a :Tensor ,
95+ b : Optional [Union [Tensor , float , int ]] = None ,
96+ out :Tensor = None ):
97+ if isinstance (b ,Tensor ):
98+ _A_B_elementwiseop_C (a ,b ,"div" ,out )
99+ else :
100+ _A_b_elementwiseop_C (a ,b ,"div" ,out )
68101@tensor_method
69102def div_ (self , other ):
70103 result = Tensor (dtype = self .dtype ,shape = self .shape )
71104 div (self ,other ,result )
72105 return result
73106
74107
75- #max
76- OpNode .register ("max" )
77- def max (a :Tensor ,b :Tensor ,out :Tensor ):
78- _A_B_op_C (a ,b ,"max" ,out )
79-
80- @tensor_method
81- def max_ (self , other ):
82- result = Tensor (dtype = self .dtype ,shape = self .shape )
83- max (self ,other ,result )
84- return result
85- #min
86- OpNode .register ("min" )
87- def min (a :Tensor ,b :Tensor ,out :Tensor ):
88- _A_B_op_C (a ,b ,"min" ,out )
89-
90- @tensor_method
91- def min_ (self , other ):
92- result = Tensor (dtype = self .dtype ,shape = self .shape )
93- min (self ,other ,result )
94- return result
95108
96109# OpNode.register("ReLU", 101)
97110# OpNode.register("Placeholder", 102)
0 commit comments