Skip to content

Commit cbfcebe

Browse files
committed
max,min待参考pytorch
1 parent d5b668e commit cbfcebe

2 files changed

Lines changed: 99 additions & 46 deletions

File tree

front/py/deepx/tensor/elementwise.py

Lines changed: 55 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,45 @@
44
from .deepxir import DeepxIR
55
from deepx.scheduler import send
66

7-
def _A_B_op_C(
7+
def _A_B_elementwiseop_C(
88
a:Tensor,
9-
b: Optional[Union[Tensor, float, int]] = None,
9+
b: Tensor,
1010
op:str=None,
1111
out:Tensor=None):
1212
opnode = a.graph.add_op(op)
1313
opnode.add_input(a.node)
14-
varnode=None
15-
if isinstance(b,Tensor):
16-
opnode.add_input(b.node)
17-
else:
18-
varnode=a.graph.add_var("",b)
19-
opnode.add_input(varnode)
14+
opnode.add_input(b.node)
2015
out.node.add_input(opnode)
2116
if a.graph.eager:
22-
if isinstance(b,Tensor):
23-
ir=DeepxIR(op, a.dtype, [a.node.name, b.node.name], [out.node.name])
24-
else:
25-
varir=DeepxIR("argset", a.dtype, [b], [varnode.name])
26-
send(str(varir))
27-
ir=DeepxIR(op+"_scalar", a.dtype, [a.node.name,varnode.name], [out.node.name])
17+
ir=DeepxIR(op, a.dtype, [a.node.name, b.node.name], [out.node.name])
18+
send(str(ir))
19+
def _A_b_elementwiseop_C(
20+
a:Tensor,
21+
b: Optional[Union[ float, int]] = None,
22+
op:str=None,
23+
out:Tensor=None):
24+
opnode = a.graph.add_op(op)
25+
opnode.add_input(a.node)
26+
varnode=a.graph.add_var("",b)
27+
opnode.add_input(varnode)
28+
out.node.add_input(opnode)
29+
if a.graph.eager:
30+
varir=DeepxIR("argset", a.dtype, [b], [varnode.name])
31+
send(str(varir))
32+
ir=DeepxIR(op+"_scalar", a.dtype, [a.node.name,varnode.name], [out.node.name])
2833
send(str(ir))
29-
3034
#add
3135
OpNode.register("add")
36+
OpNode.register("add_scalar")
37+
3238
def add(
3339
a:Tensor,
3440
b: Optional[Union[Tensor, float, int]] = None,
3541
out:Tensor=None):
36-
_A_B_op_C(a,b,"add",out)
42+
if isinstance(b,Tensor):
43+
_A_B_elementwiseop_C(a,b,"add",out)
44+
else:
45+
_A_b_elementwiseop_C(a,b,"add",out)
3746

3847
@tensor_method
3948
def add_(self, other):
@@ -42,8 +51,16 @@ def add_(self, other):
4251
return result
4352
#sub
4453
OpNode.register("sub")
45-
def sub(a:Tensor,b:Tensor,out:Tensor):
46-
_A_B_op_C(a,b,out)
54+
OpNode.register("sub_scalar")
55+
56+
def sub(
57+
a:Tensor,
58+
b: Optional[Union[Tensor, float, int]] = None,
59+
out:Tensor=None):
60+
if isinstance(b,Tensor):
61+
_A_B_elementwiseop_C(a,b,"sub",out)
62+
else:
63+
_A_b_elementwiseop_C(a,b,"sub",out)
4764
@tensor_method
4865
def sub_(self, other):
4966
result = Tensor(dtype=self.dtype,shape=self.shape)
@@ -52,8 +69,16 @@ def sub_(self, other):
5269

5370
#mul
5471
OpNode.register("mul")
55-
def mul(a:Tensor,b:Tensor,out:Tensor):
56-
_A_B_op_C(a,b,"mul",out)
72+
OpNode.register("mul_scalar")
73+
74+
def mul(
75+
a:Tensor,
76+
b: Optional[Union[Tensor, float, int]] = None,
77+
out:Tensor=None):
78+
if isinstance(b,Tensor):
79+
_A_B_elementwiseop_C(a,b,"mul",out)
80+
else:
81+
_A_b_elementwiseop_C(a,b,"mul",out)
5782
@tensor_method
5883
def mul_(self, other):
5984
result = Tensor(dtype=self.dtype,shape=self.shape)
@@ -63,35 +88,23 @@ def mul_(self, other):
6388

6489
#div
6590
OpNode.register("div")
66-
def div(a:Tensor,b:Tensor,out:Tensor):
67-
_A_B_op_C(a,b,"div",out)
91+
OpNode.register("div_scalar")
92+
93+
def div(
94+
a:Tensor,
95+
b: Optional[Union[Tensor, float, int]] = None,
96+
out:Tensor=None):
97+
if isinstance(b,Tensor):
98+
_A_B_elementwiseop_C(a,b,"div",out)
99+
else:
100+
_A_b_elementwiseop_C(a,b,"div",out)
68101
@tensor_method
69102
def div_(self, other):
70103
result = Tensor(dtype=self.dtype,shape=self.shape)
71104
div(self,other,result)
72105
return result
73106

74107

75-
#max
76-
OpNode.register("max")
77-
def max(a:Tensor,b:Tensor,out:Tensor):
78-
_A_B_op_C(a,b,"max",out)
79-
80-
@tensor_method
81-
def max_(self, other):
82-
result = Tensor(dtype=self.dtype,shape=self.shape)
83-
max(self,other,result)
84-
return result
85-
#min
86-
OpNode.register("min")
87-
def min(a:Tensor,b:Tensor,out:Tensor):
88-
_A_B_op_C(a,b,"min",out)
89-
90-
@tensor_method
91-
def min_(self, other):
92-
result = Tensor(dtype=self.dtype,shape=self.shape)
93-
min(self,other,result)
94-
return result
95108

96109
# OpNode.register("ReLU", 101)
97110
# OpNode.register("Placeholder", 102)

front/py/deepx/tensor/reduce.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from deepx.autograd.graph import OpNode
55
from .deepxir import DeepxIR
66
from deepx.scheduler import send
7+
from .elementwise import _A_b_elementwiseop_C
78

8-
def _A_v_op_C(
9+
def _A_v_reduceop_C(
910
a:Tensor,
1011
v: Optional[Union[Tensor, float, int]] = None,
1112
op:str=None,
@@ -21,12 +22,51 @@ def _A_v_op_C(
2122
send(str(varir))
2223
ir=DeepxIR(op+"_scalar", a.dtype, [a.node.name,vector_node.name], [out.node.name])
2324
send(str(ir))
24-
25+
26+
27+
#max
28+
OpNode.register("max")
29+
OpNode.register("max_scalar")
30+
31+
def max(
32+
a:Tensor,
33+
b:Optional[Union[float,int],Union[Tensor,float,int]]=None,
34+
out:Tensor=None):
35+
if isinstance(b,list):
36+
_A_v_reduceop_C(a,b,"max",out)
37+
else:
38+
_A_b_elementwiseop_C(a,b,"max_scalar",out)
39+
40+
@tensor_method
41+
def max_(self, other):
42+
result = Tensor(dtype=self.dtype,shape=self.shape)
43+
max(self,other,result)
44+
return result
45+
46+
#min
47+
OpNode.register("min")
48+
OpNode.register("min_scalar")
49+
50+
def min(a:Tensor,b:Tensor,out:Tensor):
51+
if isinstance(b,list):
52+
_A_v_reduceop_C(a,b,"min",out)
53+
else:
54+
_A_b_elementwiseop_C(a,b,"min_scalar",out)
55+
56+
@tensor_method
57+
def min_(self, other):
58+
result = Tensor(dtype=self.dtype,shape=self.shape)
59+
min(self,other,result)
60+
return result
61+
2562

2663
#sum
2764
OpNode.register("sum")
28-
def sum(a:Tensor,b:Tensor,out:Tensor):
29-
_A_v_op_C(a,b,"sum",out)
65+
def sum(
66+
a:Tensor,
67+
b:list[int],
68+
out:Tensor):
69+
_A_v_reduceop_C(a,b,"sum",out)
3070

3171
@tensor_method
3272
def sum_(self, other):

0 commit comments

Comments
 (0)