Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/internal/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,16 @@ static inline bool checkForEmptyPencils(const cudecompGridDesc_t grid_desc, int
return false;
}

// Workspace buffer alignment in bytes
static constexpr int CUDECOMP_WORKSPACE_ALIGN_BYTES = 256;

// Helper to round element count to nearest multiple of nbytes (assuming smallest supported type float)
static inline int64_t alignCountToBytes(int64_t count, int nbytes) {
int64_t count_bytes = count * sizeof(float);
int64_t count_bytes_rounded = ((count_bytes + nbytes - 1) / nbytes) * nbytes;
return count_bytes_rounded / sizeof(float);
}

} // namespace cudecomp

#endif // CUDECOMP_COMMON_H
12 changes: 6 additions & 6 deletions include/internal/halo.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG
cudecompBatchedD2DMemcpy3DParams<T> memcpy_params;
std::array<int32_t, 3> lx{};

size_t halo_size = shape_g_h[(dim + 1) % 3] * shape_g_h[(dim + 2) % 3] * halo_extents[dim];
int64_t halo_size = shape_g_h[(dim + 1) % 3] * shape_g_h[(dim + 2) % 3] * halo_extents[dim];
T* send_buff = work;
T* recv_buff = work + 2 * halo_size;
T* recv_buff = work + 2 * alignCountToBytes(halo_size, CUDECOMP_WORKSPACE_ALIGN_BYTES);

// Pack
// Left
Expand All @@ -204,7 +204,7 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG
// Right
lx[dim] = shape_g_h_p[dim] - 2 * halo_extents[dim] - padding[dim];
memcpy_params.src[1] = input + getPencilPtrOffset(pinfo_h_p, lx);
memcpy_params.dest[1] = send_buff + halo_size;
memcpy_params.dest[1] = send_buff + alignCountToBytes(halo_size, CUDECOMP_WORKSPACE_ALIGN_BYTES);

for (int i = 0; i < 2; ++i) {
memcpy_params.src_strides[0][i] = pinfo_h_p.shape[0] * pinfo_h_p.shape[1];
Expand All @@ -222,7 +222,7 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG

std::array<comm_count_t, 2> counts{static_cast<comm_count_t>(halo_size), static_cast<comm_count_t>(halo_size)};
std::array<size_t, 2> offsets{};
offsets[1] = halo_size;
offsets[1] = static_cast<size_t>(alignCountToBytes(halo_size, CUDECOMP_WORKSPACE_ALIGN_BYTES));

if (handle->performance_report_enable && current_sample) {
current_sample->sendrecv_bytes = 0;
Expand All @@ -239,7 +239,7 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG
memcpy_params.dest[0] = input + getPencilPtrOffset(pinfo_h_p, {0, 0, 0});

// Right
memcpy_params.src[1] = recv_buff + halo_size;
memcpy_params.src[1] = recv_buff + alignCountToBytes(halo_size, CUDECOMP_WORKSPACE_ALIGN_BYTES);
lx[dim] = shape_g_h_p[dim] - halo_extents[dim] - padding[dim];
memcpy_params.dest[1] = input + getPencilPtrOffset(pinfo_h_p, lx);

Expand Down Expand Up @@ -273,7 +273,7 @@ void cudecompUpdateHalos_(int ax, const cudecompHandle_t handle, const cudecompG
// Contiguous (direct send/recv)
std::array<int32_t, 3> lx{};

size_t halo_size = shape_g_h[(dim + 1) % 3] * shape_g_h[(dim + 2) % 3] * halo_extents[dim];
int64_t halo_size = shape_g_h[(dim + 1) % 3] * shape_g_h[(dim + 2) % 3] * halo_extents[dim];
std::array<comm_count_t, 2> counts{static_cast<comm_count_t>(halo_size), static_cast<comm_count_t>(halo_size)};
std::array<size_t, 2> send_offsets;
std::array<size_t, 2> recv_offsets;
Expand Down
6 changes: 3 additions & 3 deletions include/internal/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ static inline cutensorComputeDescriptor_t getCutensorComputeType(cutensorDataTyp

template <typename T> static inline uint32_t getAlignment(const T* ptr) {
auto i_ptr = reinterpret_cast<std::uintptr_t>(ptr);
for (uint32_t d = 16; d > 0; d >>= 1) {
for (uint32_t d = CUDECOMP_WORKSPACE_ALIGN_BYTES; d > 0; d >>= 1) {
if (i_ptr % d == 0) return d;
}
return 1;
Expand Down Expand Up @@ -277,13 +277,13 @@ static void cudecompTranspose_(int ax, int dir, const cudecompHandle_t handle, c
// Set input/output pointers for each phase
T* i1 = input;
T* o1 = work;
T* o2 = work + pinfo_a.size;
T* o2 = work + alignCountToBytes(pinfo_a.size, CUDECOMP_WORKSPACE_ALIGN_BYTES);
T* o3 = output;

#ifdef ENABLE_NVSHMEM
if (transposeBackendRequiresNvshmem(grid_desc->config.transpose_comm_backend)) {
auto max_pencil_size_a = getGlobalMaxPencilSize(handle, grid_desc, ax_a);
o2 = work + max_pencil_size_a;
o2 = work + alignCountToBytes(max_pencil_size_a, CUDECOMP_WORKSPACE_ALIGN_BYTES);

// NVSHMEM team synchronization between transpose operations
if (splits_a.size() != 1) {
Expand Down
22 changes: 17 additions & 5 deletions src/cudecomp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,14 @@ cudecompResult_t cudecompGetTransposeWorkspaceSize(cudecompHandle_t handle, cude
int64_t max_pencil_size_x = getGlobalMaxPencilSize(handle, grid_desc, 0);
int64_t max_pencil_size_y = getGlobalMaxPencilSize(handle, grid_desc, 1);
int64_t max_pencil_size_z = getGlobalMaxPencilSize(handle, grid_desc, 2);
*workspace_size = std::max(max_pencil_size_x + max_pencil_size_y, max_pencil_size_y + max_pencil_size_z);

// Round send portion of workspace to 256 byte boundary (in elements, assuming float)
int64_t wsize_xy = alignCountToBytes(max_pencil_size_x, CUDECOMP_WORKSPACE_ALIGN_BYTES) + max_pencil_size_y;
int64_t wsize_yx = alignCountToBytes(max_pencil_size_y, CUDECOMP_WORKSPACE_ALIGN_BYTES) + max_pencil_size_x;
int64_t wsize_yz = alignCountToBytes(max_pencil_size_y, CUDECOMP_WORKSPACE_ALIGN_BYTES) + max_pencil_size_z;
int64_t wsize_zy = alignCountToBytes(max_pencil_size_z, CUDECOMP_WORKSPACE_ALIGN_BYTES) + max_pencil_size_y;

*workspace_size = std::max({wsize_xy, wsize_yx, wsize_yz, wsize_zy});

} catch (const cudecomp::BaseException& e) {
std::cerr << e.what();
Expand All @@ -1148,11 +1155,16 @@ cudecompResult_t cudecompGetHaloWorkspaceSize(cudecompHandle_t handle, cudecompG
cudecompPencilInfo_t pinfo;
CHECK_CUDECOMP(cudecompGetPencilInfo(handle, grid_desc, &pinfo, axis, halo_extents, nullptr));
auto shape_g = getShapeG(pinfo);
size_t halo_size_x = 4 * shape_g[1] * shape_g[2] * pinfo.halo_extents[0];
size_t halo_size_y = 4 * shape_g[0] * shape_g[2] * pinfo.halo_extents[1];
size_t halo_size_z = 4 * shape_g[0] * shape_g[1] * pinfo.halo_extents[2];

*workspace_size = std::max(halo_size_x, std::max(halo_size_y, halo_size_z));
// Round all halo slots in workspace to 256 byte boundary (in elements, assuming float)
int64_t halo_size_x =
4 * alignCountToBytes(shape_g[1] * shape_g[2] * pinfo.halo_extents[0], CUDECOMP_WORKSPACE_ALIGN_BYTES);
int64_t halo_size_y =
4 * alignCountToBytes(shape_g[0] * shape_g[2] * pinfo.halo_extents[1], CUDECOMP_WORKSPACE_ALIGN_BYTES);
int64_t halo_size_z =
4 * alignCountToBytes(shape_g[0] * shape_g[1] * pinfo.halo_extents[2], CUDECOMP_WORKSPACE_ALIGN_BYTES);

*workspace_size = std::max({halo_size_x, halo_size_y, halo_size_z});
} catch (const cudecomp::BaseException& e) {
std::cerr << e.what();
return e.getResult();
Expand Down
Loading