22#include < cmath>
33#include < filesystem>
44#include < fstream>
5- #include < iostream>
65#include < memory>
76#include < vector>
87
@@ -55,7 +54,7 @@ void process_and_verify(nam::DSP* dsp, int num_buffers, int buffer_size)
5554
5655void test_loads_from_file ()
5756{
58- std::cout << " test_slimmable_wavenet_loads_from_file " << std::endl;
57+
5958
6059 std::filesystem::path path (" example_models/slimmable_wavenet.nam" );
6160 auto dsp = nam::get_dsp (path);
@@ -64,7 +63,7 @@ void test_loads_from_file()
6463
6564void test_implements_slimmable ()
6665{
67- std::cout << " test_slimmable_wavenet_implements_slimmable " << std::endl;
66+
6867
6968 auto dsp = nam::get_dsp (std::filesystem::path (" example_models/slimmable_wavenet.nam" ));
7069 assert (dsp != nullptr );
@@ -75,7 +74,7 @@ void test_implements_slimmable()
7574
7675void test_processes_audio ()
7776{
78- std::cout << " test_slimmable_wavenet_processes_audio " << std::endl;
77+
7978
8079 auto dsp = nam::get_dsp (std::filesystem::path (" example_models/slimmable_wavenet.nam" ));
8180 assert (dsp != nullptr );
@@ -84,7 +83,7 @@ void test_processes_audio()
8483
8584void test_slimming_changes_output ()
8685{
87- std::cout << " test_slimmable_wavenet_slimming_changes_output " << std::endl;
86+
8887
8988 auto dsp = nam::get_dsp (std::filesystem::path (" example_models/slimmable_wavenet.nam" ));
9089 assert (dsp != nullptr );
@@ -129,7 +128,7 @@ void test_slimming_changes_output()
129128
130129void test_boundary_values ()
131130{
132- std::cout << " test_slimmable_wavenet_boundary_values " << std::endl;
131+
133132
134133 auto dsp = nam::get_dsp (std::filesystem::path (" example_models/slimmable_wavenet.nam" ));
135134 assert (dsp != nullptr );
@@ -163,7 +162,7 @@ void test_boundary_values()
163162
164163void test_default_is_max_size ()
165164{
166- std::cout << " test_slimmable_wavenet_default_is_max_size " << std::endl;
165+
167166
168167 auto dsp = nam::get_dsp (std::filesystem::path (" example_models/slimmable_wavenet.nam" ));
169168 assert (dsp != nullptr );
@@ -197,7 +196,7 @@ void test_default_is_max_size()
197196
198197void test_ratio_mapping ()
199198{
200- std::cout << " test_slimmable_wavenet_ratio_mapping " << std::endl;
199+
201200
202201 // With allowed_channels [1, 2, 3] (len=3):
203202 // idx = min(floor(ratio * 3), 2)
@@ -247,34 +246,19 @@ void test_ratio_mapping()
247246
248247void test_from_json ()
249248{
250- std::cout << " test_slimmable_wavenet_from_json" << std::endl;
251-
252- // Build a SlimmableWavenet JSON from an existing WaveNet
253- auto wavenet_json = load_nam_json (" example_models/wavenet_3ch.nam" );
254249
255- nlohmann::json j;
256- j[" version" ] = " 0.7.0" ;
257- j[" architecture" ] = " SlimmableWavenet" ;
258-
259- // Copy the WaveNet config and add slimmable field to the first layer
260- j[" config" ][" model" ] = wavenet_json[" config" ];
261- j[" config" ][" model" ][" layers" ][0 ][" slimmable" ] = {
262- {" method" , " slice_channels_uniform" },
263- {" kwargs" , {{" allowed_channels" , {2 , 3 }}}}
264- };
265- j[" weights" ] = wavenet_json[" weights" ];
266- j[" sample_rate" ] = wavenet_json[" sample_rate" ];
267250
251+ auto j = load_nam_json (" example_models/slimmable_wavenet.nam" );
268252 auto dsp = nam::get_dsp (j);
269253 assert (dsp != nullptr );
270254 process_and_verify (dsp.get (), 3 , 64 );
271255}
272256
273257void test_no_slimmable_layers_throws ()
274258{
275- std::cout << " test_slimmable_wavenet_no_slimmable_layers_throws" << std::endl;
276259
277- auto wavenet_json = load_nam_json (" example_models/wavenet_3ch.nam" );
260+
261+ auto wavenet_json = load_nam_json (" example_models/wavenet.nam" );
278262
279263 nlohmann::json j;
280264 j[" version" ] = " 0.7.0" ;
@@ -298,18 +282,16 @@ void test_no_slimmable_layers_throws()
298282
299283void test_unsupported_method_throws ()
300284{
301- std::cout << " test_slimmable_wavenet_unsupported_method_throws" << std::endl;
302285
303- auto wavenet_json = load_nam_json (" example_models/wavenet_3ch.nam" );
286+
287+ auto wavenet_json = load_nam_json (" example_models/wavenet.nam" );
304288
305289 nlohmann::json j;
306290 j[" version" ] = " 0.7.0" ;
307291 j[" architecture" ] = " SlimmableWavenet" ;
308292 j[" config" ][" model" ] = wavenet_json[" config" ];
309293 j[" config" ][" model" ][" layers" ][0 ][" slimmable" ] = {
310- {" method" , " some_future_method" },
311- {" kwargs" , {{" allowed_channels" , {2 , 3 }}}}
312- };
294+ {" method" , " some_future_method" }, {" kwargs" , {{" allowed_channels" , {2 , 3 }}}}};
313295 j[" weights" ] = wavenet_json[" weights" ];
314296 j[" sample_rate" ] = wavenet_json[" sample_rate" ];
315297
@@ -327,7 +309,7 @@ void test_unsupported_method_throws()
327309
328310void test_slimmed_matches_small_model ()
329311{
330- std::cout << " test_slimmable_wavenet_slimmed_matches_small_model " << std::endl;
312+
331313
332314 // Build a minimal WaveNet config: 1 layer array, 2 layers (dilations [1,2]),
333315 // kernel_size=3, no gating, no layer1x1, no head1x1, no FiLM, Tanh activation.
@@ -368,10 +350,10 @@ void test_slimmed_matches_small_model()
368350 for (int l = 0 ; l < num_layers; l++)
369351 {
370352 n += ch * ch * kernel_size + ch; // conv
371- n += condition_size * ch; // input_mixin
353+ n += condition_size * ch; // input_mixin
372354 }
373355 n += ch * head_size; // head_rechannel
374- n += 1 ; // head_scale
356+ n += 1 ; // head_scale
375357 return n;
376358 };
377359
@@ -389,7 +371,7 @@ void test_slimmed_matches_small_model()
389371
390372 // Helper: embed Conv1x1(small_in, small_out) into Conv1x1(full_in, full_out)
391373 auto embed_conv1x1 = [](std::vector<float >::const_iterator& src, int small_in, int small_out, int full_in,
392- int full_out, bool bias, std::vector<float >& dst) {
374+ int full_out, bool bias, std::vector<float >& dst) {
393375 for (int i = 0 ; i < full_out; i++)
394376 for (int j = 0 ; j < full_in; j++)
395377 {
@@ -410,7 +392,7 @@ void test_slimmed_matches_small_model()
410392
411393 // Helper: embed Conv1D(small_in, small_out) into Conv1D(full_in, full_out)
412394 auto embed_conv1d = [](std::vector<float >::const_iterator& src, int small_in, int small_out, int full_in,
413- int full_out, int ks, std::vector<float >& dst) {
395+ int full_out, int ks, std::vector<float >& dst) {
414396 for (int i = 0 ; i < full_out; i++)
415397 for (int j = 0 ; j < full_in; j++)
416398 for (int k = 0 ; k < ks; k++)
@@ -465,8 +447,8 @@ void test_slimmed_matches_small_model()
465447 large_json[" version" ] = " 0.7.0" ;
466448 large_json[" architecture" ] = " SlimmableWavenet" ;
467449 auto large_layer_config = make_layer_config (large_ch);
468- large_layer_config[" slimmable" ] = {{ " method " , " slice_channels_uniform " },
469- {" kwargs" , {{" allowed_channels" , {small_ch, large_ch}}}}};
450+ large_layer_config[" slimmable" ] = {
451+ { " method " , " slice_channels_uniform " }, {" kwargs" , {{" allowed_channels" , {small_ch, large_ch}}}}};
470452 large_json[" config" ][" model" ][" layers" ] = nlohmann::json::array ({large_layer_config});
471453 large_json[" config" ][" model" ][" head_scale" ] = 1.0 ;
472454 large_json[" weights" ] = large_weights;
@@ -510,12 +492,7 @@ void test_slimmed_matches_small_model()
510492 {
511493 assert (std::isfinite (out_small[i]));
512494 assert (std::isfinite (out_large[i]));
513- if (std::abs (out_small[i] - out_large[i]) > 1e-6 )
514- {
515- std::cerr << " MISMATCH at buffer " << buf << " sample " << i << " : small=" << out_small[i]
516- << " slimmed=" << out_large[i] << " diff=" << std::abs (out_small[i] - out_large[i]) << std::endl;
517- assert (false );
518- }
495+ assert (std::abs (out_small[i] - out_large[i]) <= 1e-6 );
519496 }
520497 }
521498}
0 commit comments