diff --git a/BUILD.bazel b/BUILD.bazel index a16077088d8..84e9d9139ca 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -904,7 +904,7 @@ xnnpack_cc_library( ":xnnpack_h", "//src/configs:config_hdrs", "//src/configs:hardware_config", - "//src/configs:microkernel_configs", + "//src/subgraph/rewrites:fp16_to_fp32", "@pthreadpool", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 47c6ff59f5a..99f15a232ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -557,6 +557,7 @@ SET(SUBGRAPH_SRCS src/subgraph/max-pooling-2d.c src/subgraph/pack-lh.c src/subgraph/reshape-helpers.c + src/subgraph/rewrites/fp16_to_fp32.cc src/subgraph/rope.c src/subgraph/softmax.c src/subgraph/space-to-depth-2d.c @@ -982,7 +983,7 @@ IF(XNNPACK_BUILD_LIBRARY) TARGET_LINK_LIBRARIES(xnnpack-operator-run PRIVATE xnnpack-base xnnpack-logging) TARGET_LINK_LIBRARIES(xnnpack-operator-utils PRIVATE xnnpack-base xnnpack-logging) TARGET_LINK_LIBRARIES(xnnpack-reference-ukernels PRIVATE xnnpack-base xnnpack-datatype) - TARGET_LINK_LIBRARIES(xnnpack-subgraph PRIVATE xnnpack-base xnnpack-allocator xnnpack-logging xnnpack-memory xnnpack-mutex xnnpack-operators xnnpack-operator-run xnnpack-datatype) + TARGET_LINK_LIBRARIES(xnnpack-subgraph PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache xnnpack-logging xnnpack-memory xnnpack-mutex xnnpack-operators xnnpack-operator-run xnnpack-datatype) TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run diff --git a/src/runtime.c b/src/runtime.c index 0f137062ada..67db629e403 100644 --- a/src/runtime.c +++ b/src/runtime.c @@ -10,6 +10,7 @@ #include // For snprintf. #include #include +#include "src/subgraph/rewrites/fp16_to_fp32.h" #if defined(__EMSCRIPTEN__) #include @@ -585,6 +586,9 @@ enum xnn_status xnn_create_runtime_v4( goto error; } + XNN_IF_ERROR_GOTO(error, xnn_subgraph_alias_fp16_fp32_fallback_data( + subgraph, weights_cache)); + status = xnn_status_out_of_memory; runtime = xnn_allocate_zero_memory(sizeof(struct xnn_runtime)); @@ -1161,7 +1165,7 @@ enum xnn_status xnn_delete_runtime( xnn_release_memory(runtime->opdata); if (runtime->values != NULL) { - // Release the buffers created during FP16 rewrite. + // Release buffers created during rewrites. for (size_t i = 0; i < runtime->num_values; i++) { struct xnn_runtime_value* value = &runtime->values[i]; if (value->allocation_type == xnn_allocation_type_dynamic || diff --git a/src/subgraph.c b/src/subgraph.c index 5b2b35577a1..eaf79306b6e 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -15,6 +15,7 @@ #include "include/experimental.h" #include "include/xnnpack.h" +#include "src/subgraph/rewrites/fp16_to_fp32.h" #include "src/xnnpack/allocation-type.h" #include "src/xnnpack/allocator.h" #include "src/xnnpack/common.h" @@ -4336,6 +4337,10 @@ enum xnn_status xnn_subgraph_optimize(xnn_subgraph_t subgraph, XNN_RETURN_IF_ERROR( xnn_subgraph_optimize_packed_lhs(subgraph, optimization_flags)); + if (!xnn_is_f16_supported_natively(hardware_config)) { + XNN_RETURN_IF_ERROR(xnn_subgraph_fallback_from_fp16_to_fp32(subgraph, optimization_flags)); + } + return xnn_status_success; } diff --git a/src/subgraph/rewrites/BUILD b/src/subgraph/rewrites/BUILD new file mode 100644 index 00000000000..76537ecf0ea --- /dev/null +++ b/src/subgraph/rewrites/BUILD @@ -0,0 +1,21 @@ +load( + "//:build_defs.bzl", + "xnnpack_cxx_library", +) + +package(default_visibility = ["//:__subpackages__"]) + +xnnpack_cxx_library( + name = "fp16_to_fp32", + srcs = ["fp16_to_fp32.cc"], + hdrs = ["fp16_to_fp32.h"], + deps = [ + "//:allocator", + "//:cache", + "//:internal", + "//:logging", + "//:node_type", + "//:subgraph_h", + "//:xnnpack_h", + ], +) diff --git a/src/subgraph/rewrites/fp16_to_fp32.cc b/src/subgraph/rewrites/fp16_to_fp32.cc new file mode 100644 index 00000000000..8dbeddaddc1 --- /dev/null +++ b/src/subgraph/rewrites/fp16_to_fp32.cc @@ -0,0 +1,323 @@ +#include "src/subgraph/rewrites/fp16_to_fp32.h" + +#include +#include +#include +#include +#include + +#include "include/xnnpack.h" +#include "src/xnnpack/allocator.h" +#include "src/xnnpack/cache.h" +#include "src/xnnpack/internal.h" +#include "src/xnnpack/log.h" +#include "src/xnnpack/node-type.h" +#include "src/xnnpack/subgraph.h" + +namespace xnnpack { + +namespace { + +bool ReplaceInSet(uint32_t* set, uint32_t size, uint32_t old_value, + uint32_t new_value) { + bool replaced = false; + for (uint32_t i = 0; i < size; i++) { + if (set[i] == old_value) { + set[i] = new_value; + replaced |= true; + } + } + return replaced; +} + +enum class OpAction { + kNone, // Don't do anything, skip the node. + kRewrite, // Force outputs to fp32, insert converts from fp16 input + // to fp32 + kNeedsFP16Inputs, // The inputs must be converted back to fp16 if they werre + // rewritten. + kTransparent, // If the inputs have been rewritten, the outputs must be also. + kElide, // This op should be removed (eg. convert(fp32, fp32)). +}; + +// Is this op supported when fp16 hardware is missing (allow-list). +// TODO: b/487077315 - Add allow list for supported ops. +OpAction GetOpAction(const xnn_subgraph_t subgraph, const xnn_node& node) { + switch (node.type) { + case xnn_node_type_unary_elementwise: { + switch (node.unary_operator) { + case xnn_unary_convert: { + const xnn_value& output = subgraph->values[node.outputs[0]]; + const xnn_value& input = subgraph->values[node.inputs[0]]; + // Elide converts from T to T. These are no-ops that may be introduced + // by the rewrite. + if (output.datatype == input.datatype) { + return OpAction::kElide; + } + if (output.datatype == xnn_datatype_fp16 || + input.datatype == xnn_datatype_fp16) { + return OpAction::kNone; + } + } break; + default: + break; + } + } break; + case xnn_node_type_static_reshape: + return OpAction::kTransparent; + default: + break; + } + return OpAction::kRewrite; +} + +// Checks if an op has an fp16 input or output and whether we currently +// support that. +bool HasFp16Values(const xnn_subgraph_t subgraph, const xnn_node& node) { + auto IsFp16Value = [subgraph](uint32_t id) { + return subgraph->values[id].datatype == xnn_datatype_fp16 || + subgraph->values[id].fp16_to_fp32_fallback.was_overwritten; + }; + return std::any_of(node.inputs, node.inputs + node.num_inputs, IsFp16Value) || + std::any_of(node.outputs, node.outputs + node.num_outputs, + IsFp16Value); +} + +void RemoveFlag(uint32_t& bitfield, uint32_t flag) { + const uint32_t mask = 0xFFFFFFFF ^ flag; + bitfield &= mask; +} + +} // namespace + +} // namespace xnnpack + +enum xnn_status xnn_subgraph_fallback_from_fp16_to_fp32( + xnn_subgraph_t subgraph, int optimization_flags) { + // Maps fp16 value ids to the corresponding fp32 value id if a conversion + // has been inserted. + std::vector fp16_id_to_fp32_id(subgraph->num_values, + XNN_INVALID_VALUE_ID); + // Maps fp32 value ids to the corresponding fp16 value id if a conversion + // has been inserted. + std::vector fp32_id_to_fp16_id(subgraph->num_values, + XNN_INVALID_VALUE_ID); + + xnn_log_debug("Running fp16 analysis and falling back to fp32."); + + // Go through the graph. Count nodes that will need to be converted. + const uint32_t original_num_nodes = subgraph->num_nodes; + for (uint32_t n = 0; n < original_num_nodes; ++n) { + // Editing the subgraph may reallocate nodes, we need to access the + // current node through the array each time. + auto CurrentNode = [=]() -> xnn_node& { return subgraph->nodes[n]; }; + if (CurrentNode().type == xnn_node_type_invalid) { + continue; + } + + if (!xnnpack::HasFp16Values(subgraph, CurrentNode())) { + xnn_log_debug("node %d doesn't have fp16 values", n); + continue; + } + + const xnnpack::OpAction op_action = + xnnpack::GetOpAction(subgraph, CurrentNode()); + if (op_action == xnnpack::OpAction::kNone) { + continue; + } + + if (op_action == xnnpack::OpAction::kNeedsFP16Inputs) { + // Check for overwritten inputs that need to be converted back to fp16. + for (uint32_t i = 0; i < CurrentNode().num_inputs; i++) { + // The value is copied because adding new values may invalidate + // references. + const xnn_value value = subgraph->values[CurrentNode().inputs[i]]; + if (value.datatype == xnn_datatype_fp32 && + value.fp16_to_fp32_fallback.was_overwritten) { + if (fp32_id_to_fp16_id[value.id] == XNN_INVALID_VALUE_ID) { + XNN_RETURN_IF_ERROR(xnn_subgraph_add_internal_values(subgraph, 1)); + xnn_value& fp16_value = subgraph->values[subgraph->num_values - 1]; + xnn_value_copy(&fp16_value, &value); + fp16_value.datatype = xnn_datatype_fp16; + fp16_value.size = xnn_tensor_get_size(&fp16_value); + xnn_log_debug("Adding a convert[fp32, fp16](%d, %d) node.", + value.id, fp16_value.id); + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + value.id, fp16_value.id, + /*flags=*/0); + fp32_id_to_fp16_id[value.id] = fp16_value.id; + } else { + xnn_log_debug("Reusing convert[fp32, fp16](%d, %d) node.", value.id, + fp32_id_to_fp16_id[value.id]); + } + CurrentNode().inputs[i] = fp32_id_to_fp16_id[value.id]; + } + } + continue; + } + + if (op_action == xnnpack::OpAction::kElide) { + if (CurrentNode().num_inputs != CurrentNode().num_outputs) { + xnn_log_error("Node %" PRIu32 + " should be elided but it doesn't have that same number " + "of inputs and outputs.", + n); + } else { + xnn_node& node = CurrentNode(); + bool cancel_eliding = false; + for (uint32_t i = 0; i < node.num_inputs; ++i) { + xnn_value& input = subgraph->values[node.inputs[i]]; + if (xnn_value_is_external_output(input.flags)) { + cancel_eliding = true; + break; + } + } + + if (cancel_eliding) { + xnn_log_debug( + "Node %" PRIu32 + " should be elided but one of its inputs is also a graph output.", + n); + } else { + xnn_log_debug("Eliding node %" PRI_U32, n); + for (uint32_t i = 0; i < node.num_inputs; ++i) { + xnn_value& output = subgraph->values[node.outputs[i]]; + xnn_value& input = subgraph->values[node.inputs[i]]; + xnn_node& producer = subgraph->nodes[input.producer]; + // Overwrite the input producer to write to this node's output. + xnnpack::ReplaceInSet(producer.outputs, producer.num_outputs, + input.id, output.id); + // Update the input's consumers' input set. + uint32_t k = std::min(n, input.first_consumer); + for (int j = 0; j < input.num_consumers && k < subgraph->num_nodes; + ++k) { + xnn_node& node_k = subgraph->nodes[k]; + j += xnnpack::ReplaceInSet(node_k.inputs, node_k.num_inputs, + input.id, output.id); + } + output.producer = input.producer; + output.num_consumers += input.num_consumers - 1; + output.first_consumer = + std::min(input.first_consumer, output.first_consumer); + node.type = xnn_node_type_invalid; + } + continue; + } + } + } + + if (op_action == xnnpack::OpAction::kTransparent) { + // If an input has been rewritten from fp16 to fp32, the outputs should + // also be rewritten. + bool needs_output_rewrite = false; + for (uint32_t i = 0; i < CurrentNode().num_inputs; i++) { + xnn_value& value = subgraph->values[CurrentNode().inputs[i]]; + if (value.datatype == xnn_datatype_fp32 && + value.fp16_to_fp32_fallback.was_overwritten) { + needs_output_rewrite = true; + break; + } + } + if (!needs_output_rewrite) { + continue; + } + xnn_log_debug("Node %" PRIu32 + " is transparent and it's inputs have been rewritten.", + n); + } + + // Force outputs to be fp32. + for (uint32_t i = 0; i < CurrentNode().num_outputs; i++) { + xnn_value& value = subgraph->values[CurrentNode().outputs[i]]; + if (value.datatype != xnn_datatype_fp16) { + continue; + } + + if (CurrentNode().outputs[i] < subgraph->external_value_ids) { + // External values can't be overwritten, so we insert a value to get + // the fp32 output and a conversion to the original external tensor. + XNN_RETURN_IF_ERROR(xnn_subgraph_add_internal_values(subgraph, 1)); + xnn_value& fp32_value = subgraph->values[subgraph->num_values - 1]; + xnn_value_copy(&fp32_value, &value); + fp32_value.datatype = xnn_datatype_fp32; + fp32_value.size = xnn_tensor_get_size(&fp32_value); + xnnpack::RemoveFlag( + fp32_value.flags, + XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT); + CurrentNode().outputs[i] = fp32_value.id; + xnn_log_debug("Adding a convert[fp32, fp16](%d, %d) node.", + fp32_value.id, value.id); + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + fp32_value.id, value.id, + /*flags=*/0); + } else { + xnn_log_debug("Overriding value %d from fp16 to fp32.", value.id); + value.datatype = xnn_datatype_fp32; + value.size = xnn_tensor_get_size(&value); + value.fp16_to_fp32_fallback.was_overwritten = true; + } + } + + // Insert conversions to fp32 for fp16 inputs. + for (uint32_t i = 0; i < CurrentNode().num_inputs; i++) { + // The value is copied because adding new values may invalidate + // references. + const xnn_value value = subgraph->values[CurrentNode().inputs[i]]; + if (value.datatype == xnn_datatype_fp16) { + if (fp16_id_to_fp32_id[value.id] == XNN_INVALID_VALUE_ID) { + XNN_RETURN_IF_ERROR(xnn_subgraph_add_internal_values(subgraph, 1)); + xnn_value& fp32_value = subgraph->values[subgraph->num_values - 1]; + xnn_value_copy(&fp32_value, &value); + fp32_value.datatype = xnn_datatype_fp32; + fp32_value.size = xnn_tensor_get_size(&fp32_value); + xnnpack::RemoveFlag( + fp32_value.flags, + XNN_VALUE_FLAG_EXTERNAL_INPUT | XNN_VALUE_FLAG_EXTERNAL_OUTPUT); + if (xnn_value_is_static(value.allocation_type)) { + xnn_log_debug("Converting static value %d to new fp32 value %d.", + value.id, fp32_value.id); + // We convert static values directly to the new value without + // inserting a convert node. + fp32_value.data = xnn_allocate_zero_memory( + xnn_tensor_get_size(&value) / 2 + XNN_EXTRA_BYTES); + fp32_value.flags |= XNN_VALUE_FLAG_NEEDS_CLEANUP; + fp32_value.fp16_to_fp32_fallback.original_data = value.data; + xnn_run_unary_elementwise_nc( + xnn_unary_convert, xnn_datatype_fp16, xnn_datatype_fp32, + /*params=*/nullptr, /*input_quantization=*/nullptr, + /*output_quantization=*/nullptr, /*flags=*/0, + /*batch_size=*/xnn_shape_multiply_all_dims(&value.shape), + /*channels=*/1, + /*input_stride=*/1, /*output_stride=*/1, /*threadpool=*/nullptr, + /*input=*/value.data, /*output=*/fp32_value.data); + } else { + xnn_log_debug("Adding a convert[fp16, fp32](%d, %d) node.", + value.id, fp32_value.id); + xnn_define_unary(subgraph, xnn_unary_convert, /*params=*/nullptr, + value.id, fp32_value.id, + /*flags=*/0); + } + fp16_id_to_fp32_id[value.id] = fp32_value.id; + } + CurrentNode().inputs[i] = fp16_id_to_fp32_id[value.id]; + } + } + } + + xnn_subgraph_clean_up(subgraph); + return xnn_status_success; +} + +enum xnn_status xnn_subgraph_alias_fp16_fp32_fallback_data( + xnn_subgraph_t subgraph, xnn_weights_cache_t cache) { + if (cache) { + for (uint32_t i = 0; i < subgraph->num_values; ++i) { + const xnn_value& value = subgraph->values[i]; + if (value.fp16_to_fp32_fallback.original_data) { + XNN_RETURN_IF_ERROR(xnn_weights_cache_alias_data( + cache, value.fp16_to_fp32_fallback.original_data, value.data)); + } + } + } + return xnn_status_success; +} diff --git a/src/subgraph/rewrites/fp16_to_fp32.h b/src/subgraph/rewrites/fp16_to_fp32.h new file mode 100644 index 00000000000..e460fe95177 --- /dev/null +++ b/src/subgraph/rewrites/fp16_to_fp32.h @@ -0,0 +1,30 @@ +#ifndef XNNPACK_SRC_SUBGRAPH_REWRITES_FP16_TO_FP32_H_ +#define XNNPACK_SRC_SUBGRAPH_REWRITES_FP16_TO_FP32_H_ + +#include + +#include "include/xnnpack.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Rewrites unsupported fp16 operations to fp32. +// +// Inserts fp16-fp32 converts for inputs and fp32-fp16 converts for outputs. +// This tries to minimise the number of converts by avoiding chains of fp32 -> +// fp16 -> fp32 converts. +enum xnn_status xnn_subgraph_fallback_from_fp16_to_fp32(xnn_subgraph_t subgraph, + int optimization_flags); + +// Updates the weight cache with data aliases for static values that were +// converted during a previous call to +// `xnn_subgraph_fallback_from_fp16_to_fp32`. +enum xnn_status xnn_subgraph_alias_fp16_fp32_fallback_data( + xnn_subgraph_t subgraph, xnn_weights_cache_t cache); + +#ifdef __cplusplus +} +#endif + +#endif // XNNPACK_SRC_SUBGRAPH_REWRITES_FP16_TO_FP32_H_ diff --git a/src/xnnpack/subgraph.h b/src/xnnpack/subgraph.h index d35562f2131..77777a45293 100644 --- a/src/xnnpack/subgraph.h +++ b/src/xnnpack/subgraph.h @@ -177,6 +177,16 @@ struct xnn_value { void* fp16_temp_data; } fp16_rewrite; + struct fp16_fp32_fallback { + // This marks nodes that have been forcefully rewritten from fp16 to fp32 + // with inserting a convert. + bool was_overwritten; + // For static values, this points to the original static value's data. This + // allows the runtime inform the weight cache about the relationship between + // the original tensors and the converted tensors. + void* original_data; + } fp16_to_fp32_fallback; + // Pointer to a `xnn_gemm_config` if this value is packed for a specific GEMM. const struct xnn_gemm_config* gemm_config; // Pointer to original fp32 data if this value was converted from fp32 to fp16 @@ -643,6 +653,8 @@ void xnn_subgraph_rewrite_for_nchw(xnn_subgraph_t subgraph); // Rewrites subgraph for FP16, returns true if success, false if rewrite failed. bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph); +void xnn_subgraph_clean_up(xnn_subgraph_t subgraph); + void xnn_node_clear(struct xnn_node* node); void xnn_node_copy(struct xnn_node* dst_node, const struct xnn_node* src_node); void xnn_value_clear(struct xnn_value* value);