From 37fa1272a16c99a819fc8d5437aefb1b289f12e6 Mon Sep 17 00:00:00 2001 From: Tim Hurski Date: Fri, 14 Nov 2025 17:48:02 +0000 Subject: [PATCH 1/3] DX-108149: Add support for AES CBC encryption mode --- cpp/src/gandiva/CMakeLists.txt | 3 + cpp/src/gandiva/encrypt_utils_cbc.cc | 197 +++++++++++++++++++ cpp/src/gandiva/encrypt_utils_cbc.h | 67 +++++++ cpp/src/gandiva/encrypt_utils_cbc_test.cc | 191 ++++++++++++++++++ cpp/src/gandiva/encrypt_utils_common.cc | 33 ++++ cpp/src/gandiva/encrypt_utils_common.h | 33 ++++ cpp/src/gandiva/encrypt_utils_ecb.cc | 9 +- cpp/src/gandiva/function_registry_string.cc | 10 + cpp/src/gandiva/gdv_function_stubs.cc | 202 ++++++++++++++++++++ cpp/src/gandiva/gdv_function_stubs.h | 15 ++ cpp/src/gandiva/gdv_function_stubs_test.cc | 177 +++++++++++++++++ 11 files changed, 929 insertions(+), 8 deletions(-) create mode 100644 cpp/src/gandiva/encrypt_utils_cbc.cc create mode 100644 cpp/src/gandiva/encrypt_utils_cbc.h create mode 100644 cpp/src/gandiva/encrypt_utils_cbc_test.cc create mode 100644 cpp/src/gandiva/encrypt_utils_common.cc create mode 100644 cpp/src/gandiva/encrypt_utils_common.h diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 02c79f6782ab..8d198416d907 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -56,7 +56,9 @@ set(SRC_FILES decimal_xlarge.cc engine.cc date_utils.cc + encrypt_utils_common.cc encrypt_utils_ecb.cc + encrypt_utils_cbc.cc expr_decomposer.cc expr_validator.cc expression.cc @@ -257,6 +259,7 @@ add_gandiva_test(internals-test annotator_test.cc tree_expr_test.cc encrypt_utils_ecb_test.cc + encrypt_utils_cbc_test.cc expr_decomposer_test.cc exported_funcs_registry_test.cc expression_registry_test.cc diff --git a/cpp/src/gandiva/encrypt_utils_cbc.cc b/cpp/src/gandiva/encrypt_utils_cbc.cc new file mode 100644 index 000000000000..15ac0413a75c --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_cbc.cc @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_cbc.h" +#include "gandiva/encrypt_utils_common.h" +#include +#include +#include +#include +#include +#include + +namespace gandiva { + +namespace { + +// Padding mode enum +enum class PaddingMode { + PKCS7, + NONE +}; + +const EVP_CIPHER* get_cbc_cipher_algo(int32_t key_length) { + switch (key_length) { + case 16: + return EVP_aes_128_cbc(); + case 24: + return EVP_aes_192_cbc(); + case 32: + return EVP_aes_256_cbc(); + default: { + std::ostringstream oss; + oss << "Unsupported key length for AES-CBC: " << key_length + << " bytes. Supported lengths: 16, 24, 32 bytes"; + throw std::runtime_error(oss.str()); + } + } +} + +PaddingMode get_padding_mode(const char* padding_str, int32_t padding_len) { + if (padding_str == nullptr || padding_len <= 0) { + throw std::runtime_error("Invalid padding parameter: null or empty"); + } + + // Case-insensitive comparison using strncasecmp + if (strncasecmp(padding_str, "PKCS7", padding_len) == 0 && padding_len == 5) { + return PaddingMode::PKCS7; + } else if (strncasecmp(padding_str, "NONE", padding_len) == 0 && padding_len == 4) { + return PaddingMode::NONE; + } else { + std::ostringstream oss; + oss << "Invalid padding mode: '" << std::string(padding_str, padding_len) + << "'. Supported modes: PKCS7, NONE (case-insensitive)"; + throw std::runtime_error(oss.str()); + } +} + +} // namespace + +GANDIVA_EXPORT +int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* padding, int32_t padding_len, unsigned char* cipher) { + // Validate IV length + if (iv_len != 16) { + std::ostringstream oss; + oss << "Invalid IV length for AES-CBC: " << iv_len + << " bytes. IV must be exactly 16 bytes"; + throw std::runtime_error(oss.str()); + } + + PaddingMode padding_mode = get_padding_mode(padding, padding_len); + + int32_t cipher_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_cbc_cipher_algo(key_len); + + if (!en_ctx) { + throw std::runtime_error("Could not create EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + int padding_flag = (padding_mode == PaddingMode::PKCS7) ? 1 : 0; + if (!EVP_CIPHER_CTX_set_padding(en_ctx, padding_flag)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not set padding mode for encryption: " + + get_openssl_error_string()); + } + + if (!EVP_EncryptUpdate(en_ctx, cipher, &len, + reinterpret_cast(plaintext), + plaintext_len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not update EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) { + EVP_CIPHER_CTX_free(en_ctx); + throw std::runtime_error("Could not finalize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + EVP_CIPHER_CTX_free(en_ctx); + return cipher_len; +} + +GANDIVA_EXPORT +int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* padding, int32_t padding_len, unsigned char* plaintext) { + // Validate IV length + if (iv_len != 16) { + std::ostringstream oss; + oss << "Invalid IV length for AES-CBC: " << iv_len + << " bytes. IV must be exactly 16 bytes"; + throw std::runtime_error(oss.str()); + } + + PaddingMode padding_mode = get_padding_mode(padding, padding_len); + + int32_t plaintext_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_cbc_cipher_algo(key_len); + + if (!de_ctx) { + throw std::runtime_error("Could not create EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not initialize EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + int padding_flag = (padding_mode == PaddingMode::PKCS7) ? 1 : 0; + if (!EVP_CIPHER_CTX_set_padding(de_ctx, padding_flag)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not set padding mode for decryption: " + + get_openssl_error_string()); + } + + if (!EVP_DecryptUpdate(de_ctx, plaintext, &len, + reinterpret_cast(ciphertext), + ciphertext_len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not update EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + plaintext_len += len; + + if (!EVP_DecryptFinal_ex(de_ctx, plaintext + len, &len)) { + EVP_CIPHER_CTX_free(de_ctx); + throw std::runtime_error("Could not finalize EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + plaintext_len += len; + + EVP_CIPHER_CTX_free(de_ctx); + return plaintext_len; +} + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_cbc.h b/cpp/src/gandiva/encrypt_utils_cbc.h new file mode 100644 index 000000000000..9b5bcaa6ab23 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_cbc.h @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "gandiva/visibility.h" + +namespace gandiva { + +/** + * Encrypt data using AES-CBC algorithm with explicit padding mode + * + * @param plaintext The data to encrypt + * @param plaintext_len Length of plaintext in bytes + * @param key The encryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (must be exactly 16 bytes) + * @param iv_len Length of IV in bytes (must be 16) + * @param padding Padding mode string: "PKCS7" or "NONE" (case-insensitive) + * @param padding_len Length of padding string in bytes + * @param cipher Output buffer for encrypted data + * @return Length of encrypted data in bytes + * @throws std::runtime_error on encryption failure or invalid parameters + */ +GANDIVA_EXPORT +int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* padding, int32_t padding_len, unsigned char* cipher); + +/** + * Decrypt data using AES-CBC algorithm with explicit padding mode + * + * @param ciphertext The data to decrypt + * @param ciphertext_len Length of ciphertext in bytes + * @param key The decryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (must be exactly 16 bytes) + * @param iv_len Length of IV in bytes (must be 16) + * @param padding Padding mode string: "PKCS7" or "NONE" (case-insensitive) + * @param padding_len Length of padding string in bytes + * @param plaintext Output buffer for decrypted data + * @return Length of decrypted data in bytes + * @throws std::runtime_error on decryption failure or invalid parameters + */ +GANDIVA_EXPORT +int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* padding, int32_t padding_len, unsigned char* plaintext); + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_cbc_test.cc b/cpp/src/gandiva/encrypt_utils_cbc_test.cc new file mode 100644 index 000000000000..88e5403d1206 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_cbc_test.cc @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_cbc.h" + +#include +#include + +// Test PKCS#7 padding with 16-byte key +TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_16) { + auto* key = "12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + "PKCS7", 5, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test PKCS#7 padding with 24-byte key +TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_24) { + auto* key = "12345678abcdefgh12345678"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "some\ntest\nstring"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + "PKCS7", 5, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test PKCS#7 padding with 32-byte key +TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_32) { + auto* key = "12345678abcdefgh12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "New\ntest\nstring"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + "PKCS7", 5, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test no-padding mode with block-aligned data (16 bytes) +TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptNoPadding_16) { + auto* key = "12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "1234567890123456"; // Exactly 16 bytes + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "NONE", 4, cipher); + + unsigned char decrypted[64]; + int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + "NONE", 4, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test case-insensitive padding mode +TEST(TestAesCbcEncryptUtils, TestCaseInsensitivePadding) { + auto* key = "12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher1[64]; + unsigned char cipher2[64]; + + // Test with "pkcs7" (lowercase) + int32_t cipher1_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "pkcs7", 5, cipher1); + + // Test with "PKCS7" (uppercase) + int32_t cipher2_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher2); + + // Both should produce same ciphertext + EXPECT_EQ(cipher1_len, cipher2_len); + EXPECT_EQ(std::string(reinterpret_cast(cipher1), cipher1_len), + std::string(reinterpret_cast(cipher2), cipher2_len)); +} + +// Test invalid IV length +TEST(TestAesCbcEncryptUtils, TestInvalidIVLength) { + auto* key = "12345678abcdefgh"; + auto* iv = "short"; // Too short + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + ASSERT_THROW({ + gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + }, std::runtime_error); +} + +// Test invalid key length +TEST(TestAesCbcEncryptUtils, TestInvalidKeyLength) { + auto* key = "short"; // Too short + auto* iv = "1234567890123456"; + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + ASSERT_THROW({ + gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "PKCS7", 5, cipher); + }, std::runtime_error); +} + +// Test invalid padding mode +TEST(TestAesCbcEncryptUtils, TestInvalidPaddingMode) { + auto* key = "12345678abcdefgh"; + auto* iv = "1234567890123456"; + auto* to_encrypt = "test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[64]; + + ASSERT_THROW({ + gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, "INVALID", 7, cipher); + }, std::runtime_error); +} + diff --git a/cpp/src/gandiva/encrypt_utils_common.cc b/cpp/src/gandiva/encrypt_utils_common.cc new file mode 100644 index 000000000000..b1b7d1729495 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_common.cc @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_common.h" +#include +#include + +namespace gandiva { + +std::string get_openssl_error_string() { + unsigned long error_code = ERR_get_error(); + if (error_code == 0) { + return "Unknown OpenSSL error"; + } + return std::string(ERR_reason_error_string(error_code)); +} + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_common.h b/cpp/src/gandiva/encrypt_utils_common.h new file mode 100644 index 000000000000..3887368b5a7a --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_common.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef GANDIVA_ENCRYPT_UTILS_COMMON_H +#define GANDIVA_ENCRYPT_UTILS_COMMON_H + +#include + +namespace gandiva { + +/// @brief Get a human-readable error string from OpenSSL's error queue. +/// @return A string describing the most recent OpenSSL error, or "Unknown OpenSSL error" +/// if no error is available. +std::string get_openssl_error_string(); + +} // namespace gandiva + +#endif // GANDIVA_ENCRYPT_UTILS_COMMON_H + diff --git a/cpp/src/gandiva/encrypt_utils_ecb.cc b/cpp/src/gandiva/encrypt_utils_ecb.cc index 4d2a310121f8..5cc13f335a70 100644 --- a/cpp/src/gandiva/encrypt_utils_ecb.cc +++ b/cpp/src/gandiva/encrypt_utils_ecb.cc @@ -16,6 +16,7 @@ // under the License. #include "gandiva/encrypt_utils_ecb.h" +#include "gandiva/encrypt_utils_common.h" #include #include #include @@ -26,14 +27,6 @@ namespace gandiva { namespace { -std::string get_openssl_error_string() { - unsigned long error_code = ERR_get_error(); - if (error_code == 0) { - return "Unknown OpenSSL error"; - } - return std::string(ERR_reason_error_string(error_code)); -} - const EVP_CIPHER* get_ecb_cipher_algo(int32_t key_length) { switch (key_length) { case 16: diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index fb662b623fbf..ceca40893730 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -513,6 +513,16 @@ std::vector GetStringFunctionRegistry() { kResultNullIfNull, "gdv_fn_aes_decrypt_ecb", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + // CBC mode specific functions + // Binary-based signatures (BINARY, BINARY, BINARY, UTF8, UTF8) -> BINARY + NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), binary(), utf8(), utf8()}, binary(), + kResultNullIfNull, "gdv_fn_aes_encrypt_cbc", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), binary(), utf8(), utf8()}, binary(), + kResultNullIfNull, "gdv_fn_aes_decrypt_cbc", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + NativeFunction("mask_first_n", {}, DataTypeVector{utf8(), int32()}, utf8(), kResultNullIfNull, "gdv_mask_first_n_utf8_int32", NativeFunction::kNeedsContext), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index d2610c4ee92c..402c1f2cca45 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -31,6 +31,7 @@ #include "arrow/util/value_parsing.h" #include "gandiva/encrypt_utils_ecb.h" +#include "gandiva/encrypt_utils_cbc.h" #include "gandiva/engine.h" #include "gandiva/exported_funcs.h" #include "gandiva/in_holder.h" @@ -584,6 +585,163 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int return result; } +// CBC mode specific functions - core implementation with explicit padding +GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, + const char* padding, int32_t padding_len, + int32_t* out_len) { + // Validate mode parameter + if (mode == nullptr) { + std::ostringstream oss; + oss << "Invalid mode parameter for AES encryption"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + std::string mode_str(mode, mode_len); + // Convert to uppercase for comparison + std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); + + if (mode_str != "CBC") { + std::ostringstream oss; + oss << "AES encryption mode mismatch: function signature indicates CBC mode, but '" + << mode_str << "' was provided instead"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (data_len < 0) { + std::ostringstream oss; + oss << "Invalid data length for AES encryption: " << data_len << " (must be >= 0)"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (key_data_len < 0 || (key_data_len != 16 && key_data_len != 24 && key_data_len != 32)) { + std::ostringstream oss; + oss << "Invalid key length for AES encryption: " << key_data_len + << " bytes. Supported lengths: 16, 24, 32 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (iv_data_len != 16) { + std::ostringstream oss; + oss << "Invalid IV length for AES-CBC: " << iv_data_len + << " bytes. IV must be exactly 16 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + // Allocate output buffer with padding overhead + int32_t max_out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len + 16), 16)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, max_out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output buffer"); + *out_len = 0; + return nullptr; + } + + try { + *out_len = gandiva::aes_encrypt_cbc(data, data_len, key_data, key_data_len, + iv_data, iv_data_len, padding, padding_len, + reinterpret_cast(ret)); + } catch (const std::runtime_error& e) { + gdv_fn_context_set_error_msg(context, e.what()); + *out_len = 0; + return nullptr; + } + + return ret; +} + +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, + const char* padding, int32_t padding_len, + int32_t* out_len) { + // Validate mode parameter + if (mode == nullptr) { + std::ostringstream oss; + oss << "Invalid mode parameter for AES decryption"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + std::string mode_str(mode, mode_len); + // Convert to uppercase for comparison + std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); + + if (mode_str != "CBC") { + std::ostringstream oss; + oss << "AES decryption mode mismatch: function signature indicates CBC mode, but '" + << mode_str << "' was provided instead"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (data_len < 0) { + std::ostringstream oss; + oss << "Invalid data length for AES decryption: " << data_len << " (must be >= 0)"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (key_data_len < 0 || (key_data_len != 16 && key_data_len != 24 && key_data_len != 32)) { + std::ostringstream oss; + oss << "Invalid key length for AES decryption: " << key_data_len + << " bytes. Supported lengths: 16, 24, 32 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + if (iv_data_len != 16) { + std::ostringstream oss; + oss << "Invalid IV length for AES-CBC: " << iv_data_len + << " bytes. IV must be exactly 16 bytes"; + gdv_fn_context_set_error_msg(context, oss.str().c_str()); + *out_len = 0; + return ""; + } + + // Allocate output buffer + int32_t max_out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), 16)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, max_out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output buffer"); + *out_len = 0; + return nullptr; + } + + try { + *out_len = gandiva::aes_decrypt_cbc(data, data_len, key_data, key_data_len, + iv_data, iv_data_len, padding, padding_len, + reinterpret_cast(ret)); + } catch (const std::runtime_error& e) { + gdv_fn_context_set_error_msg(context, e.what()); + *out_len = 0; + return nullptr; + } + + return ret; +} + GANDIVA_EXPORT const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, @@ -1293,6 +1451,50 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_aes_decrypt_ecb_legacy)); + // gdv_fn_aes_encrypt_cbc + // Note: Mode and IV parameters are passed as binary strings (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (binary string) + types->i32_type(), // mode_length + types->i8_ptr_type(), // iv (binary string) + types->i32_type(), // iv_length + types->i8_ptr_type(), // padding (binary string) + types->i32_type(), // padding_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt_cbc", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_encrypt_cbc)); + + // gdv_fn_aes_decrypt_cbc + // Note: Mode and IV parameters are passed as binary strings (data + length) + // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (binary string) + types->i32_type(), // mode_length + types->i8_ptr_type(), // iv (binary string) + types->i32_type(), // iv_length + types->i8_ptr_type(), // padding (binary string) + types->i32_type(), // padding_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt_cbc", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_decrypt_cbc)); + // gdv_mask_first_n and gdv_mask_last_n std::vector mask_args = { types->i64_type(), // context diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 5c5452df2805..324ecfb2c03d 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -213,7 +213,22 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int const char* key_data, int32_t key_data_len, int32_t* out_len); +// CBC mode specific functions +GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, + const char* padding, int32_t padding_len, + int32_t* out_len); +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, + const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, + const char* padding, int32_t padding_len, + int32_t* out_len); GANDIVA_EXPORT const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index da785d1a511d..55572880bedf 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1480,4 +1480,181 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { ctx.Reset(); } +// Tests for CBC mode encryption/decryption +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbc) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "PKCS7"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, padding.c_str(), padding_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_aes_decrypt_cbc( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, + iv.c_str(), iv_len, padding.c_str(), padding_len, &decrypted_len); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcNoPadding) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "1234567890123456"; // Exactly 16 bytes + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "NONE"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, padding.c_str(), padding_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_aes_decrypt_cbc( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, + iv.c_str(), iv_len, padding.c_str(), padding_len, &decrypted_len); + EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcCaseInsensitive) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len1 = 0; + int32_t cipher_len2 = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding_upper = "PKCS7"; + auto padding_upper_len = static_cast(padding_upper.length()); + std::string padding_lower = "pkcs7"; + auto padding_lower_len = static_cast(padding_lower.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher1 = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, padding_upper.c_str(), padding_upper_len, + &cipher_len1); + const char* cipher2 = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), + key16_len, mode.c_str(), mode_len, iv.c_str(), + iv_len, padding_lower.c_str(), padding_lower_len, + &cipher_len2); + + // Both should produce same ciphertext + EXPECT_EQ(cipher_len1, cipher_len2); + EXPECT_EQ(std::string(cipher1, cipher_len1), std::string(cipher2, cipher_len2)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidIV) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "short"; // Too short + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "PKCS7"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, + mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), + padding_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid IV length")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidKey) { + gandiva::ExecutionContext ctx; + std::string key = "short"; // Too short + auto key_len = static_cast(key.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "PKCS7"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key.c_str(), key_len, + mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), + padding_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid key length")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidPadding) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = "CBC"; + auto mode_len = static_cast(mode.length()); + std::string padding = "INVALID"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, + mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), + padding_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid padding mode")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcModeValidation) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int32_t cipher_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string invalid_mode = "ECB"; + auto invalid_mode_len = static_cast(invalid_mode.length()); + std::string padding = "PKCS7"; + auto padding_len = static_cast(padding.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + // Test encrypt with invalid mode + gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, + invalid_mode.c_str(), invalid_mode_len, iv.c_str(), iv_len, + padding.c_str(), padding_len, &cipher_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("AES encryption mode mismatch")); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("ECB")); + ctx.Reset(); +} + } // namespace gandiva From 00f53d47c7b225b57c6a2aee9567e9ca9ae4c1a4 Mon Sep 17 00:00:00 2001 From: Tim Hurski Date: Wed, 19 Nov 2025 17:50:55 +0000 Subject: [PATCH 2/3] Extract AES mode validation into ensure_mode() helper function and update macOS runner versions - Create ensure_mode() helper that throws std::runtime_error for invalid modes - Update all ECB and CBC AES functions (3 encrypt + 3 decrypt) to use ensure_mode() - Wrap all function bodies in try-catch to handle exceptions from ensure_mode() - Consistent error handling across ECB and CBC modes - Update CBC function signatures to include mode parameter: (data, key, mode, iv, padding) - Update function registry to reflect new CBC parameter order - Update LLVM mappings comments for clarity - Update test expectations to match new error messages - Update macOS runner versions from macos-13 to macos-15-intel in CI workflows - Excludes GCM mode changes from the original commits --- .github/workflows/csharp.yml | 4 +- .github/workflows/java.yml | 4 +- .github/workflows/js.yml | 4 +- cpp/src/gandiva/function_registry_string.cc | 7 +- cpp/src/gandiva/gdv_function_stubs.cc | 313 ++++++-------------- cpp/src/gandiva/gdv_function_stubs_test.cc | 4 +- dev/tasks/java-jars/github.yml | 4 +- dev/tasks/r/github.packages.yml | 6 +- dev/tasks/tasks.yml | 6 +- 9 files changed, 107 insertions(+), 245 deletions(-) diff --git a/.github/workflows/csharp.yml b/.github/workflows/csharp.yml index 5f657e6c1bf5..d479456d0d8e 100644 --- a/.github/workflows/csharp.yml +++ b/.github/workflows/csharp.yml @@ -94,8 +94,8 @@ jobs: run: ci/scripts/csharp_test.sh $(pwd) macos: - name: AMD64 macOS 13 C# ${{ matrix.dotnet }} - runs-on: macos-13 + name: AMD64 macOS 15 C# ${{ matrix.dotnet }} + runs-on: macos-15-intel if: ${{ !contains(github.event.pull_request.title, 'WIP') }} timeout-minutes: 15 strategy: diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml index 5766c63bf522..6c0cf0991168 100644 --- a/.github/workflows/java.yml +++ b/.github/workflows/java.yml @@ -106,8 +106,8 @@ jobs: run: archery docker push ${{ matrix.image }} macos: - name: AMD64 macOS 13 Java JDK ${{ matrix.jdk }} - runs-on: macos-13 + name: AMD64 macOS 15 Java JDK ${{ matrix.jdk }} + runs-on: macos-15-intel if: ${{ !contains(github.event.pull_request.title, 'WIP') }} timeout-minutes: 30 strategy: diff --git a/.github/workflows/js.yml b/.github/workflows/js.yml index 031310fd4027..a51ad867aa70 100644 --- a/.github/workflows/js.yml +++ b/.github/workflows/js.yml @@ -81,8 +81,8 @@ jobs: run: archery docker push debian-js macos: - name: AMD64 macOS 13 NodeJS ${{ matrix.node }} - runs-on: macos-13 + name: AMD64 macOS 15 NodeJS ${{ matrix.node }} + runs-on: macos-15-intel if: ${{ !contains(github.event.pull_request.title, 'WIP') }} timeout-minutes: 30 strategy: diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index ceca40893730..054fb813a483 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -514,12 +514,13 @@ std::vector GetStringFunctionRegistry() { NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), // CBC mode specific functions - // Binary-based signatures (BINARY, BINARY, BINARY, UTF8, UTF8) -> BINARY - NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), binary(), utf8(), utf8()}, binary(), + // Binary-based signatures (BINARY, BINARY, UTF8, BINARY, UTF8) -> BINARY + // Parameters: data, key, mode, iv, padding + NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), utf8()}, binary(), kResultNullIfNull, "gdv_fn_aes_encrypt_cbc", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), - NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), binary(), utf8(), utf8()}, binary(), + NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), utf8()}, binary(), kResultNullIfNull, "gdv_fn_aes_decrypt_cbc", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 402c1f2cca45..ef47fb68b54a 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -396,8 +396,25 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8) #undef GDV_FN_CAST_VARCHAR_INTEGER #undef GDV_FN_CAST_VARCHAR_REAL +// Helper function to validate AES mode parameter +// Throws std::runtime_error if mode is invalid +static void ensure_mode(const char* mode, int32_t mode_len, + const std::string& expected_mode) { + if (mode == nullptr) { + throw std::runtime_error("Invalid mode parameter for AES encryption"); + } + std::string mode_str(mode, mode_len); + // Convert to uppercase for comparison + std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); + if (mode_str != expected_mode) { + std::ostringstream oss; + oss << "AES encryption mode mismatch: function signature indicates " << expected_mode + << " mode, but '" << mode_str << "' was provided instead"; + throw std::runtime_error(oss.str()); + } +} // ECB mode specific functions - core implementation // This handles both string and binary inputs (they have the same C signature) @@ -406,68 +423,40 @@ const char* gdv_fn_aes_encrypt_ecb(int64_t context, const char* data, int32_t da const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, int32_t* out_len) { - // Validate mode parameter - if (mode == nullptr) { - std::ostringstream oss; - oss << "Invalid mode parameter for AES encryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - std::string mode_str(mode, mode_len); - // Convert to uppercase for comparison - std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - - if (mode_str != "ECB") { - std::ostringstream oss; - oss << "AES encryption mode mismatch: function signature indicates ECB mode, but '" - << mode_str << "' was provided instead"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + try { + // Validate mode parameter + ensure_mode(mode, mode_len, "ECB"); - if (data_len < 0) { - std::ostringstream oss; - oss << "Invalid data length for AES encryption: " << data_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + if (data_len < 0) { + throw std::runtime_error( + std::string("Invalid data length for AES encryption: ") + std::to_string(data_len) + + " (must be >= 0)"); + } - if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { - std::ostringstream oss; - oss << "Invalid key length for AES encryption: " << key_data_len - << " bytes. Supported lengths: 16, 24, 32 bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { + throw std::runtime_error( + std::string("Invalid key length for AES encryption: ") + std::to_string(key_data_len) + + " bytes. Supported lengths: 16, 24, 32 bytes"); + } - // AES block size is always 16 bytes (128 bits), regardless of key length - int64_t kAesBlockSize = 16; - *out_len = - static_cast(arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); - if (ret == nullptr) { - std::ostringstream oss; - oss << "Could not allocate memory for AES encryption output: " << *out_len << " bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + // AES block size is always 16 bytes (128 bits), regardless of key length + int64_t kAesBlockSize = 16; + *out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + throw std::runtime_error(std::string("Could not allocate memory for AES encryption output: ") + + std::to_string(*out_len) + " bytes"); + } - try { *out_len = gandiva::aes_encrypt_ecb(data, data_len, key_data, key_data_len, reinterpret_cast(ret)); + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } // Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 @@ -498,68 +487,40 @@ const char* gdv_fn_aes_decrypt_ecb(int64_t context, const char* data, int32_t da const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, int32_t* out_len) { - // Validate mode parameter - if (mode == nullptr) { - std::ostringstream oss; - oss << "Invalid mode parameter for AES decryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - std::string mode_str(mode, mode_len); - // Convert to uppercase for comparison - std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - - if (mode_str != "ECB") { - std::ostringstream oss; - oss << "AES decryption mode mismatch: function signature indicates ECB mode, but '" - << mode_str << "' was provided instead"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + try { + // Validate mode parameter + ensure_mode(mode, mode_len, "ECB"); - if (data_len < 0) { - std::ostringstream oss; - oss << "Invalid data length for AES decryption: " << data_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + if (data_len < 0) { + throw std::runtime_error( + std::string("Invalid data length for AES decryption: ") + std::to_string(data_len) + + " (must be >= 0)"); + } - if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { - std::ostringstream oss; - oss << "Invalid key length for AES decryption: " << key_data_len - << " bytes. Supported lengths: 16, 24, 32 bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { + throw std::runtime_error( + std::string("Invalid key length for AES decryption: ") + std::to_string(key_data_len) + + " bytes. Supported lengths: 16, 24, 32 bytes"); + } - // AES block size is always 16 bytes (128 bits), regardless of key length - int64_t kAesBlockSize = 16; - *out_len = - static_cast(arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); - if (ret == nullptr) { - std::ostringstream oss; - oss << "Could not allocate memory for AES decryption output: " << *out_len << " bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return nullptr; - } + // AES block size is always 16 bytes (128 bits), regardless of key length + int64_t kAesBlockSize = 16; + *out_len = static_cast( + arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + throw std::runtime_error(std::string("Could not allocate memory for AES decryption output: ") + + std::to_string(*out_len) + " bytes"); + } - try { *out_len = gandiva::aes_decrypt_ecb(data, data_len, key_data, key_data_len, reinterpret_cast(ret)); + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } // Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 @@ -593,75 +554,25 @@ const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t da const char* iv_data, int32_t iv_data_len, const char* padding, int32_t padding_len, int32_t* out_len) { - // Validate mode parameter - if (mode == nullptr) { - std::ostringstream oss; - oss << "Invalid mode parameter for AES encryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - std::string mode_str(mode, mode_len); - // Convert to uppercase for comparison - std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - - if (mode_str != "CBC") { - std::ostringstream oss; - oss << "AES encryption mode mismatch: function signature indicates CBC mode, but '" - << mode_str << "' was provided instead"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - if (data_len < 0) { - std::ostringstream oss; - oss << "Invalid data length for AES encryption: " << data_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - if (key_data_len < 0 || (key_data_len != 16 && key_data_len != 24 && key_data_len != 32)) { - std::ostringstream oss; - oss << "Invalid key length for AES encryption: " << key_data_len - << " bytes. Supported lengths: 16, 24, 32 bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - if (iv_data_len != 16) { - std::ostringstream oss; - oss << "Invalid IV length for AES-CBC: " << iv_data_len - << " bytes. IV must be exactly 16 bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + try { + // Validate mode parameter + ensure_mode(mode, mode_len, "CBC"); - // Allocate output buffer with padding overhead - int32_t max_out_len = static_cast( - arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len + 16), 16)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, max_out_len)); - if (ret == nullptr) { - gdv_fn_context_set_error_msg(context, "Could not allocate memory for output buffer"); - *out_len = 0; - return nullptr; - } + // Allocate output buffer (max size: input + 16 bytes for padding) + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len + 16)); + if (ret == nullptr) { + throw std::runtime_error("Could not allocate memory for AES-CBC encryption"); + } - try { - *out_len = gandiva::aes_encrypt_cbc(data, data_len, key_data, key_data_len, - iv_data, iv_data_len, padding, padding_len, + *out_len = gandiva::aes_encrypt_cbc(data, data_len, key_data, key_data_len, iv_data, iv_data_len, + padding, padding_len, reinterpret_cast(ret)); + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } GANDIVA_EXPORT @@ -671,75 +582,25 @@ const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t da const char* iv_data, int32_t iv_data_len, const char* padding, int32_t padding_len, int32_t* out_len) { - // Validate mode parameter - if (mode == nullptr) { - std::ostringstream oss; - oss << "Invalid mode parameter for AES decryption"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - std::string mode_str(mode, mode_len); - // Convert to uppercase for comparison - std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - - if (mode_str != "CBC") { - std::ostringstream oss; - oss << "AES decryption mode mismatch: function signature indicates CBC mode, but '" - << mode_str << "' was provided instead"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - if (data_len < 0) { - std::ostringstream oss; - oss << "Invalid data length for AES decryption: " << data_len << " (must be >= 0)"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - if (key_data_len < 0 || (key_data_len != 16 && key_data_len != 24 && key_data_len != 32)) { - std::ostringstream oss; - oss << "Invalid key length for AES decryption: " << key_data_len - << " bytes. Supported lengths: 16, 24, 32 bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } - - if (iv_data_len != 16) { - std::ostringstream oss; - oss << "Invalid IV length for AES-CBC: " << iv_data_len - << " bytes. IV must be exactly 16 bytes"; - gdv_fn_context_set_error_msg(context, oss.str().c_str()); - *out_len = 0; - return ""; - } + try { + // Validate mode parameter + ensure_mode(mode, mode_len, "CBC"); - // Allocate output buffer - int32_t max_out_len = static_cast( - arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), 16)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, max_out_len)); - if (ret == nullptr) { - gdv_fn_context_set_error_msg(context, "Could not allocate memory for output buffer"); - *out_len = 0; - return nullptr; - } + // Allocate output buffer (max size: input size, since decryption removes padding) + char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len)); + if (ret == nullptr) { + throw std::runtime_error("Could not allocate memory for AES-CBC decryption"); + } - try { - *out_len = gandiva::aes_decrypt_cbc(data, data_len, key_data, key_data_len, - iv_data, iv_data_len, padding, padding_len, + *out_len = gandiva::aes_decrypt_cbc(data, data_len, key_data, key_data_len, iv_data, iv_data_len, + padding, padding_len, reinterpret_cast(ret)); + return ret; } catch (const std::runtime_error& e) { gdv_fn_context_set_error_msg(context, e.what()); *out_len = 0; return nullptr; } - - return ret; } @@ -1452,7 +1313,7 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { reinterpret_cast(gdv_fn_aes_decrypt_ecb_legacy)); // gdv_fn_aes_encrypt_cbc - // Note: Mode and IV parameters are passed as binary strings (data + length) + // Note: The mode, IV and padding parameters are passed as binary/UTF8 strings (data + length) // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) args = { types->i64_type(), // context @@ -1474,7 +1335,7 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { reinterpret_cast(gdv_fn_aes_encrypt_cbc)); // gdv_fn_aes_decrypt_cbc - // Note: Mode and IV parameters are passed as binary strings (data + length) + // Note: The mode, IV and padding parameters are passed as binary/UTF8 strings (data + length) // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) args = { types->i64_type(), // context diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index 55572880bedf..f70bbb9d26df 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1474,7 +1474,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher.c_str(), cipher_len_val, key16.c_str(), key16_len, invalid_mode.c_str(), invalid_mode_len, &decrypted_len); EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("AES decryption mode mismatch")); + ::testing::HasSubstr("AES encryption mode mismatch")); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("CBC")); ctx.Reset(); @@ -1607,7 +1607,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidKey) { gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key.c_str(), key_len, mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), padding_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid key length")); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported key length")); ctx.Reset(); } diff --git a/dev/tasks/java-jars/github.yml b/dev/tasks/java-jars/github.yml index f7dd177e8757..ff1834e63b91 100644 --- a/dev/tasks/java-jars/github.yml +++ b/dev/tasks/java-jars/github.yml @@ -91,7 +91,7 @@ jobs: fail-fast: false matrix: platform: - - { runs_on: ["macos-13"], arch: "x86_64"} + - { runs_on: ["macos-15-intel"], arch: "x86_64"} env: MACOSX_DEPLOYMENT_TARGET: "12.0" steps: @@ -190,7 +190,7 @@ jobs: fail-fast: false matrix: platform: - - { runs_on: ["macos-13"], arch: "x86_64"} + - { runs_on: ["macos-15-intel"], arch: "x86_64"} needs: - build-cpp-ubuntu - build-cpp-macos diff --git a/dev/tasks/r/github.packages.yml b/dev/tasks/r/github.packages.yml index 839e3d534107..a0005e19035d 100644 --- a/dev/tasks/r/github.packages.yml +++ b/dev/tasks/r/github.packages.yml @@ -66,7 +66,7 @@ jobs: fail-fast: false matrix: platform: - - { runs_on: macos-13, arch: "x86_64" } + - { runs_on: macos-15-intel, arch: "x86_64" } - { runs_on: macos-14, arch: "arm64" } openssl: ['3.0', '1.1'] @@ -208,7 +208,7 @@ jobs: matrix: platform: - { runs_on: 'windows-latest', name: "Windows"} - - { runs_on: macos-13, name: "macOS x86_64"} + - { runs_on: macos-15-intel, name: "macOS x86_64"} - { runs_on: macos-14, name: "macOS arm64" } r_version: [oldrel, release] steps: @@ -389,7 +389,7 @@ jobs: matrix: platform: - {runs_on: "ubuntu-latest", name: "Linux"} - - {runs_on: "macos-13" , name: "macOS"} + - {runs_on: "macos-15-intel" , name: "macOS"} steps: - name: Install R uses: r-lib/actions/setup-r@v2 diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 91e1c07e1fc2..5b889678765c 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -425,7 +425,7 @@ tasks: python_version: "{{ python_version }}" python_abi_tag: "{{ abi_tag }}" macos_deployment_target: "12.0" - runs_on: "macos-13" + runs_on: "macos-15-intel" vcpkg_arch: "amd64" artifacts: - pyarrow-{no_rc_version}-{{ python_tag }}-{{ abi_tag }}-macosx_12_0_x86_64.whl @@ -967,7 +967,7 @@ tasks: params: target: {{ target }} use_conda: True - github_runner: "macos-13" + github_runner: "macos-15-intel" {% endfor %} {% for target in ["cpp", @@ -982,7 +982,7 @@ tasks: template: verify-rc/github.macos.yml params: target: {{ target }} - github_runner: "macos-13" + github_runner: "macos-15-intel" {% endfor %} {% for target in ["cpp", From 68e180d5335537da62b52d2c40338a0414d24f66 Mon Sep 17 00:00:00 2001 From: Tim Hurski Date: Tue, 25 Nov 2025 05:16:19 +0000 Subject: [PATCH 3/3] Add AES dispatcher function signatures for new VARBINARY mode --- cpp/src/gandiva/CMakeLists.txt | 2 + cpp/src/gandiva/encrypt_mode_dispatcher.cc | 81 +++ cpp/src/gandiva/encrypt_mode_dispatcher.h | 83 +++ cpp/src/gandiva/encrypt_utils_cbc.cc | 36 +- cpp/src/gandiva/encrypt_utils_cbc.h | 10 +- cpp/src/gandiva/encrypt_utils_cbc_test.cc | 76 +-- cpp/src/gandiva/encrypt_utils_common.cc | 19 +- cpp/src/gandiva/encrypt_utils_common.h | 4 +- cpp/src/gandiva/encrypt_utils_common_test.cc | 95 ++++ cpp/src/gandiva/function_registry_string.cc | 27 +- cpp/src/gandiva/gdv_function_stubs.cc | 505 ++++++++++--------- cpp/src/gandiva/gdv_function_stubs.h | 77 +-- cpp/src/gandiva/gdv_function_stubs_test.cc | 271 +++------- 13 files changed, 689 insertions(+), 597 deletions(-) create mode 100644 cpp/src/gandiva/encrypt_mode_dispatcher.cc create mode 100644 cpp/src/gandiva/encrypt_mode_dispatcher.h create mode 100644 cpp/src/gandiva/encrypt_utils_common_test.cc diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 8d198416d907..2b5fe0d9a6f6 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -59,6 +59,7 @@ set(SRC_FILES encrypt_utils_common.cc encrypt_utils_ecb.cc encrypt_utils_cbc.cc + encrypt_mode_dispatcher.cc expr_decomposer.cc expr_validator.cc expression.cc @@ -260,6 +261,7 @@ add_gandiva_test(internals-test tree_expr_test.cc encrypt_utils_ecb_test.cc encrypt_utils_cbc_test.cc + encrypt_utils_common_test.cc expr_decomposer_test.cc exported_funcs_registry_test.cc expression_registry_test.cc diff --git a/cpp/src/gandiva/encrypt_mode_dispatcher.cc b/cpp/src/gandiva/encrypt_mode_dispatcher.cc new file mode 100644 index 000000000000..dc93779ed088 --- /dev/null +++ b/cpp/src/gandiva/encrypt_mode_dispatcher.cc @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License") you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_mode_dispatcher.h" +#include "gandiva/encrypt_utils_ecb.h" +#include "gandiva/encrypt_utils_cbc.h" +#include "arrow/util/string.h" +#include +#include +#include + +namespace gandiva { + +int32_t EncryptModeDispatcher::encrypt( + const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, const char* mode, int32_t mode_len, const char* iv, + int32_t iv_len, const char* fifth_argument, int32_t fifth_argument_len, + unsigned char* cipher) { + std::string mode_str = + arrow::internal::AsciiToUpper(std::string_view(mode, mode_len)); + + if (mode_str == "AES-ECB") { + return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, cipher); + } else if (mode_str == "AES-CBC-PKCS7") { + return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, + iv, iv_len, true, cipher); + } else if (mode_str == "AES-CBC-NONE") { + return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, + iv, iv_len, false, cipher); + } else if (mode_str == "AES-GCM") { + throw std::runtime_error("AES-GCM encryption mode is not yet implemented"); + } else { + std::ostringstream oss; + oss << "Unsupported encryption mode: " << mode_str + << ". Supported modes: AES-ECB, AES-CBC-PKCS7, AES-CBC-NONE"; + throw std::runtime_error(oss.str()); + } +} + +int32_t EncryptModeDispatcher::decrypt( + const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, const char* mode, int32_t mode_len, const char* iv, + int32_t iv_len, const char* fifth_argument, int32_t fifth_argument_len, + unsigned char* plaintext) { + std::string mode_str = + arrow::internal::AsciiToUpper(std::string_view(mode, mode_len)); + + if (mode_str == "AES-ECB") { + return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, plaintext); + } else if (mode_str == "AES-CBC-PKCS7") { + return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, + iv, iv_len, true, plaintext); + } else if (mode_str == "AES-CBC-NONE") { + return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, + iv, iv_len, false, plaintext); + } else if (mode_str == "AES-GCM") { + throw std::runtime_error("AES-GCM decryption mode is not yet implemented"); + } else { + std::ostringstream oss; + oss << "Unsupported decryption mode: " << mode_str + << ". Supported modes: AES-ECB, AES-CBC-PKCS7, AES-CBC-NONE"; + throw std::runtime_error(oss.str()); + } +} + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_mode_dispatcher.h b/cpp/src/gandiva/encrypt_mode_dispatcher.h new file mode 100644 index 000000000000..20326845bd02 --- /dev/null +++ b/cpp/src/gandiva/encrypt_mode_dispatcher.h @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License") you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef GANDIVA_ENCRYPT_MODE_DISPATCHER_H +#define GANDIVA_ENCRYPT_MODE_DISPATCHER_H + +#include + +namespace gandiva { + +/** + * Dispatcher for AES encryption/decryption based on mode string. + * Routes calls to appropriate implementation. + */ +class EncryptModeDispatcher { + public: + /** + * Encrypt data using the specified mode + * + * @param plaintext The data to encrypt + * @param plaintext_len Length of plaintext in bytes + * @param key The encryption key + * @param key_len Length of key in bytes + * @param mode Mode string + * @param mode_len Length of mode string in bytes + * @param iv The initialization vector (optional, only for modes that support it) + * @param iv_len Length of the IV in bytes + * @param fifth_argument Additional parameter (optional, only for modes that support it) + * @param fifth_argument_len Length of fifth_argument in bytes + * @param cipher Output buffer for encrypted data + * @return Length of encrypted data in bytes + * @throws std::runtime_error on encryption failure or unsupported mode + */ + static int32_t encrypt(const char* plaintext, int32_t plaintext_len, + const char* key, int32_t key_len, + const char* mode, int32_t mode_len, + const char* iv, int32_t iv_len, + const char* fifth_argument, int32_t fifth_argument_len, + unsigned char* cipher); + + /** + * Decrypt data using the specified mode + * + * @param ciphertext The data to decrypt + * @param ciphertext_len Length of ciphertext in bytes + * @param key The decryption key + * @param key_len Length of key in bytes + * @param mode Mode string + * @param mode_len Length of mode string in bytes + * @param iv The initialization vector (optional, only for modes that support it) + * @param iv_len Length of the IV in bytes + * @param fifth_argument Additional parameter (optional, only for modes that support it) + * @param fifth_argument_len Length of fifth_argument in bytes + * @param plaintext Output buffer for decrypted data + * @return Length of decrypted data in bytes + * @throws std::runtime_error on decryption failure or unsupported mode + */ + static int32_t decrypt(const char* ciphertext, int32_t ciphertext_len, + const char* key, int32_t key_len, + const char* mode, int32_t mode_len, + const char* iv, int32_t iv_len, + const char* fifth_argument, int32_t fifth_argument_len, + unsigned char* plaintext); +}; + +} // namespace gandiva + +#endif // GANDIVA_ENCRYPT_MODE_DISPATCHER_H + diff --git a/cpp/src/gandiva/encrypt_utils_cbc.cc b/cpp/src/gandiva/encrypt_utils_cbc.cc index 15ac0413a75c..04eb60c96a77 100644 --- a/cpp/src/gandiva/encrypt_utils_cbc.cc +++ b/cpp/src/gandiva/encrypt_utils_cbc.cc @@ -28,12 +28,6 @@ namespace gandiva { namespace { -// Padding mode enum -enum class PaddingMode { - PKCS7, - NONE -}; - const EVP_CIPHER* get_cbc_cipher_algo(int32_t key_length) { switch (key_length) { case 16: @@ -51,30 +45,12 @@ const EVP_CIPHER* get_cbc_cipher_algo(int32_t key_length) { } } -PaddingMode get_padding_mode(const char* padding_str, int32_t padding_len) { - if (padding_str == nullptr || padding_len <= 0) { - throw std::runtime_error("Invalid padding parameter: null or empty"); - } - - // Case-insensitive comparison using strncasecmp - if (strncasecmp(padding_str, "PKCS7", padding_len) == 0 && padding_len == 5) { - return PaddingMode::PKCS7; - } else if (strncasecmp(padding_str, "NONE", padding_len) == 0 && padding_len == 4) { - return PaddingMode::NONE; - } else { - std::ostringstream oss; - oss << "Invalid padding mode: '" << std::string(padding_str, padding_len) - << "'. Supported modes: PKCS7, NONE (case-insensitive)"; - throw std::runtime_error(oss.str()); - } -} - } // namespace GANDIVA_EXPORT int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char* key, int32_t key_len, const char* iv, int32_t iv_len, - const char* padding, int32_t padding_len, unsigned char* cipher) { + bool use_padding, unsigned char* cipher) { // Validate IV length if (iv_len != 16) { std::ostringstream oss; @@ -83,8 +59,6 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char throw std::runtime_error(oss.str()); } - PaddingMode padding_mode = get_padding_mode(padding, padding_len); - int32_t cipher_len = 0; int32_t len = 0; EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); @@ -103,7 +77,7 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char get_openssl_error_string()); } - int padding_flag = (padding_mode == PaddingMode::PKCS7) ? 1 : 0; + int padding_flag = use_padding ? 1 : 0; if (!EVP_CIPHER_CTX_set_padding(en_ctx, padding_flag)) { EVP_CIPHER_CTX_free(en_ctx); throw std::runtime_error("Could not set padding mode for encryption: " + @@ -135,7 +109,7 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char GANDIVA_EXPORT int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const char* key, int32_t key_len, const char* iv, int32_t iv_len, - const char* padding, int32_t padding_len, unsigned char* plaintext) { + bool use_padding, unsigned char* plaintext) { // Validate IV length if (iv_len != 16) { std::ostringstream oss; @@ -144,8 +118,6 @@ int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const ch throw std::runtime_error(oss.str()); } - PaddingMode padding_mode = get_padding_mode(padding, padding_len); - int32_t plaintext_len = 0; int32_t len = 0; EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); @@ -164,7 +136,7 @@ int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const ch get_openssl_error_string()); } - int padding_flag = (padding_mode == PaddingMode::PKCS7) ? 1 : 0; + int padding_flag = use_padding ? 1 : 0; if (!EVP_CIPHER_CTX_set_padding(de_ctx, padding_flag)) { EVP_CIPHER_CTX_free(de_ctx); throw std::runtime_error("Could not set padding mode for decryption: " + diff --git a/cpp/src/gandiva/encrypt_utils_cbc.h b/cpp/src/gandiva/encrypt_utils_cbc.h index 9b5bcaa6ab23..5cec4eb39c30 100644 --- a/cpp/src/gandiva/encrypt_utils_cbc.h +++ b/cpp/src/gandiva/encrypt_utils_cbc.h @@ -32,8 +32,7 @@ namespace gandiva { * @param key_len Length of key in bytes * @param iv The initialization vector (must be exactly 16 bytes) * @param iv_len Length of IV in bytes (must be 16) - * @param padding Padding mode string: "PKCS7" or "NONE" (case-insensitive) - * @param padding_len Length of padding string in bytes + * @param use_padding Whether to use PKCS7 padding (true) or no padding (false) * @param cipher Output buffer for encrypted data * @return Length of encrypted data in bytes * @throws std::runtime_error on encryption failure or invalid parameters @@ -41,7 +40,7 @@ namespace gandiva { GANDIVA_EXPORT int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char* key, int32_t key_len, const char* iv, int32_t iv_len, - const char* padding, int32_t padding_len, unsigned char* cipher); + bool use_padding, unsigned char* cipher); /** * Decrypt data using AES-CBC algorithm with explicit padding mode @@ -52,8 +51,7 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char * @param key_len Length of key in bytes * @param iv The initialization vector (must be exactly 16 bytes) * @param iv_len Length of IV in bytes (must be 16) - * @param padding Padding mode string: "PKCS7" or "NONE" (case-insensitive) - * @param padding_len Length of padding string in bytes + * @param use_padding Whether to use PKCS7 padding (true) or no padding (false) * @param plaintext Output buffer for decrypted data * @return Length of decrypted data in bytes * @throws std::runtime_error on decryption failure or invalid parameters @@ -61,7 +59,7 @@ int32_t aes_encrypt_cbc(const char* plaintext, int32_t plaintext_len, const char GANDIVA_EXPORT int32_t aes_decrypt_cbc(const char* ciphertext, int32_t ciphertext_len, const char* key, int32_t key_len, const char* iv, int32_t iv_len, - const char* padding, int32_t padding_len, unsigned char* plaintext); + bool use_padding, unsigned char* plaintext); } // namespace gandiva diff --git a/cpp/src/gandiva/encrypt_utils_cbc_test.cc b/cpp/src/gandiva/encrypt_utils_cbc_test.cc index 88e5403d1206..8bf9227d65b4 100644 --- a/cpp/src/gandiva/encrypt_utils_cbc_test.cc +++ b/cpp/src/gandiva/encrypt_utils_cbc_test.cc @@ -18,6 +18,7 @@ #include "gandiva/encrypt_utils_cbc.h" #include +#include #include // Test PKCS#7 padding with 16-byte key @@ -32,12 +33,12 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_16) { unsigned char cipher[64]; int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, - iv, iv_len, "PKCS7", 5, cipher); + iv, iv_len, true, cipher); unsigned char decrypted[64]; int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, - "PKCS7", 5, decrypted); + true, decrypted); EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), std::string(reinterpret_cast(decrypted), decrypted_len)); @@ -55,12 +56,12 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_24) { unsigned char cipher[64]; int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, - iv, iv_len, "PKCS7", 5, cipher); + iv, iv_len, true, cipher); unsigned char decrypted[64]; int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, - "PKCS7", 5, decrypted); + true, decrypted); EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), std::string(reinterpret_cast(decrypted), decrypted_len)); @@ -78,12 +79,12 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptPkcs7_32) { unsigned char cipher[64]; int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, - iv, iv_len, "PKCS7", 5, cipher); + iv, iv_len, true, cipher); unsigned char decrypted[64]; int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, - "PKCS7", 5, decrypted); + true, decrypted); EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), std::string(reinterpret_cast(decrypted), decrypted_len)); @@ -101,43 +102,17 @@ TEST(TestAesCbcEncryptUtils, TestAesEncryptDecryptNoPadding_16) { unsigned char cipher[64]; int32_t cipher_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, - iv, iv_len, "NONE", 4, cipher); + iv, iv_len, false, cipher); unsigned char decrypted[64]; int32_t decrypted_len = gandiva::aes_decrypt_cbc(reinterpret_cast(cipher), cipher_len, key, key_len, iv, iv_len, - "NONE", 4, decrypted); + false, decrypted); EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), std::string(reinterpret_cast(decrypted), decrypted_len)); } -// Test case-insensitive padding mode -TEST(TestAesCbcEncryptUtils, TestCaseInsensitivePadding) { - auto* key = "12345678abcdefgh"; - auto* iv = "1234567890123456"; - auto* to_encrypt = "test"; - - auto key_len = static_cast(strlen(key)); - auto iv_len = static_cast(strlen(iv)); - auto to_encrypt_len = static_cast(strlen(to_encrypt)); - unsigned char cipher1[64]; - unsigned char cipher2[64]; - - // Test with "pkcs7" (lowercase) - int32_t cipher1_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, - iv, iv_len, "pkcs7", 5, cipher1); - - // Test with "PKCS7" (uppercase) - int32_t cipher2_len = gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, - iv, iv_len, "PKCS7", 5, cipher2); - - // Both should produce same ciphertext - EXPECT_EQ(cipher1_len, cipher2_len); - EXPECT_EQ(std::string(reinterpret_cast(cipher1), cipher1_len), - std::string(reinterpret_cast(cipher2), cipher2_len)); -} - // Test invalid IV length TEST(TestAesCbcEncryptUtils, TestInvalidIVLength) { auto* key = "12345678abcdefgh"; @@ -149,10 +124,13 @@ TEST(TestAesCbcEncryptUtils, TestInvalidIVLength) { auto to_encrypt_len = static_cast(strlen(to_encrypt)); unsigned char cipher[64]; - ASSERT_THROW({ + try { gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, - iv, iv_len, "PKCS7", 5, cipher); - }, std::runtime_error); + iv, iv_len, true, cipher); + FAIL() << "Expected std::runtime_error"; + } catch (const std::runtime_error& e) { + EXPECT_THAT(e.what(), testing::HasSubstr("Invalid IV length for AES-CBC")); + } } // Test invalid key length @@ -166,26 +144,14 @@ TEST(TestAesCbcEncryptUtils, TestInvalidKeyLength) { auto to_encrypt_len = static_cast(strlen(to_encrypt)); unsigned char cipher[64]; - ASSERT_THROW({ + try { gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, - iv, iv_len, "PKCS7", 5, cipher); - }, std::runtime_error); + iv, iv_len, true, cipher); + FAIL() << "Expected std::runtime_error"; + } catch (const std::runtime_error& e) { + EXPECT_THAT(e.what(), testing::HasSubstr("Unsupported key length for AES-CBC")); + } } -// Test invalid padding mode -TEST(TestAesCbcEncryptUtils, TestInvalidPaddingMode) { - auto* key = "12345678abcdefgh"; - auto* iv = "1234567890123456"; - auto* to_encrypt = "test"; - auto key_len = static_cast(strlen(key)); - auto iv_len = static_cast(strlen(iv)); - auto to_encrypt_len = static_cast(strlen(to_encrypt)); - unsigned char cipher[64]; - - ASSERT_THROW({ - gandiva::aes_encrypt_cbc(to_encrypt, to_encrypt_len, key, key_len, - iv, iv_len, "INVALID", 7, cipher); - }, std::runtime_error); -} diff --git a/cpp/src/gandiva/encrypt_utils_common.cc b/cpp/src/gandiva/encrypt_utils_common.cc index b1b7d1729495..3213e0c6e1a1 100644 --- a/cpp/src/gandiva/encrypt_utils_common.cc +++ b/cpp/src/gandiva/encrypt_utils_common.cc @@ -18,15 +18,28 @@ #include "gandiva/encrypt_utils_common.h" #include #include +#include namespace gandiva { std::string get_openssl_error_string() { - unsigned long error_code = ERR_get_error(); - if (error_code == 0) { + std::string error_string; + unsigned long error_code; + char error_buffer[256]; + + // Loop through all errors in the queue + while ((error_code = ERR_get_error()) != 0) { + if (!error_string.empty()) { + error_string += "; "; + } + ERR_error_string(error_code, error_buffer); + error_string += std::string(error_buffer); + } + + if (error_string.empty()) { return "Unknown OpenSSL error"; } - return std::string(ERR_reason_error_string(error_code)); + return error_string; } } // namespace gandiva diff --git a/cpp/src/gandiva/encrypt_utils_common.h b/cpp/src/gandiva/encrypt_utils_common.h index 3887368b5a7a..62dc14db348e 100644 --- a/cpp/src/gandiva/encrypt_utils_common.h +++ b/cpp/src/gandiva/encrypt_utils_common.h @@ -23,7 +23,9 @@ namespace gandiva { /// @brief Get a human-readable error string from OpenSSL's error queue. -/// @return A string describing the most recent OpenSSL error, or "Unknown OpenSSL error" +/// @details Retrieves all errors from the OpenSSL error queue and concatenates them +/// with "; " as a separator. This ensures complete error information is captured. +/// @return A string describing all OpenSSL errors in the queue, or "Unknown OpenSSL error" /// if no error is available. std::string get_openssl_error_string(); diff --git a/cpp/src/gandiva/encrypt_utils_common_test.cc b/cpp/src/gandiva/encrypt_utils_common_test.cc new file mode 100644 index 000000000000..de55758d5377 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_common_test.cc @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_common.h" + +#include +#include +#include +#include + +// Test that get_openssl_error_string returns "Unknown OpenSSL error" when queue is empty +TEST(TestOpenSSLErrorUtils, TestEmptyErrorQueue) { + // Clear any existing errors + ERR_clear_error(); + + std::string error_string = gandiva::get_openssl_error_string(); + + EXPECT_EQ(error_string, "Unknown OpenSSL error"); +} + +// Test that get_openssl_error_string captures a single error +TEST(TestOpenSSLErrorUtils, TestSingleError) { + // Clear any existing errors + ERR_clear_error(); + + // Add a single error to the queue + ERR_raise(ERR_LIB_EVP, EVP_R_UNSUPPORTED_ALGORITHM); + + std::string error_string = gandiva::get_openssl_error_string(); + + // Verify that the error string is not empty and not the default message + EXPECT_NE(error_string, "Unknown OpenSSL error"); + EXPECT_GT(error_string.length(), 0); +} + +// Test that get_openssl_error_string captures multiple errors +TEST(TestOpenSSLErrorUtils, TestMultipleErrors) { + // Clear any existing errors + ERR_clear_error(); + + // Populate the OpenSSL error queue with multiple errors + ERR_raise(ERR_LIB_EVP, EVP_R_UNSUPPORTED_ALGORITHM); + ERR_raise(ERR_LIB_EVP, EVP_R_INVALID_KEY_LENGTH); + ERR_raise(ERR_LIB_EVP, EVP_R_INVALID_OPERATION); + + // Call our function to get all errors + std::string error_string = gandiva::get_openssl_error_string(); + + // Verify that the error string is not empty + EXPECT_NE(error_string, "Unknown OpenSSL error"); + + // Verify that all errors are captured (they should be separated by "; ") + // The exact error messages depend on OpenSSL version, so we just check + // that we got multiple errors (indicated by the separator) + EXPECT_THAT(error_string, testing::HasSubstr(";")); + + // Verify the error string contains meaningful content (not just separators) + EXPECT_GT(error_string.length(), 10); +} + +// Test that error queue is properly drained after calling get_openssl_error_string +TEST(TestOpenSSLErrorUtils, TestErrorQueueDrained) { + // Clear any existing errors + ERR_clear_error(); + + // Add errors to the queue + ERR_raise(ERR_LIB_EVP, EVP_R_UNSUPPORTED_ALGORITHM); + ERR_raise(ERR_LIB_EVP, EVP_R_INVALID_KEY_LENGTH); + + // Call our function to get all errors + std::string error_string = gandiva::get_openssl_error_string(); + + // Verify we got errors + EXPECT_NE(error_string, "Unknown OpenSSL error"); + + // Now call it again - the queue should be empty + std::string second_call = gandiva::get_openssl_error_string(); + + EXPECT_EQ(second_call, "Unknown OpenSSL error"); +} + diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 054fb813a483..721f13d90eb0 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -504,24 +504,31 @@ std::vector GetStringFunctionRegistry() { kResultNullIfNull, "gdv_fn_aes_decrypt_ecb_legacy", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), - // Binary-based signatures (BINARY, BINARY, UTF8) -> BINARY + // Parameters: data, key, mode (e.g. ECB mode) NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), - kResultNullIfNull, "gdv_fn_aes_encrypt_ecb", + kResultNullIfNull, "gdv_fn_aes_encrypt_dispatcher_3args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), - kResultNullIfNull, "gdv_fn_aes_decrypt_ecb", + kResultNullIfNull, "gdv_fn_aes_decrypt_dispatcher_3args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), - // CBC mode specific functions - // Binary-based signatures (BINARY, BINARY, UTF8, BINARY, UTF8) -> BINARY - // Parameters: data, key, mode, iv, padding - NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), utf8()}, binary(), - kResultNullIfNull, "gdv_fn_aes_encrypt_cbc", + // Parameters: data, key, mode, iv (e.g. CBC mode) + NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_aes_encrypt_dispatcher_4args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), - NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), utf8()}, binary(), - kResultNullIfNull, "gdv_fn_aes_decrypt_cbc", + NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_aes_decrypt_dispatcher_4args", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + // Parameters: data, key, mode, iv, fifth_argument (e.g. GCM mode) + NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_aes_encrypt_dispatcher_5args", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_aes_decrypt_dispatcher_5args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("mask_first_n", {}, DataTypeVector{utf8(), int32()}, utf8(), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index ef47fb68b54a..a4c98373e728 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -32,6 +32,7 @@ #include "gandiva/encrypt_utils_ecb.h" #include "gandiva/encrypt_utils_cbc.h" +#include "gandiva/encrypt_mode_dispatcher.h" #include "gandiva/engine.h" #include "gandiva/exported_funcs.h" #include "gandiva/in_holder.h" @@ -396,212 +397,7 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8) #undef GDV_FN_CAST_VARCHAR_INTEGER #undef GDV_FN_CAST_VARCHAR_REAL -// Helper function to validate AES mode parameter -// Throws std::runtime_error if mode is invalid -static void ensure_mode(const char* mode, int32_t mode_len, - const std::string& expected_mode) { - if (mode == nullptr) { - throw std::runtime_error("Invalid mode parameter for AES encryption"); - } - - std::string mode_str(mode, mode_len); - // Convert to uppercase for comparison - std::transform(mode_str.begin(), mode_str.end(), mode_str.begin(), ::toupper); - - if (mode_str != expected_mode) { - std::ostringstream oss; - oss << "AES encryption mode mismatch: function signature indicates " << expected_mode - << " mode, but '" << mode_str << "' was provided instead"; - throw std::runtime_error(oss.str()); - } -} - -// ECB mode specific functions - core implementation -// This handles both string and binary inputs (they have the same C signature) -GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_ecb(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - int32_t* out_len) { - try { - // Validate mode parameter - ensure_mode(mode, mode_len, "ECB"); - - if (data_len < 0) { - throw std::runtime_error( - std::string("Invalid data length for AES encryption: ") + std::to_string(data_len) + - " (must be >= 0)"); - } - - if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { - throw std::runtime_error( - std::string("Invalid key length for AES encryption: ") + std::to_string(key_data_len) + - " bytes. Supported lengths: 16, 24, 32 bytes"); - } - - // AES block size is always 16 bytes (128 bits), regardless of key length - int64_t kAesBlockSize = 16; - *out_len = static_cast( - arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); - if (ret == nullptr) { - throw std::runtime_error(std::string("Could not allocate memory for AES encryption output: ") + - std::to_string(*out_len) + " bytes"); - } - - *out_len = gandiva::aes_encrypt_ecb(data, data_len, key_data, key_data_len, - reinterpret_cast(ret)); - return ret; - } catch (const std::runtime_error& e) { - gdv_fn_context_set_error_msg(context, e.what()); - *out_len = 0; - return nullptr; - } -} - -// Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 -// This is called by the LLVM engine with string calling convention -// WARNING: This function is for backward compatibility only. Encrypted binary data -// is not guaranteed to be valid UTF-8. Use binary signatures for new code. -GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_ecb_legacy(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - int32_t* out_len) { - // Delegate to the core implementation with ECB mode - const char* mode = "ECB"; - int32_t mode_len = 3; - const char* result = gdv_fn_aes_encrypt_ecb(context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); - - // Add null terminator for string compatibility - // Note: This may not be valid UTF-8, but it's needed for string handling - if (result != nullptr) { - char* mutable_result = const_cast(result); - mutable_result[*out_len] = '\0'; - } - - return result; -} - -GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_ecb(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - int32_t* out_len) { - try { - // Validate mode parameter - ensure_mode(mode, mode_len, "ECB"); - - if (data_len < 0) { - throw std::runtime_error( - std::string("Invalid data length for AES decryption: ") + std::to_string(data_len) + - " (must be >= 0)"); - } - - if (key_data_len != 16 && key_data_len != 24 && key_data_len != 32) { - throw std::runtime_error( - std::string("Invalid key length for AES decryption: ") + std::to_string(key_data_len) + - " bytes. Supported lengths: 16, 24, 32 bytes"); - } - - // AES block size is always 16 bytes (128 bits), regardless of key length - int64_t kAesBlockSize = 16; - *out_len = static_cast( - arrow::bit_util::RoundUpToPowerOf2(static_cast(data_len), kAesBlockSize)); - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); - if (ret == nullptr) { - throw std::runtime_error(std::string("Could not allocate memory for AES decryption output: ") + - std::to_string(*out_len) + " bytes"); - } - - *out_len = gandiva::aes_decrypt_ecb(data, data_len, key_data, key_data_len, - reinterpret_cast(ret)); - return ret; - } catch (const std::runtime_error& e) { - gdv_fn_context_set_error_msg(context, e.what()); - *out_len = 0; - return nullptr; - } -} -// Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 -// This is called by the LLVM engine with string calling convention -// WARNING: This function is for backward compatibility only. Decrypted data may not be -// valid UTF-8. Use binary signatures for new code. -GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - int32_t* out_len) { - // Delegate to the core implementation with ECB mode - const char* mode = "ECB"; - int32_t mode_len = 3; - const char* result = gdv_fn_aes_decrypt_ecb(context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); - - // Add null terminator for string compatibility - // Note: This may not be valid UTF-8, but it's needed for string handling - if (result != nullptr) { - char* mutable_result = const_cast(result); - mutable_result[*out_len] = '\0'; - } - - return result; -} - -// CBC mode specific functions - core implementation with explicit padding -GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, - const char* padding, int32_t padding_len, - int32_t* out_len) { - try { - // Validate mode parameter - ensure_mode(mode, mode_len, "CBC"); - - // Allocate output buffer (max size: input + 16 bytes for padding) - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len + 16)); - if (ret == nullptr) { - throw std::runtime_error("Could not allocate memory for AES-CBC encryption"); - } - - *out_len = gandiva::aes_encrypt_cbc(data, data_len, key_data, key_data_len, iv_data, iv_data_len, - padding, padding_len, - reinterpret_cast(ret)); - return ret; - } catch (const std::runtime_error& e) { - gdv_fn_context_set_error_msg(context, e.what()); - *out_len = 0; - return nullptr; - } -} - -GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, - const char* padding, int32_t padding_len, - int32_t* out_len) { - try { - // Validate mode parameter - ensure_mode(mode, mode_len, "CBC"); - - // Allocate output buffer (max size: input size, since decryption removes padding) - char* ret = reinterpret_cast(gdv_fn_context_arena_malloc(context, data_len)); - if (ret == nullptr) { - throw std::runtime_error("Could not allocate memory for AES-CBC decryption"); - } - - *out_len = gandiva::aes_decrypt_cbc(data, data_len, key_data, key_data_len, iv_data, iv_data_len, - padding, padding_len, - reinterpret_cast(ret)); - return ret; - } catch (const std::runtime_error& e) { - gdv_fn_context_set_error_msg(context, e.what()); - *out_len = 0; - return nullptr; - } -} GANDIVA_EXPORT @@ -1058,6 +854,158 @@ const char* gdv_mask_show_last_n_utf8_int32(int64_t context, const char* data, namespace gandiva { +// Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 +// This is called by the LLVM engine with string calling convention +// WARNING: This function is for backward compatibility only. Encrypted binary +// data is not guaranteed to be valid UTF-8. Use binary signatures for new code. +extern "C" GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_ecb_legacy(int64_t context, const char* data, + int32_t data_len, + const char* key_data, + int32_t key_data_len, + int32_t* out_len) { + // Delegate to the core implementation with ECB mode + // This function is ECB-only, so we enforce the mode + const char* mode = "AES-ECB"; + int32_t mode_len = 7; + const char* result = gdv_fn_aes_encrypt_dispatcher_3args( + context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); + + // Add null terminator for string compatibility + // Note: This may not be valid UTF-8, but it's needed for string handling + if (result != nullptr) { + char* mutable_result = const_cast(result); + mutable_result[*out_len] = '\0'; + } + + return result; +} + +// Legacy wrapper for string-based signatures (UTF8, UTF8) -> UTF8 +// This is called by the LLVM engine with string calling convention +// WARNING: This function is for backward compatibility only. Decrypted data +// may not be valid UTF-8. Use binary signatures for new code. +extern "C" GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, + int32_t data_len, + const char* key_data, + int32_t key_data_len, + int32_t* out_len) { + // Delegate to the core implementation with ECB mode + // This function is ECB-only, so we enforce the mode + const char* mode = "AES-ECB"; + int32_t mode_len = 7; + const char* result = gdv_fn_aes_decrypt_dispatcher_3args( + context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); + + // Add null terminator for string compatibility + // Note: This may not be valid UTF-8, but it's needed for string handling + if (result != nullptr) { + char* mutable_result = const_cast(result); + mutable_result[*out_len] = '\0'; + } + + return result; +} + +// The 3- and 4-arg signatures exist to support optional IV and other arguments +extern "C" GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_dispatcher_3args( + int64_t context, const char* data, int32_t data_len, const char* key_data, + int32_t key_data_len, const char* mode, int32_t mode_len, + int32_t* out_len) { + return gdv_fn_aes_encrypt_dispatcher_5args( + context, data, data_len, key_data, key_data_len, mode, mode_len, nullptr, + 0, nullptr, 0, out_len); +} + +extern "C" GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_dispatcher_3args( + int64_t context, const char* data, int32_t data_len, const char* key_data, + int32_t key_data_len, const char* mode, int32_t mode_len, + int32_t* out_len) { + return gdv_fn_aes_decrypt_dispatcher_5args( + context, data, data_len, key_data, key_data_len, mode, mode_len, nullptr, + 0, nullptr, 0, out_len); +} + +extern "C" GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_dispatcher_4args( + int64_t context, const char* data, int32_t data_len, const char* key_data, + int32_t key_data_len, const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, int32_t* out_len) { + return gdv_fn_aes_encrypt_dispatcher_5args( + context, data, data_len, key_data, key_data_len, mode, mode_len, iv_data, + iv_data_len, nullptr, 0, out_len); +} + +extern "C" GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_dispatcher_4args( + int64_t context, const char* data, int32_t data_len, const char* key_data, + int32_t key_data_len, const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, int32_t* out_len) { + return gdv_fn_aes_decrypt_dispatcher_5args( + context, data, data_len, key_data, key_data_len, mode, mode_len, iv_data, + iv_data_len, nullptr, 0, out_len); +} + +extern "C" GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_dispatcher_5args( + int64_t context, const char* data, int32_t data_len, const char* key_data, + int32_t key_data_len, const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, const char* fifth_argument, + int32_t fifth_argument_len, int32_t* out_len) { + try { + // Allocate extra 16 bytes for AES block padding (PKCS7 padding can add + // up to 16 bytes for a 128-bit block cipher) + // In cases of no-padding modes, this extra space is not used + auto* output = reinterpret_cast( + gdv_fn_context_arena_malloc(context, data_len + 16)); + if (output == nullptr) { + throw std::runtime_error( + "Memory allocation failed for encryption output"); + } + + int32_t cipher_len = EncryptModeDispatcher::encrypt( + data, data_len, key_data, key_data_len, mode, mode_len, iv_data, + iv_data_len, fifth_argument, fifth_argument_len, output); + + *out_len = cipher_len; + return reinterpret_cast(output); + } catch (const std::runtime_error& e) { + gdv_fn_context_set_error_msg(context, e.what()); + *out_len = 0; + return nullptr; + } +} + +extern "C" GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_dispatcher_5args( + int64_t context, const char* data, int32_t data_len, const char* key_data, + int32_t key_data_len, const char* mode, int32_t mode_len, + const char* iv_data, int32_t iv_data_len, const char* fifth_argument, + int32_t fifth_argument_len, int32_t* out_len) { + try { + auto* output = reinterpret_cast( + gdv_fn_context_arena_malloc(context, data_len)); + if (output == nullptr) { + throw std::runtime_error( + "Memory allocation failed for decryption output"); + } + + int32_t plaintext_len = EncryptModeDispatcher::decrypt( + data, data_len, key_data, key_data_len, mode, mode_len, iv_data, + iv_data_len, fifth_argument, fifth_argument_len, output); + + *out_len = plaintext_len; + return reinterpret_cast(output); + } catch (const std::runtime_error& e) { + gdv_fn_context_set_error_msg(context, e.what()); + *out_len = 0; + return nullptr; + } +} + arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { std::vector args; auto types = engine->types(); @@ -1248,73 +1196,86 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_base64_decode_utf8)); - // gdv_fn_aes_encrypt_ecb - // Note: The mode parameter is passed as a UTF8 string (data + length) - // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, out_len) + // gdv_fn_aes_encrypt_ecb_legacy (wrapper for string-based signatures) + // Note: Mode is hardcoded internally as "ECB", not passed as parameter + // Function signature: (context, data, data_len, key_data, key_data_len, out_len) args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length - types->i8_ptr_type(), // mode (UTF8 string) - types->i32_type(), // mode_length types->i32_ptr_type() // out_length }; - engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt_ecb", + engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt_ecb_legacy", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_encrypt_ecb)); + reinterpret_cast(gdv_fn_aes_encrypt_ecb_legacy)); - // gdv_fn_aes_decrypt_ecb - // Note: The mode parameter is passed as a UTF8 string (data + length) - // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, out_len) + // gdv_fn_aes_decrypt_ecb_legacy (wrapper for string-based signatures) args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length - types->i8_ptr_type(), // mode (UTF8 string) - types->i32_type(), // mode_length types->i32_ptr_type() // out_length }; - engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt_ecb", + engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt_ecb_legacy", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_decrypt_ecb)); + reinterpret_cast(gdv_fn_aes_decrypt_ecb_legacy)); - // gdv_fn_aes_encrypt_ecb_legacy (wrapper for string-based signatures) args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (binary string) + types->i32_type(), // mode_length + types->i8_ptr_type(), // iv (binary string) + types->i32_type(), // iv_length + types->i8_ptr_type(), // fifth_argument (binary string) + types->i32_type(), // fifth_argument_length types->i32_ptr_type() // out_length }; - engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt_ecb_legacy", - types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_encrypt_ecb_legacy)); + // gdv_fn_aes_encrypt_dispatcher_3args (data, key, mode) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (binary string) + types->i32_type(), // mode_length + types->i32_ptr_type() // out_length + }; - // gdv_fn_aes_decrypt_ecb_legacy (wrapper for string-based signatures) + engine->AddGlobalMappingForFunc( + "gdv_fn_aes_encrypt_dispatcher_3args", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_encrypt_dispatcher_3args)); + + // gdv_fn_aes_decrypt_dispatcher_3args (data, key, mode) args = { types->i64_type(), // context types->i8_ptr_type(), // data types->i32_type(), // data_length types->i8_ptr_type(), // key_data types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (binary string) + types->i32_type(), // mode_length types->i32_ptr_type() // out_length }; - engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt_ecb_legacy", - types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_decrypt_ecb_legacy)); + engine->AddGlobalMappingForFunc( + "gdv_fn_aes_decrypt_dispatcher_3args", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_decrypt_dispatcher_3args)); - // gdv_fn_aes_encrypt_cbc - // Note: The mode, IV and padding parameters are passed as binary/UTF8 strings (data + length) - // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) + // gdv_fn_aes_encrypt_dispatcher_4args (data, key, mode, iv) args = { types->i64_type(), // context types->i8_ptr_type(), // data @@ -1325,18 +1286,15 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i32_type(), // mode_length types->i8_ptr_type(), // iv (binary string) types->i32_type(), // iv_length - types->i8_ptr_type(), // padding (binary string) - types->i32_type(), // padding_length types->i32_ptr_type() // out_length }; - engine->AddGlobalMappingForFunc("gdv_fn_aes_encrypt_cbc", - types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_encrypt_cbc)); + engine->AddGlobalMappingForFunc( + "gdv_fn_aes_encrypt_dispatcher_4args", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_encrypt_dispatcher_4args)); - // gdv_fn_aes_decrypt_cbc - // Note: The mode, IV and padding parameters are passed as binary/UTF8 strings (data + length) - // Function signature: (context, data, data_len, key_data, key_data_len, mode, mode_len, iv, iv_len, padding, padding_len, out_len) + // gdv_fn_aes_decrypt_dispatcher_4args (data, key, mode, iv) args = { types->i64_type(), // context types->i8_ptr_type(), // data @@ -1347,14 +1305,57 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i32_type(), // mode_length types->i8_ptr_type(), // iv (binary string) types->i32_type(), // iv_length - types->i8_ptr_type(), // padding (binary string) - types->i32_type(), // padding_length types->i32_ptr_type() // out_length }; - engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt_cbc", - types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_decrypt_cbc)); + engine->AddGlobalMappingForFunc( + "gdv_fn_aes_decrypt_dispatcher_4args", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_decrypt_dispatcher_4args)); + + // gdv_fn_aes_encrypt_dispatcher_5args (data, key, mode, iv, + // fifth_argument) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (binary string) + types->i32_type(), // mode_length + types->i8_ptr_type(), // iv (binary string) + types->i32_type(), // iv_length + types->i8_ptr_type(), // fifth_argument (binary string) + types->i32_type(), // fifth_argument_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc( + "gdv_fn_aes_encrypt_dispatcher_5args", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_encrypt_dispatcher_5args)); + + // gdv_fn_aes_decrypt_dispatcher_5args (data, key, mode, iv, + // fifth_argument) + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i8_ptr_type(), // key_data + types->i32_type(), // key_data_length + types->i8_ptr_type(), // mode (binary string) + types->i32_type(), // mode_length + types->i8_ptr_type(), // iv (binary string) + types->i32_type(), // iv_length + types->i8_ptr_type(), // fifth_argument (binary string) + types->i32_type(), // fifth_argument_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc( + "gdv_fn_aes_decrypt_dispatcher_5args", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_aes_decrypt_dispatcher_5args)); // gdv_mask_first_n and gdv_mask_last_n std::vector mask_args = { @@ -1453,7 +1454,8 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { }; engine->AddGlobalMappingForFunc( - "gdv_fn_cast_intervalday_utf8_int32", types->i64_type() /*return_type*/, args, + "gdv_fn_cast_intervalday_utf8_int32", + types->i64_type() /*return_type*/, args, reinterpret_cast(gdv_fn_cast_intervalday_utf8_int32)); // gdv_fn_cast_intervalyear_utf8 @@ -1470,12 +1472,15 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i32_type() /*return_type*/, args, reinterpret_cast(gdv_fn_cast_intervalyear_utf8)); -#define ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(LLVM_TYPE, DATA_TYPE) \ - args = {types->i64_type(), types->i8_ptr_type(), types->i32_ptr_type(), \ - types->i64_type(), types->LLVM_TYPE##_ptr_type(), types->i32_type(), types->i32_ptr_type()}; \ - engine->AddGlobalMappingForFunc( \ - "gdv_fn_populate_list_" #DATA_TYPE "_vector", types->i32_type() /*return_type*/, \ - args, reinterpret_cast(gdv_fn_populate_list_##DATA_TYPE##_vector)); +#define ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION( \ + LLVM_TYPE, DATA_TYPE) \ + args = {types->i64_type(), types->i8_ptr_type(), types->i32_ptr_type(), \ + types->i64_type(), types->LLVM_TYPE##_ptr_type(), \ + types->i32_type(), types->i32_ptr_type()}; \ + engine->AddGlobalMappingForFunc( \ + "gdv_fn_populate_list_" #DATA_TYPE "_vector", \ + types->i32_type() /*return_type*/, args, \ + reinterpret_cast(gdv_fn_populate_list_##DATA_TYPE##_vector)); ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(i32, int32_t) ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(i64, int64_t) diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 324ecfb2c03d..8224378f60a5 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -189,46 +189,65 @@ float gdv_fn_castFLOAT4_varbinary(gdv_int64 context, const char* in, int32_t in_ GANDIVA_EXPORT double gdv_fn_castFLOAT8_varbinary(gdv_int64 context, const char* in, int32_t in_len); -// ECB mode specific functions +// Legacy wrappers for string-based AES-ECB signatures GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_ecb(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - int32_t* out_len); +const char* gdv_fn_aes_encrypt_ecb_legacy(int64_t context, const char* data, + int32_t data_len, + const char* key_data, + int32_t key_data_len, + int32_t* out_len); GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_ecb(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - int32_t* out_len); +const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, + int32_t data_len, + const char* key_data, + int32_t key_data_len, + int32_t* out_len); -// Legacy wrappers for string-based signatures +// 3-argument dispatcher: (data, key, mode) GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_ecb_legacy(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - int32_t* out_len); +const char* gdv_fn_aes_encrypt_dispatcher_3args( + int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, const char* mode, + int32_t mode_len, int32_t* out_len); GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - int32_t* out_len); +const char* gdv_fn_aes_decrypt_dispatcher_3args( + int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, const char* mode, + int32_t mode_len, int32_t* out_len); -// CBC mode specific functions +// 4-argument dispatcher: (data, key, mode, iv) GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_cbc(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, - const char* padding, int32_t padding_len, - int32_t* out_len); +const char* gdv_fn_aes_encrypt_dispatcher_4args( + int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, const char* mode, + int32_t mode_len, const char* iv_data, int32_t iv_data_len, + int32_t* out_len); GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_cbc(int64_t context, const char* data, int32_t data_len, - const char* key_data, int32_t key_data_len, - const char* mode, int32_t mode_len, - const char* iv_data, int32_t iv_data_len, - const char* padding, int32_t padding_len, - int32_t* out_len); +const char* gdv_fn_aes_decrypt_dispatcher_4args( + int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, const char* mode, + int32_t mode_len, const char* iv_data, int32_t iv_data_len, + int32_t* out_len); + +// 5-argument dispatcher: (data, key, mode, iv, fifth_argument) +GANDIVA_EXPORT +const char* gdv_fn_aes_encrypt_dispatcher_5args( + int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, const char* mode, + int32_t mode_len, const char* iv_data, int32_t iv_data_len, + const char* fifth_argument, int32_t fifth_argument_len, + int32_t* out_len); + +GANDIVA_EXPORT +const char* gdv_fn_aes_decrypt_dispatcher_5args( + int64_t context, const char* data, int32_t data_len, + const char* key_data, int32_t key_data_len, const char* mode, + int32_t mode_len, const char* iv_data, int32_t iv_data_len, + const char* fifth_argument, int32_t fifth_argument_len, + int32_t* out_len); GANDIVA_EXPORT const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index f70bbb9d26df..7a7cc361d1e2 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1353,14 +1353,20 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "ECB"; + std::string mode = "AES-ECB"; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &cipher_len); - const char* decrypted_value = gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &decrypted_len); + const char* cipher = gdv_fn_aes_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, &cipher_len); + const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, &decrypted_len); - EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { @@ -1371,15 +1377,21 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "ECB"; + std::string mode = "AES-ECB"; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, mode.c_str(), mode_len, &cipher_len); + const char* cipher = gdv_fn_aes_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, mode.c_str(), + mode_len, &cipher_len); - const char* decrypted_value = gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, mode.c_str(), mode_len, &decrypted_len); + const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, mode.c_str(), + mode_len, &decrypted_len); - EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { @@ -1390,15 +1402,21 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "ECB"; + std::string mode = "AES-ECB"; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, mode.c_str(), mode_len, &cipher_len); + const char* cipher = gdv_fn_aes_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, mode.c_str(), + mode_len, &cipher_len); - const char* decrypted_value = gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, mode.c_str(), mode_len, &decrypted_len); + const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, mode.c_str(), + mode_len, &decrypted_len); - EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) { @@ -1408,20 +1426,24 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "ECB"; + std::string mode = "AES-ECB"; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); std::string cipher = "12345678abcdefgh12345678abcdefghb"; auto cipher_len = static_cast(cipher.length()); - gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, mode.c_str(), mode_len, &cipher_len); + gdv_fn_aes_encrypt_dispatcher_3args(ctx_ptr, data.c_str(), data_len, + key33.c_str(), key33_len, mode.c_str(), + mode_len, &cipher_len); EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Invalid key length for AES encryption")); + ::testing::HasSubstr("Unsupported key length for AES-ECB")); ctx.Reset(); - gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, mode.c_str(), mode_len, &decrypted_len); + gdv_fn_aes_decrypt_dispatcher_3args(ctx_ptr, cipher.c_str(), cipher_len, + key33.c_str(), key33_len, mode.c_str(), + mode_len, &decrypted_len); EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("Invalid key length for AES decryption")); + ::testing::HasSubstr("Unsupported key length for AES-ECB")); ctx.Reset(); } @@ -1434,17 +1456,21 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeEcb) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "ECB"; + std::string mode = "AES-ECB"; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key16.c_str(), - key16_len, mode.c_str(), mode_len, &cipher_len); + const char* cipher = gdv_fn_aes_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, &cipher_len); EXPECT_GT(cipher_len, 0); - const char* decrypted_value = gdv_fn_aes_decrypt_ecb( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &decrypted_len); - EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); + const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, &decrypted_len); + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); } TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { @@ -1455,205 +1481,28 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string invalid_mode = "CBC"; + std::string invalid_mode = "AES-INVALID"; auto invalid_mode_len = static_cast(invalid_mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); // Test encrypt with invalid mode - gdv_fn_aes_encrypt_ecb(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, - invalid_mode.c_str(), invalid_mode_len, &cipher_len); + gdv_fn_aes_encrypt_dispatcher_3args(ctx_ptr, data.c_str(), data_len, + key16.c_str(), key16_len, + invalid_mode.c_str(), invalid_mode_len, + &cipher_len); EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("AES encryption mode mismatch")); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("CBC")); + ::testing::HasSubstr("Unsupported encryption mode")); ctx.Reset(); // Test decrypt with invalid mode std::string cipher = "12345678abcdefgh12345678abcdefgh"; auto cipher_len_val = static_cast(cipher.length()); - gdv_fn_aes_decrypt_ecb(ctx_ptr, cipher.c_str(), cipher_len_val, key16.c_str(), key16_len, - invalid_mode.c_str(), invalid_mode_len, &decrypted_len); - EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("AES encryption mode mismatch")); + gdv_fn_aes_decrypt_dispatcher_3args(ctx_ptr, cipher.c_str(), cipher_len_val, + key16.c_str(), key16_len, + invalid_mode.c_str(), invalid_mode_len, + &decrypted_len); EXPECT_THAT(ctx.get_error(), - ::testing::HasSubstr("CBC")); - ctx.Reset(); -} - -// Tests for CBC mode encryption/decryption -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbc) { - gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; - auto iv_len = static_cast(iv.length()); - int32_t cipher_len = 0; - int32_t decrypted_len = 0; - std::string data = "test string"; - auto data_len = static_cast(data.length()); - std::string mode = "CBC"; - auto mode_len = static_cast(mode.length()); - std::string padding = "PKCS7"; - auto padding_len = static_cast(padding.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - - const char* cipher = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), - key16_len, mode.c_str(), mode_len, iv.c_str(), - iv_len, padding.c_str(), padding_len, &cipher_len); - EXPECT_GT(cipher_len, 0); - - const char* decrypted_value = gdv_fn_aes_decrypt_cbc( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, - iv.c_str(), iv_len, padding.c_str(), padding_len, &decrypted_len); - EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); -} - -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcNoPadding) { - gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; - auto iv_len = static_cast(iv.length()); - int32_t cipher_len = 0; - int32_t decrypted_len = 0; - std::string data = "1234567890123456"; // Exactly 16 bytes - auto data_len = static_cast(data.length()); - std::string mode = "CBC"; - auto mode_len = static_cast(mode.length()); - std::string padding = "NONE"; - auto padding_len = static_cast(padding.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - - const char* cipher = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), - key16_len, mode.c_str(), mode_len, iv.c_str(), - iv_len, padding.c_str(), padding_len, &cipher_len); - EXPECT_GT(cipher_len, 0); - - const char* decrypted_value = gdv_fn_aes_decrypt_cbc( - ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, - iv.c_str(), iv_len, padding.c_str(), padding_len, &decrypted_len); - EXPECT_EQ(data, std::string(reinterpret_cast(decrypted_value), decrypted_len)); -} - -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcCaseInsensitive) { - gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; - auto iv_len = static_cast(iv.length()); - int32_t cipher_len1 = 0; - int32_t cipher_len2 = 0; - std::string data = "test string"; - auto data_len = static_cast(data.length()); - std::string mode = "CBC"; - auto mode_len = static_cast(mode.length()); - std::string padding_upper = "PKCS7"; - auto padding_upper_len = static_cast(padding_upper.length()); - std::string padding_lower = "pkcs7"; - auto padding_lower_len = static_cast(padding_lower.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - - const char* cipher1 = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), - key16_len, mode.c_str(), mode_len, iv.c_str(), - iv_len, padding_upper.c_str(), padding_upper_len, - &cipher_len1); - const char* cipher2 = gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), - key16_len, mode.c_str(), mode_len, iv.c_str(), - iv_len, padding_lower.c_str(), padding_lower_len, - &cipher_len2); - - // Both should produce same ciphertext - EXPECT_EQ(cipher_len1, cipher_len2); - EXPECT_EQ(std::string(cipher1, cipher_len1), std::string(cipher2, cipher_len2)); -} - -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidIV) { - gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - std::string iv = "short"; // Too short - auto iv_len = static_cast(iv.length()); - int32_t cipher_len = 0; - std::string data = "test string"; - auto data_len = static_cast(data.length()); - std::string mode = "CBC"; - auto mode_len = static_cast(mode.length()); - std::string padding = "PKCS7"; - auto padding_len = static_cast(padding.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - - gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, - mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), - padding_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid IV length")); - ctx.Reset(); -} - -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidKey) { - gandiva::ExecutionContext ctx; - std::string key = "short"; // Too short - auto key_len = static_cast(key.length()); - std::string iv = "1234567890123456"; - auto iv_len = static_cast(iv.length()); - int32_t cipher_len = 0; - std::string data = "test string"; - auto data_len = static_cast(data.length()); - std::string mode = "CBC"; - auto mode_len = static_cast(mode.length()); - std::string padding = "PKCS7"; - auto padding_len = static_cast(padding.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - - gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key.c_str(), key_len, - mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), - padding_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported key length")); - ctx.Reset(); -} - -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcInvalidPadding) { - gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; - auto iv_len = static_cast(iv.length()); - int32_t cipher_len = 0; - std::string data = "test string"; - auto data_len = static_cast(data.length()); - std::string mode = "CBC"; - auto mode_len = static_cast(mode.length()); - std::string padding = "INVALID"; - auto padding_len = static_cast(padding.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - - gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, - mode.c_str(), mode_len, iv.c_str(), iv_len, padding.c_str(), - padding_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid padding mode")); - ctx.Reset(); -} - -TEST(TestGdvFnStubs, TestAesEncryptDecryptModeCbcModeValidation) { - gandiva::ExecutionContext ctx; - std::string key16 = "12345678abcdefgh"; - auto key16_len = static_cast(key16.length()); - std::string iv = "1234567890123456"; - auto iv_len = static_cast(iv.length()); - int32_t cipher_len = 0; - std::string data = "test string"; - auto data_len = static_cast(data.length()); - std::string invalid_mode = "ECB"; - auto invalid_mode_len = static_cast(invalid_mode.length()); - std::string padding = "PKCS7"; - auto padding_len = static_cast(padding.length()); - int64_t ctx_ptr = reinterpret_cast(&ctx); - - // Test encrypt with invalid mode - gdv_fn_aes_encrypt_cbc(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, - invalid_mode.c_str(), invalid_mode_len, iv.c_str(), iv_len, - padding.c_str(), padding_len, &cipher_len); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("AES encryption mode mismatch")); - EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("ECB")); + ::testing::HasSubstr("Unsupported decryption mode")); ctx.Reset(); }