Skip to content

Commit a7196dd

Browse files
committed
deepxpy:初步支持eager模式IR输出
1 parent 7997620 commit a7196dd

10 files changed

Lines changed: 131 additions & 92 deletions

File tree

front/py/deepx/autograd/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
'Node',
99
'NodeType',
1010
'DataNode',
11+
1112
]

front/py/deepx/autograd/_datanode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from .nodetype import NodeType
33

44
class DataNode(Node):
5-
def __init__(self, name=None, data=None):
5+
def __init__(self, name=None, type=None, data=None):
66
super().__init__(name=name, ntype=NodeType.DATA)
77
self._data = data
8-
8+
self._type=type
99
def data(self):
1010
return self._data
1111

front/py/deepx/autograd/_opnode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def register_op(cls, name: str) -> None:
2424
class OpNode(Node, metaclass=OpNodeMeta):
2525
def __init__(self, name: str):
2626
super().__init__(name=name, ntype=NodeType.OP)
27-
2827
@classmethod
2928
def register(cls, name: str) -> None:
3029
cls.__class__.register_op(name)

front/py/deepx/autograd/graph.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from ._datanode import DataNode
22
from ._opnode import OpNode
33
from ._controlflownode import ControlFlowNode
4-
5-
eager_mode=False
6-
4+
75
class Graph:
86
# 类属性存储默认实例
97
_default_graph = None
@@ -22,29 +20,58 @@ def set_default(cls, graph):
2220
raise TypeError("Must be a Graph instance")
2321
cls._default_graph = graph
2422

25-
def __init__(self,eager=False):
23+
def __init__(self,eager=True):
2624
self.nodes = []
2725
self.inputs = []
28-
self.data_counter = 0
26+
self.var_counter = 0
27+
self.vector_counter = 0
28+
self.tensor_counter = 0
2929
self.control_flow_counter = 0
30-
self.eager=eager or eager_mode
30+
self.eager=eager
31+
32+
@property
33+
def eager(self):
34+
return self._eager
35+
@eager.setter
36+
def eager(self,value):
37+
self._eager=value
3138

32-
def add_data(self, name, data,inputs=[]):
33-
self.data_counter += 1
39+
def add_var(self, name,data,inputs=[]):
40+
self.var_counter += 1
41+
if name == "":
42+
name = f"var_{self.var_counter}"
43+
node=DataNode(name, "var", data)
44+
for input in inputs:
45+
node.add_input(input)
46+
self.nodes.append(node)
47+
return node
48+
49+
def add_vector(self, name,data,inputs=[]):
50+
self.vector_counter += 1
3451
if name == "":
35-
name = f"data_{self.data_counter}"
36-
node=DataNode(name, data)
52+
name = f"vector_{self.vector_counter}"
53+
node=DataNode(name, "vector", data)
3754
for input in inputs:
3855
node.add_input(input)
3956
self.nodes.append(node)
4057
return node
58+
59+
def add_tensor(self, name,data,inputs=[]):
60+
self.tensor_counter += 1
61+
if name == "":
62+
name = f"tensor_{self.tensor_counter}"
63+
node=DataNode(name, "tensor", data)
64+
for input in inputs:
65+
node.add_input(input)
66+
self.nodes.append(node)
67+
return node
68+
69+
4170
def add_op(self,name,inputs=[]):
4271
node=OpNode(name)
4372
for input in inputs:
4473
node.add_input(input)
4574
self.nodes.append(node)
46-
if self.eager:
47-
return node.outputs[0]
4875
return node
4976
def add_control_flow(self,name,inputs=[]):
5077
self.control_flow_counter += 1

front/py/deepx/tensor/deepxir.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Tuple, List, Optional
22
from ..autograd.graph import Graph
33

4-
class Op:
4+
class DeepxIR:
55
def __init__(self,
66
name:str,
77
dtype:str,
@@ -45,21 +45,7 @@ def forward(self, *input) -> Tuple:
4545
def backward(self, *grad_outputs) -> Tuple:
4646
raise NotImplementedError
4747

48-
def get_grad_mapping(self, is_backward: bool) -> List[Tuple[str, str]]:
49-
"""
50-
获取参数与梯度的映射关系
51-
Returns:
52-
返回(参数名,梯度名)元组列表,梯度名为空表示无梯度
53-
"""
54-
if not self._grad or not is_backward:
55-
return [(arg, "") for arg in self._args]
56-
57-
return [
58-
(arg, grad)
59-
for arg, grad in zip(self._args, self._args_grad)
60-
]
61-
62-
def ir(self) -> str:
48+
def __str__(self):
6349
"""生成IR指令的优化实现"""
6450
parts = [f"{self._name}@{self._dtype}"]
6551

front/py/deepx/tensor/dtype.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,49 @@
1-
from enum import Enum, auto
1+
import numpy as np
2+
from typing import Any
23

3-
class DType(Enum):
4-
COMPLEX64 = auto()
5-
COMPLEX128 = auto()
6-
BFLOAT16 = auto()
7-
FLOAT16 = auto()
8-
FLOAT32 = auto()
9-
FLOAT64 = auto()
10-
UINT8 = auto()
11-
INT8 = auto()
12-
UINT16 = auto()
13-
UINT32 = auto()
14-
INT32 = auto()
15-
UINT64 = auto()
16-
INT64 = auto()
17-
BOOL = auto()
4+
DTYPE_MAP = {
5+
'float16': np.float16,
6+
'float32': np.float32,
7+
'float64': np.float64,
8+
'int8': np.int8,
9+
'int16': np.int16,
10+
'int32': np.int32,
11+
'int64': np.int64,
12+
}
1813

19-
def _dtype_to_typestr(dtype):
20-
return {
21-
DType.COMPLEX64: "<c8",
22-
DType.COMPLEX128: "<c16",
23-
DType.BFLOAT16: "<V2",
24-
DType.FLOAT16: "<f2",
25-
DType.FLOAT32: "<f4",
26-
DType.FLOAT64: "<f8",
27-
DType.UINT8: "|u1",
28-
DType.INT8: "|i1",
29-
DType.UINT16: "<u2",
30-
DType.INT32: "<i4",
31-
DType.UINT32: "<u4",
32-
DType.UINT64: "<u8",
33-
DType.INT64: "<i8",
34-
DType.BOOL: "|b1",
35-
}[dtype]
14+
def infer_dtype(data: Any) -> str:
15+
"""
16+
推断数组元素的深度学习数据类型
17+
18+
支持类型优先级(从高到低):
19+
float32 > float64 > int32 > int64 > bool
20+
21+
Args:
22+
data: 输入数据,支持Python原生类型、Numpy数组、列表等
23+
24+
Returns:
25+
str: 数据类型名称('float32', 'int32'等)
26+
27+
Raises:
28+
TypeError: 当包含不支持的数据类型时
29+
"""
30+
# 转换为numpy数组进行类型推断
31+
arr = np.asarray(data)
32+
33+
# 根据数值范围自动选择精度
34+
if np.issubdtype(arr.dtype, np.integer):
35+
if arr.itemsize <= 4:
36+
return 'int32' if arr.min() >= np.iinfo(np.int32).min else 'int64'
37+
return 'int64'
38+
39+
if np.issubdtype(arr.dtype, np.floating):
40+
return 'float32' if arr.itemsize <= 4 else 'float64'
41+
42+
# 处理特殊类型(如对象数组)
43+
if arr.dtype == np.object_:
44+
unique_types = {type(x) for x in arr.flat}
45+
if {int, float} == unique_types:
46+
return 'float32'
47+
raise TypeError(f"混合类型或不支持的类型: {unique_types}")
48+
49+
raise TypeError(f"不支持的数据类型: {arr.dtype}")

front/py/deepx/tensor/elementwise.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from .tensor import Tensor,tensor_method
22
from deepx.autograd.graph import Graph,DataNode,OpNode
3-
3+
from .deepxir import DeepxIR
44
OpNode.register("add")
55

66
@tensor_method
77
def add(self, other):
8-
result = self.graph.add_data("", self)
9-
op = self.graph.add_op("add")
10-
op.add_input(self.node)
11-
op.add_input(other.node)
12-
result.add_input(op)
13-
14-
return result
8+
resultnode = self.graph.add_tensor("", self)
9+
opnode = self.graph.add_op("add")
10+
opnode.add_input(self.node)
11+
opnode.add_input(other.node)
12+
resultnode.add_input(opnode)
13+
if self.graph.eager:
14+
ir=DeepxIR("add", self._dtype, [self.node.name, other.node.name], [resultnode.name])
15+
print(ir)
16+
return resultnode
1517

1618
OpNode.register("mul")
1719

front/py/deepx/tensor/tensor.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,32 @@
33
from .shape import Shape
44
from .devicetype import Device
55
from deepx.autograd import Graph,DataNode
6-
6+
from .deepxir import DeepxIR
7+
from .dtype import infer_dtype,DTYPE_MAP
78
class Tensor:
89
def __init__(self, data=None, shape=None, device=None, dtype=None, graph=None):
10+
# 计算图相关
11+
if graph is None:
12+
self._graph = Graph.get_default()
13+
else:
14+
self._graph = graph
15+
self._node= self._graph.add_tensor("",data=self)
16+
917
# data
1018
if data is not None:
1119
import numpy as np
1220
if not isinstance(data, np.ndarray):
1321
data = np.array(data)
14-
self.data = data
22+
self.data = data
1523
self._shape = Shape(data.shape)
1624

25+
# dtype
26+
if dtype is None:
27+
self._dtype = infer_dtype(data)
28+
else:
29+
self._dtype = dtype
30+
self._data = data.astype(DTYPE_MAP[dtype], casting='safe')
31+
self._node.dtype = dtype
1732
# shape
1833
if shape is not None:
1934
if isinstance(shape, (tuple, list)) and all(isinstance(i, int) for i in shape):
@@ -22,7 +37,13 @@ def __init__(self, data=None, shape=None, device=None, dtype=None, graph=None):
2237
self._shape = shape
2338
else:
2439
raise ValueError("Invalid shape")
25-
40+
shapeNode=self._graph.add_vector("",data=self._shape.shape)
41+
self._node.add_input(shapeNode)
42+
if self._graph.eager:
43+
ir1=DeepxIR("argset", 'int32', self._shape.shape, [shapeNode.name])
44+
print(ir1)
45+
ir2=DeepxIR("newtensor", self._dtype, [shapeNode.name], [self._node.name])
46+
print(ir2)
2647
# device
2748
if isinstance(device, str):
2849
self._device = Device.from_string(device)
@@ -32,17 +53,7 @@ def __init__(self, data=None, shape=None, device=None, dtype=None, graph=None):
3253
self._device = Device.CPU # 默认设备
3354

3455
self._dtype = dtype
35-
36-
# 计算图相关
37-
if graph is None:
38-
self._graph = Graph.get_default()
39-
else:
40-
self._graph = graph
41-
self._node= self._graph.add_data(name="",data=self.data)
42-
43-
44-
self._requires_grad = False
45-
56+
4657
self.data = data
4758
# shape
4859
@property

front/py/examples/2_ir/1_ir.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from deepx.tensor.op import Op
1+
from deepx.tensor.deepxir import DeepxIR
22
# 正向传播示例
3-
op = Op(
3+
op = DeepxIR(
44
name="add",
55
args=["t1", "t2"],
66
returns=["t3"],
@@ -16,7 +16,7 @@
1616
# 输出: add@float32 t1(t1_grad) t2(t2_grad) <- t3(t3_grad)
1717

1818
# 标量操作示例
19-
scalar_op = Op(
19+
scalar_op = DeepxIR(
2020
name="scalar",
2121
args=["t1", "a1"], # a1为参数
2222
returns=["t3"],

front/py/examples/2_ir/2_add.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from deepx.tensor import Tensor
2+
print()
23

34
t1 = Tensor([1,2,3])
45
t2 = Tensor([4,5,6])
5-
print(hasattr(Tensor, 'add'))
66
t3 = t1.add(t2)
7-
g=t3.graph
8-
print(g)
7+

0 commit comments

Comments
 (0)