Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions internal/util/claude_tool_id.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package util

import (
"fmt"
"regexp"
"sync/atomic"
"time"
)

var (
claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
claudeToolUseIDCounter uint64
)

// SanitizeClaudeToolID ensures the given id conforms to Claude's
// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are
// replaced with '_'; an empty result gets a generated fallback.
func SanitizeClaudeToolID(id string) string {
s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_")
if s == "" {
s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1))
}
return s
}
23 changes: 23 additions & 0 deletions internal/util/claude_tool_id_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package util

import (
"regexp"
"testing"
)

func TestSanitizeClaudeToolID_ReplacesInvalidCharacters(t *testing.T) {
got := SanitizeClaudeToolID("fs.readFile:temp@1")
if got != "fs_readFile_temp_1" {
t.Fatalf("SanitizeClaudeToolID returned %q", got)
}
}

func TestSanitizeClaudeToolID_GeneratesFallbackForEmptyResult(t *testing.T) {
got := SanitizeClaudeToolID("!!!")
if got == "" {
t.Fatal("expected non-empty fallback id")
}
if !regexp.MustCompile(`^[a-zA-Z0-9_-]+$`).MatchString(got) {
t.Fatalf("fallback id %q does not match Claude regex", got)
}
}
Comment on lines +8 to +23
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tests are good, but they could be improved by using a single table-driven test. This would make it easier to add more test cases in the future, would cover more edge cases (like empty input or already-valid IDs), and would avoid recompiling the regex in a loop.

func TestSanitizeClaudeToolID(t *testing.T) {
	claudeToolIDRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)

	tests := []struct {
		name             string
		input            string
		expected         string
		fallbackExpected bool
	}{
		{"replaces invalid characters", "fs.readFile:temp@1", "fs_readFile_temp_1", false},
		{"valid id is unchanged", "valid-id_123", "valid-id_123", false},
		{"generates fallback for all-invalid input", "!!!", "", true},
		{"generates fallback for empty input", "", "", true},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got := SanitizeClaudeToolID(tt.input)

			if tt.fallbackExpected {
				if got == "" {
					t.Error("expected fallback ID, but got empty string")
				}
			} else {
				if got != tt.expected {
					t.Errorf("got %q, want %q", got, tt.expected)
				}
			}

			if !claudeToolIDRegex.MatchString(got) {
				t.Errorf("output %q does not match Claude tool ID regex", got)
			}
		})
	}
}

44 changes: 44 additions & 0 deletions internal/util/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,50 @@ func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string {
return out
}

// ToolUseNameMapFromClaudeRequest returns a tool_use.id -> tool_use.name map extracted from Claude messages.
// It is used by request translators to recover the original tool name when tool_result only carries tool_use_id.
func ToolUseNameMapFromClaudeRequest(rawJSON []byte) map[string]string {
if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) {
return nil
}

messages := gjson.GetBytes(rawJSON, "messages")
if !messages.Exists() || !messages.IsArray() {
return nil
}

out := map[string]string{}
messages.ForEach(func(_, message gjson.Result) bool {
contents := message.Get("content")
if !contents.IsArray() {
return true
}

contents.ForEach(func(_, content gjson.Result) bool {
if content.Get("type").String() != "tool_use" {
return true
}

toolUseID := strings.TrimSpace(content.Get("id").String())
toolName := strings.TrimSpace(content.Get("name").String())
if toolUseID == "" || toolName == "" {
return true
}

if _, exists := out[toolUseID]; !exists {
out[toolUseID] = toolName
}
return true
})
return true
})
Comment on lines +272 to +301
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation with nested ForEach loops can be simplified by using a gjson path query to directly select all tool_use content parts. This makes the code more concise and idiomatic when using the gjson library.

	out := map[string]string{}
	gjson.GetBytes(rawJSON, `messages.#.content.#[type=="tool_use"]`).ForEach(func(_, toolUse gjson.Result) bool {
		toolUseID := strings.TrimSpace(toolUse.Get("id").String())
		toolName := strings.TrimSpace(toolUse.Get("name").String())
		if toolUseID == "" || toolName == "" {
			return true
		}

		if _, exists := out[toolUseID]; !exists {
			out[toolUseID] = toolName
		}
		return true
	})


if len(out) == 0 {
return nil
}
return out
Comment on lines +268 to +306
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function can be simplified by using a gjson path query to directly select the tool_use objects. This avoids nested loops and several checks, making the code more concise and potentially more performant.

    if !gjson.ValidBytes(rawJSON) {
        return nil
    }

    out := map[string]string{}
    gjson.GetBytes(rawJSON, `messages.#.content.#[type=="tool_use"]`).ForEach(func(_, content gjson.Result) bool {
        toolUseID := strings.TrimSpace(content.Get("id").String())
        toolName := strings.TrimSpace(content.Get("name").String())
        if toolUseID == "" || toolName == "" {
            return true
        }

        if _, exists := out[toolUseID]; !exists {
            out[toolUseID] = toolName
        }
        return true
    })

    if len(out) == 0 {
        return nil
    }
    return out

}

func MapToolName(toolNameMap map[string]string, name string) string {
if name == "" || toolNameMap == nil {
return name
Expand Down
45 changes: 45 additions & 0 deletions internal/util/translator_tool_use_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package util

import "testing"

func TestToolUseNameMapFromClaudeRequest(t *testing.T) {
raw := []byte(`{
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "toolu_1", "name": "Read_File"},
{"type": "text", "text": "ignored"},
{"type": "tool_use", "id": "toolu_2", "name": "Bash"},
{"type": "tool_use", "id": "toolu_1", "name": "ignored-duplicate"}
]
}
]
}`)

got := ToolUseNameMapFromClaudeRequest(raw)
if len(got) != 2 {
t.Fatalf("expected 2 tool_use mappings, got %d", len(got))
}
if got["toolu_1"] != "Read_File" {
t.Fatalf("toolu_1 = %q, want %q", got["toolu_1"], "Read_File")
}
if got["toolu_2"] != "Bash" {
t.Fatalf("toolu_2 = %q, want %q", got["toolu_2"], "Bash")
}
}

func TestToolUseNameMapFromClaudeRequest_InvalidOrMissingMessages(t *testing.T) {
tests := [][]byte{
nil,
[]byte(`not-json`),
[]byte(`{"messages": {}}`),
[]byte(`{"messages": []}`),
}

for _, raw := range tests {
if got := ToolUseNameMapFromClaudeRequest(raw); got != nil {
t.Fatalf("expected nil map for %q, got %#v", string(raw), got)
}
}
}