Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 110 additions & 15 deletions internal/pireview/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -263,7 +264,6 @@ func (r *Runner) runModelPanel(ctx context.Context, runID string, pkt *ContextPa
}

output, err := r.invokePi(ctx, slot, pkt, evidence)
result.LatencyMs = time.Since(start).Milliseconds()

if err != nil {
result.Status = "failed"
Expand All @@ -274,6 +274,7 @@ func (r *Runner) runModelPanel(ctx context.Context, runID string, pkt *ContextPa
fbOutput, fbErr := r.invokeFallback(ctx, slot, pkt, evidence)
if fbErr == nil {
result.Status = "ok"
result.Error = ""
result.ArtifactPath = writeModelArtifact(r.cfg.ProjectRoot, runID, slot.Slot, fbOutput)
result.Provider = "openrouter"
result.Model = fallbackModel(slot)
Expand All @@ -293,6 +294,7 @@ func (r *Runner) runModelPanel(ctx context.Context, runID string, pkt *ContextPa
}
}

result.LatencyMs = time.Since(start).Milliseconds()
results = append(results, result)
}

Expand All @@ -301,32 +303,125 @@ func (r *Runner) runModelPanel(ctx context.Context, runID string, pkt *ContextPa

// invokePi calls the local pi binary with the review context.
func (r *Runner) invokePi(ctx context.Context, slot ReviewerSlot, pkt *ContextPacket, evidence *TestEvidence) (string, error) {
reviewPrompt := buildReviewPrompt(slot, pkt, evidence)
modelCtx, cancel := context.WithTimeout(ctx, r.cfg.effectiveModelTimeout())
defer cancel()
out, err := r.runner.CombinedOutput(modelCtx, r.cfg.ProjectRoot,
"pi", "--provider", slot.Provider, "--model", slot.Model,
"--no-tools", "--no-context-files", "--no-session", "-p", reviewPrompt)
if err != nil {
return "", fmt.Errorf("pi run %s/%s: %w", slot.Provider, slot.Model, err)
}
return string(out), nil
return r.invokePiProvider(ctx, slot, slot.Provider, slot.Model, false, pkt, evidence)
}

// invokeFallback attempts OpenRouter when a primary provider fails.
func (r *Runner) invokeFallback(ctx context.Context, slot ReviewerSlot, pkt *ContextPacket, evidence *TestEvidence) (string, error) {
return r.invokePiProvider(ctx, slot, "openrouter", fallbackModel(slot), true, pkt, evidence)
}

func (r *Runner) invokePiProvider(
ctx context.Context,
slot ReviewerSlot,
provider string,
model string,
fallback bool,
pkt *ContextPacket,
evidence *TestEvidence,
) (string, error) {
reviewPrompt := buildReviewPrompt(slot, pkt, evidence)
modelCtx, cancel := context.WithTimeout(ctx, r.cfg.effectiveModelTimeout())
promptArg, cleanup, err := writePromptTemp(reviewPrompt)
if err != nil {
return "", fmt.Errorf("pi prompt temp file: %w", err)
}
defer cleanup()

modelCtx, cancel := context.WithTimeout(ctx, r.effectiveSlotTimeout(slot))
defer cancel()
out, err := r.runner.CombinedOutput(modelCtx, r.cfg.ProjectRoot,
"pi", "--provider", "openrouter", "--model", fallbackModel(slot),
"--no-tools", "--no-context-files", "--no-session", "-p", reviewPrompt)
"pi", "--provider", provider, "--model", model,
"--no-tools", "--no-context-files", "--no-session", "-p", promptArg)
if err != nil {
return "", fmt.Errorf("openrouter fallback: %w", err)
return "", newPiInvokeError(provider, model, len(reviewPrompt), fallback, classifyPiError(modelCtx, err))
}
return string(out), nil
}

func writePromptTemp(prompt string) (string, func(), error) {
f, err := os.CreateTemp("", "sdp-pi-review-*.prompt.md")
if err != nil {
return "", func() {}, err
}
cleanup := func() {
_ = os.Remove(f.Name())
}

if _, err := f.WriteString(prompt); err != nil {
_ = f.Close()
cleanup()
return "", func() {}, err
}
if err := f.Close(); err != nil {
cleanup()
return "", func() {}, err
}

return "@" + f.Name(), cleanup, nil
}

func classifyPiError(ctx context.Context, err error) error {
if ctxErr := ctx.Err(); ctxErr != nil {
return ctxErr
}
if errors.Is(err, context.DeadlineExceeded) {
return context.DeadlineExceeded
}
if errors.Is(err, context.Canceled) {
return context.Canceled
}
return nil
}

const kimiModelTimeoutFloor = 6 * time.Minute

func (r *Runner) effectiveSlotTimeout(slot ReviewerSlot) time.Duration {
timeout := r.cfg.effectiveModelTimeout()
if slot.Slot != "kimi" && !strings.HasPrefix(slot.Provider, "kimi") {
return timeout
}
if timeout < kimiModelTimeoutFloor {
return kimiModelTimeoutFloor
}
return timeout
}

type piInvokeError struct {
provider string
model string
promptBytes int
fallback bool
err error
}

func newPiInvokeError(provider, model string, promptBytes int, fallback bool, err error) error {
return &piInvokeError{
provider: provider,
model: model,
promptBytes: promptBytes,
fallback: fallback,
err: err,
}
}

func (e *piInvokeError) Error() string {
phase := "pi run"
if e.fallback {
phase = "openrouter fallback"
}
reason := "provider_error"
if errors.Is(e.err, context.DeadlineExceeded) {
reason = "timeout"
} else if errors.Is(e.err, context.Canceled) {
reason = "canceled"
}
return fmt.Sprintf("%s %s/%s failed: %s (prompt_bytes=%d)", phase, e.provider, e.model, reason, e.promptBytes)
}

func (e *piInvokeError) Unwrap() error {
return e.err
}

func fallbackModel(slot ReviewerSlot) string {
switch slot.Slot {
case "zai":
Expand Down
127 changes: 126 additions & 1 deletion internal/pireview/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pireview
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -302,11 +303,123 @@ func TestRunModelPanelAttemptsFallbackAfterTimeout(t *testing.T) {
if results[0].Provider != "openrouter" {
t.Fatalf("provider = %q, want openrouter", results[0].Provider)
}
if results[0].Error != "" {
t.Fatalf("fallback success should clear primary error, got %q", results[0].Error)
}
if fr.calls != 2 {
t.Fatalf("calls = %d, want primary + fallback", fr.calls)
}
}

type leakingErrorRunner struct{}

func (leakingErrorRunner) Output(context.Context, string, string, ...string) ([]byte, error) {
return nil, nil
}

func (leakingErrorRunner) Run(context.Context, string, string, ...string) error {
return nil
}

func (leakingErrorRunner) CombinedOutput(_ context.Context, _ string, _ string, args ...string) ([]byte, error) {
return nil, fmt.Errorf("pi %s: %w", strings.Join(args, " "), context.DeadlineExceeded)
}

type killedAfterContextRunner struct{}

func (killedAfterContextRunner) Output(context.Context, string, string, ...string) ([]byte, error) {
return nil, nil
}

func (killedAfterContextRunner) Run(context.Context, string, string, ...string) error {
return nil
}

func (killedAfterContextRunner) CombinedOutput(ctx context.Context, _ string, _ string, _ ...string) ([]byte, error) {
<-ctx.Done()
return nil, errors.New("signal: killed")
}

func TestInvokePiSanitizesPromptFromError(t *testing.T) {
runner := leakingErrorRunner{}
r := &Runner{
cfg: Config{
ProjectRoot: t.TempDir(),
Scope: ScopeWorkingTree,
ModelTimeout: time.Second,
Runner: runner,
},
runner: runner,
}

_, err := r.invokePi(context.Background(), ReviewerSlot{
Slot: "zai",
Provider: "zai",
Model: "glm",
Role: "reviewer",
}, &ContextPacket{
ReviewedFiles: []string{"main.go"},
UnifiedDiff: "+ const token = \"SECRET_TOKEN_123\"",
}, &TestEvidence{})

if err == nil {
t.Fatal("expected error")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded wrapper, got %v", err)
}
if strings.Contains(err.Error(), "SECRET_TOKEN_123") || strings.Contains(err.Error(), "-p") {
t.Fatalf("error leaked prompt or raw pi args: %s", err)
}
if unwrapped := errors.Unwrap(err); unwrapped != context.DeadlineExceeded {
t.Fatalf("unwrap should expose only sanitized classification, got %v", unwrapped)
}
if !strings.Contains(err.Error(), "prompt_bytes=") {
t.Fatalf("error should include prompt size for debugging: %s", err)
}
}

func TestInvokePiClassifiesCommandContextKillAsTimeout(t *testing.T) {
runner := killedAfterContextRunner{}
r := &Runner{
cfg: Config{
ProjectRoot: t.TempDir(),
Scope: ScopeWorkingTree,
ModelTimeout: 10 * time.Millisecond,
Runner: runner,
},
runner: runner,
}

_, err := r.invokePi(context.Background(), ReviewerSlot{
Slot: "zai",
Provider: "zai",
Model: "glm",
Role: "reviewer",
}, &ContextPacket{}, &TestEvidence{})

if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected deadline exceeded wrapper, got %v", err)
}
if !strings.Contains(err.Error(), "timeout") {
t.Fatalf("error should classify context kill as timeout: %s", err)
}
}

func TestEffectiveSlotTimeoutUsesKimiFloor(t *testing.T) {
r := &Runner{cfg: Config{ProjectRoot: t.TempDir(), Scope: ScopeWorkingTree, ModelTimeout: 3 * time.Minute, Runner: &fakeRunner{}}}

kimiTimeout := r.effectiveSlotTimeout(ReviewerSlot{Slot: "kimi", Provider: "kimi-coding", Model: "k2p6"})
if kimiTimeout != kimiModelTimeoutFloor {
t.Fatalf("kimi timeout = %s, want %s", kimiTimeout, kimiModelTimeoutFloor)
}

zaiTimeout := r.effectiveSlotTimeout(ReviewerSlot{Slot: "zai", Provider: "zai", Model: "glm-5.1"})
if zaiTimeout != 3*time.Minute {
t.Fatalf("zai timeout = %s, want 3m", zaiTimeout)
}
}

func TestInvokePiHonorsModelTimeout(t *testing.T) {
r := &Runner{
cfg: Config{
Expand Down Expand Up @@ -351,7 +464,9 @@ func TestInvokePiUsesPiPrintContract(t *testing.T) {
Provider: "kimi-coding",
Model: "k2p6",
Role: "reviewer",
}, &ContextPacket{}, &TestEvidence{})
}, &ContextPacket{
UnifiedDiff: "+ const token = \"SECRET_TOKEN_123\"",
}, &TestEvidence{})
if err != nil {
t.Fatalf("invokePi: %v", err)
}
Expand All @@ -371,6 +486,16 @@ func TestInvokePiUsesPiPrintContract(t *testing.T) {
if len(call.args) > 0 && call.args[0] == "run" {
t.Fatalf("pi invocation must not use removed run subcommand: %v", call.args)
}
promptArg := call.args[len(call.args)-1]
if !strings.HasPrefix(promptArg, "@") {
t.Fatalf("prompt must be passed as @file, got %q", promptArg)
}
if strings.Contains(got, "SECRET_TOKEN_123") {
t.Fatalf("pi args leaked raw prompt: %v", call.args)
}
if _, err := os.Stat(strings.TrimPrefix(promptArg, "@")); !os.IsNotExist(err) {
t.Fatalf("temp prompt file should be removed, stat error = %v", err)
}
}

func TestHashString(t *testing.T) {
Expand Down
Loading