diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 3cf851e94..c71facbd5 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -177,7 +177,10 @@ dependencies { implementation(libs.jna) { artifact { type = "aar" } } implementation(libs.vosk.android) - // LiteRT / Tensorflow Lite + // ONNX Runtime for Android (runs the Parakeet encoder and TDT decoder) + implementation(libs.onnxruntime.android) + + // LiteRT / Tensorflow Lite (used by OpenWakeWord) implementation(libs.litert) // OkHttp diff --git a/app/src/main/kotlin/org/stypox/dicio/di/SttInputDeviceWrapper.kt b/app/src/main/kotlin/org/stypox/dicio/di/SttInputDeviceWrapper.kt index a6a9ae042..0e64cc8eb 100644 --- a/app/src/main/kotlin/org/stypox/dicio/di/SttInputDeviceWrapper.kt +++ b/app/src/main/kotlin/org/stypox/dicio/di/SttInputDeviceWrapper.kt @@ -23,10 +23,14 @@ import org.stypox.dicio.io.input.InputEvent import org.stypox.dicio.io.input.SttInputDevice import org.stypox.dicio.io.input.SttState import org.stypox.dicio.io.input.external_popup.ExternalPopupInputDevice +import org.stypox.dicio.io.input.parakeet.ParakeetInputDevice +import org.stypox.dicio.io.input.scribe.ScribeRealtimeInputDevice import org.stypox.dicio.io.input.vosk.VoskInputDevice import org.stypox.dicio.settings.datastore.InputDevice import org.stypox.dicio.settings.datastore.InputDevice.INPUT_DEVICE_EXTERNAL_POPUP import org.stypox.dicio.settings.datastore.InputDevice.INPUT_DEVICE_NOTHING +import org.stypox.dicio.settings.datastore.InputDevice.INPUT_DEVICE_PARAKEET +import org.stypox.dicio.settings.datastore.InputDevice.INPUT_DEVICE_SCRIBE_REALTIME import org.stypox.dicio.settings.datastore.InputDevice.INPUT_DEVICE_UNSET import org.stypox.dicio.settings.datastore.InputDevice.INPUT_DEVICE_VOSK import org.stypox.dicio.settings.datastore.InputDevice.UNRECOGNIZED @@ -60,6 +64,7 @@ class SttInputDeviceWrapperImpl( private var inputDeviceSetting: InputDevice private var sttPlaySoundSetting: SttPlaySound private val silencesBeforeStop: StateFlow + private val scribeApiKey: StateFlow private var sttInputDevice: SttInputDevice? // null means that the user has not enabled any STT input device @@ -77,7 +82,9 @@ class SttInputDeviceWrapperImpl( inputDeviceSetting = firstSettings.first sttPlaySoundSetting = firstSettings.second - silencesBeforeStop = dataStore.data.map(SttInputDevice::getSttSilenceDurationOrDefault) + silencesBeforeStop = MutableStateFlow(SttInputDevice.DEFAULT_STT_SILENCE_DURATION) + scribeApiKey = dataStore.data + .map { it.scribeApiKey.trim() } .toStateFlowDistinctBlockingFirst(scope) sttInputDevice = buildInputDevice(inputDeviceSetting) scope.launch { @@ -107,6 +114,13 @@ class SttInputDeviceWrapperImpl( UNRECOGNIZED, INPUT_DEVICE_UNSET, INPUT_DEVICE_VOSK -> VoskInputDevice(appContext, okHttpClient, localeManager, silencesBeforeStop) + INPUT_DEVICE_PARAKEET -> ParakeetInputDevice(appContext, okHttpClient, localeManager) + INPUT_DEVICE_SCRIBE_REALTIME -> ScribeRealtimeInputDevice( + okHttpClient, + localeManager, + scribeApiKey, + silencesBeforeStop, + ) INPUT_DEVICE_EXTERNAL_POPUP -> ExternalPopupInputDevice(appContext, activityForResultManager, localeManager) INPUT_DEVICE_NOTHING -> null diff --git a/app/src/main/kotlin/org/stypox/dicio/io/input/SttInputDevice.kt b/app/src/main/kotlin/org/stypox/dicio/io/input/SttInputDevice.kt index 26703685a..650e0ee05 100644 --- a/app/src/main/kotlin/org/stypox/dicio/io/input/SttInputDevice.kt +++ b/app/src/main/kotlin/org/stypox/dicio/io/input/SttInputDevice.kt @@ -1,7 +1,6 @@ package org.stypox.dicio.io.input import kotlinx.coroutines.flow.StateFlow -import org.stypox.dicio.settings.datastore.UserSettings interface SttInputDevice { val uiState: StateFlow @@ -16,9 +15,5 @@ interface SttInputDevice { companion object { const val DEFAULT_STT_SILENCE_DURATION = 2 - fun getSttSilenceDurationOrDefault(settings: UserSettings): Int { - // unfortunately there is no way to tell protobuf to use "2" as the default value - return settings.sttSilenceDuration.takeIf { it > 0 } ?: DEFAULT_STT_SILENCE_DURATION - } } } diff --git a/app/src/main/kotlin/org/stypox/dicio/io/input/SttState.kt b/app/src/main/kotlin/org/stypox/dicio/io/input/SttState.kt index 5ebc4aec1..70e8a72e4 100644 --- a/app/src/main/kotlin/org/stypox/dicio/io/input/SttState.kt +++ b/app/src/main/kotlin/org/stypox/dicio/io/input/SttState.kt @@ -97,6 +97,17 @@ sealed interface SttState { */ data object Listening : SttState + /** + * Speech has ended and silence was detected. This state is expected to be very short-lived + * and acts as user feedback before the final inference starts. + */ + data object SilenceDetected : SttState + + /** + * The model is processing recorded audio and generating the final recognition result. + */ + data object Thinking : SttState + /** * An external Android app has been asked to listen (e.g. through * `RecognizerIntent.ACTION_RECOGNIZE_SPEECH`), and may be listening but we don't know for diff --git a/app/src/main/kotlin/org/stypox/dicio/io/input/parakeet/ParakeetInputDevice.kt b/app/src/main/kotlin/org/stypox/dicio/io/input/parakeet/ParakeetInputDevice.kt new file mode 100644 index 000000000..068d3223b --- /dev/null +++ b/app/src/main/kotlin/org/stypox/dicio/io/input/parakeet/ParakeetInputDevice.kt @@ -0,0 +1,560 @@ +/* + * Taken from /e/OS Assistant + * + * Copyright (C) 2024 MURENA SAS + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package org.stypox.dicio.io.input.parakeet + +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtSession +import android.content.Context +import android.util.Log +import dagger.hilt.android.qualifiers.ApplicationContext +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.combine +import kotlinx.coroutines.flow.getAndUpdate +import kotlinx.coroutines.launch +import okhttp3.OkHttpClient +import org.stypox.dicio.di.LocaleManager +import org.stypox.dicio.io.input.InputEvent +import org.stypox.dicio.io.input.SttInputDevice +import org.stypox.dicio.io.input.SttState +import org.stypox.dicio.io.input.parakeet.ParakeetState.Downloaded +import org.stypox.dicio.io.input.parakeet.ParakeetState.Downloading +import org.stypox.dicio.io.input.parakeet.ParakeetState.ErrorDownloading +import org.stypox.dicio.io.input.parakeet.ParakeetState.ErrorLoading +import org.stypox.dicio.io.input.parakeet.ParakeetState.Listening +import org.stypox.dicio.io.input.parakeet.ParakeetState.Loaded +import org.stypox.dicio.io.input.parakeet.ParakeetState.Loading +import org.stypox.dicio.io.input.parakeet.ParakeetState.NotAvailable +import org.stypox.dicio.io.input.parakeet.ParakeetState.NotDownloaded +import org.stypox.dicio.io.input.parakeet.ParakeetState.NotInitialized +import org.stypox.dicio.io.input.parakeet.ParakeetState.NotLoaded +import org.stypox.dicio.ui.util.Progress +import org.stypox.dicio.util.FileToDownload +import org.stypox.dicio.util.LocaleUtils +import org.stypox.dicio.util.distinctUntilChangedBlockingFirst +import org.stypox.dicio.util.downloadBinaryFilesWithPartial +import java.io.BufferedReader +import java.io.File +import java.io.IOException +import java.io.InputStreamReader +import java.util.Locale + +class ParakeetInputDevice( + @ApplicationContext appContext: Context, + private val okHttpClient: OkHttpClient, + localeManager: LocaleManager, +) : SttInputDevice { + + private val _state: MutableStateFlow + private val _transientUiState = MutableStateFlow(null) + private val _uiState: MutableStateFlow + override val uiState: StateFlow + + private var operationsJob: Job? = null + private var listeningJob: Job? = null + @Volatile + private var activeListener: ParakeetListener? = null + private val scope = CoroutineScope(Dispatchers.Default) + + private val filesDir: File = appContext.filesDir + private val cacheDir: File = appContext.cacheDir + + // Model files on disk + private val encoderFile: File get() = File(filesDir, "parakeet-encoder.int8.onnx") + private val decoderJointFile: File get() = File(filesDir, "parakeet-decoder-joint.int8.onnx") + private val vocabFile: File get() = File(filesDir, "parakeet-vocab.txt") + private val preprocessorFile: File get() = File(filesDir, "parakeet-nemo128.onnx") + + // URL-check sentinel: contains the base URL that was last downloaded successfully + private val sameModelUrlCheck: File get() = File(filesDir, "parakeet-model-url") + + init { + // Run blocking, because the locale is always available right away since LocaleManager also + // initializes in a blocking way. Moreover, if ParakeetInputDevice were not initialized + // straight away, the tryLoad() call when MainActivity starts may do nothing. + val (firstLocale, nextLocaleFlow) = localeManager.locale + .distinctUntilChangedBlockingFirst() + + val initialState = init(firstLocale) + _state = MutableStateFlow(initialState) + _uiState = MutableStateFlow(initialState.toUiState()) + uiState = _uiState + + scope.launch { + combine(_state, _transientUiState) { state, transientUiState -> + transientUiState ?: state.toUiState() + }.collect { _uiState.value = it } + } + + scope.launch { + // perform initialization again every time the locale changes + nextLocaleFlow.collect { reinit(it) } + } + } + + private fun init(locale: Locale): ParakeetState { + // choose the model url based on the locale + val modelUrl = LocaleUtils.resolveValueForSupportedLocale(locale, MODEL_URLS) + + // the model url may change if the user changes app language, or in case of model updates + val modelUrlChanged = try { + sameModelUrlCheck.readText() != modelUrl + } catch (_: IOException) { + // modelUrlCheck file does not exist + true + } + + return when { + // if the modelUrl is null, then the current locale is not supported by any Parakeet + // model + modelUrl == null -> NotAvailable + // if the model url changed, the model needs to be re-downloaded + modelUrlChanged -> NotDownloaded(modelUrl) + // if all model files exist, the model has been completely downloaded and should be + // ready to be loaded + encoderFile.exists() && decoderJointFile.exists() + && vocabFile.exists() && preprocessorFile.exists() -> NotLoaded + // if any model file is missing, the model has not been downloaded yet + else -> NotDownloaded(modelUrl) + } + } + + private suspend fun reinit(locale: Locale) { + // interrupt whatever was happening before + deinit() + + // reinitialize and emit the new state + val initialState = init(locale) + _state.emit(initialState) + } + + private suspend fun deinit() { + val prevState = _state.getAndUpdate { NotInitialized } + when (prevState) { + // either interrupt the current operation or wait for it to complete + is Downloading -> { + operationsJob?.cancel() + operationsJob?.join() + } + is Loading -> { + operationsJob?.join() + when (val s = _state.getAndUpdate { NotInitialized }) { + NotInitialized -> {} // everything is ok + is Loaded -> { + s.sessions.close() + } + is Listening -> { + stopActiveListenerAndWait() + stopListening(s.sessions, s.eventListener, false) + s.sessions.close() + } + else -> { + Log.w(TAG, "Unexpected state after loading: $s") + } + } + } + is Loaded -> { + prevState.sessions.close() + } + is Listening -> { + stopActiveListenerAndWait() + stopListening(prevState.sessions, prevState.eventListener, false) + prevState.sessions.close() + } + + // these states are all resting states, so there is nothing to interrupt + is NotInitialized, + is NotAvailable, + is NotDownloaded, + is ErrorDownloading, + is Downloaded, + is NotLoaded, + is ErrorLoading -> {} + } + } + + private suspend fun stopActiveListenerAndWait() { + activeListener?.stopAndDiscardCurrentAudio() + activeListener = null + listeningJob?.cancel() + listeningJob?.join() + listeningJob = null + } + + /** + * Loads the model with [thenStartListeningEventListener] if the model is already downloaded + * but not loaded in RAM (which will then start listening if [thenStartListeningEventListener] + * is not `null` and pass events there), or starts listening if the model is already ready + * and [thenStartListeningEventListener] is not `null` and passes events there. + * + * @param thenStartListeningEventListener if not `null`, causes the [ParakeetInputDevice] to + * start listening after it has finished loading, and the received input events are sent there + * @return `true` if the input device will start listening (or be ready to do so in case + * `thenStartListeningEventListener == null`) at some point, + * `false` if manual user intervention is required to start listening + */ + override fun tryLoad(thenStartListeningEventListener: ((InputEvent) -> Unit)?): Boolean { + val s = _state.value + if (s == NotLoaded || s is ErrorLoading) { + load(thenStartListeningEventListener) + return true + } else if (thenStartListeningEventListener != null && s is Loaded) { + startListening(s.sessions, thenStartListeningEventListener) + return true + } else { + return false + } + } + + /** + * If the model is not being downloaded/loaded, or if there was an error in any of + * those steps, downloads/loads the model. If the model is already loaded (or is being + * loaded) toggles listening state. + * + * @param eventListener only used if this click causes Parakeet to start listening, will receive + * all updates for this run + */ + override fun onClick(eventListener: (InputEvent) -> Unit) { + // the state can only be changed in the background by the jobs corresponding to Downloading + // and Loading, but as can be seen below we don't do anything in case of Downloading. For + // Loading however, special measures are taken in toggleThenStartListening() and in load() + // to ensure the button click is not lost nor has any unwanted behavior if the state changes + // right after checking its value in this switch. + when (val s = _state.value) { + is NotInitialized -> {} // wait for initialization to happen + is NotAvailable -> {} // nothing to do + is NotDownloaded -> download(s.modelUrl) + is Downloading -> {} // wait for download to finish + is ErrorDownloading -> download(s.modelUrl) // retry + is Downloaded -> load(eventListener) + is NotLoaded -> load(eventListener) + is Loading -> toggleThenStartListening(eventListener) // wait for loading to finish + is ErrorLoading -> load(eventListener) // retry + is Loaded -> startListening(s.sessions, eventListener) + is Listening -> stopListening(s.sessions, s.eventListener, true) + } + } + + /** + * If the recognizer is currently listening, stops listening. Otherwise does nothing. + */ + override fun stopListening() { + when (val s = _state.value) { + is Listening -> stopListening(s.sessions, s.eventListener, true) + else -> {} + } + } + + /** + * Downloads all model files (encoder, decoder+joint, preprocessor, vocab). Sets the state to + * [Downloading], and periodically updates it with downloading progress, until either + * [ErrorDownloading] or [NotLoaded] are set as state. + */ + private fun download(modelUrl: String) { + _state.value = Downloading(Progress.UNKNOWN) + + operationsJob = scope.launch(Dispatchers.IO) { + try { + downloadBinaryFilesWithPartial( + urlsFiles = listOf( + FileToDownload( + "$modelUrl/resolve/main/encoder-model.int8.onnx", + encoderFile, + ), + FileToDownload( + "$modelUrl/resolve/main/decoder_joint-model.int8.onnx", + decoderJointFile, + ), + FileToDownload( + "$modelUrl/resolve/main/nemo128.onnx", + preprocessorFile, + ), + FileToDownload( + "$modelUrl/resolve/main/vocab.txt", + vocabFile, + ), + ), + httpClient = okHttpClient, + cacheDir = cacheDir, + ) { progress -> + _state.value = Downloading(progress) + } + + } catch (e: IOException) { + Log.e(TAG, "Can't download Parakeet model", e) + _state.value = ErrorDownloading(modelUrl, e) + return@launch + } + + // Write the base model URL so init() can detect the model is downloaded on + // next app restart (must match what init() compares against). + sameModelUrlCheck.writeText(modelUrl) + _state.value = NotLoaded + } + } + + /** + * Loads the ONNX Runtime sessions for the encoder, decoder+joint, and preprocessor models. + * Also reads the vocabulary file. Initially sets the state to [Loading] with + * [Loading.thenStartListening] = ([thenStartListeningEventListener] != `null`), and later + * either sets the state to [Loaded] or calls [startListening] by checking the current state's + * [Loading.thenStartListening] (which might have changed in the meantime, if the user clicked + * on the button while loading). + */ + private fun load(thenStartListeningEventListener: ((InputEvent) -> Unit)?) { + _state.value = Loading(thenStartListeningEventListener) + + operationsJob = scope.launch { + val sessions: ParakeetSessions + try { + val env = OrtEnvironment.getEnvironment() + val sessionOptions = OrtSession.SessionOptions().apply { + // Use NNAPI on supported devices for hardware acceleration + try { + addNnapi() + } catch (_: Exception) { + Log.d(TAG, "NNAPI not available, using CPU") + } + setIntraOpNumThreads( + Runtime.getRuntime().availableProcessors().coerceAtMost(4) + ) + } + + val encoder = env.createSession( + encoderFile.absolutePath, sessionOptions + ) + val decoderJoint = env.createSession( + decoderJointFile.absolutePath, sessionOptions + ) + val preprocessor = env.createSession( + preprocessorFile.absolutePath, sessionOptions + ) + + // Read vocab.txt: each line is "token id", where U+2581 represents space + val vocab = loadVocab(vocabFile) + + sessions = ParakeetSessions( + encoder, decoderJoint, preprocessor, vocab, env + ) + } catch (e: Exception) { + Log.e(TAG, "Can't load Parakeet model", e) + _state.value = ErrorLoading(e) + return@launch + } + + if (!_state.compareAndSet(Loading(null), Loaded(sessions))) { + val state = _state.value + if (state is Loading && state.thenStartListening != null) { + // "state is Loading" will always be true except when the load() is being + // joined by init(). + // "state.thenStartListening" might be "null" if, in the brief moment between + // the compareAndSet() and reading _state.value, the state was changed by + // toggleThenStartListening(). + startListening(sessions, state.thenStartListening) + + } else if (!_state.compareAndSet(Loading(null, true), Loaded(sessions))) { + // The current state is not the Loading state, which is unexpected. This means + // that load() is being joined by init(), which is reinitializing everything, + // so we should drop the sessions. + sessions.close() + } + + } // else, the state was set to Loaded, so no need to do anything + } + } + + /** + * Reads the vocabulary file. Each line contains a token and its ID separated by a space. + * The Unicode block element U+2581 is replaced with a regular space. The special `` + * token is recorded as the blank index. + */ + private fun loadVocab(file: File): Map { + val vocab = mutableMapOf() + BufferedReader(InputStreamReader(file.inputStream(), Charsets.UTF_8)).use { reader -> + reader.forEachLine { line -> + val parts = line.trimEnd().split(" ", limit = 2) + if (parts.size == 2) { + val token = parts[0].replace("\u2581", " ") + val id = parts[1].toIntOrNull() ?: return@forEachLine + vocab[id] = token + } + } + } + return vocab + } + + /** + * Atomically handles toggling the [Loading.thenStartListening] state, making sure that if in + * the meantime the value is changed by [load], the user click is not wasted, and the state + * machine does not end up in an inconsistent state. + * + * @param eventListener used only if the model has finished loading in the brief moment between + * when the state is first checked, but if the state was switched to [Loaded] (and not + * [Listening]), which means that this click should start listening. + */ + private fun toggleThenStartListening(eventListener: (InputEvent) -> Unit) { + if ( + !_state.compareAndSet(Loading(null), Loading(eventListener)) && + !_state.compareAndSet(Loading(eventListener), Loading(null)) + ) { + // may happen if load() changes the state in the brief moment between when the state is + // first checked before calling this function, and when the checks above are performed + Log.w(TAG, "Cannot toggle thenStartListening") + when (val newValue = _state.value) { + is Loaded -> startListening(newValue.sessions, eventListener) + is Listening -> stopListening(newValue.sessions, newValue.eventListener, true) + is ErrorLoading -> {} // ignore the user's click + // the else should never happen, since load() only transitions from Loading(...) to + // one of Loaded, Listening or ErrorLoading + else -> Log.e(TAG, "State was none of Loading, Loaded or Listening") + } + } + } + + /** + * Starts listening for audio input, and changes the state to [Listening]. + */ + private fun startListening( + sessions: ParakeetSessions, + eventListener: (InputEvent) -> Unit, + ) { + clearTransientUiState() + activeListener?.stopAndDiscardCurrentAudio() + _state.value = Listening(sessions, eventListener) + val listener = ParakeetListener( + this@ParakeetInputDevice, + eventListener, + DEFAULT_SILENCES_BEFORE_STOP, + sessions, + ) + activeListener = listener + + val job = scope.launch { + try { + listener.startRecording() + } finally { + if (activeListener === listener) { + activeListener = null + } + } + } + listeningJob = job + job.invokeOnCompletion { + if (listeningJob === job) { + listeningJob = null + } + } + } + + /** + * Stops listening for audio input, and changes the state to [Loaded]. This is + * `internal` because it is used by [ParakeetListener]. + */ + internal fun stopListening( + sessions: ParakeetSessions, + eventListener: (InputEvent) -> Unit, + sendNoneEvent: Boolean, + ) { + if (sendNoneEvent) { + activeListener?.stopAndDiscardCurrentAudio() + activeListener = null + } + clearTransientUiState() + _state.value = Loaded(sessions) + if (sendNoneEvent) { + eventListener(InputEvent.None) + } + } + + internal fun setTransientUiState(state: SttState) { + _transientUiState.value = state + } + + internal fun clearTransientUiState() { + _transientUiState.value = null + } + + override suspend fun destroy() { + deinit() + // cancel everything + scope.cancel() + } + + companion object { + private val TAG = ParakeetInputDevice::class.simpleName + private const val DEFAULT_SILENCES_BEFORE_STOP = 1 + + /** + * Base URL for the pre-quantized ONNX model from + * [istupakov/parakeet-tdt-0.6b-v3-onnx](https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx). + * Parakeet v3 is a single multilingual model that auto-detects the spoken language, + * so all locale keys point to the same HuggingFace repository. Individual model files + * (encoder, decoder_joint, preprocessor, vocab) are resolved relative to this URL + * during download. + * + * INT8-quantized files used: + * - `encoder-model.int8.onnx` (~652 MB) + * - `decoder_joint-model.int8.onnx` (~18 MB) + * - `nemo128.onnx` (~140 KB, mel-spectrogram preprocessor) + * - `vocab.txt` (~94 KB) + * + * Supported languages (25 European languages): + * bg, hr, cs, da, nl, en, et, fi, fr, de, el, hu, it, lv, lt, mt, pl, pt, ro, sk, + * sl, es, sv, ru, uk + * + * @see NVIDIA model card + * @see ONNX conversion + */ + private const val PARAKEET_MODEL_BASE_URL = + "https://huggingface.co/istupakov/parakeet-tdt-0.6b-v3-onnx" + + val MODEL_URLS = mapOf( + "bg" to PARAKEET_MODEL_BASE_URL, + "hr" to PARAKEET_MODEL_BASE_URL, + "cs" to PARAKEET_MODEL_BASE_URL, + "da" to PARAKEET_MODEL_BASE_URL, + "nl" to PARAKEET_MODEL_BASE_URL, + "en" to PARAKEET_MODEL_BASE_URL, + "et" to PARAKEET_MODEL_BASE_URL, + "fi" to PARAKEET_MODEL_BASE_URL, + "fr" to PARAKEET_MODEL_BASE_URL, + "de" to PARAKEET_MODEL_BASE_URL, + "el" to PARAKEET_MODEL_BASE_URL, + "hu" to PARAKEET_MODEL_BASE_URL, + "it" to PARAKEET_MODEL_BASE_URL, + "lv" to PARAKEET_MODEL_BASE_URL, + "lt" to PARAKEET_MODEL_BASE_URL, + "mt" to PARAKEET_MODEL_BASE_URL, + "pl" to PARAKEET_MODEL_BASE_URL, + "pt" to PARAKEET_MODEL_BASE_URL, + "ro" to PARAKEET_MODEL_BASE_URL, + "sk" to PARAKEET_MODEL_BASE_URL, + "sl" to PARAKEET_MODEL_BASE_URL, + "es" to PARAKEET_MODEL_BASE_URL, + "sv" to PARAKEET_MODEL_BASE_URL, + "ru" to PARAKEET_MODEL_BASE_URL, + "uk" to PARAKEET_MODEL_BASE_URL, + ) + } +} diff --git a/app/src/main/kotlin/org/stypox/dicio/io/input/parakeet/ParakeetListener.kt b/app/src/main/kotlin/org/stypox/dicio/io/input/parakeet/ParakeetListener.kt new file mode 100644 index 000000000..23ecfbf05 --- /dev/null +++ b/app/src/main/kotlin/org/stypox/dicio/io/input/parakeet/ParakeetListener.kt @@ -0,0 +1,574 @@ +/* + * Taken from /e/OS Assistant + * + * Copyright (C) 2024 MURENA SAS + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package org.stypox.dicio.io.input.parakeet + +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.TensorInfo +import android.media.AudioFormat +import android.media.AudioRecord +import android.media.MediaRecorder +import android.util.Log +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.isActive +import kotlinx.coroutines.withContext +import org.stypox.dicio.io.input.InputEvent +import org.stypox.dicio.io.input.SttState +import java.nio.FloatBuffer +import java.nio.IntBuffer +import java.nio.LongBuffer + +/** + * Handles audio recording and speech-to-text inference for the Parakeet TDT model using + * ONNX Runtime. The inference pipeline mirrors the reference implementation from + * [onnx-asr](https://github.com/istupakov/onnx-asr): + * + * 1. Record 16 kHz mono PCM audio from the microphone. + * 2. Run the NeMo 128-dim mel-spectrogram preprocessor (`nemo128.onnx`). + * 3. Run the FastConformer encoder (`encoder-model.int8.onnx`). + * 4. Run TDT (Token-and-Duration Transducer) greedy decoding with the joint decoder + * (`decoder_joint-model.int8.onnx`). + * 5. Map token IDs to text via `vocab.txt`. + * + * @param parakeetInputDevice the parent input device for state management + * @param eventListener callback to receive transcription events + * @param silencesBeforeStop how many consecutive silence chunks before auto-stopping; must be >= 1 + * @param sessions the loaded ONNX Runtime sessions and vocabulary + */ +internal class ParakeetListener( + private val parakeetInputDevice: ParakeetInputDevice, + private val eventListener: (InputEvent) -> Unit, + private var silencesBeforeStop: Int, + private val sessions: ParakeetSessions, +) { + + private var audioRecord: AudioRecord? = null + @Volatile + private var isRecording = false + @Volatile + private var shouldProcessFinalAudio = true + @Volatile + private var stopRequestedByUser = false + + /** + * Starts recording audio and processing it with the Parakeet model. + * + * Parakeet does not support true streaming: we record until speech-end (silence) and then run + * a single final inference. To keep the user informed, the UI transitions through transient + * phases: [SttState.Listening] -> [SttState.SilenceDetected] -> [SttState.Thinking]. + */ + suspend fun startRecording() = withContext(Dispatchers.IO) { + try { + shouldProcessFinalAudio = true + stopRequestedByUser = false + + val bufferSize = AudioRecord.getMinBufferSize( + SAMPLE_RATE, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT + ) + + if (bufferSize == AudioRecord.ERROR || bufferSize == AudioRecord.ERROR_BAD_VALUE) { + Log.e(TAG, "Invalid buffer size: $bufferSize") + eventListener(InputEvent.Error(Exception("Invalid audio buffer size"))) + return@withContext + } + + audioRecord = AudioRecord( + MediaRecorder.AudioSource.MIC, + SAMPLE_RATE, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + bufferSize + ) + + if (audioRecord?.state != AudioRecord.STATE_INITIALIZED) { + Log.e(TAG, "AudioRecord initialization failed") + eventListener(InputEvent.Error(Exception("AudioRecord initialization failed"))) + return@withContext + } + + audioRecord?.startRecording() + isRecording = true + Log.i(TAG, "Recording started, bufferSize=$bufferSize") + + // Thread-safe audio accumulator: the recording coroutine appends, while the + // partial-inference coroutine takes snapshots. + val audioData = mutableListOf() + val audioLock = Any() + var hasHeardSpeech = false + var silenceDetected = false + var speechSamples = 0 + + // Recording loop in the current (IO) coroutine + val audioBuffer = ShortArray(bufferSize / 2) + var consecutiveSilentReads = 0 + + while (isRecording && isActive) { + val readSize = audioRecord?.read(audioBuffer, 0, audioBuffer.size) ?: 0 + + if (readSize > 0) { + synchronized(audioLock) { + for (i in 0 until readSize) { + audioData.add(audioBuffer[i]) + } + } + + // Compute RMS amplitude of this chunk for silence detection + val rms = kotlin.math.sqrt( + audioBuffer.take(readSize) + .sumOf { it.toLong() * it.toLong() } + .toDouble() / readSize + ) + val isSilent = rms < SILENCE_RMS_THRESHOLD + + if (!isSilent) { + speechSamples += readSize + hasHeardSpeech = + speechSamples >= MIN_SPEECH_SAMPLES_BEFORE_AUTO_STOP + consecutiveSilentReads = 0 + } else { + consecutiveSilentReads++ + } + + // Only consider silence-based stop after speech has been heard + val samplesPerSilenceUnit = (SILENCE_DURATION_MS * SAMPLE_RATE / 1000) + val readsPerSilenceUnit = + (samplesPerSilenceUnit / readSize).coerceAtLeast(1) + val requiredSilentReads = readsPerSilenceUnit * silencesBeforeStop + + if (hasHeardSpeech && consecutiveSilentReads >= requiredSilentReads) { + val totalSamples = synchronized(audioLock) { audioData.size } + Log.i(TAG, "Silence detected after speech, " + + "$totalSamples samples " + + "(${totalSamples / SAMPLE_RATE.toFloat()}s)") + silenceDetected = true + break + } + + // Safety cap: stop after MAX_RECORDING_SECONDS + val totalSamples = synchronized(audioLock) { audioData.size } + if (totalSamples >= SAMPLE_RATE * MAX_RECORDING_SECONDS) { + Log.i(TAG, "Max recording duration reached") + break + } + } + } + + // Stop recording hardware + stopRecording() + val finalAudio: ShortArray + synchronized(audioLock) { + finalAudio = audioData.toShortArray() + } + Log.i(TAG, "Recording stopped, total samples: ${finalAudio.size} " + + "(${finalAudio.size / SAMPLE_RATE.toFloat()}s), hasHeardSpeech=$hasHeardSpeech") + + if (!shouldProcessFinalAudio) { + Log.i(TAG, "Recording stopped by user, skipping final inference") + return@withContext + } + + if (silenceDetected) { + parakeetInputDevice.setTransientUiState(SttState.SilenceDetected) + delay(SILENCE_DETECTED_FEEDBACK_MS) + } + processAudio(finalAudio) + + } catch (e: Exception) { + if (stopRequestedByUser) { + Log.i(TAG, "Recording stopped by user while reading audio") + return@withContext + } + Log.e(TAG, "Error during recording", e) + eventListener(InputEvent.Error(e)) + stopRecording() + } + } + + /** + * Processes the final audio data and generates the final transcription result. + * After emitting the result, transitions the input device state back to [ParakeetState.Loaded] + * so the UI shows the mic button again (matching Vosk's behavior in onResult). + */ + private fun processAudio(audioData: ShortArray) { + // Transition state from Listening → Loaded *before* emitting the event, so the UI + // updates promptly. Pass sendNoneEvent = false because we emit our own event below. + parakeetInputDevice.stopListening(sessions, eventListener, false) + parakeetInputDevice.setTransientUiState(SttState.Thinking) + + try { + if (audioData.isEmpty()) { + Log.w(TAG, "processAudio: empty audio data") + eventListener(InputEvent.None) + return + } + + Log.i(TAG, "processAudio: running inference on ${audioData.size} samples " + + "(${audioData.size / SAMPLE_RATE.toFloat()}s)") + val startTime = System.currentTimeMillis() + val result = runInference(audioData) + val elapsed = System.currentTimeMillis() - startTime + Log.i(TAG, "processAudio: inference took ${elapsed}ms, result=\"$result\"") + + if (result.isBlank()) { + eventListener(InputEvent.None) + } else { + eventListener(InputEvent.Final(listOf(Pair(result, 1.0f)))) + } + } catch (e: Exception) { + Log.e(TAG, "Error processing final audio", e) + eventListener(InputEvent.Error(e)) + } finally { + parakeetInputDevice.clearTransientUiState() + } + } + + /** + * Runs the full Parakeet inference pipeline on the given PCM audio data: + * preprocessor -> encoder -> TDT greedy decode -> vocab lookup. + */ + private fun runInference(audioData: ShortArray): String { + // 1. Convert PCM int16 to float32 normalized to [-1.0, 1.0] + val floatAudio = FloatArray(audioData.size) { i -> + audioData[i].toFloat() / Short.MAX_VALUE + } + + // 2. Run the nemo128 mel-spectrogram preprocessor + var stepStart = System.currentTimeMillis() + val (features, featuresLens) = runPreprocessor(floatAudio) + Log.i(TAG, "Preprocessor: ${System.currentTimeMillis() - stepStart}ms, " + + "features size=${features.size}") + + // 3. Run the FastConformer encoder + stepStart = System.currentTimeMillis() + val (encoderOut, encoderOutLens) = runEncoder(features, featuresLens) + Log.i(TAG, "Encoder: ${System.currentTimeMillis() - stepStart}ms, " + + "T=${encoderOut.size}, D=${encoderOut.firstOrNull()?.size ?: 0}, len=$encoderOutLens") + + // 4. Run TDT greedy decoding + stepStart = System.currentTimeMillis() + val tokenIds = tdtGreedyDecode(encoderOut, encoderOutLens) + Log.i(TAG, "TDT decode: ${System.currentTimeMillis() - stepStart}ms, " + + "tokens=${tokenIds.size}") + + // 5. Map token IDs to text + return decodeTokens(tokenIds) + } + + /** + * Runs the NeMo 128-dim mel-spectrogram preprocessor. + * + * Input: `waveforms` [1, N] float32, `waveforms_lens` [1] int64 + * Output: `features` [1, 128, T] float32, `features_lens` [1] int64 + */ + private fun runPreprocessor( + waveform: FloatArray + ): Pair { + val env = sessions.env + val waveformTensor = OnnxTensor.createTensor( + env, + FloatBuffer.wrap(waveform), + longArrayOf(1, waveform.size.toLong()), + ) + val waveformLensTensor = OnnxTensor.createTensor( + env, + LongBuffer.wrap(longArrayOf(waveform.size.toLong())), + longArrayOf(1), + ) + + val result = sessions.preprocessor.run( + mapOf("waveforms" to waveformTensor, "waveforms_lens" to waveformLensTensor) + ) + + val featuresTensor = result["features"].get() as OnnxTensor + val featuresLensTensor = result["features_lens"].get() as OnnxTensor + + val featuresShape = featuresTensor.info.shape // [1, 128, T] + val featuresFlat = featuresTensor.floatBuffer.let { buf -> + FloatArray(buf.remaining()).also { buf.get(it) } + } + val featuresLens = featuresLensTensor.longBuffer.let { buf -> + LongArray(buf.remaining()).also { buf.get(it) } + } + + waveformTensor.close() + waveformLensTensor.close() + result.close() + + return Pair(featuresFlat, featuresLens) + } + + /** + * Runs the FastConformer encoder. + * + * Input: `audio_signal` [1, 128, T] float32, `length` [1] int64 + * Output: `outputs` [1, D, T'] float32, `encoded_lengths` [1] int64 + * + * The encoder output is transposed from [1, D, T'] to [T', D] for the decoder. + */ + private fun runEncoder( + features: FloatArray, + featuresLens: LongArray + ): Pair, Long> { + val env = sessions.env + + // Reconstruct shape: the preprocessor output is [1, 128, T] + val totalElements = features.size + val featureDim = 128L + val timeSteps = totalElements / featureDim + + val featuresTensor = OnnxTensor.createTensor( + env, + FloatBuffer.wrap(features), + longArrayOf(1, featureDim, timeSteps), + ) + val lengthTensor = OnnxTensor.createTensor( + env, + LongBuffer.wrap(featuresLens), + longArrayOf(1), + ) + + val result = sessions.encoder.run( + mapOf("audio_signal" to featuresTensor, "length" to lengthTensor) + ) + + val outputsTensor = result["outputs"].get() as OnnxTensor + val encodedLengthsTensor = result["encoded_lengths"].get() as OnnxTensor + + // outputs shape: [1, D, T'] — need to transpose to [T', D] + val outShape = outputsTensor.info.shape // [1, D, T'] + val outD = outShape[1].toInt() + val outT = outShape[2].toInt() + val outFlat = outputsTensor.floatBuffer.let { buf -> + FloatArray(buf.remaining()).also { buf.get(it) } + } + val encodedLength = encodedLengthsTensor.longBuffer.get(0) + + // Transpose from [1, D, T'] (row-major) to Array [T'][D] + val transposed = Array(outT) { t -> + FloatArray(outD) { d -> outFlat[d * outT + t] } + } + + featuresTensor.close() + lengthTensor.close() + result.close() + + return Pair(transposed, encodedLength) + } + + /** + * TDT (Token-and-Duration Transducer) greedy decoding. At each encoder time step, the + * decoder+joint network produces logits for both the token vocabulary and the duration + * (how many encoder frames to skip). This follows the reference implementation from + * `onnx-asr` (`NemoConformerTdt._decode` and `_AsrWithTransducerDecoding._decoding`). + * + * @param encoderOut encoder output, shape [T, D] (already transposed) + * @param encodedLength number of valid encoder time steps + * @return list of decoded token IDs (excluding blanks) + */ + private fun tdtGreedyDecode( + encoderOut: Array, + encodedLength: Long, + ): List { + val env = sessions.env + val blankIdx = sessions.blankIdx + val vocabSize = sessions.vocabSize + + // Initialize decoder LSTM states to zero, following the reference implementation + // (NemoConformerRnnt._create_state): use dim[0] and dim[2] from the model's + // input metadata (which are static), and hardcode dim[1]=1 (the batch/sequence + // dimension, which is dynamic=-1 in the ONNX model and would cause + // NegativeArraySizeException if multiplied). + val decoderInputs = sessions.decoderJoint.inputInfo + val state1ModelShape = (decoderInputs["input_states_1"]?.info as? TensorInfo)?.shape + val state2ModelShape = (decoderInputs["input_states_2"]?.info as? TensorInfo)?.shape + + val state1Shape = if (state1ModelShape != null && state1ModelShape.size == 3) { + longArrayOf(state1ModelShape[0], 1, state1ModelShape[2]) + } else { + longArrayOf(2, 1, 640) + } + val state2Shape = if (state2ModelShape != null && state2ModelShape.size == 3) { + longArrayOf(state2ModelShape[0], 1, state2ModelShape[2]) + } else { + longArrayOf(2, 1, 640) + } + + var state1 = FloatArray((state1Shape[0] * state1Shape[1] * state1Shape[2]).toInt()) + var state2 = FloatArray((state2Shape[0] * state2Shape[1] * state2Shape[2]).toInt()) + + val tokens = mutableListOf() + var t = 0 + var emittedTokens = 0 + val maxT = encodedLength.toInt().coerceAtMost(encoderOut.size) + + while (t < maxT) { + // Prepare the encoder output for this time step: [1, D, 1] + val encoderFrame = encoderOut[t] + val d = encoderFrame.size + + val encoderOutputsTensor = OnnxTensor.createTensor( + env, + FloatBuffer.wrap(encoderFrame), + longArrayOf(1, d.toLong(), 1), + ) + + // Target: the last emitted token, or blank if none emitted yet + val targetToken = if (tokens.isNotEmpty()) tokens.last() else blankIdx + val targetsTensor = OnnxTensor.createTensor( + env, + IntBuffer.wrap(intArrayOf(targetToken)), + longArrayOf(1, 1), + ) + val targetLengthTensor = OnnxTensor.createTensor( + env, + IntBuffer.wrap(intArrayOf(1)), + longArrayOf(1), + ) + val states1Tensor = OnnxTensor.createTensor( + env, + FloatBuffer.wrap(state1), + state1Shape, + ) + val states2Tensor = OnnxTensor.createTensor( + env, + FloatBuffer.wrap(state2), + state2Shape, + ) + + val result = sessions.decoderJoint.run( + mapOf( + "encoder_outputs" to encoderOutputsTensor, + "targets" to targetsTensor, + "target_length" to targetLengthTensor, + "input_states_1" to states1Tensor, + "input_states_2" to states2Tensor, + ) + ) + + val outputsTensor = result["outputs"].get() as OnnxTensor + val outState1Tensor = result["output_states_1"].get() as OnnxTensor + val outState2Tensor = result["output_states_2"].get() as OnnxTensor + + val outputs = outputsTensor.floatBuffer.let { buf -> + FloatArray(buf.remaining()).also { buf.get(it) } + } + val newState1 = outState1Tensor.floatBuffer.let { buf -> + FloatArray(buf.remaining()).also { buf.get(it) } + } + val newState2 = outState2Tensor.floatBuffer.let { buf -> + FloatArray(buf.remaining()).also { buf.get(it) } + } + + // TDT: first vocabSize elements are token logits, rest are duration logits + val tokenLogits = outputs.copyOfRange(0, vocabSize) + val durationLogits = outputs.copyOfRange(vocabSize, outputs.size) + + // Greedy: pick the token with the highest logit + val token = tokenLogits.indices.maxByOrNull { tokenLogits[it] } ?: blankIdx + + // Greedy: pick the duration with the highest logit + val step = if (durationLogits.isNotEmpty()) { + durationLogits.indices.maxByOrNull { durationLogits[it] } ?: 0 + } else { + -1 // fallback for plain RNN-T (no duration head) + } + + if (token != blankIdx) { + // Non-blank emission: update decoder states and record the token + state1 = newState1 + state2 = newState2 + tokens.add(token) + emittedTokens++ + } + + // Advance the time step based on the duration prediction + if (step > 0) { + t += step + emittedTokens = 0 + } else if (token == blankIdx || emittedTokens >= MAX_TOKENS_PER_STEP) { + t += 1 + emittedTokens = 0 + } + + // Clean up ORT tensors for this iteration + encoderOutputsTensor.close() + targetsTensor.close() + targetLengthTensor.close() + states1Tensor.close() + states2Tensor.close() + result.close() + } + + return tokens + } + + /** + * Maps a list of token IDs to their string representations using the vocabulary, + * then joins and trims the result. + */ + private fun decodeTokens(tokenIds: List): String { + return tokenIds.mapNotNull { sessions.vocab[it] } + .joinToString("") + .trim() + } + + /** + * Stops the audio recording. + */ + fun stopRecording() { + isRecording = false + try { + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + } catch (e: Exception) { + Log.e(TAG, "Error stopping recording", e) + } + } + + fun stopAndDiscardCurrentAudio() { + stopRequestedByUser = true + shouldProcessFinalAudio = false + stopRecording() + } + + companion object { + private val TAG = ParakeetListener::class.simpleName + /** Parakeet models use 16 kHz sample rate. */ + private const val SAMPLE_RATE = 16000 + /** RMS amplitude threshold for silence detection (on int16 PCM samples). */ + private const val SILENCE_RMS_THRESHOLD = 300.0 + /** Duration in milliseconds that one "silence unit" represents. */ + private const val SILENCE_DURATION_MS = 1000 + /** Maximum recording duration in seconds (safety cap). */ + private const val MAX_RECORDING_SECONDS = 30 + /** Maximum tokens the TDT decoder may emit per encoder time step. */ + private const val MAX_TOKENS_PER_STEP = 10 + /** How long to show "Silence detected" before switching to "Thinking". */ + private const val SILENCE_DETECTED_FEEDBACK_MS = 800L + /** + * Minimum amount of detected speech before auto-stopping on silence. + * Helps avoid false early stop from brief noise spikes. + */ + private const val MIN_SPEECH_SAMPLES_BEFORE_AUTO_STOP = SAMPLE_RATE * 2 / 5 + } +} diff --git a/app/src/main/kotlin/org/stypox/dicio/io/input/parakeet/ParakeetState.kt b/app/src/main/kotlin/org/stypox/dicio/io/input/parakeet/ParakeetState.kt new file mode 100644 index 000000000..fc629bd0f --- /dev/null +++ b/app/src/main/kotlin/org/stypox/dicio/io/input/parakeet/ParakeetState.kt @@ -0,0 +1,162 @@ +/* + * Taken from /e/OS Assistant + * + * Copyright (C) 2024 MURENA SAS + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package org.stypox.dicio.io.input.parakeet + +import org.stypox.dicio.io.input.InputEvent +import org.stypox.dicio.io.input.SttState +import org.stypox.dicio.ui.util.Progress + +/** + * The internal state for [ParakeetInputDevice]. This is an enum with different fields depending on + * the current state, to avoid having nullable objects all over the place in [ParakeetInputDevice]. + * [SttState] is symmetrical to this enum, except that it does not expose implementation-defined + * fields to the UI. + */ +sealed interface ParakeetState { + + /** + * The ParakeetInputDevice has not been initialized yet, or has just been deinitialized + */ + data object NotInitialized : ParakeetState + + /** + * The model is not available for the current locale + */ + data object NotAvailable : ParakeetState + + /** + * The model is not present on disk. + */ + data class NotDownloaded( + val modelUrl: String + ) : ParakeetState + + data class Downloading( + val progress: Progress, + ) : ParakeetState + + data class ErrorDownloading( + val modelUrl: String, + val throwable: Throwable + ) : ParakeetState + + data object Downloaded : ParakeetState + + /** + * The model is present on disk, but was not loaded in RAM yet. + */ + data object NotLoaded : ParakeetState + + /** + * The model is being loaded, and the nullity of [thenStartListening] indicates whether once + * loading is finished, the STT should start listening right away. + * [shouldEqualAnyLoading] is used just to create a [Loading] object with compares equal to any + * other [Loading], but [Loading] with [shouldEqualAnyLoading]` = true` will never appear as a + * state. + */ + data class Loading( + val thenStartListening: ((InputEvent) -> Unit)?, + val shouldEqualAnyLoading: Boolean = false, + ) : ParakeetState { + override fun equals(other: Any?): Boolean { + if (other !is Loading) + return false + if (shouldEqualAnyLoading || other.shouldEqualAnyLoading) + return true + return (this.thenStartListening == null) == (other.thenStartListening == null) + } + + override fun hashCode(): Int { + return if (thenStartListening == null) 0 else 1; + } + } + + data class ErrorLoading( + val throwable: Throwable + ) : ParakeetState + + /** + * The model is ready in RAM, and can start listening at any time. + */ + data class Loaded( + internal val sessions: ParakeetSessions + ) : ParakeetState + + /** + * The model is listening. + */ + data class Listening( + internal val sessions: ParakeetSessions, + internal val eventListener: (InputEvent) -> Unit, + ) : ParakeetState + + /** + * Converts this [ParakeetState] to a [SttState], which is basically the same, except that + * implementation-defined fields (e.g. [ParakeetSessions]) are stripped away. + */ + fun toUiState(): SttState { + return when (this) { + NotInitialized -> SttState.NotInitialized + NotAvailable -> SttState.NotAvailable + is NotDownloaded -> SttState.NotDownloaded + is Downloading -> SttState.Downloading(progress) + is ErrorDownloading -> SttState.ErrorDownloading(throwable) + Downloaded -> SttState.Downloaded + NotLoaded -> SttState.NotLoaded + is Loading -> SttState.Loading(thenStartListening != null) + is ErrorLoading -> SttState.ErrorLoading(throwable) + is Loaded -> SttState.Loaded + is Listening -> SttState.Listening + } + } +} + +/** + * Holds all ONNX Runtime inference sessions needed for Parakeet TDT inference: + * the NeMo mel-spectrogram preprocessor, the FastConformer encoder, and the + * RNN-T/TDT joint decoder, plus the decoded vocabulary map. + * + * @param encoder the FastConformer encoder session (`encoder-model.int8.onnx`) + * @param decoderJoint the joint decoder session (`decoder_joint-model.int8.onnx`) + * @param preprocessor the mel-spectrogram preprocessor session (`nemo128.onnx`) + * @param vocab mapping from token ID to decoded string (from `vocab.txt`) + * @param env the shared ONNX Runtime environment + */ +data class ParakeetSessions( + internal val encoder: ai.onnxruntime.OrtSession, + internal val decoderJoint: ai.onnxruntime.OrtSession, + internal val preprocessor: ai.onnxruntime.OrtSession, + internal val vocab: Map, + internal val env: ai.onnxruntime.OrtEnvironment, +) { + /** The blank token ID, used as the initial decoder target and to detect blank emissions. */ + val blankIdx: Int = vocab.entries.firstOrNull { it.value == "" }?.key + ?: (vocab.size - 1) + + /** Total vocabulary size (including the blank token). */ + val vocabSize: Int = vocab.size + + fun close() { + preprocessor.close() + encoder.close() + decoderJoint.close() + env.close() + } +} diff --git a/app/src/main/kotlin/org/stypox/dicio/io/input/scribe/ScribeRealtimeInputDevice.kt b/app/src/main/kotlin/org/stypox/dicio/io/input/scribe/ScribeRealtimeInputDevice.kt new file mode 100644 index 000000000..922ce6ab1 --- /dev/null +++ b/app/src/main/kotlin/org/stypox/dicio/io/input/scribe/ScribeRealtimeInputDevice.kt @@ -0,0 +1,397 @@ +package org.stypox.dicio.io.input.scribe + +import android.media.AudioFormat +import android.media.AudioRecord +import android.media.MediaRecorder +import android.net.Uri +import android.util.Base64 +import android.util.Log +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.Response +import okhttp3.WebSocket +import okhttp3.WebSocketListener +import okio.ByteString +import org.json.JSONObject +import org.stypox.dicio.di.LocaleManager +import org.stypox.dicio.io.input.InputEvent +import org.stypox.dicio.io.input.SttInputDevice +import org.stypox.dicio.io.input.SttState +import org.stypox.dicio.util.distinctUntilChangedBlockingFirst +import java.util.Locale + +class ScribeRealtimeInputDevice( + private val okHttpClient: OkHttpClient, + localeManager: LocaleManager, + private val apiKey: StateFlow, + private val silencesBeforeStop: StateFlow, +) : SttInputDevice { + + private val scope = CoroutineScope(Dispatchers.Default) + + private val _uiState: MutableStateFlow + override val uiState: StateFlow + + @Volatile + private var currentLanguageCode: String + @Volatile + private var currentApiKey: String + + @Volatile + private var webSocket: WebSocket? = null + @Volatile + private var audioRecord: AudioRecord? = null + @Volatile + private var shouldRecordAudio = false + + private var activeEventListener: ((InputEvent) -> Unit)? = null + private var audioStreamingJob: Job? = null + + init { + val (firstLocale, nextLocaleFlow) = localeManager.locale + .distinctUntilChangedBlockingFirst() + currentLanguageCode = languageCodeFromLocale(firstLocale) + currentApiKey = apiKey.value.trim() + + val initialState = readyStateFromApiKey(currentApiKey) + _uiState = MutableStateFlow(initialState) + uiState = _uiState + + scope.launch { + nextLocaleFlow.collect { locale -> + currentLanguageCode = languageCodeFromLocale(locale) + } + } + + scope.launch { + apiKey.collect { key -> + currentApiKey = key.trim() + if (_uiState.value != SttState.Listening) { + _uiState.value = readyStateFromApiKey(currentApiKey) + } + } + } + } + + override fun tryLoad(thenStartListeningEventListener: ((InputEvent) -> Unit)?): Boolean { + if (currentApiKey.isBlank()) { + _uiState.value = SttState.NotAvailable + return false + } + if (thenStartListeningEventListener == null) { + _uiState.value = SttState.Loaded + return true + } + + startListening(thenStartListeningEventListener) + return true + } + + override fun onClick(eventListener: (InputEvent) -> Unit) { + when (_uiState.value) { + SttState.Listening -> stopListeningInternal(sendNoneEvent = true) + else -> { + if (currentApiKey.isBlank()) { + _uiState.value = SttState.NotAvailable + return + } + startListening(eventListener) + } + } + } + + override fun stopListening() { + stopListeningInternal(sendNoneEvent = true) + } + + private fun startListening(eventListener: (InputEvent) -> Unit) { + if (_uiState.value == SttState.Listening) { + return + } + + val apiKey = currentApiKey + if (apiKey.isBlank()) { + _uiState.value = SttState.NotAvailable + return + } + + activeEventListener = eventListener + _uiState.value = SttState.Listening + shouldRecordAudio = true + + try { + val request = Request.Builder() + .url(buildRealtimeUrl()) + .addHeader("xi-api-key", apiKey) + .build() + + webSocket = okHttpClient.newWebSocket(request, object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + Log.i(TAG, "Scribe realtime websocket opened") + } + + override fun onMessage(webSocket: WebSocket, text: String) { + handleServerMessage(webSocket, text) + } + + override fun onMessage(webSocket: WebSocket, bytes: ByteString) { + Log.w(TAG, "Ignoring unexpected binary message from Scribe realtime websocket") + } + + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + Log.e(TAG, "Scribe realtime websocket failure", t) + emitErrorAndStop(t) + } + + override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { + Log.i(TAG, "Scribe realtime websocket closed: code=$code reason=$reason") + if (_uiState.value == SttState.Listening) { + if (code == 1000) { + emitNoneAndStop() + } else { + emitErrorAndStop(Exception("Scribe connection closed: $code $reason")) + } + } + } + }) + } catch (e: Exception) { + Log.e(TAG, "Failed to initialize Scribe realtime websocket", e) + emitErrorAndStop(e) + } + } + + private fun handleServerMessage(webSocket: WebSocket, text: String) { + val message = try { + JSONObject(text) + } catch (e: Exception) { + Log.e(TAG, "Invalid Scribe JSON message: $text", e) + emitErrorAndStop(e) + return + } + + val messageType = message.optString("message_type", message.optString("type", "")) + when (messageType) { + "session_started" -> { + startAudioStreaming(webSocket) + } + "partial_transcript" -> { + val partial = message.optString("text", "") + if (partial.isNotBlank()) { + activeEventListener?.invoke(InputEvent.Partial(partial)) + } + } + "committed_transcript", + "committed_transcript_with_timestamps" -> { + val transcript = message.optString("text", "") + if (transcript.isNotBlank()) { + val eventListener = activeEventListener + stopListeningInternal(sendNoneEvent = false) + eventListener?.invoke(InputEvent.Final(listOf(Pair(transcript, 1.0f)))) + } + } + in ERROR_MESSAGE_TYPES -> { + val errorMessage = message.optString("error", message.optString("message", messageType)) + emitErrorAndStop(Exception("Scribe realtime error ($messageType): $errorMessage")) + } + } + } + + private fun startAudioStreaming(webSocket: WebSocket) { + if (audioStreamingJob?.isActive == true) { + return + } + + audioStreamingJob = scope.launch { + withContext(Dispatchers.IO) { + try { + val minBufferSize = AudioRecord.getMinBufferSize( + SAMPLE_RATE, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + ) + if (minBufferSize <= 0) { + throw Exception("Invalid AudioRecord min buffer size: $minBufferSize") + } + + val recordBufferSize = maxOf(minBufferSize, CHUNK_SAMPLES * BYTES_PER_SAMPLE * 4) + val recorder = AudioRecord( + MediaRecorder.AudioSource.MIC, + SAMPLE_RATE, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + recordBufferSize, + ) + audioRecord = recorder + + if (recorder.state != AudioRecord.STATE_INITIALIZED) { + throw Exception("AudioRecord initialization failed") + } + + recorder.startRecording() + val shortBuffer = ShortArray(CHUNK_SAMPLES) + + while (shouldRecordAudio && _uiState.value == SttState.Listening) { + val readSize = recorder.read(shortBuffer, 0, shortBuffer.size) + if (readSize <= 0) { + if (readSize == AudioRecord.ERROR_INVALID_OPERATION || + readSize == AudioRecord.ERROR_BAD_VALUE) { + throw Exception("AudioRecord read error: $readSize") + } + continue + } + + val pcmBytes = toLittleEndianPcm(shortBuffer, readSize) + val payload = JSONObject() + .put("message_type", "input_audio_chunk") + .put("audio_base_64", Base64.encodeToString(pcmBytes, Base64.NO_WRAP)) + .put("sample_rate", SAMPLE_RATE) + + if (!webSocket.send(payload.toString())) { + throw Exception("Failed to send audio chunk to Scribe realtime") + } + } + } catch (e: Exception) { + if (_uiState.value == SttState.Listening) { + emitErrorAndStop(e) + } + } finally { + releaseRecorder() + } + } + } + } + + private fun toLittleEndianPcm(samples: ShortArray, readSize: Int): ByteArray { + val bytes = ByteArray(readSize * BYTES_PER_SAMPLE) + for (i in 0 until readSize) { + val sample = samples[i].toInt() + bytes[i * 2] = (sample and 0xFF).toByte() + bytes[i * 2 + 1] = ((sample ushr 8) and 0xFF).toByte() + } + return bytes + } + + private fun emitErrorAndStop(throwable: Throwable) { + val eventListener = activeEventListener + stopListeningInternal(sendNoneEvent = false) + eventListener?.invoke(InputEvent.Error(throwable)) + } + + private fun emitNoneAndStop() { + val eventListener = activeEventListener + stopListeningInternal(sendNoneEvent = false) + eventListener?.invoke(InputEvent.None) + } + + @Synchronized + private fun stopListeningInternal(sendNoneEvent: Boolean) { + val eventListener = activeEventListener + val wasListening = _uiState.value == SttState.Listening + + if (wasListening) { + _uiState.value = readyStateFromApiKey(currentApiKey) + } + + activeEventListener = null + + shouldRecordAudio = false + audioStreamingJob?.cancel() + audioStreamingJob = null + + webSocket?.close(1000, "normal") + webSocket = null + + releaseRecorder() + + if (wasListening && sendNoneEvent) { + eventListener?.invoke(InputEvent.None) + } + } + + private fun releaseRecorder() { + try { + audioRecord?.stop() + } catch (_: Exception) { + } + try { + audioRecord?.release() + } catch (_: Exception) { + } + audioRecord = null + } + + private fun buildRealtimeUrl(): String { + val languageCode = currentLanguageCode.ifBlank { DEFAULT_LANGUAGE_CODE } + val vadSilenceThresholdSecs = silencesBeforeStop.value + .toDouble() + .coerceIn(MIN_VAD_SILENCE_THRESHOLD_SECS, MAX_VAD_SILENCE_THRESHOLD_SECS) + + return Uri.parse(REALTIME_BASE_URL) + .buildUpon() + .appendQueryParameter("model_id", MODEL_ID) + .appendQueryParameter("audio_format", AUDIO_FORMAT) + .appendQueryParameter("commit_strategy", COMMIT_STRATEGY) + .appendQueryParameter("vad_silence_threshold_secs", vadSilenceThresholdSecs.toString()) + .appendQueryParameter("language_code", languageCode) + .build() + .toString() + } + + private fun readyStateFromApiKey(apiKey: String): SttState { + return if (apiKey.isBlank()) { + SttState.NotAvailable + } else { + SttState.Loaded + } + } + + private fun languageCodeFromLocale(locale: Locale): String { + return locale.language.ifBlank { DEFAULT_LANGUAGE_CODE } + } + + override suspend fun destroy() { + stopListeningInternal(sendNoneEvent = false) + scope.cancel() + } + + companion object { + private val TAG = ScribeRealtimeInputDevice::class.simpleName + + private const val REALTIME_BASE_URL = "wss://api.elevenlabs.io/v1/speech-to-text/realtime" + private const val MODEL_ID = "scribe_v2_realtime" + private const val AUDIO_FORMAT = "pcm_16000" + private const val COMMIT_STRATEGY = "vad" + + private const val SAMPLE_RATE = 16000 + private const val CHUNK_SAMPLES = 1600 // 100 ms at 16 kHz + private const val BYTES_PER_SAMPLE = 2 + + private const val DEFAULT_LANGUAGE_CODE = "en" + private const val MIN_VAD_SILENCE_THRESHOLD_SECS = 0.3 + private const val MAX_VAD_SILENCE_THRESHOLD_SECS = 3.0 + + private val ERROR_MESSAGE_TYPES = setOf( + "error", + "auth_error", + "quota_exceeded", + "transcriber_error", + "input_error", + "commit_throttled", + "unaccepted_terms", + "rate_limited", + "queue_overflow", + "resource_exhausted", + "session_time_limit_exceeded", + "chunk_size_exceeded", + "insufficient_audio_activity", + ) + } +} diff --git a/app/src/main/kotlin/org/stypox/dicio/io/input/stt_popup/SttPopup.kt b/app/src/main/kotlin/org/stypox/dicio/io/input/stt_popup/SttPopup.kt index 7eaee386c..8e1803e31 100644 --- a/app/src/main/kotlin/org/stypox/dicio/io/input/stt_popup/SttPopup.kt +++ b/app/src/main/kotlin/org/stypox/dicio/io/input/stt_popup/SttPopup.kt @@ -141,7 +141,9 @@ private fun SttPopupBottomSheet( value = textFieldValue, onValueChange = onTextFieldChange, customHint = customHint, - enabled = sttState != SttState.Listening, + enabled = sttState != SttState.Listening && + sttState != SttState.SilenceDetected && + sttState != SttState.Thinking, modifier = Modifier .padding(horizontal = 16.dp) .fillMaxWidth(), diff --git a/app/src/main/kotlin/org/stypox/dicio/settings/Definitions.kt b/app/src/main/kotlin/org/stypox/dicio/settings/Definitions.kt index c55470e5e..21769a038 100644 --- a/app/src/main/kotlin/org/stypox/dicio/settings/Definitions.kt +++ b/app/src/main/kotlin/org/stypox/dicio/settings/Definitions.kt @@ -9,7 +9,6 @@ import androidx.compose.material.icons.filled.Cloud import androidx.compose.material.icons.filled.ColorLens import androidx.compose.material.icons.filled.DarkMode import androidx.compose.material.icons.filled.Hearing -import androidx.compose.material.icons.filled.HourglassEmpty import androidx.compose.material.icons.filled.InvertColors import androidx.compose.material.icons.filled.KeyboardAlt import androidx.compose.material.icons.filled.Language @@ -31,8 +30,8 @@ import org.stypox.dicio.settings.datastore.SttPlaySound import org.stypox.dicio.settings.datastore.Theme import org.stypox.dicio.settings.datastore.WakeDevice import org.stypox.dicio.settings.ui.BooleanSetting -import org.stypox.dicio.settings.ui.IntSetting import org.stypox.dicio.settings.ui.ListSetting +import org.stypox.dicio.settings.ui.StringSetting @Composable @@ -49,6 +48,7 @@ fun languageSetting() = ListSetting( ListSetting.Value(Language.LANGUAGE_ES, "Español"), ListSetting.Value(Language.LANGUAGE_EL, "Ελληνικά"), ListSetting.Value(Language.LANGUAGE_FR, "Français"), + ListSetting.Value(Language.LANGUAGE_HU, "Magyar"), ListSetting.Value(Language.LANGUAGE_IT, "Italiano"), ListSetting.Value(Language.LANGUAGE_NL, "Nederlands"), ListSetting.Value(Language.LANGUAGE_PL, "Polski"), @@ -114,6 +114,18 @@ fun inputDevice() = ListSetting( description = stringResource(R.string.pref_input_method_vosk_summary), icon = Icons.Default.Mic, ), + ListSetting.Value( + value = InputDevice.INPUT_DEVICE_PARAKEET, + name = stringResource(R.string.pref_input_method_parakeet), + description = stringResource(R.string.pref_input_method_parakeet_summary), + icon = Icons.Default.Mic, + ), + ListSetting.Value( + value = InputDevice.INPUT_DEVICE_SCRIBE_REALTIME, + name = stringResource(R.string.pref_input_method_scribe_realtime), + description = stringResource(R.string.pref_input_method_scribe_realtime_summary), + icon = Icons.Default.Cloud, + ), ListSetting.Value( value = InputDevice.INPUT_DEVICE_EXTERNAL_POPUP, name = stringResource(R.string.pref_input_method_external_popup), @@ -128,6 +140,14 @@ fun inputDevice() = ListSetting( ), ) +@Composable +fun scribeApiKeySetting() = StringSetting( + title = stringResource(R.string.pref_input_method_scribe_api_key), + icon = Icons.Default.Cloud, + descriptionWhenEmpty = stringResource(R.string.pref_input_method_scribe_api_key_description_when_empty), + description = stringResource(R.string.pref_input_method_scribe_api_key_summary), +) + @Composable fun wakeDevice() = ListSetting( title = stringResource(R.string.pref_wake_method), @@ -173,15 +193,6 @@ fun speechOutputDevice() = ListSetting( ), ) -@Composable -fun sttSilenceDuration() = IntSetting( - title = stringResource(R.string.pref_stt_silence_duration_title), - icon = Icons.Default.HourglassEmpty, - description = @Composable { stringResource(R.string.pref_stt_silence_duration_description, it) }, - minimum = 1, - maximum = 7, -) - @Composable fun sttAutoFinish() = BooleanSetting( title = stringResource(R.string.pref_stt_auto_finish_title), diff --git a/app/src/main/kotlin/org/stypox/dicio/settings/MainSettingsScreen.kt b/app/src/main/kotlin/org/stypox/dicio/settings/MainSettingsScreen.kt index 71b5ef4ea..b2254d813 100644 --- a/app/src/main/kotlin/org/stypox/dicio/settings/MainSettingsScreen.kt +++ b/app/src/main/kotlin/org/stypox/dicio/settings/MainSettingsScreen.kt @@ -14,6 +14,7 @@ import androidx.compose.material.icons.automirrored.filled.ArrowBack import androidx.compose.material.icons.filled.DeleteSweep import androidx.compose.material.icons.filled.Extension import androidx.compose.material.icons.filled.UploadFile +import androidx.compose.material.icons.filled.Warning import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.IconButton @@ -32,7 +33,6 @@ import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel import org.stypox.dicio.R -import org.stypox.dicio.io.input.SttInputDevice import org.stypox.dicio.settings.datastore.InputDevice import org.stypox.dicio.settings.datastore.Language import org.stypox.dicio.settings.datastore.SpeechOutputDevice @@ -123,16 +123,34 @@ private fun MainSettingsScreen( /* INPUT AND OUTPUT METHODS */ item { SettingsCategoryTitle(stringResource(R.string.pref_io)) } + val selectedInputDevice = when (val inputDevice = settings.inputDevice) { + InputDevice.UNRECOGNIZED, + InputDevice.INPUT_DEVICE_UNSET -> InputDevice.INPUT_DEVICE_VOSK + else -> inputDevice + } item { inputDevice().Render( - when (val inputDevice = settings.inputDevice) { - InputDevice.UNRECOGNIZED, - InputDevice.INPUT_DEVICE_UNSET -> InputDevice.INPUT_DEVICE_VOSK - else -> inputDevice - }, + selectedInputDevice, viewModel::setInputDevice, ) } + if (selectedInputDevice == InputDevice.INPUT_DEVICE_SCRIBE_REALTIME) { + item { + scribeApiKeySetting().Render( + settings.scribeApiKey, + viewModel::setScribeApiKey, + ) + } + if (settings.scribeApiKey.isBlank()) { + item { + SettingsItem( + title = stringResource(R.string.pref_input_method_scribe_api_key_warning_title), + icon = Icons.Default.Warning, + description = stringResource(R.string.pref_input_method_scribe_api_key_warning_description), + ) + } + } + } val wakeDevice = when (val device = settings.wakeDevice) { WakeDevice.UNRECOGNIZED, WakeDevice.WAKE_DEVICE_UNSET -> WakeDevice.WAKE_DEVICE_OWW @@ -188,12 +206,6 @@ private fun MainSettingsScreen( viewModel::setSttPlaySound ) } - item { - sttSilenceDuration().Render( - SttInputDevice.getSttSilenceDurationOrDefault(settings), - viewModel::setSttSilenceDuration - ) - } item { sttAutoFinish().Render( settings.autoFinishSttPopup, diff --git a/app/src/main/kotlin/org/stypox/dicio/settings/MainSettingsViewModel.kt b/app/src/main/kotlin/org/stypox/dicio/settings/MainSettingsViewModel.kt index b4c130a12..192077fd3 100644 --- a/app/src/main/kotlin/org/stypox/dicio/settings/MainSettingsViewModel.kt +++ b/app/src/main/kotlin/org/stypox/dicio/settings/MainSettingsViewModel.kt @@ -65,14 +65,14 @@ class MainSettingsViewModel @Inject constructor( updateData { it.setDynamicColors(value) } fun setInputDevice(value: InputDevice) = updateData { it.setInputDevice(value) } + fun setScribeApiKey(value: String) = + updateData { it.setScribeApiKey(value.trim()) } fun setWakeDevice(value: WakeDevice) = updateData { it.setWakeDevice(value) } fun setSpeechOutputDevice(value: SpeechOutputDevice) = updateData { it.setSpeechOutputDevice(value) } fun setSttPlaySound(value: SttPlaySound) = updateData { it.setSttPlaySound(value) } - fun setSttSilenceDuration(value: Int) = - updateData { it.setSttSilenceDuration(value) } fun setAutoFinishSttPopup(value: Boolean) = updateData { it.setAutoFinishSttPopup(value) } } diff --git a/app/src/main/kotlin/org/stypox/dicio/skills/joke/JokeSkill.kt b/app/src/main/kotlin/org/stypox/dicio/skills/joke/JokeSkill.kt index 9675244b6..4c0302718 100644 --- a/app/src/main/kotlin/org/stypox/dicio/skills/joke/JokeSkill.kt +++ b/app/src/main/kotlin/org/stypox/dicio/skills/joke/JokeSkill.kt @@ -23,6 +23,13 @@ class JokeSkill(correspondingSkillInfo: SkillInfo, data: StandardRecognizerData< setup = joke.getString("setup"), delivery = joke.getString("punchline") ) + // Hungarian API uses "title" / "text" instead of "setup" / "delivery" + } else if (locale == "hu") { + val joke: JSONObject = ConnectionUtils.getPageJson(RANDOM_JOKE_URL_HU) + return JokeOutput.Success( + setup = joke.getString("title"), + delivery = joke.getString("text") + ) } else { val joke: JSONObject = ConnectionUtils.getPageJson( "$RANDOM_JOKE_URL?lang=$locale&safe-mode&type=twopart" @@ -37,8 +44,9 @@ class JokeSkill(correspondingSkillInfo: SkillInfo, data: StandardRecognizerData< companion object { private const val RANDOM_JOKE_URL = "https://v2.jokeapi.dev/joke/Any" private const val RANDOM_JOKE_URL_EN = "https://official-joke-api.appspot.com/random_joke" + private const val RANDOM_JOKE_URL_HU = "https://viccgyujt-api.anoim.workers.dev/random-vicc" val JOKE_SUPPORTED_LOCALES = listOf( - "cs", "de", "en", "es", "fr", "pt" + "cs", "de", "en", "es", "fr", "hu", "pt" ) } } diff --git a/app/src/main/kotlin/org/stypox/dicio/ui/home/SttButton.kt b/app/src/main/kotlin/org/stypox/dicio/ui/home/SttButton.kt index f5da083f5..e6a05b7b1 100644 --- a/app/src/main/kotlin/org/stypox/dicio/ui/home/SttButton.kt +++ b/app/src/main/kotlin/org/stypox/dicio/ui/home/SttButton.kt @@ -50,6 +50,8 @@ import org.stypox.dicio.io.input.SttState.NotAvailable import org.stypox.dicio.io.input.SttState.NotDownloaded import org.stypox.dicio.io.input.SttState.NotInitialized import org.stypox.dicio.io.input.SttState.NotLoaded +import org.stypox.dicio.io.input.SttState.SilenceDetected +import org.stypox.dicio.io.input.SttState.Thinking import org.stypox.dicio.io.input.SttState.Unzipping import org.stypox.dicio.io.input.SttState.WaitingForResult import org.stypox.dicio.ui.theme.AppTheme @@ -130,6 +132,8 @@ private fun sttFabText(state: SttState): String { is ErrorLoading -> stringResource(R.string.error_loading) is Loaded -> "" is Listening -> stringResource(R.string.listening) + is SilenceDetected -> stringResource(R.string.silence_detected) + is Thinking -> stringResource(R.string.thinking) is WaitingForResult -> stringResource(R.string.waiting) } } @@ -154,6 +158,8 @@ private fun SttFabIcon(state: SttState, contentDescription: String) { is ErrorLoading -> Icon(Icons.Default.Error, contentDescription) is Loaded -> Icon(Icons.Default.MicNone, stringResource(R.string.start_listening)) is Listening -> Icon(Icons.Default.Mic, contentDescription) + is SilenceDetected, + is Thinking, is WaitingForResult -> SmallCircularProgressIndicator() } } diff --git a/app/src/main/kotlin/org/stypox/dicio/ui/util/PreviewParameterProviders.kt b/app/src/main/kotlin/org/stypox/dicio/ui/util/PreviewParameterProviders.kt index 3419ecde0..fd08cbd40 100644 --- a/app/src/main/kotlin/org/stypox/dicio/ui/util/PreviewParameterProviders.kt +++ b/app/src/main/kotlin/org/stypox/dicio/ui/util/PreviewParameterProviders.kt @@ -138,6 +138,8 @@ class SttStatesPreviews : CollectionPreviewParameterProvider(listOf( SttState.ErrorLoading(Exception("ErrorLoading exception")), SttState.Loaded, SttState.Listening, + SttState.SilenceDetected, + SttState.Thinking, SttState.WaitingForResult, )) diff --git a/app/src/main/proto/input_device.proto b/app/src/main/proto/input_device.proto index 3f4cd7ca0..3709305d8 100644 --- a/app/src/main/proto/input_device.proto +++ b/app/src/main/proto/input_device.proto @@ -8,4 +8,6 @@ enum InputDevice { INPUT_DEVICE_NOTHING = 1; INPUT_DEVICE_VOSK = 2; INPUT_DEVICE_EXTERNAL_POPUP = 3; + INPUT_DEVICE_PARAKEET = 4; + INPUT_DEVICE_SCRIBE_REALTIME = 5; } diff --git a/app/src/main/proto/language.proto b/app/src/main/proto/language.proto index 19002e027..cbe591c56 100644 --- a/app/src/main/proto/language.proto +++ b/app/src/main/proto/language.proto @@ -17,6 +17,7 @@ enum Language { LANGUAGE_ES = 5; // Spanish LANGUAGE_EL = 6; // Greek LANGUAGE_FR = 7; // French + LANGUAGE_HU = 16; // Hungarian LANGUAGE_IT = 8; // Italian LANGUAGE_NL = 14; // Nederlands LANGUAGE_PL = 12; // Polish diff --git a/app/src/main/proto/user_settings.proto b/app/src/main/proto/user_settings.proto index 6c37b2bb8..32b568ec4 100644 --- a/app/src/main/proto/user_settings.proto +++ b/app/src/main/proto/user_settings.proto @@ -20,5 +20,7 @@ message UserSettings { map enabled_skills = 7; WakeDevice wake_device = 8; SttPlaySound stt_play_sound = 9; - int32 stt_silence_duration = 10; + reserved 10; + reserved "stt_silence_duration"; + string scribe_api_key = 11; } diff --git a/app/src/main/res/values-hu/strings.xml b/app/src/main/res/values-hu/strings.xml index b99d24ce6..2c8432afd 100644 --- a/app/src/main/res/values-hu/strings.xml +++ b/app/src/main/res/values-hu/strings.xml @@ -19,6 +19,9 @@ Beszédből szöveggé alakítás nem elérhető Beszédből szöveggé alakítása kitömörítése Figyelek… + Csend észlelve + Gondolkodom… + Várakozás… Letöltve Betöltve Ezt tudom nyújtani! @@ -32,7 +35,9 @@ Vosk modell kicsomagolása… Vosk modell letöltése meghiúsult Vosk modell kész + Beszédből szöveggé felugró ablak Beállítások + Névjegy Jelentés Hibabejelentés Sajnáljuk, hiba merült fel @@ -51,6 +56,21 @@ Hiba történt, lásd az értesítésben Megfigyelés Vosk modell letöltése… + NVIDIA Parakeet offline beszédfelismerés + Letölti és futtatja a kvantált NVIDIA Parakeet v3 többnyelvű beszédfelismerő modellt a Dición belül; 25 európai nyelvet támogat, beleértve a magyart + ElevenLabs Scribe v2 valós idejű + A mikrofon hangját az ElevenLabs felé streameli, és valós időben alacsony késleltetésű részleges és végleges átiratokat ad + ElevenLabs API kulcs + Szükséges a Scribe v2 valós idejű bevitelhez + A Scribe v2 valós idejű használatához állítsd be az ElevenLabs API kulcsot + A Scribe valós idejű még nincs kész + A bemeneti mód használata előtt add meg feljebb az ElevenLabs API kulcsot. + Parakeet modell + Többnyelvű (25 nyelv) + Nincs elérhető Parakeet modell az adott nyelven + Parakeet modell letöltése… + Parakeet modell letöltése meghiúsult + Parakeet modell kész Az adott nyelv nem támogatott az Android TTS eszközében Hiba merült fel az Android TTS eszközében A Dicio hibát észlelt, kattintson ide diff --git a/app/src/main/res/values/strings.xml b/app/src/main/res/values/strings.xml index ea8db6663..943867f50 100644 --- a/app/src/main/res/values/strings.xml +++ b/app/src/main/res/values/strings.xml @@ -24,6 +24,8 @@ Download Speech to Text Unzip Speech to Text Listening… + Silence detected + Thinking… Waiting… Start listening Downloaded @@ -42,6 +44,12 @@ Downloading Vosk model failed Extracting Vosk model from Zip failed Vosk model ready + Parakeet model + Multilingual (25 languages) + No Parakeet model available for the current language + Downloading Parakeet model… + Downloading Parakeet model failed + Parakeet model ready The current language is not supported by the Android text to speech engine An error occurred while initializing the Android text to speech engine Settings @@ -80,7 +88,16 @@ Choose the service to use to talk to Dicio Only text box Vosk offline speech recognition - Downloads and runs a Vosk SpeechToText model offline inside Dicio + Downloads and runs a Vosk SpeechToText model offline inside Dicio. Not available for all languages (e.g. Hungarian). + NVIDIA Parakeet offline speech recognition + Downloads and runs a quantized NVIDIA Parakeet v3 multilingual SpeechToText model offline inside Dicio; supports 25 European languages including Hungarian + ElevenLabs Scribe v2 realtime + Streams microphone audio to ElevenLabs and receives low-latency partial and final transcriptions in realtime + ElevenLabs API key + Required for Scribe v2 realtime input + Set your ElevenLabs API key to use Scribe v2 realtime + Scribe realtime is not ready + Enter your ElevenLabs API key above before using this input method. External SpeechToText popup Opens the UI of another app through ACTION\u200B_\u200BRECOGNIZE\u200B_\u200BSPEECH, for example whisperIME, Kõnele, Futo Voice Input Speech output method @@ -105,8 +122,6 @@ DuckDuckGo Default city Set the city to use for weather when you do not explicitly say one. The current behaviour is to get the location from IP info. - Maximum silence before start of speech - Give up listening after %1$d silent intervals Directly send result of speech to text popup Automatically send speech result to requesting app when listening finishes Wait for manual confirmation before sending speech result to requesting app diff --git a/app/src/main/sentences/hu/calculator.yml b/app/src/main/sentences/hu/calculator.yml new file mode 100644 index 000000000..93d0698f7 --- /dev/null +++ b/app/src/main/sentences/hu/calculator.yml @@ -0,0 +1,3 @@ +calculate: + - számold ki|számítsd ki|mennyi|(mi az)|(mennyi az)|számolás .calculation. + - add meg az? eredményé .calculation. diff --git a/app/src/main/sentences/hu/calculator_operators.yml b/app/src/main/sentences/hu/calculator_operators.yml new file mode 100644 index 000000000..9ffe990d9 --- /dev/null +++ b/app/src/main/sentences/hu/calculator_operators.yml @@ -0,0 +1,17 @@ +addition: + - plusz|meg|(hozzáadva|hozzáadni|összeadva|összeadni|összeg) + +subtraction: + - mínusz|(kivonva|kivonni|különbség) + +multiplication: + - szorozva|szor|(szorozni|szorzat) + +division: + - osztva|per|(osztani|osztva|hányadosa?) + +power: + - a? hatvány + +square_root: + - négyzetgyök diff --git a/app/src/main/sentences/hu/current_time.yml b/app/src/main/sentences/hu/current_time.yml new file mode 100644 index 000000000..610b466c7 --- /dev/null +++ b/app/src/main/sentences/hu/current_time.yml @@ -0,0 +1,5 @@ +query: + - mennyi az? idő|óra + - hány óra van most? + - (mi az idő)|(most hány? óra van) + - mondd meg az? idő diff --git a/app/src/main/sentences/hu/flashlight.yml b/app/src/main/sentences/hu/flashlight.yml new file mode 100644 index 000000000..3b88d698b --- /dev/null +++ b/app/src/main/sentences/hu/flashlight.yml @@ -0,0 +1,9 @@ +turn_on: + - kapcsold be|fel a? zseblámpa|világítás|villany + - (zseblámpa|világítás|villany) bekapcsol<ás|ni?> + - világíts|(világítás be) + +turn_off: + - kapcsold ki|le a? zseblámpa|világítás|villany + - (zseblámpa|világítás|villany) kikapcsol<ás|ni?> + - (világítás ki)|sötétítsd el diff --git a/app/src/main/sentences/hu/joke.yml b/app/src/main/sentences/hu/joke.yml new file mode 100644 index 000000000..1dd01807b --- /dev/null +++ b/app/src/main/sentences/hu/joke.yml @@ -0,0 +1,6 @@ +command: + - mondj|mesélj egy|még egy? jó|vicces? vicc|poént|tréfá + - (tudsz valami vicceset)|(tudsz egy viccet)|(van egy jó vicc) + - (nevettess meg)|(szórakoztass)|(vidíts fel) (egy|még egy? jó|vicces? vicc)? + - mondj valami vicces|mulatságos + - (kérlek egy vicc)|(hallgassunk egy vicc) \ No newline at end of file diff --git a/app/src/main/sentences/hu/listening.yml b/app/src/main/sentences/hu/listening.yml new file mode 100644 index 000000000..28ea71e3c --- /dev/null +++ b/app/src/main/sentences/hu/listening.yml @@ -0,0 +1,9 @@ +stop: + - (hagyd abba a hallgatást)|(ne figyeld az ébresztő szót) + - némítsd el a? mikrofon + - kapcsold ki a? mikrofon + +start: + - (kezdj el hallgatni)|(figyeld az ébresztő szót) + - (mikrofon be)|(mikrofon bekapcsol<ás?>) + - kapcsold be a? mikrofon diff --git a/app/src/main/sentences/hu/lyrics.yml b/app/src/main/sentences/hu/lyrics.yml new file mode 100644 index 000000000..bac4781d5 --- /dev/null +++ b/app/src/main/sentences/hu/lyrics.yml @@ -0,0 +1,6 @@ +query: + - keress|mutasd|töltsd be|jelenítsd meg a? dalszöveg (a dalhoz|ehhez a dalhoz)? .song. + - énekeld el (a dalt)? .song. + - .song. dalszöveg + - mi a dalszöveg (a dalhoz|ehhez)? .song. + - (mi a szövege)|(mi a szöveg) .song. diff --git a/app/src/main/sentences/hu/media.yml b/app/src/main/sentences/hu/media.yml new file mode 100644 index 000000000..2546bb24b --- /dev/null +++ b/app/src/main/sentences/hu/media.yml @@ -0,0 +1,16 @@ +play: + - játsz|indítsd|folytasd az? dal|zené|médiá|audió|videó|filmet + - kapcsold be a? lejátszó + - folytasd a lejátszás (az? dal|zené|médiá)? + +pause: + - szüneteltesd|állítsd meg|állj az? dal|zené|médiá|audió|videó|filmet + - tedd szünetr az? dal|zené|médiá + +previous: + - játszd újra|(menj|válts|lépj) az? előző dal|zené|médiá|audió|videó + - tekerj vissza + +next: + - (ugord át)|(lépj tovább) az? dal|zené|médiá + - játszd|(válts|ugorj|lépj) a? következő dal|zené|médiá|audió|videó diff --git a/app/src/main/sentences/hu/navigation.yml b/app/src/main/sentences/hu/navigation.yml new file mode 100644 index 000000000..8b68d607d --- /dev/null +++ b/app/src/main/sentences/hu/navigation.yml @@ -0,0 +1,3 @@ +query: + - navigálj|vigyél|(adj útvonalat)|(mutasd az utat)|(hogyan jut el) .where. + - hol van|található .where. diff --git a/app/src/main/sentences/hu/notify.yml b/app/src/main/sentences/hu/notify.yml new file mode 100644 index 000000000..0cd1d3c2a --- /dev/null +++ b/app/src/main/sentences/hu/notify.yml @@ -0,0 +1,4 @@ +notifications: + - olvasd fel|mondd el a? értesítés|értesítés + - mi az? (új)? értesítés + - (milyen értesítéseim vannak)|(van értesítésem) diff --git a/app/src/main/sentences/hu/open.yml b/app/src/main/sentences/hu/open.yml new file mode 100644 index 000000000..230e82e25 --- /dev/null +++ b/app/src/main/sentences/hu/open.yml @@ -0,0 +1,2 @@ +query: + - nyisd meg|indítsd el|futtasd az? (alkalmazás|appot? .what.)|(.what. alkalmazás|appot?) diff --git a/app/src/main/sentences/hu/search.yml b/app/src/main/sentences/hu/search.yml new file mode 100644 index 000000000..c530fecbd --- /dev/null +++ b/app/src/main/sentences/hu/search.yml @@ -0,0 +1,4 @@ +query: + - keress rá|keress|keress meg .what. online|(az interneten|a weben|a neten)? + - nézd meg|keress .what. + - (mi az a)|(mi az) .what. diff --git a/app/src/main/sentences/hu/telephone.yml b/app/src/main/sentences/hu/telephone.yml new file mode 100644 index 000000000..af7e639ee --- /dev/null +++ b/app/src/main/sentences/hu/telephone.yml @@ -0,0 +1,3 @@ +dial: + - hívd fel|hívd|telefonálj|tárcsázd .who. + - (csörgess rá)|(szólj rá telefonon) .who. diff --git a/app/src/main/sentences/hu/timer.yml b/app/src/main/sentences/hu/timer.yml new file mode 100644 index 000000000..003628a75 --- /dev/null +++ b/app/src/main/sentences/hu/timer.yml @@ -0,0 +1,11 @@ +set: + - időzítő|timer|(szólj nekem) .duration. + - állíts be|indíts (egy? (.duration. időzítő)|(időzítő (ennyi|időre .duration.)? (nevű|néven .name.)?)) + +cancel: + - töröld|állítsd le|kapcsold ki az? (.name.? időzítő)|(időzítő néven|nevű .name.) + - (halkítsd el)|(kapcsold ki)|(némítsd el) az? (.name.? időzítő|csengő|hangjelzés)|(időzítő|csengő nevű|néven .name.) + +query: + - mennyi (idő|van) (van még)? hátra az? (.name.? időzítő)|(időzítő néven|nevű .name.) + - mikor jár le az? (.name.? időzítő)|(időzítő néven|nevű .name.) diff --git a/app/src/main/sentences/hu/translation.yml b/app/src/main/sentences/hu/translation.yml new file mode 100644 index 000000000..488cdb8d0 --- /dev/null +++ b/app/src/main/sentences/hu/translation.yml @@ -0,0 +1,4 @@ +translate: + - kérlek? fordítsd le .query. ((.source. nyelvről?)? (.target. nyelvre)?)|(.target. nyelvre .source. nyelvről?) (nekem)? + - (szeretném tudni)|(mondd meg)? mi jelent .query. .target. nyelven|ül (.source. nyelvről|ből)? + - hogyan mondjá .query. .target. nyelven|ül diff --git a/app/src/main/sentences/hu/util_yes_no.yml b/app/src/main/sentences/hu/util_yes_no.yml new file mode 100644 index 000000000..fe0a5c550 --- /dev/null +++ b/app/src/main/sentences/hu/util_yes_no.yml @@ -0,0 +1,5 @@ +yes: + - igen|persze|természetesen|rendben|oké|hogyne|(rajta)|(gyerünk)|(mehet) + +no: + - nem|ne|hagyd|állj|mégse|mégsem|semmiképp|stop diff --git a/app/src/main/sentences/hu/weather.yml b/app/src/main/sentences/hu/weather.yml new file mode 100644 index 000000000..8f07a40c5 --- /dev/null +++ b/app/src/main/sentences/hu/weather.yml @@ -0,0 +1,6 @@ +current: + - (milyen|mi|hogyan)? az? időjárás (most)? (.where.)? + - (.where. területén)? + - időjárás (.where.)? + - milyen az idő kint|odakint + - hideg|hűvös|meleg|forró|napos|esős van (most)? (.where.)|kint|odakint? diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c67125e29..87b86b568 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -36,9 +36,10 @@ unbescape = "1.1.6.RELEASE" # https://github.com/unicode-org/cldr repo to use as a source of data unicodeCldrGitCommit = "41283df11cce01751c29c400a8f94d1d8687210d" voskAndroid = "0.3.70" -litert = "1.4.0" # cannot update to 2.0.0 as its minSdk is 23 +onnxruntime = "1.22.0" +litert = "1.4.0" permissionFlow = "2.1.0" -minSdk = "21" +minSdk = "24" targetSdk = "36" compileSdk = "36" @@ -93,6 +94,7 @@ test-runner = { module = "androidx.test:runner", version.ref = "androidxTest" } test-ui-automator = { module = "androidx.test.uiautomator:uiautomator", version.ref = "androidxTestUiAutomator" } unbescape = { module = "org.unbescape:unbescape", version.ref = "unbescape" } vosk-android = { module = "com.alphacephei:vosk-android", version.ref = "voskAndroid" } +onnxruntime-android = { module = "com.microsoft.onnxruntime:onnxruntime-android", version.ref = "onnxruntime" } litert = { module = "com.google.ai.edge.litert:litert", version.ref = "litert" } permission-flow-android = { module = "dev.shreyaspatil.permission-flow:permission-flow-android", version.ref = "permissionFlow" } permission-flow-compose = { module = "dev.shreyaspatil.permission-flow:permission-flow-compose", version.ref = "permissionFlow" }