From 35d60f3e6c61fa82d23fac852a1392302e4ee51a Mon Sep 17 00:00:00 2001 From: Phillipp Schoppmann Date: Tue, 3 Feb 2026 19:24:15 -0800 Subject: [PATCH] Take argument to UnwrapFfiStatus by value This will allow us to avoid unnecessary copies when replacing FfiStatus by a wrapped absl::Status. PiperOrigin-RevId: 865155378 --- ffi_utils/status.cc | 4 +- ffi_utils/status.h | 2 +- shell_wrapper/ahe_test.cc | 20 +- shell_wrapper/shell_types_test.cc | 65 +++--- willow/src/api/BUILD | 7 +- willow/src/api/client.cc | 21 +- willow/src/api/client.rs | 50 ++--- willow/src/api/server_accumulator.cc | 74 ++----- willow/src/api/server_accumulator.rs | 194 +++++------------- willow/src/testing_utils/BUILD | 6 +- .../testing_utils/shell_testing_decryptor.cc | 34 +-- .../testing_utils/shell_testing_decryptor.rs | 94 +++------ 12 files changed, 173 insertions(+), 398 deletions(-) diff --git a/ffi_utils/status.cc b/ffi_utils/status.cc index 087c169..87cdd04 100644 --- a/ffi_utils/status.cc +++ b/ffi_utils/status.cc @@ -35,12 +35,12 @@ FfiStatus MakeFfiStatus(int code, std::string message) { return FfiStatus{code, std::make_unique(std::move(message))}; } -absl::Status UnwrapFfiStatus(const FfiStatus& status) { +absl::Status UnwrapFfiStatus(FfiStatus status) { if (status.code == 0) { return absl::OkStatus(); } absl::string_view message = ""; - if (status.message != nullptr || status.message->empty()) { + if (status.message != nullptr) { message = *status.message; } return absl::Status(static_cast(status.code), message); diff --git a/ffi_utils/status.h b/ffi_utils/status.h index f6849da..6735be7 100644 --- a/ffi_utils/status.h +++ b/ffi_utils/status.h @@ -28,7 +28,7 @@ struct FfiStatus; FfiStatus MakeFfiStatus(); FfiStatus MakeFfiStatus(absl::Status status); FfiStatus MakeFfiStatus(int code, std::string message); -absl::Status UnwrapFfiStatus(const FfiStatus& status); +absl::Status UnwrapFfiStatus(FfiStatus status); // To be called from Rust. FfiStatus MakeFfiStatus(int32_t code, rust::Slice message); diff --git a/shell_wrapper/ahe_test.cc b/shell_wrapper/ahe_test.cc index a474e35..87b133f 100644 --- a/shell_wrapper/ahe_test.cc +++ b/shell_wrapper/ahe_test.cc @@ -54,19 +54,17 @@ constexpr double kSBase = 12.8; TEST(AheTest, EncryptDecryptOne) { // Create the public parameters. std::unique_ptr public_seed; - auto status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); AhePublicParameters public_params; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateAhePublicParameters( kLogN, kT, kQs, std::size(kQs), kPublicKeyVariance, kSBase, kSFlood, ToRustSlice(*public_seed), &public_params))); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); RnsPolynomialWrapper sk_share; SECAGG_ASSERT_OK( @@ -118,19 +116,17 @@ TEST(AheTest, EncryptDecryptOne) { TEST(AheTest, ExternCRecoveryTest) { // Create the public parameters. std::unique_ptr public_seed; - auto status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); AhePublicParameters public_params; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateAhePublicParameters( kLogN, kT, kQs, std::size(kQs), kPublicKeyVariance, kSBase, kSFlood, ToRustSlice(*public_seed), &public_params))); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); std::vector secret_key_shares; std::vector public_key_shares; diff --git a/shell_wrapper/shell_types_test.cc b/shell_wrapper/shell_types_test.cc index 9f17742..5c001f9 100644 --- a/shell_wrapper/shell_types_test.cc +++ b/shell_wrapper/shell_types_test.cc @@ -154,14 +154,11 @@ TEST(ShellTypesTest, AddErrors) { TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms_wrapper); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms_wrapper))); ModuliWrapper moduli_wrapper = CreateModuliWrapperFromKaheParams(params_wrapper); @@ -176,9 +173,8 @@ TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) { constexpr int buffer_len = 2 * kNumCoeffs; int64_t buffer[buffer_len]; uint64_t n_written; - FfiStatus res = WriteSmallRnsPolynomialToBuffer( - &poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written); - SECAGG_EXPECT_OK(UnwrapFfiStatus(res)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer( + &poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written))); EXPECT_EQ(n_written, kNumCoeffs); // We get 1 indeed. @@ -201,23 +197,19 @@ TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) { poly_wrapper = {.ptr = std::make_unique(std::move(poly))}; // Write the polynomial to a buffer. - res = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper, - buffer_len, buffer, &n_written); - SECAGG_EXPECT_OK(UnwrapFfiStatus(res)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer( + &poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written))); EXPECT_EQ(n_written, kNumCoeffs); } TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms_wrapper); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms_wrapper))); ModuliWrapper moduli_wrapper = CreateModuliWrapperFromKaheParams(params_wrapper); @@ -234,9 +226,8 @@ TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { } RnsPolynomialWrapper poly{nullptr}; - status = ReadSmallRnsPolynomialFromBuffer(buffer, buffer_len, 1 << kLogN, - moduli_wrapper, &poly); - SECAGG_EXPECT_OK(UnwrapFfiStatus(status)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(ReadSmallRnsPolynomialFromBuffer( + buffer, buffer_len, 1 << kLogN, moduli_wrapper, &poly))); EXPECT_NE(poly.ptr, nullptr); // The coefficients of the polynomial should match the ones in the @@ -267,23 +258,19 @@ TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { // Write the polynomial back to another buffer should give the same result. int64_t buffer_out[buffer_len]; uint64_t n_written; - status = WriteSmallRnsPolynomialToBuffer(&poly, moduli_wrapper, buffer_len, - buffer_out, &n_written); - SECAGG_EXPECT_OK(UnwrapFfiStatus(status)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer( + &poly, moduli_wrapper, buffer_len, buffer_out, &n_written))); EXPECT_EQ(absl::MakeSpan(buffer_out), absl::MakeSpan(buffer)); } TEST(ShellTypesTest, ReadWriteErrors) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms_wrapper); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms_wrapper))); ModuliWrapper moduli_wrapper = CreateModuliWrapperFromKaheParams(params_wrapper); @@ -301,13 +288,13 @@ TEST(ShellTypesTest, ReadWriteErrors) { constexpr int output_buffer_len = 2 * kNumCoeffs; int64_t output_buffer[output_buffer_len]; uint64_t n_written; - status = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper, - output_buffer_len, output_buffer, - &n_written); + auto status = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper, + output_buffer_len, + output_buffer, &n_written); // We should get an error because large coefficients don't have the same value // mod q_1 and q_2. Note that the buffer gets filled still. - EXPECT_THAT(UnwrapFfiStatus(status), + EXPECT_THAT(UnwrapFfiStatus(std::move(status)), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("not a small polynomial"))); EXPECT_EQ(n_written, kNumCoeffs); @@ -324,7 +311,7 @@ TEST(ShellTypesTest, ReadWriteErrors) { long_input_buffer_len, 1 << kLogN, moduli_wrapper, &poly_wrapper); // We should get an error. - EXPECT_THAT(UnwrapFfiStatus(status), + EXPECT_THAT(UnwrapFfiStatus(std::move(status)), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Buffer has too many coefficients"))); } @@ -332,9 +319,7 @@ TEST(ShellTypesTest, ReadWriteErrors) { TEST(AheTest, TestWriteRnsPolynomialToBuffer128) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, @@ -388,9 +373,7 @@ TEST(AheTest, TestWriteRnsPolynomialToBuffer128) { TEST(AheTest, WriteRnsPolynomialToBuffer128FailsWhenBufferLenIsWrong) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, diff --git a/willow/src/api/BUILD b/willow/src/api/BUILD index b0ee1cf..8491704 100644 --- a/willow/src/api/BUILD +++ b/willow/src/api/BUILD @@ -59,6 +59,8 @@ cc_library( "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", "@cxx.rs//:core", + "//ffi_utils:status_cc", + "//ffi_utils:status_macros", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:messages_cc_proto", "//willow/proto/willow:server_accumulator_cc_proto", @@ -90,6 +92,7 @@ rust_cxx_bridge( src = "server_accumulator.rs", deps = [ ":server_accumulator", + "//ffi_utils:status_cxx", ], ) @@ -153,6 +156,7 @@ rust_cxx_bridge( src = "client.rs", deps = [ ":client", + "//ffi_utils:status_cxx", ], ) @@ -165,7 +169,8 @@ cc_library( "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", "@cxx.rs//:core", - "//shell_wrapper:shell_types_cc", + "//ffi_utils:cxx_utils", + "//ffi_utils:status_macros", "//willow/proto/shell:shell_ciphertexts_cc_proto", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:messages_cc_proto", diff --git a/willow/src/api/client.cc b/willow/src/api/client.cc index 92cd8d3..76df46f 100644 --- a/willow/src/api/client.cc +++ b/willow/src/api/client.cc @@ -22,8 +22,9 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "ffi_utils/cxx_utils.h" +#include "ffi_utils/status_macros.h" #include "include/cxx.h" -#include "shell_wrapper/shell_types.h" #include "willow/proto/shell/ciphertexts.pb.h" #include "willow/proto/willow/aggregation_config.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" @@ -40,12 +41,8 @@ absl::StatusOr GenerateClientContribution( std::string config_str = aggregation_config.SerializeAsString(); auto config_ptr = std::make_unique(std::move(config_str)); secure_aggregation::WillowShellClient* client_ptr = nullptr; - std::unique_ptr status_message; - int status_code = - initialize_client(std::move(config_ptr), &client_ptr, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + initialize_client(std::move(config_ptr), &client_ptr)); // Use `into_box` to avoid linker issues arising from rust::Box::from_raw. auto client = client_into_box(client_ptr); @@ -69,15 +66,9 @@ absl::StatusOr GenerateClientContribution( auto key_ptr = std::make_unique(std::move(key_str)); rust::Slice nonce_slice = ToRustSlice(nonce); rust::Vec result_bytes; - std::unique_ptr status_message_gen; - // Encrypt data. - int status_code_gen = - generate_contribution(client, entries_slice, std::move(key_ptr), - nonce_slice, &result_bytes, &status_message_gen); - if (status_code_gen != 0) { - return absl::Status(absl::StatusCode(status_code_gen), *status_message_gen); - } + SECAGG_RETURN_IF_FFI_ERROR(generate_contribution( + client, entries_slice, std::move(key_ptr), nonce_slice, &result_bytes)); // Parse string to ClientMessage. willow::ClientMessage client_message; diff --git a/willow/src/api/client.rs b/willow/src/api/client.rs index fdcb128..8fd5771 100644 --- a/willow/src/api/client.rs +++ b/willow/src/api/client.rs @@ -46,6 +46,11 @@ pub mod ffi { values: &'a [u64], } + unsafe extern "C++" { + include!("ffi_utils/status.rs.h"); + type FfiStatus = status::ffi::FfiStatus; + } + extern "Rust" { // cxx: types used as extern Rust types are required to be defined by the same crate that // contains the bridge using them @@ -54,8 +59,7 @@ pub mod ffi { pub unsafe fn initialize_client( config: UniquePtr, out: *mut *mut WillowShellClient, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box; @@ -65,8 +69,7 @@ pub mod ffi { key: UniquePtr, nonce: &[u8], out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; } } @@ -112,23 +115,14 @@ impl WillowShellClient { } } -/// SAFETY: `out` and `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. unsafe fn initialize_client( config: cxx::UniquePtr, out: *mut *mut WillowShellClient, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match WillowShellClient::new_from_serialized_config(config) { - Ok(client) => { - *out = Box::into_raw(Box::new(client)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + WillowShellClient::new_from_serialized_config(config) + .map(|client| *out = Box::into_raw(Box::new(client))) + .into() } /// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw` @@ -139,24 +133,16 @@ unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box Box::from_raw(ptr) } -/// SAFETY: `out` and `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. unsafe fn generate_contribution( client: &mut Box, data: &[ffi::DataEntryView], public_key: cxx::UniquePtr, nonce: &[u8], out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match client.generate_contribution(data, public_key, nonce) { - Ok(contribution) => { - *out = contribution; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + client + .generate_contribution(data, public_key, nonce) + .map(|contribution| *out = contribution) + .into() } diff --git a/willow/src/api/server_accumulator.cc b/willow/src/api/server_accumulator.cc index 8bc148a..661bf03 100644 --- a/willow/src/api/server_accumulator.cc +++ b/willow/src/api/server_accumulator.cc @@ -23,39 +23,31 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "ffi_utils/status_macros.h" #include "include/cxx.h" #include "willow/proto/willow/aggregation_config.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" #include "willow/src/api/server_accumulator.rs.h" #include "willow/src/input_encoding/codec.h" -namespace secure_aggregation { -namespace willow { +namespace secure_aggregation::willow { absl::StatusOr> ServerAccumulator::Create( const AggregationConfigProto& aggregation_config) { secure_aggregation::ServerAccumulator* out; - std::unique_ptr status_message; - int status_code = + SECAGG_RETURN_IF_FFI_ERROR( secure_aggregation::NewServerAccumulatorFromSerializedConfig( std::make_unique(aggregation_config.SerializeAsString()), - &out, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &out)); return absl::WrapUnique(new ServerAccumulator(IntoBox(out))); } absl::StatusOr> ServerAccumulator::CreateFromSerializedState(std::string serialized_state) { secure_aggregation::ServerAccumulator* out; - std::unique_ptr status_message; - int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState( - std::make_unique(std::move(serialized_state)), &out, - &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + secure_aggregation::NewServerAccumulatorFromSerializedState( + std::make_unique(std::move(serialized_state)), &out)); return absl::WrapUnique(new ServerAccumulator(IntoBox(out))); } @@ -68,35 +60,22 @@ absl::Status ServerAccumulator::ProcessClientMessages( absl::Status ServerAccumulator::ProcessClientMessages( std::string serialized_client_messages) { - std::unique_ptr status_message; - int status_code = accumulator_->ProcessClientMessages( - std::make_unique(std::move(serialized_client_messages)), - &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(accumulator_->ProcessClientMessages( + std::make_unique(std::move(serialized_client_messages)))); return absl::OkStatus(); } absl::Status ServerAccumulator::Merge( std::unique_ptr other) { - std::unique_ptr status_message; - int status_code = - accumulator_->Merge(std::move(other->accumulator_), &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + accumulator_->Merge(std::move(other->accumulator_))); return absl::OkStatus(); } absl::StatusOr ServerAccumulator::ToSerializedState() { rust::Vec serialized_state; - std::unique_ptr status_message; - int status_code = - accumulator_->ToSerializedState(&serialized_state, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + accumulator_->ToSerializedState(&serialized_state)); return std::string(reinterpret_cast(serialized_state.data()), serialized_state.size()); } @@ -105,13 +84,9 @@ absl::StatusOr ServerAccumulator::Finalize() && { // Finalize accumulator in Rust and store the serialized results. rust::Vec decryption_request; rust::Vec final_result_decryptor_state; - std::unique_ptr status_message; - int status_code = secure_aggregation::FinalizeServerAccumulator( + SECAGG_RETURN_IF_FFI_ERROR(secure_aggregation::FinalizeServerAccumulator( std::move(accumulator_), &decryption_request, - &final_result_decryptor_state, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &final_result_decryptor_state)); // Pack the two serialized results into a single proto. FinalizedAccumulatorResult result_proto; @@ -129,15 +104,11 @@ absl::StatusOr> FinalResultDecryptor::CreateFromSerialized( std::string final_result_decryptor_state) { secure_aggregation::FinalResultDecryptor* out; - std::unique_ptr status_message; - int status_code = + SECAGG_RETURN_IF_FFI_ERROR( secure_aggregation::CreateFinalResultDecryptorFromSerialized( std::make_unique( std::move(final_result_decryptor_state)), - &out, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &out)); return absl::WrapUnique(new FinalResultDecryptor( secure_aggregation::FinalResultDecryptorIntoBox(out))); } @@ -145,14 +116,10 @@ FinalResultDecryptor::CreateFromSerialized( absl::StatusOr FinalResultDecryptor::Decrypt( std::string serialized_partial_decryption_response) { rust::Vec out; - std::unique_ptr status_message; - int status_code = aggregated_ciphertexts_->Decrypt( + SECAGG_RETURN_IF_FFI_ERROR(aggregated_ciphertexts_->Decrypt( std::make_unique( std::move(serialized_partial_decryption_response)), - &out, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &out)); EncodedData encoded_data; for (const auto& rust_entry : out) { std::string key(rust_entry.key); @@ -166,5 +133,4 @@ absl::StatusOr FinalResultDecryptor::Decrypt( return encoded_data; } -} // namespace willow -} // namespace secure_aggregation \ No newline at end of file +} // namespace secure_aggregation::willow diff --git a/willow/src/api/server_accumulator.rs b/willow/src/api/server_accumulator.rs index ae62b8f..abb6dbb 100644 --- a/willow/src/api/server_accumulator.rs +++ b/willow/src/api/server_accumulator.rs @@ -46,11 +46,16 @@ pub mod ffi { values: Vec, } + // Re-define FfiStatus since CXX requires shared structs to be defined in the same module + // (https://github.com/dtolnay/cxx/issues/297#issuecomment-727042059) + unsafe extern "C++" { + include!("ffi_utils/status.rs.h"); + type FfiStatus = status::ffi::FfiStatus; + } + extern "Rust" { type ServerAccumulator; - // We cannot use status::FfiStatus because CXX requires shared structs to be defined in the - // same module. So using separate message and pointer as a workaround. // SAFETY: All functions in this module are only called from the wrapping C++ library, // ensuring that output pointers are correctly wrapped by a rust::Box, and that pointer // arguments are not null. @@ -59,36 +64,26 @@ pub mod ffi { unsafe fn new_server_accumulator_from_serialized_config( serialized_aggregation_config: UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "NewServerAccumulatorFromSerializedState"] unsafe fn new_server_accumulator_from_serialized_state( serialized_server_accumulator: UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "ProcessClientMessages"] - unsafe fn process_client_messages_ffi( + fn process_client_messages_ffi( self: &mut ServerAccumulator, client_messages: UniquePtr, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "ToSerializedState"] - unsafe fn to_serialized_state_ffi( - self: &ServerAccumulator, - out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + unsafe fn to_serialized_state_ffi(self: &ServerAccumulator, out: *mut Vec) + -> FfiStatus; #[cxx_name = "Merge"] - unsafe fn merge_ffi( - self: &mut ServerAccumulator, - other: Box, - out_status_message: *mut UniquePtr, - ) -> i32; + fn merge_ffi(self: &mut ServerAccumulator, other: Box) -> FfiStatus; #[cxx_name = "IntoBox"] unsafe fn into_box(ptr: *mut ServerAccumulator) -> Box; @@ -100,23 +95,20 @@ pub mod ffi { accumulator: Box, out_decryption_request: *mut Vec, out_final_result_decryptor_state: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "Decrypt"] unsafe fn decrypt_ffi( self: &mut FinalResultDecryptor, serialized_partial_decryption_response: UniquePtr, out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "CreateFinalResultDecryptorFromSerialized"] unsafe fn create_final_result_decryptor_from_serialized( serialized_final_result_decryptor_state: UniquePtr, out: *mut *mut FinalResultDecryptor, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "FinalResultDecryptorIntoBox"] unsafe fn final_result_decryptor_into_box( @@ -125,8 +117,6 @@ pub mod ffi { } } -use status::ffi::FfiStatus; - pub struct ServerAccumulator { // Server struct used to perform aggregation of client contributions. server: WillowV1Server, @@ -334,21 +324,11 @@ impl ServerAccumulator { Ok(()) } - // SAFETY: - // - `out_status_message` must not be null. - pub unsafe fn process_client_messages_ffi( + pub fn process_client_messages_ffi( &mut self, client_messages: cxx::UniquePtr, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.process_client_messages_serialized(client_messages) { - Ok(()) => 0, - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.process_client_messages_serialized(client_messages).into() } // Atomically merges the other accumulator into `self`. @@ -414,39 +394,16 @@ impl ServerAccumulator { }) } - // SAFETY: - // - `out_status_message` must not be null. - pub unsafe fn merge_ffi( + pub fn merge_ffi( self: &mut ServerAccumulator, other: Box, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.merge(other) { - Ok(()) => 0, - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.merge(other).into() } - pub unsafe fn to_serialized_state_ffi( - &self, - out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.to_serialized_state() { - Ok(serialized_state) => { - *out = serialized_state; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + /// SAFETY: `out` must be valid for writes. + pub unsafe fn to_serialized_state_ffi(&self, out: *mut Vec) -> ffi::FfiStatus { + self.to_serialized_state().map(|result| *out = result).into() } } @@ -521,46 +478,24 @@ impl FromProto for ServerAccumulator { } } -// SAFETY: -// - `out` must not be null. It must be turned into a rust::Box on the C++ side. -// - `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. It must be turned into a rust::Box on the C++ side. unsafe fn new_server_accumulator_from_serialized_config( serialized_aggregation_config: cxx::UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match ServerAccumulator::new_from_serialized_config(serialized_aggregation_config) { - Ok(server_accumulator) => { - *out = Box::into_raw(Box::new(server_accumulator)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + ServerAccumulator::new_from_serialized_config(serialized_aggregation_config) + .map(|result| *out = Box::into_raw(Box::new(result))) + .into() } -// SAFETY: -// - `out` must not be null. It must be turned into a rust::Box on the C++ side. -// - `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. It must be turned into a rust::Box on the C++ side. unsafe fn new_server_accumulator_from_serialized_state( serialized_server_accumulator: cxx::UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match ServerAccumulator::new_from_serialized_state(serialized_server_accumulator) { - Ok(server_accumulator) => { - *out = Box::into_raw(Box::new(server_accumulator)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + ServerAccumulator::new_from_serialized_state(serialized_server_accumulator) + .map(|result| *out = Box::into_raw(Box::new(result))) + .into() } // SAFETY: @@ -623,26 +558,19 @@ fn finalize_accumulator(accumulator: ServerAccumulator) -> Result<(Vec, Vec< Ok((serialized_decryption_request, serialized_final_result_decryptor_state)) } -/// SAFETY: all pointer arguments (`out_decryption_request`, `out_final_result_decryptor_state`, -/// `out_status_message`) must be valid for writes. +/// SAFETY: all pointer arguments (`out_decryption_request`, `out_final_result_decryptor_state`) +/// must be valid for writes. pub unsafe fn finalize_accumulator_ffi( accumulator: Box, out_decryption_request: *mut Vec, out_final_result_decryptor_state: *mut Vec, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match finalize_accumulator(*accumulator) { - Ok((decryption_request, final_result_decryptor_state)) => { +) -> ffi::FfiStatus { + finalize_accumulator(*accumulator) + .map(|(decryption_request, final_result_decryptor_state)| { *out_decryption_request = decryption_request; *out_final_result_decryptor_state = final_result_decryptor_state; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + }) + .into() } impl FinalResultDecryptor { @@ -695,42 +623,22 @@ impl FinalResultDecryptor { Ok(entries) } - /// SAFETY: `out` and `out_status_message` must not be null. + /// SAFETY: `out` must be valid for writes. pub unsafe fn decrypt_ffi( &mut self, serialized_partial_decryption_response: cxx::UniquePtr, out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.decrypt(serialized_partial_decryption_response) { - Ok(result) => { - *out = result; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.decrypt(serialized_partial_decryption_response).map(|result| *out = result).into() } } -/// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. +/// SAFETY: `out` must be valid for writes. unsafe fn create_final_result_decryptor_from_serialized( serialized_proto: cxx::UniquePtr, out: *mut *mut FinalResultDecryptor, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match FinalResultDecryptor::new_from_serialized(serialized_proto) { - Ok(final_result_decryptor) => { - *out = Box::into_raw(Box::new(final_result_decryptor)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + FinalResultDecryptor::new_from_serialized(serialized_proto) + .map(|result| *out = Box::into_raw(Box::new(result))) + .into() } diff --git a/willow/src/testing_utils/BUILD b/willow/src/testing_utils/BUILD index 458411d..f639f7d 100644 --- a/willow/src/testing_utils/BUILD +++ b/willow/src/testing_utils/BUILD @@ -97,6 +97,7 @@ rust_cxx_bridge( src = "shell_testing_decryptor.rs", deps = [ ":shell_testing_decryptor", + "//ffi_utils:status_cxx", ], ) @@ -109,7 +110,6 @@ rust_library( "@protobuf//rust:protobuf", "@cxx.rs//:cxx", "//ffi_utils:status", - "//shell_wrapper:shell_types_cc", "//willow/proto/willow:aggregation_config_rust_proto", "//willow/proto/willow:messages_rust_proto", "//willow/src/api:aggregation_config", @@ -136,7 +136,9 @@ cc_library( "@abseil-cpp//absl/memory", "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", - "//shell_wrapper:shell_types_cc", + "@cxx.rs//:core", + "//ffi_utils:cxx_utils", + "//ffi_utils:status_macros", "//willow/proto/shell:shell_ciphertexts_cc_proto", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:messages_cc_proto", diff --git a/willow/src/testing_utils/shell_testing_decryptor.cc b/willow/src/testing_utils/shell_testing_decryptor.cc index 1167812..8330b07 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.cc +++ b/willow/src/testing_utils/shell_testing_decryptor.cc @@ -25,7 +25,8 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "shell_wrapper/shell_types.h" +#include "ffi_utils/cxx_utils.h" +#include "ffi_utils/status_macros.h" #include "willow/src/input_encoding/codec.h" #include "willow/src/testing_utils/shell_testing_decryptor.rs.h" @@ -43,13 +44,8 @@ ShellTestingDecryptor::Create( rust::Slice slice = ToRustSlice(aggregation_config_proto); ShellTestingDecryptorRust* out; - std::unique_ptr status_message; - int status_code = - create_shell_testing_decryptor(slice, &out, &status_message); + SECAGG_RETURN_IF_FFI_ERROR(create_shell_testing_decryptor(slice, &out)); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } // Use `into_box` to avoid linker issues arising from rust::Box::from_raw. return absl::WrapUnique(new ShellTestingDecryptor(decryptor_into_box(out))); } @@ -57,12 +53,7 @@ ShellTestingDecryptor::Create( absl::StatusOr ShellTestingDecryptor::GeneratePublicKey() { rust::Vec out; - std::unique_ptr status_message; - int status_code = decryptor_->generate_public_key(&out, &status_message); - - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(decryptor_->generate_public_key(&out)); willow::ShellAhePublicKey public_key; if (!public_key.ParseFromArray(out.data(), out.size())) { @@ -79,13 +70,7 @@ absl::StatusOr ShellTestingDecryptor::Decrypt( contribution_proto.size()); rust::Vec rust_flat_data; - std::unique_ptr status_message; - int status_code = - decryptor_->decrypt(slice, &rust_flat_data, &status_message); - - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(decryptor_->decrypt(slice, &rust_flat_data)); willow::EncodedData encoded_data; for (const auto& rust_entry : rust_flat_data) { @@ -105,13 +90,8 @@ absl::StatusOr ShellTestingDecryptor::GenerateSerializedPartialDecryptionResponse( std::string serialized_partial_decryption_request) { rust::Vec out; - std::unique_ptr status_message; - int status_code = decryptor_->generate_partial_decryption_response( - ToRustSlice(serialized_partial_decryption_request), &out, - &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(decryptor_->generate_partial_decryption_response( + ToRustSlice(serialized_partial_decryption_request), &out)); return std::string(reinterpret_cast(out.data()), out.size()); } diff --git a/willow/src/testing_utils/shell_testing_decryptor.rs b/willow/src/testing_utils/shell_testing_decryptor.rs index ae74447..ed91583 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.rs +++ b/willow/src/testing_utils/shell_testing_decryptor.rs @@ -29,7 +29,6 @@ use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use protobuf::prelude::*; use single_thread_hkdf::SingleThreadHkdfPrng; -use status::ffi::FfiStatus; use status::{StatusError, StatusErrorCode}; use std::cell::RefCell; use vahe_shell::ShellVahe; @@ -118,23 +117,9 @@ impl ShellTestingDecryptor { .map_err(|e| status::internal(format!("Serialize error: {}", e))) } - /// SAFETY: `out` and `out_status_message` must not be null. - unsafe fn generate_public_key_ffi( - &mut self, - out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.generate_public_key_serialized() { - Ok(pk) => { - *out = pk; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + /// SAFETY: `out` must be valid for writes. + unsafe fn generate_public_key_ffi(&mut self, out: *mut Vec) -> ffi::FfiStatus { + self.generate_public_key_serialized().map(|pk| *out = pk).into() } fn decrypt_serialized( @@ -163,24 +148,13 @@ impl ShellTestingDecryptor { Ok(entries) } - /// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. + /// SAFETY: `out` must be valid for writes. unsafe fn decrypt_ffi( &mut self, contribution: &[u8], out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.decrypt_serialized(contribution) { - Ok(result) => { - *out = result; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.decrypt_serialized(contribution).map(|result| *out = result).into() } fn generate_partial_decryption_response( @@ -219,24 +193,15 @@ impl ShellTestingDecryptor { .map_err(|e| status::internal(format!("Serialize error: {}", e))) } - /// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. + /// SAFETY: `out` must be valid for writes. unsafe fn generate_partial_decryption_response_ffi( &mut self, request: &[u8], out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.generate_partial_decryption_response_serialized(request) { - Ok(response) => { - *out = response; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.generate_partial_decryption_response_serialized(request) + .map(|response| *out = response) + .into() } } @@ -252,6 +217,12 @@ pub mod ffi { values: Vec, } + unsafe extern "C++" { + include!("ffi_utils/status.rs.h"); + #[namespace = "secure_aggregation"] + type FfiStatus = status::ffi::FfiStatus; + } + extern "Rust" { #[cxx_name = "ShellTestingDecryptorRust"] type ShellTestingDecryptor; @@ -259,31 +230,27 @@ pub mod ffi { unsafe fn create_shell_testing_decryptor( config: &[u8], out: *mut *mut ShellTestingDecryptor, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[rust_name = "generate_public_key_ffi"] unsafe fn generate_public_key( self: &mut ShellTestingDecryptor, out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[rust_name = "decrypt_ffi"] unsafe fn decrypt( self: &mut ShellTestingDecryptor, contribution: &[u8], out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[rust_name = "generate_partial_decryption_response_ffi"] unsafe fn generate_partial_decryption_response( self: &mut ShellTestingDecryptor, request: &[u8], out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; unsafe fn decryptor_into_box(ptr: *mut ShellTestingDecryptor) -> Box; @@ -301,23 +268,14 @@ fn create_shell_testing_decryptor_impl( Ok(Box::new(decryptor)) } -/// SAFETY: `out` and `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. unsafe fn create_shell_testing_decryptor( config: &[u8], out: *mut *mut ShellTestingDecryptor, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match create_shell_testing_decryptor_impl(config) { - Ok(decryptor) => { - *out = Box::into_raw(decryptor); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + create_shell_testing_decryptor_impl(config) + .map(|decryptor| *out = Box::into_raw(decryptor)) + .into() } /// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw`