diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 9fd2336..ade7595 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -2,9 +2,16 @@ name: Publish Docker Image on: push: + branches: + - main tags: - "v*" workflow_dispatch: + inputs: + platforms: + description: "Comma-separated Docker platforms to build" + required: false + default: "linux/amd64" env: IMAGE_NAME: webchat2api @@ -22,7 +29,44 @@ jobs: - name: Checkout uses: actions/checkout@v4 + - name: Prepare Docker build settings + id: build_settings + shell: bash + env: + REQUESTED_PLATFORMS: ${{ github.event.inputs.platforms || '' }} + run: | + if [[ "${GITHUB_EVENT_NAME}" == "workflow_dispatch" && -n "${REQUESTED_PLATFORMS}" ]]; then + platforms="${REQUESTED_PLATFORMS}" + elif [[ "${GITHUB_REF_TYPE}" == "tag" ]]; then + platforms="linux/amd64,linux/arm64" + else + platforms="linux/amd64" + fi + + platforms="${platforms//[[:space:]]/}" + if [[ -z "${platforms}" ]]; then + platforms="linux/amd64" + fi + + IFS=',' read -ra requested_platforms <<< "${platforms}" + for platform in "${requested_platforms[@]}"; do + if [[ "${platform}" != "linux/amd64" && "${platform}" != "linux/arm64" ]]; then + echo "Unsupported Docker platform: ${platform}. Supported platforms: linux/amd64,linux/arm64" >&2 + exit 1 + fi + done + + if [[ "${platforms}" == *,* ]]; then + cache_mode="max" + else + cache_mode="min" + fi + + echo "platforms=${platforms}" >> "$GITHUB_OUTPUT" + echo "cache_mode=${cache_mode}" >> "$GITHUB_OUTPUT" + - name: Set up QEMU + if: ${{ contains(steps.build_settings.outputs.platforms, 'arm') }} uses: docker/setup-qemu-action@v3 with: platforms: arm64 @@ -64,9 +108,9 @@ jobs: context: . file: ./Dockerfile target: app - platforms: linux/amd64,linux/arm64 + platforms: ${{ steps.build_settings.outputs.platforms }} push: true tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} cache-from: type=gha - cache-to: type=gha,mode=max + cache-to: type=gha,mode=${{ steps.build_settings.outputs.cache_mode }} diff --git a/.github/workflows/pr-target-branch.yml b/.github/workflows/pr-target-branch.yml index 902c78d..2d79ab6 100644 --- a/.github/workflows/pr-target-branch.yml +++ b/.github/workflows/pr-target-branch.yml @@ -113,9 +113,14 @@ jobs: name: wrongTargetBranchLabel, }); } catch (error) { - if (error.status !== 404) { - throw error; + if (error.status === 404) { + return; + } + if (error.status === 403) { + core.info(`Could not remove ${wrongTargetBranchLabel} label: ${error.message}`); + return; } + throw error; } }; diff --git a/README.md b/README.md index c3b282f..eb7b905 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,7 @@ npm run dev Grok Console 与 grok.com app-chat 是不同上游路径。本项目没有接入官方 xAI API,也不声称提供官方兼容能力。Console 路径可使用 `network_profiles.grok_console.cf_clearance` 附加手动 Cookie;app-chat 路径可使用 `network_profiles.grok_app_chat` 覆盖 UA、impersonate、`cf_clearance`、`cf_cookies`、`sec-ch-ua`、`x-statsig-id` 等字段。 -如配置 `flaresolverr_url`,直接 app-chat 请求遇到 Cloudflare 或 403 时会尝试通过 FlareSolverr 刷新 clearance 并重试。Browser Bridge 是独立浏览器路径,后端会优先使用 `browser_bridge_url`,未配置时会探测 `http://127.0.0.1:3080/health`。Browser Bridge 的接口是 `POST /api/chat {sso,payload}` 和 `GET /health`,请求会经真实 Chromium 页面发往 grok.com。 +如配置 `flaresolverr_url`,直接 app-chat 请求遇到 Cloudflare 或 403 时会尝试通过 FlareSolverr 刷新 clearance 并重试。Browser Bridge 是独立浏览器路径;显式配置 `browser_bridge_url` 时,后端会优先使用该 Bridge。未配置时,app-chat 默认先走直接请求;直接请求遇到 `403`、`408`、`502`、`503`、`504` 时,才会尝试探测并回退到 `http://127.0.0.1:3080/health` 对应的 Browser Bridge。Browser Bridge 的接口是 `POST /api/chat {sso,payload}` 和 `GET /health`,请求会经真实 Chromium 页面发往 grok.com。 > [!WARNING] > Cloudflare、WAF、账号风控和上游配额都可能变化。手动 clearance、FlareSolverr 和 Browser Bridge 都是尽力而为,不能保证长期可用。 diff --git a/services/protocol/openai_v1_chat_complete.py b/services/protocol/openai_v1_chat_complete.py index 96beb7b..c1b207c 100644 --- a/services/protocol/openai_v1_chat_complete.py +++ b/services/protocol/openai_v1_chat_complete.py @@ -115,12 +115,26 @@ def stream_grok_chat_completion(body: dict[str, Any], spec, messages: list[dict[ return completion_id = f"chatcmpl-{uuid.uuid4().hex}" created = int(time.time()) - response = grok.console_chat_completion(body, spec, messages) - if response.reasoning_content: - yield completion_chunk(model, {"role": "assistant", "reasoning_content": response.reasoning_content}, None, completion_id, created) - yield completion_chunk(model, {"content": response.content}, None, completion_id, created) - else: - yield completion_chunk(model, {"role": "assistant", "content": response.content}, None, completion_id, created) + sent_role = False + for event in grok.console_chat_completion_events(body, spec, messages): + delta = grok.extract_console_stream_delta(event) + if not delta.content and not delta.reasoning_content: + continue + if not sent_role: + sent_role = True + first_delta: dict[str, Any] = {"role": "assistant"} + if delta.reasoning_content: + first_delta["reasoning_content"] = delta.reasoning_content + else: + first_delta["content"] = delta.content + yield completion_chunk(model, first_delta, None, completion_id, created) + continue + if delta.reasoning_content: + yield completion_chunk(model, {"reasoning_content": delta.reasoning_content}, None, completion_id, created) + else: + yield completion_chunk(model, {"content": delta.content}, None, completion_id, created) + if not sent_role: + yield completion_chunk(model, {"role": "assistant", "content": ""}, None, completion_id, created) yield completion_chunk(model, {}, "stop", completion_id, created) diff --git a/services/protocol/openai_v1_response.py b/services/protocol/openai_v1_response.py index 32cbf0a..533f66e 100644 --- a/services/protocol/openai_v1_response.py +++ b/services/protocol/openai_v1_response.py @@ -7,6 +7,8 @@ from fastapi import HTTPException +from services.models import GROK_PROVIDER, is_grok_app_chat_model, resolve_model +from services.providers import grok from services.protocol.conversation import ( ConversationRequest, ImageOutput, @@ -150,6 +152,29 @@ def stream_text_response(backend, body: dict[str, Any]) -> Iterator[dict[str, An yield response_completed(response_id, model, created, [item]) +def stream_grok_console_response(body: dict[str, Any]) -> Iterator[dict[str, Any]]: + model = str(body.get("model") or "auto").strip() or "auto" + spec = resolve_model(model) + if is_grok_app_chat_model(spec): + raise HTTPException(status_code=501, detail={"error": "Grok app-chat is not supported on /v1/responses"}) + messages = messages_from_input(body.get("input"), body.get("instructions")) + if not messages: + raise HTTPException(status_code=400, detail={"error": "input text is required"}) + response_id = f"resp_{uuid.uuid4().hex}" + item_id = f"msg_{uuid.uuid4().hex}" + created = int(time.time()) + yield response_created(response_id, model, created) + yield {"type": "response.output_item.added", "output_index": 0, "item": text_output_item("", item_id, "in_progress")} + completion = grok.console_chat_completion(body, spec, messages) + text = completion.content + if text: + yield {"type": "response.output_text.delta", "item_id": item_id, "output_index": 0, "content_index": 0, "delta": text} + yield {"type": "response.output_text.done", "item_id": item_id, "output_index": 0, "content_index": 0, "text": text} + item = text_output_item(text, item_id, "completed") + yield {"type": "response.output_item.done", "output_index": 0, "item": item} + yield response_completed(response_id, model, created, [item]) + + def stream_image_response(image_outputs: Iterable[ImageOutput], prompt: str, model: str) -> Iterator[dict[str, Any]]: response_id = f"resp_{uuid.uuid4().hex}" created = int(time.time()) @@ -186,6 +211,11 @@ def collect_response(events: Iterable[dict[str, Any]]) -> dict[str, Any]: def response_events(body: dict[str, Any]) -> Iterator[dict[str, Any]]: if is_text_response_request(body): + model = str(body.get("model") or "auto").strip() or "auto" + spec = resolve_model(model) + if spec.provider == GROK_PROVIDER: + yield from stream_grok_console_response(body) + return yield from stream_text_response(text_backend(), body) return diff --git a/services/providers/grok.py b/services/providers/grok.py index 0848abc..9515a4b 100644 --- a/services/providers/grok.py +++ b/services/providers/grok.py @@ -49,6 +49,12 @@ class GrokConsoleCompletion: raw_response: dict[str, Any] | None = None +@dataclass(frozen=True) +class GrokConsoleStreamDelta: + content: str = "" + reasoning_content: str = "" + + _THINKING_SUMMARY_RE = re.compile( r"^\s*(?:\*\*)?\s*(?:思考摘要|思考总结|thinking\s+summary|thought\s+summary|reasoning\s+summary|thinking|reasoning)\s*(?:\*\*\s*[::]|[::]\s*(?:\*\*)?)\s*(.*)$", re.IGNORECASE, @@ -57,6 +63,7 @@ class GrokConsoleCompletion: r"^\s*(?:\*\*)?\s*(?:答案|回答|answer|final\s+answer|response)\s*(?:\*\*\s*[::]|[::]\s*(?:\*\*)?)\s*(.*)$", re.IGNORECASE, ) +_CONSOLE_SEARCH_TOOL_TYPES = {"web_search", "x_search"} def split_visible_console_reasoning(text: str) -> tuple[str, str]: @@ -136,6 +143,19 @@ def build_console_input(messages: list[dict[str, Any]]) -> tuple[str, list[dict[ return "\n\n".join(instructions).strip(), input_items +def _console_search_tools(tools: object) -> list[dict[str, Any]]: + search_tools: list[dict[str, Any]] = [] + if isinstance(tools, list): + for tool in tools: + if not isinstance(tool, dict): + continue + if str(tool.get("type") or "") in _CONSOLE_SEARCH_TOOL_TYPES: + search_tools.append(dict(tool)) + if not any(tool.get("type") == "web_search" for tool in search_tools): + search_tools.append({"type": "web_search"}) + return search_tools + + def build_console_payload(spec: ModelSpec, body: dict[str, Any], messages: list[dict[str, Any]]) -> dict[str, Any]: instructions, input_items = build_console_input(messages) if not input_items: @@ -143,6 +163,7 @@ def build_console_payload(spec: ModelSpec, body: dict[str, Any], messages: list[ payload: dict[str, Any] = { "model": spec.upstream_model or spec.id, "input": input_items, + "tools": _console_search_tools(body.get("tools")), } request_instructions = str(body.get("instructions") or "").strip() merged_instructions = "\n\n".join(item for item in [request_instructions, instructions] if item) @@ -152,6 +173,9 @@ def build_console_payload(spec: ModelSpec, body: dict[str, Any], messages: list[ if body.get(key) is not None: target_key = "max_output_tokens" if key == "max_tokens" else key payload[target_key] = body[key] + for key in ("tool_choice", "parallel_tool_calls"): + if body.get(key) is not None: + payload[key] = body[key] reasoning_effort = str(body.get("reasoning_effort") or spec.default_reasoning_effort or "").strip().lower() if reasoning_effort and reasoning_effort != "none": payload["reasoning"] = {"effort": "high" if reasoning_effort == "xhigh" else reasoning_effort} @@ -209,6 +233,125 @@ def extract_console_completion(payload: dict[str, Any]) -> GrokConsoleCompletion ) +def _text_field(value: object) -> str: + if isinstance(value, str): + return value + if not isinstance(value, dict): + return "" + for key in ("text", "content", "output_text", "reasoning_content", "summary_text"): + text = value.get(key) + if isinstance(text, str) and text: + return text + return "" + + +def extract_console_stream_delta(event: dict[str, Any]) -> GrokConsoleStreamDelta: + event_type = str(event.get("type") or "").lower() + if event_type and "delta" not in event_type: + return GrokConsoleStreamDelta() + text = _text_field(event.get("delta")) + if not text: + text = _text_field(event) + if not text: + return GrokConsoleStreamDelta() + if "reasoning" in event_type or "thinking" in event_type: + return GrokConsoleStreamDelta(reasoning_content=text) + return GrokConsoleStreamDelta(content=text) + + +def _parse_console_stream_payload(payload: str, current_event: str) -> dict[str, Any] | None: + if not payload: + return None + try: + event = json.loads(payload) + except json.JSONDecodeError: + logger.warning({"event": "grok_console_stream_invalid_json"}) + return None + if not isinstance(event, dict): + return None + if current_event and not event.get("type"): + event = {"type": current_event, **event} + return event + + +def _iter_console_stream_events(lines: Iterable[object]) -> Iterator[dict[str, Any]]: + current_event = "" + data_lines: list[str] = [] + + def flush_data() -> dict[str, Any] | None: + nonlocal current_event + if not data_lines: + current_event = "" + return None + payload = "\n".join(data_lines).strip() + data_lines.clear() + event = _parse_console_stream_payload(payload, current_event) + current_event = "" + return event + + for raw_line in lines: + if raw_line is None: + continue + line = raw_line.decode("utf-8", errors="replace") if isinstance(raw_line, bytes) else str(raw_line) + line = line.rstrip("\r\n") + if not line.strip(): + event = flush_data() + if event is not None: + yield event + continue + if line.startswith(":"): + continue + if line.startswith("event:"): + event = flush_data() + if event is not None: + yield event + current_event = line[6:] + if current_event.startswith(" "): + current_event = current_event[1:] + current_event = current_event.strip() + continue + if line.startswith("data:"): + payload = line[5:] + if payload.startswith(" "): + payload = payload[1:] + if payload.strip() == "[DONE]": + event = flush_data() + if event is not None: + yield event + break + data_lines.append(payload) + continue + line = line.strip() + if line.startswith("{"): + event = flush_data() + if event is not None: + yield event + event = _parse_console_stream_payload(line, current_event) + if event is not None: + yield event + + event = flush_data() + if event is not None: + yield event + + +def _raise_for_console_stream_event(event: dict[str, Any]) -> None: + event_type = str(event.get("type") or "").lower() + if event_type not in {"error", "response.failed", "response.error", "response.incomplete", "response.cancelled"}: + return + error = event.get("error") + response = event.get("response") + if not error and isinstance(response, dict): + error = response.get("error") or response.get("incomplete_details") + if isinstance(error, dict): + message = str(error.get("message") or error.get("code") or error.get("reason") or event_type) + elif error: + message = str(error) + else: + message = event_type + raise GrokConsoleError(f"Grok upstream stream error: {message}", 502) + + def _grok_console_profile(): return build_grok_console_profile(config.data) @@ -260,6 +403,50 @@ def _feedback_status(upstream_status: int) -> str | None: return None +def _console_upstream_error_detail(response: object | None) -> str: + if response is None: + return "" + json_data: object = None + json_method = getattr(response, "json", None) + if callable(json_method): + try: + json_data = json_method() + except Exception: + json_data = None + if isinstance(json_data, dict): + error = json_data.get("error") + if isinstance(error, dict): + for key in ("message", "code", "reason", "type"): + value = error.get(key) + if value: + return str(value) + elif error: + return str(error) + for key in ("message", "detail", "code", "reason"): + value = json_data.get(key) + if value: + return str(value) + text = getattr(response, "text", "") or "" + if not text: + content = getattr(response, "content", b"") + if isinstance(content, bytes): + text = content.decode("utf-8", errors="replace") + return str(text).strip()[:400] + + +def _raise_console_upstream_error(access_token: str, upstream_status: int, response: object | None = None) -> None: + feedback_status = _feedback_status(upstream_status) + if feedback_status: + from services.account_service import account_service + + account_service.update_account(access_token, {"status": feedback_status}) + message = f"Grok upstream error (HTTP {upstream_status})" + detail = _console_upstream_error_detail(response) + if detail: + message = f"{message}: {detail}" + raise GrokConsoleError(message, _openai_status(upstream_status), upstream_status) + + class GrokConsoleClient: def __init__(self, access_token: str) -> None: self.access_token = access_token @@ -310,19 +497,39 @@ def create_response(self, payload: dict[str, Any]) -> dict[str, Any]: except requests.exceptions.RequestException as exc: raise GrokConsoleError(f"Grok upstream request failed: {exc}", 502) from exc if response.status_code >= 400: - status = int(response.status_code) - feedback_status = _feedback_status(status) - if feedback_status: - from services.account_service import account_service - - account_service.update_account(self.access_token, {"status": feedback_status}) - message = f"Grok upstream error (HTTP {status})" - raise GrokConsoleError(message, _openai_status(status), status) + _raise_console_upstream_error(self.access_token, int(response.status_code), response) data = response.json() if not isinstance(data, dict): raise GrokConsoleError("Grok upstream returned an invalid response", 502) return data + def stream_response(self, payload: dict[str, Any]) -> Iterator[dict[str, Any]]: + stream_payload = dict(payload) + stream_payload["stream"] = True + try: + response = self._call_with_retry( + lambda: self.session.post( + CONSOLE_RESPONSES_URL, + headers=_headers(self.access_token), + json=stream_payload, + timeout=self.network_profile.timeout, + stream=True, + ), + context="stream_response", + ) + except requests.exceptions.RequestException as exc: + raise GrokConsoleError(f"Grok upstream request failed: {exc}", 502) from exc + if response.status_code >= 400: + _raise_console_upstream_error(self.access_token, int(response.status_code), response) + try: + for event in _iter_console_stream_events(response.iter_lines()): + _raise_for_console_stream_event(event) + yield event + finally: + close = getattr(response, "close", None) + if callable(close): + close() + def _cookie_items(cookie_header: str) -> list[tuple[str, str]]: items: list[tuple[str, str]] = [] @@ -706,12 +913,7 @@ def _try_browser_bridge(self, payload: dict[str, Any]) -> list[str] | None: logger.warning({"event": "browser_bridge_unavailable"}) return None - def stream_events(self, payload: dict[str, Any]) -> Iterator[dict[str, Any]]: - bridge_lines = self._try_browser_bridge(payload) - if bridge_lines is not None: - yield from app_chat_line_events(bridge_lines) - return - response = None + def _stream_direct_events(self, payload: dict[str, Any]) -> Iterator[dict[str, Any]]: try: response = self._call_with_retry( lambda: self.session.post( @@ -743,6 +945,26 @@ def stream_events(self, payload: dict[str, Any]) -> Iterator[dict[str, Any]]: raise classify_app_chat_upstream_error(int(response.status_code), self.access_token) yield from app_chat_line_events(response.iter_lines()) + def stream_events(self, payload: dict[str, Any]) -> Iterator[dict[str, Any]]: + bridge_first = bool(config.browser_bridge_url) + if bridge_first: + bridge_lines = self._try_browser_bridge(payload) + if bridge_lines is not None: + yield from app_chat_line_events(bridge_lines) + return + try: + yield from self._stream_direct_events(payload) + return + except GrokConsoleError as exc: + status = exc.upstream_status or exc.status_code + if bridge_first or status not in {403, 408, 502, 503, 504}: + raise + bridge_lines = self._try_browser_bridge(payload) + if bridge_lines is None: + raise + logger.info({"event": "grok_app_chat_direct_fallback_to_bridge", "status": status}) + yield from app_chat_line_events(bridge_lines) + def app_chat_completion_events(body: dict[str, Any], spec: ModelSpec, messages: list[dict[str, Any]]) -> Iterator[dict[str, Any]]: from services.account_service import account_service @@ -830,5 +1052,26 @@ def console_chat_completion(body: dict[str, Any], spec: ModelSpec, messages: lis return completion +def console_chat_completion_events(body: dict[str, Any], spec: ModelSpec, messages: list[dict[str, Any]]) -> Iterator[dict[str, Any]]: + from services.account_service import account_service + + access_token = account_service.get_text_access_token(provider=GROK_PROVIDER) + if not access_token: + raise HTTPException(status_code=503, detail={"error": "no available Grok account"}) + payload = build_console_payload(spec, body, messages) + mark_used = False + try: + with GrokConsoleClient(access_token) as client: + for event in client.stream_response(payload): + mark_used = True + yield event + mark_used = True + except GrokConsoleError as exc: + raise HTTPException(status_code=exc.status_code, detail={"error": str(exc)}) from exc + finally: + if mark_used: + account_service.mark_text_used(access_token) + + def chat_completion(body: dict[str, Any], spec: ModelSpec, messages: list[dict[str, Any]]) -> str: return console_chat_completion(body, spec, messages).content diff --git a/test/test_grok_provider.py b/test/test_grok_provider.py index d3f80fd..52faf87 100644 --- a/test/test_grok_provider.py +++ b/test/test_grok_provider.py @@ -79,9 +79,11 @@ def to_openai_error(self) -> dict[str, object]: conversation.text_backend = lambda: object() sys.modules["services.protocol.conversation"] = conversation +from fastapi import HTTPException + from services.models import resolve_model from services.network import flaresolverr -from services.protocol import openai_v1_chat_complete +from services.protocol import openai_v1_chat_complete, openai_v1_response from services.providers import grok @@ -106,6 +108,44 @@ def test_build_console_payload_converts_chat_messages(self) -> None: self.assertEqual(payload["input"][0]["content"], [{"type": "input_text", "text": "Hello"}]) self.assertEqual(payload["input"][1]["content"], [{"type": "output_text", "text": "Hi"}]) + def test_build_console_payload_defaults_web_search_tool(self) -> None: + spec = resolve_model("grok-4.3") + payload = grok.build_console_payload( + spec, + {}, + [{"role": "user", "content": "Search the web."}], + ) + + self.assertEqual(payload["tools"], [{"type": "web_search"}]) + + def test_build_console_payload_preserves_supported_search_tools(self) -> None: + spec = resolve_model("grok-4.3") + web_search = {"type": "web_search", "allowed_websites": ["example.com"]} + x_search = {"type": "x_search", "post_favorite_count": 10} + payload = grok.build_console_payload( + spec, + {"tools": [web_search, {"type": "image_generation"}, x_search]}, + [{"role": "user", "content": "Search the web."}], + ) + + self.assertEqual(payload["tools"], [web_search, x_search]) + + def test_build_console_payload_preserves_response_tool_controls(self) -> None: + spec = resolve_model("grok-4.3") + payload = grok.build_console_payload( + spec, + { + "tools": [{"type": "web_search"}], + "tool_choice": "auto", + "parallel_tool_calls": True, + }, + [{"role": "user", "content": "Search the web"}], + ) + + self.assertEqual(payload["tools"], [{"type": "web_search"}]) + self.assertEqual(payload["tool_choice"], "auto") + self.assertTrue(payload["parallel_tool_calls"]) + def test_extract_console_text_from_common_shapes(self) -> None: self.assertEqual(grok.extract_console_text({"output_text": "direct"}), "direct") self.assertEqual( @@ -152,6 +192,24 @@ def test_extract_console_completion_keeps_plain_content_unchanged(self) -> None: self.assertEqual(response.content, "plain answer") self.assertEqual(response.reasoning_content, "") + def test_extract_console_stream_delta_from_output_text_delta(self) -> None: + delta = grok.extract_console_stream_delta({"type": "response.output_text.delta", "delta": "hello"}) + + self.assertEqual(delta.content, "hello") + self.assertEqual(delta.reasoning_content, "") + + def test_extract_console_stream_delta_from_reasoning_delta(self) -> None: + delta = grok.extract_console_stream_delta({"type": "response.reasoning_summary_text.delta", "delta": "think"}) + + self.assertEqual(delta.content, "") + self.assertEqual(delta.reasoning_content, "think") + + def test_extract_console_stream_delta_ignores_completed_snapshot(self) -> None: + delta = grok.extract_console_stream_delta({"type": "response.completed", "output_text": "complete text"}) + + self.assertEqual(delta.content, "") + self.assertEqual(delta.reasoning_content, "") + def test_app_chat_headers_use_grok_app_shape_with_plain_token(self) -> None: with ( mock.patch.object(grok, "_grok_app_chat_profile", return_value=types.SimpleNamespace( @@ -371,16 +429,27 @@ def test_streaming_grok_chat_completion_returns_openai_chunks(self) -> None: "stream": True, "messages": [{"role": "user", "content": "Hello"}], } - with mock.patch.object(grok, "console_chat_completion", return_value=grok.GrokConsoleCompletion(content="Hi there")): + events = [ + {"type": "response.output_text.delta", "delta": "Hi"}, + {"type": "response.output_text.delta", "delta": " there"}, + {"type": "response.completed"}, + ] + with ( + mock.patch.object(grok, "console_chat_completion_events", return_value=iter(events)) as patched_stream, + mock.patch.object(grok, "console_chat_completion") as patched_blocking, + ): chunks = list(openai_v1_chat_complete.handle(body)) - self.assertEqual(len(chunks), 2) + patched_stream.assert_called_once() + patched_blocking.assert_not_called() + self.assertEqual(len(chunks), 3) self.assertEqual(chunks[0]["object"], "chat.completion.chunk") self.assertEqual(chunks[0]["model"], "grok-4.20-multi-agent") - self.assertEqual(chunks[0]["choices"][0]["delta"], {"role": "assistant", "content": "Hi there"}) + self.assertEqual(chunks[0]["choices"][0]["delta"], {"role": "assistant", "content": "Hi"}) self.assertIsNone(chunks[0]["choices"][0]["finish_reason"]) - self.assertEqual(chunks[1]["choices"][0]["delta"], {}) - self.assertEqual(chunks[1]["choices"][0]["finish_reason"], "stop") + self.assertEqual(chunks[1]["choices"][0]["delta"], {"content": " there"}) + self.assertEqual(chunks[2]["choices"][0]["delta"], {}) + self.assertEqual(chunks[2]["choices"][0]["finish_reason"], "stop") def test_console_grok_reasoning_model_uses_console_path(self) -> None: spec = resolve_model("grok-4.20-reasoning") @@ -393,15 +462,20 @@ def test_streaming_grok_console_completion_emits_reasoning_content(self) -> None "stream": True, "messages": [{"role": "user", "content": "Hello"}], } - with mock.patch.object( - grok, - "console_chat_completion", - return_value=grok.GrokConsoleCompletion(content="Hi", reasoning_content="think"), - ) as patched_console, mock.patch.object(grok, "app_chat_completion_events") as patched_app_chat: + events = [ + {"type": "response.reasoning_summary_text.delta", "delta": "think"}, + {"type": "response.output_text.delta", "delta": "Hi"}, + ] + with ( + mock.patch.object(grok, "console_chat_completion_events", return_value=iter(events)) as patched_console, + mock.patch.object(grok, "app_chat_completion_events") as patched_app_chat, + mock.patch.object(grok, "console_chat_completion") as patched_blocking, + ): chunks = list(openai_v1_chat_complete.handle(body)) patched_console.assert_called_once() patched_app_chat.assert_not_called() + patched_blocking.assert_not_called() self.assertEqual(chunks[0]["choices"][0]["delta"], {"role": "assistant", "reasoning_content": "think"}) self.assertEqual(chunks[1]["choices"][0]["delta"], {"content": "Hi"}) self.assertEqual(chunks[2]["choices"][0]["finish_reason"], "stop") @@ -453,6 +527,73 @@ def test_non_streaming_grok_console_completion_includes_reasoning_content(self) self.assertEqual(message["content"], "Hi") self.assertEqual(message["reasoning_content"], "think") + def test_responses_grok_console_routes_to_console_completion(self) -> None: + body = { + "model": "grok-4.3", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}], + "tools": [{"type": "web_search"}], + } + with mock.patch.object( + grok, + "console_chat_completion", + return_value=grok.GrokConsoleCompletion(content="Hi from Grok"), + ) as patched_console: + response = openai_v1_response.handle(body) + + patched_console.assert_called_once() + self.assertEqual(patched_console.call_args.args[0]["tools"], [{"type": "web_search"}]) + self.assertEqual(response["object"], "response") + self.assertEqual(response["status"], "completed") + content = response["output"][0]["content"][0] + self.assertEqual(content["type"], "output_text") + self.assertEqual(content["text"], "Hi from Grok") + + def test_streaming_responses_grok_console_emits_response_events(self) -> None: + body = { + "model": "grok-4.3", + "input": "Hello", + "stream": True, + } + with mock.patch.object( + grok, + "console_chat_completion", + return_value=grok.GrokConsoleCompletion(content="Hi"), + ) as patched_console: + events = list(openai_v1_response.handle(body)) + + patched_console.assert_called_once() + event_types = [event.get("type") for event in events] + self.assertEqual(event_types[0], "response.created") + self.assertIn("response.output_text.delta", event_types) + self.assertEqual(event_types[-1], "response.completed") + + def test_responses_unknown_non_grok_model_uses_text_backend(self) -> None: + body = { + "model": "custom-text-model", + "input": "Hello", + } + with ( + mock.patch.object(openai_v1_response, "ConversationRequest", lambda **kwargs: kwargs), + mock.patch.object(openai_v1_response, "stream_text_deltas", return_value=iter(["generic"])) as patched_stream, + mock.patch.object(grok, "console_chat_completion") as patched_console, + ): + response = openai_v1_response.handle(body) + + patched_stream.assert_called_once() + patched_console.assert_not_called() + self.assertEqual(response["output"][0]["content"][0]["text"], "generic") + + def test_responses_grok_app_chat_returns_explicit_error(self) -> None: + body = { + "model": "grok-4.20-heavy", + "input": "Hello", + } + with self.assertRaises(HTTPException) as ctx: + list(openai_v1_response.handle(body)) + + self.assertEqual(getattr(ctx.exception, "status_code", None), 501) + self.assertIn("Grok app-chat is not supported", str(getattr(ctx.exception, "detail", ""))) + def test_grok_image_lite_chat_routes_to_app_chat_image_outputs(self) -> None: body = { "model": "grok-imagine-image-lite", @@ -605,6 +746,321 @@ def close(self) -> None: self.assertEqual(created, [{"impersonate": "edge101", "verify": True}]) self.assertEqual(client.network_profile.timeout, 60) + def test_grok_console_stream_response_parses_sse_lines(self) -> None: + calls: list[dict[str, object]] = [] + closed: list[bool] = [] + + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b": keepalive", + b"event: response.output_text.delta", + b'data: {"type":"response.output_text.delta","delta":"Hi"}', + b"data: [DONE]", + b'data: {"type":"response.output_text.delta","delta":" ignored"}', + ]) + + def close(self) -> None: + closed.append(True) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + calls.append({"url": url, **kwargs}) + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + events = list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertEqual(events, [{"type": "response.output_text.delta", "delta": "Hi"}]) + self.assertEqual(calls[0]["url"], grok.CONSOLE_RESPONSES_URL) + self.assertTrue(calls[0]["stream"]) + self.assertEqual(calls[0]["json"]["stream"], True) + self.assertEqual(closed, [True]) + + def test_grok_console_stream_response_uses_sse_event_name_when_data_has_no_type(self) -> None: + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b"event: response.reasoning_summary_text.delta", + b'data: {"delta":"think"}', + b"event: response.output_text.delta", + b'data: {"delta":"Hi"}', + b"data: [DONE]", + ]) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + events = list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertEqual( + events, + [ + {"type": "response.reasoning_summary_text.delta", "delta": "think"}, + {"type": "response.output_text.delta", "delta": "Hi"}, + ], + ) + self.assertEqual(grok.extract_console_stream_delta(events[0]).reasoning_content, "think") + self.assertEqual(grok.extract_console_stream_delta(events[1]).content, "Hi") + + def test_grok_console_stream_response_aggregates_multiline_sse_data(self) -> None: + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b"event: response.output_text.delta", + b'data: {"delta":', + b'data: "Hi"}', + b"", + b"data:", + b"", + b"data: [DONE]", + ]) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + events = list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertEqual(events, [{"type": "response.output_text.delta", "delta": "Hi"}]) + self.assertEqual(grok.extract_console_stream_delta(events[0]).content, "Hi") + + def test_grok_console_stream_response_resets_sse_event_after_dispatch(self) -> None: + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b"event: response.reasoning_summary_text.delta", + b'data: {"delta":"think"}', + b"", + b'data: {"delta":"plain"}', + b"", + b"data: [DONE]", + ]) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + events = list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertEqual(events[0], {"type": "response.reasoning_summary_text.delta", "delta": "think"}) + self.assertEqual(events[1], {"delta": "plain"}) + self.assertEqual(grok.extract_console_stream_delta(events[0]).reasoning_content, "think") + self.assertEqual(grok.extract_console_stream_delta(events[1]).content, "plain") + + def test_grok_console_stream_marks_account_used_when_generator_is_closed(self) -> None: + account_service = types.SimpleNamespace( + get_text_access_token=mock.Mock(return_value="selected-token"), + mark_text_used=mock.Mock(), + ) + spec = resolve_model("grok-4.3") + + class FakeClient: + def __init__(self, access_token: str) -> None: + self.access_token = access_token + + def __enter__(self) -> "FakeClient": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + pass + + def stream_response(self, payload): + yield {"type": "response.output_text.delta", "delta": "Hi"} + yield {"type": "response.output_text.delta", "delta": " later"} + + with ( + mock.patch.dict(sys.modules, {"services.account_service": types.SimpleNamespace(account_service=account_service)}), + mock.patch.object(grok, "GrokConsoleClient", FakeClient), + ): + events = grok.console_chat_completion_events( + {"model": "grok-4.3"}, + spec, + [{"role": "user", "content": "Hello"}], + ) + self.assertEqual(next(events), {"type": "response.output_text.delta", "delta": "Hi"}) + events.close() + + account_service.mark_text_used.assert_called_once_with("selected-token") + + def test_grok_console_stream_marks_account_used_when_stream_completes_without_events(self) -> None: + account_service = types.SimpleNamespace( + get_text_access_token=mock.Mock(return_value="selected-token"), + mark_text_used=mock.Mock(), + ) + spec = resolve_model("grok-4.3") + + class FakeClient: + def __init__(self, access_token: str) -> None: + self.access_token = access_token + + def __enter__(self) -> "FakeClient": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + pass + + def stream_response(self, payload): + return iter(()) + + with ( + mock.patch.dict(sys.modules, {"services.account_service": types.SimpleNamespace(account_service=account_service)}), + mock.patch.object(grok, "GrokConsoleClient", FakeClient), + ): + events = list(grok.console_chat_completion_events( + {"model": "grok-4.3"}, + spec, + [{"role": "user", "content": "Hello"}], + )) + + self.assertEqual(events, []) + account_service.mark_text_used.assert_called_once_with("selected-token") + + def test_grok_console_stream_marks_account_used_after_partial_stream_error(self) -> None: + account_service = types.SimpleNamespace( + get_text_access_token=mock.Mock(return_value="selected-token"), + mark_text_used=mock.Mock(), + ) + spec = resolve_model("grok-4.3") + + class FakeClient: + def __init__(self, access_token: str) -> None: + self.access_token = access_token + + def __enter__(self) -> "FakeClient": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + pass + + def stream_response(self, payload): + yield {"type": "response.output_text.delta", "delta": "Hi"} + raise grok.GrokConsoleError("stream failed", 502) + + with ( + mock.patch.dict(sys.modules, {"services.account_service": types.SimpleNamespace(account_service=account_service)}), + mock.patch.object(grok, "GrokConsoleClient", FakeClient), + ): + events = grok.console_chat_completion_events( + {"model": "grok-4.3"}, + spec, + [{"role": "user", "content": "Hello"}], + ) + self.assertEqual(next(events), {"type": "response.output_text.delta", "delta": "Hi"}) + with self.assertRaises(grok.HTTPException): + next(events) + + account_service.mark_text_used.assert_called_once_with("selected-token") + + def test_grok_console_stream_response_raises_stream_errors(self) -> None: + class FakeResponse: + status_code = 200 + + def iter_lines(self): + return iter([ + b'data: {"type":"response.failed","error":{"message":"upstream failed"}}', + ]) + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with mock.patch.object(grok.config, "data", {}), mock.patch("curl_cffi.requests.Session", FakeSession): + client = grok.GrokConsoleClient("token-value") + with self.assertRaises(grok.GrokConsoleError) as ctx: + list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertIn("upstream failed", str(ctx.exception)) + + def test_grok_console_stream_response_includes_upstream_error_detail(self) -> None: + account_service = types.SimpleNamespace(update_account=mock.Mock()) + + class FakeResponse: + status_code = 402 + + def json(self): + return {"error": {"message": "quota exhausted"}} + + class FakeSession: + headers: dict[str, str] = {} + + def __init__(self, **kwargs: object) -> None: + pass + + def post(self, url: str, **kwargs: object) -> FakeResponse: + return FakeResponse() + + def close(self) -> None: + pass + + with ( + mock.patch.dict(sys.modules, {"services.account_service": types.SimpleNamespace(account_service=account_service)}), + mock.patch.object(grok.config, "data", {}), + mock.patch("curl_cffi.requests.Session", FakeSession), + ): + client = grok.GrokConsoleClient("token-value") + with self.assertRaises(grok.GrokConsoleError) as ctx: + list(client.stream_response({"model": "grok-4.3", "input": []})) + + self.assertIn("quota exhausted", str(ctx.exception)) + account_service.update_account.assert_called_once_with("token-value", {"status": "限流"}) + def test_grok_console_uses_configured_network_profile(self) -> None: settings = { "network_profiles": { @@ -927,6 +1383,48 @@ def test_try_browser_bridge_403_reports_tier_hint(self, _): client._try_browser_bridge({"message": "test"}) self.assertIn("account may lack required tier", str(ctx.exception)) + @mock.patch("services.providers.grok.config") + def test_app_chat_prefers_direct_when_bridge_is_auto_detected(self, mock_config): + from services.providers.grok import GrokAppChatClient + mock_config.browser_bridge_url = "" + client = GrokAppChatClient.__new__(GrokAppChatClient) + direct_event = {"result": {"response": {"token": "hi"}}} + with ( + mock.patch.object(client, "_stream_direct_events", return_value=iter([direct_event])) as direct, + mock.patch.object(client, "_try_browser_bridge") as bridge, + ): + self.assertEqual(list(client.stream_events({"message": "test"})), [direct_event]) + direct.assert_called_once() + bridge.assert_not_called() + + @mock.patch("services.providers.grok.config") + def test_app_chat_uses_explicit_bridge_before_direct(self, mock_config): + from services.providers.grok import GrokAppChatClient + mock_config.browser_bridge_url = "http://bridge.local" + client = GrokAppChatClient.__new__(GrokAppChatClient) + with ( + mock.patch.object(client, "_try_browser_bridge", return_value=['{"result":{"response":{"token":"hi"}}}']) as bridge, + mock.patch.object(client, "_stream_direct_events") as direct, + ): + events = list(client.stream_events({"message": "test"})) + self.assertEqual(events, [{"result": {"response": {"token": "hi"}}}]) + bridge.assert_called_once() + direct.assert_not_called() + + @mock.patch("services.providers.grok.config") + def test_app_chat_falls_back_to_bridge_after_direct_403(self, mock_config): + from services.providers.grok import GrokAppChatClient, GrokConsoleError + mock_config.browser_bridge_url = "" + client = GrokAppChatClient.__new__(GrokAppChatClient) + with ( + mock.patch.object(client, "_stream_direct_events", side_effect=GrokConsoleError("forbidden", 403, 403)) as direct, + mock.patch.object(client, "_try_browser_bridge", return_value=['{"result":{"response":{"token":"hi"}}}']) as bridge, + ): + events = list(client.stream_events({"message": "test"})) + self.assertEqual(events, [{"result": {"response": {"token": "hi"}}}]) + direct.assert_called_once() + bridge.assert_called_once() + @mock.patch("services.providers.grok.config") def test_detect_bridge_url_auto_probes(self, mock_config): import services.providers.grok as grok_mod