diff --git a/internal/postgres/push.go b/internal/postgres/push.go index 0df8b12e7..ff39164a8 100644 --- a/internal/postgres/push.go +++ b/internal/postgres/push.go @@ -286,7 +286,8 @@ func (s *Sync) Push( batch := sessions[i:end] batchResult, err := s.pushBatch( - ctx, batch, full, markerID, legacyMarkerMachines, &pushed, + ctx, batch, full, markerID, legacyMarkerMachines, + usageFingerprints, &pushed, ) if err != nil { return result, err @@ -301,7 +302,8 @@ func (s *Sync) Push( for _, sess := range batch { sr, retryErr := s.pushBatch( ctx, []db.Session{sess}, - full, markerID, legacyMarkerMachines, &pushed, + full, markerID, legacyMarkerMachines, + usageFingerprints, &pushed, ) if retryErr != nil { return result, retryErr @@ -554,6 +556,10 @@ type batchResult struct { skippedConflicts int } +var errPushComparisonPreload = errors.New( + "push comparison preload failed", +) + // pushBatch pushes a slice of sessions within a single // transaction. On success it appends to pushed and returns // ok=true with session/message counts. On a session-level @@ -566,7 +572,37 @@ func (s *Sync) pushBatch( full bool, markerID string, legacyMarkerMachines []string, + sessionUsageFingerprints map[string]string, + pushed *[]db.Session, +) (batchResult, error) { + preloadComparisons := len(batch) > 0 && !full + result, err := s.pushBatchAttempt( + ctx, batch, full, markerID, legacyMarkerMachines, + sessionUsageFingerprints, pushed, preloadComparisons, + ) + if err == nil || !errors.Is(err, errPushComparisonPreload) { + return result, err + } + log.Printf( + "pgsync: preloading pg comparison fingerprints failed, "+ + "retrying batch without preload: %v", + err, + ) + return s.pushBatchAttempt( + ctx, batch, full, markerID, legacyMarkerMachines, + sessionUsageFingerprints, pushed, false, + ) +} + +func (s *Sync) pushBatchAttempt( + ctx context.Context, + batch []db.Session, + full bool, + markerID string, + legacyMarkerMachines []string, + sessionUsageFingerprints map[string]string, pushed *[]db.Session, + preloadComparisons bool, ) (batchResult, error) { tx, err := s.pg.BeginTx(ctx, nil) if err != nil { @@ -578,6 +614,24 @@ func (s *Sync) pushBatch( n := 0 msgs := 0 skippedConflicts := 0 + sessionIDs := make([]string, 0, len(batch)) + for _, sess := range batch { + sessionIDs = append(sessionIDs, sess.ID) + } + comparisons := (*pushMessageComparison)(nil) + if preloadComparisons && len(sessionIDs) > 0 { + comparisonsBatch, err := readPushSessionMessageComparisons( + ctx, tx, sessionIDs, + ) + if err != nil { + _ = tx.Rollback() + return batchResult{}, fmt.Errorf( + "%w: %w", errPushComparisonPreload, err, + ) + } + comparisons = comparisonsBatch + } + for _, sess := range batch { if err := s.pushSession( ctx, tx, sess, markerID, legacyMarkerMachines, @@ -597,6 +651,7 @@ func (s *Sync) pushBatch( msgCount, err := s.pushMessages( ctx, tx, sess.ID, full, + sessionUsageFingerprints, comparisons, ) if err != nil { log.Printf( @@ -1241,6 +1296,8 @@ func (s *Sync) pushMessages( tx *sql.Tx, sessionID string, full bool, + sessionUsageFingerprints map[string]string, + comparisons *pushMessageComparison, ) (int, error) { localCount, err := s.local.MessageCount(sessionID) if err != nil { @@ -1289,168 +1346,198 @@ func (s *Sync) pushMessages( return 0, nil } - var pgCount int - var pgContentSum, pgContentMax, pgContentMin int64 - // Exact string fingerprint for the system-message ordinal set: - // STRING_AGG produces e.g. "0,2,5" — impossible to collide for - // distinct ordinal sets (unlike SUM or SUM+SUM-of-squares). - var pgSystemFP sql.NullString - var pgToolCallCount int - var pgTCContentSum int64 - if err := tx.QueryRowContext(ctx, - `SELECT COUNT(*), - COALESCE(SUM(content_length), 0), - COALESCE(MAX(content_length), 0), - COALESCE(MIN(content_length), 0), - STRING_AGG(ordinal::text, ',' ORDER BY ordinal) - FILTER (WHERE is_system) - FROM messages - WHERE session_id = $1`, - sessionID, - ).Scan( - &pgCount, &pgContentSum, - &pgContentMax, &pgContentMin, - &pgSystemFP, - ); err != nil { - return 0, fmt.Errorf( - "counting pg messages: %w", err, - ) - } - if err := tx.QueryRowContext(ctx, - `SELECT COUNT(*), - COALESCE(SUM(result_content_length), 0) - FROM tool_calls - WHERE session_id = $1`, - sessionID, - ).Scan(&pgToolCallCount, &pgTCContentSum); err != nil { - return 0, fmt.Errorf( - "counting pg tool_calls: %w", err, - ) - } - - if !full && pgCount == localCount && pgCount > 0 { - localSum, localMax, localMin, err := s.local.MessageContentFingerprint(sessionID) - if err != nil { + pgAgg, pgToolAgg, hasPreloadedComparisons := comparisonAggregates( + sessionID, comparisons, + ) + if !hasPreloadedComparisons { + if err := tx.QueryRowContext(ctx, + `SELECT COUNT(*), + COALESCE(SUM(content_length), 0), + COALESCE(MAX(content_length), 0), + COALESCE(MIN(content_length), 0), + COALESCE( + STRING_AGG(ordinal::text, ',' ORDER BY ordinal) + FILTER (WHERE is_system), + '' + ) + FROM messages + WHERE session_id = $1`, + sessionID, + ).Scan( + &pgAgg.Count, &pgAgg.Sum, + &pgAgg.Max, &pgAgg.Min, + &pgAgg.SysFP, + ); err != nil { return 0, fmt.Errorf( - "computing local content fingerprint: %w", - err, + "counting pg messages: %w", err, ) } - localContentHashFP, err := s.local.MessageContentHashFingerprint(sessionID) - if err != nil { + if err := tx.QueryRowContext(ctx, + `SELECT COUNT(*), + COALESCE(SUM(result_content_length), 0) + FROM tool_calls + WHERE session_id = $1`, + sessionID, + ).Scan(&pgToolAgg.Count, &pgToolAgg.Sum); err != nil { return 0, fmt.Errorf( - "computing local content hash fingerprint: %w", - err, + "counting pg tool_calls: %w", err, ) } - pgContentHashFP, err := pgMessageContentHashFingerprint( - ctx, tx, sessionID, + } + + if !full && pgAgg.Count == localCount && pgAgg.Count > 0 { + localFP := pushLocalMessageFingerprint{} + + localFP.Sum, localFP.Max, localFP.Min, err = s.local.MessageContentFingerprint( + sessionID, ) if err != nil { return 0, fmt.Errorf( - "computing pg content hash fingerprint: %w", + "computing local content fingerprint: %w", err, ) } - localRoleTimeFP, err := localMessageRoleTimePGFingerprint( - s.local, sessionID, + localFP.ContentHashFP, err = s.local.MessageContentHashFingerprint( + sessionID, ) if err != nil { return 0, fmt.Errorf( - "computing local role/time fingerprint: %w", + "computing local content hash fingerprint: %w", err, ) } - pgRoleTimeFP, err := pgMessageRoleTimeFingerprint( - ctx, tx, sessionID, + localFP.RoleTimeFP, err = localMessageRoleTimePGFingerprint( + s.local, sessionID, ) if err != nil { return 0, fmt.Errorf( - "computing pg role/time fingerprint: %w", + "computing local role/time fingerprint: %w", err, ) } - localFlagsFP, err := s.local.MessageFlagsFingerprint(sessionID) + localFP.FlagsFP, err = s.local.MessageFlagsFingerprint(sessionID) if err != nil { return 0, fmt.Errorf( "computing local message flags fingerprint: %w", err, ) } - pgFlagsFP, err := pgMessageFlagsFingerprint(ctx, tx, sessionID) - if err != nil { - return 0, fmt.Errorf( - "computing pg message flags fingerprint: %w", - err, - ) - } - localSysFP, err := s.local.SystemMessageFingerprint(sessionID) + localFP.SystemFP, err = s.local.SystemMessageFingerprint(sessionID) if err != nil { return 0, fmt.Errorf( "computing local system message fingerprint: %w", err, ) } - localTCCount, err := s.local.ToolCallCount(sessionID) + localFP.ToolCallCount, err = s.local.ToolCallCount(sessionID) if err != nil { return 0, fmt.Errorf( "counting local tool_calls: %w", err, ) } - localTCSum, err := s.local.ToolCallContentFingerprint(sessionID) + localFP.ToolCallSum, err = s.local.ToolCallContentFingerprint( + sessionID, + ) if err != nil { return 0, fmt.Errorf( - "computing local tool_call content "+ - "fingerprint: %w", err, + "computing local tool_call content fingerprint: %w", + err, ) } - localTCFP, err := s.local.ToolCallFingerprint(sessionID) + localFP.ToolCallFP, err = s.local.ToolCallFingerprint(sessionID) if err != nil { return 0, fmt.Errorf( "computing local tool_call fingerprint: %w", err, ) } - localTokenFP, err := s.local.MessageTokenFingerprint(sessionID) + localFP.TokenFP, err = s.local.MessageTokenFingerprint(sessionID) if err != nil { return 0, fmt.Errorf( - "computing local token fingerprint: %w", err, + "computing local token fingerprint: %w", + err, ) } - pgTokenFP, err := pgMessageTokenFingerprint(ctx, tx, sessionID) - if err != nil { - return 0, fmt.Errorf( - "computing pg token fingerprint: %w", err, - ) + + usageFromMap := false + if sessionUsageFingerprints != nil { + var ok bool + localFP.UsageEventFP, ok = sessionUsageFingerprints[sessionID] + usageFromMap = ok } - pgTCFP, err := pgToolCallFingerprint(ctx, tx, sessionID) - if err != nil { - return 0, fmt.Errorf( - "computing pg tool_call fingerprint: %w", err, - ) + if !usageFromMap { + localFP.UsageEventFP, err = s.local.UsageEventFingerprint(sessionID) + if err != nil { + return 0, fmt.Errorf( + "computing local usage event fingerprint: %w", + err, + ) + } } - localUsageFP, err := s.local.UsageEventFingerprint(sessionID) - if err != nil { - return 0, fmt.Errorf( - "computing local usage event fingerprint: %w", err, + + if comparisons == nil { + pgContentHashFP, err := pgMessageContentHashFingerprint( + ctx, tx, sessionID, ) - } - pgUsageFP, err := pgUsageEventFingerprint(ctx, tx, sessionID) - if err != nil { - return 0, fmt.Errorf( - "computing pg usage event fingerprint: %w", err, + if err != nil { + return 0, fmt.Errorf( + "computing pg content hash fingerprint: %w", + err, + ) + } + pgRoleTimeFP, err := pgMessageRoleTimeFingerprint( + ctx, tx, sessionID, ) - } - if localSum == pgContentSum && - localMax == pgContentMax && - localMin == pgContentMin && - localContentHashFP == pgContentHashFP && - localRoleTimeFP == pgRoleTimeFP && - localFlagsFP == pgFlagsFP && - localSysFP == pgSystemFP.String && - localTCCount == pgToolCallCount && - localTCSum == pgTCContentSum && - localTCFP == pgTCFP && - localTokenFP == pgTokenFP && - localUsageFP == pgUsageFP { + if err != nil { + return 0, fmt.Errorf( + "computing pg role/time fingerprint: %w", + err, + ) + } + pgFlagsFP, err := pgMessageFlagsFingerprint(ctx, tx, sessionID) + if err != nil { + return 0, fmt.Errorf( + "computing pg message flags fingerprint: %w", + err, + ) + } + pgTokenFP, err := pgMessageTokenFingerprint(ctx, tx, sessionID) + if err != nil { + return 0, fmt.Errorf( + "computing pg token fingerprint: %w", + err, + ) + } + pgTCFP, err := pgToolCallFingerprint(ctx, tx, sessionID) + if err != nil { + return 0, fmt.Errorf( + "computing pg tool_call fingerprint: %w", + err, + ) + } + pgUsageFP, err := pgUsageEventFingerprint(ctx, tx, sessionID) + if err != nil { + return 0, fmt.Errorf( + "computing pg usage event fingerprint: %w", + err, + ) + } + + if localFP.Sum == pgAgg.Sum && + localFP.Max == pgAgg.Max && + localFP.Min == pgAgg.Min && + localFP.ContentHashFP == pgContentHashFP && + localFP.RoleTimeFP == pgRoleTimeFP && + localFP.FlagsFP == pgFlagsFP && + localFP.SystemFP == pgAgg.SysFP && + localFP.ToolCallCount == pgToolAgg.Count && + localFP.ToolCallSum == pgToolAgg.Sum && + localFP.ToolCallFP == pgTCFP && + localFP.TokenFP == pgTokenFP && + localFP.UsageEventFP == pgUsageFP { + return 0, nil + } + } else if shouldSkipSessionMessages( + sessionID, localCount, localFP, full, comparisons, + ) { return 0, nil } } diff --git a/internal/postgres/push_fingerprint.go b/internal/postgres/push_fingerprint.go new file mode 100644 index 000000000..216b40b07 --- /dev/null +++ b/internal/postgres/push_fingerprint.go @@ -0,0 +1,618 @@ +package postgres + +import ( + "context" + "crypto/sha256" + "database/sql" + "fmt" + "strings" + + "go.kenn.io/agentsview/internal/db" +) + +const pushComparisonBatchSize = 900 + +type pushMessageAggregate struct { + Count int + Sum int64 + Max int64 + Min int64 + SysFP string +} + +type pushToolCallAggregate struct { + Count int + Sum int64 +} + +type pushMessageComparison struct { + MessageAggregates map[string]pushMessageAggregate + MessageContentHash map[string]string + MessageRoleTime map[string]string + MessageFlags map[string]string + MessageSystemOrdinals map[string]string + MessageTokenFingerprint map[string]string + ToolCallAggregates map[string]pushToolCallAggregate + ToolCallFingerprint map[string]string + UsageEventFingerprint map[string]string +} + +type pushLocalMessageFingerprint struct { + Sum int64 + Max int64 + Min int64 + ContentHashFP string + RoleTimeFP string + FlagsFP string + SystemFP string + ToolCallCount int + ToolCallSum int64 + ToolCallFP string + TokenFP string + UsageEventFP string +} + +func comparisonAggregates( + sessionID string, + comparisons *pushMessageComparison, +) (pushMessageAggregate, pushToolCallAggregate, bool) { + if comparisons == nil { + return pushMessageAggregate{}, pushToolCallAggregate{}, false + } + return comparisons.MessageAggregates[sessionID], + comparisons.ToolCallAggregates[sessionID], + true +} + +func readPushSessionMessageComparisons( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, +) (*pushMessageComparison, error) { + comparisons := &pushMessageComparison{ + MessageAggregates: make(map[string]pushMessageAggregate, len(sessionIDs)), + MessageContentHash: make(map[string]string, len(sessionIDs)), + MessageRoleTime: make(map[string]string, len(sessionIDs)), + MessageFlags: make(map[string]string, len(sessionIDs)), + MessageSystemOrdinals: make(map[string]string, len(sessionIDs)), + MessageTokenFingerprint: make(map[string]string, len(sessionIDs)), + ToolCallAggregates: make(map[string]pushToolCallAggregate, len(sessionIDs)), + ToolCallFingerprint: make(map[string]string, len(sessionIDs)), + UsageEventFingerprint: make(map[string]string, len(sessionIDs)), + } + + for i := 0; i < len(sessionIDs); i += pushComparisonBatchSize { + end := min(i+pushComparisonBatchSize, len(sessionIDs)) + chunk := sessionIDs[i:end] + + if err := loadPushMessageAggregates(ctx, tx, chunk, comparisons.MessageAggregates); err != nil { + return nil, err + } + if err := loadPushMessageContentHashFingerprints( + ctx, tx, chunk, comparisons.MessageContentHash, + ); err != nil { + return nil, err + } + if err := loadPushMessageRoleTimeFingerprints( + ctx, tx, chunk, comparisons.MessageRoleTime, + ); err != nil { + return nil, err + } + if err := loadPushMessageFlagFingerprints( + ctx, tx, chunk, comparisons.MessageFlags, + ); err != nil { + return nil, err + } + if err := loadPushMessageSystemOrdinals( + ctx, tx, chunk, comparisons.MessageSystemOrdinals, + ); err != nil { + return nil, err + } + if err := loadPushMessageTokenFingerprints( + ctx, tx, chunk, comparisons.MessageTokenFingerprint, + ); err != nil { + return nil, err + } + if err := loadPushToolCallAggregates( + ctx, tx, chunk, comparisons.ToolCallAggregates, + ); err != nil { + return nil, err + } + if err := loadPushToolCallFingerprints( + ctx, tx, chunk, comparisons.ToolCallFingerprint, + ); err != nil { + return nil, err + } + if err := loadPushUsageEventFingerprints( + ctx, tx, chunk, comparisons.UsageEventFingerprint, + ); err != nil { + return nil, err + } + } + + return comparisons, nil +} + +func loadPushMessageAggregates( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, + out map[string]pushMessageAggregate, +) error { + rows, err := tx.QueryContext(ctx, ` + SELECT session_id, COUNT(*), COALESCE(SUM(content_length), 0), + COALESCE(MAX(content_length), 0), COALESCE(MIN(content_length), 0), + COALESCE(STRING_AGG(ordinal::text, ',' ORDER BY ordinal) + FILTER (WHERE is_system), '') + FROM messages + WHERE session_id = ANY($1) + GROUP BY session_id + `, sessionIDs) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var sessionID string + var count int64 + var agg pushMessageAggregate + if err := rows.Scan( + &sessionID, &count, &agg.Sum, &agg.Max, &agg.Min, &agg.SysFP, + ); err != nil { + return err + } + agg.Count = int(count) + out[sessionID] = agg + } + return rows.Err() +} + +func loadPushMessageContentHashFingerprints( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, + out map[string]string, +) error { + rows, err := tx.QueryContext(ctx, ` + SELECT session_id, ordinal, COALESCE(content, ''), + content_length + FROM messages + WHERE session_id = ANY($1) + ORDER BY session_id, ordinal ASC + `, sessionIDs) + if err != nil { + return err + } + defer rows.Close() + + builders := make(map[string]*strings.Builder, len(sessionIDs)) + for rows.Next() { + var sessionID string + var ordinal, contentLength int + var content string + if err := rows.Scan(&sessionID, &ordinal, &content, &contentLength); err != nil { + return err + } + sum := sha256.Sum256([]byte(db.SanitizeUTF8(content))) + b := builders[sessionID] + if b == nil { + b = &strings.Builder{} + builders[sessionID] = b + } + fmt.Fprintf(b, "%d|%d|%x;", ordinal, contentLength, sum) + } + if err := rows.Err(); err != nil { + return err + } + for sessionID, b := range builders { + out[sessionID] = b.String() + } + return nil +} + +func loadPushMessageRoleTimeFingerprints( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, + out map[string]string, +) error { + rows, err := tx.QueryContext(ctx, ` + SELECT session_id, ordinal, role, timestamp + FROM messages + WHERE session_id = ANY($1) + ORDER BY session_id, ordinal ASC + `, sessionIDs) + if err != nil { + return err + } + defer rows.Close() + + builders := make(map[string]*strings.Builder, len(sessionIDs)) + for rows.Next() { + var sessionID string + var ordinal int + var role string + var timestamp sql.NullTime + if err := rows.Scan(&sessionID, &ordinal, &role, ×tamp); err != nil { + return err + } + timestampText := "" + if timestamp.Valid { + timestampText = pgPushTimestampFingerprintText( + FormatISO8601(timestamp.Time), + ) + } + b := builders[sessionID] + if b == nil { + b = &strings.Builder{} + builders[sessionID] = b + } + fmt.Fprintf( + b, "%d|%d:%s|%d:%s;", + ordinal, len(role), role, len(timestampText), timestampText, + ) + } + if err := rows.Err(); err != nil { + return err + } + for sessionID, b := range builders { + out[sessionID] = b.String() + } + return nil +} + +func loadPushMessageFlagFingerprints( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, + out map[string]string, +) error { + rows, err := tx.QueryContext(ctx, ` + SELECT session_id, ordinal, is_system, has_thinking, has_tool_use, + COALESCE(thinking_text, '') + FROM messages + WHERE session_id = ANY($1) + ORDER BY session_id, ordinal ASC + `, sessionIDs) + if err != nil { + return err + } + defer rows.Close() + + builders := make(map[string]*strings.Builder, len(sessionIDs)) + for rows.Next() { + var sessionID string + var ordinal int + var isSystem, hasThinking, hasToolUse bool + var thinkingText string + if err := rows.Scan( + &sessionID, &ordinal, &isSystem, &hasThinking, &hasToolUse, + &thinkingText, + ); err != nil { + return err + } + sum := sha256.Sum256([]byte(db.SanitizeUTF8(thinkingText))) + b := builders[sessionID] + if b == nil { + b = &strings.Builder{} + builders[sessionID] = b + } + fmt.Fprintf( + b, "%d|%t|%t|%t|%x;", ordinal, isSystem, hasThinking, + hasToolUse, sum, + ) + } + if err := rows.Err(); err != nil { + return err + } + for sessionID, b := range builders { + out[sessionID] = b.String() + } + return nil +} + +func loadPushMessageSystemOrdinals( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, + out map[string]string, +) error { + rows, err := tx.QueryContext(ctx, ` + SELECT session_id, + COALESCE( + STRING_AGG(ordinal::text, ',' ORDER BY ordinal) + FILTER (WHERE is_system), + '' + ) + FROM messages + WHERE session_id = ANY($1) + GROUP BY session_id + `, sessionIDs) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var sessionID string + var systemOrdinals string + if err := rows.Scan(&sessionID, &systemOrdinals); err != nil { + return err + } + out[sessionID] = systemOrdinals + } + return rows.Err() +} + +func loadPushMessageTokenFingerprints( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, + out map[string]string, +) error { + rows, err := tx.QueryContext(ctx, ` + SELECT session_id, ordinal, model, token_usage, context_tokens, + output_tokens, has_context_tokens, has_output_tokens, + claude_message_id, claude_request_id, + source_type, source_subtype, source_uuid, + source_parent_uuid, is_sidechain, is_compact_boundary + FROM messages + WHERE session_id = ANY($1) + ORDER BY session_id, ordinal ASC + `, sessionIDs) + if err != nil { + return err + } + defer rows.Close() + + builders := make(map[string]*strings.Builder, len(sessionIDs)) + for rows.Next() { + var sessionID string + var ordinal, contextTokens, outputTokens int + var model, tokenUsage string + var hasContextTokens, hasOutputTokens bool + var claudeMsgID, claudeReqID string + var srcType, srcSubtype, srcUUID, srcParentUUID string + var isSidechain, isCompactBoundary bool + if err := rows.Scan( + &sessionID, &ordinal, &model, &tokenUsage, &contextTokens, + &outputTokens, &hasContextTokens, &hasOutputTokens, + &claudeMsgID, &claudeReqID, + &srcType, &srcSubtype, &srcUUID, &srcParentUUID, + &isSidechain, &isCompactBoundary, + ); err != nil { + return err + } + b := builders[sessionID] + if b == nil { + b = &strings.Builder{} + builders[sessionID] = b + } + fmt.Fprintf( + b, + "%d|%d:%s|%d:%s|%d|%d|%t|%t|%s|%s|"+ + "%d:%s|%d:%s|%d:%s|%d:%s|%t|%t;", + ordinal, + len(model), model, + len(tokenUsage), tokenUsage, + contextTokens, outputTokens, + hasContextTokens, hasOutputTokens, + claudeMsgID, claudeReqID, + len(srcType), srcType, + len(srcSubtype), srcSubtype, + len(srcUUID), srcUUID, + len(srcParentUUID), srcParentUUID, + isSidechain, isCompactBoundary, + ) + } + if err := rows.Err(); err != nil { + return err + } + for sessionID, b := range builders { + out[sessionID] = b.String() + } + return nil +} + +func loadPushToolCallAggregates( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, + out map[string]pushToolCallAggregate, +) error { + rows, err := tx.QueryContext(ctx, ` + SELECT session_id, + COUNT(*), COALESCE(SUM(result_content_length), 0) + FROM tool_calls + WHERE session_id = ANY($1) + GROUP BY session_id + `, sessionIDs) + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var sessionID string + var agg pushToolCallAggregate + var count int64 + if err := rows.Scan(&sessionID, &count, &agg.Sum); err != nil { + return err + } + agg.Count = int(count) + out[sessionID] = agg + } + return rows.Err() +} + +func loadPushToolCallFingerprints( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, + out map[string]string, +) error { + rows, err := tx.QueryContext(ctx, ` + SELECT session_id, message_ordinal, call_index, tool_name, category, + tool_use_id, COALESCE(input_json, ''), + COALESCE(skill_name, ''), COALESCE(subagent_session_id, ''), + COALESCE(result_content_length, 0), + COALESCE(result_content, '') + FROM tool_calls + WHERE session_id = ANY($1) + ORDER BY session_id, message_ordinal ASC, call_index ASC + `, sessionIDs) + if err != nil { + return err + } + defer rows.Close() + + builders := make(map[string]*strings.Builder, len(sessionIDs)) + for rows.Next() { + var sessionID string + var messageOrdinal, callIndex, resultContentLength int + var toolName, category, toolUseID, inputJSON string + var skillName, subagentSessionID, resultContent string + if err := rows.Scan( + &sessionID, &messageOrdinal, &callIndex, &toolName, + &category, &toolUseID, &inputJSON, + &skillName, &subagentSessionID, &resultContentLength, + &resultContent, + ); err != nil { + return err + } + b := builders[sessionID] + if b == nil { + b = &strings.Builder{} + builders[sessionID] = b + } + fmt.Fprintf( + b, + "%d|%d|%d:%s|%d:%s|%d:%s|%d:%s|%d:%s|%d:%s|%d|%d:%s;", + messageOrdinal, callIndex, + len(toolName), toolName, + len(category), category, + len(toolUseID), toolUseID, + len(inputJSON), inputJSON, + len(skillName), skillName, + len(subagentSessionID), subagentSessionID, + resultContentLength, + len(resultContent), resultContent, + ) + } + if err := rows.Err(); err != nil { + return err + } + for sessionID, b := range builders { + out[sessionID] = b.String() + } + return nil +} + +func loadPushUsageEventFingerprints( + ctx context.Context, + tx *sql.Tx, + sessionIDs []string, + out map[string]string, +) error { + rows, err := tx.QueryContext(ctx, ` + SELECT session_id, message_ordinal, source, model, + input_tokens, output_tokens, + cache_creation_input_tokens, cache_read_input_tokens, + reasoning_tokens, cost_usd, cost_status, cost_source, + occurred_at, dedup_key + FROM usage_events + WHERE session_id = ANY($1) + ORDER BY session_id, occurred_at NULLS FIRST, id +`, sessionIDs) + if err != nil { + return err + } + defer rows.Close() + + builders := make(map[string]*strings.Builder, len(sessionIDs)) + for rows.Next() { + var sessionID string + var ordinal sql.NullInt64 + var source, model, costStatus, costSource string + var inputTokens, outputTokens int + var cacheCreationInputTokens, cacheReadInputTokens int + var reasoningTokens int + var cost sql.NullFloat64 + var occurredAt sql.NullTime + var dedupKey sql.NullString + if err := rows.Scan( + &sessionID, &ordinal, &source, &model, + &inputTokens, &outputTokens, + &cacheCreationInputTokens, &cacheReadInputTokens, + &reasoningTokens, &cost, &costStatus, &costSource, + &occurredAt, &dedupKey, + ); err != nil { + return err + } + b := builders[sessionID] + if b == nil { + b = &strings.Builder{} + builders[sessionID] = b + } + occurred := "" + if occurredAt.Valid { + occurred = FormatISO8601(occurredAt.Time) + } + fmt.Fprintf( + b, + "%t|%d|%d:%s|%d:%s|%d|%d|%d|%d|%d|%t|%g|%d:%s|%d:%s|%d:%s|%d:%s;", + ordinal.Valid, + ordinal.Int64, + len(source), source, + len(model), model, + inputTokens, + outputTokens, + cacheCreationInputTokens, + cacheReadInputTokens, + reasoningTokens, + cost.Valid, + cost.Float64, + len(costStatus), costStatus, + len(costSource), costSource, + len(occurred), occurred, + len(dedupKey.String), dedupKey.String, + ) + } + if err := rows.Err(); err != nil { + return err + } + for sessionID, b := range builders { + out[sessionID] = b.String() + } + return nil +} + +func shouldSkipSessionMessages( + sessionID string, + localCount int, + localFP pushLocalMessageFingerprint, + full bool, + comparisons *pushMessageComparison, +) bool { + if full || localCount == 0 || comparisons == nil { + return false + } + pgAgg := comparisons.MessageAggregates[sessionID] + if pgAgg.Count != localCount || pgAgg.Count == 0 { + return false + } + + return localFP.Sum == pgAgg.Sum && + localFP.Max == pgAgg.Max && + localFP.Min == pgAgg.Min && + localFP.ContentHashFP == comparisons.MessageContentHash[sessionID] && + localFP.RoleTimeFP == comparisons.MessageRoleTime[sessionID] && + localFP.FlagsFP == comparisons.MessageFlags[sessionID] && + localFP.SystemFP == comparisons.MessageSystemOrdinals[sessionID] && + localFP.ToolCallCount == comparisons.ToolCallAggregates[sessionID].Count && + localFP.ToolCallSum == comparisons.ToolCallAggregates[sessionID].Sum && + localFP.ToolCallFP == comparisons.ToolCallFingerprint[sessionID] && + localFP.TokenFP == comparisons.MessageTokenFingerprint[sessionID] && + localFP.UsageEventFP == comparisons.UsageEventFingerprint[sessionID] +} diff --git a/internal/postgres/push_fingerprint_test.go b/internal/postgres/push_fingerprint_test.go new file mode 100644 index 000000000..13a5017e7 --- /dev/null +++ b/internal/postgres/push_fingerprint_test.go @@ -0,0 +1,78 @@ +package postgres + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadPushSessionMessageComparisonsNoSessions(t *testing.T) { + comparisons, err := readPushSessionMessageComparisons( + context.Background(), nil, nil, + ) + require.NoError(t, err) + require.NotNil(t, comparisons) + assert.Empty(t, comparisons.MessageAggregates) + assert.Empty(t, comparisons.MessageContentHash) + assert.Empty(t, comparisons.MessageRoleTime) + assert.Empty(t, comparisons.MessageFlags) + assert.Empty(t, comparisons.MessageSystemOrdinals) + assert.Empty(t, comparisons.MessageTokenFingerprint) + assert.Empty(t, comparisons.ToolCallAggregates) + assert.Empty(t, comparisons.ToolCallFingerprint) + assert.Empty(t, comparisons.UsageEventFingerprint) +} + +func TestShouldSkipSessionMessagesGuardsCountAndNilMaps(t *testing.T) { + comparisons := &pushMessageComparison{ + MessageAggregates: map[string]pushMessageAggregate{ + "sess": {Count: 1, Sum: 1, Max: 1, Min: 1}, + }, + MessageContentHash: map[string]string{"sess": ""}, + MessageRoleTime: map[string]string{"sess": ""}, + MessageFlags: map[string]string{"sess": ""}, + MessageSystemOrdinals: map[string]string{"sess": ""}, + MessageTokenFingerprint: map[string]string{"sess": ""}, + ToolCallAggregates: map[string]pushToolCallAggregate{"sess": {}}, + ToolCallFingerprint: map[string]string{"sess": ""}, + UsageEventFingerprint: map[string]string{"sess": ""}, + } + localFP := pushLocalMessageFingerprint{Sum: 1, Max: 1, Min: 1} + assert.False(t, shouldSkipSessionMessages( + "sess", 1, localFP, false, nil, + )) + assert.True(t, shouldSkipSessionMessages( + "sess", 1, localFP, false, comparisons, + )) + assert.False(t, shouldSkipSessionMessages( + "sess", 2, localFP, false, comparisons, + )) +} + +func TestComparisonAggregates(t *testing.T) { + msgAgg, toolAgg, ok := comparisonAggregates("missing", nil) + assert.False(t, ok) + assert.Equal(t, pushMessageAggregate{}, msgAgg) + assert.Equal(t, pushToolCallAggregate{}, toolAgg) + + comparisons := &pushMessageComparison{ + MessageAggregates: map[string]pushMessageAggregate{ + "sess": {Count: 3, Sum: 9, Max: 5, Min: 1, SysFP: "0,2"}, + }, + ToolCallAggregates: map[string]pushToolCallAggregate{ + "sess": {Count: 2, Sum: 11}, + }, + } + + msgAgg, toolAgg, ok = comparisonAggregates("sess", comparisons) + require.True(t, ok) + assert.Equal(t, + pushMessageAggregate{ + Count: 3, Sum: 9, Max: 5, Min: 1, SysFP: "0,2", + }, + msgAgg, + ) + assert.Equal(t, pushToolCallAggregate{Count: 2, Sum: 11}, toolAgg) +} diff --git a/internal/postgres/push_test.go b/internal/postgres/push_test.go index b447073d5..0bc8e4a9a 100644 --- a/internal/postgres/push_test.go +++ b/internal/postgres/push_test.go @@ -516,3 +516,64 @@ func TestNilStrSanitizes(t *testing.T) { nul := "\x00" assert.Nil(t, nilStr(&nul), "nilStr(\"\\x00\") should be nil") } + +func TestShouldSkipSessionMessagesInBatchedPush(t *testing.T) { + const sessionID = "sess-batched" + baseComparisons := &pushMessageComparison{ + MessageAggregates: map[string]pushMessageAggregate{ + sessionID: {Count: 2, Sum: 12, Max: 6, Min: 1}, + }, + MessageContentHash: map[string]string{ + sessionID: "abc", + }, + MessageRoleTime: map[string]string{ + sessionID: "role-time", + }, + MessageFlags: map[string]string{ + sessionID: "flags", + }, + MessageSystemOrdinals: map[string]string{ + sessionID: "0,1", + }, + MessageTokenFingerprint: map[string]string{ + sessionID: "tokens", + }, + ToolCallAggregates: map[string]pushToolCallAggregate{ + sessionID: {Count: 1, Sum: 99}, + }, + ToolCallFingerprint: map[string]string{ + sessionID: "toolcalls", + }, + UsageEventFingerprint: map[string]string{ + sessionID: "usage", + }, + } + unchangedFP := pushLocalMessageFingerprint{ + Sum: 12, + Max: 6, + Min: 1, + ContentHashFP: "abc", + RoleTimeFP: "role-time", + FlagsFP: "flags", + SystemFP: "0,1", + ToolCallCount: 1, + ToolCallSum: 99, + ToolCallFP: "toolcalls", + TokenFP: "tokens", + UsageEventFP: "usage", + } + + assert.True(t, shouldSkipSessionMessages( + sessionID, 2, unchangedFP, false, baseComparisons, + ), "unchanged sessions should be skipped as unchanged") + + changedFP := unchangedFP + changedFP.ToolCallSum = 100 + assert.False(t, shouldSkipSessionMessages( + sessionID, 2, changedFP, false, baseComparisons, + ), "tool-call sum mismatch should force push") + + assert.False(t, shouldSkipSessionMessages( + sessionID, 2, unchangedFP, true, baseComparisons, + ), "full mode should not skip by fingerprint check") +}