Skip to content

Commit f6e4062

Browse files
committed
deepxpy:初步实现构建计算图
1 parent 97aa4a9 commit f6e4062

14 files changed

Lines changed: 240 additions & 216 deletions

File tree

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from .graph import Graph
2-
from ..tensor.op import Op
2+
from .node import Node
3+
from .nodetype import NodeType
4+
from ._datanode import DataNode
5+
36
__all__ = [
47
'Graph',
5-
'Op',
8+
'Node',
9+
'NodeType',
10+
'DataNode',
611
]

front/py/deepx/autograd/_constargnode.py

Lines changed: 0 additions & 6 deletions
This file was deleted.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .node import Node
2+
from .nodetype import NodeType
3+
4+
5+
class ControlFlowNode(Node):
6+
def __init__(self, name=None):
7+
super().__init__(name="control_flow", ntype=NodeType.CONTROL_FLOW)
8+
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from .node import Node, NodeType
2-
1+
from .node import Node
2+
from .nodetype import NodeType
3+
34
class DataNode(Node):
4-
def __init__(self, name=None):
5-
super().__init__(name=name, ntype=NodeType.TENSOR)
6-
self._data = None
5+
def __init__(self, name=None, data=None):
6+
super().__init__(name=name, ntype=NodeType.DATA)
7+
self._data = data
78

89
def data(self):
910
return self._data
1011

1112
def set_data(self, data):
12-
from deepx import Tensor
13-
if not isinstance(data, Tensor):
14-
raise TypeError("data must be an instance of Tensor")
15-
self._data = data
13+
self._data = data
14+
15+

front/py/deepx/autograd/_opnode.py

Lines changed: 28 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,133 +1,30 @@
1-
from .node import Node, NodeType
2-
3-
class OpType:
4-
def __init__(self, name, shortchar):
5-
self.name = name
6-
self.shortchar = shortchar
7-
def shortchar(self):
8-
return self.shortchar
9-
# 全局操作类型注册表
10-
_op_types = {}
11-
12-
def regist_op_type(name, shortchar):
13-
"""注册一个操作类型"""
14-
_op_types[name] = OpType(name, shortchar)
15-
16-
class OpNode(Node):
17-
def __init__(self, op_type_name, name=None):
18-
if op_type_name not in _op_types:
19-
raise ValueError(f"Unknown op type: {op_type_name}")
1+
from .node import Node
2+
from .nodetype import NodeType
3+
4+
5+
class OpNodeMeta(type):
6+
"""操作节点元类,负责校验操作名称"""
7+
_registered_ops = set() # 已注册操作名称缓存
8+
9+
def __call__(cls, name: str, *args, **kwargs):
10+
# 在实例化时进行名称校验
11+
if name not in cls._registered_ops:
12+
raise ValueError(
13+
f"Op '{name}' 未注册,请先使用OpNode.register('{name}')注册"
14+
)
15+
return super().__call__(name, *args, **kwargs)
16+
17+
@classmethod
18+
def register_op(cls, name: str) -> None:
19+
"""注册新操作类型"""
20+
if name in cls._registered_ops:
21+
raise ValueError(f"Op '{name}' 已存在")
22+
cls._registered_ops.add(name)
23+
24+
class OpNode(Node, metaclass=OpNodeMeta):
25+
def __init__(self, name: str):
2026
super().__init__(name=name, ntype=NodeType.OP)
21-
self.op_type = _op_types[op_type_name]
22-
23-
def shortchar(self):
24-
return self.op_type.shortchar
25-
26-
27-
regist_op_type("ReLU", "relu")
28-
regist_op_type("Placeholder", "ph")
29-
regist_op_type("Neg", "-")
30-
regist_op_type("Less", "<")
31-
regist_op_type("Equal", "==")
32-
regist_op_type("Sigmoid", "σ")
33-
regist_op_type("Tanh", "tanh")
34-
regist_op_type("Reshape", "reshape")
35-
regist_op_type("Transpose", "T")
36-
regist_op_type("Sum", "Σ")
37-
regist_op_type("Mean", "μ")
38-
39-
# 操作节点创建函数
40-
def matmul(a, b, name=None):
41-
node = OpNode("MatMul", name)
42-
node.add_input("a", a)
43-
node.add_input("b", b)
44-
return node
45-
46-
def add(a, b, name=None):
47-
node = OpNode("Add", name)
48-
node.add_input("a", a)
49-
node.add_input("b", b)
50-
return node
51-
52-
def relu(x, name=None):
53-
node = OpNode("ReLU", name)
54-
node.add_input("x", x)
55-
return node
56-
57-
def placeholder(name=None, shape=None):
58-
node = OpNode("Placeholder", name)
59-
if shape:
60-
node.set_attr("shape", shape)
61-
return node
62-
63-
def neg(x):
64-
node = OpNode("Neg")
65-
node.add_input("x", x)
66-
return node
67-
68-
def mul(a, b):
69-
node = OpNode("Mul")
70-
node.add_input("a", a)
71-
node.add_input("b", b)
72-
return node
73-
74-
def div(a, b):
75-
node = OpNode("Div")
76-
node.add_input("a", a)
77-
node.add_input("b", b)
78-
return node
79-
80-
def sub(a, b):
81-
node = OpNode("Sub")
82-
node.add_input("a", a)
83-
node.add_input("b", b)
84-
return node
85-
86-
def less(a, b):
87-
node = OpNode("Less")
88-
node.add_input("a", a)
89-
node.add_input("b", b)
90-
return node
91-
92-
def equal(a, b):
93-
node = OpNode("Equal")
94-
node.add_input("a", a)
95-
node.add_input("b", b)
96-
return node
97-
98-
def sigmoid(x):
99-
node = OpNode("Sigmoid")
100-
node.add_input("x", x)
101-
return node
102-
103-
def tanh(x):
104-
node = OpNode("Tanh")
105-
node.add_input("x", x)
106-
return node
107-
108-
def reshape(x, shape):
109-
node = OpNode("Reshape")
110-
node.add_input("x", x)
111-
node.set_attr("shape", shape)
112-
return node
113-
114-
def transpose(x, dim0, dim1):
115-
node = OpNode("Transpose")
116-
node.add_input("x", x)
117-
node.set_attr("dim0", dim0)
118-
node.set_attr("dim1", dim1)
119-
return node
120-
121-
def sum(x, dim=None, keepdim=False):
122-
node = OpNode("Sum")
123-
node.add_input("x", x)
124-
node.set_attr("dim", dim)
125-
node.set_attr("keepdim", keepdim)
126-
return node
12727

128-
def mean(x, dim=None, keepdim=False):
129-
node = OpNode("Mean")
130-
node.add_input("x", x)
131-
node.set_attr("dim", dim)
132-
node.set_attr("keepdim", keepdim)
133-
return node
28+
@classmethod
29+
def register(cls, name: str) -> None:
30+
cls.__class__.register_op(name)

front/py/deepx/autograd/graph.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from ._datanode import TensorNode
1+
from ._datanode import DataNode
22
from ._opnode import OpNode
3-
from ._constargnode import ConstArgNode
3+
from ._controlflownode import ControlFlowNode
4+
5+
eager_mode=False
46

57
class Graph:
68
# 类属性存储默认实例
@@ -20,33 +22,38 @@ def set_default(cls, graph):
2022
raise TypeError("Must be a Graph instance")
2123
cls._default_graph = graph
2224

23-
def __init__(self):
25+
def __init__(self,eager=False):
2426
self.nodes = []
2527
self.inputs = []
26-
self.tensor_counter = 0 # 添加计数器
27-
self.constarg_counter = 0
28-
def add_tensor(self, name, dtype, shape, requires_grad):
29-
self.tensor_counter += 1
28+
self.data_counter = 0
29+
self.control_flow_counter = 0
30+
self.eager=eager or eager_mode
31+
32+
def add_data(self, name, data,inputs=[]):
33+
self.data_counter += 1
3034
if name == "":
31-
name = f"tensor_{self.tensor_counter}"
32-
node=TensorNode(name, dtype, shape, requires_grad)
33-
for input in node.inputs:
34-
node.add_input(input.name, input)
35+
name = f"data_{self.data_counter}"
36+
node=DataNode(name, data)
37+
for input in inputs:
38+
node.add_input(input)
3539
self.nodes.append(node)
3640
return node
37-
def add_op(self,name,inputs):
41+
def add_op(self,name,inputs=[]):
3842
node=OpNode(name)
3943
for input in inputs:
40-
node.add_input(input.name, input)
44+
node.add_input(input)
4145
self.nodes.append(node)
46+
if self.eager:
47+
return node.outputs[0]
4248
return node
43-
def add_constarg(self, value):
44-
self.constarg_counter += 1
49+
def add_control_flow(self,name,inputs=[]):
50+
self.control_flow_counter += 1
4551
if name == "":
46-
name = f"constarg_{self.constarg_counter}"
47-
node=ConstArgNode(value)
52+
name = f"control_flow_{self.control_flow_counter}"
53+
node=ControlFlowNode(name)
54+
for input in inputs:
55+
node.add_input(input)
4856
self.nodes.append(node)
4957
return node
50-
5158
# 初始化默认图
5259
Graph._default_graph = Graph()

front/py/deepx/autograd/graph_viz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .graph import Graph
33
from ._datanode import TensorNode
44
from ._opnode import OpNode
5-
from ._constargnode import ConstArgNode
5+
from ._controlflownode import ConstArgNode
66

77
def graph_method(f):
88
"""装饰器:将函数注册为Graph类的方法"""

front/py/deepx/autograd/node.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
from enum import IntEnum
2-
3-
class NodeType(IntEnum):
4-
TENSOR = 0
5-
OP = 1
6-
CONST_ARG = 2
7-
2+
from .nodetype import NodeType
3+
84
class Node:
95
def __init__(self,
10-
ntype:NodeType=NodeType.TENSOR,
6+
ntype:NodeType=None,
117
name:str=None,
128
graph =None):
139
from .graph import Graph
@@ -18,37 +14,19 @@ def __init__(self,
1814

1915
self._ntype = ntype
2016
self._name = name
21-
self._inputs = {}
22-
self.attrs = {}
23-
17+
self._inputs = []
18+
2419
@property
2520
def ntype(self):
2621
return self._ntype
2722

23+
@property
24+
def graph(self):
25+
return self._graph
26+
2827
@property
2928
def name(self):
3029
return self._name
3130

32-
@property
33-
def input(self, name=None):
34-
if name is None:
35-
return self._inputs
36-
else:
37-
return self._inputs.get(name)
38-
39-
def add_input(self, name, input_node):
40-
self._inputs[name] = input_node
41-
input_node.outputs.append(self)
42-
43-
def remove_input(self, name):
44-
if name in self._inputs:
45-
node = self._inputs[name]
46-
if self in node.outputs:
47-
node.outputs.remove(self)
48-
del self._inputs[name]
49-
50-
def set_attr(self, key, value):
51-
self.attrs[key] = value
52-
53-
def get_attr(self, key):
54-
return self.attrs.get(key)
31+
def add_input(self, input_node):
32+
self._inputs.append(input_node)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from enum import IntEnum, EnumMeta
2+
from typing import Dict, Any
3+
4+
5+
class NodeType(IntEnum ):
6+
DATA = 0
7+
OP = 1
8+
CONTROL_FLOW = 2
9+

front/py/deepx/autograd/variable.py

Whitespace-only changes.

0 commit comments

Comments
 (0)