diff --git a/tests/testDriver_quantized_matmul_work_area_size_apis.c b/tests/testDriver_quantized_matmul_work_area_size_apis.c new file mode 100644 index 0000000..26a7485 --- /dev/null +++ b/tests/testDriver_quantized_matmul_work_area_size_apis.c @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright IBM Corp. 2024 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "testsupport.h" + +/* + * Tests for quantized matmul work-area size helper API. + */ + +void setUp(void) {} +void tearDown(void) {} + +void test_quantized_matmul_work_area_size_matches_internal_qc_tilde_shape(void) { + // Use dim1=65 so size differs between int8 source tensor and DLF16 qc_tilde. + zdnn_tensor_desc input_c_desc; + init_transformed_desc(ZDNN_NHWC, ZDNN_BINARY_INT8, ZDNN_FORMAT_4DFEATURE, + &input_c_desc, 1, 1, 1, 65); + + zdnn_tensor_desc qc_tilde_desc; + init_transformed_desc(ZDNN_NHWC, ZDNN_DLFLOAT16, ZDNN_FORMAT_4DFEATURE, + &qc_tilde_desc, 1, 1, 1, 65); + + zdnn_ztensor input_c; + zdnn_init_ztensor(&input_c_desc, &input_c_desc, &input_c); + + uint64_t expected = zdnn_getsize_ztensor(&qc_tilde_desc); + uint64_t actual = zdnn_get_quantized_matmul_work_area_size(&input_c, false); + + TEST_ASSERT_EQUAL_UINT64(expected, actual); +} + +void test_quantized_matmul_work_area_size_returns_zero_for_precomputed(void) { + zdnn_tensor_desc input_c_desc; + init_transformed_desc(ZDNN_NHWC, ZDNN_BINARY_INT8, ZDNN_FORMAT_4DFEATURE, + &input_c_desc, 1, 1, 2, 17); + + zdnn_ztensor input_c; + zdnn_init_ztensor(&input_c_desc, &input_c_desc, &input_c); + + TEST_ASSERT_EQUAL_UINT64( + 0, zdnn_get_quantized_matmul_work_area_size(&input_c, true)); +} + +void test_quantized_matmul_work_area_size_rejects_invalid_input(void) { + TEST_ASSERT_EQUAL_UINT64(0, + zdnn_get_quantized_matmul_work_area_size(NULL, false)); + + zdnn_ztensor missing_desc = {0}; + TEST_ASSERT_EQUAL_UINT64( + 0, zdnn_get_quantized_matmul_work_area_size(&missing_desc, false)); +} + +void test_quantized_matmul_work_area_size_rejects_non_int8_input_c(void) { + zdnn_tensor_desc invalid_desc; + init_transformed_desc(ZDNN_NHWC, ZDNN_DLFLOAT16, ZDNN_FORMAT_4DFEATURE, + &invalid_desc, 1, 1, 1, 64); + + zdnn_ztensor input_c; + zdnn_init_ztensor(&invalid_desc, &invalid_desc, &input_c); + + TEST_ASSERT_EQUAL_UINT64( + 0, zdnn_get_quantized_matmul_work_area_size(&input_c, false)); +} + +int main(void) { + UNITY_BEGIN(); + + RUN_TEST(test_quantized_matmul_work_area_size_matches_internal_qc_tilde_shape); + RUN_TEST(test_quantized_matmul_work_area_size_returns_zero_for_precomputed); + RUN_TEST(test_quantized_matmul_work_area_size_rejects_invalid_input); + RUN_TEST(test_quantized_matmul_work_area_size_rejects_non_int8_input_c); + + return UNITY_END(); +} diff --git a/tests/testDriver_zdnn_quantized_matmul_op.c b/tests/testDriver_zdnn_quantized_matmul_op.c index d8d54f7..cefd8bf 100644 --- a/tests/testDriver_zdnn_quantized_matmul_op.c +++ b/tests/testDriver_zdnn_quantized_matmul_op.c @@ -490,7 +490,9 @@ void test_zdnn_api_quantized_matmul( // Set work_area during second pass if (work_area_pass == 1) { - work_area = alloc_quantized_matmul_work_area(biases->buffer_size); + size_t work_area_size = + zdnn_get_quantized_matmul_work_area_size(biases, false); + work_area = alloc_quantized_matmul_work_area(work_area_size); } zdnn_status status; diff --git a/zdnn/quantized_matmul_work_area.c b/zdnn/quantized_matmul_work_area.c new file mode 100644 index 0000000..ca24cfa --- /dev/null +++ b/zdnn/quantized_matmul_work_area.c @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright IBM Corp. 2024 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "zdnn.h" +#include "zdnn_private.h" + +#ifdef __MVS__ +#pragma export(zdnn_get_quantized_matmul_work_area_size) +#endif + +uint64_t zdnn_get_quantized_matmul_work_area_size(const zdnn_ztensor *input_c, + bool pre_computed) { + // pre_computed mode does not create qc_tilde and therefore does not require + // quantized-matmul work area scratch storage. + if (pre_computed) { + return 0; + } + + if (!input_c || !input_c->transformed_desc) { + return 0; + } + + // aiu_quantized_matmul() requires input_c transformed type to be int8 when + // pre_computed is false. + if (input_c->transformed_desc->type != ZDNN_BINARY_INT8) { + return 0; + } + + // The internal temporary tensor (qc_tilde) reuses input_c's transformed + // dimensions/layout/format but promotes the element type to DLFLOAT16. + zdnn_tensor_desc qc_tilde_desc; + init_transformed_desc( + input_c->transformed_desc->layout, ZDNN_DLFLOAT16, + input_c->transformed_desc->format, &qc_tilde_desc, + input_c->transformed_desc->dim4, input_c->transformed_desc->dim3, + input_c->transformed_desc->dim2, input_c->transformed_desc->dim1); + + return zdnn_getsize_ztensor(&qc_tilde_desc); +} diff --git a/zdnn/zdnn.h b/zdnn/zdnn.h index 0c9cacc..7d43c8a 100644 --- a/zdnn/zdnn.h +++ b/zdnn/zdnn.h @@ -564,6 +564,14 @@ zdnn_status zdnn_quantized_matmul_op( const int8_t clip_max, const bool disable_clipping, const bool dequantize, const bool pre_computed, void *work_area, zdnn_ztensor *output); +// Convenience helper to compute the required 4K-aligned `work_area` size for +// zdnn_quantized_matmul_op(). +// +// Returns 0 when no work area is required (`pre_computed == true`) or when +// inputs are invalid. +uint64_t zdnn_get_quantized_matmul_work_area_size(const zdnn_ztensor *input_c, + bool pre_computed); + // ----------------------------------------------------------------------------- // External Norm Operations // ----------------------------------------------------------------------------- diff --git a/zdnn/zdnn.map b/zdnn/zdnn.map index d80d3a1..5936ed8 100644 --- a/zdnn/zdnn.map +++ b/zdnn/zdnn.map @@ -132,6 +132,7 @@ ZDNN_1.0 { zdnn_matmul_bcast_op; zdnn_matmul_transpose_op; zdnn_quantized_matmul_op; + zdnn_get_quantized_matmul_work_area_size; zdnn_batchnorm; zdnn_norm; zdnn_moments;