Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ static bool common_pull_file(httplib::Client & cli,
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
if (callback) {
callback->on_update(p);
if (callback->is_cancelled()) {
return false;
}
}
progress_step = 0;
}
Expand Down Expand Up @@ -373,6 +376,9 @@ static int common_download_file_single_online(const std::string & url,
}

for (int i = 0; i < max_attempts; ++i) {
if (opts.callback && opts.callback->is_cancelled()) {
break;
}
if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
std::this_thread::sleep_for(std::chrono::seconds(delay));
Expand Down Expand Up @@ -412,6 +418,12 @@ static int common_download_file_single_online(const std::string & url,
if (opts.callback) {
opts.callback->on_done(p, success);
}
if (opts.callback && opts.callback->is_cancelled() &&
std::filesystem::exists(path_temporary)) {
if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, path_temporary.c_str());
}
}
if (!success) {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
Expand Down
1 change: 1 addition & 0 deletions common/download.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class common_download_callback {
virtual void on_start(const common_download_progress & p) = 0;
virtual void on_update(const common_download_progress & p) = 0;
virtual void on_done(const common_download_progress & p, bool ok) = 0;
virtual bool is_cancelled() const { return false; }
};

struct common_remote_params {
Expand Down
79 changes: 56 additions & 23 deletions ggml/src/ggml-cuda/argsort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,48 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,

size_t temp_storage_bytes = 0;

bool is_capturing = false;
#ifdef USE_CUDA_GRAPH
// Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does.
// See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149
// TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release.
cudaStreamCaptureStatus capture_status;
CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status));
is_capturing = (capture_status != cudaStreamCaptureStatusNone);
#endif // USE_CUDA_GRAPH

if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) {
CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
} else if (is_capturing) {
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
} else {
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
offset_iterator, offset_iterator + 1, stream));
CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys,
temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
offset_iterator, offset_iterator + 1, stream));
}
} else {
if (nrows == 1) {
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys,
temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
} else if (is_capturing) {
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
} else {
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
stream));
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows,
offset_iterator, offset_iterator + 1, stream));
}
}

Expand All @@ -86,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,

if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) {
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys,
temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
} else if (is_capturing) {
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
offset_iterator + 1, 0, sizeof(float) * 8, stream));
} else {
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream));
CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
offset_iterator + 1, stream));
}
} else {
if (nrows == 1) {
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols, 0, sizeof(float) * 8, stream));
} else if (is_capturing) {
CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending(
d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows,
offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream));
} else {
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
offset_iterator + 1, stream));
CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys,
temp_keys, temp_indices, dst, ncols * nrows, nrows,
offset_iterator, offset_iterator + 1, stream));
}
}
}
Expand Down
97 changes: 76 additions & 21 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params {
}
};

static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type);
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);

static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
GGML_UNUSED(kv_type);

vk_fa_tuning_params result{};
result.path = FA_SCALAR;
Expand Down Expand Up @@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,

result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;

if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) {
result.block_rows /= 2;
}

Expand Down Expand Up @@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->fp16) {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )

#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, )
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, )
}
} else {
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)

#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product && device->subgroup_clustered) {
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8)
} else
#endif
{
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32)
}
}
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
Expand Down Expand Up @@ -8780,7 +8805,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}

static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
GGML_UNUSED(f32acc);
// Needs to be kept up to date on shader changes
const uint32_t wg_size = params.workgroup_size;
Expand All @@ -8789,21 +8814,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con

const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);

const bool mmq = device->integer_dot_product && device->subgroup_clustered &&
(kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 ||
kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 ||
kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL);

// tmpsh is overestimated slightly
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;

const uint32_t masksh = Bc * (Br + 1) * float_type_size;

const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
uint32_t Qf, kvsh, kblocksh_size;
if (mmq) {
// block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
Qf = Br * (hsk / 32) * block_b_size;

// kvsh uses D = HSV (K goes through kblocksh instead)
kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;

// block_a_cache size depends on quant type
uint32_t block_a_size;
switch (kv_type) {
case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break;
case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break;
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break;
default: block_a_size = 0; break;
}
kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
} else {
Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;

const uint32_t D = std::max(hsk, hsv);
kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;

const uint32_t D = std::max(hsk, hsv);
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
kblocksh_size = 0;
}

const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;

VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);

return supported;
}
Expand Down
Loading
Loading