Skip to content

Commit 4df1dbc

Browse files
committed
tensorfuncispatcher:matmul,cblas
1 parent 0fae478 commit 4df1dbc

5 files changed

Lines changed: 319 additions & 312 deletions

File tree

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef DEEPX_TENSORFUNC_CUDA_HPP
2+
#define DEEPX_TENSORFUNC_CUDA_HPP
3+
#include <cuda_fp16.h> // 为了支持half精度
4+
#include <cuda_bf16.h>
5+
#include <cublas_v2.h>
6+
#include <cstdint>
7+
#include <stdexcept>
8+
9+
#include "deepx/tensor.hpp"
10+
#include "authors.hpp"
11+
12+
namespace deepx::tensorfunc
13+
{
14+
class CublasHandle
15+
{
16+
public:
17+
CublasHandle()
18+
{
19+
if (cublasCreate(&handle_) != CUBLAS_STATUS_SUCCESS)
20+
{
21+
throw std::runtime_error("Failed to create cuBLAS handle");
22+
}
23+
}
24+
25+
~CublasHandle()
26+
{
27+
if (handle_)
28+
cublasDestroy(handle_);
29+
}
30+
31+
cublasHandle_t get() { return handle_; }
32+
33+
private:
34+
cublasHandle_t handle_;
35+
};
36+
37+
}
38+
39+
#endif

excuter/op-mem-cuda/src/deepx/tensorfunc/elementwise_cublas_basic.hpp

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,11 @@
1111
#include "deepx/tensorfunc/elementwise.hpp"
1212
#include "deepx/tensorfunc/elementwise_basic.hpp"
1313
#include "deepx/tensorfunc/authors.hpp"
14+
#include "deepx/tensorfunc/cuda.hpp"
1415
namespace deepx::tensorfunc
1516
{
1617
// cuBLAS handle管理
17-
class CublasHandle
18-
{
19-
public:
20-
CublasHandle()
21-
{
22-
if (cublasCreate(&handle_) != CUBLAS_STATUS_SUCCESS)
23-
{
24-
throw std::runtime_error("Failed to create cuBLAS handle");
25-
}
26-
}
27-
28-
~CublasHandle()
29-
{
30-
if (handle_)
31-
cublasDestroy(handle_);
32-
}
33-
34-
cublasHandle_t get() { return handle_; }
35-
36-
private:
37-
cublasHandle_t handle_;
38-
};
39-
18+
4019
// cublas作者的特化实现
4120
template <>
4221
struct _author_add<cublas>

excuter/op-mem-cuda/src/deepx/tensorfunc/matmul.hpp

Lines changed: 0 additions & 56 deletions
This file was deleted.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#ifndef DEEPX_TENSORFUNC_MATMUL_CUBLAS_HPP
2+
#define DEEPX_TENSORFUNC_MATMUL_CUBLAS_HPP
3+
4+
#include "deepx/tensor.hpp"
5+
#include "authors.hpp"
6+
7+
namespace deepx::tensorfunc
8+
{
9+
10+
template <typename T>
11+
struct matmulDispatcher<cublas,T>
12+
{
13+
static void matmul(const Tensor<T> &A, const Tensor<T> &B, Tensor<T> &C)
14+
{
15+
if (!check_matmul_shape(A.shape, B.shape))
16+
{
17+
throw std::invalid_argument("A.shape could matmul with B.shape");
18+
}
19+
C.shape.rangeParallel(C.shape.dim - 2, [&](const std::vector<int> &indices)
20+
{
21+
int aIdx=A.shape.linearat(indices);
22+
int bIdx=B.shape.linearat(indices);
23+
int cIdx=C.shape.linearat(indices);
24+
int m=A.shape[-2];
25+
int k=A.shape[-1];
26+
int n=B.shape[-1];
27+
for(int i=0;i<m;i++){
28+
for(int j=0;j<n;j++){
29+
T sum=0;
30+
for(int l=0;l<k;l++){
31+
sum+=A.data[aIdx+i*k+l]*B.data[bIdx+l*n+j];
32+
}
33+
C.data[cIdx+i*n+j]=sum;
34+
}
35+
} });
36+
}
37+
};
38+
39+
template <>
40+
void matmul<float>(const Tensor<float> &a, const Tensor<float> &b, Tensor<float> &c)
41+
{
42+
}
43+
44+
template <>
45+
void matmul<double>(const Tensor<double> &a, const Tensor<double> &b, Tensor<double> &c)
46+
{
47+
}
48+
template <typename T>
49+
void matmuladd(const Tensor<T> &a, const Tensor<T> &b, const T &alpha, const T &beta, Tensor<T> &c)
50+
{
51+
}
52+
53+
template <>
54+
void matmuladd<float>(const Tensor<float> &a, const Tensor<float> &b, const float &alpha, const float &beta, Tensor<float> &c)
55+
{
56+
}
57+
58+
template <>
59+
void matmuladd<double>(const Tensor<double> &a, const Tensor<double> &b, const double &alpha, const double &beta, Tensor<double> &c)
60+
{
61+
}
62+
}
63+
#endif // DEEPX_TENSORFUNC_MATMUL_HPP

0 commit comments

Comments
 (0)