From 3c7c8821e637895b03a99ec4ae934e34bf7fda75 Mon Sep 17 00:00:00 2001 From: Robert Haist Date: Sun, 24 May 2026 15:35:27 +0200 Subject: [PATCH] [PERFORMANCE] Templated kernels for grouped Conv1x1/Conv1D Compile-time-specialized GEMM kernels for the (out_channels, in_channels, groups) shapes used by WaveNet models. Generalizes the depthwise-only fast path from #217 to all grouped (and small dense) cases, addressing #215. Both the default Eigen path and NAM_USE_INLINE_GEMM build benefit; unknown shapes fall through to existing behavior. Render output is bit-identical to main on 33 production models including the v4 baseline a1-{pico,nano,feather,lite,standard} set. --- NAM/conv1d.cpp | 132 ++++++++++++++++++++++++++ NAM/conv1d.h | 11 +++ NAM/dsp.cpp | 120 ++++++++++++++++++++++- NAM/dsp.h | 10 ++ tools/CMakeLists.txt | 36 +++++++ tools/bench_conv1x1_groups.cpp | 167 +++++++++++++++++++++++++++++++++ tools/check_conv1d_grouped.cpp | 137 +++++++++++++++++++++++++++ 7 files changed, 611 insertions(+), 2 deletions(-) create mode 100644 tools/bench_conv1x1_groups.cpp create mode 100644 tools/check_conv1d_grouped.cpp diff --git a/NAM/conv1d.cpp b/NAM/conv1d.cpp index b561786c..e7fd3082 100644 --- a/NAM/conv1d.cpp +++ b/NAM/conv1d.cpp @@ -4,6 +4,104 @@ namespace nam { +namespace +{ +// Templated per-tap accumulating kernel for Conv1D. +// OutCh, InCh, Groups are compile-time constants so the compiler unrolls every loop +// and folds all index arithmetic. Off-block-diagonal zeros are never visited. +// Weight memory layout is col-major (out_channels rows x in_channels cols), matching +// Eigen::MatrixXf default storage in nam::Conv1D::_weight[k]. +// Input layout is assumed contiguous (channels rows x num_frames cols, col-major), as +// the existing inline-GEMM cascade also assumes. +template +void templated_conv1d_tap_kernel(const float* __restrict__ weight, const float* __restrict__ in, + float* __restrict__ out, int num_frames) +{ + static_assert(OutCh % Groups == 0, "OutCh must be divisible by Groups"); + static_assert(InCh % Groups == 0, "InCh must be divisible by Groups"); + constexpr int OutPerG = OutCh / Groups; + constexpr int InPerG = InCh / Groups; + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = in + f * InCh; + float* __restrict__ out_col = out + f * OutCh; + for (int g = 0; g < Groups; g++) + { + const int o_base = g * OutPerG; + const int i_base = g * InPerG; + for (int o = 0; o < OutPerG; o++) + { + float sum = 0.0f; + for (int i = 0; i < InPerG; i++) + { + sum += weight[(i_base + i) * OutCh + (o_base + o)] * in_col[i_base + i]; + } + out_col[o_base + o] += sum; + } + } + } +} + +// Map (out_channels, in_channels, groups) -> templated tap-kernel function pointer. +// Returns nullptr for unregistered shapes; caller falls back to existing inline / +// Eigen GEMM cascade. Depthwise (groups == channels) is handled by Conv1D's existing +// _is_depthwise path and is intentionally not registered here. +nam::Conv1D::TapKernel pick_conv1d_tap_kernel(int out_channels, int in_channels, int groups) +{ + using K = nam::Conv1D::TapKernel; + if (out_channels == 4 && in_channels == 4) + { + if (groups == 1) + return static_cast(&templated_conv1d_tap_kernel<4, 4, 1>); + if (groups == 2) + return static_cast(&templated_conv1d_tap_kernel<4, 4, 2>); + } + if (out_channels == 6 && in_channels == 6) + { + if (groups == 1) + return static_cast(&templated_conv1d_tap_kernel<6, 6, 1>); + if (groups == 2) + return static_cast(&templated_conv1d_tap_kernel<6, 6, 2>); + if (groups == 3) + return static_cast(&templated_conv1d_tap_kernel<6, 6, 3>); + } + if (out_channels == 8 && in_channels == 8) + { + if (groups == 1) + return static_cast(&templated_conv1d_tap_kernel<8, 8, 1>); + if (groups == 2) + return static_cast(&templated_conv1d_tap_kernel<8, 8, 2>); + if (groups == 4) + return static_cast(&templated_conv1d_tap_kernel<8, 8, 4>); + } + if (out_channels == 12 && in_channels == 12) + { + if (groups == 1) + return static_cast(&templated_conv1d_tap_kernel<12, 12, 1>); + if (groups == 2) + return static_cast(&templated_conv1d_tap_kernel<12, 12, 2>); + if (groups == 3) + return static_cast(&templated_conv1d_tap_kernel<12, 12, 3>); + if (groups == 4) + return static_cast(&templated_conv1d_tap_kernel<12, 12, 4>); + if (groups == 6) + return static_cast(&templated_conv1d_tap_kernel<12, 12, 6>); + } + if (out_channels == 16 && in_channels == 16) + { + if (groups == 1) + return static_cast(&templated_conv1d_tap_kernel<16, 16, 1>); + if (groups == 2) + return static_cast(&templated_conv1d_tap_kernel<16, 16, 2>); + if (groups == 4) + return static_cast(&templated_conv1d_tap_kernel<16, 16, 4>); + if (groups == 8) + return static_cast(&templated_conv1d_tap_kernel<16, 16, 8>); + } + return nullptr; +} +} // namespace + // Conv1D ===================================================================== void Conv1D::set_weights_(std::vector::iterator& weights) @@ -86,6 +184,7 @@ void Conv1D::set_size_(const int in_channels, const int out_channels, const int this->_depthwise_weight[i].setZero(); } this->_weight.clear(); // Not used for depthwise + this->_tap_kernel = nullptr; } else { @@ -99,6 +198,10 @@ void Conv1D::set_size_(const int in_channels, const int out_channels, const int } this->_depthwise_weight.clear(); // Not used for non-depthwise this->_channels = 0; + // Look up a shape-specialized templated per-tap kernel. Skips zeros for grouped + // cases and bypasses Eigen GEMM for small dense cases. nullptr -> fall back to + // existing inline / Eigen GEMM cascade. + this->_tap_kernel = pick_conv1d_tap_kernel(out_channels, in_channels, groups); } if (do_bias) @@ -251,6 +354,21 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames) } #endif } + else if (this->_tap_kernel != nullptr) + { + // Shape-specialized templated per-tap kernel (constexpr-unrolled, skips off-diagonal + // zeros for grouped cases). Accumulates across taps so output must be zeroed first. + _output.leftCols(num_frames).setZero(); + const size_t kernel_size = this->_weight.size(); + float* __restrict__ output_ptr = _output.data(); + for (size_t k = 0; k < kernel_size; k++) + { + const long offset = this->_dilation * (k + 1 - (long)kernel_size); + const long lookback = -offset; + auto input_block = _input_buffer.Read(num_frames, lookback); + this->_tap_kernel(this->_weight[k].data(), input_block.data(), output_ptr, num_frames); + } + } else { #ifdef NAM_USE_INLINE_GEMM @@ -736,6 +854,20 @@ void Conv1D::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, con this->_depthwise_weight[k].asDiagonal() * input.middleCols(i_start + offset, ncols); } } + else if (this->_tap_kernel != nullptr && input.outerStride() == input.rows() && output.outerStride() == output.rows()) + { + // Shape-specialized templated per-tap kernel; accumulates so zero the output slice first. + // Guarded by the stride check because the kernel assumes contiguous column-major storage. + output.middleCols(j_start, ncols).setZero(); + float* __restrict__ out_ptr = output.data() + j_start * output.rows(); + const size_t kernel_size = this->_weight.size(); + for (size_t k = 0; k < kernel_size; k++) + { + const long offset = this->_dilation * (k + 1 - (long)kernel_size); + const float* __restrict__ in_ptr = input.data() + (i_start + offset) * input.rows(); + this->_tap_kernel(this->_weight[k].data(), in_ptr, out_ptr, (int)ncols); + } + } else { // Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal), diff --git a/NAM/conv1d.h b/NAM/conv1d.h index 8f006864..ced9e440 100644 --- a/NAM/conv1d.h +++ b/NAM/conv1d.h @@ -117,6 +117,11 @@ class Conv1D /// \return true if bias is present, false otherwise bool has_bias() const { return this->_bias.size() > 0; }; + // Function pointer to a shape-specialized per-tap GEMM kernel that accumulates into + // the output buffer (out += weight * in). Public so the dispatch table in conv1d.cpp + // can return values of this type without exposing internals. + using TapKernel = void (*)(const float* weight, const float* in, float* out, int num_frames); + protected: // conv[kernel](cout, cin) - used for non-depthwise convolutions std::vector _weight; @@ -129,6 +134,12 @@ class Conv1D int _dilation; int _num_groups; + // Set at construction time when (in_channels, out_channels, groups) matches a + // registered template specialization. When non-null, the non-depthwise Process / + // process_ paths invoke this per tap instead of running a dense Eigen / inline GEMM + // through the block-diagonal zero structure. nullptr -> fall back to generic. + TapKernel _tap_kernel = nullptr; + private: RingBuffer _input_buffer; // Ring buffer for input (channels x buffer_size) Eigen::MatrixXf _output; // Pre-allocated output buffer (out_channels x maxBufferSize) diff --git a/NAM/dsp.cpp b/NAM/dsp.cpp index e975001b..802579e0 100644 --- a/NAM/dsp.cpp +++ b/NAM/dsp.cpp @@ -339,6 +339,106 @@ static nam::ConfigParserHelper _register_Linear("Linear", nam::linear::create_co // Conv1x1 ==================================================================== +namespace +{ +// Templated dense/grouped 1x1 kernel. +// OutCh, InCh, Groups are compile-time constants so the compiler unrolls every loop +// and folds all index arithmetic. Off-block-diagonal zeros are never visited. +// Weight memory layout is col-major (out_channels rows x in_channels cols) - +// matching Eigen::MatrixXf default storage in nam::Conv1x1::_weight. +template +void templated_conv1x1_kernel(const float* __restrict__ weight, const float* __restrict__ in, float* __restrict__ out, + int num_frames, int in_stride) +{ + static_assert(OutCh % Groups == 0, "OutCh must be divisible by Groups"); + static_assert(InCh % Groups == 0, "InCh must be divisible by Groups"); + constexpr int OutPerG = OutCh / Groups; + constexpr int InPerG = InCh / Groups; + for (int f = 0; f < num_frames; f++) + { + const float* __restrict__ in_col = in + f * in_stride; + float* __restrict__ out_col = out + f * OutCh; + for (int g = 0; g < Groups; g++) + { + constexpr int row_offset_per_group = OutPerG; + constexpr int col_offset_per_group = InPerG; + const int o_base = g * row_offset_per_group; + const int i_base = g * col_offset_per_group; + for (int o = 0; o < OutPerG; o++) + { + float sum = 0.0f; + for (int i = 0; i < InPerG; i++) + { + sum += weight[(i_base + i) * OutCh + (o_base + o)] * in_col[i_base + i]; + } + out_col[o_base + o] = sum; + } + } + } +} + +// Map (out_channels, in_channels, groups) -> templated kernel function pointer. +// Returns nullptr when no specialization is registered; caller falls back to the +// generic Eigen / inline-GEMM path. +nam::Conv1x1::ProcessKernel pick_conv1x1_kernel(int out_channels, int in_channels, int groups) +{ + using K = nam::Conv1x1::ProcessKernel; + // Square shapes (the layer1x1 / head1x1 / FiLM cases that dominate WaveNet). + // Depthwise (groups == channels) is handled by the dedicated _is_depthwise path + // and is intentionally not registered here. + if (out_channels == 4 && in_channels == 4) + { + if (groups == 1) + return static_cast(&templated_conv1x1_kernel<4, 4, 1>); + if (groups == 2) + return static_cast(&templated_conv1x1_kernel<4, 4, 2>); + } + if (out_channels == 6 && in_channels == 6) + { + if (groups == 1) + return static_cast(&templated_conv1x1_kernel<6, 6, 1>); + if (groups == 2) + return static_cast(&templated_conv1x1_kernel<6, 6, 2>); + if (groups == 3) + return static_cast(&templated_conv1x1_kernel<6, 6, 3>); + } + if (out_channels == 8 && in_channels == 8) + { + if (groups == 1) + return static_cast(&templated_conv1x1_kernel<8, 8, 1>); + if (groups == 2) + return static_cast(&templated_conv1x1_kernel<8, 8, 2>); + if (groups == 4) + return static_cast(&templated_conv1x1_kernel<8, 8, 4>); + } + if (out_channels == 12 && in_channels == 12) + { + if (groups == 1) + return static_cast(&templated_conv1x1_kernel<12, 12, 1>); + if (groups == 2) + return static_cast(&templated_conv1x1_kernel<12, 12, 2>); + if (groups == 3) + return static_cast(&templated_conv1x1_kernel<12, 12, 3>); + if (groups == 4) + return static_cast(&templated_conv1x1_kernel<12, 12, 4>); + if (groups == 6) + return static_cast(&templated_conv1x1_kernel<12, 12, 6>); + } + if (out_channels == 16 && in_channels == 16) + { + if (groups == 1) + return static_cast(&templated_conv1x1_kernel<16, 16, 1>); + if (groups == 2) + return static_cast(&templated_conv1x1_kernel<16, 16, 2>); + if (groups == 4) + return static_cast(&templated_conv1x1_kernel<16, 16, 4>); + if (groups == 8) + return static_cast(&templated_conv1x1_kernel<16, 16, 8>); + } + return nullptr; +} +} // namespace + nam::Conv1x1::Conv1x1(const int in_channels, const int out_channels, const bool _bias, const int groups) { // Validate that channels divide evenly by groups @@ -376,6 +476,9 @@ nam::Conv1x1::Conv1x1(const int in_channels, const int out_channels, const bool this->_weight.resize(out_channels, in_channels); this->_weight.setZero(); this->_channels = 0; + // Look up a shape-specialized templated kernel. Skips zeros for grouped cases and + // bypasses Eigen GEMM for small dense cases. nullptr -> fall back to generic kernel. + this->_kernel = pick_conv1x1_kernel(out_channels, in_channels, groups); } if (_bias) @@ -452,9 +555,14 @@ Eigen::MatrixXf nam::Conv1x1::process(const Eigen::MatrixXf& input, const int nu // Each channel is scaled by its corresponding weight result.noalias() = this->_depthwise_weight.asDiagonal() * input.leftCols(num_frames); } + else if (this->_kernel != nullptr) + { + // Shape-specialized templated kernel (constexpr-unrolled, skips off-diagonal zeros). + this->_kernel(this->_weight.data(), input.data(), result.data(), num_frames, (int)input.outerStride()); + } else { - // Single GEMM for all cases - block-diagonal zero structure handles grouping + // Generic fallback: single dense GEMM through the block-diagonal zero structure. result.noalias() = this->_weight * input.leftCols(num_frames); } @@ -477,6 +585,12 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons // Each channel is scaled by its corresponding weight _output.leftCols(num_frames).noalias() = this->_depthwise_weight.asDiagonal() * input.leftCols(num_frames); } + else if (this->_kernel != nullptr) + { + // Shape-specialized templated kernel (constexpr-unrolled, skips off-diagonal zeros + // for grouped cases). Bias is applied after this block by the shared bias path. + this->_kernel(this->_weight.data(), input.data(), _output.data(), num_frames, (int)input.outerStride()); + } else { #ifdef NAM_USE_INLINE_GEMM @@ -745,7 +859,9 @@ void nam::Conv1x1::process_(const Eigen::Ref& input, cons } } #else - // Single GEMM for all cases - block-diagonal zero structure handles grouping + // Single GEMM for all cases - block-diagonal zero structure handles grouping. + // Per-group Eigen blocks were tried but small-block GEMM overhead dominates; + // see the inline-GEMM path above for grouped-specific kernels. _output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames); #endif } diff --git a/NAM/dsp.h b/NAM/dsp.h index 1fadcf70..da5469af 100644 --- a/NAM/dsp.h +++ b/NAM/dsp.h @@ -352,6 +352,10 @@ class Conv1x1 long get_out_channels() const; long get_in_channels() const; + // Function pointer to a shape-specialized GEMM kernel. Public so the dispatch table + // in dsp.cpp can return values of this type without exposing internals. + using ProcessKernel = void (*)(const float* weight, const float* in, float* out, int num_frames, int in_stride); + protected: // Non-depthwise: full weight matrix (out_channels x in_channels) Eigen::MatrixXf _weight; @@ -363,6 +367,12 @@ class Conv1x1 Eigen::VectorXf _bias; int _num_groups; + // Set at construction time when (in_channels, out_channels, groups) matches a + // registered template specialization. When non-null, used by both the Eigen and + // inline-GEMM process_ paths in preference to the generic dense kernel. + // nullptr -> fall back to generic. + ProcessKernel _kernel = nullptr; + private: Eigen::MatrixXf _output; bool _do_bias; diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 492fb676..f5279e80 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -50,6 +50,42 @@ set_target_properties(bench_a2_fast PROPERTIES INTERPROCEDURAL_OPTIMIZATION TRUE PREFIX "" ) +add_executable(bench_conv1x1_groups bench_conv1x1_groups.cpp ${NAM_SOURCES}) +target_compile_features(bench_conv1x1_groups PUBLIC cxx_std_20) +set_target_properties(bench_conv1x1_groups PROPERTIES + CXX_VISIBILITY_PRESET hidden + INTERPROCEDURAL_OPTIMIZATION TRUE + PREFIX "" +) +if (MSVC) + target_compile_options(bench_conv1x1_groups PRIVATE + "$<$:/W4>" + "$<$:/O2>" + ) +else() + target_compile_options(bench_conv1x1_groups PRIVATE + -Wall -Wextra -Wno-unused-parameter + "$<$:-Ofast>" + ) +endif() +add_executable(check_conv1d_grouped check_conv1d_grouped.cpp ${NAM_SOURCES}) +target_compile_features(check_conv1d_grouped PUBLIC cxx_std_20) +set_target_properties(check_conv1d_grouped PROPERTIES + CXX_VISIBILITY_PRESET hidden + INTERPROCEDURAL_OPTIMIZATION TRUE + PREFIX "" +) +if (MSVC) + target_compile_options(check_conv1d_grouped PRIVATE + "$<$:/W4>" + "$<$:/O2>" + ) +else() + target_compile_options(check_conv1d_grouped PRIVATE + -Wall -Wextra -Wno-unused-parameter + "$<$:-O2>" + ) +endif() if (MSVC) target_compile_options(bench_a2_fast PRIVATE "$<$:/W4>" diff --git a/tools/bench_conv1x1_groups.cpp b/tools/bench_conv1x1_groups.cpp new file mode 100644 index 00000000..f63db8f7 --- /dev/null +++ b/tools/bench_conv1x1_groups.cpp @@ -0,0 +1,167 @@ +// Microbenchmark and correctness check for grouped Conv1x1. +// Sweeps registered (channels, groups) combinations. +// Mirrors the "1x1_groups" plot in issue #215. + +#include +#include +#include +#include +#include + +#include "NAM/dsp.h" + +using clk = std::chrono::high_resolution_clock; + +// Reference implementation: dense block-diagonal GEMM. Should match templated kernel +// bit-for-bit when groups==1, and within float-rounding tolerance otherwise (different +// accumulation order). Used by check_correctness(). +static Eigen::MatrixXf reference_conv1x1(int channels, int groups, const std::vector& weights, + const Eigen::MatrixXf& input) +{ + Eigen::MatrixXf W(channels, channels); + W.setZero(); + const int per_group = channels / groups; + size_t idx = 0; + for (int g = 0; g < groups; g++) + { + for (int i = 0; i < per_group; i++) + { + for (int j = 0; j < per_group; j++) + { + W(g * per_group + i, g * per_group + j) = weights[idx++]; + } + } + } + return W * input; +} + +// Verify templated kernel produces identical output (within tolerance) to a plain +// dense GEMM reference for the given (channels, groups). Returns max abs diff. +static double check_correctness(int channels, int groups, int frames, std::mt19937& rng) +{ + std::uniform_real_distribution dist(-1.0f, 1.0f); + const int per_group = channels / groups; + std::vector weights(groups * per_group * per_group); + for (auto& w : weights) + w = dist(rng); + + nam::Conv1x1 conv(channels, channels, /*bias=*/false, groups); + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(frames); + + Eigen::MatrixXf input(channels, frames); + for (int r = 0; r < channels; r++) + for (int c = 0; c < frames; c++) + input(r, c) = dist(rng); + + conv.process_(input, frames); + const Eigen::MatrixXf& got = conv.GetOutput(); + Eigen::MatrixXf want = reference_conv1x1(channels, groups, weights, input); + + double max_diff = 0.0; + for (int r = 0; r < channels; r++) + for (int c = 0; c < frames; c++) + max_diff = std::max(max_diff, (double)std::abs(got(r, c) - want(r, c))); + return max_diff; +} + +static double bench_one(int channels, int groups, int frames, int iters) +{ + std::vector weights; + // Per-group block has (channels/groups) * (channels/groups) weights. + const int per_group = channels / groups; + weights.resize(groups * per_group * per_group, 0.123f); + + nam::Conv1x1 conv(channels, channels, /*bias=*/false, groups); + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(frames); + + Eigen::MatrixXf input(channels, frames); + input.setRandom(); + + // Warmup + for (int i = 0; i < 100; i++) + conv.process_(input, frames); + + auto t1 = clk::now(); + for (int i = 0; i < iters; i++) + conv.process_(input, frames); + auto t2 = clk::now(); + + // Read output to defeat dead-store elimination. + volatile float sink = conv.GetOutput().sum(); + (void)sink; + + return std::chrono::duration(t2 - t1).count() / iters; +} + +int main(int argc, char** argv) +{ + const int frames = 64; + const int iters = 2'000'000; + + std::cout << "Conv1x1 microbench: frames=" << frames << " iters=" << iters << "\n"; +#ifdef NAM_USE_INLINE_GEMM + std::cout << "Build: NAM_USE_INLINE_GEMM\n"; +#else + std::cout << "Build: standard (Eigen GEMM)\n"; +#endif + + struct Shape + { + int channels; + std::vector groups; + }; + const std::vector shapes = { + {4, {1, 2, 4}}, + {6, {1, 2, 3, 6}}, + {8, {1, 2, 4, 8}}, + {12, {1, 2, 3, 4, 6, 12}}, + {16, {1, 2, 4, 8, 16}}, + }; + + // Correctness gate: compare templated kernel against reference dense GEMM for every + // registered shape with random weights and random input. Bail with non-zero exit if + // any shape fails the tolerance check. + std::cout << "\n== Correctness check (templated vs reference dense GEMM) ==\n"; + std::mt19937 rng(0xC0FFEE); + const double tol = 1e-4; // accounts for accumulation-order rounding + bool ok = true; + for (const auto& s : shapes) + { + for (int g : s.groups) + { + double diff = check_correctness(s.channels, g, frames, rng); + const bool pass = diff < tol; + std::cout << " ch=" << s.channels << " G=" << g << " max_abs_diff=" << diff << (pass ? " OK" : " FAIL") + << "\n"; + if (!pass) + ok = false; + } + } + if (!ok) + { + std::cerr << "FAIL: at least one shape exceeded tolerance " << tol << "\n"; + return 2; + } + + for (const auto& s : shapes) + { + std::cout << "\n-- channels=" << s.channels << " --\n"; + for (int g : s.groups) + { + double best = 1e18; + for (int r = 0; r < 3; r++) + { + double ns = bench_one(s.channels, g, frames, iters); + if (ns < best) + best = ns; + } + std::cout << "groups=" << g << " per_call=" << best << " ns" + << " per_frame=" << (best / frames) << " ns\n"; + } + } + return 0; +} diff --git a/tools/check_conv1d_grouped.cpp b/tools/check_conv1d_grouped.cpp new file mode 100644 index 00000000..a5a144c3 --- /dev/null +++ b/tools/check_conv1d_grouped.cpp @@ -0,0 +1,137 @@ +// Correctness check for grouped Conv1D templated tap kernel. +// Compares Conv1D::Process output against a reference dense block-diagonal GEMM +// for every (channels, groups, kernel_size, dilation) shape that the templated +// dispatch registers. Exits non-zero if any shape exceeds float-rounding tolerance. + +#include +#include +#include +#include +#include + +#include "NAM/conv1d.h" + +namespace +{ +Eigen::MatrixXf reference_conv1d(int channels, int groups, int kernel_size, int dilation, + const std::vector& weights, const Eigen::MatrixXf& padded_input, + int num_frames, int lookback_max) +{ + // Reconstruct dense block-diagonal weight matrices, one per kernel tap. + const int per_group = channels / groups; + std::vector W(kernel_size, Eigen::MatrixXf::Zero(channels, channels)); + // Conv1D weight layout (set_weights_): for each group, for each out-in pair, + // for each kernel position, one weight. + size_t idx = 0; + for (int g = 0; g < groups; g++) + { + for (int i = 0; i < per_group; i++) + { + for (int j = 0; j < per_group; j++) + { + for (int k = 0; k < kernel_size; k++) + { + W[k](g * per_group + i, g * per_group + j) = weights[idx++]; + } + } + } + } + Eigen::MatrixXf out(channels, num_frames); + out.setZero(); + // padded_input has lookback_max columns of history prepended. + for (int k = 0; k < kernel_size; k++) + { + const long offset = dilation * (k + 1 - kernel_size); // <= 0 + const long lookback = -offset; + out.noalias() += W[k] * padded_input.middleCols(lookback_max - lookback, num_frames); + } + return out; +} + +double check_one(int channels, int groups, int kernel_size, int dilation, int frames, std::mt19937& rng) +{ + std::uniform_real_distribution dist(-1.0f, 1.0f); + const int per_group = channels / groups; + std::vector weights(groups * per_group * per_group * kernel_size + channels); + for (auto& w : weights) + w = dist(rng); + + nam::Conv1D conv(channels, channels, kernel_size, /*bias=*/true, dilation, groups); + auto it = weights.begin(); + conv.set_weights_(it); + conv.SetMaxBufferSize(frames); + + // Generate an input matrix; also build a padded version with lookback history (zeros) + // so the reference can do the same dilated taps. + Eigen::MatrixXf input(channels, frames); + for (int r = 0; r < channels; r++) + for (int c = 0; c < frames; c++) + input(r, c) = dist(rng); + + conv.Process(input, frames); + const Eigen::MatrixXf& got = conv.GetOutput(); + + const int lookback_max = (kernel_size - 1) * dilation; + Eigen::MatrixXf padded(channels, lookback_max + frames); + padded.setZero(); + padded.rightCols(frames) = input; + + Eigen::MatrixXf want = reference_conv1d(channels, groups, kernel_size, dilation, weights, padded, frames, lookback_max); + // Add bias. + const float* bias = weights.data() + (weights.size() - channels); + for (int r = 0; r < channels; r++) + for (int c = 0; c < frames; c++) + want(r, c) += bias[r]; + + double max_diff = 0.0; + for (int r = 0; r < channels; r++) + for (int c = 0; c < frames; c++) + max_diff = std::max(max_diff, (double)std::abs(got(r, c) - want(r, c))); + return max_diff; +} +} // namespace + +int main() +{ + struct Shape + { + int channels; + std::vector groups; + }; + const std::vector shapes = { + {4, {1, 2, 4}}, + {6, {1, 2, 3, 6}}, + {8, {1, 2, 4, 8}}, + {12, {1, 2, 3, 4, 6, 12}}, + {16, {1, 2, 4, 8, 16}}, + }; + const std::vector kernel_sizes = {1, 2, 3}; + const std::vector dilations = {1, 2, 7}; + const int frames = 64; + const double tol = 1e-4; + + std::mt19937 rng(0xBADCAB); + bool ok = true; + int n = 0; + for (const auto& s : shapes) + { + for (int g : s.groups) + { + for (int K : kernel_sizes) + { + for (int D : dilations) + { + const double diff = check_one(s.channels, g, K, D, frames, rng); + const bool pass = diff < tol; + std::cout << "ch=" << s.channels << " G=" << g << " K=" << K << " D=" << D << " diff=" << diff + << (pass ? " OK" : " FAIL") << "\n"; + if (!pass) + ok = false; + n++; + } + } + } + } + std::cout << (ok ? "ALL OK" : "FAIL") << " (" << n << " shapes)\n"; + return ok ? 0 : 1; +}