Skip to content

Commit 2f25a52

Browse files
committed
tensorfunc:dispatcher,cblas,openblas支持
1 parent ca75186 commit 2f25a52

8 files changed

Lines changed: 372 additions & 614 deletions

File tree

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#ifndef DEEPX_TENSORFUNC_CHANGE_SHAPE_HPP
2+
#define DEEPX_TENSORFUNC_CHANGE_SHAPE_HPP
3+
4+
#include "deepx/tensor.hpp"
5+
#include "stdutil/error.hpp"
6+
7+
namespace deepx::tensorfunc
8+
{
9+
10+
// 通用模板声明
11+
template <typename Author, typename T>
12+
struct InitDispatcher
13+
{
14+
static void reshape(Tensor<T> &tensor, const Shape &new_shape) = delete;
15+
};
16+
17+
template <typename Author, typename T>
18+
void reshape(Tensor<T> &tensor, const Shape &new_shape)
19+
{
20+
InitDispatcher<Author, T>::reshape(tensor, new_shape);
21+
}
22+
23+
// // 作者特化示例(类型无关实现)
24+
// template <typename T>
25+
// struct InitDispatcher<miaobyte, T>
26+
// {
27+
// static void reshape(Tensor<T> &tensor, const Shape &new_shape)
28+
// {
29+
// // 统一实现,不依赖T的类型
30+
// if (tensor.shape.size() != new_shape.size())
31+
// {
32+
// throw std::invalid_argument("Total elements must match");
33+
// }
34+
// tensor.shape = new_shape;
35+
// }
36+
// };
37+
// 特化作者和具体精度
38+
// template <>
39+
// struct InitDispatcher<miaobyte, float>
40+
// {
41+
// static void reshape(Tensor<float> &tensor, const Shape &new_shape)
42+
// {
43+
// // CUDA实现
44+
// }
45+
// };
46+
}
47+
48+
#endif

0 commit comments

Comments
 (0)