diff --git a/services/protocol/openai_v1_chat_complete.py b/services/protocol/openai_v1_chat_complete.py index 96beb7b..d6c4c25 100644 --- a/services/protocol/openai_v1_chat_complete.py +++ b/services/protocol/openai_v1_chat_complete.py @@ -65,32 +65,94 @@ def completion_response( } -def stream_text_chat_completion(backend, messages: list[dict[str, Any]], model: str) -> Iterator[dict[str, Any]]: +def stream_include_usage(body: dict[str, Any]) -> bool: + stream_options = body.get("stream_options") + return isinstance(stream_options, dict) and stream_options.get("include_usage") is True + + +def completion_usage( + messages: list[dict[str, Any]], + model: str, + content: str, + reasoning_content: str = "", +) -> dict[str, int]: + prompt_tokens = count_message_tokens(messages, model) if messages else 0 + completion_tokens = count_text_tokens(content, model) if messages else 0 + reasoning_tokens = count_text_tokens(reasoning_content, model) if messages and reasoning_content else 0 + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens + reasoning_tokens, + "total_tokens": prompt_tokens + completion_tokens + reasoning_tokens, + } + + +def completion_usage_chunk(model: str, completion_id: str, created: int, usage: dict[str, int]) -> dict[str, Any]: + return { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [], + "usage": usage, + } + + +def stream_chunk(chunk: dict[str, Any], include_usage: bool) -> dict[str, Any]: + if include_usage: + chunk = dict(chunk) + chunk["usage"] = None + return chunk + + +def stream_text_chat_completion( + backend, + messages: list[dict[str, Any]], + model: str, + include_usage: bool = False, +) -> Iterator[dict[str, Any]]: completion_id = f"chatcmpl-{uuid.uuid4().hex}" created = int(time.time()) sent_role = False + content_parts: list[str] = [] request = ConversationRequest(model=model, messages=messages) for delta_text in stream_text_deltas(backend, request): + content_parts.append(delta_text) if not sent_role: sent_role = True - yield completion_chunk(model, {"role": "assistant", "content": delta_text}, None, completion_id, created) + chunk = completion_chunk(model, {"role": "assistant", "content": delta_text}, None, completion_id, created) else: - yield completion_chunk(model, {"content": delta_text}, None, completion_id, created) + chunk = completion_chunk(model, {"content": delta_text}, None, completion_id, created) + yield stream_chunk(chunk, include_usage) if not sent_role: - yield completion_chunk(model, {"role": "assistant", "content": ""}, None, completion_id, created) - yield completion_chunk(model, {}, "stop", completion_id, created) + chunk = completion_chunk(model, {"role": "assistant", "content": ""}, None, completion_id, created) + yield stream_chunk(chunk, include_usage) + yield stream_chunk(completion_chunk(model, {}, "stop", completion_id, created), include_usage) + if include_usage: + yield completion_usage_chunk( + model, + completion_id, + created, + completion_usage(messages, model, "".join(content_parts)), + ) def stream_grok_app_chat_completion(body: dict[str, Any], spec, messages: list[dict[str, Any]], model: str) -> Iterator[dict[str, Any]]: + include_usage = stream_include_usage(body) completion_id = f"chatcmpl-{uuid.uuid4().hex}" created = int(time.time()) sent_role = False + content_parts: list[str] = [] + reasoning_parts: list[str] = [] for event in grok.app_chat_completion_events(body, spec, messages): token, thinking = grok.extract_app_chat_token(event) if not token: if grok.is_app_chat_final_event(event): break continue + if thinking: + reasoning_parts.append(token) + else: + content_parts.append(token) if not sent_role: sent_role = True delta: dict[str, Any] = {"role": "assistant"} @@ -98,30 +160,55 @@ def stream_grok_app_chat_completion(body: dict[str, Any], spec, messages: list[d delta["reasoning_content"] = token else: delta["content"] = token - yield completion_chunk(model, delta, None, completion_id, created) + yield stream_chunk(completion_chunk(model, delta, None, completion_id, created), include_usage) continue if thinking: - yield completion_chunk(model, {"reasoning_content": token}, None, completion_id, created) + chunk = completion_chunk(model, {"reasoning_content": token}, None, completion_id, created) else: - yield completion_chunk(model, {"content": token}, None, completion_id, created) + chunk = completion_chunk(model, {"content": token}, None, completion_id, created) + yield stream_chunk(chunk, include_usage) if not sent_role: - yield completion_chunk(model, {"role": "assistant", "content": ""}, None, completion_id, created) - yield completion_chunk(model, {}, "stop", completion_id, created) + chunk = completion_chunk(model, {"role": "assistant", "content": ""}, None, completion_id, created) + yield stream_chunk(chunk, include_usage) + yield stream_chunk(completion_chunk(model, {}, "stop", completion_id, created), include_usage) + if include_usage: + yield completion_usage_chunk( + model, + completion_id, + created, + completion_usage(messages, model, "".join(content_parts), "".join(reasoning_parts)), + ) def stream_grok_chat_completion(body: dict[str, Any], spec, messages: list[dict[str, Any]], model: str) -> Iterator[dict[str, Any]]: if is_grok_app_chat_model(spec): yield from stream_grok_app_chat_completion(body, spec, messages, model) return + include_usage = stream_include_usage(body) completion_id = f"chatcmpl-{uuid.uuid4().hex}" created = int(time.time()) response = grok.console_chat_completion(body, spec, messages) if response.reasoning_content: - yield completion_chunk(model, {"role": "assistant", "reasoning_content": response.reasoning_content}, None, completion_id, created) - yield completion_chunk(model, {"content": response.content}, None, completion_id, created) + chunk = completion_chunk( + model, + {"role": "assistant", "reasoning_content": response.reasoning_content}, + None, + completion_id, + created, + ) + yield stream_chunk(chunk, include_usage) + yield stream_chunk(completion_chunk(model, {"content": response.content}, None, completion_id, created), include_usage) else: - yield completion_chunk(model, {"role": "assistant", "content": response.content}, None, completion_id, created) - yield completion_chunk(model, {}, "stop", completion_id, created) + chunk = completion_chunk(model, {"role": "assistant", "content": response.content}, None, completion_id, created) + yield stream_chunk(chunk, include_usage) + yield stream_chunk(completion_chunk(model, {}, "stop", completion_id, created), include_usage) + if include_usage: + yield completion_usage_chunk( + model, + completion_id, + created, + completion_usage(messages, model, response.content, response.reasoning_content), + ) def collect_chat_content(chunks: Iterable[dict[str, Any]]) -> str: @@ -243,7 +330,7 @@ def handle(body: dict[str, Any]) -> dict[str, Any] | Iterator[dict[str, Any]]: spec = resolve_model(model) if spec.provider == GROK_PROVIDER: return stream_grok_chat_completion(body, spec, messages, model) - return stream_text_chat_completion(text_backend(), messages, model) + return stream_text_chat_completion(text_backend(), messages, model, stream_include_usage(body)) if is_image_chat_request(body): return image_chat_response(body) model, messages = text_chat_parts(body) diff --git a/test/test_grok_provider.py b/test/test_grok_provider.py index d3f80fd..26a8bdf 100644 --- a/test/test_grok_provider.py +++ b/test/test_grok_provider.py @@ -406,6 +406,29 @@ def test_streaming_grok_console_completion_emits_reasoning_content(self) -> None self.assertEqual(chunks[1]["choices"][0]["delta"], {"content": "Hi"}) self.assertEqual(chunks[2]["choices"][0]["finish_reason"], "stop") + def test_streaming_grok_console_completion_includes_usage_when_requested(self) -> None: + body = { + "model": "grok-4.20-reasoning", + "stream": True, + "stream_options": {"include_usage": True}, + "messages": [{"role": "user", "content": "Hello"}], + } + with ( + mock.patch.object( + grok, + "console_chat_completion", + return_value=grok.GrokConsoleCompletion(content="Hi", reasoning_content="think"), + ), + mock.patch.object(openai_v1_chat_complete, "count_message_tokens", return_value=11), + mock.patch.object(openai_v1_chat_complete, "count_text_tokens", side_effect=[2, 3]), + ): + chunks = list(openai_v1_chat_complete.handle(body)) + + self.assertIsNone(chunks[0]["usage"]) + self.assertEqual(chunks[-2]["choices"][0]["finish_reason"], "stop") + self.assertEqual(chunks[-1]["choices"], []) + self.assertEqual(chunks[-1]["usage"], {"prompt_tokens": 11, "completion_tokens": 5, "total_tokens": 16}) + def test_streaming_grok_app_chat_completion_emits_reasoning_content(self) -> None: body = { "model": "grok-4.20-heavy", @@ -423,6 +446,51 @@ def test_streaming_grok_app_chat_completion_emits_reasoning_content(self) -> Non self.assertEqual(chunks[1]["choices"][0]["delta"], {"content": "Hi"}) self.assertEqual(chunks[-1]["choices"][0]["finish_reason"], "stop") + def test_streaming_grok_app_chat_completion_includes_usage_when_requested(self) -> None: + body = { + "model": "grok-4.20-heavy", + "stream": True, + "stream_options": {"include_usage": True}, + "messages": [{"role": "user", "content": "Hello"}], + } + events = [ + {"result": {"response": {"token": "think", "isThinking": True}}}, + {"result": {"response": {"token": "Hi"}}}, + {"result": {"response": {"token": " there", "messageTag": "final"}}}, + ] + with ( + mock.patch.object(grok, "app_chat_completion_events", return_value=iter(events)), + mock.patch.object(openai_v1_chat_complete, "count_message_tokens", return_value=7), + mock.patch.object(openai_v1_chat_complete, "count_text_tokens", side_effect=[4, 2]), + ): + chunks = list(openai_v1_chat_complete.handle(body)) + + self.assertIsNone(chunks[0]["usage"]) + self.assertEqual(chunks[-2]["choices"][0]["finish_reason"], "stop") + self.assertEqual(chunks[-1]["choices"], []) + self.assertEqual(chunks[-1]["usage"], {"prompt_tokens": 7, "completion_tokens": 6, "total_tokens": 13}) + + def test_streaming_text_completion_includes_usage_when_requested(self) -> None: + body = { + "model": "auto", + "stream": True, + "stream_options": {"include_usage": True}, + "messages": [{"role": "user", "content": "Hello"}], + } + with ( + mock.patch.object(openai_v1_chat_complete, "ConversationRequest", return_value=object()), + mock.patch.object(openai_v1_chat_complete, "stream_text_deltas", return_value=iter(["Hi", " there"])), + mock.patch.object(openai_v1_chat_complete, "count_message_tokens", return_value=5), + mock.patch.object(openai_v1_chat_complete, "count_text_tokens", return_value=8), + ): + chunks = list(openai_v1_chat_complete.handle(body)) + + self.assertEqual(chunks[0]["choices"][0]["delta"], {"role": "assistant", "content": "Hi"}) + self.assertIsNone(chunks[0]["usage"]) + self.assertEqual(chunks[-2]["choices"][0]["finish_reason"], "stop") + self.assertEqual(chunks[-1]["choices"], []) + self.assertEqual(chunks[-1]["usage"], {"prompt_tokens": 5, "completion_tokens": 8, "total_tokens": 13}) + def test_non_streaming_grok_app_chat_completion_includes_reasoning_content(self) -> None: body = { "model": "grok-4.20-heavy",