Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions tests/testDriver_quantized_matmul_work_area_size_apis.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// SPDX-License-Identifier: Apache-2.0
/*
* Copyright IBM Corp. 2024
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "testsupport.h"

/*
* Tests for quantized matmul work-area size helper API.
*/

void setUp(void) {}
void tearDown(void) {}

void test_quantized_matmul_work_area_size_matches_internal_qc_tilde_shape(void) {
// Use dim1=65 so size differs between int8 source tensor and DLF16 qc_tilde.
zdnn_tensor_desc input_c_desc;
init_transformed_desc(ZDNN_NHWC, ZDNN_BINARY_INT8, ZDNN_FORMAT_4DFEATURE,
&input_c_desc, 1, 1, 1, 65);

zdnn_tensor_desc qc_tilde_desc;
init_transformed_desc(ZDNN_NHWC, ZDNN_DLFLOAT16, ZDNN_FORMAT_4DFEATURE,
&qc_tilde_desc, 1, 1, 1, 65);

zdnn_ztensor input_c;
zdnn_init_ztensor(&input_c_desc, &input_c_desc, &input_c);

uint64_t expected = zdnn_getsize_ztensor(&qc_tilde_desc);
uint64_t actual = zdnn_get_quantized_matmul_work_area_size(&input_c, false);

TEST_ASSERT_EQUAL_UINT64(expected, actual);
}

void test_quantized_matmul_work_area_size_returns_zero_for_precomputed(void) {
zdnn_tensor_desc input_c_desc;
init_transformed_desc(ZDNN_NHWC, ZDNN_BINARY_INT8, ZDNN_FORMAT_4DFEATURE,
&input_c_desc, 1, 1, 2, 17);

zdnn_ztensor input_c;
zdnn_init_ztensor(&input_c_desc, &input_c_desc, &input_c);

TEST_ASSERT_EQUAL_UINT64(
0, zdnn_get_quantized_matmul_work_area_size(&input_c, true));
}

void test_quantized_matmul_work_area_size_rejects_invalid_input(void) {
TEST_ASSERT_EQUAL_UINT64(0,
zdnn_get_quantized_matmul_work_area_size(NULL, false));

zdnn_ztensor missing_desc = {0};
TEST_ASSERT_EQUAL_UINT64(
0, zdnn_get_quantized_matmul_work_area_size(&missing_desc, false));
}

void test_quantized_matmul_work_area_size_rejects_non_int8_input_c(void) {
zdnn_tensor_desc invalid_desc;
init_transformed_desc(ZDNN_NHWC, ZDNN_DLFLOAT16, ZDNN_FORMAT_4DFEATURE,
&invalid_desc, 1, 1, 1, 64);

zdnn_ztensor input_c;
zdnn_init_ztensor(&invalid_desc, &invalid_desc, &input_c);

TEST_ASSERT_EQUAL_UINT64(
0, zdnn_get_quantized_matmul_work_area_size(&input_c, false));
}

int main(void) {
UNITY_BEGIN();

RUN_TEST(test_quantized_matmul_work_area_size_matches_internal_qc_tilde_shape);
RUN_TEST(test_quantized_matmul_work_area_size_returns_zero_for_precomputed);
RUN_TEST(test_quantized_matmul_work_area_size_rejects_invalid_input);
RUN_TEST(test_quantized_matmul_work_area_size_rejects_non_int8_input_c);

return UNITY_END();
}
4 changes: 3 additions & 1 deletion tests/testDriver_zdnn_quantized_matmul_op.c
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ void test_zdnn_api_quantized_matmul(

// Set work_area during second pass
if (work_area_pass == 1) {
work_area = alloc_quantized_matmul_work_area(biases->buffer_size);
size_t work_area_size =
zdnn_get_quantized_matmul_work_area_size(biases, false);
work_area = alloc_quantized_matmul_work_area(work_area_size);
}

zdnn_status status;
Expand Down
55 changes: 55 additions & 0 deletions zdnn/quantized_matmul_work_area.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// SPDX-License-Identifier: Apache-2.0
/*
* Copyright IBM Corp. 2024
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <stdint.h>

#include "zdnn.h"
#include "zdnn_private.h"

#ifdef __MVS__
#pragma export(zdnn_get_quantized_matmul_work_area_size)
#endif

uint64_t zdnn_get_quantized_matmul_work_area_size(const zdnn_ztensor *input_c,
bool pre_computed) {
// pre_computed mode does not create qc_tilde and therefore does not require
// quantized-matmul work area scratch storage.
if (pre_computed) {
return 0;
}

if (!input_c || !input_c->transformed_desc) {
return 0;
}

// aiu_quantized_matmul() requires input_c transformed type to be int8 when
// pre_computed is false.
if (input_c->transformed_desc->type != ZDNN_BINARY_INT8) {
return 0;
}

// The internal temporary tensor (qc_tilde) reuses input_c's transformed
// dimensions/layout/format but promotes the element type to DLFLOAT16.
zdnn_tensor_desc qc_tilde_desc;
init_transformed_desc(
input_c->transformed_desc->layout, ZDNN_DLFLOAT16,
input_c->transformed_desc->format, &qc_tilde_desc,
input_c->transformed_desc->dim4, input_c->transformed_desc->dim3,
input_c->transformed_desc->dim2, input_c->transformed_desc->dim1);

return zdnn_getsize_ztensor(&qc_tilde_desc);
}
8 changes: 8 additions & 0 deletions zdnn/zdnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,14 @@ zdnn_status zdnn_quantized_matmul_op(
const int8_t clip_max, const bool disable_clipping, const bool dequantize,
const bool pre_computed, void *work_area, zdnn_ztensor *output);

// Convenience helper to compute the required 4K-aligned `work_area` size for
// zdnn_quantized_matmul_op().
//
// Returns 0 when no work area is required (`pre_computed == true`) or when
// inputs are invalid.
uint64_t zdnn_get_quantized_matmul_work_area_size(const zdnn_ztensor *input_c,
bool pre_computed);

// -----------------------------------------------------------------------------
// External Norm Operations
// -----------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions zdnn/zdnn.map
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ ZDNN_1.0 {
zdnn_matmul_bcast_op;
zdnn_matmul_transpose_op;
zdnn_quantized_matmul_op;
zdnn_get_quantized_matmul_work_area_size;
zdnn_batchnorm;
zdnn_norm;
zdnn_moments;
Expand Down