-
-
Notifications
You must be signed in to change notification settings - Fork 5k
feat(util): add Claude tool id and tool_use mapping helpers #1927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| package util | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "regexp" | ||
| "sync/atomic" | ||
| "time" | ||
| ) | ||
|
|
||
| var ( | ||
| claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`) | ||
| claudeToolUseIDCounter uint64 | ||
| ) | ||
|
|
||
| // SanitizeClaudeToolID ensures the given id conforms to Claude's | ||
| // tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are | ||
| // replaced with '_'; an empty result gets a generated fallback. | ||
| func SanitizeClaudeToolID(id string) string { | ||
| s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_") | ||
| if s == "" { | ||
| s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1)) | ||
| } | ||
| return s | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| package util | ||
|
|
||
| import ( | ||
| "regexp" | ||
| "testing" | ||
| ) | ||
|
|
||
| func TestSanitizeClaudeToolID_ReplacesInvalidCharacters(t *testing.T) { | ||
| got := SanitizeClaudeToolID("fs.readFile:temp@1") | ||
| if got != "fs_readFile_temp_1" { | ||
| t.Fatalf("SanitizeClaudeToolID returned %q", got) | ||
| } | ||
| } | ||
|
|
||
| func TestSanitizeClaudeToolID_GeneratesFallbackForEmptyResult(t *testing.T) { | ||
| got := SanitizeClaudeToolID("!!!") | ||
| if got == "" { | ||
| t.Fatal("expected non-empty fallback id") | ||
| } | ||
| if !regexp.MustCompile(`^[a-zA-Z0-9_-]+$`).MatchString(got) { | ||
| t.Fatalf("fallback id %q does not match Claude regex", got) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -262,6 +262,50 @@ func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string { | |
| return out | ||
| } | ||
|
|
||
| // ToolUseNameMapFromClaudeRequest returns a tool_use.id -> tool_use.name map extracted from Claude messages. | ||
| // It is used by request translators to recover the original tool name when tool_result only carries tool_use_id. | ||
| func ToolUseNameMapFromClaudeRequest(rawJSON []byte) map[string]string { | ||
| if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) { | ||
| return nil | ||
| } | ||
|
|
||
| messages := gjson.GetBytes(rawJSON, "messages") | ||
| if !messages.Exists() || !messages.IsArray() { | ||
| return nil | ||
| } | ||
|
|
||
| out := map[string]string{} | ||
| messages.ForEach(func(_, message gjson.Result) bool { | ||
| contents := message.Get("content") | ||
| if !contents.IsArray() { | ||
| return true | ||
| } | ||
|
|
||
| contents.ForEach(func(_, content gjson.Result) bool { | ||
| if content.Get("type").String() != "tool_use" { | ||
| return true | ||
| } | ||
|
|
||
| toolUseID := strings.TrimSpace(content.Get("id").String()) | ||
| toolName := strings.TrimSpace(content.Get("name").String()) | ||
| if toolUseID == "" || toolName == "" { | ||
| return true | ||
| } | ||
|
|
||
| if _, exists := out[toolUseID]; !exists { | ||
| out[toolUseID] = toolName | ||
| } | ||
| return true | ||
| }) | ||
| return true | ||
| }) | ||
|
Comment on lines
+272
to
+301
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation with nested out := map[string]string{}
gjson.GetBytes(rawJSON, `messages.#.content.#[type=="tool_use"]`).ForEach(func(_, toolUse gjson.Result) bool {
toolUseID := strings.TrimSpace(toolUse.Get("id").String())
toolName := strings.TrimSpace(toolUse.Get("name").String())
if toolUseID == "" || toolName == "" {
return true
}
if _, exists := out[toolUseID]; !exists {
out[toolUseID] = toolName
}
return true
}) |
||
|
|
||
| if len(out) == 0 { | ||
| return nil | ||
| } | ||
| return out | ||
|
Comment on lines
+268
to
+306
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function can be simplified by using a if !gjson.ValidBytes(rawJSON) {
return nil
}
out := map[string]string{}
gjson.GetBytes(rawJSON, `messages.#.content.#[type=="tool_use"]`).ForEach(func(_, content gjson.Result) bool {
toolUseID := strings.TrimSpace(content.Get("id").String())
toolName := strings.TrimSpace(content.Get("name").String())
if toolUseID == "" || toolName == "" {
return true
}
if _, exists := out[toolUseID]; !exists {
out[toolUseID] = toolName
}
return true
})
if len(out) == 0 {
return nil
}
return out |
||
| } | ||
|
|
||
| func MapToolName(toolNameMap map[string]string, name string) string { | ||
| if name == "" || toolNameMap == nil { | ||
| return name | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| package util | ||
|
|
||
| import "testing" | ||
|
|
||
| func TestToolUseNameMapFromClaudeRequest(t *testing.T) { | ||
| raw := []byte(`{ | ||
| "messages": [ | ||
| { | ||
| "role": "assistant", | ||
| "content": [ | ||
| {"type": "tool_use", "id": "toolu_1", "name": "Read_File"}, | ||
| {"type": "text", "text": "ignored"}, | ||
| {"type": "tool_use", "id": "toolu_2", "name": "Bash"}, | ||
| {"type": "tool_use", "id": "toolu_1", "name": "ignored-duplicate"} | ||
| ] | ||
| } | ||
| ] | ||
| }`) | ||
|
|
||
| got := ToolUseNameMapFromClaudeRequest(raw) | ||
| if len(got) != 2 { | ||
| t.Fatalf("expected 2 tool_use mappings, got %d", len(got)) | ||
| } | ||
| if got["toolu_1"] != "Read_File" { | ||
| t.Fatalf("toolu_1 = %q, want %q", got["toolu_1"], "Read_File") | ||
| } | ||
| if got["toolu_2"] != "Bash" { | ||
| t.Fatalf("toolu_2 = %q, want %q", got["toolu_2"], "Bash") | ||
| } | ||
| } | ||
|
|
||
| func TestToolUseNameMapFromClaudeRequest_InvalidOrMissingMessages(t *testing.T) { | ||
| tests := [][]byte{ | ||
| nil, | ||
| []byte(`not-json`), | ||
| []byte(`{"messages": {}}`), | ||
| []byte(`{"messages": []}`), | ||
| } | ||
|
|
||
| for _, raw := range tests { | ||
| if got := ToolUseNameMapFromClaudeRequest(raw); got != nil { | ||
| t.Fatalf("expected nil map for %q, got %#v", string(raw), got) | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests are good, but they could be improved by using a single table-driven test. This would make it easier to add more test cases in the future, would cover more edge cases (like empty input or already-valid IDs), and would avoid recompiling the regex in a loop.