diff --git a/tools/benchmodel.cpp b/tools/benchmodel.cpp index 5e8c45d..b2630cf 100644 --- a/tools/benchmodel.cpp +++ b/tools/benchmodel.cpp @@ -1,10 +1,14 @@ #include #include +#include #include #include +#include +#include #include "NAM/dsp.h" #include "NAM/get_dsp.h" +#include "NAM/slimmable.h" using std::chrono::duration; using std::chrono::duration_cast; @@ -18,22 +22,50 @@ double outputBuffer[AUDIO_BUFFER_SIZE]; int main(int argc, char* argv[]) { - if (argc < 2) + double slimValue = -1.0; + bool hasSlim = false; + bool useFastTanh = true; + std::vector positionalArgs; + positionalArgs.push_back(argv[0]); + + for (int i = 1; i < argc; i++) { - std::cerr << "Usage: benchmodel [--no-fast-tanh]\n"; - exit(1); + std::string arg(argv[i]); + if (arg == "--slim") + { + if (i + 1 >= argc) + { + std::cerr << "Error: --slim requires a value between 0.0 and 1.0\n"; + return 1; + } + char* end = nullptr; + slimValue = std::strtod(argv[i + 1], &end); + if (end == argv[i + 1] || *end != '\0' || slimValue < 0.0 || slimValue > 1.0) + { + std::cerr << "Error: --slim value must be a number between 0.0 and 1.0\n"; + return 1; + } + hasSlim = true; + i++; // skip the value + } + else if (arg == "--no-fast-tanh") + { + useFastTanh = false; + } + else + { + positionalArgs.push_back(argv[i]); + } } - const char* modelPath = argv[1]; - - // Check for --no-fast-tanh flag - bool useFastTanh = true; - for (int i = 2; i < argc; i++) + if (positionalArgs.size() < 2) { - if (std::strcmp(argv[i], "--no-fast-tanh") == 0) - useFastTanh = false; + std::cerr << "Usage: benchmodel [--slim <0.0-1.0>] [--no-fast-tanh] \n"; + return 1; } + const char* modelPath = positionalArgs[1]; + if (useFastTanh) { nam::activations::Activation::enable_fast_tanh(); @@ -53,7 +85,19 @@ int main(int argc, char* argv[]) if (model == nullptr) { std::cerr << "Failed to load model\n"; - exit(1); + return 1; + } + + if (hasSlim) + { + auto* slimmable = dynamic_cast(model.get()); + if (!slimmable) + { + std::cerr << "Error: --slim requires a model that implements the SlimmableModel interface\n"; + return 1; + } + std::cout << "Setting slimmable size to " << slimValue << "\n"; + slimmable->SetSlimmableSize(slimValue); } size_t bufferSize = AUDIO_BUFFER_SIZE;