Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,13 @@ directories = "6.0.0"
ureq = "2.9.6"
rodio = "0.17"
default-device-sink = "0.1"
whisper-rs = "0.11.1"
num_cpus = "1.16"

[features]
# Enable CUDA support by building with `--features cuda`
cuda = ["whisper-rs/cuda"]
# Enable OpenCL support by building with `--features opencl`
opencl = ["whisper-rs/opencl"]
# Enable Metal support by building with `--features metal`
metal = ["whisper-rs/metal"]
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ To run transcription locally without the OpenAI API, specify a model size with
desk-talk --ptt-key scroll-lock --local --model tiny
```

To run the local model on the GPU (when built with GPU support), add `--use-gpu`:

```
desk-talk --ptt-key scroll-lock --local --model tiny --use-gpu
```

When compiling from source, GPU support requires enabling one of the
`cuda`, `opencl`, or `metal` features. For example, to build with CUDA support
run:

```
cargo build --release --features cuda
```

Replace `cuda` with the appropriate feature for your hardware.

Available models include `tiny`, `base`, `small`, `medium`, and the large
variants `large-v1`, `large-v2`, or `large-v3`.

Expand Down
6 changes: 5 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ struct Opt {
#[arg(long, value_enum, requires = "local")]
model: Option<LocalModel>,

/// Use the GPU for local transcription when available.
#[arg(long, requires = "local")]
use_gpu: bool,

/// Ensures the first letter of the transcription is capitalized.
#[arg(short, long)]
cap_first: bool,
Expand Down Expand Up @@ -318,7 +322,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let model = opt
.model
.expect("--model required when --local is used");
trans::transcribe_local(&voice_tmp_path, model.into())
trans::transcribe_local(&voice_tmp_path, model.into(), opt.use_gpu)
} else {
runtime.block_on(trans::transcribe_with_retry(
&client,
Expand Down
71 changes: 59 additions & 12 deletions src/transcribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ pub mod trans {
use async_openai::{config::OpenAIConfig, types::CreateTranscriptionRequestArgs, Client};
use async_std::future;
use directories::ProjectDirs;
use mutter::{Model, ModelType};
use mutter::ModelType;
use rodio::{source::UniformSourceIterator, Decoder, Source};
use std::io::Cursor;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
use num_cpus;
use std::fs;
use std::time::Duration;
use std::{
Expand Down Expand Up @@ -137,34 +141,77 @@ pub mod trans {
Ok(cache_dir.join(filename))
}

fn load_or_download_model(model: &ModelType) -> Result<Model, Box<dyn Error>> {
fn load_or_download_context(
model: &ModelType,
use_gpu: bool,
) -> Result<WhisperContext, Box<dyn Error>> {
use std::io::Read;

let path = get_model_path(model)?;
let mut params = WhisperContextParameters::default();
params.use_gpu(use_gpu);

if path.exists() {
let path_str = path.to_str().ok_or_else(|| anyhow!("Invalid model path"))?;
Ok(Model::new(path_str).map_err(|e| anyhow!("{:?}", e))?)
Ok(WhisperContext::new_with_params(path_str, params).map_err(|e| anyhow!("{:?}", e))?)
} else {
let resp = ureq::get(&model.to_string())
.call()
.map_err(|e| anyhow!("Download error: {:?}", e))?;
let mut bytes = Vec::new();
resp.into_reader().read_to_end(&mut bytes)?;
std::fs::write(&path, &bytes)?;
let path_str = path.to_str().ok_or_else(|| anyhow!("Invalid model path"))?;
Ok(Model::new(path_str).map_err(|e| anyhow!("{:?}", e))?)
Ok(WhisperContext::new_from_buffer_with_params(&bytes, params).map_err(|e| anyhow!("{:?}", e))?)
}
}

pub fn transcribe_local(input: &Path, model_type: ModelType) -> Result<String, Box<dyn Error>> {
let model = load_or_download_model(&model_type)?;
fn decode_audio(bytes: Vec<u8>) -> Result<Vec<f32>, Box<dyn Error>> {
let input = Cursor::new(bytes);
let source = Decoder::new(input).unwrap();
let output_sample_rate = 16000;
let channels = 1;
let resample = UniformSourceIterator::new(source, channels, output_sample_rate);
let pass_filter = resample.low_pass(3000).high_pass(200).convert_samples();
let samples: Vec<i16> = pass_filter.collect::<Vec<i16>>();
let mut output: Vec<f32> = vec![0.0f32; samples.len()];
whisper_rs::convert_integer_to_float_audio(&samples, &mut output)
.map(|()| output)
.map_err(|e| anyhow!("{:?}", e).into())
}

pub fn transcribe_local(
input: &Path,
model_type: ModelType,
use_gpu: bool,
) -> Result<String, Box<dyn Error>> {
let ctx = load_or_download_context(&model_type, use_gpu)?;
let bytes = fs::read(input)?;
let res = model
.transcribe_audio(bytes, false, false, None)
.map_err(|e| anyhow!("{:?}", e))?;
let samples = decode_audio(bytes)?;

let mut params = FullParams::new(SamplingStrategy::BeamSearch { beam_size: 5, patience: 1.0 });
params.set_translate(false);
params.set_print_special(false);
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
params.set_token_timestamps(false);
params.set_split_on_word(true);
params.set_n_threads(num_cpus::get() as i32);

let mut state = ctx.create_state().expect("failed to create state");
state.full(params, &samples).expect("failed to transcribe");

let num_segments = state.full_n_segments().expect("failed to get segments");
let mut result = String::new();
for i in 0..num_segments {
let segment = state
.full_get_segment_text(i)
.map_err(|e| anyhow!("{:?}", e))?;
result.push_str(&segment);
result.push(' ');
}

let mut res = res.as_text();
res = res.replace("\n", " "); // Remove double spaces
let mut res = result.replace('\n', " ");
res = res.trim().to_string();
Ok(res)
}
Expand Down
Loading