|
1 | 1 | // |
2 | | -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. |
| 2 | +// Copyright © 2017-2021,2023 Arm Ltd and Contributors. All rights reserved. |
3 | 3 | // SPDX-License-Identifier: MIT |
4 | 4 | // |
5 | 5 |
|
@@ -767,4 +767,67 @@ void CommitPools(std::vector<::android::nn::RunTimePoolInfo>& memPools) |
767 | 767 | #endif |
768 | 768 | } |
769 | 769 | } |
| 770 | + |
| 771 | +size_t GetSize(const V1_0::Request& request, const V1_0::RequestArgument& requestArgument) |
| 772 | +{ |
| 773 | + return request.pools[requestArgument.location.poolIndex].size(); |
| 774 | +} |
| 775 | + |
| 776 | +#ifdef ARMNN_ANDROID_NN_V1_3 |
| 777 | +size_t GetSize(const V1_3::Request& request, const V1_0::RequestArgument& requestArgument) |
| 778 | +{ |
| 779 | + if (request.pools[requestArgument.location.poolIndex].getDiscriminator() == |
| 780 | + V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) |
| 781 | + { |
| 782 | + return request.pools[requestArgument.location.poolIndex].hidlMemory().size(); |
| 783 | + } |
| 784 | + else |
| 785 | + { |
| 786 | + return 0; |
| 787 | + } |
| 788 | +} |
| 789 | +#endif |
| 790 | + |
| 791 | +template <typename ErrorStatus, typename Request> |
| 792 | +ErrorStatus ValidateRequestArgument(const Request& request, |
| 793 | + const armnn::TensorInfo& tensorInfo, |
| 794 | + const V1_0::RequestArgument& requestArgument, |
| 795 | + std::string descString) |
| 796 | +{ |
| 797 | + if (requestArgument.location.poolIndex >= request.pools.size()) |
| 798 | + { |
| 799 | + std::string err = fmt::format("Invalid {} pool at index {} the pool index is greater than the number " |
| 800 | + "of available pools {}", |
| 801 | + descString, requestArgument.location.poolIndex, request.pools.size()); |
| 802 | + ALOGE(err.c_str()); |
| 803 | + return ErrorStatus::GENERAL_FAILURE; |
| 804 | + } |
| 805 | + const size_t size = GetSize(request, requestArgument); |
| 806 | + size_t totalLength = tensorInfo.GetNumBytes(); |
| 807 | + |
| 808 | + if (static_cast<size_t>(requestArgument.location.offset) + totalLength > size) |
| 809 | + { |
| 810 | + std::string err = fmt::format("Invalid {} pool at index {} the offset {} and length {} are greater " |
| 811 | + "than the pool size {}", descString, requestArgument.location.poolIndex, |
| 812 | + requestArgument.location.offset, totalLength, size); |
| 813 | + ALOGE(err.c_str()); |
| 814 | + return ErrorStatus::GENERAL_FAILURE; |
| 815 | + } |
| 816 | + return ErrorStatus::NONE; |
| 817 | +} |
| 818 | + |
| 819 | +template V1_0::ErrorStatus ValidateRequestArgument<V1_0::ErrorStatus, V1_0::Request>( |
| 820 | + const V1_0::Request& request, |
| 821 | + const armnn::TensorInfo& tensorInfo, |
| 822 | + const V1_0::RequestArgument& requestArgument, |
| 823 | + std::string descString); |
| 824 | + |
| 825 | +#ifdef ARMNN_ANDROID_NN_V1_3 |
| 826 | +template V1_3::ErrorStatus ValidateRequestArgument<V1_3::ErrorStatus, V1_3::Request>( |
| 827 | + const V1_3::Request& request, |
| 828 | + const armnn::TensorInfo& tensorInfo, |
| 829 | + const V1_0::RequestArgument& requestArgument, |
| 830 | + std::string descString); |
| 831 | +#endif |
| 832 | + |
770 | 833 | } // namespace armnn_driver |
0 commit comments