diff --git a/services/protocol/openai_v1_chat_complete.py b/services/protocol/openai_v1_chat_complete.py index 96beb7b..c1b207c 100644 --- a/services/protocol/openai_v1_chat_complete.py +++ b/services/protocol/openai_v1_chat_complete.py @@ -115,12 +115,26 @@ def stream_grok_chat_completion(body: dict[str, Any], spec, messages: list[dict[ return completion_id = f"chatcmpl-{uuid.uuid4().hex}" created = int(time.time()) - response = grok.console_chat_completion(body, spec, messages) - if response.reasoning_content: - yield completion_chunk(model, {"role": "assistant", "reasoning_content": response.reasoning_content}, None, completion_id, created) - yield completion_chunk(model, {"content": response.content}, None, completion_id, created) - else: - yield completion_chunk(model, {"role": "assistant", "content": response.content}, None, completion_id, created) + sent_role = False + for event in grok.console_chat_completion_events(body, spec, messages): + delta = grok.extract_console_stream_delta(event) + if not delta.content and not delta.reasoning_content: + continue + if not sent_role: + sent_role = True + first_delta: dict[str, Any] = {"role": "assistant"} + if delta.reasoning_content: + first_delta["reasoning_content"] = delta.reasoning_content + else: + first_delta["content"] = delta.content + yield completion_chunk(model, first_delta, None, completion_id, created) + continue + if delta.reasoning_content: + yield completion_chunk(model, {"reasoning_content": delta.reasoning_content}, None, completion_id, created) + else: + yield completion_chunk(model, {"content": delta.content}, None, completion_id, created) + if not sent_role: + yield completion_chunk(model, {"role": "assistant", "content": ""}, None, completion_id, created) yield completion_chunk(model, {}, "stop", completion_id, created) diff --git a/services/providers/grok.py b/services/providers/grok.py index 0848abc..ede4107 100644 --- a/services/providers/grok.py +++ b/services/providers/grok.py @@ -49,6 +49,12 @@ class GrokConsoleCompletion: raw_response: dict[str, Any] | None = None +@dataclass(frozen=True) +class GrokConsoleStreamDelta: + content: str = "" + reasoning_content: str = "" + + _THINKING_SUMMARY_RE = re.compile( r"^\s*(?:\*\*)?\s*(?:思考摘要|思考总结|thinking\s+summary|thought\s+summary|reasoning\s+summary|thinking|reasoning)\s*(?:\*\*\s*[::]|[::]\s*(?:\*\*)?)\s*(.*)$", re.IGNORECASE, @@ -209,6 +215,125 @@ def extract_console_completion(payload: dict[str, Any]) -> GrokConsoleCompletion ) +def _text_field(value: object) -> str: + if isinstance(value, str): + return value + if not isinstance(value, dict): + return "" + for key in ("text", "content", "output_text", "reasoning_content", "summary_text"): + text = value.get(key) + if isinstance(text, str) and text: + return text + return "" + + +def extract_console_stream_delta(event: dict[str, Any]) -> GrokConsoleStreamDelta: + event_type = str(event.get("type") or "").lower() + if event_type and "delta" not in event_type: + return GrokConsoleStreamDelta() + text = _text_field(event.get("delta")) + if not text: + text = _text_field(event) + if not text: + return GrokConsoleStreamDelta() + if "reasoning" in event_type or "thinking" in event_type: + return GrokConsoleStreamDelta(reasoning_content=text) + return GrokConsoleStreamDelta(content=text) + + +def _parse_console_stream_payload(payload: str, current_event: str) -> dict[str, Any] | None: + if not payload: + return None + try: + event = json.loads(payload) + except json.JSONDecodeError: + logger.warning({"event": "grok_console_stream_invalid_json"}) + return None + if not isinstance(event, dict): + return None + if current_event and not event.get("type"): + event = {"type": current_event, **event} + return event + + +def _iter_console_stream_events(lines: Iterable[object]) -> Iterator[dict[str, Any]]: + current_event = "" + data_lines: list[str] = [] + + def flush_data() -> dict[str, Any] | None: + nonlocal current_event + if not data_lines: + current_event = "" + return None + payload = "\n".join(data_lines).strip() + data_lines.clear() + event = _parse_console_stream_payload(payload, current_event) + current_event = "" + return event + + for raw_line in lines: + if raw_line is None: + continue + line = raw_line.decode("utf-8", errors="replace") if isinstance(raw_line, bytes) else str(raw_line) + line = line.rstrip("\r\n") + if not line.strip(): + event = flush_data() + if event is not None: + yield event + continue + if line.startswith(":"): + continue + if line.startswith("event:"): + event = flush_data() + if event is not None: + yield event + current_event = line[6:] + if current_event.startswith(" "): + current_event = current_event[1:] + current_event = current_event.strip() + continue + if line.startswith("data:"): + payload = line[5:] + if payload.startswith(" "): + payload = payload[1:] + if payload.strip() == "[DONE]": + event = flush_data() + if event is not None: + yield event + break + data_lines.append(payload) + continue + line = line.strip() + if line.startswith("{"): + event = flush_data() + if event is not None: + yield event + event = _parse_console_stream_payload(line, current_event) + if event is not None: + yield event + + event = flush_data() + if event is not None: + yield event + + +def _raise_for_console_stream_event(event: dict[str, Any]) -> None: + event_type = str(event.get("type") or "").lower() + if event_type not in {"error", "response.failed", "response.error", "response.incomplete", "response.cancelled"}: + return + error = event.get("error") + response = event.get("response") + if not error and isinstance(response, dict): + error = response.get("error") or response.get("incomplete_details") + if isinstance(error, dict): + message = str(error.get("message") or error.get("code") or error.get("reason") or event_type) + elif error: + message = str(error) + else: + message = event_type + raise GrokConsoleError(f"Grok upstream stream error: {message}", 502) + + def _grok_console_profile(): return build_grok_console_profile(config.data) @@ -260,6 +385,50 @@ def _feedback_status(upstream_status: int) -> str | None: return None +def _console_upstream_error_detail(response: object | None) -> str: + if response is None: + return "" + json_data: object = None + json_method = getattr(response, "json", None) + if callable(json_method): + try: + json_data = json_method() + except Exception: + json_data = None + if isinstance(json_data, dict): + error = json_data.get("error") + if isinstance(error, dict): + for key in ("message", "code", "reason", "type"): + value = error.get(key) + if value: + return str(value) + elif error: + return str(error) + for key in ("message", "detail", "code", "reason"): + value = json_data.get(key) + if value: + return str(value) + text = getattr(response, "text", "") or "" + if not text: + content = getattr(response, "content", b"") + if isinstance(content, bytes): + text = content.decode("utf-8", errors="replace") + return str(text).strip()[:400] + + +def _raise_console_upstream_error(access_token: str, upstream_status: int, response: object | None = None) -> None: + feedback_status = _feedback_status(upstream_status) + if feedback_status: + from services.account_service import account_service + + account_service.update_account(access_token, {"status": feedback_status}) + message = f"Grok upstream error (HTTP {upstream_status})" + detail = _console_upstream_error_detail(response) + if detail: + message = f"{message}: {detail}" + raise GrokConsoleError(message, _openai_status(upstream_status), upstream_status) + + class GrokConsoleClient: def __init__(self, access_token: str) -> None: self.access_token = access_token @@ -310,19 +479,39 @@ def create_response(self, payload: dict[str, Any]) -> dict[str, Any]: except requests.exceptions.RequestException as exc: raise GrokConsoleError(f"Grok upstream request failed: {exc}", 502) from exc if response.status_code >= 400: - status = int(response.status_code) - feedback_status = _feedback_status(status) - if feedback_status: - from services.account_service import account_service - - account_service.update_account(self.access_token, {"status": feedback_status}) - message = f"Grok upstream error (HTTP {status})" - raise GrokConsoleError(message, _openai_status(status), status) + _raise_console_upstream_error(self.access_token, int(response.status_code), response) data = response.json() if not isinstance(data, dict): raise GrokConsoleError("Grok upstream returned an invalid response", 502) return data + def stream_response(self, payload: dict[str, Any]) -> Iterator[dict[str, Any]]: + stream_payload = dict(payload) + stream_payload["stream"] = True + try: + response = self._call_with_retry( + lambda: self.session.post( + CONSOLE_RESPONSES_URL, + headers=_headers(self.access_token), + json=stream_payload, + timeout=self.network_profile.timeout, + stream=True, + ), + context="stream_response", + ) + except requests.exceptions.RequestException as exc: + raise GrokConsoleError(f"Grok upstream request failed: {exc}", 502) from exc + if response.status_code >= 400: + _raise_console_upstream_error(self.access_token, int(response.status_code), response) + try: + for event in _iter_console_stream_events(response.iter_lines()): + _raise_for_console_stream_event(event) + yield event + finally: + close = getattr(response, "close", None) + if callable(close): + close() + def _cookie_items(cookie_header: str) -> list[tuple[str, str]]: items: list[tuple[str, str]] = [] @@ -830,5 +1019,26 @@ def console_chat_completion(body: dict[str, Any], spec: ModelSpec, messages: lis return completion +def console_chat_completion_events(body: dict[str, Any], spec: ModelSpec, messages: list[dict[str, Any]]) -> Iterator[dict[str, Any]]: + from services.account_service import account_service + + access_token = account_service.get_text_access_token(provider=GROK_PROVIDER) + if not access_token: + raise HTTPException(status_code=503, detail={"error": "no available Grok account"}) + payload = build_console_payload(spec, body, messages) + mark_used = False + try: + with GrokConsoleClient(access_token) as client: + for event in client.stream_response(payload): + mark_used = True + yield event + mark_used = True + except GrokConsoleError as exc: + raise HTTPException(status_code=exc.status_code, detail={"error": str(exc)}) from exc + finally: + if mark_used: + account_service.mark_text_used(access_token) + + def chat_completion(body: dict[str, Any], spec: ModelSpec, messages: list[dict[str, Any]]) -> str: return console_chat_completion(body, spec, messages).content diff --git a/test/test_grok_provider.py b/test/test_grok_provider.py index d3f80fd..b78aba3 100644 --- a/test/test_grok_provider.py +++ b/test/test_grok_provider.py @@ -152,6 +152,24 @@ def test_extract_console_completion_keeps_plain_content_unchanged(self) -> None: self.assertEqual(response.content, "plain answer") self.assertEqual(response.reasoning_content, "") + def test_extract_console_stream_delta_from_output_text_delta(self) -> None: + delta = grok.extract_console_stream_delta({"type": "response.output_text.delta", "delta": "hello"}) + + self.assertEqual(delta.content, "hello") + self.assertEqual(delta.reasoning_content, "") + + def test_extract_console_stream_delta_from_reasoning_delta(self) -> None: + delta = grok.extract_console_stream_delta({"type": "response.reasoning_summary_text.delta", "delta": "think"}) + + self.assertEqual(delta.content, "") + self.assertEqual(delta.reasoning_content, "think") + + def test_extract_console_stream_delta_ignores_completed_snapshot(self) -> None: + delta = grok.extract_console_stream_delta({"type": "response.completed", "output_text": "complete text"}) + + self.assertEqual(delta.content, "") + self.assertEqual(delta.reasoning_content, "") + def test_app_chat_headers_use_grok_app_shape_with_plain_token(self) -> None: with ( mock.patch.object(grok, "_grok_app_chat_profile", return_value=types.SimpleNamespace( @@ -371,16 +389,27 @@ def test_streaming_grok_chat_completion_returns_openai_chunks(self) -> None: "stream": True, "messages": [{"role": "user", "content": "Hello"}], } - with mock.patch.object(grok, "console_chat_completion", return_value=grok.GrokConsoleCompletion(content="Hi there")): + events = [ + {"type": "response.output_text.delta", "delta": "Hi"}, + {"type": "response.output_text.delta", "delta": " there"}, + {"type": "response.completed"}, + ] + with ( + mock.patch.object(grok, "console_chat_completion_events", return_value=iter(events)) as patched_stream, + mock.patch.object(grok, "console_chat_completion") as patched_blocking, + ): chunks = list(openai_v1_chat_complete.handle(body)) - self.assertEqual(len(chunks), 2) + patched_stream.assert_called_once() + patched_blocking.assert_not_called() + self.assertEqual(len(chunks), 3) self.assertEqual(chunks[0]["object"], "chat.completion.chunk") self.assertEqual(chunks[0]["model"], "grok-4.20-multi-agent") - self.assertEqual(chunks[0]["choices"][0]["delta"], {"role": "assistant", "content": "Hi there"}) + self.assertEqual(chunks[0]["choices"][0]["delta"], {"role": "assistant", "content": "Hi"}) self.assertIsNone(chunks[0]["choices"][0]["finish_reason"]) - self.assertEqual(chunks[1]["choices"][0]["delta"], {}) - self.assertEqual(chunks[1]["choices"][0]["finish_reason"], "stop") + self.assertEqual(chunks[1]["choices"][0]["delta"], {"content": " there"}) + self.assertEqual(chunks[2]["choices"][0]["delta"], {}) + self.assertEqual(chunks[2]["choices"][0]["finish_reason"], "stop") def test_console_grok_reasoning_model_uses_console_path(self) -> None: spec = resolve_model("grok-4.20-reasoning") @@ -393,15 +422,20 @@ def test_streaming_grok_console_completion_emits_reasoning_content(self) -> None "stream": True, "messages": [{"role": "user", "content": "Hello"}], } - with mock.patch.object( - grok, - "console_chat_completion", - return_value=grok.GrokConsoleCompletion(content="Hi", reasoning_content="think"), - ) as patched_console, mock.patch.object(grok, "app_chat_completion_events") as patched_app_chat: + events = [ + {"type": "response.reasoning_summary_text.delta", "delta": "think"}, + {"type": "response.output_text.delta", "delta": "Hi"}, + ] + with ( + mock.patch.object(grok, "console_chat_completion_events", return_value=iter(events)) as patched_console, + mock.patch.object(grok, "app_chat_completion_events") as patched_app_chat, + mock.patch.object(grok, "console_chat_completion") as patched_blocking, + ): chunks = list(openai_v1_chat_complete.handle(body)) patched_console.assert_called_once() patched_app_chat.assert_not_called() + patched_blocking.assert_not_called() self.assertEqual(chunks[0]["choices"][0]["delta"], {"role": "assistant", "reasoning_content": "think"}) self.assertEqual(chunks[1]["choices"][0]["delta"], {"content": "Hi"}) self.assertEqual(chunks[2]["choices"][0]["finish_reason"], "stop") @@ -605,6 +639,321 @@ def close(self) -> None: self.assertEqual(created, [{"impersonate": "edge101", "verify": True}]) self.assertEqual(client.network_profile.timeout, 60) + def test_grok_console_stream_response_parses_sse_lines(self) -> None: + calls: list[dict[str, object]] = [] + closed: list[bool] = [] + + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b": keepalive", + b"event: response.output_text.delta", + b'data: {"type":"response.output_text.delta","delta":"Hi"}', + b"data: [DONE]", + b'data: {"type":"response.output_text.delta","delta":" ignored"}', + ]) + + def close(self) -> None: + closed.append(True) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + calls.append({"url": url, **kwargs}) + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + events = list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertEqual(events, [{"type": "response.output_text.delta", "delta": "Hi"}]) + self.assertEqual(calls[0]["url"], grok.CONSOLE_RESPONSES_URL) + self.assertTrue(calls[0]["stream"]) + self.assertEqual(calls[0]["json"]["stream"], True) + self.assertEqual(closed, [True]) + + def test_grok_console_stream_response_uses_sse_event_name_when_data_has_no_type(self) -> None: + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b"event: response.reasoning_summary_text.delta", + b'data: {"delta":"think"}', + b"event: response.output_text.delta", + b'data: {"delta":"Hi"}', + b"data: [DONE]", + ]) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + events = list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertEqual( + events, + [ + {"type": "response.reasoning_summary_text.delta", "delta": "think"}, + {"type": "response.output_text.delta", "delta": "Hi"}, + ], + ) + self.assertEqual(grok.extract_console_stream_delta(events[0]).reasoning_content, "think") + self.assertEqual(grok.extract_console_stream_delta(events[1]).content, "Hi") + + def test_grok_console_stream_response_aggregates_multiline_sse_data(self) -> None: + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b"event: response.output_text.delta", + b'data: {"delta":', + b'data: "Hi"}', + b"", + b"data:", + b"", + b"data: [DONE]", + ]) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + events = list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertEqual(events, [{"type": "response.output_text.delta", "delta": "Hi"}]) + self.assertEqual(grok.extract_console_stream_delta(events[0]).content, "Hi") + + def test_grok_console_stream_response_resets_sse_event_after_dispatch(self) -> None: + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b"event: response.reasoning_summary_text.delta", + b'data: {"delta":"think"}', + b"", + b'data: {"delta":"plain"}', + b"", + b"data: [DONE]", + ]) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + events = list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertEqual(events[0], {"type": "response.reasoning_summary_text.delta", "delta": "think"}) + self.assertEqual(events[1], {"delta": "plain"}) + self.assertEqual(grok.extract_console_stream_delta(events[0]).reasoning_content, "think") + self.assertEqual(grok.extract_console_stream_delta(events[1]).content, "plain") + + def test_grok_console_stream_marks_account_used_when_generator_is_closed(self) -> None: + account_service = types.SimpleNamespace( + get_text_access_token=mock.Mock(return_value="selected-token"), + mark_text_used=mock.Mock(), + ) + spec = resolve_model("grok-4.3") + + class FakeClient: + def __init__(self, access_token: str) -> None: + self.access_token = access_token + + def __enter__(self) -> "FakeClient": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + pass + + def stream_response(self, payload): + yield {"type": "response.output_text.delta", "delta": "Hi"} + yield {"type": "response.output_text.delta", "delta": " later"} + + with ( + mock.patch.dict(sys.modules, {"services.account_service": types.SimpleNamespace(account_service=account_service)}), + mock.patch.object(grok, "GrokConsoleClient", FakeClient), + ): + events = grok.console_chat_completion_events( + {"model": "grok-4.3"}, + spec, + [{"role": "user", "content": "Hello"}], + ) + self.assertEqual(next(events), {"type": "response.output_text.delta", "delta": "Hi"}) + events.close() + + account_service.mark_text_used.assert_called_once_with("selected-token") + + def test_grok_console_stream_marks_account_used_when_stream_completes_without_events(self) -> None: + account_service = types.SimpleNamespace( + get_text_access_token=mock.Mock(return_value="selected-token"), + mark_text_used=mock.Mock(), + ) + spec = resolve_model("grok-4.3") + + class FakeClient: + def __init__(self, access_token: str) -> None: + self.access_token = access_token + + def __enter__(self) -> "FakeClient": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + pass + + def stream_response(self, payload): + return iter(()) + + with ( + mock.patch.dict(sys.modules, {"services.account_service": types.SimpleNamespace(account_service=account_service)}), + mock.patch.object(grok, "GrokConsoleClient", FakeClient), + ): + events = list(grok.console_chat_completion_events( + {"model": "grok-4.3"}, + spec, + [{"role": "user", "content": "Hello"}], + )) + + self.assertEqual(events, []) + account_service.mark_text_used.assert_called_once_with("selected-token") + + def test_grok_console_stream_marks_account_used_after_partial_stream_error(self) -> None: + account_service = types.SimpleNamespace( + get_text_access_token=mock.Mock(return_value="selected-token"), + mark_text_used=mock.Mock(), + ) + spec = resolve_model("grok-4.3") + + class FakeClient: + def __init__(self, access_token: str) -> None: + self.access_token = access_token + + def __enter__(self) -> "FakeClient": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + pass + + def stream_response(self, payload): + yield {"type": "response.output_text.delta", "delta": "Hi"} + raise grok.GrokConsoleError("stream failed", 502) + + with ( + mock.patch.dict(sys.modules, {"services.account_service": types.SimpleNamespace(account_service=account_service)}), + mock.patch.object(grok, "GrokConsoleClient", FakeClient), + ): + events = grok.console_chat_completion_events( + {"model": "grok-4.3"}, + spec, + [{"role": "user", "content": "Hello"}], + ) + self.assertEqual(next(events), {"type": "response.output_text.delta", "delta": "Hi"}) + with self.assertRaises(grok.HTTPException): + next(events) + + account_service.mark_text_used.assert_called_once_with("selected-token") + + def test_grok_console_stream_response_raises_stream_errors(self) -> None: + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b'data: {"type":"response.failed","error":{"message":"upstream failed"}}', + ]) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + with self.assertRaises(grok.GrokConsoleError) as ctx: + list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertIn("upstream failed", str(ctx.exception)) + + def test_grok_console_stream_response_includes_upstream_error_detail(self) -> None: + account_service = types.SimpleNamespace(update_account=mock.Mock()) + + class FakeResponse: + status_code = 402 + + def json(self): + return {"error": {"message": "quota exhausted"}} + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with ( + mock.patch.dict(sys.modules, {"services.account_service": types.SimpleNamespace(account_service=account_service)}), + mock.patch.object(grok.config, "data", {}), + mock.patch("curl_cffi.requests.Session", FakeSession), + ): + client = grok.GrokConsoleClient("token-value") + with self.assertRaises(grok.GrokConsoleError) as ctx: + list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertIn("quota exhausted", str(ctx.exception)) + account_service.update_account.assert_called_once_with("token-value", {"status": "限流"}) + def test_grok_console_uses_configured_network_profile(self) -> None: settings = { "network_profiles": {