From 1d24a1c217bc63fce9d5707d4ab0088dca502444 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Thu, 12 Mar 2026 23:01:15 -0700 Subject: [PATCH 1/3] Get rid of SlimmableWaveNet architecture, detect slimmable functionality within WaveNet architecture --- NAM/slimmable_wavenet.cpp | 22 ++++++++------- NAM/wavenet.cpp | 23 +++++++++++++++ example_models/slimmable_wavenet.nam | 10 +++---- tools/run_tests.cpp | 2 +- tools/test/test_slimmable_wavenet.cpp | 40 ++++++++++----------------- 5 files changed, 55 insertions(+), 42 deletions(-) diff --git a/NAM/slimmable_wavenet.cpp b/NAM/slimmable_wavenet.cpp index e31e306..64f5c9a 100644 --- a/NAM/slimmable_wavenet.cpp +++ b/NAM/slimmable_wavenet.cpp @@ -411,8 +411,9 @@ void SlimmableWavenet::SetSlimmableSize(const double val) std::unique_ptr SlimmableWavenetConfig::create(std::vector weights, double sampleRate) { - // Parse the WaveNet model config into typed params - nlohmann::json model_json = raw_config["model"]; + // Parse the WaveNet model config — support both wrapped {"model": {...}} and flat config + nlohmann::json model_json = + raw_config.contains("model") ? raw_config["model"] : raw_config; auto wc = wavenet::parse_config_json(model_json, sampleRate); // Extract per-array allowed_channels from slimmable config fields @@ -433,12 +434,19 @@ std::unique_ptr SlimmableWavenetConfig::create(std::vector weights, for (const auto& ch : slim_cfg["kwargs"]["allowed_channels"]) allowed.push_back(ch.get()); } + else + { + // Missing allowed_channels: assume [1, 2, ..., channels] for slice_channels_uniform + const int channels = lc["channels"].get(); + for (int c = 1; c <= channels; c++) + allowed.push_back(c); + } } per_array_allowed.push_back(std::move(allowed)); } - // Extract condition_dsp JSON for future rebuilds - nlohmann::json condition_dsp_json = nullptr; + // Extract condition_dsp JSON for future rebuilds (in model config) + nlohmann::json condition_dsp_json; if (model_json.find("condition_dsp") != model_json.end() && !model_json["condition_dsp"].is_null()) condition_dsp_json = model_json["condition_dsp"]; @@ -455,11 +463,5 @@ std::unique_ptr create_config(const nlohmann::json& config, double return sc; } -// Auto-register with the config parser registry -namespace -{ -static ConfigParserHelper _register_SlimmableWavenet("SlimmableWavenet", nam::slimmable_wavenet::create_config); -} - } // namespace slimmable_wavenet } // namespace nam diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 522e304..6a3cbf5 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -8,6 +8,7 @@ #include "get_dsp.h" #include "registry.h" +#include "slimmable_wavenet.h" #include "wavenet.h" // Layer ====================================================================== @@ -951,9 +952,31 @@ std::unique_ptr nam::wavenet::WaveNetConfig::create(std::vector std::move(weights), std::move(condition_dsp), sampleRate); } +namespace +{ +bool config_is_slimmable_wavenet(const nlohmann::json& config) +{ + if (config.find("layers") == config.end() || !config["layers"].is_array()) + return false; + const std::string recognized_method = "slice_channels_uniform"; + for (const auto& lc : config["layers"]) + { + if (lc.find("slimmable") == lc.end() || !lc["slimmable"].is_object()) + continue; + const std::string method = lc["slimmable"].value("method", ""); + if (method == recognized_method) + return true; + } + return false; +} +} // namespace + // Config parser for ConfigParserRegistry std::unique_ptr nam::wavenet::create_config(const nlohmann::json& config, double sampleRate) { + if (config_is_slimmable_wavenet(config)) + return nam::slimmable_wavenet::create_config(config, sampleRate); + auto wc = std::make_unique(); auto parsed = parse_config_json(config, sampleRate); *wc = std::move(parsed); diff --git a/example_models/slimmable_wavenet.nam b/example_models/slimmable_wavenet.nam index 19eafd1..4075544 100644 --- a/example_models/slimmable_wavenet.nam +++ b/example_models/slimmable_wavenet.nam @@ -1,10 +1,9 @@ { "version": "0.7.0", "metadata": {}, - "architecture": "SlimmableWavenet", + "architecture": "WaveNet", "config": { - "model": { - "layers": [ + "layers": [ { "input_size": 1, "condition_size": 1, @@ -34,9 +33,8 @@ } } ], - "head": null, - "head_scale": 0.02 - } + "head": null, + "head_scale": 0.02 }, "weights": [ 0.2788535969157675, diff --git a/tools/run_tests.cpp b/tools/run_tests.cpp index 8bfd954..38aa5b5 100644 --- a/tools/run_tests.cpp +++ b/tools/run_tests.cpp @@ -298,7 +298,7 @@ int main() test_slimmable_wavenet::test_default_is_max_size(); test_slimmable_wavenet::test_ratio_mapping(); test_slimmable_wavenet::test_from_json(); - test_slimmable_wavenet::test_no_slimmable_layers_throws(); + test_slimmable_wavenet::test_wavenet_without_slimmable_loads_as_regular(); test_slimmable_wavenet::test_unsupported_method_throws(); test_slimmable_wavenet::test_slimmed_matches_small_model(); diff --git a/tools/test/test_slimmable_wavenet.cpp b/tools/test/test_slimmable_wavenet.cpp index 44b9b5f..388ceea 100644 --- a/tools/test/test_slimmable_wavenet.cpp +++ b/tools/test/test_slimmable_wavenet.cpp @@ -254,43 +254,33 @@ void test_from_json() process_and_verify(dsp.get(), 3, 64); } -void test_no_slimmable_layers_throws() +void test_wavenet_without_slimmable_loads_as_regular() { - - + // WaveNet config without slimmable allowed_channels loads as regular WaveNet auto wavenet_json = load_nam_json("example_models/wavenet.nam"); nlohmann::json j; j["version"] = "0.7.0"; - j["architecture"] = "SlimmableWavenet"; - j["config"]["model"] = wavenet_json["config"]; - // No slimmable field on any layer -> all allowed_channels empty -> should throw + j["architecture"] = "WaveNet"; + j["config"] = wavenet_json["config"]; j["weights"] = wavenet_json["weights"]; j["sample_rate"] = wavenet_json["sample_rate"]; - bool threw = false; - try - { - auto dsp = nam::get_dsp(j); - } - catch (const std::runtime_error&) - { - threw = true; - } - assert(threw); + auto dsp = nam::get_dsp(j); + assert(dsp != nullptr); + assert(dynamic_cast(dsp.get()) == nullptr); } void test_unsupported_method_throws() { - - + // WaveNet with slimmable but unsupported method -> throws auto wavenet_json = load_nam_json("example_models/wavenet.nam"); nlohmann::json j; j["version"] = "0.7.0"; - j["architecture"] = "SlimmableWavenet"; - j["config"]["model"] = wavenet_json["config"]; - j["config"]["model"]["layers"][0]["slimmable"] = { + j["architecture"] = "WaveNet"; + j["config"] = wavenet_json["config"]; + j["config"]["layers"][0]["slimmable"] = { {"method", "some_future_method"}, {"kwargs", {{"allowed_channels", {2, 3}}}}}; j["weights"] = wavenet_json["weights"]; j["sample_rate"] = wavenet_json["sample_rate"]; @@ -442,15 +432,15 @@ void test_slimmed_matches_small_model() auto small_dsp = nam::get_dsp(small_json); assert(small_dsp != nullptr); - // --- Build the 4ch SlimmableWavenet --- + // --- Build the 4ch slimmable WaveNet (detected from config) --- nlohmann::json large_json; large_json["version"] = "0.7.0"; - large_json["architecture"] = "SlimmableWavenet"; + large_json["architecture"] = "WaveNet"; auto large_layer_config = make_layer_config(large_ch); large_layer_config["slimmable"] = { {"method", "slice_channels_uniform"}, {"kwargs", {{"allowed_channels", {small_ch, large_ch}}}}}; - large_json["config"]["model"]["layers"] = nlohmann::json::array({large_layer_config}); - large_json["config"]["model"]["head_scale"] = 1.0; + large_json["config"]["layers"] = nlohmann::json::array({large_layer_config}); + large_json["config"]["head_scale"] = 1.0; large_json["weights"] = large_weights; large_json["sample_rate"] = 48000; From c2555e2143f177fab89ce96666532e9b6c116b35 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Thu, 12 Mar 2026 23:19:07 -0700 Subject: [PATCH 2/3] Throw if unsupported slimmable method given --- NAM/slimmable_wavenet.cpp | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/NAM/slimmable_wavenet.cpp b/NAM/slimmable_wavenet.cpp index 64f5c9a..99af681 100644 --- a/NAM/slimmable_wavenet.cpp +++ b/NAM/slimmable_wavenet.cpp @@ -159,8 +159,8 @@ std::vector extract_slimmed_weights(const std::vector head1x1_out, bias=true) if (p.head1x1_params.active) - extract_conv1x1(src, full_bn, p.head1x1_params.out_channels, slim_bn, p.head1x1_params.out_channels, true, - slim); + extract_conv1x1( + src, full_bn, p.head1x1_params.out_channels, slim_bn, p.head1x1_params.out_channels, true, slim); // ---- FiLM objects (8, in set_weights_ order) ---- @@ -260,9 +260,8 @@ std::vector modify_params_for_channels( new_input_size, p.condition_size, new_head_size, new_ch, new_bottleneck, p.kernel_size, std::vector(p.dilations), std::vector(p.activation_configs), std::vector(p.gating_modes), p.head_bias, p.groups_input, p.groups_input_mixin, - p.layer1x1_params, p.head1x1_params, - std::vector(p.secondary_activation_configs), p.conv_pre_film_params, - p.conv_post_film_params, p.input_mixin_pre_film_params, p.input_mixin_post_film_params, + p.layer1x1_params, p.head1x1_params, std::vector(p.secondary_activation_configs), + p.conv_pre_film_params, p.conv_post_film_params, p.input_mixin_pre_film_params, p.input_mixin_post_film_params, p.activation_pre_film_params, p.activation_post_film_params, p._layer1x1_post_film_params, p.head1x1_post_film_params)); } @@ -360,8 +359,8 @@ void SlimmableWavenet::_rebuild_model(const std::vector& target_channels) condition_dsp = get_dsp(_condition_dsp_json); double sampleRate = _current_sample_rate > 0 ? _current_sample_rate : GetExpectedSampleRate(); - _active_model = std::make_unique(_in_channels, *params_ptr, _head_scale, _with_head, - std::move(weights), std::move(condition_dsp), sampleRate); + _active_model = std::make_unique( + _in_channels, *params_ptr, _head_scale, _with_head, std::move(weights), std::move(condition_dsp), sampleRate); _current_channels = target_channels; if (_current_buffer_size > 0) @@ -412,8 +411,7 @@ void SlimmableWavenet::SetSlimmableSize(const double val) std::unique_ptr SlimmableWavenetConfig::create(std::vector weights, double sampleRate) { // Parse the WaveNet model config — support both wrapped {"model": {...}} and flat config - nlohmann::json model_json = - raw_config.contains("model") ? raw_config["model"] : raw_config; + nlohmann::json model_json = raw_config.contains("model") ? raw_config["model"] : raw_config; auto wc = wavenet::parse_config_json(model_json, sampleRate); // Extract per-array allowed_channels from slimmable config fields @@ -429,7 +427,8 @@ std::unique_ptr SlimmableWavenetConfig::create(std::vector weights, const std::string method = slim_cfg.value("method", ""); if (method != "slice_channels_uniform") throw std::runtime_error("SlimmableWavenet: unsupported slimmable method '" + method + "'"); - if (slim_cfg.find("kwargs") != slim_cfg.end() && slim_cfg["kwargs"].find("allowed_channels") != slim_cfg["kwargs"].end()) + if (slim_cfg.find("kwargs") != slim_cfg.end() + && slim_cfg["kwargs"].find("allowed_channels") != slim_cfg["kwargs"].end()) { for (const auto& ch : slim_cfg["kwargs"]["allowed_channels"]) allowed.push_back(ch.get()); @@ -451,8 +450,8 @@ std::unique_ptr SlimmableWavenetConfig::create(std::vector weights, condition_dsp_json = model_json["condition_dsp"]; return std::make_unique(std::move(wc.layer_array_params), std::move(per_array_allowed), - wc.in_channels, wc.head_scale, wc.with_head, - std::move(condition_dsp_json), std::move(weights), sampleRate); + wc.in_channels, wc.head_scale, wc.with_head, std::move(condition_dsp_json), + std::move(weights), sampleRate); } std::unique_ptr create_config(const nlohmann::json& config, double sampleRate) From 940758da01bfe48f2753629a1b260ad022292bce Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Thu, 12 Mar 2026 23:27:12 -0700 Subject: [PATCH 3/3] Assert recognized slimmable method if given --- NAM/wavenet.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/NAM/wavenet.cpp b/NAM/wavenet.cpp index 6a3cbf5..04faa45 100644 --- a/NAM/wavenet.cpp +++ b/NAM/wavenet.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -421,8 +422,8 @@ void nam::wavenet::_LayerArray::ProcessInner(const Eigen::MatrixXf& layer_inputs #ifdef NAM_USE_INLINE_GEMM { const int total = (int)this->_get_channels() * num_frames; - std::memcpy(this->_layer_outputs.data(), this->_layers[last_layer].GetOutputNextLayer().data(), - total * sizeof(float)); + std::memcpy( + this->_layer_outputs.data(), this->_layers[last_layer].GetOutputNextLayer().data(), total * sizeof(float)); } #else this->_layer_outputs.leftCols(num_frames).noalias() = @@ -948,24 +949,30 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json // WaveNetConfig::create() std::unique_ptr nam::wavenet::WaveNetConfig::create(std::vector weights, double sampleRate) { - return std::make_unique(in_channels, layer_array_params, head_scale, with_head, - std::move(weights), std::move(condition_dsp), sampleRate); + return std::make_unique( + in_channels, layer_array_params, head_scale, with_head, std::move(weights), std::move(condition_dsp), sampleRate); } namespace { +const std::string SLIMMABLE_METHOD = "slice_channels_uniform"; + bool config_is_slimmable_wavenet(const nlohmann::json& config) { if (config.find("layers") == config.end() || !config["layers"].is_array()) return false; - const std::string recognized_method = "slice_channels_uniform"; for (const auto& lc : config["layers"]) { if (lc.find("slimmable") == lc.end() || !lc["slimmable"].is_object()) continue; const std::string method = lc["slimmable"].value("method", ""); - if (method == recognized_method) - return true; + if (method != SLIMMABLE_METHOD) + { + if (!method.empty()) + throw std::runtime_error("SlimmableWavenet: unsupported slimmable method '" + method + "'"); + continue; + } + return true; } return false; }