Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions ynnpack/subgraph/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector>

#include "ynnpack/base/base.h"
#include "ynnpack/base/log.h"
#include "ynnpack/base/type.h"
#include "ynnpack/include/ynnpack.h"
#include "ynnpack/subgraph/reduce.h"
Expand Down Expand Up @@ -240,6 +241,8 @@ uint32_t get_reduce_identity_value(ynn_subgraph& subgraph,
value_f32[0] = std::numeric_limits<float>::infinity();
value_f32[1] = -std::numeric_limits<float>::infinity();
rank = output.rank();
assert(rank >= 1);
assert(rank < YNN_MAX_TENSOR_RANK);
dims[rank - 1] = 2;
break;
default:
Expand Down Expand Up @@ -278,10 +281,10 @@ uint32_t get_reduce_identity_value(ynn_subgraph& subgraph,

} // namespace

void define_reduce(ynn_subgraph& subgraph, ynn_node& node,
ynn_reduce_operator op, const ynn::axes_set& k_dims,
uint32_t input_a_id, uint32_t input_b_id,
uint32_t* output_id, bool keep_dims) {
ynn_status define_reduce(ynn_subgraph& subgraph, ynn_node& node,
ynn_reduce_operator op, const ynn::axes_set& k_dims,
uint32_t input_a_id, uint32_t input_b_id,
uint32_t* output_id, bool keep_dims) {
const ynn_value& a = subgraph.value(input_a_id);

if (*output_id == YNN_INVALID_VALUE_ID) {
Expand Down Expand Up @@ -345,6 +348,12 @@ void define_reduce(ynn_subgraph& subgraph, ynn_node& node,
output.extents.push_back(2);
}

if (output.rank() >= YNN_MAX_TENSOR_RANK) {
YNN_LOG_ERROR() << "output rank " << output.rank()
<< " exceeds YNN_MAX_TENSOR_RANK " << YNN_MAX_TENSOR_RANK;
return ynn_status_unsupported_parameter;
}

if (input_b_id == YNN_INVALID_VALUE_ID) {
input_b_id = get_reduce_identity_value(subgraph, output, op);
} else {
Expand Down Expand Up @@ -429,6 +438,7 @@ void define_reduce(ynn_subgraph& subgraph, ynn_node& node,
runtime.funcs.push_back(std::move(func));
return ynn_status_success;
};
return ynn_status_success;
}

extern "C" {
Expand Down Expand Up @@ -473,8 +483,8 @@ ynn_status ynn_define_reduce(ynn_subgraph_t subgraph,

// Make the node.
ynn_node node;
define_reduce(*subgraph, node, op, k_dims, input_a_id, input_b_id, output_id,
keep_dims);
YNN_RETURN_IF_ERROR(define_reduce(*subgraph, node, op, k_dims, input_a_id,
input_b_id, output_id, keep_dims));
subgraph->add_node(std::move(node));

if (convert_to_id != YNN_INVALID_VALUE_ID) {
Expand Down
17 changes: 14 additions & 3 deletions ynnpack/xnnpack/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,20 @@ xnn_status xnn_define_channelwise_quantized_tensor_value_v3(
const float* scale, size_t num_dims, size_t channel_dim, const size_t* dims,
const void* data, uint32_t external_id, uint32_t flags, uint32_t* id_out,
const float* channelwise_zero_point) {
// Channelwise zero points are not supported yet.
assert(channelwise_zero_point == nullptr);
assert(data);
if (channelwise_zero_point) {
YNN_LOG_ERROR() << "channelwise zero points are not supported";
return xnn_status_unsupported_parameter;
}
if (num_dims > YNN_MAX_TENSOR_RANK) {
YNN_LOG_ERROR() << "num_dims " << num_dims << " exceeds YNN_MAX_TENSOR_RANK "
<< YNN_MAX_TENSOR_RANK;
return xnn_status_unsupported_parameter;
}
if (channel_dim >= num_dims) {
YNN_LOG_ERROR() << "channel_dim " << channel_dim << " must be in [0, "
<< num_dims << ")";
return xnn_status_invalid_parameter;
}
uint32_t zero_point_id = YNN_INVALID_VALUE_ID;
if (zero_point != 0) {
ynn_status status = ynn_define_tensor(
Expand Down
Loading