Skip to content

Commit a9fb13b

Browse files
committed
rsqrt:实现一个rsqrt要改这么多
1 parent 9308ed6 commit a9fb13b

13 files changed

Lines changed: 423 additions & 55 deletions

File tree

doc/deepxIR/func.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
`torch.rsqrt`是PyTorch中的一个函数,用于计算输入张量每个元素的平方根的倒数。在数学上,对于一个数 \(x\)\(\text{rsqrt}(x)\) 的计算公式为:
2+
3+
\(\text{rsqrt}(x)=\frac{1}{\sqrt{x}}\)
4+
5+
其中,\(\sqrt{x}\) 表示 \(x\) 的平方根。
6+
7+
例如,对于张量 `x = torch.tensor([4., 9., 16.])``torch.rsqrt(x)` 将返回 `tensor([0.5000, 0.3333, 0.2500])`,分别是 \(4\)\(9\)\(16\) 的平方根的倒数。
8+
9+
从数学原理上来说,`torch.rsqrt` 是基于浮点数的运算规则来实现的。在计算平方根的倒数时,它会先计算平方根,然后再取倒数。在计算机中,浮点数的表示和运算有一定的精度限制,因此在实际计算中可能会存在一些微小的误差。

excuter/common/src/deepx/mem/mem.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ namespace deepx::mem
129129
{
130130
return mem.find(name) != mem.end();
131131
}
132+
bool existarg(const string &name) const
133+
{
134+
return args.find(name) != args.end();
135+
}
132136

133137
template <typename T>
134138
shared_ptr<Tensor<T>> gettensor(const string &name) const

excuter/common/src/deepx/op/op.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,15 @@ namespace deepx::op
7272
{
7373
return deepx::dtype<T>::name();
7474
}
75+
T getarg(int idx,mem::Mem &mem){
76+
auto x = T(0);
77+
if (mem.existarg(this->args[idx])){
78+
x = mem.getarg<T>(this->args[idx]);
79+
}else{
80+
x = T(std::stof(this->args[idx].c_str()));
81+
}
82+
return x;
83+
}
7584
};
7685
}
7786
#endif

excuter/op-mem-ompsimd/src/deepx/op/elementwise.hpp

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ namespace deepx::op
4646
deepx::tensorfunc::add(*b_grad, *c_grad, *b_grad); // b_grad += c_grad
4747
}
4848
};
49-
49+
50+
//Add_scalar
5051
template <typename T>
5152
class Add_scalar : public OpT<T>
5253
{
@@ -63,10 +64,10 @@ namespace deepx::op
6364
//已验证,2025-02-19,lipeng
6465
void forward(mem::Mem &mem) override
6566
{
66-
auto a = mem.gettensor<T>(this->args[0]);
67-
auto b = mem.getarg<T>(this->args[1]);
68-
auto c = mem.gettensor<T>(this->returns[0]);
69-
deepx::tensorfunc::add(*a, b, *c);
67+
auto A=mem.gettensor<T>(this->args[0]).get();
68+
auto b = this->getarg(1,mem);
69+
auto C = mem.gettensor<T>(this->returns[0]).get();
70+
deepx::tensorfunc::add(*A, b, *C);
7071
}
7172
//已验证,2025-02-19,lipeng
7273
void backward(mem::Mem &mem) override
@@ -169,16 +170,16 @@ namespace deepx::op
169170
//已验证,2025-02-19,lipeng
170171
void forward(mem::Mem &mem) override
171172
{
172-
auto a = mem.gettensor<T>(this->args[0]).get();
173-
auto b = mem.getarg<T>(this->args[1]);
174-
auto c = mem.gettensor<T>(this->returns[0]).get();
175-
deepx::tensorfunc::mul(*a, b, *c);
173+
auto A=mem.gettensor<T>(this->args[0]).get();
174+
auto b = this->getarg(1,mem);
175+
auto C = mem.gettensor<T>(this->returns[0]).get();
176+
deepx::tensorfunc::mul(*A, b, *C);
176177
}
177178
//已验证,2025-02-19,lipeng
178179
void backward(mem::Mem &mem) override
179180
{
180181
// 需要用到前向传播的标量输入b
181-
auto b = mem.getarg<T>(this->args[1]); // 获取标量b
182+
auto b = this->getarg(1,mem);
182183
auto a_grad = mem.gettensor<T>(this->args_grad[0]).get();
183184
auto c_grad = mem.gettensor<T>(this->returns_grad[0]).get();
184185

@@ -235,6 +236,7 @@ namespace deepx::op
235236
};
236237

237238
//Div_scalar之所以不复用Mul_scalar,是防止b接近0时,Mul_scalar(1/b)不稳定
239+
//A/b=C
238240
template <typename T>
239241
class Div_scalar : public OpT<T>
240242
{
@@ -251,25 +253,16 @@ namespace deepx::op
251253
//已验证,2025-02-19,lipeng
252254
void forward(mem::Mem &mem) override
253255
{
254-
if (mem.existstensor(this->args[0])){
255-
//C= A/b
256-
auto A = mem.gettensor<T>(this->args[0]).get();
257-
auto b = mem.getarg<T>(this->args[1]);
258-
auto C = mem.gettensor<T>(this->returns[0]).get();
259-
tensorfunc::div_scalar(*A, b, *C); // 直接使用除法
260-
}else{
261-
//C=a/B
262-
auto a = mem.getarg<T>(this->args[0]);
263-
auto B = mem.gettensor<T>(this->args[1]).get();
264-
auto C = mem.gettensor<T>(this->returns[0]).get();
265-
tensorfunc::div_scalar(a, *B, *C); // 直接使用除法
266-
}
256+
auto A = mem.gettensor<T>(this->args[0]).get();
257+
auto b = this->getarg(1,mem);
258+
auto C = mem.gettensor<T>(this->returns[0]).get();
259+
tensorfunc::div_scalar(*A, b, *C); // 直接使用除法
267260
}
268261

269262
//已验证,2025-02-19,lipeng
270263
void backward(mem::Mem &mem) override
271264
{
272-
auto b = mem.getarg<T>(this->args[1]); // 获取标量b
265+
auto b = this->getarg(1,mem);
273266
auto a_grad = mem.gettensor<T>(this->args_grad[0]).get();
274267
auto c_grad = mem.gettensor<T>(this->returns_grad[0]).get();
275268

@@ -280,6 +273,53 @@ namespace deepx::op
280273
// 标量b不需要计算梯度
281274
}
282275
};
276+
277+
278+
template <typename T>
279+
class RDiv_scalar : public OpT<T>
280+
{
281+
public:
282+
RDiv_scalar(){
283+
this->init("rdiv_scalar",dtype<T>::name(), {}, {}, false, {}, {});
284+
}
285+
RDiv_scalar(vector< string> args, vector< string> returns, bool require_grad = false, vector< string> args_grad = {}, vector< string> returns_grad = {}){
286+
this->init("rdiv_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
287+
}
288+
RDiv_scalar(initializer_list< string> args, initializer_list< string> returns, bool require_grad = false, initializer_list< string> args_grad = {}, initializer_list< string> returns_grad = {}){
289+
this->init("rdiv_scalar",dtype<T>::name(), args, returns, require_grad, args_grad, returns_grad);
290+
}
291+
292+
void forward(mem::Mem &mem) override
293+
{
294+
//C=a/B
295+
auto a = this->getarg(0,mem);
296+
auto B = mem.gettensor<T>(this->args[1]).get();
297+
auto C = mem.gettensor<T>(this->returns[0]).get();
298+
tensorfunc::div_scalar(a, *B, *C); // 直接使用除法
299+
300+
}
301+
302+
//TODO: 未验证
303+
void backward(mem::Mem &mem) override
304+
{
305+
// 需要用到前向传播的输入
306+
auto a = this->getarg(0,mem);
307+
auto B = mem.gettensor<T>(this->args[1]).get();
308+
auto C = mem.gettensor<T>(this->returns[0]).get(); // C = a/B
309+
auto B_grad = mem.gettensor<T>(this->args_grad[1]).get();
310+
auto C_grad = mem.gettensor<T>(this->returns_grad[0]).get();
311+
312+
// 标量除法的反向传播:
313+
// 对于 C = a/B
314+
// ∂L/∂B = ∂L/∂C * ∂C/∂B = ∂L/∂C * (-a/B²)
315+
// = -C_grad * (a/B²) = -C_grad * (C/B)
316+
auto temp = mem.temptensor<T>(B->shape.shape).get();
317+
deepx::tensorfunc::div(*C, *B, *temp); // temp = C/B
318+
deepx::tensorfunc::muladd(*C_grad, *temp, T(-1), *B_grad, T(1), *B_grad); // B_grad -= C_grad * temp
319+
320+
// 标量a不需要计算梯度
321+
}
322+
};
283323

284324
template <typename T>
285325
class Sqrt : public OpT<T>

excuter/op-mem-ompsimd/src/deepx/op/opfactory.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ namespace deepx::op
9797
opfactory.add_op(Div_scalar<float>());
9898
opfactory.add_op(Div_scalar<double>());
9999
}
100+
void register_rdiv_scalar(OpFactory &opfactory){
101+
opfactory.add_op(RDiv_scalar<float>());
102+
opfactory.add_op(RDiv_scalar<double>());
103+
}
100104
void register_sqrt(OpFactory &opfactory){
101105
opfactory.add_op(Sqrt<float>());
102106
opfactory.add_op(Sqrt<double>());
@@ -113,6 +117,7 @@ namespace deepx::op
113117
register_mul_scalar(opfactory);
114118
register_div(opfactory);
115119
register_div_scalar(opfactory);
120+
register_rdiv_scalar(opfactory);
116121
register_sqrt(opfactory);
117122
register_exp(opfactory);
118123
}

front/py/deepx/__init__.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,13 @@
11
from .tensor import Tensor,Shape,Device,DeviceType
2-
from deepx.nn.functional import *
2+
from deepx.nn.functional import * # 导入所有functional函数
3+
from deepx.nn.functional import __all__ as _func_all # 获取functional的导出列表
4+
35
__all__ = [
46
#tensor
57
'Tensor',
68
'Shape',
79
'Device','DeviceType',
8-
#nn.functional
9-
#init
10-
'full','zeros', 'ones', 'arange', 'rand', 'randn', 'eye',
11-
#elementwise
12-
"add","sub","mul","div","clamp",
13-
#matmul
14-
"matmul",
15-
#reduce
16-
"max","min","sum","prod","mean",
17-
#transpose
18-
"transpose",
19-
#relu
20-
"relu",
10+
*_func_all
2111
]
2212

2313
# 为了支持 import deepx as dx 的用法

front/py/deepx/nn/functional/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .elementwise import add,sub,mul,div,clamp
1+
from .elementwise import *
22
from .new import newtensor
33
from .print import printtensor
44
from .matmul import matmul
@@ -11,7 +11,7 @@
1111
"newtensor",
1212
"printtensor",
1313
"constant","full","zeros","ones","uniform","arange","rand","randn","eye",
14-
"add","sub","mul","div","clamp",
14+
"add","sub","mul","div","clamp","exp","sqrt","rsqrt",
1515
"matmul",
1616
"max","min","sum","prod","mean",
1717
"transpose","reshape",

front/py/deepx/nn/functional/elementwise.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,47 @@ def _A_B_elementwiseop_C(
99
b: Tensor,
1010
op:str=None,
1111
out:Tensor=None):
12-
opnode = a.graph.add_op(op)
12+
g=a.graph
13+
if g is None:
14+
g=b.graph
15+
16+
opnode = g.add_op(op)
1317
opnode.add_input(a.node)
1418
opnode.add_input(b.node)
1519
out.node.add_input(opnode)
16-
if a.graph.eager:
20+
if g.eager:
1721
ir=DeepxIR(op, a.dtype, [a.node.name, b.node.name], [out.node.name])
1822
send(ir)
1923
def _A_b_elementwiseop_C(
20-
a:Tensor,
21-
b: Optional[Union[ float, int]] = None,
24+
a:Optional[Union[ Tensor, float, int]] = None,
25+
b: Optional[Union[ Tensor, float, int]] = None,
2226
op:str=None,
2327
out:Tensor=None):
24-
varnode=a.graph.add_var("",b)
25-
opnode = a.graph.add_op(op)
26-
opnode.add_input(a.node)
27-
opnode.add_input(varnode)
28+
if isinstance(a,Tensor):
29+
g=a.graph
30+
else:
31+
g=b.graph
32+
33+
opnode = g.add_op(op)
34+
if isinstance(a,Tensor):
35+
opnode.add_input(a.node)
36+
else:
37+
varnode=g.add_var("",a)
38+
opnode.add_input(varnode)
39+
40+
if isinstance(b,Tensor):
41+
opnode.add_input(b.node)
42+
else:
43+
varnode=g.add_var("",b)
44+
opnode.add_input(varnode)
45+
2846
out.node.add_input(opnode)
29-
if a.graph.eager:
30-
varir=DeepxIR("argset", a.dtype, [b], [varnode.name])
31-
send(varir)
32-
ir=DeepxIR(op, a.dtype, [a.node.name,varnode.name], [out.node.name])
47+
if g.eager:
48+
ir=None
49+
if isinstance(a,Tensor):
50+
ir=DeepxIR(op, a.dtype, [a.node.name,b], [out.node.name])
51+
else:
52+
ir=DeepxIR(op, b.dtype, [a,b.node.name], [out.node.name])
3353
send(ir)
3454
#add
3555
OpNode.register("add")
@@ -75,15 +95,20 @@ def mul(
7595
#div
7696
OpNode.register("div")
7797
OpNode.register("div_scalar")
78-
98+
OpNode.register("rdiv_scalar")
7999
def div(
80100
a: Optional[Union[Tensor, float, int]] = None,
81101
b: Optional[Union[Tensor, float, int]] = None,
82102
out:Tensor=None):
83103
if isinstance(b,Tensor) and isinstance(a,Tensor):
84104
_A_B_elementwiseop_C(a,b,"div",out)
85105
else:
86-
_A_b_elementwiseop_C(a,b,"div_scalar",out)
106+
if isinstance(a,Tensor):
107+
#C=A/b
108+
_A_b_elementwiseop_C(a,b,"div_scalar",out)
109+
else:
110+
#C=a/B
111+
_A_b_elementwiseop_C(a,b,"rdiv_scalar",out)
87112

88113

89114
#clamp
@@ -118,6 +143,30 @@ def exp(
118143
ir=DeepxIR("exp", a.dtype, [a.node.name], [out.node.name])
119144
send(ir)
120145

146+
#sqrt
147+
OpNode.register("sqrt")
148+
def sqrt(
149+
input:Tensor,
150+
out:Optional[Tensor]=None)->Tensor:
151+
if out is None:
152+
out=Tensor(shape=input.shape, dtype=input.dtype, device=input.device)
153+
g=input.graph
154+
opnode = g.add_op("sqrt")
155+
opnode.add_input(input.node)
156+
out.node.add_input(opnode)
157+
if g.eager:
158+
ir=DeepxIR("sqrt", input.dtype, [input.node.name], [out.node.name])
159+
send(ir)
160+
return out
161+
162+
def rsqrt(
163+
input:Tensor,
164+
out:Optional[Tensor]=None)->Tensor:
165+
if out is None:
166+
out=Tensor(shape=input.shape, dtype=input.dtype, device=input.device)
167+
out=1/sqrt(input,out)
168+
return out
169+
121170
# OpNode.register("ReLU", 101)
122171
# OpNode.register("Placeholder", 102)
123172
# OpNode.register("Neg", 103)

0 commit comments

Comments
 (0)