diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 5de9cb5b7e09..40fdcc585d56 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1556,11 +1556,14 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x } inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { + // first index wins on ties (torch.argmax / np.argmax convention) float max = -INFINITY; int idx = 0; for (int i = 0; i < n; ++i) { - max = MAX(max, x[i]); - if (max == x[i]) { idx = i; } + if (x[i] > max) { + max = x[i]; + idx = i; + } } *s = idx; } diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu index 51967c667cfd..45852a3c1fda 100644 --- a/ggml/src/ggml-cuda/argmax.cu +++ b/ggml/src/ggml-cuda/argmax.cu @@ -27,6 +27,9 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest if (val > maxval) { maxval = val; argmax = col; + } else if (val == maxval && col < argmax) { + // keep the smallest index on ties (first-occurrence wins) + argmax = col; } } @@ -56,6 +59,9 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest if (val > maxval) { maxval = val; argmax = col; + } else if (val == maxval && col < argmax) { + // keep the smallest index on ties (first-occurrence wins) + argmax = col; } } } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 25e78e100898..60405d24e458 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2914,8 +2914,10 @@ kernel void kernel_argmax_f32( } // find the argmax value in the block + // keep the smallest index on ties: non-max threads get a sentinel so simd_min ignores them + constexpr int32_t ARGMAX_SENTINEL = numeric_limits::max(); float max_val = simd_max(lmax); - int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); + int32_t arg_val = simd_min(select(ARGMAX_SENTINEL, larg, lmax == max_val)); device int32_t * dst_i32 = (device int32_t *) dst; @@ -2941,7 +2943,7 @@ kernel void kernel_argmax_f32( arg_val = shared_argmax[tiisg]; float max_val_reduced = simd_max(max_val); - int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); + int32_t arg_val_reduced = simd_min(select(ARGMAX_SENTINEL, arg_val, max_val == max_val_reduced)); dst_i32[tgpig] = arg_val_reduced; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 41449db665ec..20a33dcdf4af 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2287,6 +2287,9 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, if (val2 > val1) { shared_data[tid] = val2; shared_indices[tid] = shared_indices[tid + stride]; + } else if (val2 == val1 && shared_indices[tid + stride] < shared_indices[tid]) { + // keep the smallest index on ties (first-occurrence wins) + shared_indices[tid] = shared_indices[tid + stride]; } } item_ct1.barrier(sycl::access::fence_space::local_space); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp index 7c128776710e..bbad824e2d58 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -49,6 +49,9 @@ void main() { if (tmpmax[col] < tmpmax[col + s]) { tmpmax[col] = tmpmax[col + s]; tmp[col] = tmp[col + s]; + } else if (tmpmax[col] == tmpmax[col + s] && tmp[col + s] < tmp[col]) { + // keep the smallest index on ties (first-occurrence wins) + tmp[col] = tmp[col + s]; } } barrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl index ca5bfcc4d4c9..d5ceb175b37c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl @@ -38,14 +38,14 @@ fn main(@builtin(workgroup_id) wid: vec3, let vec_val = src[row_idx / VEC_SIZE + col]; for (var v = 0u; v < VEC_SIZE; v++) { let val = vec_val[v]; - if (val >= local_pair.value) { + if (val > local_pair.value) { local_pair = Pair(val, i32(col * VEC_SIZE + v)); } } } #else for (var col = lid.x; col < params.ne0; col += WG_SIZE) { - if (src[row_idx + col] >= local_pair.value) { + if (src[row_idx + col] > local_pair.value) { local_pair = Pair(src[row_idx + col], i32(col)); } } @@ -59,7 +59,7 @@ fn main(@builtin(workgroup_id) wid: vec3, let b = shared_max[lid.x + offset]; if (b.value > a.value) { shared_max[lid.x] = b; - } else if (b.value == a.value && b.index > a.index) { + } else if (b.value == a.value && b.index < a.index) { shared_max[lid.x] = b; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3f18dbe220c2..1a9fca4b8fbd 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2649,6 +2649,59 @@ struct test_argmax : public test_case { } }; +// GGML_OP_ARGMAX with ties: first-occurrence index must match across backends +struct test_argmax_ties : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_argmax_ties(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 100, 1, 1}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_argmax(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_F32) { + // Ramp 0..ne0-1 with the max (ne0) at three positions; the + // first occurrence is at a non-zero index so the NMSE metric + // (mse(a,b)/mse(a,0)) is well-defined, not 0/0 under -ffast-math. + const int64_t ne0 = t->ne[0]; + const float max_val = (float) ne0; + constexpr int64_t first_idx = 3; + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(ne0); + for (int i = 0; i < ne0; i++) { + data[i] = (float) i; + } + data[first_idx] = max_val; + data[ne0 / 2] = max_val; + data[ne0 - 1] = max_val; + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], ne0 * sizeof(float)); + } + } else { + init_tensor_uniform(t); + } + } + } + + double max_nmse_err() override { + return 0.0; + } +}; + // GGML_OP_COUNT_EQUAL struct test_count_equal : public test_case { const ggml_type type; @@ -8094,6 +8147,14 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1})); test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1})); + // argmax with ties: sizes span per-thread and cross-thread reduction paths + test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {32, 1, 1, 1})); + test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {33, 1, 1, 1})); + test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {100, 10, 1, 1})); + test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {1024, 10, 1, 1})); + test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {2000, 10, 1, 1})); + test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {5438, 3, 1, 1})); + for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1 test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));