diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 38daa618..8ef399e9 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -27,6 +27,7 @@ overload, ) +import orjson from pydantic import BaseModel from ._callbacks import CallbackManager @@ -2760,6 +2761,8 @@ def _submit_turns( if any(isinstance(x, Tool) and x._is_async for x in self._tools.values()): raise ValueError("Cannot use async tools in a synchronous chat") + stream = self._resolve_stream(stream, data_model) + def emit(text: str | Content): self._echo_content(str(text)) @@ -2791,7 +2794,9 @@ def emit(text: str | Content): content = self.provider.stream_content(chunk) if content is not None: text = self.provider.stream_text(chunk) - yield from acc.process_content(content, text, content_mode, emit) + yield from acc.process_content( + content, text, content_mode, emit + ) result = self.provider.stream_merge_chunks(result, chunk) yield from acc.flush_thinking(content_mode, emit) @@ -2821,9 +2826,10 @@ def emit(text: str | Content): turn = self.provider.value_turn( response, has_data_model=data_model is not None ) - if turn.text: - emit(turn.text) - yield turn.text + text = nonstream_yield_text(turn, data_model) + if text: + emit(text) + yield text if echo == "all": emit_other_contents(turn, emit) @@ -2868,6 +2874,8 @@ async def _submit_turns_async( *, controller: StreamController, ) -> AsyncGenerator[str | Content, None]: + stream = self._resolve_stream(stream, data_model) + def emit(text: str | Content): self._echo_content(str(text)) @@ -2899,7 +2907,9 @@ def emit(text: str | Content): content = self.provider.stream_content(chunk) if content is not None: text = self.provider.stream_text(chunk) - for item in acc.process_content(content, text, content_mode, emit): + for item in acc.process_content( + content, text, content_mode, emit + ): yield item result = self.provider.stream_merge_chunks(result, chunk) @@ -2931,9 +2941,10 @@ def emit(text: str | Content): turn = self.provider.value_turn( response, has_data_model=data_model is not None ) - if turn.text: - emit(turn.text) - yield turn.text + text = nonstream_yield_text(turn, data_model) + if text: + emit(text) + yield text if echo == "all": emit_other_contents(turn, emit) @@ -2941,6 +2952,23 @@ def emit(text: str | Content): turn = finalize_assistant_turn(self.provider, turn) self._turns.extend([user_turn, turn]) + def _resolve_stream( + self, + stream: bool, + data_model: type[BaseModel] | None, + ) -> bool: + """Downgrade a streaming request when the provider can't stream it.""" + if not stream or self.provider.can_stream(data_model): + return stream + warnings.warn( + "This provider does not support streaming structured data " + "extraction; falling back to a non-streaming request. The full " + "result is still yielded (as a single chunk) and recorded in the " + "conversation.", + stacklevel=2, + ) + return False + def _collect_all_kwargs( self, kwargs: Optional[SubmitInputArgsT], @@ -3402,6 +3430,22 @@ async def aclose_response(response: Any) -> None: response.close() +def nonstream_yield_text(turn: Turn, data_model: type[BaseModel] | None) -> str: + """ + Text to yield for a turn produced by a non-streaming request. + + For structured data extraction, this is the JSON serialization of the + extracted object, so that `"".join(chunks)` honors the same contract as a + streamed response (see [](`~chatlas.Chat.stream`)). Otherwise it's the + turn's plain text. + """ + if data_model is not None: + for content in turn.contents: + if isinstance(content, ContentJson): + return orjson.dumps(content.value).decode("utf-8") + return turn.text + + def finalize_assistant_turn(provider: Provider, turn: Turn) -> AssistantTurn: """Validate turn type, compute tokens and cost, and log usage.""" if not isinstance(turn, AssistantTurn): diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 65dbd50c..9b767ce0 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -229,6 +229,18 @@ async def chat_perform_async( kwargs: SubmitInputArgsT, ) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ... + def can_stream(self, data_model: Optional[type[BaseModel]]) -> bool: + """ + Whether a streaming request is supported for the given inputs. + + Returning `False` causes [](`~chatlas.Chat`) to transparently fall back + to a non-streaming request (emitting the full response as a single + chunk). Providers override this when a particular feature—most notably + structured data extraction via certain `data_model` strategies—cannot be + streamed. + """ + return True + @abstractmethod def stream_content(self, chunk: ChatCompletionChunkT) -> Optional["Content"]: ... diff --git a/chatlas/_provider_anthropic.py b/chatlas/_provider_anthropic.py index b762fefa..b658e993 100644 --- a/chatlas/_provider_anthropic.py +++ b/chatlas/_provider_anthropic.py @@ -175,8 +175,9 @@ def ChatAnthropic( `"auto"` (default) uses Anthropic's native `output_config` API for models that support it and falls back to a tool-based approach for older models. `"native"` forces the `output_config` API (which - supports streaming). `"tool"` forces the legacy tool-based approach - (which does not support streaming). + streams the extracted JSON incrementally). `"tool"` forces the legacy + tool-based approach; this can't stream incrementally, so streaming + requests transparently fall back to a single non-streaming request. api_key The API key to use for authentication. You generally should not supply this directly, but instead set the `ANTHROPIC_API_KEY` environment @@ -425,6 +426,21 @@ async def chat_perform_async( kwargs = self._chat_perform_args(stream, turns, tools, data_model, kwargs) return await self._async_client.messages.create(**kwargs) # type: ignore + def can_stream(self, data_model: Optional[type[BaseModel]]) -> bool: + # The tool-based structured output strategy can't stream (Anthropic + # streams tool input as `{"data": ...}`-wrapped JSON, which doesn't + # match the extracted-object contract). The native `output_config` + # strategy streams JSON as ordinary text deltas, so it's fine. + if data_model is None: + return True + return self._use_native_structured_output() + + def _use_native_structured_output(self) -> bool: + mode = self._structured_output_mode + return mode == "native" or ( + mode == "auto" and supports_structured_outputs(self.model) + ) + def _chat_perform_args( self, stream: bool, @@ -435,10 +451,7 @@ def _chat_perform_args( ) -> "SubmitInputArgs": tool_schemas = [self._anthropic_tool_schema(tool) for tool in tools.values()] - mode = self._structured_output_mode - use_native = mode == "native" or ( - mode == "auto" and supports_structured_outputs(self.model) - ) + use_native = self._use_native_structured_output() if data_model is not None and use_native: from anthropic import transform_schema @@ -459,12 +472,6 @@ def _chat_perform_args( elif data_model is not None: data_model_tool = self.create_data_model_tool(data_model) tool_schemas.append(self._anthropic_tool_schema(data_model_tool)) - if stream: - stream = False - warnings.warn( - "Anthropic does not support structured data extraction in streaming mode.", - stacklevel=2, - ) kwargs_full: "SubmitInputArgs" = { "stream": stream, diff --git a/docs/get-started/structured-data.qmd b/docs/get-started/structured-data.qmd index 0e2e6f76..26e24d0a 100644 --- a/docs/get-started/structured-data.qmd +++ b/docs/get-started/structured-data.qmd @@ -218,4 +218,63 @@ height='tall' facial_hair='beard' scars='scar on left cheek' voice='deep voice' This example only works with Claude, not GPT or Gemini, because only Claude supports adding arbitrary additional properties. That said, you could prompt an LLM to suggest a `BaseModel` for you from the unstructured input, and then use that to extract the data. This is a bit more work, but it can be done. +::: + + +## Streaming structured data + +`.chat_structured()` waits for the full response before returning a validated model instance. +If you instead want to consume the response as it's generated (for example, to drive a progress indicator or a live UI), pass `data_model` to [`.stream()`](../reference/Chat.qmd#stream) (or `.stream_async()`). + +When `data_model` is provided, `.stream()` yields the model's **JSON, one fragment at a time**. +Each chunk is a piece of a JSON string; only the *concatenation* of all the chunks is valid JSON. +After the stream finishes, parse the joined chunks with `data_model.model_validate_json()`: + +```python +import chatlas as ctl +from pydantic import BaseModel + +class Person(BaseModel): + name: str + age: int + +chat = ctl.ChatOpenAI() + +chunks = [] +for chunk in chat.stream("John, age 15, won first prize", data_model=Person): + chunks.append(chunk) + +person = Person.model_validate_json("".join(chunks)) +``` + +The streamed `chat` also records the result as a `ContentJson` in conversation history, so `chat.get_last_turn()` reflects the extracted data just as it would after `.chat_structured()`. + +### The partial JSON problem + +Because chunks arrive as raw JSON fragments, an individual chunk (or any prefix of the stream) is almost never valid JSON on its own. +For the example above, the chunks look something like: + +```python +['{"', 'name', '":"', 'John', '","', 'age', '":', '15', '}'] +``` + +This means you **cannot** call `Person.model_validate_json()` on a partial accumulation and expect it to succeed — it will raise until the closing `}` arrives. +Keep this in mind when deciding what to do with the stream: + +- **You just want the final object.** Accumulate every chunk and parse once at the end (as above), or simply use `.chat_structured()`, which does this for you. +- **You want to show progress as it streams.** Display the raw text as it accumulates, but only attempt to *parse* it once the stream is complete. If you need to render *structured* partial state (e.g. populate fields as they arrive), you'll need a tolerant/partial JSON parser — chatlas does not attempt partial parsing for you, because there's no general, lossless way to interpret a half-written JSON value (is `"Jo` the final value or just the start of `"John"`?). + +::: callout-note +### Provider support + +Streaming structured data works with providers that emit the JSON as ordinary +output text, including `ChatOpenAI()`, `ChatGoogle()`, and recent Claude models +via `ChatAnthropic()`. + +Older Claude models extract structured data through a *tool call* rather than +plain text, which cannot be streamed incrementally. In that case chatlas +transparently falls back to a single, non-streaming request: you'll get one +chunk containing the complete JSON (and a warning noting the fallback). You can +force this behavior with `ChatAnthropic(structured_output_mode="tool")`, or +force native streaming with `structured_output_mode="native"`. ::: \ No newline at end of file diff --git a/tests/test_chat.py b/tests/test_chat.py index d37a574f..6c1e5ee7 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -168,6 +168,12 @@ class Person(BaseModel): person = Person.model_validate_json(result) assert person == Person(name="John", age=15) + # The JSON arrives incrementally: chunks are fragments that are only valid + # JSON once fully concatenated (the "partial JSON" contract). + assert len(chunks) > 1 + with pytest.raises(ValueError): + Person.model_validate_json("".join(chunks[:-1])) + # Verify the last turn contains ContentJson with the structured data turn = chat.get_last_turn() assert turn is not None diff --git a/tests/test_provider_anthropic.py b/tests/test_provider_anthropic.py index 55b34dd8..21d4ebe3 100644 --- a/tests/test_provider_anthropic.py +++ b/tests/test_provider_anthropic.py @@ -188,6 +188,96 @@ class Person(BaseModel): assert turn.contents[0].value == {"name": "John", "age": 15} +def _fake_tool_mode_message(model: str): + from anthropic.types import Message, ToolUseBlock, Usage + + return Message( + id="msg_x", + content=[ + ToolUseBlock( + id="toolu_x", + name="_structured_tool_call", + input={"data": {"name": "John", "age": 15}}, + type="tool_use", + ) + ], + model=model, + role="assistant", + stop_reason="tool_use", + stop_sequence=None, + type="message", + usage=Usage(input_tokens=1, output_tokens=1), + ) + + +def test_stream_with_data_model_tool_mode_downgrades_gracefully(): + from unittest.mock import patch + + from chatlas._content import ContentJson + + model = "claude-3-5-sonnet-20241022" + chat = ChatAnthropic(model=model, structured_output_mode="tool", api_key="x") + + class Person(BaseModel): + name: str + age: int + + fake = _fake_tool_mode_message(model) + provider = cast(AnthropicProvider, chat.provider) + + with patch.object(provider._client.messages, "create", return_value=fake): + with pytest.warns(UserWarning, match="streaming"): + chunks = list(chat.stream("John, age 15", data_model=Person)) + + person = Person.model_validate_json("".join(chunks)) + assert person == Person(name="John", age=15) + + turn = chat.get_last_turn() + assert turn is not None + assert len(turn.contents) == 1 + assert isinstance(turn.contents[0], ContentJson) + assert turn.contents[0].value == {"name": "John", "age": 15} + + +@pytest.mark.asyncio +async def test_stream_async_with_data_model_tool_mode_downgrades_gracefully(): + from unittest.mock import AsyncMock, patch + + from chatlas._content import ContentJson + + model = "claude-3-5-sonnet-20241022" + chat = ChatAnthropic(model=model, structured_output_mode="tool", api_key="x") + + class Person(BaseModel): + name: str + age: int + + fake = _fake_tool_mode_message(model) + provider = cast(AnthropicProvider, chat.provider) + + with patch.object( + provider._async_client.messages, + "create", + new=AsyncMock(return_value=fake), + ): + with pytest.warns(UserWarning, match="streaming"): + chunks = [ + chunk + async for chunk in await chat.stream_async( + "John, age 15", data_model=Person + ) + ] + + person = Person.model_validate_json("".join(chunks)) + assert person == Person(name="John", age=15) + + turn = chat.get_last_turn() + assert turn is not None + assert len(turn.contents) == 1 + assert isinstance(turn.contents[0], ContentJson) + assert turn.contents[0].value == {"name": "John", "age": 15} + + @pytest.mark.vcr @retry_api_call def test_anthropic_images():