From 81ea55bc989c379df579ea2403649700c918af21 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 13 Feb 2026 22:57:15 +0100 Subject: [PATCH] test framework: refactor key management --- .../testing/src/consensus_testing/keys.py | 116 +++++++++--------- packages/testing/src/framework/cli/fill.py | 6 +- 2 files changed, 59 insertions(+), 63 deletions(-) diff --git a/packages/testing/src/consensus_testing/keys.py b/packages/testing/src/consensus_testing/keys.py index 8eb84497..ba9b75f2 100755 --- a/packages/testing/src/consensus_testing/keys.py +++ b/packages/testing/src/consensus_testing/keys.py @@ -1,6 +1,5 @@ """ XMSS Key Management for Consensus Testing -========================================== Management of XMSS key pairs for test validators. @@ -17,10 +16,11 @@ python -m consensus_testing.keys --count 20 # more validators python -m consensus_testing.keys --max-slot 200 # longer lifetime -File Format: - Each key pair is stored in a separate JSON file with hex-encoded SSZ. - Directory structure: test_keys/{scheme}_scheme/{index}.json - Each file contains: {"public": "0a1b...", "secret": "2c3d..."} +File format: + +- Each key pair is stored in a separate JSON file with hex-encoded SSZ. +- Directory structure: ``test_keys/{scheme}_scheme/{index}.json`` +- Each file contains: ``{"public": "0a1b...", "secret": "2c3d..."}`` """ from __future__ import annotations @@ -32,10 +32,10 @@ import tarfile import tempfile import urllib.request +from collections.abc import Iterator, Mapping from concurrent.futures import ProcessPoolExecutor from functools import partial from pathlib import Path -from typing import TYPE_CHECKING, Iterator from lean_spec.config import LEAN_ENV from lean_spec.subspecs.containers import AttestationData, ValidatorIndex @@ -57,23 +57,31 @@ ) from lean_spec.types import Uint64 -if TYPE_CHECKING: - from collections.abc import Mapping +__all__ = [ + "CLI_DEFAULT_MAX_SLOT", + "KEY_DOWNLOAD_URLS", + "LEAN_ENV_TO_SCHEMES", + "LazyKeyDict", + "NUM_VALIDATORS", + "XmssKeyManager", + "download_keys", + "get_keys_dir", + "get_shared_key_manager", +] -# Pre-generated key download URLs KEY_DOWNLOAD_URLS = { "test": "https://github.com/leanEthereum/leansig-test-keys/releases/download/leanSpec-77bde6b/test_scheme.tar.gz", "prod": "https://github.com/leanEthereum/leansig-test-keys/releases/download/leanSpec-77bde6b/prod_scheme.tar.gz", } """URLs for downloading pre-generated keys.""" -# Signature scheme definitions LEAN_ENV_TO_SCHEMES = { "test": TEST_SIGNATURE_SCHEME, "prod": PROD_SIGNATURE_SCHEME, } """ Mapping from short name to scheme objects. This mapping is useful for: + - The CLI argument for choosing the signature scheme to generate - Deriving the file name for the cached keys - Caching key managers in test fixtures @@ -82,14 +90,11 @@ _KEY_MANAGER_CACHE: dict[tuple[str, Slot], XmssKeyManager] = {} """Cache for key managers: {(scheme_name, max_slot): XmssKeyManager}""" -_VALIDATOR_INDEX_CACHE: dict[Uint64, ValidatorIndex] = {} -"""Cache for converting Uint64 to ValidatorIndex.""" - -_DEFAULT_MAX_SLOT: Slot = Slot(10) -"""Default number of max slots that the shared key manager is generated with""" +_SHARED_MANAGER_MAX_SLOT: Slot = Slot(10) +"""Default max slot for the shared key manager.""" -def get_shared_key_manager(max_slot: Slot = _DEFAULT_MAX_SLOT) -> XmssKeyManager: +def get_shared_key_manager(max_slot: Slot = _SHARED_MANAGER_MAX_SLOT) -> XmssKeyManager: """ Get a shared XMSS key manager for reusing keys across tests. @@ -116,43 +121,43 @@ def get_shared_key_manager(max_slot: Slot = _DEFAULT_MAX_SLOT) -> XmssKeyManager return manager -NUM_VALIDATORS = 12 +NUM_VALIDATORS: int = 12 """Default number of validator key pairs.""" -DEFAULT_MAX_SLOT = Slot(100) -"""Maximum slot for test signatures (inclusive).""" +CLI_DEFAULT_MAX_SLOT = Slot(100) +"""Maximum slot for CLI-generated test signatures (inclusive).""" -NUM_ACTIVE_EPOCHS = int(DEFAULT_MAX_SLOT) + 1 -"""Key lifetime in epochs (derived from DEFAULT_MAX_SLOT).""" - -def _get_keys_dir(scheme_name: str) -> Path: +def get_keys_dir(scheme_name: str) -> Path: """Get the keys directory path for the given scheme.""" return Path(__file__).parent / "test_keys" / f"{scheme_name}_scheme" -class LazyKeyDict: +class LazyKeyDict(Mapping[ValidatorIndex, KeyPair]): """Load pre-generated keys from disk (cached after first call).""" def __init__(self, scheme_name: str) -> None: """Initialize with scheme name for locating key files.""" self._scheme_name = scheme_name - self._keys_dir = _get_keys_dir(scheme_name) + self._keys_dir = get_keys_dir(scheme_name) self._cache: dict[ValidatorIndex, KeyPair] = {} - self._available_indices: set[int] | None = None + self._available_indices: set[ValidatorIndex] | None = None def _ensure_dir_exists(self) -> None: + """Raise FileNotFoundError if the keys directory does not exist.""" if not self._keys_dir.exists(): raise FileNotFoundError( f"Keys directory not found: {self._keys_dir} - " f"Run: python -m consensus_testing.keys --scheme {self._scheme_name}" ) - def _get_available_indices(self) -> set[int]: + def _get_available_indices(self) -> set[ValidatorIndex]: """Scan directory for available key indices (cached).""" if self._available_indices is None: self._ensure_dir_exists() - self._available_indices = {int(f.stem) for f in self._keys_dir.glob("*.json")} + self._available_indices = { + ValidatorIndex(int(f.stem)) for f in self._keys_dir.glob("*.json") + } if not self._available_indices: raise FileNotFoundError( f"No key files found in: {self._keys_dir} - " @@ -160,7 +165,7 @@ def _get_available_indices(self) -> set[int]: ) return self._available_indices - def _load_key(self, idx: int) -> KeyPair: + def _load_key(self, idx: ValidatorIndex) -> KeyPair: """Load a single key from disk.""" key_file = self._keys_dir / f"{idx}.json" if not key_file.exists(): @@ -171,12 +176,14 @@ def _load_key(self, idx: int) -> KeyPair: def __getitem__(self, idx: ValidatorIndex) -> KeyPair: """Get key pair by validator index, loading from disk if needed.""" if idx not in self._cache: - self._cache[idx] = self._load_key(int(idx)) + self._cache[idx] = self._load_key(idx) return self._cache[idx] - def __contains__(self, idx: ValidatorIndex) -> bool: + def __contains__(self, idx: object) -> bool: """Check if a key exists for the given validator index.""" - return int(idx) in self._get_available_indices() + if not isinstance(idx, ValidatorIndex): + return False + return idx in self._get_available_indices() def __len__(self) -> int: """Return the number of available keys.""" @@ -184,12 +191,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[ValidatorIndex]: """Iterate over available validator indices in sorted order.""" - return iter(ValidatorIndex(i) for i in sorted(self._get_available_indices())) - - def items(self) -> Iterator[tuple[ValidatorIndex, KeyPair]]: - """Iterate over all keys (loads all into memory).""" - for idx in self: - yield idx, self[idx] + return iter(sorted(self._get_available_indices())) _LAZY_KEY_CACHE: dict[str, LazyKeyDict] = {} @@ -203,16 +205,6 @@ class XmssKeyManager: Handles automatic key state advancement for the stateful XMSS scheme. Keys are lazily loaded from disk on first access. - - Args: - max_slot: Maximum slot for signatures. - scheme: XMSS scheme instance. - - Examples: - >>> mgr = XmssKeyManager() - >>> mgr[Uint64(0)] # Get key pair - >>> mgr.get_public_key(Uint64(1)) # Get public key only - >>> mgr.sign_attestation_data(validator_id, attestation_data) # Sign with auto-advancement """ def __init__( @@ -225,9 +217,12 @@ def __init__( self.scheme = scheme self._state: dict[ValidatorIndex, KeyPair] = {} - for scheme_name, scheme_obj in LEAN_ENV_TO_SCHEMES.items(): - if scheme_obj is scheme: - self.scheme_name = scheme_name + try: + self.scheme_name = next( + name for name, obj in LEAN_ENV_TO_SCHEMES.items() if obj is scheme + ) + except StopIteration: + raise ValueError(f"Unknown scheme: {scheme}") from None @property def keys(self) -> LazyKeyDict: @@ -244,8 +239,10 @@ def __getitem__(self, idx: ValidatorIndex) -> KeyPair: raise KeyError(f"Validator {idx} not found (max: {len(self.keys) - 1})") return self.keys[idx] - def __contains__(self, idx: ValidatorIndex) -> bool: + def __contains__(self, idx: object) -> bool: """Check if validator index exists.""" + if not isinstance(idx, ValidatorIndex): + return False return idx in self.keys def __len__(self) -> int: @@ -311,7 +308,7 @@ def build_attestation_signatures( signature_lookup: Mapping[SignatureKey, Signature] | None = None, ) -> AttestationSignatures: """ - Build `AttestationSignatures` for already-aggregated attestations. + Build attestation signatures for already-aggregated attestations. For each aggregated attestation, collect the participating validators' public keys and signatures, then produce a single leanVM aggregated signature proof. @@ -371,7 +368,7 @@ def _generate_keys(lean_env: str, count: int, max_slot: int) -> None: max_slot: Maximum slot (key lifetime = max_slot + 1 epochs). """ scheme = LEAN_ENV_TO_SCHEMES[lean_env] - keys_dir = _get_keys_dir(lean_env) + keys_dir = get_keys_dir(lean_env) num_epochs = max_slot + 1 num_workers = os.cpu_count() or 1 @@ -392,8 +389,7 @@ def _generate_keys(lean_env: str, count: int, max_slot: int) -> None: # Save each keypair to a separate file for idx, key_pair in enumerate(key_pairs): key_file = keys_dir / f"{idx}.json" - with open(key_file, "w") as f: - json.dump(key_pair, f, indent=2) + key_file.write_text(json.dumps(key_pair, indent=2)) print(f"Saved {len(key_pairs)} key pairs to {keys_dir}/") @@ -401,7 +397,7 @@ def _generate_keys(lean_env: str, count: int, max_slot: int) -> None: _LAZY_KEY_CACHE.clear() -def _download_keys(scheme: str) -> None: +def download_keys(scheme: str) -> None: """ Download pre-generated XMSS key pairs from GitHub releases. @@ -439,7 +435,7 @@ def _download_keys(scheme: str) -> None: # Extract tar.gz with tarfile.open(tmp_path, "r:gz") as tar: - tar.extractall(path=base_dir) + tar.extractall(path=base_dir, filter="data") print(f"Extracted {scheme} keys to {target_dir}/") @@ -480,14 +476,14 @@ def main() -> None: parser.add_argument( "--max-slot", type=int, - default=int(DEFAULT_MAX_SLOT), + default=int(CLI_DEFAULT_MAX_SLOT), help="Maximum slot (key lifetime = max_slot + 1)", ) args = parser.parse_args() # Download keys instead of generating if specified if args.download: - _download_keys(scheme=args.scheme) + download_keys(scheme=args.scheme) return _generate_keys(lean_env=args.scheme, count=args.count, max_slot=args.max_slot) diff --git a/packages/testing/src/framework/cli/fill.py b/packages/testing/src/framework/cli/fill.py index 242fec2f..6cc5478b 100644 --- a/packages/testing/src/framework/cli/fill.py +++ b/packages/testing/src/framework/cli/fill.py @@ -81,14 +81,14 @@ def fill( # Check and download keys if needed (only for consensus layer) if layer.lower() == "consensus": # Import here to avoid loading leanSpec modules before LEAN_ENV is set - from consensus_testing.keys import _download_keys, _get_keys_dir + from consensus_testing.keys import download_keys, get_keys_dir - keys_dir = _get_keys_dir(scheme.lower()) + keys_dir = get_keys_dir(scheme.lower()) # Check if keys already exist, if not, download them if not (keys_dir.exists() and any(keys_dir.glob("*.json"))): click.echo(f"Test keys for '{scheme}' scheme not found. Downloading...") - _download_keys(scheme.lower()) + download_keys(scheme.lower()) config_path = Path(__file__).parent / "pytest_ini_files" / "pytest-fill.ini" # Find project root by looking for pyproject.toml with [tool.uv.workspace]