diff --git a/src/encode.rs b/src/encode.rs index 27ec034..75d74ed 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -11,6 +11,11 @@ const NODE_NONE: u8 = 3; const LEAF_VALUE: u8 = 0; const LEAF_HASH: u8 = 1; +// Keys are 256-bit hashes. Every Internal node's prefix plus the +// discriminating bit that follows must fit within that budget, so the +// effective depth after entering an Internal node is bounded by 255. +const KEY_BITS: usize = 256; + /// Serializes a `SubTreeNode` into a writer. pub(crate) fn serialize_node( node: &SubTreeNode, @@ -55,6 +60,10 @@ pub(crate) fn serialize_node( /// Deserializes a `SubTreeNode` from a reader. pub(crate) fn deserialize_node(reader: &mut R) -> borsh::io::Result { + deserialize_node_at(reader, 0) +} + +fn deserialize_node_at(reader: &mut R, depth: usize) -> borsh::io::Result { let mut tag = [0u8; 1]; reader.read_exact(&mut tag)?; match tag[0] { @@ -92,8 +101,24 @@ pub(crate) fn deserialize_node(reader: &mut R) -> borsh::io::Result KEY_BITS { + return Err(IoError::new( + ErrorKind::InvalidData, + "subtree internal node exceeds key length", + )); + } + + let left = Box::new(deserialize_node_at(reader, next_depth)?); + let right = Box::new(deserialize_node_at(reader, next_depth)?); Ok(SubTreeNode::Internal { prefix, left, @@ -109,3 +134,45 @@ pub(crate) fn deserialize_node(reader: &mut R) -> borsh::io::Result Err(IoError::new(ErrorKind::InvalidData, "unknown node tag")), } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Sha256Hasher, subtree::SubTree}; + + /// A crafted subtree where two nested Internal nodes claim cumulative + /// prefixes longer than 256 bits should be rejected by the + /// deserializer rather than producing a structure that can later + /// panic during traversal. + #[test] + fn deserialize_rejects_oversized_prefix_chain() { + // Root: Internal { prefix bit_len = 130 }, descending into + // another Internal { prefix bit_len = 130 }. After the root we + // are at depth 131; the child would push us to 262, past the + // 256-bit key. + let mut bytes = alloc::vec::Vec::new(); + // Root tag + prefix (bit_len 130 -> 17 bytes of payload + 1 length byte) + bytes.push(NODE_INTERNAL); + bytes.push(130); + bytes.extend(core::iter::repeat_n(0xFFu8, 17)); + // Left child: another oversized Internal + bytes.push(NODE_INTERNAL); + bytes.push(130); + bytes.extend(core::iter::repeat_n(0xFFu8, 17)); + // Two trivial Hash leaves under the inner Internal so the structure + // is at least syntactically complete. + bytes.push(NODE_HASH); + bytes.extend(core::iter::repeat_n(0u8, 32)); + bytes.push(NODE_HASH); + bytes.extend(core::iter::repeat_n(0u8, 32)); + // Right child of root + bytes.push(NODE_HASH); + bytes.extend(core::iter::repeat_n(0u8, 32)); + + let result = SubTree::::from_slice(&bytes); + assert!( + result.is_err(), + "deserializer must reject a subtree whose nested prefixes overrun the key length" + ); + } +} diff --git a/src/path.rs b/src/path.rs index ae6a932..db76cf4 100644 --- a/src/path.rs +++ b/src/path.rs @@ -186,7 +186,12 @@ impl> PathUtils for T { } fn split_point(&self, start: usize, b: S) -> Option { - let max_bit_len = core::cmp::min(self.bit_len(), b.bit_len()); + assert!(self.bit_len() >= start, "start must be within self"); + // Cap by self's remaining bits past `start`, not by self's full + // length. Otherwise a `b` longer than the tail of self causes the + // comparison to overrun into self's padding and report a spurious + // split point. + let max_bit_len = core::cmp::min(self.bit_len() - start, b.bit_len()); let (src_start_byte, src_start_bit, seg_end_byte) = (start / 8, start % 8, max_bit_len.div_ceil(8)); let mut count = 0; @@ -353,6 +358,35 @@ mod tests { assert_eq!(parent.to_string(), "000001111101000"); } + #[test] + fn split_point_does_not_overrun_self() { + use crate::path::Path; + + // self is a 256-bit key, all ones. + let key = Path([0xFFu8; 32]); + + // A malformed segment claims bit_len=100, but only 56 bits past + // start=200 remain in self. Past those 56 bits, comparison should + // stop — anything beyond is outside self's data. + // + // Make the segment's first 56 bits match self's remaining 56 bits + // (all 0xFF). The remaining 44 bits of the segment are zero. + let mut malicious = PathSegment([0u8; 33]); + malicious.set_len(100); + for byte in malicious.0[1..=7].iter_mut() { + *byte = 0xFF; + } + + // Correct behaviour: compare 56 bits, all match -> no split point. + // Buggy behaviour: cap is 100 bits, comparison runs out of self at + // bit 56, returns Some(56) reporting a spurious divergence. + assert_eq!( + key.split_point(200, malicious), + None, + "split_point must not report a divergence beyond self's length" + ); + } + #[test] fn test_extend_from_byte() { let mut segment = PathSegment([0u8; 33]);