Skip to content

Commit 1f1e9a9

Browse files
committed
cuda dtype transfer
1 parent 73addbb commit 1f1e9a9

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

excuter/op-mem-cuda/src/deepx/dtype_cuda.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include <cuda_fp16.h>
55
#include <cuda_bf16.h>
6+
#include <cuda_fp8.h>
7+
#include <cuda_fp4.h>
68

79
#include "deepx/dtype.hpp"
810

@@ -34,6 +36,27 @@ namespace deepx
3436
else
3537
return Precision::Any;
3638
}
39+
40+
41+
template <>
42+
struct to_tensor_type<PrecisionWrapper<Precision::BFloat16>> {
43+
using type = nv_bfloat16;
44+
};
45+
46+
template <>
47+
struct to_tensor_type<PrecisionWrapper<Precision::Float16>> {
48+
using type = half;
49+
};
50+
51+
template <>
52+
struct to_tensor_type<PrecisionWrapper<Precision::Float8E5M2>> {
53+
using type = __nv_fp8_e5m2;
54+
};
55+
56+
template <>
57+
struct to_tensor_type<PrecisionWrapper<Precision::Float8e4m3>> {
58+
using type = __nv_fp8_e4m3;
59+
}
3760
}
3861

3962
#endif // DEEPX_DTYPE_CUDA_HPP

0 commit comments

Comments
 (0)