Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ pub trait API: Sync + Send {
/// Sets the default provider for all the agents
async fn set_default_provider(&self, provider_id: ProviderId) -> anyhow::Result<()>;

/// Updates the caller's default provider and model together, ensuring all
/// commands resolve a consistent pair without requiring a follow-up model
/// selection call.
async fn set_default_provider_and_model(
&self,
provider_id: ProviderId,
model: ModelId,
) -> anyhow::Result<()>;

/// Retrieves information about the currently authenticated user
async fn user_info(&self) -> anyhow::Result<Option<User>>;

Expand Down
13 changes: 13 additions & 0 deletions crates/forge_api/src/forge_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,19 @@ impl<A: Services, F: CommandInfra + EnvironmentInfra + SkillRepository + GrpcInf
result
}

async fn set_default_provider_and_model(
&self,
provider_id: ProviderId,
model: ModelId,
) -> anyhow::Result<()> {
let result = self
.services
.set_default_provider_and_model(provider_id, model)
.await;
let _ = self.services.reload_agents().await;
result
}

async fn get_commit_config(&self) -> anyhow::Result<Option<CommitConfig>> {
self.services.get_commit_config().await
}
Expand Down
8 changes: 8 additions & 0 deletions crates/forge_app/src/command_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@ mod tests {
Ok(())
}

async fn set_default_provider_and_model(
&self,
_provider_id: ProviderId,
_model: ModelId,
) -> anyhow::Result<()> {
Ok(())
}

async fn get_commit_config(&self) -> Result<Option<forge_domain::CommitConfig>> {
Ok(None)
}
Expand Down
18 changes: 18 additions & 0 deletions crates/forge_app/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ pub trait AppConfigService: Send + Sync {
/// Returns an error if no default provider is configured.
async fn set_default_model(&self, model: ModelId) -> anyhow::Result<()>;

/// Sets the user's default provider and default model in a single atomic
/// update so the persisted configuration never stores a mismatched pair.
async fn set_default_provider_and_model(
&self,
provider_id: ProviderId,
model: ModelId,
) -> anyhow::Result<()>;

/// Gets the commit configuration (provider and model for commit message
/// generation).
async fn get_commit_config(&self) -> anyhow::Result<Option<forge_domain::CommitConfig>>;
Expand Down Expand Up @@ -971,6 +979,16 @@ impl<I: Services> AppConfigService for I {
self.config_service().get_provider_model(provider_id).await
}

async fn set_default_provider_and_model(
&self,
provider_id: forge_domain::ProviderId,
model: ModelId,
) -> anyhow::Result<()> {
self.config_service()
.set_default_provider_and_model(provider_id, model)
.await
}

async fn set_default_model(&self, model: ModelId) -> anyhow::Result<()> {
self.config_service().set_default_model(model).await
}
Expand Down
112 changes: 64 additions & 48 deletions crates/forge_main/src/ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use convert_case::{Case, Casing};
use forge_api::{
API, AgentId, AnyProvider, ApiKeyRequest, AuthContextRequest, AuthContextResponse, ChatRequest,
ChatResponse, CodeRequest, Conversation, ConversationId, DeviceCodeRequest, Event,
InterruptionReason, Model, ModelId, Provider, ProviderId, TextMessage, UserPrompt,
InterruptionReason, ModelId, Provider, ProviderId, TextMessage, UserPrompt,
};
use forge_app::utils::{format_display_path, truncate_key};
use forge_app::{CommitResult, ToolResolver};
Expand Down Expand Up @@ -127,14 +127,6 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
self.spinner.ewrite_ln(title)
}

/// Retrieve available models
async fn get_models(&mut self) -> Result<Vec<Model>> {
self.spinner.start(Some("Loading"))?;
let models = self.api.get_models().await?;
self.spinner.stop(None)?;
Ok(models)
}

/// Helper to get provider for an optional agent, defaulting to the current
/// active agent's provider
async fn get_provider(&self, agent_id: Option<AgentId>) -> Result<Provider<Url>> {
Expand Down Expand Up @@ -649,6 +641,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
return Ok(());
}
TopLevelCommand::Commit(commit_group) => {
self.init_state(false).await?;
let preview = commit_group.preview;
let result = self.handle_commit_command(commit_group).await?;
if preview {
Expand Down Expand Up @@ -1899,7 +1892,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
self.on_custom_event(event.into()).await?;
}
SlashCommand::Model => {
self.on_model_selection(None).await?;
self.on_model_selection(None, None).await?;
}
SlashCommand::Provider => {
self.on_provider_selection().await?;
Expand Down Expand Up @@ -2074,15 +2067,11 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
provider_filter: Option<ProviderId>,
) -> Result<Option<ModelId>> {
// Check if provider is set otherwise first ask to select a provider
if self.api.get_default_provider().await.is_err() {
self.on_provider_selection().await?;

// Check if a model was already selected during provider activation
// Return None to signal the model selection is complete and message was already
// printed
if self.api.get_default_model().await.is_some() {
if provider_filter.is_none() && self.api.get_default_provider().await.is_err() {
if !self.on_provider_selection().await? {
return Ok(None);
}
return Ok(None);
}

// Fetch models from ALL configured providers (matches shell plugin's
Expand Down Expand Up @@ -2713,6 +2702,7 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
async fn on_model_selection(
&mut self,
provider_filter: Option<ProviderId>,
provider_to_activate: Option<ProviderId>,
) -> Result<Option<ModelId>> {
// Select a model
let model_option = self.select_model(provider_filter).await?;
Expand All @@ -2723,8 +2713,14 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
None => return Ok(None),
};

// Update the operating model via API
self.api.set_default_model(model.clone()).await?;
// If we have a provider to activate, write both atomically
if let Some(provider_id) = provider_to_activate {
self.api
.set_default_provider_and_model(provider_id, model.clone())
.await?;
} else {
self.api.set_default_model(model.clone()).await?;
}

// Update the UI state with the new model
self.update_model(Some(model.clone()));
Expand All @@ -2734,15 +2730,18 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
Ok(Some(model))
}

async fn on_provider_selection(&mut self) -> Result<()> {
async fn on_provider_selection(&mut self) -> Result<bool> {
// Select a provider
// If no provider was selected (user canceled), return early
let any_provider = match self.select_provider().await? {
Some(provider) => provider,
None => return Ok(()),
None => return Ok(false),
};

self.activate_provider(any_provider).await
self.activate_provider(any_provider).await?;
// Check if provider was actually saved — if user cancelled model selection
// inside activate_provider, nothing was written
Ok(self.api.get_default_provider().await.is_ok())
}

/// Activates a provider by configuring it if needed, setting it as default,
Expand Down Expand Up @@ -2789,21 +2788,19 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
provider: Provider<Url>,
model: Option<ModelId>,
) -> Result<()> {
// Set the provider via API
self.api.set_default_provider(provider.id.clone()).await?;

self.writeln_title(
TitleFormat::action(format!("{}", provider.id))
.sub_title("is now the default provider"),
)?;

// If a model was pre-selected (e.g. from :model), validate and set it
// directly without prompting
if let Some(model) = model {
let model_id = self
.validate_model(model.as_str(), Some(&provider.id))
.await?;
self.api.set_default_model(model_id.clone()).await?;
self.api
.set_default_provider_and_model(provider.id.clone(), model_id.clone())
.await?;
self.writeln_title(
TitleFormat::action(format!("{}", provider.id))
.sub_title("is now the default provider"),
)?;
self.writeln_title(
TitleFormat::action(model_id.as_str()).sub_title("is now the default model"),
)?;
Expand All @@ -2812,18 +2809,37 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {

// Check if the current model is available for the new provider
let current_model = self.api.get_default_model().await;
if let Some(current_model) = current_model {
let models = self.get_models().await?;
let model_available = models.iter().any(|m| m.id == current_model);
let needs_model_selection = match current_model {
None => true,
Some(current_model) => {
let provider_models = self.api.get_all_provider_models().await?;
let model_available = provider_models
.iter()
.find(|pm| pm.provider_id == provider.id)
.map(|pm| pm.models.iter().any(|m| m.id == current_model))
.unwrap_or(false);
!model_available
}
};

if !model_available {
// Prompt user to select a new model, scoped to the activated provider
self.writeln_title(TitleFormat::info("Please select a new model"))?;
self.on_model_selection(Some(provider.id.clone())).await?;
if needs_model_selection {
self.writeln_title(TitleFormat::info("Please select a new model"))?;
let selected = self
.on_model_selection(Some(provider.id.clone()), Some(provider.id.clone()))
.await?;
if selected.is_none() {
// User cancelled — preserve existing config untouched
return Ok(());
}
} else {
// No model set, select one now scoped to the activated provider
self.on_model_selection(Some(provider.id.clone())).await?;
// Set the provider via API
// Only reaches here if model is confirmed — safe to write provider now
self.api.set_default_provider(provider.id.clone()).await?;

self.writeln_title(
TitleFormat::action(format!("{}", provider.id))
.sub_title("is now the default provider"),
)?;
}

Ok(())
Expand Down Expand Up @@ -2931,17 +2947,17 @@ impl<A: API + ConsoleWriter + 'static, F: Fn() -> A + Send + Sync> UI<A, F> {
// Ensure we have a model selected before proceeding with initialization
let active_agent = self.api.get_active_agent().await;

let mut operating_model = self.get_agent_model(active_agent.clone()).await;
if operating_model.is_none() {
// Use the model returned from selection instead of re-fetching
operating_model = self.on_model_selection(None).await?;
}

// Validate provider is configured before loading agents
// If provider is set in config but not configured (no credentials), prompt user
// to login
if self.api.get_default_provider().await.is_err() {
self.on_provider_selection().await?;
if self.api.get_default_provider().await.is_err() && !self.on_provider_selection().await? {
return Ok(());
}

let mut operating_model = self.get_agent_model(active_agent.clone()).await;
if operating_model.is_none() {
// Use the model returned from selection instead of re-fetching
operating_model = self.on_model_selection(None, None).await?;
}

if first {
Expand Down
9 changes: 9 additions & 0 deletions crates/forge_services/src/app_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ impl<F: ProviderRepository + EnvironmentInfra + Send + Sync> AppConfigService
.await
}

async fn set_default_provider_and_model(
&self,
provider_id: ProviderId,
model: ModelId,
) -> anyhow::Result<()> {
self.update(ConfigOperation::SetModel(provider_id, model))
.await
}

async fn get_commit_config(&self) -> anyhow::Result<Option<forge_domain::CommitConfig>> {
let config = self.infra.get_config();
Ok(config.commit.map(|mc| CommitConfig {
Expand Down
Loading