Skip to content

Commit 65ea28c

Browse files
committed
changeshape:broadcast
1 parent 233992d commit 65ea28c

7 files changed

Lines changed: 224 additions & 26 deletions

File tree

front/py/deepx/nn/functional/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .matmul import matmul
55
from .init import *
66
from .reduce import reduce_max,reduce_min,sum,prod,mean
7-
from .changeshape import transpose,reshape,broadcast_shape,broadcast,unsqueeze
7+
from .changeshape import transpose,reshape,broadcast_shape,broadcast_to,unsqueeze
88
from .activite import relu,sigmoid,swish
99
__all__ = [
1010
"newtensor",
@@ -13,7 +13,7 @@
1313
"add","sub","mul","div","clamp","exp","sqrt","rsqrt",
1414
"matmul",
1515
"max","min","sum","prod","mean",
16-
"transpose","reshape","broadcast_shape","broadcast","unsqueeze",
16+
"transpose","reshape","broadcast_shape","broadcast_to","unsqueeze",
1717
"relu","sigmoid","swish",
18-
"broadcast"
18+
1919
]

front/py/deepx/nn/functional/changeshape.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,13 @@ def expand(t:Tensor,shape:list[int],out:Union[Tensor,str]='')->Tensor:
138138
ir=DeepxIR("expand",'',[t.node.name,*map(str, shape)], [outtensor.node.name])
139139
send(ir)
140140
return outtensor
141-
# 修复 broadcast 函数缩进
142-
def broadcast(a: Tensor, b: Tensor) -> Tuple[Tensor,Tensor]:
141+
142+
def broadcast_to(a: Tensor, shape: tuple,out:Union[Tensor,str]='') -> Tensor:
143143
# 计算广播后的形状
144144
try:
145-
target_shape = broadcast_shape(a.shape, b.shape)
145+
target_shape = broadcast_shape(a.shape, shape)
146+
if target_shape!=shape:
147+
raise ValueError(f"广播失败:{a.shape} 无法广播为 {shape} ")
146148
except ValueError as e:
147149
raise ValueError(f"广播失败:{e}") from e
148150

@@ -152,21 +154,11 @@ def broadcast(a: Tensor, b: Tensor) -> Tuple[Tensor,Tensor]:
152154
a_reshaped = reshape(a,a_reshape)
153155
else:
154156
a_reshaped=a
155-
if b.shape != target_shape:
156-
b_reshape = [1] * (len(target_shape) - b.ndimension) + list(b.shape)
157-
b_reshaped = reshape(b,b_reshape)
158-
else:
159-
b_reshaped=b
160-
157+
161158
# 执行实际广播
162159
if a_reshaped.shape != target_shape:
163-
a_broadcasted = expand(a_reshaped,target_shape)
160+
a_broadcasted = expand(a_reshaped,target_shape,out)
164161
else:
165162
a_broadcasted=a_reshaped
166-
167-
if b_reshaped.shape != target_shape:
168-
b_broadcasted = expand(b_reshaped,target_shape)
169-
else:
170-
b_broadcasted=b_reshaped
171163

172-
return a_broadcasted, b_broadcasted
164+
return a_broadcasted

front/py/deepx/tensor/changeshape.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,16 @@ def expand(self,shape:tuple)->Tensor:
3030
from deepx.nn.functional import expand as expand_func
3131
result=expand_func(self,shape,False)
3232
return result
33+
34+
@tensor_method
35+
def broadcastshape(self,other:Tensor)->tuple[int]:
36+
from deepx.nn.functional import broadcastshape as broadcastshape_func
37+
result=broadcastshape_func(self.shape,other.shape)
38+
return result
39+
40+
@tensor_method
41+
def broadcast_to(self,shape:tuple,out:Union[Tensor,str]='')->Tensor:
42+
from deepx.nn.functional import broadcast_to as broadcast_to_func
43+
result=broadcast_to_func(self,shape,out)
44+
return result
45+

front/py/examples/2_ir/2_broadcast.py

Lines changed: 0 additions & 7 deletions
This file was deleted.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Computational Graph
2+
digraph {
3+
rankdir=TB
4+
node [shape=record]
5+
139089324388896 [label="a
6+
(4, 2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
7+
139089324378240 [label=constant color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
8+
139089324389088 [label="var_1
9+
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
10+
139089324621536 [label="b
11+
(2, 1)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
12+
139089325015104 [label=constant color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
13+
139089324621584 [label="var_2
14+
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
15+
139089325015200 [label="tensor_3
16+
(1, 2, 1)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
17+
139089325015488 [label=reshape color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
18+
139089325015440 [label="vector_1
19+
[1, 2, 1]" color=darkseagreen fillcolor=honeydew fontname="Sans-Serif" labeljust=l shape=box style=filled]
20+
139089325015584 [label="b.broadcasted
21+
(4, 2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
22+
139089325015872 [label=expand color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
23+
139089325015824 [label="vector_2
24+
(4, 2, 3)" color=darkseagreen fillcolor=honeydew fontname="Sans-Serif" labeljust=l shape=box style=filled]
25+
139089324378240 -> 139089324388896 [arrowsize=0.8 color=gray40 penwidth=1.2]
26+
139089324389088 -> 139089324378240 [arrowsize=0.8 color=gray40 penwidth=1.2]
27+
139089325015104 -> 139089324621536 [arrowsize=0.8 color=gray40 penwidth=1.2]
28+
139089324621584 -> 139089325015104 [arrowsize=0.8 color=gray40 penwidth=1.2]
29+
139089325015488 -> 139089325015200 [arrowsize=0.8 color=gray40 penwidth=1.2]
30+
139089324621536 -> 139089325015488 [arrowsize=0.8 color=gray40 penwidth=1.2]
31+
139089325015440 -> 139089325015488 [arrowsize=0.8 color=gray40 penwidth=1.2]
32+
139089325015872 -> 139089325015584 [arrowsize=0.8 color=gray40 penwidth=1.2]
33+
139089325015200 -> 139089325015872 [arrowsize=0.8 color=gray40 penwidth=1.2]
34+
139089325015824 -> 139089325015872 [arrowsize=0.8 color=gray40 penwidth=1.2]
35+
}
Lines changed: 153 additions & 0 deletions
Loading
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from deepx import Tensor,ones,broadcast_to
2+
3+
a=ones( 4,2,3 ,name="a")
4+
b=ones( 2,1 ,name='b')
5+
6+
bb=b.broadcast_to( a.shape,out="b.broadcasted")
7+
8+
print(bb)
9+
import os
10+
script_name = os.path.splitext(os.path.basename( os.path.abspath(__file__)))[0] # 获取不带后缀的脚本名
11+
str=b.graph.to_dot()
12+
str.render(script_name+".dot", format='svg')

0 commit comments

Comments
 (0)