diff --git a/examples/models/parakeet/CMakeLists.txt b/examples/models/parakeet/CMakeLists.txt index 143c4ebe77d..f17be1c8fa5 100644 --- a/examples/models/parakeet/CMakeLists.txt +++ b/examples/models/parakeet/CMakeLists.txt @@ -108,32 +108,49 @@ if(EXECUTORCH_BUILD_VULKAN) executorch_target_link_options_shared_lib(vulkan_backend) endif() -add_executable(parakeet_runner main.cpp timestamp_utils.cpp tokenizer_utils.cpp) -if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") - target_link_options_gc_sections(parakeet_runner) - if(NOT APPLE AND NOT MSVC) - target_link_options(parakeet_runner PRIVATE "LINKER:-s") - endif() -endif() +set(parakeet_shared_sources parakeet_transcriber.cpp timestamp_utils.cpp + tokenizer_utils.cpp +) -# Copy MLX metallib for runtime if MLX delegate is enabled -if(TARGET mlxdelegate) - executorch_target_copy_mlx_metallib(parakeet_runner) -endif() +set(parakeet_common_include_directories + ${_common_include_directories} ${EXECUTORCH_ROOT}/third-party/json/include +) -target_include_directories( - parakeet_runner PUBLIC ${_common_include_directories} +add_executable(parakeet_runner main.cpp ${parakeet_shared_sources}) +add_executable( + parakeet_helper parakeet_helper.cpp parakeet_helper_protocol.cpp + ${parakeet_shared_sources} ) -target_link_libraries(parakeet_runner PUBLIC ${link_libraries}) -target_compile_options(parakeet_runner PUBLIC ${_common_compile_options}) + +foreach(parakeet_target parakeet_runner parakeet_helper) + if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(${parakeet_target}) + if(NOT APPLE AND NOT MSVC) + target_link_options(${parakeet_target} PRIVATE "LINKER:-s") + endif() + endif() + + if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(${parakeet_target}) + endif() + + target_include_directories( + ${parakeet_target} PUBLIC ${parakeet_common_include_directories} + ) + target_link_libraries(${parakeet_target} PUBLIC ${link_libraries}) + target_compile_options(${parakeet_target} PUBLIC ${_common_compile_options}) +endforeach() # On Windows, copy required DLLs to the executable directory if(MSVC AND EXECUTORCH_BUILD_CUDA) - add_custom_command( - TARGET parakeet_runner - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different $ - $ - COMMENT "Copying aoti_cuda_shims.dll to parakeet_runner directory" - ) + foreach(parakeet_target parakeet_runner parakeet_helper) + add_custom_command( + TARGET ${parakeet_target} + POST_BUILD + COMMAND + ${CMAKE_COMMAND} -E copy_if_different $ + $ + COMMENT "Copying aoti_cuda_shims.dll to ${parakeet_target} directory" + ) + endforeach() endif() diff --git a/examples/models/parakeet/CMakePresets.json b/examples/models/parakeet/CMakePresets.json index 87ace61e315..90a90fbbdf5 100644 --- a/examples/models/parakeet/CMakePresets.json +++ b/examples/models/parakeet/CMakePresets.json @@ -89,42 +89,42 @@ "displayName": "Build Parakeet runner (CPU)", "configurePreset": "parakeet-cpu", "configuration": "Release", - "targets": ["parakeet_runner"] + "targets": ["parakeet_runner", "parakeet_helper"] }, { "name": "parakeet-cuda", "displayName": "Build Parakeet runner (CUDA)", "configurePreset": "parakeet-cuda", "configuration": "Release", - "targets": ["parakeet_runner"] + "targets": ["parakeet_runner", "parakeet_helper"] }, { "name": "parakeet-cuda-debug", "displayName": "Build Parakeet runner (CUDA, Debug)", "configurePreset": "parakeet-cuda-debug", "configuration": "Debug", - "targets": ["parakeet_runner"] + "targets": ["parakeet_runner", "parakeet_helper"] }, { "name": "parakeet-metal", "displayName": "Build Parakeet runner (Metal)", "configurePreset": "parakeet-metal", "configuration": "Release", - "targets": ["parakeet_runner"] + "targets": ["parakeet_runner", "parakeet_helper"] }, { "name": "parakeet-mlx", "displayName": "Build Parakeet runner (MLX)", "configurePreset": "parakeet-mlx", "configuration": "Release", - "targets": ["parakeet_runner"] + "targets": ["parakeet_runner", "parakeet_helper"] }, { "name": "parakeet-vulkan", "displayName": "Build Parakeet runner (Vulkan)", "configurePreset": "parakeet-vulkan", "configuration": "Release", - "targets": ["parakeet_runner"] + "targets": ["parakeet_runner", "parakeet_helper"] } ], "workflowPresets": [ diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 512e2796e63..288034a0e04 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -224,6 +224,11 @@ make parakeet-cuda make parakeet-mlx ``` +Each Parakeet build now produces both: + +- `parakeet_runner` for one-shot CLI transcription from an audio file +- `parakeet_helper` for long-lived host integrations that keep the model warm and stream PCM requests over stdin/stdout + On Windows (PowerShell), use CMake workflow presets directly: ```powershell @@ -286,6 +291,26 @@ If your generator is single-config, the runner may be at `.\cmake-out\examples\m | `--data_path` | Path to data file (.ptd) for delegate data (required for CUDA/CUDA-Windows) | | `--timestamps` | Timestamp output mode: `none\|token\|word\|segment\|all` (default: `segment`) | +### Persistent Helper + +The helper binary uses the same Parakeet transcription stack as `parakeet_runner`, +but keeps the model loaded across multiple requests so host apps can avoid repeated +startup and model load overhead. + +Example: + +```bash +# Metal +DYLD_LIBRARY_PATH=/usr/lib ./cmake-out/examples/models/parakeet/parakeet_helper \ + --model_path examples/models/parakeet/parakeet_metal/model.pte \ + --tokenizer_path examples/models/parakeet/parakeet_metal/tokenizer.model +``` + +The helper accepts framed requests over stdin, validates 16 kHz mono float32 PCM +payloads, and returns status/result messages over stdout. It is intended for app +integrations such as the macOS `ExecuWhisper` frontend in the separate +`executorch-examples` repository. + ### Mobile App Check out a [demo Android app](https://github.com/meta-pytorch/executorch-examples/tree/main/parakeet/android/ParakeetApp) for Parakeet in the separate `executorch-examples` repository. diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 87768cec38b..410ba6cea62 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -6,34 +6,14 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include +#include + #include -#include #include -#include -#include #include -#include -#include - -#include -#include "timestamp_utils.h" -#include "tokenizer_utils.h" -#include "types.h" +#include "parakeet_transcriber.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #ifdef ET_BUILD_METAL #include @@ -53,299 +33,17 @@ DEFINE_string( timestamps, "segment", "Timestamp output mode: none|token|word|segment|all"); - -using ::executorch::extension::from_blob; -using ::executorch::extension::Module; -using ::executorch::runtime::Error; -using ::executorch::runtime::EValue; - -using ::parakeet::TextWithOffsets; -using ::parakeet::Token; -using ::parakeet::TokenId; -using ::parakeet::TokenWithTextInfo; - -namespace { -// TDT duration values -const std::vector DURATIONS = {0, 1, 2, 3, 4}; - -struct TimestampOutputMode { - bool token = false; - bool word = false; - bool segment = false; - - bool enabled() const { - return token || word || segment; - } -}; - -std::string to_lower_ascii(std::string s) { - for (char& ch : s) { - ch = static_cast(std::tolower(static_cast(ch))); - } - return s; -} - -TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) { - if (raw_arg.empty()) { - throw std::invalid_argument( - "Invalid --timestamps value (empty). Expected: token, word, segment, all."); - } - const std::string mode = to_lower_ascii(raw_arg); - if (mode == "none") { - return {false, false, false}; - } - if (mode == "token") { - return {true, false, false}; - } - if (mode == "word") { - return {false, true, false}; - } - if (mode == "segment") { - return {false, false, true}; - } - if (mode == "all") { - return {true, true, true}; - } - throw std::invalid_argument( - "Invalid --timestamps value '" + raw_arg + - "'. Expected: token, word, segment, all."); -} - -// Helper to get expected scalar type for a method input -::executorch::runtime::Result<::executorch::aten::ScalarType> -get_input_scalar_type( - Module& model, - const char* method_name, - size_t input_index) { - auto method_meta_result = model.method_meta(method_name); - if (!method_meta_result.ok()) { - ET_LOG(Error, "Failed to get method metadata for %s", method_name); - return method_meta_result.error(); - } - auto method_meta = method_meta_result.get(); - if (method_meta.num_inputs() <= input_index) { - ET_LOG( - Error, - "Method %s has %zu inputs, but requested index %zu", - method_name, - method_meta.num_inputs(), - input_index); - return ::executorch::runtime::Error::InvalidArgument; - } - auto input_meta_result = method_meta.input_tensor_meta(input_index); - if (input_meta_result.error() != ::executorch::runtime::Error::Ok) { - ET_LOG( - Error, - "Failed to get input tensor metadata for %s[%zu]", - method_name, - input_index); - return input_meta_result.error(); - } - return input_meta_result.get().scalar_type(); -} - -std::vector greedy_decode_executorch( - Module& model, - const ::executorch::aten::Tensor& f_proj, - int64_t encoder_len, - int64_t blank_id, - int64_t num_rnn_layers = 2, - int64_t pred_hidden = 640, - int64_t max_symbols_per_step = 10, - ::executorch::extension::llm::Stats* stats = nullptr) { - std::vector hypothesis; - - // Shape: [1, T, joint_hidden] - size_t proj_dim = static_cast(f_proj.sizes()[2]); - - // Get expected dtype for decoder_step h and c inputs (indices 1 and 2) - auto h_dtype_result = get_input_scalar_type(model, "decoder_step", 1); - if (!h_dtype_result.ok()) { - return hypothesis; - } - auto c_dtype_result = get_input_scalar_type(model, "decoder_step", 2); - if (!c_dtype_result.ok()) { - return hypothesis; - } - auto h_dtype = h_dtype_result.get(); - auto c_dtype = c_dtype_result.get(); - - ET_LOG( - Info, - "Decoder h dtype: %s, c dtype: %s", - ::executorch::runtime::toString(h_dtype), - ::executorch::runtime::toString(c_dtype)); - - // Calculate buffer sizes based on dtype - size_t h_elem_size = ::executorch::runtime::elementSize(h_dtype); - size_t c_elem_size = ::executorch::runtime::elementSize(c_dtype); - size_t num_elements = - static_cast(num_rnn_layers) * static_cast(pred_hidden); - - // Initialize LSTM state with zeros (using byte buffers for dtype flexibility) - std::vector h_data(num_elements * h_elem_size, 0); - std::vector c_data(num_elements * c_elem_size, 0); - - auto h = from_blob( - h_data.data(), - {static_cast<::executorch::aten::SizesType>(num_rnn_layers), - 1, - static_cast<::executorch::aten::SizesType>(pred_hidden)}, - h_dtype); - auto c = from_blob( - c_data.data(), - {static_cast<::executorch::aten::SizesType>(num_rnn_layers), - 1, - static_cast<::executorch::aten::SizesType>(pred_hidden)}, - c_dtype); - - // Prime the decoder with SOS (= blank_id) to match NeMo TDT label-looping: - // - SOS is defined as blank: - // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L1063 - // - Predictor priming with SOS: - // https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L1122-L1127 - std::vector sos_token_data = {blank_id}; - auto sos_token = from_blob( - sos_token_data.data(), {1, 1}, ::executorch::aten::ScalarType::Long); - auto decoder_init_result = model.execute( - "decoder_step", - std::vector<::executorch::runtime::EValue>{sos_token, h, c}); - if (!decoder_init_result.ok()) { - ET_LOG(Error, "decoder_step (SOS) failed"); - return hypothesis; - } - auto& init_outputs = decoder_init_result.get(); - auto g_proj_init = init_outputs[0].toTensor(); - auto new_h_init = init_outputs[1].toTensor(); - auto new_c_init = init_outputs[2].toTensor(); - std::memcpy(h_data.data(), new_h_init.const_data_ptr(), h_data.size()); - std::memcpy(c_data.data(), new_c_init.const_data_ptr(), c_data.size()); - - // Get expected dtype for joint inputs (f and g at indices 0 and 1) - auto f_dtype_result = get_input_scalar_type(model, "joint", 0); - if (!f_dtype_result.ok()) { - return hypothesis; - } - auto g_dtype_result = get_input_scalar_type(model, "joint", 1); - if (!g_dtype_result.ok()) { - return hypothesis; - } - auto f_dtype = f_dtype_result.get(); - auto g_dtype = g_dtype_result.get(); - - ET_LOG( - Info, - "Joint f dtype: %s, g dtype: %s", - ::executorch::runtime::toString(f_dtype), - ::executorch::runtime::toString(g_dtype)); - - size_t f_elem_size = ::executorch::runtime::elementSize(f_dtype); - size_t g_elem_size = ::executorch::runtime::elementSize(g_dtype); - - // Copy g_proj data for reuse (using byte buffer for dtype flexibility) - size_t g_proj_num_bytes = - static_cast(g_proj_init.numel()) * g_elem_size; - std::vector g_proj_data(g_proj_num_bytes); - std::memcpy( - g_proj_data.data(), g_proj_init.const_data_ptr(), g_proj_num_bytes); - - int64_t t = 0; - int64_t symbols_on_frame = 0; - const uint8_t* f_proj_ptr = - static_cast(f_proj.const_data_ptr()); - size_t f_t_num_bytes = proj_dim * f_elem_size; - - // Scan over encoder output - while (t < encoder_len) { - // Get encoder frame at time t: f_proj[:, t:t+1, :] - std::vector f_t_data(f_t_num_bytes); - std::memcpy( - f_t_data.data(), - f_proj_ptr + static_cast(t) * f_t_num_bytes, - f_t_num_bytes); - - auto f_t = from_blob( - f_t_data.data(), - {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, - f_dtype); - - auto g_proj = from_blob( - g_proj_data.data(), - {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, - g_dtype); - - auto joint_result = model.execute( - "joint", std::vector<::executorch::runtime::EValue>{f_t, g_proj}); - if (!joint_result.ok()) { - ET_LOG(Error, "joint failed at t=%lld", static_cast(t)); - return hypothesis; - } - - int64_t k = joint_result.get()[0].toTensor().const_data_ptr()[0]; - int64_t dur_idx = - joint_result.get()[1].toTensor().const_data_ptr()[0]; - int64_t dur = DURATIONS[dur_idx]; - - if (k == blank_id) { - t += std::max(dur, static_cast(1)); - symbols_on_frame = 0; - } else { - if (hypothesis.empty() && stats) { - stats->first_token_ms = ::executorch::extension::llm::time_in_ms(); - } - hypothesis.push_back({static_cast(k), t, dur}); - - std::vector token_data = {k}; - auto token = from_blob( - token_data.data(), {1, 1}, ::executorch::aten::ScalarType::Long); - - auto decoder_result = model.execute( - "decoder_step", - std::vector<::executorch::runtime::EValue>{token, h, c}); - if (!decoder_result.ok()) { - ET_LOG(Error, "decoder_step failed"); - return hypothesis; - } - auto& outputs = decoder_result.get(); - auto new_g_proj = outputs[0].toTensor(); - auto new_h = outputs[1].toTensor(); - auto new_c = outputs[2].toTensor(); - - // Update h, c, and g_proj - std::memcpy(h_data.data(), new_h.const_data_ptr(), h_data.size()); - std::memcpy(c_data.data(), new_c.const_data_ptr(), c_data.size()); - std::memcpy( - g_proj_data.data(), new_g_proj.const_data_ptr(), g_proj_data.size()); - - t += dur; - - if (dur == 0) { - symbols_on_frame++; - if (symbols_on_frame >= max_symbols_per_step) { - t++; - symbols_on_frame = 0; - } - } else { - symbols_on_frame = 0; - } - } - } - - return hypothesis; -} - -} // namespace +DEFINE_bool( + runtime_profile, + false, + "Print a detailed runtime profile for preprocessor, encoder, and decode-loop execution."); int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - // Initialize stats for benchmarking - ::executorch::extension::llm::Stats stats; - stats.model_load_start_ms = ::executorch::extension::llm::time_in_ms(); - - TimestampOutputMode timestamp_mode; + parakeet::TimestampOutputMode timestamp_mode; try { - timestamp_mode = parse_timestamp_output_mode(FLAGS_timestamps); + timestamp_mode = parakeet::parse_timestamp_output_mode(FLAGS_timestamps); } catch (const std::invalid_argument& e) { ET_LOG(Error, "%s", e.what()); return 1; @@ -356,242 +54,57 @@ int main(int argc, char** argv) { return 1; } - // Load model (which includes the bundled preprocessor) - ET_LOG(Info, "Loading model from: %s", FLAGS_model_path.c_str()); - std::unique_ptr model; - if (!FLAGS_data_path.empty()) { - ET_LOG(Info, "Loading data from: %s", FLAGS_data_path.c_str()); - model = std::make_unique( - FLAGS_model_path, FLAGS_data_path, Module::LoadMode::Mmap); - } else { - model = std::make_unique(FLAGS_model_path, Module::LoadMode::Mmap); - } - auto model_load_error = model->load(); - if (model_load_error != Error::Ok) { - ET_LOG(Error, "Failed to load model."); - return 1; - } - - // Load all methods upfront so model_load_time captures the real cost. - // With Mmap load mode, model->load() only sets up memory mappings; - // the actual data is paged in lazily when methods are first loaded. - const std::vector required_methods = { - "preprocessor", "encoder", "decoder_step", "joint"}; - for (const auto& method : required_methods) { - auto method_load_error = model->load_method(method); - if (method_load_error != Error::Ok) { - ET_LOG(Error, "Failed to load method: %s", method.c_str()); - return 1; + try { + parakeet::ParakeetTranscriber transcriber( + FLAGS_model_path, FLAGS_tokenizer_path, FLAGS_data_path); + const auto result = transcriber.transcribe_wav_path( + FLAGS_audio_path, + parakeet::TranscribeConfig{timestamp_mode, FLAGS_runtime_profile}); + + std::cout << "Transcribed text: " << result.text << std::endl; + if (!result.stats_json.empty()) { + std::cout << "PyTorchObserver " << result.stats_json << std::endl; + } + if (result.runtime_profile_report.has_value()) { + std::cout << *result.runtime_profile_report; } - } - stats.model_load_end_ms = ::executorch::extension::llm::time_in_ms(); - stats.inference_start_ms = ::executorch::extension::llm::time_in_ms(); - - // Load audio - ET_LOG(Info, "Loading audio from: %s", FLAGS_audio_path.c_str()); - std::vector audio_data = - ::executorch::extension::llm::load_wav_audio_data(FLAGS_audio_path); - ET_LOG(Info, "Loaded %zu audio samples", audio_data.size()); - - auto audio_tensor = from_blob( - audio_data.data(), - {static_cast<::executorch::aten::SizesType>(audio_data.size())}, - ::executorch::aten::ScalarType::Float); - std::vector audio_len_data = { - static_cast(audio_data.size())}; - auto audio_len_tensor = from_blob( - audio_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); - - ET_LOG(Info, "Running preprocessor..."); - auto proc_result = model->execute( - "preprocessor", - std::vector<::executorch::runtime::EValue>{ - audio_tensor, audio_len_tensor}); - if (!proc_result.ok()) { - ET_LOG(Error, "Preprocessor forward failed."); - return 1; - } - auto& proc_outputs = proc_result.get(); - auto mel = proc_outputs[0].toTensor(); - auto mel_len_tensor_out = proc_outputs[1].toTensor(); - int64_t mel_len_value = mel_len_tensor_out.const_data_ptr()[0]; - - // Create mel_len tensor for encoder - std::vector mel_len_data = {mel_len_value}; - auto mel_len = - from_blob(mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); - - ET_LOG( - Info, - "Mel spectrogram shape: [%ld, %ld, %ld], mel_len: %lld", - static_cast(mel.sizes()[0]), - static_cast(mel.sizes()[1]), - static_cast(mel.sizes()[2]), - static_cast(mel_len_value)); - - ET_LOG(Info, "Running encoder..."); - auto enc_result = model->execute( - "encoder", std::vector<::executorch::runtime::EValue>{mel, mel_len}); - if (!enc_result.ok()) { - ET_LOG(Error, "Encoder forward failed."); - return 1; - } - stats.prompt_eval_end_ms = ::executorch::extension::llm::time_in_ms(); - - auto& enc_outputs = enc_result.get(); - auto f_proj = enc_outputs[0].toTensor(); // [B, T, joint_hidden] - int64_t encoded_len = enc_outputs[1].toTensor().const_data_ptr()[0]; - - ET_LOG( - Info, - "Encoder output (f_proj) shape: [%ld, %ld, %ld], len=%ld", - static_cast(f_proj.sizes()[0]), - static_cast(f_proj.sizes()[1]), - static_cast(f_proj.sizes()[2]), - static_cast(encoded_len)); - - // Query model metadata from constant_methods - std::vector<::executorch::runtime::EValue> empty_inputs; - auto num_rnn_layers_result = model->execute("num_rnn_layers", empty_inputs); - auto pred_hidden_result = model->execute("pred_hidden", empty_inputs); - auto vocab_size_result = model->execute("vocab_size", empty_inputs); - auto blank_id_result = model->execute("blank_id", empty_inputs); - auto sample_rate_result = model->execute("sample_rate", empty_inputs); - auto window_stride_result = model->execute("window_stride", empty_inputs); - auto encoder_subsampling_factor_result = - model->execute("encoder_subsampling_factor", empty_inputs); - - if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() || - !vocab_size_result.ok() || !blank_id_result.ok() || - !sample_rate_result.ok() || !window_stride_result.ok() || - !encoder_subsampling_factor_result.ok()) { - ET_LOG( - Error, - "Failed to query model metadata. Make sure the model was exported with constant_methods."); - return 1; - } - - int64_t vocab_size = vocab_size_result.get()[0].toInt(); - int64_t blank_id = blank_id_result.get()[0].toInt(); - int64_t num_rnn_layers = num_rnn_layers_result.get()[0].toInt(); - int64_t pred_hidden = pred_hidden_result.get()[0].toInt(); - int64_t sample_rate = sample_rate_result.get()[0].toInt(); - double window_stride = window_stride_result.get()[0].toDouble(); - int64_t encoder_subsampling_factor = - encoder_subsampling_factor_result.get()[0].toInt(); - - ET_LOG( - Info, - "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld, window_stride=%.6f, encoder_subsampling_factor=%lld", - static_cast(vocab_size), - static_cast(blank_id), - static_cast(num_rnn_layers), - static_cast(pred_hidden), - static_cast(sample_rate), - window_stride, - static_cast(encoder_subsampling_factor)); - - ET_LOG(Info, "Running TDT greedy decode..."); - auto decoded_tokens = greedy_decode_executorch( - *model, - f_proj, - encoded_len, - blank_id, - num_rnn_layers, - pred_hidden, - 10, - &stats); - - ET_LOG(Info, "Decoded %zu tokens", decoded_tokens.size()); - - // Load tokenizer - ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str()); - auto tokenizer = - ::executorch::extension::llm::load_tokenizer(FLAGS_tokenizer_path); - if (!tokenizer || !tokenizer->is_loaded()) { - ET_LOG( - Error, - "Failed to load tokenizer from: %s", - FLAGS_tokenizer_path.c_str()); - return 1; - } - - // Convert tokens to text - std::string text = parakeet::tokenizer_utils::decode_token_sequence( - decoded_tokens, *tokenizer); - std::cout << "Transcribed text: " << text << std::endl; - - // Record inference end time and token counts - stats.inference_end_ms = ::executorch::extension::llm::time_in_ms(); - stats.num_prompt_tokens = - encoded_len; // Use encoder output length as "prompt" tokens - stats.num_generated_tokens = static_cast(decoded_tokens.size()); - - // Print PyTorchObserver stats for benchmarking - ::executorch::extension::llm::print_report(stats); #ifdef ET_BUILD_METAL - executorch::backends::metal::print_metal_backend_stats(); -#endif // ET_BUILD_METAL - - if (!timestamp_mode.enabled()) { - return 0; - } - - ET_LOG(Info, "Computing timestamps..."); - std::unordered_set supported_punctuation = - parakeet::tokenizer_utils::derive_supported_punctuation(*tokenizer); - ET_LOG( - Info, - "Derived supported_punctuation size=%zu", - supported_punctuation.size()); - - // for simplicity, compute all levels of timestamps regardless of mode - std::vector tokens_with_text_info; - try { - tokens_with_text_info = - parakeet::timestamp_utils::get_tokens_with_text_info( - decoded_tokens, *tokenizer, supported_punctuation); - } catch (const std::exception& e) { - ET_LOG(Error, "Failed to get tokens with text info: %s", e.what()); - return 1; - } - const auto word_offsets = parakeet::timestamp_utils::get_words_offsets( - tokens_with_text_info, *tokenizer, supported_punctuation); - const auto segment_offsets = - parakeet::timestamp_utils::get_segment_offsets(word_offsets); - - const double frame_to_seconds = - window_stride * static_cast(encoder_subsampling_factor); + executorch::backends::metal::print_metal_backend_stats(); +#endif - if (timestamp_mode.segment) { - std::cout << "\nSegment timestamps:" << std::endl; - for (const auto& segment : segment_offsets) { - const double start = segment.start_offset * frame_to_seconds; - const double end = segment.end_offset * frame_to_seconds; - std::cout << start << "s - " << end << "s : " << segment.text - << std::endl; + if (timestamp_mode.segment) { + std::cout << "\nSegment timestamps:" << std::endl; + for (const auto& segment : result.segment_offsets) { + const double start = segment.start_offset * result.frame_to_seconds; + const double end = segment.end_offset * result.frame_to_seconds; + std::cout << start << "s - " << end << "s : " << segment.text + << std::endl; + } } - } - if (timestamp_mode.word) { - std::cout << "\nWord timestamps:" << std::endl; - for (const auto& word : word_offsets) { - const double start = word.start_offset * frame_to_seconds; - const double end = word.end_offset * frame_to_seconds; - std::cout << start << "s - " << end << "s : " << word.text << std::endl; + if (timestamp_mode.word) { + std::cout << "\nWord timestamps:" << std::endl; + for (const auto& word : result.word_offsets) { + const double start = word.start_offset * result.frame_to_seconds; + const double end = word.end_offset * result.frame_to_seconds; + std::cout << start << "s - " << end << "s : " << word.text << std::endl; + } } - } - if (timestamp_mode.token) { - std::cout << "\nToken timestamps:" << std::endl; - for (const auto& token : tokens_with_text_info) { - const double start = token.start_offset * frame_to_seconds; - const double end = token.end_offset * frame_to_seconds; - std::cout << start << "s - " << end << "s : " << token.decoded_text - << std::endl; + if (timestamp_mode.token) { + std::cout << "\nToken timestamps:" << std::endl; + for (const auto& token : result.token_offsets) { + const double start = token.start_offset * result.frame_to_seconds; + const double end = token.end_offset * result.frame_to_seconds; + std::cout << start << "s - " << end << "s : " << token.decoded_text + << std::endl; + } } - } - return 0; + return 0; + } catch (const std::exception& e) { + ET_LOG(Error, "%s", e.what()); + return 1; + } } diff --git a/examples/models/parakeet/parakeet_helper.cpp b/examples/models/parakeet/parakeet_helper.cpp new file mode 100644 index 00000000000..483aa3c8b2e --- /dev/null +++ b/examples/models/parakeet/parakeet_helper.cpp @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include "parakeet_helper_protocol.h" +#include "parakeet_transcriber.h" + +DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte)."); +DEFINE_string( + tokenizer_path, + "tokenizer.model", + "Path to SentencePiece tokenizer model file."); +DEFINE_string( + data_path, + "", + "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); + +namespace { + +constexpr int kExpectedSampleRate = 16000; +constexpr int kExpectedChannelCount = 1; +constexpr const char* kExpectedEncoding = "f32le"; + +} // namespace + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + try { + parakeet::ParakeetTranscriber transcriber( + FLAGS_model_path, FLAGS_tokenizer_path, FLAGS_data_path); + if (!parakeet::helper_protocol::write_message( + std::cout, parakeet::helper_protocol::encode_ready_message())) { + std::cerr << "Failed to write helper ready message." << std::endl; + return 1; + } + + while (true) { + parakeet::helper_protocol::Request request; + std::string request_error; + if (!parakeet::helper_protocol::read_request( + std::cin, &request, &request_error)) { + if (request_error.empty()) { + return 0; + } + parakeet::helper_protocol::write_message( + std::cout, + parakeet::helper_protocol::encode_error_message( + std::nullopt, "Failed to read helper request", request_error)); + return 1; + } + + if (request.type == parakeet::helper_protocol::Request::Type::Shutdown) { + return 0; + } + + const auto& transcribe_request = *request.transcribe; + try { + if (transcribe_request.audio.encoding != kExpectedEncoding) { + throw std::runtime_error("Unsupported audio encoding."); + } + if (transcribe_request.audio.sample_rate != kExpectedSampleRate) { + throw std::runtime_error("Unsupported audio sample rate."); + } + if (transcribe_request.audio.channel_count != kExpectedChannelCount) { + throw std::runtime_error("Unsupported audio channel count."); + } + if (transcribe_request.audio.payload_byte_count % sizeof(float) != 0) { + throw std::runtime_error("Audio payload must be float32-aligned."); + } + + std::string payload_bytes; + std::string payload_error; + if (!parakeet::helper_protocol::read_audio_payload( + std::cin, + transcribe_request.audio.payload_byte_count, + &payload_bytes, + &payload_error)) { + throw std::runtime_error(payload_error); + } + + std::vector audio( + transcribe_request.audio.payload_byte_count / sizeof(float)); + std::memcpy( + audio.data(), + payload_bytes.data(), + transcribe_request.audio.payload_byte_count); + + const auto result = transcriber.transcribe_audio( + audio.data(), + static_cast(audio.size()), + parakeet::TranscribeConfig{ + parakeet::parse_timestamp_output_mode("none"), + transcribe_request.enable_runtime_profile, + }, + [&](const std::string& status) { + std::string phase = "status"; + if (status == "Loading recording...") { + phase = "loading_recording"; + } else if (status == "Running preprocessor...") { + phase = "running_preprocessor"; + } else if (status == "Running encoder...") { + phase = "running_encoder"; + } else if (status == "Decoding final transcript...") { + phase = "decoding_final_transcript"; + } else if (status == "Computing timestamps...") { + phase = "computing_timestamps"; + } + parakeet::helper_protocol::write_message( + std::cout, + parakeet::helper_protocol::encode_status_message( + transcribe_request.request_id, phase, status)); + }); + + const std::string stdout_payload = result.stats_json.empty() + ? std::string() + : "PyTorchObserver " + result.stats_json; + const auto runtime_profile_line = + parakeet::extract_runtime_profile_line( + result.runtime_profile_report); + parakeet::helper_protocol::write_message( + std::cout, + parakeet::helper_protocol::encode_result_message( + transcribe_request.request_id, + result.text, + stdout_payload, + "", + runtime_profile_line)); + } catch (const std::exception& e) { + parakeet::helper_protocol::write_message( + std::cout, + parakeet::helper_protocol::encode_error_message( + transcribe_request.request_id, + "Helper transcription failed", + e.what())); + } + } + } catch (const std::exception& e) { + parakeet::helper_protocol::write_message( + std::cout, + parakeet::helper_protocol::encode_error_message( + std::nullopt, "Failed to start Parakeet helper", e.what())); + return 1; + } +} diff --git a/examples/models/parakeet/parakeet_helper_protocol.cpp b/examples/models/parakeet/parakeet_helper_protocol.cpp new file mode 100644 index 00000000000..b44bf07009a --- /dev/null +++ b/examples/models/parakeet/parakeet_helper_protocol.cpp @@ -0,0 +1,176 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "parakeet_helper_protocol.h" + +#include + +#include +#include +#include +#include +#include + +namespace parakeet::helper_protocol { +namespace { + +using json = nlohmann::json; + +} // namespace + +bool read_request( + std::istream& input, + Request* request, + std::string* error_message) { + std::string header_line; + if (!std::getline(input, header_line)) { + return false; + } + if (header_line.empty()) { + if (error_message) { + *error_message = "Received empty helper request header."; + } + return false; + } + + json payload; + try { + payload = json::parse(header_line); + } catch (const std::exception& e) { + if (error_message) { + *error_message = + std::string("Failed to parse helper request: ") + e.what(); + } + return false; + } + + const std::string type = payload.value("type", ""); + if (payload.value("version", -1) != kProtocolVersion) { + if (error_message) { + *error_message = "Unsupported helper protocol version."; + } + return false; + } + + if (type == "shutdown") { + request->type = Request::Type::Shutdown; + request->transcribe.reset(); + return true; + } + + if (type != "transcribe") { + if (error_message) { + *error_message = "Unsupported helper request type: " + type; + } + return false; + } + + if (!payload.contains("audio") || !payload["audio"].is_object()) { + if (error_message) { + *error_message = "Missing helper audio descriptor."; + } + return false; + } + + const auto& audio = payload["audio"]; + TranscribeRequest transcribe_request; + transcribe_request.request_id = payload.value("request_id", ""); + transcribe_request.enable_runtime_profile = + payload.value("enable_runtime_profile", false); + transcribe_request.audio.encoding = audio.value("encoding", ""); + transcribe_request.audio.sample_rate = audio.value("sample_rate", 0); + transcribe_request.audio.channel_count = audio.value("channel_count", 0); + transcribe_request.audio.payload_byte_count = + audio.value("payload_byte_count", static_cast(0)); + + request->type = Request::Type::Transcribe; + request->transcribe = transcribe_request; + return true; +} + +bool read_audio_payload( + std::istream& input, + std::size_t payload_byte_count, + std::string* payload, + std::string* error_message) { + payload->assign(payload_byte_count, '\0'); + input.read(payload->data(), static_cast(payload_byte_count)); + if (input.gcount() != static_cast(payload_byte_count)) { + if (error_message) { + *error_message = "Failed to read full helper payload."; + } + return false; + } + return true; +} + +std::string encode_ready_message() { + return json{{"type", "ready"}, {"version", kProtocolVersion}}.dump(); +} + +std::string encode_status_message( + const std::optional& request_id, + const std::string& phase, + const std::string& message) { + json payload = { + {"type", "status"}, + {"version", kProtocolVersion}, + {"phase", phase}, + {"message", message}, + }; + if (request_id.has_value()) { + payload["request_id"] = *request_id; + } + return payload.dump(); +} + +std::string encode_result_message( + const std::string& request_id, + const std::string& text, + const std::string& stdout, + const std::string& stderr, + const std::optional& runtime_profile) { + json payload = { + {"type", "result"}, + {"version", kProtocolVersion}, + {"request_id", request_id}, + {"text", text}, + {"stdout", stdout}, + {"stderr", stderr}, + }; + if (runtime_profile.has_value()) { + payload["runtime_profile"] = *runtime_profile; + } + return payload.dump(); +} + +std::string encode_error_message( + const std::optional& request_id, + const std::string& message, + const std::optional& details) { + json payload = { + {"type", "error"}, + {"version", kProtocolVersion}, + {"message", message}, + }; + if (request_id.has_value()) { + payload["request_id"] = *request_id; + } + if (details.has_value()) { + payload["details"] = *details; + } + return payload.dump(); +} + +bool write_message(std::ostream& output, const std::string& line) { + output << line << '\n'; + output.flush(); + return output.good(); +} + +} // namespace parakeet::helper_protocol diff --git a/examples/models/parakeet/parakeet_helper_protocol.h b/examples/models/parakeet/parakeet_helper_protocol.h new file mode 100644 index 00000000000..69976c438b1 --- /dev/null +++ b/examples/models/parakeet/parakeet_helper_protocol.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace parakeet::helper_protocol { + +constexpr int kProtocolVersion = 1; + +struct AudioDescriptor { + std::string encoding; + int sample_rate = 0; + int channel_count = 0; + std::size_t payload_byte_count = 0; +}; + +struct TranscribeRequest { + std::string request_id; + AudioDescriptor audio; + bool enable_runtime_profile = false; +}; + +struct Request { + enum class Type { + Transcribe, + Shutdown, + }; + + Type type = Type::Shutdown; + std::optional transcribe; +}; + +bool read_request( + std::istream& input, + Request* request, + std::string* error_message); + +bool read_audio_payload( + std::istream& input, + std::size_t payload_byte_count, + std::string* payload, + std::string* error_message); + +std::string encode_ready_message(); +std::string encode_status_message( + const std::optional& request_id, + const std::string& phase, + const std::string& message); +std::string encode_result_message( + const std::string& request_id, + const std::string& text, + const std::string& stdout, + const std::string& stderr, + const std::optional& runtime_profile); +std::string encode_error_message( + const std::optional& request_id, + const std::string& message, + const std::optional& details); + +bool write_message(std::ostream& output, const std::string& line); + +} // namespace parakeet::helper_protocol diff --git a/examples/models/parakeet/parakeet_transcriber.cpp b/examples/models/parakeet/parakeet_transcriber.cpp new file mode 100644 index 00000000000..f8dc22c4c66 --- /dev/null +++ b/examples/models/parakeet/parakeet_transcriber.cpp @@ -0,0 +1,693 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "parakeet_transcriber.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace parakeet { +namespace { + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::extension::llm::Stats; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +using SteadyClock = std::chrono::steady_clock; + +const std::vector kDurations = {0, 1, 2, 3, 4}; + +double elapsed_us( + const SteadyClock::time_point& start, + const SteadyClock::time_point& end) { + return std::chrono::duration(end - start).count(); +} + +struct MethodTiming { + double total_us = 0.0; + double max_us = 0.0; + int64_t calls = 0; + + void add(double sample_us) { + total_us += sample_us; + max_us = std::max(max_us, sample_us); + ++calls; + } + + double total_ms() const { + return total_us / 1000.0; + } + + double avg_us() const { + return calls == 0 ? 0.0 : total_us / static_cast(calls); + } +}; + +struct DecodeLoopProfile { + double total_us = 0.0; + double frame_copy_us = 0.0; + double state_copy_us = 0.0; + int64_t blank_steps = 0; + int64_t emitted_tokens = 0; + MethodTiming joint; + MethodTiming decoder_step; + + double accounted_us() const { + return joint.total_us + decoder_step.total_us + frame_copy_us + + state_copy_us; + } + + double host_overhead_us() const { + return std::max(0.0, total_us - accounted_us()); + } +}; + +std::string format_method_profile( + const char* name, + const MethodTiming& timing, + const std::string& indent = " ") { + std::ostringstream oss; + oss << std::fixed << std::setprecision(3) << indent << name << ": " + << timing.total_ms() << " ms"; + if (timing.calls > 0) { + oss << " (" << timing.calls << " calls, " << timing.avg_us() << " us avg, " + << timing.max_us << " us max)"; + } + return oss.str(); +} + +std::string build_runtime_profile_report( + double preprocessor_us, + double encoder_us, + double metadata_us, + const DecodeLoopProfile& decode_profile) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(3); + oss << "\nRuntime profile:\n"; + oss << " preprocessor: " << (preprocessor_us / 1000.0) << " ms\n"; + oss << " encoder: " << (encoder_us / 1000.0) << " ms\n"; + oss << " metadata: " << (metadata_us / 1000.0) << " ms\n"; + oss << " decode_loop: " << (decode_profile.total_us / 1000.0) << " ms\n"; + oss << format_method_profile("joint", decode_profile.joint) << "\n"; + oss << format_method_profile("decoder_step", decode_profile.decoder_step) + << "\n"; + oss << " frame_copy: " << (decode_profile.frame_copy_us / 1000.0) << " ms\n"; + oss << " state_copy: " << (decode_profile.state_copy_us / 1000.0) << " ms\n"; + oss << " host_overhead: " << (decode_profile.host_overhead_us() / 1000.0) + << " ms\n"; + oss << " blank_steps: " << decode_profile.blank_steps << "\n"; + oss << " emitted_tokens: " << decode_profile.emitted_tokens << "\n"; + oss << "RUNTIME_PROFILE" << " preprocessor_ms=" << (preprocessor_us / 1000.0) + << " encoder_ms=" << (encoder_us / 1000.0) + << " metadata_ms=" << (metadata_us / 1000.0) + << " decode_loop_ms=" << (decode_profile.total_us / 1000.0) + << " joint_ms=" << decode_profile.joint.total_ms() + << " joint_calls=" << decode_profile.joint.calls + << " joint_avg_us=" << decode_profile.joint.avg_us() + << " decoder_step_ms=" << decode_profile.decoder_step.total_ms() + << " decoder_step_calls=" << decode_profile.decoder_step.calls + << " decoder_step_avg_us=" << decode_profile.decoder_step.avg_us() + << " frame_copy_ms=" << (decode_profile.frame_copy_us / 1000.0) + << " state_copy_ms=" << (decode_profile.state_copy_us / 1000.0) + << " host_overhead_ms=" << (decode_profile.host_overhead_us() / 1000.0) + << " blank_steps=" << decode_profile.blank_steps + << " emitted_tokens=" << decode_profile.emitted_tokens << "\n"; + return oss.str(); +} + +std::string to_lower_ascii(std::string s) { + for (char& ch : s) { + ch = static_cast(std::tolower(static_cast(ch))); + } + return s; +} + +[[noreturn]] void throw_runtime_error(const std::string& message) { + ET_LOG(Error, "%s", message.c_str()); + throw std::runtime_error(message); +} + +void emit_status( + const StatusCallback& status_callback, + const std::string& status) { + if (status_callback) { + status_callback(status); + } +} + +::executorch::runtime::Result<::executorch::aten::ScalarType> +get_input_scalar_type( + Module& model, + const char* method_name, + size_t input_index) { + auto method_meta_result = model.method_meta(method_name); + if (!method_meta_result.ok()) { + ET_LOG(Error, "Failed to get method metadata for %s", method_name); + return method_meta_result.error(); + } + auto method_meta = method_meta_result.get(); + if (method_meta.num_inputs() <= input_index) { + ET_LOG( + Error, + "Method %s has %zu inputs, but requested index %zu", + method_name, + method_meta.num_inputs(), + input_index); + return ::executorch::runtime::Error::InvalidArgument; + } + auto input_meta_result = method_meta.input_tensor_meta(input_index); + if (input_meta_result.error() != ::executorch::runtime::Error::Ok) { + ET_LOG( + Error, + "Failed to get input tensor metadata for %s[%zu]", + method_name, + input_index); + return input_meta_result.error(); + } + return input_meta_result.get().scalar_type(); +} + +int64_t execute_int_constant(Module& model, const char* method_name) { + std::vector empty_inputs; + auto result = model.execute(method_name, empty_inputs); + if (!result.ok()) { + throw_runtime_error( + std::string("Failed to query model metadata method: ") + method_name); + } + return result.get()[0].toInt(); +} + +double execute_double_constant(Module& model, const char* method_name) { + std::vector empty_inputs; + auto result = model.execute(method_name, empty_inputs); + if (!result.ok()) { + throw_runtime_error( + std::string("Failed to query model metadata method: ") + method_name); + } + return result.get()[0].toDouble(); +} + +std::vector greedy_decode_executorch( + Module& model, + const ::executorch::aten::Tensor& f_proj, + int64_t encoder_len, + int64_t blank_id, + int64_t num_rnn_layers, + int64_t pred_hidden, + int64_t max_symbols_per_step, + Stats* stats = nullptr, + DecodeLoopProfile* decode_profile = nullptr) { + std::vector hypothesis; + const auto decode_loop_start = SteadyClock::now(); + const auto finalize_profile = [&]() { + if (!decode_profile) { + return; + } + decode_profile->total_us = + elapsed_us(decode_loop_start, SteadyClock::now()); + decode_profile->emitted_tokens = static_cast(hypothesis.size()); + }; + + const size_t proj_dim = static_cast(f_proj.sizes()[2]); + + auto h_dtype_result = get_input_scalar_type(model, "decoder_step", 1); + if (!h_dtype_result.ok()) { + finalize_profile(); + throw_runtime_error("Failed to inspect decoder_step hidden-state dtype."); + } + auto c_dtype_result = get_input_scalar_type(model, "decoder_step", 2); + if (!c_dtype_result.ok()) { + finalize_profile(); + throw_runtime_error("Failed to inspect decoder_step cell-state dtype."); + } + auto h_dtype = h_dtype_result.get(); + auto c_dtype = c_dtype_result.get(); + + ET_LOG( + Info, + "Decoder h dtype: %s, c dtype: %s", + ::executorch::runtime::toString(h_dtype), + ::executorch::runtime::toString(c_dtype)); + + const size_t h_elem_size = ::executorch::runtime::elementSize(h_dtype); + const size_t c_elem_size = ::executorch::runtime::elementSize(c_dtype); + const size_t num_elements = + static_cast(num_rnn_layers) * static_cast(pred_hidden); + + std::vector h_data(num_elements * h_elem_size, 0); + std::vector c_data(num_elements * c_elem_size, 0); + + auto h = from_blob( + h_data.data(), + {static_cast<::executorch::aten::SizesType>(num_rnn_layers), + 1, + static_cast<::executorch::aten::SizesType>(pred_hidden)}, + h_dtype); + auto c = from_blob( + c_data.data(), + {static_cast<::executorch::aten::SizesType>(num_rnn_layers), + 1, + static_cast<::executorch::aten::SizesType>(pred_hidden)}, + c_dtype); + + std::vector sos_token_data = {blank_id}; + auto sos_token = from_blob( + sos_token_data.data(), {1, 1}, ::executorch::aten::ScalarType::Long); + const auto decoder_init_start = SteadyClock::now(); + auto decoder_init_result = + model.execute("decoder_step", std::vector{sos_token, h, c}); + if (decode_profile) { + decode_profile->decoder_step.add( + elapsed_us(decoder_init_start, SteadyClock::now())); + } + if (!decoder_init_result.ok()) { + finalize_profile(); + throw_runtime_error("decoder_step (SOS) failed"); + } + auto& init_outputs = decoder_init_result.get(); + auto g_proj_init = init_outputs[0].toTensor(); + auto new_h_init = init_outputs[1].toTensor(); + auto new_c_init = init_outputs[2].toTensor(); + const auto init_state_copy_start = SteadyClock::now(); + std::memcpy(h_data.data(), new_h_init.const_data_ptr(), h_data.size()); + std::memcpy(c_data.data(), new_c_init.const_data_ptr(), c_data.size()); + if (decode_profile) { + decode_profile->state_copy_us += + elapsed_us(init_state_copy_start, SteadyClock::now()); + } + + auto f_dtype_result = get_input_scalar_type(model, "joint", 0); + if (!f_dtype_result.ok()) { + finalize_profile(); + throw_runtime_error("Failed to inspect joint f dtype."); + } + auto g_dtype_result = get_input_scalar_type(model, "joint", 1); + if (!g_dtype_result.ok()) { + finalize_profile(); + throw_runtime_error("Failed to inspect joint g dtype."); + } + auto f_dtype = f_dtype_result.get(); + auto g_dtype = g_dtype_result.get(); + + ET_LOG( + Info, + "Joint f dtype: %s, g dtype: %s", + ::executorch::runtime::toString(f_dtype), + ::executorch::runtime::toString(g_dtype)); + + const size_t f_elem_size = ::executorch::runtime::elementSize(f_dtype); + const size_t g_elem_size = ::executorch::runtime::elementSize(g_dtype); + + const size_t g_proj_num_bytes = + static_cast(g_proj_init.numel()) * g_elem_size; + std::vector g_proj_data(g_proj_num_bytes); + std::memcpy( + g_proj_data.data(), g_proj_init.const_data_ptr(), g_proj_num_bytes); + + int64_t t = 0; + int64_t symbols_on_frame = 0; + const uint8_t* f_proj_ptr = + static_cast(f_proj.const_data_ptr()); + const size_t f_t_num_bytes = proj_dim * f_elem_size; + + while (t < encoder_len) { + std::vector f_t_data(f_t_num_bytes); + const auto frame_copy_start = SteadyClock::now(); + std::memcpy( + f_t_data.data(), + f_proj_ptr + static_cast(t) * f_t_num_bytes, + f_t_num_bytes); + if (decode_profile) { + decode_profile->frame_copy_us += + elapsed_us(frame_copy_start, SteadyClock::now()); + } + + auto f_t = from_blob( + f_t_data.data(), + {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, + f_dtype); + + auto g_proj = from_blob( + g_proj_data.data(), + {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, + g_dtype); + + const auto joint_start = SteadyClock::now(); + auto joint_result = + model.execute("joint", std::vector{f_t, g_proj}); + if (decode_profile) { + decode_profile->joint.add(elapsed_us(joint_start, SteadyClock::now())); + } + if (!joint_result.ok()) { + finalize_profile(); + throw_runtime_error( + "joint failed at t=" + std::to_string(static_cast(t))); + } + + const int64_t k = + joint_result.get()[0].toTensor().const_data_ptr()[0]; + const int64_t dur_idx = + joint_result.get()[1].toTensor().const_data_ptr()[0]; + const int64_t dur = kDurations[dur_idx]; + + if (k == blank_id) { + if (decode_profile) { + ++decode_profile->blank_steps; + } + t += std::max(dur, static_cast(1)); + symbols_on_frame = 0; + } else { + if (hypothesis.empty() && stats) { + stats->first_token_ms = ::executorch::extension::llm::time_in_ms(); + } + hypothesis.push_back({static_cast(k), t, dur}); + + std::vector token_data = {k}; + auto token = from_blob( + token_data.data(), {1, 1}, ::executorch::aten::ScalarType::Long); + + const auto decoder_step_start = SteadyClock::now(); + auto decoder_result = + model.execute("decoder_step", std::vector{token, h, c}); + if (decode_profile) { + decode_profile->decoder_step.add( + elapsed_us(decoder_step_start, SteadyClock::now())); + } + if (!decoder_result.ok()) { + finalize_profile(); + throw_runtime_error("decoder_step failed"); + } + auto& outputs = decoder_result.get(); + auto new_g_proj = outputs[0].toTensor(); + auto new_h = outputs[1].toTensor(); + auto new_c = outputs[2].toTensor(); + + const auto state_copy_start = SteadyClock::now(); + std::memcpy(h_data.data(), new_h.const_data_ptr(), h_data.size()); + std::memcpy(c_data.data(), new_c.const_data_ptr(), c_data.size()); + std::memcpy( + g_proj_data.data(), new_g_proj.const_data_ptr(), g_proj_data.size()); + if (decode_profile) { + decode_profile->state_copy_us += + elapsed_us(state_copy_start, SteadyClock::now()); + } + + t += dur; + + if (dur == 0) { + ++symbols_on_frame; + if (symbols_on_frame >= max_symbols_per_step) { + ++t; + symbols_on_frame = 0; + } + } else { + symbols_on_frame = 0; + } + } + } + + finalize_profile(); + return hypothesis; +} + +} // namespace + +TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) { + if (raw_arg.empty()) { + throw std::invalid_argument( + "Invalid --timestamps value (empty). Expected: token, word, segment, all."); + } + const std::string mode = to_lower_ascii(raw_arg); + if (mode == "none") { + return {false, false, false}; + } + if (mode == "token") { + return {true, false, false}; + } + if (mode == "word") { + return {false, true, false}; + } + if (mode == "segment") { + return {false, false, true}; + } + if (mode == "all") { + return {true, true, true}; + } + throw std::invalid_argument( + "Invalid --timestamps value '" + raw_arg + + "'. Expected: token, word, segment, all."); +} + +ParakeetTranscriber::ParakeetTranscriber( + const std::string& model_path, + const std::string& tokenizer_path, + const std::string& data_path) { + model_load_start_ms_ = ::executorch::extension::llm::time_in_ms(); + ET_LOG(Info, "Loading model from: %s", model_path.c_str()); + if (!data_path.empty()) { + ET_LOG(Info, "Loading data from: %s", data_path.c_str()); + model_ = + std::make_unique(model_path, data_path, Module::LoadMode::Mmap); + } else { + model_ = std::make_unique(model_path, Module::LoadMode::Mmap); + } + + auto model_load_error = model_->load(); + if (model_load_error != Error::Ok) { + throw_runtime_error("Failed to load model."); + } + + const std::vector required_methods = { + "preprocessor", "encoder", "decoder_step", "joint"}; + for (const auto& method : required_methods) { + auto method_load_error = model_->load_method(method); + if (method_load_error != Error::Ok) { + throw_runtime_error("Failed to load method: " + method); + } + } + + model_load_end_ms_ = ::executorch::extension::llm::time_in_ms(); + + num_rnn_layers_ = execute_int_constant(*model_, "num_rnn_layers"); + pred_hidden_ = execute_int_constant(*model_, "pred_hidden"); + vocab_size_ = execute_int_constant(*model_, "vocab_size"); + blank_id_ = execute_int_constant(*model_, "blank_id"); + sample_rate_ = execute_int_constant(*model_, "sample_rate"); + window_stride_ = execute_double_constant(*model_, "window_stride"); + encoder_subsampling_factor_ = + execute_int_constant(*model_, "encoder_subsampling_factor"); + frame_to_seconds_ = + window_stride_ * static_cast(encoder_subsampling_factor_); + + ET_LOG( + Info, + "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld, window_stride=%.6f, encoder_subsampling_factor=%lld", + static_cast(vocab_size_), + static_cast(blank_id_), + static_cast(num_rnn_layers_), + static_cast(pred_hidden_), + static_cast(sample_rate_), + window_stride_, + static_cast(encoder_subsampling_factor_)); + + ET_LOG(Info, "Loading tokenizer from: %s", tokenizer_path.c_str()); + tokenizer_ = ::executorch::extension::llm::load_tokenizer(tokenizer_path); + if (!tokenizer_ || !tokenizer_->is_loaded()) { + throw_runtime_error("Failed to load tokenizer from: " + tokenizer_path); + } + + supported_punctuation_ = + parakeet::tokenizer_utils::derive_supported_punctuation(*tokenizer_); + ET_LOG( + Info, + "Derived supported_punctuation size=%zu", + supported_punctuation_.size()); +} + +TranscribeResult ParakeetTranscriber::transcribe_wav_path( + const std::string& audio_path, + const TranscribeConfig& config, + StatusCallback status_callback) { + ET_LOG(Info, "Loading audio from: %s", audio_path.c_str()); + emit_status(status_callback, "Loading recording..."); + std::vector audio_data = + ::executorch::extension::llm::load_wav_audio_data(audio_path); + ET_LOG(Info, "Loaded %zu audio samples", audio_data.size()); + return transcribe_audio( + audio_data.data(), + static_cast(audio_data.size()), + config, + std::move(status_callback)); +} + +TranscribeResult ParakeetTranscriber::transcribe_audio( + const float* audio_data, + int64_t num_samples, + const TranscribeConfig& config, + StatusCallback status_callback) { + Stats stats; + stats.model_load_start_ms = model_load_start_ms_; + stats.model_load_end_ms = model_load_end_ms_; + stats.inference_start_ms = ::executorch::extension::llm::time_in_ms(); + + auto audio_tensor = from_blob( + const_cast(audio_data), + {static_cast<::executorch::aten::SizesType>(num_samples)}, + ::executorch::aten::ScalarType::Float); + std::vector audio_len_data = {num_samples}; + auto audio_len_tensor = from_blob( + audio_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); + + ET_LOG(Info, "Running preprocessor..."); + emit_status(status_callback, "Running preprocessor..."); + double preprocessor_us = 0.0; + const auto preprocessor_start = SteadyClock::now(); + auto proc_result = model_->execute( + "preprocessor", std::vector{audio_tensor, audio_len_tensor}); + preprocessor_us = elapsed_us(preprocessor_start, SteadyClock::now()); + if (!proc_result.ok()) { + throw_runtime_error("Preprocessor forward failed."); + } + auto& proc_outputs = proc_result.get(); + auto mel = proc_outputs[0].toTensor(); + auto mel_len_tensor_out = proc_outputs[1].toTensor(); + int64_t mel_len_value = mel_len_tensor_out.const_data_ptr()[0]; + + std::vector mel_len_data = {mel_len_value}; + auto mel_len = + from_blob(mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); + + ET_LOG( + Info, + "Mel spectrogram shape: [%ld, %ld, %ld], mel_len: %lld", + static_cast(mel.sizes()[0]), + static_cast(mel.sizes()[1]), + static_cast(mel.sizes()[2]), + static_cast(mel_len_value)); + + ET_LOG(Info, "Running encoder..."); + emit_status(status_callback, "Running encoder..."); + double encoder_us = 0.0; + const auto encoder_start = SteadyClock::now(); + auto enc_result = + model_->execute("encoder", std::vector{mel, mel_len}); + encoder_us = elapsed_us(encoder_start, SteadyClock::now()); + if (!enc_result.ok()) { + throw_runtime_error("Encoder forward failed."); + } + stats.prompt_eval_end_ms = ::executorch::extension::llm::time_in_ms(); + + auto& enc_outputs = enc_result.get(); + auto f_proj = enc_outputs[0].toTensor(); + const int64_t encoded_len = + enc_outputs[1].toTensor().const_data_ptr()[0]; + + ET_LOG( + Info, + "Encoder output (f_proj) shape: [%ld, %ld, %ld], len=%ld", + static_cast(f_proj.sizes()[0]), + static_cast(f_proj.sizes()[1]), + static_cast(f_proj.sizes()[2]), + static_cast(encoded_len)); + + ET_LOG(Info, "Running TDT greedy decode..."); + emit_status(status_callback, "Decoding final transcript..."); + DecodeLoopProfile decode_profile; + auto decoded_tokens = greedy_decode_executorch( + *model_, + f_proj, + encoded_len, + blank_id_, + num_rnn_layers_, + pred_hidden_, + 10, + &stats, + config.runtime_profile ? &decode_profile : nullptr); + + ET_LOG(Info, "Decoded %zu tokens", decoded_tokens.size()); + + const std::string text = parakeet::tokenizer_utils::decode_token_sequence( + decoded_tokens, *tokenizer_); + + stats.inference_end_ms = ::executorch::extension::llm::time_in_ms(); + stats.num_prompt_tokens = encoded_len; + stats.num_generated_tokens = static_cast(decoded_tokens.size()); + + double metadata_us = 0.0; + if (config.runtime_profile) { + metadata_us = 0.0; + } + + TranscribeResult result; + result.text = text; + result.stats_json = ::executorch::extension::llm::stats_to_json_string(stats); + result.frame_to_seconds = frame_to_seconds_; + if (config.runtime_profile) { + result.runtime_profile_report = build_runtime_profile_report( + preprocessor_us, encoder_us, metadata_us, decode_profile); + } + + if (!config.timestamp_output_mode.enabled()) { + return result; + } + + ET_LOG(Info, "Computing timestamps..."); + emit_status(status_callback, "Computing timestamps..."); + auto tokens_with_text_info = + parakeet::timestamp_utils::get_tokens_with_text_info( + decoded_tokens, *tokenizer_, supported_punctuation_); + auto word_offsets = parakeet::timestamp_utils::get_words_offsets( + tokens_with_text_info, *tokenizer_, supported_punctuation_); + auto segment_offsets = + parakeet::timestamp_utils::get_segment_offsets(word_offsets); + + result.token_offsets = std::move(tokens_with_text_info); + result.word_offsets = std::move(word_offsets); + result.segment_offsets = std::move(segment_offsets); + return result; +} + +std::optional extract_runtime_profile_line( + const std::optional& report) { + if (!report.has_value()) { + return std::nullopt; + } + + std::istringstream stream(*report); + std::string line; + while (std::getline(stream, line)) { + if (line.rfind("RUNTIME_PROFILE", 0) == 0) { + return line; + } + } + return std::nullopt; +} + +} // namespace parakeet diff --git a/examples/models/parakeet/parakeet_transcriber.h b/examples/models/parakeet/parakeet_transcriber.h new file mode 100644 index 00000000000..f44c4c113c0 --- /dev/null +++ b/examples/models/parakeet/parakeet_transcriber.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "timestamp_utils.h" +#include "tokenizer_utils.h" +#include "types.h" + +namespace parakeet { + +struct TimestampOutputMode { + bool token = false; + bool word = false; + bool segment = false; + + bool enabled() const { + return token || word || segment; + } +}; + +TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg); + +struct TranscribeConfig { + TimestampOutputMode timestamp_output_mode; + bool runtime_profile = false; +}; + +struct TranscribeResult { + std::string text; + std::string stats_json; + std::optional runtime_profile_report; + double frame_to_seconds = 0.0; + std::vector token_offsets; + std::vector word_offsets; + std::vector segment_offsets; +}; + +using StatusCallback = std::function; + +class ParakeetTranscriber { + public: + ParakeetTranscriber( + const std::string& model_path, + const std::string& tokenizer_path, + const std::string& data_path = ""); + + TranscribeResult transcribe_audio( + const float* audio_data, + int64_t num_samples, + const TranscribeConfig& config, + StatusCallback status_callback = {}); + + TranscribeResult transcribe_wav_path( + const std::string& audio_path, + const TranscribeConfig& config, + StatusCallback status_callback = {}); + + private: + std::unique_ptr<::executorch::extension::Module> model_; + std::unique_ptr tokenizer_; + + int64_t vocab_size_ = 0; + int64_t blank_id_ = 0; + int64_t num_rnn_layers_ = 0; + int64_t pred_hidden_ = 0; + int64_t sample_rate_ = 0; + double window_stride_ = 0.0; + int64_t encoder_subsampling_factor_ = 0; + double frame_to_seconds_ = 0.0; + + long model_load_start_ms_ = 0; + long model_load_end_ms_ = 0; + + std::unordered_set supported_punctuation_; +}; + +std::optional extract_runtime_profile_line( + const std::optional& report); + +} // namespace parakeet