diff --git a/pkg/agent/loop_iteration.go b/pkg/agent/loop_iteration.go index c00541d..49bb201 100644 --- a/pkg/agent/loop_iteration.go +++ b/pkg/agent/loop_iteration.go @@ -2,6 +2,7 @@ package agent import ( "context" + "errors" "fmt" "sync" @@ -54,6 +55,17 @@ func (al *AgentLoop) runIteration(ctx context.Context, sessionKey string, stream finalContent, result, err := al.callLLMWithRetry(ctx, st, msgsForLLM) if err != nil { al.saveSession(ctx, sessionKey, *msgs) + // A cancel that lands while the provider stream is in flight surfaces + // as a provider error. Classify it as cancellation (not an LLM + // failure) so adopters' errors.Is(err, ErrContextCancelled) checks + // fire — mirroring the pre-call cancel path above. + if cause := ctx.Err(); cause != nil || errors.Is(err, context.Canceled) { + if cause == nil { + cause = err + } + al.emit(ctx, sessionKey, streamChan, errEvent(fmt.Errorf("%w: %w", ErrContextCancelled, cause))) + return 0, true + } al.emit(ctx, sessionKey, streamChan, errEvent(&LLMFailureError{Cause: err})) return 0, true } diff --git a/pkg/agent/loop_stream_test.go b/pkg/agent/loop_stream_test.go index 5211075..30cc527 100644 --- a/pkg/agent/loop_stream_test.go +++ b/pkg/agent/loop_stream_test.go @@ -195,6 +195,35 @@ func TestRunIterationStream_ForwardsErrorOnCancelledCtx(t *testing.T) { } } +func TestRunIteration_MidStreamCancelClassifiedAsContextCancelled(t *testing.T) { + // A cancel landing while the provider stream is in flight surfaces as a + // provider error wrapping context.Canceled. It must be classified as + // ErrContextCancelled (not LLMFailureError) so adopters' documented + // errors.Is(err, ErrContextCancelled) terminal check fires. The ctx itself + // is not cancelled here, exercising the errors.Is(err, context.Canceled) + // race branch specifically. + provider := &errorProvider{err: fmt.Errorf("provider stream error: %w", context.Canceled)} + loop, _ := setup(provider) + + var sawErr bool + for ev := range loop.RunText(context.Background(), "s1", "hello") { + p, ok := ev.Payload.(ErrorEvent) + if !ok { + continue + } + sawErr = true + if !errors.Is(p.Err, ErrContextCancelled) { + t.Fatalf("expected ErrContextCancelled, got %v", p.Err) + } + if errors.Is(p.Err, ErrLLMFailure) { + t.Fatalf("must not classify cancellation as LLM failure: %v", p.Err) + } + } + if !sawErr { + t.Fatal("expected an error event") + } +} + func TestRunIteration_ParallelToolCalls(t *testing.T) { provider := &scriptProvider{turns: []LLMResult{ {ToolCalls: []PendingToolCall{