diff --git a/ds4_cuda.cu b/ds4_cuda.cu index 8b6241ca3..8f467b63c 100644 --- a/ds4_cuda.cu +++ b/ds4_cuda.cu @@ -1302,9 +1302,22 @@ extern "C" ds4_gpu_tensor *ds4_gpu_tensor_alloc(uint64_t bytes) { if (bytes == 0) bytes = 1; ds4_gpu_tensor *t = (ds4_gpu_tensor *)calloc(1, sizeof(*t)); if (!t) return NULL; - if (!cuda_ok(cudaMalloc(&t->ptr, (size_t)bytes), "tensor alloc")) { - free(t); - return NULL; + + if (getenv("DS4_CUDA_MANAGED") != NULL) { + /* Use cudaMallocManaged with cudaMemAttachGlobal so the allocation + * is GPU-accessible across all streams. On UMA platforms (Strix + * Halo, Grace-Hopper) this allocates from the full unified pool, + * bypassing the BIOS VRAM carve-out. */ + if (!cuda_ok(cudaMallocManaged(&t->ptr, (size_t)bytes, cudaMemAttachGlobal), + "managed tensor alloc")) { + free(t); + return NULL; + } + } else { + if (!cuda_ok(cudaMalloc(&t->ptr, (size_t)bytes), "tensor alloc")) { + free(t); + return NULL; + } } t->bytes = bytes; t->owner = 1; @@ -6168,7 +6181,7 @@ extern "C" int ds4_gpu_attention_prefill_raw_heads_tensor(ds4_gpu_tensor *heads, if (!tmp) return 0; float *scores = tmp; float *out_tmp = (float *)((char *)tmp + out_offset); - const float alpha = rsqrtf((float)head_dim); + const float alpha = 1.0f / sqrtf((float)head_dim); const float beta = 0.0f; cublasStatus_t st = cublasSgemmStridedBatched(g_cublas, CUBLAS_OP_T, @@ -6538,7 +6551,7 @@ static int attention_prefill_mixed_launch( n_comp, head_dim); if (!cuda_ok(cudaGetLastError(), "attention mixed kv pack launch")) return 0; - const float alpha = rsqrtf((float)head_dim); + const float alpha = 1.0f / sqrtf((float)head_dim); const float beta = 0.0f; cublasStatus_t st = cublasSgemmStridedBatched(g_cublas, CUBLAS_OP_T, diff --git a/ds4_rocm.h b/ds4_rocm.h index 0400910df..55907c428 100644 --- a/ds4_rocm.h +++ b/ds4_rocm.h @@ -13,6 +13,7 @@ #define cudaSuccess hipSuccess #define cudaErrorNotSupported hipErrorNotSupported +#define cudaMemAttachGlobal hipMemAttachGlobal #define cudaErrorInvalidValue hipErrorInvalidValue #define cudaGetLastError hipGetLastError #define cudaGetErrorString hipGetErrorString