diff --git a/src/kimi_cli/auth/oauth.py b/src/kimi_cli/auth/oauth.py index 5000b532c..f5b7e5c5f 100644 --- a/src/kimi_cli/auth/oauth.py +++ b/src/kimi_cli/auth/oauth.py @@ -452,10 +452,12 @@ def delete_tokens(ref: OAuthRef) -> None: async def request_device_authorization() -> DeviceAuthorization: + oauth_host = _oauth_host().rstrip("/") + logger.info("Requesting device authorization from {host}", host=oauth_host) async with ( new_client_session() as session, session.post( - f"{_oauth_host().rstrip('/')}/api/oauth/device_authorization", + f"{oauth_host}/api/oauth/device_authorization", data={"client_id": KIMI_CODE_CLIENT_ID}, headers=_common_headers(), ) as response, @@ -463,7 +465,13 @@ async def request_device_authorization() -> DeviceAuthorization: data = await response.json(content_type=None) status = response.status if status != 200: + logger.error("Device authorization failed (HTTP {status}): {data}", status=status, data=data) raise OAuthError(f"Device authorization failed: {data}") + logger.info( + "Device authorization obtained: user_code={user_code}, interval={interval}s", + user_code=data.get("user_code"), + interval=data.get("interval", 5), + ) return DeviceAuthorization( user_code=str(data["user_code"]), device_code=str(data["device_code"]), @@ -475,6 +483,7 @@ async def request_device_authorization() -> DeviceAuthorization: async def _request_device_token(auth: DeviceAuthorization) -> tuple[int, dict[str, Any]]: + logger.debug("Polling for device token (user_code={user_code})", user_code=auth.user_code) try: async with ( new_client_session() as session, @@ -491,6 +500,7 @@ async def _request_device_token(auth: DeviceAuthorization) -> tuple[int, dict[st data_any: Any = await response.json(content_type=None) status = response.status except aiohttp.ClientError as exc: + logger.warning("Token polling request failed: {error}", error=exc) raise OAuthError("Token polling request failed.") from exc if not isinstance(data_any, dict): raise OAuthError("Unexpected token polling response.") @@ -501,6 +511,7 @@ async def _request_device_token(auth: DeviceAuthorization) -> tuple[int, dict[st async def refresh_token(refresh_token: str, *, max_retries: int = 3) -> OAuthToken: + logger.debug("Refreshing OAuth access token") last_exc: Exception | None = None for attempt in range(max_retries): try: @@ -531,7 +542,9 @@ async def refresh_token(refresh_token: str, *, max_retries: int = 3) -> OAuthTok if status in _RETRYABLE_REFRESH_STATUSES: raise _RetryableRefreshError(desc) raise OAuthError(desc) - return OAuthToken.from_response(data) + token = OAuthToken.from_response(data) + logger.info("Token refreshed successfully (expires_in={expires_in}s)", expires_in=token.expires_in) + return token except OAuthUnauthorized: raise except (aiohttp.ClientError, TimeoutError, OSError, _RetryableRefreshError) as exc: @@ -622,15 +635,21 @@ async def login_kimi_code( yield OAuthEvent("error", "Kimi Code platform is unavailable.") return + logger.info("Starting Kimi Code device OAuth login flow") auth: DeviceAuthorization token: OAuthToken | None = None while True: try: auth = await request_device_authorization() except Exception as exc: + logger.error("Device authorization request failed: {error}", error=exc) yield OAuthEvent("error", f"Login failed: {exc}") return + logger.info( + "Device code obtained, waiting for user to authorize at {url}", + url=auth.verification_uri_complete, + ) yield OAuthEvent( "info", "Please visit the following URL to finish authorization.", @@ -656,6 +675,7 @@ async def login_kimi_code( status, data = await _request_device_token(auth) if status == 200 and "access_token" in data: token = OAuthToken.from_response(data) + logger.info("Device token obtained successfully (scope={scope})", scope=token.scope) break error_code = str(data.get("error") or "unknown_error") if error_code == "expired_token": @@ -682,9 +702,11 @@ async def login_kimi_code( assert token is not None - oauth_ref = OAuthRef(storage="file", key=KIMI_CODE_OAUTH_KEY) - oauth_ref = save_tokens(oauth_ref, token) - + # Validate everything we can in memory BEFORE persisting credentials. + # If we wrote the token to disk first and then list_models / save_config + # failed, the next launch would see valid credentials but no default_model + # (banner stuck on "Model: not set") with no way to recover except a + # blind /login retry. try: models = await list_models(platform, token.access_token) except Exception as exc: @@ -692,23 +714,50 @@ async def login_kimi_code( yield OAuthEvent("error", f"Failed to get models: {exc}") return + logger.info("Fetched {count} models for platform", count=len(models)) + if not models: yield OAuthEvent("error", "No models available for the selected platform.") return selection = _select_default_model_and_thinking(models) if selection is None: + yield OAuthEvent("error", "Failed to select a default model from the returned list.") return selected_model, thinking = selection - _apply_kimi_code_config( - config, - models=models, - selected_model=selected_model, + # All validation passed — now persist credentials + config. + oauth_ref = OAuthRef(storage="file", key=KIMI_CODE_OAUTH_KEY) + try: + oauth_ref = save_tokens(oauth_ref, token) + except OSError as exc: + logger.error("Failed to save credentials: {error}", error=exc) + yield OAuthEvent("error", f"Failed to save credentials: {exc}") + return + + try: + _apply_kimi_code_config( + config, + models=models, + selected_model=selected_model, + thinking=thinking, + oauth_ref=oauth_ref, + ) + save_config(config) + except Exception as exc: + # Roll back the credentials write so we don't leave the user in the + # zombie state (token on disk, default_model missing). + logger.error("Failed to save config; rolling back credentials: {error}", error=exc) + with suppress(Exception): + delete_tokens(oauth_ref) + yield OAuthEvent("error", f"Failed to save config: {exc}") + return + + logger.info( + "Login successful: default_model={model}, thinking={thinking}", + model=config.default_model, thinking=thinking, - oauth_ref=oauth_ref, ) - save_config(config) yield OAuthEvent("success", "Logged in successfully.") return diff --git a/tests/auth/test_login_kimi_code.py b/tests/auth/test_login_kimi_code.py new file mode 100644 index 000000000..02ca02d90 --- /dev/null +++ b/tests/auth/test_login_kimi_code.py @@ -0,0 +1,213 @@ +"""Tests for the login_kimi_code flow ordering. + +Regression guard for: token was previously persisted to disk before list_models +was called, leaving credentials valid but `default_model` unset when list_models +failed (banner stuck on "Model: not set"). +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import SecretStr + +from kimi_cli.auth.oauth import DeviceAuthorization, login_kimi_code +from kimi_cli.config import Config, Services + + +def _empty_config() -> Config: + cfg = Config( + default_model="", + providers={}, + models={}, + services=Services(), + ) + cfg.is_from_default_location = True + return cfg + + +def _device_auth() -> DeviceAuthorization: + return DeviceAuthorization( + user_code="ABCD-1234", + device_code="dev-code", + verification_uri="https://kimi.test/verify", + verification_uri_complete="https://kimi.test/verify?user_code=ABCD-1234", + expires_in=600, + interval=1, + ) + + +def _token_payload() -> dict[str, object]: + return { + "access_token": "acc", + "refresh_token": "ref", + "expires_in": 900, + "scope": "kimi-code", + "token_type": "Bearer", + } + + +@pytest.mark.asyncio +async def test_list_models_failure_does_not_persist_token(): + """If list_models raises, save_tokens MUST NOT be called. + + Otherwise we leave a valid token on disk while config.default_model stays + empty → banner permanently shows "Model: not set, send /login to login". + """ + config = _empty_config() + save_tokens_mock = MagicMock() + save_config_mock = MagicMock() + delete_tokens_mock = MagicMock() + apply_config_mock = MagicMock() + platform = MagicMock(id="kimi-code", base_url="https://api.test") + + with ( + patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())), + patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))), + patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform), + patch("kimi_cli.auth.oauth.list_models", AsyncMock(side_effect=RuntimeError("boom"))), + patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock), + patch("kimi_cli.auth.oauth.save_config", save_config_mock), + patch("kimi_cli.auth.oauth.delete_tokens", delete_tokens_mock), + patch("kimi_cli.auth.oauth._apply_kimi_code_config", apply_config_mock), + patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()), + ): + events = [e async for e in login_kimi_code(config, open_browser=False)] + + error_events = [e for e in events if e.type == "error"] + assert error_events, f"expected error event, got: {[e.type for e in events]}" + assert "boom" in error_events[-1].message or "models" in error_events[-1].message.lower() + save_tokens_mock.assert_not_called() + save_config_mock.assert_not_called() + apply_config_mock.assert_not_called() + assert config.default_model == "" + + +@pytest.mark.asyncio +async def test_empty_model_list_does_not_persist_token(): + """An empty model list must emit an error event AND not persist the token.""" + config = _empty_config() + save_tokens_mock = MagicMock() + save_config_mock = MagicMock() + platform = MagicMock(id="kimi-code", base_url="https://api.test") + + with ( + patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())), + patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))), + patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform), + patch("kimi_cli.auth.oauth.list_models", AsyncMock(return_value=[])), + patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock), + patch("kimi_cli.auth.oauth.save_config", save_config_mock), + patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()), + ): + events = [e async for e in login_kimi_code(config, open_browser=False)] + + assert any(e.type == "error" for e in events) + save_tokens_mock.assert_not_called() + save_config_mock.assert_not_called() + + +@pytest.mark.asyncio +async def test_save_config_failure_rolls_back_credentials(): + """If save_config raises after save_tokens succeeded, the credentials must + be deleted to avoid the zombie state (token on disk, no default_model).""" + config = _empty_config() + save_tokens_mock = MagicMock(side_effect=lambda ref, _token: ref) + save_config_mock = MagicMock(side_effect=OSError("disk full")) + delete_tokens_mock = MagicMock() + platform = MagicMock(id="kimi-code", base_url="https://api.test") + model_info = MagicMock( + id="kimi-k2", + context_length=200_000, + capabilities=set(), + display_name="Kimi K2", + ) + + with ( + patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())), + patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))), + patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform), + patch("kimi_cli.auth.oauth.list_models", AsyncMock(return_value=[model_info])), + patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock), + patch("kimi_cli.auth.oauth.save_config", save_config_mock), + patch("kimi_cli.auth.oauth.delete_tokens", delete_tokens_mock), + patch("kimi_cli.auth.oauth._apply_kimi_code_config", MagicMock()), + patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()), + ): + events = [e async for e in login_kimi_code(config, open_browser=False)] + + save_tokens_mock.assert_called_once() + delete_tokens_mock.assert_called_once() # rollback + assert any(e.type == "error" for e in events) + assert not any(e.type == "success" for e in events) + + +@pytest.mark.asyncio +async def test_apply_config_failure_rolls_back_credentials(): + """If in-memory config application fails after save_tokens, rollback credentials.""" + config = _empty_config() + save_tokens_mock = MagicMock(side_effect=lambda ref, _token: ref) + save_config_mock = MagicMock() + delete_tokens_mock = MagicMock() + platform = MagicMock(id="kimi-code", base_url="https://api.test") + model_info = MagicMock( + id="kimi-k2", + context_length=200_000, + capabilities=set(), + display_name="Kimi K2", + ) + + with ( + patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())), + patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))), + patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform), + patch("kimi_cli.auth.oauth.list_models", AsyncMock(return_value=[model_info])), + patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock), + patch("kimi_cli.auth.oauth.save_config", save_config_mock), + patch("kimi_cli.auth.oauth.delete_tokens", delete_tokens_mock), + patch("kimi_cli.auth.oauth._apply_kimi_code_config", MagicMock(side_effect=ValueError("bad config"))), + patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()), + ): + events = [e async for e in login_kimi_code(config, open_browser=False)] + + save_tokens_mock.assert_called_once() + delete_tokens_mock.assert_called_once() + save_config_mock.assert_not_called() + assert any(e.type == "error" for e in events) + assert not any(e.type == "success" for e in events) + + +@pytest.mark.asyncio +async def test_happy_path_persists_token_and_config(): + config = _empty_config() + save_tokens_mock = MagicMock(side_effect=lambda ref, _token: ref) + save_config_mock = MagicMock() + apply_config_mock = MagicMock( + side_effect=lambda config, **_kw: setattr(config, "default_model", "managed:kimi-code/kimi-k2") + ) + platform = MagicMock(id="kimi-code", base_url="https://api.test") + model_info = MagicMock( + id="kimi-k2", + context_length=200_000, + capabilities=set(), + display_name="Kimi K2", + ) + + with ( + patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())), + patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))), + patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform), + patch("kimi_cli.auth.oauth.list_models", AsyncMock(return_value=[model_info])), + patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock), + patch("kimi_cli.auth.oauth.save_config", save_config_mock), + patch("kimi_cli.auth.oauth._apply_kimi_code_config", apply_config_mock), + patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()), + ): + events = [e async for e in login_kimi_code(config, open_browser=False)] + + save_tokens_mock.assert_called_once() + apply_config_mock.assert_called_once() + save_config_mock.assert_called_once() + assert any(e.type == "success" for e in events) + assert config.default_model == "managed:kimi-code/kimi-k2"