Skip to content

Commit 1a67ac1

Browse files
committed
py:module.Tensor的name体系解决
1 parent dc9a7df commit 1a67ac1

File tree

13 files changed

+641
-315
lines changed

13 files changed

+641
-315
lines changed

front/py/deepx/nn/functional/activite.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,36 @@ def sigmoid(
3636
from .elementwise import exp
3737
outtensor=1/(1+(t*(-1)).exp())
3838
return outtensor
39+
40+
def swiglu(
41+
x: Tensor,
42+
w: Tensor, # 第一个投影矩阵
43+
v: Tensor, # 第二个投影矩阵
44+
beta: float = 1.0, # swish函数的缩放因子
45+
out: Union[Tensor,str] = '') -> Tensor:
46+
"""SwiGLU激活函数
47+
48+
Args:
49+
x: 输入张量
50+
w: 第一个投影矩阵
51+
v: 第二个投影矩阵
52+
beta: Swish函数的β参数,默认为1.0
53+
out: 输出张量名称
54+
"""
55+
# 计算两个线性变换
56+
xw = x @ w # 第一个投影
57+
xv = x @ v # 第二个投影
58+
59+
# 计算Swish(xw)
60+
beta_xw = xw * beta
61+
sigmoid_beta_xw = 1 / (1 + (-beta_xw).exp())
62+
swish = xw * sigmoid_beta_xw
63+
64+
# 最终的逐元素相乘
65+
result = swish * xv
66+
67+
# 处理输出
68+
if isinstance(out, str):
69+
result.addtograph(out)
70+
71+
return result

front/py/deepx/nn/functional/elementwise.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,29 @@ def exp(
165165
ir=DeepxIR("exp", a.dtype, [a.node.name], [outtensor.node.name])
166166
send(ir)
167167

168+
#pow
169+
# todo
170+
OpNode.register("pow")
171+
def pow(
172+
a:Tensor,
173+
b:Union[float,int],
174+
out:Union[Tensor,str]=''):
175+
g=a.graph
176+
opnode = g.add_op("pow")
177+
opnode.add_input(a.node)
178+
opnode.add_input(g.add_var('',b))
179+
180+
outtensor=None
181+
if isinstance(out,str):
182+
outtensor=Tensor(shape=a.shape, dtype=a.dtype, device=a.device)
183+
outtensor.addtograph(out)
184+
else:
185+
outtensor=out
186+
outtensor.node.add_input(opnode)
187+
if a.graph.eager:
188+
ir=DeepxIR("pow", a.dtype, [a.node.name,b], [outtensor.node.name])
189+
send(ir)
190+
168191
#sqrt
169192
OpNode.register("sqrt")
170193
def sqrt(

front/py/deepx/nn/modules/linear.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def __init__(
2020
self.in_features = in_features
2121
self.out_features = out_features
2222
self.weight = Tensor(shape=(out_features, in_features),dtype=dtype)
23-
2423
if bias:
2524
self.bias = Tensor(shape=(out_features,),dtype=dtype)
2625
else:

front/py/deepx/tensor/elementwise.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,22 @@ def exp_(self):
139139
exp_func(self,self)
140140
return self
141141

142+
@tensor_method
143+
def pow(self,
144+
b:Union[float,int],
145+
out:Union[Tensor,str]=''):
146+
from deepx.nn.functional import pow as pow_func
147+
result=pow_func(self,b,out)
148+
return result
149+
150+
@tensor_method
151+
def pow_(self,
152+
b:Union[float,int]):
153+
from deepx.nn.functional import pow as pow_func
154+
result=pow_func(self,b,self)
155+
return result
156+
157+
142158
@tensor_method
143159
def sqrt(self,out:Optional[Union[str]]=None):
144160
result = Tensor(dtype=self.dtype,shape=self.shape)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from deepx.nn.modules import Module
2+
from deepx import Tensor,ones,rsqrt
3+
class LlamaRMSNorm(Module):
4+
def __init__(self, hidden_size, eps=1e-6):
5+
"""
6+
LlamaRMSNorm is equivalent to T5LayerNorm
7+
"""
8+
super().__init__()
9+
self.weight = ones(hidden_size)
10+
self.variance_epsilon = eps
11+
12+
def forward(self, hidden_states:Tensor):
13+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
14+
hidden_states = hidden_states * rsqrt(variance + self.variance_epsilon)
15+
return self.weight * hidden_states
16+
17+
def extra_repr(self):
18+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

front/py/examples/2_ir/5_reduce.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@
55
t.addtograph("t")
66
t.uniform_(low=-1,high=1)
77
print((t))
8-
s=sum(t,dims=[0,2],out="s")
8+
s=sum(t,dim=[0,2],out="s")
99
print(s)
1010

1111

1212
t1=ones(4,5,6,name="t1")
1313
print(t1)
14-
t2=sum(t1,dims=[0,1],out='t2')
14+
t2=sum(t1,dim=[0,1],out='t2')
1515
print(t2)
1616

1717
t3=arange(0,120,1,name="t3")
1818
t3.reshape_(4,5,6)
1919
print(t3)
2020

21-
t3_mean=mean(t3,dims=[0,1],out='t3_mean')
21+
t3_mean=mean(t3,dim=[0,1],out='t3_mean')
2222
print(t3_mean)
2323

2424
gviz=t.graph.to_dot()
25-
gviz.render('sum.dot',format='svg')
25+
gviz.render('reduce.dot',format='svg')

front/py/examples/2_ir/math.dot

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Computational Graph
2+
digraph {
3+
rankdir=TB
4+
node [shape=record]
5+
128076282530160 [label="tensor_1
6+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
7+
128076282537744 [label=constant color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
8+
128076282538512 [label="var_1
9+
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
10+
128076283260656 [label="tensor_2
11+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
12+
128076283260752 [label=add_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
13+
128076283260704 [label="var_2
14+
3" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
15+
128076283260848 [label="tensor_3
16+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
17+
128076283261088 [label=sqrt color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
18+
128076283261184 [label="tensor_4
19+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
20+
128076283261424 [label=sqrt color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
21+
128076283261664 [label="tensor_5
22+
(2, 3)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
23+
128076283261760 [label=rdiv_scalar color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
24+
128076283261712 [label="var_3
25+
1" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
26+
128076282537744 -> 128076282530160 [arrowsize=0.8 color=gray40 penwidth=1.2]
27+
128076282538512 -> 128076282537744 [arrowsize=0.8 color=gray40 penwidth=1.2]
28+
128076283260752 -> 128076283260656 [arrowsize=0.8 color=gray40 penwidth=1.2]
29+
128076282530160 -> 128076283260752 [arrowsize=0.8 color=gray40 penwidth=1.2]
30+
128076283260704 -> 128076283260752 [arrowsize=0.8 color=gray40 penwidth=1.2]
31+
128076283261088 -> 128076283260848 [arrowsize=0.8 color=gray40 penwidth=1.2]
32+
128076283260656 -> 128076283261088 [arrowsize=0.8 color=gray40 penwidth=1.2]
33+
128076283261424 -> 128076283261184 [arrowsize=0.8 color=gray40 penwidth=1.2]
34+
128076283260656 -> 128076283261424 [arrowsize=0.8 color=gray40 penwidth=1.2]
35+
128076283261760 -> 128076283261664 [arrowsize=0.8 color=gray40 penwidth=1.2]
36+
128076283261712 -> 128076283261760 [arrowsize=0.8 color=gray40 penwidth=1.2]
37+
128076283261184 -> 128076283261760 [arrowsize=0.8 color=gray40 penwidth=1.2]
38+
}
Lines changed: 171 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)