diff --git a/evolution/core/hermes_provider.py b/evolution/core/hermes_provider.py index 0a794dba..9642628e 100644 --- a/evolution/core/hermes_provider.py +++ b/evolution/core/hermes_provider.py @@ -721,24 +721,16 @@ def _maybe_resolve_nous_lm( target_model: str, role: Role, ) -> Optional[ResolvedLM]: - """Build a NousLM-backed ResolvedLM when the auth.json pool entry - looks OAuth-managed; return None to let the caller fall through to - the generic OpenAI-wire handler when the entry is just an env-var- - style API key. - - Nous uses a two-stage credential model: an OAuth access_token - (long-lived) is exchanged for a short-lived agent_key that's the - actual inference Bearer. NousLM handles both: refresh access_token - in-memory when expiring, mint a fresh agent_key from it, re-mint on - inference 401. See evolution/core/nous_lm.py. - - The "looks OAuth-managed" signal: pool entry has a refresh_token. A - pool entry without refresh_token is either env-var-only (NOUS_API_KEY - set, no real OAuth state) or hand-edited; let the caller fall - through to direct pass-through so we don't break that setup. - - The CodexLM-equivalent NousLM import is lazy to avoid a circular - dependency: nous_lm imports HermesProviderError from this module. + """Build a NousLM-backed ResolvedLM when the pool entry has a + refresh_token (the OAuth-managed signal). Returns None for env-var + or hand-edited entries with an agent_key already present (caller + falls through to the generic OpenAI-wire handler), and raises for + partial OAuth setups (access_token without refresh_token or + agent_key) so the operator gets a `hermes model` recovery hint + instead of a silent inference 401. + + See ``evolution/core/nous_lm.py`` for the two-stage credential + model and the in-memory refresh + mint flow. """ pool_entry = _pick_pool_entry(auth_store, "nous") if pool_entry is None: diff --git a/evolution/core/nous_lm.py b/evolution/core/nous_lm.py index 8e7feca5..6e05e9ef 100644 --- a/evolution/core/nous_lm.py +++ b/evolution/core/nous_lm.py @@ -197,8 +197,8 @@ def __init__( agent_key_expires_at=agent_key_expires_at, ) - # Initial mint if the constructor-supplied agent_key is missing or - # already expiring. Cheap on the happy path; one POST otherwise. + # Pay the mint cost at construction so the first forward() doesn't + # see a synchronous round-trip surprise. self._ensure_credentials() # ------------------------------------------------------------------ @@ -207,8 +207,13 @@ def __init__( def _oauth_needs_refresh(self) -> bool: if self._shared_state.oauth_expires_at is None: - # Unknown expiry → don't speculatively refresh; let the mint - # call surface a 401 if the access_token is actually dead. + # Unknown expiry → don't speculatively refresh; the mint + # call's own 401-triggers-refresh-retry path catches a + # genuinely-dead access_token. Note _agent_key_needs_mint + # makes the opposite choice (defaults True on unknown + # expiry) because there's no equivalent recovery for a + # missing agent_key — inference would just 401 with no + # built-in retry. return False return ( time.time() + OAUTH_REFRESH_SKEW_SECONDS @@ -219,8 +224,8 @@ def _agent_key_needs_mint(self) -> bool: if not self._shared_state.agent_key: return True if self._shared_state.agent_key_expires_at is None: - # Have a key but no expiry — treat as needing re-mint to be - # safe. Cheaper than letting it 401 mid-run. + # Have a key but no expiry → re-mint defensively. See + # _oauth_needs_refresh for the asymmetric reasoning. return True return ( time.time() + AGENT_KEY_REFRESH_SKEW_SECONDS @@ -228,7 +233,6 @@ def _agent_key_needs_mint(self) -> bool: ) def _sync_from_shared_state(self) -> None: - """Pull the latest agent_key out of shared state into self.kwargs.""" self.kwargs["api_key"] = self._shared_state.agent_key or "" def _ensure_credentials(self) -> None: @@ -254,8 +258,16 @@ def _force_remint(self) -> None: """Skip skew check and re-mint immediately. Called when an inference call returned 401 — the cached agent_key is bad and we don't want to wait for the skew window. + + Pre-checks the OAuth expiry too. Without this, a stale OAuth + + revoked agent_key combo takes three round-trips (mint→401→ + refresh→mint); with the pre-check it's two (refresh→mint). The + mint's 401-triggers-refresh path still backstops the case where + OAuth looks fresh by skew but the portal has revoked it. """ with self._shared_state.lock: + if self._oauth_needs_refresh(): + self._refresh_oauth() self._mint_agent_key(allow_oauth_retry=True) self._sync_from_shared_state() @@ -383,6 +395,15 @@ def _mint_agent_key(self, *, allow_oauth_retry: bool) -> None: raise HermesProviderError(_format_mint_error(response)) def _absorb_mint_response(self, response: httpx.Response) -> None: + """Parse a 200 mint response into shared state. + + Tolerates both the current ``api_key`` field and the older + ``agent_key`` shape, and prefers a server-supplied ``expires_at`` + ISO 8601 timestamp over the relative ``expires_in``. When neither + expiry field is parseable, falls back to the requested floor TTL + with a warning so portal protocol drift doesn't silently cache a + key for longer than the server intended. + """ try: payload = response.json() except ValueError as exc: @@ -391,9 +412,6 @@ def _absorb_mint_response(self, response: httpx.Response) -> None: "Run `hermes model` to re-authenticate." ) from exc - # Hermes uses both ``api_key`` (current portal field) and falls back - # to ``agent_key`` (older shape). Mirror both so a portal protocol - # rev doesn't break us. agent_key = payload.get("api_key") or payload.get("agent_key") if not isinstance(agent_key, str) or not agent_key.strip(): raise HermesProviderError( @@ -489,7 +507,10 @@ def _format_oauth_error(response: httpx.Response) -> str: """ code, detail = _parse_error_body(response) - if code == "refresh_token_reused" or "reuse" in detail.lower(): + # Match the explicit code field, not the free-form detail string — + # a substring search on detail would false-positive on unrelated + # portal messages like "this is not a reusable connection". + if "reused" in code.lower(): return ( "Nous Portal refresh token was already consumed by another " "client (the portal enforces single-use refresh-token rotation). " @@ -529,27 +550,36 @@ def _format_mint_error(response: httpx.Response) -> str: def _parse_error_body(response: httpx.Response) -> tuple[str, str]: """Best-effort parse of OAuth-style error JSON. Returns (code, detail) with sensible defaults when the body is missing or malformed. + + On JSON parse failure (e.g., a CDN returning an HTML error page, + or a portal outage returning text), ``detail`` falls back to a + truncated snippet of the raw body so the operator can correlate + the failure with what the upstream actually sent. """ code = "unknown" detail = f"status {response.status_code}" try: body = response.json() - if isinstance(body, dict): - err = body.get("error") - if isinstance(err, dict): - # OpenAI shape: {"error": {"code": ..., "message": ...}} - nested_code = err.get("code") or err.get("type") - if isinstance(nested_code, str) and nested_code.strip(): - code = nested_code.strip() - nested_msg = err.get("message") - if isinstance(nested_msg, str) and nested_msg.strip(): - detail = nested_msg.strip() - elif isinstance(err, str) and err.strip(): - # OAuth-spec shape: {"error": "code", "error_description": "..."} - code = err.strip() - desc = body.get("error_description") or body.get("message") - if isinstance(desc, str) and desc.strip(): - detail = desc.strip() except ValueError: - pass + snippet = (response.text or "").strip() + if snippet: + detail = f"status {response.status_code}: {snippet[:512]}" + return code, detail + + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict): + # OpenAI shape: {"error": {"code": ..., "message": ...}} + nested_code = err.get("code") or err.get("type") + if isinstance(nested_code, str) and nested_code.strip(): + code = nested_code.strip() + nested_msg = err.get("message") + if isinstance(nested_msg, str) and nested_msg.strip(): + detail = nested_msg.strip() + elif isinstance(err, str) and err.strip(): + # OAuth-spec shape: {"error": "code", "error_description": "..."} + code = err.strip() + desc = body.get("error_description") or body.get("message") + if isinstance(desc, str) and desc.strip(): + detail = desc.strip() return code, detail diff --git a/tests/core/test_nous_lm.py b/tests/core/test_nous_lm.py index 25c16caa..f6660fb6 100644 --- a/tests/core/test_nous_lm.py +++ b/tests/core/test_nous_lm.py @@ -753,6 +753,116 @@ def test_aforward_propagates_second_401_as_hermes_provider_error(self): ) +class TestForceRemintPreChecksOAuth: + """When inference 401s and we force a re-mint, an OAuth that's also + expiring should be refreshed FIRST, not re-discovered via mint→401→ + refresh→mint (three round-trips). Saves one hop on the rare double- + stale path; the mint's own 401-retry still backstops the case where + OAuth looks fresh by skew but the portal has revoked it. + """ + + def test_force_remint_refreshes_oauth_when_also_expiring(self): + # Build LM with both creds expiring — initial mint already fires. + with patch("evolution.core.nous_lm.httpx.Client") as mock_cls: + mock_cls.return_value = _mock_httpx_post( + [ + # Initial _ensure_credentials path: OAuth is stale, refresh first. + _mock_response(json_body={"access_token": "init-refresh", "expires_in": 86400}), + # Then mint with the refreshed access_token. + _mock_response(json_body={"api_key": "init-mint", "expires_in": 1800}), + ] + ) + lm = NousLM( + model="openai/test-model", + access_token="stale", + refresh_token="r", + oauth_expires_at=time.time() + 30, # forces refresh + ) + + # Now manually expire the OAuth again and call _force_remint. + # Expect: refresh POST + mint POST (2 calls), NOT mint→401→refresh→mint (3+). + lm._shared_state.oauth_expires_at = time.time() + 30 # stale again + with patch("evolution.core.nous_lm.httpx.Client") as mock_cls: + mock_cls.return_value = _mock_httpx_post( + [ + _mock_response(json_body={"access_token": "force-refresh", "expires_in": 86400}), + _mock_response(json_body={"api_key": "force-mint", "expires_in": 1800}), + ] + ) + lm._force_remint() + client = mock_cls.return_value + # Exactly 2 calls, in order: refresh THEN mint with the + # fresh access_token. + assert client.post.call_count == 2 + paths = [c.args[0] for c in client.post.call_args_list] + assert paths[0].endswith("/api/oauth/token") + assert paths[1].endswith("/api/oauth/agent-key") + # Mint Bearer is the freshly-refreshed token, not the stale one. + assert ( + client.post.call_args_list[1].kwargs["headers"]["Authorization"] + == "Bearer force-refresh" + ) + + +class TestParseErrorBodyTextFallback: + """When the OAuth/mint response body isn't JSON (CDN HTML page, + portal outage HTML, etc.), the error message should include a + snippet of what the upstream actually sent — not just a generic + 'unknown: status N'. + """ + + def test_html_body_appears_in_error_detail(self): + from evolution.core.nous_lm import _parse_error_body + + mock = MagicMock(spec=httpx.Response) + mock.status_code = 502 + mock.json = MagicMock(side_effect=ValueError("not json")) + mock.text = "Cloudflare 1020 Access Denied" + + code, detail = _parse_error_body(mock) + assert code == "unknown" + assert "Cloudflare 1020" in detail + assert "status 502" in detail + + def test_empty_body_falls_back_to_status_only(self): + from evolution.core.nous_lm import _parse_error_body + + mock = MagicMock(spec=httpx.Response) + mock.status_code = 503 + mock.json = MagicMock(side_effect=ValueError("not json")) + mock.text = "" + + code, detail = _parse_error_body(mock) + assert code == "unknown" + assert detail == "status 503" + + +class TestReuseSubstringMatchesCodeNotDetail: + """Regression guard: the 'reused' check is on the code field, not on + the free-form detail string. A portal returning a server-error body + like 'this is not a reusable connection' must NOT trigger the + refresh_token_reused user-facing message. + """ + + def test_reusable_in_detail_does_not_trigger_reuse_message(self): + from evolution.core.nous_lm import _format_oauth_error + + mock = MagicMock(spec=httpx.Response) + mock.status_code = 500 + mock.json = MagicMock( + return_value={ + "error": { + "code": "internal_error", + "message": "this is not a reusable connection", + } + } + ) + + msg = _format_oauth_error(mock) + assert "another client" not in msg + assert "single-use refresh-token rotation" not in msg + + class TestSharedStateInvariants: def test_post_init_rejects_partial_agent_key_state(self): # _SharedNousState __post_init__ catches the construction-time