1+ from typing import Optional , Union
2+ from deepx .tensor import Tensor
3+ from deepx .autograd .graph import Graph ,DataNode ,OpNode
4+ from deepx .nn .deepxir import DeepxIR
5+ from deepx .scheduler import send
6+
7+ def _A_B_elementwiseop_C (
8+ a :Tensor ,
9+ b : Tensor ,
10+ op :str = None ,
11+ out :Tensor = None ):
12+ opnode = a .graph .add_op (op )
13+ opnode .add_input (a .node )
14+ opnode .add_input (b .node )
15+ out .node .add_input (opnode )
16+ if a .graph .eager :
17+ ir = DeepxIR (op , a .dtype , [a .node .name , b .node .name ], [out .node .name ])
18+ send (str (ir ))
19+ def _A_b_elementwiseop_C (
20+ a :Tensor ,
21+ b : Optional [Union [ float , int ]] = None ,
22+ op :str = None ,
23+ out :Tensor = None ):
24+ opnode = a .graph .add_op (op )
25+ opnode .add_input (a .node )
26+ varnode = a .graph .add_var ("" ,b )
27+ opnode .add_input (varnode )
28+ out .node .add_input (opnode )
29+ if a .graph .eager :
30+ varir = DeepxIR ("argset" , a .dtype , [b ], [varnode .name ])
31+ send (str (varir ))
32+ ir = DeepxIR (op + "_scalar" , a .dtype , [a .node .name ,varnode .name ], [out .node .name ])
33+ send (str (ir ))
34+ #add
35+ OpNode .register ("add" )
36+ OpNode .register ("add_scalar" )
37+
38+ def add (
39+ a :Tensor ,
40+ b : Optional [Union [Tensor , float , int ]] = None ,
41+ out :Tensor = None ):
42+ if isinstance (b ,Tensor ):
43+ _A_B_elementwiseop_C (a ,b ,"add" ,out )
44+ else :
45+ _A_b_elementwiseop_C (a ,b ,"add" ,out )
46+
47+
48+ #sub
49+ OpNode .register ("sub" )
50+ OpNode .register ("sub_scalar" )
51+
52+ def sub (
53+ a :Tensor ,
54+ b : Optional [Union [Tensor , float , int ]] = None ,
55+ out :Tensor = None ):
56+ if isinstance (b ,Tensor ):
57+ _A_B_elementwiseop_C (a ,b ,"sub" ,out )
58+ else :
59+ _A_b_elementwiseop_C (a ,b ,"sub" ,out )
60+
61+
62+ #mul
63+ OpNode .register ("mul" )
64+ OpNode .register ("mul_scalar" )
65+
66+ def mul (
67+ a :Tensor ,
68+ b : Optional [Union [Tensor , float , int ]] = None ,
69+ out :Tensor = None ):
70+ if isinstance (b ,Tensor ):
71+ _A_B_elementwiseop_C (a ,b ,"mul" ,out )
72+ else :
73+ _A_b_elementwiseop_C (a ,b ,"mul" ,out )
74+
75+
76+ #div
77+ OpNode .register ("div" )
78+ OpNode .register ("div_scalar" )
79+
80+ def div (
81+ a :Tensor ,
82+ b : Optional [Union [Tensor , float , int ]] = None ,
83+ out :Tensor = None ):
84+ if isinstance (b ,Tensor ):
85+ _A_B_elementwiseop_C (a ,b ,"div" ,out )
86+ else :
87+ _A_b_elementwiseop_C (a ,b ,"div" ,out )
88+
89+
90+
91+
92+ # OpNode.register("ReLU", 101)
93+ # OpNode.register("Placeholder", 102)
94+ # OpNode.register("Neg", 103)
95+ # NodeType.register("Less", 104)
96+ # NodeType.register("Equal", 105)
97+ # NodeType.register("Sigmoid", 106)
98+ # NodeType.register("Tanh", 107)
99+ # NodeType.register("Reshape", 108)
100+ # NodeType.register("Transpose", 109)
101+ # NodeType.register("Sum", 110)
102+ # NodeType.register("Mean", 111)
103+
104+ # # 操作节点创建函数
105+ # def matmul(a, b, name=None):
106+ # node = OpNode("MatMul", name)
107+ # node.add_input("a", a)
108+ # node.add_input("b", b)
109+ # return node
110+
111+ # def add(a, b, name=None):
112+ # node = OpNode("Add", name)
113+ # node.add_input("a", a)
114+ # node.add_input("b", b)
115+ # return node
116+
117+ # def relu(x, name=None):
118+ # node = OpNode("ReLU", name)
119+ # node.add_input("x", x)
120+ # return node
121+
122+ # def placeholder(name=None, shape=None):
123+ # node = OpNode("Placeholder", name)
124+ # if shape:
125+ # node.set_attr("shape", shape)
126+ # return node
127+
128+ # def neg(x):
129+ # node = OpNode("Neg")
130+ # node.add_input("x", x)
131+ # return node
132+
133+ # def mul(a, b):
134+ # node = OpNode("Mul")
135+ # node.add_input("a", a)
136+ # node.add_input("b", b)
137+ # return node
138+
139+ # def div(a, b):
140+ # node = OpNode("Div")
141+ # node.add_input("a", a)
142+ # node.add_input("b", b)
143+ # return node
144+
145+ # def sub(a, b):
146+ # node = OpNode("Sub")
147+ # node.add_input("a", a)
148+ # node.add_input("b", b)
149+ # return node
150+
151+ # def less(a, b):
152+ # node = OpNode("Less")
153+ # node.add_input("a", a)
154+ # node.add_input("b", b)
155+ # return node
156+
157+ # def equal(a, b):
158+ # node = OpNode("Equal")
159+ # node.add_input("a", a)
160+ # node.add_input("b", b)
161+ # return node
162+
163+ # def sigmoid(x):
164+ # node = OpNode("Sigmoid")
165+ # node.add_input("x", x)
166+ # return node
167+
168+ # def tanh(x):
169+ # node = OpNode("Tanh")
170+ # node.add_input("x", x)
171+ # return node
172+
173+ # def reshape(x, shape):
174+ # node = OpNode("Reshape")
175+ # node.add_input("x", x)
176+ # node.set_attr("shape", shape)
177+ # return node
178+
179+ # def transpose(x, dim0, dim1):
180+ # node = OpNode("Transpose")
181+ # node.add_input("x", x)
182+ # node.set_attr("dim0", dim0)
183+ # node.set_attr("dim1", dim1)
184+ # return node
185+
186+ # def sum(x, dim=None, keepdim=False):
187+ # node = OpNode("Sum")
188+ # node.add_input("x", x)
189+ # node.set_attr("dim", dim)
190+ # node.set_attr("keepdim", keepdim)
191+ # return node
192+
193+ # def mean(x, dim=None, keepdim=False):
194+ # node = OpNode("Mean")
195+ # node.add_input("x", x)
196+ # node.set_attr("dim", dim)
197+ # node.set_attr("keepdim", keepdim)
198+ # return node
0 commit comments