Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ffi_utils/status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ absl::Status UnwrapFfiStatus(const FfiStatus& status) {
return absl::OkStatus();
}
absl::string_view message = "";
if (status.message != nullptr || status.message->empty()) {
if (status.message != nullptr) {
message = *status.message;
}
return absl::Status(static_cast<absl::StatusCode>(status.code), message);
Expand Down
7 changes: 6 additions & 1 deletion willow/src/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ cc_library(
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
"@cxx.rs//:core",
"//ffi_utils:status_cc",
"//ffi_utils:status_macros",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:messages_cc_proto",
"//willow/proto/willow:server_accumulator_cc_proto",
Expand Down Expand Up @@ -90,6 +92,7 @@ rust_cxx_bridge(
src = "server_accumulator.rs",
deps = [
":server_accumulator",
"//ffi_utils:status_cxx",
],
)

Expand Down Expand Up @@ -153,6 +156,7 @@ rust_cxx_bridge(
src = "client.rs",
deps = [
":client",
"//ffi_utils:status_cxx",
],
)

Expand All @@ -165,7 +169,8 @@ cc_library(
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@cxx.rs//:core",
"//shell_wrapper:shell_types_cc",
"//ffi_utils:cxx_utils",
"//ffi_utils:status_macros",
"//willow/proto/shell:shell_ciphertexts_cc_proto",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:messages_cc_proto",
Expand Down
21 changes: 6 additions & 15 deletions willow/src/api/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "ffi_utils/cxx_utils.h"
#include "ffi_utils/status_macros.h"
#include "include/cxx.h"
#include "shell_wrapper/shell_types.h"
#include "willow/proto/shell/ciphertexts.pb.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"
Expand All @@ -40,12 +41,8 @@ absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
std::string config_str = aggregation_config.SerializeAsString();
auto config_ptr = std::make_unique<std::string>(std::move(config_str));
secure_aggregation::WillowShellClient* client_ptr = nullptr;
std::unique_ptr<std::string> status_message;
int status_code =
initialize_client(std::move(config_ptr), &client_ptr, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
SECAGG_RETURN_IF_FFI_ERROR(
initialize_client(std::move(config_ptr), &client_ptr));
// Use `into_box` to avoid linker issues arising from rust::Box::from_raw.
auto client = client_into_box(client_ptr);

Expand All @@ -69,15 +66,9 @@ absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
auto key_ptr = std::make_unique<std::string>(std::move(key_str));
rust::Slice<const uint8_t> nonce_slice = ToRustSlice(nonce);
rust::Vec<uint8_t> result_bytes;
std::unique_ptr<std::string> status_message_gen;

// Encrypt data.
int status_code_gen =
generate_contribution(client, entries_slice, std::move(key_ptr),
nonce_slice, &result_bytes, &status_message_gen);
if (status_code_gen != 0) {
return absl::Status(absl::StatusCode(status_code_gen), *status_message_gen);
}
SECAGG_RETURN_IF_FFI_ERROR(generate_contribution(
client, entries_slice, std::move(key_ptr), nonce_slice, &result_bytes));

// Parse string to ClientMessage.
willow::ClientMessage client_message;
Expand Down
50 changes: 18 additions & 32 deletions willow/src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ pub mod ffi {
values: &'a [u64],
}

unsafe extern "C++" {
include!("ffi_utils/status.rs.h");
type FfiStatus = status::ffi::FfiStatus;
}

extern "Rust" {
// cxx: types used as extern Rust types are required to be defined by the same crate that
// contains the bridge using them
Expand All @@ -54,8 +59,7 @@ pub mod ffi {
pub unsafe fn initialize_client(
config: UniquePtr<CxxString>,
out: *mut *mut WillowShellClient,
out_status_message: *mut UniquePtr<CxxString>,
) -> i32;
) -> FfiStatus;

unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box<WillowShellClient>;

Expand All @@ -65,8 +69,7 @@ pub mod ffi {
key: UniquePtr<CxxString>,
nonce: &[u8],
out: *mut Vec<u8>,
out_status_message: *mut UniquePtr<CxxString>,
) -> i32;
) -> FfiStatus;
}
}

Expand Down Expand Up @@ -112,23 +115,14 @@ impl WillowShellClient {
}
}

/// SAFETY: `out` and `out_status_message` must not be null.
/// SAFETY: `out` must be valid for writes.
unsafe fn initialize_client(
config: cxx::UniquePtr<cxx::CxxString>,
out: *mut *mut WillowShellClient,
out_status_message: *mut cxx::UniquePtr<cxx::CxxString>,
) -> i32 {
match WillowShellClient::new_from_serialized_config(config) {
Ok(client) => {
*out = Box::into_raw(Box::new(client));
0
}
Err(status_error) => {
let ffi_status: FfiStatus = status_error.into();
*out_status_message = ffi_status.message;
ffi_status.code
}
}
) -> ffi::FfiStatus {
WillowShellClient::new_from_serialized_config(config)
.map(|client| *out = Box::into_raw(Box::new(client)))
.into()
}

/// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw`
Expand All @@ -139,24 +133,16 @@ unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box<WillowShellClient>
Box::from_raw(ptr)
}

/// SAFETY: `out` and `out_status_message` must not be null.
/// SAFETY: `out` must be valid for writes.
unsafe fn generate_contribution(
client: &mut Box<WillowShellClient>,
data: &[ffi::DataEntryView],
public_key: cxx::UniquePtr<cxx::CxxString>,
nonce: &[u8],
out: *mut Vec<u8>,
out_status_message: *mut cxx::UniquePtr<cxx::CxxString>,
) -> i32 {
match client.generate_contribution(data, public_key, nonce) {
Ok(contribution) => {
*out = contribution;
0
}
Err(status_error) => {
let ffi_status: FfiStatus = status_error.into();
*out_status_message = ffi_status.message;
ffi_status.code
}
}
) -> ffi::FfiStatus {
client
.generate_contribution(data, public_key, nonce)
.map(|contribution| *out = contribution)
.into()
}
74 changes: 20 additions & 54 deletions willow/src/api/server_accumulator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,31 @@
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "ffi_utils/status_macros.h"
#include "include/cxx.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"
#include "willow/src/api/server_accumulator.rs.h"
#include "willow/src/input_encoding/codec.h"

namespace secure_aggregation {
namespace willow {
namespace secure_aggregation::willow {

absl::StatusOr<std::unique_ptr<ServerAccumulator>> ServerAccumulator::Create(
const AggregationConfigProto& aggregation_config) {
secure_aggregation::ServerAccumulator* out;
std::unique_ptr<std::string> status_message;
int status_code =
SECAGG_RETURN_IF_FFI_ERROR(
secure_aggregation::NewServerAccumulatorFromSerializedConfig(
std::make_unique<std::string>(aggregation_config.SerializeAsString()),
&out, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
&out));
return absl::WrapUnique(new ServerAccumulator(IntoBox(out)));
}

absl::StatusOr<std::unique_ptr<ServerAccumulator>>
ServerAccumulator::CreateFromSerializedState(std::string serialized_state) {
secure_aggregation::ServerAccumulator* out;
std::unique_ptr<std::string> status_message;
int status_code = secure_aggregation::NewServerAccumulatorFromSerializedState(
std::make_unique<std::string>(std::move(serialized_state)), &out,
&status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
SECAGG_RETURN_IF_FFI_ERROR(
secure_aggregation::NewServerAccumulatorFromSerializedState(
std::make_unique<std::string>(std::move(serialized_state)), &out));
return absl::WrapUnique(new ServerAccumulator(IntoBox(out)));
}

Expand All @@ -68,35 +60,22 @@ absl::Status ServerAccumulator::ProcessClientMessages(

absl::Status ServerAccumulator::ProcessClientMessages(
std::string serialized_client_messages) {
std::unique_ptr<std::string> status_message;
int status_code = accumulator_->ProcessClientMessages(
std::make_unique<std::string>(std::move(serialized_client_messages)),
&status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
SECAGG_RETURN_IF_FFI_ERROR(accumulator_->ProcessClientMessages(
std::make_unique<std::string>(std::move(serialized_client_messages))));
return absl::OkStatus();
}

absl::Status ServerAccumulator::Merge(
std::unique_ptr<ServerAccumulator> other) {
std::unique_ptr<std::string> status_message;
int status_code =
accumulator_->Merge(std::move(other->accumulator_), &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
SECAGG_RETURN_IF_FFI_ERROR(
accumulator_->Merge(std::move(other->accumulator_)));
return absl::OkStatus();
}

absl::StatusOr<std::string> ServerAccumulator::ToSerializedState() {
rust::Vec<uint8_t> serialized_state;
std::unique_ptr<std::string> status_message;
int status_code =
accumulator_->ToSerializedState(&serialized_state, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
SECAGG_RETURN_IF_FFI_ERROR(
accumulator_->ToSerializedState(&serialized_state));
return std::string(reinterpret_cast<const char*>(serialized_state.data()),
serialized_state.size());
}
Expand All @@ -105,13 +84,9 @@ absl::StatusOr<FinalizedAccumulatorResult> ServerAccumulator::Finalize() && {
// Finalize accumulator in Rust and store the serialized results.
rust::Vec<uint8_t> decryption_request;
rust::Vec<uint8_t> final_result_decryptor_state;
std::unique_ptr<std::string> status_message;
int status_code = secure_aggregation::FinalizeServerAccumulator(
SECAGG_RETURN_IF_FFI_ERROR(secure_aggregation::FinalizeServerAccumulator(
std::move(accumulator_), &decryption_request,
&final_result_decryptor_state, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
&final_result_decryptor_state));

// Pack the two serialized results into a single proto.
FinalizedAccumulatorResult result_proto;
Expand All @@ -129,30 +104,22 @@ absl::StatusOr<std::unique_ptr<FinalResultDecryptor>>
FinalResultDecryptor::CreateFromSerialized(
std::string final_result_decryptor_state) {
secure_aggregation::FinalResultDecryptor* out;
std::unique_ptr<std::string> status_message;
int status_code =
SECAGG_RETURN_IF_FFI_ERROR(
secure_aggregation::CreateFinalResultDecryptorFromSerialized(
std::make_unique<std::string>(
std::move(final_result_decryptor_state)),
&out, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
&out));
return absl::WrapUnique(new FinalResultDecryptor(
secure_aggregation::FinalResultDecryptorIntoBox(out)));
}

absl::StatusOr<EncodedData> FinalResultDecryptor::Decrypt(
std::string serialized_partial_decryption_response) {
rust::Vec<EncodedDataEntry> out;
std::unique_ptr<std::string> status_message;
int status_code = aggregated_ciphertexts_->Decrypt(
SECAGG_RETURN_IF_FFI_ERROR(aggregated_ciphertexts_->Decrypt(
std::make_unique<std::string>(
std::move(serialized_partial_decryption_response)),
&out, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
&out));
EncodedData encoded_data;
for (const auto& rust_entry : out) {
std::string key(rust_entry.key);
Expand All @@ -166,5 +133,4 @@ absl::StatusOr<EncodedData> FinalResultDecryptor::Decrypt(
return encoded_data;
}

} // namespace willow
} // namespace secure_aggregation
} // namespace secure_aggregation::willow
Loading