diff --git a/src/anthropic/lib/bedrock/_stream_decoder.py b/src/anthropic/lib/bedrock/_stream_decoder.py index 02e81a3c..66dd658c 100644 --- a/src/anthropic/lib/bedrock/_stream_decoder.py +++ b/src/anthropic/lib/bedrock/_stream_decoder.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterator, AsyncIterator +import json +from typing import TYPE_CHECKING, Any, Dict, Iterator, AsyncIterator, cast from ..._utils import lru_cache from ..._streaming import ServerSentEvent @@ -35,9 +36,9 @@ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: for chunk in iterator: event_stream_buffer.add_data(chunk) for event in event_stream_buffer: - message = self._parse_message_from_event(event) - if message: - yield ServerSentEvent(data=message, event="completion") + sse = self._parse_message_from_event(event) + if sse: + yield sse async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: """Given an async iterator that yields lines, iterate over it & yield every event encountered""" @@ -47,11 +48,11 @@ async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[Ser async for chunk in iterator: event_stream_buffer.add_data(chunk) for event in event_stream_buffer: - message = self._parse_message_from_event(event) - if message: - yield ServerSentEvent(data=message, event="completion") + sse = self._parse_message_from_event(event) + if sse: + yield sse - def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: + def _parse_message_from_event(self, event: EventStreamMessage) -> ServerSentEvent | None: response_dict = event.to_response_dict() parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) if response_dict["status_code"] != 200: @@ -61,4 +62,23 @@ def _parse_message_from_event(self, event: EventStreamMessage) -> str | None: if not chunk: return None - return chunk.get("bytes").decode() # type: ignore[no-any-return] + return _chunk_bytes_to_sse(chunk.get("bytes")) + + +def _chunk_bytes_to_sse(raw: bytes) -> ServerSentEvent | None: + decoded = raw.decode() + data: Any + try: + data = json.loads(decoded) + except Exception: + data = None + + if not isinstance(data, dict): + return ServerSentEvent(data=decoded, event="completion") + + payload = cast("Dict[str, Any]", data) + event_type = payload.get("type") + if not isinstance(event_type, str): + event_type = "completion" + + return ServerSentEvent(data=decoded, event=event_type) diff --git a/tests/lib/test_bedrock.py b/tests/lib/test_bedrock.py index 6e45c27f..f8aefef3 100644 --- a/tests/lib/test_bedrock.py +++ b/tests/lib/test_bedrock.py @@ -9,6 +9,7 @@ from respx import MockRouter from anthropic import AnthropicBedrock, AsyncAnthropicBedrock +from anthropic.lib.bedrock._stream_decoder import _chunk_bytes_to_sse sync_client = AnthropicBedrock( aws_region="us-east-1", @@ -275,3 +276,32 @@ def test_region_infer_from_specified_profile( client = AnthropicBedrock() assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"] + + +def test_chunk_bytes_to_sse_typed_event() -> None: + raw = ( + b'{"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant",' + b'"content":[],"model":"claude-x","stop_reason":null,"stop_sequence":null,' + b'"usage":{"input_tokens":1,"output_tokens":1}}}' + ) + sse = _chunk_bytes_to_sse(raw) + assert sse is not None + assert sse.event == "message_start" + assert sse.data == raw.decode() + + +def test_chunk_bytes_to_sse_legacy_completion() -> None: + raw = b'{"completion":" Hello","stop_reason":null,"model":"claude-2"}' + sse = _chunk_bytes_to_sse(raw) + assert sse is not None + assert sse.event == "completion" + + +def test_chunk_bytes_to_sse_legacy_completion_with_metrics() -> None: + raw = ( + b'{"completion":" Hello","stop_reason":"stop_sequence","model":"claude-2",' + b'"amazon-bedrock-invocationMetrics":{"inputTokenCount":1,"outputTokenCount":1}}' + ) + sse = _chunk_bytes_to_sse(raw) + assert sse is not None + assert sse.event == "completion"