Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion lib/commands/sweep.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import { formatPrunedItemsList } from "../ui/utils"
import { getCurrentParams, getTotalToolTokens } from "../strategies/utils"
import { buildToolIdList, isIgnoredUserMessage } from "../messages/utils"
import { saveSessionState } from "../state/persistence"
import { isMessageCompacted } from "../shared-utils"
import { getLastUserMessage, isMessageCompacted } from "../shared-utils"
import { getFilePathsFromParameters, isProtected } from "../protected-file-patterns"
import { syncToolCache } from "../state/tool-cache"

Expand Down Expand Up @@ -215,11 +215,21 @@ export async function handleSweepCommand(ctx: SweepCommandContext): Promise<void
}

const tokensSaved = getTotalToolTokens(state, newToolIds)
const originMessageId = getLastUserMessage(messages)?.info.id || ""
if (!originMessageId) {
logger.warn("Sweep prune origin unavailable - missing user message")
}

// Add to prune list
for (const id of newToolIds) {
const entry = state.toolParameters.get(id)
state.prune.tools.set(id, entry?.tokenCount ?? 0)
if (originMessageId) {
state.prune.origins.set(id, {
source: "sweep",
originMessageId,
})
}
}
state.stats.pruneTokenCounter += tokensSaved
state.stats.totalPruneTokens += state.stats.pruneTokenCounter
Expand Down
2 changes: 2 additions & 0 deletions lib/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ const DEFAULT_PROTECTED_TOOLS = [
"batch",
"plan_enter",
"plan_exit",
"write",
"edit",
]

// Valid config keys for validation against user config
Expand Down
3 changes: 2 additions & 1 deletion lib/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import type { PluginConfig } from "./config"
import { assignMessageRefs } from "./message-ids"
import { syncToolCache } from "./state/tool-cache"
import { deduplicate, supersedeWrites, purgeErrors } from "./strategies"
import { prune, insertPruneToolContext, insertMessageIdContext } from "./messages"
import { prune, syncToolOrigins, insertPruneToolContext, insertMessageIdContext } from "./messages"
import { buildToolIdList, isIgnoredUserMessage } from "./messages/utils"
import { checkSession } from "./state"
import { renderSystemPrompt } from "./prompts"
Expand Down Expand Up @@ -113,6 +113,7 @@ export function createChatMessageTransformHandler(

syncToolCache(state, config, logger, output.messages)
buildToolIdList(state, output.messages, logger)
syncToolOrigins(state, logger, output.messages)

deduplicate(state, logger, config, output.messages)
supersedeWrites(state, logger, config, output.messages)
Expand Down
1 change: 1 addition & 0 deletions lib/messages/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { prune } from "./prune"
export { syncToolOrigins } from "./sync"
export { insertPruneToolContext } from "./inject"
export { insertMessageIdContext } from "./inject"
15 changes: 8 additions & 7 deletions lib/messages/prune.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,14 @@ const pruneToolInputs = (state: SessionState, logger: Logger, messages: WithPart
if (part.type !== "tool") {
continue
}
if (part.tool === "compress" && part.state.status === "completed") {
const content = part.state.input?.content
if (content && typeof content === "object" && "summary" in content) {
content.summary = PRUNED_COMPRESS_SUMMARY_REPLACEMENT
}
continue
}

// if (part.tool === "compress" && part.state.status === "completed") {
// const content = part.state.input?.content
// if (content && typeof content === "object" && "summary" in content) {
// content.summary = PRUNED_COMPRESS_SUMMARY_REPLACEMENT
// }
// continue
// }

if (!state.prune.tools.has(part.callID)) {
continue
Expand Down
39 changes: 39 additions & 0 deletions lib/messages/sync.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import type { SessionState, WithParts } from "../state"
import type { Logger } from "../logger"

export const syncToolOrigins = (
state: SessionState,
logger: Logger,
messages: WithParts[],
): void => {
if (!state.prune.origins?.size) {
return
}

const messageIds = new Set(messages.map((msg) => msg.info.id))
let removedToolCount = 0
let removedOriginCount = 0

for (const [toolId, origin] of state.prune.origins.entries()) {
if (!state.prune.tools.has(toolId)) {
state.prune.origins.delete(toolId)
removedOriginCount++
continue
}

if (!messageIds.has(origin.originMessageId)) {
state.prune.origins.delete(toolId)
removedOriginCount++
if (state.prune.tools.delete(toolId)) {
removedToolCount++
}
}
}

if (removedToolCount > 0 || removedOriginCount > 0) {
logger.info("Synced prune origins", {
removedToolCount,
removedOriginCount,
})
}
}
16 changes: 11 additions & 5 deletions lib/messages/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,15 @@ export const createSyntheticToolPart = (
const partId = generateStableId("prt_dcp_tool", deterministicSeed)
const callId = generateStableId("call_dcp_tool", deterministicSeed)

// Gemini requires thoughtSignature bypass to accept synthetic tool parts
// Gemini requires a thought signature on synthetic function calls.
// Keep this metadata both on the part and on state so whichever
// conversion path is used can forward it to providerOptions.
const toolPartMetadata = isGeminiModel(modelID)
? { google: { thoughtSignature: "skip_thought_signature_validator" } }
: {}
? {
google: { thoughtSignature: "skip_thought_signature_validator" },
vertex: { thoughtSignature: "skip_thought_signature_validator" },
}
: undefined

return {
id: partId,
Expand All @@ -96,14 +101,15 @@ export const createSyntheticToolPart = (
type: "tool" as const,
callID: callId,
tool: "context_info",
...(toolPartMetadata ? { metadata: toolPartMetadata } : {}),
state: {
status: "completed" as const,
input: {},
output: content,
title: "Context Info",
metadata: toolPartMetadata,
...(toolPartMetadata ? { metadata: toolPartMetadata } : {}),
time: { start: now, end: now },
},
} as any,
}
}

Expand Down
4 changes: 3 additions & 1 deletion lib/state/persistence.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import * as fs from "fs/promises"
import { existsSync } from "fs"
import { homedir } from "os"
import { join } from "path"
import type { SessionState, SessionStats, CompressSummary } from "./types"
import type { SessionState, SessionStats, CompressSummary, PruneOrigin } from "./types"
import type { Logger } from "../logger"

/** Prune state as stored on disk */
export interface PersistedPrune {
// New format: tool/message IDs with token counts
tools?: Record<string, number>
messages?: Record<string, number>
origins?: Record<string, PruneOrigin>
// Legacy format: plain ID arrays (backward compatibility)
toolIds?: string[]
messageIds?: string[]
Expand Down Expand Up @@ -64,6 +65,7 @@ export async function saveSessionState(
prune: {
tools: Object.fromEntries(sessionState.prune.tools),
messages: Object.fromEntries(sessionState.prune.messages),
origins: Object.fromEntries(sessionState.prune.origins),
},
compressSummaries: sessionState.compressSummaries,
stats: sessionState.stats,
Expand Down
4 changes: 4 additions & 0 deletions lib/state/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
countTurns,
resetOnCompaction,
loadPruneMap,
loadPruneOriginMap,
} from "./utils"
import { getLastUserMessage } from "../shared-utils"

Expand Down Expand Up @@ -67,6 +68,7 @@ export function createSessionState(): SessionState {
prune: {
tools: new Map<string, number>(),
messages: new Map<string, number>(),
origins: new Map(),
},
compressSummaries: [],
stats: {
Expand Down Expand Up @@ -97,6 +99,7 @@ export function resetSessionState(state: SessionState): void {
state.prune = {
tools: new Map<string, number>(),
messages: new Map<string, number>(),
origins: new Map(),
}
state.compressSummaries = []
state.stats = {
Expand Down Expand Up @@ -151,6 +154,7 @@ export async function ensureSessionInitialized(

state.prune.tools = loadPruneMap(persisted.prune.tools, persisted.prune.toolIds)
state.prune.messages = loadPruneMap(persisted.prune.messages, persisted.prune.messageIds)
state.prune.origins = loadPruneOriginMap(persisted.prune.origins)
state.compressSummaries = persisted.compressSummaries || []
state.stats = {
pruneTokenCounter: persisted.stats?.pruneTokenCounter || 0,
Expand Down
14 changes: 14 additions & 0 deletions lib/state/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,23 @@ export interface CompressSummary {
summary: string
}

export type PruneOriginSource =
| "prune"
| "distill"
| "sweep"
| "deduplication"
| "supersedeWrites"
| "purgeErrors"

export interface PruneOrigin {
source: PruneOriginSource
originMessageId: string
}

export interface Prune {
tools: Map<string, number>
messages: Map<string, number>
origins: Map<string, PruneOrigin>
}

export interface PendingManualTrigger {
Expand Down
24 changes: 23 additions & 1 deletion lib/state/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { SessionState, WithParts } from "./types"
import type { PruneOrigin, SessionState, WithParts } from "./types"
import { isMessageCompacted } from "../shared-utils"

export async function isSubAgentSession(client: any, sessionID: string): Promise<boolean> {
Expand Down Expand Up @@ -45,10 +45,32 @@ export function loadPruneMap(
return new Map()
}

export function loadPruneOriginMap(obj?: Record<string, PruneOrigin>): Map<string, PruneOrigin> {
if (!obj || typeof obj !== "object") {
return new Map()
}

const entries: [string, PruneOrigin][] = []
for (const [toolId, origin] of Object.entries(obj)) {
if (
origin &&
typeof origin === "object" &&
typeof origin.source === "string" &&
typeof origin.originMessageId === "string" &&
origin.originMessageId.length > 0
) {
entries.push([toolId, origin])
}
}

return new Map(entries)
}

export function resetOnCompaction(state: SessionState): void {
state.toolParameters.clear()
state.prune.tools = new Map<string, number>()
state.prune.messages = new Map<string, number>()
state.prune.origins = new Map()
state.compressSummaries = []
state.messageIds = {
byRawId: new Map<string, string>(),
Expand Down
11 changes: 11 additions & 0 deletions lib/strategies/deduplication.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { PluginConfig } from "../config"
import { Logger } from "../logger"
import type { SessionState, WithParts } from "../state"
import { getFilePathsFromParameters, isProtected } from "../protected-file-patterns"
import { getLastUserMessage } from "../shared-utils"
import { getTotalToolTokens } from "./utils"

/**
Expand Down Expand Up @@ -79,11 +80,21 @@ export const deduplicate = (
}

state.stats.totalPruneTokens += getTotalToolTokens(state, newPruneIds)
const decisionMessageId = getLastUserMessage(messages)?.info.id || ""

if (newPruneIds.length > 0) {
if (!decisionMessageId) {
logger.warn("Deduplication prune origin unavailable - missing user message")
}
for (const id of newPruneIds) {
const entry = state.toolParameters.get(id)
state.prune.tools.set(id, entry?.tokenCount ?? 0)
if (decisionMessageId) {
state.prune.origins.set(id, {
source: "deduplication",
originMessageId: decisionMessageId,
})
}
}
logger.debug(`Marked ${newPruneIds.length} duplicate tool calls for pruning`)
}
Expand Down
11 changes: 11 additions & 0 deletions lib/strategies/purge-errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { PluginConfig } from "../config"
import { Logger } from "../logger"
import type { SessionState, WithParts } from "../state"
import { getFilePathsFromParameters, isProtected } from "../protected-file-patterns"
import { getLastUserMessage } from "../shared-utils"
import { getTotalToolTokens } from "./utils"

/**
Expand Down Expand Up @@ -72,10 +73,20 @@ export const purgeErrors = (
}

if (newPruneIds.length > 0) {
const decisionMessageId = getLastUserMessage(messages)?.info.id || ""
if (!decisionMessageId) {
logger.warn("Purge errors prune origin unavailable - missing user message")
}
state.stats.totalPruneTokens += getTotalToolTokens(state, newPruneIds)
for (const id of newPruneIds) {
const entry = state.toolParameters.get(id)
state.prune.tools.set(id, entry?.tokenCount ?? 0)
if (decisionMessageId) {
state.prune.origins.set(id, {
source: "purgeErrors",
originMessageId: decisionMessageId,
})
}
}
logger.debug(
`Marked ${newPruneIds.length} error tool calls for pruning (older than ${turnThreshold} turns)`,
Expand Down
11 changes: 11 additions & 0 deletions lib/strategies/supersede-writes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { PluginConfig } from "../config"
import { Logger } from "../logger"
import type { SessionState, WithParts } from "../state"
import { getFilePathsFromParameters, isProtected } from "../protected-file-patterns"
import { getLastUserMessage } from "../shared-utils"
import { getTotalToolTokens } from "./utils"

/**
Expand Down Expand Up @@ -105,10 +106,20 @@ export const supersedeWrites = (
}

if (newPruneIds.length > 0) {
const decisionMessageId = getLastUserMessage(messages)?.info.id || ""
if (!decisionMessageId) {
logger.warn("Supersede writes prune origin unavailable - missing user message")
}
state.stats.totalPruneTokens += getTotalToolTokens(state, newPruneIds)
for (const id of newPruneIds) {
const entry = state.toolParameters.get(id)
state.prune.tools.set(id, entry?.tokenCount ?? 0)
if (decisionMessageId) {
state.prune.origins.set(id, {
source: "supersedeWrites",
originMessageId: decisionMessageId,
})
}
}
logger.debug(`Marked ${newPruneIds.length} superseded write tool calls for pruning`)
}
Expand Down
1 change: 1 addition & 0 deletions lib/tools/distill.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export function createDistillTool(ctx: PruneToolContext): ReturnType<typeof tool
ids,
"extraction" as PruneReason,
"Distill",
"distill",
distillations,
)
},
Expand Down
Loading