@@ -7,6 +7,7 @@ import { loadPrompt } from "../prompts"
77import { getCurrentParams , getTotalToolTokens , countMessageTextTokens } from "../strategies/utils"
88import { findStringInMessages , collectToolIdsInRange , collectMessageIdsInRange } from "./utils"
99import { sendCompressNotification } from "../ui/notification"
10+ import { prune as applyPruneTransforms } from "../messages/prune"
1011
1112const COMPRESS_TOOL_DESCRIPTION = loadPrompt ( "compress-tool-spec" )
1213
@@ -82,37 +83,56 @@ export function createCompressTool(ctx: PruneToolContext): ReturnType<typeof too
8283 ctx . config . manualMode . enabled ,
8384 )
8485
86+ const transformedMessages = structuredClone ( messages ) as WithParts [ ]
87+ applyPruneTransforms ( state , logger , ctx . config , transformedMessages )
88+
8589 const startResult = findStringInMessages (
86- messages ,
90+ transformedMessages ,
8791 startString ,
8892 logger ,
89- state . compressSummaries ,
9093 "startString" ,
9194 )
9295 const endResult = findStringInMessages (
93- messages ,
96+ transformedMessages ,
9497 endString ,
9598 logger ,
96- state . compressSummaries ,
9799 "endString" ,
98100 )
99101
100- if ( startResult . messageIndex > endResult . messageIndex ) {
102+ let rawStartIndex = messages . findIndex ( ( m ) => m . info . id === startResult . messageId )
103+ let rawEndIndex = messages . findIndex ( ( m ) => m . info . id === endResult . messageId )
104+
105+ // If a boundary matched inside a synthetic compress summary message,
106+ // resolve it back to the summary's anchor message in the raw messages
107+ if ( rawStartIndex === - 1 ) {
108+ const summary = state . compressSummaries . find ( ( s ) => s . summary . includes ( startString ) )
109+ if ( summary ) {
110+ rawStartIndex = messages . findIndex ( ( m ) => m . info . id === summary . anchorMessageId )
111+ }
112+ }
113+ if ( rawEndIndex === - 1 ) {
114+ const summary = state . compressSummaries . find ( ( s ) => s . summary . includes ( endString ) )
115+ if ( summary ) {
116+ rawEndIndex = messages . findIndex ( ( m ) => m . info . id === summary . anchorMessageId )
117+ }
118+ }
119+
120+ if ( rawStartIndex === - 1 || rawEndIndex === - 1 ) {
121+ throw new Error ( `Failed to map boundary matches back to raw messages` )
122+ }
123+
124+ if ( rawStartIndex > rawEndIndex ) {
101125 throw new Error (
102126 `startString appears after endString in the conversation. Start must come before end.` ,
103127 )
104128 }
105129
106- const containedToolIds = collectToolIdsInRange (
107- messages ,
108- startResult . messageIndex ,
109- endResult . messageIndex ,
110- )
130+ const containedToolIds = collectToolIdsInRange ( messages , rawStartIndex , rawEndIndex )
111131
112132 const containedMessageIds = collectMessageIdsInRange (
113133 messages ,
114- startResult . messageIndex ,
115- endResult . messageIndex ,
134+ rawStartIndex ,
135+ rawEndIndex ,
116136 )
117137
118138 // Remove any existing summaries whose anchors are now inside this range
@@ -132,38 +152,45 @@ export function createCompressTool(ctx: PruneToolContext): ReturnType<typeof too
132152 }
133153 state . compressSummaries . push ( compressSummary )
134154
155+ const compressedMessageIds = containedMessageIds . filter (
156+ ( id ) => ! state . prune . messages . has ( id ) ,
157+ )
158+ const compressedToolIds = containedToolIds . filter ( ( id ) => ! state . prune . tools . has ( id ) )
159+
135160 let textTokens = 0
136- for ( let i = startResult . messageIndex ; i <= endResult . messageIndex ; i ++ ) {
137- const msgId = messages [ i ] . info . id
138- if ( ! state . prune . messages . has ( msgId ) ) {
139- const tokens = countMessageTextTokens ( messages [ i ] )
161+ for ( const msgId of compressedMessageIds ) {
162+ const msg = messages . find ( ( m ) => m . info . id === msgId )
163+ if ( msg ) {
164+ const tokens = countMessageTextTokens ( msg )
140165 textTokens += tokens
141166 state . prune . messages . set ( msgId , tokens )
142167 }
143168 }
144- const newToolIds = containedToolIds . filter ( ( id ) => ! state . prune . tools . has ( id ) )
145- const toolTokens = getTotalToolTokens ( state , newToolIds )
146- for ( const id of newToolIds ) {
169+ const toolTokens = getTotalToolTokens ( state , compressedToolIds )
170+ for ( const id of compressedToolIds ) {
147171 const entry = state . toolParameters . get ( id )
148172 state . prune . tools . set ( id , entry ?. tokenCount ?? 0 )
149173 }
150174 const estimatedCompressedTokens = textTokens + toolTokens
151175
152176 state . stats . pruneTokenCounter += estimatedCompressedTokens
153177
178+ const rawStartResult = { messageId : startResult . messageId , messageIndex : rawStartIndex }
179+ const rawEndResult = { messageId : endResult . messageId , messageIndex : rawEndIndex }
180+
154181 const currentParams = getCurrentParams ( state , messages , logger )
155182 await sendCompressNotification (
156183 client ,
157184 logger ,
158185 ctx . config ,
159186 state ,
160187 sessionId ,
161- containedToolIds ,
162- containedMessageIds ,
188+ compressedToolIds ,
189+ compressedMessageIds ,
163190 topic ,
164191 summary ,
165- startResult ,
166- endResult ,
192+ rawStartResult ,
193+ rawEndResult ,
167194 messages . length ,
168195 currentParams ,
169196 )
@@ -184,8 +211,7 @@ export function createCompressTool(ctx: PruneToolContext): ReturnType<typeof too
184211 logger . error ( "Failed to persist state" , { error : err . message } ) ,
185212 )
186213
187- const messagesCompressed = endResult . messageIndex - startResult . messageIndex + 1
188- return `Compressed ${ messagesCompressed } messages (${ containedToolIds . length } tool calls) into summary. The content will be replaced with your summary.`
214+ return `Compressed ${ compressedMessageIds . length } messages (${ compressedToolIds . length } tool calls) into summary. The content will be replaced with your summary.`
189215 } ,
190216 } )
191217}
0 commit comments