From 4138b3d20e29da4fbc946da27f2a67a826af3a50 Mon Sep 17 00:00:00 2001 From: Phillipp Schoppmann Date: Tue, 3 Feb 2026 20:33:51 -0800 Subject: [PATCH] Make FfiStatus `Clone` PiperOrigin-RevId: 865179516 --- ffi_utils/BUILD | 6 +- ffi_utils/status.cc | 50 +++-- ffi_utils/status.h | 10 +- ffi_utils/status.rs | 62 ++++-- shell_wrapper/ahe_test.cc | 20 +- shell_wrapper/shell_types.cc | 2 +- shell_wrapper/shell_types_test.cc | 65 +++--- willow/src/api/BUILD | 7 +- willow/src/api/client.cc | 21 +- willow/src/api/client.rs | 50 ++--- willow/src/api/server_accumulator.cc | 74 ++----- willow/src/api/server_accumulator.rs | 194 +++++------------- willow/src/testing_utils/BUILD | 6 +- .../testing_utils/shell_testing_decryptor.cc | 34 +-- .../testing_utils/shell_testing_decryptor.rs | 94 +++------ 15 files changed, 264 insertions(+), 431 deletions(-) diff --git a/ffi_utils/BUILD b/ffi_utils/BUILD index 360b914..594316f 100644 --- a/ffi_utils/BUILD +++ b/ffi_utils/BUILD @@ -44,7 +44,10 @@ rust_library( rust_cxx_bridge( name = "status_cxx", src = "status.rs", - deps = [":status_cc"], + deps = [ + ":status_cc", + "@abseil-cpp//absl/status", + ], ) rust_test( @@ -60,6 +63,7 @@ cc_library( srcs = ["status.cc"], hdrs = ["status.h"], deps = [ + ":cxx_utils", ":status_cxx/include", "@abseil-cpp//absl/status", "@abseil-cpp//absl/strings:string_view", diff --git a/ffi_utils/status.cc b/ffi_utils/status.cc index 087c169..2ffc705 100644 --- a/ffi_utils/status.cc +++ b/ffi_utils/status.cc @@ -16,39 +16,59 @@ #include #include -#include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "ffi_utils/cxx_utils.h" #include "ffi_utils/status.rs.h" #include "include/cxx.h" namespace secure_aggregation { -FfiStatus MakeFfiStatus() { return FfiStatus{.code = 0, .message = nullptr}; } +FfiStatus MakeFfiStatus() { return FfiStatus{nullptr}; } + FfiStatus MakeFfiStatus(absl::Status status) { - return FfiStatus{static_cast(status.code()), - std::make_unique(status.message())}; + if (status.ok()) { + return MakeFfiStatus(); + } + return FfiStatus{std::make_unique(std::move(status))}; } -FfiStatus MakeFfiStatus(int code, std::string message) { - return FfiStatus{code, std::make_unique(std::move(message))}; + +FfiStatus MakeFfiStatus(int32_t code, rust::Slice message) { + return MakeFfiStatus(absl::Status(static_cast(code), + ToAbslStringView(message))); } -absl::Status UnwrapFfiStatus(const FfiStatus& status) { - if (status.code == 0) { +FfiStatus CloneFfiStatus(const FfiStatus& status) { + if (status.ptr == nullptr) { + return MakeFfiStatus(); + } + return MakeFfiStatus(*status.ptr); +} + +absl::Status UnwrapFfiStatus(FfiStatus status) { + if (status.ptr == nullptr) { return absl::OkStatus(); } - absl::string_view message = ""; - if (status.message != nullptr || status.message->empty()) { - message = *status.message; + absl::Status out = std::move(*status.ptr); + return out; +} + +int32_t FfiStatusCode(const FfiStatus& status) { + if (status.ptr == nullptr) { + return 0; } - return absl::Status(static_cast(status.code), message); + return static_cast(status.ptr->code()); } -FfiStatus MakeFfiStatus(int32_t code, rust::Slice message) { - return secure_aggregation::MakeFfiStatus( - code, std::string(message.begin(), message.end())); +rust::Slice FfiStatusMessage(const FfiStatus& status) { + if (status.ptr == nullptr) { + return rust::Slice(); + } + return rust::Slice( + reinterpret_cast(status.ptr->message().data()), + status.ptr->message().size()); } } // namespace secure_aggregation diff --git a/ffi_utils/status.h b/ffi_utils/status.h index f6849da..906d645 100644 --- a/ffi_utils/status.h +++ b/ffi_utils/status.h @@ -27,12 +27,14 @@ struct FfiStatus; FfiStatus MakeFfiStatus(); FfiStatus MakeFfiStatus(absl::Status status); -FfiStatus MakeFfiStatus(int code, std::string message); -absl::Status UnwrapFfiStatus(const FfiStatus& status); - -// To be called from Rust. FfiStatus MakeFfiStatus(int32_t code, rust::Slice message); +FfiStatus CloneFfiStatus(const FfiStatus& status); + +absl::Status UnwrapFfiStatus(FfiStatus status); +int32_t FfiStatusCode(const FfiStatus& status); +rust::Slice FfiStatusMessage(const FfiStatus& status); + } // namespace secure_aggregation #endif // SECURE_AGGREGATION_FFI_UTILS_STATUS_H_ diff --git a/ffi_utils/status.rs b/ffi_utils/status.rs index 96f4e94..49222ce 100644 --- a/ffi_utils/status.rs +++ b/ffi_utils/status.rs @@ -24,21 +24,42 @@ use std::borrow::Cow; #[cxx::bridge(namespace = "secure_aggregation")] pub mod ffi { - // A simple Status alternative which is cxx-compatible (because it directly uses unique_ptr). + unsafe extern "C++" { + include!("absl/status/status.h"); + #[namespace = "absl"] + type Status; + } + + // A simple Status wrapper which is cxx-compatible (because it directly uses unique_ptr). pub struct FfiStatus { - pub code: i32, - pub message: UniquePtr, + // Wrapped absl::Status. A nullptr is interpreted as an OK status. + ptr: UniquePtr, } unsafe extern "C++" { include!("ffi_utils/status.h"); + #[rust_name = "make_ok_ffi_status"] + pub fn MakeFfiStatus() -> FfiStatus; + #[rust_name = "make_ffi_status"] pub fn MakeFfiStatus(code: i32, message: &[u8]) -> FfiStatus; + #[rust_name = "ffi_status_code"] + pub fn FfiStatusCode(status: &FfiStatus) -> i32; + #[rust_name = "ffi_status_message"] + pub fn FfiStatusMessage<'a>(status: &'a FfiStatus) -> &'a [u8]; + #[rust_name = "clone_ffi_status"] + pub fn CloneFfiStatus(status: &FfiStatus) -> FfiStatus; } } pub type Status = Result<(), StatusError>; pub type StatusOr = Result; +impl Clone for ffi::FfiStatus { + fn clone(&self) -> Self { + ffi::clone_ffi_status(self) + } +} + /// All cases of C++ StatusErrorCode except `StatusErrorCode::kOk`. #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[non_exhaustive] @@ -240,26 +261,27 @@ impl std::error::Error for StatusErrorCodeTryFromError {} impl From for ffi::FfiStatus { fn from(error: StatusError) -> Self { - ffi::MakeFfiStatus(error.code as i32, error.message.0.as_slice()) + ffi::make_ffi_status(error.code as i32, error.message.0.as_slice()) } } impl From for ffi::FfiStatus { fn from(status: Status) -> Self { match status { - Ok(()) => ffi::FfiStatus { code: 0, message: cxx::UniquePtr::null() }, + Ok(()) => ffi::make_ok_ffi_status(), Err(error) => error.into(), } } } pub fn rust_status_from_cpp(status: ffi::FfiStatus) -> Status { - if status.code == 0 { + let code = ffi::ffi_status_code(&status); + if code == 0 { Ok(()) } else { - let message = if status.message.is_null() { b"" } else { status.message.as_bytes() }; + let message = ffi::ffi_status_message(&status); Err(StatusError::new( - status.code.try_into().unwrap_or(StatusErrorCode::Unknown), + code.try_into().unwrap_or(StatusErrorCode::Unknown), message, core::panic::Location::caller(), )) @@ -369,6 +391,8 @@ impl std::fmt::Debug for MaybeString { #[cfg(test)] mod tests { use super::*; + use ffi::ffi_status_code; + use ffi::ffi_status_message; use googletest::prelude::*; #[allow(dead_code)] @@ -396,6 +420,14 @@ mod tests { } } + #[gtest] + fn test_clone() { + let status = ffi::make_ffi_status(1, b"test"); + let status2 = status.clone(); + expect_eq!(ffi_status_code(&status2), 1); + expect_eq!(ffi_status_message(&status2), b"test"); + } + #[gtest] fn test_try_from() { for i in 1..=16 { @@ -415,13 +447,13 @@ mod tests { fn test_ffi_status_from_status_error() { let error = StatusError::new_untracked(StatusErrorCode::Cancelled, "test"); let ffi_status: ffi::FfiStatus = error.into(); - expect_eq!(ffi_status.code, 1); - expect_eq!(ffi_status.message.as_bytes(), b"test"); + expect_eq!(ffi_status_code(&ffi_status), 1); + expect_eq!(ffi_status_message(&ffi_status), b"test"); } #[gtest] fn test_rust_status_from_cpp() { - let ffi_status = ffi::MakeFfiStatus(1, b"test"); + let ffi_status = ffi::make_ffi_status(1, b"test"); let rust_status = rust_status_from_cpp(ffi_status); assert!(rust_status.is_err()); expect_eq!(&rust_status.as_ref().err().unwrap().code(), &StatusErrorCode::Cancelled); @@ -432,15 +464,15 @@ mod tests { fn test_ffi_status_from_ok_status() { let rust_status = Ok(()); let ffi_status: ffi::FfiStatus = rust_status.into(); - expect_eq!(ffi_status.code, 0); - expect_eq!(ffi_status.message.is_null(), true); + expect_eq!(ffi_status_code(&ffi_status), 0); + expect_eq!(ffi_status_message(&ffi_status).is_empty(), true); } #[gtest] fn test_ffi_status_from_non_ok_status() { let rust_status = Err(StatusError::new_untracked(StatusErrorCode::Cancelled, "test")); let ffi_status: ffi::FfiStatus = rust_status.into(); - expect_eq!(ffi_status.code, StatusErrorCode::Cancelled as i32); - expect_eq!(ffi_status.message.as_bytes(), b"test"); + expect_eq!(ffi_status_code(&ffi_status), StatusErrorCode::Cancelled as i32); + expect_eq!(ffi_status_message(&ffi_status), b"test"); } } diff --git a/shell_wrapper/ahe_test.cc b/shell_wrapper/ahe_test.cc index a474e35..87b133f 100644 --- a/shell_wrapper/ahe_test.cc +++ b/shell_wrapper/ahe_test.cc @@ -54,19 +54,17 @@ constexpr double kSBase = 12.8; TEST(AheTest, EncryptDecryptOne) { // Create the public parameters. std::unique_ptr public_seed; - auto status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); AhePublicParameters public_params; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateAhePublicParameters( kLogN, kT, kQs, std::size(kQs), kPublicKeyVariance, kSBase, kSFlood, ToRustSlice(*public_seed), &public_params))); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); RnsPolynomialWrapper sk_share; SECAGG_ASSERT_OK( @@ -118,19 +116,17 @@ TEST(AheTest, EncryptDecryptOne) { TEST(AheTest, ExternCRecoveryTest) { // Create the public parameters. std::unique_ptr public_seed; - auto status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); AhePublicParameters public_params; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateAhePublicParameters( kLogN, kT, kQs, std::size(kQs), kPublicKeyVariance, kSBase, kSFlood, ToRustSlice(*public_seed), &public_params))); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); std::vector secret_key_shares; std::vector public_key_shares; diff --git a/shell_wrapper/shell_types.cc b/shell_wrapper/shell_types.cc index 1557eaf..d04c83f 100644 --- a/shell_wrapper/shell_types.cc +++ b/shell_wrapper/shell_types.cc @@ -264,7 +264,7 @@ FfiStatus WriteRnsPolynomialToBuffer128( buffer[2 * i + 1] = absl::Uint128High64(coeff_128); } - return {}; + return MakeFfiStatus(); } RnsPolynomialWrapper CloneRnsPolynomialWrapper(const RnsPolynomialWrapper* in) { diff --git a/shell_wrapper/shell_types_test.cc b/shell_wrapper/shell_types_test.cc index 9f17742..5c001f9 100644 --- a/shell_wrapper/shell_types_test.cc +++ b/shell_wrapper/shell_types_test.cc @@ -154,14 +154,11 @@ TEST(ShellTypesTest, AddErrors) { TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms_wrapper); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms_wrapper))); ModuliWrapper moduli_wrapper = CreateModuliWrapperFromKaheParams(params_wrapper); @@ -176,9 +173,8 @@ TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) { constexpr int buffer_len = 2 * kNumCoeffs; int64_t buffer[buffer_len]; uint64_t n_written; - FfiStatus res = WriteSmallRnsPolynomialToBuffer( - &poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written); - SECAGG_EXPECT_OK(UnwrapFfiStatus(res)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer( + &poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written))); EXPECT_EQ(n_written, kNumCoeffs); // We get 1 indeed. @@ -201,23 +197,19 @@ TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) { poly_wrapper = {.ptr = std::make_unique(std::move(poly))}; // Write the polynomial to a buffer. - res = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper, - buffer_len, buffer, &n_written); - SECAGG_EXPECT_OK(UnwrapFfiStatus(res)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer( + &poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written))); EXPECT_EQ(n_written, kNumCoeffs); } TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms_wrapper); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms_wrapper))); ModuliWrapper moduli_wrapper = CreateModuliWrapperFromKaheParams(params_wrapper); @@ -234,9 +226,8 @@ TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { } RnsPolynomialWrapper poly{nullptr}; - status = ReadSmallRnsPolynomialFromBuffer(buffer, buffer_len, 1 << kLogN, - moduli_wrapper, &poly); - SECAGG_EXPECT_OK(UnwrapFfiStatus(status)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(ReadSmallRnsPolynomialFromBuffer( + buffer, buffer_len, 1 << kLogN, moduli_wrapper, &poly))); EXPECT_NE(poly.ptr, nullptr); // The coefficients of the polynomial should match the ones in the @@ -267,23 +258,19 @@ TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { // Write the polynomial back to another buffer should give the same result. int64_t buffer_out[buffer_len]; uint64_t n_written; - status = WriteSmallRnsPolynomialToBuffer(&poly, moduli_wrapper, buffer_len, - buffer_out, &n_written); - SECAGG_EXPECT_OK(UnwrapFfiStatus(status)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer( + &poly, moduli_wrapper, buffer_len, buffer_out, &n_written))); EXPECT_EQ(absl::MakeSpan(buffer_out), absl::MakeSpan(buffer)); } TEST(ShellTypesTest, ReadWriteErrors) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms_wrapper); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms_wrapper))); ModuliWrapper moduli_wrapper = CreateModuliWrapperFromKaheParams(params_wrapper); @@ -301,13 +288,13 @@ TEST(ShellTypesTest, ReadWriteErrors) { constexpr int output_buffer_len = 2 * kNumCoeffs; int64_t output_buffer[output_buffer_len]; uint64_t n_written; - status = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper, - output_buffer_len, output_buffer, - &n_written); + auto status = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper, + output_buffer_len, + output_buffer, &n_written); // We should get an error because large coefficients don't have the same value // mod q_1 and q_2. Note that the buffer gets filled still. - EXPECT_THAT(UnwrapFfiStatus(status), + EXPECT_THAT(UnwrapFfiStatus(std::move(status)), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("not a small polynomial"))); EXPECT_EQ(n_written, kNumCoeffs); @@ -324,7 +311,7 @@ TEST(ShellTypesTest, ReadWriteErrors) { long_input_buffer_len, 1 << kLogN, moduli_wrapper, &poly_wrapper); // We should get an error. - EXPECT_THAT(UnwrapFfiStatus(status), + EXPECT_THAT(UnwrapFfiStatus(std::move(status)), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Buffer has too many coefficients"))); } @@ -332,9 +319,7 @@ TEST(ShellTypesTest, ReadWriteErrors) { TEST(AheTest, TestWriteRnsPolynomialToBuffer128) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, @@ -388,9 +373,7 @@ TEST(AheTest, TestWriteRnsPolynomialToBuffer128) { TEST(AheTest, WriteRnsPolynomialToBuffer128FailsWhenBufferLenIsWrong) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, diff --git a/willow/src/api/BUILD b/willow/src/api/BUILD index b0ee1cf..8491704 100644 --- a/willow/src/api/BUILD +++ b/willow/src/api/BUILD @@ -59,6 +59,8 @@ cc_library( "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", "@cxx.rs//:core", + "//ffi_utils:status_cc", + "//ffi_utils:status_macros", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:messages_cc_proto", "//willow/proto/willow:server_accumulator_cc_proto", @@ -90,6 +92,7 @@ rust_cxx_bridge( src = "server_accumulator.rs", deps = [ ":server_accumulator", + "//ffi_utils:status_cxx", ], ) @@ -153,6 +156,7 @@ rust_cxx_bridge( src = "client.rs", deps = [ ":client", + "//ffi_utils:status_cxx", ], ) @@ -165,7 +169,8 @@ cc_library( "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", "@cxx.rs//:core", - "//shell_wrapper:shell_types_cc", + "//ffi_utils:cxx_utils", + "//ffi_utils:status_macros", "//willow/proto/shell:shell_ciphertexts_cc_proto", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:messages_cc_proto", diff --git a/willow/src/api/client.cc b/willow/src/api/client.cc index 92cd8d3..76df46f 100644 --- a/willow/src/api/client.cc +++ b/willow/src/api/client.cc @@ -22,8 +22,9 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "ffi_utils/cxx_utils.h" +#include "ffi_utils/status_macros.h" #include "include/cxx.h" -#include "shell_wrapper/shell_types.h" #include "willow/proto/shell/ciphertexts.pb.h" #include "willow/proto/willow/aggregation_config.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" @@ -40,12 +41,8 @@ absl::StatusOr GenerateClientContribution( std::string config_str = aggregation_config.SerializeAsString(); auto config_ptr = std::make_unique(std::move(config_str)); secure_aggregation::WillowShellClient* client_ptr = nullptr; - std::unique_ptr status_message; - int status_code = - initialize_client(std::move(config_ptr), &client_ptr, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + initialize_client(std::move(config_ptr), &client_ptr)); // Use `into_box` to avoid linker issues arising from rust::Box::from_raw. auto client = client_into_box(client_ptr); @@ -69,15 +66,9 @@ absl::StatusOr GenerateClientContribution( auto key_ptr = std::make_unique(std::move(key_str)); rust::Slice nonce_slice = ToRustSlice(nonce); rust::Vec result_bytes; - std::unique_ptr status_message_gen; - // Encrypt data. - int status_code_gen = - generate_contribution(client, entries_slice, std::move(key_ptr), - nonce_slice, &result_bytes, &status_message_gen); - if (status_code_gen != 0) { - return absl::Status(absl::StatusCode(status_code_gen), *status_message_gen); - } + SECAGG_RETURN_IF_FFI_ERROR(generate_contribution( + client, entries_slice, std::move(key_ptr), nonce_slice, &result_bytes)); // Parse string to ClientMessage. willow::ClientMessage client_message; diff --git a/willow/src/api/client.rs b/willow/src/api/client.rs index fdcb128..8fd5771 100644 --- a/willow/src/api/client.rs +++ b/willow/src/api/client.rs @@ -46,6 +46,11 @@ pub mod ffi { values: &'a [u64], } + unsafe extern "C++" { + include!("ffi_utils/status.rs.h"); + type FfiStatus = status::ffi::FfiStatus; + } + extern "Rust" { // cxx: types used as extern Rust types are required to be defined by the same crate that // contains the bridge using them @@ -54,8 +59,7 @@ pub mod ffi { pub unsafe fn initialize_client( config: UniquePtr, out: *mut *mut WillowShellClient, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box; @@ -65,8 +69,7 @@ pub mod ffi { key: UniquePtr, nonce: &[u8], out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; } } @@ -112,23 +115,14 @@ impl WillowShellClient { } } -/// SAFETY: `out` and `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. unsafe fn initialize_client( config: cxx::UniquePtr, out: *mut *mut WillowShellClient, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match WillowShellClient::new_from_serialized_config(config) { - Ok(client) => { - *out = Box::into_raw(Box::new(client)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + WillowShellClient::new_from_serialized_config(config) + .map(|client| *out = Box::into_raw(Box::new(client))) + .into() } /// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw` @@ -139,24 +133,16 @@ unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box Box::from_raw(ptr) } -/// SAFETY: `out` and `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. unsafe fn generate_contribution( client: &mut Box, data: &[ffi::DataEntryView], public_key: cxx::UniquePtr, nonce: &[u8], out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match client.generate_contribution(data, public_key, nonce) { - Ok(contribution) => { - *out = contribution; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + client + .generate_contribution(data, public_key, nonce) + .map(|contribution| *out = contribution) + .into() } diff --git a/willow/src/api/server_accumulator.cc b/willow/src/api/server_accumulator.cc index 8bc148a..661bf03 100644 --- a/willow/src/api/server_accumulator.cc +++ b/willow/src/api/server_accumulator.cc @@ -23,39 +23,31 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "ffi_utils/status_macros.h" #include "include/cxx.h" #include "willow/proto/willow/aggregation_config.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" #include "willow/src/api/server_accumulator.rs.h" #include "willow/src/input_encoding/codec.h" -namespace secure_aggregation { -namespace willow { +namespace secure_aggregation::willow { absl::StatusOr> ServerAccumulator::Create( const AggregationConfigProto& aggregation_config) { secure_aggregation::ServerAccumulator* out; - std::unique_ptr status_message; - int status_code = + SECAGG_RETURN_IF_FFI_ERROR( secure_aggregation::NewServerAccumulatorFromSerializedConfig( std::make_unique(aggregation_config.SerializeAsString()), - &out, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &out)); return absl::WrapUnique(new ServerAccumulator(IntoBox(out))); } absl::StatusOr> ServerAccumulator::CreateFromSerializedState(std::string serialized_state) { secure_aggregation::ServerAccumulator* out; - std::unique_ptr status_message; - int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState( - std::make_unique(std::move(serialized_state)), &out, - &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + secure_aggregation::NewServerAccumulatorFromSerializedState( + std::make_unique(std::move(serialized_state)), &out)); return absl::WrapUnique(new ServerAccumulator(IntoBox(out))); } @@ -68,35 +60,22 @@ absl::Status ServerAccumulator::ProcessClientMessages( absl::Status ServerAccumulator::ProcessClientMessages( std::string serialized_client_messages) { - std::unique_ptr status_message; - int status_code = accumulator_->ProcessClientMessages( - std::make_unique(std::move(serialized_client_messages)), - &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(accumulator_->ProcessClientMessages( + std::make_unique(std::move(serialized_client_messages)))); return absl::OkStatus(); } absl::Status ServerAccumulator::Merge( std::unique_ptr other) { - std::unique_ptr status_message; - int status_code = - accumulator_->Merge(std::move(other->accumulator_), &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + accumulator_->Merge(std::move(other->accumulator_))); return absl::OkStatus(); } absl::StatusOr ServerAccumulator::ToSerializedState() { rust::Vec serialized_state; - std::unique_ptr status_message; - int status_code = - accumulator_->ToSerializedState(&serialized_state, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + accumulator_->ToSerializedState(&serialized_state)); return std::string(reinterpret_cast(serialized_state.data()), serialized_state.size()); } @@ -105,13 +84,9 @@ absl::StatusOr ServerAccumulator::Finalize() && { // Finalize accumulator in Rust and store the serialized results. rust::Vec decryption_request; rust::Vec final_result_decryptor_state; - std::unique_ptr status_message; - int status_code = secure_aggregation::FinalizeServerAccumulator( + SECAGG_RETURN_IF_FFI_ERROR(secure_aggregation::FinalizeServerAccumulator( std::move(accumulator_), &decryption_request, - &final_result_decryptor_state, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &final_result_decryptor_state)); // Pack the two serialized results into a single proto. FinalizedAccumulatorResult result_proto; @@ -129,15 +104,11 @@ absl::StatusOr> FinalResultDecryptor::CreateFromSerialized( std::string final_result_decryptor_state) { secure_aggregation::FinalResultDecryptor* out; - std::unique_ptr status_message; - int status_code = + SECAGG_RETURN_IF_FFI_ERROR( secure_aggregation::CreateFinalResultDecryptorFromSerialized( std::make_unique( std::move(final_result_decryptor_state)), - &out, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &out)); return absl::WrapUnique(new FinalResultDecryptor( secure_aggregation::FinalResultDecryptorIntoBox(out))); } @@ -145,14 +116,10 @@ FinalResultDecryptor::CreateFromSerialized( absl::StatusOr FinalResultDecryptor::Decrypt( std::string serialized_partial_decryption_response) { rust::Vec out; - std::unique_ptr status_message; - int status_code = aggregated_ciphertexts_->Decrypt( + SECAGG_RETURN_IF_FFI_ERROR(aggregated_ciphertexts_->Decrypt( std::make_unique( std::move(serialized_partial_decryption_response)), - &out, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &out)); EncodedData encoded_data; for (const auto& rust_entry : out) { std::string key(rust_entry.key); @@ -166,5 +133,4 @@ absl::StatusOr FinalResultDecryptor::Decrypt( return encoded_data; } -} // namespace willow -} // namespace secure_aggregation \ No newline at end of file +} // namespace secure_aggregation::willow diff --git a/willow/src/api/server_accumulator.rs b/willow/src/api/server_accumulator.rs index ae62b8f..abb6dbb 100644 --- a/willow/src/api/server_accumulator.rs +++ b/willow/src/api/server_accumulator.rs @@ -46,11 +46,16 @@ pub mod ffi { values: Vec, } + // Re-define FfiStatus since CXX requires shared structs to be defined in the same module + // (https://github.com/dtolnay/cxx/issues/297#issuecomment-727042059) + unsafe extern "C++" { + include!("ffi_utils/status.rs.h"); + type FfiStatus = status::ffi::FfiStatus; + } + extern "Rust" { type ServerAccumulator; - // We cannot use status::FfiStatus because CXX requires shared structs to be defined in the - // same module. So using separate message and pointer as a workaround. // SAFETY: All functions in this module are only called from the wrapping C++ library, // ensuring that output pointers are correctly wrapped by a rust::Box, and that pointer // arguments are not null. @@ -59,36 +64,26 @@ pub mod ffi { unsafe fn new_server_accumulator_from_serialized_config( serialized_aggregation_config: UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "NewServerAccumulatorFromSerializedState"] unsafe fn new_server_accumulator_from_serialized_state( serialized_server_accumulator: UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "ProcessClientMessages"] - unsafe fn process_client_messages_ffi( + fn process_client_messages_ffi( self: &mut ServerAccumulator, client_messages: UniquePtr, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "ToSerializedState"] - unsafe fn to_serialized_state_ffi( - self: &ServerAccumulator, - out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + unsafe fn to_serialized_state_ffi(self: &ServerAccumulator, out: *mut Vec) + -> FfiStatus; #[cxx_name = "Merge"] - unsafe fn merge_ffi( - self: &mut ServerAccumulator, - other: Box, - out_status_message: *mut UniquePtr, - ) -> i32; + fn merge_ffi(self: &mut ServerAccumulator, other: Box) -> FfiStatus; #[cxx_name = "IntoBox"] unsafe fn into_box(ptr: *mut ServerAccumulator) -> Box; @@ -100,23 +95,20 @@ pub mod ffi { accumulator: Box, out_decryption_request: *mut Vec, out_final_result_decryptor_state: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "Decrypt"] unsafe fn decrypt_ffi( self: &mut FinalResultDecryptor, serialized_partial_decryption_response: UniquePtr, out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "CreateFinalResultDecryptorFromSerialized"] unsafe fn create_final_result_decryptor_from_serialized( serialized_final_result_decryptor_state: UniquePtr, out: *mut *mut FinalResultDecryptor, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "FinalResultDecryptorIntoBox"] unsafe fn final_result_decryptor_into_box( @@ -125,8 +117,6 @@ pub mod ffi { } } -use status::ffi::FfiStatus; - pub struct ServerAccumulator { // Server struct used to perform aggregation of client contributions. server: WillowV1Server, @@ -334,21 +324,11 @@ impl ServerAccumulator { Ok(()) } - // SAFETY: - // - `out_status_message` must not be null. - pub unsafe fn process_client_messages_ffi( + pub fn process_client_messages_ffi( &mut self, client_messages: cxx::UniquePtr, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.process_client_messages_serialized(client_messages) { - Ok(()) => 0, - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.process_client_messages_serialized(client_messages).into() } // Atomically merges the other accumulator into `self`. @@ -414,39 +394,16 @@ impl ServerAccumulator { }) } - // SAFETY: - // - `out_status_message` must not be null. - pub unsafe fn merge_ffi( + pub fn merge_ffi( self: &mut ServerAccumulator, other: Box, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.merge(other) { - Ok(()) => 0, - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.merge(other).into() } - pub unsafe fn to_serialized_state_ffi( - &self, - out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.to_serialized_state() { - Ok(serialized_state) => { - *out = serialized_state; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + /// SAFETY: `out` must be valid for writes. + pub unsafe fn to_serialized_state_ffi(&self, out: *mut Vec) -> ffi::FfiStatus { + self.to_serialized_state().map(|result| *out = result).into() } } @@ -521,46 +478,24 @@ impl FromProto for ServerAccumulator { } } -// SAFETY: -// - `out` must not be null. It must be turned into a rust::Box on the C++ side. -// - `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. It must be turned into a rust::Box on the C++ side. unsafe fn new_server_accumulator_from_serialized_config( serialized_aggregation_config: cxx::UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match ServerAccumulator::new_from_serialized_config(serialized_aggregation_config) { - Ok(server_accumulator) => { - *out = Box::into_raw(Box::new(server_accumulator)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + ServerAccumulator::new_from_serialized_config(serialized_aggregation_config) + .map(|result| *out = Box::into_raw(Box::new(result))) + .into() } -// SAFETY: -// - `out` must not be null. It must be turned into a rust::Box on the C++ side. -// - `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. It must be turned into a rust::Box on the C++ side. unsafe fn new_server_accumulator_from_serialized_state( serialized_server_accumulator: cxx::UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match ServerAccumulator::new_from_serialized_state(serialized_server_accumulator) { - Ok(server_accumulator) => { - *out = Box::into_raw(Box::new(server_accumulator)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + ServerAccumulator::new_from_serialized_state(serialized_server_accumulator) + .map(|result| *out = Box::into_raw(Box::new(result))) + .into() } // SAFETY: @@ -623,26 +558,19 @@ fn finalize_accumulator(accumulator: ServerAccumulator) -> Result<(Vec, Vec< Ok((serialized_decryption_request, serialized_final_result_decryptor_state)) } -/// SAFETY: all pointer arguments (`out_decryption_request`, `out_final_result_decryptor_state`, -/// `out_status_message`) must be valid for writes. +/// SAFETY: all pointer arguments (`out_decryption_request`, `out_final_result_decryptor_state`) +/// must be valid for writes. pub unsafe fn finalize_accumulator_ffi( accumulator: Box, out_decryption_request: *mut Vec, out_final_result_decryptor_state: *mut Vec, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match finalize_accumulator(*accumulator) { - Ok((decryption_request, final_result_decryptor_state)) => { +) -> ffi::FfiStatus { + finalize_accumulator(*accumulator) + .map(|(decryption_request, final_result_decryptor_state)| { *out_decryption_request = decryption_request; *out_final_result_decryptor_state = final_result_decryptor_state; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + }) + .into() } impl FinalResultDecryptor { @@ -695,42 +623,22 @@ impl FinalResultDecryptor { Ok(entries) } - /// SAFETY: `out` and `out_status_message` must not be null. + /// SAFETY: `out` must be valid for writes. pub unsafe fn decrypt_ffi( &mut self, serialized_partial_decryption_response: cxx::UniquePtr, out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.decrypt(serialized_partial_decryption_response) { - Ok(result) => { - *out = result; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.decrypt(serialized_partial_decryption_response).map(|result| *out = result).into() } } -/// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. +/// SAFETY: `out` must be valid for writes. unsafe fn create_final_result_decryptor_from_serialized( serialized_proto: cxx::UniquePtr, out: *mut *mut FinalResultDecryptor, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match FinalResultDecryptor::new_from_serialized(serialized_proto) { - Ok(final_result_decryptor) => { - *out = Box::into_raw(Box::new(final_result_decryptor)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + FinalResultDecryptor::new_from_serialized(serialized_proto) + .map(|result| *out = Box::into_raw(Box::new(result))) + .into() } diff --git a/willow/src/testing_utils/BUILD b/willow/src/testing_utils/BUILD index 458411d..f639f7d 100644 --- a/willow/src/testing_utils/BUILD +++ b/willow/src/testing_utils/BUILD @@ -97,6 +97,7 @@ rust_cxx_bridge( src = "shell_testing_decryptor.rs", deps = [ ":shell_testing_decryptor", + "//ffi_utils:status_cxx", ], ) @@ -109,7 +110,6 @@ rust_library( "@protobuf//rust:protobuf", "@cxx.rs//:cxx", "//ffi_utils:status", - "//shell_wrapper:shell_types_cc", "//willow/proto/willow:aggregation_config_rust_proto", "//willow/proto/willow:messages_rust_proto", "//willow/src/api:aggregation_config", @@ -136,7 +136,9 @@ cc_library( "@abseil-cpp//absl/memory", "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", - "//shell_wrapper:shell_types_cc", + "@cxx.rs//:core", + "//ffi_utils:cxx_utils", + "//ffi_utils:status_macros", "//willow/proto/shell:shell_ciphertexts_cc_proto", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:messages_cc_proto", diff --git a/willow/src/testing_utils/shell_testing_decryptor.cc b/willow/src/testing_utils/shell_testing_decryptor.cc index 1167812..8330b07 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.cc +++ b/willow/src/testing_utils/shell_testing_decryptor.cc @@ -25,7 +25,8 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "shell_wrapper/shell_types.h" +#include "ffi_utils/cxx_utils.h" +#include "ffi_utils/status_macros.h" #include "willow/src/input_encoding/codec.h" #include "willow/src/testing_utils/shell_testing_decryptor.rs.h" @@ -43,13 +44,8 @@ ShellTestingDecryptor::Create( rust::Slice slice = ToRustSlice(aggregation_config_proto); ShellTestingDecryptorRust* out; - std::unique_ptr status_message; - int status_code = - create_shell_testing_decryptor(slice, &out, &status_message); + SECAGG_RETURN_IF_FFI_ERROR(create_shell_testing_decryptor(slice, &out)); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } // Use `into_box` to avoid linker issues arising from rust::Box::from_raw. return absl::WrapUnique(new ShellTestingDecryptor(decryptor_into_box(out))); } @@ -57,12 +53,7 @@ ShellTestingDecryptor::Create( absl::StatusOr ShellTestingDecryptor::GeneratePublicKey() { rust::Vec out; - std::unique_ptr status_message; - int status_code = decryptor_->generate_public_key(&out, &status_message); - - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(decryptor_->generate_public_key(&out)); willow::ShellAhePublicKey public_key; if (!public_key.ParseFromArray(out.data(), out.size())) { @@ -79,13 +70,7 @@ absl::StatusOr ShellTestingDecryptor::Decrypt( contribution_proto.size()); rust::Vec rust_flat_data; - std::unique_ptr status_message; - int status_code = - decryptor_->decrypt(slice, &rust_flat_data, &status_message); - - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(decryptor_->decrypt(slice, &rust_flat_data)); willow::EncodedData encoded_data; for (const auto& rust_entry : rust_flat_data) { @@ -105,13 +90,8 @@ absl::StatusOr ShellTestingDecryptor::GenerateSerializedPartialDecryptionResponse( std::string serialized_partial_decryption_request) { rust::Vec out; - std::unique_ptr status_message; - int status_code = decryptor_->generate_partial_decryption_response( - ToRustSlice(serialized_partial_decryption_request), &out, - &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(decryptor_->generate_partial_decryption_response( + ToRustSlice(serialized_partial_decryption_request), &out)); return std::string(reinterpret_cast(out.data()), out.size()); } diff --git a/willow/src/testing_utils/shell_testing_decryptor.rs b/willow/src/testing_utils/shell_testing_decryptor.rs index ae74447..ed91583 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.rs +++ b/willow/src/testing_utils/shell_testing_decryptor.rs @@ -29,7 +29,6 @@ use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use protobuf::prelude::*; use single_thread_hkdf::SingleThreadHkdfPrng; -use status::ffi::FfiStatus; use status::{StatusError, StatusErrorCode}; use std::cell::RefCell; use vahe_shell::ShellVahe; @@ -118,23 +117,9 @@ impl ShellTestingDecryptor { .map_err(|e| status::internal(format!("Serialize error: {}", e))) } - /// SAFETY: `out` and `out_status_message` must not be null. - unsafe fn generate_public_key_ffi( - &mut self, - out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.generate_public_key_serialized() { - Ok(pk) => { - *out = pk; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + /// SAFETY: `out` must be valid for writes. + unsafe fn generate_public_key_ffi(&mut self, out: *mut Vec) -> ffi::FfiStatus { + self.generate_public_key_serialized().map(|pk| *out = pk).into() } fn decrypt_serialized( @@ -163,24 +148,13 @@ impl ShellTestingDecryptor { Ok(entries) } - /// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. + /// SAFETY: `out` must be valid for writes. unsafe fn decrypt_ffi( &mut self, contribution: &[u8], out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.decrypt_serialized(contribution) { - Ok(result) => { - *out = result; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.decrypt_serialized(contribution).map(|result| *out = result).into() } fn generate_partial_decryption_response( @@ -219,24 +193,15 @@ impl ShellTestingDecryptor { .map_err(|e| status::internal(format!("Serialize error: {}", e))) } - /// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. + /// SAFETY: `out` must be valid for writes. unsafe fn generate_partial_decryption_response_ffi( &mut self, request: &[u8], out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.generate_partial_decryption_response_serialized(request) { - Ok(response) => { - *out = response; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.generate_partial_decryption_response_serialized(request) + .map(|response| *out = response) + .into() } } @@ -252,6 +217,12 @@ pub mod ffi { values: Vec, } + unsafe extern "C++" { + include!("ffi_utils/status.rs.h"); + #[namespace = "secure_aggregation"] + type FfiStatus = status::ffi::FfiStatus; + } + extern "Rust" { #[cxx_name = "ShellTestingDecryptorRust"] type ShellTestingDecryptor; @@ -259,31 +230,27 @@ pub mod ffi { unsafe fn create_shell_testing_decryptor( config: &[u8], out: *mut *mut ShellTestingDecryptor, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[rust_name = "generate_public_key_ffi"] unsafe fn generate_public_key( self: &mut ShellTestingDecryptor, out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[rust_name = "decrypt_ffi"] unsafe fn decrypt( self: &mut ShellTestingDecryptor, contribution: &[u8], out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[rust_name = "generate_partial_decryption_response_ffi"] unsafe fn generate_partial_decryption_response( self: &mut ShellTestingDecryptor, request: &[u8], out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; unsafe fn decryptor_into_box(ptr: *mut ShellTestingDecryptor) -> Box; @@ -301,23 +268,14 @@ fn create_shell_testing_decryptor_impl( Ok(Box::new(decryptor)) } -/// SAFETY: `out` and `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. unsafe fn create_shell_testing_decryptor( config: &[u8], out: *mut *mut ShellTestingDecryptor, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match create_shell_testing_decryptor_impl(config) { - Ok(decryptor) => { - *out = Box::into_raw(decryptor); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + create_shell_testing_decryptor_impl(config) + .map(|decryptor| *out = Box::into_raw(decryptor)) + .into() } /// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw`