Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions dto/claude_prompt_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package dto

import (
"encoding/json"
"strings"

"github.com/QuantumNous/new-api/common"
"github.com/tidwall/sjson"
)

const (
AnthropicPromptCacheTTLEnv = "ANTHROPIC_PROMPT_CACHE_TTL"
AnthropicPromptCacheTTLHeader = "x-anthropic-prompt-cache-ttl"
AnthropicPromptCacheWorkloadHeader = "x-anthropic-prompt-cache-workload"
anthropicPromptCacheControlType = "ephemeral"
)

func ApplyAnthropicPromptCacheControlToClaudeRequest(req *ClaudeRequest, ttl string) bool {
if req == nil || ttl == "" || ClaudeRequestHasCacheControl(req) {
return false
}
req.CacheControl = newAnthropicPromptCacheControlRaw(ttl)
return true
}

func ApplyAnthropicPromptCacheControlToRawClaudeBody(body []byte, ttl string) ([]byte, bool, error) {
if ttl == "" || RawClaudeBodyHasCacheControl(body) {
return body, false, nil
}
cc := newAnthropicPromptCacheControlRaw(ttl)
out, err := sjson.SetRawBytes(body, "cache_control", cc)
if err != nil {
return body, false, err
}
return out, true, nil
}

func NormalizeAnthropicPromptCacheTTL(value, workload string) string {
switch strings.ToLower(strings.TrimSpace(value)) {
case "", "off", "false", "none", "disabled":
return ""
case "5m":
return "5m"
case "1h":
return "1h"
case "auto":
if IsLongRunningAnthropicWorkload(workload) {
return "1h"
}
return "5m"
default:
return ""
}
}

func IsLongRunningAnthropicWorkload(workload string) bool {
switch strings.ToLower(strings.TrimSpace(workload)) {
case "eval", "evaluation", "benchmark", "bench", "batch", "pipeline", "long", "long-running":
return true
default:
return false
}
}

func ClaudeRequestHasCacheControl(req *ClaudeRequest) bool {
if req == nil {
return false
}
if len(req.CacheControl) > 0 {
return true
}
data, err := common.Marshal(req)
if err != nil {
return false
}
return RawClaudeBodyHasCacheControl(data)
}

func RawClaudeBodyHasCacheControl(body []byte) bool {
var value any
if err := common.Unmarshal(body, &value); err != nil {
return false
}
return hasCacheControlKey(value)
}

func hasCacheControlKey(value any) bool {
switch v := value.(type) {
case map[string]any:
for key, child := range v {
if key == "cache_control" && child != nil {
return true
}
if hasCacheControlKey(child) {
return true
}
}
case []any:
for _, child := range v {
if hasCacheControlKey(child) {
return true
}
}
}
return false
}

func newAnthropicPromptCacheControlRaw(ttl string) json.RawMessage {
return json.RawMessage(`{"type":"` + anthropicPromptCacheControlType + `","ttl":"` + ttl + `"}`)
}
30 changes: 24 additions & 6 deletions dto/openai_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,9 @@ func (m *Message) ParseContent() []MediaContent {
case ContentTypeText:
if text, ok := contentItem["text"].(string); ok {
contentList = append(contentList, MediaContent{
Type: ContentTypeText,
Text: text,
Type: ContentTypeText,
Text: text,
CacheControl: mediaContentCacheControl(contentItem),
})
}

Expand All @@ -596,8 +597,9 @@ func (m *Message) ParseContent() []MediaContent {
}
}
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: temp,
Type: ContentTypeImageURL,
ImageUrl: temp,
CacheControl: mediaContentCacheControl(contentItem),
})

case ContentTypeInputAudio:
Expand All @@ -610,8 +612,9 @@ func (m *Message) ParseContent() []MediaContent {
Format: format,
}
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: temp,
Type: ContentTypeInputAudio,
InputAudio: temp,
CacheControl: mediaContentCacheControl(contentItem),
})
}
}
Expand All @@ -624,6 +627,7 @@ func (m *Message) ParseContent() []MediaContent {
File: &MessageFile{
FileId: fileId,
},
CacheControl: mediaContentCacheControl(contentItem),
})
} else {
fileName, ok1 := fileData["filename"].(string)
Expand All @@ -635,6 +639,7 @@ func (m *Message) ParseContent() []MediaContent {
FileName: fileName,
FileData: fileDataStr,
},
CacheControl: mediaContentCacheControl(contentItem),
})
}
}
Expand All @@ -646,6 +651,7 @@ func (m *Message) ParseContent() []MediaContent {
VideoUrl: &MessageVideoUrl{
Url: videoUrl,
},
CacheControl: mediaContentCacheControl(contentItem),
})
}
}
Expand All @@ -657,6 +663,18 @@ func (m *Message) ParseContent() []MediaContent {
return contentList
}

func mediaContentCacheControl(contentItem map[string]any) json.RawMessage {
cacheControl, ok := contentItem["cache_control"]
if !ok || cacheControl == nil {
return nil
}
data, err := common.Marshal(cacheControl)
if err != nil {
return nil
}
return data
}

// old code
/*func (m *Message) StringContent() string {
if m.parsedStringContent != nil {
Expand Down
5 changes: 5 additions & 0 deletions relay/channel/api_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ var passthroughSkipHeaderNamesLower = map[string]struct{}{
"x-api-key": {},
"x-goog-api-key": {},

// Gateway-only Anthropic prompt-cache policy headers. These control New API's
// request mutation and should never be forwarded to upstream providers.
"x-anthropic-prompt-cache-ttl": {},
"x-anthropic-prompt-cache-workload": {},

// WebSocket handshake headers are generated by the client/dialer.
"sec-websocket-key": {},
"sec-websocket-version": {},
Expand Down
30 changes: 30 additions & 0 deletions relay/channel/api_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,36 @@ func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) {
require.False(t, hasAcceptEncoding)
}

func TestProcessHeaderOverride_PassthroughSkipsAnthropicPromptCachePolicyHeaders(t *testing.T) {
t.Parallel()

gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(recorder)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
ctx.Request.Header.Set("X-Trace-Id", "trace-123")
ctx.Request.Header.Set("X-Anthropic-Prompt-Cache-Ttl", "1h")
ctx.Request.Header.Set("X-Anthropic-Prompt-Cache-Workload", "eval")

info := &relaycommon.RelayInfo{
IsChannelTest: false,
ChannelMeta: &relaycommon.ChannelMeta{
HeadersOverride: map[string]any{
"*": "",
},
},
}

headers, err := processHeaderOverride(info, ctx)
require.NoError(t, err)
require.Equal(t, "trace-123", headers["x-trace-id"])

_, hasTTL := headers["x-anthropic-prompt-cache-ttl"]
require.False(t, hasTTL)
_, hasWorkload := headers["x-anthropic-prompt-cache-workload"]
require.False(t, hasWorkload)
}

func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) {
t.Parallel()

Expand Down
19 changes: 18 additions & 1 deletion relay/channel/claude/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net/http"
"net/url"
"os"

"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
Expand Down Expand Up @@ -95,7 +96,23 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
return RequestOpenAI2ClaudeMessage(c, *request)
claudeReq, err := RequestOpenAI2ClaudeMessage(c, *request)
if err != nil {
return nil, err
}
ttl := os.Getenv(dto.AnthropicPromptCacheTTLEnv)
workload := ""
if c != nil {
if headerTTL := c.GetHeader(dto.AnthropicPromptCacheTTLHeader); headerTTL != "" {
ttl = headerTTL
}
workload = c.GetHeader(dto.AnthropicPromptCacheWorkloadHeader)
}
dto.ApplyAnthropicPromptCacheControlToClaudeRequest(
claudeReq,
dto.NormalizeAnthropicPromptCacheTTL(ttl, workload),
)
return claudeReq, nil
}

func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
Expand Down
11 changes: 7 additions & 4 deletions relay/channel/claude/relay-claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,9 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
for _, ctx := range message.ParseContent() {
if ctx.Type == "text" && ctx.Text != "" {
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](ctx.Text),
Type: "text",
Text: common.GetPointer[string](ctx.Text),
CacheControl: ctx.CacheControl,
})
}
// 未来可以在这里扩展对图片等其他类型的支持
Expand Down Expand Up @@ -376,8 +377,9 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
case "text":
if mediaMessage.Text != "" {
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](mediaMessage.Text),
Type: "text",
Text: common.GetPointer[string](mediaMessage.Text),
CacheControl: mediaMessage.CacheControl,
})
}
default:
Expand All @@ -390,6 +392,7 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
return nil, fmt.Errorf("get file data failed: %s", err.Error())
}
claudeMediaMessage := dto.ClaudeMediaMessage{
CacheControl: mediaMessage.CacheControl,
Source: &dto.ClaudeMessageSource{
Type: "base64",
},
Expand Down
Loading