diff --git a/shell_wrapper/ahe.cc b/shell_wrapper/ahe.cc index e7f9012..43fc811 100644 --- a/shell_wrapper/ahe.cc +++ b/shell_wrapper/ahe.cc @@ -83,7 +83,7 @@ class PublicKeyRawFactory { namespace secure_aggregation { FfiStatus CreateAhePublicParameters(uint64_t log_n, uint64_t t, - const uint64_t* qs, size_t num_qs, + const uint64_t* qs, uint64_t num_qs, uint64_t error_variance, double s_base_flood, double s_flood, rust::Slice seed, @@ -193,7 +193,7 @@ FfiStatus GeneratePublicKeyShareWrapper( public_key_share_error->ptr.get(), wraparound_ptr)); } -FfiStatus AheEncrypt(const uint64_t* input_values, size_t num_input_values, +FfiStatus AheEncrypt(const uint64_t* input_values, uint64_t num_input_values, const RnsPolynomialWrapper& public_key_b, const AhePublicParameters& params, SingleThreadHkdfWrapper* prng, @@ -270,8 +270,8 @@ FfiStatus PartialDecrypt(const RnsPolynomialWrapper& ciphertext_component_a, FfiStatus RecoverMessages(const RnsPolynomialWrapper& sum_partial_decryptions, const RnsPolynomialWrapper& ciphertext_component_b, const AhePublicParameters& params, - size_t output_values_length, uint64_t* output_values, - size_t* n_written) { + uint64_t output_values_length, + uint64_t* output_values, uint64_t* n_written) { if (output_values == nullptr || sum_partial_decryptions.ptr == nullptr || ciphertext_component_b.ptr == nullptr || n_written == nullptr) { return MakeFfiStatus(absl::InvalidArgumentError( @@ -288,7 +288,8 @@ FfiStatus RecoverMessages(const RnsPolynomialWrapper& sum_partial_decryptions, } // Copy messages from vector to output buffer. - *n_written = std::min(output_values_length, messages->size()); + *n_written = + std::min(static_cast(output_values_length), messages->size()); std::copy_n(messages->begin(), *n_written, output_values); return MakeFfiStatus(); } diff --git a/shell_wrapper/ahe.h b/shell_wrapper/ahe.h index e08ce3f..e908bb3 100644 --- a/shell_wrapper/ahe.h +++ b/shell_wrapper/ahe.h @@ -43,7 +43,7 @@ namespace secure_aggregation { // `seed` is the seed used to generate the public parameter (polynomial) u. // Writes the result to `out` and returns a wrapped absl::Status. FfiStatus CreateAhePublicParameters(uint64_t log_n, uint64_t t, - const uint64_t* qs, size_t num_qs, + const uint64_t* qs, uint64_t num_qs, uint64_t error_variance, double s_base_flood, double s_flood, rust::Slice seed, @@ -93,7 +93,7 @@ FfiStatus GeneratePublicKeyShareWrapper( // ciphertext in `ciphertext_component_b` (a.k.a. ct0) and // `ciphertext_component_a` (a.k.a. ct1). Also stores the secret and error for // ZK proofs. Returns a wrapped absl::Status. -FfiStatus AheEncrypt(const uint64_t* input_values, size_t num_input_values, +FfiStatus AheEncrypt(const uint64_t* input_values, uint64_t num_input_values, const RnsPolynomialWrapper& public_key_b, const AhePublicParameters& params, SingleThreadHkdfWrapper* prng, @@ -120,8 +120,8 @@ FfiStatus PartialDecrypt(const RnsPolynomialWrapper& ciphertext_component_a, FfiStatus RecoverMessages(const RnsPolynomialWrapper& sum_partial_decryptions, const RnsPolynomialWrapper& ciphertext_component_b, const AhePublicParameters& params, - size_t output_values_length, uint64_t* output_values, - size_t* n_written); + uint64_t output_values_length, + uint64_t* output_values, uint64_t* n_written); // Creates a zero polynomial with the same RNS parameters as `params`. Writes // the result to `out` and returns a wrapped absl::Status. diff --git a/shell_wrapper/ahe.rs b/shell_wrapper/ahe.rs index 7faaccb..e182c5f 100644 --- a/shell_wrapper/ahe.rs +++ b/shell_wrapper/ahe.rs @@ -59,7 +59,7 @@ mod ffi { log_n: u64, t: u64, qs: *const u64, - num_qs: usize, + num_qs: u64, error_variance: u64, s_base_flood: f64, s_flood: f64, @@ -93,7 +93,7 @@ mod ffi { pub unsafe fn AheEncrypt( input_values: *const u64, - num_input_values: usize, + num_input_values: u64, public_key_b: &RnsPolynomialWrapper, params: &AhePublicParameters, prng: *mut SingleThreadHkdfWrapper, @@ -118,9 +118,9 @@ mod ffi { sum_partial_decryptions: &RnsPolynomialWrapper, ciphertext_component_b: &RnsPolynomialWrapper, params: &AhePublicParameters, - num_output_values: usize, + num_output_values: u64, output_values: *mut u64, - n_written: *mut usize, + n_written: *mut u64, ) -> FfiStatus; pub unsafe fn CreateZeroRnsPolynomialWrapper( @@ -332,8 +332,8 @@ pub fn recover_messages( ciphertext_component_b: &RnsPolynomial, params: &AhePublicParameters, output_values: &mut [u64], -) -> Result { - let mut n_written = 0usize; +) -> Result { + let mut n_written = 0u64; let (out_ptr, out_len) = to_cpp_pointer_len_pair_mut(output_values); // SAFETY: No lifetime constraints (`BufferResult` just holds two ints and // does not keep any reference to the inputs). `Decrypt` only modifies the diff --git a/shell_wrapper/ahe_test.cc b/shell_wrapper/ahe_test.cc index 6873c56..a474e35 100644 --- a/shell_wrapper/ahe_test.cc +++ b/shell_wrapper/ahe_test.cc @@ -101,7 +101,7 @@ TEST(AheTest, EncryptDecryptOne) { // Recover from partial decryptions and component b uint64_t decrypted_buffer[2 * num_messages]; - size_t n_written; + uint64_t n_written; SECAGG_ASSERT_OK(UnwrapFfiStatus(RecoverMessages( partial_decryption, ciphertext_b, public_params, std::size(decrypted_buffer), decrypted_buffer, &n_written))); @@ -228,7 +228,7 @@ TEST(AheTest, ExternCRecoveryTest) { // Recover from partial decryptions and component b uint64_t decrypted_buffer[2 * num_messages]; - size_t n_written; + uint64_t n_written; SECAGG_ASSERT_OK(UnwrapFfiStatus(RecoverMessages( sum_partial_decryptions, ciphertext_b, public_params, std::size(decrypted_buffer), decrypted_buffer, &n_written))); diff --git a/shell_wrapper/ahe_test.rs b/shell_wrapper/ahe_test.rs index 0889109..9882cb2 100644 --- a/shell_wrapper/ahe_test.rs +++ b/shell_wrapper/ahe_test.rs @@ -143,7 +143,7 @@ fn encrypt_add_recover() -> Result<()> { let mut output_values = vec![0; 2 * num_messages]; let n_written = recover_messages(&sum_partial_decryptions, &ciphertext_b, ¶ms, &mut output_values)?; - expect_that!(n_written, eq(num_messages)); + expect_that!(n_written, eq(num_messages as u64)); // Check homomorphism. for i in 0..num_messages { diff --git a/shell_wrapper/shell_types.h b/shell_wrapper/shell_types.h index c497d15..0b30817 100644 --- a/shell_wrapper/shell_types.h +++ b/shell_wrapper/shell_types.h @@ -81,7 +81,7 @@ RnsPolynomialVecWrapper RustVecToRnsPolynomialVecWrapper( // format usable by a ZK proof library. FfiStatus WriteSmallRnsPolynomialToBuffer(const RnsPolynomialWrapper* poly, ModuliWrapper moduli, - size_t buffer_len, int64_t* buffer, + uint64_t buffer_len, int64_t* buffer, uint64_t* n_written); // Takes prime moduli {q_i}, and a buffer of `buffer_len` signed integers, diff --git a/shell_wrapper/shell_types.rs b/shell_wrapper/shell_types.rs index 7e2f4e3..2ec29e5 100644 --- a/shell_wrapper/shell_types.rs +++ b/shell_wrapper/shell_types.rs @@ -79,7 +79,7 @@ pub mod ffi { pub unsafe fn WriteRnsPolynomialToBuffer128( rns_context: *const RnsContext, poly: *const RnsPolynomialWrapper, - buffer_len: usize, + buffer_len: u64, buffer: *mut u64, ) -> FfiStatus; @@ -159,20 +159,20 @@ pub struct RnsContextRef<'a> { } /// Converts slices into C++-safe pointer and length pairs. -pub fn to_cpp_pointer_len_pair(s: &[T]) -> (*const T, usize) { +pub fn to_cpp_pointer_len_pair(s: &[T]) -> (*const T, u64) { if s.is_empty() { (std::ptr::null(), 0) } else { - (s.as_ptr(), s.len()) + (s.as_ptr(), s.len() as u64) } } /// Converts slices into C++-safe mutable pointer and length pairs. -pub fn to_cpp_pointer_len_pair_mut(s: &mut [T]) -> (*mut T, usize) { +pub fn to_cpp_pointer_len_pair_mut(s: &mut [T]) -> (*mut T, u64) { if s.is_empty() { (std::ptr::null_mut(), 0) } else { - (s.as_mut_ptr(), s.len()) + (s.as_mut_ptr(), s.len() as u64) } } @@ -186,7 +186,7 @@ pub fn write_small_rns_polynomial_to_buffer( poly: &RnsPolynomial, moduli: &Moduli, buffer: &mut [i64], -) -> Result { +) -> Result { let mut n_written: u64 = 0; // SAFETY: No lifetime constraints. Writes to `buffer` within a valid range. rust_status_from_cpp(unsafe { @@ -198,7 +198,7 @@ pub fn write_small_rns_polynomial_to_buffer( &mut n_written, ) })?; - Ok(n_written as usize) + Ok(n_written) } /// Takes prime moduli {q_i}, and a buffer of `buffer_len` signed integers, diff --git a/shell_wrapper/single_thread_hkdf.cc b/shell_wrapper/single_thread_hkdf.cc index e8b431f..58aaf08 100644 --- a/shell_wrapper/single_thread_hkdf.cc +++ b/shell_wrapper/single_thread_hkdf.cc @@ -66,12 +66,12 @@ FfiStatus Rand8(SingleThreadHkdfWrapper& prng, uint8_t& out) { return MakeFfiStatus(); } -size_t SingleThreadHkdfSeedLength() { +uint64_t SingleThreadHkdfSeedLength() { return rlwe::SingleThreadHkdfPrng::SeedLength(); } FfiStatus ComputeHkdfWrapper(rust::Slice input, rust::Slice salt, - rust::Slice info, size_t out_len, + rust::Slice info, uint64_t out_len, std::unique_ptr& out) { auto statusor = crypto::tink::subtle::Hkdf::ComputeHkdf( crypto::tink::subtle::SHA256, ToAbslStringView(input), diff --git a/shell_wrapper/single_thread_hkdf.h b/shell_wrapper/single_thread_hkdf.h index 1e488fa..79137c5 100644 --- a/shell_wrapper/single_thread_hkdf.h +++ b/shell_wrapper/single_thread_hkdf.h @@ -41,10 +41,10 @@ FfiStatus Rand8(SingleThreadHkdfWrapper& prng, uint8_t& out); // hash function SHA256. FfiStatus ComputeHkdfWrapper(rust::Slice input, rust::Slice salt, - rust::Slice info, size_t out_len, + rust::Slice info, uint64_t out_len, std::unique_ptr& out); -size_t SingleThreadHkdfSeedLength(); +uint64_t SingleThreadHkdfSeedLength(); } // namespace secure_aggregation diff --git a/shell_wrapper/single_thread_hkdf.rs b/shell_wrapper/single_thread_hkdf.rs index b1c92a5..8383a18 100644 --- a/shell_wrapper/single_thread_hkdf.rs +++ b/shell_wrapper/single_thread_hkdf.rs @@ -34,13 +34,13 @@ mod ffi { pub fn CreateSingleThreadHkdf(seed: &[u8], out: &mut SingleThreadHkdfWrapper) -> FfiStatus; pub fn Rand8(prng: &mut SingleThreadHkdfWrapper, out: &mut u8) -> FfiStatus; - pub fn SingleThreadHkdfSeedLength() -> usize; + pub fn SingleThreadHkdfSeedLength() -> u64; pub fn ComputeHkdfWrapper( input: &[u8], salt: &[u8], info: &[u8], - out_len: usize, + out_len: u64, out: &mut UniquePtr, ) -> FfiStatus; @@ -115,7 +115,7 @@ pub fn compute_hkdf( input: &[u8], salt: &[u8], info: &[u8], - out_len: usize, + out_len: u64, ) -> Result { let mut out = cxx::UniquePtr::null(); let status = ffi::ComputeHkdfWrapper(input.into(), salt.into(), info.into(), out_len, &mut out); diff --git a/willow/src/shell/ahe.rs b/willow/src/shell/ahe.rs index e152c35..29364f8 100644 --- a/willow/src/shell/ahe.rs +++ b/willow/src/shell/ahe.rs @@ -681,7 +681,7 @@ impl Recover for ShellAhe { &self.public_ahe_parameters, &mut unsigned_values[start..end], )?; - if n_written != end - start { + if n_written != (end - start) as u64 { return Err(status::internal(format!( "Expected {} recovered messages, but got {}", end - start, diff --git a/willow/src/shell/kahe.rs b/willow/src/shell/kahe.rs index ad285f1..6378aec 100644 --- a/willow/src/shell/kahe.rs +++ b/willow/src/shell/kahe.rs @@ -329,7 +329,7 @@ impl TrySecretKeyInto> for ShellKahe { let moduli = kahe::get_moduli(&self.public_kahe_parameters); let n_written = write_small_rns_polynomial_to_buffer(&sk.0, &moduli, &mut signed_values[..])?; - if n_written != self.num_coeffs { + if n_written != self.num_coeffs as u64 { return Err(status::internal(format!( "Expected {} coefficients, but got {}.", self.num_coeffs, n_written diff --git a/willow/src/shell/vahe.rs b/willow/src/shell/vahe.rs index 9a5ca2f..0845afe 100644 --- a/willow/src/shell/vahe.rs +++ b/willow/src/shell/vahe.rs @@ -105,8 +105,8 @@ pub struct ShellVahe { impl ShellVahe { fn transcript_seed(&self) -> &[u8] { - &self.public_seed.as_bytes() - [single_thread_hkdf::seed_length()..2 * single_thread_hkdf::seed_length()] + let seed_len = single_thread_hkdf::seed_length() as usize; + &self.public_seed.as_bytes()[seed_len..2 * seed_len] } fn transcript( @@ -151,11 +151,12 @@ impl AheBase for ShellVahe { type Config = ShellAheConfig; fn new(config: Self::Config, context_string: &[u8]) -> Result { + let seed_len = single_thread_hkdf::seed_length(); let public_seed = single_thread_hkdf::compute_hkdf( context_string, b"", b"ShellVahe.public_seed", - 2 * single_thread_hkdf::seed_length(), // Separate seeds for transcripts and proofs. + 2 * seed_len, // Separate seeds for transcripts and proofs. )?; let mut q = 1; for modulus in &config.qs { @@ -163,7 +164,7 @@ impl AheBase for ShellVahe { } let ahe = ShellAhe::new(config, context_string)?; let rlwe_zk = RlweRelationProverVerifier::new( - &public_seed.as_bytes()[..single_thread_hkdf::seed_length()], + &public_seed.as_bytes()[..seed_len as usize], ahe.num_coeffs(), ); Ok(ShellVahe { ahe: ahe, q: q, public_seed: public_seed, rlwe_zk: rlwe_zk })