From dfe3741b6d1c5117e6a4b9d5d0e7c2e09c084ffa Mon Sep 17 00:00:00 2001 From: Crauzer <0xcrauzer@proton.me> Date: Sat, 9 May 2026 13:25:07 +0200 Subject: [PATCH] refactor(ltk_shader): improve public shader loader API --- crates/ltk_shader/src/error.rs | 37 ++++++ crates/ltk_shader/src/lib.rs | 3 + crates/ltk_shader/src/loader.rs | 210 ++++++++++++++++++++++++-------- crates/ltk_shader/src/toc.rs | 43 +++---- 4 files changed, 220 insertions(+), 73 deletions(-) create mode 100644 crates/ltk_shader/src/error.rs diff --git a/crates/ltk_shader/src/error.rs b/crates/ltk_shader/src/error.rs new file mode 100644 index 00000000..b1a99889 --- /dev/null +++ b/crates/ltk_shader/src/error.rs @@ -0,0 +1,37 @@ +use std::io; + +use ltk_wad::WadError; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ShaderError { + #[error("io error")] + Io(#[from] io::Error), + + #[error("wad error")] + Wad(#[from] WadError), + + #[error("shader object not found: {path}")] + ShaderObjectNotFound { path: String }, + + #[error("shader bundle not found: {path}")] + ShaderBundleNotFound { path: String }, + + #[error("shader for defines `{defines}` not found in TOC")] + DefinesNotFound { defines: String }, + + #[error("invalid TOC magic: expected {expected:?}, got {actual:?}")] + InvalidTocMagic { expected: String, actual: String }, + + #[error("invalid section magic: expected {expected:?}, got {actual:?}")] + InvalidSectionMagic { expected: String, actual: String }, + + #[error("TOC vector length mismatch: expected {expected}, got {hashes} hashes and {ids} ids")] + TocLengthMismatch { + expected: usize, + hashes: usize, + ids: usize, + }, +} + +pub type Result = std::result::Result; diff --git a/crates/ltk_shader/src/lib.rs b/crates/ltk_shader/src/lib.rs index 3c9a689a..38864dca 100644 --- a/crates/ltk_shader/src/lib.rs +++ b/crates/ltk_shader/src/lib.rs @@ -1,7 +1,10 @@ pub mod defines; +pub mod error; pub mod loader; pub mod toc; +pub use error::{Result, ShaderError}; + use byteorder::{ReadBytesExt, LE}; use std::io::{self, Read}; diff --git a/crates/ltk_shader/src/loader.rs b/crates/ltk_shader/src/loader.rs index 77ad3c44..971e3885 100644 --- a/crates/ltk_shader/src/loader.rs +++ b/crates/ltk_shader/src/loader.rs @@ -3,6 +3,7 @@ use std::io::{Cursor, Read, Seek, SeekFrom}; use xxhash_rust::xxh64::xxh64; use crate::defines::ShaderMacroDefinition; +use crate::error::{Result, ShaderError}; use crate::toc::ShaderToc; use crate::{create_shader_bundle_path, create_shader_object_path, GraphicsPlatform, ShaderType}; use ltk_wad::Wad; @@ -12,15 +13,13 @@ const SHADERS_PER_BUNDLE: u32 = 100; pub struct ShaderLoader; impl ShaderLoader { - /// Loads the bytecode for a shader object from a WAD file. + /// Loads the bytecode for a shader object from a WAD file by matching defines. /// # Arguments /// * `shader_object_path` - The path of the shader object to load. /// * `shader_type` - The type of the shader to load. /// * `platform` - The platform of the shader to load. /// * `defines` - The defines to use for the shader. /// * `wad` - The WAD file to load the shader from. - /// # Returns - /// A vector of bytes containing the bytecode for the shader object. /// # Errors /// Returns an error if the shader object is not found or if the shader object data cannot be read. pub fn load_bytecode( @@ -29,73 +28,100 @@ impl ShaderLoader { platform: GraphicsPlatform, defines: &[ShaderMacroDefinition], wad: &mut Wad, - ) -> Result, Box> { + ) -> Result> { + let toc = Self::load_toc(shader_object_path, shader_type, platform, wad)?; + + let filtered_defines_formatted = Self::filter_defines(defines, &toc.base_defines); + let filtered_defines_hash = xxh64(filtered_defines_formatted.as_bytes(), 0); + + let shader_index = toc + .shader_hashes + .iter() + .position(|&h| h == filtered_defines_hash) + .ok_or(ShaderError::DefinesNotFound { + defines: filtered_defines_formatted, + })?; + + let shader_id = toc.shader_ids[shader_index]; + Self::load_bytecode_by_id(shader_object_path, shader_type, platform, shader_id, wad) + } + + /// Loads the TOC for a shader object from a WAD file. + pub fn load_toc( + shader_object_path: &str, + shader_type: ShaderType, + platform: GraphicsPlatform, + wad: &mut Wad, + ) -> Result { let full_shader_object_path = create_shader_object_path(shader_object_path, shader_type, platform); let path_hash = xxh64(full_shader_object_path.as_bytes(), 0); - let chunk = *wad - .chunks() - .get(path_hash) - .ok_or_else(|| format!("Shader object not found: {}", full_shader_object_path))?; + let chunk = + *wad.chunks() + .get(path_hash) + .ok_or_else(|| ShaderError::ShaderObjectNotFound { + path: full_shader_object_path.clone(), + })?; let shader_object_data = wad.load_chunk_decompressed(&chunk)?; let mut shader_object_reader = Cursor::new(shader_object_data); - let shader_toc = ShaderToc::read(&mut shader_object_reader)?; - - let filtered_defines_formatted = Self::filter_defines(defines, &shader_toc.base_defines); - let filtered_defines_hash = xxh64(filtered_defines_formatted.as_bytes(), 0); + ShaderToc::read(&mut shader_object_reader) + } - let shader_index = shader_toc - .shader_hashes - .iter() - .position(|&h| h == filtered_defines_hash); - - let shader_index = match shader_index { - Some(idx) => idx, - None => { - return Err(format!( - "Shader not found for defines: {}", - filtered_defines_formatted - ) - .into()) - } - }; + /// Loads the bytecode for a single shader by its `shader_id` (as recorded in the TOC). + pub fn load_bytecode_by_id( + shader_object_path: &str, + shader_type: ShaderType, + platform: GraphicsPlatform, + shader_id: u32, + wad: &mut Wad, + ) -> Result> { + let full_shader_object_path = + create_shader_object_path(shader_object_path, shader_type, platform); - let shader_id = shader_toc.shader_ids[shader_index]; let shader_bundle_id = SHADERS_PER_BUNDLE * (shader_id / SHADERS_PER_BUNDLE); let shader_index_in_bundle = shader_id % SHADERS_PER_BUNDLE; - let shader_bundle_path = - create_shader_bundle_path(&full_shader_object_path, shader_bundle_id); - let bundle_path_hash = xxh64(shader_bundle_path.as_bytes(), 0); - let bundle_chunk = *wad - .chunks() - .get(bundle_path_hash) - .ok_or_else(|| format!("Shader bundle not found: {}", shader_bundle_path))?; + let bundle_data = Self::load_bundle_data(&full_shader_object_path, shader_bundle_id, wad)?; + parse_bundle_entry_at(&bundle_data, shader_index_in_bundle) + } - let shader_bundle_data = wad.load_chunk_decompressed(&bundle_chunk)?; - let mut shader_bundle_reader = Cursor::new(shader_bundle_data); + /// Read a single bundle file and yield each entry's bytecode in order. + /// Hot path for "dump every shader for one (object, type, platform)": + /// reads the bundle chunk once instead of seeking N times. + /// + /// `full_shader_object_path` is the already-formatted path (e.g. from + /// [`create_shader_object_path`]); `shader_bundle_id` is a multiple of 100. + pub fn read_bundle( + full_shader_object_path: &str, + shader_bundle_id: u32, + wad: &mut Wad, + ) -> Result>> { + let bundle_data = Self::load_bundle_data(full_shader_object_path, shader_bundle_id, wad)?; + parse_bundle_entries(&bundle_data) + } - for _ in 0..shader_index_in_bundle { - let shader_size = shader_bundle_reader.read_u32::()?; - shader_bundle_reader.seek(SeekFrom::Current(shader_size as i64))?; - } + fn load_bundle_data( + full_shader_object_path: &str, + shader_bundle_id: u32, + wad: &mut Wad, + ) -> Result> { + let shader_bundle_path = + create_shader_bundle_path(full_shader_object_path, shader_bundle_id); - let requested_shader_size = shader_bundle_reader.read_u32::()? as usize; - let mut bytecode = Vec::with_capacity(requested_shader_size); - shader_bundle_reader.read_exact(&mut bytecode)?; + let bundle_path_hash = xxh64(shader_bundle_path.as_bytes(), 0); + let bundle_chunk = *wad.chunks().get(bundle_path_hash).ok_or_else(|| { + ShaderError::ShaderBundleNotFound { + path: shader_bundle_path.clone(), + } + })?; - Ok(bytecode) + Ok(wad.load_chunk_decompressed(&bundle_chunk)?) } /// Filters the defines to only include the defines that are in the base defines. - /// # Arguments - /// * `defines` - The defines to filter. - /// * `base_defines` - The base defines to filter the defines against. - /// # Returns - /// A string containing the filtered defines. fn filter_defines( defines: &[ShaderMacroDefinition], base_defines: &[ShaderMacroDefinition], @@ -112,3 +138,87 @@ impl ShaderLoader { filtered.iter().map(|d| d.to_string()).collect() } } + +/// Parse a shader bundle blob into its constituent bytecode entries. +/// +/// A bundle is a sequence of `(u32 length, [u8; length])` entries packed back-to-back. +fn parse_bundle_entries(data: &[u8]) -> Result>> { + let total = data.len() as u64; + let mut reader = Cursor::new(data); + let mut entries = Vec::new(); + while reader.position() < total { + let size = reader.read_u32::()? as usize; + let mut bytecode = vec![0u8; size]; + reader.read_exact(&mut bytecode)?; + entries.push(bytecode); + } + Ok(entries) +} + +/// Parse a shader bundle blob and return only the entry at `index_in_bundle`. +fn parse_bundle_entry_at(data: &[u8], index_in_bundle: u32) -> Result> { + let mut reader = Cursor::new(data); + for _ in 0..index_in_bundle { + let size = reader.read_u32::()?; + reader.seek(SeekFrom::Current(i64::from(size)))?; + } + let size = reader.read_u32::()? as usize; + let mut bytecode = vec![0u8; size]; + reader.read_exact(&mut bytecode)?; + Ok(bytecode) +} + +#[cfg(test)] +mod tests { + use super::*; + use byteorder::WriteBytesExt; + + fn build_bundle(entries: &[&[u8]]) -> Vec { + let mut out = Vec::new(); + for entry in entries { + out.write_u32::(entry.len() as u32).unwrap(); + out.extend_from_slice(entry); + } + out + } + + #[test] + fn parse_bundle_entries_reads_all_in_order() { + let bundle = build_bundle(&[b"\x01\x02\x03", b"hello", b"DXBCmagicpayload"]); + let entries = parse_bundle_entries(&bundle).expect("parse should succeed"); + assert_eq!(entries.len(), 3); + assert_eq!(entries[0], b"\x01\x02\x03"); + assert_eq!(entries[1], b"hello"); + assert_eq!(entries[2], b"DXBCmagicpayload"); + } + + #[test] + fn parse_bundle_entry_at_returns_correct_entry() { + let bundle = build_bundle(&[b"first", b"second-entry", b"third"]); + assert_eq!(parse_bundle_entry_at(&bundle, 0).unwrap(), b"first"); + assert_eq!(parse_bundle_entry_at(&bundle, 1).unwrap(), b"second-entry"); + assert_eq!(parse_bundle_entry_at(&bundle, 2).unwrap(), b"third"); + } + + /// Regression: `Vec::with_capacity(n) + read_exact(&mut v)` reads zero bytes + /// because `len() == 0`. Both helpers must use `vec![0u8; n]` instead. + #[test] + fn parse_bundle_does_not_return_empty_entries_for_nonempty_input() { + let payload = b"non-empty-bytecode"; + let bundle = build_bundle(&[payload]); + let entries = parse_bundle_entries(&bundle).unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].len(), payload.len()); + assert_eq!(entries[0], payload); + + let single = parse_bundle_entry_at(&bundle, 0).unwrap(); + assert_eq!(single.len(), payload.len()); + assert_eq!(single, payload); + } + + #[test] + fn parse_bundle_entries_empty_input_yields_no_entries() { + let entries = parse_bundle_entries(&[]).unwrap(); + assert!(entries.is_empty()); + } +} diff --git a/crates/ltk_shader/src/toc.rs b/crates/ltk_shader/src/toc.rs index 88e38a7a..784cfd37 100644 --- a/crates/ltk_shader/src/toc.rs +++ b/crates/ltk_shader/src/toc.rs @@ -1,7 +1,8 @@ use crate::defines::ShaderMacroDefinition; +use crate::error::{Result, ShaderError}; use crate::read_sized_string; use byteorder::{ReadBytesExt, LE}; -use std::io::{self, Read}; +use std::io::Read; #[derive(Debug)] pub struct ShaderToc { @@ -23,13 +24,13 @@ impl ShaderToc { } } - pub fn read(reader: &mut R) -> io::Result { + pub fn read(reader: &mut R) -> Result { let toc_magic = read_sized_string(reader)?; if toc_magic != "TOC3.0" { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Invalid TOC magic: expected TOC3.0, got {}", toc_magic), - )); + return Err(ShaderError::InvalidTocMagic { + expected: "TOC3.0".to_string(), + actual: toc_magic, + }); } let shader_count = reader.read_u32::()? as usize; @@ -39,10 +40,10 @@ impl ShaderToc { let base_defines_section_magic = read_sized_string(reader)?; if base_defines_section_magic != "baseDefines" { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Invalid baseDefines section magic", - )); + return Err(ShaderError::InvalidSectionMagic { + expected: "baseDefines".to_string(), + actual: base_defines_section_magic, + }); } let mut base_defines = Vec::with_capacity(base_defines_count); @@ -52,10 +53,10 @@ impl ShaderToc { let shaders_section_magic = read_sized_string(reader)?; if shaders_section_magic != "shaders" { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Invalid shaders section magic", - )); + return Err(ShaderError::InvalidSectionMagic { + expected: "shaders".to_string(), + actual: shaders_section_magic, + }); } let mut shader_hashes = vec![0u64; shader_count]; @@ -65,15 +66,11 @@ impl ShaderToc { reader.read_u32_into::(&mut shader_ids)?; if shader_hashes.len() != shader_count || shader_ids.len() != shader_count { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "Mismatch in shader vector lengths: expected {}, got {} hashes and {} ids", - shader_count, - shader_hashes.len(), - shader_ids.len() - ), - )); + return Err(ShaderError::TocLengthMismatch { + expected: shader_count, + hashes: shader_hashes.len(), + ids: shader_ids.len(), + }); } Ok(Self {