Skip to content

Commit 233992d

Browse files
committed
changeshape:broadcast
1 parent e75a938 commit 233992d

25 files changed

Lines changed: 132 additions & 1471 deletions

doc/excuter/op-mem-ompsimd/list.md

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,34 @@
44

55
| Operation | Data Types | Math Formula | IR Instruction |
66
|-----------|------------|--------------|----------------|
7+
| add_scalar | float32, float64 | T2 = T1 + 1.0 | add_scalar@float32 T1 1.0 -> T2 |
8+
| rdiv_scalar | float32, float64 | T3 =1 / T2 | rdiv_scalar@float32 1 T2 -> T3 |
9+
| constant | float32, float64 | T1 = full(shape, 0.0) | constant@float32 0.0 -> T1 |
10+
| uniform | float32, float64 | uniform(-1.0, 1.0,T1) | uniform@float32 -1.0 1.0 -> T1 |
11+
| mul_scalar | float32, float64 | T2 = T1 * 2.0 | mul_scalar@float32 T1 2.0 -> T2 |
12+
| deltensor | any | del T1 | deltensor@any T1 -> |
13+
| sub | float32, float64 | T3 = T1 - T2 | sub@int32 T1 T2 -> T3 |
714
| sum | float32, float64 | T2 = sum(T1, dims=[1,2]) | sum@float32 T1 1 2 -> T2 |
8-
| matmul | float32, float64 | T3 = T1 @ T2 | matmul@float32 T1 T2 -> T3 |
9-
| concat | float32, float64 | T3 = concat([T1, T2], axis=3) | concat@float32 T1 T2 3 -> T3 |
10-
| pow_scalar | float32, float64 | T2 = T1 ^ 2.0 | pow_scalar@float32 T1 2.0 -> T2 |
11-
| pow | float32, float64 | T3 = T1 ^ T2 | pow@float32 T1 T2 -> T3 |
12-
| max_scalar | float32, float64 | T2 = max(T1, 0.0) | max_scalar@float32 T1 0.0 -> T2 |
13-
| exp | float32, float64 | T2 = exp(T1) | exp@float32 T1 -> T2 |
15+
| argset | float32, float64, int32 | shape = [3, 4, 5] | argset@int32 3 4 5 -> shape |
16+
| arange | float32, float64 | arange(start=0.0, step=1.0,T1) | arange@float32 0.0 1.0 -> T1 |
17+
| add | float32, float64 | T3 = T1 + T2 | add@int32 T1 T2 -> T3 |
18+
| copytensor | float32, float64, int16, int32, int64, int8 | T2 = T1.copy() | copytensor@float32 T1 -> T2 |
19+
| min | float32, float64 | C = min(A,B) | min@float32 A B -> C |
20+
| print | any | | print@any -> |
21+
| newtensor | float32, float64, int16, int32, int64, int8 | T1 = zeros(shape) | newtensor@float32 shape -> T1 |
22+
| div | float32, float64 | T3 = T1 / T2 | div@float32 T1 T2 -> T3 |
23+
| div_scalar | float32, float64 | T2 = T1 / 2.0 | div_scalar@float32 T1 2.0 -> T2 |
24+
| reshape | any | T2 = reshape(T1, [2,3,4]) | reshape@float32 T1 2 3 4 -> T2 |
1425
| min_scalar | float32, float64 | B= min(A, 1.0) | min_scalar@float32 A 1.0 -> B |
1526
| sqrt | float32, float64 | T2 = sqrt(T1) | sqrt@float32 T1 -> T2 |
16-
| div | float32, float64 | T3 = T1 / T2 | div@float32 T1 T2 -> T3 |
1727
| mul | float32, float64 | T3 = T1 * T2 | mul@float32 T1 T2 -> T3 |
18-
| newtensor | float32, float64, int16, int32, int64, int8 | T1 = zeros(shape) | newtensor@float32 shape -> T1 |
19-
| print | any | | print@any -> |
20-
| min | float32, float64 | C = min(A,B) | min@float32 A B -> C |
21-
| copytensor | float32, float64, int16, int32, int64, int8 | T2 = T1.copy() | copytensor@float32 T1 -> T2 |
22-
| clonetensor | float32, float64, int16, int32, int64, int8 | T2 = T1.clone() | clonetensor@float32 T1 -> T2 |
23-
| arange | float32, float64 | arange(start=0.0, step=1.0,T1) | arange@float32 0.0 1.0 -> T1 |
24-
| argset | float32, float64, int32 | shape = [3, 4, 5] | argset@int32 3 4 5 -> shape |
25-
| sub | float32, float64 | T3 = T1 - T2 | sub@int32 T1 T2 -> T3 |
26-
| mul_scalar | float32, float64 | T2 = T1 * 2.0 | mul_scalar@float32 T1 2.0 -> T2 |
27-
| uniform | float32, float64 | uniform(-1.0, 1.0,T1) | uniform@float32 -1.0 1.0 -> T1 |
28-
| add | float32, float64 | T3 = T1 + T2 | add@int32 T1 T2 -> T3 |
28+
| exp | float32, float64 | T2 = exp(T1) | exp@float32 T1 -> T2 |
29+
| max_scalar | float32, float64 | T2 = max(T1, 0.0) | max_scalar@float32 T1 0.0 -> T2 |
2930
| max | float32, float64 | T3 = max(T1,T2) | max@float32 T1 -> T2 |
30-
| constant | float32, float64 | T1 = full(shape, 0.0) | constant@float32 0.0 -> T1 |
31-
| rdiv_scalar | float32, float64 | T3 =1 / T2 | rdiv_scalar@float32 1 T2 -> T3 |
32-
| add_scalar | float32, float64 | T2 = T1 + 1.0 | add_scalar@float32 T1 1.0 -> T2 |
31+
| pow | float32, float64 | T3 = T1 ^ T2 | pow@float32 T1 T2 -> T3 |
32+
| pow_scalar | float32, float64 | T2 = T1 ^ 2.0 | pow_scalar@float32 T1 2.0 -> T2 |
33+
| matmul | float32, float64 | T3 = T1 @ T2 | matmul@float32 T1 T2 -> T3 |
34+
| clonetensor | float32, float64, int16, int32, int64, int8 | T2 = T1.clone() | clonetensor@float32 T1 -> T2 |
3335
| transpose | any | T2 = transpose(T1, dimorder=[1,0]) | transpose@float32 T1 1 0 -> T2 |
34-
| div_scalar | float32, float64 | T2 = T1 / 2.0 | div_scalar@float32 T1 2.0 -> T2 |
35-
| reshape | any | T2 = reshape(T1, [2,3,4]) | reshape@float32 T1 2 3 4 -> T2 |
36+
| expand | any | T2 = expand(T1, axis=[4,6,12]) | expand@float32 T1 4 6 12 -> T2 |
37+
| concat | float32 | T3 = concat([T1, T2], axis=3) | concat@float32 T1 T2 3 -> T3 |

excuter/op-mem-ompsimd/src/deepx/op/changeshape.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,9 @@ namespace deepx::op
230230
auto input_grad = mem.gettensor<T>(this->args_grad[0]).get();
231231
auto output_grad = mem.gettensor<T>(this->returns_grad[0]).get();
232232
vector<int> target_shape = this->getvector<int32_t>( 1);
233-
vector<int> axis = sumaxis(target_shape);
233+
vector<int> axis = this->sumaxis(input_grad->shape.shape,target_shape);
234234
// sum,按指定维度求和
235-
tensorfunc::sum(*output_grad, *input_grad, axis);
235+
tensorfunc::sum(*output_grad, axis,*input_grad);
236236
}
237237
void setexample() override
238238
{

excuter/op-mem-ompsimd/src/deepx/op/opfactory.cpp

Lines changed: 48 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
#include "deepx/op/print.hpp"
99
#include "deepx/op/changeshape.hpp"
1010
namespace deepx::op
11-
{
12-
//new
13-
void register_new(OpFactory &opfactory){
11+
{
12+
// tensor
13+
void register_lifecycle(OpFactory &opfactory)
14+
{
1415
opfactory.add_op(NewTensor<int8_t>());
1516
opfactory.add_op(NewTensor<int16_t>());
1617
opfactory.add_op(NewTensor<int32_t>());
@@ -35,114 +36,84 @@ namespace deepx::op
3536
opfactory.add_op(ArgSet<int32_t>());
3637
opfactory.add_op(ArgSet<float>());
3738
opfactory.add_op(ArgSet<double>());
38-
}
39-
//del
40-
void register_del(OpFactory &opfactory){
39+
4140
opfactory.add_op(DelTensor<float>());
4241
}
43-
//init
44-
void register_uniform(OpFactory &opfactory){
42+
43+
// init
44+
void register_init(OpFactory &opfactory)
45+
{
4546
opfactory.add_op(Uniform<float>());
4647
opfactory.add_op(Uniform<double>());
47-
}
48-
void register_constant(OpFactory &opfactory){
48+
4949
opfactory.add_op(Constant<float>());
5050
opfactory.add_op(Constant<double>());
51-
}
52-
void register_arange(OpFactory &opfactory){
51+
5352
opfactory.add_op(Arange<float>());
5453
opfactory.add_op(Arange<double>());
5554
}
56-
void register_init(OpFactory &opfactory){
57-
register_uniform(opfactory);
58-
register_constant(opfactory);
59-
register_arange(opfactory);
60-
}
61-
//anytype
62-
void register_anytype(OpFactory &opfactory){
55+
// io
56+
void register_util(OpFactory &opfactory)
57+
{
6358
opfactory.add_op(Print<float>());
64-
65-
opfactory.add_op(Transpose<float>());
66-
67-
opfactory.add_op(Reshape<float>());
6859
}
69-
//elementwise
70-
void register_add(OpFactory &opfactory){
60+
61+
// elementwise
62+
void register_elementwise(OpFactory &opfactory)
63+
{
7164
opfactory.add_op(Add<float>());
7265
opfactory.add_op(Add<double>());
73-
}
74-
void register_add_scalar(OpFactory &opfactory){
66+
7567
opfactory.add_op(Add_scalar<float>());
7668
opfactory.add_op(Add_scalar<double>());
77-
}
78-
void register_sub(OpFactory &opfactory){
69+
7970
opfactory.add_op(Sub<float>());
8071
opfactory.add_op(Sub<double>());
81-
}
8272

83-
void register_mul(OpFactory &opfactory){
8473
opfactory.add_op(Mul<float>());
8574
opfactory.add_op(Mul<double>());
86-
}
87-
void register_mul_scalar(OpFactory &opfactory){
75+
8876
opfactory.add_op(Mul_scalar<float>());
8977
opfactory.add_op(Mul_scalar<double>());
90-
}
91-
void register_div(OpFactory &opfactory){
78+
9279
opfactory.add_op(Div<float>());
9380
opfactory.add_op(Div<double>());
94-
}
95-
void register_div_scalar(OpFactory &opfactory){
81+
9682
opfactory.add_op(Div_scalar<float>());
9783
opfactory.add_op(Div_scalar<double>());
98-
}
99-
void register_rdiv_scalar(OpFactory &opfactory){
84+
10085
opfactory.add_op(RDiv_scalar<float>());
10186
opfactory.add_op(RDiv_scalar<double>());
102-
}
103-
void register_sqrt(OpFactory &opfactory){
87+
10488
opfactory.add_op(Sqrt<float>());
10589
opfactory.add_op(Sqrt<double>());
106-
}
107-
void register_exp(OpFactory &opfactory){
90+
10891
opfactory.add_op(Exp<float>());
10992
opfactory.add_op(Exp<double>());
110-
}
111-
void register_pow(OpFactory &opfactory){
93+
11294
opfactory.add_op(Pow<float>());
11395
opfactory.add_op(Pow<double>());
114-
}
115-
void register_pow_scalar(OpFactory &opfactory){
96+
11697
opfactory.add_op(Pow_scalar<float>());
11798
opfactory.add_op(Pow_scalar<double>());
11899
}
119-
void register_elementwise_op(OpFactory &opfactory){
120-
register_add(opfactory);
121-
register_add_scalar(opfactory);
122-
register_sub(opfactory);
123-
register_mul(opfactory);
124-
register_mul_scalar(opfactory);
125-
register_div(opfactory);
126-
register_div_scalar(opfactory);
127-
register_rdiv_scalar(opfactory);
128-
register_sqrt(opfactory);
129-
register_exp(opfactory);
130-
register_pow(opfactory);
131-
register_pow_scalar(opfactory);
132-
}
133-
//concat
134-
135-
void register_concat(OpFactory &opfactory){
136-
opfactory.add_op(Concat<float>());
137-
opfactory.add_op(Concat<double>());
138-
}
139-
//matmul
140-
void register_matmul(OpFactory &opfactory){
100+
// matmul
101+
void register_matmul(OpFactory &opfactory)
102+
{
141103
opfactory.add_op(MatMul<float>());
142104
opfactory.add_op(MatMul<double>());
143105
}
144-
//reduce
145-
void register_reduce(OpFactory &opfactory){
106+
// changeshape
107+
void register_changeshape(OpFactory &opfactory)
108+
{
109+
opfactory.add_op(Transpose<float>());
110+
opfactory.add_op(Reshape<float>());
111+
opfactory.add_op(Expand<float>());
112+
opfactory.add_op(Concat<float>());
113+
}
114+
// reduce
115+
void register_reduce(OpFactory &opfactory)
116+
{
146117
opfactory.add_op(Max<float>());
147118
opfactory.add_op(Max<double>());
148119
opfactory.add_op(Max_scalar<float>());
@@ -154,13 +125,14 @@ namespace deepx::op
154125
opfactory.add_op(Sum<float>());
155126
opfactory.add_op(Sum<double>());
156127
}
157-
int register_all(OpFactory &opfactory){
158-
register_new(opfactory);
128+
int register_all(OpFactory &opfactory)
129+
{
130+
register_lifecycle(opfactory);
159131
register_init(opfactory);
160-
register_anytype(opfactory);
161-
register_elementwise_op(opfactory);
162-
register_concat(opfactory);
132+
register_util(opfactory);
133+
register_elementwise(opfactory);
163134
register_matmul(opfactory);
135+
register_changeshape(opfactory);
164136
register_reduce(opfactory);
165137
return 0;
166138
}

front/py/deepx/nn/functional/changeshape.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,34 @@ def reshape(t:Tensor,shape:list[int],inplace:bool=False,out:Union[Tensor,str]=''
7777

7878
def broadcast_shape(shape_a: tuple, shape_b: tuple) -> tuple:
7979
"""计算两个形状的广播后形状"""
80-
# 从右往左对齐维度
81-
reversed_dims = zip(reversed(shape_a), reversed(shape_b))
82-
new_shape = []
83-
for dim_a, dim_b in reversed_dims:
84-
if dim_a == 1:
85-
new_dim = dim_b
86-
elif dim_b == 1:
87-
new_dim = dim_a
88-
elif dim_a != dim_b:
89-
raise ValueError(f"无法广播的形状:{shape_a}{shape_b}")
80+
# 获取形状的长度
81+
len_a, len_b = len(shape_a), len(shape_b)
82+
83+
# 创建结果形状
84+
result_shape = []
85+
86+
# 从右往左对齐并计算每个维度
87+
for i in range(1, min(len_a, len_b) + 1):
88+
dim_a = shape_a[-i]
89+
dim_b = shape_b[-i]
90+
91+
if dim_a == 1 or dim_b == 1:
92+
# 广播规则:如果一个维度为1,取另一个维度的值
93+
result_shape.insert(0, max(dim_a, dim_b))
94+
elif dim_a == dim_b:
95+
# 维度相同,保持不变
96+
result_shape.insert(0, dim_a)
9097
else:
91-
new_dim = dim_a
92-
new_shape.append(new_dim)
98+
# 维度不同且都不为1,无法广播
99+
raise ValueError(f"无法广播的形状:{shape_a}{shape_b}")
93100

94-
# 处理长度不同的形状
95-
max_ndim = max(len(shape_a), len(shape_b))
96-
new_shape += [1] * (max_ndim - len(new_shape))
101+
# 添加较长形状中多出的前导维度
102+
if len_a > len_b:
103+
result_shape = list(shape_a[:len_a - len_b]) + result_shape
104+
elif len_b > len_a:
105+
result_shape = list(shape_b[:len_b - len_a]) + result_shape
97106

98-
return tuple(reversed(new_shape))
107+
return tuple(result_shape)
99108

100109

101110
def unsqueeze(t:Tensor,dim:int)->Tensor:
@@ -112,6 +121,7 @@ def unsqueeze(t:Tensor,dim:int)->Tensor:
112121

113122
return reshape(t, new_shape)
114123

124+
OpNode.register("expand")
115125
def expand(t:Tensor,shape:list[int],out:Union[Tensor,str]='')->Tensor:
116126
outtensor=None
117127
if isinstance(out,str):

front/py/deepx/tensor/changeshape.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,9 @@ def reshape_(self,*shape)->Tensor:
2424
from deepx.nn.functional import reshape as reshape_func
2525
result=reshape_func(self,shape,True)
2626
return result
27+
28+
@tensor_method
29+
def expand(self,shape:tuple)->Tensor:
30+
from deepx.nn.functional import expand as expand_func
31+
result=expand_func(self,shape,False)
32+
return result
File renamed without changes.
File renamed without changes.
File renamed without changes.

front/py/examples/2_ir/2_init.dot

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)