Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ffi_utils/status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ FfiStatus MakeFfiStatus(int code, std::string message) {
return FfiStatus{code, std::make_unique<std::string>(std::move(message))};
}

absl::Status UnwrapFfiStatus(const FfiStatus& status) {
absl::Status UnwrapFfiStatus(FfiStatus status) {
if (status.code == 0) {
return absl::OkStatus();
}
absl::string_view message = "";
if (status.message != nullptr || status.message->empty()) {
if (status.message != nullptr) {
message = *status.message;
}
return absl::Status(static_cast<absl::StatusCode>(status.code), message);
Expand Down
2 changes: 1 addition & 1 deletion ffi_utils/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct FfiStatus;
FfiStatus MakeFfiStatus();
FfiStatus MakeFfiStatus(absl::Status status);
FfiStatus MakeFfiStatus(int code, std::string message);
absl::Status UnwrapFfiStatus(const FfiStatus& status);
absl::Status UnwrapFfiStatus(FfiStatus status);

// To be called from Rust.
FfiStatus MakeFfiStatus(int32_t code, rust::Slice<const uint8_t> message);
Expand Down
20 changes: 8 additions & 12 deletions shell_wrapper/ahe_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,17 @@ constexpr double kSBase = 12.8;
TEST(AheTest, EncryptDecryptOne) {
// Create the public parameters.
std::unique_ptr<std::string> public_seed;
auto status = GenerateSingleThreadHkdfSeed(public_seed);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed)));
AhePublicParameters public_params;
SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateAhePublicParameters(
kLogN, kT, kQs, std::size(kQs), kPublicKeyVariance, kSBase, kSFlood,
ToRustSlice(*public_seed), &public_params)));

std::unique_ptr<std::string> private_seed;
status = GenerateSingleThreadHkdfSeed(private_seed);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed)));
SingleThreadHkdfWrapper prng;
status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(
CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng)));

RnsPolynomialWrapper sk_share;
SECAGG_ASSERT_OK(
Expand Down Expand Up @@ -118,19 +116,17 @@ TEST(AheTest, EncryptDecryptOne) {
TEST(AheTest, ExternCRecoveryTest) {
// Create the public parameters.
std::unique_ptr<std::string> public_seed;
auto status = GenerateSingleThreadHkdfSeed(public_seed);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed)));
AhePublicParameters public_params;
SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateAhePublicParameters(
kLogN, kT, kQs, std::size(kQs), kPublicKeyVariance, kSBase, kSFlood,
ToRustSlice(*public_seed), &public_params)));

std::unique_ptr<std::string> private_seed;
status = GenerateSingleThreadHkdfSeed(private_seed);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed)));
SingleThreadHkdfWrapper prng;
status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(
CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng)));

std::vector<RnsPolynomialWrapper> secret_key_shares;
std::vector<RnsPolynomialWrapper> public_key_shares;
Expand Down
65 changes: 24 additions & 41 deletions shell_wrapper/shell_types_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,11 @@ TEST(ShellTypesTest, AddErrors) {
TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) {
constexpr int num_public_polynomials = 1;
std::unique_ptr<std::string> public_seed;
FfiStatus status;
status = GenerateSingleThreadHkdfSeed(public_seed);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed)));
KahePublicParametersWrapper params_wrapper;
status = CreateKahePublicParametersWrapper(
SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper(
kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials,
ToRustSlice(*public_seed), &params_wrapper);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
ToRustSlice(*public_seed), &params_wrapper)));
ModuliWrapper moduli_wrapper =
CreateModuliWrapperFromKaheParams(params_wrapper);

Expand All @@ -176,9 +173,8 @@ TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) {
constexpr int buffer_len = 2 * kNumCoeffs;
int64_t buffer[buffer_len];
uint64_t n_written;
FfiStatus res = WriteSmallRnsPolynomialToBuffer(
&poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written);
SECAGG_EXPECT_OK(UnwrapFfiStatus(res));
SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer(
&poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written)));
EXPECT_EQ(n_written, kNumCoeffs);

// We get 1 indeed.
Expand All @@ -201,23 +197,19 @@ TEST(ShellTypesTest, WriteSmallRnsPolynomialToBufferKahe) {
poly_wrapper = {.ptr = std::make_unique<RnsPolynomial>(std::move(poly))};

// Write the polynomial to a buffer.
res = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper,
buffer_len, buffer, &n_written);
SECAGG_EXPECT_OK(UnwrapFfiStatus(res));
SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer(
&poly_wrapper, moduli_wrapper, buffer_len, buffer, &n_written)));
EXPECT_EQ(n_written, kNumCoeffs);
}

TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) {
constexpr int num_public_polynomials = 1;
std::unique_ptr<std::string> public_seed;
FfiStatus status;
status = GenerateSingleThreadHkdfSeed(public_seed);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed)));
KahePublicParametersWrapper params_wrapper;
status = CreateKahePublicParametersWrapper(
SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper(
kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials,
ToRustSlice(*public_seed), &params_wrapper);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
ToRustSlice(*public_seed), &params_wrapper)));
ModuliWrapper moduli_wrapper =
CreateModuliWrapperFromKaheParams(params_wrapper);

Expand All @@ -234,9 +226,8 @@ TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) {
}

RnsPolynomialWrapper poly{nullptr};
status = ReadSmallRnsPolynomialFromBuffer(buffer, buffer_len, 1 << kLogN,
moduli_wrapper, &poly);
SECAGG_EXPECT_OK(UnwrapFfiStatus(status));
SECAGG_EXPECT_OK(UnwrapFfiStatus(ReadSmallRnsPolynomialFromBuffer(
buffer, buffer_len, 1 << kLogN, moduli_wrapper, &poly)));
EXPECT_NE(poly.ptr, nullptr);

// The coefficients of the polynomial should match the ones in the
Expand Down Expand Up @@ -267,23 +258,19 @@ TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) {
// Write the polynomial back to another buffer should give the same result.
int64_t buffer_out[buffer_len];
uint64_t n_written;
status = WriteSmallRnsPolynomialToBuffer(&poly, moduli_wrapper, buffer_len,
buffer_out, &n_written);
SECAGG_EXPECT_OK(UnwrapFfiStatus(status));
SECAGG_EXPECT_OK(UnwrapFfiStatus(WriteSmallRnsPolynomialToBuffer(
&poly, moduli_wrapper, buffer_len, buffer_out, &n_written)));
EXPECT_EQ(absl::MakeSpan(buffer_out), absl::MakeSpan(buffer));
}

TEST(ShellTypesTest, ReadWriteErrors) {
constexpr int num_public_polynomials = 1;
std::unique_ptr<std::string> public_seed;
FfiStatus status;
status = GenerateSingleThreadHkdfSeed(public_seed);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed)));
KahePublicParametersWrapper params_wrapper;
status = CreateKahePublicParametersWrapper(
SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper(
kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials,
ToRustSlice(*public_seed), &params_wrapper);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
ToRustSlice(*public_seed), &params_wrapper)));
ModuliWrapper moduli_wrapper =
CreateModuliWrapperFromKaheParams(params_wrapper);

Expand All @@ -301,13 +288,13 @@ TEST(ShellTypesTest, ReadWriteErrors) {
constexpr int output_buffer_len = 2 * kNumCoeffs;
int64_t output_buffer[output_buffer_len];
uint64_t n_written;
status = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper,
output_buffer_len, output_buffer,
&n_written);
auto status = WriteSmallRnsPolynomialToBuffer(&poly_wrapper, moduli_wrapper,
output_buffer_len,
output_buffer, &n_written);

// We should get an error because large coefficients don't have the same value
// mod q_1 and q_2. Note that the buffer gets filled still.
EXPECT_THAT(UnwrapFfiStatus(status),
EXPECT_THAT(UnwrapFfiStatus(std::move(status)),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("not a small polynomial")));
EXPECT_EQ(n_written, kNumCoeffs);
Expand All @@ -324,17 +311,15 @@ TEST(ShellTypesTest, ReadWriteErrors) {
long_input_buffer_len, 1 << kLogN,
moduli_wrapper, &poly_wrapper);
// We should get an error.
EXPECT_THAT(UnwrapFfiStatus(status),
EXPECT_THAT(UnwrapFfiStatus(std::move(status)),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Buffer has too many coefficients")));
}

TEST(AheTest, TestWriteRnsPolynomialToBuffer128) {
constexpr int num_public_polynomials = 1;
std::unique_ptr<std::string> public_seed;
FfiStatus status;
status = GenerateSingleThreadHkdfSeed(public_seed);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed)));
KahePublicParametersWrapper params_wrapper;
SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper(
kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials,
Expand Down Expand Up @@ -388,9 +373,7 @@ TEST(AheTest, TestWriteRnsPolynomialToBuffer128) {
TEST(AheTest, WriteRnsPolynomialToBuffer128FailsWhenBufferLenIsWrong) {
constexpr int num_public_polynomials = 1;
std::unique_ptr<std::string> public_seed;
FfiStatus status;
status = GenerateSingleThreadHkdfSeed(public_seed);
SECAGG_ASSERT_OK(UnwrapFfiStatus(status));
SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed)));
KahePublicParametersWrapper params_wrapper;
SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper(
kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials,
Expand Down
7 changes: 6 additions & 1 deletion willow/src/api/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ cc_library(
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
"@cxx.rs//:core",
"//ffi_utils:status_cc",
"//ffi_utils:status_macros",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:messages_cc_proto",
"//willow/proto/willow:server_accumulator_cc_proto",
Expand Down Expand Up @@ -90,6 +92,7 @@ rust_cxx_bridge(
src = "server_accumulator.rs",
deps = [
":server_accumulator",
"//ffi_utils:status_cxx",
],
)

Expand Down Expand Up @@ -153,6 +156,7 @@ rust_cxx_bridge(
src = "client.rs",
deps = [
":client",
"//ffi_utils:status_cxx",
],
)

Expand All @@ -165,7 +169,8 @@ cc_library(
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@cxx.rs//:core",
"//shell_wrapper:shell_types_cc",
"//ffi_utils:cxx_utils",
"//ffi_utils:status_macros",
"//willow/proto/shell:shell_ciphertexts_cc_proto",
"//willow/proto/willow:aggregation_config_cc_proto",
"//willow/proto/willow:messages_cc_proto",
Expand Down
21 changes: 6 additions & 15 deletions willow/src/api/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "ffi_utils/cxx_utils.h"
#include "ffi_utils/status_macros.h"
#include "include/cxx.h"
#include "shell_wrapper/shell_types.h"
#include "willow/proto/shell/ciphertexts.pb.h"
#include "willow/proto/willow/aggregation_config.pb.h"
#include "willow/proto/willow/server_accumulator.pb.h"
Expand All @@ -40,12 +41,8 @@ absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
std::string config_str = aggregation_config.SerializeAsString();
auto config_ptr = std::make_unique<std::string>(std::move(config_str));
secure_aggregation::WillowShellClient* client_ptr = nullptr;
std::unique_ptr<std::string> status_message;
int status_code =
initialize_client(std::move(config_ptr), &client_ptr, &status_message);
if (status_code != 0) {
return absl::Status(absl::StatusCode(status_code), *status_message);
}
SECAGG_RETURN_IF_FFI_ERROR(
initialize_client(std::move(config_ptr), &client_ptr));
// Use `into_box` to avoid linker issues arising from rust::Box::from_raw.
auto client = client_into_box(client_ptr);

Expand All @@ -69,15 +66,9 @@ absl::StatusOr<willow::ClientMessage> GenerateClientContribution(
auto key_ptr = std::make_unique<std::string>(std::move(key_str));
rust::Slice<const uint8_t> nonce_slice = ToRustSlice(nonce);
rust::Vec<uint8_t> result_bytes;
std::unique_ptr<std::string> status_message_gen;

// Encrypt data.
int status_code_gen =
generate_contribution(client, entries_slice, std::move(key_ptr),
nonce_slice, &result_bytes, &status_message_gen);
if (status_code_gen != 0) {
return absl::Status(absl::StatusCode(status_code_gen), *status_message_gen);
}
SECAGG_RETURN_IF_FFI_ERROR(generate_contribution(
client, entries_slice, std::move(key_ptr), nonce_slice, &result_bytes));

// Parse string to ClientMessage.
willow::ClientMessage client_message;
Expand Down
50 changes: 18 additions & 32 deletions willow/src/api/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ pub mod ffi {
values: &'a [u64],
}

unsafe extern "C++" {
include!("ffi_utils/status.rs.h");
type FfiStatus = status::ffi::FfiStatus;
}

extern "Rust" {
// cxx: types used as extern Rust types are required to be defined by the same crate that
// contains the bridge using them
Expand All @@ -54,8 +59,7 @@ pub mod ffi {
pub unsafe fn initialize_client(
config: UniquePtr<CxxString>,
out: *mut *mut WillowShellClient,
out_status_message: *mut UniquePtr<CxxString>,
) -> i32;
) -> FfiStatus;

unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box<WillowShellClient>;

Expand All @@ -65,8 +69,7 @@ pub mod ffi {
key: UniquePtr<CxxString>,
nonce: &[u8],
out: *mut Vec<u8>,
out_status_message: *mut UniquePtr<CxxString>,
) -> i32;
) -> FfiStatus;
}
}

Expand Down Expand Up @@ -112,23 +115,14 @@ impl WillowShellClient {
}
}

/// SAFETY: `out` and `out_status_message` must not be null.
/// SAFETY: `out` must be valid for writes.
unsafe fn initialize_client(
config: cxx::UniquePtr<cxx::CxxString>,
out: *mut *mut WillowShellClient,
out_status_message: *mut cxx::UniquePtr<cxx::CxxString>,
) -> i32 {
match WillowShellClient::new_from_serialized_config(config) {
Ok(client) => {
*out = Box::into_raw(Box::new(client));
0
}
Err(status_error) => {
let ffi_status: FfiStatus = status_error.into();
*out_status_message = ffi_status.message;
ffi_status.code
}
}
) -> ffi::FfiStatus {
WillowShellClient::new_from_serialized_config(config)
.map(|client| *out = Box::into_raw(Box::new(client)))
.into()
}

/// Converts a raw pointer to a Box. Ideally we would use `rust::Box::from_raw`
Expand All @@ -139,24 +133,16 @@ unsafe fn client_into_box(ptr: *mut WillowShellClient) -> Box<WillowShellClient>
Box::from_raw(ptr)
}

/// SAFETY: `out` and `out_status_message` must not be null.
/// SAFETY: `out` must be valid for writes.
unsafe fn generate_contribution(
client: &mut Box<WillowShellClient>,
data: &[ffi::DataEntryView],
public_key: cxx::UniquePtr<cxx::CxxString>,
nonce: &[u8],
out: *mut Vec<u8>,
out_status_message: *mut cxx::UniquePtr<cxx::CxxString>,
) -> i32 {
match client.generate_contribution(data, public_key, nonce) {
Ok(contribution) => {
*out = contribution;
0
}
Err(status_error) => {
let ffi_status: FfiStatus = status_error.into();
*out_status_message = ffi_status.message;
ffi_status.code
}
}
) -> ffi::FfiStatus {
client
.generate_contribution(data, public_key, nonce)
.map(|contribution| *out = contribution)
.into()
}
Loading