Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion rustortion-core/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ use rustortion_core::amp::stages::{
compressor::CompressorStage,
filter::{FilterStage, FilterType},
level::LevelStage,
nam::NamConfig,
noise_gate::NoiseGateStage,
poweramp::{PowerAmpStage, PowerAmpType},
preamp::PreampStage,
tonestack::{ToneStackModel, ToneStackStage},
};
use rustortion_core::nam::{NamLoader, registry};
use std::hint::black_box;
use std::path::Path;

const SAMPLE_RATE: usize = 48000;
const BUFFER_SIZE: usize = 128;
Expand Down Expand Up @@ -110,5 +113,108 @@ fn bench_sample_vs_block(c: &mut Criterion) {
group.finish();
}

criterion_group!(benches, bench_sample_vs_block);
/// Load the vendored MIT reference WaveNet model (`tests/fixtures/`) into the global
/// registry and return its name. The fixture is committed, so the NAM benches run
/// deterministically in CI rather than depending on a user's gitignored `nam/` models.
fn load_first_nam_model() -> Option<String> {
let dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
let loader = NamLoader::new(&dir).ok()?;
registry::init_from_loader(&loader);
// The fixture is a 48 kHz model and the chain runs at SAMPLE_RATE (48 kHz, 1x),
// so the stage stays active rather than bypassing on a rate mismatch.
registry::available_names().into_iter().next()
}

fn bench_nam_sample_vs_block(c: &mut Criterion) {
let Some(model_name) = load_first_nam_model() else {
eprintln!("skipping NAM bench: no .nam model found in tests/fixtures");
return;
};

let config = NamConfig {
model_name: Some(model_name),
..NamConfig::default()
};

// Sanity-check the model actually loaded (rate matches 48 kHz); if it bypassed we
// would be benchmarking a passthrough, which is meaningless here.
if !config.to_stage(SAMPLE_RATE as f32).is_active() {
eprintln!("skipping NAM bench: model bypassed (sample-rate mismatch at 48 kHz)");
return;
}

let mut group = c.benchmark_group("NAM Chain Sample vs Block");
// NAM runs at the model's native rate (no oversampling), so benchmark at 1x only.
let buffer_size = BUFFER_SIZE;

group.bench_function(BenchmarkId::new("sample-by-sample", "1x"), |b| {
let mut chain = build_chain(SAMPLE_RATE as f32);
chain.add_stage(Box::new(config.to_stage(SAMPLE_RATE as f32)));
let input: Vec<f32> = vec![0.5f32; buffer_size];

b.iter(|| {
for &sample in &input {
black_box(chain.process(black_box(sample)));
}
});
});

group.bench_function(BenchmarkId::new("block", "1x"), |b| {
let mut chain = build_chain(SAMPLE_RATE as f32);
chain.add_stage(Box::new(config.to_stage(SAMPLE_RATE as f32)));
let mut buffer: Vec<f32> = vec![0.5f32; buffer_size];

b.iter(|| {
chain.process_block(black_box(&mut buffer));
black_box(&buffer);
});
});

group.finish();
}

/// Isolated ceiling: raw nam-rs `process_buffer` (batched) vs a `process_sample`
/// loop on the same model, no chain, no gain/mix. This is the maximum speedup a
/// `NamStage::process_block` override could capture by calling `process_buffer`.
fn bench_nam_buffer_vs_sample(c: &mut Criterion) {
let Some(model_name) = load_first_nam_model() else {
eprintln!("skipping NAM ceiling bench: no .nam model found");
return;
};
let Some(parsed) = registry::get(&model_name) else {
return;
};
let Ok(mut model) = nam_rs::Model::from_nam(&parsed) else {
eprintln!("skipping NAM ceiling bench: model failed to build");
return;
};

let mut group = c.benchmark_group("NAM Model Buffer vs Sample");

group.bench_function(BenchmarkId::new("process_sample-loop", "1x"), |b| {
let input: Vec<f32> = vec![0.5f32; BUFFER_SIZE];
b.iter(|| {
for &sample in &input {
black_box(model.process_sample(black_box(sample)));
}
});
});

group.bench_function(BenchmarkId::new("process_buffer", "1x"), |b| {
let mut buffer: Vec<f32> = vec![0.5f32; BUFFER_SIZE];
b.iter(|| {
model.process_buffer(black_box(&mut buffer));
black_box(&buffer);
});
});

group.finish();
}

criterion_group!(
benches,
bench_sample_vs_block,
bench_nam_sample_vs_block,
bench_nam_buffer_vs_sample
);
criterion_main!(benches);
97 changes: 97 additions & 0 deletions rustortion-core/src/amp/stages/nam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ pub struct NamStage {
native_sample_rate: f32,
/// True if the model's native rate differs from the engine rate.
sample_rate_mismatch: bool,
/// Scratch buffer holding the dry signal during block processing, so the
/// in-place `process_buffer` output can be blended back with `mix`. Grown on
/// demand (first block of a given size); steady-state processing never allocates.
dry: Vec<f32>,
}

impl NamStage {
Expand All @@ -35,6 +39,7 @@ impl NamStage {
mix,
native_sample_rate: 0.0,
sample_rate_mismatch: false,
dry: Vec::new(),
}
}

Expand All @@ -53,8 +58,15 @@ impl NamStage {
mix,
native_sample_rate,
sample_rate_mismatch: true,
dry: Vec::new(),
}
}

/// True when a model is loaded and running (not a passthrough or rate-mismatch bypass).
#[must_use]
pub const fn is_active(&self) -> bool {
self.model.is_some()
}
}

impl Stage for NamStage {
Expand All @@ -66,6 +78,35 @@ impl Stage for NamStage {
self.mix.mul_add(wet - input, input)
}

fn process_block(&mut self, input: &mut [f32]) {
// No model → dry passthrough (matches `process`'s early return).
if self.model.is_none() {
return;
}

// Stash the dry signal, then scale the buffer by input gain in place so the
// model's batched `process_buffer` runs over the gained signal. `resize` only
// allocates the first time a given block size is seen; steady state is alloc-free.
if self.dry.len() < input.len() {
self.dry.resize(input.len(), 0.0);
}
let dry = &mut self.dry[..input.len()];
for (d, x) in dry.iter_mut().zip(input.iter_mut()) {
*d = *x;
*x *= self.input_gain;
}

// Borrow the model only here (after the `self.dry` borrow above is done being set up).
let model = self.model.as_mut().expect("model present (checked above)");
model.process_buffer(input);

// Apply output gain and blend wet/dry per sample — same formula as `process`.
for (x, &d) in input.iter_mut().zip(self.dry[..].iter()) {
let wet = *x * self.output_gain;
*x = self.mix.mul_add(wet - d, d);
}
}

fn set_parameter(&mut self, name: &str, value: f32) -> Result<(), &'static str> {
match name {
"input_gain_db" => {
Expand Down Expand Up @@ -178,6 +219,7 @@ impl NamConfig {
native_sample_rate,
// Rates match (mismatch returned early above).
sample_rate_mismatch: false,
dry: Vec::new(),
},
Err(e) => {
warn!("Failed to build NAM model '{name}': {e}; using passthrough");
Expand Down Expand Up @@ -236,4 +278,59 @@ mod tests {
assert!(stage.set_parameter("output_gain_db", -30.0).is_err());
assert!(stage.set_parameter("input_gain_db", f32::NAN).is_err());
}

/// `process_block` (batched `process_buffer` + gain/mix wrapper) must match the
/// per-sample `process` path bit-for-bit (within float tolerance). Uses the vendored
/// MIT reference model in `tests/fixtures/`, so this runs in CI.
#[test]
fn block_matches_per_sample_with_real_model() {
use crate::nam::{NamLoader, registry};
use std::path::Path;

let dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
let Ok(loader) = NamLoader::new(&dir) else {
return;
};
registry::init_from_loader(&loader);
let Some(name) = registry::available_names().into_iter().next() else {
eprintln!("skipping NAM parity test: no model available");
return;
};

let config = NamConfig {
model_name: Some(name),
input_gain_db: 6.0,
output_gain_db: -3.0,
mix: 0.5,
bypassed: false,
};

// Two stages from the same config evolve identical internal state given the
// same input, so per-sample and block paths should agree.
let mut per_sample = config.to_stage(48_000.0);
let mut block = config.to_stage(48_000.0);
if !per_sample.is_active() {
eprintln!("skipping NAM parity test: model bypassed at 48 kHz");
return;
}

// A non-trivial signal so gain/mix differences would show up.
let input: Vec<f32> = (0..256)
.map(|i| {
let t = i as f32;
0.3f32.mul_add((t * 0.05).sin(), 0.1 * (t * 0.31).cos())
})
.collect();

let expected: Vec<f32> = input.iter().map(|&x| per_sample.process(x)).collect();
let mut got = input; // moved: input is not needed after this
block.process_block(&mut got);

for (i, (e, g)) in expected.iter().zip(got.iter()).enumerate() {
assert!(
(e - g).abs() < 1e-5,
"mismatch at {i}: per-sample={e}, block={g}"
);
}
}
}
10 changes: 2 additions & 8 deletions rustortion-core/src/audio/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,7 @@ impl Engine {
}

fn process_without_upsampling(&mut self, output: &mut [f32]) -> Result<()> {
let chain = self.chain.as_mut();
for s in output.iter_mut() {
*s = chain.process(*s);
}
self.chain.as_mut().process_block(output);

Ok(())
}
Expand All @@ -215,10 +212,7 @@ impl Engine {

let upsampled = self.samplers.upsample()?;

let chain = self.chain.as_mut();
for s in upsampled.iter_mut() {
*s = chain.process(*s);
}
self.chain.as_mut().process_block(upsampled);

let downsampled = self.samplers.downsample()?;

Expand Down
18 changes: 18 additions & 0 deletions rustortion-core/tests/fixtures/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Test fixtures

## `reference_standard.nam`

A standard-architecture WaveNet NAM model, vendored from the
[`nam-rs`](https://github.com/OpenSauce/nam-rs) test fixtures
(`tests/fixtures/reference_standard.nam`).

It is used by the NAM parity test (`block_matches_per_sample_with_real_model`) and
the `chain` benchmark's NAM groups, so both run deterministically in CI without
depending on a user's personal (gitignored) `nam/` models.

### License / attribution

`nam-rs` is distributed under the MIT License (Copyright (c) 2026 Leigh). The `.nam`
weight/config layout is a derivative of the Neural Amp Modeler ecosystem
(neural-amp-modeler / NeuralAmpModelerCore, Copyright (c) 2019-2025 Steven Atkinson,
MIT). See the `nam-rs` `LICENSE` and `NOTICE` files for full terms.
Loading