Skip to content

Commit be986be

Browse files
Address review feedback: split validations with error messages
- tensor.cc: Split combined check into separate num_dims and channel_dim validations with YNN_LOG_ERROR messages. Replace asserts with proper error returns for channelwise_zero_point. Remove assert(data) per reviewer (XNNPACK limitation, not YNNPACK). - reduce.cc: Change define_reduce to return ynn_status. Add output rank validation after min_max dimension push. Keep rank >= 1 as assert (internal invariant). Propagate error via YNN_RETURN_IF_ERROR at call site. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ce2d95a commit be986be

2 files changed

Lines changed: 28 additions & 13 deletions

File tree

ynnpack/subgraph/reduce.cc

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <vector>
1717

1818
#include "ynnpack/base/base.h"
19+
#include "ynnpack/base/log.h"
1920
#include "ynnpack/base/type.h"
2021
#include "ynnpack/include/ynnpack.h"
2122
#include "ynnpack/subgraph/reduce.h"
@@ -240,9 +241,8 @@ uint32_t get_reduce_identity_value(ynn_subgraph& subgraph,
240241
value_f32[0] = std::numeric_limits<float>::infinity();
241242
value_f32[1] = -std::numeric_limits<float>::infinity();
242243
rank = output.rank();
243-
if (rank < 1 || rank > YNN_MAX_TENSOR_RANK) {
244-
return YNN_INVALID_VALUE_ID;
245-
}
244+
assert(rank >= 1);
245+
assert(rank < YNN_MAX_TENSOR_RANK);
246246
dims[rank - 1] = 2;
247247
break;
248248
default:
@@ -281,10 +281,10 @@ uint32_t get_reduce_identity_value(ynn_subgraph& subgraph,
281281

282282
} // namespace
283283

284-
void define_reduce(ynn_subgraph& subgraph, ynn_node& node,
285-
ynn_reduce_operator op, const ynn::axes_set& k_dims,
286-
uint32_t input_a_id, uint32_t input_b_id,
287-
uint32_t* output_id, bool keep_dims) {
284+
ynn_status define_reduce(ynn_subgraph& subgraph, ynn_node& node,
285+
ynn_reduce_operator op, const ynn::axes_set& k_dims,
286+
uint32_t input_a_id, uint32_t input_b_id,
287+
uint32_t* output_id, bool keep_dims) {
288288
const ynn_value& a = subgraph.value(input_a_id);
289289

290290
if (*output_id == YNN_INVALID_VALUE_ID) {
@@ -348,6 +348,12 @@ void define_reduce(ynn_subgraph& subgraph, ynn_node& node,
348348
output.extents.push_back(2);
349349
}
350350

351+
if (output.rank() >= YNN_MAX_TENSOR_RANK) {
352+
YNN_LOG_ERROR() << "output rank " << output.rank()
353+
<< " exceeds YNN_MAX_TENSOR_RANK " << YNN_MAX_TENSOR_RANK;
354+
return ynn_status_unsupported_parameter;
355+
}
356+
351357
if (input_b_id == YNN_INVALID_VALUE_ID) {
352358
input_b_id = get_reduce_identity_value(subgraph, output, op);
353359
} else {
@@ -432,6 +438,7 @@ void define_reduce(ynn_subgraph& subgraph, ynn_node& node,
432438
runtime.funcs.push_back(std::move(func));
433439
return ynn_status_success;
434440
};
441+
return ynn_status_success;
435442
}
436443

437444
extern "C" {
@@ -476,8 +483,8 @@ ynn_status ynn_define_reduce(ynn_subgraph_t subgraph,
476483

477484
// Make the node.
478485
ynn_node node;
479-
define_reduce(*subgraph, node, op, k_dims, input_a_id, input_b_id, output_id,
480-
keep_dims);
486+
YNN_RETURN_IF_ERROR(define_reduce(*subgraph, node, op, k_dims, input_a_id,
487+
input_b_id, output_id, keep_dims));
481488
subgraph->add_node(std::move(node));
482489

483490
if (convert_to_id != YNN_INVALID_VALUE_ID) {

ynnpack/xnnpack/tensor.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,18 @@ xnn_status xnn_define_channelwise_quantized_tensor_value_v3(
120120
const float* scale, size_t num_dims, size_t channel_dim, const size_t* dims,
121121
const void* data, uint32_t external_id, uint32_t flags, uint32_t* id_out,
122122
const float* channelwise_zero_point) {
123-
// Channelwise zero points are not supported yet.
124-
assert(channelwise_zero_point == nullptr);
125-
assert(data);
126-
if (channel_dim >= num_dims || num_dims > YNN_MAX_TENSOR_RANK) {
123+
if (channelwise_zero_point) {
124+
YNN_LOG_ERROR() << "channelwise zero points are not supported";
125+
return xnn_status_unsupported_parameter;
126+
}
127+
if (num_dims > YNN_MAX_TENSOR_RANK) {
128+
YNN_LOG_ERROR() << "num_dims " << num_dims << " exceeds YNN_MAX_TENSOR_RANK "
129+
<< YNN_MAX_TENSOR_RANK;
130+
return xnn_status_unsupported_parameter;
131+
}
132+
if (channel_dim >= num_dims) {
133+
YNN_LOG_ERROR() << "channel_dim " << channel_dim << " must be in [0, "
134+
<< num_dims << ")";
127135
return xnn_status_invalid_parameter;
128136
}
129137
uint32_t zero_point_id = YNN_INVALID_VALUE_ID;

0 commit comments

Comments
 (0)