88#include " deepx/op/print.hpp"
99#include " deepx/op/changeshape.hpp"
1010namespace deepx ::op
11- {
12- // new
13- void register_new (OpFactory &opfactory){
11+ {
12+ // tensor
13+ void register_lifecycle (OpFactory &opfactory)
14+ {
1415 opfactory.add_op (NewTensor<int8_t >());
1516 opfactory.add_op (NewTensor<int16_t >());
1617 opfactory.add_op (NewTensor<int32_t >());
@@ -35,114 +36,84 @@ namespace deepx::op
3536 opfactory.add_op (ArgSet<int32_t >());
3637 opfactory.add_op (ArgSet<float >());
3738 opfactory.add_op (ArgSet<double >());
38- }
39- // del
40- void register_del (OpFactory &opfactory){
39+
4140 opfactory.add_op (DelTensor<float >());
4241 }
43- // init
44- void register_uniform (OpFactory &opfactory){
42+
43+ // init
44+ void register_init (OpFactory &opfactory)
45+ {
4546 opfactory.add_op (Uniform<float >());
4647 opfactory.add_op (Uniform<double >());
47- }
48- void register_constant (OpFactory &opfactory){
48+
4949 opfactory.add_op (Constant<float >());
5050 opfactory.add_op (Constant<double >());
51- }
52- void register_arange (OpFactory &opfactory){
51+
5352 opfactory.add_op (Arange<float >());
5453 opfactory.add_op (Arange<double >());
5554 }
56- void register_init (OpFactory &opfactory){
57- register_uniform (opfactory);
58- register_constant (opfactory);
59- register_arange (opfactory);
60- }
61- // anytype
62- void register_anytype (OpFactory &opfactory){
55+ // io
56+ void register_util (OpFactory &opfactory)
57+ {
6358 opfactory.add_op (Print<float >());
64-
65- opfactory.add_op (Transpose<float >());
66-
67- opfactory.add_op (Reshape<float >());
6859 }
69- // elementwise
70- void register_add (OpFactory &opfactory){
60+
61+ // elementwise
62+ void register_elementwise (OpFactory &opfactory)
63+ {
7164 opfactory.add_op (Add<float >());
7265 opfactory.add_op (Add<double >());
73- }
74- void register_add_scalar (OpFactory &opfactory){
66+
7567 opfactory.add_op (Add_scalar<float >());
7668 opfactory.add_op (Add_scalar<double >());
77- }
78- void register_sub (OpFactory &opfactory){
69+
7970 opfactory.add_op (Sub<float >());
8071 opfactory.add_op (Sub<double >());
81- }
8272
83- void register_mul (OpFactory &opfactory){
8473 opfactory.add_op (Mul<float >());
8574 opfactory.add_op (Mul<double >());
86- }
87- void register_mul_scalar (OpFactory &opfactory){
75+
8876 opfactory.add_op (Mul_scalar<float >());
8977 opfactory.add_op (Mul_scalar<double >());
90- }
91- void register_div (OpFactory &opfactory){
78+
9279 opfactory.add_op (Div<float >());
9380 opfactory.add_op (Div<double >());
94- }
95- void register_div_scalar (OpFactory &opfactory){
81+
9682 opfactory.add_op (Div_scalar<float >());
9783 opfactory.add_op (Div_scalar<double >());
98- }
99- void register_rdiv_scalar (OpFactory &opfactory){
84+
10085 opfactory.add_op (RDiv_scalar<float >());
10186 opfactory.add_op (RDiv_scalar<double >());
102- }
103- void register_sqrt (OpFactory &opfactory){
87+
10488 opfactory.add_op (Sqrt<float >());
10589 opfactory.add_op (Sqrt<double >());
106- }
107- void register_exp (OpFactory &opfactory){
90+
10891 opfactory.add_op (Exp<float >());
10992 opfactory.add_op (Exp<double >());
110- }
111- void register_pow (OpFactory &opfactory){
93+
11294 opfactory.add_op (Pow<float >());
11395 opfactory.add_op (Pow<double >());
114- }
115- void register_pow_scalar (OpFactory &opfactory){
96+
11697 opfactory.add_op (Pow_scalar<float >());
11798 opfactory.add_op (Pow_scalar<double >());
11899 }
119- void register_elementwise_op (OpFactory &opfactory){
120- register_add (opfactory);
121- register_add_scalar (opfactory);
122- register_sub (opfactory);
123- register_mul (opfactory);
124- register_mul_scalar (opfactory);
125- register_div (opfactory);
126- register_div_scalar (opfactory);
127- register_rdiv_scalar (opfactory);
128- register_sqrt (opfactory);
129- register_exp (opfactory);
130- register_pow (opfactory);
131- register_pow_scalar (opfactory);
132- }
133- // concat
134-
135- void register_concat (OpFactory &opfactory){
136- opfactory.add_op (Concat<float >());
137- opfactory.add_op (Concat<double >());
138- }
139- // matmul
140- void register_matmul (OpFactory &opfactory){
100+ // matmul
101+ void register_matmul (OpFactory &opfactory)
102+ {
141103 opfactory.add_op (MatMul<float >());
142104 opfactory.add_op (MatMul<double >());
143105 }
144- // reduce
145- void register_reduce (OpFactory &opfactory){
106+ // changeshape
107+ void register_changeshape (OpFactory &opfactory)
108+ {
109+ opfactory.add_op (Transpose<float >());
110+ opfactory.add_op (Reshape<float >());
111+ opfactory.add_op (Expand<float >());
112+ opfactory.add_op (Concat<float >());
113+ }
114+ // reduce
115+ void register_reduce (OpFactory &opfactory)
116+ {
146117 opfactory.add_op (Max<float >());
147118 opfactory.add_op (Max<double >());
148119 opfactory.add_op (Max_scalar<float >());
@@ -154,13 +125,14 @@ namespace deepx::op
154125 opfactory.add_op (Sum<float >());
155126 opfactory.add_op (Sum<double >());
156127 }
157- int register_all (OpFactory &opfactory){
158- register_new (opfactory);
128+ int register_all (OpFactory &opfactory)
129+ {
130+ register_lifecycle (opfactory);
159131 register_init (opfactory);
160- register_anytype (opfactory);
161- register_elementwise_op (opfactory);
162- register_concat (opfactory);
132+ register_util (opfactory);
133+ register_elementwise (opfactory);
163134 register_matmul (opfactory);
135+ register_changeshape (opfactory);
164136 register_reduce (opfactory);
165137 return 0 ;
166138 }
0 commit comments