Skip to content

Commit f7319de

Browse files
committed
linear:linear验证
1 parent 0e366bb commit f7319de

File tree

7 files changed

+77
-5
lines changed

7 files changed

+77
-5
lines changed

excuter/common/src/deepx/mem/mem.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ namespace deepx::mem
133133
template <typename T>
134134
shared_ptr<Tensor<T>> gettensor(const string &name) const
135135
{
136+
if (mem.find(name)== mem.end())
137+
{
138+
throw std::runtime_error("tensor not found: " + name);
139+
}
136140
auto ptr = mem.at(name);
137141
return std::static_pointer_cast<Tensor<T>>(ptr);
138142
}

front/py/deepx/nn/modules/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, in_features, out_features, bias=True,dtype:str="float32"):
1313
self.bias = None
1414

1515
def forward(self, input):
16-
output=input.matmul_(self.weight.T)
16+
output=input.matmul(self.weight.T)
1717
if self.bias is not None:
1818
output=output+self.bias
1919
return output

front/py/deepx/nn/modules/module.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
Any, List, overload)
44
from collections import OrderedDict
55
from deepx import Tensor
6-
6+
from deepx.autograd import Graph
77
class Module:
88
def __init__(self, name: Optional[str] = None):
9+
self._graph=Graph.get_default()
910
self._name = name or self._generate_default_name()
1011
self._parent: Optional[Module] = None
1112
self._modules: OrderedDict[str, Module] = OrderedDict()
@@ -20,6 +21,10 @@ def _generate_default_name(self) -> str:
2021
self.__class__._instance_counter += 1
2122
return f"{base_name}_{count}"
2223

24+
@property
25+
def graph(self):
26+
return self._graph
27+
2328
def __setattr__(self, name: str, value: Any) -> None:
2429
if not name.startswith('_'):
2530
if isinstance(value, Module):

front/py/deepx/tensor/tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ def __truediv__(self, other):
109109
def __matmul__(self, other):
110110
return self.matmul(other)
111111

112+
#自动转置最后两个维度,适用于二维矩阵
113+
@property
114+
def T(self) -> str:
115+
return self.transpose(1,0)
116+
112117
def __repr__(self) -> str:
113118
from deepx.nn.functional import printtensor
114119
s=printtensor(self)
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from deepx.nn.modules import Linear, Module
2-
2+
from deepx import Tensor
33
class Net(Module):
44
def __init__(self):
55
super().__init__()
6-
self.fc1 = Linear(10, 5)
7-
self.fc2 = Linear(5, 1)
6+
self.fc1 = Linear(64, 32)
7+
self.fc2 = Linear(32, 2)
88
def forward(self, x):
99
x = self.fc1(x)
1010
x = self.fc2(x)
1111
return x
1212

1313
net = Net()
14+
input=Tensor(shape=[1,64])
15+
out=net.forward(input)
16+
net.graph.to_dot().render('linear.dot',format='svg')
17+
1418
for name, param in net.named_parameters():
1519
print(f"{name}: {param.shape}")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Computational Graph
2+
digraph {
3+
rankdir=TB
4+
node [shape=record]
5+
125633308838112 [label="tensor_1
6+
(5, 10)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
7+
125633308837296 [label="tensor_2
8+
(5,)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
9+
125633308838160 [label="tensor_3
10+
(1, 5)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
11+
125633308838400 [label="tensor_4
12+
(1,)" color=skyblue fillcolor=aliceblue fontname="Sans-Serif" labeljust=l shape=box style=filled]
13+
}
Lines changed: 41 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)