Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions tests/testDriver_work_area_size_apis.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// SPDX-License-Identifier: Apache-2.0
/*
* Copyright IBM Corp. 2024
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "testsupport.h"

/*
* Tests for RNN work-area size helper APIs.
*
* These helpers codify the README formulas:
* - LSTM: dim4 = (4 * num_timesteps) + 6
* - GRU: dim4 = (3 * num_timesteps) + 5
* and multiply by 2 for BIDIR.
*/

void setUp(void) {}
void tearDown(void) {}

static uint64_t expected_size(uint32_t dim4, uint32_t num_batches,
uint32_t num_hidden,
lstm_gru_direction direction) {
zdnn_tensor_desc desc;
init_transformed_desc(ZDNN_NHWC, ZDNN_DLFLOAT16, ZDNN_FORMAT_4DFEATURE, &desc,
dim4, 1, num_batches, num_hidden);
uint64_t size = zdnn_getsize_ztensor(&desc);
if (direction == BIDIR) {
size *= 2;
}
return size;
}

void test_lstm_work_area_size_matches_formula(void) {
uint32_t ts = 7, b = 4, h = 64;
uint32_t dim4 = (4 * ts) + 6;

TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, FWD),
zdnn_get_lstm_work_area_size(ts, b, h, FWD));
TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, BWD),
zdnn_get_lstm_work_area_size(ts, b, h, BWD));
TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, BIDIR),
zdnn_get_lstm_work_area_size(ts, b, h, BIDIR));
}

void test_gru_work_area_size_matches_formula(void) {
uint32_t ts = 7, b = 4, h = 64;
uint32_t dim4 = (3 * ts) + 5;

TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, FWD),
zdnn_get_gru_work_area_size(ts, b, h, FWD));
TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, BWD),
zdnn_get_gru_work_area_size(ts, b, h, BWD));
TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, BIDIR),
zdnn_get_gru_work_area_size(ts, b, h, BIDIR));
}

void test_invalid_inputs_return_zero(void) {
TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_lstm_work_area_size(0, 1, 1, FWD));
TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_lstm_work_area_size(1, 0, 1, FWD));
TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_lstm_work_area_size(1, 1, 0, FWD));
TEST_ASSERT_EQUAL_UINT64(0,
zdnn_get_lstm_work_area_size(1, 1, 1,
(lstm_gru_direction)99));

TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_gru_work_area_size(0, 1, 1, FWD));
TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_gru_work_area_size(1, 0, 1, FWD));
TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_gru_work_area_size(1, 1, 0, FWD));
TEST_ASSERT_EQUAL_UINT64(0,
zdnn_get_gru_work_area_size(1, 1, 1,
(lstm_gru_direction)99));
}

int main(void) {
UNITY_BEGIN();

RUN_TEST(test_lstm_work_area_size_matches_formula);
RUN_TEST(test_gru_work_area_size_matches_formula);
RUN_TEST(test_invalid_inputs_return_zero);

return UNITY_END();
}
93 changes: 93 additions & 0 deletions zdnn/work_area.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// SPDX-License-Identifier: Apache-2.0
/*
* Copyright IBM Corp. 2024
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <stdint.h>

#include "zdnn.h"
#include "zdnn_private.h"

#ifdef __MVS__
#pragma export(zdnn_get_lstm_work_area_size)
#pragma export(zdnn_get_gru_work_area_size)
#endif

// Common helper for LSTM/GRU work-area sizing.
//
// The work area is internal scratch space used by the RNN implementations for
// temporary ztensors. The required size is described in README.md; these helper
// functions codify that formula to reduce caller error and avoid repeating the
// descriptor math in every integration.
static uint64_t get_rnn_work_area_size(uint32_t num_timesteps, uint32_t dim4_mul,
uint32_t dim4_add, uint32_t num_batches,
uint32_t num_hidden,
lstm_gru_direction direction) {
if (!num_timesteps || !num_batches || !num_hidden) {
return 0;
}

uint32_t num_dirs;
switch (direction) {
case FWD:
case BWD:
num_dirs = 1;
break;
case BIDIR:
num_dirs = 2;
break;
default:
return 0;
}

// Compute dim4 using a wider intermediate to avoid overflow.
uint64_t dim4_64 = ((uint64_t)dim4_mul * (uint64_t)num_timesteps) + dim4_add;
if (dim4_64 > UINT32_MAX) {
return 0;
}

zdnn_tensor_desc desc;
init_transformed_desc(ZDNN_NHWC, ZDNN_DLFLOAT16, ZDNN_FORMAT_4DFEATURE, &desc,
(uint32_t)dim4_64, 1, num_batches, num_hidden);

uint64_t size = zdnn_getsize_ztensor(&desc);

// For bidirectional calls, twice the amount of contiguous storage is
// required.
if (num_dirs == 2) {
if (size > (UINT64_MAX / 2)) {
return 0;
}
size *= 2;
}

return size;
}

uint64_t zdnn_get_lstm_work_area_size(uint32_t num_timesteps, uint32_t num_batches,
uint32_t num_hidden,
lstm_gru_direction direction) {
// LSTM sizing (README): dim4 = (4 * num_timesteps) + 6
return get_rnn_work_area_size(num_timesteps, 4, 6, num_batches, num_hidden,
direction);
}

uint64_t zdnn_get_gru_work_area_size(uint32_t num_timesteps, uint32_t num_batches,
uint32_t num_hidden,
lstm_gru_direction direction) {
// GRU sizing (README): dim4 = (3 * num_timesteps) + 5
return get_rnn_work_area_size(num_timesteps, 3, 5, num_batches, num_hidden,
direction);
}
13 changes: 13 additions & 0 deletions zdnn/zdnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,19 @@ zdnn_status zdnn_gelu(const zdnn_ztensor *input, zdnn_ztensor *output);

typedef enum lstm_gru_direction { FWD, BWD, BIDIR } lstm_gru_direction;

// Convenience helpers to compute the required 4K-aligned `work_area` size for
// zdnn_lstm()/zdnn_gru().
//
// These return 0 when inputs are invalid (e.g., direction is not recognized or
// any dimension is 0).
uint64_t zdnn_get_lstm_work_area_size(uint32_t num_timesteps,
uint32_t num_batches, uint32_t num_hidden,
lstm_gru_direction direction);

uint64_t zdnn_get_gru_work_area_size(uint32_t num_timesteps, uint32_t num_batches,
uint32_t num_hidden,
lstm_gru_direction direction);

zdnn_status zdnn_lstm(const zdnn_ztensor *input, const zdnn_ztensor *h0,
const zdnn_ztensor *c0, const zdnn_ztensor *weights,
const zdnn_ztensor *biases,
Expand Down
2 changes: 2 additions & 0 deletions zdnn/zdnn.map
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ ZDNN_1.0 {
zdnn_softmax;
zdnn_softmax_mask;
zdnn_gelu;
zdnn_get_lstm_work_area_size;
zdnn_get_gru_work_area_size;
zdnn_lstm;
zdnn_gru;
zdnn_matmul_op;
Expand Down