From 11af49466c0e911f265cc4c8103b97d8213d2c66 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 22:19:40 -0400 Subject: [PATCH 1/3] tls test ete --- src/opengradient/client/tee_connection.py | 6 +- tests/tee_connection_test.py | 135 ++++++++++++++++++++++ 2 files changed, 137 insertions(+), 4 deletions(-) diff --git a/src/opengradient/client/tee_connection.py b/src/opengradient/client/tee_connection.py index 2a7682d..e7317a1 100644 --- a/src/opengradient/client/tee_connection.py +++ b/src/opengradient/client/tee_connection.py @@ -142,12 +142,10 @@ def _connect(self) -> ActiveTEE: """Resolve TEE from registry and create a secure HTTP client.""" tee = self._resolve_tee() - ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) if tee.tls_cert_der else None - tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else True - + ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) return ActiveTEE( endpoint=tee.endpoint, - http_client=x402HttpxClient(self._x402_client, verify=tls_verify), + http_client=x402HttpxClient(self._x402_client, verify=ssl_ctx), tee_id=tee.tee_id, payment_address=tee.payment_address, ) diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py index 3f01912..ea96455 100644 --- a/tests/tee_connection_test.py +++ b/tests/tee_connection_test.py @@ -1,15 +1,25 @@ """Tests for RegistryTEEConnection and ActiveTEE.""" import asyncio +import datetime +import os import ssl +import tempfile from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID +from x402 import x402Client from src.opengradient.client.tee_connection import ( ActiveTEE, RegistryTEEConnection, ) +from src.opengradient.client.tee_registry import build_ssl_context_from_der # ── Helpers ────────────────────────────────────────────────────────── @@ -327,3 +337,128 @@ async def test_close_without_refresh_task(self): conn = _make_registry_connection(registry=mock_reg) await conn.close() # should not raise + + +# ── TLS certificate verification (real handshake) ──────────────────── + + +def _make_self_signed_cert(): + """Generate a self-signed cert. Returns (der_bytes, pem_cert_bytes, pem_key_bytes).""" + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "localhost")]) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.UTC)) + .not_valid_after(datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)) + .sign(key, hashes.SHA256()) + ) + return ( + cert.public_bytes(serialization.Encoding.DER), + cert.public_bytes(serialization.Encoding.PEM), + key.private_bytes(serialization.Encoding.PEM, serialization.PrivateFormat.TraditionalOpenSSL, serialization.NoEncryption()), + ) + + +@pytest.fixture +async def tls_server(): + """Spin up a local TLS server with a self-signed cert.""" + der, pem_cert, pem_key = _make_self_signed_cert() + + cert_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False) + key_file = tempfile.NamedTemporaryFile(suffix=".pem", delete=False) + try: + cert_file.write(pem_cert) + cert_file.close() + key_file.write(pem_key) + key_file.close() + + server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_ctx.load_cert_chain(cert_file.name, key_file.name) + + async def handler(reader, writer): + await reader.read(4096) + writer.write(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok") + await writer.drain() + writer.close() + + server = await asyncio.start_server(handler, "127.0.0.1", 0, ssl=server_ctx) + port = server.sockets[0].getsockname()[1] + + yield {"port": port, "der": der} + + server.close() + await server.wait_closed() + finally: + os.unlink(cert_file.name) + os.unlink(key_file.name) + + +def _registry_with_real_cert(tls_server): + """Return a mock registry that serves the local TLS server's real DER cert.""" + return _mock_registry_with_tee( + endpoint=f"https://127.0.0.1:{tls_server['port']}", + tls_cert_der=tls_server["der"], + tee_id="tee-real", + payment_address="0xRealPay", + ) + + +@pytest.mark.asyncio +class TestTlsCertVerification: + """End-to-end TLS handshake tests through RegistryTEEConnection. + + A real local TLS server is started with a self-signed cert. The registry + mock returns that cert's DER bytes. RegistryTEEConnection._connect() runs + its real code (build_ssl_context_from_der → x402HttpxClient(verify=ctx)) + so the full cert-pinning path is exercised with an actual TLS handshake. + """ + + async def test_connect_succeeds_with_matching_cert(self, tls_server): + mock_reg = _registry_with_real_cert(tls_server) + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) + + resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") + assert resp.status_code == 200 + assert conn.get().tee_id == "tee-real" + assert conn.get().payment_address == "0xRealPay" + await conn.close() + + async def test_connect_fails_with_wrong_cert(self, tls_server): + wrong_der, _, _ = _make_self_signed_cert() # different key pair + mock_reg = _mock_registry_with_tee( + endpoint=f"https://127.0.0.1:{tls_server['port']}", + tls_cert_der=wrong_der, + ) + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) + + with pytest.raises(httpx.ConnectError): + await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") + await conn.close() + + async def test_connect_fails_with_no_cert_pinning(self, tls_server): + """Without a pinned cert (tls_cert_der=None), system CAs are used + which won't trust our self-signed server cert.""" + mock_reg = _mock_registry_with_tee( + endpoint=f"https://127.0.0.1:{tls_server['port']}", + tls_cert_der=None, + ) + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) + + with pytest.raises(httpx.ConnectError): + await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") + await conn.close() + + async def test_reconnect_picks_up_new_cert(self, tls_server): + """After reconnect, the connection uses the freshly-resolved cert.""" + mock_reg = _registry_with_real_cert(tls_server) + conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) + + await conn.reconnect() + + resp = await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") + assert resp.status_code == 200 + await conn.close() From 077ca37cba19ab9f71b9fa8879d02611d6523075 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 22:23:26 -0400 Subject: [PATCH 2/3] fix test --- tests/client_test.py | 7 +++++-- tests/tee_connection_test.py | 19 +++++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/client_test.py b/tests/client_test.py index 2e2bf14..a2ef531 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -22,10 +22,13 @@ @pytest.fixture def mock_tee_registry(): """Mock the TEE registry so LLM.__init__ doesn't need a live registry.""" - with patch("src.opengradient.client.llm.TEERegistry") as mock_tee_registry: + with patch("src.opengradient.client.llm.TEERegistry") as mock_tee_registry, patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(), + ): mock_tee = MagicMock() mock_tee.endpoint = "https://test.tee.server" - mock_tee.tls_cert_der = None + mock_tee.tls_cert_der = b"fake-der" mock_tee.tee_id = "test-tee-id" mock_tee.payment_address = "0xTestPaymentAddress" mock_tee_registry.return_value.get_llm_tee.return_value = mock_tee diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py index ea96455..530bc9f 100644 --- a/tests/tee_connection_test.py +++ b/tests/tee_connection_test.py @@ -45,6 +45,9 @@ def _make_registry_connection(*, registry=None, http_factory=None): with patch( "src.opengradient.client.tee_connection.x402HttpxClient", side_effect=factory, + ), patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(spec=ssl.SSLContext), ): return RegistryTEEConnection( x402_client=_mock_x402_client(), @@ -52,7 +55,7 @@ def _make_registry_connection(*, registry=None, http_factory=None): ) -def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=None, tee_id="tee-1", payment_address="0xPay"): +def _mock_registry_with_tee(endpoint="https://tee.endpoint", tls_cert_der=b"fake-der", tee_id="tee-1", payment_address="0xPay"): mock_reg = MagicMock() mock_tee = MagicMock() mock_tee.endpoint = endpoint @@ -182,6 +185,9 @@ def make_client(*args, **kwargs): with patch( "src.opengradient.client.tee_connection.x402HttpxClient", side_effect=make_client, + ), patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(spec=ssl.SSLContext), ): conn = RegistryTEEConnection( x402_client=_mock_x402_client(), @@ -440,17 +446,14 @@ async def test_connect_fails_with_wrong_cert(self, tls_server): await conn.close() async def test_connect_fails_with_no_cert_pinning(self, tls_server): - """Without a pinned cert (tls_cert_der=None), system CAs are used - which won't trust our self-signed server cert.""" + """Without a pinned cert (tls_cert_der=None), build_ssl_context_from_der + rejects the None value and connection construction fails.""" mock_reg = _mock_registry_with_tee( endpoint=f"https://127.0.0.1:{tls_server['port']}", tls_cert_der=None, ) - conn = RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) - - with pytest.raises(httpx.ConnectError): - await conn.get().http_client.get(f"https://127.0.0.1:{tls_server['port']}/") - await conn.close() + with pytest.raises(TypeError): + RegistryTEEConnection(x402_client=x402Client(), registry=mock_reg) async def test_reconnect_picks_up_new_cert(self, tls_server): """After reconnect, the connection uses the freshly-resolved cert.""" From eaa7528258d8dbbb63ef45a17634dea9ef4f2dd0 Mon Sep 17 00:00:00 2001 From: "balogh.adam@icloud.com" Date: Thu, 26 Mar 2026 22:25:02 -0400 Subject: [PATCH 3/3] fix all tests --- src/opengradient/client/tee_connection.py | 2 +- tests/client_test.py | 9 ++++-- tests/tee_connection_test.py | 39 +++++++++++++++-------- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/src/opengradient/client/tee_connection.py b/src/opengradient/client/tee_connection.py index e7317a1..2807b88 100644 --- a/src/opengradient/client/tee_connection.py +++ b/src/opengradient/client/tee_connection.py @@ -142,7 +142,7 @@ def _connect(self) -> ActiveTEE: """Resolve TEE from registry and create a secure HTTP client.""" tee = self._resolve_tee() - ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) + ssl_ctx = build_ssl_context_from_der(tee.tls_cert_der) return ActiveTEE( endpoint=tee.endpoint, http_client=x402HttpxClient(self._x402_client, verify=ssl_ctx), diff --git a/tests/client_test.py b/tests/client_test.py index a2ef531..d7d70f8 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -22,9 +22,12 @@ @pytest.fixture def mock_tee_registry(): """Mock the TEE registry so LLM.__init__ doesn't need a live registry.""" - with patch("src.opengradient.client.llm.TEERegistry") as mock_tee_registry, patch( - "src.opengradient.client.tee_connection.build_ssl_context_from_der", - return_value=MagicMock(), + with ( + patch("src.opengradient.client.llm.TEERegistry") as mock_tee_registry, + patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(), + ), ): mock_tee = MagicMock() mock_tee.endpoint = "https://test.tee.server" diff --git a/tests/tee_connection_test.py b/tests/tee_connection_test.py index 530bc9f..2d54c77 100644 --- a/tests/tee_connection_test.py +++ b/tests/tee_connection_test.py @@ -42,12 +42,15 @@ def _mock_x402_client(): def _make_registry_connection(*, registry=None, http_factory=None): """Build a RegistryTEEConnection with patched externals.""" factory = http_factory or FakeHTTPClient - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=factory, - ), patch( - "src.opengradient.client.tee_connection.build_ssl_context_from_der", - return_value=MagicMock(spec=ssl.SSLContext), + with ( + patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=factory, + ), + patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(spec=ssl.SSLContext), + ), ): return RegistryTEEConnection( x402_client=_mock_x402_client(), @@ -182,12 +185,15 @@ def make_client(*args, **kwargs): mock_reg = _mock_registry_with_tee() - with patch( - "src.opengradient.client.tee_connection.x402HttpxClient", - side_effect=make_client, - ), patch( - "src.opengradient.client.tee_connection.build_ssl_context_from_der", - return_value=MagicMock(spec=ssl.SSLContext), + with ( + patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=make_client, + ), + patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(spec=ssl.SSLContext), + ), ): conn = RegistryTEEConnection( x402_client=_mock_x402_client(), @@ -199,7 +205,6 @@ def make_client(*args, **kwargs): assert conn.get().http_client is not old_client assert len(clients_created) == 2 - async def test_reconnect_swallows_close_failure(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) @@ -224,7 +229,13 @@ def slow_connect(self): mock_reg = _mock_registry_with_tee() conn = _make_registry_connection(registry=mock_reg) - with patch.object(RegistryTEEConnection, "_connect", slow_connect): + with patch.object(RegistryTEEConnection, "_connect", slow_connect), patch( + "src.opengradient.client.tee_connection.build_ssl_context_from_der", + return_value=MagicMock(spec=ssl.SSLContext), + ), patch( + "src.opengradient.client.tee_connection.x402HttpxClient", + side_effect=FakeHTTPClient, + ): await asyncio.gather(conn.reconnect(), conn.reconnect()) assert call_order == ["start", "end", "start", "end"]