diff --git a/Cargo.lock b/Cargo.lock index 06dd2f5..9c6c7da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -863,11 +863,13 @@ dependencies = [ "flume", "hound", "mutter", + "num_cpus", "rdev", "rodio 0.17.3", "tempfile", "tokio", "ureq", + "whisper-rs", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a168f79..a2e9caa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,3 +27,13 @@ directories = "6.0.0" ureq = "2.9.6" rodio = "0.17" default-device-sink = "0.1" +whisper-rs = "0.11.1" +num_cpus = "1.16" + +[features] +# Enable CUDA support by building with `--features cuda` +cuda = ["whisper-rs/cuda"] +# Enable OpenCL support by building with `--features opencl` +opencl = ["whisper-rs/opencl"] +# Enable Metal support by building with `--features metal` +metal = ["whisper-rs/metal"] diff --git a/README.md b/README.md index d277141..2a93169 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,22 @@ To run transcription locally without the OpenAI API, specify a model size with desk-talk --ptt-key scroll-lock --local --model tiny ``` +To run the local model on the GPU (when built with GPU support), add `--use-gpu`: + +``` +desk-talk --ptt-key scroll-lock --local --model tiny --use-gpu +``` + +When compiling from source, GPU support requires enabling one of the +`cuda`, `opencl`, or `metal` features. For example, to build with CUDA support +run: + +``` +cargo build --release --features cuda +``` + +Replace `cuda` with the appropriate feature for your hardware. + Available models include `tiny`, `base`, `small`, `medium`, and the large variants `large-v1`, `large-v2`, or `large-v3`. diff --git a/src/main.rs b/src/main.rs index e8730e7..76b03b8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -48,6 +48,10 @@ struct Opt { #[arg(long, value_enum, requires = "local")] model: Option, + /// Use the GPU for local transcription when available. + #[arg(long, requires = "local")] + use_gpu: bool, + /// Ensures the first letter of the transcription is capitalized. #[arg(short, long)] cap_first: bool, @@ -318,7 +322,7 @@ fn main() -> Result<(), Box> { let model = opt .model .expect("--model required when --local is used"); - trans::transcribe_local(&voice_tmp_path, model.into()) + trans::transcribe_local(&voice_tmp_path, model.into(), opt.use_gpu) } else { runtime.block_on(trans::transcribe_with_retry( &client, diff --git a/src/transcribe.rs b/src/transcribe.rs index 26cd388..6e08f74 100644 --- a/src/transcribe.rs +++ b/src/transcribe.rs @@ -4,7 +4,11 @@ pub mod trans { use async_openai::{config::OpenAIConfig, types::CreateTranscriptionRequestArgs, Client}; use async_std::future; use directories::ProjectDirs; - use mutter::{Model, ModelType}; + use mutter::ModelType; + use rodio::{source::UniformSourceIterator, Decoder, Source}; + use std::io::Cursor; + use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; + use num_cpus; use std::fs; use std::time::Duration; use std::{ @@ -137,13 +141,19 @@ pub mod trans { Ok(cache_dir.join(filename)) } - fn load_or_download_model(model: &ModelType) -> Result> { + fn load_or_download_context( + model: &ModelType, + use_gpu: bool, + ) -> Result> { use std::io::Read; let path = get_model_path(model)?; + let mut params = WhisperContextParameters::default(); + params.use_gpu(use_gpu); + if path.exists() { let path_str = path.to_str().ok_or_else(|| anyhow!("Invalid model path"))?; - Ok(Model::new(path_str).map_err(|e| anyhow!("{:?}", e))?) + Ok(WhisperContext::new_with_params(path_str, params).map_err(|e| anyhow!("{:?}", e))?) } else { let resp = ureq::get(&model.to_string()) .call() @@ -151,20 +161,57 @@ pub mod trans { let mut bytes = Vec::new(); resp.into_reader().read_to_end(&mut bytes)?; std::fs::write(&path, &bytes)?; - let path_str = path.to_str().ok_or_else(|| anyhow!("Invalid model path"))?; - Ok(Model::new(path_str).map_err(|e| anyhow!("{:?}", e))?) + Ok(WhisperContext::new_from_buffer_with_params(&bytes, params).map_err(|e| anyhow!("{:?}", e))?) } } - pub fn transcribe_local(input: &Path, model_type: ModelType) -> Result> { - let model = load_or_download_model(&model_type)?; + fn decode_audio(bytes: Vec) -> Result, Box> { + let input = Cursor::new(bytes); + let source = Decoder::new(input).unwrap(); + let output_sample_rate = 16000; + let channels = 1; + let resample = UniformSourceIterator::new(source, channels, output_sample_rate); + let pass_filter = resample.low_pass(3000).high_pass(200).convert_samples(); + let samples: Vec = pass_filter.collect::>(); + let mut output: Vec = vec![0.0f32; samples.len()]; + whisper_rs::convert_integer_to_float_audio(&samples, &mut output) + .map(|()| output) + .map_err(|e| anyhow!("{:?}", e).into()) + } + + pub fn transcribe_local( + input: &Path, + model_type: ModelType, + use_gpu: bool, + ) -> Result> { + let ctx = load_or_download_context(&model_type, use_gpu)?; let bytes = fs::read(input)?; - let res = model - .transcribe_audio(bytes, false, false, None) - .map_err(|e| anyhow!("{:?}", e))?; + let samples = decode_audio(bytes)?; + + let mut params = FullParams::new(SamplingStrategy::BeamSearch { beam_size: 5, patience: 1.0 }); + params.set_translate(false); + params.set_print_special(false); + params.set_print_progress(false); + params.set_print_realtime(false); + params.set_print_timestamps(false); + params.set_token_timestamps(false); + params.set_split_on_word(true); + params.set_n_threads(num_cpus::get() as i32); + + let mut state = ctx.create_state().expect("failed to create state"); + state.full(params, &samples).expect("failed to transcribe"); + + let num_segments = state.full_n_segments().expect("failed to get segments"); + let mut result = String::new(); + for i in 0..num_segments { + let segment = state + .full_get_segment_text(i) + .map_err(|e| anyhow!("{:?}", e))?; + result.push_str(&segment); + result.push(' '); + } - let mut res = res.as_text(); - res = res.replace("\n", " "); // Remove double spaces + let mut res = result.replace('\n', " "); res = res.trim().to_string(); Ok(res) }