Skip to content
Merged
23 changes: 23 additions & 0 deletions internal/agent/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,29 @@ func NewOpenAIProfile(model string) ProviderProfile {
}
}

func NewCodexAppServerProfile(model string) ProviderProfile {
return &baseProfile{
id: "codex-app-server",
model: strings.TrimSpace(model),
parallel: true,
contextWindow: 1_047_576,
basePrompt: openAIProfileBasePrompt,
docFiles: []string{"AGENTS.md", ".codex/instructions.md"},
toolDefs: []llm.ToolDefinition{
defReadFile(),
defApplyPatch(),
defWriteFile(),
defShell(),
defGrep(),
defGlob(),
defSpawnAgent(),
defSendInput(),
defWait(),
defCloseAgent(),
},
}
}

func NewAnthropicProfile(model string) ProviderProfile {
return &baseProfile{
id: "anthropic",
Expand Down
8 changes: 5 additions & 3 deletions internal/agent/profile_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
var (
profileFactoriesMu sync.RWMutex
profileFactories = map[string]func(string) ProviderProfile{
"openai": NewOpenAIProfile,
"anthropic": NewAnthropicProfile,
"google": NewGeminiProfile,
"openai": NewOpenAIProfile,
"anthropic": NewAnthropicProfile,
"google": NewGeminiProfile,
"codex-app-server": NewCodexAppServerProfile,
"codex": NewCodexAppServerProfile,
}
)

Expand Down
44 changes: 44 additions & 0 deletions internal/agent/profile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@ func TestProviderProfiles_ToolsetsAndDocSelection(t *testing.T) {
assertHasTool(t, gemini, "read_many_files")
assertHasTool(t, gemini, "list_dir")
assertMissingTool(t, gemini, "apply_patch")

codex := NewCodexAppServerProfile("gpt-5-codex")
if codex.ID() != "codex-app-server" {
t.Fatalf("codex id: %q", codex.ID())
}
if !codex.SupportsParallelToolCalls() {
t.Fatalf("codex profile should support parallel tool calls")
}
if codex.ContextWindowSize() != 1_047_576 {
t.Fatalf("codex context window: got %d want %d", codex.ContextWindowSize(), 1_047_576)
}
assertHasTool(t, codex, "apply_patch")
assertMissingTool(t, codex, "edit_file")
}

func TestProviderProfiles_ToolLists_MatchSpec(t *testing.T) {
Expand Down Expand Up @@ -92,6 +105,21 @@ func TestProviderProfiles_ToolLists_MatchSpec(t *testing.T) {
"close_agent",
})
})
t.Run("codex-app-server", func(t *testing.T) {
p := NewCodexAppServerProfile("gpt-5-codex")
assertToolListExact(t, p, []string{
"read_file",
"apply_patch",
"write_file",
"shell",
"grep",
"glob",
"spawn_agent",
"send_input",
"wait",
"close_agent",
})
})
}

func TestProviderProfiles_BuildSystemPrompt_IncludesProviderSpecificBaseInstructions(t *testing.T) {
Expand Down Expand Up @@ -181,4 +209,20 @@ func TestNewProfileForFamily_DefaultFamiliesAndRegistration(t *testing.T) {
if _, err := NewProfileForFamily("missing-family", "m3"); err == nil {
t.Fatalf("expected unsupported family error")
}

codex, err := NewProfileForFamily("codex-app-server", "gpt-5-codex")
if err != nil {
t.Fatalf("NewProfileForFamily(codex-app-server): %v", err)
}
if codex.ID() != "codex-app-server" {
t.Fatalf("codex profile id=%q want codex-app-server", codex.ID())
}

codexAlias, err := NewProfileForFamily("codex", "gpt-5-codex")
if err != nil {
t.Fatalf("NewProfileForFamily(codex): %v", err)
}
if codexAlias.ID() != "codex-app-server" {
t.Fatalf("codex alias profile id=%q want codex-app-server", codexAlias.ID())
}
}
209 changes: 196 additions & 13 deletions internal/agent/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"

"github.com/oklog/ulid/v2"

Expand Down Expand Up @@ -293,16 +294,11 @@ func (s *Session) execTool(ctx context.Context, call llm.ToolCallData) ToolExecR
// Emit output deltas (best-effort). Even for non-streaming tools, this gives consumers a uniform
// incremental event pattern that mirrors provider LLM streaming.
full := res.FullOutput
const chunk = 4000
for i := 0; i < len(full); i += chunk {
j := i + chunk
if j > len(full) {
j = len(full)
}
for _, delta := range utf8Chunk(full, 4000) {
s.emit(EventToolCallOutputDelta, map[string]any{
"tool_name": res.ToolName,
"call_id": res.CallID,
"delta": full[i:j],
"delta": delta,
})
}

Expand Down Expand Up @@ -494,8 +490,8 @@ func (s *Session) processOneInput(ctx context.Context, input string) (string, er
if s.cfg.LLMRetryPolicy != nil {
policy = *s.cfg.LLMRetryPolicy
}
resp, err := llm.Retry(ctx, policy, s.cfg.LLMSleep, nil, func() (llm.Response, error) {
return s.client.Complete(ctx, req)
stream, err := llm.Retry(ctx, policy, s.cfg.LLMSleep, nil, func() (llm.Stream, error) {
return s.client.Stream(ctx, req)
})
if err != nil {
s.emit(EventError, map[string]any{"error": err.Error()})
Expand All @@ -519,15 +515,175 @@ func (s *Session) processOneInput(ctx context.Context, input string) (string, er
}
}

acc := llm.NewStreamAccumulator()
var resp *llm.Response
var streamErr error
providerToolCallCount := 0
seenProviderToolCalls := map[string]struct{}{}
seenProviderOutputDeltas := map[string]struct{}{}
seenProviderToolEnds := map[string]struct{}{}
providerToolNameByCallID := map[string]string{}
providerOutputByCallID := map[string]string{}
assistantTextStarted := false
assistantTextDelta := false
emitAssistantTextStart := func() {
if assistantTextStarted {
return
}
assistantTextStarted = true
s.emit(EventAssistantTextStart, map[string]any{})
}
emitToolOutputDeltas := func(toolName, callID, fullOutput string) {
for _, delta := range utf8Chunk(fullOutput, 4000) {
s.emit(EventToolCallOutputDelta, map[string]any{
"tool_name": toolName,
"call_id": callID,
"delta": delta,
"source": "provider",
})
}
}
for ev := range stream.Events() {
acc.Process(ev)
switch ev.Type {
case llm.StreamEventTextStart:
emitAssistantTextStart()
case llm.StreamEventTextDelta:
emitAssistantTextStart()
if ev.Delta != "" {
assistantTextDelta = true
s.emit(EventAssistantTextDelta, map[string]any{"delta": ev.Delta})
}
case llm.StreamEventFinish:
if ev.Response != nil {
cp := *ev.Response
resp = &cp
}
case llm.StreamEventError:
if ev.Err != nil {
streamErr = ev.Err
} else {
streamErr = llm.NewStreamError(req.Provider, "stream error")
}
case llm.StreamEventProviderEvent:
if lifecycle, ok := llm.ParseCodexAppServerToolLifecycle(ev); ok {
callID := strings.TrimSpace(lifecycle.CallID)
if callID == "" {
if !lifecycle.Completed {
providerToolCallCount++
}
} else {
if _, exists := seenProviderToolCalls[callID]; !exists {
seenProviderToolCalls[callID] = struct{}{}
providerToolCallCount++
}
if tn := strings.TrimSpace(lifecycle.ToolName); tn != "" {
providerToolNameByCallID[callID] = tn
}
}
if lifecycle.Completed {
if callID != "" {
if _, ended := seenProviderToolEnds[callID]; ended {
continue
}
}
if callID == "" {
emitToolOutputDeltas(lifecycle.ToolName, lifecycle.CallID, lifecycle.FullOutput)
} else if _, seen := seenProviderOutputDeltas[callID]; !seen {
emitToolOutputDeltas(lifecycle.ToolName, lifecycle.CallID, lifecycle.FullOutput)
seenProviderOutputDeltas[callID] = struct{}{}
providerOutputByCallID[callID] = lifecycle.FullOutput
} else if lifecycle.FullOutput != "" && providerOutputByCallID[callID] != lifecycle.FullOutput {
// Reconcile mismatch: provider completion output is authoritative.
emitToolOutputDeltas(lifecycle.ToolName, lifecycle.CallID, lifecycle.FullOutput)
providerOutputByCallID[callID] = lifecycle.FullOutput
}
s.emit(EventToolCallEnd, map[string]any{
"tool_name": lifecycle.ToolName,
"call_id": lifecycle.CallID,
"is_error": lifecycle.IsError,
"full_output": lifecycle.FullOutput,
"source": "provider",
})
if callID != "" {
seenProviderToolEnds[callID] = struct{}{}
}
} else {
data := map[string]any{
"tool_name": lifecycle.ToolName,
"call_id": lifecycle.CallID,
"arguments_json": lifecycle.ArgumentsJSON,
"source": "provider",
}
s.emit(EventToolCallStart, data)
}
} else if outputDelta, ok := llm.ParseCodexAppServerToolOutputDelta(ev); ok {
callID := strings.TrimSpace(outputDelta.CallID)
toolName := strings.TrimSpace(outputDelta.ToolName)
if callID != "" {
if _, exists := seenProviderToolCalls[callID]; !exists {
seenProviderToolCalls[callID] = struct{}{}
providerToolCallCount++
}
if mappedToolName := strings.TrimSpace(providerToolNameByCallID[callID]); mappedToolName != "" {
toolName = mappedToolName
} else if toolName != "" {
providerToolNameByCallID[callID] = toolName
}
seenProviderOutputDeltas[callID] = struct{}{}
providerOutputByCallID[callID] += outputDelta.Delta
}
s.emit(EventToolCallOutputDelta, map[string]any{
"tool_name": toolName,
"call_id": callID,
"delta": outputDelta.Delta,
"source": "provider",
})
}
}
}
_ = stream.Close()

if streamErr != nil {
s.emit(EventError, map[string]any{"error": streamErr.Error()})
// Spec: context overflow should emit a warning (no automatic compaction).
var cle *llm.ContextLengthError
if errors.As(streamErr, &cle) {
s.emit(EventWarning, map[string]any{"message": "Context length exceeded"})
}
// Spec: non-retryable/unrecoverable errors transition the session to CLOSED.
var le llm.Error
if errors.As(streamErr, &le) && !le.Retryable() {
s.Close()
}
return "", streamErr
}

if resp == nil {
resp = acc.Response()
}
if resp == nil {
err := llm.NewStreamError(req.Provider, "stream ended without finish event")
s.emit(EventError, map[string]any{"error": err.Error()})
return "", err
}

calls := resp.ToolCalls()
turnToolCallCount := len(calls)
if providerToolCallCount > turnToolCallCount {
turnToolCallCount = providerToolCallCount
}
txt := resp.Text()
s.emit(EventAssistantTextStart, map[string]any{})
emitAssistantTextStart()
s.appendTurn(TurnAssistant, resp.Message)
if strings.TrimSpace(txt) != "" {
if !assistantTextDelta && strings.TrimSpace(txt) != "" {
s.emit(EventAssistantTextDelta, map[string]any{"delta": txt})
}
s.emit(EventAssistantTextEnd, map[string]any{"text": txt})
s.emit(EventAssistantTextEnd, map[string]any{
"text": txt,
"tool_call_count": turnToolCallCount,
})

calls := resp.ToolCalls()
if len(calls) == 0 {
return txt, nil
}
Expand Down Expand Up @@ -629,6 +785,33 @@ func (s *Session) processOneInput(ctx context.Context, input string) (string, er
return "", fmt.Errorf("max tool rounds reached")
}

func utf8Chunk(full string, maxBytes int) []string {
if maxBytes <= 0 || len(full) == 0 {
return nil
}
chunks := make([]string, 0, len(full)/maxBytes+1)
for i := 0; i < len(full); {
j := i + maxBytes
if j >= len(full) {
chunks = append(chunks, full[i:])
break
}
for j > i && !utf8.RuneStart(full[j]) {
j--
}
if j == i {
_, size := utf8.DecodeRuneInString(full[i:])
if size <= 0 {
size = 1
}
j = i + size
}
chunks = append(chunks, full[i:j])
i = j
}
return chunks
}

func (s *Session) drainSteering() []string {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down
10 changes: 7 additions & 3 deletions internal/agent/session_dod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,8 @@ func (a *errAdapter) Complete(ctx context.Context, req llm.Request) (llm.Respons
func (a *errAdapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, error) {
_ = ctx
_ = req
return nil, fmt.Errorf("stream not implemented in errAdapter")
a.calls++
return nil, a.err
}

type flaky429Adapter struct {
Expand All @@ -894,9 +895,12 @@ func (a *flaky429Adapter) Complete(ctx context.Context, req llm.Request) (llm.Re
return llm.Response{Message: llm.Assistant("ok")}, nil
}
func (a *flaky429Adapter) Stream(ctx context.Context, req llm.Request) (llm.Stream, error) {
_ = ctx
_ = req
return nil, fmt.Errorf("stream not implemented in flaky429Adapter")
a.calls++
if a.calls <= a.failCount {
return nil, llm.ErrorFromHTTPStatus(a.name, 429, "rate limited", nil, nil)
}
return streamFromResponse(ctx, llm.Response{Provider: a.name, Model: req.Model, Message: llm.Assistant("ok")}), nil
}

func TestSession_AuthenticationError_ClosesSession(t *testing.T) {
Expand Down
Loading