diff --git a/crates/forge_app/src/agent.rs b/crates/forge_app/src/agent.rs index 30ce87198a..d999472f56 100644 --- a/crates/forge_app/src/agent.rs +++ b/crates/forge_app/src/agent.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use forge_config::ForgeConfig; +use forge_config::{ForgeConfig, ModelConfig, Preset}; use forge_domain::{ - Agent, ChatCompletionMessage, Compact, Context, Conversation, Effort, MaxTokens, ModelId, - ProviderId, ReasoningConfig, ResultStream, Temperature, ToolCallContext, ToolCallFull, - ToolResult, TopK, TopP, + Agent, ChatCompletionMessage, Compact, Context, Conversation, Effort, MaxTokens, + ModelId, ProviderId, ReasoningConfig, ResultStream, Temperature, ToolCallContext, + ToolCallFull, ToolResult, TopK, TopP, }; use merge::Merge; @@ -92,43 +92,123 @@ impl AgentExt for Agent { fn apply_config(self, config: &ForgeConfig) -> Agent { let mut agent = self; - if let Some(temperature) = config - .temperature - .and_then(|d| Temperature::new(d.0 as f32).ok()) - { - agent.temperature = Some(temperature); + // Resolve the agent-specific ModelConfig from ForgeConfig. + let agent_model_config: Option<&ModelConfig> = match agent.id.as_str() { + "forge" => config.agent_forge.as_ref(), + "muse" => config.agent_muse.as_ref(), + "sage" => config.agent_sage.as_ref(), + _ => None, + }; + + // Apply model/provider from agent-specific config. + if let Some(mc) = agent_model_config { + if let Some(ref model_id) = mc.model_id { + agent.model = ModelId::new(model_id); + } + if let Some(ref provider_id) = mc.provider_id { + agent.provider = ProviderId::from(provider_id.clone()); + } } - if let Some(top_p) = config.top_p.and_then(|d| TopP::new(d.0 as f32).ok()) { - agent.top_p = Some(top_p); + // Resolve the preset: agent-specific preset_id takes priority over + // nothing (there is no global preset_id on ForgeConfig). + let preset: Option<&Preset> = agent_model_config + .and_then(|mc| mc.preset_id.as_deref()) + .and_then(|id| config.presets.get(id)); + + // Helper: convert a config ReasoningConfig to a domain ReasoningConfig. + let to_domain_reasoning = + |r: &forge_config::ReasoningConfig| -> ReasoningConfig { + use forge_config::Effort as ConfigEffort; + ReasoningConfig { + effort: r.effort.as_ref().map(|e| match e { + ConfigEffort::None => Effort::None, + ConfigEffort::Minimal => Effort::Minimal, + ConfigEffort::Low => Effort::Low, + ConfigEffort::Medium => Effort::Medium, + ConfigEffort::High => Effort::High, + ConfigEffort::XHigh => Effort::XHigh, + ConfigEffort::Max => Effort::Max, + }), + max_tokens: r.max_tokens, + exclude: r.exclude, + enabled: r.enabled, + } + }; + + // --- Apply LLM settings in priority order (lowest → highest) --- + // 1. Config global settings + // 2. Preset settings (from agent-specific ModelConfig's preset_id) + // 3. Agent's own values (never overwritten) + + // temperature + if agent.temperature.is_none() { + let value = preset + .map(|p| p.temperature) + .or(config.temperature) + .and_then(|d| Temperature::new(d.0 as f32).ok()); + if let Some(v) = value { + agent.temperature = Some(v); + } } - if let Some(top_k) = config.top_k.and_then(|k| TopK::new(k).ok()) { - agent.top_k = Some(top_k); + // top_p + if agent.top_p.is_none() { + let value = preset + .map(|p| p.top_p) + .or(config.top_p) + .and_then(|d| TopP::new(d.0 as f32).ok()); + if let Some(v) = value { + agent.top_p = Some(v); + } } - if let Some(max_tokens) = config.max_tokens.and_then(|m| MaxTokens::new(m).ok()) { - agent.max_tokens = Some(max_tokens); + // top_k + if agent.top_k.is_none() { + let value = preset + .map(|p| Some(p.top_k)) + .unwrap_or(config.top_k) + .and_then(|k| TopK::new(k).ok()); + if let Some(v) = value { + agent.top_k = Some(v); + } } - if agent.max_tool_failure_per_turn.is_none() - && let Some(max_tool_failure_per_turn) = config.max_tool_failure_per_turn - { - agent.max_tool_failure_per_turn = Some(max_tool_failure_per_turn); + // max_tokens + if agent.max_tokens.is_none() { + let value = preset + .and_then(|p| p.max_tokens) + .or(config.max_tokens) + .and_then(|m| MaxTokens::new(m).ok()); + if let Some(v) = value { + agent.max_tokens = Some(v); + } } - agent.tool_supported = Some(config.tool_supported); + // tool_supported: preset > config global; agent's own value wins when set + if agent.tool_supported.is_none() { + let value = preset + .map(|p| p.tool_supported) + .unwrap_or(config.tool_supported); + agent.tool_supported = Some(value); + } + + // max_tool_failure_per_turn: agent's own value wins + if agent.max_tool_failure_per_turn.is_none() { + if let Some(v) = config.max_tool_failure_per_turn { + agent.max_tool_failure_per_turn = Some(v); + } + } - if agent.max_requests_per_turn.is_none() - && let Some(max_requests_per_turn) = config.max_requests_per_turn - { - agent.max_requests_per_turn = Some(max_requests_per_turn); + // max_requests_per_turn: agent's own value wins + if agent.max_requests_per_turn.is_none() { + if let Some(v) = config.max_requests_per_turn { + agent.max_requests_per_turn = Some(v); + } } - // Apply workflow compact configuration to agents + // compact: merge workflow config into agent (agent fields take priority) if let Some(ref workflow_compact) = config.compact { - // Convert forge_config::Compact to forge_domain::Compact, then merge. - // Agent settings take priority over workflow settings. let mut merged_compact = Compact { retention_window: workflow_compact.retention_window, eviction_window: workflow_compact.eviction_window.value(), @@ -143,27 +223,15 @@ impl AgentExt for Agent { agent.compact = merged_compact; } - // Apply workflow reasoning configuration to agents. - // Agent-level fields take priority; config fills in any unset fields. - if let Some(ref config_reasoning) = config.reasoning { - use forge_config::Effort as ConfigEffort; - let config_as_domain = ReasoningConfig { - effort: config_reasoning.effort.as_ref().map(|e| match e { - ConfigEffort::None => Effort::None, - ConfigEffort::Minimal => Effort::Minimal, - ConfigEffort::Low => Effort::Low, - ConfigEffort::Medium => Effort::Medium, - ConfigEffort::High => Effort::High, - ConfigEffort::XHigh => Effort::XHigh, - ConfigEffort::Max => Effort::Max, - }), - max_tokens: config_reasoning.max_tokens, - exclude: config_reasoning.exclude, - enabled: config_reasoning.enabled, - }; - // Start from the agent's own settings and fill unset fields from config. + // reasoning: preset > config global; agent fields take highest priority + let base_reasoning = preset + .and_then(|p| p.reasoning.as_ref()) + .or(config.reasoning.as_ref()) + .map(to_domain_reasoning); + + if let Some(base) = base_reasoning { let mut merged = agent.reasoning.clone().unwrap_or_default(); - merged.merge(config_as_domain); + merged.merge(base); agent.reasoning = Some(merged); } @@ -173,8 +241,10 @@ impl AgentExt for Agent { #[cfg(test)] mod tests { - use forge_config::{Effort as ConfigEffort, ReasoningConfig as ConfigReasoningConfig}; - use forge_domain::{AgentId, Effort, ModelId, ProviderId, ReasoningConfig}; + use forge_config::{ + Decimal, Effort as ConfigEffort, ModelConfig, Preset, ReasoningConfig as ConfigReasoningConfig, + }; + use forge_domain::{AgentId, Effort, ModelId, ProviderId, ReasoningConfig, Temperature, TopP}; use pretty_assertions::assert_eq; use super::*; @@ -187,6 +257,14 @@ mod tests { ) } + fn fixture_forge_agent() -> Agent { + Agent::new( + AgentId::FORGE, + ProviderId::ANTHROPIC, + ModelId::new("claude-3-5-sonnet-20241022"), + ) + } + /// When the agent has no reasoning config, the config's reasoning is /// applied in full. #[test] @@ -233,4 +311,103 @@ mod tests { assert_eq!(actual, expected); } + + /// agent_forge config overrides model and provider on a FORGE agent. + #[test] + fn test_agent_specific_model_and_provider_applied() { + let config = ForgeConfig::default().agent_forge( + ModelConfig::default() + .model_id("gpt-4o") + .provider_id("openai"), + ); + + let actual = fixture_forge_agent().apply_config(&config); + + assert_eq!(actual.model, ModelId::new("gpt-4o")); + assert_eq!(actual.provider, ProviderId::from("openai".to_string())); + } + + /// agent_forge config does not affect a non-FORGE agent. + #[test] + fn test_agent_specific_config_not_applied_to_other_agents() { + let config = ForgeConfig::default().agent_forge( + ModelConfig::default() + .model_id("gpt-4o") + .provider_id("openai"), + ); + + let actual = fixture_agent().apply_config(&config); + + // Model and provider remain unchanged. + assert_eq!(actual.model, ModelId::new("claude-3-5-sonnet-20241022")); + assert_eq!(actual.provider, ProviderId::ANTHROPIC); + } + + /// Preset LLM settings are applied when the agent-specific ModelConfig + /// references a preset_id that exists in config.presets. + #[test] + fn test_preset_settings_applied_via_agent_model_config() { + let mut presets = std::collections::HashMap::new(); + presets.insert( + "fast".to_string(), + Preset { temperature: Decimal(0.2), top_p: Decimal(0.8), ..Default::default() }, + ); + + let config = ForgeConfig { + presets, + agent_forge: Some(ModelConfig::default().preset_id("fast")), + ..Default::default() + }; + + let actual = fixture_forge_agent().apply_config(&config); + + assert_eq!(actual.temperature, Temperature::new(0.2).ok()); + assert_eq!(actual.top_p, TopP::new(0.8).ok()); + } + + /// Preset settings take priority over config global settings. + #[test] + fn test_preset_takes_priority_over_global_config() { + let mut presets = std::collections::HashMap::new(); + presets.insert( + "precise".to_string(), + Preset { temperature: Decimal(0.1), ..Default::default() }, + ); + + let config = ForgeConfig { + presets, + // Global temperature is higher; preset should win. + temperature: Some(Decimal(1.0)), + agent_forge: Some(ModelConfig::default().preset_id("precise")), + ..Default::default() + }; + + let actual = fixture_forge_agent().apply_config(&config); + + assert_eq!(actual.temperature, Temperature::new(0.1).ok()); + } + + /// Agent's own temperature takes priority over both preset and global config. + #[test] + fn test_agent_temperature_takes_priority_over_preset_and_global() { + let mut presets = std::collections::HashMap::new(); + presets.insert( + "fast".to_string(), + Preset { temperature: Decimal(0.2), ..Default::default() }, + ); + + let config = ForgeConfig { + presets, + temperature: Some(Decimal(1.0)), + agent_forge: Some(ModelConfig::default().preset_id("fast")), + ..Default::default() + }; + + let agent = + fixture_forge_agent().temperature(Temperature::new(0.5).unwrap()); + + let actual = agent.apply_config(&config); + + assert_eq!(actual.temperature, Temperature::new(0.5).ok()); + } } diff --git a/crates/forge_config/src/config.rs b/crates/forge_config/src/config.rs index 2e2af49569..040d66b33c 100644 --- a/crates/forge_config/src/config.rs +++ b/crates/forge_config/src/config.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::path::PathBuf; use derive_setters::Setters; @@ -8,7 +9,8 @@ use serde::{Deserialize, Serialize}; use crate::reader::ConfigReader; use crate::writer::ConfigWriter; use crate::{ - AutoDumpFormat, Compact, Decimal, HttpConfig, ModelConfig, ReasoningConfig, RetryConfig, Update, + AutoDumpFormat, Compact, Decimal, HttpConfig, ModelConfig, Preset, ReasoningConfig, RetryConfig, + Update, }; /// Top-level Forge configuration merged from all sources (defaults, file, @@ -116,6 +118,15 @@ pub struct ForgeConfig { /// generation. #[serde(default, skip_serializing_if = "Option::is_none")] pub suggest: Option, + /// Model and provider configuration used for the Forge agent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub agent_forge: Option, + /// Model and provider configuration used for the Muse agent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub agent_muse: Option, + /// Model and provider configuration used for the Sage agent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub agent_sage: Option, // --- Workflow fields --- /// Configuration for automatic Forge updates. @@ -170,6 +181,11 @@ pub struct ForgeConfig { /// token budget, and visibility of the model's thinking process. #[serde(default, skip_serializing_if = "Option::is_none")] pub reasoning: Option, + + /// Named LLM inference presets that can be referenced by id from a model + /// configuration. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub presets: HashMap, } impl ForgeConfig { diff --git a/crates/forge_config/src/legacy.rs b/crates/forge_config/src/legacy.rs index 22f35ce52b..af36d0f2e9 100644 --- a/crates/forge_config/src/legacy.rs +++ b/crates/forge_config/src/legacy.rs @@ -59,16 +59,16 @@ impl LegacyConfig { fn into_forge_config(self) -> ForgeConfig { let session = self.provider.as_deref().map(|provider_id| { let model_id = self.model.get(provider_id).cloned(); - ModelConfig { provider_id: Some(provider_id.to_string()), model_id } + ModelConfig { provider_id: Some(provider_id.to_string()), model_id, preset_id: None } }); let commit = self .commit - .map(|c| ModelConfig { provider_id: c.provider, model_id: c.model }); + .map(|c| ModelConfig { provider_id: c.provider, model_id: c.model, preset_id: None }); let suggest = self .suggest - .map(|s| ModelConfig { provider_id: s.provider, model_id: s.model }); + .map(|s| ModelConfig { provider_id: s.provider, model_id: s.model, preset_id: None }); ForgeConfig { session, commit, suggest, ..Default::default() } } diff --git a/crates/forge_config/src/lib.rs b/crates/forge_config/src/lib.rs index cc253277e4..3196c06a08 100644 --- a/crates/forge_config/src/lib.rs +++ b/crates/forge_config/src/lib.rs @@ -7,6 +7,7 @@ mod http; mod legacy; mod model; mod percentage; +mod preset; mod reader; mod reasoning; mod retry; @@ -20,6 +21,7 @@ pub use error::Error; pub use http::*; pub use model::*; pub use percentage::*; +pub use preset::*; pub use reader::*; pub use reasoning::*; pub use retry::*; diff --git a/crates/forge_config/src/model.rs b/crates/forge_config/src/model.rs index c993222700..bbb22fbdc4 100644 --- a/crates/forge_config/src/model.rs +++ b/crates/forge_config/src/model.rs @@ -18,4 +18,8 @@ pub struct ModelConfig { pub provider_id: Option, /// The model to use for this operation. pub model_id: Option, + /// The id of a preset defined in [`ForgeConfig::presets`] whose LLM + /// settings are applied to this operation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub preset_id: Option, } diff --git a/crates/forge_config/src/preset.rs b/crates/forge_config/src/preset.rs new file mode 100644 index 0000000000..c9cb65b853 --- /dev/null +++ b/crates/forge_config/src/preset.rs @@ -0,0 +1,40 @@ +use derive_setters::Setters; +use fake::Dummy; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use crate::{Decimal, ReasoningConfig}; + +/// A named collection of LLM inference settings that can be referenced by id +/// from a model configuration. +#[derive(Default, Debug, Setters, Clone, PartialEq, Serialize, Deserialize, JsonSchema, Dummy)] +#[serde(rename_all = "snake_case")] +pub struct Preset { + /// Output randomness for the model; lower values are deterministic, higher + /// values are creative (0.0–2.0). + #[serde(default)] + pub temperature: Decimal, + + /// Nucleus sampling threshold; limits token selection to the top cumulative + /// probability mass (0.0–1.0). + #[serde(default)] + pub top_p: Decimal, + + /// Top-k vocabulary cutoff; restricts sampling to the k + /// highest-probability tokens (1–1000). + #[serde(default)] + pub top_k: u32, + + /// Maximum tokens the model may generate per response (1–100,000). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// Reasoning configuration; controls effort level, token budget, and + /// visibility of the model's thinking process. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + /// Whether tool use is supported; when false, all tool calls are disabled. + #[serde(default)] + pub tool_supported: bool, +} diff --git a/crates/forge_infra/src/env.rs b/crates/forge_infra/src/env.rs index 902198b447..9e3a11191d 100644 --- a/crates/forge_infra/src/env.rs +++ b/crates/forge_infra/src/env.rs @@ -48,7 +48,7 @@ fn apply_config_op(fc: &mut ForgeConfig, op: ConfigOperation) { session.model_id = Some(mid_str); } else { fc.session = - Some(ModelConfig { provider_id: Some(pid_str), model_id: Some(mid_str) }); + Some(ModelConfig { provider_id: Some(pid_str), model_id: Some(mid_str), preset_id: None }); } } ConfigOperation::SetCommitConfig(commit) => { @@ -59,12 +59,14 @@ fn apply_config_op(fc: &mut ForgeConfig, op: ConfigOperation) { .map(|(pid, mid)| ModelConfig { provider_id: Some(pid.as_ref().to_string()), model_id: Some(mid.to_string()), + preset_id: None, }); } ConfigOperation::SetSuggestConfig(suggest) => { fc.suggest = Some(ModelConfig { provider_id: Some(suggest.provider.as_ref().to_string()), model_id: Some(suggest.model.to_string()), + preset_id: None, }); } ConfigOperation::SetReasoningEffort(effort) => {