1- from .node import Node , NodeType
2-
3- class OpType :
4- def __init__ (self , name , shortchar ):
5- self .name = name
6- self .shortchar = shortchar
7- def shortchar (self ):
8- return self .shortchar
9- # 全局操作类型注册表
10- _op_types = {}
11-
12- def regist_op_type (name , shortchar ):
13- """注册一个操作类型"""
14- _op_types [name ] = OpType (name , shortchar )
15-
16- class OpNode (Node ):
17- def __init__ (self , op_type_name , name = None ):
18- if op_type_name not in _op_types :
19- raise ValueError (f"Unknown op type: { op_type_name } " )
1+ from .node import Node
2+ from .nodetype import NodeType
3+
4+
5+ class OpNodeMeta (type ):
6+ """操作节点元类,负责校验操作名称"""
7+ _registered_ops = set () # 已注册操作名称缓存
8+
9+ def __call__ (cls , name : str , * args , ** kwargs ):
10+ # 在实例化时进行名称校验
11+ if name not in cls ._registered_ops :
12+ raise ValueError (
13+ f"Op '{ name } ' 未注册,请先使用OpNode.register('{ name } ')注册"
14+ )
15+ return super ().__call__ (name , * args , ** kwargs )
16+
17+ @classmethod
18+ def register_op (cls , name : str ) -> None :
19+ """注册新操作类型"""
20+ if name in cls ._registered_ops :
21+ raise ValueError (f"Op '{ name } ' 已存在" )
22+ cls ._registered_ops .add (name )
23+
24+ class OpNode (Node , metaclass = OpNodeMeta ):
25+ def __init__ (self , name : str ):
2026 super ().__init__ (name = name , ntype = NodeType .OP )
21- self .op_type = _op_types [op_type_name ]
22-
23- def shortchar (self ):
24- return self .op_type .shortchar
25-
26-
27- regist_op_type ("ReLU" , "relu" )
28- regist_op_type ("Placeholder" , "ph" )
29- regist_op_type ("Neg" , "-" )
30- regist_op_type ("Less" , "<" )
31- regist_op_type ("Equal" , "==" )
32- regist_op_type ("Sigmoid" , "σ" )
33- regist_op_type ("Tanh" , "tanh" )
34- regist_op_type ("Reshape" , "reshape" )
35- regist_op_type ("Transpose" , "T" )
36- regist_op_type ("Sum" , "Σ" )
37- regist_op_type ("Mean" , "μ" )
38-
39- # 操作节点创建函数
40- def matmul (a , b , name = None ):
41- node = OpNode ("MatMul" , name )
42- node .add_input ("a" , a )
43- node .add_input ("b" , b )
44- return node
45-
46- def add (a , b , name = None ):
47- node = OpNode ("Add" , name )
48- node .add_input ("a" , a )
49- node .add_input ("b" , b )
50- return node
51-
52- def relu (x , name = None ):
53- node = OpNode ("ReLU" , name )
54- node .add_input ("x" , x )
55- return node
56-
57- def placeholder (name = None , shape = None ):
58- node = OpNode ("Placeholder" , name )
59- if shape :
60- node .set_attr ("shape" , shape )
61- return node
62-
63- def neg (x ):
64- node = OpNode ("Neg" )
65- node .add_input ("x" , x )
66- return node
67-
68- def mul (a , b ):
69- node = OpNode ("Mul" )
70- node .add_input ("a" , a )
71- node .add_input ("b" , b )
72- return node
73-
74- def div (a , b ):
75- node = OpNode ("Div" )
76- node .add_input ("a" , a )
77- node .add_input ("b" , b )
78- return node
79-
80- def sub (a , b ):
81- node = OpNode ("Sub" )
82- node .add_input ("a" , a )
83- node .add_input ("b" , b )
84- return node
85-
86- def less (a , b ):
87- node = OpNode ("Less" )
88- node .add_input ("a" , a )
89- node .add_input ("b" , b )
90- return node
91-
92- def equal (a , b ):
93- node = OpNode ("Equal" )
94- node .add_input ("a" , a )
95- node .add_input ("b" , b )
96- return node
97-
98- def sigmoid (x ):
99- node = OpNode ("Sigmoid" )
100- node .add_input ("x" , x )
101- return node
102-
103- def tanh (x ):
104- node = OpNode ("Tanh" )
105- node .add_input ("x" , x )
106- return node
107-
108- def reshape (x , shape ):
109- node = OpNode ("Reshape" )
110- node .add_input ("x" , x )
111- node .set_attr ("shape" , shape )
112- return node
113-
114- def transpose (x , dim0 , dim1 ):
115- node = OpNode ("Transpose" )
116- node .add_input ("x" , x )
117- node .set_attr ("dim0" , dim0 )
118- node .set_attr ("dim1" , dim1 )
119- return node
120-
121- def sum (x , dim = None , keepdim = False ):
122- node = OpNode ("Sum" )
123- node .add_input ("x" , x )
124- node .set_attr ("dim" , dim )
125- node .set_attr ("keepdim" , keepdim )
126- return node
12727
128- def mean (x , dim = None , keepdim = False ):
129- node = OpNode ("Mean" )
130- node .add_input ("x" , x )
131- node .set_attr ("dim" , dim )
132- node .set_attr ("keepdim" , keepdim )
133- return node
28+ @classmethod
29+ def register (cls , name : str ) -> None :
30+ cls .__class__ .register_op (name )
0 commit comments