diff --git a/Cargo.lock b/Cargo.lock index b5eabcd4..1f3d17e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -971,7 +971,7 @@ checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "salt" -version = "1.0.0" +version = "1.0.1" dependencies = [ "ark-ff", "auto_impl", diff --git a/ipa-multipoint/src/ipa.rs b/ipa-multipoint/src/ipa.rs index 688b6fb2..8b714998 100644 --- a/ipa-multipoint/src/ipa.rs +++ b/ipa-multipoint/src/ipa.rs @@ -27,8 +27,17 @@ impl IPAProof { let mut L_vec = Vec::with_capacity(num_points as usize); let mut R_vec = Vec::with_capacity(num_points as usize); - assert_eq!(((num_points * 2) + 1) * 32, bytes.len() as u32); - assert!(bytes.len().is_multiple_of(32)); + let expected_len = ((num_points * 2) + 1) * 32; + if bytes.len() != expected_len as usize { + return Err(IOError::new( + IOErrorKind::InvalidData, + format!( + "invalid proof length, expected {} bytes, got {} bytes", + expected_len, + bytes.len() + ), + )); + } // Chunk the byte slice into 32 bytes let mut chunks = bytes.chunks_exact(32); diff --git a/ipa-multipoint/src/multiproof.rs b/ipa-multipoint/src/multiproof.rs index d10b5a41..2d4d0c85 100644 --- a/ipa-multipoint/src/multiproof.rs +++ b/ipa-multipoint/src/multiproof.rs @@ -239,8 +239,15 @@ impl MultiPointProof { pub fn from_bytes(bytes: &[u8], poly_degree: usize) -> crate::IOResult { use crate::{IOError, IOErrorKind}; + if bytes.len() < 32 { + return Err(IOError::new( + IOErrorKind::InvalidData, + "bytes length is less than 32", + )); + } + let g_x_comm_bytes = &bytes[0..32]; - let ipa_bytes = &bytes[32..]; // TODO: we should return a Result here incase the user gives us bad bytes + let ipa_bytes = &bytes[32..]; let point: Element = Element::from_bytes(g_x_comm_bytes.try_into().unwrap()) .map_err(|_| IOError::from(IOErrorKind::InvalidData))?; let g_x_comm = point; @@ -661,4 +668,12 @@ mod tests { } transcript.state } + + #[test] + fn test_from_bytes_invalid_length() { + let bytes = [0u8; 31]; + let result = MultiPointProof::from_bytes(&bytes, 256); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::InvalidData); + } }