Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions crates/ltk_shader/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use std::io;

use ltk_wad::WadError;
use thiserror::Error;

#[derive(Error, Debug)]
pub enum ShaderError {
#[error("io error")]
Io(#[from] io::Error),

#[error("wad error")]
Wad(#[from] WadError),

#[error("shader object not found: {path}")]
ShaderObjectNotFound { path: String },

#[error("shader bundle not found: {path}")]
ShaderBundleNotFound { path: String },

#[error("shader for defines `{defines}` not found in TOC")]
DefinesNotFound { defines: String },

#[error("invalid TOC magic: expected {expected:?}, got {actual:?}")]
InvalidTocMagic { expected: String, actual: String },

#[error("invalid section magic: expected {expected:?}, got {actual:?}")]
InvalidSectionMagic { expected: String, actual: String },

#[error("TOC vector length mismatch: expected {expected}, got {hashes} hashes and {ids} ids")]
TocLengthMismatch {
expected: usize,
hashes: usize,
ids: usize,
},
}

pub type Result<T> = std::result::Result<T, ShaderError>;
3 changes: 3 additions & 0 deletions crates/ltk_shader/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
pub mod defines;
pub mod error;
pub mod loader;
pub mod toc;

pub use error::{Result, ShaderError};

use byteorder::{ReadBytesExt, LE};
use std::io::{self, Read};

Expand Down
210 changes: 160 additions & 50 deletions crates/ltk_shader/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::io::{Cursor, Read, Seek, SeekFrom};
use xxhash_rust::xxh64::xxh64;

use crate::defines::ShaderMacroDefinition;
use crate::error::{Result, ShaderError};
use crate::toc::ShaderToc;
use crate::{create_shader_bundle_path, create_shader_object_path, GraphicsPlatform, ShaderType};
use ltk_wad::Wad;
Expand All @@ -12,15 +13,13 @@ const SHADERS_PER_BUNDLE: u32 = 100;
pub struct ShaderLoader;

impl ShaderLoader {
/// Loads the bytecode for a shader object from a WAD file.
/// Loads the bytecode for a shader object from a WAD file by matching defines.
/// # Arguments
/// * `shader_object_path` - The path of the shader object to load.
/// * `shader_type` - The type of the shader to load.
/// * `platform` - The platform of the shader to load.
/// * `defines` - The defines to use for the shader.
/// * `wad` - The WAD file to load the shader from.
/// # Returns
/// A vector of bytes containing the bytecode for the shader object.
/// # Errors
/// Returns an error if the shader object is not found or if the shader object data cannot be read.
pub fn load_bytecode<R: Read + Seek>(
Expand All @@ -29,73 +28,100 @@ impl ShaderLoader {
platform: GraphicsPlatform,
defines: &[ShaderMacroDefinition],
wad: &mut Wad<R>,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
) -> Result<Vec<u8>> {
let toc = Self::load_toc(shader_object_path, shader_type, platform, wad)?;

let filtered_defines_formatted = Self::filter_defines(defines, &toc.base_defines);
let filtered_defines_hash = xxh64(filtered_defines_formatted.as_bytes(), 0);

let shader_index = toc
.shader_hashes
.iter()
.position(|&h| h == filtered_defines_hash)
.ok_or(ShaderError::DefinesNotFound {
defines: filtered_defines_formatted,
})?;

let shader_id = toc.shader_ids[shader_index];
Self::load_bytecode_by_id(shader_object_path, shader_type, platform, shader_id, wad)
}

/// Loads the TOC for a shader object from a WAD file.
pub fn load_toc<R: Read + Seek>(
shader_object_path: &str,
shader_type: ShaderType,
platform: GraphicsPlatform,
wad: &mut Wad<R>,
) -> Result<ShaderToc> {
let full_shader_object_path =
create_shader_object_path(shader_object_path, shader_type, platform);

let path_hash = xxh64(full_shader_object_path.as_bytes(), 0);

let chunk = *wad
.chunks()
.get(path_hash)
.ok_or_else(|| format!("Shader object not found: {}", full_shader_object_path))?;
let chunk =
*wad.chunks()
.get(path_hash)
.ok_or_else(|| ShaderError::ShaderObjectNotFound {
path: full_shader_object_path.clone(),
})?;

let shader_object_data = wad.load_chunk_decompressed(&chunk)?;
let mut shader_object_reader = Cursor::new(shader_object_data);
let shader_toc = ShaderToc::read(&mut shader_object_reader)?;

let filtered_defines_formatted = Self::filter_defines(defines, &shader_toc.base_defines);
let filtered_defines_hash = xxh64(filtered_defines_formatted.as_bytes(), 0);
ShaderToc::read(&mut shader_object_reader)
}

let shader_index = shader_toc
.shader_hashes
.iter()
.position(|&h| h == filtered_defines_hash);

let shader_index = match shader_index {
Some(idx) => idx,
None => {
return Err(format!(
"Shader not found for defines: {}",
filtered_defines_formatted
)
.into())
}
};
/// Loads the bytecode for a single shader by its `shader_id` (as recorded in the TOC).
pub fn load_bytecode_by_id<R: Read + Seek>(
shader_object_path: &str,
shader_type: ShaderType,
platform: GraphicsPlatform,
shader_id: u32,
wad: &mut Wad<R>,
) -> Result<Vec<u8>> {
let full_shader_object_path =
create_shader_object_path(shader_object_path, shader_type, platform);

let shader_id = shader_toc.shader_ids[shader_index];
let shader_bundle_id = SHADERS_PER_BUNDLE * (shader_id / SHADERS_PER_BUNDLE);
let shader_index_in_bundle = shader_id % SHADERS_PER_BUNDLE;
let shader_bundle_path =
create_shader_bundle_path(&full_shader_object_path, shader_bundle_id);

let bundle_path_hash = xxh64(shader_bundle_path.as_bytes(), 0);
let bundle_chunk = *wad
.chunks()
.get(bundle_path_hash)
.ok_or_else(|| format!("Shader bundle not found: {}", shader_bundle_path))?;
let bundle_data = Self::load_bundle_data(&full_shader_object_path, shader_bundle_id, wad)?;
parse_bundle_entry_at(&bundle_data, shader_index_in_bundle)
}

let shader_bundle_data = wad.load_chunk_decompressed(&bundle_chunk)?;
let mut shader_bundle_reader = Cursor::new(shader_bundle_data);
/// Read a single bundle file and yield each entry's bytecode in order.
/// Hot path for "dump every shader for one (object, type, platform)":
/// reads the bundle chunk once instead of seeking N times.
///
/// `full_shader_object_path` is the already-formatted path (e.g. from
/// [`create_shader_object_path`]); `shader_bundle_id` is a multiple of 100.
pub fn read_bundle<R: Read + Seek>(
full_shader_object_path: &str,
shader_bundle_id: u32,
wad: &mut Wad<R>,
) -> Result<Vec<Vec<u8>>> {
let bundle_data = Self::load_bundle_data(full_shader_object_path, shader_bundle_id, wad)?;
parse_bundle_entries(&bundle_data)
}

for _ in 0..shader_index_in_bundle {
let shader_size = shader_bundle_reader.read_u32::<LE>()?;
shader_bundle_reader.seek(SeekFrom::Current(shader_size as i64))?;
}
fn load_bundle_data<R: Read + Seek>(
full_shader_object_path: &str,
shader_bundle_id: u32,
wad: &mut Wad<R>,
) -> Result<Box<[u8]>> {
let shader_bundle_path =
create_shader_bundle_path(full_shader_object_path, shader_bundle_id);

let requested_shader_size = shader_bundle_reader.read_u32::<LE>()? as usize;
let mut bytecode = Vec::with_capacity(requested_shader_size);
shader_bundle_reader.read_exact(&mut bytecode)?;
let bundle_path_hash = xxh64(shader_bundle_path.as_bytes(), 0);
let bundle_chunk = *wad.chunks().get(bundle_path_hash).ok_or_else(|| {
ShaderError::ShaderBundleNotFound {
path: shader_bundle_path.clone(),
}
})?;

Ok(bytecode)
Ok(wad.load_chunk_decompressed(&bundle_chunk)?)
}

/// Filters the defines to only include the defines that are in the base defines.
/// # Arguments
/// * `defines` - The defines to filter.
/// * `base_defines` - The base defines to filter the defines against.
/// # Returns
/// A string containing the filtered defines.
fn filter_defines(
defines: &[ShaderMacroDefinition],
base_defines: &[ShaderMacroDefinition],
Expand All @@ -112,3 +138,87 @@ impl ShaderLoader {
filtered.iter().map(|d| d.to_string()).collect()
}
}

/// Parse a shader bundle blob into its constituent bytecode entries.
///
/// A bundle is a sequence of `(u32 length, [u8; length])` entries packed back-to-back.
fn parse_bundle_entries(data: &[u8]) -> Result<Vec<Vec<u8>>> {
let total = data.len() as u64;
let mut reader = Cursor::new(data);
let mut entries = Vec::new();
while reader.position() < total {
let size = reader.read_u32::<LE>()? as usize;
let mut bytecode = vec![0u8; size];
reader.read_exact(&mut bytecode)?;
entries.push(bytecode);
}
Ok(entries)
}

/// Parse a shader bundle blob and return only the entry at `index_in_bundle`.
fn parse_bundle_entry_at(data: &[u8], index_in_bundle: u32) -> Result<Vec<u8>> {
let mut reader = Cursor::new(data);
for _ in 0..index_in_bundle {
let size = reader.read_u32::<LE>()?;
reader.seek(SeekFrom::Current(i64::from(size)))?;
}
let size = reader.read_u32::<LE>()? as usize;
let mut bytecode = vec![0u8; size];
reader.read_exact(&mut bytecode)?;
Ok(bytecode)
}

#[cfg(test)]
mod tests {
use super::*;
use byteorder::WriteBytesExt;

fn build_bundle(entries: &[&[u8]]) -> Vec<u8> {
let mut out = Vec::new();
for entry in entries {
out.write_u32::<LE>(entry.len() as u32).unwrap();
out.extend_from_slice(entry);
}
out
}

#[test]
fn parse_bundle_entries_reads_all_in_order() {
let bundle = build_bundle(&[b"\x01\x02\x03", b"hello", b"DXBCmagicpayload"]);
let entries = parse_bundle_entries(&bundle).expect("parse should succeed");
assert_eq!(entries.len(), 3);
assert_eq!(entries[0], b"\x01\x02\x03");
assert_eq!(entries[1], b"hello");
assert_eq!(entries[2], b"DXBCmagicpayload");
}

#[test]
fn parse_bundle_entry_at_returns_correct_entry() {
let bundle = build_bundle(&[b"first", b"second-entry", b"third"]);
assert_eq!(parse_bundle_entry_at(&bundle, 0).unwrap(), b"first");
assert_eq!(parse_bundle_entry_at(&bundle, 1).unwrap(), b"second-entry");
assert_eq!(parse_bundle_entry_at(&bundle, 2).unwrap(), b"third");
}

/// Regression: `Vec::with_capacity(n) + read_exact(&mut v)` reads zero bytes
/// because `len() == 0`. Both helpers must use `vec![0u8; n]` instead.
#[test]
fn parse_bundle_does_not_return_empty_entries_for_nonempty_input() {
let payload = b"non-empty-bytecode";
let bundle = build_bundle(&[payload]);
let entries = parse_bundle_entries(&bundle).unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].len(), payload.len());
assert_eq!(entries[0], payload);

let single = parse_bundle_entry_at(&bundle, 0).unwrap();
assert_eq!(single.len(), payload.len());
assert_eq!(single, payload);
}

#[test]
fn parse_bundle_entries_empty_input_yields_no_entries() {
let entries = parse_bundle_entries(&[]).unwrap();
assert!(entries.is_empty());
}
}
43 changes: 20 additions & 23 deletions crates/ltk_shader/src/toc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::defines::ShaderMacroDefinition;
use crate::error::{Result, ShaderError};
use crate::read_sized_string;
use byteorder::{ReadBytesExt, LE};
use std::io::{self, Read};
use std::io::Read;

#[derive(Debug)]
pub struct ShaderToc {
Expand All @@ -23,13 +24,13 @@ impl ShaderToc {
}
}

pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
pub fn read<R: Read>(reader: &mut R) -> Result<Self> {
let toc_magic = read_sized_string(reader)?;
if toc_magic != "TOC3.0" {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid TOC magic: expected TOC3.0, got {}", toc_magic),
));
return Err(ShaderError::InvalidTocMagic {
expected: "TOC3.0".to_string(),
actual: toc_magic,
});
}

let shader_count = reader.read_u32::<LE>()? as usize;
Expand All @@ -39,10 +40,10 @@ impl ShaderToc {

let base_defines_section_magic = read_sized_string(reader)?;
if base_defines_section_magic != "baseDefines" {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid baseDefines section magic",
));
return Err(ShaderError::InvalidSectionMagic {
expected: "baseDefines".to_string(),
actual: base_defines_section_magic,
});
}

let mut base_defines = Vec::with_capacity(base_defines_count);
Expand All @@ -52,10 +53,10 @@ impl ShaderToc {

let shaders_section_magic = read_sized_string(reader)?;
if shaders_section_magic != "shaders" {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid shaders section magic",
));
return Err(ShaderError::InvalidSectionMagic {
expected: "shaders".to_string(),
actual: shaders_section_magic,
});
}

let mut shader_hashes = vec![0u64; shader_count];
Expand All @@ -65,15 +66,11 @@ impl ShaderToc {
reader.read_u32_into::<LE>(&mut shader_ids)?;

if shader_hashes.len() != shader_count || shader_ids.len() != shader_count {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Mismatch in shader vector lengths: expected {}, got {} hashes and {} ids",
shader_count,
shader_hashes.len(),
shader_ids.len()
),
));
return Err(ShaderError::TocLengthMismatch {
expected: shader_count,
hashes: shader_hashes.len(),
ids: shader_ids.len(),
});
}

Ok(Self {
Expand Down
Loading