11import graphviz
2- from .graph import Graph
3- from ._datanode import TensorNode
2+ from .graph import Graph , graph_method
3+ from ._datanode import DataNode
44from ._opnode import OpNode
5- from ._controlflownode import ConstArgNode
65
7- def graph_method (f ):
8- """装饰器:将函数注册为Graph类的方法"""
9- # 延迟到模块加载完成后再绑定方法
10- from .graph import Graph
11- setattr (Graph , f .__name__ , f )
12- return f
136
147@graph_method
15- def to_dot (self ):
8+ def to_dot (self )-> graphviz . Digraph :
169 """生成DOT格式的计算图可视化"""
1710 dot = graphviz .Digraph (comment = 'Computational Graph' )
1811 dot .attr (rankdir = 'TB' ) # 从上到下的布局
@@ -27,49 +20,51 @@ def to_dot(self):
2720 'labeljust' : 'l'
2821 }
2922
30- if isinstance (node , TensorNode ):
31- # 张量节点:显示形状和梯度信息
32- label = f"{ node .name } \n { node .tensor ().shape if node .tensor () else '' } "
33- attrs .update ({
34- 'shape' : 'box' ,
35- 'color' : 'skyblue' ,
36- 'style' : 'filled' ,
37- 'fillcolor' : 'aliceblue'
38- })
23+ if isinstance (node , DataNode ):
24+ label = f'{ node .name } \n '
25+ match (node ._type ):
26+ case "var" :
27+ label += f"{ node .data } "
28+ attrs .update ({
29+ 'shape' : 'box' ,
30+ 'color' : 'orange' ,
31+ 'style' : 'filled' ,
32+ 'fillcolor' : 'moccasin'
33+ })
34+ case "vector" :
35+ label += f"{ node .data } "
36+ attrs .update ({
37+ 'shape' : 'box' ,
38+ 'color' : 'darkseagreen' ,
39+ 'style' : 'filled' ,
40+ 'fillcolor' : 'honeydew'
41+ })
42+ case "tensor" :
43+ label += f"{ node .data .shape if node .data else '' } "
44+ attrs .update ({
45+ 'shape' : 'box' ,
46+ 'color' : 'skyblue' ,
47+ 'style' : 'filled' ,
48+ 'fillcolor' : 'aliceblue'
49+ })
50+
3951
4052 elif isinstance (node , OpNode ):
4153 # 操作节点:突出显示操作类型
42- label = node .shortchar ()
54+ label = node .name
4355 attrs .update ({
4456 'shape' : 'box' ,
4557 'style' : 'filled' ,
4658 'fillcolor' : 'lightgray' ,
4759 'color' : 'darkslategray' ,
4860 'fontname' : 'Courier Bold'
4961 })
50-
51- elif isinstance (node , ConstArgNode ):
52- # 常量参数节点:显示参数值
53- if node .arg_type == 'int' :
54- value = str (node .get_int ())
55- elif node .arg_type == 'float' :
56- value = f"{ node .get_float ():.2f} f"
57- else : # string
58- value = node .get_string ()
59- label = value
60- attrs .update ({
61- 'shape' : 'diamond' ,
62- 'style' : 'filled' ,
63- 'fillcolor' : 'lightyellow' ,
64- 'color' : 'goldenrod'
65- })
66-
6762 # 添加节点
6863 dot .node (str (id (node )), label , ** attrs )
6964
7065 # 添加边连接
7166 for node in self .nodes :
72- for input_name , input_node in node .inputs (). items () :
67+ for input_node in node .inputs :
7368 dot .edge (
7469 str (id (input_node )),
7570 str (id (node )),
0 commit comments