Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions src/anthropic/lib/bedrock/_stream_decoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Iterator, AsyncIterator
import json
from typing import TYPE_CHECKING, Any, Dict, Iterator, AsyncIterator, cast

from ..._utils import lru_cache
from ..._streaming import ServerSentEvent
Expand Down Expand Up @@ -35,9 +36,9 @@ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
yield ServerSentEvent(data=message, event="completion")
sse = self._parse_message_from_event(event)
if sse:
yield sse

async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
Expand All @@ -47,11 +48,11 @@ async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[Ser
async for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
yield ServerSentEvent(data=message, event="completion")
sse = self._parse_message_from_event(event)
if sse:
yield sse

def _parse_message_from_event(self, event: EventStreamMessage) -> str | None:
def _parse_message_from_event(self, event: EventStreamMessage) -> ServerSentEvent | None:
response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
Expand All @@ -61,4 +62,23 @@ def _parse_message_from_event(self, event: EventStreamMessage) -> str | None:
if not chunk:
return None

return chunk.get("bytes").decode() # type: ignore[no-any-return]
return _chunk_bytes_to_sse(chunk.get("bytes"))


def _chunk_bytes_to_sse(raw: bytes) -> ServerSentEvent | None:
decoded = raw.decode()
data: Any
try:
data = json.loads(decoded)
except Exception:
data = None

if not isinstance(data, dict):
return ServerSentEvent(data=decoded, event="completion")

payload = cast("Dict[str, Any]", data)
event_type = payload.get("type")
if not isinstance(event_type, str):
event_type = "completion"

return ServerSentEvent(data=decoded, event=event_type)
30 changes: 30 additions & 0 deletions tests/lib/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from respx import MockRouter

from anthropic import AnthropicBedrock, AsyncAnthropicBedrock
from anthropic.lib.bedrock._stream_decoder import _chunk_bytes_to_sse

sync_client = AnthropicBedrock(
aws_region="us-east-1",
Expand Down Expand Up @@ -275,3 +276,32 @@ def test_region_infer_from_specified_profile(
client = AnthropicBedrock()

assert client.aws_region == next(profile for profile in profiles if profile["name"] == aws_profile)["region"]


def test_chunk_bytes_to_sse_typed_event() -> None:
raw = (
b'{"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant",'
b'"content":[],"model":"claude-x","stop_reason":null,"stop_sequence":null,'
b'"usage":{"input_tokens":1,"output_tokens":1}}}'
)
sse = _chunk_bytes_to_sse(raw)
assert sse is not None
assert sse.event == "message_start"
assert sse.data == raw.decode()


def test_chunk_bytes_to_sse_legacy_completion() -> None:
raw = b'{"completion":" Hello","stop_reason":null,"model":"claude-2"}'
sse = _chunk_bytes_to_sse(raw)
assert sse is not None
assert sse.event == "completion"


def test_chunk_bytes_to_sse_legacy_completion_with_metrics() -> None:
raw = (
b'{"completion":" Hello","stop_reason":"stop_sequence","model":"claude-2",'
b'"amazon-bedrock-invocationMetrics":{"inputTokenCount":1,"outputTokenCount":1}}'
)
sse = _chunk_bytes_to_sse(raw)
assert sse is not None
assert sse.event == "completion"
Loading