Skip to content

Commit 0ec68c1

Browse files
committed
linear:linear验证
1 parent 184c4b7 commit 0ec68c1

6 files changed

Lines changed: 219 additions & 268 deletions

File tree

front/py/deepx/nn/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
__all__ = [
1111
"newtensor",
1212
"printtensor",
13-
"constant","full","zeros","ones","uniform","arange","rand","randn","eye","kaiming_uniform_",
13+
"constant","full","zeros","ones","uniform","arange","rand","randn","eye","kaiming_uniform_","calculate_fan_in_and_fan_out",
1414
"add","sub","mul","div","clamp","exp","sqrt","rsqrt",
1515
"matmul",
1616
"max","min","sum","prod","mean",

front/py/deepx/nn/functional/init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def eye(
7272
pass
7373

7474

75-
def _calculate_fan_in_and_fan_out(tensor:Tensor):
75+
def calculate_fan_in_and_fan_out(tensor:Tensor):
7676
dimensions = tensor.dim()
7777
if dimensions < 2:
7878
raise ValueError(
@@ -98,7 +98,7 @@ def _calculate_correct_fan(tensor:Tensor, mode:str):
9898
if mode not in valid_modes:
9999
raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
100100

101-
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
101+
fan_in, fan_out = calculate_fan_in_and_fan_out(tensor)
102102
return fan_in if mode == "fan_in" else fan_out
103103

104104
#copy from torch.nn/init.py

front/py/deepx/nn/modules/linear.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .module import Module
2-
from deepx.tensor import Tensor
3-
2+
from deepx import Tensor
3+
from deepx.nn.functional import uniform,kaiming_uniform_,calculate_fan_in_and_fan_out
4+
import math
45
class Linear(Module):
56
r'''
67
copy from torch.nn.Linear
@@ -30,14 +31,18 @@ def reset_parameters(self) -> None:
3031
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
3132
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
3233
# https://github.com/pytorch/pytorch/issues/57109
33-
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
34+
kaiming_uniform_(self.weight, a=math.sqrt(5))
3435
if self.bias is not None:
35-
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
36+
fan_in, _ = calculate_fan_in_and_fan_out(self.weight)
3637
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
37-
init.uniform_(self.bias, -bound, bound)
38+
uniform(self.bias, -bound, bound)
3839

3940
def forward(self, input: Tensor) -> Tensor:
40-
return F.linear(input, self.weight, self.bias)
41+
#`y = xA^T + b`
42+
if self.bias is None:
43+
return input @ self.weight.T
44+
else:
45+
return input @ self.weight.T + self.bias
4146

4247
def extra_repr(self) -> str:
4348
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"

front/py/examples/3_module/1_linear.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
from deepx.nn.modules import Linear, Module
22
from deepx import Tensor
3-
class Net(Module):
4-
def __init__(self):
5-
super().__init__()
6-
self.fc1 = Linear(64, 32)
7-
self.fc2 = Linear(32, 2)
8-
def forward(self, x):
9-
x = self.fc1(x)
10-
x = self.fc2(x)
11-
return x
123

13-
net = Net()
4+
net = Linear(64, 4)
145
input=Tensor(shape=[1,64])
156
out=net.forward(input)
7+
print(out)
168
net.graph.to_dot().render('linear.dot',format='svg')
179

1810
for name, param in net.named_parameters():

front/py/examples/3_module/linear.dot

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,46 @@
22
digraph {
33
rankdir=TB
44
node [shape=record]
5-
135820316918448 [label="tensor_1
6-
(32, 64)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
7-
135820316918736 [label="tensor_2
8-
(32,)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
9-
135820316918592 [label="tensor_3
10-
(2, 32)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
11-
135820316918928 [label="tensor_4
12-
(2,)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
13-
135820316919120 [label="tensor_5
5+
126836298434400 [label="tensor_1
6+
(4, 64)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
7+
126836299154480 [label="tensor_2
8+
(4,)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
9+
126836299154432 [label="var_1
10+
-0.12499999999999999" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
11+
126836299154624 [label="var_2
12+
0.12499999999999999" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
13+
126836299154000 [label=uniform color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
14+
126836299154528 [label="var_3
15+
-0.125" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
16+
126836299154336 [label="var_4
17+
0.125" color=orange fillcolor=moccasin fontname="Sans-Serif" labeljust=l shape=box style=filled]
18+
126836299154720 [label=uniform color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
19+
126836299154912 [label="tensor_3
1420
(1, 64)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
15-
135820316919360 [label="tensor_6
16-
(64, 32)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
17-
135820316919312 [label="vector_1
21+
126836299155152 [label="tensor_4
22+
(64, 4)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
23+
126836299155104 [label="vector_1
1824
(1, 0)" color=darkseagreen fillcolor=honeydew fontname="Sans-Serif" labeljust=l shape=box style=filled]
19-
135820316919456 [label=transpose color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
20-
135820316919648 [label="tensor_7
21-
(1, 32)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
22-
135820316919744 [label=matmul color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
23-
135820316920608 [label="tensor_8
24-
(1, 32)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
25-
135820316920704 [label=add color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
26-
135820316920944 [label="tensor_9
27-
(32, 2)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
28-
135820316920896 [label="vector_2
29-
(1, 0)" color=darkseagreen fillcolor=honeydew fontname="Sans-Serif" labeljust=l shape=box style=filled]
30-
135820316921040 [label=transpose color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
31-
135820316921232 [label="tensor_10
32-
(1, 2)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
33-
135820316921328 [label=matmul color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
34-
135820316921520 [label="tensor_11
35-
(1, 2)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
36-
135820316921616 [label=add color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
37-
135820316919456 -> 135820316919360 [arrowsize=0.8 color=gray40 penwidth=1.2]
38-
135820316918448 -> 135820316919456 [arrowsize=0.8 color=gray40 penwidth=1.2]
39-
135820316919312 -> 135820316919456 [arrowsize=0.8 color=gray40 penwidth=1.2]
40-
135820316919744 -> 135820316919648 [arrowsize=0.8 color=gray40 penwidth=1.2]
41-
135820316919120 -> 135820316919744 [arrowsize=0.8 color=gray40 penwidth=1.2]
42-
135820316919360 -> 135820316919744 [arrowsize=0.8 color=gray40 penwidth=1.2]
43-
135820316920704 -> 135820316920608 [arrowsize=0.8 color=gray40 penwidth=1.2]
44-
135820316919648 -> 135820316920704 [arrowsize=0.8 color=gray40 penwidth=1.2]
45-
135820316918736 -> 135820316920704 [arrowsize=0.8 color=gray40 penwidth=1.2]
46-
135820316921040 -> 135820316920944 [arrowsize=0.8 color=gray40 penwidth=1.2]
47-
135820316918592 -> 135820316921040 [arrowsize=0.8 color=gray40 penwidth=1.2]
48-
135820316920896 -> 135820316921040 [arrowsize=0.8 color=gray40 penwidth=1.2]
49-
135820316921328 -> 135820316921232 [arrowsize=0.8 color=gray40 penwidth=1.2]
50-
135820316920608 -> 135820316921328 [arrowsize=0.8 color=gray40 penwidth=1.2]
51-
135820316920944 -> 135820316921328 [arrowsize=0.8 color=gray40 penwidth=1.2]
52-
135820316921616 -> 135820316921520 [arrowsize=0.8 color=gray40 penwidth=1.2]
53-
135820316921232 -> 135820316921616 [arrowsize=0.8 color=gray40 penwidth=1.2]
54-
135820316918928 -> 135820316921616 [arrowsize=0.8 color=gray40 penwidth=1.2]
25+
126836299155248 [label=transpose color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
26+
126836299155440 [label="tensor_5
27+
(1, 4)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
28+
126836299155536 [label=matmul color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
29+
126836299155728 [label="tensor_6
30+
(1, 4)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
31+
126836299155824 [label=add color=darkslategray fillcolor=lightgray fontname="Courier Bold" labeljust=l shape=box style=filled]
32+
126836299154000 -> 126836298434400 [arrowsize=0.8 color=gray40 penwidth=1.2]
33+
126836299154720 -> 126836299154480 [arrowsize=0.8 color=gray40 penwidth=1.2]
34+
126836299154432 -> 126836299154000 [arrowsize=0.8 color=gray40 penwidth=1.2]
35+
126836299154624 -> 126836299154000 [arrowsize=0.8 color=gray40 penwidth=1.2]
36+
126836299154528 -> 126836299154720 [arrowsize=0.8 color=gray40 penwidth=1.2]
37+
126836299154336 -> 126836299154720 [arrowsize=0.8 color=gray40 penwidth=1.2]
38+
126836299155248 -> 126836299155152 [arrowsize=0.8 color=gray40 penwidth=1.2]
39+
126836298434400 -> 126836299155248 [arrowsize=0.8 color=gray40 penwidth=1.2]
40+
126836299155104 -> 126836299155248 [arrowsize=0.8 color=gray40 penwidth=1.2]
41+
126836299155536 -> 126836299155440 [arrowsize=0.8 color=gray40 penwidth=1.2]
42+
126836299154912 -> 126836299155536 [arrowsize=0.8 color=gray40 penwidth=1.2]
43+
126836299155152 -> 126836299155536 [arrowsize=0.8 color=gray40 penwidth=1.2]
44+
126836299155824 -> 126836299155728 [arrowsize=0.8 color=gray40 penwidth=1.2]
45+
126836299155440 -> 126836299155824 [arrowsize=0.8 color=gray40 penwidth=1.2]
46+
126836299154480 -> 126836299155824 [arrowsize=0.8 color=gray40 penwidth=1.2]
5547
}

0 commit comments

Comments
 (0)