Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions shell_wrapper/ahe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class PublicKeyRawFactory {
namespace secure_aggregation {

FfiStatus CreateAhePublicParameters(uint64_t log_n, uint64_t t,
const uint64_t* qs, size_t num_qs,
const uint64_t* qs, uint64_t num_qs,
uint64_t error_variance,
double s_base_flood, double s_flood,
rust::Slice<const uint8_t> seed,
Expand Down Expand Up @@ -193,7 +193,7 @@ FfiStatus GeneratePublicKeyShareWrapper(
public_key_share_error->ptr.get(), wraparound_ptr));
}

FfiStatus AheEncrypt(const uint64_t* input_values, size_t num_input_values,
FfiStatus AheEncrypt(const uint64_t* input_values, uint64_t num_input_values,
const RnsPolynomialWrapper& public_key_b,
const AhePublicParameters& params,
SingleThreadHkdfWrapper* prng,
Expand Down Expand Up @@ -270,8 +270,8 @@ FfiStatus PartialDecrypt(const RnsPolynomialWrapper& ciphertext_component_a,
FfiStatus RecoverMessages(const RnsPolynomialWrapper& sum_partial_decryptions,
const RnsPolynomialWrapper& ciphertext_component_b,
const AhePublicParameters& params,
size_t output_values_length, uint64_t* output_values,
size_t* n_written) {
uint64_t output_values_length,
uint64_t* output_values, uint64_t* n_written) {
if (output_values == nullptr || sum_partial_decryptions.ptr == nullptr ||
ciphertext_component_b.ptr == nullptr || n_written == nullptr) {
return MakeFfiStatus(absl::InvalidArgumentError(
Expand All @@ -288,7 +288,8 @@ FfiStatus RecoverMessages(const RnsPolynomialWrapper& sum_partial_decryptions,
}

// Copy messages from vector to output buffer.
*n_written = std::min(output_values_length, messages->size());
*n_written =
std::min(static_cast<size_t>(output_values_length), messages->size());
std::copy_n(messages->begin(), *n_written, output_values);
return MakeFfiStatus();
}
Expand Down
8 changes: 4 additions & 4 deletions shell_wrapper/ahe.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace secure_aggregation {
// `seed` is the seed used to generate the public parameter (polynomial) u.
// Writes the result to `out` and returns a wrapped absl::Status.
FfiStatus CreateAhePublicParameters(uint64_t log_n, uint64_t t,
const uint64_t* qs, size_t num_qs,
const uint64_t* qs, uint64_t num_qs,
uint64_t error_variance,
double s_base_flood, double s_flood,
rust::Slice<const uint8_t> seed,
Expand Down Expand Up @@ -93,7 +93,7 @@ FfiStatus GeneratePublicKeyShareWrapper(
// ciphertext in `ciphertext_component_b` (a.k.a. ct0) and
// `ciphertext_component_a` (a.k.a. ct1). Also stores the secret and error for
// ZK proofs. Returns a wrapped absl::Status.
FfiStatus AheEncrypt(const uint64_t* input_values, size_t num_input_values,
FfiStatus AheEncrypt(const uint64_t* input_values, uint64_t num_input_values,
const RnsPolynomialWrapper& public_key_b,
const AhePublicParameters& params,
SingleThreadHkdfWrapper* prng,
Expand All @@ -120,8 +120,8 @@ FfiStatus PartialDecrypt(const RnsPolynomialWrapper& ciphertext_component_a,
FfiStatus RecoverMessages(const RnsPolynomialWrapper& sum_partial_decryptions,
const RnsPolynomialWrapper& ciphertext_component_b,
const AhePublicParameters& params,
size_t output_values_length, uint64_t* output_values,
size_t* n_written);
uint64_t output_values_length,
uint64_t* output_values, uint64_t* n_written);

// Creates a zero polynomial with the same RNS parameters as `params`. Writes
// the result to `out` and returns a wrapped absl::Status.
Expand Down
12 changes: 6 additions & 6 deletions shell_wrapper/ahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mod ffi {
log_n: u64,
t: u64,
qs: *const u64,
num_qs: usize,
num_qs: u64,
error_variance: u64,
s_base_flood: f64,
s_flood: f64,
Expand Down Expand Up @@ -93,7 +93,7 @@ mod ffi {

pub unsafe fn AheEncrypt(
input_values: *const u64,
num_input_values: usize,
num_input_values: u64,
public_key_b: &RnsPolynomialWrapper,
params: &AhePublicParameters,
prng: *mut SingleThreadHkdfWrapper,
Expand All @@ -118,9 +118,9 @@ mod ffi {
sum_partial_decryptions: &RnsPolynomialWrapper,
ciphertext_component_b: &RnsPolynomialWrapper,
params: &AhePublicParameters,
num_output_values: usize,
num_output_values: u64,
output_values: *mut u64,
n_written: *mut usize,
n_written: *mut u64,
) -> FfiStatus;

pub unsafe fn CreateZeroRnsPolynomialWrapper(
Expand Down Expand Up @@ -332,8 +332,8 @@ pub fn recover_messages(
ciphertext_component_b: &RnsPolynomial,
params: &AhePublicParameters,
output_values: &mut [u64],
) -> Result<usize, status::StatusError> {
let mut n_written = 0usize;
) -> Result<u64, status::StatusError> {
let mut n_written = 0u64;
let (out_ptr, out_len) = to_cpp_pointer_len_pair_mut(output_values);
// SAFETY: No lifetime constraints (`BufferResult` just holds two ints and
// does not keep any reference to the inputs). `Decrypt` only modifies the
Expand Down
4 changes: 2 additions & 2 deletions shell_wrapper/ahe_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ TEST(AheTest, EncryptDecryptOne) {

// Recover from partial decryptions and component b
uint64_t decrypted_buffer[2 * num_messages];
size_t n_written;
uint64_t n_written;
SECAGG_ASSERT_OK(UnwrapFfiStatus(RecoverMessages(
partial_decryption, ciphertext_b, public_params,
std::size(decrypted_buffer), decrypted_buffer, &n_written)));
Expand Down Expand Up @@ -228,7 +228,7 @@ TEST(AheTest, ExternCRecoveryTest) {

// Recover from partial decryptions and component b
uint64_t decrypted_buffer[2 * num_messages];
size_t n_written;
uint64_t n_written;
SECAGG_ASSERT_OK(UnwrapFfiStatus(RecoverMessages(
sum_partial_decryptions, ciphertext_b, public_params,
std::size(decrypted_buffer), decrypted_buffer, &n_written)));
Expand Down
2 changes: 1 addition & 1 deletion shell_wrapper/ahe_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ fn encrypt_add_recover() -> Result<()> {
let mut output_values = vec![0; 2 * num_messages];
let n_written =
recover_messages(&sum_partial_decryptions, &ciphertext_b, &params, &mut output_values)?;
expect_that!(n_written, eq(num_messages));
expect_that!(n_written, eq(num_messages as u64));

// Check homomorphism.
for i in 0..num_messages {
Expand Down
2 changes: 1 addition & 1 deletion shell_wrapper/shell_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ RnsPolynomialVecWrapper RustVecToRnsPolynomialVecWrapper(
// format usable by a ZK proof library.
FfiStatus WriteSmallRnsPolynomialToBuffer(const RnsPolynomialWrapper* poly,
ModuliWrapper moduli,
size_t buffer_len, int64_t* buffer,
uint64_t buffer_len, int64_t* buffer,
uint64_t* n_written);

// Takes prime moduli {q_i}, and a buffer of `buffer_len` signed integers,
Expand Down
14 changes: 7 additions & 7 deletions shell_wrapper/shell_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub mod ffi {
pub unsafe fn WriteRnsPolynomialToBuffer128(
rns_context: *const RnsContext,
poly: *const RnsPolynomialWrapper,
buffer_len: usize,
buffer_len: u64,
buffer: *mut u64,
) -> FfiStatus;

Expand Down Expand Up @@ -159,20 +159,20 @@ pub struct RnsContextRef<'a> {
}

/// Converts slices into C++-safe pointer and length pairs.
pub fn to_cpp_pointer_len_pair<T>(s: &[T]) -> (*const T, usize) {
pub fn to_cpp_pointer_len_pair<T>(s: &[T]) -> (*const T, u64) {
if s.is_empty() {
(std::ptr::null(), 0)
} else {
(s.as_ptr(), s.len())
(s.as_ptr(), s.len() as u64)
}
}

/// Converts slices into C++-safe mutable pointer and length pairs.
pub fn to_cpp_pointer_len_pair_mut<T>(s: &mut [T]) -> (*mut T, usize) {
pub fn to_cpp_pointer_len_pair_mut<T>(s: &mut [T]) -> (*mut T, u64) {
if s.is_empty() {
(std::ptr::null_mut(), 0)
} else {
(s.as_mut_ptr(), s.len())
(s.as_mut_ptr(), s.len() as u64)
}
}

Expand All @@ -186,7 +186,7 @@ pub fn write_small_rns_polynomial_to_buffer(
poly: &RnsPolynomial,
moduli: &Moduli,
buffer: &mut [i64],
) -> Result<usize, status::StatusError> {
) -> Result<u64, status::StatusError> {
let mut n_written: u64 = 0;
// SAFETY: No lifetime constraints. Writes to `buffer` within a valid range.
rust_status_from_cpp(unsafe {
Expand All @@ -198,7 +198,7 @@ pub fn write_small_rns_polynomial_to_buffer(
&mut n_written,
)
})?;
Ok(n_written as usize)
Ok(n_written)
}

/// Takes prime moduli {q_i}, and a buffer of `buffer_len` signed integers,
Expand Down
4 changes: 2 additions & 2 deletions shell_wrapper/single_thread_hkdf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ FfiStatus Rand8(SingleThreadHkdfWrapper& prng, uint8_t& out) {
return MakeFfiStatus();
}

size_t SingleThreadHkdfSeedLength() {
uint64_t SingleThreadHkdfSeedLength() {
return rlwe::SingleThreadHkdfPrng::SeedLength();
}
FfiStatus ComputeHkdfWrapper(rust::Slice<const uint8_t> input,
rust::Slice<const uint8_t> salt,
rust::Slice<const uint8_t> info, size_t out_len,
rust::Slice<const uint8_t> info, uint64_t out_len,
std::unique_ptr<std::string>& out) {
auto statusor = crypto::tink::subtle::Hkdf::ComputeHkdf(
crypto::tink::subtle::SHA256, ToAbslStringView(input),
Expand Down
4 changes: 2 additions & 2 deletions shell_wrapper/single_thread_hkdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ FfiStatus Rand8(SingleThreadHkdfWrapper& prng, uint8_t& out);
// hash function SHA256.
FfiStatus ComputeHkdfWrapper(rust::Slice<const uint8_t> input,
rust::Slice<const uint8_t> salt,
rust::Slice<const uint8_t> info, size_t out_len,
rust::Slice<const uint8_t> info, uint64_t out_len,
std::unique_ptr<std::string>& out);

size_t SingleThreadHkdfSeedLength();
uint64_t SingleThreadHkdfSeedLength();

} // namespace secure_aggregation

Expand Down
6 changes: 3 additions & 3 deletions shell_wrapper/single_thread_hkdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ mod ffi {
pub fn CreateSingleThreadHkdf(seed: &[u8], out: &mut SingleThreadHkdfWrapper) -> FfiStatus;
pub fn Rand8(prng: &mut SingleThreadHkdfWrapper, out: &mut u8) -> FfiStatus;

pub fn SingleThreadHkdfSeedLength() -> usize;
pub fn SingleThreadHkdfSeedLength() -> u64;

pub fn ComputeHkdfWrapper(
input: &[u8],
salt: &[u8],
info: &[u8],
out_len: usize,
out_len: u64,
out: &mut UniquePtr<CxxString>,
) -> FfiStatus;

Expand Down Expand Up @@ -115,7 +115,7 @@ pub fn compute_hkdf(
input: &[u8],
salt: &[u8],
info: &[u8],
out_len: usize,
out_len: u64,
) -> Result<SeedWrapper, status::StatusError> {
let mut out = cxx::UniquePtr::null();
let status = ffi::ComputeHkdfWrapper(input.into(), salt.into(), info.into(), out_len, &mut out);
Expand Down
2 changes: 1 addition & 1 deletion willow/src/shell/ahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ impl Recover for ShellAhe {
&self.public_ahe_parameters,
&mut unsigned_values[start..end],
)?;
if n_written != end - start {
if n_written != (end - start) as u64 {
return Err(status::internal(format!(
"Expected {} recovered messages, but got {}",
end - start,
Expand Down
2 changes: 1 addition & 1 deletion willow/src/shell/kahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ impl TrySecretKeyInto<Vec<i64>> for ShellKahe {
let moduli = kahe::get_moduli(&self.public_kahe_parameters);
let n_written =
write_small_rns_polynomial_to_buffer(&sk.0, &moduli, &mut signed_values[..])?;
if n_written != self.num_coeffs {
if n_written != self.num_coeffs as u64 {
return Err(status::internal(format!(
"Expected {} coefficients, but got {}.",
self.num_coeffs, n_written
Expand Down
9 changes: 5 additions & 4 deletions willow/src/shell/vahe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ pub struct ShellVahe {

impl ShellVahe {
fn transcript_seed(&self) -> &[u8] {
&self.public_seed.as_bytes()
[single_thread_hkdf::seed_length()..2 * single_thread_hkdf::seed_length()]
let seed_len = single_thread_hkdf::seed_length() as usize;
&self.public_seed.as_bytes()[seed_len..2 * seed_len]
}

fn transcript(
Expand Down Expand Up @@ -151,19 +151,20 @@ impl AheBase for ShellVahe {
type Config = ShellAheConfig;

fn new(config: Self::Config, context_string: &[u8]) -> Result<Self, status::StatusError> {
let seed_len = single_thread_hkdf::seed_length();
let public_seed = single_thread_hkdf::compute_hkdf(
context_string,
b"",
b"ShellVahe.public_seed",
2 * single_thread_hkdf::seed_length(), // Separate seeds for transcripts and proofs.
2 * seed_len, // Separate seeds for transcripts and proofs.
)?;
let mut q = 1;
for modulus in &config.qs {
q *= *modulus as u128;
}
let ahe = ShellAhe::new(config, context_string)?;
let rlwe_zk = RlweRelationProverVerifier::new(
&public_seed.as_bytes()[..single_thread_hkdf::seed_length()],
&public_seed.as_bytes()[..seed_len as usize],
ahe.num_coeffs(),
);
Ok(ShellVahe { ahe: ahe, q: q, public_seed: public_seed, rlwe_zk: rlwe_zk })
Expand Down