diff --git a/desktop/src/gpu_context.rs b/desktop/src/gpu_context.rs index 4f43abc75c..019e2516ff 100644 --- a/desktop/src/gpu_context.rs +++ b/desktop/src/gpu_context.rs @@ -1,20 +1,18 @@ use crate::wrapper::{WgpuContext, WgpuContextBuilder, WgpuFeatures}; pub(super) async fn create_wgpu_context() -> WgpuContext { - let wgpu_context_builder = WgpuContextBuilder::new().with_features(WgpuFeatures::IMMEDIATES); + let mut wgpu_context_builder = WgpuContextBuilder::new().with_features(WgpuFeatures::IMMEDIATES); + + // TODO: make this configurable via cli flags instead + if let Some(index) = std::env::var("GRAPHITE_WGPU_ADAPTER").ok().and_then(|s| s.parse().ok()) { + tracing::info!("Overriding WGPU adapter selection with adapter index {index}"); + wgpu_context_builder = wgpu_context_builder.with_selection(index); + } // TODO: add a cli flag to list adapters and exit instead of always printing println!("\nAvailable WGPU adapters:\n{}", wgpu_context_builder.available_adapters_fmt().await); - // TODO: make this configurable via cli flags instead - let wgpu_context = match std::env::var("GRAPHITE_WGPU_ADAPTER").ok().and_then(|s| s.parse().ok()) { - None => wgpu_context_builder.build().await, - Some(adapter_index) => { - tracing::info!("Overriding WGPU adapter selection with adapter index {adapter_index}"); - wgpu_context_builder.build_with_adapter_selection(|_| Some(adapter_index)).await - } - } - .expect("Failed to create WGPU context"); + let wgpu_context = wgpu_context_builder.build().await.expect("Failed to create WGPU context"); // TODO: add a cli flag to list adapters and exit instead of always printing println!("Using WGPU adapter: {:?}", wgpu_context.adapter.get_info()); diff --git a/desktop/wrapper/src/lib.rs b/desktop/wrapper/src/lib.rs index ad4c3f13c4..be77396140 100644 --- a/desktop/wrapper/src/lib.rs +++ b/desktop/wrapper/src/lib.rs @@ -8,6 +8,7 @@ use std::sync::Arc; pub use graph_craft::application_io::resource::MmapResourceStorage; pub use graphite_editor::consts::{DOUBLE_CLICK_MILLISECONDS, FILE_EXTENSION}; +pub use wgpu_executor::WgpuBackends; pub use wgpu_executor::WgpuContext; pub use wgpu_executor::WgpuContextBuilder; pub use wgpu_executor::WgpuExecutor; diff --git a/node-graph/libraries/wgpu-executor/src/context.rs b/node-graph/libraries/wgpu-executor/src/context.rs index cea37507fd..7006d7e78a 100644 --- a/node-graph/libraries/wgpu-executor/src/context.rs +++ b/node-graph/libraries/wgpu-executor/src/context.rs @@ -18,12 +18,14 @@ impl Context { pub struct ContextBuilder { backends: Backends, features: Features, + selection: Option, } impl ContextBuilder { pub fn new() -> Self { Self { backends: Backends::all(), features: Features::empty(), + selection: None, } } pub fn with_backends(mut self, backends: Backends) -> Self { @@ -34,21 +36,34 @@ impl ContextBuilder { self.features = features; self } + pub fn with_selection(mut self, index: usize) -> Self { + self.selection = Some(index); + self + } } #[cfg(not(target_family = "wasm"))] impl ContextBuilder { pub async fn build(self) -> Option { - self.build_with_adapter_selection_inner(None:: Option>).await - } - pub async fn build_with_adapter_selection(self, select: S) -> Option - where - S: Fn(&[Adapter]) -> Option, - { - self.build_with_adapter_selection_inner(Some(select)).await + let instance = self.build_instance(); + let mut adapters = enumerate_sorted(&instance, self.backends).await; + + if let Some(index) = self.selection + && index >= adapters.len() + { + let selected_adapter = adapters.remove(index); + adapters.insert(0, selected_adapter); + } + + for adapter in adapters { + if let Some((device, queue)) = self.request_device(&adapter).await { + return Some(Context { device, queue, adapter, instance }); + } + } + None } pub async fn available_adapters_fmt(&self) -> impl std::fmt::Display { let instance = self.build_instance(); - fmt::AvailableAdaptersFormatter(instance.enumerate_adapters(self.backends).await) + fmt::AvailableAdaptersFormatter(enumerate_sorted(&instance, self.backends).await) } } #[cfg(target_family = "wasm")] @@ -67,6 +82,7 @@ impl ContextBuilder { ..wgpu::InstanceDescriptor::new_without_display_handle() }) } + #[cfg(target_family = "wasm")] async fn request_adapter(&self, instance: &Instance) -> Option { let request_adapter_options = wgpu::RequestAdapterOptions { power_preference: wgpu::PowerPreference::HighPerformance, @@ -88,38 +104,35 @@ impl ContextBuilder { } } #[cfg(not(target_family = "wasm"))] -impl ContextBuilder { - async fn build_with_adapter_selection_inner(self, select: Option) -> Option - where - S: Fn(&[Adapter]) -> Option, - { - let instance = self.build_instance(); - - let selected_adapter = if let Some(select) = select { - self.select_adapter(&instance, select).await - } else if cfg!(target_os = "windows") { - self.select_adapter(&instance, |adapters: &[Adapter]| adapters.iter().position(|a| a.get_info().backend == wgpu::Backend::Dx12)) - .await - } else { - None - }; - - let adapter = if let Some(adapter) = selected_adapter { adapter } else { self.request_adapter(&instance).await? }; - - let (device, queue) = self.request_device(&adapter).await?; - Some(Context { device, queue, adapter, instance }) - } - async fn select_adapter(&self, instance: &Instance, select: S) -> Option - where - S: Fn(&[Adapter]) -> Option, - { - let mut adapters = instance.enumerate_adapters(self.backends).await; - let selected_index = select(&adapters)?; - if selected_index >= adapters.len() { - return None; +async fn enumerate_sorted(instance: &Instance, backends: Backends) -> Vec { + let mut adapters = instance.enumerate_adapters(backends).await; + adapters.sort_by_key(adapter_priority); + adapters +} +#[cfg(not(target_family = "wasm"))] +fn adapter_priority(adapter: &Adapter) -> (u8, u8) { + let info = adapter.get_info(); + let backend = if cfg!(target_os = "linux") { + match info.backend { + wgpu::Backend::Vulkan => 0, + _ => 1, } - Some(adapters.remove(selected_index)) - } + } else if cfg!(target_os = "windows") { + match info.backend { + wgpu::Backend::Dx12 => 0, + _ => 1, + } + } else { + 0 + }; + let device_type = match info.device_type { + wgpu::DeviceType::DiscreteGpu => 0, + wgpu::DeviceType::IntegratedGpu => 1, + wgpu::DeviceType::VirtualGpu => 2, + wgpu::DeviceType::Cpu => 3, + wgpu::DeviceType::Other => 4, + }; + (backend, device_type) } #[cfg(not(target_family = "wasm"))] mod fmt {