Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
overload,
)

import orjson
from pydantic import BaseModel

from ._callbacks import CallbackManager
Expand Down Expand Up @@ -2760,6 +2761,8 @@ def _submit_turns(
if any(isinstance(x, Tool) and x._is_async for x in self._tools.values()):
raise ValueError("Cannot use async tools in a synchronous chat")

stream = self._resolve_stream(stream, data_model)

def emit(text: str | Content):
self._echo_content(str(text))

Expand Down Expand Up @@ -2791,7 +2794,9 @@ def emit(text: str | Content):
content = self.provider.stream_content(chunk)
if content is not None:
text = self.provider.stream_text(chunk)
yield from acc.process_content(content, text, content_mode, emit)
yield from acc.process_content(
content, text, content_mode, emit
)
result = self.provider.stream_merge_chunks(result, chunk)

yield from acc.flush_thinking(content_mode, emit)
Expand Down Expand Up @@ -2821,9 +2826,10 @@ def emit(text: str | Content):
turn = self.provider.value_turn(
response, has_data_model=data_model is not None
)
if turn.text:
emit(turn.text)
yield turn.text
text = nonstream_yield_text(turn, data_model)
if text:
emit(text)
yield text

if echo == "all":
emit_other_contents(turn, emit)
Expand Down Expand Up @@ -2868,6 +2874,8 @@ async def _submit_turns_async(
*,
controller: StreamController,
) -> AsyncGenerator[str | Content, None]:
stream = self._resolve_stream(stream, data_model)

def emit(text: str | Content):
self._echo_content(str(text))

Expand Down Expand Up @@ -2899,7 +2907,9 @@ def emit(text: str | Content):
content = self.provider.stream_content(chunk)
if content is not None:
text = self.provider.stream_text(chunk)
for item in acc.process_content(content, text, content_mode, emit):
for item in acc.process_content(
content, text, content_mode, emit
):
yield item
result = self.provider.stream_merge_chunks(result, chunk)

Expand Down Expand Up @@ -2931,16 +2941,34 @@ def emit(text: str | Content):
turn = self.provider.value_turn(
response, has_data_model=data_model is not None
)
if turn.text:
emit(turn.text)
yield turn.text
text = nonstream_yield_text(turn, data_model)
if text:
emit(text)
yield text

if echo == "all":
emit_other_contents(turn, emit)

turn = finalize_assistant_turn(self.provider, turn)
self._turns.extend([user_turn, turn])

def _resolve_stream(
self,
stream: bool,
data_model: type[BaseModel] | None,
) -> bool:
"""Downgrade a streaming request when the provider can't stream it."""
if not stream or self.provider.can_stream(data_model):
return stream
warnings.warn(
"This provider does not support streaming structured data "
"extraction; falling back to a non-streaming request. The full "
"result is still yielded (as a single chunk) and recorded in the "
"conversation.",
stacklevel=2,
)
return False

def _collect_all_kwargs(
self,
kwargs: Optional[SubmitInputArgsT],
Expand Down Expand Up @@ -3402,6 +3430,22 @@ async def aclose_response(response: Any) -> None:
response.close()


def nonstream_yield_text(turn: Turn, data_model: type[BaseModel] | None) -> str:
"""
Text to yield for a turn produced by a non-streaming request.

For structured data extraction, this is the JSON serialization of the
extracted object, so that `"".join(chunks)` honors the same contract as a
streamed response (see [](`~chatlas.Chat.stream`)). Otherwise it's the
turn's plain text.
"""
if data_model is not None:
for content in turn.contents:
if isinstance(content, ContentJson):
return orjson.dumps(content.value).decode("utf-8")
return turn.text


def finalize_assistant_turn(provider: Provider, turn: Turn) -> AssistantTurn:
"""Validate turn type, compute tokens and cost, and log usage."""
if not isinstance(turn, AssistantTurn):
Expand Down
12 changes: 12 additions & 0 deletions chatlas/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,18 @@ async def chat_perform_async(
kwargs: SubmitInputArgsT,
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...

def can_stream(self, data_model: Optional[type[BaseModel]]) -> bool:
"""
Whether a streaming request is supported for the given inputs.

Returning `False` causes [](`~chatlas.Chat`) to transparently fall back
to a non-streaming request (emitting the full response as a single
chunk). Providers override this when a particular feature—most notably
structured data extraction via certain `data_model` strategies—cannot be
streamed.
"""
return True

@abstractmethod
def stream_content(self, chunk: ChatCompletionChunkT) -> Optional["Content"]: ...

Expand Down
31 changes: 19 additions & 12 deletions chatlas/_provider_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ def ChatAnthropic(
`"auto"` (default) uses Anthropic's native `output_config` API for
models that support it and falls back to a tool-based approach for
older models. `"native"` forces the `output_config` API (which
supports streaming). `"tool"` forces the legacy tool-based approach
(which does not support streaming).
streams the extracted JSON incrementally). `"tool"` forces the legacy
tool-based approach; this can't stream incrementally, so streaming
requests transparently fall back to a single non-streaming request.
api_key
The API key to use for authentication. You generally should not supply
this directly, but instead set the `ANTHROPIC_API_KEY` environment
Expand Down Expand Up @@ -425,6 +426,21 @@ async def chat_perform_async(
kwargs = self._chat_perform_args(stream, turns, tools, data_model, kwargs)
return await self._async_client.messages.create(**kwargs) # type: ignore

def can_stream(self, data_model: Optional[type[BaseModel]]) -> bool:
# The tool-based structured output strategy can't stream (Anthropic
# streams tool input as `{"data": ...}`-wrapped JSON, which doesn't
# match the extracted-object contract). The native `output_config`
# strategy streams JSON as ordinary text deltas, so it's fine.
if data_model is None:
return True
return self._use_native_structured_output()

def _use_native_structured_output(self) -> bool:
mode = self._structured_output_mode
return mode == "native" or (
mode == "auto" and supports_structured_outputs(self.model)
)

def _chat_perform_args(
self,
stream: bool,
Expand All @@ -435,10 +451,7 @@ def _chat_perform_args(
) -> "SubmitInputArgs":
tool_schemas = [self._anthropic_tool_schema(tool) for tool in tools.values()]

mode = self._structured_output_mode
use_native = mode == "native" or (
mode == "auto" and supports_structured_outputs(self.model)
)
use_native = self._use_native_structured_output()

if data_model is not None and use_native:
from anthropic import transform_schema
Expand All @@ -459,12 +472,6 @@ def _chat_perform_args(
elif data_model is not None:
data_model_tool = self.create_data_model_tool(data_model)
tool_schemas.append(self._anthropic_tool_schema(data_model_tool))
if stream:
stream = False
warnings.warn(
"Anthropic does not support structured data extraction in streaming mode.",
stacklevel=2,
)

kwargs_full: "SubmitInputArgs" = {
"stream": stream,
Expand Down
59 changes: 59 additions & 0 deletions docs/get-started/structured-data.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,63 @@ height='tall' facial_hair='beard' scars='scar on left cheek' voice='deep voice'
This example only works with Claude, not GPT or Gemini, because only Claude supports adding arbitrary additional properties.

That said, you could prompt an LLM to suggest a `BaseModel` for you from the unstructured input, and then use that to extract the data. This is a bit more work, but it can be done.
:::


## Streaming structured data

`.chat_structured()` waits for the full response before returning a validated model instance.
If you instead want to consume the response as it's generated (for example, to drive a progress indicator or a live UI), pass `data_model` to [`.stream()`](../reference/Chat.qmd#stream) (or `.stream_async()`).

When `data_model` is provided, `.stream()` yields the model's **JSON, one fragment at a time**.
Each chunk is a piece of a JSON string; only the *concatenation* of all the chunks is valid JSON.
After the stream finishes, parse the joined chunks with `data_model.model_validate_json()`:

```python
import chatlas as ctl
from pydantic import BaseModel

class Person(BaseModel):
name: str
age: int

chat = ctl.ChatOpenAI()

chunks = []
for chunk in chat.stream("John, age 15, won first prize", data_model=Person):
chunks.append(chunk)

person = Person.model_validate_json("".join(chunks))
```

The streamed `chat` also records the result as a `ContentJson` in conversation history, so `chat.get_last_turn()` reflects the extracted data just as it would after `.chat_structured()`.

### The partial JSON problem

Because chunks arrive as raw JSON fragments, an individual chunk (or any prefix of the stream) is almost never valid JSON on its own.
For the example above, the chunks look something like:

```python
['{"', 'name', '":"', 'John', '","', 'age', '":', '15', '}']
```

This means you **cannot** call `Person.model_validate_json()` on a partial accumulation and expect it to succeed — it will raise until the closing `}` arrives.
Keep this in mind when deciding what to do with the stream:

- **You just want the final object.** Accumulate every chunk and parse once at the end (as above), or simply use `.chat_structured()`, which does this for you.
- **You want to show progress as it streams.** Display the raw text as it accumulates, but only attempt to *parse* it once the stream is complete. If you need to render *structured* partial state (e.g. populate fields as they arrive), you'll need a tolerant/partial JSON parser — chatlas does not attempt partial parsing for you, because there's no general, lossless way to interpret a half-written JSON value (is `"Jo` the final value or just the start of `"John"`?).

::: callout-note
### Provider support

Streaming structured data works with providers that emit the JSON as ordinary
output text, including `ChatOpenAI()`, `ChatGoogle()`, and recent Claude models
via `ChatAnthropic()`.

Older Claude models extract structured data through a *tool call* rather than
plain text, which cannot be streamed incrementally. In that case chatlas
transparently falls back to a single, non-streaming request: you'll get one
chunk containing the complete JSON (and a warning noting the fallback). You can
force this behavior with `ChatAnthropic(structured_output_mode="tool")`, or
force native streaming with `structured_output_mode="native"`.
:::
6 changes: 6 additions & 0 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ class Person(BaseModel):
person = Person.model_validate_json(result)
assert person == Person(name="John", age=15)

# The JSON arrives incrementally: chunks are fragments that are only valid
# JSON once fully concatenated (the "partial JSON" contract).
assert len(chunks) > 1
with pytest.raises(ValueError):
Person.model_validate_json("".join(chunks[:-1]))

# Verify the last turn contains ContentJson with the structured data
turn = chat.get_last_turn()
assert turn is not None
Expand Down
90 changes: 90 additions & 0 deletions tests/test_provider_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,96 @@ class Person(BaseModel):
assert turn.contents[0].value == {"name": "John", "age": 15}


def _fake_tool_mode_message(model: str):
from anthropic.types import Message, ToolUseBlock, Usage

return Message(
id="msg_x",
content=[
ToolUseBlock(
id="toolu_x",
name="_structured_tool_call",
input={"data": {"name": "John", "age": 15}},
type="tool_use",
)
],
model=model,
role="assistant",
stop_reason="tool_use",
stop_sequence=None,
type="message",
usage=Usage(input_tokens=1, output_tokens=1),
)


def test_stream_with_data_model_tool_mode_downgrades_gracefully():
from unittest.mock import patch

from chatlas._content import ContentJson

model = "claude-3-5-sonnet-20241022"
chat = ChatAnthropic(model=model, structured_output_mode="tool", api_key="x")

class Person(BaseModel):
name: str
age: int

fake = _fake_tool_mode_message(model)
provider = cast(AnthropicProvider, chat.provider)

with patch.object(provider._client.messages, "create", return_value=fake):
with pytest.warns(UserWarning, match="streaming"):
chunks = list(chat.stream("John, age 15", data_model=Person))

person = Person.model_validate_json("".join(chunks))
assert person == Person(name="John", age=15)

turn = chat.get_last_turn()
assert turn is not None
assert len(turn.contents) == 1
assert isinstance(turn.contents[0], ContentJson)
assert turn.contents[0].value == {"name": "John", "age": 15}


@pytest.mark.asyncio
async def test_stream_async_with_data_model_tool_mode_downgrades_gracefully():
from unittest.mock import AsyncMock, patch

from chatlas._content import ContentJson

model = "claude-3-5-sonnet-20241022"
chat = ChatAnthropic(model=model, structured_output_mode="tool", api_key="x")

class Person(BaseModel):
name: str
age: int

fake = _fake_tool_mode_message(model)
provider = cast(AnthropicProvider, chat.provider)

with patch.object(
provider._async_client.messages,
"create",
new=AsyncMock(return_value=fake),
):
with pytest.warns(UserWarning, match="streaming"):
chunks = [
chunk
async for chunk in await chat.stream_async(
"John, age 15", data_model=Person
)
]

person = Person.model_validate_json("".join(chunks))
assert person == Person(name="John", age=15)

turn = chat.get_last_turn()
assert turn is not None
assert len(turn.contents) == 1
assert isinstance(turn.contents[0], ContentJson)
assert turn.contents[0].value == {"name": "John", "age": 15}


@pytest.mark.vcr
@retry_api_call
def test_anthropic_images():
Expand Down