Skip to content

Commit cc078a1

Browse files
committed
py:解决循环import
1 parent 7528a06 commit cc078a1

3 files changed

Lines changed: 29 additions & 16 deletions

File tree

front/py/deepx/autograd/_tensornode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .node import Node, NodeType
2-
from deepx import Tensor
32

43
class TensorNode(Node):
54
def __init__(self, name=None):
@@ -10,6 +9,7 @@ def tensor(self):
109
return self._tensor
1110

1211
def set_tensor(self, tensor):
12+
from deepx import Tensor
1313
if not isinstance(tensor, Tensor):
1414
raise TypeError("tensor must be an instance of Tensor")
1515
self._tensor = tensor

front/py/deepx/autograd/node.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,35 @@ class NodeType(IntEnum):
66
CONST_ARG = 2
77

88
class Node:
9-
def __init__(self, name=None, ntype=NodeType.TENSOR):
10-
self.ntype = ntype
11-
self.name = name
9+
def __init__(self,
10+
ntype:NodeType=NodeType.TENSOR,
11+
name:str=None,
12+
graph =None):
13+
from .graph import Graph
14+
if graph == None:
15+
self._graph = Graph.get_default()
16+
else:
17+
self._graph = graph
18+
19+
self._ntype = ntype
20+
self._name = name
1221
self._inputs = {}
1322
self.attrs = {}
1423

24+
@property
1525
def ntype(self):
16-
return self.ntype
26+
return self._ntype
1727

28+
@property
1829
def name(self):
19-
return self.name
20-
21-
def inputs(self):
22-
return self._inputs
23-
24-
def input(self, name):
25-
return self._inputs.get(name)
30+
return self._name
31+
32+
@property
33+
def input(self, name=None):
34+
if name is None:
35+
return self._inputs
36+
else:
37+
return self._inputs.get(name)
2638

2739
def add_input(self, name, input_node):
2840
self._inputs[name] = input_node

front/py/deepx/autograd/op.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any, Tuple, List, Optional
2-
from deepx import Tensor
3-
from deepx.autograd.graph import Graph
2+
from .graph import Graph
43

54
class Op:
65
def __init__(self,args:List[str],returns:List[str],grad:bool,args_grad:List[str],returns_grad:List[str]):
@@ -11,9 +10,11 @@ def __init__(self,args:List[str],returns:List[str],grad:bool,args_grad:List[str]
1110
self._grad=grad
1211
self._args_grad=args_grad
1312
self._returns_grad=returns_grad
14-
def forward(self, *input:Tensor) -> Tuple[Tensor, ...]:
13+
14+
def forward(self, *input) -> Tuple:
1515
raise NotImplementedError
16-
def backward(self, *grad_outputs: Tensor) -> Tuple[Optional[Tensor], ...]:
16+
17+
def backward(self, *grad_outputs) -> Tuple:
1718
raise NotImplementedError
1819

1920
def to_ir(self, dtype: str, is_backward: bool = False) -> str:

0 commit comments

Comments
 (0)