diff --git a/.gitignore b/.gitignore index 55d8a4c3a..de49443ab 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,6 @@ .vscode /docs/benchmark_graphs/.venv minimal_zkVM.synctex.gz +crates/rec_aggregation/test_data .claude -misc/.build \ No newline at end of file +misc/.build diff --git a/Cargo.lock b/Cargo.lock index a1e508d94..10209907f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,43 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloy-primitives" +version = "1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3b431b4e72cd8bd0ec7a50b4be18e73dab74de0dba180eef171055e5d5926e" +dependencies = [ + "alloy-rlp", + "bytes", + "cfg-if", + "const-hex", + "derive_more", + "foldhash 0.2.0", + "hashbrown 0.16.1", + "indexmap", + "itoa", + "k256", + "keccak-asm", + "paste", + "proptest", + "rand 0.9.4", + "rapidhash", + "ruint", + "rustc-hash", + "serde", + "sha3 0.10.9", +] + +[[package]] +name = "alloy-rlp" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc90b1e703d3c03f4ff7f48e82dd0bc1c8211ab7d079cd836a06fcfeb06651cb" +dependencies = [ + "arrayvec", + "bytes", +] + [[package]] name = "ansi_term" version = "0.12.1" @@ -76,6 +113,201 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "ark-ff" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b3235cc41ee7a12aaaf2c575a2ad7b46713a8a50bda2fc3b003a04845c05dd6" +dependencies = [ + "ark-ff-asm 0.3.0", + "ark-ff-macros 0.3.0", + "ark-serialize 0.3.0", + "ark-std 0.3.0", + "derivative", + "num-bigint", + "num-traits", + "paste", + "rustc_version 0.3.3", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec847af850f44ad29048935519032c33da8aa03340876d351dfab5660d2966ba" +dependencies = [ + "ark-ff-asm 0.4.2", + "ark-ff-macros 0.4.2", + "ark-serialize 0.4.2", + "ark-std 0.4.0", + "derivative", + "digest 0.10.7", + "itertools 0.10.5", + "num-bigint", + "num-traits", + "paste", + "rustc_version 0.4.1", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" +dependencies = [ + "ark-ff-asm 0.5.0", + "ark-ff-macros 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "arrayvec", + "digest 0.10.7", + "educe", + "itertools 0.13.0", + "num-bigint", + "num-traits", + "paste", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db02d390bf6643fb404d3d22d31aee1c4bc4459600aef9113833d17e786c6e44" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-asm" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ed4aa4fe255d0bc6d79373f7e31d2ea147bcf486cba1be5ba7ea85abdb92348" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" +dependencies = [ + "quote", + "syn 2.0.117", +] + +[[package]] +name = "ark-ff-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fd794a08ccb318058009eefdf15bcaaaaf6f8161eb3345f907222bac38b20" +dependencies = [ + "num-bigint", + "num-traits", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "ark-serialize" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d6c2b318ee6e10f8c2853e73a83adc0ccb88995aa978d8a3408d492ab2ee671" +dependencies = [ + "ark-std 0.3.0", + "digest 0.9.0", +] + +[[package]] +name = "ark-serialize" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" +dependencies = [ + "ark-std 0.4.0", + "digest 0.10.7", + "num-bigint", +] + +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" +dependencies = [ + "ark-std 0.5.0", + "arrayvec", + "digest 0.10.7", + "num-bigint", +] + +[[package]] +name = "ark-std" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1df2c09229cbc5a028b1d70e00fdb2acee28b1055dfb5ca73eea49c5a25c4e7c" +dependencies = [ + "num-traits", + "rand 0.8.6", +] + +[[package]] +name = "ark-std" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" +dependencies = [ + "num-traits", + "rand 0.8.6", +] + +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand 0.8.6", +] + +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "atomic-polyfill" version = "1.0.3" @@ -85,6 +317,17 @@ dependencies = [ "critical-section", ] +[[package]] +name = "auto_impl" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -104,15 +347,55 @@ dependencies = [ "mt-symetric", "mt-utils", "mt-whir", - "rayon", + "parallel", "tracing", + "zk-alloc", +] + +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", ] +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitflags" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] [[package]] name = "block-buffer" @@ -132,12 +415,37 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "byte-slice-cast" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7575182f7272186991736b70173b0ea045398f984bf5ebbb3804736ce1330c9d" + [[package]] name = "byteorder" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +dependencies = [ + "serde", +] + +[[package]] +name = "cc" +version = "1.2.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -152,14 +460,14 @@ checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" dependencies = [ "cfg-if", "cpufeatures 0.3.0", - "rand_core", + "rand_core 0.10.1", ] [[package]] name = "clap" -version = "4.6.0" +version = "4.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" +checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" dependencies = [ "clap_builder", "clap_derive", @@ -179,14 +487,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.6.0" +version = "4.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" +checksum = "f2ce8604710f6733aa641a2b3731eaa1e8b3d9973d5e3565da11800813f997a9" dependencies = [ "heck", "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -210,12 +518,60 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" +[[package]] +name = "const-hex" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20d9a563d167a9cce0f94153382b33cb6eded6dfabff03c69ad65a28ea1514e0" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "proptest", + "serde_core", +] + +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "const-oid" version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6ef517f0926dd24a1582492c791b6a4818a4d94e789a334894aa15b0d12f55c" +[[package]] +name = "const_format" +version = "0.2.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4481a617ad9a412be3b97c5d403fef8ed023103368908b9c50af598ff467cc1e" +dependencies = [ + "const_format_proc_macros", + "konst", +] + +[[package]] +name = "const_format_proc_macros" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d57c2eccfb16dbac1f4e61e206105db5820c9d26c3c472bc17c774259ef7744" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + +[[package]] +name = "convert_case" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -265,6 +621,24 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array", + "rand_core 0.6.4", + "subtle", + "zeroize", +] + [[package]] name = "crypto-common" version = "0.1.7" @@ -284,6 +658,73 @@ dependencies = [ "hybrid-array", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "const-oid 0.9.6", + "zeroize", +] + +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_more" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version 0.4.1", + "syn 2.0.117", + "unicode-xid", +] + +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + [[package]] name = "digest" version = "0.10.7" @@ -291,7 +732,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer 0.10.4", + "const-oid 0.9.6", "crypto-common 0.1.7", + "subtle", ] [[package]] @@ -301,16 +744,61 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4850db49bf08e663084f7fb5c87d202ef91a3907271aff24a94eb97ff039153c" dependencies = [ "block-buffer 0.12.0", - "const-oid", + "const-oid 0.10.2", "crypto-common 0.2.1", ] +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" +dependencies = [ + "der", + "digest 0.10.7", + "elliptic-curve", + "rfc6979", + "signature", + "spki", +] + +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct", + "crypto-bigint", + "digest 0.10.7", + "ff", + "generic-array", + "group", + "pkcs8", + "rand_core 0.6.4", + "sec1", + "subtle", + "zeroize", +] + [[package]] name = "embedded-io" version = "0.4.0" @@ -323,18 +811,150 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" +[[package]] +name = "enum-ordinalize" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a1091a7bb1f8f2c4b28f1fe2cef4980ca2d410a3d727d67ecc3178c9b0800f0" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ca9601fb2d62598ee17836250842873a413586e5d7ed88b356e38ddbb0ec631" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "ethereum_serde_utils" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dc1355dbb41fbbd34ec28d4fb2a57d9a70c67ac3c19f6a5ca4d4a176b9e997a" +dependencies = [ + "alloy-primitives", + "hex", + "serde", + "serde_derive", + "serde_json", +] + +[[package]] +name = "ethereum_ssz" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368a4a4e4273b0135111fe9464e35465067766a8f664615b5a86338b73864407" +dependencies = [ + "alloy-primitives", + "ethereum_serde_utils", + "itertools 0.14.0", + "serde", + "serde_derive", + "smallvec", + "typenum", +] + +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + +[[package]] +name = "fastrlp" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139834ddba373bbdd213dffe02c8d110508dcf1726c2be27e8d1f7d7e1856418" +dependencies = [ + "arrayvec", + "auto_impl", + "bytes", +] + +[[package]] +name = "fastrlp" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce8dba4714ef14b8274c371879b175aa55b16b30f269663f19d576f380018dc4" +dependencies = [ + "arrayvec", + "auto_impl", + "bytes", +] + +[[package]] +name = "ff" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" +dependencies = [ + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fixed-hash" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "835c052cb0c08c1acf6ffd71c022172e18723949c8282f2b9f27efbc51e64534" +dependencies = [ + "byteorder", + "rand 0.8.6", + "rustc-hex", + "static_assertions", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "foldhash" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "generic-array" version = "0.14.7" @@ -343,6 +963,30 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", + "zeroize", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", ] [[package]] @@ -353,12 +997,23 @@ checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" dependencies = [ "cfg-if", "libc", - "r-efi", - "rand_core", + "r-efi 6.0.0", + "rand_core 0.10.1", "wasip2", "wasip3", ] +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "hash32" version = "0.2.1" @@ -368,13 +1023,19 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "foldhash", + "foldhash 0.1.5", ] [[package]] @@ -382,6 +1043,17 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "foldhash 0.2.0", + "serde", + "serde_core", +] + +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" [[package]] name = "heapless" @@ -391,9 +1063,9 @@ checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" dependencies = [ "atomic-polyfill", "hash32", - "rustc_version", + "rustc_version 0.4.1", "serde", - "spin", + "spin 0.9.8", "stable_deref_trait", ] @@ -403,11 +1075,26 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest 0.10.7", +] + [[package]] name = "hybrid-array" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3944cf8cf766b40e2a1a333ee5e9b563f854d5fa49d6a8ca2764e97c6eddb214" +checksum = "08d46837a0ed51fe95bd3b05de33cd64a1ee88fc797477ca48446872504507c5" dependencies = [ "typenum", ] @@ -418,6 +1105,26 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" +[[package]] +name = "impl-codec" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba6a270039626615617f3f36d15fc827041df3b78c439da2cadfa47455a77f2f" +dependencies = [ + "parity-scale-codec", +] + +[[package]] +name = "impl-trait-for-tuples" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "include_dir" version = "0.7.4" @@ -439,12 +1146,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.13.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.16.1", + "hashbrown 0.17.0", "serde", "serde_core", ] @@ -455,20 +1162,60 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" dependencies = [ - "either", + "either", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "k256" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6e3919bbaa2945715f0bb6d3934a173d1e9a59ac23767fbaaef277265a7411b" +dependencies = [ + "cfg-if", + "ecdsa", + "elliptic-curve", + "once_cell", + "sha2", ] [[package]] -name = "itoa" -version = "1.0.18" +name = "keccak" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" +checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" +dependencies = [ + "cpufeatures 0.2.17", +] [[package]] name = "keccak" @@ -480,6 +1227,31 @@ dependencies = [ "cpufeatures 0.3.0", ] +[[package]] +name = "keccak-asm" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa468878266ad91431012b3e5ef1bf9b170eab22883503a318d46857afa4579a" +dependencies = [ + "digest 0.10.7", + "sha3-asm", +] + +[[package]] +name = "konst" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "128133ed7824fcd73d6e7b17957c5eb7bacb885649bd8c69708b2331a10bcefb" +dependencies = [ + "konst_macro_rules", +] + +[[package]] +name = "konst_macro_rules" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4933f3f57a8e9d9da04db23fb153356ecaf00cbd14aee46279c33dc80925c37" + [[package]] name = "lazy_static" version = "1.5.0" @@ -493,13 +1265,15 @@ dependencies = [ "backend", "clap", "lean_vm", - "rand", + "leansig_wrapper", + "parallel", + "rand 0.10.1", "rec_aggregation", + "serde", "serde_json", "sub_protocols", "system-info", "utils", - "xmss", "zk-alloc", ] @@ -512,11 +1286,10 @@ dependencies = [ "lean_vm", "pest", "pest_derive", - "rand", + "rand 0.10.1", "sub_protocols", "tracing", "utils", - "xmss", ] [[package]] @@ -524,19 +1297,18 @@ name = "lean_prover" version = "0.1.0" dependencies = [ "backend", - "itertools", + "itertools 0.14.0", "lean_compiler", "lean_vm", "pest", "pest_derive", - "rand", + "rand 0.10.1", "rec_aggregation", "serde", "serde_json", "sub_protocols", "tracing", "utils", - "xmss", ] [[package]] @@ -544,13 +1316,66 @@ name = "lean_vm" version = "0.1.0" dependencies = [ "backend", - "itertools", + "itertools 0.14.0", + "leansig_wrapper", "pest", "pest_derive", - "rand", + "rand 0.10.1", + "serde", "tracing", "utils", - "xmss", +] + +[[package]] +name = "leansig" +version = "0.1.0" +source = "git+https://github.com/leanEthereum/leanSig#c08a3bae74b0d85379cab72dcbefa4091546ecbb" +dependencies = [ + "dashmap", + "ethereum_ssz", + "num-bigint", + "num-traits", + "p3-baby-bear", + "p3-field", + "p3-koala-bear", + "p3-symmetric", + "rand 0.10.1", + "rayon", + "serde", + "sha3 0.10.9", + "thiserror", +] + +[[package]] +name = "leansig_fast_keygen" +version = "0.1.0" +source = "git+https://github.com/TomWambsgans/leanSig?branch=devnet4-fast-keygen#0fa9e19b8946ef50a34f3d50d82918b98bcfa4a5" +dependencies = [ + "dashmap", + "ethereum_ssz", + "num-bigint", + "num-traits", + "p3-baby-bear", + "p3-field", + "p3-koala-bear", + "p3-symmetric", + "rand 0.10.1", + "rayon", + "serde", + "sha3 0.10.9", + "thiserror", +] + +[[package]] +name = "leansig_wrapper" +version = "0.1.0" +dependencies = [ + "backend", + "ethereum_ssz", + "leansig", + "leansig_fast_keygen", + "p3-field", + "rand 0.10.1", ] [[package]] @@ -565,6 +1390,18 @@ version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + [[package]] name = "lock_api" version = "0.4.14" @@ -620,7 +1457,7 @@ dependencies = [ "mt-koala-bear", "mt-symetric", "mt-utils", - "rayon", + "parallel", "serde", "tracing", ] @@ -629,12 +1466,12 @@ dependencies = [ name = "mt-field" version = "0.1.0" dependencies = [ - "itertools", + "itertools 0.14.0", "mt-utils", "num-bigint", + "parallel", "paste", - "rand", - "rayon", + "rand 0.10.1", "serde", "tracing", ] @@ -643,13 +1480,12 @@ dependencies = [ name = "mt-koala-bear" version = "0.1.0" dependencies = [ - "itertools", + "itertools 0.14.0", "mt-field", "mt-utils", "num-bigint", "paste", - "rand", - "rayon", + "rand 0.10.1", "serde", "tracing", ] @@ -658,14 +1494,15 @@ dependencies = [ name = "mt-poly" version = "0.1.0" dependencies = [ - "itertools", + "itertools 0.14.0", "mt-field", "mt-koala-bear", "mt-utils", - "rand", - "rayon", + "parallel", + "rand 0.10.1", "serde", "system-info", + "zk-alloc", ] [[package]] @@ -677,8 +1514,9 @@ dependencies = [ "mt-field", "mt-koala-bear", "mt-poly", - "rayon", + "parallel", "tracing", + "zk-alloc", ] [[package]] @@ -687,7 +1525,8 @@ version = "0.1.0" dependencies = [ "mt-field", "mt-koala-bear", - "rayon", + "parallel", + "zk-alloc", ] [[package]] @@ -701,7 +1540,7 @@ dependencies = [ name = "mt-whir" version = "0.1.0" dependencies = [ - "itertools", + "itertools 0.14.0", "mt-fiat-shamir", "mt-field", "mt-koala-bear", @@ -709,12 +1548,13 @@ dependencies = [ "mt-sumcheck", "mt-symetric", "mt-utils", - "rand", - "rayon", + "parallel", + "rand 0.10.1", "system-info", "tracing", "tracing-forest", "tracing-subscriber", + "zk-alloc", ] [[package]] @@ -752,6 +1592,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -791,6 +1632,222 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "p3-baby-bear" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "p3-challenger", + "p3-field", + "p3-mds", + "p3-monty-31", + "p3-poseidon1", + "p3-poseidon2", + "p3-symmetric", + "rand 0.10.1", +] + +[[package]] +name = "p3-challenger" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "p3-field", + "p3-maybe-rayon", + "p3-monty-31", + "p3-symmetric", + "p3-util", + "tracing", +] + +[[package]] +name = "p3-dft" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "itertools 0.14.0", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", + "spin 0.10.0", + "tracing", +] + +[[package]] +name = "p3-field" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "itertools 0.14.0", + "num-bigint", + "p3-maybe-rayon", + "p3-util", + "paste", + "rand 0.10.1", + "serde", + "tracing", +] + +[[package]] +name = "p3-koala-bear" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "p3-challenger", + "p3-field", + "p3-mds", + "p3-monty-31", + "p3-poseidon1", + "p3-poseidon2", + "p3-symmetric", + "rand 0.10.1", +] + +[[package]] +name = "p3-matrix" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "itertools 0.14.0", + "p3-field", + "p3-maybe-rayon", + "p3-util", + "rand 0.10.1", + "serde", + "tracing", +] + +[[package]] +name = "p3-maybe-rayon" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" + +[[package]] +name = "p3-mds" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "p3-dft", + "p3-field", + "p3-symmetric", + "p3-util", + "rand 0.10.1", +] + +[[package]] +name = "p3-monty-31" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "itertools 0.14.0", + "num-bigint", + "p3-dft", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-mds", + "p3-poseidon1", + "p3-poseidon2", + "p3-symmetric", + "p3-util", + "paste", + "rand 0.10.1", + "serde", + "spin 0.10.0", + "tracing", +] + +[[package]] +name = "p3-poseidon1" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "p3-field", + "p3-symmetric", + "rand 0.10.1", +] + +[[package]] +name = "p3-poseidon2" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "p3-field", + "p3-mds", + "p3-symmetric", + "p3-util", + "rand 0.10.1", +] + +[[package]] +name = "p3-symmetric" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "itertools 0.14.0", + "p3-field", + "p3-util", + "serde", +] + +[[package]] +name = "p3-util" +version = "0.5.1" +source = "git+https://github.com/Plonky3/Plonky3.git#bde2e86e9a9ad3ed4ec5002bff31f88f23029e40" +dependencies = [ + "serde", + "transpose", +] + +[[package]] +name = "parallel" +version = "0.1.0" +dependencies = [ + "system-info", +] + +[[package]] +name = "parity-scale-codec" +version = "3.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799781ae679d79a948e13d4824a40970bfa500058d245760dd857301059810fa" +dependencies = [ + "arrayvec", + "bitvec", + "byte-slice-cast", + "const_format", + "impl-trait-for-tuples", + "parity-scale-codec-derive", + "rustversion", + "serde", +] + +[[package]] +name = "parity-scale-codec-derive" +version = "3.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34b4653168b563151153c9e4c08ebed57fb8262bebfa79711552fa983c623e7a" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "paste" version = "1.0.15" @@ -827,7 +1884,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -846,6 +1903,16 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "postcard" version = "1.1.3" @@ -859,6 +1926,15 @@ dependencies = [ "serde", ] +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -866,7 +1942,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.117", +] + +[[package]] +name = "primitive-types" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b34d9fd68ae0b74a41b21c03c2f62847aa0ffea044eee893b4c140b37e244e2" +dependencies = [ + "fixed-hash", + "impl-codec", + "uint", +] + +[[package]] +name = "proc-macro-crate" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" +dependencies = [ + "toml_edit", ] [[package]] @@ -878,43 +1974,159 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags", + "num-traits", + "rand 0.9.4", + "rand_chacha 0.9.0", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "rand" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", + "serde", +] + +[[package]] +name = "rand" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" +dependencies = [ + "chacha20", + "getrandom 0.4.2", + "rand_core 0.10.1", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" dependencies = [ - "proc-macro2", + "getrandom 0.3.4", + "serde", ] [[package]] -name = "r-efi" -version = "6.0.0" +name = "rand_core" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" +checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" [[package]] -name = "rand" -version = "0.10.0" +name = "rand_xorshift" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" dependencies = [ - "chacha20", - "getrandom", - "rand_core", + "rand_core 0.9.5", ] [[package]] -name = "rand_core" -version = "0.10.0" +name = "rapidhash" +version = "4.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" +checksum = "b5e48930979c155e2f33aa36ab3119b5ee81332beb6482199a8ecd6029b80b59" +dependencies = [ + "rustversion", +] [[package]] name = "rayon" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" dependencies = [ "either", "rayon-core", @@ -939,19 +2151,29 @@ dependencies = [ "lean_compiler", "lean_prover", "lean_vm", + "leansig_wrapper", "lz4_flex", "objc2", "objc2-foundation", "postcard", - "rand", + "rand 0.10.1", "serde", + "sha3 0.11.0", "sub_protocols", "tracing", "utils", - "xmss", "zk-alloc", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex-automata" version = "0.4.14" @@ -969,13 +2191,119 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + +[[package]] +name = "rlp" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb919243f34364b6bd2fc10ef797edbfa75f33c252e7998527479c6d6b47e1ec" +dependencies = [ + "bytes", + "rustc-hex", +] + +[[package]] +name = "ruint" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0298da754d1395046b0afdc2f20ee76d29a8ae310cd30ffa84ed42acba9cb12a" +dependencies = [ + "alloy-rlp", + "ark-ff 0.3.0", + "ark-ff 0.4.2", + "ark-ff 0.5.0", + "bytes", + "fastrlp 0.3.1", + "fastrlp 0.4.0", + "num-bigint", + "num-integer", + "num-traits", + "parity-scale-codec", + "primitive-types", + "proptest", + "rand 0.8.6", + "rand 0.9.4", + "rlp", + "ruint-macro", + "serde_core", + "valuable", + "zeroize", +] + +[[package]] +name = "ruint-macro" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48fd7bd8a6377e15ad9d42a8ec25371b94ddc67abe7c8b9127bec79bebaaae18" + +[[package]] +name = "rustc-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" + +[[package]] +name = "rustc-hex" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e75f6a532d0fd9f7f13144f392b6ad56a32696bfcd9c78f797f16bbb6f072d6" + +[[package]] +name = "rustc_version" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee" +dependencies = [ + "semver 0.11.0", +] + [[package]] name = "rustc_version" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ - "semver", + "semver 1.0.28", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "rusty-fork" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", ] [[package]] @@ -984,11 +2312,43 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + +[[package]] +name = "semver" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6" +dependencies = [ + "semver-parser", +] + [[package]] name = "semver" -version = "1.0.27" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + +[[package]] +name = "semver-parser" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +checksum = "9900206b54a3527fdc7b8a938bffd94a568bac4f4aa8113b209df75a09c0dec2" +dependencies = [ + "pest", +] [[package]] name = "serde" @@ -1017,7 +2377,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -1044,6 +2404,16 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "sha3" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77fd7028345d415a4034cf8777cd4f8ab1851274233b45f84e3d955502d93874" +dependencies = [ + "digest 0.10.7", + "keccak 0.1.6", +] + [[package]] name = "sha3" version = "0.11.0" @@ -1051,7 +2421,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be176f1a57ce4e3d31c1a166222d9768de5954f811601fb7ca06fc8203905ce1" dependencies = [ "digest 0.11.2", - "keccak", + "keccak 0.2.0", +] + +[[package]] +name = "sha3-asm" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cbb88c189d6352cc8ae96a39d19c7ecad8f7330b29461187f2587fdc2988d5" +dependencies = [ + "cc", + "cfg-if", ] [[package]] @@ -1063,6 +2443,22 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest 0.10.7", + "rand_core 0.6.4", +] + [[package]] name = "smallvec" version = "1.15.1" @@ -1078,12 +2474,43 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spin" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "strsim" version = "0.11.1" @@ -1097,11 +2524,29 @@ dependencies = [ "backend", "lean_prover", "lean_vm", - "rand", + "parallel", + "rand 0.10.1", "tracing", "utils", ] +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.117" @@ -1118,7 +2563,25 @@ name = "system-info" version = "0.1.0" dependencies = [ "libc", - "rayon", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys", ] [[package]] @@ -1138,7 +2601,7 @@ checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -1150,6 +2613,36 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "toml_datetime" +version = "1.1.1+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_edit" +version = "0.25.11+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b59c4d22ed448339746c59b905d24568fcbb3ab65a500494f7b8c3e97739f2b" +dependencies = [ + "indexmap", + "toml_datetime", + "toml_parser", + "winnow", +] + +[[package]] +name = "toml_parser" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" +dependencies = [ + "winnow", +] + [[package]] name = "tracing" version = "0.1.44" @@ -1169,7 +2662,7 @@ checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.117", ] [[package]] @@ -1224,6 +2717,16 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "twox-hash" version = "2.1.2" @@ -1232,9 +2735,9 @@ checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" [[package]] name = "typenum" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" [[package]] name = "ucd-trie" @@ -1242,12 +2745,36 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" +[[package]] +name = "uint" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76f64bba2c53b04fcab63c01a7d7427eadc821e3bc48c34dc9ba29c501164b52" +dependencies = [ + "byteorder", + "crunchy", + "hex", + "static_assertions", +] + +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-ident" version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-segmentation" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" + [[package]] name = "unicode-xid" version = "0.2.6" @@ -1265,7 +2792,7 @@ name = "utils" version = "0.1.0" dependencies = [ "backend", - "rand", + "rand 0.10.1", "tracing", "tracing-forest", "tracing-subscriber", @@ -1283,13 +2810,28 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wait-timeout" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" +dependencies = [ + "libc", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + [[package]] name = "wasip2" -version = "1.0.2+wasi-0.2.9" +version = "1.0.3+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.57.1", ] [[package]] @@ -1298,7 +2840,7 @@ version = "0.4.0+wasi-0.3.0-rc-2026-01-06" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.51.0", ] [[package]] @@ -1332,7 +2874,7 @@ dependencies = [ "bitflags", "hashbrown 0.15.5", "indexmap", - "semver", + "semver 1.0.28", ] [[package]] @@ -1372,6 +2914,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "winnow" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" +dependencies = [ + "memchr", +] + [[package]] name = "wit-bindgen" version = "0.51.0" @@ -1381,6 +2932,12 @@ dependencies = [ "wit-bindgen-rust-macro", ] +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + [[package]] name = "wit-bindgen-core" version = "0.51.0" @@ -1402,7 +2959,7 @@ dependencies = [ "heck", "indexmap", "prettyplease", - "syn", + "syn 2.0.117", "wasm-metadata", "wit-bindgen-core", "wit-component", @@ -1418,7 +2975,7 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn", + "syn 2.0.117", "wit-bindgen-core", "wit-bindgen-rust", ] @@ -1452,7 +3009,7 @@ dependencies = [ "id-arena", "indexmap", "log", - "semver", + "semver 1.0.28", "serde", "serde_derive", "serde_json", @@ -1461,16 +3018,52 @@ dependencies = [ ] [[package]] -name = "xmss" -version = "0.1.0" +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" dependencies = [ - "backend", - "lz4_flex", - "postcard", - "rand", - "serde", - "sha3", - "utils", + "tap", +] + +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", ] [[package]] @@ -1478,7 +3071,7 @@ name = "zk-alloc" version = "0.1.0" dependencies = [ "libc", - "rayon", + "parallel", "system-info", ] diff --git a/Cargo.toml b/Cargo.toml index f8e2ada76..267dd8044 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ members = [ "crates/backend/sumcheck", "crates/backend/system-info", "crates/backend/zk-alloc", + "crates/backend/parallel", ] [workspace.lints] @@ -55,20 +56,20 @@ wildcard_imports = "allow" # Local utils = { path = "crates/utils" } lean_vm = { path = "crates/lean_vm" } -xmss = { path = "crates/xmss" } sub_protocols = { path = "crates/sub_protocols" } lean_compiler = { path = "crates/lean_compiler" } lean_prover = { path = "crates/lean_prover" } rec_aggregation = { path = "crates/rec_aggregation" } +leansig_wrapper = { path = "crates/leansig_wrapper" } backend = { path = "crates/backend" } zk-alloc = { path = "crates/backend/zk-alloc" } system-info = { path = "crates/backend/system-info" } +parallel = { path = "crates/backend/parallel" } # External sha3 = "0.11.0" clap = { version = "4.5.59", features = ["derive"] } rand = "0.10.0" -rayon = "1.11.0" pest = "2.7" pest_derive = "2.7" itertools = "0.14.0" @@ -79,25 +80,32 @@ tracing-subscriber = { version = "0.3.23", features = ["std", "env-filter"] } tracing-forest = { version = "0.3.0", features = ["ansi", "smallvec"] } postcard = { version = "1.1.3", features = ["alloc"] } lz4_flex = "0.13.0" +leansig = { git = "https://github.com/leanEthereum/leanSig" } +leansig_fast_keygen = { git = "https://github.com/TomWambsgans/leanSig", branch = "devnet4-fast-keygen" } include_dir = "0.7" [features] prox-gaps-conjecture = ["rec_aggregation/prox-gaps-conjecture"] -standard-alloc = ["rec_aggregation/standard-alloc"] +test-config = ["rec_aggregation/test-config"] [dependencies] clap.workspace = true rec_aggregation.workspace = true zk-alloc.workspace = true +parallel.workspace = true rand.workspace = true sub_protocols.workspace = true utils.workspace = true +leansig_wrapper.workspace = true lean_vm.workspace = true -xmss.workspace = true backend.workspace = true serde_json.workspace = true system-info.workspace = true +[dev-dependencies] +serde.workspace = true + [profile.release] lto = "thin" +codegen-units = 1 diff --git a/crates/backend/Cargo.toml b/crates/backend/Cargo.toml index 3f61957af..b1f633c5b 100644 --- a/crates/backend/Cargo.toml +++ b/crates/backend/Cargo.toml @@ -9,9 +9,10 @@ poly = { path = "poly", package = "mt-poly" } sumcheck = { path = "sumcheck", package = "mt-sumcheck" } field = { path = "field", package = "mt-field" } air = { path = "air", package = "mt-air" } -rayon.workspace = true +parallel.workspace = true whir = { path = "../whir", package = "mt-whir" } tracing.workspace = true fiat-shamir = { path = "fiat-shamir", package = "mt-fiat-shamir" } koala-bear = { path = "koala-bear", package = "mt-koala-bear" } utils = { path = "utils", package = "mt-utils" } +zk-alloc.workspace = true diff --git a/crates/backend/air/src/constraint_folder/normal.rs b/crates/backend/air/src/constraint_folder/normal.rs index 01880ba76..68c627fc2 100644 --- a/crates/backend/air/src/constraint_folder/normal.rs +++ b/crates/backend/air/src/constraint_folder/normal.rs @@ -47,14 +47,14 @@ where self.shift } - #[inline] + #[inline(always)] fn assert_zero(&mut self, x: IF) { let alpha_power = self.extra_data.alpha_powers()[self.constraint_index]; self.accumulator += alpha_power * x; self.constraint_index += 1; } - #[inline] + #[inline(always)] fn assert_zero_ef(&mut self, x: EF) { let alpha_power = self.extra_data.alpha_powers()[self.constraint_index]; self.accumulator += alpha_power * x; diff --git a/crates/backend/air/src/constraint_folder/packed.rs b/crates/backend/air/src/constraint_folder/packed.rs index bad2d76b0..2aa9ed734 100644 --- a/crates/backend/air/src/constraint_folder/packed.rs +++ b/crates/backend/air/src/constraint_folder/packed.rs @@ -57,21 +57,21 @@ where self.shift } - #[inline] + #[inline(always)] fn assert_zero(&mut self, x: IF) { let alpha_power = self.extra_data.alpha_powers()[self.constraint_index]; self.accumulator += EFPacking::::from(alpha_power) * x; self.constraint_index += 1; } - #[inline] + #[inline(always)] fn assert_zero_ef(&mut self, x: EFPacking) { let alpha_power = self.extra_data.alpha_powers()[self.constraint_index]; self.accumulator += EFPacking::::from(alpha_power) * x; self.constraint_index += 1; } - #[inline] + #[inline(always)] fn assert_eq_low(&mut self, x: IF, y: IF) { let alpha_power = self.extra_data.alpha_powers()[self.constraint_index]; let contrib = EFPacking::::from(alpha_power) * (x - y); @@ -80,7 +80,7 @@ where self.constraint_index += 1; } - #[inline] + #[inline(always)] fn low_degree_block(&mut self, state: &mut [IF], block: F) where F: FnOnce(&mut Self, &mut [IF]), diff --git a/crates/backend/air/src/lib.rs b/crates/backend/air/src/lib.rs index b8b3361f1..deee0bf0f 100644 --- a/crates/backend/air/src/lib.rs +++ b/crates/backend/air/src/lib.rs @@ -56,20 +56,24 @@ pub trait AirBuilder: Sized { fn assert_zero(&mut self, x: Self::IF); fn assert_zero_ef(&mut self, x: Self::EF); + #[inline(always)] fn assert_eq(&mut self, x: Self::IF, y: Self::IF) { self.assert_zero(x - y); } + #[inline(always)] fn assert_bool(&mut self, x: Self::IF) { self.assert_zero(x.bool_check()); } + #[inline(always)] fn assert_eq_low(&mut self, x: Self::IF, y: Self::IF) { self.assert_eq(x, y); } /// Execute `block` as a low-degree sub-region whose post-state is "cacheable" /// = linear in z without the low-degree constraints + #[inline(always)] fn low_degree_block(&mut self, state: &mut [Self::IF], block: F) where F: FnOnce(&mut Self, &mut [Self::IF]), diff --git a/crates/backend/air/src/symbolic.rs b/crates/backend/air/src/symbolic.rs index f36286ce3..08dadf610 100644 --- a/crates/backend/air/src/symbolic.rs +++ b/crates/backend/air/src/symbolic.rs @@ -92,9 +92,15 @@ fn alloc_node(node: SymbolicNode) -> u32 { }) } -pub fn get_node(idx: u32) -> SymbolicNode { +/// # Safety +/// `idx` must be an offset returned by `alloc_node::` for the current (same `F`, uncleared) arena. +pub unsafe fn get_node(idx: u32) -> SymbolicNode { ARENA.with(|arena| { let bytes = arena.borrow(); + assert!( + idx as usize + std::mem::size_of::>() <= bytes.len(), + "arena index out of bounds" + ); unsafe { std::ptr::read_unaligned(bytes.as_ptr().add(idx as usize) as *const SymbolicNode) } }) } diff --git a/crates/backend/fiat-shamir/Cargo.toml b/crates/backend/fiat-shamir/Cargo.toml index ec8649bc2..57f32d2ba 100644 --- a/crates/backend/fiat-shamir/Cargo.toml +++ b/crates/backend/fiat-shamir/Cargo.toml @@ -10,4 +10,4 @@ symetric = { path = "../symetric", package = "mt-symetric" } utils = { path = "../utils", package = "mt-utils" } tracing.workspace = true serde.workspace = true -rayon.workspace = true +parallel.workspace = true diff --git a/crates/backend/fiat-shamir/src/prover.rs b/crates/backend/fiat-shamir/src/prover.rs index 80bb6d13e..0485594dd 100644 --- a/crates/backend/fiat-shamir/src/prover.rs +++ b/crates/backend/fiat-shamir/src/prover.rs @@ -9,8 +9,7 @@ use field::PrimeCharacteristicRing; use field::integers::QuotientMap; use field::{ExtensionField, PrimeField64}; use koala_bear::symmetric::Permutation; -use rayon::prelude::*; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::time::Duration; use std::{fmt::Debug, sync::Mutex, time::Instant}; @@ -132,9 +131,16 @@ where let witness_found = Mutex::>>::new(None); // each batch tests lanes witnesses simultaneously let num_batches = PF::::ORDER_U64.div_ceil(lanes as u64); - (0..num_batches) - .into_par_iter() - .find_any(|&batch| { + // Work-stealing parallel search: each worker pulls batches from a shared counter and + // stops once any worker has found a witness (`found`). + let next_batch = AtomicU64::new(0); + let found = AtomicBool::new(false); + parallel::for_each_index(parallel::num_threads(), |_| { + while !found.load(Ordering::Relaxed) { + let batch = next_batch.fetch_add(1, Ordering::Relaxed); + if batch >= num_batches { + break; + } let base = batch * lanes as u64; let packed_witnesses = Packed::::from_fn(|lane| { @@ -159,14 +165,14 @@ where let rand_usize = sample.as_canonical_u64() as usize; if (rand_usize & ((1 << bits) - 1)) == 0 { *witness_found.lock().unwrap() = Some(*witness); - return true; + found.store(true, Ordering::Relaxed); + break; } } - false - }) - .expect("failed to find witness"); + } + }); - let witness = witness_found.lock().unwrap().unwrap(); + let witness = witness_found.lock().unwrap().expect("failed to find witness"); self.challenger.observe_many(&[witness]); assert!(self.challenger.state[CAPACITY].as_canonical_u64() & ((1 << bits) - 1) == 0); diff --git a/crates/backend/field/Cargo.toml b/crates/backend/field/Cargo.toml index 89e87c133..28af2f64a 100644 --- a/crates/backend/field/Cargo.toml +++ b/crates/backend/field/Cargo.toml @@ -10,6 +10,6 @@ itertools.workspace = true num-bigint = "*" paste = "*" rand.workspace = true -rayon.workspace = true +parallel.workspace = true serde.workspace = true tracing.workspace = true diff --git a/crates/backend/field/src/field.rs b/crates/backend/field/src/field.rs index b44ed45ed..836529cf9 100644 --- a/crates/backend/field/src/field.rs +++ b/crates/backend/field/src/field.rs @@ -9,7 +9,6 @@ use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAss use core::{array, slice}; use num_bigint::BigUint; -use rayon::{current_num_threads, prelude::*}; use serde::Serialize; use serde::de::DeserializeOwned; use utils::{flatten_to_base, iter_array_chunks_padded}; @@ -1020,7 +1019,7 @@ impl BoundedPowers { let mut points_packed = F::Packing::zero_vec(num_packed); // Split computation evenly among threads - let num_threads = current_num_threads().max(1); + let num_threads = parallel::num_threads().max(1); let chunk_size = num_packed.div_ceil(num_threads); // Precompute base for each chunk. @@ -1028,16 +1027,13 @@ impl BoundedPowers { let chunk_base = base.exp_u64((chunk_size * width) as u64); let shift = self.iter.current; - points_packed - .par_chunks_mut(chunk_size) - .enumerate() - .for_each(|(chunk_idx, chunk_slice)| { - // First power in this chunk - let chunk_start = shift * chunk_base.exp_u64(chunk_idx as u64); + parallel::par_chunks_mut(&mut points_packed, chunk_size, |chunk_idx, chunk_slice| { + // First power in this chunk + let chunk_start = shift * chunk_base.exp_u64(chunk_idx as u64); - // Fill the chunk with packed powers. - F::Packing::packed_shifted_powers(base, chunk_start).fill(chunk_slice); - }); + // Fill the chunk with packed powers. + F::Packing::packed_shifted_powers(base, chunk_start).fill(chunk_slice); + }); // return the number of requested points, discarding the unused packed powers // SAFETY: size_of:: always divides size_of::. diff --git a/crates/backend/field/src/op_assign_macros.rs b/crates/backend/field/src/op_assign_macros.rs index 78ffeb4f7..a78ace7b0 100644 --- a/crates/backend/field/src/op_assign_macros.rs +++ b/crates/backend/field/src/op_assign_macros.rs @@ -266,7 +266,7 @@ macro_rules! impl_rng { impl$(<$param_name: $type_param>)? Distribution<$type$(<$param_name>)?> for StandardUniform { #[inline] fn sample(&self, rng: &mut R) -> $type$(<$param_name>)? { - $type(rand::RngExt::random(rng)) + $type(StandardUniform.sample(rng)) } } } diff --git a/crates/backend/koala-bear/Cargo.toml b/crates/backend/koala-bear/Cargo.toml index aba2ab231..5ce4ad111 100644 --- a/crates/backend/koala-bear/Cargo.toml +++ b/crates/backend/koala-bear/Cargo.toml @@ -8,7 +8,6 @@ field = { path = "../field", package = "mt-field" } utils = { path = "../utils", package = "mt-utils" } rand.workspace = true -rayon.workspace = true serde.workspace = true itertools.workspace = true tracing.workspace = true diff --git a/crates/backend/koala-bear/src/benchmark_poseidons.rs b/crates/backend/koala-bear/src/benchmark_poseidons.rs index 66c6a5a0d..f92b083e4 100644 --- a/crates/backend/koala-bear/src/benchmark_poseidons.rs +++ b/crates/backend/koala-bear/src/benchmark_poseidons.rs @@ -5,6 +5,7 @@ use field::Field; use field::PackedValue; use field::PrimeCharacteristicRing; +use crate::default_koalabear_poseidon1_24; use crate::{KoalaBear, default_koalabear_poseidon1_16}; type FPacking = ::Packing; @@ -17,13 +18,17 @@ fn bench_poseidon() { let n = 1 << 23; let poseidon1_16 = default_koalabear_poseidon1_16(); + let poseidon1_24 = default_koalabear_poseidon1_24(); // warming let mut state_16: [FPacking; 16] = [FPacking::ZERO; 16]; + let mut state_24: [FPacking; 24] = [FPacking::ZERO; 24]; for _ in 0..1 << 15 { poseidon1_16.compress_in_place(&mut state_16); + poseidon1_24.compress_in_place(&mut state_24); } let _ = black_box(state_16); + let _ = black_box(state_24); let time = Instant::now(); for _ in 0..n / PACKING_WIDTH { @@ -36,4 +41,16 @@ fn bench_poseidon() { PACKING_WIDTH, (n as f64 / time_p1_simd.as_secs_f64() / 1_000_000.0) ); + + let time = Instant::now(); + for _ in 0..n / PACKING_WIDTH { + poseidon1_24.compress_in_place(&mut state_24); + } + let _ = black_box(state_24); + let time_p1_simd = time.elapsed(); + println!( + "Poseidon1 24 SIMD (width {}): {:.2}M hashes/s", + PACKING_WIDTH, + (n as f64 / time_p1_simd.as_secs_f64() / 1_000_000.0) + ); } diff --git a/crates/backend/koala-bear/src/lib.rs b/crates/backend/koala-bear/src/lib.rs index 959ed3ada..9fc32bedb 100644 --- a/crates/backend/koala-bear/src/lib.rs +++ b/crates/backend/koala-bear/src/lib.rs @@ -7,6 +7,7 @@ extern crate alloc; mod koala_bear; pub mod monty_31; mod poseidon1_koalabear_16; +mod poseidon1_koalabear_24; pub mod quintic_extension; pub mod symmetric; @@ -22,9 +23,13 @@ mod x86_64_avx2; #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] mod x86_64_avx512; -pub use koala_bear::*; pub use monty_31::*; + +pub use koala_bear::*; + pub use poseidon1_koalabear_16::*; +pub use poseidon1_koalabear_24::*; + pub use quintic_extension::*; #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] diff --git a/crates/backend/koala-bear/src/monty_31/aarch64_neon/poseidon_helpers.rs b/crates/backend/koala-bear/src/monty_31/aarch64_neon/poseidon_helpers.rs index 61333abca..7e08771cd 100644 --- a/crates/backend/koala-bear/src/monty_31/aarch64_neon/poseidon_helpers.rs +++ b/crates/backend/koala-bear/src/monty_31/aarch64_neon/poseidon_helpers.rs @@ -35,6 +35,29 @@ impl InternalLayer16 { } } +/// A specialized representation of the Poseidon state for a width of 24. +/// +/// Same split as `InternalLayer16` but for width 24. +#[derive(Clone, Copy)] +#[repr(C)] +pub struct InternalLayer24 { + pub(crate) s0: PackedMontyField31Neon, + pub(crate) s_hi: [uint32x4_t; 23], +} + +impl InternalLayer24 { + #[inline] + pub(crate) unsafe fn to_packed_field_array(self) -> [PackedMontyField31Neon; 24] { + unsafe { transmute(self) } + } + + #[inline] + #[must_use] + pub(crate) fn from_packed_field_array(vector: [PackedMontyField31Neon; 24]) -> Self { + unsafe { transmute(vector) } + } +} + /// Converts a scalar constant into a packed NEON vector in "negative form" (`c - P`). #[inline(always)] pub(crate) fn convert_to_vec_neg_form_neon(input: i32) -> int32x4_t { diff --git a/crates/backend/koala-bear/src/poseidon1_koalabear_24.rs b/crates/backend/koala-bear/src/poseidon1_koalabear_24.rs new file mode 100644 index 000000000..d2668f43d --- /dev/null +++ b/crates/backend/koala-bear/src/poseidon1_koalabear_24.rs @@ -0,0 +1,1049 @@ +// Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). + +use std::sync::OnceLock; + +use core::ops::Mul; + +use crate::KoalaBear; +use crate::symmetric::Permutation; +use field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing}; + +pub const POSEIDON1_WIDTH_24: usize = 24; +pub const POSEIDON1_HALF_FULL_ROUNDS_24: usize = 4; +pub const POSEIDON1_PARTIAL_ROUNDS_24: usize = 23; +pub const POSEIDON1_SBOX_DEGREE_24: u64 = 3; +const POSEIDON1_N_ROUNDS_24: usize = 2 * POSEIDON1_HALF_FULL_ROUNDS_24 + POSEIDON1_PARTIAL_ROUNDS_24; + +// ========================================================================= +// MDS circulant matrix (first column) +// ========================================================================= + +/// First column of the circulant MDS matrix for width 24. +/// +/// Derived from the plonky3 first-row data via first_row_to_first_col: +/// col[0] = row[0], col[i] = row[24 - i] for i = 1..23. +const MDS_CIRC_COL_24: [KoalaBear; 24] = KoalaBear::new_array([ + 0x2D0AAAAB, 0x0878A07F, 0x17E118F6, 0x5C7790FA, 0x0A6E572C, 0x6BE4DF69, 0x0524C7F2, 0x0C23DC41, 0x3C2C3DBE, + 0x1689DD98, 0x5D57AFC2, 0x2495A71D, 0x68FC71C8, 0x0360405D, 0x26D52A61, 0x3C0F5038, 0x77CDA9E2, 0x729601A7, + 0x18D6F3CA, 0x60703026, 0x6D91A8D5, 0x04ECBEB5, 0x17F5551D, 0x64850517, +]); + +// ========================================================================= +// Karatsuba convolution chain: 3 → 6 → 12 → 24 +// +// Ported from Plonky3 mds/src/karatsuba_convolution.rs (FieldConvolve). +// ========================================================================= + +#[inline(always)] +fn parity_dot, const N: usize>( + lhs: [R; N], + rhs: [KoalaBear; N], +) -> R { + let mut acc = lhs[0] * rhs[0]; + for i in 1..N { + acc += lhs[i] * rhs[i]; + } + acc +} + +#[inline(always)] +fn conv3>(lhs: [R; 3], rhs: [KoalaBear; 3], output: &mut [R]) { + output[0] = parity_dot(lhs, [rhs[0], rhs[2], rhs[1]]); + output[1] = parity_dot(lhs, [rhs[1], rhs[0], rhs[2]]); + output[2] = parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]); +} + +#[inline(always)] +fn negacyclic_conv3>( + lhs: [R; 3], + rhs: [KoalaBear; 3], + output: &mut [R], +) { + output[0] = parity_dot(lhs, [rhs[0], -rhs[2], -rhs[1]]); + output[1] = parity_dot(lhs, [rhs[1], rhs[0], -rhs[2]]); + output[2] = parity_dot(lhs, [rhs[2], rhs[1], rhs[0]]); +} + +#[inline(always)] +fn conv_n_recursive, const N: usize, const H: usize>( + lhs: [R; N], + rhs: [KoalaBear; N], + output: &mut [R], + inner_conv: fn([R; H], [KoalaBear; H], &mut [R]), + inner_neg: fn([R; H], [KoalaBear; H], &mut [R]), +) { + let mut lp = [R::ZERO; H]; + let mut ln = [R::ZERO; H]; + let mut rp = [KoalaBear::ZERO; H]; + let mut rn = [KoalaBear::ZERO; H]; + for i in 0..H { + lp[i] = lhs[i] + lhs[i + H]; + ln[i] = lhs[i] - lhs[i + H]; + rp[i] = rhs[i] + rhs[i + H]; + rn[i] = rhs[i] - rhs[i + H]; + } + let (left, right) = output.split_at_mut(H); + inner_neg(ln, rn, left); + inner_conv(lp, rp, right); + for i in 0..H { + left[i] += right[i]; + left[i] = left[i].halve(); + right[i] -= left[i]; + } +} + +#[inline(always)] +fn negacyclic_conv_n_recursive< + R: PrimeCharacteristicRing + Mul, + const N: usize, + const H: usize, +>( + lhs: [R; N], + rhs: [KoalaBear; N], + output: &mut [R], + inner_neg: fn([R; H], [KoalaBear; H], &mut [R]), +) { + let mut le = [R::ZERO; H]; + let mut lo = [R::ZERO; H]; + let mut ls = [R::ZERO; H]; + let mut re = [KoalaBear::ZERO; H]; + let mut ro = [KoalaBear::ZERO; H]; + let mut rs = [KoalaBear::ZERO; H]; + for i in 0..H { + le[i] = lhs[2 * i]; + lo[i] = lhs[2 * i + 1]; + ls[i] = le[i] + lo[i]; + re[i] = rhs[2 * i]; + ro[i] = rhs[2 * i + 1]; + rs[i] = re[i] + ro[i]; + } + let mut es = [R::ZERO; H]; + let (left, right) = output.split_at_mut(H); + inner_neg(le, re, &mut es); + inner_neg(lo, ro, left); + inner_neg(ls, rs, right); + right[0] -= es[0] + left[0]; + es[0] -= left[H - 1]; + for i in 1..H { + right[i] -= es[i] + left[i]; + es[i] += left[i - 1]; + } + for i in 0..H { + output[2 * i] = es[i]; + output[2 * i + 1] = output[i + H]; + } +} + +#[inline(always)] +fn conv6>(lhs: [R; 6], rhs: [KoalaBear; 6], output: &mut [R]) { + conv_n_recursive(lhs, rhs, output, conv3::, negacyclic_conv3::); +} + +#[inline(always)] +fn negacyclic_conv6>( + lhs: [R; 6], + rhs: [KoalaBear; 6], + output: &mut [R], +) { + negacyclic_conv_n_recursive(lhs, rhs, output, negacyclic_conv3::); +} + +#[inline(always)] +fn conv12>( + lhs: [R; 12], + rhs: [KoalaBear; 12], + output: &mut [R], +) { + conv_n_recursive(lhs, rhs, output, conv6::, negacyclic_conv6::); +} + +#[inline(always)] +fn negacyclic_conv12>( + lhs: [R; 12], + rhs: [KoalaBear; 12], + output: &mut [R], +) { + negacyclic_conv_n_recursive(lhs, rhs, output, negacyclic_conv6::); +} + +/// Circulant MDS multiply via Karatsuba convolution: state = C * state. +#[inline(always)] +pub fn mds_circ_24>(state: &mut [R; 24]) { + let input = *state; + conv_n_recursive( + input, + MDS_CIRC_COL_24, + state.as_mut_slice(), + conv12::, + negacyclic_conv12::, + ); +} + +// ========================================================================= +// NEON-optimized Karatsuba using mixed_dot_product (fewer Montgomery reductions) +// +// On NEON, mixed_dot_product accumulates products in 64-bit precision and +// does a single Montgomery reduction. The rhs (MDS column) stays as scalar +// KoalaBear values throughout the recursion, avoiding redundant NEON +// add/sub operations and SIMD port contention. Only at the leaf level +// are scalars broadcast to NEON lanes for the multiply-accumulate. +// ========================================================================= + +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +mod neon_karatsuba { + use super::*; + type P = PackedKB; + type F = KoalaBear; + + #[inline(always)] + fn pdot(lhs: [P; N], rhs: [F; N]) -> P { + P::mixed_dot_product(&lhs, &rhs) + } + + #[inline(always)] + fn conv3(lhs: [P; 3], rhs: [F; 3], output: &mut [P]) { + output[0] = pdot(lhs, [rhs[0], rhs[2], rhs[1]]); + output[1] = pdot(lhs, [rhs[1], rhs[0], rhs[2]]); + output[2] = pdot(lhs, [rhs[2], rhs[1], rhs[0]]); + } + + #[inline(always)] + fn negacyclic_conv3(lhs: [P; 3], rhs: [F; 3], output: &mut [P]) { + output[0] = pdot(lhs, [rhs[0], -rhs[2], -rhs[1]]); + output[1] = pdot(lhs, [rhs[1], rhs[0], -rhs[2]]); + output[2] = pdot(lhs, [rhs[2], rhs[1], rhs[0]]); + } + + #[inline(always)] + fn conv_n( + lhs: [P; N], + rhs: [F; N], + output: &mut [P], + inner_conv: fn([P; H], [F; H], &mut [P]), + inner_neg: fn([P; H], [F; H], &mut [P]), + ) { + let mut lp = [P::ZERO; H]; + let mut ln = [P::ZERO; H]; + let mut rp = [F::ZERO; H]; + let mut rn = [F::ZERO; H]; + for i in 0..H { + lp[i] = lhs[i] + lhs[i + H]; + ln[i] = lhs[i] - lhs[i + H]; + rp[i] = rhs[i] + rhs[i + H]; + rn[i] = rhs[i] - rhs[i + H]; + } + let (left, right) = output.split_at_mut(H); + inner_neg(ln, rn, left); + inner_conv(lp, rp, right); + for i in 0..H { + left[i] += right[i]; + left[i] = left[i].halve(); + right[i] -= left[i]; + } + } + + #[inline(always)] + fn negacyclic_conv_n( + lhs: [P; N], + rhs: [F; N], + output: &mut [P], + inner_neg: fn([P; H], [F; H], &mut [P]), + ) { + let mut le = [P::ZERO; H]; + let mut lo = [P::ZERO; H]; + let mut ls = [P::ZERO; H]; + let mut re = [F::ZERO; H]; + let mut ro = [F::ZERO; H]; + let mut rs = [F::ZERO; H]; + for i in 0..H { + le[i] = lhs[2 * i]; + lo[i] = lhs[2 * i + 1]; + ls[i] = le[i] + lo[i]; + re[i] = rhs[2 * i]; + ro[i] = rhs[2 * i + 1]; + rs[i] = re[i] + ro[i]; + } + let mut es = [P::ZERO; H]; + let (left, right) = output.split_at_mut(H); + inner_neg(le, re, &mut es); + inner_neg(lo, ro, left); + inner_neg(ls, rs, right); + right[0] -= es[0] + left[0]; + es[0] -= left[H - 1]; + for i in 1..H { + right[i] -= es[i] + left[i]; + es[i] += left[i - 1]; + } + for i in 0..H { + output[2 * i] = es[i]; + output[2 * i + 1] = output[i + H]; + } + } + + #[inline(always)] + fn conv6(lhs: [P; 6], rhs: [F; 6], output: &mut [P]) { + conv_n(lhs, rhs, output, conv3, negacyclic_conv3); + } + + #[inline(always)] + fn negacyclic_conv6(lhs: [P; 6], rhs: [F; 6], output: &mut [P]) { + negacyclic_conv_n(lhs, rhs, output, negacyclic_conv3); + } + + #[inline(always)] + fn conv12(lhs: [P; 12], rhs: [F; 12], output: &mut [P]) { + conv_n(lhs, rhs, output, conv6, negacyclic_conv6); + } + + #[inline(always)] + fn negacyclic_conv12(lhs: [P; 12], rhs: [F; 12], output: &mut [P]) { + negacyclic_conv_n(lhs, rhs, output, negacyclic_conv6); + } + + #[inline(always)] + pub(super) fn mds_circ_24_neon(state: &mut [P; 24], col: &[F; 24]) { + let input = *state; + conv_n(input, *col, state.as_mut_slice(), conv12, negacyclic_conv12); + } +} + +// ========================================================================= +// Sparse matrix decomposition helpers (precomputation only) +// ========================================================================= + +type F24 = [KoalaBear; 24]; +type M24 = [[KoalaBear; 24]; 24]; + +fn matrix_mul_24(a: &M24, b: &M24) -> M24 { + core::array::from_fn(|i| { + core::array::from_fn(|j| { + let mut s = KoalaBear::ZERO; + for k in 0..24 { + s += a[i][k] * b[k][j]; + } + s + }) + }) +} + +fn matrix_vec_mul_24(m: &M24, v: &F24) -> F24 { + core::array::from_fn(|i| { + let mut s = KoalaBear::ZERO; + for j in 0..24 { + s += m[i][j] * v[j]; + } + s + }) +} + +fn matrix_transpose_24(m: &M24) -> M24 { + core::array::from_fn(|i| core::array::from_fn(|j| m[j][i])) +} + +fn matrix_inverse_24(m: &M24) -> M24 { + let mut aug: M24 = *m; + let mut inv: M24 = + core::array::from_fn(|i| core::array::from_fn(|j| if i == j { KoalaBear::ONE } else { KoalaBear::ZERO })); + for col in 0..24 { + let pivot_row = (col..24) + .find(|&r| aug[r][col] != KoalaBear::ZERO) + .expect("Matrix is singular"); + if pivot_row != col { + aug.swap(col, pivot_row); + inv.swap(col, pivot_row); + } + let pivot_inv = aug[col][col].inverse(); + for j in 0..24 { + aug[col][j] *= pivot_inv; + inv[col][j] *= pivot_inv; + } + for i in 0..24 { + if i == col { + continue; + } + let factor = aug[i][col]; + if factor == KoalaBear::ZERO { + continue; + } + let aug_col_row = aug[col]; + let inv_col_row = inv[col]; + for j in 0..24 { + aug[i][j] -= factor * aug_col_row[j]; + inv[i][j] -= factor * inv_col_row[j]; + } + } + } + inv +} + +/// Inverse of the 23x23 bottom-right submatrix (Vec-based). +fn submatrix_inverse_23(m: &M24) -> Vec> { + let n = 23; + let mut sub: Vec> = (0..n).map(|i| (0..n).map(|j| m[i + 1][j + 1]).collect()).collect(); + let mut inv: Vec> = (0..n) + .map(|i| { + let mut row = vec![KoalaBear::ZERO; n]; + row[i] = KoalaBear::ONE; + row + }) + .collect(); + for col in 0..n { + let pivot_row = (col..n) + .find(|&r| sub[r][col] != KoalaBear::ZERO) + .expect("Submatrix is singular"); + if pivot_row != col { + sub.swap(col, pivot_row); + inv.swap(col, pivot_row); + } + let pivot_inv = sub[col][col].inverse(); + for j in 0..n { + sub[col][j] *= pivot_inv; + inv[col][j] *= pivot_inv; + } + for i in 0..n { + if i == col { + continue; + } + let factor = sub[i][col]; + if factor == KoalaBear::ZERO { + continue; + } + let sub_col_row: Vec = sub[col].clone(); + let inv_col_row: Vec = inv[col].clone(); + for j in 0..n { + sub[i][j] -= factor * sub_col_row[j]; + inv[i][j] -= factor * inv_col_row[j]; + } + } + } + inv +} + +/// Factor the dense MDS matrix into sparse matrices for partial rounds. +/// Returns (m_i, v_collection, w_hat_collection) in forward application order. +fn compute_equivalent_matrices_24(mds: &M24) -> (M24, Vec, Vec) { + let rounds_p = POSEIDON1_PARTIAL_ROUNDS_24; + let mut w_hat_collection: Vec = Vec::with_capacity(rounds_p); + let mut v_collection: Vec = Vec::with_capacity(rounds_p); + + let mds_t = matrix_transpose_24(mds); + let mut m_mul = mds_t; + let mut m_i = [[KoalaBear::ZERO; 24]; 24]; + + for _ in 0..rounds_p { + let v_arr: F24 = core::array::from_fn(|j| if j < 23 { m_mul[0][j + 1] } else { KoalaBear::ZERO }); + let w: Vec = (1..24).map(|i| m_mul[i][0]).collect(); + let m_hat_inv = submatrix_inverse_23(&m_mul); + let w_hat_arr: F24 = core::array::from_fn(|i| { + if i < 23 { + let mut s = KoalaBear::ZERO; + for k in 0..23 { + s += m_hat_inv[i][k] * w[k]; + } + s + } else { + KoalaBear::ZERO + } + }); + v_collection.push(v_arr); + w_hat_collection.push(w_hat_arr); + + m_i = m_mul; + m_i[0][0] = KoalaBear::ONE; + for row in m_i.iter_mut().skip(1) { + row[0] = KoalaBear::ZERO; + } + for elem in m_i[0].iter_mut().skip(1) { + *elem = KoalaBear::ZERO; + } + m_mul = matrix_mul_24(&mds_t, &m_i); + } + + let m_i_returned = matrix_transpose_24(&m_i); + v_collection.reverse(); + w_hat_collection.reverse(); + (m_i_returned, v_collection, w_hat_collection) +} + +/// Compress round constants via backward substitution through MDS^{-1}. +fn equivalent_round_constants_24(partial_rc: &[F24], mds_inv: &M24) -> (F24, Vec) { + let rounds_p = partial_rc.len(); + let mut opt_partial_rc = vec![KoalaBear::ZERO; rounds_p]; + let mut tmp = partial_rc[rounds_p - 1]; + for i in (0..rounds_p - 1).rev() { + let inv_cip = matrix_vec_mul_24(mds_inv, &tmp); + opt_partial_rc[i + 1] = inv_cip[0]; + tmp = partial_rc[i]; + for j in 1..24 { + tmp[j] += inv_cip[j]; + } + } + let first_round_constants = tmp; + let scalar_constants = opt_partial_rc[1..].to_vec(); + (first_round_constants, scalar_constants) +} + +// ========================================================================= +// NEON types (conditional) +// ========================================================================= + +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +type FP = crate::KoalaBearParameters; +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +type PackedKB = crate::PackedKoalaBearNeon; + +// ========================================================================= +// Precomputed constants +// ========================================================================= + +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +struct NeonPrecomputed24 { + /// Initial full round constants in negative NEON form (first 3 rounds). + packed_initial_rc: [[core::arch::aarch64::int32x4_t; 24]; POSEIDON1_HALF_FULL_ROUNDS_24 - 1], + /// Last initial round constant in negative NEON form. + packed_last_initial_rc: [core::arch::aarch64::int32x4_t; 24], + /// Terminal full round constants in negative NEON form. + packed_terminal_rc: [[core::arch::aarch64::int32x4_t; 24]; POSEIDON1_HALF_FULL_ROUNDS_24], + /// MDS circulant column for NEON Karatsuba (scalar, not packed). + /// Kept as scalar so the Karatsuba recursion avoids redundant NEON + /// operations on the constant side, reducing SIMD port contention. + mds_col: [KoalaBear; 24], + /// Fused matrix: m_i * MDS. + packed_fused_mi_mds: [[PackedKB; 24]; 24], + /// Fused bias: m_i * first_round_constants. + packed_fused_bias: [PackedKB; 24], + /// Pre-packed sparse first rows. + packed_sparse_first_row: [[PackedKB; 24]; POSEIDON1_PARTIAL_ROUNDS_24], + /// Pre-packed v vectors. + packed_sparse_v: [[PackedKB; 24]; POSEIDON1_PARTIAL_ROUNDS_24], + /// Pre-packed scalar round constants for partial rounds 0..RP-2. + packed_round_constants: [PackedKB; POSEIDON1_PARTIAL_ROUNDS_24 - 1], +} + +#[cfg(all(target_arch = "aarch64", target_feature = "neon"))] +impl std::fmt::Debug for NeonPrecomputed24 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NeonPrecomputed24").finish_non_exhaustive() + } +} + +#[derive(Debug)] +struct Precomputed24 { + /// First round constant vector (full width), added once before m_i multiply. + sparse_first_round_constants: F24, + /// Dense transition matrix m_i, applied once before the partial round loop. + sparse_m_i: M24, + /// Per-round full first row: [mds_0_0, ŵ[0], ..., ŵ[22]]. + sparse_first_row: Vec, + /// Per-round first-column vectors (excluding [0,0]). + sparse_v: Vec, + /// Scalar constants for partial rounds 0..RP-2. + sparse_round_constants: Vec, + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + neon: NeonPrecomputed24, +} + +static PRECOMPUTED_24: OnceLock = OnceLock::new(); + +fn precomputed_24() -> &'static Precomputed24 { + PRECOMPUTED_24.get_or_init(|| { + let mds: M24 = core::array::from_fn(|i| core::array::from_fn(|j| MDS_CIRC_COL_24[(24 + i - j) % 24])); + + let partial_rc = &POSEIDON1_RC_24 + [POSEIDON1_HALF_FULL_ROUNDS_24..POSEIDON1_HALF_FULL_ROUNDS_24 + POSEIDON1_PARTIAL_ROUNDS_24]; + + // Sparse matrix decomposition. + let mds_inv = matrix_inverse_24(&mds); + let (first_round_constants, scalar_round_constants) = equivalent_round_constants_24(partial_rc, &mds_inv); + let (m_i, sparse_v, sparse_w_hat) = compute_equivalent_matrices_24(&mds); + + // Pre-assemble full first rows: [mds_0_0, ŵ[0], ..., ŵ[22]]. + let mds_0_0 = mds[0][0]; + let sparse_first_row: Vec = sparse_w_hat + .iter() + .map(|w| core::array::from_fn(|i| if i == 0 { mds_0_0 } else { w[i - 1] })) + .collect(); + + // --- NEON pre-packed constants --- + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + let neon = { + use crate::PackedMontyField31Neon; + use crate::convert_to_vec_neg_form_neon; + + let pack = |c: KoalaBear| PackedMontyField31Neon::::from(c); + let neg_form = |c: KoalaBear| convert_to_vec_neg_form_neon::(c.value as i32); + + // Initial full round constants (first 3; 4th is fused). + let init_rc = poseidon1_24_initial_constants(); + let packed_initial_rc: [[core::arch::aarch64::int32x4_t; 24]; POSEIDON1_HALF_FULL_ROUNDS_24 - 1] = + core::array::from_fn(|r| init_rc[r].map(neg_form)); + let packed_last_initial_rc = init_rc[POSEIDON1_HALF_FULL_ROUNDS_24 - 1].map(neg_form); + + // Terminal full round constants. + let term_rc = poseidon1_24_final_constants(); + let packed_terminal_rc: [[core::arch::aarch64::int32x4_t; 24]; POSEIDON1_HALF_FULL_ROUNDS_24] = + core::array::from_fn(|r| term_rc[r].map(neg_form)); + + // Pre-packed sparse constants. + let packed_sparse_first_row: [[PackedKB; 24]; POSEIDON1_PARTIAL_ROUNDS_24] = + core::array::from_fn(|r| sparse_first_row[r].map(pack)); + let packed_sparse_v: [[PackedKB; 24]; POSEIDON1_PARTIAL_ROUNDS_24] = + core::array::from_fn(|r| sparse_v[r].map(pack)); + let packed_round_constants: [PackedKB; POSEIDON1_PARTIAL_ROUNDS_24 - 1] = + core::array::from_fn(|r| pack(scalar_round_constants[r])); + + // MDS column for NEON Karatsuba (scalar, not packed). + let mds_col: [KoalaBear; 24] = MDS_CIRC_COL_24; + + // Fused matrix: m_i * MDS. + let fused_mi_mds = matrix_mul_24(&m_i, &mds); + let packed_fused_mi_mds: [[PackedKB; 24]; 24] = core::array::from_fn(|i| fused_mi_mds[i].map(pack)); + + // Fused bias: m_i * first_round_constants. + let fused_bias = matrix_vec_mul_24(&m_i, &first_round_constants); + let packed_fused_bias: [PackedKB; 24] = fused_bias.map(pack); + + NeonPrecomputed24 { + packed_initial_rc, + packed_last_initial_rc, + packed_terminal_rc, + mds_col, + packed_fused_mi_mds, + packed_fused_bias, + packed_sparse_first_row, + packed_sparse_v, + packed_round_constants, + } + }; + + Precomputed24 { + sparse_first_round_constants: first_round_constants, + sparse_m_i: m_i, + sparse_first_row, + sparse_v, + sparse_round_constants: scalar_round_constants, + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + neon, + } + }) +} + +const POSEIDON1_RC_24: [[KoalaBear; 24]; POSEIDON1_N_ROUNDS_24] = KoalaBear::new_2d_array([ + // Initial full rounds (4) + [ + 0x1d0939dc, 0x6d050f8d, 0x628058ad, 0x2681385d, 0x3e3c62be, 0x032cfad8, 0x5a91ba3c, 0x015a56e6, 0x696b889c, + 0x0dbcd780, 0x5881b5c9, 0x2a076f2e, 0x55393055, 0x6513a085, 0x547ac78f, 0x4281c5b8, 0x3e7a3f6c, 0x34562c19, + 0x2c04e679, 0x0ed78234, 0x5f7a1aa9, 0x0177640e, 0x0ea4f8d1, 0x15be7692, + ], + [ + 0x6eafdd62, 0x71a572c6, 0x72416f0a, 0x31ce1ad3, 0x2136a0cf, 0x1507c0eb, 0x1eb6e07a, 0x3a0ccf7b, 0x38e4bf31, + 0x44128286, 0x6b05e976, 0x244a9b92, 0x6e4b32a8, 0x78ee2496, 0x4761115b, 0x3d3a7077, 0x75d3c670, 0x396a2475, + 0x26dd00b4, 0x7df50f59, 0x0cb922df, 0x0568b190, 0x5bd3fcd6, 0x1351f58e, + ], + [ + 0x52191b5f, 0x119171b8, 0x1e8bb727, 0x27d21f26, 0x36146613, 0x1ee817a2, 0x71abe84e, 0x44b88070, 0x5dc04410, + 0x2aeaa2f6, 0x2b7bb311, 0x6906884d, 0x0522e053, 0x0c45a214, 0x1b016998, 0x479b1052, 0x3acc89be, 0x0776021a, + 0x7a34a1f5, 0x70f87911, 0x2caf9d9e, 0x026aff1b, 0x2c42468e, 0x67726b45, + ], + [ + 0x09b6f53c, 0x73d76589, 0x5793eeb0, 0x29e720f3, 0x75fc8bdf, 0x4c2fae0e, 0x20b41db3, 0x7e491510, 0x2cadef18, + 0x57fc24d6, 0x4d1ade4a, 0x36bf8e3c, 0x3511b63c, 0x64d8476f, 0x732ba706, 0x46634978, 0x0521c17c, 0x5ee69212, + 0x3559cba9, 0x2b33df89, 0x653538d6, 0x5fde8344, 0x4091605d, 0x2933bdde, + ], + // Partial rounds (23) + [ + 0x1395d4ca, 0x5dbac049, 0x51fc2727, 0x13407399, 0x39ac6953, 0x45e8726c, 0x75a7311c, 0x599f82c9, 0x702cf13b, + 0x026b8955, 0x44e09bbc, 0x2211207f, 0x5128b4e3, 0x591c41af, 0x674f5c68, 0x3981d0d3, 0x2d82f898, 0x707cd267, + 0x3b4cca45, 0x2ad0dc3c, 0x0cb79b37, 0x23f2f4e8, 0x3de4e739, 0x7d232359, + ], + [ + 0x389d82f9, 0x259b2e6c, 0x45a94def, 0x0d497380, 0x5b049135, 0x3c268399, 0x78feb2f9, 0x300a3eec, 0x505165bb, + 0x20300973, 0x2327c081, 0x1a45a2f4, 0x5b32ea2e, 0x2d5d1a70, 0x053e613e, 0x5433e39f, 0x495529f0, 0x1eaa1aa9, + 0x578f572a, 0x698ede71, 0x5a0f9dba, 0x398a2e96, 0x0c7b2925, 0x2e6b9564, + ], + [ + 0x026b00de, 0x7644c1e9, 0x5c23d0bd, 0x3470b5ef, 0x6013cf3a, 0x48747288, 0x13b7a543, 0x3eaebd44, 0x0004e60c, + 0x1e8363a2, 0x2343259a, 0x69da0c2a, 0x06e3e4c4, 0x1095018e, 0x0deea348, 0x1f4c5513, 0x4f9a3a98, 0x3179112b, + 0x524abb1f, 0x21615ba2, 0x23ab4065, 0x1202a1d1, 0x21d25b83, 0x6ed17c2f, + ], + [ + 0x391e6b09, 0x5e4ed894, 0x6a2f58f2, 0x5d980d70, 0x3fa48c5e, 0x1f6366f7, 0x63540f5f, 0x6a8235ed, 0x14c12a78, + 0x6edde1c9, 0x58ce1c22, 0x718588bb, 0x334313ad, 0x7478dbc7, 0x647ad52f, 0x39e82049, 0x6fee146a, 0x082c2f24, + 0x1f093015, 0x30173c18, 0x53f70c0d, 0x6028ab0c, 0x2f47a1ee, 0x26a6780e, + ], + [ + 0x3540bc83, 0x1812b49f, 0x5149c827, 0x631dd925, 0x001f2dea, 0x7dc05194, 0x3789672e, 0x7cabf72e, 0x242dbe2f, + 0x0b07a51d, 0x38653650, 0x50785c4e, 0x60e8a7e0, 0x07464338, 0x3482d6e1, 0x08a69f1e, 0x3f2aff24, 0x5814c30d, + 0x13fecab2, 0x61cb291a, 0x68c8226f, 0x5c757eea, 0x289b4e1e, 0x0198d9b3, + ], + [ + 0x070a92e6, 0x2f1b6cb3, 0x535008bb, 0x35af339a, 0x7a38e92c, 0x4ff71b5c, 0x3b193aba, 0x34d12a1e, 0x17e94240, + 0x2ec214dc, 0x43e09385, 0x7d546918, 0x71af9dfd, 0x761a21bb, 0x43fdc986, 0x05dda714, 0x2d0e78b5, 0x1fcd387b, + 0x76e10a76, 0x28a112d5, 0x1a7bd787, 0x40190de2, 0x2e27906a, 0x2033954e, + ], + [ + 0x20afd2c8, 0x71b5ecb2, 0x57828fb3, 0x222851d8, 0x732df0e9, 0x73f48435, 0x7e63ea98, 0x058be348, 0x229e7a5f, + 0x04576a2f, 0x29939f10, 0x7afd830a, 0x5d6dd961, 0x0eb65d94, 0x39da2b79, 0x36bce8ba, 0x5f53a7d4, 0x383b1cd2, + 0x1fdc3c5f, 0x7d9ca544, 0x77480711, 0x36c51a1a, 0x009ea59b, 0x731b17fd, + ], + [ + 0x201359bd, 0x22bf6499, 0x610f1a29, 0x3c73aa45, 0x6a092599, 0x1c7cb703, 0x79533459, 0x7ef62d86, 0x5ab925ab, + 0x67722ab1, 0x33ca4cff, 0x007f7dce, 0x0eeac41e, 0x4724bea7, 0x45eaf64f, 0x21a6c90f, 0x094b4150, 0x0d942630, + 0x18712c30, 0x3a470338, 0x6eba7720, 0x487827c8, 0x77013a6d, 0x4ad07390, + ], + [ + 0x57d802ea, 0x720f5fd4, 0x5b8a5357, 0x3649db1f, 0x35ea476a, 0x4c6589f5, 0x02c9f31f, 0x16d04670, 0x62d74b20, + 0x1de813cc, 0x189966ed, 0x527add06, 0x1704f5af, 0x000f1703, 0x00152a1f, 0x2f49a365, 0x40ee4288, 0x0ab86260, + 0x080c8576, 0x36c6cc05, 0x0ab9346f, 0x62aa3ec8, 0x51109797, 0x0feb1585, + ], + [ + 0x04700024, 0x01dee723, 0x5cd4aaa8, 0x1fe43ce5, 0x25c31267, 0x58512b48, 0x54147539, 0x4e340ab9, 0x563fbaeb, + 0x60c8353a, 0x65a12d49, 0x6c499fb2, 0x7ea07556, 0x396e2bbb, 0x31a318f1, 0x11f855ae, 0x6edffb87, 0x59977042, + 0x6ec5fa94, 0x75b4f690, 0x44b6fc61, 0x02a8bed8, 0x4c88c824, 0x08e31432, + ], + [ + 0x09a4c09f, 0x4796b47d, 0x215b7e75, 0x0c639599, 0x0d93dd4c, 0x2fac41de, 0x4f46dadd, 0x03905848, 0x2b1c39c1, + 0x25fff199, 0x38621f7b, 0x69e59315, 0x1874c308, 0x024a3959, 0x2bae1f12, 0x3c200626, 0x6ba5d369, 0x2fe9b97e, + 0x674cc08e, 0x2cbb9657, 0x550e56c2, 0x5b80e0ec, 0x6549ccff, 0x54e3e61a, + ], + [ + 0x0fa689e3, 0x2c534848, 0x1eb24382, 0x61b959b5, 0x4d5f001e, 0x003a95cd, 0x1edd4507, 0x621e895d, 0x7dc6e599, + 0x0fbc2771, 0x152d0879, 0x77801087, 0x6a2dd731, 0x3644aba2, 0x2e43a814, 0x12ff923f, 0x01cfe2c9, 0x35f8a572, + 0x5789fd35, 0x16f39e7a, 0x7c0ca31c, 0x01016283, 0x2c9dcd96, 0x5d3c6f4e, + ], + [ + 0x0058a186, 0x16354360, 0x502a262b, 0x2b56f93e, 0x0bc41ecb, 0x33c83e8b, 0x21968fc3, 0x6364490c, 0x16a45aa5, + 0x286d873f, 0x2be17254, 0x381fbc06, 0x0df309aa, 0x15d48b84, 0x0fb2c5dd, 0x7c440d21, 0x74908f00, 0x75520624, + 0x7e58f065, 0x141e1e41, 0x6582f4ae, 0x2c4479e5, 0x7a09fff8, 0x1baa979f, + ], + [ + 0x45ab39bd, 0x774f78bc, 0x3c5f9aa2, 0x115d9dc9, 0x4b1546d7, 0x196c1a55, 0x6a88fb5e, 0x4c1ca910, 0x34869067, + 0x2662dcbb, 0x0a4625d4, 0x25b121c8, 0x1a50ccd2, 0x490ea316, 0x42556ffa, 0x6b5e4f88, 0x329faf33, 0x54f39a88, + 0x3b411e09, 0x6950ae8e, 0x310a912c, 0x63bddcba, 0x347977c0, 0x52831335, + ], + [ + 0x41f32fc6, 0x67dd5acb, 0x41ae544e, 0x1d83750a, 0x4bb58d20, 0x2f5496ee, 0x353819ec, 0x412ee425, 0x1bfd2747, + 0x32a14699, 0x2f7be906, 0x38afda41, 0x5b1e6316, 0x7b810b48, 0x6aebb30d, 0x55d94f89, 0x69db4833, 0x3a6ecb6c, + 0x50e7d206, 0x148a4b69, 0x1ac5548d, 0x40019cf9, 0x1e566f2a, 0x0998a950, + ], + [ + 0x5bc887f0, 0x73fbbd18, 0x341e05a8, 0x7d0597d5, 0x582308d9, 0x7a98addf, 0x0938b854, 0x544bf13d, 0x50090144, + 0x13baf374, 0x1896a8d5, 0x75ea7475, 0x23510dd8, 0x72c93bcc, 0x1c41410e, 0x4b72d5f9, 0x103ccc4e, 0x3896bef2, + 0x2c5e0b1c, 0x1e2096de, 0x15594d47, 0x04e035ce, 0x2785d1b1, 0x795bc87d, + ], + [ + 0x373fecbf, 0x0b18c3a0, 0x6516874a, 0x2b567be9, 0x5a2a3d1b, 0x74d99c04, 0x437de605, 0x047df991, 0x322faad4, + 0x2ef2f76f, 0x5f9e7278, 0x62740235, 0x18c1e8c2, 0x0691e203, 0x3324646d, 0x59542c9f, 0x32433d0d, 0x42c17492, + 0x45ac808a, 0x685394e0, 0x316f7193, 0x5ea108a0, 0x6bb3f12f, 0x232f8865, + ], + [ + 0x7c162b62, 0x52aa9e45, 0x1b69f8db, 0x3ec35206, 0x1ef086dd, 0x34d7a5e3, 0x33aeea57, 0x03565cc8, 0x5bc5fd47, + 0x47adc343, 0x1d5857a2, 0x5e7ece76, 0x0239fba3, 0x58bdead4, 0x41671aef, 0x3c8a9189, 0x7342ed52, 0x19871456, + 0x573a02c8, 0x2ec8ad55, 0x09c4a997, 0x34b9b63a, 0x226da984, 0x6b31d16e, + ], + [ + 0x458384d2, 0x353911e1, 0x4cfd1256, 0x163c23af, 0x7609c5e0, 0x76596c08, 0x087adac7, 0x4fd4b62c, 0x3692a037, + 0x51c54b62, 0x133daf4d, 0x0c76f623, 0x387d21f3, 0x6034abe5, 0x7c982e2b, 0x63a266b4, 0x4f2b17b8, 0x0bd62f1d, + 0x70e37a7c, 0x4f162da9, 0x38f0e527, 0x6ce798d7, 0x6c74250b, 0x606f2fad, + ], + [ + 0x212b041d, 0x6724fd32, 0x73aaf9af, 0x3ae9b76b, 0x014fe151, 0x37687943, 0x36bb7786, 0x01da85ef, 0x28c618ae, + 0x36706580, 0x3f5f610d, 0x2e0b9391, 0x5750e38d, 0x00b48d71, 0x0f1f1d7a, 0x7107c415, 0x35c1e287, 0x26ccce2f, + 0x4e29277a, 0x1580ee9d, 0x18136f74, 0x530f32ad, 0x5a19b05d, 0x3d38b320, + ], + [ + 0x6a3bf1e4, 0x39e9edbb, 0x2ce6a59e, 0x2df215e1, 0x216a17ba, 0x3a8f3cfa, 0x0a14d990, 0x1162e529, 0x1213c181, + 0x3daa68f5, 0x16c570ff, 0x1063321c, 0x06a2d0e8, 0x17c094a4, 0x39a5d9c9, 0x086d4802, 0x67ab7fe3, 0x67f51392, + 0x3649c2ac, 0x62aa8cf8, 0x55b6fdbb, 0x55c3e972, 0x2f865724, 0x314fa653, + ], + [ + 0x029f66f1, 0x016f80a2, 0x4b70e0c2, 0x1782f9ab, 0x697578ee, 0x07b2c8b7, 0x123f6681, 0x2b78db24, 0x2cd8db9d, + 0x302947b1, 0x04f4c99a, 0x1f8bcbbd, 0x61c782ea, 0x3459928c, 0x3efec720, 0x24f2b8f6, 0x5dec66b5, 0x622386cc, + 0x26b70002, 0x1fa0d640, 0x6edeaa0a, 0x670ff3e1, 0x18641d8e, 0x43b68197, + ], + [ + 0x315b1707, 0x46db526a, 0x02fa5277, 0x36f6edf9, 0x31ad912b, 0x7d518ebd, 0x61db2eea, 0x0ba28bad, 0x3c839e59, + 0x7ed007f1, 0x74447f8a, 0x6b4ce5b7, 0x7272e3a4, 0x192257d1, 0x5f882281, 0x5f890768, 0x47eec4cb, 0x2ef3e6c8, + 0x43d6e4e2, 0x668ce6ba, 0x50679e00, 0x24c067a8, 0x605be47c, 0x324ac2ec, + ], + // Terminal full rounds (4) + [ + 0x5883788f, 0x7eba66af, 0x23620f78, 0x44492c9a, 0x7cc098a4, 0x705191fa, 0x2f7185e2, 0x6ebbb07e, 0x23508c3b, + 0x6cb0f0f4, 0x1190a8c0, 0x60f8f1d0, 0x316c16a1, 0x440742c7, 0x7643f142, 0x642f9668, 0x214b7566, 0x52a5c469, + 0x1bfd90da, 0x1d7d8076, 0x6e06d1e8, 0x7d672e6d, 0x6fd2e3e3, 0x3257ae18, + ], + [ + 0x75861a51, 0x0e2996fe, 0x2bdc228b, 0x6879fcb8, 0x14ca9b1c, 0x29953d92, 0x36ee671d, 0x31366e47, 0x79c4f5f2, + 0x2b8c8639, 0x073a293d, 0x32802c31, 0x4894d32f, 0x06acc989, 0x40d852b1, 0x508857c4, 0x2ffe504d, 0x18be00c1, + 0x75a114e9, 0x4ed5922a, 0x1060ee72, 0x2176563c, 0x0b91b242, 0x6bfbf1a4, + ], + [ + 0x06f94470, 0x694f4383, 0x53cada3e, 0x1527bfd8, 0x2bdfe868, 0x120c2d2c, 0x7dfd6309, 0x10b619c2, 0x0550bc7f, + 0x488cf3dc, 0x4c5454a2, 0x00be2976, 0x349c9669, 0x2b4eb07d, 0x0450bf40, 0x58de7343, 0x3495a265, 0x2305e3b7, + 0x661dd781, 0x1c183983, 0x46992791, 0x3eb3751f, 0x38f728c8, 0x775d0a30, + ], + [ + 0x7636645a, 0x7125aa5d, 0x0c3f2dca, 0x13b595cc, 0x5a5e9bce, 0x54bb3456, 0x069a1a5a, 0x7b9f15ee, 0x50150189, + 0x68c9157b, 0x07e06e22, 0x568aecdb, 0x1403f847, 0x436cf5da, 0x3f09c026, 0x652f7b1b, 0x3e8607f3, 0x5bb37c57, + 0x1b1a9ecf, 0x39d11cb0, 0x1841a51c, 0x1251ad48, 0x74fb5edd, 0x21fa33c6, + ], +]); + +// ========================================================================= +// Accessors +// ========================================================================= + +pub fn poseidon1_24_round_constants() -> &'static [[KoalaBear; 24]; POSEIDON1_N_ROUNDS_24] { + &POSEIDON1_RC_24 +} + +#[inline(always)] +pub fn poseidon1_24_initial_constants() -> &'static [[KoalaBear; 24]] { + &POSEIDON1_RC_24[..POSEIDON1_HALF_FULL_ROUNDS_24] +} + +#[inline(always)] +pub fn poseidon1_24_partial_constants() -> &'static [[KoalaBear; 24]] { + &POSEIDON1_RC_24[POSEIDON1_HALF_FULL_ROUNDS_24..POSEIDON1_HALF_FULL_ROUNDS_24 + POSEIDON1_PARTIAL_ROUNDS_24] +} + +#[inline(always)] +pub fn poseidon1_24_final_constants() -> &'static [[KoalaBear; 24]] { + &POSEIDON1_RC_24[POSEIDON1_HALF_FULL_ROUNDS_24 + POSEIDON1_PARTIAL_ROUNDS_24..] +} + +pub fn poseidon1_24_sparse_m_i() -> &'static [[KoalaBear; 24]; 24] { + &precomputed_24().sparse_m_i +} + +pub fn poseidon1_24_sparse_first_row() -> &'static Vec<[KoalaBear; 24]> { + &precomputed_24().sparse_first_row +} + +pub fn poseidon1_24_sparse_v() -> &'static Vec<[KoalaBear; 24]> { + &precomputed_24().sparse_v +} + +pub fn poseidon1_24_sparse_first_round_constants() -> &'static [KoalaBear; 24] { + &precomputed_24().sparse_first_round_constants +} + +pub fn poseidon1_24_sparse_scalar_round_constants() -> &'static Vec { + &precomputed_24().sparse_round_constants +} + +#[derive(Clone, Debug)] +pub struct Poseidon1KoalaBear24 { + pre: &'static Precomputed24, +} + +impl Poseidon1KoalaBear24 { + #[inline(always)] + #[allow(clippy::needless_range_loop)] + fn permute_generic + InjectiveMonomial<3>>(&self, state: &mut [R; 24]) { + // Initial full rounds: AddRC + S-box + Karatsuba MDS. + for rc in poseidon1_24_initial_constants() { + Self::full_round(state, rc); + } + + // Partial rounds via sparse decomposition. + // Add first-round constants. + for (s, &c) in state.iter_mut().zip(self.pre.sparse_first_round_constants.iter()) { + *s += c; + } + // Apply dense transition matrix m_i (once). + { + let input = *state; + for i in 0..24 { + state[i] = R::ZERO; + for j in 0..24 { + state[i] += input[j] * self.pre.sparse_m_i[i][j]; + } + } + } + // Loop over partial rounds: S-box + scalar constant + cheap_matmul. + let rounds_p = self.pre.sparse_first_row.len(); + for r in 0..rounds_p { + state[0] = state[0].injective_exp_n(); + if r < rounds_p - 1 { + state[0] += self.pre.sparse_round_constants[r]; + } + // cheap_matmul: O(24) sparse matrix multiply. + let old_s0 = state[0]; + state[0] = parity_dot(*state, self.pre.sparse_first_row[r]); + for i in 1..24 { + state[i] += old_s0 * self.pre.sparse_v[r][i - 1]; + } + } + + // Terminal full rounds. + for rc in poseidon1_24_final_constants() { + Self::full_round(state, rc); + } + } + + #[inline(always)] + fn full_round + InjectiveMonomial<3>>(state: &mut [R; 24], rc: &[KoalaBear; 24]) { + for (s, &c) in state.iter_mut().zip(rc.iter()) { + *s += c; + } + for s in state.iter_mut() { + *s = s.injective_exp_n(); + } + mds_circ_24(state); + } + + /// NEON-specific fast path with pre-packed constants and latency hiding. + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + #[inline(always)] + fn permute_neon(&self, state: &mut [PackedKB; 24]) { + use crate::PackedMontyField31Neon; + use crate::exp_small; + use crate::{InternalLayer24, add_rc_and_sbox}; + use core::mem::transmute; + + let neon = &self.pre.neon; + + // --- Initial full rounds (first 3 of 4) --- + for round_constants in &neon.packed_initial_rc { + for (s, &rc) in state.iter_mut().zip(round_constants.iter()) { + add_rc_and_sbox::(s, rc); + } + neon_karatsuba::mds_circ_24_neon(state, &neon.mds_col); + } + + // --- Last initial full round: AddRC + S-box, then fused (m_i * MDS) --- + { + for (s, &rc) in state.iter_mut().zip(neon.packed_last_initial_rc.iter()) { + add_rc_and_sbox::(s, rc); + } + let input = *state; + for (i, s) in state.iter_mut().enumerate() { + *s = PackedMontyField31Neon::::dot_product(&input, &neon.packed_fused_mi_mds[i]) + + neon.packed_fused_bias[i]; + } + } + + // --- Partial rounds with latency hiding via InternalLayer24 split --- + { + let mut split = InternalLayer24::from_packed_field_array(*state); + + for r in 0..POSEIDON1_PARTIAL_ROUNDS_24 { + // PATH A (high latency): S-box on s0. + unsafe { + let s0_signed = split.s0.to_signed_vector(); + let s0_sboxed = exp_small::(s0_signed); + split.s0 = PackedMontyField31Neon::from_vector(s0_sboxed); + } + + // Add scalar round constant (except last round). + if r < POSEIDON1_PARTIAL_ROUNDS_24 - 1 { + split.s0 += neon.packed_round_constants[r]; + } + + // PATH B (can overlap with S-box): partial dot product on s_hi. + let s_hi: &[PackedKB; 23] = unsafe { transmute(&split.s_hi) }; + let first_row = &neon.packed_sparse_first_row[r]; + let first_row_hi: &[PackedKB; 23] = first_row[1..].try_into().unwrap(); + let partial_dot = PackedMontyField31Neon::::dot_product(s_hi, first_row_hi); + + // SERIAL: complete s0 = first_row[0] * s0 + partial_dot. + let s0_val = split.s0; + split.s0 = s0_val * first_row[0] + partial_dot; + + // Rank-1 update: s_hi[j] += s0_old * v[j]. + let v = &neon.packed_sparse_v[r]; + let s_hi_mut: &mut [PackedKB; 23] = unsafe { transmute(&mut split.s_hi) }; + for j in 0..23 { + s_hi_mut[j] += s0_val * v[j]; + } + } + + *state = unsafe { split.to_packed_field_array() }; + } + + // --- Terminal full rounds --- + for round_constants in &neon.packed_terminal_rc { + for (s, &rc) in state.iter_mut().zip(round_constants.iter()) { + add_rc_and_sbox::(s, rc); + } + neon_karatsuba::mds_circ_24_neon(state, &neon.mds_col); + } + } + + /// Compression mode: output = permute(input) + input. + #[inline(always)] + pub fn compress_in_place + InjectiveMonomial<3> + Send + Sync + 'static>( + &self, + state: &mut [R; 24], + ) { + let initial = *state; + Permutation::permute_mut(self, state); + for (s, init) in state.iter_mut().zip(initial) { + *s += init; + } + } +} + +impl + InjectiveMonomial<3> + Send + Sync + 'static> Permutation<[R; 24]> + for Poseidon1KoalaBear24 +{ + fn permute_mut(&self, input: &mut [R; 24]) { + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::any::TypeId::of::() == std::any::TypeId::of::() { + let neon_state: &mut [PackedKB; 24] = unsafe { &mut *(input as *mut [R; 24] as *mut [PackedKB; 24]) }; + self.permute_neon(neon_state); + return; + } + } + self.permute_generic(input); + } +} + +pub fn default_koalabear_poseidon1_24() -> Poseidon1KoalaBear24 { + Poseidon1KoalaBear24 { pre: precomputed_24() } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::KoalaBear; + use field::PrimeField32; + + #[test] + fn test_plonky3_compatibility() { + /* + + use p3_symmetric::Permutation; + + use crate::{KoalaBear, default_koalabear_poseidon1_24}; + + #[test] + fn plonky3_test() { + let poseidon1 = default_koalabear_poseidon1_24(); + let mut input: [KoalaBear; 24] = KoalaBear::new_array([ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + ]); + poseidon1.permute_mut(&mut input); + dbg!(&input); + } + + */ + let p1 = default_koalabear_poseidon1_24(); + let mut input: [KoalaBear; 24] = KoalaBear::new_array([ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + ]); + p1.permute_mut(&mut input); + let vals: Vec = input.iter().map(|x| x.as_canonical_u32()).collect(); + assert_eq!( + vals, + vec![ + 511672087, 215882318, 237782537, 740528428, 712760904, 54615367, 751514671, 110231969, 1905276435, + 992525666, 918312360, 18628693, 749929200, 1916418953, 691276896, 1112901727, 1163558623, 882867603, + 673396520, 1480278156, 1402044758, 1693467175, 1766273044, 433841551, + ] + ); + } +} diff --git a/crates/backend/koala-bear/src/quintic_extension/packed_extension.rs b/crates/backend/koala-bear/src/quintic_extension/packed_extension.rs index 682f3cb82..1aa93cac5 100644 --- a/crates/backend/koala-bear/src/quintic_extension/packed_extension.rs +++ b/crates/backend/koala-bear/src/quintic_extension/packed_extension.rs @@ -48,7 +48,7 @@ impl> From> for P #[inline] fn from(x: QuinticExtensionField) -> Self { Self { - value: x.value.map(Into::into), + value: array::from_fn(|i| x.value[i].into()), } } } @@ -117,10 +117,11 @@ macro_rules! impl_packed_ext_scalar_ops { impl Mul for PackedQuinticExtensionField { type Output = Self; #[inline] - fn mul(self, rhs: KoalaBear) -> Self { - Self { - value: self.value.map(|x| x * rhs), + fn mul(mut self, rhs: KoalaBear) -> Self { + for v in &mut self.value { + *v *= rhs; } + self } } @@ -281,10 +282,12 @@ where type Output = Self; #[inline] - fn neg(self) -> Self { - Self { - value: self.value.map(PF::neg), + fn neg(mut self) -> Self { + // Loop, not `self.value.map(..)`: avoids a thin-LTO de-inlined `Wrapped` closure. + for v in &mut self.value { + *v = -*v; } + self } } @@ -478,7 +481,7 @@ where #[inline(always)] fn mul(self, rhs: QuinticExtensionField) -> Self { - let b: [PF; 5] = rhs.value.map(|x| x.into()); + let b: [PF; 5] = array::from_fn(|i| rhs.value[i].into()); Self { value: super::extension::quintic_mul(&self.value, &b, PF::dot_product::<5>), } @@ -493,10 +496,11 @@ where type Output = Self; #[inline] - fn mul(self, rhs: PF) -> Self { - Self { - value: self.value.map(|x| x * rhs), + fn mul(mut self, rhs: PF) -> Self { + for v in &mut self.value { + *v *= rhs; } + self } } diff --git a/crates/backend/parallel/Cargo.toml b/crates/backend/parallel/Cargo.toml new file mode 100644 index 000000000..731b5163d --- /dev/null +++ b/crates/backend/parallel/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "parallel" +version.workspace = true +edition.workspace = true +description = "Minimal fixed-size thread pool for static data-parallel kernels" + +[dependencies] +system-info.workspace = true + +[lints] +workspace = true diff --git a/crates/backend/parallel/src/lib.rs b/crates/backend/parallel/src/lib.rs new file mode 100644 index 000000000..631dc8d0e --- /dev/null +++ b/crates/backend/parallel/src/lib.rs @@ -0,0 +1,443 @@ +//! Minimal fixed-size thread pool for flat data-parallel kernels ("split a range, run a closure +//! on each piece"). No work-stealing, no per-dispatch allocation; owning the runtime lets us pin +//! per-worker scratch and drop rayon. +//! +//! - **Model.** `NUM_THREADS-1` background workers (ids `1..NUM_THREADS`); the dispatcher is +//! worker 0 and runs its share inline. Workers claim ranges from a shared atomic counter +//! (guided self-scheduling) for load balance. +//! - **Lock-free dispatch.** Dispatch bumps a `generation` counter idle workers spin on, parking +//! after `SPIN_LIMIT` spins; completion is a `working` countdown the dispatcher spins on. +//! `parked` is SeqCst-ordered against `generation`, so each dispatch one side sees the other +//! (no lost wakeup) and unpark is skipped while a worker spins. +//! - **No nesting.** A dispatch from within a task would deadlock the dispatch lock; an `IN_TASK` +//! guard panics instead. +//! - **Panics.** A task panic is caught on its worker and re-raised on the dispatcher once the +//! dispatch quiesces; the pool stays usable. +//! - **One dispatcher at a time**, serialized by the `dispatch` mutex. + +use std::any::Any; +use std::cell::{Cell, UnsafeCell}; +use std::panic::{AssertUnwindSafe, catch_unwind, resume_unwind}; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{Mutex, Once, OnceLock}; +use std::thread::Thread; + +/// Idle spins before a worker parks: long enough to stay hot across back-to-back dispatches, +/// short enough to yield the core during sequential gaps. +const SPIN_LIMIT: u32 = 1 << 12; + +/// Max tasks claimed in one guided-self-scheduling step: bounds load imbalance while keeping +/// million-task kernels to a few thousand claims. +const MAX_CLAIM_BATCH: usize = 1 << 12; + +/// Worker count including the dispatcher. Resolved once at runtime (see [`system_info::num_threads`]). +#[must_use] +pub fn num_threads() -> usize { + system_info::num_threads() +} + +/// Chunk size for a flat fan-out: a few chunks per worker — fine enough for the counter to +/// rebalance heterogeneous cores, coarse enough to amortize dispatch. +#[must_use] +#[inline] +pub fn recommended_chunk_size(n_items: usize) -> usize { + n_items.div_ceil(num_threads() * 4).max(1) +} + +thread_local! { + /// Stable pool id of this thread; `0` on the dispatcher and off-pool threads. + static WORKER_ID: Cell = const { Cell::new(0) }; + /// Set while running a task; a dispatch in this state is forbidden nesting (panics). + static IN_TASK: Cell = const { Cell::new(false) }; +} + +/// Calling worker's id in `0..NUM_THREADS` (`0` off-pool). +#[must_use] +pub(crate) fn current_worker_id() -> usize { + WORKER_ID.with(Cell::get) +} + +/// Type-erased work unit. The `&dyn Fn` lifetime is erased to `'static`; it is dereferenced +/// only inside a dispatch window during which the dispatcher blocks, so the borrow outlives +/// every call. Range-based (`f(start, end)`) so a reduction looks up its per-worker +/// accumulator once per claimed batch, not per element. +struct Job { + f: NonNull, + n_tasks: usize, +} + +/// Park/unpark state, indexed by worker id (slot 0, the dispatcher, never parks). +#[derive(Debug)] +struct Worker { + /// "Currently parked", SeqCst-ordered against `Pool::generation`. + parked: AtomicBool, + /// Handle for `unpark`, published once at worker start-up. + handle: OnceLock, +} + +struct Pool { + /// Current job: written by the dispatcher before the `generation` bump, read by workers + /// after observing it (the bump supplies the happens-before). + job: UnsafeCell>, + /// Bumped once per dispatch; idle workers watch it (spin, then park). + generation: AtomicUsize, + /// Next task index to claim; reset to 0 per dispatch. + counter: AtomicUsize, + /// Background workers still draining; the dispatcher spins this to 0. + working: AtomicUsize, + /// Park flag + unpark handle per worker (slot 0 unused). + workers: Vec, + /// Serializes dispatchers: one driver at a time. + dispatch: Mutex<()>, + /// First task-panic payload of the current dispatch, re-raised by the dispatcher. Caught + /// here so it can't unwind across `worker_main` (which would skip the `working` decrement + /// and deadlock the completion spin). + panic: Mutex>>, +} + +// SAFETY: `job` is written only by the sole dispatcher (while workers are parked or before +// they observe the generation bump) and read only after; the generation release/acquire and +// SeqCst park protocol order the phases. The erased `Job` pointer is used only within a +// dispatch window where its borrow is live. +unsafe impl Sync for Pool {} +unsafe impl Send for Pool {} + +/// Idempotent warm-up: spawn workers and run one empty dispatch so the pool and the (macOS) +/// lazily-allocated mutex exist before timed work; otherwise the pool inits on first use. +pub fn init() { + static INIT: Once = Once::new(); + INIT.call_once(|| { + let _ = pool(); + if num_threads() > 1 { + for_each_index(num_threads(), |_| {}); + } + }); +} + +fn pool() -> &'static Pool { + static POOL: OnceLock<&'static Pool> = OnceLock::new(); + POOL.get_or_init(|| { + let n = num_threads().max(1); + let p: &'static Pool = Box::leak(Box::new(Pool { + job: UnsafeCell::new(None), + generation: AtomicUsize::new(0), + counter: AtomicUsize::new(0), + working: AtomicUsize::new(0), + workers: (0..n) + .map(|_| Worker { + parked: AtomicBool::new(false), + handle: OnceLock::new(), + }) + .collect(), + dispatch: Mutex::new(()), + panic: Mutex::new(None), + })); + for id in 1..n { + std::thread::Builder::new() + .name(format!("parallel-worker-{id}")) + .spawn(move || worker_main(p, id)) + .expect("failed to spawn pool worker"); + } + p + }) +} + +fn worker_main(pool: &'static Pool, id: usize) { + WORKER_ID.with(|c| c.set(id)); + let _ = pool.workers[id].handle.set(std::thread::current()); + // Leaked, lives for the whole process; workers never shut down. One iteration per dispatch. + let mut last_gen = 0usize; + loop { + last_gen = wait_for_dispatch(pool, id, last_gen); + drain(pool); + pool.working.fetch_sub(1, Ordering::Release); + } +} + +/// Block until a new job is published, returning its generation. Spins up to [`SPIN_LIMIT`], then +/// parks: publish `parked = true`, re-check `generation`, both SeqCst — the same total order the +/// dispatcher's bump and `parked` load observe, so a wakeup can't be lost. +fn wait_for_dispatch(pool: &Pool, id: usize, last_gen: usize) -> usize { + let mut spins = 0u32; + loop { + let g = pool.generation.load(Ordering::Acquire); + if g != last_gen { + return g; + } + if spins < SPIN_LIMIT { + spins += 1; + std::hint::spin_loop(); + continue; + } + // Announce intent to park, then re-check: park only if nothing changed, else re-loop. + pool.workers[id].parked.store(true, Ordering::SeqCst); + if pool.generation.load(Ordering::SeqCst) == last_gen { + std::thread::park(); + } + pool.workers[id].parked.store(false, Ordering::SeqCst); + spins = 0; + } +} + +/// Claim and run task ranges until the counter is exhausted (guided self-scheduling: each claim +/// takes `remaining / (NUM_THREADS*2)`, clamped to `1..=`[`MAX_CLAIM_BATCH`]). Big early claims +/// cut counter contention; the proportional shrink keeps the tail balanced. +fn drain(pool: &Pool) { + // SAFETY: the dispatcher published `Some(job)` before the bump this worker observed and + // overwrites it only on the next dispatch (gated on `working == 0`); no writer during drain. + let job = unsafe { (*pool.job.get()).as_ref().expect("drain without a published job") }; + // SAFETY: `job.f` borrows a `&dyn Fn` the blocked dispatcher keeps live. + let f = unsafe { job.f.as_ref() }; + let n = job.n_tasks; + let nt = num_threads(); + let prev = IN_TASK.replace(true); // catch nested dispatch (see `for_each_chunk`) + // Catch a task panic so it can't unwind across `worker_main` (skipping the `working` + // decrement → deadlock) or poison the dispatch lock; `for_each_chunk` re-raises it. + let result = catch_unwind(AssertUnwindSafe(|| { + loop { + // Stale read only affects granularity: `fetch_add` tiles `0..n` into disjoint claims. + let observed = pool.counter.load(Ordering::Relaxed); + if observed >= n { + break; + } + let batch = ((n - observed) / (nt * 2)).clamp(1, MAX_CLAIM_BATCH); + let start = pool.counter.fetch_add(batch, Ordering::Relaxed); + if start >= n { + break; + } + f(start, (start + batch).min(n)); + } + })); + IN_TASK.set(prev); + if let Err(payload) = result { + pool.panic.lock().unwrap().get_or_insert(payload); // keep the first + } +} + +/// Run `f(start, end)` over disjoint ranges tiling `0..n_tasks`, in parallel; a worker may get +/// several (guided self-scheduling, see [`drain`]). Blocks until done, the dispatcher acting as +/// worker 0. The base primitive — range-based so reductions amortize per-worker lookups. +pub fn for_each_chunk(n_tasks: usize, f: F) { + // Nesting would deadlock the dispatch lock — panic so it's caught, not silently serial. + assert!(!IN_TASK.get(), "nested parallel dispatch from within a pool task"); + + // Trivial sizes / single-core builds run inline. + let nt = num_threads(); + if nt <= 1 || n_tasks <= 1 { + if n_tasks > 0 { + f(0, n_tasks); + } + return; + } + + let pool = pool(); + let _guard = pool.dispatch.lock().unwrap(); + + // SAFETY: erase the borrow to `'static` so it fits the `Job`. The dispatcher blocks on + // `working` before returning, so `f` outlives every deref. `transmute` (not a `*const dyn` + // cast) is required: a bare cast would default the trait object to `'static` and force + // `F: 'static` (E0310); the transmute reinterprets the same fat pointer without that bound. + let f_ref: &(dyn Fn(usize, usize) + Sync) = &f; + let f_erased: NonNull = unsafe { std::mem::transmute(NonNull::from(f_ref)) }; + + // SAFETY: sole writer — prior dispatch fully drained (`working == 0`), next not yet observed. + unsafe { *pool.job.get() = Some(Job { f: f_erased, n_tasks }) }; + pool.counter.store(0, Ordering::Relaxed); + pool.working.store(nt - 1, Ordering::Release); + pool.generation.fetch_add(1, Ordering::SeqCst); // publish; SeqCst guards the park protocol + + // Wake only parked workers; spinning ones see the bump for free. + for worker in &pool.workers[1..] { + if worker.parked.load(Ordering::SeqCst) + && let Some(t) = worker.handle.get() + { + t.unpark(); + } + } + + drain(pool); // dispatcher runs as worker 0 + while pool.working.load(Ordering::Acquire) != 0 { + std::hint::spin_loop(); // lock-free completion wait + } + + // Re-raise the first task panic (if any) after dropping `_guard`, so the lock releases + // cleanly (no poison) and the pool stays usable. + let panicked = pool.panic.lock().unwrap().take(); + drop(_guard); + if let Some(payload) = panicked { + resume_unwind(payload); + } +} + +/// `f(i)` for every `i` in `0..n_tasks`, in parallel. `#[inline]` folds the range→index adapter +/// into the monomorphized [`for_each_chunk`]. +#[inline] +pub fn for_each_index(n_tasks: usize, f: F) { + for_each_chunk(n_tasks, |start, end| { + for i in start..end { + f(i); + } + }); +} + +/// A base `*mut` shareable across workers. Sound only because callers partition the allocation +/// by task index (disjoint regions). +#[derive(Debug)] +pub struct SendPtr(pub *mut T); +// SAFETY: accesses are partitioned by task index (see callers). +unsafe impl Send for SendPtr {} +unsafe impl Sync for SendPtr {} + +impl SendPtr { + /// Offset the base by `n` elements. + /// # Safety + /// `n` stays in the allocation; any write targets a slot no concurrent task touches. + #[inline] + pub unsafe fn add(&self, n: usize) -> *mut T { + unsafe { self.0.add(n) } + } + + /// Reconstruct the `len`-element slice at element offset `off`. + /// # Safety + /// `off`/`len` in-bounds and disjoint from every other concurrent task's slice. + #[inline] + pub unsafe fn slice<'a>(&self, off: usize, len: usize) -> &'a mut [T] { + unsafe { std::slice::from_raw_parts_mut(self.0.add(off), len) } + } +} + +/// Parallel `data.chunks_mut(chunk).enumerate().for_each(f)`; the final chunk may be shorter. +pub fn par_chunks_mut(data: &mut [T], chunk: usize, f: F) +where + F: Fn(usize, &mut [T]) + Sync, +{ + assert!(chunk > 0, "chunk size must be non-zero"); + let len = data.len(); + let base = SendPtr(data.as_mut_ptr()); + for_each_index(len.div_ceil(chunk), |i| { + let start = i * chunk; + // SAFETY: distinct `i` give disjoint in-bounds ranges; `data` stays borrowed. + let slice = unsafe { base.slice(start, chunk.min(len - start)) }; + f(i, slice); + }); +} + +/// Parallel `data.iter_mut().enumerate().for_each(f)`, chunked by [`recommended_chunk_size`]. +/// Hands the closure each element's **global** index. `#[inline]` folds the per-chunk adapter +/// into the monomorphized [`par_chunks_mut`]. +#[inline] +pub fn par_for_each_mut(data: &mut [T], f: F) +where + F: Fn(usize, &mut T) + Sync, +{ + let chunk = recommended_chunk_size(data.len()); + par_chunks_mut(data, chunk, |ci, sub| { + for (k, slot) in sub.iter_mut().enumerate() { + f(ci * chunk + k, slot); + } + }); +} + +/// [`par_for_each_mut`] over two equal-length slices at once: `f(i, &mut a[i], &mut b[i])` +#[inline] +pub fn par_for_each_mut2(a: &mut [A], b: &mut [B], f: F) +where + F: Fn(usize, &mut A, &mut B) + Sync, +{ + assert_eq!(a.len(), b.len(), "par_for_each_mut2: slices differ in length"); + let bp = SendPtr(b.as_mut_ptr()); + par_for_each_mut(a, |i, ai| { + f(i, ai, unsafe { &mut *bp.add(i) }); + }); +} + +/// Parallel `(0..n_tasks).map(f).collect::>()`: runs `f(i)` across the pool and writes each +/// result straight into the output in index order — one allocation, no `Option` slots. +pub fn par_map_collect T + Sync>(n_tasks: usize, f: F) -> Vec { + let mut out: Vec = Vec::with_capacity(n_tasks); + let base = SendPtr(out.as_mut_ptr()); + for_each_index(n_tasks, |i| { + // SAFETY: distinct `i` write disjoint, in-bounds slots (each exactly once) and the + // dispatch blocks until all writes finish. A panic in `f` leaks the slots written so + // far, which is fine: a pool task panic is fatal (see the module's "Panics" note). + unsafe { base.add(i).write(f(i)) }; + }); + // SAFETY: every slot in `0..n_tasks` was initialized exactly once above. + unsafe { out.set_len(n_tasks) }; + out +} + +/// Parallel `for (i, slot) in dst.iter_mut().enumerate() { *slot = build(i); }`: fill an existing +/// slice from an index closure. The in-place dual of [`par_map_collect`] (which allocates). +/// `#[inline]` folds the fill adapter into the monomorphized [`par_for_each_mut`]. Always +/// dispatches to the pool; guard the call yourself when small inputs need a sequential fast path. +#[inline] +pub fn par_fill T + Sync>(dst: &mut [T], build: F) { + par_for_each_mut(dst, |i, slot| *slot = build(i)); +} + +/// Give each worker its own persistent `Option` slot while it drains `0..n_tasks`: +/// `run(slot, start, end)` fires once per claimed batch with that worker's slot, so state +/// accumulates across its batches. Returns the slots (rest `None`) for the caller to combine. +fn drain_into_slots(n_tasks: usize, run: impl Fn(&mut Option, usize, usize) + Sync) -> Vec> { + let mut slots: Vec> = (0..num_threads()).map(|_| None).collect(); + let ptr = SendPtr(slots.as_mut_ptr()); + for_each_chunk(n_tasks, |start, end| { + // SAFETY: `current_worker_id() < NUM_THREADS` is unique per live worker → disjoint + // slots; `slots` outlives the dispatch. + let slot = unsafe { &mut *ptr.add(current_worker_id()) }; + run(slot, start, end); + }); + slots +} + +/// Parallel map-reduce over `0..n_tasks` = `(0..n).map(map).reduce(identity, reduce)`. Each +/// worker folds its claimed indices into one local partial; the partials combine on the +/// dispatcher. `reduce` must be associative with `identity()` a neutral element. +pub fn map_reduce(n_tasks: usize, identity: ID, map: M, reduce: R) -> T +where + T: Send, + ID: Fn() -> T, + M: Fn(usize) -> T + Sync, + R: Fn(T, T) -> T + Sync, +{ + let slots = drain_into_slots(n_tasks, |slot, start, end| { + // Fold the batch into the worker's partial, seeded by the first `map` so `identity` + // stays off the per-element path; take/replace the shared slot just once. + *slot = (start..end).fold(slot.take(), |acc, i| { + Some(acc.map_or_else(|| map(i), |a| reduce(a, map(i)))) + }); + }); + // `identity()` seeds the combine as a no-op left-identity; the empty and single-thread + // (`for_each_chunk` runs inline) cases then fall out without a special path. + slots.into_iter().flatten().fold(identity(), &reduce) +} + +/// Parallel reduce where each worker keeps reusable scratch beside its accumulator (so the +/// per-task body needn't allocate). `(scratch, acc)` are created once per worker and threaded +/// through its batches; the `acc`s combine on the dispatcher. `combine` must be associative +/// with `init_acc()` a neutral element. +pub fn map_reduce_with_state(n_tasks: usize, init_state: IS, init_acc: IA, fold: F, combine: C) -> A +where + S: Send, + A: Send, + IS: Fn() -> S + Sync, + IA: Fn() -> A + Sync, + F: Fn(&mut S, &mut A, usize) + Sync, + C: Fn(A, A) -> A, +{ + let slots = drain_into_slots(n_tasks, |slot, start, end| { + let (state, acc) = slot.get_or_insert_with(|| (init_state(), init_acc())); + for i in start..end { + fold(state, acc, i); + } + }); + // `init_acc()` seeds the combine as a neutral element; the empty and single-thread cases + // (`for_each_chunk` runs inline) then fall out without a special path. + slots + .into_iter() + .flatten() + .map(|(_, acc)| acc) + .fold(init_acc(), &combine) +} diff --git a/crates/backend/poly/Cargo.toml b/crates/backend/poly/Cargo.toml index dcdf80aed..ecb5c7483 100644 --- a/crates/backend/poly/Cargo.toml +++ b/crates/backend/poly/Cargo.toml @@ -7,9 +7,10 @@ edition.workspace = true field = { path = "../field", package = "mt-field" } utils = { path = "../utils", package = "mt-utils" } system-info.workspace = true +parallel.workspace = true +zk-alloc.workspace = true itertools.workspace = true -rayon.workspace = true rand.workspace = true serde.workspace = true diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 64d3733f5..593e3f8d0 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -2,37 +2,86 @@ use crate::*; use crate::{EFPacking, PF}; use ::utils::{iter_array_chunks_padded, log2_ceil_usize, log2_strict_usize}; use field::*; -use rayon::prelude::*; -use system_info::NUM_THREADS; +use system_info::num_threads; +use zk_alloc::ArenaVec; -const LOG_NUM_THREADS: usize = log2_ceil_usize(NUM_THREADS); -const NUM_THREADS_PADDED: usize = 1 << LOG_NUM_THREADS; const LOG_BATCHED_TILE_SIZE: usize = 14; +/// log2 oversubscription for the eq_mle fan-out: emit `num_threads() << this` chunks so the +/// pool's task counter rebalances across heterogeneous cores (e.g. P/E). `0` = one chunk +/// per worker; `2` (4x) is a conservative default that balances well without over-fragmenting. +const PARALLEL_LOG_OVERSUB: usize = 2; + +/// `(log2(n_chunks), n_chunks)` for the parallel fan-out. +#[inline] +fn parallel_split() -> (usize, usize) { + let log_chunks = log2_ceil_usize(num_threads()) + PARALLEL_LOG_OVERSUB; + (log_chunks, 1 << log_chunks) +} + +#[inline] +fn par_chunks_zip(out: &mut [T], chunk: usize, buf: &[A], g: G) +where + T: Send, + A: Sync, + G: Fn(&mut [T], &A) + Sync, +{ + debug_assert_eq!(out.len(), chunk * buf.len()); + parallel::par_chunks_mut(out, chunk, |i, c| g(c, &buf[i])); +} + +#[inline] +fn par_eval_eq( + eval: &[In], + out: &mut [Out], + log_chunks: usize, + n_chunks: usize, + log_packing_width: usize, + seed: Buf, + kernel: impl Fn(&[In], &mut [Out], Buf) + Sync, +) where + In: Field, + Buf: Algebra + Copy + Send + Sync, + Out: Send, +{ + let mut buffer = Buf::zero_vec(n_chunks); + buffer[0] = seed; + fill_buffer(eval[..log_chunks].iter().rev(), &mut buffer); + + let out_chunk_size = out.len() / n_chunks; + let middle = &eval[log_chunks..(eval.len() - log_packing_width)]; + par_chunks_zip(out, out_chunk_size, &buffer, |out_chunk, buffer_val| { + kernel(middle, out_chunk, *buffer_val); + }); +} + /// Given `evals` = (α_1, ..., α_n), returns a multilinear polynomial P in n variables, /// defined on the boolean hypercube by: ∀ (x_1, ..., x_n) ∈ {0, 1}^n, /// P(x_1, ..., x_n) = Π_{i=1}^{n} (x_i.α_i + (1 - x_i).(1 - α_i)) /// (often denoted as P(x) = eq(x, evals)) -pub fn eval_eq>>(eval: &[F]) -> Vec { +/// Returns an arena-backed table (see [`ArenaVec`]). Every eq table is phase-local proof scratch +/// (consumed within the proving phase that built it, or system-backed when the arena is inactive, +/// e.g. in the verifier), so it never outlives a `begin_phase()` reset. +pub fn eval_eq>>(eval: &[F]) -> ArenaVec { eval_eq_scaled(eval, F::ONE) } -pub fn eval_eq_scaled>>(eval: &[F], scalar: F) -> Vec { +pub fn eval_eq_scaled>>(eval: &[F], scalar: F) -> ArenaVec { // Alloc memory without initializing it to zero. - // This is safe because we overwrite it inside `eval_eq`. - let mut out = unsafe { uninitialized_vec(1 << eval.len()) }; + // This is safe because we overwrite it inside `compute_eval_eq`. + let mut out = unsafe { ArenaVec::uninitialized(1 << eval.len()) }; compute_eval_eq::, F, false>(eval, &mut out, scalar); out } -pub fn eval_eq_packed>>(eval: &[F]) -> Vec> { +pub fn eval_eq_packed>>(eval: &[F]) -> ArenaVec> { eval_eq_packed_scaled(eval, F::ONE) } -pub fn eval_eq_packed_scaled>>(eval: &[F], scalar: F) -> Vec> { +pub fn eval_eq_packed_scaled>>(eval: &[F], scalar: F) -> ArenaVec> { // Alloc memory without initializing it to zero. - // This is safe because we overwrite it inside `eval_eq`. - let mut out = unsafe { uninitialized_vec(1 << (eval.len() - packing_log_width::())) }; + // This is safe because we overwrite it inside `compute_eval_eq_packed`. + let mut out = unsafe { ArenaVec::uninitialized(1 << (eval.len() - packing_log_width::())) }; compute_eval_eq_packed::(eval, &mut out, scalar); out } @@ -59,7 +108,7 @@ where let packed = &mut out[selector >> shift]; let mut unpacked: Vec = unpack_extension(&[*packed]); compute_sparse_eval_eq::(selector & ((1 << shift) - 1), eval, &mut unpacked, scalar); - *packed = pack_extension(&unpacked)[0]; + *packed = pack_extension::<_, Vec<_>>(&unpacked)[0]; return; } @@ -87,62 +136,33 @@ where F: Field, EF: ExtensionField, { - // It's possible for this to be called with F = EF (Despite F actually being an extension field). - // - // IMPORTANT: We previously checked here that `packing_width > 1`, - // but this check is **not viable** for Goldilocks on Neon or when not using `target-cpu=native`. - // - // Why? Because Neon SIMD vectors are 128 bits and Goldilocks elements are already 64 bits, - // so no packing happens (width stays 1), and there's no performance advantage. - // - // Be careful: this means code relying on packing optimizations should **not assume** - // `packing_width > 1` is always true. + // `packing_width` may be 1 (e.g. Goldilocks on Neon, or without `target-cpu=native`), + // so nothing here may assume it is > 1. let log_packing_width = log2_strict_usize(F::Packing::WIDTH); - - // Ensure that the output buffer size is correct: - // It should be of size `2^n`, where `n` is the number of variables. debug_assert_eq!(out.len(), 1 << eval.len()); - // If the number of variables is small, there is no need to use - // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { + // Too small to be worth packing/parallelizing. eval_eq_basic::<_, _, _, INITIALIZED>(eval, out, scalar); return; } - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of elements of size `NUM_THREADS`. - let mut parallel_buffer = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], scalar); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - eval_eq_with_packed_scalar::<_, _, INITIALIZED>( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - ); - }); + // Split `eval` into [leading `log_chunks` | middle | trailing `log_packing_width`]: the + // trailing vars fold into the per-chunk seed, the leading vars index the chunks, the + // middle runs in parallel. + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], scalar); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + eval_eq_with_packed_scalar::<_, _, INITIALIZED>(middle, out_chunk, buffer_val); + }, + ); } #[inline] @@ -150,16 +170,8 @@ pub fn compute_eval_eq_packed(eval: &[EF], out: &mu where EF: ExtensionField>, { - // It's possible for this to be called with F = EF (Despite F actually being an extension field). - // - // IMPORTANT: We previously checked here that `packing_width > 1`, - // but this check is **not viable** for Goldilocks on Neon or when not using `target-cpu=native`. - // - // Why? Because Neon SIMD vectors are 128 bits and Goldilocks elements are already 64 bits, - // so no packing happens (width stays 1), and there's no performance advantage. - // - // Be careful: this means code relying on packing optimizations should **not assume** - // `packing_width > 1` is always true. + // `packing_width` may be 1 (e.g. Goldilocks on Neon, or without `target-cpu=native`), + // so nothing here may assume it is > 1. let packing_width = packing_width::(); let log_packing_width = log2_strict_usize(packing_width); @@ -168,12 +180,13 @@ where // If the number of variables is small, there is no need to use // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. - let mut output_no_packing = EF::zero_vec(1 << eval.len()); - eval_eq_basic::<_, _, _, false>(eval, &mut output_no_packing, scalar); - out.par_iter_mut() - .zip(output_no_packing.par_chunks_exact(packing_width)) + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { + // Small case: evaluate unpacked, then pack lanes into `out`. + let mut unpacked = unsafe { ArenaVec::zeroed(1 << eval.len()) }; + eval_eq_basic::<_, _, _, false>(eval, &mut unpacked, scalar); + out.iter_mut() + .zip(unpacked.chunks_exact(packing_width)) .for_each(|(out_elem, chunk)| { if INITIALIZED { *out_elem += EF::ExtensionPacking::from_ext_slice(chunk); @@ -181,40 +194,22 @@ where *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); } }); - } else { - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of elements of size `NUM_THREADS`. - let mut parallel_buffer = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], scalar); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - eval_eq_with_packed_output::<_, _, INITIALIZED>( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - ); - }); + return; } + + // See `compute_eval_eq` for the leading/middle/trailing split. + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], scalar); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + eval_eq_with_packed_output::<_, _, INITIALIZED>(middle, out_chunk, buffer_val); + }, + ); } /// Computes the equality polynomial evaluations efficiently. @@ -240,57 +235,30 @@ where F: Field, EF: ExtensionField, { - // we assume that packing_width is a power of 2. let log_packing_width = log2_strict_usize(F::Packing::WIDTH); - - // Ensure that the output buffer size is correct: - // It should be of size `2^n`, where `n` is the number of variables. debug_assert_eq!(out.len(), 1 << eval.len()); - // If the number of variables is small, there is no need to use - // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { eval_eq_basic::<_, _, _, INITIALIZED>(eval, out, scalar); return; } - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of PackedField elements of size `NUM_THREADS`. - // Note that this is a slightly different strategy to `eval_eq` which instead - // uses PackedExtensionField elements. Whilst this involves slightly more mathematical - // operations, it seems to be faster in practice due to less data moving around. - let mut parallel_buffer = F::Packing::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], F::ONE); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - base_eval_eq_packed::<_, _, INITIALIZED>( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - scalar, - ); - }); + // Base-field input: seed the per-chunk buffer with `F::Packing` (not `EF::ExtensionPacking`) + // and apply `scalar` inside the kernel — slightly more ops but less data movement, which is + // faster here in practice. See `compute_eval_eq` for the leading/middle/trailing split. + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], F::ONE); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + base_eval_eq_packed::<_, _, INITIALIZED>(middle, out_chunk, buffer_val, scalar); + }, + ); } #[inline] @@ -302,24 +270,18 @@ pub fn compute_eval_eq_base_packed( F: Field, EF: ExtensionField, { - // we assume that packing_width is a power of 2. let packing_width = F::Packing::WIDTH; let log_packing_width = log2_strict_usize(packing_width); assert!(log_packing_width <= eval.len()); assert_eq!(out.len(), 1 << (eval.len() - log_packing_width)); - // Ensure that the output buffer size is correct: - // It should be of size `2^n`, where `n` is the number of variables. - debug_assert_eq!(out.len(), 1 << (eval.len() - log_packing_width)); - - // If the number of variables is small, there is no need to use - // parallelization or packings. - if eval.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - // A basic recursive approach. - let mut output_no_packing = EF::zero_vec(1 << eval.len()); - eval_eq_basic::<_, _, _, false>(eval, &mut output_no_packing, scalar); - out.par_iter_mut() - .zip(output_no_packing.par_chunks_exact(packing_width)) + let (log_chunks, n_chunks) = parallel_split(); + if eval.len() <= log_packing_width + 1 + log_chunks { + // Small case: evaluate unpacked, then pack lanes into `out`. + let mut unpacked = unsafe { ArenaVec::zeroed(1 << eval.len()) }; + eval_eq_basic::<_, _, _, false>(eval, &mut unpacked, scalar); + out.iter_mut() + .zip(unpacked.chunks_exact(packing_width)) .for_each(|(out_elem, chunk)| { if INITIALIZED { *out_elem += EF::ExtensionPacking::from_ext_slice(chunk); @@ -327,45 +289,24 @@ pub fn compute_eval_eq_base_packed( *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); } }); - } else { - let eval_len_min_packing = eval.len() - log_packing_width; - - // We split eval into three parts: - // - eval[..LOG_NUM_THREADS] (the first LOG_NUM_THREADS elements) - // - eval[LOG_NUM_THREADS..eval_len_min_packing] (the middle elements) - // - eval[eval_len_min_packing..] (the last log_packing_width elements) - - // The middle elements are the ones which will be computed in parallel. - // The last log_packing_width elements are the ones which will be packed. - - // We make a buffer of PackedField elements of size `NUM_THREADS`. - // Note that this is a slightly different strategy to `eval_eq` which instead - // uses PackedExtensionField elements. Whilst this involves slightly more mathematical - // operations, it seems to be faster in practice due to less data moving around. - let mut parallel_buffer = F::Packing::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; - - // Compute the equality polynomial corresponding to the last log_packing_width elements - // and pack these. - parallel_buffer[0] = packed_eq_poly(&eval[eval_len_min_packing..], F::ONE); - - // Update the buffer so it contains the evaluations of the equality polynomial - // with respect to parts one and three. - fill_buffer(eval[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer); - - // Finally do all computations involving the middle elements in parallel. - let scalar_packed = EF::ExtensionPacking::from(scalar); - out.par_chunks_exact_mut(out_chunk_size) - .zip(parallel_buffer.par_iter()) - .for_each(|(out_chunk, buffer_val)| { - base_eval_eq_packed_with_packed_output::( - &eval[LOG_NUM_THREADS..(eval.len() - log_packing_width)], - out_chunk, - *buffer_val, - scalar_packed, - ); - }); + return; } + + // Base-field input: seed with `F::Packing` and apply `scalar` in the kernel (less data + // movement — see `compute_eval_eq_base`). See `compute_eval_eq` for the split. + let scalar_packed = EF::ExtensionPacking::from(scalar); + let seed = packed_eq_poly(&eval[eval.len() - log_packing_width..], F::ONE); + par_eval_eq( + eval, + out, + log_chunks, + n_chunks, + log_packing_width, + seed, + |middle, out_chunk, buffer_val| { + base_eval_eq_packed_with_packed_output::(middle, out_chunk, buffer_val, scalar_packed); + }, + ); } #[inline] @@ -406,27 +347,27 @@ pub fn compute_eval_eq_base_packed_batched( .map(|(eval, &scalar)| { let middle = &eval[n_prefix_levels..n - log_packing_width]; let eq_suffix = packed_eq_poly::(&eval[n - log_packing_width..], F::ONE); - let mut eq_prefix: Vec = unsafe { uninitialized_vec(1 << n_prefix_levels) }; + let mut eq_prefix: ArenaVec = unsafe { ArenaVec::uninitialized(1 << n_prefix_levels) }; eval_eq_basic::(&eval[..n_prefix_levels], &mut eq_prefix, scalar); (eq_prefix, middle, eq_suffix) }) .collect(); - out.par_chunks_exact_mut(tile_packed_size) - .enumerate() - .for_each(|(tile_idx, out_tile)| { - for (eq_prefix, middle, eq_suffix) in &per_query { - // Here e could precompute the eq poly, trading some memory for less computation - // (2x faster on M4 max, but 2x slower on machines with smaller caches. - // TODO implement both and choose based on cache size?) - base_eval_eq_packed_with_packed_output::( - middle, - out_tile, - *eq_suffix, - EF::ExtensionPacking::from(eq_prefix[tile_idx]), - ); - } - }); + // `out` already splits into `2^n_prefix_levels` tiles — many more than there are + // workers — so the pool's task counter load-balances these directly. + parallel::par_chunks_mut(out, tile_packed_size, |tile_idx, out_tile| { + for (eq_prefix, middle, eq_suffix) in &per_query { + // Here e could precompute the eq poly, trading some memory for less computation + // (2x faster on M4 max, but 2x slower on machines with smaller caches. + // TODO implement both and choose based on cache size?) + base_eval_eq_packed_with_packed_output::( + middle, + out_tile, + *eq_suffix, + EF::ExtensionPacking::from(eq_prefix[tile_idx]), + ); + } + }); } /// Fills the `buffer` with evaluations of the equality polynomial @@ -944,39 +885,40 @@ pub fn compute_eval_eq_packed_dual( assert!(log_packing_width <= eval_a.len()); assert_eq!(out.len(), 1 << (eval_a.len() - log_packing_width)); - if eval_a.len() <= log_packing_width + 1 + LOG_NUM_THREADS { - let mut output_no_packing = EF::zero_vec(1 << eval_a.len()); + let (log_chunks, n_chunks) = parallel_split(); + if eval_a.len() <= log_packing_width + 1 + log_chunks { + let mut output_no_packing = unsafe { ArenaVec::zeroed(1 << eval_a.len()) }; eval_eq_basic::<_, _, _, false>(eval_a, &mut output_no_packing, scalar_a); eval_eq_basic::<_, _, _, true>(eval_b, &mut output_no_packing, scalar_b); - out.par_iter_mut() - .zip(output_no_packing.par_chunks_exact(packing_width)) + out.iter_mut() + .zip(output_no_packing.chunks_exact(packing_width)) .for_each(|(out_elem, chunk)| { *out_elem = EF::ExtensionPacking::from_ext_slice(chunk); }); } else { let eval_len_min_packing = eval_a.len() - log_packing_width; - let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(NUM_THREADS_PADDED); - let out_chunk_size = out.len() / NUM_THREADS_PADDED; + let mut parallel_buffer_a = EF::ExtensionPacking::zero_vec(n_chunks); + let mut parallel_buffer_b = EF::ExtensionPacking::zero_vec(n_chunks); + let out_chunk_size = out.len() / n_chunks; parallel_buffer_a[0] = packed_eq_poly(&eval_a[eval_len_min_packing..], scalar_a); - fill_buffer(eval_a[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_a); + fill_buffer(eval_a[..log_chunks].iter().rev(), &mut parallel_buffer_a); parallel_buffer_b[0] = packed_eq_poly(&eval_b[eval_len_min_packing..], scalar_b); - fill_buffer(eval_b[..LOG_NUM_THREADS].iter().rev(), &mut parallel_buffer_b); - - out.par_chunks_exact_mut(out_chunk_size) - .enumerate() - .for_each(|(i, out_chunk)| { - eval_eq_with_packed_output_dual::, EF>( - &eval_a[LOG_NUM_THREADS..eval_len_min_packing], - &eval_b[LOG_NUM_THREADS..eval_len_min_packing], - out_chunk, - parallel_buffer_a[i], - parallel_buffer_b[i], - ); - }); + fill_buffer(eval_b[..log_chunks].iter().rev(), &mut parallel_buffer_b); + + let middle_a = &eval_a[log_chunks..eval_len_min_packing]; + let middle_b = &eval_b[log_chunks..eval_len_min_packing]; + parallel::par_chunks_mut(out, out_chunk_size, |i, out_chunk| { + eval_eq_with_packed_output_dual::, EF>( + middle_a, + middle_b, + out_chunk, + parallel_buffer_a[i], + parallel_buffer_b[i], + ); + }); } } @@ -1200,7 +1142,7 @@ fn packed_eq_poly>(eval: &[EF], scalar: EF) -> E debug_assert_eq!(F::Packing::WIDTH, 1 << eval.len()); // We build up the evaluations of the equality polynomial in buffer. - let mut buffer = EF::zero_vec(1 << eval.len()); + let mut buffer = unsafe { ArenaVec::zeroed(1 << eval.len()) }; buffer[0] = scalar; fill_buffer(eval.iter().rev(), &mut buffer); @@ -1312,7 +1254,7 @@ mod tests { let time = Instant::now(); compute_eval_eq::(&eval, &mut out_3, scalar); let out_3_packed = out_3 - .par_chunks_exact(packing_width) + .chunks_exact(packing_width) .map(>::ExtensionPacking::from_ext_slice) .collect::>(); println!("EXTENSION PACKED AFTER: {:?}", time.elapsed()); @@ -1347,7 +1289,7 @@ mod tests { let time = Instant::now(); compute_eval_eq_base::(&eval, &mut out_3, scalar); let out_3_packed = out_3 - .par_chunks_exact(packing_width) + .chunks_exact(packing_width) .map(>::ExtensionPacking::from_ext_slice) .collect::>(); println!("BASE PACKED AFTER: {:?}", time.elapsed()); diff --git a/crates/backend/poly/src/evals.rs b/crates/backend/poly/src/evals.rs index 7e0e07b4f..42af95ff5 100644 --- a/crates/backend/poly/src/evals.rs +++ b/crates/backend/poly/src/evals.rs @@ -1,10 +1,9 @@ use crate::*; use crate::{EFPacking, PF}; +use ::utils::log2_ceil_usize; use field::{ExtensionField, Field, PrimeCharacteristicRing}; use itertools::Itertools; -use rayon::{join, prelude::*}; -use std::borrow::Borrow; - +use zk_alloc::ArenaVec; pub trait EvaluationsList { fn num_variables(&self) -> usize; fn num_evals(&self) -> usize; @@ -14,30 +13,30 @@ pub trait EvaluationsList { fn evaluate_sparse>(&self, selector: usize, point: &MultilinearPoint) -> EF; } -impl> EvaluationsList for EL { +impl> EvaluationsList for EL { fn num_variables(&self) -> usize { - self.borrow().len().ilog2() as usize + self.as_ref().len().ilog2() as usize } fn num_evals(&self) -> usize { - self.borrow().len() + self.as_ref().len() } fn evaluate>(&self, point: &MultilinearPoint) -> EF { - eval_multilinear::<_, _, true>(self.borrow(), point) + eval_multilinear::<_, _, true>(self.as_ref(), point) } fn evaluate_sequential>(&self, point: &MultilinearPoint) -> EF { - eval_multilinear::<_, _, false>(self.borrow(), point) + eval_multilinear::<_, _, false>(self.as_ref(), point) } fn as_constant(&self) -> F { - assert_eq!(self.borrow().len(), 1); - self.borrow()[0] + assert_eq!(self.as_ref().len(), 1); + self.as_ref()[0] } fn evaluate_sparse>(&self, selector: usize, point: &MultilinearPoint) -> EF { - (&self.borrow()[selector << point.len()..][..(1 << point.len())]).evaluate(point) + (&self.as_ref()[selector << point.len()..][..(1 << point.len())]).evaluate(point) } } @@ -81,16 +80,6 @@ where } } -/// Multiply the polynomial by a scalar factor. -#[must_use] -pub fn scale_poly>(poly: &[F], factor: EF) -> Vec { - if poly.len() < PARALLEL_THRESHOLD { - poly.iter().map(|&e| factor * e).collect() - } else { - poly.par_iter().map(|&e| factor * e).collect() - } -} - fn eval_multilinear(evals: &[F], point: &[EF]) -> EF where F: Field, @@ -219,7 +208,7 @@ where // The `evals` are ordered lexicographically, meaning the first variable's bit changes the slowest. // // To align our computation with this memory layout, we process the point's coordinates in reverse. - let mut point_rev = point.to_vec(); + let mut point_rev = ArenaVec::from_slice(point); point_rev.reverse(); // Split the reversed point's coordinates into two halves: @@ -236,18 +225,18 @@ where // We precompute all `2^|z1|` values of eq(v_high, p_high) and store them in `right`. // Allocate uninitialized memory for the low-order basis polynomial evaluations. - let mut left = unsafe { uninitialized_vec(1 << z0.len()) }; + let mut left: ArenaVec<_> = unsafe { ArenaVec::uninitialized(1 << z0.len()) }; // Allocate uninitialized memory for the high-order basis polynomial evaluations. - let mut right = unsafe { uninitialized_vec(1 << z1.len()) }; + let mut right: ArenaVec<_> = unsafe { ArenaVec::uninitialized(1 << z1.len()) }; // The `eval_eq` function requires the variables in their original order, so we reverse the halves back. - let mut z0_ordered = z0.to_vec(); + let mut z0_ordered = ArenaVec::from_slice(z0); z0_ordered.reverse(); // Compute all eq(v_low, p_low) values and fill the `left` vector. compute_eval_eq::<_, _, false>(&z0_ordered, &mut left, Point::ONE); // Repeat the process for the high-order variables. - let mut z1_ordered = z1.to_vec(); + let mut z1_ordered = ArenaVec::from_slice(z1); z1_ordered.reverse(); // Compute all eq(v_high, p_high) values and fill the `right` vector. compute_eval_eq::<_, _, false>(&z1_ordered, &mut right, Point::ONE); @@ -257,20 +246,23 @@ where // // This chain of operations computes the regrouped sum: // Σ_{v_high} eq(v_high, p_high) * (Σ_{v_low} f(v_high, v_low) * eq(v_low, p_low)) - evals - .par_chunks(left.len()) - .zip_eq(right.par_iter()) - .map(|(part, &c)| { + let left_len = left.len(); + parallel::map_reduce( + right.len(), + || Res::ZERO, + |i| { + let part = &evals[i * left_len..][..left_len]; // This is the inner sum: a dot product between the evaluation chunk and the `left` basis values. mul_res_point( part.iter() .zip_eq(left.iter()) .map(|(&a, &b)| mul_coeffs_point(a, b)) .sum::(), - c, + right[i], ) - }) - .sum() + }, + |a, b| a + b, + ) } else { evals .chunks(left.len()) @@ -290,62 +282,66 @@ where } else { // For moderately sized inputs (5 to 19 variables), use the recursive strategy. // - // Split the evaluations into two halves, corresponding to the first variable being 0 or 1. - let (f0, f1) = evals.split_at(evals.len() / 2); - - // Recursively evaluate on the two smaller hypercubes. - let (f0_eval, f1_eval) = { - // Only spawn parallel tasks if the subproblem is large enough to overcome - // the overhead of threading. - let work_size: usize = (1 << 15) / std::mem::size_of::(); - if evals.len() > work_size && PARALLEL { - join( - || { - eval_multilinear_generic::<_, _, _, _, _, _, PARALLEL>( - f0, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ) - }, - || { - eval_multilinear_generic::<_, _, _, _, _, _, PARALLEL>( - f1, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ) - }, + // Only spawn parallel tasks if the subproblem is large enough to overcome + // the overhead of threading. + let work_size: usize = (1 << 15) / std::mem::size_of::(); + if evals.len() > work_size && PARALLEL { + let log_work = log2_ceil_usize(work_size.max(2)); + let n_split = point.len().saturating_sub(log_work).max(1); + let (lead, sub_point) = point.split_at(n_split); + let n_chunks = 1 << n_split; + let chunk = evals.len() >> n_split; + let partials = parallel::par_map_collect(n_chunks, |j| { + eval_multilinear_generic::<_, _, _, _, _, _, false>( + &evals[j * chunk..][..chunk], + sub_point, + mul_coeffs_point, + add_res_coeffs, + mul_res_point, ) - } else { - // For smaller subproblems, execute sequentially. - ( - eval_multilinear_generic::<_, _, _, _, _, _, false>( - f0, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ), - eval_multilinear_generic::<_, _, _, _, _, _, false>( - f1, - tail, - mul_coeffs_point, - add_res_coeffs, - mul_res_point, - ), - ) - } - }; - // Perform the final linear interpolation for the first variable `x`. - f0_eval + mul_res_point(f1_eval - f0_eval, *x) + }); + interpolate_res(&partials, lead, mul_res_point) + } else { + let (f0, f1) = evals.split_at(evals.len() / 2); + let f0_eval = eval_multilinear_generic::<_, _, _, _, _, _, false>( + f0, + tail, + mul_coeffs_point, + add_res_coeffs, + mul_res_point, + ); + let f1_eval = eval_multilinear_generic::<_, _, _, _, _, _, false>( + f1, + tail, + mul_coeffs_point, + add_res_coeffs, + mul_res_point, + ); + // Perform the final linear interpolation for the first variable `x`. + f0_eval + mul_res_point(f1_eval - f0_eval, *x) + } } } } } +fn interpolate_res(values: &[Res], point: &[Point], mul_res_point: &MRP) -> Res +where + Point: Field, + Res: Copy + PrimeCharacteristicRing, + MRP: Fn(Res, Point) -> Res, +{ + match point { + [] => values[0], + [x, tail @ ..] => { + let (low, high) = values.split_at(values.len() / 2); + let p0 = interpolate_res(low, tail, mul_res_point); + let p1 = interpolate_res(high, tail, mul_res_point); + p0 + mul_res_point(p1 - p0, *x) + } + } +} + #[cfg(test)] mod tests { use std::time::Instant; @@ -369,7 +365,7 @@ mod tests { let res_normal = eval_multilinear::<_, _, true>(&poly, &point); println!("Normal eval time: {:?}", time.elapsed()); - let packed_poly = pack_extension(&poly); + let packed_poly: Vec<_> = pack_extension(&poly); let time = Instant::now(); let res_packed = eval_packed::<_, true>(&packed_poly, &point); println!("Packed eval time: {:?}", time.elapsed()); diff --git a/crates/backend/poly/src/lib.rs b/crates/backend/poly/src/lib.rs index adddb55dc..c72266ca0 100644 --- a/crates/backend/poly/src/lib.rs +++ b/crates/backend/poly/src/lib.rs @@ -23,3 +23,6 @@ pub use evals::*; mod wrappers; pub use wrappers::*; + +mod multilinear_utils; +pub use multilinear_utils::*; diff --git a/crates/backend/poly/src/mle/mle_group_owned.rs b/crates/backend/poly/src/mle/mle_group_owned.rs index f0efcc3b9..d4e57ed1d 100644 --- a/crates/backend/poly/src/mle/mle_group_owned.rs +++ b/crates/backend/poly/src/mle/mle_group_owned.rs @@ -1,31 +1,25 @@ use crate::*; use ::utils::log2_strict_usize; use field::ExtensionField; +use zk_alloc::ArenaVec; #[derive(Debug)] pub enum MleGroupOwned>> { - Base(Vec>>), - Extension(Vec>), - BasePacked(Vec>>), - ExtensionPacked(Vec>>), + Base(Vec>>), + Extension(Vec>), + BasePacked(Vec>>), + ExtensionPacked(Vec>>), } impl>> MleGroupOwned { - pub fn as_extension_mut(&mut self) -> Option<&mut Vec>> { - match self { - Self::Extension(e) => Some(e), - _ => None, - } - } - - pub fn as_extension_packed_mut(&mut self) -> Option<&mut Vec>>> { + pub fn as_extension_packed_mut(&mut self) -> Option<&mut Vec>>> { match self { Self::ExtensionPacked(e) => Some(e), _ => None, } } - pub fn as_extension(self) -> Option>> { + pub fn as_extension(self) -> Option>> { match self { Self::Extension(e) => Some(e), _ => None, diff --git a/crates/backend/poly/src/mle/mle_group_ref.rs b/crates/backend/poly/src/mle/mle_group_ref.rs index e4a993eb0..335399168 100644 --- a/crates/backend/poly/src/mle/mle_group_ref.rs +++ b/crates/backend/poly/src/mle/mle_group_ref.rs @@ -2,6 +2,7 @@ use crate::*; use ::utils::log2_strict_usize; use field::ExtensionField; use field::PackedValue; +use zk_alloc::ArenaVec; #[derive(Debug)] pub enum MleGroupRef<'a, EF: ExtensionField>> { @@ -158,10 +159,12 @@ impl<'a, EF: ExtensionField>> MleGroupRef<'a, EF> { pub fn clone_to_owned(&self) -> MleGroupOwned { match self { - Self::Base(pols) => MleGroupOwned::Base(pols.iter().map(|v| v.to_vec()).collect()), - Self::Extension(pols) => MleGroupOwned::Extension(pols.iter().map(|v| v.to_vec()).collect()), - Self::BasePacked(pols) => MleGroupOwned::BasePacked(pols.iter().map(|v| v.to_vec()).collect()), - Self::ExtensionPacked(pols) => MleGroupOwned::ExtensionPacked(pols.iter().map(|v| v.to_vec()).collect()), + Self::Base(pols) => MleGroupOwned::Base(pols.iter().map(|v| ArenaVec::from_slice(v)).collect()), + Self::Extension(pols) => MleGroupOwned::Extension(pols.iter().map(|v| ArenaVec::from_slice(v)).collect()), + Self::BasePacked(pols) => MleGroupOwned::BasePacked(pols.iter().map(|v| ArenaVec::from_slice(v)).collect()), + Self::ExtensionPacked(pols) => { + MleGroupOwned::ExtensionPacked(pols.iter().map(|v| ArenaVec::from_slice(v)).collect()) + } } } } diff --git a/crates/backend/poly/src/mle/mle_single_owned.rs b/crates/backend/poly/src/mle/mle_single_owned.rs index 7a2d99c31..538060296 100644 --- a/crates/backend/poly/src/mle/mle_single_owned.rs +++ b/crates/backend/poly/src/mle/mle_single_owned.rs @@ -1,18 +1,19 @@ use crate::{EFPacking, Mle, MleRef, MultilinearPoint, PF, PFPacking, pack_extension, packing_width, unpack_extension}; use field::PackedValue; use field::{ExtensionField, PackedFieldExtension}; +use zk_alloc::ArenaVec; #[derive(Debug, Clone)] pub enum MleOwned>> { - Base(Vec>), - Extension(Vec), - BasePacked(Vec>), - ExtensionPacked(Vec>), + Base(ArenaVec>), + Extension(ArenaVec), + BasePacked(ArenaVec>), + ExtensionPacked(ArenaVec>), } impl>> Default for MleOwned { fn default() -> Self { - Self::Base(vec![]) + Self::Base(ArenaVec::new()) } } @@ -63,35 +64,35 @@ impl>> MleOwned { } } - pub fn as_extension_packed_mut(&mut self) -> Option<&mut Vec>> { + pub fn as_extension_packed_mut(&mut self) -> Option<&mut ArenaVec>> { match self { Self::ExtensionPacked(ep) => Some(ep), _ => None, } } - pub fn into_base(self) -> Option>> { + pub fn into_base(self) -> Option>> { match self { Self::Base(b) => Some(b), _ => None, } } - pub fn into_extension(self) -> Option> { + pub fn into_extension(self) -> Option> { match self { Self::Extension(e) => Some(e), _ => None, } } - pub fn into_base_backed(self) -> Option>> { + pub fn into_base_backed(self) -> Option>> { match self { Self::BasePacked(pb) => Some(pb), _ => None, } } - pub fn into_extension_packed(self) -> Option>> { + pub fn into_extension_packed(self) -> Option>> { match self { Self::ExtensionPacked(ep) => Some(ep), _ => None, diff --git a/crates/backend/poly/src/mle/mle_single_ref.rs b/crates/backend/poly/src/mle/mle_single_ref.rs index 61d607d76..e925d7eff 100644 --- a/crates/backend/poly/src/mle/mle_single_ref.rs +++ b/crates/backend/poly/src/mle/mle_single_ref.rs @@ -2,6 +2,7 @@ use crate::*; use ::utils::log2_strict_usize; use field::ExtensionField; use field::PackedValue; +use zk_alloc::ArenaVec; #[derive(Debug)] pub enum MleRef<'a, EF: ExtensionField>> { @@ -104,28 +105,26 @@ impl<'a, EF: ExtensionField>> MleRef<'a, EF> { } } - pub fn pack_if(&self, cond: bool) -> Mle<'a, EF> { - if cond { self.pack() } else { Mle::Ref(self.soft_clone()) } - } - pub fn clone_to_owned(&self) -> MleOwned { match self { - Self::Base(v) => MleOwned::Base(v.to_vec()), - Self::Extension(v) => MleOwned::Extension(v.to_vec()), - Self::BasePacked(pb) => MleOwned::BasePacked(pb.to_vec()), - Self::ExtensionPacked(ep) => MleOwned::ExtensionPacked(ep.to_vec()), + Self::Base(v) => MleOwned::Base(ArenaVec::from_slice(v)), + Self::Extension(v) => MleOwned::Extension(ArenaVec::from_slice(v)), + Self::BasePacked(pb) => MleOwned::BasePacked(ArenaVec::from_slice(pb)), + Self::ExtensionPacked(ep) => MleOwned::ExtensionPacked(ArenaVec::from_slice(ep)), } } pub fn fold(&self, alpha: EF) -> MleOwned { match self { - Self::Base(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a)), - Self::Extension(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a)), + Self::Base(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a, false)), + Self::Extension(pols) => MleOwned::Extension(fold_multilinear(pols, alpha, &|a, b| b * a, false)), Self::BasePacked(pols) => { let alpha_packed = EFPacking::::from(alpha); - MleOwned::ExtensionPacked(fold_multilinear(pols, alpha_packed, &|a, b| b * a)) + MleOwned::ExtensionPacked(fold_multilinear(pols, alpha_packed, &|a, b| b * a, false)) + } + Self::ExtensionPacked(pols) => { + MleOwned::ExtensionPacked(fold_multilinear(pols, alpha, &|a, b| a * b, false)) } - Self::ExtensionPacked(pols) => MleOwned::ExtensionPacked(fold_multilinear(pols, alpha, &|a, b| a * b)), } } } diff --git a/crates/backend/poly/src/multilinear_utils.rs b/crates/backend/poly/src/multilinear_utils.rs new file mode 100644 index 000000000..d96cbc57e --- /dev/null +++ b/crates/backend/poly/src/multilinear_utils.rs @@ -0,0 +1,137 @@ +use field::{ExtensionField, Field, dot_product}; +use utils::*; + +use crate::{EFPacking, EvaluationsList as _, MultilinearPoint, PF, PFPacking}; + +pub fn multilinear_eval_constants_at_right(limit: usize, point: &[F]) -> F { + let n_vars = point.len(); + + // multilinear polynomial = [0 0 --- 0][1 1 --- 1] (`limit` times 0, then `2^n_vars - limit` times 1) evaluated at `point` + + assert!(limit <= (1 << n_vars), "limit {limit} is too large for n_vars {n_vars}"); + + if limit == 1 << n_vars { + return F::ZERO; + } + + if point.is_empty() { + assert!(limit <= 1); + if limit == 1 { F::ZERO } else { F::ONE } + } else { + let main_bit = limit >> (n_vars - 1); + if main_bit == 1 { + // limit is at the right half + point[0] * multilinear_eval_constants_at_right(limit - (1 << (n_vars - 1)), &point[1..]) + } else { + // limit is at left half + point[0] + (F::ONE - point[0]) * multilinear_eval_constants_at_right(limit, &point[1..]) + } + } +} + +pub fn padd_with_zero_to_next_power_of_two(pol: &[F]) -> Vec { + let next_power_of_two = pol.len().next_power_of_two(); + let mut padded = pol.to_vec(); + padded.resize(next_power_of_two, F::ZERO); + padded +} + +pub fn evaluate_as_larger_multilinear_pol>(pol: &[F], point: &[EF]) -> EF { + // [[-pol-] 0 0 0 0 ... 0 0 0 0 0] evaluated at point + let pol_n_vars = log2_strict_usize(pol.len()); + assert!(point.len() >= pol_n_vars); + point + .iter() + .take(point.len() - pol_n_vars) + .map(|x| EF::ONE - *x) + .product::() + * pol.evaluate(&MultilinearPoint(from_end(point, pol_n_vars).to_vec())) +} + +pub fn mle_of_01234567_etc(point: &[F]) -> F { + if point.is_empty() { + F::ZERO + } else { + let e = mle_of_01234567_etc(&point[1..]); + (F::ONE - point[0]) * e + point[0] * (e + F::from_usize(1 << (point.len() - 1))) + } +} + +/// Fingerprint of a logup data tuple. The `domainsep` always occupies the last +/// fingerprint slot (`alphas_eq_poly.last()`) for domain separation, while `data` +/// fills the low slots. +pub fn finger_print(domainsep: EF, data: &[EF], alphas_eq_poly: &[EF]) -> EF { + assert!(alphas_eq_poly.len() > data.len()); + dot_product::(alphas_eq_poly.iter().copied(), data.iter().copied()) + + *alphas_eq_poly.last().unwrap() * domainsep +} + +/// Packed variant of [`finger_print`]. +#[inline(always)] +pub fn finger_print_packed>>( + domainsep: PFPacking, + data: &[PFPacking], + alphas_packed: &[EFPacking], +) -> EFPacking { + let mut result = *alphas_packed.last().unwrap() * domainsep; + for (alpha, d) in alphas_packed.iter().zip(data) { + result += *alpha * *d; + } + result +} + +#[cfg(test)] +mod tests { + use field::PrimeCharacteristicRing; + use koala_bear::{KoalaBear, QuinticExtensionFieldKB}; + use rand::rngs::StdRng; + use rand::{RngExt, SeedableRng}; + + use super::*; + + type F = KoalaBear; + type EF = QuinticExtensionFieldKB; + + #[test] + fn test_evaluate_as_larger_multilinear_pol() { + let n_vars = 5; + let n_point_vars = 7; + let mut rng = StdRng::seed_from_u64(0); + let mut pol = F::zero_vec(1 << n_point_vars); + pol.iter_mut().take(1 << n_vars).for_each(|coeff| *coeff = rng.random()); + let point = (0..n_point_vars).map(|_| rng.random()).collect::>(); + assert_eq!( + evaluate_as_larger_multilinear_pol(&pol[..1 << n_vars], &point), + pol.evaluate(&MultilinearPoint(point)) + ); + } + + #[test] + fn test_multilinear_eval_constants_at_right() { + let n_vars = 10; + let mut rng = StdRng::seed_from_u64(0); + let point = (0..n_vars).map(|_| rng.random()).collect::>(); + for limit in [0, 1, 2, 45, 74, 451, 741, 1022, 1023] { + let eval = multilinear_eval_constants_at_right(limit, &point); + let mut pol = F::zero_vec(1 << n_vars); + pol.iter_mut() + .take(1 << n_vars) + .skip(limit) + .for_each(|coeff| *coeff = F::ONE); + assert_eq!(eval, pol.evaluate(&MultilinearPoint(point.clone()))); + } + } + + #[test] + fn test_mle_of_01234567_etc() { + let n_vars = 10; + let mut rng = StdRng::seed_from_u64(0); + let point = (0..n_vars).map(|_| rng.random()).collect::>(); + let eval = mle_of_01234567_etc(&point); + let mut pol = F::zero_vec(1 << n_vars); + for (i, p) in pol.iter_mut().enumerate().take(1 << n_vars) { + *p = F::from_usize(i); + } + assert_eq!(eval, pol.evaluate(&MultilinearPoint(point))); + } +} diff --git a/crates/backend/poly/src/next_mle.rs b/crates/backend/poly/src/next_mle.rs index 7c9c687c2..af8c08e96 100644 --- a/crates/backend/poly/src/next_mle.rs +++ b/crates/backend/poly/src/next_mle.rs @@ -1,4 +1,5 @@ use field::{ExtensionField, Field, PrimeCharacteristicRing}; +use zk_alloc::ArenaVec; use crate::{PF, eval_eq_scaled}; @@ -32,12 +33,12 @@ pub fn next_mle(x: &[F], y: &[F]) -> F { /// /// This is the "folded" version: the first argument (outer_challenges) is fixed, /// and the result is a vector indexed by the second argument. -pub fn matrix_next_mle_folded>>(outer_challenges: &[F]) -> Vec +pub fn matrix_next_mle_folded>>(outer_challenges: &[F]) -> ArenaVec where PF: PrimeCharacteristicRing, { let n = outer_challenges.len(); - let mut res = F::zero_vec(1 << n); + let mut res = unsafe { ArenaVec::::zeroed(1 << n) }; for k in 0..n { let outer_challenges_prod = (F::ONE - outer_challenges[n - k - 1]) * outer_challenges[n - k..].iter().copied().product::(); diff --git a/crates/backend/poly/src/utils.rs b/crates/backend/poly/src/utils.rs index 5bb5fb1b4..4f94b3f59 100644 --- a/crates/backend/poly/src/utils.rs +++ b/crates/backend/poly/src/utils.rs @@ -1,23 +1,16 @@ -use std::{ - mem::ManuallyDrop, - ops::{Add, Range, Sub}, -}; +use std::ops::{Add, Sub}; use field::*; -use rayon::{ - iter::Zip, - prelude::*, - slice::{Iter, IterMut}, -}; +use zk_alloc::{ArenaVec, OwnedBuffer}; use crate::{EFPacking, PF, PFPacking}; pub const PARALLEL_THRESHOLD: usize = 1 << 9; -pub fn pack_extension>>(slice: &[EF]) -> Vec> { +/// AoS->SoA transpose of `slice` into the already-sized packed buffer `out` (`out.len()` +/// packed elements, each consuming `packing_width` scalars). +fn fill_packed_extension>>(slice: &[EF], out: &mut [EFPacking]) { let width = packing_width::(); - let n_packed = slice.len() / width; - let mut out: Vec> = unsafe { uninitialized_vec(n_packed) }; let write = |slot: &mut EFPacking, chunk: &[EF]| { *slot = EFPacking::::from_ext_slice(chunk); }; @@ -26,17 +19,21 @@ pub fn pack_extension>>(slice: &[EF]) -> Vec>>(vec: &[EFPacking]) -> Vec { +pub fn pack_extension>, B: OwnedBuffer>>(slice: &[EF]) -> B { + B::build(slice.len() / packing_width::(), |out| { + fill_packed_extension(slice, out) + }) +} + +fn fill_unpacked_extension>>(vec: &[EFPacking], out: &mut [EF]) { let width = packing_width::(); - let total = vec.len() * width; - let mut out: Vec = unsafe { uninitialized_vec(total) }; + let total = out.len(); let write = |out_chunk: &mut [EF], x: &EFPacking| { let packed_coeffs = x.as_basis_coefficients_slice(); for (lane, slot) in out_chunk.iter_mut().enumerate() { @@ -48,11 +45,21 @@ pub fn unpack_extension>>(vec: &[EFPacking]) -> Ve write(chunk, x); } } else { - out.par_chunks_exact_mut(width) - .zip(vec.par_iter()) - .for_each(|(chunk, x)| write(chunk, x)); + // One pool task per group of `group` packed elements, each writing `group * width` + // contiguous output scalars from a disjoint slice of `vec`. + let group = parallel::recommended_chunk_size(vec.len()); + parallel::par_chunks_mut(out, group * width, |ci, out_chunk| { + for (k, sub) in out_chunk.chunks_exact_mut(width).enumerate() { + write(sub, &vec[ci * group + k]); + } + }); } - out +} + +pub fn unpack_extension>, B: OwnedBuffer>(vec: &[EFPacking]) -> B { + B::build(vec.len() * packing_width::(), |out| { + fill_unpacked_extension(vec, out) + }) } pub const fn packing_log_width() -> usize { @@ -67,122 +74,87 @@ pub const fn must_unpack_multilinears(n_vars: usize) -> bool { n_vars <= 1 + packing_log_width::() } -pub fn batch_fold_multilinears< - EF: PrimeCharacteristicRing + Copy + Send + Sync, - IF: Copy + Sub + Send + Sync, - OF: Copy + Add + Send + Sync, - F: Fn(IF, EF) -> OF + Sync + Send, ->( - polys: &[&[IF]], - alpha: EF, - mul_if_of: F, -) -> Vec> { - let total_size: usize = polys.iter().map(|p| p.len()).sum(); - if total_size < PARALLEL_THRESHOLD { - polys - .iter() - .map(|poly| fold_multilinear(poly, alpha, &mul_if_of)) - .collect() +#[inline] +fn fill_fold OF + Sync>(res: &mut [OF], seq: bool, compute: C) { + if seq || res.len() < PARALLEL_THRESHOLD { + for (i, r) in res.iter_mut().enumerate() { + *r = compute(i); + } } else { - polys - .par_iter() - .map(|poly| fold_multilinear(poly, alpha, &mul_if_of)) - .collect() + parallel::par_fill(res, &compute); } } -pub fn fold_multilinear_lsb< +#[inline] +fn fold_fill, C: Fn(usize) -> OF + Sync>(len: usize, seq: bool, compute: C) -> B { + B::build(len, |res| fill_fold(res, seq, compute)) +} + +pub fn fold_multilinear< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, OF: Copy + Add + Send + Sync, - Mul: Fn(IF, EF) -> OF + Sync + Send, + F: Fn(IF, EF) -> OF + Sync + Send, + B: OwnedBuffer, >( m: &[IF], alpha: EF, - mul_if_of: &Mul, -) -> Vec { + mul_if_of: &F, + seq: bool, +) -> B { let new_size = m.len() / 2; - let mut res = unsafe { uninitialized_vec(new_size) }; - let compute = |(c, r_v): (&[IF], &mut OF)| { - *r_v = mul_if_of(c[1] - c[0], alpha) + c[0]; - }; - if new_size < PARALLEL_THRESHOLD { - m.chunks_exact(2).zip(res.iter_mut()).for_each(compute); - } else { - m.par_chunks_exact(2).zip(res.par_iter_mut()).for_each(compute); - } - res + fold_fill(new_size, seq, |i| mul_if_of(m[i + new_size] - m[i], alpha) + m[i]) } pub fn fold_multilinear_at_bit< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, OF: Copy + Add + Send + Sync, - Mul: Fn(IF, EF) -> OF + Sync + Send, + F: Fn(IF, EF) -> OF + Sync + Send, + B: OwnedBuffer, >( m: &[IF], alpha: EF, bit: usize, - mul_if_of: &Mul, -) -> Vec { - let new_size = m.len() / 2; + mul_if_of: &F, + seq: bool, +) -> B { assert!(m.len() >= 2 * (1 << bit), "bit out of range for slice length"); - if bit == 0 { - return fold_multilinear_lsb(m, alpha, mul_if_of); + return fold_fill(m.len() / 2, seq, |j| { + mul_if_of(m[2 * j + 1] - m[2 * j], alpha) + m[2 * j] + }); } - let stride = 1usize << bit; let lo_mask = stride - 1; - let mut res = unsafe { uninitialized_vec(new_size) }; - - let compute = |new_j: usize| { + fold_fill(m.len() / 2, seq, |new_j| { let i_hi = new_j >> bit; let i_lo = new_j & lo_mask; let i0 = (i_hi << (bit + 1)) | i_lo; let i1 = i0 | stride; mul_if_of(m[i1] - m[i0], alpha) + m[i0] - }; - - if new_size < PARALLEL_THRESHOLD { - for (new_j, res_v) in res.iter_mut().enumerate() { - *res_v = compute(new_j); - } - } else { - (0..new_size) - .into_par_iter() - .with_min_len(PARALLEL_THRESHOLD) - .map(compute) - .collect_into_vec(&mut res); - } - res + }) } -pub fn fold_multilinear< +pub fn batch_fold_multilinears< EF: PrimeCharacteristicRing + Copy + Send + Sync, IF: Copy + Sub + Send + Sync, OF: Copy + Add + Send + Sync, F: Fn(IF, EF) -> OF + Sync + Send, >( - m: &[IF], + polys: &[&[IF]], alpha: EF, - mul_if_of: &F, -) -> Vec { - let new_size = m.len() / 2; - let mut res = unsafe { uninitialized_vec(new_size) }; - - if new_size < PARALLEL_THRESHOLD { - for i in 0..new_size { - res[i] = mul_if_of(m[i + new_size] - m[i], alpha) + m[i]; - } + mul_if_of: F, +) -> Vec> { + let total_size: usize = polys.iter().map(|p| p.len()).sum(); + if total_size < PARALLEL_THRESHOLD { + polys + .iter() + .map(|poly| fold_multilinear(poly, alpha, &mul_if_of, true)) + .collect() } else { - (0..new_size) - .into_par_iter() - .with_min_len(PARALLEL_THRESHOLD) - .map(|i| mul_if_of(m[i + new_size] - m[i], alpha) + m[i]) - .collect_into_vec(&mut res); + parallel::par_map_collect(polys.len(), |i| fold_multilinear(polys[i], alpha, &mul_if_of, true)) } - res } pub fn batch_fold_multilinears_at_bit< @@ -195,18 +167,17 @@ pub fn batch_fold_multilinears_at_bit< alpha: EF, bit: usize, mul_if_of: F, -) -> Vec> { +) -> Vec> { let total_size: usize = polys.iter().map(|p| p.len()).sum(); if total_size < PARALLEL_THRESHOLD { polys .iter() - .map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of)) + .map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of, true)) .collect() } else { - polys - .par_iter() - .map(|poly| fold_multilinear_at_bit(poly, alpha, bit, &mul_if_of)) - .collect() + parallel::par_map_collect(polys.len(), |i| { + fold_multilinear_at_bit(polys[i], alpha, bit, &mul_if_of, true) + }) } } @@ -223,35 +194,6 @@ pub unsafe fn uninitialized_vec(len: usize) -> Vec { } } -pub fn split_at_many<'a, A>(slice: &'a [A], indices: &[usize]) -> Vec<&'a [A]> { - for i in 0..indices.len() { - if i > 0 { - assert!(indices[i] > indices[i - 1]); - } - assert!(indices[i] <= slice.len()); - } - - if indices.is_empty() { - return vec![slice]; - } - - let mut result = Vec::with_capacity(indices.len() + 1); - let mut current_slice = slice; - let mut prev_idx = 0; - - for &idx in indices { - let adjusted_idx = idx - prev_idx; - let (left, right) = current_slice.split_at(adjusted_idx); - result.push(left); - current_slice = right; - prev_idx = idx; - } - - result.push(current_slice); - - result -} - pub fn split_at_mut_many<'a, A>(slice: &'a mut [A], indices: &[usize]) -> Vec<&'a mut [A]> { for i in 0..indices.len() { if i > 0 { @@ -281,63 +223,8 @@ pub fn split_at_mut_many<'a, A>(slice: &'a mut [A], indices: &[usize]) -> Vec<&' result } -// Parallel - -#[allow(clippy::type_complexity)] -pub fn par_iter_split_4<'a, A: Sync + Send>( - u: &'a [A], -) -> Zip, Iter<'a, A>>, Zip, Iter<'a, A>>> { - let n = u.len(); - assert!(n.is_multiple_of(4)); - let [u_ll, u_lr, u_rl, u_rr] = split_at_many(u, &[n / 4, n / 2, 3 * n / 4]).try_into().ok().unwrap(); - (u_ll.par_iter().zip(u_lr)).zip(u_rl.par_iter().zip(u_rr.par_iter())) -} - -pub fn par_iter_split_2<'a, A: Sync + Send>(u: &'a [A]) -> Zip, Iter<'a, A>> { - par_iter_split_2_capped(u, 0..u.len() / 2) -} - -pub fn par_iter_split_2_capped<'a, A: Sync + Send>(u: &'a [A], range: Range) -> Zip, Iter<'a, A>> { - let n = u.len(); - assert!(n.is_multiple_of(2)); - let (u_left, u_right) = u.split_at(n / 2); - u_left[range.clone()].par_iter().zip(u_right[range.clone()].par_iter()) -} - -pub fn par_iter_mut_split_2<'a, A: Sync + Send>(u: &'a mut [A]) -> Zip, IterMut<'a, A>> { - par_iter_mut_split_2_capped(u, 0..u.len() / 2) -} - -pub fn par_iter_mut_split_2_capped<'a, A: Sync + Send>( - u: &'a mut [A], - range: Range, -) -> Zip, IterMut<'a, A>> { - let n = u.len(); - assert!(n.is_multiple_of(2)); - let (u_left, u_right) = u.split_at_mut(n / 2); - u_left[range.clone()].par_iter_mut().zip(u_right[range].par_iter_mut()) -} - -#[allow(clippy::type_complexity)] -pub fn par_zip_fold_2<'a, 'b, A: Sync + Send, B: Sync + Send>( - u: &'a [A], - folded: &'b mut [B], -) -> Zip, Iter<'a, A>>, Zip, Iter<'a, A>>>, Zip, IterMut<'b, B>>> { - let n = u.len(); - assert!(n.is_multiple_of(4)); - assert_eq!(folded.len(), n / 2); - par_iter_split_4(u).zip(par_iter_mut_split_2(folded)) -} - // Sequential -pub fn iter_split_2(u: &[A]) -> impl Iterator { - let n = u.len(); - assert!(n.is_multiple_of(2)); - let (u_left, u_right) = u.split_at(n / 2); - u_left.iter().zip(u_right.iter()) -} - pub fn iter_split_4(u: &[A]) -> impl Iterator { let n = u.len(); assert!(n.is_multiple_of(4)); @@ -365,18 +252,6 @@ pub fn zip_fold_2<'a, 'b, A, B>( iter_split_4(u).zip(iter_mut_split_2(folded)) } -pub fn transmute_array(input: [A; N]) -> [A; M] { - assert_eq!(N, M, "Array sizes must match"); - - unsafe { - // Prevent input from being dropped - let input = ManuallyDrop::new(input); - - // Read the array as a pointer and cast to the output type - std::ptr::read(&*input as *const [A; N] as *const [A; M]) - } -} - pub fn to_big_endian_bits(value: usize, bit_count: usize) -> Vec { (0..bit_count).rev().map(|i| (value >> i) & 1 == 1).collect() } @@ -394,12 +269,6 @@ pub fn to_little_endian_bits(value: usize, bit_count: usize) -> Vec { res } -pub fn to_little_endian_in_field(value: usize, bit_count: usize) -> Vec { - let mut res = to_big_endian_in_field::(value, bit_count); - res.reverse(); - res -} - #[cfg(test)] mod bench_tests { use std::time::{Duration, Instant}; @@ -465,9 +334,9 @@ mod bench_tests { for &log_n in &LOG_SIZES { let n = 1usize << log_n; let ext_vec: Vec = (0..n).map(|_| rng.random()).collect(); - let packed = pack_extension(&ext_vec); - let _ = unpack_extension::(&packed); // warmup - let (avg, min_t, max_t) = measure(|| unpack_extension::(&packed)); + let packed: Vec<_> = pack_extension(&ext_vec); + let _ = unpack_extension::>(&packed); // warmup + let (avg, min_t, max_t) = measure(|| unpack_extension::>(&packed)); print_row(log_n, n, avg, min_t, max_t); } } @@ -479,8 +348,8 @@ mod bench_tests { for &log_n in &LOG_SIZES { let n = 1usize << log_n; let ext_vec: Vec = (0..n).map(|_| rng.random()).collect(); - let _ = pack_extension::(&ext_vec); // warmup - let (avg, min_t, max_t) = measure(|| pack_extension::(&ext_vec)); + let _ = pack_extension::>(&ext_vec); // warmup + let (avg, min_t, max_t) = measure(|| pack_extension::>(&ext_vec)); print_row(log_n, n, avg, min_t, max_t); } } diff --git a/crates/backend/src/lib.rs b/crates/backend/src/lib.rs index cbd44fb2b..fea3c62ee 100644 --- a/crates/backend/src/lib.rs +++ b/crates/backend/src/lib.rs @@ -2,10 +2,10 @@ pub use air::*; pub use fiat_shamir::*; pub use field::*; pub use koala_bear::*; +pub use parallel; pub use poly::*; -pub use rayon; -pub use rayon::prelude::*; pub use sumcheck::*; pub use symetric::*; pub use utils::*; pub use whir::*; +pub use zk_alloc::*; diff --git a/crates/backend/sumcheck/Cargo.toml b/crates/backend/sumcheck/Cargo.toml index 91085f352..51f5d7dd7 100644 --- a/crates/backend/sumcheck/Cargo.toml +++ b/crates/backend/sumcheck/Cargo.toml @@ -9,7 +9,8 @@ air = { path = "../air", package = "mt-air" } poly = { path = "../poly", package = "mt-poly" } fiat-shamir = { path = "../fiat-shamir", package = "mt-fiat-shamir" } tracing.workspace = true -rayon.workspace = true +parallel.workspace = true +zk-alloc.workspace = true [dev-dependencies] koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } diff --git a/crates/backend/sumcheck/src/product_computation.rs b/crates/backend/sumcheck/src/product_computation.rs index 2828af039..c069e7519 100644 --- a/crates/backend/sumcheck/src/product_computation.rs +++ b/crates/backend/sumcheck/src/product_computation.rs @@ -1,8 +1,8 @@ use fiat_shamir::*; use field::*; use poly::*; -use rayon::prelude::*; use tracing::instrument; +use zk_alloc::ArenaVec; use crate::{SumcheckComputation, sumcheck_prove_many_rounds}; @@ -146,15 +146,13 @@ pub fn compute_product_sumcheck_polynomial< (a0 + b0, a2 + b2) }) } else { - pol_0[..n / 2] - .par_iter() - .zip(pol_0[n / 2..].par_iter()) - .zip(pol_1[..n / 2].par_iter().zip(pol_1[n / 2..].par_iter())) - .map(sumcheck_quadratic) - .reduce( - || (EFPacking::ZERO, EFPacking::ZERO), - |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), - ) + let half = n / 2; + parallel::map_reduce( + half, + || (EFPacking::ZERO, EFPacking::ZERO), + |i| sumcheck_quadratic(((&pol_0[i], &pol_0[half + i]), (&pol_1[i], &pol_1[half + i]))), + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ) }; let c0 = decompose(c0_packed).into_iter().sum::(); @@ -164,77 +162,6 @@ pub fn compute_product_sumcheck_polynomial< DensePolynomial::new(vec![c0, c1, c2]) } -// using delayed modular reduction -pub fn compute_product_sumcheck_polynomial_base_ext_packed< - const DIM: usize, - F: PrimeField32, - PF: PackedField, - EFP: BasedVectorSpace + Copy + Send + Sync, - EF: Field + BasedVectorSpace, ->( - pol_0: &[PF], - pol_1: &[EFP], - sum: EF, -) -> DensePolynomial { - assert_eq!(DIM, EF::DIMENSION); - let n = pol_0.len(); - assert_eq!(n, pol_1.len()); - assert!(n.is_power_of_two()); - let half = n / 2; - - type Acc = ([u128; D], [i128; D]); - - let chunk_size = 1024; - - let (c0_acc, c2_acc) = pol_0[..half] - .par_chunks(chunk_size) - .zip(pol_0[half..].par_chunks(chunk_size)) - .zip( - pol_1[..half] - .par_chunks(chunk_size) - .zip(pol_1[half..].par_chunks(chunk_size)), - ) - .map(|((b_lo, b_hi), (e_lo, e_hi))| { - let mut c0 = [0u128; DIM]; - let mut c2 = [0i128; DIM]; - for i in 0..b_lo.len() { - let x0_lanes = b_lo[i].as_slice(); - let x1_lanes = b_hi[i].as_slice(); - let y0_coords = e_lo[i].as_basis_coefficients_slice(); - let y1_coords = e_hi[i].as_basis_coefficients_slice(); - for j in 0..DIM { - let y0_j = y0_coords[j].as_slice(); - let y1_j = y1_coords[j].as_slice(); - for lane in 0..PF::WIDTH { - let x0 = x0_lanes[lane].to_unique_u32() as u64; - let y0 = y0_j[lane].to_unique_u32(); - let y1 = y1_j[lane].to_unique_u32(); - c0[j] += (y0 as u64 * x0) as u128; - c2[j] += (y1 as i64 - y0 as i64) as i128 - * (x1_lanes[lane].to_unique_u32() as i64 - x0 as i64) as i128; - } - } - } - (c0, c2) - }) - .reduce( - || ([0u128; DIM], [0i128; DIM]), - |(mut a0, mut a2): Acc, (b0, b2): Acc| { - for j in 0..DIM { - a0[j] += b0[j]; - a2[j] += b2[j]; - } - (a0, a2) - }, - ); - - let c0 = EF::from_basis_coefficients_fn(|j| F::reduce_product_sum(c0_acc[j])); - let c2 = EF::from_basis_coefficients_fn(|j| F::reduce_signed_product_sum(c2_acc[j])); - let c1 = sum - c0.double() - c2; - - DensePolynomial::new(vec![c0, c1, c2]) -} - pub fn fold_and_compute_product_sumcheck_polynomial< F: PrimeCharacteristicRing + Copy + Send + Sync + 'static, EF: Field, @@ -245,14 +172,14 @@ pub fn fold_and_compute_product_sumcheck_polynomial< prev_folding_factor: EF, sum: EF, decompose: impl Fn(EFPacking) -> Vec, -) -> (DensePolynomial, Vec>) { +) -> (DensePolynomial, Vec>) { let n = pol_0.len(); assert_eq!(n, pol_1.len()); assert!(n.is_power_of_two()); let prev_folding_factor_packed = EFPacking::from(prev_folding_factor); - let mut pol_0_folded = unsafe { uninitialized_vec::(n / 2) }; - let mut pol_1_folded = unsafe { uninitialized_vec::(n / 2) }; + let mut pol_0_folded = unsafe { ArenaVec::::uninitialized(n / 2) }; + let mut pol_1_folded = unsafe { ArenaVec::::uninitialized(n / 2) }; #[allow(clippy::type_complexity)] let process_element = |(p0_prev, p0_f): (((&F, &F), (&F, &F)), (&mut EFPacking, &mut EFPacking)), @@ -283,13 +210,33 @@ pub fn fold_and_compute_product_sumcheck_polynomial< (a0 + b0, a2 + b2) }) } else { - par_zip_fold_2(pol_0, &mut pol_0_folded) - .zip(par_zip_fold_2(pol_1, &mut pol_1_folded)) - .map(|(p0, p1)| process_element(p0, p1)) - .reduce( - || (EFPacking::ZERO, EFPacking::ZERO), - |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), - ) + let quarter = n / 4; + let p0f = parallel::SendPtr(pol_0_folded.as_mut_ptr()); + let p1f = parallel::SendPtr(pol_1_folded.as_mut_ptr()); + parallel::map_reduce( + quarter, + || (EFPacking::ZERO, EFPacking::ZERO), + |i| { + let diff_0 = pol_0[2 * quarter + i] - pol_0[i]; + let diff_1 = pol_0[3 * quarter + i] - pol_0[quarter + i]; + let x_0 = prev_folding_factor_packed * diff_0 + pol_0[i]; + let x_1 = prev_folding_factor_packed * diff_1 + pol_0[quarter + i]; + + let y_0 = prev_folding_factor_packed * (pol_1[2 * quarter + i] - pol_1[i]) + pol_1[i]; + let y_1 = + prev_folding_factor_packed * (pol_1[3 * quarter + i] - pol_1[quarter + i]) + pol_1[quarter + i]; + + unsafe { + *p0f.add(i) = x_0; + *p0f.add(quarter + i) = x_1; + *p1f.add(i) = y_0; + *p1f.add(quarter + i) = y_1; + } + + sumcheck_quadratic(((&x_0, &x_1), (&y_0, &y_1))) + }, + |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2), + ) }; let c0 = decompose(c0_packed).into_iter().sum::(); diff --git a/crates/backend/sumcheck/src/prove.rs b/crates/backend/sumcheck/src/prove.rs index 5fc50fada..5e0bdd5c1 100644 --- a/crates/backend/sumcheck/src/prove.rs +++ b/crates/backend/sumcheck/src/prove.rs @@ -6,33 +6,6 @@ use poly::*; use crate::*; -#[allow(clippy::too_many_arguments)] -pub fn sumcheck_prove<'a, EF, SC, M: Into>>( - multilinears_f: M, - computation: &SC, - extra_data: &SC::ExtraData, - eq_factor: Option>, - prover_state: &mut impl FSProver, - sum: EF, - store_intermediate_foldings: bool, -) -> (MultilinearPoint, Vec, EF) -where - EF: ExtensionField>, - SC: SumcheckComputation + 'static, - SC::ExtraData: AlphaPowers, -{ - sumcheck_fold_and_prove( - multilinears_f, - None, - computation, - extra_data, - eq_factor, - prover_state, - sum, - store_intermediate_foldings, - ) -} - #[allow(clippy::too_many_arguments)] pub fn sumcheck_fold_and_prove<'a, EF, SC, M: Into>>( multilinears_f: M, diff --git a/crates/backend/sumcheck/src/sc_computation.rs b/crates/backend/sumcheck/src/sc_computation.rs index f5f84e470..57b3ffd9f 100644 --- a/crates/backend/sumcheck/src/sc_computation.rs +++ b/crates/backend/sumcheck/src/sc_computation.rs @@ -2,9 +2,16 @@ use crate::*; use air::*; use field::*; use poly::*; -use rayon::prelude::*; use std::any::TypeId; use std::ops::{Add, AddAssign, Mul, MulAssign, Sub}; +use zk_alloc::ArenaVec; + +fn add_assign_vec(mut a: Vec, b: Vec) -> Vec { + for (x, y) in a.iter_mut().zip(b) { + *x += y; + } + a +} pub trait SumcheckComputation>>: Sync { type ExtraData: Send + Sync + 'static; @@ -58,39 +65,12 @@ where } } -fn parallel_sum(size: usize, n: usize, compute_iteration: F) -> Vec -where - T: PrimeCharacteristicRing + Send + Sync, - F: Fn(usize) -> Vec + Sync + Send, -{ - let accumulate = |mut acc: Vec, sums: Vec| { - for (j, sum) in sums.into_iter().enumerate() { - acc[j] += sum; - } - acc - }; - - if size < PARALLEL_THRESHOLD { - (0..size).fold(T::zero_vec(n), |acc, i| accumulate(acc, compute_iteration(i))) - } else { - (0..size) - .into_par_iter() - .map(compute_iteration) - .reduce(|| T::zero_vec(n), accumulate) - } -} - fn build_evals>>( sums: impl IntoIterator, missing_mul_factor: Option, ) -> Vec { sums.into_iter() - .map(|mut sum| { - if let Some(factor) = missing_mul_factor { - sum *= factor; - } - sum - }) + .map(|sum| missing_mul_factor.map_or(sum, |f| sum * f)) .collect() } @@ -210,7 +190,7 @@ where extra_data, missing_mul_factor, packed_fold_size, - |sc, pf, ed| sc.eval_packed_extension(&pf, ed), + |sc, pf, ed| sc.eval_packed_extension(pf, ed), packing_unpack_sum, ) } @@ -222,7 +202,7 @@ where extra_data, missing_mul_factor, packed_fold_size, - |sc, pf, ed| sc.eval_packed_extension(&pf, ed), + |sc, pf, ed| sc.eval_packed_extension(pf, ed), packing_unpack_sum, ), MleGroupRef::BasePacked(multilinears) => { @@ -240,7 +220,7 @@ where extra_data, missing_mul_factor, packed_fold_size, - |sc, pf, ed| sc.eval_packed_base(&pf, ed), + |sc, pf, ed| sc.eval_packed_base(pf, ed), packing_unpack_sum, ) } @@ -252,7 +232,7 @@ where extra_data, missing_mul_factor, fold_size, - |sc, pf, ed| sc.eval_base(&pf, ed), + |sc, pf, ed| sc.eval_base(pf, ed), |s| s, ), MleGroupRef::Extension(multilinears) => sumcheck_compute_core( @@ -263,7 +243,7 @@ where extra_data, missing_mul_factor, fold_size, - |sc, pf, ed| sc.eval_extension(&pf, ed), + |sc, pf, ed| sc.eval_extension(pf, ed), |s| s, ), } @@ -316,7 +296,7 @@ where missing_mul_factor, compute_fold_size, |m, id| (m[id + prev_folded_size] - m[id]) * prev_folding_factor + m[id], - |sc, pf, ed| sc.eval_packed_extension(&pf, ed), + |sc, pf, ed| sc.eval_packed_extension(pf, ed), packing_unpack_sum, MleGroupOwned::ExtensionPacked, ) @@ -332,7 +312,7 @@ where missing_mul_factor, compute_fold_size, |m, id| (m[id + prev_folded_size] - m[id]) * prev_folding_factor + m[id], - |sc, pf, ed| sc.eval_packed_extension(&pf, ed), + |sc, pf, ed| sc.eval_packed_extension(pf, ed), packing_unpack_sum, MleGroupOwned::ExtensionPacked, ) @@ -355,7 +335,7 @@ where missing_mul_factor, compute_fold_size, |m, id| prev_folding_factor_packed * (m[id + prev_folded_size] - m[id]) + m[id], - |sc, pf, ed| sc.eval_packed_extension(&pf, ed), + |sc, pf, ed| sc.eval_packed_extension(pf, ed), packing_unpack_sum, MleGroupOwned::ExtensionPacked, ) @@ -371,7 +351,7 @@ where missing_mul_factor, compute_fold_size, |m, id| prev_folding_factor * (m[id + prev_folded_size] - m[id]) + m[id], - |sc, pf, ed| sc.eval_extension(&pf, ed), + |sc, pf, ed| sc.eval_extension(pf, ed), |s| s, MleGroupOwned::Extension, ) @@ -387,7 +367,7 @@ where missing_mul_factor, compute_fold_size, |m, id| (m[id + prev_folded_size] - m[id]) * prev_folding_factor + m[id], - |sc, pf, ed| sc.eval_extension(&pf, ed), + |sc, pf, ed| sc.eval_extension(pf, ed), |s| s, MleGroupOwned::Extension, ) @@ -404,7 +384,7 @@ fn sumcheck_compute_core( extra_data: &SC::ExtraData, missing_mul_factor: Option, fold_size: usize, - eval_fn: impl Fn(&SC, Vec, &SC::ExtraData) -> EFT + Sync + Send, + eval_fn: impl Fn(&SC, &[IF], &SC::ExtraData) -> EFT + Sync + Send, unpack_sum: impl Fn(EFT) -> EF, ) -> Vec where @@ -420,43 +400,46 @@ where + MulAssign, SC: SumcheckComputation, { - let compute_at = |i: usize, eq_val: Option| -> Vec { - let mut rows = multilinears - .iter() - .map(|m| { + let n_mult = multilinears.len(); + let sums = parallel::map_reduce_with_state( + fold_size, + || (Vec::<[IF; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), + || EFT::zero_vec(degree), + |(rows, point), acc, i| { + let eq_val = eq_at(i); + + rows.clear(); + rows.extend(multilinears.iter().map(|m| { let lo = m[i]; let hi = m[i + fold_size]; - let diff_hi_lo = hi - lo; - [lo, diff_hi_lo, hi] - }) - .collect::>(); - - // z = 0 - let point_0 = rows.iter().map(|row| row[0]).collect::>(); - let mut eval_0 = eval_fn(computation, point_0, extra_data); - if let Some(eq) = eq_val { - eval_0 *= eq; - } + [lo, hi - lo, hi] + })); - let mut evals = Vec::with_capacity(degree); - evals.push(eval_0); - - // z = 2, 3, ... - for _ in 1..degree { - for [_, diff_hi_lo, running] in &mut rows { - *running += *diff_hi_lo; - } - let point_f = rows.iter().map(|row| row[2]).collect::>(); - let mut eval = eval_fn(computation, point_f, extra_data); + // z = 0 + point.clear(); + point.extend(rows.iter().map(|row| row[0])); + let mut eval_0 = eval_fn(computation, point, extra_data); if let Some(eq) = eq_val { - eval *= eq; + eval_0 *= eq; } - evals.push(eval); - } - evals - }; + acc[0] += eval_0; - let sums = parallel_sum(fold_size, degree, |i| compute_at(i, eq_at(i))); + // z = 2, 3, ... + for acc_d in acc.iter_mut().skip(1) { + for [_, diff_hi_lo, running] in rows.iter_mut() { + *running += *diff_hi_lo; + } + point.clear(); + point.extend(rows.iter().map(|row| row[2])); + let mut eval = eval_fn(computation, point, extra_data); + if let Some(eq) = eq_val { + eval *= eq; + } + *acc_d += eval; + } + }, + add_assign_vec, + ); let unpacked_sums = sums.into_iter().map(&unpack_sum); build_evals(unpacked_sums, missing_mul_factor) } @@ -472,9 +455,9 @@ fn sumcheck_fold_and_compute_core( missing_mul_factor: Option, compute_fold_size: usize, fold_f: impl Fn(&[IF], usize) -> FT + Sync + Send, - eval_fn: impl Fn(&SC, Vec, &SC::ExtraData) -> FT + Sync + Send, + eval_fn: impl Fn(&SC, &[FT], &SC::ExtraData) -> FT + Sync + Send, unpack_sum: impl Fn(FT) -> EF, - wrap_f: impl FnOnce(Vec>) -> MleGroupOwned, + wrap_f: impl FnOnce(Vec>) -> MleGroupOwned, ) -> (Vec, MleGroupOwned) where EF: ExtensionField>, @@ -484,17 +467,20 @@ where { let prev_folded_size = 2 * compute_fold_size; - let folded_f: Vec> = (0..multilinears.len()) - .map(|_| FT::zero_vec(prev_folded_size)) + let folded_f: Vec> = (0..multilinears.len()) + .map(|_| unsafe { ArenaVec::::zeroed(prev_folded_size) }) .collect(); - let compute_iteration = |i: usize| -> Vec { - let eq_mle_eval = eq_at(i); + let n_mult = multilinears.len(); + let sums = parallel::map_reduce_with_state( + compute_fold_size, + || (Vec::<[FT; 3]>::with_capacity(n_mult), Vec::::with_capacity(n_mult)), + || FT::zero_vec(degree), + |(rows_f, point), acc, i| { + let eq_mle_eval = eq_at(i); - let mut rows_f: Vec<[FT; 3]> = multilinears - .iter() - .enumerate() - .map(|(j, m)| { + rows_f.clear(); + rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { let lo = fold_f(m, i); let hi = fold_f(m, i + compute_fold_size); unsafe { @@ -502,37 +488,34 @@ where *ptr.add(i) = lo; *ptr.add(i + compute_fold_size) = hi; } - let diff_hi_lo = hi - lo; - [lo, diff_hi_lo, hi] - }) - .collect(); - - // z = 0 - let point_0 = rows_f.iter().map(|row| row[0]).collect::>(); - let mut eval_0 = eval_fn(computation, point_0, extra_data); - if let Some(eq) = eq_mle_eval { - eval_0 *= eq; - } - - let mut evals = Vec::with_capacity(degree); - evals.push(eval_0); + [lo, hi - lo, hi] + })); - // z = 2, 3, ... - for _ in 1..degree { - for [_, diff_hi_lo, running] in &mut rows_f { - *running += *diff_hi_lo; - } - let point_f = rows_f.iter().map(|row| row[2]).collect::>(); - let mut eval = eval_fn(computation, point_f, extra_data); + // z = 0 + point.clear(); + point.extend(rows_f.iter().map(|row| row[0])); + let mut eval_0 = eval_fn(computation, point, extra_data); if let Some(eq) = eq_mle_eval { - eval *= eq; + eval_0 *= eq; } - evals.push(eval); - } - evals - }; + acc[0] += eval_0; - let sums = parallel_sum(compute_fold_size, degree, compute_iteration); + // z = 2, 3, ... + for acc_d in acc.iter_mut().skip(1) { + for [_, diff_hi_lo, running] in rows_f.iter_mut() { + *running += *diff_hi_lo; + } + point.clear(); + point.extend(rows_f.iter().map(|row| row[2])); + let mut eval = eval_fn(computation, point, extra_data); + if let Some(eq) = eq_mle_eval { + eval *= eq; + } + *acc_d += eval; + } + }, + add_assign_vec, + ); let unpacked_sums = sums.into_iter().map(&unpack_sum); (build_evals(unpacked_sums, missing_mul_factor), wrap_f(folded_f)) } @@ -546,7 +529,7 @@ fn sumcheck_compute_with_split_eq( extra_data: &SC::ExtraData, missing_mul_factor: Option, fold_size: usize, - eval_fn: impl Fn(&SC, Vec>, &SC::ExtraData) -> EFPacking + Sync + Send, + eval_fn: impl Fn(&SC, &[EFPacking], &SC::ExtraData) -> EFPacking + Sync + Send, unpack_sum: impl Fn(EFPacking) -> EF, ) -> Vec where @@ -559,57 +542,57 @@ where let eq_lo = &split_eq.eq_lo; let eq_hi = &split_eq.eq_hi_packed; - let zero = || EFPacking::::zero_vec(degree); - let accumulate = |mut acc: Vec>, vals: Vec>| -> Vec> { - for (a, v) in acc.iter_mut().zip(vals.iter()) { - *a += *v; - } - acc - }; - - let sums: Vec> = (0..n_lo) - .into_par_iter() - .map(|b_lo| { + let n_mult = multilinears.len(); + let sums: Vec> = parallel::map_reduce_with_state( + n_lo, + || { + ( + Vec::<[EFPacking; 3]>::with_capacity(n_mult), + Vec::>::with_capacity(n_mult), + EFPacking::::zero_vec(degree), + ) + }, + || EFPacking::::zero_vec(degree), + |(rows, point, block_acc), acc, b_lo| { let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); let base = b_lo << log_packed_hi; - let mut block_acc = zero(); + block_acc.iter_mut().for_each(|x| *x = EFPacking::::ZERO); for k in 0..packed_hi { let i = base + k; let eq_val = eq_hi[k]; - let mut rows = multilinears - .iter() - .map(|m| { - let lo = m[i]; - let hi = m[i + fold_size]; - let diff = hi - lo; - [lo, diff, hi] - }) - .collect::>(); + rows.clear(); + rows.extend(multilinears.iter().map(|m| { + let lo = m[i]; + let hi = m[i + fold_size]; + [lo, hi - lo, hi] + })); // z = 0 - let p0 = rows.iter().map(|r| r[0]).collect(); - let mut e0 = eval_fn(computation, p0, extra_data); + point.clear(); + point.extend(rows.iter().map(|r| r[0])); + let mut e0 = eval_fn(computation, point, extra_data); e0 *= eq_val; block_acc[0] += e0; // z = 2, 3, ... for d in 1..degree { - for [_, diff, running] in &mut rows { + for [_, diff, running] in rows.iter_mut() { *running += *diff; } - let pf = rows.iter().map(|r| r[2]).collect(); - let mut ev = eval_fn(computation, pf, extra_data); + point.clear(); + point.extend(rows.iter().map(|r| r[2])); + let mut ev = eval_fn(computation, point, extra_data); ev *= eq_val; block_acc[d] += ev; } } - for a in &mut block_acc { - *a *= eq_lo_bc; + for (a, b) in acc.iter_mut().zip(block_acc.iter()) { + *a += *b * eq_lo_bc; } - block_acc - }) - .reduce(zero, accumulate); + }, + add_assign_vec, + ); let unpacked = sums.into_iter().map(&unpack_sum); build_evals(unpacked, missing_mul_factor) @@ -626,9 +609,9 @@ fn sumcheck_fold_and_compute_with_split_eq( missing_mul_factor: Option, compute_fold_size: usize, fold_f: impl Fn(&[IF], usize) -> EFPacking + Sync + Send, - eval_fn: impl Fn(&SC, Vec>, &SC::ExtraData) -> EFPacking + Sync + Send, + eval_fn: impl Fn(&SC, &[EFPacking], &SC::ExtraData) -> EFPacking + Sync + Send, unpack_sum: impl Fn(EFPacking) -> EF, - wrap_f: impl FnOnce(Vec>>) -> MleGroupOwned, + wrap_f: impl FnOnce(Vec>>) -> MleGroupOwned, ) -> (Vec, MleGroupOwned) where EF: ExtensionField>, @@ -636,8 +619,8 @@ where SC: SumcheckComputation, { let prev_folded_size = 2 * compute_fold_size; - let folded_f: Vec>> = (0..multilinears.len()) - .map(|_| EFPacking::::zero_vec(prev_folded_size)) + let folded_f: Vec>> = (0..multilinears.len()) + .map(|_| unsafe { ArenaVec::>::zeroed(prev_folded_size) }) .collect(); let n_lo = split_eq.n_lo(); @@ -646,61 +629,62 @@ where let eq_lo = &split_eq.eq_lo; let eq_hi = &split_eq.eq_hi_packed; - let zero = || EFPacking::::zero_vec(degree); - let accumulate = |mut acc: Vec>, vals: Vec>| -> Vec> { - for (a, v) in acc.iter_mut().zip(vals.iter()) { - *a += *v; - } - acc - }; - - let sums: Vec> = (0..n_lo) - .into_par_iter() - .map(|b_lo| { + let n_mult = multilinears.len(); + let sums: Vec> = parallel::map_reduce_with_state( + n_lo, + || { + ( + Vec::<[EFPacking; 3]>::with_capacity(n_mult), + Vec::>::with_capacity(n_mult), + EFPacking::::zero_vec(degree), + ) + }, + || EFPacking::::zero_vec(degree), + |(rows_f, point, block_acc), acc, b_lo| { let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); let base = b_lo << log_packed_hi; - let mut block_acc = zero(); + block_acc.iter_mut().for_each(|x| *x = EFPacking::::ZERO); for k in 0..packed_hi { let i = base + k; let eq_val = eq_hi[k]; - let mut rows_f: Vec<[EFPacking; 3]> = multilinears - .iter() - .enumerate() - .map(|(j, m)| { - let lo = fold_f(m, i); - let hi = fold_f(m, i + compute_fold_size); - unsafe { - let ptr = folded_f[j].as_ptr() as *mut EFPacking; - *ptr.add(i) = lo; - *ptr.add(i + compute_fold_size) = hi; - } - let diff = hi - lo; - [lo, diff, hi] - }) - .collect(); - - let p0 = rows_f.iter().map(|r| r[0]).collect(); - let mut e0 = eval_fn(computation, p0, extra_data); + rows_f.clear(); + rows_f.extend(multilinears.iter().enumerate().map(|(j, m)| { + let lo = fold_f(m, i); + let hi = fold_f(m, i + compute_fold_size); + unsafe { + let ptr = folded_f[j].as_ptr() as *mut EFPacking; + *ptr.add(i) = lo; + *ptr.add(i + compute_fold_size) = hi; + } + [lo, hi - lo, hi] + })); + + // z = 0 + point.clear(); + point.extend(rows_f.iter().map(|r| r[0])); + let mut e0 = eval_fn(computation, point, extra_data); e0 *= eq_val; block_acc[0] += e0; + // z = 2, 3, ... for d in 1..degree { - for [_, diff, running] in &mut rows_f { + for [_, diff, running] in rows_f.iter_mut() { *running += *diff; } - let pf = rows_f.iter().map(|r| r[2]).collect(); - let mut ev = eval_fn(computation, pf, extra_data); + point.clear(); + point.extend(rows_f.iter().map(|r| r[2])); + let mut ev = eval_fn(computation, point, extra_data); ev *= eq_val; block_acc[d] += ev; } } - for a in &mut block_acc { - *a *= eq_lo_bc; + for (a, b) in acc.iter_mut().zip(block_acc.iter()) { + *a += *b * eq_lo_bc; } - block_acc - }) - .reduce(zero, accumulate); + }, + add_assign_vec, + ); let unpacked = sums.into_iter().map(&unpack_sum); (build_evals(unpacked, missing_mul_factor), wrap_f(folded_f)) diff --git a/crates/backend/sumcheck/src/split_eq.rs b/crates/backend/sumcheck/src/split_eq.rs index a5ced294c..16a8908b8 100644 --- a/crates/backend/sumcheck/src/split_eq.rs +++ b/crates/backend/sumcheck/src/split_eq.rs @@ -1,13 +1,14 @@ use field::{ExtensionField, PackedFieldExtension}; use poly::*; +use zk_alloc::ArenaVec; #[derive(Debug)] pub struct SplitEq>> { - pub eq_lo: Vec, - pub eq_hi_packed: Vec>, + pub eq_lo: ArenaVec, + pub eq_hi_packed: ArenaVec>, pub log_packed_hi: u32, // = log2(eq_hi_packed.len()), cached for bit-shift in get_packed /// Unpacked remainder for when the packed table is empty or exhausted. - pub remainder: Vec, + pub remainder: ArenaVec, } impl>> SplitEq { @@ -16,8 +17,8 @@ impl>> SplitEq { if must_unpack_multilinears::(n + 1) { return Self { - eq_lo: vec![EF::ONE], - eq_hi_packed: Vec::new(), + eq_lo: ArenaVec::filled(EF::ONE, 1), + eq_hi_packed: ArenaVec::new(), log_packed_hi: 0, remainder: eval_eq(eq_point), }; @@ -32,7 +33,7 @@ impl>> SplitEq { eq_lo, eq_hi_packed, log_packed_hi, - remainder: Vec::new(), + remainder: ArenaVec::new(), } } @@ -53,7 +54,7 @@ impl>> SplitEq { self.log_packed_hi = new_len.trailing_zeros(); } else { // eq_hi_packed has 0 or 1 element — unpack to remainder and halve - let mut unpacked: Vec = EFPacking::::to_ext_iter(self.eq_hi_packed.iter().copied()).collect(); + let mut unpacked: ArenaVec = EFPacking::::to_ext_iter(self.eq_hi_packed.iter().copied()).collect(); let scale = self.eq_lo[0]; for v in &mut unpacked { *v *= scale; diff --git a/crates/backend/symetric/Cargo.toml b/crates/backend/symetric/Cargo.toml index 125fb5535..81fb33829 100644 --- a/crates/backend/symetric/Cargo.toml +++ b/crates/backend/symetric/Cargo.toml @@ -6,4 +6,5 @@ edition.workspace = true [dependencies] koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } field = { path = "../field", package = "mt-field" } -rayon.workspace = true +parallel.workspace = true +zk-alloc.workspace = true diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index adcf69b35..9ea0f162d 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -4,7 +4,7 @@ use std::array; use field::PackedValue; -use rayon::prelude::*; +use zk_alloc::ArenaVec; use crate::Compression; @@ -13,7 +13,7 @@ pub const DIGEST_ELEMS: usize = 8; /// A Merkle tree storing only the digest layers (no leaf data). #[derive(Debug, Clone)] pub struct MerkleTree { - pub digest_layers: Vec>, + pub digest_layers: Vec>, } impl MerkleTree { @@ -23,7 +23,7 @@ impl MerkleT P: PackedValue + Default, Comp: Compression<[F; WIDTH]> + Compression<[P; WIDTH]>, { - let mut digest_layers = vec![first_layer]; + let mut digest_layers = vec![ArenaVec::from_slice(&first_layer)]; loop { let prev_layer = digest_layers.last().unwrap().as_slice(); if prev_layer.len() == 1 { @@ -50,7 +50,7 @@ impl MerkleT pub fn compress_layer( prev_layer: &[[P::Value; DIGEST_ELEMS]], comp: &Comp, -) -> Vec<[P::Value; DIGEST_ELEMS]> +) -> ArenaVec<[P::Value; DIGEST_ELEMS]> where P: PackedValue + Default, P::Value: Default + Copy, @@ -65,20 +65,18 @@ where let next_len = prev_layer.len() / 2; let default_digest = [P::Value::default(); DIGEST_ELEMS]; - let mut next_digests = vec![default_digest; next_len_padded]; - - next_digests[0..next_len] - .par_chunks_exact_mut(width) - .enumerate() - .for_each(|(i, digests_chunk)| { - let first_row = i * width; - let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j])); - let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j])); - let packed_digest = crate::compress(comp, [left, right]); - for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { - *dst = src; - } - }); + let mut next_digests = ArenaVec::filled(default_digest, next_len_padded); + + let exact_len = next_len / width * width; + parallel::par_chunks_mut(&mut next_digests[0..exact_len], width, |i, digests_chunk| { + let first_row = i * width; + let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j])); + let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j])); + let packed_digest = crate::compress(comp, [left, right]); + for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { + *dst = src; + } + }); for i in (next_len / width * width)..next_len { let left = prev_layer[2 * i]; diff --git a/crates/backend/symetric/src/permutation.rs b/crates/backend/symetric/src/permutation.rs index c129a1dc4..13c15c6dd 100644 --- a/crates/backend/symetric/src/permutation.rs +++ b/crates/backend/symetric/src/permutation.rs @@ -1,7 +1,7 @@ // Credits: Plonky3 (https://github.com/Plonky3/Plonky3) (MIT and Apache-2.0 licenses). use field::{Algebra, InjectiveMonomial}; -use koala_bear::{KoalaBear, Poseidon1KoalaBear16}; +use koala_bear::{KoalaBear, Poseidon1KoalaBear16, Poseidon1KoalaBear24}; pub trait Compression: Clone + Sync { #[inline(always)] @@ -20,3 +20,11 @@ impl + InjectiveMonomial<3> + Send + Sync + 'static> Compr self.compress_in_place(input); } } + +impl + InjectiveMonomial<3> + Send + Sync + 'static> Compression<[R; 24]> + for Poseidon1KoalaBear24 +{ + fn compress_mut(&self, input: &mut [R; 24]) { + self.compress_in_place(input); + } +} diff --git a/crates/backend/system-info/Cargo.toml b/crates/backend/system-info/Cargo.toml index c63ee1297..862e36e89 100644 --- a/crates/backend/system-info/Cargo.toml +++ b/crates/backend/system-info/Cargo.toml @@ -5,7 +5,6 @@ edition.workspace = true [dependencies] libc = "0.2" -rayon.workspace = true [lints] workspace = true diff --git a/crates/backend/system-info/build.rs b/crates/backend/system-info/build.rs deleted file mode 100644 index 5a6ffd48c..000000000 --- a/crates/backend/system-info/build.rs +++ /dev/null @@ -1,52 +0,0 @@ -// Assumes build host == run host, for simplicity (to be changed in the future) - -fn main() { - let cores = std::thread::available_parallelism().unwrap().get(); - let l1_cache_size = match detect_l1_cache_size() { - Some(size) => size, - None => { - eprintln!("Warning: failed to detect L1 cache size, defaulting to 32 KB"); - 32 * 1024 - } - }; - let out_dir = std::env::var_os("OUT_DIR").unwrap(); - let path = std::path::Path::new(&out_dir).join("info.rs"); - std::fs::write( - &path, - format!( - "pub const NUM_THREADS: usize = {cores};\n\ - pub const L1_CACHE_SIZE: usize = {l1_cache_size};\n" - ), - ) - .unwrap(); - println!("cargo:rerun-if-changed=build.rs"); -} - -#[cfg(target_os = "linux")] -fn detect_l1_cache_size() -> Option { - // /sys reports e.g. "32K\n", "48K\n", "1M\n". - let s = std::fs::read_to_string("/sys/devices/system/cpu/cpu0/cache/index0/size").ok()?; - let s = s.trim(); - let last = s.chars().last()?; - match last { - 'K' | 'k' => s[..s.len() - 1].parse::().ok().map(|n| n * 1024), - 'M' | 'm' => s[..s.len() - 1].parse::().ok().map(|n| n * 1024 * 1024), - c if c.is_ascii_digit() => s.parse().ok(), - _ => None, - } -} - -#[cfg(target_os = "macos")] -fn detect_l1_cache_size() -> Option { - // `hw.l1dcachesize` returns the E-core value on Apple Silicon; prefer the P-core size. - let read_sysctl = |key: &str| -> Option { - let out = std::process::Command::new("sysctl").args(["-n", key]).output().ok()?; - std::str::from_utf8(&out.stdout).ok()?.trim().parse().ok() - }; - read_sysctl("hw.perflevel0.l1dcachesize").or_else(|| read_sysctl("hw.l1dcachesize")) -} - -#[cfg(not(any(target_os = "linux", target_os = "macos")))] -fn detect_l1_cache_size() -> Option { - None -} diff --git a/crates/backend/system-info/src/lib.rs b/crates/backend/system-info/src/lib.rs index 07180559b..533c3cc71 100644 --- a/crates/backend/system-info/src/lib.rs +++ b/crates/backend/system-info/src/lib.rs @@ -1,7 +1,28 @@ -include!(concat!(env!("OUT_DIR"), "/info.rs")); +use std::sync::OnceLock; const _: () = assert!(usize::BITS == 64, "this project requires a 64-bit target (for now)"); +#[must_use] +pub fn num_threads() -> usize { + static CACHE: OnceLock = OnceLock::new(); + *CACHE.get_or_init(|| { + std::thread::available_parallelism() + .expect("failed to detect available parallelism") + .get() + }) +} + +#[must_use] +pub fn l1_cache_size() -> usize { + static CACHE: OnceLock = OnceLock::new(); + *CACHE.get_or_init(|| { + detect_l1_cache_size().unwrap_or_else(|| { + eprintln!("Warning: failed to detect L1 cache size, defaulting to 32 KB"); + 32 * 1024 + }) + }) +} + pub fn peak_rss_bytes() -> u64 { let mut ru: libc::rusage = unsafe { std::mem::zeroed() }; unsafe { libc::getrusage(libc::RUSAGE_SELF, &raw mut ru) }; @@ -10,35 +31,31 @@ pub fn peak_rss_bytes() -> u64 { if cfg!(target_os = "macos") { max } else { max * 1024 } } -/// Number of jobs [`flush_rayon`] pushes. Must exceed -/// `crossbeam_deque::deque::BLOCK_CAP` (currently 63 — -/// `crossbeam-deque-0.8.6/src/deque.rs:1191`). -const RAYON_FLUSH_JOBS: usize = 256; - -/// Drain rayon's internal queues so they release any storage allocated during the -/// previous phase. -/// -/// Rayon's global pool owns a `crossbeam_deque::Injector`, internally a linked list -/// of fixed-size blocks (`Block` and `Injector::push` — -/// `crossbeam-deque-0.8.6/src/deque.rs:1219` and `:1371`). A block is freed only -/// once its last slot has been consumed. -/// -/// `rayon::join` from a non-worker thread reaches that injector via -/// `join` (`rayon-core-1.13.0/src/join/mod.rs:132`) -> -/// `registry::in_worker` (`registry.rs:946`) -> -/// `Registry::in_worker_cold` (`:517`) -> -/// `Registry::inject` (`:428`) -> `Injector::push`. -/// -/// Under an arena allocator that recycles memory between phases (e.g. `zk-alloc`), -/// a block allocated *during* a phase points into a slab the next `begin_phase()` -/// will reuse. The next push then writes a `JobRef` straight through whatever the -/// application has placed on top, silently corrupting it. -/// -/// Pushing more than `BLOCK_CAP` jobs while the arena is off forces the Injector -/// to allocate a fresh tail block (which lands in System), and forces workers to -/// steal the last slot of every preceding block (which destroys them). -pub fn flush_rayon() { - for _ in 0..RAYON_FLUSH_JOBS { - rayon::join(|| {}, || {}); +#[cfg(target_os = "linux")] +fn detect_l1_cache_size() -> Option { + // /sys reports e.g. "32K\n", "48K\n", "1M\n". + let s = std::fs::read_to_string("/sys/devices/system/cpu/cpu0/cache/index0/size").ok()?; + let s = s.trim(); + let last = s.chars().last()?; + match last { + 'K' | 'k' => s[..s.len() - 1].parse::().ok().map(|n| n * 1024), + 'M' | 'm' => s[..s.len() - 1].parse::().ok().map(|n| n * 1024 * 1024), + c if c.is_ascii_digit() => s.parse().ok(), + _ => None, } } + +#[cfg(target_os = "macos")] +fn detect_l1_cache_size() -> Option { + // `hw.l1dcachesize` returns the E-core value on Apple Silicon; prefer the P-core size. + let read_sysctl = |key: &str| -> Option { + let out = std::process::Command::new("sysctl").args(["-n", key]).output().ok()?; + std::str::from_utf8(&out.stdout).ok()?.trim().parse().ok() + }; + read_sysctl("hw.perflevel0.l1dcachesize").or_else(|| read_sysctl("hw.l1dcachesize")) +} + +#[cfg(not(any(target_os = "linux", target_os = "macos")))] +fn detect_l1_cache_size() -> Option { + None +} diff --git a/crates/backend/utils/src/lib.rs b/crates/backend/utils/src/lib.rs index d391caccc..30aeb48ef 100644 --- a/crates/backend/utils/src/lib.rs +++ b/crates/backend/utils/src/lib.rs @@ -6,6 +6,13 @@ use std::{ pub mod array_serialization; +/// Returns the last `n` elements of `slice`. +#[must_use] +pub fn from_end(slice: &[A], n: usize) -> &[A] { + assert!(n <= slice.len()); + &slice[slice.len() - n..] +} + /// Computes `log_2(n)` /// /// # Panics @@ -121,7 +128,6 @@ pub const fn indices_arr() -> [usize; N] { /// reallocations. /// /// # Safety -/// /// This assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`. #[inline] pub unsafe fn flatten_to_base(vec: Vec) -> Vec { @@ -129,27 +135,21 @@ pub unsafe fn flatten_to_base(vec: Vec) -> Vec assert!(align_of::() == align_of::()); assert!(size_of::().is_multiple_of(size_of::())); } - let d = size_of::() / size_of::(); - let mut values = std::mem::ManuallyDrop::new(vec); - let new_len = values.len() * d; - let new_cap = values.capacity() * d; - let ptr = values.as_mut_ptr() as *mut Base; - unsafe { Vec::from_raw_parts(ptr, new_len, new_cap) } + let mut me = mem::ManuallyDrop::new(vec); + unsafe { Vec::from_raw_parts(me.as_mut_ptr().cast::(), me.len() * d, me.capacity() * d) } } /// Convert a vector of `Base` elements to a vector of `BaseArray` elements. /// /// # Safety -/// /// This assumes that `BaseArray` has the same alignment and memory layout as `[Base; N]`. #[inline] -pub unsafe fn reconstitute_from_base(mut vec: Vec) -> Vec { +pub unsafe fn reconstitute_from_base(vec: Vec) -> Vec { const { assert!(align_of::() == align_of::()); assert!(size_of::().is_multiple_of(size_of::())); } - let d = size_of::() / size_of::(); assert!( vec.len().is_multiple_of(d), @@ -158,17 +158,13 @@ pub unsafe fn reconstitute_from_base(mut vec: Vec) d ); let new_len = vec.len() / d; - let cap = vec.capacity(); - - if cap.is_multiple_of(d) { - let mut values = std::mem::ManuallyDrop::new(vec); - let new_cap = cap / d; - let ptr = values.as_mut_ptr() as *mut BaseArray; - unsafe { Vec::from_raw_parts(ptr, new_len, new_cap) } + if vec.capacity().is_multiple_of(d) { + let mut me = mem::ManuallyDrop::new(vec); + unsafe { Vec::from_raw_parts(me.as_mut_ptr().cast::(), new_len, me.capacity() / d) } } else { - let buf_ptr = vec.as_mut_ptr().cast::(); - let slice_ref = unsafe { slice::from_raw_parts(buf_ptr, new_len) }; - slice_ref.to_vec() + // Capacity isn't a clean multiple: copy into a fresh buffer. + let buf_ptr = vec.as_ptr().cast::(); + unsafe { slice::from_raw_parts(buf_ptr, new_len) }.to_vec() } } diff --git a/crates/backend/zk-alloc/Cargo.toml b/crates/backend/zk-alloc/Cargo.toml index fe4c12233..85c078050 100644 --- a/crates/backend/zk-alloc/Cargo.toml +++ b/crates/backend/zk-alloc/Cargo.toml @@ -6,11 +6,7 @@ description = "Bump+reset arena allocator for ZK proving workloads" [dependencies] system-info.workspace = true - -[dev-dependencies] -rayon.workspace = true - -[target.'cfg(not(all(target_os = "linux", target_arch = "x86_64")))'.dependencies] +parallel.workspace = true libc = "0.2" [lints] diff --git a/crates/backend/zk-alloc/src/arena_cow.rs b/crates/backend/zk-alloc/src/arena_cow.rs new file mode 100644 index 000000000..f35e3e845 --- /dev/null +++ b/crates/backend/zk-alloc/src/arena_cow.rs @@ -0,0 +1,73 @@ +use std::ops::Deref; + +use crate::ArenaVec; + +#[derive(Debug)] +pub enum ArenaCow<'a, T> { + Borrowed(&'a [T]), + Owned(ArenaVec), +} + +impl ArenaCow<'_, T> { + #[inline] + #[must_use] + pub fn as_slice(&self) -> &[T] { + match self { + Self::Borrowed(s) => s, + Self::Owned(v) => v, + } + } + + #[inline] + #[must_use] + pub fn len(&self) -> usize { + self.as_slice().len() + } + + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.as_slice().is_empty() + } +} + +impl ArenaCow<'_, T> { + /// Take ownership of the buffer, copying into the arena only if currently borrowed. + #[inline] + #[must_use] + pub fn into_owned(self) -> ArenaVec { + match self { + Self::Borrowed(s) => ArenaVec::from_slice(s), + Self::Owned(v) => v, + } + } +} + +impl<'a, T> From<&'a [T]> for ArenaCow<'a, T> { + #[inline] + fn from(s: &'a [T]) -> Self { + Self::Borrowed(s) + } +} + +impl From> for ArenaCow<'_, T> { + #[inline] + fn from(v: ArenaVec) -> Self { + Self::Owned(v) + } +} + +impl Deref for ArenaCow<'_, T> { + type Target = [T]; + #[inline] + fn deref(&self) -> &[T] { + self.as_slice() + } +} + +impl AsRef<[T]> for ArenaCow<'_, T> { + #[inline] + fn as_ref(&self) -> &[T] { + self.as_slice() + } +} diff --git a/crates/backend/zk-alloc/src/arena_vec.rs b/crates/backend/zk-alloc/src/arena_vec.rs new file mode 100644 index 000000000..df5ae7db9 --- /dev/null +++ b/crates/backend/zk-alloc/src/arena_vec.rs @@ -0,0 +1,464 @@ +//! [`ArenaVec`] — a minimal owning vector backed by the proving arena. +//! +//! Allocation goes through [`raw_alloc`](crate::raw_alloc) (arena bump in a phase, else system) and +//! `Drop`/growth through [`raw_dealloc`](crate::raw_dealloc), which picks arena-vs-system by pointer +//! range — the dynamic choice that lets `ArenaVec` carry no allocator type parameter. An `ArenaVec` +//! allocated in a phase is invalidated by the next [`begin_phase`](crate::begin_phase); anything +//! that must outlive a phase uses the system allocator (a plain `Vec`, or an `ArenaVec` built outside a +//! phase). + +use std::alloc::handle_alloc_error; +use std::cmp; +use std::fmt; +use std::marker::PhantomData; +use std::mem::{ManuallyDrop, align_of, size_of}; +use std::ops::{Deref, DerefMut}; +use std::ptr::{self, NonNull}; +use std::slice; + +use crate::{raw_alloc, raw_dealloc}; + +/// Owning, growable buffer allocated from the proving arena (see the module docs). +pub struct ArenaVec { + /// Always aligned and non-null; dangling (and never dereferenced for reads) while `cap == 0`. + ptr: NonNull, + len: usize, + /// Element capacity. For zero-sized `T` this is fixed at `usize::MAX` and no memory is owned. + cap: usize, + _marker: PhantomData, +} + +unsafe impl Send for ArenaVec {} +unsafe impl Sync for ArenaVec {} + +pub trait OwnedBuffer: DerefMut + Sized { + /// `len` uninitialized elements. + /// + /// # Safety + /// Every element must be written before it is read. + unsafe fn uninit(len: usize) -> Self; + + /// `len` elements, initialized in place by `fill` — which **must** write all of them. + #[inline] + fn build(len: usize, fill: impl FnOnce(&mut [T])) -> Self { + // SAFETY: `fill` writes every one of the `len` elements before any is read. + let mut buf = unsafe { Self::uninit(len) }; + fill(&mut buf); + buf + } +} + +impl OwnedBuffer for Vec { + #[inline] + #[allow(clippy::uninit_vec)] + unsafe fn uninit(len: usize) -> Self { + let mut v = Vec::with_capacity(len); + // SAFETY: the `uninit`/`build` contract requires all `len` slots written before read. + unsafe { v.set_len(len) }; + v + } +} + +impl OwnedBuffer for ArenaVec { + #[inline] + unsafe fn uninit(len: usize) -> Self { + // SAFETY: as above. + unsafe { Self::uninitialized(len) } + } +} + +impl ArenaVec { + /// `usize::MAX` capacity stands in for "unbounded" for zero-sized elements (which never + /// allocate); `0` otherwise. + const EMPTY_CAP: usize = if size_of::() == 0 { usize::MAX } else { 0 }; + + /// A new, empty vector. No allocation. + #[inline] + #[must_use] + pub const fn new() -> Self { + Self { + ptr: NonNull::dangling(), + len: 0, + cap: Self::EMPTY_CAP, + _marker: PhantomData, + } + } + + /// A new, empty vector with room for `cap` elements pre-reserved (exact, no over-allocation). + #[inline] + #[must_use] + pub fn with_capacity(cap: usize) -> Self { + let mut v = Self::new(); + if size_of::() != 0 && cap != 0 { + v.realloc_to(cap); + } + v + } + + /// Arena-backed `vec![value; n]`. + #[inline] + #[must_use] + pub fn filled(value: T, n: usize) -> Self + where + T: Clone, + { + let mut v = Self::with_capacity(n); + v.resize(n, value); + v + } + + /// Arena-backed zero-initialized buffer of length `n`, zeroed with a single `write_bytes` + /// (`memset`) — far cheaper than [`filled`](Self::filled)'s element-wise clone loop. + /// + /// # Safety + /// `T`'s all-zero bit pattern must be a valid, fully-initialized value of `T` (true for the + /// Montgomery field types and their SIMD packings, whose `ZERO` is all-zero bytes). + #[inline] + #[must_use] + pub unsafe fn zeroed(n: usize) -> Self { + // SAFETY: every slot is initialized by the `write_bytes` below before it can be read. + let mut v = unsafe { Self::uninitialized(n) }; + // SAFETY: `v` owns `n` allocated slots; caller guarantees all-zero is a valid `T`. + unsafe { ptr::write_bytes(v.as_mut_ptr(), 0u8, n) }; + v + } + + /// Arena-backed `slice.to_vec()`. + #[inline] + #[must_use] + pub fn from_slice(slice: &[T]) -> Self + where + T: Clone, + { + let mut v = Self::with_capacity(slice.len()); + v.extend_from_slice(slice); + v + } + + /// `len` uninitialized slots. + /// + /// # Safety + /// Every element must be overwritten before it is read. + #[inline] + #[must_use] + pub unsafe fn uninitialized(len: usize) -> Self { + let mut v = Self::with_capacity(len); + // SAFETY: caller guarantees all `len` slots are written before being read. + unsafe { v.set_len(len) }; + v + } + + /// Arena-backed parallel `(0..n).map(f).collect()`: fill a vector of length `n` in parallel. + /// The single allocation happens on the calling thread; workers write disjoint slots. + #[inline] + #[must_use] + pub fn par_collect T + Sync>(n: usize, f: F) -> Self + where + T: Send, + { + // SAFETY: `par_fill` writes every slot in `0..n` exactly once before any is read. + let mut v = unsafe { Self::uninitialized(n) }; + parallel::par_fill(&mut v, f); + v + } + + #[inline] + #[must_use] + pub const fn len(&self) -> usize { + self.len + } + + #[inline] + #[must_use] + pub const fn capacity(&self) -> usize { + self.cap + } + + #[inline] + #[must_use] + pub const fn is_empty(&self) -> bool { + self.len == 0 + } + + #[inline] + #[must_use] + pub const fn as_ptr(&self) -> *const T { + self.ptr.as_ptr() + } + + #[inline] + pub const fn as_mut_ptr(&mut self) -> *mut T { + self.ptr.as_ptr() + } + + #[inline] + #[must_use] + pub fn as_slice(&self) -> &[T] { + self + } + + #[inline] + pub fn as_mut_slice(&mut self) -> &mut [T] { + self + } + + /// Set the length without touching the buffer. + /// + /// # Safety + /// `new_len <= capacity()` and every element in `0..new_len` must be initialized. + #[inline] + pub unsafe fn set_len(&mut self, new_len: usize) { + debug_assert!(new_len <= self.cap); + self.len = new_len; + } + + /// Reserve space for at least `additional` more elements (amortized doubling). + #[inline] + pub fn reserve(&mut self, additional: usize) { + if size_of::() == 0 { + return; // capacity is conceptually unbounded for ZSTs + } + let required = self.len.checked_add(additional).expect("ArenaVec capacity overflow"); + if required > self.cap { + let new_cap = cmp::max(required, self.cap.saturating_mul(2)); + self.realloc_to(new_cap); + } + } + + #[inline] + pub fn push(&mut self, value: T) { + if self.len == self.cap { + // ZSTs never reach here (cap == usize::MAX); only sized types grow. + let new_cap = cmp::max(self.cap.saturating_mul(2), 4); + self.realloc_to(new_cap); + } + // SAFETY: `len < cap` now, so slot `len` is allocated and uninitialized. + unsafe { self.ptr.as_ptr().add(self.len).write(value) }; + self.len += 1; + } + + /// Append a clone of every element of `other`. + #[inline] + pub fn extend_from_slice(&mut self, other: &[T]) + where + T: Clone, + { + self.reserve(other.len()); + // Bump `len` per element so a panic mid-clone leaves a consistent vector (written clones drop). + for x in other { + // SAFETY: `reserve` guaranteed room for `other.len()` more; `len` stays < `cap`. + unsafe { self.ptr.as_ptr().add(self.len).write(x.clone()) }; + self.len += 1; + } + } + + /// Grow or shrink to `new_len`, filling new slots with clones of `value`. + pub fn resize(&mut self, new_len: usize, value: T) + where + T: Clone, + { + if new_len > self.len { + self.reserve(new_len - self.len); + while self.len < new_len { + // SAFETY: room reserved above; `len < new_len <= cap`. + unsafe { self.ptr.as_ptr().add(self.len).write(value.clone()) }; + self.len += 1; + } + } else { + self.truncate(new_len); + } + } + + /// Drop the elements past `len`, keeping capacity. + pub fn truncate(&mut self, len: usize) { + if len < self.len { + let drop_count = self.len - len; + // Shorten first so a panicking `Drop` can't observe/double-drop the tail. + self.len = len; + // SAFETY: `[len, old_len)` were initialized and are now logically removed. + unsafe { + ptr::drop_in_place(ptr::slice_from_raw_parts_mut(self.ptr.as_ptr().add(len), drop_count)); + } + } + } + + #[inline] + pub fn clear(&mut self) { + self.truncate(0); + } + + /// Decompose into raw parts, leaking the buffer. Inverse of [`from_raw_parts`](Self::from_raw_parts). + #[inline] + #[must_use] + pub fn into_raw_parts(self) -> (*mut T, usize, usize) { + let me = ManuallyDrop::new(self); + (me.ptr.as_ptr(), me.len, me.cap) + } + + /// Reconstruct from parts previously obtained via [`into_raw_parts`](Self::into_raw_parts) + /// (or a layout-compatible reinterpret thereof). + /// + /// # Safety + /// `ptr` is non-null and aligned for `T`; `len <= cap`; and `ptr` either was returned by + /// [`raw_alloc`](crate::raw_alloc) for `cap * size_of::()` bytes at `align_of::()`, or + /// `cap == 0` and `ptr` is dangling-but-aligned. Exactly one `ArenaVec` may own a given pointer. + #[inline] + #[must_use] + pub unsafe fn from_raw_parts(ptr: *mut T, len: usize, cap: usize) -> Self { + Self { + // SAFETY: caller guarantees `ptr` is non-null. + ptr: unsafe { NonNull::new_unchecked(ptr) }, + len, + cap, + _marker: PhantomData, + } + } + + /// Allocate a fresh `new_cap`-element buffer, move the `len` live elements into it, and free + /// the old one. Only called for sized `T` with `new_cap >= len` and `new_cap > 0`. + fn realloc_to(&mut self, new_cap: usize) { + debug_assert!(size_of::() != 0 && new_cap >= self.len && new_cap > 0); + let align = align_of::(); + let new_bytes = new_cap.checked_mul(size_of::()).expect("ArenaVec capacity overflow"); + assert!(new_bytes <= isize::MAX as usize, "ArenaVec capacity overflow"); + + // SAFETY: `align` is a valid power of two; `new_bytes > 0`. + let raw = unsafe { raw_alloc(new_bytes, align) }.cast::(); + let Some(new_ptr) = NonNull::new(raw) else { + // Matches `Vec`: an allocation failure aborts rather than unwinds. + handle_alloc_error(unsafe { std::alloc::Layout::from_size_align_unchecked(new_bytes, align) }); + }; + + if self.cap != 0 { + // SAFETY: the two buffers are distinct; `len <= old cap` initialized elements move. + unsafe { ptr::copy_nonoverlapping(self.ptr.as_ptr(), new_ptr.as_ptr(), self.len) }; + // SAFETY: old buffer came from `raw_alloc` with this size/align (range-checked free). + unsafe { raw_dealloc(self.ptr.as_ptr().cast::(), self.cap * size_of::(), align) }; + } + self.ptr = new_ptr; + self.cap = new_cap; + } +} + +impl Drop for ArenaVec { + fn drop(&mut self) { + // Drop the live elements first (no-op for `Copy`/trivial types; the compiler elides it). + if std::mem::needs_drop::() { + // SAFETY: `0..len` are initialized. + unsafe { ptr::drop_in_place(ptr::slice_from_raw_parts_mut(self.ptr.as_ptr(), self.len)) }; + } + // Free the buffer. ZSTs and never-allocated vectors own nothing. + if size_of::() != 0 && self.cap != 0 { + // SAFETY: buffer came from `raw_alloc(cap * size, align)`; `raw_dealloc` range-checks + // arena-vs-system. Arena pointers free as a no-op (reclaimed at the next phase reset). + unsafe { + raw_dealloc( + self.ptr.as_ptr().cast::(), + self.cap * size_of::(), + align_of::(), + ) + }; + } + } +} + +impl Deref for ArenaVec { + type Target = [T]; + #[inline] + fn deref(&self) -> &[T] { + // SAFETY: `ptr` is aligned and `0..len` are initialized (valid for ZSTs too: a dangling + // aligned pointer is a valid base for a zero-byte-stride slice). + unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) } + } +} + +impl DerefMut for ArenaVec { + #[inline] + fn deref_mut(&mut self) -> &mut [T] { + // SAFETY: as `deref`, with unique access. + unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) } + } +} + +impl AsRef<[T]> for ArenaVec { + #[inline] + fn as_ref(&self) -> &[T] { + self + } +} + +impl AsMut<[T]> for ArenaVec { + #[inline] + fn as_mut(&mut self) -> &mut [T] { + self + } +} + +impl Default for ArenaVec { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Clone for ArenaVec { + fn clone(&self) -> Self { + let mut out = Self::with_capacity(self.len); + out.extend_from_slice(self); + out + } +} + +impl fmt::Debug for ArenaVec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl PartialEq for ArenaVec { + #[inline] + fn eq(&self, other: &Self) -> bool { + **self == **other + } +} + +impl Eq for ArenaVec {} + +impl Extend for ArenaVec { + #[inline] + fn extend>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + for x in iter { + self.push(x); + } + } +} + +impl FromIterator for ArenaVec { + #[inline] + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let mut v = Self::with_capacity(iter.size_hint().0); + v.extend(iter); + v + } +} + +impl<'a, T> IntoIterator for &'a ArenaVec { + type Item = &'a T; + type IntoIter = slice::Iter<'a, T>; + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T> IntoIterator for &'a mut ArenaVec { + type Item = &'a mut T; + type IntoIter = slice::IterMut<'a, T>; + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter_mut() + } +} diff --git a/crates/backend/zk-alloc/src/lib.rs b/crates/backend/zk-alloc/src/lib.rs index 1b43143d6..1452d1b6c 100644 --- a/crates/backend/zk-alloc/src/lib.rs +++ b/crates/backend/zk-alloc/src/lib.rs @@ -1,121 +1,127 @@ -//! Bump-pointer arena allocator. -//! -//! One mmap region split into per-thread slabs. Allocation = increment a thread-local -//! pointer; free = no-op. `begin_phase()` resets the arena: each thread's next -//! allocation starts over at the beginning of its slab, overwriting the previous -//! phase's data. Allocations that don't fit (too large, or beyond `MAX_THREADS`) fall -//! back to the system allocator. -//! -//! ```ignore -//! init(); // once, at process start -//! loop { -//! begin_phase(); // arena ON; slabs reset lazily -//! let res = heavy_work(); // fast increments -//! end_phase(); // arena OFF; new allocations go to System -//! let copy = res.clone(); // detach from arena before next phase resets it -//! } -//! ``` +//! Bump-pointer arena, used explicitly (never as a `#[global_allocator]`). One mmap region split +//! into per-thread slabs: alloc bumps a thread-local pointer, free is a no-op, `begin_phase()` +//! resets every slab. Proof data lives in [`ArenaVec`]; `raw_dealloc` picks arena-vs-system by +//! pointer range, so `ArenaVec` carries no allocator parameter. use std::alloc::{GlobalAlloc, Layout}; use std::cell::Cell; -use std::sync::Once; +use std::sync::OnceLock; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use system_info::NUM_THREADS; +use system_info::num_threads; +mod arena_cow; +mod arena_vec; mod syscall; -const SLAB_SIZE: usize = 8 << 30; // 8GB -const SLACK: usize = 4; // SLACK absorbs the main thread and any non-rayon helpers. -const MAX_THREADS: usize = NUM_THREADS + SLACK; -const REGION_SIZE: usize = SLAB_SIZE * MAX_THREADS; +pub use arena_cow::ArenaCow; +pub use arena_vec::{ArenaVec, OwnedBuffer}; -#[derive(Debug)] -pub struct ZkAllocator; - -/// Incremented by `begin_phase()`. Every thread caches the last value it saw in -/// `ARENA_GEN`; when they differ, the thread resets its allocation cursor to the start -/// of its slab on the next allocation. This is how a single store on the main thread -/// "resets" every other thread's slab without any cross-thread synchronization. -static GENERATION: AtomicUsize = AtomicUsize::new(0); +/// Build an [`ArenaVec`], mirroring [`std::vec!`]: +#[macro_export] +macro_rules! arena_vec { + () => { $crate::ArenaVec::new() }; + ($elem:expr; $n:expr) => { $crate::ArenaVec::filled($elem, $n) }; + ($($x:expr),+ $(,)?) => { $crate::ArenaVec::from_iter([$($x),+]) }; +} -/// Master switch for the arena. `true` (set by `begin_phase`) routes allocations -/// through the arena; `false` (set by `end_phase`) routes them to the system allocator. -static ARENA_ACTIVE: AtomicBool = AtomicBool::new(false); +const SLAB_SIZE: usize = 8 << 30; // 8 GiB; per-thread soft cap, overflow falls back to System +const SLACK: usize = 4; // extra slabs for non-pool threads that allocate in a phase -/// Base address of the mmap'd region, or `0` before `ensure_region` runs. Read on -/// every `dealloc` to test whether a pointer belongs to us. -static REGION_BASE: AtomicUsize = AtomicUsize::new(0); +fn max_threads() -> usize { + num_threads() + SLACK +} -/// Synchronizes the one-time mmap so concurrent first-allocators don't race. -static REGION_INIT: Once = Once::new(); +fn region_size() -> usize { + static SIZE: OnceLock = OnceLock::new(); + *SIZE.get_or_init(|| SLAB_SIZE * max_threads()) +} -/// Monotonic counter handed out to threads to pick their slab. `fetch_add`'d once per -/// thread on its first arena allocation. Threads that get `idx >= MAX_THREADS` mark -/// themselves `ARENA_NO_SLAB` and permanently fall through to the system allocator. +/// Bumped by `begin_phase()`; a thread resets its slab when its cached `ARENA_GEN` lags — one store +/// resets every thread, lock-free. +static GENERATION: AtomicUsize = AtomicUsize::new(0); +/// Arena on (route to arena) vs off (route to System). +static ARENA_ACTIVE: AtomicBool = AtomicBool::new(false); +/// Process-wide opt-in; gates `begin_phase`'s all-thread reset so a stray call can't corrupt another +/// proving's buffers. Until [`enable_arena`], phases are no-ops and `ArenaVec` uses System. +static ARENA_ENGAGED: AtomicBool = AtomicBool::new(false); +/// mmap'd region base, mapped once; also the arena-vs-system discriminator in `raw_dealloc`. +static REGION: OnceLock = OnceLock::new(); +/// Slab index handed out once per thread; `idx >= MAX_THREADS` falls back to System. static THREAD_IDX: AtomicUsize = AtomicUsize::new(0); thread_local! { - /// Where this thread's next allocation lands. Advanced past each allocation. + /// This thread's next allocation address. static ARENA_PTR: Cell = const { Cell::new(0) }; - /// One past the last byte of this thread's slab. An alloc fits iff - /// `aligned + size <= ARENA_END`. + /// One past this thread's slab. static ARENA_END: Cell = const { Cell::new(0) }; - /// Base address of this thread's slab (`0` = not yet claimed). On reset, - /// `ARENA_PTR` is set back to this value. + /// This thread's slab base (`0` = unclaimed); the reset target. static ARENA_BASE: Cell = const { Cell::new(0) }; - /// Last `GENERATION` value this thread observed. When the global moves past - /// this, the next allocation resets `ARENA_PTR` to `ARENA_BASE` and updates - /// this field. + /// Last `GENERATION` seen; a mismatch triggers a slab reset. static ARENA_GEN: Cell = const { Cell::new(0) }; - /// `true` if this thread was created after `MAX_THREADS` was already exhausted. - /// Such threads skip arena logic entirely and always go to the system allocator. + /// Thread got no slab (`idx >= MAX_THREADS`) — always uses System. static ARENA_NO_SLAB: Cell = const { Cell::new(false) }; } -/// Returns the base address of the mmap'd region, mapping it on the first call. fn ensure_region() -> usize { - REGION_INIT.call_once(|| { - // SAFETY: mmap_anonymous returns a page-aligned pointer or null. MAP_NORESERVE - // means no physical memory is committed until pages are touched. - let ptr = unsafe { syscall::mmap_anonymous(REGION_SIZE) }; + *REGION.get_or_init(|| { + let size = region_size(); + // SAFETY: mmap returns a page-aligned pointer or null; lazily backed. + let ptr = unsafe { syscall::mmap_anonymous(size) }; if ptr.is_null() { std::process::abort(); } - unsafe { syscall::madvise(ptr, REGION_SIZE, syscall::MADV_NOHUGEPAGE) }; - REGION_BASE.store(ptr as usize, Ordering::Release); - }); - REGION_BASE.load(Ordering::Acquire) + unsafe { syscall::madvise(ptr, size, syscall::MADV_NOHUGEPAGE) }; + ptr as usize + }) } -/// Call once at process start, before any `begin_phase()`. -pub fn init() { - let actual_num_threads = std::thread::available_parallelism().unwrap().get(); - assert_eq!( - actual_num_threads, NUM_THREADS, - "built for {NUM_THREADS} threads but this machine reports {actual_num_threads} -> please rebuild`" - ); +/// Opt into the arena (once, at startup). Until then phases are inert and `ArenaVec` uses System. +pub fn enable_arena() { + #[cfg(target_os = "linux")] + unsafe { + // Disable heap trimming, so freed memory is kept rather than returned to the OS + libc::mallopt(libc::M_TRIM_THRESHOLD, -1); + // Disable mmap for large allocations, routing everything through the heap instead + libc::mallopt(libc::M_MMAP_MAX, 0); + } + ARENA_ENGAGED.store(true, Ordering::Release); } -/// Activates the arena and resets every thread's slab. All allocations until the next -/// `end_phase()` go to the arena; the previous phase's data is overwritten in place. +/// Activate the arena and reset every thread's slab (overwriting the previous phase). No-op until +/// [`enable_arena`]; phases must not nest. pub fn begin_phase() { + if !ARENA_ENGAGED.load(Ordering::Acquire) { + return; + } let prev_active = ARENA_ACTIVE.swap(true, Ordering::Release); - assert!( - !prev_active, - "begin_phase() called while another phase is already active — phases must not nest" - ); + assert!(!prev_active, "phases must not nest"); GENERATION.fetch_add(1, Ordering::Release); } -/// Deactivates the arena. New allocations go to the system allocator; existing arena -/// pointers stay valid until the next `begin_phase()` resets the slabs. -/// -/// Also calls [`system_info::flush_rayon`] to release any rayon/crossbeam storage -/// still referencing this phase's arena memory. +/// Deactivate the arena; existing arena pointers stay valid until the next `begin_phase()`. pub fn end_phase() { + if !ARENA_ENGAGED.load(Ordering::Acquire) { + return; + } ARENA_ACTIVE.store(false, Ordering::Release); - system_info::flush_rayon(); +} + +/// Guard that [`end_phase`]s on drop. +#[derive(Debug)] +pub struct PhaseGuard(()); + +impl Drop for PhaseGuard { + fn drop(&mut self) { + end_phase(); + } +} + +/// [`begin_phase`] + an RAII guard that [`end_phase`]s on drop (incl. early return / panic). +#[must_use = "the phase ends the moment the guard is dropped"] +pub fn enter_phase() -> PhaseGuard { + begin_phase(); + PhaseGuard(()) } #[cold] @@ -127,7 +133,7 @@ unsafe fn arena_alloc_cold(size: usize, align: usize) -> *mut u8 { if base == 0 { let region = ensure_region(); let idx = THREAD_IDX.fetch_add(1, Ordering::Relaxed); - if idx >= MAX_THREADS { + if idx >= max_threads() { ARENA_NO_SLAB.set(true); return unsafe { std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align)) }; } @@ -147,52 +153,41 @@ unsafe fn arena_alloc_cold(size: usize, align: usize) -> *mut u8 { unsafe { std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align)) } } -// SAFETY: All pointers returned are either from our mmap'd region (valid, aligned, -// non-overlapping per thread) or from System. The arena is thread-local so no data -// races. Relaxed ordering on ARENA_ACTIVE/GENERATION is sound: worst case a thread -// sees a stale value and does one extra system-alloc before picking up the new -// generation on the next call. -unsafe impl GlobalAlloc for ZkAllocator { - #[inline(always)] - unsafe fn alloc(&self, layout: Layout) -> *mut u8 { - if ARENA_ACTIVE.load(Ordering::Relaxed) { - let generation = GENERATION.load(Ordering::Relaxed); - if ARENA_GEN.get() == generation { - let align = layout.align(); - let aligned = (ARENA_PTR.get() + align - 1) & !(align - 1); - let new_ptr = aligned + layout.size(); - if new_ptr <= ARENA_END.get() { - ARENA_PTR.set(new_ptr); - return aligned as *mut u8; - } +/// [`ArenaVec`]'s allocator: bump the thread's slab in an active phase, else System. The cursor is +/// thread-local, so the Relaxed reads can't race — a stale read just costs one extra System alloc. +/// +/// # Safety +/// `align` is a power of two; the result is valid for `size` bytes (or null on System failure) until +/// the next `begin_phase()`. +#[inline(always)] +pub(crate) unsafe fn raw_alloc(size: usize, align: usize) -> *mut u8 { + if ARENA_ACTIVE.load(Ordering::Relaxed) { + let generation = GENERATION.load(Ordering::Relaxed); + if ARENA_GEN.get() == generation { + let aligned = (ARENA_PTR.get() + align - 1) & !(align - 1); + let new_ptr = aligned + size; + if new_ptr <= ARENA_END.get() { + ARENA_PTR.set(new_ptr); + return aligned as *mut u8; } - return unsafe { arena_alloc_cold(layout.size(), layout.align()) }; - } - unsafe { std::alloc::System.alloc(layout) } - } - - #[inline(always)] - unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { - let addr = ptr as usize; - let base = REGION_BASE.load(Ordering::Relaxed); - if base != 0 && addr >= base && addr < base + REGION_SIZE { - return; // arena-owned pointer — free is a no-op } - unsafe { std::alloc::System.dealloc(ptr, layout) }; + return unsafe { arena_alloc_cold(size, align) }; } + unsafe { std::alloc::System.alloc(Layout::from_size_align_unchecked(size, align)) } +} - #[inline(always)] - unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { - if new_size <= layout.size() { - return ptr; - } - // SAFETY: new_size > layout.size() > 0, align unchanged from valid layout. - let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) }; - let new_ptr = unsafe { self.alloc(new_layout) }; - if !new_ptr.is_null() { - unsafe { std::ptr::copy(ptr, new_ptr, layout.size()) }; - unsafe { self.dealloc(ptr, layout) }; - } - new_ptr +/// Free for [`raw_alloc`]: no-op for arena pointers (reclaimed at the next `begin_phase()`), else System. +/// +/// # Safety +/// `ptr` came from [`raw_alloc`] with this `size`/`align`. +#[inline(always)] +pub(crate) unsafe fn raw_dealloc(ptr: *mut u8, size: usize, align: usize) { + let addr = ptr as usize; + if REGION + .get() + .is_some_and(|&base| addr >= base && addr < base + region_size()) + { + return; // arena pointer — free is a no-op } + unsafe { std::alloc::System.dealloc(ptr, Layout::from_size_align_unchecked(size, align)) }; } diff --git a/crates/backend/zk-alloc/src/syscall.rs b/crates/backend/zk-alloc/src/syscall.rs index 13d71531d..9b62f5e1a 100644 --- a/crates/backend/zk-alloc/src/syscall.rs +++ b/crates/backend/zk-alloc/src/syscall.rs @@ -1,163 +1,50 @@ -// Raw syscalls instead of libc wrappers to avoid reentrancy: libc's mmap/madvise -// may internally call malloc, which would deadlock when called from inside -// #[global_allocator]. - -#[cfg(all(target_os = "linux", target_arch = "x86_64"))] -mod imp { - use std::ptr; - - const SYS_MMAP: usize = 9; - const SYS_MADVISE: usize = 28; - - const PROT_READ: usize = 1; - const PROT_WRITE: usize = 2; - const MAP_PRIVATE: usize = 0x02; - const MAP_ANONYMOUS: usize = 0x20; - const MAP_NORESERVE: usize = 0x4000; - - pub const MADV_NOHUGEPAGE: usize = 15; - - #[inline] - unsafe fn syscall6(nr: usize, a1: usize, a2: usize, a3: usize, a4: usize, a5: usize, a6: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "syscall", - inlateout("rax") nr as isize => ret, - in("rdi") a1, - in("rsi") a2, - in("rdx") a3, - in("r10") a4, - in("r8") a5, - in("r9") a6, - lateout("rcx") _, - lateout("r11") _, - options(nostack), - ); - } - ret - } - - #[inline] - unsafe fn syscall3(nr: usize, a1: usize, a2: usize, a3: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "syscall", - inlateout("rax") nr as isize => ret, - in("rdi") a1, - in("rsi") a2, - in("rdx") a3, - lateout("rcx") _, - lateout("r11") _, - lateout("r10") _, - options(nostack), - ); - } - ret - } - - #[inline] - pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { - let flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE; - let ret = unsafe { syscall6(SYS_MMAP, 0, size, PROT_READ | PROT_WRITE, flags, usize::MAX, 0) }; - if ret < 0 { ptr::null_mut() } else { ret as *mut u8 } - } - - #[inline] - pub unsafe fn madvise(ptr: *mut u8, size: usize, advice: usize) { - unsafe { syscall3(SYS_MADVISE, ptr as usize, size, advice) }; +//! Anonymous `mmap` + `madvise` via `libc`. +//! +//! (Raw inline-asm syscalls when zk-alloc was a `#[global_allocator]`, to avoid `libc` re-entering +//! `malloc`. It no longer is, so `libc` is safe: its internal allocations hit the system allocator, +//! not this arena.) + +use std::ptr; + +/// `madvise` advice: disable transparent huge pages for the region. Consulted only on Linux +/// (a no-op elsewhere); see [`madvise`]. +pub const MADV_NOHUGEPAGE: usize = 15; + +/// Reserve `size` bytes of anonymous virtual address space, lazily backed by physical pages. +/// +/// # Safety +/// Always safe to call; returns a page-aligned pointer or null on failure, and the caller owns the +/// resulting mapping. +#[inline] +pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { + let flags = libc::MAP_PRIVATE | libc::MAP_ANON; + // MAP_NORESERVE (Linux) keeps the huge sparse reservation from committing swap up front; macOS + // backs anonymous mappings lazily without it. + #[cfg(target_os = "linux")] + let flags = flags | libc::MAP_NORESERVE; + // SAFETY: a null `addr` lets the kernel pick the placement; `fd` is -1 for an anonymous map. + let ret = unsafe { libc::mmap(ptr::null_mut(), size, libc::PROT_READ | libc::PROT_WRITE, flags, -1, 0) }; + if ret == libc::MAP_FAILED { + ptr::null_mut() + } else { + ret.cast::() } } -#[cfg(all(target_os = "linux", target_arch = "aarch64"))] -mod imp { - use std::ptr; - - const SYS_MMAP: usize = 222; - const SYS_MADVISE: usize = 233; - - const PROT_READ: usize = 1; - const PROT_WRITE: usize = 2; - const MAP_PRIVATE: usize = 0x02; - const MAP_ANONYMOUS: usize = 0x20; - const MAP_NORESERVE: usize = 0x4000; - - pub const MADV_NOHUGEPAGE: usize = 15; - - #[inline] - unsafe fn syscall6(nr: usize, a1: usize, a2: usize, a3: usize, a4: usize, a5: usize, a6: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "svc 0", - in("x8") nr, - inlateout("x0") a1 as isize => ret, - in("x1") a2, - in("x2") a3, - in("x3") a4, - in("x4") a5, - in("x5") a6, - options(nostack), - ); - } - ret - } - - #[inline] - unsafe fn syscall3(nr: usize, a1: usize, a2: usize, a3: usize) -> isize { - let ret: isize; - unsafe { - std::arch::asm!( - "svc 0", - in("x8") nr, - inlateout("x0") a1 as isize => ret, - in("x1") a2, - in("x2") a3, - options(nostack), - ); - } - ret - } - - #[inline] - pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { - let flags = MAP_PRIVATE | MAP_ANONYMOUS | MAP_NORESERVE; - let ret = unsafe { syscall6(SYS_MMAP, 0, size, PROT_READ | PROT_WRITE, flags, usize::MAX, 0) }; - if ret < 0 { ptr::null_mut() } else { ret as *mut u8 } - } - - #[inline] - pub unsafe fn madvise(ptr: *mut u8, size: usize, advice: usize) { - unsafe { syscall3(SYS_MADVISE, ptr as usize, size, advice) }; - } -} - -#[cfg(not(all(target_os = "linux", any(target_arch = "x86_64", target_arch = "aarch64"))))] -mod imp { - use std::ptr; - - pub const MADV_NOHUGEPAGE: usize = 15; - - #[inline] - pub unsafe fn mmap_anonymous(size: usize) -> *mut u8 { - // MAP_NORESERVE is Linux-only. macOS lazily backs anonymous mappings - // with physical memory by default, so the large virtual reservation - // is fine without NORESERVE. - let prot = libc::PROT_READ | libc::PROT_WRITE; - let flags = libc::MAP_PRIVATE | libc::MAP_ANON; - let ret = unsafe { libc::mmap(ptr::null_mut(), size, prot, flags, -1, 0) }; - if ret == libc::MAP_FAILED { - ptr::null_mut() - } else { - ret.cast::() - } - } - - #[inline] - pub unsafe fn madvise(_ptr: *mut u8, _size: usize, _advice: usize) { - // The advice values we pass are Linux-specific. +/// Apply `advice` to `[ptr, ptr + size)`. No-op on non-Linux (the advice values we use are +/// Linux-specific). +/// +/// # Safety +/// `ptr`/`size` must describe a live mapping returned by [`mmap_anonymous`]. +#[inline] +pub unsafe fn madvise(ptr: *mut u8, size: usize, advice: usize) { + #[cfg(target_os = "linux")] + unsafe { + // SAFETY: the caller guarantees `[ptr, ptr + size)` is a live mapping. + libc::madvise(ptr.cast::(), size, advice as libc::c_int); + } + #[cfg(not(target_os = "linux"))] + { + let _ = (ptr, size, advice); } } - -pub use imp::{MADV_NOHUGEPAGE, madvise, mmap_anonymous}; diff --git a/crates/backend/zk-alloc/tests/test_alloc.rs b/crates/backend/zk-alloc/tests/test_alloc.rs new file mode 100644 index 000000000..e78583d9e --- /dev/null +++ b/crates/backend/zk-alloc/tests/test_alloc.rs @@ -0,0 +1,49 @@ +//! `ArenaVec` drives the arena explicitly, with the process keeping its **own** allocator (no +//! `#[global_allocator]` is installed here). Only `ArenaVec`-backed buffers touch the arena; +//! everything else is untouched by a phase reset — the property that lets a library use the +//! arena without forcing its allocator on consumers. + +use zk_alloc::{ArenaVec, begin_phase, enable_arena, end_phase}; + +const N: usize = 4096; + +#[test] +fn arena_vec_without_global_allocator() { + // Opt into the arena: without this, begin_phase/end_phase are inert and ArenaVec would + // transparently use the system allocator (no slab reuse to observe). + enable_arena(); + + // Phase 1: one arena allocation on this (main) thread → claims the slab at its base. + begin_phase(); + let mut v: ArenaVec = ArenaVec::with_capacity(N); + v.resize(N, 0xABCD); // fits the reservation: no realloc, pointer stays put + let p1 = v.as_ptr() as usize; + end_phase(); + + // Arena is off: this lands in the system allocator and must survive the next reset. + let canary = vec![0xAB_u8; 8192]; + + // Phase 2: the slab is reset, so an identically-shaped buffer reuses the same address. + begin_phase(); + let mut w: ArenaVec = ArenaVec::with_capacity(N); + w.resize(N, 0x1234); + let p2 = w.as_ptr() as usize; + end_phase(); + + assert_eq!( + p1, p2, + "phase reset should recycle the slab — ArenaVec must hit the arena" + ); + assert!( + canary.iter().all(|&b| b == 0xAB), + "a system allocation was corrupted by the arena reset" + ); + + // Outside any phase, ArenaVec transparently uses the system allocator (no panic). + let mut off: ArenaVec = ArenaVec::new(); + off.extend(0..1000); + assert_eq!(off.iter().sum::(), (0..1000).sum()); + + drop(v); + drop(w); +} diff --git a/crates/backend/zk-alloc/tests/test_rayon.rs b/crates/backend/zk-alloc/tests/test_rayon.rs deleted file mode 100644 index ae084af21..000000000 --- a/crates/backend/zk-alloc/tests/test_rayon.rs +++ /dev/null @@ -1,26 +0,0 @@ -//! Regression test for the bug prevented by `system_info::flush_rayon`. - -use rayon::prelude::*; - -#[global_allocator] -static A: zk_alloc::ZkAllocator = zk_alloc::ZkAllocator; - -#[test] -fn rayon_does_not_corrupt_zkalloc() { - zk_alloc::init(); - let _: u64 = (0..1_000_000_u64).into_par_iter().sum(); - - zk_alloc::begin_phase(); - for _ in 0..200 { - rayon::join(|| {}, || {}); - } - zk_alloc::end_phase(); - - zk_alloc::begin_phase(); - let canary = vec![0xAB_u8; 8192]; - rayon::join(|| {}, || {}); - zk_alloc::end_phase(); - - let pos = canary.iter().position(|&b| b != 0xAB); - assert!(pos.is_none(), "canary corrupted at offset {}", pos.unwrap()); -} diff --git a/crates/lean_compiler/Cargo.toml b/crates/lean_compiler/Cargo.toml index f6394cb09..f32242844 100644 --- a/crates/lean_compiler/Cargo.toml +++ b/crates/lean_compiler/Cargo.toml @@ -10,7 +10,6 @@ workspace = true pest.workspace = true pest_derive.workspace = true utils.workspace = true -xmss.workspace = true rand.workspace = true tracing.workspace = true diff --git a/crates/lean_compiler/snark_lib.py b/crates/lean_compiler/snark_lib.py index bec599832..f430e1443 100644 --- a/crates/lean_compiler/snark_lib.py +++ b/crates/lean_compiler/snark_lib.py @@ -73,6 +73,16 @@ def poseidon16_permute(left, right, output): _ = left, right, output +def poseidon24_compress_0_9(left, right, output): + _ = left, right, output + +def poseidon24_permute_0_9(left, right, output): + _ = left, right, output + +def poseidon24_permute_9_18(left, right, output): + _ = left, right, output + + def add_be(a, b, result, length=None): _ = a, b, result, length diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index b36419ff5..a5480e558 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -3,7 +3,7 @@ use backend::PrimeCharacteristicRing; use lean_vm::{ ALL_POSEIDON16_NAMES, Boolean, BooleanExpr, CustomHint, ExtensionOpMode, FunctionName, POSEIDON16_HALF_HARDCODED_LEFT_NAME, POSEIDON16_HALF_NAME, POSEIDON16_HARDCODED_LEFT_NAME, POSEIDON16_PERMUTE_NAME, - PrecompileArgs, PrecompileCompTimeArgs, SourceLocation, + Poseidon24Mode, PrecompileArgs, PrecompileCompTimeArgs, SourceLocation, }; use std::{ collections::{BTreeMap, BTreeSet}, @@ -1852,6 +1852,38 @@ fn simplify_lines( continue; } + // Special handling for poseidon24 precompile (3 variants). + let p24_mode = match function_name.as_str() { + "poseidon24_compress_0_9" => Some(Poseidon24Mode::Compress0_9), + "poseidon24_permute_0_9" => Some(Poseidon24Mode::Permute0_9), + "poseidon24_permute_9_18" => Some(Poseidon24Mode::Permute9_18), + _ => None, + }; + if let Some(mode) = p24_mode { + if !targets.is_empty() { + return Err(format!( + "Precompile {function_name} should not return values, at {location}" + )); + } + if args.len() != 3 { + return Err(format!( + "Precompile {function_name} expects 3 arguments (ptr_a, ptr_b, ptr_res), got {}, at {location}", + args.len() + )); + } + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + res.push(SimpleLine::Precompile(PrecompileArgs { + arg_0: simplified_args[0].clone(), + arg_1: simplified_args[1].clone(), + res: simplified_args[2].clone(), + data: PrecompileCompTimeArgs::Poseidon24(mode), + })); + continue; + } + // Special handling for poseidon16 precompile (5 variants). if ALL_POSEIDON16_NAMES.contains(&function_name.as_str()) { if !targets.is_empty() { diff --git a/crates/lean_compiler/src/c_compile_final.rs b/crates/lean_compiler/src/c_compile_final.rs index 1a3397d27..70bbba866 100644 --- a/crates/lean_compiler/src/c_compile_final.rs +++ b/crates/lean_compiler/src/c_compile_final.rs @@ -132,7 +132,8 @@ pub fn compile_to_low_level_bytecode( validate_instruction(instruction)?; } - let instructions_encoded = instructions.par_iter().map(field_representation).collect::>(); + let instructions_encoded = + parallel::par_map_collect(instructions.len(), |i| field_representation(&instructions[i])); let mut instructions_multilinear = vec![]; for instr in &instructions_encoded { diff --git a/crates/lean_compiler/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs index df6e22b17..f7b597355 100644 --- a/crates/lean_compiler/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -61,6 +61,9 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { + POSEIDON_FLAG_LEFT_SHIFT * flag_left + POSEIDON_OFFSET_LEFT_SHIFT * offset_left_val } + PrecompileCompTimeArgs::Poseidon24(mode) => { + POSEIDON_24_DOMAINSEP_BASE + POSEIDON_24_DOMAINSEP_STEP * mode.as_usize() + } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { assert!(*size >= 1, "invalid extension_op size={size}"); mode.flag_encoding() + EXT_OP_LEN_MULTIPLIER * size diff --git a/crates/lean_compiler/src/lib.rs b/crates/lean_compiler/src/lib.rs index 3564cfaa6..5cdd70828 100644 --- a/crates/lean_compiler/src/lib.rs +++ b/crates/lean_compiler/src/lib.rs @@ -11,7 +11,7 @@ use crate::{ mod a_simplify_lang; mod b_compile_intermediate; mod c_compile_final; -mod instruction_encoder; +pub mod instruction_encoder; pub mod ir; mod lang; mod parser; diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 0334f5dc8..e8b3542e3 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -8,7 +8,7 @@ use crate::{ grammar::{ParsePair, Rule}, }, }; -use lean_vm::{ALL_POSEIDON16_NAMES, CUSTOM_HINTS, ExtensionOpMode}; +use lean_vm::{CUSTOM_HINTS, ExtensionOpMode}; /// Reserved function names that users cannot define. pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ @@ -26,6 +26,10 @@ pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ "range", "parallel_range", "match_range", + "poseidon16_compress", + "poseidon24_compress_0_9", + "poseidon24_permute_0_9", + "poseidon24_permute_9_18", ]; /// Check if a function name is reserved. @@ -34,9 +38,6 @@ fn is_reserved_function_name(name: &str) -> bool { if RESERVED_FUNCTION_NAMES.contains(&name) || CUSTOM_HINTS.iter().any(|hint| hint.name() == name) { return true; } - if ALL_POSEIDON16_NAMES.contains(&name) { - return true; - } if ExtensionOpMode::from_name(name).is_some() { return true; } diff --git a/crates/lean_compiler/zkDSL.md b/crates/lean_compiler/zkDSL.md index 5cd4f92c9..fac3c3330 100644 --- a/crates/lean_compiler/zkDSL.md +++ b/crates/lean_compiler/zkDSL.md @@ -639,8 +639,8 @@ The full list: Precompiles are special instructions in the leanVM ISA, alongside the four basic ones (ADD, MUL, DEREF, JUMP). The zkDSL exposes them as built-in -functions. There are two families: Poseidon hashing and extension-field -operations. +functions. There are three families: Poseidon16 hashing, Poseidon24 hashing, +and extension-field operations. ### Poseidon16 family @@ -666,6 +666,17 @@ buffer; `off` (where present) is a compile-time address. | `poseidon16_compress_half_hardcoded_left(L, R, O, off)` | `O[0..4]` | half-output + hardcoded-left composition | | `poseidon16_permute(L, R, O)` | `O[0..16]` | raw Poseidon permutation, no feed-forward | +### Poseidon24 family + +Width-24 Poseidon. `L` is a 9-cell buffer (capacity), `R` a 15-cell buffer +(rate), `O` the 9-cell output buffer. + +| Function | Cells written to `O` | Notes | +| ----------------------------------- | -------------------- | ------------------------------------------- | +| `poseidon24_compress_0_9(L, R, O)` | `O[0..9]` | `(Poseidon(L \|\| R) + (L \|\| R))[0..9]` | +| `poseidon24_permute_0_9(L, R, O)` | `O[0..9]` | raw permutation, output cells `0..9` | +| `poseidon24_permute_9_18(L, R, O)` | `O[0..9]` | raw permutation, output cells `9..18` | + ### Extension-field operations Six built-in functions, each reading two length-`n` vectors `a` and `b` and diff --git a/crates/lean_prover/Cargo.toml b/crates/lean_prover/Cargo.toml index 2163ed200..2bc77408e 100644 --- a/crates/lean_prover/Cargo.toml +++ b/crates/lean_prover/Cargo.toml @@ -13,9 +13,7 @@ prox-gaps-conjecture = [] pest.workspace = true pest_derive.workspace = true utils.workspace = true -xmss.workspace = true rand.workspace = true - tracing.workspace = true sub_protocols.workspace = true lean_vm.workspace = true @@ -25,6 +23,5 @@ itertools.workspace = true serde.workspace = true [dev-dependencies] -xmss.workspace = true rec_aggregation.workspace = true serde_json.workspace = true diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index dfcf68803..30ae3ab7e 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -28,7 +28,7 @@ pub const WHIR_SUBSEQUENT_FOLDING_FACTOR: usize = 5; pub const RS_DOMAIN_INITIAL_REDUCTION_FACTOR: usize = 5; pub const SNARK_DOMAIN_SEP: [F; 8] = F::new_array([ - 130704175, 1303721200, 493664240, 1035493700, 2063844858, 1410214009, 1938905908, 1696767928, + 1046873597, 587403661, 1441000407, 1547181303, 1522249642, 1883305763, 367566943, 2033638717, ]); pub fn fiat_shamir_domain_sep(bytecode: &Bytecode) -> [F; 8] { diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index aaf50be3b..9fbc5161a 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use crate::*; +use backend::ArenaVec; use lean_vm::*; use serde::{Deserialize, Serialize}; @@ -81,7 +82,7 @@ pub fn prove_execution( tracing::info!("Trace tables sizes: {}", table_log.magenta()); // TODO parrallelize - let mut memory_acc = F::zero_vec(memory.len()); + let mut memory_acc = unsafe { ArenaVec::::zeroed(memory.len()) }; info_span!("Building memory access count").in_scope(|| -> Result<(), ProverError> { for (table, trace) in &traces { let buses = table.bus_interactions(); @@ -101,7 +102,7 @@ pub fn prove_execution( })?; // // TODO parrallelize - let mut bytecode_acc = F::zero_vec(bytecode.padded_size()); + let mut bytecode_acc = unsafe { ArenaVec::::zeroed(bytecode.padded_size()) }; info_span!("Building bytecode access count").in_scope(|| -> Result<(), ProverError> { for pc in traces[&Table::execution()].columns[EXEC_COL_PC].iter() { *bytecode_acc.get_mut(pc.to_usize()).ok_or(RunnerError::PCOutOfBounds)? += F::ONE; @@ -160,13 +161,13 @@ pub fn prove_execution( .map(|table| { traces[table].columns[..table.n_columns()] .iter() - .map(Vec::as_slice) + .map(|c| c.as_slice()) .collect() }) .collect(); let _span = info_span!("Computing shifted columns for AIR sumcheck").entered(); - let shifted_rows: Vec>> = ALL_TABLES - .par_iter() + let shifted_rows: Vec>> = ALL_TABLES + .iter() .zip(&column_refs) .map(|(table, cols)| compute_shifted_columns(table.n_shift_columns(), cols)) .collect(); @@ -192,10 +193,10 @@ pub fn prove_execution( let eq_suffix = from_end(gkr_point, log_n_rows).to_vec(); let alpha_slice = air_alpha_powers[alpha_offset..alpha_offset + n_constraints].to_vec(); - let extra_data = ExtraDataForBuses::new(logup_alphas_eq_poly.clone(), alpha_slice); + let extra_data = ExtraDataForBuses::new(&logup_alphas_eq_poly, alpha_slice); let mut flat_and_shift: Vec<&[PF]> = column_refs[idx].to_vec(); - flat_and_shift.extend(shifted_rows[idx].iter().map(Vec::as_slice)); + flat_and_shift.extend(shifted_rows[idx].iter().map(|c| c.as_slice())); let packed = MleGroupRef::::Base(flat_and_shift).pack(); let non_padded = traces[table].non_padded_n_rows; diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 5bee4a97a..9eebd5b1c 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -1,12 +1,13 @@ use backend::*; use lean_vm::*; use std::{array, collections::BTreeMap}; -use utils::{ToUsize, get_poseidon_16_of_zero, transposed_par_iter_mut}; +use tracing::info_span; +use utils::{ToUsize, get_poseidon_16_of_zero, get_poseidon_24_of_zero, transposed_par_for_each_mut}; #[derive(Debug)] pub struct ExecutionTrace { pub traces: BTreeMap, - pub memory: Vec, // of length a multiple of public_memory_size + pub memory: ArenaVec, // of length a multiple of public_memory_size pub metadata: ExecutionMetadata, } @@ -19,88 +20,94 @@ pub fn get_execution_trace( let n_cycles = execution_result.pcs.len(); let memory = &execution_result.memory; - let mut main_trace: [Vec; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS] = - array::from_fn(|_| F::zero_vec(n_cycles.next_power_of_two())); + let mut main_trace: [ArenaVec; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS] = + array::from_fn(|_| unsafe { ArenaVec::::zeroed(n_cycles.next_power_of_two()) }); for col in &mut main_trace { unsafe { col.set_len(n_cycles); } } - transposed_par_iter_mut(&mut main_trace) - .zip(execution_result.pcs.par_iter()) - .zip(execution_result.fps.par_iter()) - .for_each(|((trace_row, &pc), &fp)| { - let instruction = &bytecode.code[pc].instruction; - let field_repr = &bytecode.instructions_multilinear[pc * N_INSTRUCTION_COLUMNS.next_power_of_two()..] - [..N_INSTRUCTION_COLUMNS]; - - let flag_a = field_repr[instr_idx(EXEC_COL_FLAG_A)]; - let flag_b = field_repr[instr_idx(EXEC_COL_FLAG_B)]; - let flag_c = field_repr[instr_idx(EXEC_COL_FLAG_C)]; - let flag_c_fp = field_repr[instr_idx(EXEC_COL_FLAG_C_FP)]; - let flag_ab_fp = field_repr[instr_idx(EXEC_COL_FLAG_AB_FP)]; - let aux_1 = field_repr[instr_idx(EXEC_COL_AUX_1)]; - let is_deref = aux_1 == F::TWO; - - let mut addr_a = F::ZERO; - if flag_a.is_zero() && flag_ab_fp.is_zero() { - addr_a = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]; - } - let value_a = memory.0.get(addr_a.to_usize()).copied().flatten().unwrap_or_default(); - - let mut addr_b = F::ZERO; - if flag_b.is_zero() && flag_ab_fp.is_zero() { - addr_b = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; - } else if is_deref { - // DEREF: addr_B = value_A + operand_B - addr_b = value_a + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; - } - let value_b = memory.0.get(addr_b.to_usize()).copied().flatten().unwrap_or_default(); + transposed_par_for_each_mut(&mut main_trace, |i, trace_row| { + let pc = execution_result.pcs[i]; + let fp = execution_result.fps[i]; + let instruction = &bytecode.code[pc].instruction; + let field_repr = &bytecode.instructions_multilinear[pc * N_INSTRUCTION_COLUMNS.next_power_of_two()..] + [..N_INSTRUCTION_COLUMNS]; + + let flag_a = field_repr[instr_idx(EXEC_COL_FLAG_A)]; + let flag_b = field_repr[instr_idx(EXEC_COL_FLAG_B)]; + let flag_c = field_repr[instr_idx(EXEC_COL_FLAG_C)]; + let flag_c_fp = field_repr[instr_idx(EXEC_COL_FLAG_C_FP)]; + let flag_ab_fp = field_repr[instr_idx(EXEC_COL_FLAG_AB_FP)]; + let aux_1 = field_repr[instr_idx(EXEC_COL_AUX_1)]; + let is_deref = aux_1 == F::TWO; + + let mut addr_a = F::ZERO; + if flag_a.is_zero() && flag_ab_fp.is_zero() { + addr_a = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]; + } + let value_a = memory.0.get(addr_a.to_usize()).copied().flatten().unwrap_or_default(); + + let mut addr_b = F::ZERO; + if flag_b.is_zero() && flag_ab_fp.is_zero() { + addr_b = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; + } else if is_deref { + // DEREF: addr_B = value_A + operand_B + addr_b = value_a + field_repr[instr_idx(EXEC_COL_OPERAND_B)]; + } + let value_b = memory.0.get(addr_b.to_usize()).copied().flatten().unwrap_or_default(); - let mut addr_c = F::ZERO; - if flag_c.is_zero() && flag_c_fp.is_zero() { - addr_c = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]; - } - let value_c = memory.0.get(addr_c.to_usize()).copied().flatten().unwrap_or_default(); + let mut addr_c = F::ZERO; + if flag_c.is_zero() && flag_c_fp.is_zero() { + addr_c = F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]; + } + let value_c = memory.0.get(addr_c.to_usize()).copied().flatten().unwrap_or_default(); - for (j, field) in field_repr.iter().enumerate() { - *trace_row[j + N_RUNTIME_COLUMNS] = *field; - } + for (j, field) in field_repr.iter().enumerate() { + *trace_row[j + N_RUNTIME_COLUMNS] = *field; + } - let nu_a = flag_a * field_repr[instr_idx(EXEC_COL_OPERAND_A)] - + (F::ONE - flag_a - flag_ab_fp) * value_a - + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]); - let nu_b = flag_b * field_repr[instr_idx(EXEC_COL_OPERAND_B)] - + (F::ONE - flag_b - flag_ab_fp) * value_b - + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]); - let nu_c = flag_c * field_repr[instr_idx(EXEC_COL_OPERAND_C)] - + (F::ONE - flag_c - flag_c_fp) * value_c - + flag_c_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]); - if let Instruction::Precompile(..) = instruction { - *trace_row[EXEC_COL_FLAG_PRECOMPILE] = F::ONE; - } - *trace_row[EXEC_COL_NU_A] = nu_a; - *trace_row[EXEC_COL_NU_B] = nu_b; - *trace_row[EXEC_COL_NU_C] = nu_c; - - *trace_row[EXEC_COL_VALUE_A] = value_a; - *trace_row[EXEC_COL_VALUE_B] = value_b; - *trace_row[EXEC_COL_VALUE_C] = value_c; - *trace_row[EXEC_COL_PC] = F::from_usize(pc); - *trace_row[EXEC_COL_FP] = F::from_usize(fp); - *trace_row[EXEC_COL_ADDR_A] = addr_a; - *trace_row[EXEC_COL_ADDR_B] = addr_b; - *trace_row[EXEC_COL_ADDR_C] = addr_c; - }); + let nu_a = flag_a * field_repr[instr_idx(EXEC_COL_OPERAND_A)] + + (F::ONE - flag_a - flag_ab_fp) * value_a + + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_A)]); + let nu_b = flag_b * field_repr[instr_idx(EXEC_COL_OPERAND_B)] + + (F::ONE - flag_b - flag_ab_fp) * value_b + + flag_ab_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_B)]); + let nu_c = flag_c * field_repr[instr_idx(EXEC_COL_OPERAND_C)] + + (F::ONE - flag_c - flag_c_fp) * value_c + + flag_c_fp * (F::from_usize(fp) + field_repr[instr_idx(EXEC_COL_OPERAND_C)]); + if let Instruction::Precompile(..) = instruction { + *trace_row[EXEC_COL_FLAG_PRECOMPILE] = F::ONE; + } + *trace_row[EXEC_COL_NU_A] = nu_a; + *trace_row[EXEC_COL_NU_B] = nu_b; + *trace_row[EXEC_COL_NU_C] = nu_c; + + *trace_row[EXEC_COL_VALUE_A] = value_a; + *trace_row[EXEC_COL_VALUE_B] = value_b; + *trace_row[EXEC_COL_VALUE_C] = value_c; + *trace_row[EXEC_COL_PC] = F::from_usize(pc); + *trace_row[EXEC_COL_FP] = F::from_usize(fp); + *trace_row[EXEC_COL_ADDR_A] = addr_a; + *trace_row[EXEC_COL_ADDR_B] = addr_b; + *trace_row[EXEC_COL_ADDR_C] = addr_c; + }); - let mut memory_padded = memory.0.par_iter().map(|&v| v.unwrap_or(F::ZERO)).collect::>(); + let mut memory_padded: ArenaVec = ArenaVec::par_collect(memory.0.len(), |i| memory.0[i].unwrap_or(F::ZERO)); - // Write [0000000000000000 | poseidon_compress(0000000000000000)] (to make lookups work on padding-rows). + // Write [0000000000000000 | poseidon16_compress(0000000000000000) | poseidon24_compress(000000000000000000000000)] (to make lookups work on padding-rows). let padding_zero_vec_ptr = memory_padded.len(); memory_padded.extend(std::iter::repeat_n(F::ZERO, 16)); let null_poseidon_16_hash_ptr = memory_padded.len(); memory_padded.extend_from_slice(get_poseidon_16_of_zero()); + // poseidon16 padding rows read DIGEST_LEN*2 = 16 cells from `null_poseidon_16_hash_ptr` (the + // compression output lookup spans OUT_LO|OUT_HI), with OUT_HI = 0. Pad the null-16 region to 16 + // cells of [hash(8) | 0(8)] so that read matches memory; otherwise the following null-24 hash + // would occupy those cells and break the Logup balance on padding rows. + memory_padded.extend(std::iter::repeat_n(F::ZERO, 8)); + let null_poseidon_24_hash_ptr = memory_padded.len(); + memory_padded.extend_from_slice(get_poseidon_24_of_zero()); // IMPORTANT: memory size should always be >= number of VM cycles let padded_memory_len = (memory_padded.len().max(n_cycles).max(1 << MIN_LOG_N_ROWS_PER_TABLE)).next_power_of_two(); @@ -120,25 +127,24 @@ pub fn get_execution_trace( let permute_col = &left[POSEIDON_COL_FLAG_PERMUTE]; let nu_c_col = &left[POSEIDON_COL_NU_C]; const N: usize = HALF_DIGEST_LEN + DIGEST_LEN; - let cols: &mut [Vec; N] = (&mut right[..N]).try_into().unwrap(); - - transposed_par_iter_mut(cols) - .zip(flag_short_col) - .zip(permute_col) - .zip(nu_c_col) - .for_each(|(((row, &flag_short), &permute), &nu_c)| { - if permute == F::ZERO { - let base = nu_c.to_usize(); - if flag_short == F::ONE { - for j in 0..HALF_DIGEST_LEN { - *row[j] = memory_padded[base + HALF_DIGEST_LEN + j]; - } - } - for j in 0..DIGEST_LEN { - *row[HALF_DIGEST_LEN + j] = memory_padded[base + DIGEST_LEN + j]; + let cols: &mut [ArenaVec; N] = (&mut right[..N]).try_into().unwrap(); + + transposed_par_for_each_mut(cols, |i, row| { + let flag_short = flag_short_col[i]; + let permute = permute_col[i]; + let nu_c = nu_c_col[i]; + if permute == F::ZERO { + let base = nu_c.to_usize(); + if flag_short == F::ONE { + for j in 0..HALF_DIGEST_LEN { + *row[j] = memory_padded[base + HALF_DIGEST_LEN + j]; } } - }); + for j in 0..DIGEST_LEN { + *row[HALF_DIGEST_LEN + j] = memory_padded[base + DIGEST_LEN + j]; + } + } + }); } let extension_op_trace = traces.get_mut(&Table::extension_op()).unwrap(); @@ -163,11 +169,33 @@ pub fn get_execution_trace( &mut traces, padding_zero_vec_ptr, null_poseidon_16_hash_ptr, + null_poseidon_24_hash_ptr, bytecode.ending_pc, floor, ); } + // Ensure poseidon24 is always the smallest (last) table by padding other tables if needed. + // The recursive aggregation verifier assumes this ordering. + let p24_log = traces[&Table::poseidon24()].log_n_rows; + for &table in &[Table::extension_op(), Table::poseidon16()] { + if traces[&table].log_n_rows < p24_log { + let target = 1usize << p24_log; + let trace = traces.get_mut(&table).unwrap(); + let padding = table.padding_row(padding_zero_vec_ptr, null_poseidon_16_hash_ptr, bytecode.ending_pc); + for (col, val) in trace.columns.iter_mut().zip(padding.iter()) { + col.resize(target, *val); + } + trace.log_n_rows = p24_log; + } + } + + // Fill AIR trace columns (intermediate round states + outputs). + // poseidon16 is filled earlier (before padding) together with its output override. + info_span!("Poseidon AIR trace fill").in_scope(|| { + fill_trace_poseidon_24(&mut traces.get_mut(&Table::poseidon24()).unwrap().columns); + }); + ExecutionTrace { traces, memory: memory_padded, @@ -180,22 +208,26 @@ fn pad_table( traces: &mut BTreeMap, zero_vec_ptr: usize, null_poseidon_16_hash_ptr: usize, + null_poseidon_24_hash_ptr: usize, ending_pc: usize, min_log_n_rows: usize, ) { let trace = traces.get_mut(table).unwrap(); let h = trace.columns[0].len(); - trace - .columns - .iter() - .enumerate() - .for_each(|(i, col)| assert_eq!(col.len(), h, "column {}, table {}", i, table.name())); trace.non_padded_n_rows = h; trace.log_n_rows = log2_ceil_usize(h + 1).max(min_log_n_rows); let n_rows = 1 << trace.log_n_rows; - let padding_row = table.padding_row(zero_vec_ptr, null_poseidon_16_hash_ptr, ending_pc); - trace.columns.par_iter_mut().enumerate().for_each(|(i, col)| { + // Each table interprets the null-hash argument it needs; poseidon24 uses the width-24 + // null hash, all others use the width-16 one (or ignore it). `ending_pc` is used by the + // execution table only. + let null_hash_ptr = if *table == Table::poseidon24() { + null_poseidon_24_hash_ptr + } else { + null_poseidon_16_hash_ptr + }; + let padding_row = table.padding_row(zero_vec_ptr, null_hash_ptr, ending_pc); + parallel::par_for_each_mut(&mut trace.columns, |i, col| { assert!(col.len() <= h); // potentially some columns have not been filled (in Poseidon -> we fill it later with SIMD + parallelism), but the first one should always be representative col.resize(n_rows, padding_row[i]); }); diff --git a/crates/lean_prover/src/verify_execution.rs b/crates/lean_prover/src/verify_execution.rs index f49ae8176..80899a5e1 100644 --- a/crates/lean_prover/src/verify_execution.rs +++ b/crates/lean_prover/src/verify_execution.rs @@ -126,7 +126,7 @@ pub fn verify_execution( let alpha_slice = air_alpha_powers[alpha_offset..alpha_offset + n_constraints].to_vec(); verify_data.push(TableVerifyData { table, - extra_data: ExtraDataForBuses::new(logup_alphas_eq_poly.clone(), alpha_slice), + extra_data: ExtraDataForBuses::new(&logup_alphas_eq_poly, alpha_slice), }); alpha_offset += n_constraints; diff --git a/crates/lean_vm/Cargo.toml b/crates/lean_vm/Cargo.toml index 2b5d1832e..63e4058ca 100644 --- a/crates/lean_vm/Cargo.toml +++ b/crates/lean_vm/Cargo.toml @@ -10,8 +10,9 @@ workspace = true pest.workspace = true pest_derive.workspace = true utils.workspace = true -xmss.workspace = true rand.workspace = true +leansig_wrapper.workspace = true tracing.workspace = true backend.workspace = true -itertools.workspace = true \ No newline at end of file +itertools.workspace = true +serde.workspace = true diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index 8c5baa9fb..ceaebe7a5 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -23,10 +23,11 @@ pub const MAX_BYTECODE_LOG_SIZE: usize = 22; /// Minimum and maximum number of rows per table (as powers of two), both inclusive pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to each at least, if this minimum is not reached, (ensuring AIR / GKR work fine, with SIMD, without too much edge cases). Long term, we should find a more elegant solution. -pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [ - (Table::execution(), 24), - (Table::extension_op(), 21), - (Table::poseidon16(), 21), +pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 4] = [ + (Table::execution(), 25), + (Table::extension_op(), 20), + (Table::poseidon16(), 19), + (Table::poseidon24(), 19), ]; pub fn max_log_n_rows_per_table(table: &Table) -> usize { diff --git a/crates/lean_vm/src/core/label.rs b/crates/lean_vm/src/core/label.rs index 7d2190dee..ee8d03d8f 100644 --- a/crates/lean_vm/src/core/label.rs +++ b/crates/lean_vm/src/core/label.rs @@ -1,7 +1,7 @@ use crate::SourceLocation; /// Structured label for bytecode locations -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)] pub enum Label { /// Function entry point: @function_{name} Function(String), @@ -26,7 +26,7 @@ pub enum Label { Custom(String), } -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)] pub enum IfKind { /// @if_{id} If, @@ -36,7 +36,7 @@ pub enum IfKind { End, } -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)] pub enum AuxKind { /// @aux_var_{id} AuxVar, diff --git a/crates/lean_vm/src/core/types.rs b/crates/lean_vm/src/core/types.rs index fbad9af10..f95fb6e16 100644 --- a/crates/lean_vm/src/core/types.rs +++ b/crates/lean_vm/src/core/types.rs @@ -27,7 +27,7 @@ pub type FunctionName = String; pub type FileId = usize; /// Location in source code -#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)] +#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy, serde::Serialize, serde::Deserialize)] pub struct SourceLocation { pub file_id: FileId, pub line_number: SourceLineNumber, diff --git a/crates/lean_vm/src/diagnostics/exec_result.rs b/crates/lean_vm/src/diagnostics/exec_result.rs index 2024fa083..82ffed943 100644 --- a/crates/lean_vm/src/diagnostics/exec_result.rs +++ b/crates/lean_vm/src/diagnostics/exec_result.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use backend::pretty_integer; +use backend::{ArenaVec, pretty_integer}; use crate::execution::Memory; use crate::{Table, TableTrace}; @@ -73,8 +73,8 @@ impl ExecutionMetadata { pub struct ExecutionResult { pub runtime_memory_size: usize, pub memory: Memory, - pub pcs: Vec, - pub fps: Vec, + pub pcs: ArenaVec, + pub fps: ArenaVec, pub traces: BTreeMap, pub metadata: ExecutionMetadata, } diff --git a/crates/lean_vm/src/execution/memory.rs b/crates/lean_vm/src/execution/memory.rs index 364eb7471..53cc895c0 100644 --- a/crates/lean_vm/src/execution/memory.rs +++ b/crates/lean_vm/src/execution/memory.rs @@ -77,7 +77,9 @@ impl MemoryAccess for Memory { impl Memory { pub fn new(public_memory: Vec) -> Self { - Self(public_memory.into_par_iter().map(Some).collect()) + Self(parallel::par_map_collect(public_memory.len(), |i| { + Some(public_memory[i]) + })) } pub fn get(&self, index: usize) -> Result { diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index f00e04880..cc64fb0dd 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -96,7 +96,7 @@ impl Trace { for (table, other_t) in other.tables { let mine = self.tables.get_mut(&table).unwrap(); for (col, new_data) in mine.columns.iter_mut().zip(other_t.columns) { - col.extend(new_data); + col.extend_from_slice(&new_data); } } } @@ -333,7 +333,12 @@ fn execute_bytecode_helper( None }; let runtime_memory_size = memory.0.len() - PUBLIC_INPUT_LEN - witness.preamble_memory_len; - let used_memory_cells = memory.0.par_iter().filter(|&&x| x.is_some()).count(); + let used_memory_cells = parallel::map_reduce( + memory.0.len(), + || 0usize, + |i| usize::from(memory.0[i].is_some()), + |a, b| a + b, + ); let metadata = ExecutionMetadata { cycles: trace.pcs.len(), memory: memory.0.len(), @@ -349,8 +354,8 @@ fn execute_bytecode_helper( Ok(ExecutionResult { runtime_memory_size: no_vec_runtime_memory, memory, - pcs: trace.pcs, - fps: trace.fps, + pcs: ArenaVec::from_slice(&trace.pcs), + fps: ArenaVec::from_slice(&trace.fps), traces: trace.tables, metadata, }) @@ -432,15 +437,16 @@ fn handle_parallel_batch( let split_at = batch.batch_fp + stride; // end of iteration 0's frame let (left, right) = memory.0.split_at_mut(split_at); let shared: &[Option] = &*left; - let segment_slices: Vec<&mut [Option]> = right.chunks_mut(stride).take(n_par).collect(); + let mut segment_slices: Vec<&mut [Option]> = right.chunks_mut(stride).take(n_par).collect(); type SegResult = Result<(Trace, Vec<(usize, F)>), RunnerError>; - let results: Vec = segment_slices - .into_par_iter() - .enumerate() - .map(|(i, seg_slice)| { + let mut results: Vec = (0..segment_slices.len()) + .map(|_| Ok((Trace::new(), Vec::new()))) + .collect(); + parallel::par_for_each_mut2(&mut segment_slices, &mut results, |i, seg_slice, result| { + *result = (|| -> SegResult { let seg_start = split_at + i * stride; - let mut seg_mem = SegmentMemory::new(shared, seg_slice, seg_start); + let mut seg_mem = SegmentMemory::new(shared, &mut **seg_slice, seg_start); let fp_i = batch.batch_fp + (i + 1) * stride; let mut seg_trace = Trace::new(); let mut seg_pc = batch.batch_pc; @@ -452,8 +458,10 @@ fn handle_parallel_batch( cursor.index += i * delta; } } - let seg_start_indices: HashMap<_, _> = - seg_named_hints.iter().map(|(name, c)| (name.clone(), c.index)).collect(); + let seg_start_indices: HashMap<_, _> = seg_named_hints + .iter() + .map(|(name, c)| (name.clone(), c.index)) + .collect(); let mut hints = HintState { diagnostics: None, named_hints: &mut seg_named_hints, @@ -478,8 +486,8 @@ fn handle_parallel_batch( } let deferred = seg_mem.into_deferred_writes(); Ok((seg_trace, deferred)) - }) - .collect(); + })(); + }); for (idx, result) in results.into_iter().enumerate() { let (seg_trace, deferred) = result.map_err(|e| RunnerError::ParallelSegmentFailed(idx + 1, Box::new(e)))?; diff --git a/crates/lean_vm/src/isa/bytecode.rs b/crates/lean_vm/src/isa/bytecode.rs index b21e7e91a..28ccf358d 100644 --- a/crates/lean_vm/src/isa/bytecode.rs +++ b/crates/lean_vm/src/isa/bytecode.rs @@ -1,6 +1,7 @@ //! Bytecode representation and management use backend::*; +use serde::{Deserialize, Serialize}; use crate::{DIMENSION, F, FileId, FunctionName, Hint, N_INSTRUCTION_COLUMNS, SourceLocation}; @@ -8,14 +9,14 @@ use super::Instruction; use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CodeEntry { pub hints: Box<[Hint]>, // executed before the instruction pub instruction: Instruction, } /// `instructions_multilinear`, `hash`, and `ending_pc` must be checked at initialization to match `code`. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Bytecode { pub unpadded_size: usize, pub code: Vec, // assumed to be well-formed diff --git a/crates/lean_vm/src/isa/hint.rs b/crates/lean_vm/src/isa/hint.rs index 9bbc3f6d9..64d56fb71 100644 --- a/crates/lean_vm/src/isa/hint.rs +++ b/crates/lean_vm/src/isa/hint.rs @@ -13,7 +13,7 @@ use utils::ToUsize; /// VM hints provide execution guidance and debugging information, but does not appear /// in the verified bytecode. -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum Hint { /// Compute the inverse of a field element Inverse { @@ -80,7 +80,7 @@ pub enum Hint { }, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum HintWitnessDestination { /// Write directly at `m[fp + fp_offset ..] Inline { offset: T }, @@ -99,7 +99,7 @@ impl HintWitnessDestination { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum CustomHint { // Decompose values into their custom representations: /// each field element x is decomposed to: (a0, a1, a2, ..., a11, b) where: @@ -137,7 +137,7 @@ impl CustomHint { pub fn n_args(&self) -> usize { match self { - Self::DecomposeBitsXMSS => 4, + Self::DecomposeBitsXMSS => 5, Self::DecomposeBitsMerkleWhir => 3, Self::DecomposeBits => 3, Self::LessThan => 3, @@ -153,24 +153,40 @@ impl CustomHint { ) -> Result<(), RunnerError> { match self { Self::DecomposeBitsXMSS => { + // Aborting hypercube decomposition: a_i = Q * d_i + r_i + // where d_i = floor(a_i / Q), r_i = a_i mod Q, Q = 127 + // Then d_i is decomposed into base-w digits (w = 2^chunk_size) let decomposed_ptr = args[0].read_value(ctx.memory, ctx.fp)?.to_usize(); - let to_decompose_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); - let num_to_decompose = args[2].read_value(ctx.memory, ctx.fp)?.to_usize(); - let chunk_size = args[3].read_value(ctx.memory, ctx.fp)?.to_usize(); + let remaining_ptr = args[1].read_value(ctx.memory, ctx.fp)?.to_usize(); + let to_decompose_ptr = args[2].read_value(ctx.memory, ctx.fp)?.to_usize(); + let num_to_decompose = args[3].read_value(ctx.memory, ctx.fp)?.to_usize(); + let chunk_size = args[4].read_value(ctx.memory, ctx.fp)?.to_usize(); if chunk_size == 0 || !24_usize.is_multiple_of(chunk_size) { return Err(RunnerError::InvalidHintArguments(format!( "DecomposeBitsXMSS: chunk_size {chunk_size} must be a nonzero divisor of 24" ))); } + let q: usize = 127; // Q parameter for aborting hypercube (p = Q * w^z + 1) + let base = 1 << chunk_size; + let n_chunks = 24 / chunk_size; let mut memory_index_decomposed = decomposed_ptr; + let mut memory_index_remaining = remaining_ptr; #[allow(clippy::explicit_counter_loop)] for i in 0..num_to_decompose { let value = ctx.memory.get(to_decompose_ptr + i)?.to_usize(); - for i in 0..24 / chunk_size { - let value = F::from_usize((value >> (chunk_size * i)) & ((1 << chunk_size) - 1)); - ctx.memory.set(memory_index_decomposed, value)?; + let mut d_i = value / q; // floor(a_i / Q) + let r_i = value % q; // a_i mod Q + for _ in 0..n_chunks { + ctx.memory.set(memory_index_decomposed, F::from_usize(d_i % base))?; + d_i /= base; memory_index_decomposed += 1; } + assert_eq!( + d_i, 0, + "d_i does not fit in {n_chunks} base-{base} digits -> invalid XMSS encoding" + ); + ctx.memory.set(memory_index_remaining, F::from_usize(r_i))?; + memory_index_remaining += 1; } } Self::DecomposeBitsMerkleWhir => { @@ -235,7 +251,7 @@ impl CustomHint { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum Boolean { Equal, Different, @@ -243,7 +259,7 @@ pub enum Boolean { LessOrEqual, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub struct BooleanExpr { pub left: E, pub right: E, diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index d8063ea94..9dbb2b05c 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -6,7 +6,7 @@ use crate::core::{F, Label}; use crate::diagnostics::RunnerError; use crate::execution::memory::MemoryAccess; use crate::tables::TableT; -use crate::{ExtensionOpMode, Table, TableTrace}; +use crate::{ExtensionOpMode, Poseidon24Mode, Table, TableTrace}; use crate::{POSEIDON16_NAME, POSEIDON16_PERMUTE_NAME}; use backend::*; use std::collections::BTreeMap; @@ -15,7 +15,7 @@ use std::ops::AddAssign; use utils::ToUsize; /// Complete set of VM instruction types with comprehensive operation support -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum Instruction { /// Basic arithmetic computation instruction (ADD, MUL) Computation { @@ -53,7 +53,7 @@ pub enum Instruction { Precompile(PrecompileInstruction), } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub struct PrecompileArgs { pub arg_0: V, pub arg_1: V, @@ -61,7 +61,7 @@ pub struct PrecompileArgs { pub data: PrecompileCompTimeArgs, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum PrecompileCompTimeArgs { Poseidon16 { half_output: bool, @@ -71,6 +71,7 @@ pub enum PrecompileCompTimeArgs { // Mutually exclusive with `half_output`. permute: bool, }, + Poseidon24(Poseidon24Mode), ExtensionOp { size: S, mode: ExtensionOpMode, @@ -81,6 +82,7 @@ impl PrecompileCompTimeArgs { pub fn table(&self) -> Table { match self { Self::Poseidon16 { .. } => Table::poseidon16(), + Self::Poseidon24(_) => Table::poseidon24(), Self::ExtensionOp { .. } => Table::extension_op(), } } @@ -96,6 +98,7 @@ impl PrecompileCompTimeArgs { hardcoded_offset_left: hardcoded_left_4.map(&mut f), permute, }, + Self::Poseidon24(mode) => PrecompileCompTimeArgs::Poseidon24(mode), Self::ExtensionOp { size, mode } => PrecompileCompTimeArgs::ExtensionOp { size: f(size), mode }, } } @@ -135,7 +138,7 @@ pub struct InstructionContext<'a, M: MemoryAccess> { pub memory: &'a mut M, pub fp: &'a mut usize, pub pc: &'a mut usize, - pub pcs: &'a Vec, + pub pcs: &'a [usize], pub traces: &'a mut BTreeMap, pub counts: &'a mut InstructionCounts, } @@ -277,6 +280,9 @@ impl Display for PrecompileArgs { } } } + PrecompileCompTimeArgs::Poseidon24(mode) => { + write!(f, "poseidon24(mode={mode:?}, {arg_0}, {arg_1}, {res})") + } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { write!(f, "{}({arg_0}, {arg_1}, {res}, {size})", mode.name()) } diff --git a/crates/lean_vm/src/isa/operands/mem_or_constant.rs b/crates/lean_vm/src/isa/operands/mem_or_constant.rs index 0f88f8eea..747680fc4 100644 --- a/crates/lean_vm/src/isa/operands/mem_or_constant.rs +++ b/crates/lean_vm/src/isa/operands/mem_or_constant.rs @@ -5,7 +5,7 @@ use backend::*; use std::fmt::{Display, Formatter}; /// Represents a value that can be either a constant or memory location -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum MemOrConstant { /// Direct constant value Constant(F), diff --git a/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs b/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs index 8974bb380..9f0044719 100644 --- a/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs +++ b/crates/lean_vm/src/isa/operands/mem_or_fp_or_constant.rs @@ -6,7 +6,7 @@ use backend::*; use std::fmt::{Display, Formatter}; /// Memory, frame pointer, or constant operand -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum MemOrFpOrConstant { /// memory[fp + offset] MemoryAfterFp { offset: usize }, diff --git a/crates/lean_vm/src/isa/operation.rs b/crates/lean_vm/src/isa/operation.rs index eddc475df..19268355f 100644 --- a/crates/lean_vm/src/isa/operation.rs +++ b/crates/lean_vm/src/isa/operation.rs @@ -5,7 +5,7 @@ use backend::*; use std::fmt::{Display, Formatter}; /// Basic arithmetic operations supported by the VM -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum Operation { Add, Mul, diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index 471c1fdc6..1fb36f4fa 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -5,7 +5,7 @@ use backend::*; mod air; pub use air::*; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub struct ExecutionTable; impl TableT for ExecutionTable { diff --git a/crates/lean_vm/src/tables/extension_op/mod.rs b/crates/lean_vm/src/tables/extension_op/mod.rs index 6f085cd7e..10b44a735 100644 --- a/crates/lean_vm/src/tables/extension_op/mod.rs +++ b/crates/lean_vm/src/tables/extension_op/mod.rs @@ -13,7 +13,7 @@ pub(crate) const EXT_OP_FLAG_DOT_PRODUCT: usize = 16; pub(crate) const EXT_OP_FLAG_EQ: usize = 32; pub const EXT_OP_LEN_MULTIPLIER: usize = 64; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub enum ExtensionOp { Add, DotProduct, @@ -39,7 +39,7 @@ impl ExtensionOp { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub struct ExtensionOpMode { pub op: ExtensionOp, pub flag_be: bool, @@ -75,7 +75,7 @@ impl ExtensionOpMode { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub struct ExtensionOpPrecompile; impl TableT for ExtensionOpPrecompile { diff --git a/crates/lean_vm/src/tables/mod.rs b/crates/lean_vm/src/tables/mod.rs index 42a3e100b..38188b4da 100644 --- a/crates/lean_vm/src/tables/mod.rs +++ b/crates/lean_vm/src/tables/mod.rs @@ -4,6 +4,9 @@ pub use extension_op::*; mod poseidon; pub use poseidon::*; +mod poseidon_24; +pub use poseidon_24::*; + mod table_enum; pub use table_enum::*; @@ -17,11 +20,11 @@ mod utils; pub(crate) use utils::*; // In logup interractions, the `domainsep` is the last entry of every tuple going into -// the bus. It separates the two precompile tables from each other (Poseidon16 is odd, -// ExtensionOp is a multiple of 4), and — since every value is odd `>= 3` (Poseidon16) or -// a multiple of 4 (ExtensionOp) — also from the memory and bytecode lookups, whose -// reserved domainseps are respectively 1 and 2. +// the bus. It separates the precompile tables from each other and — since every value +// avoids the reserved memory (1) and bytecode (2) domainseps — also from the memory and +// bytecode lookups. // -// Poseidon16 (odd >= 3): 3 + 2·flag_permute + 4·flag_short + 8·flag_left + 16·flag_left·offset_left -// ExtensionOp (0 mod 4): 4·flag_be + 8·flag_add + 16·flag_dot_product + 32·flag_eq + 64·len +// Poseidon16 (odd >= 3): 3 + 2·flag_permute + 4·flag_short + 8·flag_left + 16·flag_left·offset_left +// ExtensionOp (0 mod 4): 4·flag_be + 8·flag_add + 16·flag_dot_product + 32·flag_eq + 64·len +// Poseidon24 (2 mod 4, >2): 6 + 4·mode (Compress0_9=6, Permute0_9=10, Permute9_18=14) // diff --git a/crates/lean_vm/src/tables/poseidon/mod.rs b/crates/lean_vm/src/tables/poseidon/mod.rs index f90e4646a..31a33f966 100644 --- a/crates/lean_vm/src/tables/poseidon/mod.rs +++ b/crates/lean_vm/src/tables/poseidon/mod.rs @@ -126,7 +126,7 @@ pub const ALL_POSEIDON16_NAMES: [&str; 5] = [ ]; pub const HALF_DIGEST_LEN: usize = DIGEST_LEN / 2; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] pub struct Poseidon16Precompile; impl TableT for Poseidon16Precompile { diff --git a/crates/lean_vm/src/tables/poseidon/trace_gen.rs b/crates/lean_vm/src/tables/poseidon/trace_gen.rs index 3664048d0..ca2ab1558 100644 --- a/crates/lean_vm/src/tables/poseidon/trace_gen.rs +++ b/crates/lean_vm/src/tables/poseidon/trace_gen.rs @@ -7,7 +7,7 @@ use crate::{ use backend::*; #[instrument(name = "generate Poseidon16 AIR trace", skip_all)] -pub fn fill_trace_poseidon_16(trace: &mut [Vec]) { +pub fn fill_trace_poseidon_16(trace: &mut [ArenaVec]) { let n = trace.iter().map(|col| col.len()).max().unwrap(); for col in trace.iter_mut() { if col.len() != n { @@ -19,7 +19,7 @@ pub fn fill_trace_poseidon_16(trace: &mut [Vec]) { let trace_packed: Vec<_> = trace.iter().map(|col| FPacking::::pack_slice(&col[..m])).collect(); // fill the packed rows - (0..m / packing_width::()).into_par_iter().for_each(|i| { + parallel::for_each_index(m / packing_width::(), |i| { let ptrs: Vec<*mut FPacking> = trace_packed .iter() .map(|col| unsafe { (col.as_ptr() as *mut FPacking).add(i) }) diff --git a/crates/lean_vm/src/tables/poseidon_24/mod.rs b/crates/lean_vm/src/tables/poseidon_24/mod.rs new file mode 100644 index 000000000..07074069d --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_24/mod.rs @@ -0,0 +1,512 @@ +use std::any::TypeId; + +use crate::*; +use backend::*; +use utils::{ToUsize, poseidon24_compress_0_9, poseidon24_permute_0_9, poseidon24_permute_9_18}; + +/// Dispatch `mds_circ_24` through concrete types. +/// For `SymbolicExpression` we use the dense form so the zkDSL generator can +/// emit `dot_product_be` precompile calls instead of Karatsuba arithmetic. +#[inline(always)] +fn mds_air_24(state: &mut [A; WIDTH_24]) { + if TypeId::of::() == TypeId::of::>() { + dense_mat_vec_air_24(mds_dense_24(), state); + return; + } + macro_rules! dispatch { + ($t:ty) => { + if TypeId::of::() == TypeId::of::<$t>() { + mds_circ_24::<$t>(unsafe { &mut *(state as *mut [A; WIDTH_24] as *mut [$t; WIDTH_24]) }); + return; + } + }; + } + dispatch!(F); + dispatch!(EF); + dispatch!(FPacking); + dispatch!(EFPacking); + unreachable!() +} + +fn mds_dense_24() -> &'static [[F; WIDTH_24]; WIDTH_24] { + use std::sync::OnceLock; + static MAT: OnceLock<[[KoalaBear; WIDTH_24]; WIDTH_24]> = OnceLock::new(); + MAT.get_or_init(|| { + let cols: [[F; WIDTH_24]; WIDTH_24] = std::array::from_fn(|j| { + let mut e = [F::ZERO; WIDTH_24]; + e[j] = F::ONE; + mds_circ_24(&mut e); + e + }); + std::array::from_fn(|i| std::array::from_fn(|j| cols[j][i])) + }) +} + +#[inline(always)] +fn add_kb_24(a: &mut A, value: F) { + macro_rules! dispatch { + ($t:ty) => { + if TypeId::of::() == TypeId::of::<$t>() { + *unsafe { &mut *(a as *mut A as *mut $t) } += value; + return; + } + }; + } + dispatch!(F); + dispatch!(EF); + dispatch!(FPacking); + dispatch!(EFPacking); + dispatch!(SymbolicExpression); + unreachable!() +} + +#[inline(always)] +fn mul_kb_24(a: A, value: F) -> A { + macro_rules! dispatch { + ($t:ty) => { + if TypeId::of::() == TypeId::of::<$t>() { + let r = unsafe { std::ptr::read(&a as *const A as *const $t) } * value; + return unsafe { std::ptr::read(&r as *const $t as *const A) }; + } + }; + } + dispatch!(F); + dispatch!(EF); + dispatch!(FPacking); + dispatch!(EFPacking); + dispatch!(SymbolicExpression); + unreachable!() +} + +mod trace_gen; +pub use trace_gen::fill_trace_poseidon_24; +use trace_gen::generate_trace_rows_for_perm_24; + +pub(super) const WIDTH_24: usize = 24; +const HALF_INITIAL_FULL_ROUNDS_24: usize = POSEIDON1_HALF_FULL_ROUNDS_24 / 2; +const PARTIAL_ROUNDS_24: usize = POSEIDON1_PARTIAL_ROUNDS_24; +const HALF_FINAL_FULL_ROUNDS_24: usize = POSEIDON1_HALF_FULL_ROUNDS_24 / 2; + +// domain separation (see `tables/mod.rs`): values must avoid memory (1), bytecode (2), +// Poseidon16 (odd >= 3) and ExtensionOp (multiple of 4). We use the free residue class +// `2 mod 4`, `> 2`: Compress0_9 = 6, Permute0_9 = 10, Permute9_18 = 14. +pub const POSEIDON_24_DOMAINSEP_BASE: usize = 6; +pub const POSEIDON_24_DOMAINSEP_STEP: usize = 4; + +// 3 modes for Poseidon24 precompile: +// Compress0_9: feedforward + output[0..9] -> domainsep = 6 +// Permute0_9: permutation + output[0..9] -> domainsep = 10 +// Permute9_18: permutation + output[9..18] -> domainsep = 14 +// 2 committed boolean columns: is_compress_0_9, is_permute_0_9 +// 3rd mode deduced: is_permute_9_18 = 1 - is_compress_0_9 - is_permute_0_9 +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] +pub enum Poseidon24Mode { + Compress0_9 = 0, + Permute0_9 = 1, + Permute9_18 = 2, +} + +impl Poseidon24Mode { + pub const fn as_usize(self) -> usize { + self as usize + } + + pub const fn is_compress(self) -> bool { + matches!(self, Self::Compress0_9) + } + + pub const fn is_permute_0_9(self) -> bool { + matches!(self, Self::Permute0_9) + } +} + +pub const POSEIDON_24_INPUT_LEFT_SIZE: usize = 9; +pub const POSEIDON_24_INPUT_RIGHT_SIZE: usize = 15; +pub const POSEIDON_24_OUTPUT_SIZE: usize = 9; + +pub const POSEIDON_24_COL_FLAG: ColIndex = 0; +pub const POSEIDON_24_COL_IS_COMPRESS_0_9: ColIndex = 1; +pub const POSEIDON_24_COL_IS_PERMUTE_0_9: ColIndex = 2; +pub const POSEIDON_24_COL_INDEX_INPUT_LEFT: ColIndex = 3; +pub const POSEIDON_24_COL_INDEX_INPUT_RIGHT: ColIndex = 4; +pub const POSEIDON_24_COL_INDEX_RES: ColIndex = 5; +pub const POSEIDON_24_COL_INPUT_START: ColIndex = 6; +pub const POSEIDON_24_COL_OUTPUT_START: ColIndex = num_cols_poseidon_24() - POSEIDON_24_OUTPUT_SIZE; + +// virtual columns (not committed) +pub const POSEIDON_24_COL_PRECOMPILE_DATA: usize = num_cols_poseidon_24(); + +pub const POSEIDON24_NAME: &str = "poseidon24_compress"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] +pub struct Poseidon24Precompile; + +impl TableT for Poseidon24Precompile { + fn name(&self) -> &'static str { + POSEIDON24_NAME + } + + fn table(&self) -> Table { + Table::poseidon24() + } + + fn bus_interactions(&self) -> Vec { + // Convention shared with the other tables: the unique Multiplicity::Column bus + // comes first; everything that follows is Multiplicity::One. + let mut buses = vec![BusInteraction { + direction: BusDirection::Pull, + multiplicity: BusMultiplicity::Column(POSEIDON_24_COL_FLAG), + domainsep: BusData::Column(POSEIDON_24_COL_PRECOMPILE_DATA), + data: vec![ + BusData::Column(POSEIDON_24_COL_INDEX_INPUT_LEFT), + BusData::Column(POSEIDON_24_COL_INDEX_INPUT_RIGHT), + BusData::Column(POSEIDON_24_COL_INDEX_RES), + ], + }]; + buses.extend(memory_lookups_consecutive( + POSEIDON_24_COL_INDEX_INPUT_LEFT, + POSEIDON_24_COL_INPUT_START, + POSEIDON_24_INPUT_LEFT_SIZE, + )); + buses.extend(memory_lookups_consecutive( + POSEIDON_24_COL_INDEX_INPUT_RIGHT, + POSEIDON_24_COL_INPUT_START + POSEIDON_24_INPUT_LEFT_SIZE, + POSEIDON_24_INPUT_RIGHT_SIZE, + )); + buses.extend(memory_lookups_consecutive( + POSEIDON_24_COL_INDEX_RES, + POSEIDON_24_COL_OUTPUT_START, + POSEIDON_24_OUTPUT_SIZE, + )); + buses + } + + fn n_columns_total(&self) -> usize { + self.n_columns() + 1 // +1 for POSEIDON_24_COL_PRECOMPILE_DATA + } + + fn padding_row(&self, zero_vec_ptr: usize, null_hash_24_ptr: usize, _ending_pc: usize) -> Vec { + let mut row = vec![F::ZERO; num_cols_poseidon_24() + 1]; + let ptrs: Vec<*mut F> = (0..num_cols_poseidon_24()) + .map(|i| unsafe { row.as_mut_ptr().add(i) }) + .collect(); + + let perm: &mut Poseidon1Cols24<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols24<&mut F>) }; + perm.inputs.iter_mut().for_each(|x| **x = F::ZERO); + *perm.flag = F::ZERO; + *perm.is_compress_0_9 = F::ONE; // convention + *perm.is_permute_0_9 = F::ZERO; + *perm.index_input_left = F::from_usize(zero_vec_ptr); + *perm.index_input_right = F::from_usize(zero_vec_ptr); + *perm.index_res = F::from_usize(null_hash_24_ptr); + + generate_trace_rows_for_perm_24(perm); + // virtual column + row[POSEIDON_24_COL_PRECOMPILE_DATA] = F::from_usize( + POSEIDON_24_DOMAINSEP_BASE + POSEIDON_24_DOMAINSEP_STEP * Poseidon24Mode::Compress0_9.as_usize(), + ); // ...following the above convention + row + } + + #[inline(always)] + fn execute( + &self, + index_input_left: F, + index_input_right: F, + index_res: F, + args: PrecompileCompTimeArgs, + ctx: &mut InstructionContext<'_, M>, + ) -> Result<(), RunnerError> { + let PrecompileCompTimeArgs::Poseidon24(mode) = args else { + panic!("expected Poseidon24 precompile args"); + }; + let is_compress_0_9 = mode.is_compress(); + let is_permute_0_9 = mode.is_permute_0_9(); + let trace = ctx.traces.get_mut(&self.table()).unwrap(); + + let arg0 = ctx + .memory + .get_slice(index_input_left.to_usize(), POSEIDON_24_INPUT_LEFT_SIZE)?; + let arg1 = ctx + .memory + .get_slice(index_input_right.to_usize(), POSEIDON_24_INPUT_RIGHT_SIZE)?; + + let mut input = [F::ZERO; POSEIDON_24_INPUT_LEFT_SIZE + POSEIDON_24_INPUT_RIGHT_SIZE]; + input[..POSEIDON_24_INPUT_LEFT_SIZE].copy_from_slice(&arg0); + input[POSEIDON_24_INPUT_LEFT_SIZE..].copy_from_slice(&arg1); + + let result = match mode { + Poseidon24Mode::Compress0_9 => poseidon24_compress_0_9(input), + Poseidon24Mode::Permute0_9 => poseidon24_permute_0_9(input), + Poseidon24Mode::Permute9_18 => poseidon24_permute_9_18(input), + }; + + let res_a: [F; POSEIDON_24_OUTPUT_SIZE] = result[..POSEIDON_24_OUTPUT_SIZE].try_into().unwrap(); + + ctx.memory.set_slice(index_res.to_usize(), &res_a)?; + + trace.columns[POSEIDON_24_COL_FLAG].push(F::ONE); + trace.columns[POSEIDON_24_COL_IS_COMPRESS_0_9].push(F::from_bool(is_compress_0_9)); + trace.columns[POSEIDON_24_COL_IS_PERMUTE_0_9].push(F::from_bool(is_permute_0_9)); + trace.columns[POSEIDON_24_COL_INDEX_INPUT_LEFT].push(index_input_left); + trace.columns[POSEIDON_24_COL_INDEX_INPUT_RIGHT].push(index_input_right); + trace.columns[POSEIDON_24_COL_INDEX_RES].push(index_res); + for (i, value) in input.iter().enumerate() { + trace.columns[POSEIDON_24_COL_INPUT_START + i].push(*value); + } + trace.columns[POSEIDON_24_COL_PRECOMPILE_DATA].push(F::from_usize( + POSEIDON_24_DOMAINSEP_BASE + POSEIDON_24_DOMAINSEP_STEP * mode.as_usize(), + )); + + // the rest of the trace is filled at the end of the execution (for parallelism + SIMD) + + Ok(()) + } +} + +impl Air for Poseidon24Precompile { + type ExtraData = ExtraDataForBuses; + fn n_columns(&self) -> usize { + num_cols_poseidon_24() + } + fn degree_air(&self) -> usize { + 10 + } + fn low_degree_air(&self) -> Option<(usize, usize)> { + // Each partial round contributes one `assert_eq_low` per round (1 S-box / round), of degree 3 (= the "low" degree part) + Some((3, PARTIAL_ROUNDS_24)) + } + fn n_shift_columns(&self) -> usize { + 0 + } + fn n_constraints(&self) -> usize { + 2 * BUS as usize + 107 + } + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let cols: Poseidon1Cols24 = { + let flat = builder.flat(); + let (prefix, shorts, suffix) = unsafe { flat.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + unsafe { std::ptr::read(&shorts[0]) } + }; + + // domainsep = 6 + 4*mode = 6 + 4*is_permute_0_9 + 8*is_permute_9_18 + let precompile_data = AB::IF::from_usize(POSEIDON_24_DOMAINSEP_BASE) + + cols.is_permute_0_9 * AB::F::from_usize(POSEIDON_24_DOMAINSEP_STEP) + + (AB::IF::ONE - cols.is_compress_0_9 - cols.is_permute_0_9) // is_permute_9_18 + * AB::F::from_usize(2 * POSEIDON_24_DOMAINSEP_STEP); + + if BUS { + eval_bus_virtual::( + builder, + extra_data, + cols.flag, + precompile_data, + &[cols.index_input_left, cols.index_input_right, cols.index_res], + ); + } else { + builder.declare_values(std::slice::from_ref(&cols.flag)); + builder.declare_values(&[ + cols.index_input_left, + cols.index_input_right, + cols.index_res, + precompile_data, + ]); + } + + builder.assert_bool(cols.flag); + builder.assert_bool(cols.is_compress_0_9); + builder.assert_bool(cols.is_permute_0_9); + + let is_compress = cols.is_compress_0_9; + let is_output_0_9 = cols.is_compress_0_9 + cols.is_permute_0_9; + + eval_poseidon1_24(builder, &cols, is_compress, is_output_0_9) + } +} + +#[repr(C)] +#[derive(Debug)] +pub(super) struct Poseidon1Cols24 { + pub flag: T, + pub is_compress_0_9: T, + pub is_permute_0_9: T, + pub index_input_left: T, + pub index_input_right: T, + pub index_res: T, + + pub inputs: [T; WIDTH_24], + pub beginning_full_rounds: [[T; WIDTH_24]; HALF_INITIAL_FULL_ROUNDS_24], + pub partial_rounds: [T; PARTIAL_ROUNDS_24], + pub ending_full_rounds: [[T; WIDTH_24]; HALF_FINAL_FULL_ROUNDS_24 - 1], + pub outputs: [T; POSEIDON_24_OUTPUT_SIZE], +} + +fn eval_poseidon1_24( + builder: &mut AB, + local: &Poseidon1Cols24, + is_compress: AB::IF, + is_output_0_9: AB::IF, +) { + let mut state: [_; WIDTH_24] = local.inputs; + + // No initial linear layer for Poseidon1 + + let initial_constants = poseidon1_24_initial_constants(); + for round in 0..HALF_INITIAL_FULL_ROUNDS_24 { + eval_2_full_rounds_24( + &mut state, + &local.beginning_full_rounds[round], + &initial_constants[2 * round], + &initial_constants[2 * round + 1], + builder, + ); + } + + // --- Sparse partial rounds --- + // Transition: add first-round constants, multiply by m_i + builder.low_degree_block(&mut state, |b, state| { + let state: &mut [AB::IF; WIDTH_24] = state.try_into().unwrap(); + + let frc = poseidon1_24_sparse_first_round_constants(); + for (s, &c) in state.iter_mut().zip(frc.iter()) { + add_kb_24(s, c); + } + dense_mat_vec_air_24(poseidon1_24_sparse_m_i(), state); + + let first_rows = poseidon1_24_sparse_first_row(); + let v_vecs = poseidon1_24_sparse_v(); + let scalar_rc = poseidon1_24_sparse_scalar_round_constants(); + for round in 0..PARTIAL_ROUNDS_24 { + // S-box on state[0] + state[0] = state[0].cube(); + b.assert_eq_low(state[0], local.partial_rounds[round]); + state[0] = local.partial_rounds[round]; + // Scalar round constant (not on last round) + if round < PARTIAL_ROUNDS_24 - 1 { + add_kb_24(&mut state[0], scalar_rc[round]); + } + // Sparse matrix: new_s0 = dot(first_row, state), state[i] += old_s0 * v[i-1] + sparse_mat_air_24(state, &first_rows[round], &v_vecs[round]); + } + }); + + let final_constants = poseidon1_24_final_constants(); + for round in 0..HALF_FINAL_FULL_ROUNDS_24 - 1 { + eval_2_full_rounds_24( + &mut state, + &local.ending_full_rounds[round], + &final_constants[2 * round], + &final_constants[2 * round + 1], + builder, + ); + } + + eval_last_2_full_rounds_24( + &local.inputs, + &mut state, + &local.outputs, + &final_constants[2 * (HALF_FINAL_FULL_ROUNDS_24 - 1)], + &final_constants[2 * (HALF_FINAL_FULL_ROUNDS_24 - 1) + 1], + is_compress, + is_output_0_9, + builder, + ); +} + +pub const fn num_cols_poseidon_24() -> usize { + size_of::>() +} + +#[inline] +fn eval_2_full_rounds_24( + state: &mut [AB::IF; WIDTH_24], + post_full_round: &[AB::IF; WIDTH_24], + round_constants_1: &[F; WIDTH_24], + round_constants_2: &[F; WIDTH_24], + builder: &mut AB, +) { + for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { + add_kb_24(s, *r); + *s = s.cube(); + } + mds_air_24(state); + for (s, r) in state.iter_mut().zip(round_constants_2.iter()) { + add_kb_24(s, *r); + *s = s.cube(); + } + mds_air_24(state); + for (state_i, post_i) in state.iter_mut().zip(post_full_round) { + builder.assert_eq(*state_i, *post_i); + *state_i = *post_i; + } +} + +#[inline] +#[allow(clippy::too_many_arguments)] +fn eval_last_2_full_rounds_24( + initial_state: &[AB::IF; WIDTH_24], + state: &mut [AB::IF; WIDTH_24], + outputs: &[AB::IF; POSEIDON_24_OUTPUT_SIZE], + round_constants_1: &[F; WIDTH_24], + round_constants_2: &[F; WIDTH_24], + is_compress: AB::IF, + is_output_0_9: AB::IF, + builder: &mut AB, +) { + for (s, r) in state.iter_mut().zip(round_constants_1.iter()) { + add_kb_24(s, *r); + *s = s.cube(); + } + mds_air_24(state); + for (s, r) in state.iter_mut().zip(round_constants_2.iter()) { + add_kb_24(s, *r); + *s = s.cube(); + } + mds_air_24(state); + // conditional feedforward: only for compress mode + for (state_i, init_state_i) in state.iter_mut().zip(initial_state) { + *state_i += *init_state_i * is_compress; + } + for ((output_i, state_i), state_9_plus_i) in outputs + .iter() + .zip(&state[..POSEIDON_24_OUTPUT_SIZE]) + .zip(&state[POSEIDON_24_OUTPUT_SIZE..][..POSEIDON_24_OUTPUT_SIZE]) + { + builder.assert_eq( + *output_i, + *state_i * is_output_0_9 + *state_9_plus_i * (AB::IF::ONE - is_output_0_9), + ); + } +} + +#[inline] +fn dense_mat_vec_air_24(mat: &[[F; 24]; 24], state: &mut [A; WIDTH_24]) { + let input = *state; + for i in 0..WIDTH_24 { + let mut acc = A::ZERO; + for j in 0..WIDTH_24 { + acc += mul_kb_24(input[j], mat[i][j]); + } + state[i] = acc; + } +} + +#[inline] +fn sparse_mat_air_24( + state: &mut [A; WIDTH_24], + first_row: &[F; WIDTH_24], + v: &[F; WIDTH_24], +) { + let old_s0 = state[0]; + let mut new_s0 = A::ZERO; + for j in 0..WIDTH_24 { + new_s0 += mul_kb_24(state[j], first_row[j]); + } + state[0] = new_s0; + for i in 1..WIDTH_24 { + state[i] += mul_kb_24(old_s0, v[i - 1]); + } +} diff --git a/crates/lean_vm/src/tables/poseidon_24/trace_gen.rs b/crates/lean_vm/src/tables/poseidon_24/trace_gen.rs new file mode 100644 index 000000000..185209a3b --- /dev/null +++ b/crates/lean_vm/src/tables/poseidon_24/trace_gen.rs @@ -0,0 +1,173 @@ +use tracing::instrument; + +use crate::{ + F, + tables::{POSEIDON_24_OUTPUT_SIZE, Poseidon1Cols24, WIDTH_24}, +}; +use backend::*; + +#[instrument(name = "generate Poseidon24 AIR trace", skip_all)] +pub fn fill_trace_poseidon_24(trace: &mut [ArenaVec]) { + let n = trace.iter().map(|col| col.len()).max().unwrap(); + for col in trace.iter_mut() { + if col.len() != n { + col.resize(n, F::ZERO); + } + } + + let m = n - (n % packing_width::()); + let trace_packed: Vec<_> = trace.iter().map(|col| FPacking::::pack_slice(&col[..m])).collect(); + + // fill the packed rows + parallel::for_each_index(m / packing_width::(), |i| { + let ptrs: Vec<*mut FPacking> = trace_packed + .iter() + .map(|col| unsafe { (col.as_ptr() as *mut FPacking).add(i) }) + .collect(); + let perm: &mut Poseidon1Cols24<&mut FPacking> = + unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols24<&mut FPacking>) }; + + generate_trace_rows_for_perm_24(perm); + }); + + // fill the remaining rows (non packed) + for i in m..n { + let ptrs: Vec<*mut F> = trace + .iter() + .map(|col| unsafe { (col.as_ptr() as *mut F).add(i) }) + .collect(); + let perm: &mut Poseidon1Cols24<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols24<&mut F>) }; + generate_trace_rows_for_perm_24(perm); + } +} + +pub(super) fn generate_trace_rows_for_perm_24 + Copy>(perm: &mut Poseidon1Cols24<&mut F>) { + let inputs: [F; WIDTH_24] = std::array::from_fn(|i| *perm.inputs[i]); + let mut state = inputs; + + // No initial linear layer for Poseidon1 + + for (full_round, constants) in perm + .beginning_full_rounds + .iter_mut() + .zip(poseidon1_24_initial_constants().chunks_exact(2)) + { + generate_2_full_round_24(&mut state, full_round, &constants[0], &constants[1]); + } + + // --- Sparse partial rounds --- + let frc = poseidon1_24_sparse_first_round_constants(); + for (s, &c) in state.iter_mut().zip(frc.iter()) { + *s += c; + } + let m_i = poseidon1_24_sparse_m_i(); + let input_for_mi = state; + for i in 0..WIDTH_24 { + let row: [F; WIDTH_24] = m_i[i].map(F::from); + state[i] = F::dot_product(&input_for_mi, &row); + } + + let first_rows = poseidon1_24_sparse_first_row(); + let v_vecs = poseidon1_24_sparse_v(); + let scalar_rc = poseidon1_24_sparse_scalar_round_constants(); + let n_partial = perm.partial_rounds.len(); + for round in 0..n_partial { + // S-box on state[0] + state[0] = state[0].cube(); + *perm.partial_rounds[round] = state[0]; + // Scalar round constant (not on last round) + if round < n_partial - 1 { + state[0] += scalar_rc[round]; + } + // Sparse matrix + let old_s0 = state[0]; + let row: [F; WIDTH_24] = first_rows[round].map(F::from); + let new_s0 = F::dot_product(&state, &row); + state[0] = new_s0; + for i in 1..WIDTH_24 { + state[i] += old_s0 * v_vecs[round][i - 1]; + } + } + + let n_ending_full_rounds = perm.ending_full_rounds.len(); + for (full_round, constants) in perm + .ending_full_rounds + .iter_mut() + .zip(poseidon1_24_final_constants().chunks_exact(2)) + { + generate_2_full_round_24(&mut state, full_round, &constants[0], &constants[1]); + } + + // Last 2 full rounds with conditional feedforward and output selection + let is_compress = *perm.is_compress_0_9; + let is_output_0_9 = *perm.is_compress_0_9 + *perm.is_permute_0_9; + generate_last_2_full_rounds_24( + &mut state, + &inputs, + &mut perm.outputs, + &poseidon1_24_final_constants()[2 * n_ending_full_rounds], + &poseidon1_24_final_constants()[2 * n_ending_full_rounds + 1], + is_compress, + is_output_0_9, + ); +} + +#[inline] +fn generate_2_full_round_24 + Copy>( + state: &mut [F; WIDTH_24], + post_full_round: &mut [&mut F; WIDTH_24], + round_constants_1: &[KoalaBear; WIDTH_24], + round_constants_2: &[KoalaBear; WIDTH_24], +) { + for (state_i, const_i) in state.iter_mut().zip(round_constants_1) { + *state_i += *const_i; + *state_i = state_i.cube(); + } + mds_circ_24(state); + + for (state_i, const_i) in state.iter_mut().zip(round_constants_2.iter()) { + *state_i += *const_i; + *state_i = state_i.cube(); + } + mds_circ_24(state); + + post_full_round.iter_mut().zip(*state).for_each(|(post, x)| { + **post = x; + }); +} + +#[inline] +fn generate_last_2_full_rounds_24 + Copy>( + state: &mut [F; WIDTH_24], + inputs: &[F; WIDTH_24], + outputs: &mut [&mut F; POSEIDON_24_OUTPUT_SIZE], + round_constants_1: &[KoalaBear; WIDTH_24], + round_constants_2: &[KoalaBear; WIDTH_24], + is_compress: F, + is_output_0_9: F, +) { + for (state_i, const_i) in state.iter_mut().zip(round_constants_1) { + *state_i += *const_i; + *state_i = state_i.cube(); + } + mds_circ_24(state); + + for (state_i, const_i) in state.iter_mut().zip(round_constants_2.iter()) { + *state_i += *const_i; + *state_i = state_i.cube(); + } + mds_circ_24(state); + + // Conditional feedforward: only for compress mode + for (state_i, input_i) in state.iter_mut().zip(inputs) { + *state_i += *input_i * is_compress; + } + // Select output[0..9] or output[9..18] based on is_output_0_9 + for ((output, first), second) in outputs + .iter_mut() + .zip(&state[..POSEIDON_24_OUTPUT_SIZE]) + .zip(&state[POSEIDON_24_OUTPUT_SIZE..][..POSEIDON_24_OUTPUT_SIZE]) + { + **output = *first * is_output_0_9 + *second * (F::ONE - is_output_0_9); + } +} diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index 1a21d6066..b26b0eab2 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -3,17 +3,23 @@ use backend::*; use crate::execution::memory::MemoryAccess; use crate::*; -pub const N_TABLES: usize = 3; -pub const ALL_TABLES: [Table; N_TABLES] = [Table::execution(), Table::extension_op(), Table::poseidon16()]; +pub const N_TABLES: usize = 4; +pub const ALL_TABLES: [Table; N_TABLES] = [ + Table::execution(), + Table::extension_op(), + Table::poseidon16(), + Table::poseidon24(), +]; pub const MAX_BUS_WIDTH: usize = N_INSTRUCTION_COLUMNS + 2; // + 1 for PC, + 1 for domainsep pub const LOG_MAX_BUS_WIDTH: usize = log2_ceil_usize(MAX_BUS_WIDTH); -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize)] #[repr(usize)] pub enum Table { Execution(ExecutionTable), ExtensionOp(ExtensionOpPrecompile), Poseidon16(Poseidon16Precompile), + Poseidon24(Poseidon24Precompile), } #[macro_export] @@ -23,6 +29,7 @@ macro_rules! delegate_to_inner { match $self { Self::ExtensionOp(p) => p.$method($($($arg),*)?), Self::Poseidon16(p) => p.$method($($($arg),*)?), + Self::Poseidon24(p) => p.$method($($($arg),*)?), Self::Execution(p) => p.$method($($($arg),*)?), } }; @@ -31,6 +38,7 @@ macro_rules! delegate_to_inner { match $self { Table::ExtensionOp(p) => $macro_name!(p), Table::Poseidon16(p) => $macro_name!(p), + Table::Poseidon24(p) => $macro_name!(p), Table::Execution(p) => $macro_name!(p), } }; @@ -46,6 +54,9 @@ impl Table { pub const fn poseidon16() -> Self { Self::Poseidon16(Poseidon16Precompile) } + pub const fn poseidon24() -> Self { + Self::Poseidon24(Poseidon24Precompile) + } pub fn embed(&self) -> PF { PF::from_usize(self.index()) } diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index e1fa8933b..489a148c9 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -138,7 +138,7 @@ pub struct MemoryLookupGroup { #[derive(Debug, Default)] pub struct TableTrace { - pub columns: Vec>, + pub columns: Vec>, pub non_padded_n_rows: usize, pub log_n_rows: VarCount, } @@ -146,7 +146,7 @@ pub struct TableTrace { impl TableTrace { pub fn new(air: &A) -> Self { Self { - columns: vec![Vec::new(); air.n_columns_total()], + columns: (0..air.n_columns_total()).map(|_| ArenaVec::new()).collect(), non_padded_n_rows: 0, // filled later log_n_rows: 0, // filled later } @@ -167,10 +167,10 @@ pub struct ExtraDataForBuses>> { pub alpha_powers: Vec, } impl>> ExtraDataForBuses { - pub fn new(logup_alphas_eq_poly: Vec, alpha_powers: Vec) -> Self { + pub fn new(logup_alphas_eq_poly: &[EF], alpha_powers: Vec) -> Self { let logup_alphas_eq_poly_packed = logup_alphas_eq_poly.iter().map(|a| EFPacking::::from(*a)).collect(); Self { - logup_alphas_eq_poly, + logup_alphas_eq_poly: logup_alphas_eq_poly.to_vec(), logup_alphas_eq_poly_packed, alpha_powers, } diff --git a/crates/leansig_wrapper/Cargo.toml b/crates/leansig_wrapper/Cargo.toml new file mode 100644 index 000000000..73b4d1cf0 --- /dev/null +++ b/crates/leansig_wrapper/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "leansig_wrapper" +version.workspace = true +edition.workspace = true + +[lints] +workspace = true + +[dependencies] +leansig.workspace = true +leansig_fast_keygen.workspace = true +backend.workspace = true +rand.workspace = true +p3-field = { git = "https://github.com/Plonky3/Plonky3.git" } +ssz = { package = "ethereum_ssz", version = "0.10.0" } + +[features] +test-config = [] \ No newline at end of file diff --git a/crates/leansig_wrapper/src/lib.rs b/crates/leansig_wrapper/src/lib.rs new file mode 100644 index 000000000..ac2cad034 --- /dev/null +++ b/crates/leansig_wrapper/src/lib.rs @@ -0,0 +1,184 @@ +use backend::{KoalaBear, integers::QuotientMap}; +use leansig::{ + inc_encoding::target_sum::TargetSumEncoding, + signature::{ + SignatureScheme, + generalized_xmss::{ + GeneralizedXMSSPublicKey, GeneralizedXMSSSecretKey, GeneralizedXMSSSignature, + GeneralizedXMSSSignatureScheme, + }, + }, + symmetric::{ + message_hash::aborting::AbortingHypercubeMessageHash, prf::shake_to_field::ShakePRFtoF, + tweak_hash::poseidon::PoseidonTweakHash, + }, +}; +use leansig_fast_keygen::{ + signature::SignatureScheme as FastKeyGenSignatureScheme, symmetric::message_hash::encode_message, +}; +use p3_field::PrimeField32; +use std::array; + +#[cfg(feature = "test-config")] +pub const V: usize = 4; +#[cfg(not(feature = "test-config"))] +pub const V: usize = 46; +pub const BASE: usize = 1 << W; +const Z: usize = 8; +const Q: usize = 127; +#[cfg(feature = "test-config")] +pub const TARGET_SUM: usize = 6; +#[cfg(not(feature = "test-config"))] +pub const TARGET_SUM: usize = 200; +pub const RAND_LEN_FE: usize = 7; +pub const HASH_LEN_FE: usize = 8; +pub const MSG_LEN_FE: usize = 9; +pub const PARAMETER_LEN: usize = 5; +pub const TWEAK_LEN_FE: usize = 2; + +pub const W: usize = 3; +pub const MESSAGE_LENGTH: usize = 32; +pub const POSEIDON24_CAPACITY: usize = 9; +pub const POSEIDON24_RATE: usize = 15; + +#[cfg(feature = "test-config")] +pub const LOG_LIFETIME: usize = 8; +#[cfg(not(feature = "test-config"))] +pub const LOG_LIFETIME: usize = 32; + +pub const SIG_SIZE_FE: usize = RAND_LEN_FE + (V + LOG_LIFETIME) * HASH_LEN_FE; + +pub(crate) type F = KoalaBear; + +#[cfg(feature = "test-config")] +pub const WOTS_PUBKET_SPONGE_DOMAIN_SEP: [F; POSEIDON24_CAPACITY] = F::new_array([ + 627826400, 1244476188, 370678638, 978729783, 1996000804, 1380088873, 1753334201, 433326939, 1294775677, +]); +#[cfg(not(feature = "test-config"))] +pub const WOTS_PUBKET_SPONGE_DOMAIN_SEP: [F; POSEIDON24_CAPACITY] = F::new_array([ + 2060061975, 916902315, 229801915, 83751504, 2093549181, 1743125625, 721042244, 1252069948, 1192880636, +]); + +pub use leansig::symmetric::tweak_hash::TweakableHash; +use rand::CryptoRng; + +pub type LeanSigTH = PoseidonTweakHash; + +type MH = + AbortingHypercubeMessageHash; +type TH = PoseidonTweakHash; +type PrF = ShakePRFtoF; +type IE = TargetSumEncoding; + +pub type LeanSigScheme = GeneralizedXMSSSignatureScheme; +pub type XmssPublicKey = GeneralizedXMSSPublicKey; +pub type XmssSecretKey = GeneralizedXMSSSecretKey; +pub type XmssSignature = GeneralizedXMSSSignature; + +#[cfg(feature = "test-config")] +pub type FastKeyGenScheme = leansig_fast_keygen::signature::generalized_xmss::instantiations_aborting::lifetime_2_to_the_8::SchemeAbortingTargetSumLifetime8Dim46Base8; +#[cfg(feature = "test-config")] +pub type FastKeyGenSecretKey = leansig_fast_keygen::signature::generalized_xmss::instantiations_aborting::lifetime_2_to_the_8::SecretKeyAbortingTargetSumLifetime8Dim46Base8; +#[cfg(not(feature = "test-config"))] +pub type FastKeyGenScheme = leansig_fast_keygen::signature::generalized_xmss::instantiations_aborting::lifetime_2_to_the_32::SchemeAbortingTargetSumLifetime32Dim46Base8; +#[cfg(not(feature = "test-config"))] +pub type FastKeyGenSecretKey = leansig_fast_keygen::signature::generalized_xmss::instantiations_aborting::lifetime_2_to_the_32::SecretKeyAbortingTargetSumLifetime32Dim46Base8; + +pub fn pubkey_merkle_root(pub_keys: &XmssPublicKey) -> [F; HASH_LEN_FE] { + assert_eq!(pub_keys.root().len(), HASH_LEN_FE); + array::from_fn(|i| F::from_canonical_checked(pub_keys.root()[i].as_canonical_u32()).unwrap()) +} + +pub fn pubkey_public_parameter(pub_keys: &XmssPublicKey) -> [F; PARAMETER_LEN] { + assert_eq!(pub_keys.parameter().len(), PARAMETER_LEN); + array::from_fn(|i| F::from_canonical_checked(pub_keys.parameter()[i].as_canonical_u32()).unwrap()) +} + +pub fn chain_tweak(slot: u32, chain_idx: u32, step: u32) -> [F; TWEAK_LEN_FE] { + let [t0, t1] = LeanSigTH::chain_tweak(slot, chain_idx as u8, step as u8).to_field_elements(); + [ + F::from_canonical_checked(t0.as_canonical_u32()).unwrap(), + F::from_canonical_checked(t1.as_canonical_u32()).unwrap(), + ] +} + +pub fn merkle_tweak(level: usize, pos_in_level: u32) -> [F; TWEAK_LEN_FE] { + let [t0, t1] = LeanSigTH::tree_tweak(level as u8, pos_in_level).to_field_elements(); + [ + F::from_canonical_checked(t0.as_canonical_u32()).unwrap(), + F::from_canonical_checked(t1.as_canonical_u32()).unwrap(), + ] +} + +pub fn xmss_merkle_path(sig: &XmssSignature) -> &Vec<[F; HASH_LEN_FE]> { + unsafe { std::mem::transmute(sig.path()) } +} + +pub fn xmss_randomness(sig: &XmssSignature) -> &[F; RAND_LEN_FE] { + unsafe { std::mem::transmute(sig.rho()) } +} + +pub fn xmmss_revealed_chain_tips(sig: &XmssSignature) -> &Vec<[F; HASH_LEN_FE]> { + unsafe { std::mem::transmute(sig.hashes()) } +} + +#[allow(clippy::result_unit_err)] +pub fn xmss_public_key_from_ssz(bytes: &[u8]) -> Result { + use ssz::Decode; + XmssPublicKey::from_ssz_bytes(bytes).map_err(|_| ()) +} + +pub fn xmss_public_key_to_ssz(pk: &XmssPublicKey) -> Vec { + use ssz::Encode; + pk.as_ssz_bytes() +} + +#[allow(clippy::result_unit_err)] +pub fn xmss_signature_from_ssz(bytes: &[u8]) -> Result { + use ssz::Decode; + XmssSignature::from_ssz_bytes(bytes).map_err(|_| ()) +} + +pub fn xmss_signature_to_ssz(sig: &XmssSignature) -> Vec { + use ssz::Encode; + sig.as_ssz_bytes() +} + +#[allow(clippy::result_unit_err)] +pub fn xmss_verify( + pk: &XmssPublicKey, + slot: u32, + message: &[u8; MESSAGE_LENGTH], + sig: &XmssSignature, +) -> Result<(), ()> { + if LeanSigScheme::verify(pk, slot, message, sig) { + Ok(()) + } else { + Err(()) + } +} + +pub fn xmss_encode_message(message: &[u8; MESSAGE_LENGTH]) -> [F; MSG_LEN_FE] { + let encoded = encode_message::(message); + array::from_fn(|i| F::from_canonical_checked(encoded[i].as_canonical_u32()).unwrap()) +} + +pub fn xmss_keygen_fast( + rng: &mut R, + activation_epoch: u32, + num_active_epochs: u32, +) -> (FastKeyGenSecretKey, XmssPublicKey) { + let (pk, sk) = FastKeyGenScheme::key_gen(rng, activation_epoch as usize, num_active_epochs as usize); + #[allow(clippy::missing_transmute_annotations)] + let pk = unsafe { std::mem::transmute(pk) }; + (sk, pk) +} + +#[allow(clippy::result_unit_err)] +pub fn xmss_sign_fast( + sk: &FastKeyGenSecretKey, + message: &[u8; MESSAGE_LENGTH], + slot: u32, +) -> Result { + unsafe { std::mem::transmute(FastKeyGenScheme::sign(sk, slot, message).map_err(|_| ())?) } +} diff --git a/crates/rec_aggregation/Cargo.toml b/crates/rec_aggregation/Cargo.toml index ac111ba3f..cc2aa4ac3 100644 --- a/crates/rec_aggregation/Cargo.toml +++ b/crates/rec_aggregation/Cargo.toml @@ -8,13 +8,12 @@ workspace = true [features] prox-gaps-conjecture = ["lean_prover/prox-gaps-conjecture"] -standard-alloc = [] +test-config = ["leansig_wrapper/test-config"] [dependencies] utils.workspace = true -xmss.workspace = true rand.workspace = true - +leansig_wrapper.workspace = true tracing.workspace = true include_dir.workspace = true sub_protocols.workspace = true @@ -25,6 +24,7 @@ backend.workspace = true postcard.workspace = true lz4_flex.workspace = true serde.workspace = true +sha3.workspace = true zk-alloc.workspace = true [target.'cfg(target_os = "macos")'.dependencies] diff --git a/crates/rec_aggregation/TYPE1_TYPE2_LAYOUT.md b/crates/rec_aggregation/TYPE1_TYPE2_LAYOUT.md index d1b3853b8..93c1eff4d 100644 --- a/crates/rec_aggregation/TYPE1_TYPE2_LAYOUT.md +++ b/crates/rec_aggregation/TYPE1_TYPE2_LAYOUT.md @@ -33,16 +33,19 @@ The bytecode-claim region encodes a multilinear evaluation: a point + the result - The outer `· 5` is the **extension-field degree**: each extension element is 5 base-field elements. - For `log_size = 19`: `(19 + 4 + 1) · 5 = 24 · 5 = 120` (already a multiple of 8, but otherwise we padd it with zeros). -## Type-1 component data (fixed, 4 chunks = 32 FE) +## Type-1 component data (fixed, 5 chunks = 40 FE) + +Raw payload is 33 FE (8 + 9 + 8 + 8), padded with 7 zero FE up to the next multiple of `DIGEST_LEN = 8`. | Offset | Size | Contents | | ------ | ---- | ---------------------------------- | | `136` | `8` | Hash of all aggregated public keys | -| `144` | `8` | Message | -| `152` | `8` | Merkle chunks identifying the slot | -| `160` | `8` | Tweak-table hash | +| `144` | `9` | Message (`MSG_LEN_FE`) | +| `153` | `8` | Merkle chunks identifying the slot | +| `161` | `8` | Tweak-table hash | +| `169` | `7` | Zero padding | -**Total Type-1 buffer = 168 FE = 21 chunks** (independent of `n_sigs`). +**Total Type-1 buffer = 176 FE = 22 chunks** (independent of `n_sigs`). ## Type-2 component data (variable, `n_components` chunks) @@ -55,8 +58,8 @@ The bytecode-claim region encodes a multilinear evaluation: a point + the result ## Picture ``` -Type-1 (168 FE): -[flag=1 | n_sigs | 0×6] [bytecode claim, 120 FE] [domsep, 8 FE] [pubkeys_hash | message | merkle_chunks | tweaks_hash] +Type-1 (176 FE): +[flag=1 | n_sigs | 0×6] [bytecode claim, 120 FE] [domsep, 8 FE] [pubkeys_hash(8) | message(9) | merkle_chunks(8) | tweaks_hash(8) | pad(7)] Type-2 ((n+17)·8 FE): [flag=0 | n | 0×6] [bytecode claim, 120 FE] [domsep, 8 FE] [digest_0] [digest_1] … [digest_{n-1}] diff --git a/crates/rec_aggregation/src/type_1_aggregation.rs b/crates/rec_aggregation/src/aggregation.rs similarity index 58% rename from crates/rec_aggregation/src/type_1_aggregation.rs rename to crates/rec_aggregation/src/aggregation.rs index 16316f3cf..1dda3ac16 100644 --- a/crates/rec_aggregation/src/type_1_aggregation.rs +++ b/crates/rec_aggregation/src/aggregation.rs @@ -3,29 +3,29 @@ use backend::*; use lean_prover::fiat_shamir_domain_sep; use lean_prover::prove_execution::{ExecutionProof, prove_execution}; use lean_vm::*; -use tracing::instrument; -use utils::poseidon_compress_slice; -use xmss::CHAIN_LENGTH; -use xmss::make_tweak; -use xmss::{ - LOG_LIFETIME, MESSAGE_LEN_FE, PUB_KEY_FLAT_SIZE, TWEAK_TYPE_CHAIN, TWEAK_TYPE_ENCODING, TWEAK_TYPE_MERKLE, - TWEAK_TYPE_WOTS_PK, V, WOTS_SIG_SIZE_FE, XmssPublicKey, XmssSignature, +use leansig_wrapper::{ + BASE, HASH_LEN_FE, LOG_LIFETIME, MESSAGE_LENGTH, MSG_LEN_FE, PARAMETER_LEN, SIG_SIZE_FE, V, XmssPublicKey, + XmssSignature, chain_tweak, merkle_tweak, pubkey_merkle_root, pubkey_public_parameter, xmmss_revealed_chain_tips, + xmss_encode_message, xmss_merkle_path, xmss_randomness, }; +use tracing::instrument; +use utils::{poseidon_compress_slice, poseidon_compress_slice_zero_iv}; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap, HashSet}; use crate::InnerVerified; -use crate::bytecode_claims::compute_bytecode_value_at; +use crate::bytecode_claims::evaluation_for_bytecode_point; use crate::bytecode_claims::flatten_bytecode_claim; use crate::bytecode_claims::reduce_bytecode_claims; use crate::compilation::{ BYTECODE_CLAIM_OFFSET, MAX_RECURSIONS, MAX_XMSS_AGGREGATED, MAX_XMSS_DUPLICATES, N_MERKLE_CHUNKS_FOR_SLOT, - PREAMBLE_MEMORY_LEN, TYPE1_FLAG, get_aggregation_bytecode, try_get_aggregation_bytecode, - type1_input_data_size_padded, + PREAMBLE_MEMORY_LEN, TYPE1_FLAG, get_aggregation_bytecode, type1_input_data_size_padded, }; -use crate::decompress_size_prepended_bounded; -use crate::verify_inner; +use crate::{lz4_postcard_decode, lz4_postcard_encode, verify_inner}; + +const CHAIN_LENGTH: usize = BASE; +const PUB_KEY_FLAT_SIZE: usize = HASH_LEN_FE + PARAMETER_LEN; /// Number of tweaks in the table: 1 encoding + V*CHAIN_LENGTH chains + 1 wots_pk + LOG_LIFETIME merkle pub(crate) const N_TWEAKS: usize = 1 + V * CHAIN_LENGTH + 1 + LOG_LIFETIME; @@ -37,47 +37,38 @@ pub(crate) const TWEAK_TABLE_SIZE_FE_PADDED: usize = (N_TWEAKS * TWEAK_SLOT_SIZE pub(crate) struct Digest(pub [F; DIGEST_LEN]); #[derive(Debug, Clone, PartialEq, Eq)] -pub struct TypeOneInfo { - pub message: [F; MESSAGE_LEN_FE], - pub slot: u32, +pub struct AggregatedXMSSInfo { + pub without_pubkeys: AggregatedXMSSInfoWithoutPubkeys, pub pubkeys: Vec, - pub bytecode_claim: Evaluation, // value is trusted to be correct (should be recomputed when receiving a proof from an untrusted source) } // Aggregation of many signatures, all sharing the same (message, slot) #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TypeOneMultiSignature { - pub info: TypeOneInfo, +pub struct AggregatedXMSS { + pub info: AggregatedXMSSInfo, pub proof: ExecutionProof, } -impl Serialize for TypeOneInfo { +impl Serialize for AggregatedXMSSInfo { fn serialize(&self, s: S) -> Result { - (&self.message, &self.slot, &self.pubkeys, &self.bytecode_claim.point).serialize(s) + (&self.without_pubkeys, &self.pubkeys).serialize(s) } } -impl<'de> Deserialize<'de> for TypeOneInfo { +impl<'de> Deserialize<'de> for AggregatedXMSSInfo { fn deserialize>(d: D) -> Result { - let (message, slot, pubkeys, bytecode_claim_point) = - <([F; MESSAGE_LEN_FE], u32, Vec, MultilinearPoint)>::deserialize(d)?; - let bytecode = - try_get_aggregation_bytecode().ok_or_else(|| serde::de::Error::custom("bytecode not initialized"))?; - if bytecode_claim_point.len() != bytecode.cumulated_n_vars() { - return Err(serde::de::Error::custom("invalid bytecode point")); + let (without_pubkeys, pubkeys) = <(AggregatedXMSSInfoWithoutPubkeys, Vec)>::deserialize(d)?; + if !pubkeys.is_sorted() { + return Err(serde::de::Error::custom("unsorted pubkeys")); } - check_type_one_pubkeys(&pubkeys).map_err(serde::de::Error::custom)?; - let bytecode_value = compute_bytecode_value_at(&bytecode_claim_point); Ok(Self { - message, - slot, + without_pubkeys, pubkeys, - bytecode_claim: Evaluation::new(bytecode_claim_point, bytecode_value), }) } } -pub(crate) fn check_type_one_pubkeys(pubkeys: &[XmssPublicKey]) -> Result<(), &'static str> { +pub(crate) fn check_aggregation_pubkeys(pubkeys: &[XmssPublicKey]) -> Result<(), &'static str> { if pubkeys.is_empty() { return Err("pubkeys must be non-empty"); } @@ -90,16 +81,24 @@ pub(crate) fn check_type_one_pubkeys(pubkeys: &[XmssPublicKey]) -> Result<(), &' Ok(()) } -impl TypeOneMultiSignature { +impl AggregatedXMSS { pub fn compress(&self) -> Vec { - let encoded = postcard::to_allocvec(self).expect("postcard serialization failed"); - lz4_flex::compress_prepend_size(&encoded) + lz4_postcard_encode(self) } pub fn decompress(bytes: &[u8]) -> Option { - let decompressed = decompress_size_prepended_bounded(bytes)?; - let (value, rest) = postcard::take_from_bytes::(&decompressed).ok()?; - rest.is_empty().then_some(value) + lz4_postcard_decode(bytes) + } + + pub fn compress_without_pubkeys(&self) -> Vec { + lz4_postcard_encode(&(&self.info.without_pubkeys, &self.proof)) + } + + pub fn decompress_without_pubkeys(bytes: &[u8], pubkeys: Vec) -> Option { + let (without_pubkeys, proof) = + lz4_postcard_decode::<(AggregatedXMSSInfoWithoutPubkeys, ExecutionProof)>(bytes)?; + let info = without_pubkeys.with_pubkeys(pubkeys)?; + Some(Self { info, proof }) } pub(crate) fn bytecode_claim_flat(&self) -> Vec { @@ -107,19 +106,65 @@ impl TypeOneMultiSignature { } } -impl TypeOneInfo { +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct AggregatedXMSSInfoWithoutPubkeys { + pub message: [u8; MESSAGE_LENGTH], + pub slot: u32, + pub bytecode_claim: Evaluation, +} + +impl Serialize for AggregatedXMSSInfoWithoutPubkeys { + fn serialize(&self, s: S) -> Result { + (&self.message, &self.slot, &self.bytecode_claim.point).serialize(s) + } +} + +impl<'de> Deserialize<'de> for AggregatedXMSSInfoWithoutPubkeys { + fn deserialize>(d: D) -> Result { + let (message, slot, bytecode_claim_point) = + <([u8; MESSAGE_LENGTH], u32, MultilinearPoint)>::deserialize(d)?; + let bytecode_claim = evaluation_for_bytecode_point(bytecode_claim_point) + .ok_or_else(|| serde::de::Error::custom("invalid bytecode point"))?; + Ok(Self { + message, + slot, + bytecode_claim, + }) + } +} + +impl AggregatedXMSSInfoWithoutPubkeys { + pub(crate) fn with_pubkeys(self, mut pubkeys: Vec) -> Option { + pubkeys.sort(); + pubkeys.dedup(); + Some(AggregatedXMSSInfo { + without_pubkeys: self, + pubkeys, + }) + } +} + +impl AggregatedXMSSInfo { pub(crate) fn bytecode_claim_flat(&self) -> Vec { - flatten_bytecode_claim(&self.bytecode_claim) + flatten_bytecode_claim(&self.without_pubkeys.bytecode_claim) + } + + pub fn compress_without_pubkeys(&self) -> Vec { + lz4_postcard_encode(&self.without_pubkeys) + } + + pub fn decompress_without_pubkeys(bytes: &[u8], pubkeys: Vec) -> Option { + lz4_postcard_decode::(bytes)?.with_pubkeys(pubkeys) } pub(crate) fn build_input_data(&self) -> Vec { - let tweak_table = compute_tweak_table(self.slot); + let tweak_table = compute_tweak_table(self.without_pubkeys.slot); let tweaks_hash = poseidon_compress_slice(&tweak_table); - build_type1_input_data( + build_aggregation_input_data( self.pubkeys.len(), &hash_pubkeys(&self.pubkeys), - &self.message, - self.slot, + &xmss_encode_message(&self.without_pubkeys.message), + self.without_pubkeys.slot, &tweaks_hash, &self.bytecode_claim_flat(), get_aggregation_bytecode(), @@ -127,37 +172,53 @@ impl TypeOneInfo { } } +fn pub_key_flat(pk: &XmssPublicKey) -> Vec { + let mut data = Vec::with_capacity(PUB_KEY_FLAT_SIZE); + data.extend_from_slice(&pubkey_merkle_root(pk)); + data.extend_from_slice(&pubkey_public_parameter(pk)); + data +} + pub(crate) fn hash_pubkeys(pub_keys: &[XmssPublicKey]) -> [F; DIGEST_LEN] { - let flat: Vec = pub_keys.iter().flat_map(|pk| pk.flaten().into_iter()).collect(); - poseidon_compress_slice(&flat) + // leansig pubkeys are PUB_KEY_FLAT_SIZE=13 FE each (not 8-aligned) -> zero-IV remainder hash, + // matching the zkDSL `slice_hash_with_iv_dynamic_unroll`. + let flat: Vec = pub_keys.iter().flat_map(pub_key_flat).collect(); + poseidon_compress_slice_zero_iv(&flat) } /// Tweak slots are 4-FE [tw[0], tw[1], 0, 0] fn compute_tweak_table(slot: u32) -> Vec { let mut table = Vec::new(); - let push_padded = |table: &mut Vec, tweak_type: usize, sub_position: usize, index: u32| { - table.extend(make_tweak(tweak_type, sub_position, index)); + let push_padded = |table: &mut Vec, tweak: [F; 2]| { + table.extend(tweak); table.extend(std::iter::repeat_n(F::ZERO, 2)); }; - // Encoding tweak - push_padded(&mut table, TWEAK_TYPE_ENCODING, 0, slot); + // Encoding tweak: encode_epoch(slot) = ((slot << 8) | TWEAK_SEPARATOR_MSG) in base-p + let acc = ((slot as u64) << 8) | 0x02u64; + let encoding_tweak = [F::from_u64(acc % F::ORDER_U64), F::from_u64(acc / F::ORDER_U64)]; + push_padded(&mut table, encoding_tweak); // Chain tweaks for i in 0..V { for s in 0..CHAIN_LENGTH { - push_padded(&mut table, TWEAK_TYPE_CHAIN, i * CHAIN_LENGTH + s, slot); + push_padded(&mut table, chain_tweak(slot, i as u32, s as u32)); } } - // WOTS_PK tweak - push_padded(&mut table, TWEAK_TYPE_WOTS_PK, 0, slot); + // Leaf tweak: tree_tweak(0, slot) for hashing chain ends into a leaf node + push_padded(&mut table, merkle_tweak(0, slot)); // Merkle tweaks for level in 0..LOG_LIFETIME { - let parent_index = ((slot as u64) >> (level + 1)) as u32; - push_padded(&mut table, TWEAK_TYPE_MERKLE, level + 1, parent_index); + let parent_level = level + 1; + let parent_index = if parent_level < 32 { + ((slot as u64) >> parent_level) as u32 + } else { + 0 + }; + push_padded(&mut table, merkle_tweak(parent_level, parent_index)); } table.resize(TWEAK_TABLE_SIZE_FE_PADDED, F::ZERO); table @@ -173,10 +234,10 @@ fn compute_merkle_chunks_for_slot(slot: u32) -> Vec { } /// Layout: [prefix(8) | bytecode_claim_padded | initial_fiat_shamir_cap(8) | pubkeys_hash | message | merkle_chunks | tweaks_hash]. -pub(crate) fn build_type1_input_data( +pub(crate) fn build_aggregation_input_data( n_sigs: usize, pubkeys_hash: &[F; DIGEST_LEN], - message: &[F; MESSAGE_LEN_FE], + message: &[F; MSG_LEN_FE], slot: u32, tweaks_hash: &[F; DIGEST_LEN], bytecode_claim_flat: &[F], @@ -195,44 +256,100 @@ pub(crate) fn build_type1_input_data( data.extend_from_slice(message); data.extend(compute_merkle_chunks_for_slot(slot)); data.extend_from_slice(tweaks_hash); + // Pad up to a multiple of DIGEST_LEN so slice_hash_with_iv consumes the whole buffer. + data.resize(data.len().next_multiple_of(DIGEST_LEN), F::ZERO); data } -fn encode_wots_signature(sig: &XmssSignature) -> Vec { +fn encode_xmss_signature(sig: &XmssSignature) -> Vec { let mut data = vec![]; - data.extend(sig.wots_signature.randomness.to_vec()); - data.extend(sig.wots_signature.chain_tips.iter().flat_map(|digest| digest.to_vec())); - assert_eq!(data.len(), WOTS_SIG_SIZE_FE); + data.extend_from_slice(xmss_randomness(sig)); + data.extend( + xmmss_revealed_chain_tips(sig) + .iter() + .flat_map(|digest| digest.iter().copied()), + ); + data.extend(xmss_merkle_path(sig).iter().flat_map(|digest| digest.iter().copied())); + assert_eq!(data.len(), SIG_SIZE_FE); data } -// assumes `bytecode_value` in TypeOneMultiSignature::proof is correct (it should not be read / deserialized from an untrusted source) -pub fn verify_type_1(sig: &TypeOneMultiSignature) -> Result { - check_type_one_pubkeys(&sig.info.pubkeys).map_err(|_| ProofError::InvalidProof)?; +// assumes `bytecode_value` in AggregatedXMSS::proof is correct (it should not be read / deserialized from an untrusted source) +pub(crate) fn verify_aggregation(sig: &AggregatedXMSS) -> Result { + check_aggregation_pubkeys(&sig.info.pubkeys).map_err(|_| ProofError::InvalidProof)?; verify_inner(sig.info.build_input_data(), sig.proof.proof.clone()) } +/// Verify an aggregated multi-signature against the expected `pub_keys`, `message` and `slot`. +pub fn xmss_verify_aggregation( + pub_keys: Vec, + agg_sig: &AggregatedXMSS, + message: &[u8; MESSAGE_LENGTH], + slot: u32, +) -> Result { + let mut pub_keys = pub_keys; + pub_keys.sort(); + pub_keys.dedup(); + if pub_keys != agg_sig.info.pubkeys + || agg_sig.info.without_pubkeys.message != *message + || agg_sig.info.without_pubkeys.slot != slot + { + return Err(ProofError::InvalidProof); + } + verify_aggregation(agg_sig) +} + /// Aggregate raw XMSS signatures and previously aggregated multi-signatures. /// Type 1 = single message, single slot. #[instrument(skip_all)] -pub fn aggregate_type_1( - children: &[TypeOneMultiSignature], +pub(crate) fn aggregate( + children: &[AggregatedXMSS], raw_xmss: Vec<(XmssPublicKey, XmssSignature)>, - message: [F; MESSAGE_LEN_FE], + message: [u8; MESSAGE_LENGTH], slot: u32, log_inv_rate: usize, -) -> Result { - aggregate_type_1_with_min_padding(children, raw_xmss, message, slot, log_inv_rate, BTreeMap::new()) +) -> Result { + xmss_aggregate_with_min_padding(children, raw_xmss, message, slot, log_inv_rate, BTreeMap::new()) } -pub(crate) fn aggregate_type_1_with_min_padding( - children: &[TypeOneMultiSignature], +/// Aggregate raw XMSS signatures and previously aggregated multi-signatures, all sharing the +/// given `message`/`slot`. Each child is `(pub_keys, signature)`; `pub_keys` must match those +/// bound inside the signature. +/// +/// Returns the sorted, deduplicated union of all signers alongside the aggregated signature. +pub fn xmss_aggregate( + children: &[(&[XmssPublicKey], AggregatedXMSS)], + raw_xmss: Vec<(XmssPublicKey, XmssSignature)>, + message: &[u8; MESSAGE_LENGTH], + slot: u32, + log_inv_rate: usize, +) -> Result<(Vec, AggregatedXMSS), AggregationError> { + let mut child_sigs = Vec::with_capacity(children.len()); + for (pub_keys, agg) in children { + let mut pub_keys = pub_keys.to_vec(); + pub_keys.sort(); + pub_keys.dedup(); + if pub_keys != agg.info.pubkeys + || agg.info.without_pubkeys.message != *message + || agg.info.without_pubkeys.slot != slot + { + return Err(AggregationError::InvalidChildProof(ProofError::InvalidProof)); + } + child_sigs.push(agg.clone()); + } + let aggregated = aggregate(&child_sigs, raw_xmss, *message, slot, log_inv_rate)?; + let pub_keys = aggregated.info.pubkeys.clone(); + Ok((pub_keys, aggregated)) +} + +pub(crate) fn xmss_aggregate_with_min_padding( + children: &[AggregatedXMSS], mut raw_xmss: Vec<(XmssPublicKey, XmssSignature)>, - message: [F; MESSAGE_LEN_FE], + message: [u8; MESSAGE_LENGTH], slot: u32, log_inv_rate: usize, min_table_log_n_rows: BTreeMap, -) -> Result { +) -> Result { if children.len() > MAX_RECURSIONS { return Err(AggregationError::LimitExceeded { what: "aggregation children", @@ -241,19 +358,19 @@ pub(crate) fn aggregate_type_1_with_min_padding( }); } for child in children { - if child.info.message != message { + if child.info.without_pubkeys.message != message { return Err(AggregationError::InconsistentChildren { what: "all children of a type-1 aggregation must share the same message", }); } - if child.info.slot != slot { + if child.info.without_pubkeys.slot != slot { return Err(AggregationError::InconsistentChildren { what: "all children of a type-1 aggregation must share the same slot", }); } } let message = &message; - let verified_children: Vec = children.iter().map(verify_type_1).collect::>()?; + let verified_children: Vec = children.iter().map(verify_aggregation).collect::>()?; let children: Vec<&[XmssPublicKey]> = children.iter().map(|c| c.info.pubkeys.as_slice()).collect(); let children = children.as_slice(); @@ -293,10 +410,10 @@ pub(crate) fn aggregate_type_1_with_min_padding( let reduced_claims = reduce_bytecode_claims(&verified_children); - let pub_input_data = build_type1_input_data( + let pub_input_data = build_aggregation_input_data( n_sigs, &hash_pubkeys(&global_pub_keys), - message, + &xmss_encode_message(message), slot, &tweaks_hash, &reduced_claims.final_claim_flat(), @@ -307,11 +424,7 @@ pub(crate) fn aggregate_type_1_with_min_padding( let mut claimed: HashSet = HashSet::new(); let mut dup_pub_keys: Vec = Vec::new(); - let wots_blobs: Vec> = raw_xmss.iter().map(|(_, sig)| encode_wots_signature(sig)).collect(); - let xmss_merkle_node_blobs: Vec> = raw_xmss - .iter() - .flat_map(|(_, sig)| sig.merkle_proof.iter().map(|d| d.to_vec())) - .collect(); + let xmss_signature_blobs: Vec> = raw_xmss.iter().map(|(_, sig)| encode_xmss_signature(sig)).collect(); let raw_indices: Vec = raw_xmss .iter() @@ -326,7 +439,7 @@ pub(crate) fn aggregate_type_1_with_min_padding( let mut bytecode_value_hint_blobs = Vec::with_capacity(n_recursions); let mut inner_bytecode_claim_blobs = Vec::with_capacity(n_recursions); let mut proof_transcript_blobs = Vec::with_capacity(n_recursions); - let mut table_sort_perm_blobs = Vec::with_capacity(n_recursions); + let mut table_sort_perm_blobs: Vec> = Vec::with_capacity(n_recursions); let claim_size_padded = bytecode_claim_size.next_multiple_of(DIGEST_LEN); @@ -363,10 +476,10 @@ pub(crate) fn aggregate_type_1_with_min_padding( let mut pubkeys_blob: Vec = Vec::with_capacity((n_sigs + n_dup) * PUB_KEY_FLAT_SIZE); for pk in &global_pub_keys { - pubkeys_blob.extend_from_slice(&pk.flaten()); + pubkeys_blob.extend(pub_key_flat(pk)); } for pk in &dup_pub_keys { - pubkeys_blob.extend_from_slice(&pk.flaten()); + pubkeys_blob.extend(pub_key_flat(pk)); } let (merkle_leaf_blobs, merkle_path_blobs) = @@ -394,8 +507,6 @@ pub(crate) fn aggregate_type_1_with_min_padding( let fast_path = n_recursions == 1 && raw_count == 0 && dup_pub_keys.is_empty(); let sub_indices_for_hints = if fast_path { Vec::new() } else { sub_indices_blobs }; hints.insert("sub_indices".to_string(), sub_indices_for_hints); - // Standard type-1 (not a split). - hints.insert("is_split".to_string(), vec![vec![F::ZERO]]); hints.insert("bytecode_value_hint".to_string(), bytecode_value_hint_blobs); hints.insert("inner_bytecode_claim".to_string(), inner_bytecode_claim_blobs); hints.insert( @@ -407,8 +518,7 @@ pub(crate) fn aggregate_type_1_with_min_padding( ); hints.insert("proof_transcript".to_string(), proof_transcript_blobs); hints.insert("table_sort_perm".to_string(), table_sort_perm_blobs); - hints.insert("wots".to_string(), wots_blobs); - hints.insert("xmss_merkle_node".to_string(), xmss_merkle_node_blobs); + hints.insert("xmss_signature".to_string(), xmss_signature_blobs); hints.insert("merkle_leaf".to_string(), merkle_leaf_blobs); hints.insert("merkle_path".to_string(), merkle_path_blobs); hints.insert("aggregate_sizes".to_string(), vec![aggregate_sizes]); @@ -427,12 +537,14 @@ pub(crate) fn aggregate_type_1_with_min_padding( }; let proof = prove_execution(bytecode, &public_input, &witness, &whir_config, false)?; - Ok(TypeOneMultiSignature { - info: TypeOneInfo { - message: *message, - slot, + Ok(AggregatedXMSS { + info: AggregatedXMSSInfo { + without_pubkeys: AggregatedXMSSInfoWithoutPubkeys { + message: *message, + slot, + bytecode_claim: reduced_claims.final_claim, + }, pubkeys: global_pub_keys, - bytecode_claim: reduced_claims.final_claim, }, proof, }) @@ -457,7 +569,7 @@ pub(crate) fn extract_merkle_hint_blobs<'a>( mod tests { use super::*; use crate::compilation::init_aggregation_bytecode; - use xmss::signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; + use crate::signatures_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; /// Exercises the recursive-aggregation path when the inner proof has the /// extension-op table bigger than the execution table. @@ -476,14 +588,13 @@ mod tests { let mut min_padding: BTreeMap = BTreeMap::new(); min_padding.insert(Table::extension_op(), extension_padding_log); - let inner = - aggregate_type_1_with_min_padding(&[], raws_inner, message, slot, log_inv_rate, min_padding).unwrap(); - verify_type_1(&inner).unwrap(); + let inner = xmss_aggregate_with_min_padding(&[], raws_inner, message, slot, log_inv_rate, min_padding).unwrap(); + verify_aggregation(&inner).unwrap(); let inner_metadata = inner.proof.metadata.as_ref().expect("inner metadata available"); assert!(dbg!(inner_metadata.cycles) < 1usize << extension_padding_log,); - let outer = aggregate_type_1(&[inner], raws_outer, message, slot, log_inv_rate).unwrap(); - verify_type_1(&outer).unwrap(); + let outer = aggregate(&[inner], raws_outer, message, slot, log_inv_rate).unwrap(); + verify_aggregation(&outer).unwrap(); } } diff --git a/crates/rec_aggregation/src/benchmark.rs b/crates/rec_aggregation/src/benchmark.rs index 920d9f4c3..c0922287b 100644 --- a/crates/rec_aggregation/src/benchmark.rs +++ b/crates/rec_aggregation/src/benchmark.rs @@ -1,14 +1,14 @@ use backend::*; use lean_vm::*; +use leansig_wrapper::{XmssPublicKey, XmssSignature}; use serde::{Deserialize, Serialize}; use std::io::{self, Write}; use std::time::Instant; use utils::ansi as s; -use xmss::signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; -use xmss::{XmssPublicKey, XmssSignature}; +use crate::aggregation::{AggregatedXMSS, aggregate, verify_aggregation}; use crate::compilation::{get_aggregation_bytecode, init_aggregation_bytecode}; -use crate::type_1_aggregation::{TypeOneMultiSignature, aggregate_type_1, verify_type_1}; +use crate::signatures_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; #[derive(Debug, Clone)] pub struct AggregationTopology { @@ -47,6 +47,15 @@ fn count_nodes(topology: &AggregationTopology) -> usize { 1 + topology.children.iter().map(count_nodes).sum::() } +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum NodeKind { + #[default] + AggregateType1, + MergeManyType1, + SplitType2, +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct NodeStats { pub time_secs: f64, @@ -61,6 +70,8 @@ pub struct NodeStats { pub poseidons: usize, pub dots: usize, pub n_xmss: Option, + #[serde(default)] + pub kind: NodeKind, } fn default_samples() -> usize { @@ -256,6 +267,21 @@ impl LiveTree { } } +fn print_stage(silent: bool, label: &str, stats: &NodeStats) { + if silent { + return; + } + let xmss_tag = stats.n_xmss.map(|n| format!(" n_xmss={}", n)).unwrap_or_default(); + println!( + "{:30} {:>8.3}s {:>5} KiB cycles={:>10}{}", + label, + stats.time_secs, + stats.proof_kib, + pretty_integer(stats.cycles), + xmss_tag, + ); +} + #[allow(clippy::too_many_arguments)] fn build_tree_descs( topology: &AggregationTopology, @@ -351,13 +377,13 @@ fn build_aggregation( tracing: bool, is_root: bool, repeat: usize, -) -> TypeOneMultiSignature { +) -> AggregatedXMSS { let raw_count = topology.raw_xmss; let raw_xmss: Vec<(XmssPublicKey, XmssSignature)> = (0..raw_count) .map(|i| (pub_keys[i].clone(), signatures[i].clone())) .collect(); - let mut children: Vec = vec![]; + let mut children: Vec = vec![]; let mut child_start = raw_count; let mut child_display_index = display_index; for (child_idx, child) in topology.children.iter().enumerate() { @@ -390,16 +416,12 @@ fn build_aggregation( assert!(repeat > 0); let is_leaf = topology.children.is_empty(); - let n_xmss_opt = is_leaf.then_some(topology.raw_xmss); let mut times = Vec::with_capacity(repeat); - let mut last_result: Option = None; + let mut last_result: Option = None; let own_display_index = display_index + count_nodes(topology) - 1; for _ in 0..repeat { - #[cfg(not(feature = "standard-alloc"))] - zk_alloc::begin_phase(); - let time = Instant::now(); - let result = aggregate_type_1( + let result = aggregate( &children, raw_xmss.clone(), message_for_benchmark(), @@ -409,13 +431,6 @@ fn build_aggregation( .unwrap(); let elapsed = time.elapsed(); - // Clone the outputs out of the arena before the next phase resets its slabs. - #[cfg(not(feature = "standard-alloc"))] - let result = { - zk_alloc::end_phase(); - result.clone() - }; - times.push(elapsed.as_secs_f64()); last_result = Some(result); @@ -435,7 +450,8 @@ fn build_aggregation( memory: meta.memory, poseidons: meta.n_poseidons, dots: meta.n_extension_ops, - n_xmss: n_xmss_opt, + n_xmss: if is_leaf { Some(topology.raw_xmss) } else { None }, + kind: NodeKind::AggregateType1, }, ); } @@ -475,7 +491,8 @@ fn build_aggregation( memory: meta.memory, poseidons: meta.n_poseidons, dots: meta.n_extension_ops, - n_xmss: n_xmss_opt, + n_xmss: if is_leaf { Some(topology.raw_xmss) } else { None }, + kind: NodeKind::AggregateType1, }; if !tracing { live_tree.update_node(own_display_index, &stats); @@ -541,11 +558,65 @@ pub fn run_aggregation_benchmark( repeat, ); - verify_type_1(&aggregated).expect("root type-1 proof failed to verify"); + verify_aggregation(&aggregated).expect("root type-1 proof failed to verify"); BenchmarkReport { nodes } } +#[allow(clippy::too_many_arguments)] +fn run_xmss_aggregate( + children: &[AggregatedXMSS], + raw_xmss: Vec<(XmssPublicKey, XmssSignature)>, + log_inv_rate: usize, + n_xmss: usize, + path: Vec, + label: &str, + silent: bool, + nodes: &mut Vec, +) -> AggregatedXMSS { + let time = Instant::now(); + + zk_alloc::begin_phase(); + + let result = aggregate( + children, + raw_xmss, + message_for_benchmark(), + BENCHMARK_SLOT, + log_inv_rate, + ) + .unwrap(); + + let result = { + zk_alloc::end_phase(); + result.clone() + }; + + let elapsed = time.elapsed(); + let meta = result.proof.metadata.as_ref().unwrap(); + let proof_kib = result.proof.proof.proof_size_fe() * F::bits() / (8 * 1024); + let stats = NodeStats { + time_secs: elapsed.as_secs_f64(), + time_ci_secs: 0.0, + samples: 1, + proof_kib, + cycles: meta.cycles, + memory: meta.memory, + poseidons: meta.n_poseidons, + dots: meta.n_extension_ops, + n_xmss: Some(n_xmss), + kind: NodeKind::AggregateType1, + }; + + print_stage(silent, label, &stats); + nodes.push(NodeReport { + path: path.clone(), + stats, + }); + + result +} + // TODO is there a better fix? #[cfg(target_os = "macos")] mod macos_activity { diff --git a/crates/rec_aggregation/src/bytecode_claims.rs b/crates/rec_aggregation/src/bytecode_claims.rs index b948dc0e8..eb50b670f 100644 --- a/crates/rec_aggregation/src/bytecode_claims.rs +++ b/crates/rec_aggregation/src/bytecode_claims.rs @@ -33,6 +33,14 @@ pub(crate) fn compute_bytecode_value_at(point: &MultilinearPoint) -> EF { } } +pub(crate) fn evaluation_for_bytecode_point(point: MultilinearPoint) -> Option> { + if point.len() != get_aggregation_bytecode().cumulated_n_vars() { + return None; + } + let value = compute_bytecode_value_at(&point); + Some(Evaluation::new(point, value)) +} + pub(crate) fn reduce_bytecode_claims(verified: &[InnerVerified]) -> ReducedBytecodeClaims { let bytecode = get_aggregation_bytecode(); @@ -64,12 +72,14 @@ pub(crate) fn reduce_bytecode_claims(verified: &[InnerVerified]) -> ReducedBytec let n_claims = claims.len(); let alpha_powers: Vec = alpha.powers().take(n_claims).collect(); - let weights_packed = claims - .par_iter() - .zip(&alpha_powers) - .map(|(eval, &alpha_i)| eval_eq_packed_scaled(&eval.point.0, alpha_i)) - .reduce_with(|mut acc, eq_i| { - acc.par_iter_mut().zip(&eq_i).for_each(|(w, e)| *w += *e); + // Sequential outer fold: `n_claims` is small and `eval_eq_packed_scaled` is itself parallel, + // so parallelizing here would nest parallel dispatch (forbidden by the `parallel` pool). + let weights_packed = (0..n_claims) + .map(|i| eval_eq_packed_scaled(&claims[i].point.0, alpha_powers[i])) + .reduce(|mut acc, eq_i| { + for (w, e) in acc.iter_mut().zip(eq_i.iter()) { + *w += *e; + } acc }) .unwrap(); diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 830ad583a..94cd3dc01 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -5,15 +5,17 @@ use lean_prover::{ WHIR_SUBSEQUENT_FOLDING_FACTOR, default_whir_config, }; use lean_vm::*; +use leansig_wrapper::{ + LOG_LIFETIME, MSG_LEN_FE, PARAMETER_LEN, RAND_LEN_FE, TARGET_SUM, TWEAK_LEN_FE, V, W, WOTS_PUBKET_SPONGE_DOMAIN_SEP, +}; use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::OnceLock; use sub_protocols::{N_VARS_TO_SEND_GKR_COEFFS, min_stacked_n_vars, total_whir_statements}; use tracing::instrument; use utils::Counter; -use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, PUBLIC_PARAM_LEN_FE, RANDOMNESS_LEN_FE, TARGET_SUM, V, W, XMSS_DIGEST_LEN}; +use crate::aggregation::TWEAK_TABLE_SIZE_FE_PADDED; use crate::bytecode_claims::bytecode_reduction_sumcheck_proof_size; -use crate::type_1_aggregation::TWEAK_TABLE_SIZE_FE_PADDED; // preamble memory layout: see `build_preamble_memory` in utils.py: // [000.. (ZERO_VEC_LEN)][10000000 (fiat-shamir domain sep)][10000 (one in extension field)][111... (NUM_REPEATED_ONES)][tweak table] @@ -27,20 +29,10 @@ pub(crate) const N_MERKLE_CHUNKS_FOR_SLOT: usize = LOG_LIFETIME / MERKLE_LEVELS_ static BYTECODE: OnceLock = OnceLock::new(); -pub fn get_aggregation_bytecode() -> &'static Bytecode { - BYTECODE - .get() - .unwrap_or_else(|| panic!("call init_aggregation_bytecode() first")) -} - pub fn try_get_aggregation_bytecode() -> Option<&'static Bytecode> { BYTECODE.get() } -pub fn init_aggregation_bytecode() { - BYTECODE.get_or_init(compile_main_program_self_referential); -} - static EMBEDDED_ZK_DSL: include_dir::Dir<'_> = include_dir::include_dir!("$CARGO_MANIFEST_DIR/zkdsl_implem"); pub const MAX_RECURSIONS: usize = 16; @@ -48,11 +40,11 @@ pub const MAX_XMSS_AGGREGATED: usize = 1 << 15; // TODO increase (we would need pub const MAX_XMSS_DUPLICATES: usize = 1 << 15; // ...same pub(crate) const TYPE1_FLAG: usize = 1; -pub(crate) const TYPE2_FLAG: usize = 0; pub(crate) const BYTECODE_CLAIM_OFFSET: usize = DIGEST_LEN; -/// Type-1's component data: pubkeys_hash | message | merkle_chunks | tweaks_hash. -pub(crate) const COMPONENT_DATA_SIZE: usize = DIGEST_LEN + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT + DIGEST_LEN; +/// Type-1's component data: pubkeys_hash | message | merkle_chunks | tweaks_hash, padded to DIGEST_LEN. +pub(crate) const COMPONENT_DATA_SIZE: usize = + (DIGEST_LEN + MSG_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT + DIGEST_LEN).next_multiple_of(DIGEST_LEN); pub(crate) fn bytecode_claim_size_padded(program_log_size: usize) -> usize { let bytecode_point_n_vars = program_log_size + log2_ceil_usize(N_INSTRUCTION_COLUMNS); @@ -71,6 +63,16 @@ pub(crate) fn type1_input_data_size_padded(program_log_size: usize) -> usize { component_data_offset(program_log_size) + COMPONENT_DATA_SIZE } +pub fn get_aggregation_bytecode() -> &'static Bytecode { + BYTECODE + .get() + .unwrap_or_else(|| panic!("call init_aggregation_bytecode() first")) +} + +pub fn init_aggregation_bytecode() { + BYTECODE.get_or_init(compile_main_program_self_referential); +} + fn compile_main_program(program_log_size: usize, bytecode_zero_eval: F) -> Bytecode { let replacements = build_replacements(program_log_size, bytecode_zero_eval); @@ -426,22 +428,29 @@ fn build_replacements(log_inner_bytecode: usize, bytecode_zero_eval: F) -> BTree // XMSS-specific replacements replacements.insert("V_PLACEHOLDER".to_string(), V.to_string()); replacements.insert("W_PLACEHOLDER".to_string(), W.to_string()); + replacements.insert("PUBLIC_PARAM_LEN_PLACEHOLDER".to_string(), PARAMETER_LEN.to_string()); + replacements.insert("TWEAK_LEN_PLACEHOLDER".to_string(), TWEAK_LEN_FE.to_string()); replacements.insert("TARGET_SUM_PLACEHOLDER".to_string(), TARGET_SUM.to_string()); replacements.insert("LOG_LIFETIME_PLACEHOLDER".to_string(), LOG_LIFETIME.to_string()); - replacements.insert("MESSAGE_LEN_PLACEHOLDER".to_string(), MESSAGE_LEN_FE.to_string()); - replacements.insert("RANDOMNESS_LEN_PLACEHOLDER".to_string(), RANDOMNESS_LEN_FE.to_string()); - replacements.insert( - "PUBLIC_PARAM_LEN_FE_PLACEHOLDER".to_string(), - PUBLIC_PARAM_LEN_FE.to_string(), - ); + replacements.insert("MESSAGE_LEN_PLACEHOLDER".to_string(), MSG_LEN_FE.to_string()); + replacements.insert("RANDOMNESS_LEN_PLACEHOLDER".to_string(), RAND_LEN_FE.to_string()); replacements.insert( "MERKLE_LEVELS_PER_CHUNK_PLACEHOLDER".to_string(), MERKLE_LEVELS_PER_CHUNK_FOR_SLOT.to_string(), ); - replacements.insert("XMSS_DIGEST_LEN_PLACEHOLDER".to_string(), XMSS_DIGEST_LEN.to_string()); + replacements.insert( + "WOTS_PUBKET_SPONGE_DOMAIN_SEP_PLACEHOLDER".to_string(), + format!( + "[{}]", + WOTS_PUBKET_SPONGE_DOMAIN_SEP + .iter() + .map(|v| v.to_string()) + .collect::>() + .join(", ") + ), + ); replacements.insert("TYPE_1_FLAG_PLACEHOLDER".to_string(), TYPE1_FLAG.to_string()); - replacements.insert("TYPE_2_FLAG_PLACEHOLDER".to_string(), TYPE2_FLAG.to_string()); replacements.insert( "MAX_XMSS_AGGREGATED_PLACEHOLDER".to_string(), MAX_XMSS_AGGREGATED.to_string(), @@ -471,6 +480,7 @@ fn all_air_evals_in_zk_dsl() -> String { res += &air_eval_in_zk_dsl(ExecutionTable:: {}); res += &air_eval_in_zk_dsl(ExtensionOpPrecompile:: {}); res += &air_eval_in_zk_dsl(Poseidon16Precompile:: {}); + res += &air_eval_in_zk_dsl(Poseidon24Precompile:: {}); res } @@ -595,7 +605,7 @@ fn eval_air_constraint( ctx.expr_cache.insert(idx, v.clone()); return v; } else { - let node = get_node::(idx); + let node = unsafe { get_node::(idx) }; let v = match node.op { SymbolicOperation::Neg => { let a = eval_air_constraint(node.lhs, None, ctx, res); @@ -632,7 +642,7 @@ fn try_emit_dot_product_be(idx: u32, dest: Option<&str>, ctx: &mut AirCodegenCtx if op_idx != idx && ctx.expr_cache.contains_key(&op_idx) { return None; } - let node = get_node::(op_idx); + let node = unsafe { get_node::(op_idx) }; if node.op != SymbolicOperation::Add { return None; } @@ -640,7 +650,7 @@ fn try_emit_dot_product_be(idx: u32, dest: Option<&str>, ctx: &mut AirCodegenCtx SymbolicExpression::Operation(i) => i, _ => return None, }; - let mul = get_node::(mul_idx); + let mul = unsafe { get_node::(mul_idx) }; if mul.op != SymbolicOperation::Mul { return None; } diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 3f2a6cb59..596b498a7 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -1,12 +1,16 @@ #![cfg_attr(not(test), allow(unused_crate_dependencies))] +mod aggregation; pub mod benchmark; mod bytecode_claims; mod compilation; mod error; -mod type_1_aggregation; -mod type_2_aggregation; +pub mod signatures_cache; +pub use aggregation::{ + AggregatedXMSS, AggregatedXMSSInfo, AggregatedXMSSInfoWithoutPubkeys, xmss_aggregate, xmss_verify_aggregation, +}; use backend::{Evaluation, Proof, ProofError, RawProof}; +pub use benchmark::AggregationTopology; pub use compilation::{ MAX_RECURSIONS, MAX_XMSS_AGGREGATED, MAX_XMSS_DUPLICATES, NUM_REPEATED_ONES, PREAMBLE_MEMORY_LEN, ZERO_VEC_LEN, get_aggregation_bytecode, init_aggregation_bytecode, @@ -15,10 +19,6 @@ pub use error::AggregationError; pub use lean_prover::ProverError; use lean_prover::verify_execution::verify_execution; use lean_vm::{DIGEST_LEN, EF, F}; -pub use type_1_aggregation::{TypeOneInfo, TypeOneMultiSignature, aggregate_type_1, verify_type_1}; -pub use type_2_aggregation::{ - TypeTwoMultiSignature, merge_many_type_1, split_type_2, split_type_2_by_msg, verify_type_2, -}; use utils::poseidon_compress_slice; #[allow(missing_debug_implementations)] @@ -54,3 +54,15 @@ pub(crate) fn verify_inner(input_data: Vec, proof: Proof) -> Result(value: &T) -> Vec { + let encoded = postcard::to_allocvec(value).expect("postcard serialization failed"); + lz4_flex::compress_prepend_size(&encoded) +} + +/// Inverse of `lz4_postcard_encode`. Returns `None` on either lz4 or postcard failure. +pub(crate) fn lz4_postcard_decode(bytes: &[u8]) -> Option { + let decompressed = lz4_flex::decompress_size_prepended(bytes).ok()?; + postcard::from_bytes(&decompressed).ok() +} diff --git a/crates/xmss/src/signers_cache.rs b/crates/rec_aggregation/src/signatures_cache.rs similarity index 72% rename from crates/xmss/src/signers_cache.rs rename to crates/rec_aggregation/src/signatures_cache.rs index 6e7a9956e..775ee8a11 100644 --- a/crates/xmss/src/signers_cache.rs +++ b/crates/rec_aggregation/src/signatures_cache.rs @@ -1,6 +1,9 @@ use backend::*; +#[cfg(test)] +use leansig_wrapper::xmss_verify; +use leansig_wrapper::{MESSAGE_LENGTH, XmssPublicKey, XmssSignature, xmss_keygen_fast, xmss_sign_fast}; +use rand::SeedableRng; use rand::rngs::StdRng; -use rand::{RngExt, SeedableRng}; use serde::{Deserialize, Serialize}; use sha3::{Digest, Sha3_256}; use std::fs; @@ -9,7 +12,9 @@ use std::sync::OnceLock; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Instant; -use crate::*; +pub fn message_for_benchmark() -> [u8; MESSAGE_LENGTH] { + BENCHMARK_MESSAGE +} static SIGNERS_CACHE: OnceLock> = OnceLock::new(); @@ -18,12 +23,12 @@ pub fn get_benchmark_signatures() -> &'static Vec<(XmssPublicKey, XmssSignature) } pub const BENCHMARK_SLOT: u32 = 111; +pub const BENCHMARK_MESSAGE: [u8; MESSAGE_LENGTH] = [ + 78, 32, 21, 11, 1, 76, 255, 254, 0, 0, 22, 11, 11, 87, 87, 32, 11, 32, 11, 76, 23, 12, 11, 2, 2, 2, 2, 2, 2, 3, 4, + 5, +]; pub const NUM_BENCHMARK_SIGNERS: usize = 10_000; -pub fn message_for_benchmark() -> [F; MESSAGE_LEN_FE] { - std::array::from_fn(F::from_usize) -} - const CACHE_SCHEMA_VERSION: u32 = 2; #[derive(Serialize, Deserialize)] @@ -36,12 +41,9 @@ fn cache_footprint(first_pubkey: &XmssPublicKey) -> u128 { let mut hasher = Sha3_256::new(); hasher.update(NUM_BENCHMARK_SIGNERS.to_le_bytes()); hasher.update(BENCHMARK_SLOT.to_le_bytes()); - for f in message_for_benchmark() { - hasher.update(f.as_canonical_u32().to_le_bytes()); - } - for f in first_pubkey.merkle_root { - hasher.update(f.as_canonical_u32().to_le_bytes()); - } + hasher.update(BENCHMARK_MESSAGE); + let pk_bytes = postcard::to_allocvec(first_pubkey).expect("pubkey serialization failed"); + hasher.update(&pk_bytes); let hash = hasher.finalize(); u128::from_le_bytes(hash[..16].try_into().unwrap()) } @@ -65,8 +67,8 @@ fn compute_signer(index: usize) -> (XmssPublicKey, XmssSignature) { let mut rng = StdRng::seed_from_u64(index as u64); let key_start = BENCHMARK_SLOT; let key_end = BENCHMARK_SLOT + 1; - let (sk, pk) = xmss_key_gen(rng.random(), key_start, key_end).unwrap(); - let sig = xmss_sign(&mut rng, &sk, &message_for_benchmark(), BENCHMARK_SLOT).unwrap(); + let (sk, pk) = xmss_keygen_fast(&mut rng, key_start, key_end); + let sig = xmss_sign_fast(&sk, &BENCHMARK_MESSAGE, BENCHMARK_SLOT).unwrap(); (pk, sig) } @@ -89,18 +91,16 @@ fn gen_benchmark_signers_cache() -> Vec<(XmssPublicKey, XmssSignature)> { let completed = AtomicUsize::new(1); let time = Instant::now(); - let rest: Vec<_> = (1..NUM_BENCHMARK_SIGNERS) - .into_par_iter() - .map(|index| { - let signer = compute_signer(index); - let done = completed.fetch_add(1, Ordering::Relaxed) + 1; - print!( - "\rPrecomputing benchmark signatures (cached after first run): {:.0}%", - 100.0 * done as f64 / NUM_BENCHMARK_SIGNERS as f64 - ); - signer - }) - .collect(); + let rest: Vec<_> = parallel::par_map_collect(NUM_BENCHMARK_SIGNERS - 1, |j| { + let index = j + 1; + let signer = compute_signer(index); + let done = completed.fetch_add(1, Ordering::Relaxed) + 1; + print!( + "\rPrecomputing benchmark signatures (cached after first run): {:.0}%", + 100.0 * done as f64 / NUM_BENCHMARK_SIGNERS as f64 + ); + signer + }); println!( "\rGenerating signatures for benchmark (one-time operation): 100% - done ({:.2}s)", @@ -128,8 +128,9 @@ fn gen_benchmark_signers_cache() -> Vec<(XmssPublicKey, XmssSignature)> { #[test] fn test_signature_cache() { let signatures = get_benchmark_signatures(); - signatures.par_iter().enumerate().for_each(|(i, (pk, sig))| { - xmss_verify(pk, &message_for_benchmark(), sig, BENCHMARK_SLOT) + parallel::for_each_index(signatures.len(), |i| { + let (pk, sig) = &signatures[i]; + xmss_verify(pk, BENCHMARK_SLOT, &BENCHMARK_MESSAGE, sig) .unwrap_or_else(|_| panic!("Signature {} failed to verify", i)); }); } diff --git a/crates/rec_aggregation/src/type_2_aggregation.rs b/crates/rec_aggregation/src/type_2_aggregation.rs deleted file mode 100644 index ef14e699d..000000000 --- a/crates/rec_aggregation/src/type_2_aggregation.rs +++ /dev/null @@ -1,288 +0,0 @@ -use crate::error::AggregationError; -use backend::*; -use lean_prover::default_whir_config; -use lean_prover::fiat_shamir_domain_sep; -use lean_prover::prove_execution::ExecutionProof; -use lean_prover::prove_execution::prove_execution; -use lean_vm::*; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use utils::poseidon_compress_slice; - -use crate::InnerVerified; -use crate::bytecode_claims::compute_bytecode_value_at; -use crate::bytecode_claims::flatten_bytecode_claim; -use crate::bytecode_claims::reduce_bytecode_claims; -use crate::compilation::{ - BYTECODE_CLAIM_OFFSET, MAX_RECURSIONS, PREAMBLE_MEMORY_LEN, TYPE2_FLAG, get_aggregation_bytecode, - try_get_aggregation_bytecode, -}; -use crate::decompress_size_prepended_bounded; -use crate::type_1_aggregation::{ - TypeOneInfo, TypeOneMultiSignature, check_type_one_pubkeys, extract_merkle_hint_blobs, verify_type_1, -}; -use crate::verify_inner; - -/// A bundle of `n` type-1 multi-signatures with potentially distinct (message, slot) per component, attested by a single snark. -#[derive(Debug, Clone)] -pub struct TypeTwoMultiSignature { - pub info: Vec, - pub bytecode_claim: Evaluation, // value is trusted to be correct (should be recomputed when receiving a proof from an untrusted source) - pub proof: ExecutionProof, -} - -impl Serialize for TypeTwoMultiSignature { - fn serialize(&self, s: S) -> Result { - (&self.info, &self.bytecode_claim.point, &self.proof).serialize(s) - } -} - -impl<'de> Deserialize<'de> for TypeTwoMultiSignature { - fn deserialize>(d: D) -> Result { - let (info, bytecode_claim_point, proof) = - <(Vec, MultilinearPoint, ExecutionProof)>::deserialize(d)?; - let bytecode = - try_get_aggregation_bytecode().ok_or_else(|| serde::de::Error::custom("bytecode not initialized"))?; - if bytecode_claim_point.len() != bytecode.cumulated_n_vars() { - return Err(serde::de::Error::custom("invalid bytecode point")); - } - let bytecode_value = compute_bytecode_value_at(&bytecode_claim_point); - Ok(TypeTwoMultiSignature { - info, - bytecode_claim: Evaluation::new(bytecode_claim_point, bytecode_value), - proof, - }) - } -} - -impl TypeTwoMultiSignature { - pub fn compress(&self) -> Vec { - let encoded = postcard::to_allocvec(self).expect("postcard serialization failed"); - lz4_flex::compress_prepend_size(&encoded) - } - - pub fn decompress(bytes: &[u8]) -> Option { - let decompressed = decompress_size_prepended_bounded(bytes)?; - let (value, rest) = postcard::take_from_bytes::(&decompressed).ok()?; - rest.is_empty().then_some(value) - } - - pub(crate) fn bytecode_claim_flat(&self) -> Vec { - flatten_bytecode_claim(&self.bytecode_claim) - } -} - -/// Layout: [prefix(8) | bytecode_claim_padded | initial_fiat_shamir_cap(8) | n × digest(8)]. -fn build_type2_input_data(digests: &[[F; DIGEST_LEN]], bytecode_claim_flat: &[F]) -> Vec { - let n = digests.len(); - let claim_padded = bytecode_claim_flat.len().next_multiple_of(DIGEST_LEN); - let domsep_offset = BYTECODE_CLAIM_OFFSET + claim_padded; - let digests_offset = domsep_offset + DIGEST_LEN; - let mut data = vec![F::ZERO; digests_offset + n * DIGEST_LEN]; - - data[0] = F::from_usize(TYPE2_FLAG); - data[1] = F::from_usize(n); - // data[2..8] stays zero (prefix-chunk pad). - - data[BYTECODE_CLAIM_OFFSET..][..bytecode_claim_flat.len()].copy_from_slice(bytecode_claim_flat); - let domsep = fiat_shamir_domain_sep(get_aggregation_bytecode()); - data[domsep_offset..][..DIGEST_LEN].copy_from_slice(&domsep); - - for (i, d) in digests.iter().enumerate() { - data[digests_offset + i * DIGEST_LEN..][..DIGEST_LEN].copy_from_slice(d); - } - - data -} - -pub fn merge_many_type_1( - types_1: Vec, - log_inv_rate: usize, -) -> Result { - let n_components = types_1.len(); - if n_components == 0 { - return Err(AggregationError::EmptyAggregation { - what: "type-1 components", - }); - } - if n_components > MAX_RECURSIONS { - return Err(AggregationError::LimitExceeded { - what: "type-1 components", - actual: n_components, - max: MAX_RECURSIONS, - }); - } - let whir_config = default_whir_config(log_inv_rate); - let bytecode = get_aggregation_bytecode(); - - let verified_children: Vec = types_1.iter().map(verify_type_1).collect::>()?; - - let reduced_claims = reduce_bytecode_claims(&verified_children); - - let digests: Vec<[F; DIGEST_LEN]> = verified_children.iter().map(|v| v.input_data_hash).collect(); - let pub_input_data = build_type2_input_data(&digests, &reduced_claims.final_claim_flat()); - let public_input_digest = poseidon_compress_slice(&pub_input_data); - - let bytecode_value_hint_blobs: Vec> = verified_children - .iter() - .map(|v| v.bytecode_evaluation.value.as_basis_coefficients_slice().to_vec()) - .collect(); - let component_layout_blobs: Vec> = verified_children.iter().map(|v| v.input_data.clone()).collect(); - let proof_transcript_blobs: Vec> = verified_children - .iter() - .map(|v| v.raw_proof.transcript.clone()) - .collect(); - let table_sort_perm_blobs: Vec> = verified_children - .iter() - .map(|v| v.sorted_table_perm.iter().map(|&i| F::from_usize(i)).collect()) - .collect(); - let (merkle_leaf_blobs, merkle_path_blobs) = - extract_merkle_hint_blobs(verified_children.iter().map(|v| &v.raw_proof)); - - let mut hints: HashMap>> = HashMap::new(); - hints.insert( - "input_data_num_chunks".to_string(), - vec![vec![F::from_usize(pub_input_data.len() / DIGEST_LEN)]], - ); - hints.insert("input_data".to_string(), vec![pub_input_data]); - hints.insert("bytecode_value_hint".to_string(), bytecode_value_hint_blobs); - hints.insert("component_layout".to_string(), component_layout_blobs); - hints.insert( - "proof_transcript_size".to_string(), - proof_transcript_blobs - .iter() - .map(|b| vec![F::from_usize(b.len())]) - .collect(), - ); - hints.insert("proof_transcript".to_string(), proof_transcript_blobs); - hints.insert("table_sort_perm".to_string(), table_sort_perm_blobs); - hints.insert("merkle_leaf".to_string(), merkle_leaf_blobs); - hints.insert("merkle_path".to_string(), merkle_path_blobs); - hints.insert( - "bytecode_sumcheck_proof".to_string(), - vec![reduced_claims.sumcheck_transcript], - ); - - let witness = ExecutionWitness { - preamble_memory_len: PREAMBLE_MEMORY_LEN, - hints, - min_table_log_n_rows: Default::default(), - }; - let execution_proof = prove_execution(bytecode, &public_input_digest, &witness, &whir_config, false)?; - - Ok(TypeTwoMultiSignature { - info: types_1.into_iter().map(|sig| sig.info).collect(), - bytecode_claim: reduced_claims.final_claim, - proof: execution_proof, - }) -} - -pub fn verify_type_2(sig: &TypeTwoMultiSignature) -> Result { - if sig.info.is_empty() || sig.info.len() > MAX_RECURSIONS { - return Err(ProofError::InvalidProof); - } - for info in &sig.info { - check_type_one_pubkeys(&info.pubkeys).map_err(|_| ProofError::InvalidProof)?; - } - let digests = sig - .info - .iter() - .map(|info| poseidon_compress_slice(&info.build_input_data())) - .collect::>(); - let input_data = build_type2_input_data(&digests, &sig.bytecode_claim_flat()); - verify_inner(input_data, sig.proof.proof.clone()) -} - -pub fn split_type_2_by_msg( - type_2: TypeTwoMultiSignature, - msg: [F; DIGEST_LEN], - log_inv_rate: usize, -) -> Result { - let Some(index) = type_2.info.iter().position(|info| info.message == msg) else { - return Err(AggregationError::UnknownMessage); - }; - if type_2.info.iter().filter(|info| info.message == msg).count() > 1 { - return Err(AggregationError::MultipleMessages); - } - split_type_2(type_2, index, log_inv_rate) -} - -/// Recover an independent type-1 multi-signature for the component at `index` -/// from a type-2 multi-signature. -pub fn split_type_2( - type_2: TypeTwoMultiSignature, - index: usize, - log_inv_rate: usize, -) -> Result { - let n_components = type_2.info.len(); - if index >= n_components { - return Err(AggregationError::InvalidSplitIndex { index, n_components }); - } - if n_components > MAX_RECURSIONS { - return Err(AggregationError::LimitExceeded { - what: "type-2 components", - actual: n_components, - max: MAX_RECURSIONS, - }); - } - let whir_config = default_whir_config(log_inv_rate); - let bytecode = get_aggregation_bytecode(); - - let outer_verified = verify_type_2(&type_2)?; - - let reduced_claims = reduce_bytecode_claims(std::slice::from_ref(&outer_verified)); - let bytecode_value_hint_blob = flatten_scalars_to_base(&[outer_verified.bytecode_evaluation.value]); - let table_sort_perm_blob: Vec = outer_verified - .sorted_table_perm - .iter() - .map(|&i| F::from_usize(i)) - .collect(); - - let mut outer_type_1 = type_2.info[index].clone(); - outer_type_1.bytecode_claim = reduced_claims.final_claim.clone(); - let ourer_input_data = outer_type_1.build_input_data(); - let outer_digest = poseidon_compress_slice(&ourer_input_data); - - let inner_input_data: Vec = type_2.info[index].build_input_data(); - - let (merkle_leaf_blobs, merkle_path_blobs) = - extract_merkle_hint_blobs(std::slice::from_ref(&outer_verified.raw_proof)); - let proof_transcript = outer_verified.raw_proof.transcript; - let proof_transcript_size = vec![F::from_usize(proof_transcript.len())]; - - let mut hints: HashMap>> = HashMap::new(); - hints.insert( - "input_data_num_chunks".to_string(), - vec![vec![F::from_usize(ourer_input_data.len() / DIGEST_LEN)]], - ); - hints.insert("input_data".to_string(), vec![ourer_input_data]); - hints.insert("is_split".to_string(), vec![vec![F::ONE]]); - hints.insert( - "type2_meta".to_string(), - vec![vec![F::from_usize(n_components), F::from_usize(index)]], - ); - hints.insert("inner_type2_layout".to_string(), vec![outer_verified.input_data]); - hints.insert("kept_type1_buff".to_string(), vec![inner_input_data]); - hints.insert("bytecode_value_hint".to_string(), vec![bytecode_value_hint_blob]); - hints.insert("proof_transcript_size".to_string(), vec![proof_transcript_size]); - hints.insert("proof_transcript".to_string(), vec![proof_transcript]); - hints.insert("table_sort_perm".to_string(), vec![table_sort_perm_blob]); - hints.insert("merkle_leaf".to_string(), merkle_leaf_blobs); - hints.insert("merkle_path".to_string(), merkle_path_blobs); - hints.insert( - "bytecode_sumcheck_proof".to_string(), - vec![reduced_claims.sumcheck_transcript], - ); - - let witness = ExecutionWitness { - preamble_memory_len: PREAMBLE_MEMORY_LEN, - hints, - min_table_log_n_rows: Default::default(), - }; - let execution_proof = prove_execution(bytecode, &outer_digest, &witness, &whir_config, false)?; - - Ok(TypeOneMultiSignature { - info: outer_type_1, - proof: execution_proof, - }) -} diff --git a/crates/rec_aggregation/tests/test_hashing.py b/crates/rec_aggregation/tests/test_hashing.py new file mode 100644 index 000000000..92ab0cd68 --- /dev/null +++ b/crates/rec_aggregation/tests/test_hashing.py @@ -0,0 +1,17 @@ +from snark_lib import * +from ..zkdsl_implem.utils import * + + +def main(): + build_preamble_memory() + expected_hash = 0 + input_size_buf = Array(1) + hint_witness("input_size", input_size_buf) + len = input_size_buf[0] + assert len < 2**15 + debug_assert(0 < len) + data = Array(len) + hint_witness("input", data) + hash = slice_hash_with_iv_dynamic_unroll(data, len, 15) + copy_8(hash, expected_hash) + return diff --git a/crates/rec_aggregation/tests/test_hashing.rs b/crates/rec_aggregation/tests/test_hashing.rs new file mode 100644 index 000000000..7cc0039eb --- /dev/null +++ b/crates/rec_aggregation/tests/test_hashing.rs @@ -0,0 +1,36 @@ +use backend::PrimeCharacteristicRing; +use lean_compiler::*; +use lean_vm::*; +use rand::{RngExt, SeedableRng, rngs::StdRng}; +use rec_aggregation::{NUM_REPEATED_ONES, PREAMBLE_MEMORY_LEN, ZERO_VEC_LEN}; +use std::collections::{BTreeMap, HashMap}; +use utils::poseidon_compress_slice_zero_iv; + +#[test] +fn test_slice_hashing() { + let path = format!("{}/tests/test_hashing.py", env!("CARGO_MANIFEST_DIR")); + let replacements = BTreeMap::from([ + ("ZERO_VEC_LEN_PLACEHOLDER".to_string(), ZERO_VEC_LEN.to_string()), + ( + "NUM_REPEATED_ONES_PLACEHOLDER".to_string(), + NUM_REPEATED_ONES.to_string(), + ), + ]); + let bytecode = compile_program_with_flags(&ProgramSource::Filepath(path), CompilationFlags { replacements }); + + for len in [1, 2, 6, 7, 8, 9, 15, 16, 17, 24, 100, 1000, 12345] { + let mut rng = StdRng::seed_from_u64(0); + let data: Vec = (0..len).map(|_| rng.random()).collect(); + let public_input = poseidon_compress_slice_zero_iv(&data); + let hints = HashMap::from([ + ("input_size".to_string(), vec![vec![F::from_usize(len)]]), + ("input".to_string(), vec![data]), + ]); + let witness = ExecutionWitness { + preamble_memory_len: PREAMBLE_MEMORY_LEN, + hints, + min_table_log_n_rows: BTreeMap::new(), + }; + execute_bytecode(&bytecode, &public_input, &witness, false); + } +} diff --git a/crates/rec_aggregation/zkdsl_implem/hashing.py b/crates/rec_aggregation/zkdsl_implem/hashing.py index 00b8a7007..e7e44fd40 100644 --- a/crates/rec_aggregation/zkdsl_implem/hashing.py +++ b/crates/rec_aggregation/zkdsl_implem/hashing.py @@ -163,6 +163,101 @@ def slice_hash_runtime(data, num_chunks): return final_state_ptr +# leansig pubkey hashing: zero-IV sponge over `len` FE (arbitrary length, last chunk zero-padded). +# Matches the Rust `poseidon_compress_slice_zero_iv`. +def slice_hash_with_iv_dynamic_unroll(data, len, len_bits: Const): + remainder = modulo_8(len, len_bits) + num_full_elements = len - remainder + num_full_chunks = num_full_elements / 8 + + if num_full_chunks == 0: + left = Array(DIGEST_LEN) + fill_padded_chunk(left, data, remainder) + result = Array(DIGEST_LEN) + poseidon16_compress(ZERO_VEC_PTR, left, result) + return result + + if num_full_chunks == 1: + if remainder == 0: + result = Array(DIGEST_LEN) + poseidon16_compress(ZERO_VEC_PTR, data, result) + return result + else: + h0 = Array(DIGEST_LEN) + poseidon16_compress(ZERO_VEC_PTR, data, h0) + right = Array(DIGEST_LEN) + fill_padded_chunk(right, data + DIGEST_LEN, remainder) + result = Array(DIGEST_LEN) + poseidon16_compress(h0, right, result) + return result + + partial_hash = slice_hash_chunks_with_iv(data, num_full_chunks, len_bits) + if remainder == 0: + return partial_hash + else: + padded_last = Array(DIGEST_LEN) + fill_padded_chunk(padded_last, data + num_full_elements, remainder) + final_hash = Array(DIGEST_LEN) + poseidon16_compress(partial_hash, padded_last, final_hash) + return final_hash + + +@inline +def slice_hash_chunks_with_iv(data, num_chunks, num_chunks_bits): + debug_assert(1 < num_chunks) + states = Array(num_chunks * DIGEST_LEN) + poseidon16_compress(ZERO_VEC_PTR, data, states) + n_iters = num_chunks - 1 + state_ptr: Mut = states + data_ptr: Mut = data + DIGEST_LEN + + n_chunks_outer, remainder = euclidian_div_runtime(n_iters, PARTIAL_UNROLL_BATCH) + for _ in range(0, n_chunks_outer): + for _ in unroll(0, PARTIAL_UNROLL_BATCH): + new_state = state_ptr + DIGEST_LEN + poseidon16_compress(state_ptr, data_ptr, new_state) + state_ptr = new_state + data_ptr += DIGEST_LEN + + final_state_ptr = match_range( + remainder, + range(0, PARTIAL_UNROLL_BATCH), + lambda r: absorb_n_hashes_const(r, state_ptr, data_ptr), + ) + return final_state_ptr + + +def fill_padded_chunk(dst, src, n): + debug_assert(0 < n) + debug_assert(n < DIGEST_LEN) + match_range(n, range(1, DIGEST_LEN), lambda r: fill_padded_chunk_const(dst, src, r)) + return + + +def fill_padded_chunk_const(dst, src, n: Const): + for i in unroll(0, n): + dst[i] = src[i] + for i in unroll(n, DIGEST_LEN): + dst[i] = 0 + return + + +def modulo_8(n, n_bits: Const): + debug_assert(2 < n_bits) + debug_assert(n < 2**n_bits) + bits = Array(n_bits) + hint_decompose_bits(n, bits, n_bits) + partial_sums = Array(n_bits) + partial_sums[0] = bits[n_bits - 1] + assert partial_sums[0] * (1 - partial_sums[0]) == 0 + for i in unroll(1, n_bits): + b = bits[n_bits - 1 - i] + assert b * (1 - b) == 0 + partial_sums[i] = partial_sums[i - 1] + b * 2**i + assert n == partial_sums[n_bits - 1] + return partial_sums[2] + + @inline def whir_do_4_merkle_levels(b, state_in, path_chunk, state_out): b0 = b % 2 diff --git a/crates/rec_aggregation/zkdsl_implem/main.py b/crates/rec_aggregation/zkdsl_implem/main.py index e516d6281..c3402c967 100644 --- a/crates/rec_aggregation/zkdsl_implem/main.py +++ b/crates/rec_aggregation/zkdsl_implem/main.py @@ -5,13 +5,12 @@ MAX_N_SIGS = MAX_XMSS_AGGREGATED_PLACEHOLDER MAX_N_DUPS = MAX_XMSS_DUPLICATES_PLACEHOLDER -# data_buf[0..8] = [flag, count, 0×6] (count = n_sigs for type-1, n_components for type-2). +# data_buf[0..8] = [flag, count, 0×6] (count = n_sigs). TYPE_1_FLAG = TYPE_1_FLAG_PLACEHOLDER -TYPE_2_FLAG = TYPE_2_FLAG_PLACEHOLDER BYTECODE_SUMCHECK_PROOF_SIZE = BYTECODE_SUMCHECK_PROOF_SIZE_PLACEHOLDER -# layout: [flag, count, 0×6 (8)] [bytecode_claim_padded] [initial_fiat_shamir_cap(8)] [type1/type2 mode-specific data] +# layout: [flag, count, 0×6 (8)] [bytecode_claim_padded] [initial_fiat_shamir_cap(8)] [aggregation data] BYTECODE_CLAIM_OFFSET = DIGEST_LEN # (right after the prefix chunk) INITIAL_FIAT_SHAMIR_CAP_OFFSET = BYTECODE_CLAIM_OFFSET + BYTECODE_CLAIM_SIZE_PADDED COMPONENT_DATA_OFFSET = INITIAL_FIAT_SHAMIR_CAP_OFFSET + DIGEST_LEN @@ -19,16 +18,15 @@ # Type-1 mode-specific data (fixed): pubkeys_hash | message | merkle_chunks | tweaks_hash. TYPE_1_PUBKEYS_HASH_OFFSET = COMPONENT_DATA_OFFSET TYPE_1_MSG_HASH_OFFSET = COMPONENT_DATA_OFFSET + DIGEST_LEN -TYPE_1_MERKLE_CHUNKS_OFFSET = TYPE_1_MSG_HASH_OFFSET + DIGEST_LEN +TYPE_1_MERKLE_CHUNKS_OFFSET = TYPE_1_MSG_HASH_OFFSET + MESSAGE_LEN TYPE_1_TWEAKS_HASH_OFFSET = TYPE_1_MERKLE_CHUNKS_OFFSET + N_MERKLE_CHUNKS -TYPE_1_INPUT_DATA_SIZE_PADDED = TYPE_1_TWEAKS_HASH_OFFSET + DIGEST_LEN +TYPE_1_INPUT_DATA_SIZE_PADDED = next_multiple_of(TYPE_1_TWEAKS_HASH_OFFSET + DIGEST_LEN, DIGEST_LEN) TYPE_1_INPUT_DATA_NUM_CHUNKS = TYPE_1_INPUT_DATA_SIZE_PADDED / DIGEST_LEN -# Type-2 mode-specific data (variable): n_components × digest(8). -TYPE_2_DIGESTS_OFFSET = COMPONENT_DATA_OFFSET +# Component data (pubkeys_hash | message | merkle_chunks | tweaks_hash) is a whole number of +# DIGEST_LEN chunks; its size depends on the config (N_MERKLE_CHUNKS scales with LOG_LIFETIME). +COMPONENT_DATA_NUM_CHUNKS = (TYPE_1_INPUT_DATA_SIZE_PADDED - COMPONENT_DATA_OFFSET) / DIGEST_LEN -BYTECODE_CLAIM_NUM_CHUNKS = BYTECODE_CLAIM_SIZE_PADDED / DIGEST_LEN -TYPE_2_BASE_NUM_CHUNKS = BYTECODE_CLAIM_NUM_CHUNKS + 2 # prefix chunk + domsep chunk def main(): @@ -47,67 +45,8 @@ def main(): initial_fiat_shamir_cap = data_buf + INITIAL_FIAT_SHAMIR_CAP_OFFSET discriminator = data_buf[0] - if discriminator == TYPE_2_FLAG: - # Type-2: merge of n type-1 multi-signatures. - n_components = data_buf[1] - assert n_components != 0 - assert n_components <= MAX_RECURSIONS - - n_bytecode_claims = n_components * 2 - bytecode_claims = Array(n_bytecode_claims) - - for c in range(0, n_components): - component_digest = data_buf + TYPE_2_DIGESTS_OFFSET + c * DIGEST_LEN - inner_type1_buf = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) - hint_witness("component_layout", inner_type1_buf) - ensure_well_formed_input_data(inner_type1_buf, initial_fiat_shamir_cap, TYPE_1_FLAG) - slice_hash(inner_type1_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, component_digest) - - bytecode_claims[2 * c] = inner_type1_buf + BYTECODE_CLAIM_OFFSET - bytecode_claims[2 * c + 1] = recursion(component_digest, initial_fiat_shamir_cap) - - reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_output, initial_fiat_shamir_cap) - - slice_hash_range(data_buf, n_components + TYPE_2_BASE_NUM_CHUNKS, pub_mem) - return - assert discriminator == TYPE_1_FLAG - is_split_buf = Array(1) - hint_witness("is_split", is_split_buf) - if is_split_buf[0] == 1: - # ============ type-1: Split (extract a type-one from a type-two) ============ - type2_meta_hint = Array(2) - hint_witness("type2_meta", type2_meta_hint) - type2_n_components = type2_meta_hint[0] - type2_kept_index = type2_meta_hint[1] - assert type2_n_components != 0 - assert type2_n_components <= MAX_RECURSIONS - assert type2_kept_index < type2_n_components - - type2_num_chunks = type2_n_components + TYPE_2_BASE_NUM_CHUNKS - type2_data_buf = Array(type2_num_chunks * DIGEST_LEN) - hint_witness("inner_type2_layout", type2_data_buf) - ensure_well_formed_input_data(type2_data_buf, initial_fiat_shamir_cap, TYPE_2_FLAG) - type2_digests = type2_data_buf + TYPE_2_DIGESTS_OFFSET - - kept_type1_buff = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) - hint_witness("kept_type1_buff", kept_type1_buff) - copy_8(data_buf, kept_type1_buff) # type-1 flag | n_signatures | 0×6 - copy_32(data_buf + COMPONENT_DATA_OFFSET, kept_type1_buff + COMPONENT_DATA_OFFSET) - ensure_well_formed_input_data(kept_type1_buff, initial_fiat_shamir_cap, TYPE_1_FLAG) - digest_kept = type2_digests + type2_kept_index * DIGEST_LEN - slice_hash(kept_type1_buff, TYPE_1_INPUT_DATA_NUM_CHUNKS, digest_kept) - - inner_pub_mem = Array(INNER_PUB_MEM_SIZE) - slice_hash_range(type2_data_buf, type2_num_chunks, inner_pub_mem) - bytecode_claims = Array(2) - bytecode_claims[0] = type2_data_buf + BYTECODE_CLAIM_OFFSET - bytecode_claims[1] = recursion(inner_pub_mem, initial_fiat_shamir_cap) - reduce_bytecode_claims(bytecode_claims, 2, bytecode_claim_output, initial_fiat_shamir_cap) - slice_hash(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) - return - # ============ Standard type-1: single (message, slot) aggregation ============ n_sigs = data_buf[1] assert n_sigs != 0 @@ -148,7 +87,7 @@ def main(): if n_raw_xmss == 0: type1_data_buf = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) copy_8(data_buf, type1_data_buf) # prefix - copy_32(data_buf + COMPONENT_DATA_OFFSET, type1_data_buf + COMPONENT_DATA_OFFSET) + copy_8n(data_buf + COMPONENT_DATA_OFFSET, type1_data_buf + COMPONENT_DATA_OFFSET, COMPONENT_DATA_NUM_CHUNKS) hint_witness("inner_bytecode_claim", type1_data_buf + BYTECODE_CLAIM_OFFSET) ensure_well_formed_input_data(type1_data_buf, initial_fiat_shamir_cap, TYPE_1_FLAG) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) @@ -161,7 +100,7 @@ def main(): return # General path - computed_pubkeys_hash = slice_hash_runtime(all_pubkeys, n_sigs) + computed_pubkeys_hash = slice_hash_with_iv_dynamic_unroll(all_pubkeys, n_sigs * PUB_KEY_SIZE, MAX_LOG_MEMORY_SIZE) copy_8(computed_pubkeys_hash, pubkeys_hash_expected) # Buffer for partition verification @@ -187,22 +126,19 @@ def main(): sub_indices_arr = Array(n_sub) hint_witness("sub_indices", sub_indices_arr) - - running_hash: Mut = build_iv(n_sub * PUB_KEY_SIZE) - n_chunks, remainder = euclidian_div_runtime(n_sub, PARTIAL_UNROLL_BATCH) - j: Mut = 0 - for _ in range(0, n_chunks): - for u in unroll(0, PARTIAL_UNROLL_BATCH): - counter, running_hash = absorb_recursive_pubkey(j + u, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash) - j += PARTIAL_UNROLL_BATCH - # Tail iterations - tail_counter, tail_running_hash = match_range( - remainder, - range(0, PARTIAL_UNROLL_BATCH), - lambda r: absorb_n_pubkeys_const(r, j, sub_indices_arr, n_total, all_pubkeys, buffer, counter, running_hash), + sub_pubkeys_buf = Array(n_sub * PUB_KEY_SIZE) + for j in range(0, n_sub): + idx = sub_indices_arr[j] + assert idx < n_total + buffer[idx] = counter + counter += 1 + src = all_pubkeys + idx * PUB_KEY_SIZE + dst = sub_pubkeys_buf + j * PUB_KEY_SIZE + copy_13(src, dst) + + sub_pubkeys_hash = slice_hash_with_iv_dynamic_unroll( + sub_pubkeys_buf, n_sub * PUB_KEY_SIZE, MAX_LOG_MEMORY_SIZE ) - counter = tail_counter - running_hash = tail_running_hash type1_data_buf = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) type1_data_buf[0] = TYPE_1_FLAG @@ -210,10 +146,13 @@ def main(): for k in unroll(2, DIGEST_LEN): type1_data_buf[k] = 0 - copy_8(running_hash, type1_data_buf + TYPE_1_PUBKEYS_HASH_OFFSET) - copy_8(message, type1_data_buf + TYPE_1_PUBKEYS_HASH_OFFSET + DIGEST_LEN) - copy_8(merkle_chunks_for_slot, type1_data_buf + TYPE_1_PUBKEYS_HASH_OFFSET + DIGEST_LEN + MESSAGE_LEN) + copy_8(sub_pubkeys_hash, type1_data_buf + TYPE_1_PUBKEYS_HASH_OFFSET) + copy_9(message, type1_data_buf + TYPE_1_MSG_HASH_OFFSET) + copy_8(merkle_chunks_for_slot, type1_data_buf + TYPE_1_MERKLE_CHUNKS_OFFSET) copy_8(tweaks_hash_expected, type1_data_buf + TYPE_1_TWEAKS_HASH_OFFSET) + # Zero-pad the trailing region up to TYPE_1_INPUT_DATA_SIZE_PADDED. + for k in unroll(TYPE_1_TWEAKS_HASH_OFFSET + DIGEST_LEN, TYPE_1_INPUT_DATA_SIZE_PADDED): + type1_data_buf[k] = 0 hint_witness("inner_bytecode_claim", type1_data_buf + BYTECODE_CLAIM_OFFSET) ensure_well_formed_input_data(type1_data_buf, initial_fiat_shamir_cap, TYPE_1_FLAG) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) diff --git a/crates/rec_aggregation/zkdsl_implem/recursion.py b/crates/rec_aggregation/zkdsl_implem/recursion.py index 2578f8a86..6f0bc0211 100644 --- a/crates/rec_aggregation/zkdsl_implem/recursion.py +++ b/crates/rec_aggregation/zkdsl_implem/recursion.py @@ -708,6 +708,8 @@ def evaluate_air_constraints(table_index, inner_evals, air_alpha_powers, logup_a res = evaluate_air_constraints_table_1(inner_evals, air_alpha_powers, logup_alphas_eq_poly) case 2: res = evaluate_air_constraints_table_2(inner_evals, air_alpha_powers, logup_alphas_eq_poly) + case 3: + res = evaluate_air_constraints_table_3(inner_evals, air_alpha_powers, logup_alphas_eq_poly) return res diff --git a/crates/rec_aggregation/zkdsl_implem/utils.py b/crates/rec_aggregation/zkdsl_implem/utils.py index d4b2e983c..db85af6cd 100644 --- a/crates/rec_aggregation/zkdsl_implem/utils.py +++ b/crates/rec_aggregation/zkdsl_implem/utils.py @@ -39,10 +39,10 @@ def div_ceil_dynamic(a, b: Const): def powers(alpha, n): # alpha: EF # n: F - assert n < 400 + assert n < 512 assert 0 < n # 2**log2_ceil(i) is not really necessary but helps reduce byetcode size (traedoff cycles / bytecode size) - res = match_range(n, range(1, 400), lambda i: powers_const(alpha, 2 ** log2_ceil(i))) + res = match_range(n, range(1, 512), lambda i: powers_const(alpha, 2 ** log2_ceil(i))) return res @@ -418,6 +418,14 @@ def set_to_8_zeros(a): return +@inline +def set_to_9_zeros(a): + zero_ptr = ZERO_VEC_PTR + dot_product_ee(a, ONE_EF_PTR, zero_ptr) + dot_product_ee(a + (9 - DIM), ONE_EF_PTR, zero_ptr) + return + + @inline def set_to_16_zeros(a): zero_ptr = ZERO_VEC_PTR @@ -428,6 +436,29 @@ def set_to_16_zeros(a): return +@inline +def copy_9(a, b): + dot_product_ee(a, ONE_EF_PTR, b) + dot_product_ee(a + (9 - DIM), ONE_EF_PTR, b + (9 - DIM)) + return + + +@inline +def copy_13(a, b): + dot_product_ee(a, ONE_EF_PTR, b) + dot_product_ee(a + 5, ONE_EF_PTR, b + 5) + dot_product_ee(a + (13 - DIM), ONE_EF_PTR, b + (13 - DIM)) + return + + +@inline +def copy_15(a, b): + copy_5(a, b) + copy_5(a + 5, b + 5) + copy_5(a + 10, b + 10) + return + + @inline def copy_16(a, b): dot_product_ee(a, ONE_EF_PTR, b) @@ -454,6 +485,25 @@ def copy_32(a, b): return +@inline +def copy_40(a, b): + copy_8(a, b) + copy_8(a + 8, b + 8) + copy_8(a + 16, b + 16) + copy_8(a + 24, b + 24) + copy_8(a + 32, b + 32) + return + + +# Copy n consecutive DIGEST_LEN-sized chunks (n compile-time known). +@inline +def copy_8n(a, b, n): + for i in unroll(0, n): + dot_product_ee(a + i * DIGEST_LEN, ONE_EF_PTR, b + i * DIGEST_LEN) + dot_product_ee(a + i * DIGEST_LEN + (DIGEST_LEN - DIM), ONE_EF_PTR, b + i * DIGEST_LEN + (DIGEST_LEN - DIM)) + return + + @inline def copy_many_ef(a, b, n): for i in unroll(0, n): @@ -569,14 +619,12 @@ def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks, lea partial_sum: Mut = nibbles[0] for i in unroll(1, 6): partial_sum += nibbles[i] * 16**i - - # p = 2^31 - 2^24 + 1, so 2^24 * 127 = p - 1 ≡ -1 (mod p), hence inv(2^24) = -127. - # Deduce top7 from the identity partial_sum + top7 * 2^24 == a: - # top7 = (a - partial_sum) * inv(2^24) = (partial_sum - a) * 127 + # top7 = (a - partial_sum) * inv(2^24) = (partial_sum - a) * 127 (inv(2^24) = -127 mod p) top7 = (partial_sum - a) * 127 assert top7 < 2**7 if top7 == 2**7 - 1: assert partial_sum == 0 + assert partial_sum + top7 * 2**24 == a leaf_data = Array(num_chunks * DIGEST_LEN) hint_witness("merkle_leaf", leaf_data) diff --git a/crates/rec_aggregation/zkdsl_implem/xmss_aggregate.py b/crates/rec_aggregation/zkdsl_implem/xmss_aggregate.py index 20ca144d4..17e68d716 100644 --- a/crates/rec_aggregation/zkdsl_implem/xmss_aggregate.py +++ b/crates/rec_aggregation/zkdsl_implem/xmss_aggregate.py @@ -8,200 +8,214 @@ LOG_LIFETIME = LOG_LIFETIME_PLACEHOLDER MESSAGE_LEN = MESSAGE_LEN_PLACEHOLDER RANDOMNESS_LEN = RANDOMNESS_LEN_PLACEHOLDER -PUBLIC_PARAM_LEN_FE = PUBLIC_PARAM_LEN_FE_PLACEHOLDER -XMSS_DIGEST_LEN = XMSS_DIGEST_LEN_PLACEHOLDER -PUB_KEY_SIZE = XMSS_DIGEST_LEN + PUBLIC_PARAM_LEN_FE -PP_IN_LEFT = DIGEST_LEN - XMSS_DIGEST_LEN -WOTS_SIG_SIZE = RANDOMNESS_LEN + V * XMSS_DIGEST_LEN -# wots_public_key pair stride: each pair occupies 10 cells `[leading_0 | tip_a(4) | tip_b(4) | trailing_0]`. In order to be able to use copy_5 on both sides. -WOTS_PK_PAIR_STRIDE = 2 + 2 * XMSS_DIGEST_LEN -NUM_ENCODING_FE = div_ceil(V, (24 / W)) +PUBLIC_PARAM_LEN = PUBLIC_PARAM_LEN_PLACEHOLDER +TWEAK_LEN = TWEAK_LEN_PLACEHOLDER +PUBKEY_SIZE = DIGEST_LEN + PUBLIC_PARAM_LEN +PUB_KEY_SIZE = PUBKEY_SIZE +SIG_SIZE = RANDOMNESS_LEN + (V + LOG_LIFETIME) * DIGEST_LEN MERKLE_LEVELS_PER_CHUNK = MERKLE_LEVELS_PER_CHUNK_PLACEHOLDER N_MERKLE_CHUNKS = LOG_LIFETIME / MERKLE_LEVELS_PER_CHUNK +WOTS_PUBKET_SPONGE_DOMAIN_SEP = WOTS_PUBKET_SPONGE_DOMAIN_SEP_PLACEHOLDER INNER_PUB_MEM_SIZE = 2**INNER_PUBLIC_MEMORY_LOG_SIZE # = DIGEST_LEN TWEAK_TABLE_ADDR = PREAMBLE_MEMORY_END -# Tweak table layout: all tweaks are stored as a 4-FE slot [tw[0], tw[1], 0, 0] -TWEAK_LEN = 4 # stride / slot size for non-encoding tweaks +POSEIDON24_CAP = 9 +POSEIDON24_RATE = 15 +CHUNKS_PER_FE = 24 / W # 8 +NUM_ENCODING_FE = div_ceil(V, CHUNKS_PER_FE) # ceil(V/8) +Q = 127 # Rejection parameter: p = Q * CHAIN_LENGTH^CHUNKS_PER_FE + 1 = 127 * 8^8 + 1 + +# All tweaks layout: encoding_tweak(2) + chain_tweaks(V*CHAIN_LENGTH*2) + leaf_tweak(2) + merkle_tweaks(LOG_LIFETIME*2) +N_ALL_TWEAKS = TWEAK_LEN + V * CHAIN_LENGTH * TWEAK_LEN + TWEAK_LEN + LOG_LIFETIME * TWEAK_LEN +CHAIN_TWEAKS_OFFSET = TWEAK_LEN +LEAF_TWEAK_OFFSET = TWEAK_LEN + V * CHAIN_LENGTH * TWEAK_LEN +MERKLE_TWEAKS_OFFSET = TWEAK_LEN + V * CHAIN_LENGTH * TWEAK_LEN + TWEAK_LEN + +# Padded tweak table layout: slot stride 4 (TWEAK_LEN_FE + 2 zeros), see compute_tweak_table in type_1_aggregation.rs. +TWEAK_SLOT_SIZE = 4 N_TWEAKS = 1 + V * CHAIN_LENGTH + 1 + LOG_LIFETIME -TWEAK_TABLE_SIZE_FE_PADDED = next_multiple_of(N_TWEAKS * TWEAK_LEN, DIGEST_LEN) -TWEAK_ENCODING_OFFSET = 0 -TWEAK_CHAIN_OFFSET = TWEAK_ENCODING_OFFSET + TWEAK_LEN # just after the encoding tweak -TWEAK_WOTS_PK_OFFSET = TWEAK_CHAIN_OFFSET + V * CHAIN_LENGTH * TWEAK_LEN -TWEAK_MERKLE_OFFSET = TWEAK_WOTS_PK_OFFSET + TWEAK_LEN +TWEAK_TABLE_SIZE_FE_PADDED = next_multiple_of(N_TWEAKS * TWEAK_SLOT_SIZE, DIGEST_LEN) @inline def xmss_verify(pub_key, message, merkle_chunks): - wots = Array(WOTS_SIG_SIZE) - hint_witness("wots", wots) - - public_param = pub_key + XMSS_DIGEST_LEN - randomness = wots - chain_starts = wots + RANDOMNESS_LEN - - # 1) Encode: poseidon16_compress(message[0:8], [randomness(6) | tweak_encoding(2)) - # poseidon16_compress(pre_compressed, [pp(4) | zeros(4)]) - encoding_tweak = TWEAK_TABLE_ADDR + TWEAK_ENCODING_OFFSET - a_input_right = Array(DIGEST_LEN) - copy_6(randomness, a_input_right) - a_input_right[6] = encoding_tweak[0] - a_input_right[7] = encoding_tweak[1] - pre_compressed = Array(DIGEST_LEN) - poseidon16_compress(message, a_input_right, pre_compressed) - - public_params_paded_buff = Array(DIGEST_LEN + 2) # 0 [public_param(4) | zeros(4)] 0 - copy_5(public_param - 1, public_params_paded_buff) - set_to_5_zeros(public_params_paded_buff + 5) - public_params_paded = public_params_paded_buff + 1 - encoding_fe = Array(DIGEST_LEN) - poseidon16_compress(pre_compressed, public_params_paded, encoding_fe) - - # Decompose the encoding into chunks of 2*W bits. Each chunk packs the chain step - # counts of two consecutive WOTS chains: chunk i = step_{2i} + CHAIN_LENGTH * step_{2i+1}. - encoding = Array(NUM_ENCODING_FE * 24 / (2 * W)) - - hint_decompose_bits_xmss(encoding, encoding_fe, NUM_ENCODING_FE, 2 * W) - - # check that the decomposition is correct - for i in unroll(0, NUM_ENCODING_FE): - for j in unroll(0, 24 / (2 * W)): - assert encoding[i * (24 / (2 * W)) + j] < CHAIN_LENGTH**2 - - partial_sum: Mut = encoding[i * (24 / (2 * W))] - for j in unroll(1, 24 / (2 * W)): - partial_sum += encoding[i * (24 / (2 * W)) + j] * (CHAIN_LENGTH**2) ** j - - # p = 2^31 - 2^24 + 1 = 127.2^24 + 1, so inv(2^24) = -127 (mod p). - # Deduce remaining_i from partial_sum + remaining_i * 2^24 == encoding_fe[i]: - # remaining_i = (encoding_fe[i] - partial_sum) * inv(2^24) = (partial_sum - encoding_fe[i]) * 127 - remaining_i = (partial_sum - encoding_fe[i]) * 127 - assert remaining_i < 127 # ensures uniformity + prevent overflow + # pub_key layout: merkle_root(DIGEST_LEN) | public_param(PUBLIC_PARAM_LEN) + merkle_root = pub_key + public_param = pub_key + DIGEST_LEN + + # All tweaks live in the preamble memory at TWEAK_TABLE_ADDR (4-FE slots). + # Each non-encoding tweak occupies a TWEAK_SLOT_SIZE=4 slot, but only the + # first TWEAK_LEN=2 elements are read here (slot stride differs from LEN). + encoding_tweak = TWEAK_TABLE_ADDR + chain_tweaks_base = TWEAK_TABLE_ADDR + TWEAK_SLOT_SIZE + leaf_tweak = TWEAK_TABLE_ADDR + TWEAK_SLOT_SIZE * (1 + V * CHAIN_LENGTH) + merkle_tweaks_base = TWEAK_TABLE_ADDR + TWEAK_SLOT_SIZE * (1 + V * CHAIN_LENGTH + 1) + + # XMSS signature: randomness | chain_tips | merkle_path + signature = Array(SIG_SIZE) + hint_witness("xmss_signature", signature) + randomness = signature + chain_starts = signature + RANDOMNESS_LEN + merkle_path = chain_starts + V * DIGEST_LEN + + # 1) Encode: poseidon24_compress_0_9(message(9) || pp(5) || slot(2) || randomness(7) || 0) + enc_rate = Array(15) + copy_5(public_param, enc_rate) + enc_rate[5] = encoding_tweak[0] + enc_rate[6] = encoding_tweak[1] + copy_7(randomness, enc_rate + 7) + enc_rate[14] = 0 + + encoding_fe = Array(POSEIDON24_CAP) + poseidon24_compress_0_9(message, enc_rate, encoding_fe) + + # 2) Decompose encoding_fe into chain indices (only first NUM_ENCODING_FE elements) + encoding = Array(NUM_ENCODING_FE * CHUNKS_PER_FE) + remaining = Array(NUM_ENCODING_FE) + hint_decompose_bits_xmss(encoding, remaining, encoding_fe, NUM_ENCODING_FE, W) - debug_assert(V % 2 == 0) - wots_public_key = Array((V / 2) * WOTS_PK_PAIR_STRIDE) + for i in unroll(0, NUM_ENCODING_FE): + for j in unroll(0, CHUNKS_PER_FE): + assert encoding[i * CHUNKS_PER_FE + j] < CHAIN_LENGTH + assert remaining[i] < Q + partial_sum: Mut = remaining[i] + for j in unroll(0, CHUNKS_PER_FE): + partial_sum += encoding[i * CHUNKS_PER_FE + j] * (Q * CHAIN_LENGTH ** j) + assert partial_sum == encoding_fe[i] + + # 3) Chain hashing with Poseidon16 + pre-computed tweaks target_sum: Mut = 0 - for i in unroll(0, V / 2): - chain_start_a = chain_starts + (2 * i) * XMSS_DIGEST_LEN - chain_start_b = chain_starts + (2 * i + 1) * XMSS_DIGEST_LEN - chain_end_a = wots_public_key + i * WOTS_PK_PAIR_STRIDE + 1 - chain_end_b = chain_end_a + XMSS_DIGEST_LEN - tweaks_a = TWEAK_TABLE_ADDR + TWEAK_CHAIN_OFFSET + (2 * i) * CHAIN_LENGTH * TWEAK_LEN - tweaks_b = TWEAK_TABLE_ADDR + TWEAK_CHAIN_OFFSET + (2 * i + 1) * CHAIN_LENGTH * TWEAK_LEN - pair_sum_ptr = Array(1) + wots_public_key = Array(V * DIGEST_LEN) + for i in unroll(0, V): + chain_start = chain_starts + i * DIGEST_LEN + chain_end = wots_public_key + i * DIGEST_LEN + enc_val_ptr = encoding + i + chain_sum_ptr = Array(1) match_range( - encoding[i], - range(0, CHAIN_LENGTH**2), - lambda n: chain_hash_pair( - chain_start_a, - chain_start_b, - n, - chain_end_a, - chain_end_b, - tweaks_a, - tweaks_b, - public_params_paded, - pair_sum_ptr, - ), + enc_val_ptr[0], range(0, CHAIN_LENGTH), + lambda n: chain_hash(chain_start, n, chain_end, chain_sum_ptr, public_param, chain_tweaks_base, i) ) - target_sum += pair_sum_ptr[0] + target_sum += chain_sum_ptr[0] assert target_sum == TARGET_SUM - merkle_leaf = wots_pk_hash(wots_public_key, public_param) + # 4) WOTS PK hash with Poseidon24 sponge (parameter + leaf_tweak prefix) + wots_pk_hashed = wots_pk_hash_p24(wots_public_key, public_param, leaf_tweak) + + # 5) Merkle verify with Poseidon24 + pre-computed tweaks + xmss_merkle_verify_p24(wots_pk_hashed, merkle_path, merkle_chunks, merkle_root, public_param, merkle_tweaks_base) - merkle_tweaks = TWEAK_TABLE_ADDR + TWEAK_MERKLE_OFFSET - xmss_merkle_verify(merkle_leaf, merkle_chunks, pub_key, public_param, merkle_tweaks) return @inline -def chain_hash_pa(input, n, output, chain_i_tweaks, chain_right): - starting_step = CHAIN_LENGTH - 1 - n - if n == 1: - first_tweak = chain_i_tweaks + starting_step * TWEAK_LEN - poseidon16_compress_half_hardcoded_left(input, chain_right, output, first_tweak) - else: - digests = Array(n * XMSS_DIGEST_LEN) - - # Hash 0: input → digests[0..4] - first_tweak = chain_i_tweaks + starting_step * TWEAK_LEN - poseidon16_compress_half_hardcoded_left(input, chain_right, digests, first_tweak) - - # Hashes 1..n-2: digests[(j-1)*4..j*4] → digests[j*4..(j+1)*4] - for j in unroll(1, n - 1): - cur_tweak = chain_i_tweaks + (starting_step + j) * TWEAK_LEN - poseidon16_compress_half_hardcoded_left( - digests + (j - 1) * XMSS_DIGEST_LEN, - chain_right, - digests + j * XMSS_DIGEST_LEN, - cur_tweak, - ) - - # Final hash: digests[(n-2)*4..(n-1)*4] → output - last_tweak = chain_i_tweaks + (starting_step + n - 1) * TWEAK_LEN - poseidon16_compress_half_hardcoded_left(digests + (n - 2) * XMSS_DIGEST_LEN, chain_right, output, last_tweak) - return +def make_chain_right(public_param, chain_tweaks, chain_index, step): + right = Array(DIGEST_LEN) + # chain_tweaks lives in the 4-FE-stride padded tweak table. + tweak_idx = (chain_index * CHAIN_LENGTH + step) * TWEAK_SLOT_SIZE + copy_5(public_param, right) + right[5] = chain_tweaks[tweak_idx] + right[6] = chain_tweaks[tweak_idx + 1] + right[7] = 0 + return right @inline -def chain_hash_pair( - input_a, - input_b, - n, - output_a, - output_b, - tweaks_a, - tweaks_b, - chain_right, - pair_sum_ptr, -): - # Pair-encoded chain hash. `n` is a compile-time constant in [0, CHAIN_LENGTH^2) - raw_a = n % CHAIN_LENGTH - raw_b = (n - raw_a) / CHAIN_LENGTH - num_hashes_a = (CHAIN_LENGTH - 1) - raw_a - num_hashes_b = (CHAIN_LENGTH - 1) - raw_b - - if num_hashes_a == 0: - copy_5(input_a - 1, output_a - 1) +def chain_hash(input_ptr, n, output_ptr, chain_sum_ptr, public_param, chain_tweaks, chain_index): + num_hashes = (CHAIN_LENGTH - 1) - n + start_step = n + 1 + + if num_hashes == 0: + copy_8(input_ptr, output_ptr) + elif num_hashes == 1: + right = make_chain_right(public_param, chain_tweaks, chain_index, start_step) + poseidon16_compress(input_ptr, right, output_ptr) else: - chain_hash_pa(input_a, num_hashes_a, output_a, tweaks_a, chain_right) - - if num_hashes_b == 0: - copy_5(input_b, output_b) - else: - chain_hash_pa(input_b, num_hashes_b, output_b, tweaks_b, chain_right) - - pair_sum_ptr[0] = raw_a + raw_b + states = Array((num_hashes - 1) * DIGEST_LEN) + right0 = make_chain_right(public_param, chain_tweaks, chain_index, start_step) + poseidon16_compress(input_ptr, right0, states) + for j in unroll(1, num_hashes - 1): + right_j = make_chain_right(public_param, chain_tweaks, chain_index, start_step + j) + poseidon16_compress(states + (j - 1) * DIGEST_LEN, right_j, states + j * DIGEST_LEN) + right_last = make_chain_right(public_param, chain_tweaks, chain_index, start_step + num_hashes - 1) + poseidon16_compress(states + (num_hashes - 2) * DIGEST_LEN, right_last, output_ptr) + + chain_sum_ptr[0] = n return @inline -def wots_pk_hash(wots_public_key, public_param): - N_CHUNKS = V / 2 - states = Array((N_CHUNKS + 1) * DIGEST_LEN) - poseidon16_compress_hardcoded_left(public_param, ZERO_VEC_PTR, states, TWEAK_TABLE_ADDR + TWEAK_WOTS_PK_OFFSET) - for i in unroll(0, N_CHUNKS): - poseidon16_compress( - states + i * DIGEST_LEN, - wots_public_key + i * WOTS_PK_PAIR_STRIDE + 1, - states + (i + 1) * DIGEST_LEN, - ) - - return states + N_CHUNKS * DIGEST_LEN +def wots_pk_hash_p24(wots_pk, public_param, leaf_tweak): + # Sponge input: parameter(5) | leaf_tweak(2) | chain_ends(V*8) + PREFIX_LEN = PUBLIC_PARAM_LEN + TWEAK_LEN # 7 + capacity: Mut = Array(POSEIDON24_CAP) + for i in unroll(0, POSEIDON24_CAP): + capacity[i] = WOTS_PUBKET_SPONGE_DOMAIN_SEP[i] + # First chunk: parameter(5) | leaf_tweak(2) | wots_pk[0..8] + first_rate = Array(POSEIDON24_RATE) + copy_5(public_param, first_rate) + first_rate[5] = leaf_tweak[0] + first_rate[6] = leaf_tweak[1] + copy_8(wots_pk, first_rate + PREFIX_LEN) + new_capacity = Array(POSEIDON24_CAP) + poseidon24_permute_0_9(capacity, first_rate, new_capacity) + capacity = new_capacity + # Remaining data: wots_pk[8..] = V*DIGEST_LEN - 8 elements + WOTS_PK_OFFSET = POSEIDON24_RATE - PREFIX_LEN # 8 + REMAINING = V * DIGEST_LEN - WOTS_PK_OFFSET + REMAINDER = REMAINING % POSEIDON24_RATE + N_FULL_STEPS = div_floor(REMAINING, POSEIDON24_RATE) + for step in unroll(0, N_FULL_STEPS): + src = wots_pk + WOTS_PK_OFFSET + step * POSEIDON24_RATE + new_capacity = Array(POSEIDON24_CAP) + if step == N_FULL_STEPS - 1: + if REMAINDER == 0: + poseidon24_permute_9_18(capacity, src, new_capacity) + else: + poseidon24_permute_0_9(capacity, src, new_capacity) + else: + poseidon24_permute_0_9(capacity, src, new_capacity) + capacity = new_capacity + if REMAINDER != 0: + src = wots_pk + WOTS_PK_OFFSET + N_FULL_STEPS * POSEIDON24_RATE + remainder_rate = Array(POSEIDON24_RATE) + for i in unroll(0, REMAINDER): + remainder_rate[i] = src[i] + for i in unroll(REMAINDER, POSEIDON24_RATE): + remainder_rate[i] = 0 + remainder_capacity = Array(POSEIDON24_CAP) + poseidon24_permute_9_18(capacity, remainder_rate, remainder_capacity) + capacity = remainder_capacity + return capacity @inline -def set_buf_prefix_right(buf, public_param): - # Writes [pp(4)] to buf[0..4] — the RIGHT-input prefix. - for k in unroll(0, PP_IN_LEFT): - buf[k] = public_param[k] +def xmss_merkle_verify_p24(leaf_digest, merkle_path, merkle_chunks, expected_root, public_param, merkle_tweaks): + states = Array((N_MERKLE_CHUNKS - 1) * DIGEST_LEN) + + match_range(merkle_chunks[0], range(0, 16), lambda b: + do_4_merkle_levels_p24(b, leaf_digest, merkle_path, states, public_param, merkle_tweaks, 0)) + + for j in unroll(1, N_MERKLE_CHUNKS - 1): + match_range(merkle_chunks[j], range(0, 16), lambda b: + do_4_merkle_levels_p24( + b, states + (j - 1) * DIGEST_LEN, + merkle_path + j * MERKLE_LEVELS_PER_CHUNK * DIGEST_LEN, + states + j * DIGEST_LEN, + public_param, merkle_tweaks, j * MERKLE_LEVELS_PER_CHUNK)) + + match_range(merkle_chunks[N_MERKLE_CHUNKS - 1], range(0, 16), lambda b: + do_4_merkle_levels_p24( + b, states + (N_MERKLE_CHUNKS - 2) * DIGEST_LEN, + merkle_path + (N_MERKLE_CHUNKS - 1) * MERKLE_LEVELS_PER_CHUNK * DIGEST_LEN, + expected_root, + public_param, merkle_tweaks, (N_MERKLE_CHUNKS - 1) * MERKLE_LEVELS_PER_CHUNK)) return @inline -def do_4_merkle_levels(b, state_in, state_out, public_param, merkle_tweaks_chunk): +def do_4_merkle_levels_p24(b, state_in, path_chunk, state_out, public_param, merkle_tweaks, base_level): b0 = b % 2 r1 = (b - b0) / 2 b1 = r1 % 2 @@ -210,90 +224,41 @@ def do_4_merkle_levels(b, state_in, state_out, public_param, merkle_tweaks_chunk r3 = (r2 - b2) / 2 b3 = r3 % 2 - buf0_alloc = Array(XMSS_DIGEST_LEN * 2 + 2) - buf0 = buf0_alloc + 1 - if b0 == 1: - # state_in is the LEFT child → state_in[0..4] lands at buf0[0..4]. - copy_5(state_in - 1, buf0 - 1) - hint_witness("xmss_merkle_node", buf0 + XMSS_DIGEST_LEN) - else: - # state_in is the RIGHT child → state_in[0..4] lands at buf0[4..8]. - hint_witness("xmss_merkle_node", buf0) - copy_5(state_in, buf0 + XMSS_DIGEST_LEN) - - # Level 0 hash - buf1 = Array(XMSS_DIGEST_LEN * 2) - if b1 == 1: - poseidon16_compress_half_hardcoded_left(public_param, buf0, buf1, merkle_tweaks_chunk) - hint_witness("xmss_merkle_node", buf1 + XMSS_DIGEST_LEN) - else: - poseidon16_compress_half_hardcoded_left(public_param, buf0, buf1 + XMSS_DIGEST_LEN, merkle_tweaks_chunk) - hint_witness("xmss_merkle_node", buf1) - - # Level 1 hash → buf2 - buf2 = Array(XMSS_DIGEST_LEN * 2) - if b2 == 1: - poseidon16_compress_half_hardcoded_left(public_param, buf1, buf2, merkle_tweaks_chunk + 1 * TWEAK_LEN) - hint_witness("xmss_merkle_node", buf2 + XMSS_DIGEST_LEN) - else: - poseidon16_compress_half_hardcoded_left( - public_param, buf1, buf2 + XMSS_DIGEST_LEN, merkle_tweaks_chunk + 1 * TWEAK_LEN - ) - hint_witness("xmss_merkle_node", buf2) - - # Level 2 hash → buf3 - buf3 = Array(XMSS_DIGEST_LEN * 2) - if b3 == 1: - poseidon16_compress_half_hardcoded_left(public_param, buf2, buf3, merkle_tweaks_chunk + 2 * TWEAK_LEN) - hint_witness("xmss_merkle_node", buf3 + XMSS_DIGEST_LEN) - else: - poseidon16_compress_half_hardcoded_left( - public_param, buf2, buf3 + XMSS_DIGEST_LEN, merkle_tweaks_chunk + 2 * TWEAK_LEN - ) - hint_witness("xmss_merkle_node", buf3) + temps = Array(3 * DIGEST_LEN) - poseidon16_compress_half_hardcoded_left(public_param, buf3, state_out, merkle_tweaks_chunk + 3 * TWEAK_LEN) + merkle_p24_one_level(b0, state_in, path_chunk, temps, public_param, merkle_tweaks, base_level) + merkle_p24_one_level(b1, temps, path_chunk + DIGEST_LEN, temps + DIGEST_LEN, public_param, merkle_tweaks, base_level + 1) + merkle_p24_one_level(b2, temps + DIGEST_LEN, path_chunk + 2 * DIGEST_LEN, temps + 2 * DIGEST_LEN, public_param, merkle_tweaks, base_level + 2) + merkle_p24_one_level(b3, temps + 2 * DIGEST_LEN, path_chunk + 3 * DIGEST_LEN, state_out, public_param, merkle_tweaks, base_level + 3) return @inline -def xmss_merkle_verify(leaf_digest, merkle_chunks, expected_root, public_param, merkle_tweaks): - states_alloc = Array(DIM * N_MERKLE_CHUNKS) - states = states_alloc + 1 - - # First chunk - match_range( - merkle_chunks[0], - range(0, 16), - lambda b: do_4_merkle_levels(b, leaf_digest, states, public_param, merkle_tweaks), - ) - - state_indexes = Array(N_MERKLE_CHUNKS - 1) - state_indexes[0] = states - for j in unroll(1, N_MERKLE_CHUNKS - 1): - state_indexes[j] = state_indexes[j - 1] + DIM - match_range( - merkle_chunks[j], - range(0, 16), - lambda b: do_4_merkle_levels( - b, - state_indexes[j - 1], - state_indexes[j], - public_param, - merkle_tweaks + j * MERKLE_LEVELS_PER_CHUNK * TWEAK_LEN, - ), - ) +def merkle_p24_one_level(is_left_bit, current, neighbour, output, public_param, merkle_tweaks, child_level): + # merkle_tweaks lives in the 4-FE-stride padded tweak table. + tweak_ptr = merkle_tweaks + child_level * TWEAK_SLOT_SIZE + + input_buf = Array(24) + copy_5(public_param, input_buf) + input_buf[5] = tweak_ptr[0] + input_buf[6] = tweak_ptr[1] + if is_left_bit == 0: + copy_8(neighbour, input_buf + 7) + copy_8(current, input_buf + 15) + else: + copy_8(current, input_buf + 7) + copy_8(neighbour, input_buf + 15) + input_buf[23] = 0 + + merkle_output = Array(POSEIDON24_CAP) + poseidon24_compress_0_9(input_buf, input_buf + 9, merkle_output) + for k in unroll(0, DIGEST_LEN): + output[k] = merkle_output[k] + return - # last chunk → write directly to expected_root - match_range( - merkle_chunks[N_MERKLE_CHUNKS - 1], - range(0, 16), - lambda b: do_4_merkle_levels( - b, - state_indexes[N_MERKLE_CHUNKS - 2], - expected_root, - public_param, - merkle_tweaks + (N_MERKLE_CHUNKS - 1) * MERKLE_LEVELS_PER_CHUNK * TWEAK_LEN, - ), - ) + +@inline +def copy_7(x, y): + dot_product_ee(x, ONE_EF_PTR, y) + dot_product_ee(x + (7 - DIM), ONE_EF_PTR, y + (7 - DIM)) return diff --git a/crates/sub_protocols/Cargo.toml b/crates/sub_protocols/Cargo.toml index 38c8663a6..47e50f50a 100644 --- a/crates/sub_protocols/Cargo.toml +++ b/crates/sub_protocols/Cargo.toml @@ -9,6 +9,7 @@ workspace = true [dependencies] tracing.workspace = true utils.workspace = true +parallel.workspace = true lean_vm.workspace = true backend.workspace = true diff --git a/crates/sub_protocols/src/air_sumcheck.rs b/crates/sub_protocols/src/air_sumcheck.rs index 0f536d7fa..eb9ca0a02 100644 --- a/crates/sub_protocols/src/air_sumcheck.rs +++ b/crates/sub_protocols/src/air_sumcheck.rs @@ -89,22 +89,20 @@ where let _span = info_span!("chunk-bit-reversing columns").entered(); let chunk_size = 1usize << pivot; let shift = usize::BITS as usize - pivot; - let bit_reversed = cols - .par_iter() - .map(|&src| { - let mut dst: Vec> = unsafe { uninitialized_vec(src.len()) }; - let src_u = PFPacking::::unpack_slice(src); - let dst_u = PFPacking::::unpack_slice_mut(&mut dst); - for (src_chunk, dst_chunk) in - src_u.chunks_exact(chunk_size).zip(dst_u.chunks_exact_mut(chunk_size)) - { - for (p, slot) in dst_chunk.iter_mut().enumerate() { - *slot = src_chunk[p.reverse_bits() >> shift]; - } + let mut bit_reversed: Vec>> = vec![ArenaVec::new(); cols.len()]; + parallel::par_chunks_mut(&mut bit_reversed, 1, |i, out_slot| { + let src = cols[i]; + let mut dst: ArenaVec> = unsafe { ArenaVec::uninitialized(src.len()) }; + let src_u = PFPacking::::unpack_slice(src); + let dst_u = PFPacking::::unpack_slice_mut(&mut dst); + for (src_chunk, dst_chunk) in src_u.chunks_exact(chunk_size).zip(dst_u.chunks_exact_mut(chunk_size)) + { + for (p, slot) in dst_chunk.iter_mut().enumerate() { + *slot = src_chunk[p.reverse_bits() >> shift]; } - dst - }) - .collect(); + } + out_slot[0] = dst; + }); MleGroup::Owned(MleGroupOwned::BasePacked(bit_reversed)) } _ => unreachable!(), @@ -438,120 +436,112 @@ where let hi_zs_halved: Vec<_> = hi_zs.iter().map(|&tz| tz.halve()).collect(); let lagrange_coeffs = lagrange_basis_evals(&low_zs, &hi_zs); - let acc = (0..active_count_pairs) - .into_par_iter() - .fold( - || { - ( - vec![EFPacking::::ZERO; degree], - Vec::::with_capacity(n_cols), - Vec::::with_capacity(n_cols), - vec![EFPacking::::ZERO; n_full], - Vec::::new(), - Vec::::new(), - Vec::::new(), - ) - }, - |(mut acc, mut point, mut diff, mut low_evals, mut state_0, mut state_2, mut cached_buf), new_j| { - let i_hi = new_j >> fold_bit; - let i_lo = new_j & lo_mask; - let i0 = (i_hi << (fold_bit + 1)) | i_lo; - let i1 = i0 | stride; - let partial_eq = get_split_eq(new_j); - - // `point` holds column values at z=0; `diff[k] = col_k[i1] - col_k[i0]`. - // Invariant for the rest of this closure: `col_k(z) = point[k] + z · diff[k]`, - // so advancing z by 1 means `point[k] += diff[k]` for all k. - point.clear(); - diff.clear(); - for c in cols { - let lo = c[i0]; - let hi = c[i1]; - point.push(lo); - diff.push(hi - lo); - } + let acc = parallel::map_reduce_with_state( + active_count_pairs, + || { + ( + Vec::::with_capacity(n_cols), + Vec::::with_capacity(n_cols), + vec![EFPacking::::ZERO; n_full], + Vec::::new(), + Vec::::new(), + Vec::::new(), + ) + }, + || vec![EFPacking::::ZERO; degree], + |(point, diff, low_evals, state_0, state_2, cached_buf), acc, new_j| { + let i_hi = new_j >> fold_bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (fold_bit + 1)) | i_lo; + let i1 = i0 | stride; + let partial_eq = get_split_eq(new_j); + + // `point` holds column values at z=0; `diff[k] = col_k[i1] - col_k[i0]`. + // Invariant for the rest of this closure: `col_k(z) = point[k] + z · diff[k]`, + // so advancing z by 1 means `point[k] += diff[k]` for all k. + point.clear(); + diff.clear(); + for c in cols { + let lo = c[i0]; + let hi = c[i1]; + point.push(lo); + diff.push(hi - lo); + } - // Phase 1: full AIR constraints + // Phase 1: full AIR constraints - // z = 0: full eval, capture post-block state. - { - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - folder.cached_state = Some(state_0); - Air::eval(computation, &mut folder, extra_data); - acc[0] += folder.accumulator * partial_eq; - low_evals[0] = folder.accumulator_low; - state_0 = folder.cached_state.unwrap(); - } + // z = 0: full eval, capture post-block state. + { + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.cached_state = Some(std::mem::take(state_0)); + Air::eval(computation, &mut folder, extra_data); + acc[0] += folder.accumulator * partial_eq; + low_evals[0] = folder.accumulator_low; + *state_0 = folder.cached_state.unwrap(); + } - // z = 2: advance `point` by 2·diff, full eval, capture post-block state. - // Together with `state_0` this pins down the linear `state(z)` (linear when we "omit" the low degree constraints of the block) + // z = 2: advance `point` by 2·diff, full eval, capture post-block state. + // Together with `state_0` this pins down the linear `state(z)` (linear when we "omit" the low degree constraints of the block) + for k in 0..n_cols { + point[k] += diff[k].double(); + } + { + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.cached_state = Some(std::mem::take(state_2)); + Air::eval(computation, &mut folder, extra_data); + acc[1] += folder.accumulator * partial_eq; + low_evals[1] = folder.accumulator_low; + *state_2 = folder.cached_state.unwrap(); + } + + // z = 3, …, d_low+1: still doing full eval + for z_idx in 2..n_full { for k in 0..n_cols { - point[k] += diff[k].double(); - } - { - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - folder.cached_state = Some(state_2); - Air::eval(computation, &mut folder, extra_data); - acc[1] += folder.accumulator * partial_eq; - low_evals[1] = folder.accumulator_low; - state_2 = folder.cached_state.unwrap(); + point[k] += diff[k]; } + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + Air::eval(computation, &mut folder, extra_data); + acc[z_idx] += folder.accumulator * partial_eq; + low_evals[z_idx] = folder.accumulator_low; + } - // z = 3, …, d_low+1: still doing full eval - for z_idx in 2..n_full { - for k in 0..n_cols { - point[k] += diff[k]; - } - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - Air::eval(computation, &mut folder, extra_data); - acc[z_idx] += folder.accumulator * partial_eq; - low_evals[z_idx] = folder.accumulator_low; + // Phase 2: skip the low degree constraints of the block + // For each skipped point, assemble Constraints(z) = high(z) + low(z): + // -high(z): run folder with `skip_low = true` + // -low(z): deduce it via Lagrange-interpolation from previous computations + for t in 0..n_skip { + for k in 0..n_cols { + point[k] += diff[k]; } - // Phase 2: skip the low degree constraints of the block - // For each skipped point, assemble Constraints(z) = high(z) + low(z): - // -high(z): run folder with `skip_low = true` - // -low(z): deduce it via Lagrange-interpolation from previous computations - for t in 0..n_skip { - for k in 0..n_cols { - point[k] += diff[k]; - } - - cached_buf.clear(); - for i in 0..state_0.len() { - cached_buf - .push(state_0[i] + (state_2[i] - state_0[i]) * PFPacking::::from(hi_zs_halved[t])); - } - - let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); - folder.skip_low = true; - folder.cached_state = Some(cached_buf); - folder.low_ci_count = low_n_constraints; - Air::eval(computation, &mut folder, extra_data); - cached_buf = folder.cached_state.unwrap(); - - // low(hi_zs[t]) = Σ_i L_i(hi_zs[t]) · low(low_zs[i]) - let mut low_interpolated = EFPacking::::ZERO; - for (i, lc) in lagrange_coeffs[t].iter().enumerate() { - low_interpolated += low_evals[i] * PFPacking::::from(*lc); - } - - acc[n_full + t] += (folder.accumulator + low_interpolated) * partial_eq; + cached_buf.clear(); + for i in 0..state_0.len() { + cached_buf.push(state_0[i] + (state_2[i] - state_0[i]) * PFPacking::::from(hi_zs_halved[t])); } - (acc, point, diff, low_evals, state_0, state_2, cached_buf) - }, - ) - .map(|(acc, ..)| acc) - .reduce( - || vec![EFPacking::::ZERO; degree], - |mut a, b| { - for i in 0..degree { - a[i] += b[i]; + let mut folder = ConstraintFolderPacked::new(&point[..n_flat], &point[n_flat..], extra_data); + folder.skip_low = true; + folder.cached_state = Some(std::mem::take(cached_buf)); + folder.low_ci_count = low_n_constraints; + Air::eval(computation, &mut folder, extra_data); + *cached_buf = folder.cached_state.unwrap(); + + // low(hi_zs[t]) = Σ_i L_i(hi_zs[t]) · low(low_zs[i]) + let mut low_interpolated = EFPacking::::ZERO; + for (i, lc) in lagrange_coeffs[t].iter().enumerate() { + low_interpolated += low_evals[i] * PFPacking::::from(*lc); } - a - }, - ); + + acc[n_full + t] += (folder.accumulator + low_interpolated) * partial_eq; + } + }, + |mut a, b| { + for i in 0..degree { + a[i] += b[i]; + } + a + }, + ); acc.into_iter().map(&unpack_sum).collect() } @@ -581,54 +571,43 @@ where let stride = 1usize << fold_bit; let lo_mask = stride - 1; - let acc = (0..active_count_pairs) - .into_par_iter() - .fold( - || { - ( - vec![EFT::ZERO; degree], - Vec::::with_capacity(n_cols), - Vec::::with_capacity(n_cols), - ) - }, - |(mut acc, mut point, mut diff), new_j| { - let i_hi = new_j >> fold_bit; - let i_lo = new_j & lo_mask; - let i0 = (i_hi << (fold_bit + 1)) | i_lo; - let i1 = i0 | stride; - let partial_eq = get_split_eq(new_j); - point.clear(); - diff.clear(); - for c in cols { - let lo = c[i0]; - let hi = c[i1]; - point.push(lo); - diff.push(hi - lo); - } - // z = 0 then (skip z = 1) z = 2, 3, …, degree. - acc[0] += eval_fn(computation, &point, extra_data) * partial_eq; + let acc = parallel::map_reduce_with_state( + active_count_pairs, + || (Vec::::with_capacity(n_cols), Vec::::with_capacity(n_cols)), + || vec![EFT::ZERO; degree], + |(point, diff), acc, new_j| { + let i_hi = new_j >> fold_bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (fold_bit + 1)) | i_lo; + let i1 = i0 | stride; + let partial_eq = get_split_eq(new_j); + point.clear(); + diff.clear(); + for c in cols { + let lo = c[i0]; + let hi = c[i1]; + point.push(lo); + diff.push(hi - lo); + } + // z = 0 then (skip z = 1) z = 2, 3, …, degree. + acc[0] += eval_fn(computation, point, extra_data) * partial_eq; + for k in 0..n_cols { + point[k] += diff[k]; + } + for acc_z in &mut acc[1..] { for k in 0..n_cols { point[k] += diff[k]; } - for acc_z in &mut acc[1..] { - for k in 0..n_cols { - point[k] += diff[k]; - } - *acc_z += eval_fn(computation, &point, extra_data) * partial_eq; - } - (acc, point, diff) - }, - ) - .map(|(acc, _, _)| acc) - .reduce( - || vec![EFT::ZERO; degree], - |mut a, b| { - for i in 0..degree { - a[i] += b[i]; - } - a - }, - ); + *acc_z += eval_fn(computation, point, extra_data) * partial_eq; + } + }, + |mut a, b| { + for i in 0..degree { + a[i] += b[i]; + } + a + }, + ); acc.into_iter().map(unpack_sum).collect() } @@ -678,17 +657,17 @@ pub fn prove_batched_air_sumcheck<'a, EF: ExtensionField>>( MultilinearPoint(challenges) } -pub fn compute_shifted_columns(n_shift_columns: usize, columns: &[&[F]]) -> Vec> { +pub fn compute_shifted_columns(n_shift_columns: usize, columns: &[&[F]]) -> Vec> { // Convention: the first `n_shift_columns` columns are the ones that get shifted. - columns[..n_shift_columns] - .par_iter() - .map(|column| { - let mut shifted = unsafe { uninitialized_vec(column.len()) }; - shifted[..column.len() - 1].copy_from_slice(&column[1..]); - shifted[column.len() - 1] = column[column.len() - 1]; - shifted - }) - .collect() + let mut out: Vec> = (0..n_shift_columns).map(|_| ArenaVec::new()).collect(); + parallel::par_chunks_mut(&mut out, 1, |i, slot| { + let column = columns[i]; + let mut shifted = unsafe { ArenaVec::::uninitialized(column.len()) }; + shifted[..column.len() - 1].copy_from_slice(&column[1..]); + shifted[column.len() - 1] = column[column.len() - 1]; + slot[0] = shifted; + }); + out } pub fn natural_ordering_point_for_session(sumcheck_air_point: &[EF], log_n_rows: usize) -> Vec { diff --git a/crates/sub_protocols/src/logup.rs b/crates/sub_protocols/src/logup.rs index 55af0a320..e680d273e 100644 --- a/crates/sub_protocols/src/logup.rs +++ b/crates/sub_protocols/src/logup.rs @@ -1,10 +1,11 @@ use crate::{ENDIANNESS_PIVOT_GKR, prove_gkr_quotient, verify_gkr_quotient}; use backend::*; use lean_vm::*; +use parallel::par_fill; use std::collections::BTreeMap; use tracing::instrument; +use utils::VarCount; use utils::ansi::Colorize; -use utils::*; #[derive(Debug, PartialEq, Hash, Clone)] pub struct GenericLogupStatements { @@ -48,9 +49,10 @@ pub fn prove_generic_logup( &tables_log_heights_sorted, ); let total_gkr_n_vars = log2_ceil_usize(total_active_len); - let mut numerators: Vec = unsafe { uninitialized_vec(total_active_len) }; + let mut numerators: ArenaVec = unsafe { ArenaVec::::uninitialized(total_active_len) }; let width = packing_width::(); - let mut denominators: Vec> = unsafe { uninitialized_vec(total_active_len / width) }; + let mut denominators: ArenaVec> = + unsafe { ArenaVec::>::uninitialized(total_active_len / width) }; let c_packed = EFPacking::::from(c); let alphas_packed: Vec> = alphas_eq_poly.iter().map(|a| EFPacking::::from(*a)).collect(); let memory_domainsep_packed = PFPacking::::from(F::from_usize(LOGUP_MEMORY_DOMAINSEP)); @@ -72,15 +74,13 @@ pub fn prove_generic_logup( }; let fill_num_from = |dst: &mut [F], src: &[F], neg: bool| { - dst.par_chunks_exact_mut(chunk_size) - .enumerate() - .for_each(|(c, dst_chunk)| { - let src_chunk = &src[c * chunk_size..][..chunk_size]; - for (i, slot) in dst_chunk.iter_mut().enumerate() { - let v = src_chunk[i.reverse_bits() >> chunk_shift]; - *slot = if neg { -v } else { v }; - } - }); + parallel::par_chunks_mut(dst, chunk_size, |c, dst_chunk| { + let src_chunk = &src[c * chunk_size..][..chunk_size]; + for (i, slot) in dst_chunk.iter_mut().enumerate() { + let v = src_chunk[i.reverse_bits() >> chunk_shift]; + *slot = if neg { -v } else { v }; + } + }); }; let mut offset = 0; @@ -88,7 +88,7 @@ pub fn prove_generic_logup( // Memory section. assert_eq!(memory.len(), memory_acc.len()); fill_num_from(&mut numerators[offset..][..memory.len()], memory_acc, true); - fill_denoms(&mut denominators[offset / width..][..memory.len() / width], |p| { + par_fill(&mut denominators[offset / width..][..memory.len() / width], |p| { c_packed - finger_print_packed::( memory_domainsep_packed, @@ -105,7 +105,7 @@ pub fn prove_generic_logup( assert_eq!(1 << log_bytecode, bytecode_acc.len()); fill_num_from(&mut numerators[offset..][..bytecode_acc.len()], bytecode_acc, true); let bytecode_stride = N_INSTRUCTION_COLUMNS.next_power_of_two(); - fill_denoms( + par_fill( &mut denominators[offset / width..][..(1 << log_bytecode) / width], |p| { let mut data = [PFPacking::::ZERO; N_INSTRUCTION_COLUMNS + 1]; @@ -118,12 +118,14 @@ pub fn prove_generic_logup( ); if 1 << log_bytecode < max_table_height { // padding - numerators[offset + (1 << log_bytecode)..offset + max_table_height] - .par_iter_mut() - .for_each(|n| *n = F::ZERO); - denominators[(offset + (1 << log_bytecode)) / width..(offset + max_table_height) / width] - .par_iter_mut() - .for_each(|d| *d = EFPacking::::ONE); + par_fill( + &mut numerators[offset + (1 << log_bytecode)..offset + max_table_height], + |_| F::ZERO, + ); + par_fill( + &mut denominators[(offset + (1 << log_bytecode)) / width..(offset + max_table_height) / width], + |_| EFPacking::::ONE, + ); } offset += max_table_height.max(1 << log_bytecode); @@ -142,17 +144,15 @@ pub fn prove_generic_logup( let col_index = &trace.columns[group.idx_col]; let packed_chunk_size = (1 << log_n_rows) / width; - numerators[offset..][..group_len << log_n_rows] - .par_iter_mut() - .for_each(|n| *n = F::ONE); + par_fill(&mut numerators[offset..][..group_len << log_n_rows], |_| F::ONE); - denominators[offset / width..][..group_len * packed_chunk_size] - .par_chunks_exact_mut(packed_chunk_size) - .enumerate() - .for_each(|(i, denom_chunk)| { + parallel::par_chunks_mut( + &mut denominators[offset / width..][..group_len * packed_chunk_size], + packed_chunk_size, + |i, denom_chunk| { let i_field = F::from_usize(i); let col_value = &trace.columns[group.value_cols[i]]; - denom_chunk.par_iter_mut().enumerate().for_each(|(p, slot)| { + for (p, slot) in denom_chunk.iter_mut().enumerate() { *slot = c_packed - finger_print_packed::( memory_domainsep_packed, @@ -162,8 +162,9 @@ pub fn prove_generic_logup( ], &alphas_packed, ); - }); - }); + } + }, + ); offset += group_len << log_n_rows; bus_idx += group_len; next_group += 1; @@ -175,7 +176,7 @@ pub fn prove_generic_logup( match bus.multiplicity { BusMultiplicity::One => { let val = bus.direction.to_field_flag(); - slice.par_iter_mut().for_each(|n| *n = val); + par_fill(slice, |_| val); } BusMultiplicity::Column(col) => { fill_num_from(slice, &trace.columns[col], matches!(bus.direction, BusDirection::Pull)); @@ -204,7 +205,7 @@ pub fn prove_generic_logup( _ => PFPacking::::ZERO, }; - fill_denoms(denom_slot, |p| { + par_fill(denom_slot, |p| { let mut data_buf = [PFPacking::::ZERO; MAX_BUS_WIDTH]; for k in 0..n_data { let col = data_cols[k]; @@ -526,11 +527,3 @@ fn compute_total_active_len( .map(|(table, log_n_rows)| offset_for_table(table, *log_n_rows)) .sum::() } - -#[inline] -fn fill_denoms(dst: &mut [EFPacking], build: Build) -where - Build: Fn(usize) -> EFPacking + Sync, -{ - dst.par_iter_mut().enumerate().for_each(|(p, slot)| *slot = build(p)); -} diff --git a/crates/sub_protocols/src/quotient_gkr/layers.rs b/crates/sub_protocols/src/quotient_gkr/layers.rs index 0ff9e1663..7a19da588 100644 --- a/crates/sub_protocols/src/quotient_gkr/layers.rs +++ b/crates/sub_protocols/src/quotient_gkr/layers.rs @@ -1,22 +1,21 @@ use backend::PackedValue; -use std::borrow::Cow; use backend::*; pub(super) enum LayerStorage<'a, EF: ExtensionField>> { Initial { - nums: Cow<'a, [PFPacking]>, - dens: Cow<'a, [EFPacking]>, + nums: ArenaCow<'a, PFPacking>, + dens: ArenaCow<'a, EFPacking>, chunk_log: usize, }, PackedBr { - nums: Cow<'a, [EFPacking]>, - dens: Cow<'a, [EFPacking]>, + nums: ArenaCow<'a, EFPacking>, + dens: ArenaCow<'a, EFPacking>, chunk_log: usize, }, Natural { - nums: Cow<'a, [EF]>, - dens: Cow<'a, [EF]>, + nums: ArenaCow<'a, EF>, + dens: ArenaCow<'a, EF>, }, } @@ -24,24 +23,24 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { pub(super) fn convert_to_natural(&self) -> Self { match self { Self::Initial { nums, dens, chunk_log } => { - let n_nat_base: Vec = unpack_base_and_unreverse_active::(nums.as_ref(), *chunk_log); + let n_nat_base: ArenaVec = unpack_base_and_unreverse_active::(nums.as_ref(), *chunk_log); let d_nat = unpack_and_unreverse_active::(dens.as_ref(), *chunk_log); Self::Natural { - nums: Cow::Owned(n_nat_base), - dens: Cow::Owned(d_nat), + nums: ArenaCow::Owned(n_nat_base), + dens: ArenaCow::Owned(d_nat), } } Self::PackedBr { nums, dens, chunk_log } => { let n_nat = unpack_and_unreverse_active::(nums.as_ref(), *chunk_log); let d_nat = unpack_and_unreverse_active::(dens.as_ref(), *chunk_log); Self::Natural { - nums: Cow::Owned(n_nat), - dens: Cow::Owned(d_nat), + nums: ArenaCow::Owned(n_nat), + dens: ArenaCow::Owned(d_nat), } } Self::Natural { nums, dens } => Self::Natural { - nums: Cow::Owned(nums.to_vec()), - dens: Cow::Owned(dens.to_vec()), + nums: ArenaCow::Owned(ArenaVec::from_slice(nums.as_ref())), + dens: ArenaCow::Owned(ArenaVec::from_slice(dens.as_ref())), }, } } @@ -52,8 +51,8 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { let (new_nums, new_dens) = sum_quotients_2_by_2_packed_br::(nums.as_ref(), dens.as_ref(), *chunk_log); Self::PackedBr { - nums: Cow::Owned(new_nums), - dens: Cow::Owned(new_dens), + nums: ArenaCow::Owned(new_nums), + dens: ArenaCow::Owned(new_dens), chunk_log: *chunk_log - 1, } } @@ -61,16 +60,16 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { let (new_nums, new_dens) = sum_quotients_2_by_2_packed_br::(nums.as_ref(), dens.as_ref(), *chunk_log); Self::PackedBr { - nums: Cow::Owned(new_nums), - dens: Cow::Owned(new_dens), + nums: ArenaCow::Owned(new_nums), + dens: ArenaCow::Owned(new_dens), chunk_log: *chunk_log - 1, } } Self::Natural { nums, dens } => { let (nn, nd) = sum_quotients_2_by_2(nums.as_ref(), dens.as_ref()); Self::Natural { - nums: Cow::Owned(nn), - dens: Cow::Owned(nd), + nums: ArenaCow::Owned(nn), + dens: ArenaCow::Owned(nd), } } } @@ -84,7 +83,7 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { } } - pub fn materialise_in_full(self) -> (Vec, Vec) { + pub fn materialise_in_full(self) -> (ArenaVec, ArenaVec) { let natural = match self { Self::Natural { .. } => self, other => other.convert_to_natural(), @@ -101,47 +100,46 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { } } -pub(super) fn bit_reverse_chunks(v: &[T], chunk_log: usize) -> Vec { +pub(super) fn bit_reverse_chunks(v: &[T], chunk_log: usize) -> ArenaVec { let n = v.len(); let chunk_size = 1usize << chunk_log; debug_assert!(n.is_multiple_of(chunk_size)); - let mut out: Vec = unsafe { uninitialized_vec(n) }; + let mut out: ArenaVec = unsafe { ArenaVec::uninitialized(n) }; if chunk_log == 0 { out.copy_from_slice(v); return out; } let shift = usize::BITS as usize - chunk_log; - out.par_chunks_exact_mut(chunk_size) - .zip(v.par_chunks_exact(chunk_size)) - .for_each(|(dst, src)| { - for (p, slot) in dst.iter_mut().enumerate() { - *slot = src[p.reverse_bits() >> shift]; - } - }); + parallel::par_chunks_mut(&mut out, chunk_size, |c, dst| { + let src = &v[c * chunk_size..][..chunk_size]; + for (p, slot) in dst.iter_mut().enumerate() { + *slot = src[p.reverse_bits() >> shift]; + } + }); out } -fn sum_quotients_2_by_2>>(nums: &[EF], dens: &[EF]) -> (Vec, Vec) { +fn sum_quotients_2_by_2>>(nums: &[EF], dens: &[EF]) -> (ArenaVec, ArenaVec) { assert_eq!(nums.len(), dens.len()); let active_len = nums.len(); let new_active = active_len.div_ceil(2); let full_pairs = active_len / 2; - let mut new_nums: Vec = unsafe { uninitialized_vec(new_active) }; - let mut new_dens: Vec = unsafe { uninitialized_vec(new_active) }; + let mut new_nums: ArenaVec = unsafe { ArenaVec::uninitialized(new_active) }; + let mut new_dens: ArenaVec = unsafe { ArenaVec::uninitialized(new_active) }; - new_nums[..full_pairs] - .par_iter_mut() - .zip(new_dens[..full_pairs].par_iter_mut()) - .enumerate() - .for_each(|(i, (num, den))| { + parallel::par_for_each_mut2( + &mut new_nums[..full_pairs], + &mut new_dens[..full_pairs], + |i, num, den| { let n0 = nums[2 * i]; let n1 = nums[2 * i + 1]; let d0 = dens[2 * i]; let d1 = dens[2 * i + 1]; *num = d1 * n0 + d0 * n1; *den = d0 * d1; - }); + }, + ); // Boundary (at most one pair: a/b + 0/1 = a/b). if full_pairs < new_active { @@ -156,7 +154,7 @@ fn sum_quotients_2_by_2_packed_br>, N>( nums: &[N], dens: &[EFPacking], chunk_log: usize, -) -> (Vec>, Vec>) +) -> (ArenaVec>, ArenaVec>) where N: Copy + Send + Sync, EFPacking: Algebra, @@ -169,21 +167,17 @@ where let stride = 1usize << bit; let lo_mask = stride - 1; - let mut new_nums: Vec> = unsafe { uninitialized_vec(nums.len() >> 1) }; - let mut new_dens: Vec> = unsafe { uninitialized_vec(nums.len() >> 1) }; - - new_nums - .par_iter_mut() - .zip(new_dens.par_iter_mut()) - .enumerate() - .for_each(|(new_j, (num_out, den_out))| { - let i_hi = new_j >> bit; - let i_lo = new_j & lo_mask; - let i0 = (i_hi << (bit + 1)) | i_lo; - let i1 = i0 | stride; - *num_out = dens[i1] * nums[i0] + dens[i0] * nums[i1]; - *den_out = dens[i0] * dens[i1]; - }); + let mut new_nums: ArenaVec> = unsafe { ArenaVec::uninitialized(nums.len() >> 1) }; + let mut new_dens: ArenaVec> = unsafe { ArenaVec::uninitialized(nums.len() >> 1) }; + + parallel::par_for_each_mut2(&mut new_nums, &mut new_dens, |new_j, num_out, den_out| { + let i_hi = new_j >> bit; + let i_lo = new_j & lo_mask; + let i0 = (i_hi << (bit + 1)) | i_lo; + let i1 = i0 | stride; + *num_out = dens[i1] * nums[i0] + dens[i0] * nums[i1]; + *den_out = dens[i0] * dens[i1]; + }); (new_nums, new_dens) } @@ -191,11 +185,11 @@ where pub(super) fn unpack_and_unreverse_active>>( v: &[EFPacking], chunk_log: usize, -) -> Vec { - bit_reverse_chunks(&unpack_extension::(v), chunk_log) +) -> ArenaVec { + bit_reverse_chunks(&unpack_extension::>(v), chunk_log) } -fn unpack_base_and_unreverse_active>>(v: &[PFPacking], chunk_log: usize) -> Vec { - let active_unpacked: Vec = PFPacking::::unpack_slice(v).iter().map(|x| EF::from(*x)).collect(); +fn unpack_base_and_unreverse_active>>(v: &[PFPacking], chunk_log: usize) -> ArenaVec { + let active_unpacked: ArenaVec = PFPacking::::unpack_slice(v).iter().map(|x| EF::from(*x)).collect(); bit_reverse_chunks(&active_unpacked, chunk_log) } diff --git a/crates/sub_protocols/src/quotient_gkr/mod.rs b/crates/sub_protocols/src/quotient_gkr/mod.rs index 26fa25a65..0ad1b4185 100644 --- a/crates/sub_protocols/src/quotient_gkr/mod.rs +++ b/crates/sub_protocols/src/quotient_gkr/mod.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use backend::*; use tracing::instrument; @@ -41,8 +39,8 @@ pub fn prove_gkr_quotient<'a, EF: ExtensionField>>( assert_eq!(nums_br.len(), dens_br.len()); let initial = LayerStorage::Initial { - nums: Cow::Borrowed(nums_br), - dens: Cow::Borrowed(dens_br), + nums: ArenaCow::Borrowed(nums_br), + dens: ArenaCow::Borrowed(dens_br), chunk_log: pivot, }; @@ -207,7 +205,7 @@ mod tests { type EF = QuinticExtensionFieldKB; fn sum_all_quotients(nums: &[F], den: &[EF]) -> EF { - nums.par_iter().zip(den).map(|(&n, &d)| EF::from(n) / d).sum() + nums.iter().zip(den).map(|(&n, &d)| EF::from(n) / d).sum() } fn bit_reverse_chunks_and_pack_ext>>(v: &[EF], chunk_log: usize) -> Vec> { diff --git a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs index 27afd58f7..c771d6e04 100644 --- a/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs +++ b/crates/sub_protocols/src/quotient_gkr/sumcheck_utils.rs @@ -1,13 +1,10 @@ -use std::{ - borrow::Cow, - ops::{Add, AddAssign, Mul}, -}; +use std::ops::{Add, AddAssign, Mul}; use backend::*; use crate::quotient_gkr::layers::unpack_and_unreverse_active; -pub(super) fn even_odd_split(v: &[T]) -> (Vec, Vec) { +pub(super) fn even_odd_split(v: &[T]) -> (ArenaVec, ArenaVec) { ( v.iter().step_by(2).copied().collect(), v.iter().skip(1).step_by(2).copied().collect(), @@ -67,14 +64,15 @@ where N: PrimeCharacteristicRing + Copy, T: Algebra + Algebra + Copy, { - let (c0_den, c2_den) = sumcheck_quadratic(((&dl.0, &dl.1), (&dr.0, &dr.1))); - let (c0_a, c2_a) = sumcheck_quadratic(((&nl.0, &nl.1), (&dr.0, &dr.1))); - let (c0_b, c2_b) = sumcheck_quadratic(((&nr.0, &nr.1), (&dl.0, &dl.1))); + let ddl = dl.1 - dl.0; + let ddr = dr.1 - dr.0; + let dnl = nl.1 - nl.0; + let dnr = nr.1 - nr.0; RoundCoeffs { - c0_den, - c2_den, - c0_num: c0_a + c0_b, - c2_num: c2_a + c2_b, + c0_den: dr.0 * dl.0, + c2_den: ddl * ddr, + c0_num: dr.0 * nl.0 + dl.0 * nr.0, + c2_num: ddr * dnl + ddl * dnr, } } @@ -131,12 +129,12 @@ pub(super) fn quotient_sumcheck_prove_packed_br_base>> let mut sum = expected_sum; let outer_point = remaining_eq[..head_len].to_vec(); - let eq_outer = eval_eq(&outer_point); + let eq_outer: ArenaVec = eval_eq(&outer_point); let padding_sum = alpha * mle_of_zeros_then_ones(active_chunks, &outer_point); let eq_alpha_0 = *remaining_eq.last().unwrap(); - let eq_within_0 = eval_eq_packed(&within_pt(&remaining_eq, head_len)); + let eq_within_0: ArenaVec<_> = eval_eq_packed(&within_pt(&remaining_eq, head_len)); let coeffs_0 = compute_round_packed::(packed_nums, packed_dens, parent_chunk_log, &eq_outer, &eq_within_0); let r0 = finalize_round( prover_state, @@ -151,7 +149,7 @@ pub(super) fn quotient_sumcheck_prove_packed_br_base>> remaining_eq.pop(); let eq_alpha_1 = *remaining_eq.last().unwrap(); - let eq_within_1 = eval_eq_packed(&within_pt(&remaining_eq, head_len)); + let eq_within_1: ArenaVec<_> = eval_eq_packed(&within_pt(&remaining_eq, head_len)); let (nums_ext, dens_ext, coeffs_1) = fold_and_compute_round_packed::(packed_nums, packed_dens, parent_chunk_log, r0, &eq_outer, &eq_within_1); let r1 = finalize_round( @@ -168,8 +166,8 @@ pub(super) fn quotient_sumcheck_prove_packed_br_base>> run_phase1_sumcheck( prover_state, - Cow::Owned(nums_ext), - Cow::Owned(dens_ext), + ArenaCow::Owned(nums_ext), + ArenaCow::Owned(dens_ext), parent_chunk_log - 2, remaining_eq, q_natural, @@ -185,15 +183,15 @@ pub(super) fn quotient_sumcheck_prove_packed_br_base>> #[allow(clippy::too_many_arguments)] pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( prover_state: &mut impl FSProver, - mut nums: Cow<'a, [EFPacking]>, - mut dens: Cow<'a, [EFPacking]>, + mut nums: ArenaCow<'a, EFPacking>, + mut dens: ArenaCow<'a, EFPacking>, mut layer_chunk_log: usize, mut remaining_eq: Vec, mut q_natural: Vec, alpha: EF, mut sum: EF, mut mmf: EF, - precomputed_eq_outer: Option>, + precomputed_eq_outer: Option>, initial_pending_r: Option, ) -> (Vec, [EF; 4]) { let w = packing_log_width::(); @@ -218,7 +216,7 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( let head_len = (remaining_eq.len() + 1).saturating_sub(layer_chunk_log); let outer_point: Vec = remaining_eq[..head_len].to_vec(); - let eq_outer: Vec = precomputed_eq_outer.unwrap_or_else(|| eval_eq(&outer_point)); + let eq_outer: ArenaVec = precomputed_eq_outer.unwrap_or_else(|| eval_eq(&outer_point)); let active_chunks = (nums.len() << w) >> (layer_chunk_log + usize::from(initial_pending_r.is_some())); @@ -227,7 +225,7 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( let mut pending_r: Option = initial_pending_r; while layer_chunk_log > w + 1 && remaining_eq.len() > w + 1 { let eq_alpha = *remaining_eq.last().unwrap(); - let eq_within = eval_eq_packed(&within_pt(&remaining_eq, head_len)); + let eq_within: ArenaVec<_> = eval_eq_packed(&within_pt(&remaining_eq, head_len)); let coeffs = if let Some(prev_r) = pending_r.take() { let (new_nums, new_dens, c) = fold_and_compute_round_packed::( @@ -238,8 +236,8 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( &eq_outer, &eq_within, ); - nums = Cow::Owned(new_nums); - dens = Cow::Owned(new_dens); + nums = ArenaCow::Owned(new_nums); + dens = ArenaCow::Owned(new_dens); c } else { compute_round_packed::(nums.as_ref(), dens.as_ref(), layer_chunk_log, &eq_outer, &eq_within) @@ -255,8 +253,8 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( if let Some(prev_r) = pending_r { let prev_bit = layer_chunk_log - 1 - w; let mul = |x: EFPacking, a: EF| x * a; - nums = Cow::Owned(fold_multilinear_at_bit(nums.as_ref(), prev_r, prev_bit, &mul)); - dens = Cow::Owned(fold_multilinear_at_bit(dens.as_ref(), prev_r, prev_bit, &mul)); + nums = ArenaCow::Owned(fold_multilinear_at_bit(nums.as_ref(), prev_r, prev_bit, &mul, false)); + dens = ArenaCow::Owned(fold_multilinear_at_bit(dens.as_ref(), prev_r, prev_bit, &mul, false)); } let nums_nat = unpack_and_unreverse_active::(nums.as_ref(), layer_chunk_log); @@ -281,10 +279,10 @@ pub(super) fn run_phase1_sumcheck<'a, EF: ExtensionField>>( #[allow(clippy::too_many_arguments)] pub(super) fn run_phase2_sumcheck>>( prover_state: &mut impl FSProver, - mut num_l: Vec, - mut num_r: Vec, - mut den_l: Vec, - mut den_r: Vec, + mut num_l: ArenaVec, + mut num_r: ArenaVec, + mut den_l: ArenaVec, + mut den_r: ArenaVec, mut remaining_eq: Vec, mut q_natural: Vec, alpha: EF, @@ -292,7 +290,7 @@ pub(super) fn run_phase2_sumcheck>>( mut mmf: EF, ) -> (Vec, [EF; 4]) { let eq_prefix_init = &remaining_eq[..remaining_eq.len().saturating_sub(1)]; - let mut eq_table = eval_eq(eq_prefix_init); + let mut eq_table: ArenaVec = eval_eq(eq_prefix_init); for _round in 0..remaining_eq.len() { let eq_alpha = *remaining_eq.last().unwrap(); @@ -328,10 +326,7 @@ pub(super) fn run_phase2_sumcheck>>( }; let acc: RoundCoeffs = if active_pairs > PARALLEL_THRESHOLD { - (0..active_pairs) - .into_par_iter() - .map(term) - .reduce(RoundCoeffs::zero, Add::add) + parallel::map_reduce(active_pairs, RoundCoeffs::zero, term, Add::add) } else { (0..active_pairs).map(term).fold(RoundCoeffs::::zero(), Add::add) }; @@ -362,7 +357,7 @@ pub(super) fn run_phase2_sumcheck>>( if new_eq_len > 0 { let fold_eq = |i: usize| eq_table[2 * i] + eq_table[2 * i + 1]; eq_table = if new_eq_len >= PARALLEL_THRESHOLD { - (0..new_eq_len).into_par_iter().map(fold_eq).collect() + ArenaVec::par_collect(new_eq_len, fold_eq) } else { (0..new_eq_len).map(fold_eq).collect() }; @@ -377,25 +372,22 @@ pub(super) fn run_phase2_sumcheck>>( (q_natural, evals) } -fn fold_normal_with_padding>>(m: &[EF], r: EF, pad_value: EF) -> Vec { +fn fold_normal_with_padding>>(m: &[EF], r: EF, pad_value: EF) -> ArenaVec { let active = m.len(); let new_active = active.div_ceil(2); assert!(new_active != 0); - let mut out = unsafe { uninitialized_vec(new_active) }; + let mut out: ArenaVec = unsafe { ArenaVec::uninitialized(new_active) }; - let compute = |(i, slot): (usize, &mut EF)| { + let compute = |i: usize, slot: &mut EF| { let a = m[2 * i]; let b = if 2 * i + 1 < active { m[2 * i + 1] } else { pad_value }; *slot = a + (b - a) * r; }; if new_active < PARALLEL_THRESHOLD { - out.iter_mut().enumerate().for_each(compute); + out.iter_mut().enumerate().for_each(|(i, slot)| compute(i, slot)); } else { - out.par_iter_mut() - .with_min_len(PARALLEL_THRESHOLD) - .enumerate() - .for_each(compute); + parallel::par_for_each_mut(&mut out, compute); } out } @@ -420,10 +412,13 @@ where debug_assert_eq!(dens.len(), nums.len()); debug_assert_eq!(eq_within.len(), quarter); - nums.par_chunks_exact(layer_packed) - .zip(dens.par_chunks_exact(layer_packed)) - .enumerate() - .fold(RoundCoeffs::zero, |mut acc, (c, (n_c, d_c))| { + let n_chunks = nums.len() / layer_packed; + parallel::map_reduce( + n_chunks, + RoundCoeffs::zero, + |c| { + let n_c = &nums[c * layer_packed..][..layer_packed]; + let d_c = &dens[c * layer_packed..][..layer_packed]; let eq_o: EF = eq_outer.get(c).copied().unwrap_or(EF::ONE); let mut local = RoundCoeffs::>::zero(); for inner in 0..quarter { @@ -435,10 +430,10 @@ where ); local += coeffs * eq_within[inner]; } - acc += local * eq_o; - acc - }) - .reduce(RoundCoeffs::zero, Add::add) + local * eq_o + }, + Add::add, + ) } #[allow(clippy::type_complexity)] @@ -449,7 +444,11 @@ fn fold_and_compute_round_packed>, N>( prev_r: EF, eq_outer: &[EF], eq_within: &[EFPacking], -) -> (Vec>, Vec>, RoundCoeffs>) +) -> ( + ArenaVec>, + ArenaVec>, + RoundCoeffs>, +) where N: PrimeCharacteristicRing + Copy + Send + Sync, EFPacking: Algebra, @@ -468,17 +467,21 @@ where debug_assert_eq!(eq_within.len(), in_eighth); let active_out_packed = nums.len() / 2; - let mut new_nums: Vec> = unsafe { uninitialized_vec(active_out_packed) }; - let mut new_dens: Vec> = unsafe { uninitialized_vec(active_out_packed) }; + let mut new_nums: ArenaVec> = unsafe { ArenaVec::uninitialized(active_out_packed) }; + let mut new_dens: ArenaVec> = unsafe { ArenaVec::uninitialized(active_out_packed) }; let prev_r_packed: EFPacking = as From>::from(prev_r); - let coeffs = nums - .par_chunks_exact(in_packed) - .zip(dens.par_chunks_exact(in_packed)) - .zip(new_nums.par_chunks_exact_mut(out_packed)) - .zip(new_dens.par_chunks_exact_mut(out_packed)) - .enumerate() - .fold(RoundCoeffs::zero, |mut acc, (c, (((n_c, d_c), nn_c), nd_c))| { + let n_chunks = nums.len() / in_packed; + let nn = parallel::SendPtr(new_nums.as_mut_ptr()); + let nd = parallel::SendPtr(new_dens.as_mut_ptr()); + let coeffs = parallel::map_reduce( + n_chunks, + RoundCoeffs::zero, + |c| { + let n_c = &nums[c * in_packed..][..in_packed]; + let d_c = &dens[c * in_packed..][..in_packed]; + let nn_c = unsafe { nn.slice(c * out_packed, out_packed) }; + let nd_c = unsafe { nd.slice(c * out_packed, out_packed) }; let eq_o: EF = eq_outer.get(c).copied().unwrap_or(EF::ONE); let mut local = RoundCoeffs::>::zero(); for i in 0..in_eighth { @@ -499,10 +502,10 @@ where ); local += round * eq_within[i]; } - acc += local * eq_o; - acc - }) - .reduce(RoundCoeffs::zero, Add::add); + local * eq_o + }, + Add::add, + ); (new_nums, new_dens, coeffs) } diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index 207418caa..c7a40a754 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -119,7 +119,7 @@ pub fn stack_polynomials_and_commit( log2_strict_usize(bytecode_acc.len()), &tables_heights_sorted.iter().cloned().collect(), ); - let mut global_polynomial = F::zero_vec(1 << stacked_n_vars); // TODO avoid cloning all witness data + let mut global_polynomial = unsafe { ArenaVec::::zeroed(1 << stacked_n_vars) }; global_polynomial[..memory.len()].copy_from_slice(memory); let mut offset = memory.len(); global_polynomial[offset..][..memory_acc.len()].copy_from_slice(memory_acc); diff --git a/crates/sub_protocols/tests/prove_poseidon.rs b/crates/sub_protocols/tests/prove_poseidon.rs index cc584009d..a9070b1a3 100644 --- a/crates/sub_protocols/tests/prove_poseidon.rs +++ b/crates/sub_protocols/tests/prove_poseidon.rs @@ -26,13 +26,13 @@ fn prove_air_poseidon_16(log_n_rows: usize) { let n_rows = 1 << log_n_rows; let mut rng = StdRng::seed_from_u64(0); let n_cols = num_cols_poseidon_16(); - let mut trace = vec![vec![F::ZERO; n_rows]; n_cols]; + let mut trace: Vec> = (0..n_cols).map(|_| ArenaVec::filled(F::ZERO, n_rows)).collect(); for t in trace.iter_mut().skip(POSEIDON_COL_INPUT_START).take(WIDTH) { - *t = (0..n_rows).map(|_| rng.random()).collect(); + *t = ArenaVec::from_iter((0..n_rows).map(|_| rng.random())); } - trace[POSEIDON_COL_MULTIPLICITY] = vec![F::ONE; n_rows]; - trace[POSEIDON_COL_ADDR_LEFT_LO] = vec![F::ZERO; n_rows]; - trace[POSEIDON_COL_ADDR_LEFT_HI] = vec![F::from_usize(HALF_DIGEST_LEN); n_rows]; + trace[POSEIDON_COL_MULTIPLICITY] = ArenaVec::filled(F::ONE, n_rows); + trace[POSEIDON_COL_ADDR_LEFT_LO] = ArenaVec::filled(F::ZERO, n_rows); + trace[POSEIDON_COL_ADDR_LEFT_HI] = ArenaVec::filled(F::from_usize(HALF_DIGEST_LEN), n_rows); fill_trace_poseidon_16(&mut trace); let air = Poseidon16Precompile::; @@ -56,7 +56,7 @@ fn prove_air_poseidon_16(log_n_rows: usize) { let time = Instant::now(); - let mut commitmed_pol = F::zero_vec((n_cols << log_n_rows).next_power_of_two()); + let mut commitmed_pol = ArenaVec::filled(F::ZERO, (n_cols << log_n_rows).next_power_of_two()); for (i, col) in trace.iter().enumerate() { commitmed_pol[i << log_n_rows..(i + 1) << log_n_rows].copy_from_slice(col); } @@ -66,10 +66,10 @@ fn prove_air_poseidon_16(log_n_rows: usize) { let alpha = prover_state.sample(); let air_alpha_powers: Vec = alpha.powers().collect_n(n_constraints); // BUS=false => `logup_alphas_eq_poly` is unused; only `alpha_powers` matter. - let extra_data = ExtraDataForBuses::new(Vec::new(), air_alpha_powers); + let extra_data = ExtraDataForBuses::new(&[], air_alpha_powers); prover_state.duplex(); let eq_factor: Vec = prover_state.sample_vec(log_n_rows); - let column_refs: Vec<&[F]> = trace.iter().map(Vec::as_slice).collect(); + let column_refs: Vec<&[F]> = trace.iter().map(|c| c.as_slice()).collect(); let packed = MleGroupRef::::Base(column_refs).pack(); let mut sessions: Vec + '_>> = vec![Box::new(AirSumcheckSession::new( @@ -109,7 +109,7 @@ fn prove_air_poseidon_16(log_n_rows: usize) { let alpha = verifier_state.sample(); let air_alpha_powers: Vec = alpha.powers().collect_n(n_constraints); - let extra_data = ExtraDataForBuses::new(Vec::new(), air_alpha_powers); + let extra_data = ExtraDataForBuses::new(&[], air_alpha_powers); verifier_state.duplex(); let eq_factor_v: Vec = verifier_state.sample_vec(log_n_rows); diff --git a/crates/utils/src/misc.rs b/crates/utils/src/misc.rs index ff317ae4c..a92e22bed 100644 --- a/crates/utils/src/misc.rs +++ b/crates/utils/src/misc.rs @@ -1,5 +1,3 @@ -use std::sync::atomic::{AtomicPtr, Ordering}; - use backend::*; pub fn from_end(slice: &[A], n: usize) -> &[A] { @@ -7,15 +5,18 @@ pub fn from_end(slice: &[A], n: usize) -> &[A] { &slice[slice.len() - n..] } -pub fn transposed_par_iter_mut( - array: &mut [Vec; N], // all vectors must have the same length -) -> impl IndexedParallelIterator + '_ { +pub fn transposed_par_for_each_mut( + array: &mut [ArenaVec; N], // all vectors must have the same length + f: impl Fn(usize, [&mut A; N]) + Sync, +) { let len = array[0].len(); - let data_ptrs: [AtomicPtr; N] = array.each_mut().map(|v| AtomicPtr::new(v.as_mut_ptr())); - - (0..len) - .into_par_iter() - .map(move |i| unsafe { std::array::from_fn(|j| &mut *data_ptrs[j].load(Ordering::Relaxed).add(i)) }) + let data_ptrs: [parallel::SendPtr; N] = std::array::from_fn(|j| parallel::SendPtr(array[j].as_mut_ptr())); + parallel::for_each_index(len, |i| { + // SAFETY: distinct `i` index disjoint rows across all N columns; the column base pointers + // stay valid for the whole call (`array` is borrowed mutably for its duration). + let row: [&mut A; N] = unsafe { std::array::from_fn(|j| &mut *data_ptrs[j].0.add(i)) }; + f(i, row); + }); } pub fn collect_refs(vecs: &[Vec]) -> Vec<&[T]> { @@ -36,3 +37,11 @@ impl Counter { Self(0) } } + +pub fn decode_hex(s: &str) -> Vec { + let s = s.strip_prefix("0x").unwrap_or(s); + (0..s.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap()) + .collect() +} diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index 2030ca5f4..8a972782e 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -13,10 +13,9 @@ pub fn multilinears_linear_combination, P: Borro assert_eq!(pols.len(), scalars.len()); let n_vars = log2_strict_usize(pols[0].borrow().len()); assert!(pols.iter().all(|p| log2_strict_usize(p.borrow().len()) == n_vars)); - (0..1 << n_vars) - .into_par_iter() - .map(|i| dot_product(scalars.iter().copied(), pols.iter().map(|p| p.borrow()[i]))) - .collect::>() + parallel::par_map_collect(1 << n_vars, |i| { + dot_product(scalars.iter().copied(), pols.iter().map(|p| p.borrow()[i])) + }) } pub fn multilinear_eval_constants_at_right(limit: usize, point: &[F]) -> F { diff --git a/crates/utils/src/poseidon.rs b/crates/utils/src/poseidon.rs index dcd3fa85a..d4c569624 100644 --- a/crates/utils/src/poseidon.rs +++ b/crates/utils/src/poseidon.rs @@ -3,12 +3,18 @@ use backend::*; use std::sync::OnceLock; pub type Poseidon16 = Poseidon1KoalaBear16; +pub type Poseidon24 = Poseidon1KoalaBear24; pub const HALF_FULL_ROUNDS_16: usize = POSEIDON1_HALF_FULL_ROUNDS; pub const PARTIAL_ROUNDS_16: usize = POSEIDON1_PARTIAL_ROUNDS; +pub const HALF_FULL_ROUNDS_24: usize = POSEIDON1_HALF_FULL_ROUNDS_24; +pub const PARTIAL_ROUNDS_24: usize = POSEIDON1_PARTIAL_ROUNDS_24; + static POSEIDON_16_INSTANCE: OnceLock = OnceLock::new(); static POSEIDON_16_OF_ZERO: OnceLock<[KoalaBear; 8]> = OnceLock::new(); +static POSEIDON_24_INSTANCE: OnceLock = OnceLock::new(); +static POSEIDON_24_OF_ZERO: OnceLock<[KoalaBear; 9]> = OnceLock::new(); #[inline(always)] pub fn get_poseidon16() -> &'static Poseidon16 { @@ -37,6 +43,43 @@ pub fn poseidon16_compress_pair(left: &[KoalaBear; 8], right: &[KoalaBear; 8]) - poseidon16_compress(input) } +#[inline(always)] +pub fn get_poseidon24() -> &'static Poseidon24 { + POSEIDON_24_INSTANCE.get_or_init(default_koalabear_poseidon1_24) +} + +#[inline(always)] +pub fn get_poseidon_24_of_zero() -> &'static [KoalaBear; 9] { + POSEIDON_24_OF_ZERO.get_or_init(|| poseidon24_compress_0_9([KoalaBear::default(); 24])) +} + +#[inline(always)] +pub fn poseidon24_compress_0_9(input: [KoalaBear; 24]) -> [KoalaBear; 9] { + get_poseidon24().compress(input)[0..9].try_into().unwrap() +} + +#[inline(always)] +pub fn poseidon24_compress_9_18(input: [KoalaBear; 24]) -> [KoalaBear; 9] { + get_poseidon24().compress(input)[9..18].try_into().unwrap() +} + +#[inline(always)] +pub fn poseidon24_permute_0_9(input: [KoalaBear; 24]) -> [KoalaBear; 9] { + get_poseidon24().permute(input)[0..9].try_into().unwrap() +} + +#[inline(always)] +pub fn poseidon24_permute_9_18(input: [KoalaBear; 24]) -> [KoalaBear; 9] { + get_poseidon24().permute(input)[9..18].try_into().unwrap() +} + +pub fn poseidon24_compress_0_9_pair(left: [KoalaBear; 9], right: [KoalaBear; 15]) -> [KoalaBear; 9] { + let mut input = [KoalaBear::default(); 24]; + input[..9].copy_from_slice(&left); + input[9..].copy_from_slice(&right); + poseidon24_compress_0_9(input) +} + /// Absorbs `data` in rate-mode chunks of 8, starting from the IV `[data.len(), 0, ..., 0]`. pub fn poseidon_compress_slice(data: &[KoalaBear]) -> [KoalaBear; 8] { assert!(!data.is_empty()); @@ -51,3 +94,18 @@ pub fn poseidon_compress_slice(data: &[KoalaBear]) -> [KoalaBear; 8] { } hash } + +/// Sponge hash starting from the all-zero IV (capacity), absorbing `data` in rate-mode +/// chunks of 8; the final partial chunk (if any) is zero-padded. Handles arbitrary length +/// (no 8-alignment requirement). Matches the zkDSL `slice_hash_with_iv_dynamic_unroll`. +pub fn poseidon_compress_slice_zero_iv(data: &[KoalaBear]) -> [KoalaBear; 8] { + assert!(!data.is_empty()); + let mut hash = [KoalaBear::default(); 8]; + for chunk in data.chunks(8) { + let mut block = [KoalaBear::default(); 16]; + block[..8].copy_from_slice(&hash); + block[8..8 + chunk.len()].copy_from_slice(chunk); + hash = poseidon16_compress(block); + } + hash +} diff --git a/crates/whir/Cargo.toml b/crates/whir/Cargo.toml index 1c2a2b0a7..d2b2e0b37 100644 --- a/crates/whir/Cargo.toml +++ b/crates/whir/Cargo.toml @@ -12,9 +12,10 @@ fiat-shamir = { path = "../backend/fiat-shamir", package = "mt-fiat-shamir" } utils = { path = "../backend/utils", package = "mt-utils" } symetric = { path = "../backend/symetric", package = "mt-symetric" } system-info.workspace = true +zk-alloc.workspace = true +parallel.workspace = true itertools.workspace = true -rayon.workspace = true rand.workspace = true tracing.workspace = true diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index b64bb3502..f74bcb7fe 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -4,6 +4,7 @@ use fiat_shamir::FSProver; use field::{ExtensionField, TwoAdicField}; use poly::*; use tracing::{info_span, instrument}; +use zk_alloc::ArenaVec; use crate::*; @@ -35,11 +36,11 @@ impl>> MerkleData { match self { MerkleData::Base(prover_data) => { let (leaf, proof) = merkle_open::, PF>(prover_data, index); - (MleOwned::Base(leaf), proof) + (MleOwned::Base(ArenaVec::from_slice(&leaf)), proof) } MerkleData::Extension(prover_data) => { let (leaf, proof) = merkle_open::, EF>(prover_data, index); - (MleOwned::Extension(leaf), proof) + (MleOwned::Extension(ArenaVec::from_slice(&leaf)), proof) } } } diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index 277597eb8..8a29d7591 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -29,7 +29,6 @@ use field::PackedValue; use field::{BasedVectorSpace, Field, PackedField, TwoAdicField}; use itertools::Itertools; -use rayon::prelude::*; use tracing::instrument; use utils::{as_base_slice, log2_strict_usize}; @@ -164,7 +163,7 @@ where /// also divide by the height. #[inline] fn par_initial_layers(mat: &mut [F], chunk_size: usize, root_table: &[Vec], width: usize) { - mat.par_chunks_exact_mut(chunk_size).for_each(|chunk| { + parallel::par_chunks_mut(mat, chunk_size, |_, chunk| { initial_layers(chunk, root_table, width); }); } @@ -197,14 +196,15 @@ fn dft_layer>(vec: &mut [F], twiddles: &[B], width: us #[inline] fn dft_layer_par>(vec: &mut [F], twiddles: &[B], width: usize) { - vec.par_chunks_exact_mut(twiddles.len() * 2 * width).for_each(|block| { + parallel::par_chunks_mut(vec, twiddles.len() * 2 * width, |_, block| { let (left, right) = block.split_at_mut(twiddles.len() * width); - left.par_chunks_exact_mut(width) - .zip(right.par_chunks_exact_mut(width)) - .zip(twiddles.par_iter()) - .for_each(|((hi_chunk, lo_chunk), twiddle)| { - twiddle.apply_to_rows(hi_chunk, lo_chunk); - }); + for ((hi_chunk, lo_chunk), twiddle) in left + .chunks_exact_mut(width) + .zip(right.chunks_exact_mut(width)) + .zip(twiddles.iter()) + { + twiddle.apply_to_rows(hi_chunk, lo_chunk); + } }); } @@ -235,39 +235,20 @@ fn dft_layer_par_double, M: MultiLayerButterfly> assert_eq!(twiddles_large.len(), twiddles_small.len() * 2); // TODO optimal workload size with L1 cache - mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { - // (0..twiddles_small.len()).into_par_iter().for_each(|ind| { - // let hi_hi = slice_ref_mut(block, ind * width, width); - // let hi_lo = slice_ref_mut(block, (ind + twiddles_small.len()) * width, width); - // let lo_hi = slice_ref_mut(block, (ind + 2 * twiddles_small.len()) * width, width); - // let lo_lo = slice_ref_mut(block, (ind + 3 * twiddles_small.len()) * width, width); - // multi_butterfly.apply_2_layers( - // ((hi_hi, hi_lo), (lo_hi, lo_lo)), - // ind, - // twiddles_small, - // twiddles_large, - // ); - // }); - let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 2); - let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width); - let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width); - hi_hi_blocks - .par_chunks_exact_mut(width) - .zip(hi_lo_blocks.par_chunks_exact_mut(width)) - .zip(lo_hi_blocks.par_chunks_exact_mut(width)) - .zip(lo_lo_blocks.par_chunks_exact_mut(width)) - .enumerate() - .for_each(|(ind, (((hi_hi, hi_lo), lo_hi), lo_lo))| { - multi_butterfly.apply_2_layers( - ((hi_hi, hi_lo), (lo_hi, lo_lo)), - ind, - twiddles_small, - twiddles_large, - ); - }); - }); + parallel::par_chunks_mut(&mut *mat.values, twiddles_large.len() * 2 * width, |_, block| { + let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 2); + let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width); + let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width); + for (ind, (((hi_hi, hi_lo), lo_hi), lo_lo)) in hi_hi_blocks + .chunks_exact_mut(width) + .zip(hi_lo_blocks.chunks_exact_mut(width)) + .zip(lo_hi_blocks.chunks_exact_mut(width)) + .zip(lo_lo_blocks.chunks_exact_mut(width)) + .enumerate() + { + multi_butterfly.apply_2_layers(((hi_hi, hi_lo), (lo_hi, lo_lo)), ind, twiddles_small, twiddles_large); + } + }); } /// Applies three layers of a Radix-2 FFT butterfly network making use of parallelization. @@ -303,44 +284,38 @@ fn dft_layer_par_triple, M: MultiLayerButterfly> // let inner_chunk_size = // (workload_size::().next_power_of_two() / 8).min(eighth_outer_block_size); - mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { - let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 4); - let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width * 2); - let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width * 2); - let (hi_hi_hi_blocks, hi_hi_lo_blocks) = hi_hi_blocks.split_at_mut(twiddles_small.len() * width); - let (hi_lo_hi_blocks, hi_lo_lo_blocks) = hi_lo_blocks.split_at_mut(twiddles_small.len() * width); - let (lo_hi_hi_blocks, lo_hi_lo_blocks) = lo_hi_blocks.split_at_mut(twiddles_small.len() * width); - let (lo_lo_hi_blocks, lo_lo_lo_blocks) = lo_lo_blocks.split_at_mut(twiddles_small.len() * width); + parallel::par_chunks_mut(&mut *mat.values, twiddles_large.len() * 2 * width, |_, block| { + let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 4); + let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width * 2); + let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width * 2); + let (hi_hi_hi_blocks, hi_hi_lo_blocks) = hi_hi_blocks.split_at_mut(twiddles_small.len() * width); + let (hi_lo_hi_blocks, hi_lo_lo_blocks) = hi_lo_blocks.split_at_mut(twiddles_small.len() * width); + let (lo_hi_hi_blocks, lo_hi_lo_blocks) = lo_hi_blocks.split_at_mut(twiddles_small.len() * width); + let (lo_lo_hi_blocks, lo_lo_lo_blocks) = lo_lo_blocks.split_at_mut(twiddles_small.len() * width); + for (ind, (((((((hi_hi_hi, hi_hi_lo), hi_lo_hi), hi_lo_lo), lo_hi_hi), lo_hi_lo), lo_lo_hi), lo_lo_lo)) in hi_hi_hi_blocks - .par_chunks_exact_mut(width) - .zip(hi_hi_lo_blocks.par_chunks_exact_mut(width)) - .zip(hi_lo_hi_blocks.par_chunks_exact_mut(width)) - .zip(hi_lo_lo_blocks.par_chunks_exact_mut(width)) - .zip(lo_hi_hi_blocks.par_chunks_exact_mut(width)) - .zip(lo_hi_lo_blocks.par_chunks_exact_mut(width)) - .zip(lo_lo_hi_blocks.par_chunks_exact_mut(width)) - .zip(lo_lo_lo_blocks.par_chunks_exact_mut(width)) + .chunks_exact_mut(width) + .zip(hi_hi_lo_blocks.chunks_exact_mut(width)) + .zip(hi_lo_hi_blocks.chunks_exact_mut(width)) + .zip(hi_lo_lo_blocks.chunks_exact_mut(width)) + .zip(lo_hi_hi_blocks.chunks_exact_mut(width)) + .zip(lo_hi_lo_blocks.chunks_exact_mut(width)) + .zip(lo_lo_hi_blocks.chunks_exact_mut(width)) + .zip(lo_lo_lo_blocks.chunks_exact_mut(width)) .enumerate() - .for_each( - |( - ind, - (((((((hi_hi_hi, hi_hi_lo), hi_lo_hi), hi_lo_lo), lo_hi_hi), lo_hi_lo), lo_lo_hi), lo_lo_lo), - )| { - multi_butterfly.apply_3_layers( - ( - ((hi_hi_hi, hi_hi_lo), (hi_lo_hi, hi_lo_lo)), - ((lo_hi_hi, lo_hi_lo), (lo_lo_hi, lo_lo_lo)), - ), - ind, - twiddles_small, - twiddles_med, - twiddles_large, - ); - }, - ); - }); + { + multi_butterfly.apply_3_layers( + ( + ((hi_hi_hi, hi_hi_lo), (hi_lo_hi, hi_lo_lo)), + ((lo_hi_hi, lo_hi_lo), (lo_lo_hi, lo_lo_lo)), + ), + ind, + twiddles_small, + twiddles_med, + twiddles_large, + ); + } + }); } /// Applies the remaining layers of the Radix-2 FFT butterfly network in parallel. @@ -444,8 +419,8 @@ fn fft_triple_layer_quad_twiddle>( /// Estimates the optimal workload size for `T` to fit in L1 cache. #[must_use] -const fn workload_size() -> usize { - system_info::L1_CACHE_SIZE / size_of::() +fn workload_size() -> usize { + system_info::l1_cache_size() / size_of::() } /// Estimates the optimal number of rows of a `RowMajorMatrix` to take in each parallel chunk. diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 43446714a..d038629c2 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -12,7 +12,6 @@ use field::PrimeCharacteristicRing; use koala_bear::{KoalaBear, QuinticExtensionFieldKB, default_koalabear_poseidon1_16}; use poly::*; -use rayon::prelude::*; use symetric::Compression; use symetric::merkle::unpack_array; use tracing::instrument; @@ -228,22 +227,19 @@ where let mut digests = unsafe { uninitialized_vec(height) }; - digests - .par_chunks_exact_mut(width) - .enumerate() - .for_each(|(i, digests_chunk)| { - let first_row = i * width; - let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); - let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( - perm, - rtl_iter, - packed_initial_state, - ); - for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { - *dst = src; - } - }); + parallel::par_chunks_mut(&mut digests, width, |i, digests_chunk| { + let first_row = i * width; + let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); + let packed_digest: [P; DIGEST_ELEMS] = + symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( + perm, + rtl_iter, + packed_initial_state, + ); + for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { + *dst = src; + } + }); digests } diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 6636b77c7..0137a6860 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -5,9 +5,9 @@ use fiat_shamir::{FSProver, MerklePath, ProofResult}; use field::PrimeCharacteristicRing; use field::{ExtensionField, Field, TwoAdicField}; use poly::*; -use rayon::prelude::*; use sumcheck::{ProductComputation, run_product_sumcheck, sumcheck_prove_many_rounds}; use tracing::{info_span, instrument}; +use zk_alloc::ArenaVec; use crate::{config::WhirConfig, *}; @@ -188,7 +188,7 @@ where // Convert evaluations to coefficient form and send to the verifier. let mut coeffs = match &round_state.sumcheck_prover.evals { MleOwned::Extension(evals) => evals.clone(), - MleOwned::ExtensionPacked(evals) => unpack_extension::(evals), + MleOwned::ExtensionPacked(evals) => unpack_extension::>(evals), _ => unreachable!(), }; evals_to_coeffs(&mut coeffs); @@ -212,14 +212,14 @@ where match answer { MleOwned::Base(leaf) => { base_paths.push(MerklePath { - leaf_data: leaf, + leaf_data: leaf.to_vec(), sibling_hashes, leaf_index: challenge, }); } MleOwned::Extension(leaf) => { ext_paths.push(MerklePath { - leaf_data: leaf, + leaf_data: leaf.to_vec(), sibling_hashes, leaf_index: challenge, }); @@ -292,14 +292,14 @@ fn open_merkle_tree_at_challenges>>( match &answer { MleOwned::Base(leaf) => { base_paths.push(MerklePath { - leaf_data: leaf.clone(), + leaf_data: leaf.to_vec(), sibling_hashes, leaf_index: challenge, }); } MleOwned::Extension(leaf) => { ext_paths.push(MerklePath { - leaf_data: leaf.clone(), + leaf_data: leaf.to_vec(), sibling_hashes, leaf_index: challenge, }); @@ -422,7 +422,7 @@ where let (weights, sum) = combine_statement::(statement, combination_randomness); let mut evals = evals.pack(); - let mut weights = Mle::Owned(MleOwned::ExtensionPacked(weights)); + let mut weights = Mle::Owned(MleOwned::ExtensionPacked(ArenaVec::from_slice(&weights))); let (challengess, new_sum, new_evals, new_weights) = run_product_sumcheck( &evals.by_ref(), &weights.by_ref(), @@ -594,17 +594,12 @@ where for (e, &scalar) in smt.values.iter().zip(&next_gamma_powers) { combined_sum += e.value * scalar; } - chunks_mut - .into_par_iter() - .zip(&indexed_smt_values) - .for_each(|(out_buff, &(origin_index, _))| { - out_buff[..1 << shift] - .par_iter_mut() - .zip(&inner_poly) - .for_each(|(out_elem, &poly_elem)| { - *out_elem += poly_elem * next_gamma_powers[origin_index]; - }); - }); + parallel::par_for_each_mut(&mut chunks_mut, |i, out_buff| { + let (origin_index, _) = indexed_smt_values[i]; + for (out_elem, &poly_elem) in out_buff[..1 << shift].iter_mut().zip(&inner_poly[..]) { + *out_elem += poly_elem * next_gamma_powers[origin_index]; + } + }); gamma_pow = *next_gamma_powers.last().unwrap() * gamma; } } diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index e64799149..84f724a8c 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -6,7 +6,6 @@ use field::Field; use field::PackedValue; use field::{ExtensionField, TwoAdicField}; use poly::*; -use rayon::prelude::*; use std::any::{Any, TypeId}; use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; @@ -138,15 +137,12 @@ fn prepare_evals_for_fft_unpacked( let log_block_size = log2_strict_usize(block_size); let out_len = block_size * dft_n_cols; - (0..out_len) - .into_par_iter() - .map(|i| { - let block_index = i % dft_n_cols; - let offset_in_block = i / dft_n_cols; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; - unsafe { *evals.get_unchecked(src_index) } - }) - .collect() + parallel::par_map_collect(out_len, |i| { + let block_index = i % dft_n_cols; + let offset_in_block = i / dft_n_cols; + let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; + unsafe { *evals.get_unchecked(src_index) } + }) } fn prepare_evals_for_fft_packed_extension>>( @@ -163,22 +159,19 @@ fn prepare_evals_for_fft_packed_extension>>( let n_blocks_mask = n_blocks - 1; let packing_mask = (1 << log_packing) - 1; - (0..full_len) - .into_par_iter() - .map(|i| { - let block_index = i & n_blocks_mask; - let offset_in_block = i >> folding_factor; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; - let packed_src_index = src_index >> log_packing; - let offset_in_packing = src_index & packing_mask; - let packed = unsafe { evals.get_unchecked(packed_src_index) }; - let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); - EF::from_basis_coefficients_fn(|i| unsafe { - let u: &PFPacking = unpacked.get_unchecked(i); - *u.as_slice().get_unchecked(offset_in_packing) - }) + parallel::par_map_collect(full_len, |i| { + let block_index = i & n_blocks_mask; + let offset_in_block = i >> folding_factor; + let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; + let packed_src_index = src_index >> log_packing; + let offset_in_packing = src_index & packing_mask; + let packed = unsafe { evals.get_unchecked(packed_src_index) }; + let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); + EF::from_basis_coefficients_fn(|i| unsafe { + let u: &PFPacking = unpacked.get_unchecked(i); + *u.as_slice().get_unchecked(offset_in_packing) }) - .collect() + }) } type CacheKey = TypeId; diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index 217a19c5d..908cda416 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -10,6 +10,7 @@ use poly::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; use tracing_forest::{ForestLayer, util::LevelFilter}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; +use zk_alloc::ArenaVec; type F = KoalaBear; type EF = QuinticExtensionFieldKB; @@ -104,7 +105,7 @@ fn test_run_whir() { precompute_dft_twiddles::(1 << F::TWO_ADICITY); - let polynomial: MleOwned = MleOwned::Base(polynomial); + let polynomial: MleOwned = MleOwned::Base(ArenaVec::from_slice(&polynomial)); let time = Instant::now(); let witness = params.commit(&mut prover_state, &polynomial, num_coeffs); diff --git a/crates/xmss/Cargo.toml b/crates/xmss/Cargo.toml deleted file mode 100644 index 86c6ed3b7..000000000 --- a/crates/xmss/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "xmss" -version.workspace = true -edition.workspace = true - -[lints] -workspace = true - -[dependencies] - -rand.workspace = true -utils.workspace = true -backend.workspace = true -serde.workspace = true -lz4_flex.workspace = true -postcard.workspace = true -sha3.workspace = true - -[dev-dependencies] -postcard.workspace = true \ No newline at end of file diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs deleted file mode 100644 index 941eff17c..000000000 --- a/crates/xmss/src/lib.rs +++ /dev/null @@ -1,84 +0,0 @@ -#![cfg_attr(not(test), warn(unused_crate_dependencies))] -use backend::PrimeCharacteristicRing; -use backend::{DIGEST_LEN_FE, KoalaBear, POSEIDON1_WIDTH}; - -pub mod signers_cache; -mod wots; -pub use wots::*; -mod xmss; -pub use xmss::*; - -pub const XMSS_DIGEST_LEN: usize = 4; -pub(crate) const TWEAK_LEN: usize = 2; - -type F = KoalaBear; -type Digest = [F; XMSS_DIGEST_LEN]; -type PublicParam = [F; PUBLIC_PARAM_LEN_FE]; -type Randomness = [F; RANDOMNESS_LEN_FE]; - -// WOTS -pub const V: usize = 42; -pub const W: usize = 3; -pub const CHAIN_LENGTH: usize = 1 << W; -pub const NUM_CHAIN_HASHES: usize = 110; -pub const TARGET_SUM: usize = V * (CHAIN_LENGTH - 1) - NUM_CHAIN_HASHES; -pub const NUM_ENCODING_FE: usize = V.div_ceil(24 / W); -pub const RANDOMNESS_LEN_FE: usize = 6; -pub const MESSAGE_LEN_FE: usize = 8; -pub const PUBLIC_PARAM_LEN_FE: usize = 4; -pub const PUB_KEY_FLAT_SIZE: usize = XMSS_DIGEST_LEN + PUBLIC_PARAM_LEN_FE; -pub const WOTS_SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + V * XMSS_DIGEST_LEN; - -// XMSS -pub const LOG_LIFETIME: usize = 32; - -// Tweak: domain separation within each hash. -pub const TWEAK_TYPE_CHAIN: usize = 0; -pub const TWEAK_TYPE_WOTS_PK: usize = 1; -pub const TWEAK_TYPE_MERKLE: usize = 2; -pub const TWEAK_TYPE_ENCODING: usize = 3; - -const _: () = assert!(V.is_multiple_of(2)); // For efficiency of the snark (we can batch chains in pairs) - -/// index = slot or node_index in Merkle tree -pub fn make_tweak(tweak_type: usize, sub_position: usize, index: u32) -> [F; TWEAK_LEN] { - assert!(tweak_type < 4); - assert!(sub_position < 1 << 10); - let index_lo = (index & 0xFFFF) as usize; - let index_hi = (index >> 16) as usize; - [ - F::from_usize((tweak_type << 26) + (index_hi << 10) + sub_position), - F::from_usize(index_lo), - ] -} - -/// [tweak(2) | zeros(2) | public_param(4) | left_child(4) | right_child(4)] -pub(crate) fn build_merkle_data( - tweak: [F; TWEAK_LEN], - public_param: &PublicParam, - left_child: &Digest, - right_child: &Digest, -) -> [F; POSEIDON1_WIDTH] { - let mut data = [F::default(); POSEIDON1_WIDTH]; - data[..TWEAK_LEN].copy_from_slice(&tweak); - // data[2..4] = zeros (default) - data[DIGEST_LEN_FE - PUBLIC_PARAM_LEN_FE..][..PUBLIC_PARAM_LEN_FE].copy_from_slice(public_param); - data[DIGEST_LEN_FE..][..XMSS_DIGEST_LEN].copy_from_slice(left_child); - data[DIGEST_LEN_FE + XMSS_DIGEST_LEN..].copy_from_slice(right_child); - data -} - -/// [tweak(2) | zeros(2) | data(4)] -pub(crate) fn build_left_chain_input(tweak: [F; TWEAK_LEN], data: &Digest) -> [F; DIGEST_LEN_FE] { - let mut left = [F::default(); DIGEST_LEN_FE]; - left[..TWEAK_LEN].copy_from_slice(&tweak); - left[DIGEST_LEN_FE - XMSS_DIGEST_LEN..].copy_from_slice(data); - left -} - -/// [public_param(4) | zeros(4)] -pub(crate) fn build_right_chain_input(public_param: &PublicParam) -> [F; DIGEST_LEN_FE] { - let mut right = [F::default(); DIGEST_LEN_FE]; - right[..PUBLIC_PARAM_LEN_FE].copy_from_slice(public_param); - right -} diff --git a/crates/xmss/src/wots.rs b/crates/xmss/src/wots.rs deleted file mode 100644 index 30aeed579..000000000 --- a/crates/xmss/src/wots.rs +++ /dev/null @@ -1,198 +0,0 @@ -use backend::*; -use rand::{CryptoRng, RngExt}; -use serde::{Deserialize, Serialize}; -use utils::{ToUsize, poseidon16_compress_pair}; - -use crate::*; - -#[derive(Debug)] -pub struct WotsSecretKey { - pub pre_images: [Digest; V], - public_key: WotsPublicKey, -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct WotsPublicKey(pub [Digest; V]); - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct WotsSignature { - #[serde( - with = "backend::array_serialization", - bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>") - )] - pub chain_tips: [Digest; V], - pub randomness: Randomness, -} - -impl WotsSecretKey { - pub fn random(rng: &mut impl CryptoRng, public_param: PublicParam, slot: u32) -> Self { - Self::new(rng.random(), public_param, slot) - } - - pub fn new(pre_images: [Digest; V], public_param: PublicParam, slot: u32) -> Self { - Self { - pre_images, - public_key: WotsPublicKey(std::array::from_fn(|i| { - iterate_hash(&pre_images[i], CHAIN_LENGTH - 1, public_param, slot, i, 0) - })), - } - } - - pub const fn public_key(&self) -> &WotsPublicKey { - &self.public_key - } - - pub fn sign_with_randomness( - &self, - message: &[F; MESSAGE_LEN_FE], - slot: u32, - xmss_pub_key: &XmssPublicKey, - randomness: Randomness, - ) -> Option { - let encoding = wots_encode(message, slot, xmss_pub_key, &randomness)?; - Some(self.sign_with_encoding(randomness, &encoding, xmss_pub_key.public_param, slot)) - } - - fn sign_with_encoding( - &self, - randomness: Randomness, - encoding: &[u8; V], - public_param: PublicParam, - slot: u32, - ) -> WotsSignature { - WotsSignature { - chain_tips: std::array::from_fn(|i| { - iterate_hash(&self.pre_images[i], encoding[i] as usize, public_param, slot, i, 0) - }), - randomness, - } - } -} - -impl WotsSignature { - pub fn recover_public_key( - &self, - message: &[F; MESSAGE_LEN_FE], - slot: u32, - xmss_pub_key: &XmssPublicKey, - ) -> Option { - let encoding = wots_encode(message, slot, xmss_pub_key, &self.randomness)?; - Some(WotsPublicKey(std::array::from_fn(|i| { - iterate_hash( - &self.chain_tips[i], - CHAIN_LENGTH - 1 - encoding[i] as usize, - xmss_pub_key.public_param, - slot, - i, - encoding[i] as usize, - ) - }))) - } -} - -impl WotsPublicKey { - // We use a T-Sponge with replacement, i.e. we use Poseidon in compression mode + replace (instead of modular addition) when ingesting 8 new field elements. - pub fn hash(&self, public_param: PublicParam, slot: u32) -> Digest { - // IV: [tweak(2) | 00 | pp(4)] - let tweak = make_tweak(TWEAK_TYPE_WOTS_PK, 0, slot); - let mut state = [F::default(); 8]; - state[..TWEAK_LEN].copy_from_slice(&tweak); - // state[2..4] = 00 (default) - state[4..4 + PUBLIC_PARAM_LEN_FE].copy_from_slice(&public_param); - - let zeros = [F::ZERO; 8]; // for snark-friendliless (not necessary for security) - state = poseidon16_compress_pair(&state, &zeros); - - for i in (0..V).step_by(2) { - let mut chunk = [F::default(); 8]; - chunk[..XMSS_DIGEST_LEN].copy_from_slice(&self.0[i]); - chunk[XMSS_DIGEST_LEN..].copy_from_slice(&self.0[i + 1]); - state = poseidon16_compress_pair(&state, &chunk); - } - state[..XMSS_DIGEST_LEN].try_into().unwrap() - } -} - -pub fn iterate_hash( - a: &Digest, - n: usize, - public_param: PublicParam, - slot: u32, - chain_index: usize, - start_step: usize, -) -> Digest { - // Chain hash layout: left = [tweak (2) | zeros (2) | data (4)], right = [public_param(4) | zeros(4)]. - let right = build_right_chain_input(&public_param); - (0..n).fold(*a, |acc, j| { - let tweak = make_tweak(TWEAK_TYPE_CHAIN, chain_index * CHAIN_LENGTH + start_step + j, slot); - let left = build_left_chain_input(tweak, &acc); - poseidon16_compress_pair(&left, &right)[..XMSS_DIGEST_LEN] - .try_into() - .unwrap() - }) -} - -pub fn find_randomness_for_wots_encoding( - message: &[F; MESSAGE_LEN_FE], - slot: u32, - xmss_pub_key: &XmssPublicKey, - rng: &mut impl CryptoRng, -) -> (Randomness, [u8; V], usize) { - let mut num_iters = 0; - loop { - num_iters += 1; - let randomness = rng.random(); - if let Some(encoding) = wots_encode(message, slot, xmss_pub_key, &randomness) { - return (randomness, encoding, num_iters); - } - } -} - -pub fn wots_encode( - message: &[F; MESSAGE_LEN_FE], - slot: u32, - xmss_pub_key: &XmssPublicKey, - randomness: &Randomness, -) -> Option<[u8; V]> { - let first_input_left = message; - let mut first_input_right = [F::default(); DIGEST_LEN_FE]; - first_input_right[..RANDOMNESS_LEN_FE].copy_from_slice(randomness); - first_input_right[RANDOMNESS_LEN_FE..][..TWEAK_LEN].copy_from_slice(&make_tweak(TWEAK_TYPE_ENCODING, 0, slot)); - let pre_compressed = poseidon16_compress_pair(first_input_left, &first_input_right); - - let mut second_input_right = [F::default(); DIGEST_LEN_FE]; - second_input_right[..PUBLIC_PARAM_LEN_FE].copy_from_slice(&xmss_pub_key.public_param); - let compressed = poseidon16_compress_pair(&pre_compressed, &second_input_right); - - if compressed[..NUM_ENCODING_FE].iter().any(|&kb| kb == -F::ONE) { - // ensures uniformity of encoding - return None; - } - let all_indices: Vec<_> = compressed[..NUM_ENCODING_FE] - .iter() - .flat_map(|kb| to_little_endian_bits(kb.to_usize(), 24)) - .collect::>() - .chunks_exact(W) - .take(V) - .map(|chunk| { - chunk - .iter() - .enumerate() - .fold(0u8, |acc, (i, &bit)| acc | (u8::from(bit) << i)) - }) - .collect(); - is_valid_encoding(&all_indices).then(|| all_indices[..V].try_into().unwrap()) -} - -fn is_valid_encoding(encoding: &[u8]) -> bool { - if encoding.len() != V { - return false; - } - if !encoding.iter().all(|&x| (x as usize) < CHAIN_LENGTH) { - return false; - } - if encoding.iter().map(|&x| x as usize).sum::() != TARGET_SUM { - return false; - } - true -} diff --git a/crates/xmss/src/xmss.rs b/crates/xmss/src/xmss.rs deleted file mode 100644 index d5f69f445..000000000 --- a/crates/xmss/src/xmss.rs +++ /dev/null @@ -1,238 +0,0 @@ -use backend::*; -use rand::{CryptoRng, RngExt, SeedableRng, rngs::StdRng}; -use serde::{Deserialize, Serialize}; -use sha3::{Digest as Sha3Digest, Keccak256}; -use utils::poseidon16_compress; - -use crate::*; - -#[derive(Debug)] -pub struct XmssSecretKey { - pub(crate) slot_start: u32, // inclusive - pub(crate) slot_end: u32, // inclusive - pub(crate) public_param: PublicParam, - pub(crate) seed: [u8; 32], - // At level l, stored indices go from (slot_start >> l) to (slot_end >> l). - pub(crate) merkle_tree: Vec>, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct XmssSignature { - pub wots_signature: WotsSignature, - #[serde( - with = "backend::array_serialization", - bound(serialize = "F: Serialize", deserialize = "F: Deserialize<'de>") - )] - pub merkle_proof: [Digest; LOG_LIFETIME], -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct XmssPublicKey { - pub merkle_root: Digest, - pub public_param: PublicParam, -} - -impl XmssPublicKey { - pub fn flaten(&self) -> [F; PUB_KEY_FLAT_SIZE] { - let mut output = [F::default(); PUB_KEY_FLAT_SIZE]; - output[..XMSS_DIGEST_LEN].copy_from_slice(&self.merkle_root); - output[XMSS_DIGEST_LEN..].copy_from_slice(&self.public_param); - output - } -} - -fn gen_wots_secret_key(seed: &[u8; 32], slot: u32, public_param: PublicParam) -> WotsSecretKey { - let mut hasher = Keccak256::new(); - hasher.update(b"wots_secret_key"); - hasher.update(seed); - hasher.update(slot.to_le_bytes()); - let mut rng = StdRng::from_seed(hasher.finalize().into()); - WotsSecretKey::random(&mut rng, public_param, slot) -} - -fn gen_public_param(seed: &[u8; 32]) -> PublicParam { - let mut hasher = Keccak256::new(); - hasher.update(b"public_param"); - hasher.update(seed); - let mut rng = StdRng::from_seed(hasher.finalize().into()); - rng.random() -} - -/// Deterministic pseudo-random digest for an out-of-range tree node. -fn gen_random_node(seed: &[u8; 32], level: usize, index: u64) -> Digest { - let mut hasher = Keccak256::new(); - hasher.update(b"random_node"); - hasher.update(seed); - hasher.update((level as u64).to_le_bytes()); - hasher.update(index.to_le_bytes()); - let mut rng = StdRng::from_seed(hasher.finalize().into()); - rng.random() -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] -pub enum XmssKeyGenError { - InvalidRange, -} - -pub fn xmss_key_gen( - seed: [u8; 32], - slot_start: u32, - slot_end: u32, -) -> Result<(XmssSecretKey, XmssPublicKey), XmssKeyGenError> { - if slot_start > slot_end || slot_end as u64 >= (1 << LOG_LIFETIME) { - return Err(XmssKeyGenError::InvalidRange); - } - let public_param: PublicParam = gen_public_param(&seed); - // Level 0: WOTS leaf hashes for slots in [slot_start, slot_end] - let leaves: Vec = (slot_start..=slot_end) - .into_par_iter() - .map(|slot| { - let wots = gen_wots_secret_key(&seed, slot, public_param); - wots.public_key().hash(public_param, slot) - }) - .collect(); - let mut merkle_tree = vec![leaves]; - // Build levels 1..=LOG_LIFETIME. - // At level l, we store nodes with index in [(slot_start >> l), (slot_end >> l)]. - // Children outside [slot_start, slot_end]'s subtree are replaced by gen_random_node. - for level in 1..=LOG_LIFETIME { - let base: u64 = (slot_start as u64) >> level; - let top: u64 = (slot_end as u64) >> level; - let prev_base: u64 = (slot_start as u64) >> (level - 1); - let prev_top: u64 = (slot_end as u64) >> (level - 1); - let nodes: Vec = { - let prev = &merkle_tree[level - 1]; - (base..=top) - .into_par_iter() - .map(|i| { - let left_idx = 2 * i; - let right_idx = 2 * i + 1; - let left = if left_idx >= prev_base && left_idx <= prev_top { - prev[(left_idx - prev_base) as usize] - } else { - gen_random_node(&seed, level - 1, left_idx) - }; - let right = if right_idx >= prev_base && right_idx <= prev_top { - prev[(right_idx - prev_base) as usize] - } else { - gen_random_node(&seed, level - 1, right_idx) - }; - let merkle_data = build_merkle_data( - make_tweak(TWEAK_TYPE_MERKLE, level, i as u32), - &public_param, - &left, - &right, - ); - poseidon16_compress(merkle_data)[..XMSS_DIGEST_LEN].try_into().unwrap() - }) - .collect() - }; - merkle_tree.push(nodes); - } - let pub_key = XmssPublicKey { - merkle_root: merkle_tree.last().unwrap()[0], - public_param, - }; - let secret_key = XmssSecretKey { - slot_start, - slot_end, - public_param, - seed, - merkle_tree, - }; - Ok((secret_key, pub_key)) -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] -pub enum XmssSignatureError { - SlotOutOfRange, - InvalidRandomness, -} - -pub fn xmss_sign( - rng: &mut R, - secret_key: &XmssSecretKey, - message: &[F; MESSAGE_LEN_FE], - slot: u32, -) -> Result { - let (randomness, _, _) = find_randomness_for_wots_encoding(message, slot, &secret_key.public_key(), rng); - xmss_sign_with_randomness(secret_key, message, slot, randomness) -} - -pub fn xmss_sign_with_randomness( - secret_key: &XmssSecretKey, - message: &[F; MESSAGE_LEN_FE], - slot: u32, - randomness: [F; RANDOMNESS_LEN_FE], -) -> Result { - if slot < secret_key.slot_start || slot > secret_key.slot_end { - return Err(XmssSignatureError::SlotOutOfRange); - } - let wots_secret_key = gen_wots_secret_key(&secret_key.seed, slot, secret_key.public_param); - let wots_signature = wots_secret_key - .sign_with_randomness(message, slot, &secret_key.public_key(), randomness) - .ok_or(XmssSignatureError::InvalidRandomness)?; - let merkle_proof = std::array::from_fn(|level| { - let neighbour_index = ((slot as u64) >> level) ^ 1; - let base = (secret_key.slot_start as u64) >> level; - let top = (secret_key.slot_end as u64) >> level; - if neighbour_index >= base && neighbour_index <= top { - secret_key.merkle_tree[level][(neighbour_index - base) as usize] - } else { - gen_random_node(&secret_key.seed, level, neighbour_index) - } - }); - Ok(XmssSignature { - wots_signature, - merkle_proof, - }) -} - -impl XmssSecretKey { - pub fn public_key(&self) -> XmssPublicKey { - XmssPublicKey { - merkle_root: self.merkle_tree.last().unwrap()[0], - public_param: self.public_param, - } - } -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] -pub enum XmssVerifyError { - InvalidWots, - InvalidMerklePath, -} - -pub fn xmss_verify( - pub_key: &XmssPublicKey, - message: &[F; MESSAGE_LEN_FE], - signature: &XmssSignature, - slot: u32, -) -> Result<(), XmssVerifyError> { - let wots_public_key = signature - .wots_signature - .recover_public_key(message, slot, pub_key) - .ok_or(XmssVerifyError::InvalidWots)?; - let mut current_hash = wots_public_key.hash(pub_key.public_param, slot); - for (level, neighbour) in signature.merkle_proof.iter().enumerate() { - let is_left = (((slot as u64) >> level) & 1) == 0; - let parent_index = ((slot as u64) >> (level + 1)) as u32; - let (left_child, right_child) = if is_left { - (current_hash, *neighbour) - } else { - (*neighbour, current_hash) - }; - let merkle_data = build_merkle_data( - make_tweak(TWEAK_TYPE_MERKLE, level + 1, parent_index), - &pub_key.public_param, - &left_child, - &right_child, - ); - current_hash = poseidon16_compress(merkle_data)[..XMSS_DIGEST_LEN].try_into().unwrap(); - } - if current_hash == pub_key.merkle_root { - Ok(()) - } else { - Err(XmssVerifyError::InvalidMerklePath) - } -} diff --git a/crates/xmss/tests/xmss_tests.rs b/crates/xmss/tests/xmss_tests.rs deleted file mode 100644 index 0fb08e01d..000000000 --- a/crates/xmss/tests/xmss_tests.rs +++ /dev/null @@ -1,62 +0,0 @@ -use backend::*; -use rand::{SeedableRng, rngs::StdRng}; -use xmss::*; - -type F = KoalaBear; - -#[test] -fn test_xmss_serialize_deserialize() { - let keygen_seed: [u8; 32] = std::array::from_fn(|i| i as u8); - let message: [F; MESSAGE_LEN_FE] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); - let slot_start = 100; - let slot_end = 115; - let slot = 110; - - let (sk, pk) = xmss_key_gen(keygen_seed, slot_start, slot_end).unwrap(); - let sig = xmss_sign(&mut StdRng::seed_from_u64(slot as u64), &sk, &message, slot).unwrap(); - - let pk_bytes = postcard::to_allocvec(&pk).unwrap(); - let pk2: XmssPublicKey = postcard::from_bytes(&pk_bytes).unwrap(); - assert_eq!(pk, pk2); - - let sig_bytes = postcard::to_allocvec(&sig).unwrap(); - let sig2: XmssSignature = postcard::from_bytes(&sig_bytes).unwrap(); - assert_eq!(sig, sig2); - - xmss_verify(&pk2, &message, &sig2, slot).unwrap(); -} - -#[test] -fn keygen_sign_verify() { - let keygen_seed: [u8; 32] = std::array::from_fn(|i| i as u8); - let message: [F; MESSAGE_LEN_FE] = std::array::from_fn(|i| F::from_usize(i * 3 + 7)); - - for slot in [0, 1234, u32::MAX] { - let (sk, pk) = xmss_key_gen(keygen_seed, slot.saturating_sub(1), slot.saturating_add(2)).unwrap(); - let sig = xmss_sign(&mut StdRng::seed_from_u64(slot as u64), &sk, &message, slot).unwrap(); - xmss_verify(&pk, &message, &sig, slot).unwrap(); - } -} - -#[test] -#[ignore] -fn encoding_grinding_bits() { - let n = 100; - let xmss_pub_key = XmssPublicKey { - merkle_root: Default::default(), - public_param: Default::default(), - }; - let total_iters = (0..n) - .into_par_iter() - .map(|i| { - let message: [F; MESSAGE_LEN_FE] = Default::default(); - let slot = i as u32; - let mut rng = StdRng::seed_from_u64(i as u64); - let (_randomness, _encoding, num_iters) = - find_randomness_for_wots_encoding(&message, slot, &xmss_pub_key, &mut rng); - num_iters - }) - .sum::(); - let grinding = ((total_iters as f64) / (n as f64)).log2(); - println!("Average grinding bits: {:.1}", grinding); -} diff --git a/crates/xmss/xmss.md b/crates/xmss/xmss.md deleted file mode 100644 index e1538e1ce..000000000 --- a/crates/xmss/xmss.md +++ /dev/null @@ -1,48 +0,0 @@ -# XMSS high-level specification - -## Field - -KoalaBear (p = 2^31 - 2^24 + 1). - -## Hash function - -[Poseidon](https://eprint.iacr.org/2019/458), in compression mode (feedforward addition). Input: 16 field elements. Output: 8 field elements. We denote it `H`. Chain hashes, Merkle hashes, and the final WOTS-pubkey hash truncate the output to 4 field elements (`n`); the encoding step and the intermediate WOTS-pubkey sponge states keep the full 8 elements. - -## Sizes (in field elements) - -- `n = 4`: digest size -- `|pp| = 4`: public parameter -- `|randomness| = 6`: signature randomness -- `|msg| = 8`: message size -- `|tweak| = 2`: tweak (domain separation: `encoding`, `chain`, `wots_pk`, `merkle`) - -## WOTS (Winternitz One Time Signature) - -- `v = 42`: number of hash chains -- `w = 3`, `chain_length = 2^w = 8` -- `target_sum = 184`: a WOTS encoding `(e_0, ..., e_{v-1})` is valid iff each `e_i < chain_length` and `sum(e_i) = target_sum`. The signer grinds `randomness` until the encoding is valid (avoids checksum chains). - -## XMSS - -`log_lifetime = 32`: a key is valid for up to `2^32` slots. `log_lifetime` corresponds to the Merkle tree height. - -## Verification - -Inputs: public key `(merkle_root, pp)`, message `msg`, slot `s`, signature `(randomness, chain_tips, merkle_proof)`. - -1. **Encode**: compute the 8-limb digest `D = H(H(msg | randomness | tweak_encoding(s)) | pp | 0000)`. For each limb `D_i`, take the canonical representative `D_i = low + 2^24 · high` (with `low ∈ [0, 2^24)`, `high ∈ [0, 128)`) and reject if `high == 127` (equivalently `D_i == −1`). This guarantees an uniform encoding. Concatenate the 24-bit `low` parts of the 8 limbs in little-endian order to get 192 bits, then take the first `v · w = 126` bits split into `v = 42` little-endian chunks of `w = 3` bits → encoding `(e_0, ..., e_{v-1})` with each `e_i ∈ [0, chain_length)`. Reject if `sum(e_i) ≠ target_sum`. -2. **Recover WOTS public key**: for each `i`, walk chain `i` from `chain_tips[i]` for `chain_length - 1 - e_i` steps, where each step is `H(tweak_chain(i, step, s) | 00 | previous_value | pp | 0000)` truncated to `n`. -3. **Hash WOTS public key**: T-sponge with replacement over the `v` recovered chain ends, with IV `[tweak_wots_pk(s) | 00 | pp]`, ingesting two chain end digests at a time. Output is the Merkle leaf. -4. **Walk Merkle path**: for `level = 0..log_lifetime`, combine the current node with `merkle_proof[level]` (left/right determined by bit `level` of `s`) via `H(tweak_merkle(level+1, parent_index) | 00 | pp | left | right)` truncated to `n`. -5. **Check root**: accept iff the final hash equals `merkle_root`. - - -## Security - -target = 123,9 ≈ 124 bits of classical security in the ROM, and ≈ 62 bits of quantum security in the QROM, with an analysis inspired by the section 3.1 of [Tight adaptive reprogramming in the QROM](https://arxiv.org/pdf/2010.15103). TODO write the complete proof. - -## Signature size - -**1171 bytes** `log2(p).(|randomness| + n.(v + log_lifetime))` - -below IPv6 [MTU](https://fr.wikipedia.org/wiki/Maximum_transmission_unit) (1280 bytes) diff --git a/src/lib.rs b/src/lib.rs index 0c982fd82..72e725214 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,26 @@ use backend::*; pub use backend::ProofError; +pub use leansig_wrapper::{ + MESSAGE_LENGTH, XmssPublicKey, XmssSignature, xmss_keygen_fast, xmss_sign_fast, xmss_verify, +}; pub use rec_aggregation::{ - AggregationError, MAX_RECURSIONS, MAX_XMSS_AGGREGATED, MAX_XMSS_DUPLICATES, ProverError, TypeOneInfo, - TypeOneMultiSignature, TypeTwoMultiSignature, aggregate_type_1, merge_many_type_1, split_type_2, verify_type_1, - verify_type_2, + AggregatedXMSS, AggregatedXMSSInfo, AggregationError, AggregationTopology, MAX_RECURSIONS, MAX_XMSS_AGGREGATED, + MAX_XMSS_DUPLICATES, ProverError, xmss_aggregate, xmss_verify_aggregation, }; -pub use xmss::{MESSAGE_LEN_FE, XmssPublicKey, XmssSecretKey, XmssSignature, xmss_key_gen, xmss_sign, xmss_verify}; pub type F = KoalaBear; -/// Call once before proving. Compiles the aggregation program and precomputes DFT twiddles. +/// Call once before proving. Enables the proving arena, compiles the aggregation program, and +/// precomputes DFT twiddles. +/// +/// # Safety +/// Never generate two proofs concurrently in one process. (The arena allocator has a single shared +/// region per process, so concurrent proving corrupts each proof's buffers.) Use separate processes +/// to parallelize. pub fn setup_prover() { + zk_alloc::enable_arena(); + parallel::init(); rec_aggregation::init_aggregation_bytecode(); precompute_dft_twiddles::(1 << 24); } @@ -20,16 +29,3 @@ pub fn setup_prover() { pub fn setup_verifier() { rec_aggregation::init_aggregation_bytecode(); } - -/// Bump-arena allocator. -/// -/// **Optional.** -/// -/// To enable, set it as the `#[global_allocator]` in your binary and call -/// [`init_allocator`] once at startup. Then bracket each proving call with -/// [`begin_phase`] / [`end_phase`] and **clone the outputs after -/// [`end_phase`]** so the cloned copy lands in the system allocator before the -/// next [`begin_phase`] resets the arena slabs. -/// -/// See `tests/test_zk_alloc.rs` for a runnable end-to-end example. -pub use zk_alloc::{ZkAllocator, begin_phase, end_phase, init as init_allocator}; diff --git a/src/main.rs b/src/main.rs index 646fc6f64..a0cc00f14 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,6 @@ use clap::Parser; use rec_aggregation::benchmark::{AggregationTopology, biggest_leaf, run_aggregation_benchmark}; -#[cfg(not(feature = "standard-alloc"))] -#[global_allocator] -static ALLOC: zk_alloc::ZkAllocator = zk_alloc::ZkAllocator; - #[derive(Parser)] enum Cli { #[command(about = "Aggregate XMSS")] @@ -67,8 +63,8 @@ fn run_with_warmup(topology: &AggregationTopology, tracing: bool, json: bool, re #[allow(clippy::too_many_lines)] fn main() { - #[cfg(not(feature = "standard-alloc"))] - zk_alloc::init(); + zk_alloc::enable_arena(); + parallel::init(); let cli = Cli::parse(); @@ -99,7 +95,7 @@ fn main() { raw_xmss: 0, children: vec![ AggregationTopology { - raw_xmss: 775, + raw_xmss: 700, children: vec![], log_inv_rate, overlap: 0, @@ -124,13 +120,13 @@ fn main() { raw_xmss: 0, children: vec![ AggregationTopology { - raw_xmss: 1550, + raw_xmss: 1400, children: vec![], log_inv_rate: 1, overlap: 0, }, AggregationTopology { - raw_xmss: 508, + raw_xmss: 658, children: vec![], log_inv_rate: 2, overlap: 0, @@ -143,13 +139,13 @@ fn main() { raw_xmss: 0, children: vec![ AggregationTopology { - raw_xmss: 1550, + raw_xmss: 1400, children: vec![], log_inv_rate: 2, overlap: 0, }, AggregationTopology { - raw_xmss: 508, + raw_xmss: 658, children: vec![], log_inv_rate: 2, overlap: 0, @@ -166,7 +162,7 @@ fn main() { raw_xmss: 0, children: vec![ AggregationTopology { - raw_xmss: 775, + raw_xmss: 700, children: vec![], log_inv_rate: 2, overlap: 0, diff --git a/tests/test_multisignatures.rs b/tests/test_multisignatures.rs index e90719f13..086b13cd8 100644 --- a/tests/test_multisignatures.rs +++ b/tests/test_multisignatures.rs @@ -1,30 +1,20 @@ -use std::time::Instant; - -use lean_multisig::{ - TypeOneMultiSignature, TypeTwoMultiSignature, aggregate_type_1, merge_many_type_1, setup_prover, split_type_2, - verify_type_1, verify_type_2, -}; -use rand::{RngExt, SeedableRng, rngs::StdRng}; -use rec_aggregation::{ - benchmark::{AggregationTopology, run_aggregation_benchmark}, - split_type_2_by_msg, -}; -use xmss::{ - signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}, - xmss_key_gen, xmss_sign, xmss_verify, -}; +use lean_multisig::{AggregatedXMSS, setup_prover, xmss_aggregate, xmss_verify_aggregation}; +use leansig_wrapper::{xmss_keygen_fast, xmss_sign_fast, xmss_verify}; +use rand::{SeedableRng, rngs::StdRng}; +use rec_aggregation::benchmark::{AggregationTopology, run_aggregation_benchmark}; +use rec_aggregation::signatures_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; #[test] fn test_xmss_signature() { - let start_slot = 111; - let end_slot = 200; + let activation_epoch = 111; + let num_active_epochs = 39; let slot: u32 = 124; let mut rng: StdRng = StdRng::seed_from_u64(0); - let msg = rng.random(); + let msg = [42u8; leansig_wrapper::MESSAGE_LENGTH]; - let (secret_key, pub_key) = xmss_key_gen(rng.random(), start_slot, end_slot).unwrap(); - let signature = xmss_sign(&mut rng, &secret_key, &msg, slot).unwrap(); - xmss_verify(&pub_key, &msg, &signature, slot).unwrap(); + let (secret_key, pub_key) = xmss_keygen_fast(&mut rng, activation_epoch, num_active_epochs); + let signature = xmss_sign_fast(&secret_key, &msg, slot).unwrap(); + xmss_verify(&pub_key, slot, &msg, &signature).unwrap(); } #[test] @@ -41,7 +31,7 @@ fn test_aggregation() { } #[test] -fn test_type_1_aggregation() { +fn test_xmss_aggregate() { setup_prover(); let log_inv_rate = 2; // [1, 2, 3 or 4] (lower = faster but bigger proofs) @@ -50,79 +40,67 @@ fn test_type_1_aggregation() { let signatures = get_benchmark_signatures(); let raws_a = signatures[0..3].to_vec(); - let type1_a = aggregate_type_1(&[], raws_a, message, slot, log_inv_rate).unwrap(); + let (_, type1_a) = xmss_aggregate(&[], raws_a, &message, slot, log_inv_rate).unwrap(); let raws_b = signatures[3..5].to_vec(); - let type1_b = aggregate_type_1(&[], raws_b, message, slot, log_inv_rate).unwrap(); + let (_, type1_b) = xmss_aggregate(&[], raws_b, &message, slot, log_inv_rate).unwrap(); let raws_c = signatures[5..6].to_vec(); - let final_sig = aggregate_type_1(&[type1_a, type1_b], raws_c, message, slot, log_inv_rate).unwrap(); + let pks_a = type1_a.info.pubkeys.clone(); + let pks_b = type1_b.info.pubkeys.clone(); + let (_, final_sig) = xmss_aggregate( + &[(&pks_a, type1_a), (&pks_b, type1_b)], + raws_c, + &message, + slot, + log_inv_rate, + ) + .unwrap(); let serialized_proof = final_sig.compress(); println!("Serialized aggregated final: {} KiB", serialized_proof.len() / 1024); - let recovered = TypeOneMultiSignature::decompress(&serialized_proof).unwrap(); + let recovered = AggregatedXMSS::decompress(&serialized_proof).unwrap(); - verify_type_1(&recovered).unwrap(); + xmss_verify_aggregation(recovered.info.pubkeys.clone(), &recovered, &message, slot).unwrap(); } #[test] -fn test_type_2_aggregation() { +fn test_type1_compression() { setup_prover(); - let log_inv_rate = 2; // [1, 2, 3 or 4] (lower = faster but bigger proofs) - let slot_a = BENCHMARK_SLOT; - let message_a = message_for_benchmark(); + let log_inv_rate = 2; + let message = message_for_benchmark(); + let slot = BENCHMARK_SLOT; let signatures = get_benchmark_signatures(); - let raws_a = signatures[0..3].to_vec(); - let slot_b = BENCHMARK_SLOT + 1; - let mut rng_b: StdRng = StdRng::seed_from_u64(17); - let message_b: [_; 8] = std::array::from_fn(|_| rng_b.random()); - - assert!(message_b != message_a && slot_b != slot_a); - - let raws_b: Vec<_> = (0..2) - .map(|_| { - let (sk, pk) = xmss_key_gen(rng_b.random(), slot_b, slot_b).unwrap(); - let sig = xmss_sign(&mut rng_b, &sk, &message_b, slot_b).unwrap(); - (pk, sig) - }) - .collect(); - - let type1_a = aggregate_type_1(&[], raws_a, message_a, slot_a, log_inv_rate).unwrap(); - let type1_b = aggregate_type_1(&[], raws_b, message_b, slot_b, log_inv_rate).unwrap(); - - verify_type_1(&type1_a).unwrap(); - verify_type_1(&type1_b).unwrap(); - - let info_a = type1_a.info.clone(); - let info_b = type1_b.info.clone(); - - let time = Instant::now(); - let type2 = merge_many_type_1(vec![type1_a, type1_b], log_inv_rate).unwrap(); - println!("merge_many_type_1: {:.2}s", time.elapsed().as_secs_f64()); - assert_eq!(type2.info.len(), 2); - assert_eq!(type2.info[0], info_a); - assert_eq!(type2.info[1], info_b); - - let compressed_type2 = type2.compress(); - let type2 = TypeTwoMultiSignature::decompress(&compressed_type2).unwrap(); - verify_type_2(&type2).unwrap(); - - let time = Instant::now(); - let split_a = split_type_2(type2.clone(), 0, log_inv_rate).unwrap(); - println!("split index 0: {:.2}s", time.elapsed().as_secs_f64()); - let time = Instant::now(); - let split_b = split_type_2_by_msg(type2, message_b, log_inv_rate).unwrap(); - println!("split index 1: {:.2}s", time.elapsed().as_secs_f64()); - assert_eq!( - (split_a.info.message, &split_a.info.slot, &split_a.info.pubkeys), - (info_a.message, &info_a.slot, &info_a.pubkeys) - ); - assert_eq!( - (split_b.info.message, &split_b.info.slot, &split_b.info.pubkeys), - (info_b.message, &info_b.slot, &info_b.pubkeys) - ); - verify_type_1(&split_a).expect("split index 0 failed verify_type_1"); - verify_type_1(&split_b).expect("split index 1 failed verify_type_1"); + // The pubkey set is shared between prover and verifier. + let raws_a = signatures[..3].to_vec(); + let shared_pubkeys_a = raws_a.iter().map(|(pk, _)| pk.clone()).collect::>(); + let (_, type1_a) = xmss_aggregate(&[], raws_a, &message, slot, log_inv_rate).unwrap(); + + let type1_a_compressed_compact = type1_a.compress_without_pubkeys(); + let type1_a_compact_recovered = + AggregatedXMSS::decompress_without_pubkeys(&type1_a_compressed_compact, shared_pubkeys_a) + .expect("type-1 round-trip"); + xmss_verify_aggregation( + type1_a_compact_recovered.info.pubkeys.clone(), + &type1_a_compact_recovered, + &message, + slot, + ) + .expect("recovered type-1 must verify"); + assert_eq!(type1_a_compact_recovered.info.pubkeys, type1_a.info.pubkeys); + + let type1_a_compressed_full = type1_a.compress(); + let type1_a_full_recovered = AggregatedXMSS::decompress(&type1_a_compressed_full).expect("type-1 round-trip"); + xmss_verify_aggregation( + type1_a_full_recovered.info.pubkeys.clone(), + &type1_a_full_recovered, + &message, + slot, + ) + .expect("recovered type-1 must verify"); + assert_eq!(type1_a_full_recovered.info.pubkeys, type1_a.info.pubkeys); + + assert!(type1_a_compressed_compact.len() < type1_a_compressed_full.len()); } diff --git a/tests/test_vectors/xmss_prod_test_vector.json b/tests/test_vectors/xmss_prod_test_vector.json new file mode 100644 index 000000000..5c799872c --- /dev/null +++ b/tests/test_vectors/xmss_prod_test_vector.json @@ -0,0 +1,6 @@ +{ + "public_key": "0x5caf047c20695a512c5b606a1c702761cb92fd505d63345b126b8a36b54dc3313fd11056749de15f276ca24ee799d921e1ab9c37", + "slot": 5, + "message": "0xabababababababababababababababababababababababababababababababab", + "signature_ssz": "0x2400000075871b720d712619164d6966f8066d125d40c34a35df5f006da060672804000004000000c94d0d18492f4e728e435927a595163716bbcf6150aa8b5c9d48367d7e0d911f13b89b4bbffe9a597dc64839b641836e0662114051879e1d41285919b6e4fe7db2c36665ab4b461d10fbb4288b918610f0c3273b2d44760e74091c642729ed0c525d3e35e61bd84df378710c9bce8928314ef56b8e24715b3a07d401a791d8094a88dd47a70af0563cd0960e04d3261a8d898c5fb1d96a5d9a890d6cb48a0311884d9844787b2a35cd938c2ca3d0b41e1cf30e275760367efdfd237985d5296b3d76d32c965ff872cff01a411df6912b9836d34b1e13324abac09d243cc963497056ac7abbbb9d65d82001707cae060b8f873d37be7c1e1788978a260796f7663de1fc0ef2281c1f1be2a55728ad7255228186557344473dd54fd20a6b59e4004982d665df81e46f4edc5a4f2b6aba48283d9872ab07f1289fc74c73051cfb34a8047d384dd3574a0c564b0b3274fe3b4f08e607bfa9020ced67d13954b2c4694a01e84cad2d3e06d7f00f175e6a06347337605ddb6a3d73999cd63ecf61cc6894d83467cba5f45ce9a4755f6b671e07f3e4b8531e1d773c301ee63e06558a1a7791e46ec4769548f49f75002d6a3c2ff468b078699f711fbf204d6671e42c5829b69e3931939e222250fd37a16c8530d268986a57bc3c0aac5f2a3d9c050851f134af1c93212e4cf7647f7763a3da30d3e3dc6e69057c3596e8c96fbea02337ab273f6e4b4d65343cdeb8793e1b0230766e6774b0052a1676482c02319c5603a035505acb89094b10a3810f1bd27034abdfd95a9429d92ab5b59856997e3f2d1709c6794b67ed13b16a091d74ac0c626686d94143fc3a2fb37d9443d3b1bc06f18e6b79a950fc123631b959cda7cb3a33c4d17c6eba386169b6f87559ec3466ceccf01152c4491624a9470b51c37269d3fc0e0040c2bb189a1e48651499de046a20ae7d211fd961ec834025375a0b43a19498483b3ac4768618384a24182101bf2bc67cf649b757e657cf7a12e6ef09e46d311b96a9f64151d76770ec65fb347bec7c1dce9edc4917edd5285711be2d124ad7603b31e54863647474f82ab56c4c954b1d9dd8ba1dafbba61b61aed803124b106a0240bb4a566681314eb44d294d4aa64d3a3d7b1dd0b18739893d284a53f4e913d227617a6373ba436be7051994ab1c077d003f509d588f3ceac96f76ed1110151f526032b21a0c1f00734f466e83060d90715818f7248f6b30372953b2dfbb6990104c26f05d441d866a507ca07a3b2ee9c3130aacd9b3326b9681476299613eeb43313512897743d9367137a772b8457e287e4c4356502bb0cedc3c7d02527b5f594e7a50acd00ed8c7944b6ead6d6e32daec29f2f08a212b7d6416b36dcc71a40b8d39d7ab9844f3624c5f2896185c2af21857d5fe6a6d1d995743c9a497707dde7b2c2a7a9866fa36150825c18f308b71424321aed4711d789e022156b84b9becf42d645b137b0cd6bb3d320db77ef7afe92a9c906f04c8fcf865a49617150e59992dad75366e0bd48165a241ab2ea1534b2d2d4ac53ffa77e2505fef0f77fc86d23d4ea0a56fa8b5d86cecb9c21c4b71c6616539e268ba9bde1920efca7eb47c671e245bb144a4d05d49a21a760dba6e971aa1af1856484da976ad1494451734f51400b1a433c3e52b2df423cd1642236c6757cbc62296c58e487d9dde78e7c0c7591ee1c84729910725bad3bc5015fb8b0a4cd1371eb56ab66eb676af344c61ef17f54e3c300ff2c305120b4455e3b35572f869dc7b394bae67c1ee5a04af3ecc5063d6d55e3344c60bfc5d444c589628139974c4457956360f3ef0a75b2cc77327918018088ef8fa3e55b7f45fc937780da2cd481bd6a6ae373a382e190ed87a26eaa8a76a1c7090493024f913ed01b83e9cdde0667d28cf095561a056f298003262a6172268503d53e0984a3f602f8930b6283d337d730b4c05b639413767306f0bbee362373f95222639927182d302023d1a527c567f9e3731f9bf0e24d3ac2add41c1030ead1559df0f7e6eef674a5dfd1e817cecaec91c3483890975ebf74b5ff8c05145706038cade10328f1fc65b35828357a195a971b32cb92e94d331614810b20a8c3178610fa72e69261cca4631404c79822e8e36a6e9e0042d8be530759c5a63b4d89116da07b066be038577f6e9122a7eb5a868a02db7565807a026470e7b2c0ccfa63cb314ad04ea1d9016d5bded4019a61467f6778e59b8083204762c716d1416ce070c01bc2469429a160b7bc82a56aa56508ed4195ce386c520f141316ea76b261eddca9303aa379e77d16a6d65cc6c7c2e4415b670050c332fc7d2d36c1be2800d39b1f306d2144775ca0d544d93d0ce5e18daf03982b0c2408c655f4f3c48d851d1726873471b99165f2e30063151803f9dcd2537436407480a8cb95cdb2cfc643deb393ea4021e76fbae5c0bea38a51891a2e320d032ad619067262fc1c44f506823da10a3e38f55d69cd95fead33d2a7c42f912c03ae64ad65db07ddcf95f728fe06838ebdb4b4cf3a65f3061819920efcd5115e3fa5e0fa63cff4b2baf371c85226360639e1709c6625e0288671f263ed84a09a95cce2aa826bc534e959e2162ee596512510c456abd6d3a57cf8763f709c24e05bc223544bea6473b3d526ac76fed313cbc3251d3c3ad1f58f9de3d2bd3ef3d9b55a24caacf600b2c232a398f9918610118fd2ef0a8691a406c5a5094ffb42e0ad1f164b2f0f60117aa2246acba444222365c34c7884d2c19d0e14204c3ab6254cc36043b5d1114fc6722250485031d63aa974eacee2e46c0c7ea23fed3c30774edc360db1b5529bed9f9080ee061174d75f930e9a7ef6a903b96062d7db1628baf5c1ad4f7171156466c33ff0b446bcd722a5e025ffe3024b688209082da36f899a12b781ab6574383502aaff4504161564f282fb6fd26151c1d2b7d2a9e53da491e3de92a0168b3f17c76337952792d8c75318a70f33fd8db77731961f4297fbb446dce9efb6ebff6175d69ffac4c89c0985d1682d252d1f18d74e1a4306c13741149c8728d2337cd621a76075b4582947c7bd3d948409def8739b9fcc52ddfff5472b5f5130171c40d200f0db706a85ec7054cea874f7491a815ab8f8229339940724fb754489ff1fd6cec3a997e1465e15cafb564786a64563183e4dd1b86e55c32b4e6cc7a7a09372ce4be746f3af52b44dde43e587f34036a30b41d48e5969142b0d5b3009f9de8682a4b081ceee2c50296c3fc4ddee73475293c5650fc34103b3e1c205480b0080d6a1c4277372f7e40a995af5626b1004a0ac4485a0dc87f5bb6a0455413392814cb3adf39ffb3c616b94b3e3a343d27189a179e17c47a556d03cffb54be101d265703e60f729f9e34d3c09d2cf9d98647ab0a5d4e262771682e1d0e21dfd28e7bbf462950e0561a79ace93b3884aeba64a515f402250b1744355c0f2f418f0d2f3a69db46983042316691fd7a4303701026a4aa727f75ef245e15ec047fc7e209" +} \ No newline at end of file diff --git a/tests/test_vectors/xmss_test_vector.json b/tests/test_vectors/xmss_test_vector.json new file mode 100644 index 000000000..e4d36a216 --- /dev/null +++ b/tests/test_vectors/xmss_test_vector.json @@ -0,0 +1,6 @@ +{ + "public_key": "0xa952704f32175406c216be5af21834765545ef4aa13c06633059954c34aff83d68e86d059be8f01d2f57e73427654d192973d93f", + "slot": 4, + "message": "0xabababababababababababababababababababababababababababababababab", + "signature_ssz": "0x24000000d4aefd0aa1d828289c6f6544f9b7d35a4fd45a2833ef934d38b9bf2128010000040000008d234c504472c07a0d972037fa8d4b713991ff0724aa0c42bbb3f015f14a753f9e53203de989140f1235826fdf907a3067c98a6a2b967722fef1e573fe153a24f9f6232c0b4931757e849d3267bf476484b033740bd1242fb136b8761b2e5152e3bf951519061b3dd7975a0083e67975857964095b1da40971eeed3061e9160bb815f7734a1cf071ee704966e36e577bf3489256e16e070899119327ea567963600ab260d16adb63a417ff177eb8f02fb46c66132fcaf169e225847745d1fd67913a207da6af054c13ad7d44878fcb08bc902c73f26d644925b8063c6a99361e68fa3b06815c8325cc16b119a346681637a66440804e0849c0828770bfcc3a4436884e5deb882949677c700e2fe53d6ae42c752f7a7eea26a68d122293e0830d31377c197854fc6c9186bb5261adc86e7c17fa25a1fec51965fe060ffbf52a7e8390742a78f98d26d4ba825459cc4d786a5bb14a6ceefc20bad1fe55cee81c35d0ce0a37b0d1961bb09a0f0cadb1d24b216b5d069a5d1f6e73438e680797d219" +} diff --git a/tests/test_zk_alloc.rs b/tests/test_zk_alloc.rs deleted file mode 100644 index d826ed80f..000000000 --- a/tests/test_zk_alloc.rs +++ /dev/null @@ -1,25 +0,0 @@ -use lean_multisig::{ZkAllocator, aggregate_type_1, begin_phase, end_phase, setup_prover, verify_type_1}; -use xmss::signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; - -#[global_allocator] -static ALLOC: ZkAllocator = ZkAllocator; - -#[test] -#[allow(clippy::redundant_clone)] -fn test_aggregation_with_zk_alloc() { - setup_prover(); - - let log_inv_rate = 2; - let message = message_for_benchmark(); - let slot: u32 = BENCHMARK_SLOT; - let signatures = get_benchmark_signatures(); - let raw_xmss = signatures[0..6].to_vec(); - - begin_phase(); - let aggregated = aggregate_type_1(&[], raw_xmss, message, slot, log_inv_rate).unwrap(); - end_phase(); - // IMPORTANT: clone to move the data out of the arena memory - let aggregated = aggregated.clone(); - - verify_type_1(&aggregated).unwrap(); -}