11#ifndef DEEPX_TF_TFFACTORY_HPP
22#define DEEPX_TF_TFFACTORY_HPP
33
4- #include < unordered_map>
5- #include < string>
6- #include < memory>
7- #include < vector>
8- #include < algorithm>
9-
104#include " deepx/tf/tf.hpp"
115
126namespace deepx ::tf
137{
8+
149 struct TypeSignature
1510 {
1611 vector<TypeDef> args;
1712 vector<TypeDef> returns;
1813
19- bool is_compatible (const TypeSignature& other) const {
20- return is_compatible_types (args, other.args ) &&
14+ bool is_compatible (const TypeSignature &other) const
15+ {
16+ return is_compatible_types (args, other.args ) &&
2117 is_compatible_types (returns, other.returns );
2218 }
2319
2420 private:
25- static bool is_compatible_types (const vector<TypeDef>& a, const vector<TypeDef>& b) {
26- if (a.size () != b.size ()) return false ;
27- for (size_t i = 0 ; i < a.size (); i++) {
28- if ((static_cast <uint8_t >(a[i].parts .category ) &
29- static_cast <uint8_t >(b[i].parts .category )) == 0 ) {
21+ static bool is_compatible_types (const vector<TypeDef> &a, const vector<TypeDef> &b)
22+ {
23+ if (a.size () != b.size ())
24+ return false ;
25+ for (size_t i = 0 ; i < a.size (); i++)
26+ {
27+ if ((static_cast <uint8_t >(a[i].parts .category ) &
28+ static_cast <uint8_t >(b[i].parts .category )) == 0 )
29+ {
3030 return false ;
3131 }
32- if (a[i].parts .precision != Precision::Any &&
33- b[i].parts .precision != Precision::Any &&
34- a[i].parts .precision != b[i].parts .precision ) {
32+ if (a[i].parts .precision != Precision::Any &&
33+ b[i].parts .precision != Precision::Any &&
34+ a[i].parts .precision != b[i].parts .precision )
35+ {
3536 return false ;
3637 }
3738 }
@@ -44,23 +45,28 @@ namespace deepx::tf
4445 vector<std::shared_ptr<TF>> tfs;
4546
4647 // 获取匹配的TF实现
47- std::shared_ptr<TF> get_matching_tf (const vector<TypeDef>& arg_types,
48- const vector<TypeDef>& return_types) const {
48+ std::shared_ptr<TF> get_matching_tf (const vector<TypeDef> &arg_types,
49+ const vector<TypeDef> &return_types) const
50+ {
4951 TypeSignature target{arg_types, return_types};
50-
51- for (const auto & tf : tfs) {
52+
53+ for (const auto &tf : tfs)
54+ {
5255 vector<TypeDef> tf_arg_types;
53- for (const auto & arg : tf->args ) {
56+ for (const auto &arg : tf->args )
57+ {
5458 tf_arg_types.push_back (arg.dtype );
5559 }
56-
60+
5761 vector<TypeDef> tf_return_types;
58- for (const auto & ret : tf->returns ) {
62+ for (const auto &ret : tf->returns )
63+ {
5964 tf_return_types.push_back (ret.dtype );
6065 }
61-
66+
6267 TypeSignature current{tf_arg_types, tf_return_types};
63- if (target.is_compatible (current)) {
68+ if (target.is_compatible (current))
69+ {
6470 return tf;
6571 }
6672 }
@@ -101,9 +107,6 @@ namespace deepx::tf
101107 // 输出为markdown表格格式
102108 string print_markdown () const ;
103109 };
104-
105- int register_all (TfFactory &tfactory);
106110}
107-
108111
109112#endif
0 commit comments