Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ scripts/download-stats.sh
# IDE files
.vscode/
.cursor/
.codex
208 changes: 198 additions & 10 deletions agent/context/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package context

import (
"fmt"
"strings"
"sync"
"time"

"github.com/mindspore-lab/mindspore-cli/configs"
"github.com/mindspore-lab/mindspore-cli/integrations/llm"
)

Expand All @@ -23,8 +25,8 @@ type ManagerConfig struct {
// DefaultManagerConfig 返回默认配置
func DefaultManagerConfig() ManagerConfig {
return ManagerConfig{
ContextWindow: 200000,
ReserveTokens: 4000,
ContextWindow: configs.DefaultContextWindow,
ReserveTokens: configs.DefaultReserveTokens(configs.DefaultContextWindow),
CompactionThreshold: 0.9,
EnableSmartCompact: true,
CompactStrategy: CompactStrategyHybrid,
Expand All @@ -42,6 +44,13 @@ type Manager struct {
system *llm.Message
usage TokenUsage

exactSnapshotTokens int
exactSnapshotEstimate int
exactSnapshotProvider string
exactSnapshotScope ProviderTokenScope
exactSnapshotUsage llm.Usage
hasExactSnapshot bool

// 增强组件
tokenizer *Tokenizer
compactor *Compactor
Expand All @@ -59,6 +68,36 @@ type TokenUsage struct {
Available int
}

type TokenUsageSource string
type ProviderTokenScope string

const (
TokenUsageSourceLocalEstimate TokenUsageSource = "local_estimate"
TokenUsageSourceProvider TokenUsageSource = "provider_snapshot"

ProviderTokenScopePrompt ProviderTokenScope = "prompt"
ProviderTokenScopeTotal ProviderTokenScope = "total"
)

type TokenUsageDetails struct {
TokenUsage
Source TokenUsageSource
Provider string
ProviderSnapshotTokens int
ProviderTokenScope ProviderTokenScope
ProviderUsage llm.Usage
LocalEstimatedTotal int
LocalDelta int
}

type ProviderUsageSnapshot struct {
Provider string
TokenScope ProviderTokenScope
Tokens int
LocalDelta int
Usage llm.Usage
}

// Stats 上下文统计
type Stats struct {
MessageCount int
Expand All @@ -71,10 +110,10 @@ type Stats struct {
// NewManager creates a new context manager.
func NewManager(cfg ManagerConfig) *Manager {
if cfg.ContextWindow == 0 {
cfg.ContextWindow = 200000
cfg.ContextWindow = configs.DefaultContextWindow
}
if cfg.ReserveTokens == 0 {
cfg.ReserveTokens = 4000
cfg.ReserveTokens = configs.DefaultReserveTokens(cfg.ContextWindow)
}
if cfg.CompactionThreshold == 0 {
cfg.CompactionThreshold = 0.9
Expand Down Expand Up @@ -108,6 +147,7 @@ func (m *Manager) SetSystemPrompt(content string) {

msg := llm.NewSystemMessage(content)
m.system = &msg
m.clearProviderTokenUsageLocked()

m.recalculateUsage()
}
Expand Down Expand Up @@ -190,6 +230,7 @@ func (m *Manager) SetNonSystemMessages(msgs []llm.Message) {

m.messages = make([]llm.Message, len(msgs))
copy(m.messages, msgs)
m.clearProviderTokenUsageLocked()

m.stats.MessageCount = len(m.messages)
m.stats.ToolCallCount = 0
Expand All @@ -208,6 +249,7 @@ func (m *Manager) Clear() {
defer m.mu.Unlock()

m.messages = make([]llm.Message, 0)
m.clearProviderTokenUsageLocked()
m.recalculateUsage()
}

Expand All @@ -216,7 +258,7 @@ func (m *Manager) Compact() error {
m.mu.Lock()
defer m.mu.Unlock()

currentTokens := m.totalTokensLocked()
currentTokens := m.currentTokensLocked()
if currentTokens == 0 {
return nil
}
Expand Down Expand Up @@ -261,6 +303,128 @@ func (m *Manager) SetContextWindowLimits(contextWindow, reserveTokens int) error
return nil
}

// SetPromptTokenUsage records provider-reported prompt tokens for the current context.
// Values <= 0 clear the provider usage and fall back to local estimation.
func (m *Manager) SetPromptTokenUsage(provider string, promptTokens int) {
m.setProviderTokenUsage(provider, promptTokens, ProviderTokenScopePrompt, llm.Usage{
PromptTokens: promptTokens,
})
}

// SetTotalTokenUsage records provider-reported total tokens for the current context.
// Values <= 0 clear the provider usage and fall back to local estimation.
func (m *Manager) SetTotalTokenUsage(provider string, totalTokens int) {
m.setProviderTokenUsage(provider, totalTokens, ProviderTokenScopeTotal, llm.Usage{
TotalTokens: totalTokens,
})
}

// SetProviderTokenUsage records the best provider-reported usage snapshot for the current context.
// When total tokens are available they are preferred over prompt-only snapshots.
func (m *Manager) SetProviderTokenUsage(provider string, usage llm.Usage) {
switch {
case usage.TotalTokens > 0:
m.setProviderTokenUsage(provider, usage.TotalTokens, ProviderTokenScopeTotal, usage)
case usage.PromptTokens > 0:
m.setProviderTokenUsage(provider, usage.PromptTokens, ProviderTokenScopePrompt, usage)
default:
m.setProviderTokenUsage(provider, 0, "", llm.Usage{})
}
}

func (m *Manager) setProviderTokenUsage(provider string, tokens int, scope ProviderTokenScope, usage llm.Usage) {
m.mu.Lock()
defer m.mu.Unlock()

if tokens <= 0 {
m.clearProviderTokenUsageLocked()
} else {
m.exactSnapshotTokens = tokens
m.exactSnapshotEstimate = m.totalTokensLocked()
m.exactSnapshotProvider = strings.TrimSpace(provider)
m.exactSnapshotScope = scope
m.exactSnapshotUsage = usage.Clone()
m.hasExactSnapshot = true
}

m.recalculateUsage()
}

// RestoreProviderUsageSnapshot restores a persisted provider snapshot for the current context.
// The caller must restore the matching message set before invoking this method.
func (m *Manager) RestoreProviderUsageSnapshot(snapshot ProviderUsageSnapshot) {
m.mu.Lock()
defer m.mu.Unlock()

if snapshot.Tokens <= 0 {
m.clearProviderTokenUsageLocked()
m.recalculateUsage()
return
}

localTotal := m.totalTokensLocked()
localDelta := snapshot.LocalDelta
if localDelta < 0 {
localDelta = 0
}
exactEstimate := localTotal - localDelta
if exactEstimate < 0 {
exactEstimate = 0
}

scope := snapshot.TokenScope
if scope != ProviderTokenScopeTotal {
scope = ProviderTokenScopePrompt
}

m.exactSnapshotTokens = snapshot.Tokens
m.exactSnapshotEstimate = exactEstimate
m.exactSnapshotProvider = strings.TrimSpace(snapshot.Provider)
m.exactSnapshotScope = scope
usage := snapshot.Usage.Clone()
switch scope {
case ProviderTokenScopeTotal:
if usage.TotalTokens <= 0 {
usage.TotalTokens = snapshot.Tokens
}
default:
if usage.PromptTokens <= 0 {
usage.PromptTokens = snapshot.Tokens
}
}
m.exactSnapshotUsage = usage
m.hasExactSnapshot = true
m.recalculateUsage()
}

// TokenUsageDetails returns current token usage together with its source metadata.
func (m *Manager) TokenUsageDetails() TokenUsageDetails {
m.mu.RLock()
defer m.mu.RUnlock()

localEstimatedTotal := m.totalTokensLocked()
details := TokenUsageDetails{
TokenUsage: m.usage,
Source: TokenUsageSourceLocalEstimate,
LocalEstimatedTotal: localEstimatedTotal,
}
if !m.hasExactSnapshot {
return details
}

localDelta := localEstimatedTotal - m.exactSnapshotEstimate
if localDelta < 0 {
localDelta = 0
}
details.Source = TokenUsageSourceProvider
details.Provider = m.exactSnapshotProvider
details.ProviderSnapshotTokens = m.exactSnapshotTokens
details.ProviderTokenScope = m.exactSnapshotScope
details.ProviderUsage = m.exactSnapshotUsage.Clone()
details.LocalDelta = localDelta
return details
}

// EstimateTokens estimates token count for messages.
func (m *Manager) EstimateTokens(msgs []llm.Message) int {
return m.tokenizer.EstimateMessages(msgs)
Expand All @@ -271,7 +435,7 @@ func (m *Manager) IsWithinBudget(msg llm.Message) bool {
m.mu.RLock()
defer m.mu.RUnlock()

estimated := m.totalTokensLocked() + m.tokenizer.EstimateMessage(msg)
estimated := m.currentTokensLocked() + m.tokenizer.EstimateMessage(msg)
return estimated <= m.maxUsableTokensLocked()
}

Expand Down Expand Up @@ -318,21 +482,21 @@ func (m *Manager) GetDetailedStats() map[string]any {
// shouldCompactLocked checks if compaction is needed (must hold lock).
func (m *Manager) shouldCompactLocked(additionalTokens int) bool {
threshold := m.compactionThresholdPercentLocked()
estimatedTokens := m.totalTokensLocked() + additionalTokens
estimatedTokens := m.currentTokensLocked() + additionalTokens
return float64(estimatedTokens) >= float64(m.config.ContextWindow)*(threshold/100.0)
}

// compactLocked compacts the context (must hold lock).
func (m *Manager) compactLocked() error {
currentTokens := m.totalTokensLocked()
currentTokens := m.currentTokensLocked()
if currentTokens == 0 || !m.shouldCompactLocked(0) {
return nil
}
return m.compactToTargetLocked(m.compactionTargetTokensLocked())
}

func (m *Manager) compactToTargetLocked(targetTokens int) error {
currentTokens := m.totalTokensLocked()
currentTokens := m.currentTokensLocked()
if currentTokens == 0 {
return nil
}
Expand Down Expand Up @@ -364,6 +528,7 @@ func (m *Manager) compactToTargetLocked(targetTokens int) error {
}

m.messages = compacted
m.clearProviderTokenUsageLocked()
m.stats.CompactCount++
now := time.Now()
m.stats.LastCompactAt = &now
Expand Down Expand Up @@ -401,7 +566,7 @@ func (m *Manager) compactionTargetTokensLocked() int {

// recalculateUsage recalculates token usage (must hold lock).
func (m *Manager) recalculateUsage() {
total := m.totalTokensLocked()
total := m.currentTokensLocked()

m.usage = TokenUsage{
Current: total,
Expand All @@ -421,6 +586,28 @@ func (m *Manager) totalTokensLocked() int {
return total
}

func (m *Manager) currentTokensLocked() int {
total := m.totalTokensLocked()
if !m.hasExactSnapshot {
return total
}

delta := total - m.exactSnapshotEstimate
if delta < 0 {
return total
}
return m.exactSnapshotTokens + delta
}

func (m *Manager) clearProviderTokenUsageLocked() {
m.exactSnapshotTokens = 0
m.exactSnapshotEstimate = 0
m.exactSnapshotProvider = ""
m.exactSnapshotScope = ""
m.exactSnapshotUsage = llm.Usage{}
m.hasExactSnapshot = false
}

func (m *Manager) maxUsableTokensLocked() int {
return m.config.ContextWindow - m.config.ReserveTokens
}
Expand Down Expand Up @@ -472,6 +659,7 @@ func (m *Manager) TruncateTo(count int) {
}

m.messages = keepRecentMessages(m.messages, count)
m.clearProviderTokenUsageLocked()
m.recalculateUsage()
}

Expand Down
Loading