Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions tests/testDriver_transform_typed_wrappers.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// SPDX-License-Identifier: Apache-2.0
/*
* Copyright IBM Corp. 2024
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "testsupport.h"

#include <string.h>

/*
* Tests for typed wrappers over the varargs zdnn_transform_ztensor() API.
*
* Goals:
* - The wrappers must reject incorrect transformed layouts (avoid silent
* varargs misuse).
* - For valid concatenated tensors, wrapper output must match the equivalent
* varargs call byte-for-byte.
*/

void setUp(void) {}
void tearDown(void) {}

static void fill_fp16(uint16_t *buf, size_t n, uint16_t v) {
for (size_t i = 0; i < n; i++) {
buf[i] = v;
}
}

void test_fico_wrapper_rejects_non_fico_layout(void) {
zdnn_tensor_desc pre = {0}, tfrmd = {0};
zdnn_ztensor z = {0};

zdnn_init_pre_transformed_desc(ZDNN_1D, FP16, &pre, 4);
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_generate_transformed_desc(&pre, &tfrmd));

zdnn_init_ztensor(&pre, &tfrmd, &z);

// Should fail fast before dereferencing any gate pointers.
TEST_ASSERT_EQUAL_UINT32(ZDNN_INVALID_LAYOUT,
zdnn_transform_ztensor_fico(&z, NULL, NULL, NULL, NULL));
}

void test_zrh_wrapper_rejects_non_zrh_layout(void) {
zdnn_tensor_desc pre = {0}, tfrmd = {0};
zdnn_ztensor z = {0};

zdnn_init_pre_transformed_desc(ZDNN_1D, FP16, &pre, 4);
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_generate_transformed_desc(&pre, &tfrmd));

zdnn_init_ztensor(&pre, &tfrmd, &z);

TEST_ASSERT_EQUAL_UINT32(ZDNN_INVALID_LAYOUT,
zdnn_transform_ztensor_zrh(&z, NULL, NULL, NULL));
}

void test_fico_wrapper_matches_varargs(void) {
const uint32_t num_dirs = 1;
const uint32_t num_hidden = 7;
const size_t gate_el = (size_t)num_dirs * num_hidden;

zdnn_tensor_desc pre = {0}, tfrmd = {0};

// Concatenated LSTM biases.
zdnn_concat_info info = RNN_TYPE_LSTM | USAGE_BIASES | PREV_LAYER_NONE;

// For biases: pre-transformed shape per gate is (num_dirs, num_hidden) using
// ZDNN_2DS.
zdnn_init_pre_transformed_desc(ZDNN_2DS, FP16, &pre, num_dirs, num_hidden);
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK,
zdnn_generate_transformed_desc_concatenated(&pre, info,
&tfrmd));

zdnn_ztensor z_varargs = {0}, z_wrapper = {0};

TEST_ASSERT_EQUAL_UINT32(ZDNN_OK,
zdnn_init_ztensor_with_malloc(&pre, &tfrmd, &z_varargs));
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK,
zdnn_init_ztensor_with_malloc(&pre, &tfrmd, &z_wrapper));

// Distinct FP16 values per gate to catch ordering mistakes.
// (IEEE-754 half): 1.0=0x3C00, 2.0=0x4000, 3.0=0x4200, 4.0=0x4400.
uint16_t f[gate_el], i[gate_el], c[gate_el], o[gate_el];
fill_fp16(f, gate_el, 0x3C00);
fill_fp16(i, gate_el, 0x4000);
fill_fp16(c, gate_el, 0x4200);
fill_fp16(o, gate_el, 0x4400);

TEST_ASSERT_EQUAL_UINT32(ZDNN_OK,
zdnn_transform_ztensor(&z_varargs, f, i, c, o));
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK,
zdnn_transform_ztensor_fico(&z_wrapper, f, i, c, o));

TEST_ASSERT_EQUAL_UINT64_MESSAGE(z_varargs.buffer_size, z_wrapper.buffer_size,
"Buffer sizes unexpectedly differ");

TEST_ASSERT_EQUAL_INT_MESSAGE(0,
memcmp(z_varargs.buffer, z_wrapper.buffer,
z_varargs.buffer_size),
"Wrapper output differs from varargs output");

TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_free_ztensor_buffer(&z_varargs));
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_free_ztensor_buffer(&z_wrapper));
}

void test_zrh_wrapper_matches_varargs(void) {
const uint32_t num_dirs = 1;
const uint32_t num_hidden = 7;
const size_t gate_el = (size_t)num_dirs * num_hidden;

zdnn_tensor_desc pre = {0}, tfrmd = {0};

// Concatenated GRU biases.
zdnn_concat_info info = RNN_TYPE_GRU | USAGE_BIASES | PREV_LAYER_NONE;

zdnn_init_pre_transformed_desc(ZDNN_2DS, FP16, &pre, num_dirs, num_hidden);
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK,
zdnn_generate_transformed_desc_concatenated(&pre, info,
&tfrmd));

zdnn_ztensor z_varargs = {0}, z_wrapper = {0};

TEST_ASSERT_EQUAL_UINT32(ZDNN_OK,
zdnn_init_ztensor_with_malloc(&pre, &tfrmd, &z_varargs));
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK,
zdnn_init_ztensor_with_malloc(&pre, &tfrmd, &z_wrapper));

uint16_t z[gate_el], r[gate_el], h[gate_el];
fill_fp16(z, gate_el, 0x3C00);
fill_fp16(r, gate_el, 0x4000);
fill_fp16(h, gate_el, 0x4200);

TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_transform_ztensor(&z_varargs, z, r, h));
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_transform_ztensor_zrh(&z_wrapper, z, r, h));

TEST_ASSERT_EQUAL_UINT64_MESSAGE(z_varargs.buffer_size, z_wrapper.buffer_size,
"Buffer sizes unexpectedly differ");

TEST_ASSERT_EQUAL_INT_MESSAGE(0,
memcmp(z_varargs.buffer, z_wrapper.buffer,
z_varargs.buffer_size),
"Wrapper output differs from varargs output");

TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_free_ztensor_buffer(&z_varargs));
TEST_ASSERT_EQUAL_UINT32(ZDNN_OK, zdnn_free_ztensor_buffer(&z_wrapper));
}

int main(void) {
UNITY_BEGIN();

RUN_TEST(test_fico_wrapper_rejects_non_fico_layout);
RUN_TEST(test_zrh_wrapper_rejects_non_zrh_layout);
RUN_TEST(test_fico_wrapper_matches_varargs);
RUN_TEST(test_zrh_wrapper_matches_varargs);

return UNITY_END();
}
58 changes: 58 additions & 0 deletions zdnn/stickify.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#ifdef __MVS__
#pragma export(zdnn_transform_ztensor)
#pragma export(zdnn_transform_ztensor_fico)
#pragma export(zdnn_transform_ztensor_zrh)
#pragma export(zdnn_transform_origtensor)
#pragma export(zdnn_transform_quantized_ztensor)
#pragma export(zdnn_transform_ztensor_with_saturation)
Expand Down Expand Up @@ -1329,6 +1331,62 @@ zdnn_status zdnn_transform_ztensor(zdnn_ztensor *ztensor, ...) {
return status;
}

/// Typed wrapper for transforming concatenated LSTM gate tensors (FICO order).
///
/// This is a convenience helper over the varargs zdnn_transform_ztensor() API
/// to make call sites explicit and prevent argument ordering mistakes.
///
/// \param[in,out] ztensor Target ztensor to be transformed into. Must have a
/// transformed layout of ZDNN_FICO or ZDNN_BIDIR_FICO.
/// \param[in] f Pointer to Forget gate data buffer
/// \param[in] i Pointer to Input gate data buffer
/// \param[in] c Pointer to Cell gate data buffer
/// \param[in] o Pointer to Output gate data buffer
///
/// \return Same as zdnn_transform_ztensor()
///
zdnn_status zdnn_transform_ztensor_fico(zdnn_ztensor *ztensor, const void *f,
const void *i, const void *c,
const void *o) {
if (ztensor->transformed_desc->layout != ZDNN_FICO &&
ztensor->transformed_desc->layout != ZDNN_BIDIR_FICO) {
return ZDNN_STATUS(
ZDNN_INVALID_LAYOUT,
"zdnn_transform_ztensor_fico() expects transformed_desc->layout to be "
"ZDNN_FICO or ZDNN_BIDIR_FICO (found: %s)",
get_data_layout_str(ztensor->transformed_desc->layout));
}

return zdnn_transform_ztensor(ztensor, f, i, c, o);
}

/// Typed wrapper for transforming concatenated GRU gate tensors (ZRH order).
///
/// This is a convenience helper over the varargs zdnn_transform_ztensor() API
/// to make call sites explicit and prevent argument ordering mistakes.
///
/// \param[in,out] ztensor Target ztensor to be transformed into. Must have a
/// transformed layout of ZDNN_ZRH or ZDNN_BIDIR_ZRH.
/// \param[in] z Pointer to (Z)update gate data buffer
/// \param[in] r Pointer to Reset gate data buffer
/// \param[in] h Pointer to Hidden gate data buffer
///
/// \return Same as zdnn_transform_ztensor()
///
zdnn_status zdnn_transform_ztensor_zrh(zdnn_ztensor *ztensor, const void *z,
const void *r, const void *h) {
if (ztensor->transformed_desc->layout != ZDNN_ZRH &&
ztensor->transformed_desc->layout != ZDNN_BIDIR_ZRH) {
return ZDNN_STATUS(
ZDNN_INVALID_LAYOUT,
"zdnn_transform_ztensor_zrh() expects transformed_desc->layout to be "
"ZDNN_ZRH or ZDNN_BIDIR_ZRH (found: %s)",
get_data_layout_str(ztensor->transformed_desc->layout));
}

return zdnn_transform_ztensor(ztensor, z, r, h);
}

/// Converts the input tensor to the supported stick format for
/// execution by zDNN operations.
///
Expand Down
11 changes: 11 additions & 0 deletions zdnn/zdnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,19 @@ zdnn_status zdnn_conv2d(const zdnn_ztensor *input, const zdnn_ztensor *kernel,
// External Tensor Transform Operations
// -----------------------------------------------------------------------------

// NOTE: zdnn_transform_ztensor() is a varargs API. For concatenated tensors
// (FICO/ZRH/BIDIR_*), typed wrappers are provided to make call sites explicit
// and avoid argument ordering mistakes.

zdnn_status zdnn_transform_ztensor(zdnn_ztensor *ztensor, ...);

zdnn_status zdnn_transform_ztensor_fico(zdnn_ztensor *ztensor, const void *f,
const void *i, const void *c,
const void *o);

zdnn_status zdnn_transform_ztensor_zrh(zdnn_ztensor *ztensor, const void *z,
const void *r, const void *h);

zdnn_status zdnn_transform_ztensor_with_saturation(zdnn_ztensor *ztensor, ...);

zdnn_status zdnn_transform_quantized_ztensor(zdnn_ztensor *ztensor,
Expand Down
2 changes: 2 additions & 0 deletions zdnn/zdnn.map
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ ZDNN_1.0 {
zdnn_maxpool2d;
zdnn_conv2d;
zdnn_transform_ztensor;
zdnn_transform_ztensor_fico;
zdnn_transform_ztensor_zrh;
zdnn_transform_ztensor_with_saturation;
zdnn_transform_quantized_ztensor;
zdnn_transform_origtensor;
Expand Down