Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions ds4_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1302,9 +1302,22 @@ extern "C" ds4_gpu_tensor *ds4_gpu_tensor_alloc(uint64_t bytes) {
if (bytes == 0) bytes = 1;
ds4_gpu_tensor *t = (ds4_gpu_tensor *)calloc(1, sizeof(*t));
if (!t) return NULL;
if (!cuda_ok(cudaMalloc(&t->ptr, (size_t)bytes), "tensor alloc")) {
free(t);
return NULL;

if (getenv("DS4_CUDA_MANAGED") != NULL) {
/* Use cudaMallocManaged with cudaMemAttachGlobal so the allocation
* is GPU-accessible across all streams. On UMA platforms (Strix
* Halo, Grace-Hopper) this allocates from the full unified pool,
* bypassing the BIOS VRAM carve-out. */
if (!cuda_ok(cudaMallocManaged(&t->ptr, (size_t)bytes, cudaMemAttachGlobal),
"managed tensor alloc")) {
free(t);
return NULL;
}
} else {
if (!cuda_ok(cudaMalloc(&t->ptr, (size_t)bytes), "tensor alloc")) {
free(t);
return NULL;
}
}
t->bytes = bytes;
t->owner = 1;
Expand Down Expand Up @@ -6168,7 +6181,7 @@ extern "C" int ds4_gpu_attention_prefill_raw_heads_tensor(ds4_gpu_tensor *heads,
if (!tmp) return 0;
float *scores = tmp;
float *out_tmp = (float *)((char *)tmp + out_offset);
const float alpha = rsqrtf((float)head_dim);
const float alpha = 1.0f / sqrtf((float)head_dim);
const float beta = 0.0f;
cublasStatus_t st = cublasSgemmStridedBatched(g_cublas,
CUBLAS_OP_T,
Expand Down Expand Up @@ -6538,7 +6551,7 @@ static int attention_prefill_mixed_launch(
n_comp,
head_dim);
if (!cuda_ok(cudaGetLastError(), "attention mixed kv pack launch")) return 0;
const float alpha = rsqrtf((float)head_dim);
const float alpha = 1.0f / sqrtf((float)head_dim);
const float beta = 0.0f;
cublasStatus_t st = cublasSgemmStridedBatched(g_cublas,
CUBLAS_OP_T,
Expand Down
1 change: 1 addition & 0 deletions ds4_rocm.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#define cudaSuccess hipSuccess
#define cudaErrorNotSupported hipErrorNotSupported
#define cudaMemAttachGlobal hipMemAttachGlobal
#define cudaErrorInvalidValue hipErrorInvalidValue
#define cudaGetLastError hipGetLastError
#define cudaGetErrorString hipGetErrorString
Expand Down