Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions crates/audio-device/src/macos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl MacOSBackend {
.unwrap_or(TransportType::Unknown);

let is_default = default_device_id
.map(|id| device.0.0 == id)
.map(|id| device.0 .0 == id)
.unwrap_or(false);

let mut audio_device = AudioDevice {
Expand Down Expand Up @@ -102,7 +102,11 @@ impl MacOSBackend {
})
});

if detected { Some(true) } else { None }
if detected {
Some(true)
} else {
None
}
}

fn is_external_from_device(device: Option<ca::Device>) -> bool {
Expand All @@ -120,8 +124,8 @@ impl AudioDeviceBackend for MacOSBackend {
let ca_devices =
ca::System::devices().map_err(|e| Error::EnumerationFailed(format!("{:?}", e)))?;

let default_input_id = ca::System::default_input_device().ok().map(|d| d.0.0);
let default_output_id = ca::System::default_output_device().ok().map(|d| d.0.0);
let default_input_id = ca::System::default_input_device().ok().map(|d| d.0 .0);
let default_output_id = ca::System::default_output_device().ok().map(|d| d.0 .0);

let mut devices = Vec::new();

Expand Down Expand Up @@ -161,7 +165,7 @@ impl AudioDeviceBackend for MacOSBackend {
Ok(Self::create_audio_device(
&ca_device,
AudioDirection::Input,
Some(ca_device.0.0),
Some(ca_device.0 .0),
))
}

Expand All @@ -178,7 +182,7 @@ impl AudioDeviceBackend for MacOSBackend {
Ok(Self::create_audio_device(
&ca_device,
AudioDirection::Output,
Some(ca_device.0.0),
Some(ca_device.0 .0),
))
}

Expand Down
8 changes: 4 additions & 4 deletions crates/audio-device/src/windows.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use crate::{AudioDevice, AudioDeviceBackend, AudioDirection, DeviceId, Error, TransportType};
use std::ffi::OsString;
use std::os::windows::ffi::OsStringExt;
use windows::core::{Interface, GUID, PCWSTR, PWSTR};
use windows::Win32::Devices::FunctionDiscovery::PKEY_Device_FriendlyName;
use windows::Win32::Media::Audio::Endpoints::IAudioEndpointVolume;
use windows::Win32::Media::Audio::{
DEVICE_STATE_ACTIVE, IMMDevice, IMMDeviceEnumerator, MMDeviceEnumerator, eAll, eCapture,
eConsole, eRender,
eAll, eCapture, eConsole, eRender, IMMDevice, IMMDeviceEnumerator, MMDeviceEnumerator,
DEVICE_STATE_ACTIVE,
};
use windows::Win32::System::Com::{
CLSCTX_ALL, COINIT_MULTITHREADED, CoCreateInstance, CoInitializeEx, CoUninitialize, STGM_READ,
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED, STGM_READ,
};
use windows::Win32::UI::Shell::PropertiesSystem::IPropertyStore;
use windows::core::{GUID, Interface, PCWSTR, PWSTR};

pub struct WindowsBackend;

Expand Down
89 changes: 87 additions & 2 deletions crates/owhisper-client/src/adapter/openai/live.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,18 @@ impl RealtimeSttAdapter for OpenAIAdapter {
false
}

fn build_ws_url(&self, api_base: &str, _params: &ListenParams, _channels: u8) -> url::Url {
let (mut url, existing_params) = Self::build_ws_url_from_base(api_base);
fn build_ws_url(&self, api_base: &str, params: &ListenParams, _channels: u8) -> url::Url {
// Detect Azure from the base URL and store flag for initial_message
if let Ok(parsed) = api_base.parse::<url::Url>() {
if let Some(host) = parsed.host_str() {
if Self::is_azure_host(host) {
self.set_azure(true);
}
}
}

let model = params.model.as_deref();
let (mut url, existing_params) = Self::build_ws_url_from_base_with_model(api_base, model);

if !existing_params.is_empty() {
let mut query_pairs = url.query_pairs_mut();
Expand Down Expand Up @@ -78,6 +88,11 @@ impl RealtimeSttAdapter for OpenAIAdapter {
None => default,
};

// Use the Azure flag set during build_ws_url (detected from api_base URL)
if self.is_azure() {
return self.build_azure_initial_message(model, language);
}

let session_config = SessionUpdateEvent {
event_type: "session.update".to_string(),
session: SessionConfig {
Expand Down Expand Up @@ -227,6 +242,76 @@ impl RealtimeSttAdapter for OpenAIAdapter {
}
}

impl OpenAIAdapter {
/// Build Azure OpenAI-specific initial message
/// Azure uses a different session update format: transcription_session.update
fn build_azure_initial_message(
&self,
model: &str,
language: Option<String>,
) -> Option<Message> {
let session_update = AzureTranscriptionSessionUpdate {
event_type: "transcription_session.update".to_string(),
session: AzureSessionConfig {
input_audio_format: "pcm16".to_string(),
input_audio_transcription: AzureTranscriptionConfig {
model: model.to_string(),
prompt: None,
language,
},
turn_detection: Some(AzureTurnDetection {
detection_type: VAD_DETECTION_TYPE.to_string(),
threshold: Some(VAD_THRESHOLD),
prefix_padding_ms: Some(VAD_PREFIX_PADDING_MS),
silence_duration_ms: Some(VAD_SILENCE_DURATION_MS),
}),
},
};

let json = serde_json::to_string(&session_update).ok()?;
tracing::debug!(payload = %json, "azure_openai_session_update_payload");
Some(Message::Text(json.into()))
}
}

// Azure OpenAI specific session message types

#[derive(Debug, Serialize)]
struct AzureTranscriptionSessionUpdate {
#[serde(rename = "type")]
event_type: String,
session: AzureSessionConfig,
}

#[derive(Debug, Serialize)]
struct AzureSessionConfig {
input_audio_format: String,
input_audio_transcription: AzureTranscriptionConfig,
#[serde(skip_serializing_if = "Option::is_none")]
turn_detection: Option<AzureTurnDetection>,
}

#[derive(Debug, Serialize)]
struct AzureTranscriptionConfig {
model: String,
#[serde(skip_serializing_if = "Option::is_none")]
prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
language: Option<String>,
}

#[derive(Debug, Serialize)]
struct AzureTurnDetection {
#[serde(rename = "type")]
detection_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
threshold: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
prefix_padding_ms: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
silence_duration_ms: Option<u32>,
}

#[derive(Debug, Serialize)]
struct SessionUpdateEvent {
#[serde(rename = "type")]
Expand Down
134 changes: 129 additions & 5 deletions crates/owhisper-client/src/adapter/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,33 @@ use crate::providers::Provider;

use super::{LanguageQuality, LanguageSupport};

#[derive(Clone, Default)]
pub struct OpenAIAdapter;
const AZURE_API_VERSION: &str = "2025-04-01-preview";

use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};

#[derive(Clone)]
pub struct OpenAIAdapter {
is_azure: Arc<AtomicBool>,
}

impl Default for OpenAIAdapter {
fn default() -> Self {
Self {
is_azure: Arc::new(AtomicBool::new(false)),
}
}
}

impl OpenAIAdapter {
pub fn set_azure(&self, value: bool) {
self.is_azure.store(value, Ordering::SeqCst);
}

pub fn is_azure(&self) -> bool {
self.is_azure.load(Ordering::SeqCst)
}
}

impl OpenAIAdapter {
pub fn language_support_live(_languages: &[hypr_language::Language]) -> LanguageSupport {
Expand All @@ -27,7 +52,18 @@ impl OpenAIAdapter {
Self::language_support_batch(languages).is_supported()
}

pub fn is_azure_host(host: &str) -> bool {
host.ends_with(".openai.azure.com")
}

pub(crate) fn build_ws_url_from_base(api_base: &str) -> (url::Url, Vec<(String, String)>) {
Self::build_ws_url_from_base_with_model(api_base, None)
}

pub(crate) fn build_ws_url_from_base_with_model(
api_base: &str,
model: Option<&str>,
) -> (url::Url, Vec<(String, String)>) {
if api_base.is_empty() {
return (
Provider::OpenAI
Expand All @@ -43,15 +79,20 @@ impl OpenAIAdapter {
}

let parsed: url::Url = api_base.parse().expect("invalid_api_base");
let host = parsed
.host_str()
.unwrap_or(Provider::OpenAI.default_ws_host());

if Self::is_azure_host(host) {
return Self::build_azure_ws_url(&parsed, host, model);
}

let mut existing_params = super::extract_query_params(&parsed);

if !existing_params.iter().any(|(k, _)| k == "intent") {
existing_params.push(("intent".to_string(), "transcription".to_string()));
}

let host = parsed
.host_str()
.unwrap_or(Provider::OpenAI.default_ws_host());
let mut url: url::Url = format!("wss://{}{}", host, Provider::OpenAI.ws_path())
.parse()
.expect("invalid_ws_url");
Expand All @@ -60,6 +101,28 @@ impl OpenAIAdapter {

(url, existing_params)
}

fn build_azure_ws_url(
parsed: &url::Url,
host: &str,
_model: Option<&str>,
) -> (url::Url, Vec<(String, String)>) {
// For Azure transcription Realtime API:
// - deployment/model should NOT be in URL (causes 400 error)
// - deployment must be sent in transcription_session.update message
// - Only api-version and intent go in the URL

let url: url::Url = format!("wss://{}/openai/realtime", host)
.parse()
.expect("invalid_azure_ws_url");

let params = vec![
("api-version".to_string(), AZURE_API_VERSION.to_string()),
("intent".to_string(), "transcription".to_string()),
];

(url, params)
}
}

#[cfg(test)]
Expand Down Expand Up @@ -98,4 +161,65 @@ mod tests {
assert!(Provider::OpenAI.is_host("openai.com"));
assert!(!Provider::OpenAI.is_host("api.deepgram.com"));
}

#[test]
fn test_is_azure_host() {
assert!(OpenAIAdapter::is_azure_host("my-resource.openai.azure.com"));
assert!(OpenAIAdapter::is_azure_host("eastus.openai.azure.com"));
assert!(!OpenAIAdapter::is_azure_host("api.openai.com"));
assert!(!OpenAIAdapter::is_azure_host("openai.com"));
assert!(!OpenAIAdapter::is_azure_host("azure.com"));
}

#[test]
fn test_build_ws_url_azure() {
// Azure transcription: deployment should NOT be in URL
// Only api-version and intent should be in params
let (url, params) = OpenAIAdapter::build_ws_url_from_base_with_model(
"https://my-resource.openai.azure.com",
Some("gpt-4o-realtime-preview"),
);
assert_eq!(
url.as_str(),
"wss://my-resource.openai.azure.com/openai/realtime"
);
assert!(
params
.iter()
.any(|(k, v)| k == "api-version" && v == "2025-04-01-preview")
);
assert!(
params
.iter()
.any(|(k, v)| k == "intent" && v == "transcription")
);
// deployment should NOT be in URL params
assert!(!params.iter().any(|(k, _)| k == "deployment"));
}

#[test]
fn test_build_ws_url_azure_deployment_not_in_url() {
// Even if deployment is in input URL, it should not appear in output params
// (deployment goes in session message, not URL)
let (url, params) = OpenAIAdapter::build_ws_url_from_base_with_model(
"https://my-resource.openai.azure.com?deployment=my-deployment",
None,
);
assert_eq!(
url.as_str(),
"wss://my-resource.openai.azure.com/openai/realtime"
);
// deployment should NOT be in URL params
assert!(!params.iter().any(|(k, _)| k == "deployment"));
assert!(
params
.iter()
.any(|(k, v)| k == "api-version" && v == "2025-04-01-preview")
);
assert!(
params
.iter()
.any(|(k, v)| k == "intent" && v == "transcription")
);
}
}
16 changes: 13 additions & 3 deletions crates/owhisper-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ pub use providers::{Auth, Provider, is_meta_model};

use std::marker::PhantomData;

#[cfg(feature = "argmax")]
pub use adapter::StreamingBatchConfig;
pub use adapter::deepgram::DeepgramModel;
pub use adapter::{
AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, BatchSttAdapter, CactusAdapter, CallbackResult,
Expand All @@ -26,7 +24,15 @@ pub use adapter::{
documented_language_codes_batch, documented_language_codes_live, is_hyprnote_proxy,
is_local_host, normalize_languages,
};
pub use adapter::{StreamingBatchEvent, StreamingBatchStream};

fn is_azure_openai(api_base: &str) -> bool {
url::Url::parse(api_base)
.ok()
.and_then(|u| u.host_str().map(OpenAIAdapter::is_azure_host))
.unwrap_or(false)
}
#[cfg(feature = "argmax")]
pub use adapter::{StreamingBatchConfig, StreamingBatchEvent, StreamingBatchStream};

pub use batch::{BatchClient, BatchClientBuilder};
pub use error::Error;
Expand Down Expand Up @@ -128,6 +134,10 @@ impl<A: RealtimeSttAdapter> ListenClientBuilder<A> {
for (name, value) in &self.extra_headers {
request = request.with_header(name, value);
}
} else if is_azure_openai(original_api_base) {
if let Some(api_key) = self.api_key.as_deref() {
request = request.with_header("api-key", api_key);
}
} else if let Some((header_name, header_value)) =
adapter.build_auth_header(self.api_key.as_deref())
{
Expand Down
Loading