From 05d8cd8a2da701a22eb766bcaaf6a4a958797bb3 Mon Sep 17 00:00:00 2001 From: Kaitlyn Davis Date: Thu, 12 Feb 2026 08:22:20 -0800 Subject: [PATCH] transform: add typed wrappers for FICO and ZRH gate inputs What: Introduce typed wrapper APIs for commonly-used transform inputs (FICO + ZRH gates). Why: Typed wrappers reduce integration errors and simplify safe language bindings compared to varargs/loose call sites. Expected impact: Cleaner, less error-prone caller code; additive API with no behavior change to existing entry points. Tests: add dedicated coverage for new wrappers (tests/testDriver_transform_typed_wrappers.c). Signed-off-by: Kaitlyn Davis Signed-off-by: Kaitlyn Davis --- tests/testDriver_transform_typed_wrappers.c | 168 ++++++++++++++++++++ zdnn/stickify.c | 58 +++++++ zdnn/zdnn.h | 11 ++ zdnn/zdnn.map | 2 + 4 files changed, 239 insertions(+) create mode 100644 tests/testDriver_transform_typed_wrappers.c diff --git a/tests/testDriver_transform_typed_wrappers.c b/tests/testDriver_transform_typed_wrappers.c new file mode 100644 index 0000000..759d445 --- /dev/null +++ b/tests/testDriver_transform_typed_wrappers.c @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright IBM Corp. 2024 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "testsupport.h" + +#include + +/* + * Tests for typed wrappers over the varargs zdnn_transform_ztensor() API. + * + * Goals: + * - The wrappers must reject incorrect transformed layouts (avoid silent + * varargs misuse). + * - For valid concatenated tensors, wrapper output must match the equivalent + * varargs call byte-for-byte. + */ + +void setUp(void) {} +void tearDown(void) {} + +static void fill_fp16(uint16_t *buf, size_t n, uint16_t v) { + for (size_t i = 0; i < n; i++) { + buf[i] = v; + } +} + +void test_fico_wrapper_rejects_non_fico_layout(void) { + zdnn_tensor_desc pre = {0}, tfrmd = {0}; + zdnn_ztensor z = {0}; + + zdnn_init_pre_transformed_desc(ZDNN_1D, FP16, &pre, 4); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_generate_transformed_desc(&pre, &tfrmd)); + + zdnn_init_ztensor(&pre, &tfrmd, &z); + + // Should fail fast before dereferencing any gate pointers. + TEST_ASSERT_EQUAL_UINT32(ZDNN_INVALID_LAYOUT, + zdnn_transform_ztensor_fico(&z, NULL, NULL, NULL, NULL)); +} + +void test_zrh_wrapper_rejects_non_zrh_layout(void) { + zdnn_tensor_desc pre = {0}, tfrmd = {0}; + zdnn_ztensor z = {0}; + + zdnn_init_pre_transformed_desc(ZDNN_1D, FP16, &pre, 4); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_generate_transformed_desc(&pre, &tfrmd)); + + zdnn_init_ztensor(&pre, &tfrmd, &z); + + TEST_ASSERT_EQUAL_UINT32(ZDNN_INVALID_LAYOUT, + zdnn_transform_ztensor_zrh(&z, NULL, NULL, NULL)); +} + +void test_fico_wrapper_matches_varargs(void) { + const uint32_t num_dirs = 1; + const uint32_t num_hidden = 7; + const size_t gate_el = (size_t)num_dirs * num_hidden; + + zdnn_tensor_desc pre = {0}, tfrmd = {0}; + + // Concatenated LSTM biases. + zdnn_concat_info info = RNN_TYPE_LSTM | USAGE_BIASES | PREV_LAYER_NONE; + + // For biases: pre-transformed shape per gate is (num_dirs, num_hidden) using + // ZDNN_2DS. + zdnn_init_pre_transformed_desc(ZDNN_2DS, FP16, &pre, num_dirs, num_hidden); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, + zdnn_generate_transformed_desc_concatenated(&pre, info, + &tfrmd)); + + zdnn_ztensor z_varargs = {0}, z_wrapper = {0}; + + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, + zdnn_init_ztensor_with_malloc(&pre, &tfrmd, &z_varargs)); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, + zdnn_init_ztensor_with_malloc(&pre, &tfrmd, &z_wrapper)); + + // Distinct FP16 values per gate to catch ordering mistakes. + // (IEEE-754 half): 1.0=0x3C00, 2.0=0x4000, 3.0=0x4200, 4.0=0x4400. + uint16_t f[gate_el], i[gate_el], c[gate_el], o[gate_el]; + fill_fp16(f, gate_el, 0x3C00); + fill_fp16(i, gate_el, 0x4000); + fill_fp16(c, gate_el, 0x4200); + fill_fp16(o, gate_el, 0x4400); + + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, + zdnn_transform_ztensor(&z_varargs, f, i, c, o)); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, + zdnn_transform_ztensor_fico(&z_wrapper, f, i, c, o)); + + TEST_ASSERT_EQUAL_UINT64_MESSAGE(z_varargs.buffer_size, z_wrapper.buffer_size, + "Buffer sizes unexpectedly differ"); + + TEST_ASSERT_EQUAL_INT_MESSAGE(0, + memcmp(z_varargs.buffer, z_wrapper.buffer, + z_varargs.buffer_size), + "Wrapper output differs from varargs output"); + + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_free_ztensor_buffer(&z_varargs)); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_free_ztensor_buffer(&z_wrapper)); +} + +void test_zrh_wrapper_matches_varargs(void) { + const uint32_t num_dirs = 1; + const uint32_t num_hidden = 7; + const size_t gate_el = (size_t)num_dirs * num_hidden; + + zdnn_tensor_desc pre = {0}, tfrmd = {0}; + + // Concatenated GRU biases. + zdnn_concat_info info = RNN_TYPE_GRU | USAGE_BIASES | PREV_LAYER_NONE; + + zdnn_init_pre_transformed_desc(ZDNN_2DS, FP16, &pre, num_dirs, num_hidden); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, + zdnn_generate_transformed_desc_concatenated(&pre, info, + &tfrmd)); + + zdnn_ztensor z_varargs = {0}, z_wrapper = {0}; + + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, + zdnn_init_ztensor_with_malloc(&pre, &tfrmd, &z_varargs)); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, + zdnn_init_ztensor_with_malloc(&pre, &tfrmd, &z_wrapper)); + + uint16_t z[gate_el], r[gate_el], h[gate_el]; + fill_fp16(z, gate_el, 0x3C00); + fill_fp16(r, gate_el, 0x4000); + fill_fp16(h, gate_el, 0x4200); + + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_transform_ztensor(&z_varargs, z, r, h)); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_transform_ztensor_zrh(&z_wrapper, z, r, h)); + + TEST_ASSERT_EQUAL_UINT64_MESSAGE(z_varargs.buffer_size, z_wrapper.buffer_size, + "Buffer sizes unexpectedly differ"); + + TEST_ASSERT_EQUAL_INT_MESSAGE(0, + memcmp(z_varargs.buffer, z_wrapper.buffer, + z_varargs.buffer_size), + "Wrapper output differs from varargs output"); + + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_free_ztensor_buffer(&z_varargs)); + TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_free_ztensor_buffer(&z_wrapper)); +} + +int main(void) { + UNITY_BEGIN(); + + RUN_TEST(test_fico_wrapper_rejects_non_fico_layout); + RUN_TEST(test_zrh_wrapper_rejects_non_zrh_layout); + RUN_TEST(test_fico_wrapper_matches_varargs); + RUN_TEST(test_zrh_wrapper_matches_varargs); + + return UNITY_END(); +} diff --git a/zdnn/stickify.c b/zdnn/stickify.c index c45256b..d9fc644 100644 --- a/zdnn/stickify.c +++ b/zdnn/stickify.c @@ -29,6 +29,8 @@ #ifdef __MVS__ #pragma export(zdnn_transform_ztensor) +#pragma export(zdnn_transform_ztensor_fico) +#pragma export(zdnn_transform_ztensor_zrh) #pragma export(zdnn_transform_origtensor) #pragma export(zdnn_transform_quantized_ztensor) #pragma export(zdnn_transform_ztensor_with_saturation) @@ -1329,6 +1331,62 @@ zdnn_status zdnn_transform_ztensor(zdnn_ztensor *ztensor, ...) { return status; } +/// Typed wrapper for transforming concatenated LSTM gate tensors (FICO order). +/// +/// This is a convenience helper over the varargs zdnn_transform_ztensor() API +/// to make call sites explicit and prevent argument ordering mistakes. +/// +/// \param[in,out] ztensor Target ztensor to be transformed into. Must have a +/// transformed layout of ZDNN_FICO or ZDNN_BIDIR_FICO. +/// \param[in] f Pointer to Forget gate data buffer +/// \param[in] i Pointer to Input gate data buffer +/// \param[in] c Pointer to Cell gate data buffer +/// \param[in] o Pointer to Output gate data buffer +/// +/// \return Same as zdnn_transform_ztensor() +/// +zdnn_status zdnn_transform_ztensor_fico(zdnn_ztensor *ztensor, const void *f, + const void *i, const void *c, + const void *o) { + if (ztensor->transformed_desc->layout != ZDNN_FICO && + ztensor->transformed_desc->layout != ZDNN_BIDIR_FICO) { + return ZDNN_STATUS( + ZDNN_INVALID_LAYOUT, + "zdnn_transform_ztensor_fico() expects transformed_desc->layout to be " + "ZDNN_FICO or ZDNN_BIDIR_FICO (found: %s)", + get_data_layout_str(ztensor->transformed_desc->layout)); + } + + return zdnn_transform_ztensor(ztensor, f, i, c, o); +} + +/// Typed wrapper for transforming concatenated GRU gate tensors (ZRH order). +/// +/// This is a convenience helper over the varargs zdnn_transform_ztensor() API +/// to make call sites explicit and prevent argument ordering mistakes. +/// +/// \param[in,out] ztensor Target ztensor to be transformed into. Must have a +/// transformed layout of ZDNN_ZRH or ZDNN_BIDIR_ZRH. +/// \param[in] z Pointer to (Z)update gate data buffer +/// \param[in] r Pointer to Reset gate data buffer +/// \param[in] h Pointer to Hidden gate data buffer +/// +/// \return Same as zdnn_transform_ztensor() +/// +zdnn_status zdnn_transform_ztensor_zrh(zdnn_ztensor *ztensor, const void *z, + const void *r, const void *h) { + if (ztensor->transformed_desc->layout != ZDNN_ZRH && + ztensor->transformed_desc->layout != ZDNN_BIDIR_ZRH) { + return ZDNN_STATUS( + ZDNN_INVALID_LAYOUT, + "zdnn_transform_ztensor_zrh() expects transformed_desc->layout to be " + "ZDNN_ZRH or ZDNN_BIDIR_ZRH (found: %s)", + get_data_layout_str(ztensor->transformed_desc->layout)); + } + + return zdnn_transform_ztensor(ztensor, z, r, h); +} + /// Converts the input tensor to the supported stick format for /// execution by zDNN operations. /// diff --git a/zdnn/zdnn.h b/zdnn/zdnn.h index 0c9cacc..bc1fb4d 100644 --- a/zdnn/zdnn.h +++ b/zdnn/zdnn.h @@ -616,8 +616,19 @@ zdnn_status zdnn_conv2d(const zdnn_ztensor *input, const zdnn_ztensor *kernel, // External Tensor Transform Operations // ----------------------------------------------------------------------------- +// NOTE: zdnn_transform_ztensor() is a varargs API. For concatenated tensors +// (FICO/ZRH/BIDIR_*), typed wrappers are provided to make call sites explicit +// and avoid argument ordering mistakes. + zdnn_status zdnn_transform_ztensor(zdnn_ztensor *ztensor, ...); +zdnn_status zdnn_transform_ztensor_fico(zdnn_ztensor *ztensor, const void *f, + const void *i, const void *c, + const void *o); + +zdnn_status zdnn_transform_ztensor_zrh(zdnn_ztensor *ztensor, const void *z, + const void *r, const void *h); + zdnn_status zdnn_transform_ztensor_with_saturation(zdnn_ztensor *ztensor, ...); zdnn_status zdnn_transform_quantized_ztensor(zdnn_ztensor *ztensor, diff --git a/zdnn/zdnn.map b/zdnn/zdnn.map index d80d3a1..fd63133 100644 --- a/zdnn/zdnn.map +++ b/zdnn/zdnn.map @@ -142,6 +142,8 @@ ZDNN_1.0 { zdnn_maxpool2d; zdnn_conv2d; zdnn_transform_ztensor; + zdnn_transform_ztensor_fico; + zdnn_transform_ztensor_zrh; zdnn_transform_ztensor_with_saturation; zdnn_transform_quantized_ztensor; zdnn_transform_origtensor;