diff --git a/include/internal/common.h b/include/internal/common.h index 24fa781..fde5680 100644 --- a/include/internal/common.h +++ b/include/internal/common.h @@ -474,6 +474,16 @@ static inline bool checkForEmptyPencils(const cudecompGridDesc_t grid_desc, int return false; } +// Workspace buffer alignment in bytes +static constexpr int CUDECOMP_WORKSPACE_ALIGN_BYTES = 256; + +// Helper to round element count to nearest multiple of nbytes (assuming smallest supported type float) +static inline int64_t alignCountToBytes(int64_t count, int nbytes) { + int64_t count_bytes = count * sizeof(float); + int64_t count_bytes_rounded = ((count_bytes + nbytes - 1) / nbytes) * nbytes; + return count_bytes_rounded / sizeof(float); +} + } // namespace cudecomp #endif // CUDECOMP_COMMON_H diff --git a/include/internal/halo.h b/include/internal/halo.h index a37a3dc..c9d890f 100644 --- a/include/internal/halo.h +++ b/include/internal/halo.h @@ -191,9 +191,9 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG cudecompBatchedD2DMemcpy3DParams memcpy_params; std::array lx{}; - size_t halo_size = shape_g_h[(dim + 1) % 3] * shape_g_h[(dim + 2) % 3] * halo_extents[dim]; + int64_t halo_size = shape_g_h[(dim + 1) % 3] * shape_g_h[(dim + 2) % 3] * halo_extents[dim]; T* send_buff = work; - T* recv_buff = work + 2 * halo_size; + T* recv_buff = work + 2 * alignCountToBytes(halo_size, CUDECOMP_WORKSPACE_ALIGN_BYTES); // Pack // Left @@ -204,7 +204,7 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG // Right lx[dim] = shape_g_h_p[dim] - 2 * halo_extents[dim] - padding[dim]; memcpy_params.src[1] = input + getPencilPtrOffset(pinfo_h_p, lx); - memcpy_params.dest[1] = send_buff + halo_size; + memcpy_params.dest[1] = send_buff + alignCountToBytes(halo_size, CUDECOMP_WORKSPACE_ALIGN_BYTES); for (int i = 0; i < 2; ++i) { memcpy_params.src_strides[0][i] = pinfo_h_p.shape[0] * pinfo_h_p.shape[1]; @@ -222,7 +222,7 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG std::array counts{static_cast(halo_size), static_cast(halo_size)}; std::array offsets{}; - offsets[1] = halo_size; + offsets[1] = static_cast(alignCountToBytes(halo_size, CUDECOMP_WORKSPACE_ALIGN_BYTES)); if (handle->performance_report_enable && current_sample) { current_sample->sendrecv_bytes = 0; @@ -239,7 +239,7 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG memcpy_params.dest[0] = input + getPencilPtrOffset(pinfo_h_p, {0, 0, 0}); // Right - memcpy_params.src[1] = recv_buff + halo_size; + memcpy_params.src[1] = recv_buff + alignCountToBytes(halo_size, CUDECOMP_WORKSPACE_ALIGN_BYTES); lx[dim] = shape_g_h_p[dim] - halo_extents[dim] - padding[dim]; memcpy_params.dest[1] = input + getPencilPtrOffset(pinfo_h_p, lx); @@ -273,7 +273,7 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG // Contiguous (direct send/recv) std::array lx{}; - size_t halo_size = shape_g_h[(dim + 1) % 3] * shape_g_h[(dim + 2) % 3] * halo_extents[dim]; + int64_t halo_size = shape_g_h[(dim + 1) % 3] * shape_g_h[(dim + 2) % 3] * halo_extents[dim]; std::array counts{static_cast(halo_size), static_cast(halo_size)}; std::array send_offsets; std::array recv_offsets; diff --git a/include/internal/transpose.h b/include/internal/transpose.h index d4140c9..6328f93 100644 --- a/include/internal/transpose.h +++ b/include/internal/transpose.h @@ -63,7 +63,7 @@ static inline cutensorComputeDescriptor_t getCutensorComputeType(cutensorDataTyp template static inline uint32_t getAlignment(const T* ptr) { auto i_ptr = reinterpret_cast(ptr); - for (uint32_t d = 16; d > 0; d >>= 1) { + for (uint32_t d = CUDECOMP_WORKSPACE_ALIGN_BYTES; d > 0; d >>= 1) { if (i_ptr % d == 0) return d; } return 1; @@ -277,13 +277,13 @@ static void cudecompTranspose_(int ax, int dir, const cudecompHandle_t handle, c // Set input/output pointers for each phase T* i1 = input; T* o1 = work; - T* o2 = work + pinfo_a.size; + T* o2 = work + alignCountToBytes(pinfo_a.size, CUDECOMP_WORKSPACE_ALIGN_BYTES); T* o3 = output; #ifdef ENABLE_NVSHMEM if (transposeBackendRequiresNvshmem(grid_desc->config.transpose_comm_backend)) { auto max_pencil_size_a = getGlobalMaxPencilSize(handle, grid_desc, ax_a); - o2 = work + max_pencil_size_a; + o2 = work + alignCountToBytes(max_pencil_size_a, CUDECOMP_WORKSPACE_ALIGN_BYTES); // NVSHMEM team synchronization between transpose operations if (splits_a.size() != 1) { diff --git a/src/cudecomp.cc b/src/cudecomp.cc index 6cbdffa..c37902f 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -1127,7 +1127,14 @@ cudecompResult_t cudecompGetTransposeWorkspaceSize(cudecompHandle_t handle, cude int64_t max_pencil_size_x = getGlobalMaxPencilSize(handle, grid_desc, 0); int64_t max_pencil_size_y = getGlobalMaxPencilSize(handle, grid_desc, 1); int64_t max_pencil_size_z = getGlobalMaxPencilSize(handle, grid_desc, 2); - *workspace_size = std::max(max_pencil_size_x + max_pencil_size_y, max_pencil_size_y + max_pencil_size_z); + + // Round send portion of workspace to 256 byte boundary (in elements, assuming float) + int64_t wsize_xy = alignCountToBytes(max_pencil_size_x, CUDECOMP_WORKSPACE_ALIGN_BYTES) + max_pencil_size_y; + int64_t wsize_yx = alignCountToBytes(max_pencil_size_y, CUDECOMP_WORKSPACE_ALIGN_BYTES) + max_pencil_size_x; + int64_t wsize_yz = alignCountToBytes(max_pencil_size_y, CUDECOMP_WORKSPACE_ALIGN_BYTES) + max_pencil_size_z; + int64_t wsize_zy = alignCountToBytes(max_pencil_size_z, CUDECOMP_WORKSPACE_ALIGN_BYTES) + max_pencil_size_y; + + *workspace_size = std::max({wsize_xy, wsize_yx, wsize_yz, wsize_zy}); } catch (const cudecomp::BaseException& e) { std::cerr << e.what(); @@ -1148,11 +1155,16 @@ cudecompResult_t cudecompGetHaloWorkspaceSize(cudecompHandle_t handle, cudecompG cudecompPencilInfo_t pinfo; CHECK_CUDECOMP(cudecompGetPencilInfo(handle, grid_desc, &pinfo, axis, halo_extents, nullptr)); auto shape_g = getShapeG(pinfo); - size_t halo_size_x = 4 * shape_g[1] * shape_g[2] * pinfo.halo_extents[0]; - size_t halo_size_y = 4 * shape_g[0] * shape_g[2] * pinfo.halo_extents[1]; - size_t halo_size_z = 4 * shape_g[0] * shape_g[1] * pinfo.halo_extents[2]; - *workspace_size = std::max(halo_size_x, std::max(halo_size_y, halo_size_z)); + // Round all halo slots in workspace to 256 byte boundary (in elements, assuming float) + int64_t halo_size_x = + 4 * alignCountToBytes(shape_g[1] * shape_g[2] * pinfo.halo_extents[0], CUDECOMP_WORKSPACE_ALIGN_BYTES); + int64_t halo_size_y = + 4 * alignCountToBytes(shape_g[0] * shape_g[2] * pinfo.halo_extents[1], CUDECOMP_WORKSPACE_ALIGN_BYTES); + int64_t halo_size_z = + 4 * alignCountToBytes(shape_g[0] * shape_g[1] * pinfo.halo_extents[2], CUDECOMP_WORKSPACE_ALIGN_BYTES); + + *workspace_size = std::max({halo_size_x, halo_size_y, halo_size_z}); } catch (const cudecomp::BaseException& e) { std::cerr << e.what(); return e.getResult();