diff --git a/excuter/cpp-common/src/deepx/dtype.hpp b/excuter/cpp-common/src/deepx/dtype.hpp index 73b4ba5b..0a2625aa 100644 --- a/excuter/cpp-common/src/deepx/dtype.hpp +++ b/excuter/cpp-common/src/deepx/dtype.hpp @@ -475,5 +475,43 @@ namespace deepx } } + template + struct PrecisionWrapper {}; + + template + struct to_tensor_type; + + template <> + struct to_tensor_type> { + using type = double; + }; + + template <> + struct to_tensor_type> { + using type = float; + }; + + template <> + struct to_tensor_type> { + using type = int64_t; + }; + + template <> + struct to_tensor_type> { + using type = int32_t; + }; + + template <> + struct to_tensor_type> { + using type = int16_t; + }; + + template <> + struct to_tensor_type> { + using type = int8_t; + }; + + template + using tensor_t = typename to_tensor_type>::type; } // namespace deepx #endif diff --git a/excuter/cpp-common/test/0_dtypes.cpp b/excuter/cpp-common/test/0_dtypes.cpp index 766a1974..1761f010 100644 --- a/excuter/cpp-common/test/0_dtypes.cpp +++ b/excuter/cpp-common/test/0_dtypes.cpp @@ -6,9 +6,7 @@ using namespace std; using namespace deepx::tf; using namespace deepx; -int main(int argc, char **argv) -{ - +void test_1() { unordered_map dtype_map = { {"tensor", make_dtype(DataCategory::Tensor, Precision::Any)}, {"tensor", make_dtype(DataCategory::Tensor, Precision::Int)}, @@ -54,6 +52,26 @@ int main(int argc, char **argv) } cout << string(80, '=') << endl; +} + +// test to tensor type +void test_2() { + if (typeid(tensor_t)== typeid(double)) { + std::cout<<"it's ok"<)== typeid(float)) { + std::cout<<"it's ok"<