|
3 | 3 |
|
4 | 4 | #include <cuda_bf16.h> |
5 | 5 | #include <cuda_fp16.h> |
| 6 | +#include <curand_kernel.h> |
| 7 | + |
| 8 | + |
6 | 9 | #include "deepx/tensorfunc/cuda.hpp" |
7 | 10 | #include "deepx/tensorfunc/authors.hpp" |
8 | 11 | #include "deepx/tensorfunc/cuda_math.cuh" |
@@ -404,6 +407,48 @@ namespace deepx::tensorfunc |
404 | 407 | template void launch_invert<int16_t>(const int16_t *a, int16_t *c, const int size); |
405 | 408 | template void launch_invert<int8_t>(const int8_t *a, int8_t *c, const int size); |
406 | 409 |
|
| 410 | + //dropout |
| 411 | + template <typename T> |
| 412 | + __global__ void dropout_kernel(const T *A, const float p,const unsigned int seed, T *C, const int size) |
| 413 | + { |
| 414 | + int stride = blockDim.x * gridDim.x; |
| 415 | + curandState state; |
| 416 | + curand_init(seed, threadIdx.x, 0, &state); // 仅初始化一次 |
| 417 | + |
| 418 | + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride) |
| 419 | + { |
| 420 | + float rand = curand_uniform(&state); |
| 421 | + if (rand < p) |
| 422 | + { |
| 423 | + C[idx] = 0; |
| 424 | + } |
| 425 | + else |
| 426 | + { |
| 427 | + C[idx] = A[idx]; |
| 428 | + } |
| 429 | + } |
| 430 | + } |
| 431 | + |
| 432 | + template <typename T> |
| 433 | + void launch_dropout(const T *a, const float p,const unsigned int seed, T *c, const int size) |
| 434 | + { |
| 435 | + auto [numBlocks, blockSize] = BestDims(size); |
| 436 | + dropout_kernel<<<numBlocks, blockSize>>>(a, p, seed, c, size); |
| 437 | + cudaError_t err = cudaGetLastError(); |
| 438 | + if (err != cudaSuccess) |
| 439 | + { |
| 440 | + throw std::runtime_error("Failed to launch dropout kernel: " + |
| 441 | + std::string(cudaGetErrorString(err))); |
| 442 | + } |
| 443 | + } |
| 444 | + template void launch_dropout<double>(const double *a, const float p,const unsigned int seed, double *c, const int size); |
| 445 | + template void launch_dropout<float>(const float *a, const float p,const unsigned int seed, float *c, const int size); |
| 446 | + template void launch_dropout<half>(const half *a, const float p,const unsigned int seed, half *c, const int size); |
| 447 | + template void launch_dropout<nv_bfloat16>(const nv_bfloat16 *a, const float p,const unsigned int seed, nv_bfloat16 *c, const int size); |
| 448 | + template void launch_dropout<int64_t>(const int64_t *a, const float p,const unsigned int seed, int64_t *c, const int size); |
| 449 | + template void launch_dropout<int32_t>(const int32_t *a, const float p,const unsigned int seed, int32_t *c, const int size); |
| 450 | + template void launch_dropout<int16_t>(const int16_t *a, const float p,const unsigned int seed, int16_t *c, const int size); |
| 451 | + template void launch_dropout<int8_t>(const int8_t *a, const float p,const unsigned int seed, int8_t *c, const int size); |
407 | 452 | } |
408 | 453 |
|
409 | 454 | #endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAOBYTE_BASIC_CU |
0 commit comments