diff --git a/ffi_utils/BUILD b/ffi_utils/BUILD index 360b914..594316f 100644 --- a/ffi_utils/BUILD +++ b/ffi_utils/BUILD @@ -44,7 +44,10 @@ rust_library( rust_cxx_bridge( name = "status_cxx", src = "status.rs", - deps = [":status_cc"], + deps = [ + ":status_cc", + "@abseil-cpp//absl/status", + ], ) rust_test( @@ -60,6 +63,7 @@ cc_library( srcs = ["status.cc"], hdrs = ["status.h"], deps = [ + ":cxx_utils", ":status_cxx/include", "@abseil-cpp//absl/status", "@abseil-cpp//absl/strings:string_view", diff --git a/ffi_utils/status.cc b/ffi_utils/status.cc index 087c169..2ffc705 100644 --- a/ffi_utils/status.cc +++ b/ffi_utils/status.cc @@ -16,39 +16,59 @@ #include #include -#include #include #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "ffi_utils/cxx_utils.h" #include "ffi_utils/status.rs.h" #include "include/cxx.h" namespace secure_aggregation { -FfiStatus MakeFfiStatus() { return FfiStatus{.code = 0, .message = nullptr}; } +FfiStatus MakeFfiStatus() { return FfiStatus{nullptr}; } + FfiStatus MakeFfiStatus(absl::Status status) { - return FfiStatus{static_cast(status.code()), - std::make_unique(status.message())}; + if (status.ok()) { + return MakeFfiStatus(); + } + return FfiStatus{std::make_unique(std::move(status))}; } -FfiStatus MakeFfiStatus(int code, std::string message) { - return FfiStatus{code, std::make_unique(std::move(message))}; + +FfiStatus MakeFfiStatus(int32_t code, rust::Slice message) { + return MakeFfiStatus(absl::Status(static_cast(code), + ToAbslStringView(message))); } -absl::Status UnwrapFfiStatus(const FfiStatus& status) { - if (status.code == 0) { +FfiStatus CloneFfiStatus(const FfiStatus& status) { + if (status.ptr == nullptr) { + return MakeFfiStatus(); + } + return MakeFfiStatus(*status.ptr); +} + +absl::Status UnwrapFfiStatus(FfiStatus status) { + if (status.ptr == nullptr) { return absl::OkStatus(); } - absl::string_view message = ""; - if (status.message != nullptr || status.message->empty()) { - message = *status.message; + absl::Status out = std::move(*status.ptr); + return out; +} + +int32_t FfiStatusCode(const FfiStatus& status) { + if (status.ptr == nullptr) { + return 0; } - return absl::Status(static_cast(status.code), message); + return static_cast(status.ptr->code()); } -FfiStatus MakeFfiStatus(int32_t code, rust::Slice message) { - return secure_aggregation::MakeFfiStatus( - code, std::string(message.begin(), message.end())); +rust::Slice FfiStatusMessage(const FfiStatus& status) { + if (status.ptr == nullptr) { + return rust::Slice(); + } + return rust::Slice( + reinterpret_cast(status.ptr->message().data()), + status.ptr->message().size()); } } // namespace secure_aggregation diff --git a/ffi_utils/status.h b/ffi_utils/status.h index f6849da..906d645 100644 --- a/ffi_utils/status.h +++ b/ffi_utils/status.h @@ -27,12 +27,14 @@ struct FfiStatus; FfiStatus MakeFfiStatus(); FfiStatus MakeFfiStatus(absl::Status status); -FfiStatus MakeFfiStatus(int code, std::string message); -absl::Status UnwrapFfiStatus(const FfiStatus& status); - -// To be called from Rust. FfiStatus MakeFfiStatus(int32_t code, rust::Slice message); +FfiStatus CloneFfiStatus(const FfiStatus& status); + +absl::Status UnwrapFfiStatus(FfiStatus status); +int32_t FfiStatusCode(const FfiStatus& status); +rust::Slice FfiStatusMessage(const FfiStatus& status); + } // namespace secure_aggregation #endif // SECURE_AGGREGATION_FFI_UTILS_STATUS_H_ diff --git a/ffi_utils/status.rs b/ffi_utils/status.rs index 96f4e94..ad23d8e 100644 --- a/ffi_utils/status.rs +++ b/ffi_utils/status.rs @@ -21,24 +21,57 @@ //! corresponding to the widely used RETURN_IF_ERROR macro. use std::borrow::Cow; +use std::fmt::Debug; #[cxx::bridge(namespace = "secure_aggregation")] pub mod ffi { - // A simple Status alternative which is cxx-compatible (because it directly uses unique_ptr). + unsafe extern "C++" { + include!("absl/status/status.h"); + #[namespace = "absl"] + type Status; + } + + // A simple Status wrapper which is cxx-compatible (because it directly uses unique_ptr). pub struct FfiStatus { - pub code: i32, - pub message: UniquePtr, + // Wrapped absl::Status. A nullptr is interpreted as an OK status. + ptr: UniquePtr, } unsafe extern "C++" { include!("ffi_utils/status.h"); + #[rust_name = "make_ok_ffi_status"] + pub fn MakeFfiStatus() -> FfiStatus; + #[rust_name = "make_ffi_status"] pub fn MakeFfiStatus(code: i32, message: &[u8]) -> FfiStatus; + #[rust_name = "ffi_status_code"] + pub fn FfiStatusCode(status: &FfiStatus) -> i32; + #[rust_name = "ffi_status_message"] + pub fn FfiStatusMessage<'a>(status: &'a FfiStatus) -> &'a [u8]; + #[rust_name = "clone_ffi_status"] + pub fn CloneFfiStatus(status: &FfiStatus) -> FfiStatus; } } pub type Status = Result<(), StatusError>; pub type StatusOr = Result; +impl Clone for ffi::FfiStatus { + fn clone(&self) -> Self { + ffi::clone_ffi_status(self) + } +} + +impl Debug for ffi::FfiStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!( + f, + "FfiStatus({}: {})", + ffi::ffi_status_code(self), + String::from_utf8_lossy(ffi::ffi_status_message(self)) + ) + } +} + /// All cases of C++ StatusErrorCode except `StatusErrorCode::kOk`. #[derive(Debug, PartialEq, Eq, Clone, Copy)] #[non_exhaustive] @@ -62,39 +95,48 @@ pub enum StatusErrorCode { Unauthenticated = 16, } -/// Holds components of absl::Status in the error case. +/// Holds a wrapped non-OK absl::Status. /// We optionally keep a source location, but note that it cannot be passed to /// C++ yet. -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, Clone)] pub struct StatusError { - code: StatusErrorCode, - message: MaybeString, + ffi_status: ffi::FfiStatus, loc: Option<&'static core::panic::Location<'static>>, } impl StatusError { pub fn new( code: StatusErrorCode, - message: impl Into>, + message: &[u8], loc: &'static core::panic::Location<'static>, ) -> Self { - StatusError { code, message: MaybeString(message.into()), loc: Some(loc) } + StatusError { ffi_status: ffi::make_ffi_status(code as i32, message), loc: Some(loc) } + } + + pub fn from_ffi_status( + ffi_status: ffi::FfiStatus, + loc: Option<&'static core::panic::Location<'static>>, + ) -> Self { + if ffi::ffi_status_code(&ffi_status) == 0 { + panic!("Cannot create StatusError from OK status"); + } + StatusError { ffi_status, loc } } /// Create a new StatusError with no source code location. - pub fn new_untracked(code: StatusErrorCode, message: impl Into>) -> Self { - StatusError { code, message: MaybeString(message.into()), loc: None } + pub fn new_untracked(code: StatusErrorCode, message: &[u8]) -> Self { + StatusError { ffi_status: ffi::make_ffi_status(code as i32, message), loc: None } } /// Create a new StatusError pointing to the current source location. #[track_caller] - pub fn new_with_current_location(code: StatusErrorCode, message: impl Into>) -> Self { + pub fn new_with_current_location(code: StatusErrorCode, message: &[u8]) -> Self { StatusError::new_untracked(code, message).with_current_location() } /// Returns the canonical error code of this status. pub fn code(&self) -> StatusErrorCode { - self.code + ffi::ffi_status_code(&self.ffi_status).try_into().unwrap_or(StatusErrorCode::Unknown) } /// Returns the error message associated with this error code. @@ -102,12 +144,12 @@ impl StatusError { /// unusual for the error message to be the empty string. As a result, /// prefer `Display` for debug logging. pub fn message(&self) -> Cow { - String::from_utf8_lossy(&self.message.0) + String::from_utf8_lossy(self.message_bytes()) } /// Returns the raw bytes of the error message. pub fn message_bytes(&self) -> &[u8] { - &self.message.0 + ffi::ffi_status_message(&self.ffi_status) } /// Returns location of the error message. @@ -118,7 +160,7 @@ impl StatusError { /// Returns a new `StatusError` with the same code and message but pointing /// to the provided source location. pub fn with_location(self, location: &'static core::panic::Location<'static>) -> Self { - StatusError { code: self.code, message: self.message, loc: Some(location) } + StatusError { ffi_status: self.ffi_status, loc: Some(location) } } /// Returns a new `StatusError` with the same code and message but pointing @@ -132,9 +174,9 @@ impl StatusError { impl std::fmt::Display for StatusError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { if let Some(loc) = self.loc { - write!(f, "{}:{}:{}: {}", self.code.as_str(), loc.file(), loc.line(), self.message()) + write!(f, "{}:{}:{}: {}", self.code().as_str(), loc.file(), loc.line(), self.message()) } else { - write!(f, "{}: {}", self.code.as_str(), self.message()) + write!(f, "{}: {}", self.code().as_str(), self.message()) } } } @@ -240,141 +282,150 @@ impl std::error::Error for StatusErrorCodeTryFromError {} impl From for ffi::FfiStatus { fn from(error: StatusError) -> Self { - ffi::MakeFfiStatus(error.code as i32, error.message.0.as_slice()) + error.ffi_status } } impl From for ffi::FfiStatus { fn from(status: Status) -> Self { match status { - Ok(()) => ffi::FfiStatus { code: 0, message: cxx::UniquePtr::null() }, + Ok(()) => ffi::make_ok_ffi_status(), Err(error) => error.into(), } } } pub fn rust_status_from_cpp(status: ffi::FfiStatus) -> Status { - if status.code == 0 { + if ffi::ffi_status_code(&status) == 0 { Ok(()) } else { - let message = if status.message.is_null() { b"" } else { status.message.as_bytes() }; - Err(StatusError::new( - status.code.try_into().unwrap_or(StatusErrorCode::Unknown), - message, - core::panic::Location::caller(), - )) + Err(StatusError::from_ffi_status(status, Some(core::panic::Location::caller()))) } } #[track_caller] -pub fn cancelled>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::Cancelled, msg.into(), core::panic::Location::caller()) +pub fn cancelled(msg: &str) -> StatusError { + StatusError::new(StatusErrorCode::Cancelled, msg.as_bytes(), core::panic::Location::caller()) } #[track_caller] -pub fn unknown>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::Unknown, msg.into(), core::panic::Location::caller()) +pub fn unknown(msg: &str) -> StatusError { + StatusError::new(StatusErrorCode::Unknown, msg.as_bytes(), core::panic::Location::caller()) } #[track_caller] -pub fn invalid_argument>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::InvalidArgument, msg.into(), core::panic::Location::caller()) +pub fn invalid_argument(msg: &str) -> StatusError { + StatusError::new( + StatusErrorCode::InvalidArgument, + msg.as_bytes(), + core::panic::Location::caller(), + ) } #[track_caller] -pub fn deadline_exceeded>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::DeadlineExceeded, msg.into(), core::panic::Location::caller()) +pub fn deadline_exceeded(msg: &str) -> StatusError { + StatusError::new( + StatusErrorCode::DeadlineExceeded, + msg.as_bytes(), + core::panic::Location::caller(), + ) } #[track_caller] -pub fn not_found>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::NotFound, msg.into(), core::panic::Location::caller()) +pub fn not_found(msg: &str) -> StatusError { + StatusError::new(StatusErrorCode::NotFound, msg.as_bytes(), core::panic::Location::caller()) } #[track_caller] -pub fn already_exists>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::AlreadyExists, msg.into(), core::panic::Location::caller()) +pub fn already_exists(msg: &str) -> StatusError { + StatusError::new( + StatusErrorCode::AlreadyExists, + msg.as_bytes(), + core::panic::Location::caller(), + ) } #[track_caller] -pub fn permission_denied>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::PermissionDenied, msg.into(), core::panic::Location::caller()) +pub fn permission_denied(msg: &str) -> StatusError { + StatusError::new( + StatusErrorCode::PermissionDenied, + msg.as_bytes(), + core::panic::Location::caller(), + ) } #[track_caller] -pub fn resource_exhausted>(msg: S) -> StatusError { +pub fn resource_exhausted(msg: &str) -> StatusError { StatusError::new( StatusErrorCode::ResourceExhausted, - msg.into(), + msg.as_bytes(), core::panic::Location::caller(), ) } #[track_caller] -pub fn failed_precondition>(msg: S) -> StatusError { +pub fn failed_precondition(msg: &str) -> StatusError { StatusError::new( StatusErrorCode::FailedPrecondition, - msg.into(), + msg.as_bytes(), core::panic::Location::caller(), ) } #[track_caller] -pub fn aborted>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::Aborted, msg.into(), core::panic::Location::caller()) +pub fn aborted(msg: &str) -> StatusError { + StatusError::new(StatusErrorCode::Aborted, msg.as_bytes(), core::panic::Location::caller()) } #[track_caller] -pub fn out_of_range>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::OutOfRange, msg.into(), core::panic::Location::caller()) +pub fn out_of_range(msg: &str) -> StatusError { + StatusError::new(StatusErrorCode::OutOfRange, msg.as_bytes(), core::panic::Location::caller()) } #[track_caller] -pub fn unimplemented>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::Unimplemented, msg.into(), core::panic::Location::caller()) +pub fn unimplemented(msg: &str) -> StatusError { + StatusError::new( + StatusErrorCode::Unimplemented, + msg.as_bytes(), + core::panic::Location::caller(), + ) } #[track_caller] -pub fn internal>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::Internal, msg.into(), core::panic::Location::caller()) +pub fn internal(msg: &str) -> StatusError { + StatusError::new(StatusErrorCode::Internal, msg.as_bytes(), core::panic::Location::caller()) } #[track_caller] -pub fn unavailable>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::Unavailable, msg.into(), core::panic::Location::caller()) +pub fn unavailable(msg: &str) -> StatusError { + StatusError::new(StatusErrorCode::Unavailable, msg.as_bytes(), core::panic::Location::caller()) } #[track_caller] -pub fn data_loss>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::DataLoss, msg.into(), core::panic::Location::caller()) +pub fn data_loss(msg: &str) -> StatusError { + StatusError::new(StatusErrorCode::DataLoss, msg.as_bytes(), core::panic::Location::caller()) } #[track_caller] -pub fn unauthenticated>(msg: S) -> StatusError { - StatusError::new(StatusErrorCode::Unauthenticated, msg.into(), core::panic::Location::caller()) -} - -/// Holds a sequence of bytes that may be UTF-8. This primarily exists to give -/// it a String-like Debug implementation. -#[derive(PartialEq, Eq, Clone)] -struct MaybeString(pub Vec); - -impl std::fmt::Debug for MaybeString { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - let string = String::from_utf8_lossy(&self.0); - write!(f, "{:?}", string) - } +pub fn unauthenticated(msg: &str) -> StatusError { + StatusError::new( + StatusErrorCode::Unauthenticated, + msg.as_bytes(), + core::panic::Location::caller(), + ) } #[cfg(test)] mod tests { use super::*; + use ffi::ffi_status_code; + use ffi::ffi_status_message; use googletest::prelude::*; #[allow(dead_code)] fn compile_test() -> Status { if 0 == 1 { - return Err(cancelled(format!("bad stuff: {}", 0))); + return Err(cancelled(&format!("bad stuff: {}", 0))); } Ok(()) } @@ -390,12 +441,22 @@ mod tests { #[gtest] fn test() -> Result<()> { - match fail_whale() { - Err(StatusError { code: StatusErrorCode::Cancelled, .. }) => Ok(()), - status => fail!("unexpected status: {:?}", status), + let status = fail_whale(); + if status.is_err() && status.as_ref().err().unwrap().code() == StatusErrorCode::Cancelled { + Ok(()) + } else { + fail!("unexpected status: {:?}", status) } } + #[gtest] + fn test_clone() { + let status = ffi::make_ffi_status(1, b"test"); + let status2 = status.clone(); + expect_eq!(ffi_status_code(&status2), 1); + expect_eq!(ffi_status_message(&status2), b"test"); + } + #[gtest] fn test_try_from() { for i in 1..=16 { @@ -413,15 +474,15 @@ mod tests { #[gtest] fn test_ffi_status_from_status_error() { - let error = StatusError::new_untracked(StatusErrorCode::Cancelled, "test"); + let error = StatusError::new_untracked(StatusErrorCode::Cancelled, b"test"); let ffi_status: ffi::FfiStatus = error.into(); - expect_eq!(ffi_status.code, 1); - expect_eq!(ffi_status.message.as_bytes(), b"test"); + expect_eq!(ffi_status_code(&ffi_status), 1); + expect_eq!(ffi_status_message(&ffi_status), b"test"); } #[gtest] fn test_rust_status_from_cpp() { - let ffi_status = ffi::MakeFfiStatus(1, b"test"); + let ffi_status = ffi::make_ffi_status(1, b"test"); let rust_status = rust_status_from_cpp(ffi_status); assert!(rust_status.is_err()); expect_eq!(&rust_status.as_ref().err().unwrap().code(), &StatusErrorCode::Cancelled); @@ -432,15 +493,15 @@ mod tests { fn test_ffi_status_from_ok_status() { let rust_status = Ok(()); let ffi_status: ffi::FfiStatus = rust_status.into(); - expect_eq!(ffi_status.code, 0); - expect_eq!(ffi_status.message.is_null(), true); + expect_eq!(ffi_status_code(&ffi_status), 0); + expect_eq!(ffi_status_message(&ffi_status).is_empty(), true); } #[gtest] fn test_ffi_status_from_non_ok_status() { - let rust_status = Err(StatusError::new_untracked(StatusErrorCode::Cancelled, "test")); + let rust_status = Err(StatusError::new_untracked(StatusErrorCode::Cancelled, b"test")); let ffi_status: ffi::FfiStatus = rust_status.into(); - expect_eq!(ffi_status.code, StatusErrorCode::Cancelled as i32); - expect_eq!(ffi_status.message.as_bytes(), b"test"); + expect_eq!(ffi_status_code(&ffi_status), StatusErrorCode::Cancelled as i32); + expect_eq!(ffi_status_message(&ffi_status), b"test"); } } diff --git a/shell_wrapper/ahe_test.cc b/shell_wrapper/ahe_test.cc index a474e35..87b133f 100644 --- a/shell_wrapper/ahe_test.cc +++ b/shell_wrapper/ahe_test.cc @@ -54,19 +54,17 @@ constexpr double kSBase = 12.8; TEST(AheTest, EncryptDecryptOne) { // Create the public parameters. std::unique_ptr public_seed; - auto status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); AhePublicParameters public_params; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateAhePublicParameters( kLogN, kT, kQs, std::size(kQs), kPublicKeyVariance, kSBase, kSFlood, ToRustSlice(*public_seed), &public_params))); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); RnsPolynomialWrapper sk_share; SECAGG_ASSERT_OK( @@ -118,19 +116,17 @@ TEST(AheTest, EncryptDecryptOne) { TEST(AheTest, ExternCRecoveryTest) { // Create the public parameters. std::unique_ptr public_seed; - auto status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); AhePublicParameters public_params; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateAhePublicParameters( kLogN, kT, kQs, std::size(kQs), kPublicKeyVariance, kSBase, kSFlood, ToRustSlice(*public_seed), &public_params))); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); std::vector secret_key_shares; std::vector public_key_shares; diff --git a/shell_wrapper/kahe.rs b/shell_wrapper/kahe.rs index f796c28..4f0a23d 100644 --- a/shell_wrapper/kahe.rs +++ b/shell_wrapper/kahe.rs @@ -209,7 +209,10 @@ pub fn encrypt( for (id, packed_vector_config) in packed_vector_configs.iter() { let id: &str = id; if !input_vectors.contains_key(id) { - return Err(status::invalid_argument(format!("Input vector with id {} not found", id))); + return Err(status::invalid_argument(&format!( + "Input vector with id {} not found", + id + ))); } rust_status_from_cpp(unsafe { ffi::PackMessagesRaw( diff --git a/shell_wrapper/shell_serialization.rs b/shell_wrapper/shell_serialization.rs index 2511016..de75ece 100644 --- a/shell_wrapper/shell_serialization.rs +++ b/shell_wrapper/shell_serialization.rs @@ -17,7 +17,6 @@ use protobuf::prelude::*; use rns_serialization_rust_proto::SerializedRnsPolynomial; use shell_types::{create_empty_rns_polynomial, Moduli, RnsPolynomial}; -use status::{StatusError, StatusErrorCode}; #[cxx::bridge(namespace = "secure_aggregation")] mod ffi { @@ -59,12 +58,8 @@ pub fn rns_polynomial_to_proto( rust_status_from_cpp(unsafe { ffi::SerializeRnsPolynomialToBytes(poly, moduli.moduli, &mut out) })?; - SerializedRnsPolynomial::parse(out.as_bytes()).map_err(|parse_error| { - StatusError::new_with_current_location( - StatusErrorCode::Internal, - format!("{parse_error:?}"), - ) - }) + SerializedRnsPolynomial::parse(out.as_bytes()) + .map_err(|parse_error| status::internal(&format!("{parse_error:?}"))) } // Deserialize a SerializedRnsPolynomial proto to a RnsPolynomial. @@ -72,12 +67,10 @@ pub fn rns_polynomial_from_proto( serialized: impl protobuf::AsView, moduli: &Moduli, ) -> Result { - let serialized_bytes = serialized.as_view().serialize().map_err(|serialize_error| { - StatusError::new_with_current_location( - StatusErrorCode::Internal, - format!("{serialize_error:?}"), - ) - })?; + let serialized_bytes = serialized + .as_view() + .serialize() + .map_err(|serialize_error| status::internal(&format!("{serialize_error:?}")))?; // SAFETY: No lifetime constraints (`create_empty_rns_polynomial` creates and returns an empty // C++ object). diff --git a/shell_wrapper/shell_types.cc b/shell_wrapper/shell_types.cc index 1557eaf..d04c83f 100644 --- a/shell_wrapper/shell_types.cc +++ b/shell_wrapper/shell_types.cc @@ -264,7 +264,7 @@ FfiStatus WriteRnsPolynomialToBuffer128( buffer[2 * i + 1] = absl::Uint128High64(coeff_128); } - return {}; + return MakeFfiStatus(); } RnsPolynomialWrapper CloneRnsPolynomialWrapper(const RnsPolynomialWrapper* in) { diff --git a/shell_wrapper/shell_types_test.cc b/shell_wrapper/shell_types_test.cc index 9f17742..5c001f9 100644 --- a/shell_wrapper/shell_types_test.cc +++ b/shell_wrapper/shell_types_test.cc @@ -154,14 +154,11 @@ TEST(ShellTypesTest, AddErrors) { TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms_wrapper); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms_wrapper))); ModuliWrapper moduli_wrapper = CreateModuliWrapperFromKaheParams(params_wrapper); @@ -176,9 +173,8 @@ TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) { constexpr int buffer_len = 2 * kNumCoeffs; int64_t buffer[buffer_len]; uint64_t n_written; - FfiStatus res = WriteSmallRnsPolynomialToBuffer( - &poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written); - SECAGG_EXPECT_OK(UnwrapFfiStatus(res)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer( + &poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written))); EXPECT_EQ(n_written, kNumCoeffs); // We get 1 indeed. @@ -201,23 +197,19 @@ TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) { poly_wrapper = {.ptr = std::make_unique(std::move(poly))}; // Write the polynomial to a buffer. - res = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper, - buffer_len, buffer, &n_written); - SECAGG_EXPECT_OK(UnwrapFfiStatus(res)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer( + &poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written))); EXPECT_EQ(n_written, kNumCoeffs); } TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms_wrapper); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms_wrapper))); ModuliWrapper moduli_wrapper = CreateModuliWrapperFromKaheParams(params_wrapper); @@ -234,9 +226,8 @@ TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { } RnsPolynomialWrapper poly{nullptr}; - status = ReadSmallRnsPolynomialFromBuffer(buffer, buffer_len, 1 << kLogN, - moduli_wrapper, &poly); - SECAGG_EXPECT_OK(UnwrapFfiStatus(status)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(ReadSmallRnsPolynomialFromBuffer( + buffer, buffer_len, 1 << kLogN, moduli_wrapper, &poly))); EXPECT_NE(poly.ptr, nullptr); // The coefficients of the polynomial should match the ones in the @@ -267,23 +258,19 @@ TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { // Write the polynomial back to another buffer should give the same result. int64_t buffer_out[buffer_len]; uint64_t n_written; - status = WriteSmallRnsPolynomialToBuffer(&poly, moduli_wrapper, buffer_len, - buffer_out, &n_written); - SECAGG_EXPECT_OK(UnwrapFfiStatus(status)); + SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer( + &poly, moduli_wrapper, buffer_len, buffer_out, &n_written))); EXPECT_EQ(absl::MakeSpan(buffer_out), absl::MakeSpan(buffer)); } TEST(ShellTypesTest, ReadWriteErrors) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms_wrapper); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms_wrapper))); ModuliWrapper moduli_wrapper = CreateModuliWrapperFromKaheParams(params_wrapper); @@ -301,13 +288,13 @@ TEST(ShellTypesTest, ReadWriteErrors) { constexpr int output_buffer_len = 2 * kNumCoeffs; int64_t output_buffer[output_buffer_len]; uint64_t n_written; - status = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper, - output_buffer_len, output_buffer, - &n_written); + auto status = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper, + output_buffer_len, + output_buffer, &n_written); // We should get an error because large coefficients don't have the same value // mod q_1 and q_2. Note that the buffer gets filled still. - EXPECT_THAT(UnwrapFfiStatus(status), + EXPECT_THAT(UnwrapFfiStatus(std::move(status)), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("not a small polynomial"))); EXPECT_EQ(n_written, kNumCoeffs); @@ -324,7 +311,7 @@ TEST(ShellTypesTest, ReadWriteErrors) { long_input_buffer_len, 1 << kLogN, moduli_wrapper, &poly_wrapper); // We should get an error. - EXPECT_THAT(UnwrapFfiStatus(status), + EXPECT_THAT(UnwrapFfiStatus(std::move(status)), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Buffer has too many coefficients"))); } @@ -332,9 +319,7 @@ TEST(ShellTypesTest, ReadWriteErrors) { TEST(AheTest, TestWriteRnsPolynomialToBuffer128) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, @@ -388,9 +373,7 @@ TEST(AheTest, TestWriteRnsPolynomialToBuffer128) { TEST(AheTest, WriteRnsPolynomialToBuffer128FailsWhenBufferLenIsWrong) { constexpr int num_public_polynomials = 1; std::unique_ptr public_seed; - FfiStatus status; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params_wrapper; SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, diff --git a/willow/src/api/BUILD b/willow/src/api/BUILD index b0ee1cf..8491704 100644 --- a/willow/src/api/BUILD +++ b/willow/src/api/BUILD @@ -59,6 +59,8 @@ cc_library( "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", "@cxx.rs//:core", + "//ffi_utils:status_cc", + "//ffi_utils:status_macros", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:messages_cc_proto", "//willow/proto/willow:server_accumulator_cc_proto", @@ -90,6 +92,7 @@ rust_cxx_bridge( src = "server_accumulator.rs", deps = [ ":server_accumulator", + "//ffi_utils:status_cxx", ], ) @@ -153,6 +156,7 @@ rust_cxx_bridge( src = "client.rs", deps = [ ":client", + "//ffi_utils:status_cxx", ], ) @@ -165,7 +169,8 @@ cc_library( "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", "@cxx.rs//:core", - "//shell_wrapper:shell_types_cc", + "//ffi_utils:cxx_utils", + "//ffi_utils:status_macros", "//willow/proto/shell:shell_ciphertexts_cc_proto", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:messages_cc_proto", diff --git a/willow/src/api/client.cc b/willow/src/api/client.cc index 92cd8d3..76df46f 100644 --- a/willow/src/api/client.cc +++ b/willow/src/api/client.cc @@ -22,8 +22,9 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "ffi_utils/cxx_utils.h" +#include "ffi_utils/status_macros.h" #include "include/cxx.h" -#include "shell_wrapper/shell_types.h" #include "willow/proto/shell/ciphertexts.pb.h" #include "willow/proto/willow/aggregation_config.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" @@ -40,12 +41,8 @@ absl::StatusOr GenerateClientContribution( std::string config_str = aggregation_config.SerializeAsString(); auto config_ptr = std::make_unique(std::move(config_str)); secure_aggregation::WillowShellClient* client_ptr = nullptr; - std::unique_ptr status_message; - int status_code = - initialize_client(std::move(config_ptr), &client_ptr, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + initialize_client(std::move(config_ptr), &client_ptr)); // Use `into_box` to avoid linker issues arising from rust::Box::from_raw. auto client = client_into_box(client_ptr); @@ -69,15 +66,9 @@ absl::StatusOr GenerateClientContribution( auto key_ptr = std::make_unique(std::move(key_str)); rust::Slice nonce_slice = ToRustSlice(nonce); rust::Vec result_bytes; - std::unique_ptr status_message_gen; - // Encrypt data. - int status_code_gen = - generate_contribution(client, entries_slice, std::move(key_ptr), - nonce_slice, &result_bytes, &status_message_gen); - if (status_code_gen != 0) { - return absl::Status(absl::StatusCode(status_code_gen), *status_message_gen); - } + SECAGG_RETURN_IF_FFI_ERROR(generate_contribution( + client, entries_slice, std::move(key_ptr), nonce_slice, &result_bytes)); // Parse string to ClientMessage. willow::ClientMessage client_message; diff --git a/willow/src/api/client.rs b/willow/src/api/client.rs index fdcb128..cd6aa19 100644 --- a/willow/src/api/client.rs +++ b/willow/src/api/client.rs @@ -46,6 +46,11 @@ pub mod ffi { values: &'a [u64], } + unsafe extern "C++" { + include!("ffi_utils/status.rs.h"); + type FfiStatus = status::ffi::FfiStatus; + } + extern "Rust" { // cxx: types used as extern Rust types are required to be defined by the same crate that // contains the bridge using them @@ -54,8 +59,7 @@ pub mod ffi { pub unsafe fn initialize_client( config: UniquePtr, out: *mut *mut WillowShellClient, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box; @@ -65,8 +69,7 @@ pub mod ffi { key: UniquePtr, nonce: &[u8], out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; } } @@ -78,7 +81,7 @@ impl WillowShellClient { ) -> Result { let aggregation_config_proto = AggregationConfigProto::parse(config.as_bytes()).map_err(|e| { - status::internal(format!("Failed to parse AggregationConfigProto: {}", e)) + status::internal(&format!("Failed to parse AggregationConfigProto: {}", e)) })?; let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?; let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config)?; @@ -98,37 +101,28 @@ impl WillowShellClient { let mut plaintext_slice: HashMap<&str, &[u64]> = HashMap::new(); for entry in data { let key = std::str::from_utf8(entry.key) - .map_err(|e| status::internal(format!("Failed to parse key as UTF-8: {}", e)))?; + .map_err(|e| status::internal(&format!("Failed to parse key as UTF-8: {}", e)))?; plaintext_slice.insert(key, entry.values); } let public_key_proto = ShellAhePublicKey::parse(public_key.as_bytes()) - .map_err(|e| status::internal(format!("Failed to parse ShellAhePublicKey: {}", e)))?; + .map_err(|e| status::internal(&format!("Failed to parse ShellAhePublicKey: {}", e)))?; let public_key_rust = PublicKey::from_proto(public_key_proto, self.0.vahe.as_ref())?; let message = self.0.create_client_message(&plaintext_slice, &public_key_rust, nonce)?; Ok(message .to_proto(&self.0)? .serialize() - .map_err(|e| status::internal(format!("Failed to serialize ClientMessage: {}", e)))?) + .map_err(|e| status::internal(&format!("Failed to serialize ClientMessage: {}", e)))?) } } -/// SAFETY: `out` and `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. unsafe fn initialize_client( config: cxx::UniquePtr, out: *mut *mut WillowShellClient, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match WillowShellClient::new_from_serialized_config(config) { - Ok(client) => { - *out = Box::into_raw(Box::new(client)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + WillowShellClient::new_from_serialized_config(config) + .map(|client| *out = Box::into_raw(Box::new(client))) + .into() } /// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw` @@ -139,24 +133,16 @@ unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box Box::from_raw(ptr) } -/// SAFETY: `out` and `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. unsafe fn generate_contribution( client: &mut Box, data: &[ffi::DataEntryView], public_key: cxx::UniquePtr, nonce: &[u8], out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match client.generate_contribution(data, public_key, nonce) { - Ok(contribution) => { - *out = contribution; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + client + .generate_contribution(data, public_key, nonce) + .map(|contribution| *out = contribution) + .into() } diff --git a/willow/src/api/server_accumulator.cc b/willow/src/api/server_accumulator.cc index 8bc148a..661bf03 100644 --- a/willow/src/api/server_accumulator.cc +++ b/willow/src/api/server_accumulator.cc @@ -23,39 +23,31 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "ffi_utils/status_macros.h" #include "include/cxx.h" #include "willow/proto/willow/aggregation_config.pb.h" #include "willow/proto/willow/server_accumulator.pb.h" #include "willow/src/api/server_accumulator.rs.h" #include "willow/src/input_encoding/codec.h" -namespace secure_aggregation { -namespace willow { +namespace secure_aggregation::willow { absl::StatusOr> ServerAccumulator::Create( const AggregationConfigProto& aggregation_config) { secure_aggregation::ServerAccumulator* out; - std::unique_ptr status_message; - int status_code = + SECAGG_RETURN_IF_FFI_ERROR( secure_aggregation::NewServerAccumulatorFromSerializedConfig( std::make_unique(aggregation_config.SerializeAsString()), - &out, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &out)); return absl::WrapUnique(new ServerAccumulator(IntoBox(out))); } absl::StatusOr> ServerAccumulator::CreateFromSerializedState(std::string serialized_state) { secure_aggregation::ServerAccumulator* out; - std::unique_ptr status_message; - int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState( - std::make_unique(std::move(serialized_state)), &out, - &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + secure_aggregation::NewServerAccumulatorFromSerializedState( + std::make_unique(std::move(serialized_state)), &out)); return absl::WrapUnique(new ServerAccumulator(IntoBox(out))); } @@ -68,35 +60,22 @@ absl::Status ServerAccumulator::ProcessClientMessages( absl::Status ServerAccumulator::ProcessClientMessages( std::string serialized_client_messages) { - std::unique_ptr status_message; - int status_code = accumulator_->ProcessClientMessages( - std::make_unique(std::move(serialized_client_messages)), - &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(accumulator_->ProcessClientMessages( + std::make_unique(std::move(serialized_client_messages)))); return absl::OkStatus(); } absl::Status ServerAccumulator::Merge( std::unique_ptr other) { - std::unique_ptr status_message; - int status_code = - accumulator_->Merge(std::move(other->accumulator_), &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + accumulator_->Merge(std::move(other->accumulator_))); return absl::OkStatus(); } absl::StatusOr ServerAccumulator::ToSerializedState() { rust::Vec serialized_state; - std::unique_ptr status_message; - int status_code = - accumulator_->ToSerializedState(&serialized_state, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR( + accumulator_->ToSerializedState(&serialized_state)); return std::string(reinterpret_cast(serialized_state.data()), serialized_state.size()); } @@ -105,13 +84,9 @@ absl::StatusOr ServerAccumulator::Finalize() && { // Finalize accumulator in Rust and store the serialized results. rust::Vec decryption_request; rust::Vec final_result_decryptor_state; - std::unique_ptr status_message; - int status_code = secure_aggregation::FinalizeServerAccumulator( + SECAGG_RETURN_IF_FFI_ERROR(secure_aggregation::FinalizeServerAccumulator( std::move(accumulator_), &decryption_request, - &final_result_decryptor_state, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &final_result_decryptor_state)); // Pack the two serialized results into a single proto. FinalizedAccumulatorResult result_proto; @@ -129,15 +104,11 @@ absl::StatusOr> FinalResultDecryptor::CreateFromSerialized( std::string final_result_decryptor_state) { secure_aggregation::FinalResultDecryptor* out; - std::unique_ptr status_message; - int status_code = + SECAGG_RETURN_IF_FFI_ERROR( secure_aggregation::CreateFinalResultDecryptorFromSerialized( std::make_unique( std::move(final_result_decryptor_state)), - &out, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &out)); return absl::WrapUnique(new FinalResultDecryptor( secure_aggregation::FinalResultDecryptorIntoBox(out))); } @@ -145,14 +116,10 @@ FinalResultDecryptor::CreateFromSerialized( absl::StatusOr FinalResultDecryptor::Decrypt( std::string serialized_partial_decryption_response) { rust::Vec out; - std::unique_ptr status_message; - int status_code = aggregated_ciphertexts_->Decrypt( + SECAGG_RETURN_IF_FFI_ERROR(aggregated_ciphertexts_->Decrypt( std::make_unique( std::move(serialized_partial_decryption_response)), - &out, &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + &out)); EncodedData encoded_data; for (const auto& rust_entry : out) { std::string key(rust_entry.key); @@ -166,5 +133,4 @@ absl::StatusOr FinalResultDecryptor::Decrypt( return encoded_data; } -} // namespace willow -} // namespace secure_aggregation \ No newline at end of file +} // namespace secure_aggregation::willow diff --git a/willow/src/api/server_accumulator.rs b/willow/src/api/server_accumulator.rs index ae62b8f..22439ee 100644 --- a/willow/src/api/server_accumulator.rs +++ b/willow/src/api/server_accumulator.rs @@ -46,11 +46,16 @@ pub mod ffi { values: Vec, } + // Re-define FfiStatus since CXX requires shared structs to be defined in the same module + // (https://github.com/dtolnay/cxx/issues/297#issuecomment-727042059) + unsafe extern "C++" { + include!("ffi_utils/status.rs.h"); + type FfiStatus = status::ffi::FfiStatus; + } + extern "Rust" { type ServerAccumulator; - // We cannot use status::FfiStatus because CXX requires shared structs to be defined in the - // same module. So using separate message and pointer as a workaround. // SAFETY: All functions in this module are only called from the wrapping C++ library, // ensuring that output pointers are correctly wrapped by a rust::Box, and that pointer // arguments are not null. @@ -59,36 +64,26 @@ pub mod ffi { unsafe fn new_server_accumulator_from_serialized_config( serialized_aggregation_config: UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "NewServerAccumulatorFromSerializedState"] unsafe fn new_server_accumulator_from_serialized_state( serialized_server_accumulator: UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "ProcessClientMessages"] - unsafe fn process_client_messages_ffi( + fn process_client_messages_ffi( self: &mut ServerAccumulator, client_messages: UniquePtr, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "ToSerializedState"] - unsafe fn to_serialized_state_ffi( - self: &ServerAccumulator, - out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + unsafe fn to_serialized_state_ffi(self: &ServerAccumulator, out: *mut Vec) + -> FfiStatus; #[cxx_name = "Merge"] - unsafe fn merge_ffi( - self: &mut ServerAccumulator, - other: Box, - out_status_message: *mut UniquePtr, - ) -> i32; + fn merge_ffi(self: &mut ServerAccumulator, other: Box) -> FfiStatus; #[cxx_name = "IntoBox"] unsafe fn into_box(ptr: *mut ServerAccumulator) -> Box; @@ -100,23 +95,20 @@ pub mod ffi { accumulator: Box, out_decryption_request: *mut Vec, out_final_result_decryptor_state: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "Decrypt"] unsafe fn decrypt_ffi( self: &mut FinalResultDecryptor, serialized_partial_decryption_response: UniquePtr, out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "CreateFinalResultDecryptorFromSerialized"] unsafe fn create_final_result_decryptor_from_serialized( serialized_final_result_decryptor_state: UniquePtr, out: *mut *mut FinalResultDecryptor, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[cxx_name = "FinalResultDecryptorIntoBox"] unsafe fn final_result_decryptor_into_box( @@ -125,8 +117,6 @@ pub mod ffi { } } -use status::ffi::FfiStatus; - pub struct ServerAccumulator { // Server struct used to perform aggregation of client contributions. server: WillowV1Server, @@ -168,7 +158,7 @@ impl ServerAccumulator { let serialized_aggregation_config_proto = AggregationConfigProto::parse( serialized_aggregation_config.as_bytes(), ) - .map_err(|e| status::internal(format!("Failed to parse AggregationConfigProto: {}", e)))?; + .map_err(|e| status::internal(&format!("Failed to parse AggregationConfigProto: {}", e)))?; let aggregation_config = AggregationConfig::from_proto(serialized_aggregation_config_proto, ())?; Self::new(aggregation_config) @@ -180,7 +170,7 @@ impl ServerAccumulator { let serialized_server_accumulator_proto = ServerAccumulatorState::parse( serialized_server_accumulator.as_bytes(), ) - .map_err(|e| status::internal(format!("Failed to parse ServerAccumulatorState: {}", e)))?; + .map_err(|e| status::internal(&format!("Failed to parse ServerAccumulatorState: {}", e)))?; Self::from_proto(serialized_server_accumulator_proto, ()) } @@ -264,7 +254,7 @@ impl ServerAccumulator { client_messages.sort_by(|a, b| a.nonce.cmp(&b.nonce)); // Check that `nonce_range` does not overlap with any range already processed. if self.ranges_processed.overlaps(&nonce_range) { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "Invalid nonce range: [{:?}, {:?}) overlaps with already processed range", nonce_range.start, nonce_range.end, ))); @@ -276,7 +266,7 @@ impl ServerAccumulator { for message in client_messages { if message.nonce < nonce_range.start || message.nonce >= nonce_range.end { // Return immediately in case a message is out of range. - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "Invalid nonce: {:?} outside of range [{:?}, {:?})", message.nonce, nonce_range.start, nonce_range.end, ))); @@ -319,7 +309,7 @@ impl ServerAccumulator { client_messages: cxx::UniquePtr, ) -> Result<(), StatusError> { let client_range_proto = ClientMessageRange::parse(client_messages.as_bytes()) - .map_err(|e| status::internal(format!("Failed to parse ClientMessageRange: {}", e)))?; + .map_err(|e| status::internal(&format!("Failed to parse ClientMessageRange: {}", e)))?; std::mem::drop(client_messages); // Release memory early. `client_messages` can be huge. if !client_range_proto.client_messages().is_empty() { let client_messages: Result, _> = client_range_proto @@ -334,21 +324,11 @@ impl ServerAccumulator { Ok(()) } - // SAFETY: - // - `out_status_message` must not be null. - pub unsafe fn process_client_messages_ffi( + pub fn process_client_messages_ffi( &mut self, client_messages: cxx::UniquePtr, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.process_client_messages_serialized(client_messages) { - Ok(()) => 0, - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.process_client_messages_serialized(client_messages).into() } // Atomically merges the other accumulator into `self`. @@ -373,13 +353,13 @@ impl ServerAccumulator { if nonce_range.start != nonce_range_start { // This should not happen on properly constructed accumulators, since the order of // iteration over BTreeMap is by ascending key, same as RangeSet. - return Err(status::internal(format!( + return Err(status::internal(&format!( "Nonce range mismatch: range starts at {:?}, but verifier state indexed by {:?}", nonce_range.start, nonce_range_start, ))); } if self.ranges_processed.overlaps(&nonce_range) { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "Invalid nonce range: [{:?}, {:?}) overlaps with already processed range", nonce_range.start, nonce_range.end, ))); @@ -410,43 +390,20 @@ impl ServerAccumulator { fn to_serialized_state(&self) -> Result, StatusError> { self.to_proto(())?.serialize().map_err(|e| { - status::internal(format!("Failed to serialize ServerAccumulatorState: {}", e)) + status::internal(&format!("Failed to serialize ServerAccumulatorState: {}", e)) }) } - // SAFETY: - // - `out_status_message` must not be null. - pub unsafe fn merge_ffi( + pub fn merge_ffi( self: &mut ServerAccumulator, other: Box, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.merge(other) { - Ok(()) => 0, - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.merge(other).into() } - pub unsafe fn to_serialized_state_ffi( - &self, - out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.to_serialized_state() { - Ok(serialized_state) => { - *out = serialized_state; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + /// SAFETY: `out` must be valid for writes. + pub unsafe fn to_serialized_state_ffi(&self, out: *mut Vec) -> ffi::FfiStatus { + self.to_serialized_state().map(|result| *out = result).into() } } @@ -455,7 +412,7 @@ fn nonce_range_from_proto( ) -> Result>, StatusError> { let proto = proto.as_view(); if proto.start() >= proto.end() { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "Invalid nonce range: {:?} >= {:?}", proto.start(), proto.end(), @@ -510,7 +467,7 @@ impl FromProto for ServerAccumulator { result.verifier_states.insert(nonce_range_proto.start().to_vec(), verifier_state); // Check that range insertion indeed added another range that didn't overlap. if result.ranges_processed.len() != result.verifier_states.len() { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "Invalid range: [{:?}, {:?}) either overlaps or is adjacent to another range", nonce_range_proto.start(), nonce_range_proto.end() @@ -521,46 +478,24 @@ impl FromProto for ServerAccumulator { } } -// SAFETY: -// - `out` must not be null. It must be turned into a rust::Box on the C++ side. -// - `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. It must be turned into a rust::Box on the C++ side. unsafe fn new_server_accumulator_from_serialized_config( serialized_aggregation_config: cxx::UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match ServerAccumulator::new_from_serialized_config(serialized_aggregation_config) { - Ok(server_accumulator) => { - *out = Box::into_raw(Box::new(server_accumulator)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + ServerAccumulator::new_from_serialized_config(serialized_aggregation_config) + .map(|result| *out = Box::into_raw(Box::new(result))) + .into() } -// SAFETY: -// - `out` must not be null. It must be turned into a rust::Box on the C++ side. -// - `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. It must be turned into a rust::Box on the C++ side. unsafe fn new_server_accumulator_from_serialized_state( serialized_server_accumulator: cxx::UniquePtr, out: *mut *mut ServerAccumulator, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match ServerAccumulator::new_from_serialized_state(serialized_server_accumulator) { - Ok(server_accumulator) => { - *out = Box::into_raw(Box::new(server_accumulator)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + ServerAccumulator::new_from_serialized_state(serialized_server_accumulator) + .map(|result| *out = Box::into_raw(Box::new(result))) + .into() } // SAFETY: @@ -607,7 +542,7 @@ fn finalize_accumulator(accumulator: ServerAccumulator) -> Result<(Vec, Vec< let serialized_decryption_request = partial_decryption_request .to_proto(&accumulator.server)? .serialize() - .map_err(|e| status::internal(format!("Failed to serialize: {}", e)))?; + .map_err(|e| status::internal(&format!("Failed to serialize: {}", e)))?; // Extract the server state (i.e. sum of KAHE ciphertexts and sum of AHE ct_0 ciphertexts). let server_state_proto = accumulator.server_state.to_proto(&accumulator.server)?; @@ -618,31 +553,24 @@ fn finalize_accumulator(accumulator: ServerAccumulator) -> Result<(Vec, Vec< }); let serialized_final_result_decryptor_state = final_result_decryptor_state .serialize() - .map_err(|e| status::internal(format!("Failed to serialize: {}", e)))?; + .map_err(|e| status::internal(&format!("Failed to serialize: {}", e)))?; Ok((serialized_decryption_request, serialized_final_result_decryptor_state)) } -/// SAFETY: all pointer arguments (`out_decryption_request`, `out_final_result_decryptor_state`, -/// `out_status_message`) must be valid for writes. +/// SAFETY: all pointer arguments (`out_decryption_request`, `out_final_result_decryptor_state`) +/// must be valid for writes. pub unsafe fn finalize_accumulator_ffi( accumulator: Box, out_decryption_request: *mut Vec, out_final_result_decryptor_state: *mut Vec, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match finalize_accumulator(*accumulator) { - Ok((decryption_request, final_result_decryptor_state)) => { +) -> ffi::FfiStatus { + finalize_accumulator(*accumulator) + .map(|(decryption_request, final_result_decryptor_state)| { *out_decryption_request = decryption_request; *out_final_result_decryptor_state = final_result_decryptor_state; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + }) + .into() } impl FinalResultDecryptor { @@ -652,7 +580,7 @@ impl FinalResultDecryptor { // Parse aggregation config and server state protos. let final_result_decryptor_state_proto = FinalResultDecryptorState::parse(serialized_proto.as_bytes()).map_err(|e| { - status::internal(format!("Failed to parse FinalResultDecryptorState: {}", e)) + status::internal(&format!("Failed to parse FinalResultDecryptorState: {}", e)) })?; let server_state_proto = final_result_decryptor_state_proto.server_state(); let aggregation_config_proto = final_result_decryptor_state_proto.aggregation_config(); @@ -677,7 +605,7 @@ impl FinalResultDecryptor { serialized_partial_decryption_response.as_bytes(), ) .map_err(|e| { - status::internal(format!("Failed to parse PartialDecryptionResponse: {}", e)) + status::internal(&format!("Failed to parse PartialDecryptionResponse: {}", e)) })?; let pd = PartialDecryptionResponse::from_proto(pd_proto, &self.server)?; @@ -695,42 +623,22 @@ impl FinalResultDecryptor { Ok(entries) } - /// SAFETY: `out` and `out_status_message` must not be null. + /// SAFETY: `out` must be valid for writes. pub unsafe fn decrypt_ffi( &mut self, serialized_partial_decryption_response: cxx::UniquePtr, out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.decrypt(serialized_partial_decryption_response) { - Ok(result) => { - *out = result; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.decrypt(serialized_partial_decryption_response).map(|result| *out = result).into() } } -/// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. +/// SAFETY: `out` must be valid for writes. unsafe fn create_final_result_decryptor_from_serialized( serialized_proto: cxx::UniquePtr, out: *mut *mut FinalResultDecryptor, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match FinalResultDecryptor::new_from_serialized(serialized_proto) { - Ok(final_result_decryptor) => { - *out = Box::into_raw(Box::new(final_result_decryptor)); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + FinalResultDecryptor::new_from_serialized(serialized_proto) + .map(|result| *out = Box::into_raw(Box::new(result))) + .into() } diff --git a/willow/src/shell/ahe.rs b/willow/src/shell/ahe.rs index 29364f8..27ff8ae 100644 --- a/willow/src/shell/ahe.rs +++ b/willow/src/shell/ahe.rs @@ -47,7 +47,7 @@ const S_BASE_FLOOD: f64 = 12.8; fn check_vec_len(left: &Vec, right: &Vec) -> Result<(), status::StatusError> { if left.len() != right.len() { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "left and right must have the same length, got {} and {}", left.len(), right.len() @@ -659,7 +659,7 @@ impl Recover for ShellAhe { if let Some(l) = plaintext_len { let min_buffer_len = (pd.0.len() - 1) * self.num_coeffs; if l < min_buffer_len { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "received plaintext_len = {}, but the ciphertexts contain at least {} values", l, min_buffer_len ))); @@ -682,7 +682,7 @@ impl Recover for ShellAhe { &mut unsigned_values[start..end], )?; if n_written != (end - start) as u64 { - return Err(status::internal(format!( + return Err(status::internal(&format!( "Expected {} recovered messages, but got {}", end - start, n_written, diff --git a/willow/src/shell/kahe.rs b/willow/src/shell/kahe.rs index 6378aec..d4a52dc 100644 --- a/willow/src/shell/kahe.rs +++ b/willow/src/shell/kahe.rs @@ -58,7 +58,7 @@ impl ShellKahe { /// Validates KAHE parameters in ShellKaheConfig. fn validate_kahe_config(config: &ShellKaheConfig) -> Result<(), status::StatusError> { if config.log_t > BIG_INT_BITS { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "log_t must be <= {} for plaintexts to fit in the C++ big integer type, got {}", BIG_INT_BITS, config.log_t ))); @@ -68,23 +68,23 @@ impl ShellKahe { let dimension = packed_vector_config.dimension; let num_packed_coeffs = packed_vector_config.num_packed_coeffs; if base <= 1 { - return Err(status::invalid_argument(format!("base must be > 1, got {}", base))); + return Err(status::invalid_argument(&format!("base must be > 1, got {}", base))); } if dimension <= 0 { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "For packing id {}, dimension must be > 0, got {}", id, dimension ))); } if num_packed_coeffs <= 0 { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "For packing id {}, num_packed_coeffs must be > 0, got {}", id, num_packed_coeffs ))); } let log_base = (base as f64).log2().ceil() as u64; if log_base * dimension > config.log_t as u64 { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "For packing id {}, base^dimension must not be larger than the KAHE plaintext modulus 2^log_t+1: base = {}, dimension = {}, log_t = {}", id, base, @@ -218,7 +218,7 @@ impl KaheBase for ShellKahe { right: &mut Self::Plaintext, ) -> Result<(), status::StatusError> { if left.len() != right.len() { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "left and right must have the same length, got {} and {}", left.len(), right.len() @@ -227,7 +227,7 @@ impl KaheBase for ShellKahe { for (id, values) in left.iter() { if let Some(right_values) = right.get_mut(id) { if right_values.len() != values.len() { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "right values for key {} must have the same length as left, got {} and {}", id, right_values.len(), @@ -238,7 +238,7 @@ impl KaheBase for ShellKahe { right_values[i] += v; } } else { - return Err(status::invalid_argument(format!("right must contain key {}", id))); + return Err(status::invalid_argument(&format!("right must contain key {}", id))); } } Ok(()) @@ -275,7 +275,7 @@ impl KaheEncrypt for ShellKahe { let max_length = packed_vector_config.dimension * packed_vector_config.num_packed_coeffs; if values.len() > max_length as usize { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "plaintext for id {} can have at most {} elements, got {}", id, max_length, @@ -284,7 +284,7 @@ impl KaheEncrypt for ShellKahe { } for v in values.iter() { if *v >= packed_vector_config.base { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "plaintext for id {} cannot contain values larger than the input bound {}, got {}", id, packed_vector_config.base, @@ -293,7 +293,7 @@ impl KaheEncrypt for ShellKahe { } } } else { - return Err(status::invalid_argument(format!("unknown plaintext id {}", id))); + return Err(status::invalid_argument(&format!("unknown plaintext id {}", id))); } } @@ -330,7 +330,7 @@ impl TrySecretKeyInto> for ShellKahe { let n_written = write_small_rns_polynomial_to_buffer(&sk.0, &moduli, &mut signed_values[..])?; if n_written != self.num_coeffs as u64 { - return Err(status::internal(format!( + return Err(status::internal(&format!( "Expected {} coefficients, but got {}.", self.num_coeffs, n_written ))); @@ -346,7 +346,7 @@ impl TrySecretKeyFrom> for ShellKahe { sk_buffer: Vec, ) -> Result { if sk_buffer.len() < self.num_coeffs { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "secret key buffer is too short: {} < {}", sk_buffer.len(), self.num_coeffs diff --git a/willow/src/shell/parameters.rs b/willow/src/shell/parameters.rs index 5797727..77e1dd9 100644 --- a/willow/src/shell/parameters.rs +++ b/willow/src/shell/parameters.rs @@ -88,7 +88,7 @@ pub fn create_shell_ahe_config( max_number_of_decryptors: i64, ) -> Result { if max_number_of_decryptors > AHE_FIXED_MAX_NUM_DECRYPTORS { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "`max_number_of_decryptors` cannot be larger than {}", AHE_FIXED_MAX_NUM_DECRYPTORS ))); @@ -113,7 +113,7 @@ pub fn create_shell_kahe_config( .sum(); let max_input_bound = aggregation_config.vector_lengths_and_bounds.values().map(|(_, bound)| bound).max().ok_or( - status::invalid_argument(format!("empty vector configs in aggregation config")), + status::invalid_argument(&format!("empty vector configs in aggregation config")), )?; if total_input_length <= 1000 @@ -179,7 +179,7 @@ pub fn create_shell_kahe_config( }); } - Err(status::invalid_argument(format!( + Err(status::invalid_argument(&format!( "input setting is not supported: aggregation_config = {:?}", aggregation_config ))) diff --git a/willow/src/shell/parameters_generation.rs b/willow/src/shell/parameters_generation.rs index 0e135bc..3eabcd3 100644 --- a/willow/src/shell/parameters_generation.rs +++ b/willow/src/shell/parameters_generation.rs @@ -39,7 +39,7 @@ pub fn generate_packing_config( return Err(status::invalid_argument("`plaintext_bits` must be positive.")); } if plaintext_bits >= BIG_INT_BITS { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "`plaintext_bits` must be less than {}.", BIG_INT_BITS ))); @@ -50,13 +50,13 @@ pub fn generate_packing_config( let mut packing_configs = BTreeMap::::new(); for (id, (length, bound)) in agg_config.vector_lengths_and_bounds.iter() { if *length <= 0 { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "For id = {}, input length must be positive.", id ))); } if *bound <= 0 { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "For id = {}, input bound must be positive.", id ))); @@ -66,18 +66,18 @@ pub fn generate_packing_config( let agg_bound: i64 = agg_config.max_number_of_clients * bound; let base_bits: usize = (agg_bound as f64 + 1.0).log2().ceil() as usize; if base_bits > MAX_PACKING_BASE_BITS { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "For id = {}, input bound * max_number_of_clients is too large.", id, ))); } if base_bits == 0 { - return Err(status::invalid_argument(format!("For id = {}, base bits is 0.", id,))); + return Err(status::invalid_argument(&format!("For id = {}, base bits is 0.", id,))); } let base = 1i64 << base_bits; let dimension = plaintext_bits / base_bits; if dimension == 0 { - return Err(status::invalid_argument(format!( + return Err(status::invalid_argument(&format!( "For id = {}, plaintext_bits is too small; got {}, expected at least {}.", id, plaintext_bits, base_bits ))); diff --git a/willow/src/shell/parameters_utils.rs b/willow/src/shell/parameters_utils.rs index d675ee1..b4f0203 100644 --- a/willow/src/shell/parameters_utils.rs +++ b/willow/src/shell/parameters_utils.rs @@ -54,7 +54,7 @@ fn create_human_readable_shell_config_impl( ) -> Result, status::StatusError> { let config_proto = AggregationConfigProto::parse(aggregation_config_proto.as_bytes()).map_err(|e| { - status::invalid_argument(format!("Failed to parse AggregationConfigProto: {}", e)) + status::invalid_argument(&format!("Failed to parse AggregationConfigProto: {}", e)) })?; let config = AggregationConfig::from_proto(config_proto, ())?; let (kahe_config, ahe_config) = create_shell_configs(&config)?; diff --git a/willow/src/testing_utils/BUILD b/willow/src/testing_utils/BUILD index 458411d..f639f7d 100644 --- a/willow/src/testing_utils/BUILD +++ b/willow/src/testing_utils/BUILD @@ -97,6 +97,7 @@ rust_cxx_bridge( src = "shell_testing_decryptor.rs", deps = [ ":shell_testing_decryptor", + "//ffi_utils:status_cxx", ], ) @@ -109,7 +110,6 @@ rust_library( "@protobuf//rust:protobuf", "@cxx.rs//:cxx", "//ffi_utils:status", - "//shell_wrapper:shell_types_cc", "//willow/proto/willow:aggregation_config_rust_proto", "//willow/proto/willow:messages_rust_proto", "//willow/src/api:aggregation_config", @@ -136,7 +136,9 @@ cc_library( "@abseil-cpp//absl/memory", "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", - "//shell_wrapper:shell_types_cc", + "@cxx.rs//:core", + "//ffi_utils:cxx_utils", + "//ffi_utils:status_macros", "//willow/proto/shell:shell_ciphertexts_cc_proto", "//willow/proto/willow:aggregation_config_cc_proto", "//willow/proto/willow:messages_cc_proto", diff --git a/willow/src/testing_utils/shell_testing_decryptor.cc b/willow/src/testing_utils/shell_testing_decryptor.cc index 1167812..8330b07 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.cc +++ b/willow/src/testing_utils/shell_testing_decryptor.cc @@ -25,7 +25,8 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "shell_wrapper/shell_types.h" +#include "ffi_utils/cxx_utils.h" +#include "ffi_utils/status_macros.h" #include "willow/src/input_encoding/codec.h" #include "willow/src/testing_utils/shell_testing_decryptor.rs.h" @@ -43,13 +44,8 @@ ShellTestingDecryptor::Create( rust::Slice slice = ToRustSlice(aggregation_config_proto); ShellTestingDecryptorRust* out; - std::unique_ptr status_message; - int status_code = - create_shell_testing_decryptor(slice, &out, &status_message); + SECAGG_RETURN_IF_FFI_ERROR(create_shell_testing_decryptor(slice, &out)); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } // Use `into_box` to avoid linker issues arising from rust::Box::from_raw. return absl::WrapUnique(new ShellTestingDecryptor(decryptor_into_box(out))); } @@ -57,12 +53,7 @@ ShellTestingDecryptor::Create( absl::StatusOr ShellTestingDecryptor::GeneratePublicKey() { rust::Vec out; - std::unique_ptr status_message; - int status_code = decryptor_->generate_public_key(&out, &status_message); - - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(decryptor_->generate_public_key(&out)); willow::ShellAhePublicKey public_key; if (!public_key.ParseFromArray(out.data(), out.size())) { @@ -79,13 +70,7 @@ absl::StatusOr ShellTestingDecryptor::Decrypt( contribution_proto.size()); rust::Vec rust_flat_data; - std::unique_ptr status_message; - int status_code = - decryptor_->decrypt(slice, &rust_flat_data, &status_message); - - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(decryptor_->decrypt(slice, &rust_flat_data)); willow::EncodedData encoded_data; for (const auto& rust_entry : rust_flat_data) { @@ -105,13 +90,8 @@ absl::StatusOr ShellTestingDecryptor::GenerateSerializedPartialDecryptionResponse( std::string serialized_partial_decryption_request) { rust::Vec out; - std::unique_ptr status_message; - int status_code = decryptor_->generate_partial_decryption_response( - ToRustSlice(serialized_partial_decryption_request), &out, - &status_message); - if (status_code != 0) { - return absl::Status(absl::StatusCode(status_code), *status_message); - } + SECAGG_RETURN_IF_FFI_ERROR(decryptor_->generate_partial_decryption_response( + ToRustSlice(serialized_partial_decryption_request), &out)); return std::string(reinterpret_cast(out.data()), out.size()); } diff --git a/willow/src/testing_utils/shell_testing_decryptor.rs b/willow/src/testing_utils/shell_testing_decryptor.rs index ae74447..95e03f8 100644 --- a/willow/src/testing_utils/shell_testing_decryptor.rs +++ b/willow/src/testing_utils/shell_testing_decryptor.rs @@ -29,7 +29,6 @@ use prng_traits::SecurePrng; use proto_serialization_traits::{FromProto, ToProto}; use protobuf::prelude::*; use single_thread_hkdf::SingleThreadHkdfPrng; -use status::ffi::FfiStatus; use status::{StatusError, StatusErrorCode}; use std::cell::RefCell; use vahe_shell::ShellVahe; @@ -90,10 +89,7 @@ impl ShellTestingDecryptor { let rest_of_ciphertext = self.vahe.get_recover_ciphertext(&client_message.ahe_ciphertext)?; match &self.secret_key { - None => Err(StatusError::new_with_current_location( - StatusErrorCode::InvalidArgument, - "No secret key available", - )), + None => Err(status::invalid_argument("No secret key available")), Some(sk_share) => { let partial_decryption = self.vahe.partial_decrypt( &partial_dec_ciphertext, @@ -113,28 +109,14 @@ impl ShellTestingDecryptor { fn generate_public_key_serialized(&mut self) -> Result, StatusError> { let pk = self.generate_public_key()?; pk.to_proto(&self.vahe) - .map_err(|e| status::internal(format!("ToProto error: {}", e)))? + .map_err(|e| status::internal(&format!("ToProto error: {}", e)))? .serialize() - .map_err(|e| status::internal(format!("Serialize error: {}", e))) + .map_err(|e| status::internal(&format!("Serialize error: {}", e))) } - /// SAFETY: `out` and `out_status_message` must not be null. - unsafe fn generate_public_key_ffi( - &mut self, - out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.generate_public_key_serialized() { - Ok(pk) => { - *out = pk; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + /// SAFETY: `out` must be valid for writes. + unsafe fn generate_public_key_ffi(&mut self, out: *mut Vec) -> ffi::FfiStatus { + self.generate_public_key_serialized().map(|pk| *out = pk).into() } fn decrypt_serialized( @@ -142,7 +124,7 @@ impl ShellTestingDecryptor { contribution: &[u8], ) -> Result, StatusError> { let client_message_proto = ClientMessageProto::parse(contribution) - .map_err(|e| status::internal(format!("Failed to parse ClientMessageProto: {}", e)))?; + .map_err(|e| status::internal(&format!("Failed to parse ClientMessageProto: {}", e)))?; let kahe_ciphertext = KaheCiphertext::from_proto(client_message_proto.kahe_ciphertext(), &self.kahe)?; @@ -163,24 +145,13 @@ impl ShellTestingDecryptor { Ok(entries) } - /// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. + /// SAFETY: `out` must be valid for writes. unsafe fn decrypt_ffi( &mut self, contribution: &[u8], out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.decrypt_serialized(contribution) { - Ok(result) => { - *out = result; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.decrypt_serialized(contribution).map(|result| *out = result).into() } fn generate_partial_decryption_response( @@ -188,10 +159,7 @@ impl ShellTestingDecryptor { request: &PartialDecryptionRequest, ) -> Result, StatusError> { match &self.secret_key { - None => Err(StatusError::new_with_current_location( - StatusErrorCode::InvalidArgument, - "No secret key available", - )), + None => Err(status::invalid_argument("No secret key available")), Some(sk_share) => { let partial_decryption = self.vahe.partial_decrypt( &request.partial_dec_ciphertext, @@ -208,35 +176,26 @@ impl ShellTestingDecryptor { request: &[u8], ) -> Result, StatusError> { let request_proto = PartialDecryptionRequestProto::parse(request).map_err(|e| { - status::internal(format!("Failed to parse PartialDecryptionRequestProto: {}", e)) + status::internal(&format!("Failed to parse PartialDecryptionRequestProto: {}", e)) })?; let request = PartialDecryptionRequest::from_proto(request_proto, self)?; let response = self.generate_partial_decryption_response(&request)?; response .to_proto(self) - .map_err(|e| status::internal(format!("ToProto error: {}", e)))? + .map_err(|e| status::internal(&format!("ToProto error: {}", e)))? .serialize() - .map_err(|e| status::internal(format!("Serialize error: {}", e))) + .map_err(|e| status::internal(&format!("Serialize error: {}", e))) } - /// SAFETY: all pointer arguments (`out`, `out_status_message`) must be valid for writes. + /// SAFETY: `out` must be valid for writes. unsafe fn generate_partial_decryption_response_ffi( &mut self, request: &[u8], out: *mut Vec, - out_status_message: *mut cxx::UniquePtr, - ) -> i32 { - match self.generate_partial_decryption_response_serialized(request) { - Ok(response) => { - *out = response; - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } + ) -> ffi::FfiStatus { + self.generate_partial_decryption_response_serialized(request) + .map(|response| *out = response) + .into() } } @@ -252,6 +211,12 @@ pub mod ffi { values: Vec, } + unsafe extern "C++" { + include!("ffi_utils/status.rs.h"); + #[namespace = "secure_aggregation"] + type FfiStatus = status::ffi::FfiStatus; + } + extern "Rust" { #[cxx_name = "ShellTestingDecryptorRust"] type ShellTestingDecryptor; @@ -259,31 +224,27 @@ pub mod ffi { unsafe fn create_shell_testing_decryptor( config: &[u8], out: *mut *mut ShellTestingDecryptor, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[rust_name = "generate_public_key_ffi"] unsafe fn generate_public_key( self: &mut ShellTestingDecryptor, out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[rust_name = "decrypt_ffi"] unsafe fn decrypt( self: &mut ShellTestingDecryptor, contribution: &[u8], out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; #[rust_name = "generate_partial_decryption_response_ffi"] unsafe fn generate_partial_decryption_response( self: &mut ShellTestingDecryptor, request: &[u8], out: *mut Vec, - out_status_message: *mut UniquePtr, - ) -> i32; + ) -> FfiStatus; unsafe fn decryptor_into_box(ptr: *mut ShellTestingDecryptor) -> Box; @@ -294,30 +255,21 @@ fn create_shell_testing_decryptor_impl( config: &[u8], ) -> Result, StatusError> { let aggregation_config_proto = AggregationConfigProto::parse(config) - .map_err(|e| status::internal(format!("Failed to parse AggregationConfigProto: {}", e)))?; + .map_err(|e| status::internal(&format!("Failed to parse AggregationConfigProto: {}", e)))?; let aggregation_config = AggregationConfig::from_proto(aggregation_config_proto, ())?; let context_bytes = &aggregation_config.key_id; let decryptor = ShellTestingDecryptor::new(&aggregation_config, context_bytes)?; Ok(Box::new(decryptor)) } -/// SAFETY: `out` and `out_status_message` must not be null. +/// SAFETY: `out` must be valid for writes. unsafe fn create_shell_testing_decryptor( config: &[u8], out: *mut *mut ShellTestingDecryptor, - out_status_message: *mut cxx::UniquePtr, -) -> i32 { - match create_shell_testing_decryptor_impl(config) { - Ok(decryptor) => { - *out = Box::into_raw(decryptor); - 0 - } - Err(status_error) => { - let ffi_status: FfiStatus = status_error.into(); - *out_status_message = ffi_status.message; - ffi_status.code - } - } +) -> ffi::FfiStatus { + create_shell_testing_decryptor_impl(config) + .map(|decryptor| *out = Box::into_raw(decryptor)) + .into() } /// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw` diff --git a/willow/src/testing_utils/shell_testing_parameters.rs b/willow/src/testing_utils/shell_testing_parameters.rs index 6d699ed..3c02eb2 100644 --- a/willow/src/testing_utils/shell_testing_parameters.rs +++ b/willow/src/testing_utils/shell_testing_parameters.rs @@ -67,7 +67,7 @@ pub fn make_kahe_config_for( num_public_polynomials: divide_and_roundup(total_num_coeffs, 1 << 12), }) } - _ => Err(status::invalid_argument(format!( + _ => Err(status::invalid_argument(&format!( "No KAHE configuration for plaintext_modulus_bits = {}", plaintext_modulus_bits ))), diff --git a/willow/src/willow_v1/decryptor.rs b/willow/src/willow_v1/decryptor.rs index b911c7d..65705a7 100644 --- a/willow/src/willow_v1/decryptor.rs +++ b/willow/src/willow_v1/decryptor.rs @@ -125,7 +125,7 @@ where ) -> Result, status::StatusError> { let Some(ref sk_share) = decryptor_state.sk_share else { return Err(status::failed_precondition( - "decryptor_state does not contain a secret key share".to_string(), + "decryptor_state does not contain a secret key share", )); }; // Compute the partial decryption. diff --git a/willow/src/willow_v1/server.rs b/willow/src/willow_v1/server.rs index 85f3138..1a216e7 100644 --- a/willow/src/willow_v1/server.rs +++ b/willow/src/willow_v1/server.rs @@ -189,7 +189,7 @@ where server_state: &mut Self::ServerState, ) -> Result<(), status::StatusError> { if server_state.decryptor_public_key_shares.contains_key(decryptor_id) { - return Err(status::failed_precondition(format!( + return Err(status::failed_precondition(&format!( "Public key share for decryptor with ID '{decryptor_id}' has already been handled." ))); } diff --git a/willow/src/zk/linear_ip.rs b/willow/src/zk/linear_ip.rs index 7ea38b5..3fc23c5 100644 --- a/willow/src/zk/linear_ip.rs +++ b/willow/src/zk/linear_ip.rs @@ -108,7 +108,7 @@ impl LinearInnerProductProverVerifier { ) -> Result { if a.len() != self.params.G.len() { return Err(status::permission_denied( - "Length of a doesnt match length specified at prover construction.".to_string(), + "Length of a doesnt match length specified at prover construction.", )); } let C = RistrettoPoint::vartime_multiscalar_mul( @@ -128,13 +128,12 @@ impl LinearInnerProductProverVerifier { ) -> Result { if end > self.params.G.len() { return Err(status::permission_denied( - "End of range is beyond the size speicified at prover construction.".to_string(), + "End of range is beyond the size speicified at prover construction.", )); } if (end - start) != a.len() { return Err(status::permission_denied( - "Length of a doesnt match the length of the specified range of generators." - .to_string(), + "Length of a doesnt match the length of the specified range of generators.", )); } let C = RistrettoPoint::vartime_multiscalar_mul( @@ -162,19 +161,18 @@ impl ) -> Result { if self.params.n != statement.n { return Err(status::permission_denied( - "Number of generators doesnt match length in statement".to_string(), + "Number of generators doesnt match length in statement", )); } if witness.a.len() != statement.n { return Err(status::permission_denied( - "Length of witness doesnt match length in statement".to_string(), + "Length of witness doesnt match length in statement", )); } if statement.b.len() != statement.n { return Err(status::permission_denied( - "Length of vector b in statement doesnt match value for length in statement" - .to_string(), + "Length of vector b in statement doesnt match value for length in statement", )); } diff --git a/willow/src/zk/quadratic_ip.rs b/willow/src/zk/quadratic_ip.rs index 2017983..f7efec3 100644 --- a/willow/src/zk/quadratic_ip.rs +++ b/willow/src/zk/quadratic_ip.rs @@ -114,12 +114,12 @@ impl QuadraticInnerProductProverVerifier { ) -> Result { if a.len() != self.params.n { return Err(status::permission_denied( - "Length of a doesnt match length specified at prover construction.".to_string(), + "Length of a doesnt match length specified at prover construction.", )); } if b.len() != self.params.n { return Err(status::permission_denied( - "Length of b doesnt match length specified at prover construction.".to_string(), + "Length of b doesnt match length specified at prover construction.", )); } let C = RistrettoPoint::vartime_multiscalar_mul( @@ -152,12 +152,12 @@ impl ) -> Result { if witness.a.len() != self.params.n { return Err(status::permission_denied( - "Length of witness a doesn't match length in parameters".to_string(), + "Length of witness a doesn't match length in parameters", )); } if witness.b.len() != self.params.n { return Err(status::permission_denied( - "Length of witness b doesnt match length in parameters".to_string(), + "Length of witness b doesnt match length in parameters", )); } diff --git a/willow/src/zk/rlwe_relation.rs b/willow/src/zk/rlwe_relation.rs index 60da5d9..96e99f6 100644 --- a/willow/src/zk/rlwe_relation.rs +++ b/willow/src/zk/rlwe_relation.rs @@ -163,7 +163,7 @@ fn check_statement_can_be_handled( // Check that the n of the statement matches that used at construction. if n != statement.n { return Err(status::failed_precondition( - "n in statement does not match n specified at construction.".to_string(), + "n in statement does not match n specified at construction.", )); } let q = statement.q; @@ -177,14 +177,13 @@ fn check_statement_can_be_handled( let log_gap = ceil_log_2(2500 * (usize::isqrt(n) as u128 + 1)); if log_bound_w + log_gap + log_q > 251 { return Err(status::failed_precondition( - "q^2*n^(3/2)*2500 exceeds (or almost exceeds) 2^251 so the proof would overflow." - .to_string(), + "q^2*n^(3/2)*2500 exceeds (or almost exceeds) 2^251 so the proof would overflow.", )); } let samples_required = calculate_samples_required(n, q, 128) as usize; if samples_required > MAX_RHOS { return Err(status::failed_precondition( - "Too many samples required to prove the relation. n is too close to q, if we can't use a larger q or smaller n. We could change the constant MAX_RHOS to be larger.".to_string(), + "Too many samples required to prove the relation. n is too close to q, if we can't use a larger q or smaller n. We could change the constant MAX_RHOS to be larger.", )); } Ok(()) @@ -341,9 +340,7 @@ pub fn try_matrices_and_compute_z( ) -> Result, status::StatusError> { let n = v.len(); if n != R1.len() || n != R2.len() { - return Err(status::failed_precondition( - "R1, R2, and v must have the same length".to_string(), - )); + return Err(status::failed_precondition("R1, R2, and v must have the same length")); } let mut z = vec![0 as i128; 128]; for j in 0..128 { @@ -375,7 +372,7 @@ pub fn flatten_challenge_matrix( ) -> Result<(Vec, Vec), status::StatusError> { let n = R1.len(); if n != R2.len() { - return Err(status::failed_precondition("R1 and R2 have different lengths".to_string())); + return Err(status::failed_precondition("R1 and R2 have different lengths")); } let mut Rplus = vec![0u128; n]; @@ -414,7 +411,7 @@ fn check_loose_bound_will_not_overflow(bound: u128, n: usize) -> Result<(), stat // to account for machine precision errors. if log_loose_bound > 126.99 { return Err(status::failed_precondition( - "The bound requested is too large, the product would overflow".to_string(), + "The bound requested is too large, the product would overflow", )); } Ok(()) @@ -498,12 +495,12 @@ fn generate_range_product( for x in v { if x.abs() > bound as i128 { return Err(status::failed_precondition( - "Provided vector doesn't satisfy the given bound.".to_string(), + "Provided vector doesn't satisfy the given bound.", )); } } return Err(status::internal( - "Rejection sampling failed too many times. This should never happen by chance and is likely a bug.".to_string(), + "Rejection sampling failed too many times. This should never happen by chance and is likely a bug.", )); } // Restore the backup transcript to try again. @@ -551,7 +548,7 @@ fn generate_range_product_for_verification_and_verify_z_bound( || u128::from_le_bytes(z_bytes[0..16].try_into().unwrap()) > loose_bound { return Err(status::permission_denied( - "Provided z doesn't satisfy the required uniform bound.".to_string(), + "Provided z doesn't satisfy the required uniform bound.", )); } } @@ -717,7 +714,7 @@ impl<'a> ZeroKnowledgeProver, RlweRelationProofWi // = crho + q*wrho over the integers. if (arrho_lower + erho_lower + vprho_lower) % q != crho_vec[j] { return Err(status::failed_precondition( - "The provided witness does not satisfy the relation.".to_string(), + "The provided witness does not satisfy the relation.", )); } wrho_vec[j] = arrho_upper