Skip to content

Commit 5397848

Browse files
committed
relu+uniform:测试ok
1 parent 7ec711d commit 5397848

16 files changed

Lines changed: 191 additions & 58 deletions

File tree

excuter/op-mem-ompsimd/src/deepx/op/elementwise.hpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,19 @@ namespace deepx::op
251251
//已验证,2025-02-19,lipeng
252252
void forward(mem::Mem &mem) override
253253
{
254-
auto a = mem.gettensor<T>(this->args[0]).get();
255-
auto b = mem.getarg<T>(this->args[1]);
256-
auto c = mem.gettensor<T>(this->returns[0]).get();
257-
deepx::tensorfunc::div(*a, b, *c); // 直接使用除法
254+
if (mem.existstensor(this->args[0])){
255+
//C= A/b
256+
auto A = mem.gettensor<T>(this->args[0]).get();
257+
auto b = mem.getarg<T>(this->args[1]);
258+
auto C = mem.gettensor<T>(this->returns[0]).get();
259+
tensorfunc::div_scalar(*A, b, *C); // 直接使用除法
260+
}else{
261+
//C=a/B
262+
auto a = mem.getarg<T>(this->args[0]);
263+
auto B = mem.gettensor<T>(this->args[1]).get();
264+
auto C = mem.gettensor<T>(this->returns[0]).get();
265+
tensorfunc::div_scalar(a, *B, *C); // 直接使用除法
266+
}
258267
}
259268

260269
//已验证,2025-02-19,lipeng

excuter/op-mem-ompsimd/src/deepx/op/opfactory.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,16 @@ namespace deepx::op
112112
}
113113
//reduce
114114
void register_reduce(OpFactory &opfactory){
115-
opfactory.add_op(Sum<float>());
116-
opfactory.add_op(Sum<double>());
117115
opfactory.add_op(Max<float>());
118116
opfactory.add_op(Max<double>());
117+
opfactory.add_op(Max_scalar<float>());
118+
opfactory.add_op(Max_scalar<double>());
119119
opfactory.add_op(Min<float>());
120120
opfactory.add_op(Min<double>());
121+
opfactory.add_op(Min_scalar<float>());
122+
opfactory.add_op(Min_scalar<double>());
123+
opfactory.add_op(Sum<float>());
124+
opfactory.add_op(Sum<double>());
121125
}
122126
int register_all(OpFactory &opfactory){
123127
register_new(opfactory);

excuter/op-mem-ompsimd/src/deepx/op/reduce.hpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "deepx/tensorfunc/reduce.hpp"
66
#include "deepx/tensorfunc/broadcast.hpp"
77
#include "deepx/tensorfunc/compare.hpp"
8+
#include "stdutil/num.hpp"
89

910
namespace deepx::op
1011
{
@@ -82,14 +83,24 @@ namespace deepx::op
8283

8384
void forward(mem::Mem &mem) override{
8485
auto A=mem.gettensor<T>(this->args[0]);
85-
auto b=mem.getarg<T>(this->args[1]);
86+
T b;
87+
if (!is_float(this->args[1])){
88+
b=mem.getarg<T>(this->args[1]);
89+
}else{
90+
b=T(atof(this->args[1].c_str()));
91+
}
8692
auto output=mem.gettensor<T>(this->returns[0]);
8793
deepx::tensorfunc::max(*A, b, *output);
8894
}
8995

9096
void backward(mem::Mem &mem) override{
9197
auto A=mem.gettensor<T>(this->args[0]);
92-
auto b=mem.getarg<T>(this->args[1]);
98+
T b;
99+
if (!is_float(this->args[1])){
100+
b=mem.getarg<T>(this->args[1]);
101+
}else{
102+
b=T(atof(this->args[1].c_str()));
103+
}
93104
auto A_grad=mem.gettensor<T>(this->args_grad [0]);
94105
auto output_grad=mem.gettensor<T>(this->returns_grad[0]);
95106
deepx::tensorfunc::max_grad(*A, b, *A_grad, *output_grad);
@@ -139,14 +150,24 @@ namespace deepx::op
139150
}
140151
void forward(mem::Mem &mem) override{
141152
auto A=mem.gettensor<T>(this->args[0]);
142-
auto b=mem.getarg<T>(this->args[1]);
153+
T b;
154+
if (!is_float(this->args[1])){
155+
b=mem.getarg<T>(this->args[1]);
156+
}else{
157+
b=T(atof(this->args[1].c_str()));
158+
}
143159
auto output=mem.gettensor<T>(this->returns[0]);
144160
deepx::tensorfunc::min(*A, b, *output);
145161
}
146162

147163
void backward(mem::Mem &mem) override{
148164
auto A=mem.gettensor<T>(this->args[0]);
149-
auto b=mem.getarg<T>(this->args[1]);
165+
T b;
166+
if (!is_float(this->args[1])){
167+
b=mem.getarg<T>(this->args[1]);
168+
}else{
169+
b=T(atof(this->args[1].c_str()));
170+
}
150171
auto A_grad=mem.gettensor<T>(this->args_grad[0]);
151172
auto output_grad=mem.gettensor<T>(this->returns_grad[0]);
152173
deepx::tensorfunc::min_grad(*A, b, *A_grad, *output_grad);

excuter/op-mem-ompsimd/src/deepx/tensorfunc/elementwise.hpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,8 @@ namespace deepx::tensorfunc
506506
}
507507
}
508508

509+
//div
510+
// C= A/B
509511
template <typename T>
510512
void div(const Tensor<T> &A, const Tensor<T> &B, Tensor<T> &C)
511513
{
@@ -685,9 +687,10 @@ namespace deepx::tensorfunc
685687
}
686688
}
687689

688-
690+
//div_scalar
691+
// C= A/value
689692
template <typename T>
690-
void div(const Tensor<T> &input, const T value, Tensor<T> &output)
693+
void div_scalar(const Tensor<T> &input, const T value, Tensor<T> &output)
691694
{
692695
if (input.shape == output.shape)
693696
{
@@ -726,6 +729,48 @@ namespace deepx::tensorfunc
726729
}
727730
}
728731

732+
//div_scalar
733+
// C= A/value
734+
template <typename T>
735+
void div_scalar(const T value,const Tensor<T> &t, Tensor<T> &output)
736+
{
737+
if (t.shape == output.shape)
738+
{
739+
output.shape.rangeParallel(output.shape.dim - 1, [&t, &output, &value](int i)
740+
{
741+
int shape_last=output.shape[-1];
742+
const ScalableTag<T> tag;
743+
const size_t lanes = Lanes(tag);
744+
size_t j=0;
745+
746+
// 1. 处理前置未对齐部分
747+
while (j < shape_last && !IsAligned(tag,t.data + i + j)) {
748+
output.data[i+j] = value / t.data[i+j] ;
749+
++j;
750+
}
751+
752+
// 2. 处理中间对齐部分
753+
size_t aligned_end=shape_last-(shape_last%lanes);
754+
for (; j+lanes<=aligned_end; j += lanes )
755+
{
756+
auto vec = Load(tag, t.data + i + j);
757+
auto scalar = Set(tag, value);
758+
auto vec_result = Div(scalar, vec);
759+
Store(vec_result, tag, output.data + i + j);
760+
}
761+
762+
// 3. 处理尾部剩余元素
763+
for (;j<shape_last;j++)
764+
{
765+
output.data[i+j] = value / t.data[i+j] ;
766+
} });
767+
}
768+
else
769+
{
770+
throw std::invalid_argument("shape mismatch");
771+
}
772+
}
773+
729774
template <typename T>
730775
void sqrt(const Tensor<T> &input, Tensor<T> &output)
731776
{

front/py/deepx/nn/functional/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from .new import newtensor
33
from .print import printtensor
44
from .matmul import matmul
5-
from .init import constant,full,zeros,ones,arange,rand,randn,eye
5+
from .init import constant,full,zeros,ones,uniform_,arange,rand,randn,eye
66
from .reduce import max,min,sum,prod,mean
77
from .transpose import transpose
88
from .activite import relu
99

1010
__all__ = [
1111
"newtensor",
1212
"printtensor",
13-
"constant","full","zeros","ones","arange","rand","randn","eye",
13+
"constant","full","zeros","ones","uniform_","arange","rand","randn","eye",
1414
"add","sub","mul","div","clamp",
1515
"matmul",
1616
"max","min","sum","prod","mean",

front/py/deepx/nn/functional/activite.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,11 @@ def relu(t: Tensor,inplace:bool=False)->Tensor:
99
ir=DeepxIR("max_scalar",t.dtype,[t._node.name,0], [out._node.name])
1010
send(ir)
1111
return out
12-
12+
13+
# 数学公式:σ(x) = 1 / (1 + exp(-x))
14+
def sigmoid(t: Tensor,inplace:bool=False)->Tensor:
15+
out=t
16+
if not inplace:
17+
out=Tensor(shape=t.shape, dtype=t.dtype, device=t.device)
18+
out=1/(1+(t*(-1)).exp())
19+
return out

front/py/deepx/nn/functional/elementwise.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def sub(
5858
else:
5959
_A_b_elementwiseop_C(a,b,"sub_scalar",out)
6060

61-
6261
#mul
6362
OpNode.register("mul")
6463
OpNode.register("mul_scalar")
@@ -78,15 +77,15 @@ def mul(
7877
OpNode.register("div_scalar")
7978

8079
def div(
81-
a:Tensor,
80+
a: Optional[Union[Tensor, float, int]] = None,
8281
b: Optional[Union[Tensor, float, int]] = None,
8382
out:Tensor=None):
84-
if isinstance(b,Tensor):
83+
if isinstance(b,Tensor) and isinstance(a,Tensor):
8584
_A_B_elementwiseop_C(a,b,"div",out)
8685
else:
8786
_A_b_elementwiseop_C(a,b,"div_scalar",out)
8887

89-
88+
9089
#clamp
9190
OpNode.register("clamp")
9291
def clamp(
@@ -107,6 +106,18 @@ def clamp(
107106
varir=DeepxIR("clamp", a.dtype, [a.node.name,min,max], [out.node.name])
108107
send(str(varir))
109108

109+
#exp
110+
OpNode.register("exp")
111+
def exp(
112+
a:Tensor,
113+
out:Tensor=None):
114+
opnode = a.graph.add_op("exp")
115+
opnode.add_input(a.node)
116+
out.node.add_input(opnode)
117+
if a.graph.eager:
118+
ir=DeepxIR("exp", a.dtype, [a.node.name], [out.node.name])
119+
send(ir)
120+
110121
# OpNode.register("ReLU", 101)
111122
# OpNode.register("Placeholder", 102)
112123
# OpNode.register("Neg", 103)

front/py/deepx/nn/functional/init.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from typing import Optional
22

3-
from deepx.tensor import Tensor
3+
from deepx import Tensor
4+
from deepx.autograd.graph import OpNode
45
from deepx.nn.deepxir import DeepxIR
56
from deepx.scheduler import send
67

78
def constant(t:Tensor, fill_value):
9+
opnode = t.graph.add_op("constant")
10+
opnode.add_input(t.node)
11+
argnode=t.graph.add_var('',fill_value)
12+
opnode.add_input(argnode)
13+
t.node.add_input(opnode)
814
if t.graph.eager:
915
ir=DeepxIR("constant", t.dtype, [fill_value], [t.node.name])
1016
send(ir)
@@ -23,6 +29,22 @@ def zeros(*shape, dtype=None, device=None):
2329
def ones(*size, dtype=None, device=None):
2430
return full(*size, fill_value=1, dtype=dtype, device=device)
2531

32+
OpNode.register("uniform")
33+
def uniform_(t:Tensor,low=0, high=1)->Tensor:
34+
if low >= high:
35+
raise ValueError(f"low({low})必须小于high({high})")
36+
opnode = t.graph.add_op("uniform")
37+
opnode.add_input(t.node)
38+
arglow=t.graph.add_var('',low)
39+
arghigh=t.graph.add_var('',high)
40+
opnode.add_input(arglow)
41+
opnode.add_input(arghigh)
42+
t.node.add_input(opnode)
43+
if t.graph.eager:
44+
ir=DeepxIR("uniform", t.dtype, [low, high], [t.node.name])
45+
send(ir)
46+
return t
47+
2648
def rand(*size, dtype=None, device=None):
2749
#TODO
2850
pass

front/py/deepx/nn/modules/activation.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,6 @@
44

55
#copy from pytorch
66
class ReLU(Module):
7-
r"""Applies the rectified linear unit function element-wise.
8-
9-
:math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
10-
11-
Args:
12-
inplace: can optionally do the operation in-place. Default: ``False``
13-
14-
Shape:
15-
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
16-
- Output: :math:`(*)`, same shape as the input.
17-
18-
.. image:: ../scripts/activation_images/ReLU.png
19-
20-
Examples::
21-
22-
>>> m = nn.ReLU()
23-
>>> input = torch.randn(2)
24-
>>> output = m(input)
25-
26-
27-
An implementation of CReLU - https://arxiv.org/abs/1603.05201
28-
29-
>>> m = nn.ReLU()
30-
>>> input = torch.randn(2).unsqueeze(0)
31-
>>> output = torch.cat((m(input), m(-input)))
32-
"""
33-
347
__constants__ = ["inplace"]
358
inplace: bool
369

@@ -43,4 +16,12 @@ def forward(self, input: Tensor) -> Tensor:
4316

4417
def extra_repr(self) -> str:
4518
inplace_str = "inplace=True" if self.inplace else ""
46-
return inplace_str
19+
return inplace_str
20+
21+
class Sigmoid(Module):
22+
def __init__(self):
23+
super().__init__()
24+
25+
def forward(self, input: Tensor) -> Tensor:
26+
return F.sigmoid(input)
27+

front/py/deepx/nn/modules/module.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ def _generate_default_name(self) -> str:
2121
return f"{base_name}_{count}"
2222

2323
def __setattr__(self, name: str, value: Any) -> None:
24-
if isinstance(value, Module):
25-
self.register_module(name, value)
26-
elif isinstance(value, Tensor):
27-
self.register_parameter(name, value)
24+
if not name.startswith('_'):
25+
if isinstance(value, Module):
26+
self.register_module(name, value)
27+
elif isinstance(value, Tensor):
28+
self.register_parameter(name, value)
29+
# 使用父类方法设置属性,避免递归
2830
super().__setattr__(name, value)
2931

3032
def register_module(self, name: str, module: Optional['Module']) -> None:

0 commit comments

Comments
 (0)