diff --git a/tests/testDriver_work_area_size_apis.c b/tests/testDriver_work_area_size_apis.c new file mode 100644 index 0000000..19a9a2d --- /dev/null +++ b/tests/testDriver_work_area_size_apis.c @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright IBM Corp. 2024 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "testsupport.h" + +/* + * Tests for RNN work-area size helper APIs. + * + * These helpers codify the README formulas: + * - LSTM: dim4 = (4 * num_timesteps) + 6 + * - GRU: dim4 = (3 * num_timesteps) + 5 + * and multiply by 2 for BIDIR. + */ + +void setUp(void) {} +void tearDown(void) {} + +static uint64_t expected_size(uint32_t dim4, uint32_t num_batches, + uint32_t num_hidden, + lstm_gru_direction direction) { + zdnn_tensor_desc desc; + init_transformed_desc(ZDNN_NHWC, ZDNN_DLFLOAT16, ZDNN_FORMAT_4DFEATURE, &desc, + dim4, 1, num_batches, num_hidden); + uint64_t size = zdnn_getsize_ztensor(&desc); + if (direction == BIDIR) { + size *= 2; + } + return size; +} + +void test_lstm_work_area_size_matches_formula(void) { + uint32_t ts = 7, b = 4, h = 64; + uint32_t dim4 = (4 * ts) + 6; + + TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, FWD), + zdnn_get_lstm_work_area_size(ts, b, h, FWD)); + TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, BWD), + zdnn_get_lstm_work_area_size(ts, b, h, BWD)); + TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, BIDIR), + zdnn_get_lstm_work_area_size(ts, b, h, BIDIR)); +} + +void test_gru_work_area_size_matches_formula(void) { + uint32_t ts = 7, b = 4, h = 64; + uint32_t dim4 = (3 * ts) + 5; + + TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, FWD), + zdnn_get_gru_work_area_size(ts, b, h, FWD)); + TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, BWD), + zdnn_get_gru_work_area_size(ts, b, h, BWD)); + TEST_ASSERT_EQUAL_UINT64(expected_size(dim4, b, h, BIDIR), + zdnn_get_gru_work_area_size(ts, b, h, BIDIR)); +} + +void test_invalid_inputs_return_zero(void) { + TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_lstm_work_area_size(0, 1, 1, FWD)); + TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_lstm_work_area_size(1, 0, 1, FWD)); + TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_lstm_work_area_size(1, 1, 0, FWD)); + TEST_ASSERT_EQUAL_UINT64(0, + zdnn_get_lstm_work_area_size(1, 1, 1, + (lstm_gru_direction)99)); + + TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_gru_work_area_size(0, 1, 1, FWD)); + TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_gru_work_area_size(1, 0, 1, FWD)); + TEST_ASSERT_EQUAL_UINT64(0, zdnn_get_gru_work_area_size(1, 1, 0, FWD)); + TEST_ASSERT_EQUAL_UINT64(0, + zdnn_get_gru_work_area_size(1, 1, 1, + (lstm_gru_direction)99)); +} + +int main(void) { + UNITY_BEGIN(); + + RUN_TEST(test_lstm_work_area_size_matches_formula); + RUN_TEST(test_gru_work_area_size_matches_formula); + RUN_TEST(test_invalid_inputs_return_zero); + + return UNITY_END(); +} diff --git a/zdnn/work_area.c b/zdnn/work_area.c new file mode 100644 index 0000000..cf1a919 --- /dev/null +++ b/zdnn/work_area.c @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * Copyright IBM Corp. 2024 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "zdnn.h" +#include "zdnn_private.h" + +#ifdef __MVS__ +#pragma export(zdnn_get_lstm_work_area_size) +#pragma export(zdnn_get_gru_work_area_size) +#endif + +// Common helper for LSTM/GRU work-area sizing. +// +// The work area is internal scratch space used by the RNN implementations for +// temporary ztensors. The required size is described in README.md; these helper +// functions codify that formula to reduce caller error and avoid repeating the +// descriptor math in every integration. +static uint64_t get_rnn_work_area_size(uint32_t num_timesteps, uint32_t dim4_mul, + uint32_t dim4_add, uint32_t num_batches, + uint32_t num_hidden, + lstm_gru_direction direction) { + if (!num_timesteps || !num_batches || !num_hidden) { + return 0; + } + + uint32_t num_dirs; + switch (direction) { + case FWD: + case BWD: + num_dirs = 1; + break; + case BIDIR: + num_dirs = 2; + break; + default: + return 0; + } + + // Compute dim4 using a wider intermediate to avoid overflow. + uint64_t dim4_64 = ((uint64_t)dim4_mul * (uint64_t)num_timesteps) + dim4_add; + if (dim4_64 > UINT32_MAX) { + return 0; + } + + zdnn_tensor_desc desc; + init_transformed_desc(ZDNN_NHWC, ZDNN_DLFLOAT16, ZDNN_FORMAT_4DFEATURE, &desc, + (uint32_t)dim4_64, 1, num_batches, num_hidden); + + uint64_t size = zdnn_getsize_ztensor(&desc); + + // For bidirectional calls, twice the amount of contiguous storage is + // required. + if (num_dirs == 2) { + if (size > (UINT64_MAX / 2)) { + return 0; + } + size *= 2; + } + + return size; +} + +uint64_t zdnn_get_lstm_work_area_size(uint32_t num_timesteps, uint32_t num_batches, + uint32_t num_hidden, + lstm_gru_direction direction) { + // LSTM sizing (README): dim4 = (4 * num_timesteps) + 6 + return get_rnn_work_area_size(num_timesteps, 4, 6, num_batches, num_hidden, + direction); +} + +uint64_t zdnn_get_gru_work_area_size(uint32_t num_timesteps, uint32_t num_batches, + uint32_t num_hidden, + lstm_gru_direction direction) { + // GRU sizing (README): dim4 = (3 * num_timesteps) + 5 + return get_rnn_work_area_size(num_timesteps, 3, 5, num_batches, num_hidden, + direction); +} diff --git a/zdnn/zdnn.h b/zdnn/zdnn.h index 0c9cacc..9988177 100644 --- a/zdnn/zdnn.h +++ b/zdnn/zdnn.h @@ -525,6 +525,19 @@ zdnn_status zdnn_gelu(const zdnn_ztensor *input, zdnn_ztensor *output); typedef enum lstm_gru_direction { FWD, BWD, BIDIR } lstm_gru_direction; +// Convenience helpers to compute the required 4K-aligned `work_area` size for +// zdnn_lstm()/zdnn_gru(). +// +// These return 0 when inputs are invalid (e.g., direction is not recognized or +// any dimension is 0). +uint64_t zdnn_get_lstm_work_area_size(uint32_t num_timesteps, + uint32_t num_batches, uint32_t num_hidden, + lstm_gru_direction direction); + +uint64_t zdnn_get_gru_work_area_size(uint32_t num_timesteps, uint32_t num_batches, + uint32_t num_hidden, + lstm_gru_direction direction); + zdnn_status zdnn_lstm(const zdnn_ztensor *input, const zdnn_ztensor *h0, const zdnn_ztensor *c0, const zdnn_ztensor *weights, const zdnn_ztensor *biases, diff --git a/zdnn/zdnn.map b/zdnn/zdnn.map index d80d3a1..5396134 100644 --- a/zdnn/zdnn.map +++ b/zdnn/zdnn.map @@ -126,6 +126,8 @@ ZDNN_1.0 { zdnn_softmax; zdnn_softmax_mask; zdnn_gelu; + zdnn_get_lstm_work_area_size; + zdnn_get_gru_work_area_size; zdnn_lstm; zdnn_gru; zdnn_matmul_op;