Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ggml/src/ggml-cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -1556,11 +1556,14 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
}

inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
// first index wins on ties (torch.argmax / np.argmax convention)
float max = -INFINITY;
int idx = 0;
for (int i = 0; i < n; ++i) {
max = MAX(max, x[i]);
if (max == x[i]) { idx = i; }
if (x[i] > max) {
max = x[i];
idx = i;
}
}
*s = idx;
}
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cuda/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
if (val > maxval) {
maxval = val;
argmax = col;
} else if (val == maxval && col < argmax) {
// keep the smallest index on ties (first-occurrence wins)
argmax = col;
}
}

Expand Down Expand Up @@ -56,6 +59,9 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
if (val > maxval) {
maxval = val;
argmax = col;
} else if (val == maxval && col < argmax) {
// keep the smallest index on ties (first-occurrence wins)
argmax = col;
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2914,8 +2914,10 @@ kernel void kernel_argmax_f32(
}

// find the argmax value in the block
// keep the smallest index on ties: non-max threads get a sentinel so simd_min ignores them
constexpr int32_t ARGMAX_SENTINEL = numeric_limits<int32_t>::max();
float max_val = simd_max(lmax);
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
int32_t arg_val = simd_min(select(ARGMAX_SENTINEL, larg, lmax == max_val));

device int32_t * dst_i32 = (device int32_t *) dst;

Expand All @@ -2941,7 +2943,7 @@ kernel void kernel_argmax_f32(
arg_val = shared_argmax[tiisg];

float max_val_reduced = simd_max(max_val);
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
int32_t arg_val_reduced = simd_min(select(ARGMAX_SENTINEL, arg_val, max_val == max_val_reduced));

dst_i32[tgpig] = arg_val_reduced;

Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2287,6 +2287,9 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
if (val2 > val1) {
shared_data[tid] = val2;
shared_indices[tid] = shared_indices[tid + stride];
} else if (val2 == val1 && shared_indices[tid + stride] < shared_indices[tid]) {
// keep the smallest index on ties (first-occurrence wins)
shared_indices[tid] = shared_indices[tid + stride];
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ void main() {
if (tmpmax[col] < tmpmax[col + s]) {
tmpmax[col] = tmpmax[col + s];
tmp[col] = tmp[col + s];
} else if (tmpmax[col] == tmpmax[col + s] && tmp[col + s] < tmp[col]) {
// keep the smallest index on ties (first-occurrence wins)
tmp[col] = tmp[col + s];
}
}
barrier();
Expand Down
6 changes: 3 additions & 3 deletions ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let vec_val = src[row_idx / VEC_SIZE + col];
for (var v = 0u; v < VEC_SIZE; v++) {
let val = vec_val[v];
if (val >= local_pair.value) {
if (val > local_pair.value) {
local_pair = Pair(val, i32(col * VEC_SIZE + v));
}
}
}
#else
for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
if (src[row_idx + col] >= local_pair.value) {
if (src[row_idx + col] > local_pair.value) {
local_pair = Pair(src[row_idx + col], i32(col));
}
}
Expand All @@ -59,7 +59,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let b = shared_max[lid.x + offset];
if (b.value > a.value) {
shared_max[lid.x] = b;
} else if (b.value == a.value && b.index > a.index) {
} else if (b.value == a.value && b.index < a.index) {
shared_max[lid.x] = b;
}
}
Expand Down
61 changes: 61 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2649,6 +2649,59 @@ struct test_argmax : public test_case {
}
};

// GGML_OP_ARGMAX with ties: first-occurrence index must match across backends
struct test_argmax_ties : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;

std::string vars() override {
return VARS_TO_STR2(type, ne);
}

test_argmax_ties(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 100, 1, 1})
: type(type), ne(ne) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");

ggml_tensor * out = ggml_argmax(ctx, a);
ggml_set_name(out, "out");

return out;
}

void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_F32) {
// Ramp 0..ne0-1 with the max (ne0) at three positions; the
// first occurrence is at a non-zero index so the NMSE metric
// (mse(a,b)/mse(a,0)) is well-defined, not 0/0 under -ffast-math.
const int64_t ne0 = t->ne[0];
const float max_val = (float) ne0;
constexpr int64_t first_idx = 3;
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<float> data(ne0);
for (int i = 0; i < ne0; i++) {
data[i] = (float) i;
}
data[first_idx] = max_val;
data[ne0 / 2] = max_val;
data[ne0 - 1] = max_val;
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], ne0 * sizeof(float));
}
} else {
init_tensor_uniform(t);
}
}
}

double max_nmse_err() override {
return 0.0;
}
};

// GGML_OP_COUNT_EQUAL
struct test_count_equal : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -8094,6 +8147,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1}));

// argmax with ties: sizes span per-thread and cross-thread reduction paths
test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {32, 1, 1, 1}));
test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {33, 1, 1, 1}));
test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {100, 10, 1, 1}));
test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {1024, 10, 1, 1}));
test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {2000, 10, 1, 1}));
test_cases.emplace_back(new test_argmax_ties(GGML_TYPE_F32, {5438, 3, 1, 1}));

for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
Expand Down
Loading