From c1f901135b9a465ba65986bae5c09fe980778420 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 29 May 2026 05:53:37 -0700 Subject: [PATCH] feat: add persistent PQ key storage --- pyproject.toml | 1 + switchboard/pq_keys.py | 184 +++++++++++++++++++++++++++++++++++++++++ tests/test_pq_keys.py | 102 +++++++++++++++++++++++ 3 files changed, 287 insertions(+) create mode 100644 switchboard/pq_keys.py create mode 100644 tests/test_pq_keys.py diff --git a/pyproject.toml b/pyproject.toml index d7bb061..8c65540 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ classifiers = [ ] dependencies = [ "sortedcontainers>=2.4.0", + "cryptography>=42.0.0", ] [project.optional-dependencies] diff --git a/switchboard/pq_keys.py b/switchboard/pq_keys.py new file mode 100644 index 0000000..7473167 --- /dev/null +++ b/switchboard/pq_keys.py @@ -0,0 +1,184 @@ +"""Persistent PQ key management for switchboard. + +This module layers durable storage and key IDs on top of ``switchboard.pq``. +It is import-safe without ``liboqs`` installed: loading and saving existing +keys works without PQ runtime support, while ``generate()`` / ``sign()`` / +``verify()`` defer to ``switchboard.pq`` and raise there when liboqs is absent. +""" + +from __future__ import annotations + +import base64 +import hashlib +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import ClassVar + +from . import pq + +__all__ = ["PQKeyPair", "verify"] + +_PEM_BEGIN = "-----BEGIN SWITCHBOARD PQ KEY-----" +_PEM_END = "-----END SWITCHBOARD PQ KEY-----" +_PUB_BEGIN = "-----BEGIN SWITCHBOARD PQ PUBLIC KEY-----" +_PUB_END = "-----END SWITCHBOARD PQ PUBLIC KEY-----" +_SCRYPT_N = 32768 +_SCRYPT_R = 8 +_SCRYPT_P = 1 +_SALT_BYTES = 16 +_NONCE_BYTES = 12 + + +def _require_crypto(): + try: + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + except ImportError as exc: # pragma: no cover - env-dependent + raise RuntimeError( + "cryptography is required for encrypted PQ key storage" + ) from exc + return ChaCha20Poly1305 + + +def _b64e(data: bytes) -> str: + return base64.b64encode(data).decode("ascii") + + +def _b64d(data: str) -> bytes: + return base64.b64decode(data.encode("ascii"), validate=True) + + +def _derive_key(passphrase: bytes, salt: bytes, n: int, r: int, p: int) -> bytes: + try: + from cryptography.hazmat.primitives.kdf.scrypt import Scrypt + + return Scrypt(salt=salt, length=32, n=n, r=r, p=p).derive(passphrase) + except ImportError: # pragma: no cover - env-dependent fallback + return hashlib.scrypt(passphrase, salt=salt, n=n, r=r, p=p, dklen=32, maxmem=0) + + +def _key_id(pk: bytes) -> str: + return hashlib.sha256(pk).digest()[:16].hex() + + +def _parse_pem_lines(path: Path, begin: str, end: str) -> tuple[dict[str, str], str]: + lines = path.read_text().splitlines() + if not lines or lines[0].strip() != begin or lines[-1].strip() != end: + raise ValueError(f"invalid key envelope in {path}") + headers: dict[str, str] = {} + payload_lines: list[str] = [] + for line in lines[1:-1]: + if ": " in line and not payload_lines: + k, v = line.split(": ", 1) + headers[k.strip()] = v.strip() + elif line.strip(): + payload_lines.append(line.strip()) + if not payload_lines: + raise ValueError(f"missing key payload in {path}") + return headers, "".join(payload_lines) + + +@dataclass(slots=True) +class PQKeyPair: + alg: str + sk: bytes + pk: bytes + key_id: str = field(init=False) + + DEFAULT_ALG: ClassVar[str] = "ml-dsa-65" + + def __post_init__(self) -> None: + if not isinstance(self.sk, bytes) or not isinstance(self.pk, bytes): + raise TypeError("sk and pk must be raw bytes") + pq._check_alg(self.alg) + self.key_id = _key_id(self.pk) + + @classmethod + def generate(cls, alg: str = DEFAULT_ALG) -> "PQKeyPair": + pk, sk = pq.generate(alg) + return cls(alg=alg, sk=sk, pk=pk) + + @classmethod + def load(cls, path: str, passphrase: bytes | None = None) -> "PQKeyPair": + key_path = Path(path) + headers, payload = _parse_pem_lines(key_path, _PEM_BEGIN, _PEM_END) + alg = headers["alg"] + pq._check_alg(alg) + + cipher = headers.get("cipher") + raw = _b64d(payload) + + if cipher != "chacha20-poly1305" or headers.get("kdf") != "scrypt": + raise ValueError("unsupported PQ key envelope") + passphrase = passphrase or b"" + salt = bytes.fromhex(headers["kdf-salt"]) + nonce = bytes.fromhex(headers["cipher-nonce"]) + tag = bytes.fromhex(headers["cipher-tag"]) + n = int(headers["kdf-n"]) + r = int(headers["kdf-r"]) + p = int(headers["kdf-p"]) + key = _derive_key(passphrase, salt, n, r, p) + ChaCha20Poly1305 = _require_crypto() + try: + sk = ChaCha20Poly1305(key).decrypt(nonce, raw + tag, None) + except Exception as exc: # noqa: BLE001 + raise ValueError("invalid passphrase or corrupted PQ key file") from exc + + pub_headers, pub_payload = _parse_pem_lines(key_path.with_suffix(key_path.suffix + ".pub"), _PUB_BEGIN, _PUB_END) + if pub_headers.get("alg") != alg: + raise ValueError("public key algorithm does not match private key") + pk = _b64d(pub_payload) + key = cls(alg=alg, sk=sk, pk=pk) + expected_key_id = pub_headers.get("key-id") + if expected_key_id and expected_key_id != key.key_id: + raise ValueError("public key key-id does not match contents") + return key + + def save(self, path: str, passphrase: bytes | None = None) -> None: + key_path = Path(path) + key_path.parent.mkdir(parents=True, exist_ok=True) + + passphrase = passphrase or b"" + salt = os.urandom(_SALT_BYTES) + nonce = os.urandom(_NONCE_BYTES) + key = _derive_key(passphrase, salt, _SCRYPT_N, _SCRYPT_R, _SCRYPT_P) + ChaCha20Poly1305 = _require_crypto() + sealed = ChaCha20Poly1305(key).encrypt(nonce, self.sk, None) + payload, tag = sealed[:-16], sealed[-16:] + headers = { + "alg": self.alg, + "kdf": "scrypt", + "kdf-salt": salt.hex(), + "kdf-n": str(_SCRYPT_N), + "kdf-r": str(_SCRYPT_R), + "kdf-p": str(_SCRYPT_P), + "cipher": "chacha20-poly1305", + "cipher-nonce": nonce.hex(), + "cipher-tag": tag.hex(), + } + + key_text = "\n".join( + [_PEM_BEGIN] + + [f"{k}: {v}" for k, v in headers.items()] + + [_b64e(payload), _PEM_END, ""] + ) + key_path.write_text(key_text) + + pub_text = "\n".join( + [ + _PUB_BEGIN, + f"alg: {self.alg}", + f"key-id: {self.key_id}", + _b64e(self.pk), + _PUB_END, + "", + ] + ) + key_path.with_suffix(key_path.suffix + ".pub").write_text(pub_text) + + def sign(self, transcript: bytes) -> bytes: + return pq.sign(self.alg, self.sk, transcript) + + +def verify(alg: str, pk: bytes, transcript: bytes, sig: bytes) -> bool: + return pq.verify(alg, pk, transcript, sig) diff --git a/tests/test_pq_keys.py b/tests/test_pq_keys.py new file mode 100644 index 0000000..b795f00 --- /dev/null +++ b/tests/test_pq_keys.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import importlib + +import pytest + +from switchboard import pq +from switchboard import pq_keys + + +def test_load_save_roundtrip_with_passphrase(tmp_path) -> None: + pair = pq_keys.PQKeyPair(alg="ml-dsa-65", sk=b"secret-key-material", pk=b"public-key-material") + path = tmp_path / "agent-pq.key" + + pair.save(str(path), passphrase=b"correct horse battery staple") + loaded = pq_keys.PQKeyPair.load(str(path), passphrase=b"correct horse battery staple") + + assert loaded.alg == pair.alg + assert loaded.sk == pair.sk + assert loaded.pk == pair.pk + assert loaded.key_id == pair.key_id + assert path.with_suffix(path.suffix + ".pub").exists() + + +def test_wrong_passphrase_fails_loudly(tmp_path) -> None: + pair = pq_keys.PQKeyPair(alg="ml-dsa-65", sk=b"secret-key-material", pk=b"public-key-material") + path = tmp_path / "agent-pq.key" + pair.save(str(path), passphrase=b"right-passphrase") + + with pytest.raises(ValueError, match="invalid passphrase|corrupted"): + pq_keys.PQKeyPair.load(str(path), passphrase=b"wrong-passphrase") + + +def test_key_id_is_deterministic() -> None: + pk = b"same-public-key" + a = pq_keys.PQKeyPair(alg="ml-dsa-65", sk=b"one", pk=pk) + b = pq_keys.PQKeyPair(alg="ml-dsa-65", sk=b"two", pk=pk) + + assert a.key_id == b.key_id + assert len(a.key_id) == 32 + + +def test_generate_uses_pq_module(monkeypatch) -> None: + def fake_generate(alg: str): + assert alg == "ml-dsa-44" + return (b"pub", b"sec") + + monkeypatch.setattr(pq_keys.pq, "generate", fake_generate) + pair = pq_keys.PQKeyPair.generate("ml-dsa-44") + assert pair.pk == b"pub" + assert pair.sk == b"sec" + assert pair.alg == "ml-dsa-44" + + +def test_sign_and_verify_delegate(monkeypatch) -> None: + pair = pq_keys.PQKeyPair(alg="ml-dsa-65", sk=b"sec", pk=b"pub") + calls = [] + + def fake_sign(alg: str, sk: bytes, transcript: bytes) -> bytes: + calls.append(("sign", alg, sk, transcript)) + return b"sig" + + def fake_verify(alg: str, pk: bytes, transcript: bytes, sig: bytes) -> bool: + calls.append(("verify", alg, pk, transcript, sig)) + return True + + monkeypatch.setattr(pq_keys.pq, "sign", fake_sign) + monkeypatch.setattr(pq_keys.pq, "verify", fake_verify) + + sig = pair.sign(b"transcript") + ok = pq_keys.verify(pair.alg, pair.pk, b"transcript", sig) + + assert sig == b"sig" + assert ok is True + assert calls == [ + ("sign", "ml-dsa-65", b"sec", b"transcript"), + ("verify", "ml-dsa-65", b"pub", b"transcript", b"sig"), + ] + + +def test_save_load_without_passphrase_still_uses_envelope(tmp_path) -> None: + pair = pq_keys.PQKeyPair(alg="ml-dsa-65", sk=b"secret-key-material", pk=b"public-key-material") + path = tmp_path / "agent-pq.key" + + pair.save(str(path)) + loaded = pq_keys.PQKeyPair.load(str(path)) + + assert loaded.sk == pair.sk + assert loaded.pk == pair.pk + text = path.read_text() + assert "kdf: scrypt" in text + assert "cipher: chacha20-poly1305" in text + assert "cipher-tag:" in text + + +def test_import_safe_without_oqs(monkeypatch) -> None: + monkeypatch.setattr(pq, "HAS_OQS", False) + reloaded = importlib.reload(pq_keys) + pair = reloaded.PQKeyPair(alg="ml-dsa-65", sk=b"secret", pk=b"public") + assert hasattr(reloaded, "PQKeyPair") + assert hasattr(reloaded, "verify") + assert pair.key_id