Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions src/kimi_cli/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,18 +452,26 @@ def delete_tokens(ref: OAuthRef) -> None:


async def request_device_authorization() -> DeviceAuthorization:
oauth_host = _oauth_host().rstrip("/")
logger.info("Requesting device authorization from {host}", host=oauth_host)
async with (
new_client_session() as session,
session.post(
f"{_oauth_host().rstrip('/')}/api/oauth/device_authorization",
f"{oauth_host}/api/oauth/device_authorization",
data={"client_id": KIMI_CODE_CLIENT_ID},
headers=_common_headers(),
) as response,
):
data = await response.json(content_type=None)
status = response.status
if status != 200:
logger.error("Device authorization failed (HTTP {status}): {data}", status=status, data=data)
raise OAuthError(f"Device authorization failed: {data}")
logger.info(
"Device authorization obtained: user_code={user_code}, interval={interval}s",
user_code=data.get("user_code"),
interval=data.get("interval", 5),
)
return DeviceAuthorization(
user_code=str(data["user_code"]),
device_code=str(data["device_code"]),
Expand All @@ -475,6 +483,7 @@ async def request_device_authorization() -> DeviceAuthorization:


async def _request_device_token(auth: DeviceAuthorization) -> tuple[int, dict[str, Any]]:
logger.debug("Polling for device token (user_code={user_code})", user_code=auth.user_code)
try:
async with (
new_client_session() as session,
Expand All @@ -491,6 +500,7 @@ async def _request_device_token(auth: DeviceAuthorization) -> tuple[int, dict[st
data_any: Any = await response.json(content_type=None)
status = response.status
except aiohttp.ClientError as exc:
logger.warning("Token polling request failed: {error}", error=exc)
raise OAuthError("Token polling request failed.") from exc
if not isinstance(data_any, dict):
raise OAuthError("Unexpected token polling response.")
Expand All @@ -501,6 +511,7 @@ async def _request_device_token(auth: DeviceAuthorization) -> tuple[int, dict[st


async def refresh_token(refresh_token: str, *, max_retries: int = 3) -> OAuthToken:
logger.debug("Refreshing OAuth access token")
last_exc: Exception | None = None
for attempt in range(max_retries):
try:
Expand Down Expand Up @@ -531,7 +542,9 @@ async def refresh_token(refresh_token: str, *, max_retries: int = 3) -> OAuthTok
if status in _RETRYABLE_REFRESH_STATUSES:
raise _RetryableRefreshError(desc)
raise OAuthError(desc)
return OAuthToken.from_response(data)
token = OAuthToken.from_response(data)
logger.info("Token refreshed successfully (expires_in={expires_in}s)", expires_in=token.expires_in)
return token
except OAuthUnauthorized:
raise
except (aiohttp.ClientError, TimeoutError, OSError, _RetryableRefreshError) as exc:
Expand Down Expand Up @@ -622,15 +635,21 @@ async def login_kimi_code(
yield OAuthEvent("error", "Kimi Code platform is unavailable.")
return

logger.info("Starting Kimi Code device OAuth login flow")
auth: DeviceAuthorization
token: OAuthToken | None = None
while True:
try:
auth = await request_device_authorization()
except Exception as exc:
logger.error("Device authorization request failed: {error}", error=exc)
yield OAuthEvent("error", f"Login failed: {exc}")
return

logger.info(
"Device code obtained, waiting for user to authorize at {url}",
url=auth.verification_uri_complete,
)
yield OAuthEvent(
"info",
"Please visit the following URL to finish authorization.",
Expand All @@ -656,6 +675,7 @@ async def login_kimi_code(
status, data = await _request_device_token(auth)
if status == 200 and "access_token" in data:
token = OAuthToken.from_response(data)
logger.info("Device token obtained successfully (scope={scope})", scope=token.scope)
break
error_code = str(data.get("error") or "unknown_error")
if error_code == "expired_token":
Expand All @@ -682,33 +702,62 @@ async def login_kimi_code(

assert token is not None

oauth_ref = OAuthRef(storage="file", key=KIMI_CODE_OAUTH_KEY)
oauth_ref = save_tokens(oauth_ref, token)

# Validate everything we can in memory BEFORE persisting credentials.
# If we wrote the token to disk first and then list_models / save_config
# failed, the next launch would see valid credentials but no default_model
# (banner stuck on "Model: not set") with no way to recover except a
# blind /login retry.
try:
models = await list_models(platform, token.access_token)
except Exception as exc:
logger.error("Failed to get models: {error}", error=exc)
yield OAuthEvent("error", f"Failed to get models: {exc}")
return

logger.info("Fetched {count} models for platform", count=len(models))

if not models:
yield OAuthEvent("error", "No models available for the selected platform.")
return

selection = _select_default_model_and_thinking(models)
if selection is None:
yield OAuthEvent("error", "Failed to select a default model from the returned list.")
return
selected_model, thinking = selection

_apply_kimi_code_config(
config,
models=models,
selected_model=selected_model,
# All validation passed — now persist credentials + config.
oauth_ref = OAuthRef(storage="file", key=KIMI_CODE_OAUTH_KEY)
try:
oauth_ref = save_tokens(oauth_ref, token)
except OSError as exc:
logger.error("Failed to save credentials: {error}", error=exc)
yield OAuthEvent("error", f"Failed to save credentials: {exc}")
return

try:
_apply_kimi_code_config(
config,
models=models,
selected_model=selected_model,
thinking=thinking,
oauth_ref=oauth_ref,
)
save_config(config)
except Exception as exc:
# Roll back the credentials write so we don't leave the user in the
# zombie state (token on disk, default_model missing).
logger.error("Failed to save config; rolling back credentials: {error}", error=exc)
with suppress(Exception):
delete_tokens(oauth_ref)
yield OAuthEvent("error", f"Failed to save config: {exc}")
return

logger.info(
"Login successful: default_model={model}, thinking={thinking}",
model=config.default_model,
thinking=thinking,
oauth_ref=oauth_ref,
)
save_config(config)
yield OAuthEvent("success", "Logged in successfully.")
return

Expand Down
213 changes: 213 additions & 0 deletions tests/auth/test_login_kimi_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Tests for the login_kimi_code flow ordering.

Regression guard for: token was previously persisted to disk before list_models
was called, leaving credentials valid but `default_model` unset when list_models
failed (banner stuck on "Model: not set").
"""

from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from pydantic import SecretStr

from kimi_cli.auth.oauth import DeviceAuthorization, login_kimi_code
from kimi_cli.config import Config, Services


def _empty_config() -> Config:
cfg = Config(
default_model="",
providers={},
models={},
services=Services(),
)
cfg.is_from_default_location = True
return cfg


def _device_auth() -> DeviceAuthorization:
return DeviceAuthorization(
user_code="ABCD-1234",
device_code="dev-code",
verification_uri="https://kimi.test/verify",
verification_uri_complete="https://kimi.test/verify?user_code=ABCD-1234",
expires_in=600,
interval=1,
)


def _token_payload() -> dict[str, object]:
return {
"access_token": "acc",
"refresh_token": "ref",
"expires_in": 900,
"scope": "kimi-code",
"token_type": "Bearer",
}


@pytest.mark.asyncio
async def test_list_models_failure_does_not_persist_token():
"""If list_models raises, save_tokens MUST NOT be called.

Otherwise we leave a valid token on disk while config.default_model stays
empty → banner permanently shows "Model: not set, send /login to login".
"""
config = _empty_config()
save_tokens_mock = MagicMock()
save_config_mock = MagicMock()
delete_tokens_mock = MagicMock()
apply_config_mock = MagicMock()
platform = MagicMock(id="kimi-code", base_url="https://api.test")

with (
patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())),
patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))),
patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform),
patch("kimi_cli.auth.oauth.list_models", AsyncMock(side_effect=RuntimeError("boom"))),
patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock),
patch("kimi_cli.auth.oauth.save_config", save_config_mock),
patch("kimi_cli.auth.oauth.delete_tokens", delete_tokens_mock),
patch("kimi_cli.auth.oauth._apply_kimi_code_config", apply_config_mock),
patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()),
):
events = [e async for e in login_kimi_code(config, open_browser=False)]

error_events = [e for e in events if e.type == "error"]
assert error_events, f"expected error event, got: {[e.type for e in events]}"
assert "boom" in error_events[-1].message or "models" in error_events[-1].message.lower()
save_tokens_mock.assert_not_called()
save_config_mock.assert_not_called()
apply_config_mock.assert_not_called()
assert config.default_model == ""


@pytest.mark.asyncio
async def test_empty_model_list_does_not_persist_token():
"""An empty model list must emit an error event AND not persist the token."""
config = _empty_config()
save_tokens_mock = MagicMock()
save_config_mock = MagicMock()
platform = MagicMock(id="kimi-code", base_url="https://api.test")

with (
patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())),
patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))),
patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform),
patch("kimi_cli.auth.oauth.list_models", AsyncMock(return_value=[])),
patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock),
patch("kimi_cli.auth.oauth.save_config", save_config_mock),
patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()),
):
events = [e async for e in login_kimi_code(config, open_browser=False)]

assert any(e.type == "error" for e in events)
save_tokens_mock.assert_not_called()
save_config_mock.assert_not_called()


@pytest.mark.asyncio
async def test_save_config_failure_rolls_back_credentials():
"""If save_config raises after save_tokens succeeded, the credentials must
be deleted to avoid the zombie state (token on disk, no default_model)."""
config = _empty_config()
save_tokens_mock = MagicMock(side_effect=lambda ref, _token: ref)
save_config_mock = MagicMock(side_effect=OSError("disk full"))
delete_tokens_mock = MagicMock()
platform = MagicMock(id="kimi-code", base_url="https://api.test")
model_info = MagicMock(
id="kimi-k2",
context_length=200_000,
capabilities=set(),
display_name="Kimi K2",
)

with (
patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())),
patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))),
patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform),
patch("kimi_cli.auth.oauth.list_models", AsyncMock(return_value=[model_info])),
patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock),
patch("kimi_cli.auth.oauth.save_config", save_config_mock),
patch("kimi_cli.auth.oauth.delete_tokens", delete_tokens_mock),
patch("kimi_cli.auth.oauth._apply_kimi_code_config", MagicMock()),
patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()),
):
events = [e async for e in login_kimi_code(config, open_browser=False)]

save_tokens_mock.assert_called_once()
delete_tokens_mock.assert_called_once() # rollback
assert any(e.type == "error" for e in events)
assert not any(e.type == "success" for e in events)


@pytest.mark.asyncio
async def test_apply_config_failure_rolls_back_credentials():
"""If in-memory config application fails after save_tokens, rollback credentials."""
config = _empty_config()
save_tokens_mock = MagicMock(side_effect=lambda ref, _token: ref)
save_config_mock = MagicMock()
delete_tokens_mock = MagicMock()
platform = MagicMock(id="kimi-code", base_url="https://api.test")
model_info = MagicMock(
id="kimi-k2",
context_length=200_000,
capabilities=set(),
display_name="Kimi K2",
)

with (
patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())),
patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))),
patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform),
patch("kimi_cli.auth.oauth.list_models", AsyncMock(return_value=[model_info])),
patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock),
patch("kimi_cli.auth.oauth.save_config", save_config_mock),
patch("kimi_cli.auth.oauth.delete_tokens", delete_tokens_mock),
patch("kimi_cli.auth.oauth._apply_kimi_code_config", MagicMock(side_effect=ValueError("bad config"))),
patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()),
):
events = [e async for e in login_kimi_code(config, open_browser=False)]

save_tokens_mock.assert_called_once()
delete_tokens_mock.assert_called_once()
save_config_mock.assert_not_called()
assert any(e.type == "error" for e in events)
assert not any(e.type == "success" for e in events)


@pytest.mark.asyncio
async def test_happy_path_persists_token_and_config():
config = _empty_config()
save_tokens_mock = MagicMock(side_effect=lambda ref, _token: ref)
save_config_mock = MagicMock()
apply_config_mock = MagicMock(
side_effect=lambda config, **_kw: setattr(config, "default_model", "managed:kimi-code/kimi-k2")
)
platform = MagicMock(id="kimi-code", base_url="https://api.test")
model_info = MagicMock(
id="kimi-k2",
context_length=200_000,
capabilities=set(),
display_name="Kimi K2",
)

with (
patch("kimi_cli.auth.oauth.request_device_authorization", AsyncMock(return_value=_device_auth())),
patch("kimi_cli.auth.oauth._request_device_token", AsyncMock(return_value=(200, _token_payload()))),
patch("kimi_cli.auth.oauth.get_platform_by_id", return_value=platform),
patch("kimi_cli.auth.oauth.list_models", AsyncMock(return_value=[model_info])),
patch("kimi_cli.auth.oauth.save_tokens", save_tokens_mock),
patch("kimi_cli.auth.oauth.save_config", save_config_mock),
patch("kimi_cli.auth.oauth._apply_kimi_code_config", apply_config_mock),
patch("kimi_cli.auth.oauth.webbrowser.open", MagicMock()),
):
events = [e async for e in login_kimi_code(config, open_browser=False)]

save_tokens_mock.assert_called_once()
apply_config_mock.assert_called_once()
save_config_mock.assert_called_once()
assert any(e.type == "success" for e in events)
assert config.default_model == "managed:kimi-code/kimi-k2"