@@ -17,6 +17,8 @@ namespace deepx::tensorfunc
1717 }
1818 }
1919
20+
21+
2022 // 实现特化版本的成员函数
2123 void _constant_func<miaobyte, float >::func(Tensor<float > &tensor, const float value)
2224 {
@@ -59,4 +61,57 @@ namespace deepx::tensorfunc
5961 throw std::runtime_error (" Failed to launch constant kernel" );
6062 }
6163 }
64+
65+ // 添加kernel函数
66+ template <typename T>
67+ __global__ void kernel_arange (T *data, int size, T start, T step)
68+ {
69+ int idx = blockIdx .x * blockDim .x + threadIdx .x ;
70+ if (idx < size)
71+ {
72+ data[idx] = start + step * static_cast <T>(idx);
73+ }
74+ }
75+
76+ void _arange_func<miaobyte, float >::func(Tensor<float > &tensor, const float start, const float step)
77+ {
78+ int size = tensor.shape .size ;
79+ int blockSize = 256 ;
80+ int numBlocks = (size + blockSize - 1 ) / blockSize;
81+
82+ kernel_arange<<<numBlocks, blockSize>>> (tensor.data , size, start, step);
83+
84+ cudaError_t err = cudaGetLastError ();
85+ if (err != cudaSuccess) {
86+ throw std::runtime_error (" Failed to launch arange kernel" );
87+ }
88+ }
89+
90+ void _arange_func<miaobyte, double >::func(Tensor<double > &tensor, const double start, const double step)
91+ {
92+ int size = tensor.shape .size ;
93+ int blockSize = 256 ;
94+ int numBlocks = (size + blockSize - 1 ) / blockSize;
95+
96+ kernel_arange<<<numBlocks, blockSize>>> (tensor.data , size, start, step);
97+
98+ cudaError_t err = cudaGetLastError ();
99+ if (err != cudaSuccess) {
100+ throw std::runtime_error (" Failed to launch arange kernel" );
101+ }
102+ }
103+
104+ void _arange_func<miaobyte, __half>::func(Tensor<__half> &tensor, const __half start, const __half step)
105+ {
106+ int size = tensor.shape .size ;
107+ int blockSize = 256 ;
108+ int numBlocks = (size + blockSize - 1 ) / blockSize;
109+
110+ kernel_arange<<<numBlocks, blockSize>>> (tensor.data , size, start, step);
111+
112+ cudaError_t err = cudaGetLastError ();
113+ if (err != cudaSuccess) {
114+ throw std::runtime_error (" Failed to launch arange kernel" );
115+ }
116+ }
62117}
0 commit comments