Skip to content

Commit 431f181

Browse files
authored
Merge pull request #25 from aws-samples/fix/cache-ttl-and-batch-names
fix: cache_control ttl model check + batch create with explicit names
2 parents c407d62 + e9f18cf commit 431f181

7 files changed

Lines changed: 178 additions & 84 deletions

File tree

backend/app/api/admin/endpoints/tokens.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import List
88

99
from fastapi import APIRouter, Depends, HTTPException, status
10-
from pydantic import BaseModel, Field
10+
from pydantic import BaseModel
1111
from sqlalchemy import func, select
1212
from sqlalchemy.ext.asyncio import AsyncSession
1313

@@ -58,16 +58,27 @@ class CreateTokenRequest(BaseModel):
5858

5959

6060
class BatchCreateTokenRequest(BaseModel):
61-
"""Batch create tokens request."""
61+
"""Batch create tokens request.
6262
63-
count: int = Field(ge=1, le=100)
64-
name_prefix: str
63+
``names`` is a comma-separated string of token names (e.g. "alice, bob, charlie").
64+
Whitespace around each name is automatically trimmed and empty entries are ignored.
65+
"""
66+
67+
names: str
6568
expires_at: datetime | None = None
6669
quota_usd: Decimal | None = None
6770
allowed_ips: List[str] | None = None
6871
token_metadata: dict | None = None
6972
model_names: List[str] | None = None
7073

74+
def parsed_names(self) -> List[str]:
75+
"""Parse comma-separated names. Supports ASCII comma, Chinese comma, semicolons, and newlines."""
76+
import re
77+
78+
return [
79+
n for n in (s.strip() for s in re.split(r"[,,;;\n]+", self.names)) if n
80+
]
81+
7182

7283
class UpdateTokenRequest(BaseModel):
7384
"""Update token request."""
@@ -217,19 +228,29 @@ async def batch_create_tokens(
217228
"""
218229
Batch create API tokens with optional shared model list.
219230
220-
- **count**: Number of tokens to create (1-100)
221-
- **name_prefix**: Name prefix, tokens named {prefix}-001, {prefix}-002, ...
231+
- **names**: Comma-separated token names (e.g. "alice, bob, charlie")
222232
- **model_names**: Optional list of model names to assign to all tokens
223233
"""
234+
names = request.parsed_names()
235+
if not names:
236+
raise HTTPException(
237+
status_code=status.HTTP_400_BAD_REQUEST,
238+
detail="No valid names provided",
239+
)
240+
if len(names) > 100:
241+
raise HTTPException(
242+
status_code=status.HTTP_400_BAD_REQUEST,
243+
detail=f"Too many names ({len(names)}), maximum is 100",
244+
)
245+
224246
try:
225247
validated_meta = validate_token_metadata(request.token_metadata)
226248
except ValueError as e:
227249
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
228250

229251
results = await token_service.create_tokens_batch(
230252
user_id=current_user.id,
231-
count=request.count,
232-
name_prefix=request.name_prefix,
253+
names=names,
233254
expires_at=request.expires_at,
234255
quota_usd=request.quota_usd,
235256
allowed_ips=request.allowed_ips,

backend/app/services/bedrock.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ async def _try_stream_with_content_timeout(
683683
if use_converse:
684684
converse_params = self._build_converse_params(request, model_id)
685685
else:
686-
body = self._build_anthropic_body(request)
686+
body = self._build_anthropic_body(request, model_id=model_id)
687687
invoke_kwargs = self._build_invoke_kwargs(request, model_id)
688688

689689
content_received = False
@@ -806,12 +806,34 @@ async def _try_stream_with_content_timeout(
806806

807807
# ------------------------------------------------------------------
808808

809+
# Models that support the extended 1-hour cache TTL.
810+
# Only Claude 4.5 family models support ``ttl`` in ``cache_control``;
811+
# older/newer families (Claude 4, etc.) reject it with
812+
# ``Extra inputs are not permitted``.
813+
_EXTENDED_TTL_MODEL_PATTERNS = (
814+
"claude-opus-4-5",
815+
"claude-sonnet-4-5",
816+
"claude-haiku-4-5",
817+
)
818+
819+
@classmethod
820+
def _model_supports_cache_ttl(cls, model_id: str | None) -> bool:
821+
"""Check if a model supports the extended ``ttl`` field in cache_control."""
822+
if not model_id:
823+
return False
824+
return any(pat in model_id for pat in cls._EXTENDED_TTL_MODEL_PATTERNS)
825+
809826
@staticmethod
810-
def _new_cache_marker(ttl: str | None = None) -> dict:
811-
"""Create a cache_control marker with configured TTL."""
827+
def _new_cache_marker(ttl: str | None = None, model_id: str | None = None) -> dict:
828+
"""Create a cache_control marker with configured TTL.
829+
830+
The ``ttl`` field is only supported by Claude 4.5 family models.
831+
For unsupported models the field must be omitted, otherwise Bedrock
832+
returns ``Extra inputs are not permitted``.
833+
"""
812834
cache_ttl = ttl or get_settings().PROMPT_CACHE_TTL
813835
marker: dict = {"type": "ephemeral"}
814-
if cache_ttl != "5m":
836+
if cache_ttl != "5m" and BedrockClient._model_supports_cache_ttl(model_id):
815837
marker["ttl"] = cache_ttl
816838
return marker
817839

@@ -846,7 +868,9 @@ def _body_has_cache_control(body: dict) -> bool:
846868
return len(BedrockClient._collect_cache_blocks(body)) > 0
847869

848870
@staticmethod
849-
def _inject_prompt_cache_breakpoints(body: dict, ttl: str | None = None) -> None:
871+
def _inject_prompt_cache_breakpoints(
872+
body: dict, ttl: str | None = None, model_id: str | None = None
873+
) -> None:
850874
"""Inject up to 4 cache_control breakpoints into the request body.
851875
852876
Strategy aligned with claudecode-bedrock-proxy:
@@ -860,17 +884,24 @@ def _inject_prompt_cache_breakpoints(body: dict, ttl: str | None = None) -> None
860884
count against this budget.
861885
"""
862886
cache_ttl = ttl or get_settings().PROMPT_CACHE_TTL
863-
marker = BedrockClient._new_cache_marker(ttl=cache_ttl)
887+
supports_ttl = BedrockClient._model_supports_cache_ttl(model_id)
888+
marker = BedrockClient._new_cache_marker(ttl=cache_ttl, model_id=model_id)
864889

865890
# --- Step 1: Upgrade TTL on pre-existing breakpoints ---
866891
existing_blocks = BedrockClient._collect_cache_blocks(body)
867892
upgraded = 0
868-
if cache_ttl != "5m":
893+
if cache_ttl != "5m" and supports_ttl:
869894
for block in existing_blocks:
870895
cc = block.get("cache_control")
871896
if isinstance(cc, dict):
872897
cc["ttl"] = cache_ttl
873898
upgraded += 1
899+
elif not supports_ttl:
900+
# Strip ttl from pre-existing breakpoints for unsupported models
901+
for block in existing_blocks:
902+
cc = block.get("cache_control")
903+
if isinstance(cc, dict) and "ttl" in cc:
904+
del cc["ttl"]
874905

875906
existing = len(existing_blocks)
876907
budget = BedrockClient.MAX_CACHE_BREAKPOINTS - existing
@@ -951,7 +982,9 @@ def _inject_prompt_cache_breakpoints(body: dict, ttl: str | None = None) -> None
951982
)
952983

953984
@staticmethod
954-
def _build_anthropic_body(request: BedrockRequest) -> dict:
985+
def _build_anthropic_body(
986+
request: BedrockRequest, model_id: str | None = None
987+
) -> dict:
955988
"""
956989
Build an Anthropic Messages API request body from a BedrockRequest.
957990
@@ -1019,7 +1052,9 @@ def _build_anthropic_body(request: BedrockRequest) -> dict:
10191052
BedrockClient._body_has_cache_control(body) if should_inject else False
10201053
)
10211054
if should_inject and not has_cache:
1022-
BedrockClient._inject_prompt_cache_breakpoints(body, ttl=request.cache_ttl)
1055+
BedrockClient._inject_prompt_cache_breakpoints(
1056+
body, ttl=request.cache_ttl, model_id=model_id
1057+
)
10231058

10241059
# --- effort parameter: requires beta flag + output_config wrapper ---
10251060
# Users may pass "effort" as a top-level field (via additional_model_request_fields).
@@ -1425,7 +1460,7 @@ async def _invoke_inner(
14251460
},
14261461
)
14271462
else:
1428-
body = self._build_anthropic_body(request)
1463+
body = self._build_anthropic_body(request, model_id=model_id)
14291464
invoke_kwargs = self._build_invoke_kwargs(request, model_id)
14301465

14311466
max_retries = 3
@@ -1745,7 +1780,7 @@ async def _invoke_stream_inner(
17451780
},
17461781
)
17471782
else:
1748-
body = self._build_anthropic_body(request)
1783+
body = self._build_anthropic_body(request, model_id=model_id)
17491784
invoke_kwargs = self._build_invoke_kwargs(request, model_id)
17501785

17511786
max_retries = 4

backend/app/services/token.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,15 @@ async def create_token(
8787
async def create_tokens_batch(
8888
self,
8989
user_id: UUID,
90-
count: int,
91-
name_prefix: str,
90+
names: List[str],
9291
expires_at: Optional[datetime] = None,
9392
quota_usd: Optional[Decimal] = None,
9493
allowed_ips: Optional[List[str]] = None,
9594
token_metadata: Optional[dict] = None,
9695
model_names: Optional[List[str]] = None,
9796
) -> List[tuple[APIToken, str]]:
9897
"""
99-
Batch create API tokens with optional shared model list.
98+
Batch create API tokens with explicit names and optional shared model list.
10099
101100
All tokens are inserted in a single transaction (atomic).
102101
@@ -105,14 +104,14 @@ async def create_tokens_batch(
105104
"""
106105
tokens_and_keys: List[tuple[APIToken, str]] = []
107106

108-
for i in range(1, count + 1):
107+
for name in names:
109108
plain_token = generate_api_token()
110109
token_hash = hash_token(plain_token)
111110
encrypted = encrypt_token(plain_token)
112111

113112
token = APIToken(
114113
user_id=user_id,
115-
name=f"{name_prefix}-{i:03d}",
114+
name=name,
116115
token_hash=token_hash,
117116
encrypted_token=encrypted,
118117
expires_at=expires_at,

0 commit comments

Comments
 (0)