Skip to content

Commit fcca386

Browse files
committed
Fix CommandCode tool event mapping for OpenAI clients
1 parent df91dd8 commit fcca386

2 files changed

Lines changed: 226 additions & 99 deletions

File tree

internal/api/commandcode.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,14 @@ type CCRequestBody struct {
4949
}
5050

5151
type CCStreamEvent struct {
52-
Type string `json:"type"`
53-
Text string `json:"text"`
54-
ToolCallID string `json:"toolCallId"`
55-
ToolName string `json:"toolName"`
56-
FinishReason string `json:"finishReason"`
52+
Type string `json:"type"`
53+
Text string `json:"text"`
54+
ID string `json:"id"`
55+
Delta string `json:"delta"`
56+
Input map[string]any `json:"input"`
57+
ToolCallID string `json:"toolCallId"`
58+
ToolName string `json:"toolName"`
59+
FinishReason string `json:"finishReason"`
5760
Error *struct {
5861
Message string `json:"message"`
5962
StatusCode *int `json:"statusCode"`

internal/proxy/proxy.go

Lines changed: 218 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ func (p *Proxy) StreamResponse(w http.ResponseWriter, r *http.Request, ccResp *h
219219
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
220220
sentRole := false
221221
toolCallIndex := 0
222+
toolCallIndexes := map[string]int{}
222223

223224
for scanner.Scan() {
224225
select {
@@ -238,77 +239,154 @@ func (p *Proxy) StreamResponse(w http.ResponseWriter, r *http.Request, ccResp *h
238239
continue
239240
}
240241

241-
switch event.Type {
242-
case "text-delta":
243-
delta := api.OpenAIDelta{Content: event.Text}
244-
if !sentRole {
245-
delta.Role = "assistant"
246-
sentRole = true
242+
switch event.Type {
243+
case "text-delta":
244+
delta := api.OpenAIDelta{Content: event.Text}
245+
if !sentRole {
246+
delta.Role = "assistant"
247+
sentRole = true
248+
}
249+
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
250+
ID: requestID,
251+
Object: "chat.completion.chunk",
252+
Created: created,
253+
Model: model,
254+
Choices: []api.OpenAIChoice{{Index: 0, Delta: &delta}},
255+
})
256+
257+
case "tool-use":
258+
toolCalls := []api.OpenAIDeltaToolCall{{
259+
Index: toolCallIndex,
260+
ID: event.ToolCallID,
261+
Type: "function",
262+
Function: &api.OpenAIDeltaFunction{Name: event.ToolName},
263+
}}
264+
delta := api.OpenAIDelta{ToolCalls: toolCalls}
265+
if !sentRole {
266+
delta.Role = "assistant"
267+
sentRole = true
268+
}
269+
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
270+
ID: requestID,
271+
Object: "chat.completion.chunk",
272+
Created: created,
273+
Model: model,
274+
Choices: []api.OpenAIChoice{{Index: 0, Delta: &delta}},
275+
})
276+
toolCallIndex++
277+
278+
case "tool-delta":
279+
toolCalls := []api.OpenAIDeltaToolCall{{
280+
Index: toolCallIndex - 1,
281+
Function: &api.OpenAIDeltaFunction{Arguments: event.Text},
282+
}}
283+
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
284+
ID: requestID,
285+
Object: "chat.completion.chunk",
286+
Created: created,
287+
Model: model,
288+
Choices: []api.OpenAIChoice{{Index: 0, Delta: &api.OpenAIDelta{ToolCalls: toolCalls}}},
289+
})
290+
291+
case "tool-input-start":
292+
if _, ok := toolCallIndexes[event.ID]; !ok {
293+
toolCallIndexes[event.ID] = toolCallIndex
294+
toolCallIndex++
295+
}
296+
delta := api.OpenAIDelta{ToolCalls: []api.OpenAIDeltaToolCall{{
297+
Index: toolCallIndexes[event.ID],
298+
ID: event.ID,
299+
Type: "function",
300+
Function: &api.OpenAIDeltaFunction{
301+
Name: event.ToolName,
302+
},
303+
}}}
304+
if !sentRole {
305+
delta.Role = "assistant"
306+
sentRole = true
307+
}
308+
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
309+
ID: requestID,
310+
Object: "chat.completion.chunk",
311+
Created: created,
312+
Model: model,
313+
Choices: []api.OpenAIChoice{{Index: 0, Delta: &delta}},
314+
})
315+
316+
case "tool-input-delta":
317+
idx, ok := toolCallIndexes[event.ID]
318+
if !ok {
319+
idx = toolCallIndex
320+
toolCallIndexes[event.ID] = idx
321+
toolCallIndex++
322+
}
323+
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
324+
ID: requestID,
325+
Object: "chat.completion.chunk",
326+
Created: created,
327+
Model: model,
328+
Choices: []api.OpenAIChoice{{Index: 0, Delta: &api.OpenAIDelta{ToolCalls: []api.OpenAIDeltaToolCall{{
329+
Index: idx,
330+
Function: &api.OpenAIDeltaFunction{Arguments: event.Delta},
331+
}}}}},
332+
})
333+
334+
case "tool-call":
335+
idx, ok := toolCallIndexes[event.ToolCallID]
336+
if !ok {
337+
idx = toolCallIndex
338+
toolCallIndexes[event.ToolCallID] = idx
339+
toolCallIndex++
340+
}
341+
args := ""
342+
if event.Input != nil {
343+
if data, err := json.Marshal(event.Input); err == nil {
344+
args = string(data)
345+
}
346+
}
347+
delta := api.OpenAIDelta{ToolCalls: []api.OpenAIDeltaToolCall{{
348+
Index: idx,
349+
ID: event.ToolCallID,
350+
Type: "function",
351+
Function: &api.OpenAIDeltaFunction{
352+
Name: event.ToolName,
353+
Arguments: args,
354+
},
355+
}}}
356+
if !sentRole {
357+
delta.Role = "assistant"
358+
sentRole = true
359+
}
360+
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
361+
ID: requestID,
362+
Object: "chat.completion.chunk",
363+
Created: created,
364+
Model: model,
365+
Choices: []api.OpenAIChoice{{Index: 0, Delta: &delta}},
366+
})
367+
368+
case "finish":
369+
reason := "stop"
370+
if event.FinishReason == "tool_calls" || event.FinishReason == "tool-calls" {
371+
reason = "tool_calls"
372+
}
373+
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
374+
ID: requestID,
375+
Object: "chat.completion.chunk",
376+
Created: created,
377+
Model: model,
378+
Choices: []api.OpenAIChoice{{
379+
Index: 0,
380+
Delta: &api.OpenAIDelta{},
381+
FinishReason: &reason,
382+
}},
383+
})
384+
fmt.Fprintf(w, "data: [DONE]\n\n")
385+
flusher.Flush()
386+
387+
case "error":
388+
log.Printf("[ERROR] Stream error: %v", event.Error)
247389
}
248-
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
249-
ID: requestID,
250-
Object: "chat.completion.chunk",
251-
Created: created,
252-
Model: model,
253-
Choices: []api.OpenAIChoice{{Index: 0, Delta: &delta}},
254-
})
255-
256-
case "tool-use":
257-
toolCalls := []api.OpenAIDeltaToolCall{{
258-
Index: toolCallIndex,
259-
ID: event.ToolCallID,
260-
Type: "function",
261-
Function: &api.OpenAIDeltaFunction{Name: event.ToolName},
262-
}}
263-
delta := api.OpenAIDelta{ToolCalls: toolCalls}
264-
if !sentRole {
265-
delta.Role = "assistant"
266-
sentRole = true
267-
}
268-
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
269-
ID: requestID,
270-
Object: "chat.completion.chunk",
271-
Created: created,
272-
Model: model,
273-
Choices: []api.OpenAIChoice{{Index: 0, Delta: &delta}},
274-
})
275-
toolCallIndex++
276-
277-
case "tool-delta":
278-
toolCalls := []api.OpenAIDeltaToolCall{{
279-
Index: toolCallIndex - 1,
280-
Function: &api.OpenAIDeltaFunction{Arguments: event.Text},
281-
}}
282-
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
283-
ID: requestID,
284-
Object: "chat.completion.chunk",
285-
Created: created,
286-
Model: model,
287-
Choices: []api.OpenAIChoice{{Index: 0, Delta: &api.OpenAIDelta{ToolCalls: toolCalls}}},
288-
})
289-
290-
case "finish":
291-
reason := "stop"
292-
if event.FinishReason == "tool_calls" {
293-
reason = "tool_calls"
294-
}
295-
p.WriteSSE(w, flusher, api.OpenAIChatResponse{
296-
ID: requestID,
297-
Object: "chat.completion.chunk",
298-
Created: created,
299-
Model: model,
300-
Choices: []api.OpenAIChoice{{
301-
Index: 0,
302-
Delta: &api.OpenAIDelta{},
303-
FinishReason: &reason,
304-
}},
305-
})
306-
fmt.Fprintf(w, "data: [DONE]\n\n")
307-
flusher.Flush()
308-
309-
case "error":
310-
log.Printf("[ERROR] Stream error: %v", event.Error)
311-
}
312390
}
313391

314392
if err := scanner.Err(); err != nil && err != io.EOF {
@@ -332,6 +410,8 @@ func (p *Proxy) NonStreamResponse(w http.ResponseWriter, ccResp *http.Response,
332410
var inputTokens, outputTokens int
333411
var hasToolCalls bool
334412
var toolCalls []api.ToolCall
413+
toolCallByID := map[string]int{}
414+
toolInputBuffers := map[string]*strings.Builder{}
335415

336416
for scanner.Scan() {
337417
line := strings.TrimSpace(scanner.Text())
@@ -345,31 +425,75 @@ func (p *Proxy) NonStreamResponse(w http.ResponseWriter, ccResp *http.Response,
345425
continue
346426
}
347427

348-
switch event.Type {
349-
case "text-delta":
350-
content.WriteString(event.Text)
351-
case "tool-use":
352-
hasToolCalls = true
353-
toolCalls = append(toolCalls, api.ToolCall{
354-
ID: event.ToolCallID,
355-
Type: "function",
356-
Function: api.FunctionCall{
357-
Name: event.ToolName,
358-
Arguments: "",
359-
},
360-
})
361-
case "tool-delta":
362-
if len(toolCalls) > 0 {
363-
toolCalls[len(toolCalls)-1].Function.Arguments += event.Text
428+
switch event.Type {
429+
case "text-delta":
430+
content.WriteString(event.Text)
431+
case "tool-use":
432+
hasToolCalls = true
433+
toolCallByID[event.ToolCallID] = len(toolCalls)
434+
toolCalls = append(toolCalls, api.ToolCall{
435+
ID: event.ToolCallID,
436+
Type: "function",
437+
Function: api.FunctionCall{
438+
Name: event.ToolName,
439+
Arguments: "",
440+
},
441+
})
442+
case "tool-delta":
443+
if len(toolCalls) > 0 {
444+
toolCalls[len(toolCalls)-1].Function.Arguments += event.Text
445+
}
446+
case "tool-input-start":
447+
hasToolCalls = true
448+
toolCallByID[event.ID] = len(toolCalls)
449+
toolInputBuffers[event.ID] = &strings.Builder{}
450+
toolCalls = append(toolCalls, api.ToolCall{
451+
ID: event.ID,
452+
Type: "function",
453+
Function: api.FunctionCall{
454+
Name: event.ToolName,
455+
Arguments: "",
456+
},
457+
})
458+
case "tool-input-delta":
459+
if b := toolInputBuffers[event.ID]; b != nil {
460+
b.WriteString(event.Delta)
461+
}
462+
if idx, ok := toolCallByID[event.ID]; ok {
463+
toolCalls[idx].Function.Arguments += event.Delta
464+
}
465+
case "tool-call":
466+
hasToolCalls = true
467+
args := ""
468+
if event.Input != nil {
469+
if data, err := json.Marshal(event.Input); err == nil {
470+
args = string(data)
471+
}
472+
}
473+
if idx, ok := toolCallByID[event.ToolCallID]; ok {
474+
toolCalls[idx].Function.Name = event.ToolName
475+
if args != "" {
476+
toolCalls[idx].Function.Arguments = args
477+
}
478+
} else {
479+
toolCallByID[event.ToolCallID] = len(toolCalls)
480+
toolCalls = append(toolCalls, api.ToolCall{
481+
ID: event.ToolCallID,
482+
Type: "function",
483+
Function: api.FunctionCall{
484+
Name: event.ToolName,
485+
Arguments: args,
486+
},
487+
})
488+
}
489+
case "finish":
490+
if event.TotalUsage != nil {
491+
inputTokens = event.TotalUsage.InputTokens
492+
outputTokens = event.TotalUsage.OutputTokens
493+
}
494+
case "error":
495+
log.Printf("[ERROR] Stream error: %v", event.Error)
364496
}
365-
case "finish":
366-
if event.TotalUsage != nil {
367-
inputTokens = event.TotalUsage.InputTokens
368-
outputTokens = event.TotalUsage.OutputTokens
369-
}
370-
case "error":
371-
log.Printf("[ERROR] Stream error: %v", event.Error)
372-
}
373497
}
374498

375499
msg := &api.OpenAIMessage{

0 commit comments

Comments
 (0)