Skip to content

Commit 05469d7

Browse files
committed
py:tensor的+-*/已修正
1 parent 036704f commit 05469d7

3 files changed

Lines changed: 237 additions & 23 deletions

File tree

front/py/deepx/nn/functional/elementwise.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,43 +28,48 @@ def _A_B_elementwiseop_C(
2828
send(ir)
2929
return outtensor
3030
def _A_b_elementwiseop_C(
31-
a:Optional[Union[ Tensor, float, int]] = None,
32-
b: Optional[Union[ Tensor, float, int]] = None,
31+
a:Tensor,
32+
b: Union[ float, int] ,
3333
op:str=None,
3434
out:Union[Tensor,str]="")->Tensor:
35-
if isinstance(a,Tensor):
36-
g=a.graph
37-
else:
38-
g=b.graph
39-
35+
g=a.graph
4036
opnode = g.add_op(op)
41-
if isinstance(a,Tensor):
42-
opnode.add_input(a.node)
43-
else:
44-
varnode=g.add_var("",a)
45-
opnode.add_input(varnode)
37+
opnode.add_input(a.node)
38+
opnode.add_input(g.add_var("",b))
4639

47-
if isinstance(b,Tensor):
48-
opnode.add_input(b.node)
40+
outtensor=None
41+
if isinstance(out,str):
42+
outtensor=Tensor(shape=a.shape, dtype=a.dtype, device=a.device)
43+
outtensor.addtograph(out)
4944
else:
50-
varnode=g.add_var("",b)
51-
opnode.add_input(varnode)
45+
outtensor=out
46+
outtensor.node.add_input(opnode)
47+
if g.eager:
48+
ir=DeepxIR(op, a.dtype, [a.node.name,b], [outtensor.node.name])
49+
send(ir)
50+
return outtensor
51+
def _a_B_elementwiseop_C(
52+
a: Union[ float, int] ,
53+
b: Tensor,
54+
op:str=None,
55+
out:Union[Tensor,str]="")->Tensor:
56+
g=b.graph
57+
opnode = g.add_op(op)
58+
opnode.add_input(g.add_var("",a))
59+
opnode.add_input(b.node)
5260

5361
outtensor=None
5462
if isinstance(out,str):
55-
outtensor=Tensor(shape=a.shape, dtype=a.dtype, device=a.device)
63+
outtensor=Tensor(shape=b.shape, dtype=b.dtype, device=b.device)
5664
outtensor.addtograph(out)
5765
else:
5866
outtensor=out
5967
outtensor.node.add_input(opnode)
6068
if g.eager:
61-
ir=None
62-
if isinstance(a,Tensor):
63-
ir=DeepxIR(op, a.dtype, [a.node.name,b], [outtensor.node.name])
64-
else:
65-
ir=DeepxIR(op, b.dtype, [a,b.node.name], [outtensor.node.name])
69+
ir=DeepxIR(op, b.dtype, [a,b.node.name], [outtensor.node.name])
6670
send(ir)
6771
return outtensor
72+
6873
#add
6974
OpNode.register("add")
7075
OpNode.register("add_scalar")
@@ -122,7 +127,7 @@ def div(
122127
return _A_b_elementwiseop_C(a,b,"div_scalar",out)
123128
else:
124129
#C=a/B
125-
return _A_b_elementwiseop_C(a,b,"rdiv_scalar",out)
130+
return _a_B_elementwiseop_C(a,b,"rdiv_scalar",out)
126131
#clamp
127132
OpNode.register("clamp")
128133
def clamp(

front/py/examples/2_ir/6_math.dot

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Computational Graph
2+
digraph {
3+
rankdir=TB
4+
node [shape=record]
5+
136406866977952 [label="tensor_1
6+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
7+
136406866572096 [label=constant color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
8+
136406866572000 [label="var_1
9+
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
10+
136406866572048 [label=add_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
11+
136406866984816 [label="var_2
12+
3" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
13+
136406866561488 [label="tensor_2
14+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
15+
136406866571808 [label="tensor_3
16+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
17+
136406866571568 [label=sqrt color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
18+
136406866571472 [label="tensor_4
19+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
20+
136406866570944 [label=sqrt color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
21+
136406866570848 [label=rdiv_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
22+
136406866570992 [label="var_3
23+
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
24+
136406866570608 [label="tensor_5
25+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
26+
136406866572096 -> 136406866977952 [arrowsize=0.8 color=gray40 penwidth=1.2]
27+
136406866572000 -> 136406866572096 [arrowsize=0.8 color=gray40 penwidth=1.2]
28+
136406866977952 -> 136406866572048 [arrowsize=0.8 color=gray40 penwidth=1.2]
29+
136406866984816 -> 136406866572048 [arrowsize=0.8 color=gray40 penwidth=1.2]
30+
136406866572048 -> 136406866561488 [arrowsize=0.8 color=gray40 penwidth=1.2]
31+
136406866571568 -> 136406866571808 [arrowsize=0.8 color=gray40 penwidth=1.2]
32+
136406866561488 -> 136406866571568 [arrowsize=0.8 color=gray40 penwidth=1.2]
33+
136406866570944 -> 136406866571472 [arrowsize=0.8 color=gray40 penwidth=1.2]
34+
136406866561488 -> 136406866570944 [arrowsize=0.8 color=gray40 penwidth=1.2]
35+
136406866570992 -> 136406866570848 [arrowsize=0.8 color=gray40 penwidth=1.2]
36+
136406866571472 -> 136406866570848 [arrowsize=0.8 color=gray40 penwidth=1.2]
37+
136406866570848 -> 136406866570608 [arrowsize=0.8 color=gray40 penwidth=1.2]
38+
}
Lines changed: 171 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)