Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions NAM/slimmable_wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ std::vector<float> extract_slimmed_weights(const std::vector<wavenet::LayerArray

// head1x1 (optional): Conv1x1(B -> head1x1_out, bias=true)
if (p.head1x1_params.active)
extract_conv1x1(src, full_bn, p.head1x1_params.out_channels, slim_bn, p.head1x1_params.out_channels, true,
slim);
extract_conv1x1(
src, full_bn, p.head1x1_params.out_channels, slim_bn, p.head1x1_params.out_channels, true, slim);

// ---- FiLM objects (8, in set_weights_ order) ----

Expand Down Expand Up @@ -260,9 +260,8 @@ std::vector<wavenet::LayerArrayParams> modify_params_for_channels(
new_input_size, p.condition_size, new_head_size, new_ch, new_bottleneck, p.kernel_size,
std::vector<int>(p.dilations), std::vector<activations::ActivationConfig>(p.activation_configs),
std::vector<wavenet::GatingMode>(p.gating_modes), p.head_bias, p.groups_input, p.groups_input_mixin,
p.layer1x1_params, p.head1x1_params,
std::vector<activations::ActivationConfig>(p.secondary_activation_configs), p.conv_pre_film_params,
p.conv_post_film_params, p.input_mixin_pre_film_params, p.input_mixin_post_film_params,
p.layer1x1_params, p.head1x1_params, std::vector<activations::ActivationConfig>(p.secondary_activation_configs),
p.conv_pre_film_params, p.conv_post_film_params, p.input_mixin_pre_film_params, p.input_mixin_post_film_params,
p.activation_pre_film_params, p.activation_post_film_params, p._layer1x1_post_film_params,
p.head1x1_post_film_params));
}
Expand Down Expand Up @@ -360,8 +359,8 @@ void SlimmableWavenet::_rebuild_model(const std::vector<int>& target_channels)
condition_dsp = get_dsp(_condition_dsp_json);

double sampleRate = _current_sample_rate > 0 ? _current_sample_rate : GetExpectedSampleRate();
_active_model = std::make_unique<wavenet::WaveNet>(_in_channels, *params_ptr, _head_scale, _with_head,
std::move(weights), std::move(condition_dsp), sampleRate);
_active_model = std::make_unique<wavenet::WaveNet>(
_in_channels, *params_ptr, _head_scale, _with_head, std::move(weights), std::move(condition_dsp), sampleRate);
_current_channels = target_channels;

if (_current_buffer_size > 0)
Expand Down Expand Up @@ -411,8 +410,8 @@ void SlimmableWavenet::SetSlimmableSize(const double val)

std::unique_ptr<DSP> SlimmableWavenetConfig::create(std::vector<float> weights, double sampleRate)
{
// Parse the WaveNet model config into typed params
nlohmann::json model_json = raw_config["model"];
// Parse the WaveNet model config — support both wrapped {"model": {...}} and flat config
nlohmann::json model_json = raw_config.contains("model") ? raw_config["model"] : raw_config;
auto wc = wavenet::parse_config_json(model_json, sampleRate);

// Extract per-array allowed_channels from slimmable config fields
Expand All @@ -428,23 +427,31 @@ std::unique_ptr<DSP> SlimmableWavenetConfig::create(std::vector<float> weights,
const std::string method = slim_cfg.value("method", "");
if (method != "slice_channels_uniform")
throw std::runtime_error("SlimmableWavenet: unsupported slimmable method '" + method + "'");
if (slim_cfg.find("kwargs") != slim_cfg.end() && slim_cfg["kwargs"].find("allowed_channels") != slim_cfg["kwargs"].end())
if (slim_cfg.find("kwargs") != slim_cfg.end()
&& slim_cfg["kwargs"].find("allowed_channels") != slim_cfg["kwargs"].end())
{
for (const auto& ch : slim_cfg["kwargs"]["allowed_channels"])
allowed.push_back(ch.get<int>());
}
else
{
// Missing allowed_channels: assume [1, 2, ..., channels] for slice_channels_uniform
const int channels = lc["channels"].get<int>();
for (int c = 1; c <= channels; c++)
allowed.push_back(c);
}
}
per_array_allowed.push_back(std::move(allowed));
}

// Extract condition_dsp JSON for future rebuilds
nlohmann::json condition_dsp_json = nullptr;
// Extract condition_dsp JSON for future rebuilds (in model config)
nlohmann::json condition_dsp_json;
if (model_json.find("condition_dsp") != model_json.end() && !model_json["condition_dsp"].is_null())
condition_dsp_json = model_json["condition_dsp"];

return std::make_unique<SlimmableWavenet>(std::move(wc.layer_array_params), std::move(per_array_allowed),
wc.in_channels, wc.head_scale, wc.with_head,
std::move(condition_dsp_json), std::move(weights), sampleRate);
wc.in_channels, wc.head_scale, wc.with_head, std::move(condition_dsp_json),
std::move(weights), sampleRate);
}

std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double sampleRate)
Expand All @@ -455,11 +462,5 @@ std::unique_ptr<ModelConfig> create_config(const nlohmann::json& config, double
return sc;
}

// Auto-register with the config parser registry
namespace
{
static ConfigParserHelper _register_SlimmableWavenet("SlimmableWavenet", nam::slimmable_wavenet::create_config);
}

} // namespace slimmable_wavenet
} // namespace nam
38 changes: 34 additions & 4 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
#include <iostream>
#include <math.h>
#include <sstream>
#include <stdexcept>

#include <Eigen/Dense>

#include "get_dsp.h"
#include "registry.h"
#include "slimmable_wavenet.h"
#include "wavenet.h"

// Layer ======================================================================
Expand Down Expand Up @@ -420,8 +422,8 @@ void nam::wavenet::_LayerArray::ProcessInner(const Eigen::MatrixXf& layer_inputs
#ifdef NAM_USE_INLINE_GEMM
{
const int total = (int)this->_get_channels() * num_frames;
std::memcpy(this->_layer_outputs.data(), this->_layers[last_layer].GetOutputNextLayer().data(),
total * sizeof(float));
std::memcpy(
this->_layer_outputs.data(), this->_layers[last_layer].GetOutputNextLayer().data(), total * sizeof(float));
}
#else
this->_layer_outputs.leftCols(num_frames).noalias() =
Expand Down Expand Up @@ -947,13 +949,41 @@ nam::wavenet::WaveNetConfig nam::wavenet::parse_config_json(const nlohmann::json
// WaveNetConfig::create()
std::unique_ptr<nam::DSP> nam::wavenet::WaveNetConfig::create(std::vector<float> weights, double sampleRate)
{
return std::make_unique<nam::wavenet::WaveNet>(in_channels, layer_array_params, head_scale, with_head,
std::move(weights), std::move(condition_dsp), sampleRate);
return std::make_unique<nam::wavenet::WaveNet>(
in_channels, layer_array_params, head_scale, with_head, std::move(weights), std::move(condition_dsp), sampleRate);
}

namespace
{
const std::string SLIMMABLE_METHOD = "slice_channels_uniform";

bool config_is_slimmable_wavenet(const nlohmann::json& config)
{
if (config.find("layers") == config.end() || !config["layers"].is_array())
return false;
for (const auto& lc : config["layers"])
{
if (lc.find("slimmable") == lc.end() || !lc["slimmable"].is_object())
continue;
const std::string method = lc["slimmable"].value("method", "");
if (method != SLIMMABLE_METHOD)
{
if (!method.empty())
throw std::runtime_error("SlimmableWavenet: unsupported slimmable method '" + method + "'");
continue;
}
return true;
}
return false;
}
} // namespace

// Config parser for ConfigParserRegistry
std::unique_ptr<nam::ModelConfig> nam::wavenet::create_config(const nlohmann::json& config, double sampleRate)
{
if (config_is_slimmable_wavenet(config))
return nam::slimmable_wavenet::create_config(config, sampleRate);

auto wc = std::make_unique<WaveNetConfig>();
auto parsed = parse_config_json(config, sampleRate);
*wc = std::move(parsed);
Expand Down
10 changes: 4 additions & 6 deletions example_models/slimmable_wavenet.nam
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
{
"version": "0.7.0",
"metadata": {},
"architecture": "SlimmableWavenet",
"architecture": "WaveNet",
"config": {
"model": {
"layers": [
"layers": [
{
"input_size": 1,
"condition_size": 1,
Expand Down Expand Up @@ -34,9 +33,8 @@
}
}
],
"head": null,
"head_scale": 0.02
}
"head": null,
"head_scale": 0.02
},
"weights": [
0.2788535969157675,
Expand Down
2 changes: 1 addition & 1 deletion tools/run_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ int main()
test_slimmable_wavenet::test_default_is_max_size();
test_slimmable_wavenet::test_ratio_mapping();
test_slimmable_wavenet::test_from_json();
test_slimmable_wavenet::test_no_slimmable_layers_throws();
test_slimmable_wavenet::test_wavenet_without_slimmable_loads_as_regular();
test_slimmable_wavenet::test_unsupported_method_throws();
test_slimmable_wavenet::test_slimmed_matches_small_model();

Expand Down
40 changes: 15 additions & 25 deletions tools/test/test_slimmable_wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,43 +254,33 @@ void test_from_json()
process_and_verify(dsp.get(), 3, 64);
}

void test_no_slimmable_layers_throws()
void test_wavenet_without_slimmable_loads_as_regular()
{


// WaveNet config without slimmable allowed_channels loads as regular WaveNet
auto wavenet_json = load_nam_json("example_models/wavenet.nam");

nlohmann::json j;
j["version"] = "0.7.0";
j["architecture"] = "SlimmableWavenet";
j["config"]["model"] = wavenet_json["config"];
// No slimmable field on any layer -> all allowed_channels empty -> should throw
j["architecture"] = "WaveNet";
j["config"] = wavenet_json["config"];
j["weights"] = wavenet_json["weights"];
j["sample_rate"] = wavenet_json["sample_rate"];

bool threw = false;
try
{
auto dsp = nam::get_dsp(j);
}
catch (const std::runtime_error&)
{
threw = true;
}
assert(threw);
auto dsp = nam::get_dsp(j);
assert(dsp != nullptr);
assert(dynamic_cast<nam::SlimmableModel*>(dsp.get()) == nullptr);
}

void test_unsupported_method_throws()
{


// WaveNet with slimmable but unsupported method -> throws
auto wavenet_json = load_nam_json("example_models/wavenet.nam");

nlohmann::json j;
j["version"] = "0.7.0";
j["architecture"] = "SlimmableWavenet";
j["config"]["model"] = wavenet_json["config"];
j["config"]["model"]["layers"][0]["slimmable"] = {
j["architecture"] = "WaveNet";
j["config"] = wavenet_json["config"];
j["config"]["layers"][0]["slimmable"] = {
{"method", "some_future_method"}, {"kwargs", {{"allowed_channels", {2, 3}}}}};
j["weights"] = wavenet_json["weights"];
j["sample_rate"] = wavenet_json["sample_rate"];
Expand Down Expand Up @@ -442,15 +432,15 @@ void test_slimmed_matches_small_model()
auto small_dsp = nam::get_dsp(small_json);
assert(small_dsp != nullptr);

// --- Build the 4ch SlimmableWavenet ---
// --- Build the 4ch slimmable WaveNet (detected from config) ---
nlohmann::json large_json;
large_json["version"] = "0.7.0";
large_json["architecture"] = "SlimmableWavenet";
large_json["architecture"] = "WaveNet";
auto large_layer_config = make_layer_config(large_ch);
large_layer_config["slimmable"] = {
{"method", "slice_channels_uniform"}, {"kwargs", {{"allowed_channels", {small_ch, large_ch}}}}};
large_json["config"]["model"]["layers"] = nlohmann::json::array({large_layer_config});
large_json["config"]["model"]["head_scale"] = 1.0;
large_json["config"]["layers"] = nlohmann::json::array({large_layer_config});
large_json["config"]["head_scale"] = 1.0;
large_json["weights"] = large_weights;
large_json["sample_rate"] = 48000;

Expand Down
Loading