Skip to content

Commit 0e366bb

Browse files
committed
动态计算图验证:
1 parent 5397848 commit 0e366bb

31 files changed

Lines changed: 394 additions & 119 deletions

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,8 @@
22
**/build/
33
.idea
44
**/.idea
5-
**/__pycache__/
5+
**/__pycache__/
6+
**/dist/
7+
**/egg.info/
8+
front/py/deepx/deepx.egg-info/*
9+
*.pdf

front/py/deepx/autograd/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from .graph import Graph
2+
from .graph_viz import to_dot
23
from .node import Node
34
from .nodetype import NodeType
45
from ._datanode import DataNode
56
from ._opnode import OpNode
7+
68
__all__ = [
79
'Graph',
810
'Node',

front/py/deepx/autograd/_datanode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ def __init__(self, name=None, type=None, data=None):
66
super().__init__(name=name, ntype=NodeType.DATA)
77
self._data = data
88
self._type=type
9+
@property
910
def data(self):
1011
return self._data
1112

front/py/deepx/autograd/graph.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,5 +82,10 @@ def add_control_flow(self,name,inputs=[]):
8282
node.add_input(input)
8383
self.nodes.append(node)
8484
return node
85+
def graph_method(f):
86+
setattr(Graph, f.__name__, f)
87+
return f
88+
89+
8590
# 初始化默认图
8691
Graph._default_graph = Graph()

front/py/deepx/autograd/graph_viz.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
11
import graphviz
2-
from .graph import Graph
3-
from ._datanode import TensorNode
2+
from .graph import Graph,graph_method
3+
from ._datanode import DataNode
44
from ._opnode import OpNode
5-
from ._controlflownode import ConstArgNode
65

7-
def graph_method(f):
8-
"""装饰器:将函数注册为Graph类的方法"""
9-
# 延迟到模块加载完成后再绑定方法
10-
from .graph import Graph
11-
setattr(Graph, f.__name__, f)
12-
return f
136

147
@graph_method
15-
def to_dot(self):
8+
def to_dot(self)->graphviz.Digraph:
169
"""生成DOT格式的计算图可视化"""
1710
dot = graphviz.Digraph(comment='Computational Graph')
1811
dot.attr(rankdir='TB') # 从上到下的布局
@@ -27,49 +20,51 @@ def to_dot(self):
2720
'labeljust': 'l'
2821
}
2922

30-
if isinstance(node, TensorNode):
31-
# 张量节点:显示形状和梯度信息
32-
label = f"{node.name}\n{node.tensor().shape if node.tensor() else ''}"
33-
attrs.update({
34-
'shape': 'box',
35-
'color': 'skyblue',
36-
'style': 'filled',
37-
'fillcolor': 'aliceblue'
38-
})
23+
if isinstance(node, DataNode):
24+
label=f'{node.name}\n'
25+
match(node._type):
26+
case "var":
27+
label += f"{node.data}"
28+
attrs.update({
29+
'shape': 'box',
30+
'color': 'orange',
31+
'style': 'filled',
32+
'fillcolor': 'moccasin'
33+
})
34+
case "vector":
35+
label += f"{node.data}"
36+
attrs.update({
37+
'shape': 'box',
38+
'color': 'darkseagreen',
39+
'style': 'filled',
40+
'fillcolor': 'honeydew'
41+
})
42+
case "tensor":
43+
label += f"{node.data.shape if node.data else ''}"
44+
attrs.update({
45+
'shape': 'box',
46+
'color': 'skyblue',
47+
'style': 'filled',
48+
'fillcolor': 'aliceblue'
49+
})
50+
3951

4052
elif isinstance(node, OpNode):
4153
# 操作节点:突出显示操作类型
42-
label = node.shortchar()
54+
label = node.name
4355
attrs.update({
4456
'shape': 'box',
4557
'style': 'filled',
4658
'fillcolor': 'lightgray',
4759
'color': 'darkslategray',
4860
'fontname': 'Courier Bold'
4961
})
50-
51-
elif isinstance(node, ConstArgNode):
52-
# 常量参数节点:显示参数值
53-
if node.arg_type == 'int':
54-
value = str(node.get_int())
55-
elif node.arg_type == 'float':
56-
value = f"{node.get_float():.2f}f"
57-
else: # string
58-
value = node.get_string()
59-
label = value
60-
attrs.update({
61-
'shape': 'diamond',
62-
'style': 'filled',
63-
'fillcolor': 'lightyellow',
64-
'color': 'goldenrod'
65-
})
66-
6762
# 添加节点
6863
dot.node(str(id(node)), label, **attrs)
6964

7065
# 添加边连接
7166
for node in self.nodes:
72-
for input_name, input_node in node.inputs().items():
67+
for input_node in node.inputs:
7368
dot.edge(
7469
str(id(input_node)),
7570
str(id(node)),

front/py/deepx/autograd/node.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ def graph(self):
2727
@property
2828
def name(self):
2929
return self._name
30-
30+
31+
@property
32+
def inputs(self):
33+
return self._inputs
34+
35+
3136
def add_input(self, input_node):
3237
self._inputs.append(input_node)

front/py/deepx/nn/functional/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from .new import newtensor
33
from .print import printtensor
44
from .matmul import matmul
5-
from .init import constant,full,zeros,ones,uniform_,arange,rand,randn,eye
5+
from .init import constant,full,zeros,ones,uniform,arange,rand,randn,eye
66
from .reduce import max,min,sum,prod,mean
77
from .transpose import transpose
88
from .activite import relu
99

1010
__all__ = [
1111
"newtensor",
1212
"printtensor",
13-
"constant","full","zeros","ones","uniform_","arange","rand","randn","eye",
13+
"constant","full","zeros","ones","uniform","arange","rand","randn","eye",
1414
"add","sub","mul","div","clamp",
1515
"matmul",
1616
"max","min","sum","prod","mean",

front/py/deepx/nn/functional/activite.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from deepx.scheduler import send
44

55
def relu(t: Tensor,inplace:bool=False)->Tensor:
6-
out=t
7-
if not inplace:
6+
from .reduce import max as max_func
7+
if inplace:
8+
max_func(t,0,t)
9+
else:
810
out=Tensor(shape=t.shape, dtype=t.dtype, device=t.device)
9-
ir=DeepxIR("max_scalar",t.dtype,[t._node.name,0], [out._node.name])
10-
send(ir)
11+
max_func(t,0,None,out)
1112
return out
1213

1314
# 数学公式:σ(x) = 1 / (1 + exp(-x))

front/py/deepx/nn/functional/elementwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ def _A_b_elementwiseop_C(
2121
b: Optional[Union[ float, int]] = None,
2222
op:str=None,
2323
out:Tensor=None):
24+
varnode=a.graph.add_var("",b)
2425
opnode = a.graph.add_op(op)
2526
opnode.add_input(a.node)
26-
varnode=a.graph.add_var("",b)
2727
opnode.add_input(varnode)
2828
out.node.add_input(opnode)
2929
if a.graph.eager:
3030
varir=DeepxIR("argset", a.dtype, [b], [varnode.name])
31-
send(str(varir))
31+
send(varir)
3232
ir=DeepxIR(op, a.dtype, [a.node.name,varnode.name], [out.node.name])
3333
send(ir)
3434
#add

front/py/deepx/nn/functional/init.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
from typing import Optional
1+
from typing import Optional,Union
22

33
from deepx import Tensor
44
from deepx.autograd.graph import OpNode
55
from deepx.nn.deepxir import DeepxIR
66
from deepx.scheduler import send
77

8-
def constant(t:Tensor, fill_value):
8+
OpNode.register("constant")
9+
10+
def constant(t:Tensor, value:Optional[Union[
11+
float,int
12+
]]=None) -> Tensor:
913
opnode = t.graph.add_op("constant")
10-
opnode.add_input(t.node)
11-
argnode=t.graph.add_var('',fill_value)
14+
argnode=t.graph.add_var('',value)
1215
opnode.add_input(argnode)
1316
t.node.add_input(opnode)
1417
if t.graph.eager:
15-
ir=DeepxIR("constant", t.dtype, [fill_value], [t.node.name])
18+
ir=DeepxIR("constant", t.dtype, [value], [t.node.name])
1619
send(ir)
1720
return t
1821

@@ -30,13 +33,15 @@ def ones(*size, dtype=None, device=None):
3033
return full(*size, fill_value=1, dtype=dtype, device=device)
3134

3235
OpNode.register("uniform")
33-
def uniform_(t:Tensor,low=0, high=1)->Tensor:
36+
def uniform(t:Tensor,low=0, high=1)->Tensor:
3437
if low >= high:
3538
raise ValueError(f"low({low})必须小于high({high})")
36-
opnode = t.graph.add_op("uniform")
37-
opnode.add_input(t.node)
38-
arglow=t.graph.add_var('',low)
39-
arghigh=t.graph.add_var('',high)
39+
if t is None:
40+
raise ValueError("t不能为None")
41+
g=t.graph
42+
arglow=g.add_var('',low)
43+
arghigh=g.add_var('',high)
44+
opnode = g.add_op("uniform")
4045
opnode.add_input(arglow)
4146
opnode.add_input(arghigh)
4247
t.node.add_input(opnode)

0 commit comments

Comments
 (0)