From ed0ad2025219d0536d3fb349a7f0f9dd831187c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kr=C3=B6ning?= Date: Wed, 11 Feb 2026 15:34:02 +0100 Subject: [PATCH 1/7] refactor(arch): do more early returns --- src/arch/aarch64/kernel/interrupts.rs | 157 +++++++++--------- src/arch/aarch64/kernel/mmio.rs | 228 +++++++++++++------------- src/arch/aarch64/kernel/pci.rs | 98 +++++------ src/arch/aarch64/kernel/processor.rs | 31 ++-- src/arch/aarch64/mm/paging.rs | 42 ++--- src/arch/riscv64/kernel/devicetree.rs | 52 +++--- src/arch/riscv64/kernel/mod.rs | 51 +++--- src/arch/riscv64/kernel/processor.rs | 19 ++- src/arch/riscv64/mm/paging.rs | 12 +- src/arch/x86_64/kernel/acpi.rs | 86 +++++----- src/arch/x86_64/kernel/apic.rs | 74 +++++---- src/arch/x86_64/kernel/mod.rs | 15 +- src/arch/x86_64/kernel/processor.rs | 90 +++++----- 13 files changed, 482 insertions(+), 473 deletions(-) diff --git a/src/arch/aarch64/kernel/interrupts.rs b/src/arch/aarch64/kernel/interrupts.rs index c07ff1017b..ac20ff9151 100644 --- a/src/arch/aarch64/kernel/interrupts.rs +++ b/src/arch/aarch64/kernel/interrupts.rs @@ -129,54 +129,54 @@ pub(crate) fn install_handlers() { #[unsafe(no_mangle)] pub(crate) extern "C" fn do_fiq(_state: &State) -> *mut usize { - if let Some(irqid) = GicV3::get_and_acknowledge_interrupt(InterruptGroup::Group1) { - let vector: u8 = u32::from(irqid).try_into().unwrap(); - - debug!("Receive fiq {vector}"); - increment_irq_counter(vector); + let Some(irqid) = GicV3::get_and_acknowledge_interrupt(InterruptGroup::Group1) else { + return ptr::null_mut(); + }; - if let Some(handlers) = INTERRUPT_HANDLERS.get() - && let Some(queue) = handlers.get(&vector) - { - for handler in queue.iter() { - handler(); - } - } - crate::executor::run(); - core_scheduler().handle_waiting_tasks(); + let vector: u8 = u32::from(irqid).try_into().unwrap(); - GicV3::end_interrupt(irqid, InterruptGroup::Group1); + debug!("Receive fiq {vector}"); + increment_irq_counter(vector); - return core_scheduler().scheduler().unwrap_or_default(); + if let Some(handlers) = INTERRUPT_HANDLERS.get() + && let Some(queue) = handlers.get(&vector) + { + for handler in queue.iter() { + handler(); + } } + crate::executor::run(); + core_scheduler().handle_waiting_tasks(); + + GicV3::end_interrupt(irqid, InterruptGroup::Group1); - ptr::null_mut() + core_scheduler().scheduler().unwrap_or_default() } #[unsafe(no_mangle)] pub(crate) extern "C" fn do_irq(_state: &State) -> *mut usize { - if let Some(irqid) = GicV3::get_and_acknowledge_interrupt(InterruptGroup::Group1) { - let vector: u8 = u32::from(irqid).try_into().unwrap(); - - debug!("Receive interrupt {vector}"); - increment_irq_counter(vector); + let Some(irqid) = GicV3::get_and_acknowledge_interrupt(InterruptGroup::Group1) else { + return ptr::null_mut(); + }; - if let Some(handlers) = INTERRUPT_HANDLERS.get() - && let Some(queue) = handlers.get(&vector) - { - for handler in queue.iter() { - handler(); - } - } - crate::executor::run(); - core_scheduler().handle_waiting_tasks(); + let vector: u8 = u32::from(irqid).try_into().unwrap(); - GicV3::end_interrupt(irqid, InterruptGroup::Group1); + debug!("Receive interrupt {vector}"); + increment_irq_counter(vector); - return core_scheduler().scheduler().unwrap_or_default(); + if let Some(handlers) = INTERRUPT_HANDLERS.get() + && let Some(queue) = handlers.get(&vector) + { + for handler in queue.iter() { + handler(); + } } + crate::executor::run(); + core_scheduler().handle_waiting_tasks(); + + GicV3::end_interrupt(irqid, InterruptGroup::Group1); - ptr::null_mut() + core_scheduler().scheduler().unwrap_or_default() } #[unsafe(no_mangle)] @@ -436,51 +436,54 @@ pub(crate) fn init() { pub fn init_cpu() { let cpu_id: usize = core_id().try_into().unwrap(); - if let Some(ref mut gic) = *GIC.lock() { - debug!("Mark cpu {cpu_id} as awake"); - - gic.setup(cpu_id); - GicV3::set_priority_mask(0xff); - - let fdt = env::fdt().unwrap(); - - if let Some(timer_node) = fdt.find_compatible(&["arm,armv8-timer", "arm,armv7-timer"]) { - let irq_slice = timer_node.property("interrupts").unwrap().value; - /* Secure Phys IRQ */ - let (_irqtype, irq_slice) = irq_slice.split_at(mem::size_of::()); - let (_irq, irq_slice) = irq_slice.split_at(mem::size_of::()); - let (_irqflags, irq_slice) = irq_slice.split_at(mem::size_of::()); - /* Non-secure Phys IRQ */ - let (irqtype, irq_slice) = irq_slice.split_at(mem::size_of::()); - let (irq, irq_slice) = irq_slice.split_at(mem::size_of::()); - let (irqflags, _irq_slice) = irq_slice.split_at(mem::size_of::()); - let irqtype = u32::from_be_bytes(irqtype.try_into().unwrap()); - let irq = u32::from_be_bytes(irq.try_into().unwrap()); - let irqflags = u32::from_be_bytes(irqflags.try_into().unwrap()); - - // enable timer interrupt - let timer_irqid = if irqtype == 1 { - IntId::ppi(irq) - } else if irqtype == 0 { - IntId::spi(irq) - } else { - panic!("Invalid interrupt type"); - }; - gic.set_interrupt_priority(timer_irqid, Some(cpu_id), 0x00); - if (irqflags & 0xf) == 4 || (irqflags & 0xf) == 8 { - gic.set_trigger(timer_irqid, Some(cpu_id), Trigger::Level); - } else if (irqflags & 0xf) == 2 || (irqflags & 0xf) == 1 { - gic.set_trigger(timer_irqid, Some(cpu_id), Trigger::Edge); - } else { - panic!("Invalid interrupt level!"); - } - gic.enable_interrupt(timer_irqid, Some(cpu_id), true); - } + let mut gic = GIC.lock(); + let Some(gic) = &mut *gic else { + return; + }; + + debug!("Mark cpu {cpu_id} as awake"); + + gic.setup(cpu_id); + GicV3::set_priority_mask(0xff); + + let fdt = env::fdt().unwrap(); + + if let Some(timer_node) = fdt.find_compatible(&["arm,armv8-timer", "arm,armv7-timer"]) { + let irq_slice = timer_node.property("interrupts").unwrap().value; + /* Secure Phys IRQ */ + let (_irqtype, irq_slice) = irq_slice.split_at(mem::size_of::()); + let (_irq, irq_slice) = irq_slice.split_at(mem::size_of::()); + let (_irqflags, irq_slice) = irq_slice.split_at(mem::size_of::()); + /* Non-secure Phys IRQ */ + let (irqtype, irq_slice) = irq_slice.split_at(mem::size_of::()); + let (irq, irq_slice) = irq_slice.split_at(mem::size_of::()); + let (irqflags, _irq_slice) = irq_slice.split_at(mem::size_of::()); + let irqtype = u32::from_be_bytes(irqtype.try_into().unwrap()); + let irq = u32::from_be_bytes(irq.try_into().unwrap()); + let irqflags = u32::from_be_bytes(irqflags.try_into().unwrap()); - let reschedid = IntId::sgi(SGI_RESCHED.into()); - gic.set_interrupt_priority(reschedid, Some(cpu_id), 0x01); - gic.enable_interrupt(reschedid, Some(cpu_id), true); + // enable timer interrupt + let timer_irqid = if irqtype == 1 { + IntId::ppi(irq) + } else if irqtype == 0 { + IntId::spi(irq) + } else { + panic!("Invalid interrupt type"); + }; + gic.set_interrupt_priority(timer_irqid, Some(cpu_id), 0x00); + if (irqflags & 0xf) == 4 || (irqflags & 0xf) == 8 { + gic.set_trigger(timer_irqid, Some(cpu_id), Trigger::Level); + } else if (irqflags & 0xf) == 2 || (irqflags & 0xf) == 1 { + gic.set_trigger(timer_irqid, Some(cpu_id), Trigger::Edge); + } else { + panic!("Invalid interrupt level!"); + } + gic.enable_interrupt(timer_irqid, Some(cpu_id), true); } + + let reschedid = IntId::sgi(SGI_RESCHED.into()); + gic.set_interrupt_priority(reschedid, Some(cpu_id), 0x01); + gic.enable_interrupt(reschedid, Some(cpu_id), true); } static IRQ_NAMES: InterruptTicketMutex> = diff --git a/src/arch/aarch64/kernel/mmio.rs b/src/arch/aarch64/kernel/mmio.rs index e6bcdc6b46..b94f85ab1b 100644 --- a/src/arch/aarch64/kernel/mmio.rs +++ b/src/arch/aarch64/kernel/mmio.rs @@ -84,124 +84,124 @@ pub(crate) fn get_filesystem_driver() -> Option<&'static InterruptTicketMutex( - virtio_region_start.align_down(paging::BasePageSize::SIZE), - ); - - // Verify the first register value to find out if this is really an MMIO magic-value. - let ptr = virtio_region.starting_address as *mut DeviceRegisters; - let mmio = unsafe { VolatileRef::new(NonNull::new(ptr).unwrap()) }; - - let magic = mmio.as_ptr().magic_value().read().to_ne(); - let version = mmio.as_ptr().version().read().to_ne(); - - const MMIO_MAGIC_VALUE: u32 = 0x7472_6976; - if magic != MMIO_MAGIC_VALUE { - error!("It's not a MMIO-device at {mmio:p}"); - } - - if version != 2 { - warn!("Found a legacy device, which isn't supported"); - } - - // We found a MMIO-device (whose 512-bit address in this structure). - trace!("Found a MMIO-device at {mmio:p}"); - - // Verify the device-ID to find the network card - let id = mmio.as_ptr().device_id().read(); - let cpu_id: usize = 0; - - if id == virtio::Id::Reserved { - continue; - } - - debug!( - "Found {id:?} card at {mmio:p}, irq: {irq}, type: {irqtype}, flags: {irqflags}" - ); - - let drv = match mmio_virtio::init_device(mmio, irq.try_into().unwrap()) - { - Ok(drv) => drv, - Err(err) => { - error!("{err}"); - continue; - } - }; - - let mut gic = GIC.lock(); - let Some(gic) = gic.as_mut() else { - error!("No GIC found"); - continue; - }; - - // enable timer interrupt - let virtio_irqid = if irqtype == 1 { - IntId::ppi(irq) - } else if irqtype == 0 { - IntId::spi(irq) - } else { - panic!("Invalid interrupt type"); - }; - gic.set_interrupt_priority(virtio_irqid, Some(cpu_id), 0x00); - if (irqflags & 0xf) == 4 || (irqflags & 0xf) == 8 { - gic.set_trigger(virtio_irqid, Some(cpu_id), Trigger::Level); - } else if (irqflags & 0xf) == 2 || (irqflags & 0xf) == 1 { - gic.set_trigger(virtio_irqid, Some(cpu_id), Trigger::Edge); - } else { - panic!("Invalid interrupt level!"); - } - gic.enable_interrupt(virtio_irqid, Some(cpu_id), true); - - match drv { - #[cfg(feature = "virtio-console")] - VirtioDriver::Console(drv) => register_driver(MmioDriver::VirtioConsole( - InterruptTicketMutex::new(*drv), - )), - #[cfg(feature = "virtio-fs")] - VirtioDriver::FileSystem(drv) => register_driver(MmioDriver::VirtioFs( - hermit_sync::InterruptTicketMutex::new(*drv), - )), - #[cfg(feature = "virtio-net")] - VirtioDriver::Net(drv) => *NETWORK_DEVICE.lock() = Some(*drv), - } + let Some(fdt) = crate::env::fdt() else { + error!("No device tree found, cannot initialize MMIO drivers"); + return; + }; + + for node in fdt.find_all_nodes("/virtio_mmio") { + let Some(compatible) = node.compatible() else { + continue; + }; + + for i in compatible.all() { + if i == "virtio,mmio" { + let virtio_region = node + .reg() + .expect("reg property for virtio mmio not found in FDT") + .next() + .unwrap(); + let mut irq = 0; + let mut irqtype = 0; + let mut irqflags = 0; + + for prop in node.properties() { + if prop.name == "interrupts" { + irqtype = u32::from_be_bytes(prop.value[0..4].try_into().unwrap()); + irq = u32::from_be_bytes(prop.value[4..8].try_into().unwrap()); + irqflags = u32::from_be_bytes(prop.value[8..12].try_into().unwrap()); + break; } } + + let virtio_region_start = + PhysAddr::from(virtio_region.starting_address.expose_provenance()); + + assert!( + virtio_region.size.unwrap() + < usize::try_from(paging::BasePageSize::SIZE).unwrap() + ); + paging::identity_map::( + virtio_region_start.align_down(paging::BasePageSize::SIZE), + ); + + // Verify the first register value to find out if this is really an MMIO magic-value. + let ptr = virtio_region.starting_address as *mut DeviceRegisters; + let mmio = unsafe { VolatileRef::new(NonNull::new(ptr).unwrap()) }; + + let magic = mmio.as_ptr().magic_value().read().to_ne(); + let version = mmio.as_ptr().version().read().to_ne(); + + const MMIO_MAGIC_VALUE: u32 = 0x7472_6976; + if magic != MMIO_MAGIC_VALUE { + error!("It's not a MMIO-device at {mmio:p}"); + } + + if version != 2 { + warn!("Found a legacy device, which isn't supported"); + } + + // We found a MMIO-device (whose 512-bit address in this structure). + trace!("Found a MMIO-device at {mmio:p}"); + + // Verify the device-ID to find the network card + let id = mmio.as_ptr().device_id().read(); + let cpu_id: usize = 0; + + if id == virtio::Id::Reserved { + continue; + } + + debug!( + "Found {id:?} card at {mmio:p}, irq: {irq}, type: {irqtype}, flags: {irqflags}" + ); + + let drv = match mmio_virtio::init_device(mmio, irq.try_into().unwrap()) { + Ok(drv) => drv, + Err(err) => { + error!("{err}"); + continue; + } + }; + + let mut gic = GIC.lock(); + let Some(gic) = gic.as_mut() else { + error!("No GIC found"); + continue; + }; + + // enable timer interrupt + let virtio_irqid = if irqtype == 1 { + IntId::ppi(irq) + } else if irqtype == 0 { + IntId::spi(irq) + } else { + panic!("Invalid interrupt type"); + }; + gic.set_interrupt_priority(virtio_irqid, Some(cpu_id), 0x00); + if (irqflags & 0xf) == 4 || (irqflags & 0xf) == 8 { + gic.set_trigger(virtio_irqid, Some(cpu_id), Trigger::Level); + } else if (irqflags & 0xf) == 2 || (irqflags & 0xf) == 1 { + gic.set_trigger(virtio_irqid, Some(cpu_id), Trigger::Edge); + } else { + panic!("Invalid interrupt level!"); + } + gic.enable_interrupt(virtio_irqid, Some(cpu_id), true); + + match drv { + #[cfg(feature = "virtio-console")] + VirtioDriver::Console(drv) => register_driver(MmioDriver::VirtioConsole( + InterruptTicketMutex::new(*drv), + )), + #[cfg(feature = "virtio-fs")] + VirtioDriver::FileSystem(drv) => register_driver(MmioDriver::VirtioFs( + hermit_sync::InterruptTicketMutex::new(*drv), + )), + #[cfg(feature = "virtio-net")] + VirtioDriver::Net(drv) => *NETWORK_DEVICE.lock() = Some(*drv), + } } } - } else { - error!("No device tree found, cannot initialize MMIO drivers"); } }); diff --git a/src/arch/aarch64/kernel/pci.rs b/src/arch/aarch64/kernel/pci.rs index a430b2f6c1..3b84f83738 100644 --- a/src/arch/aarch64/kernel/pci.rs +++ b/src/arch/aarch64/kernel/pci.rs @@ -269,54 +269,56 @@ pub fn init() { let mut cmd = CommandRegister::empty(); let mut range_iter = 0..MAX_BARS; while let Some(i) = range_iter.next() { - if let Some(bar) = dev.get_bar(i.try_into().unwrap()) { - match bar { - Bar::Io { .. } => { - dev.set_bar( - i.try_into().unwrap(), - Bar::Io { - port: io_start.try_into().unwrap(), - }, - ); - io_start += 0x20; - cmd |= CommandRegister::IO_ENABLE - | CommandRegister::BUS_MASTER_ENABLE; - } - Bar::Memory32 { - address: _, - size, - prefetchable, - } => { - dev.set_bar( - i.try_into().unwrap(), - Bar::Memory32 { - address: mem32_start.try_into().unwrap(), - size, - prefetchable, - }, - ); - mem32_start += u64::from(size); - cmd |= CommandRegister::MEMORY_ENABLE - | CommandRegister::BUS_MASTER_ENABLE; - } - Bar::Memory64 { - address: _, - size, - prefetchable, - } => { - dev.set_bar( - i.try_into().unwrap(), - Bar::Memory64 { - address: mem64_start, - size, - prefetchable, - }, - ); - mem64_start += size; - cmd |= CommandRegister::MEMORY_ENABLE - | CommandRegister::BUS_MASTER_ENABLE; - range_iter.next(); // Skip 32-bit bar that is part of the 64-bit bar - } + let Some(bar) = dev.get_bar(i.try_into().unwrap()) else { + continue; + }; + + match bar { + Bar::Io { .. } => { + dev.set_bar( + i.try_into().unwrap(), + Bar::Io { + port: io_start.try_into().unwrap(), + }, + ); + io_start += 0x20; + cmd |= + CommandRegister::IO_ENABLE | CommandRegister::BUS_MASTER_ENABLE; + } + Bar::Memory32 { + address: _, + size, + prefetchable, + } => { + dev.set_bar( + i.try_into().unwrap(), + Bar::Memory32 { + address: mem32_start.try_into().unwrap(), + size, + prefetchable, + }, + ); + mem32_start += u64::from(size); + cmd |= CommandRegister::MEMORY_ENABLE + | CommandRegister::BUS_MASTER_ENABLE; + } + Bar::Memory64 { + address: _, + size, + prefetchable, + } => { + dev.set_bar( + i.try_into().unwrap(), + Bar::Memory64 { + address: mem64_start, + size, + prefetchable, + }, + ); + mem64_start += size; + cmd |= CommandRegister::MEMORY_ENABLE + | CommandRegister::BUS_MASTER_ENABLE; + range_iter.next(); // Skip 32-bit bar that is part of the 64-bit bar } } } diff --git a/src/arch/aarch64/kernel/processor.rs b/src/arch/aarch64/kernel/processor.rs index f52f5562da..b1580b8299 100644 --- a/src/arch/aarch64/kernel/processor.rs +++ b/src/arch/aarch64/kernel/processor.rs @@ -157,13 +157,13 @@ impl CpuFrequency { ) -> Result<(), ()> { //The clock frequency must never be set to zero, otherwise a division by zero will //occur during runtime - if khz > 0 { - self.khz = khz; - self.source = source; - Ok(()) - } else { - Err(()) + if khz == 0 { + return Err(()); } + + self.khz = khz; + self.source = source; + Ok(()) } unsafe fn detect_from_cmdline(&mut self) -> Result<(), ()> { @@ -306,18 +306,19 @@ pub fn detect_frequency() { #[inline] fn __set_oneshot_timer(wakeup_time: Option) { - if let Some(wt) = wakeup_time { - // wt is the absolute wakeup time in microseconds based on processor::get_timer_ticks. - let freq: u64 = CPU_FREQUENCY.get().into(); // frequency in KHz - let deadline = (wt / 1000) * freq; - - CNTP_CVAL_EL0.set(deadline); - CNTP_CTL_EL0.write(CNTP_CTL_EL0::ENABLE::SET); - } else { + let Some(wt) = wakeup_time else { // disable timer CNTP_CVAL_EL0.set(0); CNTP_CTL_EL0.write(CNTP_CTL_EL0::ENABLE::CLEAR); - } + return; + }; + + // wt is the absolute wakeup time in microseconds based on processor::get_timer_ticks. + let freq: u64 = CPU_FREQUENCY.get().into(); // frequency in KHz + let deadline = (wt / 1000) * freq; + + CNTP_CVAL_EL0.set(deadline); + CNTP_CTL_EL0.write(CNTP_CTL_EL0::ENABLE::SET); } pub fn set_oneshot_timer(wakeup_time: Option) { diff --git a/src/arch/aarch64/mm/paging.rs b/src/arch/aarch64/mm/paging.rs index 6f786d81e5..de157b7fda 100644 --- a/src/arch/aarch64/mm/paging.rs +++ b/src/arch/aarch64/mm/paging.rs @@ -301,13 +301,13 @@ impl Iterator for PageIter { type Item = Page; fn next(&mut self) -> Option> { - if self.current.virtual_address <= self.last.virtual_address { - let p = self.current; - self.current.virtual_address += S::SIZE; - Some(p) - } else { - None + if self.last.virtual_address < self.current.virtual_address { + return None; } + + let p = self.current; + self.current.virtual_address += S::SIZE; + Some(p) } } @@ -434,11 +434,11 @@ impl PageTableMethods for PageTable { assert_eq!(L::LEVEL, S::MAP_LEVEL); let index = page.table_index::(); - if self.entries[index].is_present() { - Some(self.entries[index]) - } else { - None + if !self.entries[index].is_present() { + return None; } + + Some(self.entries[index]) } /// Maps a single page to the given physical address. @@ -467,15 +467,15 @@ where assert!(L::LEVEL <= S::MAP_LEVEL); let index = page.table_index::(); - if self.entries[index].is_present() { - if L::LEVEL < S::MAP_LEVEL { - let subtable = self.subtable::(page); - subtable.get_page_table_entry::(page) - } else { - Some(self.entries[index]) - } + if !self.entries[index].is_present() { + return None; + } + + if L::LEVEL < S::MAP_LEVEL { + let subtable = self.subtable::(page); + subtable.get_page_table_entry::(page) } else { - None + Some(self.entries[index]) } } @@ -685,10 +685,10 @@ pub fn map_heap(virt_addr: VirtAddr, nr_pages: usize) -> Result<(), } if map_counter < nr_pages { - Err(map_counter) - } else { - Ok(()) + return Err(map_counter); } + + Ok(()) } pub fn identity_map(phys_addr: PhysAddr) { diff --git a/src/arch/riscv64/kernel/devicetree.rs b/src/arch/riscv64/kernel/devicetree.rs index 185e03199e..5ad9da5a42 100644 --- a/src/arch/riscv64/kernel/devicetree.rs +++ b/src/arch/riscv64/kernel/devicetree.rs @@ -58,32 +58,34 @@ enum Model { /// This function should only be called once pub fn init() { debug!("Init devicetree"); - if let Some(fdt) = env::fdt() { - let model = fdt - .find_node("/") - .unwrap() - .property("compatible") - .expect("compatible not found in FDT") - .as_str() - .unwrap(); - - let platform_model = if model.contains("riscv-virtio") { - Model::Virt - } else if model.contains("sifive,hifive-unmatched-a00") - || model.contains("sifive,hifive-unleashed-a00") - || model.contains("sifive,fu740") - || model.contains("sifive,fu540") - { - Model::Fux40 - } else { - warn!("Unknown platform, guessing PLIC context 1"); - Model::Unknown - }; - unsafe { - PLATFORM_MODEL = platform_model; - } - info!("Model: {model}"); + let Some(fdt) = env::fdt() else { + return; + }; + + let model = fdt + .find_node("/") + .unwrap() + .property("compatible") + .expect("compatible not found in FDT") + .as_str() + .unwrap(); + + let platform_model = if model.contains("riscv-virtio") { + Model::Virt + } else if model.contains("sifive,hifive-unmatched-a00") + || model.contains("sifive,hifive-unleashed-a00") + || model.contains("sifive,fu740") + || model.contains("sifive,fu540") + { + Model::Fux40 + } else { + warn!("Unknown platform, guessing PLIC context 1"); + Model::Unknown + }; + unsafe { + PLATFORM_MODEL = platform_model; } + info!("Model: {model}"); } /// Inits drivers based on the device tree diff --git a/src/arch/riscv64/kernel/mod.rs b/src/arch/riscv64/kernel/mod.rs index cdbb6ce5b5..44c66100a7 100644 --- a/src/arch/riscv64/kernel/mod.rs +++ b/src/arch/riscv64/kernel/mod.rs @@ -150,33 +150,34 @@ pub fn boot_next_processor() { let next_hart_index = lsb(new_hart_mask); - if let Some(next_hart_id) = next_hart_index { - { - debug!("Allocating stack for hard_id {next_hart_id}"); - let frame_layout = PageLayout::from_size(KERNEL_STACK_SIZE).unwrap(); - let frame_range = FrameAlloc::allocate(frame_layout) - .expect("Failed to allocate boot stack for new core"); - let stack = ptr::with_exposed_provenance_mut(frame_range.start()); - CURRENT_STACK_ADDRESS.store(stack, Ordering::Relaxed); - } - - info!( - "Starting CPU {} with hart_id {}", - core_id() + 1, - next_hart_id - ); - - // TODO: Old: Changing cpu_online will cause uhyve to start the next processor - CPU_ONLINE.fetch_add(1, Ordering::Release); - - //When running bare-metal/QEMU we use the firmware to start the next hart - if !env::is_uhyve() { - let start_addr = (start::_start as *const ()).expose_provenance(); - sbi_rt::hart_start(next_hart_id as usize, start_addr, 0).unwrap(); - } - } else { + let Some(next_hart_id) = next_hart_index else { info!("All processors are initialized"); CPU_ONLINE.fetch_add(1, Ordering::Release); + return; + }; + + { + debug!("Allocating stack for hard_id {next_hart_id}"); + let frame_layout = PageLayout::from_size(KERNEL_STACK_SIZE).unwrap(); + let frame_range = + FrameAlloc::allocate(frame_layout).expect("Failed to allocate boot stack for new core"); + let stack = ptr::with_exposed_provenance_mut(frame_range.start()); + CURRENT_STACK_ADDRESS.store(stack, Ordering::Relaxed); + } + + info!( + "Starting CPU {} with hart_id {}", + core_id() + 1, + next_hart_id + ); + + // TODO: Old: Changing cpu_online will cause uhyve to start the next processor + CPU_ONLINE.fetch_add(1, Ordering::Release); + + //When running bare-metal/QEMU we use the firmware to start the next hart + if !env::is_uhyve() { + let start_addr = (start::_start as *const ()).expose_provenance(); + sbi_rt::hart_start(next_hart_id as usize, start_addr, 0).unwrap(); } } diff --git a/src/arch/riscv64/kernel/processor.rs b/src/arch/riscv64/kernel/processor.rs index 608494a99e..5c9b1afc96 100644 --- a/src/arch/riscv64/kernel/processor.rs +++ b/src/arch/riscv64/kernel/processor.rs @@ -272,19 +272,20 @@ pub fn supports_2mib_pages() -> bool { } pub fn set_oneshot_timer(wakeup_time: Option) { - if let Some(wt) = wakeup_time { - debug!("Starting Timer: {:x}", get_timestamp()); - unsafe { - sie::set_stimer(); - } - let next_time = wt * u64::from(get_frequency()); - - sbi_rt::set_timer(next_time); - } else { + let Some(wt) = wakeup_time else { // Disable the Timer (and clear a pending interrupt) debug!("Stopping Timer"); sbi_rt::set_timer(u64::MAX); + return; + }; + + debug!("Starting Timer: {:x}", get_timestamp()); + unsafe { + sie::set_stimer(); } + let next_time = wt * u64::from(get_frequency()); + + sbi_rt::set_timer(next_time); } pub fn wakeup_core(core_to_wakeup: CoreId) { diff --git a/src/arch/riscv64/mm/paging.rs b/src/arch/riscv64/mm/paging.rs index d52943acf8..c65b5ab64e 100644 --- a/src/arch/riscv64/mm/paging.rs +++ b/src/arch/riscv64/mm/paging.rs @@ -271,13 +271,13 @@ impl Iterator for PageIter { type Item = Page; fn next(&mut self) -> Option> { - if self.current.virtual_address <= self.last.virtual_address { - let p = self.current; - self.current.virtual_address += S::SIZE; - Some(p) - } else { - None + if self.last.virtual_address < self.current.virtual_address { + return None; } + + let p = self.current; + self.current.virtual_address += S::SIZE; + Some(p) } } diff --git a/src/arch/x86_64/kernel/acpi.rs b/src/arch/x86_64/kernel/acpi.rs index bf44eea404..29bf7dbb1f 100644 --- a/src/arch/x86_64/kernel/acpi.rs +++ b/src/arch/x86_64/kernel/acpi.rs @@ -392,40 +392,41 @@ fn search_s5_in_table(table: AcpiTable<'_>) { // Find the "_S5_" object in the bytecode. let s5 = [b'_', b'S', b'5', b'_', AML_PACKAGEOP]; let s5_position = aml.windows(s5.len()).position(|window| window == s5); - if let Some(i) = s5_position { - // We have found an "_S5_" object that looks valid. - // To be sure, verify that it begins with an AML_NAMEOP or an AML_NAMEOP and a backslash. - if i > 2 && (aml[i - 1] == AML_NAMEOP || (aml[i - 2] == AML_NAMEOP && aml[i - 1] == b'\\')) - { - // This is a valid "_S5_" object. - // It should be followed by this structure: - // - single byte for PkgLength (index 5) - // - single byte for NumElements (index 6) - let pkg_length = aml[i + 5]; - let num_elements = aml[i + 6]; - - // Bits 6-7 of PkgLength are non-zero for larger packages, resulting in a different structure. - // This mustn't be the case for the "_S5_" object. - if pkg_length & 0b1100_0000 == 0 && num_elements > 0 { - // The next byte is an opcode describing the data. - // It is usually the byte prefix, indicating that the actual data is the single byte following the opcode. - // However, if the data is a zero or one byte, this may also be indicated by the opcode. - let op = aml[i + 7]; - let slp_typa = match op { - AML_ZEROOP => 0, - AML_ONEOP => 1, - AML_BYTEPREFIX => aml[i + 8], - _ => return, - }; - - // All assumptions are correct, so slp_typa is supposed to contain valid information. - // Now we have all information we need for powering off through ACPI. - // - // Note that Power Off may also be controlled through PM1B_CNT_BLK / SLP_TYPB - // according to the ACPI Specification. However, this has not yet been observed on real computers - // and therefore not implemented. - SLP_TYPA.set(slp_typa).unwrap(); - } + let Some(i) = s5_position else { + return; + }; + + // We have found an "_S5_" object that looks valid. + // To be sure, verify that it begins with an AML_NAMEOP or an AML_NAMEOP and a backslash. + if i > 2 && (aml[i - 1] == AML_NAMEOP || (aml[i - 2] == AML_NAMEOP && aml[i - 1] == b'\\')) { + // This is a valid "_S5_" object. + // It should be followed by this structure: + // - single byte for PkgLength (index 5) + // - single byte for NumElements (index 6) + let pkg_length = aml[i + 5]; + let num_elements = aml[i + 6]; + + // Bits 6-7 of PkgLength are non-zero for larger packages, resulting in a different structure. + // This mustn't be the case for the "_S5_" object. + if pkg_length & 0b1100_0000 == 0 && num_elements > 0 { + // The next byte is an opcode describing the data. + // It is usually the byte prefix, indicating that the actual data is the single byte following the opcode. + // However, if the data is a zero or one byte, this may also be indicated by the opcode. + let op = aml[i + 7]; + let slp_typa = match op { + AML_ZEROOP => 0, + AML_ONEOP => 1, + AML_BYTEPREFIX => aml[i + 8], + _ => return, + }; + + // All assumptions are correct, so slp_typa is supposed to contain valid information. + // Now we have all information we need for powering off through ACPI. + // + // Note that Power Off may also be controlled through PM1B_CNT_BLK / SLP_TYPB + // according to the ACPI Specification. However, this has not yet been observed on real computers + // and therefore not implemented. + SLP_TYPA.set(slp_typa).unwrap(); } } } @@ -498,15 +499,16 @@ pub fn get_mcfg_table() -> Option<&'static AcpiTable<'static>> { } pub fn poweroff() { - if let (Some(mut pm1a_cnt_blk), Some(&slp_typa)) = (PM1A_CNT_BLK.get().cloned(), SLP_TYPA.get()) - { - let bits = (u16::from(slp_typa) << 10) | SLP_EN; - debug!("Powering Off through ACPI (port {pm1a_cnt_blk:?}, bitmask {bits:#X})"); - unsafe { - pm1a_cnt_blk.write(bits); - } - } else { + let (Some(mut pm1a_cnt_blk), Some(&slp_typa)) = (PM1A_CNT_BLK.get().cloned(), SLP_TYPA.get()) + else { warn!("ACPI Power Off is not available"); + return; + }; + + let bits = (u16::from(slp_typa) << 10) | SLP_EN; + debug!("Powering Off through ACPI (port {pm1a_cnt_blk:?}, bitmask {bits:#X})"); + unsafe { + pm1a_cnt_blk.write(bits); } } diff --git a/src/arch/x86_64/kernel/apic.rs b/src/arch/x86_64/kernel/apic.rs index f06fd95645..09e788e226 100644 --- a/src/arch/x86_64/kernel/apic.rs +++ b/src/arch/x86_64/kernel/apic.rs @@ -670,45 +670,47 @@ fn calibrate_timer() { } fn __set_oneshot_timer(wakeup_time: Option) { - if let Some(wt) = wakeup_time { - if processor::supports_tsc_deadline() { - // wt is the absolute wakeup time in microseconds based on processor::get_timer_ticks. - // We can simply multiply it by the processor frequency to get the absolute Time-Stamp Counter deadline - // (see processor::get_timer_ticks). - let tsc_deadline = wt * (u64::from(processor::get_frequency())); - - // Enable the APIC Timer in TSC-Deadline Mode and let it start by writing to the respective MSR. - local_apic_write( - IA32_X2APIC_LVT_TIMER, - APIC_LVT_TIMER_TSC_DEADLINE | u64::from(TIMER_INTERRUPT_NUMBER), - ); - let mut ia32_tsc_deadline = IA32_TSC_DEADLINE; - unsafe { - ia32_tsc_deadline.write(tsc_deadline); - } - } else { - // Calculate the relative timeout from the absolute wakeup time. - // Maintain a minimum value of one tick, otherwise the timer interrupt does not fire at all. - // The Timer Counter Register is also a 32-bit register, which we must not overflow for longer timeouts. - let current_time = processor::get_timer_ticks(); - let ticks = if wt > current_time { - wt - current_time - } else { - 1 - }; - let init_count = cmp::min( - CALIBRATED_COUNTER_VALUE.get().unwrap() * ticks, - u64::from(u32::MAX), - ); - - // Enable the APIC Timer in One-Shot Mode and let it start by setting the initial counter value. - local_apic_write(IA32_X2APIC_LVT_TIMER, u64::from(TIMER_INTERRUPT_NUMBER)); - local_apic_write(IA32_X2APIC_INIT_COUNT, init_count); - } - } else { + let Some(wt) = wakeup_time else { // Disable the APIC Timer. local_apic_write(IA32_X2APIC_LVT_TIMER, APIC_LVT_MASK); + return; + }; + + if processor::supports_tsc_deadline() { + // wt is the absolute wakeup time in microseconds based on processor::get_timer_ticks. + // We can simply multiply it by the processor frequency to get the absolute Time-Stamp Counter deadline + // (see processor::get_timer_ticks). + let tsc_deadline = wt * (u64::from(processor::get_frequency())); + + // Enable the APIC Timer in TSC-Deadline Mode and let it start by writing to the respective MSR. + local_apic_write( + IA32_X2APIC_LVT_TIMER, + APIC_LVT_TIMER_TSC_DEADLINE | u64::from(TIMER_INTERRUPT_NUMBER), + ); + let mut ia32_tsc_deadline = IA32_TSC_DEADLINE; + unsafe { + ia32_tsc_deadline.write(tsc_deadline); + } + return; } + + // Calculate the relative timeout from the absolute wakeup time. + // Maintain a minimum value of one tick, otherwise the timer interrupt does not fire at all. + // The Timer Counter Register is also a 32-bit register, which we must not overflow for longer timeouts. + let current_time = processor::get_timer_ticks(); + let ticks = if wt > current_time { + wt - current_time + } else { + 1 + }; + let init_count = cmp::min( + CALIBRATED_COUNTER_VALUE.get().unwrap() * ticks, + u64::from(u32::MAX), + ); + + // Enable the APIC Timer in One-Shot Mode and let it start by setting the initial counter value. + local_apic_write(IA32_X2APIC_LVT_TIMER, u64::from(TIMER_INTERRUPT_NUMBER)); + local_apic_write(IA32_X2APIC_INIT_COUNT, init_count); } pub fn set_oneshot_timer(wakeup_time: Option) { diff --git a/src/arch/x86_64/kernel/mod.rs b/src/arch/x86_64/kernel/mod.rs index dba7cb3430..3c74007148 100644 --- a/src/arch/x86_64/kernel/mod.rs +++ b/src/arch/x86_64/kernel/mod.rs @@ -324,16 +324,13 @@ pub unsafe fn jump_to_user_land(entry_point: usize, code_size: usize, arg: &[&st let mut pos: usize = 0; for (i, s) in arg.iter().enumerate() { - if let Ok(s) = CString::new(*s) { - let bytes = s.as_bytes_with_nul(); - argv[i] = ptr::with_exposed_provenance_mut::(stack_pointer + pos); - pos += bytes.len(); + let s = CString::new(*s).unwrap(); + let bytes = s.as_bytes_with_nul(); + argv[i] = ptr::with_exposed_provenance_mut::(stack_pointer + pos); + pos += bytes.len(); - unsafe { - argv[i].copy_from_nonoverlapping(bytes.as_ptr(), bytes.len()); - } - } else { - panic!("Unable to create C string!"); + unsafe { + argv[i].copy_from_nonoverlapping(bytes.as_ptr(), bytes.len()); } } diff --git a/src/arch/x86_64/kernel/processor.rs b/src/arch/x86_64/kernel/processor.rs index 71aba2fd29..0915d6365c 100644 --- a/src/arch/x86_64/kernel/processor.rs +++ b/src/arch/x86_64/kernel/processor.rs @@ -275,13 +275,13 @@ impl CpuFrequency { ) -> Result<(), ()> { //The clock frequency must never be set to zero, otherwise a division by zero will //occur during runtime - if mhz > 0 { - self.mhz = mhz; - self.source = source; - Ok(()) - } else { - Err(()) + if mhz == 0 { + return Err(()); } + + self.mhz = mhz; + self.source = source; + Ok(()) } unsafe fn detect_from_cmdline(&mut self) -> Result<(), ()> { @@ -327,31 +327,28 @@ impl CpuFrequency { &mut self, cpuid: &CpuId, ) -> Result<(), ()> { - if let Some(processor_brand) = cpuid.get_processor_brand_string() { - let brand_string = processor_brand.as_str(); - let ghz_find = brand_string.find("GHz"); - - if let Some(ghz_find) = ghz_find { - let index = ghz_find - 4; - let thousand_char = brand_string.chars().nth(index).unwrap(); - let decimal_char = brand_string.chars().nth(index + 1).unwrap(); - let hundred_char = brand_string.chars().nth(index + 2).unwrap(); - let ten_char = brand_string.chars().nth(index + 3).unwrap(); - - if let (Some(thousand), '.', Some(hundred), Some(ten)) = ( - thousand_char.to_digit(10), - decimal_char, - hundred_char.to_digit(10), - ten_char.to_digit(10), - ) { - let mhz = (thousand * 1000 + hundred * 100 + ten * 10) as u16; - return self - .set_detected_cpu_frequency(mhz, CpuFrequencySources::CpuIdBrandString); - } - } + let processor_brand = cpuid.get_processor_brand_string().ok_or(())?; + + let brand_string = processor_brand.as_str(); + let ghz_find = brand_string.find("GHz"); + + let ghz_find = ghz_find.ok_or(())?; + + let index = ghz_find - 4; + let thousand_char = brand_string.chars().nth(index).unwrap(); + let decimal_char = brand_string.chars().nth(index + 1).unwrap(); + let hundred_char = brand_string.chars().nth(index + 2).unwrap(); + let ten_char = brand_string.chars().nth(index + 3).unwrap(); + + let thousand = thousand_char.to_digit(10).ok_or(())?; + if decimal_char != '.' { + return Err(()); } + let hundred = hundred_char.to_digit(10).ok_or(())?; + let ten = ten_char.to_digit(10).ok_or(())?; - Err(()) + let mhz = (thousand * 1000 + hundred * 100 + ten * 10) as u16; + self.set_detected_cpu_frequency(mhz, CpuFrequencySources::CpuIdBrandString) } fn detect_from_fdt(&mut self) -> Result<(), ()> { @@ -967,27 +964,28 @@ pub fn print_information() { pub fn seed_entropy() -> Option<[u8; 32]> { let mut buf = [0; 32]; - if FEATURES.supports_rdseed { - for word in buf.chunks_mut(8) { - let mut value = 0; - - // Some RDRAND implementations on AMD CPUs have had bugs where the carry - // flag was incorrectly set without there actually being a random value - // available. Even though no bugs are known for RDSEED, we should not - // consider the default values random for extra security. - while unsafe { _rdseed64_step(&mut value) != 1 } || value == 0 || value == u64::MAX { - // Spin as per the recommendation in the - // IntelĀ® Digital Random Number Generator (DRNG) implementation guide - spin_loop(); - } - word.copy_from_slice(&value.to_ne_bytes()); + if !FEATURES.supports_rdseed { + return None; + } + + for word in buf.chunks_mut(8) { + let mut value = 0; + + // Some RDRAND implementations on AMD CPUs have had bugs where the carry + // flag was incorrectly set without there actually being a random value + // available. Even though no bugs are known for RDSEED, we should not + // consider the default values random for extra security. + while unsafe { _rdseed64_step(&mut value) != 1 } || value == 0 || value == u64::MAX { + // Spin as per the recommendation in the + // IntelĀ® Digital Random Number Generator (DRNG) implementation guide + spin_loop(); } - Some(buf) - } else { - None + word.copy_from_slice(&value.to_ne_bytes()); } + + Some(buf) } #[inline] From 907e825bea4e68156895d661dc083339fb6e0a84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kr=C3=B6ning?= Date: Wed, 11 Feb 2026 15:34:21 +0100 Subject: [PATCH 2/7] refactor(drivers): do more early returns --- src/drivers/console/mod.rs | 114 +++++++++++++------------ src/drivers/mmio.rs | 16 ++-- src/drivers/net/gem.rs | 40 ++++----- src/drivers/net/loopback.rs | 22 ++--- src/drivers/net/mod.rs | 10 +-- src/drivers/net/rtl8139.rs | 68 ++++++++------- src/drivers/net/virtio/mod.rs | 58 +++++++------ src/drivers/pci.rs | 31 ++++--- src/drivers/virtio/transport/mmio.rs | 12 +-- src/drivers/virtio/transport/pci.rs | 18 ++-- src/drivers/virtio/virtqueue/mod.rs | 55 ++++++------ src/drivers/virtio/virtqueue/packed.rs | 99 +++++++++++---------- src/drivers/virtio/virtqueue/split.rs | 6 +- src/drivers/vsock/mod.rs | 108 ++++++++++++----------- 14 files changed, 338 insertions(+), 319 deletions(-) diff --git a/src/drivers/console/mod.rs b/src/drivers/console/mod.rs index 0b7b633f46..6cef7157c8 100644 --- a/src/drivers/console/mod.rs +++ b/src/drivers/console/mod.rs @@ -83,21 +83,19 @@ impl ErrorType for VirtioUART { impl Read for VirtioUART { fn read(&mut self, buf: &mut [u8]) -> Result { - if let Some(drv) = get_console_driver() { - drv.lock().read(buf) - } else { - Err(Errno::Io) - } + let drv = get_console_driver().ok_or(Errno::Io)?; + + drv.lock().read(buf) } } impl ReadReady for VirtioUART { fn read_ready(&mut self) -> Result { - if let Some(drv) = get_console_driver() { - Ok(drv.lock().has_packet()) - } else { - Ok(false) - } + let Some(drv) = get_console_driver() else { + return Ok(false); + }; + + Ok(drv.lock().has_packet()) } } @@ -138,15 +136,19 @@ impl RxQueue { } pub fn enable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.enable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.enable_notifs(); } pub fn disable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.disable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.disable_notifs(); } fn has_packet(&self) -> bool { @@ -161,21 +163,17 @@ impl RxQueue { where F: FnMut(&[u8]) -> usize, { - if let Some(mut buffer_tkn) = self.get_next() { - let packet = buffer_tkn.used_recv_buff.pop_front_vec().unwrap(); - - if let Some(ref mut vq) = self.vq { - let result = f(&packet[..]); + let Some(mut buffer_tkn) = self.get_next() else { + return Ok(0); + }; - fill_queue(vq, 1, self.packet_size); + let packet = buffer_tkn.used_recv_buff.pop_front_vec().unwrap(); + let vq = self.vq.as_mut().unwrap(); + let result = f(&packet[..]); - return Ok(result); - } else { - panic!("Invalid length of receive queue"); - } - } + fill_queue(vq, 1, self.packet_size); - Ok(0) + Ok(result) } } @@ -199,21 +197,27 @@ impl TxQueue { } pub fn enable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.enable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.enable_notifs(); } pub fn disable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.disable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.disable_notifs(); } fn poll(&mut self) { - if let Some(ref mut vq) = self.vq { - while vq.try_recv().is_ok() {} - } + let Some(vq) = &mut self.vq else { + return; + }; + + while vq.try_recv().is_ok() {} } /// Provides a slice to copy the packet and transfer the packet @@ -223,25 +227,23 @@ impl TxQueue { // We need to poll to get the queue to remove elements from the table and make space for // what we are about to add self.poll(); - if let Some(ref mut vq) = self.vq { - assert!(buf.len() < usize::try_from(self.packet_length).unwrap()); - let mut packet = Vec::with_capacity_in(buf.len(), DeviceAlloc); - packet.extend_from_slice(buf); - - let buff_tkn = AvailBufferToken::new( - { - let mut vec = SmallVec::new(); - vec.push(BufferElem::Vector(packet)); - vec - }, - SmallVec::new(), - ) - .unwrap(); - - vq.dispatch(buff_tkn, false, BufferType::Direct).unwrap(); - } else { - panic!("Unable to get send queue"); - } + let vq = self.vq.as_mut().unwrap(); + + assert!(buf.len() < usize::try_from(self.packet_length).unwrap()); + let mut packet = Vec::with_capacity_in(buf.len(), DeviceAlloc); + packet.extend_from_slice(buf); + + let buff_tkn = AvailBufferToken::new( + { + let mut vec = SmallVec::new(); + vec.push(BufferElem::Vector(packet)); + vec + }, + SmallVec::new(), + ) + .unwrap(); + + vq.dispatch(buff_tkn, false, BufferType::Direct).unwrap(); } } diff --git a/src/drivers/mmio.rs b/src/drivers/mmio.rs index 7d5926e8b4..fafd7cd372 100644 --- a/src/drivers/mmio.rs +++ b/src/drivers/mmio.rs @@ -39,9 +39,11 @@ pub(crate) fn get_interrupt_handlers() -> HashMap HashMap Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - if let Some(rx_index) = self.next_rx_index() - && let Some(tx_index) = self.next_tx_index() - { - self.reserve_tx_index(tx_index); - - // Starting point to search for next frame - self.rx_counter = (rx_index + 1) % RX_BUF_NUM; - self.rx_fields.rxbuffer_reserved[usize::try_from(rx_index).unwrap()] = true; - - Some(( - RxToken { - buffer_index: rx_index, - rx_fields: &mut self.rx_fields, - }, - TxToken { - buffer_index: tx_index, - tx_fields: &mut self.tx_fields, - }, - )) - } else { - None - } + let rx_index = self.next_rx_index()?; + let tx_index = self.next_tx_index()?; + + self.reserve_tx_index(tx_index); + + // Starting point to search for next frame + self.rx_counter = (rx_index + 1) % RX_BUF_NUM; + self.rx_fields.rxbuffer_reserved[usize::try_from(rx_index).unwrap()] = true; + + let rx_token = RxToken { + buffer_index: rx_index, + rx_fields: &mut self.rx_fields, + }; + let tx_token = TxToken { + buffer_index: tx_index, + tx_fields: &mut self.tx_fields, + }; + Some((rx_token, tx_token)) } fn transmit(&mut self, timestamp: smoltcp::time::Instant) -> Option> { self.handle_interrupt(); diff --git a/src/drivers/net/loopback.rs b/src/drivers/net/loopback.rs index 9a2ab1be67..516c49ba7a 100644 --- a/src/drivers/net/loopback.rs +++ b/src/drivers/net/loopback.rs @@ -79,18 +79,18 @@ impl smoltcp::phy::Device for LoopbackDriver { type TxToken<'a> = TxToken<'a>; fn receive(&mut self, _: Instant) -> Option<(RxToken<'_>, TxToken<'_>)> { - if self.queue.lock().len() > self.reserved_receives.load(Ordering::Relaxed) { - self.reserved_receives.fetch_add(1, Ordering::Relaxed); - Some(( - RxToken { - queue: &self.queue, - reserved_receives: &self.reserved_receives, - }, - TxToken { queue: &self.queue }, - )) - } else { - None + if self.queue.lock().len() <= self.reserved_receives.load(Ordering::Relaxed) { + return None; } + + self.reserved_receives.fetch_add(1, Ordering::Relaxed); + Some(( + RxToken { + queue: &self.queue, + reserved_receives: &self.reserved_receives, + }, + TxToken { queue: &self.queue }, + )) } fn transmit(&mut self, _: Instant) -> Option> { diff --git a/src/drivers/net/mod.rs b/src/drivers/net/mod.rs index d8c4379af5..7a30e65ead 100644 --- a/src/drivers/net/mod.rs +++ b/src/drivers/net/mod.rs @@ -50,11 +50,11 @@ pub(crate) fn mtu() -> u16 { /// This is 1500 IP MTU and a 14-byte ethernet header. const DEFAULT_MTU: u16 = DEFAULT_IP_MTU + 14; - if let Some(my_mtu) = hermit_var!("HERMIT_MTU") { - u16::from_str(&my_mtu).unwrap() - } else { - DEFAULT_MTU - } + let Some(my_mtu) = hermit_var!("HERMIT_MTU") else { + return DEFAULT_MTU; + }; + + u16::from_str(&my_mtu).unwrap() } cfg_if::cfg_if! { diff --git a/src/drivers/net/rtl8139.rs b/src/drivers/net/rtl8139.rs index 732e3ef48e..1afcf757c1 100644 --- a/src/drivers/net/rtl8139.rs +++ b/src/drivers/net/rtl8139.rs @@ -578,42 +578,46 @@ impl smoltcp::phy::Device for RTL8139Driver { type TxToken<'a> = TxToken<'a>; fn receive(&mut self, _: smoltcp::time::Instant) -> Option<(RxToken<'_>, TxToken<'_>)> { - if !self.rx_fields.rx_in_use && self.has_packet() { - self.rx_fields.rx_in_use = true; - let regs = self.regs.as_mut_ptr(); - - Some(( - RxToken { - capr: map_field!(regs.capr), - rx_fields: &mut self.rx_fields, - }, - TxToken { - tsd0: map_field!(regs.tsd0), - tsd1: map_field!(regs.tsd1), - tsd2: map_field!(regs.tsd2), - tsd3: map_field!(regs.tsd3), - tx_fields: &mut self.tx_fields, - }, - )) - } else { - None + if self.rx_fields.rx_in_use { + return None; } - } - fn transmit(&mut self, _: smoltcp::time::Instant) -> Option> { - if self.tx_fields.remaining_bufs > 0 { - let regs = self.regs.as_mut_ptr(); + if !self.has_packet() { + return None; + } - Some(TxToken { + self.rx_fields.rx_in_use = true; + let regs = self.regs.as_mut_ptr(); + + Some(( + RxToken { + capr: map_field!(regs.capr), + rx_fields: &mut self.rx_fields, + }, + TxToken { tsd0: map_field!(regs.tsd0), tsd1: map_field!(regs.tsd1), tsd2: map_field!(regs.tsd2), tsd3: map_field!(regs.tsd3), tx_fields: &mut self.tx_fields, - }) - } else { - None + }, + )) + } + + fn transmit(&mut self, _: smoltcp::time::Instant) -> Option> { + if self.tx_fields.remaining_bufs == 0 { + return None; } + + let regs = self.regs.as_mut_ptr(); + + Some(TxToken { + tsd0: map_field!(regs.tsd0), + tsd1: map_field!(regs.tsd1), + tsd2: map_field!(regs.tsd2), + tsd3: map_field!(regs.tsd3), + tx_fields: &mut self.tx_fields, + }) } fn capabilities(&self) -> smoltcp::phy::DeviceCapabilities { @@ -746,11 +750,13 @@ pub(crate) fn init_device( let mut regs = None; for i in 0..MAX_BARS { - if let Some(Bar::Memory32 { .. }) = device.get_bar(i.try_into().unwrap()) { - let (addr, _size) = device.memory_map_bar(i.try_into().unwrap(), true).unwrap(); + let Some(Bar::Memory32 { .. }) = device.get_bar(i.try_into().unwrap()) else { + continue; + }; - regs = Some(unsafe { VolatileRef::new(NonNull::new(addr.as_mut_ptr()).unwrap()) }); - } + let (addr, _size) = device.memory_map_bar(i.try_into().unwrap(), true).unwrap(); + + regs = Some(unsafe { VolatileRef::new(NonNull::new(addr.as_mut_ptr()).unwrap()) }); } let mut regs = regs.ok_or(DriverError::InitRTL8139DevFail(RTL8139Error::Unknown))?; diff --git a/src/drivers/net/virtio/mod.rs b/src/drivers/net/virtio/mod.rs index d2301b3fcf..2769d948ee 100644 --- a/src/drivers/net/virtio/mod.rs +++ b/src/drivers/net/virtio/mod.rs @@ -439,39 +439,43 @@ impl smoltcp::phy::Device for VirtioNetDriver { &mut self, _timestamp: smoltcp::time::Instant, ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - if self.inner.recv_vqs.has_packet() && { - self.free_up_send_capacity(); - self.inner.send_capacity >= u32::from(BUFF_PER_PACKET) - } { - self.inner.send_capacity -= u32::from(BUFF_PER_PACKET); - Some(( - RxToken { - recv_vqs: &mut self.inner.recv_vqs, - is_mrg_rxbuf_enabled: self.dev_cfg.features.contains(virtio::net::F::MRG_RXBUF), - }, - TxToken { - send_vqs: &mut self.inner.send_vqs, - checksums: self.checksums.clone(), - send_capacity: &mut self.inner.send_capacity, - }, - )) - } else { - None + if !self.inner.recv_vqs.has_packet() { + return None; } - } - fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { self.free_up_send_capacity(); - if self.inner.send_capacity >= u32::from(BUFF_PER_PACKET) { - self.inner.send_capacity -= u32::from(BUFF_PER_PACKET); - Some(TxToken { + + self.inner.send_capacity = self + .inner + .send_capacity + .checked_sub(u32::from(BUFF_PER_PACKET))?; + + Some(( + RxToken { + recv_vqs: &mut self.inner.recv_vqs, + is_mrg_rxbuf_enabled: self.dev_cfg.features.contains(virtio::net::F::MRG_RXBUF), + }, + TxToken { send_vqs: &mut self.inner.send_vqs, checksums: self.checksums.clone(), send_capacity: &mut self.inner.send_capacity, - }) - } else { - None - } + }, + )) + } + + fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { + self.free_up_send_capacity(); + + self.inner.send_capacity = self + .inner + .send_capacity + .checked_sub(u32::from(BUFF_PER_PACKET))?; + + Some(TxToken { + send_vqs: &mut self.inner.send_vqs, + checksums: self.checksums.clone(), + send_capacity: &mut self.inner.send_capacity, + }) } } diff --git a/src/drivers/pci.rs b/src/drivers/pci.rs index d3881ba49c..cdf99eef10 100644 --- a/src/drivers/pci.rs +++ b/src/drivers/pci.rs @@ -75,11 +75,8 @@ impl PciDevice { /// Returns the bar at bar-register `slot`. pub fn get_bar(&self, slot: u8) -> Option { let header = self.header(); - if let Some(endpoint) = EndpointHeader::from_header(header, &self.access) { - return endpoint.bar(slot, &self.access); - } - - None + let endpoint = EndpointHeader::from_header(header, &self.access)?; + endpoint.bar(slot, &self.access) } /// Configure the bar at register `slot` @@ -375,9 +372,11 @@ impl PciDriver { #[cfg(feature = "virtio-vsock")] Self::VirtioVsock(drv) => { fn vsock_handler() { - if let Some(driver) = get_vsock_driver() { - driver.lock().handle_interrupt(); - } + let Some(driver) = get_vsock_driver() else { + return; + }; + + driver.lock().handle_interrupt(); } let irq_number = drv.lock().get_interrupt_number(); @@ -387,9 +386,11 @@ impl PciDriver { #[cfg(feature = "virtio-fs")] Self::VirtioFs(drv) => { fn fuse_handler() { - if let Some(driver) = get_filesystem_driver() { - driver.lock().handle_interrupt(); - } + let Some(driver) = get_filesystem_driver() else { + return; + }; + + driver.lock().handle_interrupt(); } let irq_number = drv.lock().get_interrupt_number(); @@ -399,9 +400,11 @@ impl PciDriver { #[cfg(feature = "virtio-console")] Self::VirtioConsole(drv) => { fn console_handler() { - if let Some(driver) = get_console_driver() { - driver.lock().handle_interrupt(); - } + let Some(driver) = get_console_driver() else { + return; + }; + + driver.lock().handle_interrupt(); } let irq_number = drv.lock().get_interrupt_number(); diff --git a/src/drivers/virtio/transport/mmio.rs b/src/drivers/virtio/transport/mmio.rs index b5e8b72cfd..90ef28f6c7 100644 --- a/src/drivers/virtio/transport/mmio.rs +++ b/src/drivers/virtio/transport/mmio.rs @@ -120,13 +120,13 @@ impl ComCfg { /// INFO: The queue size is automatically bounded by constant `src::config:VIRTIO_MAX_QUEUE_SIZE`. pub fn select_vq(&mut self, index: u16) -> Option> { if self.get_max_queue_size(index) == 0 { - None - } else { - Some(VqCfgHandler { - vq_index: index, - raw: self.com_cfg.borrow_mut(), - }) + return None; } + + Some(VqCfgHandler { + vq_index: index, + raw: self.com_cfg.borrow_mut(), + }) } pub fn get_max_queue_size(&mut self, sel: u16) -> u16 { diff --git a/src/drivers/virtio/transport/pci.rs b/src/drivers/virtio/transport/pci.rs index ae4a345a53..a16aba68a6 100644 --- a/src/drivers/virtio/transport/pci.rs +++ b/src/drivers/virtio/transport/pci.rs @@ -281,13 +281,13 @@ impl ComCfg { self.com_cfg.as_mut_ptr().queue_select().write(index.into()); if self.com_cfg.as_mut_ptr().queue_size().read().to_ne() == 0 { - None - } else { - Some(VqCfgHandler { - vq_index: index, - raw: self.com_cfg.borrow_mut(), - }) + return None; } + + Some(VqCfgHandler { + vq_index: index, + raw: self.com_cfg.borrow_mut(), + }) } #[allow(dead_code)] @@ -540,10 +540,10 @@ fn read_caps(device: &PciDevice) -> Result, PciErro if capabilities.is_empty() { error!("No virtio capability found for device {device_id:x}"); - Err(PciError::NoVirtioCaps(device_id)) - } else { - Ok(capabilities) + return Err(PciError::NoVirtioCaps(device_id)); } + + Ok(capabilities) } pub(crate) fn map_caps(device: &PciDevice) -> Result { diff --git a/src/drivers/virtio/virtqueue/mod.rs b/src/drivers/virtio/virtqueue/mod.rs index 00a0c5cd3d..b02c35f7e3 100644 --- a/src/drivers/virtio/virtqueue/mod.rs +++ b/src/drivers/virtio/virtqueue/mod.rs @@ -392,19 +392,16 @@ impl UsedDeviceWritableBuffer { self.elems.remove(0) }; - if let BufferElem::Sized(sized) = elem { - match sized.downcast::>() { - Ok(cast) => { - self.remaining_written_len -= u32::try_from(size_of::()).unwrap(); - Some(unsafe { cast.assume_init() }) - } - Err(_) => { - panic!("Attempted to downcast element to wrong type"); - } - } - } else { + let BufferElem::Sized(sized) = elem else { panic!("Attempted to pop elements in order different from insertion"); - } + }; + + let cast = sized + .downcast::>() + .expect("Attempted to downcast element to wrong type"); + + self.remaining_written_len -= u32::try_from(size_of::()).unwrap(); + Some(unsafe { cast.assume_init() }) } pub fn pop_front_vec(&mut self) -> Option> { @@ -418,17 +415,17 @@ impl UsedDeviceWritableBuffer { self.elems.remove(0) }; - if let BufferElem::Vector(mut vector) = elem { - let new_len = u32::min( - vector.capacity().try_into().unwrap(), - self.remaining_written_len, - ); - self.remaining_written_len -= new_len; - unsafe { vector.set_len(new_len.try_into().unwrap()) }; - Some(vector) - } else { + let BufferElem::Vector(mut vector) = elem else { panic!("Attempted to pop elements in order different from insertion"); - } + }; + + let new_len = u32::min( + vector.capacity().try_into().unwrap(), + self.remaining_written_len, + ); + self.remaining_written_len -= new_len; + unsafe { vector.set_len(new_len.try_into().unwrap()) }; + Some(vector) } /// It is possible for devices to use descriptors for a type other than what they were meant. @@ -445,14 +442,14 @@ impl UsedDeviceWritableBuffer { self.elems.remove(0) }; - if let BufferElem::Sized(sized) = elem { - let capacity = u32::try_from(size_of_val(sized.as_ref())).unwrap(); - let len = u32::min(capacity, self.remaining_written_len); - self.remaining_written_len -= len; - Some((sized, len.try_into().unwrap())) - } else { + let BufferElem::Sized(sized) = elem else { panic!("This function is meant for the Sized variant of the BufferElem enum."); - } + }; + + let capacity = u32::try_from(size_of_val(sized.as_ref())).unwrap(); + let len = u32::min(capacity, self.remaining_written_len); + self.remaining_written_len -= len; + Some((sized, len.try_into().unwrap())) } } diff --git a/src/drivers/virtio/virtqueue/packed.rs b/src/drivers/virtio/virtqueue/packed.rs index 068e07e6ee..d9d11de87d 100644 --- a/src/drivers/virtio/virtqueue/packed.rs +++ b/src/drivers/virtio/virtqueue/packed.rs @@ -120,19 +120,16 @@ impl DescriptorRing { // Catch empty push, in order to allow zero initialized first_ctrl_settings struct // which will be overwritten in the first iteration of the for-loop - let first_ctrl_settings; - let first_buffer; - let mut ctrl; - let mut tkn_iterator = tkn_lst.into_iter(); - if let Some(first_tkn) = tkn_iterator.next() { - ctrl = self.push_without_making_available(&first_tkn)?; - first_ctrl_settings = (ctrl.start, ctrl.buff_id, ctrl.first_flags); - first_buffer = first_tkn; - } else { + let Some(first_tkn) = tkn_iterator.next() else { // Empty batches are an error return Err(VirtqError::BufferNotSpecified); - } + }; + + let mut ctrl = self.push_without_making_available(&first_tkn)?; + let first_ctrl_settings = (ctrl.start, ctrl.buff_id, ctrl.first_flags); + let first_buffer = first_tkn; + // Push the remaining tokens (if any) for tkn in tkn_iterator { ctrl.make_avail(tkn); @@ -276,46 +273,46 @@ impl ReadCtrl<'_> { fn poll_next(&mut self) -> Option<(TransferToken, u32)> { // Check if descriptor has been marked used. let desc = &self.desc_ring.ring[usize::from(self.position)]; - if self.desc_ring.is_marked_used(desc.flags) { - let buff_id = desc.id.to_ne(); - let tkn = self.desc_ring.tkn_ref_ring[usize::from(buff_id)] - .take() - .expect( - "The buff_id is incorrect or the reference to the TransferToken was misplaced.", - ); - - // Retrieve if any has been written to the queue. If this is the case, we calculate the overall length - // This is necessary in order to provide the drivers with the correct access, to usable data. - // - // According to the standard the device signals solely via the first written descriptor if anything has been written to - // the write descriptors of a buffer. - // See Virtio specification v1.1. - 2.7.4 - // - 2.7.5 - // - 2.7.6 - // let mut write_len = if self.desc_ring.ring[self.position].flags & DescrFlags::VIRTQ_DESC_F_WRITE == DescrFlags::VIRTQ_DESC_F_WRITE { - // self.desc_ring.ring[self.position].len - // } else { - // 0 - // }; - // - // INFO: - // Due to the behavior of the currently used devices and the virtio code from the linux kernel, we assume, that device do NOT set this - // flag correctly upon writes. Hence we omit it, in order to receive data. - - // We need to read the written length before advancing the position. - let write_len = desc.len.to_ne(); - - for _ in 0..tkn.num_consuming_descr() { - self.incrmt(); - } - unsafe { - self.desc_ring.indexes.deallocate(buff_id.into()); - } + if !self.desc_ring.is_marked_used(desc.flags) { + return None; + } - Some((tkn, write_len)) - } else { - None + let buff_id = desc.id.to_ne(); + let tkn = self.desc_ring.tkn_ref_ring[usize::from(buff_id)] + .take() + .expect( + "The buff_id is incorrect or the reference to the TransferToken was misplaced.", + ); + + // Retrieve if any has been written to the queue. If this is the case, we calculate the overall length + // This is necessary in order to provide the drivers with the correct access, to usable data. + // + // According to the standard the device signals solely via the first written descriptor if anything has been written to + // the write descriptors of a buffer. + // See Virtio specification v1.1. - 2.7.4 + // - 2.7.5 + // - 2.7.6 + // let mut write_len = if self.desc_ring.ring[self.position].flags & DescrFlags::VIRTQ_DESC_F_WRITE == DescrFlags::VIRTQ_DESC_F_WRITE { + // self.desc_ring.ring[self.position].len + // } else { + // 0 + // }; + // + // INFO: + // Due to the behavior of the currently used devices and the virtio code from the linux kernel, we assume, that device do NOT set this + // flag correctly upon writes. Hence we omit it, in order to receive data. + + // We need to read the written length before advancing the position. + let write_len = desc.len.to_ne(); + + for _ in 0..tkn.num_consuming_descr() { + self.incrmt(); + } + unsafe { + self.desc_ring.indexes.deallocate(buff_id.into()); } + + Some((tkn, write_len)) } fn incrmt(&mut self) { @@ -670,9 +667,9 @@ impl PackedVq { } // Get a handler to the queues configuration area. - let Some(mut vq_handler) = com_cfg.select_vq(index) else { - return Err(VirtqError::QueueNotExisting(index)); - }; + let mut vq_handler = com_cfg + .select_vq(index) + .ok_or(VirtqError::QueueNotExisting(index))?; // Must catch zero size as it is not allowed for packed queues. // Must catch size larger 0x8000 (2^15) as it is not allowed for packed queues. diff --git a/src/drivers/virtio/virtqueue/split.rs b/src/drivers/virtio/virtqueue/split.rs index c6892d56f8..cc584053bb 100644 --- a/src/drivers/virtio/virtqueue/split.rs +++ b/src/drivers/virtio/virtqueue/split.rs @@ -250,9 +250,9 @@ impl SplitVq { features: virtio::F, ) -> Result { // Get a handler to the queues configuration area. - let Some(mut vq_handler) = com_cfg.select_vq(index) else { - return Err(VirtqError::QueueNotExisting(index)); - }; + let mut vq_handler = com_cfg + .select_vq(index) + .ok_or(VirtqError::QueueNotExisting(index))?; let size = vq_handler.set_vq_size(max_size); diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index 33a254b2aa..83253c4d7d 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -79,15 +79,19 @@ impl RxQueue { } pub fn enable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.enable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.enable_notifs(); } pub fn disable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.disable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.disable_notifs(); } fn get_next(&mut self) -> Option { @@ -107,13 +111,11 @@ impl RxQueue { }; let packet = buffer_tkn.used_recv_buff.pop_front_vec().unwrap(); - if let Some(ref mut vq) = self.vq { - f(&header, &packet[..]); + let vq = self.vq.as_mut().expect("Invalid length of receive queue"); - fill_queue(vq, 1, self.packet_size); - } else { - panic!("Invalid length of receive queue"); - } + f(&header, &packet[..]); + + fill_queue(vq, 1, self.packet_size); } } } @@ -138,21 +140,27 @@ impl TxQueue { } pub fn enable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.enable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.enable_notifs(); } pub fn disable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.disable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.disable_notifs(); } fn poll(&mut self) { - if let Some(ref mut vq) = self.vq { - while vq.try_recv().is_ok() {} - } + let Some(vq) = &mut self.vq else { + return; + }; + + while vq.try_recv().is_ok() {} } /// Provides a slice to copy the packet and transfer the packet @@ -165,31 +173,29 @@ impl TxQueue { // We need to poll to get the queue to remove elements from the table and make space for // what we are about to add self.poll(); - if let Some(ref mut vq) = self.vq { - assert!(len < usize::try_from(self.packet_length).unwrap()); - let mut packet = Vec::with_capacity_in(len, DeviceAlloc); - let result = unsafe { - let result = f(packet.spare_capacity_mut().assume_init_mut()); - packet.set_len(len); - result - }; + let vq = self.vq.as_mut().expect("Unable to get send queue"); - let buff_tkn = AvailBufferToken::new( - { - let mut vec = SmallVec::new(); - vec.push(BufferElem::Vector(packet)); - vec - }, - SmallVec::new(), - ) - .unwrap(); + assert!(len < usize::try_from(self.packet_length).unwrap()); + let mut packet = Vec::with_capacity_in(len, DeviceAlloc); + let result = unsafe { + let result = f(packet.spare_capacity_mut().assume_init_mut()); + packet.set_len(len); + result + }; - vq.dispatch(buff_tkn, false, BufferType::Direct).unwrap(); + let buff_tkn = AvailBufferToken::new( + { + let mut vec = SmallVec::new(); + vec.push(BufferElem::Vector(packet)); + vec + }, + SmallVec::new(), + ) + .unwrap(); - result - } else { - panic!("Unable to get send queue"); - } + vq.dispatch(buff_tkn, false, BufferType::Direct).unwrap(); + + result } } @@ -217,15 +223,19 @@ impl EventQueue { } pub fn enable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.enable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.enable_notifs(); } pub fn disable_notifs(&mut self) { - if let Some(ref mut vq) = self.vq { - vq.disable_notifs(); - } + let Some(vq) = &mut self.vq else { + return; + }; + + vq.disable_notifs(); } } From 897939e7dc635ad68dc8c66d25c6b9299cc30956 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kr=C3=B6ning?= Date: Wed, 11 Feb 2026 15:34:54 +0100 Subject: [PATCH 3/7] refactor(fs): do more early returns --- src/fs/fuse.rs | 557 ++++++++++++++++++++++++------------------------ src/fs/mem.rs | 297 ++++++++++++-------------- src/fs/mod.rs | 50 ++--- src/fs/uhyve.rs | 89 ++++---- 4 files changed, 490 insertions(+), 503 deletions(-) diff --git a/src/fs/fuse.rs b/src/fs/fuse.rs index 1bfce7bef3..9577c86aaa 100644 --- a/src/fs/fuse.rs +++ b/src/fs/fuse.rs @@ -691,32 +691,36 @@ impl FuseFileHandleInner { let kh = KH.fetch_add(1, Ordering::SeqCst); future::poll_fn(|cx| { - if let (Some(nid), Some(fh)) = (self.fuse_nid, self.fuse_fh) { - let (cmd, rsp_payload_len) = ops::Poll::create(nid, fh, kh, events); - let rsp = get_filesystem_driver() - .ok_or(Errno::Nosys)? - .lock() - .send_command(cmd, rsp_payload_len)?; + let Some(nid) = self.fuse_nid else { + return Poll::Ready(Ok(PollEvent::POLLERR)); + }; - if rsp.headers.out_header.error < 0 { - Poll::Ready(Err(Errno::Io)) - } else { - let revents = - PollEvent::from_bits(i16::try_from(rsp.headers.op_header.revents).unwrap()) - .unwrap(); - if !revents.intersects(events) - && !revents.intersects( - PollEvent::POLLERR | PollEvent::POLLNVAL | PollEvent::POLLHUP, - ) { - // the current implementation use polling to wait for an event - // consequently, we have to wakeup the waker, if the the event doesn't arrive - cx.waker().wake_by_ref(); - } - Poll::Ready(Ok(revents)) - } - } else { - Poll::Ready(Ok(PollEvent::POLLERR)) + let Some(fh) = self.fuse_fh else { + return Poll::Ready(Ok(PollEvent::POLLERR)); + }; + + let (cmd, rsp_payload_len) = ops::Poll::create(nid, fh, kh, events); + let rsp = get_filesystem_driver() + .ok_or(Errno::Nosys)? + .lock() + .send_command(cmd, rsp_payload_len)?; + + if rsp.headers.out_header.error < 0 { + return Poll::Ready(Err(Errno::Io)); + } + + let revents = + PollEvent::from_bits(i16::try_from(rsp.headers.op_header.revents).unwrap()) + .unwrap(); + if !revents.intersects(events) + && !revents + .intersects(PollEvent::POLLERR | PollEvent::POLLNVAL | PollEvent::POLLHUP) + { + // the current implementation use polling to wait for an event + // consequently, we have to wakeup the waker, if the the event doesn't arrive + cx.waker().wake_by_ref(); } + Poll::Ready(Ok(revents)) }) .await } @@ -731,24 +735,23 @@ impl FuseFileHandleInner { // position instead. match whence { SeekWhence::End | SeekWhence::Data | SeekWhence::Hole => { - if let (Some(nid), Some(fh)) = (self.fuse_nid, self.fuse_fh) { - let (cmd, rsp_payload_len) = ops::Lseek::create(nid, fh, offset, whence); - let rsp = get_filesystem_driver() - .ok_or(Errno::Nosys)? - .lock() - .send_command(cmd, rsp_payload_len)?; - - if rsp.headers.out_header.error < 0 { - return Err(Errno::Io); - } - - let rsp_offset = rsp.headers.op_header.offset; - self.offset = rsp.headers.op_header.offset.try_into().unwrap(); - - Ok(rsp_offset.try_into().unwrap()) - } else { - Err(Errno::Io) + let nid = self.fuse_nid.ok_or(Errno::Io)?; + let fh = self.fuse_fh.ok_or(Errno::Io)?; + + let (cmd, rsp_payload_len) = ops::Lseek::create(nid, fh, offset, whence); + let rsp = get_filesystem_driver() + .ok_or(Errno::Nosys)? + .lock() + .send_command(cmd, rsp_payload_len)?; + + if rsp.headers.out_header.error < 0 { + return Err(Errno::Io); } + + let rsp_offset = rsp.headers.op_header.offset; + self.offset = rsp.headers.op_header.offset.try_into().unwrap(); + + Ok(rsp_offset.try_into().unwrap()) } SeekWhence::Set => { self.offset = offset.try_into().map_err(|_e| Errno::Inval)?; @@ -765,36 +768,40 @@ impl FuseFileHandleInner { fn fstat(&mut self) -> io::Result { debug!("FUSE getattr"); - if let (Some(nid), Some(fh)) = (self.fuse_nid, self.fuse_fh) { - let (cmd, rsp_payload_len) = ops::Getattr::create(nid, fh, FUSE_GETATTR_FH); - let rsp = get_filesystem_driver() - .ok_or(Errno::Nosys)? - .lock() - .send_command(cmd, rsp_payload_len)?; - if rsp.headers.out_header.error < 0 { - return Err(Errno::Io); - } - Ok(rsp.headers.op_header.attr.into()) - } else { - Err(Errno::Io) + + let nid = self.fuse_nid.ok_or(Errno::Io)?; + let fh = self.fuse_fh.ok_or(Errno::Io)?; + + let (cmd, rsp_payload_len) = ops::Getattr::create(nid, fh, FUSE_GETATTR_FH); + let rsp = get_filesystem_driver() + .ok_or(Errno::Nosys)? + .lock() + .send_command(cmd, rsp_payload_len)?; + + if rsp.headers.out_header.error < 0 { + return Err(Errno::Io); } + + Ok(rsp.headers.op_header.attr.into()) } fn set_attr(&mut self, attr: FileAttr, valid: SetAttrValidFields) -> io::Result { debug!("FUSE setattr"); - if let (Some(nid), Some(fh)) = (self.fuse_nid, self.fuse_fh) { - let (cmd, rsp_payload_len) = ops::Setattr::create(nid, fh, attr, valid); - let rsp = get_filesystem_driver() - .ok_or(Errno::Nosys)? - .lock() - .send_command(cmd, rsp_payload_len)?; - if rsp.headers.out_header.error < 0 { - return Err(Errno::Io); - } - Ok(rsp.headers.op_header.attr.into()) - } else { - Err(Errno::Io) + + let nid = self.fuse_nid.ok_or(Errno::Io)?; + let fh = self.fuse_fh.ok_or(Errno::Io)?; + + let (cmd, rsp_payload_len) = ops::Setattr::create(nid, fh, attr, valid); + let rsp = get_filesystem_driver() + .ok_or(Errno::Nosys)? + .lock() + .send_command(cmd, rsp_payload_len)?; + + if rsp.headers.out_header.error < 0 { + return Err(Errno::Io); } + + Ok(rsp.headers.op_header.attr.into()) } } @@ -809,29 +816,27 @@ impl Read for FuseFileHandleInner { debug!("Reading longer than max_read_len: {len}"); len = MAX_READ_LEN; } - if let (Some(nid), Some(fh)) = (self.fuse_nid, self.fuse_fh) { - let (cmd, rsp_payload_len) = - ops::Read::create(nid, fh, len.try_into().unwrap(), self.offset as u64); - let rsp = get_filesystem_driver() - .ok_or(Errno::Nosys)? - .lock() - .send_command(cmd, rsp_payload_len)?; - let len: usize = - if (rsp.headers.out_header.len as usize) - mem::size_of::() >= len - { - len - } else { - (rsp.headers.out_header.len as usize) - mem::size_of::() - }; - self.offset += len; - - buf[..len].copy_from_slice(&rsp.payload.unwrap()[..len]); - - Ok(len) - } else { - debug!("File not open, cannot read!"); - Err(Errno::Noent) - } + + let nid = self.fuse_nid.ok_or(Errno::Io)?; + let fh = self.fuse_fh.ok_or(Errno::Io)?; + + let (cmd, rsp_payload_len) = + ops::Read::create(nid, fh, len.try_into().unwrap(), self.offset as u64); + let rsp = get_filesystem_driver() + .ok_or(Errno::Nosys)? + .lock() + .send_command(cmd, rsp_payload_len)?; + let len: usize = + if (rsp.headers.out_header.len as usize) - mem::size_of::() >= len { + len + } else { + (rsp.headers.out_header.len as usize) - mem::size_of::() + }; + self.offset += len; + + buf[..len].copy_from_slice(&rsp.payload.unwrap()[..len]); + + Ok(len) } } @@ -847,31 +852,29 @@ impl Write for FuseFileHandleInner { ); truncated_len = MAX_WRITE_LEN; } - if let (Some(nid), Some(fh)) = (self.fuse_nid, self.fuse_fh) { - let truncated_buf = Box::<[u8]>::from(&buf[..truncated_len]); - let (cmd, rsp_payload_len) = - ops::Write::create(nid, fh, truncated_buf, self.offset as u64); - let rsp = get_filesystem_driver() - .ok_or(Errno::Nosys)? - .lock() - .send_command(cmd, rsp_payload_len)?; - if rsp.headers.out_header.error < 0 { - return Err(Errno::Io); - } + let nid = self.fuse_nid.ok_or(Errno::Io)?; + let fh = self.fuse_fh.ok_or(Errno::Io)?; - let rsp_size = rsp.headers.op_header.size; - let rsp_len: usize = if rsp_size > u32::try_from(truncated_len).unwrap() { - truncated_len - } else { - rsp_size.try_into().unwrap() - }; - self.offset += rsp_len; - Ok(rsp_len) - } else { - warn!("File not open, cannot read!"); - Err(Errno::Noent) + let truncated_buf = Box::<[u8]>::from(&buf[..truncated_len]); + let (cmd, rsp_payload_len) = ops::Write::create(nid, fh, truncated_buf, self.offset as u64); + let rsp = get_filesystem_driver() + .ok_or(Errno::Nosys)? + .lock() + .send_command(cmd, rsp_payload_len)?; + + if rsp.headers.out_header.error < 0 { + return Err(Errno::Io); } + + let rsp_size = rsp.headers.op_header.size; + let rsp_len: usize = if rsp_size > u32::try_from(truncated_len).unwrap() { + truncated_len + } else { + rsp_size.try_into().unwrap() + }; + self.offset += rsp_len; + Ok(rsp_len) } fn flush(&mut self) -> Result<(), Self::Error> { @@ -881,16 +884,20 @@ impl Write for FuseFileHandleInner { impl Drop for FuseFileHandleInner { fn drop(&mut self) { - if let Some(fuse_nid) = self.fuse_nid - && let Some(fuse_fh) = self.fuse_fh - { - let (cmd, rsp_payload_len) = ops::Release::create(fuse_nid, fuse_fh); - get_filesystem_driver() - .unwrap() - .lock() - .send_command(cmd, rsp_payload_len) - .unwrap(); - } + let Some(fuse_nid) = self.fuse_nid else { + return; + }; + + let Some(fuse_fh) = self.fuse_fh else { + return; + }; + + let (cmd, rsp_payload_len) = ops::Release::create(fuse_nid, fuse_fh); + get_filesystem_driver() + .unwrap() + .lock() + .send_command(cmd, rsp_payload_len) + .unwrap(); } } @@ -1289,58 +1296,58 @@ impl VfsNode for FuseDirectory { .send_command(cmd, rsp_payload_len)?; let attr = FileAttr::from(rsp.headers.op_header.attr); - if attr.st_mode.contains(AccessPermission::S_IFDIR) { - let mut path = path.into_string().unwrap(); - path.remove(0); - Ok(Arc::new(async_lock::RwLock::new(FuseDirectoryHandle::new( - Some(path), - )))) - } else { - Err(Errno::Notdir) + if !attr.st_mode.contains(AccessPermission::S_IFDIR) { + return Err(Errno::Notdir); } - } else { - let file = FuseFileHandle::new(); - // 1.FUSE_INIT to create session - // Already done - let mut file_guard = block_on(async { Ok(file.0.lock().await) }, None)?; + let mut path = path.into_string().unwrap(); + path.remove(0); + return Ok(Arc::new(async_lock::RwLock::new(FuseDirectoryHandle::new( + Some(path), + )))); + } - // Differentiate between opening and creating new file, since fuse does not support O_CREAT on open. - if opt.contains(OpenOption::O_CREAT) { - // Create file (opens implicitly, returns results from both lookup and open calls) - let (cmd, rsp_payload_len) = - ops::Create::create(path, opt.bits().try_into().unwrap(), mode.bits()); - let rsp = get_filesystem_driver() - .ok_or(Errno::Nosys)? - .lock() - .send_command(cmd, rsp_payload_len)?; + let file = FuseFileHandle::new(); - let inner = rsp.headers.op_header; - file_guard.fuse_nid = Some(inner.entry.nodeid); - file_guard.fuse_fh = Some(inner.open.fh); - } else { - // 2.FUSE_LOOKUP(FUSE_ROOT_ID, ā€œfooā€) -> nodeid - file_guard.fuse_nid = lookup(path); + // 1.FUSE_INIT to create session + // Already done + let mut file_guard = block_on(async { Ok(file.0.lock().await) }, None)?; - if file_guard.fuse_nid.is_none() { - warn!("Fuse lookup seems to have failed!"); - return Err(Errno::Noent); - } + // Differentiate between opening and creating new file, since fuse does not support O_CREAT on open. + if opt.contains(OpenOption::O_CREAT) { + // Create file (opens implicitly, returns results from both lookup and open calls) + let (cmd, rsp_payload_len) = + ops::Create::create(path, opt.bits().try_into().unwrap(), mode.bits()); + let rsp = get_filesystem_driver() + .ok_or(Errno::Nosys)? + .lock() + .send_command(cmd, rsp_payload_len)?; - // 3.FUSE_OPEN(nodeid, O_RDONLY) -> fh - let (cmd, rsp_payload_len) = - ops::Open::create(file_guard.fuse_nid.unwrap(), opt.bits().try_into().unwrap()); - let rsp = get_filesystem_driver() - .ok_or(Errno::Nosys)? - .lock() - .send_command(cmd, rsp_payload_len)?; - file_guard.fuse_fh = Some(rsp.headers.op_header.fh); - } + let inner = rsp.headers.op_header; + file_guard.fuse_nid = Some(inner.entry.nodeid); + file_guard.fuse_fh = Some(inner.open.fh); + } else { + // 2.FUSE_LOOKUP(FUSE_ROOT_ID, ā€œfooā€) -> nodeid + file_guard.fuse_nid = lookup(path); - drop(file_guard); + if file_guard.fuse_nid.is_none() { + warn!("Fuse lookup seems to have failed!"); + return Err(Errno::Noent); + } - Ok(Arc::new(async_lock::RwLock::new(file))) + // 3.FUSE_OPEN(nodeid, O_RDONLY) -> fh + let (cmd, rsp_payload_len) = + ops::Open::create(file_guard.fuse_nid.unwrap(), opt.bits().try_into().unwrap()); + let rsp = get_filesystem_driver() + .ok_or(Errno::Nosys)? + .lock() + .send_command(cmd, rsp_payload_len)?; + file_guard.fuse_fh = Some(rsp.headers.op_header.fh); } + + drop(file_guard); + + Ok(Arc::new(async_lock::RwLock::new(file))) } fn traverse_unlink(&self, components: &mut Vec<&str>) -> io::Result<()> { @@ -1377,138 +1384,138 @@ impl VfsNode for FuseDirectory { .ok_or(Errno::Nosys)? .lock() .send_command(cmd, rsp_payload_len)?; - if rsp.headers.out_header.error == 0 { - Ok(()) - } else { - Err(Errno::try_from(-rsp.headers.out_header.error).unwrap()) + if rsp.headers.out_header.error != 0 { + return Err(Errno::try_from(-rsp.headers.out_header.error).unwrap()); } + + Ok(()) } } pub(crate) fn init() { debug!("Try to initialize fuse filesystem"); - if let Some(driver) = get_filesystem_driver() { - let (cmd, rsp_payload_len) = ops::Init::create(); - let rsp = driver.lock().send_command(cmd, rsp_payload_len).unwrap(); - trace!("fuse init answer: {rsp:?}"); - - let mount_point = driver.lock().get_mount_point(); - if mount_point == "/" { - let fuse_nid = lookup(c"/".to_owned()).unwrap(); - // Opendir - // Flag 0x10000 for O_DIRECTORY might not be necessary - let (mut cmd, rsp_payload_len) = ops::Open::create(fuse_nid, 0x10000); - cmd.headers.in_header.opcode = fuse_opcode::FUSE_OPENDIR as u32; - let rsp = get_filesystem_driver() - .unwrap() - .lock() - .send_command(cmd, rsp_payload_len) - .unwrap(); - let fuse_fh = rsp.headers.op_header.fh; + let Some(driver) = get_filesystem_driver() else { + return; + }; - // Linux seems to allocate a single page to store the dirfile - let len = MAX_READ_LEN as u32; - let mut offset: usize = 0; + let (cmd, rsp_payload_len) = ops::Init::create(); + let rsp = driver.lock().send_command(cmd, rsp_payload_len).unwrap(); + trace!("fuse init answer: {rsp:?}"); - // read content of the directory - let (mut cmd, rsp_payload_len) = ops::Read::create(fuse_nid, fuse_fh, len, 0); - cmd.headers.in_header.opcode = fuse_opcode::FUSE_READDIR as u32; - let rsp = get_filesystem_driver() - .unwrap() - .lock() - .send_command(cmd, rsp_payload_len) - .unwrap(); + let mount_point = driver.lock().get_mount_point(); + if mount_point != "/" { + let mount_point = if mount_point.starts_with('/') { + mount_point + } else { + "/".to_owned() + &mount_point + }; - let len: usize = if rsp.headers.out_header.len as usize - - mem::size_of::() - >= usize::try_from(len).unwrap() - { - len.try_into().unwrap() - } else { - (rsp.headers.out_header.len as usize) - mem::size_of::() - }; + info!("Mounting virtio-fs at {mount_point}"); + fs::FILESYSTEM + .get() + .unwrap() + .mount(mount_point.as_str(), Box::new(FuseDirectory::new(None))) + .expect("Mount failed. Invalid mount_point?"); + return; + } - assert!(len > mem::size_of::(), "FUSE no new dirs"); - - let mut entries: Vec = Vec::new(); - while (rsp.headers.out_header.len as usize) - offset > mem::size_of::() { - let dirent = unsafe { - &*rsp - .payload - .as_ref() - .unwrap() - .as_ptr() - .byte_add(offset) - .cast::() - }; - - offset += mem::size_of::() + dirent.namelen as usize; - // Align to dirent struct - offset = ((offset) + U64_SIZE - 1) & (!(U64_SIZE - 1)); - - let name: &'static [u8] = unsafe { - slice::from_raw_parts( - dirent.name.as_ptr().cast(), - dirent.namelen.try_into().unwrap(), - ) - }; - entries.push(unsafe { core::str::from_utf8_unchecked(name).to_owned() }); - } + let fuse_nid = lookup(c"/".to_owned()).unwrap(); + // Opendir + // Flag 0x10000 for O_DIRECTORY might not be necessary + let (mut cmd, rsp_payload_len) = ops::Open::create(fuse_nid, 0x10000); + cmd.headers.in_header.opcode = fuse_opcode::FUSE_OPENDIR as u32; + let rsp = get_filesystem_driver() + .unwrap() + .lock() + .send_command(cmd, rsp_payload_len) + .unwrap(); + let fuse_fh = rsp.headers.op_header.fh; + + // Linux seems to allocate a single page to store the dirfile + let len = MAX_READ_LEN as u32; + let mut offset: usize = 0; + + // read content of the directory + let (mut cmd, rsp_payload_len) = ops::Read::create(fuse_nid, fuse_fh, len, 0); + cmd.headers.in_header.opcode = fuse_opcode::FUSE_READDIR as u32; + let rsp = get_filesystem_driver() + .unwrap() + .lock() + .send_command(cmd, rsp_payload_len) + .unwrap(); + + let len: usize = if rsp.headers.out_header.len as usize - mem::size_of::() + >= usize::try_from(len).unwrap() + { + len.try_into().unwrap() + } else { + (rsp.headers.out_header.len as usize) - mem::size_of::() + }; - let (cmd, rsp_payload_len) = ops::Release::create(fuse_nid, fuse_fh); - get_filesystem_driver() + assert!(len > mem::size_of::(), "FUSE no new dirs"); + + let mut entries: Vec = Vec::new(); + while (rsp.headers.out_header.len as usize) - offset > mem::size_of::() { + let dirent = unsafe { + &*rsp + .payload + .as_ref() .unwrap() - .lock() - .send_command(cmd, rsp_payload_len) - .unwrap(); - - // remove predefined directories - entries.retain(|x| x != "."); - entries.retain(|x| x != ".."); - entries.retain(|x| x != "tmp"); - entries.retain(|x| x != "proc"); - warn!( - "Fuse don't mount the host directories 'tmp' and 'proc' into the guest file system!" - ); + .as_ptr() + .byte_add(offset) + .cast::() + }; - for i in entries { - let i_cstr = CString::new(i.as_str()).unwrap(); - let (cmd, rsp_payload_len) = ops::Lookup::create(i_cstr); - let rsp = get_filesystem_driver() - .unwrap() - .lock() - .send_command(cmd, rsp_payload_len) - .unwrap(); + offset += mem::size_of::() + dirent.namelen as usize; + // Align to dirent struct + offset = ((offset) + U64_SIZE - 1) & (!(U64_SIZE - 1)); - let attr = FileAttr::from(rsp.headers.op_header.attr); - if attr.st_mode.contains(AccessPermission::S_IFDIR) { - info!("Fuse mount {i} to /{i}"); - fs::FILESYSTEM - .get() - .unwrap() - .mount( - &("/".to_owned() + i.as_str()), - Box::new(FuseDirectory::new(Some(i))), - ) - .expect("Mount failed. Invalid mount_point?"); - } else { - warn!("Fuse don't mount {i}. It isn't a directory!"); - } - } - } else { - let mount_point = if mount_point.starts_with('/') { - mount_point - } else { - "/".to_owned() + &mount_point - }; + let name: &'static [u8] = unsafe { + slice::from_raw_parts( + dirent.name.as_ptr().cast(), + dirent.namelen.try_into().unwrap(), + ) + }; + entries.push(unsafe { core::str::from_utf8_unchecked(name).to_owned() }); + } + + let (cmd, rsp_payload_len) = ops::Release::create(fuse_nid, fuse_fh); + get_filesystem_driver() + .unwrap() + .lock() + .send_command(cmd, rsp_payload_len) + .unwrap(); + + // remove predefined directories + entries.retain(|x| x != "."); + entries.retain(|x| x != ".."); + entries.retain(|x| x != "tmp"); + entries.retain(|x| x != "proc"); + warn!("Fuse don't mount the host directories 'tmp' and 'proc' into the guest file system!"); + + for i in entries { + let i_cstr = CString::new(i.as_str()).unwrap(); + let (cmd, rsp_payload_len) = ops::Lookup::create(i_cstr); + let rsp = get_filesystem_driver() + .unwrap() + .lock() + .send_command(cmd, rsp_payload_len) + .unwrap(); - info!("Mounting virtio-fs at {mount_point}"); + let attr = FileAttr::from(rsp.headers.op_header.attr); + if attr.st_mode.contains(AccessPermission::S_IFDIR) { + info!("Fuse mount {i} to /{i}"); fs::FILESYSTEM .get() .unwrap() - .mount(mount_point.as_str(), Box::new(FuseDirectory::new(None))) + .mount( + &("/".to_owned() + i.as_str()), + Box::new(FuseDirectory::new(Some(i))), + ) .expect("Mount failed. Invalid mount_point?"); + } else { + warn!("Fuse don't mount {i}. It isn't a directory!"); } } } diff --git a/src/fs/mem.rs b/src/fs/mem.rs index 0bb470844a..242f9d4b83 100644 --- a/src/fs/mem.rs +++ b/src/fs/mem.rs @@ -96,12 +96,12 @@ impl ObjectInterface for RomFileInterface { _ => return Err(Errno::Inval), }; - if (0..=data_len).contains(&new_pos) { - *pos_guard = new_pos as usize; - Ok(new_pos) - } else { - Err(Errno::Inval) + if !(0..=data_len).contains(&new_pos) { + return Err(Errno::Inval); } + + *pos_guard = new_pos as usize; + Ok(new_pos) } async fn fstat(&self) -> io::Result { @@ -283,19 +283,19 @@ impl VfsNode for RomFile { } fn traverse_lstat(&self, components: &mut Vec<&str>) -> io::Result { - if components.is_empty() { - self.get_file_attributes() - } else { - Err(Errno::Badf) + if !components.is_empty() { + return Err(Errno::Badf); } + + self.get_file_attributes() } fn traverse_stat(&self, components: &mut Vec<&str>) -> io::Result { - if components.is_empty() { - self.get_file_attributes() - } else { - Err(Errno::Badf) + if !components.is_empty() { + return Err(Errno::Badf); } + + self.get_file_attributes() } } @@ -339,19 +339,19 @@ impl VfsNode for RamFile { } fn traverse_lstat(&self, components: &mut Vec<&str>) -> io::Result { - if components.is_empty() { - self.get_file_attributes() - } else { - Err(Errno::Badf) + if !components.is_empty() { + return Err(Errno::Badf); } + + self.get_file_attributes() } fn traverse_stat(&self, components: &mut Vec<&str>) -> io::Result { - if components.is_empty() { - self.get_file_attributes() - } else { - Err(Errno::Badf) + if !components.is_empty() { + return Err(Errno::Badf); } + + self.get_file_attributes() } } @@ -474,38 +474,35 @@ impl MemDirectory { opt: OpenOption, mode: AccessPermission, ) -> io::Result>> { - if let Some(component) = components.pop() { - if components.is_empty() { - let mut guard = self.inner.write().await; - if let Some(file) = guard.get(component) { - if opt.contains(OpenOption::O_DIRECTORY) - && file.get_kind() != NodeKind::Directory - { - return Err(Errno::Notdir); - } - - if file.get_kind() == NodeKind::File || file.get_kind() == NodeKind::Directory { - return file.get_object(); - } else { - return Err(Errno::Noent); - } - } else if opt.contains(OpenOption::O_CREAT) { - let file = Box::new(RamFile::new(mode)); - guard.insert(component.to_owned(), file.clone()); - return Ok(Arc::new(RwLock::new(RamFileInterface::new( - file.data.clone(), - )))); - } else { - return Err(Errno::Noent); - } - } + let component = components.pop().ok_or(Errno::Noent)?; - if let Some(directory) = self.inner.read().await.get(component) { - return directory.traverse_open(components, opt, mode); + if !components.is_empty() { + let inner = self.inner.read().await; + let directory = inner.get(component).ok_or(Errno::Noent)?; + return directory.traverse_open(components, opt, mode); + } + + let mut inner = self.inner.write().await; + let Some(file) = inner.get(component) else { + if opt.contains(OpenOption::O_CREAT) { + let file = Box::new(RamFile::new(mode)); + inner.insert(component.to_owned(), file.clone()); + let file = Arc::new(RwLock::new(RamFileInterface::new(file.data.clone()))); + return Ok(file); } + + return Err(Errno::Noent); + }; + + if opt.contains(OpenOption::O_DIRECTORY) && file.get_kind() != NodeKind::Directory { + return Err(Errno::Notdir); + } + + if file.get_kind() != NodeKind::File && file.get_kind() != NodeKind::Directory { + return Err(Errno::Noent); } - Err(Errno::Noent) + file.get_object() } } @@ -527,21 +524,21 @@ impl VfsNode for MemDirectory { fn traverse_mkdir(&self, components: &mut Vec<&str>, mode: AccessPermission) -> io::Result<()> { block_on( async { - if let Some(component) = components.pop() { - if let Some(directory) = self.inner.read().await.get(component) { - return directory.traverse_mkdir(components, mode); - } - - if components.is_empty() { - self.inner - .write() - .await - .insert(component.to_owned(), Box::new(MemDirectory::new(mode))); - return Ok(()); - } + let component = components.pop().ok_or(Errno::Badf)?; + + if let Some(directory) = self.inner.read().await.get(component) { + return directory.traverse_mkdir(components, mode); + } + + if !components.is_empty() { + return Err(Errno::Badf); } - Err(Errno::Badf) + self.inner + .write() + .await + .insert(component.to_owned(), Box::new(MemDirectory::new(mode))); + Ok(()) }, None, ) @@ -550,23 +547,23 @@ impl VfsNode for MemDirectory { fn traverse_rmdir(&self, components: &mut Vec<&str>) -> io::Result<()> { block_on( async { - if let Some(component) = components.pop() { - if components.is_empty() { - let mut guard = self.inner.write().await; - - let obj = guard.remove(component).ok_or(Errno::Noent)?; - if obj.get_kind() == NodeKind::Directory { - return Ok(()); - } else { - guard.insert(component.to_owned(), obj); - return Err(Errno::Notdir); - } - } else if let Some(directory) = self.inner.read().await.get(component) { - return directory.traverse_rmdir(components); - } + let component = components.pop().ok_or(Errno::Badf)?; + + if !components.is_empty() { + let inner = &*self.inner.read().await; + let directory = inner.get(component).ok_or(Errno::Badf)?; + return directory.traverse_rmdir(components); + } + + let mut guard = self.inner.write().await; + + let obj = guard.remove(component).ok_or(Errno::Noent)?; + if obj.get_kind() != NodeKind::Directory { + guard.insert(component.to_owned(), obj); + return Err(Errno::Notdir); } - Err(Errno::Badf) + Ok(()) }, None, ) @@ -575,23 +572,23 @@ impl VfsNode for MemDirectory { fn traverse_unlink(&self, components: &mut Vec<&str>) -> io::Result<()> { block_on( async { - if let Some(component) = components.pop() { - if components.is_empty() { - let mut guard = self.inner.write().await; - - let obj = guard.remove(component).ok_or(Errno::Noent)?; - if obj.get_kind() == NodeKind::File { - return Ok(()); - } else { - guard.insert(component.to_owned(), obj); - return Err(Errno::Isdir); - } - } else if let Some(directory) = self.inner.read().await.get(component) { - return directory.traverse_unlink(components); - } + let component = components.pop().ok_or(Errno::Badf)?; + + if !components.is_empty() { + let inner = self.inner.read().await; + let directory = inner.get(component).ok_or(Errno::Badf)?; + return directory.traverse_unlink(components); + } + + let mut guard = self.inner.write().await; + + let obj = guard.remove(component).ok_or(Errno::Noent)?; + if obj.get_kind() != NodeKind::File { + guard.insert(component.to_owned(), obj); + return Err(Errno::Isdir); } - Err(Errno::Badf) + Ok(()) }, None, ) @@ -601,19 +598,17 @@ impl VfsNode for MemDirectory { block_on( async { if let Some(component) = components.pop() { - if let Some(directory) = self.inner.read().await.get(component) { - directory.traverse_readdir(components) - } else { - Err(Errno::Badf) - } - } else { - let mut entries: Vec = Vec::new(); - for name in self.inner.read().await.keys() { - entries.push(DirectoryEntry::new(name.clone())); - } - - Ok(entries) + let inner = self.inner.read().await; + let directory = inner.get(component).ok_or(Errno::Badf)?; + return directory.traverse_readdir(components); + }; + + let mut entries = Vec::new(); + for name in self.inner.read().await.keys() { + entries.push(DirectoryEntry::new(name.clone())); } + + Ok(entries) }, None, ) @@ -622,21 +617,17 @@ impl VfsNode for MemDirectory { fn traverse_lstat(&self, components: &mut Vec<&str>) -> io::Result { block_on( async { - if let Some(component) = components.pop() { - if components.is_empty() - && let Some(node) = self.inner.read().await.get(component) - { - return node.get_file_attributes(); - } - - if let Some(directory) = self.inner.read().await.get(component) { - directory.traverse_lstat(components) - } else { - Err(Errno::Badf) - } - } else { - Err(Errno::Nosys) + let component = components.pop().ok_or(Errno::Nosys)?; + + if !components.is_empty() { + let inner = self.inner.read().await; + let directory = inner.get(component).ok_or(Errno::Badf)?; + return directory.traverse_lstat(components); } + + let inner = self.inner.read().await; + let node = inner.get(component).ok_or(Errno::Badf)?; + node.get_file_attributes() }, None, ) @@ -645,21 +636,17 @@ impl VfsNode for MemDirectory { fn traverse_stat(&self, components: &mut Vec<&str>) -> io::Result { block_on( async { - if let Some(component) = components.pop() { - if components.is_empty() - && let Some(node) = self.inner.read().await.get(component) - { - return node.get_file_attributes(); - } - - if let Some(directory) = self.inner.read().await.get(component) { - directory.traverse_stat(components) - } else { - Err(Errno::Badf) - } - } else { - Err(Errno::Nosys) + let component = components.pop().ok_or(Errno::Nosys)?; + + if !components.is_empty() { + let inner = self.inner.read().await; + let directory = inner.get(component).ok_or(Errno::Badf)?; + return directory.traverse_stat(components); } + + let inner = self.inner.read().await; + let node = inner.get(component).ok_or(Errno::Badf)?; + node.get_file_attributes() }, None, ) @@ -668,18 +655,18 @@ impl VfsNode for MemDirectory { fn traverse_mount(&self, components: &mut Vec<&str>, obj: Box) -> io::Result<()> { block_on( async { - if let Some(component) = components.pop() { - if let Some(directory) = self.inner.read().await.get(component) { - return directory.traverse_mount(components, obj); - } - - if components.is_empty() { - self.inner.write().await.insert(component.to_owned(), obj); - return Ok(()); - } + let component = components.pop().ok_or(Errno::Badf)?; + + if let Some(directory) = self.inner.read().await.get(component) { + return directory.traverse_mount(components, obj); + } + + if !components.is_empty() { + return Err(Errno::Badf); } - Err(Errno::Badf) + self.inner.write().await.insert(component.to_owned(), obj); + Ok(()) }, None, ) @@ -702,22 +689,20 @@ impl VfsNode for MemDirectory { ) -> io::Result<()> { block_on( async { - if let Some(component) = components.pop() { - if components.is_empty() { - let file = RomFile::new(data, mode); - self.inner - .write() - .await - .insert(component.to_owned(), Box::new(file)); - return Ok(()); - } - - if let Some(directory) = self.inner.read().await.get(component) { - return directory.traverse_create_file(components, data, mode); - } + let component = components.pop().ok_or(Errno::Noent)?; + + if !components.is_empty() { + let inner = self.inner.read().await; + let directory = inner.get(component).ok_or(Errno::Noent)?; + return directory.traverse_create_file(components, data, mode); } - Err(Errno::Noent) + let file = RomFile::new(data, mode); + self.inner + .write() + .await + .insert(component.to_owned(), Box::new(file)); + Ok(()) }, None, ) diff --git a/src/fs/mod.rs b/src/fs/mod.rs index b5a14ba0e0..a8d3927609 100644 --- a/src/fs/mod.rs +++ b/src/fs/mod.rs @@ -442,31 +442,32 @@ where F: FnOnce(&str) -> io::Result, { if name.starts_with("/") { - callback(name) - } else { - let cwd = WORKING_DIRECTORY.lock(); - if let Some(cwd) = cwd.as_ref() { - let mut path = String::with_capacity(cwd.len() + name.len() + 1); - path.push_str(cwd); - path.push('/'); - path.push_str(name); - - callback(&path) - } else { - // Relative path with no CWD, this is weird/impossible - Err(Errno::Badf) - } + return callback(name); } + + let cwd = WORKING_DIRECTORY.lock(); + + let Some(cwd) = cwd.as_ref() else { + // Relative path with no CWD, this is weird/impossible + return Err(Errno::Badf); + }; + + let mut path = String::with_capacity(cwd.len() + name.len() + 1); + path.push_str(cwd); + path.push('/'); + path.push_str(name); + + callback(&path) } pub fn truncate(name: &str, size: usize) -> io::Result<()> { with_relative_filename(name, |name| { let fs = FILESYSTEM.get().ok_or(Errno::Inval)?; - if let Ok(file) = fs.open(name, OpenOption::O_TRUNC, AccessPermission::empty()) { - block_on(async { file.read().await.truncate(size).await }, None) - } else { - Err(Errno::Badf) - } + let file = fs + .open(name, OpenOption::O_TRUNC, AccessPermission::empty()) + .map_err(|_| Errno::Badf)?; + + block_on(async { file.read().await.truncate(size).await }, None) }) } @@ -488,11 +489,8 @@ pub fn open(name: &str, flags: OpenOption, mode: AccessPermission) -> io::Result pub fn get_cwd() -> io::Result { let cwd = WORKING_DIRECTORY.lock(); - if let Some(cwd) = cwd.as_ref() { - Ok(cwd.clone()) - } else { - Err(Errno::Noent) - } + let cwd = cwd.as_ref().ok_or(Errno::Noent)?; + Ok(cwd.clone()) } pub fn set_cwd(cwd: &str) -> io::Result<()> { @@ -502,9 +500,7 @@ pub fn set_cwd(cwd: &str) -> io::Result<()> { if cwd.starts_with("/") { *working_dir = Some(cwd.to_owned()); } else { - let Some(working_dir) = working_dir.as_mut() else { - return Err(Errno::Badf); - }; + let working_dir = working_dir.as_mut().ok_or(Errno::Badf)?; working_dir.push('/'); working_dir.push_str(cwd); } diff --git a/src/fs/uhyve.rs b/src/fs/uhyve.rs index 91108e1ba7..36abc36b07 100644 --- a/src/fs/uhyve.rs +++ b/src/fs/uhyve.rs @@ -40,11 +40,11 @@ impl UhyveFileHandleInner { }; uhyve_hypercall(Hypercall::FileLseek(&mut lseek_params)); - if lseek_params.offset >= 0 { - Ok(lseek_params.offset) - } else { - Err(Errno::Inval) + if lseek_params.offset < 0 { + return Err(Errno::Inval); } + + Ok(lseek_params.offset) } } @@ -62,11 +62,11 @@ impl Read for UhyveFileHandleInner { }; uhyve_hypercall(Hypercall::FileRead(&mut read_params)); - if read_params.ret >= 0 { - Ok(read_params.ret.try_into().unwrap()) - } else { - Err(Errno::Io) + if read_params.ret < 0 { + return Err(Errno::Io); } + + Ok(read_params.ret.try_into().unwrap()) } } @@ -187,13 +187,13 @@ impl VfsNode for UhyveDirectory { }; uhyve_hypercall(Hypercall::FileOpen(&mut open_params)); - if open_params.ret > 0 { - Ok(Arc::new(async_lock::RwLock::new(UhyveFileHandle::new( - open_params.ret, - )))) - } else { - Err(Errno::Io) + if open_params.ret <= 0 { + return Err(Errno::Io); } + + Ok(Arc::new(async_lock::RwLock::new(UhyveFileHandle::new( + open_params.ret, + )))) } fn traverse_unlink(&self, components: &mut Vec<&str>) -> io::Result<()> { @@ -209,11 +209,11 @@ impl VfsNode for UhyveDirectory { }; uhyve_hypercall(Hypercall::FileUnlink(&mut unlink_params)); - if unlink_params.ret == 0 { - Ok(()) - } else { - Err(Errno::Io) + if unlink_params.ret != 0 { + return Err(Errno::Io); } + + Ok(()) } fn traverse_rmdir(&self, _components: &mut Vec<&str>) -> io::Result<()> { @@ -231,38 +231,14 @@ impl VfsNode for UhyveDirectory { pub(crate) fn init() { info!("Try to initialize uhyve filesystem"); + let mount_str = fdt().and_then(|fdt| { fdt.find_node("/uhyve,mounts") .and_then(|node| node.property("mounts")) .and_then(|property| property.as_str()) }); - if let Some(mount_str) = mount_str { - assert_ne!(mount_str.len(), 0, "Invalid /uhyve,mounts node in FDT"); - for mount_point in mount_str.split('\0') { - info!("Mounting uhyve filesystem at {mount_point}"); - - if let Err(errno) = fs::FILESYSTEM.get().unwrap().mount( - mount_point, - Box::new(UhyveDirectory::new(Some(mount_point.to_owned()))), - ) { - assert_eq!(errno, Errno::Badf); - debug!( - "Mounting of {mount_point} failed with {errno:?}. Creating missing parent folders" - ); - let (parent_path, _file_name) = mount_point.rsplit_once('/').unwrap(); - create_dir_recursive(parent_path, AccessPermission::S_IRWXU).unwrap(); - - fs::FILESYSTEM - .get() - .unwrap() - .mount( - mount_point, - Box::new(UhyveDirectory::new(Some(mount_point.to_owned()))), - ) - .unwrap(); - } - } - } else { + + let Some(mount_str) = mount_str else { // No FDT -> Uhyve legacy mounting (to /root) let mount_point = hermit_var_or!("UHYVE_MOUNT", "/root").to_owned(); info!("Mounting uhyve filesystem at {mount_point}"); @@ -274,5 +250,28 @@ pub(crate) fn init() { Box::new(UhyveDirectory::new(Some(mount_point.clone()))), ) .expect("Mount failed. Duplicate mount_point?"); + return; + }; + + assert_ne!(mount_str.len(), 0, "Invalid /uhyve,mounts node in FDT"); + for mount_point in mount_str.split('\0') { + info!("Mounting uhyve filesystem at {mount_point}"); + + let obj = Box::new(UhyveDirectory::new(Some(mount_point.to_owned()))); + let Err(errno) = fs::FILESYSTEM.get().unwrap().mount(mount_point, obj) else { + return; + }; + + assert_eq!(errno, Errno::Badf); + debug!("Mounting of {mount_point} failed with {errno:?}. Creating missing parent folders"); + let (parent_path, _file_name) = mount_point.rsplit_once('/').unwrap(); + create_dir_recursive(parent_path, AccessPermission::S_IRWXU).unwrap(); + + let obj = Box::new(UhyveDirectory::new(Some(mount_point.to_owned()))); + fs::FILESYSTEM + .get() + .unwrap() + .mount(mount_point, obj) + .unwrap(); } } From a4e9419022e98643b1acde34705e4ea8ec7fbcbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kr=C3=B6ning?= Date: Wed, 11 Feb 2026 15:35:08 +0100 Subject: [PATCH 4/7] refactor(fd): do more early returns --- src/fd/eventfd.rs | 8 +++---- src/fd/mod.rs | 18 ++++++++------- src/fd/socket/tcp.rs | 52 ++++++++++++++++++++++---------------------- src/fd/socket/udp.rs | 50 ++++++++++++++++++++---------------------- 4 files changed, 64 insertions(+), 64 deletions(-) diff --git a/src/fd/eventfd.rs b/src/fd/eventfd.rs index e7c030fcbb..945d09d4f0 100644 --- a/src/fd/eventfd.rs +++ b/src/fd/eventfd.rs @@ -106,11 +106,11 @@ impl ObjectInterface for EventFd { guard.counter += c; if self.flags.contains(EventFlags::EFD_SEMAPHORE) { for _i in 0..c { - if let Some(cx) = guard.read_queue.pop_front() { - cx.wake_by_ref(); - } else { + let Some(cx) = guard.read_queue.pop_front() else { break; - } + }; + + cx.wake_by_ref(); } } else if let Some(cx) = guard.read_queue.pop_front() { cx.wake_by_ref(); diff --git a/src/fd/mod.rs b/src/fd/mod.rs index c2d3ef3a90..6f6fc99ddc 100644 --- a/src/fd/mod.rs +++ b/src/fd/mod.rs @@ -377,14 +377,16 @@ async fn poll_fds(fds: &mut [PollFd]) -> io::Result { for i in &mut *fds { let fd = i.fd; i.revents = PollEvent::empty(); - if let Ok(obj) = core_scheduler().get_object(fd) { - let mut pinned = pin!(async { obj.read().await.poll(i.events).await }); - if let Ready(Ok(e)) = pinned.as_mut().poll(cx) - && !e.is_empty() - { - counter += 1; - i.revents = e; - } + let Ok(obj) = core_scheduler().get_object(fd) else { + continue; + }; + + let mut pinned = pin!(async { obj.read().await.poll(i.events).await }); + if let Ready(Ok(e)) = pinned.as_mut().poll(cx) + && !e.is_empty() + { + counter += 1; + i.revents = e; } } diff --git a/src/fd/socket/tcp.rs b/src/fd/socket/tcp.rs index 423ae8f870..689f6090cf 100644 --- a/src/fd/socket/tcp.rs +++ b/src/fd/socket/tcp.rs @@ -266,38 +266,38 @@ impl ObjectInterface for Socket { async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> { #[allow(irrefutable_let_patterns)] - if let ListenEndpoint::Ip(endpoint) = endpoint { - self.endpoint.port = endpoint.port; - if let Some(addr) = endpoint.addr { - self.endpoint.addr = addr; - } - Ok(()) - } else { - Err(Errno::Io) + let ListenEndpoint::Ip(endpoint) = endpoint else { + return Err(Errno::Io); + }; + + self.endpoint.port = endpoint.port; + if let Some(addr) = endpoint.addr { + self.endpoint.addr = addr; } + Ok(()) } async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> { #[allow(irrefutable_let_patterns)] - if let Endpoint::Ip(endpoint) = endpoint { - self.with_context(|socket, cx| socket.connect(cx, endpoint, get_ephemeral_port())) - .map_err(|_| Errno::Io)?; - - future::poll_fn(|cx| { - self.with(|socket| match socket.state() { - tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(Errno::Fault)), - tcp::State::Listen => Poll::Ready(Err(Errno::Io)), - tcp::State::SynSent | tcp::State::SynReceived => { - socket.register_send_waker(cx.waker()); - Poll::Pending - } - _ => Poll::Ready(Ok(())), - }) + let Endpoint::Ip(endpoint) = endpoint else { + return Err(Errno::Io); + }; + + self.with_context(|socket, cx| socket.connect(cx, endpoint, get_ephemeral_port())) + .map_err(|_| Errno::Io)?; + + future::poll_fn(|cx| { + self.with(|socket| match socket.state() { + tcp::State::Closed | tcp::State::TimeWait => Poll::Ready(Err(Errno::Fault)), + tcp::State::Listen => Poll::Ready(Err(Errno::Io)), + tcp::State::SynSent | tcp::State::SynReceived => { + socket.register_send_waker(cx.waker()); + Poll::Pending + } + _ => Poll::Ready(Ok(())), }) - .await - } else { - Err(Errno::Io) - } + }) + .await } async fn accept( diff --git a/src/fd/socket/udp.rs b/src/fd/socket/udp.rs index 8db1a6780c..80af99c664 100644 --- a/src/fd/socket/udp.rs +++ b/src/fd/socket/udp.rs @@ -124,35 +124,35 @@ impl ObjectInterface for Socket { async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> { #[allow(irrefutable_let_patterns)] - if let ListenEndpoint::Ip(endpoint) = endpoint { - self.local_endpoint.port = endpoint.port; - if let Some(addr) = endpoint.addr { - self.local_endpoint.addr = addr; - } - self.with(|socket| socket.bind(endpoint).map_err(|_| Errno::Addrinuse)) - } else { - Err(Errno::Io) + let ListenEndpoint::Ip(endpoint) = endpoint else { + return Err(Errno::Io); + }; + + self.local_endpoint.port = endpoint.port; + if let Some(addr) = endpoint.addr { + self.local_endpoint.addr = addr; } + self.with(|socket| socket.bind(endpoint).map_err(|_| Errno::Addrinuse)) } async fn connect(&mut self, endpoint: Endpoint) -> io::Result<()> { #[allow(irrefutable_let_patterns)] - if let Endpoint::Ip(endpoint) = endpoint { - self.remote_endpoint = Some(endpoint); - Ok(()) - } else { - Err(Errno::Io) - } + let Endpoint::Ip(endpoint) = endpoint else { + return Err(Errno::Io); + }; + + self.remote_endpoint = Some(endpoint); + Ok(()) } async fn sendto(&self, buf: &[u8], endpoint: Endpoint) -> io::Result { #[allow(irrefutable_let_patterns)] - if let Endpoint::Ip(endpoint) = endpoint { - let meta = UdpMetadata::from(endpoint); - self.write_with_meta(buf, &meta).await - } else { - Err(Errno::Io) - } + let Endpoint::Ip(endpoint) = endpoint else { + return Err(Errno::Io); + }; + + let meta = UdpMetadata::from(endpoint); + self.write_with_meta(buf, &meta).await } async fn recvfrom(&self, buffer: &mut [MaybeUninit]) -> io::Result<(usize, Endpoint)> { @@ -219,12 +219,10 @@ impl ObjectInterface for Socket { } async fn write(&self, buf: &[u8]) -> io::Result { - if let Some(endpoint) = self.remote_endpoint { - let meta = UdpMetadata::from(endpoint); - self.write_with_meta(buf, &meta).await - } else { - Err(Errno::Inval) - } + let endpoint = self.remote_endpoint.ok_or(Errno::Inval)?; + + let meta = UdpMetadata::from(endpoint); + self.write_with_meta(buf, &meta).await } async fn status_flags(&self) -> io::Result { From 0e7823bfb61a7ec14c2cb54f422a5310533f3f4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kr=C3=B6ning?= Date: Wed, 11 Feb 2026 15:49:04 +0100 Subject: [PATCH 5/7] refactor(executor): do more early returns --- src/executor/mod.rs | 8 +- src/executor/network.rs | 51 ++++++------ src/executor/vsock.rs | 177 ++++++++++++++++++++-------------------- 3 files changed, 119 insertions(+), 117 deletions(-) diff --git a/src/executor/mod.rs b/src/executor/mod.rs index ca250bfa28..1eee8fe743 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -58,9 +58,11 @@ impl WakerRegistration { /// Wake the registered waker, if any. #[allow(dead_code)] pub fn wake(&mut self) { - if let Some(w) = self.waker.take() { - w.wake(); - } + let Some(w) = self.waker.take() else { + return; + }; + + w.wake(); } } diff --git a/src/executor/network.rs b/src/executor/network.rs index 3fc25abcb1..5ace81971d 100644 --- a/src/executor/network.rs +++ b/src/executor/network.rs @@ -191,22 +191,21 @@ async fn dhcpv4_run() { async fn network_run() { future::poll_fn(|cx| { - if let Some(mut guard) = NIC.try_lock() { - match &mut *guard { - NetworkState::Initialized(nic) => { - nic.poll_common(now()); - // FIXME: only wake when progress can be made - cx.waker().wake_by_ref(); - Poll::Pending - } - _ => Poll::Ready(()), - } - } else { + let Some(mut guard) = NIC.try_lock() else { // FIXME: only wake when progress can be made cx.waker().wake_by_ref(); // another task is already using the NIC => don't check - Poll::Pending - } + return Poll::Pending; + }; + + let NetworkState::Initialized(nic) = &mut *guard else { + return Poll::Ready(()); + }; + + nic.poll_common(now()); + // FIXME: only wake when progress can be made + cx.waker().wake_by_ref(); + Poll::Pending }) .await; } @@ -254,18 +253,20 @@ pub(crate) fn init() { *guard = NetworkInterface::create(); - if let NetworkState::Initialized(nic) = &mut *guard { - let time = now(); - nic.poll_common(time); - let wakeup_time = nic - .poll_delay(time) - .map(|d| crate::arch::processor::get_timer_ticks() + d.total_micros()); - crate::core_scheduler().add_network_timer(wakeup_time); - - spawn(network_run()); - #[cfg(feature = "dhcpv4")] - spawn(dhcpv4_run()); - } + let NetworkState::Initialized(nic) = &mut *guard else { + return; + }; + + let time = now(); + nic.poll_common(time); + let wakeup_time = nic + .poll_delay(time) + .map(|d| crate::arch::processor::get_timer_ticks() + d.total_micros()); + crate::core_scheduler().add_network_timer(wakeup_time); + + spawn(network_run()); + #[cfg(feature = "dhcpv4")] + spawn(dhcpv4_run()); } impl<'a> NetworkInterface<'a> { diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index 8d8beea7f2..f992a478d5 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -62,98 +62,97 @@ impl RawSocket { async fn vsock_run() { future::poll_fn(|cx| { - if let Some(driver) = hardware::get_vsock_driver() { - const HEADER_SIZE: usize = mem::size_of::(); - let mut driver_guard = driver.lock(); - let mut hdr: Option = None; - let mut fwd_cnt: u32 = 0; - - driver_guard.process_packet(|header, data| { - let op = Op::try_from(header.op.to_ne()).unwrap(); - let port = header.dst_port.to_ne(); - let type_ = Type::try_from(header.type_.to_ne()).unwrap(); - let mut vsock_guard = VSOCK_MAP.lock(); - let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap(); - - if let Some(raw) = vsock_guard.get_mut_socket(port) { - if op == Op::Request && raw.state == VsockState::Listen && type_ == Type::Stream - { - raw.state = VsockState::ReceiveRequest; - raw.remote_cid = header_cid; - raw.remote_port = header.src_port.to_ne(); - raw.peer_buf_alloc = header.buf_alloc.to_ne(); - raw.rx_waker.wake(); - } else if (raw.state == VsockState::Connected - || raw.state == VsockState::Shutdown) - && type_ == Type::Stream - && op == Op::Rw - { - if raw.remote_cid == header_cid { - raw.buffer.extend_from_slice(data); - raw.fwd_cnt = - raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap()); - raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); - raw.tx_waker.wake(); - raw.rx_waker.wake(); - hdr = Some(*header); - fwd_cnt = raw.fwd_cnt; - } else { - trace!("Receive message from invalid source {header_cid}"); - } - } else if op == Op::CreditUpdate { - if raw.remote_cid == header_cid { - raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); - raw.tx_waker.wake(); - } else { - trace!("Receive message from invalid source {header_cid}"); - } - } else if op == Op::Shutdown { - if raw.remote_cid == header_cid { - raw.state = VsockState::Shutdown; - } else { - trace!("Receive message from invalid source {header_cid}"); - } - } else if op == Op::Response && type_ == Type::Stream { - if raw.remote_cid == header_cid && raw.state == VsockState::Connecting { - raw.state = VsockState::Connected; - } - } else if raw.remote_cid == header_cid { - hdr = Some(*header); - fwd_cnt = raw.fwd_cnt; - } + let Some(driver) = hardware::get_vsock_driver() else { + return Poll::Ready(()); + }; + + const HEADER_SIZE: usize = mem::size_of::(); + let mut driver_guard = driver.lock(); + let mut hdr: Option = None; + let mut fwd_cnt: u32 = 0; + + driver_guard.process_packet(|header, data| { + let op = Op::try_from(header.op.to_ne()).unwrap(); + let port = header.dst_port.to_ne(); + let type_ = Type::try_from(header.type_.to_ne()).unwrap(); + let mut vsock_guard = VSOCK_MAP.lock(); + let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap(); + + let Some(raw) = vsock_guard.get_mut_socket(port) else { + return; + }; + + if op == Op::Request && raw.state == VsockState::Listen && type_ == Type::Stream { + raw.state = VsockState::ReceiveRequest; + raw.remote_cid = header_cid; + raw.remote_port = header.src_port.to_ne(); + raw.peer_buf_alloc = header.buf_alloc.to_ne(); + raw.rx_waker.wake(); + } else if (raw.state == VsockState::Connected || raw.state == VsockState::Shutdown) + && type_ == Type::Stream + && op == Op::Rw + { + if raw.remote_cid == header_cid { + raw.buffer.extend_from_slice(data); + raw.fwd_cnt = raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap()); + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); + raw.rx_waker.wake(); + hdr = Some(*header); + fwd_cnt = raw.fwd_cnt; + } else { + trace!("Receive message from invalid source {header_cid}"); } - }); - - if let Some(hdr) = hdr { - driver_guard.send_packet(HEADER_SIZE, |buffer| { - let response = unsafe { &mut *buffer.as_mut_ptr().cast::() }; - - response.src_cid = hdr.dst_cid; - response.dst_cid = hdr.src_cid; - response.src_port = hdr.dst_port; - response.dst_port = hdr.src_port; - response.len = le32::from_ne(0); - response.type_ = hdr.type_; - if hdr.op.to_ne() == u16::from(Op::CreditRequest) - || hdr.op.to_ne() == u16::from(Op::Rw) - { - response.op = le16::from_ne(Op::CreditUpdate.into()); - } else { - // reset connection - response.op = le16::from_ne(Op::Rst.into()); - } - response.flags = le32::from_ne(0); - response.buf_alloc = le32::from_ne(RAW_SOCKET_BUFFER_SIZE as u32); - response.fwd_cnt = le32::from_ne(fwd_cnt); - }); + } else if op == Op::CreditUpdate { + if raw.remote_cid == header_cid { + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); + } else { + trace!("Receive message from invalid source {header_cid}"); + } + } else if op == Op::Shutdown { + if raw.remote_cid == header_cid { + raw.state = VsockState::Shutdown; + } else { + trace!("Receive message from invalid source {header_cid}"); + } + } else if op == Op::Response && type_ == Type::Stream { + if raw.remote_cid == header_cid && raw.state == VsockState::Connecting { + raw.state = VsockState::Connected; + } + } else if raw.remote_cid == header_cid { + hdr = Some(*header); + fwd_cnt = raw.fwd_cnt; } - - // FIXME: only wake when progress can be made - cx.waker().wake_by_ref(); - Poll::Pending - } else { - Poll::Ready(()) + }); + + if let Some(hdr) = hdr { + driver_guard.send_packet(HEADER_SIZE, |buffer| { + let response = unsafe { &mut *buffer.as_mut_ptr().cast::() }; + + response.src_cid = hdr.dst_cid; + response.dst_cid = hdr.src_cid; + response.src_port = hdr.dst_port; + response.dst_port = hdr.src_port; + response.len = le32::from_ne(0); + response.type_ = hdr.type_; + if hdr.op.to_ne() == u16::from(Op::CreditRequest) + || hdr.op.to_ne() == u16::from(Op::Rw) + { + response.op = le16::from_ne(Op::CreditUpdate.into()); + } else { + // reset connection + response.op = le16::from_ne(Op::Rst.into()); + } + response.flags = le32::from_ne(0); + response.buf_alloc = le32::from_ne(RAW_SOCKET_BUFFER_SIZE as u32); + response.fwd_cnt = le32::from_ne(fwd_cnt); + }); } + + // FIXME: only wake when progress can be made + cx.waker().wake_by_ref(); + Poll::Pending }) .await; } From 13cb8099c4967407dae1f6ebc02335e2f760c1ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kr=C3=B6ning?= Date: Wed, 11 Feb 2026 15:49:26 +0100 Subject: [PATCH 6/7] refactor(syscalls): do more early returns --- src/syscalls/mod.rs | 44 ++++----- src/syscalls/socket/mod.rs | 197 +++++++++++++++++++------------------ src/syscalls/tasks.rs | 8 +- 3 files changed, 126 insertions(+), 123 deletions(-) diff --git a/src/syscalls/mod.rs b/src/syscalls/mod.rs index 82f162cb15..cf3b73b0e0 100644 --- a/src/syscalls/mod.rs +++ b/src/syscalls/mod.rs @@ -323,11 +323,11 @@ pub unsafe extern "C" fn sys_fstat(fd: RawFd, stat: *mut FileAttr) -> i32 { #[hermit_macro::system(errno)] #[unsafe(no_mangle)] pub unsafe extern "C" fn sys_opendir(name: *const c_char) -> RawFd { - if let Ok(name) = unsafe { CStr::from_ptr(name) }.to_str() { - crate::fs::opendir(name).unwrap_or_else(|e| -i32::from(e)) - } else { - -i32::from(Errno::Inval) - } + let Ok(name) = unsafe { CStr::from_ptr(name) }.to_str() else { + return -i32::from(Errno::Inval); + }; + + crate::fs::opendir(name).unwrap_or_else(|e| -i32::from(e)) } #[hermit_macro::system(errno)] @@ -340,11 +340,11 @@ pub unsafe extern "C" fn sys_open(name: *const c_char, flags: i32, mode: u32) -> return -i32::from(Errno::Inval); }; - if let Ok(name) = unsafe { CStr::from_ptr(name) }.to_str() { - crate::fs::open(name, flags, mode).unwrap_or_else(|e| -i32::from(e)) - } else { - -i32::from(Errno::Inval) - } + let Ok(name) = unsafe { CStr::from_ptr(name) }.to_str() else { + return -i32::from(Errno::Inval); + }; + + crate::fs::open(name, flags, mode).unwrap_or_else(|e| -i32::from(e)) } #[hermit_macro::system] @@ -395,13 +395,13 @@ pub extern "C" fn sys_fchdir(_fd: RawFd) -> i32 { #[hermit_macro::system(errno)] #[unsafe(no_mangle)] pub unsafe extern "C" fn sys_chdir(path: *mut c_char) -> i32 { - if let Ok(name) = unsafe { CStr::from_ptr(path) }.to_str() { - crate::fs::set_cwd(name) - .map(|()| 0) - .unwrap_or_else(|e| -i32::from(e)) - } else { - -i32::from(Errno::Inval) - } + let Ok(name) = unsafe { CStr::from_ptr(path) }.to_str() else { + return -i32::from(Errno::Inval); + }; + + crate::fs::set_cwd(name) + .map(|()| 0) + .unwrap_or_else(|e| -i32::from(e)) } #[hermit_macro::system] @@ -851,11 +851,11 @@ pub unsafe extern "C" fn sys_poll(fds: *mut PollFd, nfds: usize, timeout: i32) - #[hermit_macro::system(errno)] #[unsafe(no_mangle)] pub extern "C" fn sys_eventfd(initval: u64, flags: i16) -> i32 { - if let Some(flags) = EventFlags::from_bits(flags) { - crate::fd::eventfd(initval, flags).unwrap_or_else(|e| -i32::from(e)) - } else { - -i32::from(Errno::Inval) - } + let Some(flags) = EventFlags::from_bits(flags) else { + return -i32::from(Errno::Inval); + }; + + crate::fd::eventfd(initval, flags).unwrap_or_else(|e| -i32::from(e)) } #[hermit_macro::system] diff --git a/src/syscalls/socket/mod.rs b/src/syscalls/socket/mod.rs index 34cf16791c..7fecea9dbe 100644 --- a/src/syscalls/socket/mod.rs +++ b/src/syscalls/socket/mod.rs @@ -548,11 +548,10 @@ pub unsafe extern "C" fn sys_getaddrbyname( }; let name = unsafe { core::ffi::CStr::from_ptr(name) }; - let name = if let Ok(name) = name.to_str() { - name.to_owned() - } else { + let Ok(name) = name.to_str() else { return -i32::from(Errno::Inval); }; + let name = name.to_owned(); let query = { let mut guard = NIC.lock(); @@ -865,53 +864,53 @@ pub unsafe extern "C" fn sys_getsockname( obj.map_or_else( |e| -i32::from(e), |v| { - if let Ok(Some(endpoint)) = block_on(async { v.read().await.getsockname().await }, None) - { - if !addr.is_null() && !addrlen.is_null() { - let addrlen = unsafe { &mut *addrlen }; - - match endpoint { - #[cfg(feature = "net")] - Endpoint::Ip(endpoint) => match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= u32::try_from(size_of::()).unwrap() { - let addr = unsafe { &mut *addr.cast() }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - - 0 - } else { - -i32::from(Errno::Inval) - } - } - #[cfg(feature = "net")] - IpAddress::Ipv6(_) => { - if *addrlen >= u32::try_from(size_of::()).unwrap() { - let addr = unsafe { &mut *addr.cast() }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - - 0 - } else { - -i32::from(Errno::Inval) - } - } - }, - #[cfg(feature = "virtio-vsock")] - Endpoint::Vsock(_) => { - if *addrlen >= u32::try_from(size_of::()).unwrap() { - warn!("unsupported device"); - 0 - } else { - -i32::from(Errno::Inval) - } + let Ok(Some(endpoint)) = block_on(async { v.read().await.getsockname().await }, None) + else { + return -i32::from(Errno::Inval); + }; + + if addr.is_null() || addrlen.is_null() { + return -i32::from(Errno::Inval); + } + + let addrlen = unsafe { &mut *addrlen }; + + match endpoint { + #[cfg(feature = "net")] + Endpoint::Ip(endpoint) => match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= u32::try_from(size_of::()).unwrap() { + let addr = unsafe { &mut *addr.cast() }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + + 0 + } else { + -i32::from(Errno::Inval) } } - } else { - -i32::from(Errno::Inval) + #[cfg(feature = "net")] + IpAddress::Ipv6(_) => { + if *addrlen >= u32::try_from(size_of::()).unwrap() { + let addr = unsafe { &mut *addr.cast() }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + + 0 + } else { + -i32::from(Errno::Inval) + } + } + }, + #[cfg(feature = "virtio-vsock")] + Endpoint::Vsock(_) => { + if *addrlen >= u32::try_from(size_of::()).unwrap() { + warn!("unsupported device"); + 0 + } else { + -i32::from(Errno::Inval) + } } - } else { - -i32::from(Errno::Inval) } }, ) @@ -1027,44 +1026,46 @@ pub unsafe extern "C" fn sys_getpeername( obj.map_or_else( |e| -i32::from(e), |v| { - if let Ok(Some(endpoint)) = block_on(async { v.read().await.getpeername().await }, None) - { - if !addr.is_null() && !addrlen.is_null() { - let addrlen = unsafe { &mut *addrlen }; - - match endpoint { - #[cfg(feature = "net")] - Endpoint::Ip(endpoint) => match endpoint.addr { - IpAddress::Ipv4(_) => { - if *addrlen >= u32::try_from(size_of::()).unwrap() { - let addr = unsafe { &mut *addr.cast() }; - *addr = sockaddr_in::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return -i32::from(Errno::Inval); - } - } - IpAddress::Ipv6(_) => { - if *addrlen >= u32::try_from(size_of::()).unwrap() { - let addr = unsafe { &mut *addr.cast() }; - *addr = sockaddr_in6::from(endpoint); - *addrlen = size_of::().try_into().unwrap(); - } else { - return -i32::from(Errno::Inval); - } - } - }, - #[cfg(feature = "virtio-vsock")] - Endpoint::Vsock(_) => { - if *addrlen >= u32::try_from(size_of::()).unwrap() { - warn!("unsupported device"); - } else { - return -i32::from(Errno::Inval); - } + let Ok(Some(endpoint)) = block_on(async { v.read().await.getpeername().await }, None) + else { + return 0; + }; + + if addr.is_null() || addrlen.is_null() { + return -i32::from(Errno::Inval); + } + + let addrlen = unsafe { &mut *addrlen }; + + match endpoint { + #[cfg(feature = "net")] + Endpoint::Ip(endpoint) => match endpoint.addr { + IpAddress::Ipv4(_) => { + if *addrlen >= u32::try_from(size_of::()).unwrap() { + let addr = unsafe { &mut *addr.cast() }; + *addr = sockaddr_in::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -i32::from(Errno::Inval); } } - } else { - return -i32::from(Errno::Inval); + IpAddress::Ipv6(_) => { + if *addrlen >= u32::try_from(size_of::()).unwrap() { + let addr = unsafe { &mut *addr.cast() }; + *addr = sockaddr_in6::from(endpoint); + *addrlen = size_of::().try_into().unwrap(); + } else { + return -i32::from(Errno::Inval); + } + } + }, + #[cfg(feature = "virtio-vsock")] + Endpoint::Vsock(_) => { + if *addrlen >= u32::try_from(size_of::()).unwrap() { + warn!("unsupported device"); + } else { + return -i32::from(Errno::Inval); + } } } @@ -1158,22 +1159,22 @@ pub unsafe extern "C" fn sys_sendto( } } - if let Some(endpoint) = endpoint { - let slice = unsafe { slice::from_raw_parts(buf, len) }; - let obj = get_object(fd); + let Some(endpoint) = endpoint else { + return (-i32::from(Errno::Inval)).try_into().unwrap(); + }; - obj.map_or_else( - |e| isize::try_from(-i32::from(e)).unwrap(), - |v| { - block_on(async { v.read().await.sendto(slice, endpoint).await }, None).map_or_else( - |e| isize::try_from(-i32::from(e)).unwrap(), - |v| v.try_into().unwrap(), - ) - }, - ) - } else { - (-i32::from(Errno::Inval)).try_into().unwrap() - } + let slice = unsafe { slice::from_raw_parts(buf, len) }; + let obj = get_object(fd); + + obj.map_or_else( + |e| isize::try_from(-i32::from(e)).unwrap(), + |v| { + block_on(async { v.read().await.sendto(slice, endpoint).await }, None).map_or_else( + |e| isize::try_from(-i32::from(e)).unwrap(), + |v| v.try_into().unwrap(), + ) + }, + ) } #[hermit_macro::system(errno)] diff --git a/src/syscalls/tasks.rs b/src/syscalls/tasks.rs index 6e6b0f1dff..58b65d0885 100644 --- a/src/syscalls/tasks.rs +++ b/src/syscalls/tasks.rs @@ -231,9 +231,11 @@ pub extern "C" fn sys_block_current_task_with_timeout(timeout: u64) { pub extern "C" fn sys_wakeup_task(id: Tid) { let task_id = TaskId::from(id); - if let Some(handle) = BLOCKED_TASKS.lock().remove(&task_id) { - core_scheduler().custom_wakeup(handle); - } + let Some(handle) = BLOCKED_TASKS.lock().remove(&task_id) else { + return; + }; + + core_scheduler().custom_wakeup(handle); } /// Determine the priority of the current thread From 2226f6bc3c53833d93b0f6c9930a5b7bd1e69dd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Kr=C3=B6ning?= Date: Wed, 11 Feb 2026 15:49:40 +0100 Subject: [PATCH 7/7] refactor: do more early returns --- src/entropy.rs | 14 ++-- src/mm/physicalmem.rs | 18 ++--- src/scheduler/mod.rs | 147 +++++++++++++++++++------------------- src/scheduler/task/mod.rs | 84 ++++++++++------------ src/synch/recmutex.rs | 38 +++++----- src/synch/semaphore.rs | 16 +++-- 6 files changed, 155 insertions(+), 162 deletions(-) diff --git a/src/entropy.rs b/src/entropy.rs index 0ac9469854..f286b40b4f 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -35,14 +35,14 @@ pub fn read(buf: &mut [u8], _flags: Flags) -> isize { let pool = match pool { Some(pool) if now.saturating_sub(pool.last_reseed) <= RESEED_INTERVAL => pool, pool => { - if let Some(seed) = seed_entropy() { - pool.insert(Pool { - rng: ChaCha20Rng::from_seed(seed), - last_reseed: now, - }) - } else { + let Some(seed) = seed_entropy() else { return -i32::from(Errno::Nosys) as isize; - } + }; + + pool.insert(Pool { + rng: ChaCha20Rng::from_seed(seed), + last_reseed: now, + }) } }; diff --git a/src/mm/physicalmem.rs b/src/mm/physicalmem.rs index 05997a4649..e8a1372b29 100644 --- a/src/mm/physicalmem.rs +++ b/src/mm/physicalmem.rs @@ -230,14 +230,16 @@ unsafe fn init() { paging::unmap::(start, count); } - if let Err(_err) = unsafe { detect_from_fdt() } { - cfg_if::cfg_if! { - if #[cfg(any(target_arch = "aarch64", target_arch = "riscv64"))] { - error!("Could not detect physical memory from FDT"); - unsafe { detect_from_limits().unwrap(); } - } else { - panic!("Could not detect physical memory from FDT"); - } + if unsafe { detect_from_fdt().is_ok() } { + return; + } + + cfg_if::cfg_if! { + if #[cfg(any(target_arch = "aarch64", target_arch = "riscv64"))] { + error!("Could not detect physical memory from FDT"); + unsafe { detect_from_limits().unwrap(); } + } else { + panic!("Could not detect physical memory from FDT"); } } } diff --git a/src/scheduler/mod.rs b/src/scheduler/mod.rs index 7597d33ed2..0af4d2e6ba 100644 --- a/src/scheduler/mod.rs +++ b/src/scheduler/mod.rs @@ -108,26 +108,25 @@ impl PerCoreSchedulerExt for &mut PerCoreScheduler { #[cfg(target_arch = "x86_64")] fn reschedule(self) { without_interrupts(|| { - if let Some(last_stack_pointer) = self.scheduler() { - let (new_stack_pointer, is_idle) = { - let borrowed = self.current_task.borrow(); - ( - borrowed.last_stack_pointer, - borrowed.status == TaskStatus::Idle, - ) - }; - - if is_idle || Rc::ptr_eq(&self.current_task, &self.fpu_owner) { - unsafe { - switch_to_fpu_owner( - last_stack_pointer, - new_stack_pointer.as_u64() as usize, - ); - } - } else { - unsafe { - switch_to_task(last_stack_pointer, new_stack_pointer.as_u64() as usize); - } + let Some(last_stack_pointer) = self.scheduler() else { + return; + }; + + let (new_stack_pointer, is_idle) = { + let borrowed = self.current_task.borrow(); + ( + borrowed.last_stack_pointer, + borrowed.status == TaskStatus::Idle, + ) + }; + + if is_idle || Rc::ptr_eq(&self.current_task, &self.fpu_owner) { + unsafe { + switch_to_fpu_owner(last_stack_pointer, new_stack_pointer.as_u64() as usize); + } + } else { + unsafe { + switch_to_task(last_stack_pointer, new_stack_pointer.as_u64() as usize); } } }); @@ -795,60 +794,60 @@ impl PerCoreScheduler { } } - if let Some(task) = new_task { - // There is a new task we want to switch to. + let task = new_task?; + // There is a new task we want to switch to. + + // Handle the current task. + if status == TaskStatus::Running { + // Mark the running task as ready again and add it back to the queue. + self.current_task.borrow_mut().status = TaskStatus::Ready; + self.ready_queue.push(self.current_task.clone()); + } - // Handle the current task. - if status == TaskStatus::Running { - // Mark the running task as ready again and add it back to the queue. - self.current_task.borrow_mut().status = TaskStatus::Ready; - self.ready_queue.push(self.current_task.clone()); + // Handle the new task and get information about it. + let (new_id, new_stack_pointer) = { + let mut borrowed = task.borrow_mut(); + if borrowed.status != TaskStatus::Idle { + // Mark the new task as running. + borrowed.status = TaskStatus::Running; } - // Handle the new task and get information about it. - let (new_id, new_stack_pointer) = { - let mut borrowed = task.borrow_mut(); - if borrowed.status != TaskStatus::Idle { - // Mark the new task as running. - borrowed.status = TaskStatus::Running; - } + (borrowed.id, borrowed.last_stack_pointer) + }; - (borrowed.id, borrowed.last_stack_pointer) - }; + if id == new_id { + return None; + } - if id != new_id { - // Tell the scheduler about the new task. - debug!( - "Switching task from {} to {} (stack {:#X} => {:p})", - id, - new_id, - unsafe { *last_stack_pointer }, - new_stack_pointer - ); - #[cfg(not(target_arch = "riscv64"))] - { - self.current_task = task; - } + // Tell the scheduler about the new task. + debug!( + "Switching task from {} to {} (stack {:#X} => {:p})", + id, + new_id, + unsafe { *last_stack_pointer }, + new_stack_pointer + ); + #[cfg(not(target_arch = "riscv64"))] + { + self.current_task = task; + } - // Finally return the context of the new task. - #[cfg(not(target_arch = "riscv64"))] - return Some(last_stack_pointer); + // Finally return the context of the new task. + #[cfg(not(target_arch = "riscv64"))] + return Some(last_stack_pointer); - #[cfg(target_arch = "riscv64")] - { - if sstatus::read().fs() == sstatus::FS::Dirty { - self.current_task.borrow_mut().last_fpu_state.save(); - } - task.borrow().last_fpu_state.restore(); - self.current_task = task; - unsafe { - switch_to_task(last_stack_pointer, new_stack_pointer.as_usize()); - } - } + #[cfg(target_arch = "riscv64")] + { + if sstatus::read().fs() == sstatus::FS::Dirty { + self.current_task.borrow_mut().last_fpu_state.save(); + } + task.borrow().last_fpu_state.restore(); + self.current_task = task; + unsafe { + switch_to_task(last_stack_pointer, new_stack_pointer.as_usize()); } + None } - - None } } @@ -960,16 +959,16 @@ pub fn join(id: TaskId) -> Result<(), ()> { loop { let mut waiting_tasks_guard = WAITING_TASKS.lock(); - if let Some(queue) = waiting_tasks_guard.get_mut(&id) { - queue.push_back(core_scheduler.get_current_task_handle()); - core_scheduler.block_current_task(None); - - // Switch to the next task. - drop(waiting_tasks_guard); - core_scheduler.reschedule(); - } else { + let Some(queue) = waiting_tasks_guard.get_mut(&id) else { return Ok(()); - } + }; + + queue.push_back(core_scheduler.get_current_task_handle()); + core_scheduler.block_current_task(None); + + // Switch to the next task. + drop(waiting_tasks_guard); + core_scheduler.reschedule(); } } diff --git a/src/scheduler/task/mod.rs b/src/scheduler/task/mod.rs index 5b71f00ddf..bcb890c215 100644 --- a/src/scheduler/task/mod.rs +++ b/src/scheduler/task/mod.rs @@ -196,26 +196,22 @@ impl TaskHandlePriorityQueue { } fn pop_from_queue(&mut self, queue_index: usize) -> Option { - if let Some(queue) = &mut self.queues[queue_index] { - let task = queue.pop_front(); + let queue = self.queues[queue_index].as_mut()?; - if queue.is_empty() { - *self.prio_bitmap &= !(1 << queue_index as u64); - } + let task = queue.pop_front(); - task - } else { - None + if queue.is_empty() { + *self.prio_bitmap &= !(1 << queue_index as u64); } + + task } /// Pop the task handle with the highest priority from the queue pub fn pop(&mut self) -> Option { - if let Some(i) = msb(self.prio_bitmap.into_inner()) { - return self.pop_from_queue(i as usize); - } + let i = msb(self.prio_bitmap.into_inner())?; - None + self.pop_from_queue(i as usize) } /// Remove a specific task handle from the priority queue. Returns `true` if @@ -290,18 +286,18 @@ impl PriorityTaskQueue { //assert!(prio < NO_PRIORITIES, "Priority {} is too high", prio); let queue = &mut self.queues[queue_index]; - if task_index <= queue.len() { - // Calling remove is unstable: https://github.com/rust-lang/rust/issues/69210 - let mut split_list = queue.split_off(task_index); - let element = split_list.pop_front(); - queue.append(&mut split_list); - if queue.is_empty() { - self.prio_bitmap &= !(1 << queue_index as u64); - } - element - } else { - None + if queue.len() < task_index { + return None; } + + // Calling remove is unstable: https://github.com/rust-lang/rust/issues/69210 + let mut split_list = queue.split_off(task_index); + let element = split_list.pop_front(); + queue.append(&mut split_list); + if queue.is_empty() { + self.prio_bitmap &= !(1 << queue_index as u64); + } + element } /// Returns true if the queue is empty. @@ -318,50 +314,44 @@ impl PriorityTaskQueue { /// Pop the task with the highest priority from the queue pub fn pop(&mut self) -> Option>> { - if let Some(i) = msb(self.prio_bitmap) { - return self.pop_from_queue(i as usize); - } + let i = msb(self.prio_bitmap)?; - None + self.pop_from_queue(i as usize) } /// Pop the next task, which has a higher or the same priority as `prio` pub fn pop_with_prio(&mut self, prio: Priority) -> Option>> { - if let Some(i) = msb(self.prio_bitmap) - && i >= u32::from(prio.into()) - { - return self.pop_from_queue(i as usize); + let i = msb(self.prio_bitmap)?; + + if i < u32::from(prio.into()) { + return None; } - None + self.pop_from_queue(i as usize) } /// Returns the highest priority of all available task #[cfg(all(any(target_arch = "x86_64", target_arch = "riscv64"), feature = "smp"))] pub fn get_highest_priority(&self) -> Priority { - if let Some(i) = msb(self.prio_bitmap) { - Priority::from(i.try_into().unwrap()) - } else { - IDLE_PRIO - } + let Some(i) = msb(self.prio_bitmap) else { + return IDLE_PRIO; + }; + + Priority::from(i.try_into().unwrap()) } /// Change priority of specific task pub fn set_priority(&mut self, handle: TaskHandle, prio: Priority) -> Result<(), ()> { let old_priority = handle.get_priority().into() as usize; - if let Some(index) = self.queues[old_priority] + let index = self.queues[old_priority] .iter() .position(|current_task| current_task.borrow().id == handle.id) - { - let Some(task) = self.remove_from_queue(index, old_priority) else { - return Err(()); - }; - task.borrow_mut().prio = prio; - self.push(task); - return Ok(()); - } + .ok_or(())?; - Err(()) + let task = self.remove_from_queue(index, old_priority).ok_or(())?; + task.borrow_mut().prio = prio; + self.push(task); + Ok(()) } } diff --git a/src/synch/recmutex.rs b/src/synch/recmutex.rs index 07688594cd..434e6eab57 100644 --- a/src/synch/recmutex.rs +++ b/src/synch/recmutex.rs @@ -63,29 +63,29 @@ impl RecursiveMutex { } pub fn release(&self) { - if let Some(task) = { - let mut locked_state = self.state.lock(); + let mut locked_state = self.state.lock(); - // We could do a sanity check here whether the RecursiveMutex is actually held by the current task. - // But let's just trust our code using this function for the sake of simplicity and performance. + // We could do a sanity check here whether the RecursiveMutex is actually held by the current task. + // But let's just trust our code using this function for the sake of simplicity and performance. - // Decrement the counter (recursive mutex behavior). - if locked_state.count > 0 { - locked_state.count -= 1; - } - - if locked_state.count == 0 { - // Release the entire recursive mutex. - locked_state.current_tid = None; + // Decrement the counter (recursive mutex behavior). + if locked_state.count > 0 { + locked_state.count -= 1; + } - locked_state.queue.pop() - } else { - None - } - } { - // Wake up any task that has been waiting for this mutex. - core_scheduler().custom_wakeup(task); + if locked_state.count > 0 { + return; } + + // Release the entire recursive mutex. + locked_state.current_tid = None; + + let Some(task) = locked_state.queue.pop() else { + return; + }; + + // Wake up any task that has been waiting for this mutex. + core_scheduler().custom_wakeup(task); } } diff --git a/src/synch/semaphore.rs b/src/synch/semaphore.rs index c32ea673b3..69fcab4b2c 100644 --- a/src/synch/semaphore.rs +++ b/src/synch/semaphore.rs @@ -138,13 +138,15 @@ impl Semaphore { /// This will increment the number of resources in this semaphore by 1 and /// will notify any pending waiters in `acquire` or `access` if necessary. pub fn release(&self) { - if let Some(task) = { - let mut locked_state = self.state.lock(); - locked_state.count += 1; - locked_state.queue.pop() - } { - // Wake up any task that has been waiting for this semaphore. - core_scheduler().custom_wakeup(task); + let mut locked_state = self.state.lock(); + locked_state.count += 1; + let task = locked_state.queue.pop(); + + let Some(task) = task else { + return; }; + + // Wake up any task that has been waiting for this semaphore. + core_scheduler().custom_wakeup(task); } }