Skip to content

Commit c03ed9b

Browse files
committed
sigmoid,swish:ok
1 parent 41537b0 commit c03ed9b

12 files changed

Lines changed: 357 additions & 266 deletions

File tree

excuter/op-mem-ompsimd/src/deepx/op/init.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,13 @@ namespace deepx::op{
1919
}
2020
void forward(mem::Mem &mem) override{
2121
auto output = mem.gettensor<T>(this->returns[0]).get();
22-
if (is_float(this->args[0])){
23-
T low = std::stof(this->args[0]);
24-
T high = std::stof(this->args[1]);
25-
tensorfunc::uniform(*output,low,high);
26-
}else{
27-
T low = mem.getarg<T>(this->args[0]);
28-
T high = mem.getarg<T>(this->args[1]);
29-
tensorfunc::uniform(*output,low,high);
22+
T low = this->getarg<T>(0,mem);
23+
T high = this->getarg<T>(1,mem);
24+
uint32_t seed = 0;
25+
if (this->args.size() == 3){
26+
seed = this->getarg<uint32_t>(2,mem);
3027
}
28+
tensorfunc::uniform(*output,low,high,seed);
3129
}
3230
void backward(mem::Mem &mem) override{
3331
throw std::runtime_error("Uniform op does not support backward");

excuter/op-mem-ompsimd/src/deepx/tensorfunc/init.hpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
namespace deepx::tensorfunc
1111
{
1212
template <typename T>
13-
void uniform(Tensor<T> &tensor,const T low = 0,const T high = 1)
13+
void uniform(Tensor<T> &tensor, const T low = 0, const T high = 1, const unsigned int seed = 0)
1414
{
1515
std::uniform_real_distribution<double> distribution(low, high);
1616
std::random_device rd;
@@ -20,7 +20,17 @@ namespace deepx::tensorfunc
2020
std::vector<std::default_random_engine> generators(num_threads);
2121
for (int i = 0; i < num_threads; ++i)
2222
{
23-
generators[i].seed(rd());
23+
if (seed == 0)
24+
{
25+
// 使用随机设备生成种子
26+
std::random_device rd;
27+
generators[i].seed(rd());
28+
}
29+
else
30+
{
31+
// 使用主seed和线程ID生成确定性种子
32+
generators[i].seed(seed + i);
33+
}
2434
}
2535

2636
#pragma omp parallel for
@@ -32,18 +42,16 @@ namespace deepx::tensorfunc
3242
}
3343

3444
template <typename T>
35-
void constant(Tensor<T> &tensor,const T value)
45+
void constant(Tensor<T> &tensor, const T value)
3646
{
3747
std::fill(tensor.data, tensor.data + tensor.shape.size, value);
3848
}
3949

4050
template <typename T>
41-
void arange(Tensor<T> &tensor, const T start,const T step = 1)
51+
void arange(Tensor<T> &tensor, const T start, const T step = 1)
4252
{
4353
tensor.shape.rangeParallel(tensor.shape.dim, [&](int idx_linear)
44-
{
45-
tensor.data[idx_linear] = start + (idx_linear)*step;
46-
});
54+
{ tensor.data[idx_linear] = start + (idx_linear)*step; });
4755
}
4856

4957
// template <typename T>

front/py/deepx/nn/functional/activite.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def sigmoid(
3636
else:
3737
outtensor=out
3838
from .elementwise import exp
39-
outtensor=1/((t*(-1)+1).exp())
39+
outtensor = 1 / ((t*-1).exp()+1)
4040
return outtensor
4141

4242
def swish(
@@ -56,32 +56,3 @@ def swish(
5656
输出张量
5757
"""
5858
return x*sigmoid(x*beta,out=out)
59-
60-
def swiglu(
61-
x: Tensor,
62-
W: Tensor, # 第一个投影矩阵
63-
V: Tensor, # 第二个投影矩阵
64-
beta: float = 1.0, # swish函数的缩放因子
65-
out: Union[Tensor,str] = '') -> Tensor:
66-
"""SwiGLU激活函数
67-
68-
.. math::
69-
\text{SwiGLU}(x, W, V) = \text{swish}(xW) \odot (xV)
70-
71-
其中:
72-
- :math:`\odot` 表示逐元素乘法
73-
- :math:`\text{swish}(x)` 是swish激活函数
74-
- :math:`W` 和 :math:`V` 是投影矩阵
75-
76-
Args:
77-
x: 输入张量
78-
W: 第一个投影矩阵
79-
V: 第二个投影矩阵
80-
beta: swish函数的缩放因子
81-
out: 输出张量或名称
82-
83-
Returns:
84-
输出张量
85-
"""
86-
result=swish(x@W,beta=beta).mul(x@V,out=out)
87-
return result

front/py/deepx/nn/functional/elementwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def sub(
9595
if isinstance(b,Tensor):
9696
return _A_B_elementwiseop_C(a,b,"sub",out)
9797
else:
98-
return _A_b_elementwiseop_C(a,b,"sub_scalar",out)
98+
return _A_b_elementwiseop_C(a,b*-1,"add_scalar",out)
9999

100100
#mul
101101
OpNode.register("mul")

front/py/deepx/nn/functional/init.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,21 @@ def arange(start=0, end=None, step=1,dtype=None, device=None,name:Union[Tensor,s
5454
return outtensor
5555

5656
OpNode.register("uniform")
57-
def uniform(t:Tensor,low=0, high=1)->Tensor:
57+
def uniform(t:Tensor,low=0, high=1,seed:int=0)->Tensor:
5858
if low >= high:
5959
raise ValueError(f"low({low})必须小于high({high})")
6060
if t is None:
6161
raise ValueError("t不能为None")
6262
g=t.graph
63-
arglow=g.add_var('',low)
64-
arghigh=g.add_var('',high)
63+
6564
opnode = g.add_op("uniform")
66-
opnode.add_input(arglow)
67-
opnode.add_input(arghigh)
65+
opnode.add_input(g.add_var('',low))
66+
opnode.add_input(g.add_var('',high))
67+
if seed is not None:
68+
opnode.add_input(g.add_var('',seed))
6869
t.node.add_input(opnode)
6970
if t.graph.eager:
70-
ir=DeepxIR("uniform", t.dtype, [low, high], [t.node.name])
71+
ir=DeepxIR("uniform", t.dtype, [low, high,seed], [t.node.name])
7172
send(ir)
7273
return t
7374

front/py/deepx/nn/modules/activation.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from deepx.tensor import Tensor
1+
from typing import Union
2+
from deepx import Tensor,ones
23
import deepx.nn.functional as F
34
from .module import Module
45

@@ -24,4 +25,32 @@ def __init__(self):
2425

2526
def forward(self, input: Tensor) -> Tensor:
2627
return F.sigmoid(input)
27-
28+
29+
class Swish(Module):
30+
def __init__(self):
31+
super().__init__()
32+
33+
def forward(self, input: Tensor) -> Tensor:
34+
return F.swish(input)
35+
36+
37+
class Swiglu(Module):
38+
def __init__(self):
39+
super().__init__()
40+
self.W = ones(shape=(1,1),name=self.full_name+"_W")
41+
self.V = ones(shape=(1,1),name=self.full_name+"_V")
42+
43+
def swiglu(
44+
x: Tensor,
45+
W: Tensor, # 第一个投影矩阵
46+
V: Tensor, # 第二个投影矩阵
47+
beta: float = 1.0, # swish函数的缩放因子
48+
out: Union[Tensor,str] = '') -> Tensor:
49+
from deepx.nn.functional import swish
50+
result=swish(x@W,beta=beta).mul(x@V,out=out)
51+
return result
52+
53+
54+
def forward(self, input: Tensor) -> Tensor:
55+
return self.swiglu(input,self.W,self.V)
56+

front/py/deepx/tensor/init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def ones_(self):
1616
constant_func(self,value=1)
1717

1818
@tensor_method
19-
def uniform_(self,low=0, high=1):
19+
def uniform_(self,low=0, high=1,seed:int=0):
2020
from deepx.nn.functional import uniform as uniform_func
21-
uniform_func(self,low=low, high=high)
21+
uniform_func(self,low=low, high=high,seed=seed)
2222

2323
@tensor_method
2424
def rand_(self):

front/py/examples/3_functional/1_sigmoid.dot

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,55 @@
22
digraph {
33
rankdir=TB
44
node [shape=record]
5-
124458605174432 [label="x
5+
130115172244864 [label="x
66
(3, 4, 5)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
7-
124458605170256 [label="var_1
8-
-1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
9-
124458605174000 [label="var_2
10-
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
11-
124458605169248 [label=uniform color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
12-
124458605656912 [label="out
13-
(3, 4, 5)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
14-
124458605657104 [label="tensor_3
7+
130117430101712 [label=reshape color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
8+
130117430101376 [label="vector_1
9+
(3, 4, 5)" color=darkseagreen fillcolor=honeydew fontname="Sans-Serif" labeljust=l shape=box style=filled]
10+
130115170569184 [label=div_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
11+
130117430113280 [label="var_1
12+
10.0" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
13+
130115170569280 [label=add_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
14+
130115170569136 [label="var_2
15+
-3.0" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
16+
130115170569088 [label="out
1517
(3, 4, 5)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
16-
124458605657200 [label=mul_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
17-
124458605657152 [label="var_3
18+
130115170569520 [label=mul_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
19+
130115170569472 [label="var_3
1820
-1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
19-
124458605657488 [label="tensor_4
21+
130115170569664 [label="tensor_3
2022
(3, 4, 5)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
21-
124458605657584 [label=add_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
22-
124458605657536 [label="var_4
23-
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
24-
124458605657872 [label="tensor_5
23+
130115170569904 [label=exp color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
24+
130115170570000 [label="tensor_4
2525
(3, 4, 5)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
26-
124458605657968 [label=exp color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
27-
124458605904032 [label="tensor_6
26+
130115170576384 [label=add_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
27+
130115170576336 [label="var_4
28+
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
29+
130115170576528 [label="tensor_5
2830
(3, 4, 5)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
29-
124458605904128 [label=rdiv_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
30-
124458605904080 [label="var_5
31+
130115170576768 [label=rdiv_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
32+
130115170576720 [label="var_5
3133
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
32-
124458605169248 -> 124458605174432 [arrowsize=0.8 color=gray40 penwidth=1.2]
33-
124458605170256 -> 124458605169248 [arrowsize=0.8 color=gray40 penwidth=1.2]
34-
124458605174000 -> 124458605169248 [arrowsize=0.8 color=gray40 penwidth=1.2]
35-
124458605657200 -> 124458605657104 [arrowsize=0.8 color=gray40 penwidth=1.2]
36-
124458605174432 -> 124458605657200 [arrowsize=0.8 color=gray40 penwidth=1.2]
37-
124458605657152 -> 124458605657200 [arrowsize=0.8 color=gray40 penwidth=1.2]
38-
124458605657584 -> 124458605657488 [arrowsize=0.8 color=gray40 penwidth=1.2]
39-
124458605657104 -> 124458605657584 [arrowsize=0.8 color=gray40 penwidth=1.2]
40-
124458605657536 -> 124458605657584 [arrowsize=0.8 color=gray40 penwidth=1.2]
41-
124458605657968 -> 124458605657872 [arrowsize=0.8 color=gray40 penwidth=1.2]
42-
124458605657488 -> 124458605657968 [arrowsize=0.8 color=gray40 penwidth=1.2]
43-
124458605904128 -> 124458605904032 [arrowsize=0.8 color=gray40 penwidth=1.2]
44-
124458605904080 -> 124458605904128 [arrowsize=0.8 color=gray40 penwidth=1.2]
45-
124458605657872 -> 124458605904128 [arrowsize=0.8 color=gray40 penwidth=1.2]
34+
130115170576912 [label="tensor_6
35+
(3, 4, 5)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
36+
130117430101712 -> 130115172244864 [arrowsize=0.8 color=gray40 penwidth=1.2]
37+
130115170569184 -> 130115172244864 [arrowsize=0.8 color=gray40 penwidth=1.2]
38+
130115170569280 -> 130115172244864 [arrowsize=0.8 color=gray40 penwidth=1.2]
39+
130115172244864 -> 130117430101712 [arrowsize=0.8 color=gray40 penwidth=1.2]
40+
130117430101376 -> 130117430101712 [arrowsize=0.8 color=gray40 penwidth=1.2]
41+
130115172244864 -> 130115170569184 [arrowsize=0.8 color=gray40 penwidth=1.2]
42+
130117430113280 -> 130115170569184 [arrowsize=0.8 color=gray40 penwidth=1.2]
43+
130115172244864 -> 130115170569280 [arrowsize=0.8 color=gray40 penwidth=1.2]
44+
130115170569136 -> 130115170569280 [arrowsize=0.8 color=gray40 penwidth=1.2]
45+
130115172244864 -> 130115170569520 [arrowsize=0.8 color=gray40 penwidth=1.2]
46+
130115170569472 -> 130115170569520 [arrowsize=0.8 color=gray40 penwidth=1.2]
47+
130115170569520 -> 130115170569664 [arrowsize=0.8 color=gray40 penwidth=1.2]
48+
130115170569664 -> 130115170569904 [arrowsize=0.8 color=gray40 penwidth=1.2]
49+
130115170569904 -> 130115170570000 [arrowsize=0.8 color=gray40 penwidth=1.2]
50+
130115170570000 -> 130115170576384 [arrowsize=0.8 color=gray40 penwidth=1.2]
51+
130115170576336 -> 130115170576384 [arrowsize=0.8 color=gray40 penwidth=1.2]
52+
130115170576384 -> 130115170576528 [arrowsize=0.8 color=gray40 penwidth=1.2]
53+
130115170576720 -> 130115170576768 [arrowsize=0.8 color=gray40 penwidth=1.2]
54+
130115170576528 -> 130115170576768 [arrowsize=0.8 color=gray40 penwidth=1.2]
55+
130115170576768 -> 130115170576912 [arrowsize=0.8 color=gray40 penwidth=1.2]
4656
}

0 commit comments

Comments
 (0)