From fc61725e910b816f640e8a29092580e606a9d84a Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 20:55:30 +0700 Subject: [PATCH 01/58] Scaffold local Rust backend --- desktop/local-backend/.gitignore | 2 + desktop/local-backend/Cargo.lock | 852 ++++++++++++++++++++++++++++ desktop/local-backend/Cargo.toml | 18 + desktop/local-backend/README.md | 34 ++ desktop/local-backend/src/config.rs | 49 ++ desktop/local-backend/src/health.rs | 23 + desktop/local-backend/src/main.rs | 85 +++ 7 files changed, 1063 insertions(+) create mode 100644 desktop/local-backend/.gitignore create mode 100644 desktop/local-backend/Cargo.lock create mode 100644 desktop/local-backend/Cargo.toml create mode 100644 desktop/local-backend/README.md create mode 100644 desktop/local-backend/src/config.rs create mode 100644 desktop/local-backend/src/health.rs create mode 100644 desktop/local-backend/src/main.rs diff --git a/desktop/local-backend/.gitignore b/desktop/local-backend/.gitignore new file mode 100644 index 00000000000..556a76756c3 --- /dev/null +++ b/desktop/local-backend/.gitignore @@ -0,0 +1,2 @@ +/target/ +!Cargo.lock diff --git a/desktop/local-backend/Cargo.lock b/desktop/local-backend/Cargo.lock new file mode 100644 index 00000000000..eea9126b087 --- /dev/null +++ b/desktop/local-backend/Cargo.lock @@ -0,0 +1,852 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "bitflags" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "directories" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a49173b84e034382284f27f1af4dcbbd231ffa358c0fe316541a7337f376a35" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "bytes", + "http", + "http-body", + "hyper", + "pin-project-lite", + "tokio", + "tower-service", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "libredox" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" +dependencies = [ + "libc", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "omi-local-backend" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum", + "directories", + "serde", + "tokio", + "tower-http", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tokio" +version = "1.52.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" +dependencies = [ + "libc", + "mio", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags", + "bytes", + "http", + "http-body", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-serde" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704b1aeb7be0d0a84fc9828cae51dab5970fee5088f83d1dd7ee6f6246fc6ff1" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", + "tracing-serde", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/desktop/local-backend/Cargo.toml b/desktop/local-backend/Cargo.toml new file mode 100644 index 00000000000..a181c57c7b8 --- /dev/null +++ b/desktop/local-backend/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "omi-local-backend" +version = "0.1.0" +edition = "2021" +license = "MIT" +publish = false + +[dependencies] +anyhow = "1" +axum = "0.7" +directories = "5" +serde = { version = "1", features = ["derive"] } +tokio = { version = "1", features = ["macros", "net", "rt-multi-thread", "signal"] } +tower-http = { version = "0.5", features = ["trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json"] } + +[workspace] diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md new file mode 100644 index 00000000000..ea3cc18239b --- /dev/null +++ b/desktop/local-backend/README.md @@ -0,0 +1,34 @@ +# Omi Local Backend + +This is the lean local-first daemon scaffold for Omi Desktop. It is separate from +`desktop/Backend-Rust` so the local daemon can build and run without Omi cloud +credentials, Firebase, Firestore, Redis, GCS, pusher, paywall, or agent-proxy +dependencies. + +## Run Locally + +```bash +cd desktop/local-backend +cargo run +``` + +The daemon listens on `127.0.0.1:8765` by default and stores local data under the +platform app data directory. + +Configuration is environment-based: + +```bash +OMI_LOCAL_BACKEND_HOST=127.0.0.1 \ +OMI_LOCAL_BACKEND_PORT=8777 \ +OMI_LOCAL_BACKEND_DATA_DIR=/tmp/omi-local-backend \ +cargo run +``` + +Verify the health endpoint: + +```bash +curl http://127.0.0.1:8765/health +``` + +The response includes the service name, local mode, package version, bind +address, and resolved data directory. diff --git a/desktop/local-backend/src/config.rs b/desktop/local-backend/src/config.rs new file mode 100644 index 00000000000..926e0c7ac8f --- /dev/null +++ b/desktop/local-backend/src/config.rs @@ -0,0 +1,49 @@ +use std::env; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::path::PathBuf; + +use anyhow::{Context, Result}; +use directories::ProjectDirs; + +const DEFAULT_HOST: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); +const DEFAULT_PORT: u16 = 8765; + +#[derive(Clone, Debug)] +pub struct Config { + pub bind_addr: SocketAddr, + pub data_dir: PathBuf, +} + +impl Config { + pub fn from_env() -> Result { + let host = env::var("OMI_LOCAL_BACKEND_HOST") + .ok() + .map(|value| value.parse::()) + .transpose() + .context("OMI_LOCAL_BACKEND_HOST must be an IP address")? + .unwrap_or(DEFAULT_HOST); + + let port = env::var("OMI_LOCAL_BACKEND_PORT") + .ok() + .map(|value| value.parse::()) + .transpose() + .context("OMI_LOCAL_BACKEND_PORT must be a valid TCP port")? + .unwrap_or(DEFAULT_PORT); + + let data_dir = match env::var("OMI_LOCAL_BACKEND_DATA_DIR") { + Ok(value) => PathBuf::from(value), + Err(_) => default_data_dir()?, + }; + + Ok(Self { + bind_addr: SocketAddr::new(host, port), + data_dir, + }) + } +} + +fn default_data_dir() -> Result { + let project_dirs = ProjectDirs::from("com", "omi", "Omi Local Backend") + .context("could not resolve a local data directory for this platform")?; + Ok(project_dirs.data_local_dir().to_path_buf()) +} diff --git a/desktop/local-backend/src/health.rs b/desktop/local-backend/src/health.rs new file mode 100644 index 00000000000..84d329d6bfd --- /dev/null +++ b/desktop/local-backend/src/health.rs @@ -0,0 +1,23 @@ +use axum::{extract::State, Json}; +use serde::Serialize; + +use crate::AppState; + +#[derive(Serialize)] +pub struct HealthResponse { + service: &'static str, + mode: &'static str, + version: &'static str, + bind_addr: String, + data_dir: String, +} + +pub async fn health(State(state): State) -> Json { + Json(HealthResponse { + service: "omi-local-backend", + mode: "local", + version: env!("CARGO_PKG_VERSION"), + bind_addr: state.config.bind_addr.to_string(), + data_dir: state.config.data_dir.display().to_string(), + }) +} diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs new file mode 100644 index 00000000000..d2b202ee5c6 --- /dev/null +++ b/desktop/local-backend/src/main.rs @@ -0,0 +1,85 @@ +use std::fs; +use std::sync::Arc; + +use anyhow::Result; +use axum::{routing::get, Router}; +use tokio::net::TcpListener; +use tower_http::trace::TraceLayer; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +mod config; +mod health; + +use config::Config; +use health::health; + +#[derive(Clone)] +pub struct AppState { + pub config: Arc, +} + +#[tokio::main] +async fn main() -> Result<()> { + init_tracing(); + + let config = Config::from_env()?; + fs::create_dir_all(&config.data_dir)?; + + let bind_addr = config.bind_addr; + let state = AppState { + config: Arc::new(config), + }; + + let app = Router::new() + .route("/health", get(health)) + .layer(TraceLayer::new_for_http()) + .with_state(state); + + let listener = TcpListener::bind(bind_addr).await?; + tracing::info!( + service = "omi-local-backend", + mode = "local", + %bind_addr, + "listening" + ); + + axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal()) + .await?; + + Ok(()) +} + +fn init_tracing() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "omi_local_backend=info,tower_http=info".into()), + ) + .with(tracing_subscriber::fmt::layer().json()) + .init(); +} + +async fn shutdown_signal() { + let ctrl_c = async { + tokio::signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to install SIGTERM handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } +} From 8f606862c2d40f1dca8e26df749bca5d8b83ae33 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 21:00:32 +0700 Subject: [PATCH 02/58] Add local backend SQLite storage --- desktop/local-backend/Cargo.lock | 664 ++++++++++++++++++- desktop/local-backend/Cargo.toml | 7 + desktop/local-backend/src/main.rs | 5 + desktop/local-backend/src/storage.rs | 927 +++++++++++++++++++++++++++ 4 files changed, 1602 insertions(+), 1 deletion(-) create mode 100644 desktop/local-backend/src/storage.rs diff --git a/desktop/local-backend/Cargo.lock b/desktop/local-backend/Cargo.lock index eea9126b087..6cdd0cafa78 100644 --- a/desktop/local-backend/Cargo.lock +++ b/desktop/local-backend/Cargo.lock @@ -2,6 +2,18 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -11,6 +23,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anyhow" version = "1.0.102" @@ -34,6 +55,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + [[package]] name = "axum" version = "0.7.9" @@ -95,18 +122,90 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + [[package]] name = "bytes" version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "cc" +version = "1.2.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "num-traits", + "serde", + "windows-link", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "directories" version = "5.0.1" @@ -128,6 +227,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "errno" version = "0.3.14" @@ -138,6 +243,36 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -180,6 +315,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.17" @@ -191,6 +336,58 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", + "wasip3", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" + +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "http" version = "1.4.0" @@ -271,18 +468,78 @@ dependencies = [ "tower-service", ] +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "indexmap" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" +dependencies = [ + "equivalent", + "hashbrown 0.17.1", + "serde", + "serde_core", +] + [[package]] name = "itoa" version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" +[[package]] +name = "js-sys" +version = "0.3.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" +dependencies = [ + "cfg-if", + "futures-util", + "once_cell", + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.186" @@ -298,6 +555,23 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + [[package]] name = "log" version = "0.4.29" @@ -351,14 +625,28 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "omi-local-backend" version = "0.1.0" dependencies = [ "anyhow", "axum", + "chrono", "directories", + "rusqlite", "serde", + "serde_json", + "sha2", + "tempfile", "tokio", "tower-http", "tracing", @@ -389,6 +677,22 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" +[[package]] +name = "pkg-config" +version = "0.3.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -407,13 +711,19 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "redox_users" version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.17", "libredox", "thiserror", ] @@ -435,6 +745,34 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags", + "chrono", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -447,6 +785,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + [[package]] name = "serde" version = "1.0.228" @@ -513,6 +857,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -522,6 +877,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.8" @@ -571,6 +932,19 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -746,30 +1120,204 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "typenum" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" + [[package]] name = "unicode-ident" version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "valuable" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasip2" +version = "1.0.3+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" +dependencies = [ + "wit-bindgen 0.57.1", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.121" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.121" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.121" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.121" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -845,6 +1393,120 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zmij" version = "1.0.21" diff --git a/desktop/local-backend/Cargo.toml b/desktop/local-backend/Cargo.toml index a181c57c7b8..ed99d975ea3 100644 --- a/desktop/local-backend/Cargo.toml +++ b/desktop/local-backend/Cargo.toml @@ -8,11 +8,18 @@ publish = false [dependencies] anyhow = "1" axum = "0.7" +chrono = { version = "0.4", default-features = false, features = ["clock", "serde"] } directories = "5" +rusqlite = { version = "0.32", features = ["bundled", "chrono"] } serde = { version = "1", features = ["derive"] } +serde_json = "1" +sha2 = "0.10" tokio = { version = "1", features = ["macros", "net", "rt-multi-thread", "signal"] } tower-http = { version = "0.5", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json"] } +[dev-dependencies] +tempfile = "3" + [workspace] diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index d2b202ee5c6..5f5617f463f 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -9,13 +9,16 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod config; mod health; +mod storage; use config::Config; use health::health; +use storage::Store; #[derive(Clone)] pub struct AppState { pub config: Arc, + pub store: Store, } #[tokio::main] @@ -24,10 +27,12 @@ async fn main() -> Result<()> { let config = Config::from_env()?; fs::create_dir_all(&config.data_dir)?; + let store = Store::open(config.data_dir.join("omi-local-backend.sqlite"))?; let bind_addr = config.bind_addr; let state = AppState { config: Arc::new(config), + store, }; let app = Router::new() diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs new file mode 100644 index 00000000000..691e87c899d --- /dev/null +++ b/desktop/local-backend/src/storage.rs @@ -0,0 +1,927 @@ +use std::path::Path; +use std::sync::{Arc, Mutex}; + +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use rusqlite::{params, Connection, OptionalExtension}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; + +const MIGRATIONS: &[Migration] = &[Migration { + version: 1, + name: "initial_local_storage", + sql: r#" + CREATE TABLE conversations ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + title TEXT NOT NULL DEFAULT '', + overview TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'open', + started_at TEXT NOT NULL, + ended_at TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT, + cloud_id TEXT, + sync_version INTEGER NOT NULL DEFAULT 0, + sync_state TEXT NOT NULL DEFAULT 'local', + metadata_json TEXT NOT NULL DEFAULT '{}' + ); + + CREATE INDEX idx_conversations_session_id ON conversations(session_id); + CREATE INDEX idx_conversations_updated_at ON conversations(updated_at); + CREATE INDEX idx_conversations_deleted_at ON conversations(deleted_at); + + CREATE TABLE transcript_segments ( + id TEXT PRIMARY KEY, + conversation_id TEXT NOT NULL REFERENCES conversations(id) ON DELETE CASCADE, + session_id TEXT NOT NULL, + speaker_id TEXT, + speaker_label TEXT, + text TEXT NOT NULL, + start_ms INTEGER NOT NULL, + end_ms INTEGER NOT NULL, + segment_index INTEGER NOT NULL, + source TEXT NOT NULL DEFAULT 'local', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT, + cloud_id TEXT, + sync_version INTEGER NOT NULL DEFAULT 0, + sync_state TEXT NOT NULL DEFAULT 'local', + metadata_json TEXT NOT NULL DEFAULT '{}', + UNIQUE(conversation_id, segment_index) + ); + + CREATE INDEX idx_transcript_segments_conversation ON transcript_segments(conversation_id, segment_index); + CREATE INDEX idx_transcript_segments_session ON transcript_segments(session_id); + + CREATE TABLE memories ( + id TEXT PRIMARY KEY, + content TEXT NOT NULL, + category TEXT, + conversation_id TEXT REFERENCES conversations(id) ON DELETE SET NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT, + cloud_id TEXT, + sync_version INTEGER NOT NULL DEFAULT 0, + sync_state TEXT NOT NULL DEFAULT 'local', + metadata_json TEXT NOT NULL DEFAULT '{}' + ); + + CREATE TABLE action_items ( + id TEXT PRIMARY KEY, + conversation_id TEXT REFERENCES conversations(id) ON DELETE SET NULL, + title TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'open', + due_at TEXT, + completed_at TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT, + cloud_id TEXT, + sync_version INTEGER NOT NULL DEFAULT 0, + sync_state TEXT NOT NULL DEFAULT 'local', + metadata_json TEXT NOT NULL DEFAULT '{}' + ); + + CREATE TABLE local_settings ( + key TEXT PRIMARY KEY, + value_json TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT, + cloud_id TEXT, + sync_version INTEGER NOT NULL DEFAULT 0, + sync_state TEXT NOT NULL DEFAULT 'local' + ); + + CREATE TABLE local_profiles ( + id TEXT PRIMARY KEY, + display_name TEXT NOT NULL DEFAULT '', + timezone TEXT, + locale TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT, + cloud_id TEXT, + sync_version INTEGER NOT NULL DEFAULT 0, + sync_state TEXT NOT NULL DEFAULT 'local', + metadata_json TEXT NOT NULL DEFAULT '{}' + ); + + CREATE TABLE processing_jobs ( + id TEXT PRIMARY KEY, + kind TEXT NOT NULL, + status TEXT NOT NULL CHECK(status IN ('queued', 'running', 'completed', 'failed')), + target_conversation_id TEXT REFERENCES conversations(id) ON DELETE CASCADE, + retry_count INTEGER NOT NULL DEFAULT 0, + max_retries INTEGER NOT NULL DEFAULT 3, + last_error TEXT, + payload_json TEXT NOT NULL DEFAULT '{}', + result_json TEXT NOT NULL DEFAULT '{}', + queued_at TEXT NOT NULL, + started_at TEXT, + completed_at TEXT, + failed_at TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT, + cloud_id TEXT, + sync_version INTEGER NOT NULL DEFAULT 0, + sync_state TEXT NOT NULL DEFAULT 'local' + ); + + CREATE INDEX idx_processing_jobs_status ON processing_jobs(status, queued_at); + CREATE INDEX idx_processing_jobs_conversation ON processing_jobs(target_conversation_id); + + CREATE TABLE sync_outbox ( + id TEXT PRIMARY KEY, + entity_type TEXT NOT NULL, + entity_id TEXT NOT NULL, + operation TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'pending', + attempt_count INTEGER NOT NULL DEFAULT 0, + last_error TEXT, + payload_json TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + next_attempt_at TEXT, + completed_at TEXT + ); + + CREATE INDEX idx_sync_outbox_status ON sync_outbox(status, next_attempt_at, created_at); + CREATE INDEX idx_sync_outbox_entity ON sync_outbox(entity_type, entity_id); + + CREATE TABLE local_files ( + id TEXT PRIMARY KEY, + conversation_id TEXT REFERENCES conversations(id) ON DELETE SET NULL, + kind TEXT NOT NULL, + path TEXT NOT NULL, + media_type TEXT, + byte_size INTEGER, + checksum TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT, + cloud_id TEXT, + sync_version INTEGER NOT NULL DEFAULT 0, + sync_state TEXT NOT NULL DEFAULT 'local', + metadata_json TEXT NOT NULL DEFAULT '{}' + ); + + CREATE INDEX idx_local_files_conversation ON local_files(conversation_id); + + CREATE VIRTUAL TABLE conversation_search USING fts5( + conversation_id UNINDEXED, + source_type UNINDEXED, + source_id UNINDEXED, + title, + overview, + transcript_text, + tokenize = 'unicode61' + ); + + CREATE TRIGGER conversations_ai AFTER INSERT ON conversations BEGIN + INSERT INTO conversation_search(conversation_id, source_type, source_id, title, overview, transcript_text) + VALUES (new.id, 'conversation', new.id, new.title, new.overview, ''); + END; + + CREATE TRIGGER conversations_au AFTER UPDATE OF title, overview, deleted_at ON conversations BEGIN + DELETE FROM conversation_search WHERE source_type = 'conversation' AND source_id = old.id; + INSERT INTO conversation_search(conversation_id, source_type, source_id, title, overview, transcript_text) + SELECT new.id, 'conversation', new.id, new.title, new.overview, '' + WHERE new.deleted_at IS NULL; + END; + + CREATE TRIGGER conversations_ad AFTER DELETE ON conversations BEGIN + DELETE FROM conversation_search WHERE conversation_id = old.id; + END; + + CREATE TRIGGER transcript_segments_ai AFTER INSERT ON transcript_segments BEGIN + INSERT INTO conversation_search(conversation_id, source_type, source_id, title, overview, transcript_text) + SELECT new.conversation_id, 'segment', new.id, c.title, c.overview, new.text + FROM conversations c + WHERE c.id = new.conversation_id AND new.deleted_at IS NULL AND c.deleted_at IS NULL; + END; + + CREATE TRIGGER transcript_segments_au AFTER UPDATE OF text, deleted_at ON transcript_segments BEGIN + DELETE FROM conversation_search WHERE source_type = 'segment' AND source_id = old.id; + INSERT INTO conversation_search(conversation_id, source_type, source_id, title, overview, transcript_text) + SELECT new.conversation_id, 'segment', new.id, c.title, c.overview, new.text + FROM conversations c + WHERE c.id = new.conversation_id AND new.deleted_at IS NULL AND c.deleted_at IS NULL; + END; + + CREATE TRIGGER transcript_segments_ad AFTER DELETE ON transcript_segments BEGIN + DELETE FROM conversation_search WHERE source_type = 'segment' AND source_id = old.id; + END; + "#, +}]; + +#[derive(Clone)] +pub struct Store { + conn: Arc>, +} + +struct Migration { + version: i64, + name: &'static str, + sql: &'static str, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Conversation { + pub id: String, + pub session_id: String, + pub title: String, + pub overview: String, + pub status: String, + pub started_at: DateTime, + pub ended_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, + pub cloud_id: Option, + pub sync_version: i64, + pub sync_state: String, + pub metadata_json: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TranscriptSegment { + pub id: String, + pub conversation_id: String, + pub session_id: String, + pub speaker_id: Option, + pub speaker_label: Option, + pub text: String, + pub start_ms: i64, + pub end_ms: i64, + pub segment_index: i64, + pub source: String, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, + pub cloud_id: Option, + pub sync_version: i64, + pub sync_state: String, + pub metadata_json: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SearchResult { + pub conversation_id: String, + pub title: String, + pub overview: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ProcessingJob { + pub id: String, + pub kind: String, + pub status: ProcessingJobStatus, + pub target_conversation_id: Option, + pub retry_count: i64, + pub max_retries: i64, + pub last_error: Option, + pub payload_json: String, + pub result_json: String, + pub queued_at: DateTime, + pub started_at: Option>, + pub completed_at: Option>, + pub failed_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, + pub cloud_id: Option, + pub sync_version: i64, + pub sync_state: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ProcessingJobStatus { + Queued, + Running, + Completed, + Failed, +} + +impl Store { + pub fn open(path: impl AsRef) -> Result { + let conn = Connection::open(path.as_ref()).with_context(|| { + format!("failed to open SQLite store at {}", path.as_ref().display()) + })?; + configure_connection(&conn)?; + run_migrations(&conn)?; + + Ok(Self { + conn: Arc::new(Mutex::new(conn)), + }) + } + + #[cfg(test)] + fn open_in_memory() -> Result { + let conn = Connection::open_in_memory().context("failed to open in-memory SQLite store")?; + configure_connection(&conn)?; + run_migrations(&conn)?; + + Ok(Self { + conn: Arc::new(Mutex::new(conn)), + }) + } + + pub fn conversations(&self) -> ConversationRepository { + ConversationRepository { + conn: Arc::clone(&self.conn), + } + } + + pub fn transcripts(&self) -> TranscriptRepository { + TranscriptRepository { + conn: Arc::clone(&self.conn), + } + } + + pub fn processing_jobs(&self) -> ProcessingJobRepository { + ProcessingJobRepository { + conn: Arc::clone(&self.conn), + } + } + + pub fn search(&self) -> SearchRepository { + SearchRepository { + conn: Arc::clone(&self.conn), + } + } +} + +pub struct ConversationRepository { + conn: Arc>, +} + +impl ConversationRepository { + pub fn create(&self, new: NewConversation) -> Result { + let now = Utc::now(); + let conversation = Conversation { + id: new.id, + session_id: new.session_id, + title: new.title, + overview: new.overview, + status: "open".to_string(), + started_at: new.started_at.unwrap_or(now), + ended_at: None, + created_at: now, + updated_at: now, + deleted_at: None, + cloud_id: None, + sync_version: 0, + sync_state: "local".to_string(), + metadata_json: json_or_empty_object(new.metadata)?, + }; + + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + INSERT INTO conversations ( + id, session_id, title, overview, status, started_at, ended_at, created_at, + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json + ) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14) + "#, + params![ + conversation.id, + conversation.session_id, + conversation.title, + conversation.overview, + conversation.status, + conversation.started_at, + conversation.ended_at, + conversation.created_at, + conversation.updated_at, + conversation.deleted_at, + conversation.cloud_id, + conversation.sync_version, + conversation.sync_state, + conversation.metadata_json + ], + ) + .context("failed to insert conversation")?; + + Ok(conversation) + } + + pub fn get(&self, id: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT id, session_id, title, overview, status, started_at, ended_at, created_at, + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json + FROM conversations + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![id], + map_conversation, + ) + .optional() + .context("failed to fetch conversation") + } +} + +pub struct TranscriptRepository { + conn: Arc>, +} + +impl TranscriptRepository { + pub fn append(&self, new: NewTranscriptSegment) -> Result { + let now = Utc::now(); + let segment = TranscriptSegment { + id: new.id, + conversation_id: new.conversation_id, + session_id: new.session_id, + speaker_id: new.speaker_id, + speaker_label: new.speaker_label, + text: new.text, + start_ms: new.start_ms, + end_ms: new.end_ms, + segment_index: new.segment_index, + source: new.source.unwrap_or_else(|| "local".to_string()), + created_at: now, + updated_at: now, + deleted_at: None, + cloud_id: None, + sync_version: 0, + sync_state: "local".to_string(), + metadata_json: json_or_empty_object(new.metadata)?, + }; + + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + INSERT INTO transcript_segments ( + id, conversation_id, session_id, speaker_id, speaker_label, text, start_ms, + end_ms, segment_index, source, created_at, updated_at, deleted_at, cloud_id, + sync_version, sync_state, metadata_json + ) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17) + "#, + params![ + segment.id, + segment.conversation_id, + segment.session_id, + segment.speaker_id, + segment.speaker_label, + segment.text, + segment.start_ms, + segment.end_ms, + segment.segment_index, + segment.source, + segment.created_at, + segment.updated_at, + segment.deleted_at, + segment.cloud_id, + segment.sync_version, + segment.sync_state, + segment.metadata_json + ], + ) + .context("failed to insert transcript segment")?; + + Ok(segment) + } + + pub fn list_for_conversation(&self, conversation_id: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let mut stmt = conn + .prepare( + r#" + SELECT id, conversation_id, session_id, speaker_id, speaker_label, text, start_ms, + end_ms, segment_index, source, created_at, updated_at, deleted_at, cloud_id, + sync_version, sync_state, metadata_json + FROM transcript_segments + WHERE conversation_id = ?1 AND deleted_at IS NULL + ORDER BY segment_index ASC + "#, + ) + .context("failed to prepare transcript segment list query")?; + + let rows = stmt + .query_map(params![conversation_id], map_transcript_segment) + .context("failed to list transcript segments")?; + + collect_rows(rows) + } +} + +pub struct ProcessingJobRepository { + conn: Arc>, +} + +impl ProcessingJobRepository { + pub fn enqueue(&self, new: NewProcessingJob) -> Result { + let now = Utc::now(); + let job = ProcessingJob { + id: new.id, + kind: new.kind, + status: ProcessingJobStatus::Queued, + target_conversation_id: new.target_conversation_id, + retry_count: 0, + max_retries: new.max_retries.unwrap_or(3), + last_error: None, + payload_json: json_or_empty_object(new.payload)?, + result_json: "{}".to_string(), + queued_at: now, + started_at: None, + completed_at: None, + failed_at: None, + created_at: now, + updated_at: now, + deleted_at: None, + cloud_id: None, + sync_version: 0, + sync_state: "local".to_string(), + }; + + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + INSERT INTO processing_jobs ( + id, kind, status, target_conversation_id, retry_count, max_retries, last_error, + payload_json, result_json, queued_at, started_at, completed_at, failed_at, + created_at, updated_at, deleted_at, cloud_id, sync_version, sync_state + ) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18, ?19) + "#, + params![ + job.id, + job.kind, + job.status.as_str(), + job.target_conversation_id, + job.retry_count, + job.max_retries, + job.last_error, + job.payload_json, + job.result_json, + job.queued_at, + job.started_at, + job.completed_at, + job.failed_at, + job.created_at, + job.updated_at, + job.deleted_at, + job.cloud_id, + job.sync_version, + job.sync_state + ], + ) + .context("failed to enqueue processing job")?; + + Ok(job) + } +} + +pub struct SearchRepository { + conn: Arc>, +} + +impl SearchRepository { + pub fn conversations(&self, query: &str, limit: i64) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let mut stmt = conn + .prepare( + r#" + SELECT c.id, c.title, c.overview + FROM conversations c + JOIN ( + SELECT DISTINCT conversation_id + FROM conversation_search + WHERE conversation_search MATCH ?1 + ) matches ON matches.conversation_id = c.id + WHERE c.deleted_at IS NULL + ORDER BY c.updated_at DESC + LIMIT ?2 + "#, + ) + .context("failed to prepare conversation search query")?; + + let rows = stmt + .query_map(params![query, limit], |row| { + Ok(SearchResult { + conversation_id: row.get(0)?, + title: row.get(1)?, + overview: row.get(2)?, + }) + }) + .context("failed to search conversations")?; + + collect_rows(rows) + } +} + +#[derive(Debug, Clone)] +pub struct NewConversation { + pub id: String, + pub session_id: String, + pub title: String, + pub overview: String, + pub started_at: Option>, + pub metadata: Option, +} + +#[derive(Debug, Clone)] +pub struct NewTranscriptSegment { + pub id: String, + pub conversation_id: String, + pub session_id: String, + pub speaker_id: Option, + pub speaker_label: Option, + pub text: String, + pub start_ms: i64, + pub end_ms: i64, + pub segment_index: i64, + pub source: Option, + pub metadata: Option, +} + +#[derive(Debug, Clone)] +pub struct NewProcessingJob { + pub id: String, + pub kind: String, + pub target_conversation_id: Option, + pub max_retries: Option, + pub payload: Option, +} + +pub fn deterministic_id(prefix: &str, parts: &[&str]) -> String { + let mut hasher = Sha256::new(); + hasher.update(prefix.as_bytes()); + for part in parts { + hasher.update([0]); + hasher.update(part.as_bytes()); + } + let digest = hasher.finalize(); + format!("{prefix}_{digest:x}") + .chars() + .take(prefix.len() + 1 + 32) + .collect() +} + +fn configure_connection(conn: &Connection) -> Result<()> { + conn.pragma_update(None, "foreign_keys", "ON") + .context("failed to enable SQLite foreign keys")?; + conn.pragma_update(None, "journal_mode", "WAL") + .context("failed to enable SQLite WAL journal mode")?; + conn.pragma_update(None, "busy_timeout", 5000) + .context("failed to set SQLite busy timeout")?; + Ok(()) +} + +fn run_migrations(conn: &Connection) -> Result<()> { + conn.execute( + "CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + name TEXT NOT NULL, + applied_at TEXT NOT NULL + )", + [], + ) + .context("failed to create schema_migrations table")?; + + for migration in MIGRATIONS { + let applied = conn + .query_row( + "SELECT 1 FROM schema_migrations WHERE version = ?1", + params![migration.version], + |_| Ok(()), + ) + .optional() + .context("failed to check migration state")? + .is_some(); + + if applied { + continue; + } + + let tx = conn + .unchecked_transaction() + .context("failed to start migration transaction")?; + tx.execute_batch(migration.sql) + .with_context(|| format!("failed to apply migration {}", migration.name))?; + tx.execute( + "INSERT INTO schema_migrations (version, name, applied_at) VALUES (?1, ?2, ?3)", + params![migration.version, migration.name, Utc::now()], + ) + .context("failed to record migration")?; + tx.commit().context("failed to commit migration")?; + } + + Ok(()) +} + +fn json_or_empty_object(value: Option) -> Result { + serde_json::to_string(&value.unwrap_or_else(|| serde_json::json!({}))) + .context("failed to serialize JSON metadata") +} + +fn collect_rows( + rows: rusqlite::MappedRows<'_, impl FnMut(&rusqlite::Row<'_>) -> rusqlite::Result>, +) -> Result> { + rows.collect::>>() + .context("failed to collect SQLite rows") +} + +fn map_conversation(row: &rusqlite::Row<'_>) -> rusqlite::Result { + Ok(Conversation { + id: row.get(0)?, + session_id: row.get(1)?, + title: row.get(2)?, + overview: row.get(3)?, + status: row.get(4)?, + started_at: row.get(5)?, + ended_at: row.get(6)?, + created_at: row.get(7)?, + updated_at: row.get(8)?, + deleted_at: row.get(9)?, + cloud_id: row.get(10)?, + sync_version: row.get(11)?, + sync_state: row.get(12)?, + metadata_json: row.get(13)?, + }) +} + +fn map_transcript_segment(row: &rusqlite::Row<'_>) -> rusqlite::Result { + Ok(TranscriptSegment { + id: row.get(0)?, + conversation_id: row.get(1)?, + session_id: row.get(2)?, + speaker_id: row.get(3)?, + speaker_label: row.get(4)?, + text: row.get(5)?, + start_ms: row.get(6)?, + end_ms: row.get(7)?, + segment_index: row.get(8)?, + source: row.get(9)?, + created_at: row.get(10)?, + updated_at: row.get(11)?, + deleted_at: row.get(12)?, + cloud_id: row.get(13)?, + sync_version: row.get(14)?, + sync_state: row.get(15)?, + metadata_json: row.get(16)?, + }) +} + +impl ProcessingJobStatus { + fn as_str(&self) -> &'static str { + match self { + Self::Queued => "queued", + Self::Running => "running", + Self::Completed => "completed", + Self::Failed => "failed", + } + } +} + +#[cfg(test)] +mod tests { + use tempfile::tempdir; + + use super::*; + + #[test] + fn migrations_create_expected_tables_and_pragmas() -> Result<()> { + let store = Store::open_in_memory()?; + let conn = store.conn.lock().expect("SQLite connection mutex poisoned"); + + let foreign_keys: i64 = conn.query_row("PRAGMA foreign_keys", [], |row| row.get(0))?; + assert_eq!(foreign_keys, 1); + + for table in [ + "conversations", + "transcript_segments", + "memories", + "action_items", + "local_settings", + "local_profiles", + "processing_jobs", + "sync_outbox", + "local_files", + "conversation_search", + ] { + let exists: i64 = conn.query_row( + "SELECT COUNT(*) FROM sqlite_master WHERE name = ?1", + params![table], + |row| row.get(0), + )?; + assert_eq!(exists, 1, "missing table {table}"); + } + + Ok(()) + } + + #[test] + fn conversation_and_segments_persist_after_reopen() -> Result<()> { + let temp = tempdir()?; + let db_path = temp.path().join("local.sqlite"); + let conversation_id = deterministic_id("conv", &["session-a"]); + + { + let store = Store::open(&db_path)?; + store.conversations().create(NewConversation { + id: conversation_id.clone(), + session_id: "session-a".to_string(), + title: "Planning sync".to_string(), + overview: "MVP storage discussion".to_string(), + started_at: None, + metadata: None, + })?; + store.transcripts().append(NewTranscriptSegment { + id: deterministic_id("seg", &[&conversation_id, "0"]), + conversation_id: conversation_id.clone(), + session_id: "session-a".to_string(), + speaker_id: Some("speaker-1".to_string()), + speaker_label: Some("Alice".to_string()), + text: "We need local persistence.".to_string(), + start_ms: 0, + end_ms: 1500, + segment_index: 0, + source: None, + metadata: None, + })?; + } + + let reopened = Store::open(&db_path)?; + let conversation = reopened.conversations().get(&conversation_id)?; + let segments = reopened + .transcripts() + .list_for_conversation(&conversation_id)?; + + assert_eq!( + conversation.expect("conversation should persist").title, + "Planning sync" + ); + assert_eq!(segments.len(), 1); + assert_eq!(segments[0].text, "We need local persistence."); + + Ok(()) + } + + #[test] + fn fts_search_matches_conversation_and_transcript_text() -> Result<()> { + let store = Store::open_in_memory()?; + let conversation_id = deterministic_id("conv", &["session-search"]); + + store.conversations().create(NewConversation { + id: conversation_id.clone(), + session_id: "session-search".to_string(), + title: "Weekly design review".to_string(), + overview: "Discuss local backend schema".to_string(), + started_at: None, + metadata: None, + })?; + store.transcripts().append(NewTranscriptSegment { + id: deterministic_id("seg", &[&conversation_id, "0"]), + conversation_id: conversation_id.clone(), + session_id: "session-search".to_string(), + speaker_id: None, + speaker_label: None, + text: "The transcript mentions vector clocks and durable outbox sync.".to_string(), + start_ms: 0, + end_ms: 3000, + segment_index: 0, + source: None, + metadata: None, + })?; + + let title_results = store.search().conversations("design", 10)?; + assert_eq!(title_results.len(), 1); + assert_eq!(title_results[0].conversation_id, conversation_id); + + let transcript_results = store.search().conversations("durable", 10)?; + assert_eq!(transcript_results.len(), 1); + assert_eq!(transcript_results[0].title, "Weekly design review"); + + Ok(()) + } + + #[test] + fn processing_jobs_start_queued_with_retry_metadata() -> Result<()> { + let store = Store::open_in_memory()?; + let job = store.processing_jobs().enqueue(NewProcessingJob { + id: deterministic_id("job", &["summarize", "conversation-1"]), + kind: "summarize_conversation".to_string(), + target_conversation_id: None, + max_retries: Some(5), + payload: Some(serde_json::json!({"conversation_id": "conversation-1"})), + })?; + + assert_eq!(job.status, ProcessingJobStatus::Queued); + assert_eq!(job.retry_count, 0); + assert_eq!(job.max_retries, 5); + assert!(job.last_error.is_none()); + + Ok(()) + } +} From 491246a71e1fa2a24fc960f5dccba6479128766d Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 21:07:57 +0700 Subject: [PATCH 03/58] Expose local backend MVP APIs --- desktop/local-backend/Cargo.lock | 1 + desktop/local-backend/Cargo.toml | 1 + desktop/local-backend/README.md | 37 ++ desktop/local-backend/src/main.rs | 203 ++++++- desktop/local-backend/src/routes.rs | 667 ++++++++++++++++++++++ desktop/local-backend/src/storage.rs | 798 ++++++++++++++++++++++++++- 6 files changed, 1701 insertions(+), 6 deletions(-) create mode 100644 desktop/local-backend/src/routes.rs diff --git a/desktop/local-backend/Cargo.lock b/desktop/local-backend/Cargo.lock index 6cdd0cafa78..4f61b0e16b4 100644 --- a/desktop/local-backend/Cargo.lock +++ b/desktop/local-backend/Cargo.lock @@ -648,6 +648,7 @@ dependencies = [ "sha2", "tempfile", "tokio", + "tower", "tower-http", "tracing", "tracing-subscriber", diff --git a/desktop/local-backend/Cargo.toml b/desktop/local-backend/Cargo.toml index ed99d975ea3..6ab5d434f1e 100644 --- a/desktop/local-backend/Cargo.toml +++ b/desktop/local-backend/Cargo.toml @@ -21,5 +21,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json"] [dev-dependencies] tempfile = "3" +tower = { version = "0.5", features = ["util"] } [workspace] diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md index ea3cc18239b..360b298c45b 100644 --- a/desktop/local-backend/README.md +++ b/desktop/local-backend/README.md @@ -32,3 +32,40 @@ curl http://127.0.0.1:8765/health The response includes the service name, local mode, package version, bind address, and resolved data directory. + +## MVP HTTP API + +The local daemon exposes JSON endpoints for the desktop MVP: + +- `GET /health`, `GET /version`, `GET /profile/status` +- `GET|POST /v1/conversations` +- `GET|PATCH|DELETE /v1/conversations/:id` +- `POST /v1/conversations/:id/transcript-segments` +- `POST /v1/conversations/:id/finalize-transcript` +- `GET /v1/search/conversations?q=` +- `GET|POST /v1/memories` +- `GET|PATCH|DELETE /v1/memories/:id` +- `GET|POST /v1/action-items` +- `GET|PATCH|DELETE /v1/action-items/:id` +- `GET|PUT /v1/profile` +- `GET|PUT /v1/settings` +- `GET /v1/processing-jobs` +- `GET /v1/processing-jobs/:id` +- `GET /v1/processing-jobs/status` + +Finalizing transcript ingestion currently enqueues a local `finalize_transcript` +processing job. Later processing workers can consume the same durable +`processing_jobs` rows and update queued/running/completed/failed state. + +## Differences From The Cloud API + +The local daemon is intentionally unauthenticated on loopback for the MVP. It +does not require Firebase ID tokens, does not return paywall errors, does not +create GCS signed URLs, and does not depend on Redis, Firestore, pusher, or +agent-proxy coordination. + +Local responses use explicit JSON errors with `source: "local_daemon"`. Profile +status reports `authenticated: false` because local-first mode does not imply an +Omi cloud account. Cloud IDs and sync fields are retained in storage models so a +future sync adapter can map local records to cloud records without making cloud +state the source of truth. diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index 5f5617f463f..d131cbba8b6 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -2,13 +2,14 @@ use std::fs; use std::sync::Arc; use anyhow::Result; -use axum::{routing::get, Router}; +use axum::Router; use tokio::net::TcpListener; use tower_http::trace::TraceLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod config; mod health; +mod routes; mod storage; use config::Config; @@ -35,10 +36,7 @@ async fn main() -> Result<()> { store, }; - let app = Router::new() - .route("/health", get(health)) - .layer(TraceLayer::new_for_http()) - .with_state(state); + let app = app(state); let listener = TcpListener::bind(bind_addr).await?; tracing::info!( @@ -55,6 +53,14 @@ async fn main() -> Result<()> { Ok(()) } +fn app(state: AppState) -> Router { + Router::new() + .merge(routes::router()) + .route("/health", axum::routing::get(health)) + .layer(TraceLayer::new_for_http()) + .with_state(state) +} + fn init_tracing() { tracing_subscriber::registry() .with( @@ -88,3 +94,190 @@ async fn shutdown_signal() { _ = terminate => {}, } } + +#[cfg(test)] +mod tests { + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + use std::sync::Arc; + + use anyhow::Result; + use axum::{ + body::{to_bytes, Body}, + http::{Method, Request, StatusCode}, + }; + use serde_json::{json, Value}; + use tower::ServiceExt; + + use super::*; + + #[tokio::test] + async fn mvp_routes_support_local_desktop_flow() -> Result<()> { + let app = test_app()?; + + let created = request_json( + app.clone(), + Method::POST, + "/v1/conversations", + Some(json!({ + "session_id": "session-route", + "title": "Local route test", + "overview": "Exercise the MVP API" + })), + ) + .await?; + let conversation_id = created["conversation"]["id"] + .as_str() + .expect("conversation id") + .to_string(); + + request_json( + app.clone(), + Method::POST, + &format!("/v1/conversations/{conversation_id}/transcript-segments"), + Some(json!({ + "text": "The local daemon stores transcript text for search.", + "start_ms": 0, + "end_ms": 1200 + })), + ) + .await?; + + let job = request_json( + app.clone(), + Method::POST, + &format!("/v1/conversations/{conversation_id}/finalize-transcript"), + None, + ) + .await?; + assert_eq!(job["processing_job"]["status"], "queued"); + + let conversation = request_json( + app.clone(), + Method::GET, + &format!("/v1/conversations/{conversation_id}"), + None, + ) + .await?; + assert_eq!( + conversation["transcript_segments"] + .as_array() + .unwrap() + .len(), + 1 + ); + + let search = request_json( + app.clone(), + Method::GET, + "/v1/search/conversations?q=daemon", + None, + ) + .await?; + assert_eq!(search["results"].as_array().unwrap().len(), 1); + + let status = + request_json(app.clone(), Method::GET, "/v1/processing-jobs/status", None).await?; + assert_eq!(status["queued"], 1); + + let memory = request_json( + app.clone(), + Method::POST, + "/v1/memories", + Some(json!({"content": "Prefers local-first desktop mode"})), + ) + .await?; + assert!(memory["memory"]["id"].is_string()); + + let memories = request_json(app.clone(), Method::GET, "/v1/memories", None).await?; + assert_eq!(memories["memories"].as_array().unwrap().len(), 1); + + let action_item = request_json( + app.clone(), + Method::POST, + "/v1/action-items", + Some(json!({"title": "Review local processing status"})), + ) + .await?; + assert_eq!(action_item["action_item"]["status"], "open"); + + let action_items = request_json(app, Method::GET, "/v1/action-items", None).await?; + assert_eq!(action_items["action_items"].as_array().unwrap().len(), 1); + + Ok(()) + } + + #[tokio::test] + async fn profile_and_settings_routes_are_local_without_auth() -> Result<()> { + let app = test_app()?; + + let status = request_json(app.clone(), Method::GET, "/profile/status", None).await?; + assert_eq!(status["mode"], "local"); + assert_eq!(status["authenticated"], false); + + let profile = request_json( + app.clone(), + Method::PUT, + "/v1/profile", + Some(json!({ + "display_name": "Local User", + "timezone": "UTC", + "locale": "en" + })), + ) + .await?; + assert_eq!(profile["profile"]["display_name"], "Local User"); + + let settings = request_json( + app, + Method::PUT, + "/v1/settings", + Some(json!({ + "provider": {"kind": "openai"}, + "local_first": true + })), + ) + .await?; + assert_eq!(settings["settings"].as_array().unwrap().len(), 2); + + Ok(()) + } + + fn test_app() -> Result { + let config = Config { + bind_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + data_dir: std::env::temp_dir().join("omi-local-backend-route-tests"), + }; + let state = AppState { + config: Arc::new(config), + store: Store::open_in_memory()?, + }; + Ok(app(state)) + } + + async fn request_json( + app: Router, + method: Method, + uri: &str, + body: Option, + ) -> Result { + let request_body = match body { + Some(value) => Body::from(serde_json::to_vec(&value)?), + None => Body::empty(), + }; + let request = Request::builder() + .method(method) + .uri(uri) + .header("content-type", "application/json") + .body(request_body)?; + + let response = app.oneshot(request).await?; + let status = response.status(); + let bytes = to_bytes(response.into_body(), 1024 * 1024).await?; + assert!( + status == StatusCode::OK || status == StatusCode::CREATED, + "unexpected status {status}: {}", + String::from_utf8_lossy(&bytes) + ); + Ok(serde_json::from_slice(&bytes)?) + } +} diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs new file mode 100644 index 00000000000..4a874c1fffc --- /dev/null +++ b/desktop/local-backend/src/routes.rs @@ -0,0 +1,667 @@ +use axum::{ + extract::{Path, Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, + routing::{get, post}, + Json, Router, +}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Map, Value}; + +use crate::{ + storage::{ + deterministic_id, NewActionItem, NewConversation, NewMemory, NewProcessingJob, + NewTranscriptSegment, UpdateActionItem, UpdateConversation, UpdateMemory, UpdateProfile, + }, + AppState, +}; + +pub fn router() -> Router { + Router::new() + .route("/version", get(version)) + .route("/profile/status", get(profile_status)) + .route("/v1/profile", get(get_profile).put(update_profile)) + .route("/v1/settings", get(list_settings).put(update_settings)) + .route( + "/v1/conversations", + get(list_conversations).post(create_conversation), + ) + .route( + "/v1/conversations/:id", + get(get_conversation) + .patch(update_conversation) + .delete(delete_conversation), + ) + .route( + "/v1/conversations/:id/transcript-segments", + post(append_transcript_segment), + ) + .route( + "/v1/conversations/:id/finalize-transcript", + post(finalize_transcript), + ) + .route("/v1/search/conversations", get(search_conversations)) + .route("/v1/memories", get(list_memories).post(create_memory)) + .route( + "/v1/memories/:id", + get(get_memory).patch(update_memory).delete(delete_memory), + ) + .route( + "/v1/action-items", + get(list_action_items).post(create_action_item), + ) + .route( + "/v1/action-items/:id", + get(get_action_item) + .patch(update_action_item) + .delete(delete_action_item), + ) + .route("/v1/processing-jobs", get(list_processing_jobs)) + .route("/v1/processing-jobs/status", get(processing_status)) + .route("/v1/processing-jobs/:id", get(get_processing_job)) +} + +#[derive(Debug)] +struct ApiError { + status: StatusCode, + message: String, +} + +impl ApiError { + fn bad_request(message: impl Into) -> Self { + Self { + status: StatusCode::BAD_REQUEST, + message: message.into(), + } + } + + fn not_found(entity: &str) -> Self { + Self { + status: StatusCode::NOT_FOUND, + message: format!("{entity} not found"), + } + } + + fn internal(error: anyhow::Error) -> Self { + Self { + status: StatusCode::INTERNAL_SERVER_ERROR, + message: error.to_string(), + } + } +} + +impl IntoResponse for ApiError { + fn into_response(self) -> Response { + ( + self.status, + Json(json!({ + "error": { + "code": self.status.as_u16(), + "message": self.message, + "source": "local_daemon" + } + })), + ) + .into_response() + } +} + +type ApiResult = Result, ApiError>; + +#[derive(Serialize)] +struct VersionResponse { + service: &'static str, + mode: &'static str, + version: &'static str, +} + +async fn version() -> Json { + Json(VersionResponse { + service: "omi-local-backend", + mode: "local", + version: env!("CARGO_PKG_VERSION"), + }) +} + +async fn profile_status(State(state): State) -> ApiResult { + let profile = state + .store + .profile() + .get_or_create_default() + .map_err(ApiError::internal)?; + Ok(Json(json!({ + "mode": "local", + "authenticated": false, + "profile": profile, + "backend": { + "service": "omi-local-backend", + "version": env!("CARGO_PKG_VERSION"), + "data_dir": state.config.data_dir + } + }))) +} + +#[derive(Deserialize)] +struct ListQuery { + limit: Option, +} + +async fn list_conversations( + State(state): State, + Query(query): Query, +) -> ApiResult { + let conversations = state + .store + .conversations() + .list(limit_or_default(query.limit)) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "conversations": conversations }))) +} + +#[derive(Deserialize)] +struct CreateConversationRequest { + id: Option, + session_id: Option, + title: Option, + overview: Option, + started_at: Option>, + metadata: Option, +} + +async fn create_conversation( + State(state): State, + Json(request): Json, +) -> ApiResult { + let session_id = request.session_id.unwrap_or_else(|| local_id("session")); + let id = request + .id + .unwrap_or_else(|| deterministic_id("conv", &[&session_id])); + let conversation = state + .store + .conversations() + .create(NewConversation { + id, + session_id, + title: request.title.unwrap_or_default(), + overview: request.overview.unwrap_or_default(), + started_at: request.started_at, + metadata: request.metadata, + }) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "conversation": conversation }))) +} + +async fn get_conversation( + State(state): State, + Path(id): Path, +) -> ApiResult { + let conversation = state + .store + .conversations() + .get(&id) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("conversation"))?; + let transcript_segments = state + .store + .transcripts() + .list_for_conversation(&id) + .map_err(ApiError::internal)?; + Ok(Json(json!({ + "conversation": conversation, + "transcript_segments": transcript_segments + }))) +} + +#[derive(Deserialize)] +struct UpdateConversationRequest { + title: Option, + overview: Option, + status: Option, + ended_at: Option>, + metadata: Option, +} + +async fn update_conversation( + State(state): State, + Path(id): Path, + Json(request): Json, +) -> ApiResult { + let conversation = state + .store + .conversations() + .update( + &id, + UpdateConversation { + title: request.title, + overview: request.overview, + status: request.status, + ended_at: request.ended_at.map(Some), + metadata: request.metadata, + }, + ) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("conversation"))?; + Ok(Json(json!({ "conversation": conversation }))) +} + +async fn delete_conversation( + State(state): State, + Path(id): Path, +) -> Result { + if state + .store + .conversations() + .soft_delete(&id) + .map_err(ApiError::internal)? + { + Ok(StatusCode::NO_CONTENT) + } else { + Err(ApiError::not_found("conversation")) + } +} + +#[derive(Deserialize)] +struct AppendSegmentRequest { + id: Option, + session_id: Option, + speaker_id: Option, + speaker_label: Option, + text: String, + start_ms: i64, + end_ms: i64, + segment_index: Option, + source: Option, + metadata: Option, +} + +async fn append_transcript_segment( + State(state): State, + Path(conversation_id): Path, + Json(request): Json, +) -> ApiResult { + let conversation = state + .store + .conversations() + .get(&conversation_id) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("conversation"))?; + if request.text.trim().is_empty() { + return Err(ApiError::bad_request("transcript segment text is required")); + } + let segment_index = match request.segment_index { + Some(index) => index, + None => state + .store + .transcripts() + .next_segment_index(&conversation_id) + .map_err(ApiError::internal)?, + }; + let id = request.id.unwrap_or_else(|| { + deterministic_id("seg", &[&conversation_id, &segment_index.to_string()]) + }); + let segment = state + .store + .transcripts() + .append(NewTranscriptSegment { + id, + conversation_id, + session_id: request.session_id.unwrap_or(conversation.session_id), + speaker_id: request.speaker_id, + speaker_label: request.speaker_label, + text: request.text, + start_ms: request.start_ms, + end_ms: request.end_ms, + segment_index, + source: request.source, + metadata: request.metadata, + }) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "transcript_segment": segment }))) +} + +async fn finalize_transcript( + State(state): State, + Path(conversation_id): Path, +) -> ApiResult { + state + .store + .conversations() + .get(&conversation_id) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("conversation"))?; + let job = state + .store + .processing_jobs() + .enqueue(NewProcessingJob { + id: local_id("job"), + kind: "finalize_transcript".to_string(), + target_conversation_id: Some(conversation_id.clone()), + max_retries: Some(3), + payload: Some(json!({ "conversation_id": conversation_id })), + }) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "processing_job": job }))) +} + +#[derive(Deserialize)] +struct SearchQuery { + q: String, + limit: Option, +} + +async fn search_conversations( + State(state): State, + Query(query): Query, +) -> ApiResult { + if query.q.trim().is_empty() { + return Err(ApiError::bad_request("search query is required")); + } + let results = state + .store + .search() + .conversations(&query.q, limit_or_default(query.limit)) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "results": results }))) +} + +async fn list_memories(State(state): State) -> ApiResult { + let memories = state.store.memories().list().map_err(ApiError::internal)?; + Ok(Json(json!({ "memories": memories }))) +} + +#[derive(Deserialize)] +struct CreateMemoryRequest { + id: Option, + content: String, + category: Option, + conversation_id: Option, + metadata: Option, +} + +async fn create_memory( + State(state): State, + Json(request): Json, +) -> ApiResult { + let memory = state + .store + .memories() + .create(NewMemory { + id: request.id.unwrap_or_else(|| local_id("mem")), + content: request.content, + category: request.category, + conversation_id: request.conversation_id, + metadata: request.metadata, + }) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "memory": memory }))) +} + +async fn get_memory(State(state): State, Path(id): Path) -> ApiResult { + let memory = state + .store + .memories() + .get(&id) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("memory"))?; + Ok(Json(json!({ "memory": memory }))) +} + +#[derive(Deserialize)] +struct UpdateMemoryRequest { + content: Option, + category: Option, + conversation_id: Option, + metadata: Option, +} + +async fn update_memory( + State(state): State, + Path(id): Path, + Json(request): Json, +) -> ApiResult { + let memory = state + .store + .memories() + .update( + &id, + UpdateMemory { + content: request.content, + category: request.category.map(Some), + conversation_id: request.conversation_id.map(Some), + metadata: request.metadata, + }, + ) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("memory"))?; + Ok(Json(json!({ "memory": memory }))) +} + +async fn delete_memory( + State(state): State, + Path(id): Path, +) -> Result { + if state + .store + .memories() + .soft_delete(&id) + .map_err(ApiError::internal)? + { + Ok(StatusCode::NO_CONTENT) + } else { + Err(ApiError::not_found("memory")) + } +} + +async fn list_action_items(State(state): State) -> ApiResult { + let action_items = state + .store + .action_items() + .list() + .map_err(ApiError::internal)?; + Ok(Json(json!({ "action_items": action_items }))) +} + +#[derive(Deserialize)] +struct CreateActionItemRequest { + id: Option, + conversation_id: Option, + title: String, + description: Option, + status: Option, + due_at: Option>, + metadata: Option, +} + +async fn create_action_item( + State(state): State, + Json(request): Json, +) -> ApiResult { + let action_item = state + .store + .action_items() + .create(NewActionItem { + id: request.id.unwrap_or_else(|| local_id("act")), + conversation_id: request.conversation_id, + title: request.title, + description: request.description, + status: request.status, + due_at: request.due_at, + metadata: request.metadata, + }) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "action_item": action_item }))) +} + +async fn get_action_item( + State(state): State, + Path(id): Path, +) -> ApiResult { + let action_item = state + .store + .action_items() + .get(&id) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("action item"))?; + Ok(Json(json!({ "action_item": action_item }))) +} + +#[derive(Deserialize)] +struct UpdateActionItemRequest { + conversation_id: Option, + title: Option, + description: Option, + status: Option, + due_at: Option>, + metadata: Option, +} + +async fn update_action_item( + State(state): State, + Path(id): Path, + Json(request): Json, +) -> ApiResult { + let action_item = state + .store + .action_items() + .update( + &id, + UpdateActionItem { + conversation_id: request.conversation_id.map(Some), + title: request.title, + description: request.description, + status: request.status, + due_at: request.due_at.map(Some), + metadata: request.metadata, + }, + ) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("action item"))?; + Ok(Json(json!({ "action_item": action_item }))) +} + +async fn delete_action_item( + State(state): State, + Path(id): Path, +) -> Result { + if state + .store + .action_items() + .soft_delete(&id) + .map_err(ApiError::internal)? + { + Ok(StatusCode::NO_CONTENT) + } else { + Err(ApiError::not_found("action item")) + } +} + +async fn get_profile(State(state): State) -> ApiResult { + let profile = state + .store + .profile() + .get_or_create_default() + .map_err(ApiError::internal)?; + Ok(Json(json!({ "profile": profile }))) +} + +#[derive(Deserialize)] +struct UpdateProfileRequest { + display_name: Option, + timezone: Option, + locale: Option, + metadata: Option, +} + +async fn update_profile( + State(state): State, + Json(request): Json, +) -> ApiResult { + let profile = state + .store + .profile() + .upsert(UpdateProfile { + display_name: request.display_name, + timezone: request.timezone, + locale: request.locale, + metadata: request.metadata, + }) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "profile": profile }))) +} + +async fn list_settings(State(state): State) -> ApiResult { + let settings = state.store.settings().list().map_err(ApiError::internal)?; + Ok(Json(json!({ "settings": settings }))) +} + +async fn update_settings( + State(state): State, + Json(values): Json>, +) -> ApiResult { + let settings = state + .store + .settings() + .upsert_many(values) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "settings": settings }))) +} + +async fn list_processing_jobs(State(state): State) -> ApiResult { + let jobs = state + .store + .processing_jobs() + .list() + .map_err(ApiError::internal)?; + Ok(Json(json!({ "processing_jobs": jobs }))) +} + +async fn get_processing_job( + State(state): State, + Path(id): Path, +) -> ApiResult { + let job = state + .store + .processing_jobs() + .get(&id) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("processing job"))?; + Ok(Json(json!({ "processing_job": job }))) +} + +async fn processing_status(State(state): State) -> ApiResult { + let jobs = state + .store + .processing_jobs() + .list() + .map_err(ApiError::internal)?; + let mut queued = 0; + let mut running = 0; + let mut completed = 0; + let mut failed = 0; + for job in jobs { + match job.status { + crate::storage::ProcessingJobStatus::Queued => queued += 1, + crate::storage::ProcessingJobStatus::Running => running += 1, + crate::storage::ProcessingJobStatus::Completed => completed += 1, + crate::storage::ProcessingJobStatus::Failed => failed += 1, + } + } + Ok(Json(json!({ + "queued": queued, + "running": running, + "completed": completed, + "failed": failed + }))) +} + +fn limit_or_default(limit: Option) -> i64 { + limit.unwrap_or(50).clamp(1, 200) +} + +fn local_id(prefix: &str) -> String { + let now = Utc::now() + .timestamp_nanos_opt() + .unwrap_or_else(|| Utc::now().timestamp_micros() * 1000); + deterministic_id(prefix, &[&now.to_string()]) +} diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index 691e87c899d..091285da326 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -300,6 +300,65 @@ pub struct ProcessingJob { pub sync_state: String, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Memory { + pub id: String, + pub content: String, + pub category: Option, + pub conversation_id: Option, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, + pub cloud_id: Option, + pub sync_version: i64, + pub sync_state: String, + pub metadata_json: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ActionItem { + pub id: String, + pub conversation_id: Option, + pub title: String, + pub description: String, + pub status: String, + pub due_at: Option>, + pub completed_at: Option>, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, + pub cloud_id: Option, + pub sync_version: i64, + pub sync_state: String, + pub metadata_json: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct LocalProfile { + pub id: String, + pub display_name: String, + pub timezone: Option, + pub locale: Option, + pub created_at: DateTime, + pub updated_at: DateTime, + pub deleted_at: Option>, + pub cloud_id: Option, + pub sync_version: i64, + pub sync_state: String, + pub metadata_json: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct LocalSetting { + pub key: String, + pub value_json: String, + pub updated_at: DateTime, + pub deleted_at: Option>, + pub cloud_id: Option, + pub sync_version: i64, + pub sync_state: String, +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ProcessingJobStatus { @@ -323,7 +382,7 @@ impl Store { } #[cfg(test)] - fn open_in_memory() -> Result { + pub(crate) fn open_in_memory() -> Result { let conn = Connection::open_in_memory().context("failed to open in-memory SQLite store")?; configure_connection(&conn)?; run_migrations(&conn)?; @@ -356,6 +415,30 @@ impl Store { conn: Arc::clone(&self.conn), } } + + pub fn memories(&self) -> MemoryRepository { + MemoryRepository { + conn: Arc::clone(&self.conn), + } + } + + pub fn action_items(&self) -> ActionItemRepository { + ActionItemRepository { + conn: Arc::clone(&self.conn), + } + } + + pub fn profile(&self) -> ProfileRepository { + ProfileRepository { + conn: Arc::clone(&self.conn), + } + } + + pub fn settings(&self) -> SettingsRepository { + SettingsRepository { + conn: Arc::clone(&self.conn), + } + } } pub struct ConversationRepository { @@ -428,6 +511,83 @@ impl ConversationRepository { .optional() .context("failed to fetch conversation") } + + pub fn list(&self, limit: i64) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let mut stmt = conn + .prepare( + r#" + SELECT id, session_id, title, overview, status, started_at, ended_at, created_at, + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json + FROM conversations + WHERE deleted_at IS NULL + ORDER BY updated_at DESC + LIMIT ?1 + "#, + ) + .context("failed to prepare conversation list query")?; + let rows = stmt + .query_map(params![limit], map_conversation) + .context("failed to list conversations")?; + collect_rows(rows) + } + + pub fn update(&self, id: &str, update: UpdateConversation) -> Result> { + let Some(mut conversation) = self.get(id)? else { + return Ok(None); + }; + if let Some(title) = update.title { + conversation.title = title; + } + if let Some(overview) = update.overview { + conversation.overview = overview; + } + if let Some(status) = update.status { + conversation.status = status; + } + if let Some(ended_at) = update.ended_at { + conversation.ended_at = ended_at; + } + if let Some(metadata) = update.metadata { + conversation.metadata_json = json_or_empty_object(Some(metadata))?; + } + conversation.updated_at = Utc::now(); + + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + UPDATE conversations + SET title = ?2, overview = ?3, status = ?4, ended_at = ?5, updated_at = ?6, + metadata_json = ?7, sync_version = sync_version + 1 + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![ + conversation.id, + conversation.title, + conversation.overview, + conversation.status, + conversation.ended_at, + conversation.updated_at, + conversation.metadata_json + ], + ) + .context("failed to update conversation")?; + + drop(conn); + self.get(id) + } + + pub fn soft_delete(&self, id: &str) -> Result { + let now = Utc::now(); + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let changed = conn + .execute( + "UPDATE conversations SET deleted_at = ?2, updated_at = ?2 WHERE id = ?1 AND deleted_at IS NULL", + params![id, now], + ) + .context("failed to delete conversation")?; + Ok(changed > 0) + } } pub struct TranscriptRepository { @@ -513,6 +673,16 @@ impl TranscriptRepository { collect_rows(rows) } + + pub fn next_segment_index(&self, conversation_id: &str) -> Result { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + "SELECT COALESCE(MAX(segment_index) + 1, 0) FROM transcript_segments WHERE conversation_id = ?1", + params![conversation_id], + |row| row.get(0), + ) + .context("failed to fetch next segment index") + } } pub struct ProcessingJobRepository { @@ -580,12 +750,467 @@ impl ProcessingJobRepository { Ok(job) } + + pub fn get(&self, id: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT id, kind, status, target_conversation_id, retry_count, max_retries, last_error, + payload_json, result_json, queued_at, started_at, completed_at, failed_at, + created_at, updated_at, deleted_at, cloud_id, sync_version, sync_state + FROM processing_jobs + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![id], + map_processing_job, + ) + .optional() + .context("failed to fetch processing job") + } + + pub fn list(&self) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let mut stmt = conn + .prepare( + r#" + SELECT id, kind, status, target_conversation_id, retry_count, max_retries, last_error, + payload_json, result_json, queued_at, started_at, completed_at, failed_at, + created_at, updated_at, deleted_at, cloud_id, sync_version, sync_state + FROM processing_jobs + WHERE deleted_at IS NULL + ORDER BY queued_at DESC + "#, + ) + .context("failed to prepare processing job list query")?; + let rows = stmt + .query_map([], map_processing_job) + .context("failed to list processing jobs")?; + collect_rows(rows) + } } pub struct SearchRepository { conn: Arc>, } +pub struct MemoryRepository { + conn: Arc>, +} + +impl MemoryRepository { + pub fn create(&self, new: NewMemory) -> Result { + let now = Utc::now(); + let memory = Memory { + id: new.id, + content: new.content, + category: new.category, + conversation_id: new.conversation_id, + created_at: now, + updated_at: now, + deleted_at: None, + cloud_id: None, + sync_version: 0, + sync_state: "local".to_string(), + metadata_json: json_or_empty_object(new.metadata)?, + }; + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + INSERT INTO memories ( + id, content, category, conversation_id, created_at, updated_at, deleted_at, + cloud_id, sync_version, sync_state, metadata_json + ) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11) + "#, + params![ + memory.id, + memory.content, + memory.category, + memory.conversation_id, + memory.created_at, + memory.updated_at, + memory.deleted_at, + memory.cloud_id, + memory.sync_version, + memory.sync_state, + memory.metadata_json + ], + ) + .context("failed to create memory")?; + Ok(memory) + } + + pub fn get(&self, id: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT id, content, category, conversation_id, created_at, updated_at, deleted_at, + cloud_id, sync_version, sync_state, metadata_json + FROM memories + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![id], + map_memory, + ) + .optional() + .context("failed to fetch memory") + } + + pub fn list(&self) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let mut stmt = conn + .prepare( + r#" + SELECT id, content, category, conversation_id, created_at, updated_at, deleted_at, + cloud_id, sync_version, sync_state, metadata_json + FROM memories + WHERE deleted_at IS NULL + ORDER BY updated_at DESC + "#, + ) + .context("failed to prepare memory list query")?; + let rows = stmt + .query_map([], map_memory) + .context("failed to list memories")?; + collect_rows(rows) + } + + pub fn update(&self, id: &str, update: UpdateMemory) -> Result> { + let Some(mut memory) = self.get(id)? else { + return Ok(None); + }; + if let Some(content) = update.content { + memory.content = content; + } + if let Some(category) = update.category { + memory.category = category; + } + if let Some(conversation_id) = update.conversation_id { + memory.conversation_id = conversation_id; + } + if let Some(metadata) = update.metadata { + memory.metadata_json = json_or_empty_object(Some(metadata))?; + } + memory.updated_at = Utc::now(); + + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + UPDATE memories + SET content = ?2, category = ?3, conversation_id = ?4, updated_at = ?5, + metadata_json = ?6, sync_version = sync_version + 1 + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![ + memory.id, + memory.content, + memory.category, + memory.conversation_id, + memory.updated_at, + memory.metadata_json + ], + ) + .context("failed to update memory")?; + drop(conn); + self.get(id) + } + + pub fn soft_delete(&self, id: &str) -> Result { + soft_delete_by_id(&self.conn, "memories", id, "memory") + } +} + +pub struct ActionItemRepository { + conn: Arc>, +} + +impl ActionItemRepository { + pub fn create(&self, new: NewActionItem) -> Result { + let now = Utc::now(); + let action_item = ActionItem { + id: new.id, + conversation_id: new.conversation_id, + title: new.title, + description: new.description.unwrap_or_default(), + status: new.status.unwrap_or_else(|| "open".to_string()), + due_at: new.due_at, + completed_at: None, + created_at: now, + updated_at: now, + deleted_at: None, + cloud_id: None, + sync_version: 0, + sync_state: "local".to_string(), + metadata_json: json_or_empty_object(new.metadata)?, + }; + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + INSERT INTO action_items ( + id, conversation_id, title, description, status, due_at, completed_at, created_at, + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json + ) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14) + "#, + params![ + action_item.id, + action_item.conversation_id, + action_item.title, + action_item.description, + action_item.status, + action_item.due_at, + action_item.completed_at, + action_item.created_at, + action_item.updated_at, + action_item.deleted_at, + action_item.cloud_id, + action_item.sync_version, + action_item.sync_state, + action_item.metadata_json + ], + ) + .context("failed to create action item")?; + Ok(action_item) + } + + pub fn get(&self, id: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT id, conversation_id, title, description, status, due_at, completed_at, + created_at, updated_at, deleted_at, cloud_id, sync_version, sync_state, + metadata_json + FROM action_items + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![id], + map_action_item, + ) + .optional() + .context("failed to fetch action item") + } + + pub fn list(&self) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let mut stmt = conn + .prepare( + r#" + SELECT id, conversation_id, title, description, status, due_at, completed_at, + created_at, updated_at, deleted_at, cloud_id, sync_version, sync_state, + metadata_json + FROM action_items + WHERE deleted_at IS NULL + ORDER BY updated_at DESC + "#, + ) + .context("failed to prepare action item list query")?; + let rows = stmt + .query_map([], map_action_item) + .context("failed to list action items")?; + collect_rows(rows) + } + + pub fn update(&self, id: &str, update: UpdateActionItem) -> Result> { + let Some(mut action_item) = self.get(id)? else { + return Ok(None); + }; + if let Some(title) = update.title { + action_item.title = title; + } + if let Some(description) = update.description { + action_item.description = description; + } + if let Some(status) = update.status { + action_item.status = status; + if action_item.status == "completed" && action_item.completed_at.is_none() { + action_item.completed_at = Some(Utc::now()); + } + } + if let Some(due_at) = update.due_at { + action_item.due_at = due_at; + } + if let Some(conversation_id) = update.conversation_id { + action_item.conversation_id = conversation_id; + } + if let Some(metadata) = update.metadata { + action_item.metadata_json = json_or_empty_object(Some(metadata))?; + } + action_item.updated_at = Utc::now(); + + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + UPDATE action_items + SET conversation_id = ?2, title = ?3, description = ?4, status = ?5, due_at = ?6, + completed_at = ?7, updated_at = ?8, metadata_json = ?9, + sync_version = sync_version + 1 + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![ + action_item.id, + action_item.conversation_id, + action_item.title, + action_item.description, + action_item.status, + action_item.due_at, + action_item.completed_at, + action_item.updated_at, + action_item.metadata_json + ], + ) + .context("failed to update action item")?; + drop(conn); + self.get(id) + } + + pub fn soft_delete(&self, id: &str) -> Result { + soft_delete_by_id(&self.conn, "action_items", id, "action item") + } +} + +pub struct ProfileRepository { + conn: Arc>, +} + +impl ProfileRepository { + pub fn get_or_create_default(&self) -> Result { + if let Some(profile) = self.get("local")? { + return Ok(profile); + } + self.upsert(UpdateProfile { + display_name: Some(String::new()), + timezone: None, + locale: None, + metadata: None, + }) + } + + pub fn get(&self, id: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT id, display_name, timezone, locale, created_at, updated_at, deleted_at, + cloud_id, sync_version, sync_state, metadata_json + FROM local_profiles + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![id], + map_local_profile, + ) + .optional() + .context("failed to fetch local profile") + } + + pub fn upsert(&self, update: UpdateProfile) -> Result { + let now = Utc::now(); + let current = self.get("local")?; + let display_name = update + .display_name + .or_else(|| current.as_ref().map(|profile| profile.display_name.clone())) + .unwrap_or_default(); + let timezone = update.timezone.or_else(|| { + current + .as_ref() + .and_then(|profile| profile.timezone.clone()) + }); + let locale = update + .locale + .or_else(|| current.as_ref().and_then(|profile| profile.locale.clone())); + let metadata_json = match update.metadata { + Some(metadata) => json_or_empty_object(Some(metadata))?, + None => current + .as_ref() + .map(|profile| profile.metadata_json.clone()) + .unwrap_or_else(|| "{}".to_string()), + }; + let created_at = current + .as_ref() + .map(|profile| profile.created_at) + .unwrap_or(now); + + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + INSERT INTO local_profiles ( + id, display_name, timezone, locale, created_at, updated_at, deleted_at, cloud_id, + sync_version, sync_state, metadata_json + ) + VALUES ('local', ?1, ?2, ?3, ?4, ?5, NULL, NULL, 0, 'local', ?6) + ON CONFLICT(id) DO UPDATE SET + display_name = excluded.display_name, + timezone = excluded.timezone, + locale = excluded.locale, + updated_at = excluded.updated_at, + metadata_json = excluded.metadata_json, + sync_version = local_profiles.sync_version + 1 + "#, + params![ + display_name, + timezone, + locale, + created_at, + now, + metadata_json + ], + ) + .context("failed to upsert local profile")?; + drop(conn); + self.get("local")? + .context("local profile missing after upsert") + } +} + +pub struct SettingsRepository { + conn: Arc>, +} + +impl SettingsRepository { + pub fn list(&self) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let mut stmt = conn + .prepare( + r#" + SELECT key, value_json, updated_at, deleted_at, cloud_id, sync_version, sync_state + FROM local_settings + WHERE deleted_at IS NULL + ORDER BY key ASC + "#, + ) + .context("failed to prepare settings list query")?; + let rows = stmt + .query_map([], map_local_setting) + .context("failed to list settings")?; + collect_rows(rows) + } + + pub fn upsert_many( + &self, + values: serde_json::Map, + ) -> Result> { + let now = Utc::now(); + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + for (key, value) in values { + let value_json = + serde_json::to_string(&value).context("failed to serialize setting value")?; + conn.execute( + r#" + INSERT INTO local_settings (key, value_json, updated_at, deleted_at, cloud_id, sync_version, sync_state) + VALUES (?1, ?2, ?3, NULL, NULL, 0, 'local') + ON CONFLICT(key) DO UPDATE SET + value_json = excluded.value_json, + updated_at = excluded.updated_at, + deleted_at = NULL, + sync_version = local_settings.sync_version + 1 + "#, + params![key, value_json, now], + ) + .context("failed to upsert local setting")?; + } + drop(conn); + self.list() + } +} + impl SearchRepository { pub fn conversations(&self, query: &str, limit: i64) -> Result> { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); @@ -654,6 +1279,61 @@ pub struct NewProcessingJob { pub payload: Option, } +#[derive(Debug, Clone)] +pub struct UpdateConversation { + pub title: Option, + pub overview: Option, + pub status: Option, + pub ended_at: Option>>, + pub metadata: Option, +} + +#[derive(Debug, Clone)] +pub struct NewMemory { + pub id: String, + pub content: String, + pub category: Option, + pub conversation_id: Option, + pub metadata: Option, +} + +#[derive(Debug, Clone)] +pub struct UpdateMemory { + pub content: Option, + pub category: Option>, + pub conversation_id: Option>, + pub metadata: Option, +} + +#[derive(Debug, Clone)] +pub struct NewActionItem { + pub id: String, + pub conversation_id: Option, + pub title: String, + pub description: Option, + pub status: Option, + pub due_at: Option>, + pub metadata: Option, +} + +#[derive(Debug, Clone)] +pub struct UpdateActionItem { + pub conversation_id: Option>, + pub title: Option, + pub description: Option, + pub status: Option, + pub due_at: Option>>, + pub metadata: Option, +} + +#[derive(Debug, Clone)] +pub struct UpdateProfile { + pub display_name: Option, + pub timezone: Option, + pub locale: Option, + pub metadata: Option, +} + pub fn deterministic_id(prefix: &str, parts: &[&str]) -> String { let mut hasher = Sha256::new(); hasher.update(prefix.as_bytes()); @@ -773,6 +1453,113 @@ fn map_transcript_segment(row: &rusqlite::Row<'_>) -> rusqlite::Result) -> rusqlite::Result { + let status: String = row.get(2)?; + Ok(ProcessingJob { + id: row.get(0)?, + kind: row.get(1)?, + status: ProcessingJobStatus::from_db(&status), + target_conversation_id: row.get(3)?, + retry_count: row.get(4)?, + max_retries: row.get(5)?, + last_error: row.get(6)?, + payload_json: row.get(7)?, + result_json: row.get(8)?, + queued_at: row.get(9)?, + started_at: row.get(10)?, + completed_at: row.get(11)?, + failed_at: row.get(12)?, + created_at: row.get(13)?, + updated_at: row.get(14)?, + deleted_at: row.get(15)?, + cloud_id: row.get(16)?, + sync_version: row.get(17)?, + sync_state: row.get(18)?, + }) +} + +fn map_memory(row: &rusqlite::Row<'_>) -> rusqlite::Result { + Ok(Memory { + id: row.get(0)?, + content: row.get(1)?, + category: row.get(2)?, + conversation_id: row.get(3)?, + created_at: row.get(4)?, + updated_at: row.get(5)?, + deleted_at: row.get(6)?, + cloud_id: row.get(7)?, + sync_version: row.get(8)?, + sync_state: row.get(9)?, + metadata_json: row.get(10)?, + }) +} + +fn map_action_item(row: &rusqlite::Row<'_>) -> rusqlite::Result { + Ok(ActionItem { + id: row.get(0)?, + conversation_id: row.get(1)?, + title: row.get(2)?, + description: row.get(3)?, + status: row.get(4)?, + due_at: row.get(5)?, + completed_at: row.get(6)?, + created_at: row.get(7)?, + updated_at: row.get(8)?, + deleted_at: row.get(9)?, + cloud_id: row.get(10)?, + sync_version: row.get(11)?, + sync_state: row.get(12)?, + metadata_json: row.get(13)?, + }) +} + +fn map_local_profile(row: &rusqlite::Row<'_>) -> rusqlite::Result { + Ok(LocalProfile { + id: row.get(0)?, + display_name: row.get(1)?, + timezone: row.get(2)?, + locale: row.get(3)?, + created_at: row.get(4)?, + updated_at: row.get(5)?, + deleted_at: row.get(6)?, + cloud_id: row.get(7)?, + sync_version: row.get(8)?, + sync_state: row.get(9)?, + metadata_json: row.get(10)?, + }) +} + +fn map_local_setting(row: &rusqlite::Row<'_>) -> rusqlite::Result { + Ok(LocalSetting { + key: row.get(0)?, + value_json: row.get(1)?, + updated_at: row.get(2)?, + deleted_at: row.get(3)?, + cloud_id: row.get(4)?, + sync_version: row.get(5)?, + sync_state: row.get(6)?, + }) +} + +fn soft_delete_by_id( + conn: &Arc>, + table: &str, + id: &str, + entity_name: &str, +) -> Result { + let now = Utc::now(); + let conn = conn.lock().expect("SQLite connection mutex poisoned"); + let changed = conn + .execute( + &format!( + "UPDATE {table} SET deleted_at = ?2, updated_at = ?2 WHERE id = ?1 AND deleted_at IS NULL" + ), + params![id, now], + ) + .with_context(|| format!("failed to delete {entity_name}"))?; + Ok(changed > 0) +} + impl ProcessingJobStatus { fn as_str(&self) -> &'static str { match self { @@ -782,6 +1569,15 @@ impl ProcessingJobStatus { Self::Failed => "failed", } } + + fn from_db(value: &str) -> Self { + match value { + "running" => Self::Running, + "completed" => Self::Completed, + "failed" => Self::Failed, + _ => Self::Queued, + } + } } #[cfg(test)] From 8cb9ac035b833195568ab51d0137c248bac97966 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 21:24:36 +0700 Subject: [PATCH 04/58] feat(desktop): harden APIClient with structured routing and transcription retry - Major APIClient.swift refactor: structured routing layer (+609/-1) - Add DesktopBackendEnvironment with configurable endpoint resolution (+62) - Harden TranscriptionRetryService with exponential backoff (+35/-19) - Add APIClientRoutingTests with comprehensive route coverage (+91) - Update .env.example with new backend environment variables --- desktop/.env.example | 8 + desktop/Desktop/Sources/APIClient.swift | 609 +++++++++++++++++- .../Sources/DesktopBackendEnvironment.swift | 62 ++ .../Sources/TranscriptionRetryService.swift | 35 +- .../Desktop/Tests/APIClientRoutingTests.swift | 91 +++ 5 files changed, 803 insertions(+), 2 deletions(-) diff --git a/desktop/.env.example b/desktop/.env.example index 5cfaed873bc..1a37940c8a7 100644 --- a/desktop/.env.example +++ b/desktop/.env.example @@ -24,6 +24,14 @@ OMI_DESKTOP_API_URL=http://localhost:10201 # WARNING: Do NOT set this to OMI_DESKTOP_API_URL — that points to the Rust desktop-backend. OMI_PYTHON_API_URL=https://api.omi.me +# Desktop app data backend mode for MVP local-first flows. +# cloud: Omi-hosted Python backend (default) +# local: local daemon; start it with `cd desktop/local-backend && cargo run`, +# then verify `curl http://127.0.0.1:8765/health` +# custom: custom remote URL from OMI_PYTHON_API_URL +# OMI_DESKTOP_BACKEND_MODE=cloud +# OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 + # Firebase Web API key — fetched from backend via /v1/config/api-keys # Only set this for local dev without a backend running # FIREBASE_API_KEY= diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 3af96ffa442..ddeb52005b7 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -20,6 +20,14 @@ actor APIClient { return "" } + var selectedBackendTarget: DesktopBackendEnvironment.BackendTarget { + DesktopBackendEnvironment.selectedBackendTarget + } + + var isUsingLocalDaemon: Bool { + selectedBackendTarget.mode == .localDaemon + } + let session: URLSession private let decoder: JSONDecoder @@ -148,6 +156,22 @@ actor APIClient { return try await performRequest(request) } + func put( + _ endpoint: String, + body: B, + requireAuth: Bool = true, + customBaseURL: String? = nil + ) async throws -> T { + let base = customBaseURL ?? baseURL + let url = URL(string: base + endpoint)! + var request = URLRequest(url: url) + request.httpMethod = "PUT" + request.allHTTPHeaderFields = try await buildHeaders(requireAuth: requireAuth) + request.httpBody = try JSONEncoder().encode(body) + + return try await performRequest(request) + } + func delete( _ endpoint: String, requireAuth: Bool = true, @@ -174,6 +198,38 @@ actor APIClient { } } + func checkSelectedBackendHealth() async throws -> LocalDaemonHealth { + let target = selectedBackendTarget + return try await get("health", requireAuth: target.requiresAuth, customBaseURL: target.baseURL) + } + + func getSelectedBackendSettings() async throws -> [LocalDaemonSetting] { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + return [] + } + let response: LocalDaemonSettingsResponse = try await get( + "v1/settings", + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.settings + } + + func updateSelectedBackendSettings(_ values: [String: String]) async throws -> [LocalDaemonSetting] { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + return [] + } + let response: LocalDaemonSettingsResponse = try await put( + "v1/settings", + body: values, + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.settings + } + // MARK: - Request Execution private func performRequest(_ request: URLRequest) async throws -> T { @@ -264,10 +320,54 @@ enum APIError: LocalizedError { } } +struct LocalDaemonHealth: Decodable, Equatable { + let service: String + let mode: String + let version: String + let bindAddress: String? + let dataDir: String? + + enum CodingKeys: String, CodingKey { + case service, mode, version + case bindAddress = "bind_addr" + case dataDir = "data_dir" + } +} + +struct LocalDaemonSettingsResponse: Decodable { + let settings: [LocalDaemonSetting] +} + +struct LocalDaemonSetting: Decodable, Equatable { + let key: String + let valueJson: String + let updatedAt: Date + + enum CodingKeys: String, CodingKey { + case key + case valueJson = "value_json" + case updatedAt = "updated_at" + } +} + // MARK: - Conversation API extension APIClient { + private var mvpBackendTarget: DesktopBackendEnvironment.BackendTarget { + selectedBackendTarget + } + + private static func isoString(_ date: Date) -> String { + let formatter = ISO8601DateFormatter() + formatter.formatOptions = [.withInternetDateTime, .withFractionalSeconds] + return formatter.string(from: date) + } + + private static func queryValue(_ value: String) -> String { + value.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? value + } + /// Fetches conversations from the API with optional filtering func getConversations( limit: Int = 50, @@ -279,6 +379,17 @@ extension APIClient { folderId: String? = nil, starred: Bool? = nil ) async throws -> [ServerConversation] { + let target = mvpBackendTarget + if target.mode == .localDaemon { + let endpoint = "v1/conversations?limit=\(limit)" + let response: LocalConversationsResponse = try await get( + endpoint, + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.conversations.map { $0.toServerConversation(transcriptSegments: []) } + } + var queryItems: [String] = [ "limit=\(limit)", "offset=\(offset)", @@ -314,11 +425,29 @@ extension APIClient { /// Fetches a single conversation by ID func getConversation(id: String) async throws -> ServerConversation { + let target = mvpBackendTarget + if target.mode == .localDaemon { + let response: LocalConversationResponse = try await get( + "v1/conversations/\(id)", + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.conversation.toServerConversation( + transcriptSegments: response.transcriptSegments.map { $0.toTranscriptSegment() } + ) + } + return try await get("v1/conversations/\(id)") } /// Deletes a conversation by ID func deleteConversation(id: String) async throws { + let target = mvpBackendTarget + if target.mode == .localDaemon { + try await delete("v1/conversations/\(id)", requireAuth: false, customBaseURL: target.baseURL) + return + } + try await delete("v1/conversations/\(id)") } @@ -373,6 +502,17 @@ extension APIClient { let title: String } + let target = mvpBackendTarget + if target.mode == .localDaemon { + let _: LocalConversationEnvelope = try await patch( + "v1/conversations/\(id)", + body: TitleUpdate(title: title), + requireAuth: false, + customBaseURL: target.baseURL + ) + return + } + let url = URL(string: baseURL + "v1/conversations/\(id)")! var request = URLRequest(url: url) request.httpMethod = "PATCH" @@ -394,6 +534,20 @@ extension APIClient { perPage: Int = 10, includeDiscarded: Bool = false ) async throws -> ConversationSearchResult { + let target = mvpBackendTarget + if target.mode == .localDaemon { + let response: LocalConversationSearchResponse = try await get( + "v1/search/conversations?q=\(Self.queryValue(query))&limit=\(perPage)", + requireAuth: false, + customBaseURL: target.baseURL + ) + return ConversationSearchResult( + items: response.results.map { $0.toServerConversation() }, + currentPage: 1, + totalPages: 1 + ) + } + struct SearchRequest: Encodable { let query: String let page: Int @@ -422,6 +576,16 @@ extension APIClient { includeDiscarded: Bool = false, statuses: [ConversationStatus] = [.completed, .processing] ) async throws -> Int { + let target = mvpBackendTarget + if target.mode == .localDaemon { + let response: LocalConversationsResponse = try await get( + "v1/conversations?limit=10000", + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.conversations.count + } + if let cache = conversationsCountCache, let time = conversationsCountCacheTime, Date().timeIntervalSince(time) < 5 { @@ -449,6 +613,103 @@ extension APIClient { return response.count } + func createLocalDaemonConversation( + sessionId: String? = nil, + title: String? = nil, + overview: String? = nil, + startedAt: Date? = nil + ) async throws -> ServerConversation { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + throw APIError.httpError(statusCode: 400) + } + + struct Request: Encodable { + let sessionId: String? + let title: String? + let overview: String? + let startedAt: String? + + enum CodingKeys: String, CodingKey { + case title, overview + case sessionId = "session_id" + case startedAt = "started_at" + } + } + + let response: LocalConversationEnvelope = try await post( + "v1/conversations", + body: Request( + sessionId: sessionId, + title: title, + overview: overview, + startedAt: startedAt.map(Self.isoString) + ), + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.conversation.toServerConversation(transcriptSegments: []) + } + + func appendLocalDaemonTranscriptSegment( + conversationId: String, + segment: TranscriptionSegmentRecord + ) async throws { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + throw APIError.httpError(statusCode: 400) + } + + struct Request: Encodable { + let id: String? + let speakerId: String? + let speakerLabel: String? + let text: String + let startMs: Int64 + let endMs: Int64 + let segmentIndex: Int + let source: String + + enum CodingKeys: String, CodingKey { + case id, text, source + case speakerId = "speaker_id" + case speakerLabel = "speaker_label" + case startMs = "start_ms" + case endMs = "end_ms" + case segmentIndex = "segment_index" + } + } + + let _: LocalTranscriptSegmentEnvelope = try await post( + "v1/conversations/\(conversationId)/transcript-segments", + body: Request( + id: segment.segmentId, + speakerId: String(segment.speaker), + speakerLabel: segment.speakerLabel, + text: segment.text, + startMs: Int64((segment.startTime * 1000).rounded()), + endMs: Int64((segment.endTime * 1000).rounded()), + segmentIndex: segment.segmentOrder, + source: "desktop" + ), + requireAuth: false, + customBaseURL: target.baseURL + ) + } + + func finalizeLocalDaemonTranscript(conversationId: String) async throws { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + throw APIError.httpError(statusCode: 400) + } + + let _: LocalProcessingJobEnvelope = try await post( + "v1/conversations/\(conversationId)/finalize-transcript", + requireAuth: false, + customBaseURL: target.baseURL + ) + } + /// Gets the count of AI chat messages from PostHog func getChatMessageCount() async throws -> Int { struct CountResponse: Decodable { @@ -961,6 +1222,170 @@ struct ConversationSearchResult: Codable { } } +private struct LocalConversationsResponse: Decodable { + let conversations: [LocalConversation] +} + +private struct LocalConversationEnvelope: Decodable { + let conversation: LocalConversation +} + +private struct LocalConversationResponse: Decodable { + let conversation: LocalConversation + let transcriptSegments: [LocalTranscriptSegment] + + enum CodingKeys: String, CodingKey { + case conversation + case transcriptSegments = "transcript_segments" + } +} + +private struct LocalConversationSearchResponse: Decodable { + let results: [LocalConversationSearchResult] +} + +private struct LocalTranscriptSegmentEnvelope: Decodable { + let transcriptSegment: LocalTranscriptSegment + + enum CodingKeys: String, CodingKey { + case transcriptSegment = "transcript_segment" + } +} + +private struct LocalProcessingJobEnvelope: Decodable { + let processingJob: LocalProcessingJob + + enum CodingKeys: String, CodingKey { + case processingJob = "processing_job" + } +} + +private struct LocalProcessingJob: Decodable { + let id: String +} + +private struct LocalConversation: Decodable { + let id: String + let sessionId: String + let title: String + let overview: String + let status: String + let startedAt: Date + let endedAt: Date? + let createdAt: Date + let deletedAt: Date? + + enum CodingKeys: String, CodingKey { + case id, title, overview, status + case sessionId = "session_id" + case startedAt = "started_at" + case endedAt = "ended_at" + case createdAt = "created_at" + case deletedAt = "deleted_at" + } + + func toServerConversation(transcriptSegments: [TranscriptSegment]) -> ServerConversation { + ServerConversation( + id: id, + createdAt: createdAt, + startedAt: startedAt, + finishedAt: endedAt, + structured: Structured( + title: title, + overview: overview, + emoji: "", + category: "other", + actionItems: [], + events: [] + ), + transcriptSegments: transcriptSegments, + geolocation: nil, + photos: [], + appsResults: [], + source: .desktop, + language: nil, + status: status == "open" ? .inProgress : (ConversationStatus(rawValue: status) ?? .completed), + discarded: false, + deleted: deletedAt != nil, + isLocked: false, + starred: false, + folderId: nil, + inputDeviceName: nil + ) + } +} + +private struct LocalTranscriptSegment: Decodable { + let id: String + let speakerId: String? + let speakerLabel: String? + let text: String + let startMs: Int64 + let endMs: Int64 + + enum CodingKeys: String, CodingKey { + case id, text + case speakerId = "speaker_id" + case speakerLabel = "speaker_label" + case startMs = "start_ms" + case endMs = "end_ms" + } + + func toTranscriptSegment() -> TranscriptSegment { + TranscriptSegment( + id: id, + backendId: id, + text: text, + speaker: speakerLabel ?? speakerId, + isUser: false, + personId: nil, + start: Double(startMs) / 1000, + end: Double(endMs) / 1000 + ) + } +} + +private struct LocalConversationSearchResult: Decodable { + let conversationId: String + let title: String + let overview: String + + enum CodingKeys: String, CodingKey { + case title, overview + case conversationId = "conversation_id" + } + + func toServerConversation() -> ServerConversation { + ServerConversation( + id: conversationId, + createdAt: Date(), + startedAt: nil, + finishedAt: nil, + structured: Structured( + title: title, + overview: overview, + emoji: "", + category: "other", + actionItems: [], + events: [] + ), + transcriptSegments: [], + geolocation: nil, + photos: [], + appsResults: [], + source: .desktop, + language: nil, + status: .completed, + discarded: false, + deleted: false, + isLocked: false, + starred: false, + folderId: nil, + inputDeviceName: nil + ) + } +} + // MARK: - Merge Response /// Response from merge conversations API @@ -1139,6 +1564,56 @@ struct ServerMemory: Codable, Identifiable { headline = try container.decodeIfPresent(String.self, forKey: .headline) } + init( + id: String, + content: String, + category: MemoryCategory, + createdAt: Date, + updatedAt: Date, + conversationId: String? = nil, + reviewed: Bool = false, + userReview: Bool? = nil, + visibility: String = "private", + manuallyAdded: Bool = false, + scoring: String? = nil, + source: String? = nil, + confidence: Double? = nil, + sourceApp: String? = nil, + contextSummary: String? = nil, + isRead: Bool = false, + isDismissed: Bool = false, + tags: [String] = [], + reasoning: String? = nil, + currentActivity: String? = nil, + inputDeviceName: String? = nil, + windowTitle: String? = nil, + headline: String? = nil + ) { + self.id = id + self.content = content + self.category = category + self.createdAt = createdAt + self.updatedAt = updatedAt + self.conversationId = conversationId + self.reviewed = reviewed + self.userReview = userReview + self.visibility = visibility + self.manuallyAdded = manuallyAdded + self.scoring = scoring + self.source = source + self.confidence = confidence + self.sourceApp = sourceApp + self.contextSummary = contextSummary + self.isRead = isRead + self.isDismissed = isDismissed + self.tags = tags + self.reasoning = reasoning + self.currentActivity = currentActivity + self.inputDeviceName = inputDeviceName + self.windowTitle = windowTitle + self.headline = headline + } + var isPublic: Bool { visibility == "public" } @@ -1260,6 +1735,16 @@ extension APIClient { tags: [String]? = nil, includeDismissed: Bool = false ) async throws -> [ServerMemory] { + let target = selectedBackendTarget + if target.mode == .localDaemon { + let response: LocalMemoriesResponse = try await get( + "v1/memories", + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.memories.map { $0.toServerMemory() } + } + var endpoint = "v3/memories?limit=\(limit)&offset=\(offset)" if let category = category { endpoint += "&category=\(category)" @@ -1288,6 +1773,21 @@ extension APIClient { windowTitle: String? = nil, headline: String? = nil ) async throws -> CreateMemoryResponse { + let target = selectedBackendTarget + if target.mode == .localDaemon { + struct LocalCreateRequest: Encodable { + let content: String + let category: String? + } + let response: LocalMemoryEnvelope = try await post( + "v1/memories", + body: LocalCreateRequest(content: content, category: category?.rawValue), + requireAuth: false, + customBaseURL: target.baseURL + ) + return CreateMemoryResponse(id: response.memory.id, message: nil) + } + struct CreateRequest: Encodable { let content: String let visibility: String @@ -1450,6 +1950,44 @@ struct CreateMemoryResponse: Codable { let message: String? } +private struct LocalMemoriesResponse: Decodable { + let memories: [LocalMemory] +} + +private struct LocalMemoryEnvelope: Decodable { + let memory: LocalMemory +} + +private struct LocalMemory: Decodable { + let id: String + let content: String + let category: String? + let conversationId: String? + let createdAt: Date + let updatedAt: Date + + enum CodingKeys: String, CodingKey { + case id, content, category + case conversationId = "conversation_id" + case createdAt = "created_at" + case updatedAt = "updated_at" + } + + func toServerMemory() -> ServerMemory { + ServerMemory( + id: id, + content: content, + category: category.flatMap(MemoryCategory.init(rawValue:)) ?? .system, + createdAt: createdAt, + updatedAt: updatedAt, + conversationId: conversationId, + visibility: "private", + manuallyAdded: true, + source: "local_daemon" + ) + } +} + /// One item in a POST /v3/memories/batch payload. Mirrors the `Memory` model /// in `backend/models/memories.py` (the server hardcodes category=manual on /// batch creation, so we intentionally don't send it). @@ -1512,7 +2050,15 @@ struct ActionItemsListResponse: Decodable { } else { self.items = try container.decode([TaskActionItem].self, forKey: .items) } - self.hasMore = try container.decode(Bool.self, forKey: .hasMore) + self.hasMore = try container.decodeIfPresent(Bool.self, forKey: .hasMore) ?? false + } +} + +private struct LocalActionItemEnvelope: Decodable { + let actionItem: TaskActionItem + + enum CodingKeys: String, CodingKey { + case actionItem = "action_item" } } @@ -1530,6 +2076,11 @@ extension APIClient { sortBy: String? = nil, deleted: Bool? = nil ) async throws -> ActionItemsListResponse { + let target = selectedBackendTarget + if target.mode == .localDaemon { + return try await get("v1/action-items", requireAuth: false, customBaseURL: target.baseURL) + } + let formatter = ISO8601DateFormatter() formatter.formatOptions = [.withInternetDateTime, .withFractionalSeconds] @@ -1651,11 +2202,42 @@ extension APIClient { recurrenceRule: recurrenceRule ) + let target = selectedBackendTarget + if target.mode == .localDaemon { + struct LocalUpdateRequest: Encodable { + let title: String? + let description: String? + let dueAt: String? + + enum CodingKeys: String, CodingKey { + case title, description + case dueAt = "due_at" + } + } + let response: LocalActionItemEnvelope = try await patch( + "v1/action-items/\(id)", + body: LocalUpdateRequest( + title: description, + description: description, + dueAt: dueAt.map { formatter.string(from: $0) } + ), + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.actionItem + } + return try await patch("v1/action-items/\(id)", body: request) } /// Deletes an action item func deleteActionItem(id: String) async throws { + let target = selectedBackendTarget + if target.mode == .localDaemon { + try await delete("v1/action-items/\(id)", requireAuth: false, customBaseURL: target.baseURL) + return + } + try await delete("v1/action-items/\(id)") } @@ -1716,6 +2298,31 @@ extension APIClient { recurrenceParentId: recurrenceParentId ) + let target = selectedBackendTarget + if target.mode == .localDaemon { + struct LocalCreateRequest: Encodable { + let title: String + let description: String? + let dueAt: String? + + enum CodingKeys: String, CodingKey { + case title, description + case dueAt = "due_at" + } + } + let response: LocalActionItemEnvelope = try await post( + "v1/action-items", + body: LocalCreateRequest( + title: description, + description: description, + dueAt: dueAt.map { formatter.string(from: $0) } + ), + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.actionItem + } + return try await post("v1/action-items", body: request) } diff --git a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift index ca22c27d2e4..fa56c4477d6 100644 --- a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift +++ b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift @@ -1,9 +1,61 @@ import Foundation enum DesktopBackendEnvironment { + enum BackendMode: Equatable { + case cloud + case localDaemon + case customRemote + } + + struct BackendTarget: Equatable { + let mode: BackendMode + let baseURL: String + let requiresAuth: Bool + } + static let productionPythonAPIURL = "https://api.omi.me/" static let developmentPythonAPIURL = "https://api.omiapi.com/" static let developmentRustBackendURL = "https://desktop-backend-dt5lrfkkoa-uc.a.run.app/" + static let defaultLocalDaemonURL = "http://127.0.0.1:8765/" + + static var selectedBackendTarget: BackendTarget { + selectedBackendTarget( + modeValue: currentEnvironmentValue("OMI_DESKTOP_BACKEND_MODE") + ?? currentEnvironmentValue("OMI_BACKEND_MODE"), + pythonEnvironmentValue: currentEnvironmentValue("OMI_PYTHON_API_URL"), + localDaemonEnvironmentValue: currentEnvironmentValue("OMI_LOCAL_DAEMON_URL") + ) + } + + static func selectedBackendTarget( + modeValue: String?, + pythonEnvironmentValue: String?, + localDaemonEnvironmentValue: String? + ) -> BackendTarget { + switch normalizedMode(modeValue) { + case "local", "local-daemon", "local_daemon", "daemon": + return BackendTarget( + mode: .localDaemon, + baseURL: localDaemonBaseURL(environmentValue: localDaemonEnvironmentValue), + requiresAuth: false + ) + case "custom", "remote", "custom-remote", "custom_remote": + return BackendTarget( + mode: .customRemote, + baseURL: pythonBaseURL( + useDevelopmentBackends: false, + environmentValue: pythonEnvironmentValue + ), + requiresAuth: true + ) + default: + return BackendTarget( + mode: .cloud, + baseURL: pythonBaseURL(environmentValue: pythonEnvironmentValue), + requiresAuth: true + ) + } + } static var shouldUseDevelopmentBackends: Bool { shouldUseDevelopmentBackends( @@ -90,6 +142,12 @@ enum DesktopBackendEnvironment { return "" } + static func localDaemonBaseURL( + environmentValue: String? = currentEnvironmentValue("OMI_LOCAL_DAEMON_URL") + ) -> String { + normalizedURL(environmentValue) ?? defaultLocalDaemonURL + } + static func applyReleaseChannelDefaults() { guard shouldUseDevelopmentBackends else { return } @@ -110,6 +168,10 @@ enum DesktopBackendEnvironment { return trimmed.hasSuffix("/") ? trimmed : trimmed + "/" } + private static func normalizedMode(_ raw: String?) -> String { + raw?.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() ?? "cloud" + } + private static func currentEnvironmentValue(_ key: String) -> String? { guard let value = getenv(key), let string = String(validatingUTF8: value) else { return nil diff --git a/desktop/Desktop/Sources/TranscriptionRetryService.swift b/desktop/Desktop/Sources/TranscriptionRetryService.swift index 469d9473073..0bf03317755 100644 --- a/desktop/Desktop/Sources/TranscriptionRetryService.swift +++ b/desktop/Desktop/Sources/TranscriptionRetryService.swift @@ -115,7 +115,7 @@ class TranscriptionRetryService { /// Process the retry queue (called periodically by timer) private func processRetryQueue() async { // Skip if user is signed out (tokens are cleared) - guard await AuthState.shared.isSignedIn else { return } + guard await APIClient.shared.isUsingLocalDaemon || await AuthState.shared.isSignedIn else { return } guard !isProcessing else { log("TranscriptionRetryService: Already processing, skipping") return @@ -214,6 +214,11 @@ class TranscriptionRetryService { log("TranscriptionRetryService: Reconciling session \(sessionId) (retryCount: \(session.retryCount))") do { + if await APIClient.shared.isUsingLocalDaemon { + try await uploadSessionToLocalDaemon(session, sessionId: sessionId) + return + } + // Check if backend already has a conversation for this time window let finishedAt = session.finishedAt ?? session.startedAt.addingTimeInterval(1) if let existing = try? await APIClient.shared.getConversations( @@ -253,4 +258,32 @@ class TranscriptionRetryService { } } + private func uploadSessionToLocalDaemon(_ session: TranscriptionSessionRecord, sessionId: Int64) async throws { + try await TranscriptionStorage.shared.markSessionUploading(id: sessionId) + let segments = try await TranscriptionStorage.shared.getSegments(sessionId: sessionId) + guard !segments.isEmpty else { + try await TranscriptionStorage.shared.markSessionFailed( + id: sessionId, error: "No transcript segments to upload to local daemon") + return + } + + let conversation = try await APIClient.shared.createLocalDaemonConversation( + sessionId: "desktop-\(sessionId)", + title: session.title, + overview: session.overview, + startedAt: session.startedAt + ) + + for segment in segments { + try await APIClient.shared.appendLocalDaemonTranscriptSegment( + conversationId: conversation.id, + segment: segment + ) + } + + try await APIClient.shared.finalizeLocalDaemonTranscript(conversationId: conversation.id) + try await TranscriptionStorage.shared.markSessionCompleted(id: sessionId, backendId: conversation.id) + log("TranscriptionRetryService: Session \(sessionId) stored in local daemon as \(conversation.id)") + } + } diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 788c9cd0b4e..80cb6a8d3f5 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -243,6 +243,39 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertEqual(url, "") } + func testSelectedBackendTargetDefaultsToCloudPython() { + let target = DesktopBackendEnvironment.selectedBackendTarget( + modeValue: nil, + pythonEnvironmentValue: "https://api.example.test", + localDaemonEnvironmentValue: nil + ) + XCTAssertEqual(target.mode, .cloud) + XCTAssertEqual(target.baseURL, "https://api.example.test/") + XCTAssertTrue(target.requiresAuth) + } + + func testSelectedBackendTargetSupportsLocalDaemonDefault() { + let target = DesktopBackendEnvironment.selectedBackendTarget( + modeValue: "local", + pythonEnvironmentValue: "https://api.example.test", + localDaemonEnvironmentValue: nil + ) + XCTAssertEqual(target.mode, .localDaemon) + XCTAssertEqual(target.baseURL, "http://127.0.0.1:8765/") + XCTAssertFalse(target.requiresAuth) + } + + func testSelectedBackendTargetSupportsCustomRemote() { + let target = DesktopBackendEnvironment.selectedBackendTarget( + modeValue: "custom", + pythonEnvironmentValue: "http://custom-backend:7777", + localDaemonEnvironmentValue: "http://127.0.0.1:8765" + ) + XCTAssertEqual(target.mode, .customRemote) + XCTAssertEqual(target.baseURL, "http://custom-backend:7777/") + XCTAssertTrue(target.requiresAuth) + } + func testBaseURLAndRustBackendURLAreIndependent() async { setenv("OMI_PYTHON_API_URL", "http://python:8080", 1) setenv("OMI_DESKTOP_API_URL", "http://rust:8787", 1) @@ -272,11 +305,15 @@ final class APIClientRoutingTests: XCTestCase { URLCapture.reset() setenv("OMI_PYTHON_API_URL", "http://python-test:9001", 1) setenv("OMI_DESKTOP_API_URL", "http://rust-test:9002", 1) + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") } override func tearDown() { unsetenv("OMI_PYTHON_API_URL") unsetenv("OMI_DESKTOP_API_URL") + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") URLCapture.reset() super.tearDown() } @@ -291,6 +328,18 @@ final class APIClientRoutingTests: XCTestCase { label: "getConversation") } + func testLocalModeGetConversationRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + _ = try? await client.getConversation(id: "local-123") as ServerConversation + + assertRoutes(URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v1/conversations/local-123", method: "GET", + label: "local getConversation") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + func testDeleteConversationRoutesToPython() async { let client = await makeTestClient() try? await client.deleteConversation(id: "conv-456") @@ -299,6 +348,48 @@ final class APIClientRoutingTests: XCTestCase { label: "deleteConversation") } + func testLocalModeCreateConversationRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + _ = try? await client.createLocalDaemonConversation( + sessionId: "desktop-1", + title: "Local", + overview: "Local daemon", + startedAt: Date(timeIntervalSince1970: 0) + ) + + let requests = URLCapture.capturedRequests + assertRoutes(requests, host: "127.0.0.1", port: 8765, + pathContains: "v1/conversations", method: "POST", + label: "local createConversation") + XCTAssertNil(requests.first?.headers["Authorization"]) + } + + func testLocalModeHealthCheckRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + _ = try? await client.checkSelectedBackendHealth() + + assertRoutes(URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "health", method: "GET", + label: "local health") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + + func testLocalModeSettingsRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + _ = try? await client.updateSelectedBackendSettings(["profile_name": "Local"]) + + assertRoutes(URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v1/settings", method: "PUT", + label: "local settings") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + // -- Conversations: manual URL(string: baseURL + ...) paths (PATCH → Python) -- func testSetConversationStarredRoutesToPython() async { From 9f5d91826a3c71dbb3c3648b847ef0fea9d61d5c Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 21:38:59 +0700 Subject: [PATCH 05/58] Fix desktop local daemon Swift validation --- desktop/Desktop/Sources/APIClient.swift | 64 ++++--------------- .../Sources/TranscriptionRetryService.swift | 4 +- 2 files changed, 16 insertions(+), 52 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index ddeb52005b7..ebc69fa0e95 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -1564,56 +1564,6 @@ struct ServerMemory: Codable, Identifiable { headline = try container.decodeIfPresent(String.self, forKey: .headline) } - init( - id: String, - content: String, - category: MemoryCategory, - createdAt: Date, - updatedAt: Date, - conversationId: String? = nil, - reviewed: Bool = false, - userReview: Bool? = nil, - visibility: String = "private", - manuallyAdded: Bool = false, - scoring: String? = nil, - source: String? = nil, - confidence: Double? = nil, - sourceApp: String? = nil, - contextSummary: String? = nil, - isRead: Bool = false, - isDismissed: Bool = false, - tags: [String] = [], - reasoning: String? = nil, - currentActivity: String? = nil, - inputDeviceName: String? = nil, - windowTitle: String? = nil, - headline: String? = nil - ) { - self.id = id - self.content = content - self.category = category - self.createdAt = createdAt - self.updatedAt = updatedAt - self.conversationId = conversationId - self.reviewed = reviewed - self.userReview = userReview - self.visibility = visibility - self.manuallyAdded = manuallyAdded - self.scoring = scoring - self.source = source - self.confidence = confidence - self.sourceApp = sourceApp - self.contextSummary = contextSummary - self.isRead = isRead - self.isDismissed = isDismissed - self.tags = tags - self.reasoning = reasoning - self.currentActivity = currentActivity - self.inputDeviceName = inputDeviceName - self.windowTitle = windowTitle - self.headline = headline - } - var isPublic: Bool { visibility == "public" } @@ -1981,9 +1931,21 @@ private struct LocalMemory: Decodable { createdAt: createdAt, updatedAt: updatedAt, conversationId: conversationId, + reviewed: false, + userReview: nil, visibility: "private", manuallyAdded: true, - source: "local_daemon" + scoring: nil, + source: "local_daemon", + confidence: nil, + sourceApp: nil, + contextSummary: nil, + isRead: false, + isDismissed: false, + tags: [], + reasoning: nil, + currentActivity: nil, + inputDeviceName: nil ) } } diff --git a/desktop/Desktop/Sources/TranscriptionRetryService.swift b/desktop/Desktop/Sources/TranscriptionRetryService.swift index 0bf03317755..c91ad832e40 100644 --- a/desktop/Desktop/Sources/TranscriptionRetryService.swift +++ b/desktop/Desktop/Sources/TranscriptionRetryService.swift @@ -115,7 +115,9 @@ class TranscriptionRetryService { /// Process the retry queue (called periodically by timer) private func processRetryQueue() async { // Skip if user is signed out (tokens are cleared) - guard await APIClient.shared.isUsingLocalDaemon || await AuthState.shared.isSignedIn else { return } + let usingLocalDaemon = await APIClient.shared.isUsingLocalDaemon + let isSignedIn = await MainActor.run { AuthState.shared.isSignedIn } + guard usingLocalDaemon || isSignedIn else { return } guard !isProcessing else { log("TranscriptionRetryService: Already processing, skipping") return From 7c26d4f9e7cdd40f2699b924407717b67538e452 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 21:44:05 +0700 Subject: [PATCH 06/58] Add local conversation processing jobs --- desktop/local-backend/Cargo.lock | 825 +++++++++++++++++++++++- desktop/local-backend/Cargo.toml | 3 +- desktop/local-backend/src/main.rs | 5 +- desktop/local-backend/src/processing.rs | 359 +++++++++++ desktop/local-backend/src/providers.rs | 171 +++++ desktop/local-backend/src/routes.rs | 9 + desktop/local-backend/src/storage.rs | 107 +++ 7 files changed, 1465 insertions(+), 14 deletions(-) create mode 100644 desktop/local-backend/src/processing.rs create mode 100644 desktop/local-backend/src/providers.rs diff --git a/desktop/local-backend/Cargo.lock b/desktop/local-backend/Cargo.lock index 4f61b0e16b4..f4b4f307bf8 100644 --- a/desktop/local-backend/Cargo.lock +++ b/desktop/local-backend/Cargo.lock @@ -116,6 +116,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "2.11.1" @@ -159,6 +165,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.44" @@ -227,6 +239,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -332,8 +355,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "r-efi 5.3.0", + "wasip2", + "wasm-bindgen", ] [[package]] @@ -344,7 +383,7 @@ checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" dependencies = [ "cfg-if", "libc", - "r-efi", + "r-efi 6.0.0", "wasip2", "wasip3", ] @@ -451,6 +490,23 @@ dependencies = [ "pin-project-lite", "smallvec", "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", ] [[package]] @@ -459,13 +515,21 @@ version = "0.1.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" dependencies = [ + "base64", "bytes", + "futures-channel", + "futures-util", "http", "http-body", "hyper", + "ipnet", + "libc", + "percent-encoding", "pin-project-lite", + "socket2", "tokio", "tower-service", + "tracing", ] [[package]] @@ -492,12 +556,115 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" +dependencies = [ + "displaydoc", + "potential_utf", + "utf8_iter", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" + +[[package]] +name = "icu_properties" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" + +[[package]] +name = "icu_provider" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + [[package]] name = "id-arena" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + [[package]] name = "indexmap" version = "2.14.0" @@ -510,6 +677,12 @@ dependencies = [ "serde_core", ] +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + [[package]] name = "itoa" version = "1.0.18" @@ -572,12 +745,24 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" +[[package]] +name = "litemap" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" + [[package]] name = "log" version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "matchers" version = "0.2.0" @@ -642,6 +827,7 @@ dependencies = [ "axum", "chrono", "directories", + "reqwest", "rusqlite", "serde", "serde_json", @@ -649,7 +835,7 @@ dependencies = [ "tempfile", "tokio", "tower", - "tower-http", + "tower-http 0.5.2", "tracing", "tracing-subscriber", ] @@ -684,6 +870,24 @@ version = "0.3.33" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" +[[package]] +name = "potential_utf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" +dependencies = [ + "zerovec", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -703,6 +907,61 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + [[package]] name = "quote" version = "1.0.45" @@ -712,12 +971,47 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "r-efi" version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" +[[package]] +name = "rand" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "redox_users" version = "0.4.6" @@ -726,7 +1020,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom 0.2.17", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -746,6 +1040,58 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tower", + "tower-http 0.6.11", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rusqlite" version = "0.32.1" @@ -761,6 +1107,12 @@ dependencies = [ "smallvec", ] +[[package]] +name = "rustc-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" + [[package]] name = "rustix" version = "1.1.4" @@ -774,6 +1126,41 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -916,6 +1303,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.117" @@ -932,6 +1331,20 @@ name = "sync_wrapper" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "tempfile" @@ -952,7 +1365,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", ] [[package]] @@ -966,6 +1388,17 @@ dependencies = [ "syn", ] +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.9" @@ -975,12 +1408,38 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "tinystr" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tinyvec" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" dependencies = [ + "bytes", "libc", "mio", "pin-project-lite", @@ -1001,6 +1460,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tower" version = "0.5.3" @@ -1034,6 +1503,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cfcf7e2740e6fc6d4d688b4ef00650406bb94adf4731e43c096c3a19fe40840" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", + "url", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -1121,6 +1608,12 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "typenum" version = "1.20.0" @@ -1139,6 +1632,30 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "valuable" version = "0.1.1" @@ -1157,6 +1674,15 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -1194,6 +1720,16 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "wasm-bindgen-macro" version = "0.2.121" @@ -1260,6 +1796,35 @@ dependencies = [ "semver", ] +[[package]] +name = "web-sys" +version = "0.3.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "windows-core" version = "0.62.2" @@ -1325,7 +1890,25 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", ] [[package]] @@ -1343,13 +1926,46 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] @@ -1358,42 +1974,138 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + [[package]] name = "wit-bindgen" version = "0.51.0" @@ -1488,6 +2200,35 @@ dependencies = [ "wasmparser", ] +[[package]] +name = "writeable" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" + +[[package]] +name = "yoke" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.8.48" @@ -1508,6 +2249,66 @@ dependencies = [ "syn", ] +[[package]] +name = "zerofrom" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zmij" version = "1.0.21" diff --git a/desktop/local-backend/Cargo.toml b/desktop/local-backend/Cargo.toml index 6ab5d434f1e..e1a0da82639 100644 --- a/desktop/local-backend/Cargo.toml +++ b/desktop/local-backend/Cargo.toml @@ -10,11 +10,12 @@ anyhow = "1" axum = "0.7" chrono = { version = "0.4", default-features = false, features = ["clock", "serde"] } directories = "5" +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } rusqlite = { version = "0.32", features = ["bundled", "chrono"] } serde = { version = "1", features = ["derive"] } serde_json = "1" sha2 = "0.10" -tokio = { version = "1", features = ["macros", "net", "rt-multi-thread", "signal"] } +tokio = { version = "1", features = ["macros", "net", "rt-multi-thread", "signal", "time"] } tower-http = { version = "0.5", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json"] } diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index d131cbba8b6..f9b447ce5d2 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -9,6 +9,8 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod config; mod health; +mod processing; +mod providers; mod routes; mod storage; @@ -33,8 +35,9 @@ async fn main() -> Result<()> { let bind_addr = config.bind_addr; let state = AppState { config: Arc::new(config), - store, + store: store.clone(), }; + processing::spawn_worker(store); let app = app(state); diff --git a/desktop/local-backend/src/processing.rs b/desktop/local-backend/src/processing.rs new file mode 100644 index 00000000000..76489158c05 --- /dev/null +++ b/desktop/local-backend/src/processing.rs @@ -0,0 +1,359 @@ +use std::time::Duration; + +use anyhow::{anyhow, Context, Result}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use tokio::time; + +use crate::{ + providers::{configured_openai_provider, ChatMessage}, + storage::{ + deterministic_id, NewActionItem, NewMemory, ProcessingJob, ProcessingJobStatus, Store, + UpdateConversation, + }, +}; + +const TITLE_WORD_LIMIT: usize = 8; +const TITLE_CHAR_LIMIT: usize = 80; +const OVERVIEW_CHAR_LIMIT: usize = 240; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ProcessingOutput { + pub title: String, + pub overview: String, + pub action_items: Vec, + pub memories: Vec, + pub provider: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ExtractedActionItem { + pub title: String, + pub description: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ExtractedMemory { + pub content: String, + pub category: Option, +} + +pub fn spawn_worker(store: Store) { + tokio::spawn(async move { + let mut interval = time::interval(Duration::from_secs(1)); + loop { + interval.tick().await; + if let Err(error) = process_next_job(&store).await { + tracing::warn!(error = %error, "local processing worker iteration failed"); + } + } + }); +} + +pub async fn process_next_job(store: &Store) -> Result> { + let Some(job) = store.processing_jobs().claim_next_queued()? else { + return Ok(None); + }; + + match process_claimed_job(store, &job).await { + Ok(result) => store + .processing_jobs() + .complete(&job.id, result) + .with_context(|| format!("failed to complete job {}", job.id)), + Err(error) => { + let message = error.to_string(); + store + .processing_jobs() + .fail(&job.id, &message) + .with_context(|| format!("failed to fail job {}", job.id)) + } + } +} + +async fn process_claimed_job(store: &Store, job: &ProcessingJob) -> Result { + if job.status != ProcessingJobStatus::Running { + return Err(anyhow!("processing job must be running before execution")); + } + + match job.kind.as_str() { + "finalize_transcript" | "process_conversation" => { + process_conversation_job(store, job).await + } + other => Err(anyhow!("unsupported processing job kind: {other}")), + } +} + +async fn process_conversation_job(store: &Store, job: &ProcessingJob) -> Result { + let conversation_id = job + .target_conversation_id + .as_ref() + .ok_or_else(|| anyhow!("processing job missing target conversation id"))?; + let segments = store.transcripts().list_for_conversation(conversation_id)?; + let transcript = segments + .iter() + .map(|segment| segment.text.as_str()) + .collect::>() + .join(" "); + let output = if let Some(provider) = configured_openai_provider(store)? { + match provider + .complete_json(processing_prompt(&transcript)) + .await + .and_then(parse_provider_output) + { + Ok(mut output) => { + output.provider = "openai_compatible".to_string(); + output + } + Err(error) => { + tracing::warn!(error = %error, "provider processing failed; using deterministic fallback"); + fallback_output(&transcript) + } + } + } else { + fallback_output(&transcript) + }; + + persist_processing_output(store, conversation_id, &output)?; + + Ok(json!({ + "conversation_id": conversation_id, + "title": output.title, + "overview": output.overview, + "action_items": output.action_items, + "memories": output.memories, + "provider": output.provider + })) +} + +fn processing_prompt(transcript: &str) -> Vec { + vec![ + ChatMessage::system( + "Return compact JSON with title, overview, action_items, and memories. \ + action_items must be an array of {title, description}. \ + memories must be an array of {content, category}.", + ), + ChatMessage::user(format!("Transcript:\n{transcript}")), + ] +} + +fn parse_provider_output(value: Value) -> Result { + let title = value["title"] + .as_str() + .unwrap_or_default() + .trim() + .to_string(); + let overview = value["overview"] + .as_str() + .unwrap_or_default() + .trim() + .to_string(); + let action_items = value["action_items"] + .as_array() + .into_iter() + .flatten() + .filter_map(|item| { + let title = item["title"].as_str()?.trim().to_string(); + if title.is_empty() { + return None; + } + Some(ExtractedActionItem { + title, + description: item["description"] + .as_str() + .unwrap_or_default() + .trim() + .to_string(), + }) + }) + .collect(); + let memories = value["memories"] + .as_array() + .into_iter() + .flatten() + .filter_map(|item| { + let content = item["content"].as_str()?.trim().to_string(); + if content.is_empty() { + return None; + } + Some(ExtractedMemory { + content, + category: item["category"] + .as_str() + .map(|category| category.to_string()), + }) + }) + .collect(); + + Ok(ProcessingOutput { + title, + overview, + action_items, + memories, + provider: "openai_compatible".to_string(), + }) +} + +pub fn fallback_output(transcript: &str) -> ProcessingOutput { + let normalized = normalize_whitespace(transcript); + ProcessingOutput { + title: fallback_title(&normalized), + overview: clip_chars(&normalized, OVERVIEW_CHAR_LIMIT), + action_items: Vec::new(), + memories: Vec::new(), + provider: "fallback".to_string(), + } +} + +fn fallback_title(normalized_transcript: &str) -> String { + if normalized_transcript.is_empty() { + return "Untitled conversation".to_string(); + } + let words = normalized_transcript + .split_whitespace() + .take(TITLE_WORD_LIMIT) + .collect::>() + .join(" "); + clip_chars(&words, TITLE_CHAR_LIMIT) +} + +fn normalize_whitespace(value: &str) -> String { + value.split_whitespace().collect::>().join(" ") +} + +fn clip_chars(value: &str, limit: usize) -> String { + value.chars().take(limit).collect() +} + +fn persist_processing_output( + store: &Store, + conversation_id: &str, + output: &ProcessingOutput, +) -> Result<()> { + store + .conversations() + .update( + conversation_id, + UpdateConversation { + title: Some(output.title.clone()), + overview: Some(output.overview.clone()), + status: Some("processed".to_string()), + ended_at: None, + metadata: None, + }, + )? + .ok_or_else(|| anyhow!("conversation missing while persisting processing output"))?; + + for (index, item) in output.action_items.iter().enumerate() { + store.action_items().create(NewActionItem { + id: deterministic_id("act", &[conversation_id, &index.to_string(), &item.title]), + conversation_id: Some(conversation_id.to_string()), + title: item.title.clone(), + description: Some(item.description.clone()), + status: Some("open".to_string()), + due_at: None, + metadata: Some(json!({"source": "local_processing"})), + })?; + } + + for (index, memory) in output.memories.iter().enumerate() { + store.memories().create(NewMemory { + id: deterministic_id( + "mem", + &[conversation_id, &index.to_string(), &memory.content], + ), + content: memory.content.clone(), + category: memory.category.clone(), + conversation_id: Some(conversation_id.to_string()), + metadata: Some(json!({"source": "local_processing"})), + })?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::storage::{NewConversation, NewProcessingJob, NewTranscriptSegment}; + + use super::*; + + #[test] + fn fallback_processing_is_deterministic_and_empty_for_items_and_memories() { + let output = fallback_output( + " Discuss launch planning.\nNext we assign owners and review the demo checklist. ", + ); + + assert_eq!( + output.title, + "Discuss launch planning. Next we assign owners and" + ); + assert_eq!( + output.overview, + "Discuss launch planning. Next we assign owners and review the demo checklist." + ); + assert_eq!(output.action_items, Vec::new()); + assert_eq!(output.memories, Vec::new()); + assert_eq!(output.provider, "fallback"); + } + + #[tokio::test] + async fn processing_job_lifecycle_persists_outputs() -> Result<()> { + let store = Store::open_in_memory()?; + let conversation_id = deterministic_id("conv", &["session-processing"]); + + store.conversations().create(NewConversation { + id: conversation_id.clone(), + session_id: "session-processing".to_string(), + title: String::new(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + store.transcripts().append(NewTranscriptSegment { + id: deterministic_id("seg", &[&conversation_id, "0"]), + conversation_id: conversation_id.clone(), + session_id: "session-processing".to_string(), + speaker_id: None, + speaker_label: None, + text: "Plan the desktop local backend MVP and verify deterministic processing." + .to_string(), + start_ms: 0, + end_ms: 2000, + segment_index: 0, + source: None, + metadata: None, + })?; + store.processing_jobs().enqueue(NewProcessingJob { + id: deterministic_id("job", &["process", &conversation_id]), + kind: "finalize_transcript".to_string(), + target_conversation_id: Some(conversation_id.clone()), + max_retries: Some(3), + payload: Some(json!({"conversation_id": conversation_id})), + })?; + + let job = process_next_job(&store) + .await? + .expect("queued job should be processed"); + assert_eq!(job.status, ProcessingJobStatus::Completed); + + let conversation = store + .conversations() + .get(job.target_conversation_id.as_ref().unwrap())? + .expect("conversation should exist"); + assert_eq!( + conversation.title, + "Plan the desktop local backend MVP and verify" + ); + assert_eq!(conversation.status, "processed"); + assert!(conversation + .overview + .starts_with("Plan the desktop local backend MVP")); + assert!(store.action_items().list()?.is_empty()); + assert!(store.memories().list()?.is_empty()); + + let result: Value = serde_json::from_str(&job.result_json)?; + assert_eq!(result["provider"], "fallback"); + + Ok(()) + } +} diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs new file mode 100644 index 00000000000..c83ac8402ee --- /dev/null +++ b/desktop/local-backend/src/providers.rs @@ -0,0 +1,171 @@ +use anyhow::{anyhow, Context, Result}; +use reqwest::{Client, Method}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +use crate::storage::Store; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +impl ChatMessage { + pub fn system(content: impl Into) -> Self { + Self { + role: "system".to_string(), + content: content.into(), + } + } + + pub fn user(content: impl Into) -> Self { + Self { + role: "user".to_string(), + content: content.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ProviderHttpRequest { + pub method: Method, + pub url: String, + pub authorization: String, + pub body: Value, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct OpenAiCompatibleConfig { + pub base_url: String, + pub model: String, + pub api_key: String, +} + +#[derive(Clone)] +pub struct OpenAiCompatibleProvider { + config: OpenAiCompatibleConfig, + client: Client, +} + +impl OpenAiCompatibleProvider { + pub fn new(config: OpenAiCompatibleConfig) -> Self { + Self { + config, + client: Client::new(), + } + } + + pub fn build_chat_completions_request( + &self, + messages: Vec, + ) -> ProviderHttpRequest { + ProviderHttpRequest { + method: Method::POST, + url: format!( + "{}/chat/completions", + self.config.base_url.trim_end_matches('/') + ), + authorization: format!("Bearer {}", self.config.api_key), + body: json!({ + "model": self.config.model, + "messages": messages, + "temperature": 0, + "response_format": {"type": "json_object"} + }), + } + } + + pub async fn complete_json(&self, messages: Vec) -> Result { + let request = self.build_chat_completions_request(messages); + let response: Value = self + .client + .request(request.method, request.url) + .header("authorization", request.authorization) + .json(&request.body) + .send() + .await + .context("failed to send OpenAI-compatible chat completion request")? + .error_for_status() + .context("OpenAI-compatible chat completion request failed")? + .json() + .await + .context("failed to decode OpenAI-compatible chat completion response")?; + + let content = response["choices"][0]["message"]["content"] + .as_str() + .ok_or_else(|| anyhow!("OpenAI-compatible response did not include message content"))?; + serde_json::from_str(content).context("provider message content was not valid JSON") + } +} + +pub fn configured_openai_provider(store: &Store) -> Result> { + let Some(config) = load_openai_config(store)? else { + return Ok(None); + }; + Ok(Some(OpenAiCompatibleProvider::new(config))) +} + +pub fn load_openai_config(store: &Store) -> Result> { + for key in ["ai_provider", "provider"] { + let Some(setting) = store.settings().get(key)? else { + continue; + }; + let value: Value = serde_json::from_str(&setting.value_json) + .with_context(|| format!("failed to parse {key} provider setting"))?; + let kind = value["kind"].as_str().unwrap_or_default(); + if kind != "openai" && kind != "openai_compatible" { + continue; + } + + let base_url = value["base_url"] + .as_str() + .unwrap_or("https://api.openai.com/v1") + .to_string(); + let model = value["model"].as_str().unwrap_or("gpt-4o-mini").to_string(); + let api_key = value["api_key"] + .as_str() + .or_else(|| value["key"].as_str()) + .unwrap_or_default() + .to_string(); + + if api_key.trim().is_empty() { + continue; + } + + return Ok(Some(OpenAiCompatibleConfig { + base_url, + model, + api_key, + })); + } + + Ok(None) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn openai_compatible_request_uses_configured_endpoint_model_and_key() { + let provider = OpenAiCompatibleProvider::new(OpenAiCompatibleConfig { + base_url: "http://localhost:11434/v1/".to_string(), + model: "local-model".to_string(), + api_key: "test-key".to_string(), + }); + + let request = provider.build_chat_completions_request(vec![ + ChatMessage::system("Return JSON."), + ChatMessage::user("Summarize."), + ]); + + assert_eq!(request.method, Method::POST); + assert_eq!(request.url, "http://localhost:11434/v1/chat/completions"); + assert_eq!(request.authorization, "Bearer test-key"); + assert_eq!(request.body["model"], "local-model"); + assert_eq!(request.body["temperature"], 0); + assert_eq!(request.body["messages"][0]["role"], "system"); + assert_eq!(request.body["messages"][1]["content"], "Summarize."); + } +} diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index 4a874c1fffc..0d0be46bb4a 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; use crate::{ + processing, storage::{ deterministic_id, NewActionItem, NewConversation, NewMemory, NewProcessingJob, NewTranscriptSegment, UpdateActionItem, UpdateConversation, UpdateMemory, UpdateProfile, @@ -58,6 +59,7 @@ pub fn router() -> Router { .delete(delete_action_item), ) .route("/v1/processing-jobs", get(list_processing_jobs)) + .route("/v1/processing-jobs/process-next", post(process_next_job)) .route("/v1/processing-jobs/status", get(processing_status)) .route("/v1/processing-jobs/:id", get(get_processing_job)) } @@ -629,6 +631,13 @@ async fn get_processing_job( Ok(Json(json!({ "processing_job": job }))) } +async fn process_next_job(State(state): State) -> ApiResult { + let job = processing::process_next_job(&state.store) + .await + .map_err(ApiError::internal)?; + Ok(Json(json!({ "processing_job": job }))) +} + async fn processing_status(State(state): State) -> ApiResult { let jobs = state .store diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index 091285da326..859efe30c62 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -787,6 +787,98 @@ impl ProcessingJobRepository { .context("failed to list processing jobs")?; collect_rows(rows) } + + pub fn claim_next_queued(&self) -> Result> { + let Some(job) = self.next_queued()? else { + return Ok(None); + }; + let now = Utc::now(); + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let changed = conn + .execute( + r#" + UPDATE processing_jobs + SET status = 'running', started_at = ?2, updated_at = ?2, last_error = NULL + WHERE id = ?1 AND status = 'queued' AND deleted_at IS NULL + "#, + params![job.id, now], + ) + .context("failed to claim processing job")?; + drop(conn); + + if changed == 0 { + Ok(None) + } else { + self.get(&job.id) + } + } + + pub fn complete(&self, id: &str, result: serde_json::Value) -> Result> { + let now = Utc::now(); + let result_json = + serde_json::to_string(&result).context("failed to serialize job result")?; + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let changed = conn + .execute( + r#" + UPDATE processing_jobs + SET status = 'completed', result_json = ?2, completed_at = ?3, updated_at = ?3, + last_error = NULL, sync_version = sync_version + 1 + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![id, result_json, now], + ) + .context("failed to complete processing job")?; + drop(conn); + + if changed == 0 { + Ok(None) + } else { + self.get(id) + } + } + + pub fn fail(&self, id: &str, error: &str) -> Result> { + let now = Utc::now(); + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let changed = conn + .execute( + r#" + UPDATE processing_jobs + SET status = 'failed', last_error = ?2, failed_at = ?3, updated_at = ?3, + retry_count = retry_count + 1, sync_version = sync_version + 1 + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![id, error, now], + ) + .context("failed to mark processing job failed")?; + drop(conn); + + if changed == 0 { + Ok(None) + } else { + self.get(id) + } + } + + fn next_queued(&self) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT id, kind, status, target_conversation_id, retry_count, max_retries, last_error, + payload_json, result_json, queued_at, started_at, completed_at, failed_at, + created_at, updated_at, deleted_at, cloud_id, sync_version, sync_state + FROM processing_jobs + WHERE status = 'queued' AND deleted_at IS NULL + ORDER BY queued_at ASC + LIMIT 1 + "#, + [], + map_processing_job, + ) + .optional() + .context("failed to fetch next queued processing job") + } } pub struct SearchRepository { @@ -1183,6 +1275,21 @@ impl SettingsRepository { collect_rows(rows) } + pub fn get(&self, key: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT key, value_json, updated_at, deleted_at, cloud_id, sync_version, sync_state + FROM local_settings + WHERE key = ?1 AND deleted_at IS NULL + "#, + params![key], + map_local_setting, + ) + .optional() + .context("failed to fetch local setting") + } + pub fn upsert_many( &self, values: serde_json::Map, From 25c4baed1d70cc6e1c9d9046f5a0ad8c6e3815f9 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 21:49:32 +0700 Subject: [PATCH 07/58] Add local-first capability boundaries --- desktop/Desktop/Sources/APIClient.swift | 51 +++++++++ desktop/Desktop/Sources/AgentVMService.swift | 16 +++ .../Sources/DesktopBackendEnvironment.swift | 74 +++++++++++++ .../Sources/MainWindow/CrispManager.swift | 9 ++ .../Sources/TranscriptionService.swift | 36 +++++++ .../Desktop/Tests/APIClientRoutingTests.swift | 101 ++++++++++++++++++ desktop/local-backend/README.md | 17 +++ 7 files changed, 304 insertions(+) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index ebc69fa0e95..1d4f56a7dda 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -305,6 +305,7 @@ enum APIError: LocalizedError { case unauthorized case httpError(statusCode: Int) case decodingError(Error) + case featureUnavailable(feature: String, reason: String) var errorDescription: String? { switch self { @@ -316,6 +317,8 @@ enum APIError: LocalizedError { return "HTTP error: \(statusCode)" case .decodingError(let error): return "Failed to decode response: \(error.localizedDescription)" + case .featureUnavailable(let feature, let reason): + return "\(feature) is unavailable: \(reason)" } } } @@ -358,6 +361,17 @@ extension APIClient { selectedBackendTarget } + private func requireCapability(_ capability: DesktopBackendEnvironment.Capability) throws { + let target = selectedBackendTarget + guard DesktopBackendEnvironment.isCapability(capability, availableIn: target.mode) else { + throw APIError.featureUnavailable( + feature: capability.rawValue, + reason: DesktopBackendEnvironment.unavailableReason(for: capability, in: target.mode) + ?? "Unavailable in the selected backend mode." + ) + } + } + private static func isoString(_ date: Date) -> String { let formatter = ISO8601DateFormatter() formatter.formatOptions = [.withInternetDateTime, .withFractionalSeconds] @@ -471,6 +485,8 @@ extension APIClient { /// - id: The conversation ID /// - visibility: The visibility level ("shared", "public", or "private") func setConversationVisibility(id: String, visibility: String = "shared") async throws { + try requireCapability(.publicSharing) + let url = URL( string: baseURL + "v1/conversations/\(id)/visibility?value=\(visibility)&visibility=\(visibility)")! @@ -490,6 +506,8 @@ extension APIClient { /// - Parameter id: The conversation ID /// - Returns: The shareable URL for the conversation func getConversationShareLink(id: String) async throws -> String { + try requireCapability(.publicSharing) + // Set visibility to shared try await setConversationVisibility(id: id, visibility: "shared") // Return the web URL for the shared conversation @@ -2331,6 +2349,8 @@ extension APIClient { /// Shares tasks and returns a shareable URL func shareTasks(taskIds: [String]) async throws -> ShareTasksResponse { + try requireCapability(.publicSharing) + struct ShareRequest: Encodable { let taskIds: [String] enum CodingKeys: String, CodingKey { @@ -3935,11 +3955,15 @@ extension APIClient { /// Fetches private cloud sync setting func getPrivateCloudSync() async throws -> PrivateCloudSyncResponse { + try requireCapability(.cloudSync) + return try await get("v1/users/private-cloud-sync") } /// Sets private cloud sync func setPrivateCloudSync(enabled: Bool) async throws { + try requireCapability(.cloudSync) + let url = URL(string: baseURL + "v1/users/private-cloud-sync?value=\(enabled)")! var request = URLRequest(url: url) request.httpMethod = "POST" @@ -4680,6 +4704,8 @@ extension APIClient { /// Share chat messages and get a shareable URL func shareChatMessages(messageIds: [String]) async throws -> ShareChatResponse { + try requireCapability(.publicSharing) + struct ShareRequest: Encodable { let message_ids: [String] } @@ -4994,6 +5020,8 @@ extension APIClient { /// Provision a cloud agent VM for the current user (fire-and-forget) func provisionAgentVM() async throws -> AgentProvisionResponse { + try requireCapability(.managedAgentVM) + return try await post("v2/agent/provision", customBaseURL: rustBackendURL) } @@ -5009,6 +5037,8 @@ extension APIClient { /// Get current agent VM status func getAgentStatus() async throws -> AgentStatusResponse? { + try requireCapability(.managedAgentVM) + return try await get("v2/agent/status", customBaseURL: rustBackendURL) } } @@ -5033,18 +5063,26 @@ struct Person: Codable, Identifiable { extension APIClient { func getUserSubscription() async throws -> UserSubscriptionResponse { + try requireCapability(.payments) + return try await get("v1/users/me/subscription") } func getAvailablePlans() async throws -> AvailablePlansResponse { + try requireCapability(.payments) + return try await get("v1/payments/available-plans") } func getOverageInfo() async throws -> OverageInfoResponse { + try requireCapability(.payments) + return try await get("v1/payments/overage-info") } func createCheckoutSession(priceId: String) async throws -> CheckoutSessionResponse { + try requireCapability(.payments) + struct Request: Encodable { let priceId: String @@ -5057,6 +5095,8 @@ extension APIClient { } func upgradeSubscription(priceId: String) async throws -> UpgradeSubscriptionResponse { + try requireCapability(.payments) + struct Request: Encodable { let priceId: String @@ -5069,6 +5109,8 @@ extension APIClient { } func createCustomerPortalSession() async throws -> CustomerPortalResponse { + try requireCapability(.payments) + return try await post("v1/payments/customer-portal") } @@ -5221,6 +5263,11 @@ extension APIClient { } func fetchChatUsageQuota() async -> ChatUsageQuota? { + guard DesktopBackendEnvironment.isCapability(.payments, availableIn: selectedBackendTarget.mode) else { + log("APIClient: Chat quota disabled in local daemon mode") + return nil + } + do { let res: ChatUsageQuota = try await get("v1/users/me/usage-quota") log( @@ -5250,6 +5297,8 @@ extension APIClient { } func fetchApiKeys() async throws -> ApiKeysResponse { + try requireCapability(.omiBackendProviderProxy) + return try await get("v1/config/api-keys", customBaseURL: rustBackendURL) } @@ -5266,6 +5315,8 @@ extension APIClient { } func synthesizeSpeech(request body: TtsSynthesizeRequest) async throws -> Data { + try requireCapability(.omiBackendProviderProxy) + let base = rustBackendURL guard !base.isEmpty, let url = URL(string: base + "v1/tts/synthesize") else { throw APIError.invalidResponse diff --git a/desktop/Desktop/Sources/AgentVMService.swift b/desktop/Desktop/Sources/AgentVMService.swift index ea8ff3c82ec..20188d7ec1a 100644 --- a/desktop/Desktop/Sources/AgentVMService.swift +++ b/desktop/Desktop/Sources/AgentVMService.swift @@ -10,6 +10,14 @@ actor AgentVMService { /// Check backend for existing VM — if none exists, run the full pipeline. /// Call this on every app launch for signed-in users. func ensureProvisioned() { + guard DesktopBackendEnvironment.isCapability( + .managedAgentVM, + availableIn: DesktopBackendEnvironment.selectedBackendTarget.mode + ) else { + log("AgentVMService: disabled in local daemon mode") + return + } + guard !isRunning else { log("AgentVMService: Pipeline already running, skipping") return @@ -58,6 +66,14 @@ actor AgentVMService { /// Kick off the full VM setup pipeline: provision → poll status → upload DB. /// Safe to call multiple times — only one pipeline runs at a time. func startPipeline() { + guard DesktopBackendEnvironment.isCapability( + .managedAgentVM, + availableIn: DesktopBackendEnvironment.selectedBackendTarget.mode + ) else { + log("AgentVMService: disabled in local daemon mode") + return + } + guard !isRunning else { log("AgentVMService: Pipeline already running, skipping") return diff --git a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift index fa56c4477d6..0c55fa858ef 100644 --- a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift +++ b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift @@ -7,6 +7,24 @@ enum DesktopBackendEnvironment { case customRemote } + enum Capability: String, CaseIterable, Equatable { + case localConversationData + case firebaseSignIn + case managedAgentVM + case omiBackendProviderProxy + case publicSharing + case cloudSync + case payments + case crispSupport + case hostedTranscription + } + + struct CapabilityState: Equatable { + let capability: Capability + let available: Bool + let reason: String? + } + struct BackendTarget: Equatable { let mode: BackendMode let baseURL: String @@ -148,6 +166,62 @@ enum DesktopBackendEnvironment { normalizedURL(environmentValue) ?? defaultLocalDaemonURL } + static func capabilities(for mode: BackendMode) -> [CapabilityState] { + Capability.allCases.map { capability in + CapabilityState( + capability: capability, + available: isCapability(capability, availableIn: mode), + reason: unavailableReason(for: capability, in: mode) + ) + } + } + + static func isCapability(_ capability: Capability, availableIn mode: BackendMode) -> Bool { + guard mode == .localDaemon else { + return true + } + + switch capability { + case .localConversationData: + return true + case .firebaseSignIn: + return true + case .managedAgentVM, + .omiBackendProviderProxy, + .publicSharing, + .cloudSync, + .payments, + .crispSupport, + .hostedTranscription: + return false + } + } + + static func unavailableReason(for capability: Capability, in mode: BackendMode) -> String? { + guard !isCapability(capability, availableIn: mode) else { + return nil + } + + switch capability { + case .managedAgentVM: + return "Managed agent VMs are cloud-only and are disabled in local daemon mode." + case .omiBackendProviderProxy: + return "Omi backend provider proxies are not used in local daemon mode. Configure direct local provider settings instead." + case .publicSharing: + return "Public sharing requires Omi cloud-hosted URLs and is unavailable in local daemon mode." + case .cloudSync: + return "Cloud sync is intentionally disabled while local data is the source of truth." + case .payments: + return "Subscription and payment-gated features require Omi cloud services." + case .crispSupport: + return "Crisp support messaging is cloud-bound and is disabled in local daemon mode." + case .hostedTranscription: + return "Hosted transcription endpoints are disabled in local daemon mode; local transcripts are stored through the local daemon." + case .localConversationData, .firebaseSignIn: + return nil + } + } + static func applyReleaseChannelDefaults() { guard shouldUseDevelopmentBackends else { return } diff --git a/desktop/Desktop/Sources/MainWindow/CrispManager.swift b/desktop/Desktop/Sources/MainWindow/CrispManager.swift index a0e9928f8a5..091af179021 100644 --- a/desktop/Desktop/Sources/MainWindow/CrispManager.swift +++ b/desktop/Desktop/Sources/MainWindow/CrispManager.swift @@ -67,6 +67,15 @@ class CrispManager: ObservableObject { /// from lifecycle unit tests that want to exercise observer registration /// without touching the network, auth state, or firing real notifications. func start(performInitialPoll: Bool = true) { + guard DesktopBackendEnvironment.isCapability( + .crispSupport, + availableIn: DesktopBackendEnvironment.selectedBackendTarget.mode + ) else { + stop() + log("CrispManager: disabled in local daemon mode") + return + } + guard !isStarted else { return } isStarted = true diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index 33e30c91220..4510b8ce6a6 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -163,6 +163,18 @@ class TranscriptionService { /// - language: Language code for transcription (e.g., "en", "uk", "ru", "multi" for auto-detect) /// - mode: Streaming mode — `.conversation` for `/v4/listen` (default), `.ptt` for `/v2/voice-message/transcribe-stream` init(language: String = "en", mode: StreamingMode = .conversation, contextKeywords: [String] = []) throws { + guard DesktopBackendEnvironment.isCapability( + .hostedTranscription, + availableIn: DesktopBackendEnvironment.selectedBackendTarget.mode + ) else { + throw TranscriptionError.webSocketError( + DesktopBackendEnvironment.unavailableReason( + for: .hostedTranscription, + in: DesktopBackendEnvironment.selectedBackendTarget.mode + ) ?? "Hosted transcription is unavailable in local daemon mode" + ) + } + self.apiKey = "" // Not needed — Python backend uses Firebase auth self.language = language self.streamingMode = mode @@ -179,6 +191,18 @@ class TranscriptionService { guard forBatchOnly else { throw TranscriptionError.webSocketError("Use init(language:) for streaming mode") } + guard DesktopBackendEnvironment.isCapability( + .hostedTranscription, + availableIn: DesktopBackendEnvironment.selectedBackendTarget.mode + ) else { + throw TranscriptionError.webSocketError( + DesktopBackendEnvironment.unavailableReason( + for: .hostedTranscription, + in: DesktopBackendEnvironment.selectedBackendTarget.mode + ) ?? "Hosted transcription is unavailable in local daemon mode" + ) + } + // Batch mode uses Firebase auth + Python backend — no DG key needed self.apiKey = "" self.language = language @@ -581,6 +605,18 @@ extension TranscriptionService { apiKey: String? = nil, contextKeywords: [String] = [] ) async throws -> String? { + guard DesktopBackendEnvironment.isCapability( + .hostedTranscription, + availableIn: DesktopBackendEnvironment.selectedBackendTarget.mode + ) else { + throw TranscriptionError.webSocketError( + DesktopBackendEnvironment.unavailableReason( + for: .hostedTranscription, + in: DesktopBackendEnvironment.selectedBackendTarget.mode + ) ?? "Hosted transcription is unavailable in local daemon mode" + ) + } + // Always use Firebase auth + Python backend let authService = await MainActor.run { AuthService.shared } let authHeader = try await authService.getAuthHeader() diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 80cb6a8d3f5..64f50e0d267 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -107,6 +107,19 @@ private func assertRoutes( XCTAssertEqual(req.method, method, "\(label): wrong HTTP method", file: file, line: line) } +private func assertUnavailable( + _ error: Error?, + capability: DesktopBackendEnvironment.Capability, + file: StaticString = #filePath, + line: UInt = #line +) { + guard let error, case APIError.featureUnavailable(let feature, _) = error else { + XCTFail("expected featureUnavailable for \(capability.rawValue), got \(String(describing: error))", file: file, line: line) + return + } + XCTAssertEqual(feature, capability.rawValue, file: file, line: line) +} + // MARK: - Tests final class APIClientRoutingTests: XCTestCase { @@ -276,6 +289,31 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertTrue(target.requiresAuth) } + func testLocalDaemonCapabilityMatrixDisablesCloudBoundFeatures() { + let capabilities = Dictionary( + uniqueKeysWithValues: DesktopBackendEnvironment.capabilities(for: .localDaemon) + .map { ($0.capability, $0) } + ) + + XCTAssertEqual(capabilities[.localConversationData]?.available, true) + XCTAssertEqual(capabilities[.firebaseSignIn]?.available, true) + XCTAssertEqual(capabilities[.managedAgentVM]?.available, false) + XCTAssertEqual(capabilities[.omiBackendProviderProxy]?.available, false) + XCTAssertEqual(capabilities[.publicSharing]?.available, false) + XCTAssertEqual(capabilities[.cloudSync]?.available, false) + XCTAssertEqual(capabilities[.payments]?.available, false) + XCTAssertEqual(capabilities[.crispSupport]?.available, false) + XCTAssertEqual(capabilities[.hostedTranscription]?.available, false) + XCTAssertNotNil(capabilities[.managedAgentVM]?.reason) + } + + func testCloudCapabilityMatrixAllowsCloudBoundFeatures() { + for state in DesktopBackendEnvironment.capabilities(for: .cloud) { + XCTAssertTrue(state.available, "\(state.capability.rawValue) should be available in cloud mode") + XCTAssertNil(state.reason) + } + } + func testBaseURLAndRustBackendURLAreIndependent() async { setenv("OMI_PYTHON_API_URL", "http://python:8080", 1) setenv("OMI_DESKTOP_API_URL", "http://rust:8787", 1) @@ -390,6 +428,69 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) } + func testLocalModeMVPConversationFlowsIgnoreInvalidCloudURLs() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "http://omi-cloud-invalid:9001", 1) + setenv("OMI_DESKTOP_API_URL", "http://omi-rust-invalid:9002", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:9876", 1) + let client = await makeTestClient() + + _ = try? await client.getConversations() + _ = try? await client.getConversation(id: "local-123") as ServerConversation + _ = try? await client.searchConversations(query: "offline") + try? await client.updateConversationTitle(id: "local-123", title: "Offline") + _ = try? await client.updateSelectedBackendSettings(["profile_name": "Offline"]) + + let requests = URLCapture.capturedRequests + XCTAssertEqual(requests.count, 5) + XCTAssertTrue(requests.allSatisfy { $0.url.host == "127.0.0.1" && $0.url.port == 9876 }) + XCTAssertTrue(requests.allSatisfy { $0.headers["Authorization"] == nil }) + XCTAssertFalse(requests.contains { $0.url.host == "omi-cloud-invalid" || $0.url.host == "omi-rust-invalid" }) + } + + func testLocalModeCloudOnlyFeaturesFailBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + do { + _ = try await client.provisionAgentVM() + XCTFail("expected managed agent VM to be unavailable") + } catch { + assertUnavailable(error, capability: .managedAgentVM) + } + + do { + _ = try await client.fetchApiKeys() + XCTFail("expected backend provider proxy to be unavailable") + } catch { + assertUnavailable(error, capability: .omiBackendProviderProxy) + } + + do { + _ = try await client.getUserSubscription() + XCTFail("expected payments to be unavailable") + } catch { + assertUnavailable(error, capability: .payments) + } + + do { + _ = try await client.shareChatMessages(messageIds: ["m1"]) + XCTFail("expected public sharing to be unavailable") + } catch { + assertUnavailable(error, capability: .publicSharing) + } + + do { + try await client.setPrivateCloudSync(enabled: true) + XCTFail("expected cloud sync to be unavailable") + } catch { + assertUnavailable(error, capability: .cloudSync) + } + + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + // -- Conversations: manual URL(string: baseURL + ...) paths (PATCH → Python) -- func testSetConversationStarredRoutesToPython() async { diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md index 360b298c45b..f347e727394 100644 --- a/desktop/local-backend/README.md +++ b/desktop/local-backend/README.md @@ -69,3 +69,20 @@ status reports `authenticated: false` because local-first mode does not imply an Omi cloud account. Cloud IDs and sync fields are retained in storage models so a future sync adapter can map local records to cloud records without making cloud state the source of truth. + +## Local-First Capability Boundaries + +Desktop local daemon mode uses `DesktopBackendEnvironment.Capability` as the +capability matrix for deciding what UI/API flows may call into cloud-bound +services. In local daemon mode: + +- Available: local conversation data, local transcript ingestion, local search, + local memories, local action items, local settings, and optional Firebase + sign-in as an account-only feature. +- Unavailable: managed agent VM provisioning/sync, Omi backend provider proxies, + hosted transcription endpoints, public sharing links, Omi cloud sync, + subscriptions/payments/quotas, and Crisp support messaging. + +Unavailable capabilities fail before building a request to Omi-hosted services. +Local conversation CRUD/search/settings flows continue to use the configured +loopback daemon URL and do not require Firebase auth. From 19776985fb4a9a684f51aac5a67b7a890e3edce1 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 21:54:41 +0700 Subject: [PATCH 08/58] Validate local backend MVP end to end --- desktop/local-backend/README.md | 16 ++ desktop/local-backend/docs/architecture.md | 144 +++++++++++++++ desktop/local-backend/tools/e2e_smoke.sh | 193 +++++++++++++++++++++ 3 files changed, 353 insertions(+) create mode 100644 desktop/local-backend/docs/architecture.md create mode 100755 desktop/local-backend/tools/e2e_smoke.sh diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md index f347e727394..6edea30df84 100644 --- a/desktop/local-backend/README.md +++ b/desktop/local-backend/README.md @@ -86,3 +86,19 @@ services. In local daemon mode: Unavailable capabilities fail before building a request to Omi-hosted services. Local conversation CRUD/search/settings flows continue to use the configured loopback daemon URL and do not require Firebase auth. + +## Architecture And E2E Validation + +The durable MVP architecture note and validation checklist live in +`docs/architecture.md`. + +Run the local daemon API smoke test: + +```bash +desktop/local-backend/tools/e2e_smoke.sh +``` + +The smoke starts the daemon on a temporary loopback port, creates and updates a +conversation, appends/finalizes transcript segments, waits for fallback +processing, checks job status, restarts the daemon, and verifies persisted +conversation/search output without Omi cloud credentials. diff --git a/desktop/local-backend/docs/architecture.md b/desktop/local-backend/docs/architecture.md new file mode 100644 index 00000000000..43bafa13c0b --- /dev/null +++ b/desktop/local-backend/docs/architecture.md @@ -0,0 +1,144 @@ +# Omi Local Backend Architecture + +This note describes the backend-free desktop MVP shape that should survive past +the prototype stage. + +## Crate Layout + +`desktop/local-backend/` is a separate Rust workspace from +`desktop/Backend-Rust`. Keeping it separate prevents Firebase, Firestore, Redis, +GCS, pusher, paywall, agent-proxy, and Cloud Run assumptions from entering the +local daemon's critical path. + +- `src/main.rs` owns startup, tracing, config loading, SQLite open, worker + startup, and Axum router assembly. +- `src/config.rs` resolves the loopback bind address and data directory from + environment variables. +- `src/health.rs` exposes health metadata for desktop startup checks. +- `src/routes.rs` is the HTTP API boundary for desktop MVP flows. +- `src/storage.rs` owns migrations, SQLite pragmas, repositories, normalized + transcript rows, FTS, sync metadata fields, and local profile/settings. +- `src/processing.rs` owns durable job execution, deterministic fallback + processing, and output persistence. +- `src/providers.rs` owns direct provider adapters. The current adapter is + OpenAI-compatible chat completions and is configured only through local + settings. + +## Data Directory And Database + +The daemon stores data under the platform app data directory by default, or +under `OMI_LOCAL_BACKEND_DATA_DIR` when set. The SQLite database is named +`omi-local-backend.sqlite`. + +SQLite is the local source of truth. Startup creates the data directory, opens +the database, runs migrations, enables WAL and foreign keys, and creates FTS +indexes over conversation title, overview, and transcript segment text. +Transcript segments are stored as normalized rows keyed by conversation and +session. JSON exists at the API boundary for client compatibility, not as the +canonical transcript representation. + +Tables include local IDs, timestamps, soft-delete fields, sync state/version, +and optional cloud IDs so later sync can map local records without making cloud +state authoritative. + +## Desktop Backend Mode Selection + +Desktop selects a backend through `DesktopBackendEnvironment`. + +- `cloud` is the default and routes to the configured Omi cloud Python backend + with auth. +- `local` / `local-daemon` routes MVP local flows to + `OMI_LOCAL_DAEMON_URL`, defaulting to `http://127.0.0.1:8765/`, without + Firebase auth. +- `custom` preserves the existing custom remote URL path for developer use. + +Local daemon mode has an explicit capability matrix. Local conversation data, +transcript ingestion, search, memories, action items, settings, and optional +Firebase sign-in remain available. Managed agent VM, Omi backend provider +proxies, hosted transcription endpoints, public sharing, cloud sync, payments, +quotas, and Crisp support are unavailable before request construction. + +## Direct Provider Adapters + +Remote AI/STT providers are allowed only when explicitly configured by the user +or developer. The daemon talks directly to configured providers; it does not use +Omi backend provider proxies. + +The MVP includes an OpenAI-compatible chat completions adapter with local +settings for base URL, model, and API key. The processing pipeline still works +without any provider key by using deterministic fallbacks: + +- title: first meaningful transcript words, bounded length +- overview: clipped transcript excerpt +- action items: empty list +- memories: empty list + +Provider keys stay in local daemon settings and are not sent to Omi-hosted +services. + +## Cloud Sync Boundary + +Cloud sync is a future optional adapter. The local database remains the source +of truth in local daemon mode. Sync metadata fields and cloud IDs are present to +support mapping, conflict handling, and outbox-style work later, but the MVP +does not require Omi cloud credentials or services for local read/write/search +or processing fallback. + +## End-To-End Validation Checklist + +Run the local daemon API smoke: + +```bash +desktop/local-backend/tools/e2e_smoke.sh +``` + +Run local daemon tests: + +```bash +cd desktop/local-backend +cargo test +``` + +Run focused desktop routing checks: + +```bash +cd desktop/Desktop +swift test --filter APIClientRoutingTests +``` + +Manual desktop local mode check: + +```bash +cd desktop/local-backend +OMI_LOCAL_BACKEND_PORT=8765 cargo run +``` + +In another terminal, launch the dev desktop app with: + +```bash +cd desktop +OMI_DESKTOP_BACKEND_MODE=local \ +OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 \ +OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001 \ +OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ +./run.sh +``` + +The app-facing local MVP routes should use the loopback daemon and should not +require Firebase, Omi backend, Redis, Firestore, GCS, pusher, or agent-proxy +credentials. Cloud mode routing basics are covered by `APIClientRoutingTests`, +which verifies default cloud URL selection, custom URL selection, and local +daemon routing without auth. + +## Known Limitations And Follow-Up Work + +- The desktop app currently has a documented dev launch contract for the daemon; + production supervision/packaging is not implemented. +- Hosted transcription is intentionally unavailable in local daemon mode. The + MVP validates transcript import/append/finalize, not direct local STT parity. +- Existing desktop GRDB/Rewind stores are not migrated into the local daemon + database yet. +- Local provider configuration exists at the daemon API/settings layer, but the + user-facing settings workflow is still thin. +- Cloud sync remains disabled until a dedicated optional sync adapter is + designed and tested. diff --git a/desktop/local-backend/tools/e2e_smoke.sh b/desktop/local-backend/tools/e2e_smoke.sh new file mode 100755 index 00000000000..b2e8dc30cd1 --- /dev/null +++ b/desktop/local-backend/tools/e2e_smoke.sh @@ -0,0 +1,193 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +DATA_DIR="${OMI_LOCAL_BACKEND_SMOKE_DATA_DIR:-$(mktemp -d /tmp/omi-local-backend-smoke.XXXXXX)}" +PORT="${OMI_LOCAL_BACKEND_SMOKE_PORT:-$(python3 - <<'PY' +import socket +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + print(s.getsockname()[1]) +PY +)}" +BASE_URL="http://127.0.0.1:${PORT}" +LOG_FILE="${DATA_DIR}/daemon.log" + +DAEMON_PID="" + +cleanup() { + if [[ -n "${DAEMON_PID}" ]] && kill -0 "${DAEMON_PID}" >/dev/null 2>&1; then + kill "${DAEMON_PID}" >/dev/null 2>&1 || true + wait "${DAEMON_PID}" >/dev/null 2>&1 || true + fi +} +trap cleanup EXIT + +json_value() { + python3 - "$1" "$2" <<'PY' +import json +import sys + +path = sys.argv[1].split(".") +with open(sys.argv[2], "r", encoding="utf-8") as handle: + value = json.load(handle) +for part in path: + if part.isdigit(): + value = value[int(part)] + else: + value = value[part] +print(value) +PY +} + +request() { + local method="$1" + local path="$2" + local body="${3:-}" + local output + output="$(mktemp)" + if [[ -n "${body}" ]]; then + curl -fsS -X "${method}" "${BASE_URL}${path}" \ + -H "Content-Type: application/json" \ + --data "${body}" \ + -o "${output}" + else + curl -fsS -X "${method}" "${BASE_URL}${path}" -o "${output}" + fi + printf '%s\n' "${output}" +} + +assert_json_value() { + local file="$1" + local path="$2" + local expected="$3" + local actual + actual="$(json_value "${path}" "${file}")" + if [[ "${actual}" != "${expected}" ]]; then + echo "Expected ${path}=${expected}, got ${actual}" >&2 + echo "Response file: ${file}" >&2 + exit 1 + fi +} + +start_daemon() { + OMI_LOCAL_BACKEND_HOST=127.0.0.1 \ + OMI_LOCAL_BACKEND_PORT="${PORT}" \ + OMI_LOCAL_BACKEND_DATA_DIR="${DATA_DIR}" \ + cargo run --quiet --manifest-path "${ROOT_DIR}/Cargo.toml" >"${LOG_FILE}" 2>&1 & + DAEMON_PID="$!" + + for _ in $(seq 1 80); do + if curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + return + fi + if ! kill -0 "${DAEMON_PID}" >/dev/null 2>&1; then + echo "Daemon exited during startup. Log:" >&2 + sed -n '1,160p' "${LOG_FILE}" >&2 || true + exit 1 + fi + sleep 0.25 + done + + echo "Timed out waiting for daemon health at ${BASE_URL}/health. Log:" >&2 + sed -n '1,160p' "${LOG_FILE}" >&2 || true + exit 1 +} + +stop_daemon() { + if [[ -n "${DAEMON_PID}" ]] && kill -0 "${DAEMON_PID}" >/dev/null 2>&1; then + kill "${DAEMON_PID}" >/dev/null 2>&1 || true + wait "${DAEMON_PID}" >/dev/null 2>&1 || true + fi + DAEMON_PID="" +} + +wait_for_completed_job() { + local job_id="$1" + local job_file + for _ in $(seq 1 30); do + job_file="$(request GET "/v1/processing-jobs/${job_id}")" + if [[ "$(json_value "processing_job.status" "${job_file}")" == "completed" ]]; then + printf '%s\n' "${job_file}" + return + fi + request POST "/v1/processing-jobs/process-next" >/dev/null || true + sleep 0.25 + done + echo "Processing job ${job_id} did not complete" >&2 + exit 1 +} + +echo "Starting local daemon smoke on ${BASE_URL}" +echo "Data dir: ${DATA_DIR}" + +start_daemon + +health_file="$(request GET /health)" +assert_json_value "${health_file}" "service" "omi-local-backend" +assert_json_value "${health_file}" "mode" "local" + +profile_file="$(request GET /profile/status)" +assert_json_value "${profile_file}" "mode" "local" +assert_json_value "${profile_file}" "authenticated" "False" + +conversation_file="$(request POST /v1/conversations '{ + "id": "conv-e2e-smoke", + "session_id": "session-e2e-smoke", + "title": "Smoke seed", + "overview": "Created by local smoke" +}')" +assert_json_value "${conversation_file}" "conversation.id" "conv-e2e-smoke" + +segment_file="$(request POST /v1/conversations/conv-e2e-smoke/transcript-segments '{ + "id": "seg-e2e-smoke-0", + "text": "Plan the backend free desktop MVP and verify deterministic local processing.", + "start_ms": 0, + "end_ms": 2400, + "segment_index": 0, + "source": "smoke" +}')" +assert_json_value "${segment_file}" "transcript_segment.id" "seg-e2e-smoke-0" + +updated_file="$(request PATCH /v1/conversations/conv-e2e-smoke '{ + "title": "Smoke updated", + "overview": "Updated before processing" +}')" +assert_json_value "${updated_file}" "conversation.title" "Smoke updated" + +list_file="$(request GET /v1/conversations)" +assert_json_value "${list_file}" "conversations.0.id" "conv-e2e-smoke" + +search_file="$(request GET '/v1/search/conversations?q=deterministic')" +assert_json_value "${search_file}" "results.0.conversation_id" "conv-e2e-smoke" + +job_file="$(request POST /v1/conversations/conv-e2e-smoke/finalize-transcript)" +job_id="$(json_value "processing_job.id" "${job_file}")" +completed_job_file="$(wait_for_completed_job "${job_id}")" +assert_json_value "${completed_job_file}" "processing_job.status" "completed" + +status_file="$(request GET /v1/processing-jobs/status)" +assert_json_value "${status_file}" "failed" "0" + +processed_file="$(request GET /v1/conversations/conv-e2e-smoke)" +assert_json_value "${processed_file}" "conversation.status" "processed" +assert_json_value "${processed_file}" "conversation.title" "Plan the backend free desktop MVP and verify" + +settings_file="$(request PUT /v1/settings '{ + "local_first": true, + "provider.kind": "fallback" +}')" +assert_json_value "${settings_file}" "settings.0.key" "local_first" + +stop_daemon +start_daemon + +persisted_file="$(request GET /v1/conversations/conv-e2e-smoke)" +assert_json_value "${persisted_file}" "conversation.status" "processed" +assert_json_value "${persisted_file}" "transcript_segments.0.text" "Plan the backend free desktop MVP and verify deterministic local processing." + +persisted_search_file="$(request GET '/v1/search/conversations?q=backend')" +assert_json_value "${persisted_search_file}" "results.0.conversation_id" "conv-e2e-smoke" + +echo "Local backend E2E smoke passed." +echo "OMI_DESKTOP_BACKEND_MODE=local OMI_LOCAL_DAEMON_URL=${BASE_URL}" From 1d7cef1c5721b09333aa9f866ef79352e808b701 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 22:01:45 +0700 Subject: [PATCH 09/58] Add local MVP transcript import runbook --- desktop/local-backend/README.md | 17 ++ .../local-backend/docs/local-mvp-runbook.md | 188 ++++++++++++++++ .../local-backend/tools/import_transcript.py | 211 ++++++++++++++++++ 3 files changed, 416 insertions(+) create mode 100644 desktop/local-backend/docs/local-mvp-runbook.md create mode 100755 desktop/local-backend/tools/import_transcript.py diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md index 6edea30df84..7e363db1a1f 100644 --- a/desktop/local-backend/README.md +++ b/desktop/local-backend/README.md @@ -7,6 +7,9 @@ dependencies. ## Run Locally +For a complete user-test walkthrough, including desktop launch environment and +transcript import, see `docs/local-mvp-runbook.md`. + ```bash cd desktop/local-backend cargo run @@ -52,6 +55,7 @@ The local daemon exposes JSON endpoints for the desktop MVP: - `GET /v1/processing-jobs` - `GET /v1/processing-jobs/:id` - `GET /v1/processing-jobs/status` +- `POST /v1/processing-jobs/process-next` Finalizing transcript ingestion currently enqueues a local `finalize_transcript` processing job. Later processing workers can consume the same durable @@ -87,6 +91,19 @@ Unavailable capabilities fail before building a request to Omi-hosted services. Local conversation CRUD/search/settings flows continue to use the configured loopback daemon URL and do not require Firebase auth. +Hosted transcription endpoints are intentionally unavailable in local daemon +mode. Direct live STT parity is not part of this MVP unless a future direct +provider path is added. For user testing, import transcript text or JSON +fixtures through the supported helper: + +```bash +desktop/local-backend/tools/import_transcript.py /path/to/transcript.txt +``` + +The helper creates a conversation, appends transcript segments, finalizes +ingestion, waits for local processing, verifies search, and prints read/search +commands for the imported conversation. + ## Architecture And E2E Validation The durable MVP architecture note and validation checklist live in diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md new file mode 100644 index 00000000000..76d5d53e936 --- /dev/null +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -0,0 +1,188 @@ +# Local MVP User-Test Runbook + +This runbook starts the local daemon and imports transcript data without Omi +hosted backend services. It is the supported MVP path for testing local +conversation storage, transcript ingestion, processing fallback, and search. + +## Prerequisites + +- macOS with Xcode command line tools installed. +- Rust toolchain with `cargo`. +- Python 3 for the import helper. +- `curl` for health and API checks. +- For desktop app testing: the normal desktop development prerequisites from + `desktop/README.md` and `desktop/run.sh`. + +No Firebase, Omi Python backend, Rust cloud backend, Redis, Firestore, GCS, +pusher, or agent-proxy credentials are required for the local daemon path. + +## Start The Local Daemon + +From the repo root: + +```bash +cd desktop/local-backend +cargo run +``` + +The daemon listens on `127.0.0.1:8765` by default. To keep test data isolated: + +```bash +cd desktop/local-backend +OMI_LOCAL_BACKEND_DATA_DIR=/tmp/omi-local-mvp \ +OMI_LOCAL_BACKEND_PORT=8765 \ +cargo run +``` + +Verify health from another terminal: + +```bash +curl http://127.0.0.1:8765/health +``` + +Expected signals: + +- `service` is `omi-local-backend`. +- `mode` is `local`. +- `data_dir` points at the daemon data directory. + +If health does not respond, check that the daemon terminal is still running and +that no other process is already using the selected port. + +## Launch Desktop In Local Daemon Mode + +The desktop app selects local mode through environment variables: + +```bash +cd desktop +OMI_DESKTOP_BACKEND_MODE=local \ +OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 \ +OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001 \ +OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ +./run.sh +``` + +The invalid cloud URLs make accidental cloud routing obvious during a user test. +Local conversation, transcript, memory, action item, settings, and search flows +should use `OMI_LOCAL_DAEMON_URL`. + +## Import Transcript Data + +Hosted transcription endpoints are intentionally disabled in local daemon mode. +Direct live STT parity is not part of the current MVP unless a future direct +provider path is added. For local MVP testing, import or append transcript text +and then finalize processing. + +Create a plain text fixture: + +```bash +cat >/tmp/omi-local-transcript.txt <<'EOF' +We reviewed the local-first desktop MVP and confirmed it can store transcript +segments without Firebase or Firestore. + +The next action is to test local search and processing status from the desktop +app. +EOF +``` + +Import it: + +```bash +desktop/local-backend/tools/import_transcript.py /tmp/omi-local-transcript.txt +``` + +The helper: + +- creates a local conversation through `POST /v1/conversations` +- appends transcript segment rows through + `POST /v1/conversations/:id/transcript-segments` +- finalizes ingestion through `POST /v1/conversations/:id/finalize-transcript` +- waits for the local processing job to complete +- verifies search finds the imported transcript text +- prints the conversation ID plus read and search `curl` commands + +JSON fixtures are also supported. The file may be a list of segment strings, a +list of segment objects, or an object with conversation fields plus `segments` +or `transcript_segments`: + +```json +{ + "title": "Local MVP Fixture", + "overview": "Imported for a local daemon user test", + "segments": [ + { + "speaker_label": "Alex", + "text": "Local daemon mode keeps transcript data on this machine.", + "start_ms": 0, + "end_ms": 2400 + }, + { + "speaker_label": "Sam", + "text": "Search should find this imported fixture without cloud credentials.", + "start_ms": 2400, + "end_ms": 5200 + } + ] +} +``` + +Useful helper options: + +```bash +desktop/local-backend/tools/import_transcript.py fixture.json \ + --base-url http://127.0.0.1:8765 \ + --title "User Test Import" \ + --search-query "imported fixture" +``` + +## Read And Search Imported Conversations + +Use the commands printed by the import helper, or run them directly: + +```bash +curl http://127.0.0.1:8765/v1/conversations/ +curl 'http://127.0.0.1:8765/v1/search/conversations?q=local+first' +curl http://127.0.0.1:8765/v1/processing-jobs/status +``` + +After finalization, the conversation status should become `processed`. Search +results should include the imported conversation when the query appears in the +title, overview, or transcript text. + +## What Works Without Omi-Hosted Services + +- Local daemon startup on loopback. +- SQLite-backed conversation create/read/update/delete. +- Transcript segment append and finalize. +- Local fallback processing for title and overview. +- Local full-text search over conversation and transcript text. +- Local memories, action items, profile, and settings endpoints. +- Desktop routing for local MVP flows without Firebase auth. + +## What Still Needs Provider Keys Or Cloud Mode + +- Live hosted transcription and Deepgram/Omi transcription endpoints require + cloud mode today. +- Omi backend provider proxies, quota checks, subscriptions, payments, public + sharing, Crisp support, managed agent VMs, and cloud sync require cloud mode. +- Remote AI provider calls from the local daemon require explicit local provider + settings/API keys. Without them, processing uses deterministic fallback output. +- Fully offline local LLM/STT support is outside the current MVP. + +## Known Environment Blockers + +- `curl /health` cannot connect: the daemon is not running, the port is wrong, + or another process is bound to the port. Restart with + `OMI_LOCAL_BACKEND_PORT=` and update `OMI_LOCAL_DAEMON_URL`. +- `cargo run` fails before listening: inspect the daemon terminal output for + Rust build errors or data directory permission errors. +- Import helper reports HTTP 404 for local routes: verify `--base-url` points to + the local daemon, not the cloud backend. +- Import helper times out waiting for processing: run + `curl http://127.0.0.1:8765/v1/processing-jobs/status` and check the daemon + log for processing errors. +- Desktop still calls cloud endpoints: confirm the app was launched with + `OMI_DESKTOP_BACKEND_MODE=local` and that `OMI_LOCAL_DAEMON_URL` includes the + daemon port. +- Desktop launch/auth callback issues in custom test builds: keep the app name + and bundle suffix aligned as described in the repo desktop agent rules. diff --git a/desktop/local-backend/tools/import_transcript.py b/desktop/local-backend/tools/import_transcript.py new file mode 100755 index 00000000000..9fee82ccd33 --- /dev/null +++ b/desktop/local-backend/tools/import_transcript.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python3 +"""Import a transcript fixture into the Omi local backend MVP.""" + +from __future__ import annotations + +import argparse +import json +import re +import sys +import time +import urllib.error +import urllib.parse +import urllib.request +from pathlib import Path +from typing import Any + + +DEFAULT_BASE_URL = "http://127.0.0.1:8765" + + +def request_json(method: str, base_url: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]: + data = None + headers = {"Accept": "application/json"} + if body is not None: + data = json.dumps(body).encode("utf-8") + headers["Content-Type"] = "application/json" + + request = urllib.request.Request( + f"{base_url.rstrip('/')}{path}", + data=data, + headers=headers, + method=method, + ) + try: + with urllib.request.urlopen(request, timeout=10) as response: + payload = response.read().decode("utf-8") + except urllib.error.HTTPError as error: + payload = error.read().decode("utf-8", errors="replace") + raise RuntimeError(f"{method} {path} failed with HTTP {error.code}: {payload}") from error + except urllib.error.URLError as error: + raise RuntimeError(f"{method} {path} failed: {error.reason}") from error + + return json.loads(payload) if payload else {} + + +def load_fixture(path: Path) -> tuple[dict[str, Any], list[dict[str, Any]]]: + raw = path.read_text(encoding="utf-8") + if path.suffix.lower() == ".json": + return load_json_fixture(json.loads(raw), path) + return ( + {"title": path.stem.replace("_", " ").replace("-", " ").strip().title() or "Imported Transcript"}, + plain_text_segments(raw), + ) + + +def load_json_fixture(value: Any, path: Path) -> tuple[dict[str, Any], list[dict[str, Any]]]: + conversation: dict[str, Any] = { + "title": path.stem.replace("_", " ").replace("-", " ").strip().title() or "Imported Transcript" + } + segment_values: Any = value + + if isinstance(value, dict): + conversation = { + key: value[key] + for key in ("id", "session_id", "title", "overview", "started_at", "metadata") + if key in value + } + segment_values = value.get("segments", value.get("transcript_segments", value.get("transcript", []))) + + if not isinstance(segment_values, list): + raise ValueError("JSON fixture must be a list, or an object with segments/transcript_segments") + + segments = [] + for index, item in enumerate(segment_values): + if isinstance(item, str): + text = item.strip() + segment = {"text": text} + elif isinstance(item, dict): + text = str(item.get("text", "")).strip() + segment = dict(item) + segment["text"] = text + else: + raise ValueError(f"Segment {index} must be a string or object") + + if text: + segment.setdefault("start_ms", index * 2_000) + segment.setdefault("end_ms", segment["start_ms"] + 2_000) + segment.setdefault("segment_index", index) + segment.setdefault("source", "local_import") + segments.append(segment) + + return conversation, segments + + +def plain_text_segments(raw: str) -> list[dict[str, Any]]: + chunks = [chunk.strip() for chunk in raw.replace("\r\n", "\n").split("\n\n") if chunk.strip()] + if not chunks: + chunks = [line.strip() for line in raw.splitlines() if line.strip()] + + return [ + { + "text": chunk, + "start_ms": index * 2_000, + "end_ms": (index + 1) * 2_000, + "segment_index": index, + "source": "local_import", + } + for index, chunk in enumerate(chunks) + ] + + +def wait_for_job(base_url: str, job_id: str, timeout_seconds: float) -> dict[str, Any]: + deadline = time.monotonic() + timeout_seconds + last_job: dict[str, Any] = {} + while time.monotonic() < deadline: + last_job = request_json("GET", base_url, f"/v1/processing-jobs/{urllib.parse.quote(job_id)}")[ + "processing_job" + ] + if last_job["status"] in {"completed", "failed"}: + return last_job + request_json("POST", base_url, "/v1/processing-jobs/process-next") + time.sleep(0.25) + raise TimeoutError(f"Processing job {job_id} did not finish; last status: {last_job.get('status')}") + + +def default_search_query(segments: list[dict[str, Any]]) -> str: + for segment in segments: + words = [word.strip(".,:;!?()[]{}\"'") for word in segment["text"].split()] + words = [word for word in words if len(word) > 3] + if words: + return " ".join(words[:3]) + return segments[0]["text"][:24] + + +def search_safe_query(raw: str) -> str: + query = re.sub(r"[^\w\s]", " ", raw, flags=re.UNICODE) + query = " ".join(query.split()) + return query or raw + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("transcript", type=Path, help="Plain text or JSON transcript fixture") + parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help=f"Local daemon URL, default {DEFAULT_BASE_URL}") + parser.add_argument("--title", help="Override the conversation title") + parser.add_argument("--overview", help="Override the conversation overview") + parser.add_argument("--conversation-id", help="Use a specific conversation ID") + parser.add_argument("--session-id", help="Use a specific session ID") + parser.add_argument("--search-query", help="Search query to verify after import") + parser.add_argument("--wait-timeout", type=float, default=15.0, help="Seconds to wait for finalize processing") + args = parser.parse_args() + + conversation, segments = load_fixture(args.transcript) + if not segments: + raise ValueError("Transcript fixture did not contain any non-empty segments") + + for key, value in ( + ("id", args.conversation_id), + ("session_id", args.session_id), + ("title", args.title), + ("overview", args.overview), + ): + if value: + conversation[key] = value + + request_json("GET", args.base_url, "/health") + created = request_json("POST", args.base_url, "/v1/conversations", conversation)["conversation"] + conversation_id = created["id"] + + for index, segment in enumerate(segments): + payload = dict(segment) + payload.setdefault("segment_index", index) + request_json( + "POST", + args.base_url, + f"/v1/conversations/{urllib.parse.quote(conversation_id)}/transcript-segments", + payload, + ) + + queued_job = request_json( + "POST", + args.base_url, + f"/v1/conversations/{urllib.parse.quote(conversation_id)}/finalize-transcript", + )["processing_job"] + completed_job = wait_for_job(args.base_url, queued_job["id"], args.wait_timeout) + if completed_job["status"] != "completed": + raise RuntimeError(f"Finalize processing failed for {conversation_id}: {completed_job}") + + search_query = search_safe_query(args.search_query or default_search_query(segments)) + encoded_query = urllib.parse.urlencode({"q": search_query}) + search = request_json("GET", args.base_url, f"/v1/search/conversations?{encoded_query}") + if not any(result.get("conversation_id") == conversation_id for result in search.get("results", [])): + raise RuntimeError(f"Search for {search_query!r} did not find imported conversation {conversation_id}") + + print(f"Imported conversation: {conversation_id}") + print(f"Segments imported: {len(segments)}") + print(f"Finalize job: {completed_job['id']} ({completed_job['status']})") + print(f"Read command: curl {args.base_url.rstrip('/')}/v1/conversations/{conversation_id}") + print( + "Search command: " + f"curl '{args.base_url.rstrip('/')}/v1/search/conversations?{urllib.parse.urlencode({'q': search_query})}'" + ) + return 0 + + +if __name__ == "__main__": + try: + raise SystemExit(main()) + except Exception as error: + print(f"import_transcript.py: {error}", file=sys.stderr) + raise SystemExit(1) From dfc3f3593c6fd5347be209403403e5d9cf1f143c Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 22:09:18 +0700 Subject: [PATCH 10/58] Add local daemon dev supervision --- desktop/.env.example | 7 +- desktop/README.md | 6 + .../local-backend/docs/local-mvp-runbook.md | 22 ++- desktop/run.sh | 143 ++++++++++++++++-- 4 files changed, 163 insertions(+), 15 deletions(-) diff --git a/desktop/.env.example b/desktop/.env.example index 1a37940c8a7..25f78bbcd58 100644 --- a/desktop/.env.example +++ b/desktop/.env.example @@ -26,11 +26,14 @@ OMI_PYTHON_API_URL=https://api.omi.me # Desktop app data backend mode for MVP local-first flows. # cloud: Omi-hosted Python backend (default) -# local: local daemon; start it with `cd desktop/local-backend && cargo run`, -# then verify `curl http://127.0.0.1:8765/health` +# local: local daemon; use `OMI_LOCAL_DAEMON_SUPERVISE=1 ./run.sh` to let +# desktop/run.sh start/check it, or start it manually with +# `cd desktop/local-backend && cargo run` and verify +# `curl http://127.0.0.1:8765/health` # custom: custom remote URL from OMI_PYTHON_API_URL # OMI_DESKTOP_BACKEND_MODE=cloud # OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 +# OMI_LOCAL_DAEMON_SUPERVISE=1 # Firebase Web API key — fetched from backend via /v1/config/api-keys # Only set this for local dev without a backend running diff --git a/desktop/README.md b/desktop/README.md index 3c6865219f5..0317ee95127 100644 --- a/desktop/README.md +++ b/desktop/README.md @@ -7,6 +7,7 @@ macOS app for OMI — always-on AI companion. Swift/SwiftUI frontend, Rust backe ``` Desktop/ Swift/SwiftUI macOS app (SPM package) Backend-Rust/ Rust API server (Firestore, Redis, auth, LLM) +local-backend/ Local-first daemon for desktop MVP testing agent/ Agent runtime for multi-provider chat (TypeScript) agent-cloud/ Cloud agent service dmg-assets/ DMG installer resources @@ -22,10 +23,15 @@ Requires macOS 14.0+, Rust toolchain, and code signing with an Apple Developer I # Run with the prod backend (skips local Rust + tunnel) ./run.sh --yolo + +# Run in local daemon mode and let the dev launcher start/check the daemon +OMI_DESKTOP_BACKEND_MODE=local OMI_LOCAL_DAEMON_SUPERVISE=1 ./run.sh ``` `run.sh` auto-detects an `Apple Development` or `Developer ID Application` signing identity from your login keychain. Override with `OMI_SIGN_IDENTITY="..." ./run.sh`. +Local daemon mode uses `http://127.0.0.1:8765` by default. To manage the daemon yourself, run `cd desktop/local-backend && cargo run`, verify `curl http://127.0.0.1:8765/health`, then launch desktop with `OMI_DESKTOP_BACKEND_MODE=local ./run.sh`. The launcher only targets the dev app bundle (`Omi Dev.app` / `com.omi.desktop-dev`) and does not modify the production app. + ## License MIT diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index 76d5d53e936..d2504da60c8 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -51,11 +51,15 @@ that no other process is already using the selected port. ## Launch Desktop In Local Daemon Mode -The desktop app selects local mode through environment variables: +For developer/user-test runs, the desktop launcher can supervise the local +daemon. It checks `/health`, starts `desktop/local-backend` only if the daemon +is unreachable, and stops only the daemon process it started when the launcher +exits: ```bash cd desktop OMI_DESKTOP_BACKEND_MODE=local \ +OMI_LOCAL_DAEMON_SUPERVISE=1 \ OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 \ OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001 \ OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ @@ -64,7 +68,21 @@ OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ The invalid cloud URLs make accidental cloud routing obvious during a user test. Local conversation, transcript, memory, action item, settings, and search flows -should use `OMI_LOCAL_DAEMON_URL`. +should use `OMI_LOCAL_DAEMON_URL`. The launcher path targets the development app +bundle only (`Omi Dev.app` / `com.omi.desktop-dev`) and must not be used to +manage `/Applications/omi.app`. + +To keep using a manually managed daemon, start it first and launch the desktop +app with the same local-mode environment: + +```bash +cd desktop +OMI_DESKTOP_BACKEND_MODE=local \ +OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 \ +OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001 \ +OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ +./run.sh +``` ## Import Transcript Data diff --git a/desktop/run.sh b/desktop/run.sh index f4f3d58d3a7..0db58257756 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -21,8 +21,11 @@ Options (via environment variables): OMI_PYTHON_API_URL="..." Python backend URL (subscriptions, payments, etc; default: https://api.omi.me) OMI_SIGN_IDENTITY="..." Code signing identity (auto-detected if not set) OMI_ENABLE_LOCAL_AUTOMATION=1 Enable agent-swift automation bridge + OMI_DESKTOP_BACKEND_MODE=local Route MVP data flows to the local daemon + OMI_LOCAL_DAEMON_SUPERVISE=1 In local mode, start desktop/local-backend if /health is unreachable + OMI_LOCAL_DAEMON_URL="..." Local daemon URL (default: http://127.0.0.1:8765) -Required files: +Required files for cloud backend mode: Backend-Rust/.env Environment variables (copy from ../.env.example) Backend-Rust/google-credentials.json GCP service account key @@ -36,6 +39,8 @@ Examples: ./run.sh # Full local dev (backend + tunnel + app) OMI_SKIP_BACKEND=1 ./run.sh # App only (backend running elsewhere) OMI_SKIP_TUNNEL=1 ./run.sh # No Cloudflare tunnel (use direct URL) + OMI_DESKTOP_BACKEND_MODE=local OMI_LOCAL_DAEMON_SUPERVISE=1 ./run.sh + # Local daemon mode with dev supervisor ./run.sh --yolo # Quick start: use prod backend, no local services USAGE exit 0 @@ -141,11 +146,57 @@ fi # Backend configuration (Rust) BACKEND_DIR="$(cd "$(dirname "$0")/Backend-Rust" && pwd)" +LOCAL_DAEMON_DIR="$(cd "$(dirname "$0")/local-backend" && pwd)" BACKEND_PID="" +LOCAL_DAEMON_PID="" TUNNEL_PID="" TUNNEL_URL="${TUNNEL_URL:-}" -# Cleanup function to stop backend, auth, and tunnel on exit +is_local_daemon_mode() { + local mode + mode="$(printf '%s' "${OMI_DESKTOP_BACKEND_MODE:-${OMI_BACKEND_MODE:-}}" | tr '[:upper:]' '[:lower:]')" + case "$mode" in + local|local-daemon|local_daemon|daemon) return 0 ;; + *) return 1 ;; + esac +} + +normalize_local_daemon_url() { + local url="${OMI_LOCAL_DAEMON_URL:-http://127.0.0.1:8765}" + url="${url%/}" + export OMI_LOCAL_DAEMON_URL="$url" +} + +configure_local_daemon_mode() { + if ! is_local_daemon_mode; then + return + fi + + normalize_local_daemon_url + export OMI_SKIP_BACKEND="${OMI_SKIP_BACKEND:-1}" + export OMI_SKIP_TUNNEL="${OMI_SKIP_TUNNEL:-1}" + + if [ -z "${OMI_LOCAL_BACKEND_PORT:-}" ]; then + OMI_LOCAL_BACKEND_PORT="$(python3 - "$OMI_LOCAL_DAEMON_URL" <<'PY' +from urllib.parse import urlparse +import sys + +parsed = urlparse(sys.argv[1]) +print(parsed.port or 8765) +PY +)" + export OMI_LOCAL_BACKEND_PORT + fi + export OMI_LOCAL_BACKEND_HOST="${OMI_LOCAL_BACKEND_HOST:-127.0.0.1}" +} + +local_daemon_health_ok() { + curl -fsS "${OMI_LOCAL_DAEMON_URL}/health" >/dev/null 2>&1 +} + +configure_local_daemon_mode + +# Cleanup function to stop only services started by this dev launcher. cleanup() { if [ -n "$TUNNEL_PID" ] && kill -0 "$TUNNEL_PID" 2>/dev/null; then echo "Stopping tunnel (PID: $TUNNEL_PID)..." @@ -155,6 +206,11 @@ cleanup() { echo "Stopping backend (PID: $BACKEND_PID)..." kill "$BACKEND_PID" 2>/dev/null || true fi + if [ -n "$LOCAL_DAEMON_PID" ] && kill -0 "$LOCAL_DAEMON_PID" 2>/dev/null; then + echo "Stopping local daemon (PID: $LOCAL_DAEMON_PID)..." + kill "$LOCAL_DAEMON_PID" 2>/dev/null || true + wait "$LOCAL_DAEMON_PID" 2>/dev/null || true + fi } trap cleanup EXIT @@ -169,7 +225,9 @@ auth_debug "BEFORE pkill: ALL_KEYS=$(defaults read "$BUNDLE_ID" 2>&1 | grep -E ' # Only kill the dev app — never touch Omi Beta (production) pkill -f "$APP_NAME.app" 2>/dev/null || true # Note: don't pkill cloudflared here — other agents may have tunnels running on this machine -# Kill any old Rust backend by process name (port-agnostic) +# Kill any old dev cloud-shaped Rust backend by process name (port-agnostic). +# Do not kill omi-local-backend here; local daemon mode may use a manually +# managed daemon that this launcher should only detect. pgrep -f "omi-desktop-backend" 2>/dev/null | while read pid; do substep "Killing old backend (PID: $pid)" kill -9 "$pid" 2>/dev/null || true @@ -240,7 +298,7 @@ if [ ! -f ".env" ] && [ -f "../backend/.env" ]; then elif [ ! -f ".env" ] && [ -f "../Backend/.env" ]; then cp "../Backend/.env" ".env" fi -if [ ! -f ".env" ] && [ "$1" != "--yolo" ]; then +if [ ! -f ".env" ] && [ "$1" != "--yolo" ] && ! is_local_daemon_mode; then echo "" echo "=== First-time setup ===" echo "No .env file found at $BACKEND_DIR/.env" @@ -278,6 +336,7 @@ fi if [ -f "$BACKEND_DIR/.env" ]; then set -a; source "$BACKEND_DIR/.env"; set +a fi +configure_local_daemon_mode # Read backend PORT from env (default: 10201, never use 8080) BACKEND_PORT="${PORT:-10201}" @@ -308,12 +367,60 @@ if [ -z "$FIREBASE_PROJECT_ID" ] && [ "${OMI_SKIP_BACKEND:-0}" != "1" ]; then echo " FIREBASE_PROJECT_ID=based-hardware-dev # dev Firestore" exit 1 fi -if [ -n "$FIREBASE_AUTH_PROJECT_ID" ]; then +if is_local_daemon_mode; then + substep "Local daemon mode: skipping cloud backend credential requirements" +elif [ -n "$FIREBASE_AUTH_PROJECT_ID" ]; then substep "Auth project: tokens validated against $FIREBASE_AUTH_PROJECT_ID, Firestore on $FIREBASE_PROJECT_ID" + substep "Firebase project: $FIREBASE_PROJECT_ID | Backend port: $BACKEND_PORT" +else + substep "Firebase project: $FIREBASE_PROJECT_ID | Backend port: $BACKEND_PORT" fi -substep "Firebase project: $FIREBASE_PROJECT_ID | Backend port: $BACKEND_PORT" cd - > /dev/null +# ─── Local daemon health preflight / dev supervision ────────────────── +if is_local_daemon_mode; then + step "Checking local daemon health..." + if local_daemon_health_ok; then + substep "Local daemon is ready at $OMI_LOCAL_DAEMON_URL" + elif [ "${OMI_LOCAL_DAEMON_SUPERVISE:-0}" = "1" ]; then + LOCAL_DAEMON_LOG="${OMI_LOCAL_DAEMON_LOG:-/tmp/omi-local-backend-dev.log}" + substep "Starting local daemon from $LOCAL_DAEMON_DIR" + cd "$LOCAL_DAEMON_DIR" + cargo run --quiet > "$LOCAL_DAEMON_LOG" 2>&1 & + LOCAL_DAEMON_PID=$! + cd - > /dev/null + + for i in {1..80}; do + if local_daemon_health_ok; then + substep "Local daemon is ready at $OMI_LOCAL_DAEMON_URL (PID: $LOCAL_DAEMON_PID)" + break + fi + if ! kill -0 "$LOCAL_DAEMON_PID" 2>/dev/null; then + echo "ERROR: Local daemon exited during startup. Log:" + sed -n '1,160p' "$LOCAL_DAEMON_LOG" 2>/dev/null || true + exit 1 + fi + sleep 0.25 + done + + if ! local_daemon_health_ok; then + echo "ERROR: Timed out waiting for local daemon health at $OMI_LOCAL_DAEMON_URL/health" + echo "Log: $LOCAL_DAEMON_LOG" + sed -n '1,160p' "$LOCAL_DAEMON_LOG" 2>/dev/null || true + exit 1 + fi + else + echo "ERROR: Local daemon mode is enabled, but $OMI_LOCAL_DAEMON_URL/health is unreachable." + echo "" + echo "Start it manually:" + echo " cd desktop/local-backend && cargo run" + echo "" + echo "Or let this dev launcher supervise it:" + echo " OMI_DESKTOP_BACKEND_MODE=local OMI_LOCAL_DAEMON_SUPERVISE=1 ./run.sh" + exit 1 + fi +fi + # ─── Start Rust backend ─────────────────────────────────────────────── if [ "${OMI_SKIP_BACKEND:-0}" != "1" ]; then step "Starting Rust backend..." @@ -481,6 +588,18 @@ elif [ -f ".env.app" ]; then else touch "$APP_BUNDLE/Contents/Resources/.env" fi + +set_bundle_env() { + local key="$1" + local value="$2" + local env_file="$APP_BUNDLE/Contents/Resources/.env" + if grep -q "^${key}=" "$env_file"; then + sed -i '' "s|^${key}=.*|${key}=${value}|" "$env_file" + else + echo "${key}=${value}" >> "$env_file" + fi +} + # Set OMI_DESKTOP_API_URL: tunnel URL if available, otherwise from .env or local backend if [ -n "$TUNNEL_URL" ]; then EFFECTIVE_API_URL="$TUNNEL_URL" @@ -489,12 +608,14 @@ elif [ -n "$OMI_DESKTOP_API_URL" ]; then else EFFECTIVE_API_URL="http://localhost:$BACKEND_PORT" fi -if grep -q "^OMI_DESKTOP_API_URL=" "$APP_BUNDLE/Contents/Resources/.env"; then - sed -i '' "s|^OMI_DESKTOP_API_URL=.*|OMI_DESKTOP_API_URL=$EFFECTIVE_API_URL|" "$APP_BUNDLE/Contents/Resources/.env" -else - echo "OMI_DESKTOP_API_URL=$EFFECTIVE_API_URL" >> "$APP_BUNDLE/Contents/Resources/.env" -fi +set_bundle_env "OMI_DESKTOP_API_URL" "$EFFECTIVE_API_URL" substep "OMI_DESKTOP_API_URL=$EFFECTIVE_API_URL" +if is_local_daemon_mode; then + set_bundle_env "OMI_DESKTOP_BACKEND_MODE" "local" + set_bundle_env "OMI_LOCAL_DAEMON_URL" "$OMI_LOCAL_DAEMON_URL" + substep "OMI_DESKTOP_BACKEND_MODE=local" + substep "OMI_LOCAL_DAEMON_URL=$OMI_LOCAL_DAEMON_URL" +fi # Bootstrap FIREBASE_API_KEY — check env var first (yolo mode), then backend .env if ! grep -q "^FIREBASE_API_KEY=" "$APP_BUNDLE/Contents/Resources/.env"; then FIREBASE_KEY="${FIREBASE_API_KEY:-}" From d28259e0961d36d5904be30ae7e22f417ad01b9d Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 18 May 2026 15:17:02 +0000 Subject: [PATCH 11/58] Fix local daemon action-item completion and conversation count API Map completed to status for PATCH v1/action-items in local mode, and add GET v1/conversations/count in the local Rust backend so the desktop client avoids loading full conversation rows just to count them. --- desktop/Desktop/Sources/APIClient.swift | 20 +++++++++++--------- desktop/local-backend/src/routes.rs | 10 ++++++++++ desktop/local-backend/src/storage.rs | 10 ++++++++++ 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 1d4f56a7dda..f14ccc3c5c5 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -594,14 +594,18 @@ extension APIClient { includeDiscarded: Bool = false, statuses: [ConversationStatus] = [.completed, .processing] ) async throws -> Int { + struct CountResponse: Decodable { + let count: Int + } + let target = mvpBackendTarget if target.mode == .localDaemon { - let response: LocalConversationsResponse = try await get( - "v1/conversations?limit=10000", + let response: CountResponse = try await get( + "v1/conversations/count", requireAuth: false, customBaseURL: target.baseURL ) - return response.conversations.count + return response.count } if let cache = conversationsCountCache, let time = conversationsCountCacheTime, @@ -621,10 +625,6 @@ extension APIClient { let endpoint = "v1/conversations/count?\(queryItems.joined(separator: "&"))" - struct CountResponse: Decodable { - let count: Int - } - let response: CountResponse = try await get(endpoint) conversationsCountCache = response.count conversationsCountCacheTime = Date() @@ -2188,9 +2188,10 @@ extension APIClient { let title: String? let description: String? let dueAt: String? + let status: String? enum CodingKeys: String, CodingKey { - case title, description + case title, description, status case dueAt = "due_at" } } @@ -2199,7 +2200,8 @@ extension APIClient { body: LocalUpdateRequest( title: description, description: description, - dueAt: dueAt.map { formatter.string(from: $0) } + dueAt: dueAt.map { formatter.string(from: $0) }, + status: completed.map { $0 ? "completed" : "open" } ), requireAuth: false, customBaseURL: target.baseURL diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index 0d0be46bb4a..1d3a5bc1426 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -28,6 +28,7 @@ pub fn router() -> Router { "/v1/conversations", get(list_conversations).post(create_conversation), ) + .route("/v1/conversations/count", get(count_conversations)) .route( "/v1/conversations/:id", get(get_conversation) @@ -161,6 +162,15 @@ async fn list_conversations( Ok(Json(json!({ "conversations": conversations }))) } +async fn count_conversations(State(state): State) -> ApiResult { + let count = state + .store + .conversations() + .count() + .map_err(ApiError::internal)?; + Ok(Json(json!({ "count": count }))) +} + #[derive(Deserialize)] struct CreateConversationRequest { id: Option, diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index 859efe30c62..b3c41182859 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -532,6 +532,16 @@ impl ConversationRepository { collect_rows(rows) } + pub fn count(&self) -> Result { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + "SELECT COUNT(*) FROM conversations WHERE deleted_at IS NULL", + [], + |row| row.get(0), + ) + .context("failed to count conversations") + } + pub fn update(&self, id: &str, update: UpdateConversation) -> Result> { let Some(mut conversation) = self.get(id)? else { return Ok(None); From b5dd990d337cc3f6276d13bdad9e1ec5ff5af64e Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 22:28:25 +0700 Subject: [PATCH 12/58] Close local-mode conversation cloud leaks --- desktop/Desktop/Sources/APIClient.swift | 113 +- desktop/Desktop/Sources/AppState.swift | 167 +- .../MainWindow/Pages/ConversationsPage.swift | 29 +- .../Desktop/Tests/APIClientRoutingTests.swift | 1673 +++++++++-------- desktop/local-backend/src/main.rs | 10 + desktop/local-backend/src/processing.rs | 1 + desktop/local-backend/src/routes.rs | 2 + desktop/local-backend/src/storage.rs | 84 +- 8 files changed, 1219 insertions(+), 860 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index f14ccc3c5c5..f805b2dac44 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -122,7 +122,7 @@ actor APIClient { request.httpMethod = "GET" request.allHTTPHeaderFields = try await buildHeaders(requireAuth: requireAuth) - return try await performRequest(request) + return try await performRequest(request, retryOnUnauthorized: requireAuth) } func post( @@ -139,7 +139,7 @@ actor APIClient { request.allHTTPHeaderFields = try await buildHeaders(requireAuth: requireAuth) request.httpBody = try JSONEncoder().encode(body) - return try await performRequest(request) + return try await performRequest(request, retryOnUnauthorized: requireAuth) } func post( @@ -153,7 +153,7 @@ actor APIClient { request.httpMethod = "POST" request.allHTTPHeaderFields = try await buildHeaders(requireAuth: requireAuth) - return try await performRequest(request) + return try await performRequest(request, retryOnUnauthorized: requireAuth) } func put( @@ -169,7 +169,7 @@ actor APIClient { request.allHTTPHeaderFields = try await buildHeaders(requireAuth: requireAuth) request.httpBody = try JSONEncoder().encode(body) - return try await performRequest(request) + return try await performRequest(request, retryOnUnauthorized: requireAuth) } func delete( @@ -216,7 +216,9 @@ actor APIClient { return response.settings } - func updateSelectedBackendSettings(_ values: [String: String]) async throws -> [LocalDaemonSetting] { + func updateSelectedBackendSettings(_ values: [String: String]) async throws + -> [LocalDaemonSetting] + { let target = selectedBackendTarget guard target.mode == .localDaemon else { return [] @@ -232,7 +234,9 @@ actor APIClient { // MARK: - Request Execution - private func performRequest(_ request: URLRequest) async throws -> T { + private func performRequest(_ request: URLRequest, retryOnUnauthorized: Bool) + async throws -> T + { let (data, response) = try await session.data(for: request) guard let httpResponse = response as? HTTPURLResponse else { @@ -241,6 +245,9 @@ actor APIClient { // Handle 401 - token might be expired if httpResponse.statusCode == 401 { + guard retryOnUnauthorized else { + throw APIError.unauthorized + } // Try to refresh token and retry once let authService = await MainActor.run { AuthService.shared } _ = try await authService.getIdToken(forceRefresh: true) @@ -467,6 +474,21 @@ extension APIClient { /// Updates the starred status of a conversation func setConversationStarred(id: String, starred: Bool) async throws { + struct StarredUpdate: Encodable { + let starred: Bool + } + + let target = mvpBackendTarget + if target.mode == .localDaemon { + let _: LocalConversationEnvelope = try await patch( + "v1/conversations/\(id)", + body: StarredUpdate(starred: starred), + requireAuth: false, + customBaseURL: target.baseURL + ) + return + } + let url = URL(string: baseURL + "v1/conversations/\(id)/starred?starred=\(starred)")! var request = URLRequest(url: url) request.httpMethod = "PATCH" @@ -742,6 +764,13 @@ extension APIClient { func mergeConversations(ids: [String], reprocess: Bool = true) async throws -> MergeConversationsResponse { + if mvpBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "conversation_merge", + reason: "Conversation merge is not implemented in local daemon mode yet." + ) + } + struct MergeRequest: Encodable { let conversationIds: [String] let reprocess: Bool @@ -760,6 +789,13 @@ extension APIClient { /// Gets all folders for the user func getFolders() async throws -> [Folder] { + if mvpBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "conversation_folders", + reason: "Conversation folders are not implemented in local daemon mode yet." + ) + } + return try await get("v1/folders") } @@ -767,6 +803,13 @@ extension APIClient { func createFolder(name: String, description: String? = nil, color: String? = nil) async throws -> Folder { + if mvpBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "conversation_folders", + reason: "Conversation folders are not implemented in local daemon mode yet." + ) + } + let body = CreateFolderRequest(name: name, description: description, color: color) return try await post("v1/folders", body: body) } @@ -776,12 +819,26 @@ extension APIClient { id: String, name: String? = nil, description: String? = nil, color: String? = nil, order: Int? = nil ) async throws -> Folder { + if mvpBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "conversation_folders", + reason: "Conversation folders are not implemented in local daemon mode yet." + ) + } + let body = UpdateFolderRequest(name: name, description: description, color: color, order: order) return try await patch("v1/folders/\(id)", body: body) } /// Deletes a folder func deleteFolder(id: String, moveToFolderId: String? = nil) async throws { + if mvpBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "conversation_folders", + reason: "Conversation folders are not implemented in local daemon mode yet." + ) + } + var endpoint = "v1/folders/\(id)" if let moveToId = moveToFolderId { endpoint += "?move_to_folder_id=\(moveToId)" @@ -791,6 +848,13 @@ extension APIClient { /// Moves a conversation to a folder func moveConversationToFolder(conversationId: String, folderId: String?) async throws { + if mvpBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "conversation_folders", + reason: "Conversation folders are not implemented in local daemon mode yet." + ) + } + let body = MoveToFolderRequest(folderId: folderId) let url = URL(string: baseURL + "v1/conversations/\(conversationId)/folder")! var request = URLRequest(url: url) @@ -1292,9 +1356,10 @@ private struct LocalConversation: Decodable { let endedAt: Date? let createdAt: Date let deletedAt: Date? + let starred: Bool enum CodingKeys: String, CodingKey { - case id, title, overview, status + case id, title, overview, status, starred case sessionId = "session_id" case startedAt = "started_at" case endedAt = "ended_at" @@ -1326,7 +1391,7 @@ private struct LocalConversation: Decodable { discarded: false, deleted: deletedAt != nil, isLocked: false, - starred: false, + starred: starred, folderId: nil, inputDeviceName: nil ) @@ -1977,8 +2042,9 @@ struct MemoryBatchItem: Encodable { let tags: [String] let headline: String? - init(content: String, visibility: String = "private", tags: [String] = [], headline: String? = nil) - { + init( + content: String, visibility: String = "private", tags: [String] = [], headline: String? = nil + ) { self.content = content self.visibility = visibility self.tags = tags @@ -4246,11 +4312,11 @@ struct NotificationSettingsResponse: Codable { } enum SubscriptionPlanType: String, Codable { - case basic // display "Free" + case basic // display "Free" case unlimited // legacy — display "Unlimited (legacy)" case architect // display "Architect" ($400/mo, cost_usd quota) - case pro // backward compat: old Firestore docs may still say "pro" - case `operator` // new — display "Operator" + case pro // backward compat: old Firestore docs may still say "pro" + case `operator` // new — display "Operator" } enum SubscriptionStatusType: String, Codable { @@ -4315,7 +4381,10 @@ struct SubscriptionPlanOption: Codable, Identifiable { let features: [String] let prices: [SubscriptionPriceOption] - init(id: String, title: String, subtitle: String? = nil, description: String? = nil, eyebrow: String? = nil, features: [String] = [], prices: [SubscriptionPriceOption] = []) { + init( + id: String, title: String, subtitle: String? = nil, description: String? = nil, + eyebrow: String? = nil, features: [String] = [], prices: [SubscriptionPriceOption] = [] + ) { self.id = id self.title = title self.subtitle = subtitle @@ -4732,7 +4801,8 @@ extension APIClient { var request = URLRequest(url: url) request.httpMethod = "POST" request.allHTTPHeaderFields = try await buildHeaders(requireAuth: true) - request.setValue("multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type") + request.setValue( + "multipart/form-data; boundary=\(boundary)", forHTTPHeaderField: "Content-Type") var body = Data() let lineBreak = "\r\n" @@ -5243,14 +5313,14 @@ extension APIClient { /// Current-month chat usage + the plan's cap. Backed by Python backend /// endpoint `/v1/users/me/usage-quota` which reads `users/{uid}/llm_usage/*`. struct ChatUsageQuota: Decodable { - let plan: String // display name: "Free" | "Plus" | "Pro" - let planType: String // internal id: "basic" | "unlimited" | "architect" - let unit: String // "questions" | "cost_usd" + let plan: String // display name: "Free" | "Plus" | "Pro" + let planType: String // internal id: "basic" | "unlimited" | "architect" + let unit: String // "questions" | "cost_usd" let used: Double - let limit: Double? // nil means unlimited + let limit: Double? // nil means unlimited let percent: Double let allowed: Bool - let resetAt: Int? // unix seconds — start of next UTC month + let resetAt: Int? // unix seconds — start of next UTC month enum CodingKeys: String, CodingKey { case plan @@ -5265,7 +5335,8 @@ extension APIClient { } func fetchChatUsageQuota() async -> ChatUsageQuota? { - guard DesktopBackendEnvironment.isCapability(.payments, availableIn: selectedBackendTarget.mode) else { + guard DesktopBackendEnvironment.isCapability(.payments, availableIn: selectedBackendTarget.mode) + else { log("APIClient: Chat quota disabled in local daemon mode") return nil } diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index b63c8420040..261bcff9074 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -15,13 +15,13 @@ struct SegmentTranslation: Identifiable { struct SpeakerSegment: Identifiable { /// Stable identity — uses backend segment ID when available, otherwise speaker + start time var id: String { segmentId ?? "\(speaker)-\(start)" } - var segmentId: String? // Backend-assigned UUID + var segmentId: String? // Backend-assigned UUID var speaker: Int var text: String var start: Double var end: Double var isUser: Bool = false - var personId: String? // Backend-assigned person ID from speaker identification + var personId: String? // Backend-assigned person ID from speaker identification var translations: [SegmentTranslation] = [] } @@ -813,73 +813,73 @@ class AppState: ObservableObject { DispatchQueue.main.async { [weak self] in guard let self else { return } UNUserNotificationCenter.current().getNotificationSettings { settings in - DispatchQueue.main.async { - let isNowGranted = settings.authorizationStatus == .authorized - self.hasNotificationPermission = isNowGranted - self.notificationAlertStyle = settings.alertStyle - - // Log the current notification settings - let authStatus = - switch settings.authorizationStatus { - case .notDetermined: "notDetermined" - case .denied: "denied" - case .authorized: "authorized" - case .provisional: "provisional" - case .ephemeral: "ephemeral" - @unknown default: "unknown" - } - let alertStyleName = - switch settings.alertStyle { - case .none: "NONE (no banners)" - case .banner: "BANNER" - case .alert: "ALERT" - @unknown default: "unknown" - } - log( - "Notification settings: auth=\(authStatus), alertStyle=\(alertStyleName), sound=\(settings.soundSetting.rawValue), badge=\(settings.badgeSetting.rawValue)" - ) - - // Track notification settings in analytics only when they change - let soundEnabled = settings.soundSetting == .enabled - let badgeEnabled = settings.badgeSetting == .enabled - let settingsChanged = - authStatus != self.lastNotificationAuthStatus - || alertStyleName != self.lastNotificationAlertStyle - || soundEnabled != self.lastNotificationSoundEnabled - || badgeEnabled != self.lastNotificationBadgeEnabled - - if settingsChanged { - AnalyticsManager.shared.notificationSettingsChecked( - authStatus: authStatus, - alertStyle: alertStyleName, - soundEnabled: soundEnabled, - badgeEnabled: badgeEnabled, - bannersDisabled: settings.alertStyle == .none + DispatchQueue.main.async { + let isNowGranted = settings.authorizationStatus == .authorized + self.hasNotificationPermission = isNowGranted + self.notificationAlertStyle = settings.alertStyle + + // Log the current notification settings + let authStatus = + switch settings.authorizationStatus { + case .notDetermined: "notDetermined" + case .denied: "denied" + case .authorized: "authorized" + case .provisional: "provisional" + case .ephemeral: "ephemeral" + @unknown default: "unknown" + } + let alertStyleName = + switch settings.alertStyle { + case .none: "NONE (no banners)" + case .banner: "BANNER" + case .alert: "ALERT" + @unknown default: "unknown" + } + log( + "Notification settings: auth=\(authStatus), alertStyle=\(alertStyleName), sound=\(settings.soundSetting.rawValue), badge=\(settings.badgeSetting.rawValue)" ) - // Detect regression: was authorized, now reverted to notDetermined - // This happens on macOS 26+ where the OS silently revokes notification permission - if self.lastNotificationAuthStatus == "authorized" && authStatus == "notDetermined" { - log( - "Notification permission REGRESSED from authorized to notDetermined — triggering auto-repair" - ) - AnalyticsManager.shared.notificationRepairTriggered( - reason: "auth_regression", - previousStatus: "authorized", - currentStatus: "notDetermined" + // Track notification settings in analytics only when they change + let soundEnabled = settings.soundSetting == .enabled + let badgeEnabled = settings.badgeSetting == .enabled + let settingsChanged = + authStatus != self.lastNotificationAuthStatus + || alertStyleName != self.lastNotificationAlertStyle + || soundEnabled != self.lastNotificationSoundEnabled + || badgeEnabled != self.lastNotificationBadgeEnabled + + if settingsChanged { + AnalyticsManager.shared.notificationSettingsChecked( + authStatus: authStatus, + alertStyle: alertStyleName, + soundEnabled: soundEnabled, + badgeEnabled: badgeEnabled, + bannersDisabled: settings.alertStyle == .none ) - self.repairNotificationRegistrationAndRetry() + + // Detect regression: was authorized, now reverted to notDetermined + // This happens on macOS 26+ where the OS silently revokes notification permission + if self.lastNotificationAuthStatus == "authorized" && authStatus == "notDetermined" { + log( + "Notification permission REGRESSED from authorized to notDetermined — triggering auto-repair" + ) + AnalyticsManager.shared.notificationRepairTriggered( + reason: "auth_regression", + previousStatus: "authorized", + currentStatus: "notDetermined" + ) + self.repairNotificationRegistrationAndRetry() + } + + // Update last known state + self.lastNotificationAuthStatus = authStatus + self.lastNotificationAlertStyle = alertStyleName + self.lastNotificationSoundEnabled = soundEnabled + self.lastNotificationBadgeEnabled = badgeEnabled } - // Update last known state - self.lastNotificationAuthStatus = authStatus - self.lastNotificationAlertStyle = alertStyleName - self.lastNotificationSoundEnabled = soundEnabled - self.lastNotificationBadgeEnabled = badgeEnabled } - } - } } // end DispatchQueue.main.async } @@ -1519,7 +1519,9 @@ class AppState: ObservableObject { silentMicFallbackInProgress = true guard let builtInID = AudioCaptureService.findBuiltInMicDeviceID() else { - log("Transcription: silent-mic detected but no built-in microphone available — leaving capture as-is") + log( + "Transcription: silent-mic detected but no built-in microphone available — leaving capture as-is" + ) silentMicFallbackInProgress = false return } @@ -1645,7 +1647,9 @@ class AppState: ObservableObject { // finalize the NEW conversation instead of the one we just stopped. // The retry service will reconcile the old session by timestamp matching. guard self.recordingGeneration == generationAtStop else { - log("Transcription: New recording started during delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")") + log( + "Transcription: New recording started during delay, skipping force-process for session \(capturedSessionId.map(String.init) ?? "nil")" + ) return } @@ -1653,15 +1657,20 @@ class AppState: ObservableObject { if let conversation = try await APIClient.shared.forceProcessConversation() { // Validate the returned conversation matches the session we just stopped if let sessionId = capturedSessionId, let startTime = capturedStartTime, - let convStarted = conversation.startedAt, - abs(convStarted.timeIntervalSince(startTime)) < 10, - conversation.source == .desktop { + let convStarted = conversation.startedAt, + abs(convStarted.timeIntervalSince(startTime)) < 10, + conversation.source == .desktop + { try? await TranscriptionStorage.shared.markSessionCompleted( id: sessionId, backendId: conversation.id) - log("Transcription: Force-processed conversation \(conversation.id), session \(sessionId) completed") + log( + "Transcription: Force-processed conversation \(conversation.id), session \(sessionId) completed" + ) } else if let sessionId = capturedSessionId, let startTime = capturedStartTime { // Force-process returned a different conversation — fall back to reconciliation - log("Transcription: Force-processed conversation \(conversation.id) does not match session \(sessionId), reconciling by timestamp") + log( + "Transcription: Force-processed conversation \(conversation.id) does not match session \(sessionId), reconciling by timestamp" + ) await reconcileSession(sessionId: sessionId, startTime: startTime) } } else { @@ -1700,7 +1709,9 @@ class AppState: ObservableObject { id: sessionId, backendId: match.id) log("Transcription: Reconciled session \(sessionId) → backend conversation \(match.id)") } else { - log("Transcription: No matching backend conversation found for session \(sessionId), leaving for retry") + log( + "Transcription: No matching backend conversation found for session \(sessionId), leaving for retry" + ) } } catch { logError("Transcription: Reconciliation failed for session \(sessionId)", error: error) @@ -1715,7 +1726,9 @@ class AppState: ObservableObject { return .discarded } - log("Transcription: Finishing conversation — disconnecting WebSocket to trigger backend processing") + log( + "Transcription: Finishing conversation — disconnecting WebSocket to trigger backend processing" + ) // Capture state before rotation — memory_created event for this conversation // may arrive on the new WebSocket after currentSessionId and recordingStartTime have changed. @@ -2194,6 +2207,12 @@ class AppState: ObservableObject { func loadFolders() async { guard !isLoadingFolders else { return } + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + folders = [] + selectedFolderId = nil + return + } + isLoadingFolders = true do { @@ -2330,7 +2349,7 @@ class AppState: ObservableObject { let idSet = Set(segmentIds) if let idx = conversations.firstIndex(where: { $0.id == conversationId }) { for segIdx in conversations[idx].transcriptSegments.indices - where idSet.contains(conversations[idx].transcriptSegments[segIdx].id) { + where idSet.contains(conversations[idx].transcriptSegments[segIdx].id) { let old = conversations[idx].transcriptSegments[segIdx] conversations[idx].transcriptSegments[segIdx] = TranscriptSegment( id: old.id, @@ -2600,7 +2619,9 @@ class AppState: ObservableObject { // Always persist to SQLite — even if the segment was trimmed from // the in-memory window, the event payload has all fields needed if let sessionId = currentSessionId { - let mapped = newTranslations.map { TranscriptTranslation(lang: $0.lang, text: $0.text) } + let mapped = newTranslations.map { + TranscriptTranslation(lang: $0.lang, text: $0.text) + } var translationsJson: String? if let jsonData = try? JSONEncoder().encode(mapped) { translationsJson = String(data: jsonData, encoding: .utf8) diff --git a/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift index 3f25b8e9990..5e776ad46c7 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift @@ -61,6 +61,10 @@ struct ConversationsPage: View { @State private var isMerging: Bool = false @State private var mergeError: String? = nil + private var isLocalDaemonMode: Bool { + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + } + var body: some View { Group { if let selected = selectedConversation { @@ -127,7 +131,9 @@ struct ConversationsPage: View { } } } - .onReceive(NotificationCenter.default.publisher(for: .desktopAutomationOpenConversationRequested)) { + .onReceive( + NotificationCenter.default.publisher(for: .desktopAutomationOpenConversationRequested) + ) { notification in handleAutomationOpenConversation(notification) } @@ -177,7 +183,8 @@ struct ConversationsPage: View { Task { await appState.refreshConversations() await MainActor.run { - guard let conversation = appState.conversations.first(where: { $0.id == conversationId }) else { + guard let conversation = appState.conversations.first(where: { $0.id == conversationId }) + else { log("Desktop automation: conversation \(conversationId) not found") return } @@ -283,14 +290,16 @@ struct ConversationsPage: View { .padding(.vertical, 12) // Folder tabs strip - FolderTabsStrip( - appState: appState, - onCreateFolder: { showCreateFolderSheet = true }, - onEditFolder: { folder in editingFolder = folder }, - onDeleteFolder: { folder in deletingFolder = folder } - ) - .padding(.horizontal, 24) - .padding(.bottom, 12) + if !isLocalDaemonMode { + FolderTabsStrip( + appState: appState, + onCreateFolder: { showCreateFolderSheet = true }, + onEditFolder: { folder in editingFolder = folder }, + onDeleteFolder: { folder in deletingFolder = folder } + ) + .padding(.horizontal, 24) + .padding(.bottom, 12) + } // List - show search results or regular conversations if !searchQuery.isEmpty { diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 64f50e0d267..d28c1f00000 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -1,794 +1,979 @@ import XCTest + @testable import Omi_Computer // MARK: - Request-capturing protocol for routing verification /// Captured request info: URL + HTTP method. private struct CapturedRequest { - let url: URL - let method: String - let headers: [String: String] - let body: Data? + let url: URL + let method: String + let headers: [String: String] + let body: Data? } /// Intercepts HTTP requests, records their URL and method, then returns 403 /// so APIClient throws .httpError (not 401, which triggers AuthService refresh). private final class URLCapture: URLProtocol, @unchecked Sendable { - private static let lock = NSLock() - private static var _requests: [CapturedRequest] = [] - - static var capturedRequests: [CapturedRequest] { - lock.lock() - defer { lock.unlock() } - return _requests - } - - static func reset() { - lock.lock() - _requests.removeAll() - lock.unlock() - } - - private static func record(_ req: CapturedRequest) { - lock.lock() - _requests.append(req) - lock.unlock() - } - - private static func bodyData(from request: URLRequest) -> Data? { - if let body = request.httpBody { - return body - } - - guard let stream = request.httpBodyStream else { - return nil - } - - stream.open() - defer { stream.close() } - - var data = Data() - let bufferSize = 4096 - let buffer = UnsafeMutablePointer.allocate(capacity: bufferSize) - defer { buffer.deallocate() } - - while stream.hasBytesAvailable { - let readCount = stream.read(buffer, maxLength: bufferSize) - if readCount > 0 { - data.append(buffer, count: readCount) - } else if readCount < 0 { - return nil - } else { - break - } - } - - return data.isEmpty ? nil : data - } - - override class func canInit(with request: URLRequest) -> Bool { true } - override class func canonicalRequest(for request: URLRequest) -> URLRequest { request } - - override func startLoading() { - if let url = request.url { - URLCapture.record(CapturedRequest( - url: url, - method: request.httpMethod ?? "GET", - headers: request.allHTTPHeaderFields ?? [:], - body: Self.bodyData(from: request) - )) - } - let response = HTTPURLResponse(url: request.url!, statusCode: 403, httpVersion: nil, headerFields: nil)! - client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) - client?.urlProtocol(self, didLoad: Data("{\"detail\":\"test\"}".utf8)) - client?.urlProtocolDidFinishLoading(self) + private static let lock = NSLock() + private static var _requests: [CapturedRequest] = [] + private static var _statusCode = 403 + + static var capturedRequests: [CapturedRequest] { + lock.lock() + defer { lock.unlock() } + return _requests + } + + static var statusCode: Int { + lock.lock() + defer { lock.unlock() } + return _statusCode + } + + static func reset() { + lock.lock() + _requests.removeAll() + _statusCode = 403 + lock.unlock() + } + + static func setStatusCode(_ statusCode: Int) { + lock.lock() + _statusCode = statusCode + lock.unlock() + } + + private static func record(_ req: CapturedRequest) { + lock.lock() + _requests.append(req) + lock.unlock() + } + + private static func bodyData(from request: URLRequest) -> Data? { + if let body = request.httpBody { + return body + } + + guard let stream = request.httpBodyStream else { + return nil + } + + stream.open() + defer { stream.close() } + + var data = Data() + let bufferSize = 4096 + let buffer = UnsafeMutablePointer.allocate(capacity: bufferSize) + defer { buffer.deallocate() } + + while stream.hasBytesAvailable { + let readCount = stream.read(buffer, maxLength: bufferSize) + if readCount > 0 { + data.append(buffer, count: readCount) + } else if readCount < 0 { + return nil + } else { + break + } + } + + return data.isEmpty ? nil : data + } + + override class func canInit(with request: URLRequest) -> Bool { true } + override class func canonicalRequest(for request: URLRequest) -> URLRequest { request } + + override func startLoading() { + if let url = request.url { + URLCapture.record( + CapturedRequest( + url: url, + method: request.httpMethod ?? "GET", + headers: request.allHTTPHeaderFields ?? [:], + body: Self.bodyData(from: request) + )) } + let response = HTTPURLResponse( + url: request.url!, statusCode: Self.statusCode, httpVersion: nil, headerFields: nil)! + client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) + client?.urlProtocol(self, didLoad: Data("{\"detail\":\"test\"}".utf8)) + client?.urlProtocolDidFinishLoading(self) + } - override func stopLoading() {} + override func stopLoading() {} } // MARK: - Assertion helpers private func assertRoutes( - _ reqs: [CapturedRequest], - host: String, - port: Int, - pathContains: String, - method: String, - label: String, - file: StaticString = #filePath, - line: UInt = #line + _ reqs: [CapturedRequest], + host: String, + port: Int, + pathContains: String, + method: String, + label: String, + file: StaticString = #filePath, + line: UInt = #line ) { - XCTAssertEqual(reqs.count, 1, "\(label): expected 1 request, got \(reqs.count)", file: file, line: line) - guard let req = reqs.first else { return } - XCTAssertEqual(req.url.host, host, "\(label): wrong host", file: file, line: line) - XCTAssertEqual(req.url.port, port, "\(label): wrong port", file: file, line: line) - XCTAssertTrue(req.url.absoluteString.contains(pathContains), "\(label): path should contain '\(pathContains)', got \(req.url.absoluteString)", file: file, line: line) - XCTAssertEqual(req.method, method, "\(label): wrong HTTP method", file: file, line: line) + XCTAssertEqual( + reqs.count, 1, "\(label): expected 1 request, got \(reqs.count)", file: file, line: line) + guard let req = reqs.first else { return } + XCTAssertEqual(req.url.host, host, "\(label): wrong host", file: file, line: line) + XCTAssertEqual(req.url.port, port, "\(label): wrong port", file: file, line: line) + XCTAssertTrue( + req.url.absoluteString.contains(pathContains), + "\(label): path should contain '\(pathContains)', got \(req.url.absoluteString)", file: file, + line: line) + XCTAssertEqual(req.method, method, "\(label): wrong HTTP method", file: file, line: line) } private func assertUnavailable( - _ error: Error?, - capability: DesktopBackendEnvironment.Capability, - file: StaticString = #filePath, - line: UInt = #line + _ error: Error?, + capability: DesktopBackendEnvironment.Capability, + file: StaticString = #filePath, + line: UInt = #line ) { - guard let error, case APIError.featureUnavailable(let feature, _) = error else { - XCTFail("expected featureUnavailable for \(capability.rawValue), got \(String(describing: error))", file: file, line: line) - return - } - XCTAssertEqual(feature, capability.rawValue, file: file, line: line) + guard let error, case APIError.featureUnavailable(let feature, _) = error else { + XCTFail( + "expected featureUnavailable for \(capability.rawValue), got \(String(describing: error))", + file: file, line: line) + return + } + XCTAssertEqual(feature, capability.rawValue, file: file, line: line) } // MARK: - Tests final class APIClientRoutingTests: XCTestCase { - // MARK: - URL property tests - - func testBaseURLDefaultsToPythonBackend() async { - unsetenv("OMI_PYTHON_API_URL") - let client = APIClient() - let url = await client.baseURL - XCTAssertEqual(url, "https://api.omi.me/") - } - - func testBetaProductionBundleUsesDevelopmentPythonBackend() { - let url = DesktopBackendEnvironment.pythonBaseURL( - useDevelopmentBackends: true, - environmentValue: "https://api.omi.me" - ) - XCTAssertEqual(url, "https://api.omiapi.com/") - } - - func testStableProductionBundleKeepsProductionPythonBackend() { - let url = DesktopBackendEnvironment.pythonBaseURL( - useDevelopmentBackends: false, - environmentValue: "https://api.omi.me" - ) - XCTAssertEqual(url, "https://api.omi.me/") - } - - func testBetaProductionBundleUsesDevelopmentRustBackend() { - let url = DesktopBackendEnvironment.rustBackendURL( - useDevelopmentBackends: true, - environmentValue: "https://desktop-backend-hhibjajaja-uc.a.run.app", - launchEnvironmentValue: nil - ) - XCTAssertEqual(url, "https://desktop-backend-dt5lrfkkoa-uc.a.run.app/") - } - - func testStableProductionBundleKeepsConfiguredRustBackend() { - let url = DesktopBackendEnvironment.rustBackendURL( - useDevelopmentBackends: false, - environmentValue: "https://desktop-backend-hhibjajaja-uc.a.run.app", - launchEnvironmentValue: nil - ) - XCTAssertEqual(url, "https://desktop-backend-hhibjajaja-uc.a.run.app/") - } - - func testBetaProductionBundleRoutesToDevelopmentBackends() { - XCTAssertTrue(DesktopBackendEnvironment.shouldUseDevelopmentBackends( - bundleIdentifier: "com.omi.computer-macos", - updateChannel: "beta" - )) - // "staging" is normalized to "beta" — same routing. - XCTAssertTrue(DesktopBackendEnvironment.shouldUseDevelopmentBackends( - bundleIdentifier: "com.omi.computer-macos", - updateChannel: "staging" - )) - } - - func testStableProductionBundleKeepsProductionBackends() { - XCTAssertFalse(DesktopBackendEnvironment.shouldUseDevelopmentBackends( - bundleIdentifier: "com.omi.computer-macos", - updateChannel: "stable" - )) - } - - func testNonProductionBundleSkipsAutomaticBetaRouting() { - // Dev bundle and named test bundles never trigger beta-to-dev routing - // automatically. They must opt in via OMI_FORCE_DEV_BACKENDS or env URLs. - XCTAssertFalse(DesktopBackendEnvironment.shouldUseDevelopmentBackends( - bundleIdentifier: "com.omi.desktop-dev", - updateChannel: "beta" - )) - XCTAssertFalse(DesktopBackendEnvironment.shouldUseDevelopmentBackends( - bundleIdentifier: "com.omi.omi-beta-dev-test", - updateChannel: "beta" - )) - } - - func testForceOverrideEnablesDevelopmentBackendsForAnyBundle() { - XCTAssertTrue(DesktopBackendEnvironment.shouldUseDevelopmentBackends( - bundleIdentifier: "com.omi.desktop-dev", - updateChannel: "stable", - forceOverride: "1" - )) - XCTAssertTrue(DesktopBackendEnvironment.shouldUseDevelopmentBackends( - bundleIdentifier: "com.omi.omi-beta-dev-test", - updateChannel: "stable", - forceOverride: "true" - )) - XCTAssertFalse(DesktopBackendEnvironment.shouldUseDevelopmentBackends( - bundleIdentifier: "com.omi.computer-macos", - updateChannel: "stable", - forceOverride: "0" - )) - } - - func testBaseURLReadsFromPythonEnvVar() async { - setenv("OMI_PYTHON_API_URL", "http://localhost:8080", 1) - defer { unsetenv("OMI_PYTHON_API_URL") } - let client = APIClient() - let url = await client.baseURL - XCTAssertEqual(url, "http://localhost:8080/") - } - - func testBaseURLAddsTrailingSlash() async { - setenv("OMI_PYTHON_API_URL", "http://localhost:8080", 1) - defer { unsetenv("OMI_PYTHON_API_URL") } - let client = APIClient() - let url = await client.baseURL - XCTAssertTrue(url.hasSuffix("/")) - } - - func testBaseURLPreservesExistingTrailingSlash() async { - setenv("OMI_PYTHON_API_URL", "http://localhost:8080/", 1) - defer { unsetenv("OMI_PYTHON_API_URL") } - let client = APIClient() - let url = await client.baseURL - XCTAssertEqual(url, "http://localhost:8080/") - } - - func testRustBackendURLReadsFromApiUrlEnvVar() async { - setenv("OMI_DESKTOP_API_URL", "http://localhost:8787", 1) - defer { unsetenv("OMI_DESKTOP_API_URL") } - let client = APIClient() - let url = await client.rustBackendURL - XCTAssertEqual(url, "http://localhost:8787/") - } - - func testRustBackendURLReturnsEmptyWhenNotSet() async { - unsetenv("OMI_DESKTOP_API_URL") - let client = APIClient() - let url = await client.rustBackendURL - XCTAssertEqual(url, "") - } - - func testSelectedBackendTargetDefaultsToCloudPython() { - let target = DesktopBackendEnvironment.selectedBackendTarget( - modeValue: nil, - pythonEnvironmentValue: "https://api.example.test", - localDaemonEnvironmentValue: nil - ) - XCTAssertEqual(target.mode, .cloud) - XCTAssertEqual(target.baseURL, "https://api.example.test/") - XCTAssertTrue(target.requiresAuth) - } - - func testSelectedBackendTargetSupportsLocalDaemonDefault() { - let target = DesktopBackendEnvironment.selectedBackendTarget( - modeValue: "local", - pythonEnvironmentValue: "https://api.example.test", - localDaemonEnvironmentValue: nil - ) - XCTAssertEqual(target.mode, .localDaemon) - XCTAssertEqual(target.baseURL, "http://127.0.0.1:8765/") - XCTAssertFalse(target.requiresAuth) - } - - func testSelectedBackendTargetSupportsCustomRemote() { - let target = DesktopBackendEnvironment.selectedBackendTarget( - modeValue: "custom", - pythonEnvironmentValue: "http://custom-backend:7777", - localDaemonEnvironmentValue: "http://127.0.0.1:8765" - ) - XCTAssertEqual(target.mode, .customRemote) - XCTAssertEqual(target.baseURL, "http://custom-backend:7777/") - XCTAssertTrue(target.requiresAuth) - } - - func testLocalDaemonCapabilityMatrixDisablesCloudBoundFeatures() { - let capabilities = Dictionary( - uniqueKeysWithValues: DesktopBackendEnvironment.capabilities(for: .localDaemon) - .map { ($0.capability, $0) } - ) - - XCTAssertEqual(capabilities[.localConversationData]?.available, true) - XCTAssertEqual(capabilities[.firebaseSignIn]?.available, true) - XCTAssertEqual(capabilities[.managedAgentVM]?.available, false) - XCTAssertEqual(capabilities[.omiBackendProviderProxy]?.available, false) - XCTAssertEqual(capabilities[.publicSharing]?.available, false) - XCTAssertEqual(capabilities[.cloudSync]?.available, false) - XCTAssertEqual(capabilities[.payments]?.available, false) - XCTAssertEqual(capabilities[.crispSupport]?.available, false) - XCTAssertEqual(capabilities[.hostedTranscription]?.available, false) - XCTAssertNotNil(capabilities[.managedAgentVM]?.reason) - } - - func testCloudCapabilityMatrixAllowsCloudBoundFeatures() { - for state in DesktopBackendEnvironment.capabilities(for: .cloud) { - XCTAssertTrue(state.available, "\(state.capability.rawValue) should be available in cloud mode") - XCTAssertNil(state.reason) - } - } - - func testBaseURLAndRustBackendURLAreIndependent() async { - setenv("OMI_PYTHON_API_URL", "http://python:8080", 1) - setenv("OMI_DESKTOP_API_URL", "http://rust:8787", 1) - defer { unsetenv("OMI_PYTHON_API_URL"); unsetenv("OMI_DESKTOP_API_URL") } - - let client = APIClient() - let base = await client.baseURL - let rust = await client.rustBackendURL - XCTAssertEqual(base, "http://python:8080/") - XCTAssertEqual(rust, "http://rust:8787/") - XCTAssertNotEqual(base, rust) - } - - // MARK: - Routing behavior: Python-routed endpoints (default baseURL) - - private func makeTestClient() async -> APIClient { - let config = URLSessionConfiguration.ephemeral - config.protocolClasses = [URLCapture.self] - let session = URLSession(configuration: config) - let client = APIClient(session: session) - await client.setTestAuthHeader("Bearer test-token") - return client - } - - override func setUp() { - super.setUp() - URLCapture.reset() - setenv("OMI_PYTHON_API_URL", "http://python-test:9001", 1) - setenv("OMI_DESKTOP_API_URL", "http://rust-test:9002", 1) - unsetenv("OMI_DESKTOP_BACKEND_MODE") - unsetenv("OMI_LOCAL_DAEMON_URL") - } - - override func tearDown() { - unsetenv("OMI_PYTHON_API_URL") - unsetenv("OMI_DESKTOP_API_URL") - unsetenv("OMI_DESKTOP_BACKEND_MODE") - unsetenv("OMI_LOCAL_DAEMON_URL") - URLCapture.reset() - super.tearDown() - } - - // -- Conversations (GET, DELETE → Python) -- - - func testGetConversationRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getConversation(id: "test-123") as ServerConversation - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/conversations/test-123", method: "GET", - label: "getConversation") - } - - func testLocalModeGetConversationRoutesToLocalDaemonWithoutAuth() async { - setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) - setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) - let client = await makeTestClient() - _ = try? await client.getConversation(id: "local-123") as ServerConversation - - assertRoutes(URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, - pathContains: "v1/conversations/local-123", method: "GET", - label: "local getConversation") - XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) - } - - func testDeleteConversationRoutesToPython() async { - let client = await makeTestClient() - try? await client.deleteConversation(id: "conv-456") - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/conversations/conv-456", method: "DELETE", - label: "deleteConversation") - } - - func testLocalModeCreateConversationRoutesToLocalDaemonWithoutAuth() async { - setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) - setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) - let client = await makeTestClient() - _ = try? await client.createLocalDaemonConversation( - sessionId: "desktop-1", - title: "Local", - overview: "Local daemon", - startedAt: Date(timeIntervalSince1970: 0) - ) - - let requests = URLCapture.capturedRequests - assertRoutes(requests, host: "127.0.0.1", port: 8765, - pathContains: "v1/conversations", method: "POST", - label: "local createConversation") - XCTAssertNil(requests.first?.headers["Authorization"]) - } - - func testLocalModeHealthCheckRoutesToLocalDaemonWithoutAuth() async { - setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) - setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) - let client = await makeTestClient() - _ = try? await client.checkSelectedBackendHealth() - - assertRoutes(URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, - pathContains: "health", method: "GET", - label: "local health") - XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) - } - - func testLocalModeSettingsRoutesToLocalDaemonWithoutAuth() async { - setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) - setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) - let client = await makeTestClient() - _ = try? await client.updateSelectedBackendSettings(["profile_name": "Local"]) - - assertRoutes(URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, - pathContains: "v1/settings", method: "PUT", - label: "local settings") - XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) - } - - func testLocalModeMVPConversationFlowsIgnoreInvalidCloudURLs() async { - setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) - setenv("OMI_PYTHON_API_URL", "http://omi-cloud-invalid:9001", 1) - setenv("OMI_DESKTOP_API_URL", "http://omi-rust-invalid:9002", 1) - setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:9876", 1) - let client = await makeTestClient() - - _ = try? await client.getConversations() - _ = try? await client.getConversation(id: "local-123") as ServerConversation - _ = try? await client.searchConversations(query: "offline") - try? await client.updateConversationTitle(id: "local-123", title: "Offline") - _ = try? await client.updateSelectedBackendSettings(["profile_name": "Offline"]) - - let requests = URLCapture.capturedRequests - XCTAssertEqual(requests.count, 5) - XCTAssertTrue(requests.allSatisfy { $0.url.host == "127.0.0.1" && $0.url.port == 9876 }) - XCTAssertTrue(requests.allSatisfy { $0.headers["Authorization"] == nil }) - XCTAssertFalse(requests.contains { $0.url.host == "omi-cloud-invalid" || $0.url.host == "omi-rust-invalid" }) - } - - func testLocalModeCloudOnlyFeaturesFailBeforeNetworkRequests() async { - setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) - setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) - let client = await makeTestClient() - - do { - _ = try await client.provisionAgentVM() - XCTFail("expected managed agent VM to be unavailable") - } catch { - assertUnavailable(error, capability: .managedAgentVM) - } - - do { - _ = try await client.fetchApiKeys() - XCTFail("expected backend provider proxy to be unavailable") - } catch { - assertUnavailable(error, capability: .omiBackendProviderProxy) - } - - do { - _ = try await client.getUserSubscription() - XCTFail("expected payments to be unavailable") - } catch { - assertUnavailable(error, capability: .payments) - } - - do { - _ = try await client.shareChatMessages(messageIds: ["m1"]) - XCTFail("expected public sharing to be unavailable") - } catch { - assertUnavailable(error, capability: .publicSharing) - } - - do { - try await client.setPrivateCloudSync(enabled: true) - XCTFail("expected cloud sync to be unavailable") - } catch { - assertUnavailable(error, capability: .cloudSync) - } - - XCTAssertTrue(URLCapture.capturedRequests.isEmpty) - } - - // -- Conversations: manual URL(string: baseURL + ...) paths (PATCH → Python) -- - - func testSetConversationStarredRoutesToPython() async { - let client = await makeTestClient() - try? await client.setConversationStarred(id: "c1", starred: true) - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/conversations/c1/starred", method: "PATCH", - label: "setConversationStarred") - } - - func testUpdateConversationTitleRoutesToPython() async { - let client = await makeTestClient() - try? await client.updateConversationTitle(id: "c2", title: "New") - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/conversations/c2", method: "PATCH", - label: "updateConversationTitle") - } - - // -- Folders (GET → Python) -- - - func testGetFoldersRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getFolders() as [Folder] - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/folders", method: "GET", - label: "getFolders") - } - - // -- Memories (POST → Python) -- - - func testCreateMemoryRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.createMemory(content: "test memory") as CreateMemoryResponse - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v3/memories", method: "POST", - label: "createMemory") - } - - // -- Goals: manual URL path (PATCH → Python) -- - - func testUpdateGoalProgressRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.updateGoalProgress(goalId: "g1", currentValue: 42.0) as Goal - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/goals/g1/progress", method: "PATCH", - label: "updateGoalProgress") - } - - // -- Apps (GET → Python) -- - - func testGetAppsRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getApps() as [OmiApp] - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/apps", method: "GET", - label: "getApps") - } - - // -- Personas (GET → Python) -- - - func testGetPersonaRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getPersona() as Persona? - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/personas", method: "GET", - label: "getPersona") - } - - // -- User settings (GET → Python) -- - - func testGetDailySummarySettingsRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getDailySummarySettings() as DailySummarySettings - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/users/daily-summary-settings", method: "GET", - label: "getDailySummarySettings") - } - - // -- Subscription/payments (GET → Python, was explicit pythonBackendURL, now default) -- - - func testGetUserSubscriptionRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getUserSubscription() as UserSubscriptionResponse - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/users/me/subscription", method: "GET", - label: "getUserSubscription") - } - - // MARK: - Routing behavior: Rust-routed endpoints (customBaseURL: rustBackendURL) - - // -- Config/API keys (GET → Rust) -- - - func testFetchApiKeysRoutesToRust() async { - let client = await makeTestClient() - _ = try? await client.fetchApiKeys() as APIClient.ApiKeysResponse - assertRoutes(URLCapture.capturedRequests, host: "rust-test", port: 9002, - pathContains: "v1/config/api-keys", method: "GET", - label: "fetchApiKeys") - } - - func testSynthesizeSpeechRoutesToRust() async { - let client = await makeTestClient() - _ = try? await client.synthesizeSpeech( - request: APIClient.TtsSynthesizeRequest( - text: "Hello", - voiceId: "onyx", - instructions: "Speak naturally" - ) - ) - - let requests = URLCapture.capturedRequests - assertRoutes(requests, host: "rust-test", port: 9002, - pathContains: "v1/tts/synthesize", method: "POST", - label: "synthesizeSpeech") - - let body = requests.first?.body.flatMap { try? JSONSerialization.jsonObject(with: $0) as? [String: Any] } - XCTAssertEqual(body?["text"] as? String, "Hello") - XCTAssertEqual(body?["voice_id"] as? String, "onyx") - XCTAssertEqual(body?["instructions"] as? String, "Speak naturally") - } - - // -- Assistant settings (GET → Python, migrated from Rust) -- - - func testGetAssistantSettingsRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getAssistantSettings() as AssistantSettingsResponse - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/users/assistant-settings", method: "GET", - label: "getAssistantSettings") - } - - // -- Notification settings (GET → Python, migrated from Rust) -- - - func testGetNotificationSettingsRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getNotificationSettings() as NotificationSettingsResponse - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/users/notification-settings", method: "GET", - label: "getNotificationSettings") - } - - // -- Staged tasks (GET, DELETE → Python, migrated from Rust) -- - - func testGetStagedTasksRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getStagedTasks() as ActionItemsListResponse - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/staged-tasks", method: "GET", - label: "getStagedTasks") - } - - func testDeleteStagedTaskRoutesToPython() async { - let client = await makeTestClient() - try? await client.deleteStagedTask(id: "st-1") - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/staged-tasks/st-1", method: "DELETE", - label: "deleteStagedTask") - } - - // -- Chat sessions (GET, POST, DELETE → Python, migrated from Rust) -- - - func testGetChatSessionsRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getChatSessions() as [ChatSession] - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v2/chat-sessions", method: "GET", - label: "getChatSessions") - } - - func testCreateChatSessionRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.createChatSession(title: "test") as ChatSession - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v2/chat-sessions", method: "POST", - label: "createChatSession") - } - - func testDeleteChatSessionRoutesToPython() async { - let client = await makeTestClient() - try? await client.deleteChatSession(sessionId: "sess-1") - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v2/chat-sessions/sess-1", method: "DELETE", - label: "deleteChatSession") - } - - // -- Desktop messages (DELETE → Python, path changed to v2/desktop/messages) -- - - func testDeleteMessagesRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.deleteMessages() as MessageDeleteResponse - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v2/desktop/messages", method: "DELETE", - label: "deleteMessages") - } - - // -- LLM usage (GET → Python, migrated from Rust) -- - - func testFetchTotalOmiAICostRoutesToPython() async { - let client = await makeTestClient() - _ = await client.fetchTotalOmiAICost() - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/users/me/llm-usage/total", method: "GET", - label: "fetchTotalOmiAICost") - } - - // MARK: - Python-routed: remaining manual URL builders - - // -- setConversationVisibility: manual URL(string: baseURL + ...) PATCH → Python -- - - func testSetConversationVisibilityRoutesToPython() async { - let client = await makeTestClient() - try? await client.setConversationVisibility(id: "c3") - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/conversations/c3/visibility", method: "PATCH", - label: "setConversationVisibility") + // MARK: - URL property tests + + func testBaseURLDefaultsToPythonBackend() async { + unsetenv("OMI_PYTHON_API_URL") + let client = APIClient() + let url = await client.baseURL + XCTAssertEqual(url, "https://api.omi.me/") + } + + func testBetaProductionBundleUsesDevelopmentPythonBackend() { + let url = DesktopBackendEnvironment.pythonBaseURL( + useDevelopmentBackends: true, + environmentValue: "https://api.omi.me" + ) + XCTAssertEqual(url, "https://api.omiapi.com/") + } + + func testStableProductionBundleKeepsProductionPythonBackend() { + let url = DesktopBackendEnvironment.pythonBaseURL( + useDevelopmentBackends: false, + environmentValue: "https://api.omi.me" + ) + XCTAssertEqual(url, "https://api.omi.me/") + } + + func testBetaProductionBundleUsesDevelopmentRustBackend() { + let url = DesktopBackendEnvironment.rustBackendURL( + useDevelopmentBackends: true, + environmentValue: "https://desktop-backend-hhibjajaja-uc.a.run.app", + launchEnvironmentValue: nil + ) + XCTAssertEqual(url, "https://desktop-backend-dt5lrfkkoa-uc.a.run.app/") + } + + func testStableProductionBundleKeepsConfiguredRustBackend() { + let url = DesktopBackendEnvironment.rustBackendURL( + useDevelopmentBackends: false, + environmentValue: "https://desktop-backend-hhibjajaja-uc.a.run.app", + launchEnvironmentValue: nil + ) + XCTAssertEqual(url, "https://desktop-backend-hhibjajaja-uc.a.run.app/") + } + + func testBetaProductionBundleRoutesToDevelopmentBackends() { + XCTAssertTrue( + DesktopBackendEnvironment.shouldUseDevelopmentBackends( + bundleIdentifier: "com.omi.computer-macos", + updateChannel: "beta" + )) + // "staging" is normalized to "beta" — same routing. + XCTAssertTrue( + DesktopBackendEnvironment.shouldUseDevelopmentBackends( + bundleIdentifier: "com.omi.computer-macos", + updateChannel: "staging" + )) + } + + func testStableProductionBundleKeepsProductionBackends() { + XCTAssertFalse( + DesktopBackendEnvironment.shouldUseDevelopmentBackends( + bundleIdentifier: "com.omi.computer-macos", + updateChannel: "stable" + )) + } + + func testNonProductionBundleSkipsAutomaticBetaRouting() { + // Dev bundle and named test bundles never trigger beta-to-dev routing + // automatically. They must opt in via OMI_FORCE_DEV_BACKENDS or env URLs. + XCTAssertFalse( + DesktopBackendEnvironment.shouldUseDevelopmentBackends( + bundleIdentifier: "com.omi.desktop-dev", + updateChannel: "beta" + )) + XCTAssertFalse( + DesktopBackendEnvironment.shouldUseDevelopmentBackends( + bundleIdentifier: "com.omi.omi-beta-dev-test", + updateChannel: "beta" + )) + } + + func testForceOverrideEnablesDevelopmentBackendsForAnyBundle() { + XCTAssertTrue( + DesktopBackendEnvironment.shouldUseDevelopmentBackends( + bundleIdentifier: "com.omi.desktop-dev", + updateChannel: "stable", + forceOverride: "1" + )) + XCTAssertTrue( + DesktopBackendEnvironment.shouldUseDevelopmentBackends( + bundleIdentifier: "com.omi.omi-beta-dev-test", + updateChannel: "stable", + forceOverride: "true" + )) + XCTAssertFalse( + DesktopBackendEnvironment.shouldUseDevelopmentBackends( + bundleIdentifier: "com.omi.computer-macos", + updateChannel: "stable", + forceOverride: "0" + )) + } + + func testBaseURLReadsFromPythonEnvVar() async { + setenv("OMI_PYTHON_API_URL", "http://localhost:8080", 1) + defer { unsetenv("OMI_PYTHON_API_URL") } + let client = APIClient() + let url = await client.baseURL + XCTAssertEqual(url, "http://localhost:8080/") + } + + func testBaseURLAddsTrailingSlash() async { + setenv("OMI_PYTHON_API_URL", "http://localhost:8080", 1) + defer { unsetenv("OMI_PYTHON_API_URL") } + let client = APIClient() + let url = await client.baseURL + XCTAssertTrue(url.hasSuffix("/")) + } + + func testBaseURLPreservesExistingTrailingSlash() async { + setenv("OMI_PYTHON_API_URL", "http://localhost:8080/", 1) + defer { unsetenv("OMI_PYTHON_API_URL") } + let client = APIClient() + let url = await client.baseURL + XCTAssertEqual(url, "http://localhost:8080/") + } + + func testRustBackendURLReadsFromApiUrlEnvVar() async { + setenv("OMI_DESKTOP_API_URL", "http://localhost:8787", 1) + defer { unsetenv("OMI_DESKTOP_API_URL") } + let client = APIClient() + let url = await client.rustBackendURL + XCTAssertEqual(url, "http://localhost:8787/") + } + + func testRustBackendURLReturnsEmptyWhenNotSet() async { + unsetenv("OMI_DESKTOP_API_URL") + let client = APIClient() + let url = await client.rustBackendURL + XCTAssertEqual(url, "") + } + + func testSelectedBackendTargetDefaultsToCloudPython() { + let target = DesktopBackendEnvironment.selectedBackendTarget( + modeValue: nil, + pythonEnvironmentValue: "https://api.example.test", + localDaemonEnvironmentValue: nil + ) + XCTAssertEqual(target.mode, .cloud) + XCTAssertEqual(target.baseURL, "https://api.example.test/") + XCTAssertTrue(target.requiresAuth) + } + + func testSelectedBackendTargetSupportsLocalDaemonDefault() { + let target = DesktopBackendEnvironment.selectedBackendTarget( + modeValue: "local", + pythonEnvironmentValue: "https://api.example.test", + localDaemonEnvironmentValue: nil + ) + XCTAssertEqual(target.mode, .localDaemon) + XCTAssertEqual(target.baseURL, "http://127.0.0.1:8765/") + XCTAssertFalse(target.requiresAuth) + } + + func testSelectedBackendTargetSupportsCustomRemote() { + let target = DesktopBackendEnvironment.selectedBackendTarget( + modeValue: "custom", + pythonEnvironmentValue: "http://custom-backend:7777", + localDaemonEnvironmentValue: "http://127.0.0.1:8765" + ) + XCTAssertEqual(target.mode, .customRemote) + XCTAssertEqual(target.baseURL, "http://custom-backend:7777/") + XCTAssertTrue(target.requiresAuth) + } + + func testLocalDaemonCapabilityMatrixDisablesCloudBoundFeatures() { + let capabilities = Dictionary( + uniqueKeysWithValues: DesktopBackendEnvironment.capabilities(for: .localDaemon) + .map { ($0.capability, $0) } + ) + + XCTAssertEqual(capabilities[.localConversationData]?.available, true) + XCTAssertEqual(capabilities[.firebaseSignIn]?.available, true) + XCTAssertEqual(capabilities[.managedAgentVM]?.available, false) + XCTAssertEqual(capabilities[.omiBackendProviderProxy]?.available, false) + XCTAssertEqual(capabilities[.publicSharing]?.available, false) + XCTAssertEqual(capabilities[.cloudSync]?.available, false) + XCTAssertEqual(capabilities[.payments]?.available, false) + XCTAssertEqual(capabilities[.crispSupport]?.available, false) + XCTAssertEqual(capabilities[.hostedTranscription]?.available, false) + XCTAssertNotNil(capabilities[.managedAgentVM]?.reason) + } + + func testCloudCapabilityMatrixAllowsCloudBoundFeatures() { + for state in DesktopBackendEnvironment.capabilities(for: .cloud) { + XCTAssertTrue( + state.available, "\(state.capability.rawValue) should be available in cloud mode") + XCTAssertNil(state.reason) + } + } + + func testBaseURLAndRustBackendURLAreIndependent() async { + setenv("OMI_PYTHON_API_URL", "http://python:8080", 1) + setenv("OMI_DESKTOP_API_URL", "http://rust:8787", 1) + defer { + unsetenv("OMI_PYTHON_API_URL") + unsetenv("OMI_DESKTOP_API_URL") + } + + let client = APIClient() + let base = await client.baseURL + let rust = await client.rustBackendURL + XCTAssertEqual(base, "http://python:8080/") + XCTAssertEqual(rust, "http://rust:8787/") + XCTAssertNotEqual(base, rust) + } + + // MARK: - Routing behavior: Python-routed endpoints (default baseURL) + + private func makeTestClient() async -> APIClient { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [URLCapture.self] + let session = URLSession(configuration: config) + let client = APIClient(session: session) + await client.setTestAuthHeader("Bearer test-token") + return client + } + + override func setUp() { + super.setUp() + URLCapture.reset() + setenv("OMI_PYTHON_API_URL", "http://python-test:9001", 1) + setenv("OMI_DESKTOP_API_URL", "http://rust-test:9002", 1) + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") + } + + override func tearDown() { + unsetenv("OMI_PYTHON_API_URL") + unsetenv("OMI_DESKTOP_API_URL") + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") + URLCapture.reset() + super.tearDown() + } + + // -- Conversations (GET, DELETE → Python) -- + + func testGetConversationRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getConversation(id: "test-123") as ServerConversation + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/conversations/test-123", method: "GET", + label: "getConversation") + } + + func testLocalModeGetConversationRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + _ = try? await client.getConversation(id: "local-123") as ServerConversation + + assertRoutes( + URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v1/conversations/local-123", method: "GET", + label: "local getConversation") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + + func testDeleteConversationRoutesToPython() async { + let client = await makeTestClient() + try? await client.deleteConversation(id: "conv-456") + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/conversations/conv-456", method: "DELETE", + label: "deleteConversation") + } + + func testLocalModeCreateConversationRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + _ = try? await client.createLocalDaemonConversation( + sessionId: "desktop-1", + title: "Local", + overview: "Local daemon", + startedAt: Date(timeIntervalSince1970: 0) + ) + + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 8765, + pathContains: "v1/conversations", method: "POST", + label: "local createConversation") + XCTAssertNil(requests.first?.headers["Authorization"]) + } + + func testLocalModeHealthCheckRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + _ = try? await client.checkSelectedBackendHealth() + + assertRoutes( + URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "health", method: "GET", + label: "local health") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + + func testLocalModeSettingsRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + _ = try? await client.updateSelectedBackendSettings(["profile_name": "Local"]) + + assertRoutes( + URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v1/settings", method: "PUT", + label: "local settings") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + + func testLocalModeMVPConversationFlowsIgnoreInvalidCloudURLs() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "http://omi-cloud-invalid:9001", 1) + setenv("OMI_DESKTOP_API_URL", "http://omi-rust-invalid:9002", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:9876", 1) + let client = await makeTestClient() + + _ = try? await client.getConversations() + _ = try? await client.getConversation(id: "local-123") as ServerConversation + _ = try? await client.searchConversations(query: "offline") + try? await client.updateConversationTitle(id: "local-123", title: "Offline") + try? await client.setConversationStarred(id: "local-123", starred: true) + _ = try? await client.updateSelectedBackendSettings(["profile_name": "Offline"]) + + let requests = URLCapture.capturedRequests + XCTAssertEqual(requests.count, 6) + XCTAssertTrue(requests.allSatisfy { $0.url.host == "127.0.0.1" && $0.url.port == 9876 }) + XCTAssertTrue(requests.allSatisfy { $0.headers["Authorization"] == nil }) + XCTAssertFalse( + requests.contains { $0.url.host == "omi-cloud-invalid" || $0.url.host == "omi-rust-invalid" }) + } + + func testLocalModeSetConversationStarredRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + try? await client.setConversationStarred(id: "local-star", starred: true) + + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 8765, + pathContains: "v1/conversations/local-star", method: "PATCH", + label: "local setConversationStarred") + XCTAssertNil(requests.first?.headers["Authorization"]) + + let body = requests.first?.body.flatMap { + try? JSONSerialization.jsonObject(with: $0) as? [String: Any] + } + XCTAssertEqual(body?["starred"] as? Bool, true) + } + + func testLocalModeMergeAndFolderActionsFailBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + do { + _ = try await client.mergeConversations(ids: ["c1", "c2"]) + XCTFail("expected merge to be unavailable") + } catch { + guard case APIError.featureUnavailable(let feature, _) = error else { + XCTFail("expected featureUnavailable for merge, got \(error)") + return + } + XCTAssertEqual(feature, "conversation_merge") } - // -- moveConversationToFolder: manual URL PATCH → Python -- - - func testMoveConversationToFolderRoutesToPython() async { - let client = await makeTestClient() - try? await client.moveConversationToFolder(conversationId: "c4", folderId: "f1") - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/conversations/c4/folder", method: "PATCH", - label: "moveConversationToFolder") + do { + _ = try await client.getFolders() + XCTFail("expected folders to be unavailable") + } catch { + guard case APIError.featureUnavailable(let feature, _) = error else { + XCTFail("expected featureUnavailable for folders, got \(error)") + return + } + XCTAssertEqual(feature, "conversation_folders") } - // -- setRecordingPermission: manual URL POST → Python -- - - func testSetRecordingPermissionRoutesToPython() async { - let client = await makeTestClient() - try? await client.setRecordingPermission(enabled: true) - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/users/store-recording-permission", method: "POST", - label: "setRecordingPermission") + do { + _ = try await client.createFolder(name: "Work") + XCTFail("expected folder creation to be unavailable") + } catch { + guard case APIError.featureUnavailable = error else { + XCTFail("expected featureUnavailable for folder creation, got \(error)") + return + } } - // -- setPrivateCloudSync: manual URL POST → Python -- - - func testSetPrivateCloudSyncRoutesToPython() async { - let client = await makeTestClient() - try? await client.setPrivateCloudSync(enabled: false) - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/users/private-cloud-sync", method: "POST", - label: "setPrivateCloudSync") + do { + _ = try await client.updateFolder(id: "f1", name: "Renamed") + XCTFail("expected folder update to be unavailable") + } catch { + guard case APIError.featureUnavailable = error else { + XCTFail("expected featureUnavailable for folder update, got \(error)") + return + } } - // -- completeGoal: manual URL PATCH → Python -- - - func testCompleteGoalRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.completeGoal(id: "g2") as Goal - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/goals/g2", method: "PATCH", - label: "completeGoal") + do { + try await client.deleteFolder(id: "f1") + XCTFail("expected folder deletion to be unavailable") + } catch { + guard case APIError.featureUnavailable = error else { + XCTFail("expected featureUnavailable for folder deletion, got \(error)") + return + } } - // -- assignSegmentsBulk: manual URL PATCH → Python -- - - func testAssignSegmentsBulkRoutesToPython() async { - let client = await makeTestClient() - try? await client.assignSegmentsBulk(conversationId: "c5", segmentIds: ["s1"], isUser: true, personId: nil) - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/conversations/c5/segments/assign-bulk", method: "PATCH", - label: "assignSegmentsBulk") + do { + try await client.moveConversationToFolder(conversationId: "c1", folderId: "f1") + XCTFail("expected move-to-folder to be unavailable") + } catch { + guard case APIError.featureUnavailable = error else { + XCTFail("expected featureUnavailable for move-to-folder, got \(error)") + return + } } - // -- Chat AI endpoints (migrated from Rust to Python) -- - - func testGetInitialMessageRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getInitialMessage(sessionId: "s1") - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v2/chat/initial-message", method: "POST", - label: "getInitialMessage") - } + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } - func testGenerateSessionTitleRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.generateSessionTitle(sessionId: "s1", messages: [("hi", "human")]) - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v2/chat/generate-title", method: "POST", - label: "generateSessionTitle") - } + func testLocalUnauthenticatedRequestDoesNotRefreshAuthOn401() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + URLCapture.setStatusCode(401) + let client = await makeTestClient() - func testGetChatMessageCountRoutesToPython() async { - let client = await makeTestClient() - _ = try? await client.getChatMessageCount() - assertRoutes(URLCapture.capturedRequests, host: "python-test", port: 9001, - pathContains: "v1/users/stats/chat-messages", method: "GET", - label: "getChatMessageCount") - } + do { + _ = try await client.getSelectedBackendSettings() + XCTFail("expected unauthorized") + } catch { + guard case APIError.unauthorized = error else { + XCTFail("expected unauthorized, got \(error)") + return + } + } + + let requests = URLCapture.capturedRequests + XCTAssertEqual(requests.count, 1) + XCTAssertNil(requests.first?.headers["Authorization"]) + } + + func testLocalModeCloudOnlyFeaturesFailBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + do { + _ = try await client.provisionAgentVM() + XCTFail("expected managed agent VM to be unavailable") + } catch { + assertUnavailable(error, capability: .managedAgentVM) + } + + do { + _ = try await client.fetchApiKeys() + XCTFail("expected backend provider proxy to be unavailable") + } catch { + assertUnavailable(error, capability: .omiBackendProviderProxy) + } + + do { + _ = try await client.getUserSubscription() + XCTFail("expected payments to be unavailable") + } catch { + assertUnavailable(error, capability: .payments) + } + + do { + _ = try await client.shareChatMessages(messageIds: ["m1"]) + XCTFail("expected public sharing to be unavailable") + } catch { + assertUnavailable(error, capability: .publicSharing) + } + + do { + try await client.setPrivateCloudSync(enabled: true) + XCTFail("expected cloud sync to be unavailable") + } catch { + assertUnavailable(error, capability: .cloudSync) + } + + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + + // -- Conversations: manual URL(string: baseURL + ...) paths (PATCH → Python) -- + + func testSetConversationStarredRoutesToPython() async { + let client = await makeTestClient() + try? await client.setConversationStarred(id: "c1", starred: true) + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/conversations/c1/starred", method: "PATCH", + label: "setConversationStarred") + } + + func testUpdateConversationTitleRoutesToPython() async { + let client = await makeTestClient() + try? await client.updateConversationTitle(id: "c2", title: "New") + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/conversations/c2", method: "PATCH", + label: "updateConversationTitle") + } + + // -- Folders (GET → Python) -- + + func testGetFoldersRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getFolders() as [Folder] + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/folders", method: "GET", + label: "getFolders") + } + + // -- Memories (POST → Python) -- + + func testCreateMemoryRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.createMemory(content: "test memory") as CreateMemoryResponse + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v3/memories", method: "POST", + label: "createMemory") + } + + // -- Goals: manual URL path (PATCH → Python) -- + + func testUpdateGoalProgressRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.updateGoalProgress(goalId: "g1", currentValue: 42.0) as Goal + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/goals/g1/progress", method: "PATCH", + label: "updateGoalProgress") + } + + // -- Apps (GET → Python) -- + + func testGetAppsRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getApps() as [OmiApp] + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/apps", method: "GET", + label: "getApps") + } + + // -- Personas (GET → Python) -- + + func testGetPersonaRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getPersona() as Persona? + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/personas", method: "GET", + label: "getPersona") + } + + // -- User settings (GET → Python) -- + + func testGetDailySummarySettingsRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getDailySummarySettings() as DailySummarySettings + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/users/daily-summary-settings", method: "GET", + label: "getDailySummarySettings") + } + + // -- Subscription/payments (GET → Python, was explicit pythonBackendURL, now default) -- + + func testGetUserSubscriptionRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getUserSubscription() as UserSubscriptionResponse + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/users/me/subscription", method: "GET", + label: "getUserSubscription") + } + + // MARK: - Routing behavior: Rust-routed endpoints (customBaseURL: rustBackendURL) + + // -- Config/API keys (GET → Rust) -- + + func testFetchApiKeysRoutesToRust() async { + let client = await makeTestClient() + _ = try? await client.fetchApiKeys() as APIClient.ApiKeysResponse + assertRoutes( + URLCapture.capturedRequests, host: "rust-test", port: 9002, + pathContains: "v1/config/api-keys", method: "GET", + label: "fetchApiKeys") + } + + func testSynthesizeSpeechRoutesToRust() async { + let client = await makeTestClient() + _ = try? await client.synthesizeSpeech( + request: APIClient.TtsSynthesizeRequest( + text: "Hello", + voiceId: "onyx", + instructions: "Speak naturally" + ) + ) + + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "rust-test", port: 9002, + pathContains: "v1/tts/synthesize", method: "POST", + label: "synthesizeSpeech") + + let body = requests.first?.body.flatMap { + try? JSONSerialization.jsonObject(with: $0) as? [String: Any] + } + XCTAssertEqual(body?["text"] as? String, "Hello") + XCTAssertEqual(body?["voice_id"] as? String, "onyx") + XCTAssertEqual(body?["instructions"] as? String, "Speak naturally") + } + + // -- Assistant settings (GET → Python, migrated from Rust) -- + + func testGetAssistantSettingsRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getAssistantSettings() as AssistantSettingsResponse + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/users/assistant-settings", method: "GET", + label: "getAssistantSettings") + } + + // -- Notification settings (GET → Python, migrated from Rust) -- + + func testGetNotificationSettingsRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getNotificationSettings() as NotificationSettingsResponse + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/users/notification-settings", method: "GET", + label: "getNotificationSettings") + } + + // -- Staged tasks (GET, DELETE → Python, migrated from Rust) -- + + func testGetStagedTasksRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getStagedTasks() as ActionItemsListResponse + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/staged-tasks", method: "GET", + label: "getStagedTasks") + } + + func testDeleteStagedTaskRoutesToPython() async { + let client = await makeTestClient() + try? await client.deleteStagedTask(id: "st-1") + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/staged-tasks/st-1", method: "DELETE", + label: "deleteStagedTask") + } + + // -- Chat sessions (GET, POST, DELETE → Python, migrated from Rust) -- + + func testGetChatSessionsRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getChatSessions() as [ChatSession] + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/chat-sessions", method: "GET", + label: "getChatSessions") + } + + func testCreateChatSessionRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.createChatSession(title: "test") as ChatSession + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/chat-sessions", method: "POST", + label: "createChatSession") + } + + func testDeleteChatSessionRoutesToPython() async { + let client = await makeTestClient() + try? await client.deleteChatSession(sessionId: "sess-1") + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/chat-sessions/sess-1", method: "DELETE", + label: "deleteChatSession") + } + + // -- Desktop messages (DELETE → Python, path changed to v2/desktop/messages) -- + + func testDeleteMessagesRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.deleteMessages() as MessageDeleteResponse + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/desktop/messages", method: "DELETE", + label: "deleteMessages") + } + + // -- LLM usage (GET → Python, migrated from Rust) -- + + func testFetchTotalOmiAICostRoutesToPython() async { + let client = await makeTestClient() + _ = await client.fetchTotalOmiAICost() + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/users/me/llm-usage/total", method: "GET", + label: "fetchTotalOmiAICost") + } + + // MARK: - Python-routed: remaining manual URL builders + + // -- setConversationVisibility: manual URL(string: baseURL + ...) PATCH → Python -- + + func testSetConversationVisibilityRoutesToPython() async { + let client = await makeTestClient() + try? await client.setConversationVisibility(id: "c3") + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/conversations/c3/visibility", method: "PATCH", + label: "setConversationVisibility") + } + + // -- moveConversationToFolder: manual URL PATCH → Python -- + + func testMoveConversationToFolderRoutesToPython() async { + let client = await makeTestClient() + try? await client.moveConversationToFolder(conversationId: "c4", folderId: "f1") + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/conversations/c4/folder", method: "PATCH", + label: "moveConversationToFolder") + } + + // -- setRecordingPermission: manual URL POST → Python -- + + func testSetRecordingPermissionRoutesToPython() async { + let client = await makeTestClient() + try? await client.setRecordingPermission(enabled: true) + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/users/store-recording-permission", method: "POST", + label: "setRecordingPermission") + } + + // -- setPrivateCloudSync: manual URL POST → Python -- + + func testSetPrivateCloudSyncRoutesToPython() async { + let client = await makeTestClient() + try? await client.setPrivateCloudSync(enabled: false) + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/users/private-cloud-sync", method: "POST", + label: "setPrivateCloudSync") + } + + // -- completeGoal: manual URL PATCH → Python -- + + func testCompleteGoalRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.completeGoal(id: "g2") as Goal + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/goals/g2", method: "PATCH", + label: "completeGoal") + } + + // -- assignSegmentsBulk: manual URL PATCH → Python -- + + func testAssignSegmentsBulkRoutesToPython() async { + let client = await makeTestClient() + try? await client.assignSegmentsBulk( + conversationId: "c5", segmentIds: ["s1"], isUser: true, personId: nil) + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/conversations/c5/segments/assign-bulk", method: "PATCH", + label: "assignSegmentsBulk") + } + + // -- Chat AI endpoints (migrated from Rust to Python) -- + + func testGetInitialMessageRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getInitialMessage(sessionId: "s1") + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/chat/initial-message", method: "POST", + label: "getInitialMessage") + } + + func testGenerateSessionTitleRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.generateSessionTitle(sessionId: "s1", messages: [("hi", "human")]) + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v2/chat/generate-title", method: "POST", + label: "generateSessionTitle") + } + + func testGetChatMessageCountRoutesToPython() async { + let client = await makeTestClient() + _ = try? await client.getChatMessageCount() + assertRoutes( + URLCapture.capturedRequests, host: "python-test", port: 9001, + pathContains: "v1/users/stats/chat-messages", method: "GET", + label: "getChatMessageCount") + } } // MARK: - Helper extension to set testAuthHeader from async context extension APIClient { - func setTestAuthHeader(_ header: String) async { - self.testAuthHeader = header - } + func setTestAuthHeader(_ header: String) async { + self.testAuthHeader = header + } } diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index f9b447ce5d2..5a38ef25a9b 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -154,6 +154,15 @@ mod tests { .await?; assert_eq!(job["processing_job"]["status"], "queued"); + let starred = request_json( + app.clone(), + Method::PATCH, + &format!("/v1/conversations/{conversation_id}"), + Some(json!({"starred": true})), + ) + .await?; + assert_eq!(starred["conversation"]["starred"], true); + let conversation = request_json( app.clone(), Method::GET, @@ -161,6 +170,7 @@ mod tests { None, ) .await?; + assert_eq!(conversation["conversation"]["starred"], true); assert_eq!( conversation["transcript_segments"] .as_array() diff --git a/desktop/local-backend/src/processing.rs b/desktop/local-backend/src/processing.rs index 76489158c05..b5be70416e6 100644 --- a/desktop/local-backend/src/processing.rs +++ b/desktop/local-backend/src/processing.rs @@ -239,6 +239,7 @@ fn persist_processing_output( status: Some("processed".to_string()), ended_at: None, metadata: None, + starred: None, }, )? .ok_or_else(|| anyhow!("conversation missing while persisting processing output"))?; diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index 1d3a5bc1426..97590972f88 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -232,6 +232,7 @@ struct UpdateConversationRequest { status: Option, ended_at: Option>, metadata: Option, + starred: Option, } async fn update_conversation( @@ -250,6 +251,7 @@ async fn update_conversation( status: request.status, ended_at: request.ended_at.map(Some), metadata: request.metadata, + starred: request.starred, }, ) .map_err(ApiError::internal)? diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index b3c41182859..4f4841fa4f4 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -7,10 +7,11 @@ use rusqlite::{params, Connection, OptionalExtension}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -const MIGRATIONS: &[Migration] = &[Migration { - version: 1, - name: "initial_local_storage", - sql: r#" +const MIGRATIONS: &[Migration] = &[ + Migration { + version: 1, + name: "initial_local_storage", + sql: r#" CREATE TABLE conversations ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL, @@ -218,7 +219,16 @@ const MIGRATIONS: &[Migration] = &[Migration { DELETE FROM conversation_search WHERE source_type = 'segment' AND source_id = old.id; END; "#, -}]; + }, + Migration { + version: 2, + name: "conversation_starred", + sql: r#" + ALTER TABLE conversations ADD COLUMN starred INTEGER NOT NULL DEFAULT 0; + CREATE INDEX idx_conversations_starred ON conversations(starred, updated_at); + "#, + }, +]; #[derive(Clone)] pub struct Store { @@ -247,6 +257,7 @@ pub struct Conversation { pub sync_version: i64, pub sync_state: String, pub metadata_json: String, + pub starred: bool, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -463,6 +474,7 @@ impl ConversationRepository { sync_version: 0, sync_state: "local".to_string(), metadata_json: json_or_empty_object(new.metadata)?, + starred: false, }; let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); @@ -470,9 +482,9 @@ impl ConversationRepository { r#" INSERT INTO conversations ( id, session_id, title, overview, status, started_at, ended_at, created_at, - updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred ) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15) "#, params![ conversation.id, @@ -488,7 +500,8 @@ impl ConversationRepository { conversation.cloud_id, conversation.sync_version, conversation.sync_state, - conversation.metadata_json + conversation.metadata_json, + conversation.starred ], ) .context("failed to insert conversation")?; @@ -501,7 +514,7 @@ impl ConversationRepository { conn.query_row( r#" SELECT id, session_id, title, overview, status, started_at, ended_at, created_at, - updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred FROM conversations WHERE id = ?1 AND deleted_at IS NULL "#, @@ -518,7 +531,7 @@ impl ConversationRepository { .prepare( r#" SELECT id, session_id, title, overview, status, started_at, ended_at, created_at, - updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred FROM conversations WHERE deleted_at IS NULL ORDER BY updated_at DESC @@ -561,6 +574,9 @@ impl ConversationRepository { if let Some(metadata) = update.metadata { conversation.metadata_json = json_or_empty_object(Some(metadata))?; } + if let Some(starred) = update.starred { + conversation.starred = starred; + } conversation.updated_at = Utc::now(); let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); @@ -568,7 +584,7 @@ impl ConversationRepository { r#" UPDATE conversations SET title = ?2, overview = ?3, status = ?4, ended_at = ?5, updated_at = ?6, - metadata_json = ?7, sync_version = sync_version + 1 + metadata_json = ?7, starred = ?8, sync_version = sync_version + 1 WHERE id = ?1 AND deleted_at IS NULL "#, params![ @@ -578,7 +594,8 @@ impl ConversationRepository { conversation.status, conversation.ended_at, conversation.updated_at, - conversation.metadata_json + conversation.metadata_json, + conversation.starred ], ) .context("failed to update conversation")?; @@ -1403,6 +1420,7 @@ pub struct UpdateConversation { pub status: Option, pub ended_at: Option>>, pub metadata: Option, + pub starred: Option, } #[derive(Debug, Clone)] @@ -1545,6 +1563,7 @@ fn map_conversation(row: &rusqlite::Row<'_>) -> rusqlite::Result { sync_version: row.get(11)?, sync_state: row.get(12)?, metadata_json: row.get(13)?, + starred: row.get(14)?, }) } @@ -1781,6 +1800,47 @@ mod tests { Ok(()) } + #[test] + fn conversation_starred_updates_persist() -> Result<()> { + let store = Store::open_in_memory()?; + let conversation_id = deterministic_id("conv", &["session-starred"]); + + store.conversations().create(NewConversation { + id: conversation_id.clone(), + session_id: "session-starred".to_string(), + title: "Starred conversation".to_string(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + + let updated = store + .conversations() + .update( + &conversation_id, + UpdateConversation { + title: None, + overview: None, + status: None, + ended_at: None, + metadata: None, + starred: Some(true), + }, + )? + .expect("conversation should update"); + + assert!(updated.starred); + assert!( + store + .conversations() + .get(&conversation_id)? + .expect("conversation should persist") + .starred + ); + + Ok(()) + } + #[test] fn fts_search_matches_conversation_and_transcript_text() -> Result<()> { let store = Store::open_in_memory()?; From c8722d949fdf36062e174cb84f961e93110b4b7d Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 22:39:39 +0700 Subject: [PATCH 13/58] Make local ingestion retries idempotent --- .../local-backend/docs/local-mvp-runbook.md | 6 + desktop/local-backend/src/main.rs | 65 ++++ desktop/local-backend/src/processing.rs | 98 +++++- desktop/local-backend/src/routes.rs | 33 +- desktop/local-backend/src/storage.rs | 289 +++++++++++++++++- .../local-backend/tools/import_transcript.py | 27 +- 6 files changed, 505 insertions(+), 13 deletions(-) diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index d2504da60c8..e5314c4a8da 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -119,6 +119,12 @@ The helper: - verifies search finds the imported transcript text - prints the conversation ID plus read and search `curl` commands +For retry tests, pass a stable `--conversation-id` and run the same command +again. The helper reuses the existing conversation, exact duplicate transcript +segments return the existing row, and finalize returns an already active or +current completed processing job instead of piling up duplicate queued work. A +different segment body at an existing `segment_index` returns HTTP 409. + JSON fixtures are also supported. The file may be a list of segment strings, a list of segment objects, or an object with conversation fields plus `segments` or `transcript_segments`: diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index 5a38ef25a9b..049f24d264c 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -255,6 +255,71 @@ mod tests { Ok(()) } + #[tokio::test] + async fn duplicate_finalize_reuses_active_and_current_completed_job() -> Result<()> { + let app = test_app()?; + + let created = request_json( + app.clone(), + Method::POST, + "/v1/conversations", + Some(json!({"session_id": "session-finalize-retry"})), + ) + .await?; + let conversation_id = created["conversation"]["id"].as_str().unwrap(); + request_json( + app.clone(), + Method::POST, + &format!("/v1/conversations/{conversation_id}/transcript-segments"), + Some(json!({ + "text": "Finalize should be retry safe.", + "start_ms": 0, + "end_ms": 1000, + "segment_index": 0 + })), + ) + .await?; + + let first = request_json( + app.clone(), + Method::POST, + &format!("/v1/conversations/{conversation_id}/finalize-transcript"), + None, + ) + .await?; + let second = request_json( + app.clone(), + Method::POST, + &format!("/v1/conversations/{conversation_id}/finalize-transcript"), + None, + ) + .await?; + assert_eq!( + first["processing_job"]["id"], + second["processing_job"]["id"] + ); + assert_eq!(second["processing_job"]["status"], "queued"); + + request_json( + app.clone(), + Method::POST, + "/v1/processing-jobs/process-next", + None, + ) + .await?; + let third = request_json( + app, + Method::POST, + &format!("/v1/conversations/{conversation_id}/finalize-transcript"), + None, + ) + .await?; + assert_eq!(first["processing_job"]["id"], third["processing_job"]["id"]); + assert_eq!(third["processing_job"]["status"], "completed"); + + Ok(()) + } + fn test_app() -> Result { let config = Config { bind_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), diff --git a/desktop/local-backend/src/processing.rs b/desktop/local-backend/src/processing.rs index b5be70416e6..79cded852a2 100644 --- a/desktop/local-backend/src/processing.rs +++ b/desktop/local-backend/src/processing.rs @@ -244,9 +244,21 @@ fn persist_processing_output( )? .ok_or_else(|| anyhow!("conversation missing while persisting processing output"))?; + let action_item_ids = output + .action_items + .iter() + .enumerate() + .map(|(index, item)| { + deterministic_id("act", &[conversation_id, &index.to_string(), &item.title]) + }) + .collect::>(); + store + .action_items() + .soft_delete_local_processing_except(conversation_id, &action_item_ids)?; + for (index, item) in output.action_items.iter().enumerate() { - store.action_items().create(NewActionItem { - id: deterministic_id("act", &[conversation_id, &index.to_string(), &item.title]), + store.action_items().upsert(NewActionItem { + id: action_item_ids[index].clone(), conversation_id: Some(conversation_id.to_string()), title: item.title.clone(), description: Some(item.description.clone()), @@ -256,12 +268,24 @@ fn persist_processing_output( })?; } - for (index, memory) in output.memories.iter().enumerate() { - store.memories().create(NewMemory { - id: deterministic_id( + let memory_ids = output + .memories + .iter() + .enumerate() + .map(|(index, memory)| { + deterministic_id( "mem", &[conversation_id, &index.to_string(), &memory.content], - ), + ) + }) + .collect::>(); + store + .memories() + .soft_delete_local_processing_except(conversation_id, &memory_ids)?; + + for (index, memory) in output.memories.iter().enumerate() { + store.memories().upsert(NewMemory { + id: memory_ids[index].clone(), content: memory.content.clone(), category: memory.category.clone(), conversation_id: Some(conversation_id.to_string()), @@ -357,4 +381,66 @@ mod tests { Ok(()) } + + #[test] + fn provider_style_processing_outputs_are_retry_safe() -> Result<()> { + let store = Store::open_in_memory()?; + let conversation_id = deterministic_id("conv", &["session-provider-retry"]); + + store.conversations().create(NewConversation { + id: conversation_id.clone(), + session_id: "session-provider-retry".to_string(), + title: String::new(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + + let output = ProcessingOutput { + title: "Provider summary".to_string(), + overview: "Provider overview".to_string(), + action_items: vec![ExtractedActionItem { + title: "Review retry behavior".to_string(), + description: "Confirm local processing upserts deterministic rows.".to_string(), + }], + memories: vec![ExtractedMemory { + content: "User prefers retry-safe local imports.".to_string(), + category: Some("preference".to_string()), + }], + provider: "openai_compatible".to_string(), + }; + + persist_processing_output(&store, &conversation_id, &output)?; + persist_processing_output(&store, &conversation_id, &output)?; + + let action_items = store.action_items().list()?; + let memories = store.memories().list()?; + assert_eq!(action_items.len(), 1); + assert_eq!(action_items[0].title, "Review retry behavior"); + assert_eq!(memories.len(), 1); + assert_eq!( + memories[0].content, + "User prefers retry-safe local imports." + ); + + let replacement = ProcessingOutput { + title: "Provider summary".to_string(), + overview: "Provider overview".to_string(), + action_items: vec![ExtractedActionItem { + title: "Ship retry behavior".to_string(), + description: "Replace stale local processing rows.".to_string(), + }], + memories: Vec::new(), + provider: "openai_compatible".to_string(), + }; + persist_processing_output(&store, &conversation_id, &replacement)?; + + let action_items = store.action_items().list()?; + let memories = store.memories().list()?; + assert_eq!(action_items.len(), 1); + assert_eq!(action_items[0].title, "Ship retry behavior"); + assert!(memories.is_empty()); + + Ok(()) + } } diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index 97590972f88..fb7a7fb1417 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -12,8 +12,9 @@ use serde_json::{json, Map, Value}; use crate::{ processing, storage::{ - deterministic_id, NewActionItem, NewConversation, NewMemory, NewProcessingJob, - NewTranscriptSegment, UpdateActionItem, UpdateConversation, UpdateMemory, UpdateProfile, + deterministic_id, AppendTranscriptResult, NewActionItem, NewConversation, NewMemory, + NewProcessingJob, NewTranscriptSegment, UpdateActionItem, UpdateConversation, UpdateMemory, + UpdateProfile, }, AppState, }; @@ -79,6 +80,13 @@ impl ApiError { } } + fn conflict(message: impl Into) -> Self { + Self { + status: StatusCode::CONFLICT, + message: message.into(), + } + } + fn not_found(entity: &str) -> Self { Self { status: StatusCode::NOT_FOUND, @@ -314,7 +322,7 @@ async fn append_transcript_segment( let id = request.id.unwrap_or_else(|| { deterministic_id("seg", &[&conversation_id, &segment_index.to_string()]) }); - let segment = state + let append_result = state .store .transcripts() .append(NewTranscriptSegment { @@ -331,6 +339,16 @@ async fn append_transcript_segment( metadata: request.metadata, }) .map_err(ApiError::internal)?; + let segment = match append_result { + AppendTranscriptResult::Inserted(segment) | AppendTranscriptResult::Existing(segment) => { + segment + } + AppendTranscriptResult::Conflict(_) => { + return Err(ApiError::conflict( + "transcript segment already exists with different content at this index", + )); + } + }; Ok(Json(json!({ "transcript_segment": segment }))) } @@ -344,6 +362,15 @@ async fn finalize_transcript( .get(&conversation_id) .map_err(ApiError::internal)? .ok_or_else(|| ApiError::not_found("conversation"))?; + if let Some(job) = state + .store + .processing_jobs() + .reusable_for_conversation("finalize_transcript", &conversation_id) + .map_err(ApiError::internal)? + { + return Ok(Json(json!({ "processing_job": job }))); + } + let job = state .store .processing_jobs() diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index 4f4841fa4f4..656ea7c6ca8 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -622,7 +622,17 @@ pub struct TranscriptRepository { } impl TranscriptRepository { - pub fn append(&self, new: NewTranscriptSegment) -> Result { + pub fn append(&self, new: NewTranscriptSegment) -> Result { + if let Some(existing) = + self.get_by_conversation_index(&new.conversation_id, new.segment_index)? + { + return if transcript_matches_new(&existing, &new)? { + Ok(AppendTranscriptResult::Existing(existing)) + } else { + Ok(AppendTranscriptResult::Conflict(existing)) + }; + } + let now = Utc::now(); let segment = TranscriptSegment { id: new.id, @@ -676,7 +686,7 @@ impl TranscriptRepository { ) .context("failed to insert transcript segment")?; - Ok(segment) + Ok(AppendTranscriptResult::Inserted(segment)) } pub fn list_for_conversation(&self, conversation_id: &str) -> Result> { @@ -701,6 +711,27 @@ impl TranscriptRepository { collect_rows(rows) } + pub fn get_by_conversation_index( + &self, + conversation_id: &str, + segment_index: i64, + ) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT id, conversation_id, session_id, speaker_id, speaker_label, text, start_ms, + end_ms, segment_index, source, created_at, updated_at, deleted_at, cloud_id, + sync_version, sync_state, metadata_json + FROM transcript_segments + WHERE conversation_id = ?1 AND segment_index = ?2 AND deleted_at IS NULL + "#, + params![conversation_id, segment_index], + map_transcript_segment, + ) + .optional() + .context("failed to fetch transcript segment by index") + } + pub fn next_segment_index(&self, conversation_id: &str) -> Result { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); conn.query_row( @@ -712,6 +743,32 @@ impl TranscriptRepository { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AppendTranscriptResult { + Inserted(TranscriptSegment), + Existing(TranscriptSegment), + Conflict(TranscriptSegment), +} + +fn transcript_matches_new( + existing: &TranscriptSegment, + new: &NewTranscriptSegment, +) -> Result { + let source = new.source.as_deref().unwrap_or("local"); + let metadata_json = json_or_empty_object(new.metadata.clone())?; + Ok(existing.id == new.id + && existing.conversation_id == new.conversation_id + && existing.session_id == new.session_id + && existing.speaker_id == new.speaker_id + && existing.speaker_label == new.speaker_label + && existing.text == new.text + && existing.start_ms == new.start_ms + && existing.end_ms == new.end_ms + && existing.segment_index == new.segment_index + && existing.source == source + && existing.metadata_json == metadata_json) +} + pub struct ProcessingJobRepository { conn: Arc>, } @@ -795,6 +852,53 @@ impl ProcessingJobRepository { .context("failed to fetch processing job") } + pub fn reusable_for_conversation( + &self, + kind: &str, + conversation_id: &str, + ) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let mut stmt = conn + .prepare( + r#" + SELECT id, kind, status, target_conversation_id, retry_count, max_retries, last_error, + payload_json, result_json, queued_at, started_at, completed_at, failed_at, + created_at, updated_at, deleted_at, cloud_id, sync_version, sync_state + FROM processing_jobs + WHERE kind = ?1 + AND target_conversation_id = ?2 + AND deleted_at IS NULL + AND ( + status IN ('queued', 'running') + OR ( + status = 'completed' + AND julianday(completed_at) >= COALESCE( + ( + SELECT MAX(julianday(updated_at)) + FROM transcript_segments + WHERE conversation_id = ?2 AND deleted_at IS NULL + ), + julianday(completed_at) + ) + ) + ) + ORDER BY + CASE status + WHEN 'running' THEN 0 + WHEN 'queued' THEN 1 + ELSE 2 + END, + updated_at DESC + LIMIT 1 + "#, + ) + .context("failed to prepare reusable processing job query")?; + + stmt.query_row(params![kind, conversation_id], map_processing_job) + .optional() + .context("failed to fetch reusable processing job") + } + pub fn list(&self) -> Result> { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); let mut stmt = conn @@ -959,6 +1063,50 @@ impl MemoryRepository { Ok(memory) } + pub fn upsert(&self, new: NewMemory) -> Result { + let now = Utc::now(); + let metadata_json = json_or_empty_object(new.metadata)?; + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + INSERT INTO memories ( + id, content, category, conversation_id, created_at, updated_at, deleted_at, + cloud_id, sync_version, sync_state, metadata_json + ) + VALUES (?1, ?2, ?3, ?4, ?5, ?5, NULL, NULL, 0, 'local', ?6) + ON CONFLICT(id) DO UPDATE SET + content = excluded.content, + category = excluded.category, + conversation_id = excluded.conversation_id, + updated_at = excluded.updated_at, + deleted_at = NULL, + metadata_json = excluded.metadata_json, + sync_version = memories.sync_version + 1, + sync_state = 'local' + "#, + params![ + new.id, + new.content, + new.category, + new.conversation_id, + now, + metadata_json + ], + ) + .context("failed to upsert memory")?; + drop(conn); + self.get(&new.id)? + .ok_or_else(|| anyhow::anyhow!("memory missing after upsert")) + } + + pub fn soft_delete_local_processing_except( + &self, + conversation_id: &str, + keep_ids: &[String], + ) -> Result { + soft_delete_local_processing_except(&self.conn, "memories", conversation_id, keep_ids) + } + pub fn get(&self, id: &str) -> Result> { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); conn.query_row( @@ -1092,6 +1240,60 @@ impl ActionItemRepository { Ok(action_item) } + pub fn upsert(&self, new: NewActionItem) -> Result { + let now = Utc::now(); + let description = new.description.unwrap_or_default(); + let status = new.status.unwrap_or_else(|| "open".to_string()); + let metadata_json = json_or_empty_object(new.metadata)?; + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + INSERT INTO action_items ( + id, conversation_id, title, description, status, due_at, completed_at, created_at, + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json + ) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, NULL, ?7, ?7, NULL, NULL, 0, 'local', ?8) + ON CONFLICT(id) DO UPDATE SET + conversation_id = excluded.conversation_id, + title = excluded.title, + description = excluded.description, + status = excluded.status, + due_at = excluded.due_at, + completed_at = CASE + WHEN excluded.status = 'completed' THEN COALESCE(action_items.completed_at, excluded.updated_at) + ELSE NULL + END, + updated_at = excluded.updated_at, + deleted_at = NULL, + metadata_json = excluded.metadata_json, + sync_version = action_items.sync_version + 1, + sync_state = 'local' + "#, + params![ + new.id, + new.conversation_id, + new.title, + description, + status, + new.due_at, + now, + metadata_json + ], + ) + .context("failed to upsert action item")?; + drop(conn); + self.get(&new.id)? + .ok_or_else(|| anyhow::anyhow!("action item missing after upsert")) + } + + pub fn soft_delete_local_processing_except( + &self, + conversation_id: &str, + keep_ids: &[String], + ) -> Result { + soft_delete_local_processing_except(&self.conn, "action_items", conversation_id, keep_ids) + } + pub fn get(&self, id: &str) -> Result> { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); conn.query_row( @@ -1696,6 +1898,34 @@ fn soft_delete_by_id( Ok(changed > 0) } +fn soft_delete_local_processing_except( + conn: &Arc>, + table: &str, + conversation_id: &str, + keep_ids: &[String], +) -> Result { + let now = Utc::now(); + let conn = conn.lock().expect("SQLite connection mutex poisoned"); + let keep_ids_json = + serde_json::to_string(keep_ids).context("failed to serialize local processing ids")?; + let changed = conn + .execute( + &format!( + r#" + UPDATE {table} + SET deleted_at = ?3, updated_at = ?3, sync_version = sync_version + 1 + WHERE conversation_id = ?1 + AND deleted_at IS NULL + AND json_extract(metadata_json, '$.source') = 'local_processing' + AND id NOT IN (SELECT value FROM json_each(?2)) + "# + ), + params![conversation_id, keep_ids_json, now], + ) + .with_context(|| format!("failed to delete stale local processing rows from {table}"))?; + Ok(changed) +} + impl ProcessingJobStatus { fn as_str(&self) -> &'static str { match self { @@ -1800,6 +2030,61 @@ mod tests { Ok(()) } + #[test] + fn duplicate_transcript_append_is_existing_or_conflict() -> Result<()> { + let store = Store::open_in_memory()?; + let conversation_id = deterministic_id("conv", &["session-duplicate-segment"]); + + store.conversations().create(NewConversation { + id: conversation_id.clone(), + session_id: "session-duplicate-segment".to_string(), + title: String::new(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + + let new_segment = NewTranscriptSegment { + id: deterministic_id("seg", &[&conversation_id, "0"]), + conversation_id: conversation_id.clone(), + session_id: "session-duplicate-segment".to_string(), + speaker_id: Some("speaker-1".to_string()), + speaker_label: Some("Alice".to_string()), + text: "Retry-safe transcript append.".to_string(), + start_ms: 0, + end_ms: 1200, + segment_index: 0, + source: None, + metadata: Some(serde_json::json!({"source": "test"})), + }; + + assert!(matches!( + store.transcripts().append(new_segment.clone())?, + AppendTranscriptResult::Inserted(_) + )); + assert!(matches!( + store.transcripts().append(new_segment)?, + AppendTranscriptResult::Existing(_) + )); + + let conflict = store.transcripts().append(NewTranscriptSegment { + id: deterministic_id("seg", &[&conversation_id, "0"]), + conversation_id, + session_id: "session-duplicate-segment".to_string(), + speaker_id: Some("speaker-1".to_string()), + speaker_label: Some("Alice".to_string()), + text: "Different content at the same segment index.".to_string(), + start_ms: 0, + end_ms: 1200, + segment_index: 0, + source: None, + metadata: Some(serde_json::json!({"source": "test"})), + })?; + assert!(matches!(conflict, AppendTranscriptResult::Conflict(_))); + + Ok(()) + } + #[test] fn conversation_starred_updates_persist() -> Result<()> { let store = Store::open_in_memory()?; diff --git a/desktop/local-backend/tools/import_transcript.py b/desktop/local-backend/tools/import_transcript.py index 9fee82ccd33..37dc3dc985c 100755 --- a/desktop/local-backend/tools/import_transcript.py +++ b/desktop/local-backend/tools/import_transcript.py @@ -18,9 +18,16 @@ DEFAULT_BASE_URL = "http://127.0.0.1:8765" -def request_json(method: str, base_url: str, path: str, body: dict[str, Any] | None = None) -> dict[str, Any]: +def request_json( + method: str, + base_url: str, + path: str, + body: dict[str, Any] | None = None, + ok_statuses: set[int] | None = None, +) -> dict[str, Any]: data = None headers = {"Accept": "application/json"} + ok_statuses = ok_statuses or {200} if body is not None: data = json.dumps(body).encode("utf-8") headers["Content-Type"] = "application/json" @@ -33,9 +40,14 @@ def request_json(method: str, base_url: str, path: str, body: dict[str, Any] | N ) try: with urllib.request.urlopen(request, timeout=10) as response: + if response.status not in ok_statuses: + payload = response.read().decode("utf-8", errors="replace") + raise RuntimeError(f"{method} {path} failed with HTTP {response.status}: {payload}") payload = response.read().decode("utf-8") except urllib.error.HTTPError as error: payload = error.read().decode("utf-8", errors="replace") + if error.code in ok_statuses: + return json.loads(payload) if payload else {} raise RuntimeError(f"{method} {path} failed with HTTP {error.code}: {payload}") from error except urllib.error.URLError as error: raise RuntimeError(f"{method} {path} failed: {error.reason}") from error @@ -164,7 +176,18 @@ def main() -> int: conversation[key] = value request_json("GET", args.base_url, "/health") - created = request_json("POST", args.base_url, "/v1/conversations", conversation)["conversation"] + if "id" in conversation: + existing = request_json( + "GET", + args.base_url, + f"/v1/conversations/{urllib.parse.quote(conversation['id'])}", + ok_statuses={200, 404}, + ) + created = existing.get("conversation") + else: + created = None + if created is None: + created = request_json("POST", args.base_url, "/v1/conversations", conversation)["conversation"] conversation_id = created["id"] for index, segment in enumerate(segments): From 0d14e9fb65b567023a1b88d13ca8c7f650d85b00 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Mon, 18 May 2026 22:53:26 +0700 Subject: [PATCH 14/58] feat(desktop): expand conversation detail view and settings with local-mode UI - Major ConversationDetailView.swift refactor (+2,378 lines changed) - Add local-mode toggle to SettingsPage - Update ConversationsPage with local-first indicators - Expand .env.example with local backend config - Update local-backend README and MVP runbook --- desktop/.env.example | 12 + .../Pages/ConversationDetailView.swift | 2378 +++++++++-------- .../MainWindow/Pages/ConversationsPage.swift | 18 + .../MainWindow/Pages/SettingsPage.swift | 232 +- desktop/local-backend/README.md | 19 + .../local-backend/docs/local-mvp-runbook.md | 114 +- 6 files changed, 1543 insertions(+), 1230 deletions(-) diff --git a/desktop/.env.example b/desktop/.env.example index 25f78bbcd58..0af05c97cfa 100644 --- a/desktop/.env.example +++ b/desktop/.env.example @@ -32,6 +32,18 @@ OMI_PYTHON_API_URL=https://api.omi.me # `curl http://127.0.0.1:8765/health` # custom: custom remote URL from OMI_PYTHON_API_URL # OMI_DESKTOP_BACKEND_MODE=cloud +# +# Primary local MVP user-test launch: +# OMI_DESKTOP_BACKEND_MODE=local \ +# OMI_LOCAL_DAEMON_SUPERVISE=1 \ +# OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 \ +# OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001 \ +# OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ +# ./run.sh +# +# Required for local mode: OMI_DESKTOP_BACKEND_MODE=local. +# Optional: OMI_LOCAL_DAEMON_SUPERVISE=1, OMI_LOCAL_DAEMON_URL, +# OMI_LOCAL_BACKEND_DATA_DIR, OMI_LOCAL_BACKEND_PORT, OMI_LOCAL_DAEMON_LOG. # OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 # OMI_LOCAL_DAEMON_SUPERVISE=1 diff --git a/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift b/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift index 8d59d6eb43d..bade04d45c5 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/ConversationDetailView.swift @@ -2,1303 +2,1331 @@ import SwiftUI /// Full detail view for a single conversation struct ConversationDetailView: View { - let conversation: ServerConversation - let onBack: () -> Void - var folders: [Folder] = [] - var onMoveToFolder: ((String, String?) async -> Void)? - var onDelete: (() -> Void)? - var onTitleUpdated: ((String) -> Void)? - - // People (speaker naming) - var people: [Person] = [] - var onFetchPeople: (() async -> Void)? - var onCreatePerson: ((String) async -> Person?)? - var onAssignSpeaker: ((String, [String], String?, Bool) async -> Bool)? - - @StateObject private var appProvider = AppProvider() - @State private var showAppSelector = false - @State private var isReprocessing = false - @State private var selectedAppForReprocess: OmiApp? - - // Transcript drawer state (replaces tab system) - @State private var showTranscriptDrawer = false - - // Entry animation - @State private var hasAppeared = false - - // Full conversation loaded from API (with transcript segments) - @State private var loadedConversation: ServerConversation? - @State private var isLoadingConversation = false - - // Action states - @State private var showDeleteConfirmation = false - @State private var showEditDialog = false - @State private var editedTitle = "" - @State private var isUpdatingTitle = false - @State private var isCopyingLink = false - @State private var isDeleting = false - - // Speaker naming state - @State private var selectedSegmentForNaming: TranscriptSegment? = nil - - static func assignmentMetadata( - for segmentIndices: [Int], - in segments: [TranscriptSegment] - ) -> (targets: [String], backendIds: [String], fallbackOrders: [Int]) { - let validIndices = segmentIndices.filter { segments.indices.contains($0) } - let targets = validIndices.map { index in - segments[index].backendId ?? "#index:\(index)" - } - let backendIds = validIndices.compactMap { index in - segments[index].backendId - } - let fallbackOrders = validIndices.filter { index in - segments[index].backendId == nil - } - return (targets, backendIds, fallbackOrders) - } - - /// The conversation to display - use loaded version if available, otherwise use prop - private var displayConversation: ServerConversation { - loadedConversation ?? conversation - } - - /// The date to display (prefer startedAt, fall back to createdAt) - private var displayDate: Date { - displayConversation.startedAt ?? displayConversation.createdAt + let conversation: ServerConversation + let onBack: () -> Void + var folders: [Folder] = [] + var onMoveToFolder: ((String, String?) async -> Void)? + var onDelete: (() -> Void)? + var onTitleUpdated: ((String) -> Void)? + + // People (speaker naming) + var people: [Person] = [] + var onFetchPeople: (() async -> Void)? + var onCreatePerson: ((String) async -> Person?)? + var onAssignSpeaker: ((String, [String], String?, Bool) async -> Bool)? + + @StateObject private var appProvider = AppProvider() + @State private var showAppSelector = false + @State private var isReprocessing = false + @State private var selectedAppForReprocess: OmiApp? + + // Transcript drawer state (replaces tab system) + @State private var showTranscriptDrawer = false + + // Entry animation + @State private var hasAppeared = false + + // Full conversation loaded from API (with transcript segments) + @State private var loadedConversation: ServerConversation? + @State private var isLoadingConversation = false + + // Action states + @State private var showDeleteConfirmation = false + @State private var showEditDialog = false + @State private var editedTitle = "" + @State private var isUpdatingTitle = false + @State private var isCopyingLink = false + @State private var isDeleting = false + + // Speaker naming state + @State private var selectedSegmentForNaming: TranscriptSegment? = nil + + static func assignmentMetadata( + for segmentIndices: [Int], + in segments: [TranscriptSegment] + ) -> (targets: [String], backendIds: [String], fallbackOrders: [Int]) { + let validIndices = segmentIndices.filter { segments.indices.contains($0) } + let targets = validIndices.map { index in + segments[index].backendId ?? "#index:\(index)" } - - // Static date formatters — creating DateFormatter is expensive, avoid per-render allocation - private static let dayDateFormatter: DateFormatter = { - let f = DateFormatter() - f.dateFormat = "EEEE, MMM d, yyyy" - return f - }() - private static let timeOnlyFormatter: DateFormatter = { - let f = DateFormatter() - f.dateFormat = "h:mm a" - return f - }() - private static let shortDateFormatter: DateFormatter = { - let f = DateFormatter() - f.dateFormat = "MMM d, yyyy" - return f - }() - - /// Format date for display - private var formattedDate: String { - Self.dayDateFormatter.string(from: displayDate) + let backendIds = validIndices.compactMap { index in + segments[index].backendId } - - /// Format time for display - private var formattedTime: String { - Self.timeOnlyFormatter.string(from: displayDate) + let fallbackOrders = validIndices.filter { index in + segments[index].backendId == nil } - - /// Format time range for header subtitle (e.g., "Jan 15, 2025 from 2:30 PM to 3:15 PM") - private var formattedTimeRange: String { - let dateStr = Self.shortDateFormatter.string(from: displayDate) - let startStr = Self.timeOnlyFormatter.string(from: displayDate) - - if let finishedAt = displayConversation.finishedAt { - let endStr = Self.timeOnlyFormatter.string(from: finishedAt) - return "\(dateStr) from \(startStr) to \(endStr)" - } - return "\(dateStr) at \(startStr)" + return (targets, backendIds, fallbackOrders) + } + + /// The conversation to display - use loaded version if available, otherwise use prop + private var displayConversation: ServerConversation { + loadedConversation ?? conversation + } + + private var isLocalDaemonMode: Bool { + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + } + + /// The date to display (prefer startedAt, fall back to createdAt) + private var displayDate: Date { + displayConversation.startedAt ?? displayConversation.createdAt + } + + // Static date formatters — creating DateFormatter is expensive, avoid per-render allocation + private static let dayDateFormatter: DateFormatter = { + let f = DateFormatter() + f.dateFormat = "EEEE, MMM d, yyyy" + return f + }() + private static let timeOnlyFormatter: DateFormatter = { + let f = DateFormatter() + f.dateFormat = "h:mm a" + return f + }() + private static let shortDateFormatter: DateFormatter = { + let f = DateFormatter() + f.dateFormat = "MMM d, yyyy" + return f + }() + + /// Format date for display + private var formattedDate: String { + Self.dayDateFormatter.string(from: displayDate) + } + + /// Format time for display + private var formattedTime: String { + Self.timeOnlyFormatter.string(from: displayDate) + } + + /// Format time range for header subtitle (e.g., "Jan 15, 2025 from 2:30 PM to 3:15 PM") + private var formattedTimeRange: String { + let dateStr = Self.shortDateFormatter.string(from: displayDate) + let startStr = Self.timeOnlyFormatter.string(from: displayDate) + + if let finishedAt = displayConversation.finishedAt { + let endStr = Self.timeOnlyFormatter.string(from: finishedAt) + return "\(dateStr) from \(startStr) to \(endStr)" } - - var body: some View { - HStack(spacing: 0) { - // Main content (always visible) - VStack(alignment: .leading, spacing: 0) { - headerView - - ScrollView { - // Card container wrapping summary content - VStack(alignment: .leading, spacing: 0) { - // Card header bar - HStack(spacing: 8) { - Image(systemName: "doc.text") - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) - Text("Conversation Details") - .scaledFont(size: 13, weight: .medium) - .foregroundColor(OmiColors.textSecondary) - Spacer() - } - .padding(.horizontal, 16) - .padding(.vertical, 10) - .background(OmiColors.backgroundTertiary.opacity(0.4)) - - VStack(alignment: .leading, spacing: 24) { - summaryContent - } - .padding(24) - } - .background( - RoundedRectangle(cornerRadius: 16) - .fill(OmiColors.backgroundSecondary.opacity(0.6)) - ) - .clipShape(RoundedRectangle(cornerRadius: 16)) - .overlay( - RoundedRectangle(cornerRadius: 16) - .stroke(OmiColors.backgroundTertiary.opacity(0.3), lineWidth: 1) - ) - .shadow(color: Color.black.opacity(0.1), radius: 20, x: 0, y: 8) - .padding(24) - } + return "\(dateStr) at \(startStr)" + } + + var body: some View { + HStack(spacing: 0) { + // Main content (always visible) + VStack(alignment: .leading, spacing: 0) { + headerView + + ScrollView { + // Card container wrapping summary content + VStack(alignment: .leading, spacing: 0) { + // Card header bar + HStack(spacing: 8) { + Image(systemName: "doc.text") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + Text("Conversation Details") + .scaledFont(size: 13, weight: .medium) + .foregroundColor(OmiColors.textSecondary) + Spacer() } - .frame(maxWidth: .infinity) - - // Transcript drawer (slides in from right) - if showTranscriptDrawer { - Rectangle() - .fill(OmiColors.border) - .frame(width: 1) + .padding(.horizontal, 16) + .padding(.vertical, 10) + .background(OmiColors.backgroundTertiary.opacity(0.4)) - transcriptDrawerView - .frame(width: 450) - .transition(.move(edge: .trailing)) - } - } - .opacity(hasAppeared ? 1 : 0) - .offset(y: hasAppeared ? 0 : 20) - .onAppear { - withAnimation(.easeOut(duration: 0.5)) { - hasAppeared = true + VStack(alignment: .leading, spacing: 24) { + summaryContent } + .padding(24) + } + .background( + RoundedRectangle(cornerRadius: 16) + .fill(OmiColors.backgroundSecondary.opacity(0.6)) + ) + .clipShape(RoundedRectangle(cornerRadius: 16)) + .overlay( + RoundedRectangle(cornerRadius: 16) + .stroke(OmiColors.backgroundTertiary.opacity(0.3), lineWidth: 1) + ) + .shadow(color: Color.black.opacity(0.1), radius: 20, x: 0, y: 8) + .padding(24) } - .task { - await appProvider.fetchApps() - await onFetchPeople?() - AnalyticsManager.shared.conversationDetailOpened(conversationId: conversation.id) - - // Load segments from local database if not already present - // Segments are stored locally but not loaded with the list view for performance - if conversation.transcriptSegments.isEmpty { - isLoadingConversation = true - do { - // First try local database (faster, works offline) - if let session = try await TranscriptionStorage.shared.getSessionByBackendId(conversation.id) { - let segmentRecords = try await TranscriptionStorage.shared.getSegments(sessionId: session.id!) - if !segmentRecords.isEmpty { - // Convert local records to TranscriptSegments and update conversation - let segments = segmentRecords.map { $0.toTranscriptSegment() } - var updatedConversation = conversation - updatedConversation.transcriptSegments = segments - loadedConversation = updatedConversation - log("ConversationDetail: Loaded \(segments.count) segments from local database") - } else { - // No local segments, fetch from API - let fullConversation = try await APIClient.shared.getConversation(id: conversation.id) - loadedConversation = fullConversation - log("ConversationDetail: Loaded \(fullConversation.transcriptSegments.count) segments from API") - } - } else { - // No local session found, fetch from API - let fullConversation = try await APIClient.shared.getConversation(id: conversation.id) - loadedConversation = fullConversation - log("ConversationDetail: Loaded \(fullConversation.transcriptSegments.count) segments from API (no local session)") - } - } catch { - logError("ConversationDetail: Failed to load conversation segments", error: error) - } - isLoadingConversation = false - } - } - .onReceive( - NotificationCenter.default.publisher(for: .desktopAutomationShowConversationTranscriptRequested) - ) { notification in - guard let conversationId = notification.userInfo?["conversationId"] as? String, - conversationId == displayConversation.id - else { return } - withAnimation(.easeInOut(duration: 0.2)) { - showTranscriptDrawer = true + } + .frame(maxWidth: .infinity) + + // Transcript drawer (slides in from right) + if showTranscriptDrawer { + Rectangle() + .fill(OmiColors.border) + .frame(width: 1) + + transcriptDrawerView + .frame(width: 450) + .transition(.move(edge: .trailing)) + } + } + .opacity(hasAppeared ? 1 : 0) + .offset(y: hasAppeared ? 0 : 20) + .onAppear { + withAnimation(.easeOut(duration: 0.5)) { + hasAppeared = true + } + } + .task { + await appProvider.fetchApps() + await onFetchPeople?() + AnalyticsManager.shared.conversationDetailOpened(conversationId: conversation.id) + + // Load segments from local database if not already present + // Segments are stored locally but not loaded with the list view for performance + if conversation.transcriptSegments.isEmpty { + isLoadingConversation = true + do { + // First try local database (faster, works offline) + if let session = try await TranscriptionStorage.shared.getSessionByBackendId( + conversation.id) + { + let segmentRecords = try await TranscriptionStorage.shared.getSegments( + sessionId: session.id!) + if !segmentRecords.isEmpty { + // Convert local records to TranscriptSegments and update conversation + let segments = segmentRecords.map { $0.toTranscriptSegment() } + var updatedConversation = conversation + updatedConversation.transcriptSegments = segments + loadedConversation = updatedConversation + log("ConversationDetail: Loaded \(segments.count) segments from local database") + } else { + // No local segments, fetch from API + let fullConversation = try await APIClient.shared.getConversation(id: conversation.id) + loadedConversation = fullConversation + log( + "ConversationDetail: Loaded \(fullConversation.transcriptSegments.count) segments from API" + ) } - } - .dismissableSheet(isPresented: $showAppSelector) { - AppSelectorSheet( - apps: appProvider.apps.filter { $0.capabilities.contains("memories") }, - isLoading: isReprocessing, - onSelect: { app in - selectedAppForReprocess = app - Task { - await reprocessWithApp(app) - } - }, - onDismiss: { showAppSelector = false } + } else { + // No local session found, fetch from API + let fullConversation = try await APIClient.shared.getConversation(id: conversation.id) + loadedConversation = fullConversation + log( + "ConversationDetail: Loaded \(fullConversation.transcriptSegments.count) segments from API (no local session)" ) - .frame(width: 400, height: 500) + } + } catch { + logError("ConversationDetail: Failed to load conversation segments", error: error) } - .dismissableSheet(item: $selectedSegmentForNaming) { segment in - NameSpeakerSheet( - segment: segment, - allSegments: displayConversation.transcriptSegments, - people: people, - onSave: { personId, isUser, segmentIndices in - Task { - let assignment = Self.assignmentMetadata( - for: segmentIndices, - in: displayConversation.transcriptSegments - ) - let success = await onAssignSpeaker?( - conversation.id, - assignment.targets, - personId, - isUser - ) ?? false - if success { - await persistSpeakerAssignment( - conversationId: conversation.id, - backendSegmentIds: assignment.backendIds, - fallbackSegmentOrders: assignment.fallbackOrders, - isUser: isUser, - personId: personId - ) - await updateDisplayedConversation(segmentIndices: segmentIndices, isUser: isUser, personId: personId) - } - selectedSegmentForNaming = nil - } - }, - onCreatePerson: { name in - await onCreatePerson?(name) - }, - onDismiss: { - selectedSegmentForNaming = nil - } + isLoadingConversation = false + } + } + .onReceive( + NotificationCenter.default.publisher( + for: .desktopAutomationShowConversationTranscriptRequested) + ) { notification in + guard let conversationId = notification.userInfo?["conversationId"] as? String, + conversationId == displayConversation.id + else { return } + withAnimation(.easeInOut(duration: 0.2)) { + showTranscriptDrawer = true + } + } + .dismissableSheet(isPresented: $showAppSelector) { + AppSelectorSheet( + apps: appProvider.apps.filter { $0.capabilities.contains("memories") }, + isLoading: isReprocessing, + onSelect: { app in + selectedAppForReprocess = app + Task { + await reprocessWithApp(app) + } + }, + onDismiss: { showAppSelector = false } + ) + .frame(width: 400, height: 500) + } + .dismissableSheet(item: $selectedSegmentForNaming) { segment in + NameSpeakerSheet( + segment: segment, + allSegments: displayConversation.transcriptSegments, + people: people, + onSave: { personId, isUser, segmentIndices in + Task { + let assignment = Self.assignmentMetadata( + for: segmentIndices, + in: displayConversation.transcriptSegments ) + let success = + await onAssignSpeaker?( + conversation.id, + assignment.targets, + personId, + isUser + ) ?? false + if success { + await persistSpeakerAssignment( + conversationId: conversation.id, + backendSegmentIds: assignment.backendIds, + fallbackSegmentOrders: assignment.fallbackOrders, + isUser: isUser, + personId: personId + ) + await updateDisplayedConversation( + segmentIndices: segmentIndices, isUser: isUser, personId: personId) + } + selectedSegmentForNaming = nil + } + }, + onCreatePerson: { name in + await onCreatePerson?(name) + }, + onDismiss: { + selectedSegmentForNaming = nil } + ) } + } - // MARK: - Header - - private var headerView: some View { - HStack(spacing: 12) { - // Back button - Button(action: onBack) { - HStack(spacing: 6) { - Image(systemName: "chevron.left") - .scaledFont(size: 14, weight: .medium) - Text("Back") - .scaledFont(size: 14, weight: .medium) - } - .foregroundColor(OmiColors.purplePrimary) - } - .buttonStyle(.plain) - - // Emoji - Text(displayConversation.structured.emoji.isEmpty ? "\u{1F4AC}" : displayConversation.structured.emoji) - .scaledFont(size: 28) - - // Title + timestamp subtitle - VStack(alignment: .leading, spacing: 2) { - HStack(spacing: 8) { - Text(displayConversation.title) - .scaledFont(size: 18, weight: .semibold) - .foregroundColor(OmiColors.textPrimary) - .lineLimit(1) - - // Edit title button (inline with title) - Button(action: { - editedTitle = displayConversation.title - showEditDialog = true - }) { - Image(systemName: "pencil") - .scaledFont(size: 14) - .foregroundColor(OmiColors.textTertiary) - } - .buttonStyle(.plain) - .help("Edit title") - } + // MARK: - Header - Text(formattedTimeRange) - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) - } - - Spacer() + private var headerView: some View { + HStack(spacing: 12) { + // Back button + Button(action: onBack) { + HStack(spacing: 6) { + Image(systemName: "chevron.left") + .scaledFont(size: 14, weight: .medium) + Text("Back") + .scaledFont(size: 14, weight: .medium) + } + .foregroundColor(OmiColors.purplePrimary) + } + .buttonStyle(.plain) + + // Emoji + Text( + displayConversation.structured.emoji.isEmpty + ? "\u{1F4AC}" : displayConversation.structured.emoji + ) + .scaledFont(size: 28) + + // Title + timestamp subtitle + VStack(alignment: .leading, spacing: 2) { + HStack(spacing: 8) { + Text(displayConversation.title) + .scaledFont(size: 18, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + .lineLimit(1) + + // Edit title button (inline with title) + Button(action: { + editedTitle = displayConversation.title + showEditDialog = true + }) { + Image(systemName: "pencil") + .scaledFont(size: 14) + .foregroundColor(OmiColors.textTertiary) + } + .buttonStyle(.plain) + .help("Edit title") + } - // Status badge - if displayConversation.status != .completed { - statusBadge - } + Text(formattedTimeRange) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + } - // View Transcript pill button - viewTranscriptButton + Spacer() - // Inline action buttons - inlineActionButtons - } - .padding(.horizontal, 24) - .padding(.vertical, 16) - .background(OmiColors.backgroundTertiary.opacity(0.5)) - .alert("Edit Conversation Title", isPresented: $showEditDialog) { - TextField("Title", text: $editedTitle) - Button("Cancel", role: .cancel) { } - Button("Save") { - Task { await updateTitle() } - } - .disabled(editedTitle.isEmpty || isUpdatingTitle) - } message: { - Text("Enter a new title for this conversation") - } - .alert("Delete Conversation", isPresented: $showDeleteConfirmation) { - Button("Cancel", role: .cancel) { } - Button("Delete", role: .destructive) { - Task { await deleteConversation() } - } - } message: { - Text("Are you sure you want to delete this conversation? This action cannot be undone.") - } - } + // Status badge + if displayConversation.status != .completed { + statusBadge + } - // MARK: - View Transcript Button + // View Transcript pill button + viewTranscriptButton - private var viewTranscriptButton: some View { - Button(action: { - withAnimation(.easeInOut(duration: 0.25)) { - showTranscriptDrawer.toggle() - } - }) { - HStack(spacing: 6) { - Image(systemName: "text.quote") - .scaledFont(size: 12) - Text(showTranscriptDrawer ? "Hide Transcript" : "View Transcript") - .scaledFont(size: 12, weight: .medium) - } - .foregroundColor(showTranscriptDrawer ? .white : OmiColors.textSecondary) - .padding(.horizontal, 12) - .padding(.vertical, 6) + // Inline action buttons + inlineActionButtons + } + .padding(.horizontal, 24) + .padding(.vertical, 16) + .background(OmiColors.backgroundTertiary.opacity(0.5)) + .alert("Edit Conversation Title", isPresented: $showEditDialog) { + TextField("Title", text: $editedTitle) + Button("Cancel", role: .cancel) {} + Button("Save") { + Task { await updateTitle() } + } + .disabled(editedTitle.isEmpty || isUpdatingTitle) + } message: { + Text("Enter a new title for this conversation") + } + .alert("Delete Conversation", isPresented: $showDeleteConfirmation) { + Button("Cancel", role: .cancel) {} + Button("Delete", role: .destructive) { + Task { await deleteConversation() } + } + } message: { + Text("Are you sure you want to delete this conversation? This action cannot be undone.") + } + } + + // MARK: - View Transcript Button + + private var viewTranscriptButton: some View { + Button(action: { + withAnimation(.easeInOut(duration: 0.25)) { + showTranscriptDrawer.toggle() + } + }) { + HStack(spacing: 6) { + Image(systemName: "text.quote") + .scaledFont(size: 12) + Text(showTranscriptDrawer ? "Hide Transcript" : "View Transcript") + .scaledFont(size: 12, weight: .medium) + } + .foregroundColor(showTranscriptDrawer ? .white : OmiColors.textSecondary) + .padding(.horizontal, 12) + .padding(.vertical, 6) + .background( + Capsule() + .fill(showTranscriptDrawer ? OmiColors.purplePrimary : OmiColors.backgroundTertiary) + ) + } + .buttonStyle(.plain) + } + + // MARK: - Inline Action Buttons + + private var inlineActionButtons: some View { + HStack(spacing: 8) { + // Public share links are cloud-only. + if !isLocalDaemonMode { + Button(action: { Task { await copyLink() } }) { + Image(systemName: isCopyingLink ? "arrow.triangle.2.circlepath" : "link") + .scaledFont(size: 14) + .foregroundColor(OmiColors.textSecondary) + .frame(width: 28, height: 28) .background( - Capsule() - .fill(showTranscriptDrawer ? OmiColors.purplePrimary : OmiColors.backgroundTertiary) + Circle() + .fill(OmiColors.backgroundTertiary) ) } .buttonStyle(.plain) - } - - // MARK: - Inline Action Buttons - - private var inlineActionButtons: some View { - HStack(spacing: 8) { - // Copy link button - Button(action: { Task { await copyLink() } }) { - Image(systemName: isCopyingLink ? "arrow.triangle.2.circlepath" : "link") - .scaledFont(size: 14) - .foregroundColor(OmiColors.textSecondary) - .frame(width: 28, height: 28) - .background( - Circle() - .fill(OmiColors.backgroundTertiary) - ) - } - .buttonStyle(.plain) - .disabled(isCopyingLink) - .help("Copy link") - - // Copy transcript button - Button(action: copyTranscript) { - Image(systemName: "doc.on.doc") - .scaledFont(size: 14) - .foregroundColor(OmiColors.textSecondary) - .frame(width: 28, height: 28) - .background( - Circle() - .fill(OmiColors.backgroundTertiary) - ) + .disabled(isCopyingLink) + .help("Copy link") + } + + // Copy transcript button + Button(action: copyTranscript) { + Image(systemName: "doc.on.doc") + .scaledFont(size: 14) + .foregroundColor(OmiColors.textSecondary) + .frame(width: 28, height: 28) + .background( + Circle() + .fill(OmiColors.backgroundTertiary) + ) + } + .buttonStyle(.plain) + .help("Copy transcript") + + // Move to folder button (menu) + if !folders.isEmpty { + Menu { + if displayConversation.folderId != nil { + Button(action: { + Task { await onMoveToFolder?(conversation.id, nil) } + }) { + Label("Remove from Folder", systemImage: "folder.badge.minus") } - .buttonStyle(.plain) - .help("Copy transcript") - - // Move to folder button (menu) - if !folders.isEmpty { - Menu { - if displayConversation.folderId != nil { - Button(action: { - Task { await onMoveToFolder?(conversation.id, nil) } - }) { - Label("Remove from Folder", systemImage: "folder.badge.minus") - } - Divider() - } - - ForEach(folders) { folder in - Button(action: { - Task { await onMoveToFolder?(conversation.id, folder.id) } - }) { - HStack { - Text(folder.name) - if displayConversation.folderId == folder.id { - Image(systemName: "checkmark") - } - } - } - .disabled(displayConversation.folderId == folder.id) - } - } label: { - Image(systemName: displayConversation.folderId != nil ? "folder.fill" : "folder") - .scaledFont(size: 14) - .foregroundColor(displayConversation.folderId != nil ? OmiColors.purplePrimary : OmiColors.textSecondary) - .frame(width: 28, height: 28) - .background( - Circle() - .fill(OmiColors.backgroundTertiary) - ) + Divider() + } + + ForEach(folders) { folder in + Button(action: { + Task { await onMoveToFolder?(conversation.id, folder.id) } + }) { + HStack { + Text(folder.name) + if displayConversation.folderId == folder.id { + Image(systemName: "checkmark") } - .menuStyle(.borderlessButton) - .frame(width: 28) - .help("Move to folder") - } - - // Delete button - Button(action: { showDeleteConfirmation = true }) { - Image(systemName: "trash") - .scaledFont(size: 14) - .foregroundColor(OmiColors.error) - .frame(width: 28, height: 28) - .background( - Circle() - .fill(OmiColors.backgroundTertiary) - ) + } } - .buttonStyle(.plain) - .help("Delete conversation") + .disabled(displayConversation.folderId == folder.id) + } + } label: { + Image(systemName: displayConversation.folderId != nil ? "folder.fill" : "folder") + .scaledFont(size: 14) + .foregroundColor( + displayConversation.folderId != nil + ? OmiColors.purplePrimary : OmiColors.textSecondary + ) + .frame(width: 28, height: 28) + .background( + Circle() + .fill(OmiColors.backgroundTertiary) + ) } + .menuStyle(.borderlessButton) + .frame(width: 28) + .help("Move to folder") + } + + // Delete button + Button(action: { showDeleteConfirmation = true }) { + Image(systemName: "trash") + .scaledFont(size: 14) + .foregroundColor(OmiColors.error) + .frame(width: 28, height: 28) + .background( + Circle() + .fill(OmiColors.backgroundTertiary) + ) + } + .buttonStyle(.plain) + .help("Delete conversation") } - - // MARK: - Actions - - private func copyTranscript() { - let peopleDict = Dictionary(uniqueKeysWithValues: people.map { ($0.id, $0) }) - let transcript: String = displayConversation.transcriptSegments.map { segment -> String in - let speakerName: String - if segment.isUser { - speakerName = "You" - } else if let personId = segment.personId, let person = peopleDict[personId] { - speakerName = person.name - } else { - speakerName = "Speaker \(segment.speaker ?? "Unknown")" - } - return "[\(speakerName)]: \(segment.text)" - }.joined(separator: "\n\n") - + } + + // MARK: - Actions + + private func copyTranscript() { + let peopleDict = Dictionary(uniqueKeysWithValues: people.map { ($0.id, $0) }) + let transcript: String = displayConversation.transcriptSegments.map { segment -> String in + let speakerName: String + if segment.isUser { + speakerName = "You" + } else if let personId = segment.personId, let person = peopleDict[personId] { + speakerName = person.name + } else { + speakerName = "Speaker \(segment.speaker ?? "Unknown")" + } + return "[\(speakerName)]: \(segment.text)" + }.joined(separator: "\n\n") + + NSPasteboard.general.clearContents() + NSPasteboard.general.setString(transcript, forType: .string) + } + + private func copyLink() async { + isCopyingLink = true + defer { isCopyingLink = false } + + do { + let shareableUrl = try await APIClient.shared.getConversationShareLink(id: conversation.id) + await MainActor.run { NSPasteboard.general.clearContents() - NSPasteboard.general.setString(transcript, forType: .string) + NSPasteboard.general.setString(shareableUrl, forType: .string) + } + AnalyticsManager.shared.shareAction( + category: "conversation", properties: ["conversation_id": conversation.id]) + } catch { + logError("Failed to get share link", error: error) } + } + + private func updateTitle() async { + guard !editedTitle.isEmpty else { return } + isUpdatingTitle = true + defer { isUpdatingTitle = false } + + do { + try await APIClient.shared.updateConversationTitle(id: conversation.id, title: editedTitle) + onTitleUpdated?(editedTitle) + } catch { + logError("Failed to update title", error: error) + } + } + + private func deleteConversation() async { + isDeleting = true + defer { isDeleting = false } + + do { + try await APIClient.shared.deleteConversation(id: conversation.id) + await MainActor.run { + onDelete?() + onBack() + } + } catch { + logError("Failed to delete conversation", error: error) + } + } + + private var statusBadge: some View { + Text(displayConversation.status.rawValue.replacingOccurrences(of: "_", with: " ").capitalized) + .scaledFont(size: 11, weight: .medium) + .foregroundColor(statusColor) + .padding(.horizontal, 8) + .padding(.vertical, 4) + .background( + Capsule() + .fill(statusColor.opacity(0.2)) + ) + } + + private var statusColor: Color { + switch displayConversation.status { + case .completed: + return OmiColors.success + case .processing, .merging: + return OmiColors.info + case .inProgress: + return OmiColors.warning + case .failed: + return OmiColors.error + } + } - private func copyLink() async { - isCopyingLink = true - defer { isCopyingLink = false } + // MARK: - Summary Content (always visible, no tabs) - do { - let shareableUrl = try await APIClient.shared.getConversationShareLink(id: conversation.id) - await MainActor.run { - NSPasteboard.general.clearContents() - NSPasteboard.general.setString(shareableUrl, forType: .string) - } - AnalyticsManager.shared.shareAction(category: "conversation", properties: ["conversation_id": conversation.id]) - } catch { - logError("Failed to get share link", error: error) - } + @ViewBuilder + private var summaryContent: some View { + // Overview section + if !displayConversation.overview.isEmpty { + overviewSection } - private func updateTitle() async { - guard !editedTitle.isEmpty else { return } - isUpdatingTitle = true - defer { isUpdatingTitle = false } + // Metadata chips + metadataSection - do { - try await APIClient.shared.updateConversationTitle(id: conversation.id, title: editedTitle) - onTitleUpdated?(editedTitle) - } catch { - logError("Failed to update title", error: error) - } + // App Results section + if !displayConversation.appsResults.isEmpty { + appResultsSection } - private func deleteConversation() async { - isDeleting = true - defer { isDeleting = false } + // Suggested apps section + suggestedAppsSection - do { - try await APIClient.shared.deleteConversation(id: conversation.id) - await MainActor.run { - onDelete?() - onBack() - } - } catch { - logError("Failed to delete conversation", error: error) - } + // Action items section + if !displayConversation.structured.actionItems.isEmpty { + actionItemsSection } + } + + // MARK: - Transcript Drawer + + @ViewBuilder + private var transcriptDrawerView: some View { + VStack(alignment: .leading, spacing: 0) { + // Drawer header + HStack(spacing: 10) { + Image(systemName: "text.quote") + .scaledFont(size: 14) + .foregroundColor(OmiColors.textSecondary) + + Text("Transcript") + .scaledFont(size: 15, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + + // Segment count badge + Text("\(displayConversation.transcriptSegments.count)") + .scaledFont(size: 11, weight: .medium) + .foregroundColor(OmiColors.purplePrimary) + .padding(.horizontal, 8) + .padding(.vertical, 2) + .background( + Capsule() + .fill(OmiColors.purplePrimary.opacity(0.15)) + ) - private var statusBadge: some View { - Text(displayConversation.status.rawValue.replacingOccurrences(of: "_", with: " ").capitalized) - .scaledFont(size: 11, weight: .medium) - .foregroundColor(statusColor) - .padding(.horizontal, 8) - .padding(.vertical, 4) + Spacer() + + // Copy button + Button(action: copyTranscript) { + Image(systemName: "doc.on.doc") + .scaledFont(size: 13) + .foregroundColor(OmiColors.textSecondary) + .frame(width: 28, height: 28) .background( - Capsule() - .fill(statusColor.opacity(0.2)) + Circle() + .fill(OmiColors.backgroundTertiary) ) - } - - private var statusColor: Color { - switch displayConversation.status { - case .completed: - return OmiColors.success - case .processing, .merging: - return OmiColors.info - case .inProgress: - return OmiColors.warning - case .failed: - return OmiColors.error } - } - - // MARK: - Summary Content (always visible, no tabs) + .buttonStyle(.plain) + .help("Copy transcript") - @ViewBuilder - private var summaryContent: some View { - // Overview section - if !displayConversation.overview.isEmpty { - overviewSection + // Close button + Button(action: { + withAnimation(.easeInOut(duration: 0.25)) { + showTranscriptDrawer = false + } + }) { + Image(systemName: "xmark") + .scaledFont(size: 13) + .foregroundColor(OmiColors.textSecondary) + .frame(width: 28, height: 28) + .background( + Circle() + .fill(OmiColors.backgroundTertiary) + ) } - - // Metadata chips - metadataSection - - // App Results section - if !displayConversation.appsResults.isEmpty { - appResultsSection + .buttonStyle(.plain) + .help("Close transcript") + } + .padding(.horizontal, 20) + .padding(.vertical, 14) + .background(OmiColors.backgroundTertiary.opacity(0.5)) + + // Drawer content + if displayConversation.transcriptSegments.isEmpty && !isLoadingConversation { + // Empty state + VStack(spacing: 12) { + Image(systemName: "text.quote") + .scaledFont(size: 40) + .foregroundColor(OmiColors.textTertiary.opacity(0.5)) + + Text("No transcript available") + .scaledFont(size: 14) + .foregroundColor(OmiColors.textTertiary) } - - // Suggested apps section - suggestedAppsSection - - // Action items section - if !displayConversation.structured.actionItems.isEmpty { - actionItemsSection + .frame(maxWidth: .infinity, maxHeight: .infinity) + } else if isLoadingConversation { + // Loading state + VStack(spacing: 12) { + ProgressView() + .scaleEffect(0.8) + + Text("Loading transcript...") + .scaledFont(size: 14) + .foregroundColor(OmiColors.textTertiary) } - } - - // MARK: - Transcript Drawer - - @ViewBuilder - private var transcriptDrawerView: some View { - VStack(alignment: .leading, spacing: 0) { - // Drawer header - HStack(spacing: 10) { - Image(systemName: "text.quote") - .scaledFont(size: 14) - .foregroundColor(OmiColors.textSecondary) - - Text("Transcript") - .scaledFont(size: 15, weight: .semibold) - .foregroundColor(OmiColors.textPrimary) - - // Segment count badge - Text("\(displayConversation.transcriptSegments.count)") - .scaledFont(size: 11, weight: .medium) - .foregroundColor(OmiColors.purplePrimary) - .padding(.horizontal, 8) - .padding(.vertical, 2) - .background( - Capsule() - .fill(OmiColors.purplePrimary.opacity(0.15)) - ) - - Spacer() - - // Copy button - Button(action: copyTranscript) { - Image(systemName: "doc.on.doc") - .scaledFont(size: 13) - .foregroundColor(OmiColors.textSecondary) - .frame(width: 28, height: 28) - .background( - Circle() - .fill(OmiColors.backgroundTertiary) - ) - } - .buttonStyle(.plain) - .help("Copy transcript") - - // Close button - Button(action: { - withAnimation(.easeInOut(duration: 0.25)) { - showTranscriptDrawer = false - } - }) { - Image(systemName: "xmark") - .scaledFont(size: 13) - .foregroundColor(OmiColors.textSecondary) - .frame(width: 28, height: 28) - .background( - Circle() - .fill(OmiColors.backgroundTertiary) - ) - } - .buttonStyle(.plain) - .help("Close transcript") - } - .padding(.horizontal, 20) - .padding(.vertical, 14) - .background(OmiColors.backgroundTertiary.opacity(0.5)) - - // Drawer content - if displayConversation.transcriptSegments.isEmpty && !isLoadingConversation { - // Empty state - VStack(spacing: 12) { - Image(systemName: "text.quote") - .scaledFont(size: 40) - .foregroundColor(OmiColors.textTertiary.opacity(0.5)) - - Text("No transcript available") - .scaledFont(size: 14) - .foregroundColor(OmiColors.textTertiary) - } - .frame(maxWidth: .infinity, maxHeight: .infinity) - } else if isLoadingConversation { - // Loading state - VStack(spacing: 12) { - ProgressView() - .scaleEffect(0.8) - - Text("Loading transcript...") - .scaledFont(size: 14) - .foregroundColor(OmiColors.textTertiary) - } - .frame(maxWidth: .infinity, maxHeight: .infinity) - } else { - // LazyVStack is a DIRECT child of ScrollView so it gets bounded proposed height - // and only materializes visible children. - ScrollView { - LazyVStack(alignment: .leading, spacing: 12) { - transcriptBubblesContent - } - .padding(16) - } - } + .frame(maxWidth: .infinity, maxHeight: .infinity) + } else { + // LazyVStack is a DIRECT child of ScrollView so it gets bounded proposed height + // and only materializes visible children. + ScrollView { + LazyVStack(alignment: .leading, spacing: 12) { + transcriptBubblesContent + } + .padding(16) } - .background(OmiColors.backgroundPrimary) + } } - - // MARK: - Transcript Bubbles (shared) - - /// Flat content intended to be placed inside a parent LazyVStack. - /// Do NOT wrap this in another LazyVStack or VStack — it emits ForEach items directly. - @ViewBuilder - private var transcriptBubblesContent: some View { - let peopleDict = Dictionary(uniqueKeysWithValues: people.map { ($0.id, $0) }) - ForEach(displayConversation.transcriptSegments) { segment in - SpeakerBubbleView( - segment: segment, - isUser: segment.isUser, - personName: segment.personId.flatMap { peopleDict[$0]?.name }, - onSpeakerTapped: segment.isUser ? nil : { - selectedSegmentForNaming = segment - } - ) - .padding(.horizontal, 16) - } + .background(OmiColors.backgroundPrimary) + } + + // MARK: - Transcript Bubbles (shared) + + /// Flat content intended to be placed inside a parent LazyVStack. + /// Do NOT wrap this in another LazyVStack or VStack — it emits ForEach items directly. + @ViewBuilder + private var transcriptBubblesContent: some View { + let peopleDict = Dictionary(uniqueKeysWithValues: people.map { ($0.id, $0) }) + ForEach(displayConversation.transcriptSegments) { segment in + SpeakerBubbleView( + segment: segment, + isUser: segment.isUser, + personName: segment.personId.flatMap { peopleDict[$0]?.name }, + onSpeakerTapped: segment.isUser + ? nil + : { + selectedSegmentForNaming = segment + } + ) + .padding(.horizontal, 16) } - - @MainActor - private func updateDisplayedConversation(segmentIndices: [Int], isUser: Bool, personId: String?) { - var updatedConversation = displayConversation - for index in segmentIndices where updatedConversation.transcriptSegments.indices.contains(index) { - let oldSegment = updatedConversation.transcriptSegments[index] - updatedConversation.transcriptSegments[index] = TranscriptSegment( - id: oldSegment.id, - backendId: oldSegment.backendId, - text: oldSegment.text, - speaker: oldSegment.speaker, - isUser: isUser, - personId: isUser ? nil : personId, - start: oldSegment.start, - end: oldSegment.end, - translations: oldSegment.translations - ) - } - loadedConversation = updatedConversation + } + + @MainActor + private func updateDisplayedConversation(segmentIndices: [Int], isUser: Bool, personId: String?) { + var updatedConversation = displayConversation + for index in segmentIndices where updatedConversation.transcriptSegments.indices.contains(index) + { + let oldSegment = updatedConversation.transcriptSegments[index] + updatedConversation.transcriptSegments[index] = TranscriptSegment( + id: oldSegment.id, + backendId: oldSegment.backendId, + text: oldSegment.text, + speaker: oldSegment.speaker, + isUser: isUser, + personId: isUser ? nil : personId, + start: oldSegment.start, + end: oldSegment.end, + translations: oldSegment.translations + ) } - - private func persistSpeakerAssignment( - conversationId: String, - backendSegmentIds: [String], - fallbackSegmentOrders: [Int], - isUser: Bool, - personId: String? - ) async { - do { - try await TranscriptionStorage.shared.updateSpeakerAssignmentByBackendId( - conversationId, - segmentIds: backendSegmentIds, - fallbackSegmentOrders: fallbackSegmentOrders, - isUser: isUser, - personId: isUser ? nil : personId - ) - } catch { - logError("ConversationDetail: Failed to persist speaker assignment locally", error: error) - } + loadedConversation = updatedConversation + } + + private func persistSpeakerAssignment( + conversationId: String, + backendSegmentIds: [String], + fallbackSegmentOrders: [Int], + isUser: Bool, + personId: String? + ) async { + do { + try await TranscriptionStorage.shared.updateSpeakerAssignmentByBackendId( + conversationId, + segmentIds: backendSegmentIds, + fallbackSegmentOrders: fallbackSegmentOrders, + isUser: isUser, + personId: isUser ? nil : personId + ) + } catch { + logError("ConversationDetail: Failed to persist speaker assignment locally", error: error) } - - // MARK: - Overview Section - - private var overviewSection: some View { - VStack(alignment: .leading, spacing: 8) { - HStack(spacing: 6) { - Image(systemName: "star.fill") - .scaledFont(size: 13) - .foregroundColor(Color(red: 0.95, green: 0.75, blue: 0.15)) - - Text("Summary") - .scaledFont(size: 14, weight: .semibold) - .foregroundColor(OmiColors.textSecondary) - } - - SelectableMarkdown(text: displayConversation.overview, sender: .ai) - .textSelection(.enabled) - .environment(\.colorScheme, .dark) - .frame(maxWidth: .infinity, alignment: .leading) - } + } + + // MARK: - Overview Section + + private var overviewSection: some View { + VStack(alignment: .leading, spacing: 8) { + HStack(spacing: 6) { + Image(systemName: "star.fill") + .scaledFont(size: 13) + .foregroundColor(Color(red: 0.95, green: 0.75, blue: 0.15)) + + Text("Summary") + .scaledFont(size: 14, weight: .semibold) + .foregroundColor(OmiColors.textSecondary) + } + + SelectableMarkdown(text: displayConversation.overview, sender: .ai) + .textSelection(.enabled) + .environment(\.colorScheme, .dark) + .frame(maxWidth: .infinity, alignment: .leading) } + } - // MARK: - Metadata Section + // MARK: - Metadata Section - private var metadataSection: some View { - HStack(spacing: 12) { - // Source chip (device indicator) - sourceChip + private var metadataSection: some View { + HStack(spacing: 12) { + // Source chip (device indicator) + sourceChip - // Duration chip - metadataChip(icon: "hourglass", text: displayConversation.formattedDuration) + // Duration chip + metadataChip(icon: "hourglass", text: displayConversation.formattedDuration) - // Category chip - if !displayConversation.structured.category.isEmpty && displayConversation.structured.category != "other" { - metadataChip(icon: "tag", text: displayConversation.structured.category.capitalized) - } + // Category chip + if !displayConversation.structured.category.isEmpty + && displayConversation.structured.category != "other" + { + metadataChip(icon: "tag", text: displayConversation.structured.category.capitalized) + } - Spacer() - } + Spacer() } - - private var sourceChip: some View { - metadataChip(icon: "dot.radiowaves.left.and.right", text: sourceLabel) + } + + private var sourceChip: some View { + metadataChip(icon: "dot.radiowaves.left.and.right", text: sourceLabel) + } + + private var sourceLabel: String { + switch displayConversation.source { + case .desktop: return "Desktop" + case .omi: return "omi" + case .phone: return "Phone" + case .appleWatch: return "Apple Watch" + case .workflow: return "Workflow" + case .screenpipe: return "Screenpipe" + case .friend, .friendCom: return "Friend" + case .openglass: return "OpenGlass" + case .frame: return "Frame" + case .bee: return "Bee" + case .limitless: return "Limitless" + case .plaud: return "Plaud" + default: return "Unknown" } + } - private var sourceLabel: String { - switch displayConversation.source { - case .desktop: return "Desktop" - case .omi: return "omi" - case .phone: return "Phone" - case .appleWatch: return "Apple Watch" - case .workflow: return "Workflow" - case .screenpipe: return "Screenpipe" - case .friend, .friendCom: return "Friend" - case .openglass: return "OpenGlass" - case .frame: return "Frame" - case .bee: return "Bee" - case .limitless: return "Limitless" - case .plaud: return "Plaud" - default: return "Unknown" - } - } - - private func metadataChip(icon: String, text: String) -> some View { - HStack(spacing: 6) { - Image(systemName: icon) - .scaledFont(size: 11) - .foregroundColor(OmiColors.textTertiary) + private func metadataChip(icon: String, text: String) -> some View { + HStack(spacing: 6) { + Image(systemName: icon) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) - Text(text) - .scaledFont(size: 12) - .foregroundColor(OmiColors.textSecondary) + Text(text) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + } + .padding(.horizontal, 10) + .padding(.vertical, 6) + .background( + Capsule() + .fill(OmiColors.backgroundTertiary) + ) + } + + // MARK: - App Results Section + + private var appResultsSection: some View { + VStack(alignment: .leading, spacing: 12) { + HStack { + Text("App Insights") + .scaledFont(size: 14, weight: .semibold) + .foregroundColor(OmiColors.textSecondary) + + Spacer() + + Button(action: { showAppSelector = true }) { + HStack(spacing: 4) { + Image(systemName: "arrow.triangle.2.circlepath") + .scaledFont(size: 11) + Text("Reprocess") + .scaledFont(size: 12) + } + .foregroundColor(OmiColors.purplePrimary) } - .padding(.horizontal, 10) - .padding(.vertical, 6) - .background( - Capsule() - .fill(OmiColors.backgroundTertiary) + .buttonStyle(.plain) + .disabled(isReprocessing) + } + + ForEach(displayConversation.appsResults) { result in + AppResultCard( + result: result, + app: appProvider.apps.first { $0.id == result.appId } ) + } } - - // MARK: - App Results Section - - private var appResultsSection: some View { - VStack(alignment: .leading, spacing: 12) { - HStack { - Text("App Insights") - .scaledFont(size: 14, weight: .semibold) - .foregroundColor(OmiColors.textSecondary) - - Spacer() - - Button(action: { showAppSelector = true }) { - HStack(spacing: 4) { - Image(systemName: "arrow.triangle.2.circlepath") - .scaledFont(size: 11) - Text("Reprocess") - .scaledFont(size: 12) - } - .foregroundColor(OmiColors.purplePrimary) + } + + // MARK: - Suggested Apps Section + + private var suggestedAppsSection: some View { + VStack(alignment: .leading, spacing: 12) { + HStack { + Text("Try with Apps") + .scaledFont(size: 14, weight: .semibold) + .foregroundColor(OmiColors.textSecondary) + + Spacer() + } + + let memoryApps = appProvider.apps.filter { + $0.capabilities.contains("memories") + && !displayConversation.appsResults.contains(where: { $0.appId == $0.id }) + }.prefix(4) + + if memoryApps.isEmpty && !appProvider.isLoading { + Text("Enable apps with memory capability to get additional insights") + .scaledFont(size: 13) + .foregroundColor(OmiColors.textTertiary) + .padding() + .frame(maxWidth: .infinity) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(OmiColors.backgroundSecondary) + ) + } else { + ScrollView(.horizontal, showsIndicators: false) { + HStack(spacing: 12) { + ForEach(Array(memoryApps)) { app in + SuggestedAppCard( + app: app, + isLoading: selectedAppForReprocess?.id == app.id && isReprocessing, + onTap: { + selectedAppForReprocess = app + Task { + await reprocessWithApp(app) + } } - .buttonStyle(.plain) - .disabled(isReprocessing) - } - - ForEach(displayConversation.appsResults) { result in - AppResultCard( - result: result, - app: appProvider.apps.first { $0.id == result.appId } - ) + ) } + } } + } } + } - // MARK: - Suggested Apps Section - - private var suggestedAppsSection: some View { - VStack(alignment: .leading, spacing: 12) { - HStack { - Text("Try with Apps") - .scaledFont(size: 14, weight: .semibold) - .foregroundColor(OmiColors.textSecondary) - - Spacer() - } + // MARK: - Reprocess - let memoryApps = appProvider.apps.filter { - $0.capabilities.contains("memories") && - !displayConversation.appsResults.contains(where: { $0.appId == $0.id }) - }.prefix(4) - - if memoryApps.isEmpty && !appProvider.isLoading { - Text("Enable apps with memory capability to get additional insights") - .scaledFont(size: 13) - .foregroundColor(OmiColors.textTertiary) - .padding() - .frame(maxWidth: .infinity) - .background( - RoundedRectangle(cornerRadius: 8) - .fill(OmiColors.backgroundSecondary) - ) - } else { - ScrollView(.horizontal, showsIndicators: false) { - HStack(spacing: 12) { - ForEach(Array(memoryApps)) { app in - SuggestedAppCard( - app: app, - isLoading: selectedAppForReprocess?.id == app.id && isReprocessing, - onTap: { - selectedAppForReprocess = app - Task { - await reprocessWithApp(app) - } - } - ) - } - } - } - } - } + private func reprocessWithApp(_ app: OmiApp) async { + isReprocessing = true + defer { + isReprocessing = false + selectedAppForReprocess = nil + showAppSelector = false } - // MARK: - Reprocess - - private func reprocessWithApp(_ app: OmiApp) async { - isReprocessing = true - defer { - isReprocessing = false - selectedAppForReprocess = nil - showAppSelector = false - } - - // Track reprocess - AnalyticsManager.shared.conversationReprocessed(conversationId: conversation.id, appId: app.id) + // Track reprocess + AnalyticsManager.shared.conversationReprocessed(conversationId: conversation.id, appId: app.id) - do { - try await APIClient.shared.reprocessConversation( - conversationId: conversation.id, - appId: app.id - ) - } catch { - logError("Failed to reprocess conversation", error: error) - } + do { + try await APIClient.shared.reprocessConversation( + conversationId: conversation.id, + appId: app.id + ) + } catch { + logError("Failed to reprocess conversation", error: error) } - - // MARK: - Action Items Section - - private var actionItemsSection: some View { - let activeItems = displayConversation.structured.actionItems.filter { !$0.deleted } - return VStack(alignment: .leading, spacing: 12) { - HStack(spacing: 8) { - Image(systemName: "checklist") - .scaledFont(size: 14) - .foregroundColor(OmiColors.textSecondary) - - Text("Action Items") - .scaledFont(size: 16, weight: .semibold) - .foregroundColor(OmiColors.textSecondary) - - // Count badge - Text("\(activeItems.count)") - .scaledFont(size: 11, weight: .medium) - .foregroundColor(OmiColors.purplePrimary) - .padding(.horizontal, 8) - .padding(.vertical, 2) - .background( - Capsule() - .fill(OmiColors.purplePrimary.opacity(0.15)) - ) - - Spacer() - } - - VStack(alignment: .leading, spacing: 8) { - ForEach(activeItems) { item in - HStack(alignment: .top, spacing: 10) { - Image(systemName: item.completed ? "checkmark.circle.fill" : "circle") - .scaledFont(size: 16) - .foregroundColor(item.completed ? OmiColors.success : OmiColors.textTertiary) - - Text(item.description) - .scaledFont(size: 14) - .foregroundColor(item.completed ? OmiColors.textTertiary : OmiColors.textPrimary) - .textSelection(.enabled) - .strikethrough(item.completed, color: OmiColors.textTertiary) - } - .padding(12) - .frame(maxWidth: .infinity, alignment: .leading) - .background( - RoundedRectangle(cornerRadius: 12) - .fill(OmiColors.backgroundTertiary) - ) - .overlay( - RoundedRectangle(cornerRadius: 12) - .stroke(OmiColors.backgroundTertiary.opacity(0.3), lineWidth: 1) - ) - } - } + } + + // MARK: - Action Items Section + + private var actionItemsSection: some View { + let activeItems = displayConversation.structured.actionItems.filter { !$0.deleted } + return VStack(alignment: .leading, spacing: 12) { + HStack(spacing: 8) { + Image(systemName: "checklist") + .scaledFont(size: 14) + .foregroundColor(OmiColors.textSecondary) + + Text("Action Items") + .scaledFont(size: 16, weight: .semibold) + .foregroundColor(OmiColors.textSecondary) + + // Count badge + Text("\(activeItems.count)") + .scaledFont(size: 11, weight: .medium) + .foregroundColor(OmiColors.purplePrimary) + .padding(.horizontal, 8) + .padding(.vertical, 2) + .background( + Capsule() + .fill(OmiColors.purplePrimary.opacity(0.15)) + ) + + Spacer() + } + + VStack(alignment: .leading, spacing: 8) { + ForEach(activeItems) { item in + HStack(alignment: .top, spacing: 10) { + Image(systemName: item.completed ? "checkmark.circle.fill" : "circle") + .scaledFont(size: 16) + .foregroundColor(item.completed ? OmiColors.success : OmiColors.textTertiary) + + Text(item.description) + .scaledFont(size: 14) + .foregroundColor(item.completed ? OmiColors.textTertiary : OmiColors.textPrimary) + .textSelection(.enabled) + .strikethrough(item.completed, color: OmiColors.textTertiary) + } + .padding(12) + .frame(maxWidth: .infinity, alignment: .leading) + .background( + RoundedRectangle(cornerRadius: 12) + .fill(OmiColors.backgroundTertiary) + ) + .overlay( + RoundedRectangle(cornerRadius: 12) + .stroke(OmiColors.backgroundTertiary.opacity(0.3), lineWidth: 1) + ) } + } } + } } #Preview { - ConversationDetailView( - conversation: ServerConversation.preview, - onBack: { } - ) - .frame(width: 600, height: 800) - .background(OmiColors.backgroundPrimary) + ConversationDetailView( + conversation: ServerConversation.preview, + onBack: {} + ) + .frame(width: 600, height: 800) + .background(OmiColors.backgroundPrimary) } // Preview helper extension ServerConversation { - static var preview: ServerConversation { - // This would need to be implemented with a proper initializer - // For now, previews won't work without mock data - fatalError("Preview not implemented") - } + static var preview: ServerConversation { + // This would need to be implemented with a proper initializer + // For now, previews won't work without mock data + fatalError("Preview not implemented") + } } // MARK: - App Result Card struct AppResultCard: View { - let result: AppResponse - let app: OmiApp? - - @State private var isExpanded = false - - var body: some View { - VStack(alignment: .leading, spacing: 10) { - // Header - HStack(spacing: 10) { - if let app = app { - AsyncImage(url: URL(string: app.image)) { phase in - switch phase { - case .success(let image): - image - .resizable() - .aspectRatio(contentMode: .fill) - default: - RoundedRectangle(cornerRadius: 8) - .fill(OmiColors.backgroundTertiary) - } - } - .frame(width: 32, height: 32) - .clipShape(RoundedRectangle(cornerRadius: 8)) - - VStack(alignment: .leading, spacing: 2) { - Text(app.name) - .scaledFont(size: 13, weight: .medium) - .foregroundColor(OmiColors.textPrimary) - - Text(app.author) - .scaledFont(size: 11) - .foregroundColor(OmiColors.textTertiary) - } - } else { - Image(systemName: "app.fill") - .scaledFont(size: 16) - .foregroundColor(OmiColors.textTertiary) - .frame(width: 32, height: 32) - .background(OmiColors.backgroundTertiary) - .clipShape(RoundedRectangle(cornerRadius: 8)) - - Text("App") - .scaledFont(size: 13, weight: .medium) - .foregroundColor(OmiColors.textPrimary) - } - - Spacer() - - Button(action: { withAnimation { isExpanded.toggle() } }) { - Image(systemName: isExpanded ? "chevron.up" : "chevron.down") - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) - } - .buttonStyle(.plain) + let result: AppResponse + let app: OmiApp? + + @State private var isExpanded = false + + var body: some View { + VStack(alignment: .leading, spacing: 10) { + // Header + HStack(spacing: 10) { + if let app = app { + AsyncImage(url: URL(string: app.image)) { phase in + switch phase { + case .success(let image): + image + .resizable() + .aspectRatio(contentMode: .fill) + default: + RoundedRectangle(cornerRadius: 8) + .fill(OmiColors.backgroundTertiary) } + } + .frame(width: 32, height: 32) + .clipShape(RoundedRectangle(cornerRadius: 8)) + + VStack(alignment: .leading, spacing: 2) { + Text(app.name) + .scaledFont(size: 13, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + + Text(app.author) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + } + } else { + Image(systemName: "app.fill") + .scaledFont(size: 16) + .foregroundColor(OmiColors.textTertiary) + .frame(width: 32, height: 32) + .background(OmiColors.backgroundTertiary) + .clipShape(RoundedRectangle(cornerRadius: 8)) + + Text("App") + .scaledFont(size: 13, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + } - // Content - if isExpanded || result.content.count < 200 { - Text(result.content) - .scaledFont(size: 13) - .foregroundColor(OmiColors.textSecondary) - .textSelection(.enabled) - .lineSpacing(4) - } else { - Text(result.content.prefix(200) + "...") - .scaledFont(size: 13) - .foregroundColor(OmiColors.textSecondary) - .textSelection(.enabled) - .lineSpacing(4) - } + Spacer() - // "Generated by" footer - if let app = app { - HStack(spacing: 6) { - AsyncImage(url: URL(string: app.image)) { phase in - switch phase { - case .success(let image): - image - .resizable() - .aspectRatio(contentMode: .fill) - default: - RoundedRectangle(cornerRadius: 4) - .fill(OmiColors.backgroundTertiary) - } - } - .frame(width: 16, height: 16) - .clipShape(RoundedRectangle(cornerRadius: 4)) - - Text("Generated by \(app.name)") - .scaledFont(size: 11) - .foregroundColor(OmiColors.textTertiary) - } - .padding(.horizontal, 10) - .padding(.vertical, 5) - .background( - Capsule() - .fill(OmiColors.backgroundTertiary.opacity(0.6)) - ) + Button(action: { withAnimation { isExpanded.toggle() } }) { + Image(systemName: isExpanded ? "chevron.up" : "chevron.down") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + } + .buttonStyle(.plain) + } + + // Content + if isExpanded || result.content.count < 200 { + Text(result.content) + .scaledFont(size: 13) + .foregroundColor(OmiColors.textSecondary) + .textSelection(.enabled) + .lineSpacing(4) + } else { + Text(result.content.prefix(200) + "...") + .scaledFont(size: 13) + .foregroundColor(OmiColors.textSecondary) + .textSelection(.enabled) + .lineSpacing(4) + } + + // "Generated by" footer + if let app = app { + HStack(spacing: 6) { + AsyncImage(url: URL(string: app.image)) { phase in + switch phase { + case .success(let image): + image + .resizable() + .aspectRatio(contentMode: .fill) + default: + RoundedRectangle(cornerRadius: 4) + .fill(OmiColors.backgroundTertiary) } + } + .frame(width: 16, height: 16) + .clipShape(RoundedRectangle(cornerRadius: 4)) + + Text("Generated by \(app.name)") + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) } - .padding(14) + .padding(.horizontal, 10) + .padding(.vertical, 5) .background( - RoundedRectangle(cornerRadius: 12) - .fill(OmiColors.backgroundSecondary) + Capsule() + .fill(OmiColors.backgroundTertiary.opacity(0.6)) ) + } } + .padding(14) + .background( + RoundedRectangle(cornerRadius: 12) + .fill(OmiColors.backgroundSecondary) + ) + } } // MARK: - Suggested App Card struct SuggestedAppCard: View { - let app: OmiApp - let isLoading: Bool - let onTap: () -> Void - - @State private var isHovering = false - - var body: some View { - Button(action: onTap) { - VStack(spacing: 8) { - ZStack { - AsyncImage(url: URL(string: app.image)) { phase in - switch phase { - case .success(let image): - image - .resizable() - .aspectRatio(contentMode: .fill) - default: - RoundedRectangle(cornerRadius: 12) - .fill(OmiColors.backgroundTertiary) - } - } - .frame(width: 56, height: 56) - .clipShape(RoundedRectangle(cornerRadius: 12)) - - if isLoading { - RoundedRectangle(cornerRadius: 12) - .fill(Color.black.opacity(0.5)) - .frame(width: 56, height: 56) - - ProgressView() - .scaleEffect(0.7) - .tint(.white) - } - } - - Text(app.name) - .scaledFont(size: 11, weight: .medium) - .foregroundColor(OmiColors.textPrimary) - .lineLimit(1) + let app: OmiApp + let isLoading: Bool + let onTap: () -> Void + + @State private var isHovering = false + + var body: some View { + Button(action: onTap) { + VStack(spacing: 8) { + ZStack { + AsyncImage(url: URL(string: app.image)) { phase in + switch phase { + case .success(let image): + image + .resizable() + .aspectRatio(contentMode: .fill) + default: + RoundedRectangle(cornerRadius: 12) + .fill(OmiColors.backgroundTertiary) } - .frame(width: 80) - .padding(.vertical, 10) - .padding(.horizontal, 8) - .background( - RoundedRectangle(cornerRadius: 12) - .fill(isHovering ? OmiColors.backgroundTertiary : OmiColors.backgroundSecondary) - ) + } + .frame(width: 56, height: 56) + .clipShape(RoundedRectangle(cornerRadius: 12)) + + if isLoading { + RoundedRectangle(cornerRadius: 12) + .fill(Color.black.opacity(0.5)) + .frame(width: 56, height: 56) + + ProgressView() + .scaleEffect(0.7) + .tint(.white) + } } - .buttonStyle(.plain) - .disabled(isLoading) - .onHover { isHovering = $0 } + + Text(app.name) + .scaledFont(size: 11, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + .lineLimit(1) + } + .frame(width: 80) + .padding(.vertical, 10) + .padding(.horizontal, 8) + .background( + RoundedRectangle(cornerRadius: 12) + .fill(isHovering ? OmiColors.backgroundTertiary : OmiColors.backgroundSecondary) + ) } + .buttonStyle(.plain) + .disabled(isLoading) + .onHover { isHovering = $0 } + } } // MARK: - App Selector Sheet struct AppSelectorSheet: View { - let apps: [OmiApp] - let isLoading: Bool - let onSelect: (OmiApp) -> Void - let onDismiss: () -> Void - - @State private var selectedAppId: String? - - var body: some View { - VStack(spacing: 0) { - // Header - HStack { - Text("Select App") - .scaledFont(size: 16, weight: .semibold) - .foregroundColor(OmiColors.textPrimary) - - Spacer() - - Button(action: onDismiss) { - Image(systemName: "xmark.circle.fill") - .scaledFont(size: 20) - .foregroundColor(OmiColors.textTertiary) - } - .buttonStyle(.plain) - } - .padding() - - Divider() - .background(OmiColors.backgroundTertiary) - - // Apps list - if apps.isEmpty { - VStack(spacing: 12) { - Image(systemName: "square.grid.2x2") - .scaledFont(size: 40) - .foregroundColor(OmiColors.textTertiary) - - Text("No Apps Available") - .scaledFont(size: 14, weight: .medium) - .foregroundColor(OmiColors.textSecondary) - - Text("Enable apps with memory capability to reprocess conversations") - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) - .multilineTextAlignment(.center) - } - .frame(maxWidth: .infinity, maxHeight: .infinity) - .padding() - } else { - ScrollView { - LazyVStack(spacing: 2) { - ForEach(apps) { app in - AppSelectorRow( - app: app, - isSelected: selectedAppId == app.id, - isLoading: isLoading && selectedAppId == app.id - ) { - selectedAppId = app.id - onSelect(app) - } - } - } - .padding(.horizontal, 8) - .padding(.vertical, 8) - } + let apps: [OmiApp] + let isLoading: Bool + let onSelect: (OmiApp) -> Void + let onDismiss: () -> Void + + @State private var selectedAppId: String? + + var body: some View { + VStack(spacing: 0) { + // Header + HStack { + Text("Select App") + .scaledFont(size: 16, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + + Spacer() + + Button(action: onDismiss) { + Image(systemName: "xmark.circle.fill") + .scaledFont(size: 20) + .foregroundColor(OmiColors.textTertiary) + } + .buttonStyle(.plain) + } + .padding() + + Divider() + .background(OmiColors.backgroundTertiary) + + // Apps list + if apps.isEmpty { + VStack(spacing: 12) { + Image(systemName: "square.grid.2x2") + .scaledFont(size: 40) + .foregroundColor(OmiColors.textTertiary) + + Text("No Apps Available") + .scaledFont(size: 14, weight: .medium) + .foregroundColor(OmiColors.textSecondary) + + Text("Enable apps with memory capability to reprocess conversations") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .multilineTextAlignment(.center) + } + .frame(maxWidth: .infinity, maxHeight: .infinity) + .padding() + } else { + ScrollView { + LazyVStack(spacing: 2) { + ForEach(apps) { app in + AppSelectorRow( + app: app, + isSelected: selectedAppId == app.id, + isLoading: isLoading && selectedAppId == app.id + ) { + selectedAppId = app.id + onSelect(app) + } } + } + .padding(.horizontal, 8) + .padding(.vertical, 8) } - .frame(width: 320, height: 400) - .background(OmiColors.backgroundPrimary) + } } + .frame(width: 320, height: 400) + .background(OmiColors.backgroundPrimary) + } } struct AppSelectorRow: View { - let app: OmiApp - let isSelected: Bool - let isLoading: Bool - let onSelect: () -> Void - - @State private var isHovering = false - - var body: some View { - Button(action: onSelect) { - HStack(spacing: 12) { - AsyncImage(url: URL(string: app.image)) { phase in - switch phase { - case .success(let image): - image - .resizable() - .aspectRatio(contentMode: .fill) - default: - RoundedRectangle(cornerRadius: 10) - .fill(OmiColors.backgroundTertiary) - } - } - .frame(width: 44, height: 44) - .clipShape(RoundedRectangle(cornerRadius: 10)) + let app: OmiApp + let isSelected: Bool + let isLoading: Bool + let onSelect: () -> Void + + @State private var isHovering = false + + var body: some View { + Button(action: onSelect) { + HStack(spacing: 12) { + AsyncImage(url: URL(string: app.image)) { phase in + switch phase { + case .success(let image): + image + .resizable() + .aspectRatio(contentMode: .fill) + default: + RoundedRectangle(cornerRadius: 10) + .fill(OmiColors.backgroundTertiary) + } + } + .frame(width: 44, height: 44) + .clipShape(RoundedRectangle(cornerRadius: 10)) - VStack(alignment: .leading, spacing: 2) { - Text(app.name) - .scaledFont(size: 13, weight: .medium) - .foregroundColor(OmiColors.textPrimary) + VStack(alignment: .leading, spacing: 2) { + Text(app.name) + .scaledFont(size: 13, weight: .medium) + .foregroundColor(OmiColors.textPrimary) - Text(app.author) - .scaledFont(size: 11) - .foregroundColor(OmiColors.textTertiary) - } + Text(app.author) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + } - Spacer() + Spacer() - if isLoading { - ProgressView() - .scaleEffect(0.7) - } else if isSelected { - Image(systemName: "checkmark.circle.fill") - .scaledFont(size: 18) - .foregroundColor(OmiColors.purplePrimary) - } - } - .padding(.horizontal, 12) - .padding(.vertical, 10) - .background( - RoundedRectangle(cornerRadius: 10) - .fill(isSelected || isHovering ? OmiColors.backgroundTertiary : Color.clear) - ) + if isLoading { + ProgressView() + .scaleEffect(0.7) + } else if isSelected { + Image(systemName: "checkmark.circle.fill") + .scaledFont(size: 18) + .foregroundColor(OmiColors.purplePrimary) } - .buttonStyle(.plain) - .disabled(isLoading) - .onHover { isHovering = $0 } + } + .padding(.horizontal, 12) + .padding(.vertical, 10) + .background( + RoundedRectangle(cornerRadius: 10) + .fill(isSelected || isHovering ? OmiColors.backgroundTertiary : Color.clear) + ) } + .buttonStyle(.plain) + .disabled(isLoading) + .onHover { isHovering = $0 } + } } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift index 5e776ad46c7..245660d5d7b 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift @@ -202,6 +202,24 @@ struct ConversationsPage: View { .scaledFont(size: 18, weight: .semibold) .foregroundColor(OmiColors.textPrimary) + if isLocalDaemonMode { + HStack(spacing: 5) { + Circle() + .fill(OmiColors.success) + .frame(width: 6, height: 6) + Text("Local") + .scaledFont(size: 11, weight: .semibold) + } + .foregroundColor(OmiColors.success) + .padding(.horizontal, 8) + .padding(.vertical, 4) + .background( + Capsule() + .fill(OmiColors.success.opacity(0.12)) + ) + .help("Using local daemon mode") + } + Spacer() quickNoteButton diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index a151d03b53a..4a1a2f137df 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -224,6 +224,10 @@ struct SettingsContentView: View { // Loading states @State private var isLoadingSettings: Bool = false + @State private var backendHealth: LocalDaemonHealth? + @State private var backendSettings: [LocalDaemonSetting] = [] + @State private var backendStatusError: String? + @State private var isLoadingBackendStatus: Bool = false @State private var userSubscription: UserSubscriptionResponse? @State private var isLoadingSubscription: Bool = false @State private var subscriptionError: String? @@ -479,6 +483,7 @@ struct SettingsContentView: View { selectedSection = .advanced } loadBackendSettings() + refreshSelectedBackendStatus() loadSubscriptionInfo() // Sync transcription state with appState isTranscribing = appState.isTranscribing @@ -1874,18 +1879,23 @@ struct SettingsContentView: View { settingsCard(settingId: "planusage.overage") { VStack(alignment: .leading, spacing: 10) { HStack(spacing: 10) { - Image(systemName: info.excessQuestions > 0 - ? "dollarsign.circle.fill" - : "checkmark.circle.fill") - .scaledFont(size: 18) - .foregroundColor(info.excessQuestions > 0 + Image( + systemName: info.excessQuestions > 0 + ? "dollarsign.circle.fill" + : "checkmark.circle.fill" + ) + .scaledFont(size: 18) + .foregroundColor( + info.excessQuestions > 0 ? OmiColors.warning : OmiColors.success) - Text(info.excessQuestions > 0 - ? "Usage-based overage" - : "No overage yet this cycle") - .scaledFont(size: 14, weight: .semibold) - .foregroundColor(OmiColors.textPrimary) + Text( + info.excessQuestions > 0 + ? "Usage-based overage" + : "No overage yet this cycle" + ) + .scaledFont(size: 14, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) Spacer() if info.excessQuestions > 0 { Text(String(format: "$%.2f", info.overageUsd)) @@ -1980,7 +1990,9 @@ struct SettingsContentView: View { .frame(minWidth: 440, minHeight: 360) } - private func overageExplainerRow(_ label: String, value: String, emphasized: Bool = false) -> some View { + private func overageExplainerRow(_ label: String, value: String, emphasized: Bool = false) + -> some View + { HStack { Text(label) .scaledFont(size: 12) @@ -2072,13 +2084,17 @@ struct SettingsContentView: View { // only show the hard "upgrade" copy on Free and other hard-capped // plans. if let info = overageInfo, info.isOveragePlan { - Text("You're past your included limit — extra usage is billed as overage at end of cycle.") - .scaledFont(size: 12) - .foregroundColor(OmiColors.warning) + Text( + "You're past your included limit — extra usage is billed as overage at end of cycle." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.warning) } else { - Text("You've reached this month's limit. Upgrade your plan or wait until the next reset.") - .scaledFont(size: 12) - .foregroundColor(OmiColors.warning) + Text( + "You've reached this month's limit. Upgrade your plan or wait until the next reset." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.warning) } } else if quota.percent >= 80.0 { Text("You're close to your monthly limit.") @@ -3095,9 +3111,11 @@ struct SettingsContentView: View { Text("Chat Prompt Lab") .scaledFont(size: 15, weight: .semibold) .foregroundColor(OmiColors.textPrimary) - Text("Iterate on chat system prompts with real questions, AI grading, and production ratings") - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) + Text( + "Iterate on chat system prompts with real questions, AI grading, and production ratings" + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) } Spacer() Button("Open") { @@ -5751,6 +5769,8 @@ struct SettingsContentView: View { private var aboutSection: some View { VStack(spacing: 20) { + backendStatusCard + settingsCard(settingId: "about.version") { VStack(spacing: 16) { // App info @@ -5957,6 +5977,120 @@ struct SettingsContentView: View { } } + private var backendStatusCard: some View { + let target = DesktopBackendEnvironment.selectedBackendTarget + return settingsCard(settingId: "about.backend") { + VStack(alignment: .leading, spacing: 14) { + HStack(spacing: 10) { + Circle() + .fill(backendStatusColor(for: target.mode)) + .frame(width: 10, height: 10) + + VStack(alignment: .leading, spacing: 2) { + Text(backendModeTitle(for: target.mode)) + .scaledFont(size: 15, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + + Text(target.baseURL) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .textSelection(.enabled) + } + + Spacer() + + if isLoadingBackendStatus { + ProgressView() + .controlSize(.small) + } else { + Button("Refresh") { + refreshSelectedBackendStatus() + } + .buttonStyle(.bordered) + } + } + + Divider() + .background(OmiColors.backgroundQuaternary) + + VStack(spacing: 8) { + backendStatusRow( + title: "Authentication", + value: target.requiresAuth ? "Firebase token required" : "No auth for loopback MVP" + ) + + if target.mode == .localDaemon { + backendStatusRow( + title: "Health", + value: backendHealth.map { "\($0.service) \($0.version)" } + ?? backendStatusError + ?? "Not checked" + ) + + backendStatusRow( + title: "Data Directory", + value: backendHealth?.dataDir ?? "Unavailable until /health responds" + ) + + backendStatusRow( + title: "Processing Provider", + value: localProcessingProviderStatus + ) + } else { + backendStatusRow(title: "Mode", value: "Omi-hosted backend services enabled") + } + } + } + } + } + + private var localProcessingProviderStatus: String { + let providerSetting = backendSettings.first { $0.key == "ai_provider" || $0.key == "provider" } + guard let providerSetting else { + return "Deterministic fallback" + } + if providerSetting.valueJson.contains("\"api_key\"") + || providerSetting.valueJson.contains("\"key\"") + { + return "OpenAI-compatible provider configured" + } + return "Deterministic fallback" + } + + private func backendStatusRow(title: String, value: String) -> some View { + HStack(alignment: .top, spacing: 12) { + Text(title) + .scaledFont(size: 12, weight: .medium) + .foregroundColor(OmiColors.textTertiary) + .frame(width: 130, alignment: .leading) + Text(value) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + .textSelection(.enabled) + Spacer() + } + } + + private func backendModeTitle(for mode: DesktopBackendEnvironment.BackendMode) -> String { + switch mode { + case .cloud: + return "Backend Mode: Cloud" + case .localDaemon: + return "Backend Mode: Local Daemon" + case .customRemote: + return "Backend Mode: Custom Remote" + } + } + + private func backendStatusColor(for mode: DesktopBackendEnvironment.BackendMode) -> Color { + switch mode { + case .localDaemon: + return backendHealth == nil ? OmiColors.warning : OmiColors.success + case .cloud, .customRemote: + return OmiColors.purplePrimary + } + } + // MARK: - Helper Views private func fontShortcutRow(label: String, keys: String) -> some View { @@ -6125,7 +6259,8 @@ struct SettingsContentView: View { // on the right. Hide the user's current plan — they already see it above. // Neo ($20) | Operator ($49) | Architect ($200) — cheapest to premium let order = ["unlimited": 0, "operator": 1, "architect": 2] - return mergedPlanCatalog + return + mergedPlanCatalog .filter { !isCurrentSubscriptionPlan($0) } .sorted { lhs, rhs in let lhsOrder = order[lhs.id, default: Int.max] @@ -6171,7 +6306,7 @@ struct SettingsContentView: View { /// Operator→Unlimited remapping in `/v1/users/me/subscription`. private func isCurrentSubscriptionOperator() -> Bool { guard let subscription = userSubscription?.subscription, - let currentPriceId = subscription.currentPriceId + let currentPriceId = subscription.currentPriceId else { return false } for plan in subscriptionPlansForDisplay { guard plan.title == "Operator" else { continue } @@ -6805,7 +6940,8 @@ struct SettingsContentView: View { notificationsEnabled = notifications.enabled notificationFrequency = notifications.frequency // Mirror to UserDefaults so NotificationService can throttle without a backend roundtrip. - UserDefaults.standard.set(notifications.frequency, forKey: NotificationService.frequencyDefaultsKey) + UserDefaults.standard.set( + notifications.frequency, forKey: NotificationService.frequencyDefaultsKey) userLanguage = language.language recordingPermissionEnabled = recording.enabled privateCloudSyncEnabled = cloudSync.enabled @@ -6840,6 +6976,43 @@ struct SettingsContentView: View { } } + private func refreshSelectedBackendStatus() { + guard !isLoadingBackendStatus else { return } + isLoadingBackendStatus = true + backendStatusError = nil + + Task { + let target = DesktopBackendEnvironment.selectedBackendTarget + guard target.mode == .localDaemon else { + await MainActor.run { + backendHealth = nil + backendSettings = [] + isLoadingBackendStatus = false + } + return + } + + do { + async let health = APIClient.shared.checkSelectedBackendHealth() + async let settings = APIClient.shared.getSelectedBackendSettings() + let (resolvedHealth, resolvedSettings) = try await (health, settings) + await MainActor.run { + backendHealth = resolvedHealth + backendSettings = resolvedSettings + backendStatusError = nil + isLoadingBackendStatus = false + } + } catch { + await MainActor.run { + backendHealth = nil + backendSettings = [] + backendStatusError = error.localizedDescription + isLoadingBackendStatus = false + } + } + } + } + private func loadSubscriptionInfo() { guard !isLoadingSubscription else { return } isLoadingSubscription = true @@ -6864,10 +7037,13 @@ struct SettingsContentView: View { // the trial cache) — without this they'd stay paywalled until the // next app restart even after their Operator/Architect plan is active. if subscription.subscription.plan != .basic, - subscription.subscription.status == .active, - AppState.current?.isPaywalled == true { + subscription.subscription.status == .active, + AppState.current?.isPaywalled == true + { AppState.current?.isPaywalled = false - log("Paywall: cleared sticky flag — subscription \(subscription.subscription.plan.rawValue) is active") + log( + "Paywall: cleared sticky flag — subscription \(subscription.subscription.plan.rawValue) is active" + ) } isLoadingSubscription = false } @@ -6923,8 +7099,8 @@ struct SettingsContentView: View { // If user already has an active paid subscription (not canceled), use upgrade endpoint // to schedule the plan change at end of billing period (no double-charging) if hasPaidSubscription, - let subscription = userSubscription?.subscription, - !subscription.cancelAtPeriodEnd + let subscription = userSubscription?.subscription, + !subscription.cancelAtPeriodEnd { Task { do { diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md index 7e363db1a1f..691216cd094 100644 --- a/desktop/local-backend/README.md +++ b/desktop/local-backend/README.md @@ -10,6 +10,21 @@ dependencies. For a complete user-test walkthrough, including desktop launch environment and transcript import, see `docs/local-mvp-runbook.md`. +Primary desktop local-mode user-test command: + +```bash +cd desktop +OMI_DESKTOP_BACKEND_MODE=local \ +OMI_LOCAL_DAEMON_SUPERVISE=1 \ +OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 \ +OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001 \ +OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ +./run.sh +``` + +This targets only the development app bundle and supervises the local daemon if +`/health` is not already reachable. + ```bash cd desktop/local-backend cargo run @@ -104,6 +119,10 @@ The helper creates a conversation, appends transcript segments, finalizes ingestion, waits for local processing, verifies search, and prints read/search commands for the imported conversation. +Local processing uses deterministic fallback unless an OpenAI-compatible +provider is configured through `PUT /v1/settings`; see the runbook for set and +clear commands. + ## Architecture And E2E Validation The durable MVP architecture note and validation checklist live in diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index e5314c4a8da..cb07e2eb5c0 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -16,16 +16,52 @@ conversation storage, transcript ingestion, processing fallback, and search. No Firebase, Omi Python backend, Rust cloud backend, Redis, Firestore, GCS, pusher, or agent-proxy credentials are required for the local daemon path. -## Start The Local Daemon +## Primary Desktop Launch -From the repo root: +For user-test runs, use one command from the repo root. This launches the +development app bundle only (`Omi Dev.app` / `com.omi.desktop-dev`), checks the +local daemon health endpoint, starts `desktop/local-backend` if needed, and +keeps Omi-hosted backend URLs deliberately invalid so accidental cloud routing +is obvious: ```bash -cd desktop/local-backend -cargo run +cd desktop +OMI_DESKTOP_BACKEND_MODE=local \ +OMI_LOCAL_DAEMON_SUPERVISE=1 \ +OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 \ +OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001 \ +OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ +./run.sh ``` -The daemon listens on `127.0.0.1:8765` by default. To keep test data isolated: +Required environment: + +- `OMI_DESKTOP_BACKEND_MODE=local` selects the local daemon profile. +- `OMI_LOCAL_DAEMON_SUPERVISE=1` lets `desktop/run.sh` start the daemon when + `/health` is not already reachable. + +Recommended test-boundary environment: + +- `OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765` makes the daemon URL explicit. +- `OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001` makes accidental Python + backend calls fail locally. +- `OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002` makes accidental cloud Rust + backend calls fail locally. + +Optional daemon environment: + +- `OMI_LOCAL_BACKEND_DATA_DIR=/tmp/omi-local-mvp` isolates SQLite data. +- `OMI_LOCAL_BACKEND_PORT=` chooses a free loopback port. Keep + `OMI_LOCAL_DAEMON_URL` in sync with it. +- `OMI_LOCAL_DAEMON_LOG=/tmp/omi-local-backend-dev.log` changes the supervised + daemon log path. + +Do not use this launcher to manage `/Applications/omi.app`. + +## Manual Daemon Launch + +For API-only testing or when you do not want `desktop/run.sh` to supervise the +daemon, start it manually from the repo root: ```bash cd desktop/local-backend @@ -49,40 +85,28 @@ Expected signals: If health does not respond, check that the daemon terminal is still running and that no other process is already using the selected port. -## Launch Desktop In Local Daemon Mode - -For developer/user-test runs, the desktop launcher can supervise the local -daemon. It checks `/health`, starts `desktop/local-backend` only if the daemon -is unreachable, and stops only the daemon process it started when the launcher -exits: +To keep using a manually managed daemon, start it first and launch the desktop +app with the same local-mode environment: ```bash cd desktop OMI_DESKTOP_BACKEND_MODE=local \ -OMI_LOCAL_DAEMON_SUPERVISE=1 \ OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 \ OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001 \ OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ ./run.sh ``` -The invalid cloud URLs make accidental cloud routing obvious during a user test. -Local conversation, transcript, memory, action item, settings, and search flows -should use `OMI_LOCAL_DAEMON_URL`. The launcher path targets the development app -bundle only (`Omi Dev.app` / `com.omi.desktop-dev`) and must not be used to -manage `/Applications/omi.app`. +## Confirm Local Mode In The App -To keep using a manually managed daemon, start it first and launch the desktop -app with the same local-mode environment: +The Conversations header shows a `Local` chip when the app is using local daemon +mode. Settings → About also includes a Backend Mode card with the selected +daemon URL, auth requirement, `/health` result, data directory, and whether +processing is using deterministic fallback or an OpenAI-compatible provider. -```bash -cd desktop -OMI_DESKTOP_BACKEND_MODE=local \ -OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 \ -OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001 \ -OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ -./run.sh -``` +Cloud-only folder and public share controls are hidden in local mode. Merge and +folder API calls fail locally before building an Omi-hosted request if they are +reached from another surface. ## Import Transcript Data @@ -173,6 +197,42 @@ After finalization, the conversation status should become `processed`. Search results should include the imported conversation when the query appears in the title, overview, or transcript text. +## Local Provider Configuration + +Processing works without provider keys by using deterministic fallback. To force +that path, clear provider settings: + +```bash +curl -X PUT http://127.0.0.1:8765/v1/settings \ + -H 'content-type: application/json' \ + -d '{"ai_provider": null, "provider": null}' +``` + +To test a direct OpenAI-compatible provider without editing source code, store +the provider configuration in the local daemon settings: + +```bash +curl -X PUT http://127.0.0.1:8765/v1/settings \ + -H 'content-type: application/json' \ + -d '{ + "ai_provider": { + "kind": "openai_compatible", + "base_url": "https://api.openai.com/v1", + "model": "gpt-4o-mini", + "api_key": "'"$OPENAI_API_KEY"'" + } + }' +``` + +Inspect the active settings: + +```bash +curl http://127.0.0.1:8765/v1/settings +``` + +Provider keys remain in the local daemon SQLite settings table and are sent +directly to the configured provider, not to Omi-hosted backend services. + ## What Works Without Omi-Hosted Services - Local daemon startup on loopback. From dc88c4222822d94eac59f744bee0c221fb61b06f Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 10:16:30 +0700 Subject: [PATCH 15/58] Polish local daemon launch path --- desktop/.env.example | 3 ++- desktop/local-backend/docs/local-mvp-runbook.md | 3 +++ desktop/run.sh | 14 ++++++++++---- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/desktop/.env.example b/desktop/.env.example index 0af05c97cfa..f00efd9adb1 100644 --- a/desktop/.env.example +++ b/desktop/.env.example @@ -43,7 +43,8 @@ OMI_PYTHON_API_URL=https://api.omi.me # # Required for local mode: OMI_DESKTOP_BACKEND_MODE=local. # Optional: OMI_LOCAL_DAEMON_SUPERVISE=1, OMI_LOCAL_DAEMON_URL, -# OMI_LOCAL_BACKEND_DATA_DIR, OMI_LOCAL_BACKEND_PORT, OMI_LOCAL_DAEMON_LOG. +# OMI_LOCAL_BACKEND_DATA_DIR, OMI_LOCAL_BACKEND_PORT, OMI_LOCAL_DAEMON_LOG, +# OMI_CLEAN_STALE_CLONES=1. # OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 # OMI_LOCAL_DAEMON_SUPERVISE=1 diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index cb07e2eb5c0..5ad5ddff58c 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -55,6 +55,9 @@ Optional daemon environment: `OMI_LOCAL_DAEMON_URL` in sync with it. - `OMI_LOCAL_DAEMON_LOG=/tmp/omi-local-backend-dev.log` changes the supervised daemon log path. +- `OMI_CLEAN_STALE_CLONES=1` enables the broad home-directory cleanup for stale + `Omi Dev.app` clones. Local daemon mode skips that scan by default so the + primary launch command reaches the daemon preflight quickly. Do not use this launcher to manage `/Applications/omi.app`. diff --git a/desktop/run.sh b/desktop/run.sh index 0db58257756..570254b6d88 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -260,10 +260,16 @@ done find "$(dirname "$0")/../app/build" -name "$APP_NAME.app" -type d -exec rm -rf {} + 2>/dev/null || true # Kill stale app bundles from other repo clones (e.g. ~/omi-desktop/) # These confuse LaunchServices and get launched instead of the /Applications copy. -find "$HOME" -maxdepth 4 -name "$APP_NAME.app" -type d -not -path "$APP_BUNDLE" -not -path "$APP_PATH" 2>/dev/null | while read stale; do - substep "Removing stale clone: $stale" - rm -rf "$stale" -done +# In local daemon mode, keep the primary user-test command fast and avoid broad +# home-directory scans unless explicitly requested. +if ! is_local_daemon_mode || [ "${OMI_CLEAN_STALE_CLONES:-0}" = "1" ]; then + find "$HOME" -maxdepth 4 -name "$APP_NAME.app" -type d -not -path "$APP_BUNDLE" -not -path "$APP_PATH" 2>/dev/null | while read stale; do + substep "Removing stale clone: $stale" + rm -rf "$stale" + done +else + substep "Local daemon mode: skipping stale clone scan (set OMI_CLEAN_STALE_CLONES=1 to enable)" +fi if [ "${OMI_SKIP_TUNNEL:-0}" != "1" ]; then step "Starting Cloudflare quick tunnel..." From 0e4dddf4b3e10b5d6f502e6acbac361a4f00cb8e Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 10:21:18 +0700 Subject: [PATCH 16/58] Add local-only MVP self-test --- .../Desktop/Tests/APIClientRoutingTests.swift | 67 ++++++++++++++++++- .../local-backend/docs/local-mvp-runbook.md | 15 +++++ desktop/local-backend/tools/e2e_smoke.sh | 27 +++++++- .../tools/local_only_self_test.sh | 24 +++++++ 4 files changed, 128 insertions(+), 5 deletions(-) create mode 100755 desktop/local-backend/tools/local_only_self_test.sh diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index d28c1f00000..6587c1226a0 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -128,6 +128,32 @@ private func assertRoutes( XCTAssertEqual(req.method, method, "\(label): wrong HTTP method", file: file, line: line) } +private func assertNoOmiHostedBackendRequests( + _ reqs: [CapturedRequest], + file: StaticString = #filePath, + line: UInt = #line +) { + let forbiddenHosts: Set = [ + "api.omi.me", + "api.omiapi.com", + "desktop-backend-hhibjajaja-uc.a.run.app", + "desktop-backend-dt5lrfkkoa-uc.a.run.app", + "omi-cloud-invalid", + "omi-rust-invalid", + ] + + XCTAssertFalse( + reqs.contains { request in + guard let host = request.url.host else { return false } + return forbiddenHosts.contains(host) || host.contains("firebase") + || host.hasSuffix(".firebaseio.com") || host.hasSuffix(".googleapis.com") + }, + "local-mode routing should not call Omi-hosted backend or Firebase endpoints", + file: file, + line: line + ) +} + private func assertUnavailable( _ error: Error?, capability: DesktopBackendEnvironment.Capability, @@ -482,13 +508,48 @@ final class APIClientRoutingTests: XCTestCase { try? await client.updateConversationTitle(id: "local-123", title: "Offline") try? await client.setConversationStarred(id: "local-123", starred: true) _ = try? await client.updateSelectedBackendSettings(["profile_name": "Offline"]) + try? await client.deleteConversation(id: "local-123") let requests = URLCapture.capturedRequests - XCTAssertEqual(requests.count, 6) + XCTAssertEqual(requests.count, 7) XCTAssertTrue(requests.allSatisfy { $0.url.host == "127.0.0.1" && $0.url.port == 9876 }) XCTAssertTrue(requests.allSatisfy { $0.headers["Authorization"] == nil }) - XCTAssertFalse( - requests.contains { $0.url.host == "omi-cloud-invalid" || $0.url.host == "omi-rust-invalid" }) + assertNoOmiHostedBackendRequests(requests) + } + + func testLocalModeTranscriptImportRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "https://api.omi.me", 1) + setenv("OMI_DESKTOP_API_URL", "https://desktop-backend-hhibjajaja-uc.a.run.app", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:9876", 1) + let client = await makeTestClient() + let segment = TranscriptionSegmentRecord( + sessionId: 42, + speaker: 0, + text: "Local transcript import should stay on loopback.", + startTime: 0, + endTime: 2.5, + segmentOrder: 0, + segmentId: "seg-local-0", + speakerLabel: "Speaker 1" + ) + + try? await client.appendLocalDaemonTranscriptSegment( + conversationId: "local-123", + segment: segment + ) + try? await client.finalizeLocalDaemonTranscript(conversationId: "local-123") + + let requests = URLCapture.capturedRequests + XCTAssertEqual(requests.count, 2) + XCTAssertTrue(requests.allSatisfy { $0.url.host == "127.0.0.1" && $0.url.port == 9876 }) + XCTAssertTrue(requests.allSatisfy { $0.headers["Authorization"] == nil }) + XCTAssertEqual(requests.map(\.method), ["POST", "POST"]) + XCTAssertTrue( + requests[0].url.path.contains("/v1/conversations/local-123/transcript-segments")) + XCTAssertTrue( + requests[1].url.path.contains("/v1/conversations/local-123/finalize-transcript")) + assertNoOmiHostedBackendRequests(requests) } func testLocalModeSetConversationStarredRoutesToLocalDaemonWithoutAuth() async { diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index 5ad5ddff58c..6a9cd81b00e 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -16,6 +16,21 @@ conversation storage, transcript ingestion, processing fallback, and search. No Firebase, Omi Python backend, Rust cloud backend, Redis, Firestore, GCS, pusher, or agent-proxy credentials are required for the local daemon path. +## Automated Local-Only Self-Test + +Run the unattended MVP check from the repo root: + +```bash +desktop/local-backend/tools/local_only_self_test.sh +``` + +The self-test creates an isolated temp data directory, starts the local daemon +on a free loopback port, verifies health/profile/settings, conversation +create/read/update/delete, transcript append/finalize, search, processing +status, restart persistence, and then runs `APIClientRoutingTests` to assert +local-mode desktop actions stay on the local daemon without Firebase auth or +Omi-hosted backend requests. It prints a concise pass/fail summary at the end. + ## Primary Desktop Launch For user-test runs, use one command from the repo root. This launches the diff --git a/desktop/local-backend/tools/e2e_smoke.sh b/desktop/local-backend/tools/e2e_smoke.sh index b2e8dc30cd1..606489ffab9 100755 --- a/desktop/local-backend/tools/e2e_smoke.sh +++ b/desktop/local-backend/tools/e2e_smoke.sh @@ -189,5 +189,28 @@ assert_json_value "${persisted_file}" "transcript_segments.0.text" "Plan the bac persisted_search_file="$(request GET '/v1/search/conversations?q=backend')" assert_json_value "${persisted_search_file}" "results.0.conversation_id" "conv-e2e-smoke" -echo "Local backend E2E smoke passed." -echo "OMI_DESKTOP_BACKEND_MODE=local OMI_LOCAL_DAEMON_URL=${BASE_URL}" +request DELETE /v1/conversations/conv-e2e-smoke >/dev/null + +deleted_list_file="$(request GET /v1/conversations)" +deleted_count="$(python3 - "${deleted_list_file}" <<'PY' +import json +import sys + +with open(sys.argv[1], "r", encoding="utf-8") as handle: + data = json.load(handle) +print(len(data["conversations"])) +PY +)" +if [[ "${deleted_count}" != "0" ]]; then + echo "Expected deleted conversation to be hidden from list, got ${deleted_count} conversations" >&2 + echo "Response file: ${deleted_list_file}" >&2 + exit 1 +fi + +cat < ${label}" + "$@" + summary+=("PASS ${label}") +} + +run_step "local daemon MVP API flow" \ + "${ROOT_DIR}/desktop/local-backend/tools/e2e_smoke.sh" + +run_step "desktop local-mode routing boundary" \ + xcrun swift test --package-path "${ROOT_DIR}/desktop/Desktop" --filter APIClientRoutingTests + +echo +echo "Local-only MVP self-test passed:" +printf -- "- %s\n" "${summary[@]}" From 2638675d15ee5507df17114fb90ceb7045e22fe3 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 10:26:52 +0700 Subject: [PATCH 17/58] Document desktop routing test command --- desktop/local-backend/docs/architecture.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/desktop/local-backend/docs/architecture.md b/desktop/local-backend/docs/architecture.md index 43bafa13c0b..6e705d4701e 100644 --- a/desktop/local-backend/docs/architecture.md +++ b/desktop/local-backend/docs/architecture.md @@ -102,8 +102,8 @@ cargo test Run focused desktop routing checks: ```bash -cd desktop/Desktop -swift test --filter APIClientRoutingTests +cd desktop +xcrun swift test --package-path Desktop --filter APIClientRoutingTests ``` Manual desktop local mode check: From 1d8d95bb41a44c7f4c7e2410826352e7bbbb55fe Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 12:06:47 +0700 Subject: [PATCH 18/58] Harden local daemon retry contract --- desktop/Desktop/Sources/APIClient.swift | 66 +++++++ desktop/Desktop/Sources/AppState.swift | 7 + .../Desktop/Tests/APIClientRoutingTests.swift | 51 ++++++ desktop/local-backend/README.md | 12 +- .../local-backend/docs/local-mvp-runbook.md | 20 +- desktop/local-backend/src/main.rs | 171 ++++++++++++++++++ desktop/local-backend/src/providers.rs | 63 +++++++ desktop/local-backend/src/routes.rs | 134 +++++++++++--- desktop/local-backend/tools/e2e_smoke.sh | 34 +++- .../local-backend/tools/import_transcript.py | 13 +- 10 files changed, 523 insertions(+), 48 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index f805b2dac44..9d5b0a8f8dd 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -218,6 +218,13 @@ actor APIClient { func updateSelectedBackendSettings(_ values: [String: String]) async throws -> [LocalDaemonSetting] + { + try await updateSelectedBackendSettings( + values.mapValues { LocalDaemonSettingUpdateValue.string($0) }) + } + + func updateSelectedBackendSettings(_ values: [String: LocalDaemonSettingUpdateValue]) async throws + -> [LocalDaemonSetting] { let target = selectedBackendTarget guard target.mode == .localDaemon else { @@ -360,6 +367,54 @@ struct LocalDaemonSetting: Decodable, Equatable { } } +enum LocalDaemonSettingUpdateValue: Encodable, Equatable, ExpressibleByStringLiteral, + ExpressibleByBooleanLiteral, ExpressibleByIntegerLiteral, ExpressibleByFloatLiteral +{ + case string(String) + case bool(Bool) + case int(Int) + case double(Double) + case object([String: LocalDaemonSettingUpdateValue]) + case array([LocalDaemonSettingUpdateValue]) + case null + + init(stringLiteral value: String) { + self = .string(value) + } + + init(booleanLiteral value: Bool) { + self = .bool(value) + } + + init(integerLiteral value: Int) { + self = .int(value) + } + + init(floatLiteral value: Double) { + self = .double(value) + } + + func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .string(let value): + try container.encode(value) + case .bool(let value): + try container.encode(value) + case .int(let value): + try container.encode(value) + case .double(let value): + try container.encode(value) + case .object(let value): + try container.encode(value) + case .array(let value): + try container.encode(value) + case .null: + try container.encodeNil() + } + } +} + // MARK: - Conversation API extension APIClient { @@ -1740,6 +1795,17 @@ extension APIClient { /// Returns the processed conversation on success, nil on 404 (already processed). /// Throws on other errors. func forceProcessConversation() async throws -> ServerConversation? { + let target = selectedBackendTarget + guard target.mode != .localDaemon else { + throw APIError.featureUnavailable( + feature: "force_process_conversation", + reason: DesktopBackendEnvironment.unavailableReason( + for: .hostedTranscription, + in: target.mode + ) ?? "Force-processing is only available for hosted transcription sessions." + ) + } + struct EmptyBody: Encodable {} do { diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 261bcff9074..f8725a1ce20 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -1653,6 +1653,13 @@ class AppState: ObservableObject { return } + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + log( + "Transcription: Local daemon mode stopped capture; leaving session \(capturedSessionId.map(String.init) ?? "nil") for local transcript retry/finalize" + ) + return + } + do { if let conversation = try await APIClient.shared.forceProcessConversation() { // Validate the returned conversation matches the session we just stopped diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 6587c1226a0..70500759df2 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -495,6 +495,38 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) } + func testLocalModeStructuredProviderSettingsRouteToLocalDaemon() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + _ = try? await client.updateSelectedBackendSettings([ + "ai_provider": .object([ + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:43210/v1", + "model": "stub-model", + "api_key": "local-test-key", + ]), + "local_first": true, + ]) + + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 8765, + pathContains: "v1/settings", method: "PUT", + label: "local structured provider settings") + XCTAssertNil(requests.first?.headers["Authorization"]) + + let body = requests.first?.body.flatMap { + try? JSONSerialization.jsonObject(with: $0) as? [String: Any] + } + let provider = body?["ai_provider"] as? [String: Any] + XCTAssertEqual(provider?["kind"] as? String, "openai_compatible") + XCTAssertEqual(provider?["base_url"] as? String, "http://127.0.0.1:43210/v1") + XCTAssertEqual(provider?["model"] as? String, "stub-model") + XCTAssertEqual(provider?["api_key"] as? String, "local-test-key") + XCTAssertEqual(body?["local_first"] as? Bool, true) + } + func testLocalModeMVPConversationFlowsIgnoreInvalidCloudURLs() async { setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) setenv("OMI_PYTHON_API_URL", "http://omi-cloud-invalid:9001", 1) @@ -705,6 +737,25 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertTrue(URLCapture.capturedRequests.isEmpty) } + func testLocalModeForceProcessConversationFailsBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + do { + _ = try await client.forceProcessConversation() + XCTFail("expected force-process to be unavailable") + } catch { + guard case APIError.featureUnavailable(let feature, _) = error else { + XCTFail("expected featureUnavailable for force-process, got \(error)") + return + } + XCTAssertEqual(feature, "force_process_conversation") + } + + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + // -- Conversations: manual URL(string: baseURL + ...) paths (PATCH → Python) -- func testSetConversationStarredRoutesToPython() async { diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md index 691216cd094..af354c7d59d 100644 --- a/desktop/local-backend/README.md +++ b/desktop/local-backend/README.md @@ -109,7 +109,9 @@ loopback daemon URL and do not require Firebase auth. Hosted transcription endpoints are intentionally unavailable in local daemon mode. Direct live STT parity is not part of this MVP unless a future direct provider path is added. For user testing, import transcript text or JSON -fixtures through the supported helper: +fixtures through the supported helper. Desktop local mode blocks hosted live +capture before network and leaves stopped local sessions for the retry/import +finalize path instead of calling Python force-process. ```bash desktop/local-backend/tools/import_transcript.py /path/to/transcript.txt @@ -117,11 +119,13 @@ desktop/local-backend/tools/import_transcript.py /path/to/transcript.txt The helper creates a conversation, appends transcript segments, finalizes ingestion, waits for local processing, verifies search, and prints read/search -commands for the imported conversation. +commands for the imported conversation. Stable client conversation, memory, and +action-item IDs are idempotent: exact replay returns the existing row, while a +conflicting replay returns HTTP 409. Local processing uses deterministic fallback unless an OpenAI-compatible -provider is configured through `PUT /v1/settings`; see the runbook for set and -clear commands. +provider is configured through structured `PUT /v1/settings` JSON; see the +runbook for local-stub set and clear commands. ## Architecture And E2E Validation diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index 6a9cd81b00e..cf1445af08f 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -133,6 +133,11 @@ Direct live STT parity is not part of the current MVP unless a future direct provider path is added. For local MVP testing, import or append transcript text and then finalize processing. +The desktop app refuses hosted live capture before opening a WebSocket when +`OMI_DESKTOP_BACKEND_MODE=local`. Stopping a capture session in local mode does +not call Python force-process; any locally stored session data is left for the +local retry/import/finalize path with a log entry. + Create a plain text fixture: ```bash @@ -165,7 +170,8 @@ For retry tests, pass a stable `--conversation-id` and run the same command again. The helper reuses the existing conversation, exact duplicate transcript segments return the existing row, and finalize returns an already active or current completed processing job instead of piling up duplicate queued work. A -different segment body at an existing `segment_index` returns HTTP 409. +different conversation payload for an existing `id`, or a different segment body +at an existing `segment_index`, returns HTTP 409. JSON fixtures are also supported. The file may be a list of segment strings, a list of segment objects, or an object with conversation fields plus `segments` @@ -226,8 +232,10 @@ curl -X PUT http://127.0.0.1:8765/v1/settings \ -d '{"ai_provider": null, "provider": null}' ``` -To test a direct OpenAI-compatible provider without editing source code, store -the provider configuration in the local daemon settings: +To test a direct OpenAI-compatible provider without editing source code, point +the provider configuration at a local stub or a user-managed endpoint. For local +MVP validation, prefer a loopback stub so the test cannot reach hosted Omi or +OpenAI services by accident: ```bash curl -X PUT http://127.0.0.1:8765/v1/settings \ @@ -235,9 +243,9 @@ curl -X PUT http://127.0.0.1:8765/v1/settings \ -d '{ "ai_provider": { "kind": "openai_compatible", - "base_url": "https://api.openai.com/v1", - "model": "gpt-4o-mini", - "api_key": "'"$OPENAI_API_KEY"'" + "base_url": "http://127.0.0.1:43210/v1", + "model": "local-stub", + "api_key": "local-test-key" } }' ``` diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index 049f24d264c..c2a5c5b3f1c 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -320,6 +320,148 @@ mod tests { Ok(()) } + #[tokio::test] + async fn create_routes_are_idempotent_for_client_supplied_ids() -> Result<()> { + let app = test_app()?; + + let conversation_body = json!({ + "id": "conv-idempotent", + "session_id": "session-idempotent", + "title": "Replay safe", + "overview": "Same client payload", + "metadata": {"source": "test"} + }); + let first_conversation = request_json( + app.clone(), + Method::POST, + "/v1/conversations", + Some(conversation_body.clone()), + ) + .await?; + let second_conversation = request_json( + app.clone(), + Method::POST, + "/v1/conversations", + Some(conversation_body), + ) + .await?; + assert_eq!( + first_conversation["conversation"]["id"], + second_conversation["conversation"]["id"] + ); + assert_eq!( + first_conversation["conversation"]["created_at"], + second_conversation["conversation"]["created_at"] + ); + + let memory_body = json!({ + "id": "mem-idempotent", + "content": "User prefers local retries.", + "category": "preference", + "conversation_id": "conv-idempotent", + "metadata": {"source": "test"} + }); + request_json( + app.clone(), + Method::POST, + "/v1/memories", + Some(memory_body.clone()), + ) + .await?; + let replayed_memory = + request_json(app.clone(), Method::POST, "/v1/memories", Some(memory_body)).await?; + assert_eq!(replayed_memory["memory"]["id"], "mem-idempotent"); + + let action_item_body = json!({ + "id": "act-idempotent", + "conversation_id": "conv-idempotent", + "title": "Check retry path", + "description": "Replay the create call", + "status": "open", + "metadata": {"source": "test"} + }); + request_json( + app.clone(), + Method::POST, + "/v1/action-items", + Some(action_item_body.clone()), + ) + .await?; + let replayed_action_item = request_json( + app, + Method::POST, + "/v1/action-items", + Some(action_item_body), + ) + .await?; + assert_eq!(replayed_action_item["action_item"]["id"], "act-idempotent"); + + Ok(()) + } + + #[tokio::test] + async fn create_routes_return_conflict_for_same_id_different_payload() -> Result<()> { + let app = test_app()?; + + request_json( + app.clone(), + Method::POST, + "/v1/conversations", + Some(json!({ + "id": "conv-conflict", + "session_id": "session-conflict", + "title": "Original" + })), + ) + .await?; + request_status( + app.clone(), + Method::POST, + "/v1/conversations", + Some(json!({ + "id": "conv-conflict", + "session_id": "session-conflict", + "title": "Changed" + })), + StatusCode::CONFLICT, + ) + .await?; + + request_json( + app.clone(), + Method::POST, + "/v1/memories", + Some(json!({"id": "mem-conflict", "content": "Original"})), + ) + .await?; + request_status( + app.clone(), + Method::POST, + "/v1/memories", + Some(json!({"id": "mem-conflict", "content": "Changed"})), + StatusCode::CONFLICT, + ) + .await?; + + request_json( + app.clone(), + Method::POST, + "/v1/action-items", + Some(json!({"id": "act-conflict", "title": "Original"})), + ) + .await?; + request_status( + app, + Method::POST, + "/v1/action-items", + Some(json!({"id": "act-conflict", "title": "Changed"})), + StatusCode::CONFLICT, + ) + .await?; + + Ok(()) + } + fn test_app() -> Result { let config = Config { bind_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), @@ -358,4 +500,33 @@ mod tests { ); Ok(serde_json::from_slice(&bytes)?) } + + async fn request_status( + app: Router, + method: Method, + uri: &str, + body: Option, + expected_status: StatusCode, + ) -> Result { + let request_body = match body { + Some(value) => Body::from(serde_json::to_vec(&value)?), + None => Body::empty(), + }; + let request = Request::builder() + .method(method) + .uri(uri) + .header("content-type", "application/json") + .body(request_body)?; + + let response = app.oneshot(request).await?; + let status = response.status(); + let bytes = to_bytes(response.into_body(), 1024 * 1024).await?; + assert_eq!( + status, + expected_status, + "unexpected status {status}: {}", + String::from_utf8_lossy(&bytes) + ); + Ok(serde_json::from_slice(&bytes)?) + } } diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs index c83ac8402ee..5f915c985f7 100644 --- a/desktop/local-backend/src/providers.rs +++ b/desktop/local-backend/src/providers.rs @@ -146,6 +146,9 @@ pub fn load_openai_config(store: &Store) -> Result Result<()> { + let store = Store::open_in_memory()?; + let mut settings = Map::new(); + settings.insert( + "ai_provider".to_string(), + json!({ + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:43210/v1", + "model": "stub-model", + "api_key": "local-test-key" + }), + ); + store.settings().upsert_many(settings)?; + + let config = load_openai_config(&store)?.expect("provider should be configured"); + assert_eq!(config.base_url, "http://127.0.0.1:43210/v1"); + assert_eq!(config.model, "stub-model"); + assert_eq!(config.api_key, "local-test-key"); + + Ok(()) + } + + #[tokio::test] + async fn openai_compatible_provider_uses_local_stub_endpoint() -> Result<()> { + let app = Router::new().route( + "/v1/chat/completions", + post(|| async { + Json(json!({ + "choices": [{ + "message": { + "content": "{\"title\":\"Stub title\",\"overview\":\"Stub overview\",\"action_items\":[],\"memories\":[]}" + } + }] + })) + }), + ); + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + tokio::spawn(async move { + axum::serve(listener, app) + .await + .expect("stub server failed"); + }); + + let provider = OpenAiCompatibleProvider::new(OpenAiCompatibleConfig { + base_url: format!("http://{addr}/v1"), + model: "stub-model".to_string(), + api_key: "local-test-key".to_string(), + }); + + let response = provider + .complete_json(vec![ChatMessage::user("Summarize locally.")]) + .await?; + assert_eq!(response["title"], "Stub title"); + assert_eq!(response["overview"], "Stub overview"); + + Ok(()) + } } diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index fb7a7fb1417..6df7575ce0a 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -197,17 +197,31 @@ async fn create_conversation( let id = request .id .unwrap_or_else(|| deterministic_id("conv", &[&session_id])); + let new_conversation = NewConversation { + id: id.clone(), + session_id, + title: request.title.unwrap_or_default(), + overview: request.overview.unwrap_or_default(), + started_at: request.started_at, + metadata: request.metadata, + }; + if let Some(existing) = state + .store + .conversations() + .get(&id) + .map_err(ApiError::internal)? + { + if conversation_matches_new(&existing, &new_conversation).map_err(ApiError::internal)? { + return Ok(Json(json!({ "conversation": existing }))); + } + return Err(ApiError::conflict( + "conversation already exists with different content", + )); + } let conversation = state .store .conversations() - .create(NewConversation { - id, - session_id, - title: request.title.unwrap_or_default(), - overview: request.overview.unwrap_or_default(), - started_at: request.started_at, - metadata: request.metadata, - }) + .create(new_conversation) .map_err(ApiError::internal)?; Ok(Json(json!({ "conversation": conversation }))) } @@ -424,16 +438,31 @@ async fn create_memory( State(state): State, Json(request): Json, ) -> ApiResult { + let id = request.id.unwrap_or_else(|| local_id("mem")); + let new_memory = NewMemory { + id: id.clone(), + content: request.content, + category: request.category, + conversation_id: request.conversation_id, + metadata: request.metadata, + }; + if let Some(existing) = state + .store + .memories() + .get(&id) + .map_err(ApiError::internal)? + { + if memory_matches_new(&existing, &new_memory).map_err(ApiError::internal)? { + return Ok(Json(json!({ "memory": existing }))); + } + return Err(ApiError::conflict( + "memory already exists with different content", + )); + } let memory = state .store .memories() - .create(NewMemory { - id: request.id.unwrap_or_else(|| local_id("mem")), - content: request.content, - category: request.category, - conversation_id: request.conversation_id, - metadata: request.metadata, - }) + .create(new_memory) .map_err(ApiError::internal)?; Ok(Json(json!({ "memory": memory }))) } @@ -518,18 +547,33 @@ async fn create_action_item( State(state): State, Json(request): Json, ) -> ApiResult { + let id = request.id.unwrap_or_else(|| local_id("act")); + let new_action_item = NewActionItem { + id: id.clone(), + conversation_id: request.conversation_id, + title: request.title, + description: request.description, + status: request.status, + due_at: request.due_at, + metadata: request.metadata, + }; + if let Some(existing) = state + .store + .action_items() + .get(&id) + .map_err(ApiError::internal)? + { + if action_item_matches_new(&existing, &new_action_item).map_err(ApiError::internal)? { + return Ok(Json(json!({ "action_item": existing }))); + } + return Err(ApiError::conflict( + "action item already exists with different content", + )); + } let action_item = state .store .action_items() - .create(NewActionItem { - id: request.id.unwrap_or_else(|| local_id("act")), - conversation_id: request.conversation_id, - title: request.title, - description: request.description, - status: request.status, - due_at: request.due_at, - metadata: request.metadata, - }) + .create(new_action_item) .map_err(ApiError::internal)?; Ok(Json(json!({ "action_item": action_item }))) } @@ -713,3 +757,45 @@ fn local_id(prefix: &str) -> String { .unwrap_or_else(|| Utc::now().timestamp_micros() * 1000); deterministic_id(prefix, &[&now.to_string()]) } + +fn conversation_matches_new( + existing: &crate::storage::Conversation, + new: &NewConversation, +) -> anyhow::Result { + let mutable_fields_match = existing.status != "open" + || (existing.title == new.title && existing.overview == new.overview); + Ok(existing.id == new.id + && existing.session_id == new.session_id + && mutable_fields_match + && new + .started_at + .map(|started_at| existing.started_at == started_at) + .unwrap_or(true) + && json_matches_optional(&existing.metadata_json, &new.metadata)?) +} + +fn memory_matches_new(existing: &crate::storage::Memory, new: &NewMemory) -> anyhow::Result { + Ok(existing.id == new.id + && existing.content == new.content + && existing.category == new.category + && existing.conversation_id == new.conversation_id + && json_matches_optional(&existing.metadata_json, &new.metadata)?) +} + +fn action_item_matches_new( + existing: &crate::storage::ActionItem, + new: &NewActionItem, +) -> anyhow::Result { + Ok(existing.id == new.id + && existing.conversation_id == new.conversation_id + && existing.title == new.title + && existing.description == new.description.clone().unwrap_or_default() + && existing.status == new.status.clone().unwrap_or_else(|| "open".to_string()) + && existing.due_at == new.due_at + && json_matches_optional(&existing.metadata_json, &new.metadata)?) +} + +fn json_matches_optional(existing_json: &str, new: &Option) -> anyhow::Result { + let existing: Value = serde_json::from_str(existing_json)?; + Ok(existing == new.clone().unwrap_or_else(|| json!({}))) +} diff --git a/desktop/local-backend/tools/e2e_smoke.sh b/desktop/local-backend/tools/e2e_smoke.sh index 606489ffab9..960e4fca3a0 100755 --- a/desktop/local-backend/tools/e2e_smoke.sh +++ b/desktop/local-backend/tools/e2e_smoke.sh @@ -40,6 +40,24 @@ print(value) PY } +json_embedded_value() { + python3 - "$1" "$2" "$3" <<'PY' +import json +import sys + +path = sys.argv[1].split(".") +with open(sys.argv[2], "r", encoding="utf-8") as handle: + value = json.load(handle) +for part in path[:-1]: + if part.isdigit(): + value = value[int(part)] + else: + value = value[part] +embedded = json.loads(value[path[-1]]) +print(embedded[sys.argv[3]]) +PY +} + request() { local method="$1" local path="$2" @@ -165,6 +183,12 @@ job_file="$(request POST /v1/conversations/conv-e2e-smoke/finalize-transcript)" job_id="$(json_value "processing_job.id" "${job_file}")" completed_job_file="$(wait_for_completed_job "${job_id}")" assert_json_value "${completed_job_file}" "processing_job.status" "completed" +fallback_provider="$(json_embedded_value "processing_job.result_json" "${completed_job_file}" provider)" +if [[ "${fallback_provider}" != "fallback" ]]; then + echo "Expected fallback processing provider, got ${fallback_provider}" >&2 + echo "Response file: ${completed_job_file}" >&2 + exit 1 +fi status_file="$(request GET /v1/processing-jobs/status)" assert_json_value "${status_file}" "failed" "0" @@ -175,9 +199,15 @@ assert_json_value "${processed_file}" "conversation.title" "Plan the backend fre settings_file="$(request PUT /v1/settings '{ "local_first": true, - "provider.kind": "fallback" + "ai_provider": { + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:43210/v1", + "model": "local-stub", + "api_key": "local-test-key" + } }')" -assert_json_value "${settings_file}" "settings.0.key" "local_first" +assert_json_value "${settings_file}" "settings.0.key" "ai_provider" +assert_json_value "${settings_file}" "settings.1.key" "local_first" stop_daemon start_daemon diff --git a/desktop/local-backend/tools/import_transcript.py b/desktop/local-backend/tools/import_transcript.py index 37dc3dc985c..666276805ab 100755 --- a/desktop/local-backend/tools/import_transcript.py +++ b/desktop/local-backend/tools/import_transcript.py @@ -176,18 +176,7 @@ def main() -> int: conversation[key] = value request_json("GET", args.base_url, "/health") - if "id" in conversation: - existing = request_json( - "GET", - args.base_url, - f"/v1/conversations/{urllib.parse.quote(conversation['id'])}", - ok_statuses={200, 404}, - ) - created = existing.get("conversation") - else: - created = None - if created is None: - created = request_json("POST", args.base_url, "/v1/conversations", conversation)["conversation"] + created = request_json("POST", args.base_url, "/v1/conversations", conversation)["conversation"] conversation_id = created["id"] for index, segment in enumerate(segments): From d0ea4d435add71c61e999324990e3d39b13aab0a Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 12:13:56 +0700 Subject: [PATCH 19/58] Close desktop local-mode cloud leaks --- desktop/Desktop/Sources/APIClient.swift | 175 ++++++++++++++++-- desktop/Desktop/Sources/APIKeyService.swift | 12 +- desktop/Desktop/Sources/AuthService.swift | 58 ++++-- .../Desktop/Sources/Chat/AgentBridge.swift | 5 + .../FloatingControlBar/AgentPill.swift | 8 + desktop/Desktop/Sources/OmiApp.swift | 28 ++- .../Core/GeminiClient.swift | 3 + .../Services/EmbeddingService.swift | 3 + .../Services/SettingsSyncManager.swift | 12 ++ .../Desktop/Tests/APIClientRoutingTests.swift | 43 +++++ desktop/local-backend/README.md | 8 + .../local-backend/docs/local-mvp-runbook.md | 7 + 12 files changed, 326 insertions(+), 36 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 9d5b0a8f8dd..e70b7852edf 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -4016,6 +4016,10 @@ extension APIClient { /// Fetches daily summary settings func getDailySummarySettings() async throws -> DailySummarySettings { + if isUsingLocalDaemon { + return DailySummarySettings.localDefault + } + return try await get("v1/users/daily-summary-settings") } @@ -4023,6 +4027,18 @@ extension APIClient { func updateDailySummarySettings(enabled: Bool? = nil, hour: Int? = nil) async throws -> DailySummarySettings { + if isUsingLocalDaemon { + var settings = DailySummarySettings.localDefault + if let enabled { + settings.enabled = enabled + } + if let hour { + settings.hour = hour + } + settings.saveLocalDefault() + return settings + } + struct UpdateRequest: Encodable { let enabled: Bool? let hour: Int? @@ -4033,6 +4049,10 @@ extension APIClient { /// Fetches transcription preferences func getTranscriptionPreferences() async throws -> TranscriptionPreferences { + if isUsingLocalDaemon { + return TranscriptionPreferences.localDefault + } + return try await get("v1/users/transcription-preferences") } @@ -4040,6 +4060,18 @@ extension APIClient { func updateTranscriptionPreferences(singleLanguageMode: Bool? = nil, vocabulary: [String]? = nil) async throws -> TranscriptionPreferences { + if isUsingLocalDaemon { + var preferences = TranscriptionPreferences.localDefault + if let singleLanguageMode { + preferences.singleLanguageMode = singleLanguageMode + } + if let vocabulary { + preferences.vocabulary = vocabulary + } + preferences.saveLocalDefault() + return preferences + } + struct UpdateRequest: Encodable { let singleLanguageMode: Bool? let vocabulary: [String]? @@ -4055,11 +4087,22 @@ extension APIClient { /// Fetches user language preference func getUserLanguage() async throws -> UserLanguageResponse { + if isUsingLocalDaemon { + return UserLanguageResponse( + language: UserDefaults.standard.string(forKey: "transcriptionLanguage") ?? "en" + ) + } + return try await get("v1/users/language") } /// Updates user language preference func updateUserLanguage(_ language: String) async throws -> UserLanguageResponse { + if isUsingLocalDaemon { + UserDefaults.standard.set(language, forKey: "transcriptionLanguage") + return UserLanguageResponse(language: language) + } + struct UpdateRequest: Encodable { let language: String } @@ -4069,11 +4112,20 @@ extension APIClient { /// Fetches recording permission status func getRecordingPermission() async throws -> RecordingPermissionResponse { + if isUsingLocalDaemon { + return RecordingPermissionResponse.localDefault + } + return try await get("v1/users/store-recording-permission") } /// Sets recording permission func setRecordingPermission(enabled: Bool) async throws { + if isUsingLocalDaemon { + UserDefaults.standard.set(enabled, forKey: RecordingPermissionResponse.localDefaultsKey) + return + } + let url = URL(string: baseURL + "v1/users/store-recording-permission?value=\(enabled)")! var request = URLRequest(url: url) request.httpMethod = "POST" @@ -4089,6 +4141,10 @@ extension APIClient { /// Fetches private cloud sync setting func getPrivateCloudSync() async throws -> PrivateCloudSyncResponse { + if isUsingLocalDaemon { + return PrivateCloudSyncResponse(enabled: false) + } + try requireCapability(.cloudSync) return try await get("v1/users/private-cloud-sync") @@ -4113,6 +4169,10 @@ extension APIClient { /// Fetches notification settings func getNotificationSettings() async throws -> NotificationSettingsResponse { + if isUsingLocalDaemon { + return NotificationSettingsResponse.localDefault + } + return try await get("v1/users/notification-settings") } @@ -4120,6 +4180,18 @@ extension APIClient { func updateNotificationSettings(enabled: Bool? = nil, frequency: Int? = nil) async throws -> NotificationSettingsResponse { + if isUsingLocalDaemon { + var settings = NotificationSettingsResponse.localDefault + if let enabled { + settings.enabled = enabled + } + if let frequency { + settings.frequency = frequency + } + settings.saveLocalDefault() + return settings + } + struct UpdateRequest: Encodable { let enabled: Bool? let frequency: Int? @@ -4130,6 +4202,13 @@ extension APIClient { /// Fetches user profile func getUserProfile() async throws -> UserProfileResponse { + if isUsingLocalDaemon { + throw APIError.featureUnavailable( + feature: "user_profile", + reason: "Cloud profile sync is disabled in local daemon mode." + ) + } + return try await get("v1/users/profile") } @@ -4138,6 +4217,10 @@ extension APIClient { name: String? = nil, motivation: String? = nil, useCase: String? = nil, job: String? = nil, company: String? = nil ) async throws { + if isUsingLocalDaemon { + return + } + struct UpdateRequest: Encodable { let name: String? let motivation: String? @@ -4159,6 +4242,10 @@ extension APIClient { /// Fetches assistant settings from the backend func getAssistantSettings() async throws -> AssistantSettingsResponse { + if isUsingLocalDaemon { + return AssistantSettingsResponse() + } + return try await get("v1/users/assistant-settings") } @@ -4166,6 +4253,10 @@ extension APIClient { func updateAssistantSettings(_ settings: AssistantSettingsResponse) async throws -> AssistantSettingsResponse { + if isUsingLocalDaemon { + return settings + } + return try await patch("v1/users/assistant-settings", body: settings) } @@ -4313,26 +4404,60 @@ struct RebuildGraphResponse: Codable { /// Daily summary notification settings struct DailySummarySettings: Codable { - let enabled: Bool - let hour: Int + var enabled: Bool + var hour: Int + + static let enabledDefaultsKey = "local_daily_summary_enabled" + static let hourDefaultsKey = "local_daily_summary_hour" + + static var localDefault: DailySummarySettings { + DailySummarySettings( + enabled: UserDefaults.standard.object(forKey: enabledDefaultsKey) as? Bool ?? true, + hour: UserDefaults.standard.object(forKey: hourDefaultsKey) as? Int ?? 22 + ) + } + + func saveLocalDefault() { + UserDefaults.standard.set(enabled, forKey: Self.enabledDefaultsKey) + UserDefaults.standard.set(hour, forKey: Self.hourDefaultsKey) + } } /// Transcription preferences struct TranscriptionPreferences: Codable { - let singleLanguageMode: Bool - let vocabulary: [String] + var singleLanguageMode: Bool + var vocabulary: [String] enum CodingKeys: String, CodingKey { case singleLanguageMode = "single_language_mode" case vocabulary } + init(singleLanguageMode: Bool, vocabulary: [String]) { + self.singleLanguageMode = singleLanguageMode + self.vocabulary = vocabulary + } + init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) singleLanguageMode = try container.decodeIfPresent(Bool.self, forKey: .singleLanguageMode) ?? false vocabulary = try container.decodeIfPresent([String].self, forKey: .vocabulary) ?? [] } + + static var localDefault: TranscriptionPreferences { + let autoDetect = + UserDefaults.standard.object(forKey: "transcriptionAutoDetect") as? Bool ?? true + return TranscriptionPreferences( + singleLanguageMode: !autoDetect, + vocabulary: UserDefaults.standard.stringArray(forKey: "transcriptionVocabulary") ?? [] + ) + } + + func saveLocalDefault() { + UserDefaults.standard.set(!singleLanguageMode, forKey: "transcriptionAutoDetect") + UserDefaults.standard.set(vocabulary, forKey: "transcriptionVocabulary") + } } /// User language response @@ -4342,7 +4467,14 @@ struct UserLanguageResponse: Codable { /// Recording permission response struct RecordingPermissionResponse: Codable { - let enabled: Bool + var enabled: Bool + static let localDefaultsKey = "local_store_recording_permission" + + static var localDefault: RecordingPermissionResponse { + RecordingPermissionResponse( + enabled: UserDefaults.standard.object(forKey: localDefaultsKey) as? Bool ?? false + ) + } enum CodingKeys: String, CodingKey { case enabled = "store_recording_permission" @@ -4360,8 +4492,23 @@ struct PrivateCloudSyncResponse: Codable { /// Notification settings response struct NotificationSettingsResponse: Codable { - let enabled: Bool - let frequency: Int + var enabled: Bool + var frequency: Int + + static let enabledDefaultsKey = "local_notifications_enabled" + static let frequencyDefaultsKey = "notification_frequency" + + static var localDefault: NotificationSettingsResponse { + NotificationSettingsResponse( + enabled: UserDefaults.standard.object(forKey: enabledDefaultsKey) as? Bool ?? true, + frequency: UserDefaults.standard.object(forKey: frequencyDefaultsKey) as? Int ?? 3 + ) + } + + func saveLocalDefault() { + UserDefaults.standard.set(enabled, forKey: Self.enabledDefaultsKey) + UserDefaults.standard.set(frequency, forKey: Self.frequencyDefaultsKey) + } /// Frequency level description var frequencyDescription: String { @@ -4703,13 +4850,13 @@ struct FloatingBarSettingsResponse: Codable { } struct AssistantSettingsResponse: Codable { - var shared: SharedAssistantSettingsResponse? - var focus: FocusSettingsResponse? - var task: TaskSettingsResponse? - var insight: InsightSettingsResponse? - var memory: MemorySettingsResponse? - var floatingBar: FloatingBarSettingsResponse? - var updateChannel: String? + var shared: SharedAssistantSettingsResponse? = nil + var focus: FocusSettingsResponse? = nil + var task: TaskSettingsResponse? = nil + var insight: InsightSettingsResponse? = nil + var memory: MemorySettingsResponse? = nil + var floatingBar: FloatingBarSettingsResponse? = nil + var updateChannel: String? = nil enum CodingKeys: String, CodingKey { case shared, focus, task diff --git a/desktop/Desktop/Sources/APIKeyService.swift b/desktop/Desktop/Sources/APIKeyService.swift index 60f28f19a16..2c0b1ccba08 100644 --- a/desktop/Desktop/Sources/APIKeyService.swift +++ b/desktop/Desktop/Sources/APIKeyService.swift @@ -94,6 +94,13 @@ final class APIKeyService: ObservableObject { func fetchKeys() async { loadError = nil + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + isLoaded = true + applyToEnvironment() + log("APIKeyService: Local daemon mode skips backend API key fetch") + return + } + // Retry up to 3 times with backoff for attempt in 1...3 { do { @@ -168,7 +175,10 @@ final class APIKeyService: ObservableObject { /// True when the app has enough configuration to start transcription and screen analysis. /// In proxy mode (OMI_DESKTOP_API_URL set), no client-side Deepgram/Gemini keys are needed. nonisolated static var keysAvailable: Bool { - getenv("GEMINI_API_KEY") != nil || getenv("OMI_DESKTOP_API_URL") != nil + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + return currentGeminiKey != nil || byokKey(.deepgram) != nil + } + return getenv("GEMINI_API_KEY") != nil || getenv("OMI_DESKTOP_API_URL") != nil } private nonisolated static func nonEmptyStatic(_ s: String?) -> String? { diff --git a/desktop/Desktop/Sources/AuthService.swift b/desktop/Desktop/Sources/AuthService.swift index 2daa308e366..3cf9656cfac 100644 --- a/desktop/Desktop/Sources/AuthService.swift +++ b/desktop/Desktop/Sources/AuthService.swift @@ -47,6 +47,9 @@ class AuthService { private var apiBaseURL: String { DesktopBackendEnvironment.pythonBaseURL() } + private var isLocalDaemonMode: Bool { + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + } private var redirectURI: String { return "\(urlScheme)://auth/callback" } @@ -209,10 +212,15 @@ class AuthService { if let uid = user?.uid { Task { await RewindDatabase.shared.configure(userId: uid) } } - // Load name from backend profile (Firestore), then Firebase Auth as fallback - self?.loadNameFromBackendIfNeeded() - // Sync assistant settings from backend (fire-and-forget) - Task { await SettingsSyncManager.shared.syncFromServer() } + if self?.isLocalDaemonMode == true { + log("AUTH_LISTENER: local daemon mode skips backend profile and settings sync") + self?.loadNameFromFirebaseIfNeeded() + } else { + // Load name from backend profile (Firestore), then Firebase Auth as fallback + self?.loadNameFromBackendIfNeeded() + // Sync assistant settings from backend (fire-and-forget) + Task { await SettingsSyncManager.shared.syncFromServer() } + } } else { // Firebase has no user - check if we have a saved session (for dev builds where Keychain doesn't persist) let savedSignedIn = UserDefaults.standard.bool(forKey: self?.kAuthIsSignedIn ?? "") @@ -325,13 +333,19 @@ class AuthService { saveAuthState(isSignedIn: true, email: AuthState.shared.userEmail, userId: userId) Task { await RewindDatabase.shared.configure(userId: userId) } - if givenName.isEmpty { + if givenName.isEmpty && !isLocalDaemonMode { loadNameFromBackendIfNeeded() + } else if givenName.isEmpty { + loadNameFromFirebaseIfNeeded() } AnalyticsManager.shared.identify() AnalyticsManager.shared.signInCompleted(provider: "apple") - APIKeyService.shared.startFetchingKeys() + if isLocalDaemonMode { + log("OMI AUTH: Local daemon mode skips backend API key fetch after Apple sign-in") + } else { + APIKeyService.shared.startFetchingKeys() + } if !AnalyticsManager.isDevBuild { let sentryUser = User(userId: userId) @@ -341,7 +355,11 @@ class AuthService { } NSLog("OMI AUTH: Apple Sign in complete!") - fetchConversations() + if isLocalDaemonMode { + log("OMI AUTH: Local daemon mode skips cloud conversation fetch after Apple sign-in") + } else { + fetchConversations() + } } // MARK: - Sign in with Google (Web OAuth Flow) @@ -439,15 +457,21 @@ class AuthService { await RewindDatabase.shared.configure(userId: userId) // Try to load name from backend profile (Firestore), then Firebase Auth as fallback - if givenName.isEmpty { + if givenName.isEmpty && !isLocalDaemonMode { loadNameFromBackendIfNeeded() + } else if givenName.isEmpty { + loadNameFromFirebaseIfNeeded() } // Identify user first, then track sign-in completed // (identify must happen before events for PostHog person profiles to work) AnalyticsManager.shared.identify() AnalyticsManager.shared.signInCompleted(provider: provider) - APIKeyService.shared.startFetchingKeys() + if isLocalDaemonMode { + log("OMI AUTH: Local daemon mode skips backend API key fetch after sign-in") + } else { + APIKeyService.shared.startFetchingKeys() + } // Set Sentry user context for error tracking (skip in dev builds) if !AnalyticsManager.isDevBuild { @@ -460,7 +484,11 @@ class AuthService { NSLog("OMI AUTH: Sign in complete!") // Fetch conversations after successful sign-in - fetchConversations() + if isLocalDaemonMode { + log("OMI AUTH: Local daemon mode skips cloud conversation fetch after sign-in") + } else { + fetchConversations() + } } catch AuthError.cancelled { // User-initiated cancel: clear any stale error and stay silent. @@ -705,7 +733,7 @@ class AuthService { let isImpersonating = UserDefaults.standard.bool(forKey: "auth_isImpersonating") if isImpersonating { NSLog("OMI AUTH: Skipping Firebase displayName update (impersonation mode)") - } else if let user = Auth.auth().currentUser { + } else if !isLocalDaemonMode, let user = Auth.auth().currentUser { do { let changeRequest = user.createProfileChangeRequest() changeRequest.displayName = trimmedName @@ -717,13 +745,15 @@ class AuthService { } // Also save to backend profile (Firestore) so it persists across sign-in methods - if !isImpersonating { + if !isImpersonating && !isLocalDaemonMode { do { try await APIClient.shared.updateUserProfile(name: trimmedName) NSLog("OMI AUTH: Updated backend profile name to: %@", trimmedName) } catch { NSLog("OMI AUTH: Failed to update backend profile name (non-fatal): %@", error.localizedDescription) } + } else if isLocalDaemonMode { + NSLog("OMI AUTH: Local daemon mode persisted name locally only") } } @@ -750,6 +780,10 @@ class AuthService { /// but the user already has a name stored in Firestore from a previous sign-up. func loadNameFromBackendIfNeeded() { guard givenName.isEmpty else { return } + guard !isLocalDaemonMode else { + loadNameFromFirebaseIfNeeded() + return + } Task { do { let profile = try await APIClient.shared.getUserProfile() diff --git a/desktop/Desktop/Sources/Chat/AgentBridge.swift b/desktop/Desktop/Sources/Chat/AgentBridge.swift index 9b0dbfca507..423cfb4cf3b 100644 --- a/desktop/Desktop/Sources/Chat/AgentBridge.swift +++ b/desktop/Desktop/Sources/Chat/AgentBridge.swift @@ -159,6 +159,11 @@ actor AgentBridge { // SECURITY: if we can't get a Firebase token, refuse to start. The bridge // must NEVER fall back to ANTHROPIC_API_KEY as the Omi backend credential. if harnessMode == "piMono" { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + log("AgentBridge: pi-mono disabled in local daemon mode") + throw BridgeError.authMissing + } + let authService = await MainActor.run { AuthService.shared } let token: String do { diff --git a/desktop/Desktop/Sources/FloatingControlBar/AgentPill.swift b/desktop/Desktop/Sources/FloatingControlBar/AgentPill.swift index 3f43de984bc..59fa2f5a58c 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/AgentPill.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/AgentPill.swift @@ -115,6 +115,10 @@ final class AgentPillsManager: ObservableObject { } private static func runRouterCall(for query: String) async -> RouterDecision? { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + log("AgentPill: router skipped in local daemon mode") + return nil + } let baseURL = await APIClient.shared.rustBackendURL guard !baseURL.isEmpty else { log("AgentPill: router skipped — rustBackendURL empty, defaulting to chat") @@ -522,6 +526,10 @@ final class AgentPillsManager: ObservableObject { } fileprivate static func generateTitleAndAck(for query: String) async -> (title: String, ack: String)? { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + log("AgentPill: title gen skipped in local daemon mode") + return nil + } // Route through the desktop-backend's OpenAI-compatible proxy at // /v2/chat/completions instead of hitting api.anthropic.com directly. // This way we don't need a BYOK key (no partial-BYOK 403 risk), and diff --git a/desktop/Desktop/Sources/OmiApp.swift b/desktop/Desktop/Sources/OmiApp.swift index 7cd51f15d61..cd86871e01d 100644 --- a/desktop/Desktop/Sources/OmiApp.swift +++ b/desktop/Desktop/Sources/OmiApp.swift @@ -388,6 +388,8 @@ class AppDelegate: NSObject, NSApplicationDelegate, NSMenuDelegate { // Identify user if already signed in if AuthState.shared.isSignedIn { + let isLocalDaemonMode = + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon AnalyticsManager.shared.identify() // Set Sentry user context (now enabled for dev builds too) if let email = AuthState.shared.userEmail { @@ -397,18 +399,22 @@ class AppDelegate: NSObject, NSApplicationDelegate, NSMenuDelegate { AuthService.shared.displayName.isEmpty ? nil : AuthService.shared.displayName SentrySDK.setUser(sentryUser) } - // Fetch conversations on startup - AuthService.shared.fetchConversations() + if isLocalDaemonMode { + log("AppDelegate: signed-in local daemon launch skips cloud startup fetches") + } else { + // Fetch conversations on startup + AuthService.shared.fetchConversations() - // Fetch API keys from backend (keys are not bundled in the app) - APIKeyService.shared.startFetchingKeys() + // Fetch API keys from backend (keys are not bundled in the app) + APIKeyService.shared.startFetchingKeys() - // Fetch subscription plan for floating bar usage limits - Task { await FloatingBarUsageLimiter.shared.fetchPlan() } + // Fetch subscription plan for floating bar usage limits + Task { await FloatingBarUsageLimiter.shared.fetchPlan() } - // Check tier eligibility (at most once per day) - Task { - await TierManager.shared.checkTierIfNeeded() + // Check tier eligibility (at most once per day) + Task { + await TierManager.shared.checkTierIfNeeded() + } } // Report comprehensive settings state (at most once per day) @@ -1253,6 +1259,10 @@ class AppDelegate: NSObject, NSApplicationDelegate, NSMenuDelegate { func applicationDidBecomeActive(_ notification: Notification) { AnalyticsManager.shared.appBecameActive() // Sync remote assistant settings so server-side changes take effect promptly + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + log("AppDelegate: skipped activation settings sync in local daemon mode") + return + } Task { await SettingsSyncManager.shared.syncFromServer() } } diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift index f9d48a6c237..7b4d507cc0f 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift @@ -185,6 +185,9 @@ actor GeminiClient { /// Backend proxy base URL (from OMI_DESKTOP_API_URL env var) private static var proxyBaseURL: String { + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + return "" + } if let cString = getenv("OMI_DESKTOP_API_URL"), let url = String(validatingUTF8: cString), !url.isEmpty { return url.hasSuffix("/") ? url : url + "/" } diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift index 403b1a845c1..4f54a776c3e 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift @@ -18,6 +18,9 @@ actor EmbeddingService { /// Backend proxy base URL (from OMI_DESKTOP_API_URL env var) private static var proxyBaseURL: String { + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + return "" + } if let cString = getenv("OMI_DESKTOP_API_URL"), let url = String(validatingUTF8: cString), !url.isEmpty { return url.hasSuffix("/") ? url : url + "/" } diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/SettingsSyncManager.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/SettingsSyncManager.swift index ff5e68d060f..6b5dd9e436b 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/SettingsSyncManager.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/SettingsSyncManager.swift @@ -10,6 +10,10 @@ class SettingsSyncManager { /// Pull settings from server and apply non-nil values to local singletons. func syncFromServer() async { guard AuthService.shared.isSignedIn else { return } + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + log("SettingsSyncManager: skipped server sync in local daemon mode") + return + } do { let remote = try await APIClient.shared.getAssistantSettings() applyRemoteSettings(remote) @@ -21,6 +25,10 @@ class SettingsSyncManager { /// Push all current local settings to the server. func syncToServer() async { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + log("SettingsSyncManager: skipped server push in local daemon mode") + return + } let settings = buildFromLocal() do { let _ = try await APIClient.shared.updateAssistantSettings(settings) @@ -32,6 +40,10 @@ class SettingsSyncManager { /// Fire-and-forget partial update to server. func pushPartialUpdate(_ settings: AssistantSettingsResponse) { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + log("SettingsSyncManager: skipped partial server push in local daemon mode") + return + } Task { do { let _ = try await APIClient.shared.updateAssistantSettings(settings) diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 70500759df2..f0b2f27c513 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -737,6 +737,49 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertTrue(URLCapture.capturedRequests.isEmpty) } + func testLocalModeUserSettingsReturnLocalDefaultsBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "https://api.omi.me", 1) + setenv("OMI_DESKTOP_API_URL", "https://desktop-backend-hhibjajaja-uc.a.run.app", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + _ = try? await client.getDailySummarySettings() + _ = try? await client.updateDailySummarySettings(enabled: false, hour: 8) + _ = try? await client.getNotificationSettings() + _ = try? await client.updateNotificationSettings(enabled: false, frequency: 1) + _ = try? await client.getUserLanguage() + _ = try? await client.updateUserLanguage("en") + _ = try? await client.getRecordingPermission() + try? await client.setRecordingPermission(enabled: true) + _ = try? await client.getTranscriptionPreferences() + _ = try? await client.updateTranscriptionPreferences( + singleLanguageMode: true, + vocabulary: ["omi", "local"] + ) + _ = try? await client.getAssistantSettings() + _ = try? await client.updateAssistantSettings(AssistantSettingsResponse()) + + let cloudSync = try? await client.getPrivateCloudSync() + XCTAssertEqual(cloudSync?.enabled, false) + + assertNoOmiHostedBackendRequests(URLCapture.capturedRequests) + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + + @MainActor + func testLocalModeAPIKeyServiceSkipsBackendFetch() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "https://api.omi.me", 1) + setenv("OMI_DESKTOP_API_URL", "https://desktop-backend-hhibjajaja-uc.a.run.app", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + + await APIKeyService.shared.fetchKeys() + + assertNoOmiHostedBackendRequests(URLCapture.capturedRequests) + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + func testLocalModeForceProcessConversationFailsBeforeNetworkRequests() async { setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md index af354c7d59d..de3ab5fa0e3 100644 --- a/desktop/local-backend/README.md +++ b/desktop/local-backend/README.md @@ -106,6 +106,14 @@ Unavailable capabilities fail before building a request to Omi-hosted services. Local conversation CRUD/search/settings flows continue to use the configured loopback daemon URL and do not require Firebase auth. +Signed-in local daemon mode keeps the signed-in identity for UI/account context, +but it does not run cloud startup sync. App launch, activation, Settings load, +profile-name edits, assistant settings sync, backend API-key fetch, subscription +refresh, quota refresh, managed agent VM setup, and Crisp support polling are +skipped, answered from local defaults, or fail before network. Assistant/chat +features that depend on the Omi Gemini/Anthropic/provider proxy remain +unavailable until direct local provider support is added for that surface. + Hosted transcription endpoints are intentionally unavailable in local daemon mode. Direct live STT parity is not part of this MVP unless a future direct provider path is added. For user testing, import transcript text or JSON diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index cf1445af08f..2bf1d92f569 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -268,6 +268,10 @@ directly to the configured provider, not to Omi-hosted backend services. - Local full-text search over conversation and transcript text. - Local memories, action items, profile, and settings endpoints. - Desktop routing for local MVP flows without Firebase auth. +- Signed-in local daemon sessions keep account UI state without cloud startup + sync: launch/activation do not fetch cloud conversations, assistant settings, + backend API keys, subscription state, quotas, profile data, managed agent VM + state, or Crisp support messages. ## What Still Needs Provider Keys Or Cloud Mode @@ -275,6 +279,9 @@ directly to the configured provider, not to Omi-hosted backend services. cloud mode today. - Omi backend provider proxies, quota checks, subscriptions, payments, public sharing, Crisp support, managed agent VMs, and cloud sync require cloud mode. +- Proactive assistant and chat paths that currently depend on Omi-hosted + Gemini/Anthropic/provider proxy endpoints are disabled in local daemon mode + unless the path has direct local provider configuration. - Remote AI provider calls from the local daemon require explicit local provider settings/API keys. Without them, processing uses deterministic fallback output. - Fully offline local LLM/STT support is outside the current MVP. From 75bb204b21620aefa69c9aa1c82f2b0515e80d2e Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 12:22:52 +0700 Subject: [PATCH 20/58] Localize visible local-mode surfaces --- desktop/Desktop/Sources/APIClient.swift | 150 ++++++++++++++++++ .../Sources/MainWindow/Pages/TasksPage.swift | 28 ++-- .../Sources/Rewind/Core/GoalStorage.swift | 41 +++++ .../Sources/Rewind/Core/RewindDatabase.swift | 14 +- .../Desktop/Tests/APIClientRoutingTests.swift | 133 ++++++++++++++++ desktop/local-backend/src/main.rs | 67 +++++++- desktop/local-backend/src/routes.rs | 7 +- 7 files changed, 423 insertions(+), 17 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index e70b7852edf..bb282c4a414 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -1950,11 +1950,31 @@ extension APIClient { /// Deletes a memory by ID func deleteMemory(id: String) async throws { + let target = selectedBackendTarget + if target.mode == .localDaemon { + try await delete("v1/memories/\(id)", requireAuth: false, customBaseURL: target.baseURL) + return + } + try await delete("v3/memories/\(id)") } /// Edits a memory's content func editMemory(id: String, content: String) async throws { + let target = selectedBackendTarget + if target.mode == .localDaemon { + struct LocalEditRequest: Encodable { + let content: String + } + let _: LocalMemoryEnvelope = try await patch( + "v1/memories/\(id)", + body: LocalEditRequest(content: content), + requireAuth: false, + customBaseURL: target.baseURL + ) + return + } + struct EditRequest: Encodable { let value: String } @@ -1964,6 +1984,13 @@ extension APIClient { /// Updates a memory's visibility func updateMemoryVisibility(id: String, visibility: String) async throws { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "memory_visibility", + reason: "Memory visibility controls are cloud-sharing metadata and are disabled in local daemon mode." + ) + } + struct VisibilityRequest: Encodable { let value: String } @@ -1990,11 +2017,25 @@ extension APIClient { /// Marks all memories as read func markAllMemoriesRead() async throws { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "memory_read_status", + reason: "Bulk memory read status is cloud-only metadata and is disabled in local daemon mode." + ) + } + let _: MemoryStatusResponse = try await post("v3/memories/mark-all-read", body: EmptyBody()) } /// Updates visibility of all memories func updateAllMemoriesVisibility(visibility: String) async throws { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "memory_visibility", + reason: "Memory visibility controls are cloud-sharing metadata and are disabled in local daemon mode." + ) + } + struct VisibilityRequest: Encodable { let value: String } @@ -2004,6 +2045,13 @@ extension APIClient { /// Deletes all memories func deleteAllMemories() async throws { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "memory_bulk_delete", + reason: "Bulk memory deletion is not exposed by the local daemon yet. Delete individual local memories instead." + ) + } + try await delete("v3/memories") } @@ -2320,11 +2368,29 @@ extension APIClient { let title: String? let description: String? let dueAt: String? + let includeDueAt: Bool + let clearDueAt: Bool let status: String? enum CodingKeys: String, CodingKey { case title, description, status case dueAt = "due_at" + case clearDueAt = "clear_due_at" + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encodeIfPresent(title, forKey: .title) + try container.encodeIfPresent(description, forKey: .description) + if includeDueAt { + if clearDueAt { + try container.encodeNil(forKey: .dueAt) + try container.encode(true, forKey: .clearDueAt) + } else { + try container.encodeIfPresent(dueAt, forKey: .dueAt) + } + } + try container.encodeIfPresent(status, forKey: .status) } } let response: LocalActionItemEnvelope = try await patch( @@ -2333,6 +2399,8 @@ extension APIClient { title: description, description: description, dueAt: dueAt.map { formatter.string(from: $0) }, + includeDueAt: clearDueAt || dueAt != nil, + clearDueAt: clearDueAt, status: completed.map { $0 ? "completed" : "open" } ), requireAuth: false, @@ -2442,6 +2510,10 @@ extension APIClient { /// Batch update relevance scores for multiple action items func batchUpdateScores(_ scores: [(id: String, score: Int)]) async throws { + if selectedBackendTarget.mode == .localDaemon { + return + } + struct ScoreUpdate: Encodable { let id: String let relevance_score: Int @@ -2461,6 +2533,10 @@ extension APIClient { func batchUpdateSortOrders(_ updates: [(id: String, sortOrder: Int, indentLevel: Int)]) async throws { + if selectedBackendTarget.mode == .localDaemon { + return + } + struct SortUpdate: Encodable { let id: String let sort_order: Int @@ -2625,6 +2701,10 @@ extension APIClient { /// Fetches all active goals (up to 4). Uses 5-second cache to deduplicate parallel calls. func getGoals() async throws -> [Goal] { + if selectedBackendTarget.mode == .localDaemon { + return try await GoalStorage.shared.getLocalGoals() + } + if let cache = goalsCache, let time = goalsCacheTime, Date().timeIntervalSince(time) < 5 { return cache } @@ -2679,6 +2759,33 @@ extension APIClient { source: source ) + if selectedBackendTarget.mode == .localDaemon { + let now = Date() + let record = GoalRecord( + backendId: "local_goal_\(UUID().uuidString)", + backendSynced: false, + title: title, + goalDescription: description, + goalType: goalType.rawValue, + targetValue: targetValue, + currentValue: currentValue, + minValue: minValue, + maxValue: maxValue, + unit: unit, + isActive: true, + completedAt: nil, + deleted: false, + createdAt: now, + updatedAt: now + ) + let inserted = try await GoalStorage.shared.insertLocalGoal(record) + guard let goal = inserted.toGoal() else { + throw APIError.invalidResponse + } + goalsCache = nil + return goal + } + let goal: Goal = try await post("v1/goals", body: request) goalsCache = nil return goal @@ -2686,6 +2793,15 @@ extension APIClient { /// Updates a goal's progress func updateGoalProgress(goalId: String, currentValue: Double) async throws -> Goal { + if selectedBackendTarget.mode == .localDaemon { + try await GoalStorage.shared.updateProgress(backendId: goalId, currentValue: currentValue) + guard let goal = try await GoalStorage.shared.getGoal(backendId: goalId) else { + throw APIError.invalidResponse + } + goalsCache = nil + return goal + } + let url = URL(string: baseURL + "v1/goals/\(goalId)/progress?current_value=\(currentValue)")! var request = URLRequest(url: url) request.httpMethod = "PATCH" @@ -2725,6 +2841,20 @@ extension APIClient { targetValue: targetValue ) + if selectedBackendTarget.mode == .localDaemon { + try await GoalStorage.shared.updateGoal( + backendId: goalId, + title: title, + currentValue: currentValue, + targetValue: targetValue + ) + guard let goal = try await GoalStorage.shared.getGoal(backendId: goalId) else { + throw APIError.invalidResponse + } + goalsCache = nil + return goal + } + let goal: Goal = try await patch("v1/goals/\(goalId)", body: request) goalsCache = nil return goal @@ -2732,12 +2862,26 @@ extension APIClient { /// Gets completed goals for history func getCompletedGoals() async throws -> [Goal] { + if selectedBackendTarget.mode == .localDaemon { + return try await GoalStorage.shared.getLocalGoals(activeOnly: false) + .filter { !$0.isActive || $0.completedAt != nil } + } + let goals: [Goal] = try await get("v1/goals/completed") return goals } /// Completes a goal (marks as inactive with completed_at) func completeGoal(id: String) async throws -> Goal { + if selectedBackendTarget.mode == .localDaemon { + try await GoalStorage.shared.markCompleted(backendId: id) + guard let goal = try await GoalStorage.shared.getGoal(backendId: id, activeOnly: false) else { + throw APIError.invalidResponse + } + goalsCache = nil + return goal + } + struct CompleteGoalRequest: Encodable { let is_active: Bool let completed_at: String @@ -2770,6 +2914,12 @@ extension APIClient { /// Deletes a goal func deleteGoal(id: String) async throws { + if selectedBackendTarget.mode == .localDaemon { + try await GoalStorage.shared.softDelete(backendId: id) + goalsCache = nil + return + } + try await delete("v1/goals/\(id)") goalsCache = nil } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/TasksPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/TasksPage.swift index 2ce80da7e52..5e6b89cbc18 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/TasksPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/TasksPage.swift @@ -1104,6 +1104,7 @@ class TasksViewModel: ObservableObject { /// Collect current sort orders from all categories and write to SQLite + backend private func syncSortOrders() async { var updates: [(id: String, sortOrder: Int, indentLevel: Int)] = [] + let isLocalDaemon = DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon for category in TaskCategory.allCases { let orderedTasks = getOrderedTasks(for: category) @@ -1128,6 +1129,11 @@ class TasksViewModel: ObservableObject { log("TasksVM: Failed to write sort orders to SQLite: \(error)") } + if isLocalDaemon { + log("TasksVM: Saved \(updates.count) sort orders locally; backend batch sync disabled in local daemon mode") + return + } + // Sync to backend API do { try await APIClient.shared.batchUpdateSortOrders(updates) @@ -4347,17 +4353,19 @@ struct TaskRow: View { } // Share link button - Button { - Task { await copyShareLink() } - } label: { - Image(systemName: isCopyingLink ? "arrow.triangle.2.circlepath" : "arrowshape.turn.up.right.fill") - .scaledFont(size: 14) - .foregroundColor(OmiColors.textTertiary) - .frame(width: 24, height: 24) + if DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon { + Button { + Task { await copyShareLink() } + } label: { + Image(systemName: isCopyingLink ? "arrow.triangle.2.circlepath" : "arrowshape.turn.up.right.fill") + .scaledFont(size: 14) + .foregroundColor(OmiColors.textTertiary) + .frame(width: 24, height: 24) + } + .buttonStyle(.plain) + .disabled(isCopyingLink) + .help("Copy share link") } - .buttonStyle(.plain) - .disabled(isCopyingLink) - .help("Copy share link") // Delete button Button { diff --git a/desktop/Desktop/Sources/Rewind/Core/GoalStorage.swift b/desktop/Desktop/Sources/Rewind/Core/GoalStorage.swift index 37908f323db..1700b20e601 100644 --- a/desktop/Desktop/Sources/Rewind/Core/GoalStorage.swift +++ b/desktop/Desktop/Sources/Rewind/Core/GoalStorage.swift @@ -74,6 +74,23 @@ actor GoalStorage { } } + /// Get one local goal by backendId/local goal id. + func getGoal(backendId: String, activeOnly: Bool = true) async throws -> Goal? { + let db = try await ensureInitialized() + + return try await db.read { database in + var query = GoalRecord + .filter(Column("backendId") == backendId) + .filter(Column("deleted") == false) + + if activeOnly { + query = query.filter(Column("isActive") == true) + } + + return try query.fetchOne(database)?.toGoal() + } + } + // MARK: - Sync Operations /// Batch upsert from API response and reconcile. @@ -220,6 +237,30 @@ actor GoalStorage { } } + /// Update editable local goal fields by backendId/local goal id. + func updateGoal( + backendId: String, + title: String, + currentValue: Double, + targetValue: Double + ) async throws { + let db = try await ensureInitialized() + + try await db.write { database in + guard var record = try GoalRecord + .filter(Column("backendId") == backendId) + .fetchOne(database) else { + throw GoalStorageError.recordNotFound + } + + record.title = title + record.currentValue = currentValue + record.targetValue = targetValue + record.updatedAt = Date() + try record.update(database) + } + } + /// Soft-delete a goal by backendId func softDelete(backendId: String) async throws { let db = try await ensureInitialized() diff --git a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift index 7381d242b73..9a34eb88938 100644 --- a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift +++ b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift @@ -122,7 +122,7 @@ actor RewindDatabase { /// Falls back to the static currentUserId (set synchronously at app start) when /// configure() hasn't been called yet (e.g., TierManager triggers init early). private func userBaseDirectory() -> URL { - let appSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! + let appSupport = Self.applicationSupportDirectory() let userId = configuredUserId ?? RewindDatabase.currentUserId ?? "anonymous" return appSupport .appendingPathComponent("Omi", isDirectory: true) @@ -132,7 +132,7 @@ actor RewindDatabase { /// Static version of userBaseDirectory for nonisolated markCleanShutdown private static func staticUserBaseDirectory() -> URL { - let appSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! + let appSupport = applicationSupportDirectory() let userId = currentUserId ?? "anonymous" return appSupport .appendingPathComponent("Omi", isDirectory: true) @@ -140,6 +140,14 @@ actor RewindDatabase { .appendingPathComponent(userId, isDirectory: true) } + private static func applicationSupportDirectory() -> URL { + if let override = ProcessInfo.processInfo.environment["OMI_REWIND_DATABASE_ROOT"], + !override.isEmpty { + return URL(fileURLWithPath: override, isDirectory: true) + } + return FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! + } + /// Mark a clean shutdown by removing the running flag file. /// Call from applicationWillTerminate to avoid unnecessary integrity checks on next launch. /// This is nonisolated so it can be called synchronously from the main thread during termination. @@ -351,7 +359,7 @@ actor RewindDatabase { /// Handles both first-time migration (DB move) and partial re-runs (directory merges). private func migrateFromLegacyPathIfNeeded(to userDir: URL) { let fileManager = FileManager.default - let appSupport = fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! + let appSupport = Self.applicationSupportDirectory() let omiDir = appSupport.appendingPathComponent("Omi", isDirectory: true) // Determine migration source: prefer legacy root (Omi/omi.db), fall back to anonymous dir. diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index f0b2f27c513..aa8000fa549 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -406,6 +406,7 @@ final class APIClientRoutingTests: XCTestCase { setenv("OMI_DESKTOP_API_URL", "http://rust-test:9002", 1) unsetenv("OMI_DESKTOP_BACKEND_MODE") unsetenv("OMI_LOCAL_DAEMON_URL") + unsetenv("OMI_REWIND_DATABASE_ROOT") } override func tearDown() { @@ -413,6 +414,7 @@ final class APIClientRoutingTests: XCTestCase { unsetenv("OMI_DESKTOP_API_URL") unsetenv("OMI_DESKTOP_BACKEND_MODE") unsetenv("OMI_LOCAL_DAEMON_URL") + unsetenv("OMI_REWIND_DATABASE_ROOT") URLCapture.reset() super.tearDown() } @@ -841,6 +843,109 @@ final class APIClientRoutingTests: XCTestCase { label: "createMemory") } + func testLocalModeMemoryMutationsRouteToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + _ = try? await client.createMemory(content: "local memory") as CreateMemoryResponse + try? await client.editMemory(id: "mem-local", content: "updated local memory") + try? await client.deleteMemory(id: "mem-local") + + let requests = URLCapture.capturedRequests + XCTAssertEqual(requests.count, 3) + XCTAssertTrue(requests.allSatisfy { $0.url.host == "127.0.0.1" && $0.url.port == 8765 }) + XCTAssertTrue(requests.allSatisfy { $0.headers["Authorization"] == nil }) + XCTAssertEqual(requests.map(\.method), ["POST", "PATCH", "DELETE"]) + XCTAssertTrue(requests[0].url.path.contains("/v1/memories")) + XCTAssertTrue(requests[1].url.path.contains("/v1/memories/mem-local")) + XCTAssertTrue(requests[2].url.path.contains("/v1/memories/mem-local")) + } + + func testLocalModeCloudOnlyMemoryBulkOperationsFailBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + var errors: [Error] = [] + do { + try await client.updateMemoryVisibility(id: "mem-local", visibility: "public") + } catch { + errors.append(error) + } + do { + try await client.markAllMemoriesRead() + } catch { + errors.append(error) + } + do { + try await client.updateAllMemoriesVisibility(visibility: "private") + } catch { + errors.append(error) + } + do { + try await client.deleteAllMemories() + } catch { + errors.append(error) + } + + XCTAssertEqual(errors.count, 4) + XCTAssertTrue(errors.allSatisfy { + if case APIError.featureUnavailable = $0 { return true } + return false + }) + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + + func testLocalModeActionItemMutationsRouteToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + _ = try? await client.getActionItems() + _ = try? await client.createActionItem(description: "local task", dueAt: nil) + _ = try? await client.updateActionItem(id: "act-local", completed: true, clearDueAt: true) + try? await client.deleteActionItem(id: "act-local") + + let requests = URLCapture.capturedRequests + XCTAssertEqual(requests.count, 4) + XCTAssertTrue(requests.allSatisfy { $0.url.host == "127.0.0.1" && $0.url.port == 8765 }) + XCTAssertTrue(requests.allSatisfy { $0.headers["Authorization"] == nil }) + XCTAssertEqual(requests.map(\.method), ["GET", "POST", "PATCH", "DELETE"]) + XCTAssertTrue(requests[0].url.path.contains("/v1/action-items")) + XCTAssertTrue(requests[1].url.path.contains("/v1/action-items")) + XCTAssertTrue(requests[2].url.path.contains("/v1/action-items/act-local")) + XCTAssertTrue(requests[3].url.path.contains("/v1/action-items/act-local")) + + let body = requests[2].body.flatMap { + try? JSONSerialization.jsonObject(with: $0) as? [String: Any] + } + XCTAssertEqual(body?["status"] as? String, "completed") + XCTAssertTrue(body?.keys.contains("due_at") == true) + XCTAssertEqual(body?["clear_due_at"] as? Bool, true) + } + + func testLocalModeActionItemBatchAndShareDoNotCallCloud() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + try? await client.batchUpdateScores([(id: "act-local", score: 10)]) + try? await client.batchUpdateSortOrders([(id: "act-local", sortOrder: 1, indentLevel: 0)]) + do { + _ = try await client.shareTasks(taskIds: ["act-local"]) + XCTFail("expected task sharing to be unavailable") + } catch { + guard case APIError.featureUnavailable(let feature, _) = error else { + XCTFail("expected featureUnavailable for sharing, got \(error)") + return + } + XCTAssertEqual(feature, DesktopBackendEnvironment.Capability.publicSharing.rawValue) + } + + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + // -- Goals: manual URL path (PATCH → Python) -- func testUpdateGoalProgressRoutesToPython() async { @@ -852,6 +957,34 @@ final class APIClientRoutingTests: XCTestCase { label: "updateGoalProgress") } + func testLocalModeGoalAPIsUseLocalStorageBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let testUserId = "api-client-routing-goals-\(UUID().uuidString)" + let testRoot = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true) + .appendingPathComponent("omi-rewind-routing-\(UUID().uuidString)", isDirectory: true) + setenv("OMI_REWIND_DATABASE_ROOT", testRoot.path, 1) + await RewindDatabase.shared.close() + await RewindDatabase.shared.configure(userId: testUserId) + await GoalStorage.shared.invalidateCache() + let client = await makeTestClient() + + _ = try? await client.getGoals() + _ = try? await client.createGoal(title: "Local goal", targetValue: 1) + _ = try? await client.updateGoalProgress(goalId: "missing-local-goal", currentValue: 1) + _ = try? await client.updateGoal(goalId: "missing-local-goal", title: "Updated", currentValue: 1, targetValue: 2) + _ = try? await client.getCompletedGoals() + _ = try? await client.completeGoal(id: "missing-local-goal") + try? await client.deleteGoal(id: "missing-local-goal") + + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + + await RewindDatabase.shared.close() + await GoalStorage.shared.invalidateCache() + try? FileManager.default.removeItem(at: testRoot) + unsetenv("OMI_REWIND_DATABASE_ROOT") + } + // -- Apps (GET → Python) -- func testGetAppsRoutesToPython() async { diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index c2a5c5b3f1c..b045adfd799 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -200,6 +200,19 @@ mod tests { ) .await?; assert!(memory["memory"]["id"].is_string()); + let memory_id = memory["memory"]["id"].as_str().expect("memory id"); + + let updated_memory = request_json( + app.clone(), + Method::PATCH, + &format!("/v1/memories/{memory_id}"), + Some(json!({"content": "Prefers local-only desktop mode"})), + ) + .await?; + assert_eq!( + updated_memory["memory"]["content"], + "Prefers local-only desktop mode" + ); let memories = request_json(app.clone(), Method::GET, "/v1/memories", None).await?; assert_eq!(memories["memories"].as_array().unwrap().len(), 1); @@ -212,9 +225,57 @@ mod tests { ) .await?; assert_eq!(action_item["action_item"]["status"], "open"); + let action_item_id = action_item["action_item"]["id"] + .as_str() + .expect("action item id"); + let updated_action_item = request_json( + app.clone(), + Method::PATCH, + &format!("/v1/action-items/{action_item_id}"), + Some(json!({ + "status": "completed", + "due_at": "2026-05-19T12:00:00Z" + })), + ) + .await?; + assert_eq!(updated_action_item["action_item"]["status"], "completed"); + assert!(updated_action_item["action_item"]["completed_at"].is_string()); + assert_eq!( + updated_action_item["action_item"]["due_at"], + "2026-05-19T12:00:00Z" + ); + + let cleared_due_at = request_json( + app.clone(), + Method::PATCH, + &format!("/v1/action-items/{action_item_id}"), + Some(json!({"clear_due_at": true, "due_at": null})), + ) + .await?; + assert!(cleared_due_at["action_item"]["due_at"].is_null()); + + request_status( + app.clone(), + Method::DELETE, + &format!("/v1/memories/{memory_id}"), + None, + StatusCode::NO_CONTENT, + ) + .await?; + let memories = request_json(app.clone(), Method::GET, "/v1/memories", None).await?; + assert!(memories["memories"].as_array().unwrap().is_empty()); + + request_status( + app.clone(), + Method::DELETE, + &format!("/v1/action-items/{action_item_id}"), + None, + StatusCode::NO_CONTENT, + ) + .await?; let action_items = request_json(app, Method::GET, "/v1/action-items", None).await?; - assert_eq!(action_items["action_items"].as_array().unwrap().len(), 1); + assert!(action_items["action_items"].as_array().unwrap().is_empty()); Ok(()) } @@ -507,7 +568,7 @@ mod tests { uri: &str, body: Option, expected_status: StatusCode, - ) -> Result { + ) -> Result<()> { let request_body = match body { Some(value) => Body::from(serde_json::to_vec(&value)?), None => Body::empty(), @@ -527,6 +588,6 @@ mod tests { "unexpected status {status}: {}", String::from_utf8_lossy(&bytes) ); - Ok(serde_json::from_slice(&bytes)?) + Ok(()) } } diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index 6df7575ce0a..b6b35fd9593 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -598,6 +598,7 @@ struct UpdateActionItemRequest { description: Option, status: Option, due_at: Option>, + clear_due_at: Option, metadata: Option, } @@ -616,7 +617,11 @@ async fn update_action_item( title: request.title, description: request.description, status: request.status, - due_at: request.due_at.map(Some), + due_at: if request.clear_due_at.unwrap_or(false) { + Some(None) + } else { + request.due_at.map(Some) + }, metadata: request.metadata, }, ) From 93aff7052553be47e86a3548fabb405a256112d1 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 12:31:07 +0700 Subject: [PATCH 21/58] Expand local MVP boundary verification --- desktop/Desktop/Sources/APIClient.swift | 17 +++ .../Sources/ScreenActivitySyncService.swift | 9 ++ .../Desktop/Sources/ViewModelContainer.swift | 9 +- .../Desktop/Tests/APIClientRoutingTests.swift | 25 +++- desktop/local-backend/tools/e2e_smoke.sh | 133 +++++++++++++++++- 5 files changed, 188 insertions(+), 5 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index bb282c4a414..b39b9c8522e 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -2926,6 +2926,10 @@ extension APIClient { /// Get all scores (daily, weekly, overall) with default tab selection func getScores(date: Date? = nil) async throws -> ScoreResponse { + if selectedBackendTarget.mode == .localDaemon { + return ScoreResponse.emptyLocal(date: date) + } + var endpoint = "v1/scores" if let date = date { let formatter = DateFormatter() @@ -3540,6 +3544,19 @@ struct ScoreResponse: Codable { case daily, weekly, overall, date case defaultTab = "default_tab" } + + static func emptyLocal(date: Date? = nil) -> ScoreResponse { + let empty = ScoreData(score: 0, completedTasks: 0, totalTasks: 0) + let formatter = DateFormatter() + formatter.dateFormat = "yyyy-MM-dd" + return ScoreResponse( + daily: empty, + weekly: empty, + overall: empty, + defaultTab: "daily", + date: formatter.string(from: date ?? Date()) + ) + } } // MARK: - App Models diff --git a/desktop/Desktop/Sources/ScreenActivitySyncService.swift b/desktop/Desktop/Sources/ScreenActivitySyncService.swift index 0e7a428fb77..8623127021a 100644 --- a/desktop/Desktop/Sources/ScreenActivitySyncService.swift +++ b/desktop/Desktop/Sources/ScreenActivitySyncService.swift @@ -28,6 +28,11 @@ actor ScreenActivitySyncService { /// Start the sync loop. Call after auth is established and database is ready. func start() { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + stop() + log("ScreenActivitySync: disabled in local daemon mode") + return + } guard !isRunning else { log("ScreenActivitySync: already running") return @@ -149,6 +154,10 @@ actor ScreenActivitySyncService { // MARK: - HTTP push private func pushRows(_ rows: [[String: Any]]) async -> Bool { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + log("ScreenActivitySync: refusing backend push in local daemon mode") + return false + } let payload: [String: Any] = ["rows": rows] guard let jsonData = try? JSONSerialization.data(withJSONObject: payload) else { diff --git a/desktop/Desktop/Sources/ViewModelContainer.swift b/desktop/Desktop/Sources/ViewModelContainer.swift index fdf7f90da18..ad3d3d74d01 100644 --- a/desktop/Desktop/Sources/ViewModelContainer.swift +++ b/desktop/Desktop/Sources/ViewModelContainer.swift @@ -110,8 +110,13 @@ class ViewModelContainer: ObservableObject { // Wire task chat coordinator to view model for delete/purge operations tasksViewModel.chatCoordinator = taskChatCoordinator - // Start screen activity sync to backend (Firestore + Pinecone) - await ScreenActivitySyncService.shared.start() + // Start cloud screen activity sync only when cloud backends are selected. + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + await ScreenActivitySyncService.shared.stop() + log("DATA LOAD: Skipping ScreenActivitySyncService in local daemon mode") + } else { + await ScreenActivitySyncService.shared.start() + } } isLoading = false diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index aa8000fa549..aefe55f2eaa 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -537,6 +537,7 @@ final class APIClientRoutingTests: XCTestCase { let client = await makeTestClient() _ = try? await client.getConversations() + _ = try? await client.getConversationsCount() _ = try? await client.getConversation(id: "local-123") as ServerConversation _ = try? await client.searchConversations(query: "offline") try? await client.updateConversationTitle(id: "local-123", title: "Offline") @@ -545,9 +546,11 @@ final class APIClientRoutingTests: XCTestCase { try? await client.deleteConversation(id: "local-123") let requests = URLCapture.capturedRequests - XCTAssertEqual(requests.count, 7) + XCTAssertEqual(requests.count, 8) XCTAssertTrue(requests.allSatisfy { $0.url.host == "127.0.0.1" && $0.url.port == 9876 }) + XCTAssertTrue(requests.allSatisfy { $0.url.scheme == "http" }) XCTAssertTrue(requests.allSatisfy { $0.headers["Authorization"] == nil }) + XCTAssertTrue(requests.contains { $0.url.path == "/v1/conversations/count" }) assertNoOmiHostedBackendRequests(requests) } @@ -778,6 +781,26 @@ final class APIClientRoutingTests: XCTestCase { await APIKeyService.shared.fetchKeys() + XCTAssertTrue(APIKeyService.shared.isLoaded) + XCTAssertNil(APIKeyService.shared.loadError) + assertNoOmiHostedBackendRequests(URLCapture.capturedRequests) + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + + func testLocalModeDashboardScoresReturnLocalDefaultBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "https://api.omi.me", 1) + setenv("OMI_DESKTOP_API_URL", "https://desktop-backend-hhibjajaja-uc.a.run.app", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + let scores = try? await client.getScores(date: Date(timeIntervalSince1970: 0)) + + XCTAssertEqual(scores?.daily.score, 0) + XCTAssertEqual(scores?.weekly.totalTasks, 0) + XCTAssertEqual(scores?.overall.completedTasks, 0) + XCTAssertEqual(scores?.defaultTab, "daily") + XCTAssertEqual(scores?.date, "1970-01-01") assertNoOmiHostedBackendRequests(URLCapture.capturedRequests) XCTAssertTrue(URLCapture.capturedRequests.isEmpty) } diff --git a/desktop/local-backend/tools/e2e_smoke.sh b/desktop/local-backend/tools/e2e_smoke.sh index 960e4fca3a0..c1fe66221e4 100755 --- a/desktop/local-backend/tools/e2e_smoke.sh +++ b/desktop/local-backend/tools/e2e_smoke.sh @@ -10,12 +10,25 @@ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: print(s.getsockname()[1]) PY )}" +PROVIDER_PORT="${OMI_LOCAL_PROVIDER_STUB_PORT:-$(python3 - <<'PY' +import socket +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + print(s.getsockname()[1]) +PY +)}" BASE_URL="http://127.0.0.1:${PORT}" LOG_FILE="${DATA_DIR}/daemon.log" +PROVIDER_LOG_FILE="${DATA_DIR}/provider-requests.jsonl" DAEMON_PID="" +PROVIDER_PID="" cleanup() { + if [[ -n "${PROVIDER_PID}" ]] && kill -0 "${PROVIDER_PID}" >/dev/null 2>&1; then + kill "${PROVIDER_PID}" >/dev/null 2>&1 || true + wait "${PROVIDER_PID}" >/dev/null 2>&1 || true + fi if [[ -n "${DAEMON_PID}" ]] && kill -0 "${DAEMON_PID}" >/dev/null 2>&1; then kill "${DAEMON_PID}" >/dev/null 2>&1 || true wait "${DAEMON_PID}" >/dev/null 2>&1 || true @@ -120,6 +133,71 @@ stop_daemon() { DAEMON_PID="" } +start_provider_stub() { + local port="$1" + : >"${PROVIDER_LOG_FILE}" + PROVIDER_LOG_FILE="${PROVIDER_LOG_FILE}" python3 - "$port" <<'PY' & +import json +import os +import sys +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + +port = int(sys.argv[1]) +log_file = os.environ["PROVIDER_LOG_FILE"] + +class Handler(BaseHTTPRequestHandler): + def do_POST(self): + length = int(self.headers.get("content-length", "0")) + body = self.rfile.read(length).decode("utf-8") if length else "" + with open(log_file, "a", encoding="utf-8") as handle: + handle.write(json.dumps({ + "path": self.path, + "authorization": self.headers.get("authorization"), + "body": json.loads(body) if body else {}, + }) + "\n") + payload = { + "choices": [{ + "message": { + "content": json.dumps({ + "title": "Stub provider title", + "overview": "Stub provider overview", + "action_items": [{ + "title": "Review local provider egress", + "description": "Confirm processing only calls the configured loopback provider." + }], + "memories": [{ + "content": "Local provider smoke used loopback only.", + "category": "validation" + }] + }) + } + }] + } + encoded = json.dumps(payload).encode("utf-8") + self.send_response(200) + self.send_header("content-type", "application/json") + self.send_header("content-length", str(len(encoded))) + self.end_headers() + self.wfile.write(encoded) + + def log_message(self, *_): + return + +ThreadingHTTPServer(("127.0.0.1", port), Handler).serve_forever() +PY + PROVIDER_PID="$!" + + for _ in $(seq 1 40); do + if kill -0 "${PROVIDER_PID}" >/dev/null 2>&1; then + return + fi + sleep 0.1 + done + + echo "Provider stub failed to start" >&2 + exit 1 +} + wait_for_completed_job() { local job_id="$1" local job_file @@ -201,7 +279,7 @@ settings_file="$(request PUT /v1/settings '{ "local_first": true, "ai_provider": { "kind": "openai_compatible", - "base_url": "http://127.0.0.1:43210/v1", + "base_url": "http://127.0.0.1:'"${PROVIDER_PORT}"'/v1", "model": "local-stub", "api_key": "local-test-key" } @@ -209,6 +287,56 @@ settings_file="$(request PUT /v1/settings '{ assert_json_value "${settings_file}" "settings.0.key" "ai_provider" assert_json_value "${settings_file}" "settings.1.key" "local_first" +start_provider_stub "${PROVIDER_PORT}" + +provider_conversation_file="$(request POST /v1/conversations '{ + "id": "conv-provider-smoke", + "session_id": "session-provider-smoke", + "title": "", + "overview": "" +}')" +assert_json_value "${provider_conversation_file}" "conversation.id" "conv-provider-smoke" + +provider_segment_file="$(request POST /v1/conversations/conv-provider-smoke/transcript-segments '{ + "id": "seg-provider-smoke-0", + "text": "Use the configured loopback provider for local processing.", + "start_ms": 0, + "end_ms": 1600, + "segment_index": 0, + "source": "smoke" +}')" +assert_json_value "${provider_segment_file}" "transcript_segment.id" "seg-provider-smoke-0" + +provider_job_file="$(request POST /v1/conversations/conv-provider-smoke/finalize-transcript)" +provider_job_id="$(json_value "processing_job.id" "${provider_job_file}")" +completed_provider_job_file="$(wait_for_completed_job "${provider_job_id}")" +assert_json_value "${completed_provider_job_file}" "processing_job.status" "completed" +configured_provider="$(json_embedded_value "processing_job.result_json" "${completed_provider_job_file}" provider)" +if [[ "${configured_provider}" != "openai_compatible" ]]; then + echo "Expected openai_compatible processing provider, got ${configured_provider}" >&2 + echo "Response file: ${completed_provider_job_file}" >&2 + exit 1 +fi +provider_processed_file="$(request GET /v1/conversations/conv-provider-smoke)" +assert_json_value "${provider_processed_file}" "conversation.title" "Stub provider title" + +provider_request_count="$(wc -l <"${PROVIDER_LOG_FILE}" | tr -d ' ')" +if [[ "${provider_request_count}" != "1" ]]; then + echo "Expected one local provider request, got ${provider_request_count}" >&2 + cat "${PROVIDER_LOG_FILE}" >&2 || true + exit 1 +fi +provider_request_path="$(json_value "path" "${PROVIDER_LOG_FILE}")" +provider_request_auth="$(json_value "authorization" "${PROVIDER_LOG_FILE}")" +if [[ "${provider_request_path}" != "/v1/chat/completions" ]]; then + echo "Expected local provider chat completions path, got ${provider_request_path}" >&2 + exit 1 +fi +if [[ "${provider_request_auth}" != "Bearer local-test-key" ]]; then + echo "Expected configured local provider bearer token, got ${provider_request_auth}" >&2 + exit 1 +fi + stop_daemon start_daemon @@ -220,6 +348,7 @@ persisted_search_file="$(request GET '/v1/search/conversations?q=backend')" assert_json_value "${persisted_search_file}" "results.0.conversation_id" "conv-e2e-smoke" request DELETE /v1/conversations/conv-e2e-smoke >/dev/null +request DELETE /v1/conversations/conv-provider-smoke >/dev/null deleted_list_file="$(request GET /v1/conversations)" deleted_count="$(python3 - "${deleted_list_file}" <<'PY' @@ -241,6 +370,6 @@ cat < Date: Tue, 19 May 2026 12:42:35 +0700 Subject: [PATCH 22/58] Harden local daemon processing contracts --- desktop/local-backend/src/main.rs | 106 +++++++++++++++++ desktop/local-backend/src/processing.rs | 99 +++++++++++++--- desktop/local-backend/src/providers.rs | 96 ++++++++++++++++ desktop/local-backend/src/routes.rs | 13 ++- desktop/local-backend/src/storage.rs | 139 ++++++++++++++++++++++- desktop/local-backend/tools/e2e_smoke.sh | 99 +++++++++++++++- 6 files changed, 529 insertions(+), 23 deletions(-) diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index b045adfd799..5ace2206b26 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -523,6 +523,112 @@ mod tests { Ok(()) } + #[tokio::test] + async fn processed_conversation_replay_conflicts_when_create_payload_changes() -> Result<()> { + let app = test_app()?; + let body = json!({ + "id": "conv-processed-conflict", + "session_id": "session-processed-conflict", + "title": "Original title", + "overview": "Original overview", + "metadata": {"source": "test"} + }); + request_json( + app.clone(), + Method::POST, + "/v1/conversations", + Some(body.clone()), + ) + .await?; + request_json( + app.clone(), + Method::PATCH, + "/v1/conversations/conv-processed-conflict", + Some(json!({"status": "processed"})), + ) + .await?; + + let replayed = + request_json(app.clone(), Method::POST, "/v1/conversations", Some(body)).await?; + assert_eq!(replayed["conversation"]["id"], "conv-processed-conflict"); + + request_status( + app.clone(), + Method::POST, + "/v1/conversations", + Some(json!({ + "id": "conv-processed-conflict", + "session_id": "session-processed-conflict", + "title": "Changed title", + "overview": "Original overview", + "metadata": {"source": "test"} + })), + StatusCode::CONFLICT, + ) + .await?; + request_status( + app, + Method::POST, + "/v1/conversations", + Some(json!({ + "id": "conv-processed-conflict", + "session_id": "session-processed-conflict", + "title": "Original title", + "overview": "Original overview", + "metadata": {"source": "changed"} + })), + StatusCode::CONFLICT, + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn settings_reject_omi_firebase_and_google_provider_hosts() -> Result<()> { + let app = test_app()?; + + for base_url in [ + "https://api.omi.me/v1", + "https://api.omiapi.com/v1", + "https://desktop-backend-dt5lrfkkoa-uc.a.run.app/v1", + "https://identitytoolkit.googleapis.com/v1", + "https://based-hardware.firebaseapp.com/v1", + ] { + request_status( + app.clone(), + Method::PUT, + "/v1/settings", + Some(json!({ + "ai_provider": { + "kind": "openai_compatible", + "base_url": base_url, + "api_key": "blocked" + } + })), + StatusCode::BAD_REQUEST, + ) + .await?; + } + + let allowed = request_json( + app, + Method::PUT, + "/v1/settings", + Some(json!({ + "ai_provider": { + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "api_key": "local" + } + })), + ) + .await?; + assert_eq!(allowed["settings"][0]["key"], "ai_provider"); + + Ok(()) + } + fn test_app() -> Result { let config = Config { bind_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), diff --git a/desktop/local-backend/src/processing.rs b/desktop/local-backend/src/processing.rs index 79cded852a2..a384b3b1abe 100644 --- a/desktop/local-backend/src/processing.rs +++ b/desktop/local-backend/src/processing.rs @@ -64,8 +64,8 @@ pub async fn process_next_job(store: &Store) -> Result> { let message = error.to_string(); store .processing_jobs() - .fail(&job.id, &message) - .with_context(|| format!("failed to fail job {}", job.id)) + .fail_or_requeue(&job.id, &message) + .with_context(|| format!("failed to fail or requeue job {}", job.id)) } } } @@ -95,20 +95,12 @@ async fn process_conversation_job(store: &Store, job: &ProcessingJob) -> Result< .collect::>() .join(" "); let output = if let Some(provider) = configured_openai_provider(store)? { - match provider + let mut output = provider .complete_json(processing_prompt(&transcript)) .await - .and_then(parse_provider_output) - { - Ok(mut output) => { - output.provider = "openai_compatible".to_string(); - output - } - Err(error) => { - tracing::warn!(error = %error, "provider processing failed; using deterministic fallback"); - fallback_output(&transcript) - } - } + .and_then(parse_provider_output)?; + output.provider = "openai_compatible".to_string(); + output } else { fallback_output(&transcript) }; @@ -301,6 +293,8 @@ mod tests { use crate::storage::{NewConversation, NewProcessingJob, NewTranscriptSegment}; use super::*; + use serde_json::Map; + use std::net::TcpListener; #[test] fn fallback_processing_is_deterministic_and_empty_for_items_and_memories() { @@ -382,6 +376,83 @@ mod tests { Ok(()) } + #[tokio::test] + async fn provider_failures_requeue_until_retry_limit() -> Result<()> { + let store = Store::open_in_memory()?; + let conversation_id = deterministic_id("conv", &["session-provider-failure"]); + let listener = TcpListener::bind("127.0.0.1:0")?; + let unused_addr = listener.local_addr()?; + drop(listener); + + let mut settings = Map::new(); + settings.insert( + "ai_provider".to_string(), + json!({ + "kind": "openai_compatible", + "base_url": format!("http://{unused_addr}/v1"), + "model": "offline-model", + "api_key": "local-test-key" + }), + ); + store.settings().upsert_many(settings)?; + store.conversations().create(NewConversation { + id: conversation_id.clone(), + session_id: "session-provider-failure".to_string(), + title: String::new(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + store.transcripts().append(NewTranscriptSegment { + id: deterministic_id("seg", &[&conversation_id, "0"]), + conversation_id: conversation_id.clone(), + session_id: "session-provider-failure".to_string(), + speaker_id: None, + speaker_label: None, + text: "Provider failures should retry instead of falling back.".to_string(), + start_ms: 0, + end_ms: 1000, + segment_index: 0, + source: None, + metadata: None, + })?; + let enqueued = store.processing_jobs().enqueue(NewProcessingJob { + id: deterministic_id("job", &["provider-failure", &conversation_id]), + kind: "finalize_transcript".to_string(), + target_conversation_id: Some(conversation_id.clone()), + max_retries: Some(2), + payload: Some(json!({"conversation_id": conversation_id})), + })?; + + let first = process_next_job(&store) + .await? + .expect("failed job should be returned"); + assert_eq!(first.id, enqueued.id); + assert_eq!(first.status, ProcessingJobStatus::Queued); + assert_eq!(first.retry_count, 1); + assert!(first.last_error.as_deref().unwrap_or("").contains("failed")); + + let second = process_next_job(&store) + .await? + .expect("exhausted job should be returned"); + assert_eq!(second.id, enqueued.id); + assert_eq!(second.status, ProcessingJobStatus::Failed); + assert_eq!(second.retry_count, 2); + assert!(second + .last_error + .as_deref() + .unwrap_or("") + .contains("failed")); + + let conversation = store + .conversations() + .get(second.target_conversation_id.as_ref().unwrap())? + .expect("conversation should exist"); + assert_eq!(conversation.status, "open"); + + Ok(()) + } + #[test] fn provider_style_processing_outputs_are_retry_safe() -> Result<()> { let store = Store::open_in_memory()?; diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs index 5f915c985f7..5ebc2a09047 100644 --- a/desktop/local-backend/src/providers.rs +++ b/desktop/local-backend/src/providers.rs @@ -122,6 +122,7 @@ pub fn load_openai_config(store: &Store) -> Result Result Result<()> { + let kind = value["kind"].as_str().unwrap_or_default(); + if kind != "openai" && kind != "openai_compatible" { + return Ok(()); + } + + let base_url = value["base_url"] + .as_str() + .unwrap_or("https://api.openai.com/v1"); + validate_provider_base_url(base_url) +} + +fn validate_provider_base_url(base_url: &str) -> Result<()> { + let url = reqwest::Url::parse(base_url) + .with_context(|| format!("provider base_url is not a valid URL: {base_url}"))?; + match url.scheme() { + "http" | "https" => {} + scheme => return Err(anyhow!("provider base_url scheme is not allowed: {scheme}")), + } + + let host = url + .host_str() + .ok_or_else(|| anyhow!("provider base_url must include a host"))? + .trim_end_matches('.') + .to_ascii_lowercase(); + if is_denied_provider_host(&host) { + return Err(anyhow!( + "provider base_url host is not allowed in local daemon mode: {host}" + )); + } + Ok(()) +} + +fn is_denied_provider_host(host: &str) -> bool { + matches!(host, "api.omi.me" | "api.omiapi.com") + || (host.starts_with("desktop-backend-") && host.ends_with(".a.run.app")) + || host == "firebase.google.com" + || host.ends_with(".firebase.google.com") + || host.ends_with(".firebaseio.com") + || host.ends_with(".firebaseapp.com") + || host.ends_with(".firebasestorage.app") + || matches!( + host, + "googleapis.com" + | "identitytoolkit.googleapis.com" + | "securetoken.googleapis.com" + | "firestore.googleapis.com" + | "firebasestorage.googleapis.com" + | "firebase.googleapis.com" + | "www.googleapis.com" + | "oauth2.googleapis.com" + ) +} + #[cfg(test)] mod tests { use super::*; @@ -195,6 +250,47 @@ mod tests { Ok(()) } + #[test] + fn provider_validation_denies_omi_firebase_and_google_hosts() { + for base_url in [ + "https://api.omi.me/v1", + "https://api.omiapi.com/v1", + "https://desktop-backend-dt5lrfkkoa-uc.a.run.app/v1", + "https://identitytoolkit.googleapis.com/v1", + "https://based-hardware.firebaseio.com", + "https://based-hardware.firebaseapp.com", + "https://based-hardware.firebasestorage.app", + ] { + assert!( + validate_provider_setting(&json!({ + "kind": "openai_compatible", + "base_url": base_url, + "api_key": "key" + })) + .is_err(), + "{base_url} should be denied" + ); + } + } + + #[test] + fn provider_validation_allows_direct_provider_and_loopback_hosts() -> Result<()> { + for base_url in [ + "https://api.openai.com/v1", + "https://api.anthropic.com/v1", + "https://generativelanguage.googleapis.com/v1beta", + "http://127.0.0.1:11434/v1", + "http://localhost:43210/v1", + ] { + validate_provider_setting(&json!({ + "kind": "openai_compatible", + "base_url": base_url, + "api_key": "key" + }))?; + } + Ok(()) + } + #[tokio::test] async fn openai_compatible_provider_uses_local_stub_endpoint() -> Result<()> { let app = Router::new().route( diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index b6b35fd9593..5e05216c3c3 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; use serde_json::{json, Map, Value}; use crate::{ - processing, + processing, providers, storage::{ deterministic_id, AppendTranscriptResult, NewActionItem, NewConversation, NewMemory, NewProcessingJob, NewTranscriptSegment, UpdateActionItem, UpdateConversation, UpdateMemory, @@ -689,6 +689,12 @@ async fn update_settings( State(state): State, Json(values): Json>, ) -> ApiResult { + for key in ["ai_provider", "provider"] { + if let Some(value) = values.get(key) { + providers::validate_provider_setting(value) + .map_err(|error| ApiError::bad_request(error.to_string()))?; + } + } let settings = state .store .settings() @@ -767,11 +773,10 @@ fn conversation_matches_new( existing: &crate::storage::Conversation, new: &NewConversation, ) -> anyhow::Result { - let mutable_fields_match = existing.status != "open" - || (existing.title == new.title && existing.overview == new.overview); Ok(existing.id == new.id && existing.session_id == new.session_id - && mutable_fields_match + && existing.title == new.title + && existing.overview == new.overview && new .started_at .map(|started_at| existing.started_at == started_at) diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index 656ea7c6ca8..ffcbad323de 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -881,11 +881,24 @@ impl ProcessingJobRepository { julianday(completed_at) ) ) + OR ( + status = 'failed' + AND retry_count >= max_retries + AND julianday(failed_at) >= COALESCE( + ( + SELECT MAX(julianday(updated_at)) + FROM transcript_segments + WHERE conversation_id = ?2 AND deleted_at IS NULL + ), + julianday(failed_at) + ) + ) ) ORDER BY CASE status WHEN 'running' THEN 0 WHEN 'queued' THEN 1 + WHEN 'failed' THEN 2 ELSE 2 END, updated_at DESC @@ -969,20 +982,38 @@ impl ProcessingJobRepository { } } - pub fn fail(&self, id: &str, error: &str) -> Result> { + pub fn fail_or_requeue(&self, id: &str, error: &str) -> Result> { let now = Utc::now(); let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); let changed = conn .execute( r#" UPDATE processing_jobs - SET status = 'failed', last_error = ?2, failed_at = ?3, updated_at = ?3, - retry_count = retry_count + 1, sync_version = sync_version + 1 + SET status = CASE + WHEN retry_count + 1 < max_retries THEN 'queued' + ELSE 'failed' + END, + last_error = ?2, + failed_at = CASE + WHEN retry_count + 1 < max_retries THEN NULL + ELSE ?3 + END, + queued_at = CASE + WHEN retry_count + 1 < max_retries THEN ?3 + ELSE queued_at + END, + started_at = CASE + WHEN retry_count + 1 < max_retries THEN NULL + ELSE started_at + END, + updated_at = ?3, + retry_count = retry_count + 1, + sync_version = sync_version + 1 WHERE id = ?1 AND deleted_at IS NULL "#, params![id, error, now], ) - .context("failed to mark processing job failed")?; + .context("failed to fail or requeue processing job")?; drop(conn); if changed == 0 { @@ -2182,4 +2213,104 @@ mod tests { Ok(()) } + + #[test] + fn processing_job_failure_requeues_until_max_retries() -> Result<()> { + let store = Store::open_in_memory()?; + store.conversations().create(NewConversation { + id: "conversation-1".to_string(), + session_id: "session-retry".to_string(), + title: String::new(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + let repository = store.processing_jobs(); + let enqueued = repository.enqueue(NewProcessingJob { + id: deterministic_id("job", &["retry", "conversation-1"]), + kind: "finalize_transcript".to_string(), + target_conversation_id: Some("conversation-1".to_string()), + max_retries: Some(2), + payload: Some(serde_json::json!({"conversation_id": "conversation-1"})), + })?; + + let claimed = repository + .claim_next_queued()? + .expect("queued job should be claimed"); + assert_eq!(claimed.id, enqueued.id); + assert_eq!(claimed.status, ProcessingJobStatus::Running); + + let retryable = repository + .fail_or_requeue(&claimed.id, "provider timeout")? + .expect("job should still exist"); + assert_eq!(retryable.status, ProcessingJobStatus::Queued); + assert_eq!(retryable.retry_count, 1); + assert_eq!(retryable.last_error.as_deref(), Some("provider timeout")); + assert!(retryable.failed_at.is_none()); + assert!(retryable.started_at.is_none()); + + let reclaimed = repository + .claim_next_queued()? + .expect("retryable job should be claimable again"); + assert_eq!(reclaimed.id, enqueued.id); + assert_eq!(reclaimed.status, ProcessingJobStatus::Running); + assert_eq!(reclaimed.retry_count, 1); + assert!(reclaimed.last_error.is_none()); + + let exhausted = repository + .fail_or_requeue(&reclaimed.id, "provider still unavailable")? + .expect("job should still exist"); + assert_eq!(exhausted.status, ProcessingJobStatus::Failed); + assert_eq!(exhausted.retry_count, 2); + assert_eq!( + exhausted.last_error.as_deref(), + Some("provider still unavailable") + ); + assert!(exhausted.failed_at.is_some()); + + Ok(()) + } + + #[test] + fn exhausted_failed_finalize_job_is_reusable_for_duplicate_finalize() -> Result<()> { + let store = Store::open_in_memory()?; + let conversation_id = deterministic_id("conv", &["session-failed-finalize"]); + store.conversations().create(NewConversation { + id: conversation_id.clone(), + session_id: "session-failed-finalize".to_string(), + title: String::new(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + store.processing_jobs().enqueue(NewProcessingJob { + id: deterministic_id("job", &["failed-finalize", &conversation_id]), + kind: "finalize_transcript".to_string(), + target_conversation_id: Some(conversation_id.clone()), + max_retries: Some(1), + payload: Some(serde_json::json!({"conversation_id": conversation_id})), + })?; + + let claimed = store + .processing_jobs() + .claim_next_queued()? + .expect("queued job should be claimed"); + let failed = store + .processing_jobs() + .fail_or_requeue(&claimed.id, "exhausted")? + .expect("job should still exist"); + assert_eq!(failed.status, ProcessingJobStatus::Failed); + + let reusable = store + .processing_jobs() + .reusable_for_conversation( + "finalize_transcript", + failed.target_conversation_id.as_ref().unwrap(), + )? + .expect("failed exhausted job should be reusable"); + assert_eq!(reusable.id, failed.id); + assert_eq!(reusable.status, ProcessingJobStatus::Failed); + + Ok(()) + } } diff --git a/desktop/local-backend/tools/e2e_smoke.sh b/desktop/local-backend/tools/e2e_smoke.sh index c1fe66221e4..0210ef04b7b 100755 --- a/desktop/local-backend/tools/e2e_smoke.sh +++ b/desktop/local-backend/tools/e2e_smoke.sh @@ -17,6 +17,13 @@ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: print(s.getsockname()[1]) PY )}" +UNUSED_PROVIDER_PORT="${OMI_LOCAL_UNUSED_PROVIDER_PORT:-$(python3 - <<'PY' +import socket +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + print(s.getsockname()[1]) +PY +)}" BASE_URL="http://127.0.0.1:${PORT}" LOG_FILE="${DATA_DIR}/daemon.log" PROVIDER_LOG_FILE="${DATA_DIR}/provider-requests.jsonl" @@ -88,6 +95,25 @@ request() { printf '%s\n' "${output}" } +request_status() { + local method="$1" + local path="$2" + local expected_status="$3" + local body="${4:-}" + local status + if [[ -n "${body}" ]]; then + status="$(curl -sS -o /dev/null -w "%{http_code}" -X "${method}" "${BASE_URL}${path}" \ + -H "Content-Type: application/json" \ + --data "${body}")" + else + status="$(curl -sS -o /dev/null -w "%{http_code}" -X "${method}" "${BASE_URL}${path}")" + fi + if [[ "${status}" != "${expected_status}" ]]; then + echo "Expected ${method} ${path} to return HTTP ${expected_status}, got ${status}" >&2 + exit 1 + fi +} + assert_json_value() { local file="$1" local path="$2" @@ -214,6 +240,22 @@ wait_for_completed_job() { exit 1 } +wait_for_failed_job() { + local job_id="$1" + local job_file + for _ in $(seq 1 40); do + job_file="$(request GET "/v1/processing-jobs/${job_id}")" + if [[ "$(json_value "processing_job.status" "${job_file}")" == "failed" ]]; then + printf '%s\n' "${job_file}" + return + fi + request POST "/v1/processing-jobs/process-next" >/dev/null || true + sleep 0.25 + done + echo "Processing job ${job_id} did not fail after retry exhaustion" >&2 + exit 1 +} + echo "Starting local daemon smoke on ${BASE_URL}" echo "Data dir: ${DATA_DIR}" @@ -275,6 +317,22 @@ processed_file="$(request GET /v1/conversations/conv-e2e-smoke)" assert_json_value "${processed_file}" "conversation.status" "processed" assert_json_value "${processed_file}" "conversation.title" "Plan the backend free desktop MVP and verify" +request_status POST /v1/conversations 409 '{ + "id": "conv-e2e-smoke", + "session_id": "session-e2e-smoke", + "title": "Changed replay title", + "overview": "Created by local smoke" +}' + +request_status PUT /v1/settings 400 '{ + "ai_provider": { + "kind": "openai_compatible", + "base_url": "https://api.omi.me/v1", + "model": "blocked", + "api_key": "blocked" + } +}' + settings_file="$(request PUT /v1/settings '{ "local_first": true, "ai_provider": { @@ -337,6 +395,44 @@ if [[ "${provider_request_auth}" != "Bearer local-test-key" ]]; then exit 1 fi +retry_settings_file="$(request PUT /v1/settings '{ + "ai_provider": { + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:'"${UNUSED_PROVIDER_PORT}"'/v1", + "model": "offline-stub", + "api_key": "local-test-key" + } +}')" +assert_json_value "${retry_settings_file}" "settings.0.key" "ai_provider" + +retry_conversation_file="$(request POST /v1/conversations '{ + "id": "conv-provider-retry-smoke", + "session_id": "session-provider-retry-smoke", + "title": "", + "overview": "" +}')" +assert_json_value "${retry_conversation_file}" "conversation.id" "conv-provider-retry-smoke" + +retry_segment_file="$(request POST /v1/conversations/conv-provider-retry-smoke/transcript-segments '{ + "id": "seg-provider-retry-smoke-0", + "text": "A transient provider failure should retry and eventually fail usefully.", + "start_ms": 0, + "end_ms": 1600, + "segment_index": 0, + "source": "smoke" +}')" +assert_json_value "${retry_segment_file}" "transcript_segment.id" "seg-provider-retry-smoke-0" + +retry_job_file="$(request POST /v1/conversations/conv-provider-retry-smoke/finalize-transcript)" +retry_job_id="$(json_value "processing_job.id" "${retry_job_file}")" +failed_retry_job_file="$(wait_for_failed_job "${retry_job_id}")" +assert_json_value "${failed_retry_job_file}" "processing_job.status" "failed" +assert_json_value "${failed_retry_job_file}" "processing_job.retry_count" "3" + +duplicate_failed_job_file="$(request POST /v1/conversations/conv-provider-retry-smoke/finalize-transcript)" +assert_json_value "${duplicate_failed_job_file}" "processing_job.id" "${retry_job_id}" +assert_json_value "${duplicate_failed_job_file}" "processing_job.status" "failed" + stop_daemon start_daemon @@ -349,6 +445,7 @@ assert_json_value "${persisted_search_file}" "results.0.conversation_id" "conv-e request DELETE /v1/conversations/conv-e2e-smoke >/dev/null request DELETE /v1/conversations/conv-provider-smoke >/dev/null +request DELETE /v1/conversations/conv-provider-retry-smoke >/dev/null deleted_list_file="$(request GET /v1/conversations)" deleted_count="$(python3 - "${deleted_list_file}" <<'PY' @@ -370,6 +467,6 @@ cat < Date: Tue, 19 May 2026 12:49:01 +0700 Subject: [PATCH 23/58] Polish local-mode desktop surfaces --- desktop/Desktop/Sources/APIClient.swift | 15 +++- desktop/Desktop/Sources/AppState.swift | 18 +++++ .../MainWindow/Pages/MemoriesPage.swift | 78 +++++++++++-------- .../Sources/TranscriptionRetryService.swift | 19 ++++- .../Desktop/Tests/APIClientRoutingTests.swift | 36 +++++++++ desktop/local-backend/src/routes.rs | 12 ++- desktop/local-backend/src/storage.rs | 27 ++++++- desktop/local-backend/tools/e2e_smoke.sh | 23 +++++- 8 files changed, 187 insertions(+), 41 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index b39b9c8522e..e34a5ce8b00 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -457,7 +457,20 @@ extension APIClient { ) async throws -> [ServerConversation] { let target = mvpBackendTarget if target.mode == .localDaemon { - let endpoint = "v1/conversations?limit=\(limit)" + var queryItems: [String] = [ + "limit=\(limit)", + "offset=\(offset)", + ] + if let startDate = startDate { + queryItems.append("start_date=\(Self.queryValue(Self.isoString(startDate)))") + } + if let endDate = endDate { + queryItems.append("end_date=\(Self.queryValue(Self.isoString(endDate)))") + } + if let starred = starred { + queryItems.append("starred=\(starred)") + } + let endpoint = "v1/conversations?\(queryItems.joined(separator: "&"))" let response: LocalConversationsResponse = try await get( endpoint, requireAuth: false, diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index f8725a1ce20..890fa05f330 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -1741,6 +1741,7 @@ class AppState: ObservableObject { // may arrive on the new WebSocket after currentSessionId and recordingStartTime have changed. finishedSessionId = currentSessionId finishedRecordingStartTime = recordingStartTime + let sessionIdToFinalizeLocally = currentSessionId // Mark current DB session as finished before stopping // (backend will process it; memory_created event may arrive on the new session's WebSocket) @@ -1841,6 +1842,23 @@ class AppState: ObservableObject { } } + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon, + let sessionId = sessionIdToFinalizeLocally + { + do { + if let conversationId = try await TranscriptionRetryService.shared + .finalizeLocalDaemonSessionNow(sessionId: sessionId) + { + log( + "Transcription: Local daemon finalized session \(sessionId) as conversation \(conversationId)" + ) + } + } catch { + logError("Transcription: Local daemon finalize failed for session \(sessionId)", error: error) + return .error(error.localizedDescription) + } + } + // Refresh the conversations list to show the new conversation await loadConversations() diff --git a/desktop/Desktop/Sources/MainWindow/Pages/MemoriesPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/MemoriesPage.swift index f661fcd6618..d92548b3efe 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/MemoriesPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/MemoriesPage.swift @@ -939,6 +939,10 @@ struct MemoriesPage: View { @State private var pendingSelectedTags: Set = [] @State private var showManagementMenu = false + private var isLocalDaemonMode: Bool { + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + } + var body: some View { Group { if let conversation = viewModel.linkedConversation { @@ -1147,20 +1151,22 @@ struct MemoriesPage: View { .buttonStyle(.plain) .help("Add Memory") - // Management menu - Button { - showManagementMenu = true - } label: { - Image(systemName: "chevron.down") - .scaledFont(size: 12, weight: .medium) - .foregroundColor(.black) - .frame(width: 42, height: 42) - .background(OmiColors.textPrimary) - .clipShape(RoundedRectangle(cornerRadius: 16, style: .continuous)) - } - .buttonStyle(.plain) - .popover(isPresented: $showManagementMenu, arrowEdge: .bottom) { - managementMenuPopover + if !isLocalDaemonMode { + // Management menu + Button { + showManagementMenu = true + } label: { + Image(systemName: "chevron.down") + .scaledFont(size: 12, weight: .medium) + .foregroundColor(.black) + .frame(width: 42, height: 42) + .background(OmiColors.textPrimary) + .clipShape(RoundedRectangle(cornerRadius: 16, style: .continuous)) + } + .buttonStyle(.plain) + .popover(isPresented: $showManagementMenu, arrowEdge: .bottom) { + managementMenuPopover + } } } .padding(.horizontal, 28) @@ -1952,6 +1958,10 @@ struct MemoryDetailSheet: View { @State private var isEditingContent = false @State private var editContentText = "" + private var isLocalDaemonMode: Bool { + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + } + private func dismissSheet() { if let onDismiss = onDismiss { onDismiss() @@ -1978,26 +1988,28 @@ struct MemoryDetailSheet: View { Spacer() - // Public toggle - HStack(spacing: 6) { - Text("Public") - .scaledFont(size: 13) - .foregroundColor(OmiColors.textSecondary) - if viewModel.isTogglingVisibility { - ProgressView() - .scaleEffect(0.7) - } else { - Toggle( - "", - isOn: Binding( - get: { memory.isPublic }, - set: { _ in - Task { await viewModel.toggleVisibility(memory) } - } + if !isLocalDaemonMode { + // Public toggle + HStack(spacing: 6) { + Text("Public") + .scaledFont(size: 13) + .foregroundColor(OmiColors.textSecondary) + if viewModel.isTogglingVisibility { + ProgressView() + .scaleEffect(0.7) + } else { + Toggle( + "", + isOn: Binding( + get: { memory.isPublic }, + set: { _ in + Task { await viewModel.toggleVisibility(memory) } + } + ) ) - ) - .toggleStyle(.switch) - .labelsHidden() + .toggleStyle(.switch) + .labelsHidden() + } } } diff --git a/desktop/Desktop/Sources/TranscriptionRetryService.swift b/desktop/Desktop/Sources/TranscriptionRetryService.swift index c91ad832e40..a96923c73e2 100644 --- a/desktop/Desktop/Sources/TranscriptionRetryService.swift +++ b/desktop/Desktop/Sources/TranscriptionRetryService.swift @@ -14,6 +14,19 @@ class TranscriptionRetryService { private init() {} + /// Immediately uploads and finalizes a finished local session instead of waiting for the retry timer. + func finalizeLocalDaemonSessionNow(sessionId: Int64) async throws -> String? { + guard await APIClient.shared.isUsingLocalDaemon else { return nil } + guard let session = try await TranscriptionStorage.shared.getSession(id: sessionId) else { + return nil + } + if session.status == .completed { + return session.backendId + } + let conversation = try await uploadSessionToLocalDaemon(session, sessionId: sessionId) + return conversation.id + } + // MARK: - Service Lifecycle /// Start the retry service (call on app launch) @@ -260,13 +273,14 @@ class TranscriptionRetryService { } } - private func uploadSessionToLocalDaemon(_ session: TranscriptionSessionRecord, sessionId: Int64) async throws { + @discardableResult + private func uploadSessionToLocalDaemon(_ session: TranscriptionSessionRecord, sessionId: Int64) async throws -> ServerConversation { try await TranscriptionStorage.shared.markSessionUploading(id: sessionId) let segments = try await TranscriptionStorage.shared.getSegments(sessionId: sessionId) guard !segments.isEmpty else { try await TranscriptionStorage.shared.markSessionFailed( id: sessionId, error: "No transcript segments to upload to local daemon") - return + throw APIError.httpError(statusCode: 400) } let conversation = try await APIClient.shared.createLocalDaemonConversation( @@ -286,6 +300,7 @@ class TranscriptionRetryService { try await APIClient.shared.finalizeLocalDaemonTranscript(conversationId: conversation.id) try await TranscriptionStorage.shared.markSessionCompleted(id: sessionId, backendId: conversation.id) log("TranscriptionRetryService: Session \(sessionId) stored in local daemon as \(conversation.id)") + return conversation } } diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index aefe55f2eaa..ee8ce943d88 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -554,6 +554,42 @@ final class APIClientRoutingTests: XCTestCase { assertNoOmiHostedBackendRequests(requests) } + func testLocalModeGetConversationsPreservesVisibleFilters() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "https://api.omi.me", 1) + setenv("OMI_DESKTOP_API_URL", "https://desktop-backend-hhibjajaja-uc.a.run.app", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + let startDate = Date(timeIntervalSince1970: 1_700_000_000) + let endDate = Date(timeIntervalSince1970: 1_700_086_400) + + _ = try? await client.getConversations( + limit: 25, + offset: 10, + startDate: startDate, + endDate: endDate, + starred: true + ) + + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 8765, + pathContains: "v1/conversations", method: "GET", + label: "local filtered getConversations") + XCTAssertNil(requests.first?.headers["Authorization"]) + let queryItems = Dictionary( + uniqueKeysWithValues: URLComponents(url: requests[0].url, resolvingAgainstBaseURL: false)! + .queryItems! + .compactMap { item in item.value.map { (item.name, $0) } } + ) + XCTAssertEqual(queryItems["limit"], "25") + XCTAssertEqual(queryItems["offset"], "10") + XCTAssertEqual(queryItems["starred"], "true") + XCTAssertNotNil(queryItems["start_date"]) + XCTAssertNotNil(queryItems["end_date"]) + assertNoOmiHostedBackendRequests(requests) + } + func testLocalModeTranscriptImportRoutesToLocalDaemonWithoutAuth() async { setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) setenv("OMI_PYTHON_API_URL", "https://api.omi.me", 1) diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index 5e05216c3c3..3d439139970 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -156,6 +156,10 @@ async fn profile_status(State(state): State) -> ApiResult { #[derive(Deserialize)] struct ListQuery { limit: Option, + offset: Option, + start_date: Option>, + end_date: Option>, + starred: Option, } async fn list_conversations( @@ -165,7 +169,13 @@ async fn list_conversations( let conversations = state .store .conversations() - .list(limit_or_default(query.limit)) + .list_filtered( + limit_or_default(query.limit), + query.offset.unwrap_or(0).max(0), + query.start_date, + query.end_date, + query.starred, + ) .map_err(ApiError::internal)?; Ok(Json(json!({ "conversations": conversations }))) } diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index ffcbad323de..e8424ea183a 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -526,6 +526,17 @@ impl ConversationRepository { } pub fn list(&self, limit: i64) -> Result> { + self.list_filtered(limit, 0, None, None, None) + } + + pub fn list_filtered( + &self, + limit: i64, + offset: i64, + start_date: Option>, + end_date: Option>, + starred: Option, + ) -> Result> { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); let mut stmt = conn .prepare( @@ -534,13 +545,25 @@ impl ConversationRepository { updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred FROM conversations WHERE deleted_at IS NULL + AND (?1 IS NULL OR started_at >= ?1) + AND (?2 IS NULL OR started_at < ?2) + AND (?3 IS NULL OR starred = ?3) ORDER BY updated_at DESC - LIMIT ?1 + LIMIT ?4 OFFSET ?5 "#, ) .context("failed to prepare conversation list query")?; let rows = stmt - .query_map(params![limit], map_conversation) + .query_map( + params![ + start_date, + end_date, + starred.map(|value| if value { 1 } else { 0 }), + limit, + offset + ], + map_conversation, + ) .context("failed to list conversations")?; collect_rows(rows) } diff --git a/desktop/local-backend/tools/e2e_smoke.sh b/desktop/local-backend/tools/e2e_smoke.sh index 0210ef04b7b..55c221833f4 100755 --- a/desktop/local-backend/tools/e2e_smoke.sh +++ b/desktop/local-backend/tools/e2e_smoke.sh @@ -289,12 +289,31 @@ assert_json_value "${segment_file}" "transcript_segment.id" "seg-e2e-smoke-0" updated_file="$(request PATCH /v1/conversations/conv-e2e-smoke '{ "title": "Smoke updated", - "overview": "Updated before processing" + "overview": "Updated before processing", + "starred": true }')" assert_json_value "${updated_file}" "conversation.title" "Smoke updated" +assert_json_value "${updated_file}" "conversation.starred" "True" list_file="$(request GET /v1/conversations)" assert_json_value "${list_file}" "conversations.0.id" "conv-e2e-smoke" +filtered_list_file="$(request GET '/v1/conversations?limit=10&offset=0&starred=true&start_date=2000-01-01T00:00:00Z&end_date=2100-01-01T00:00:00Z')" +assert_json_value "${filtered_list_file}" "conversations.0.id" "conv-e2e-smoke" +unstarred_list_file="$(request GET '/v1/conversations?starred=false')" +unstarred_count="$(python3 - "${unstarred_list_file}" <<'PY' +import json +import sys + +with open(sys.argv[1], "r", encoding="utf-8") as handle: + data = json.load(handle) +print(len(data["conversations"])) +PY +)" +if [[ "${unstarred_count}" != "0" ]]; then + echo "Expected starred=false filter to hide starred smoke conversation, got ${unstarred_count}" >&2 + echo "Response file: ${unstarred_list_file}" >&2 + exit 1 +fi search_file="$(request GET '/v1/search/conversations?q=deterministic')" assert_json_value "${search_file}" "results.0.conversation_id" "conv-e2e-smoke" @@ -467,6 +486,6 @@ cat < Date: Tue, 19 May 2026 12:55:42 +0700 Subject: [PATCH 24/58] Harden desktop auth secret boundaries --- desktop/Desktop/Sources/APIClient.swift | 14 ++- desktop/Desktop/Sources/APIKeyService.swift | 59 +++++++++- desktop/Desktop/Sources/AuthService.swift | 27 ++++- .../FloatingControlBar/AgentPill.swift | 4 +- .../Core/GeminiClient.swift | 4 + .../Services/EmbeddingService.swift | 53 +++++++-- .../Desktop/Tests/APIClientRoutingTests.swift | 110 ++++++++++++++++++ 7 files changed, 246 insertions(+), 25 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index e34a5ce8b00..c28fd4a0a6b 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -83,7 +83,9 @@ actor APIClient { // MARK: - Request Building - func buildHeaders(requireAuth: Bool = true) async throws -> [String: String] { + func buildHeaders(requireAuth: Bool = true, includeBYOK: Bool = false) async throws + -> [String: String] + { var headers: [String: String] = [ "Content-Type": "application/json", "X-App-Platform": "macos", @@ -100,10 +102,10 @@ actor APIClient { } } - // BYOK: attach user-provided keys so the backend uses them for LLM/STT - // calls this request triggers. Sent per-request; never stored server-side. - for (provider, entry) in APIKeyService.byokSnapshot { - headers[provider.headerName] = entry.key + // BYOK keys are raw provider credentials. Only attach them to explicit + // provider/STT work, never to ordinary account, settings, session, or support requests. + if includeBYOK { + headers.merge(APIKeyService.byokHeaders()) { _, new in new } } return headers @@ -5790,7 +5792,7 @@ extension APIClient { var request = URLRequest(url: url) request.httpMethod = "POST" request.timeoutInterval = 60 - request.allHTTPHeaderFields = try await buildHeaders() + request.allHTTPHeaderFields = try await buildHeaders(includeBYOK: true) request.httpBody = try JSONEncoder().encode(body) let (data, response) = try await session.data(for: request) diff --git a/desktop/Desktop/Sources/APIKeyService.swift b/desktop/Desktop/Sources/APIKeyService.swift index 2c0b1ccba08..75b5869f87c 100644 --- a/desktop/Desktop/Sources/APIKeyService.swift +++ b/desktop/Desktop/Sources/APIKeyService.swift @@ -6,7 +6,7 @@ import Foundation /// /// Also hosts the Bring-Your-Own-Key (BYOK) free-plan flow: when the user supplies /// their own OpenAI, Anthropic, Gemini, and Deepgram keys, the app sends them along -/// with every request and the backend skips subscription billing. Keys live in +/// with provider-work requests and the backend skips subscription billing. Keys live in /// UserDefaults (reusing the existing dev-override AppStorage pattern); the backend /// only ever sees SHA-256 fingerprints for state tracking. /// @@ -83,7 +83,7 @@ final class APIKeyService: ObservableObject { } var effectiveFirebaseApiKey: String? { - firebaseApiKey + firebaseApiKey ?? Self.bootstrapFirebaseApiKey } var effectiveGoogleCalendarApiKey: String? { @@ -163,6 +163,43 @@ final class APIKeyService: ObservableObject { return s } + nonisolated static var bootstrapFirebaseApiKey: String? { + nonEmptyStatic(getenv("FIREBASE_API_KEY").flatMap { String(validatingUTF8: $0) }) + ?? bundledFirebaseApiKey() + } + + nonisolated static func bundledFirebaseApiKey( + bundle: Bundle = .main, + resourceName: String = "GoogleService-Info" + ) -> String? { + if let key = firebaseApiKeyFromBundle(bundle, resourceName: resourceName) { + return key + } + #if SWIFT_PACKAGE + return firebaseApiKeyFromBundle(.module, resourceName: resourceName) + #else + return nil + #endif + } + + private nonisolated static func firebaseApiKeyFromBundle( + _ bundle: Bundle, + resourceName: String + ) -> String? { + guard let url = bundle.url(forResource: resourceName, withExtension: "plist"), + let data = try? Data(contentsOf: url), + let plist = try? PropertyListSerialization.propertyList( + from: data, + options: [], + format: nil + ) as? [String: Any], + let key = plist["API_KEY"] as? String + else { + return nil + } + return nonEmptyStatic(key) + } + // MARK: - Thread-safe key access (for non-MainActor contexts) // These read from UserDefaults (thread-safe) and getenv() (set by applyToEnvironment). // Use these from actors, nonisolated inits, and background threads. @@ -217,4 +254,22 @@ final class APIKeyService: ObservableObject { } return out } + + nonisolated static func byokHeaders(providers: Set? = nil) -> [String: String] { + byokSnapshot.reduce(into: [:]) { headers, element in + let (provider, entry) = element + if providers == nil || providers?.contains(provider) == true { + headers[provider.headerName] = entry.key + } + } + } + + nonisolated static func applyBYOKHeaders( + to request: inout URLRequest, + providers: Set? = nil + ) { + for (header, value) in byokHeaders(providers: providers) { + request.setValue(value, forHTTPHeaderField: header) + } + } } diff --git a/desktop/Desktop/Sources/AuthService.swift b/desktop/Desktop/Sources/AuthService.swift index 3cf9656cfac..9b1f20b8805 100644 --- a/desktop/Desktop/Sources/AuthService.swift +++ b/desktop/Desktop/Sources/AuthService.swift @@ -79,13 +79,21 @@ class AuthService { private let kAuthTokenExpiry = "auth_tokenExpiry" private let kAuthTokenUserId = "auth_tokenUserId" // User ID that owns the stored token - // Firebase Web API key — fetched from backend via APIKeyService, set as env var. - // No hardcoded fallback — if the key isn't available, auth operations will fail - // with a clear error instead of silently using a potentially wrong key. + // Firebase Web API key. Clean web OAuth needs this before the user is + // authenticated, so prefer env and bundled Firebase options before any + // post-auth backend key fetch can run. private var firebaseApiKey: String { if let envKey = getenv("FIREBASE_API_KEY"), let key = String(validatingUTF8: envKey), !key.isEmpty { return key } + if let key = APIKeyService.shared.effectiveFirebaseApiKey, !key.isEmpty { + setenv("FIREBASE_API_KEY", key, 1) + return key + } + if let key = APIKeyService.bootstrapFirebaseApiKey { + setenv("FIREBASE_API_KEY", key, 1) + return key + } log("AuthService: FIREBASE_API_KEY not set — auth operations will fail") return "" } @@ -532,7 +540,7 @@ class AuthService { /// Called by AppDelegate when the app receives an OAuth callback URL @MainActor func handleOAuthCallback(url: URL) { - NSLog("OMI AUTH: Received OAuth callback: %@", url.absoluteString) + NSLog("OMI AUTH: Received OAuth callback: %@", Self.sanitizedOAuthCallbackLogDetails(url: url)) guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { NSLog("OMI AUTH: Failed to parse callback URL") @@ -581,6 +589,17 @@ class AuthService { oauthContinuation = nil } + nonisolated static func sanitizedOAuthCallbackLogDetails(url: URL) -> String { + guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + return "scheme=\(url.scheme ?? "nil") host=\(url.host ?? "nil") path=\(url.path) has_code=false has_state=false has_error=false" + } + let queryItems = components.queryItems ?? [] + func hasParameter(_ name: String) -> Bool { + queryItems.contains { $0.name == name } + } + return "scheme=\(components.scheme ?? "nil") host=\(components.host ?? "nil") path=\(components.path) has_code=\(hasParameter("code")) has_state=\(hasParameter("state")) has_error=\(hasParameter("error"))" + } + /// Cancel an in-flight web OAuth sign-in so the user can retry from a clean /// state. The recovery path we care about: the user fails on the web side /// (closed the tab, denied, or just walked away) and comes back to a diff --git a/desktop/Desktop/Sources/FloatingControlBar/AgentPill.swift b/desktop/Desktop/Sources/FloatingControlBar/AgentPill.swift index 59fa2f5a58c..ea3b9e0352b 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/AgentPill.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/AgentPill.swift @@ -131,7 +131,7 @@ final class AgentPillsManager: ObservableObject { request.httpMethod = "POST" request.timeoutInterval = 4 do { - let headers = try await APIClient.shared.buildHeaders(requireAuth: true) + let headers = try await APIClient.shared.buildHeaders(requireAuth: true, includeBYOK: true) for (k, v) in headers { request.setValue(v, forHTTPHeaderField: k) } } catch { log("AgentPill: router skipped — auth header unavailable (\(error.localizedDescription))") @@ -549,7 +549,7 @@ final class AgentPillsManager: ObservableObject { request.httpMethod = "POST" request.timeoutInterval = 8 do { - let headers = try await APIClient.shared.buildHeaders(requireAuth: true) + let headers = try await APIClient.shared.buildHeaders(requireAuth: true, includeBYOK: true) for (k, v) in headers { request.setValue(v, forHTTPHeaderField: k) } } catch { log("AgentPill: title gen skipped — auth header unavailable (\(error.localizedDescription))") diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift index 7b4d507cc0f..f85f2a10aee 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift @@ -388,6 +388,7 @@ actor GeminiClient { urlRequest.httpMethod = "POST" urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") urlRequest.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") + APIKeyService.applyBYOKHeaders(to: &urlRequest, providers: [.gemini]) urlRequest.timeoutInterval = 300 urlRequest.httpBody = requestBody @@ -462,6 +463,7 @@ actor GeminiClient { urlRequest.httpMethod = "POST" urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") urlRequest.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") + APIKeyService.applyBYOKHeaders(to: &urlRequest, providers: [.gemini]) urlRequest.timeoutInterval = timeout urlRequest.httpBody = try JSONEncoder().encode(request) @@ -533,6 +535,7 @@ actor GeminiClient { urlRequest.httpMethod = "POST" urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") urlRequest.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") + APIKeyService.applyBYOKHeaders(to: &urlRequest, providers: [.gemini]) urlRequest.timeoutInterval = 300 urlRequest.httpBody = try JSONEncoder().encode(request) @@ -832,6 +835,7 @@ extension GeminiClient { urlRequest.httpMethod = "POST" urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") urlRequest.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") + APIKeyService.applyBYOKHeaders(to: &urlRequest, providers: [.gemini]) urlRequest.timeoutInterval = 300 urlRequest.httpBody = requestBody diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift index 4f54a776c3e..8e0c9f9c7ee 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift @@ -4,6 +4,7 @@ import Foundation /// Actor-based service for embeddings using Gemini (3072-dim) actor EmbeddingService { static let shared = EmbeddingService() + typealias AuthHeaderProvider = @Sendable () async throws -> String /// Gemini embedding-001 outputs 3072 dimensions by default static let embeddingDimension = 3072 @@ -15,6 +16,8 @@ actor EmbeddingService { /// Cap in-memory embeddings to limit memory (~12KB each, 5000 = ~60MB max) private let maxIndexSize = 5000 + private let urlSession: URLSession + private let authHeaderProvider: AuthHeaderProvider? /// Backend proxy base URL (from OMI_DESKTOP_API_URL env var) private static var proxyBaseURL: String { @@ -29,11 +32,28 @@ actor EmbeddingService { /// Get Firebase auth header for proxy requests private func authHeader() async throws -> String { + if let authHeaderProvider { + return try await authHeaderProvider() + } let authService = await MainActor.run { AuthService.shared } return try await authService.getAuthHeader() } - private init() {} + init( + urlSession: URLSession = .shared, + authHeaderProvider: AuthHeaderProvider? = nil + ) { + self.urlSession = urlSession + self.authHeaderProvider = authHeaderProvider + } + + private func checkHTTPStatus(_ response: URLResponse, data: Data) throws { + guard let httpResponse = response as? HTTPURLResponse else { return } + guard (200..<300).contains(httpResponse.statusCode) else { + let body = String(data: data.prefix(200), encoding: .utf8) ?? "" + throw EmbeddingError.serverError(statusCode: httpResponse.statusCode, body: body) + } + } // MARK: - Embedding API @@ -64,17 +84,12 @@ actor EmbeddingService { request.httpMethod = "POST" request.setValue("application/json", forHTTPHeaderField: "Content-Type") request.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") + APIKeyService.applyBYOKHeaders(to: &request, providers: [.gemini]) request.timeoutInterval = 30 request.httpBody = try JSONSerialization.data(withJSONObject: requestBody) - let (data, response) = try await URLSession.shared.data(for: request) - - // Check HTTP status before parsing — non-JSON error bodies (HTML 401/500) - // cause "data couldn't be read" errors that mask the real problem. - if let httpResponse = response as? HTTPURLResponse, httpResponse.statusCode != 200 { - let body = String(data: data.prefix(200), encoding: .utf8) ?? "" - throw EmbeddingError.serverError(statusCode: httpResponse.statusCode, body: body) - } + let (data, response) = try await urlSession.data(for: request) + try checkHTTPStatus(response, data: data) guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any], let embedding = json["embedding"] as? [String: Any], @@ -119,10 +134,12 @@ actor EmbeddingService { request.httpMethod = "POST" request.setValue("application/json", forHTTPHeaderField: "Content-Type") request.setValue(try await authHeader(), forHTTPHeaderField: "Authorization") + APIKeyService.applyBYOKHeaders(to: &request, providers: [.gemini]) request.timeoutInterval = 60 request.httpBody = try JSONSerialization.data(withJSONObject: requestBody) - let (data, _) = try await URLSession.shared.data(for: request) + let (data, response) = try await urlSession.data(for: request) + try checkHTTPStatus(response, data: data) guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any], let embeddings = json["embeddings"] as? [[String: Any]] @@ -324,7 +341,21 @@ actor EmbeddingService { switch self { case .missingAPIKey: return "AI features are not configured. Please update the app." case .invalidResponse: return "AI service returned an unexpected response. Please try again." - case .serverError(let statusCode, let body): return "Embedding API error (HTTP \(statusCode)): \(body)" + case .serverError(let statusCode, let body): + let detail = body.trimmingCharacters(in: .whitespacesAndNewlines) + let suffix = detail.isEmpty ? "" : ": \(detail)" + switch statusCode { + case 401: + return "Embedding proxy rejected the current sign-in. Please sign in again\(suffix)" + case 403: + return "Embedding proxy access is not allowed for this account\(suffix)" + case 429: + return "Embedding proxy rate limit exceeded. Please try again later\(suffix)" + case 500...599: + return "Embedding proxy is temporarily unavailable (HTTP \(statusCode))\(suffix)" + default: + return "Embedding proxy failed (HTTP \(statusCode))\(suffix)" + } } } } diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index ee8ce943d88..2fe30b1a0f2 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -154,6 +154,26 @@ private func assertNoOmiHostedBackendRequests( ) } +private func assertNoBYOKHeaders( + _ request: CapturedRequest?, + file: StaticString = #filePath, + line: UInt = #line +) { + guard let headers = request?.headers else { + XCTFail("expected captured request", file: file, line: line) + return + } + for provider in BYOKProvider.allCases { + XCTAssertNil(headers[provider.headerName], "unexpected \(provider.headerName)", file: file, line: line) + } +} + +private func clearBYOKDefaults() { + for provider in BYOKProvider.allCases { + UserDefaults.standard.removeObject(forKey: provider.storageKey) + } +} + private func assertUnavailable( _ error: Error?, capability: DesktopBackendEnvironment.Capability, @@ -402,6 +422,7 @@ final class APIClientRoutingTests: XCTestCase { override func setUp() { super.setUp() URLCapture.reset() + clearBYOKDefaults() setenv("OMI_PYTHON_API_URL", "http://python-test:9001", 1) setenv("OMI_DESKTOP_API_URL", "http://rust-test:9002", 1) unsetenv("OMI_DESKTOP_BACKEND_MODE") @@ -410,6 +431,7 @@ final class APIClientRoutingTests: XCTestCase { } override func tearDown() { + clearBYOKDefaults() unsetenv("OMI_PYTHON_API_URL") unsetenv("OMI_DESKTOP_API_URL") unsetenv("OMI_DESKTOP_BACKEND_MODE") @@ -419,6 +441,94 @@ final class APIClientRoutingTests: XCTestCase { super.tearDown() } + func testBundledFirebaseApiKeyBootstrapsWhenEnvIsMissing() { + unsetenv("FIREBASE_API_KEY") + let key = APIKeyService.bootstrapFirebaseApiKey + XCTAssertNotNil(key) + XCTAssertFalse(key?.isEmpty ?? true) + } + + func testOAuthCallbackLogDetailsAreSanitized() { + let url = URL(string: "omi-computer://auth/callback?code=secret-code&state=secret-state&extra=visible")! + let details = AuthService.sanitizedOAuthCallbackLogDetails(url: url) + XCTAssertTrue(details.contains("scheme=omi-computer")) + XCTAssertTrue(details.contains("host=auth")) + XCTAssertTrue(details.contains("path=/callback")) + XCTAssertTrue(details.contains("has_code=true")) + XCTAssertTrue(details.contains("has_state=true")) + XCTAssertFalse(details.contains("secret-code")) + XCTAssertFalse(details.contains("secret-state")) + XCTAssertFalse(details.contains("extra=visible")) + XCTAssertFalse(details.contains(url.absoluteString)) + } + + func testOrdinaryRequestsDoNotAttachBYOKHeaders() async { + for provider in BYOKProvider.allCases { + UserDefaults.standard.set("test-\(provider.rawValue)-key", forKey: provider.storageKey) + } + let client = await makeTestClient() + + _ = try? await client.getAssistantSettings() as AssistantSettingsResponse + assertNoBYOKHeaders(URLCapture.capturedRequests.first) + + URLCapture.reset() + _ = try? await client.getChatSessions() as [ChatSession] + assertNoBYOKHeaders(URLCapture.capturedRequests.first) + + URLCapture.reset() + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + _ = try? await client.getSelectedBackendSettings() + assertNoBYOKHeaders(URLCapture.capturedRequests.first) + } + + func testExplicitProviderRequestsAttachBYOKHeaders() async { + for provider in BYOKProvider.allCases { + UserDefaults.standard.set("test-\(provider.rawValue)-key", forKey: provider.storageKey) + } + let client = await makeTestClient() + + _ = try? await client.synthesizeSpeech( + request: APIClient.TtsSynthesizeRequest( + text: "Hello", + voiceId: "onyx", + instructions: nil + )) + + let headers = URLCapture.capturedRequests.first?.headers ?? [:] + for provider in BYOKProvider.allCases { + XCTAssertEqual(headers[provider.headerName], "test-\(provider.rawValue)-key") + } + } + + func testEmbeddingBatchHandlesNon2xxProxyResponses() async { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [URLCapture.self] + let session = URLSession(configuration: config) + let service = EmbeddingService( + urlSession: session, + authHeaderProvider: { "Bearer test-token" } + ) + + for status in [401, 403, 429, 503] { + URLCapture.reset() + URLCapture.setStatusCode(status) + do { + _ = try await service.embedBatch(texts: ["hello"]) + XCTFail("expected serverError for HTTP \(status)") + } catch let error as EmbeddingService.EmbeddingError { + guard case .serverError(let statusCode, let body) = error else { + XCTFail("expected serverError, got \(error)") + continue + } + XCTAssertEqual(statusCode, status) + XCTAssertTrue(body.contains("detail")) + } catch { + XCTFail("expected EmbeddingError.serverError, got \(error)") + } + } + } + // -- Conversations (GET, DELETE → Python) -- func testGetConversationRoutesToPython() async { From 2792f492fd476eec08c0b2ce6edde623cf52a629 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 20:21:03 +0700 Subject: [PATCH 25/58] Close local memory import and listing gaps --- desktop/Desktop/Sources/APIClient.swift | 24 ++++- .../MainWindow/Pages/MemoriesPage.swift | 15 +++- .../Desktop/Tests/APIClientRoutingTests.swift | 51 +++++++++++ desktop/local-backend/src/main.rs | 78 ++++++++++++++++ desktop/local-backend/src/routes.rs | 88 ++++++++++++++++++- desktop/local-backend/src/storage.rs | 53 ++++++++++- 6 files changed, 299 insertions(+), 10 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index c28fd4a0a6b..74990b61919 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -1851,8 +1851,21 @@ extension APIClient { ) async throws -> [ServerMemory] { let target = selectedBackendTarget if target.mode == .localDaemon { + var queryItems: [String] = [ + "limit=\(limit)", + "offset=\(offset)", + ] + if let category = category { + queryItems.append("category=\(Self.queryValue(category))") + } + if let tags = tags, !tags.isEmpty { + queryItems.append("tags=\(Self.queryValue(tags.joined(separator: ",")))") + } + if includeDismissed { + queryItems.append("include_dismissed=true") + } let response: LocalMemoriesResponse = try await get( - "v1/memories", + "v1/memories?\(queryItems.joined(separator: "&"))", requireAuth: false, customBaseURL: target.baseURL ) @@ -1960,6 +1973,15 @@ extension APIClient { let memories: [MemoryBatchItem] } let body = BatchRequest(memories: memories) + let target = selectedBackendTarget + if target.mode == .localDaemon { + return try await post( + "v1/memories/batch", + body: body, + requireAuth: false, + customBaseURL: target.baseURL + ) + } return try await post("v3/memories/batch", body: body) } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/MemoriesPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/MemoriesPage.swift index d92548b3efe..05145321662 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/MemoriesPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/MemoriesPage.swift @@ -178,6 +178,13 @@ class MemoriesViewModel: ObservableObject { private var currentOffset = 0 private let pageSize = 100 // Reduced from 500 for better performance + private func appendUniqueMemories(_ newMemories: [ServerMemory]) -> Int { + let existingIds = Set(memories.map(\.id)) + let uniqueMemories = newMemories.filter { !existingIds.contains($0.id) } + memories.append(contentsOf: uniqueMemories) + return uniqueMemories.count + } + // Bulk operations state @Published var showingDeleteAllConfirmation = false @Published var isBulkOperationInProgress = false @@ -658,11 +665,11 @@ class MemoriesViewModel: ObservableObject { ) if !moreFromCache.isEmpty { - memories.append(contentsOf: moreFromCache) + let appendedCount = appendUniqueMemories(moreFromCache) currentOffset += moreFromCache.count hasMoreMemories = moreFromCache.count >= pageSize log( - "MemoriesViewModel: Loaded \(moreFromCache.count) more from local cache (total: \(memories.count))" + "MemoriesViewModel: Loaded \(appendedCount) unique memories from local cache (total: \(memories.count))" ) isLoadingMore = false return @@ -680,10 +687,10 @@ class MemoriesViewModel: ObservableObject { try await MemoryStorage.shared.syncServerMemories(newMemories) // Then append to display - memories.append(contentsOf: newMemories) + let appendedCount = appendUniqueMemories(newMemories) currentOffset += newMemories.count hasMoreMemories = newMemories.count >= pageSize - log("MemoriesViewModel: Loaded \(newMemories.count) more from API (total: \(memories.count))") + log("MemoriesViewModel: Loaded \(appendedCount) unique memories from API (total: \(memories.count))") } catch { logError("Failed to load more memories", error: error) } diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 2fe30b1a0f2..f13eae53108 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -1031,6 +1031,57 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertTrue(requests[2].url.path.contains("/v1/memories/mem-local")) } + func testLocalModeGetMemoriesPreservesQueryParametersWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + _ = try? await client.getMemories( + limit: 25, + offset: 50, + category: "manual", + tags: ["focus", "health"], + includeDismissed: true + ) + + let request = URLCapture.capturedRequests.first + XCTAssertEqual(URLCapture.capturedRequests.count, 1) + XCTAssertEqual(request?.url.host, "127.0.0.1") + XCTAssertEqual(request?.url.port, 8765) + XCTAssertEqual(request?.method, "GET") + XCTAssertEqual(request?.headers["Authorization"], nil) + XCTAssertEqual(request?.url.path, "/v1/memories") + let components = URLComponents(url: request!.url, resolvingAgainstBaseURL: false) + let query = Dictionary(uniqueKeysWithValues: (components?.queryItems ?? []).map { ($0.name, $0.value ?? "") }) + XCTAssertEqual(query["limit"], "25") + XCTAssertEqual(query["offset"], "50") + XCTAssertEqual(query["category"], "manual") + XCTAssertEqual(query["tags"], "focus,health") + XCTAssertEqual(query["include_dismissed"], "true") + } + + func testLocalModeCreateMemoriesBatchRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + _ = try? await client.createMemoriesBatch([ + MemoryBatchItem(content: "local imported memory", tags: ["focus"], headline: "Focus") + ]) + + let request = URLCapture.capturedRequests.first + XCTAssertEqual(URLCapture.capturedRequests.count, 1) + XCTAssertEqual(request?.url.host, "127.0.0.1") + XCTAssertEqual(request?.url.port, 8765) + XCTAssertEqual(request?.url.path, "/v1/memories/batch") + XCTAssertEqual(request?.method, "POST") + XCTAssertEqual(request?.headers["Authorization"], nil) + let body = try? JSONSerialization.jsonObject(with: request?.body ?? Data()) as? [String: Any] + let memories = body?["memories"] as? [[String: Any]] + XCTAssertEqual(memories?.first?["content"] as? String, "local imported memory") + XCTAssertEqual(memories?.first?["tags"] as? [String], ["focus"]) + } + func testLocalModeCloudOnlyMemoryBulkOperationsFailBeforeNetworkRequests() async { setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index 5ace2206b26..71de6206bc7 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -584,6 +584,84 @@ mod tests { Ok(()) } + #[tokio::test] + async fn memory_list_supports_pagination_category_and_tags() -> Result<()> { + let app = test_app()?; + + let batch = request_json( + app.clone(), + Method::POST, + "/v1/memories/batch", + Some(json!({ + "memories": [ + {"content": "Manual focus memory", "tags": ["focus", "productivity"]}, + {"content": "Manual health memory", "tags": ["health"]}, + {"content": "Manual focus followup", "tags": ["focus"]} + ] + })), + ) + .await?; + assert_eq!(batch["created_count"], 3); + + request_json( + app.clone(), + Method::POST, + "/v1/memories", + Some(json!({ + "id": "mem-system", + "content": "System memory", + "category": "system", + "metadata": {"tags": ["focus"]} + })), + ) + .await?; + + let page_one = request_json( + app.clone(), + Method::GET, + "/v1/memories?limit=2&offset=0", + None, + ) + .await?; + let page_two = request_json( + app.clone(), + Method::GET, + "/v1/memories?limit=2&offset=2", + None, + ) + .await?; + let page_one_ids: Vec<&str> = page_one["memories"] + .as_array() + .unwrap() + .iter() + .map(|memory| memory["id"].as_str().unwrap()) + .collect(); + let page_two_ids: Vec<&str> = page_two["memories"] + .as_array() + .unwrap() + .iter() + .map(|memory| memory["id"].as_str().unwrap()) + .collect(); + assert_eq!(page_one_ids.len(), 2); + assert_eq!(page_two_ids.len(), 2); + assert!(page_one_ids.iter().all(|id| !page_two_ids.contains(id))); + + let system_focus = request_json( + app.clone(), + Method::GET, + "/v1/memories?category=system&tags=focus", + None, + ) + .await?; + assert_eq!(system_focus["memories"].as_array().unwrap().len(), 1); + assert_eq!(system_focus["memories"][0]["id"], "mem-system"); + + let focus = request_json(app, Method::GET, "/v1/memories?tags=focus", None).await?; + assert_eq!(focus["memories"].as_array().unwrap().len(), 3); + + Ok(()) + } + #[tokio::test] async fn settings_reject_omi_firebase_and_google_provider_hosts() -> Result<()> { let app = test_app()?; diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index 3d439139970..7904353f82a 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -46,6 +46,7 @@ pub fn router() -> Router { ) .route("/v1/search/conversations", get(search_conversations)) .route("/v1/memories", get(list_memories).post(create_memory)) + .route("/v1/memories/batch", post(create_memories_batch)) .route( "/v1/memories/:id", get(get_memory).patch(update_memory).delete(delete_memory), @@ -430,8 +431,38 @@ async fn search_conversations( Ok(Json(json!({ "results": results }))) } -async fn list_memories(State(state): State) -> ApiResult { - let memories = state.store.memories().list().map_err(ApiError::internal)?; +#[derive(Deserialize)] +struct ListMemoriesQuery { + limit: Option, + offset: Option, + category: Option, + tags: Option, + #[serde(rename = "include_dismissed")] + _include_dismissed: Option, +} + +async fn list_memories( + State(state): State, + Query(query): Query, +) -> ApiResult { + let tags = query.tags.map(|tags| { + tags.split(',') + .map(str::trim) + .filter(|tag| !tag.is_empty()) + .map(ToString::to_string) + .collect::>() + }); + let memories = state + .store + .memories() + .list_filtered(crate::storage::MemoryListOptions { + limit: Some(limit_or_default(query.limit) as usize), + offset: query.offset.unwrap_or(0), + category: query.category, + tags: tags.unwrap_or_default(), + include_deleted: false, + }) + .map_err(ApiError::internal)?; Ok(Json(json!({ "memories": memories }))) } @@ -477,6 +508,59 @@ async fn create_memory( Ok(Json(json!({ "memory": memory }))) } +#[derive(Deserialize)] +struct CreateMemoriesBatchRequest { + memories: Vec, +} + +#[derive(Deserialize)] +struct CreateMemoryBatchItem { + content: String, + tags: Option>, + headline: Option, +} + +async fn create_memories_batch( + State(state): State, + Json(request): Json, +) -> ApiResult { + if request.memories.len() > 100 { + return Err(ApiError::bad_request("memory batch size exceeds 100")); + } + + let mut created = Vec::with_capacity(request.memories.len()); + for item in request.memories { + let mut metadata = Map::new(); + if let Some(tags) = item.tags { + metadata.insert("tags".to_string(), json!(tags)); + } + if let Some(headline) = item.headline { + metadata.insert("headline".to_string(), json!(headline)); + } + let memory = state + .store + .memories() + .create(NewMemory { + id: local_id("mem"), + content: item.content, + category: Some("manual".to_string()), + conversation_id: None, + metadata: Some(Value::Object(metadata)), + }) + .map_err(ApiError::internal)?; + created.push(json!({ + "id": memory.id, + "content": memory.content, + })); + } + + let created_count = created.len(); + Ok(Json(json!({ + "memories": created, + "created_count": created_count, + }))) +} + async fn get_memory(State(state): State, Path(id): Path) -> ApiResult { let memory = state .store diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index e8424ea183a..1f5d2801085 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -5,6 +5,7 @@ use anyhow::{Context, Result}; use chrono::{DateTime, Utc}; use rusqlite::{params, Connection, OptionalExtension}; use serde::{Deserialize, Serialize}; +use serde_json::Value; use sha2::{Digest, Sha256}; const MIGRATIONS: &[Migration] = &[ @@ -1074,6 +1075,15 @@ pub struct MemoryRepository { conn: Arc>, } +#[derive(Debug, Clone, Default)] +pub struct MemoryListOptions { + pub limit: Option, + pub offset: usize, + pub category: Option, + pub tags: Vec, + pub include_deleted: bool, +} + impl MemoryRepository { pub fn create(&self, new: NewMemory) -> Result { let now = Utc::now(); @@ -1178,6 +1188,10 @@ impl MemoryRepository { } pub fn list(&self) -> Result> { + self.list_filtered(MemoryListOptions::default()) + } + + pub fn list_filtered(&self, options: MemoryListOptions) -> Result> { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); let mut stmt = conn .prepare( @@ -1185,15 +1199,22 @@ impl MemoryRepository { SELECT id, content, category, conversation_id, created_at, updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json FROM memories - WHERE deleted_at IS NULL + WHERE (?1 OR deleted_at IS NULL) ORDER BY updated_at DESC "#, ) .context("failed to prepare memory list query")?; let rows = stmt - .query_map([], map_memory) + .query_map(params![options.include_deleted], map_memory) .context("failed to list memories")?; - collect_rows(rows) + let memories = collect_rows(rows)?; + let filtered = memories + .into_iter() + .filter(|memory| memory_matches_list_options(memory, &options)) + .skip(options.offset) + .take(options.limit.unwrap_or(usize::MAX)) + .collect(); + Ok(filtered) } pub fn update(&self, id: &str, update: UpdateMemory) -> Result> { @@ -1886,6 +1907,32 @@ fn map_memory(row: &rusqlite::Row<'_>) -> rusqlite::Result { }) } +fn memory_matches_list_options(memory: &Memory, options: &MemoryListOptions) -> bool { + if let Some(category) = &options.category { + if memory.category.as_deref() != Some(category.as_str()) { + return false; + } + } + + if options.tags.is_empty() { + return true; + } + + let Ok(metadata) = serde_json::from_str::(&memory.metadata_json) else { + return false; + }; + let Some(memory_tags) = metadata.get("tags").and_then(Value::as_array) else { + return false; + }; + + options.tags.iter().all(|requested_tag| { + memory_tags + .iter() + .filter_map(Value::as_str) + .any(|memory_tag| memory_tag == requested_tag) + }) +} + fn map_action_item(row: &rusqlite::Row<'_>) -> rusqlite::Result { Ok(ActionItem { id: row.get(0)?, From b3005d39e05baaafccc1381eec3095ed5bf5513e Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 20:34:52 +0700 Subject: [PATCH 26/58] Close local mode proactive cloud leaks --- desktop/Desktop/Sources/APIClient.swift | 141 ++++++++++++++++++ .../MainWindow/Pages/PersonaPage.swift | 35 +++++ .../TaskExtraction/TaskAssistant.swift | 5 + .../Services/AIUserProfileService.swift | 126 +++++++++------- .../Rewind/Core/StagedTaskStorage.swift | 64 +++++++- .../Desktop/Tests/APIClientRoutingTests.swift | 96 ++++++++++++ 6 files changed, 408 insertions(+), 59 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 74990b61919..1eeab59c5a5 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -2239,6 +2239,11 @@ struct ActionItemsListResponse: Decodable { case hasMore = "has_more" } + init(items: [TaskActionItem], hasMore: Bool = false) { + self.items = items + self.hasMore = hasMore + } + init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) if let actionItems = try container.decodeIfPresent([TaskActionItem].self, forKey: .actionItems) @@ -2667,17 +2672,64 @@ extension APIClient { relevanceScore: relevanceScore ) + if selectedBackendTarget.mode == .localDaemon { + let tagsJson: String? + if let tags = metadata?["tags"] as? [String], + let data = try? JSONEncoder().encode(tags), + let json = String(data: data, encoding: .utf8) + { + tagsJson = json + } else { + tagsJson = nil + } + let record = StagedTaskRecord( + backendId: "local_staged_\(UUID().uuidString)", + backendSynced: false, + description: description, + source: source, + priority: priority, + category: category, + tagsJson: tagsJson, + dueAt: dueAt, + confidence: metadata?["confidence"] as? Double, + sourceApp: metadata?["source_app"] as? String, + windowTitle: metadata?["window_title"] as? String, + contextSummary: metadata?["context_summary"] as? String, + currentActivity: metadata?["current_activity"] as? String, + metadataJson: metadataString, + relevanceScore: relevanceScore, + scoredAt: relevanceScore == nil ? nil : Date() + ) + let inserted: StagedTaskRecord + if relevanceScore != nil { + inserted = try await StagedTaskStorage.shared.insertWithScoreShift(record) + } else { + inserted = try await StagedTaskStorage.shared.insertLocalStagedTask(record) + } + return inserted.toTaskActionItem() + } + return try await post("v1/staged-tasks", body: request) } /// Fetches staged tasks ordered by relevance score func getStagedTasks(limit: Int = 100, offset: Int = 0) async throws -> ActionItemsListResponse { + if selectedBackendTarget.mode == .localDaemon { + let items = try await StagedTaskStorage.shared.getScoredStagedTasks(limit: limit, offset: offset) + return ActionItemsListResponse(items: items, hasMore: items.count == limit) + } + let params = "limit=\(limit)&offset=\(offset)" return try await get("v1/staged-tasks?\(params)") } /// Hard-deletes a staged task func deleteStagedTask(id: String) async throws { + if selectedBackendTarget.mode == .localDaemon { + try await StagedTaskStorage.shared.deleteByTaskId(id) + return + } + try await delete("v1/staged-tasks/\(id)") } @@ -2695,22 +2747,62 @@ extension APIClient { } let request = BatchRequest( scores: scores.map { ScoreUpdate(id: $0.id, relevance_score: $0.score) }) + if selectedBackendTarget.mode == .localDaemon { + try await StagedTaskStorage.shared.updateScores(scores) + return + } + let _: StatusResponse = try await patch("v1/staged-tasks/batch-scores", body: request) } /// Promotes the top-ranked staged task to action_items func promoteTopStagedTask() async throws -> PromoteResponse { + if selectedBackendTarget.mode == .localDaemon { + guard let staged = try await StagedTaskStorage.shared.promoteTopLocalStagedTask() else { + return PromoteResponse(promoted: false, reason: "no staged tasks", promotedTask: nil) + } + let promoted = TaskActionItem( + id: "local_action_\(UUID().uuidString)", + description: staged.description, + completed: false, + createdAt: staged.createdAt, + updatedAt: Date(), + dueAt: staged.dueAt, + conversationId: staged.conversationId, + source: staged.source, + priority: staged.priority, + metadata: staged.metadata, + category: staged.category, + deleted: false, + fromStaged: true, + relevanceScore: staged.relevanceScore, + contextSummary: staged.contextSummary, + currentActivity: staged.currentActivity + ) + return PromoteResponse(promoted: true, reason: nil, promotedTask: promoted) + } + return try await post("v1/staged-tasks/promote") } /// One-time migration of existing AI tasks to staged_tasks func migrateStagedTasks() async throws { + if selectedBackendTarget.mode == .localDaemon { + log("APIClient: staged-task backend migration skipped in local daemon mode") + return + } + struct StatusResponse: Decodable { let status: String } let _: StatusResponse = try await post("v1/staged-tasks/migrate") } /// Migrate conversation-extracted action items (no source field) to staged_tasks func migrateConversationItemsToStaged() async throws { + if selectedBackendTarget.mode == .localDaemon { + log("APIClient: conversation-to-staged backend migration skipped in local daemon mode") + return + } + struct MigrateResponse: Decodable { let status: String let migrated: Int @@ -4062,11 +4154,25 @@ extension APIClient { /// Fetches user's persona (if exists) func getPersona() async throws -> Persona? { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "persona", + reason: "AI Persona is an Omi cloud account feature and is disabled in local daemon mode." + ) + } + return try await get("v1/personas") } /// Creates a new persona func createPersona(name: String, username: String? = nil) async throws -> Persona { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "persona", + reason: "AI Persona is an Omi cloud account feature and is disabled in local daemon mode." + ) + } + struct CreateRequest: Encodable { let name: String let username: String? @@ -4082,6 +4188,13 @@ extension APIClient { personaPrompt: String? = nil, image: String? = nil ) async throws -> Persona { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "persona", + reason: "AI Persona is an Omi cloud account feature and is disabled in local daemon mode." + ) + } + struct UpdateRequest: Encodable { let name: String? let description: String? @@ -4100,17 +4213,38 @@ extension APIClient { /// Deletes user's persona func deletePersona() async throws { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "persona", + reason: "AI Persona is an Omi cloud account feature and is disabled in local daemon mode." + ) + } + try await delete("v1/personas") } /// Regenerates persona prompt from current public memories func regeneratePersonaPrompt() async throws -> GeneratePromptResponse { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "persona", + reason: "AI Persona is an Omi cloud account feature and is disabled in local daemon mode." + ) + } + struct EmptyRequest: Encodable {} return try await post("v1/personas/generate-prompt", body: EmptyRequest()) } /// Checks if a username is available func checkPersonaUsername(_ username: String) async throws -> UsernameAvailableResponse { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: "persona", + reason: "AI Persona is an Omi cloud account feature and is disabled in local daemon mode." + ) + } + return try await get("v1/personas/check-username?username=\(username)") } } @@ -5479,6 +5613,13 @@ extension APIClient { /// Sync AI-generated user profile to backend func syncAIUserProfile(profileText: String, generatedAt: Date, dataSourcesUsed: Int) async throws { + if selectedBackendTarget.mode == .localDaemon { + throw APIError.featureUnavailable( + feature: DesktopBackendEnvironment.Capability.cloudSync.rawValue, + reason: "AI user profile cloud sync is disabled in local daemon mode." + ) + } + struct SyncRequest: Encodable { let profile_text: String let generated_at: String diff --git a/desktop/Desktop/Sources/MainWindow/Pages/PersonaPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/PersonaPage.swift index 94947243ccf..6680e1c71e9 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/PersonaPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/PersonaPage.swift @@ -36,6 +36,10 @@ struct PersonaPage: View { } } + private var isLocalDaemonMode: Bool { + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + } + var body: some View { VStack(spacing: 0) { // Sheet header with close button @@ -520,6 +524,13 @@ struct PersonaPage: View { // MARK: - API Calls private func loadPersona() async { + guard !isLocalDaemonMode else { + persona = nil + errorMessage = "AI Persona is unavailable in local daemon mode." + isLoading = false + return + } + isLoading = true errorMessage = nil @@ -533,6 +544,11 @@ struct PersonaPage: View { } private func createPersona() async { + guard !isLocalDaemonMode else { + errorMessage = "AI Persona is unavailable in local daemon mode." + return + } + isCreating = true do { @@ -550,6 +566,11 @@ struct PersonaPage: View { } private func deletePersona() async { + guard !isLocalDaemonMode else { + errorMessage = "AI Persona is unavailable in local daemon mode." + return + } + do { try await APIClient.shared.deletePersona() persona = nil @@ -560,6 +581,10 @@ struct PersonaPage: View { private func saveEdits() async { guard let currentPersona = persona else { return } + guard !isLocalDaemonMode else { + errorMessage = "AI Persona is unavailable in local daemon mode." + return + } do { let updated = try await APIClient.shared.updatePersona( @@ -574,6 +599,11 @@ struct PersonaPage: View { } private func regeneratePrompt() async { + guard !isLocalDaemonMode else { + errorMessage = "AI Persona is unavailable in local daemon mode." + return + } + isRegenerating = true do { @@ -589,6 +619,11 @@ struct PersonaPage: View { } private func checkUsername() async { + guard !isLocalDaemonMode else { + usernameAvailable = nil + return + } + guard newPersonaUsername.count >= 3 else { usernameAvailable = nil return diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift index 694186f80ee..90357260f76 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift @@ -485,6 +485,11 @@ actor TaskAssistant: ProactiveAssistant { /// Sync task to backend API, returns backend ID if successful private func syncTaskToBackend(task: ExtractedTask, taskResult: TaskExtractionResult, windowTitle: String? = nil) async -> String? { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { + log("Task: Skipped staged task backend sync in local daemon mode") + return nil + } + do { var metadata: [String: Any] = [ "source_app": task.sourceApp, diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift index 356728481f6..6492aa609c6 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/AIUserProfileService.swift @@ -109,28 +109,32 @@ actor AIUserProfileService { arguments: [newText, id] ) } - // Sync updated profile to backend (fire-and-forget) - let text = newText - Task { - do { - // Fetch the record to get generatedAt and dataSourcesUsed - let record = try? await db.read { database in - try AIUserProfileRecord.fetchOne(database, key: id) - } - try await APIClient.shared.syncAIUserProfile( - profileText: text, - generatedAt: record?.generatedAt ?? Date(), - dataSourcesUsed: record?.dataSourcesUsed ?? 0 - ) - _ = try? await db.write { database in - try database.execute( - sql: "UPDATE ai_user_profiles SET backendSynced = 1 WHERE id = ?", - arguments: [id] + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + log("AIUserProfileService: skipped backend profile sync in local daemon mode") + } else { + // Sync updated profile to backend (fire-and-forget) + let text = newText + Task { + do { + // Fetch the record to get generatedAt and dataSourcesUsed + let record = try? await db.read { database in + try AIUserProfileRecord.fetchOne(database, key: id) + } + try await APIClient.shared.syncAIUserProfile( + profileText: text, + generatedAt: record?.generatedAt ?? Date(), + dataSourcesUsed: record?.dataSourcesUsed ?? 0 ) + _ = try? await db.write { database in + try database.execute( + sql: "UPDATE ai_user_profiles SET backendSynced = 1 WHERE id = ?", + arguments: [id] + ) + } + log("AIUserProfileService: Synced updated profile to backend") + } catch { + log("AIUserProfileService: Failed to sync updated profile to backend: \(error.localizedDescription)") } - log("AIUserProfileService: Synced updated profile to backend") - } catch { - log("AIUserProfileService: Failed to sync updated profile to backend: \(error.localizedDescription)") } } return true @@ -161,27 +165,31 @@ actor AIUserProfileService { } log("AIUserProfileService: Saved exploration as new profile (\(record.profileText.count) chars)") - // Sync to backend (fire-and-forget) - let profileText = record.profileText - let recordId = insertedId - Task { - do { - try await APIClient.shared.syncAIUserProfile( - profileText: profileText, - generatedAt: generatedAt, - dataSourcesUsed: 1 - ) - if let id = recordId, let db = try? await self.ensureDB() { - _ = try? await db.write { database in - try database.execute( - sql: "UPDATE ai_user_profiles SET backendSynced = 1 WHERE id = ?", - arguments: [id] - ) + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + log("AIUserProfileService: skipped backend profile sync in local daemon mode") + } else { + // Sync to backend (fire-and-forget) + let profileText = record.profileText + let recordId = insertedId + Task { + do { + try await APIClient.shared.syncAIUserProfile( + profileText: profileText, + generatedAt: generatedAt, + dataSourcesUsed: 1 + ) + if let id = recordId, let db = try? await self.ensureDB() { + _ = try? await db.write { database in + try database.execute( + sql: "UPDATE ai_user_profiles SET backendSynced = 1 WHERE id = ?", + arguments: [id] + ) + } } + log("AIUserProfileService: Synced exploration profile to backend") + } catch { + log("AIUserProfileService: Failed to sync exploration profile to backend: \(error.localizedDescription)") } - log("AIUserProfileService: Synced exploration profile to backend") - } catch { - log("AIUserProfileService: Failed to sync exploration profile to backend: \(error.localizedDescription)") } } return true @@ -334,26 +342,30 @@ actor AIUserProfileService { } // 7. Sync to backend (fire-and-forget) - let recordId = record.id - Task { - do { - try await APIClient.shared.syncAIUserProfile( - profileText: truncated, - generatedAt: generatedAt, - dataSourcesUsed: dataSourcesUsed - ) - // Mark as synced - if let id = recordId, let db = try? await self.ensureDB() { - _ = try? await db.write { database in - try database.execute( - sql: "UPDATE ai_user_profiles SET backendSynced = 1 WHERE id = ?", - arguments: [id] - ) + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + log("AIUserProfileService: skipped backend profile sync in local daemon mode") + } else { + let recordId = record.id + Task { + do { + try await APIClient.shared.syncAIUserProfile( + profileText: truncated, + generatedAt: generatedAt, + dataSourcesUsed: dataSourcesUsed + ) + // Mark as synced + if let id = recordId, let db = try? await self.ensureDB() { + _ = try? await db.write { database in + try database.execute( + sql: "UPDATE ai_user_profiles SET backendSynced = 1 WHERE id = ?", + arguments: [id] + ) + } } + log("AIUserProfileService: Synced profile to backend") + } catch { + log("AIUserProfileService: Failed to sync profile to backend: \(error.localizedDescription)") } - log("AIUserProfileService: Synced profile to backend") - } catch { - log("AIUserProfileService: Failed to sync profile to backend: \(error.localizedDescription)") } } diff --git a/desktop/Desktop/Sources/Rewind/Core/StagedTaskStorage.swift b/desktop/Desktop/Sources/Rewind/Core/StagedTaskStorage.swift index 1a8b3923a9e..71aefbed20b 100644 --- a/desktop/Desktop/Sources/Rewind/Core/StagedTaskStorage.swift +++ b/desktop/Desktop/Sources/Rewind/Core/StagedTaskStorage.swift @@ -131,7 +131,7 @@ actor StagedTaskStorage { } /// Get staged tasks ordered by relevance score (best first) - func getScoredStagedTasks(limit: Int = 100) async throws -> [TaskActionItem] { + func getScoredStagedTasks(limit: Int = 100, offset: Int = 0) async throws -> [TaskActionItem] { let db = try await ensureInitialized() return try await db.read { database in @@ -139,7 +139,7 @@ actor StagedTaskStorage { .filter(Column("deleted") == false) .filter(Column("completed") == false) .order(sql: "COALESCE(relevanceScore, 999999) ASC") - .limit(limit) + .limit(limit, offset: offset) .fetchAll(database) return records.map { $0.toTaskActionItem() } @@ -193,6 +193,66 @@ actor StagedTaskStorage { log("StagedTaskStorage: Hard-deleted staged task with id \(id)") } + /// Hard-delete a staged task by API-facing ID. Accepts backend IDs and + /// local fallback IDs in the form "staged_". + func deleteByTaskId(_ taskId: String) async throws { + if taskId.hasPrefix("staged_"), + let localId = Int64(taskId.dropFirst("staged_".count)) { + try await deleteById(localId) + return + } + + try await deleteByBackendId(taskId) + } + + /// Update local relevance scores by API-facing ID. + func updateScores(_ scores: [(id: String, score: Int)]) async throws { + guard !scores.isEmpty else { return } + let db = try await ensureInitialized() + let now = Date() + + try await db.write { database in + for score in scores { + if score.id.hasPrefix("staged_"), + let localId = Int64(score.id.dropFirst("staged_".count)) { + try database.execute( + sql: "UPDATE staged_tasks SET relevanceScore = ?, scoredAt = ?, updatedAt = ? WHERE id = ?", + arguments: [score.score, now, now, localId] + ) + } else { + try database.execute( + sql: "UPDATE staged_tasks SET relevanceScore = ?, scoredAt = ?, updatedAt = ? WHERE backendId = ?", + arguments: [score.score, now, now, score.id] + ) + } + } + } + } + + /// Remove and return the current top local staged task for promotion. + func promoteTopLocalStagedTask() async throws -> TaskActionItem? { + let db = try await ensureInitialized() + + return try await db.write { database in + guard let record = try StagedTaskRecord + .filter(Column("deleted") == false) + .filter(Column("completed") == false) + .order(sql: "COALESCE(relevanceScore, 999999) ASC, createdAt ASC") + .fetchOne(database) else { + return nil + } + + if let id = record.id { + try database.execute( + sql: "DELETE FROM staged_tasks WHERE id = ?", + arguments: [id] + ) + } + + return record.toTaskActionItem() + } + } + // MARK: - Re-ranking /// Apply selective re-ranking from Gemini response diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index f13eae53108..877eae33ac2 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -1227,6 +1227,29 @@ final class APIClientRoutingTests: XCTestCase { label: "getPersona") } + func testLocalModePersonaAPIsFailBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "https://api.omi.me", 1) + setenv("OMI_DESKTOP_API_URL", "https://desktop-backend-hhibjajaja-uc.a.run.app", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + var errors: [Error] = [] + do { _ = try await client.getPersona() as Persona? } catch { errors.append(error) } + do { _ = try await client.createPersona(name: "Local") } catch { errors.append(error) } + do { _ = try await client.updatePersona(name: "Updated") } catch { errors.append(error) } + do { try await client.deletePersona() } catch { errors.append(error) } + do { _ = try await client.regeneratePersonaPrompt() } catch { errors.append(error) } + do { _ = try await client.checkPersonaUsername("local") } catch { errors.append(error) } + + XCTAssertEqual(errors.count, 6) + XCTAssertTrue(errors.allSatisfy { + guard case APIError.featureUnavailable(let feature, _) = $0 else { return false } + return feature == "persona" + }) + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + // -- User settings (GET → Python) -- func testGetDailySummarySettingsRoutesToPython() async { @@ -1328,6 +1351,79 @@ final class APIClientRoutingTests: XCTestCase { label: "deleteStagedTask") } + func testLocalModeStagedTaskAPIsUseLocalStorageBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "https://api.omi.me", 1) + setenv("OMI_DESKTOP_API_URL", "https://desktop-backend-hhibjajaja-uc.a.run.app", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let testUserId = "api-client-routing-staged-\(UUID().uuidString)" + let testRoot = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true) + .appendingPathComponent("omi-rewind-routing-\(UUID().uuidString)", isDirectory: true) + setenv("OMI_REWIND_DATABASE_ROOT", testRoot.path, 1) + await RewindDatabase.shared.close() + await RewindDatabase.shared.configure(userId: testUserId) + await StagedTaskStorage.shared.invalidateCache() + await ActionItemStorage.shared.invalidateCache() + let client = await makeTestClient() + + let first = try? await client.createStagedTask( + description: "local staged one", + source: "screenshot", + priority: "high", + category: "work", + metadata: ["source_app": "XCTest", "tags": ["work"]], + relevanceScore: 2 + ) + let second = try? await client.createStagedTask( + description: "local staged two", + source: "screenshot", + relevanceScore: 1 + ) + let listed = try? await client.getStagedTasks(limit: 10) + try? await client.batchUpdateStagedScores( + [first, second].compactMap { item in item.map { (id: $0.id, score: 5) } } + ) + if let first { + try? await client.deleteStagedTask(id: first.id) + } + let promoted = try? await client.promoteTopStagedTask() + try? await client.migrateStagedTasks() + try? await client.migrateConversationItemsToStaged() + + XCTAssertEqual(listed?.items.count, 2) + XCTAssertEqual(listed?.items.first?.description, "local staged two") + XCTAssertEqual(promoted?.promoted, true) + XCTAssertEqual(promoted?.promotedTask?.description, "local staged two") + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + + await RewindDatabase.shared.close() + await StagedTaskStorage.shared.invalidateCache() + await ActionItemStorage.shared.invalidateCache() + try? FileManager.default.removeItem(at: testRoot) + unsetenv("OMI_REWIND_DATABASE_ROOT") + } + + func testLocalModeAIUserProfileSyncFailsBeforeNetworkRequests() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_PYTHON_API_URL", "https://api.omi.me", 1) + setenv("OMI_DESKTOP_API_URL", "https://desktop-backend-hhibjajaja-uc.a.run.app", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + do { + try await client.syncAIUserProfile(profileText: "local profile", generatedAt: Date(), dataSourcesUsed: 1) + XCTFail("expected AI profile cloud sync to be unavailable") + } catch { + guard case APIError.featureUnavailable(let feature, _) = error else { + XCTFail("expected featureUnavailable for AI profile sync, got \(error)") + return + } + XCTAssertEqual(feature, DesktopBackendEnvironment.Capability.cloudSync.rawValue) + } + + XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + } + // -- Chat sessions (GET, POST, DELETE → Python, migrated from Rust) -- func testGetChatSessionsRoutesToPython() async { From f16e16f41186c150e24a126fa2d4d01d7141ac54 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 19:39:56 -0400 Subject: [PATCH 27/58] Add hybrid provider parity for local daemon desktop mode. Desktop routes chat, embeddings, vision, and Apple Speech STT through configurable local providers; the daemon gains folders, merge, chat sessions, and provider test/seed tooling for end-to-end hybrid dev. Co-authored-by: Cursor --- Makefile | 19 + desktop/Desktop/Info.plist | 2 + desktop/Desktop/Sources/APIClient.swift | 216 +++- desktop/Desktop/Sources/AppState.swift | 44 +- desktop/Desktop/Sources/AuthService.swift | 28 + .../Sources/DesktopBackendEnvironment.swift | 48 +- .../Desktop/Sources/HybridChatClient.swift | 221 ++++ .../Sources/HybridEmbeddingClient.swift | 140 +++ desktop/Desktop/Sources/HybridLLMClient.swift | 490 ++++++++ .../Sources/HybridProviderBootstrap.swift | 48 + .../Sources/HybridProviderReadiness.swift | 120 ++ .../Sources/HybridVisionProvider.swift | 22 + .../LocalSpeechTranscriptionAdapter.swift | 252 ++++ .../Sources/MainWindow/DesktopHomeView.swift | 5 +- .../MainWindow/Pages/ConversationsPage.swift | 49 +- .../MainWindow/Pages/SettingsPage.swift | 474 +++++++- .../Assistants/Insight/InsightAssistant.swift | 152 ++- .../Core/GeminiClient.swift | 246 +++- .../Services/EmbeddingService.swift | 15 + .../Sources/Providers/ChatProvider.swift | 254 ++-- .../Sources/Rewind/Core/RewindDatabase.swift | 12 + .../Rewind/Services/OCREmbeddingService.swift | 40 + .../Sources/Rewind/UI/RewindViewModel.swift | 2 +- desktop/Desktop/Sources/SignInView.swift | 39 +- .../Sources/TranscriptionService.swift | 164 ++- .../Desktop/Tests/APIClientRoutingTests.swift | 286 ++++- .../Tests/HybridEmbeddingClientTests.swift | 39 + .../Tests/HybridLLMProviderConfigTests.swift | 52 + .../Tests/HybridVisionProviderTests.swift | 47 + .../LocalSpeechTranscriptionLocaleTests.swift | 55 + ...LocalSpeechTranscriptionMappingTests.swift | 38 + desktop/README.md | 3 +- desktop/local-backend/docs/architecture.md | 6 +- .../docs/hybrid-embedding-versioning.md | 37 + .../docs/hybrid-provider-settings.md | 74 ++ .../local-backend/docs/local-mvp-runbook.md | 84 +- desktop/local-backend/src/main.rs | 200 +++ desktop/local-backend/src/processing.rs | 1 + desktop/local-backend/src/providers.rs | 98 +- desktop/local-backend/src/routes.rs | 386 +++++- desktop/local-backend/src/storage.rs | 1069 ++++++++++++++++- .../tools/seed_hybrid_defaults.sh | 91 ++ desktop/run.sh | 37 + 43 files changed, 5400 insertions(+), 305 deletions(-) create mode 100644 Makefile create mode 100644 desktop/Desktop/Sources/HybridChatClient.swift create mode 100644 desktop/Desktop/Sources/HybridEmbeddingClient.swift create mode 100644 desktop/Desktop/Sources/HybridLLMClient.swift create mode 100644 desktop/Desktop/Sources/HybridProviderBootstrap.swift create mode 100644 desktop/Desktop/Sources/HybridProviderReadiness.swift create mode 100644 desktop/Desktop/Sources/HybridVisionProvider.swift create mode 100644 desktop/Desktop/Sources/LocalSpeechTranscriptionAdapter.swift create mode 100644 desktop/Desktop/Tests/HybridEmbeddingClientTests.swift create mode 100644 desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift create mode 100644 desktop/Desktop/Tests/HybridVisionProviderTests.swift create mode 100644 desktop/Desktop/Tests/LocalSpeechTranscriptionLocaleTests.swift create mode 100644 desktop/Desktop/Tests/LocalSpeechTranscriptionMappingTests.swift create mode 100644 desktop/local-backend/docs/hybrid-embedding-versioning.md create mode 100644 desktop/local-backend/docs/hybrid-provider-settings.md create mode 100755 desktop/local-backend/tools/seed_hybrid_defaults.sh diff --git a/Makefile b/Makefile new file mode 100644 index 00000000000..0e49da40ed6 --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +# Hybrid local development (desktop local daemon + Omi Dev app). +# See desktop/local-backend/docs/local-mvp-runbook.md + +.PHONY: help serve-local down-local + +help: + @echo "Hybrid local development targets:" + @echo " make serve-local Start omi-local-backend + Omi Dev (tmux when available)" + @echo " make down-local Stop tmux session, daemon, and dev desktop app" + @echo "" + @echo "Optional env: OMI_LOCAL_DAEMON_URL, OMI_LOCAL_BACKEND_DATA_DIR," + @echo " OMI_HYBRID_LOCAL_TMUX_SESSION (default: omi-hybrid-local)" + @echo " OMI_HYBRID_LOCAL_ATTACH=0 start tmux detached" + +serve-local: + @bash "$(CURDIR)/scripts/hybrid-local.sh" up + +down-local: + @bash "$(CURDIR)/scripts/hybrid-local.sh" down diff --git a/desktop/Desktop/Info.plist b/desktop/Desktop/Info.plist index 6ec015c2e19..3a5b0503228 100644 --- a/desktop/Desktop/Info.plist +++ b/desktop/Desktop/Info.plist @@ -39,6 +39,8 @@ Omi needs permission to detect which application is currently active. NSMicrophoneUsageDescription Omi needs microphone access to transcribe your conversations in real-time. + NSSpeechRecognitionUsageDescription + Omi uses speech recognition to transcribe your voice while local conversations use on-device daemon mode. NSAudioCaptureUsageDescription Omi needs permission to capture system audio for transcription of calls and meetings. NSBluetoothAlwaysUsageDescription diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 1eeab59c5a5..717125e61d1 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -28,6 +28,9 @@ actor APIClient { selectedBackendTarget.mode == .localDaemon } + /// Matches `LOCAL_DEFAULT_CHAT_SESSION_ID` in `desktop/local-backend` (implicit default-chat thread). + static let localDaemonDefaultChatSessionId = "00000000-0000-4000-8000-000000000001" + let session: URLSession private let decoder: JSONDecoder @@ -241,6 +244,22 @@ actor APIClient { return response.settings } + func testHybridProvider(key: String) async throws -> LocalDaemonTestProviderResponse { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + throw APIError.featureUnavailable( + feature: "test_hybrid_provider", + reason: "Provider tests are only available in local daemon mode." + ) + } + return try await post( + "v1/settings/test-provider", + body: LocalDaemonTestProviderRequest(key: key), + requireAuth: false, + customBaseURL: target.baseURL + ) + } + // MARK: - Request Execution private func performRequest(_ request: URLRequest, retryOnUnauthorized: Bool) @@ -357,6 +376,16 @@ struct LocalDaemonSettingsResponse: Decodable { let settings: [LocalDaemonSetting] } +struct LocalDaemonTestProviderRequest: Encodable { + let key: String +} + +struct LocalDaemonTestProviderResponse: Decodable, Equatable { + let ok: Bool + let key: String + let message: String +} + struct LocalDaemonSetting: Decodable, Equatable { let key: String let valueJson: String @@ -472,6 +501,9 @@ extension APIClient { if let starred = starred { queryItems.append("starred=\(starred)") } + if let folderId = folderId { + queryItems.append("folder_id=\(Self.queryValue(folderId))") + } let endpoint = "v1/conversations?\(queryItems.joined(separator: "&"))" let response: LocalConversationsResponse = try await get( endpoint, @@ -834,13 +866,6 @@ extension APIClient { func mergeConversations(ids: [String], reprocess: Bool = true) async throws -> MergeConversationsResponse { - if mvpBackendTarget.mode == .localDaemon { - throw APIError.featureUnavailable( - feature: "conversation_merge", - reason: "Conversation merge is not implemented in local daemon mode yet." - ) - } - struct MergeRequest: Encodable { let conversationIds: [String] let reprocess: Bool @@ -851,6 +876,17 @@ extension APIClient { } } + let target = mvpBackendTarget + if target.mode == .localDaemon { + let body = MergeRequest(conversationIds: ids, reprocess: reprocess) + return try await post( + "v1/conversations/merge", + body: body, + requireAuth: false, + customBaseURL: target.baseURL + ) + } + let body = MergeRequest(conversationIds: ids, reprocess: reprocess) return try await post("v1/conversations/merge", body: body) } @@ -859,11 +895,17 @@ extension APIClient { /// Gets all folders for the user func getFolders() async throws -> [Folder] { - if mvpBackendTarget.mode == .localDaemon { - throw APIError.featureUnavailable( - feature: "conversation_folders", - reason: "Conversation folders are not implemented in local daemon mode yet." + let target = mvpBackendTarget + if target.mode == .localDaemon { + struct FoldersPayload: Decodable { + let folders: [Folder] + } + let response: FoldersPayload = try await get( + "v1/conversation-folders", + requireAuth: false, + customBaseURL: target.baseURL ) + return response.folders } return try await get("v1/folders") @@ -873,14 +915,21 @@ extension APIClient { func createFolder(name: String, description: String? = nil, color: String? = nil) async throws -> Folder { - if mvpBackendTarget.mode == .localDaemon { - throw APIError.featureUnavailable( - feature: "conversation_folders", - reason: "Conversation folders are not implemented in local daemon mode yet." + let target = mvpBackendTarget + let body = CreateFolderRequest(name: name, description: description, color: color) + if target.mode == .localDaemon { + struct FolderPayload: Decodable { + let folder: Folder + } + let response: FolderPayload = try await post( + "v1/conversation-folders", + body: body, + requireAuth: false, + customBaseURL: target.baseURL ) + return response.folder } - let body = CreateFolderRequest(name: name, description: description, color: color) return try await post("v1/folders", body: body) } @@ -889,40 +938,60 @@ extension APIClient { id: String, name: String? = nil, description: String? = nil, color: String? = nil, order: Int? = nil ) async throws -> Folder { - if mvpBackendTarget.mode == .localDaemon { - throw APIError.featureUnavailable( - feature: "conversation_folders", - reason: "Conversation folders are not implemented in local daemon mode yet." + let target = mvpBackendTarget + let body = UpdateFolderRequest(name: name, description: description, color: color, order: order) + if target.mode == .localDaemon { + struct FolderPayload: Decodable { + let folder: Folder + } + let response: FolderPayload = try await patch( + "v1/conversation-folders/\(id)", + body: body, + requireAuth: false, + customBaseURL: target.baseURL ) + return response.folder } - let body = UpdateFolderRequest(name: name, description: description, color: color, order: order) return try await patch("v1/folders/\(id)", body: body) } /// Deletes a folder func deleteFolder(id: String, moveToFolderId: String? = nil) async throws { - if mvpBackendTarget.mode == .localDaemon { - throw APIError.featureUnavailable( - feature: "conversation_folders", - reason: "Conversation folders are not implemented in local daemon mode yet." - ) + let target = mvpBackendTarget + if target.mode == .localDaemon { + var endpoint = "v1/conversation-folders/\(id)" + if let moveToId = moveToFolderId { + endpoint += "?move_to_folder_id=\(Self.queryValue(moveToId))" + } + try await delete(endpoint, requireAuth: false, customBaseURL: target.baseURL) + return } var endpoint = "v1/folders/\(id)" if let moveToId = moveToFolderId { - endpoint += "?move_to_folder_id=\(moveToId)" + endpoint += "?move_to_folder_id=\(Self.queryValue(moveToId))" } try await delete(endpoint) } /// Moves a conversation to a folder func moveConversationToFolder(conversationId: String, folderId: String?) async throws { - if mvpBackendTarget.mode == .localDaemon { - throw APIError.featureUnavailable( - feature: "conversation_folders", - reason: "Conversation folders are not implemented in local daemon mode yet." + let target = mvpBackendTarget + if target.mode == .localDaemon { + struct FolderIdBody: Encodable { + let folderId: String? + enum CodingKeys: String, CodingKey { + case folderId = "folder_id" + } + } + let _: LocalConversationEnvelope = try await patch( + "v1/conversations/\(conversationId)", + body: FolderIdBody(folderId: folderId), + requireAuth: false, + customBaseURL: target.baseURL ) + return } let body = MoveToFolderRequest(folderId: folderId) @@ -1427,6 +1496,7 @@ private struct LocalConversation: Decodable { let createdAt: Date let deletedAt: Date? let starred: Bool + let folderId: String? enum CodingKeys: String, CodingKey { case id, title, overview, status, starred @@ -1435,6 +1505,7 @@ private struct LocalConversation: Decodable { case endedAt = "ended_at" case createdAt = "created_at" case deletedAt = "deleted_at" + case folderId = "folder_id" } func toServerConversation(transcriptSegments: [TranscriptSegment]) -> ServerConversation { @@ -1462,7 +1533,7 @@ private struct LocalConversation: Decodable { deleted: deletedAt != nil, isLocked: false, starred: starred, - folderId: nil, + folderId: folderId, inputDeviceName: nil ) } @@ -3677,6 +3748,8 @@ struct ScoreResponse: Codable { static func emptyLocal(date: Date? = nil) -> ScoreResponse { let empty = ScoreData(score: 0, completedTasks: 0, totalTasks: 0) let formatter = DateFormatter() + formatter.calendar = Calendar(identifier: .gregorian) + formatter.timeZone = TimeZone(secondsFromGMT: 0) formatter.dateFormat = "yyyy-MM-dd" return ScoreResponse( daily: empty, @@ -5233,14 +5306,33 @@ extension APIClient { sessionId: String? = nil, metadata: String? = nil ) async throws -> SaveMessageResponse { - struct SaveRequest: Encodable { + let target = selectedBackendTarget + if target.mode == .localDaemon { + struct LocalSaveRequest: Encodable { + let text: String + let sender: String + let app_id: String? + let metadata: String? + } + let sid = sessionId ?? Self.localDaemonDefaultChatSessionId + let body = LocalSaveRequest( + text: text, sender: sender, app_id: appId, metadata: metadata) + return try await post( + "v2/chat-sessions/\(sid)/messages", + body: body, + requireAuth: false, + customBaseURL: target.baseURL + ) + } + + struct CloudSaveRequest: Encodable { let text: String let sender: String let app_id: String? let session_id: String? let metadata: String? } - let body = SaveRequest( + let body = CloudSaveRequest( text: text, sender: sender, app_id: appId, session_id: sessionId, metadata: metadata) return try await post("v2/desktop/messages", body: body) } @@ -5251,6 +5343,20 @@ extension APIClient { limit: Int = 100, offset: Int = 0 ) async throws -> [ChatMessageDB] { + let target = selectedBackendTarget + if target.mode == .localDaemon { + var queryItems: [String] = [ + "limit=\(limit)", + "offset=\(offset)", + ] + if let appId = appId { + queryItems.append("app_id=\(appId)") + } + let sid = Self.localDaemonDefaultChatSessionId + let endpoint = "v2/chat-sessions/\(sid)/messages?\(queryItems.joined(separator: "&"))" + return try await get(endpoint, requireAuth: false, customBaseURL: target.baseURL) + } + var queryItems: [String] = [ "limit=\(limit)", "offset=\(offset)", @@ -5299,13 +5405,20 @@ extension APIClient { limit: Int = 100, offset: Int = 0 ) async throws -> [ChatMessageDB] { + let target = selectedBackendTarget let queryItems: [String] = [ - "session_id=\(sessionId)", "limit=\(limit)", "offset=\(offset)", ] - let endpoint = "v2/desktop/messages?\(queryItems.joined(separator: "&"))" + if target.mode == .localDaemon { + let endpoint = + "v2/chat-sessions/\(sessionId)/messages?\(queryItems.joined(separator: "&"))" + return try await get(endpoint, requireAuth: false, customBaseURL: target.baseURL) + } + + let endpoint = + "v2/desktop/messages?session_id=\(sessionId)&\(queryItems.joined(separator: "&"))" return try await get(endpoint) } @@ -5424,6 +5537,15 @@ extension APIClient { let app_id: String? } let body = CreateRequest(title: title, app_id: appId) + let target = selectedBackendTarget + if target.mode == .localDaemon { + return try await post( + "v2/chat-sessions", + body: body, + requireAuth: false, + customBaseURL: target.baseURL + ) + } return try await post("v2/chat-sessions", body: body) } @@ -5447,6 +5569,10 @@ extension APIClient { } let endpoint = "v2/chat-sessions?\(queryItems.joined(separator: "&"))" + let target = selectedBackendTarget + if target.mode == .localDaemon { + return try await get(endpoint, requireAuth: false, customBaseURL: target.baseURL) + } return try await get(endpoint) } @@ -5461,11 +5587,29 @@ extension APIClient { let starred: Bool? } let body = UpdateRequest(title: title, starred: starred) + let target = selectedBackendTarget + if target.mode == .localDaemon { + return try await patch( + "v2/chat-sessions/\(sessionId)", + body: body, + requireAuth: false, + customBaseURL: target.baseURL + ) + } return try await patch("v2/chat-sessions/\(sessionId)", body: body) } /// Delete a chat session and its messages func deleteChatSession(sessionId: String) async throws { + let target = selectedBackendTarget + if target.mode == .localDaemon { + try await delete( + "v2/chat-sessions/\(sessionId)", + requireAuth: false, + customBaseURL: target.baseURL + ) + return + } try await delete("v2/chat-sessions/\(sessionId)") } diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 890fa05f330..c9e73445189 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -1640,7 +1640,38 @@ class AppState: ObservableObject { // After WS close, the Python backend processes the conversation automatically. // Call force-process to ensure finalization and get the backend conversation ID. // This prevents the retry service from picking up the pendingUpload session. + // Post-stop: cloud needs time for WS teardown + force-process; local daemon uploads assembled SQLite transcripts. Task { + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + guard let sid = capturedSessionId else { + log( + "Transcription: Local daemon stopped with no persisted session id — nothing to upload" + ) + return + } + // Allow SQLite segment upserts and Apple Speech finals to settle before upload. + try? await Task.sleep(nanoseconds: 550_000_000) + + guard self.recordingGeneration == generationAtStop else { + log( + "Transcription: Skipping local daemon finalize — new recording started during upload delay" + ) + return + } + + do { + _ = try await TranscriptionRetryService.shared.finalizeLocalDaemonSessionNow( + sessionId: sid) + log( + "Transcription: Local daemon finalized stop session \(sid) via upload pipeline (or queued retry)" + ) + } catch { + logError("Transcription: Local daemon finalize-after-stop failed for \(sid)", error: error) + } + await loadConversations() + return + } + try? await Task.sleep(nanoseconds: 3_000_000_000) // 3s for backend to process after WS close // If a new recording started during the delay, skip force-process — it would @@ -1653,13 +1684,6 @@ class AppState: ObservableObject { return } - if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { - log( - "Transcription: Local daemon mode stopped capture; leaving session \(capturedSessionId.map(String.init) ?? "nil") for local transcript retry/finalize" - ) - return - } - do { if let conversation = try await APIClient.shared.forceProcessConversation() { // Validate the returned conversation matches the session we just stopped @@ -2232,12 +2256,6 @@ class AppState: ObservableObject { func loadFolders() async { guard !isLoadingFolders else { return } - if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { - folders = [] - selectedFolderId = nil - return - } - isLoadingFolders = true do { diff --git a/desktop/Desktop/Sources/AuthService.swift b/desktop/Desktop/Sources/AuthService.swift index 9b1f20b8805..f36386cb51a 100644 --- a/desktop/Desktop/Sources/AuthService.swift +++ b/desktop/Desktop/Sources/AuthService.swift @@ -50,6 +50,10 @@ class AuthService { private var isLocalDaemonMode: Bool { DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon } + + /// Stable identity for hybrid local daemon mode (no Firebase / OAuth). + static let localGuestUserId = "local-hybrid-guest" + static let localGuestEmail = "local@omi.local" private var redirectURI: String { return "\(urlScheme)://auth/callback" } @@ -137,6 +141,7 @@ class AuthService { isConfigured = true restoreAuthState() setupAuthStateListener() + establishLocalGuestSessionIfNeeded() // Timeout: if auth isn't restored within 5 seconds, stop showing loading DispatchQueue.main.asyncAfter(deadline: .now() + 5.0) { @@ -147,6 +152,29 @@ class AuthService { } } + /// Hybrid local daemon mode does not use the Python OAuth backend or Firebase for data APIs. + /// Enter a stable on-device guest session so the main UI is reachable without cloud sign-in. + @MainActor + func establishLocalGuestSessionIfNeeded() { + guard isLocalDaemonMode else { return } + guard !isSignedIn else { + isLoading = false + AuthState.shared.isRestoringAuth = false + return + } + + NSLog("OMI AUTH: Local daemon mode — establishing offline guest session (no cloud login)") + isSignedIn = true + isLoading = false + AuthState.shared.isRestoringAuth = false + AuthState.shared.userEmail = Self.localGuestEmail + saveAuthState(isSignedIn: true, email: Self.localGuestEmail, userId: Self.localGuestUserId) + Task { + await RewindDatabase.shared.configure(userId: Self.localGuestUserId) + await HybridProviderBootstrap.ensureDefaultsIfNeeded() + } + } + // MARK: - Auth Persistence (UserDefaults for dev builds) private func saveAuthState(isSignedIn: Bool, email: String?, userId: String?) { diff --git a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift index 0c55fa858ef..f8ec41c38a9 100644 --- a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift +++ b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift @@ -10,6 +10,11 @@ enum DesktopBackendEnvironment { enum Capability: String, CaseIterable, Equatable { case localConversationData case firebaseSignIn + case directSTT + case directChat + case directEmbeddings + case optionalCloudSTT + case optionalCloudChat case managedAgentVM case omiBackendProviderProxy case publicSharing @@ -182,10 +187,24 @@ enum DesktopBackendEnvironment { } switch capability { - case .localConversationData: - return true - case .firebaseSignIn: + case .localConversationData, .firebaseSignIn: return true + case .directSTT: + if hybridDirectSTTExplicitlyDisabled() { + return false + } + if isAffirmative(currentEnvironmentValue("OMI_HYBRID_DIRECT_STT_ENABLED")) { + return true + } + return LocalSpeechTranscriptionAdapter.isRecognitionEngineAvailableForPreferredSystemLanguages() + case .directChat: + return isAffirmative(currentEnvironmentValue("OMI_HYBRID_DIRECT_CHAT_ENABLED")) + case .directEmbeddings: + return isAffirmative(currentEnvironmentValue("OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED")) + case .optionalCloudSTT: + return isAffirmative(currentEnvironmentValue("OMI_HYBRID_OPTIONAL_CLOUD_STT")) + case .optionalCloudChat: + return isAffirmative(currentEnvironmentValue("OMI_HYBRID_OPTIONAL_CLOUD_CHAT")) case .managedAgentVM, .omiBackendProviderProxy, .publicSharing, @@ -203,6 +222,22 @@ enum DesktopBackendEnvironment { } switch capability { + case .directSTT: + if hybridDirectSTTExplicitlyDisabled() { + return "Direct local speech-to-text is disabled (OMI_HYBRID_DIRECT_STT_ENABLED is off)." + } + return + "Apple Speech is not available for this Mac’s preferred languages (or Speech Recognition is off in System Settings). Set OMI_HYBRID_DIRECT_STT_ENABLED=1 to opt in when the engine is available." + case .directChat: + return + "Direct local chat requires OMI_HYBRID_DIRECT_CHAT_ENABLED=1 and a chat_provider in hybrid settings." + case .directEmbeddings: + return + "Direct local embeddings require OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=1 and an embedding_provider in hybrid settings." + case .optionalCloudSTT: + return "Optional cloud speech-to-text is off. Set OMI_HYBRID_OPTIONAL_CLOUD_STT=1 to allow hosted Listen." + case .optionalCloudChat: + return "Optional cloud chat is off. Set OMI_HYBRID_OPTIONAL_CLOUD_CHAT=1 to allow Omi-hosted chat." case .managedAgentVM: return "Managed agent VMs are cloud-only and are disabled in local daemon mode." case .omiBackendProviderProxy: @@ -258,4 +293,11 @@ enum DesktopBackendEnvironment { let normalized = value.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() return normalized == "1" || normalized == "true" || normalized == "yes" } + + /// Treats `OMI_HYBRID_DIRECT_STT_ENABLED=0|false|no|off` as an explicit hybrid direct-STT kill switch. + private static func hybridDirectSTTExplicitlyDisabled() -> Bool { + guard let raw = currentEnvironmentValue("OMI_HYBRID_DIRECT_STT_ENABLED") else { return false } + let n = raw.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() + return n == "0" || n == "false" || n == "no" || n == "off" + } } diff --git a/desktop/Desktop/Sources/HybridChatClient.swift b/desktop/Desktop/Sources/HybridChatClient.swift new file mode 100644 index 00000000000..5b4a830827a --- /dev/null +++ b/desktop/Desktop/Sources/HybridChatClient.swift @@ -0,0 +1,221 @@ +import Foundation + +/// Direct OpenAI-compatible chat completions for hybrid local daemon mode (no pi-mono proxy). +enum HybridChatClient { + + struct ProviderConfig: Equatable { + let baseURL: String + let model: String + let apiKey: String + } + + struct CompletionResult: Equatable { + let text: String + let model: String + let inputTokens: Int + let outputTokens: Int + } + + enum ClientError: LocalizedError { + case notConfigured + case invalidSettings + case invalidResponse + + var errorDescription: String? { + switch self { + case .notConfigured: + return + "Hybrid direct chat is not configured. Set chat_provider or ai_provider in Settings → Plan and Usage (or run a local LLM at the default Ollama URL)." + case .invalidSettings: + return "chat_provider settings are invalid." + case .invalidResponse: + return "Chat provider returned an unexpected response." + } + } + } + + static func isEnabled() -> Bool { + guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { + return false + } + guard DesktopBackendEnvironment.isCapability(.directChat, availableIn: .localDaemon) else { + return false + } + return true + } + + /// Resolves chat_provider → ai_provider / provider (matches HybridLLMClient). + static func resolveEffectiveChatConfig(from settings: [LocalDaemonSetting]) -> ProviderConfig? { + if let chat = loadProviderConfig(from: settings, key: "chat_provider") { + return chat + } + if let ai = loadProviderConfig(from: settings, keys: ["ai_provider", "provider"]) { + return ai + } + return byokOpenAIConfig() + } + + static func loadProviderConfig(from settings: [LocalDaemonSetting]) -> ProviderConfig? { + loadProviderConfig(from: settings, key: "chat_provider") + } + + private static func loadProviderConfig( + from settings: [LocalDaemonSetting], + key: String + ) -> ProviderConfig? { + loadProviderConfig(from: settings, keys: [key]) + } + + private static func loadProviderConfig( + from settings: [LocalDaemonSetting], + keys: [String] + ) -> ProviderConfig? { + guard let raw = settings.first(where: { keys.contains($0.key) })?.valueJson, + let data = raw.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { + return nil + } + return parseOpenAICompatible(json: json) + } + + private static func parseOpenAICompatible(json: [String: Any]) -> ProviderConfig? { + let kind = (json["kind"] as? String)?.lowercased() ?? "" + guard kind == "openai_compatible" || kind == "openai" else { + return nil + } + guard let baseURL = json["base_url"] as? String, !baseURL.isEmpty else { + return nil + } + let model = (json["model"] as? String) ?? HybridProviderReadiness.defaultModel() + let apiKey = + (json["api_key"] as? String) ?? (json["key"] as? String) ?? "" + return ProviderConfig(baseURL: baseURL, model: model, apiKey: apiKey) + } + + private static func byokOpenAIConfig() -> ProviderConfig? { + guard let key = APIKeyService.byokKey(.openai), !key.isEmpty else { + return nil + } + let model = + ProcessInfo.processInfo.environment["OMI_HYBRID_BYOK_OPENAI_MODEL"].flatMap { + $0.trimmingCharacters(in: .whitespacesAndNewlines) + }.flatMap { $0.isEmpty ? nil : $0 } ?? "gpt-4o-mini" + return ProviderConfig(baseURL: "https://api.openai.com/v1", model: model, apiKey: key) + } + + /// Loads daemon hybrid settings and completes one chat turn (non-streaming). + static func completeFromDaemonSettings( + systemPrompt: String, + conversationMessages: [(role: String, text: String)], + userMessage: String + ) async throws -> CompletionResult { + let settings = try await APIClient.shared.getSelectedBackendSettings() + return try await complete( + systemPrompt: systemPrompt, + conversationMessages: conversationMessages, + userMessage: userMessage, + settings: settings + ) + } + + static func complete( + systemPrompt: String, + conversationMessages: [(role: String, text: String)], + userMessage: String, + settings: [LocalDaemonSetting] + ) async throws -> CompletionResult { + guard let config = resolveEffectiveChatConfig(from: settings) else { + throw ClientError.notConfigured + } + return try await completeOpenAICompatible( + config: config, + systemPrompt: systemPrompt, + conversationMessages: conversationMessages, + userMessage: userMessage + ) + } + + private struct ChatCompletionMessage: Encodable { + let role: String + let content: String + } + + private struct ChatCompletionRequest: Encodable { + let model: String + let messages: [ChatCompletionMessage] + let temperature: Double + } + + private static func completeOpenAICompatible( + config: ProviderConfig, + systemPrompt: String, + conversationMessages: [(role: String, text: String)], + userMessage: String + ) async throws -> CompletionResult { + let base = config.baseURL.hasSuffix("/") ? String(config.baseURL.dropLast()) : config.baseURL + guard let url = URL(string: "\(base)/chat/completions") else { + throw ClientError.invalidSettings + } + + var apiMessages: [ChatCompletionMessage] = [ + ChatCompletionMessage(role: "system", content: systemPrompt) + ] + for turn in conversationMessages { + apiMessages.append(ChatCompletionMessage(role: turn.role, content: turn.text)) + } + apiMessages.append(ChatCompletionMessage(role: "user", content: userMessage)) + + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !config.apiKey.isEmpty { + request.setValue("Bearer \(config.apiKey)", forHTTPHeaderField: "Authorization") + } + request.timeoutInterval = 120 + let payload = ChatCompletionRequest( + model: config.model, + messages: apiMessages, + temperature: 0.2 + ) + request.httpBody = try JSONEncoder().encode(payload) + + let (data, response) = try await URLSession.shared.data(for: request) + guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else { + throw ClientError.invalidResponse + } + guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any], + let choices = json["choices"] as? [[String: Any]], + let first = choices.first, + let msg = first["message"] as? [String: Any] + else { + throw ClientError.invalidResponse + } + + let content: String + if let str = msg["content"] as? String { + content = str + } else if let parts = msg["content"] as? [[String: Any]] { + let texts = parts.compactMap { $0["text"] as? String } + content = texts.joined(separator: "\n") + } else { + throw ClientError.invalidResponse + } + + let returnedModel = (json["model"] as? String) ?? config.model + var inputTokens = 0 + var outputTokens = 0 + if let usage = json["usage"] as? [String: Any] { + inputTokens = usage["prompt_tokens"] as? Int ?? usage["input_tokens"] as? Int ?? 0 + outputTokens = + usage["completion_tokens"] as? Int ?? usage["output_tokens"] as? Int ?? 0 + } + + return CompletionResult( + text: content.trimmingCharacters(in: .whitespacesAndNewlines), + model: returnedModel, + inputTokens: inputTokens, + outputTokens: outputTokens + ) + } +} diff --git a/desktop/Desktop/Sources/HybridEmbeddingClient.swift b/desktop/Desktop/Sources/HybridEmbeddingClient.swift new file mode 100644 index 00000000000..b420c7d7020 --- /dev/null +++ b/desktop/Desktop/Sources/HybridEmbeddingClient.swift @@ -0,0 +1,140 @@ +import Foundation + +/// Direct embedding client for hybrid mode (no Omi Gemini proxy). +enum HybridEmbeddingClient { + static let legacyGeminiModelId = "gemini-embedding-001" + static let legacyGeminiDimension = 3072 + + struct ProviderConfig: Equatable { + let baseURL: String + let model: String + let apiKey: String + } + + struct EmbeddingResult: Equatable { + let vector: [Float] + let model: String + let dimension: Int + } + + enum ClientError: LocalizedError { + case notConfigured + case invalidSettings + case invalidResponse + case dimensionMismatch(expected: Int, got: Int) + + var errorDescription: String? { + switch self { + case .notConfigured: + return "Hybrid embeddings are not configured. Set embedding_provider in Settings and enable OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED." + case .invalidSettings: + return "embedding_provider settings are invalid." + case .invalidResponse: + return "Embedding provider returned an unexpected response." + case .dimensionMismatch(let expected, let got): + return "Embedding dimension mismatch (expected \(expected), got \(got))." + } + } + } + + static func isEnabled() -> Bool { + guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { + return false + } + return DesktopBackendEnvironment.isCapability(.directEmbeddings, availableIn: .localDaemon) + } + + static func loadProviderConfig(from settings: [LocalDaemonSetting]) -> ProviderConfig? { + guard let raw = settings.first(where: { $0.key == "embedding_provider" })?.valueJson, + let data = raw.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { + return nil + } + let kind = (json["kind"] as? String) ?? "" + guard kind == "openai_compatible" || kind == "openai" else { + return nil + } + guard let baseURL = json["base_url"] as? String, !baseURL.isEmpty else { + return nil + } + let model = (json["model"] as? String) ?? "text-embedding-3-small" + let apiKey = + (json["api_key"] as? String) ?? (json["key"] as? String) ?? "" + return ProviderConfig(baseURL: baseURL, model: model, apiKey: apiKey) + } + + static func embed(text: String, settings: [LocalDaemonSetting]) async throws -> EmbeddingResult { + guard let config = loadProviderConfig(from: settings) else { + throw ClientError.notConfigured + } + return try await embedOpenAICompatible(text: text, config: config) + } + + static func embedFromDaemonSettings(text: String) async throws -> EmbeddingResult { + let settings = try await APIClient.shared.getSelectedBackendSettings() + return try await embed(text: text, settings: settings) + } + + private static func embedOpenAICompatible(text: String, config: ProviderConfig) async throws + -> EmbeddingResult + { + let base = config.baseURL.hasSuffix("/") ? String(config.baseURL.dropLast()) : config.baseURL + guard let url = URL(string: "\(base)/embeddings") else { + throw ClientError.invalidSettings + } + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !config.apiKey.isEmpty { + request.setValue("Bearer \(config.apiKey)", forHTTPHeaderField: "Authorization") + } + request.timeoutInterval = 60 + request.httpBody = try JSONSerialization.data(withJSONObject: [ + "model": config.model, + "input": text, + ]) + + let (data, response) = try await URLSession.shared.data(for: request) + guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else { + throw ClientError.invalidResponse + } + guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any], + let items = json["data"] as? [[String: Any]], + let first = items.first, + let embedding = first["embedding"] as? [Double] + else { + throw ClientError.invalidResponse + } + let floats = embedding.map { Float($0) } + let normalized = normalize(floats) + return EmbeddingResult( + vector: normalized, + model: config.model, + dimension: normalized.count + ) + } + + private static func normalize(_ vector: [Float]) -> [Float] { + var sum: Float = 0 + for value in vector { + sum += value * value + } + let magnitude = sqrt(sum) + guard magnitude > 0 else { return vector } + return vector.map { $0 / magnitude } + } + + static func isCompatibleEmbedding( + storedModel: String?, + storedDim: Int?, + activeModel: String, + activeDim: Int + ) -> Bool { + guard let storedModel, let storedDim else { + // Legacy rows without metadata: only match legacy Gemini size when using cloud defaults. + return activeModel == legacyGeminiModelId && activeDim == legacyGeminiDimension + } + return storedModel == activeModel && storedDim == activeDim + } +} diff --git a/desktop/Desktop/Sources/HybridLLMClient.swift b/desktop/Desktop/Sources/HybridLLMClient.swift new file mode 100644 index 00000000000..e741bcc7593 --- /dev/null +++ b/desktop/Desktop/Sources/HybridLLMClient.swift @@ -0,0 +1,490 @@ +import Foundation +import Vision + +// MARK: - Settings cache + +/// Short-TTL cache for local daemon hybrid settings (avoids hitting /v1/settings on every capture frame). +actor HybridDaemonSettingsCache { + static let shared = HybridDaemonSettingsCache() + + private var cached: [LocalDaemonSetting]? + private var fetchedAt: Date? + private let ttlSeconds: TimeInterval = 45 + + func settings() async throws -> [LocalDaemonSetting] { + guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { + return [] + } + if let cached, let fetchedAt, Date().timeIntervalSince(fetchedAt) < ttlSeconds { + return cached + } + let fresh = try await APIClient.shared.getSelectedBackendSettings() + cached = fresh + fetchedAt = Date() + return fresh + } +} + +// MARK: - Hybrid LLM (OpenAI-compatible chat completions) + +/// Direct OpenAI-compatible `/v1/chat/completions` for hybrid (local daemon) mode. +enum HybridLLMClient { + + struct ProviderConfig: Equatable { + let baseURL: String + let model: String + let apiKey: String + } + + enum ClientError: LocalizedError { + case notConfigured + case invalidSettings + case invalidResponse + case httpFailure(status: Int, body: String) + + var errorDescription: String? { + switch self { + case .notConfigured: + return "Hybrid AI is not configured. Set ai_provider or chat_provider in Settings, or add a BYOK OpenAI key." + case .invalidSettings: + return "Hybrid provider settings are invalid." + case .invalidResponse: + return "Hybrid AI provider returned an unexpected response." + case .httpFailure(let status, _): + return "Hybrid AI request failed (HTTP \(status))." + } + } + } + + // MARK: Provider loading + + /// Vision / multimodal routing — optional separate provider (see ``HybridVisionProvider``). + static func loadVisionProviderConfig(from settings: [LocalDaemonSetting]) -> ProviderConfig? { + guard HybridVisionProvider.isConfigured(settings: settings) else { + return nil + } + return loadOpenAICompatibleProvider(forKeys: ["vision_provider"], settings: settings) + } + + /// Primary chat routing for assistants: prefers chat_provider, then legacy ai_provider / provider. + static func resolveEffectiveChatConfig(settings: [LocalDaemonSetting]) -> ProviderConfig? { + if let c = loadOpenAICompatibleProvider(forKeys: ["chat_provider"], settings: settings) { + return c + } + return loadOpenAICompatibleProvider(forKeys: ["ai_provider", "provider"], settings: settings) + ?? byokOpenAIConfig() + } + + /// BYOK OpenAI → vendor endpoint (desktop hybrid escape hatch when daemon JSON is unset). + private static func byokOpenAIConfig() -> ProviderConfig? { + guard let key = APIKeyService.byokKey(.openai), !key.isEmpty else { + return nil + } + let model = + ProcessInfo.processInfo.environment["OMI_HYBRID_BYOK_OPENAI_MODEL"].flatMap { + $0.trimmingCharacters(in: .whitespacesAndNewlines) + }.flatMap { $0.isEmpty ? nil : $0 } ?? "gpt-4o-mini" + return ProviderConfig(baseURL: "https://api.openai.com/v1", model: model, apiKey: key) + } + + private static func loadOpenAICompatibleProvider( + forKeys keys: [String], + settings: [LocalDaemonSetting] + ) -> ProviderConfig? { + guard let raw = settings.first(where: { keys.contains($0.key) })?.valueJson, + let data = raw.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { + return nil + } + return parseOpenAICompatible(json: json) + } + + private static func parseOpenAICompatible(json: [String: Any]) -> ProviderConfig? { + let kind = (json["kind"] as? String)?.lowercased() ?? "" + guard kind == "openai_compatible" || kind == "openai" else { + return nil + } + guard let baseURL = json["base_url"] as? String, !baseURL.isEmpty else { + return nil + } + let model = (json["model"] as? String) ?? "gpt-4o-mini" + let apiKey = + (json["api_key"] as? String) ?? (json["key"] as? String) ?? "" + return ProviderConfig(baseURL: baseURL, model: model, apiKey: apiKey) + } + + // MARK: HTTP helpers + + private static func completionsURL(config: ProviderConfig) throws -> URL { + let trimmed = config.baseURL.trimmingCharacters(in: .whitespacesAndNewlines) + let base = trimmed.hasSuffix("/") ? String(trimmed.dropLast()) : trimmed + guard let url = URL(string: "\(base)/chat/completions") else { + throw ClientError.invalidSettings + } + return url + } + + private static func postJSON(url: URL, body: [String: Any], apiKey: String, timeout: TimeInterval) async throws + -> [String: Any] + { + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + if !apiKey.isEmpty { + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + } + request.timeoutInterval = timeout + request.httpBody = try JSONSerialization.data(withJSONObject: body) + + let (data, response) = try await URLSession.shared.data(for: request) + guard let http = response as? HTTPURLResponse else { + throw ClientError.invalidResponse + } + guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] else { + throw ClientError.invalidResponse + } + guard (200..<300).contains(http.statusCode) else { + let body = String(data: data.prefix(512), encoding: .utf8) ?? "" + throw ClientError.httpFailure(status: http.statusCode, body: body) + } + return json + } + + // MARK: Chat (no tools) + + static func chatCompletionText( + config: ProviderConfig, + systemPrompt: String, + userText: String, + jsonMode: Bool, + timeout: TimeInterval = 300 + ) async throws -> String { + let content: [[String: Any]] = [ + ["type": "text", "text": userText] + ] + let messages: [[String: Any]] = [ + ["role": "system", "content": systemPrompt], + ["role": "user", "content": content], + ] + return try await chatCompletionRaw( + config: config, messages: messages, jsonMode: jsonMode, tools: nil, toolChoice: nil, timeout: timeout) + } + + static func chatCompletionMultimodalJPEG( + config: ProviderConfig, + systemPrompt: String, + userText: String, + jpegData: Data, + jsonMode: Bool, + timeout: TimeInterval = 300 + ) async throws -> String { + let b64 = jpegData.base64EncodedString() + let dataUrl = "data:image/jpeg;base64,\(b64)" + let content: [[String: Any]] = [ + ["type": "text", "text": userText], + ["type": "image_url", "image_url": ["url": dataUrl]], + ] + let messages: [[String: Any]] = [ + ["role": "system", "content": systemPrompt], + ["role": "user", "content": content], + ] + return try await chatCompletionRaw( + config: config, messages: messages, jsonMode: jsonMode, tools: nil, toolChoice: nil, timeout: timeout) + } + + private static func chatCompletionRaw( + config: ProviderConfig, + messages: [[String: Any]], + jsonMode: Bool, + tools: [[String: Any]]?, + toolChoice: Any?, + timeout: TimeInterval + ) async throws -> String { + var body: [String: Any] = [ + "model": config.model, + "messages": messages, + "temperature": 0.4, + ] + if jsonMode { + body["response_format"] = ["type": "json_object"] + } + if let tools { + body["tools"] = tools + } + if let toolChoice { + body["tool_choice"] = toolChoice + } + + let json = try await postJSON(url: try completionsURL(config: config), body: body, apiKey: config.apiKey, timeout: timeout) + return try extractAssistantText(from: json) + } + + private static func extractAssistantText(from json: [String: Any]) throws -> String { + guard let choices = json["choices"] as? [[String: Any]], + let first = choices.first, + let message = first["message"] as? [String: Any] + else { + throw ClientError.invalidResponse + } + if let toolCalls = message["tool_calls"] as? [[String: Any]], !toolCalls.isEmpty { + throw ClientError.invalidResponse + } + if let str = message["content"] as? String { + return str + } + if let parts = message["content"] as? [[String: Any]] { + // Some providers use array-of-parts content + let texts = parts.compactMap { $0["text"] as? String } + return texts.joined(separator: "\n") + } + throw ClientError.invalidResponse + } + + // MARK: Tool loop (GeminiImageToolRequest → OpenAI chat) + + /// One round of a tool-calling dialog (caller appends turns between rounds). + static func performGeminiCompatibleToolRound( + config: ProviderConfig, + systemPrompt: String, + contents: [GeminiImageToolRequest.Content], + tools: [GeminiTool], + forceToolCall: Bool, + allowVisionInlineJPEG: Bool, + timeout: TimeInterval = 300 + ) async throws -> ToolChatResult { + let messages = try openAIMessages(from: contents, allowVisionInlineJPEG: allowVisionInlineJPEG) + let openAITools = openAITools(from: tools) + + var toolChoice: Any = "auto" + if forceToolCall { + toolChoice = "required" + } + + var body: [String: Any] = [ + "model": config.model, + "messages": [["role": "system", "content": systemPrompt]] + messages, + "tools": openAITools, + "tool_choice": toolChoice, + "temperature": 0.4, + ] + + let json = try await postJSON(url: try completionsURL(config: config), body: body, apiKey: config.apiKey, timeout: timeout) + return try parseToolChatResult(from: json) + } + + private static func parseToolChatResult(from json: [String: Any]) throws -> ToolChatResult { + guard let choices = json["choices"] as? [[String: Any]], + let first = choices.first, + let message = first["message"] as? [String: Any] + else { + throw ClientError.invalidResponse + } + + var toolCalls: [ToolCall] = [] + if let rawCalls = message["tool_calls"] as? [[String: Any]] { + for tc in rawCalls { + guard let fn = tc["function"] as? [String: Any], + let name = fn["name"] as? String + else { + continue + } + let argsStr = fn["arguments"] as? String ?? "{}" + let argsAny = + (try? JSONSerialization.jsonObject(with: Data(argsStr.utf8))) as? [String: Any] ?? [:] + toolCalls.append( + ToolCall(name: name, arguments: argsAny, thoughtSignature: nil)) + } + } + + var textResponse = "" + if let content = message["content"] as? String { + textResponse = content + } else if let parts = message["content"] as? [[String: Any]] { + textResponse = parts.compactMap { $0["text"] as? String }.joined(separator: "\n") + } + + return ToolChatResult( + text: textResponse, + toolCalls: toolCalls, + requiresToolExecution: !toolCalls.isEmpty + ) + } + + // MARK: OpenAI message building + + private static func openAIMessages( + from contents: [GeminiImageToolRequest.Content], + allowVisionInlineJPEG: Bool + ) throws -> [[String: Any]] { + var out: [[String: Any]] = [] + /// Pair assistant `tool_calls[].id` with subsequent tool result rows (Gemini omits ids). + var pendingToolCallIds: [String] = [] + + for content in contents { + let role = content.role + if role == "user" { + var userParts: [[String: Any]] = [] + var toolResults: [[String: Any]] = [] + + for part in content.parts { + if let fr = part.functionResponse { + let toolCallId = + pendingToolCallIds.isEmpty ? "call_hybrid_fallback_\(fr.name)" : pendingToolCallIds.removeFirst() + toolResults.append([ + "role": "tool", + "tool_call_id": toolCallId, + "content": fr.response.result, + ]) + continue + } + if let t = part.text, !t.isEmpty { + userParts.append(["type": "text", "text": t]) + } + if let img = part.inlineData, allowVisionInlineJPEG { + let mime = img.mimeType + let dataUrl = "data:\(mime);base64,\(img.data)" + userParts.append(["type": "image_url", "image_url": ["url": dataUrl]]) + } + } + + if !userParts.isEmpty { + out.append(["role": "user", "content": userParts]) + } + for tr in toolResults { + out.append(tr) + } + } else if role == "model" { + var textAccum = "" + var oaToolCalls: [[String: Any]] = [] + + for part in content.parts { + if let t = part.text, !t.isEmpty { + textAccum += t + } + if let fc = part.functionCall { + let id = "call_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + pendingToolCallIds.append(id) + let argData = try JSONSerialization.data(withJSONObject: fc.args, options: []) + let argStr = String(data: argData, encoding: .utf8) ?? "{}" + oaToolCalls.append([ + "id": id, + "type": "function", + "function": ["name": fc.name, "arguments": argStr], + ]) + } + } + + var msg: [String: Any] = ["role": "assistant"] + if !textAccum.isEmpty { + msg["content"] = textAccum + } else if oaToolCalls.isEmpty { + msg["content"] = "" + } + if !oaToolCalls.isEmpty { + msg["tool_calls"] = oaToolCalls + } + out.append(msg) + } else { + // Unknown role — skip + } + } + return out + } + + private static func openAITools(from tools: [GeminiTool]) -> [[String: Any]] { + tools.flatMap(\.functionDeclarations).compactMap { fd in + guard let schema = jsonSchema(from: fd.parameters) else { return nil } + return [ + "type": "function", + "function": [ + "name": fd.name, + "description": fd.description, + "parameters": schema, + ], + ] + } + } + + private static func jsonSchema(from params: GeminiTool.FunctionDeclaration.Parameters) -> [String: Any]? { + var properties: [String: Any] = [:] + for (name, prop) in params.properties { + if let nested = propJSONSchema(prop) { + properties[name] = nested + } + } + var schema: [String: Any] = [ + "type": params.type, + "properties": properties, + "required": params.required, + ] + return schema + } + + private static func propJSONSchema(_ prop: GeminiTool.FunctionDeclaration.Parameters.Property) -> [String: Any]? { + if let nested = prop.nestedProperties, let req = prop.nestedRequired { + var childProps: [String: Any] = [:] + for (k, v) in nested { + if let sch = propJSONSchema(v) { + childProps[k] = sch + } + } + var obj: [String: Any] = [ + "type": "object", + "properties": childProps, + "required": req, + ] + if let d = prop.description, !d.isEmpty { + obj["description"] = d + } + return obj + } + + var out: [String: Any] = [ + "type": prop.type + ] + if let d = prop.description, !d.isEmpty { + out["description"] = d + } + if let `enum` = prop.`enum` { + out["enum"] = `enum` + } + if let items = prop.items { + out["items"] = ["type": items.type] + } + return out + } + + // MARK: OCR (on-device) for hybrid without vision_provider + + enum ScreenOCR { + static func recognizeTextFromJPEG(_ jpegData: Data) async throws -> String { + try await Task.detached(priority: .userInitiated) { + try await Self.recognizeTextFromJPEGSync(jpegData) + }.value + } + + private static func recognizeTextFromJPEGSync(_ jpegData: Data) async throws -> String { + try await withCheckedThrowingContinuation { continuation in + let request = VNRecognizeTextRequest { request, error in + if let error { + continuation.resume(throwing: error) + return + } + let observations = (request.results as? [VNRecognizedTextObservation]) ?? [] + let text = observations.compactMap { $0.topCandidates(1).first?.string }.joined(separator: "\n") + continuation.resume(returning: text) + } + request.recognitionLevel = .accurate + request.usesLanguageCorrection = true + + let handler = VNImageRequestHandler(data: jpegData, options: [:]) + do { + try handler.perform([request]) + } catch { + continuation.resume(throwing: error) + } + } + } + } +} diff --git a/desktop/Desktop/Sources/HybridProviderBootstrap.swift b/desktop/Desktop/Sources/HybridProviderBootstrap.swift new file mode 100644 index 00000000000..72d4b443a6c --- /dev/null +++ b/desktop/Desktop/Sources/HybridProviderBootstrap.swift @@ -0,0 +1,48 @@ +import Foundation + +/// Idempotent default hybrid provider keys for local daemon dev (Ollama loopback). +enum HybridProviderBootstrap { + + static func defaultProviderObject() -> [String: LocalDaemonSettingUpdateValue] { + var object: [String: LocalDaemonSettingUpdateValue] = [ + "kind": "openai_compatible", + "base_url": .string(HybridProviderReadiness.defaultBaseURL()), + "model": .string(HybridProviderReadiness.defaultModel()), + ] + return object + } + + /// Writes `ai_provider` and `chat_provider` when absent. Does not overwrite existing keys. + @MainActor + static func ensureDefaultsIfNeeded() async { + guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { + return + } + + do { + let settings = try await APIClient.shared.getSelectedBackendSettings() + var updates: [String: LocalDaemonSettingUpdateValue] = [:] + let provider = defaultProviderObject() + + if !HybridProviderReadiness.hasOpenAICompatibleProvider( + in: settings, keys: ["ai_provider", "provider"]) + { + updates["ai_provider"] = .object(provider) + } + if !HybridProviderReadiness.hasOpenAICompatibleProvider( + in: settings, keys: ["chat_provider"]) + { + updates["chat_provider"] = .object(provider) + } + + guard !updates.isEmpty else { return } + + _ = try await APIClient.shared.updateSelectedBackendSettings(updates) + log( + "HybridProviderBootstrap: seeded \(updates.keys.sorted().joined(separator: ", ")) at \(HybridProviderReadiness.defaultBaseURL())" + ) + } catch { + logError("HybridProviderBootstrap: failed to seed defaults", error: error) + } + } +} diff --git a/desktop/Desktop/Sources/HybridProviderReadiness.swift b/desktop/Desktop/Sources/HybridProviderReadiness.swift new file mode 100644 index 00000000000..a19cc8fa095 --- /dev/null +++ b/desktop/Desktop/Sources/HybridProviderReadiness.swift @@ -0,0 +1,120 @@ +import Foundation + +/// Checklist rows for hybrid local daemon provider setup (Plan & Usage, About). +enum HybridProviderReadiness { + + enum RowStatus: Equatable { + case configured + case optionalFallback + case missing + case capabilityOff + } + + struct Row: Identifiable, Equatable { + let id: String + let label: String + let status: RowStatus + let detail: String + } + + static func defaultBaseURL() -> String { + if let raw = getenv("OMI_HYBRID_DEFAULT_CHAT_BASE_URL"), + let value = String(validatingUTF8: raw), + !value.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty + { + return value.trimmingCharacters(in: .whitespacesAndNewlines) + } + return "http://127.0.0.1:11434/v1" + } + + static func defaultModel() -> String { + if let raw = getenv("OMI_HYBRID_DEFAULT_CHAT_MODEL"), + let value = String(validatingUTF8: raw), + !value.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty + { + return value.trimmingCharacters(in: .whitespacesAndNewlines) + } + return "llama3.2" + } + + static func rows(from settings: [LocalDaemonSetting]) -> [Row] { + var result: [Row] = [] + + let aiConfigured = hasOpenAICompatibleProvider( + in: settings, keys: ["ai_provider", "provider"]) + result.append( + Row( + id: "ai_provider", + label: "Processing (ai_provider)", + status: aiConfigured ? .configured : .optionalFallback, + detail: aiConfigured + ? "OpenAI-compatible provider configured" + : "Optional — deterministic fallback when unset" + )) + + let sttAvailable = DesktopBackendEnvironment.isCapability(.directSTT, availableIn: .localDaemon) + result.append( + Row( + id: "stt", + label: "Live transcription", + status: sttAvailable ? .configured : .capabilityOff, + detail: sttAvailable + ? "On-device Apple Speech (no daemon key)" + : (DesktopBackendEnvironment.unavailableReason(for: .directSTT, in: .localDaemon) + ?? "Direct STT unavailable") + )) + + let chatResolvable = HybridChatClient.resolveEffectiveChatConfig(from: settings) != nil + let chatCap = DesktopBackendEnvironment.isCapability(.directChat, availableIn: .localDaemon) + result.append( + Row( + id: "chat_provider", + label: "Chat (chat_provider)", + status: chatResolvable && chatCap + ? .configured + : (chatCap ? .missing : .capabilityOff), + detail: chatResolvable + ? "Direct chat endpoint configured" + : (chatCap + ? "Set chat_provider or ai_provider in hybrid settings" + : (DesktopBackendEnvironment.unavailableReason(for: .directChat, in: .localDaemon) + ?? "Direct chat disabled")) + )) + + let embedConfigured = HybridEmbeddingClient.loadProviderConfig(from: settings) != nil + let embedCap = DesktopBackendEnvironment.isCapability( + .directEmbeddings, availableIn: .localDaemon) + result.append( + Row( + id: "embedding_provider", + label: "Embeddings (embedding_provider)", + status: embedConfigured && embedCap + ? .configured + : (embedCap ? .missing : .capabilityOff), + detail: embedConfigured + ? "Embedding provider configured" + : (embedCap + ? "Optional for Rewind semantic search" + : (DesktopBackendEnvironment.unavailableReason( + for: .directEmbeddings, in: .localDaemon) ?? "Direct embeddings disabled")) + )) + + return result + } + + static func hasOpenAICompatibleProvider( + in settings: [LocalDaemonSetting], + keys: [String] + ) -> Bool { + guard let raw = settings.first(where: { keys.contains($0.key) })?.valueJson, + let data = raw.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { + return false + } + let kind = ((json["kind"] as? String) ?? "").lowercased() + guard kind == "openai_compatible" || kind == "openai" else { return false } + guard let base = json["base_url"] as? String, !base.isEmpty else { return false } + return true + } +} diff --git a/desktop/Desktop/Sources/HybridVisionProvider.swift b/desktop/Desktop/Sources/HybridVisionProvider.swift new file mode 100644 index 00000000000..1e79e40be69 --- /dev/null +++ b/desktop/Desktop/Sources/HybridVisionProvider.swift @@ -0,0 +1,22 @@ +import Foundation + +/// Vision provider configuration for hybrid local-daemon mode (screenshot / multimodal APIs). +enum HybridVisionProvider { + /// Whether `vision_provider` is set to a supported OpenAI-compatible provider entry. + static func isConfigured(settings: [LocalDaemonSetting]) -> Bool { + guard let raw = settings.first(where: { $0.key == "vision_provider" })?.valueJson, + let data = raw.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { + return false + } + let kind = (json["kind"] as? String) ?? "" + guard kind == "openai_compatible" || kind == "openai" else { + return false + } + guard let baseURL = json["base_url"] as? String, !baseURL.isEmpty else { + return false + } + return true + } +} diff --git a/desktop/Desktop/Sources/LocalSpeechTranscriptionAdapter.swift b/desktop/Desktop/Sources/LocalSpeechTranscriptionAdapter.swift new file mode 100644 index 00000000000..5e323e779a0 --- /dev/null +++ b/desktop/Desktop/Sources/LocalSpeechTranscriptionAdapter.swift @@ -0,0 +1,252 @@ +import AVFoundation +import Foundation +import Speech + +/// Live transcription using Apple's Speech framework (buffer-based recognition). +/// Emits consolidated segments compatible with `TranscriptionService.BackendSegment`. +final class LocalSpeechTranscriptionAdapter: @unchecked Sendable { + + /// Stable pseudo backend ID so SQLite upserts update the rolling transcript row. + static let pseudoBackendSegmentId = "apple-hybrid-live" + + private let languageCode: String + private let audioSerialQueue = DispatchQueue(label: "omi.hybrid.localspeech.audio") + private var recognitionRequest: SFSpeechAudioBufferRecognitionRequest? + private var recognitionTask: SFSpeechRecognitionTask? + + /// Samples captured before authorization / task creation land here, then replay. + private var pendingPCM = Data() + + private var terminated = false + private let sessionWallClockBegin = CFAbsoluteTimeGetCurrent() + + init(languageCode: String) { + self.languageCode = languageCode + } + + /// Hybrid capability probe: whether Apple's Speech engine reports availability for the given assistant language code. + static func isRecognitionEngineAvailable(forAssistantLanguageCode code: String) -> Bool { + let primary = normalizedLocaleIdentifier(forAssistantLanguageCode: code) + if let r = SFSpeechRecognizer(locale: Locale(identifier: primary)), r.isAvailable { + return true + } + return SFSpeechRecognizer(locale: Locale(identifier: "en-US"))?.isAvailable == true + } + + /// Uses the user's preferred macOS language list (fallback `en-US`). + static func isRecognitionEngineAvailableForPreferredSystemLanguages() -> Bool { + let code = Locale.preferredLanguages.first ?? "en-US" + return isRecognitionEngineAvailable(forAssistantLanguageCode: code) + } + + /// Normalize assistant language tokens (matches `effectiveTranscriptionLanguage`) to `Locale` identifiers Speech accepts. + static func normalizedLocaleIdentifier(forAssistantLanguageCode code: String) -> String { + let trimmed = code.trimmingCharacters(in: .whitespacesAndNewlines) + let lower = trimmed.lowercased() + switch lower { + case "", "multi", "auto": + return Locale.preferredLanguages.first ?? "en-US" + case "zh", "cn": + return "zh-CN" + case _ where trimmed.contains("_"): + let parts = trimmed.split(separator: "_", maxSplits: 1).map(String.init) + guard parts.count == 2 else { return lower } + return bcp47Locale(language: parts[0], region: parts[1]) + case _ where trimmed.contains("-"): + let parts = trimmed.split(separator: "-", maxSplits: 1).map(String.init) + guard parts.count == 2 else { return lower } + return bcp47Locale(language: parts[0], region: parts[1]) + default: + return "\(lower)-\(Locale.current.region?.identifier ?? "US")" + } + } + + private static func bcp47Locale(language: String, region: String) -> String { + let lang = language.lowercased() + let regionPart: String + if region.count == 2, region == region.lowercased() { + regionPart = region.lowercased() + } else if region.count == 2 { + regionPart = region.uppercased() + } else { + regionPart = region + } + return "\(lang)-\(regionPart)" + } + + /// Start Speech authorization then begin buffer recognition. `onReady` runs when PCM may be appended. + func start( + onSegments: @escaping ([TranscriptionService.BackendSegment]) -> Void, + onError: ((Error) -> Void)?, + onReady: @escaping () -> Void + ) { + terminated = false + pendingPCM = Data() + + SFSpeechRecognizer.requestAuthorization { [weak self] status in + guard let self else { return } + switch status { + case .authorized: + self.audioSerialQueue.async { + self.beginRecognitionLocked(onSegments: onSegments, onError: onError, onReady: onReady) + } + case .denied, .restricted, .notDetermined: + onError?( + TranscriptionService.TranscriptionError.webSocketError( + "Speech recognition authorization denied")) + @unknown default: + onError?( + TranscriptionService.TranscriptionError.webSocketError( + "Speech recognition authorization unavailable")) + } + } + } + + /// Append microphone linear16 PCM (16 kHz, mono — same codec as `/v4/listen` streaming input). + func appendLinear16PCMSamples(_ pcm: Data) { + guard !pcm.isEmpty else { return } + audioSerialQueue.async { [weak self] in + guard let self, !self.terminated else { return } + guard let rr = recognitionRequest else { + self.pendingPCM.append(pcm) + return + } + guard let buf = Self.makePCMBuffer(fromLinear16PCM: pcm) else { return } + rr.append(buf) + } + } + + /// Tell Speech there will be no more audio (allows final results — used on PTT `finishStream` and on `stop`). + func endAudioInput() { + audioSerialQueue.async { [weak self] in + guard let self, !self.terminated else { return } + self.flushPendingPCMSamplesLocked() + self.recognitionRequest?.endAudio() + } + } + + /// Tear down streaming recognition immediately. + func cancel() { + audioSerialQueue.async { [weak self] in + guard let self else { return } + self.terminated = true + self.recognitionTask?.cancel() + self.recognitionTask = nil + self.recognitionRequest = nil + self.pendingPCM = Data() + } + } + + // MARK: - Private + + private func beginRecognitionLocked( + onSegments: @escaping ([TranscriptionService.BackendSegment]) -> Void, + onError: ((Error) -> Void)?, + onReady: @escaping () -> Void + ) { + terminated = false + let localeIdentifier = Self.normalizedLocaleIdentifier(forAssistantLanguageCode: languageCode) + let locale = Locale(identifier: localeIdentifier) + guard let recognizer = SFSpeechRecognizer(locale: locale), recognizer.isAvailable else { + DispatchQueue.main.async { + onError?( + TranscriptionService.TranscriptionError.webSocketError( + "Speech recognition unavailable for locale \(localeIdentifier)" + )) + } + return + } + + recognitionRequest = SFSpeechAudioBufferRecognitionRequest() + + guard let request = recognitionRequest else { return } + + request.shouldReportPartialResults = true + request.taskHint = .dictation + + recognitionTask = recognizer.recognitionTask(with: request) { [weak self] result, error in + guard let self else { return } + + self.audioSerialQueue.async { + guard !self.terminated else { return } + if let error { + DispatchQueue.main.async { + guard !self.terminated else { return } + onError?( + TranscriptionService.TranscriptionError.webSocketError( + error.localizedDescription)) + } + return + } + + guard let transcription = result?.bestTranscription else { return } + let trimmed = transcription.formattedString.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { return } + + let elapsedSeconds = CFAbsoluteTimeGetCurrent() - self.sessionWallClockBegin + let segments = Self.makeHybridRollingSegments(text: trimmed, elapsedSeconds: elapsedSeconds) + + DispatchQueue.main.async { + onSegments(segments) + } + } + } + + flushPendingPCMSamplesLocked() + DispatchQueue.main.async { [weak self] in + guard let self, !self.terminated else { return } + onReady() + } + } + + private func flushPendingPCMSamplesLocked() { + guard let rr = recognitionRequest, !pendingPCM.isEmpty else { return } + let chunk = pendingPCM + pendingPCM = Data() + guard let buf = Self.makePCMBuffer(fromLinear16PCM: chunk) else { return } + rr.append(buf) + } + + private static func makePCMBuffer(fromLinear16PCM pcm: Data) -> AVAudioPCMBuffer? { + let frameCount = pcm.count / MemoryLayout.size + guard frameCount > 0 else { return nil } + + guard + let format = AVAudioFormat( + commonFormat: .pcmFormatInt16, + sampleRate: 16_000, + channels: 1, + interleaved: false + ), + let buffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: AVAudioFrameCount(frameCount)) + else { return nil } + + buffer.frameLength = AVAudioFrameCount(frameCount) + pcm.withUnsafeBytes { raw in + guard let base = raw.baseAddress?.assumingMemoryBound(to: Int16.self), + let channel = buffer.int16ChannelData?[0] + else { return } + channel.update(from: base, count: frameCount) + } + return buffer + } + + /// Single rolling segment with a stable pseudo backend id — `TranscriptionStorage.upsertSegment` updates one row per session. + static func makeHybridRollingSegments(text: String, elapsedSeconds: TimeInterval) + -> [TranscriptionService.BackendSegment] + { + let end = max(0, elapsedSeconds) + let seg = TranscriptionService.BackendSegment( + id: pseudoBackendSegmentId, + text: text, + speaker: "SPEAKER_00", + speaker_id: 0, + is_user: true, + person_id: nil, + start: 0, + end: end, + translations: nil + ) + return [seg] + } +} diff --git a/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift b/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift index b060df86d46..fcbae586adc 100644 --- a/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift +++ b/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift @@ -74,7 +74,9 @@ struct DesktopHomeView: View { } } else if !appState.hasCompletedOnboarding { // State 2: Signed in but onboarding not complete - if shouldSkipOnboarding() { + if shouldSkipOnboarding() + || DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + { Color.clear.onAppear { log("DesktopHomeView: --skip-onboarding flag detected, skipping onboarding") appState.hasCompletedOnboarding = true @@ -363,6 +365,7 @@ struct DesktopHomeView: View { .preferredColorScheme(.dark) .tint(OmiColors.purplePrimary) .onAppear { + AuthService.shared.establishLocalGuestSessionIfNeeded() log( "DesktopHomeView: View appeared - isSignedIn=\(authState.isSignedIn), hasCompletedOnboarding=\(appState.hasCompletedOnboarding)" ) diff --git a/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift index 245660d5d7b..814c9ee7f83 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/ConversationsPage.swift @@ -308,16 +308,14 @@ struct ConversationsPage: View { .padding(.vertical, 12) // Folder tabs strip - if !isLocalDaemonMode { - FolderTabsStrip( - appState: appState, - onCreateFolder: { showCreateFolderSheet = true }, - onEditFolder: { folder in editingFolder = folder }, - onDeleteFolder: { folder in deletingFolder = folder } - ) - .padding(.horizontal, 24) - .padding(.bottom, 12) - } + FolderTabsStrip( + appState: appState, + onCreateFolder: { showCreateFolderSheet = true }, + onEditFolder: { folder in editingFolder = folder }, + onDeleteFolder: { folder in deletingFolder = folder } + ) + .padding(.horizontal, 24) + .padding(.bottom, 12) // List - show search results or regular conversations if !searchQuery.isEmpty { @@ -530,6 +528,37 @@ struct ConversationsPage: View { .buttonStyle(.plain) .disabled(isFilteringStarred) + // Multi-select to merge conversations (local daemon + cloud) + Button(action: { + withAnimation(.easeInOut(duration: 0.2)) { + if isMultiSelectMode { + isMultiSelectMode = false + selectedConversationIds.removeAll() + } else { + isMultiSelectMode = true + } + } + }) { + HStack(spacing: 6) { + Image(systemName: isMultiSelectMode ? "xmark.circle" : "checkmark.circle") + .scaledFont(size: 12) + Text(isMultiSelectMode ? "Cancel" : "Select") + .scaledFont(size: 12, weight: .medium) + } + .foregroundColor(isMultiSelectMode ? OmiColors.purplePrimary : OmiColors.textSecondary) + .padding(.horizontal, 14) + .padding(.vertical, 9) + .omiControlSurface( + fill: isMultiSelectMode + ? OmiColors.purplePrimary.opacity(0.12) + : OmiColors.backgroundSecondary, + radius: 16, + stroke: isMultiSelectMode + ? OmiColors.purplePrimary.opacity(0.28) + : OmiColors.border.opacity(0.14)) + } + .buttonStyle(.plain) + // Date filter button Button(action: { showDatePicker.toggle() diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 4a1a2f137df..17cadad10ac 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -228,6 +228,18 @@ struct SettingsContentView: View { @State private var backendSettings: [LocalDaemonSetting] = [] @State private var backendStatusError: String? @State private var isLoadingBackendStatus: Bool = false + @State private var hybridAiBaseURL: String = HybridProviderReadiness.defaultBaseURL() + @State private var hybridAiModel: String = HybridProviderReadiness.defaultModel() + @State private var hybridAiApiKey: String = "" + @State private var hybridChatBaseURL: String = HybridProviderReadiness.defaultBaseURL() + @State private var hybridChatModel: String = HybridProviderReadiness.defaultModel() + @State private var hybridChatApiKey: String = "" + @State private var hybridEmbedBaseURL: String = HybridProviderReadiness.defaultBaseURL() + @State private var hybridEmbedModel: String = HybridProviderReadiness.defaultModel() + @State private var hybridEmbedApiKey: String = "" + @State private var hybridProviderStatus: String? + @State private var isSavingHybridProvider: Bool = false + @State private var isTestingHybridProvider: Bool = false @State private var userSubscription: UserSubscriptionResponse? @State private var isLoadingSubscription: Bool = false @State private var subscriptionError: String? @@ -484,7 +496,11 @@ struct SettingsContentView: View { } loadBackendSettings() refreshSelectedBackendStatus() - loadSubscriptionInfo() + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + loadLocalHybridPlanUsage() + } else { + loadSubscriptionInfo() + } // Sync transcription state with appState isTranscribing = appState.isTranscribing // Sync floating bar state with persisted preference (not transient visibility) @@ -510,7 +526,11 @@ struct SettingsContentView: View { return } if newValue == .planUsage { - loadSubscriptionInfo() + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + loadLocalHybridPlanUsage() + } else { + loadSubscriptionInfo() + } } } .onReceive(NotificationCenter.default.publisher(for: .navigateToTaskSettings)) { _ in @@ -1740,6 +1760,188 @@ struct SettingsContentView: View { // MARK: - Plan and Usage Section private var planUsageSection: some View { + Group { + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + localHybridPlanUsageSection + } else { + cloudPlanUsageSection + } + } + } + + private var localHybridPlanUsageSection: some View { + VStack(spacing: 20) { + settingsCard(settingId: "planusage.local") { + VStack(alignment: .leading, spacing: 14) { + HStack(spacing: 16) { + Image(systemName: "desktopcomputer") + .scaledFont(size: 28) + .foregroundColor(OmiColors.purplePrimary) + + VStack(alignment: .leading, spacing: 4) { + Text("Local") + .scaledFont(size: 16, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + + Text( + "Data and transcripts stay on this Mac via the local daemon. No Omi subscription or usage metering." + ) + .scaledFont(size: 13) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } + + Spacer() + + if isLoadingBackendStatus { + ProgressView() + .controlSize(.small) + } else { + Button("Refresh") { + loadLocalHybridPlanUsage() + } + .buttonStyle(.bordered) + } + } + + if let health = backendHealth { + Divider() + .overlay(OmiColors.backgroundQuaternary) + Text("\(health.service) \(health.version) · \(health.dataDir ?? "local data")") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + .textSelection(.enabled) + } else if let err = backendStatusError { + Text(err) + .scaledFont(size: 12) + .foregroundColor(OmiColors.warning) + } + } + } + + settingsCard(settingId: "planusage.local.checklist") { + VStack(alignment: .leading, spacing: 12) { + Text("Hybrid provider setup") + .scaledFont(size: 15, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + + Text( + "Bring your own AI endpoints. Keys are stored in the local daemon SQLite database on this Mac and sent only to URLs you configure—not Omi cloud proxies." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + .fixedSize(horizontal: false, vertical: true) + + ForEach(HybridProviderReadiness.rows(from: backendSettings)) { row in + HStack(alignment: .top, spacing: 10) { + Image( + systemName: row.status == .configured || row.status == .optionalFallback + ? "checkmark.circle.fill" : "circle" + ) + .foregroundColor( + row.status == .configured + ? OmiColors.success + : (row.status == .optionalFallback + ? OmiColors.textTertiary : OmiColors.warning)) + VStack(alignment: .leading, spacing: 2) { + Text(row.label) + .scaledFont(size: 13, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + Text(row.detail) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + } + Spacer() + } + } + + Button(action: applyLocalHybridProviderDefaults) { + Text("Apply local defaults") + .scaledFont(size: 13, weight: .semibold) + } + .buttonStyle(.borderedProminent) + .disabled(isSavingHybridProvider) + } + } + + localHybridProvidersEditorCard + } + } + + private var localHybridProvidersEditorCard: some View { + settingsCard(settingId: "planusage.local.providers") { + VStack(alignment: .leading, spacing: 18) { + hybridProviderEditorBlock( + title: "Processing (ai_provider)", + baseURL: $hybridAiBaseURL, + model: $hybridAiModel, + apiKey: $hybridAiApiKey, + settingKey: "ai_provider" + ) + + Divider().background(OmiColors.backgroundQuaternary) + + hybridProviderEditorBlock( + title: "Chat (chat_provider)", + baseURL: $hybridChatBaseURL, + model: $hybridChatModel, + apiKey: $hybridChatApiKey, + settingKey: "chat_provider" + ) + + Divider().background(OmiColors.backgroundQuaternary) + + hybridProviderEditorBlock( + title: "Embeddings (embedding_provider)", + baseURL: $hybridEmbedBaseURL, + model: $hybridEmbedModel, + apiKey: $hybridEmbedApiKey, + settingKey: "embedding_provider" + ) + + if let hybridProviderStatus { + Text(hybridProviderStatus) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + .textSelection(.enabled) + } + } + } + } + + private func hybridProviderEditorBlock( + title: String, + baseURL: Binding, + model: Binding, + apiKey: Binding, + settingKey: String + ) -> some View { + VStack(alignment: .leading, spacing: 8) { + Text(title) + .scaledFont(size: 12, weight: .medium) + .foregroundColor(OmiColors.textTertiary) + TextField("Base URL", text: baseURL) + .textFieldStyle(.roundedBorder) + TextField("Model", text: model) + .textFieldStyle(.roundedBorder) + SecureField("API key (optional on loopback)", text: apiKey) + .textFieldStyle(.roundedBorder) + HStack(spacing: 10) { + Button("Save") { + saveHybridProvider(key: settingKey, baseURL: baseURL.wrappedValue, model: model.wrappedValue, apiKey: apiKey.wrappedValue) + } + .buttonStyle(.bordered) + .disabled(isSavingHybridProvider) + Button("Test") { + testHybridProvider(key: settingKey, baseURL: baseURL.wrappedValue, model: model.wrappedValue, apiKey: apiKey.wrappedValue) + } + .buttonStyle(.bordered) + .disabled(isTestingHybridProvider) + } + } + } + + private var cloudPlanUsageSection: some View { VStack(spacing: 20) { settingsCard(settingId: "planusage.current") { VStack(alignment: .leading, spacing: 14) { @@ -5770,6 +5972,7 @@ struct SettingsContentView: View { private var aboutSection: some View { VStack(spacing: 20) { backendStatusCard + hybridProvidersCard settingsCard(settingId: "about.version") { VStack(spacing: 16) { @@ -6044,6 +6247,120 @@ struct SettingsContentView: View { } } + private var hybridProvidersCard: some View { + let target = DesktopBackendEnvironment.selectedBackendTarget + return Group { + if target.mode == .localDaemon { + settingsCard(settingId: "about.hybrid_providers") { + VStack(alignment: .leading, spacing: 14) { + Text("Hybrid providers") + .scaledFont(size: 15, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + + Text( + "API keys are stored in the local daemon SQLite database on this Mac and sent only to the endpoints you configure—not to Omi cloud proxies." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + .fixedSize(horizontal: false, vertical: true) + + VStack(alignment: .leading, spacing: 8) { + Text("Processing (ai_provider)") + .scaledFont(size: 12, weight: .medium) + .foregroundColor(OmiColors.textTertiary) + TextField("Base URL", text: $hybridAiBaseURL) + .textFieldStyle(.roundedBorder) + TextField("Model", text: $hybridAiModel) + .textFieldStyle(.roundedBorder) + SecureField("API key (optional on loopback)", text: $hybridAiApiKey) + .textFieldStyle(.roundedBorder) + } + + HStack(spacing: 10) { + Button("Save") { + saveHybridAiProvider() + } + .buttonStyle(.borderedProminent) + .disabled(isSavingHybridProvider) + + Button("Test connection") { + testHybridAiProvider() + } + .buttonStyle(.bordered) + .disabled(isTestingHybridProvider) + + if isSavingHybridProvider || isTestingHybridProvider { + ProgressView().controlSize(.small) + } + } + + if let hybridProviderStatus { + Text(hybridProviderStatus) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + .textSelection(.enabled) + } + + Divider().background(OmiColors.backgroundQuaternary) + + Text("Capabilities") + .scaledFont(size: 12, weight: .medium) + .foregroundColor(OmiColors.textTertiary) + + ForEach(DesktopBackendEnvironment.capabilities(for: .localDaemon), id: \.capability) { + state in + HStack(alignment: .top, spacing: 8) { + Image(systemName: state.available ? "checkmark.circle.fill" : "xmark.circle") + .foregroundColor(state.available ? OmiColors.success : OmiColors.textTertiary) + VStack(alignment: .leading, spacing: 2) { + Text(state.capability.rawValue) + .scaledFont(size: 11, weight: .medium) + .foregroundColor(OmiColors.textSecondary) + if let reason = state.reason { + Text(reason) + .scaledFont(size: 10) + .foregroundColor(OmiColors.textTertiary) + } + } + Spacer() + } + } + } + } + .onAppear { + syncAllHybridProviderFieldsFromBackendSettings() + } + .onChange(of: backendSettings) { _, _ in + syncAllHybridProviderFieldsFromBackendSettings() + } + } + } + } + + private func syncHybridProviderFieldsFromBackendSettings() { + syncHybridProviderFields( + forKey: "ai_provider", alsoKeys: ["provider"], + intoBaseURL: &hybridAiBaseURL, model: &hybridAiModel, apiKey: &hybridAiApiKey) + } + + private func saveHybridAiProvider() { + saveHybridProvider( + key: "ai_provider", + baseURL: hybridAiBaseURL, + model: hybridAiModel, + apiKey: hybridAiApiKey + ) + } + + private func testHybridAiProvider() { + testHybridProvider( + key: "ai_provider", + baseURL: hybridAiBaseURL, + model: hybridAiModel, + apiKey: hybridAiApiKey + ) + } + private var localProcessingProviderStatus: String { let providerSetting = backendSettings.first { $0.key == "ai_provider" || $0.key == "provider" } guard let providerSetting else { @@ -7013,7 +7330,159 @@ struct SettingsContentView: View { } } + private func loadLocalHybridPlanUsage() { + refreshSelectedBackendStatus() + Task { + await HybridProviderBootstrap.ensureDefaultsIfNeeded() + await MainActor.run { + syncAllHybridProviderFieldsFromBackendSettings() + } + } + } + + private func syncAllHybridProviderFieldsFromBackendSettings() { + syncHybridProviderFields(forKey: "ai_provider", alsoKeys: ["provider"], intoBaseURL: &hybridAiBaseURL, model: &hybridAiModel, apiKey: &hybridAiApiKey) + syncHybridProviderFields(forKey: "chat_provider", alsoKeys: [], intoBaseURL: &hybridChatBaseURL, model: &hybridChatModel, apiKey: &hybridChatApiKey) + syncHybridProviderFields(forKey: "embedding_provider", alsoKeys: [], intoBaseURL: &hybridEmbedBaseURL, model: &hybridEmbedModel, apiKey: &hybridEmbedApiKey) + } + + private func syncHybridProviderFields( + forKey key: String, + alsoKeys: [String], + intoBaseURL baseURL: inout String, + model: inout String, + apiKey: inout String + ) { + let keys = [key] + alsoKeys + guard let raw = backendSettings.first(where: { keys.contains($0.key) })?.valueJson, + let data = raw.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { return } + if let base = json["base_url"] as? String, !base.isEmpty { + baseURL = base + } + if let m = json["model"] as? String { + model = m + } + if let k = json["api_key"] as? String ?? json["key"] as? String { + apiKey = k + } + } + + private func applyLocalHybridProviderDefaults() { + applyLocalHybridProviderDefaultsToUI() + guard !isSavingHybridProvider else { return } + isSavingHybridProvider = true + hybridProviderStatus = nil + let provider = HybridProviderBootstrap.defaultProviderObject() + Task { + do { + backendSettings = try await APIClient.shared.updateSelectedBackendSettings([ + "ai_provider": .object(provider), + "chat_provider": .object(provider), + ]) + await MainActor.run { + applyLocalHybridProviderDefaultsToUI() + syncAllHybridProviderFieldsFromBackendSettings() + hybridProviderStatus = "Applied local defaults (ai_provider + chat_provider)." + isSavingHybridProvider = false + } + do { + let test = try await APIClient.shared.testHybridProvider(key: "chat_provider") + await MainActor.run { + hybridProviderStatus = test.message + } + } catch { + await MainActor.run { + hybridProviderStatus = + "Saved defaults. Test connection failed: \(error.localizedDescription)" + } + } + } catch { + await MainActor.run { + hybridProviderStatus = error.localizedDescription + isSavingHybridProvider = false + } + } + } + } + + private func applyLocalHybridProviderDefaultsToUI() { + hybridAiBaseURL = HybridProviderReadiness.defaultBaseURL() + hybridAiModel = HybridProviderReadiness.defaultModel() + hybridChatBaseURL = hybridAiBaseURL + hybridChatModel = hybridAiModel + hybridEmbedBaseURL = hybridAiBaseURL + hybridEmbedModel = hybridAiModel + } + + private func saveHybridProvider(key: String, baseURL: String, model: String, apiKey: String) { + guard !isSavingHybridProvider else { return } + isSavingHybridProvider = true + hybridProviderStatus = nil + Task { + do { + var provider: [String: LocalDaemonSettingUpdateValue] = [ + "kind": "openai_compatible", + "base_url": .string(baseURL.trimmingCharacters(in: .whitespacesAndNewlines)), + ] + let trimmedModel = model.trimmingCharacters(in: .whitespacesAndNewlines) + if !trimmedModel.isEmpty { + provider["model"] = .string(trimmedModel) + } + let trimmedKey = apiKey.trimmingCharacters(in: .whitespacesAndNewlines) + if !trimmedKey.isEmpty { + provider["api_key"] = .string(trimmedKey) + } + backendSettings = try await APIClient.shared.updateSelectedBackendSettings([ + key: .object(provider) + ]) + await MainActor.run { + hybridProviderStatus = "Saved \(key) to the local daemon." + isSavingHybridProvider = false + syncAllHybridProviderFieldsFromBackendSettings() + } + } catch { + await MainActor.run { + hybridProviderStatus = error.localizedDescription + isSavingHybridProvider = false + } + } + } + } + + private func testHybridProvider(key: String, baseURL: String, model: String, apiKey: String) { + guard !isTestingHybridProvider else { return } + isTestingHybridProvider = true + hybridProviderStatus = nil + Task { + do { + var provider: [String: LocalDaemonSettingUpdateValue] = [ + "kind": "openai_compatible", + "base_url": .string(baseURL.trimmingCharacters(in: .whitespacesAndNewlines)), + "model": .string(model.trimmingCharacters(in: .whitespacesAndNewlines)), + ] + let trimmedKey = apiKey.trimmingCharacters(in: .whitespacesAndNewlines) + if !trimmedKey.isEmpty { + provider["api_key"] = .string(trimmedKey) + } + _ = try await APIClient.shared.updateSelectedBackendSettings([key: .object(provider)]) + let result = try await APIClient.shared.testHybridProvider(key: key) + await MainActor.run { + hybridProviderStatus = result.message + isTestingHybridProvider = false + } + } catch { + await MainActor.run { + hybridProviderStatus = error.localizedDescription + isTestingHybridProvider = false + } + } + } + } + private func loadSubscriptionInfo() { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { return } guard !isLoadingSubscription else { return } isLoadingSubscription = true subscriptionError = nil @@ -7072,6 +7541,7 @@ struct SettingsContentView: View { } private func loadOverageInfo() { + guard DesktopBackendEnvironment.selectedBackendTarget.mode != .localDaemon else { return } guard !isLoadingOverage else { return } isLoadingOverage = true Task { diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift index 0497da584d7..20daca71fab 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift @@ -496,8 +496,7 @@ actor InsightAssistant: ProactiveAssistant { /// Two-phase insight extraction: /// Phase 1 (text-only): Activity summary + SQL investigation loop. Model investigates via /// execute_sql, then calls `request_screenshot` with an ID and its findings so far. - /// Phase 2 (single vision call): Load the chosen screenshot + Phase 1 findings → single - /// Gemini call with image → provide_advice or no_advice. + /// Phase 2: screenshot pixels when vision_provider is configured in hybrid (otherwise OCR-only text). /// Returns (result, sqlQueryCount). private func runAdviceExtraction( jpegData: Data?, @@ -509,6 +508,11 @@ actor InsightAssistant: ProactiveAssistant { ) async throws -> (InsightExtractionResult?, Int) { var sqlCount = 0 + let hybridLocalDaemon = DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + let daemonSettings = (try? await APIClient.shared.getSelectedBackendSettings()) ?? [] + let insightUsesScreenshotImage = + !hybridLocalDaemon || HybridVisionProvider.isConfigured(settings: daemonSettings) + // Build prompt with current context let timeFormatter = DateFormatter() timeFormatter.dateFormat = "h:mm a, EEEE" @@ -672,62 +676,128 @@ actor InsightAssistant: ProactiveAssistant { } // ============================================= - // PHASE 2: Single vision call with chosen screenshot + // PHASE 2: Vision (cloud / hybrid + vision_provider) or OCR-only (hybrid without vision_provider) // ============================================= - log("Insight: Phase 2 — loading screenshot \(screenshotId)") + let phase2Tools = buildPhase2Tools() + let phase2InitialContents: [GeminiImageToolRequest.Content] + + if insightUsesScreenshotImage { + log("Insight: Phase 2 — loading screenshot \(screenshotId)") + + let imageData: Data + do { + guard let screenshot = try await RewindDatabase.shared.getScreenshot(id: screenshotId) else { + log("Insight: P2 screenshot not in DB: \(screenshotId)") + return (nil, sqlCount) + } + if screenshot.usesVideoStorage, let chunk = screenshot.videoChunkPath { + let activeChunk = await VideoChunkEncoder.shared.currentChunkPath + if chunk == activeChunk { + log("Insight: P2 screenshot is in active chunk, skipping") + return (nil, sqlCount) + } + } + let rawData = try await RewindStorage.shared.loadScreenshotData(for: screenshot) + imageData = Self.compressForGemini(rawData) ?? rawData + log("Insight: P2 loaded \(imageData.count) bytes (\(rawData.count) raw) from \(screenshot.appName)") + } catch { + log("Insight: P2 screenshot load failed: \(error.localizedDescription)") + return (nil, sqlCount) + } + + let phase2Prompt = """ + INVESTIGATION FINDINGS: + \(findings) + + The screenshot below is from the app/window identified during investigation. + + Before giving insight, CROSS-REFERENCE your findings: + - Use execute_sql to check if this issue was resolved in later screenshots + - Check if the user moved on to something else (the issue may be stale) + - Verify the context is still relevant by looking at nearby timestamps + + Then call provide_advice if the insight is still valid, or no_advice if it was resolved or is no longer relevant. + """ + + let base64 = imageData.base64EncodedString() + phase2InitialContents = [ + GeminiImageToolRequest.Content( + role: "user", + parts: [ + GeminiImageToolRequest.Part(text: phase2Prompt), + GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64), + ] + ), + ] + } else { + log("Insight: Phase 2 (hybrid OCR-only) — screenshot \(screenshotId)") - // Load the screenshot image - let imageData: Data - do { guard let screenshot = try await RewindDatabase.shared.getScreenshot(id: screenshotId) else { - log("Insight: P2 screenshot not in DB: \(screenshotId)") + log("Insight: P2 OCR path: screenshot not in DB: \(screenshotId)") return (nil, sqlCount) } - // Check active chunk if screenshot.usesVideoStorage, let chunk = screenshot.videoChunkPath { let activeChunk = await VideoChunkEncoder.shared.currentChunkPath if chunk == activeChunk { - log("Insight: P2 screenshot is in active chunk, skipping") + log("Insight: P2 OCR path: screenshot is in active chunk, skipping") return (nil, sqlCount) } } - let rawData = try await RewindStorage.shared.loadScreenshotData(for: screenshot) - imageData = Self.compressForGemini(rawData) ?? rawData - log("Insight: P2 loaded \(imageData.count) bytes (\(rawData.count) raw) from \(screenshot.appName)") - } catch { - log("Insight: P2 screenshot load failed: \(error.localizedDescription)") - return (nil, sqlCount) - } - // Build Phase 2 prompt — compact findings + image + cross-reference instruction - let phase2Prompt = """ - INVESTIGATION FINDINGS: - \(findings) + let ocrBody = screenshot.ocrText?.trimmingCharacters(in: .whitespacesAndNewlines) ?? "" + guard !ocrBody.isEmpty else { + log("Insight: P2 OCR path: no OCR text for screenshot \(screenshotId)") + return (nil, sqlCount) + } - The screenshot below is from the app/window identified during investigation. + let phase2Prompt = """ + INVESTIGATION FINDINGS: + \(findings) - Before giving insight, CROSS-REFERENCE your findings: - - Use execute_sql to check if this issue was resolved in later screenshots - - Check if the user moved on to something else (the issue may be stale) - - Verify the context is still relevant by looking at nearby timestamps + Hybrid local mode without vision_provider: no screenshot image is available. Use only the OCR text below plus execute_sql cross-checks. - Then call provide_advice if the insight is still valid, or no_advice if it was resolved or is no longer relevant. - """ + Screenshot id \(screenshotId), app: \(screenshot.appName), window: \(screenshot.windowTitle ?? "(none)"). - let phase2Tools = buildPhase2Tools() - let base64 = imageData.base64EncodedString() - var phase2Contents: [GeminiImageToolRequest.Content] = [ - GeminiImageToolRequest.Content( - role: "user", - parts: [ - GeminiImageToolRequest.Part(text: phase2Prompt), - GeminiImageToolRequest.Part(mimeType: "image/jpeg", data: base64), - ] - ) - ] + OCR TEXT: + \(ocrBody) + + Before giving insight, CROSS-REFERENCE your findings: + - Use execute_sql to check if this issue was resolved in later screenshots + - Check if the user moved on to something else (the issue may be stale) + - Verify the context is still relevant by looking at nearby timestamps + + Then call provide_advice if the insight is still valid, or no_advice if it was resolved or is no longer relevant. + """ + + phase2InitialContents = [ + GeminiImageToolRequest.Content( + role: "user", + parts: [GeminiImageToolRequest.Part(text: phase2Prompt)] + ), + ] + } + + return try await runInsightPhase2ToolLoop( + client: client, + phase2Contents: phase2InitialContents, + currentSystemPrompt: currentSystemPrompt, + phase2Tools: phase2Tools, + sqlCount: sqlCount + ) + } + + /// Phase 2 tool loop shared by vision (image + tools) and hybrid OCR-only paths. + private func runInsightPhase2ToolLoop( + client: GeminiClient, + phase2Contents initialContents: [GeminiImageToolRequest.Content], + currentSystemPrompt: String, + phase2Tools: GeminiTool, + sqlCount initialSqlCount: Int + ) async throws -> (InsightExtractionResult?, Int) { + var phase2Contents = initialContents + var sqlCount = initialSqlCount - // Phase 2 loop — model can cross-reference via SQL before deciding for p2Iteration in 0..<5 { let p2Contents = phase2Contents let p2SystemPrompt = currentSystemPrompt @@ -799,7 +869,7 @@ actor InsightAssistant: ProactiveAssistant { log("Insight: P2 unexpected tool: \(toolCall.name)") break } - break // Break on unexpected tool + break } return (nil, sqlCount) } diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift index f85f2a10aee..21e7ae002d7 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift @@ -1,4 +1,5 @@ import Foundation +import ImageIO // MARK: - Thinking Budget Configuration @@ -180,7 +181,17 @@ struct GeminiResponse: Decodable { /// Low-level client for communicating with the Gemini API via backend proxy. /// All requests route through the Rust backend (/v1/proxy/gemini/*) which adds /// the Gemini API key server-side. Auth uses Firebase Bearer token. +/// +/// In `localDaemon` hybrid mode there is no Gemini proxy — when `ai_provider` / +/// `chat_provider` (or BYOK OpenAI) is configured, requests use +/// ``HybridLLMClient`` (OpenAI-compatible chat completions) instead. actor GeminiClient { + private enum Transport { + case geminiProxy + case hybridOpenAICompatible + } + + private let transport: Transport private let model: String /// Backend proxy base URL (from OMI_DESKTOP_API_URL env var) @@ -251,14 +262,96 @@ actor GeminiClient { } init(apiKey: String? = nil, model: String = ModelQoS.Gemini.proactive) throws { - // BREAKING CHANGE (issue #5861): apiKey parameter is ignored. - // All Gemini requests now route through the backend proxy which supplies - // the key server-side. Defaults to production when OMI_DESKTOP_API_URL is absent - // so installed test bundles launched from Finder still have AI features. + // BREAKING CHANGE (issue #5861): apiKey parameter is ignored for cloud proxy mode. + self.model = model + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + self.transport = .hybridOpenAICompatible + return + } guard !Self.proxyBaseURL.isEmpty else { throw GeminiClientError.missingAPIKey } - self.model = model + self.transport = .geminiProxy + } + + private func mapHybridError(_ error: HybridLLMClient.ClientError) -> GeminiClientError { + switch error { + case .notConfigured: + return .missingAPIKey + case .invalidSettings: + return .apiError(error.localizedDescription) + case .invalidResponse: + return .invalidResponse + case .httpFailure(let status, let body): + return .apiError("HTTP \(status): \(body)") + } + } + + /// Text chat completion via hybrid OpenAI-compatible provider (or BYOK OpenAI). + private func hybridChatText( + systemPrompt: String, + userText: String, + jsonMode: Bool, + timeout: TimeInterval = 300 + ) async throws -> String { + let settings = try await HybridDaemonSettingsCache.shared.settings() + guard let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) else { + throw GeminiClientError.missingAPIKey + } + do { + return try await HybridLLMClient.chatCompletionText( + config: config, + systemPrompt: systemPrompt, + userText: userText, + jsonMode: jsonMode, + timeout: timeout + ) + } catch let error as HybridLLMClient.ClientError { + throw mapHybridError(error) + } + } + + /// Multimodal when `vision_provider` is set; otherwise macOS Vision OCR + text JSON. + private func hybridChatImageOrOCR( + prompt: String, + imageData: Data, + systemPrompt: String, + jsonMode: Bool, + timeout: TimeInterval = 300 + ) async throws -> String { + let settings = try await HybridDaemonSettingsCache.shared.settings() + guard let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) else { + throw GeminiClientError.missingAPIKey + } + let visionConfig = HybridLLMClient.loadVisionProviderConfig(from: settings) + do { + if let visionConfig { + return try await HybridLLMClient.chatCompletionMultimodalJPEG( + config: visionConfig, + systemPrompt: systemPrompt, + userText: prompt, + jpegData: imageData, + jsonMode: jsonMode, + timeout: timeout + ) + } + let ocr = try await HybridLLMClient.ScreenOCR.recognizeTextFromJPEG(imageData) + let user = + prompt + + "\n\n--- ON-SCREEN TEXT (macOS OCR) ---\n\(ocr)\n--- END OCR ---\n" + let sysp = + systemPrompt + + "\n\nReturn a single JSON object only (no prose or markdown fences)." + return try await HybridLLMClient.chatCompletionText( + config: config, + systemPrompt: sysp, + userText: user, + jsonMode: jsonMode, + timeout: timeout + ) + } catch let error as HybridLLMClient.ClientError { + throw mapHybridError(error) + } } /// Get Firebase auth header for proxy requests @@ -326,6 +419,12 @@ actor GeminiClient { return false } } + if let hybridError = error as? HybridLLMClient.ClientError { + if case .httpFailure(let status, _) = hybridError { + return status == 429 || status == 503 + } + return false + } // URLSession network errors are transient return (error as NSError).domain == NSURLErrorDomain } @@ -357,6 +456,15 @@ actor GeminiClient { for attempt in 0...maxRetries { do { + if transport == .hybridOpenAICompatible { + return try await hybridChatImageOrOCR( + prompt: prompt, + imageData: imageData, + systemPrompt: systemPrompt, + jsonMode: true + ) + } + // Wrap base64 encoding + JSON serialization in autoreleasepool. // These create bridged Obj-C objects (NSString, NSData) that accumulate // in Swift concurrency's cooperative thread pool without being drained. @@ -442,6 +550,15 @@ actor GeminiClient { for attempt in 0...maxRetries { do { + if transport == .hybridOpenAICompatible { + return try await hybridChatText( + systemPrompt: systemPrompt, + userText: prompt, + jsonMode: false, + timeout: timeout + ) + } + let request = GeminiRequest( contents: [ GeminiRequest.Content(parts: [ @@ -514,6 +631,14 @@ actor GeminiClient { for attempt in 0...maxRetries { do { + if transport == .hybridOpenAICompatible { + return try await hybridChatText( + systemPrompt: systemPrompt, + userText: prompt, + jsonMode: true + ) + } + let request = GeminiRequest( contents: [ GeminiRequest.Content(parts: [ @@ -568,6 +693,60 @@ actor GeminiClient { throw lastError! } + /// When hybrid mode has no `vision_provider`, strip inline image bytes and attach macOS OCR text instead. + private func hybridContentsWithOCRInsteadOfImages( + _ contents: [GeminiImageToolRequest.Content] + ) async throws -> [GeminiImageToolRequest.Content] { + var results: [GeminiImageToolRequest.Content] = [] + for content in contents { + var newParts: [GeminiImageToolRequest.Part] = [] + for part in content.parts { + if let inline = part.inlineData { + let mime = inline.mimeType.lowercased() + guard let rawImageData = Data(base64Encoded: inline.data, options: [.ignoreUnknownCharacters]) + else { + continue + } + let jpegData: Data + if mime.contains("webp"), let converted = Self.webpDataAsJPEG(rawImageData) { + jpegData = converted + } else if mime.contains("jpeg") || mime.contains("jpg") { + jpegData = rawImageData + } else { + continue + } + let ocr = try await HybridLLMClient.ScreenOCR.recognizeTextFromJPEG(jpegData) + let banner = + "\n\n--- ON-SCREEN TEXT (macOS OCR; hybrid mode without vision_provider) ---\n\(ocr)\n--- END OCR ---\n" + newParts.append(GeminiImageToolRequest.Part(text: banner)) + continue + } + newParts.append(part) + } + results.append(GeminiImageToolRequest.Content(role: content.role, parts: newParts)) + } + return results + } + + /// Best-effort WebP → JPEG for Vision OCR when assistants embed WebP inline data. + private nonisolated static func webpDataAsJPEG(_ webpData: Data) -> Data? { + guard let src = CGImageSourceCreateWithData(webpData as CFData, nil), + let cgImage = CGImageSourceCreateImageAtIndex(src, 0, nil) + else { + return nil + } + let destData = NSMutableData() + guard let dest = CGImageDestinationCreateWithData(destData, "public.jpeg" as CFString, 1, nil) + else { + return nil + } + CGImageDestinationAddImage(dest, cgImage, nil) + guard CGImageDestinationFinalize(dest) else { + return nil + } + return destData as Data + } + } @@ -627,15 +806,39 @@ struct GeminiTool: Encodable { struct Property: Encodable { let type: String - let description: String + let description: String? let `enum`: [String]? let items: Items? + let nestedProperties: [String: Property]? + let nestedRequired: [String]? - init(type: String, description: String, enumValues: [String]? = nil, items: Items? = nil) { + enum CodingKeys: String, CodingKey { + case type + case `enum` + case description + case items + case nestedProperties = "properties" + case nestedRequired = "required" + } + + init(type: String, description: String = "", enumValues: [String]? = nil, items: Items? = nil) { self.type = type - self.description = description + self.description = description.isEmpty ? nil : description self.enum = enumValues self.items = items + self.nestedProperties = nil + self.nestedRequired = nil + } + + init( + type: String, description: String = "", properties: [String: Property], required: [String] + ) { + self.type = type + self.description = description.isEmpty ? nil : description + self.enum = nil + self.items = nil + self.nestedProperties = properties + self.nestedRequired = required } struct Items: Encodable { @@ -806,6 +1009,33 @@ extension GeminiClient { for attempt in 0...maxRetries { do { + if transport == .hybridOpenAICompatible { + let settings = try await HybridDaemonSettingsCache.shared.settings() + guard let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) else { + throw GeminiClientError.missingAPIKey + } + let allowVision = HybridVisionProvider.isConfigured(settings: settings) + do { + let contentsForHybrid: [GeminiImageToolRequest.Content] + if allowVision { + contentsForHybrid = contents + } else { + contentsForHybrid = try await hybridContentsWithOCRInsteadOfImages(contents) + } + return try await HybridLLMClient.performGeminiCompatibleToolRound( + config: config, + systemPrompt: systemPrompt, + contents: contentsForHybrid, + tools: tools, + forceToolCall: forceToolCall, + allowVisionInlineJPEG: allowVision, + timeout: 300 + ) + } catch let error as HybridLLMClient.ClientError { + throw mapHybridError(error) + } + } + // Wrap JSON serialization in autoreleasepool (contents may include // large base64 image data that creates bridged Obj-C intermediaries). let requestBody: Data = try autoreleasepool { diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift index 8e0c9f9c7ee..b4cff6d8c37 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift @@ -62,6 +62,11 @@ actor EmbeddingService { /// - text: Text to embed /// - taskType: Optional Gemini task type (e.g. "RETRIEVAL_DOCUMENT", "RETRIEVAL_QUERY") func embed(text: String, taskType: String? = nil) async throws -> [Float] { + if HybridEmbeddingClient.isEnabled() { + let result = try await HybridEmbeddingClient.embedFromDaemonSettings(text: text) + return result.vector + } + guard !Self.proxyBaseURL.isEmpty else { throw EmbeddingError.missingAPIKey } @@ -107,6 +112,16 @@ actor EmbeddingService { /// - texts: Texts to embed /// - taskType: Optional Gemini task type (e.g. "RETRIEVAL_DOCUMENT", "RETRIEVAL_QUERY") func embedBatch(texts: [String], taskType: String? = nil) async throws -> [[Float]] { + if HybridEmbeddingClient.isEnabled() { + var results: [[Float]] = [] + results.reserveCapacity(texts.count) + for text in texts { + let result = try await HybridEmbeddingClient.embedFromDaemonSettings(text: text) + results.append(result.vector) + } + return results + } + guard !Self.proxyBaseURL.isEmpty else { throw EmbeddingError.missingAPIKey } diff --git a/desktop/Desktop/Sources/Providers/ChatProvider.swift b/desktop/Desktop/Sources/Providers/ChatProvider.swift index 83115d69e62..aa1094c6e7b 100644 --- a/desktop/Desktop/Sources/Providers/ChatProvider.swift +++ b/desktop/Desktop/Sources/Providers/ChatProvider.swift @@ -2460,10 +2460,16 @@ A screenshot may be attached — use it silently only if relevant. Never mention usageLimiter.recordQuery() } - // Ensure bridge is running - guard await ensureBridgeStarted() else { - errorMessage = "AI not available" - return + let localDaemon = DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + let mayUseHybridDirectChat = localDaemon && HybridChatClient.isEnabled() + + // Ensure Claude / ACP bridge when not using hybrid direct chat. Hybrid path may + // skip the bridge until multimodal attachments require ACP. + if !mayUseHybridDirectChat { + guard await ensureBridgeStarted() else { + errorMessage = "AI not available" + return + } } // Show upgrade prompt if over threshold but don't block the message @@ -2602,8 +2608,16 @@ A screenshot may be attached — use it silently only if relevant. Never mention var toolStartTimes: [String: Date] = [:] var sqlRowsReturned = 0 var sqlQueryCount = 0 + var hybridResolvedModel: String? do { + if mayUseHybridDirectChat { + await preparePromptContextIfNeeded() + if !isOnboarding { + cachedMainSystemPrompt = buildSystemPrompt(contextString: formatMemoriesSection()) + } + } + // Use the system prompt built at warmup. The agent bridge applies it only // at session/new; for the normal reused-session path it is ignored. // Passing it here ensures it is applied if the session was invalidated @@ -2645,109 +2659,161 @@ A screenshot may be attached — use it silently only if relevant. Never mention } } - // Query the active bridge with streaming - // Callbacks for agent bridge - let textDeltaHandler: AgentBridge.TextDeltaHandler = { [weak self] delta in - Task { @MainActor [weak self] in - self?.appendToMessage(id: aiMessageId, text: delta) + let useHybridNow = + mayUseHybridDirectChat && effectiveImageData == nil + && !attachmentsForMessage.contains(where: { $0.isImage }) + + let queryResult: AgentBridge.QueryResult + + if useHybridNow { + let historyPairs: [(role: String, text: String)] = + messages + .dropLast(2) + .filter { !$0.isStreaming && !$0.text.isEmpty } + .map { msg in + ( + role: msg.sender == .user ? "user" : "assistant", + text: msg.text + ) + } + let hybrid = try await HybridChatClient.completeFromDaemonSettings( + systemPrompt: systemPrompt, + conversationMessages: historyPairs, + userMessage: trimmedText + ) + hybridResolvedModel = hybrid.model + let normalized = normalizeAssistantSentenceSpacing(hybrid.text) + queryResult = AgentBridge.QueryResult( + text: normalized, + costUsd: 0, + sessionId: "", + inputTokens: hybrid.inputTokens, + outputTokens: hybrid.outputTokens, + cacheReadTokens: 0, + cacheWriteTokens: 0 + ) + } else { + if mayUseHybridDirectChat && !useHybridNow { + guard await ensureBridgeStarted() else { + throw BridgeError.notRunning + } } - } - let toolCallHandler: AgentBridge.ToolCallHandler = { callId, name, input in - let toolCall = ToolCall(name: name, arguments: input, thoughtSignature: nil) - let result = await ChatToolExecutor.execute(toolCall) - log("OMI tool \(name) executed for callId=\(callId)") - // Track SQL query stats for metadata - if name == "execute_sql" { - sqlQueryCount += 1 - // Parse row count from result (format: "\nN row(s)" at end) - if let match = result.range(of: #"(\d+) row\(s\)"#, options: .regularExpression) { - let numStr = result[match].components(separatedBy: " ").first ?? "0" - sqlRowsReturned += Int(numStr) ?? 0 + + // Query the active bridge with streaming + // Callbacks for agent bridge + let textDeltaHandler: AgentBridge.TextDeltaHandler = { [weak self] delta in + Task { @MainActor [weak self] in + self?.appendToMessage(id: aiMessageId, text: delta) } } - return result - } - let toolActivityHandler: AgentBridge.ToolActivityHandler = { [weak self] name, status, toolUseId, input in - Task { @MainActor [weak self] in - self?.addToolActivity( - messageId: aiMessageId, - toolName: name, - status: status == "started" ? .running : .completed, - toolUseId: toolUseId, - input: input - ) - if status == "started" { - toolNames.append(name) - toolStartTimes[name] = Date() - if (name.contains("browser") || name.contains("playwright")) { - let token = UserDefaults.standard.string(forKey: "playwrightExtensionToken") ?? "" - if token.isEmpty { - log("ChatProvider: Browser tool \(name) called without extension token — aborting query and prompting setup") - self?.needsBrowserExtensionSetup = true - self?.stopAgent() - // Keep floating-bar sessions non-intrusive: do not foreground - // the main window when the query originated from the floating bar. - if sessionKey != "floating" { - // Bring the app to the foreground so the setup sheet is visible - // (the failed browser attempt may have opened Chrome, stealing focus) - NSApp.activate() - for window in NSApp.windows where window.title.hasPrefix("Omi") { - window.makeKeyAndOrderFront(nil) + let toolCallHandler: AgentBridge.ToolCallHandler = { callId, name, input in + let toolCall = ToolCall(name: name, arguments: input, thoughtSignature: nil) + let result = await ChatToolExecutor.execute(toolCall) + log("OMI tool \(name) executed for callId=\(callId)") + // Track SQL query stats for metadata + if name == "execute_sql" { + sqlQueryCount += 1 + // Parse row count from result (format: "\nN row(s)" at end) + if let match = result.range(of: #"(\d+) row\(s\)"#, options: .regularExpression) { + let numStr = result[match].components(separatedBy: " ").first ?? "0" + sqlRowsReturned += Int(numStr) ?? 0 + } + } + return result + } + let toolActivityHandler: AgentBridge.ToolActivityHandler = { + [weak self] name, status, toolUseId, input in + Task { @MainActor [weak self] in + self?.addToolActivity( + messageId: aiMessageId, + toolName: name, + status: status == "started" ? .running : .completed, + toolUseId: toolUseId, + input: input + ) + if status == "started" { + toolNames.append(name) + toolStartTimes[name] = Date() + if (name.contains("browser") || name.contains("playwright")) { + let token = + UserDefaults.standard.string(forKey: "playwrightExtensionToken") ?? "" + if token.isEmpty { + log( + "ChatProvider: Browser tool \(name) called without extension token — aborting query and prompting setup" + ) + self?.needsBrowserExtensionSetup = true + self?.stopAgent() + // Keep floating-bar sessions non-intrusive: do not foreground + // the main window when the query originated from the floating bar. + if sessionKey != "floating" { + // Bring the app to the foreground so the setup sheet is visible + // (the failed browser attempt may have opened Chrome, stealing focus) + NSApp.activate() + for window in NSApp.windows where window.title.hasPrefix("Omi") { + window.makeKeyAndOrderFront(nil) + } } } + // Show the floating bar so the user has an always-on-top UI + // when Chrome takes focus (important on small screens) + if !FloatingControlBarManager.shared.isVisible { + log( + "ChatProvider: Browser tool active — showing floating bar so it stays above Chrome" + ) + FloatingControlBarManager.shared.showTemporarily() + } } - // Show the floating bar so the user has an always-on-top UI - // when Chrome takes focus (important on small screens) - if !FloatingControlBarManager.shared.isVisible { - log("ChatProvider: Browser tool active — showing floating bar so it stays above Chrome") - FloatingControlBarManager.shared.showTemporarily() - } + } else if status == "completed", + let startTime = toolStartTimes.removeValue(forKey: name) + { + let durationMs = Int(Date().timeIntervalSince(startTime) * 1000) + AnalyticsManager.shared.chatToolCallCompleted( + toolName: name, durationMs: durationMs) } - } else if status == "completed", let startTime = toolStartTimes.removeValue(forKey: name) { - let durationMs = Int(Date().timeIntervalSince(startTime) * 1000) - AnalyticsManager.shared.chatToolCallCompleted(toolName: name, durationMs: durationMs) } } - } - let thinkingDeltaHandler: AgentBridge.ThinkingDeltaHandler = { [weak self] text in - Task { @MainActor [weak self] in - self?.appendThinking(messageId: aiMessageId, text: text) - } - } - let toolResultDisplayHandler: AgentBridge.ToolResultDisplayHandler = { [weak self] toolUseId, name, output in - Task { @MainActor [weak self] in - self?.addToolResult(messageId: aiMessageId, toolUseId: toolUseId, name: name, output: output) - } - } - - let queryResult = try await agentBridge.query( - prompt: trimmedText, - systemPrompt: systemPrompt, - sessionKey: isOnboarding ? "onboarding" : (sessionKey ?? "main"), - cwd: workingDirectory, - mode: chatMode.rawValue, - model: model ?? modelOverride, - resume: resume, - imageData: effectiveImageData, - onTextDelta: textDeltaHandler, - onToolCall: toolCallHandler, - onToolActivity: toolActivityHandler, - onThinkingDelta: thinkingDeltaHandler, - onToolResultDisplay: toolResultDisplayHandler, - onAuthRequired: { [weak self] methods, authUrl in + let thinkingDeltaHandler: AgentBridge.ThinkingDeltaHandler = { [weak self] text in Task { @MainActor [weak self] in - self?.claudeAuthMethods = methods - self?.claudeAuthUrl = authUrl - self?.isClaudeAuthRequired = true + self?.appendThinking(messageId: aiMessageId, text: text) } - }, - onAuthSuccess: { [weak self] in + } + let toolResultDisplayHandler: AgentBridge.ToolResultDisplayHandler = { + [weak self] toolUseId, name, output in Task { @MainActor [weak self] in - self?.isClaudeAuthRequired = false - self?.checkClaudeConnectionStatus() + self?.addToolResult( + messageId: aiMessageId, toolUseId: toolUseId, name: name, output: output) } } - ) + + queryResult = try await agentBridge.query( + prompt: trimmedText, + systemPrompt: systemPrompt, + sessionKey: isOnboarding ? "onboarding" : (sessionKey ?? "main"), + cwd: workingDirectory, + mode: chatMode.rawValue, + model: model ?? modelOverride, + resume: resume, + imageData: effectiveImageData, + onTextDelta: textDeltaHandler, + onToolCall: toolCallHandler, + onToolActivity: toolActivityHandler, + onThinkingDelta: thinkingDeltaHandler, + onToolResultDisplay: toolResultDisplayHandler, + onAuthRequired: { [weak self] methods, authUrl in + Task { @MainActor [weak self] in + self?.claudeAuthMethods = methods + self?.claudeAuthUrl = authUrl + self?.isClaudeAuthRequired = true + } + }, + onAuthSuccess: { [weak self] in + Task { @MainActor [weak self] in + self?.isClaudeAuthRequired = false + self?.checkClaudeConnectionStatus() + } + } + ) + } // Flush any remaining buffered streaming text before finalizing streamingFlushWorkItem?.cancel() @@ -2762,7 +2828,7 @@ A screenshot may be attached — use it silently only if relevant. Never mention messages[index].text = messageText messages[index].isStreaming = false messages[index].metadata = MessageMetadata( - model: model ?? modelOverride, + model: hybridResolvedModel ?? model ?? modelOverride, inputTokens: queryResult.inputTokens, outputTokens: queryResult.outputTokens, cacheReadTokens: queryResult.cacheReadTokens, diff --git a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift index 9a34eb88938..d246d6ab2f0 100644 --- a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift +++ b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift @@ -2164,6 +2164,18 @@ actor RewindDatabase { } } + migrator.registerMigration("hybridEmbeddingMetadata") { db in + try db.alter(table: "screenshots") { t in + t.add(column: "embeddingModel", .text) + t.add(column: "embeddingDim", .integer) + } + try db.alter(table: "staged_tasks") { t in + t.add(column: "embeddingModel", .text) + t.add(column: "embeddingDim", .integer) + } + print("[RewindDatabase] Migration: Added embedding_model/embedding_dim metadata for hybrid embedders") + } + try migrator.migrate(queue) } diff --git a/desktop/Desktop/Sources/Rewind/Services/OCREmbeddingService.swift b/desktop/Desktop/Sources/Rewind/Services/OCREmbeddingService.swift index b92af576063..f2e5c3d6059 100644 --- a/desktop/Desktop/Sources/Rewind/Services/OCREmbeddingService.swift +++ b/desktop/Desktop/Sources/Rewind/Services/OCREmbeddingService.swift @@ -37,6 +37,27 @@ actor OCREmbeddingService { private init() {} + /// Hybrid local daemon: skip OCR embedding API work when direct embeddings are off or `embedding_provider` is missing. + private func hybridDaemonOCREmbeddingUnavailable() async -> Bool { + guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { + return false + } + if !HybridEmbeddingClient.isEnabled() { + return true + } + let settings = (try? await APIClient.shared.getSelectedBackendSettings()) ?? [] + return HybridEmbeddingClient.loadProviderConfig(from: settings) == nil + } + + private enum HybridOCREmbedLog { + nonisolated(unsafe) static var didLogSkip = false + static func logSkipOnce(_ message: @autoclosure () -> String) { + guard !didLogSkip else { return } + didLogSkip = true + log(message()) + } + } + // MARK: - Text Formatting /// Format screenshot text for embedding: prepend app context for better retrieval @@ -99,6 +120,14 @@ actor OCREmbeddingService { guard !pendingItems.isEmpty else { return } + if await hybridDaemonOCREmbeddingUnavailable() { + HybridOCREmbedLog.logSkipOnce( + "OCREmbeddingService: Skipping OCR embedding work in hybrid mode until OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED is on and embedding_provider is configured." + ) + startFlushTimerIfNeeded() + return + } + // Take the current batch and clear the buffer let batch = pendingItems pendingItems = [] @@ -164,6 +193,13 @@ actor OCREmbeddingService { /// Capped at 5000 items per launch to prevent cost spikes. func backfillIfNeeded() async { do { + if await hybridDaemonOCREmbeddingUnavailable() { + HybridOCREmbedLog.logSkipOnce( + "OCREmbeddingService: Skipping OCR embedding work in hybrid mode until OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED is on and embedding_provider is configured." + ) + return + } + let status = try await RewindDatabase.shared.getScreenshotEmbeddingBackfillStatus() if status.completed { log("OCREmbeddingService: Backfill already complete, skipping") @@ -243,6 +279,10 @@ actor OCREmbeddingService { // Flush any pending embeddings before searching so recent screenshots are findable await flushPendingEmbeddings() + if await hybridDaemonOCREmbeddingUnavailable() { + return [] + } + // Embed the query with RETRIEVAL_QUERY task type for asymmetric search let queryEmbedding = try await EmbeddingService.shared.embed(text: query, taskType: "RETRIEVAL_QUERY") diff --git a/desktop/Desktop/Sources/Rewind/UI/RewindViewModel.swift b/desktop/Desktop/Sources/Rewind/UI/RewindViewModel.swift index 4ee3bff3617..940dc7036ff 100644 --- a/desktop/Desktop/Sources/Rewind/UI/RewindViewModel.swift +++ b/desktop/Desktop/Sources/Rewind/UI/RewindViewModel.swift @@ -226,7 +226,7 @@ class RewindViewModel: ObservableObject { searchTask = Task { do { - // Run FTS and vector search in parallel + // Keyword search uses local GRDB FTS only (screenshots_fts); hybrid daemon does not route this. async let ftsResults = RewindDatabase.shared.search( query: trimmedQuery, appFilter: selectedApp, diff --git a/desktop/Desktop/Sources/SignInView.swift b/desktop/Desktop/Sources/SignInView.swift index ec7a53e7d96..139b1c9be80 100644 --- a/desktop/Desktop/Sources/SignInView.swift +++ b/desktop/Desktop/Sources/SignInView.swift @@ -3,6 +3,10 @@ import SwiftUI struct SignInView: View { @ObservedObject var authState: AuthState + private var isLocalHybridMode: Bool { + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + } + var body: some View { ZStack { // Full background @@ -28,16 +32,38 @@ struct SignInView: View { .scaledFont(size: 48, weight: .bold) .foregroundColor(OmiColors.textPrimary) - Text("Sign in to continue") + Text(isLocalHybridMode ? "Local hybrid mode" : "Sign in to continue") .font(.title3) .foregroundColor(OmiColors.textTertiary) + + if isLocalHybridMode { + Text("Cloud sign-in is optional. Your data stays on this Mac.") + .font(.subheadline) + .foregroundColor(OmiColors.textTertiary) + .multilineTextAlignment(.center) + .padding(.horizontal, 24) + } } Spacer() // Sign in buttons VStack(spacing: 12) { - // Sign in with Apple + if isLocalHybridMode { + Button(action: { + AuthService.shared.establishLocalGuestSessionIfNeeded() + }) { + Text("Continue without signing in") + .scaledFont(size: 17, weight: .medium) + .foregroundColor(.black) + .frame(maxWidth: .infinity) + .frame(height: 50) + .background(Color.white) + .cornerRadius(10) + } + .buttonStyle(.plain) + } + // Sign in with Apple (cloud mode only — hybrid local uses guest session) Button(action: { Task { do { @@ -66,7 +92,7 @@ struct SignInView: View { .cornerRadius(10) } .buttonStyle(.plain) - .disabled(authState.isLoading) + .disabled(authState.isLoading || isLocalHybridMode) // Sign in with Google Button(action: { @@ -101,7 +127,7 @@ struct SignInView: View { ) } .buttonStyle(.plain) - .disabled(authState.isLoading) + .disabled(authState.isLoading || isLocalHybridMode) // Loading overlay for both buttons if authState.isLoading { @@ -139,6 +165,11 @@ struct SignInView: View { } .frame(maxWidth: .infinity, maxHeight: .infinity) } + .onAppear { + if isLocalHybridMode { + AuthService.shared.establishLocalGuestSessionIfNeeded() + } + } } } diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index 4510b8ce6a6..6d9f0068f23 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -4,6 +4,8 @@ import Foundation /// Conversation capture: Python backend `/v4/listen` WebSocket (speech profiles, speaker assignment, memory events). /// PTT live streaming: Python backend `/v2/voice-message/transcribe-stream` WebSocket (transcription only). /// PTT batch: Python backend `/v2/voice-message/transcribe` REST API. +/// Local daemon hybrid (`OMI_HYBRID_DIRECT_STT_ENABLED` + Speech availability): buffered Apple Speech streaming, +/// uploads finished sessions through the Rust local daemon transcript pipeline (`TranscriptionRetryService`). /// Full stereo batch: removed (formerly Rust proxy Deepgram, now dead code). class TranscriptionService { @@ -37,6 +39,29 @@ class TranscriptionService { let start: Double let end: Double let translations: [BackendTranslation]? + + /// Construct segments for adapters (Apple Speech hybrid path) without JSON decoding. + internal init( + id: String?, + text: String, + speaker: String?, + speaker_id: Int?, + is_user: Bool, + person_id: String?, + start: Double, + end: Double, + translations: [BackendTranslation]? + ) { + self.id = id + self.text = text + self.speaker = speaker + self.speaker_id = speaker_id + self.is_user = is_user + self.person_id = person_id + self.start = start + self.end = end + self.translations = translations + } } /// Message event (from `/v4/listen` only — not used by PTT transcribe-stream) @@ -99,6 +124,10 @@ class TranscriptionService { private let streamingMode: StreamingMode private let contextKeywords: [String] + /// When true, PCM is recognized with Apple Speech (local daemon hybrid) instead of Python WebSockets. + private let useHybridLocalSpeech: Bool + private var localSpeechAdapter: LocalSpeechTranscriptionAdapter? + /// Python backend base URL for transcription endpoints. /// Resolution order: beta release channel → OMI_PYTHON_API_URL → https://api.omi.me/ /// NOTE: Do NOT fall back to OMI_DESKTOP_API_URL — that points to the Rust desktop-backend @@ -163,23 +192,36 @@ class TranscriptionService { /// - language: Language code for transcription (e.g., "en", "uk", "ru", "multi" for auto-detect) /// - mode: Streaming mode — `.conversation` for `/v4/listen` (default), `.ptt` for `/v2/voice-message/transcribe-stream` init(language: String = "en", mode: StreamingMode = .conversation, contextKeywords: [String] = []) throws { - guard DesktopBackendEnvironment.isCapability( - .hostedTranscription, - availableIn: DesktopBackendEnvironment.selectedBackendTarget.mode - ) else { + let backendMode = DesktopBackendEnvironment.selectedBackendTarget.mode + // Local daemon: stream PCM through Apple Speech when `directSTT` is on — either explicitly + // (`OMI_HYBRID_DIRECT_STT_ENABLED=1`) or when the Speech engine is available for preferred languages. + let hybridSpeech = backendMode == .localDaemon + && DesktopBackendEnvironment.isCapability(.directSTT, availableIn: backendMode) + + guard + DesktopBackendEnvironment.isCapability( + .hostedTranscription, + availableIn: backendMode + ) + || hybridSpeech + else { throw TranscriptionError.webSocketError( DesktopBackendEnvironment.unavailableReason( for: .hostedTranscription, - in: DesktopBackendEnvironment.selectedBackendTarget.mode + in: backendMode ) ?? "Hosted transcription is unavailable in local daemon mode" ) } + self.useHybridLocalSpeech = hybridSpeech self.apiKey = "" // Not needed — Python backend uses Firebase auth self.language = language self.streamingMode = mode self.contextKeywords = Self.sanitizedContextKeywords(contextKeywords) - log("TranscriptionService: Initialized for \(mode == .conversation ? "/v4/listen" : "/v2/voice-message/transcribe-stream"), language=\(language), contextKeywords=\(self.contextKeywords.count)") + + log( + "TranscriptionService: Initialized for \(Self.endpointLabel(mode: mode, hybridLocal: hybridSpeech)), language=\(language), contextKeywords=\(self.contextKeywords.count)" + ) } /// Initialize for batch (PTT) mode only — uses Python backend `/v2/voice-message/transcribe` @@ -204,6 +246,7 @@ class TranscriptionService { } // Batch mode uses Firebase auth + Python backend — no DG key needed + self.useHybridLocalSpeech = false self.apiKey = "" self.language = language self.streamingMode = .ptt // Batch doesn't stream, but PTT is the correct context @@ -228,6 +271,11 @@ class TranscriptionService { func finishStream() { flushAudioBuffer() + if useHybridLocalSpeech { + localSpeechAdapter?.endAudioInput() + return + } + // Only PTT mode uses the "finalize" protocol — conversation mode (/v4/listen) doesn't support it guard streamingMode == .ptt else { return } guard isConnected, let webSocketTask = webSocketTask else { return } @@ -254,9 +302,14 @@ class TranscriptionService { self.onError = onError self.onConnected = onConnected self.onDisconnected = onDisconnected - self.shouldReconnect = true + self.shouldReconnect = !useHybridLocalSpeech self.reconnectAttempts = 0 + if useHybridLocalSpeech { + startHybridLocalSpeechStreaming() + return + } + connect() } @@ -271,12 +324,16 @@ class TranscriptionService { // Flush any remaining audio flushAudioBuffer() + if useHybridLocalSpeech { + localSpeechAdapter?.endAudioInput() + } + disconnect() } /// Send audio data to the backend (buffered for efficiency) func sendAudio(_ data: Data) { - guard isConnected else { return } + guard useHybridLocalSpeech || isConnected else { return } audioBufferLock.lock() audioBuffer.append(data) @@ -304,9 +361,16 @@ class TranscriptionService { } } - /// Actually send an audio chunk to the backend + /// Actually send an audio chunk to the backend (or Apple Speech hybrid path). private func sendAudioChunk(_ data: Data) { - guard isConnected, let webSocketTask = webSocketTask else { return } + guard useHybridLocalSpeech || isConnected else { return } + + if useHybridLocalSpeech { + localSpeechAdapter?.appendLinear16PCMSamples(data) + return + } + + guard let webSocketTask = webSocketTask else { return } let message = URLSessionWebSocketTask.Message.data(data) webSocketTask.send(message) { [weak self] error in @@ -322,9 +386,77 @@ class TranscriptionService { return isConnected } + /// Label for diagnostics (hosted WebSocket endpoints vs hybrid local Apple Speech). + private static func endpointLabel(mode: StreamingMode, hybridLocal: Bool) -> String { + if hybridLocal { + switch mode { + case .conversation: return "/v4/listen (Apple Speech hybrid local)" + case .ptt: return "/v2/voice-message/transcribe-stream (Apple Speech hybrid local)" + } + } + switch mode { + case .conversation: return "/v4/listen" + case .ptt: return "/v2/voice-message/transcribe-stream" + } + } + + /// Apple Speech buffered recognition for local daemon hybrid mode (`OMI_HYBRID_DIRECT_STT_ENABLED=1`). + private func startHybridLocalSpeechStreaming() { + reconnectTask?.cancel() + reconnectTask = nil + watchdogTask?.cancel() + watchdogTask = nil + + let adapter = LocalSpeechTranscriptionAdapter(languageCode: language) + localSpeechAdapter = adapter + + localSpeechAdapter?.start( + onSegments: { [weak self] segments in + guard let self else { return } + self.lastDataReceivedAt = Date() + if !segments.isEmpty { + self.onBackendSegments?(segments) + } + }, + onError: { [weak self] error in + guard let self else { return } + self.shouldReconnect = false + self.isConnected = false + self.watchdogTask?.cancel() + self.watchdogTask = nil + self.cleanupLocalSpeechAdapterImmediately() + self.onError?(error) + }, + onReady: { [weak self] in + guard let self else { return } + self.isConnected = true + self.reconnectAttempts = 0 + self.lastDataReceivedAt = Date() + log("TranscriptionService: Apple Speech hybrid ready") + self.onConnected?() + } + ) + } + + /// Cancel hybrid recognition synchronously — used after fatal Speech errors before WebSocket teardown. + private func cleanupLocalSpeechAdapterImmediately() { + localSpeechAdapter?.cancel() + localSpeechAdapter = nil + } + + /// Tear down Speech with a brief delay after `endAudioInput()` so the task can flush finals. + private func scheduleLocalSpeechAdapterCancellation() { + guard let adapter = localSpeechAdapter else { return } + localSpeechAdapter = nil + DispatchQueue.global(qos: .utility).asyncAfter(deadline: .now() + 0.55) { + adapter.cancel() + } + } + // MARK: - Private Methods (Connection) private func connect() { + guard !useHybridLocalSpeech else { return } // Always use Firebase auth for Python backend Task { [weak self] in guard let self = self else { return } @@ -459,6 +591,16 @@ class TranscriptionService { } private func disconnect() { + if useHybridLocalSpeech { + isConnected = false + watchdogTask?.cancel() + watchdogTask = nil + scheduleLocalSpeechAdapterCancellation() + log("TranscriptionService: Disconnected (hybrid Apple Speech)") + onDisconnected?() + return + } + isConnected = false watchdogTask?.cancel() watchdogTask = nil @@ -471,6 +613,7 @@ class TranscriptionService { } func handleDisconnection() { + guard !useHybridLocalSpeech else { return } guard isConnected else { return } isConnected = false @@ -501,6 +644,7 @@ class TranscriptionService { /// Cleanup a failed/pending connection and schedule reconnect. /// Unlike handleDisconnection(), this works even when isConnected is false (pre-handshake failures). func cleanupAndReconnect() { + guard !useHybridLocalSpeech else { return } webSocketTask?.cancel(with: .abnormalClosure, reason: nil) webSocketTask = nil urlSession?.invalidateAndCancel() diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 877eae33ac2..cf92c663873 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -367,6 +367,10 @@ final class APIClientRoutingTests: XCTestCase { } func testLocalDaemonCapabilityMatrixDisablesCloudBoundFeatures() { + setenv("OMI_HYBRID_DIRECT_STT_ENABLED", "0", 1) + defer { + unsetenv("OMI_HYBRID_DIRECT_STT_ENABLED") + } let capabilities = Dictionary( uniqueKeysWithValues: DesktopBackendEnvironment.capabilities(for: .localDaemon) .map { ($0.capability, $0) } @@ -381,7 +385,29 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertEqual(capabilities[.payments]?.available, false) XCTAssertEqual(capabilities[.crispSupport]?.available, false) XCTAssertEqual(capabilities[.hostedTranscription]?.available, false) + XCTAssertEqual(capabilities[.directSTT]?.available, false) + XCTAssertEqual(capabilities[.directChat]?.available, false) + XCTAssertEqual(capabilities[.directEmbeddings]?.available, false) + XCTAssertEqual(capabilities[.optionalCloudSTT]?.available, false) + XCTAssertEqual(capabilities[.optionalCloudChat]?.available, false) XCTAssertNotNil(capabilities[.managedAgentVM]?.reason) + XCTAssertNotNil(capabilities[.directSTT]?.reason) + } + + func testLocalDaemonDirectSTTCanBeOptedInViaEnvRegardlessOfEngine() { + setenv("OMI_HYBRID_DIRECT_STT_ENABLED", "1", 1) + defer { + unsetenv("OMI_HYBRID_DIRECT_STT_ENABLED") + } + XCTAssertTrue(DesktopBackendEnvironment.isCapability(.directSTT, availableIn: .localDaemon)) + } + + func testLocalDaemonDirectChatCanBeOptedInViaEnv() { + setenv("OMI_HYBRID_DIRECT_CHAT_ENABLED", "1", 1) + defer { + unsetenv("OMI_HYBRID_DIRECT_CHAT_ENABLED") + } + XCTAssertTrue(DesktopBackendEnvironment.isCapability(.directChat, availableIn: .localDaemon)) } func testCloudCapabilityMatrixAllowsCloudBoundFeatures() { @@ -428,6 +454,9 @@ final class APIClientRoutingTests: XCTestCase { unsetenv("OMI_DESKTOP_BACKEND_MODE") unsetenv("OMI_LOCAL_DAEMON_URL") unsetenv("OMI_REWIND_DATABASE_ROOT") + unsetenv("OMI_HYBRID_DIRECT_STT_ENABLED") + unsetenv("OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED") + unsetenv("OMI_HYBRID_DIRECT_CHAT_ENABLED") } override func tearDown() { @@ -437,6 +466,9 @@ final class APIClientRoutingTests: XCTestCase { unsetenv("OMI_DESKTOP_BACKEND_MODE") unsetenv("OMI_LOCAL_DAEMON_URL") unsetenv("OMI_REWIND_DATABASE_ROOT") + unsetenv("OMI_HYBRID_DIRECT_STT_ENABLED") + unsetenv("OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED") + unsetenv("OMI_HYBRID_DIRECT_CHAT_ENABLED") URLCapture.reset() super.tearDown() } @@ -754,74 +786,105 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertEqual(body?["starred"] as? Bool, true) } - func testLocalModeMergeAndFolderActionsFailBeforeNetworkRequests() async { + func testLocalModeMergeRoutesToLocalDaemonWithoutAuth() async { setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) - setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:9765", 1) let client = await makeTestClient() - do { - _ = try await client.mergeConversations(ids: ["c1", "c2"]) - XCTFail("expected merge to be unavailable") - } catch { - guard case APIError.featureUnavailable(let feature, _) = error else { - XCTFail("expected featureUnavailable for merge, got \(error)") - return - } - XCTAssertEqual(feature, "conversation_merge") - } + _ = try? await client.mergeConversations(ids: ["conv-1", "conv-2"], reprocess: false) - do { - _ = try await client.getFolders() - XCTFail("expected folders to be unavailable") - } catch { - guard case APIError.featureUnavailable(let feature, _) = error else { - XCTFail("expected featureUnavailable for folders, got \(error)") - return - } - XCTAssertEqual(feature, "conversation_folders") - } + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 9765, + pathContains: "v1/conversations/merge", method: "POST", + label: "local mergeConversations") + XCTAssertNil(requests.first?.headers["Authorization"]) - do { - _ = try await client.createFolder(name: "Work") - XCTFail("expected folder creation to be unavailable") - } catch { - guard case APIError.featureUnavailable = error else { - XCTFail("expected featureUnavailable for folder creation, got \(error)") - return - } + let body = requests.first?.body.flatMap { + try? JSONSerialization.jsonObject(with: $0) as? [String: Any] } + XCTAssertEqual(body?["conversation_ids"] as? [String], ["conv-1", "conv-2"]) + XCTAssertEqual(body?["reprocess"] as? Bool, false) + } - do { - _ = try await client.updateFolder(id: "f1", name: "Renamed") - XCTFail("expected folder update to be unavailable") - } catch { - guard case APIError.featureUnavailable = error else { - XCTFail("expected featureUnavailable for folder update, got \(error)") - return - } - } + func testLocalModeGetFoldersRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8877", 1) + let client = await makeTestClient() - do { - try await client.deleteFolder(id: "f1") - XCTFail("expected folder deletion to be unavailable") - } catch { - guard case APIError.featureUnavailable = error else { - XCTFail("expected featureUnavailable for folder deletion, got \(error)") - return - } - } + _ = try? await client.getFolders() - do { - try await client.moveConversationToFolder(conversationId: "c1", folderId: "f1") - XCTFail("expected move-to-folder to be unavailable") - } catch { - guard case APIError.featureUnavailable = error else { - XCTFail("expected featureUnavailable for move-to-folder, got \(error)") - return - } - } + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 8877, + pathContains: "v1/conversation-folders", method: "GET", + label: "local getFolders") + XCTAssertNil(requests.first?.headers["Authorization"]) + } - XCTAssertTrue(URLCapture.capturedRequests.isEmpty) + func testLocalModeCreateFolderRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8878", 1) + let client = await makeTestClient() + + _ = try? await client.createFolder(name: "Work", description: "Notes") + + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 8878, + pathContains: "v1/conversation-folders", method: "POST", + label: "local createFolder") + XCTAssertNil(requests.first?.headers["Authorization"]) + } + + func testLocalModeUpdateFolderRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8879", 1) + let client = await makeTestClient() + + _ = try? await client.updateFolder(id: "fld-local-1", name: "Renamed") + + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 8879, + pathContains: "v1/conversation-folders/fld-local-1", method: "PATCH", + label: "local updateFolder") + XCTAssertNil(requests.first?.headers["Authorization"]) + } + + func testLocalModeDeleteFolderRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8880", 1) + let client = await makeTestClient() + + try? await client.deleteFolder(id: "fld-del", moveToFolderId: "fld-other") + + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 8880, + pathContains: "v1/conversation-folders/fld-del", method: "DELETE", + label: "local deleteFolder") + XCTAssertNil(requests.first?.headers["Authorization"]) + XCTAssertTrue(requests.first?.url.absoluteString.contains("move_to_folder_id=fld-other") ?? false) + } + + func testLocalModeMoveConversationToFolderPatchesLocalConversationWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8881", 1) + let client = await makeTestClient() + + try? await client.moveConversationToFolder(conversationId: "c-local", folderId: "f-local") + + let requests = URLCapture.capturedRequests + assertRoutes( + requests, host: "127.0.0.1", port: 8881, + pathContains: "v1/conversations/c-local", method: "PATCH", + label: "local moveConversationToFolder") + XCTAssertNil(requests.first?.headers["Authorization"]) + let body = requests.first?.body.flatMap { + try? JSONSerialization.jsonObject(with: $0) as? [String: Any] + } + XCTAssertEqual(body?["folder_id"] as? String, "f-local") } func testLocalUnauthenticatedRequestDoesNotRefreshAuthOn401() async { @@ -1001,6 +1064,49 @@ final class APIClientRoutingTests: XCTestCase { label: "getFolders") } + func testLocalModeConversationFoldersRouteToLocalDaemon() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + _ = try? await client.getFolders() as [Folder] + try? await client.createFolder(name: "Work", description: nil, color: "#111111") + _ = try? await client.updateFolder( + id: "f1", name: "Renamed", description: nil, color: nil, order: 1) + try? await client.deleteFolder(id: "f1", moveToFolderId: "f2") + + let requests = URLCapture.capturedRequests + XCTAssertEqual(requests.count, 4) + XCTAssertTrue(requests.allSatisfy { $0.url.host == "127.0.0.1" && $0.url.port == 8765 }) + XCTAssertTrue(requests.allSatisfy { $0.headers["Authorization"] == nil }) + XCTAssertEqual(requests.map(\.method), ["GET", "POST", "PATCH", "DELETE"]) + XCTAssertTrue(requests[0].url.path.contains("/v1/conversation-folders")) + XCTAssertEqual(requests[0].method, "GET") + XCTAssertTrue(requests[1].url.path.contains("/v1/conversation-folders")) + XCTAssertEqual(requests[1].method, "POST") + XCTAssertTrue(requests[2].url.path.contains("/v1/conversation-folders/f1")) + XCTAssertEqual(requests[2].method, "PATCH") + XCTAssertTrue(requests[3].url.path.contains("/v1/conversation-folders/f1")) + XCTAssertEqual(requests[3].method, "DELETE") + XCTAssertTrue(requests[3].url.query?.contains("move_to_folder_id=f2") == true) + } + + func testLocalModeMergeConversationsRoutesToLocalDaemon() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + let client = await makeTestClient() + + _ = try? await client.mergeConversations(ids: ["a", "b"], reprocess: false) + + XCTAssertEqual(URLCapture.capturedRequests.count, 1) + let request = URLCapture.capturedRequests.first + XCTAssertEqual(request?.url.host, "127.0.0.1") + XCTAssertEqual(request?.url.port, 8765) + XCTAssertEqual(request?.method, "POST") + XCTAssertTrue(request?.url.path.contains("/v1/conversations/merge") == true) + XCTAssertEqual(request?.headers["Authorization"], nil) + } + // -- Memories (POST → Python) -- func testCreateMemoryRoutesToPython() async { @@ -1453,6 +1559,70 @@ final class APIClientRoutingTests: XCTestCase { label: "deleteChatSession") } + func testLocalModeGetChatSessionsRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + defer { + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") + } + let client = await makeTestClient() + _ = try? await client.getChatSessions() as [ChatSession] + assertRoutes( + URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v2/chat-sessions", method: "GET", + label: "local getChatSessions") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + + func testLocalModeCreateChatSessionRoutesToLocalDaemon() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + defer { + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") + } + let client = await makeTestClient() + _ = try? await client.createChatSession(title: "local-title") as ChatSession + assertRoutes( + URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v2/chat-sessions", method: "POST", + label: "local createChatSession") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + + func testLocalModeSaveMessageUsesDefaultDaemonSessionWhenSessionIdNil() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + defer { + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") + } + let client = await makeTestClient() + _ = try? await client.saveMessage(text: "hello", sender: "human", sessionId: nil) + assertRoutes( + URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v2/chat-sessions/\(APIClient.localDaemonDefaultChatSessionId)/messages", + method: "POST", + label: "local saveMessage default session") + } + + func testLocalModeGetMessagesForSessionRoutesToLocalDaemon() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + defer { + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") + } + let client = await makeTestClient() + _ = try? await client.getMessages(sessionId: "sess-local", limit: 10, offset: 2) + as [ChatMessageDB] + assertRoutes( + URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v2/chat-sessions/sess-local/messages", method: "GET", + label: "local getMessages(sessionId:)") + } + // -- Desktop messages (DELETE → Python, path changed to v2/desktop/messages) -- func testDeleteMessagesRoutesToPython() async { diff --git a/desktop/Desktop/Tests/HybridEmbeddingClientTests.swift b/desktop/Desktop/Tests/HybridEmbeddingClientTests.swift new file mode 100644 index 00000000000..387b4de71fd --- /dev/null +++ b/desktop/Desktop/Tests/HybridEmbeddingClientTests.swift @@ -0,0 +1,39 @@ +import XCTest + +@testable import Omi_Computer + +final class HybridEmbeddingClientTests: XCTestCase { + func testLoadProviderConfigParsesOpenAICompatible() { + let settings = [ + LocalDaemonSetting( + key: "embedding_provider", + valueJson: """ + {"kind":"openai_compatible","base_url":"http://127.0.0.1:11434/v1","model":"nomic-embed-text","api_key":"k"} + """, + updatedAt: Date() + ) + ] + let config = HybridEmbeddingClient.loadProviderConfig(from: settings) + XCTAssertEqual(config?.baseURL, "http://127.0.0.1:11434/v1") + XCTAssertEqual(config?.model, "nomic-embed-text") + } + + func testCompatibilityRejectsMixedDimensions() { + XCTAssertFalse( + HybridEmbeddingClient.isCompatibleEmbedding( + storedModel: "nomic-embed-text", + storedDim: 768, + activeModel: "nomic-embed-text", + activeDim: 384 + ) + ) + XCTAssertTrue( + HybridEmbeddingClient.isCompatibleEmbedding( + storedModel: nil, + storedDim: nil, + activeModel: HybridEmbeddingClient.legacyGeminiModelId, + activeDim: HybridEmbeddingClient.legacyGeminiDimension + ) + ) + } +} diff --git a/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift b/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift new file mode 100644 index 00000000000..fce16ae3b5d --- /dev/null +++ b/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift @@ -0,0 +1,52 @@ +import XCTest + +@testable import Omi_Computer + +final class HybridLLMProviderConfigTests: XCTestCase { + + func testResolveEffectivePrefersChatProviderOverAiProvider() throws { + let payload = """ + [{"key":"chat_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://chat.local/v1\\",\\"model\\":\\"m-chat\\"}","updated_at":"2026-05-19T12:00:00Z"},{"key":"ai_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://ai.local/v1\\",\\"model\\":\\"m-ai\\"}","updated_at":"2026-05-19T12:00:00Z"}] + """ + let settings = try decodeSettings(payload) + let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) + XCTAssertEqual(config?.baseURL, "http://chat.local/v1") + XCTAssertEqual(config?.model, "m-chat") + } + + func testResolveEffectiveFallsBackToProviderAlias() throws { + let payload = """ + [{"key":"provider","value_json":"{\\"kind\\":\\"openai\\",\\"base_url\\":\\"http://legacy.local/v1\\",\\"model\\":\\"legacy\\"}","updated_at":"2026-05-19T12:00:00Z"}] + """ + let settings = try decodeSettings(payload) + let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) + XCTAssertEqual(config?.baseURL, "http://legacy.local/v1") + XCTAssertEqual(config?.model, "legacy") + } + + func testVisionProviderLoadsWhenConfigured() throws { + let payload = """ + [{"key":"vision_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://vision.local/v1\\",\\"model\\":\\"vlm\\"}","updated_at":"2026-05-19T12:00:00Z"}] + """ + let settings = try decodeSettings(payload) + let config = HybridLLMClient.loadVisionProviderConfig(from: settings) + XCTAssertEqual(config?.baseURL, "http://vision.local/v1") + XCTAssertEqual(config?.model, "vlm") + } + + func testHybridChatClientFallsBackToAiProvider() throws { + let payload = """ + [{"key":"ai_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://ai.local/v1\\",\\"model\\":\\"m-ai\\"}","updated_at":"2026-05-19T12:00:00Z"}] + """ + let settings = try decodeSettings(payload) + let config = HybridChatClient.resolveEffectiveChatConfig(from: settings) + XCTAssertEqual(config?.baseURL, "http://ai.local/v1") + XCTAssertEqual(config?.model, "m-ai") + } + + private func decodeSettings(_ jsonArray: String) throws -> [LocalDaemonSetting] { + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .iso8601 + return try decoder.decode([LocalDaemonSetting].self, from: Data(jsonArray.utf8)) + } +} diff --git a/desktop/Desktop/Tests/HybridVisionProviderTests.swift b/desktop/Desktop/Tests/HybridVisionProviderTests.swift new file mode 100644 index 00000000000..ab25d0fabbb --- /dev/null +++ b/desktop/Desktop/Tests/HybridVisionProviderTests.swift @@ -0,0 +1,47 @@ +import XCTest + +@testable import Omi_Computer + +final class HybridVisionProviderTests: XCTestCase { + + func testIsConfiguredWhenVisionProviderOpenAICompatible() throws { + let settings = try Self.decodeSettings( + key: "vision_provider", + value: ["kind": "openai_compatible", "base_url": "https://api.example.com/v1"] + ) + XCTAssertTrue(HybridVisionProvider.isConfigured(settings: settings)) + } + + func testIsNotConfiguredWithoutVisionProvider() throws { + let settings = try Self.decodeSettings( + key: "embedding_provider", + value: ["kind": "openai_compatible", "base_url": "https://api.example.com/v1"] + ) + XCTAssertFalse(HybridVisionProvider.isConfigured(settings: settings)) + } + + func testIsNotConfiguredWhenBaseUrlMissing() throws { + let settings = try Self.decodeSettings( + key: "vision_provider", + value: ["kind": "openai_compatible"] + ) + XCTAssertFalse(HybridVisionProvider.isConfigured(settings: settings)) + } + + private static func decodeSettings(key: String, value: [String: Any]) throws -> [LocalDaemonSetting] { + let valueJsonData = try JSONSerialization.data(withJSONObject: value) + guard let valueJson = String(data: valueJsonData, encoding: .utf8) else { + struct EncodeError: Error {} + throw EncodeError() + } + let row: [String: Any] = [ + "key": key, + "value_json": valueJson, + "updated_at": "2026-05-19T12:00:00Z", + ] + let payload = try JSONSerialization.data(withJSONObject: [row]) + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .iso8601 + return try decoder.decode([LocalDaemonSetting].self, from: payload) + } +} diff --git a/desktop/Desktop/Tests/LocalSpeechTranscriptionLocaleTests.swift b/desktop/Desktop/Tests/LocalSpeechTranscriptionLocaleTests.swift new file mode 100644 index 00000000000..19b71e660b8 --- /dev/null +++ b/desktop/Desktop/Tests/LocalSpeechTranscriptionLocaleTests.swift @@ -0,0 +1,55 @@ +import XCTest + +@testable import Omi_Computer + +final class LocalSpeechTranscriptionLocaleTests: XCTestCase { + + func testNormalizedLocaleUnderscoresBecomeHyphens() { + XCTAssertEqual( + LocalSpeechTranscriptionAdapter.normalizedLocaleIdentifier(forAssistantLanguageCode: "en_US"), + "en-US") + } + + func testNormalizedLocaleLowercaseUnderscores() { + XCTAssertEqual( + LocalSpeechTranscriptionAdapter.normalizedLocaleIdentifier(forAssistantLanguageCode: "en_au"), + "en-au") + } + + func testNormalizedLocalePreservesExistingHyphens() { + XCTAssertEqual( + LocalSpeechTranscriptionAdapter.normalizedLocaleIdentifier(forAssistantLanguageCode: "en-US"), + "en-US") + } + + func testNormalizedLocaleZhAlias() { + XCTAssertEqual( + LocalSpeechTranscriptionAdapter.normalizedLocaleIdentifier(forAssistantLanguageCode: "zh"), + "zh-CN") + } + + func testBackendSegmentHybridInitializerMatchesDecodedShape() throws { + let json = """ + [{"id":"apple-hybrid-live","text":"hello","speaker":"SPEAKER_00","speaker_id":0,"is_user":true,"person_id":null,"start":0,"end":12.5,"translations":null}] + """ + let decoded = try JSONDecoder().decode( + [TranscriptionService.BackendSegment].self, from: XCTUnwrap(json.data(using: .utf8))) + XCTAssertEqual(decoded.count, 1) + + let built = TranscriptionService.BackendSegment( + id: LocalSpeechTranscriptionAdapter.pseudoBackendSegmentId, + text: "hello", + speaker: "SPEAKER_00", + speaker_id: 0, + is_user: true, + person_id: nil, + start: 0, + end: 12.5, + translations: nil + ) + XCTAssertEqual(built.text, decoded[0].text) + XCTAssertEqual(built.id, decoded[0].id) + XCTAssertEqual(built.speaker_id, decoded[0].speaker_id) + XCTAssertEqual(built.is_user, decoded[0].is_user) + } +} diff --git a/desktop/Desktop/Tests/LocalSpeechTranscriptionMappingTests.swift b/desktop/Desktop/Tests/LocalSpeechTranscriptionMappingTests.swift new file mode 100644 index 00000000000..bd7e168052d --- /dev/null +++ b/desktop/Desktop/Tests/LocalSpeechTranscriptionMappingTests.swift @@ -0,0 +1,38 @@ +import XCTest + +@testable import Omi_Computer + +final class LocalSpeechTranscriptionMappingTests: XCTestCase { + + func testHybridRollingSegmentsMatchTranscriptionStorageContract() { + let segments = LocalSpeechTranscriptionAdapter.makeHybridRollingSegments( + text: "hello world", + elapsedSeconds: 12.5 + ) + XCTAssertEqual(segments.count, 1) + let segment = segments[0] + XCTAssertEqual(segment.id, LocalSpeechTranscriptionAdapter.pseudoBackendSegmentId) + XCTAssertEqual(segment.text, "hello world") + XCTAssertEqual(segment.speaker, "SPEAKER_00") + XCTAssertEqual(segment.speaker_id, 0) + XCTAssertTrue(segment.is_user) + XCTAssertNil(segment.person_id) + XCTAssertEqual(segment.start, 0) + XCTAssertEqual(segment.end, 12.5) + XCTAssertNil(segment.translations) + } + + func testNormalizedLocaleIdentifierHandlesMultiAndTwoLetterCodes() { + let multi = LocalSpeechTranscriptionAdapter.normalizedLocaleIdentifier( + forAssistantLanguageCode: "multi") + XCTAssertFalse(multi.isEmpty) + + let uk = LocalSpeechTranscriptionAdapter.normalizedLocaleIdentifier( + forAssistantLanguageCode: "uk") + XCTAssertTrue(uk.lowercased().contains("uk")) + + let zh = LocalSpeechTranscriptionAdapter.normalizedLocaleIdentifier( + forAssistantLanguageCode: "zh") + XCTAssertTrue(zh.lowercased().contains("zh")) + } +} diff --git a/desktop/README.md b/desktop/README.md index 0317ee95127..00e768afa36 100644 --- a/desktop/README.md +++ b/desktop/README.md @@ -15,7 +15,8 @@ dmg-assets/ DMG installer resources ## Development -Requires macOS 14.0+, Rust toolchain, and code signing with an Apple Developer ID. +Requires macOS 14.0+, Rust toolchain, code signing with an Apple Developer ID, and +Homebrew `webp` for the Swift app (`brew install webp`). ```bash # Run (builds Swift app, starts Rust backend, launches app) diff --git a/desktop/local-backend/docs/architecture.md b/desktop/local-backend/docs/architecture.md index 6e705d4701e..d6239cd4934 100644 --- a/desktop/local-backend/docs/architecture.md +++ b/desktop/local-backend/docs/architecture.md @@ -130,14 +130,16 @@ credentials. Cloud mode routing basics are covered by `APIClientRoutingTests`, which verifies default cloud URL selection, custom URL selection, and local daemon routing without auth. +## Rewind And The Local Daemon + +Rewind timeline, OCR, **GRDB FTS** over screenshot text, and on-disk vector search all use the desktop app’s existing local **GRDB/SQLite** database and files. They do not depend on the local daemon process. The daemon’s SQLite store is focused on conversations, transcripts, and hybrid provider settings; **migrating Rewind’s indices or media into the daemon database is deferred** until a sync and storage story is defined, so hybrid mode keeps Rewind fully local while MVP data still flows through the loopback API. + ## Known Limitations And Follow-Up Work - The desktop app currently has a documented dev launch contract for the daemon; production supervision/packaging is not implemented. - Hosted transcription is intentionally unavailable in local daemon mode. The MVP validates transcript import/append/finalize, not direct local STT parity. -- Existing desktop GRDB/Rewind stores are not migrated into the local daemon - database yet. - Local provider configuration exists at the daemon API/settings layer, but the user-facing settings workflow is still thin. - Cloud sync remains disabled until a dedicated optional sync adapter is diff --git a/desktop/local-backend/docs/hybrid-embedding-versioning.md b/desktop/local-backend/docs/hybrid-embedding-versioning.md new file mode 100644 index 00000000000..cdb382ce0aa --- /dev/null +++ b/desktop/local-backend/docs/hybrid-embedding-versioning.md @@ -0,0 +1,37 @@ +# Hybrid embedding versioning (ADR) + +## Problem + +Cloud desktop uses **3072-dimensional** `gemini-embedding-001` vectors stored in GRDB +(`staged_tasks.embedding`, `screenshots.embedding`). Hybrid mode cannot use the Omi +Gemini proxy (`EmbeddingService.proxyBaseURL` is empty in local daemon mode). + +Switching embedders changes vector dimension and semantic space. Mixed indexes produce +garbage similarity scores. + +## Schema (desktop GRDB) + +After migration `hybridEmbeddingMetadata`: + +| Table | Columns | +|-------|---------| +| `screenshots` | `embedding_model TEXT`, `embedding_dim INTEGER` | +| `staged_tasks` | `embedding_model TEXT`, `embedding_dim INTEGER` | + +Null `embedding_model` means legacy Gemini 3072-d (cloud-era rows). + +## Rules + +1. **Never search** across rows with different `(embedding_model, embedding_dim)`. +2. On embedder change, set `embedding = NULL` and reset backfill flags for affected tables. +3. `HybridEmbeddingClient` records model id + dimension on each write. +4. Default hybrid embedder: OpenAI-compatible `/embeddings` from `embedding_provider` in daemon settings (see [hybrid-provider-settings.md](hybrid-provider-settings.md)). + +## Enabling hybrid embeddings + +Desktop: `OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=1` and configure `embedding_provider` on the daemon. + +## Backfill + +- `screenshot_embedding_backfill` migration_status row controls Rewind OCR embeddings. +- Task staged-task backfill: `TaskAssistant` / `EmbeddingService.backfillIfNeeded` after provider change. diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md new file mode 100644 index 00000000000..551bee47c52 --- /dev/null +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -0,0 +1,74 @@ +# Hybrid provider settings (ADR) + +## Context + +Desktop hybrid mode (`OMI_DESKTOP_BACKEND_MODE=local`) stores provider credentials in the +local daemon SQLite `settings` table. Requests go **directly** to configured endpoints, +never through Omi Python/Rust proxies. + +## Settings keys + +| Key | Purpose | `kind` values (v1) | +|-----|---------|-------------------| +| `ai_provider` | Post-transcript processing (title/overview JSON) | `openai_compatible` | +| `provider` | Legacy alias for `ai_provider` | same | +| `stt_provider` | Live speech-to-text (epic 02+) | `openai_compatible`, `deepgram_direct` (reserved) | +| `chat_provider` | Chat / agent completions (epic 04+) | `openai_compatible`, `anthropic_direct` (reserved) | +| `embedding_provider` | Vector embeddings (epic 01+) | `openai_compatible`, `gemini_direct` (reserved) | +| `vision_provider` | Multimodal / screenshot models (epic 05+) | `openai_compatible`, `gemini_direct` (reserved) | + +Set a key to JSON `null` to clear it. + +## OpenAI-compatible object shape + +```json +{ + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "model": "local-model", + "api_key": "optional-for-ollama" +} +``` + +## Host policy + +`base_url` must be `http` or `https`. Hosts matching Omi cloud, Firebase, and Google +identity/Firestore endpoints are **denied** (see `is_denied_provider_host` in +`src/providers.rs`). + +Loopback and direct vendor APIs (OpenAI, Anthropic, Deepgram, etc.) are allowed. + +## Optional cloud tiers (desktop only) + +Environment flags (desktop process, not daemon): + +- `OMI_HYBRID_DIRECT_STT_ENABLED=1` — enables hybrid live transcription via Apple Speech in local daemon mode (also enabled by default when `desktop/run.sh` configures local mode; launcher writes this into bundled `.env` so GUI launches see it). +- `OMI_HYBRID_DIRECT_CHAT_ENABLED=1` — enables hybrid OpenAI-compatible chat (`HybridChatClient`) when combined with `chat_provider`; sessions/messages persist via daemon SQLite (`run.sh` defaults this on in local mode). +- `OMI_HYBRID_OPTIONAL_CLOUD_STT=1` — exposes `optionalCloudSTT` capability +- `OMI_HYBRID_OPTIONAL_CLOUD_CHAT=1` — exposes `optionalCloudChat` capability + +Default hybrid optional tiers: both cloud toggles off. `run.sh` local mode defaults direct STT/embeddings/chat capability env flags on for GUI launches; hosted Listen and pi-mono remain disabled without explicit optional-cloud flags / cloud backends. + +## Local dev defaults (seed) + +When the daemon starts via `make serve-local` or `desktop/run.sh` in local mode, +`desktop/local-backend/tools/seed_hybrid_defaults.sh` runs idempotently: + +- If `ai_provider` / `provider` is unset → sets OpenAI-compatible defaults. +- If `chat_provider` is unset → sets the same defaults. + +| Variable | Default | +|----------|---------| +| `OMI_HYBRID_DEFAULT_CHAT_BASE_URL` | `http://127.0.0.1:11434/v1` | +| `OMI_HYBRID_DEFAULT_CHAT_MODEL` | `llama3.2` | + +The desktop app also calls `HybridProviderBootstrap.ensureDefaultsIfNeeded()` on +local guest session startup. Chat resolves `chat_provider` → `ai_provider` → BYOK OpenAI +(see `HybridChatClient`). + +Configure or override in **Settings → Plan and Usage** (local mode) or via `PUT /v1/settings`. + +## Test connection + +`POST /v1/settings/test-provider` with body `{ "key": "ai_provider" }` runs a minimal +request against the configured provider (chat completions ping for `openai_compatible`). diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index 2bf1d92f569..5dab8d0d8b0 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -10,6 +10,8 @@ conversation storage, transcript ingestion, processing fallback, and search. - Rust toolchain with `cargo`. - Python 3 for the import helper. - `curl` for health and API checks. +- Homebrew **`webp`** for the desktop Swift build (`CWebP` / screen capture): + `brew install webp` (verify with `pkg-config --exists libwebp`). - For desktop app testing: the normal desktop development prerequisites from `desktop/README.md` and `desktop/run.sh`. @@ -33,8 +35,49 @@ Omi-hosted backend requests. It prints a concise pass/fail summary at the end. ## Primary Desktop Launch -For user-test runs, use one command from the repo root. This launches the -development app bundle only (`Omi Dev.app` / `com.omi.desktop-dev`), checks the +### One command from repo root (recommended) + +`make serve-local` starts the local daemon and desktop dev app in a tmux session +(`omi-hybrid-local`): top pane runs `cargo run` in `desktop/local-backend`, bottom +pane runs `desktop/run.sh` in hybrid local mode. Teardown: `make down-local`. + +```bash +make serve-local # attach or switch tmux client to omi-hybrid-local +make down-local # stop tmux session, daemon, and Omi Dev.app +``` + +If you are already inside another tmux session, `make serve-local` switches your +client to `omi-hybrid-local` (it does not nest sessions). To start without +attaching: + +```bash +OMI_HYBRID_LOCAL_ATTACH=0 make serve-local +tmux attach -t omi-hybrid-local +``` + +The first desktop build can take several minutes while SwiftPM resolves +packages; `run.sh` waits if another `swift-build` is already running on the +machine. + +**Troubleshooting `make serve-local`:** + +- You still see a different tmux session (for example `local-supergemma`): run + `tmux attach -t omi-hybrid-local` or `tmux switch-client -t omi-hybrid-local`. +- Bottom pane stuck on `Waiting for other SwiftPM instance`: another + `swift-build` is running (often Firebase/GRDB package resolve). Wait for it to + finish or stop that build, then re-run `make serve-local`. +- Daemon port in use after a crash: `make down-local`, then `make serve-local`. +- Verify daemon only: `curl http://127.0.0.1:8765/health` should return + `"service":"omi-local-backend"`. +- Hybrid providers: `make serve-local` and `desktop/run.sh` (local mode) run + `desktop/local-backend/tools/seed_hybrid_defaults.sh` when the daemon is healthy, + seeding `ai_provider` and `chat_provider` to `http://127.0.0.1:11434/v1` (Ollama) if unset. + Override with `OMI_HYBRID_DEFAULT_CHAT_BASE_URL` and `OMI_HYBRID_DEFAULT_CHAT_MODEL`. + +### Manual `run.sh` launch + +For user-test runs without tmux, use one command from the repo root. This launches +the development app bundle only (`Omi Dev.app` / `com.omi.desktop-dev`), checks the local daemon health endpoint, starts `desktop/local-backend` if needed, and keeps Omi-hosted backend URLs deliberately invalid so accidental cloud routing is obvious: @@ -57,6 +100,9 @@ Required environment: Recommended test-boundary environment: +- Live transcription in local daemon mode uses on-device Apple Speech when hybrid direct STT is enabled. + `./run.sh` injects `OMI_HYBRID_DIRECT_STT_ENABLED=1` into the bundled app `.env` for local daemon mode by default (and sets the same in the launcher environment). Disable with `OMI_HYBRID_DIRECT_STT_ENABLED=0` if you need to turn it off. + - `OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765` makes the daemon URL explicit. - `OMI_PYTHON_API_URL=http://omi-cloud-invalid:9001` makes accidental Python backend calls fail locally. @@ -117,6 +163,18 @@ OMI_DESKTOP_API_URL=http://omi-rust-invalid:9002 \ ## Confirm Local Mode In The App +Hybrid local mode does **not** require Apple/Google sign-in for daily use. The app +enters an on-device guest session automatically (`local-hybrid-guest`). Cloud +sign-in remains optional for account UI only; OAuth uses the Python backend and +will not work when `OMI_PYTHON_API_URL` is set to an invalid host (intentional +for hybrid testing). + +**Settings → Plan and Usage** shows a **Local** plan (not cloud Free/Neo tiers). +Use that section to configure hybrid providers (`ai_provider`, `chat_provider`, +`embedding_provider`). Keys are stored in the local daemon SQLite database on this Mac. +Cloud subscription, usage quotas, and the Advanced “BYOK free forever” flow are hidden +in local mode. + The Conversations header shows a `Local` chip when the app is using local daemon mode. Settings → About also includes a Backend Mode card with the selected daemon URL, auth requirement, `/health` result, data directory, and whether @@ -223,6 +281,10 @@ title, overview, or transcript text. ## Local Provider Configuration +See [hybrid-provider-settings.md](hybrid-provider-settings.md) for the full settings schema +(`ai_provider`, `stt_provider`, `chat_provider`, `embedding_provider`, `vision_provider`) and +`POST /v1/settings/test-provider`. + Processing works without provider keys by using deterministic fallback. To force that path, clear provider settings: @@ -303,3 +365,21 @@ directly to the configured provider, not to Omi-hosted backend services. daemon port. - Desktop launch/auth callback issues in custom test builds: keep the app name and bundle suffix aligned as described in the repo desktop agent rules. +- `swift build` fails with `cannot change to .../firebase-ios-sdk-...: No such file + or directory`: Swift Package Manager has a missing or partial git mirror for + `firebase-ios-sdk` (or `GRDB.swift`) in + `~/Library/Caches/org.swift.swiftpm/repositories/`. This is unrelated to local + daemon mode — the desktop app still resolves Firebase for auth even in local + mode. Stop competing builds, clear the broken mirrors, pre-resolve, then + re-run `./run.sh`: + + ```bash + pkill -f 'swift-build|swift-package' 2>/dev/null || true + rm -rf ~/Library/Caches/org.swift.swiftpm/repositories/firebase-ios-sdk-* + rm -rf ~/Library/Caches/org.swift.swiftpm/repositories/GRDB.swift-* + cd desktop + xcrun swift package resolve --package-path Desktop + ``` + + The first resolve can take several minutes. Do not interrupt it; a partial clone + leaves SPM pointing at a cache path that does not exist yet. diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index 71de6206bc7..1db6eeac00f 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -707,6 +707,206 @@ mod tests { Ok(()) } + #[tokio::test] + async fn hybrid_v2_chat_sessions_and_messages_routes() -> Result<()> { + let app = test_app()?; + + let empty_list = request_json(app.clone(), Method::GET, "/v2/chat-sessions", None).await?; + assert!( + empty_list.as_array().map(|a| a.is_empty()).unwrap_or(false), + "expected no user sessions yet" + ); + + let created = request_json( + app.clone(), + Method::POST, + "/v2/chat-sessions", + Some(json!({"title": "Route test"})), + ) + .await?; + let sid = created["id"].as_str().expect("session id"); + + let listed = request_json(app.clone(), Method::GET, "/v2/chat-sessions", None).await?; + assert_eq!(listed.as_array().expect("sessions").len(), 1); + + let one = request_json( + app.clone(), + Method::GET, + &format!("/v2/chat-sessions/{sid}"), + None, + ) + .await?; + assert_eq!(one["title"], "Route test"); + + request_json( + app.clone(), + Method::PATCH, + &format!("/v2/chat-sessions/{sid}"), + Some(json!({"starred": true})), + ) + .await?; + + let saved = request_json( + app.clone(), + Method::POST, + &format!("/v2/chat-sessions/{sid}/messages"), + Some(json!({"text": "hello daemon", "sender": "human"})), + ) + .await?; + assert!(saved["id"].as_str().is_some()); + + let msgs = request_json( + app.clone(), + Method::GET, + &format!("/v2/chat-sessions/{sid}/messages"), + None, + ) + .await?; + assert_eq!(msgs.as_array().expect("messages").len(), 1); + + let default_sid = "00000000-0000-4000-8000-000000000001"; + request_json( + app.clone(), + Method::POST, + &format!("/v2/chat-sessions/{default_sid}/messages"), + Some(json!({"text": "default thread", "sender": "human"})), + ) + .await?; + + let default_msgs = request_json( + app.clone(), + Method::GET, + &format!("/v2/chat-sessions/{default_sid}/messages"), + None, + ) + .await?; + assert_eq!(default_msgs.as_array().expect("default msgs").len(), 1); + + request_status( + app.clone(), + Method::DELETE, + &format!("/v2/chat-sessions/{sid}"), + None, + StatusCode::NO_CONTENT, + ) + .await?; + + Ok(()) + } + + #[tokio::test] + async fn conversation_folders_merge_and_folder_assignment_routes() -> Result<()> { + let app = test_app()?; + + let folder = request_json( + app.clone(), + Method::POST, + "/v1/conversation-folders", + Some(json!({ + "name": "Desk", + "description": "papers", + "color": "#111111", + })), + ) + .await?; + let folder_id = folder["folder"]["id"].as_str().expect("folder id"); + + let folders = + request_json(app.clone(), Method::GET, "/v1/conversation-folders", None).await?; + assert_eq!(folders["folders"].as_array().expect("folders array").len(), 1); + + let conv_a = request_json( + app.clone(), + Method::POST, + "/v1/conversations", + Some(json!({ + "session_id": "s-merge-a", + "title": "A", + "overview": "", + })), + ) + .await?; + let id_a = conv_a["conversation"]["id"].as_str().unwrap(); + + let conv_b = request_json( + app.clone(), + Method::POST, + "/v1/conversations", + Some(json!({ + "session_id": "s-merge-b", + "title": "B", + "overview": "", + })), + ) + .await?; + let id_b = conv_b["conversation"]["id"].as_str().unwrap(); + + request_json( + app.clone(), + Method::POST, + &format!("/v1/conversations/{id_a}/transcript-segments"), + Some(json!({"text": "one", "start_ms": 0, "end_ms": 50})), + ) + .await?; + request_json( + app.clone(), + Method::POST, + &format!("/v1/conversations/{id_b}/transcript-segments"), + Some(json!({"text": "two", "start_ms": 0, "end_ms": 60})), + ) + .await?; + + let merge_resp = request_json( + app.clone(), + Method::POST, + "/v1/conversations/merge", + Some(json!({ + "conversation_ids": [id_a, id_b], + "reprocess": false, + })), + ) + .await?; + assert_eq!(merge_resp["status"], "completed"); + let merged_id = merge_resp["new_conversation_id"] + .as_str() + .expect("merged id") + .to_string(); + + let folder_update = request_json( + app.clone(), + Method::PATCH, + &format!("/v1/conversation-folders/{folder_id}"), + Some(json!({"name": "Desk2"})), + ) + .await?; + assert_eq!(folder_update["folder"]["name"], "Desk2"); + + let assigned = request_json( + app.clone(), + Method::PATCH, + &format!("/v1/conversations/{merged_id}"), + Some(json!({"folder_id": folder_id})), + ) + .await?; + assert_eq!(assigned["conversation"]["folder_id"], folder_id); + + request_status( + app.clone(), + Method::DELETE, + &format!("/v1/conversation-folders/{folder_id}"), + None, + StatusCode::NO_CONTENT, + ) + .await?; + + let unfiled = + request_json(app.clone(), Method::GET, &format!("/v1/conversations/{merged_id}"), None) + .await?; + assert!(unfiled["conversation"]["folder_id"].is_null()); + + Ok(()) + } + fn test_app() -> Result { let config = Config { bind_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), diff --git a/desktop/local-backend/src/processing.rs b/desktop/local-backend/src/processing.rs index a384b3b1abe..a1141d7501c 100644 --- a/desktop/local-backend/src/processing.rs +++ b/desktop/local-backend/src/processing.rs @@ -232,6 +232,7 @@ fn persist_processing_output( ended_at: None, metadata: None, starred: None, + folder_id: None, }, )? .ok_or_else(|| anyhow!("conversation missing while persisting processing output"))?; diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs index 5ebc2a09047..2b90937f5da 100644 --- a/desktop/local-backend/src/providers.rs +++ b/desktop/local-backend/src/providers.rs @@ -107,7 +107,7 @@ pub fn configured_openai_provider(store: &Store) -> Result Result> { - for key in ["ai_provider", "provider"] { + for key in ["ai_provider", "provider", "chat_provider"] { let Some(setting) = store.settings().get(key)? else { continue; }; @@ -144,7 +144,20 @@ pub fn load_openai_config(store: &Store) -> Result Result<()> { + if value.is_null() { + return Ok(()); + } let kind = value["kind"].as_str().unwrap_or_default(); if kind != "openai" && kind != "openai_compatible" { return Ok(()); @@ -156,6 +169,34 @@ pub fn validate_provider_setting(value: &Value) -> Result<()> { validate_provider_base_url(base_url) } +pub fn validate_hybrid_provider_setting(key: &str, value: &Value) -> Result<()> { + if value.is_null() { + return Ok(()); + } + if !HYBRID_PROVIDER_SETTING_KEYS.contains(&key) { + return Ok(()); + } + validate_provider_setting(value) +} + +pub fn is_provider_configured(value: &Value) -> bool { + if value.is_null() { + return false; + } + let kind = value["kind"].as_str().unwrap_or_default(); + if kind != "openai" && kind != "openai_compatible" { + return false; + } + let api_key = value["api_key"] + .as_str() + .or_else(|| value["key"].as_str()) + .unwrap_or_default(); + !api_key.trim().is_empty() + || value["base_url"] + .as_str() + .is_some_and(|url| url.contains("127.0.0.1") || url.contains("localhost")) +} + fn validate_provider_base_url(base_url: &str) -> Result<()> { let url = reqwest::Url::parse(base_url) .with_context(|| format!("provider base_url is not a valid URL: {base_url}"))?; @@ -177,7 +218,60 @@ fn validate_provider_base_url(base_url: &str) -> Result<()> { Ok(()) } -fn is_denied_provider_host(host: &str) -> bool { +pub async fn test_configured_provider(store: &Store, key: &str) -> Result { + let Some(setting) = store.settings().get(key)? else { + return Err(anyhow!("setting {key} is not configured")); + }; + let value: Value = serde_json::from_str(&setting.value_json) + .with_context(|| format!("failed to parse {key}"))?; + if value.is_null() { + return Err(anyhow!("setting {key} is not configured")); + } + let kind = value["kind"].as_str().unwrap_or_default(); + if kind != "openai" && kind != "openai_compatible" { + return Err(anyhow!("test connection supports openai_compatible providers only")); + } + let provider = load_openai_config_from_value(&value)?; + let client = OpenAiCompatibleProvider::new(provider); + let _ = client + .complete_json(vec![ + ChatMessage::system("Reply with JSON only: {\"ok\":true}"), + ChatMessage::user("ping"), + ]) + .await?; + Ok(format!("{key} responded successfully")) +} + +fn load_openai_config_from_value(value: &Value) -> Result { + let kind = value["kind"].as_str().unwrap_or_default(); + if kind != "openai" && kind != "openai_compatible" { + return Err(anyhow!("unsupported provider kind: {kind}")); + } + let base_url = value["base_url"] + .as_str() + .unwrap_or("https://api.openai.com/v1") + .to_string(); + validate_provider_base_url(&base_url)?; + let model = value["model"].as_str().unwrap_or("gpt-4o-mini").to_string(); + let api_key = value["api_key"] + .as_str() + .or_else(|| value["key"].as_str()) + .unwrap_or_default() + .to_string(); + if api_key.trim().is_empty() + && !base_url.contains("127.0.0.1") + && !base_url.contains("localhost") + { + return Err(anyhow!("api_key is required for test connection")); + } + Ok(OpenAiCompatibleConfig { + base_url, + model, + api_key, + }) +} + +pub fn is_denied_provider_host(host: &str) -> bool { matches!(host, "api.omi.me" | "api.omiapi.com") || (host.starts_with("desktop-backend-") && host.ends_with(".a.run.app")) || host == "firebase.google.com" diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index 7904353f82a..b3e484ad79f 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -2,7 +2,7 @@ use axum::{ extract::{Path, Query, State}, http::StatusCode, response::{IntoResponse, Response}, - routing::{get, post}, + routing::{get, patch, post}, Json, Router, }; use chrono::{DateTime, Utc}; @@ -12,9 +12,9 @@ use serde_json::{json, Map, Value}; use crate::{ processing, providers, storage::{ - deterministic_id, AppendTranscriptResult, NewActionItem, NewConversation, NewMemory, - NewProcessingJob, NewTranscriptSegment, UpdateActionItem, UpdateConversation, UpdateMemory, - UpdateProfile, + deterministic_id, AppendTranscriptResult, ChatMessageDto, ChatSessionDto, NewActionItem, + NewConversation, NewFolder, NewMemory, NewProcessingJob, NewTranscriptSegment, + UpdateActionItem, UpdateConversation, UpdateFolder, UpdateMemory, UpdateProfile, }, AppState, }; @@ -25,11 +25,13 @@ pub fn router() -> Router { .route("/profile/status", get(profile_status)) .route("/v1/profile", get(get_profile).put(update_profile)) .route("/v1/settings", get(list_settings).put(update_settings)) + .route("/v1/settings/test-provider", post(test_provider)) .route( "/v1/conversations", get(list_conversations).post(create_conversation), ) .route("/v1/conversations/count", get(count_conversations)) + .route("/v1/conversations/merge", post(merge_conversations)) .route( "/v1/conversations/:id", get(get_conversation) @@ -64,7 +66,29 @@ pub fn router() -> Router { .route("/v1/processing-jobs", get(list_processing_jobs)) .route("/v1/processing-jobs/process-next", post(process_next_job)) .route("/v1/processing-jobs/status", get(processing_status)) + .route( + "/v1/conversation-folders", + get(list_conversation_folders).post(create_conversation_folder), + ) + .route( + "/v1/conversation-folders/:id", + patch(update_conversation_folder).delete(delete_conversation_folder), + ) .route("/v1/processing-jobs/:id", get(get_processing_job)) + .route( + "/v2/chat-sessions", + get(list_chat_sessions).post(create_chat_session), + ) + .route( + "/v2/chat-sessions/:id", + get(get_chat_session) + .patch(update_chat_session) + .delete(delete_chat_session), + ) + .route( + "/v2/chat-sessions/:session_id/messages", + get(list_chat_messages).post(create_chat_message), + ) } #[derive(Debug)] @@ -161,6 +185,7 @@ struct ListQuery { start_date: Option>, end_date: Option>, starred: Option, + folder_id: Option, } async fn list_conversations( @@ -176,6 +201,7 @@ async fn list_conversations( query.start_date, query.end_date, query.starred, + query.folder_id.as_deref(), ) .map_err(ApiError::internal)?; Ok(Json(json!({ "conversations": conversations }))) @@ -266,6 +292,8 @@ struct UpdateConversationRequest { ended_at: Option>, metadata: Option, starred: Option, + #[serde(default)] + folder_id: Option, } async fn update_conversation( @@ -273,6 +301,27 @@ async fn update_conversation( Path(id): Path, Json(request): Json, ) -> ApiResult { + let folder_id = match request.folder_id { + None => None, + Some(Value::Null) => Some(None), + Some(Value::String(s)) => Some(Some(s)), + Some(_) => { + return Err(ApiError::bad_request( + "folder_id must be a string or null", + )); + } + }; + if let Some(Some(ref fid)) = folder_id { + if !state + .store + .folders() + .exists_active(fid) + .map_err(ApiError::internal)? + { + return Err(ApiError::bad_request("folder not found")); + } + } + let conversation = state .store .conversations() @@ -285,6 +334,7 @@ async fn update_conversation( ended_at: request.ended_at.map(Some), metadata: request.metadata, starred: request.starred, + folder_id, }, ) .map_err(ApiError::internal)? @@ -783,9 +833,9 @@ async fn update_settings( State(state): State, Json(values): Json>, ) -> ApiResult { - for key in ["ai_provider", "provider"] { - if let Some(value) = values.get(key) { - providers::validate_provider_setting(value) + for key in providers::HYBRID_PROVIDER_SETTING_KEYS { + if let Some(value) = values.get(*key) { + providers::validate_hybrid_provider_setting(key, value) .map_err(|error| ApiError::bad_request(error.to_string()))?; } } @@ -797,6 +847,31 @@ async fn update_settings( Ok(Json(json!({ "settings": settings }))) } +#[derive(Debug, Deserialize)] +struct TestProviderRequest { + key: String, +} + +async fn test_provider( + State(state): State, + Json(request): Json, +) -> ApiResult { + if !providers::HYBRID_PROVIDER_SETTING_KEYS.contains(&request.key.as_str()) { + return Err(ApiError::bad_request(format!( + "unsupported provider setting key: {}", + request.key + ))); + } + let message = providers::test_configured_provider(&state.store, &request.key) + .await + .map_err(|error| ApiError::bad_request(error.to_string()))?; + Ok(Json(json!({ + "ok": true, + "key": request.key, + "message": message + }))) +} + async fn list_processing_jobs(State(state): State) -> ApiResult { let jobs = state .store @@ -852,6 +927,303 @@ async fn processing_status(State(state): State) -> ApiResult { }))) } +#[derive(Deserialize)] +struct MergeConversationsBody { + #[serde(rename = "conversation_ids")] + conversation_ids: Vec, + #[serde(default = "merge_default_reprocess")] + reprocess: bool, +} + +fn merge_default_reprocess() -> bool { + true +} + +async fn merge_conversations( + State(state): State, + Json(body): Json, +) -> ApiResult { + if body.conversation_ids.len() < 2 { + return Err(ApiError::bad_request( + "at least two conversation_ids are required", + )); + } + + let merged = state + .store + .merge_conversations(&body.conversation_ids, body.reprocess) + .map_err(|e| ApiError::bad_request(e.to_string()))?; + + if body.reprocess { + if state + .store + .processing_jobs() + .reusable_for_conversation("finalize_transcript", &merged.id) + .map_err(ApiError::internal)? + .is_none() + { + state + .store + .processing_jobs() + .enqueue(NewProcessingJob { + id: local_id("job"), + kind: "finalize_transcript".to_string(), + target_conversation_id: Some(merged.id.clone()), + max_retries: Some(3), + payload: Some(json!({ "conversation_id": merged.id })), + }) + .map_err(ApiError::internal)?; + } + } + + Ok(Json(json!({ + "status": "completed", + "message": "Conversations merged locally", + "warning": serde_json::Value::Null, + "conversation_ids": body.conversation_ids, + "new_conversation_id": merged.id, + }))) +} + +#[derive(Deserialize)] +struct CreateConversationFolderBody { + name: String, + description: Option, + color: Option, +} + +async fn list_conversation_folders(State(state): State) -> ApiResult { + let folders = state.store.folders().list().map_err(ApiError::internal)?; + Ok(Json(json!({ "folders": folders }))) +} + +async fn create_conversation_folder( + State(state): State, + Json(body): Json, +) -> ApiResult { + let id = local_id("fld"); + let folder = state + .store + .folders() + .create(NewFolder { + id, + name: body.name, + description: body.description, + color: body.color, + }) + .map_err(ApiError::internal)?; + Ok(Json(json!({ "folder": folder }))) +} + +#[derive(Deserialize)] +struct UpdateConversationFolderBody { + name: Option, + description: Option, + color: Option, + order: Option, +} + +async fn update_conversation_folder( + State(state): State, + Path(id): Path, + Json(body): Json, +) -> ApiResult { + let folder = state + .store + .folders() + .update( + &id, + UpdateFolder { + name: body.name, + description: body.description, + color: body.color, + sort_order: body.order, + }, + ) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("folder"))?; + Ok(Json(json!({ "folder": folder }))) +} + +#[derive(Deserialize)] +struct DeleteConversationFolderQuery { + move_to_folder_id: Option, +} + +async fn delete_conversation_folder( + State(state): State, + Path(id): Path, + Query(query): Query, +) -> Result { + let deleted = state + .store + .folders() + .soft_delete(&id, query.move_to_folder_id.as_deref()) + .map_err(|e| ApiError::bad_request(e.to_string()))?; + if deleted { + Ok(StatusCode::NO_CONTENT) + } else { + Err(ApiError::not_found("folder")) + } +} + +#[derive(Deserialize)] +struct ChatSessionsListQuery { + limit: Option, + offset: Option, + app_id: Option, + starred: Option, +} + +async fn list_chat_sessions( + State(state): State, + Query(query): Query, +) -> ApiResult> { + let limit = query.limit.unwrap_or(50).clamp(1, 500); + let offset = query.offset.unwrap_or(0).max(0); + let sessions = state + .store + .chat_sessions() + .list_sessions( + limit, + offset, + query.app_id.as_deref(), + query.starred, + ) + .map_err(ApiError::internal)?; + Ok(Json(sessions)) +} + +#[derive(Deserialize)] +struct CreateChatSessionBody { + title: Option, + #[serde(rename = "app_id")] + app_id: Option, +} + +async fn create_chat_session( + State(state): State, + Json(body): Json, +) -> Result<(StatusCode, Json), ApiError> { + let session = state + .store + .chat_sessions() + .create_session(body.title.as_deref(), body.app_id.as_deref()) + .map_err(ApiError::internal)?; + Ok((StatusCode::CREATED, Json(session))) +} + +async fn get_chat_session( + State(state): State, + Path(id): Path, +) -> ApiResult { + let session = state + .store + .chat_sessions() + .get_session(&id) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("chat session"))?; + Ok(Json(session)) +} + +#[derive(Deserialize)] +struct PatchChatSessionBody { + title: Option, + starred: Option, +} + +async fn update_chat_session( + State(state): State, + Path(id): Path, + Json(body): Json, +) -> ApiResult { + let session = state + .store + .chat_sessions() + .update_session(&id, body.title.as_deref(), body.starred) + .map_err(ApiError::internal)? + .ok_or_else(|| ApiError::not_found("chat session"))?; + Ok(Json(session)) +} + +async fn delete_chat_session( + State(state): State, + Path(id): Path, +) -> Result { + match state.store.chat_sessions().delete_session(&id) { + Ok(true) => Ok(StatusCode::NO_CONTENT), + Ok(false) => Err(ApiError::not_found("chat session")), + Err(e) => Err(ApiError::bad_request(e.to_string())), + } +} + +#[derive(Deserialize)] +struct ChatMessagesListQuery { + limit: Option, + offset: Option, + #[serde(rename = "app_id")] + app_id: Option, +} + +async fn list_chat_messages( + State(state): State, + Path(session_id): Path, + Query(query): Query, +) -> ApiResult> { + let limit = query.limit.unwrap_or(100).clamp(1, 500); + let offset = query.offset.unwrap_or(0).max(0); + let messages = state + .store + .chat_sessions() + .list_messages(&session_id, query.app_id.as_deref(), limit, offset) + .map_err(ApiError::internal)?; + Ok(Json(messages)) +} + +#[derive(Deserialize)] +struct CreateChatMessageBody { + text: String, + sender: String, + #[serde(rename = "app_id")] + app_id: Option, + metadata: Option, +} + +#[derive(Serialize)] +struct SaveChatMessageResponse { + id: String, + #[serde(rename = "created_at")] + created_at: DateTime, +} + +async fn create_chat_message( + State(state): State, + Path(session_id): Path, + Json(body): Json, +) -> Result<(StatusCode, Json), ApiError> { + let (id, created_at) = state + .store + .chat_sessions() + .append_message( + &session_id, + &body.text, + &body.sender, + body.app_id.as_deref(), + body.metadata.as_deref(), + ) + .map_err(|e| { + if e.to_string().contains("not found") { + ApiError::not_found("chat session") + } else { + ApiError::internal(e) + } + })?; + Ok(( + StatusCode::CREATED, + Json(SaveChatMessageResponse { id, created_at }), + )) +} + fn limit_or_default(limit: Option) -> i64 { limit.unwrap_or(50).clamp(1, 200) } diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index 1f5d2801085..f1b88044317 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -229,6 +229,78 @@ const MIGRATIONS: &[Migration] = &[ CREATE INDEX idx_conversations_starred ON conversations(starred, updated_at); "#, }, + Migration { + version: 3, + name: "conversation_folders", + sql: r#" + CREATE TABLE folders ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL DEFAULT '', + description TEXT, + color TEXT NOT NULL DEFAULT '#6B7280', + sort_order INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + deleted_at TEXT + ); + + CREATE INDEX idx_folders_updated ON folders(updated_at); + + ALTER TABLE conversations ADD COLUMN folder_id TEXT REFERENCES folders(id) ON DELETE SET NULL; + CREATE INDEX idx_conversations_folder_id ON conversations(folder_id); + "#, + }, + Migration { + version: 4, + name: "folder_metadata_columns", + sql: r#" + ALTER TABLE folders ADD COLUMN is_default INTEGER NOT NULL DEFAULT 0; + ALTER TABLE folders ADD COLUMN is_system INTEGER NOT NULL DEFAULT 0; + ALTER TABLE folders ADD COLUMN category_mapping TEXT; + "#, + }, + Migration { + version: 5, + name: "hybrid_chat_sessions", + sql: r#" + CREATE TABLE chat_sessions ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL DEFAULT 'New Chat', + preview TEXT, + app_id TEXT, + starred INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE INDEX idx_chat_sessions_updated ON chat_sessions(updated_at DESC); + + CREATE TABLE chat_messages ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL REFERENCES chat_sessions(id) ON DELETE CASCADE, + text TEXT NOT NULL DEFAULT '', + sender TEXT NOT NULL DEFAULT 'human', + app_id TEXT, + metadata TEXT, + created_at TEXT NOT NULL, + rating INTEGER, + reported INTEGER NOT NULL DEFAULT 0 + ); + + CREATE INDEX idx_chat_messages_session_created ON chat_messages(session_id, created_at); + + INSERT INTO chat_sessions (id, title, preview, app_id, starred, created_at, updated_at) + VALUES ( + '00000000-0000-4000-8000-000000000001', + 'Default Chat', + NULL, + NULL, + 0, + '2020-01-01T00:00:00Z', + '2020-01-01T00:00:00Z' + ); + "#, + }, ]; #[derive(Clone)] @@ -259,6 +331,50 @@ pub struct Conversation { pub sync_state: String, pub metadata_json: String, pub starred: bool, + pub folder_id: Option, +} + +/// Serialized folder shape consumed by desktop clients (Swift `Folder`). +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ConversationFolder { + pub id: String, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub color: String, + #[serde(rename = "created_at")] + pub created_at: DateTime, + #[serde(rename = "updated_at")] + pub updated_at: DateTime, + #[serde(rename = "order")] + pub sort_order: i64, + #[serde(rename = "is_default")] + pub is_default: bool, + #[serde(rename = "is_system")] + pub is_system: bool, + #[serde( + rename = "category_mapping", + skip_serializing_if = "Option::is_none" + )] + pub category_mapping: Option, + #[serde(rename = "conversation_count")] + pub conversation_count: i64, +} + +#[derive(Debug, Clone)] +pub struct NewFolder { + pub id: String, + pub name: String, + pub description: Option, + pub color: Option, +} + +#[derive(Debug, Clone, Default)] +pub struct UpdateFolder { + pub name: Option, + pub description: Option, + pub color: Option, + pub sort_order: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -451,6 +567,443 @@ impl Store { conn: Arc::clone(&self.conn), } } + + pub fn folders(&self) -> FolderRepository { + FolderRepository { + conn: Arc::clone(&self.conn), + } + } + + pub fn chat_sessions(&self) -> ChatSessionsRepository { + ChatSessionsRepository { + conn: Arc::clone(&self.conn), + } + } + + /// Merge [`source_ids.len()`] ≥ 2 conversations into one new conversation. + pub fn merge_conversations(&self, source_ids: &[String], _reprocess: bool) -> Result { + if source_ids.len() < 2 { + anyhow::bail!("at least two conversation_ids required"); + } + let mut sorted: Vec = source_ids.to_vec(); + sorted.sort(); + let merge_key = sorted.join("|"); + let new_id = deterministic_id("merge", &[&merge_key]); + + let now = Utc::now(); + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let tx = conn + .unchecked_transaction() + .context("failed to start merge transaction")?; + + let mut conversations: Vec = Vec::new(); + for id in &sorted { + let row = tx + .query_row( + r#" + SELECT id, session_id, title, overview, status, started_at, ended_at, created_at, + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred, + folder_id + FROM conversations + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![id], + map_conversation, + ) + .optional() + .context("merge: load conversation")?; + let c = row.ok_or_else(|| anyhow::anyhow!("conversation not found: {id}"))?; + conversations.push(c); + } + + conversations.sort_by_key(|c| c.started_at); + let primary = &conversations[0]; + let title = format!("{} (merged)", primary.title); + + let new_row = Conversation { + id: new_id.clone(), + session_id: primary.session_id.clone(), + title, + overview: primary.overview.clone(), + status: "open".to_string(), + started_at: primary.started_at, + ended_at: None, + created_at: now, + updated_at: now, + deleted_at: None, + cloud_id: None, + sync_version: 0, + sync_state: "local".to_string(), + metadata_json: primary.metadata_json.clone(), + starred: false, + folder_id: primary.folder_id.clone(), + }; + + let dupe_exists: Option = tx + .query_row( + "SELECT 1 FROM conversations WHERE id = ?1 AND deleted_at IS NULL", + params![&new_id], + |row| row.get(0), + ) + .optional() + .context("merge: check duplicate merged id")?; + if dupe_exists.is_some() { + anyhow::bail!( + "merged conversation already exists for this set; delete it before merging again" + ); + } + + tx.execute( + r#" + INSERT INTO conversations ( + id, session_id, title, overview, status, started_at, ended_at, created_at, + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred, + folder_id + ) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16) + "#, + params![ + new_row.id, + new_row.session_id, + new_row.title, + new_row.overview, + new_row.status, + new_row.started_at, + new_row.ended_at, + new_row.created_at, + new_row.updated_at, + new_row.deleted_at, + new_row.cloud_id, + new_row.sync_version, + new_row.sync_state, + new_row.metadata_json, + new_row.starred, + new_row.folder_id, + ], + ) + .context("merge: insert merged conversation")?; + + #[derive(Clone)] + #[allow(dead_code)] + struct SegmentRow { + id: String, + session_id: String, + speaker_id: Option, + speaker_label: Option, + text: String, + start_ms: i64, + end_ms: i64, + started_at: DateTime, + segment_index: i64, + source: String, + metadata_json: String, + } + + let mut ordered: Vec = Vec::new(); + for conv in &conversations { + let mut stmt = tx + .prepare( + r#" + SELECT id, session_id, speaker_id, speaker_label, text, start_ms, end_ms, segment_index, + source, metadata_json + FROM transcript_segments + WHERE conversation_id = ?1 AND deleted_at IS NULL + ORDER BY segment_index ASC + "#, + ) + .context("merge: prepare list segments")?; + let rows = stmt.query_map(params![conv.id], |row| { + Ok(SegmentRow { + id: row.get(0)?, + session_id: row.get(1)?, + speaker_id: row.get(2)?, + speaker_label: row.get(3)?, + text: row.get(4)?, + start_ms: row.get(5)?, + end_ms: row.get(6)?, + segment_index: row.get(7)?, + source: row.get(8)?, + started_at: conv.started_at, + metadata_json: row.get(9)?, + }) + })?; + for r in rows { + ordered.push(r.context("merge: read segment")?); + } + } + ordered.sort_by(|a, b| { + a.started_at + .cmp(&b.started_at) + .then(a.segment_index.cmp(&b.segment_index)) + }); + + for (idx, seg) in ordered.iter().enumerate() { + let idx_i = idx as i64; + tx.execute( + r#" + UPDATE transcript_segments + SET conversation_id = ?1, session_id = ?2, segment_index = ?3, updated_at = ?4, + sync_version = sync_version + 1 + WHERE id = ?5 AND deleted_at IS NULL + "#, + params![ + new_id, + seg.session_id, + idx_i, + now, + seg.id, + ], + ) + .context("merge: reattach transcript segment")?; + } + + for sid in &sorted { + tx.execute( + "UPDATE memories SET conversation_id = ?1, updated_at = ?2, sync_version = sync_version + 1 WHERE conversation_id = ?3 AND deleted_at IS NULL", + params![&new_id, now, sid], + ) + .context("merge: repoint memories")?; + tx.execute( + "UPDATE action_items SET conversation_id = ?1, updated_at = ?2, sync_version = sync_version + 1 WHERE conversation_id = ?3 AND deleted_at IS NULL", + params![&new_id, now, sid], + ) + .context("merge: repoint action items")?; + tx.execute( + "UPDATE local_files SET conversation_id = ?1, updated_at = ?2, sync_version = sync_version + 1 WHERE conversation_id = ?3 AND deleted_at IS NULL", + params![&new_id, now, sid], + ) + .context("merge: repoint local files")?; + } + + for sid in &sorted { + tx.execute( + "UPDATE conversations SET deleted_at = ?1, updated_at = ?1 WHERE id = ?2 AND deleted_at IS NULL", + params![now, sid], + ) + .context("merge: soft-delete source")?; + } + + tx.commit().context("merge: commit")?; + + drop(conn); + self.conversations() + .get(&new_id) + .context("merge: reload")? + .ok_or_else(|| anyhow::anyhow!("merged conversation missing after commit")) + } +} + +pub struct FolderRepository { + conn: Arc>, +} + +impl FolderRepository { + pub fn next_sort_order(&self) -> Result { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + "SELECT COALESCE(MAX(sort_order), -1) + 1 FROM folders WHERE deleted_at IS NULL", + [], + |row| row.get(0), + ) + .context("folder sort order") + } + + pub fn exists_active(&self, id: &str) -> Result { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let count: i64 = conn.query_row( + "SELECT COUNT(*) FROM folders WHERE id = ?1 AND deleted_at IS NULL", + params![id], + |row| row.get(0), + ) + .context("folder exists")?; + Ok(count > 0) + } + + pub fn list(&self) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let mut stmt = conn + .prepare( + r#" + SELECT f.id, f.name, f.description, f.color, f.created_at, f.updated_at, f.sort_order, + f.is_default, f.is_system, f.category_mapping, + (SELECT COUNT(*) FROM conversations c + WHERE c.folder_id = f.id AND c.deleted_at IS NULL) + FROM folders f + WHERE f.deleted_at IS NULL + ORDER BY f.sort_order ASC, f.name ASC + "#, + ) + .context("prepare list folders")?; + let rows = stmt.query_map([], |row| { + let is_default_i: i64 = row.get(7)?; + let is_system_i: i64 = row.get(8)?; + Ok(ConversationFolder { + id: row.get(0)?, + name: row.get(1)?, + description: row.get(2)?, + color: row.get(3)?, + created_at: row.get(4)?, + updated_at: row.get(5)?, + sort_order: row.get(6)?, + is_default: is_default_i != 0, + is_system: is_system_i != 0, + category_mapping: row.get(9)?, + conversation_count: row.get(10)?, + }) + })?; + collect_rows(rows) + } + + pub fn get(&self, id: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT f.id, f.name, f.description, f.color, f.created_at, f.updated_at, f.sort_order, + f.is_default, f.is_system, f.category_mapping, + (SELECT COUNT(*) FROM conversations c + WHERE c.folder_id = f.id AND c.deleted_at IS NULL) + FROM folders f + WHERE f.id = ?1 AND f.deleted_at IS NULL + "#, + params![id], + |row| { + let is_default_i: i64 = row.get(7)?; + let is_system_i: i64 = row.get(8)?; + Ok(ConversationFolder { + id: row.get(0)?, + name: row.get(1)?, + description: row.get(2)?, + color: row.get(3)?, + created_at: row.get(4)?, + updated_at: row.get(5)?, + sort_order: row.get(6)?, + is_default: is_default_i != 0, + is_system: is_system_i != 0, + category_mapping: row.get(9)?, + conversation_count: row.get(10)?, + }) + }, + ) + .optional() + .context("get folder") + } + + pub fn create(&self, new: NewFolder) -> Result { + let now = Utc::now(); + let sort_order = self.next_sort_order()?; + let color = new.color.unwrap_or_else(|| "#6B7280".to_string()); + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + INSERT INTO folders (id, name, description, color, sort_order, created_at, updated_at, deleted_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, NULL) + "#, + params![ + new.id, + new.name, + new.description, + color, + sort_order, + now, + now, + ], + ) + .context("insert folder")?; + drop(conn); + self.get(&new.id)?.ok_or_else(|| anyhow::anyhow!("folder missing after insert")) + } + + pub fn update(&self, id: &str, update: UpdateFolder) -> Result> { + let Some(mut row) = self.get(id)? else { + return Ok(None); + }; + if let Some(name) = update.name { + row.name = name; + } + if let Some(description) = update.description { + row.description = if description.is_empty() { + None + } else { + Some(description) + }; + } + if let Some(color) = update.color { + row.color = color; + } + if let Some(order) = update.sort_order { + row.sort_order = order; + } + row.updated_at = Utc::now(); + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.execute( + r#" + UPDATE folders + SET name = ?2, description = ?3, color = ?4, sort_order = ?5, updated_at = ?6 + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![ + id, + row.name, + row.description, + row.color, + row.sort_order, + row.updated_at, + ], + ) + .context("update folder")?; + drop(conn); + self.get(id) + } + + pub fn soft_delete(&self, id: &str, move_to_folder_id: Option<&str>) -> Result { + let folder = self + .get(id)? + .ok_or_else(|| anyhow::anyhow!("folder not found"))?; + if folder.is_system { + anyhow::bail!("cannot delete system folders"); + } + let now = Utc::now(); + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + if let Some(target) = move_to_folder_id { + if target == id { + return Err(anyhow::anyhow!("move_to_folder_id must differ from deleted folder")); + } + let exists: i64 = conn.query_row( + "SELECT COUNT(*) FROM folders WHERE id = ?1 AND deleted_at IS NULL", + params![target], + |row| row.get(0), + )?; + if exists == 0 { + anyhow::bail!("move_to_folder_id not found"); + } + conn.execute( + r#" + UPDATE conversations + SET folder_id = ?1, updated_at = ?2, sync_version = sync_version + 1 + WHERE folder_id = ?3 AND deleted_at IS NULL + "#, + params![target, now, id], + ) + .context("reassign conversations on folder delete")?; + } else { + conn.execute( + r#" + UPDATE conversations + SET folder_id = NULL, updated_at = ?1, sync_version = sync_version + 1 + WHERE folder_id = ?2 AND deleted_at IS NULL + "#, + params![now, id], + ) + .context("unfile conversations on folder delete")?; + } + let changed = conn + .execute( + "UPDATE folders SET deleted_at = ?1, updated_at = ?1 WHERE id = ?2 AND deleted_at IS NULL", + params![now, id], + ) + .context("soft-delete folder")?; + Ok(changed > 0) + } } pub struct ConversationRepository { @@ -476,6 +1029,7 @@ impl ConversationRepository { sync_state: "local".to_string(), metadata_json: json_or_empty_object(new.metadata)?, starred: false, + folder_id: None, }; let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); @@ -483,9 +1037,10 @@ impl ConversationRepository { r#" INSERT INTO conversations ( id, session_id, title, overview, status, started_at, ended_at, created_at, - updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred, + folder_id ) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16) "#, params![ conversation.id, @@ -502,7 +1057,8 @@ impl ConversationRepository { conversation.sync_version, conversation.sync_state, conversation.metadata_json, - conversation.starred + conversation.starred, + conversation.folder_id, ], ) .context("failed to insert conversation")?; @@ -515,7 +1071,8 @@ impl ConversationRepository { conn.query_row( r#" SELECT id, session_id, title, overview, status, started_at, ended_at, created_at, - updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred, + folder_id FROM conversations WHERE id = ?1 AND deleted_at IS NULL "#, @@ -527,7 +1084,7 @@ impl ConversationRepository { } pub fn list(&self, limit: i64) -> Result> { - self.list_filtered(limit, 0, None, None, None) + self.list_filtered(limit, 0, None, None, None, None) } pub fn list_filtered( @@ -537,20 +1094,23 @@ impl ConversationRepository { start_date: Option>, end_date: Option>, starred: Option, + folder_id: Option<&str>, ) -> Result> { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); let mut stmt = conn .prepare( r#" SELECT id, session_id, title, overview, status, started_at, ended_at, created_at, - updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred + updated_at, deleted_at, cloud_id, sync_version, sync_state, metadata_json, starred, + folder_id FROM conversations WHERE deleted_at IS NULL AND (?1 IS NULL OR started_at >= ?1) AND (?2 IS NULL OR started_at < ?2) AND (?3 IS NULL OR starred = ?3) + AND (?4 IS NULL OR folder_id = ?4) ORDER BY updated_at DESC - LIMIT ?4 OFFSET ?5 + LIMIT ?5 OFFSET ?6 "#, ) .context("failed to prepare conversation list query")?; @@ -560,6 +1120,7 @@ impl ConversationRepository { start_date, end_date, starred.map(|value| if value { 1 } else { 0 }), + folder_id, limit, offset ], @@ -601,14 +1162,30 @@ impl ConversationRepository { if let Some(starred) = update.starred { conversation.starred = starred; } + if let Some(folder_id) = update.folder_id { + conversation.folder_id = folder_id; + } conversation.updated_at = Utc::now(); let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + if let Some(fid) = conversation.folder_id.as_ref() { + let folder_count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM folders WHERE id = ?1 AND deleted_at IS NULL", + params![fid], + |row| row.get(0), + ) + .context("validate folder assignment")?; + if folder_count == 0 { + anyhow::bail!("unknown folder id: {}", fid); + } + } + conn.execute( r#" UPDATE conversations SET title = ?2, overview = ?3, status = ?4, ended_at = ?5, updated_at = ?6, - metadata_json = ?7, starred = ?8, sync_version = sync_version + 1 + metadata_json = ?7, starred = ?8, folder_id = ?9, sync_version = sync_version + 1 WHERE id = ?1 AND deleted_at IS NULL "#, params![ @@ -619,7 +1196,8 @@ impl ConversationRepository { conversation.ended_at, conversation.updated_at, conversation.metadata_json, - conversation.starred + if conversation.starred { 1 } else { 0 }, + conversation.folder_id, ], ) .context("failed to update conversation")?; @@ -1698,6 +2276,8 @@ pub struct UpdateConversation { pub ended_at: Option>>, pub metadata: Option, pub starred: Option, + /// `None` = omit field, `Some(None)` = clear folder, `Some(Some(id))` = set. + pub folder_id: Option>, } #[derive(Debug, Clone)] @@ -1840,7 +2420,8 @@ fn map_conversation(row: &rusqlite::Row<'_>) -> rusqlite::Result { sync_version: row.get(11)?, sync_state: row.get(12)?, metadata_json: row.get(13)?, - starred: row.get(14)?, + starred: row.get::<_, i64>(14)? != 0, + folder_id: row.get(15)?, }) } @@ -2027,6 +2608,361 @@ fn soft_delete_local_processing_except( Ok(changed) } +/// Reserved session row for desktop “default chat” (no explicit multi-chat session). +pub const LOCAL_DEFAULT_CHAT_SESSION_ID: &str = "00000000-0000-4000-8000-000000000001"; + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub struct ChatSessionDto { + pub id: String, + pub title: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub preview: Option, + pub created_at: DateTime, + pub updated_at: DateTime, + #[serde(skip_serializing_if = "Option::is_none")] + pub app_id: Option, + pub message_count: i64, + pub starred: bool, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub struct ChatMessageDto { + pub id: String, + pub text: String, + pub created_at: DateTime, + pub sender: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub app_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub rating: Option, + pub reported: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +pub struct ChatSessionsRepository { + conn: Arc>, +} + +impl ChatSessionsRepository { + pub fn list_sessions( + &self, + limit: i64, + offset: i64, + app_id: Option<&str>, + starred: Option, + ) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let limit = limit.max(1).min(500); + let offset = offset.max(0); + + let map_row = |row: &rusqlite::Row<'_>| -> rusqlite::Result { + let starred_i: i64 = row.get(4)?; + Ok(ChatSessionDto { + id: row.get(0)?, + title: row.get(1)?, + preview: row.get(2)?, + app_id: row.get(3)?, + starred: starred_i != 0, + created_at: row.get(5)?, + updated_at: row.get(6)?, + message_count: row.get(7)?, + }) + }; + + let sessions = match (app_id, starred) { + (None, None) => { + let mut stmt = conn.prepare( + r#" + SELECT s.id, s.title, s.preview, s.app_id, s.starred, s.created_at, s.updated_at, + (SELECT COUNT(*) FROM chat_messages m WHERE m.session_id = s.id) + FROM chat_sessions s + WHERE s.id != ?1 + ORDER BY s.updated_at DESC + LIMIT ?2 OFFSET ?3 + "#, + )?; + let rows = stmt + .query_map(params![LOCAL_DEFAULT_CHAT_SESSION_ID, limit, offset], map_row)? + .collect::, _>>()?; + rows + } + (Some(aid), None) => { + let mut stmt = conn.prepare( + r#" + SELECT s.id, s.title, s.preview, s.app_id, s.starred, s.created_at, s.updated_at, + (SELECT COUNT(*) FROM chat_messages m WHERE m.session_id = s.id) + FROM chat_sessions s + WHERE s.id != ?1 AND (s.app_id IS NOT DISTINCT FROM ?2) + ORDER BY s.updated_at DESC + LIMIT ?3 OFFSET ?4 + "#, + )?; + let rows = stmt + .query_map(params![LOCAL_DEFAULT_CHAT_SESSION_ID, aid, limit, offset], map_row)? + .collect::, _>>()?; + rows + } + (None, Some(st)) => { + let st_i = if st { 1 } else { 0 }; + let mut stmt = conn.prepare( + r#" + SELECT s.id, s.title, s.preview, s.app_id, s.starred, s.created_at, s.updated_at, + (SELECT COUNT(*) FROM chat_messages m WHERE m.session_id = s.id) + FROM chat_sessions s + WHERE s.id != ?1 AND s.starred = ?2 + ORDER BY s.updated_at DESC + LIMIT ?3 OFFSET ?4 + "#, + )?; + let rows = stmt + .query_map(params![LOCAL_DEFAULT_CHAT_SESSION_ID, st_i, limit, offset], map_row)? + .collect::, _>>()?; + rows + } + (Some(aid), Some(st)) => { + let st_i = if st { 1 } else { 0 }; + let mut stmt = conn.prepare( + r#" + SELECT s.id, s.title, s.preview, s.app_id, s.starred, s.created_at, s.updated_at, + (SELECT COUNT(*) FROM chat_messages m WHERE m.session_id = s.id) + FROM chat_sessions s + WHERE s.id != ?1 AND (s.app_id IS NOT DISTINCT FROM ?2) AND s.starred = ?3 + ORDER BY s.updated_at DESC + LIMIT ?4 OFFSET ?5 + "#, + )?; + let rows = stmt + .query_map( + params![LOCAL_DEFAULT_CHAT_SESSION_ID, aid, st_i, limit, offset], + map_row, + )? + .collect::, _>>()?; + rows + } + }; + + Ok(sessions) + } + + pub fn create_session(&self, title: Option<&str>, app_id: Option<&str>) -> Result { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let now = Utc::now(); + let id = deterministic_id( + "chat_sess", + &[&now.timestamp_nanos_opt().unwrap_or(0).to_string()], + ); + let title_str = title.unwrap_or("New Chat"); + conn.execute( + r#" + INSERT INTO chat_sessions (id, title, preview, app_id, starred, created_at, updated_at) + VALUES (?1, ?2, NULL, ?3, 0, ?4, ?5) + "#, + params![id, title_str, app_id, now, now], + ) + .context("insert chat_session")?; + + Ok(ChatSessionDto { + id, + title: title_str.to_string(), + preview: None, + app_id: app_id.map(|s| s.to_string()), + starred: false, + created_at: now, + updated_at: now, + message_count: 0, + }) + } + + pub fn get_session(&self, id: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let row = conn + .query_row( + r#" + SELECT s.id, s.title, s.preview, s.app_id, s.starred, s.created_at, s.updated_at, + (SELECT COUNT(*) FROM chat_messages m WHERE m.session_id = s.id) + FROM chat_sessions s + WHERE s.id = ?1 + "#, + params![id], + |row| { + let starred_i: i64 = row.get(4)?; + Ok(ChatSessionDto { + id: row.get(0)?, + title: row.get(1)?, + preview: row.get(2)?, + app_id: row.get(3)?, + starred: starred_i != 0, + created_at: row.get(5)?, + updated_at: row.get(6)?, + message_count: row.get(7)?, + }) + }, + ) + .optional() + .context("get chat_session")?; + Ok(row) + } + + pub fn update_session( + &self, + id: &str, + title: Option<&str>, + starred: Option, + ) -> Result> { + if id == LOCAL_DEFAULT_CHAT_SESSION_ID { + anyhow::bail!("cannot update reserved default chat session"); + } + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let exists: i64 = conn.query_row( + "SELECT COUNT(*) FROM chat_sessions WHERE id = ?1", + params![id], + |row| row.get(0), + )?; + if exists == 0 { + return Ok(None); + } + + let now = Utc::now(); + if let Some(t) = title { + conn.execute( + "UPDATE chat_sessions SET title = ?1, updated_at = ?2 WHERE id = ?3", + params![t, now, id], + )?; + } + if let Some(st) = starred { + conn.execute( + "UPDATE chat_sessions SET starred = ?1, updated_at = ?2 WHERE id = ?3", + params![if st { 1 } else { 0 }, now, id], + )?; + } + + drop(conn); + self.get_session(id) + } + + pub fn delete_session(&self, id: &str) -> Result { + if id == LOCAL_DEFAULT_CHAT_SESSION_ID { + anyhow::bail!("cannot delete reserved default chat session"); + } + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let changed = conn.execute("DELETE FROM chat_sessions WHERE id = ?1", params![id])?; + Ok(changed > 0) + } + + pub fn list_messages( + &self, + session_id: &str, + app_id: Option<&str>, + limit: i64, + offset: i64, + ) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let limit = limit.max(1).min(500); + let offset = offset.max(0); + + let rows = match app_id { + None => { + let mut stmt = conn.prepare( + r#" + SELECT id, text, created_at, sender, app_id, session_id, rating, reported, metadata + FROM chat_messages + WHERE session_id = ?1 + ORDER BY created_at DESC + LIMIT ?2 OFFSET ?3 + "#, + )?; + let rows = stmt.query_map(params![session_id, limit, offset], Self::map_message_row)? + .collect::, _>>()?; + rows + } + Some(aid) => { + let mut stmt = conn.prepare( + r#" + SELECT id, text, created_at, sender, app_id, session_id, rating, reported, metadata + FROM chat_messages + WHERE session_id = ?1 AND (app_id IS NOT DISTINCT FROM ?2) + ORDER BY created_at DESC + LIMIT ?3 OFFSET ?4 + "#, + )?; + let rows = + stmt.query_map(params![session_id, aid, limit, offset], Self::map_message_row)? + .collect::, _>>()?; + rows + } + }; + + Ok(rows) + } + + fn map_message_row(row: &rusqlite::Row<'_>) -> rusqlite::Result { + let reported_i: i64 = row.get(7)?; + Ok(ChatMessageDto { + id: row.get(0)?, + text: row.get(1)?, + created_at: row.get(2)?, + sender: row.get(3)?, + app_id: row.get(4)?, + session_id: row.get(5)?, + rating: row.get(6)?, + reported: reported_i != 0, + metadata: row.get(8)?, + }) + } + + pub fn append_message( + &self, + session_id: &str, + text: &str, + sender: &str, + app_id: Option<&str>, + metadata: Option<&str>, + ) -> Result<(String, DateTime)> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + let exists: i64 = conn.query_row( + "SELECT COUNT(*) FROM chat_sessions WHERE id = ?1", + params![session_id], + |row| row.get(0), + )?; + if exists == 0 { + anyhow::bail!("chat session not found"); + } + + let now = Utc::now(); + let msg_id = deterministic_id( + "chat_msg", + &[ + session_id, + &now.timestamp_nanos_opt().unwrap_or(0).to_string(), + sender, + text, + ], + ); + + conn.execute( + r#" + INSERT INTO chat_messages (id, session_id, text, sender, app_id, metadata, created_at, rating, reported) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, NULL, 0) + "#, + params![msg_id, session_id, text, sender, app_id, metadata, now], + ) + .context("insert chat_message")?; + + let preview: String = text.chars().take(240).collect(); + conn.execute( + "UPDATE chat_sessions SET updated_at = ?1, preview = ?2 WHERE id = ?3", + params![now, preview, session_id], + )?; + + Ok((msg_id, now)) + } +} + impl ProcessingJobStatus { fn as_str(&self) -> &'static str { match self { @@ -2072,6 +3008,9 @@ mod tests { "sync_outbox", "local_files", "conversation_search", + "folders", + "chat_sessions", + "chat_messages", ] { let exists: i64 = conn.query_row( "SELECT COUNT(*) FROM sqlite_master WHERE name = ?1", @@ -2211,6 +3150,7 @@ mod tests { ended_at: None, metadata: None, starred: Some(true), + folder_id: None, }, )? .expect("conversation should update"); @@ -2383,4 +3323,113 @@ mod tests { Ok(()) } + + #[test] + fn folders_crud_assign_conversations_and_soft_delete_unfiles() -> Result<()> { + let store = Store::open_in_memory()?; + let folder = store.folders().create(NewFolder { + id: "fld-crm".to_string(), + name: "Work".to_string(), + description: Some("desc".into()), + color: Some("#112233".into()), + })?; + assert_eq!(folder.name, "Work"); + assert_eq!(store.folders().list()?.len(), 1); + + store.conversations().create(NewConversation { + id: "conv-fld".into(), + session_id: "s-fld".into(), + title: "Tagged".into(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + store + .conversations() + .update( + "conv-fld", + UpdateConversation { + title: None, + overview: None, + status: None, + ended_at: None, + metadata: None, + starred: None, + folder_id: Some(Some("fld-crm".into())), + }, + )? + .expect("update"); + + let conv = store.conversations().get("conv-fld")?.expect("conversation"); + assert_eq!(conv.folder_id.as_deref(), Some("fld-crm")); + + store.folders().soft_delete("fld-crm", None)?; + assert!(store.folders().get("fld-crm")?.is_none()); + + let unfiled = store.conversations().get("conv-fld")?.expect("conversation"); + assert!(unfiled.folder_id.is_none()); + + Ok(()) + } + + #[test] + fn merge_conversations_moves_segments_and_soft_deletes_sources() -> Result<()> { + let store = Store::open_in_memory()?; + + store.conversations().create(NewConversation { + id: "merge-a".into(), + session_id: "s-a".into(), + title: "A".into(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + store.conversations().create(NewConversation { + id: "merge-b".into(), + session_id: "s-b".into(), + title: "B".into(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + + store.transcripts().append(NewTranscriptSegment { + id: "seg-a0".into(), + conversation_id: "merge-a".into(), + session_id: "s-a".into(), + speaker_id: None, + speaker_label: Some("Sp1".into()), + text: "alpha".into(), + start_ms: 0, + end_ms: 50, + segment_index: 0, + source: None, + metadata: None, + })?; + store.transcripts().append(NewTranscriptSegment { + id: "seg-b0".into(), + conversation_id: "merge-b".into(), + session_id: "s-b".into(), + speaker_id: None, + speaker_label: Some("Sp2".into()), + text: "beta".into(), + start_ms: 0, + end_ms: 60, + segment_index: 0, + source: None, + metadata: None, + })?; + + let merged = store.merge_conversations(&["merge-b".into(), "merge-a".into()], false)?; + let merged_segs = store.transcripts().list_for_conversation(&merged.id)?; + let texts: Vec<&str> = merged_segs.iter().map(|s| s.text.as_str()).collect(); + assert_eq!(texts.len(), 2); + assert!(texts.contains(&"alpha")); + assert!(texts.contains(&"beta")); + + assert!(store.conversations().get("merge-a")?.is_none()); + assert!(store.conversations().get("merge-b")?.is_none()); + + Ok(()) + } } diff --git a/desktop/local-backend/tools/seed_hybrid_defaults.sh b/desktop/local-backend/tools/seed_hybrid_defaults.sh new file mode 100755 index 00000000000..b4b6a9cd776 --- /dev/null +++ b/desktop/local-backend/tools/seed_hybrid_defaults.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +# Idempotent: seed ai_provider + chat_provider on the local daemon when unset. +set -euo pipefail + +BASE_URL="${OMI_LOCAL_DAEMON_URL:-http://127.0.0.1:8765}" +BASE_URL="${BASE_URL%/}" +PROVIDER_BASE="${OMI_HYBRID_DEFAULT_CHAT_BASE_URL:-http://127.0.0.1:11434/v1}" +MODEL="${OMI_HYBRID_DEFAULT_CHAT_MODEL:-llama3.2}" + +if ! curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + echo "seed_hybrid_defaults: daemon not healthy at ${BASE_URL}/health" >&2 + exit 1 +fi + +settings_json="$(curl -fsS "${BASE_URL}/v1/settings")" + +has_key() { + local key="$1" + echo "$settings_json" | python3 -c " +import json, sys +key = sys.argv[1] +data = json.load(sys.stdin) +for s in data.get('settings', []): + if s.get('key') != key: + continue + raw = s.get('value_json') or '' + if not raw or raw == 'null': + sys.exit(1) + try: + v = json.loads(raw) + except json.JSONDecodeError: + sys.exit(1) + if isinstance(v, dict) and v.get('base_url'): + sys.exit(0) +sys.exit(1) +" "$key" +} + +provider_payload="$(python3 - < ${PROVIDER_BASE}" +fi + +if ! has_key chat_provider; then + if [ "$updated" -eq 1 ]; then + body="$(echo "$body" | python3 -c " +import json, sys +p = json.loads(sys.stdin.read()) +chat = json.loads('''${provider_payload}''') +p['chat_provider'] = chat +print(json.dumps(p)) +")" + else + body="$(echo "$provider_payload" | python3 -c " +import json, sys +p = json.load(sys.stdin) +print(json.dumps({'chat_provider': p})) +")" + updated=1 + fi + echo "seed_hybrid_defaults: will set chat_provider -> ${PROVIDER_BASE}" +fi + +if [ "$updated" -eq 0 ]; then + echo "seed_hybrid_defaults: ai_provider and chat_provider already configured" + exit 0 +fi + +curl -fsS -X PUT "${BASE_URL}/v1/settings" \ + -H 'content-type: application/json' \ + -d "$body" >/dev/null + +echo "seed_hybrid_defaults: done" diff --git a/desktop/run.sh b/desktop/run.sh index 570254b6d88..a174c1cf8c9 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -24,6 +24,9 @@ Options (via environment variables): OMI_DESKTOP_BACKEND_MODE=local Route MVP data flows to the local daemon OMI_LOCAL_DAEMON_SUPERVISE=1 In local mode, start desktop/local-backend if /health is unreachable OMI_LOCAL_DAEMON_URL="..." Local daemon URL (default: http://127.0.0.1:8765) + OMI_HYBRID_DIRECT_STT_ENABLED Hybrid Apple Speech live transcription in local daemon (default 1 in configure_local_daemon_mode when unset) + OMI_HYBRID_DIRECT_CHAT_ENABLED Hybrid OpenAI-compatible chat + daemon-backed sessions/messages (default 1 in configure_local_daemon_mode when unset) + OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED Hybrid direct embeddings for Rewind/proactive features (default 1 in local bundle; requires embedding_provider) Required files for cloud backend mode: Backend-Rust/.env Environment variables (copy from ../.env.example) @@ -188,6 +191,18 @@ PY export OMI_LOCAL_BACKEND_PORT fi export OMI_LOCAL_BACKEND_HOST="${OMI_LOCAL_BACKEND_HOST:-127.0.0.1}" + # Default hybrid on-device STT for local daemon dev (Apple Speech). Set OMI_HYBRID_DIRECT_STT_ENABLED=0 to disable. + if [ -z "${OMI_HYBRID_DIRECT_STT_ENABLED+x}" ]; then + export OMI_HYBRID_DIRECT_STT_ENABLED=1 + fi + # Hybrid direct chat capability for GUI launches (requires chat_provider in daemon settings). + if [ -z "${OMI_HYBRID_DIRECT_CHAT_ENABLED+x}" ]; then + export OMI_HYBRID_DIRECT_CHAT_ENABLED=1 + fi + # Optional direct embeddings for Rewind OCR vectors etc.: requires embedding_provider in daemon settings. + if [ -z "${OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED+x}" ]; then + export OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=1 + fi } local_daemon_health_ok() { @@ -425,6 +440,14 @@ if is_local_daemon_mode; then echo " OMI_DESKTOP_BACKEND_MODE=local OMI_LOCAL_DAEMON_SUPERVISE=1 ./run.sh" exit 1 fi + + if local_daemon_health_ok; then + SEED_SCRIPT="$(cd "$(dirname "$0")/local-backend/tools" && pwd)/seed_hybrid_defaults.sh" + if [ -x "$SEED_SCRIPT" ]; then + substep "Seeding hybrid provider defaults (if unset)" + "$SEED_SCRIPT" || substep "Warning: hybrid provider seed failed (non-fatal)" + fi + fi fi # ─── Start Rust backend ─────────────────────────────────────────────── @@ -487,6 +510,13 @@ if [ -f scripts/check_schema_docs.sh ]; then bash scripts/check_schema_docs.sh || substep "Schema docs check failed (non-fatal)" fi +if ! pkg-config --exists libwebp 2>/dev/null; then + echo "ERROR: libwebp headers not found (required by CWebP for screen capture)." + echo " brew install webp" + echo " Then re-run ./run.sh" + exit 1 +fi + step "Building Swift app (swift build -c debug)..." xcrun swift build -c debug --package-path Desktop @@ -619,8 +649,15 @@ substep "OMI_DESKTOP_API_URL=$EFFECTIVE_API_URL" if is_local_daemon_mode; then set_bundle_env "OMI_DESKTOP_BACKEND_MODE" "local" set_bundle_env "OMI_LOCAL_DAEMON_URL" "$OMI_LOCAL_DAEMON_URL" + # GUI launches via `open` do not inherit shell exports — AppState.loadEnvironment() reads bundled .env. + set_bundle_env "OMI_HYBRID_DIRECT_STT_ENABLED" "${OMI_HYBRID_DIRECT_STT_ENABLED:-1}" + set_bundle_env "OMI_HYBRID_DIRECT_CHAT_ENABLED" "${OMI_HYBRID_DIRECT_CHAT_ENABLED:-1}" + set_bundle_env "OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED" "${OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED:-1}" substep "OMI_DESKTOP_BACKEND_MODE=local" substep "OMI_LOCAL_DAEMON_URL=$OMI_LOCAL_DAEMON_URL" + substep "OMI_HYBRID_DIRECT_STT_ENABLED=${OMI_HYBRID_DIRECT_STT_ENABLED:-1}" + substep "OMI_HYBRID_DIRECT_CHAT_ENABLED=${OMI_HYBRID_DIRECT_CHAT_ENABLED:-1}" + substep "OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=${OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED:-1}" fi # Bootstrap FIREBASE_API_KEY — check env var first (yolo mode), then backend .env if ! grep -q "^FIREBASE_API_KEY=" "$APP_BUNDLE/Contents/Resources/.env"; then From 149d1a6b56ca60240417b019e1b6f9f626d9805f Mon Sep 17 00:00:00 2001 From: David Zhang Date: Tue, 19 May 2026 19:49:09 -0400 Subject: [PATCH 28/58] Update .gitignore to include cursor and desktop build artifacts --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 13d9ed256b2..abeb8ab59d5 100644 --- a/.gitignore +++ b/.gitignore @@ -221,3 +221,7 @@ app/macos/Runner/LocalDev.entitlements # Generated browser snapshots (contain PII) builds-snapshot.md *-snapshot.md + +# Cursor / agent local artifacts +.cursor/plans/ +desktop/Desktop/.build-*/ From 373a37cfe7da333c6a13e08888acab390dc59688 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 00:23:42 -0400 Subject: [PATCH 29/58] Add ChatGPT/Codex plan support via local loopback proxy. Route hybrid chat and proactive LLM through a localhost Codex proxy using ~/.codex auth, add Settings enrollment UX, local memory wiki + FTS for search, backend tier activation, and bundle the proxy in run.sh. Co-authored-by: Cursor --- .gitignore | 2 + backend/database/users.py | 49 + backend/routers/users.py | 69 + backend/tests/unit/test_chatgpt_enrollment.py | 35 + backend/utils/subscription.py | 13 +- desktop/Desktop/Sources/APIClient.swift | 17 + .../Desktop/Sources/CodexAuthService.swift | 110 ++ .../Sources/CodexEnrollmentCoordinator.swift | 134 ++ .../Sources/CodexProviderBootstrap.swift | 52 + .../Desktop/Sources/CodexProxyService.swift | 186 +++ .../Sources/DesktopBackendEnvironment.swift | 5 + .../Desktop/Sources/HybridChatClient.swift | 44 +- desktop/Desktop/Sources/HybridLLMClient.swift | 18 +- .../MainWindow/Pages/SettingsPage.swift | 91 ++ .../Desktop/Sources/MemoryWikiStorage.swift | 181 +++ desktop/Desktop/Sources/OmiApp.swift | 3 + .../MemoryExtraction/MemoryAssistant.swift | 15 + .../TaskExtraction/TaskAssistant.swift | 63 +- .../Core/GeminiClient.swift | 10 +- .../Sources/Providers/ChatProvider.swift | 3 +- .../Sources/Rewind/Core/RewindDatabase.swift | 52 + .../Desktop/Tests/CodexAuthServiceTests.swift | 126 ++ desktop/codex-proxy/Cargo.lock | 1443 +++++++++++++++++ desktop/codex-proxy/Cargo.toml | 13 + desktop/codex-proxy/README.md | 61 + desktop/codex-proxy/e2e_smoke.sh | 63 + desktop/codex-proxy/src/main.rs | 799 +++++++++ .../docs/hybrid-provider-settings.md | 11 + desktop/run.sh | 11 + 29 files changed, 3667 insertions(+), 12 deletions(-) create mode 100644 backend/tests/unit/test_chatgpt_enrollment.py create mode 100644 desktop/Desktop/Sources/CodexAuthService.swift create mode 100644 desktop/Desktop/Sources/CodexEnrollmentCoordinator.swift create mode 100644 desktop/Desktop/Sources/CodexProviderBootstrap.swift create mode 100644 desktop/Desktop/Sources/CodexProxyService.swift create mode 100644 desktop/Desktop/Sources/MemoryWikiStorage.swift create mode 100644 desktop/Desktop/Tests/CodexAuthServiceTests.swift create mode 100644 desktop/codex-proxy/Cargo.lock create mode 100644 desktop/codex-proxy/Cargo.toml create mode 100644 desktop/codex-proxy/README.md create mode 100755 desktop/codex-proxy/e2e_smoke.sh create mode 100644 desktop/codex-proxy/src/main.rs diff --git a/.gitignore b/.gitignore index abeb8ab59d5..087560242aa 100644 --- a/.gitignore +++ b/.gitignore @@ -103,7 +103,9 @@ web/app/public/firebase-messaging-sw.js !app/pubspec.lock !app/ios/Podfile.lock !mcp/uv.lock +desktop/codex-proxy/target/ *.lock +!desktop/codex-proxy/Cargo.lock *.log *.swo *.swp diff --git a/backend/database/users.py b/backend/database/users.py index 448b4d28b44..b3fbf6fe539 100644 --- a/backend/database/users.py +++ b/backend/database/users.py @@ -215,6 +215,55 @@ def clear_byok_active(uid: str): ) +def get_chatgpt_state(uid: str) -> dict: + user_ref = db.collection('users').document(uid) + data = user_ref.get().to_dict() or {} + return data.get('chatgpt', {}) + + +def is_chatgpt_active(uid: str) -> bool: + """True if user enrolled ChatGPT/Codex tier (LLM-only; separate from four-key BYOK).""" + state = get_chatgpt_state(uid) + if not state.get('active'): + return False + last_seen = state.get('last_seen_at') + if not last_seen: + return False + if isinstance(last_seen, datetime): + age = (datetime.now(timezone.utc) - last_seen).total_seconds() + else: + return False + return age <= BYOK_HEARTBEAT_TTL_SECONDS + + +def set_chatgpt_active(uid: str, fingerprint: str): + user_ref = db.collection('users').document(uid) + user_ref.set( + { + 'chatgpt': { + 'active': True, + 'fingerprint': fingerprint, + 'last_seen_at': datetime.now(timezone.utc), + } + }, + merge=True, + ) + + +def clear_chatgpt_active(uid: str): + user_ref = db.collection('users').document(uid) + user_ref.set( + { + 'chatgpt': { + 'active': False, + 'fingerprint': '', + 'last_seen_at': datetime.now(timezone.utc), + } + }, + merge=True, + ) + + def set_user_deletion_feedback(uid: str, reason: Optional[str], reason_details: Optional[str] = None): # Stored in a top-level collection so it survives the user record being deleted. db.collection('account_deletions').document(uid).set( diff --git a/backend/routers/users.py b/backend/routers/users.py index 1079529d3b2..2190cb28e65 100644 --- a/backend/routers/users.py +++ b/backend/routers/users.py @@ -813,6 +813,47 @@ def deactivate_byok_endpoint(uid: str = Depends(auth.get_current_user_uid_no_byo return {"active": False} +class ChatGPTActivateRequest(BaseModel): + fingerprint: str + + +@router.post('/v1/users/me/chatgpt-active', tags=['v1']) +def activate_chatgpt_endpoint( + data: ChatGPTActivateRequest, uid: str = Depends(auth.get_current_user_uid_no_byok_validation) +): + """Enroll ChatGPT / Codex subscription tier (LLM workloads only; no provider keys stored).""" + if not _SHA256_HEX_RE.match(data.fingerprint): + raise HTTPException( + status_code=400, + detail='Invalid fingerprint: expected lowercase hex SHA-256 (64 chars)', + ) + users_db.set_chatgpt_active(uid, data.fingerprint) + clear_trial_paywall_cache(uid) + return {"active": True} + + +@router.delete('/v1/users/me/chatgpt-active', tags=['v1']) +def deactivate_chatgpt_endpoint(uid: str = Depends(auth.get_current_user_uid_no_byok_validation)): + """Drop ChatGPT / Codex tier enrollment.""" + users_db.clear_chatgpt_active(uid) + clear_trial_paywall_cache(uid) + return {"active": False} + + +def _chatgpt_unlimited_subscription() -> Subscription: + return Subscription( + plan=PlanType.unlimited, + status=SubscriptionStatus.active, + features=["chatgpt"], + limits=PlanLimits( + transcription_seconds=None, + words_transcribed=None, + insights_gained=None, + memories_created=None, + ), + ) + + def _byok_unlimited_subscription() -> Subscription: """BYOK free plan: unlimited limits, marked with the `byok` feature flag.""" return Subscription( @@ -844,6 +885,22 @@ def get_user_subscription_endpoint( # these users aren't surprised by a disabled phone-call feature. unlimited_phone_quota = PhoneCallQuota(has_access=True, is_paid=True) + if users_db.is_chatgpt_active(uid): + return UserSubscriptionResponse( + subscription=_chatgpt_unlimited_subscription(), + transcription_seconds_used=0, + transcription_seconds_limit=0, + words_transcribed_used=0, + words_transcribed_limit=0, + insights_gained_used=0, + insights_gained_limit=0, + memories_created_used=0, + memories_created_limit=0, + available_plans=[], + show_subscription_ui=False, + phone_call_quota=unlimited_phone_quota, + ) + if users_db.is_byok_active(uid) and has_byok_keys(): return UserSubscriptionResponse( subscription=_byok_unlimited_subscription(), @@ -1053,6 +1110,18 @@ def get_user_chat_usage_quota( # BYOK free plan: user brings their own keys, so there's no Omi-side cost # to meter. Only return unlimited when BYOK headers are on the request (desktop). # Mobile (no headers) should see real quota. + if users_db.is_chatgpt_active(uid): + return ChatUsageQuota( + plan='Free (ChatGPT)', + plan_type=PlanType.unlimited.value, + unit=ChatQuotaUnit.questions, + used=0.0, + limit=None, + percent=0.0, + allowed=True, + reset_at=None, + ) + if users_db.is_byok_active(uid) and has_byok_keys(): return ChatUsageQuota( plan='Free (BYOK)', diff --git a/backend/tests/unit/test_chatgpt_enrollment.py b/backend/tests/unit/test_chatgpt_enrollment.py new file mode 100644 index 00000000000..217f1e02751 --- /dev/null +++ b/backend/tests/unit/test_chatgpt_enrollment.py @@ -0,0 +1,35 @@ +"""Unit tests for ChatGPT / Codex tier enrollment (standalone, no Firestore imports).""" + +import re +from datetime import datetime, timedelta, timezone + +_SHA256_HEX_RE = re.compile(r'^[a-f0-9]{64}$') +_SHA256 = 'a' * 64 +_CHATGPT_TTL_SECONDS = 7 * 24 * 60 * 60 + + +def _is_chatgpt_active_state(state: dict) -> bool: + if not state.get('active'): + return False + last_seen = state.get('last_seen_at') + if not isinstance(last_seen, datetime): + return False + age = (datetime.now(timezone.utc) - last_seen).total_seconds() + return age <= _CHATGPT_TTL_SECONDS + + +def test_fingerprint_must_be_sha256_hex(): + assert _SHA256_HEX_RE.match(_SHA256) + assert not _SHA256_HEX_RE.match('not-hex') + assert not _SHA256_HEX_RE.match('A' * 64) + + +def test_chatgpt_active_ttl(): + fresh = {'active': True, 'last_seen_at': datetime.now(timezone.utc) - timedelta(days=1)} + assert _is_chatgpt_active_state(fresh) is True + + stale = {'active': True, 'last_seen_at': datetime.now(timezone.utc) - timedelta(days=30)} + assert _is_chatgpt_active_state(stale) is False + + inactive = {'active': False, 'last_seen_at': datetime.now(timezone.utc)} + assert _is_chatgpt_active_state(inactive) is False diff --git a/backend/utils/subscription.py b/backend/utils/subscription.py index 6d8ca62835c..d3b176d56dc 100644 --- a/backend/utils/subscription.py +++ b/backend/utils/subscription.py @@ -86,6 +86,8 @@ def _is_trial_expired_uncached(uid: str) -> bool: return False if users_db.is_byok_active(uid): return False + if users_db.is_chatgpt_active(uid): + return False user_record = firebase_auth.get_user(uid) creation_ms = user_record.user_metadata.creation_timestamp if not creation_ms: @@ -157,7 +159,12 @@ def get_trial_metadata(uid: str) -> TrialMetadata: # Same request-level escape hatch as `_is_trial_expired_cached`: a request # carrying all 4 BYOK provider headers is treated as BYOK-active even if # Firestore hasn't caught up yet. - if plan != PlanType.basic or users_db.is_byok_active(uid) or _request_has_all_byok_keys(): + if ( + plan != PlanType.basic + or users_db.is_byok_active(uid) + or users_db.is_chatgpt_active(uid) + or _request_has_all_byok_keys() + ): return TrialMetadata( trial_expired=False, trial_duration_seconds=TRIAL_LENGTH_SECONDS, @@ -534,6 +541,10 @@ def enforce_chat_quota(uid: str, platform: Optional[str] = None) -> None: ) # BYOK users pay their own LLM provider — no Omi-side cost to cap. + # ChatGPT/Codex tier: user pays OpenAI via subscription; LLM quota bypass only. + if users_db.is_chatgpt_active(uid): + return + # Require an LLM provider key on this request (not just any BYOK header) # so a user can't activate with fake fingerprints or send only x-byok-deepgram # to bypass chat quota while chat falls back to Omi's OpenAI/Anthropic keys. diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 717125e61d1..4bf5053ba71 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -5906,6 +5906,23 @@ extension APIClient { try await delete("v1/users/me/byok-active") } + /// Activate ChatGPT / Codex subscription tier (LLM only; fingerprint of account_id). + func activateChatGPT(fingerprint: String) async throws { + guard selectedBackendTarget.mode != .localDaemon else { return } + struct Request: Encodable { + let fingerprint: String + } + struct Empty: Decodable {} + let _: Empty = try await post( + "v1/users/me/chatgpt-active", body: Request(fingerprint: fingerprint) + ) + } + + func deactivateChatGPT() async throws { + guard selectedBackendTarget.mode != .localDaemon else { return } + try await delete("v1/users/me/chatgpt-active") + } + /// Fetches all people for the current user func getPeople() async throws -> [Person] { return try await get("v1/users/people") diff --git a/desktop/Desktop/Sources/CodexAuthService.swift b/desktop/Desktop/Sources/CodexAuthService.swift new file mode 100644 index 00000000000..03209ea36d4 --- /dev/null +++ b/desktop/Desktop/Sources/CodexAuthService.swift @@ -0,0 +1,110 @@ +import CryptoKit +import Foundation + +/// ChatGPT / Codex subscription auth via local `~/.codex/auth.json` (same cache as Codex CLI). +/// Tokens never leave this Mac except through the loopback Codex proxy to OpenAI. +enum CodexAuthService { + private static let enrolledKey = "codex_auth_enrolled" + private static let preferredModelKey = "codex_preferred_model" + private static let defaultModel = "gpt-5.4" + + struct AuthSnapshot: Equatable { + let accessToken: String + let accountId: String + let refreshToken: String? + let authFilePath: URL + } + + /// User opted in via Settings (distinct from merely having auth.json from Codex CLI). + static var isEnrolled: Bool { + UserDefaults.standard.bool(forKey: enrolledKey) + } + + static var preferredModel: String { + let stored = UserDefaults.standard.string(forKey: preferredModelKey)? + .trimmingCharacters(in: .whitespacesAndNewlines) + if let stored, !stored.isEmpty { return stored } + return defaultModel + } + + static func setPreferredModel(_ model: String) { + UserDefaults.standard.set(model, forKey: preferredModelKey) + } + + /// SHA-256 fingerprint of account_id for backend enrollment (never stores tokens server-side). + static func enrollmentFingerprint(for accountId: String) -> String { + let digest = SHA256.hash(data: Data(accountId.utf8)) + return digest.map { String(format: "%02x", $0) }.joined() + } + + static func resolveAuthFilePath() -> URL { + if let codexHome = ProcessInfo.processInfo.environment["CODEX_HOME"], + !codexHome.isEmpty + { + return URL(fileURLWithPath: codexHome).appendingPathComponent("auth.json") + } + return FileManager.default.homeDirectoryForCurrentUser + .appendingPathComponent(".codex/auth.json") + } + + static func loadSnapshot() -> AuthSnapshot? { + let url = resolveAuthFilePath() + guard let data = try? Data(contentsOf: url), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + let fields = parseAuthFields(from: json) + else { + return nil + } + return AuthSnapshot( + accessToken: fields.accessToken, + accountId: fields.accountId, + refreshToken: fields.refreshToken, + authFilePath: url + ) + } + + /// Codex CLI stores tokens at the top level (legacy) or under `tokens` (current format). + private static func parseAuthFields(from json: [String: Any]) -> ( + accessToken: String, accountId: String, refreshToken: String? + )? { + if let parsed = parseAuthFields(fromTokenContainer: json) { + return parsed + } + if let tokens = json["tokens"] as? [String: Any] { + return parseAuthFields(fromTokenContainer: tokens) + } + return nil + } + + private static func parseAuthFields(fromTokenContainer json: [String: Any]) -> ( + accessToken: String, accountId: String, refreshToken: String? + )? { + guard let accessToken = json["access_token"] as? String, + !accessToken.isEmpty, + let accountId = json["account_id"] as? String, + !accountId.isEmpty + else { + return nil + } + let refresh = (json["refresh_token"] as? String).flatMap { $0.isEmpty ? nil : $0 } + return (accessToken, accountId, refresh) + } + + /// True when enrolled and a valid auth file is present. + static var isActive: Bool { + isEnrolled && loadSnapshot() != nil + } + + static func markEnrolled() { + UserDefaults.standard.set(true, forKey: enrolledKey) + } + + static func clearEnrollment() { + UserDefaults.standard.set(false, forKey: enrolledKey) + } + + static func enrollmentFingerprintIfActive() -> String? { + guard let snap = loadSnapshot(), isEnrolled else { return nil } + return enrollmentFingerprint(for: snap.accountId) + } +} diff --git a/desktop/Desktop/Sources/CodexEnrollmentCoordinator.swift b/desktop/Desktop/Sources/CodexEnrollmentCoordinator.swift new file mode 100644 index 00000000000..eb9bafb5bb2 --- /dev/null +++ b/desktop/Desktop/Sources/CodexEnrollmentCoordinator.swift @@ -0,0 +1,134 @@ +import AppKit +import Foundation + +/// Connects ChatGPT / Codex subscription: runs `codex login`, validates auth.json, enrolls backend. +@MainActor +enum CodexEnrollmentCoordinator { + enum EnrollmentError: LocalizedError { + case authFileMissing + case proxyFailed(String) + case backendFailed(String) + + var errorDescription: String? { + switch self { + case .authFileMissing: + return + "Sign-in timed out. Complete login in Terminal, then try again." + case .proxyFailed(let msg): + return "Codex proxy failed: \(msg)" + case .backendFailed(let msg): + return "Could not activate ChatGPT plan on account: \(msg)" + } + } + } + + private static var connectInFlight = false + private static var loginTerminalLaunched = false + + /// Opens Codex login in Terminal (user completes browser flow). + private static func launchCodexLogin() { + if loginTerminalLaunched { + activateTerminal() + return + } + loginTerminalLaunched = true + + let command = "npx @openai/codex login" + // `do script` alone always opens a new window when Terminal is already running. + // Reuse the front window (new tab) so one click does not spawn a second window. + let script = """ + tell application "Terminal" + if not running then + do script "\(command)" + else + activate + if (count of windows) is 0 then + do script "\(command)" + else + do script "\(command)" in front window + end if + end if + end tell + """ + if let appleScript = NSAppleScript(source: script) { + var error: NSDictionary? + appleScript.executeAndReturnError(&error) + if error != nil { + loginTerminalLaunched = false + NSWorkspace.shared.open(URL(string: "https://developers.openai.com/codex/auth")!) + } + } + } + + private static func activateTerminal() { + let script = """ + tell application "Terminal" + activate + end tell + """ + if let appleScript = NSAppleScript(source: script) { + var error: NSDictionary? + appleScript.executeAndReturnError(&error) + } + } + + /// Sign in: use existing auth if present, otherwise open Codex login then poll. + static func connect(pollSeconds: Int = 120) async throws { + guard !connectInFlight else { return } + connectInFlight = true + defer { + connectInFlight = false + loginTerminalLaunched = false + } + + if let snap = CodexAuthService.loadSnapshot() { + try await finalizeEnrollment(snapshot: snap) + return + } + launchCodexLogin() + try await connectAfterLogin(pollSeconds: pollSeconds) + } + + /// Poll for auth.json after login, then enroll + start proxy. + private static func connectAfterLogin(pollSeconds: Int = 120) async throws { + let deadline = Date().addingTimeInterval(TimeInterval(pollSeconds)) + while Date() < deadline { + if let snap = CodexAuthService.loadSnapshot() { + try await finalizeEnrollment(snapshot: snap) + return + } + try await Task.sleep(nanoseconds: 2_000_000_000) + } + throw EnrollmentError.authFileMissing + } + + static func disconnect() async { + CodexAuthService.clearEnrollment() + await CodexProxyService.shared.stop() + await CodexProviderBootstrap.clearDaemonProviders() + try? await APIClient.shared.deactivateChatGPT() + await FloatingBarUsageLimiter.shared.fetchPlan() + } + + private static func finalizeEnrollment(snapshot: CodexAuthService.AuthSnapshot) async throws { + CodexAuthService.markEnrolled() + await CodexProxyService.shared.ensureRunning() + guard CodexProxyService.shared.isRunning else { + CodexAuthService.clearEnrollment() + throw EnrollmentError.proxyFailed(CodexProxyService.shared.lastError ?? "unknown") + } + + let fingerprint = CodexAuthService.enrollmentFingerprint(for: snapshot.accountId) + do { + try await APIClient.shared.activateChatGPT(fingerprint: fingerprint) + } catch { + CodexAuthService.clearEnrollment() + await CodexProxyService.shared.stop() + throw EnrollmentError.backendFailed(error.localizedDescription) + } + + await CodexProviderBootstrap.applyIfNeeded() + await FloatingBarUsageLimiter.shared.fetchPlan() + AppState.current?.isPaywalled = false + } +} diff --git a/desktop/Desktop/Sources/CodexProviderBootstrap.swift b/desktop/Desktop/Sources/CodexProviderBootstrap.swift new file mode 100644 index 00000000000..2c6e2d5814b --- /dev/null +++ b/desktop/Desktop/Sources/CodexProviderBootstrap.swift @@ -0,0 +1,52 @@ +import Foundation + +/// Wires ChatGPT/Codex loopback proxy into hybrid daemon provider settings when enrolled. +enum CodexProviderBootstrap { + + static func codexProviderObject(model: String? = nil) -> [String: LocalDaemonSettingUpdateValue] { + [ + "kind": "openai_compatible", + "base_url": .string(CodexProxyEndpoints.baseURL), + "model": .string(model ?? CodexAuthService.preferredModel), + "api_key": .string(""), + ] + } + + /// After successful ChatGPT connect: start proxy and set chat + ai providers (not embeddings). + @MainActor + static func applyIfNeeded() async { + guard CodexAuthService.isActive else { return } + await CodexProxyService.shared.ensureRunning() + guard CodexProxyService.shared.isRunning else { return } + + guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { + return + } + + do { + let provider = codexProviderObject() + let updates: [String: LocalDaemonSettingUpdateValue] = [ + "chat_provider": .object(provider), + "ai_provider": .object(provider), + ] + _ = try await APIClient.shared.updateSelectedBackendSettings(updates) + log("CodexProviderBootstrap: applied chat_provider + ai_provider → Codex loopback") + } catch { + logError("CodexProviderBootstrap: failed to update daemon settings", error: error) + } + } + + /// Clear Codex provider keys from daemon (logout). + @MainActor + static func clearDaemonProviders() async { + guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { return } + do { + _ = try await APIClient.shared.updateSelectedBackendSettings([ + "chat_provider": .null, + "ai_provider": .null, + ]) + } catch { + logError("CodexProviderBootstrap: failed to clear providers", error: error) + } + } +} diff --git a/desktop/Desktop/Sources/CodexProxyService.swift b/desktop/Desktop/Sources/CodexProxyService.swift new file mode 100644 index 00000000000..6eecf3570af --- /dev/null +++ b/desktop/Desktop/Sources/CodexProxyService.swift @@ -0,0 +1,186 @@ +import Foundation + +/// Loopback Codex proxy endpoints (nonisolated constants). +enum CodexProxyEndpoints { + static let defaultPort: Int = 10531 + + static var baseURL: String { + let port: Int + if let raw = ProcessInfo.processInfo.environment["OMI_CODEX_PROXY_PORT"], + let value = Int(raw), value > 0 + { + port = value + } else { + port = defaultPort + } + return "http://127.0.0.1:\(port)/v1" + } + + static var healthURL: String { + baseURL.replacingOccurrences(of: "/v1", with: "") + "/health" + } +} + +/// Manages the loopback Codex OpenAI-compatible proxy (`desktop/codex-proxy`). +@MainActor +final class CodexProxyService: ObservableObject { + static let shared = CodexProxyService() + + static var defaultBaseURL: String { CodexProxyEndpoints.baseURL } + + private static var port: Int { + if let raw = ProcessInfo.processInfo.environment["OMI_CODEX_PROXY_PORT"], + let value = Int(raw), value > 0 + { + return value + } + return CodexProxyEndpoints.defaultPort + } + + @Published private(set) var isRunning = false + @Published private(set) var lastError: String? + + private var process: Process? + private var healthTask: Task? + + private init() {} + + /// Start proxy when ChatGPT tier is active. Idempotent. + func ensureRunning() async { + guard CodexAuthService.isActive else { + await stop() + return + } + if isRunning, await healthCheck() { return } + await stop() + guard let executable = resolveExecutableURL() else { + lastError = + "Codex proxy binary not found. Build with: cd desktop/codex-proxy && cargo build --release" + isRunning = false + return + } + guard CodexAuthService.loadSnapshot() != nil else { + lastError = "Sign in with ChatGPT first (run Codex login or connect in Settings)." + isRunning = false + return + } + + let proc = Process() + proc.executableURL = executable + proc.arguments = [] + var env = ProcessInfo.processInfo.environment + env["OMI_CODEX_PROXY_PORT"] = String(Self.port) + proc.environment = env + proc.standardOutput = FileHandle.nullDevice + proc.standardError = FileHandle.nullDevice + + do { + try proc.run() + process = proc + for _ in 0..<30 { + try? await Task.sleep(nanoseconds: 100_000_000) + if await healthCheck() { + isRunning = true + lastError = nil + startHealthMonitor() + log("CodexProxyService: proxy running at \(Self.defaultBaseURL)") + return + } + } + lastError = "Codex proxy failed to start (health check timeout)." + await stop() + } catch { + lastError = error.localizedDescription + await stop() + } + } + + func stop() async { + healthTask?.cancel() + healthTask = nil + if let process, process.isRunning { + process.terminate() + } + process = nil + isRunning = false + } + + private func startHealthMonitor() { + healthTask?.cancel() + healthTask = Task { + while !Task.isCancelled { + try? await Task.sleep(nanoseconds: 15_000_000_000) + guard CodexAuthService.isActive else { + await stop() + return + } + if !(await healthCheck()) { + log("CodexProxyService: health failed — restarting") + await ensureRunning() + return + } + } + } + } + + private func healthCheck() async -> Bool { + guard let url = URL(string: CodexProxyEndpoints.healthURL) else { + return false + } + var request = URLRequest(url: url) + request.timeoutInterval = 2 + do { + let (_, response) = try await URLSession.shared.data(for: request) + return (response as? HTTPURLResponse)?.statusCode == 200 + } catch { + return false + } + } + + private func resolveExecutableURL() -> URL? { + let names = ["omi-codex-proxy", "codex-proxy"] + if let resource = Bundle.main.resourceURL { + for name in names { + let candidate = resource.appendingPathComponent(name) + if FileManager.default.isExecutableFile(atPath: candidate.path) { + return candidate + } + } + } + let repoRelative = URL(fileURLWithPath: #filePath) + .deletingLastPathComponent() + .deletingLastPathComponent() + .deletingLastPathComponent() + .appendingPathComponent("codex-proxy/target/release/omi-codex-proxy") + if FileManager.default.isExecutableFile(atPath: repoRelative.path) { + return repoRelative + } + for name in names { + if let path = which(name) { + return URL(fileURLWithPath: path) + } + } + return nil + } + + private func which(_ name: String) -> String? { + let proc = Process() + proc.executableURL = URL(fileURLWithPath: "/usr/bin/which") + proc.arguments = [name] + let pipe = Pipe() + proc.standardOutput = pipe + proc.standardError = FileHandle.nullDevice + do { + try proc.run() + proc.waitUntilExit() + guard proc.terminationStatus == 0 else { return nil } + let data = pipe.fileHandleForReading.readDataToEndOfFile() + let path = String(data: data, encoding: .utf8)? + .trimmingCharacters(in: .whitespacesAndNewlines) + guard let path, !path.isEmpty else { return nil } + return path + } catch { + return nil + } + } +} diff --git a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift index f8ec41c38a9..54077175fb9 100644 --- a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift +++ b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift @@ -13,6 +13,7 @@ enum DesktopBackendEnvironment { case directSTT case directChat case directEmbeddings + case localMemoryWiki case optionalCloudSTT case optionalCloudChat case managedAgentVM @@ -201,6 +202,8 @@ enum DesktopBackendEnvironment { return isAffirmative(currentEnvironmentValue("OMI_HYBRID_DIRECT_CHAT_ENABLED")) case .directEmbeddings: return isAffirmative(currentEnvironmentValue("OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED")) + case .localMemoryWiki: + return CodexAuthService.isActive || !MemorySearchMode.usesVectorEmbeddings case .optionalCloudSTT: return isAffirmative(currentEnvironmentValue("OMI_HYBRID_OPTIONAL_CLOUD_STT")) case .optionalCloudChat: @@ -234,6 +237,8 @@ enum DesktopBackendEnvironment { case .directEmbeddings: return "Direct local embeddings require OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=1 and an embedding_provider in hybrid settings." + case .localMemoryWiki: + return "Local memory wiki (FTS) is off while vector embeddings mode is enabled." case .optionalCloudSTT: return "Optional cloud speech-to-text is off. Set OMI_HYBRID_OPTIONAL_CLOUD_STT=1 to allow hosted Listen." case .optionalCloudChat: diff --git a/desktop/Desktop/Sources/HybridChatClient.swift b/desktop/Desktop/Sources/HybridChatClient.swift index 5b4a830827a..7dca650bf67 100644 --- a/desktop/Desktop/Sources/HybridChatClient.swift +++ b/desktop/Desktop/Sources/HybridChatClient.swift @@ -20,6 +20,7 @@ enum HybridChatClient { case notConfigured case invalidSettings case invalidResponse + case providerError(String) var errorDescription: String? { switch self { @@ -30,11 +31,16 @@ enum HybridChatClient { return "chat_provider settings are invalid." case .invalidResponse: return "Chat provider returned an unexpected response." + case .providerError(let message): + return message } } } static func isEnabled() -> Bool { + if CodexAuthService.isActive { + return true + } guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { return false } @@ -44,8 +50,15 @@ enum HybridChatClient { return true } - /// Resolves chat_provider → ai_provider / provider (matches HybridLLMClient). + /// Resolves Codex → chat_provider → ai_provider / provider (matches HybridLLMClient). static func resolveEffectiveChatConfig(from settings: [LocalDaemonSetting]) -> ProviderConfig? { + if let codex = HybridLLMClient.codexProviderConfig() { + return ProviderConfig( + baseURL: codex.baseURL, + model: codex.model, + apiKey: codex.apiKey + ) + } if let chat = loadProviderConfig(from: settings, key: "chat_provider") { return chat } @@ -110,6 +123,9 @@ enum HybridChatClient { conversationMessages: [(role: String, text: String)], userMessage: String ) async throws -> CompletionResult { + if CodexAuthService.isActive { + await CodexProxyService.shared.ensureRunning() + } let settings = try await APIClient.shared.getSelectedBackendSettings() return try await complete( systemPrompt: systemPrompt, @@ -181,9 +197,12 @@ enum HybridChatClient { request.httpBody = try JSONEncoder().encode(payload) let (data, response) = try await URLSession.shared.data(for: request) - guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else { + guard let http = response as? HTTPURLResponse else { throw ClientError.invalidResponse } + guard (200..<300).contains(http.statusCode) else { + throw ClientError.providerError(parseProviderErrorBody(data: data, statusCode: http.statusCode)) + } guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any], let choices = json["choices"] as? [[String: Any]], let first = choices.first, @@ -218,4 +237,25 @@ enum HybridChatClient { outputTokens: outputTokens ) } + + private static func parseProviderErrorBody(data: Data, statusCode: Int) -> String { + if let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] { + if let error = json["error"] as? [String: Any], + let message = error["message"] as? String, + !message.isEmpty + { + return message + } + if let detail = json["detail"] as? String, !detail.isEmpty { + return detail + } + } + let snippet = + String(data: data.prefix(400), encoding: .utf8)? + .trimmingCharacters(in: .whitespacesAndNewlines) ?? "" + if snippet.isEmpty { + return "Chat provider request failed (HTTP \(statusCode))." + } + return "Chat provider request failed (HTTP \(statusCode)): \(snippet)" + } } diff --git a/desktop/Desktop/Sources/HybridLLMClient.swift b/desktop/Desktop/Sources/HybridLLMClient.swift index e741bcc7593..f820d6dda43 100644 --- a/desktop/Desktop/Sources/HybridLLMClient.swift +++ b/desktop/Desktop/Sources/HybridLLMClient.swift @@ -12,6 +12,9 @@ actor HybridDaemonSettingsCache { private let ttlSeconds: TimeInterval = 45 func settings() async throws -> [LocalDaemonSetting] { + if CodexAuthService.isActive { + return [] + } guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { return [] } @@ -66,8 +69,21 @@ enum HybridLLMClient { return loadOpenAICompatibleProvider(forKeys: ["vision_provider"], settings: settings) } - /// Primary chat routing for assistants: prefers chat_provider, then legacy ai_provider / provider. + /// ChatGPT / Codex subscription loopback proxy (when enrolled). + static func codexProviderConfig() -> ProviderConfig? { + guard CodexAuthService.isActive else { return nil } + return ProviderConfig( + baseURL: CodexProxyEndpoints.baseURL, + model: CodexAuthService.preferredModel, + apiKey: "" + ) + } + + /// Primary chat routing for assistants: Codex → chat_provider → ai_provider → BYOK OpenAI. static func resolveEffectiveChatConfig(settings: [LocalDaemonSetting]) -> ProviderConfig? { + if let codex = codexProviderConfig() { + return codex + } if let c = loadOpenAICompatibleProvider(forKeys: ["chat_provider"], settings: settings) { return c } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 17cadad10ac..7f820a4b7a4 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -395,6 +395,8 @@ struct SettingsContentView: View { @AppStorage("dev_deepgram_api_key") private var devDeepgramKey: String = "" @State private var byokKeyStatuses: [BYOKProvider: BYOKValidator.Status] = [:] @State private var byokActivationError: String? + @State private var codexEnrollmentError: String? + @State private var codexEnrollmentBusy = false init( appState: AppState, @@ -3292,6 +3294,8 @@ struct SettingsContentView: View { preferencesSubsection advancedCategoryHeader(title: "Troubleshooting", icon: "wrench.and.screwdriver") troubleshootingSubsection + advancedCategoryHeader(title: "ChatGPT plan", icon: "bubble.left.and.bubble.right") + chatGPTPlanSubsection advancedCategoryHeader(title: "Developer API Keys", icon: "key") developerKeysSubsection @@ -5374,6 +5378,93 @@ struct SettingsContentView: View { return formatter } + // MARK: - ChatGPT / Codex plan + + private var chatGPTPlanSubsection: some View { + VStack(spacing: 20) { + settingsCard(settingId: "advanced.chatgpt.info") { + VStack(alignment: .leading, spacing: 10) { + HStack(spacing: 10) { + Image(systemName: CodexAuthService.isActive ? "checkmark.seal.fill" : "person.crop.circle.badge.checkmark") + .foregroundColor(CodexAuthService.isActive ? OmiColors.success : OmiColors.textTertiary) + Text(CodexAuthService.isActive ? "ChatGPT plan active" : "Use your ChatGPT subscription") + .scaledFont(size: 14, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + } + Text( + CodexAuthService.isActive + ? "LLM features use your ChatGPT/Codex subscription via a local proxy on this Mac. Memory search uses local wiki + keyword search (no embedding API). Live transcription is unchanged." + : "Sign in with ChatGPT to route chat and proactive AI through your subscription via a local proxy on this Mac. A Terminal window opens for Codex login — complete sign-in there and Omi will connect automatically. Unofficial community integration — use at your own risk per OpenAI terms. Tokens stay on this Mac." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + if CodexProxyService.shared.isRunning { + Text("Proxy: \(CodexProxyService.defaultBaseURL)") + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + } + } + } + + if let codexEnrollmentError { + settingsCard(settingId: "advanced.chatgpt.error") { + HStack(spacing: 10) { + Image(systemName: "exclamationmark.triangle.fill") + .foregroundColor(OmiColors.warning) + Text(codexEnrollmentError) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + Spacer() + } + } + } + + HStack(spacing: 12) { + if CodexAuthService.isActive { + Button(action: disconnectChatGPTPlan) { + Text("Disconnect") + .frame(maxWidth: .infinity) + } + .buttonStyle(.bordered) + } else { + Button(action: signInChatGPTPlan) { + Text(codexEnrollmentBusy ? "Signing in…" : "Sign in with ChatGPT") + .frame(maxWidth: .infinity) + } + .buttonStyle(.borderedProminent) + .disabled(codexEnrollmentBusy) + } + } + } + } + + private func signInChatGPTPlan() { + guard !codexEnrollmentBusy else { return } + codexEnrollmentBusy = true + codexEnrollmentError = nil + Task { + do { + try await CodexEnrollmentCoordinator.connect() + await MainActor.run { + codexEnrollmentBusy = false + codexEnrollmentError = nil + } + } catch { + await MainActor.run { + codexEnrollmentBusy = false + codexEnrollmentError = error.localizedDescription + } + } + } + } + + private func disconnectChatGPTPlan() { + Task { + await CodexEnrollmentCoordinator.disconnect() + await MainActor.run { codexEnrollmentError = nil } + } + } + // MARK: - Developer API Keys Subsection private var developerKeysSubsection: some View { diff --git a/desktop/Desktop/Sources/MemoryWikiStorage.swift b/desktop/Desktop/Sources/MemoryWikiStorage.swift new file mode 100644 index 00000000000..dde7bcf9f8e --- /dev/null +++ b/desktop/Desktop/Sources/MemoryWikiStorage.swift @@ -0,0 +1,181 @@ +import Foundation +import GRDB + +// MARK: - Memory wiki page + +struct MemoryWikiPageRecord: Codable, FetchableRecord, PersistableRecord, Identifiable { + var id: Int64? + var slug: String + var title: String + var body: String + var tagsJson: String? + var linksJson: String? + var category: String + var sourceType: String? + var sourceId: String? + var createdAt: Date + var updatedAt: Date + + static let databaseTableName = "memory_pages" +} + +struct MemoryWikiSearchHit: Identifiable, Equatable { + let id: Int64 + let slug: String + let title: String + let snippet: String + let category: String + let rank: Double +} + +/// Local structured wiki + FTS5 search (no embedding API). +actor MemoryWikiStorage { + static let shared = MemoryWikiStorage() + + private var dbQueue: DatabasePool? + + private init() {} + + func invalidateCache() { + dbQueue = nil + } + + private func ensureDB() async throws -> DatabasePool { + if let dbQueue { return dbQueue } + try await RewindDatabase.shared.initialize() + guard let queue = await RewindDatabase.shared.getDatabaseQueue() else { + throw MemoryWikiError.databaseNotInitialized + } + dbQueue = queue + return queue + } + + func upsertPage( + slug: String, + title: String, + body: String, + tags: [String] = [], + links: [String] = [], + category: String = "system", + sourceType: String? = nil, + sourceId: String? = nil + ) async throws -> Int64 { + let db = try await ensureDB() + let now = Date() + let tagsJson = tags.isEmpty ? nil : String(data: try JSONEncoder().encode(tags), encoding: .utf8) + let linksJson = links.isEmpty ? nil : String(data: try JSONEncoder().encode(links), encoding: .utf8) + + return try await db.write { database in + if let existing = try MemoryWikiPageRecord + .filter(Column("slug") == slug) + .fetchOne(database) + { + var row = existing + row.title = title + row.body = body + row.tagsJson = tagsJson + row.linksJson = linksJson + row.category = category + row.sourceType = sourceType + row.sourceId = sourceId + row.updatedAt = now + try row.update(database) + return existing.id ?? 0 + } + var row = MemoryWikiPageRecord( + id: nil, + slug: slug, + title: title, + body: body, + tagsJson: tagsJson, + linksJson: linksJson, + category: category, + sourceType: sourceType, + sourceId: sourceId, + createdAt: now, + updatedAt: now + ) + try row.insert(database) + return row.id ?? 0 + } + } + + func search(query: String, limit: Int = 20) async throws -> [MemoryWikiSearchHit] { + let trimmed = query.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { return [] } + + let words = trimmed.components(separatedBy: .whitespaces) + .map { $0.filter { $0.isLetter || $0.isNumber } } + .filter { $0.count >= 2 } + guard !words.isEmpty else { return [] } + let ftsQuery = words.map { "\($0)*" }.joined(separator: " OR ") + + let db = try await ensureDB() + return try await db.read { database in + let rows = try Row.fetchAll( + database, + sql: """ + SELECT memory_pages.id, memory_pages.slug, memory_pages.title, memory_pages.category, + snippet(memory_pages_fts, 1, '', '', '…', 12) AS snippet, + bm25(memory_pages_fts) AS rank + FROM memory_pages_fts + JOIN memory_pages ON memory_pages.id = memory_pages_fts.rowid + WHERE memory_pages_fts MATCH ? + ORDER BY rank + LIMIT ? + """, + arguments: [ftsQuery, limit] + ) + return rows.compactMap { row -> MemoryWikiSearchHit? in + guard let id: Int64 = row["id"], + let slug: String = row["slug"], + let title: String = row["title"], + let category: String = row["category"] + else { return nil } + let snippet: String = row["snippet"] ?? title + let rank: Double = row["rank"] ?? 0 + return MemoryWikiSearchHit( + id: id, slug: slug, title: title, snippet: snippet, category: category, rank: rank + ) + } + } + } + + static func slugify(_ title: String) -> String { + let lowered = title.lowercased() + let allowed = lowered.map { char -> Character in + if char.isLetter || char.isNumber { return char } + if char == " " || char == "-" || char == "_" { return "-" } + return "-" + } + let collapsed = String(allowed) + .replacingOccurrences(of: "--+", with: "-", options: .regularExpression) + .trimmingCharacters(in: CharacterSet(charactersIn: "-")) + return collapsed.isEmpty ? "page-\(UUID().uuidString.prefix(8))" : collapsed + } +} + +enum MemoryWikiError: Error { + case databaseNotInitialized +} + +/// Feature flag: local wiki search instead of vector embeddings. +enum MemorySearchMode { + case localWiki + case vectorEmbeddings + + static var current: MemorySearchMode { + if CodexAuthService.isActive { + return .localWiki + } + if HybridEmbeddingClient.isEnabled() { + return .vectorEmbeddings + } + let raw = UserDefaults.standard.string(forKey: "memory_search_mode") ?? "local_wiki" + return raw == "vector" ? .vectorEmbeddings : .localWiki + } + + static var usesVectorEmbeddings: Bool { + current == .vectorEmbeddings + } +} diff --git a/desktop/Desktop/Sources/OmiApp.swift b/desktop/Desktop/Sources/OmiApp.swift index cd86871e01d..744cbf433f1 100644 --- a/desktop/Desktop/Sources/OmiApp.swift +++ b/desktop/Desktop/Sources/OmiApp.swift @@ -107,6 +107,9 @@ struct OMIApp: App { .withFontScaling() .onAppear { log("OmiApp: Main window content appeared (mode: \(Self.launchMode.rawValue))") + if CodexAuthService.isActive { + Task { await CodexProxyService.shared.ensureRunning() } + } } } .windowStyle(.titleBar) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift index 82c2e838607..bd2f3402a75 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift @@ -245,6 +245,21 @@ actor MemoryAssistant: ProactiveAssistant { do { let inserted = try await MemoryStorage.shared.insertLocalMemory(record) log("Memory: Saved to SQLite (id: \(inserted.id ?? -1))") + if !MemorySearchMode.usesVectorEmbeddings { + let title = memory.content.prefix(80).trimmingCharacters(in: .whitespacesAndNewlines) + let slug = MemoryWikiStorage.slugify(String(title)) + Task { + _ = try? await MemoryWikiStorage.shared.upsertPage( + slug: slug, + title: String(title), + body: memory.content, + tags: ["memory", category], + category: category, + sourceType: "memory", + sourceId: inserted.id.map(String.init) + ) + } + } return inserted } catch { logError("Memory: Failed to save to SQLite", error: error) diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift index 90357260f76..dff5d22be78 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift @@ -177,8 +177,12 @@ actor TaskAssistant: ProactiveAssistant { // MARK: - Embedding Lifecycle - /// Load embedding index and kick off backfill + /// Load embedding index and kick off backfill (skipped when using local wiki search). private func initializeEmbeddings() async { + guard MemorySearchMode.usesVectorEmbeddings else { + log("Task: Skipping embedding index — local wiki / FTS memory search mode") + return + } await EmbeddingService.shared.loadIndex() // Backfill in background Task { @@ -367,10 +371,16 @@ actor TaskAssistant: ProactiveAssistant { windowTitle: windowTitle ) - // Generate embedding for new staged task in background + // Generate embedding or wiki index for new staged task in background if let recordId = extractionRecord?.id { - Task { - await self.generateEmbeddingForTask(id: recordId, text: task.title) + if MemorySearchMode.usesVectorEmbeddings { + Task { + await self.generateEmbeddingForTask(id: recordId, text: task.title) + } + } else { + Task { + await self.indexStagedTaskInWiki(id: recordId, description: task.title) + } } } @@ -409,6 +419,23 @@ actor TaskAssistant: ProactiveAssistant { } } + private func indexStagedTaskInWiki(id: Int64, description: String) async { + let slug = "task-\(id)" + do { + _ = try await MemoryWikiStorage.shared.upsertPage( + slug: slug, + title: description, + body: description, + tags: ["task", "staged"], + category: "task", + sourceType: "staged_task", + sourceId: String(id) + ) + } catch { + logError("Task: Failed to index staged task in memory wiki", error: error) + } + } + /// Save extracted task to staged_tasks SQLite table private func saveTaskToSQLite( task: ExtractedTask, @@ -1249,8 +1276,12 @@ actor TaskAssistant: ProactiveAssistant { ) } - /// Execute vector similarity search + /// Execute vector similarity search (or FTS + wiki when embeddings disabled). private func executeVectorSearch(query: String) async -> [TaskSearchResult] { + guard MemorySearchMode.usesVectorEmbeddings else { + return await executeKeywordAndWikiSearch(query: query) + } + var results: [TaskSearchResult] = [] do { @@ -1296,6 +1327,28 @@ actor TaskAssistant: ProactiveAssistant { return results.sorted { ($0.similarity ?? 0) > ($1.similarity ?? 0) } } + /// Keyword FTS across tasks plus memory wiki (used when vector embeddings are off). + private func executeKeywordAndWikiSearch(query: String) async -> [TaskSearchResult] { + var results = await executeKeywordSearch(query: query) + do { + let wikiHits = try await MemoryWikiStorage.shared.search(query: query, limit: 8) + for hit in wikiHits { + results.append( + TaskSearchResult( + id: hit.id, + description: "\(hit.title): \(hit.snippet)", + status: "active", + similarity: nil, + matchType: "wiki_fts", + relevanceScore: nil + )) + } + } catch { + logError("Task: Wiki search failed", error: error) + } + return results + } + /// Execute FTS5 keyword search (searches both action_items and staged_tasks) private func executeKeywordSearch(query: String) async -> [TaskSearchResult] { var results: [TaskSearchResult] = [] diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift index 21e7ae002d7..7cb9d89b56f 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift @@ -264,7 +264,9 @@ actor GeminiClient { init(apiKey: String? = nil, model: String = ModelQoS.Gemini.proactive) throws { // BREAKING CHANGE (issue #5861): apiKey parameter is ignored for cloud proxy mode. self.model = model - if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + || CodexAuthService.isActive + { self.transport = .hybridOpenAICompatible return } @@ -294,6 +296,9 @@ actor GeminiClient { jsonMode: Bool, timeout: TimeInterval = 300 ) async throws -> String { + if CodexAuthService.isActive { + await CodexProxyService.shared.ensureRunning() + } let settings = try await HybridDaemonSettingsCache.shared.settings() guard let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) else { throw GeminiClientError.missingAPIKey @@ -319,6 +324,9 @@ actor GeminiClient { jsonMode: Bool, timeout: TimeInterval = 300 ) async throws -> String { + if CodexAuthService.isActive { + await CodexProxyService.shared.ensureRunning() + } let settings = try await HybridDaemonSettingsCache.shared.settings() guard let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) else { throw GeminiClientError.missingAPIKey diff --git a/desktop/Desktop/Sources/Providers/ChatProvider.swift b/desktop/Desktop/Sources/Providers/ChatProvider.swift index aa1094c6e7b..6ff7977da6c 100644 --- a/desktop/Desktop/Sources/Providers/ChatProvider.swift +++ b/desktop/Desktop/Sources/Providers/ChatProvider.swift @@ -2460,8 +2460,7 @@ A screenshot may be attached — use it silently only if relevant. Never mention usageLimiter.recordQuery() } - let localDaemon = DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon - let mayUseHybridDirectChat = localDaemon && HybridChatClient.isEnabled() + let mayUseHybridDirectChat = HybridChatClient.isEnabled() // Ensure Claude / ACP bridge when not using hybrid direct chat. Hybrid path may // skip the bridge until multimodal attachments require ACP. diff --git a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift index d246d6ab2f0..c741bcf2fa9 100644 --- a/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift +++ b/desktop/Desktop/Sources/Rewind/Core/RewindDatabase.swift @@ -2176,6 +2176,58 @@ actor RewindDatabase { print("[RewindDatabase] Migration: Added embedding_model/embedding_dim metadata for hybrid embedders") } + migrator.registerMigration("createMemoryWikiPages") { db in + try db.create(table: "memory_pages") { t in + t.autoIncrementedPrimaryKey("id") + t.column("slug", .text).notNull().unique() + t.column("title", .text).notNull() + t.column("body", .text).notNull() + t.column("tagsJson", .text) + t.column("linksJson", .text) + t.column("category", .text).notNull().defaults(to: "system") + t.column("sourceType", .text) + t.column("sourceId", .text) + t.column("createdAt", .datetime).notNull() + t.column("updatedAt", .datetime).notNull() + } + try db.create(index: "idx_memory_pages_slug", on: "memory_pages", columns: ["slug"]) + try db.create(index: "idx_memory_pages_updated", on: "memory_pages", columns: ["updatedAt"]) + + try db.execute(sql: """ + CREATE VIRTUAL TABLE memory_pages_fts USING fts5( + title, + body, + tagsJson, + content='memory_pages', + content_rowid='id', + tokenize='unicode61' + ) + """) + + try db.execute(sql: """ + CREATE TRIGGER memory_pages_fts_ai AFTER INSERT ON memory_pages BEGIN + INSERT INTO memory_pages_fts(rowid, title, body, tagsJson) + VALUES (new.id, new.title, new.body, COALESCE(new.tagsJson, '')); + END + """) + + try db.execute(sql: """ + CREATE TRIGGER memory_pages_fts_ad AFTER DELETE ON memory_pages BEGIN + INSERT INTO memory_pages_fts(memory_pages_fts, rowid, title, body, tagsJson) + VALUES ('delete', old.id, old.title, old.body, COALESCE(old.tagsJson, '')); + END + """) + + try db.execute(sql: """ + CREATE TRIGGER memory_pages_fts_au AFTER UPDATE ON memory_pages BEGIN + INSERT INTO memory_pages_fts(memory_pages_fts, rowid, title, body, tagsJson) + VALUES ('delete', old.id, old.title, old.body, COALESCE(old.tagsJson, '')); + INSERT INTO memory_pages_fts(rowid, title, body, tagsJson) + VALUES (new.id, new.title, new.body, COALESCE(new.tagsJson, '')); + END + """) + } + try migrator.migrate(queue) } diff --git a/desktop/Desktop/Tests/CodexAuthServiceTests.swift b/desktop/Desktop/Tests/CodexAuthServiceTests.swift new file mode 100644 index 00000000000..d27aac4e249 --- /dev/null +++ b/desktop/Desktop/Tests/CodexAuthServiceTests.swift @@ -0,0 +1,126 @@ +import XCTest + +@testable import Omi_Computer + +final class CodexAuthServiceTests: XCTestCase { + + override func setUp() { + super.setUp() + UserDefaults.standard.removeObject(forKey: "codex_auth_enrolled") + UserDefaults.standard.removeObject(forKey: "codex_preferred_model") + } + + override func tearDown() { + UserDefaults.standard.removeObject(forKey: "codex_auth_enrolled") + UserDefaults.standard.removeObject(forKey: "codex_preferred_model") + super.tearDown() + } + + func testEnrollmentFingerprintIsStableHex() { + let fp = CodexAuthService.enrollmentFingerprint(for: "account-123") + XCTAssertEqual(fp.count, 64) + XCTAssertTrue(fp.allSatisfy { $0.isHexDigit }) + XCTAssertEqual(fp, CodexAuthService.enrollmentFingerprint(for: "account-123")) + } + + func testIsActiveRequiresEnrollmentAndSnapshot() { + let tempAuth = makeTempCodexHomeWithoutAuth() + defer { tempAuth.cleanup() } + + XCTAssertFalse(CodexAuthService.isActive) + CodexAuthService.markEnrolled() + XCTAssertFalse(CodexAuthService.isActive) + } + + func testLoadSnapshotParsesNestedTokensFormat() throws { + let dir = FileManager.default.temporaryDirectory + .appendingPathComponent("codex-auth-test-\(UUID().uuidString)") + try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + defer { try? FileManager.default.removeItem(at: dir) } + + let authURL = dir.appendingPathComponent("auth.json") + let payload = """ + { + "auth_mode": "chatgpt", + "tokens": { + "access_token": "test-access", + "refresh_token": "test-refresh", + "account_id": "acct-nested" + } + } + """ + try payload.write(to: authURL, atomically: true, encoding: .utf8) + + let previous = ProcessInfo.processInfo.environment["CODEX_HOME"] + setenv("CODEX_HOME", dir.path, 1) + defer { + if let previous { + setenv("CODEX_HOME", previous, 1) + } else { + unsetenv("CODEX_HOME") + } + } + + let snap = CodexAuthService.loadSnapshot() + XCTAssertEqual(snap?.accessToken, "test-access") + XCTAssertEqual(snap?.accountId, "acct-nested") + XCTAssertEqual(snap?.refreshToken, "test-refresh") + } + + func testMemorySearchModeDefaultsToWikiWhenCodexEnrolled() { + CodexAuthService.markEnrolled() + XCTAssertEqual(MemorySearchMode.current, .localWiki) + } +} + +final class CodexProxyConfigTests: XCTestCase { + + override func setUp() { + super.setUp() + UserDefaults.standard.removeObject(forKey: "codex_auth_enrolled") + } + + override func tearDown() { + UserDefaults.standard.removeObject(forKey: "codex_auth_enrolled") + super.tearDown() + } + + func testHybridLLMUsesDaemonSettingsWithoutAuthSnapshot() throws { + let tempAuth = makeTempCodexHomeWithoutAuth() + defer { tempAuth.cleanup() } + + CodexAuthService.markEnrolled() + let payload = """ + [{"key":"chat_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://chat.local/v1\\",\\"model\\":\\"m-chat\\"}","updated_at":"2026-05-19T12:00:00Z"}] + """ + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .iso8601 + let settings = try decoder.decode([LocalDaemonSetting].self, from: Data(payload.utf8)) + let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) + XCTAssertEqual(config?.baseURL, "http://chat.local/v1") + XCTAssertNil(HybridLLMClient.codexProviderConfig()) + } +} + +private struct TempCodexHome { + let path: String + let previous: String? + + func cleanup() { + if let previous { + setenv("CODEX_HOME", previous, 1) + } else { + unsetenv("CODEX_HOME") + } + try? FileManager.default.removeItem(atPath: path) + } +} + +private func makeTempCodexHomeWithoutAuth() -> TempCodexHome { + let dir = FileManager.default.temporaryDirectory + .appendingPathComponent("codex-auth-empty-\(UUID().uuidString)") + try? FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + let previous = ProcessInfo.processInfo.environment["CODEX_HOME"] + setenv("CODEX_HOME", dir.path, 1) + return TempCodexHome(path: dir.path, previous: previous) +} diff --git a/desktop/codex-proxy/Cargo.lock b/desktop/codex-proxy/Cargo.lock new file mode 100644 index 00000000000..f1acce4a8e4 --- /dev/null +++ b/desktop/codex-proxy/Cargo.lock @@ -0,0 +1,1443 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "axum" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" +dependencies = [ + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "r-efi", + "wasip2", + "wasm-bindgen", +] + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "icu_collections" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" +dependencies = [ + "displaydoc", + "potential_utf", + "utf8_iter", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" + +[[package]] +name = "icu_properties" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" + +[[package]] +name = "icu_provider" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "js-sys" +version = "0.3.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" +dependencies = [ + "cfg-if", + "futures-util", + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "litemap" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "omi-codex-proxy" +version = "0.1.0" +dependencies = [ + "axum", + "reqwest", + "serde", + "serde_json", + "tokio", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "potential_utf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" +dependencies = [ + "zerovec", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" + +[[package]] +name = "rustls" +version = "0.23.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tinyvec" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokio" +version = "1.52.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cfcf7e2740e6fc6d4d688b4ef00650406bb94adf4731e43c096c3a19fe40840" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", + "url", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.3+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.121" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.121" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.121" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.121" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "web-sys" +version = "0.3.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b572dff8bcf38bad0fa19729c89bb5748b2b9b1d8be70cf90df697e3a8f32aa" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "writeable" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" + +[[package]] +name = "yoke" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zerofrom" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/desktop/codex-proxy/Cargo.toml b/desktop/codex-proxy/Cargo.toml new file mode 100644 index 00000000000..a435f949738 --- /dev/null +++ b/desktop/codex-proxy/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "omi-codex-proxy" +version = "0.1.0" +edition = "2021" +description = "Minimal loopback proxy: OpenAI chat/completions ↔ ChatGPT Codex /responses" +license = "MIT" + +[dependencies] +axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "net"] } diff --git a/desktop/codex-proxy/README.md b/desktop/codex-proxy/README.md new file mode 100644 index 00000000000..a21b24d7ad1 --- /dev/null +++ b/desktop/codex-proxy/README.md @@ -0,0 +1,61 @@ +# omi-codex-proxy + +Small **localhost-only** HTTP proxy that turns `POST /v1/chat/completions` (OpenAI-style JSON) into `POST https://chatgpt.com/backend-api/codex/responses`, using OAuth tokens stored in **`~/.codex/auth.json`**. + +## Prerequisites + +Rust toolchain (`cargo`). HTTPS uses **rustls** (no OpenSSL dependency). + +Create `~/.codex/auth.json` (minimal fields): + +```json +{ + "access_token": "", + "account_id": "", + "refresh_token": "" +} +``` + +## Run + +```bash +cd desktop/codex-proxy +# Optional: defaults to 10531 when unset +export OMI_CODEX_PROXY_PORT=10531 +cargo run --release +``` + +## Endpoints + +- `GET /health` → `200 OK` (`ok`). +- `POST /v1/chat/completions` → upstream Codex `responses`; **non-stream only** (`"stream": true` returns `501`). + +Forwarded headers: + +- `Authorization: Bearer ` +- `ChatGPT-Account-Id: ` +- `Content-Type: application/json` + +On **`401`** from Codex (and if `refresh_token` is present), the proxy refreshes via `POST https://auth.openai.com/oauth/token` (`client_id=app_EMoamEEZ73f0CkXaXp7hrann`, `grant_type=refresh_token`), persists updated tokens back to `auth.json`, and retries once. + +## Request / response mapping (basic) + +**OpenAI → Codex** + +- Copies `model` and maps OpenAI chat `messages` into a Codex Responses-style `input`: + - String `content` becomes `[{ "type": "input_text", "text": "..." }]`. + - Array `content` is passed through. + +**Codex → OpenAI** + +Parses common Responses payloads: **`output[].content[]`**, looking for **`output_text` / `text`** (or string `content`). If the upstream body already resembles `choices`, it is echoed. + +If your Codex revision uses a slightly different envelope, extend `extract_assistant_text` / `codex_payload_from_openai_chat` in `src/main.rs`. + +## Example curl + +```bash +curl -sS http://127.0.0.1:${OMI_CODEX_PROXY_PORT:-10531}/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{"model":"gpt-5","messages":[{"role":"user","content":"Say hi in one word."}]}' +``` diff --git a/desktop/codex-proxy/e2e_smoke.sh b/desktop/codex-proxy/e2e_smoke.sh new file mode 100755 index 00000000000..d1b53a1be8a --- /dev/null +++ b/desktop/codex-proxy/e2e_smoke.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash +# End-to-end smoke tests for omi-codex-proxy (requires valid ~/.codex/auth.json). +set -euo pipefail + +PORT="${OMI_CODEX_PROXY_PORT:-10531}" +BASE="http://127.0.0.1:${PORT}" +PROXY_BIN="$(cd "$(dirname "$0")" && pwd)/target/release/omi-codex-proxy" + +if ! curl -fsS "${BASE}/health" >/dev/null 2>&1; then + if [[ ! -x "$PROXY_BIN" ]]; then + echo "Building proxy..." + (cd "$(dirname "$0")" && cargo build --release) + fi + echo "Starting proxy on ${PORT}..." + "$PROXY_BIN" & + PROXY_PID=$! + trap 'kill "$PROXY_PID" 2>/dev/null || true' EXIT + for _ in $(seq 1 30); do + curl -fsS "${BASE}/health" >/dev/null 2>&1 && break + sleep 0.2 + done +fi + +python3 <<'PY' +import json, urllib.request, textwrap, sys + +BASE = f"http://127.0.0.1:{__import__('os').environ.get('OMI_CODEX_PROXY_PORT', '10531')}/v1/chat/completions" + +def post(messages, label): + payload = {"model": "gpt-5.4", "messages": messages, "temperature": 0.2, "stream": False} + req = urllib.request.Request(BASE, data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"}) + try: + with urllib.request.urlopen(req, timeout=120) as resp: + data = json.load(resp) + text = data["choices"][0]["message"]["content"].strip() + assert text, f"empty response for {label}" + print(f"PASS {label}: {text[:80]!r}") + return text + except Exception as e: + body = e.read().decode()[:400] if hasattr(e, "read") else str(e) + print(f"FAIL {label}: {body}", file=sys.stderr) + raise + +post([{"role": "system", "content": "You are Omi."}, {"role": "user", "content": "Reply with exactly: alpha"}], "single-turn") +post([ + {"role": "system", "content": "You are Omi."}, + {"role": "user", "content": "hey"}, + {"role": "assistant", "content": "Hey!"}, + {"role": "user", "content": "What is 2+2? Reply with just the number."}, +], "multi-turn") +post([ + {"role": "system", "content": "You are Omi.\n" + ("Context line.\n" * 20)}, + {"role": "user", "content": "Reply with exactly: ready"}, +], "large-system") +post([{"role": "user", "content": "Reply with exactly: default"}], "default-instructions") +post([ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": [{"type": "text", "text": "Reply with exactly: array-ok"}]}, +], "hybrid-llm-array-content") +print("ALL PROXY E2E CHECKS PASSED") +PY + +echo "e2e_smoke: OK" diff --git a/desktop/codex-proxy/src/main.rs b/desktop/codex-proxy/src/main.rs new file mode 100644 index 00000000000..5cfc047d2be --- /dev/null +++ b/desktop/codex-proxy/src/main.rs @@ -0,0 +1,799 @@ +//! Local OpenAI-compat proxy for ChatGPT Codex `/backend-api/codex/responses`. + +use std::{ + fs, io, + net::SocketAddr, + path::{Path, PathBuf}, + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; + +use axum::{ + body::Body, + extract::{Json, State}, + http::{header, StatusCode}, + response::{IntoResponse, Response}, + routing::{get, post}, + Router, +}; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION, CONTENT_TYPE}; +use serde::Deserialize; +use serde_json::{json, Value}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; + +const DEFAULT_PORT: u16 = 10531; +const CODEX_RESPONSES_URL: &str = "https://chatgpt.com/backend-api/codex/responses"; +const OPENAI_AUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; +const OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; +const DEFAULT_INSTRUCTIONS: &str = "You are a helpful assistant."; + +#[derive(Clone, Deserialize)] +struct AuthCore { + access_token: String, + account_id: String, + #[serde(default)] + refresh_token: Option, +} + +impl AuthCore { + fn from_doc(doc: &Value) -> Result { + if doc.get("access_token").is_some() { + return serde_json::from_value(doc.clone()).map_err(|e| e.to_string()); + } + if let Some(tokens) = doc.get("tokens") { + return serde_json::from_value(tokens.clone()).map_err(|e| e.to_string()); + } + Err("missing access_token (expected top-level or tokens.access_token)".into()) + } +} + +struct AuthDisk { + path: PathBuf, + doc: Value, +} + +struct AppState { + http: reqwest::Client, + auth: Mutex, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let auth_path = default_auth_path()?; + let doc = load_auth_doc(&auth_path)?; + AuthCore::from_doc(&doc).map_err(|e| format!("invalid {}: {}", auth_path.display(), e))?; + + let port = std::env::var("OMI_CODEX_PROXY_PORT") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_PORT); + + let state = Arc::new(AppState { + http: reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .user_agent(format!("omi-codex-proxy/{}", env!("CARGO_PKG_VERSION"))) + .build()?, + auth: Mutex::new(AuthDisk { + path: auth_path.clone(), + doc, + }), + }); + + let app = Router::new() + .route("/health", get(health_ok)) + .route("/v1/chat/completions", post(chat_completions)) + .with_state(state.clone()); + + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + let listener = TcpListener::bind(addr).await?; + println!( + "omi-codex-proxy listening on http://{}", + listener.local_addr()? + ); + println!("auth file {}", auth_path.display()); + + axum::serve(listener, app).await?; + Ok(()) +} + +async fn health_ok() -> &'static str { + "ok" +} + +async fn chat_completions(State(state): State>, Json(body): Json) -> Response { + if body.get("stream").and_then(|v| v.as_bool()) == Some(true) { + return json_error( + StatusCode::NOT_IMPLEMENTED, + "stream=true is not implemented; send non-stream chat/completions.", + ) + .into_response(); + } + + let upstream_payload = match codex_payload_from_openai_chat(&body) { + Ok(v) => v, + Err(msg) => return json_error(StatusCode::BAD_REQUEST, &msg).into_response(), + }; + + let requested_model_hint = body + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("") + .to_owned(); + + match invoke_codex(&state, &upstream_payload, requested_model_hint).await { + Ok(resp) => resp, + Err(msg) => json_error(StatusCode::INTERNAL_SERVER_ERROR, &msg).into_response(), + } +} + +async fn invoke_codex( + state: &AppState, + upstream_payload: &Value, + requested_model_hint: String, +) -> Result { + let bytes = encode_codex_request(upstream_payload)?; + + let mut refreshed = false; + loop { + let hdrs = { + let g = state.auth.lock().await; + let core = AuthCore::from_doc(&g.doc)?; + codex_headers(&core)? + }; + + let upstream = state + .http + .post(CODEX_RESPONSES_URL) + .headers(hdrs.clone()) + .header(CONTENT_TYPE, HeaderValue::from_static("application/json")) + .header(header::ACCEPT, HeaderValue::from_static("text/event-stream")) + .body(bytes.clone()) + .send() + .await + .map_err(|e| e.to_string())?; + + let status = upstream.status(); + let upstream_bytes = upstream.bytes().await.map_err(|e| e.to_string())?; + + if status.as_u16() == 401 { + let has_refresh_token = { + let g = state.auth.lock().await; + AuthCore::from_doc(&g.doc)? + .refresh_token + .as_ref() + .map(|t| !t.is_empty()) + .unwrap_or(false) + }; + + let should_retry = !refreshed && has_refresh_token; + + if should_retry { + let refresh_token_owned = { + let g = state.auth.lock().await; + AuthCore::from_doc(&g.doc)? + .refresh_token + .filter(|t| !t.is_empty()) + .ok_or_else(|| { + "refresh_token vanished between checks — cannot refresh access token" + .to_string() + })? + }; + + let refresh_envelope = + oauth_refresh_access_token(&state.http, refresh_token_owned).await?; + + { + let mut g = state.auth.lock().await; + apply_refresh_to_doc(&mut *g, refresh_envelope)?; + } + + refreshed = true; + continue; + } + } + + if !status.is_success() { + return Ok((status, Body::from(upstream_bytes)).into_response()); + } + + let sse_body = String::from_utf8(upstream_bytes.to_vec()) + .map_err(|e| format!("upstream SSE is not valid UTF-8: {e}"))?; + let assistant_text = collect_text_from_codex_sse(&sse_body)?; + if assistant_text.trim().is_empty() { + return Err("upstream SSE contained no assistant text".into()); + } + + let openai_completion = json!({ + "id": new_chat_completion_id(), + "object": "chat.completion", + "created": unix_secs(), + "model": if requested_model_hint.trim().is_empty() { Value::Null } else { Value::String(requested_model_hint.clone()) }, + "choices": [{ + "index": 0, + "message": { "role": "assistant", "content": assistant_text }, + "logprobs": null, + "finish_reason": "stop", + }], + "usage": Value::Null, + }); + return Ok(JsonResponse { + status: StatusCode::OK, + json: openai_completion, + } + .into_response()); + } +} + +fn encode_codex_request(payload: &Value) -> Result, String> { + serde_json::to_vec(payload).map_err(|e| e.to_string()) +} + +#[derive(Clone)] +struct JsonResponse { + status: StatusCode, + json: Value, +} + +impl IntoResponse for JsonResponse { + fn into_response(self) -> Response { + let body = serde_json::to_vec(&self.json).unwrap_or_else(|_| { + br#"{"error":{"message":"failed to serialize upstream json envelope","type":"omi_codex_proxy_error"}}"#.to_vec() + }); + + Response::builder() + .status(self.status) + .header(header::CONTENT_TYPE, "application/json") + .body(Body::from(body)) + .unwrap() + } +} + +fn json_error(status: StatusCode, message: impl AsRef) -> JsonResponse { + JsonResponse { + status, + json: json!({ + "error": { + "message": message.as_ref(), + "type": "omi_codex_proxy_error", + }, + }), + } +} + +#[derive(Debug, Deserialize)] +struct RefreshEnvelope { + access_token: Option, + refresh_token: Option, +} + +async fn oauth_refresh_access_token( + http: &reqwest::Client, + refresh_token: String, +) -> Result { + let response = http + .post(OPENAI_AUTH_TOKEN_URL) + .header( + CONTENT_TYPE, + HeaderValue::from_static("application/x-www-form-urlencoded"), + ) + .form(&[ + ("grant_type", "refresh_token"), + ("refresh_token", refresh_token.as_str()), + ("client_id", OAUTH_CLIENT_ID), + ]) + .send() + .await + .map_err(|e| format!("oauth refresh transport error: {e}"))?; + + let status = response.status(); + let body_text = response + .text() + .await + .map_err(|e| format!("oauth refresh read error: {e}"))?; + + if !status.is_success() { + return Err(format!( + "oauth refresh failed ({status}): {body_text}", + status = status, + body_text = body_text + )); + } + + let env: RefreshEnvelope = serde_json::from_str(&body_text) + .map_err(|e| format!("oauth refresh json decode error ({e}): {body_text}"))?; + + Ok(env) +} + +fn apply_refresh_to_doc(disk: &mut AuthDisk, mut env: RefreshEnvelope) -> Result<(), String> { + let new_access = env + .access_token + .take() + .filter(|t| !t.is_empty()) + .ok_or_else(|| "oauth refresh succeeded but omitted access_token".to_string())?; + + if disk.doc.get("tokens").map(|t| t.is_object()).unwrap_or(false) { + if let Some(tokens) = disk.doc.get_mut("tokens").and_then(Value::as_object_mut) { + tokens.insert("access_token".to_string(), Value::String(new_access)); + if let Some(new_refresh) = env.refresh_token.take().filter(|t| !t.is_empty()) { + tokens.insert("refresh_token".to_string(), Value::String(new_refresh)); + } + } + } else { + disk.doc["access_token"] = Value::String(new_access); + if let Some(new_refresh) = env.refresh_token.take().filter(|t| !t.is_empty()) { + disk.doc["refresh_token"] = Value::String(new_refresh); + } + } + + persist_auth(&disk.path, &disk.doc)?; + println!( + "oauth: refreshed access_token (persisted {})", + disk.path.display() + ); + Ok(()) +} + +fn codex_headers(core: &AuthCore) -> Result { + let mut map = HeaderMap::new(); + let bearer = HeaderValue::from_str(format!("Bearer {}", core.access_token).as_str()) + .map_err(|e| e.to_string())?; + map.insert(AUTHORIZATION, bearer); + map.insert( + HeaderName::from_static("chatgpt-account-id"), + HeaderValue::from_str(&core.account_id).map_err(|e| e.to_string())?, + ); + map.insert( + HeaderName::from_static("originator"), + HeaderValue::from_static("pi"), + ); + Ok(map) +} + +fn default_auth_path() -> Result { + if let Ok(codex_home) = std::env::var("CODEX_HOME") { + let trimmed = codex_home.trim(); + if !trimmed.is_empty() { + return Ok(PathBuf::from(trimmed).join("auth.json")); + } + } + let home = std::env::var_os("HOME") + .filter(|v| !v.is_empty()) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + "HOME is not set; cannot resolve ~/.codex/auth.json", + ) + })?; + Ok(PathBuf::from(home).join(".codex").join("auth.json")) +} + +fn load_auth_doc(path: &Path) -> Result { + let raw = fs::read_to_string(path).map_err(|e| format!("read {}: {e}", path.display()))?; + serde_json::from_str(&raw).map_err(|e| format!("parse {}: {e}", path.display())) +} + +fn persist_auth(path: &Path, doc: &Value) -> Result<(), String> { + let serialized = + serde_json::to_vec_pretty(doc).map_err(|e| format!("serialize auth doc: {e}"))?; + fs::write(path, serialized).map_err(|e| format!("write {}: {e}", path.display())) +} + +fn codex_payload_from_openai_chat(openai_body: &Value) -> Result { + let model = openai_body + .get("model") + .and_then(|m| m.as_str()) + .filter(|s| !s.trim().is_empty()) + .unwrap_or("gpt-5.4"); + + let messages = openai_body + .get("messages") + .and_then(Value::as_array) + .ok_or_else(|| "missing array `messages`".to_string())?; + + if messages.is_empty() { + return Err("`messages` must be non-empty".into()); + } + + let mut instructions_parts = Vec::new(); + let mut input_items = Vec::new(); + + for (idx, msg) in messages.iter().enumerate() { + let role = msg + .get("role") + .and_then(|r| r.as_str()) + .ok_or_else(|| format!("messages[{idx}].role missing"))?; + + if role == "system" { + if let Some(text) = message_content_as_string(msg.get("content").unwrap_or(&Value::Null))? + { + if !text.is_empty() { + instructions_parts.push(text); + } + } + continue; + } + + let content_parts = + normalize_message_content(msg.get("content").unwrap_or(&Value::Null), role)?; + let parts_array = content_parts.as_array().cloned().unwrap_or_default(); + if parts_array.is_empty() { + continue; + } + input_items.push(json!({ + "type": "message", + "role": role, + "content": parts_array, + })); + } + + if input_items.is_empty() { + return Err("`messages` must include at least one non-system message".into()); + } + + let instructions = if instructions_parts.is_empty() { + DEFAULT_INSTRUCTIONS.to_string() + } else { + instructions_parts.join("\n") + }; + + Ok(json!({ + "model": model, + "store": false, + "stream": true, + "instructions": instructions, + "input": input_items, + "text": { "verbosity": "medium" }, + "include": ["reasoning.encrypted_content"], + "tool_choice": "auto", + "parallel_tool_calls": true, + })) +} + +fn message_content_as_string(raw: &Value) -> Result, String> { + Ok(match raw { + Value::String(s) => Some(s.clone()), + Value::Array(items) => { + let mut out = String::new(); + for it in items { + if let Some(t) = it.get("text").and_then(Value::as_str) { + out.push_str(t); + } + } + if out.is_empty() { + None + } else { + Some(out) + } + } + Value::Null => None, + other => Err(format!( + "unsupported message content type `{}` — expected string or array", + serde_json::to_string(other).unwrap_or_else(|_| "unknown".into()) + ))?, + }) +} + +fn collect_text_from_codex_sse(body: &str) -> Result { + let mut text = String::new(); + for line in body.lines() { + let data = line + .strip_prefix("data:") + .map(str::trim) + .filter(|s| !s.is_empty() && *s != "[DONE]"); + let Some(data) = data else { + continue; + }; + + let event: Value = + serde_json::from_str(data).map_err(|e| format!("invalid SSE data json: {e}"))?; + match event.get("type").and_then(Value::as_str) { + Some("response.output_text.delta") => { + if let Some(delta) = event.get("delta").and_then(Value::as_str) { + text.push_str(delta); + } + } + Some("response.output_text.done") => { + if text.is_empty() { + if let Some(done) = event.get("text").and_then(Value::as_str) { + text.push_str(done); + } + } + } + Some("error") => { + let message = event + .pointer("/error/message") + .and_then(Value::as_str) + .or_else(|| event.get("message").and_then(Value::as_str)) + .unwrap_or("Codex backend returned an error event"); + return Err(message.to_string()); + } + _ => {} + } + } + + Ok(text) +} + +fn normalize_message_content(raw: &Value, role: &str) -> Result { + let text_type = if role == "assistant" { + "output_text" + } else { + "input_text" + }; + Ok(match raw { + Value::String(s) => json!([ + {"type": text_type, "text": s}, + ]), + Value::Array(parts) => { + if parts.is_empty() { + Value::Array(vec![]) + } else { + Value::Array( + parts + .iter() + .map(|part| normalize_content_part(part, text_type)) + .collect(), + ) + } + } + Value::Null => Value::Array(vec![json!({ "type": text_type, "text": "" })]), + other => Err(format!( + "unsupported message content type `{}` — expected string or array", + serde_json::to_string(other).unwrap_or_else(|_| "unknown".into()) + ))?, + }) +} + +fn normalize_content_part(part: &Value, default_type: &str) -> Value { + match part { + Value::Object(map) => { + let mut out = map.clone(); + if !out.contains_key("type") { + out.insert("type".to_string(), Value::String(default_type.to_string())); + } else if let Some(Value::String(kind)) = out.get("type") { + if kind == "text" { + out.insert("type".to_string(), Value::String(default_type.to_string())); + } + } + Value::Object(out) + } + Value::String(s) => json!({ "type": default_type, "text": s }), + other => other.clone(), + } +} + +fn codex_body_to_chat_completion(model_fallback: &str, bytes: &[u8]) -> Result { + let v: Value = serde_json::from_slice(bytes).map_err(|e| format!("upstream json: {e}"))?; + + if v.get("choices").is_some() { + let mut enriched = v; + if enriched.get("id").and_then(Value::as_str).is_none() + || enriched.get("id") == Some(&Value::Null) + { + enriched["id"] = Value::String(new_chat_completion_id()); + } + if enriched.get("object").and_then(Value::as_str).is_none() + || enriched.get("object") == Some(&Value::Null) + { + enriched["object"] = Value::from("chat.completion"); + } + if enriched.get("created").and_then(Value::as_i64).is_none() + || enriched.get("created") == Some(&Value::Null) + { + enriched["created"] = Value::Number(unix_secs().into()); + } + Ok(enriched) + } else { + let text = extract_assistant_text(&v) + .ok_or_else(|| serde_json::to_string(&v).unwrap_or_else(|_| "(unprintable)".into()))?; + let model = chat_model_choice(&v, model_fallback)?; + Ok(json!({ + "id": new_chat_completion_id(), + "object": "chat.completion", + "created": unix_secs(), + "model": model, + "choices": [{ + "index": 0, + "message": { "role": "assistant", "content": text}, + "logprobs": null, + "finish_reason": infer_finish_reason(&v), + }], + "usage": v.get("usage").cloned().unwrap_or(Value::Null), + })) + } +} + +fn infer_finish_reason(v: &Value) -> Value { + v.pointer("/choices/0/finish_reason") + .cloned() + .unwrap_or_else(|| Value::from("stop")) +} + +fn chat_model_choice(v: &Value, fallback: &str) -> Result { + if let Some(m) = v + .get("model") + .and_then(Value::as_str) + .filter(|s| !s.is_empty()) + { + return Ok(Value::String(m.to_owned())); + } + if !fallback.trim().is_empty() { + return Ok(Value::String(fallback.to_owned())); + } + Err("upstream response missing model and original request lacked model hint".into()) +} + +fn new_chat_completion_id() -> String { + format!("chatcmpl-{}", now_millis()) +} + +fn unix_secs() -> i64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs() as i64) + .unwrap_or(0) +} + +fn now_millis() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_millis()) + .unwrap_or(0) +} + +/// Best-effort assistant text extractor for Responses-style payloads (`output`, etc.). +fn extract_assistant_text(v: &Value) -> Option { + if let Some(Value::Array(choices)) = v.get("choices") { + if let Some(first) = choices.first() { + let text = openai_choice_text(first); + if text.is_some() { + return text; + } + } + } + + if let Some(output) = v.get("output") { + if let Some(text) = flatten_output_chunks(output) { + return Some(text); + } + } + + let mut chunks = Vec::new(); + visit_collect_output_text(v, &mut chunks); + if chunks.is_empty() { + None + } else { + Some(chunks.join("")) + } +} + +fn openai_choice_text(choice: &Value) -> Option { + let msg = choice.get("message")?; + extract_message_content_as_string(msg) +} + +fn extract_message_content_as_string(msg: &Value) -> Option { + match msg.get("content") { + Some(Value::String(s)) => Some(s.clone()), + Some(Value::Array(items)) => { + let mut out = String::new(); + for it in items { + if let Some(t) = it.get("text").and_then(Value::as_str) { + out.push_str(t); + } else if let Some(inner) = it.get("content").and_then(Value::as_str) { + out.push_str(inner); + } + } + if !out.is_empty() { + Some(out) + } else { + None + } + } + Some(Value::Null) | None => None, + _ => None, + } +} + +fn flatten_output_chunks(outputs: &Value) -> Option { + let mut combined = Vec::new(); + if let Value::Array(items) = outputs { + for item in items { + visit_collect_output_text(item, &mut combined); + } + } else { + visit_collect_output_text(outputs, &mut combined); + } + + (!combined.is_empty()).then(|| combined.join("")) +} + +fn push_output_text_piece(map: &serde_json::Map, bucket: &mut Vec) { + if map.get("type").and_then(Value::as_str) != Some("output_text") { + return; + } + let Some(raw) = map.get("text").and_then(Value::as_str) else { + return; + }; + let trimmed = raw.trim(); + if !trimmed.is_empty() { + bucket.push(trimmed.to_owned()); + } +} + +fn visit_collect_output_text(v: &Value, bucket: &mut Vec) { + match v { + Value::Object(map) => { + push_output_text_piece(map, bucket); + for child in map.values() { + visit_collect_output_text(child, bucket); + } + } + Value::Array(items) => { + for item in items { + visit_collect_output_text(item, bucket); + } + } + _ => {} + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn maps_openai_messages_to_codex_input() { + let openai = json!({ + "model": "gpt-test", + "messages": [ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":[{"type":"input_text","text":"hi"}]}, + {"role":"assistant","content":"Hello!"}, + {"role":"user","content":"again"} + ] + }); + let out = codex_payload_from_openai_chat(&openai).expect("mapping"); + assert_eq!(out["instructions"], json!("You are helpful.")); + assert_eq!(out["stream"], json!(true)); + assert_eq!(out["input"].as_array().unwrap().len(), 3); + assert_eq!(out["input"][0]["type"], json!("message")); + assert_eq!(out["input"][0]["role"], json!("user")); + assert_eq!( + out["input"][0]["content"], + json!([{"type":"input_text","text":"hi"}]) + ); + assert_eq!( + out["input"][1]["content"], + json!([{"type":"output_text","text":"Hello!"}]) + ); + assert_eq!( + out["input"][2]["content"], + json!([{"type":"input_text","text":"again"}]) + ); + } + + #[test] + fn maps_responses_like_output_message() { + let upstream = json!({ + "model": "gpt-output", + "output": [{ + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "Hello"} + ] + }] + }); + + let out = + codex_body_to_chat_completion("", &serde_json::to_vec(&upstream).unwrap()).unwrap(); + assert_eq!( + out["choices"][0]["message"]["content"], + Value::from("Hello") + ); + assert_eq!(out["model"], Value::from("gpt-output")); + } +} diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md index 551bee47c52..719e8f6394d 100644 --- a/desktop/local-backend/docs/hybrid-provider-settings.md +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -68,6 +68,17 @@ local guest session startup. Chat resolves `chat_provider` → `ai_provider` → Configure or override in **Settings → Plan and Usage** (local mode) or via `PUT /v1/settings`. +## ChatGPT / Codex subscription (desktop) + +When the user connects **ChatGPT plan** in Settings → Advanced: + +- A loopback proxy (`desktop/codex-proxy`, default `http://127.0.0.1:10531/v1`) uses `~/.codex/auth.json` from Codex CLI login. +- Daemon `chat_provider` and `ai_provider` are set to that URL (not `embedding_provider`). +- Memory search uses **local wiki + FTS5** instead of vector embeddings unless `OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=1`. +- Deepgram / live transcription behavior is unchanged. + +Build proxy: `cd desktop/codex-proxy && cargo build --release` + ## Test connection `POST /v1/settings/test-provider` with body `{ "key": "ai_provider" }` runs a minimal diff --git a/desktop/run.sh b/desktop/run.sh index a174c1cf8c9..ce5561c4e9a 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -616,6 +616,17 @@ else echo "Warning: pi-mono-extension not found at $PI_MONO_EXT_DIR" fi +substep "Building Codex proxy (omi-codex-proxy)" +CODEX_PROXY_DIR="$(dirname "$0")/codex-proxy" +if [ -d "$CODEX_PROXY_DIR" ]; then + (cd "$CODEX_PROXY_DIR" && cargo build --release --quiet) + mkdir -p "$APP_BUNDLE/Contents/Resources" + cp -f "$CODEX_PROXY_DIR/target/release/omi-codex-proxy" "$APP_BUNDLE/Contents/Resources/omi-codex-proxy" + chmod +x "$APP_BUNDLE/Contents/Resources/omi-codex-proxy" +else + echo "Warning: codex-proxy not found at $CODEX_PROXY_DIR" +fi + substep "Copying .env.app" if [ -f ".env.app.dev" ]; then cp -f .env.app.dev "$APP_BUNDLE/Contents/Resources/.env" From 6e74de8665b3ff5d4edafc17cf29c403d9b5ed25 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 01:18:05 -0400 Subject: [PATCH 30/58] Fix ChatGPT plan connect/disconnect UX and proxy startup reliability. Reuse a healthy loopback proxy instead of respawning, surface proxy stderr on failure, skip cloud enrollment in local daemon mode, and refresh Settings immediately via CodexAuthStore. Co-authored-by: Cursor --- .../Desktop/Sources/CodexAuthService.swift | 4 +++ desktop/Desktop/Sources/CodexAuthStore.swift | 17 +++++++++ .../Sources/CodexEnrollmentCoordinator.swift | 15 +++++++- .../Desktop/Sources/CodexProxyService.swift | 27 +++++++++++--- .../MainWindow/Pages/SettingsPage.swift | 36 ++++++++++++++----- .../Desktop/Tests/CodexAuthServiceTests.swift | 3 ++ 6 files changed, 88 insertions(+), 14 deletions(-) create mode 100644 desktop/Desktop/Sources/CodexAuthStore.swift diff --git a/desktop/Desktop/Sources/CodexAuthService.swift b/desktop/Desktop/Sources/CodexAuthService.swift index 03209ea36d4..b8ea2f29abd 100644 --- a/desktop/Desktop/Sources/CodexAuthService.swift +++ b/desktop/Desktop/Sources/CodexAuthService.swift @@ -95,12 +95,16 @@ enum CodexAuthService { isEnrolled && loadSnapshot() != nil } + @MainActor static func markEnrolled() { UserDefaults.standard.set(true, forKey: enrolledKey) + CodexAuthStore.shared.notifyEnrollmentChanged() } + @MainActor static func clearEnrollment() { UserDefaults.standard.set(false, forKey: enrolledKey) + CodexAuthStore.shared.notifyEnrollmentChanged() } static func enrollmentFingerprintIfActive() -> String? { diff --git a/desktop/Desktop/Sources/CodexAuthStore.swift b/desktop/Desktop/Sources/CodexAuthStore.swift new file mode 100644 index 00000000000..0b8206bc192 --- /dev/null +++ b/desktop/Desktop/Sources/CodexAuthStore.swift @@ -0,0 +1,17 @@ +import Foundation + +/// Publishes ChatGPT / Codex enrollment changes so Settings can refresh without re-navigation. +@MainActor +final class CodexAuthStore: ObservableObject { + static let shared = CodexAuthStore() + + var isEnrolled: Bool { CodexAuthService.isEnrolled } + + var isActive: Bool { CodexAuthService.isActive } + + private init() {} + + func notifyEnrollmentChanged() { + objectWillChange.send() + } +} diff --git a/desktop/Desktop/Sources/CodexEnrollmentCoordinator.swift b/desktop/Desktop/Sources/CodexEnrollmentCoordinator.swift index eb9bafb5bb2..ebe7b46fb06 100644 --- a/desktop/Desktop/Sources/CodexEnrollmentCoordinator.swift +++ b/desktop/Desktop/Sources/CodexEnrollmentCoordinator.swift @@ -5,12 +5,15 @@ import Foundation @MainActor enum CodexEnrollmentCoordinator { enum EnrollmentError: LocalizedError { + case alreadyInProgress case authFileMissing case proxyFailed(String) case backendFailed(String) var errorDescription: String? { switch self { + case .alreadyInProgress: + return "Sign-in is already in progress. Wait a moment or complete login in Terminal." case .authFileMissing: return "Sign-in timed out. Complete login in Terminal, then try again." @@ -74,7 +77,9 @@ enum CodexEnrollmentCoordinator { /// Sign in: use existing auth if present, otherwise open Codex login then poll. static func connect(pollSeconds: Int = 120) async throws { - guard !connectInFlight else { return } + guard !connectInFlight else { + throw EnrollmentError.alreadyInProgress + } connectInFlight = true defer { connectInFlight = false @@ -118,6 +123,14 @@ enum CodexEnrollmentCoordinator { throw EnrollmentError.proxyFailed(CodexProxyService.shared.lastError ?? "unknown") } + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + log("CodexEnrollmentCoordinator: local daemon mode — skipping cloud chatgpt-active enrollment") + await CodexProviderBootstrap.applyIfNeeded() + await FloatingBarUsageLimiter.shared.fetchPlan() + AppState.current?.isPaywalled = false + return + } + let fingerprint = CodexAuthService.enrollmentFingerprint(for: snapshot.accountId) do { try await APIClient.shared.activateChatGPT(fingerprint: fingerprint) diff --git a/desktop/Desktop/Sources/CodexProxyService.swift b/desktop/Desktop/Sources/CodexProxyService.swift index 6eecf3570af..220ad204d1e 100644 --- a/desktop/Desktop/Sources/CodexProxyService.swift +++ b/desktop/Desktop/Sources/CodexProxyService.swift @@ -51,8 +51,16 @@ final class CodexProxyService: ObservableObject { await stop() return } - if isRunning, await healthCheck() { return } - await stop() + // Reuse an already-healthy proxy (e.g. from a prior session or manual start). + if await healthCheck() { + isRunning = true + lastError = nil + startHealthMonitor() + return + } + if isRunning { + await stop() + } guard let executable = resolveExecutableURL() else { lastError = "Codex proxy binary not found. Build with: cd desktop/codex-proxy && cargo build --release" @@ -71,13 +79,14 @@ final class CodexProxyService: ObservableObject { var env = ProcessInfo.processInfo.environment env["OMI_CODEX_PROXY_PORT"] = String(Self.port) proc.environment = env + let stderrPipe = Pipe() + proc.standardError = stderrPipe proc.standardOutput = FileHandle.nullDevice - proc.standardError = FileHandle.nullDevice do { try proc.run() process = proc - for _ in 0..<30 { + for _ in 0..<50 { try? await Task.sleep(nanoseconds: 100_000_000) if await healthCheck() { isRunning = true @@ -87,7 +96,15 @@ final class CodexProxyService: ObservableObject { return } } - lastError = "Codex proxy failed to start (health check timeout)." + let stderrData = stderrPipe.fileHandleForReading.readDataToEndOfFile() + let stderrHint = String(data: stderrData, encoding: .utf8)? + .trimmingCharacters(in: .whitespacesAndNewlines) + let detail = + (stderrHint?.isEmpty == false) + ? stderrHint! + : "Codex proxy failed to start (health check timeout)." + lastError = detail + logError("CodexProxyService: failed to start — \(detail)") await stop() } catch { lastError = error.localizedDescription diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 7f820a4b7a4..2e1e05201f1 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -397,6 +397,8 @@ struct SettingsContentView: View { @State private var byokActivationError: String? @State private var codexEnrollmentError: String? @State private var codexEnrollmentBusy = false + @ObservedObject private var codexAuthStore = CodexAuthStore.shared + @ObservedObject private var codexProxyService = CodexProxyService.shared init( appState: AppState, @@ -5385,21 +5387,29 @@ struct SettingsContentView: View { settingsCard(settingId: "advanced.chatgpt.info") { VStack(alignment: .leading, spacing: 10) { HStack(spacing: 10) { - Image(systemName: CodexAuthService.isActive ? "checkmark.seal.fill" : "person.crop.circle.badge.checkmark") - .foregroundColor(CodexAuthService.isActive ? OmiColors.success : OmiColors.textTertiary) - Text(CodexAuthService.isActive ? "ChatGPT plan active" : "Use your ChatGPT subscription") + Image(systemName: codexAuthStore.isActive ? "checkmark.seal.fill" : "person.crop.circle.badge.checkmark") + .foregroundColor(codexAuthStore.isActive ? OmiColors.success : OmiColors.textTertiary) + Text(codexAuthStore.isActive ? "ChatGPT plan active" : "Use your ChatGPT subscription") .scaledFont(size: 14, weight: .semibold) .foregroundColor(OmiColors.textPrimary) } Text( - CodexAuthService.isActive + codexAuthStore.isActive ? "LLM features use your ChatGPT/Codex subscription via a local proxy on this Mac. Memory search uses local wiki + keyword search (no embedding API). Live transcription is unchanged." : "Sign in with ChatGPT to route chat and proactive AI through your subscription via a local proxy on this Mac. A Terminal window opens for Codex login — complete sign-in there and Omi will connect automatically. Unofficial community integration — use at your own risk per OpenAI terms. Tokens stay on this Mac." ) .scaledFont(size: 12) .foregroundColor(OmiColors.textTertiary) - if CodexProxyService.shared.isRunning { + if codexProxyService.isRunning { Text("Proxy: \(CodexProxyService.defaultBaseURL)") + .scaledFont(size: 11) + .foregroundColor(OmiColors.success) + } else if codexAuthStore.isEnrolled, let proxyError = codexProxyService.lastError { + Text("Proxy not running: \(proxyError)") + .scaledFont(size: 11) + .foregroundColor(OmiColors.warning) + } else if CodexAuthService.loadSnapshot() == nil { + Text("Complete Codex login in Terminal, then try again.") .scaledFont(size: 11) .foregroundColor(OmiColors.textTertiary) } @@ -5420,7 +5430,7 @@ struct SettingsContentView: View { } HStack(spacing: 12) { - if CodexAuthService.isActive { + if codexAuthStore.isActive { Button(action: disconnectChatGPTPlan) { Text("Disconnect") .frame(maxWidth: .infinity) @@ -5445,9 +5455,16 @@ struct SettingsContentView: View { Task { do { try await CodexEnrollmentCoordinator.connect() + await CodexProxyService.shared.ensureRunning() await MainActor.run { codexEnrollmentBusy = false - codexEnrollmentError = nil + codexAuthStore.notifyEnrollmentChanged() + if codexAuthStore.isActive, !codexProxyService.isRunning { + codexEnrollmentError = + codexProxyService.lastError ?? "Codex proxy did not start." + } else { + codexEnrollmentError = nil + } } } catch { await MainActor.run { @@ -5461,7 +5478,10 @@ struct SettingsContentView: View { private func disconnectChatGPTPlan() { Task { await CodexEnrollmentCoordinator.disconnect() - await MainActor.run { codexEnrollmentError = nil } + await MainActor.run { + codexEnrollmentError = nil + codexAuthStore.notifyEnrollmentChanged() + } } } diff --git a/desktop/Desktop/Tests/CodexAuthServiceTests.swift b/desktop/Desktop/Tests/CodexAuthServiceTests.swift index d27aac4e249..0ae87f70bd2 100644 --- a/desktop/Desktop/Tests/CodexAuthServiceTests.swift +++ b/desktop/Desktop/Tests/CodexAuthServiceTests.swift @@ -23,6 +23,7 @@ final class CodexAuthServiceTests: XCTestCase { XCTAssertEqual(fp, CodexAuthService.enrollmentFingerprint(for: "account-123")) } + @MainActor func testIsActiveRequiresEnrollmentAndSnapshot() { let tempAuth = makeTempCodexHomeWithoutAuth() defer { tempAuth.cleanup() } @@ -67,6 +68,7 @@ final class CodexAuthServiceTests: XCTestCase { XCTAssertEqual(snap?.refreshToken, "test-refresh") } + @MainActor func testMemorySearchModeDefaultsToWikiWhenCodexEnrolled() { CodexAuthService.markEnrolled() XCTAssertEqual(MemorySearchMode.current, .localWiki) @@ -85,6 +87,7 @@ final class CodexProxyConfigTests: XCTestCase { super.tearDown() } + @MainActor func testHybridLLMUsesDaemonSettingsWithoutAuthSnapshot() throws { let tempAuth = makeTempCodexHomeWithoutAuth() defer { tempAuth.cleanup() } From 22e2787bacdcbd0662c38434f4642f5bd20c6960 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 13:13:41 +0700 Subject: [PATCH 31/58] feat(local-backend): add hybrid provider settings, routes, and storage for local profiles --- .../docs/hybrid-provider-settings.md | 93 ++- desktop/local-backend/src/main.rs | 81 ++- desktop/local-backend/src/providers.rs | 562 +++++++++++++++++- desktop/local-backend/src/routes.rs | 42 +- desktop/local-backend/src/storage.rs | 82 ++- 5 files changed, 782 insertions(+), 78 deletions(-) diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md index 719e8f6394d..aaa781f0a98 100644 --- a/desktop/local-backend/docs/hybrid-provider-settings.md +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -2,10 +2,81 @@ ## Context -Desktop hybrid mode (`OMI_DESKTOP_BACKEND_MODE=local`) stores provider credentials in the +Desktop hybrid mode (`OMI_DESKTOP_BACKEND_MODE=local`) stores provider policy in the local daemon SQLite `settings` table. Requests go **directly** to configured endpoints, never through Omi Python/Rust proxies. +The current durable policy key is `provider_policy`. Older raw provider keys are still +read as a compatibility bridge while desktop UI migrates to the typed policy API. + +## Provider policy + +`provider_policy` is versioned JSON: + +```json +{ + "version": 1, + "provider_accounts": [ + { + "id": "local-ollama", + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "api_key": null, + "display_name": "Local Ollama", + "capabilities": { + "chat_completions": true, + "json_mode": true, + "tool_calls": false, + "vision": false, + "speech_to_text": false + }, + "subscription_integration": null + } + ], + "model_slots": { + "post_transcript": { + "provider_account_id": "local-ollama", + "model_id": "llama3.2", + "options": { + "json_mode": true, + "tool_support": false + } + }, + "memory_search": { + "provider_account_id": null, + "model_id": "local_wiki", + "options": {} + } + } +} +``` + +Stable slot names are: + +| Slot | Purpose | +|------|---------| +| `chat` | Ask Omi / local chat completions | +| `post_transcript` | Conversation title, overview, memories, and action items | +| `proactive` | Proactive local intelligence | +| `vision` | Screenshot / OCR-adjacent multimodal model calls | +| `stt` | Speech-to-text provider policy | +| `memory_search` | Local memory retrieval policy | + +`memory_search` is `local_wiki` for this local profile and does not require +`embedding_provider` or embeddings readiness. + +## Provider policy API + +Desktop clients should prefer these daemon APIs over manual JSON editing: + +- `GET /v1/provider-policy` returns the typed policy plus legacy-derived slots. +- `PUT /v1/provider-policy` validates and persists the typed policy. +- `GET /v1/provider-policy/resolve/{slot}` resolves one slot to its provider account, + model, options, and source (`provider_policy`, `legacy_setting`, or `default`). + +Callers should resolve `post_transcript`, `proactive`, and `chat` slots explicitly +instead of duplicating the legacy setting-key scan order. + ## Settings keys | Key | Purpose | `kind` values (v1) | @@ -19,6 +90,17 @@ never through Omi Python/Rust proxies. Set a key to JSON `null` to clear it. +Legacy mapping: + +| Legacy key | Typed slot | +|------------|------------| +| `ai_provider` | `post_transcript` | +| `provider` | `post_transcript` fallback alias | +| `chat_provider` | `chat` | +| `vision_provider` | `vision` | +| `stt_provider` | `stt` | +| `embedding_provider` | accepted as legacy data only; not required for this profile | + ## OpenAI-compatible object shape ```json @@ -36,7 +118,10 @@ Set a key to JSON `null` to clear it. identity/Firestore endpoints are **denied** (see `is_denied_provider_host` in `src/providers.rs`). -Loopback and direct vendor APIs (OpenAI, Anthropic, Deepgram, etc.) are allowed. +Loopback providers (`localhost`, `127.0.0.1`, `::1`) may omit an API key. Non-loopback +providers must include `api_key` or an explicit `subscription_integration` value. +Direct vendor APIs (OpenAI, Anthropic, Deepgram, etc.) are allowed when configured this +way. ## Optional cloud tiers (desktop only) @@ -82,4 +167,6 @@ Build proxy: `cd desktop/codex-proxy && cargo build --release` ## Test connection `POST /v1/settings/test-provider` with body `{ "key": "ai_provider" }` runs a minimal -request against the configured provider (chat completions ping for `openai_compatible`). +request against a legacy configured provider (chat completions ping for +`openai_compatible`). New UI should read and write policy through `/v1/provider-policy` +and use `/v1/provider-policy/resolve/{slot}` before making task-specific calls. diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index 1db6eeac00f..314c0b3ff1d 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -306,7 +306,10 @@ mod tests { Method::PUT, "/v1/settings", Some(json!({ - "provider": {"kind": "openai"}, + "provider": { + "kind": "openai", + "base_url": "http://127.0.0.1:11434/v1" + }, "local_first": true })), ) @@ -316,6 +319,67 @@ mod tests { Ok(()) } + #[tokio::test] + async fn provider_policy_routes_read_update_and_resolve_slots() -> Result<()> { + let app = test_app()?; + + let updated = request_json( + app.clone(), + Method::PUT, + "/v1/provider-policy", + Some(json!({ + "version": 1, + "provider_accounts": [{ + "id": "local-openai", + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "api_key": null, + "display_name": "Local OpenAI-compatible", + "capabilities": { + "chat_completions": true, + "json_mode": true, + "tool_calls": false, + "vision": false, + "speech_to_text": false + }, + "subscription_integration": null + }], + "model_slots": { + "post_transcript": { + "provider_account_id": "local-openai", + "model_id": "llama3.2", + "options": {"json_mode": true, "tool_support": false} + }, + "memory_search": { + "provider_account_id": null, + "model_id": "local_wiki", + "options": {} + } + } + })), + ) + .await?; + assert_eq!( + updated["provider_policy"]["model_slots"]["post_transcript"]["model_id"], + "llama3.2" + ); + + let policy = request_json(app.clone(), Method::GET, "/v1/provider-policy", None).await?; + assert_eq!(policy["provider_policy"]["version"], 1); + + let resolved = request_json( + app, + Method::GET, + "/v1/provider-policy/resolve/post_transcript", + None, + ) + .await?; + assert_eq!(resolved["resolved"]["model_id"], "llama3.2"); + assert_eq!(resolved["resolved"]["source"], "provider_policy"); + + Ok(()) + } + #[tokio::test] async fn duplicate_finalize_reuses_active_and_current_completed_job() -> Result<()> { let app = test_app()?; @@ -813,7 +877,10 @@ mod tests { let folders = request_json(app.clone(), Method::GET, "/v1/conversation-folders", None).await?; - assert_eq!(folders["folders"].as_array().expect("folders array").len(), 1); + assert_eq!( + folders["folders"].as_array().expect("folders array").len(), + 1 + ); let conv_a = request_json( app.clone(), @@ -899,9 +966,13 @@ mod tests { ) .await?; - let unfiled = - request_json(app.clone(), Method::GET, &format!("/v1/conversations/{merged_id}"), None) - .await?; + let unfiled = request_json( + app.clone(), + Method::GET, + &format!("/v1/conversations/{merged_id}"), + None, + ) + .await?; assert!(unfiled["conversation"]["folder_id"].is_null()); Ok(()) diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs index 2b90937f5da..c64665d850f 100644 --- a/desktop/local-backend/src/providers.rs +++ b/desktop/local-backend/src/providers.rs @@ -1,5 +1,7 @@ use anyhow::{anyhow, Context, Result}; use reqwest::{Client, Method}; +use std::collections::BTreeMap; + use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -42,6 +44,81 @@ pub struct OpenAiCompatibleConfig { pub api_key: String, } +pub const PROVIDER_POLICY_SETTING_KEY: &str = "provider_policy"; +pub const PROVIDER_POLICY_VERSION: u32 = 1; + +pub const SLOT_CHAT: &str = "chat"; +pub const SLOT_POST_TRANSCRIPT: &str = "post_transcript"; +pub const SLOT_PROACTIVE: &str = "proactive"; +pub const SLOT_VISION: &str = "vision"; +pub const SLOT_STT: &str = "stt"; +pub const SLOT_MEMORY_SEARCH: &str = "memory_search"; + +const LEGACY_SLOT_KEYS: &[(&str, &[&str])] = &[ + (SLOT_POST_TRANSCRIPT, &["ai_provider", "provider"]), + (SLOT_CHAT, &["chat_provider"]), + (SLOT_VISION, &["vision_provider"]), + (SLOT_STT, &["stt_provider"]), +]; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ProviderPolicy { + pub version: u32, + #[serde(default)] + pub provider_accounts: Vec, + #[serde(default)] + pub model_slots: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ProviderAccount { + pub id: String, + pub kind: String, + pub base_url: Option, + pub api_key: Option, + pub display_name: Option, + #[serde(default)] + pub capabilities: ProviderCapabilities, + pub subscription_integration: Option, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct ProviderCapabilities { + #[serde(default)] + pub chat_completions: bool, + #[serde(default)] + pub json_mode: bool, + #[serde(default)] + pub tool_calls: bool, + #[serde(default)] + pub vision: bool, + #[serde(default)] + pub speech_to_text: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ModelSlotTarget { + pub provider_account_id: Option, + pub model_id: String, + #[serde(default)] + pub options: ModelSlotOptions, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct ModelSlotOptions { + pub json_mode: Option, + pub tool_support: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ResolvedModelSlot { + pub slot: String, + pub provider_account: Option, + pub model_id: String, + pub options: ModelSlotOptions, + pub source: String, +} + #[derive(Clone)] pub struct OpenAiCompatibleProvider { config: OpenAiCompatibleConfig, @@ -107,41 +184,146 @@ pub fn configured_openai_provider(store: &Store) -> Result Result> { - for key in ["ai_provider", "provider", "chat_provider"] { - let Some(setting) = store.settings().get(key)? else { - continue; - }; - let value: Value = serde_json::from_str(&setting.value_json) - .with_context(|| format!("failed to parse {key} provider setting"))?; - let kind = value["kind"].as_str().unwrap_or_default(); - if kind != "openai" && kind != "openai_compatible" { - continue; - } - - let base_url = value["base_url"] - .as_str() - .unwrap_or("https://api.openai.com/v1") - .to_string(); - validate_provider_base_url(&base_url)?; - let model = value["model"].as_str().unwrap_or("gpt-4o-mini").to_string(); - let api_key = value["api_key"] - .as_str() - .or_else(|| value["key"].as_str()) - .unwrap_or_default() - .to_string(); + let Some(resolved) = resolve_model_slot(store, SLOT_POST_TRANSCRIPT)? else { + return Ok(None); + }; + let Some(account) = resolved.provider_account else { + return Ok(None); + }; + if !is_openai_compatible_kind(&account.kind) { + return Ok(None); + } + let base_url = account + .base_url + .unwrap_or_else(|| "https://api.openai.com/v1".to_string()); + validate_provider_base_url(&base_url)?; + Ok(Some(OpenAiCompatibleConfig { + base_url, + model: resolved.model_id, + api_key: account.api_key.unwrap_or_default(), + })) +} - if api_key.trim().is_empty() { - continue; +pub fn load_provider_policy(store: &Store) -> Result { + let mut policy = if let Some(setting) = store.settings().get(PROVIDER_POLICY_SETTING_KEY)? { + let policy: ProviderPolicy = serde_json::from_str(&setting.value_json) + .context("failed to parse provider_policy setting")?; + if policy.version != PROVIDER_POLICY_VERSION { + return Err(anyhow!( + "unsupported provider_policy version: {}", + policy.version + )); + } + policy + } else { + ProviderPolicy { + version: PROVIDER_POLICY_VERSION, + provider_accounts: Vec::new(), + model_slots: BTreeMap::new(), } + }; + add_legacy_policy_bridge(store, &mut policy)?; + Ok(policy) +} - return Ok(Some(OpenAiCompatibleConfig { - base_url, - model, - api_key, +pub fn save_provider_policy(store: &Store, policy: ProviderPolicy) -> Result { + validate_provider_policy(&policy)?; + let value = serde_json::to_value(&policy).context("failed to serialize provider policy")?; + let mut settings = serde_json::Map::new(); + settings.insert(PROVIDER_POLICY_SETTING_KEY.to_string(), value); + store.settings().upsert_many(settings)?; + load_provider_policy(store) +} + +pub fn resolve_model_slot(store: &Store, slot: &str) -> Result> { + if slot == SLOT_MEMORY_SEARCH { + return Ok(Some(ResolvedModelSlot { + slot: SLOT_MEMORY_SEARCH.to_string(), + provider_account: None, + model_id: "local_wiki".to_string(), + options: ModelSlotOptions::default(), + source: "default".to_string(), })); } - Ok(None) + let policy = load_provider_policy(store)?; + let Some(target) = policy.model_slots.get(slot) else { + return Ok(None); + }; + let account = match target.provider_account_id.as_deref() { + Some(account_id) => Some( + policy + .provider_accounts + .iter() + .find(|account| account.id == account_id) + .cloned() + .ok_or_else(|| { + anyhow!("model slot {slot} references missing provider account: {account_id}") + })?, + ), + None => None, + }; + Ok(Some(ResolvedModelSlot { + slot: slot.to_string(), + provider_account: account, + model_id: target.model_id.clone(), + options: target.options.clone(), + source: if target + .provider_account_id + .as_deref() + .is_some_and(|id| id.starts_with("legacy-")) + { + "legacy_setting".to_string() + } else { + "provider_policy".to_string() + }, + })) +} + +pub fn validate_provider_policy(policy: &ProviderPolicy) -> Result<()> { + if policy.version != PROVIDER_POLICY_VERSION { + return Err(anyhow!( + "provider_policy version must be {}", + PROVIDER_POLICY_VERSION + )); + } + let mut account_ids = std::collections::BTreeSet::new(); + for account in &policy.provider_accounts { + if account.id.trim().is_empty() { + return Err(anyhow!("provider account id is required")); + } + if !account_ids.insert(account.id.as_str()) { + return Err(anyhow!("duplicate provider account id: {}", account.id)); + } + validate_provider_account(account)?; + } + for slot in [ + SLOT_CHAT, + SLOT_POST_TRANSCRIPT, + SLOT_PROACTIVE, + SLOT_VISION, + SLOT_STT, + SLOT_MEMORY_SEARCH, + ] { + if let Some(target) = policy.model_slots.get(slot) { + validate_model_slot_target(slot, target, &account_ids)?; + } + } + for slot in policy.model_slots.keys() { + if ![ + SLOT_CHAT, + SLOT_POST_TRANSCRIPT, + SLOT_PROACTIVE, + SLOT_VISION, + SLOT_STT, + SLOT_MEMORY_SEARCH, + ] + .contains(&slot.as_str()) + { + return Err(anyhow!("unsupported model slot: {slot}")); + } + } + Ok(()) } /// Settings keys validated on `PUT /v1/settings` for hybrid direct providers. @@ -154,6 +336,135 @@ pub const HYBRID_PROVIDER_SETTING_KEYS: &[&str] = &[ "vision_provider", ]; +fn add_legacy_policy_bridge(store: &Store, policy: &mut ProviderPolicy) -> Result<()> { + for (slot, keys) in LEGACY_SLOT_KEYS { + if policy.model_slots.contains_key(*slot) { + continue; + } + for key in *keys { + let Some(setting) = store.settings().get(key)? else { + continue; + }; + let value: Value = serde_json::from_str(&setting.value_json) + .with_context(|| format!("failed to parse {key} provider setting"))?; + if value.is_null() { + continue; + } + let Some(account) = legacy_provider_account(slot, key, &value)? else { + continue; + }; + let model_id = value["model"] + .as_str() + .unwrap_or("gpt-5.4-mini") + .to_string(); + policy.provider_accounts.push(account.clone()); + policy.model_slots.insert( + (*slot).to_string(), + ModelSlotTarget { + provider_account_id: Some(account.id), + model_id, + options: ModelSlotOptions { + json_mode: Some(matches!(*slot, SLOT_POST_TRANSCRIPT | SLOT_PROACTIVE)), + tool_support: None, + }, + }, + ); + break; + } + } + Ok(()) +} + +fn legacy_provider_account( + slot: &str, + key: &str, + value: &Value, +) -> Result> { + let kind = value["kind"].as_str().unwrap_or_default(); + if !is_openai_compatible_kind(kind) { + return Ok(None); + } + let base_url = value["base_url"] + .as_str() + .unwrap_or("https://api.openai.com/v1") + .to_string(); + let account = ProviderAccount { + id: format!("legacy-{slot}"), + kind: "openai_compatible".to_string(), + base_url: Some(base_url), + api_key: value["api_key"] + .as_str() + .or_else(|| value["key"].as_str()) + .map(ToString::to_string), + display_name: Some(format!("Legacy {key}")), + capabilities: ProviderCapabilities { + chat_completions: true, + json_mode: true, + tool_calls: false, + vision: slot == SLOT_VISION, + speech_to_text: slot == SLOT_STT, + }, + subscription_integration: value["subscription_integration"] + .as_str() + .map(ToString::to_string), + }; + validate_provider_account(&account)?; + Ok(Some(account)) +} + +fn validate_provider_account(account: &ProviderAccount) -> Result<()> { + if !is_openai_compatible_kind(&account.kind) { + return Ok(()); + } + let base_url = account + .base_url + .as_deref() + .unwrap_or("https://api.openai.com/v1"); + validate_provider_base_url(base_url)?; + let has_api_key = account + .api_key + .as_deref() + .is_some_and(|api_key| !api_key.trim().is_empty()); + let has_subscription_integration = account + .subscription_integration + .as_deref() + .is_some_and(|value| !value.trim().is_empty()); + if !has_api_key && !has_subscription_integration && !is_loopback_provider_base_url(base_url)? { + return Err(anyhow!( + "api_key or subscription_integration is required for non-loopback provider account {}", + account.id + )); + } + Ok(()) +} + +fn validate_model_slot_target( + slot: &str, + target: &ModelSlotTarget, + account_ids: &std::collections::BTreeSet<&str>, +) -> Result<()> { + if target.model_id.trim().is_empty() { + return Err(anyhow!("model slot {slot} requires model_id")); + } + if slot == SLOT_MEMORY_SEARCH && target.model_id != "local_wiki" { + return Err(anyhow!( + "memory_search must use local_wiki in this local profile" + )); + } + if let Some(account_id) = target.provider_account_id.as_deref() { + if !account_ids.contains(account_id) { + return Err(anyhow!( + "model slot {slot} references missing provider account: {account_id}" + )); + } + } else if slot != SLOT_MEMORY_SEARCH { + return Err(anyhow!( + "model slot {slot} requires provider_account_id unless it is memory_search" + )); + } + Ok(()) +} + pub fn validate_provider_setting(value: &Value) -> Result<()> { if value.is_null() { return Ok(()); @@ -166,7 +477,23 @@ pub fn validate_provider_setting(value: &Value) -> Result<()> { let base_url = value["base_url"] .as_str() .unwrap_or("https://api.openai.com/v1"); - validate_provider_base_url(base_url) + validate_provider_base_url(base_url)?; + let api_key = value["api_key"] + .as_str() + .or_else(|| value["key"].as_str()) + .unwrap_or_default(); + let subscription_integration = value["subscription_integration"] + .as_str() + .unwrap_or_default(); + if api_key.trim().is_empty() + && subscription_integration.trim().is_empty() + && !is_loopback_provider_base_url(base_url)? + { + return Err(anyhow!( + "api_key or subscription_integration is required for non-loopback provider" + )); + } + Ok(()) } pub fn validate_hybrid_provider_setting(key: &str, value: &Value) -> Result<()> { @@ -197,6 +524,10 @@ pub fn is_provider_configured(value: &Value) -> bool { .is_some_and(|url| url.contains("127.0.0.1") || url.contains("localhost")) } +fn is_openai_compatible_kind(kind: &str) -> bool { + kind == "openai" || kind == "openai_compatible" +} + fn validate_provider_base_url(base_url: &str) -> Result<()> { let url = reqwest::Url::parse(base_url) .with_context(|| format!("provider base_url is not a valid URL: {base_url}"))?; @@ -218,6 +549,15 @@ fn validate_provider_base_url(base_url: &str) -> Result<()> { Ok(()) } +fn is_loopback_provider_base_url(base_url: &str) -> Result { + let url = reqwest::Url::parse(base_url) + .with_context(|| format!("provider base_url is not a valid URL: {base_url}"))?; + let Some(host) = url.host_str() else { + return Ok(false); + }; + Ok(matches!(host, "localhost" | "127.0.0.1" | "::1")) +} + pub async fn test_configured_provider(store: &Store, key: &str) -> Result { let Some(setting) = store.settings().get(key)? else { return Err(anyhow!("setting {key} is not configured")); @@ -229,7 +569,9 @@ pub async fn test_configured_provider(store: &Store, key: &str) -> Result Result<()> { + let store = Store::open_in_memory()?; + let mut slots = BTreeMap::new(); + slots.insert( + SLOT_POST_TRANSCRIPT.to_string(), + ModelSlotTarget { + provider_account_id: Some("local-ollama".to_string()), + model_id: "llama3.2".to_string(), + options: ModelSlotOptions { + json_mode: Some(true), + tool_support: Some(false), + }, + }, + ); + slots.insert( + SLOT_MEMORY_SEARCH.to_string(), + ModelSlotTarget { + provider_account_id: None, + model_id: "local_wiki".to_string(), + options: ModelSlotOptions::default(), + }, + ); + let policy = ProviderPolicy { + version: PROVIDER_POLICY_VERSION, + provider_accounts: vec![ProviderAccount { + id: "local-ollama".to_string(), + kind: "openai_compatible".to_string(), + base_url: Some("http://127.0.0.1:11434/v1".to_string()), + api_key: None, + display_name: Some("Local Ollama".to_string()), + capabilities: ProviderCapabilities { + chat_completions: true, + json_mode: true, + tool_calls: false, + vision: false, + speech_to_text: false, + }, + subscription_integration: None, + }], + model_slots: slots, + }; + + let saved = save_provider_policy(&store, policy.clone())?; + assert_eq!(saved, policy); + + let loaded = load_provider_policy(&store)?; + assert_eq!(loaded, policy); + + let resolved = resolve_model_slot(&store, SLOT_POST_TRANSCRIPT)?.expect("slot"); + assert_eq!(resolved.model_id, "llama3.2"); + assert_eq!(resolved.source, "provider_policy"); + + Ok(()) + } + + #[test] + fn legacy_settings_resolve_to_typed_slots() -> Result<()> { + let store = Store::open_in_memory()?; + let mut settings = Map::new(); + settings.insert( + "chat_provider".to_string(), + json!({ + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "model": "chat-local" + }), + ); + settings.insert( + "ai_provider".to_string(), + json!({ + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "model": "post-local" + }), + ); + settings.insert( + "embedding_provider".to_string(), + json!({ + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "model": "legacy-embedding" + }), + ); + store.settings().upsert_many(settings)?; + + let chat = resolve_model_slot(&store, SLOT_CHAT)?.expect("chat slot"); + assert_eq!(chat.model_id, "chat-local"); + assert_eq!(chat.source, "legacy_setting"); + + let post_transcript = resolve_model_slot(&store, SLOT_POST_TRANSCRIPT)?.expect("post slot"); + assert_eq!(post_transcript.model_id, "post-local"); + + let memory_search = resolve_model_slot(&store, SLOT_MEMORY_SEARCH)?.expect("memory"); + assert_eq!(memory_search.provider_account, None); + assert_eq!(memory_search.model_id, "local_wiki"); + + Ok(()) + } + + #[test] + fn unresolved_slots_return_none() -> Result<()> { + let store = Store::open_in_memory()?; + + assert!(resolve_model_slot(&store, SLOT_PROACTIVE)?.is_none()); + assert!(resolve_model_slot(&store, SLOT_VISION)?.is_none()); + + Ok(()) + } + + #[test] + fn local_loopback_provider_does_not_require_api_key() -> Result<()> { + let store = Store::open_in_memory()?; + let mut settings = Map::new(); + settings.insert( + "ai_provider".to_string(), + json!({ + "kind": "openai_compatible", + "base_url": "http://localhost:11434/v1", + "model": "local-no-key" + }), + ); + store.settings().upsert_many(settings)?; + + let config = load_openai_config(&store)?.expect("loopback provider should resolve"); + assert_eq!(config.base_url, "http://localhost:11434/v1"); + assert_eq!(config.model, "local-no-key"); + assert_eq!(config.api_key, ""); + + Ok(()) + } + + #[test] + fn non_loopback_provider_requires_key_or_subscription_integration() { + assert!(validate_provider_setting(&json!({ + "kind": "openai_compatible", + "base_url": "https://api.openai.com/v1", + "model": "gpt-5.4-mini" + })) + .is_err()); + + validate_provider_setting(&json!({ + "kind": "openai_compatible", + "base_url": "https://api.openai.com/v1", + "model": "gpt-5.4-mini", + "api_key": "key" + })) + .expect("api key should satisfy remote provider policy"); + + validate_provider_setting(&json!({ + "kind": "openai_compatible", + "base_url": "https://api.openai.com/v1", + "model": "gpt-5.4-mini", + "subscription_integration": "chatgpt_plan" + })) + .expect("subscription integration should satisfy remote provider policy"); + } + #[test] fn provider_validation_denies_omi_firebase_and_google_hosts() { for base_url in [ diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index b3e484ad79f..a0cdb19fe08 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -26,6 +26,14 @@ pub fn router() -> Router { .route("/v1/profile", get(get_profile).put(update_profile)) .route("/v1/settings", get(list_settings).put(update_settings)) .route("/v1/settings/test-provider", post(test_provider)) + .route( + "/v1/provider-policy", + get(get_provider_policy).put(update_provider_policy), + ) + .route( + "/v1/provider-policy/resolve/:slot", + get(resolve_provider_slot), + ) .route( "/v1/conversations", get(list_conversations).post(create_conversation), @@ -306,9 +314,7 @@ async fn update_conversation( Some(Value::Null) => Some(None), Some(Value::String(s)) => Some(Some(s)), Some(_) => { - return Err(ApiError::bad_request( - "folder_id must be a string or null", - )); + return Err(ApiError::bad_request("folder_id must be a string or null")); } }; if let Some(Some(ref fid)) = folder_id { @@ -872,6 +878,29 @@ async fn test_provider( }))) } +async fn get_provider_policy(State(state): State) -> ApiResult { + let policy = providers::load_provider_policy(&state.store).map_err(ApiError::internal)?; + Ok(Json(json!({ "provider_policy": policy }))) +} + +async fn update_provider_policy( + State(state): State, + Json(policy): Json, +) -> ApiResult { + let policy = providers::save_provider_policy(&state.store, policy) + .map_err(|error| ApiError::bad_request(error.to_string()))?; + Ok(Json(json!({ "provider_policy": policy }))) +} + +async fn resolve_provider_slot( + State(state): State, + Path(slot): Path, +) -> ApiResult { + let resolved = + providers::resolve_model_slot(&state.store, &slot).map_err(ApiError::internal)?; + Ok(Json(json!({ "resolved": resolved }))) +} + async fn list_processing_jobs(State(state): State) -> ApiResult { let jobs = state .store @@ -1084,12 +1113,7 @@ async fn list_chat_sessions( let sessions = state .store .chat_sessions() - .list_sessions( - limit, - offset, - query.app_id.as_deref(), - query.starred, - ) + .list_sessions(limit, offset, query.app_id.as_deref(), query.starred) .map_err(ApiError::internal)?; Ok(Json(sessions)) } diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index f1b88044317..35e4882cfe5 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -352,10 +352,7 @@ pub struct ConversationFolder { pub is_default: bool, #[serde(rename = "is_system")] pub is_system: bool, - #[serde( - rename = "category_mapping", - skip_serializing_if = "Option::is_none" - )] + #[serde(rename = "category_mapping", skip_serializing_if = "Option::is_none")] pub category_mapping: Option, #[serde(rename = "conversation_count")] pub conversation_count: i64, @@ -581,7 +578,11 @@ impl Store { } /// Merge [`source_ids.len()`] ≥ 2 conversations into one new conversation. - pub fn merge_conversations(&self, source_ids: &[String], _reprocess: bool) -> Result { + pub fn merge_conversations( + &self, + source_ids: &[String], + _reprocess: bool, + ) -> Result { if source_ids.len() < 2 { anyhow::bail!("at least two conversation_ids required"); } @@ -746,13 +747,7 @@ impl Store { sync_version = sync_version + 1 WHERE id = ?5 AND deleted_at IS NULL "#, - params![ - new_id, - seg.session_id, - idx_i, - now, - seg.id, - ], + params![new_id, seg.session_id, idx_i, now, seg.id,], ) .context("merge: reattach transcript segment")?; } @@ -810,12 +805,13 @@ impl FolderRepository { pub fn exists_active(&self, id: &str) -> Result { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); - let count: i64 = conn.query_row( - "SELECT COUNT(*) FROM folders WHERE id = ?1 AND deleted_at IS NULL", - params![id], - |row| row.get(0), - ) - .context("folder exists")?; + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM folders WHERE id = ?1 AND deleted_at IS NULL", + params![id], + |row| row.get(0), + ) + .context("folder exists")?; Ok(count > 0) } @@ -910,7 +906,8 @@ impl FolderRepository { ) .context("insert folder")?; drop(conn); - self.get(&new.id)?.ok_or_else(|| anyhow::anyhow!("folder missing after insert")) + self.get(&new.id)? + .ok_or_else(|| anyhow::anyhow!("folder missing after insert")) } pub fn update(&self, id: &str, update: UpdateFolder) -> Result> { @@ -966,7 +963,9 @@ impl FolderRepository { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); if let Some(target) = move_to_folder_id { if target == id { - return Err(anyhow::anyhow!("move_to_folder_id must differ from deleted folder")); + return Err(anyhow::anyhow!( + "move_to_folder_id must differ from deleted folder" + )); } let exists: i64 = conn.query_row( "SELECT COUNT(*) FROM folders WHERE id = ?1 AND deleted_at IS NULL", @@ -2687,7 +2686,10 @@ impl ChatSessionsRepository { "#, )?; let rows = stmt - .query_map(params![LOCAL_DEFAULT_CHAT_SESSION_ID, limit, offset], map_row)? + .query_map( + params![LOCAL_DEFAULT_CHAT_SESSION_ID, limit, offset], + map_row, + )? .collect::, _>>()?; rows } @@ -2703,7 +2705,10 @@ impl ChatSessionsRepository { "#, )?; let rows = stmt - .query_map(params![LOCAL_DEFAULT_CHAT_SESSION_ID, aid, limit, offset], map_row)? + .query_map( + params![LOCAL_DEFAULT_CHAT_SESSION_ID, aid, limit, offset], + map_row, + )? .collect::, _>>()?; rows } @@ -2720,7 +2725,10 @@ impl ChatSessionsRepository { "#, )?; let rows = stmt - .query_map(params![LOCAL_DEFAULT_CHAT_SESSION_ID, st_i, limit, offset], map_row)? + .query_map( + params![LOCAL_DEFAULT_CHAT_SESSION_ID, st_i, limit, offset], + map_row, + )? .collect::, _>>()?; rows } @@ -2749,7 +2757,11 @@ impl ChatSessionsRepository { Ok(sessions) } - pub fn create_session(&self, title: Option<&str>, app_id: Option<&str>) -> Result { + pub fn create_session( + &self, + title: Option<&str>, + app_id: Option<&str>, + ) -> Result { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); let now = Utc::now(); let id = deterministic_id( @@ -2876,7 +2888,8 @@ impl ChatSessionsRepository { LIMIT ?2 OFFSET ?3 "#, )?; - let rows = stmt.query_map(params![session_id, limit, offset], Self::map_message_row)? + let rows = stmt + .query_map(params![session_id, limit, offset], Self::map_message_row)? .collect::, _>>()?; rows } @@ -2890,9 +2903,12 @@ impl ChatSessionsRepository { LIMIT ?3 OFFSET ?4 "#, )?; - let rows = - stmt.query_map(params![session_id, aid, limit, offset], Self::map_message_row)? - .collect::, _>>()?; + let rows = stmt + .query_map( + params![session_id, aid, limit, offset], + Self::map_message_row, + )? + .collect::, _>>()?; rows } }; @@ -3360,13 +3376,19 @@ mod tests { )? .expect("update"); - let conv = store.conversations().get("conv-fld")?.expect("conversation"); + let conv = store + .conversations() + .get("conv-fld")? + .expect("conversation"); assert_eq!(conv.folder_id.as_deref(), Some("fld-crm")); store.folders().soft_delete("fld-crm", None)?; assert!(store.folders().get("fld-crm")?.is_none()); - let unfiled = store.conversations().get("conv-fld")?.expect("conversation"); + let unfiled = store + .conversations() + .get("conv-fld")? + .expect("conversation"); assert!(unfiled.folder_id.is_none()); Ok(()) From f07100db0855955eccda288cf1be3763ba5d6c46 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 13:35:36 +0700 Subject: [PATCH 32/58] Add local model catalog defaults --- desktop/local-backend/docs/architecture.md | 17 +- .../docs/hybrid-provider-settings.md | 39 +- .../local-backend/docs/local-mvp-runbook.md | 6 +- desktop/local-backend/src/main.rs | 30 +- desktop/local-backend/src/providers.rs | 527 ++++++++++++++++-- desktop/local-backend/src/routes.rs | 15 +- .../tools/seed_hybrid_defaults.sh | 130 +++-- 7 files changed, 646 insertions(+), 118 deletions(-) diff --git a/desktop/local-backend/docs/architecture.md b/desktop/local-backend/docs/architecture.md index d6239cd4934..0762ce2821e 100644 --- a/desktop/local-backend/docs/architecture.md +++ b/desktop/local-backend/docs/architecture.md @@ -20,9 +20,9 @@ local daemon's critical path. transcript rows, FTS, sync metadata fields, and local profile/settings. - `src/processing.rs` owns durable job execution, deterministic fallback processing, and output persistence. -- `src/providers.rs` owns direct provider adapters. The current adapter is - OpenAI-compatible chat completions and is configured only through local - settings. +- `src/providers.rs` owns the model catalog, provider-account/model-slot policy, + and direct provider adapters. The current adapter is OpenAI-compatible chat + completions and is configured only through local settings/policy. ## Data Directory And Database @@ -64,9 +64,14 @@ Remote AI/STT providers are allowed only when explicitly configured by the user or developer. The daemon talks directly to configured providers; it does not use Omi backend provider proxies. -The MVP includes an OpenAI-compatible chat completions adapter with local -settings for base URL, model, and API key. The processing pipeline still works -without any provider key by using deterministic fallbacks: +The MVP includes a subscription/profile-aware model catalog plus an +OpenAI-compatible chat completions adapter with local settings for base URL, +model, and API key. `post_transcript` and `proactive` default to the designated +small model `gpt-5.4-mini`; `chat` uses the profile/subscription default and is +configurable; `memory_search` defaults to `local_wiki` and requires no embedding +provider. A default model slot can exist without a usable provider account, in +which case resolution returns a readable reason and the processing pipeline uses +deterministic fallbacks: - title: first meaningful transcript words, bounded length - overview: clipped transcript excerpt diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md index aaa781f0a98..9cb3ee940f6 100644 --- a/desktop/local-backend/docs/hybrid-provider-settings.md +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -9,6 +9,31 @@ never through Omi Python/Rust proxies. The current durable policy key is `provider_policy`. Older raw provider keys are still read as a compatibility bridge while desktop UI migrates to the typed policy API. +## Model catalog and provider policy + +The local daemon exposes a subscription/profile-aware catalog at +`GET /v1/model-catalog`. Each entry reports: + +- model ID and display name; +- compatible provider kinds and configured account IDs; +- allowed task slots; +- availability for the local profile/subscription state; +- capability flags for JSON mode, tool calls, multimodal input, streaming, and + local/remote origin. + +Local profile defaults are centralized in this catalog/policy layer: + +| Slot | Default model | Account behavior | +|------|---------------|------------------| +| `post_transcript` | `gpt-5.4-mini` | selected by default; unusable until a provider account is configured | +| `proactive` | `gpt-5.4-mini` | selected by default; unusable until a provider account is configured | +| `chat` | profile/subscription default (`gpt-5.4-mini` for this profile) | configurable | +| `memory_search` | `local_wiki` | always local; no embedding provider required | + +Slot resolution returns both the selected model and a readable reason. A slot can +therefore report a default model while also explaining that it cannot run because no +provider account or subscription integration is configured. + ## Provider policy `provider_policy` is versioned JSON: @@ -36,7 +61,7 @@ read as a compatibility bridge while desktop UI migrates to the typed policy API "model_slots": { "post_transcript": { "provider_account_id": "local-ollama", - "model_id": "llama3.2", + "model_id": "gpt-5.4-mini", "options": { "json_mode": true, "tool_support": false @@ -72,7 +97,9 @@ Desktop clients should prefer these daemon APIs over manual JSON editing: - `GET /v1/provider-policy` returns the typed policy plus legacy-derived slots. - `PUT /v1/provider-policy` validates and persists the typed policy. - `GET /v1/provider-policy/resolve/{slot}` resolves one slot to its provider account, - model, options, and source (`provider_policy`, `legacy_setting`, or `default`). + model, options, source (`provider_policy`, `legacy_setting`, or `default`), and a + readable success/failure reason. +- `GET /v1/model-catalog` returns the local model catalog and availability. Callers should resolve `post_transcript`, `proactive`, and `chat` slots explicitly instead of duplicating the legacy setting-key scan order. @@ -139,13 +166,15 @@ Default hybrid optional tiers: both cloud toggles off. `run.sh` local mode defau When the daemon starts via `make serve-local` or `desktop/run.sh` in local mode, `desktop/local-backend/tools/seed_hybrid_defaults.sh` runs idempotently: -- If `ai_provider` / `provider` is unset → sets OpenAI-compatible defaults. -- If `chat_provider` is unset → sets the same defaults. +- If `post_transcript`, `proactive`, or `chat` lacks a provider account, the script + creates/reuses a local OpenAI-compatible account and points those slots at it. +- `memory_search` remains `local_wiki`. | Variable | Default | |----------|---------| | `OMI_HYBRID_DEFAULT_CHAT_BASE_URL` | `http://127.0.0.1:11434/v1` | -| `OMI_HYBRID_DEFAULT_CHAT_MODEL` | `llama3.2` | +| `OMI_HYBRID_DEFAULT_CHAT_MODEL` | `gpt-5.4-mini` | +| `OMI_HYBRID_DEFAULT_PROVIDER_ACCOUNT_ID` | `local-openai-compatible` | The desktop app also calls `HybridProviderBootstrap.ensureDefaultsIfNeeded()` on local guest session startup. Chat resolves `chat_provider` → `ai_provider` → BYOK OpenAI diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index 5dab8d0d8b0..5ee08116acf 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -71,8 +71,10 @@ machine. `"service":"omi-local-backend"`. - Hybrid providers: `make serve-local` and `desktop/run.sh` (local mode) run `desktop/local-backend/tools/seed_hybrid_defaults.sh` when the daemon is healthy, - seeding `ai_provider` and `chat_provider` to `http://127.0.0.1:11434/v1` (Ollama) if unset. - Override with `OMI_HYBRID_DEFAULT_CHAT_BASE_URL` and `OMI_HYBRID_DEFAULT_CHAT_MODEL`. + seeding `post_transcript`, `proactive`, and `chat` model slots to a local + OpenAI-compatible account if those slots lack provider accounts. Override with + `OMI_HYBRID_DEFAULT_CHAT_BASE_URL`, `OMI_HYBRID_DEFAULT_CHAT_MODEL`, and + `OMI_HYBRID_DEFAULT_PROVIDER_ACCOUNT_ID`. ### Manual `run.sh` launch diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index 314c0b3ff1d..8232a67d3a7 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -366,9 +366,23 @@ mod tests { let policy = request_json(app.clone(), Method::GET, "/v1/provider-policy", None).await?; assert_eq!(policy["provider_policy"]["version"], 1); + assert_eq!( + policy["provider_policy"]["model_slots"]["proactive"]["model_id"], + "gpt-5.4-mini" + ); + + let catalog = request_json(app.clone(), Method::GET, "/v1/model-catalog", None).await?; + assert!( + catalog["models"] + .as_array() + .unwrap() + .iter() + .any(|model| model["id"] == "local_wiki" + && model["availability"]["available"] == true) + ); let resolved = request_json( - app, + app.clone(), Method::GET, "/v1/provider-policy/resolve/post_transcript", None, @@ -376,6 +390,20 @@ mod tests { .await?; assert_eq!(resolved["resolved"]["model_id"], "llama3.2"); assert_eq!(resolved["resolved"]["source"], "provider_policy"); + assert_eq!(resolved["resolution"]["ok"], true); + + let proactive = request_json( + app, + Method::GET, + "/v1/provider-policy/resolve/proactive", + None, + ) + .await?; + assert_eq!(proactive["resolution"]["ok"], false); + assert!(proactive["resolution"]["reason"] + .as_str() + .unwrap() + .contains("no provider account")); Ok(()) } diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs index c64665d850f..1d8729476d3 100644 --- a/desktop/local-backend/src/providers.rs +++ b/desktop/local-backend/src/providers.rs @@ -54,6 +54,12 @@ pub const SLOT_VISION: &str = "vision"; pub const SLOT_STT: &str = "stt"; pub const SLOT_MEMORY_SEARCH: &str = "memory_search"; +pub const MODEL_GPT_5_4: &str = "gpt-5.4"; +pub const MODEL_GPT_5_4_MINI: &str = "gpt-5.4-mini"; +pub const MODEL_LLAMA_3_2: &str = "llama3.2"; +pub const MODEL_LOCAL_WIKI: &str = "local_wiki"; +pub const MODEL_WHISPER_1: &str = "whisper-1"; + const LEGACY_SLOT_KEYS: &[(&str, &[&str])] = &[ (SLOT_POST_TRANSCRIPT, &["ai_provider", "provider"]), (SLOT_CHAT, &["chat_provider"]), @@ -70,6 +76,39 @@ pub struct ProviderPolicy { pub model_slots: BTreeMap, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ModelCatalogEntry { + pub id: String, + pub display_name: String, + pub compatible_provider_kinds: Vec, + pub compatible_provider_account_ids: Vec, + pub allowed_slots: Vec, + pub availability: ModelAvailability, + pub capabilities: ModelCapabilities, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ModelAvailability { + pub local_profile: bool, + pub subscription_required: Option, + pub configured_account_required: bool, + pub available: bool, + pub reason: String, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct ModelCapabilities { + #[serde(default)] + pub json_mode: bool, + #[serde(default)] + pub tool_calls: bool, + #[serde(default)] + pub multimodal_input: bool, + #[serde(default)] + pub streaming: bool, + pub origin: String, +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ProviderAccount { pub id: String, @@ -119,6 +158,14 @@ pub struct ResolvedModelSlot { pub source: String, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ModelSlotResolution { + pub slot: String, + pub ok: bool, + pub resolved: Option, + pub reason: String, +} + #[derive(Clone)] pub struct OpenAiCompatibleProvider { config: OpenAiCompatibleConfig, @@ -223,6 +270,7 @@ pub fn load_provider_policy(store: &Store) -> Result { } }; add_legacy_policy_bridge(store, &mut policy)?; + add_local_profile_defaults(&mut policy)?; Ok(policy) } @@ -236,19 +284,23 @@ pub fn save_provider_policy(store: &Store, policy: ProviderPolicy) -> Result Result> { - if slot == SLOT_MEMORY_SEARCH { - return Ok(Some(ResolvedModelSlot { - slot: SLOT_MEMORY_SEARCH.to_string(), - provider_account: None, - model_id: "local_wiki".to_string(), - options: ModelSlotOptions::default(), - source: "default".to_string(), - })); + let resolution = resolve_model_slot_result(store, slot)?; + if resolution.ok { + Ok(resolution.resolved) + } else { + Ok(None) } +} +pub fn resolve_model_slot_result(store: &Store, slot: &str) -> Result { let policy = load_provider_policy(store)?; let Some(target) = policy.model_slots.get(slot) else { - return Ok(None); + return Ok(ModelSlotResolution { + slot: slot.to_string(), + ok: false, + resolved: None, + reason: format!("no model slot configured for {slot}"), + }); }; let account = match target.provider_account_id.as_deref() { Some(account_id) => Some( @@ -263,7 +315,7 @@ pub fn resolve_model_slot(store: &Store, slot: &str) -> Result None, }; - Ok(Some(ResolvedModelSlot { + let resolved = ResolvedModelSlot { slot: slot.to_string(), provider_account: account, model_id: target.model_id.clone(), @@ -277,7 +329,28 @@ pub fn resolve_model_slot(store: &Store, slot: &str) -> Result Result<()> { @@ -307,6 +380,7 @@ pub fn validate_provider_policy(policy: &ProviderPolicy) -> Result<()> { ] { if let Some(target) = policy.model_slots.get(slot) { validate_model_slot_target(slot, target, &account_ids)?; + validate_model_allowed_for_slot(policy, slot, target)?; } } for slot in policy.model_slots.keys() { @@ -355,7 +429,7 @@ fn add_legacy_policy_bridge(store: &Store, policy: &mut ProviderPolicy) -> Resul }; let model_id = value["model"] .as_str() - .unwrap_or("gpt-5.4-mini") + .unwrap_or(MODEL_GPT_5_4_MINI) .to_string(); policy.provider_accounts.push(account.clone()); policy.model_slots.insert( @@ -375,6 +449,257 @@ fn add_legacy_policy_bridge(store: &Store, policy: &mut ProviderPolicy) -> Resul Ok(()) } +pub fn model_catalog(store: &Store) -> Result> { + let policy = load_provider_policy(store)?; + Ok(model_catalog_for_policy(&policy)) +} + +pub fn model_catalog_for_policy(policy: &ProviderPolicy) -> Vec { + let account_ids: Vec = policy + .provider_accounts + .iter() + .map(|account| account.id.clone()) + .collect(); + let mut entries = vec![ + catalog_entry( + MODEL_GPT_5_4_MINI, + "GPT-5.4 mini", + &["openai", "openai_compatible"], + &account_ids, + &[SLOT_CHAT, SLOT_POST_TRANSCRIPT, SLOT_PROACTIVE], + ModelAvailability { + local_profile: true, + subscription_required: Some("chatgpt_plan".to_string()), + configured_account_required: true, + available: !account_ids.is_empty(), + reason: account_availability_reason( + &account_ids, + "requires a configured OpenAI-compatible account or ChatGPT plan integration", + ), + }, + ModelCapabilities { + json_mode: true, + tool_calls: true, + multimodal_input: false, + streaming: true, + origin: "remote".to_string(), + }, + ), + catalog_entry( + MODEL_GPT_5_4, + "GPT-5.4", + &["openai", "openai_compatible"], + &account_ids, + &[SLOT_CHAT], + ModelAvailability { + local_profile: true, + subscription_required: Some("chatgpt_plan".to_string()), + configured_account_required: true, + available: !account_ids.is_empty(), + reason: account_availability_reason( + &account_ids, + "requires a configured OpenAI-compatible account or ChatGPT plan integration", + ), + }, + ModelCapabilities { + json_mode: true, + tool_calls: true, + multimodal_input: true, + streaming: true, + origin: "remote".to_string(), + }, + ), + catalog_entry( + MODEL_LLAMA_3_2, + "Llama 3.2", + &["openai_compatible"], + &account_ids, + &[SLOT_CHAT, SLOT_POST_TRANSCRIPT, SLOT_PROACTIVE], + ModelAvailability { + local_profile: true, + subscription_required: None, + configured_account_required: true, + available: !account_ids.is_empty(), + reason: account_availability_reason( + &account_ids, + "requires a configured OpenAI-compatible local account", + ), + }, + ModelCapabilities { + json_mode: true, + tool_calls: false, + multimodal_input: false, + streaming: true, + origin: "local".to_string(), + }, + ), + catalog_entry( + MODEL_WHISPER_1, + "Whisper", + &["openai", "openai_compatible"], + &account_ids, + &[SLOT_STT], + ModelAvailability { + local_profile: true, + subscription_required: None, + configured_account_required: true, + available: !account_ids.is_empty(), + reason: account_availability_reason( + &account_ids, + "requires a configured speech-to-text account", + ), + }, + ModelCapabilities { + json_mode: false, + tool_calls: false, + multimodal_input: false, + streaming: false, + origin: "remote".to_string(), + }, + ), + catalog_entry( + MODEL_LOCAL_WIKI, + "Local wiki search", + &["local"], + &[], + &[SLOT_MEMORY_SEARCH], + ModelAvailability { + local_profile: true, + subscription_required: None, + configured_account_required: false, + available: true, + reason: + "available from local SQLite/FTS memory search; no embedding provider required" + .to_string(), + }, + ModelCapabilities { + json_mode: false, + tool_calls: false, + multimodal_input: false, + streaming: false, + origin: "local".to_string(), + }, + ), + ]; + for (slot, target) in &policy.model_slots { + if entries.iter().any(|entry| entry.id == target.model_id) { + continue; + } + let Some(account_id) = target.provider_account_id.as_deref() else { + continue; + }; + let Some(account) = policy + .provider_accounts + .iter() + .find(|account| account.id == account_id) + else { + continue; + }; + if account + .base_url + .as_deref() + .is_some_and(|base_url| is_loopback_provider_base_url(base_url).unwrap_or(false)) + { + entries.push(catalog_entry( + &target.model_id, + &target.model_id, + &[account.kind.as_str()], + std::slice::from_ref(&account.id), + &[slot.as_str()], + ModelAvailability { + local_profile: true, + subscription_required: None, + configured_account_required: true, + available: true, + reason: "available through configured loopback provider account".to_string(), + }, + ModelCapabilities { + json_mode: account.capabilities.json_mode, + tool_calls: account.capabilities.tool_calls, + multimodal_input: account.capabilities.vision, + streaming: account.capabilities.chat_completions, + origin: "local".to_string(), + }, + )); + } + } + entries +} + +fn catalog_entry( + id: &str, + display_name: &str, + compatible_provider_kinds: &[&str], + compatible_provider_account_ids: &[String], + allowed_slots: &[&str], + availability: ModelAvailability, + capabilities: ModelCapabilities, +) -> ModelCatalogEntry { + ModelCatalogEntry { + id: id.to_string(), + display_name: display_name.to_string(), + compatible_provider_kinds: compatible_provider_kinds + .iter() + .map(|value| value.to_string()) + .collect(), + compatible_provider_account_ids: compatible_provider_account_ids.to_vec(), + allowed_slots: allowed_slots + .iter() + .map(|value| value.to_string()) + .collect(), + availability, + capabilities, + } +} + +fn account_availability_reason(account_ids: &[String], unavailable_reason: &str) -> String { + if account_ids.is_empty() { + unavailable_reason.to_string() + } else { + "available through configured provider policy account".to_string() + } +} + +fn add_local_profile_defaults(policy: &mut ProviderPolicy) -> Result<()> { + policy + .model_slots + .entry(SLOT_POST_TRANSCRIPT.to_string()) + .or_insert_with(|| default_model_slot_target(None, MODEL_GPT_5_4_MINI, true, false)); + policy + .model_slots + .entry(SLOT_PROACTIVE.to_string()) + .or_insert_with(|| default_model_slot_target(None, MODEL_GPT_5_4_MINI, true, false)); + policy + .model_slots + .entry(SLOT_CHAT.to_string()) + .or_insert_with(|| default_model_slot_target(None, MODEL_GPT_5_4_MINI, false, false)); + policy + .model_slots + .entry(SLOT_MEMORY_SEARCH.to_string()) + .or_insert_with(|| ModelSlotTarget { + provider_account_id: None, + model_id: MODEL_LOCAL_WIKI.to_string(), + options: ModelSlotOptions::default(), + }); + validate_provider_policy(policy) +} + +fn default_model_slot_target( + provider_account_id: Option<&str>, + model_id: &str, + json_mode: bool, + tool_support: bool, +) -> ModelSlotTarget { + ModelSlotTarget { + provider_account_id: provider_account_id.map(ToString::to_string), + model_id: model_id.to_string(), + options: ModelSlotOptions { + json_mode: Some(json_mode), + tool_support: Some(tool_support), + }, + } +} + fn legacy_provider_account( slot: &str, key: &str, @@ -457,11 +782,60 @@ fn validate_model_slot_target( "model slot {slot} references missing provider account: {account_id}" )); } - } else if slot != SLOT_MEMORY_SEARCH { + } + Ok(()) +} + +fn validate_model_allowed_for_slot( + policy: &ProviderPolicy, + slot: &str, + target: &ModelSlotTarget, +) -> Result<()> { + let catalog = model_catalog_for_policy(policy); + let entry = catalog + .iter() + .find(|entry| entry.id == target.model_id) + .ok_or_else(|| anyhow!("model slot {slot} uses unknown model: {}", target.model_id))?; + if !entry.allowed_slots.iter().any(|allowed| allowed == slot) { return Err(anyhow!( - "model slot {slot} requires provider_account_id unless it is memory_search" + "model {} is not allowed for slot {slot}", + target.model_id )); } + if let Some(account_id) = target.provider_account_id.as_deref() { + let account = policy + .provider_accounts + .iter() + .find(|account| account.id == account_id) + .ok_or_else(|| { + anyhow!("model slot {slot} references missing provider account: {account_id}") + })?; + if !entry + .compatible_provider_kinds + .iter() + .any(|kind| kind == &account.kind) + { + return Err(anyhow!( + "model {} is not compatible with provider account {}", + target.model_id, + account.id + )); + } + if target.options.json_mode == Some(true) + && (!entry.capabilities.json_mode || !account.capabilities.json_mode) + { + return Err(anyhow!( + "model slot {slot} requires JSON mode but model/account does not support it" + )); + } + if target.options.tool_support == Some(true) + && (!entry.capabilities.tool_calls || !account.capabilities.tool_calls) + { + return Err(anyhow!( + "model slot {slot} requires tool calls but model/account does not support them" + )); + } + } Ok(()) } @@ -506,24 +880,6 @@ pub fn validate_hybrid_provider_setting(key: &str, value: &Value) -> Result<()> validate_provider_setting(value) } -pub fn is_provider_configured(value: &Value) -> bool { - if value.is_null() { - return false; - } - let kind = value["kind"].as_str().unwrap_or_default(); - if kind != "openai" && kind != "openai_compatible" { - return false; - } - let api_key = value["api_key"] - .as_str() - .or_else(|| value["key"].as_str()) - .unwrap_or_default(); - !api_key.trim().is_empty() - || value["base_url"] - .as_str() - .is_some_and(|url| url.contains("127.0.0.1") || url.contains("localhost")) -} - fn is_openai_compatible_kind(kind: &str) -> bool { kind == "openai" || kind == "openai_compatible" } @@ -730,10 +1086,17 @@ mod tests { }; let saved = save_provider_policy(&store, policy.clone())?; - assert_eq!(saved, policy); + assert_eq!( + saved.model_slots[SLOT_POST_TRANSCRIPT], + policy.model_slots[SLOT_POST_TRANSCRIPT] + ); + assert_eq!( + saved.model_slots[SLOT_PROACTIVE].model_id, + MODEL_GPT_5_4_MINI + ); let loaded = load_provider_policy(&store)?; - assert_eq!(loaded, policy); + assert_eq!(loaded, saved); let resolved = resolve_model_slot(&store, SLOT_POST_TRANSCRIPT)?.expect("slot"); assert_eq!(resolved.model_id, "llama3.2"); @@ -787,11 +1150,105 @@ mod tests { } #[test] - fn unresolved_slots_return_none() -> Result<()> { + fn default_slots_select_small_models_but_need_accounts() -> Result<()> { let store = Store::open_in_memory()?; assert!(resolve_model_slot(&store, SLOT_PROACTIVE)?.is_none()); assert!(resolve_model_slot(&store, SLOT_VISION)?.is_none()); + let proactive = resolve_model_slot_result(&store, SLOT_PROACTIVE)?; + assert!(!proactive.ok); + assert_eq!( + proactive + .resolved + .expect("default proactive model") + .model_id, + MODEL_GPT_5_4_MINI + ); + assert!(proactive.reason.contains("no provider account")); + + let memory_search = resolve_model_slot(&store, SLOT_MEMORY_SEARCH)?.expect("memory"); + assert_eq!(memory_search.model_id, MODEL_LOCAL_WIKI); + assert_eq!(memory_search.provider_account, None); + + Ok(()) + } + + #[test] + fn allowed_override_and_disallowed_slot_are_validated() -> Result<()> { + let account = ProviderAccount { + id: "openai-plan".to_string(), + kind: "openai_compatible".to_string(), + base_url: Some("https://api.openai.com/v1".to_string()), + api_key: None, + display_name: Some("OpenAI plan".to_string()), + capabilities: ProviderCapabilities { + chat_completions: true, + json_mode: true, + tool_calls: true, + vision: true, + speech_to_text: false, + }, + subscription_integration: Some("chatgpt_plan".to_string()), + }; + let mut slots = BTreeMap::new(); + slots.insert( + SLOT_CHAT.to_string(), + ModelSlotTarget { + provider_account_id: Some(account.id.clone()), + model_id: MODEL_GPT_5_4.to_string(), + options: ModelSlotOptions { + json_mode: Some(false), + tool_support: Some(true), + }, + }, + ); + let policy = ProviderPolicy { + version: PROVIDER_POLICY_VERSION, + provider_accounts: vec![account.clone()], + model_slots: slots, + }; + validate_provider_policy(&policy)?; + + let mut bad_slots = BTreeMap::new(); + bad_slots.insert( + SLOT_POST_TRANSCRIPT.to_string(), + ModelSlotTarget { + provider_account_id: Some(account.id.clone()), + model_id: MODEL_WHISPER_1.to_string(), + options: ModelSlotOptions::default(), + }, + ); + let bad = ProviderPolicy { + version: PROVIDER_POLICY_VERSION, + provider_accounts: vec![account], + model_slots: bad_slots, + }; + assert!(validate_provider_policy(&bad) + .unwrap_err() + .to_string() + .contains("not allowed")); + + Ok(()) + } + + #[test] + fn model_catalog_reports_no_account_and_memory_search_needs_no_embeddings() -> Result<()> { + let store = Store::open_in_memory()?; + let catalog = model_catalog(&store)?; + let mini = catalog + .iter() + .find(|entry| entry.id == MODEL_GPT_5_4_MINI) + .expect("mini model"); + assert!(!mini.availability.available); + assert!(mini.availability.configured_account_required); + + let memory_search = catalog + .iter() + .find(|entry| entry.id == MODEL_LOCAL_WIKI) + .expect("local wiki model"); + assert!(memory_search.availability.available); + assert!(!memory_search.availability.configured_account_required); + assert_eq!(memory_search.capabilities.origin, "local"); Ok(()) } diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index a0cdb19fe08..76f898fa603 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -26,6 +26,7 @@ pub fn router() -> Router { .route("/v1/profile", get(get_profile).put(update_profile)) .route("/v1/settings", get(list_settings).put(update_settings)) .route("/v1/settings/test-provider", post(test_provider)) + .route("/v1/model-catalog", get(get_model_catalog)) .route( "/v1/provider-policy", get(get_provider_policy).put(update_provider_policy), @@ -883,6 +884,11 @@ async fn get_provider_policy(State(state): State) -> ApiResult Ok(Json(json!({ "provider_policy": policy }))) } +async fn get_model_catalog(State(state): State) -> ApiResult { + let catalog = providers::model_catalog(&state.store).map_err(ApiError::internal)?; + Ok(Json(json!({ "models": catalog }))) +} + async fn update_provider_policy( State(state): State, Json(policy): Json, @@ -896,9 +902,12 @@ async fn resolve_provider_slot( State(state): State, Path(slot): Path, ) -> ApiResult { - let resolved = - providers::resolve_model_slot(&state.store, &slot).map_err(ApiError::internal)?; - Ok(Json(json!({ "resolved": resolved }))) + let resolution = + providers::resolve_model_slot_result(&state.store, &slot).map_err(ApiError::internal)?; + Ok(Json(json!({ + "resolved": resolution.resolved, + "resolution": resolution + }))) } async fn list_processing_jobs(State(state): State) -> ApiResult { diff --git a/desktop/local-backend/tools/seed_hybrid_defaults.sh b/desktop/local-backend/tools/seed_hybrid_defaults.sh index b4b6a9cd776..433518b2286 100755 --- a/desktop/local-backend/tools/seed_hybrid_defaults.sh +++ b/desktop/local-backend/tools/seed_hybrid_defaults.sh @@ -1,91 +1,89 @@ #!/usr/bin/env bash -# Idempotent: seed ai_provider + chat_provider on the local daemon when unset. +# Idempotent: seed local model slots on the local daemon when they lack accounts. set -euo pipefail BASE_URL="${OMI_LOCAL_DAEMON_URL:-http://127.0.0.1:8765}" BASE_URL="${BASE_URL%/}" PROVIDER_BASE="${OMI_HYBRID_DEFAULT_CHAT_BASE_URL:-http://127.0.0.1:11434/v1}" -MODEL="${OMI_HYBRID_DEFAULT_CHAT_MODEL:-llama3.2}" +MODEL="${OMI_HYBRID_DEFAULT_CHAT_MODEL:-gpt-5.4-mini}" +ACCOUNT_ID="${OMI_HYBRID_DEFAULT_PROVIDER_ACCOUNT_ID:-local-openai-compatible}" if ! curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then echo "seed_hybrid_defaults: daemon not healthy at ${BASE_URL}/health" >&2 exit 1 fi -settings_json="$(curl -fsS "${BASE_URL}/v1/settings")" +policy_json="$(curl -fsS "${BASE_URL}/v1/provider-policy")" -has_key() { - local key="$1" - echo "$settings_json" | python3 -c " -import json, sys -key = sys.argv[1] -data = json.load(sys.stdin) -for s in data.get('settings', []): - if s.get('key') != key: +seed_body="$(POLICY_JSON="$policy_json" python3 - "$PROVIDER_BASE" "$MODEL" "$ACCOUNT_ID" <<'PY' +import json +import os +import sys + +provider_base, model, account_id = sys.argv[1:4] +data = json.loads(os.environ["POLICY_JSON"]) +policy = data.get("provider_policy") or {"version": 1} +accounts = policy.setdefault("provider_accounts", []) +slots = policy.setdefault("model_slots", {}) + +account = next((a for a in accounts if a.get("id") == account_id), None) +changed = False +if account is None: + accounts.append({ + "id": account_id, + "kind": "openai_compatible", + "base_url": provider_base, + "api_key": None, + "display_name": "Local OpenAI-compatible", + "capabilities": { + "chat_completions": True, + "json_mode": True, + "tool_calls": False, + "vision": False, + "speech_to_text": False, + }, + "subscription_integration": None, + }) + changed = True + +for slot, json_mode in ( + ("post_transcript", True), + ("proactive", True), + ("chat", False), +): + current = slots.get(slot) or {} + if current.get("provider_account_id"): continue - raw = s.get('value_json') or '' - if not raw or raw == 'null': - sys.exit(1) - try: - v = json.loads(raw) - except json.JSONDecodeError: - sys.exit(1) - if isinstance(v, dict) and v.get('base_url'): - sys.exit(0) -sys.exit(1) -" "$key" -} + slots[slot] = { + "provider_account_id": account_id, + "model_id": model, + "options": { + "json_mode": json_mode, + "tool_support": False, + }, + } + changed = True -provider_payload="$(python3 - < ${PROVIDER_BASE}" -fi - -if ! has_key chat_provider; then - if [ "$updated" -eq 1 ]; then - body="$(echo "$body" | python3 -c " -import json, sys -p = json.loads(sys.stdin.read()) -chat = json.loads('''${provider_payload}''') -p['chat_provider'] = chat -print(json.dumps(p)) -")" - else - body="$(echo "$provider_payload" | python3 -c " -import json, sys -p = json.load(sys.stdin) -print(json.dumps({'chat_provider': p})) -")" - updated=1 - fi - echo "seed_hybrid_defaults: will set chat_provider -> ${PROVIDER_BASE}" -fi +changed="$(echo "$seed_body" | python3 -c 'import json, sys; print(json.load(sys.stdin)["changed"])')" -if [ "$updated" -eq 0 ]; then - echo "seed_hybrid_defaults: ai_provider and chat_provider already configured" +if [ "$changed" != "True" ]; then + echo "seed_hybrid_defaults: model slots already have provider accounts" exit 0 fi -curl -fsS -X PUT "${BASE_URL}/v1/settings" \ +body="$(echo "$seed_body" | python3 -c 'import json, sys; print(json.dumps(json.load(sys.stdin)["policy"]))')" +curl -fsS -X PUT "${BASE_URL}/v1/provider-policy" \ -H 'content-type: application/json' \ -d "$body" >/dev/null -echo "seed_hybrid_defaults: done" +echo "seed_hybrid_defaults: seeded ${ACCOUNT_ID} -> ${PROVIDER_BASE} (${MODEL})" From cb6df9f897d80814bcefc272f82d9f964cb4009e Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 13:42:53 +0700 Subject: [PATCH 33/58] Run local processing through post transcript slot --- desktop/local-backend/README.md | 7 +- .../docs/hybrid-provider-settings.md | 63 +++ .../local-backend/docs/local-mvp-runbook.md | 9 +- desktop/local-backend/src/processing.rs | 419 ++++++++++++++++-- desktop/local-backend/src/providers.rs | 49 +- desktop/local-backend/tools/e2e_smoke.sh | 5 +- 6 files changed, 506 insertions(+), 46 deletions(-) diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md index de3ab5fa0e3..a0f8c2e6f65 100644 --- a/desktop/local-backend/README.md +++ b/desktop/local-backend/README.md @@ -131,9 +131,10 @@ commands for the imported conversation. Stable client conversation, memory, and action-item IDs are idempotent: exact replay returns the existing row, while a conflicting replay returns HTTP 409. -Local processing uses deterministic fallback unless an OpenAI-compatible -provider is configured through structured `PUT /v1/settings` JSON; see the -runbook for local-stub set and clear commands. +Local processing resolves the `post_transcript` model slot from +`/v1/provider-policy`. Without a resolved provider account, it records +deterministic fallback metadata; with an OpenAI-compatible account it persists +model-derived title, overview, memories, action items, and provenance. ## Architecture And E2E Validation diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md index 9cb3ee940f6..ce02d4a99e3 100644 --- a/desktop/local-backend/docs/hybrid-provider-settings.md +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -104,6 +104,69 @@ Desktop clients should prefer these daemon APIs over manual JSON editing: Callers should resolve `post_transcript`, `proactive`, and `chat` slots explicitly instead of duplicating the legacy setting-key scan order. +## Post-transcript processing + +Finalized transcript jobs resolve `/v1/provider-policy/resolve/post_transcript` +inside the daemon. When the slot resolves to an `openai_compatible` account, the +daemon asks that model for strict JSON containing: + +- `title` +- `overview` +- `action_items`: array of `{ "title", "description" }` +- `memories`: array of `{ "content", "category" }` + +Successful jobs persist title/overview on the conversation, replace prior +local-processing memories/action items for that conversation, and record +`local_processing` provenance metadata with the job ID, slot source, provider +account, and model ID. Malformed model JSON is treated as a provider failure, so +the durable job retry counter advances and eventually leaves an inspectable +failed job. + +When the slot has no usable provider account, the daemon completes with +deterministic fallback title/overview, conversation status `processed_fallback`, +and job result metadata containing `mode: "fallback"` plus the slot resolution +reason. + +Minimal local stub policy: + +```bash +curl -fsS -X PUT http://127.0.0.1:8765/v1/provider-policy \ + -H 'content-type: application/json' \ + -d '{ + "version": 1, + "provider_accounts": [{ + "id": "local-openai-compatible", + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "api_key": null, + "display_name": "Local OpenAI-compatible", + "capabilities": { + "chat_completions": true, + "json_mode": true, + "tool_calls": false, + "vision": false, + "speech_to_text": false + }, + "subscription_integration": null + }], + "model_slots": { + "post_transcript": { + "provider_account_id": "local-openai-compatible", + "model_id": "gpt-5.4-mini", + "options": { "json_mode": true, "tool_support": false } + } + } + }' +``` + +After importing a transcript, inspect: + +```bash +curl -fsS http://127.0.0.1:8765/v1/provider-policy/resolve/post_transcript +curl -fsS http://127.0.0.1:8765/v1/processing-jobs/status +curl -fsS http://127.0.0.1:8765/v1/conversations/ +``` + ## Settings keys | Key | Purpose | `kind` values (v1) | diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index 5ee08116acf..3fe2426c5cc 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -328,7 +328,8 @@ directly to the configured provider, not to Omi-hosted backend services. - Local daemon startup on loopback. - SQLite-backed conversation create/read/update/delete. - Transcript segment append and finalize. -- Local fallback processing for title and overview. +- Local post-transcript processing through the `post_transcript` slot, with + deterministic fallback metadata when no provider account is configured. - Local full-text search over conversation and transcript text. - Local memories, action items, profile, and settings endpoints. - Desktop routing for local MVP flows without Firebase auth. @@ -346,8 +347,10 @@ directly to the configured provider, not to Omi-hosted backend services. - Proactive assistant and chat paths that currently depend on Omi-hosted Gemini/Anthropic/provider proxy endpoints are disabled in local daemon mode unless the path has direct local provider configuration. -- Remote AI provider calls from the local daemon require explicit local provider - settings/API keys. Without them, processing uses deterministic fallback output. +- Remote AI provider calls from the local daemon require an explicit provider + account and model slot in `/v1/provider-policy` or compatible legacy settings. + Without a resolved `post_transcript` slot provider, processing uses + deterministic fallback output and records the fallback reason in job metadata. - Fully offline local LLM/STT support is outside the current MVP. ## Known Environment Blockers diff --git a/desktop/local-backend/src/processing.rs b/desktop/local-backend/src/processing.rs index a1141d7501c..334f13239ef 100644 --- a/desktop/local-backend/src/processing.rs +++ b/desktop/local-backend/src/processing.rs @@ -6,7 +6,10 @@ use serde_json::{json, Value}; use tokio::time; use crate::{ - providers::{configured_openai_provider, ChatMessage}, + providers::{ + configured_openai_provider_for_slot, post_transcript_slot_resolution, ChatMessage, + ResolvedModelSlot, SLOT_POST_TRANSCRIPT, + }, storage::{ deterministic_id, NewActionItem, NewMemory, ProcessingJob, ProcessingJobStatus, Store, UpdateConversation, @@ -94,18 +97,34 @@ async fn process_conversation_job(store: &Store, job: &ProcessingJob) -> Result< .map(|segment| segment.text.as_str()) .collect::>() .join(" "); - let output = if let Some(provider) = configured_openai_provider(store)? { + let resolution = post_transcript_slot_resolution(store)?; + let (output, metadata) = if resolution.ok { + let provider = configured_openai_provider_for_slot(store, SLOT_POST_TRANSCRIPT)? + .ok_or_else(|| anyhow!("post_transcript slot resolved to an unsupported provider"))?; let mut output = provider + .provider .complete_json(processing_prompt(&transcript)) .await .and_then(parse_provider_output)?; - output.provider = "openai_compatible".to_string(); - output + let account = provider + .slot + .provider_account + .as_ref() + .ok_or_else(|| anyhow!("post_transcript slot missing provider account"))?; + output.provider = account.kind.clone(); + ( + output, + provider_metadata(job, conversation_id, &provider.slot), + ) } else { - fallback_output(&transcript) + let output = fallback_output(&transcript); + ( + output, + fallback_metadata(job, conversation_id, &resolution.reason), + ) }; - persist_processing_output(store, conversation_id, &output)?; + persist_processing_output(store, conversation_id, &output, &metadata)?; Ok(json!({ "conversation_id": conversation_id, @@ -113,7 +132,8 @@ async fn process_conversation_job(store: &Store, job: &ProcessingJob) -> Result< "overview": output.overview, "action_items": output.action_items, "memories": output.memories, - "provider": output.provider + "provider": output.provider, + "metadata": metadata })) } @@ -129,20 +149,12 @@ fn processing_prompt(transcript: &str) -> Vec { } fn parse_provider_output(value: Value) -> Result { - let title = value["title"] - .as_str() - .unwrap_or_default() - .trim() - .to_string(); - let overview = value["overview"] - .as_str() - .unwrap_or_default() - .trim() - .to_string(); + let title = required_string(&value, "title")?; + let overview = required_string(&value, "overview")?; let action_items = value["action_items"] .as_array() - .into_iter() - .flatten() + .ok_or_else(|| anyhow!("provider output missing action_items array"))? + .iter() .filter_map(|item| { let title = item["title"].as_str()?.trim().to_string(); if title.is_empty() { @@ -160,8 +172,8 @@ fn parse_provider_output(value: Value) -> Result { .collect(); let memories = value["memories"] .as_array() - .into_iter() - .flatten() + .ok_or_else(|| anyhow!("provider output missing memories array"))? + .iter() .filter_map(|item| { let content = item["content"].as_str()?.trim().to_string(); if content.is_empty() { @@ -185,6 +197,18 @@ fn parse_provider_output(value: Value) -> Result { }) } +fn required_string(value: &Value, key: &str) -> Result { + let value = value[key] + .as_str() + .ok_or_else(|| anyhow!("provider output missing {key}"))? + .trim() + .to_string(); + if value.is_empty() { + return Err(anyhow!("provider output {key} was empty")); + } + Ok(value) +} + pub fn fallback_output(transcript: &str) -> ProcessingOutput { let normalized = normalize_whitespace(transcript); ProcessingOutput { @@ -220,7 +244,19 @@ fn persist_processing_output( store: &Store, conversation_id: &str, output: &ProcessingOutput, + processing_metadata: &Value, ) -> Result<()> { + let conversation = store + .conversations() + .get(conversation_id)? + .ok_or_else(|| anyhow!("conversation missing while persisting processing output"))?; + let conversation_metadata = + merge_processing_metadata(&conversation.metadata_json, processing_metadata); + let status = if output.provider == "fallback" { + "processed_fallback" + } else { + "processed" + }; store .conversations() .update( @@ -228,9 +264,9 @@ fn persist_processing_output( UpdateConversation { title: Some(output.title.clone()), overview: Some(output.overview.clone()), - status: Some("processed".to_string()), + status: Some(status.to_string()), ended_at: None, - metadata: None, + metadata: Some(conversation_metadata), starred: None, folder_id: None, }, @@ -257,7 +293,7 @@ fn persist_processing_output( description: Some(item.description.clone()), status: Some("open".to_string()), due_at: None, - metadata: Some(json!({"source": "local_processing"})), + metadata: Some(row_processing_metadata(processing_metadata)), })?; } @@ -282,20 +318,76 @@ fn persist_processing_output( content: memory.content.clone(), category: memory.category.clone(), conversation_id: Some(conversation_id.to_string()), - metadata: Some(json!({"source": "local_processing"})), + metadata: Some(row_processing_metadata(processing_metadata)), })?; } Ok(()) } +fn provider_metadata( + job: &ProcessingJob, + conversation_id: &str, + slot: &ResolvedModelSlot, +) -> Value { + let account = slot.provider_account.as_ref(); + json!({ + "source": "local_processing", + "mode": "model", + "conversation_id": conversation_id, + "job_id": job.id, + "job_kind": job.kind, + "slot": slot.slot, + "slot_source": slot.source, + "model_id": slot.model_id, + "provider_account_id": account.map(|account| account.id.clone()), + "provider_kind": account.map(|account| account.kind.clone()), + "options": slot.options.clone() + }) +} + +fn fallback_metadata(job: &ProcessingJob, conversation_id: &str, reason: &str) -> Value { + json!({ + "source": "local_processing", + "mode": "fallback", + "conversation_id": conversation_id, + "job_id": job.id, + "job_kind": job.kind, + "slot": SLOT_POST_TRANSCRIPT, + "fallback_reason": reason + }) +} + +fn row_processing_metadata(processing_metadata: &Value) -> Value { + json!({ + "source": "local_processing", + "local_processing": processing_metadata + }) +} + +fn merge_processing_metadata(existing_metadata_json: &str, processing_metadata: &Value) -> Value { + let mut metadata = + serde_json::from_str::(existing_metadata_json).unwrap_or_else(|_| json!({})); + if !metadata.is_object() { + metadata = json!({}); + } + metadata["local_processing"] = processing_metadata.clone(); + metadata +} + #[cfg(test)] mod tests { + use crate::providers::{ + save_provider_policy, ModelSlotOptions, ModelSlotTarget, ProviderAccount, + ProviderCapabilities, ProviderPolicy, PROVIDER_POLICY_VERSION, + }; use crate::storage::{NewConversation, NewProcessingJob, NewTranscriptSegment}; use super::*; + use axum::{routing::post, Json, Router}; use serde_json::Map; - use std::net::TcpListener; + use std::{collections::BTreeMap, net::TcpListener}; + use tokio::net::TcpListener as TokioTcpListener; #[test] fn fallback_processing_is_deterministic_and_empty_for_items_and_memories() { @@ -364,7 +456,7 @@ mod tests { conversation.title, "Plan the desktop local backend MVP and verify" ); - assert_eq!(conversation.status, "processed"); + assert_eq!(conversation.status, "processed_fallback"); assert!(conversation .overview .starts_with("Plan the desktop local backend MVP")); @@ -373,6 +465,137 @@ mod tests { let result: Value = serde_json::from_str(&job.result_json)?; assert_eq!(result["provider"], "fallback"); + assert_eq!(result["metadata"]["mode"], "fallback"); + assert_eq!(result["metadata"]["slot"], SLOT_POST_TRANSCRIPT); + + let conversation_metadata: Value = serde_json::from_str(&conversation.metadata_json)?; + assert_eq!( + conversation_metadata["local_processing"]["mode"], + "fallback" + ); + + Ok(()) + } + + #[tokio::test] + async fn post_transcript_slot_provider_persists_model_outputs() -> Result<()> { + let store = Store::open_in_memory()?; + let conversation_id = deterministic_id("conv", &["session-slot-processing"]); + configure_post_transcript_stub_provider( + &store, + &spawn_processing_stub(json!({ + "title": "Slot generated title", + "overview": "Slot generated overview", + "action_items": [{ + "title": "Review slot wiring", + "description": "Confirm post transcript processing uses provider policy." + }], + "memories": [{ + "content": "User wants local post transcript processing through slots.", + "category": "preference" + }] + })) + .await?, + "slot-model", + )?; + + seed_conversation_with_segment( + &store, + &conversation_id, + "session-slot-processing", + "Use the configured slot provider for local processing.", + )?; + store.processing_jobs().enqueue(NewProcessingJob { + id: deterministic_id("job", &["slot-processing", &conversation_id]), + kind: "finalize_transcript".to_string(), + target_conversation_id: Some(conversation_id.clone()), + max_retries: Some(3), + payload: Some(json!({"conversation_id": conversation_id})), + })?; + + let job = process_next_job(&store) + .await? + .expect("queued job should be processed"); + assert_eq!(job.status, ProcessingJobStatus::Completed); + + let conversation = store + .conversations() + .get(job.target_conversation_id.as_ref().unwrap())? + .expect("conversation should exist"); + assert_eq!(conversation.title, "Slot generated title"); + assert_eq!(conversation.overview, "Slot generated overview"); + assert_eq!(conversation.status, "processed"); + + let action_items = store.action_items().list()?; + assert_eq!(action_items.len(), 1); + assert_eq!(action_items[0].title, "Review slot wiring"); + let action_metadata: Value = serde_json::from_str(&action_items[0].metadata_json)?; + assert_eq!( + action_metadata["local_processing"]["model_id"], + "slot-model" + ); + + let memories = store.memories().list()?; + assert_eq!(memories.len(), 1); + assert_eq!( + memories[0].content, + "User wants local post transcript processing through slots." + ); + + let result: Value = serde_json::from_str(&job.result_json)?; + assert_eq!(result["metadata"]["mode"], "model"); + assert_eq!(result["metadata"]["slot"], SLOT_POST_TRANSCRIPT); + assert_eq!(result["metadata"]["model_id"], "slot-model"); + assert_eq!(result["metadata"]["slot_source"], "provider_policy"); + + Ok(()) + } + + #[tokio::test] + async fn malformed_provider_json_fails_and_is_retry_safe() -> Result<()> { + let store = Store::open_in_memory()?; + let conversation_id = deterministic_id("conv", &["session-malformed-provider"]); + configure_post_transcript_stub_provider( + &store, + &spawn_processing_stub(json!({ + "title": "Missing arrays", + "overview": "This should not be accepted." + })) + .await?, + "slot-model", + )?; + + seed_conversation_with_segment( + &store, + &conversation_id, + "session-malformed-provider", + "Malformed model JSON should fail instead of empty-success processing.", + )?; + store.processing_jobs().enqueue(NewProcessingJob { + id: deterministic_id("job", &["malformed-provider", &conversation_id]), + kind: "finalize_transcript".to_string(), + target_conversation_id: Some(conversation_id.clone()), + max_retries: Some(1), + payload: Some(json!({"conversation_id": conversation_id})), + })?; + + let job = process_next_job(&store) + .await? + .expect("failed job should be returned"); + assert_eq!(job.status, ProcessingJobStatus::Failed); + assert!(job + .last_error + .as_deref() + .unwrap_or("") + .contains("action_items array")); + + let conversation = store + .conversations() + .get(job.target_conversation_id.as_ref().unwrap())? + .expect("conversation should exist"); + assert_eq!(conversation.status, "open"); + assert!(store.action_items().list()?.is_empty()); + assert!(store.memories().list()?.is_empty()); Ok(()) } @@ -482,8 +705,21 @@ mod tests { provider: "openai_compatible".to_string(), }; - persist_processing_output(&store, &conversation_id, &output)?; - persist_processing_output(&store, &conversation_id, &output)?; + let metadata = json!({ + "source": "local_processing", + "mode": "model", + "conversation_id": conversation_id, + "job_id": "job-provider-retry-1", + "job_kind": "finalize_transcript", + "slot": SLOT_POST_TRANSCRIPT, + "slot_source": "provider_policy", + "model_id": "first-model", + "provider_account_id": "local-openai", + "provider_kind": "openai_compatible" + }); + + persist_processing_output(&store, &conversation_id, &output, &metadata)?; + persist_processing_output(&store, &conversation_id, &output, &metadata)?; let action_items = store.action_items().list()?; let memories = store.memories().list()?; @@ -505,7 +741,24 @@ mod tests { memories: Vec::new(), provider: "openai_compatible".to_string(), }; - persist_processing_output(&store, &conversation_id, &replacement)?; + let replacement_metadata = json!({ + "source": "local_processing", + "mode": "model", + "conversation_id": conversation_id, + "job_id": "job-provider-retry-2", + "job_kind": "finalize_transcript", + "slot": SLOT_POST_TRANSCRIPT, + "slot_source": "provider_policy", + "model_id": "replacement-model", + "provider_account_id": "local-openai", + "provider_kind": "openai_compatible" + }); + persist_processing_output( + &store, + &conversation_id, + &replacement, + &replacement_metadata, + )?; let action_items = store.action_items().list()?; let memories = store.memories().list()?; @@ -513,6 +766,112 @@ mod tests { assert_eq!(action_items[0].title, "Ship retry behavior"); assert!(memories.is_empty()); + let action_metadata: Value = serde_json::from_str(&action_items[0].metadata_json)?; + assert_eq!( + action_metadata["local_processing"]["model_id"], + "replacement-model" + ); + Ok(()) } + + fn seed_conversation_with_segment( + store: &Store, + conversation_id: &str, + session_id: &str, + text: &str, + ) -> Result<()> { + store.conversations().create(NewConversation { + id: conversation_id.to_string(), + session_id: session_id.to_string(), + title: String::new(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + store.transcripts().append(NewTranscriptSegment { + id: deterministic_id("seg", &[conversation_id, "0"]), + conversation_id: conversation_id.to_string(), + session_id: session_id.to_string(), + speaker_id: None, + speaker_label: None, + text: text.to_string(), + start_ms: 0, + end_ms: 1000, + segment_index: 0, + source: None, + metadata: None, + })?; + Ok(()) + } + + fn configure_post_transcript_stub_provider( + store: &Store, + base_url: &str, + model_id: &str, + ) -> Result<()> { + let account = ProviderAccount { + id: "slot-stub".to_string(), + kind: "openai_compatible".to_string(), + base_url: Some(base_url.to_string()), + api_key: Some("local-test-key".to_string()), + display_name: Some("Slot Stub".to_string()), + capabilities: ProviderCapabilities { + chat_completions: true, + json_mode: true, + tool_calls: false, + vision: false, + speech_to_text: false, + }, + subscription_integration: None, + }; + let mut slots = BTreeMap::new(); + slots.insert( + SLOT_POST_TRANSCRIPT.to_string(), + ModelSlotTarget { + provider_account_id: Some(account.id.clone()), + model_id: model_id.to_string(), + options: ModelSlotOptions { + json_mode: Some(true), + tool_support: Some(false), + }, + }, + ); + save_provider_policy( + store, + ProviderPolicy { + version: PROVIDER_POLICY_VERSION, + provider_accounts: vec![account], + model_slots: slots, + }, + )?; + Ok(()) + } + + async fn spawn_processing_stub(content: Value) -> Result { + let content = serde_json::to_string(&content)?; + let app = Router::new().route( + "/v1/chat/completions", + post(move || { + let content = content.clone(); + async move { + Json(json!({ + "choices": [{ + "message": { + "content": content + } + }] + })) + } + }), + ); + let listener = TokioTcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + tokio::spawn(async move { + axum::serve(listener, app) + .await + .expect("stub server failed"); + }); + Ok(format!("http://{addr}/v1")) + } } diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs index 1d8729476d3..7a987ae0dd5 100644 --- a/desktop/local-backend/src/providers.rs +++ b/desktop/local-backend/src/providers.rs @@ -172,6 +172,11 @@ pub struct OpenAiCompatibleProvider { client: Client, } +pub struct ResolvedOpenAiCompatibleProvider { + pub provider: OpenAiCompatibleProvider, + pub slot: ResolvedModelSlot, +} + impl OpenAiCompatibleProvider { pub fn new(config: OpenAiCompatibleConfig) -> Self { Self { @@ -223,32 +228,60 @@ impl OpenAiCompatibleProvider { } } -pub fn configured_openai_provider(store: &Store) -> Result> { - let Some(config) = load_openai_config(store)? else { +pub fn configured_openai_provider_for_slot( + store: &Store, + slot: &str, +) -> Result> { + let Some(resolved) = resolve_model_slot(store, slot)? else { + return Ok(None); + }; + let Some(account) = resolved.provider_account.as_ref() else { return Ok(None); }; - Ok(Some(OpenAiCompatibleProvider::new(config))) + if !is_openai_compatible_kind(&account.kind) { + return Ok(None); + } + let config = openai_config_from_resolved_slot(&resolved)?; + Ok(Some(ResolvedOpenAiCompatibleProvider { + provider: OpenAiCompatibleProvider::new(config), + slot: resolved, + })) } pub fn load_openai_config(store: &Store) -> Result> { let Some(resolved) = resolve_model_slot(store, SLOT_POST_TRANSCRIPT)? else { return Ok(None); }; - let Some(account) = resolved.provider_account else { + let Some(account) = resolved.provider_account.as_ref() else { return Ok(None); }; if !is_openai_compatible_kind(&account.kind) { return Ok(None); } + openai_config_from_resolved_slot(&resolved).map(Some) +} + +fn openai_config_from_resolved_slot( + resolved: &ResolvedModelSlot, +) -> Result { + let account = resolved + .provider_account + .as_ref() + .ok_or_else(|| anyhow!("resolved model slot missing provider account"))?; let base_url = account .base_url + .clone() .unwrap_or_else(|| "https://api.openai.com/v1".to_string()); validate_provider_base_url(&base_url)?; - Ok(Some(OpenAiCompatibleConfig { + Ok(OpenAiCompatibleConfig { base_url, - model: resolved.model_id, - api_key: account.api_key.unwrap_or_default(), - })) + model: resolved.model_id.clone(), + api_key: account.api_key.clone().unwrap_or_default(), + }) +} + +pub fn post_transcript_slot_resolution(store: &Store) -> Result { + resolve_model_slot_result(store, SLOT_POST_TRANSCRIPT) } pub fn load_provider_policy(store: &Store) -> Result { diff --git a/desktop/local-backend/tools/e2e_smoke.sh b/desktop/local-backend/tools/e2e_smoke.sh index 55c221833f4..bf9d6c8e239 100755 --- a/desktop/local-backend/tools/e2e_smoke.sh +++ b/desktop/local-backend/tools/e2e_smoke.sh @@ -333,7 +333,7 @@ status_file="$(request GET /v1/processing-jobs/status)" assert_json_value "${status_file}" "failed" "0" processed_file="$(request GET /v1/conversations/conv-e2e-smoke)" -assert_json_value "${processed_file}" "conversation.status" "processed" +assert_json_value "${processed_file}" "conversation.status" "processed_fallback" assert_json_value "${processed_file}" "conversation.title" "Plan the backend free desktop MVP and verify" request_status POST /v1/conversations 409 '{ @@ -395,6 +395,7 @@ if [[ "${configured_provider}" != "openai_compatible" ]]; then exit 1 fi provider_processed_file="$(request GET /v1/conversations/conv-provider-smoke)" +assert_json_value "${provider_processed_file}" "conversation.status" "processed" assert_json_value "${provider_processed_file}" "conversation.title" "Stub provider title" provider_request_count="$(wc -l <"${PROVIDER_LOG_FILE}" | tr -d ' ')" @@ -456,7 +457,7 @@ stop_daemon start_daemon persisted_file="$(request GET /v1/conversations/conv-e2e-smoke)" -assert_json_value "${persisted_file}" "conversation.status" "processed" +assert_json_value "${persisted_file}" "conversation.status" "processed_fallback" assert_json_value "${persisted_file}" "transcript_segments.0.text" "Plan the backend free desktop MVP and verify deterministic local processing." persisted_search_file="$(request GET '/v1/search/conversations?q=backend')" From 0720be77e6fac9be2aec3f312dda2932b09462b3 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 13:50:57 +0700 Subject: [PATCH 34/58] Route proactive assistants through local slots --- desktop/Desktop/Sources/APIClient.swift | 33 ++++- desktop/Desktop/Sources/HybridLLMClient.swift | 71 +++++++-- .../Sources/HybridProviderPolicy.swift | 121 ++++++++++++++++ .../Sources/HybridVisionProvider.swift | 18 +++ .../Core/GeminiClient.swift | 130 ++++++++++++----- .../Tests/HybridLLMProviderConfigTests.swift | 135 ++++++++++++++++-- desktop/local-backend/README.md | 6 + desktop/local-backend/docs/architecture.md | 9 ++ .../docs/hybrid-provider-settings.md | 18 +++ 9 files changed, 484 insertions(+), 57 deletions(-) create mode 100644 desktop/Desktop/Sources/HybridProviderPolicy.swift diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 4bf5053ba71..da158357255 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -260,6 +260,24 @@ actor APIClient { ) } + func resolveSelectedBackendProviderSlot(_ slot: String) async throws + -> HybridProviderPolicy.SlotResolutionResponse + { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + throw APIError.featureUnavailable( + feature: "resolve_provider_slot", + reason: "Provider policy slots are only available in local daemon mode." + ) + } + let encodedSlot = slot.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? slot + return try await get( + "v1/provider-policy/resolve/\(encodedSlot)", + requireAuth: false, + customBaseURL: target.baseURL + ) + } + // MARK: - Request Execution private func performRequest(_ request: URLRequest, retryOnUnauthorized: Bool) @@ -2095,7 +2113,8 @@ extension APIClient { if selectedBackendTarget.mode == .localDaemon { throw APIError.featureUnavailable( feature: "memory_visibility", - reason: "Memory visibility controls are cloud-sharing metadata and are disabled in local daemon mode." + reason: + "Memory visibility controls are cloud-sharing metadata and are disabled in local daemon mode." ) } @@ -2128,7 +2147,8 @@ extension APIClient { if selectedBackendTarget.mode == .localDaemon { throw APIError.featureUnavailable( feature: "memory_read_status", - reason: "Bulk memory read status is cloud-only metadata and is disabled in local daemon mode." + reason: + "Bulk memory read status is cloud-only metadata and is disabled in local daemon mode." ) } @@ -2140,7 +2160,8 @@ extension APIClient { if selectedBackendTarget.mode == .localDaemon { throw APIError.featureUnavailable( feature: "memory_visibility", - reason: "Memory visibility controls are cloud-sharing metadata and are disabled in local daemon mode." + reason: + "Memory visibility controls are cloud-sharing metadata and are disabled in local daemon mode." ) } @@ -2156,7 +2177,8 @@ extension APIClient { if selectedBackendTarget.mode == .localDaemon { throw APIError.featureUnavailable( feature: "memory_bulk_delete", - reason: "Bulk memory deletion is not exposed by the local daemon yet. Delete individual local memories instead." + reason: + "Bulk memory deletion is not exposed by the local daemon yet. Delete individual local memories instead." ) } @@ -2786,7 +2808,8 @@ extension APIClient { /// Fetches staged tasks ordered by relevance score func getStagedTasks(limit: Int = 100, offset: Int = 0) async throws -> ActionItemsListResponse { if selectedBackendTarget.mode == .localDaemon { - let items = try await StagedTaskStorage.shared.getScoredStagedTasks(limit: limit, offset: offset) + let items = try await StagedTaskStorage.shared.getScoredStagedTasks( + limit: limit, offset: offset) return ActionItemsListResponse(items: items, hasMore: items.count == limit) } diff --git a/desktop/Desktop/Sources/HybridLLMClient.swift b/desktop/Desktop/Sources/HybridLLMClient.swift index f820d6dda43..6ded1c0b0a7 100644 --- a/desktop/Desktop/Sources/HybridLLMClient.swift +++ b/desktop/Desktop/Sources/HybridLLMClient.swift @@ -9,6 +9,8 @@ actor HybridDaemonSettingsCache { private var cached: [LocalDaemonSetting]? private var fetchedAt: Date? + private var slotResolutions: + [String: (response: HybridProviderPolicy.SlotResolutionResponse, fetchedAt: Date)] = [:] private let ttlSeconds: TimeInterval = 45 func settings() async throws -> [LocalDaemonSetting] { @@ -26,6 +28,21 @@ actor HybridDaemonSettingsCache { fetchedAt = Date() return fresh } + + func slotResolution(_ slot: String) async throws -> HybridProviderPolicy.SlotResolutionResponse? { + if CodexAuthService.isActive { + return nil + } + guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { + return nil + } + if let cached = slotResolutions[slot], Date().timeIntervalSince(cached.fetchedAt) < ttlSeconds { + return cached.response + } + let fresh = try await APIClient.shared.resolveSelectedBackendProviderSlot(slot) + slotResolutions[slot] = (fresh, Date()) + return fresh + } } // MARK: - Hybrid LLM (OpenAI-compatible chat completions) @@ -48,7 +65,8 @@ enum HybridLLMClient { var errorDescription: String? { switch self { case .notConfigured: - return "Hybrid AI is not configured. Set ai_provider or chat_provider in Settings, or add a BYOK OpenAI key." + return + "Hybrid AI is not configured. Set ai_provider or chat_provider in Settings, or add a BYOK OpenAI key." case .invalidSettings: return "Hybrid provider settings are invalid." case .invalidResponse: @@ -91,6 +109,28 @@ enum HybridLLMClient { ?? byokOpenAIConfig() } + /// Local proactive routing must be governed by the daemon's `proactive` slot. + static func resolveEffectiveProactiveConfig( + slotResolution: HybridProviderPolicy.SlotResolutionResponse? + ) -> ProviderConfig? { + if let codex = codexProviderConfig() { + return codex + } + guard let slotResolution else { + return nil + } + return HybridProviderPolicy.providerConfig(from: slotResolution) + } + + static func resolveEffectiveProactiveConfig(settings: [LocalDaemonSetting]) -> ProviderConfig? { + resolveEffectiveProactiveConfig( + slotResolution: HybridProviderPolicy.resolveSlotFromSettings( + HybridProviderPolicy.proactiveSlot, + settings: settings + ) + ) + } + /// BYOK OpenAI → vendor endpoint (desktop hybrid escape hatch when daemon JSON is unset). private static func byokOpenAIConfig() -> ProviderConfig? { guard let key = APIKeyService.byokKey(.openai), !key.isEmpty else { @@ -141,7 +181,8 @@ enum HybridLLMClient { return url } - private static func postJSON(url: URL, body: [String: Any], apiKey: String, timeout: TimeInterval) async throws + private static func postJSON(url: URL, body: [String: Any], apiKey: String, timeout: TimeInterval) + async throws -> [String: Any] { var request = URLRequest(url: url) @@ -184,7 +225,8 @@ enum HybridLLMClient { ["role": "user", "content": content], ] return try await chatCompletionRaw( - config: config, messages: messages, jsonMode: jsonMode, tools: nil, toolChoice: nil, timeout: timeout) + config: config, messages: messages, jsonMode: jsonMode, tools: nil, toolChoice: nil, + timeout: timeout) } static func chatCompletionMultimodalJPEG( @@ -206,7 +248,8 @@ enum HybridLLMClient { ["role": "user", "content": content], ] return try await chatCompletionRaw( - config: config, messages: messages, jsonMode: jsonMode, tools: nil, toolChoice: nil, timeout: timeout) + config: config, messages: messages, jsonMode: jsonMode, tools: nil, toolChoice: nil, + timeout: timeout) } private static func chatCompletionRaw( @@ -232,7 +275,8 @@ enum HybridLLMClient { body["tool_choice"] = toolChoice } - let json = try await postJSON(url: try completionsURL(config: config), body: body, apiKey: config.apiKey, timeout: timeout) + let json = try await postJSON( + url: try completionsURL(config: config), body: body, apiKey: config.apiKey, timeout: timeout) return try extractAssistantText(from: json) } @@ -285,7 +329,8 @@ enum HybridLLMClient { "temperature": 0.4, ] - let json = try await postJSON(url: try completionsURL(config: config), body: body, apiKey: config.apiKey, timeout: timeout) + let json = try await postJSON( + url: try completionsURL(config: config), body: body, apiKey: config.apiKey, timeout: timeout) return try parseToolChatResult(from: json) } @@ -346,7 +391,8 @@ enum HybridLLMClient { for part in content.parts { if let fr = part.functionResponse { let toolCallId = - pendingToolCallIds.isEmpty ? "call_hybrid_fallback_\(fr.name)" : pendingToolCallIds.removeFirst() + pendingToolCallIds.isEmpty + ? "call_hybrid_fallback_\(fr.name)" : pendingToolCallIds.removeFirst() toolResults.append([ "role": "tool", "tool_call_id": toolCallId, @@ -422,7 +468,9 @@ enum HybridLLMClient { } } - private static func jsonSchema(from params: GeminiTool.FunctionDeclaration.Parameters) -> [String: Any]? { + private static func jsonSchema(from params: GeminiTool.FunctionDeclaration.Parameters) -> [String: + Any]? + { var properties: [String: Any] = [:] for (name, prop) in params.properties { if let nested = propJSONSchema(prop) { @@ -437,7 +485,9 @@ enum HybridLLMClient { return schema } - private static func propJSONSchema(_ prop: GeminiTool.FunctionDeclaration.Parameters.Property) -> [String: Any]? { + private static func propJSONSchema(_ prop: GeminiTool.FunctionDeclaration.Parameters.Property) + -> [String: Any]? + { if let nested = prop.nestedProperties, let req = prop.nestedRequired { var childProps: [String: Any] = [:] for (k, v) in nested { @@ -488,7 +538,8 @@ enum HybridLLMClient { return } let observations = (request.results as? [VNRecognizedTextObservation]) ?? [] - let text = observations.compactMap { $0.topCandidates(1).first?.string }.joined(separator: "\n") + let text = observations.compactMap { $0.topCandidates(1).first?.string }.joined( + separator: "\n") continuation.resume(returning: text) } request.recognitionLevel = .accurate diff --git a/desktop/Desktop/Sources/HybridProviderPolicy.swift b/desktop/Desktop/Sources/HybridProviderPolicy.swift new file mode 100644 index 00000000000..07ab92ec2c8 --- /dev/null +++ b/desktop/Desktop/Sources/HybridProviderPolicy.swift @@ -0,0 +1,121 @@ +import Foundation + +enum HybridProviderPolicy { + static let proactiveSlot = "proactive" + static let visionSlot = "vision" + + struct ProviderAccount: Decodable, Equatable { + let id: String + let kind: String + let baseURL: String? + let apiKey: String? + + enum CodingKeys: String, CodingKey { + case id + case kind + case baseURL = "base_url" + case apiKey = "api_key" + } + } + + struct ResolvedSlot: Decodable, Equatable { + let slot: String + let providerAccount: ProviderAccount? + let modelID: String + let source: String + + enum CodingKeys: String, CodingKey { + case slot + case providerAccount = "provider_account" + case modelID = "model_id" + case source + } + } + + struct SlotResolution: Decodable, Equatable { + let slot: String + let ok: Bool + let resolved: ResolvedSlot? + let reason: String + } + + struct SlotResolutionResponse: Decodable, Equatable { + let resolved: ResolvedSlot? + let resolution: SlotResolution + } + + struct Policy: Decodable { + let version: Int + let providerAccounts: [ProviderAccount] + let modelSlots: [String: ModelSlotTarget] + + enum CodingKeys: String, CodingKey { + case version + case providerAccounts = "provider_accounts" + case modelSlots = "model_slots" + } + } + + struct ModelSlotTarget: Decodable { + let providerAccountID: String? + let modelID: String + + enum CodingKeys: String, CodingKey { + case providerAccountID = "provider_account_id" + case modelID = "model_id" + } + } + + static func providerConfig(from response: SlotResolutionResponse) -> HybridLLMClient + .ProviderConfig? + { + guard response.resolution.ok, + let resolved = response.resolved, + let account = resolved.providerAccount, + let baseURL = account.baseURL?.trimmingCharacters(in: .whitespacesAndNewlines), + !baseURL.isEmpty, + isOpenAICompatible(kind: account.kind) + else { + return nil + } + return HybridLLMClient.ProviderConfig( + baseURL: baseURL, + model: resolved.modelID, + apiKey: account.apiKey ?? "" + ) + } + + static func resolveSlotFromSettings( + _ slot: String, + settings: [LocalDaemonSetting] + ) -> SlotResolutionResponse? { + guard let raw = settings.first(where: { $0.key == "provider_policy" })?.valueJson, + let data = raw.data(using: .utf8), + let policy = try? JSONDecoder().decode(Policy.self, from: data), + let target = policy.modelSlots[slot] + else { + return nil + } + let account = target.providerAccountID.flatMap { accountID in + policy.providerAccounts.first(where: { $0.id == accountID }) + } + let resolved = ResolvedSlot( + slot: slot, + providerAccount: account, + modelID: target.modelID, + source: "provider_policy" + ) + let ok = account != nil || slot == "memory_search" + let reason = + ok + ? "\(slot) resolved to \(target.modelID) from provider_policy" + : "model slot \(slot) selects \(target.modelID) but no provider account is configured" + let resolution = SlotResolution(slot: slot, ok: ok, resolved: resolved, reason: reason) + return SlotResolutionResponse(resolved: resolved, resolution: resolution) + } + + static func isOpenAICompatible(kind: String) -> Bool { + let normalized = kind.lowercased() + return normalized == "openai_compatible" || normalized == "openai" + } +} diff --git a/desktop/Desktop/Sources/HybridVisionProvider.swift b/desktop/Desktop/Sources/HybridVisionProvider.swift index 1e79e40be69..875d09e2ffc 100644 --- a/desktop/Desktop/Sources/HybridVisionProvider.swift +++ b/desktop/Desktop/Sources/HybridVisionProvider.swift @@ -19,4 +19,22 @@ enum HybridVisionProvider { } return true } + + static func providerConfig(from response: HybridProviderPolicy.SlotResolutionResponse?) + -> HybridLLMClient.ProviderConfig? + { + guard let response else { + return nil + } + return HybridProviderPolicy.providerConfig(from: response) + } + + static func providerConfig(settings: [LocalDaemonSetting]) -> HybridLLMClient.ProviderConfig? { + providerConfig( + from: HybridProviderPolicy.resolveSlotFromSettings( + HybridProviderPolicy.visionSlot, + settings: settings + ) + ) + } } diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift index 7cb9d89b56f..aed80795346 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift @@ -199,7 +199,9 @@ actor GeminiClient { if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { return "" } - if let cString = getenv("OMI_DESKTOP_API_URL"), let url = String(validatingUTF8: cString), !url.isEmpty { + if let cString = getenv("OMI_DESKTOP_API_URL"), let url = String(validatingUTF8: cString), + !url.isEmpty + { return url.hasSuffix("/") ? url : url + "/" } return "https://api.omi.me/" @@ -207,6 +209,7 @@ actor GeminiClient { enum GeminiClientError: LocalizedError { case missingAPIKey + case providerNotReady(String) case networkError(Error) case invalidResponse case apiError(String) @@ -221,6 +224,8 @@ actor GeminiClient { switch self { case .missingAPIKey: return "AI features are not configured. Please update the app." + case .providerNotReady(let reason): + return "Local proactive AI is not ready: \(reason)" case .networkError: return "Could not reach AI service. Check your internet connection and try again." case .invalidResponse: @@ -289,6 +294,70 @@ actor GeminiClient { } } + private func proactiveProviderConfig() async throws -> HybridLLMClient.ProviderConfig { + if CodexAuthService.isActive { + await CodexProxyService.shared.ensureRunning() + if let config = HybridLLMClient.codexProviderConfig() { + return config + } + } + let slotResolution: HybridProviderPolicy.SlotResolutionResponse? + do { + slotResolution = try await HybridDaemonSettingsCache.shared.slotResolution( + HybridProviderPolicy.proactiveSlot + ) + } catch { + logError("GeminiClient: failed to resolve local proactive slot", error: error) + slotResolution = nil + } + if let resolution = slotResolution { + if let config = HybridLLMClient.resolveEffectiveProactiveConfig(slotResolution: resolution) { + return config + } + throw GeminiClientError.providerNotReady(resolution.resolution.reason) + } + let settings = try await HybridDaemonSettingsCache.shared.settings() + if let config = HybridLLMClient.resolveEffectiveProactiveConfig(settings: settings) { + return config + } + if let resolution = HybridProviderPolicy.resolveSlotFromSettings( + HybridProviderPolicy.proactiveSlot, + settings: settings + ) { + throw GeminiClientError.providerNotReady(resolution.resolution.reason) + } + throw GeminiClientError.providerNotReady( + "no proactive model slot is configured in provider_policy" + ) + } + + private func visionProviderConfig() async throws -> HybridLLMClient.ProviderConfig? { + let slotResolution: HybridProviderPolicy.SlotResolutionResponse? + do { + slotResolution = try await HybridDaemonSettingsCache.shared.slotResolution( + HybridProviderPolicy.visionSlot + ) + } catch { + logError("GeminiClient: failed to resolve local vision slot", error: error) + slotResolution = nil + } + if let resolution = slotResolution { + if let config = HybridVisionProvider.providerConfig(from: resolution) { + return config + } + log( + "GeminiClient: local vision slot unavailable; using OCR-only screen context (\(resolution.resolution.reason))" + ) + return nil + } + let settings = try await HybridDaemonSettingsCache.shared.settings() + if let config = HybridVisionProvider.providerConfig(settings: settings) { + return config + } + log("GeminiClient: no local vision slot configured; using OCR-only screen context") + return nil + } + /// Text chat completion via hybrid OpenAI-compatible provider (or BYOK OpenAI). private func hybridChatText( systemPrompt: String, @@ -296,13 +365,7 @@ actor GeminiClient { jsonMode: Bool, timeout: TimeInterval = 300 ) async throws -> String { - if CodexAuthService.isActive { - await CodexProxyService.shared.ensureRunning() - } - let settings = try await HybridDaemonSettingsCache.shared.settings() - guard let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) else { - throw GeminiClientError.missingAPIKey - } + let config = try await proactiveProviderConfig() do { return try await HybridLLMClient.chatCompletionText( config: config, @@ -316,7 +379,7 @@ actor GeminiClient { } } - /// Multimodal when `vision_provider` is set; otherwise macOS Vision OCR + text JSON. + /// Multimodal when the daemon `vision` slot resolves; otherwise macOS Vision OCR + text JSON. private func hybridChatImageOrOCR( prompt: String, imageData: Data, @@ -324,14 +387,8 @@ actor GeminiClient { jsonMode: Bool, timeout: TimeInterval = 300 ) async throws -> String { - if CodexAuthService.isActive { - await CodexProxyService.shared.ensureRunning() - } - let settings = try await HybridDaemonSettingsCache.shared.settings() - guard let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) else { - throw GeminiClientError.missingAPIKey - } - let visionConfig = HybridLLMClient.loadVisionProviderConfig(from: settings) + let config = try await proactiveProviderConfig() + let visionConfig = try await visionProviderConfig() do { if let visionConfig { return try await HybridLLMClient.chatCompletionMultimodalJPEG( @@ -343,6 +400,7 @@ actor GeminiClient { timeout: timeout ) } + log("GeminiClient: local proactive image request using OCR-only screen context") let ocr = try await HybridLLMClient.ScreenOCR.recognizeTextFromJPEG(imageData) let user = prompt @@ -373,7 +431,6 @@ actor GeminiClient { URL(string: "\(Self.proxyBaseURL)v1/proxy/gemini/models/\(model):\(action)")! } - /// Log the raw API error message for debugging and throw a sanitized error. /// The `errorDescription` on GeminiClientError is user-friendly; this log preserves the raw detail. private func throwAPIError(_ rawMessage: String) throws -> Never { @@ -423,7 +480,7 @@ actor GeminiClient { || lower.contains("internal error") case .networkError: return true - case .invalidResponse, .missingAPIKey: + case .invalidResponse, .missingAPIKey, .providerNotReady: return false } } @@ -440,7 +497,9 @@ actor GeminiClient { /// Sleep with exponential backoff (2s, 8s) and log the retry attempt. private func retryBackoff(attempt: Int, error: Error) async { let delaySec = [2, 8][min(attempt, 1)] - log("GeminiClient: transient error, retrying in \(delaySec)s (attempt \(attempt + 2)/3): \(error.localizedDescription)") + log( + "GeminiClient: transient error, retrying in \(delaySec)s (attempt \(attempt + 2)/3): \(error.localizedDescription)" + ) try? await Task.sleep(nanoseconds: UInt64(delaySec) * 1_000_000_000) } @@ -492,7 +551,8 @@ actor GeminiClient { generationConfig: GeminiRequest.GenerationConfig( responseMimeType: "application/json", responseSchema: responseSchema, - thinkingConfig: ThinkingConfig(thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) + thinkingConfig: ThinkingConfig( + thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) ) ) @@ -579,7 +639,8 @@ actor GeminiClient { generationConfig: GeminiRequest.GenerationConfig( responseMimeType: nil, responseSchema: nil, - thinkingConfig: ThinkingConfig(thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) + thinkingConfig: ThinkingConfig( + thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) ) ) @@ -659,7 +720,8 @@ actor GeminiClient { generationConfig: GeminiRequest.GenerationConfig( responseMimeType: "application/json", responseSchema: responseSchema, - thinkingConfig: ThinkingConfig(thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) + thinkingConfig: ThinkingConfig( + thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) ) ) @@ -711,7 +773,8 @@ actor GeminiClient { for part in content.parts { if let inline = part.inlineData { let mime = inline.mimeType.lowercased() - guard let rawImageData = Data(base64Encoded: inline.data, options: [.ignoreUnknownCharacters]) + guard + let rawImageData = Data(base64Encoded: inline.data, options: [.ignoreUnknownCharacters]) else { continue } @@ -757,7 +820,6 @@ actor GeminiClient { } - // MARK: - Tool Calling Support /// Wrapper for dynamic JSON values in function arguments @@ -829,7 +891,9 @@ struct GeminiTool: Encodable { case nestedRequired = "required" } - init(type: String, description: String = "", enumValues: [String]? = nil, items: Items? = nil) { + init( + type: String, description: String = "", enumValues: [String]? = nil, items: Items? = nil + ) { self.type = type self.description = description.isEmpty ? nil : description self.enum = enumValues @@ -857,7 +921,6 @@ struct GeminiTool: Encodable { } } - /// Result of a tool-enabled chat (may include tool calls) struct ToolChatResult { let text: String @@ -1018,11 +1081,9 @@ extension GeminiClient { for attempt in 0...maxRetries { do { if transport == .hybridOpenAICompatible { - let settings = try await HybridDaemonSettingsCache.shared.settings() - guard let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) else { - throw GeminiClientError.missingAPIKey - } - let allowVision = HybridVisionProvider.isConfigured(settings: settings) + let config = try await proactiveProviderConfig() + let visionConfig = try await visionProviderConfig() + let allowVision = visionConfig != nil do { let contentsForHybrid: [GeminiImageToolRequest.Content] if allowVision { @@ -1031,7 +1092,7 @@ extension GeminiClient { contentsForHybrid = try await hybridContentsWithOCRInsteadOfImages(contents) } return try await HybridLLMClient.performGeminiCompatibleToolRound( - config: config, + config: visionConfig ?? config, systemPrompt: systemPrompt, contents: contentsForHybrid, tools: tools, @@ -1059,7 +1120,8 @@ extension GeminiClient { parts: [.init(text: systemPrompt)] ), generationConfig: GeminiImageToolRequest.GenerationConfig( - thinkingConfig: ThinkingConfig(thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) + thinkingConfig: ThinkingConfig( + thinkingBudget: max(thinkingBudget, ThinkingConfig.minimumBudget(for: model))) ), tools: tools, toolConfig: toolConfig diff --git a/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift b/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift index fce16ae3b5d..8a6c041dcf7 100644 --- a/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift +++ b/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift @@ -6,8 +6,8 @@ final class HybridLLMProviderConfigTests: XCTestCase { func testResolveEffectivePrefersChatProviderOverAiProvider() throws { let payload = """ - [{"key":"chat_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://chat.local/v1\\",\\"model\\":\\"m-chat\\"}","updated_at":"2026-05-19T12:00:00Z"},{"key":"ai_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://ai.local/v1\\",\\"model\\":\\"m-ai\\"}","updated_at":"2026-05-19T12:00:00Z"}] - """ + [{"key":"chat_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://chat.local/v1\\",\\"model\\":\\"m-chat\\"}","updated_at":"2026-05-19T12:00:00Z"},{"key":"ai_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://ai.local/v1\\",\\"model\\":\\"m-ai\\"}","updated_at":"2026-05-19T12:00:00Z"}] + """ let settings = try decodeSettings(payload) let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) XCTAssertEqual(config?.baseURL, "http://chat.local/v1") @@ -16,8 +16,8 @@ final class HybridLLMProviderConfigTests: XCTestCase { func testResolveEffectiveFallsBackToProviderAlias() throws { let payload = """ - [{"key":"provider","value_json":"{\\"kind\\":\\"openai\\",\\"base_url\\":\\"http://legacy.local/v1\\",\\"model\\":\\"legacy\\"}","updated_at":"2026-05-19T12:00:00Z"}] - """ + [{"key":"provider","value_json":"{\\"kind\\":\\"openai\\",\\"base_url\\":\\"http://legacy.local/v1\\",\\"model\\":\\"legacy\\"}","updated_at":"2026-05-19T12:00:00Z"}] + """ let settings = try decodeSettings(payload) let config = HybridLLMClient.resolveEffectiveChatConfig(settings: settings) XCTAssertEqual(config?.baseURL, "http://legacy.local/v1") @@ -26,8 +26,8 @@ final class HybridLLMProviderConfigTests: XCTestCase { func testVisionProviderLoadsWhenConfigured() throws { let payload = """ - [{"key":"vision_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://vision.local/v1\\",\\"model\\":\\"vlm\\"}","updated_at":"2026-05-19T12:00:00Z"}] - """ + [{"key":"vision_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://vision.local/v1\\",\\"model\\":\\"vlm\\"}","updated_at":"2026-05-19T12:00:00Z"}] + """ let settings = try decodeSettings(payload) let config = HybridLLMClient.loadVisionProviderConfig(from: settings) XCTAssertEqual(config?.baseURL, "http://vision.local/v1") @@ -36,17 +36,136 @@ final class HybridLLMProviderConfigTests: XCTestCase { func testHybridChatClientFallsBackToAiProvider() throws { let payload = """ - [{"key":"ai_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://ai.local/v1\\",\\"model\\":\\"m-ai\\"}","updated_at":"2026-05-19T12:00:00Z"}] - """ + [{"key":"ai_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://ai.local/v1\\",\\"model\\":\\"m-ai\\"}","updated_at":"2026-05-19T12:00:00Z"}] + """ let settings = try decodeSettings(payload) let config = HybridChatClient.resolveEffectiveChatConfig(from: settings) XCTAssertEqual(config?.baseURL, "http://ai.local/v1") XCTAssertEqual(config?.model, "m-ai") } + func testProactiveUsesProviderPolicySlot() throws { + let settings = try makeSettings([ + ( + "provider_policy", + [ + "version": 1, + "provider_accounts": [ + [ + "id": "local-proactive", + "kind": "openai_compatible", + "base_url": "http://proactive.local/v1", + "api_key": "test-key", + ] + ], + "model_slots": [ + "proactive": [ + "provider_account_id": "local-proactive", + "model_id": "gpt-5.4-mini", + "options": ["json_mode": true], + ] + ], + ] + ) + ]) + + let config = HybridLLMClient.resolveEffectiveProactiveConfig(settings: settings) + + XCTAssertEqual(config?.baseURL, "http://proactive.local/v1") + XCTAssertEqual(config?.model, "gpt-5.4-mini") + XCTAssertEqual(config?.apiKey, "test-key") + } + + func testProactiveDoesNotFallBackToChatProvider() throws { + let settings = try makeSettings([ + ( + "chat_provider", + [ + "kind": "openai_compatible", + "base_url": "http://chat.local/v1", + "model": "chat-model", + ] + ) + ]) + + XCTAssertNil(HybridLLMClient.resolveEffectiveProactiveConfig(settings: settings)) + } + + func testProactiveResolutionSurfacesMissingProviderReason() throws { + let settings = try makeSettings([ + ( + "provider_policy", + [ + "version": 1, + "provider_accounts": [], + "model_slots": [ + "proactive": [ + "provider_account_id": NSNull(), + "model_id": "gpt-5.4-mini", + "options": ["json_mode": true], + ] + ], + ] + ) + ]) + + let response = HybridProviderPolicy.resolveSlotFromSettings("proactive", settings: settings) + + XCTAssertEqual(response?.resolution.ok, false) + XCTAssertTrue(response?.resolution.reason.contains("no provider account") == true) + XCTAssertEqual(response?.resolved?.modelID, "gpt-5.4-mini") + XCTAssertNil(HybridProviderPolicy.providerConfig(from: response!)) + } + + func testVisionFallsBackToOCRWhenNoVisionSlotExists() throws { + let settings = try makeSettings([ + ( + "provider_policy", + [ + "version": 1, + "provider_accounts": [ + [ + "id": "local-proactive", + "kind": "openai_compatible", + "base_url": "http://proactive.local/v1", + ] + ], + "model_slots": [ + "proactive": [ + "provider_account_id": "local-proactive", + "model_id": "gpt-5.4-mini", + "options": ["json_mode": true], + ] + ], + ] + ) + ]) + + XCTAssertNil(HybridVisionProvider.providerConfig(settings: settings)) + } + private func decodeSettings(_ jsonArray: String) throws -> [LocalDaemonSetting] { let decoder = JSONDecoder() decoder.dateDecodingStrategy = .iso8601 return try decoder.decode([LocalDaemonSetting].self, from: Data(jsonArray.utf8)) } + + private func makeSettings(_ rows: [(String, [String: Any])]) throws -> [LocalDaemonSetting] { + let payloadRows = try rows.map { key, value in + let valueJsonData = try JSONSerialization.data(withJSONObject: value) + guard let valueJson = String(data: valueJsonData, encoding: .utf8) else { + struct EncodeError: Error {} + throw EncodeError() + } + return [ + "key": key, + "value_json": valueJson, + "updated_at": "2026-05-20T12:00:00Z", + ] + } + let payload = try JSONSerialization.data(withJSONObject: payloadRows) + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .iso8601 + return try decoder.decode([LocalDaemonSetting].self, from: payload) + } } diff --git a/desktop/local-backend/README.md b/desktop/local-backend/README.md index a0f8c2e6f65..eda9d84967c 100644 --- a/desktop/local-backend/README.md +++ b/desktop/local-backend/README.md @@ -136,6 +136,12 @@ Local processing resolves the `post_transcript` model slot from deterministic fallback metadata; with an OpenAI-compatible account it persists model-derived title, overview, memories, action items, and provenance. +Desktop proactive assistants resolve `/v1/provider-policy/resolve/proactive` +for local daemon AI calls and `/v1/provider-policy/resolve/vision` only for +optional screenshot multimodal input. If `vision` is unavailable, assistants use +local Rewind/macOS OCR text. This local path does not require Omi-hosted provider +proxies, cloud screen-activity sync, or embedding provider readiness. + ## Architecture And E2E Validation The durable MVP architecture note and validation checklist live in diff --git a/desktop/local-backend/docs/architecture.md b/desktop/local-backend/docs/architecture.md index 0762ce2821e..be4ed243522 100644 --- a/desktop/local-backend/docs/architecture.md +++ b/desktop/local-backend/docs/architecture.md @@ -81,6 +81,15 @@ deterministic fallbacks: Provider keys stay in local daemon settings and are not sent to Omi-hosted services. +Proactive assistants in desktop local daemon mode resolve the daemon +`proactive` model slot before making AI calls. They do not construct Omi-hosted +Gemini or chat-completion proxy URLs in this mode, and they do not treat +`chat_provider` or embeddings as readiness requirements for proactive work. +Screenshot analysis resolves the optional `vision` slot separately; without a +configured vision provider the desktop app uses local Rewind/macOS OCR text and +records the OCR-only path in logs. Memory retrieval remains local wiki/FTS over +the daemon/Rewind SQLite stores. + ## Cloud Sync Boundary Cloud sync is a future optional adapter. The local database remains the source diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md index ce02d4a99e3..960f082dfcf 100644 --- a/desktop/local-backend/docs/hybrid-provider-settings.md +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -104,6 +104,24 @@ Desktop clients should prefer these daemon APIs over manual JSON editing: Callers should resolve `post_transcript`, `proactive`, and `chat` slots explicitly instead of duplicating the legacy setting-key scan order. +## Proactive assistants + +In desktop local daemon mode, proactive assistant model calls resolve +`GET /v1/provider-policy/resolve/proactive` and use the returned +OpenAI-compatible provider account/model. The slot defaults to `gpt-5.4-mini`, +but resolution is not actionable until the policy includes a provider account +for that slot. The desktop client must surface the daemon's resolution reason +instead of silently falling back to `chat_provider`, `ai_provider`, Omi-hosted +Gemini proxy URLs, or Omi-hosted chat completion proxies. + +Screenshot-aware assistants resolve `GET /v1/provider-policy/resolve/vision` +separately. When that slot resolves to an allowed provider, the assistant may +send multimodal screenshot input to that provider. When the slot is missing or +unavailable, the desktop app uses local macOS OCR/Rewind text from the captured +screen and logs the OCR-only path. Local proactive memory/task/conversation +context comes from local daemon/Rewind SQLite data and FTS-backed local wiki +search; embedding provider readiness is not required. + ## Post-transcript processing Finalized transcript jobs resolve `/v1/provider-policy/resolve/post_transcript` From 58f06ad77c69418daddaf63679af9d603559a243 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 14:00:52 +0700 Subject: [PATCH 35/58] Make local Ask Omi use chat slot policy --- .../Sources/DesktopBackendEnvironment.swift | 2 +- .../Desktop/Sources/HybridChatClient.swift | 131 ++++------ .../Sources/HybridProviderPolicy.swift | 24 ++ .../Sources/HybridProviderReadiness.swift | 10 +- .../Sources/Providers/ChatProvider.swift | 75 +++++- .../Desktop/Tests/APIClientRoutingTests.swift | 18 ++ .../Desktop/Tests/HybridChatClientTests.swift | 242 ++++++++++++++++++ .../Tests/HybridLLMProviderConfigTests.swift | 36 ++- .../docs/hybrid-provider-settings.md | 8 +- desktop/local-backend/src/providers.rs | 97 +++++++ desktop/local-backend/src/storage.rs | 51 ++++ 11 files changed, 594 insertions(+), 100 deletions(-) create mode 100644 desktop/Desktop/Tests/HybridChatClientTests.swift diff --git a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift index 54077175fb9..b05b7a7d904 100644 --- a/desktop/Desktop/Sources/DesktopBackendEnvironment.swift +++ b/desktop/Desktop/Sources/DesktopBackendEnvironment.swift @@ -233,7 +233,7 @@ enum DesktopBackendEnvironment { "Apple Speech is not available for this Mac’s preferred languages (or Speech Recognition is off in System Settings). Set OMI_HYBRID_DIRECT_STT_ENABLED=1 to opt in when the engine is available." case .directChat: return - "Direct local chat requires OMI_HYBRID_DIRECT_CHAT_ENABLED=1 and a chat_provider in hybrid settings." + "Direct local chat requires OMI_HYBRID_DIRECT_CHAT_ENABLED=1 and a resolved chat slot in local provider policy." case .directEmbeddings: return "Direct local embeddings require OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=1 and an embedding_provider in hybrid settings." diff --git a/desktop/Desktop/Sources/HybridChatClient.swift b/desktop/Desktop/Sources/HybridChatClient.swift index 7dca650bf67..5abe5d9e5bb 100644 --- a/desktop/Desktop/Sources/HybridChatClient.swift +++ b/desktop/Desktop/Sources/HybridChatClient.swift @@ -7,28 +7,56 @@ enum HybridChatClient { let baseURL: String let model: String let apiKey: String + let providerAccountID: String? + let providerKind: String? + let slotSource: String? + let resolutionReason: String? + + init( + baseURL: String, + model: String, + apiKey: String, + providerAccountID: String? = nil, + providerKind: String? = nil, + slotSource: String? = nil, + resolutionReason: String? = nil + ) { + self.baseURL = baseURL + self.model = model + self.apiKey = apiKey + self.providerAccountID = providerAccountID + self.providerKind = providerKind + self.slotSource = slotSource + self.resolutionReason = resolutionReason + } } struct CompletionResult: Equatable { let text: String let model: String + let providerAccountID: String? + let providerKind: String? + let slotSource: String? + let resolutionReason: String? let inputTokens: Int let outputTokens: Int } enum ClientError: LocalizedError { - case notConfigured + case notConfigured(String) case invalidSettings case invalidResponse case providerError(String) var errorDescription: String? { switch self { - case .notConfigured: - return - "Hybrid direct chat is not configured. Set chat_provider or ai_provider in Settings → Plan and Usage (or run a local LLM at the default Ollama URL)." + case .notConfigured(let reason): + if reason.isEmpty { + return "Chat model slot is not configured. Configure the chat slot in local provider policy." + } + return "Chat model slot is not configured: \(reason)" case .invalidSettings: - return "chat_provider settings are invalid." + return "Chat provider policy settings are invalid." case .invalidResponse: return "Chat provider returned an unexpected response." case .providerError(let message): @@ -50,74 +78,13 @@ enum HybridChatClient { return true } - /// Resolves Codex → chat_provider → ai_provider / provider (matches HybridLLMClient). - static func resolveEffectiveChatConfig(from settings: [LocalDaemonSetting]) -> ProviderConfig? { - if let codex = HybridLLMClient.codexProviderConfig() { - return ProviderConfig( - baseURL: codex.baseURL, - model: codex.model, - apiKey: codex.apiKey - ) - } - if let chat = loadProviderConfig(from: settings, key: "chat_provider") { - return chat - } - if let ai = loadProviderConfig(from: settings, keys: ["ai_provider", "provider"]) { - return ai - } - return byokOpenAIConfig() - } - - static func loadProviderConfig(from settings: [LocalDaemonSetting]) -> ProviderConfig? { - loadProviderConfig(from: settings, key: "chat_provider") - } - - private static func loadProviderConfig( - from settings: [LocalDaemonSetting], - key: String + static func resolveEffectiveChatConfig( + from response: HybridProviderPolicy.SlotResolutionResponse ) -> ProviderConfig? { - loadProviderConfig(from: settings, keys: [key]) + HybridProviderPolicy.chatProviderConfig(from: response) } - private static func loadProviderConfig( - from settings: [LocalDaemonSetting], - keys: [String] - ) -> ProviderConfig? { - guard let raw = settings.first(where: { keys.contains($0.key) })?.valueJson, - let data = raw.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] - else { - return nil - } - return parseOpenAICompatible(json: json) - } - - private static func parseOpenAICompatible(json: [String: Any]) -> ProviderConfig? { - let kind = (json["kind"] as? String)?.lowercased() ?? "" - guard kind == "openai_compatible" || kind == "openai" else { - return nil - } - guard let baseURL = json["base_url"] as? String, !baseURL.isEmpty else { - return nil - } - let model = (json["model"] as? String) ?? HybridProviderReadiness.defaultModel() - let apiKey = - (json["api_key"] as? String) ?? (json["key"] as? String) ?? "" - return ProviderConfig(baseURL: baseURL, model: model, apiKey: apiKey) - } - - private static func byokOpenAIConfig() -> ProviderConfig? { - guard let key = APIKeyService.byokKey(.openai), !key.isEmpty else { - return nil - } - let model = - ProcessInfo.processInfo.environment["OMI_HYBRID_BYOK_OPENAI_MODEL"].flatMap { - $0.trimmingCharacters(in: .whitespacesAndNewlines) - }.flatMap { $0.isEmpty ? nil : $0 } ?? "gpt-4o-mini" - return ProviderConfig(baseURL: "https://api.openai.com/v1", model: model, apiKey: key) - } - - /// Loads daemon hybrid settings and completes one chat turn (non-streaming). + /// Resolves the daemon chat slot and completes one chat turn (non-streaming). static func completeFromDaemonSettings( systemPrompt: String, conversationMessages: [(role: String, text: String)], @@ -126,12 +93,13 @@ enum HybridChatClient { if CodexAuthService.isActive { await CodexProxyService.shared.ensureRunning() } - let settings = try await APIClient.shared.getSelectedBackendSettings() + let resolution = try await APIClient.shared.resolveSelectedBackendProviderSlot( + HybridProviderPolicy.chatSlot) return try await complete( systemPrompt: systemPrompt, conversationMessages: conversationMessages, userMessage: userMessage, - settings: settings + slotResolution: resolution ) } @@ -139,16 +107,18 @@ enum HybridChatClient { systemPrompt: String, conversationMessages: [(role: String, text: String)], userMessage: String, - settings: [LocalDaemonSetting] + slotResolution: HybridProviderPolicy.SlotResolutionResponse, + session: URLSession = .shared ) async throws -> CompletionResult { - guard let config = resolveEffectiveChatConfig(from: settings) else { - throw ClientError.notConfigured + guard let config = resolveEffectiveChatConfig(from: slotResolution) else { + throw ClientError.notConfigured(slotResolution.resolution.reason) } return try await completeOpenAICompatible( config: config, systemPrompt: systemPrompt, conversationMessages: conversationMessages, - userMessage: userMessage + userMessage: userMessage, + session: session ) } @@ -167,7 +137,8 @@ enum HybridChatClient { config: ProviderConfig, systemPrompt: String, conversationMessages: [(role: String, text: String)], - userMessage: String + userMessage: String, + session: URLSession ) async throws -> CompletionResult { let base = config.baseURL.hasSuffix("/") ? String(config.baseURL.dropLast()) : config.baseURL guard let url = URL(string: "\(base)/chat/completions") else { @@ -196,7 +167,7 @@ enum HybridChatClient { ) request.httpBody = try JSONEncoder().encode(payload) - let (data, response) = try await URLSession.shared.data(for: request) + let (data, response) = try await session.data(for: request) guard let http = response as? HTTPURLResponse else { throw ClientError.invalidResponse } @@ -233,6 +204,10 @@ enum HybridChatClient { return CompletionResult( text: content.trimmingCharacters(in: .whitespacesAndNewlines), model: returnedModel, + providerAccountID: config.providerAccountID, + providerKind: config.providerKind, + slotSource: config.slotSource, + resolutionReason: config.resolutionReason, inputTokens: inputTokens, outputTokens: outputTokens ) diff --git a/desktop/Desktop/Sources/HybridProviderPolicy.swift b/desktop/Desktop/Sources/HybridProviderPolicy.swift index 07ab92ec2c8..60ce6fe04ac 100644 --- a/desktop/Desktop/Sources/HybridProviderPolicy.swift +++ b/desktop/Desktop/Sources/HybridProviderPolicy.swift @@ -1,6 +1,7 @@ import Foundation enum HybridProviderPolicy { + static let chatSlot = "chat" static let proactiveSlot = "proactive" static let visionSlot = "vision" @@ -85,6 +86,29 @@ enum HybridProviderPolicy { ) } + static func chatProviderConfig(from response: SlotResolutionResponse) -> HybridChatClient + .ProviderConfig? + { + guard response.resolution.ok, + let resolved = response.resolved, + let account = resolved.providerAccount, + let baseURL = account.baseURL?.trimmingCharacters(in: .whitespacesAndNewlines), + !baseURL.isEmpty, + isOpenAICompatible(kind: account.kind) + else { + return nil + } + return HybridChatClient.ProviderConfig( + baseURL: baseURL, + model: resolved.modelID, + apiKey: account.apiKey ?? "", + providerAccountID: account.id, + providerKind: account.kind, + slotSource: resolved.source, + resolutionReason: response.resolution.reason + ) + } + static func resolveSlotFromSettings( _ slot: String, settings: [LocalDaemonSetting] diff --git a/desktop/Desktop/Sources/HybridProviderReadiness.swift b/desktop/Desktop/Sources/HybridProviderReadiness.swift index a19cc8fa095..bc659b816cd 100644 --- a/desktop/Desktop/Sources/HybridProviderReadiness.swift +++ b/desktop/Desktop/Sources/HybridProviderReadiness.swift @@ -64,7 +64,11 @@ enum HybridProviderReadiness { ?? "Direct STT unavailable") )) - let chatResolvable = HybridChatClient.resolveEffectiveChatConfig(from: settings) != nil + let chatResolution = HybridProviderPolicy.resolveSlotFromSettings( + HybridProviderPolicy.chatSlot, + settings: settings + ) + let chatResolvable = chatResolution.flatMap(HybridChatClient.resolveEffectiveChatConfig) != nil let chatCap = DesktopBackendEnvironment.isCapability(.directChat, availableIn: .localDaemon) result.append( Row( @@ -74,9 +78,9 @@ enum HybridProviderReadiness { ? .configured : (chatCap ? .missing : .capabilityOff), detail: chatResolvable - ? "Direct chat endpoint configured" + ? "Direct chat endpoint configured through provider policy" : (chatCap - ? "Set chat_provider or ai_provider in hybrid settings" + ? (chatResolution?.resolution.reason ?? "Configure the chat model slot") : (DesktopBackendEnvironment.unavailableReason(for: .directChat, in: .localDaemon) ?? "Direct chat disabled")) )) diff --git a/desktop/Desktop/Sources/Providers/ChatProvider.swift b/desktop/Desktop/Sources/Providers/ChatProvider.swift index 6ff7977da6c..b19a1ca56d9 100644 --- a/desktop/Desktop/Sources/Providers/ChatProvider.swift +++ b/desktop/Desktop/Sources/Providers/ChatProvider.swift @@ -2608,6 +2608,10 @@ A screenshot may be attached — use it silently only if relevant. Never mention var sqlRowsReturned = 0 var sqlQueryCount = 0 var hybridResolvedModel: String? + var hybridProviderAccountId: String? + var hybridProviderKind: String? + var hybridSlotSource: String? + var hybridSlotReason: String? do { if mayUseHybridDirectChat { @@ -2681,6 +2685,10 @@ A screenshot may be attached — use it silently only if relevant. Never mention userMessage: trimmedText ) hybridResolvedModel = hybrid.model + hybridProviderAccountId = hybrid.providerAccountID + hybridProviderKind = hybrid.providerKind + hybridSlotSource = hybrid.slotSource + hybridSlotReason = hybrid.resolutionReason let normalized = normalizeAssistantSentenceSpacing(hybrid.text) queryResult = AgentBridge.QueryResult( text: normalized, @@ -2870,7 +2878,14 @@ A screenshot may be attached — use it silently only if relevant. Never mention let textToSave = queryResult.text.isEmpty ? messageText : queryResult.text if !textToSave.isEmpty { do { - let toolMetadata = serializeToolCallMetadata(messageId: aiMessageId) + let toolMetadata = serializeAIMessageMetadata( + messageId: aiMessageId, + hybridModel: hybridResolvedModel, + hybridProviderAccountId: hybridProviderAccountId, + hybridProviderKind: hybridProviderKind, + hybridSlotSource: hybridSlotSource, + hybridSlotReason: hybridSlotReason + ) let response = try await APIClient.shared.saveMessage( text: textToSave, sender: "ai", @@ -3269,11 +3284,41 @@ A screenshot may be attached — use it silently only if relevant. Never mention } } - /// Serialize tool calls from a message's contentBlocks into a JSON metadata string. - /// Returns nil if there are no tool calls. - private func serializeToolCallMetadata(messageId: String) -> String? { - guard let index = messages.firstIndex(where: { $0.id == messageId }) else { return nil } + /// Serialize tool calls and resolved local-provider metadata into the persisted message JSON. + /// Returns nil when there is no extra metadata to store. + private func serializeAIMessageMetadata( + messageId: String, + hybridModel: String?, + hybridProviderAccountId: String?, + hybridProviderKind: String?, + hybridSlotSource: String?, + hybridSlotReason: String? + ) -> String? { + var metadata: [String: Any] = [:] + + if let hybridModel = hybridModel { + var provider: [String: Any] = [ + "slot": "chat", + "model": hybridModel, + ] + if let hybridProviderAccountId = hybridProviderAccountId { + provider["provider_account_id"] = hybridProviderAccountId + } + if let hybridProviderKind = hybridProviderKind { + provider["provider_kind"] = hybridProviderKind + } + if let hybridSlotSource = hybridSlotSource { + provider["source"] = hybridSlotSource + } + if let hybridSlotReason = hybridSlotReason { + provider["reason"] = hybridSlotReason + } + metadata["provider_policy"] = provider + } + guard let index = messages.firstIndex(where: { $0.id == messageId }) else { + return metadata.isEmpty ? nil : Self.encodeMetadata(metadata) + } var toolCalls: [[String: Any]] = [] for block in messages[index].contentBlocks { if case .toolCall(_, let name, _, let toolUseId, let input, let output) = block { @@ -3291,14 +3336,30 @@ A screenshot may be attached — use it silently only if relevant. Never mention } } - guard !toolCalls.isEmpty else { return nil } + if !toolCalls.isEmpty { + metadata["tool_calls"] = toolCalls + } + + return metadata.isEmpty ? nil : Self.encodeMetadata(metadata) + } - let metadata: [String: Any] = ["tool_calls": toolCalls] + private static func encodeMetadata(_ metadata: [String: Any]) -> String? { guard let data = try? JSONSerialization.data(withJSONObject: metadata), let json = String(data: data, encoding: .utf8) else { return nil } return json } + private func serializeToolCallMetadata(messageId: String) -> String? { + serializeAIMessageMetadata( + messageId: messageId, + hybridModel: nil, + hybridProviderAccountId: nil, + hybridProviderKind: nil, + hybridSlotSource: nil, + hybridSlotReason: nil + ) + } + // MARK: - Message Rating /// Rate a message (thumbs up/down) diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index cf92c663873..59a65a27117 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -1607,6 +1607,24 @@ final class APIClientRoutingTests: XCTestCase { label: "local saveMessage default session") } + func testLocalModeResolveChatSlotRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + defer { + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") + } + let client = await makeTestClient() + _ = try? await client.resolveSelectedBackendProviderSlot("chat") + as HybridProviderPolicy.SlotResolutionResponse + assertRoutes( + URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v1/provider-policy/resolve/chat", + method: "GET", + label: "local resolve chat slot") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + func testLocalModeGetMessagesForSessionRoutesToLocalDaemon() async { setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) diff --git a/desktop/Desktop/Tests/HybridChatClientTests.swift b/desktop/Desktop/Tests/HybridChatClientTests.swift new file mode 100644 index 00000000000..88acb33abfa --- /dev/null +++ b/desktop/Desktop/Tests/HybridChatClientTests.swift @@ -0,0 +1,242 @@ +import XCTest + +@testable import Omi_Computer + +private final class ChatProviderCapture: URLProtocol, @unchecked Sendable { + private static let lock = NSLock() + private static var _requests: [URLRequest] = [] + private static var _bodies: [Data] = [] + private static var responseModel = "stub-model" + + static var bodies: [Data] { + lock.lock() + defer { lock.unlock() } + return _bodies + } + + static var requests: [URLRequest] { + lock.lock() + defer { lock.unlock() } + return _requests + } + + static func reset(model: String = "stub-model") { + lock.lock() + _requests.removeAll() + _bodies.removeAll() + responseModel = model + lock.unlock() + } + + private static func record(request: URLRequest, body: Data) { + lock.lock() + _requests.append(request) + _bodies.append(body) + lock.unlock() + } + + private static func bodyData(from request: URLRequest) -> Data { + if let body = request.httpBody { + return body + } + guard let stream = request.httpBodyStream else { + return Data() + } + stream.open() + defer { stream.close() } + + var data = Data() + let bufferSize = 4096 + let buffer = UnsafeMutablePointer.allocate(capacity: bufferSize) + defer { buffer.deallocate() } + while stream.hasBytesAvailable { + let count = stream.read(buffer, maxLength: bufferSize) + if count > 0 { + data.append(buffer, count: count) + } else { + break + } + } + return data + } + + override class func canInit(with request: URLRequest) -> Bool { true } + override class func canonicalRequest(for request: URLRequest) -> URLRequest { request } + + override func startLoading() { + let body = Self.bodyData(from: request) + Self.record(request: request, body: body) + let model = Self.responseModel + let response = HTTPURLResponse( + url: request.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + let payload = Data( + """ + {"model":"\(model)","choices":[{"message":{"content":"answer"}}],"usage":{"prompt_tokens":7,"completion_tokens":3}} + """.utf8) + client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) + client?.urlProtocol(self, didLoad: payload) + client?.urlProtocolDidFinishLoading(self) + } + + override func stopLoading() {} +} + +final class HybridChatClientTests: XCTestCase { + override func setUp() { + super.setUp() + ChatProviderCapture.reset() + } + + func testChatSlotResolutionBuildsProviderConfig() { + let response = slotResolution( + accountID: "local-a", + baseURL: "http://127.0.0.1:11434/v1", + model: "llama3.2" + ) + + let config = HybridChatClient.resolveEffectiveChatConfig(from: response) + + XCTAssertEqual(config?.baseURL, "http://127.0.0.1:11434/v1") + XCTAssertEqual(config?.model, "llama3.2") + XCTAssertEqual(config?.providerAccountID, "local-a") + XCTAssertEqual(config?.slotSource, "provider_policy") + } + + func testRequestBodyUsesSelectedChatSlotModel() async throws { + let session = capturedSession() + let response = slotResolution( + accountID: "local-chat", + baseURL: "http://127.0.0.1:11434/v1", + model: "selected-model" + ) + + let result = try await HybridChatClient.complete( + systemPrompt: "system", + conversationMessages: [(role: "assistant", text: "prior")], + userMessage: "hello", + slotResolution: response, + session: session + ) + + let body = try XCTUnwrap(ChatProviderCapture.bodies.first) + let json = try JSONSerialization.jsonObject(with: body) as? [String: Any] + XCTAssertEqual(json?["model"] as? String, "selected-model") + XCTAssertEqual(result.providerAccountID, "local-chat") + XCTAssertEqual(result.model, "stub-model") + } + + func testProviderAccountSwitchChangesRequestTargetAndModel() async throws { + let session = capturedSession() + + _ = try await HybridChatClient.complete( + systemPrompt: "system", + conversationMessages: [], + userMessage: "one", + slotResolution: slotResolution( + accountID: "provider-a", + baseURL: "http://127.0.0.1:11434/v1", + model: "model-a" + ), + session: session + ) + _ = try await HybridChatClient.complete( + systemPrompt: "system", + conversationMessages: [], + userMessage: "two", + slotResolution: slotResolution( + accountID: "provider-b", + baseURL: "http://localhost:43210/v1", + model: "model-b" + ), + session: session + ) + + let requests = ChatProviderCapture.requests + XCTAssertEqual(requests.map { $0.url?.absoluteString }, [ + "http://127.0.0.1:11434/v1/chat/completions", + "http://localhost:43210/v1/chat/completions", + ]) + let models = try ChatProviderCapture.bodies.map { body in + let json = try XCTUnwrap(JSONSerialization.jsonObject(with: body) as? [String: Any]) + return json["model"] as? String + } + XCTAssertEqual(models, ["model-a", "model-b"]) + } + + func testMissingChatSlotReasonIsSurfaced() async { + let response = HybridProviderPolicy.SlotResolutionResponse( + resolved: nil, + resolution: HybridProviderPolicy.SlotResolution( + slot: "chat", + ok: false, + resolved: nil, + reason: "model slot chat selects gpt-5.4-mini but no provider account is configured" + ) + ) + + do { + _ = try await HybridChatClient.complete( + systemPrompt: "system", + conversationMessages: [], + userMessage: "hello", + slotResolution: response, + session: capturedSession() + ) + XCTFail("expected missing provider error") + } catch { + XCTAssertTrue(error.localizedDescription.contains("no provider account")) + XCTAssertTrue(ChatProviderCapture.requests.isEmpty) + } + } + + func testLocalDaemonChatDoesNotUseOmiHostedEndpoints() async throws { + _ = try await HybridChatClient.complete( + systemPrompt: "system", + conversationMessages: [], + userMessage: "hello", + slotResolution: slotResolution( + accountID: "local-only", + baseURL: "http://127.0.0.1:9999/v1", + model: "local-model" + ), + session: capturedSession() + ) + + let url = try XCTUnwrap(ChatProviderCapture.requests.first?.url?.absoluteString) + XCTAssertTrue(url.hasPrefix("http://127.0.0.1:9999/v1/chat/completions")) + XCTAssertFalse(url.contains("omi.me")) + XCTAssertFalse(url.contains("omiapi.com")) + } + + private func capturedSession() -> URLSession { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [ChatProviderCapture.self] + return URLSession(configuration: config) + } + + private func slotResolution( + accountID: String, + baseURL: String, + model: String + ) -> HybridProviderPolicy.SlotResolutionResponse { + let account = HybridProviderPolicy.ProviderAccount( + id: accountID, + kind: "openai_compatible", + baseURL: baseURL, + apiKey: "test-key" + ) + let resolved = HybridProviderPolicy.ResolvedSlot( + slot: "chat", + providerAccount: account, + modelID: model, + source: "provider_policy" + ) + let resolution = HybridProviderPolicy.SlotResolution( + slot: "chat", + ok: true, + resolved: resolved, + reason: "chat resolved to \(model) from provider_policy" + ) + return HybridProviderPolicy.SlotResolutionResponse(resolved: resolved, resolution: resolution) + } +} diff --git a/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift b/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift index 8a6c041dcf7..79531c9afb7 100644 --- a/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift +++ b/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift @@ -34,14 +34,34 @@ final class HybridLLMProviderConfigTests: XCTestCase { XCTAssertEqual(config?.model, "vlm") } - func testHybridChatClientFallsBackToAiProvider() throws { - let payload = """ - [{"key":"ai_provider","value_json":"{\\"kind\\":\\"openai_compatible\\",\\"base_url\\":\\"http://ai.local/v1\\",\\"model\\":\\"m-ai\\"}","updated_at":"2026-05-19T12:00:00Z"}] - """ - let settings = try decodeSettings(payload) - let config = HybridChatClient.resolveEffectiveChatConfig(from: settings) - XCTAssertEqual(config?.baseURL, "http://ai.local/v1") - XCTAssertEqual(config?.model, "m-ai") + func testHybridChatClientUsesChatProviderPolicySlot() throws { + let settings = try makeSettings([ + ( + "provider_policy", + [ + "version": 1, + "provider_accounts": [ + [ + "id": "local-chat", + "kind": "openai_compatible", + "base_url": "http://chat.local/v1", + "api_key": "test-key", + ] + ], + "model_slots": [ + "chat": [ + "provider_account_id": "local-chat", + "model_id": "chat-model", + ] + ], + ] + ) + ]) + let response = try XCTUnwrap(HybridProviderPolicy.resolveSlotFromSettings("chat", settings: settings)) + let config = HybridChatClient.resolveEffectiveChatConfig(from: response) + XCTAssertEqual(config?.baseURL, "http://chat.local/v1") + XCTAssertEqual(config?.model, "chat-model") + XCTAssertEqual(config?.providerAccountID, "local-chat") } func testProactiveUsesProviderPolicySlot() throws { diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md index 960f082dfcf..cb4fd4da199 100644 --- a/desktop/local-backend/docs/hybrid-provider-settings.md +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -236,7 +236,7 @@ way. Environment flags (desktop process, not daemon): - `OMI_HYBRID_DIRECT_STT_ENABLED=1` — enables hybrid live transcription via Apple Speech in local daemon mode (also enabled by default when `desktop/run.sh` configures local mode; launcher writes this into bundled `.env` so GUI launches see it). -- `OMI_HYBRID_DIRECT_CHAT_ENABLED=1` — enables hybrid OpenAI-compatible chat (`HybridChatClient`) when combined with `chat_provider`; sessions/messages persist via daemon SQLite (`run.sh` defaults this on in local mode). +- `OMI_HYBRID_DIRECT_CHAT_ENABLED=1` — enables hybrid OpenAI-compatible chat (`HybridChatClient`) when the daemon `chat` slot resolves to a provider account; sessions/messages persist via daemon SQLite (`run.sh` defaults this on in local mode). - `OMI_HYBRID_OPTIONAL_CLOUD_STT=1` — exposes `optionalCloudSTT` capability - `OMI_HYBRID_OPTIONAL_CLOUD_CHAT=1` — exposes `optionalCloudChat` capability @@ -258,8 +258,10 @@ When the daemon starts via `make serve-local` or `desktop/run.sh` in local mode, | `OMI_HYBRID_DEFAULT_PROVIDER_ACCOUNT_ID` | `local-openai-compatible` | The desktop app also calls `HybridProviderBootstrap.ensureDefaultsIfNeeded()` on -local guest session startup. Chat resolves `chat_provider` → `ai_provider` → BYOK OpenAI -(see `HybridChatClient`). +local guest session startup. Ask Omi resolves +`GET /v1/provider-policy/resolve/chat` before each local direct chat request and +uses the returned provider account/model. Legacy `chat_provider` rows are only +read by the daemon compatibility bridge when constructing the typed policy. Configure or override in **Settings → Plan and Usage** (local mode) or via `PUT /v1/settings`. diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs index 7a987ae0dd5..2cf2d9299b5 100644 --- a/desktop/local-backend/src/providers.rs +++ b/desktop/local-backend/src/providers.rs @@ -1187,6 +1187,7 @@ mod tests { let store = Store::open_in_memory()?; assert!(resolve_model_slot(&store, SLOT_PROACTIVE)?.is_none()); + assert!(resolve_model_slot(&store, SLOT_CHAT)?.is_none()); assert!(resolve_model_slot(&store, SLOT_VISION)?.is_none()); let proactive = resolve_model_slot_result(&store, SLOT_PROACTIVE)?; assert!(!proactive.ok); @@ -1206,6 +1207,102 @@ mod tests { Ok(()) } + #[test] + fn chat_slot_resolution_switches_provider_account_and_model() -> Result<()> { + let store = Store::open_in_memory()?; + let account_a = ProviderAccount { + id: "local-a".to_string(), + kind: "openai_compatible".to_string(), + base_url: Some("http://127.0.0.1:11434/v1".to_string()), + api_key: None, + display_name: Some("Local A".to_string()), + capabilities: ProviderCapabilities { + chat_completions: true, + json_mode: true, + tool_calls: false, + vision: false, + speech_to_text: false, + }, + subscription_integration: None, + }; + let account_b = ProviderAccount { + id: "local-b".to_string(), + kind: "openai_compatible".to_string(), + base_url: Some("http://localhost:43210/v1".to_string()), + api_key: None, + display_name: Some("Local B".to_string()), + capabilities: ProviderCapabilities { + chat_completions: true, + json_mode: true, + tool_calls: false, + vision: false, + speech_to_text: false, + }, + subscription_integration: None, + }; + let mut slots = BTreeMap::new(); + slots.insert( + SLOT_CHAT.to_string(), + ModelSlotTarget { + provider_account_id: Some(account_a.id.clone()), + model_id: "model-a".to_string(), + options: ModelSlotOptions::default(), + }, + ); + save_provider_policy( + &store, + ProviderPolicy { + version: PROVIDER_POLICY_VERSION, + provider_accounts: vec![account_a.clone(), account_b.clone()], + model_slots: slots, + }, + )?; + + let first = configured_openai_provider_for_slot(&store, SLOT_CHAT)? + .expect("chat provider should resolve"); + let first_request = first + .provider + .build_chat_completions_request(vec![ChatMessage::user("hello")]); + assert_eq!(first.slot.provider_account.as_ref().unwrap().id, "local-a"); + assert_eq!( + first_request.url, + "http://127.0.0.1:11434/v1/chat/completions" + ); + assert_eq!(first_request.body["model"], "model-a"); + + let mut slots = BTreeMap::new(); + slots.insert( + SLOT_CHAT.to_string(), + ModelSlotTarget { + provider_account_id: Some(account_b.id.clone()), + model_id: "model-b".to_string(), + options: ModelSlotOptions::default(), + }, + ); + save_provider_policy( + &store, + ProviderPolicy { + version: PROVIDER_POLICY_VERSION, + provider_accounts: vec![account_a, account_b], + model_slots: slots, + }, + )?; + + let second = configured_openai_provider_for_slot(&store, SLOT_CHAT)? + .expect("switched chat provider should resolve"); + let second_request = second + .provider + .build_chat_completions_request(vec![ChatMessage::user("hello")]); + assert_eq!(second.slot.provider_account.as_ref().unwrap().id, "local-b"); + assert_eq!( + second_request.url, + "http://localhost:43210/v1/chat/completions" + ); + assert_eq!(second_request.body["model"], "model-b"); + + Ok(()) + } + #[test] fn allowed_override_and_disallowed_slot_are_validated() -> Result<()> { let account = ProviderAccount { diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index 35e4882cfe5..a87e5bcfc7a 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -3086,6 +3086,57 @@ mod tests { Ok(()) } + #[test] + fn chat_sessions_and_messages_persist_after_reopen() -> Result<()> { + let temp = tempdir()?; + let db_path = temp.path().join("local.sqlite"); + let session_id; + + { + let store = Store::open(&db_path)?; + let session = store + .chat_sessions() + .create_session(Some("Local chat"), Some("desktop"))?; + session_id = session.id; + store.chat_sessions().append_message( + &session_id, + "hello", + "human", + Some("desktop"), + Some("{\"provider_policy\":{\"slot\":\"chat\",\"model\":\"model-a\"}}"), + )?; + store.chat_sessions().append_message( + &session_id, + "answer", + "ai", + Some("desktop"), + Some("{\"provider_policy\":{\"slot\":\"chat\",\"model\":\"model-a\",\"provider_account_id\":\"local-a\"}}"), + )?; + } + + let reopened = Store::open(&db_path)?; + let session = reopened + .chat_sessions() + .get_session(&session_id)? + .expect("chat session should persist"); + let messages = + reopened + .chat_sessions() + .list_messages(&session_id, Some("desktop"), 10, 0)?; + + assert_eq!(session.title, "Local chat"); + assert_eq!(session.message_count, 2); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].text, "answer"); + assert!(messages[0] + .metadata + .as_deref() + .unwrap() + .contains("\"provider_account_id\":\"local-a\"")); + + Ok(()) + } + #[test] fn duplicate_transcript_append_is_existing_or_conflict() -> Result<()> { let store = Store::open_in_memory()?; From 90eb0b233ed4035d96a6acf64831a00bd01efed1 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 22:57:41 +0700 Subject: [PATCH 36/58] Expose local provider slot readiness --- desktop/Desktop/Sources/APIClient.swift | 68 ++++ .../Sources/HybridProviderPolicy.swift | 54 ++- .../Sources/HybridProviderReadiness.swift | 116 ++++-- .../MainWindow/Pages/SettingsPage.swift | 357 +++++++++++++----- .../Desktop/Tests/APIClientRoutingTests.swift | 18 + .../Desktop/Tests/HybridChatClientTests.swift | 5 +- .../Tests/HybridLLMProviderConfigTests.swift | 49 +++ .../docs/hybrid-provider-settings.md | 29 +- .../local-backend/docs/local-mvp-runbook.md | 65 +++- desktop/local-backend/src/main.rs | 17 +- desktop/local-backend/src/providers.rs | 38 ++ desktop/local-backend/src/routes.rs | 18 + 12 files changed, 673 insertions(+), 161 deletions(-) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index da158357255..ca0f75fb5b6 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -260,6 +260,58 @@ actor APIClient { ) } + func getSelectedBackendProviderPolicy() async throws -> HybridProviderPolicy.Policy { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + throw APIError.featureUnavailable( + feature: "get_provider_policy", + reason: "Provider policy is only available in local daemon mode." + ) + } + let response: LocalDaemonProviderPolicyResponse = try await get( + "v1/provider-policy", + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.providerPolicy + } + + func updateSelectedBackendProviderPolicy(_ policy: HybridProviderPolicy.Policy) async throws + -> HybridProviderPolicy.Policy + { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + throw APIError.featureUnavailable( + feature: "update_provider_policy", + reason: "Provider policy is only available in local daemon mode." + ) + } + let response: LocalDaemonProviderPolicyResponse = try await put( + "v1/provider-policy", + body: policy, + requireAuth: false, + customBaseURL: target.baseURL + ) + return response.providerPolicy + } + + func testSelectedBackendProviderSlot(_ slot: String) async throws -> LocalDaemonTestSlotResponse { + let target = selectedBackendTarget + guard target.mode == .localDaemon else { + throw APIError.featureUnavailable( + feature: "test_provider_slot", + reason: "Provider slot tests are only available in local daemon mode." + ) + } + let encodedSlot = slot.addingPercentEncoding(withAllowedCharacters: .urlPathAllowed) ?? slot + return try await post( + "v1/provider-policy/test-slot/\(encodedSlot)", + body: EmptyRequest(), + requireAuth: false, + customBaseURL: target.baseURL + ) + } + func resolveSelectedBackendProviderSlot(_ slot: String) async throws -> HybridProviderPolicy.SlotResolutionResponse { @@ -394,6 +446,16 @@ struct LocalDaemonSettingsResponse: Decodable { let settings: [LocalDaemonSetting] } +struct LocalDaemonProviderPolicyResponse: Decodable { + let providerPolicy: HybridProviderPolicy.Policy + + enum CodingKeys: String, CodingKey { + case providerPolicy = "provider_policy" + } +} + +struct EmptyRequest: Encodable {} + struct LocalDaemonTestProviderRequest: Encodable { let key: String } @@ -404,6 +466,12 @@ struct LocalDaemonTestProviderResponse: Decodable, Equatable { let message: String } +struct LocalDaemonTestSlotResponse: Decodable, Equatable { + let ok: Bool + let slot: String + let message: String +} + struct LocalDaemonSetting: Decodable, Equatable { let key: String let valueJson: String diff --git a/desktop/Desktop/Sources/HybridProviderPolicy.swift b/desktop/Desktop/Sources/HybridProviderPolicy.swift index 60ce6fe04ac..bc68a871afb 100644 --- a/desktop/Desktop/Sources/HybridProviderPolicy.swift +++ b/desktop/Desktop/Sources/HybridProviderPolicy.swift @@ -2,20 +2,57 @@ import Foundation enum HybridProviderPolicy { static let chatSlot = "chat" + static let postTranscriptSlot = "post_transcript" static let proactiveSlot = "proactive" static let visionSlot = "vision" + static let sttSlot = "stt" + static let memorySearchSlot = "memory_search" + static let localProviderAccountID = "local-openai-compatible" + static let localWikiModel = "local_wiki" - struct ProviderAccount: Decodable, Equatable { + struct ProviderCapabilities: Codable, Equatable { + let chatCompletions: Bool + let jsonMode: Bool + let toolCalls: Bool + let vision: Bool + let speechToText: Bool + + enum CodingKeys: String, CodingKey { + case chatCompletions = "chat_completions" + case jsonMode = "json_mode" + case toolCalls = "tool_calls" + case vision + case speechToText = "speech_to_text" + } + } + + struct ModelSlotOptions: Codable, Equatable { + let jsonMode: Bool? + let toolSupport: Bool? + + enum CodingKeys: String, CodingKey { + case jsonMode = "json_mode" + case toolSupport = "tool_support" + } + } + + struct ProviderAccount: Codable, Equatable { let id: String let kind: String let baseURL: String? let apiKey: String? + let displayName: String? + let capabilities: ProviderCapabilities? + let subscriptionIntegration: String? enum CodingKeys: String, CodingKey { case id case kind case baseURL = "base_url" case apiKey = "api_key" + case displayName = "display_name" + case capabilities + case subscriptionIntegration = "subscription_integration" } } @@ -45,7 +82,7 @@ enum HybridProviderPolicy { let resolution: SlotResolution } - struct Policy: Decodable { + struct Policy: Codable { let version: Int let providerAccounts: [ProviderAccount] let modelSlots: [String: ModelSlotTarget] @@ -57,13 +94,15 @@ enum HybridProviderPolicy { } } - struct ModelSlotTarget: Decodable { + struct ModelSlotTarget: Codable { let providerAccountID: String? let modelID: String + let options: ModelSlotOptions? enum CodingKeys: String, CodingKey { case providerAccountID = "provider_account_id" case modelID = "model_id" + case options } } @@ -142,4 +181,13 @@ enum HybridProviderPolicy { let normalized = kind.lowercased() return normalized == "openai_compatible" || normalized == "openai" } + + static func policyFromSettings(_ settings: [LocalDaemonSetting]) -> Policy? { + guard let raw = settings.first(where: { $0.key == "provider_policy" })?.valueJson, + let data = raw.data(using: .utf8) + else { + return nil + } + return try? JSONDecoder().decode(Policy.self, from: data) + } } diff --git a/desktop/Desktop/Sources/HybridProviderReadiness.swift b/desktop/Desktop/Sources/HybridProviderReadiness.swift index bc659b816cd..737305b98b8 100644 --- a/desktop/Desktop/Sources/HybridProviderReadiness.swift +++ b/desktop/Desktop/Sources/HybridProviderReadiness.swift @@ -37,31 +37,42 @@ enum HybridProviderReadiness { return "llama3.2" } + static func defaultSmallModel() -> String { + "gpt-5.4-mini" + } + static func rows(from settings: [LocalDaemonSetting]) -> [Row] { var result: [Row] = [] + let policy = HybridProviderPolicy.policyFromSettings(settings) - let aiConfigured = hasOpenAICompatibleProvider( - in: settings, keys: ["ai_provider", "provider"]) + let postResolution = HybridProviderPolicy.resolveSlotFromSettings( + HybridProviderPolicy.postTranscriptSlot, + settings: settings + ) + let postConfigured = hasResolvedOpenAICompatibleSlot(postResolution) result.append( Row( - id: "ai_provider", - label: "Processing (ai_provider)", - status: aiConfigured ? .configured : .optionalFallback, - detail: aiConfigured - ? "OpenAI-compatible provider configured" - : "Optional — deterministic fallback when unset" + id: HybridProviderPolicy.postTranscriptSlot, + label: "Post-transcript processing", + status: postConfigured ? .configured : .optionalFallback, + detail: postConfigured + ? slotDetail(postResolution) + : "Defaults to \(defaultSmallModel()); deterministic fallback when no provider account is configured" )) - let sttAvailable = DesktopBackendEnvironment.isCapability(.directSTT, availableIn: .localDaemon) + let proactiveResolution = HybridProviderPolicy.resolveSlotFromSettings( + HybridProviderPolicy.proactiveSlot, + settings: settings + ) + let proactiveConfigured = hasResolvedOpenAICompatibleSlot(proactiveResolution) result.append( Row( - id: "stt", - label: "Live transcription", - status: sttAvailable ? .configured : .capabilityOff, - detail: sttAvailable - ? "On-device Apple Speech (no daemon key)" - : (DesktopBackendEnvironment.unavailableReason(for: .directSTT, in: .localDaemon) - ?? "Direct STT unavailable") + id: HybridProviderPolicy.proactiveSlot, + label: "Proactive assistants", + status: proactiveConfigured ? .configured : .missing, + detail: proactiveConfigured + ? slotDetail(proactiveResolution) + : "Defaults to \(defaultSmallModel()); configure a provider account to enable proactive AI calls" )) let chatResolution = HybridProviderPolicy.resolveSlotFromSettings( @@ -72,35 +83,53 @@ enum HybridProviderReadiness { let chatCap = DesktopBackendEnvironment.isCapability(.directChat, availableIn: .localDaemon) result.append( Row( - id: "chat_provider", - label: "Chat (chat_provider)", + id: HybridProviderPolicy.chatSlot, + label: "Chat", status: chatResolvable && chatCap ? .configured : (chatCap ? .missing : .capabilityOff), detail: chatResolvable - ? "Direct chat endpoint configured through provider policy" + ? slotDetail(chatResolution) : (chatCap ? (chatResolution?.resolution.reason ?? "Configure the chat model slot") : (DesktopBackendEnvironment.unavailableReason(for: .directChat, in: .localDaemon) ?? "Direct chat disabled")) )) - let embedConfigured = HybridEmbeddingClient.loadProviderConfig(from: settings) != nil - let embedCap = DesktopBackendEnvironment.isCapability( - .directEmbeddings, availableIn: .localDaemon) + let visionResolution = HybridProviderPolicy.resolveSlotFromSettings( + HybridProviderPolicy.visionSlot, + settings: settings + ) + let visionConfigured = hasResolvedOpenAICompatibleSlot(visionResolution) result.append( Row( - id: "embedding_provider", - label: "Embeddings (embedding_provider)", - status: embedConfigured && embedCap - ? .configured - : (embedCap ? .missing : .capabilityOff), - detail: embedConfigured - ? "Embedding provider configured" - : (embedCap - ? "Optional for Rewind semantic search" - : (DesktopBackendEnvironment.unavailableReason( - for: .directEmbeddings, in: .localDaemon) ?? "Direct embeddings disabled")) + id: HybridProviderPolicy.visionSlot, + label: "Vision, optional", + status: visionConfigured ? .configured : .optionalFallback, + detail: visionConfigured + ? slotDetail(visionResolution) + : "Optional; screenshot assistants use local OCR text when no vision slot is configured" + )) + + let sttAvailable = DesktopBackendEnvironment.isCapability(.directSTT, availableIn: .localDaemon) + let sttModel = policy?.modelSlots[HybridProviderPolicy.sttSlot]?.modelID + result.append( + Row( + id: HybridProviderPolicy.sttSlot, + label: "STT/local transcription", + status: sttAvailable ? .configured : .capabilityOff, + detail: sttAvailable + ? "On-device Apple Speech\(sttModel.map { " / \($0)" } ?? ""); no daemon provider key required" + : (DesktopBackendEnvironment.unavailableReason(for: .directSTT, in: .localDaemon) + ?? "Direct STT unavailable") + )) + + result.append( + Row( + id: HybridProviderPolicy.memorySearchSlot, + label: "Memory search", + status: .configured, + detail: "Local wiki/FTS search using \(policy?.modelSlots[HybridProviderPolicy.memorySearchSlot]?.modelID ?? HybridProviderPolicy.localWikiModel); no embeddings required" )) return result @@ -121,4 +150,25 @@ enum HybridProviderReadiness { guard let base = json["base_url"] as? String, !base.isEmpty else { return false } return true } + + private static func hasResolvedOpenAICompatibleSlot( + _ response: HybridProviderPolicy.SlotResolutionResponse? + ) -> Bool { + guard response?.resolution.ok == true, + let account = response?.resolved?.providerAccount, + let baseURL = account.baseURL?.trimmingCharacters(in: .whitespacesAndNewlines), + !baseURL.isEmpty + else { + return false + } + return HybridProviderPolicy.isOpenAICompatible(kind: account.kind) + } + + private static func slotDetail(_ response: HybridProviderPolicy.SlotResolutionResponse?) -> String { + guard let resolved = response?.resolved else { + return "Slot not configured" + } + let account = resolved.providerAccount?.displayName ?? resolved.providerAccount?.id ?? "no provider account" + return "\(account) / \(resolved.modelID)" + } } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 2e1e05201f1..3a2d633d3ba 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -237,6 +237,7 @@ struct SettingsContentView: View { @State private var hybridEmbedBaseURL: String = HybridProviderReadiness.defaultBaseURL() @State private var hybridEmbedModel: String = HybridProviderReadiness.defaultModel() @State private var hybridEmbedApiKey: String = "" + @State private var hybridVisionModel: String = "" @State private var hybridProviderStatus: String? @State private var isSavingHybridProvider: Bool = false @State private var isTestingHybridProvider: Bool = false @@ -1830,7 +1831,7 @@ struct SettingsContentView: View { .foregroundColor(OmiColors.textPrimary) Text( - "Bring your own AI endpoints. Keys are stored in the local daemon SQLite database on this Mac and sent only to URLs you configure—not Omi cloud proxies." + "Bring your own AI endpoint, then assign models to local task slots. Keys are stored in the local daemon SQLite database on this Mac and sent only to URLs you configure—not Omi cloud proxies." ) .scaledFont(size: 12) .foregroundColor(OmiColors.textSecondary) @@ -1875,34 +1876,64 @@ struct SettingsContentView: View { private var localHybridProvidersEditorCard: some View { settingsCard(settingId: "planusage.local.providers") { VStack(alignment: .leading, spacing: 18) { - hybridProviderEditorBlock( - title: "Processing (ai_provider)", - baseURL: $hybridAiBaseURL, - model: $hybridAiModel, - apiKey: $hybridAiApiKey, - settingKey: "ai_provider" - ) + VStack(alignment: .leading, spacing: 8) { + Text("Provider account") + .scaledFont(size: 12, weight: .medium) + .foregroundColor(OmiColors.textTertiary) + TextField("Base URL", text: $hybridAiBaseURL) + .textFieldStyle(.roundedBorder) + SecureField("API key (optional on loopback)", text: $hybridAiApiKey) + .textFieldStyle(.roundedBorder) + } Divider().background(OmiColors.backgroundQuaternary) - hybridProviderEditorBlock( - title: "Chat (chat_provider)", - baseURL: $hybridChatBaseURL, + hybridSlotEditorBlock( + title: "Chat", + subtitle: "Direct Ask Omi chat", model: $hybridChatModel, - apiKey: $hybridChatApiKey, - settingKey: "chat_provider" + slot: HybridProviderPolicy.chatSlot + ) + + Divider().background(OmiColors.backgroundQuaternary) + + hybridSlotEditorBlock( + title: "Post-transcript processing", + subtitle: "Titles, summaries, memories, and action items", + model: $hybridAiModel, + slot: HybridProviderPolicy.postTranscriptSlot ) Divider().background(OmiColors.backgroundQuaternary) - hybridProviderEditorBlock( - title: "Embeddings (embedding_provider)", - baseURL: $hybridEmbedBaseURL, + hybridSlotEditorBlock( + title: "Proactive assistants", + subtitle: "Local assistant jobs; defaults to \(HybridProviderReadiness.defaultSmallModel())", model: $hybridEmbedModel, - apiKey: $hybridEmbedApiKey, - settingKey: "embedding_provider" + slot: HybridProviderPolicy.proactiveSlot + ) + + Divider().background(OmiColors.backgroundQuaternary) + + hybridSlotEditorBlock( + title: "Vision, optional", + subtitle: "Leave blank to use local OCR text only", + model: $hybridVisionModel, + slot: HybridProviderPolicy.visionSlot, + optional: true ) + Divider().background(OmiColors.backgroundQuaternary) + + VStack(alignment: .leading, spacing: 4) { + Text("Memory search") + .scaledFont(size: 12, weight: .medium) + .foregroundColor(OmiColors.textTertiary) + Text("Local wiki/FTS search uses local_wiki and does not require embeddings.") + .scaledFont(size: 11) + .foregroundColor(OmiColors.textSecondary) + } + if let hybridProviderStatus { Text(hybridProviderStatus) .scaledFont(size: 12) @@ -1913,34 +1944,33 @@ struct SettingsContentView: View { } } - private func hybridProviderEditorBlock( + private func hybridSlotEditorBlock( title: String, - baseURL: Binding, + subtitle: String, model: Binding, - apiKey: Binding, - settingKey: String + slot: String, + optional: Bool = false ) -> some View { VStack(alignment: .leading, spacing: 8) { Text(title) .scaledFont(size: 12, weight: .medium) .foregroundColor(OmiColors.textTertiary) - TextField("Base URL", text: baseURL) - .textFieldStyle(.roundedBorder) + Text(subtitle) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textSecondary) TextField("Model", text: model) .textFieldStyle(.roundedBorder) - SecureField("API key (optional on loopback)", text: apiKey) - .textFieldStyle(.roundedBorder) HStack(spacing: 10) { Button("Save") { - saveHybridProvider(key: settingKey, baseURL: baseURL.wrappedValue, model: model.wrappedValue, apiKey: apiKey.wrappedValue) + saveHybridProviderPolicy() } .buttonStyle(.bordered) .disabled(isSavingHybridProvider) Button("Test") { - testHybridProvider(key: settingKey, baseURL: baseURL.wrappedValue, model: model.wrappedValue, apiKey: apiKey.wrappedValue) + testHybridProviderSlot(slot) } .buttonStyle(.bordered) - .disabled(isTestingHybridProvider) + .disabled(isTestingHybridProvider || (optional && model.wrappedValue.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)) } } } @@ -6364,57 +6394,43 @@ struct SettingsContentView: View { if target.mode == .localDaemon { settingsCard(settingId: "about.hybrid_providers") { VStack(alignment: .leading, spacing: 14) { - Text("Hybrid providers") + Text("Local provider slots") .scaledFont(size: 15, weight: .semibold) .foregroundColor(OmiColors.textPrimary) Text( - "API keys are stored in the local daemon SQLite database on this Mac and sent only to the endpoints you configure—not to Omi cloud proxies." + "Local mode resolves task slots through the daemon provider policy. Memory search uses local wiki/FTS search and does not require embeddings." ) .scaledFont(size: 12) .foregroundColor(OmiColors.textSecondary) .fixedSize(horizontal: false, vertical: true) - VStack(alignment: .leading, spacing: 8) { - Text("Processing (ai_provider)") - .scaledFont(size: 12, weight: .medium) - .foregroundColor(OmiColors.textTertiary) - TextField("Base URL", text: $hybridAiBaseURL) - .textFieldStyle(.roundedBorder) - TextField("Model", text: $hybridAiModel) - .textFieldStyle(.roundedBorder) - SecureField("API key (optional on loopback)", text: $hybridAiApiKey) - .textFieldStyle(.roundedBorder) - } - - HStack(spacing: 10) { - Button("Save") { - saveHybridAiProvider() - } - .buttonStyle(.borderedProminent) - .disabled(isSavingHybridProvider) - - Button("Test connection") { - testHybridAiProvider() - } - .buttonStyle(.bordered) - .disabled(isTestingHybridProvider) - - if isSavingHybridProvider || isTestingHybridProvider { - ProgressView().controlSize(.small) + ForEach(HybridProviderReadiness.rows(from: backendSettings)) { row in + HStack(alignment: .top, spacing: 8) { + Image( + systemName: row.status == .configured || row.status == .optionalFallback + ? "checkmark.circle.fill" : "xmark.circle" + ) + .foregroundColor( + row.status == .configured + ? OmiColors.success + : (row.status == .optionalFallback + ? OmiColors.textTertiary : OmiColors.warning)) + VStack(alignment: .leading, spacing: 2) { + Text(row.label) + .scaledFont(size: 11, weight: .medium) + .foregroundColor(OmiColors.textSecondary) + Text(row.detail) + .scaledFont(size: 10) + .foregroundColor(OmiColors.textTertiary) + } + Spacer() } } - if let hybridProviderStatus { - Text(hybridProviderStatus) - .scaledFont(size: 12) - .foregroundColor(OmiColors.textSecondary) - .textSelection(.enabled) - } - Divider().background(OmiColors.backgroundQuaternary) - Text("Capabilities") + Text("Local daemon capabilities") .scaledFont(size: 12, weight: .medium) .foregroundColor(OmiColors.textTertiary) @@ -6449,40 +6465,29 @@ struct SettingsContentView: View { } private func syncHybridProviderFieldsFromBackendSettings() { - syncHybridProviderFields( - forKey: "ai_provider", alsoKeys: ["provider"], - intoBaseURL: &hybridAiBaseURL, model: &hybridAiModel, apiKey: &hybridAiApiKey) + syncAllHybridProviderFieldsFromBackendSettings() } private func saveHybridAiProvider() { - saveHybridProvider( - key: "ai_provider", - baseURL: hybridAiBaseURL, - model: hybridAiModel, - apiKey: hybridAiApiKey - ) + saveHybridProviderPolicy() } private func testHybridAiProvider() { - testHybridProvider( - key: "ai_provider", - baseURL: hybridAiBaseURL, - model: hybridAiModel, - apiKey: hybridAiApiKey - ) + testHybridProviderSlot(HybridProviderPolicy.postTranscriptSlot) } private var localProcessingProviderStatus: String { - let providerSetting = backendSettings.first { $0.key == "ai_provider" || $0.key == "provider" } - guard let providerSetting else { - return "Deterministic fallback" - } - if providerSetting.valueJson.contains("\"api_key\"") - || providerSetting.valueJson.contains("\"key\"") + if let resolution = HybridProviderPolicy.resolveSlotFromSettings( + HybridProviderPolicy.postTranscriptSlot, + settings: backendSettings + ), + resolution.resolution.ok, + let resolved = resolution.resolved, + let account = resolved.providerAccount { - return "OpenAI-compatible provider configured" + return "\(account.displayName ?? account.id) / \(resolved.modelID)" } - return "Deterministic fallback" + return "Deterministic fallback; post-transcript defaults to \(HybridProviderReadiness.defaultSmallModel()) until a provider account is configured" } private func backendStatusRow(title: String, value: String) -> some View { @@ -7452,9 +7457,56 @@ struct SettingsContentView: View { } private func syncAllHybridProviderFieldsFromBackendSettings() { - syncHybridProviderFields(forKey: "ai_provider", alsoKeys: ["provider"], intoBaseURL: &hybridAiBaseURL, model: &hybridAiModel, apiKey: &hybridAiApiKey) - syncHybridProviderFields(forKey: "chat_provider", alsoKeys: [], intoBaseURL: &hybridChatBaseURL, model: &hybridChatModel, apiKey: &hybridChatApiKey) - syncHybridProviderFields(forKey: "embedding_provider", alsoKeys: [], intoBaseURL: &hybridEmbedBaseURL, model: &hybridEmbedModel, apiKey: &hybridEmbedApiKey) + if syncHybridProviderPolicyFieldsFromBackendSettings() { + return + } + syncHybridProviderFields( + forKey: "ai_provider", alsoKeys: ["provider"], intoBaseURL: &hybridAiBaseURL, + model: &hybridAiModel, apiKey: &hybridAiApiKey) + syncHybridProviderFields( + forKey: "chat_provider", alsoKeys: [], intoBaseURL: &hybridChatBaseURL, + model: &hybridChatModel, apiKey: &hybridChatApiKey) + syncHybridProviderFields( + forKey: "vision_provider", alsoKeys: [], intoBaseURL: &hybridAiBaseURL, + model: &hybridVisionModel, apiKey: &hybridAiApiKey) + } + + private func syncHybridProviderPolicyFieldsFromBackendSettings() -> Bool { + guard let policy = HybridProviderPolicy.policyFromSettings(backendSettings) else { + return false + } + let accountID = + policy.modelSlots[HybridProviderPolicy.chatSlot]?.providerAccountID + ?? policy.modelSlots[HybridProviderPolicy.postTranscriptSlot]?.providerAccountID + ?? policy.modelSlots[HybridProviderPolicy.proactiveSlot]?.providerAccountID + ?? policy.modelSlots[HybridProviderPolicy.visionSlot]?.providerAccountID + if let accountID, + let account = policy.providerAccounts.first(where: { $0.id == accountID }) + { + if let baseURL = account.baseURL, !baseURL.isEmpty { + hybridAiBaseURL = baseURL + hybridChatBaseURL = baseURL + hybridEmbedBaseURL = baseURL + } + if let apiKey = account.apiKey { + hybridAiApiKey = apiKey + hybridChatApiKey = apiKey + hybridEmbedApiKey = apiKey + } + } + hybridChatModel = + policy.modelSlots[HybridProviderPolicy.chatSlot]?.modelID + ?? hybridChatModel + hybridAiModel = + policy.modelSlots[HybridProviderPolicy.postTranscriptSlot]?.modelID + ?? hybridAiModel + hybridEmbedModel = + policy.modelSlots[HybridProviderPolicy.proactiveSlot]?.modelID + ?? hybridEmbedModel + hybridVisionModel = + policy.modelSlots[HybridProviderPolicy.visionSlot]?.modelID + ?? "" + return true } private func syncHybridProviderFields( @@ -7485,21 +7537,20 @@ struct SettingsContentView: View { guard !isSavingHybridProvider else { return } isSavingHybridProvider = true hybridProviderStatus = nil - let provider = HybridProviderBootstrap.defaultProviderObject() Task { do { - backendSettings = try await APIClient.shared.updateSelectedBackendSettings([ - "ai_provider": .object(provider), - "chat_provider": .object(provider), - ]) + let policy = currentHybridProviderPolicyFromUI() + _ = try await APIClient.shared.updateSelectedBackendProviderPolicy(policy) + backendSettings = try await APIClient.shared.getSelectedBackendSettings() await MainActor.run { applyLocalHybridProviderDefaultsToUI() syncAllHybridProviderFieldsFromBackendSettings() - hybridProviderStatus = "Applied local defaults (ai_provider + chat_provider)." + hybridProviderStatus = "Applied local provider-policy defaults." isSavingHybridProvider = false } do { - let test = try await APIClient.shared.testHybridProvider(key: "chat_provider") + let test = try await APIClient.shared.testSelectedBackendProviderSlot( + HybridProviderPolicy.chatSlot) await MainActor.run { hybridProviderStatus = test.message } @@ -7520,11 +7571,113 @@ struct SettingsContentView: View { private func applyLocalHybridProviderDefaultsToUI() { hybridAiBaseURL = HybridProviderReadiness.defaultBaseURL() - hybridAiModel = HybridProviderReadiness.defaultModel() + hybridAiModel = HybridProviderReadiness.defaultSmallModel() hybridChatBaseURL = hybridAiBaseURL - hybridChatModel = hybridAiModel + hybridChatModel = HybridProviderReadiness.defaultModel() hybridEmbedBaseURL = hybridAiBaseURL - hybridEmbedModel = hybridAiModel + hybridEmbedModel = HybridProviderReadiness.defaultSmallModel() + hybridVisionModel = "" + } + + private func currentHybridProviderPolicyFromUI() -> HybridProviderPolicy.Policy { + let account = HybridProviderPolicy.ProviderAccount( + id: HybridProviderPolicy.localProviderAccountID, + kind: "openai_compatible", + baseURL: hybridAiBaseURL.trimmingCharacters(in: .whitespacesAndNewlines), + apiKey: hybridAiApiKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty + ? nil : hybridAiApiKey.trimmingCharacters(in: .whitespacesAndNewlines), + displayName: "Local OpenAI-compatible", + capabilities: HybridProviderPolicy.ProviderCapabilities( + chatCompletions: true, + jsonMode: true, + toolCalls: false, + vision: !hybridVisionModel.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty, + speechToText: false + ), + subscriptionIntegration: nil + ) + var slots: [String: HybridProviderPolicy.ModelSlotTarget] = [ + HybridProviderPolicy.chatSlot: HybridProviderPolicy.ModelSlotTarget( + providerAccountID: account.id, + modelID: hybridChatModel.trimmingCharacters(in: .whitespacesAndNewlines), + options: HybridProviderPolicy.ModelSlotOptions(jsonMode: false, toolSupport: false) + ), + HybridProviderPolicy.postTranscriptSlot: HybridProviderPolicy.ModelSlotTarget( + providerAccountID: account.id, + modelID: hybridAiModel.trimmingCharacters(in: .whitespacesAndNewlines), + options: HybridProviderPolicy.ModelSlotOptions(jsonMode: true, toolSupport: false) + ), + HybridProviderPolicy.proactiveSlot: HybridProviderPolicy.ModelSlotTarget( + providerAccountID: account.id, + modelID: hybridEmbedModel.trimmingCharacters(in: .whitespacesAndNewlines), + options: HybridProviderPolicy.ModelSlotOptions(jsonMode: true, toolSupport: false) + ), + HybridProviderPolicy.memorySearchSlot: HybridProviderPolicy.ModelSlotTarget( + providerAccountID: nil, + modelID: HybridProviderPolicy.localWikiModel, + options: HybridProviderPolicy.ModelSlotOptions(jsonMode: nil, toolSupport: nil) + ), + ] + let visionModel = hybridVisionModel.trimmingCharacters(in: .whitespacesAndNewlines) + if !visionModel.isEmpty { + slots[HybridProviderPolicy.visionSlot] = HybridProviderPolicy.ModelSlotTarget( + providerAccountID: account.id, + modelID: visionModel, + options: HybridProviderPolicy.ModelSlotOptions(jsonMode: true, toolSupport: false) + ) + } + return HybridProviderPolicy.Policy( + version: 1, + providerAccounts: [account], + modelSlots: slots + ) + } + + private func saveHybridProviderPolicy() { + guard !isSavingHybridProvider else { return } + isSavingHybridProvider = true + hybridProviderStatus = nil + let policy = currentHybridProviderPolicyFromUI() + Task { + do { + _ = try await APIClient.shared.updateSelectedBackendProviderPolicy(policy) + backendSettings = try await APIClient.shared.getSelectedBackendSettings() + await MainActor.run { + hybridProviderStatus = "Saved provider account and model slots to the local daemon." + isSavingHybridProvider = false + syncAllHybridProviderFieldsFromBackendSettings() + } + } catch { + await MainActor.run { + hybridProviderStatus = error.localizedDescription + isSavingHybridProvider = false + } + } + } + } + + private func testHybridProviderSlot(_ slot: String) { + guard !isTestingHybridProvider else { return } + isTestingHybridProvider = true + hybridProviderStatus = nil + let policy = currentHybridProviderPolicyFromUI() + Task { + do { + _ = try await APIClient.shared.updateSelectedBackendProviderPolicy(policy) + let result = try await APIClient.shared.testSelectedBackendProviderSlot(slot) + backendSettings = try await APIClient.shared.getSelectedBackendSettings() + await MainActor.run { + hybridProviderStatus = result.message + isTestingHybridProvider = false + syncAllHybridProviderFieldsFromBackendSettings() + } + } catch { + await MainActor.run { + hybridProviderStatus = error.localizedDescription + isTestingHybridProvider = false + } + } + } } private func saveHybridProvider(key: String, baseURL: String, model: String, apiKey: String) { diff --git a/desktop/Desktop/Tests/APIClientRoutingTests.swift b/desktop/Desktop/Tests/APIClientRoutingTests.swift index 59a65a27117..2e9c11ae670 100644 --- a/desktop/Desktop/Tests/APIClientRoutingTests.swift +++ b/desktop/Desktop/Tests/APIClientRoutingTests.swift @@ -1625,6 +1625,24 @@ final class APIClientRoutingTests: XCTestCase { XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) } + func testLocalModeTestProviderSlotRoutesToLocalDaemonWithoutAuth() async { + setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) + setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) + defer { + unsetenv("OMI_DESKTOP_BACKEND_MODE") + unsetenv("OMI_LOCAL_DAEMON_URL") + } + let client = await makeTestClient() + _ = try? await client.testSelectedBackendProviderSlot("post_transcript") + as LocalDaemonTestSlotResponse + assertRoutes( + URLCapture.capturedRequests, host: "127.0.0.1", port: 8765, + pathContains: "v1/provider-policy/test-slot/post_transcript", + method: "POST", + label: "local test provider slot") + XCTAssertNil(URLCapture.capturedRequests.first?.headers["Authorization"]) + } + func testLocalModeGetMessagesForSessionRoutesToLocalDaemon() async { setenv("OMI_DESKTOP_BACKEND_MODE", "local", 1) setenv("OMI_LOCAL_DAEMON_URL", "http://127.0.0.1:8765", 1) diff --git a/desktop/Desktop/Tests/HybridChatClientTests.swift b/desktop/Desktop/Tests/HybridChatClientTests.swift index 88acb33abfa..021cb4d700c 100644 --- a/desktop/Desktop/Tests/HybridChatClientTests.swift +++ b/desktop/Desktop/Tests/HybridChatClientTests.swift @@ -223,7 +223,10 @@ final class HybridChatClientTests: XCTestCase { id: accountID, kind: "openai_compatible", baseURL: baseURL, - apiKey: "test-key" + apiKey: "test-key", + displayName: nil, + capabilities: nil, + subscriptionIntegration: nil ) let resolved = HybridProviderPolicy.ResolvedSlot( slot: "chat", diff --git a/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift b/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift index 79531c9afb7..5b36f420234 100644 --- a/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift +++ b/desktop/Desktop/Tests/HybridLLMProviderConfigTests.swift @@ -164,6 +164,55 @@ final class HybridLLMProviderConfigTests: XCTestCase { XCTAssertNil(HybridVisionProvider.providerConfig(settings: settings)) } + func testReadinessRowsExposeSlotsAndLocalWikiMemory() throws { + let settings = try makeSettings([ + ( + "provider_policy", + [ + "version": 1, + "provider_accounts": [ + [ + "id": "local-openai-compatible", + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "api_key": "test-key", + "display_name": "Local OpenAI-compatible", + ] + ], + "model_slots": [ + "chat": [ + "provider_account_id": "local-openai-compatible", + "model_id": "chat-local", + ], + "post_transcript": [ + "provider_account_id": NSNull(), + "model_id": "gpt-5.4-mini", + ], + "proactive": [ + "provider_account_id": NSNull(), + "model_id": "gpt-5.4-mini", + ], + "memory_search": [ + "provider_account_id": NSNull(), + "model_id": "local_wiki", + ], + ], + ] + ) + ]) + + let rows = HybridProviderReadiness.rows(from: settings) + + XCTAssertEqual(rows.map(\.id).contains("embedding_provider"), false) + XCTAssertEqual(rows.first(where: { $0.id == "chat" })?.detail, "Local OpenAI-compatible / chat-local") + XCTAssertTrue( + rows.first(where: { $0.id == "post_transcript" })?.detail.contains("gpt-5.4-mini") + == true) + XCTAssertTrue( + rows.first(where: { $0.id == "memory_search" })?.detail.contains("no embeddings required") + == true) + } + private func decodeSettings(_ jsonArray: String) throws -> [LocalDaemonSetting] { let decoder = JSONDecoder() decoder.dateDecodingStrategy = .iso8601 diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md index cb4fd4da199..15d9cfd3226 100644 --- a/desktop/local-backend/docs/hybrid-provider-settings.md +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -99,6 +99,9 @@ Desktop clients should prefer these daemon APIs over manual JSON editing: - `GET /v1/provider-policy/resolve/{slot}` resolves one slot to its provider account, model, options, source (`provider_policy`, `legacy_setting`, or `default`), and a readable success/failure reason. +- `POST /v1/provider-policy/test-slot/{slot}` validates the actual resolved task + path. OpenAI-compatible slots run a minimal chat-completions JSON ping; `memory_search` + reports local wiki readiness without requiring embeddings. - `GET /v1/model-catalog` returns the local model catalog and availability. Callers should resolve `post_transcript`, `proactive`, and `chat` slots explicitly @@ -263,7 +266,10 @@ local guest session startup. Ask Omi resolves uses the returned provider account/model. Legacy `chat_provider` rows are only read by the daemon compatibility bridge when constructing the typed policy. -Configure or override in **Settings → Plan and Usage** (local mode) or via `PUT /v1/settings`. +Configure or override in **Settings → Plan and Usage** (local mode). The app writes +`/v1/provider-policy` so users can inspect one provider account and slot model +choices without hand-editing JSON. Advanced callers may still use `PUT +/v1/provider-policy` directly. ## ChatGPT / Codex subscription (desktop) @@ -271,14 +277,25 @@ When the user connects **ChatGPT plan** in Settings → Advanced: - A loopback proxy (`desktop/codex-proxy`, default `http://127.0.0.1:10531/v1`) uses `~/.codex/auth.json` from Codex CLI login. - Daemon `chat_provider` and `ai_provider` are set to that URL (not `embedding_provider`). -- Memory search uses **local wiki + FTS5** instead of vector embeddings unless `OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=1`. +- Memory search readiness uses **local wiki + FTS5** for this profile, not vector + embeddings. - Deepgram / live transcription behavior is unchanged. Build proxy: `cd desktop/codex-proxy && cargo build --release` ## Test connection -`POST /v1/settings/test-provider` with body `{ "key": "ai_provider" }` runs a minimal -request against a legacy configured provider (chat completions ping for -`openai_compatible`). New UI should read and write policy through `/v1/provider-policy` -and use `/v1/provider-policy/resolve/{slot}` before making task-specific calls. +`POST /v1/provider-policy/test-slot/{slot}` validates the active slot path and is +the preferred readiness check for Settings UI. Example: + +```bash +curl -fsS -X POST http://127.0.0.1:8765/v1/provider-policy/test-slot/chat \ + -H 'content-type: application/json' \ + -d '{}' +``` + +`POST /v1/settings/test-provider` with body `{ "key": "ai_provider" }` remains as +a legacy helper for raw setting-key providers. New UI should read and write +policy through `/v1/provider-policy`, use `/v1/provider-policy/resolve/{slot}` +before task-specific calls, and use `/v1/provider-policy/test-slot/{slot}` for +actionable readiness. diff --git a/desktop/local-backend/docs/local-mvp-runbook.md b/desktop/local-backend/docs/local-mvp-runbook.md index 3fe2426c5cc..b8d0b681ff4 100644 --- a/desktop/local-backend/docs/local-mvp-runbook.md +++ b/desktop/local-backend/docs/local-mvp-runbook.md @@ -172,9 +172,12 @@ will not work when `OMI_PYTHON_API_URL` is set to an invalid host (intentional for hybrid testing). **Settings → Plan and Usage** shows a **Local** plan (not cloud Free/Neo tiers). -Use that section to configure hybrid providers (`ai_provider`, `chat_provider`, -`embedding_provider`). Keys are stored in the local daemon SQLite database on this Mac. -Cloud subscription, usage quotas, and the Advanced “BYOK free forever” flow are hidden +Use that section to configure a local provider account and task slots: Chat, +Post-transcript processing, Proactive assistants, optional Vision, STT/local +transcription, and Memory search: Local wiki. Keys are stored in the local daemon +SQLite database on this Mac. Memory search uses local wiki/FTS search and does +not require `embedding_provider` or vector embeddings for this profile. Cloud +subscription, usage quotas, and the Advanced “BYOK free forever” flow are hidden in local mode. The Conversations header shows a `Local` chip when the app is using local daemon @@ -283,17 +286,17 @@ title, overview, or transcript text. ## Local Provider Configuration -See [hybrid-provider-settings.md](hybrid-provider-settings.md) for the full settings schema -(`ai_provider`, `stt_provider`, `chat_provider`, `embedding_provider`, `vision_provider`) and -`POST /v1/settings/test-provider`. +See [hybrid-provider-settings.md](hybrid-provider-settings.md) for the full +provider-policy schema, default slots, legacy setting-key bridge, and +`POST /v1/provider-policy/test-slot/{slot}`. Processing works without provider keys by using deterministic fallback. To force -that path, clear provider settings: +that path, clear the typed provider policy and legacy provider settings: ```bash curl -X PUT http://127.0.0.1:8765/v1/settings \ -H 'content-type: application/json' \ - -d '{"ai_provider": null, "provider": null}' + -d '{"provider_policy": null, "ai_provider": null, "provider": null, "chat_provider": null, "vision_provider": null}' ``` To test a direct OpenAI-compatible provider without editing source code, point @@ -302,22 +305,56 @@ MVP validation, prefer a loopback stub so the test cannot reach hosted Omi or OpenAI services by accident: ```bash -curl -X PUT http://127.0.0.1:8765/v1/settings \ +curl -X PUT http://127.0.0.1:8765/v1/provider-policy \ -H 'content-type: application/json' \ -d '{ - "ai_provider": { + "version": 1, + "provider_accounts": [{ + "id": "local-openai-compatible", "kind": "openai_compatible", "base_url": "http://127.0.0.1:43210/v1", - "model": "local-stub", - "api_key": "local-test-key" + "api_key": "local-test-key", + "display_name": "Local stub", + "capabilities": { + "chat_completions": true, + "json_mode": true, + "tool_calls": false, + "vision": false, + "speech_to_text": false + }, + "subscription_integration": null + }], + "model_slots": { + "chat": { + "provider_account_id": "local-openai-compatible", + "model_id": "local-stub", + "options": {"json_mode": false, "tool_support": false} + }, + "post_transcript": { + "provider_account_id": "local-openai-compatible", + "model_id": "gpt-5.4-mini", + "options": {"json_mode": true, "tool_support": false} + }, + "proactive": { + "provider_account_id": "local-openai-compatible", + "model_id": "gpt-5.4-mini", + "options": {"json_mode": true, "tool_support": false} + }, + "memory_search": { + "provider_account_id": null, + "model_id": "local_wiki", + "options": {} + } } }' ``` -Inspect the active settings: +Inspect and validate the active policy: ```bash -curl http://127.0.0.1:8765/v1/settings +curl http://127.0.0.1:8765/v1/provider-policy +curl http://127.0.0.1:8765/v1/provider-policy/resolve/post_transcript +curl -X POST http://127.0.0.1:8765/v1/provider-policy/test-slot/memory_search -d '{}' ``` Provider keys remain in the local daemon SQLite settings table and are sent diff --git a/desktop/local-backend/src/main.rs b/desktop/local-backend/src/main.rs index 8232a67d3a7..a5c4bb6f4f1 100644 --- a/desktop/local-backend/src/main.rs +++ b/desktop/local-backend/src/main.rs @@ -302,7 +302,7 @@ mod tests { assert_eq!(profile["profile"]["display_name"], "Local User"); let settings = request_json( - app, + app.clone(), Method::PUT, "/v1/settings", Some(json!({ @@ -393,7 +393,7 @@ mod tests { assert_eq!(resolved["resolution"]["ok"], true); let proactive = request_json( - app, + app.clone(), Method::GET, "/v1/provider-policy/resolve/proactive", None, @@ -405,6 +405,19 @@ mod tests { .unwrap() .contains("no provider account")); + let memory_test = request_json( + app.clone(), + Method::POST, + "/v1/provider-policy/test-slot/memory_search", + Some(json!({})), + ) + .await?; + assert_eq!(memory_test["ok"], true); + assert!(memory_test["message"] + .as_str() + .unwrap() + .contains("does not require embeddings")); + Ok(()) } diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs index 2cf2d9299b5..b124164d308 100644 --- a/desktop/local-backend/src/providers.rs +++ b/desktop/local-backend/src/providers.rs @@ -973,6 +973,44 @@ pub async fn test_configured_provider(store: &Store, key: &str) -> Result Result { + let resolution = resolve_model_slot_result(store, slot)?; + let Some(resolved) = resolution.resolved.as_ref() else { + return Err(anyhow!(resolution.reason)); + }; + if slot == SLOT_MEMORY_SEARCH { + return Ok(format!( + "memory_search uses {}; local wiki search does not require embeddings", + resolved.model_id + )); + } + if slot == SLOT_STT && resolved.provider_account.is_none() { + return Ok("stt uses local transcription; no provider account test required".to_string()); + } + if !resolution.ok { + return Err(anyhow!(resolution.reason)); + } + let Some(account) = resolved.provider_account.as_ref() else { + return Err(anyhow!("model slot {slot} has no provider account to test")); + }; + if !is_openai_compatible_kind(&account.kind) { + return Err(anyhow!( + "test-slot supports openai_compatible provider accounts only" + )); + } + let provider = OpenAiCompatibleProvider::new(openai_config_from_resolved_slot(resolved)?); + let _ = provider + .complete_json(vec![ + ChatMessage::system("Reply with JSON only: {\"ok\":true}"), + ChatMessage::user("ping"), + ]) + .await?; + Ok(format!( + "{slot} resolved to {} on {} and responded successfully", + resolved.model_id, account.id + )) +} + fn load_openai_config_from_value(value: &Value) -> Result { let kind = value["kind"].as_str().unwrap_or_default(); if kind != "openai" && kind != "openai_compatible" { diff --git a/desktop/local-backend/src/routes.rs b/desktop/local-backend/src/routes.rs index 76f898fa603..36e0c9139a8 100644 --- a/desktop/local-backend/src/routes.rs +++ b/desktop/local-backend/src/routes.rs @@ -35,6 +35,10 @@ pub fn router() -> Router { "/v1/provider-policy/resolve/:slot", get(resolve_provider_slot), ) + .route( + "/v1/provider-policy/test-slot/:slot", + post(test_provider_slot), + ) .route( "/v1/conversations", get(list_conversations).post(create_conversation), @@ -910,6 +914,20 @@ async fn resolve_provider_slot( }))) } +async fn test_provider_slot( + State(state): State, + Path(slot): Path, +) -> ApiResult { + let message = providers::test_configured_slot(&state.store, &slot) + .await + .map_err(|error| ApiError::bad_request(error.to_string()))?; + Ok(Json(json!({ + "ok": true, + "slot": slot, + "message": message + }))) +} + async fn list_processing_jobs(State(state): State) -> ApiResult { let jobs = state .store From 37e59b2f65a21a3b345c56b40a773d707cc4a901 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 23:03:16 +0700 Subject: [PATCH 37/58] Keep hybrid local embeddings opt-in --- desktop/local-backend/docs/hybrid-provider-settings.md | 2 +- desktop/run.sh | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md index 15d9cfd3226..fdb42fca1f1 100644 --- a/desktop/local-backend/docs/hybrid-provider-settings.md +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -243,7 +243,7 @@ Environment flags (desktop process, not daemon): - `OMI_HYBRID_OPTIONAL_CLOUD_STT=1` — exposes `optionalCloudSTT` capability - `OMI_HYBRID_OPTIONAL_CLOUD_CHAT=1` — exposes `optionalCloudChat` capability -Default hybrid optional tiers: both cloud toggles off. `run.sh` local mode defaults direct STT/embeddings/chat capability env flags on for GUI launches; hosted Listen and pi-mono remain disabled without explicit optional-cloud flags / cloud backends. +Default hybrid optional tiers: both cloud toggles off. `run.sh` local mode defaults direct STT/chat capability env flags on for GUI launches and keeps direct embeddings off; this local profile uses local wiki/FTS memory search. Hosted Listen and pi-mono remain disabled without explicit optional-cloud flags / cloud backends. ## Local dev defaults (seed) diff --git a/desktop/run.sh b/desktop/run.sh index ce5561c4e9a..783fe19f3b5 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -26,7 +26,7 @@ Options (via environment variables): OMI_LOCAL_DAEMON_URL="..." Local daemon URL (default: http://127.0.0.1:8765) OMI_HYBRID_DIRECT_STT_ENABLED Hybrid Apple Speech live transcription in local daemon (default 1 in configure_local_daemon_mode when unset) OMI_HYBRID_DIRECT_CHAT_ENABLED Hybrid OpenAI-compatible chat + daemon-backed sessions/messages (default 1 in configure_local_daemon_mode when unset) - OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED Hybrid direct embeddings for Rewind/proactive features (default 1 in local bundle; requires embedding_provider) + OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED Optional hybrid direct embeddings for vector search (default 0 in local bundle; local wiki search does not require embeddings) Required files for cloud backend mode: Backend-Rust/.env Environment variables (copy from ../.env.example) @@ -199,9 +199,9 @@ PY if [ -z "${OMI_HYBRID_DIRECT_CHAT_ENABLED+x}" ]; then export OMI_HYBRID_DIRECT_CHAT_ENABLED=1 fi - # Optional direct embeddings for Rewind OCR vectors etc.: requires embedding_provider in daemon settings. + # Optional direct embeddings for vector search. Default off: this local profile uses local wiki/FTS memory search. if [ -z "${OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED+x}" ]; then - export OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=1 + export OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=0 fi } @@ -663,12 +663,12 @@ if is_local_daemon_mode; then # GUI launches via `open` do not inherit shell exports — AppState.loadEnvironment() reads bundled .env. set_bundle_env "OMI_HYBRID_DIRECT_STT_ENABLED" "${OMI_HYBRID_DIRECT_STT_ENABLED:-1}" set_bundle_env "OMI_HYBRID_DIRECT_CHAT_ENABLED" "${OMI_HYBRID_DIRECT_CHAT_ENABLED:-1}" - set_bundle_env "OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED" "${OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED:-1}" + set_bundle_env "OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED" "${OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED:-0}" substep "OMI_DESKTOP_BACKEND_MODE=local" substep "OMI_LOCAL_DAEMON_URL=$OMI_LOCAL_DAEMON_URL" substep "OMI_HYBRID_DIRECT_STT_ENABLED=${OMI_HYBRID_DIRECT_STT_ENABLED:-1}" substep "OMI_HYBRID_DIRECT_CHAT_ENABLED=${OMI_HYBRID_DIRECT_CHAT_ENABLED:-1}" - substep "OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=${OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED:-1}" + substep "OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED=${OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED:-0}" fi # Bootstrap FIREBASE_API_KEY — check env var first (yolo mode), then backend .env if ! grep -q "^FIREBASE_API_KEY=" "$APP_BUNDLE/Contents/Resources/.env"; then From a82c52960e629bc7901faf76a19bee113141ecba Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 12:08:01 -0400 Subject: [PATCH 38/58] Ignore local-asr-helper Rust build artifacts. Co-authored-by: Cursor --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 087560242aa..40b5e503cbf 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,7 @@ web/app/public/firebase-messaging-sw.js !app/ios/Podfile.lock !mcp/uv.lock desktop/codex-proxy/target/ +desktop/local-asr-helper/target/ *.lock !desktop/codex-proxy/Cargo.lock *.log From 818e2dd5209e5bcb12369e163f10da4d17034a79 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 10:39:06 +0700 Subject: [PATCH 39/58] Add desktop transcription provider architecture (cherry picked from commit aad2f58d96a44f54b8841ac73c2059cfccd8907a) --- desktop/Desktop/Sources/AppState.swift | 78 +-- .../Services/AssistantSettings.swift | 569 +++++++++--------- .../Sources/TranscriptionProvider.swift | 428 +++++++++++++ .../Sources/TranscriptionService.swift | 2 +- .../TranscriptionProviderPolicyTests.swift | 153 +++++ 5 files changed, 906 insertions(+), 324 deletions(-) create mode 100644 desktop/Desktop/Sources/TranscriptionProvider.swift create mode 100644 desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index c9e73445189..d9a68d0505e 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -217,6 +217,7 @@ class AppState: ObservableObject { private let maxInMemorySegments = 200 private var totalSegmentCount = 0 // Total segments created this session (including trimmed) private var totalWordCount = 0 // Running word count for analytics + private var speakerSegmentReducer = SpeakerSegmentReducer(maxInMemorySegments: 200) // Conversation tracking for auto-save private var recordingStartTime: Date? @@ -1376,6 +1377,7 @@ class AppState: ObservableObject { speakerSegments = [] totalSegmentCount = 0 totalWordCount = 0 + speakerSegmentReducer.reset() liveSpeakerPersonMap = [:] LiveTranscriptMonitor.shared.clear() recordingStartTime = Date() @@ -1791,6 +1793,7 @@ class AppState: ObservableObject { speakerSegments = [] totalSegmentCount = 0 totalWordCount = 0 + speakerSegmentReducer.reset() liveSpeakerPersonMap = [:] LiveTranscriptMonitor.shared.clear() LiveNotesMonitor.shared.endSession() @@ -1965,6 +1968,7 @@ class AppState: ObservableObject { AnalyticsManager.shared.transcriptionStopped(wordCount: totalWordCount) totalSegmentCount = 0 totalWordCount = 0 + speakerSegmentReducer.reset() currentTranscript = "" log("Transcription: Stopped") @@ -2426,57 +2430,29 @@ class AppState: ObservableObject { /// Handle incoming transcript segments from Python backend `/v4/listen`. /// Backend sends pre-merged segments with speaker attribution — no client-side word merging needed. private func handleBackendSegments(_ segments: [TranscriptionService.BackendSegment]) { - for segment in segments { - guard !segment.text.isEmpty else { continue } - - // Extract speaker_id from backend (e.g. "SPEAKER_00" → 0) - let speakerId = segment.speaker_id ?? 0 - - // Convert backend segment to local SpeakerSegment - let translations = (segment.translations ?? []).map { - SegmentTranslation(lang: $0.lang, text: $0.text) - } - let newSeg = SpeakerSegment( + let incomingSegments = segments.compactMap { segment -> SpeakerSegment? in + guard !segment.text.isEmpty else { return nil } + return SpeakerSegment( segmentId: segment.id, - speaker: speakerId, + speaker: segment.speaker_id ?? 0, text: segment.text, start: segment.start, end: segment.end, isUser: segment.is_user, personId: segment.person_id, - translations: translations - ) - - // Upsert: if we already have a segment with this ID, update it; otherwise append - if let segId = segment.id, - let existingIdx = speakerSegments.firstIndex(where: { $0.segmentId == segId }) - { - // Adjust word count: subtract old words, add new words - let oldWords = speakerSegments[existingIdx].text.split(separator: " ").count - totalWordCount += newSeg.text.split(separator: " ").count - oldWords - // Preserve existing translations if the backend didn't send new ones - var updatedSeg = newSeg - if translations.isEmpty && !speakerSegments[existingIdx].translations.isEmpty { - updatedSeg.translations = speakerSegments[existingIdx].translations + translations: (segment.translations ?? []).map { + SegmentTranslation(lang: $0.lang, text: $0.text) } - speakerSegments[existingIdx] = updatedSeg - log( - "Transcript [UPDATE] Speaker \(speakerId) [\(String(format: "%.1f", segment.start))s-\(String(format: "%.1f", segment.end))s]: \(segment.text.prefix(80))" - ) - } else { - totalWordCount += newSeg.text.split(separator: " ").count - speakerSegments.append(newSeg) - totalSegmentCount += 1 - log( - "Transcript [ADD] Speaker \(speakerId) [\(String(format: "%.1f", segment.start))s-\(String(format: "%.1f", segment.end))s]: \(segment.text.prefix(80))" - ) - } + ) } - // Sliding window: trim old segments from memory (they're already persisted in SQLite) - if speakerSegments.count > maxInMemorySegments { - let excess = speakerSegments.count - maxInMemorySegments - speakerSegments.removeFirst(excess) + let applyResult = speakerSegmentReducer.apply(incomingSegments) + speakerSegments = speakerSegmentReducer.segments + totalSegmentCount = speakerSegmentReducer.totalSegmentCount + totalWordCount = speakerSegmentReducer.totalWordCount + + if applyResult.added > 0 || applyResult.updated > 0 { + log("Transcript [UPSERT] Added: \(applyResult.added), updated: \(applyResult.updated)") } log( @@ -2594,19 +2570,10 @@ class AppState: ObservableObject { case "segments_deleted": if let segmentIds = event.raw["segment_ids"] as? [String] { log("Transcription: Backend deleted \(segmentIds.count) segments") - // Decrement counters for deleted segments - let deletedSegments = speakerSegments.filter { seg in - guard let segId = seg.segmentId else { return false } - return segmentIds.contains(segId) - } - let deletedWords = deletedSegments.reduce(0) { $0 + $1.text.split(separator: " ").count } - totalWordCount = max(0, totalWordCount - deletedWords) - totalSegmentCount = max(0, totalSegmentCount - deletedSegments.count) - - speakerSegments.removeAll { seg in - guard let segId = seg.segmentId else { return false } - return segmentIds.contains(segId) - } + _ = speakerSegmentReducer.deleteSegmentIds(segmentIds) + speakerSegments = speakerSegmentReducer.segments + totalSegmentCount = speakerSegmentReducer.totalSegmentCount + totalWordCount = speakerSegmentReducer.totalWordCount LiveTranscriptMonitor.shared.updateSegments(speakerSegments) // Also remove from DB @@ -2657,6 +2624,7 @@ class AppState: ObservableObject { // Update in-memory if the segment is still loaded if let idx = speakerSegments.firstIndex(where: { $0.segmentId == segId }) { speakerSegments[idx].translations = newTranslations + speakerSegmentReducer.replaceSegments(speakerSegments) } // Always persist to SQLite — even if the segment was trimmed from diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift index 38d8bf229fc..1a100218788 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/AssistantSettings.swift @@ -3,275 +3,144 @@ import Foundation /// Manages shared settings for all Proactive Assistants stored in UserDefaults @MainActor class AssistantSettings { - static let shared = AssistantSettings() - - // MARK: - UserDefaults Keys - - private let cooldownIntervalKey = "assistantsCooldownInterval" - private let glowOverlayEnabledKey = "assistantsGlowOverlayEnabled" - private let analysisDelayKey = "assistantsAnalysisDelay" - private let screenAnalysisEnabledKey = "screenAnalysisEnabled" - private let transcriptionEnabledKey = "transcriptionEnabled" - private let transcriptionLanguageKey = "transcriptionLanguage" - private let transcriptionAutoDetectKey = "transcriptionAutoDetect" - private let transcriptionVocabularyKey = "transcriptionVocabulary" - private let vadGateEnabledKey = "vadGateEnabled" - private let batchTranscriptionEnabledKey = "batchTranscriptionEnabled" - - // MARK: - Default Values - - private let defaultCooldownInterval = 10 // minutes - private let defaultGlowOverlayEnabled = false - private let defaultAnalysisDelay = 60 // seconds (1 minute) - private let defaultScreenAnalysisEnabled = true - private let defaultTranscriptionEnabled = true - private let defaultTranscriptionLanguage = "en" - private let defaultTranscriptionAutoDetect = true - private let defaultTranscriptionVocabulary: [String] = [] - private let defaultVadGateEnabled = false - private let defaultBatchTranscriptionEnabled = false - - private init() { - // Register defaults - UserDefaults.standard.register(defaults: [ - cooldownIntervalKey: defaultCooldownInterval, - glowOverlayEnabledKey: defaultGlowOverlayEnabled, - analysisDelayKey: defaultAnalysisDelay, - screenAnalysisEnabledKey: defaultScreenAnalysisEnabled, - transcriptionEnabledKey: defaultTranscriptionEnabled, - transcriptionLanguageKey: defaultTranscriptionLanguage, - transcriptionAutoDetectKey: defaultTranscriptionAutoDetect, - transcriptionVocabularyKey: defaultTranscriptionVocabulary, - vadGateEnabledKey: defaultVadGateEnabled, - batchTranscriptionEnabledKey: defaultBatchTranscriptionEnabled, - ]) + static let shared = AssistantSettings() + + // MARK: - UserDefaults Keys + + private let cooldownIntervalKey = "assistantsCooldownInterval" + private let glowOverlayEnabledKey = "assistantsGlowOverlayEnabled" + private let analysisDelayKey = "assistantsAnalysisDelay" + private let screenAnalysisEnabledKey = "screenAnalysisEnabled" + private let transcriptionEnabledKey = "transcriptionEnabled" + private let transcriptionLanguageKey = "transcriptionLanguage" + private let transcriptionAutoDetectKey = "transcriptionAutoDetect" + private let transcriptionVocabularyKey = "transcriptionVocabulary" + private let transcriptionProviderModeKey = "transcriptionProviderMode" + private let transcriptionQualityPresetKey = "transcriptionQualityPreset" + private let vadGateEnabledKey = "vadGateEnabled" + private let batchTranscriptionEnabledKey = "batchTranscriptionEnabled" + + // MARK: - Default Values + + private let defaultCooldownInterval = 10 // minutes + private let defaultGlowOverlayEnabled = false + private let defaultAnalysisDelay = 60 // seconds (1 minute) + private let defaultScreenAnalysisEnabled = true + private let defaultTranscriptionEnabled = true + private let defaultTranscriptionLanguage = "en" + private let defaultTranscriptionAutoDetect = true + private let defaultTranscriptionVocabulary: [String] = [] + private let defaultTranscriptionProviderMode = TranscriptionProviderKind.cloud.rawValue + private let defaultTranscriptionQualityPreset = TranscriptionQualityPreset.auto.rawValue + private let defaultVadGateEnabled = false + private let defaultBatchTranscriptionEnabled = false + + private init() { + // Register defaults + UserDefaults.standard.register(defaults: [ + cooldownIntervalKey: defaultCooldownInterval, + glowOverlayEnabledKey: defaultGlowOverlayEnabled, + analysisDelayKey: defaultAnalysisDelay, + screenAnalysisEnabledKey: defaultScreenAnalysisEnabled, + transcriptionEnabledKey: defaultTranscriptionEnabled, + transcriptionLanguageKey: defaultTranscriptionLanguage, + transcriptionAutoDetectKey: defaultTranscriptionAutoDetect, + transcriptionVocabularyKey: defaultTranscriptionVocabulary, + transcriptionProviderModeKey: defaultTranscriptionProviderMode, + transcriptionQualityPresetKey: defaultTranscriptionQualityPreset, + vadGateEnabledKey: defaultVadGateEnabled, + batchTranscriptionEnabledKey: defaultBatchTranscriptionEnabled, + ]) + } + + // MARK: - Properties + + /// Cooldown interval between notifications in minutes + var cooldownInterval: Int { + get { + let value = UserDefaults.standard.integer(forKey: cooldownIntervalKey) + return value > 0 ? value : defaultCooldownInterval } - - // MARK: - Properties - - /// Cooldown interval between notifications in minutes - var cooldownInterval: Int { - get { - let value = UserDefaults.standard.integer(forKey: cooldownIntervalKey) - return value > 0 ? value : defaultCooldownInterval - } - set { - UserDefaults.standard.set(newValue, forKey: cooldownIntervalKey) - NotificationCenter.default.post(name: .assistantSettingsDidChange, object: nil) - } - } - - /// Cooldown interval in seconds (for NotificationService) - var cooldownIntervalSeconds: TimeInterval { - return TimeInterval(cooldownInterval * 60) - } - - /// Whether the glow overlay effect is enabled - var glowOverlayEnabled: Bool { - get { UserDefaults.standard.bool(forKey: glowOverlayEnabledKey) } - set { - UserDefaults.standard.set(newValue, forKey: glowOverlayEnabledKey) - NotificationCenter.default.post(name: .assistantSettingsDidChange, object: nil) - } - } - - /// Delay in seconds before analyzing after an app switch (0 = instant, 60 = 1 min, 300 = 5 min) - var analysisDelay: Int { - get { - let value = UserDefaults.standard.integer(forKey: analysisDelayKey) - return value >= 0 ? value : defaultAnalysisDelay - } - set { - UserDefaults.standard.set(newValue, forKey: analysisDelayKey) - NotificationCenter.default.post(name: .assistantSettingsDidChange, object: nil) - } + set { + UserDefaults.standard.set(newValue, forKey: cooldownIntervalKey) + NotificationCenter.default.post(name: .assistantSettingsDidChange, object: nil) } - - /// Whether screen analysis (proactive monitoring) should be enabled - var screenAnalysisEnabled: Bool { - get { UserDefaults.standard.bool(forKey: screenAnalysisEnabledKey) } - set { - UserDefaults.standard.set(newValue, forKey: screenAnalysisEnabledKey) - NotificationCenter.default.post(name: .assistantSettingsDidChange, object: nil) - } - } - - /// Whether transcription should be enabled - var transcriptionEnabled: Bool { - get { UserDefaults.standard.bool(forKey: transcriptionEnabledKey) } - set { - UserDefaults.standard.set(newValue, forKey: transcriptionEnabledKey) - NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) - } - } - - /// The language code for transcription (e.g., "en", "uk", "ru") - var transcriptionLanguage: String { - get { - let value = UserDefaults.standard.string(forKey: transcriptionLanguageKey) - return value ?? defaultTranscriptionLanguage - } - set { - UserDefaults.standard.set(newValue, forKey: transcriptionLanguageKey) - NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) - } - } - - /// Whether auto-detect (multi-language) mode is enabled - /// When true, DeepGram will auto-detect the language - /// When false, uses the specific language set in transcriptionLanguage - var transcriptionAutoDetect: Bool { - get { UserDefaults.standard.bool(forKey: transcriptionAutoDetectKey) } - set { - UserDefaults.standard.set(newValue, forKey: transcriptionAutoDetectKey) - NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) - } + } + + /// Cooldown interval in seconds (for NotificationService) + var cooldownIntervalSeconds: TimeInterval { + return TimeInterval(cooldownInterval * 60) + } + + /// Whether the glow overlay effect is enabled + var glowOverlayEnabled: Bool { + get { UserDefaults.standard.bool(forKey: glowOverlayEnabledKey) } + set { + UserDefaults.standard.set(newValue, forKey: glowOverlayEnabledKey) + NotificationCenter.default.post(name: .assistantSettingsDidChange, object: nil) } + } - /// Returns the effective language to send to DeepGram - /// If auto-detect is enabled and the language supports multi-language mode, returns "multi" - /// Otherwise returns the specific language code - var effectiveTranscriptionLanguage: String { - if transcriptionAutoDetect { - // Languages that support multi-language detection in DeepGram Nova-3 - let multiLanguageSupported: Set = [ - "en", "en-US", "en-AU", "en-GB", "en-IN", "en-NZ", - "es", "es-419", - "fr", "fr-CA", - "de", - "hi", - "ru", - "pt", "pt-BR", "pt-PT", - "ja", - "it", - "nl" - ] - - // If the selected language supports multi-language mode, use "multi" - // Otherwise fall back to single language (e.g., Ukrainian doesn't support multi) - if multiLanguageSupported.contains(transcriptionLanguage) { - return "multi" - } - } - return transcriptionLanguage + /// Delay in seconds before analyzing after an app switch (0 = instant, 60 = 1 min, 300 = 5 min) + var analysisDelay: Int { + get { + let value = UserDefaults.standard.integer(forKey: analysisDelayKey) + return value >= 0 ? value : defaultAnalysisDelay } - - /// Custom vocabulary for improved transcription accuracy - /// Array of words/terms that DeepGram should recognize (Nova-3 limit: 500 tokens total) - var transcriptionVocabulary: [String] { - get { - let value = UserDefaults.standard.stringArray(forKey: transcriptionVocabularyKey) - return value ?? defaultTranscriptionVocabulary - } - set { - UserDefaults.standard.set(newValue, forKey: transcriptionVocabularyKey) - NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) - } + set { + UserDefaults.standard.set(newValue, forKey: analysisDelayKey) + NotificationCenter.default.post(name: .assistantSettingsDidChange, object: nil) } - - /// Returns vocabulary as comma-separated string for display - var transcriptionVocabularyString: String { - get { - return transcriptionVocabulary.joined(separator: ", ") - } - set { - let terms = newValue - .split(separator: ",") - .map { $0.trimmingCharacters(in: .whitespaces) } - .filter { !$0.isEmpty } - transcriptionVocabulary = terms - } + } + + /// Whether screen analysis (proactive monitoring) should be enabled + var screenAnalysisEnabled: Bool { + get { UserDefaults.standard.bool(forKey: screenAnalysisEnabledKey) } + set { + UserDefaults.standard.set(newValue, forKey: screenAnalysisEnabledKey) + NotificationCenter.default.post(name: .assistantSettingsDidChange, object: nil) } - - /// Whether batch transcription mode is enabled (transcribes audio in chunks at silence boundaries) - var batchTranscriptionEnabled: Bool { - get { UserDefaults.standard.bool(forKey: batchTranscriptionEnabledKey) } - set { - UserDefaults.standard.set(newValue, forKey: batchTranscriptionEnabledKey) - NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) - } + } + + /// Whether transcription should be enabled + var transcriptionEnabled: Bool { + get { UserDefaults.standard.bool(forKey: transcriptionEnabledKey) } + set { + UserDefaults.standard.set(newValue, forKey: transcriptionEnabledKey) + NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) } + } - /// Whether local VAD gate is enabled to skip silence and reduce Deepgram usage - var vadGateEnabled: Bool { - get { UserDefaults.standard.bool(forKey: vadGateEnabledKey) } - set { - UserDefaults.standard.set(newValue, forKey: vadGateEnabledKey) - NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) - } + /// The language code for transcription (e.g., "en", "uk", "ru") + var transcriptionLanguage: String { + get { + let value = UserDefaults.standard.string(forKey: transcriptionLanguageKey) + return value ?? defaultTranscriptionLanguage } - - /// Returns vocabulary with "Omi" always included (for DeepGram) - var effectiveVocabulary: [String] { - var vocab = Set(transcriptionVocabulary) - vocab.insert("Omi") - return Array(vocab) + set { + UserDefaults.standard.set(newValue, forKey: transcriptionLanguageKey) + NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) } - - /// Reset all settings to defaults - func resetToDefaults() { - cooldownInterval = defaultCooldownInterval - glowOverlayEnabled = defaultGlowOverlayEnabled - analysisDelay = defaultAnalysisDelay - screenAnalysisEnabled = defaultScreenAnalysisEnabled - transcriptionEnabled = defaultTranscriptionEnabled - transcriptionLanguage = defaultTranscriptionLanguage - transcriptionAutoDetect = defaultTranscriptionAutoDetect - transcriptionVocabulary = defaultTranscriptionVocabulary - vadGateEnabled = defaultVadGateEnabled - batchTranscriptionEnabled = defaultBatchTranscriptionEnabled + } + + /// Whether auto-detect (multi-language) mode is enabled + /// When true, DeepGram will auto-detect the language + /// When false, uses the specific language set in transcriptionLanguage + var transcriptionAutoDetect: Bool { + get { UserDefaults.standard.bool(forKey: transcriptionAutoDetectKey) } + set { + UserDefaults.standard.set(newValue, forKey: transcriptionAutoDetectKey) + NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) } - - // MARK: - Supported Languages - - /// All languages supported by DeepGram Nova-3 for single-language transcription - static let supportedLanguages: [(code: String, name: String)] = [ - ("en", "English"), - ("en-US", "English (US)"), - ("en-GB", "English (UK)"), - ("en-AU", "English (Australia)"), - ("en-IN", "English (India)"), - ("en-NZ", "English (New Zealand)"), - ("bg", "Bulgarian"), - ("ca", "Catalan"), - ("cs", "Czech"), - ("da", "Danish"), - ("nl", "Dutch"), - ("nl-BE", "Dutch (Belgium)"), - ("et", "Estonian"), - ("fi", "Finnish"), - ("fr", "French"), - ("fr-CA", "French (Canada)"), - ("de", "German"), - ("de-CH", "German (Switzerland)"), - ("el", "Greek"), - ("hi", "Hindi"), - ("hu", "Hungarian"), - ("id", "Indonesian"), - ("it", "Italian"), - ("ja", "Japanese"), - ("ko", "Korean"), - ("lv", "Latvian"), - ("lt", "Lithuanian"), - ("ms", "Malay"), - ("no", "Norwegian"), - ("pl", "Polish"), - ("pt", "Portuguese"), - ("pt-BR", "Portuguese (Brazil)"), - ("pt-PT", "Portuguese (Portugal)"), - ("ro", "Romanian"), - ("ru", "Russian"), - ("sk", "Slovak"), - ("es", "Spanish"), - ("es-419", "Spanish (Latin America)"), - ("sv", "Swedish"), - ("tr", "Turkish"), - ("uk", "Ukrainian"), - ("vi", "Vietnamese"), - ] - - /// Languages that support multi-language (auto-detect) mode in DeepGram Nova-3 - static let multiLanguageSupported: Set = [ + } + + /// Returns the effective language to send to DeepGram + /// If auto-detect is enabled and the language supports multi-language mode, returns "multi" + /// Otherwise returns the specific language code + var effectiveTranscriptionLanguage: String { + if transcriptionAutoDetect { + // Languages that support multi-language detection in DeepGram Nova-3 + let multiLanguageSupported: Set = [ "en", "en-US", "en-AU", "en-GB", "en-IN", "en-NZ", "es", "es-419", "fr", "fr-CA", @@ -281,22 +150,186 @@ class AssistantSettings { "pt", "pt-BR", "pt-PT", "ja", "it", - "nl" - ] + "nl", + ] + + // If the selected language supports multi-language mode, use "multi" + // Otherwise fall back to single language (e.g., Ukrainian doesn't support multi) + if multiLanguageSupported.contains(transcriptionLanguage) { + return "multi" + } + } + return transcriptionLanguage + } + + /// Custom vocabulary for improved transcription accuracy + /// Array of words/terms that DeepGram should recognize (Nova-3 limit: 500 tokens total) + var transcriptionVocabulary: [String] { + get { + let value = UserDefaults.standard.stringArray(forKey: transcriptionVocabularyKey) + return value ?? defaultTranscriptionVocabulary + } + set { + UserDefaults.standard.set(newValue, forKey: transcriptionVocabularyKey) + NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) + } + } - /// Check if a language supports auto-detect mode - static func supportsAutoDetect(_ languageCode: String) -> Bool { - return multiLanguageSupported.contains(languageCode) + /// Returns vocabulary as comma-separated string for display + var transcriptionVocabularyString: String { + get { + return transcriptionVocabulary.joined(separator: ", ") + } + set { + let terms = + newValue + .split(separator: ",") + .map { $0.trimmingCharacters(in: .whitespaces) } + .filter { !$0.isEmpty } + transcriptionVocabulary = terms + } + } + + /// Provider selection policy for transcription. Defaults to cloud so existing + /// desktop behavior remains unchanged until local-first onboarding is wired. + var transcriptionProviderSelection: TranscriptionProviderSelection { + get { + let modeValue = UserDefaults.standard.string(forKey: transcriptionProviderModeKey) + let qualityValue = UserDefaults.standard.string(forKey: transcriptionQualityPresetKey) + return TranscriptionProviderSelection( + mode: TranscriptionProviderKind(rawValue: modeValue ?? defaultTranscriptionProviderMode) + ?? .cloud, + quality: TranscriptionQualityPreset( + rawValue: qualityValue ?? defaultTranscriptionQualityPreset) ?? .auto + ) + } + set { + UserDefaults.standard.set(newValue.mode.rawValue, forKey: transcriptionProviderModeKey) + UserDefaults.standard.set(newValue.quality.rawValue, forKey: transcriptionQualityPresetKey) + NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) + } + } + + /// Whether batch transcription mode is enabled (transcribes audio in chunks at silence boundaries) + var batchTranscriptionEnabled: Bool { + get { UserDefaults.standard.bool(forKey: batchTranscriptionEnabledKey) } + set { + UserDefaults.standard.set(newValue, forKey: batchTranscriptionEnabledKey) + NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) + } + } + + /// Whether local VAD gate is enabled to skip silence and reduce Deepgram usage + var vadGateEnabled: Bool { + get { UserDefaults.standard.bool(forKey: vadGateEnabledKey) } + set { + UserDefaults.standard.set(newValue, forKey: vadGateEnabledKey) + NotificationCenter.default.post(name: .transcriptionSettingsDidChange, object: nil) } + } + + /// Returns vocabulary with "Omi" always included (for DeepGram) + var effectiveVocabulary: [String] { + var vocab = Set(transcriptionVocabulary) + vocab.insert("Omi") + return Array(vocab) + } + + /// Reset all settings to defaults + func resetToDefaults() { + cooldownInterval = defaultCooldownInterval + glowOverlayEnabled = defaultGlowOverlayEnabled + analysisDelay = defaultAnalysisDelay + screenAnalysisEnabled = defaultScreenAnalysisEnabled + transcriptionEnabled = defaultTranscriptionEnabled + transcriptionLanguage = defaultTranscriptionLanguage + transcriptionAutoDetect = defaultTranscriptionAutoDetect + transcriptionVocabulary = defaultTranscriptionVocabulary + transcriptionProviderSelection = TranscriptionProviderSelection( + mode: TranscriptionProviderKind(rawValue: defaultTranscriptionProviderMode) ?? .cloud, + quality: TranscriptionQualityPreset(rawValue: defaultTranscriptionQualityPreset) ?? .auto + ) + vadGateEnabled = defaultVadGateEnabled + batchTranscriptionEnabled = defaultBatchTranscriptionEnabled + } + + // MARK: - Supported Languages + + /// All languages supported by DeepGram Nova-3 for single-language transcription + static let supportedLanguages: [(code: String, name: String)] = [ + ("en", "English"), + ("en-US", "English (US)"), + ("en-GB", "English (UK)"), + ("en-AU", "English (Australia)"), + ("en-IN", "English (India)"), + ("en-NZ", "English (New Zealand)"), + ("bg", "Bulgarian"), + ("ca", "Catalan"), + ("cs", "Czech"), + ("da", "Danish"), + ("nl", "Dutch"), + ("nl-BE", "Dutch (Belgium)"), + ("et", "Estonian"), + ("fi", "Finnish"), + ("fr", "French"), + ("fr-CA", "French (Canada)"), + ("de", "German"), + ("de-CH", "German (Switzerland)"), + ("el", "Greek"), + ("hi", "Hindi"), + ("hu", "Hungarian"), + ("id", "Indonesian"), + ("it", "Italian"), + ("ja", "Japanese"), + ("ko", "Korean"), + ("lv", "Latvian"), + ("lt", "Lithuanian"), + ("ms", "Malay"), + ("no", "Norwegian"), + ("pl", "Polish"), + ("pt", "Portuguese"), + ("pt-BR", "Portuguese (Brazil)"), + ("pt-PT", "Portuguese (Portugal)"), + ("ro", "Romanian"), + ("ru", "Russian"), + ("sk", "Slovak"), + ("es", "Spanish"), + ("es-419", "Spanish (Latin America)"), + ("sv", "Swedish"), + ("tr", "Turkish"), + ("uk", "Ukrainian"), + ("vi", "Vietnamese"), + ] + + /// Languages that support multi-language (auto-detect) mode in DeepGram Nova-3 + static let multiLanguageSupported: Set = [ + "en", "en-US", "en-AU", "en-GB", "en-IN", "en-NZ", + "es", "es-419", + "fr", "fr-CA", + "de", + "hi", + "ru", + "pt", "pt-BR", "pt-PT", + "ja", + "it", + "nl", + ] + + /// Check if a language supports auto-detect mode + static func supportsAutoDetect(_ languageCode: String) -> Bool { + return multiLanguageSupported.contains(languageCode) + } } // MARK: - Notification Names extension Notification.Name { - static let assistantSettingsDidChange = Notification.Name("assistantSettingsDidChange") - static let assistantMonitoringStateDidChange = Notification.Name("assistantMonitoringStateDidChange") - static let assistantMonitoringToggleRequested = Notification.Name("assistantMonitoringToggleRequested") - static let transcriptionSettingsDidChange = Notification.Name("transcriptionSettingsDidChange") + static let assistantSettingsDidChange = Notification.Name("assistantSettingsDidChange") + static let assistantMonitoringStateDidChange = Notification.Name( + "assistantMonitoringStateDidChange") + static let assistantMonitoringToggleRequested = Notification.Name( + "assistantMonitoringToggleRequested") + static let transcriptionSettingsDidChange = Notification.Name("transcriptionSettingsDidChange") } // MARK: - Backward Compatibility @@ -305,6 +338,6 @@ extension Notification.Name { typealias FocusSettings = AssistantSettings extension Notification.Name { - static let focusSettingsDidChange = Notification.Name.assistantSettingsDidChange - static let focusMonitoringStateDidChange = Notification.Name.assistantMonitoringStateDidChange + static let focusSettingsDidChange = Notification.Name.assistantSettingsDidChange + static let focusMonitoringStateDidChange = Notification.Name.assistantMonitoringStateDidChange } diff --git a/desktop/Desktop/Sources/TranscriptionProvider.swift b/desktop/Desktop/Sources/TranscriptionProvider.swift new file mode 100644 index 00000000000..55567049f16 --- /dev/null +++ b/desktop/Desktop/Sources/TranscriptionProvider.swift @@ -0,0 +1,428 @@ +import Darwin +import Foundation + +enum TranscriptionProviderKind: String, CaseIterable, Codable, Equatable { + case auto + case local + case cloud +} + +enum TranscriptionQualityPreset: String, CaseIterable, Codable, Equatable { + case auto + case fast + case balanced + case accurate +} + +enum LocalTranscriptionEngine: String, CaseIterable, Codable, Equatable, Hashable { + case mlxWhisper = "mlx-whisper" + case fasterWhisper = "faster-whisper" +} + +struct TranscriptionProviderSelection: Codable, Equatable { + var mode: TranscriptionProviderKind + var quality: TranscriptionQualityPreset + + static let `default` = TranscriptionProviderSelection(mode: .auto, quality: .auto) +} + +struct LocalTranscriptionCapabilities: Equatable { + enum Processor: Equatable { + case nativeAppleSilicon + case rosettaOnAppleSilicon + case intel + case unknown + } + + var processor: Processor + var physicalMemoryBytes: UInt64 + var availableEngines: Set + + var isNativeAppleSilicon: Bool { + processor == .nativeAppleSilicon + } + + var canUseMLXWhisper: Bool { + isNativeAppleSilicon && availableEngines.contains(.mlxWhisper) + } + + var canUseFasterWhisper: Bool { + availableEngines.contains(.fasterWhisper) + } + + var canUseAnyLocalEngine: Bool { + canUseMLXWhisper || canUseFasterWhisper + } +} + +struct LocalTranscriptionCapabilityDetector { + var physicalMemoryBytes: () -> UInt64 = { ProcessInfo.processInfo.physicalMemory } + var isTranslatedProcess: () -> Bool = { + var translated: Int32 = 0 + var size = MemoryLayout.size + let result = sysctlbyname("sysctl.proc_translated", &translated, &size, nil, 0) + return result == 0 && translated == 1 + } + var availableEngines: () -> Set = { [] } + + func detect() -> LocalTranscriptionCapabilities { + LocalTranscriptionCapabilities( + processor: detectProcessor(), + physicalMemoryBytes: physicalMemoryBytes(), + availableEngines: availableEngines() + ) + } + + private func detectProcessor() -> LocalTranscriptionCapabilities.Processor { + #if arch(arm64) + return isTranslatedProcess() ? .rosettaOnAppleSilicon : .nativeAppleSilicon + #elseif arch(x86_64) + return isTranslatedProcess() ? .rosettaOnAppleSilicon : .intel + #else + return .unknown + #endif + } +} + +struct TranscriptionProviderPolicyResult: Equatable { + var provider: TranscriptionProviderKind + var quality: TranscriptionQualityPreset + var localEngine: LocalTranscriptionEngine? + var fallbackReason: String? + + var usesCloud: Bool { + provider == .cloud + } + + var usesLocal: Bool { + provider == .local + } +} + +struct TranscriptionProviderPolicy { + func resolve( + selection: TranscriptionProviderSelection, + capabilities: LocalTranscriptionCapabilities + ) -> TranscriptionProviderPolicyResult { + let quality = selection.quality + + switch selection.mode { + case .cloud: + return TranscriptionProviderPolicyResult( + provider: .cloud, quality: quality, localEngine: nil, fallbackReason: nil) + case .local: + if let engine = preferredLocalEngine(for: quality, capabilities: capabilities) { + return TranscriptionProviderPolicyResult( + provider: .local, quality: quality, localEngine: engine, fallbackReason: nil) + } + return TranscriptionProviderPolicyResult( + provider: .cloud, + quality: quality, + localEngine: nil, + fallbackReason: "No local transcription engine is available" + ) + case .auto: + if let engine = preferredLocalEngine(for: quality, capabilities: capabilities) { + return TranscriptionProviderPolicyResult( + provider: .local, quality: quality, localEngine: engine, fallbackReason: nil) + } + return TranscriptionProviderPolicyResult( + provider: .cloud, + quality: quality, + localEngine: nil, + fallbackReason: "Auto mode fell back to cloud because no local engine is available" + ) + } + } + + private func preferredLocalEngine( + for _: TranscriptionQualityPreset, + capabilities: LocalTranscriptionCapabilities + ) -> LocalTranscriptionEngine? { + if capabilities.canUseMLXWhisper { + return .mlxWhisper + } + if capabilities.canUseFasterWhisper { + return .fasterWhisper + } + return nil + } +} + +struct NormalizedTranscriptTranslation: Codable, Equatable { + var lang: String + var text: String +} + +struct NormalizedTranscriptSegment: Codable, Equatable, Identifiable { + var id: String { segmentId ?? "\(speaker)-\(start)" } + var segmentId: String? + var speaker: Int + var speakerLabel: String? + var text: String + var start: Double + var end: Double + var isUser: Bool + var personId: String? + var translations: [NormalizedTranscriptTranslation] +} + +enum TranscriptionProviderConnectionState: Equatable { + case idle + case starting + case connected + case stopping + case stopped + case failed(String) +} + +struct TranscriptionProviderCapabilities: Equatable { + var provider: TranscriptionProviderKind + var supportsStreaming: Bool + var supportsBatch: Bool + var supportsLocalProcessing: Bool + var supportsSpeakerDiarization: Bool + var supportsTranslations: Bool + var localEngine: LocalTranscriptionEngine? +} + +struct TranscriptionProviderEvent { + var type: String + var raw: [String: Any] +} + +struct TranscriptionProviderCallbacks { + var onSegments: ([NormalizedTranscriptSegment]) -> Void + var onEvent: (TranscriptionProviderEvent) -> Void + var onError: (Error) -> Void + var onConnected: () -> Void + var onDisconnected: () -> Void +} + +struct TranscriptionProviderConfiguration: Equatable { + var language: String + var mode: TranscriptionService.StreamingMode + var contextKeywords: [String] + + static func conversation(language: String, contextKeywords: [String] = []) + -> TranscriptionProviderConfiguration + { + TranscriptionProviderConfiguration( + language: language, mode: .conversation, contextKeywords: contextKeywords) + } +} + +protocol TranscriptionProvider: AnyObject { + var status: TranscriptionProviderConnectionState { get } + var capabilities: TranscriptionProviderCapabilities { get } + var failureState: Error? { get } + + func start( + configuration: TranscriptionProviderConfiguration, callbacks: TranscriptionProviderCallbacks) + func sendAudio(_ data: Data) + func finalize() + func stop() +} + +final class CloudTranscriptionProvider: TranscriptionProvider { + private var service: TranscriptionService? + private(set) var status: TranscriptionProviderConnectionState = .idle + private(set) var failureState: Error? + + let capabilities = TranscriptionProviderCapabilities( + provider: .cloud, + supportsStreaming: true, + supportsBatch: true, + supportsLocalProcessing: false, + supportsSpeakerDiarization: true, + supportsTranslations: true, + localEngine: nil + ) + + func start( + configuration: TranscriptionProviderConfiguration, callbacks: TranscriptionProviderCallbacks + ) { + do { + status = .starting + let service = try TranscriptionService( + language: configuration.language, + mode: configuration.mode, + contextKeywords: configuration.contextKeywords + ) + self.service = service + service.start( + onSegments: { callbacks.onSegments($0.map { $0.normalized }) }, + onEvent: { + callbacks.onEvent(TranscriptionProviderEvent(type: $0.type, raw: $0.raw)) + }, + onError: { [weak self] error in + self?.failureState = error + self?.status = .failed(error.localizedDescription) + callbacks.onError(error) + }, + onConnected: { [weak self] in + self?.status = .connected + callbacks.onConnected() + }, + onDisconnected: { [weak self] in + self?.status = .stopped + callbacks.onDisconnected() + } + ) + } catch { + failureState = error + status = .failed(error.localizedDescription) + callbacks.onError(error) + } + } + + func sendAudio(_ data: Data) { + service?.sendAudio(data) + } + + func finalize() { + service?.finishStream() + } + + func stop() { + status = .stopping + service?.stop() + service = nil + status = .stopped + } +} + +final class LocalWhisperTranscriptionProvider: TranscriptionProvider { + private(set) var status: TranscriptionProviderConnectionState = .idle + private(set) var failureState: Error? + let capabilities: TranscriptionProviderCapabilities + + init(engine: LocalTranscriptionEngine?) { + self.capabilities = TranscriptionProviderCapabilities( + provider: .local, + supportsStreaming: false, + supportsBatch: true, + supportsLocalProcessing: true, + supportsSpeakerDiarization: false, + supportsTranslations: false, + localEngine: engine + ) + } + + func start( + configuration: TranscriptionProviderConfiguration, callbacks: TranscriptionProviderCallbacks + ) { + let error = TranscriptionService.TranscriptionError.webSocketError( + "Local Whisper provider helper is not implemented in this ticket" + ) + failureState = error + status = .failed(error.localizedDescription) + callbacks.onError(error) + } + + func sendAudio(_ data: Data) {} + + func finalize() {} + + func stop() { + status = .stopped + } +} + +struct SpeakerSegmentReducer { + struct ApplyResult: Equatable { + var added: Int = 0 + var updated: Int = 0 + var totalSegmentCount: Int = 0 + var totalWordCount: Int = 0 + } + + private(set) var segments: [SpeakerSegment] = [] + private(set) var totalSegmentCount: Int = 0 + private(set) var totalWordCount: Int = 0 + var maxInMemorySegments: Int + + init(maxInMemorySegments: Int) { + self.maxInMemorySegments = maxInMemorySegments + } + + mutating func reset() { + segments = [] + totalSegmentCount = 0 + totalWordCount = 0 + } + + mutating func replaceSegments(_ replacement: [SpeakerSegment]) { + segments = replacement + totalWordCount = replacement.reduce(0) { $0 + wordCount($1.text) } + } + + mutating func apply(_ incomingSegments: [SpeakerSegment]) -> ApplyResult { + var result = ApplyResult() + + for incoming in incomingSegments where !incoming.text.isEmpty { + if let segId = incoming.segmentId, + let existingIdx = segments.firstIndex(where: { $0.segmentId == segId }) + { + let oldWords = wordCount(segments[existingIdx].text) + var updated = incoming + if updated.translations.isEmpty && !segments[existingIdx].translations.isEmpty { + updated.translations = segments[existingIdx].translations + } + segments[existingIdx] = updated + totalWordCount += wordCount(updated.text) - oldWords + result.updated += 1 + } else { + segments.append(incoming) + totalSegmentCount += 1 + totalWordCount += wordCount(incoming.text) + result.added += 1 + } + } + + if segments.count > maxInMemorySegments { + segments.removeFirst(segments.count - maxInMemorySegments) + } + + result.totalSegmentCount = totalSegmentCount + result.totalWordCount = totalWordCount + return result + } + + mutating func deleteSegmentIds(_ segmentIds: [String]) -> Int { + let deletedSegments = segments.filter { segment in + guard let segmentId = segment.segmentId else { return false } + return segmentIds.contains(segmentId) + } + let deletedWords = deletedSegments.reduce(0) { $0 + wordCount($1.text) } + totalWordCount = max(0, totalWordCount - deletedWords) + totalSegmentCount = max(0, totalSegmentCount - deletedSegments.count) + segments.removeAll { segment in + guard let segmentId = segment.segmentId else { return false } + return segmentIds.contains(segmentId) + } + return deletedSegments.count + } + + private func wordCount(_ text: String) -> Int { + text.split(separator: " ").count + } +} + +extension TranscriptionService.BackendSegment { + var normalized: NormalizedTranscriptSegment { + NormalizedTranscriptSegment( + segmentId: id, + speaker: speaker_id ?? 0, + speakerLabel: speaker, + text: text, + start: start, + end: end, + isUser: is_user, + personId: person_id, + translations: (translations ?? []).map { + NormalizedTranscriptTranslation(lang: $0.lang, text: $0.text) + } + ) + } +} diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index 6d9f0068f23..d040e7668fd 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -12,7 +12,7 @@ class TranscriptionService { // MARK: - Types /// Streaming mode determines which backend endpoint and parameters are used. - enum StreamingMode { + enum StreamingMode: Equatable { /// Conversation capture via `/v4/listen` — full pipeline with speech profiles, /// speaker assignment, memory creation events, and conversation lifecycle. case conversation diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift new file mode 100644 index 00000000000..3108c5f0e48 --- /dev/null +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -0,0 +1,153 @@ +import XCTest + +@testable import Omi_Computer + +final class TranscriptionProviderPolicyTests: XCTestCase { + func testAutoPrefersMLXOnlyOnNativeAppleSilicon() { + let capabilities = LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper, .fasterWhisper] + ) + + let result = TranscriptionProviderPolicy().resolve( + selection: TranscriptionProviderSelection(mode: .auto, quality: .balanced), + capabilities: capabilities + ) + + XCTAssertEqual(result.provider, .local) + XCTAssertEqual(result.localEngine, .mlxWhisper) + XCTAssertNil(result.fallbackReason) + } + + func testAutoFallsBackToFasterWhisperWhenMLXIsNotViable() { + let capabilities = LocalTranscriptionCapabilities( + processor: .rosettaOnAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper, .fasterWhisper] + ) + + let result = TranscriptionProviderPolicy().resolve( + selection: TranscriptionProviderSelection(mode: .auto, quality: .fast), + capabilities: capabilities + ) + + XCTAssertEqual(result.provider, .local) + XCTAssertEqual(result.localEngine, .fasterWhisper) + } + + func testAutoFallsBackToCloudWhenNoLocalEngineExists() { + let capabilities = LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 8 * 1024 * 1024 * 1024, + availableEngines: [] + ) + + let result = TranscriptionProviderPolicy().resolve( + selection: TranscriptionProviderSelection(mode: .auto, quality: .auto), + capabilities: capabilities + ) + + XCTAssertEqual(result.provider, .cloud) + XCTAssertNil(result.localEngine) + XCTAssertNotNil(result.fallbackReason) + } + + func testCloudSelectionAlwaysUsesCloud() { + let capabilities = LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper] + ) + + let result = TranscriptionProviderPolicy().resolve( + selection: TranscriptionProviderSelection(mode: .cloud, quality: .accurate), + capabilities: capabilities + ) + + XCTAssertEqual(result.provider, .cloud) + XCTAssertNil(result.localEngine) + XCTAssertNil(result.fallbackReason) + } + + func testCapabilityDetectorDistinguishesRosettaFromNativeArm() { + let native = LocalTranscriptionCapabilityDetector( + physicalMemoryBytes: { 1 }, + isTranslatedProcess: { false }, + availableEngines: { [] } + ).detect() + let translated = LocalTranscriptionCapabilityDetector( + physicalMemoryBytes: { 1 }, + isTranslatedProcess: { true }, + availableEngines: { [] } + ).detect() + + #if arch(arm64) + XCTAssertEqual(native.processor, .nativeAppleSilicon) + XCTAssertEqual(translated.processor, .rosettaOnAppleSilicon) + #elseif arch(x86_64) + XCTAssertEqual(native.processor, .intel) + XCTAssertEqual(translated.processor, .rosettaOnAppleSilicon) + #else + XCTAssertEqual(native.processor, .unknown) + XCTAssertEqual(translated.processor, .unknown) + #endif + } +} + +final class SpeakerSegmentReducerTests: XCTestCase { + func testReducerAddsUpdatesAndPreservesTranslations() { + var reducer = SpeakerSegmentReducer(maxInMemorySegments: 10) + let initial = SpeakerSegment( + segmentId: "s1", + speaker: 0, + text: "hello world", + start: 0, + end: 1, + isUser: true, + personId: "p1", + translations: [SegmentTranslation(lang: "es", text: "hola mundo")] + ) + + let first = reducer.apply([initial]) + XCTAssertEqual(first.added, 1) + XCTAssertEqual(reducer.totalSegmentCount, 1) + XCTAssertEqual(reducer.totalWordCount, 2) + + let updateWithoutTranslations = SpeakerSegment( + segmentId: "s1", + speaker: 0, + text: "hello again world", + start: 0, + end: 1.5, + isUser: true, + personId: "p1", + translations: [] + ) + + let second = reducer.apply([updateWithoutTranslations]) + XCTAssertEqual(second.updated, 1) + XCTAssertEqual(reducer.totalSegmentCount, 1) + XCTAssertEqual(reducer.totalWordCount, 3) + XCTAssertEqual(reducer.segments.first?.translations.first?.text, "hola mundo") + } + + func testReducerTrimsInMemorySegmentsButKeepsTotalCount() { + var reducer = SpeakerSegmentReducer(maxInMemorySegments: 2) + + for index in 0..<3 { + _ = reducer.apply([ + SpeakerSegment( + segmentId: "s\(index)", + speaker: 0, + text: "word", + start: Double(index), + end: Double(index + 1) + ) + ]) + } + + XCTAssertEqual(reducer.totalSegmentCount, 3) + XCTAssertEqual(reducer.segments.map(\.segmentId), ["s1", "s2"]) + } +} From 1b6ff137f6b3bbf6b14d84dfe1be2bd03099d7da Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 10:39:24 +0700 Subject: [PATCH 40/58] Ignore nested target directories (cherry picked from commit 65bfaa66a32a0647c725068002c718a3862396d1) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 40b5e503cbf..1a9dd09a4d2 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ build/ dist/ .build/ .swiftpm/ +**/target/ # VS Code .vscode/* From 22bb456b40b7a97c9470dce91c4ccc0d2de67005 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 10:48:29 +0700 Subject: [PATCH 41/58] Add local ASR helper scaffold (cherry picked from commit a18d2bb1a5a6d19ac1013fb156b17b7d29201125) --- .../PushToTalkManager.swift | 58 +++- .../LocalTranscription/LocalASRRuntime.swift | 275 ++++++++++++++++++ .../Sources/TranscriptionProvider.swift | 55 +++- .../TranscriptionProviderPolicyTests.swift | 137 +++++++++ desktop/local-asr-helper/Cargo.toml | 9 + desktop/local-asr-helper/README.md | 13 + desktop/local-asr-helper/src/main.rs | 121 ++++++++ 7 files changed, 646 insertions(+), 22 deletions(-) create mode 100644 desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift create mode 100644 desktop/local-asr-helper/Cargo.toml create mode 100644 desktop/local-asr-helper/README.md create mode 100644 desktop/local-asr-helper/src/main.rs diff --git a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift index 9a78eabc066..2d44c04b11c 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift @@ -316,7 +316,9 @@ class PushToTalkManager: ObservableObject { private var finalizedMode: String = "hold" private func finalize() { - guard state == .listening || state == .lockedListening || state == .pendingLockDecision else { return } + guard state == .listening || state == .lockedListening || state == .pendingLockDecision else { + return + } lastOptionUpTime = 0 finalizedMode = state == .lockedListening ? "locked" : "hold" @@ -361,7 +363,9 @@ class PushToTalkManager: ObservableObject { await self.contextCaptureTask?.value let language = AssistantSettings.shared.effectiveTranscriptionLanguage let audioSeconds = Double(audioData.count) / (16000.0 * 2.0) - log("PushToTalkManager: batch audio \(audioData.count) bytes (\(String(format: "%.1f", audioSeconds))s), pttLanguage=\(language), selectedLanguage=\(AssistantSettings.shared.transcriptionLanguage), autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect)") + log( + "PushToTalkManager: batch audio \(audioData.count) bytes (\(String(format: "%.1f", audioSeconds))s), pttLanguage=\(language), selectedLanguage=\(AssistantSettings.shared.transcriptionLanguage), autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect)" + ) var transcript = try await TranscriptionService.batchTranscribe( audioData: audioData, @@ -369,8 +373,12 @@ class PushToTalkManager: ObservableObject { contextKeywords: self.currentContextSnapshot?.keywords ?? [] ) - if (transcript == nil || transcript?.isEmpty == true) && language != "en" && language != "multi" && audioSeconds < 5.0 { - log("PushToTalkManager: selected language returned empty on short audio, retrying with 'en'") + if (transcript == nil || transcript?.isEmpty == true) && language != "en" + && language != "multi" && audioSeconds < 5.0 + { + log( + "PushToTalkManager: selected language returned empty on short audio, retrying with 'en'" + ) transcript = try await TranscriptionService.batchTranscribe( audioData: audioData, language: "en", @@ -385,7 +393,9 @@ class PushToTalkManager: ObservableObject { } } catch { logError("PushToTalkManager: batch transcription failed", error: error) - let message = (error as? TranscriptionService.TranscriptionError)?.errorDescription ?? "Transcription failed" + let message = + (error as? TranscriptionService.TranscriptionError)?.errorDescription + ?? "Transcription failed" barState?.voiceTranscript = "⚠️ \(message)" try? await Task.sleep(nanoseconds: 3_000_000_000) barState?.voiceTranscript = "" @@ -449,7 +459,11 @@ class PushToTalkManager: ObservableObject { } Task { [weak self, query, contextKeywords, wasFollowUp] in - let cleanedQuery = await PTTTranscriptCleanupService.shared.cleanup(query, keywords: contextKeywords) + let cleanedQuery = await PTTTranscriptPostProcessor.process( + query, + keywords: contextKeywords, + provider: .cloud + ) await MainActor.run { self?.sendQuery(cleanedQuery, wasFollowUp: wasFollowUp) } @@ -473,10 +487,13 @@ class PushToTalkManager: ObservableObject { startAudioTranscription() let captureStartedAt = Date() contextCaptureTask = Task { [weak self] in - let snapshot = await PTTContextVocabularyProvider.capture(at: captureStartedAt, preOverlayImage: preOverlayImage) + let snapshot = await PTTContextVocabularyProvider.capture( + at: captureStartedAt, preOverlayImage: preOverlayImage) await MainActor.run { guard let self, !Task.isCancelled else { return } - guard self.state == .listening || self.state == .lockedListening || self.state == .finalizing else { return } + guard + self.state == .listening || self.state == .lockedListening || self.state == .finalizing + else { return } self.currentContextSnapshot = snapshot } } @@ -604,7 +621,8 @@ class PushToTalkManager: ObservableObject { log("PushToTalkManager: silent-mic detected but no built-in mic to fall back to") return } - log("PushToTalkManager: silent-mic fallback — switching to built-in mic (deviceID=\(builtInID))") + log( + "PushToTalkManager: silent-mic fallback — switching to built-in mic (deviceID=\(builtInID))") audioCaptureService?.stopCapture() audioCaptureService = nil startMicCapture(batchMode: batchMode, overrideDeviceID: builtInID) @@ -653,7 +671,9 @@ class PushToTalkManager: ObservableObject { // Skip resize when in follow-up mode, expanded AI conversation, or during onboarding // (during onboarding the floating bar shouldn't appear as a separate window) let isOnboarding = !UserDefaults.standard.bool(forKey: "hasCompletedOnboarding") - guard !skipResize && !barState.isVoiceFollowUp && !barState.showingAIConversation && !isOnboarding else { return } + guard + !skipResize && !barState.isVoiceFollowUp && !barState.showingAIConversation && !isOnboarding + else { return } if barState.isVoiceListening && !wasListening { FloatingControlBarManager.shared.resizeForPTT(expanded: true) } else if !barState.isVoiceListening && wasListening { @@ -661,3 +681,21 @@ class PushToTalkManager: ObservableObject { } } } + +enum PTTTranscriptPostProcessor { + typealias Cleanup = (String, [String]) async -> String + + static func process( + _ transcript: String, + keywords: [String], + provider: TranscriptionProviderKind, + cleanup: @escaping Cleanup = { transcript, keywords in + await PTTTranscriptCleanupService.shared.cleanup(transcript, keywords: keywords) + } + ) async -> String { + guard provider != .local else { + return transcript + } + return await cleanup(transcript, keywords) + } +} diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift new file mode 100644 index 00000000000..4f82ead05b3 --- /dev/null +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -0,0 +1,275 @@ +import Foundation + +enum LocalTranscriptionModel: String, CaseIterable, Codable, Equatable, Hashable { + case tiny + case base + case small + case medium + case largeV3Turbo = "large_v3_turbo" +} + +struct LocalTranscriptionPlan: Equatable { + var engine: LocalTranscriptionEngine + var model: LocalTranscriptionModel + var quality: TranscriptionQualityPreset +} + +struct LocalASRTranscriptionRequest: Codable, Equatable { + var requestId: String + var audioPath: String + var language: String + var sampleRate: Int + var channels: Int + var engine: LocalTranscriptionEngine + var model: LocalTranscriptionModel + var fixtureSegments: [LocalASRTranscriptSegment]? + + enum CodingKeys: String, CodingKey { + case requestId = "request_id" + case audioPath = "audio_path" + case language + case sampleRate = "sample_rate" + case channels + case engine + case model + case fixtureSegments = "fixture_segments" + } +} + +struct LocalASRTranscriptionResponse: Codable, Equatable { + var requestId: String + var engine: LocalTranscriptionEngine + var model: LocalTranscriptionModel + var language: String + var segments: [LocalASRTranscriptSegment] + var fixture: Bool + + enum CodingKeys: String, CodingKey { + case requestId = "request_id" + case engine + case model + case language + case segments + case fixture + } +} + +struct LocalASRTranscriptSegment: Codable, Equatable { + var id: String? + var speaker: Int? + var text: String + var start: Double + var end: Double + + func normalized(defaultSpeaker: Int = 0) -> NormalizedTranscriptSegment { + NormalizedTranscriptSegment( + segmentId: id, + speaker: speaker ?? defaultSpeaker, + speakerLabel: nil, + text: text, + start: start, + end: end, + isUser: true, + personId: nil, + translations: [] + ) + } +} + +struct LocalASRHelperClient { + var executableURL: URL + var timeoutSeconds: TimeInterval = 60 + + func transcribe(_ request: LocalASRTranscriptionRequest) async throws + -> LocalASRTranscriptionResponse + { + let process = Process() + process.executableURL = executableURL + + let input = Pipe() + let output = Pipe() + let errors = Pipe() + process.standardInput = input + process.standardOutput = output + process.standardError = errors + + try process.run() + let requestData = try JSONEncoder.localASR.encode(request) + input.fileHandleForWriting.write(requestData) + try? input.fileHandleForWriting.close() + + return try await withTimeout(seconds: timeoutSeconds) { + process.waitUntilExit() + let outputData = output.fileHandleForReading.readDataToEndOfFile() + if process.terminationStatus != 0 { + let errorText = + String(data: errors.fileHandleForReading.readDataToEndOfFile(), encoding: .utf8) ?? "" + throw TranscriptionService.TranscriptionError.webSocketError( + "Local ASR helper exited with status \(process.terminationStatus): \(errorText)" + ) + } + return try JSONDecoder.localASR.decode(LocalASRTranscriptionResponse.self, from: outputData) + } + } + + private func withTimeout( + seconds: TimeInterval, + operation: @escaping @Sendable () throws -> T + ) async throws -> T { + try await withThrowingTaskGroup(of: T.self) { group in + group.addTask { + try operation() + } + group.addTask { + try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000)) + throw CancellationError() + } + guard let result = try await group.next() else { + throw CancellationError() + } + group.cancelAll() + return result + } + } +} + +struct LocalTranscriptMerger { + private(set) var segments: [NormalizedTranscriptSegment] = [] + private let duplicateOverlapThreshold: Double + + init(duplicateOverlapThreshold: Double = 0.8) { + self.duplicateOverlapThreshold = duplicateOverlapThreshold + } + + mutating func merge(_ incomingSegments: [NormalizedTranscriptSegment]) + -> [NormalizedTranscriptSegment] + { + for incoming in incomingSegments + where !incoming.text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty { + upsert(incoming) + } + segments.sort { lhs, rhs in + if lhs.start == rhs.start { + return lhs.end < rhs.end + } + return lhs.start < rhs.start + } + return segments + } + + private mutating func upsert(_ incoming: NormalizedTranscriptSegment) { + if let segmentId = incoming.segmentId, + let index = segments.firstIndex(where: { $0.segmentId == segmentId }) + { + segments[index] = preferredSegment(existing: segments[index], incoming: incoming) + return + } + + if let index = segments.firstIndex(where: { isDuplicate($0, incoming) }) { + segments[index] = preferredSegment(existing: segments[index], incoming: incoming) + return + } + + if let index = segments.firstIndex(where: { canMergeOverlap($0, incoming) }) { + segments[index] = mergedOverlap(segments[index], incoming) + return + } + + segments.append(incoming) + } + + private func isDuplicate( + _ existing: NormalizedTranscriptSegment, _ incoming: NormalizedTranscriptSegment + ) -> Bool { + guard normalizedText(existing.text) == normalizedText(incoming.text) else { return false } + let intersection = max(0, min(existing.end, incoming.end) - max(existing.start, incoming.start)) + let shorterDuration = max( + 0.001, min(existing.end - existing.start, incoming.end - incoming.start)) + return intersection / shorterDuration >= duplicateOverlapThreshold + } + + private func canMergeOverlap( + _ existing: NormalizedTranscriptSegment, + _ incoming: NormalizedTranscriptSegment + ) -> Bool { + guard existing.speaker == incoming.speaker else { return false } + guard min(existing.end, incoming.end) > max(existing.start, incoming.start) else { + return false + } + return edgeTokenOverlap(existing.text, incoming.text) > 0 + } + + private func mergedOverlap( + _ existing: NormalizedTranscriptSegment, + _ incoming: NormalizedTranscriptSegment + ) -> NormalizedTranscriptSegment { + let existingFirst = existing.start <= incoming.start + let first = existingFirst ? existing : incoming + let second = existingFirst ? incoming : existing + let overlap = edgeTokenOverlap(first.text, second.text) + let suffix = tokenized(second.text).dropFirst(overlap).joined(separator: " ") + + var merged = first + merged.segmentId = first.segmentId ?? second.segmentId + merged.start = min(first.start, second.start) + merged.end = max(first.end, second.end) + merged.text = suffix.isEmpty ? first.text : "\(first.text) \(suffix)" + return merged + } + + private func edgeTokenOverlap(_ first: String, _ second: String) -> Int { + let left = tokenized(first) + let right = tokenized(second) + guard !left.isEmpty, !right.isEmpty else { return 0 } + + let maxOverlap = min(left.count, right.count) + for count in stride(from: maxOverlap, through: 1, by: -1) { + if Array(left.suffix(count)) == Array(right.prefix(count)) { + return count + } + } + return 0 + } + + private func preferredSegment( + existing: NormalizedTranscriptSegment, + incoming: NormalizedTranscriptSegment + ) -> NormalizedTranscriptSegment { + if normalizedText(existing.text) == normalizedText(incoming.text), + existing.text.count <= incoming.text.count + { + return existing + } + if incoming.end - incoming.start > existing.end - existing.start { + return incoming + } + if incoming.text.count > existing.text.count { + return incoming + } + return existing + } + + private func normalizedText(_ value: String) -> String { + value.lowercased() + .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) + .trimmingCharacters(in: .whitespacesAndNewlines) + } + + private func tokenized(_ value: String) -> [String] { + normalizedText(value).split(separator: " ").map(String.init) + } +} + +extension JSONEncoder { + fileprivate static var localASR: JSONEncoder { + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys] + return encoder + } +} + +extension JSONDecoder { + fileprivate static var localASR: JSONDecoder { + JSONDecoder() + } +} diff --git a/desktop/Desktop/Sources/TranscriptionProvider.swift b/desktop/Desktop/Sources/TranscriptionProvider.swift index 55567049f16..f5882691ebb 100644 --- a/desktop/Desktop/Sources/TranscriptionProvider.swift +++ b/desktop/Desktop/Sources/TranscriptionProvider.swift @@ -88,6 +88,7 @@ struct TranscriptionProviderPolicyResult: Equatable { var provider: TranscriptionProviderKind var quality: TranscriptionQualityPreset var localEngine: LocalTranscriptionEngine? + var localPlan: LocalTranscriptionPlan? var fallbackReason: String? var usesCloud: Bool { @@ -109,43 +110,73 @@ struct TranscriptionProviderPolicy { switch selection.mode { case .cloud: return TranscriptionProviderPolicyResult( - provider: .cloud, quality: quality, localEngine: nil, fallbackReason: nil) + provider: .cloud, quality: quality, localEngine: nil, localPlan: nil, fallbackReason: nil) case .local: - if let engine = preferredLocalEngine(for: quality, capabilities: capabilities) { + if let plan = localPlan(for: quality, capabilities: capabilities) { return TranscriptionProviderPolicyResult( - provider: .local, quality: quality, localEngine: engine, fallbackReason: nil) + provider: .local, quality: quality, localEngine: plan.engine, localPlan: plan, + fallbackReason: nil) } return TranscriptionProviderPolicyResult( provider: .cloud, quality: quality, localEngine: nil, + localPlan: nil, fallbackReason: "No local transcription engine is available" ) case .auto: - if let engine = preferredLocalEngine(for: quality, capabilities: capabilities) { + if let plan = localPlan(for: quality, capabilities: capabilities) { return TranscriptionProviderPolicyResult( - provider: .local, quality: quality, localEngine: engine, fallbackReason: nil) + provider: .local, quality: quality, localEngine: plan.engine, localPlan: plan, + fallbackReason: nil) } return TranscriptionProviderPolicyResult( provider: .cloud, quality: quality, localEngine: nil, + localPlan: nil, fallbackReason: "Auto mode fell back to cloud because no local engine is available" ) } } - private func preferredLocalEngine( - for _: TranscriptionQualityPreset, + private func localPlan( + for quality: TranscriptionQualityPreset, capabilities: LocalTranscriptionCapabilities - ) -> LocalTranscriptionEngine? { + ) -> LocalTranscriptionPlan? { + let engine: LocalTranscriptionEngine if capabilities.canUseMLXWhisper { - return .mlxWhisper + engine = .mlxWhisper + } else if capabilities.canUseFasterWhisper { + engine = .fasterWhisper + } else { + return nil } - if capabilities.canUseFasterWhisper { - return .fasterWhisper + + return LocalTranscriptionPlan( + engine: engine, + model: model(for: quality, engine: engine, memoryBytes: capabilities.physicalMemoryBytes), + quality: quality + ) + } + + private func model( + for quality: TranscriptionQualityPreset, + engine: LocalTranscriptionEngine, + memoryBytes: UInt64 + ) -> LocalTranscriptionModel { + let gib = memoryBytes / (1024 * 1024 * 1024) + switch quality { + case .fast: + return .base + case .balanced, .auto: + return gib >= 8 ? .small : .base + case .accurate: + if engine == .mlxWhisper, gib >= 24 { + return .largeV3Turbo + } + return gib >= 16 ? .medium : .small } - return nil } } diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index 3108c5f0e48..3982a1cd172 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -17,6 +17,7 @@ final class TranscriptionProviderPolicyTests: XCTestCase { XCTAssertEqual(result.provider, .local) XCTAssertEqual(result.localEngine, .mlxWhisper) + XCTAssertEqual(result.localPlan?.model, .small) XCTAssertNil(result.fallbackReason) } @@ -34,6 +35,7 @@ final class TranscriptionProviderPolicyTests: XCTestCase { XCTAssertEqual(result.provider, .local) XCTAssertEqual(result.localEngine, .fasterWhisper) + XCTAssertEqual(result.localPlan?.model, .base) } func testAutoFallsBackToCloudWhenNoLocalEngineExists() { @@ -50,6 +52,7 @@ final class TranscriptionProviderPolicyTests: XCTestCase { XCTAssertEqual(result.provider, .cloud) XCTAssertNil(result.localEngine) + XCTAssertNil(result.localPlan) XCTAssertNotNil(result.fallbackReason) } @@ -67,9 +70,36 @@ final class TranscriptionProviderPolicyTests: XCTestCase { XCTAssertEqual(result.provider, .cloud) XCTAssertNil(result.localEngine) + XCTAssertNil(result.localPlan) XCTAssertNil(result.fallbackReason) } + func testAccurateUsesLargerModelsOnlyWhenMemoryAllows() { + let lowMemory = LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 8 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper] + ) + let highMemory = LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 24 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper] + ) + + let policy = TranscriptionProviderPolicy() + let lowMemoryResult = policy.resolve( + selection: TranscriptionProviderSelection(mode: .local, quality: .accurate), + capabilities: lowMemory + ) + let highMemoryResult = policy.resolve( + selection: TranscriptionProviderSelection(mode: .local, quality: .accurate), + capabilities: highMemory + ) + + XCTAssertEqual(lowMemoryResult.localPlan?.model, .small) + XCTAssertEqual(highMemoryResult.localPlan?.model, .largeV3Turbo) + } + func testCapabilityDetectorDistinguishesRosettaFromNativeArm() { let native = LocalTranscriptionCapabilityDetector( physicalMemoryBytes: { 1 }, @@ -151,3 +181,110 @@ final class SpeakerSegmentReducerTests: XCTestCase { XCTAssertEqual(reducer.segments.map(\.segmentId), ["s1", "s2"]) } } + +final class LocalASRRuntimeTests: XCTestCase { + func testHelperContractRoundTripsFixtureSegments() throws { + let request = LocalASRTranscriptionRequest( + requestId: "fixture-1", + audioPath: "/tmp/audio.pcm", + language: "en", + sampleRate: 16000, + channels: 1, + engine: .mlxWhisper, + model: .small, + fixtureSegments: [ + LocalASRTranscriptSegment(id: "seg-1", speaker: 0, text: "hello local", start: 0, end: 1) + ] + ) + + let encoded = try JSONEncoder().encode(request) + let decoded = try JSONDecoder().decode(LocalASRTranscriptionRequest.self, from: encoded) + + XCTAssertEqual(decoded, request) + } + + func testDeterministicMergeDeduplicatesOverlappingChunksAndIsIdempotent() { + var merger = LocalTranscriptMerger() + let first = LocalASRTranscriptSegment( + id: nil, + speaker: 0, + text: "hello local whisper", + start: 0.0, + end: 2.0 + ).normalized() + let duplicate = LocalASRTranscriptSegment( + id: nil, + speaker: 0, + text: "hello local whisper", + start: 0.1, + end: 2.1 + ).normalized() + let next = LocalASRTranscriptSegment( + id: "s2", + speaker: 0, + text: "next segment", + start: 2.2, + end: 3.0 + ).normalized() + + XCTAssertEqual( + merger.merge([first, next]).map(\.text), ["hello local whisper", "next segment"]) + XCTAssertEqual( + merger.merge([duplicate, next]).map(\.text), ["hello local whisper", "next segment"]) + XCTAssertEqual(merger.merge([first, next]).count, 2) + } + + func testDeterministicMergeCombinesPartialOverlapByTokenBoundary() { + var merger = LocalTranscriptMerger() + let first = LocalASRTranscriptSegment( + id: "chunk-1", + speaker: 0, + text: "hello local whisper", + start: 0.0, + end: 2.0 + ).normalized() + let second = LocalASRTranscriptSegment( + id: "chunk-2", + speaker: 0, + text: "whisper works offline", + start: 1.8, + end: 3.2 + ).normalized() + + let result = merger.merge([first, second]) + + XCTAssertEqual(result.count, 1) + XCTAssertEqual(result[0].text, "hello local whisper works offline") + XCTAssertEqual(result[0].start, 0.0) + XCTAssertEqual(result[0].end, 3.2) + } +} + +final class PTTTranscriptPostProcessorTests: XCTestCase { + func testLocalModeBypassesLLMCleanup() async { + var cleanupCalls = 0 + let result = await PTTTranscriptPostProcessor.process( + "raw local transcript", + keywords: ["Omi"], + provider: .local, + cleanup: { transcript, _ in + cleanupCalls += 1 + return "\(transcript) cleaned" + } + ) + + XCTAssertEqual(result, "raw local transcript") + XCTAssertEqual(cleanupCalls, 0) + } + + func testCloudModeUsesCleanup() async { + let result = await PTTTranscriptPostProcessor.process( + "raw cloud transcript", + keywords: [], + provider: .cloud, + cleanup: { transcript, _ in "\(transcript) cleaned" } + ) + + XCTAssertEqual(result, "raw cloud transcript cleaned") + } +} diff --git a/desktop/local-asr-helper/Cargo.toml b/desktop/local-asr-helper/Cargo.toml new file mode 100644 index 00000000000..19dc460df0a --- /dev/null +++ b/desktop/local-asr-helper/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "omi-local-asr-helper" +version = "0.1.0" +edition = "2021" + +[dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + diff --git a/desktop/local-asr-helper/README.md b/desktop/local-asr-helper/README.md new file mode 100644 index 00000000000..66683270ac8 --- /dev/null +++ b/desktop/local-asr-helper/README.md @@ -0,0 +1,13 @@ +# Omi Local ASR Helper + +This is the control-plane scaffold for local Whisper transcription. The desktop +app sends one JSON request on stdin and reads one JSON response on stdout. The +current implementation is fixture-backed so CI can exercise the same contract +before MLX Whisper and faster-whisper adapters are installed. + +Smoke command: + +```bash +printf '{"request_id":"fixture-1","audio_path":"/tmp/sample.pcm","language":"en","sample_rate":16000,"channels":1,"engine":"mlx-whisper","model":"small","fixture_segments":[{"id":"s1","speaker":0,"text":"hello local whisper","start":0.0,"end":1.2}]}' | cargo run --quiet --manifest-path desktop/local-asr-helper/Cargo.toml +``` + diff --git a/desktop/local-asr-helper/src/main.rs b/desktop/local-asr-helper/src/main.rs new file mode 100644 index 00000000000..66d201e5bc6 --- /dev/null +++ b/desktop/local-asr-helper/src/main.rs @@ -0,0 +1,121 @@ +use serde::{Deserialize, Serialize}; +use std::io::{self, Read}; +use std::process; + +#[derive(Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "snake_case")] +struct TranscriptionRequest { + request_id: String, + audio_path: String, + language: String, + sample_rate: u32, + channels: u8, + engine: LocalEngine, + model: LocalModel, + fixture_segments: Option>, +} + +#[derive(Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "kebab-case")] +enum LocalEngine { + MlxWhisper, + FasterWhisper, +} + +#[derive(Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "snake_case")] +enum LocalModel { + Tiny, + Base, + Small, + Medium, + LargeV3Turbo, +} + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] +struct TranscriptSegment { + id: Option, + speaker: Option, + text: String, + start: f64, + end: f64, +} + +#[derive(Debug, Serialize, PartialEq)] +#[serde(rename_all = "snake_case")] +struct TranscriptionResponse { + request_id: String, + engine: LocalEngine, + model: LocalModel, + language: String, + segments: Vec, + fixture: bool, +} + +fn main() { + if let Err(error) = run() { + eprintln!("{error}"); + process::exit(1); + } +} + +fn run() -> Result<(), String> { + let request = read_request()?; + let fixture_segments = request.fixture_segments.clone().unwrap_or_else(|| { + vec![TranscriptSegment { + id: Some(format!("{}-fixture-0", request.request_id)), + speaker: Some(0), + text: "fixture local transcription".to_string(), + start: 0.0, + end: 1.0, + }] + }); + + let response = TranscriptionResponse { + request_id: request.request_id, + engine: request.engine, + model: request.model, + language: request.language, + segments: fixture_segments, + fixture: true, + }; + + let json = serde_json::to_string(&response).map_err(|error| error.to_string())?; + println!("{json}"); + Ok(()) +} + +fn read_request() -> Result { + let mut input = String::new(); + io::stdin() + .read_to_string(&mut input) + .map_err(|error| format!("failed to read stdin: {error}"))?; + serde_json::from_str(&input).map_err(|error| format!("invalid request json: {error}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decodes_fixture_request_contract() { + let json = r#"{ + "request_id": "req-1", + "audio_path": "/tmp/audio.pcm", + "language": "en", + "sample_rate": 16000, + "channels": 1, + "engine": "mlx-whisper", + "model": "small", + "fixture_segments": [ + {"id": "seg-1", "speaker": 0, "text": "hello", "start": 0.0, "end": 1.0} + ] + }"#; + + let request: TranscriptionRequest = serde_json::from_str(json).unwrap(); + + assert_eq!(request.engine, LocalEngine::MlxWhisper); + assert_eq!(request.model, LocalModel::Small); + assert_eq!(request.fixture_segments.unwrap()[0].text, "hello"); + } +} From 0e5cc29d8b9a448236b307098883aa71bbf9159b Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 10:56:52 +0700 Subject: [PATCH 42/58] Wire local PTT transcription routing (cherry picked from commit c1e5d0d3d299e119cc68325e734833a13dc5cc27) --- desktop/Desktop/Sources/AppState.swift | 18 ++ .../PushToTalkManager.swift | 25 +- .../LocalTranscription/LocalASRRuntime.swift | 236 ++++++++++++++++++ .../TranscriptionProviderPolicyTests.swift | 211 ++++++++++++++++ 4 files changed, 485 insertions(+), 5 deletions(-) diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index d9a68d0505e..bb2152a2a87 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -1276,6 +1276,24 @@ class AppState: ObservableObject { // Use provided source or fall back to current setting let effectiveSource = source ?? audioSource + let backgroundRouting = BackgroundTranscriptionRoutingGuard().decide( + selection: AssistantSettings.shared.transcriptionProviderSelection, + capabilities: LocalTranscriptionCapabilityDetector( + availableEngines: { LocalASRHelperLocator.detectedEngines() } + ).detect() + ) + if !backgroundRouting.useCloudBackend { + let message = + backgroundRouting.unsupportedLocalReason + ?? "Local background transcription is not available yet." + log("Transcription: \(message)") + showAlert( + title: "Local Background Transcription Unavailable", + message: + "Local transcription is currently available for Push-to-Talk batch mode. Choose cloud transcription to use background capture." + ) + return + } // For BLE device, check if device is connected if effectiveSource == .bleDevice { diff --git a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift index 2d44c04b11c..9118e2a030d 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift @@ -47,6 +47,7 @@ class PushToTalkManager: ObservableObject { private var isCurrentSessionFollowUp = false private var currentContextSnapshot: PTTContextSnapshot? private var contextCaptureTask: Task? + private var currentTranscriptionProvider: TranscriptionProviderKind = .cloud // Batch mode: accumulate raw audio for post-recording transcription private var batchAudioBuffer = Data() @@ -220,6 +221,7 @@ class PushToTalkManager: ObservableObject { transcriptSegments = [] lastInterimText = "" currentContextSnapshot = nil + currentTranscriptionProvider = .cloud finalizeWorkItem?.cancel() finalizeWorkItem = nil @@ -262,6 +264,7 @@ class PushToTalkManager: ObservableObject { transcriptSegments = [] lastInterimText = "" currentContextSnapshot = nil + currentTranscriptionProvider = .cloud let preOverlayImage = ScreenCaptureManager.captureScreenImage() captureContextAndStartAudio(preOverlayImage: preOverlayImage) } @@ -302,6 +305,7 @@ class PushToTalkManager: ObservableObject { batchAudioLock.lock() batchAudioBuffer = Data() batchAudioLock.unlock() + currentTranscriptionProvider = .cloud isCurrentSessionFollowUp = false updateBarState() } @@ -357,6 +361,7 @@ class PushToTalkManager: ObservableObject { } barState?.voiceTranscript = "Transcribing..." + let providerSelection = AssistantSettings.shared.transcriptionProviderSelection Task { do { @@ -367,26 +372,34 @@ class PushToTalkManager: ObservableObject { "PushToTalkManager: batch audio \(audioData.count) bytes (\(String(format: "%.1f", audioSeconds))s), pttLanguage=\(language), selectedLanguage=\(AssistantSettings.shared.transcriptionLanguage), autoDetect=\(AssistantSettings.shared.transcriptionAutoDetect)" ) - var transcript = try await TranscriptionService.batchTranscribe( + let router = PTTBatchTranscriptionRouter(selection: { providerSelection }) + var result = try await router.transcribe( audioData: audioData, language: language, contextKeywords: self.currentContextSnapshot?.keywords ?? [] ) + self.currentTranscriptionProvider = result.provider + if let reason = result.fallbackReason { + log("PushToTalkManager: provider policy fallback: \(reason)") + } - if (transcript == nil || transcript?.isEmpty == true) && language != "en" + if result.provider == .cloud + && (result.transcript == nil || result.transcript?.isEmpty == true) + && language != "en" && language != "multi" && audioSeconds < 5.0 { log( "PushToTalkManager: selected language returned empty on short audio, retrying with 'en'" ) - transcript = try await TranscriptionService.batchTranscribe( + result = try await router.transcribe( audioData: audioData, language: "en", contextKeywords: self.currentContextSnapshot?.keywords ?? [] ) + self.currentTranscriptionProvider = result.provider } - if let transcript, !transcript.isEmpty { + if let transcript = result.transcript, !transcript.isEmpty { self.transcriptSegments = [transcript] } else { log("PushToTalkManager: transcription returned empty after retry") @@ -451,6 +464,8 @@ class PushToTalkManager: ObservableObject { transcriptSegments = [] lastInterimText = "" currentContextSnapshot = nil + let provider = currentTranscriptionProvider + currentTranscriptionProvider = .cloud updateBarState(skipResize: hasQuery || wasFollowUp) guard hasQuery else { @@ -462,7 +477,7 @@ class PushToTalkManager: ObservableObject { let cleanedQuery = await PTTTranscriptPostProcessor.process( query, keywords: contextKeywords, - provider: .cloud + provider: provider ) await MainActor.run { self?.sendQuery(cleanedQuery, wasFollowUp: wasFollowUp) diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index 4f82ead05b3..908ccb0c308 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -133,6 +133,242 @@ struct LocalASRHelperClient { } } +enum LocalASRHelperLocator { + static let environmentKey = "OMI_LOCAL_ASR_HELPER_PATH" + + static func defaultExecutableURL( + environment: [String: String] = ProcessInfo.processInfo.environment, + bundle: Bundle = .main, + fileManager: FileManager = .default + ) -> URL? { + if let override = environment[environmentKey], !override.isEmpty { + let url = URL(fileURLWithPath: override) + return fileManager.isExecutableFile(atPath: url.path) ? url : nil + } + + let bundleCandidates = [ + bundle.url(forResource: "local-asr-helper", withExtension: nil), + bundle.resourceURL?.appendingPathComponent("local-asr-helper"), + ].compactMap { $0 } + + if let bundled = bundleCandidates.first(where: { + fileManager.isExecutableFile(atPath: $0.path) + }) { + return bundled + } + + #if DEBUG + let currentDirectory = URL(fileURLWithPath: fileManager.currentDirectoryPath) + let debugCandidates = [ + currentDirectory.appendingPathComponent("local-asr-helper/target/debug/local-asr-helper"), + currentDirectory.appendingPathComponent( + "../local-asr-helper/target/debug/local-asr-helper"), + currentDirectory.appendingPathComponent( + "../../local-asr-helper/target/debug/local-asr-helper"), + currentDirectory.appendingPathComponent( + "desktop/local-asr-helper/target/debug/local-asr-helper"), + ] + return debugCandidates.first { fileManager.isExecutableFile(atPath: $0.path) } + #else + return nil + #endif + } + + static func detectedEngines(executableURL: URL? = defaultExecutableURL()) + -> Set + { + guard executableURL != nil else { return [] } + return Set(LocalTranscriptionEngine.allCases) + } +} + +struct LocalASRBatchTranscriber { + typealias RequestHandler = (LocalASRTranscriptionRequest) async throws + -> LocalASRTranscriptionResponse + + var requestHandler: RequestHandler + var temporaryDirectory: URL + var fileManager: FileManager + var makeRequestId: () -> String + + init( + executableURL: URL, + timeoutSeconds: TimeInterval = 60, + temporaryDirectory: URL = FileManager.default.temporaryDirectory, + fileManager: FileManager = .default, + makeRequestId: @escaping () -> String = { UUID().uuidString } + ) { + let client = LocalASRHelperClient(executableURL: executableURL, timeoutSeconds: timeoutSeconds) + self.init( + requestHandler: { request in + try await client.transcribe(request) + }, + temporaryDirectory: temporaryDirectory, + fileManager: fileManager, + makeRequestId: makeRequestId + ) + } + + init( + requestHandler: @escaping RequestHandler, + temporaryDirectory: URL = FileManager.default.temporaryDirectory, + fileManager: FileManager = .default, + makeRequestId: @escaping () -> String = { UUID().uuidString } + ) { + self.requestHandler = requestHandler + self.temporaryDirectory = temporaryDirectory + self.fileManager = fileManager + self.makeRequestId = makeRequestId + } + + func transcribe( + audioData: Data, + language: String, + plan: LocalTranscriptionPlan + ) async throws -> [NormalizedTranscriptSegment] { + let requestId = makeRequestId() + let audioURL = temporaryDirectory.appendingPathComponent("\(requestId).pcm") + try audioData.write(to: audioURL, options: .atomic) + defer { try? fileManager.removeItem(at: audioURL) } + + let response = try await requestHandler( + LocalASRTranscriptionRequest( + requestId: requestId, + audioPath: audioURL.path, + language: language, + sampleRate: 16000, + channels: 1, + engine: plan.engine, + model: plan.model, + fixtureSegments: nil + ) + ) + + var merger = LocalTranscriptMerger() + return merger.merge(response.segments.map { $0.normalized() }) + } +} + +struct PTTBatchTranscriptionResult: Equatable { + var provider: TranscriptionProviderKind + var transcript: String? + var fallbackReason: String? +} + +struct PTTBatchTranscriptionRouter { + typealias CloudTranscriber = (Data, String, [String]) async throws -> String? + typealias LocalTranscriber = (Data, String, LocalTranscriptionPlan) async throws + -> [NormalizedTranscriptSegment] + + var selection: () -> TranscriptionProviderSelection + var capabilities: () -> LocalTranscriptionCapabilities + var cloudTranscribe: CloudTranscriber + var localTranscribe: LocalTranscriber + var policy: TranscriptionProviderPolicy + + init( + selection: @escaping () -> TranscriptionProviderSelection = { .default }, + capabilities: @escaping () -> LocalTranscriptionCapabilities = { + LocalTranscriptionCapabilityDetector( + availableEngines: { LocalASRHelperLocator.detectedEngines() } + ).detect() + }, + cloudTranscribe: @escaping CloudTranscriber = { audioData, language, keywords in + try await TranscriptionService.batchTranscribe( + audioData: audioData, + language: language, + contextKeywords: keywords + ) + }, + localTranscribe: @escaping LocalTranscriber = { audioData, language, plan in + guard let executableURL = LocalASRHelperLocator.defaultExecutableURL() else { + throw TranscriptionService.TranscriptionError.webSocketError( + "Local ASR helper is not available" + ) + } + return try await LocalASRBatchTranscriber(executableURL: executableURL).transcribe( + audioData: audioData, + language: language, + plan: plan + ) + }, + policy: TranscriptionProviderPolicy = TranscriptionProviderPolicy() + ) { + self.selection = selection + self.capabilities = capabilities + self.cloudTranscribe = cloudTranscribe + self.localTranscribe = localTranscribe + self.policy = policy + } + + func transcribe(audioData: Data, language: String, contextKeywords: [String]) async throws + -> PTTBatchTranscriptionResult + { + let currentSelection = selection() + let resolved = policy.resolve(selection: currentSelection, capabilities: capabilities()) + + if resolved.provider == .local, let plan = resolved.localPlan { + let segments = try await localTranscribe(audioData, language, plan) + let transcript = segments.map(\.text).joined(separator: " ") + .trimmingCharacters(in: .whitespacesAndNewlines) + return PTTBatchTranscriptionResult( + provider: .local, + transcript: transcript.isEmpty ? nil : transcript, + fallbackReason: resolved.fallbackReason + ) + } + + if currentSelection.mode == .local { + throw TranscriptionService.TranscriptionError.webSocketError( + resolved.fallbackReason ?? "No local transcription engine is available" + ) + } + + let transcript = try await cloudTranscribe(audioData, language, contextKeywords) + return PTTBatchTranscriptionResult( + provider: .cloud, + transcript: transcript, + fallbackReason: resolved.fallbackReason + ) + } +} + +struct BackgroundTranscriptionRoutingDecision: Equatable { + var useCloudBackend: Bool + var unsupportedLocalReason: String? +} + +struct BackgroundTranscriptionRoutingGuard { + var policy: TranscriptionProviderPolicy = TranscriptionProviderPolicy() + + func decide( + selection: TranscriptionProviderSelection, + capabilities: LocalTranscriptionCapabilities + ) -> BackgroundTranscriptionRoutingDecision { + let resolved = policy.resolve(selection: selection, capabilities: capabilities) + if selection.mode == .local, resolved.provider != .local { + return BackgroundTranscriptionRoutingDecision( + useCloudBackend: false, + unsupportedLocalReason: resolved.fallbackReason + ?? "No local transcription engine is available" + ) + } + + guard resolved.provider == .local else { + return BackgroundTranscriptionRoutingDecision( + useCloudBackend: true, + unsupportedLocalReason: resolved.fallbackReason + ) + } + + return BackgroundTranscriptionRoutingDecision( + useCloudBackend: false, + unsupportedLocalReason: + "Local background transcription is not available until local finalization can persist conversations without backend force-processing." + ) + } +} + struct LocalTranscriptMerger { private(set) var segments: [NormalizedTranscriptSegment] = [] private let duplicateOverlapThreshold: Double diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index 3982a1cd172..fa836ee8ca3 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -258,6 +258,217 @@ final class LocalASRRuntimeTests: XCTestCase { XCTAssertEqual(result[0].start, 0.0) XCTAssertEqual(result[0].end, 3.2) } + + func testLocalBatchTranscriberWritesPCMAndNormalizesMergedSegments() async throws { + var capturedRequest: LocalASRTranscriptionRequest? + let tempDirectory = FileManager.default.temporaryDirectory.appendingPathComponent( + "LocalASRRuntimeTests-\(UUID().uuidString)", + isDirectory: true + ) + try FileManager.default.createDirectory( + at: tempDirectory, + withIntermediateDirectories: true + ) + defer { try? FileManager.default.removeItem(at: tempDirectory) } + + let transcriber = LocalASRBatchTranscriber( + requestHandler: { request in + capturedRequest = request + XCTAssertTrue(FileManager.default.fileExists(atPath: request.audioPath)) + return LocalASRTranscriptionResponse( + requestId: request.requestId, + engine: request.engine, + model: request.model, + language: request.language, + segments: [ + LocalASRTranscriptSegment( + id: "a", + speaker: 0, + text: "hello local", + start: 0, + end: 1 + ), + LocalASRTranscriptSegment( + id: "b", + speaker: 0, + text: "local whisper", + start: 0.9, + end: 2 + ), + ], + fixture: true + ) + }, + temporaryDirectory: tempDirectory, + makeRequestId: { "req-1" } + ) + + let result = try await transcriber.transcribe( + audioData: Data([1, 2, 3]), + language: "en", + plan: LocalTranscriptionPlan(engine: .mlxWhisper, model: .small, quality: .balanced) + ) + + XCTAssertEqual( + capturedRequest?.audioPath, tempDirectory.appendingPathComponent("req-1.pcm").path) + XCTAssertFalse( + FileManager.default.fileExists(atPath: tempDirectory.appendingPathComponent("req-1.pcm").path) + ) + XCTAssertEqual(result.map(\.text), ["hello local whisper"]) + } +} + +final class PTTBatchTranscriptionRouterTests: XCTestCase { + func testLocalProviderUsesHelperPathAndDoesNotCallCloud() async throws { + var cloudCalls = 0 + var localPlan: LocalTranscriptionPlan? + let router = PTTBatchTranscriptionRouter( + selection: { TranscriptionProviderSelection(mode: .local, quality: .balanced) }, + capabilities: { + LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper] + ) + }, + cloudTranscribe: { _, _, _ in + cloudCalls += 1 + return "cloud transcript" + }, + localTranscribe: { _, _, plan in + localPlan = plan + return [ + LocalASRTranscriptSegment( + id: "local-1", + speaker: 0, + text: "local transcript", + start: 0, + end: 1 + ).normalized() + ] + } + ) + + let result = try await router.transcribe( + audioData: Data([1]), + language: "en", + contextKeywords: ["Omi"] + ) + + XCTAssertEqual(result.provider, .local) + XCTAssertEqual(result.transcript, "local transcript") + XCTAssertEqual(localPlan?.engine, .mlxWhisper) + XCTAssertEqual(cloudCalls, 0) + } + + func testCloudProviderKeepsExistingBatchPath() async throws { + var localCalls = 0 + var capturedKeywords: [String] = [] + let router = PTTBatchTranscriptionRouter( + selection: { TranscriptionProviderSelection(mode: .cloud, quality: .auto) }, + capabilities: { + LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper] + ) + }, + cloudTranscribe: { _, _, keywords in + capturedKeywords = keywords + return "cloud transcript" + }, + localTranscribe: { _, _, _ in + localCalls += 1 + return [] + } + ) + + let result = try await router.transcribe( + audioData: Data([1]), + language: "en", + contextKeywords: ["keyword"] + ) + + XCTAssertEqual(result.provider, .cloud) + XCTAssertEqual(result.transcript, "cloud transcript") + XCTAssertEqual(capturedKeywords, ["keyword"]) + XCTAssertEqual(localCalls, 0) + } + + func testExplicitLocalWithoutEngineDoesNotCallCloud() async { + var cloudCalls = 0 + let router = PTTBatchTranscriptionRouter( + selection: { TranscriptionProviderSelection(mode: .local, quality: .auto) }, + capabilities: { + LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [] + ) + }, + cloudTranscribe: { _, _, _ in + cloudCalls += 1 + return "cloud transcript" + }, + localTranscribe: { _, _, _ in [] } + ) + + do { + _ = try await router.transcribe( + audioData: Data([1]), + language: "en", + contextKeywords: [] + ) + XCTFail("Expected explicit local mode without an engine to fail") + } catch { + XCTAssertEqual(cloudCalls, 0) + XCTAssertTrue(error.localizedDescription.contains("No local transcription engine")) + } + } +} + +final class BackgroundTranscriptionRoutingGuardTests: XCTestCase { + func testAutoCloudFallbackAllowsBackgroundCloud() { + let decision = BackgroundTranscriptionRoutingGuard().decide( + selection: TranscriptionProviderSelection(mode: .auto, quality: .auto), + capabilities: LocalTranscriptionCapabilities( + processor: .intel, + physicalMemoryBytes: 8 * 1024 * 1024 * 1024, + availableEngines: [] + ) + ) + + XCTAssertTrue(decision.useCloudBackend) + XCTAssertNotNil(decision.unsupportedLocalReason) + } + + func testResolvedLocalBlocksBackgroundCaptureUntilLocalFinalizeExists() { + let decision = BackgroundTranscriptionRoutingGuard().decide( + selection: TranscriptionProviderSelection(mode: .local, quality: .balanced), + capabilities: LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper] + ) + ) + + XCTAssertFalse(decision.useCloudBackend) + XCTAssertTrue(decision.unsupportedLocalReason?.contains("backend force-processing") == true) + } + + func testExplicitLocalWithoutEngineDoesNotSilentlyUseCloudForBackground() { + let decision = BackgroundTranscriptionRoutingGuard().decide( + selection: TranscriptionProviderSelection(mode: .local, quality: .balanced), + capabilities: LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [] + ) + ) + + XCTAssertFalse(decision.useCloudBackend) + XCTAssertEqual(decision.unsupportedLocalReason, "No local transcription engine is available") + } } final class PTTTranscriptPostProcessorTests: XCTestCase { From b2063dca2db27be99f55e572755a47f1e856347d Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 11:09:02 +0700 Subject: [PATCH 43/58] Add local transcription desktop choice (cherry picked from commit b739064b99d29d1806380462fea9c34d103dc5de) --- .../MainWindow/Pages/SettingsPage.swift | 183 ++++++++++++++++++ .../Sources/OnboardingBYOKStepView.swift | 148 ++++++++++++-- desktop/Desktop/Sources/OnboardingView.swift | 22 ++- .../Sources/TranscriptionProvider.swift | 86 ++++++++ .../TranscriptionProviderPolicyTests.swift | 30 +++ 5 files changed, 456 insertions(+), 13 deletions(-) diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 3a2d633d3ba..73fc7d1a613 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -278,6 +278,8 @@ struct SettingsContentView: View { @State private var transcriptionAutoDetect: Bool = true @State private var transcriptionLanguage: String = "en" @State private var vadGateEnabled: Bool = false + @State private var transcriptionProviderSelection: TranscriptionProviderSelection + @State private var localTranscriptionCapabilities: LocalTranscriptionCapabilities // Multi-chat mode setting @AppStorage("multiChatEnabled") private var multiChatEnabled = false @@ -447,6 +449,9 @@ struct SettingsContentView: View { _vadGateEnabled = State(initialValue: settings.vadGateEnabled) _transcriptionLanguage = State(initialValue: settings.transcriptionLanguage) _transcriptionAutoDetect = State(initialValue: settings.transcriptionAutoDetect) + _transcriptionProviderSelection = State(initialValue: settings.transcriptionProviderSelection) + _localTranscriptionCapabilities = State( + initialValue: SettingsContentView.detectLocalTranscriptionCapabilities()) } /// Computed status text for notifications @@ -515,6 +520,7 @@ struct SettingsContentView: View { chatProvider?.checkClaudeConnectionStatus() // Refresh notification permission state appState.checkNotificationPermission() + localTranscriptionCapabilities = SettingsContentView.detectLocalTranscriptionCapabilities() } .onReceive(NotificationCenter.default.publisher(for: .assistantMonitoringStateDidChange)) { notification in @@ -1036,6 +1042,97 @@ struct SettingsContentView: View { private var transcriptionSection: some View { VStack(spacing: 20) { + // Provider + settingsCard(settingId: "transcription.provider") { + VStack(alignment: .leading, spacing: 16) { + HStack { + Image(systemName: "waveform.and.magnifyingglass") + .scaledFont(size: 16) + .foregroundColor(OmiColors.purplePrimary) + + VStack(alignment: .leading, spacing: 4) { + Text("Transcription Provider") + .scaledFont(size: 15, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + + Text(transcriptionProviderStatusText) + .scaledFont(size: 13) + .foregroundColor(transcriptionProviderStatusColor) + .fixedSize(horizontal: false, vertical: true) + } + + Spacer() + + Picker( + "", + selection: Binding( + get: { transcriptionProviderSelection.quality }, + set: { newValue in + updateTranscriptionProviderSelection( + TranscriptionProviderSelection( + mode: transcriptionProviderSelection.mode, + quality: newValue + ) + ) + } + ) + ) { + ForEach(TranscriptionQualityPreset.allCases, id: \.rawValue) { quality in + Text(TranscriptionProviderOnboardingAdvisor.displayName(for: quality)) + .tag(quality) + } + } + .pickerStyle(.menu) + .frame(width: 130) + } + + VStack(spacing: 10) { + transcriptionProviderOption( + mode: .auto, + title: "Local First", + detail: "Use local Whisper when available; otherwise use cloud transcription.", + icon: "sparkle.magnifyingglass" + ) + + transcriptionProviderOption( + mode: .local, + title: "Local Whisper Only", + detail: "Use on-device batch transcription. Background capture is not available yet.", + icon: "desktopcomputer" + ) + + transcriptionProviderOption( + mode: .cloud, + title: "Cloud Transcription", + detail: + "Use the existing Omi cloud transcription path for live meetings and background capture.", + icon: "cloud.fill" + ) + } + + if !backgroundTranscriptionRouting.useCloudBackend { + HStack(alignment: .top, spacing: 8) { + Image(systemName: "info.circle.fill") + .scaledFont(size: 12) + .foregroundColor(OmiColors.warning) + .padding(.top, 1) + + Text( + "Local background transcription is not available yet. Push-to-Talk can use local Whisper; choose Cloud Transcription for continuous background capture." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.warning) + .fixedSize(horizontal: false, vertical: true) + } + .padding(10) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(OmiColors.warning.opacity(0.1)) + ) + } + } + } + // Language Mode settingsCard(settingId: "transcription.languagemode") { VStack(alignment: .leading, spacing: 16) { @@ -1346,6 +1443,92 @@ struct SettingsContentView: View { } } + private static func detectLocalTranscriptionCapabilities() -> LocalTranscriptionCapabilities { + LocalTranscriptionCapabilityDetector( + availableEngines: { LocalASRHelperLocator.detectedEngines() } + ).detect() + } + + private var resolvedTranscriptionProvider: TranscriptionProviderPolicyResult { + TranscriptionProviderPolicy().resolve( + selection: transcriptionProviderSelection, + capabilities: localTranscriptionCapabilities + ) + } + + private var backgroundTranscriptionRouting: BackgroundTranscriptionRoutingDecision { + BackgroundTranscriptionRoutingGuard().decide( + selection: transcriptionProviderSelection, + capabilities: localTranscriptionCapabilities + ) + } + + private var transcriptionProviderStatusText: String { + TranscriptionProviderOnboardingAdvisor.statusText(for: resolvedTranscriptionProvider) + } + + private var transcriptionProviderStatusColor: Color { + resolvedTranscriptionProvider.usesLocal ? OmiColors.success : OmiColors.textTertiary + } + + private func transcriptionProviderOption( + mode: TranscriptionProviderKind, + title: String, + detail: String, + icon: String + ) -> some View { + let isSelected = transcriptionProviderSelection.mode == mode + + return Button(action: { + updateTranscriptionProviderSelection( + TranscriptionProviderSelection(mode: mode, quality: transcriptionProviderSelection.quality) + ) + }) { + HStack(alignment: .top, spacing: 12) { + Image(systemName: isSelected ? "checkmark.circle.fill" : "circle") + .scaledFont(size: 20) + .foregroundColor(isSelected ? OmiColors.purplePrimary : OmiColors.textTertiary) + + Image(systemName: icon) + .scaledFont(size: 15) + .foregroundColor(OmiColors.textSecondary) + .frame(width: 18) + + VStack(alignment: .leading, spacing: 4) { + Text(title) + .scaledFont(size: 14, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + + Text(detail) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } + + Spacer(minLength: 12) + } + .padding(12) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(isSelected ? OmiColors.purplePrimary.opacity(0.1) : Color.clear) + .overlay( + RoundedRectangle(cornerRadius: 8) + .stroke( + isSelected ? OmiColors.purplePrimary.opacity(0.3) : OmiColors.backgroundQuaternary, + lineWidth: 1) + ) + ) + } + .buttonStyle(.plain) + } + + private func updateTranscriptionProviderSelection(_ selection: TranscriptionProviderSelection) { + guard selection != transcriptionProviderSelection else { return } + transcriptionProviderSelection = selection + AssistantSettings.shared.transcriptionProviderSelection = selection + restartTranscriptionIfNeeded() + } + // MARK: - Notifications Section private var notificationsSection: some View { diff --git a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift index e4f1cb8dab8..7e9fbeb69a4 100644 --- a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift +++ b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift @@ -1,8 +1,8 @@ import SwiftUI /// Final step before Tasks: offer a free-forever plan if the user supplies their -/// own API keys for OpenAI, Anthropic, Gemini, and Deepgram. Keys live on the -/// device (UserDefaults); the backend receives only SHA-256 fingerprints. +/// own API keys. Keys live on the device (UserDefaults); the backend receives +/// only SHA-256 fingerprints. struct OnboardingBYOKStepView: View { @ObservedObject var graphViewModel: MemoryGraphViewModel let stepIndex: Int @@ -19,6 +19,12 @@ struct OnboardingBYOKStepView: View { @State private var isActivating = false @State private var activationError: String? @State private var keyStatuses: [BYOKProvider: BYOKValidator.Status] = [:] + @State private var localCapabilities = OnboardingBYOKStepView.detectLocalCapabilities() + @State private var providerSelection = TranscriptionProviderSelection.default + + private var recommendation: TranscriptionProviderOnboardingRecommendation { + TranscriptionProviderOnboardingAdvisor().recommendation(capabilities: localCapabilities) + } var body: some View { OnboardingStepScaffold( @@ -26,9 +32,9 @@ struct OnboardingBYOKStepView: View { stepIndex: stepIndex, totalSteps: totalSteps, eyebrow: "Free forever", - title: "Bring your own keys.", + title: "Choose transcription.", description: - "Paste your own API keys for OpenAI, Anthropic, Gemini, and Deepgram and Omi is free forever. Keys stay on this Mac — we never store them on our servers.", + "Use local Whisper when this Mac can support it, or keep the existing cloud transcription path. API keys are optional unless you want the free-forever plan.", showsSkip: true, onSkip: { AnalyticsManager.shared.onboardingStepCompleted(step: stepIndex, stepName: "BYOK_Skipped") @@ -37,10 +43,30 @@ struct OnboardingBYOKStepView: View { onForceComplete: onForceComplete ) { VStack(alignment: .leading, spacing: 18) { + transcriptionChoice + + Divider() + .background(Color.white.opacity(0.08)) + .frame(maxWidth: 560) + + VStack(alignment: .leading, spacing: 6) { + Text("Bring your own keys") + .font(.system(size: 14, weight: .semibold)) + .foregroundColor(OmiColors.textPrimary) + + Text( + "Add OpenAI, Anthropic, Gemini, and Deepgram keys to activate the free plan. Local Whisper does not require a Deepgram key." + ) + .font(.system(size: 12)) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } + .frame(maxWidth: 560, alignment: .leading) + keyField(provider: .openai, binding: $openaiKey, help: "Used for GPT calls.") keyField(provider: .anthropic, binding: $anthropicKey, help: "Used for Claude chat.") keyField(provider: .gemini, binding: $geminiKey, help: "Used for proactive AI.") - keyField(provider: .deepgram, binding: $deepgramKey, help: "Used for transcription.") + keyField(provider: .deepgram, binding: $deepgramKey, help: "Used for cloud transcription.") if let activationError { Text(activationError) @@ -64,6 +90,84 @@ struct OnboardingBYOKStepView: View { } .frame(maxWidth: .infinity, alignment: .leading) } + .onAppear { + localCapabilities = OnboardingBYOKStepView.detectLocalCapabilities() + providerSelection = AssistantSettings.shared.transcriptionProviderSelection + } + } + + private var transcriptionChoice: some View { + VStack(alignment: .leading, spacing: 12) { + HStack(alignment: .top, spacing: 12) { + Image(systemName: recommendation.canRecommendLocal ? "desktopcomputer" : "cloud.fill") + .font(.system(size: 18, weight: .semibold)) + .foregroundColor(OmiColors.textSecondary) + .frame(width: 24) + + VStack(alignment: .leading, spacing: 5) { + Text(recommendation.title) + .font(.system(size: 15, weight: .semibold)) + .foregroundColor(OmiColors.textPrimary) + + Text(recommendation.detail) + .font(.system(size: 12)) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + + Text(recommendation.status) + .font(.system(size: 11, weight: .medium)) + .foregroundColor( + recommendation.canRecommendLocal ? OmiColors.success : OmiColors.warning + ) + .fixedSize(horizontal: false, vertical: true) + } + } + .frame(maxWidth: 560, alignment: .leading) + .padding(14) + .background( + RoundedRectangle(cornerRadius: 12, style: .continuous) + .fill(OmiColors.backgroundSecondary) + .overlay( + RoundedRectangle(cornerRadius: 12, style: .continuous) + .stroke(Color.white.opacity(0.08), lineWidth: 1) + ) + ) + + HStack(spacing: 12) { + Button(recommendation.canRecommendLocal ? "Use Local Whisper" : "Use Cloud Transcription") { + saveProviderSelection(recommendation.recommendedSelection) + AnalyticsManager.shared.onboardingStepCompleted( + step: stepIndex, + stepName: recommendation.canRecommendLocal + ? "Transcription_Local" : "Transcription_Cloud" + ) + onContinue() + } + .buttonStyle(OnboardingCardButtonStyle(isPrimary: true)) + + Button("Transcribe in the Cloud") { + saveProviderSelection(TranscriptionProviderSelection(mode: .cloud, quality: .auto)) + AnalyticsManager.shared.onboardingStepCompleted( + step: stepIndex, + stepName: "Transcription_Cloud" + ) + onContinue() + } + .buttonStyle(OnboardingCloudChoiceButtonStyle()) + } + .frame(maxWidth: 560, alignment: .leading) + } + } + + private static func detectLocalCapabilities() -> LocalTranscriptionCapabilities { + LocalTranscriptionCapabilityDetector( + availableEngines: { LocalASRHelperLocator.detectedEngines() } + ).detect() + } + + private func saveProviderSelection(_ selection: TranscriptionProviderSelection) { + providerSelection = selection + AssistantSettings.shared.transcriptionProviderSelection = selection } private var allKeysProvided: Bool { @@ -150,18 +254,20 @@ struct OnboardingBYOKStepView: View { } if !failed.isEmpty { let names = failed.keys.map(\.displayName).sorted().joined(separator: ", ") - activationError = "These keys were rejected by their provider: \(names). Fix them to continue." + activationError = + "These keys were rejected by their provider: \(names). Fix them to continue." return } // Step 2: all four authenticate — flip the backend flag. do { - try await APIClient.shared.activateBYOK(fingerprints: BYOKProvider.allCases.reduce(into: [:]) { - acc, provider in - if let key = APIKeyService.byokKey(provider) { - acc[provider.rawValue] = APIKeyService.byokFingerprint(key) - } - }) + try await APIClient.shared.activateBYOK( + fingerprints: BYOKProvider.allCases.reduce(into: [:]) { + acc, provider in + if let key = APIKeyService.byokKey(provider) { + acc[provider.rawValue] = APIKeyService.byokFingerprint(key) + } + }) // Refresh the in-memory quota snapshot — otherwise the client keeps // blocking chat against the stale basic-tier 30-message cap. await FloatingBarUsageLimiter.shared.fetchPlan() @@ -173,3 +279,21 @@ struct OnboardingBYOKStepView: View { } } } + +private struct OnboardingCloudChoiceButtonStyle: ButtonStyle { + func makeBody(configuration: Configuration) -> some View { + configuration.label + .font(.system(size: 12, weight: .semibold)) + .foregroundColor(OmiColors.textTertiary) + .padding(.horizontal, 10) + .padding(.vertical, 8) + .background( + RoundedRectangle(cornerRadius: 10, style: .continuous) + .fill(Color.white.opacity(configuration.isPressed ? 0.08 : 0.04)) + ) + .overlay( + RoundedRectangle(cornerRadius: 10, style: .continuous) + .stroke(Color.white.opacity(0.06), lineWidth: 1) + ) + } +} diff --git a/desktop/Desktop/Sources/OnboardingView.swift b/desktop/Desktop/Sources/OnboardingView.swift index 287d6558547..38781663e35 100644 --- a/desktop/Desktop/Sources/OnboardingView.swift +++ b/desktop/Desktop/Sources/OnboardingView.swift @@ -502,7 +502,7 @@ struct OnboardingView: View { AnalyticsManager.shared.launchAtLoginChanged(enabled: true, source: "onboarding_complete") } startMonitoringIfNeeded() - appState.startTranscription() + startBackgroundTranscriptionIfAvailable() // Create welcome task Task { @@ -525,6 +525,26 @@ struct OnboardingView: View { ProactiveAssistantsPlugin.shared.startMonitoring { _, _ in } } } + + private func startBackgroundTranscriptionIfAvailable() { + let selection = AssistantSettings.shared.transcriptionProviderSelection + let capabilities = LocalTranscriptionCapabilityDetector( + availableEngines: { LocalASRHelperLocator.detectedEngines() } + ).detect() + let routing = BackgroundTranscriptionRoutingGuard().decide( + selection: selection, + capabilities: capabilities + ) + + guard routing.useCloudBackend else { + log( + "OnboardingView: skipping automatic background transcription start because local background capture is unavailable" + ) + return + } + + appState.startTranscription() + } } struct OnboardingTrustPreviewCard: View { diff --git a/desktop/Desktop/Sources/TranscriptionProvider.swift b/desktop/Desktop/Sources/TranscriptionProvider.swift index f5882691ebb..23fc98eee91 100644 --- a/desktop/Desktop/Sources/TranscriptionProvider.swift +++ b/desktop/Desktop/Sources/TranscriptionProvider.swift @@ -26,6 +26,14 @@ struct TranscriptionProviderSelection: Codable, Equatable { static let `default` = TranscriptionProviderSelection(mode: .auto, quality: .auto) } +struct TranscriptionProviderOnboardingRecommendation: Equatable { + var recommendedSelection: TranscriptionProviderSelection + var canRecommendLocal: Bool + var title: String + var detail: String + var status: String +} + struct LocalTranscriptionCapabilities: Equatable { enum Processor: Equatable { case nativeAppleSilicon @@ -180,6 +188,84 @@ struct TranscriptionProviderPolicy { } } +struct TranscriptionProviderOnboardingAdvisor { + var policy: TranscriptionProviderPolicy = TranscriptionProviderPolicy() + + func recommendation( + capabilities: LocalTranscriptionCapabilities, + quality: TranscriptionQualityPreset = .auto + ) -> TranscriptionProviderOnboardingRecommendation { + let localFirst = TranscriptionProviderSelection(mode: .auto, quality: quality) + let result = policy.resolve(selection: localFirst, capabilities: capabilities) + + if result.usesLocal { + return TranscriptionProviderOnboardingRecommendation( + recommendedSelection: localFirst, + canRecommendLocal: true, + title: "Use Local Whisper", + detail: + "Recommended for this Mac. Voice notes stay on-device with local Whisper; background capture can use cloud when you choose it.", + status: Self.statusText(for: result) + ) + } + + return TranscriptionProviderOnboardingRecommendation( + recommendedSelection: TranscriptionProviderSelection(mode: .cloud, quality: quality), + canRecommendLocal: false, + title: "Use Cloud Transcription", + detail: + "Local Whisper is not available on this Mac yet. Cloud transcription keeps meetings and background capture working.", + status: result.fallbackReason ?? "Local Whisper is unavailable" + ) + } + + static func statusText(for result: TranscriptionProviderPolicyResult) -> String { + if result.usesCloud { + return result.fallbackReason ?? "Using Omi cloud transcription" + } + + guard let engine = result.localEngine, let plan = result.localPlan else { + return "Using local transcription" + } + + return "Using \(displayName(for: engine)) with \(displayName(for: plan.model))" + } + + static func displayName(for mode: TranscriptionProviderKind) -> String { + switch mode { + case .auto: return "Local First" + case .local: return "Local Whisper Only" + case .cloud: return "Cloud Transcription" + } + } + + static func displayName(for quality: TranscriptionQualityPreset) -> String { + switch quality { + case .auto: return "Auto" + case .fast: return "Fast" + case .balanced: return "Balanced" + case .accurate: return "Accurate" + } + } + + static func displayName(for engine: LocalTranscriptionEngine) -> String { + switch engine { + case .mlxWhisper: return "MLX Whisper" + case .fasterWhisper: return "faster-whisper" + } + } + + static func displayName(for model: LocalTranscriptionModel) -> String { + switch model { + case .tiny: return "Tiny" + case .base: return "Base" + case .small: return "Small" + case .medium: return "Medium" + case .largeV3Turbo: return "Large v3 Turbo" + } + } +} + struct NormalizedTranscriptTranslation: Codable, Equatable { var lang: String var text: String diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index fa836ee8ca3..f2e118cb6de 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -471,6 +471,36 @@ final class BackgroundTranscriptionRoutingGuardTests: XCTestCase { } } +final class TranscriptionProviderOnboardingAdvisorTests: XCTestCase { + func testEligibleNativeAppleSiliconRecommendsLocalFirst() { + let recommendation = TranscriptionProviderOnboardingAdvisor().recommendation( + capabilities: LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper, .fasterWhisper] + ) + ) + + XCTAssertTrue(recommendation.canRecommendLocal) + XCTAssertEqual(recommendation.recommendedSelection.mode, .auto) + XCTAssertTrue(recommendation.status.contains("MLX Whisper")) + } + + func testUnavailableLocalEngineRecommendsCloudFallback() { + let recommendation = TranscriptionProviderOnboardingAdvisor().recommendation( + capabilities: LocalTranscriptionCapabilities( + processor: .intel, + physicalMemoryBytes: 8 * 1024 * 1024 * 1024, + availableEngines: [] + ) + ) + + XCTAssertFalse(recommendation.canRecommendLocal) + XCTAssertEqual(recommendation.recommendedSelection.mode, .cloud) + XCTAssertTrue(recommendation.status.contains("cloud")) + } +} + final class PTTTranscriptPostProcessorTests: XCTestCase { func testLocalModeBypassesLLMCleanup() async { var cleanupCalls = 0 From b6bcdd384a868ae3e4f60eb8df6110b2e2ccba60 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 11:24:01 +0700 Subject: [PATCH 44/58] Add real local ASR helper runtime (cherry picked from commit 5e7bc7be08c170130f3f6e30a9bfaec4e28ccc90) --- .../LocalTranscription/LocalASRRuntime.swift | 46 +- .../TranscriptionProviderPolicyTests.swift | 38 ++ desktop/local-asr-helper/Cargo.toml | 5 +- desktop/local-asr-helper/README.md | 44 +- desktop/local-asr-helper/src/main.rs | 433 +++++++++++++++++- desktop/run.sh | 19 +- 6 files changed, 554 insertions(+), 31 deletions(-) diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index 908ccb0c308..e18285f83e9 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -54,6 +54,16 @@ struct LocalASRTranscriptionResponse: Codable, Equatable { } } +struct LocalASRCapabilityResponse: Codable, Equatable { + var engines: [LocalASREngineCapability] +} + +struct LocalASREngineCapability: Codable, Equatable { + var engine: LocalTranscriptionEngine + var available: Bool + var reason: String? +} + struct LocalASRTranscriptSegment: Codable, Equatable { var id: String? var speaker: Int? @@ -177,8 +187,40 @@ enum LocalASRHelperLocator { static func detectedEngines(executableURL: URL? = defaultExecutableURL()) -> Set { - guard executableURL != nil else { return [] } - return Set(LocalTranscriptionEngine.allCases) + guard let executableURL else { return [] } + let process = Process() + process.executableURL = executableURL + process.arguments = ["--capabilities"] + + let output = Pipe() + process.standardOutput = output + process.standardError = Pipe() + + do { + try process.run() + } catch { + return [] + } + + let deadline = Date().addingTimeInterval(8) + while process.isRunning && Date() < deadline { + Thread.sleep(forTimeInterval: 0.05) + } + if process.isRunning { + process.terminate() + return [] + } + guard process.terminationStatus == 0 else { return [] } + + let outputData = output.fileHandleForReading.readDataToEndOfFile() + guard + let response = try? JSONDecoder.localASR.decode( + LocalASRCapabilityResponse.self, from: outputData) + else { + return [] + } + + return Set(response.engines.filter(\.available).map(\.engine)) } } diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index f2e118cb6de..85fdb85c055 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -316,6 +316,44 @@ final class LocalASRRuntimeTests: XCTestCase { ) XCTAssertEqual(result.map(\.text), ["hello local whisper"]) } + + func testDetectedEnginesUsesHelperCapabilityProbe() throws { + let helper = try makeExecutableHelper( + body: + #"printf '{"engines":[{"engine":"mlx-whisper","available":true},{"engine":"faster-whisper","available":false,"reason":"missing model"}]}'"# + ) + + let engines = LocalASRHelperLocator.detectedEngines(executableURL: helper) + + XCTAssertEqual(engines, [.mlxWhisper]) + } + + func testDetectedEnginesReturnsEmptyOnProbeFailure() throws { + let helper = try makeExecutableHelper(body: "exit 2") + + let engines = LocalASRHelperLocator.detectedEngines(executableURL: helper) + + XCTAssertTrue(engines.isEmpty) + } + + private func makeExecutableHelper(body: String) throws -> URL { + let directory = FileManager.default.temporaryDirectory.appendingPathComponent( + "LocalASRHelperLocatorTests-\(UUID().uuidString)", + isDirectory: true + ) + try FileManager.default.createDirectory(at: directory, withIntermediateDirectories: true) + let helper = directory.appendingPathComponent("local-asr-helper") + let script = "#!/bin/sh\n\(body)\n" + try script.write(to: helper, atomically: true, encoding: .utf8) + try FileManager.default.setAttributes( + [.posixPermissions: 0o755], + ofItemAtPath: helper.path + ) + addTeardownBlock { + try? FileManager.default.removeItem(at: directory) + } + return helper + } } final class PTTBatchTranscriptionRouterTests: XCTestCase { diff --git a/desktop/local-asr-helper/Cargo.toml b/desktop/local-asr-helper/Cargo.toml index 19dc460df0a..b75d4154fc9 100644 --- a/desktop/local-asr-helper/Cargo.toml +++ b/desktop/local-asr-helper/Cargo.toml @@ -3,7 +3,10 @@ name = "omi-local-asr-helper" version = "0.1.0" edition = "2021" +[[bin]] +name = "local-asr-helper" +path = "src/main.rs" + [dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" - diff --git a/desktop/local-asr-helper/README.md b/desktop/local-asr-helper/README.md index 66683270ac8..18d8c60c266 100644 --- a/desktop/local-asr-helper/README.md +++ b/desktop/local-asr-helper/README.md @@ -1,13 +1,45 @@ # Omi Local ASR Helper -This is the control-plane scaffold for local Whisper transcription. The desktop -app sends one JSON request on stdin and reads one JSON response on stdout. The -current implementation is fixture-backed so CI can exercise the same contract -before MLX Whisper and faster-whisper adapters are installed. +The desktop app sends one JSON transcription request on stdin and reads one JSON +response on stdout. `fixture_segments` are still accepted for contract tests, but +normal requests now require a real MLX Whisper or faster-whisper runtime. -Smoke command: +Capability probe: ```bash -printf '{"request_id":"fixture-1","audio_path":"/tmp/sample.pcm","language":"en","sample_rate":16000,"channels":1,"engine":"mlx-whisper","model":"small","fixture_segments":[{"id":"s1","speaker":0,"text":"hello local whisper","start":0.0,"end":1.2}]}' | cargo run --quiet --manifest-path desktop/local-asr-helper/Cargo.toml +cargo run --quiet --manifest-path desktop/local-asr-helper/Cargo.toml -- --capabilities ``` +Runtime setup: + +- Set `OMI_LOCAL_ASR_PYTHON` to the Python executable that has the ASR runtime. + It defaults to `python3`. +- MLX Whisper requires native Apple Silicon, `mlx-whisper`, and either a cached + Hugging Face model or a local model directory set with + `OMI_MLX_WHISPER_MODEL_DIR` or `OMI_MLX_WHISPER_MODEL_DIR_SMALL`. +- faster-whisper requires `faster-whisper` and either a cached Hugging Face model + or a local model directory set with `OMI_FASTER_WHISPER_MODEL_DIR` or + `OMI_FASTER_WHISPER_MODEL_DIR_SMALL`. +- For development only, `OMI_LOCAL_ASR_ALLOW_MODEL_DOWNLOAD=1` lets the Python + runtime resolve the default Hugging Face model names. + +Real PCM smoke command: + +```bash +say -o /tmp/omi-asr-smoke.aiff "hello local whisper" +afconvert /tmp/omi-asr-smoke.aiff -f WAVE -d LEI16@16000 -c 1 /tmp/omi-asr-smoke.wav +python3 - <<'PY' +import wave +with wave.open("/tmp/omi-asr-smoke.wav", "rb") as wav: + data = wav.readframes(wav.getnframes()) +with open("/tmp/omi-asr-smoke.pcm", "wb") as f: + f.write(data) +PY + +printf '{"request_id":"smoke-1","audio_path":"/tmp/omi-asr-smoke.pcm","language":"en","sample_rate":16000,"channels":1,"engine":"mlx-whisper","model":"small"}' \ + | cargo run --quiet --manifest-path desktop/local-asr-helper/Cargo.toml +``` + +The dev desktop app build (`desktop/run.sh`) builds this helper and copies it to +`.app/Contents/Resources/local-asr-helper`, which is the bundled path used +by `LocalASRHelperLocator`. diff --git a/desktop/local-asr-helper/src/main.rs b/desktop/local-asr-helper/src/main.rs index 66d201e5bc6..9074ba8e633 100644 --- a/desktop/local-asr-helper/src/main.rs +++ b/desktop/local-asr-helper/src/main.rs @@ -1,8 +1,14 @@ use serde::{Deserialize, Serialize}; -use std::io::{self, Read}; -use std::process; +use std::env; +use std::fs::{self, File}; +use std::io::{self, Read, Write}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; -#[derive(Debug, Deserialize, Serialize, PartialEq)] +const SAMPLE_RATE: u32 = 16_000; +const CHANNELS: u8 = 1; + +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] #[serde(rename_all = "snake_case")] struct TranscriptionRequest { request_id: String, @@ -15,14 +21,23 @@ struct TranscriptionRequest { fixture_segments: Option>, } -#[derive(Debug, Deserialize, Serialize, PartialEq)] +#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq)] #[serde(rename_all = "kebab-case")] enum LocalEngine { MlxWhisper, FasterWhisper, } -#[derive(Debug, Deserialize, Serialize, PartialEq)] +impl LocalEngine { + fn as_str(self) -> &'static str { + match self { + Self::MlxWhisper => "mlx-whisper", + Self::FasterWhisper => "faster-whisper", + } + } +} + +#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq)] #[serde(rename_all = "snake_case")] enum LocalModel { Tiny, @@ -32,6 +47,18 @@ enum LocalModel { LargeV3Turbo, } +impl LocalModel { + fn as_str(self) -> &'static str { + match self { + Self::Tiny => "tiny", + Self::Base => "base", + Self::Small => "small", + Self::Medium => "medium", + Self::LargeV3Turbo => "large_v3_turbo", + } + } +} + #[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] struct TranscriptSegment { id: Option, @@ -52,36 +79,74 @@ struct TranscriptionResponse { fixture: bool, } +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +struct CapabilityResponse { + engines: Vec, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +struct EngineCapability { + engine: LocalEngine, + available: bool, + reason: Option, +} + fn main() { if let Err(error) = run() { eprintln!("{error}"); - process::exit(1); + std::process::exit(1); } } fn run() -> Result<(), String> { + if env::args().any(|arg| arg == "--capabilities") { + let response = CapabilityResponse { + engines: vec![ + probe_engine(LocalEngine::MlxWhisper), + probe_engine(LocalEngine::FasterWhisper), + ], + }; + println!( + "{}", + serde_json::to_string(&response).map_err(|error| error.to_string())? + ); + return Ok(()); + } + let request = read_request()?; - let fixture_segments = request.fixture_segments.clone().unwrap_or_else(|| { - vec![TranscriptSegment { - id: Some(format!("{}-fixture-0", request.request_id)), - speaker: Some(0), - text: "fixture local transcription".to_string(), - start: 0.0, - end: 1.0, - }] - }); + let segments = match request.fixture_segments.clone() { + Some(segments) => { + let response = TranscriptionResponse { + request_id: request.request_id, + engine: request.engine, + model: request.model, + language: request.language, + segments, + fixture: true, + }; + println!( + "{}", + serde_json::to_string(&response).map_err(|error| error.to_string())? + ); + return Ok(()); + } + None => transcribe(&request)?, + }; let response = TranscriptionResponse { request_id: request.request_id, engine: request.engine, model: request.model, language: request.language, - segments: fixture_segments, - fixture: true, + segments, + fixture: false, }; - - let json = serde_json::to_string(&response).map_err(|error| error.to_string())?; - println!("{json}"); + println!( + "{}", + serde_json::to_string(&response).map_err(|error| error.to_string())? + ); Ok(()) } @@ -93,6 +158,302 @@ fn read_request() -> Result { serde_json::from_str(&input).map_err(|error| format!("invalid request json: {error}")) } +fn transcribe(request: &TranscriptionRequest) -> Result, String> { + if request.sample_rate != SAMPLE_RATE || request.channels != CHANNELS { + return Err(format!( + "local ASR expects {SAMPLE_RATE} Hz mono PCM, got {} Hz with {} channel(s)", + request.sample_rate, request.channels + )); + } + + let capability = probe_engine(request.engine); + if !capability.available { + return Err(format!( + "{} is unavailable: {}", + request.engine.as_str(), + capability + .reason + .unwrap_or_else(|| "capability probe failed".to_string()) + )); + } + + let audio_path = Path::new(&request.audio_path); + if !audio_path.is_file() { + return Err(format!("audio file does not exist: {}", request.audio_path)); + } + + let wav_path = write_wav_copy(audio_path, &request.request_id)?; + let result = transcribe_with_python(request, &wav_path); + let _ = fs::remove_file(&wav_path); + result +} + +fn write_wav_copy(pcm_path: &Path, request_id: &str) -> Result { + let pcm = fs::read(pcm_path).map_err(|error| format!("failed to read PCM audio: {error}"))?; + if pcm.len() % 2 != 0 { + return Err("PCM audio must be signed 16-bit little-endian samples".to_string()); + } + + let path = env::temp_dir().join(format!("omi-local-asr-{request_id}.wav")); + let mut file = + File::create(&path).map_err(|error| format!("failed to create WAV file: {error}"))?; + write_wav_header(&mut file, pcm.len() as u32)?; + file.write_all(&pcm) + .map_err(|error| format!("failed to write WAV audio: {error}"))?; + Ok(path) +} + +fn write_wav_header(file: &mut File, data_len: u32) -> Result<(), String> { + let byte_rate = SAMPLE_RATE * CHANNELS as u32 * 2; + let block_align = CHANNELS as u16 * 2; + file.write_all(b"RIFF").map_err(|error| error.to_string())?; + file.write_all(&(36 + data_len).to_le_bytes()) + .map_err(|error| error.to_string())?; + file.write_all(b"WAVEfmt ") + .map_err(|error| error.to_string())?; + file.write_all(&16u32.to_le_bytes()) + .map_err(|error| error.to_string())?; + file.write_all(&1u16.to_le_bytes()) + .map_err(|error| error.to_string())?; + file.write_all(&(CHANNELS as u16).to_le_bytes()) + .map_err(|error| error.to_string())?; + file.write_all(&SAMPLE_RATE.to_le_bytes()) + .map_err(|error| error.to_string())?; + file.write_all(&byte_rate.to_le_bytes()) + .map_err(|error| error.to_string())?; + file.write_all(&block_align.to_le_bytes()) + .map_err(|error| error.to_string())?; + file.write_all(&16u16.to_le_bytes()) + .map_err(|error| error.to_string())?; + file.write_all(b"data").map_err(|error| error.to_string())?; + file.write_all(&data_len.to_le_bytes()) + .map_err(|error| error.to_string()) +} + +fn transcribe_with_python( + request: &TranscriptionRequest, + wav_path: &Path, +) -> Result, String> { + let model = model_argument(request.engine, request.model)?; + let output = Command::new(python_executable()) + .arg("-c") + .arg(PYTHON_TRANSCRIBE) + .arg(request.engine.as_str()) + .arg(model) + .arg(wav_path) + .arg(&request.language) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output() + .map_err(|error| format!("failed to start Python ASR adapter: {error}"))?; + + if !output.status.success() { + return Err(format!( + "Python ASR adapter failed: {}", + String::from_utf8_lossy(&output.stderr).trim() + )); + } + + serde_json::from_slice(&output.stdout).map_err(|error| { + format!( + "invalid Python ASR adapter response: {error}: {}", + String::from_utf8_lossy(&output.stdout).trim() + ) + }) +} + +fn probe_engine(engine: LocalEngine) -> EngineCapability { + if engine == LocalEngine::MlxWhisper && !is_native_apple_silicon() { + return EngineCapability { + engine, + available: false, + reason: Some("MLX Whisper requires native Apple Silicon".to_string()), + }; + } + + let output = Command::new(python_executable()) + .arg("-c") + .arg(PYTHON_PROBE) + .arg(engine.as_str()) + .arg(model_argument_for_probe(engine).unwrap_or_default()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output(); + + match output { + Ok(output) if output.status.success() => EngineCapability { + engine, + available: true, + reason: None, + }, + Ok(output) => EngineCapability { + engine, + available: false, + reason: Some(String::from_utf8_lossy(&output.stderr).trim().to_string()), + }, + Err(error) => EngineCapability { + engine, + available: false, + reason: Some(format!("failed to start Python probe: {error}")), + }, + } +} + +fn model_argument_for_probe(engine: LocalEngine) -> Result { + model_argument(engine, LocalModel::Small).or_else(|_| model_argument(engine, LocalModel::Base)) +} + +fn model_argument(engine: LocalEngine, model: LocalModel) -> Result { + let specific_key = format!( + "OMI_{}_MODEL_DIR_{}", + engine_env_prefix(engine), + model.as_str().to_ascii_uppercase() + ); + if let Ok(value) = env::var(&specific_key) { + if !value.is_empty() && Path::new(&value).exists() { + return Ok(value); + } + } + + let general_key = format!("OMI_{}_MODEL_DIR", engine_env_prefix(engine)); + if let Ok(value) = env::var(&general_key) { + if !value.is_empty() && Path::new(&value).exists() { + return Ok(value); + } + } + + Ok(default_remote_model(engine, model).to_string()) +} + +fn engine_env_prefix(engine: LocalEngine) -> &'static str { + match engine { + LocalEngine::MlxWhisper => "MLX_WHISPER", + LocalEngine::FasterWhisper => "FASTER_WHISPER", + } +} + +fn default_remote_model(engine: LocalEngine, model: LocalModel) -> &'static str { + match engine { + LocalEngine::MlxWhisper => match model { + LocalModel::Tiny => "mlx-community/whisper-tiny-mlx", + LocalModel::Base => "mlx-community/whisper-base-mlx", + LocalModel::Small => "mlx-community/whisper-small-mlx", + LocalModel::Medium => "mlx-community/whisper-medium-mlx", + LocalModel::LargeV3Turbo => "mlx-community/whisper-large-v3-turbo", + }, + LocalEngine::FasterWhisper => match model { + LocalModel::Tiny => "Systran/faster-whisper-tiny", + LocalModel::Base => "Systran/faster-whisper-base", + LocalModel::Small => "Systran/faster-whisper-small", + LocalModel::Medium => "Systran/faster-whisper-medium", + LocalModel::LargeV3Turbo => "mobiuslabsgmbh/faster-whisper-large-v3-turbo", + }, + } +} + +fn python_executable() -> String { + env::var("OMI_LOCAL_ASR_PYTHON").unwrap_or_else(|_| "python3".to_string()) +} + +fn is_native_apple_silicon() -> bool { + if env::consts::ARCH != "aarch64" { + return false; + } + let translated = Command::new("sysctl") + .args(["-in", "sysctl.proc_translated"]) + .output() + .ok() + .and_then(|output| String::from_utf8(output.stdout).ok()) + .map(|value| value.trim() == "1") + .unwrap_or(false); + !translated +} + +const PYTHON_PROBE: &str = r#" +import os +import sys + +engine = sys.argv[1] +model = sys.argv[2] + +if not model: + raise SystemExit("no local model is configured for this engine") + +if engine == "mlx-whisper": + import mlx_whisper # noqa: F401 +elif engine == "faster-whisper": + import faster_whisper # noqa: F401 +else: + raise SystemExit(f"unknown engine: {engine}") + +if os.path.exists(model): + print("ok") +else: + from huggingface_hub import snapshot_download + snapshot_download( + repo_id=model, + local_files_only=os.environ.get("OMI_LOCAL_ASR_ALLOW_MODEL_DOWNLOAD") != "1", + ) + print("ok") +"#; + +const PYTHON_TRANSCRIBE: &str = r#" +import json +import os +import sys + +engine, model, audio_path, language = sys.argv[1:5] +language = None if language in ("", "auto") else language + +def resolve_model(model): + if os.path.exists(model): + return model + from huggingface_hub import snapshot_download + return snapshot_download( + repo_id=model, + local_files_only=os.environ.get("OMI_LOCAL_ASR_ALLOW_MODEL_DOWNLOAD") != "1", + ) + +def clean_segments(raw_segments): + segments = [] + for index, segment in enumerate(raw_segments): + if isinstance(segment, dict): + start = float(segment.get("start", 0.0)) + end = float(segment.get("end", start)) + text = str(segment.get("text", "")).strip() + sid = segment.get("id", index) + else: + start = float(getattr(segment, "start", 0.0)) + end = float(getattr(segment, "end", start)) + text = str(getattr(segment, "text", "")).strip() + sid = getattr(segment, "id", index) + if text: + segments.append({ + "id": f"local-{sid}", + "speaker": 0, + "text": text, + "start": start, + "end": end, + }) + return segments + +if engine == "mlx-whisper": + import mlx_whisper + model = resolve_model(model) + result = mlx_whisper.transcribe(audio_path, path_or_hf_repo=model, language=language, verbose=False) + raw_segments = result.get("segments", []) +elif engine == "faster-whisper": + from faster_whisper import WhisperModel + allow_download = os.environ.get("OMI_LOCAL_ASR_ALLOW_MODEL_DOWNLOAD") == "1" + model = resolve_model(model) + whisper = WhisperModel(model, device="auto", compute_type="auto", local_files_only=not allow_download) + raw_segments, _ = whisper.transcribe(audio_path, language=language, vad_filter=True) +else: + raise SystemExit(f"unknown engine: {engine}") + +print(json.dumps(clean_segments(raw_segments))) +"#; + #[cfg(test)] mod tests { use super::*; @@ -118,4 +479,36 @@ mod tests { assert_eq!(request.model, LocalModel::Small); assert_eq!(request.fixture_segments.unwrap()[0].text, "hello"); } + + #[test] + fn writes_pcm_as_wav() { + let source = env::temp_dir().join("omi-local-asr-test.pcm"); + fs::write(&source, [0u8, 0, 1, 0]).unwrap(); + + let wav = write_wav_copy(&source, "unit-test").unwrap(); + let bytes = fs::read(&wav).unwrap(); + + assert_eq!(&bytes[0..4], b"RIFF"); + assert_eq!(&bytes[8..12], b"WAVE"); + assert_eq!(&bytes[44..], &[0u8, 0, 1, 0]); + + let _ = fs::remove_file(source); + let _ = fs::remove_file(wav); + } + + #[test] + fn capability_response_contract_round_trips() { + let response = CapabilityResponse { + engines: vec![EngineCapability { + engine: LocalEngine::FasterWhisper, + available: false, + reason: Some("missing model".to_string()), + }], + }; + + let json = serde_json::to_string(&response).unwrap(); + let decoded: CapabilityResponse = serde_json::from_str(&json).unwrap(); + + assert_eq!(decoded, response); + } } diff --git a/desktop/run.sh b/desktop/run.sh index 783fe19f3b5..7c431df596a 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -27,6 +27,7 @@ Options (via environment variables): OMI_HYBRID_DIRECT_STT_ENABLED Hybrid Apple Speech live transcription in local daemon (default 1 in configure_local_daemon_mode when unset) OMI_HYBRID_DIRECT_CHAT_ENABLED Hybrid OpenAI-compatible chat + daemon-backed sessions/messages (default 1 in configure_local_daemon_mode when unset) OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED Optional hybrid direct embeddings for vector search (default 0 in local bundle; local wiki search does not require embeddings) + OMI_SKIP_STALE_BUNDLE_SCAN=1 Skip scanning $HOME for stale dev app bundles Required files for cloud backend mode: Backend-Rust/.env Environment variables (copy from ../.env.example) @@ -277,13 +278,17 @@ find "$(dirname "$0")/../app/build" -name "$APP_NAME.app" -type d -exec rm -rf { # These confuse LaunchServices and get launched instead of the /Applications copy. # In local daemon mode, keep the primary user-test command fast and avoid broad # home-directory scans unless explicitly requested. -if ! is_local_daemon_mode || [ "${OMI_CLEAN_STALE_CLONES:-0}" = "1" ]; then +if [ "${OMI_SKIP_STALE_BUNDLE_SCAN:-0}" != "1" ] && { ! is_local_daemon_mode || [ "${OMI_CLEAN_STALE_CLONES:-0}" = "1" ]; }; then find "$HOME" -maxdepth 4 -name "$APP_NAME.app" -type d -not -path "$APP_BUNDLE" -not -path "$APP_PATH" 2>/dev/null | while read stale; do substep "Removing stale clone: $stale" rm -rf "$stale" done else - substep "Local daemon mode: skipping stale clone scan (set OMI_CLEAN_STALE_CLONES=1 to enable)" + if is_local_daemon_mode && [ "${OMI_CLEAN_STALE_CLONES:-0}" != "1" ]; then + substep "Local daemon mode: skipping stale clone scan (set OMI_CLEAN_STALE_CLONES=1 to enable)" + else + substep "Skipping stale bundle scan (OMI_SKIP_STALE_BUNDLE_SCAN=1)" + fi fi if [ "${OMI_SKIP_TUNNEL:-0}" != "1" ]; then @@ -520,6 +525,9 @@ fi step "Building Swift app (swift build -c debug)..." xcrun swift build -c debug --package-path Desktop +step "Building local ASR helper (cargo build)..." +cargo build --manifest-path local-asr-helper/Cargo.toml + auth_debug "AFTER swift build: auth_isSignedIn=$(defaults read "$BUNDLE_ID" auth_isSignedIn 2>&1 || true)" step "Creating app bundle..." @@ -699,6 +707,9 @@ fi substep "Copying app icon" cp -f omi_icon.icns "$APP_BUNDLE/Contents/Resources/OmiIcon.icns" 2>/dev/null || true +substep "Copying local ASR helper" +cp -f "local-asr-helper/target/debug/local-asr-helper" "$APP_BUNDLE/Contents/Resources/local-asr-helper" + substep "Creating PkgInfo" echo -n "APPL????" > "$APP_BUNDLE/Contents/PkgInfo" @@ -767,6 +778,10 @@ if [ -n "$SIGN_IDENTITY" ]; then substep "Signing bundled node binary" codesign --force --options runtime --entitlements Desktop/Node.entitlements --sign "$SIGN_IDENTITY" "$NODE_BIN" fi + if [ -f "$APP_BUNDLE/Contents/Resources/local-asr-helper" ]; then + substep "Signing local ASR helper" + codesign --force --options runtime --sign "$SIGN_IDENTITY" "$APP_BUNDLE/Contents/Resources/local-asr-helper" + fi # If local signing identity doesn't match embedded profile team, macOS rejects # restricted entitlements (notably com.apple.developer.applesignin) and launch From 4f48a48f34a6fac822d1889438d4e4d9cc19fc11 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 11:27:02 +0700 Subject: [PATCH 45/58] Narrow local background transcription scope (cherry picked from commit 5892f09d3e97f258b49e893b5489d79c28992e5b) --- .../Sources/MainWindow/Pages/SettingsPage.swift | 3 ++- .../Desktop/Sources/OnboardingBYOKStepView.swift | 2 +- .../Desktop/Sources/TranscriptionProvider.swift | 2 +- .../Tests/TranscriptionProviderPolicyTests.swift | 14 ++++++++++++++ 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 73fc7d1a613..75f767d1106 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -1090,7 +1090,8 @@ struct SettingsContentView: View { transcriptionProviderOption( mode: .auto, title: "Local First", - detail: "Use local Whisper when available; otherwise use cloud transcription.", + detail: + "Use local Whisper for Push-to-Talk when available; use cloud when local is unavailable or for background capture.", icon: "sparkle.magnifyingglass" ) diff --git a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift index 7e9fbeb69a4..abd9c94153a 100644 --- a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift +++ b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift @@ -34,7 +34,7 @@ struct OnboardingBYOKStepView: View { eyebrow: "Free forever", title: "Choose transcription.", description: - "Use local Whisper when this Mac can support it, or keep the existing cloud transcription path. API keys are optional unless you want the free-forever plan.", + "Use local Whisper for Push-to-Talk when this Mac can support it, or keep the existing cloud transcription path for continuous background capture. API keys are optional unless you want the free-forever plan.", showsSkip: true, onSkip: { AnalyticsManager.shared.onboardingStepCompleted(step: stepIndex, stepName: "BYOK_Skipped") diff --git a/desktop/Desktop/Sources/TranscriptionProvider.swift b/desktop/Desktop/Sources/TranscriptionProvider.swift index 23fc98eee91..c9b41de8674 100644 --- a/desktop/Desktop/Sources/TranscriptionProvider.swift +++ b/desktop/Desktop/Sources/TranscriptionProvider.swift @@ -204,7 +204,7 @@ struct TranscriptionProviderOnboardingAdvisor { canRecommendLocal: true, title: "Use Local Whisper", detail: - "Recommended for this Mac. Voice notes stay on-device with local Whisper; background capture can use cloud when you choose it.", + "Recommended for this Mac. Push-to-Talk transcription stays on-device with local Whisper; continuous background capture still requires cloud transcription.", status: Self.statusText(for: result) ) } diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index 85fdb85c055..cecc15988ee 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -494,6 +494,20 @@ final class BackgroundTranscriptionRoutingGuardTests: XCTestCase { XCTAssertTrue(decision.unsupportedLocalReason?.contains("backend force-processing") == true) } + func testAutoResolvedLocalBlocksBackgroundCaptureUntilLocalFinalizeExists() { + let decision = BackgroundTranscriptionRoutingGuard().decide( + selection: TranscriptionProviderSelection(mode: .auto, quality: .balanced), + capabilities: LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper] + ) + ) + + XCTAssertFalse(decision.useCloudBackend) + XCTAssertTrue(decision.unsupportedLocalReason?.contains("backend force-processing") == true) + } + func testExplicitLocalWithoutEngineDoesNotSilentlyUseCloudForBackground() { let decision = BackgroundTranscriptionRoutingGuard().decide( selection: TranscriptionProviderSelection(mode: .local, quality: .balanced), From 9ee7db2c892ec9f0035df583db781dfa373be842 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 20 May 2026 04:32:14 +0000 Subject: [PATCH 46/58] Fix transcription word count corruption and main-thread engine probe blocking (cherry picked from commit c6490162262e1ec0ff7de2249e24edbaeed0925c) --- desktop/Desktop/Sources/AppState.swift | 6 ++++- .../LocalTranscription/LocalASRRuntime.swift | 22 ++++++++++++++----- .../Sources/TranscriptionProvider.swift | 1 - 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index bb2152a2a87..3cb62253b7e 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -2632,6 +2632,7 @@ class AppState: ObservableObject { let translatedSegments = try JSONDecoder().decode( [TranscriptionService.BackendSegment].self, from: data) log("Transcription: Translation event with \(translatedSegments.count) segments") + var updatedInMemorySegments = false for translated in translatedSegments { guard let segId = translated.id else { continue } let newTranslations = (translated.translations ?? []).map { @@ -2642,7 +2643,7 @@ class AppState: ObservableObject { // Update in-memory if the segment is still loaded if let idx = speakerSegments.firstIndex(where: { $0.segmentId == segId }) { speakerSegments[idx].translations = newTranslations - speakerSegmentReducer.replaceSegments(speakerSegments) + updatedInMemorySegments = true } // Always persist to SQLite — even if the segment was trimmed from @@ -2671,6 +2672,9 @@ class AppState: ObservableObject { } } } + if updatedInMemorySegments { + speakerSegmentReducer.replaceSegments(speakerSegments) + } LiveTranscriptMonitor.shared.updateSegments(speakerSegments) } catch { logError("Transcription: Failed to parse translation event", error: error) diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index e18285f83e9..edf712c1636 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -186,6 +186,16 @@ enum LocalASRHelperLocator { static func detectedEngines(executableURL: URL? = defaultExecutableURL()) -> Set + { + let probe = { detectedEnginesBlocking(executableURL: executableURL) } + if Thread.isMainThread { + return DispatchQueue.global(qos: .userInitiated).sync(execute: probe) + } + return probe() + } + + private static func detectedEnginesBlocking(executableURL: URL?) + -> Set { guard let executableURL else { return [] } let process = Process() @@ -202,13 +212,13 @@ enum LocalASRHelperLocator { return [] } - let deadline = Date().addingTimeInterval(8) - while process.isRunning && Date() < deadline { - Thread.sleep(forTimeInterval: 0.05) - } if process.isRunning { - process.terminate() - return [] + let finished = DispatchSemaphore(value: 0) + process.terminationHandler = { _ in finished.signal() } + if finished.wait(timeout: .now() + 8) == .timedOut { + process.terminate() + return [] + } } guard process.terminationStatus == 0 else { return [] } diff --git a/desktop/Desktop/Sources/TranscriptionProvider.swift b/desktop/Desktop/Sources/TranscriptionProvider.swift index c9b41de8674..3e8bfe5183d 100644 --- a/desktop/Desktop/Sources/TranscriptionProvider.swift +++ b/desktop/Desktop/Sources/TranscriptionProvider.swift @@ -471,7 +471,6 @@ struct SpeakerSegmentReducer { mutating func replaceSegments(_ replacement: [SpeakerSegment]) { segments = replacement - totalWordCount = replacement.reduce(0) { $0 + wordCount($1.text) } } mutating func apply(_ incomingSegments: [SpeakerSegment]) -> ApplyResult { From 99e267edaf2811e0f120d4763bece55d6fbd5d87 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 11:50:28 +0700 Subject: [PATCH 47/58] Reorient desktop local transcription settings (cherry picked from commit 3fcd52e5857971e9a12f83c3599bce84cb46343a) --- desktop/Desktop/Sources/AppState.swift | 6 +-- .../LocalTranscription/LocalASRRuntime.swift | 53 ++++++++++++++----- .../MainWindow/Pages/SettingsPage.swift | 24 +++++---- .../Sources/OnboardingBYOKStepView.swift | 4 +- desktop/Desktop/Sources/OnboardingView.swift | 2 +- .../Sources/TranscriptionProvider.swift | 8 +-- .../TranscriptionProviderPolicyTests.swift | 15 ++++-- 7 files changed, 77 insertions(+), 35 deletions(-) diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 3cb62253b7e..188fc6a945d 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -1285,12 +1285,12 @@ class AppState: ObservableObject { if !backgroundRouting.useCloudBackend { let message = backgroundRouting.unsupportedLocalReason - ?? "Local background transcription is not available yet." + ?? "Local background transcription is selected and will not use the cloud listen path." log("Transcription: \(message)") showAlert( - title: "Local Background Transcription Unavailable", + title: "Local Background Transcription", message: - "Local transcription is currently available for Push-to-Talk batch mode. Choose cloud transcription to use background capture." + "Local background transcription is selected, so Omi will not start a cloud /v4/listen session. The local background session manager is not enabled in this build yet." ) return } diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index edf712c1636..f1564b5ba56 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -386,8 +386,38 @@ struct PTTBatchTranscriptionRouter { } struct BackgroundTranscriptionRoutingDecision: Equatable { - var useCloudBackend: Bool - var unsupportedLocalReason: String? + enum Route: Equatable { + case cloudBackend(fallbackReason: String?) + case localWhisper(LocalTranscriptionPlan) + case unavailable(String) + } + + var route: Route + + var useCloudBackend: Bool { + if case .cloudBackend = route { + return true + } + return false + } + + var localPlan: LocalTranscriptionPlan? { + if case .localWhisper(let plan) = route { + return plan + } + return nil + } + + var unsupportedLocalReason: String? { + switch route { + case .cloudBackend(let fallbackReason): + return fallbackReason + case .localWhisper: + return nil + case .unavailable(let reason): + return reason + } + } } struct BackgroundTranscriptionRoutingGuard { @@ -400,24 +430,23 @@ struct BackgroundTranscriptionRoutingGuard { let resolved = policy.resolve(selection: selection, capabilities: capabilities) if selection.mode == .local, resolved.provider != .local { return BackgroundTranscriptionRoutingDecision( - useCloudBackend: false, - unsupportedLocalReason: resolved.fallbackReason - ?? "No local transcription engine is available" + route: .unavailable(resolved.fallbackReason ?? "No local transcription engine is available") ) } guard resolved.provider == .local else { return BackgroundTranscriptionRoutingDecision( - useCloudBackend: true, - unsupportedLocalReason: resolved.fallbackReason + route: .cloudBackend(fallbackReason: resolved.fallbackReason) ) } - return BackgroundTranscriptionRoutingDecision( - useCloudBackend: false, - unsupportedLocalReason: - "Local background transcription is not available until local finalization can persist conversations without backend force-processing." - ) + guard let plan = resolved.localPlan else { + return BackgroundTranscriptionRoutingDecision( + route: .unavailable("No local transcription engine is available") + ) + } + + return BackgroundTranscriptionRoutingDecision(route: .localWhisper(plan)) } } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 75f767d1106..4adae640e19 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -1089,16 +1089,17 @@ struct SettingsContentView: View { VStack(spacing: 10) { transcriptionProviderOption( mode: .auto, - title: "Local First", + title: "Local Background First", detail: - "Use local Whisper for Push-to-Talk when available; use cloud when local is unavailable or for background capture.", + "Use local Whisper for continuous background transcription when available. If local is unavailable, this mode may use cloud transcription.", icon: "sparkle.magnifyingglass" ) transcriptionProviderOption( mode: .local, - title: "Local Whisper Only", - detail: "Use on-device batch transcription. Background capture is not available yet.", + title: "Local Background Only", + detail: + "Use on-device Whisper for continuous background transcription. If local ASR is unavailable, background transcription will fail instead of using cloud.", icon: "desktopcomputer" ) @@ -1106,21 +1107,19 @@ struct SettingsContentView: View { mode: .cloud, title: "Cloud Transcription", detail: - "Use the existing Omi cloud transcription path for live meetings and background capture.", + "Use the existing Omi cloud transcription path for live meetings and continuous background capture.", icon: "cloud.fill" ) } - if !backgroundTranscriptionRouting.useCloudBackend { + if let unavailableReason = backgroundTranscriptionUnavailableReason { HStack(alignment: .top, spacing: 8) { Image(systemName: "info.circle.fill") .scaledFont(size: 12) .foregroundColor(OmiColors.warning) .padding(.top, 1) - Text( - "Local background transcription is not available yet. Push-to-Talk can use local Whisper; choose Cloud Transcription for continuous background capture." - ) + Text(unavailableReason) .scaledFont(size: 12) .foregroundColor(OmiColors.warning) .fixedSize(horizontal: false, vertical: true) @@ -1464,6 +1463,13 @@ struct SettingsContentView: View { ) } + private var backgroundTranscriptionUnavailableReason: String? { + guard case .unavailable(let reason) = backgroundTranscriptionRouting.route else { + return nil + } + return reason + } + private var transcriptionProviderStatusText: String { TranscriptionProviderOnboardingAdvisor.statusText(for: resolvedTranscriptionProvider) } diff --git a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift index abd9c94153a..62f3c0056d7 100644 --- a/desktop/Desktop/Sources/OnboardingBYOKStepView.swift +++ b/desktop/Desktop/Sources/OnboardingBYOKStepView.swift @@ -34,7 +34,7 @@ struct OnboardingBYOKStepView: View { eyebrow: "Free forever", title: "Choose transcription.", description: - "Use local Whisper for Push-to-Talk when this Mac can support it, or keep the existing cloud transcription path for continuous background capture. API keys are optional unless you want the free-forever plan.", + "Use local Whisper for continuous background transcription when this Mac can support it, or keep the existing cloud transcription path. Push-to-Talk may still use cloud. API keys are optional unless you want the free-forever plan.", showsSkip: true, onSkip: { AnalyticsManager.shared.onboardingStepCompleted(step: stepIndex, stepName: "BYOK_Skipped") @@ -55,7 +55,7 @@ struct OnboardingBYOKStepView: View { .foregroundColor(OmiColors.textPrimary) Text( - "Add OpenAI, Anthropic, Gemini, and Deepgram keys to activate the free plan. Local Whisper does not require a Deepgram key." + "Add OpenAI, Anthropic, Gemini, and Deepgram keys to activate the free plan. Local background Whisper does not require a Deepgram key." ) .font(.system(size: 12)) .foregroundColor(OmiColors.textTertiary) diff --git a/desktop/Desktop/Sources/OnboardingView.swift b/desktop/Desktop/Sources/OnboardingView.swift index 38781663e35..27fd99a202a 100644 --- a/desktop/Desktop/Sources/OnboardingView.swift +++ b/desktop/Desktop/Sources/OnboardingView.swift @@ -538,7 +538,7 @@ struct OnboardingView: View { guard routing.useCloudBackend else { log( - "OnboardingView: skipping automatic background transcription start because local background capture is unavailable" + "OnboardingView: skipping automatic background transcription start because local background capture must not use the cloud listen path" ) return } diff --git a/desktop/Desktop/Sources/TranscriptionProvider.swift b/desktop/Desktop/Sources/TranscriptionProvider.swift index 3e8bfe5183d..493c2327eeb 100644 --- a/desktop/Desktop/Sources/TranscriptionProvider.swift +++ b/desktop/Desktop/Sources/TranscriptionProvider.swift @@ -204,7 +204,7 @@ struct TranscriptionProviderOnboardingAdvisor { canRecommendLocal: true, title: "Use Local Whisper", detail: - "Recommended for this Mac. Push-to-Talk transcription stays on-device with local Whisper; continuous background capture still requires cloud transcription.", + "Recommended for this Mac. Continuous background transcription can use local Whisper to reduce always-on cloud transcription cost; Push-to-Talk may still use the existing cloud path.", status: Self.statusText(for: result) ) } @@ -214,7 +214,7 @@ struct TranscriptionProviderOnboardingAdvisor { canRecommendLocal: false, title: "Use Cloud Transcription", detail: - "Local Whisper is not available on this Mac yet. Cloud transcription keeps meetings and background capture working.", + "Local Whisper is not available on this Mac yet. Cloud transcription keeps meetings and continuous background capture working.", status: result.fallbackReason ?? "Local Whisper is unavailable" ) } @@ -233,8 +233,8 @@ struct TranscriptionProviderOnboardingAdvisor { static func displayName(for mode: TranscriptionProviderKind) -> String { switch mode { - case .auto: return "Local First" - case .local: return "Local Whisper Only" + case .auto: return "Local Background First" + case .local: return "Local Background Only" case .cloud: return "Cloud Transcription" } } diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index cecc15988ee..acc5731a46e 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -478,9 +478,10 @@ final class BackgroundTranscriptionRoutingGuardTests: XCTestCase { XCTAssertTrue(decision.useCloudBackend) XCTAssertNotNil(decision.unsupportedLocalReason) + XCTAssertNil(decision.localPlan) } - func testResolvedLocalBlocksBackgroundCaptureUntilLocalFinalizeExists() { + func testResolvedLocalBackgroundRoutesToLocalWhisperAndNotCloudListen() { let decision = BackgroundTranscriptionRoutingGuard().decide( selection: TranscriptionProviderSelection(mode: .local, quality: .balanced), capabilities: LocalTranscriptionCapabilities( @@ -491,10 +492,11 @@ final class BackgroundTranscriptionRoutingGuardTests: XCTestCase { ) XCTAssertFalse(decision.useCloudBackend) - XCTAssertTrue(decision.unsupportedLocalReason?.contains("backend force-processing") == true) + XCTAssertEqual(decision.localPlan?.engine, .mlxWhisper) + XCTAssertNil(decision.unsupportedLocalReason) } - func testAutoResolvedLocalBlocksBackgroundCaptureUntilLocalFinalizeExists() { + func testAutoResolvedLocalBackgroundRoutesToLocalWhisperAndNotCloudListen() { let decision = BackgroundTranscriptionRoutingGuard().decide( selection: TranscriptionProviderSelection(mode: .auto, quality: .balanced), capabilities: LocalTranscriptionCapabilities( @@ -505,7 +507,8 @@ final class BackgroundTranscriptionRoutingGuardTests: XCTestCase { ) XCTAssertFalse(decision.useCloudBackend) - XCTAssertTrue(decision.unsupportedLocalReason?.contains("backend force-processing") == true) + XCTAssertEqual(decision.localPlan?.engine, .mlxWhisper) + XCTAssertNil(decision.unsupportedLocalReason) } func testExplicitLocalWithoutEngineDoesNotSilentlyUseCloudForBackground() { @@ -520,6 +523,7 @@ final class BackgroundTranscriptionRoutingGuardTests: XCTestCase { XCTAssertFalse(decision.useCloudBackend) XCTAssertEqual(decision.unsupportedLocalReason, "No local transcription engine is available") + XCTAssertNil(decision.localPlan) } } @@ -535,6 +539,8 @@ final class TranscriptionProviderOnboardingAdvisorTests: XCTestCase { XCTAssertTrue(recommendation.canRecommendLocal) XCTAssertEqual(recommendation.recommendedSelection.mode, .auto) + XCTAssertTrue(recommendation.detail.contains("Continuous background transcription")) + XCTAssertFalse(recommendation.detail.contains("still requires cloud transcription")) XCTAssertTrue(recommendation.status.contains("MLX Whisper")) } @@ -549,6 +555,7 @@ final class TranscriptionProviderOnboardingAdvisorTests: XCTestCase { XCTAssertFalse(recommendation.canRecommendLocal) XCTAssertEqual(recommendation.recommendedSelection.mode, .cloud) + XCTAssertTrue(recommendation.detail.contains("continuous background capture")) XCTAssertTrue(recommendation.status.contains("cloud")) } } From 772d264b2eade7003a40de7ec31c2944facd8a55 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 11:56:30 +0700 Subject: [PATCH 48/58] Add local background ASR chunk queue (cherry picked from commit 938c01c9f07a1a752ab45c75a9ea367aede66255) --- .../LocalTranscription/LocalASRRuntime.swift | 333 ++++++++++++++++++ .../TranscriptionProviderPolicyTests.swift | 195 ++++++++++ 2 files changed, 528 insertions(+) diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index f1564b5ba56..866ebf1062d 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -301,6 +301,339 @@ struct LocalASRBatchTranscriber { } } +struct LocalBackgroundChunkerConfiguration: Equatable { + var sampleRate: Int = 16000 + var bytesPerSample: Int = 2 + var maxChunkDuration: TimeInterval = 15 + var minChunkDuration: TimeInterval = 1 + var overlapDuration: TimeInterval = 1 + var silenceWindowDuration: TimeInterval = 0.35 + var silenceAmplitudeThreshold: Int16 = 256 + var maxPendingChunks: Int = 4 + + var maxChunkSamples: Int { max(1, Int(maxChunkDuration * Double(sampleRate))) } + var minChunkSamples: Int { max(1, Int(minChunkDuration * Double(sampleRate))) } + var overlapSamples: Int { + min(max(0, Int(overlapDuration * Double(sampleRate))), max(0, maxChunkSamples - 1)) + } + var silenceWindowSamples: Int { max(1, Int(silenceWindowDuration * Double(sampleRate))) } +} + +struct LocalBackgroundAudioChunk: Equatable { + var sequence: Int + var audioData: Data + var startTime: Double + var endTime: Double + var sampleRate: Int + var overlappedStartTime: Double? + + var duration: Double { + endTime - startTime + } +} + +struct LocalBackgroundIngestResult: Equatable { + var enqueuedChunks: [LocalBackgroundAudioChunk] + var droppedChunks: [LocalBackgroundAudioChunk] + var pendingChunkCount: Int +} + +struct LocalBackgroundASRRawChunkResult: Equatable { + var chunk: LocalBackgroundAudioChunk + var response: LocalASRTranscriptionResponse + var remappedSegments: [NormalizedTranscriptSegment] + var latencySeconds: TimeInterval + + var joinedText: String { + remappedSegments.map(\.text).joined(separator: " ") + .trimmingCharacters(in: .whitespacesAndNewlines) + } +} + +struct LocalBackgroundTranscriptSnapshot: Equatable { + var rawChunkResults: [LocalBackgroundASRRawChunkResult] + var mergedSegments: [NormalizedTranscriptSegment] + + var joinedTranscript: String { + mergedSegments.map(\.text).joined(separator: " ") + .trimmingCharacters(in: .whitespacesAndNewlines) + } +} + +struct LocalBackgroundAudioChunker { + private var configuration: LocalBackgroundChunkerConfiguration + private var buffer = Data() + private var bufferStartTime: Double? + private var nextSequence = 0 + + init(configuration: LocalBackgroundChunkerConfiguration = LocalBackgroundChunkerConfiguration()) { + self.configuration = configuration + } + + mutating func append(pcmData: Data, startTime: Double) -> [LocalBackgroundAudioChunk] { + guard !pcmData.isEmpty else { return [] } + if buffer.isEmpty { + bufferStartTime = startTime + } + buffer.append(alignedPCM(pcmData)) + return emitBoundedChunks(flush: false) + } + + mutating func flush() -> [LocalBackgroundAudioChunk] { + emitBoundedChunks(flush: true) + } + + private mutating func emitBoundedChunks(flush: Bool) -> [LocalBackgroundAudioChunk] { + var chunks: [LocalBackgroundAudioChunk] = [] + while sampleCount(in: buffer) >= configuration.maxChunkSamples + || (flush && sampleCount(in: buffer) > 0) + { + let availableSamples = sampleCount(in: buffer) + let cutSamples: Int + if flush { + cutSamples = availableSamples + } else { + cutSamples = cutSample(availableSamples: availableSamples) + } + + guard cutSamples > 0, let startTime = bufferStartTime else { break } + let byteCount = cutSamples * configuration.bytesPerSample + let audioData = buffer.prefix(byteCount) + let endTime = startTime + Double(cutSamples) / Double(configuration.sampleRate) + let overlappedStartTime = nextSequence == 0 ? nil : startTime + chunks.append( + LocalBackgroundAudioChunk( + sequence: nextSequence, + audioData: Data(audioData), + startTime: startTime, + endTime: endTime, + sampleRate: configuration.sampleRate, + overlappedStartTime: overlappedStartTime + ) + ) + nextSequence += 1 + + if flush { + buffer.removeFirst(byteCount) + bufferStartTime = buffer.isEmpty ? nil : endTime + } else { + let retainFromSample = max(0, cutSamples - configuration.overlapSamples) + let retainFromByte = retainFromSample * configuration.bytesPerSample + buffer.removeFirst(retainFromByte) + bufferStartTime = startTime + Double(retainFromSample) / Double(configuration.sampleRate) + } + } + return chunks + } + + private func cutSample(availableSamples: Int) -> Int { + let boundedSamples = min(availableSamples, configuration.maxChunkSamples) + if let silenceBoundary = lastSilenceBoundary(before: boundedSamples) { + return silenceBoundary + } + return boundedSamples + } + + private func lastSilenceBoundary(before upperBound: Int) -> Int? { + let window = configuration.silenceWindowSamples + let minimum = configuration.minChunkSamples + guard upperBound >= minimum + window else { return nil } + + let samples = int16Samples(from: buffer) + guard samples.count >= upperBound else { return nil } + + var index = upperBound - window + while index >= minimum { + let range = index..<(index + window) + if range.allSatisfy({ + abs(Int32(samples[$0])) <= Int32(configuration.silenceAmplitudeThreshold) + }) { + return index + window + } + index -= window + } + return nil + } + + private func alignedPCM(_ data: Data) -> Data { + if data.count % configuration.bytesPerSample == 0 { + return data + } + return data.dropLast(data.count % configuration.bytesPerSample) + } + + private func sampleCount(in data: Data) -> Int { + data.count / configuration.bytesPerSample + } + + private func int16Samples(from data: Data) -> [Int16] { + data.withUnsafeBytes { rawBuffer in + Array(rawBuffer.bindMemory(to: Int16.self)) + } + } +} + +final class LocalBackgroundTranscriptionSession { + typealias RequestHandler = (LocalASRTranscriptionRequest) async throws + -> LocalASRTranscriptionResponse + + private let language: String + private let plan: LocalTranscriptionPlan + private let configuration: LocalBackgroundChunkerConfiguration + private let requestHandler: RequestHandler + private let temporaryDirectory: URL + private let fileManager: FileManager + private let makeRequestId: () -> String + private let now: () -> Date + private var chunker: LocalBackgroundAudioChunker + private var pendingChunks: [LocalBackgroundAudioChunk] = [] + private var merger = LocalTranscriptMerger() + private(set) var rawChunkResults: [LocalBackgroundASRRawChunkResult] = [] + private(set) var droppedChunkCount = 0 + + init( + language: String, + plan: LocalTranscriptionPlan, + configuration: LocalBackgroundChunkerConfiguration = LocalBackgroundChunkerConfiguration(), + executableURL: URL, + timeoutSeconds: TimeInterval = 60, + temporaryDirectory: URL = FileManager.default.temporaryDirectory, + fileManager: FileManager = .default, + makeRequestId: @escaping () -> String = { UUID().uuidString }, + now: @escaping () -> Date = Date.init + ) { + let client = LocalASRHelperClient(executableURL: executableURL, timeoutSeconds: timeoutSeconds) + self.language = language + self.plan = plan + self.configuration = configuration + self.requestHandler = { request in + try await client.transcribe(request) + } + self.temporaryDirectory = temporaryDirectory + self.fileManager = fileManager + self.makeRequestId = makeRequestId + self.now = now + self.chunker = LocalBackgroundAudioChunker(configuration: configuration) + } + + init( + language: String, + plan: LocalTranscriptionPlan, + configuration: LocalBackgroundChunkerConfiguration = LocalBackgroundChunkerConfiguration(), + requestHandler: @escaping RequestHandler, + temporaryDirectory: URL = FileManager.default.temporaryDirectory, + fileManager: FileManager = .default, + makeRequestId: @escaping () -> String = { UUID().uuidString }, + now: @escaping () -> Date = Date.init + ) { + self.language = language + self.plan = plan + self.configuration = configuration + self.requestHandler = requestHandler + self.temporaryDirectory = temporaryDirectory + self.fileManager = fileManager + self.makeRequestId = makeRequestId + self.now = now + self.chunker = LocalBackgroundAudioChunker(configuration: configuration) + } + + func append(pcmData: Data, startTime: Double) -> LocalBackgroundIngestResult { + enqueue(chunker.append(pcmData: pcmData, startTime: startTime)) + } + + func finishInput() -> LocalBackgroundIngestResult { + enqueue(chunker.flush()) + } + + func transcribeNext() async throws -> LocalBackgroundASRRawChunkResult? { + guard !pendingChunks.isEmpty else { return nil } + let chunk = pendingChunks.removeFirst() + let requestId = makeRequestId() + let audioURL = temporaryDirectory.appendingPathComponent("\(requestId).pcm") + try chunk.audioData.write(to: audioURL, options: .atomic) + defer { try? fileManager.removeItem(at: audioURL) } + + let started = now() + let response = try await requestHandler( + LocalASRTranscriptionRequest( + requestId: requestId, + audioPath: audioURL.path, + language: language, + sampleRate: configuration.sampleRate, + channels: 1, + engine: plan.engine, + model: plan.model, + fixtureSegments: nil + ) + ) + let latency = max(0, now().timeIntervalSince(started)) + let remapped = remap(response.segments, chunk: chunk) + let merged = merger.merge(remapped) + let rawResult = LocalBackgroundASRRawChunkResult( + chunk: chunk, + response: response, + remappedSegments: remapped, + latencySeconds: latency + ) + rawChunkResults.append(rawResult) + _ = merged + return rawResult + } + + func transcribePending() async throws -> [LocalBackgroundASRRawChunkResult] { + var results: [LocalBackgroundASRRawChunkResult] = [] + while let result = try await transcribeNext() { + results.append(result) + } + return results + } + + func snapshot() -> LocalBackgroundTranscriptSnapshot { + LocalBackgroundTranscriptSnapshot( + rawChunkResults: rawChunkResults, + mergedSegments: merger.segments + ) + } + + private func enqueue(_ chunks: [LocalBackgroundAudioChunk]) -> LocalBackgroundIngestResult { + guard !chunks.isEmpty else { + return LocalBackgroundIngestResult( + enqueuedChunks: [], + droppedChunks: [], + pendingChunkCount: pendingChunks.count + ) + } + + pendingChunks.append(contentsOf: chunks) + var dropped: [LocalBackgroundAudioChunk] = [] + if pendingChunks.count > configuration.maxPendingChunks { + let overflow = pendingChunks.count - configuration.maxPendingChunks + dropped = Array(pendingChunks.prefix(overflow)) + pendingChunks.removeFirst(overflow) + droppedChunkCount += overflow + } + + return LocalBackgroundIngestResult( + enqueuedChunks: chunks, + droppedChunks: dropped, + pendingChunkCount: pendingChunks.count + ) + } + + private func remap( + _ segments: [LocalASRTranscriptSegment], + chunk: LocalBackgroundAudioChunk + ) -> [NormalizedTranscriptSegment] { + segments.map { segment in + var normalized = segment.normalized() + normalized.start = chunk.startTime + segment.start + normalized.end = chunk.startTime + segment.end + normalized.segmentId = normalized.segmentId ?? "local-bg-\(chunk.sequence)-\(segment.start)" + return normalized + } + } +} + struct PTTBatchTranscriptionResult: Equatable { var provider: TranscriptionProviderKind var transcript: String? diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index acc5731a46e..11ef8339a06 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -317,6 +317,189 @@ final class LocalASRRuntimeTests: XCTestCase { XCTAssertEqual(result.map(\.text), ["hello local whisper"]) } + func testBackgroundChunkerUsesSilenceBoundaryAndOverlap() { + var chunker = LocalBackgroundAudioChunker( + configuration: LocalBackgroundChunkerConfiguration( + sampleRate: 10, + bytesPerSample: 2, + maxChunkDuration: 1.0, + minChunkDuration: 0.4, + overlapDuration: 0.2, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 2, + maxPendingChunks: 4 + ) + ) + let samples: [Int16] = [20, 20, 20, 20, 0, 0, 20, 20, 20, 20, 20, 20] + + let chunks = chunker.append(pcmData: pcm(samples), startTime: 3.0) + let final = chunker.flush() + + XCTAssertEqual(chunks.count, 1) + XCTAssertEqual(chunks[0].sequence, 0) + XCTAssertEqual(chunks[0].startTime, 3.0, accuracy: 0.001) + XCTAssertEqual(chunks[0].endTime, 3.6, accuracy: 0.001) + XCTAssertLessThanOrEqual(chunks[0].duration, 1.0) + let finalChunk = tryUnwrap(final.first) + XCTAssertEqual(finalChunk.startTime, 3.4, accuracy: 0.001) + XCTAssertEqual(tryUnwrap(finalChunk.overlappedStartTime), 3.4, accuracy: 0.001) + } + + func testBackgroundSessionAppliesBackpressureToPendingChunks() { + let session = LocalBackgroundTranscriptionSession( + language: "en", + plan: LocalTranscriptionPlan(engine: .mlxWhisper, model: .small, quality: .balanced), + configuration: LocalBackgroundChunkerConfiguration( + sampleRate: 10, + bytesPerSample: 2, + maxChunkDuration: 1, + minChunkDuration: 0.5, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 1, + maxPendingChunks: 2 + ), + requestHandler: { request in + LocalASRTranscriptionResponse( + requestId: request.requestId, + engine: request.engine, + model: request.model, + language: request.language, + segments: [], + fixture: false + ) + } + ) + + let result = session.append(pcmData: pcm(Array(repeating: 10, count: 31)), startTime: 0) + + XCTAssertEqual(result.enqueuedChunks.count, 3) + XCTAssertEqual(result.droppedChunks.map(\.sequence), [0]) + XCTAssertEqual(result.pendingChunkCount, 2) + XCTAssertEqual(session.droppedChunkCount, 1) + } + + func testBackgroundSessionRemapsRawChunkResultsAndMergesOverlap() async throws { + var capturedRequests: [LocalASRTranscriptionRequest] = [] + var dates = [ + Date(timeIntervalSince1970: 10), + Date(timeIntervalSince1970: 10.25), + Date(timeIntervalSince1970: 11), + Date(timeIntervalSince1970: 11.5), + ] + let tempDirectory = FileManager.default.temporaryDirectory.appendingPathComponent( + "LocalBackgroundSessionTests-\(UUID().uuidString)", + isDirectory: true + ) + try FileManager.default.createDirectory(at: tempDirectory, withIntermediateDirectories: true) + defer { try? FileManager.default.removeItem(at: tempDirectory) } + + let session = LocalBackgroundTranscriptionSession( + language: "en", + plan: LocalTranscriptionPlan(engine: .mlxWhisper, model: .small, quality: .balanced), + configuration: LocalBackgroundChunkerConfiguration( + sampleRate: 10, + bytesPerSample: 2, + maxChunkDuration: 1, + minChunkDuration: 0.5, + overlapDuration: 0.2, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 1, + maxPendingChunks: 4 + ), + requestHandler: { request in + capturedRequests.append(request) + XCTAssertTrue(FileManager.default.fileExists(atPath: request.audioPath)) + let sequence = capturedRequests.count - 1 + let segments: [LocalASRTranscriptSegment] + if sequence == 0 { + segments = [ + LocalASRTranscriptSegment( + id: nil, + speaker: 0, + text: "hello local whisper", + start: 0, + end: 1.0 + ) + ] + } else { + segments = [ + LocalASRTranscriptSegment( + id: nil, + speaker: 0, + text: "whisper runs here", + start: 0.0, + end: 0.9 + ) + ] + } + return LocalASRTranscriptionResponse( + requestId: request.requestId, + engine: request.engine, + model: request.model, + language: request.language, + segments: segments, + fixture: false + ) + }, + temporaryDirectory: tempDirectory, + makeRequestId: { + "background-\(capturedRequests.count)" + }, + now: { + dates.removeFirst() + } + ) + + _ = session.append(pcmData: pcm(Array(repeating: 10, count: 17)), startTime: 5) + _ = session.finishInput() + let results = try await session.transcribePending() + let snapshot = session.snapshot() + + XCTAssertEqual(results.count, 2) + XCTAssertEqual(results[0].remappedSegments[0].start, 5.0, accuracy: 0.001) + XCTAssertEqual(results[1].chunk.startTime, 5.8, accuracy: 0.001) + XCTAssertEqual(results[1].remappedSegments[0].start, 5.8, accuracy: 0.001) + XCTAssertEqual(results[0].latencySeconds, 0.25, accuracy: 0.001) + XCTAssertEqual(snapshot.rawChunkResults.count, 2) + XCTAssertEqual(snapshot.joinedTranscript, "hello local whisper runs here") + XCTAssertEqual(capturedRequests.map(\.sampleRate), [10, 10]) + XCTAssertFalse( + FileManager.default.fileExists( + atPath: tempDirectory.appendingPathComponent("background-0.pcm").path) + ) + } + + func testBackgroundPipelineCanExerciseRealHelperWhenRuntimeAvailable() async throws { + guard let helperURL = LocalASRHelperLocator.defaultExecutableURL() else { + throw XCTSkip("Local ASR helper executable is not available") + } + let engines = LocalASRHelperLocator.detectedEngines(executableURL: helperURL) + guard let engine = engines.first else { + throw XCTSkip("No real local ASR engine/model is available") + } + + let session = LocalBackgroundTranscriptionSession( + language: "en", + plan: LocalTranscriptionPlan(engine: engine, model: .base, quality: .fast), + configuration: LocalBackgroundChunkerConfiguration( + maxChunkDuration: 0.5, + minChunkDuration: 0.25, + overlapDuration: 0, + maxPendingChunks: 2 + ), + executableURL: helperURL, + timeoutSeconds: 20 + ) + + _ = session.append(pcmData: pcm(Array(repeating: 0, count: 16000)), startTime: 0) + _ = session.finishInput() + let results = try await session.transcribePending() + + XCTAssertFalse(results.isEmpty) + XCTAssertTrue(results.allSatisfy { !$0.response.fixture }) + } + func testDetectedEnginesUsesHelperCapabilityProbe() throws { let helper = try makeExecutableHelper( body: @@ -354,6 +537,18 @@ final class LocalASRRuntimeTests: XCTestCase { } return helper } + + private func pcm(_ samples: [Int16]) -> Data { + samples.withUnsafeBufferPointer { Data(buffer: $0) } + } + + private func tryUnwrap(_ value: T?, file: StaticString = #filePath, line: UInt = #line) -> T { + guard let value else { + XCTFail("Expected non-nil value", file: file, line: line) + fatalError("Expected non-nil value") + } + return value + } } final class PTTBatchTranscriptionRouterTests: XCTestCase { From f0ab8060e5c19c5ac752e4c416e7efd885049212 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 12:05:07 +0700 Subject: [PATCH 49/58] feat(desktop): wire local transcription state, ASR runtime integration, and storage (cherry picked from commit 6cfc3f9d3c507b567d49a478eb911eaafd34e1d5) --- desktop/Desktop/Sources/AppState.swift | 359 +++++++++++++++++- .../LocalTranscription/LocalASRRuntime.swift | 20 + .../Rewind/Core/TranscriptionStorage.swift | 30 ++ .../TranscriptionProviderPolicyTests.swift | 61 +++ 4 files changed, 453 insertions(+), 17 deletions(-) diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 188fc6a945d..324a4c72e37 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -208,6 +208,10 @@ class AppState: ObservableObject { // Transcription services private var audioCaptureService: AudioCaptureService? private var transcriptionService: TranscriptionService? + private var localBackgroundSession: LocalBackgroundTranscriptionSession? + private var localBackgroundASRTask: Task? + private(set) var localBackgroundState: LocalBackgroundSessionState? + private var localBackgroundSampleCursor: Int64 = 0 private var systemAudioCaptureService: Any? // SystemAudioCaptureService (macOS 14.4+) private var audioMixer: AudioMixer? private var vadGateService: VADGateService? @@ -1282,19 +1286,23 @@ class AppState: ObservableObject { availableEngines: { LocalASRHelperLocator.detectedEngines() } ).detect() ) - if !backgroundRouting.useCloudBackend { + if case .unavailable = backgroundRouting.route { let message = backgroundRouting.unsupportedLocalReason ?? "Local background transcription is selected and will not use the cloud listen path." log("Transcription: \(message)") showAlert( title: "Local Background Transcription", - message: - "Local background transcription is selected, so Omi will not start a cloud /v4/listen session. The local background session manager is not enabled in this build yet." + message: message ) return } + if let localPlan = backgroundRouting.localPlan { + startLocalBackgroundTranscription(source: effectiveSource, plan: localPlan) + return + } + // For BLE device, check if device is connected if effectiveSource == .bleDevice { guard DeviceProvider.shared.isConnected else { @@ -1342,17 +1350,7 @@ class AppState: ObservableObject { // Initialize system audio capture if supported (macOS 14.4+) // Can be disabled via: defaults write com.omi.desktop-dev disableSystemAudioCapture -bool true // or: defaults write com.omi.computer-macos disableSystemAudioCapture -bool true - let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture") - if systemAudioDisabled { - log( - "Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)" - ) - } else if #available(macOS 14.4, *) { - systemAudioCaptureService = SystemAudioCaptureService() - log("Transcription: System audio capture initialized (macOS 14.4+)") - } else { - log("Transcription: System audio capture not available (requires macOS 14.4+)") - } + initializeOptionalSystemAudioCapture() } // For BLE device, BleAudioService will be used in startAudioCapture @@ -1452,6 +1450,110 @@ class AppState: ObservableObject { } } + /// Start local background transcription without creating a backend `/v4/listen` session. + private func startLocalBackgroundTranscription(source: AudioSource, plan: LocalTranscriptionPlan) { + guard source == .microphone else { + showAlert( + title: "Local Background Transcription", + message: "Local background transcription currently supports the Mac microphone source only." + ) + return + } + + guard AudioCaptureService.checkPermission() else { + requestMicrophonePermission() + return + } + + guard let executableURL = LocalASRHelperLocator.defaultExecutableURL() else { + let message = "Local transcription helper is not available." + log("Transcription: \(message)") + localBackgroundState = .failed + showAlert(title: "Local Background Transcription", message: message) + return + } + + let effectiveLanguage = AssistantSettings.shared.effectiveTranscriptionLanguage + localBackgroundSession = LocalBackgroundTranscriptionSession( + language: effectiveLanguage, + plan: plan, + executableURL: executableURL + ) + localBackgroundState = .recording + localBackgroundSampleCursor = 0 + currentConversationSource = .desktop + recordingInputDeviceName = AudioCaptureService.getCurrentMicrophoneName() + + audioCaptureService = AudioCaptureService() + audioMixer = AudioMixer() + vadGateService = nil + initializeOptionalSystemAudioCapture() + + isTranscribing = true + recordingGeneration &+= 1 + AssistantSettings.shared.transcriptionEnabled = true + audioSource = source + currentTranscript = "" + speakerSegments = [] + totalSegmentCount = 0 + totalWordCount = 0 + speakerSegmentReducer.reset() + liveSpeakerPersonMap = [:] + LiveTranscriptMonitor.shared.clear() + recordingStartTime = Date() + AudioLevelMonitor.shared.reset() + RecordingTimer.shared.start() + + Task { + do { + let sessionId = try await TranscriptionStorage.shared.startSession( + source: currentConversationSource.rawValue, + language: effectiveLanguage, + timezone: TimeZone.current.identifier, + inputDeviceName: recordingInputDeviceName + ) + await MainActor.run { + self.currentSessionId = sessionId + LiveNotesMonitor.shared.startSession(sessionId: sessionId) + } + log("Transcription: Created local background DB session \(sessionId)") + } catch { + logError("Transcription: Failed to create local background DB session", error: error) + } + } + + maxRecordingTimer = Timer.scheduledTimer( + withTimeInterval: maxRecordingDuration, repeats: false + ) { [weak self] _ in + Task { @MainActor in + guard let self = self, self.isTranscribing else { return } + log("Transcription: 4-hour limit reached - stopping local background session") + self.stopTranscription() + } + } + + AnalyticsManager.shared.transcriptionStarted() + log( + "Transcription: Starting local background transcription with \(plan.engine.rawValue)/\(plan.model.rawValue)" + ) + + Task { @MainActor in + await self.startAudioCapture(source: source) + } + } + + private func initializeOptionalSystemAudioCapture() { + let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture") + if systemAudioDisabled { + log("Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)") + } else if #available(macOS 14.4, *) { + systemAudioCaptureService = SystemAudioCaptureService() + log("Transcription: System audio capture initialized (macOS 14.4+)") + } else { + log("Transcription: System audio capture not available (requires macOS 14.4+)") + } + } + /// Start audio capture and pipe to transcription service /// - Parameter source: Audio source to capture from private func startAudioCapture(source: AudioSource = .microphone) async { @@ -1481,9 +1583,11 @@ class AppState: ObservableObject { } // Start the mixer — it sums mic + system into a mono stream and forwards it to - // the transcription WebSocket. + // the active transcription provider. audioMixer?.start { [weak self] monoMixed in - self?.transcriptionService?.sendAudio(monoMixed) + Task { @MainActor in + self?.handleMixedBackgroundAudio(monoMixed) + } } do { @@ -1531,6 +1635,102 @@ class AppState: ObservableObject { } } + private func handleMixedBackgroundAudio(_ monoMixed: Data) { + if let localBackgroundSession { + let startTime = Double(localBackgroundSampleCursor) / 16_000.0 + localBackgroundSampleCursor += Int64(monoMixed.count / 2) + let ingest = localBackgroundSession.append(pcmData: monoMixed, startTime: startTime) + if !ingest.droppedChunks.isEmpty { + log("Transcription: Local background dropped \(ingest.droppedChunks.count) stale chunks") + } + drainLocalBackgroundASRQueue() + return + } + + transcriptionService?.sendAudio(monoMixed) + } + + private func drainLocalBackgroundASRQueue() { + guard localBackgroundASRTask == nil, let session = localBackgroundSession else { return } + localBackgroundASRTask = Task { [weak self, session] in + do { + while let result = try await session.transcribeNext() { + await MainActor.run { + self?.applyLocalBackgroundSegments(result.remappedSegments) + } + } + } catch { + await MainActor.run { + self?.localBackgroundState = .failed + logError("Transcription: Local background ASR failed", error: error) + } + } + await MainActor.run { + self?.localBackgroundASRTask = nil + } + } + } + + private func applyLocalBackgroundSegments(_ normalizedSegments: [NormalizedTranscriptSegment]) { + let incomingSegments = normalizedSegments.compactMap { segment -> SpeakerSegment? in + guard !segment.text.isEmpty else { return nil } + return SpeakerSegment( + segmentId: segment.segmentId, + speaker: segment.speaker, + text: segment.text, + start: segment.start, + end: segment.end, + isUser: segment.isUser, + personId: segment.personId, + translations: segment.translations.map { + SegmentTranslation(lang: $0.lang, text: $0.text) + } + ) + } + + let applyResult = speakerSegmentReducer.apply(incomingSegments) + speakerSegments = speakerSegmentReducer.segments + totalSegmentCount = speakerSegmentReducer.totalSegmentCount + totalWordCount = speakerSegmentReducer.totalWordCount + LiveTranscriptMonitor.shared.updateSegments(speakerSegments) + + if let sessionId = currentSessionId { + Task { + await self.persistLocalBackgroundSegments(normalizedSegments, sessionId: sessionId) + } + } + + if applyResult.added > 0 || applyResult.updated > 0 { + log( + "Transcript [LOCAL UPSERT] Added: \(applyResult.added), updated: \(applyResult.updated)" + ) + } + } + + private func persistLocalBackgroundSegments( + _ normalizedSegments: [NormalizedTranscriptSegment], + sessionId: Int64 + ) async { + for segment in normalizedSegments where !segment.text.isEmpty { + do { + try await TranscriptionStorage.shared.upsertSegment( + sessionId: sessionId, + backendSegmentId: segment.segmentId, + speaker: segment.speaker, + text: segment.text, + startTime: segment.start, + endTime: segment.end, + isUser: segment.isUser, + personId: segment.personId, + speakerLabel: segment.speakerLabel + ) + } catch { + logError("Transcription: Failed to persist local segment to DB", error: error) + await RewindDatabase.shared.reportQueryError(error) + } + } + } + /// Fall back from a silent Bluetooth mic to the built-in microphone. /// Triggered by `AudioCaptureService.onSilentMicDetected`. @MainActor @@ -1648,6 +1848,11 @@ class AppState: ObservableObject { /// triggers conversation processing on the backend side. We also call force-process to ensure /// the conversation is finalized, preventing the retry service from creating duplicates. func stopTranscription() { + if localBackgroundSession != nil { + stopLocalBackgroundTranscription() + return + } + // Capture session metadata BEFORE clearing state (clearTranscriptionState sets sessionId to nil) let capturedSessionId = currentSessionId let capturedStartTime = recordingStartTime @@ -1740,6 +1945,105 @@ class AppState: ObservableObject { } } + private func stopLocalBackgroundTranscription() { + let sessionId = currentSessionId + let generationAtStop = recordingGeneration + localBackgroundState = .transcribingBacklog + + stopAudioCapture() + _ = localBackgroundSession?.finishInput() + drainLocalBackgroundASRQueue() + + Task { + await waitForLocalBackgroundBacklog(timeoutSeconds: 20) + + guard self.recordingGeneration == generationAtStop else { + log("Transcription: New recording started while local background finalized") + return + } + + await MainActor.run { + self.localBackgroundState = .finalizing + self.isSavingConversation = true + } + + if let sessionId, let session = await MainActor.run(body: { self.localBackgroundSession }) { + let snapshot = session.snapshot() + await MainActor.run { + self.applyLocalBackgroundSegments(snapshot.mergedSegments) + } + + do { + await self.persistLocalBackgroundSegments(snapshot.mergedSegments, sessionId: sessionId) + let title = Self.localConversationTitle(from: snapshot.joinedTranscript) + try await TranscriptionStorage.shared.completeLocalSession( + id: sessionId, + title: title, + overview: snapshot.joinedTranscript + ) + log( + "Transcription: Finalized local background session \(sessionId) with \(snapshot.mergedSegments.count) segments" + ) + } catch { + await MainActor.run { + self.localBackgroundState = .failed + } + logError("Transcription: Failed to finalize local background session", error: error) + } + } + + await MainActor.run { + self.localBackgroundState = self.localBackgroundState == .failed ? .failed : .finalized + self.isSavingConversation = false + self.clearLocalBackgroundTranscriptionState() + } + + await loadConversations() + } + } + + private func waitForLocalBackgroundBacklog(timeoutSeconds: TimeInterval) async { + let deadline = Date().addingTimeInterval(timeoutSeconds) + while Date() < deadline { + drainLocalBackgroundASRQueue() + if localBackgroundASRTask == nil && (localBackgroundSession?.pendingChunkCount ?? 0) == 0 { + return + } + try? await Task.sleep(nanoseconds: 100_000_000) + } + log("Transcription: Local background ASR backlog timed out") + } + + private func clearLocalBackgroundTranscriptionState() { + log( + "Transcription: Final local segments count: \(totalSegmentCount) (in-memory: \(speakerSegments.count)), words: \(totalWordCount)" + ) + LiveNotesMonitor.shared.endSession() + LiveNotesMonitor.shared.clear() + LiveTranscriptMonitor.shared.clear() + speakerSegments = [] + liveSpeakerPersonMap = [:] + recordingStartTime = nil + currentSessionId = nil + localBackgroundSession = nil + localBackgroundASRTask = nil + localBackgroundSampleCursor = 0 + AnalyticsManager.shared.transcriptionStopped(wordCount: totalWordCount) + totalSegmentCount = 0 + totalWordCount = 0 + speakerSegmentReducer.reset() + currentTranscript = "" + log("Transcription: Stopped local background session") + } + + nonisolated static func localConversationTitle(from transcript: String) -> String { + let normalized = transcript + .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) + .trimmingCharacters(in: .whitespacesAndNewlines) + guard !normalized.isEmpty else { return "Local transcription" } + return String(normalized.prefix(60)) + } + /// Reconcile a local session by checking if a matching conversation exists on the backend. /// If found, marks the session as completed. Otherwise leaves it as pendingUpload for retry. private func reconcileSession(sessionId: Int64, startTime: Date) async { @@ -1772,6 +2076,11 @@ class AppState: ObservableObject { /// Finish the current conversation and keep recording for a new one. /// Disconnects the WebSocket (triggers backend conversation processing) then reconnects. func finishConversation() async -> FinishConversationResult { + if localBackgroundSession != nil { + stopLocalBackgroundTranscription() + return .saved + } + guard totalSegmentCount > 0 || !speakerSegments.isEmpty else { log("Transcription: No segments to finish") return .discarded @@ -2014,6 +2323,7 @@ class AppState: ObservableObject { isLoadingConversations = true conversationsError = nil + var cachedLocalConversations: [ServerConversation] = [] // Step 1: Load from local cache first (instant display) // Use timeout to avoid blocking UI if database is initializing (e.g. recovery) @@ -2037,6 +2347,7 @@ class AppState: ObservableObject { } if !cachedConversations.isEmpty { + cachedLocalConversations = cachedConversations conversations = cachedConversations log("Conversations: Loaded \(cachedConversations.count) from local cache (instant)") @@ -2083,7 +2394,7 @@ class AppState: ObservableObject { do { let fetchedConversations = try await conversationsTask - conversations = fetchedConversations + conversations = mergeLocalOnlyConversations(cachedLocalConversations, with: fetchedConversations) log( "Conversations: Refreshed \(fetchedConversations.count) from API (starred=\(showStarredOnly), date=\(selectedDateFilter?.description ?? "nil"))" ) @@ -2239,6 +2550,20 @@ class AppState: ObservableObject { return result } + private func mergeLocalOnlyConversations( + _ localConversations: [ServerConversation], + with fetchedConversations: [ServerConversation] + ) -> [ServerConversation] { + var merged = fetchedConversations + let fetchedIds = Set(fetchedConversations.map(\.id)) + let localOnly = localConversations.filter { + $0.id.hasPrefix("local-") && !fetchedIds.contains($0.id) + } + merged.append(contentsOf: localOnly) + merged.sort { $0.createdAt > $1.createdAt } + return merged + } + /// Update the starred status of a conversation locally func setConversationStarred(_ conversationId: String, starred: Bool) { if let index = conversations.firstIndex(where: { $0.id == conversationId }) { diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index 866ebf1062d..1dc4e439e78 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -360,6 +360,25 @@ struct LocalBackgroundTranscriptSnapshot: Equatable { } } +enum LocalBackgroundSessionState: String, Equatable { + case recording + case transcribingBacklog = "transcribing_backlog" + case finalizing + case finalized + case failed +} + +struct BackgroundConversationFinalizationPolicy { + enum Owner: Equatable { + case cloudBackend + case localBackground + } + + func shouldForceProcessBackend(owner: Owner) -> Bool { + owner == .cloudBackend + } +} + struct LocalBackgroundAudioChunker { private var configuration: LocalBackgroundChunkerConfiguration private var buffer = Data() @@ -490,6 +509,7 @@ final class LocalBackgroundTranscriptionSession { private var merger = LocalTranscriptMerger() private(set) var rawChunkResults: [LocalBackgroundASRRawChunkResult] = [] private(set) var droppedChunkCount = 0 + var pendingChunkCount: Int { pendingChunks.count } init( language: String, diff --git a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift index 7b585fd49c6..6cf2fddcabb 100644 --- a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift +++ b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift @@ -117,6 +117,36 @@ actor TranscriptionStorage { log("TranscriptionStorage: Completed session \(id) (backendId: \(backendId))") } + /// Mark a locally-transcribed session as completed without creating or reconciling a backend conversation. + func completeLocalSession( + id: Int64, + finishedAt: Date = Date(), + title: String, + overview: String, + category: String = "other" + ) async throws { + let db = try await ensureInitialized() + + try await db.write { database in + guard var record = try TranscriptionSessionRecord.fetchOne(database, key: id) else { + throw TranscriptionStorageError.sessionNotFound + } + + record.finishedAt = finishedAt + record.status = .completed + record.backendId = record.backendId ?? "local-\(id)" + record.backendSynced = true + record.title = title + record.overview = overview + record.category = category + record.conversationStatus = .completed + record.updatedAt = Date() + try record.update(database) + } + + log("TranscriptionStorage: Completed local session \(id)") + } + /// Mark session as failed with error. /// No-op if the session is already completed (prevents race with concurrent completion). func markSessionFailed(id: Int64, error: String) async throws { diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index 11ef8339a06..5b3cbbb2524 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -722,6 +722,67 @@ final class BackgroundTranscriptionRoutingGuardTests: XCTestCase { } } +final class LocalBackgroundLifecycleTests: XCTestCase { + func testLocalBackgroundStatesExposeLifecycleForDebugging() { + XCTAssertEqual( + Set(LocalBackgroundSessionState.allCasesForTest.map(\.rawValue)), + ["recording", "transcribing_backlog", "finalizing", "finalized", "failed"] + ) + } + + func testLocalFinalizationDoesNotForceProcessBackend() { + let policy = BackgroundConversationFinalizationPolicy() + + XCTAssertFalse(policy.shouldForceProcessBackend(owner: .localBackground)) + XCTAssertTrue(policy.shouldForceProcessBackend(owner: .cloudBackend)) + } + + func testLocalConversationTitleUsesDeterministicTranscriptPrefix() { + XCTAssertEqual(AppState.localConversationTitle(from: " hello local background "), "hello local background") + XCTAssertEqual(AppState.localConversationTitle(from: " "), "Local transcription") + } + + func testLocalCompletedSessionConvertsToConversationPath() { + let session = TranscriptionSessionRecord( + id: 42, + startedAt: Date(timeIntervalSince1970: 100), + finishedAt: Date(timeIntervalSince1970: 130), + source: ConversationSource.desktop.rawValue, + language: "en", + status: .completed, + backendId: "local-42", + backendSynced: true, + createdAt: Date(timeIntervalSince1970: 100), + updatedAt: Date(timeIntervalSince1970: 130), + title: "Local background", + overview: "hello local background", + category: "other", + conversationStatus: .completed + ) + let segment = TranscriptionSegmentRecord( + sessionId: 42, + speaker: 0, + text: "hello local background", + startTime: 0, + endTime: 2, + segmentOrder: 0, + segmentId: "local-bg-0-0.0" + ) + + let conversation = session.toServerConversation(segments: [segment]) + + XCTAssertEqual(conversation?.id, "local-42") + XCTAssertEqual(conversation?.status, .completed) + XCTAssertEqual(conversation?.transcriptSegments.first?.text, "hello local background") + } +} + +extension LocalBackgroundSessionState { + fileprivate static var allCasesForTest: [LocalBackgroundSessionState] { + [.recording, .transcribingBacklog, .finalizing, .finalized, .failed] + } +} + final class TranscriptionProviderOnboardingAdvisorTests: XCTestCase { func testEligibleNativeAppleSiliconRecommendsLocalFirst() { let recommendation = TranscriptionProviderOnboardingAdvisor().recommendation( From e7f28734f2647561d87ce90b253a0ccd455f2f0c Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 12:20:41 +0700 Subject: [PATCH 50/58] feat(desktop): add local background ASR harness (cherry picked from commit 660155a70827eded246e333a77abe02ac1a2ee09) --- .../LocalTranscription/LocalASRRuntime.swift | 2 +- .../TranscriptComparison.swift | 64 ++++++ .../LocalBackgroundSmokeHarnessTests.swift | 208 ++++++++++++++++++ .../Tests/TranscriptComparisonTests.swift | 36 +++ .../TranscriptionProviderPolicyTests.swift | 2 + desktop/local-asr-helper/README.md | 40 ++++ .../local_background_asr_harness.py | 158 +++++++++++++ 7 files changed, 509 insertions(+), 1 deletion(-) create mode 100644 desktop/Desktop/Sources/LocalTranscription/TranscriptComparison.swift create mode 100644 desktop/Desktop/Tests/LocalBackgroundSmokeHarnessTests.swift create mode 100644 desktop/Desktop/Tests/TranscriptComparisonTests.swift create mode 100755 desktop/local-asr-helper/local_background_asr_harness.py diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index 1dc4e439e78..e8471db1744 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -648,7 +648,7 @@ final class LocalBackgroundTranscriptionSession { var normalized = segment.normalized() normalized.start = chunk.startTime + segment.start normalized.end = chunk.startTime + segment.end - normalized.segmentId = normalized.segmentId ?? "local-bg-\(chunk.sequence)-\(segment.start)" + normalized.segmentId = "local-bg-\(chunk.sequence)-\(segment.id ?? "\(segment.start)")" return normalized } } diff --git a/desktop/Desktop/Sources/LocalTranscription/TranscriptComparison.swift b/desktop/Desktop/Sources/LocalTranscription/TranscriptComparison.swift new file mode 100644 index 00000000000..87ac201779a --- /dev/null +++ b/desktop/Desktop/Sources/LocalTranscription/TranscriptComparison.swift @@ -0,0 +1,64 @@ +import Foundation + +struct TranscriptComparison { + static func normalizedWords(_ value: String) -> [String] { + normalizedText(value).split(separator: " ").map(String.init) + } + + static func normalizedCharacters(_ value: String) -> [Character] { + Array(normalizedText(value).replacingOccurrences(of: " ", with: "")) + } + + static func normalizedText(_ value: String) -> String { + value.lowercased() + .replacingOccurrences( + of: #"[^a-z0-9\s]"#, + with: " ", + options: .regularExpression + ) + .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) + .trimmingCharacters(in: .whitespacesAndNewlines) + } + + static func wordErrorRate(reference: String, hypothesis: String) -> Double { + errorRate(reference: normalizedWords(reference), hypothesis: normalizedWords(hypothesis)) + } + + static func characterErrorRate(reference: String, hypothesis: String) -> Double { + errorRate( + reference: normalizedCharacters(reference), + hypothesis: normalizedCharacters(hypothesis) + ) + } + + private static func errorRate(reference: [T], hypothesis: [T]) -> Double { + guard !reference.isEmpty else { + return hypothesis.isEmpty ? 0 : 1 + } + return Double(editDistance(reference, hypothesis)) / Double(reference.count) + } + + private static func editDistance(_ reference: [T], _ hypothesis: [T]) -> Int { + var previous = Array(0...hypothesis.count) + var current = Array(repeating: 0, count: hypothesis.count + 1) + + for referenceIndex in 1...reference.count { + current[0] = referenceIndex + for hypothesisIndex in 1...hypothesis.count { + if reference[referenceIndex - 1] == hypothesis[hypothesisIndex - 1] { + current[hypothesisIndex] = previous[hypothesisIndex - 1] + } else { + current[hypothesisIndex] = + min( + previous[hypothesisIndex], + current[hypothesisIndex - 1], + previous[hypothesisIndex - 1] + ) + 1 + } + } + swap(&previous, ¤t) + } + + return previous[hypothesis.count] + } +} diff --git a/desktop/Desktop/Tests/LocalBackgroundSmokeHarnessTests.swift b/desktop/Desktop/Tests/LocalBackgroundSmokeHarnessTests.swift new file mode 100644 index 00000000000..d151177dfd1 --- /dev/null +++ b/desktop/Desktop/Tests/LocalBackgroundSmokeHarnessTests.swift @@ -0,0 +1,208 @@ +import Foundation +import XCTest + +@testable import Omi_Computer + +final class LocalBackgroundSmokeHarnessTests: XCTestCase { + func testRunHarness() async throws { + let environment = ProcessInfo.processInfo.environment + guard environment["OMI_LOCAL_BACKGROUND_ASR_HARNESS"] == "1" else { + throw XCTSkip("Set OMI_LOCAL_BACKGROUND_ASR_HARNESS=1 to run the harness") + } + guard let pcmPath = environment["OMI_LOCAL_BACKGROUND_ASR_PCM_PATH"], !pcmPath.isEmpty else { + XCTFail("OMI_LOCAL_BACKGROUND_ASR_PCM_PATH is required") + return + } + guard let outputPath = environment["OMI_LOCAL_BACKGROUND_ASR_OUTPUT_PATH"], !outputPath.isEmpty + else { + XCTFail("OMI_LOCAL_BACKGROUND_ASR_OUTPUT_PATH is required") + return + } + + let mode = environment["OMI_LOCAL_BACKGROUND_ASR_MODE"] ?? "fixture" + let sampleRate = Int(environment["OMI_LOCAL_BACKGROUND_ASR_SAMPLE_RATE"] ?? "16000") ?? 16000 + let model = LocalTranscriptionModel(rawValue: environment["OMI_LOCAL_ASR_MODEL"] ?? "base") + ?? .base + let engine = LocalTranscriptionEngine(rawValue: environment["OMI_LOCAL_ASR_ENGINE"] ?? "") + ?? .mlxWhisper + let quality = TranscriptionQualityPreset(rawValue: environment["OMI_LOCAL_ASR_QUALITY"] ?? "fast") + ?? .fast + let language = environment["OMI_LOCAL_ASR_LANGUAGE"] ?? "en" + let audioData = try Data(contentsOf: URL(fileURLWithPath: pcmPath)) + let fixtureText = environment["OMI_LOCAL_BACKGROUND_ASR_FIXTURE_TEXT"] + ?? "hello local background transcription" + let start = Date() + var requestCounter = 0 + + let session: LocalBackgroundTranscriptionSession + let plan = LocalTranscriptionPlan(engine: engine, model: model, quality: quality) + let configuration = LocalBackgroundChunkerConfiguration( + sampleRate: sampleRate, + bytesPerSample: 2, + maxChunkDuration: Double(environment["OMI_LOCAL_BACKGROUND_ASR_MAX_CHUNK_SECONDS"] ?? "2") + ?? 2, + minChunkDuration: Double(environment["OMI_LOCAL_BACKGROUND_ASR_MIN_CHUNK_SECONDS"] ?? "0.4") + ?? 0.4, + overlapDuration: Double(environment["OMI_LOCAL_BACKGROUND_ASR_OVERLAP_SECONDS"] ?? "0.25") + ?? 0.25, + silenceWindowDuration: 0.25, + silenceAmplitudeThreshold: 256, + maxPendingChunks: 64 + ) + + if mode == "local" { + guard let helperURL = LocalASRHelperLocator.defaultExecutableURL() else { + XCTFail("Local ASR helper is unavailable; run fixture mode or set OMI_LOCAL_ASR_HELPER_PATH") + return + } + session = LocalBackgroundTranscriptionSession( + language: language, + plan: plan, + configuration: configuration, + executableURL: helperURL, + timeoutSeconds: Double(environment["OMI_LOCAL_ASR_TIMEOUT_SECONDS"] ?? "120") ?? 120 + ) + } else { + session = LocalBackgroundTranscriptionSession( + language: language, + plan: plan, + configuration: configuration, + requestHandler: { request in + let current = requestCounter + requestCounter += 1 + let bytes = (try? Data(contentsOf: URL(fileURLWithPath: request.audioPath)).count) ?? 0 + let duration = Double(bytes) / Double(max(1, sampleRate * 2)) + let words = fixtureText.split(separator: " ").map(String.init) + let word = words.isEmpty ? "fixture" : words[current % words.count] + return LocalASRTranscriptionResponse( + requestId: request.requestId, + engine: request.engine, + model: request.model, + language: request.language, + segments: [ + LocalASRTranscriptSegment( + id: "fixture-\(current)", + speaker: 0, + text: "\(word) chunk \(current)", + start: 0, + end: max(0.01, duration) + ) + ], + fixture: true + ) + } + ) + } + + let inputStepBytes = max(2, sampleRate * 2 / 2) + var offset = 0 + var startTime = 0.0 + var droppedChunkCount = 0 + while offset < audioData.count { + let end = min(audioData.count, offset + inputStepBytes) + let result = session.append(pcmData: audioData[offset.. 0 ? elapsed / audioDuration : nil, + droppedChunkCount: droppedChunkCount + session.droppedChunkCount, + chunkResults: snapshot.rawChunkResults.map(HarnessChunkResult.init), + joinedTranscript: joinedTranscript, + scores: scores + ) + let encoder = JSONEncoder() + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + let outputURL = URL(fileURLWithPath: outputPath) + try FileManager.default.createDirectory( + at: outputURL.deletingLastPathComponent(), + withIntermediateDirectories: true + ) + try encoder.encode(report).write(to: outputURL, options: .atomic) + } +} + +private struct HarnessReport: Encodable { + var mode: String + var pcmPath: String + var language: String + var sampleRate: Int + var engine: String + var model: String + var quality: String + var audioDurationSeconds: Double + var elapsedSeconds: Double + var realTimeFactor: Double? + var droppedChunkCount: Int + var chunkResults: [HarnessChunkResult] + var joinedTranscript: String + var scores: HarnessScores? +} + +private struct HarnessChunkResult: Encodable { + var sequence: Int + var startTime: Double + var endTime: Double + var overlappedStartTime: Double? + var latencySeconds: Double + var helperEngine: String + var helperModel: String + var fixture: Bool + var rawSegments: [LocalASRTranscriptSegment] + var remappedSegments: [NormalizedTranscriptSegment] + var joinedText: String + + init(_ result: LocalBackgroundASRRawChunkResult) { + sequence = result.chunk.sequence + startTime = result.chunk.startTime + endTime = result.chunk.endTime + overlappedStartTime = result.chunk.overlappedStartTime + latencySeconds = result.latencySeconds + helperEngine = result.response.engine.rawValue + helperModel = result.response.model.rawValue + fixture = result.response.fixture + rawSegments = result.response.segments + remappedSegments = result.remappedSegments + joinedText = result.joinedText + } +} + +private struct HarnessScores: Encodable { + var reference: String + var normalizedReference: String + var normalizedHypothesis: String + var wordErrorRate: Double + var characterErrorRate: Double +} diff --git a/desktop/Desktop/Tests/TranscriptComparisonTests.swift b/desktop/Desktop/Tests/TranscriptComparisonTests.swift new file mode 100644 index 00000000000..960f9a2dbc8 --- /dev/null +++ b/desktop/Desktop/Tests/TranscriptComparisonTests.swift @@ -0,0 +1,36 @@ +import XCTest + +@testable import Omi_Computer + +final class TranscriptComparisonTests: XCTestCase { + func testNormalizationRemovesCasePunctuationAndRepeatedWhitespace() { + XCTAssertEqual( + TranscriptComparison.normalizedText(" Hello, LOCAL Whisper! "), + "hello local whisper" + ) + } + + func testWordErrorRateCountsSubstitutionInsertionAndDeletion() { + XCTAssertEqual( + TranscriptComparison.wordErrorRate( + reference: "hello local whisper", + hypothesis: "hello cloud whisper now" + ), + 2.0 / 3.0, + accuracy: 0.0001 + ) + } + + func testCharacterErrorRateUsesNormalizedCharactersWithoutSpaces() { + XCTAssertEqual( + TranscriptComparison.characterErrorRate(reference: "abc def", hypothesis: "abc dxf"), + 1.0 / 6.0, + accuracy: 0.0001 + ) + } + + func testEmptyReferenceScoringIsBounded() { + XCTAssertEqual(TranscriptComparison.wordErrorRate(reference: "", hypothesis: ""), 0) + XCTAssertEqual(TranscriptComparison.wordErrorRate(reference: "", hypothesis: "extra"), 1) + } +} diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index 5b3cbbb2524..92a417457be 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -458,8 +458,10 @@ final class LocalASRRuntimeTests: XCTestCase { XCTAssertEqual(results.count, 2) XCTAssertEqual(results[0].remappedSegments[0].start, 5.0, accuracy: 0.001) + XCTAssertEqual(results[0].remappedSegments[0].segmentId, "local-bg-0-0.0") XCTAssertEqual(results[1].chunk.startTime, 5.8, accuracy: 0.001) XCTAssertEqual(results[1].remappedSegments[0].start, 5.8, accuracy: 0.001) + XCTAssertEqual(results[1].remappedSegments[0].segmentId, "local-bg-1-0.0") XCTAssertEqual(results[0].latencySeconds, 0.25, accuracy: 0.001) XCTAssertEqual(snapshot.rawChunkResults.count, 2) XCTAssertEqual(snapshot.joinedTranscript, "hello local whisper runs here") diff --git a/desktop/local-asr-helper/README.md b/desktop/local-asr-helper/README.md index 18d8c60c266..6570a56907f 100644 --- a/desktop/local-asr-helper/README.md +++ b/desktop/local-asr-helper/README.md @@ -40,6 +40,46 @@ printf '{"request_id":"smoke-1","audio_path":"/tmp/omi-asr-smoke.pcm","language" | cargo run --quiet --manifest-path desktop/local-asr-helper/Cargo.toml ``` +Background pipeline harness: + +```bash +# Deterministic fixture mode. Requires no local model or cloud credentials. +desktop/local-asr-helper/local_background_asr_harness.py \ + --generate-speech "hello local background transcription" \ + --reference "hello local background transcription" \ + --mode fixture \ + --max-chunk-seconds 2 \ + --output /tmp/omi-local-background-asr-report.json + +# Real local Whisper mode. Build the helper first or point at an existing helper. +cargo build --manifest-path desktop/local-asr-helper/Cargo.toml +OMI_LOCAL_ASR_HELPER_PATH="$PWD/desktop/local-asr-helper/target/debug/local-asr-helper" \ + desktop/local-asr-helper/local_background_asr_harness.py \ + --generate-speech "hello local background transcription" \ + --reference "hello local background transcription" \ + --mode local \ + --engine mlx-whisper \ + --model base \ + --output /tmp/omi-local-background-asr-local-report.json +``` + +The harness is intentionally focused on the desktop background transcription +path, not one-off helper invocation. It converts or generates 16 kHz mono PCM, +runs the Swift `LocalBackgroundTranscriptionSession`, and writes JSON with: + +- chunk boundaries and overlap start times; +- helper engine/model; +- raw per-chunk Whisper/helper segments; +- timestamp-remapped session-relative segments; +- deterministic joined transcript; +- latency and real-time-factor data; +- optional WER/CER scores when `--reference` is provided. + +Use `--audio path/to/file.wav` or `--audio path/to/file.pcm` instead of +`--generate-speech` to inspect a checked-in or manually recorded sample. The +`--deepgram-compare` flag is reserved as the explicit future extension point and +requires `DEEPGRAM_API_KEY`; it is not required for local smoke validation. + The dev desktop app build (`desktop/run.sh`) builds this helper and copies it to `.app/Contents/Resources/local-asr-helper`, which is the bundled path used by `LocalASRHelperLocator`. diff --git a/desktop/local-asr-helper/local_background_asr_harness.py b/desktop/local-asr-helper/local_background_asr_harness.py new file mode 100755 index 00000000000..1907ab37e7b --- /dev/null +++ b/desktop/local-asr-helper/local_background_asr_harness.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import shutil +import subprocess +import sys +import tempfile +from pathlib import Path + + +def run(command, cwd=None, env=None): + result = subprocess.run(command, cwd=cwd, env=env, text=True) + if result.returncode != 0: + raise SystemExit(result.returncode) + + +def make_generated_audio(text, work_dir): + aiff_path = work_dir / "generated.aiff" + wav_path = work_dir / "generated.wav" + pcm_path = work_dir / "generated.pcm" + run(["say", "-o", str(aiff_path), text]) + run( + [ + "afconvert", + str(aiff_path), + "-f", + "WAVE", + "-d", + "LEI16@16000", + "-c", + "1", + str(wav_path), + ] + ) + extract_wav_pcm(wav_path, pcm_path) + return pcm_path + + +def convert_audio(audio_path, work_dir): + if audio_path.suffix.lower() == ".pcm": + return audio_path + wav_path = work_dir / "input.wav" + pcm_path = work_dir / "input.pcm" + run( + [ + "afconvert", + str(audio_path), + "-f", + "WAVE", + "-d", + "LEI16@16000", + "-c", + "1", + str(wav_path), + ] + ) + extract_wav_pcm(wav_path, pcm_path) + return pcm_path + + +def extract_wav_pcm(wav_path, pcm_path): + import wave + + with wave.open(str(wav_path), "rb") as wav: + if wav.getframerate() != 16000 or wav.getnchannels() != 1 or wav.getsampwidth() != 2: + raise SystemExit(f"{wav_path} is not 16 kHz mono 16-bit PCM") + pcm_path.write_bytes(wav.readframes(wav.getnframes())) + + +def main(): + parser = argparse.ArgumentParser( + description="Run desktop local background ASR through the Swift chunker/queue/merge harness." + ) + source = parser.add_mutually_exclusive_group(required=True) + source.add_argument( + "--audio", + type=Path, + help="Input audio file. .pcm is used directly; other formats use afconvert.", + ) + source.add_argument( + "--generate-speech", + help="Generate a deterministic macOS speech sample with `say`.", + ) + parser.add_argument("--mode", choices=["fixture", "local"], default="fixture") + parser.add_argument("--output", type=Path, default=Path("/tmp/omi-local-background-asr-report.json")) + parser.add_argument("--engine", choices=["mlx-whisper", "faster-whisper"], default="mlx-whisper") + parser.add_argument("--model", default="base") + parser.add_argument("--quality", default="fast") + parser.add_argument("--language", default="en") + parser.add_argument("--max-chunk-seconds", default="15") + parser.add_argument("--min-chunk-seconds", default="1") + parser.add_argument("--overlap-seconds", default="1") + parser.add_argument("--reference", help="Optional reference transcript for WER/CER scoring.") + parser.add_argument( + "--deepgram-compare", + action="store_true", + help="Reserved extension point. Requires DEEPGRAM_API_KEY and is not used for local smoke validation.", + ) + args = parser.parse_args() + + if args.deepgram_compare and not os.environ.get("DEEPGRAM_API_KEY"): + raise SystemExit("--deepgram-compare requires DEEPGRAM_API_KEY") + if args.deepgram_compare: + print( + "Deepgram comparison is an explicit future extension point; local smoke continues without cloud.", + file=sys.stderr, + ) + + repo_root = Path(__file__).resolve().parents[2] + desktop_package = repo_root / "desktop" / "Desktop" + with tempfile.TemporaryDirectory(prefix="omi-local-bg-asr-") as temp: + work_dir = Path(temp) + if args.generate_speech: + if not shutil.which("say") or not shutil.which("afconvert"): + raise SystemExit("Generated speech requires macOS `say` and `afconvert`.") + pcm_path = make_generated_audio(args.generate_speech, work_dir) + fixture_text = args.generate_speech + else: + if not args.audio.exists(): + raise SystemExit(f"Input audio does not exist: {args.audio}") + if args.audio.suffix.lower() != ".pcm" and not shutil.which("afconvert"): + raise SystemExit("Non-PCM input requires macOS `afconvert`.") + pcm_path = convert_audio(args.audio.resolve(), work_dir) + fixture_text = args.reference or "hello local background transcription" + + env = os.environ.copy() + env.update( + { + "OMI_LOCAL_BACKGROUND_ASR_HARNESS": "1", + "OMI_LOCAL_BACKGROUND_ASR_PCM_PATH": str(pcm_path), + "OMI_LOCAL_BACKGROUND_ASR_OUTPUT_PATH": str(args.output.resolve()), + "OMI_LOCAL_BACKGROUND_ASR_MODE": args.mode, + "OMI_LOCAL_BACKGROUND_ASR_FIXTURE_TEXT": fixture_text, + "OMI_LOCAL_ASR_ENGINE": args.engine, + "OMI_LOCAL_ASR_MODEL": args.model, + "OMI_LOCAL_ASR_QUALITY": args.quality, + "OMI_LOCAL_ASR_LANGUAGE": args.language, + "OMI_LOCAL_BACKGROUND_ASR_MAX_CHUNK_SECONDS": args.max_chunk_seconds, + "OMI_LOCAL_BACKGROUND_ASR_MIN_CHUNK_SECONDS": args.min_chunk_seconds, + "OMI_LOCAL_BACKGROUND_ASR_OVERLAP_SECONDS": args.overlap_seconds, + } + ) + if args.reference: + env["OMI_LOCAL_BACKGROUND_ASR_REFERENCE"] = args.reference + + run( + ["swift", "test", "--filter", "LocalBackgroundSmokeHarnessTests/testRunHarness"], + cwd=desktop_package, + env=env, + ) + + report = json.loads(args.output.read_text()) + print(json.dumps(report, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() From 6ffb6e03c6aca1c5fc1afbd7bb3c9d0d9610ec61 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 13:03:27 +0700 Subject: [PATCH 51/58] Fix local background transcription gating (cherry picked from commit c8c87a9607909abfbf1d278b71dd822cb4e956a5) --- desktop/Desktop/Sources/AppState.swift | 13 +- .../LocalTranscription/LocalASRRuntime.swift | 4 + .../Components/LiveTranscriptView.swift | 8 +- .../Sources/MainWindow/DesktopHomeView.swift | 12 +- .../MainWindow/Pages/SettingsPage.swift | 134 ++++++++++++++++++ .../Sources/MainWindow/SidebarView.swift | 2 +- .../Rewind/Core/TranscriptionStorage.swift | 23 +++ .../TranscriptionProviderPolicyTests.swift | 32 +++++ desktop/local-asr-helper/README.md | 4 + 9 files changed, 222 insertions(+), 10 deletions(-) diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 324a4c72e37..dfa03b18976 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -1273,11 +1273,6 @@ class AppState: ObservableObject { func startTranscription(source: AudioSource? = nil) { guard !isTranscribing else { return } - // Paywall hard-stop: every code path that enables the mic + WS streaming - // funnels through here, including auto-restart from sleep and toggle - // shortcuts. Refuse to start and surface the upgrade popup. - if blockIfPaywalled() { return } - // Use provided source or fall back to current setting let effectiveSource = source ?? audioSource let backgroundRouting = BackgroundTranscriptionRoutingGuard().decide( @@ -1286,6 +1281,14 @@ class AppState: ObservableObject { availableEngines: { LocalASRHelperLocator.detectedEngines() } ).detect() ) + + // Paywall hard-stop applies only to the cloud listen path. Local background + // MLX/faster-whisper capture never opens `/v4/listen` and should keep + // working for users who selected a local provider. + if backgroundRouting.requiresCloudEntitlement && blockIfPaywalled(reason: "transcription") { + return + } + if case .unavailable = backgroundRouting.route { let message = backgroundRouting.unsupportedLocalReason diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index e8471db1744..d825bb9d3b2 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -754,6 +754,10 @@ struct BackgroundTranscriptionRoutingDecision: Equatable { return false } + var requiresCloudEntitlement: Bool { + useCloudBackend + } + var localPlan: LocalTranscriptionPlan? { if case .localWhisper(let plan) = route { return plan diff --git a/desktop/Desktop/Sources/MainWindow/Components/LiveTranscriptView.swift b/desktop/Desktop/Sources/MainWindow/Components/LiveTranscriptView.swift index decbf7767ae..2a9ad7baa0b 100644 --- a/desktop/Desktop/Sources/MainWindow/Components/LiveTranscriptView.swift +++ b/desktop/Desktop/Sources/MainWindow/Components/LiveTranscriptView.swift @@ -49,9 +49,11 @@ struct RecordingBarAudioLevels: View { HStack(spacing: 6) { Image(systemName: "mic.fill") .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) + .foregroundColor( + monitor.microphoneLevel > 0.01 ? OmiColors.success : OmiColors.textTertiary + ) AudioLevelWaveformView( - level: monitor.microphoneLevel, + level: max(monitor.microphoneLevel, 0.08), barCount: 8, isActive: true ) @@ -62,7 +64,7 @@ struct RecordingBarAudioLevels: View { .scaledFont(size: 12) .foregroundColor(OmiColors.textTertiary) AudioLevelWaveformView( - level: monitor.systemLevel, + level: max(monitor.systemLevel, 0.04), barCount: 8, isActive: true ) diff --git a/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift b/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift index fcbae586adc..e633c49cee3 100644 --- a/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift +++ b/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift @@ -151,7 +151,7 @@ struct DesktopHomeView: View { // Auto-start transcription if enabled in settings. // If API keys aren't loaded yet, onChange below retries. if settings.transcriptionEnabled && !appState.isTranscribing { - if APIKeyService.keysAvailable { + if APIKeyService.keysAvailable || !backgroundTranscriptionNeedsAPIKeys(settings: settings) { log("DesktopHomeView: Auto-starting transcription") appState.startTranscription() } else { @@ -414,6 +414,16 @@ struct DesktopHomeView: View { } } + private func backgroundTranscriptionNeedsAPIKeys(settings: AssistantSettings) -> Bool { + let routing = BackgroundTranscriptionRoutingGuard().decide( + selection: settings.transcriptionProviderSelection, + capabilities: LocalTranscriptionCapabilityDetector( + availableEngines: { LocalASRHelperLocator.detectedEngines() } + ).detect() + ) + return routing.requiresCloudEntitlement + } + /// Recursively find all NSHostingViews in a window and set sizingOptions to [], /// disabling ALL size computations to prevent full-tree sizeThatFits() traversals. /// Window min/max sizes are enforced at the AppKit level via NSWindow.minSize instead. diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 4adae640e19..421e7678b66 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -186,6 +186,9 @@ struct SettingsContentView: View { // Advanced stats @State private var advancedStats: UserStats? @State private var isLoadingStats = false + @State private var rawTranscriptionHistory: [TranscriptionSessionWithSegments] = [] + @State private var isLoadingRawTranscriptionHistory = false + @State private var rawTranscriptionHistoryError: String? @State private var chatMessageCount: Int? @State private var isLoadingChatMessages = false @State private var showProfileAndStats = false @@ -3557,6 +3560,98 @@ struct SettingsContentView: View { .clipShape(RoundedRectangle(cornerRadius: 8)) } } + + settingsCard(settingId: "advanced.devtools.rawtranscription") { + VStack(alignment: .leading, spacing: 14) { + HStack(spacing: 12) { + Image(systemName: "waveform.badge.magnifyingglass") + .scaledFont(size: 16) + .foregroundColor(OmiColors.purplePrimary) + VStack(alignment: .leading, spacing: 4) { + Text("Raw Transcription History") + .scaledFont(size: 15, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + Text("Inspect locally persisted sessions and segment text from background capture") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + } + Spacer() + Button { + Task { await loadRawTranscriptionHistory() } + } label: { + HStack(spacing: 6) { + if isLoadingRawTranscriptionHistory { + ProgressView().controlSize(.mini) + } else { + Image(systemName: "arrow.clockwise") + .scaledFont(size: 12, weight: .semibold) + } + Text("Refresh") + .scaledFont(size: 13, weight: .medium) + } + } + .buttonStyle(.plain) + .padding(.horizontal, 12) + .padding(.vertical, 6) + .background(OmiColors.backgroundSecondary) + .clipShape(RoundedRectangle(cornerRadius: 8)) + .disabled(isLoadingRawTranscriptionHistory) + } + + if let rawTranscriptionHistoryError { + Text(rawTranscriptionHistoryError) + .scaledFont(size: 12) + .foregroundColor(OmiColors.warning) + } else if rawTranscriptionHistory.isEmpty { + Text("No local transcription sessions found yet.") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + } else { + VStack(alignment: .leading, spacing: 10) { + ForEach(rawTranscriptionHistory.indices, id: \.self) { index in + rawTranscriptionHistoryRow(rawTranscriptionHistory[index]) + } + } + } + } + .task { + if rawTranscriptionHistory.isEmpty && rawTranscriptionHistoryError == nil { + await loadRawTranscriptionHistory() + } + } + } + } + } + + private func rawTranscriptionHistoryRow(_ item: TranscriptionSessionWithSegments) -> some View { + VStack(alignment: .leading, spacing: 8) { + HStack(spacing: 8) { + Text("#\(item.session.id ?? -1)") + .scaledMonospacedFont(size: 11, weight: .semibold) + .foregroundColor(OmiColors.textSecondary) + Text(item.session.status.rawValue) + .scaledFont(size: 11, weight: .medium) + .foregroundColor( + item.session.status == .completed ? OmiColors.success : OmiColors.warning + ) + Text(rawTranscriptionDateFormatter.string(from: item.session.startedAt)) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + Spacer() + Text("\(item.segments.count) segments") + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + } + + Text(rawTranscriptionPreview(for: item)) + .scaledMonospacedFont(size: 11) + .foregroundColor(OmiColors.textSecondary) + .lineLimit(8) + .textSelection(.enabled) + .frame(maxWidth: .infinity, alignment: .leading) + .padding(10) + .background(OmiColors.backgroundSecondary.opacity(0.8)) + .clipShape(RoundedRectangle(cornerRadius: 8)) } } @@ -5705,6 +5800,45 @@ struct SettingsContentView: View { } } + private var rawTranscriptionDateFormatter: DateFormatter { + let formatter = DateFormatter() + formatter.dateStyle = .short + formatter.timeStyle = .medium + return formatter + } + + private func rawTranscriptionPreview(for item: TranscriptionSessionWithSegments) -> String { + let lines = item.segments.prefix(20).map { segment in + let speaker = segment.speakerLabel ?? "speaker \(segment.speaker)" + return String( + format: "[%.2f-%.2f] %@: %@", + segment.startTime, + segment.endTime, + speaker, + segment.text + ) + } + if lines.isEmpty { + return "(no segments persisted)" + } + let suffix = + item.segments.count > lines.count ? "\n... \(item.segments.count - lines.count) more" : "" + return lines.joined(separator: "\n") + suffix + } + + private func loadRawTranscriptionHistory() async { + isLoadingRawTranscriptionHistory = true + rawTranscriptionHistoryError = nil + defer { isLoadingRawTranscriptionHistory = false } + + do { + rawTranscriptionHistory = try await TranscriptionStorage.shared.getRecentSessionsWithSegments( + limit: 8) + } catch { + rawTranscriptionHistoryError = error.localizedDescription + } + } + // MARK: - Developer API Keys Subsection private var developerKeysSubsection: some View { diff --git a/desktop/Desktop/Sources/MainWindow/SidebarView.swift b/desktop/Desktop/Sources/MainWindow/SidebarView.swift index 73f3b203177..6dc6eb30f13 100644 --- a/desktop/Desktop/Sources/MainWindow/SidebarView.swift +++ b/desktop/Desktop/Sources/MainWindow/SidebarView.swift @@ -1647,7 +1647,7 @@ struct SidebarAudioLevelIcon: View { /// Combined audio level (max of mic and system) private var combinedLevel: Float { - max(micLevel, systemLevel) + max(max(micLevel, systemLevel), 0.08) } var body: some View { diff --git a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift index 6cf2fddcabb..d31f950bbe4 100644 --- a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift +++ b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift @@ -589,6 +589,29 @@ actor TranscriptionStorage { } } + /// Get recent local transcription sessions with raw persisted segments for debugging. + func getRecentSessionsWithSegments(limit: Int = 8) async throws -> [TranscriptionSessionWithSegments] { + let db = try await ensureInitialized() + + return try await db.read { database in + let sessions = try TranscriptionSessionRecord + .order(Column("startedAt").desc) + .limit(max(1, limit)) + .fetchAll(database) + + return try sessions.map { session in + guard let sessionId = session.id else { + return TranscriptionSessionWithSegments(session: session, segments: []) + } + let segments = try TranscriptionSegmentRecord + .filter(Column("sessionId") == sessionId) + .order(Column("segmentOrder").asc) + .fetchAll(database) + return TranscriptionSessionWithSegments(session: session, segments: segments) + } + } + } + /// Get all sessions needing recovery (crashed, pending, or failed with retries left) func getSessionsNeedingRecovery() async throws -> [TranscriptionSessionRecord] { let db = try await ensureInitialized() diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index 92a417457be..72245b1c030 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -74,6 +74,38 @@ final class TranscriptionProviderPolicyTests: XCTestCase { XCTAssertNil(result.fallbackReason) } + func testLocalBackgroundRouteDoesNotRequireCloudEntitlement() { + let capabilities = LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [.mlxWhisper] + ) + + let decision = BackgroundTranscriptionRoutingGuard().decide( + selection: TranscriptionProviderSelection(mode: .local, quality: .fast), + capabilities: capabilities + ) + + XCTAssertNotNil(decision.localPlan) + XCTAssertFalse(decision.requiresCloudEntitlement) + } + + func testAutoCloudFallbackRequiresCloudEntitlement() { + let capabilities = LocalTranscriptionCapabilities( + processor: .nativeAppleSilicon, + physicalMemoryBytes: 16 * 1024 * 1024 * 1024, + availableEngines: [] + ) + + let decision = BackgroundTranscriptionRoutingGuard().decide( + selection: TranscriptionProviderSelection(mode: .auto, quality: .fast), + capabilities: capabilities + ) + + XCTAssertTrue(decision.useCloudBackend) + XCTAssertTrue(decision.requiresCloudEntitlement) + } + func testAccurateUsesLargerModelsOnlyWhenMemoryAllows() { let lowMemory = LocalTranscriptionCapabilities( processor: .nativeAppleSilicon, diff --git a/desktop/local-asr-helper/README.md b/desktop/local-asr-helper/README.md index 6570a56907f..4a046bb95d0 100644 --- a/desktop/local-asr-helper/README.md +++ b/desktop/local-asr-helper/README.md @@ -80,6 +80,10 @@ Use `--audio path/to/file.wav` or `--audio path/to/file.pcm` instead of `--deepgram-compare` flag is reserved as the explicit future extension point and requires `DEEPGRAM_API_KEY`; it is not required for local smoke validation. +In the dev app, open Settings -> Advanced -> Dev Tools -> Raw Transcription +History to inspect the recent locally persisted sessions and raw segment text. +Use the Refresh button after stopping a local background recording. + The dev desktop app build (`desktop/run.sh`) builds this helper and copies it to `.app/Contents/Resources/local-asr-helper`, which is the bundled path used by `LocalASRHelperLocator`. From e85360b29b8a07cbe565acf2fe27e297af03a520 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 12:06:10 -0400 Subject: [PATCH 52/58] Update .gitignore files to exclude agent-generated plans and build artifacts (cherry picked from commit 85d5728b9f40c338caacec48f78f2da5506a5837) --- .gitignore | 4 ++++ desktop/.gitignore | 1 + 2 files changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 1a9dd09a4d2..de226b0f9e0 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,9 @@ dump/ node_modules yarn.lock +# Cursor (agent-generated plans; rules/hooks remain tracked) +.cursor/plans/ + # Coordination (AI agent collab manifests) .coordination/ @@ -30,6 +33,7 @@ yarn.lock build/ dist/ .build/ +**/.build-*/ .swiftpm/ **/target/ diff --git a/desktop/.gitignore b/desktop/.gitignore index 4cacc3c6802..887d2269ba0 100644 --- a/desktop/.gitignore +++ b/desktop/.gitignore @@ -1,5 +1,6 @@ # Build artifacts .build/ +**/.build-*/ build/ DerivedData/ *.xcodeproj/xcuserdata/ From b8d600de054436e07ada15c465eeb667e271e1ea Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 20 May 2026 05:03:00 +0000 Subject: [PATCH 53/58] Fix ChatGPT tier security and desktop Codex/wiki bugs Require X-ChatGPT-Fingerprint on requests for quota and subscription bypass, refresh enrollment heartbeat from desktop launch and throttled server updates, resolve Codex transport at call time in GeminiClient, label wiki search hits distinctly from tasks, and include memory id in wiki slugs to avoid collisions. (cherry picked from commit 63fde7bbe828f23150c5f087e939aade8be70f6a) --- backend/database/users.py | 9 +++ backend/main.py | 2 + backend/routers/users.py | 5 +- backend/utils/chatgpt.py | 74 +++++++++++++++++++ backend/utils/subscription.py | 5 +- desktop/Desktop/Sources/APIClient.swift | 4 + desktop/Desktop/Sources/OmiApp.swift | 7 +- .../MemoryExtraction/MemoryAssistant.swift | 8 +- .../TaskExtraction/TaskAssistant.swift | 6 +- .../Core/GeminiClient.swift | 19 ++--- 10 files changed, 121 insertions(+), 18 deletions(-) create mode 100644 backend/utils/chatgpt.py diff --git a/backend/database/users.py b/backend/database/users.py index b3fbf6fe539..b42f8d8e3bb 100644 --- a/backend/database/users.py +++ b/backend/database/users.py @@ -250,6 +250,15 @@ def set_chatgpt_active(uid: str, fingerprint: str): ) +def touch_chatgpt_heartbeat(uid: str): + """Refresh ChatGPT tier heartbeat (called when a valid fingerprint is on the request).""" + user_ref = db.collection('users').document(uid) + user_ref.set( + {'chatgpt': {'last_seen_at': datetime.now(timezone.utc)}}, + merge=True, + ) + + def clear_chatgpt_active(uid: str): user_ref = db.collection('users').document(uid) user_ref.set( diff --git a/backend/main.py b/backend/main.py index ab7be1dac2a..c9b823af16c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -144,8 +144,10 @@ app.add_middleware(TimeoutMiddleware, methods_timeout=methods_timeout) from utils.byok import BYOKMiddleware +from utils.chatgpt import ChatGPTMiddleware app.add_middleware(BYOKMiddleware) +app.add_middleware(ChatGPTMiddleware) @app.on_event("shutdown") diff --git a/backend/routers/users.py b/backend/routers/users.py index 2190cb28e65..e8654d1cdf7 100644 --- a/backend/routers/users.py +++ b/backend/routers/users.py @@ -98,6 +98,7 @@ from utils.webhooks import webhook_first_time_setup from database.action_items import get_action_items as get_standalone_action_items from utils.byok import has_byok_keys, invalidate_byok_state_cache +from utils.chatgpt import chatgpt_request_grants_bypass import logging logger = logging.getLogger(__name__) @@ -885,7 +886,7 @@ def get_user_subscription_endpoint( # these users aren't surprised by a disabled phone-call feature. unlimited_phone_quota = PhoneCallQuota(has_access=True, is_paid=True) - if users_db.is_chatgpt_active(uid): + if chatgpt_request_grants_bypass(uid): return UserSubscriptionResponse( subscription=_chatgpt_unlimited_subscription(), transcription_seconds_used=0, @@ -1110,7 +1111,7 @@ def get_user_chat_usage_quota( # BYOK free plan: user brings their own keys, so there's no Omi-side cost # to meter. Only return unlimited when BYOK headers are on the request (desktop). # Mobile (no headers) should see real quota. - if users_db.is_chatgpt_active(uid): + if chatgpt_request_grants_bypass(uid): return ChatUsageQuota( plan='Free (ChatGPT)', plan_type=PlanType.unlimited.value, diff --git a/backend/utils/chatgpt.py b/backend/utils/chatgpt.py new file mode 100644 index 00000000000..ea6a9f3f1ef --- /dev/null +++ b/backend/utils/chatgpt.py @@ -0,0 +1,74 @@ +"""Per-request ChatGPT / Codex tier fingerprint plumbing. + +Desktop sends ``X-ChatGPT-Fingerprint`` (SHA-256 of Codex account_id) on requests +while Codex is active. Quota and subscription bypass require a matching enrolled +fingerprint on the same request — enrollment alone is not enough (mirrors BYOK). +""" + +import logging +import re +from contextvars import ContextVar +from datetime import datetime, timezone +from typing import Optional + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request + +import database.users as users_db + +logger = logging.getLogger('chatgpt') + +CHATGPT_FINGERPRINT_HEADER = 'x-chatgpt-fingerprint' +_SHA256_HEX_RE = re.compile(r'^[a-f0-9]{64}$') +# Refresh Firestore heartbeat at most once per day when desktop sends a valid fingerprint. +_HEARTBEAT_REFRESH_INTERVAL_SECONDS = 24 * 60 * 60 + +_chatgpt_fp_ctx: ContextVar[Optional[str]] = ContextVar('chatgpt_fingerprint', default=None) + + +def get_chatgpt_fingerprint() -> Optional[str]: + return _chatgpt_fp_ctx.get() + + +def has_chatgpt_fingerprint() -> bool: + """True if the current request carries a ChatGPT enrollment fingerprint header.""" + return bool(_chatgpt_fp_ctx.get()) + + +def chatgpt_request_grants_bypass(uid: str) -> bool: + """True when enrolled and this request's fingerprint matches Firestore enrollment. + + Refreshes ``last_seen_at`` on success so desktop enrollment stays alive without + re-posting to ``/chatgpt-active`` every week. + """ + fp = _chatgpt_fp_ctx.get() + if not fp or not _SHA256_HEX_RE.match(fp): + return False + if not users_db.is_chatgpt_active(uid): + return False + state = users_db.get_chatgpt_state(uid) + if state.get('fingerprint') != fp: + return False + last_seen = state.get('last_seen_at') + if not isinstance(last_seen, datetime): + users_db.touch_chatgpt_heartbeat(uid) + else: + age = (datetime.now(timezone.utc) - last_seen).total_seconds() + if age >= _HEARTBEAT_REFRESH_INTERVAL_SECONDS: + users_db.touch_chatgpt_heartbeat(uid) + return True + + +class ChatGPTMiddleware(BaseHTTPMiddleware): + """Extract ChatGPT fingerprint header into a per-request contextvar.""" + + async def dispatch(self, request: Request, call_next): + raw = request.headers.get(CHATGPT_FINGERPRINT_HEADER) + fp = raw.strip() if raw else None + if fp and not _SHA256_HEX_RE.match(fp): + fp = None + token = _chatgpt_fp_ctx.set(fp) + try: + return await call_next(request) + finally: + _chatgpt_fp_ctx.reset(token) diff --git a/backend/utils/subscription.py b/backend/utils/subscription.py index d3b176d56dc..4483d9e3a95 100644 --- a/backend/utils/subscription.py +++ b/backend/utils/subscription.py @@ -13,6 +13,7 @@ from database.announcements import compare_versions from models.users import PlanType, SubscriptionStatus, Subscription, PlanLimits, TrialMetadata from utils.byok import get_byok_key, get_byok_keys +from utils.chatgpt import chatgpt_request_grants_bypass from utils.log_sanitizer import sanitize import logging @@ -541,8 +542,8 @@ def enforce_chat_quota(uid: str, platform: Optional[str] = None) -> None: ) # BYOK users pay their own LLM provider — no Omi-side cost to cap. - # ChatGPT/Codex tier: user pays OpenAI via subscription; LLM quota bypass only. - if users_db.is_chatgpt_active(uid): + # ChatGPT/Codex tier: bypass only when this request proves Codex enrollment (header). + if chatgpt_request_grants_bypass(uid): return # Require an LLM provider key on this request (not just any BYOK header) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index ca0f75fb5b6..2745ca85e28 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -111,6 +111,10 @@ actor APIClient { headers.merge(APIKeyService.byokHeaders()) { _, new in new } } + if let chatgptFingerprint = CodexAuthService.enrollmentFingerprintIfActive() { + headers["X-ChatGPT-Fingerprint"] = chatgptFingerprint + } + return headers } diff --git a/desktop/Desktop/Sources/OmiApp.swift b/desktop/Desktop/Sources/OmiApp.swift index 744cbf433f1..7b2cbb10a2d 100644 --- a/desktop/Desktop/Sources/OmiApp.swift +++ b/desktop/Desktop/Sources/OmiApp.swift @@ -108,7 +108,12 @@ struct OMIApp: App { .onAppear { log("OmiApp: Main window content appeared (mode: \(Self.launchMode.rawValue))") if CodexAuthService.isActive { - Task { await CodexProxyService.shared.ensureRunning() } + Task { + await CodexProxyService.shared.ensureRunning() + if let fingerprint = CodexAuthService.enrollmentFingerprintIfActive() { + try? await APIClient.shared.activateChatGPT(fingerprint: fingerprint) + } + } } } } diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift index bd2f3402a75..7aac804e28a 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/MemoryExtraction/MemoryAssistant.swift @@ -247,7 +247,13 @@ actor MemoryAssistant: ProactiveAssistant { log("Memory: Saved to SQLite (id: \(inserted.id ?? -1))") if !MemorySearchMode.usesVectorEmbeddings { let title = memory.content.prefix(80).trimmingCharacters(in: .whitespacesAndNewlines) - let slug = MemoryWikiStorage.slugify(String(title)) + let baseSlug = MemoryWikiStorage.slugify(String(title)) + let slug: String + if let memoryId = inserted.id { + slug = "\(baseSlug)-\(memoryId)" + } else { + slug = baseSlug + } Task { _ = try? await MemoryWikiStorage.shared.upsertPage( slug: slug, diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift index dff5d22be78..42e7f1f2653 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift @@ -1335,9 +1335,9 @@ actor TaskAssistant: ProactiveAssistant { for hit in wikiHits { results.append( TaskSearchResult( - id: hit.id, - description: "\(hit.title): \(hit.snippet)", - status: "active", + id: 0, + description: "[Memory wiki] \(hit.title): \(hit.snippet)", + status: "wiki", similarity: nil, matchType: "wiki_fts", relevanceScore: nil diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift index aed80795346..e6972c304d8 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift @@ -191,9 +191,17 @@ actor GeminiClient { case hybridOpenAICompatible } - private let transport: Transport private let model: String + private var transport: Transport { + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + || CodexAuthService.isActive + { + return .hybridOpenAICompatible + } + return .geminiProxy + } + /// Backend proxy base URL (from OMI_DESKTOP_API_URL env var) private static var proxyBaseURL: String { if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { @@ -269,16 +277,9 @@ actor GeminiClient { init(apiKey: String? = nil, model: String = ModelQoS.Gemini.proactive) throws { // BREAKING CHANGE (issue #5861): apiKey parameter is ignored for cloud proxy mode. self.model = model - if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon - || CodexAuthService.isActive - { - self.transport = .hybridOpenAICompatible - return - } - guard !Self.proxyBaseURL.isEmpty else { + if transport == .geminiProxy && Self.proxyBaseURL.isEmpty { throw GeminiClientError.missingAPIKey } - self.transport = .geminiProxy } private func mapHybridError(_ error: HybridLLMClient.ClientError) -> GeminiClientError { From 5585c016d6bbad6fe9da5b01c7377686a40e88e9 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 12:32:36 -0400 Subject: [PATCH 54/58] Resolve cherry-pick integration warnings --- desktop/Desktop/Sources/AppState.swift | 5 ++--- .../Sources/ProactiveAssistants/Core/GeminiClient.swift | 6 +++++- desktop/Desktop/Sources/TranscriptionService.swift | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index dfa03b18976..f80f1f33642 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -2980,9 +2980,8 @@ class AppState: ObservableObject { let mapped = newTranslations.map { TranscriptTranslation(lang: $0.lang, text: $0.text) } - var translationsJson: String? - if let jsonData = try? JSONEncoder().encode(mapped) { - translationsJson = String(data: jsonData, encoding: .utf8) + let translationsJson = (try? JSONEncoder().encode(mapped)).flatMap { + String(data: $0, encoding: .utf8) } Task { try? await TranscriptionStorage.shared.upsertSegment( diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift index e6972c304d8..b0a443bcdc8 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift @@ -194,6 +194,10 @@ actor GeminiClient { private let model: String private var transport: Transport { + Self.resolveTransport() + } + + private static func resolveTransport() -> Transport { if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon || CodexAuthService.isActive { @@ -277,7 +281,7 @@ actor GeminiClient { init(apiKey: String? = nil, model: String = ModelQoS.Gemini.proactive) throws { // BREAKING CHANGE (issue #5861): apiKey parameter is ignored for cloud proxy mode. self.model = model - if transport == .geminiProxy && Self.proxyBaseURL.isEmpty { + if Self.resolveTransport() == .geminiProxy && Self.proxyBaseURL.isEmpty { throw GeminiClientError.missingAPIKey } } diff --git a/desktop/Desktop/Sources/TranscriptionService.swift b/desktop/Desktop/Sources/TranscriptionService.swift index d040e7668fd..9bd033e6e21 100644 --- a/desktop/Desktop/Sources/TranscriptionService.swift +++ b/desktop/Desktop/Sources/TranscriptionService.swift @@ -721,7 +721,7 @@ class TranscriptionService { do { let json = try JSONSerialization.jsonObject(with: data) - if let array = json as? [[String: Any]] { + if json is [[String: Any]] { // JSON array = transcript segments let segments = try JSONDecoder().decode([BackendSegment].self, from: data) if !segments.isEmpty { From dae1f80712691618385320458119a16a61e00bb7 Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 13:02:40 -0400 Subject: [PATCH 55/58] Fix local provider policy seeding --- .../docs/hybrid-provider-settings.md | 11 ++- desktop/local-backend/src/providers.rs | 97 +++++++++++++++++-- .../tools/seed_hybrid_defaults.sh | 63 ++++++++++-- desktop/run.sh | 2 +- 4 files changed, 155 insertions(+), 18 deletions(-) diff --git a/desktop/local-backend/docs/hybrid-provider-settings.md b/desktop/local-backend/docs/hybrid-provider-settings.md index fdb42fca1f1..f2b070ae139 100644 --- a/desktop/local-backend/docs/hybrid-provider-settings.md +++ b/desktop/local-backend/docs/hybrid-provider-settings.md @@ -250,8 +250,17 @@ Default hybrid optional tiers: both cloud toggles off. `run.sh` local mode defau When the daemon starts via `make serve-local` or `desktop/run.sh` in local mode, `desktop/local-backend/tools/seed_hybrid_defaults.sh` runs idempotently: -- If `post_transcript`, `proactive`, or `chat` lacks a provider account, the script +- The typed `provider_policy` is the source of truth once it exists. Legacy + `ai_provider` / `chat_provider` settings are bridged only before a typed policy + has been created. +- The seed script removes synthetic `legacy-*` accounts from the policy it writes, + so compatibility bridge output is not permanently materialized. +- If the default local OpenAI-compatible endpoint is reachable, and + `post_transcript`, `proactive`, or `chat` lacks a provider account, the script creates/reuses a local OpenAI-compatible account and points those slots at it. +- If the default endpoint is unavailable, the script leaves unconfigured slots + without provider accounts so post-transcript processing can use deterministic + fallback output instead of failing against a dead local model server. - `memory_search` remains `local_wiki`. | Variable | Default | diff --git a/desktop/local-backend/src/providers.rs b/desktop/local-backend/src/providers.rs index b124164d308..406d0858572 100644 --- a/desktop/local-backend/src/providers.rs +++ b/desktop/local-backend/src/providers.rs @@ -285,7 +285,8 @@ pub fn post_transcript_slot_resolution(store: &Store) -> Result Result { - let mut policy = if let Some(setting) = store.settings().get(PROVIDER_POLICY_SETTING_KEY)? { + let policy_setting = store.settings().get(PROVIDER_POLICY_SETTING_KEY)?; + let mut policy = if let Some(setting) = policy_setting.as_ref() { let policy: ProviderPolicy = serde_json::from_str(&setting.value_json) .context("failed to parse provider_policy setting")?; if policy.version != PROVIDER_POLICY_VERSION { @@ -302,7 +303,9 @@ pub fn load_provider_policy(store: &Store) -> Result { model_slots: BTreeMap::new(), } }; - add_legacy_policy_bridge(store, &mut policy)?; + if policy_setting.is_none() { + add_legacy_policy_bridge(store, &mut policy)?; + } add_local_profile_defaults(&mut policy)?; Ok(policy) } @@ -460,10 +463,7 @@ fn add_legacy_policy_bridge(store: &Store, policy: &mut ProviderPolicy) -> Resul let Some(account) = legacy_provider_account(slot, key, &value)? else { continue; }; - let model_id = value["model"] - .as_str() - .unwrap_or(MODEL_GPT_5_4_MINI) - .to_string(); + let model_id = legacy_model_for_slot(slot, value["model"].as_str()); policy.provider_accounts.push(account.clone()); policy.model_slots.insert( (*slot).to_string(), @@ -482,6 +482,14 @@ fn add_legacy_policy_bridge(store: &Store, policy: &mut ProviderPolicy) -> Resul Ok(()) } +fn legacy_model_for_slot(slot: &str, model: Option<&str>) -> String { + let model = model.unwrap_or(MODEL_GPT_5_4_MINI); + if matches!(slot, SLOT_POST_TRANSCRIPT | SLOT_PROACTIVE) && model == MODEL_GPT_5_4 { + return MODEL_GPT_5_4_MINI.to_string(); + } + model.to_string() +} + pub fn model_catalog(store: &Store) -> Result> { let policy = load_provider_policy(store)?; Ok(model_catalog_for_policy(&policy)) @@ -1220,6 +1228,83 @@ mod tests { Ok(()) } + #[test] + fn legacy_gpt_5_4_settings_are_sanitized_for_json_slots() -> Result<()> { + let store = Store::open_in_memory()?; + let mut settings = Map::new(); + settings.insert( + "chat_provider".to_string(), + json!({ + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "model": MODEL_GPT_5_4 + }), + ); + settings.insert( + "ai_provider".to_string(), + json!({ + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:11434/v1", + "model": MODEL_GPT_5_4 + }), + ); + store.settings().upsert_many(settings)?; + + let policy = load_provider_policy(&store)?; + assert_eq!(policy.model_slots[SLOT_CHAT].model_id, MODEL_GPT_5_4); + assert_eq!( + policy.model_slots[SLOT_POST_TRANSCRIPT].model_id, + MODEL_GPT_5_4_MINI + ); + + Ok(()) + } + + #[test] + fn typed_provider_policy_suppresses_legacy_settings_bridge() -> Result<()> { + let store = Store::open_in_memory()?; + let mut settings = Map::new(); + settings.insert( + "ai_provider".to_string(), + json!({ + "kind": "openai_compatible", + "base_url": "http://127.0.0.1:10531/v1", + "model": MODEL_GPT_5_4_MINI + }), + ); + settings.insert( + PROVIDER_POLICY_SETTING_KEY.to_string(), + json!({ + "version": PROVIDER_POLICY_VERSION, + "provider_accounts": [], + "model_slots": { + "memory_search": { + "provider_account_id": null, + "model_id": MODEL_LOCAL_WIKI, + "options": {} + } + } + }), + ); + store.settings().upsert_many(settings)?; + + let policy = load_provider_policy(&store)?; + assert!(policy + .provider_accounts + .iter() + .all(|account| !account.id.starts_with("legacy-"))); + assert_eq!( + policy.model_slots[SLOT_POST_TRANSCRIPT].provider_account_id, + None + ); + + let resolution = post_transcript_slot_resolution(&store)?; + assert!(!resolution.ok); + assert!(resolution.reason.contains("no provider account")); + + Ok(()) + } + #[test] fn default_slots_select_small_models_but_need_accounts() -> Result<()> { let store = Store::open_in_memory()?; diff --git a/desktop/local-backend/tools/seed_hybrid_defaults.sh b/desktop/local-backend/tools/seed_hybrid_defaults.sh index 433518b2286..57060e7bbe1 100755 --- a/desktop/local-backend/tools/seed_hybrid_defaults.sh +++ b/desktop/local-backend/tools/seed_hybrid_defaults.sh @@ -1,4 +1,4 @@ -#!/usr/bin/env bash +#!/bin/bash # Idempotent: seed local model slots on the local daemon when they lack accounts. set -euo pipefail @@ -15,19 +15,56 @@ fi policy_json="$(curl -fsS "${BASE_URL}/v1/provider-policy")" -seed_body="$(POLICY_JSON="$policy_json" python3 - "$PROVIDER_BASE" "$MODEL" "$ACCOUNT_ID" <<'PY' +provider_available="False" +if curl -fsS --connect-timeout 1 --max-time 2 "${PROVIDER_BASE%/}/models" >/dev/null 2>&1; then + provider_available="True" +fi + +seed_body="$(POLICY_JSON="$policy_json" PROVIDER_AVAILABLE="$provider_available" python3 - "$PROVIDER_BASE" "$MODEL" "$ACCOUNT_ID" <<'PY' import json import os import sys provider_base, model, account_id = sys.argv[1:4] +provider_available = os.environ["PROVIDER_AVAILABLE"] == "True" data = json.loads(os.environ["POLICY_JSON"]) policy = data.get("provider_policy") or {"version": 1} accounts = policy.setdefault("provider_accounts", []) slots = policy.setdefault("model_slots", {}) +changed = False + +legacy_account_ids = { + account.get("id") + for account in accounts + if str(account.get("id", "")).startswith("legacy-") +} +if legacy_account_ids: + accounts[:] = [ + account + for account in accounts + if account.get("id") not in legacy_account_ids + ] + for slot, current in list(slots.items()): + if current.get("provider_account_id") in legacy_account_ids: + del slots[slot] + changed = True + +slots.setdefault("memory_search", { + "provider_account_id": None, + "model_id": "local_wiki", + "options": {}, +}) + +if not provider_available: + print(json.dumps({ + "changed": changed, + "skipped": True, + "reason": f"default provider unavailable at {provider_base}", + "policy": policy, + })) + raise SystemExit(0) account = next((a for a in accounts if a.get("id") == account_id), None) -changed = False if account is None: accounts.append({ "id": account_id, @@ -64,19 +101,19 @@ for slot, json_mode in ( } changed = True -slots.setdefault("memory_search", { - "provider_account_id": None, - "model_id": "local_wiki", - "options": {}, -}) - -print(json.dumps({"changed": changed, "policy": policy})) +print(json.dumps({"changed": changed, "skipped": False, "reason": "", "policy": policy})) PY )" changed="$(echo "$seed_body" | python3 -c 'import json, sys; print(json.load(sys.stdin)["changed"])')" +skipped="$(echo "$seed_body" | python3 -c 'import json, sys; print(json.load(sys.stdin)["skipped"])')" if [ "$changed" != "True" ]; then + if [ "$skipped" = "True" ]; then + reason="$(echo "$seed_body" | python3 -c 'import json, sys; print(json.load(sys.stdin)["reason"])')" + echo "seed_hybrid_defaults: ${reason}; leaving model slots unseeded" + exit 0 + fi echo "seed_hybrid_defaults: model slots already have provider accounts" exit 0 fi @@ -86,4 +123,10 @@ curl -fsS -X PUT "${BASE_URL}/v1/provider-policy" \ -H 'content-type: application/json' \ -d "$body" >/dev/null +if [ "$skipped" = "True" ]; then + reason="$(echo "$seed_body" | python3 -c 'import json, sys; print(json.load(sys.stdin)["reason"])')" + echo "seed_hybrid_defaults: cleaned legacy slots; ${reason}; leaving model slots unseeded" + exit 0 +fi + echo "seed_hybrid_defaults: seeded ${ACCOUNT_ID} -> ${PROVIDER_BASE} (${MODEL})" diff --git a/desktop/run.sh b/desktop/run.sh index 7c431df596a..7b66f07c35b 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -450,7 +450,7 @@ if is_local_daemon_mode; then SEED_SCRIPT="$(cd "$(dirname "$0")/local-backend/tools" && pwd)/seed_hybrid_defaults.sh" if [ -x "$SEED_SCRIPT" ]; then substep "Seeding hybrid provider defaults (if unset)" - "$SEED_SCRIPT" || substep "Warning: hybrid provider seed failed (non-fatal)" + bash "$SEED_SCRIPT" || substep "Warning: hybrid provider seed failed (non-fatal)" fi fi fi From 370e25e1b3e2ba6d276897c5b52dcb8ee1acbaeb Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 13:16:41 -0400 Subject: [PATCH 56/58] Track hybrid local tmux launcher --- .gitignore | 2 + scripts/hybrid-local.sh | 227 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 229 insertions(+) create mode 100755 scripts/hybrid-local.sh diff --git a/.gitignore b/.gitignore index de226b0f9e0..7c11b0372f3 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,8 @@ venv/ .DS_Store dump/ /scripts/ +!/scripts/ +!/scripts/hybrid-local.sh *.zip *.wav node_modules diff --git a/scripts/hybrid-local.sh b/scripts/hybrid-local.sh new file mode 100755 index 00000000000..4998565527c --- /dev/null +++ b/scripts/hybrid-local.sh @@ -0,0 +1,227 @@ +#!/usr/bin/env bash +# Start or stop the hybrid local stack: omi-local-backend daemon + Omi Dev desktop. +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +TMUX_SESSION="${OMI_HYBRID_LOCAL_TMUX_SESSION:-omi-hybrid-local}" +DEV_APP_NAME="${OMI_APP_NAME:-Omi Dev}" +DAEMON_HEALTH_WAIT_SECS="${OMI_DAEMON_HEALTH_WAIT_SECS:-180}" + +export OMI_LOCAL_DAEMON_URL="${OMI_LOCAL_DAEMON_URL:-http://127.0.0.1:8765}" +export OMI_LOCAL_BACKEND_HOST="${OMI_LOCAL_BACKEND_HOST:-127.0.0.1}" +export OMI_LOCAL_BACKEND_PORT="${OMI_LOCAL_BACKEND_PORT:-8765}" +export OMI_LOCAL_BACKEND_DATA_DIR="${OMI_LOCAL_BACKEND_DATA_DIR:-/tmp/omi-local-mvp}" +export OMI_LOCAL_DAEMON_LOG="${OMI_LOCAL_DAEMON_LOG:-/tmp/omi-local-backend-dev.log}" + +# Shared hybrid desktop env (matches local-mvp-runbook.md). +hybrid_desktop_env() { + export OMI_DESKTOP_BACKEND_MODE=local + export OMI_LOCAL_DAEMON_SUPERVISE=0 + export OMI_LOCAL_DAEMON_URL + export OMI_LOCAL_BACKEND_PORT + export OMI_PYTHON_API_URL="${OMI_PYTHON_API_URL:-http://omi-cloud-invalid:9001}" + export OMI_DESKTOP_API_URL="${OMI_DESKTOP_API_URL:-http://omi-rust-invalid:9002}" +} + +local_daemon_health_ok() { + curl -fsS "${OMI_LOCAL_DAEMON_URL}/health" >/dev/null 2>&1 +} + +daemon_port() { + python3 - "$OMI_LOCAL_DAEMON_URL" <<'PY' +from urllib.parse import urlparse +import sys + +parsed = urlparse(sys.argv[1]) +print(parsed.port or 8765) +PY +} + +kill_daemon_listeners() { + local port + port="$(daemon_port)" + if command -v lsof >/dev/null 2>&1; then + local pids + pids="$(lsof -ti "tcp:${port}" -sTCP:LISTEN 2>/dev/null || true)" + if [ -n "$pids" ]; then + echo "Stopping local daemon listener(s) on port ${port}: ${pids}" + # shellcheck disable=SC2086 + kill ${pids} 2>/dev/null || true + fi + fi + if pgrep -x omi-local-backend >/dev/null 2>&1; then + echo "Stopping omi-local-backend process(es)" + pkill -x omi-local-backend 2>/dev/null || true + fi +} + +kill_dev_desktop() { + if pgrep -f "${DEV_APP_NAME}.app" >/dev/null 2>&1; then + echo "Stopping ${DEV_APP_NAME}.app" + pkill -f "${DEV_APP_NAME}.app" 2>/dev/null || true + fi +} + +attach_tmux_session() { + if [ "${OMI_HYBRID_LOCAL_ATTACH:-1}" = "0" ]; then + echo "tmux session '${TMUX_SESSION}' is running (attach skipped)." + echo " tmux attach -t ${TMUX_SESSION}" + return 0 + fi + + if [ -n "${TMUX:-}" ]; then + echo "Switching tmux client to session '${TMUX_SESSION}'..." + tmux switch-client -t "$TMUX_SESSION" + return 0 + fi + + exec tmux attach -t "$TMUX_SESSION" +} + +daemon_pane_command() { + if local_daemon_health_ok; then + cat </dev/null 2>&1; do + if [ "\$elapsed" -ge "${DAEMON_HEALTH_WAIT_SECS}" ]; then + echo "ERROR: Timed out after ${DAEMON_HEALTH_WAIT_SECS}s waiting for \${OMI_LOCAL_DAEMON_URL}/health" + echo 'Fix the daemon pane above, then re-run: make serve-local' + exec bash -l + fi + sleep 1 + elapsed=\$((elapsed + 1)) +done +SEED_SCRIPT="${ROOT_DIR}/desktop/local-backend/tools/seed_hybrid_defaults.sh" +if [ -x "\$SEED_SCRIPT" ]; then + echo 'Seeding hybrid provider defaults (if unset)...' + bash "\$SEED_SCRIPT" || echo 'Warning: hybrid provider seed failed (non-fatal)' +fi +echo 'Daemon healthy — starting ./run.sh (first Swift build can take several minutes)...' +RUN_LOG="\${OMI_HYBRID_DESKTOP_RUN_LOG:-/tmp/omi-hybrid-desktop-run.log}" +echo "run.sh output: \$RUN_LOG" +exec ./run.sh 2>&1 | tee "\$RUN_LOG" +EOF +} + +serve_with_tmux() { + local daemon_cmd desktop_cmd + daemon_cmd="$(daemon_pane_command)" + desktop_cmd="$(desktop_pane_command)" + + tmux new-session -d -s "$TMUX_SESSION" -n hybrid -c "${ROOT_DIR}/desktop/local-backend" bash -lc "$daemon_cmd" + tmux split-window -v -t "${TMUX_SESSION}:0" -c "${ROOT_DIR}/desktop" bash -lc "$desktop_cmd" + tmux select-pane -t "${TMUX_SESSION}:0.0" + tmux set-option -t "$TMUX_SESSION" remain-on-exit on +} + +serve_single_process() { + echo "tmux not found; starting desktop/run.sh with daemon supervision." + echo "Daemon log: ${OMI_LOCAL_DAEMON_LOG}" + cd "${ROOT_DIR}/desktop" + hybrid_desktop_env + export OMI_LOCAL_DAEMON_SUPERVISE=1 + exec ./run.sh +} + +cmd_up() { + if command -v tmux >/dev/null 2>&1 && tmux has-session -t "$TMUX_SESSION" 2>/dev/null; then + echo "tmux session '${TMUX_SESSION}' is already running." + attach_tmux_session + return 0 + fi + + if command -v tmux >/dev/null 2>&1; then + if local_daemon_health_ok; then + echo "Local daemon already healthy at ${OMI_LOCAL_DAEMON_URL} — starting desktop pane only in tmux." + else + echo "Starting hybrid local stack in tmux session '${TMUX_SESSION}'" + fi + echo " top pane: desktop/local-backend (or status if already running)" + echo " bottom pane: desktop/run.sh (Omi Dev, hybrid env)" + echo "Teardown: make down-local" + if [ -n "${TMUX:-}" ]; then + echo "" + echo "Note: you are already inside tmux; this will switch you to '${TMUX_SESSION}'." + echo " To start detached instead: OMI_HYBRID_LOCAL_ATTACH=0 make serve-local" + fi + serve_with_tmux + attach_tmux_session + return 0 + fi + + serve_single_process +} + +cmd_down() { + if command -v tmux >/dev/null 2>&1 && tmux has-session -t "$TMUX_SESSION" 2>/dev/null; then + echo "Killing tmux session '${TMUX_SESSION}'" + tmux kill-session -t "$TMUX_SESSION" + fi + kill_daemon_listeners + kill_dev_desktop + echo "Hybrid local stack stopped." +} + +usage() { + cat < + + up Start hybrid local backend + desktop (tmux when available) + down Stop tmux session, local daemon, and Omi Dev.app + +Environment (optional): + OMI_HYBRID_LOCAL_TMUX_SESSION tmux session name (default: omi-hybrid-local) + OMI_HYBRID_LOCAL_ATTACH=0 start tmux detached (do not attach/switch) + OMI_LOCAL_DAEMON_URL daemon base URL (default: http://127.0.0.1:8765) + OMI_LOCAL_BACKEND_DATA_DIR SQLite data dir (default: /tmp/omi-local-mvp) + OMI_DAEMON_HEALTH_WAIT_SECS wait for /health (default: 180) + OMI_APP_NAME desktop bundle name (default: Omi Dev) + +See desktop/local-backend/docs/local-mvp-runbook.md +EOF +} + +main() { + case "${1:-}" in + up) cmd_up ;; + down) cmd_down ;; + *) + usage >&2 + exit 1 + ;; + esac +} + +main "$@" From 04275055c0528332dafc303eac6de3fc13d4b6ec Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 15:35:19 -0400 Subject: [PATCH 57/58] Add local ASR add-on flow --- Makefile | 9 +- desktop/.env.example | 5 + desktop/Backend-Rust/src/routes/updates.rs | 300 +++++++- desktop/Desktop/Sources/APIClient.swift | 3 +- desktop/Desktop/Sources/AppState.swift | 70 +- .../PushToTalkManager.swift | 10 + .../Desktop/Sources/HybridChatClient.swift | 182 ++++- .../LocalASRAddonManager.swift | 691 ++++++++++++++++++ .../LocalTranscription/LocalASRRuntime.swift | 90 ++- .../TranscriptionComparisonHarness.swift | 439 +++++++++++ .../Sources/MainWindow/DesktopHomeView.swift | 21 +- .../MainWindow/Pages/SettingsPage.swift | 482 ++++++++++-- .../Pages/ShortcutsSettingsSection.swift | 17 +- .../Sources/MainWindow/SettingsSidebar.swift | 11 +- .../UI/InsightTestRunnerWindow.swift | 5 +- .../Sources/Providers/ChatProvider.swift | 69 +- .../Rewind/Core/TranscriptionStorage.swift | 5 +- .../Sources/TranscriptionRetryService.swift | 22 +- .../Desktop/Tests/HybridChatClientTests.swift | 109 ++- .../Tests/TranscriptComparisonTests.swift | 33 + .../TranscriptionProviderPolicyTests.swift | 102 ++- desktop/local-asr-addon/build_dev_fixture.sh | 133 ++++ .../local-asr-addon/build_model_artifact.sh | 33 + .../local-asr-addon/build_runtime_artifact.sh | 62 ++ desktop/local-asr-helper/README.md | 32 +- .../local-backend/docs/local-mvp-runbook.md | 20 + desktop/local-backend/src/routes.rs | 2 +- desktop/local-backend/src/storage.rs | 81 ++ desktop/run.sh | 5 + scripts/hybrid-local.sh | 32 + 30 files changed, 2888 insertions(+), 187 deletions(-) create mode 100644 desktop/Desktop/Sources/LocalTranscription/LocalASRAddonManager.swift create mode 100644 desktop/Desktop/Sources/LocalTranscription/TranscriptionComparisonHarness.swift create mode 100755 desktop/local-asr-addon/build_dev_fixture.sh create mode 100755 desktop/local-asr-addon/build_model_artifact.sh create mode 100755 desktop/local-asr-addon/build_runtime_artifact.sh diff --git a/Makefile b/Makefile index 0e49da40ed6..0b4bfb992e6 100644 --- a/Makefile +++ b/Makefile @@ -1,19 +1,26 @@ # Hybrid local development (desktop local daemon + Omi Dev app). # See desktop/local-backend/docs/local-mvp-runbook.md -.PHONY: help serve-local down-local +.PHONY: help serve-local down-local local-asr-fixture help: @echo "Hybrid local development targets:" @echo " make serve-local Start omi-local-backend + Omi Dev (tmux when available)" @echo " make down-local Stop tmux session, daemon, and dev desktop app" + @echo " make local-asr-fixture" + @echo " Build a production-shaped Local Whisper fixture manifest" @echo "" @echo "Optional env: OMI_LOCAL_DAEMON_URL, OMI_LOCAL_BACKEND_DATA_DIR," @echo " OMI_HYBRID_LOCAL_TMUX_SESSION (default: omi-hybrid-local)" @echo " OMI_HYBRID_LOCAL_ATTACH=0 start tmux detached" + @echo " OMI_LOCAL_ASR_PYTHON Python with mlx-whisper for local ASR fixture" + @echo " OMI_LOCAL_ASR_FIXTURE_DIR fixture output dir (default: /tmp/omi-local-asr-fixture)" serve-local: @bash "$(CURDIR)/scripts/hybrid-local.sh" up down-local: @bash "$(CURDIR)/scripts/hybrid-local.sh" down + +local-asr-fixture: + @bash "$(CURDIR)/desktop/local-asr-addon/build_dev_fixture.sh" diff --git a/desktop/.env.example b/desktop/.env.example index f00efd9adb1..571c88f5f2d 100644 --- a/desktop/.env.example +++ b/desktop/.env.example @@ -48,6 +48,11 @@ OMI_PYTHON_API_URL=https://api.omi.me # OMI_LOCAL_DAEMON_URL=http://127.0.0.1:8765 # OMI_LOCAL_DAEMON_SUPERVISE=1 +# Optional production-shaped Local Whisper add-on manifest override. +# Use this for local dev with fixture artifacts; Settings still runs the normal +# manifest/download/checksum/extract/activate install path. +# OMI_LOCAL_ASR_MANIFEST_URL=file:///path/to/local-asr/manifest.json + # Firebase Web API key — fetched from backend via /v1/config/api-keys # Only set this for local dev without a backend running # FIREBASE_API_KEY= diff --git a/desktop/Backend-Rust/src/routes/updates.rs b/desktop/Backend-Rust/src/routes/updates.rs index 7753b421cc4..0e8295cc6ba 100644 --- a/desktop/Backend-Rust/src/routes/updates.rs +++ b/desktop/Backend-Rust/src/routes/updates.rs @@ -41,6 +41,35 @@ pub struct ReleaseInfo { pub channel: Option, } +#[derive(Debug, Serialize)] +struct LocalAsrAddonManifest { + version: u32, + runtime: LocalAsrRuntimeArtifact, + models: Vec, +} + +#[derive(Debug, Serialize)] +struct LocalAsrRuntimeArtifact { + version: String, + platform: String, + arch: String, + url: String, + sha256: String, + size_bytes: i64, + minimum_app_version: Option, +} + +#[derive(Debug, Serialize)] +struct LocalAsrModelArtifact { + model: String, + version: String, + url: String, + sha256: String, + size_bytes: i64, +} + +const LOCAL_ASR_MODELS: &[&str] = &["tiny", "base", "small", "medium", "large_v3_turbo"]; + /// Generate Sparkle 2.0 appcast XML /// /// Picks the latest live release per channel (stable, beta, staging). @@ -48,13 +77,15 @@ pub struct ReleaseInfo { /// Releases with channel="stable" get no XML tag (Sparkle default = stable). /// Releases with channel=None (unpromoted) get `staging`. fn generate_appcast_xml(releases: &[ReleaseInfo], platform: &str) -> String { - let mut xml = String::from(r#" + let mut xml = String::from( + r#" Omi Desktop Updates Omi AI Desktop Application en -"#); +"#, + ); // Deduplicate: pick the latest live release per channel let mut seen_channels = std::collections::HashSet::new(); @@ -62,7 +93,10 @@ fn generate_appcast_xml(releases: &[ReleaseInfo], platform: &str) -> String { if !release.is_live { continue; } - let ch_key = release.channel.clone().unwrap_or_else(|| "staging".to_string()); + let ch_key = release + .channel + .clone() + .unwrap_or_else(|| "staging".to_string()); if !seen_channels.insert(ch_key) { continue; // already emitted an item for this channel } @@ -71,13 +105,16 @@ fn generate_appcast_xml(releases: &[ReleaseInfo], platform: &str) -> String { let changelog_html = if release.changelog.is_empty() { "

Bug fixes and improvements.

".to_string() } else { - let items: String = release.changelog.iter() + let items: String = release + .changelog + .iter() .map(|c| format!("
  • {}
  • ", c)) .collect(); format!("
      {}
    ", items) }; - xml.push_str(&format!(r#" + xml.push_str(&format!( + r#" Omi {} {} {} @@ -104,7 +141,10 @@ fn generate_appcast_xml(releases: &[ReleaseInfo], platform: &str) -> String { match release.channel.as_deref() { Some("stable") => {} // No tag = Sparkle default channel (stable) Some(ch) if !ch.is_empty() => { - xml.push_str(&format!(" {}\n", ch)); + xml.push_str(&format!( + " {}\n", + ch + )); } _ => { // None or empty = unpromoted, treat as staging @@ -124,15 +164,15 @@ fn generate_appcast_xml(releases: &[ReleaseInfo], platform: &str) -> String { } /// GET /appcast.xml - Sparkle appcast feed -async fn get_appcast( - State(state): State, - Query(query): Query, -) -> Response { +async fn get_appcast(State(state): State, Query(query): Query) -> Response { // Fetch all releases — generate_appcast_xml handles filtering and per-channel dedup let releases = match state.firestore.get_desktop_releases().await { Ok(releases) => releases, Err(e) => { - tracing::warn!("Failed to fetch releases from Firestore: {}, using fallback", e); + tracing::warn!( + "Failed to fetch releases from Firestore: {}, using fallback", + e + ); // Return empty appcast if no releases found vec![] } @@ -163,7 +203,11 @@ async fn get_latest_version(State(state): State) -> impl IntoResponse match state.firestore.get_desktop_releases().await { Ok(releases) => { // Return the latest live stable release (channel == "stable") - if let Some(latest) = releases.into_iter().filter(|r| r.is_live && r.channel.as_deref() == Some("stable")).next() { + if let Some(latest) = releases + .into_iter() + .filter(|r| r.is_live && r.channel.as_deref() == Some("stable")) + .next() + { axum::Json(LatestVersionResponse { version: latest.version, build_number: latest.build_number, @@ -172,11 +216,7 @@ async fn get_latest_version(State(state): State) -> impl IntoResponse }) .into_response() } else { - ( - axum::http::StatusCode::NOT_FOUND, - "No live releases found", - ) - .into_response() + (axum::http::StatusCode::NOT_FOUND, "No live releases found").into_response() } } Err(e) => { @@ -197,7 +237,11 @@ async fn download_redirect(State(state): State) -> impl IntoResponse { match state.firestore.get_desktop_releases().await { Ok(releases) => { // Return the latest live stable release for download (channel == "stable") - if let Some(latest) = releases.into_iter().filter(|r| r.is_live && r.channel.as_deref() == Some("stable")).next() { + if let Some(latest) = releases + .into_iter() + .filter(|r| r.is_live && r.channel.as_deref() == Some("stable")) + .next() + { // Serve from GCS bucket for direct download (avoids multi-hop GitHub redirects) let gcs_url = format!( "https://storage.googleapis.com/omi_macos_updates/releases/v{}/Omi.Beta.dmg", @@ -206,11 +250,7 @@ async fn download_redirect(State(state): State) -> impl IntoResponse { tracing::info!("Redirecting download to GCS: {}", gcs_url); axum::response::Redirect::temporary(&gcs_url).into_response() } else { - ( - StatusCode::NOT_FOUND, - "No live releases found", - ) - .into_response() + (StatusCode::NOT_FOUND, "No live releases found").into_response() } } Err(e) => { @@ -224,6 +264,102 @@ async fn download_redirect(State(state): State) -> impl IntoResponse { } } +/// GET /v1/local-asr/manifest - Local Whisper add-on artifact manifest +/// +/// The manifest is intentionally separate from Sparkle appcast so Local Whisper +/// runtime/model updates can ship independently from app updates. +async fn get_local_asr_manifest() -> Response { + match build_local_asr_manifest_from_env() { + Ok(manifest) => ( + [ + (header::CONTENT_TYPE, "application/json; charset=utf-8"), + (header::CACHE_CONTROL, "max-age=300"), + ], + Json(manifest), + ) + .into_response(), + Err(message) => (StatusCode::SERVICE_UNAVAILABLE, message).into_response(), + } +} + +fn build_local_asr_manifest_from_env() -> Result { + let base_url = std::env::var("LOCAL_ASR_ADDON_BASE_URL").unwrap_or_else(|_| { + "https://storage.googleapis.com/omi_macos_updates/local-asr".to_string() + }); + let runtime_version = required_env("LOCAL_ASR_RUNTIME_VERSION")?; + let runtime_sha256 = required_env("LOCAL_ASR_RUNTIME_SHA256")?; + let runtime_size = required_env_i64("LOCAL_ASR_RUNTIME_SIZE_BYTES")?; + let runtime_url = std::env::var("LOCAL_ASR_RUNTIME_URL").unwrap_or_else(|_| { + format!( + "{}/runtime-macos-arm64-{}.zip", + base_url.trim_end_matches('/'), + runtime_version + ) + }); + + let mut models = Vec::new(); + for model in LOCAL_ASR_MODELS { + let key_model = model.to_ascii_uppercase(); + let version_key = format!("LOCAL_ASR_MODEL_{}_VERSION", key_model); + let sha_key = format!("LOCAL_ASR_MODEL_{}_SHA256", key_model); + let size_key = format!("LOCAL_ASR_MODEL_{}_SIZE_BYTES", key_model); + + let Ok(version) = std::env::var(&version_key) else { + continue; + }; + let sha256 = required_env(&sha_key)?; + let size_bytes = required_env_i64(&size_key)?; + let url = + std::env::var(format!("LOCAL_ASR_MODEL_{}_URL", key_model)).unwrap_or_else(|_| { + format!( + "{}/model-{}-{}.zip", + base_url.trim_end_matches('/'), + model.replace('_', "-"), + version + ) + }); + models.push(LocalAsrModelArtifact { + model: (*model).to_string(), + version, + url, + sha256, + size_bytes, + }); + } + + if models.is_empty() { + return Err("Local ASR add-on manifest is not configured: no model artifacts".to_string()); + } + + Ok(LocalAsrAddonManifest { + version: 1, + runtime: LocalAsrRuntimeArtifact { + version: runtime_version, + platform: "macos".to_string(), + arch: "arm64".to_string(), + url: runtime_url, + sha256: runtime_sha256, + size_bytes: runtime_size, + minimum_app_version: std::env::var("LOCAL_ASR_MINIMUM_APP_VERSION").ok(), + }, + models, + }) +} + +fn required_env(key: &str) -> Result { + std::env::var(key) + .map(|value| value.trim().to_string()) + .ok() + .filter(|value| !value.is_empty()) + .ok_or_else(|| format!("Local ASR add-on manifest is not configured: missing {key}")) +} + +fn required_env_i64(key: &str) -> Result { + required_env(key)? + .parse::() + .map_err(|_| format!("Local ASR add-on manifest is not configured: invalid {key}")) +} + /// Request body for creating a release #[derive(Debug, Deserialize)] pub struct CreateReleaseRequest { @@ -358,11 +494,20 @@ async fn promote_release( ); } - match state.firestore.promote_desktop_release(&request.doc_id).await { + match state + .firestore + .promote_desktop_release(&request.doc_id) + .await + { Ok((old_channel, new_channel)) => { let old_display = old_channel.clone(); let new_display = new_channel.clone(); - tracing::info!("Promoted release {}: {} → {}", request.doc_id, old_display, new_display); + tracing::info!( + "Promoted release {}: {} → {}", + request.doc_id, + old_display, + new_display + ); let message = format!("Release promoted from {} to {}", old_display, new_display); ( StatusCode::OK, @@ -395,7 +540,12 @@ async fn promote_release( mod tests { use super::*; - fn make_release(version: &str, build: u32, channel: Option<&str>, is_live: bool) -> ReleaseInfo { + fn make_release( + version: &str, + build: u32, + channel: Option<&str>, + is_live: bool, + ) -> ReleaseInfo { ReleaseInfo { version: version.to_string(), build_number: build, @@ -409,49 +559,83 @@ mod tests { } } + fn clear_local_asr_manifest_env() { + for key in [ + "LOCAL_ASR_ADDON_BASE_URL", + "LOCAL_ASR_RUNTIME_VERSION", + "LOCAL_ASR_RUNTIME_SHA256", + "LOCAL_ASR_RUNTIME_SIZE_BYTES", + "LOCAL_ASR_RUNTIME_URL", + "LOCAL_ASR_MINIMUM_APP_VERSION", + ] { + std::env::remove_var(key); + } + for model in LOCAL_ASR_MODELS { + let key_model = model.to_ascii_uppercase(); + for suffix in ["VERSION", "SHA256", "SIZE_BYTES", "URL"] { + std::env::remove_var(format!("LOCAL_ASR_MODEL_{key_model}_{suffix}")); + } + } + } + #[test] fn test_null_channel_gets_staging_tag() { let releases = vec![make_release("0.1.0", 100, None, true)]; let xml = generate_appcast_xml(&releases, "macos"); - assert!(xml.contains("staging"), - "null channel should emit staging tag, got:\n{}", xml); + assert!( + xml.contains("staging"), + "null channel should emit staging tag, got:\n{}", + xml + ); } #[test] fn test_stable_channel_gets_no_tag() { let releases = vec![make_release("0.2.0", 200, Some("stable"), true)]; let xml = generate_appcast_xml(&releases, "macos"); - assert!(!xml.contains(""), - "stable channel should emit no channel tag, got:\n{}", xml); + assert!( + !xml.contains(""), + "stable channel should emit no channel tag, got:\n{}", + xml + ); } #[test] fn test_beta_channel_gets_beta_tag() { let releases = vec![make_release("0.3.0", 300, Some("beta"), true)]; let xml = generate_appcast_xml(&releases, "macos"); - assert!(xml.contains("beta"), - "beta channel should emit beta tag, got:\n{}", xml); + assert!( + xml.contains("beta"), + "beta channel should emit beta tag, got:\n{}", + xml + ); } #[test] fn test_staging_channel_gets_staging_tag() { let releases = vec![make_release("0.4.0", 400, Some("staging"), true)]; let xml = generate_appcast_xml(&releases, "macos"); - assert!(xml.contains("staging"), - "staging channel should emit staging tag, got:\n{}", xml); + assert!( + xml.contains("staging"), + "staging channel should emit staging tag, got:\n{}", + xml + ); } #[test] fn test_dedup_null_and_staging_same_group() { // null-channel and staging-channel should deduplicate together let releases = vec![ - make_release("0.5.0", 500, None, true), // null → staging group + make_release("0.5.0", 500, None, true), // null → staging group make_release("0.4.0", 400, Some("staging"), true), // explicit staging ]; let xml = generate_appcast_xml(&releases, "macos"); // Only the first (higher build) should appear assert!(xml.contains("0.5.0"), "first staging release should appear"); - assert!(!xml.contains("0.4.0"), "second staging release should be deduped"); + assert!( + !xml.contains("0.4.0"), + "second staging release should be deduped" + ); } #[test] @@ -462,7 +646,10 @@ mod tests { ]; let xml = generate_appcast_xml(&releases, "macos"); assert!(xml.contains("1.0.0"), "stable release should appear"); - assert!(xml.contains("0.9.0"), "null/staging release should also appear (different group)"); + assert!( + xml.contains("0.9.0"), + "null/staging release should also appear (different group)" + ); } #[test] @@ -471,11 +658,50 @@ mod tests { let xml = generate_appcast_xml(&releases, "macos"); assert!(!xml.contains("0.1.0"), "non-live release should not appear"); } + + #[test] + fn test_local_asr_manifest_uses_configured_artifacts() { + clear_local_asr_manifest_env(); + std::env::set_var( + "LOCAL_ASR_ADDON_BASE_URL", + "https://downloads.example.com/local-asr", + ); + std::env::set_var("LOCAL_ASR_RUNTIME_VERSION", "2026.05.20"); + std::env::set_var("LOCAL_ASR_RUNTIME_SHA256", "runtime-sha"); + std::env::set_var("LOCAL_ASR_RUNTIME_SIZE_BYTES", "123"); + std::env::set_var("LOCAL_ASR_MINIMUM_APP_VERSION", "0.2.0"); + std::env::set_var("LOCAL_ASR_MODEL_SMALL_VERSION", "mlx-2026.05.20"); + std::env::set_var("LOCAL_ASR_MODEL_SMALL_SHA256", "model-sha"); + std::env::set_var("LOCAL_ASR_MODEL_SMALL_SIZE_BYTES", "456"); + + let manifest = build_local_asr_manifest_from_env().expect("manifest should build"); + + assert_eq!(manifest.version, 1); + assert_eq!(manifest.runtime.version, "2026.05.20"); + assert_eq!( + manifest.runtime.url, + "https://downloads.example.com/local-asr/runtime-macos-arm64-2026.05.20.zip" + ); + assert_eq!(manifest.runtime.size_bytes, 123); + assert_eq!( + manifest.runtime.minimum_app_version.as_deref(), + Some("0.2.0") + ); + assert_eq!(manifest.models.len(), 1); + assert_eq!(manifest.models[0].model, "small"); + assert_eq!( + manifest.models[0].url, + "https://downloads.example.com/local-asr/model-small-mlx-2026.05.20.zip" + ); + assert_eq!(manifest.models[0].size_bytes, 456); + clear_local_asr_manifest_env(); + } } pub fn updates_routes() -> Router { Router::new() .route("/appcast.xml", get(get_appcast)) + .route("/v1/local-asr/manifest", get(get_local_asr_manifest)) .route("/updates/latest", get(get_latest_version)) .route("/updates/releases", post(create_release)) .route("/updates/releases/promote", patch(promote_release)) diff --git a/desktop/Desktop/Sources/APIClient.swift b/desktop/Desktop/Sources/APIClient.swift index 2745ca85e28..3ba1793643a 100644 --- a/desktop/Desktop/Sources/APIClient.swift +++ b/desktop/Desktop/Sources/APIClient.swift @@ -915,7 +915,8 @@ extension APIClient { let _: LocalTranscriptSegmentEnvelope = try await post( "v1/conversations/\(conversationId)/transcript-segments", body: Request( - id: segment.segmentId, + // Provider segment ids are local to their source; let the local daemon mint durable ids. + id: nil, speakerId: String(segment.speaker), speakerLabel: segment.speakerLabel, text: segment.text, diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index f80f1f33642..39d00287125 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -212,6 +212,9 @@ class AppState: ObservableObject { private var localBackgroundASRTask: Task? private(set) var localBackgroundState: LocalBackgroundSessionState? private var localBackgroundSampleCursor: Int64 = 0 + @Published private(set) var transcriptionComparisonSnapshot = + TranscriptionComparisonHarnessSnapshot.idle + private var transcriptionComparisonHarness: TranscriptionComparisonHarness? private var systemAudioCaptureService: Any? // SystemAudioCaptureService (macOS 14.4+) private var audioMixer: AudioMixer? private var vadGateService: VADGateService? @@ -1257,6 +1260,16 @@ class AppState: ObservableObject { alert.runModal() } + private func openLocalTranscriptionRepair(message: String) { + log("Transcription: \(message)") + NSApp.activate(ignoringOtherApps: true) + NotificationCenter.default.post( + name: .navigateToTranscriptionSettings, + object: nil, + userInfo: ["highlightedSettingId": "transcription.localWhisperAddon"] + ) + } + // MARK: - Transcription /// Toggle transcription on/off @@ -1293,11 +1306,7 @@ class AppState: ObservableObject { let message = backgroundRouting.unsupportedLocalReason ?? "Local background transcription is selected and will not use the cloud listen path." - log("Transcription: \(message)") - showAlert( - title: "Local Background Transcription", - message: message - ) + openLocalTranscriptionRepair(message: message) return } @@ -1454,7 +1463,8 @@ class AppState: ObservableObject { } /// Start local background transcription without creating a backend `/v4/listen` session. - private func startLocalBackgroundTranscription(source: AudioSource, plan: LocalTranscriptionPlan) { + private func startLocalBackgroundTranscription(source: AudioSource, plan: LocalTranscriptionPlan) + { guard source == .microphone else { showAlert( title: "Local Background Transcription", @@ -1472,7 +1482,7 @@ class AppState: ObservableObject { let message = "Local transcription helper is not available." log("Transcription: \(message)") localBackgroundState = .failed - showAlert(title: "Local Background Transcription", message: message) + openLocalTranscriptionRepair(message: message) return } @@ -1482,6 +1492,7 @@ class AppState: ObservableObject { plan: plan, executableURL: executableURL ) + startTranscriptionComparisonHarnessIfNeeded(language: effectiveLanguage) localBackgroundState = .recording localBackgroundSampleCursor = 0 currentConversationSource = .desktop @@ -1545,10 +1556,34 @@ class AppState: ObservableObject { } } + private func startTranscriptionComparisonHarnessIfNeeded(language: String) { + guard TranscriptionComparisonHarness.isEnabled else { + transcriptionComparisonHarness = nil + transcriptionComparisonSnapshot = .idle + return + } + + let harness = TranscriptionComparisonHarness( + language: language, + deepgramAPIKey: TranscriptionComparisonHarness.configuredDeepgramAPIKey(), + onSnapshot: { [weak self] snapshot in + Task { @MainActor in + self?.transcriptionComparisonSnapshot = snapshot + } + } + ) + transcriptionComparisonHarness = harness + transcriptionComparisonSnapshot = harness.snapshot + harness.start() + log("Transcription: Started development Whisper/Deepgram comparison harness") + } + private func initializeOptionalSystemAudioCapture() { let systemAudioDisabled = UserDefaults.standard.bool(forKey: "disableSystemAudioCapture") if systemAudioDisabled { - log("Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)") + log( + "Transcription: System audio capture DISABLED by user preference (disableSystemAudioCapture)" + ) } else if #available(macOS 14.4, *) { systemAudioCaptureService = SystemAudioCaptureService() log("Transcription: System audio capture initialized (macOS 14.4+)") @@ -1643,6 +1678,7 @@ class AppState: ObservableObject { let startTime = Double(localBackgroundSampleCursor) / 16_000.0 localBackgroundSampleCursor += Int64(monoMixed.count / 2) let ingest = localBackgroundSession.append(pcmData: monoMixed, startTime: startTime) + transcriptionComparisonHarness?.appendAudio(monoMixed) if !ingest.droppedChunks.isEmpty { log("Transcription: Local background dropped \(ingest.droppedChunks.count) stale chunks") } @@ -1696,6 +1732,7 @@ class AppState: ObservableObject { totalSegmentCount = speakerSegmentReducer.totalSegmentCount totalWordCount = speakerSegmentReducer.totalWordCount LiveTranscriptMonitor.shared.updateSegments(speakerSegments) + transcriptionComparisonHarness?.appendWhisperSegments(normalizedSegments) if let sessionId = currentSessionId { Task { @@ -1894,7 +1931,8 @@ class AppState: ObservableObject { "Transcription: Local daemon finalized stop session \(sid) via upload pipeline (or queued retry)" ) } catch { - logError("Transcription: Local daemon finalize-after-stop failed for \(sid)", error: error) + logError( + "Transcription: Local daemon finalize-after-stop failed for \(sid)", error: error) } await loadConversations() return @@ -1955,6 +1993,7 @@ class AppState: ObservableObject { stopAudioCapture() _ = localBackgroundSession?.finishInput() + transcriptionComparisonHarness?.finish() drainLocalBackgroundASRQueue() Task { @@ -2031,6 +2070,8 @@ class AppState: ObservableObject { localBackgroundSession = nil localBackgroundASRTask = nil localBackgroundSampleCursor = 0 + transcriptionComparisonHarness?.stop() + transcriptionComparisonHarness = nil AnalyticsManager.shared.transcriptionStopped(wordCount: totalWordCount) totalSegmentCount = 0 totalWordCount = 0 @@ -2040,7 +2081,8 @@ class AppState: ObservableObject { } nonisolated static func localConversationTitle(from transcript: String) -> String { - let normalized = transcript + let normalized = + transcript .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) .trimmingCharacters(in: .whitespacesAndNewlines) guard !normalized.isEmpty else { return "Local transcription" } @@ -2211,7 +2253,8 @@ class AppState: ObservableObject { ) } } catch { - logError("Transcription: Local daemon finalize failed for session \(sessionId)", error: error) + logError( + "Transcription: Local daemon finalize failed for session \(sessionId)", error: error) return .error(error.localizedDescription) } } @@ -2397,7 +2440,8 @@ class AppState: ObservableObject { do { let fetchedConversations = try await conversationsTask - conversations = mergeLocalOnlyConversations(cachedLocalConversations, with: fetchedConversations) + conversations = mergeLocalOnlyConversations( + cachedLocalConversations, with: fetchedConversations) log( "Conversations: Refreshed \(fetchedConversations.count) from API (starred=\(showStarredOnly), date=\(selectedDateFilter?.description ?? "nil"))" ) @@ -3493,6 +3537,8 @@ extension Notification.Name { static let navigateToFloatingBarSettings = Notification.Name("navigateToFloatingBarSettings") /// Posted to navigate to AI Chat settings static let navigateToAIChatSettings = Notification.Name("navigateToAIChatSettings") + /// Posted to navigate to Transcription settings + static let navigateToTranscriptionSettings = Notification.Name("navigateToTranscriptionSettings") /// Posted when a new Rewind frame is captured (for live frame count updates) static let rewindFrameCaptured = Notification.Name("rewindFrameCaptured") /// Posted when Rewind page finishes loading initial data diff --git a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift index 9118e2a030d..3c985780b22 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift @@ -48,6 +48,7 @@ class PushToTalkManager: ObservableObject { private var currentContextSnapshot: PTTContextSnapshot? private var contextCaptureTask: Task? private var currentTranscriptionProvider: TranscriptionProviderKind = .cloud + private var shortcutCaptureSuspended = false // Batch mode: accumulate raw audio for post-recording transcription private var batchAudioBuffer = Data() @@ -74,6 +75,14 @@ class PushToTalkManager: ObservableObject { log("PushToTalkManager: cleanup complete") } + func setShortcutCaptureSuspended(_ suspended: Bool) { + shortcutCaptureSuspended = suspended + if suspended, state != .idle { + stopListening() + state = .idle + } + } + // MARK: - Event Monitors private func installEventMonitors() { @@ -115,6 +124,7 @@ class PushToTalkManager: ObservableObject { // MARK: - Shortcut Handling private func handleShortcutEvent(_ event: NSEvent) { + guard !shortcutCaptureSuspended else { return } guard ShortcutSettings.shared.pttEnabled else { return } let shortcut = ShortcutSettings.shared.pttShortcut diff --git a/desktop/Desktop/Sources/HybridChatClient.swift b/desktop/Desktop/Sources/HybridChatClient.swift index 5abe5d9e5bb..743c315f736 100644 --- a/desktop/Desktop/Sources/HybridChatClient.swift +++ b/desktop/Desktop/Sources/HybridChatClient.swift @@ -1,6 +1,6 @@ import Foundation -/// Direct OpenAI-compatible chat completions for hybrid local daemon mode (no pi-mono proxy). +/// Direct OpenAI-compatible chat completions for desktop provider routes. enum HybridChatClient { struct ProviderConfig: Equatable { @@ -52,7 +52,8 @@ enum HybridChatClient { switch self { case .notConfigured(let reason): if reason.isEmpty { - return "Chat model slot is not configured. Configure the chat slot in local provider policy." + return + "Chat model slot is not configured. Configure the chat slot in local provider policy." } return "Chat model slot is not configured: \(reason)" case .invalidSettings: @@ -65,17 +66,61 @@ enum HybridChatClient { } } - static func isEnabled() -> Bool { - if CodexAuthService.isActive { + enum Route: Equatable { + case directCodex(ProviderConfig) + case directDaemonChatSlot + case agentBridge(reason: String) + + var usesDirectProvider: Bool { + switch self { + case .directCodex, .directDaemonChatSlot: return true + case .agentBridge: + return false } - guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { + } + + var supportsInlineImages: Bool { + switch self { + case .directCodex: + return true + case .directDaemonChatSlot: return false + case .agentBridge: + return true + } + } + + var displayName: String { + switch self { + case .directCodex: + return "ChatGPT plan" + case .directDaemonChatSlot: + return "Local provider policy" + case .agentBridge: + return "Agent bridge" + } + } + } + + static func currentRoute() -> Route { + if CodexAuthService.isActive { + if let config = codexChatConfig() { + return .directCodex(config) + } + return .agentBridge( + reason: "ChatGPT plan is connected, but no Codex auth snapshot is available.") + } + guard DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon else { + return .agentBridge(reason: "Cloud backend mode uses the agent bridge.") } guard DesktopBackendEnvironment.isCapability(.directChat, availableIn: .localDaemon) else { - return false + return .agentBridge( + reason: DesktopBackendEnvironment.unavailableReason(for: .directChat, in: .localDaemon) + ?? "Direct local chat is unavailable." + ) } - return true + return .directDaemonChatSlot } static func resolveEffectiveChatConfig( @@ -84,23 +129,42 @@ enum HybridChatClient { HybridProviderPolicy.chatProviderConfig(from: response) } - /// Resolves the daemon chat slot and completes one chat turn (non-streaming). - static func completeFromDaemonSettings( + /// Completes one non-streaming turn through the active direct provider route. + static func completeWithActiveDirectProvider( systemPrompt: String, conversationMessages: [(role: String, text: String)], - userMessage: String + userMessage: String, + imageData: Data? = nil, + session: URLSession = .shared, + ensureCodexProxy: Bool = true ) async throws -> CompletionResult { - if CodexAuthService.isActive { + switch currentRoute() { + case .directCodex(let config): + if ensureCodexProxy { await CodexProxyService.shared.ensureRunning() } + return try await completeOpenAICompatible( + config: config, + systemPrompt: systemPrompt, + conversationMessages: conversationMessages, + userMessage: userMessage, + imageData: imageData, + session: session + ) + case .directDaemonChatSlot: let resolution = try await APIClient.shared.resolveSelectedBackendProviderSlot( HybridProviderPolicy.chatSlot) return try await complete( systemPrompt: systemPrompt, conversationMessages: conversationMessages, userMessage: userMessage, - slotResolution: resolution + slotResolution: resolution, + imageData: imageData, + session: session ) + case .agentBridge(let reason): + throw ClientError.notConfigured(reason) + } } static func complete( @@ -108,6 +172,7 @@ enum HybridChatClient { conversationMessages: [(role: String, text: String)], userMessage: String, slotResolution: HybridProviderPolicy.SlotResolutionResponse, + imageData: Data? = nil, session: URLSession = .shared ) async throws -> CompletionResult { guard let config = resolveEffectiveChatConfig(from: slotResolution) else { @@ -118,13 +183,82 @@ enum HybridChatClient { systemPrompt: systemPrompt, conversationMessages: conversationMessages, userMessage: userMessage, + imageData: imageData, session: session ) } + private static func codexChatConfig() -> ProviderConfig? { + guard let config = HybridLLMClient.codexProviderConfig() else { return nil } + return ProviderConfig( + baseURL: config.baseURL, + model: config.model, + apiKey: config.apiKey, + providerAccountID: "chatgpt-plan", + providerKind: "openai_compatible", + slotSource: "chatgpt_plan", + resolutionReason: "ChatGPT plan subscription integration" + ) + } + + private enum ChatCompletionContent: Encodable { + case text(String) + case parts([ChatCompletionContentPart]) + + func encode(to encoder: Encoder) throws { + switch self { + case .text(let value): + var container = encoder.singleValueContainer() + try container.encode(value) + case .parts(let parts): + var container = encoder.singleValueContainer() + try container.encode(parts) + } + } + } + + private struct ChatCompletionContentPart: Encodable { + private struct ImageURL: Encodable { + let url: String + } + + private let type: String + private let text: String? + private let imageURL: ImageURL? + + enum CodingKeys: String, CodingKey { + case type + case text + case imageURL = "image_url" + } + + static func text(_ value: String) -> ChatCompletionContentPart { + ChatCompletionContentPart(type: "text", text: value, imageURL: nil) + } + + static func image(_ imageData: Data) -> ChatCompletionContentPart { + ChatCompletionContentPart( + type: "image_url", + text: nil, + imageURL: ImageURL(url: "data:image/png;base64,\(imageData.base64EncodedString())") + ) + } + + func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(type, forKey: .type) + if let text { + try container.encode(text, forKey: .text) + } + if let imageURL { + try container.encode(imageURL, forKey: .imageURL) + } + } + } + private struct ChatCompletionMessage: Encodable { let role: String - let content: String + let content: ChatCompletionContent } private struct ChatCompletionRequest: Encodable { @@ -138,6 +272,7 @@ enum HybridChatClient { systemPrompt: String, conversationMessages: [(role: String, text: String)], userMessage: String, + imageData: Data?, session: URLSession ) async throws -> CompletionResult { let base = config.baseURL.hasSuffix("/") ? String(config.baseURL.dropLast()) : config.baseURL @@ -146,12 +281,24 @@ enum HybridChatClient { } var apiMessages: [ChatCompletionMessage] = [ - ChatCompletionMessage(role: "system", content: systemPrompt) + ChatCompletionMessage(role: "system", content: .text(systemPrompt)) ] for turn in conversationMessages { - apiMessages.append(ChatCompletionMessage(role: turn.role, content: turn.text)) + apiMessages.append(ChatCompletionMessage(role: turn.role, content: .text(turn.text))) + } + if let imageData { + apiMessages.append( + ChatCompletionMessage( + role: "user", + content: .parts([ + .text(userMessage), + .image(imageData), + ]) + ) + ) + } else { + apiMessages.append(ChatCompletionMessage(role: "user", content: .text(userMessage))) } - apiMessages.append(ChatCompletionMessage(role: "user", content: userMessage)) var request = URLRequest(url: url) request.httpMethod = "POST" @@ -172,7 +319,8 @@ enum HybridChatClient { throw ClientError.invalidResponse } guard (200..<300).contains(http.statusCode) else { - throw ClientError.providerError(parseProviderErrorBody(data: data, statusCode: http.statusCode)) + throw ClientError.providerError( + parseProviderErrorBody(data: data, statusCode: http.statusCode)) } guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any], let choices = json["choices"] as? [[String: Any]], diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRAddonManager.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRAddonManager.swift new file mode 100644 index 00000000000..f2533844adc --- /dev/null +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRAddonManager.swift @@ -0,0 +1,691 @@ +import CryptoKit +import Foundation + +struct LocalASRAddonProgress: Equatable { + var label: String + var fraction: Double? +} + +struct LocalASRAddonStatus: Equatable { + enum State: Equatable { + case notInstalled + case installing(LocalASRAddonProgress) + case installed(version: String, models: Set) + case updateAvailable(installedVersion: String, latestVersion: String) + case repairRequired(reason: String) + case unsupported(reason: String) + } + + var state: State + var pythonPath: String? + var detail: String + + var isInstalled: Bool { + if case .installed = state { return true } + if case .updateAvailable = state { return true } + return false + } + + var isActionableInstall: Bool { + switch state { + case .notInstalled, .repairRequired, .updateAvailable: + return true + case .installing, .unsupported: + return false + case .installed: + return true + } + } +} + +struct LocalASRAddonRemoteManifest: Codable, Equatable { + var version: Int + var runtime: RuntimeArtifact + var models: [ModelArtifact] + + struct RuntimeArtifact: Codable, Equatable { + var version: String + var platform: String + var arch: String + var url: String + var sha256: String + var sizeBytes: Int64 + var minimumAppVersion: String? + + enum CodingKeys: String, CodingKey { + case version, platform, arch, url, sha256 + case sizeBytes = "size_bytes" + case minimumAppVersion = "minimum_app_version" + } + } + + struct ModelArtifact: Codable, Equatable { + var model: LocalTranscriptionModel + var version: String + var url: String + var sha256: String + var sizeBytes: Int64 + + enum CodingKeys: String, CodingKey { + case model, version, url, sha256 + case sizeBytes = "size_bytes" + } + } +} + +struct LocalASRAddonInstalledManifest: Codable, Equatable { + var schemaVersion: Int + var runtimeVersion: String + var runtimeSha256: String + var pythonPath: String + var installedAt: Date + var models: [InstalledModel] + + struct InstalledModel: Codable, Equatable { + var model: LocalTranscriptionModel + var version: String + var sha256: String + var path: String + var installedAt: Date + } + + enum CodingKeys: String, CodingKey { + case schemaVersion = "schema_version" + case runtimeVersion = "runtime_version" + case runtimeSha256 = "runtime_sha256" + case pythonPath = "python_path" + case installedAt = "installed_at" + case models + } +} + +enum LocalASRAddonManager { + typealias ProgressHandler = @MainActor (LocalASRAddonProgress) -> Void + + private static let manifestFilename = "installed-manifest.json" + private static let schemaVersion = 1 + + static var rootDirectory: URL { + let appSupport = + FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first + ?? URL(fileURLWithPath: NSHomeDirectory()).appendingPathComponent( + "Library/Application Support") + return appSupport.appendingPathComponent("Omi/LocalASR", isDirectory: true) + } + + static var manifestURL: URL { + rootDirectory.appendingPathComponent(manifestFilename, isDirectory: false) + } + + static func status() -> LocalASRAddonStatus { + if let unsupported = unsupportedReason() { + return LocalASRAddonStatus( + state: .unsupported(reason: unsupported), pythonPath: nil, detail: unsupported) + } + + guard let installed = try? readInstalledManifest() else { + return LocalASRAddonStatus( + state: .notInstalled, + pythonPath: nil, + detail: "Install local Whisper to enable on-device transcription" + ) + } + + guard FileManager.default.isExecutableFile(atPath: installed.pythonPath) else { + return LocalASRAddonStatus( + state: .repairRequired(reason: "Installed runtime is missing"), + pythonPath: installed.pythonPath, + detail: "Installed runtime is missing" + ) + } + + if let missingModel = installed.models.first(where: { + !FileManager.default.fileExists(atPath: $0.path) + }) { + return LocalASRAddonStatus( + state: .repairRequired(reason: "Installed \(missingModel.model.rawValue) model is missing"), + pythonPath: installed.pythonPath, + detail: "Installed \(missingModel.model.rawValue) model is missing" + ) + } + + let installedModels = Set(installed.models.map(\.model)) + return LocalASRAddonStatus( + state: .installed(version: installed.runtimeVersion, models: installedModels), + pythonPath: installed.pythonPath, + detail: installedModels.isEmpty + ? "Local Whisper runtime installed; model install required" + : "Local Whisper add-on installed" + ) + } + + static func activateIfInstalled() { + guard let installed = try? readInstalledManifest() else { return } + guard FileManager.default.isExecutableFile(atPath: installed.pythonPath) else { return } + + setenv("OMI_LOCAL_ASR_PYTHON", installed.pythonPath, 1) + setenv("OMI_LOCAL_ASR_ALLOW_MODEL_DOWNLOAD", "0", 1) + + for model in installed.models where FileManager.default.fileExists(atPath: model.path) { + setenv(modelDirectoryEnvironmentKey(for: model.model), model.path, 1) + } + } + + static func install( + quality: TranscriptionQualityPreset = .auto, + progress: ProgressHandler? = nil + ) async throws -> LocalASRAddonStatus { + if let unsupported = unsupportedReason() { + throw addonError(unsupported) + } + + let model = initialModel(for: quality) + let remote = try await fetchRemoteManifest() + try await install(remote: remote, requiredModel: model, progress: progress) + activateIfInstalled() + return status() + } + + static func installModel( + for quality: TranscriptionQualityPreset, + progress: ProgressHandler? = nil + ) async throws -> LocalASRAddonStatus { + if let unsupported = unsupportedReason() { + throw addonError(unsupported) + } + + let model = initialModel(for: quality) + let remote = try await fetchRemoteManifest() + try await install(remote: remote, requiredModel: model, progress: progress) + activateIfInstalled() + return status() + } + + static func refreshStatusAgainstRemote() async -> LocalASRAddonStatus { + var current = status() + guard case .installed(let installedVersion, _) = current.state else { return current } + guard let remote = try? await fetchRemoteManifest() else { return current } + if remote.runtime.version != installedVersion { + current.state = .updateAvailable( + installedVersion: installedVersion, + latestVersion: remote.runtime.version + ) + current.detail = "Local Whisper update available" + } + return current + } + + static func status(afterCapabilityProbe engines: Set) + -> LocalASRAddonStatus + { + var current = status() + guard current.isInstalled, !engines.contains(.mlxWhisper) else { return current } + current.state = .repairRequired( + reason: "Installed runtime did not pass the MLX Whisper capability probe") + current.detail = "Installed runtime did not pass the MLX Whisper capability probe" + return current + } + + static func remove() throws -> LocalASRAddonStatus { + if FileManager.default.fileExists(atPath: rootDirectory.path) { + try FileManager.default.removeItem(at: rootDirectory) + } + + unsetenv("OMI_LOCAL_ASR_PYTHON") + unsetenv("OMI_LOCAL_ASR_ALLOW_MODEL_DOWNLOAD") + for model in LocalTranscriptionModel.allCases { + unsetenv(modelDirectoryEnvironmentKey(for: model)) + } + + return status() + } + + private static func install( + remote: LocalASRAddonRemoteManifest, + requiredModel: LocalTranscriptionModel, + progress: ProgressHandler? + ) async throws { + try validate(remote: remote) + let runtimeArtifact = remote.runtime + let modelArtifact = try artifact(for: requiredModel, in: remote) + + let tempRoot = rootDirectory.appendingPathComponent( + "installing-\(UUID().uuidString)", isDirectory: true) + let runtimeTemp = tempRoot.appendingPathComponent("runtime", isDirectory: true) + let modelTemp = tempRoot.appendingPathComponent( + "model-\(requiredModel.rawValue)", isDirectory: true) + let runtimeActive = rootDirectory.appendingPathComponent("runtime", isDirectory: true) + let modelsActive = rootDirectory.appendingPathComponent("models", isDirectory: true) + let modelActive = modelsActive.appendingPathComponent(requiredModel.rawValue, isDirectory: true) + + try FileManager.default.createDirectory(at: tempRoot, withIntermediateDirectories: true) + defer { try? FileManager.default.removeItem(at: tempRoot) } + + let existing = try? readInstalledManifest() + let existingRuntimeUsable = + existing?.runtimeVersion == runtimeArtifact.version + && (existing?.runtimeSha256.lowercased() == runtimeArtifact.sha256.lowercased()) + && FileManager.default.isExecutableFile(atPath: existing?.pythonPath ?? "") + + let activePython: URL + if existingRuntimeUsable, let existingPython = existing?.pythonPath { + activePython = URL(fileURLWithPath: existingPython) + } else { + await progress?(LocalASRAddonProgress(label: "Downloading runtime", fraction: nil)) + let runtimeZip = try await download( + artifactURL: runtimeArtifact.url, + expectedSHA256: runtimeArtifact.sha256, + expectedBytes: runtimeArtifact.sizeBytes, + destinationName: "runtime-\(runtimeArtifact.version).zip", + progressLabel: "Downloading runtime", + progress: progress + ) + try unzip(runtimeZip, to: runtimeTemp) + let runtimePayload = try singlePayloadDirectory(in: runtimeTemp) + let pythonPath = try findPython(in: runtimePayload) + + try FileManager.default.createDirectory(at: rootDirectory, withIntermediateDirectories: true) + if FileManager.default.fileExists(atPath: runtimeActive.path) { + try FileManager.default.removeItem(at: runtimeActive) + } + try FileManager.default.moveItem(at: runtimePayload, to: runtimeActive) + activePython = runtimeActive.appendingPathComponent( + relativePath(from: runtimePayload, to: pythonPath)) + } + + await progress?(LocalASRAddonProgress(label: "Downloading model", fraction: nil)) + let modelZip = try await download( + artifactURL: modelArtifact.url, + expectedSHA256: modelArtifact.sha256, + expectedBytes: modelArtifact.sizeBytes, + destinationName: "model-\(requiredModel.rawValue).zip", + progressLabel: "Downloading model", + progress: progress + ) + try unzip(modelZip, to: modelTemp) + + try FileManager.default.createDirectory(at: rootDirectory, withIntermediateDirectories: true) + try FileManager.default.createDirectory(at: modelsActive, withIntermediateDirectories: true) + + if FileManager.default.fileExists(atPath: modelActive.path) { + try FileManager.default.removeItem(at: modelActive) + } + try FileManager.default.moveItem(at: modelTemp, to: modelActive) + + let activeModelPath = try singlePayloadDirectory(in: modelActive).path + try writeInstalledManifest( + runtimeVersion: runtimeArtifact.version, + runtimeSha256: runtimeArtifact.sha256, + pythonPath: activePython.path, + model: requiredModel, + modelVersion: modelArtifact.version, + modelSha256: modelArtifact.sha256, + modelPath: activeModelPath + ) + + activateIfInstalled() + await progress?(LocalASRAddonProgress(label: "Validating local Whisper", fraction: nil)) + let engines = LocalASRHelperLocator.detectedEngines() + guard engines.contains(.mlxWhisper) else { + throw addonError("Installed runtime did not pass the MLX Whisper capability probe") + } + } + + private static func fetchRemoteManifest() async throws -> LocalASRAddonRemoteManifest { + let manifestURL = try remoteManifestURL() + if manifestURL.isFileURL { + let data = try Data(contentsOf: manifestURL) + return try JSONDecoder.localASRAddon.decode(LocalASRAddonRemoteManifest.self, from: data) + } + + let (data, response): (Data, URLResponse) + do { + (data, response) = try await URLSession.shared.data(from: manifestURL) + } catch { + throw addonError("Local Whisper manifest request failed: \(error.localizedDescription)") + } + guard let http = response as? HTTPURLResponse, (200..<300).contains(http.statusCode) else { + throw addonError("Local Whisper manifest request failed") + } + return try JSONDecoder.localASRAddon.decode(LocalASRAddonRemoteManifest.self, from: data) + } + + private static func remoteManifestURL() throws -> URL { + let override = ProcessInfo.processInfo.environment["OMI_LOCAL_ASR_MANIFEST_URL"]? + .trimmingCharacters(in: .whitespacesAndNewlines) + if let override, !override.isEmpty { + guard let url = URL(string: override) else { + throw addonError("Invalid OMI_LOCAL_ASR_MANIFEST_URL: \(override)") + } + return url + } + + let base = DesktopBackendEnvironment.rustBackendURL() + if let message = localDevManifestConfigurationMessage( + modeValue: ProcessInfo.processInfo.environment["OMI_DESKTOP_BACKEND_MODE"], + rustBackendURL: base, + manifestOverride: override + ) { + throw addonError(message) + } + + guard !base.isEmpty, let url = URL(string: base + "v1/local-asr/manifest") else { + throw addonError("Omi desktop backend is not configured for Local Whisper add-on downloads") + } + return url + } + + static func localDevManifestConfigurationMessage( + modeValue: String?, + rustBackendURL: String, + manifestOverride: String? + ) -> String? { + let mode = modeValue?.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() + let override = manifestOverride?.trimmingCharacters(in: .whitespacesAndNewlines) + guard override?.isEmpty ?? true else { return nil } + guard ["local", "local-daemon", "local_daemon", "daemon"].contains(mode ?? "") else { return nil } + guard let host = URL(string: rustBackendURL)?.host?.lowercased(), host == "omi-rust-invalid" else { + return nil + } + return + "Local Whisper add-on manifest is not configured for local dev. Run `make local-asr-fixture`, or set OMI_LOCAL_ASR_MANIFEST_URL." + } + + private static func download( + artifactURL: String, + expectedSHA256: String, + expectedBytes: Int64, + destinationName: String, + progressLabel: String, + progress: ProgressHandler? + ) async throws -> URL { + guard let url = URL(string: artifactURL) else { + throw addonError("Invalid Local Whisper artifact URL: \(artifactURL)") + } + + let downloads = rootDirectory.appendingPathComponent("downloads", isDirectory: true) + try FileManager.default.createDirectory(at: downloads, withIntermediateDirectories: true) + let destination = downloads.appendingPathComponent(destinationName, isDirectory: false) + let partial = destination.appendingPathExtension("part") + + if url.isFileURL { + await progress?(LocalASRAddonProgress(label: progressLabel, fraction: nil)) + if FileManager.default.fileExists(atPath: partial.path) { + try FileManager.default.removeItem(at: partial) + } + try FileManager.default.copyItem(at: url, to: partial) + let actual = try sha256(of: partial) + guard actual.lowercased() == expectedSHA256.lowercased() else { + try? FileManager.default.removeItem(at: partial) + throw addonError("Local Whisper artifact checksum mismatch") + } + if FileManager.default.fileExists(atPath: destination.path) { + try FileManager.default.removeItem(at: destination) + } + try FileManager.default.moveItem(at: partial, to: destination) + await progress?(LocalASRAddonProgress(label: progressLabel, fraction: 1.0)) + return destination + } + + var existingBytes: Int64 = 0 + if FileManager.default.fileExists(atPath: partial.path), + let attrs = try? FileManager.default.attributesOfItem(atPath: partial.path), + let size = attrs[.size] as? NSNumber + { + existingBytes = size.int64Value + } + + var request = URLRequest(url: url) + if existingBytes > 0 { + request.setValue("bytes=\(existingBytes)-", forHTTPHeaderField: "Range") + } + + let (bytes, response) = try await URLSession.shared.bytes(for: request) + guard let http = response as? HTTPURLResponse else { + throw addonError("Invalid Local Whisper artifact response") + } + + let shouldAppend = existingBytes > 0 && http.statusCode == 206 + guard http.statusCode == 200 || shouldAppend else { + throw addonError("Local Whisper artifact download failed with HTTP \(http.statusCode)") + } + + if !shouldAppend { + existingBytes = 0 + try? FileManager.default.removeItem(at: partial) + FileManager.default.createFile(atPath: partial.path, contents: nil) + } + + let handle = try FileHandle(forWritingTo: partial) + try handle.seekToEnd() + defer { try? handle.close() } + + var downloaded = existingBytes + var buffer = Data() + buffer.reserveCapacity(64 * 1024) + for try await byte in bytes { + buffer.append(byte) + if buffer.count >= 64 * 1024 { + try handle.write(contentsOf: buffer) + buffer.removeAll(keepingCapacity: true) + } + downloaded += 1 + if expectedBytes > 0 && downloaded % 262_144 == 0 { + await progress?( + LocalASRAddonProgress( + label: progressLabel, + fraction: min(Double(downloaded) / Double(expectedBytes), 1.0) + )) + } + } + if !buffer.isEmpty { + try handle.write(contentsOf: buffer) + } + + try handle.close() + let actual = try sha256(of: partial) + guard actual.lowercased() == expectedSHA256.lowercased() else { + try? FileManager.default.removeItem(at: partial) + throw addonError("Local Whisper artifact checksum mismatch") + } + + if FileManager.default.fileExists(atPath: destination.path) { + try FileManager.default.removeItem(at: destination) + } + try FileManager.default.moveItem(at: partial, to: destination) + await progress?(LocalASRAddonProgress(label: progressLabel, fraction: 1.0)) + return destination + } + + private static func validate(remote: LocalASRAddonRemoteManifest) throws { + guard remote.version == 1 else { + throw addonError("Unsupported Local Whisper manifest version") + } + guard remote.runtime.platform == "macos", remote.runtime.arch == "arm64" else { + throw addonError("Local Whisper runtime does not support this Mac") + } + guard !remote.runtime.url.isEmpty, !remote.runtime.sha256.isEmpty else { + throw addonError("Local Whisper runtime manifest is incomplete") + } + if let minimumAppVersion = remote.runtime.minimumAppVersion, + compareVersion(currentAppVersion(), minimumAppVersion) == .orderedAscending + { + throw addonError("Local Whisper requires Omi \(minimumAppVersion) or newer") + } + } + + private static func artifact( + for model: LocalTranscriptionModel, + in remote: LocalASRAddonRemoteManifest + ) throws -> LocalASRAddonRemoteManifest.ModelArtifact { + guard let artifact = remote.models.first(where: { $0.model == model }) else { + throw addonError("No Local Whisper model artifact is available for \(model.rawValue)") + } + return artifact + } + + private static func readInstalledManifest() throws -> LocalASRAddonInstalledManifest { + let data = try Data(contentsOf: manifestURL) + return try JSONDecoder.localASRAddon.decode(LocalASRAddonInstalledManifest.self, from: data) + } + + private static func writeInstalledManifest( + runtimeVersion: String, + runtimeSha256: String, + pythonPath: String, + model: LocalTranscriptionModel, + modelVersion: String, + modelSha256: String, + modelPath: String + ) throws { + var installedModels: [LocalASRAddonInstalledManifest.InstalledModel] = [] + if let existing = try? readInstalledManifest() { + installedModels = existing.models.filter { $0.model != model } + } + installedModels.append( + LocalASRAddonInstalledManifest.InstalledModel( + model: model, + version: modelVersion, + sha256: modelSha256, + path: modelPath, + installedAt: Date() + )) + + let manifest = LocalASRAddonInstalledManifest( + schemaVersion: schemaVersion, + runtimeVersion: runtimeVersion, + runtimeSha256: runtimeSha256, + pythonPath: pythonPath, + installedAt: Date(), + models: installedModels.sorted { $0.model.rawValue < $1.model.rawValue } + ) + let data = try JSONEncoder.localASRAddon.encode(manifest) + try data.write(to: manifestURL, options: .atomic) + } + + private static func findPython(in directory: URL) throws -> URL { + let candidates = [ + directory.appendingPathComponent("bin/python3", isDirectory: false), + directory.appendingPathComponent("venv/bin/python3", isDirectory: false), + directory.appendingPathComponent("runtime/bin/python3", isDirectory: false), + directory.appendingPathComponent("runtime/venv/bin/python3", isDirectory: false), + ] + if let python = candidates.first(where: { + FileManager.default.isExecutableFile(atPath: $0.path) + }) { + return python + } + throw addonError("Local Whisper runtime artifact does not contain bin/python3") + } + + private static func singlePayloadDirectory(in directory: URL) throws -> URL { + let contents = try FileManager.default.contentsOfDirectory( + at: directory, + includingPropertiesForKeys: [.isDirectoryKey], + options: [.skipsHiddenFiles] + ) + if contents.count == 1, + let values = try? contents[0].resourceValues(forKeys: [.isDirectoryKey]), + values.isDirectory == true + { + return contents[0] + } + return directory + } + + private static func unzip(_ zipURL: URL, to destination: URL) throws { + try FileManager.default.createDirectory(at: destination, withIntermediateDirectories: true) + let process = Process() + process.executableURL = URL(fileURLWithPath: "/usr/bin/unzip") + process.arguments = ["-q", zipURL.path, "-d", destination.path] + process.standardOutput = Pipe() + let errors = Pipe() + process.standardError = errors + try process.run() + process.waitUntilExit() + guard process.terminationStatus == 0 else { + let stderr = + String(data: errors.fileHandleForReading.readDataToEndOfFile(), encoding: .utf8) ?? "" + throw addonError("Failed to extract Local Whisper artifact: \(stderr)") + } + } + + private static func sha256(of url: URL) throws -> String { + let handle = try FileHandle(forReadingFrom: url) + defer { try? handle.close() } + + var hasher = SHA256() + while true { + let data = try handle.read(upToCount: 1024 * 1024) ?? Data() + if data.isEmpty { break } + hasher.update(data: data) + } + return hasher.finalize().map { String(format: "%02x", $0) }.joined() + } + + private static func relativePath(from root: URL, to child: URL) -> String { + let rootPath = root.standardizedFileURL.path + let childPath = child.standardizedFileURL.path + guard childPath.hasPrefix(rootPath + "/") else { return child.lastPathComponent } + return String(childPath.dropFirst(rootPath.count + 1)) + } + + private static func compareVersion(_ lhs: String, _ rhs: String) -> ComparisonResult { + lhs.compare(rhs, options: [.numeric, .caseInsensitive]) + } + + private static func currentAppVersion() -> String { + Bundle.main.object(forInfoDictionaryKey: "CFBundleShortVersionString") as? String ?? "0.0.0" + } + + static func initialModel(for quality: TranscriptionQualityPreset) -> LocalTranscriptionModel { + let memory = ProcessInfo.processInfo.physicalMemory / (1024 * 1024 * 1024) + switch quality { + case .fast: + return .base + case .auto, .balanced: + return memory >= 8 ? .small : .base + case .accurate: + return memory >= 24 ? .largeV3Turbo : (memory >= 16 ? .medium : .small) + } + } + + private static func modelDirectoryEnvironmentKey(for model: LocalTranscriptionModel) -> String { + "OMI_MLX_WHISPER_MODEL_DIR_\(model.rawValue.uppercased())" + } + + private static func unsupportedReason() -> String? { + let capabilities = LocalTranscriptionCapabilityDetector().detect() + switch capabilities.processor { + case .nativeAppleSilicon: + return nil + case .rosettaOnAppleSilicon: + return "MLX Whisper requires running Omi natively, not under Rosetta." + case .intel: + return "MLX Whisper requires an Apple Silicon Mac." + case .unknown: + return "This Mac does not report a supported Apple Silicon processor." + } + } + + private static func addonError(_ message: String) -> NSError { + NSError(domain: "LocalASRAddonManager", code: 1, userInfo: [NSLocalizedDescriptionKey: message]) + } +} + +extension JSONDecoder { + static var localASRAddon: JSONDecoder { + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .iso8601 + return decoder + } +} + +extension JSONEncoder { + static var localASRAddon: JSONEncoder { + let encoder = JSONEncoder() + encoder.dateEncodingStrategy = .iso8601 + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] + return encoder + } +} diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index d825bb9d3b2..efc347c3123 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -93,6 +93,7 @@ struct LocalASRHelperClient { func transcribe(_ request: LocalASRTranscriptionRequest) async throws -> LocalASRTranscriptionResponse { + LocalASRAddonManager.activateIfInstalled() let process = Process() process.executableURL = executableURL @@ -145,6 +146,12 @@ struct LocalASRHelperClient { enum LocalASRHelperLocator { static let environmentKey = "OMI_LOCAL_ASR_HELPER_PATH" + private static let cacheLock = NSLock() + private static var cachedEngines: Set? + private static var cachedExecutablePath: String? + private static var cachedAt: Date? + private static var refreshInFlight = false + private static var refreshInFlightExecutablePath: String? static func defaultExecutableURL( environment: [String: String] = ProcessInfo.processInfo.environment, @@ -187,11 +194,36 @@ enum LocalASRHelperLocator { static func detectedEngines(executableURL: URL? = defaultExecutableURL()) -> Set { - let probe = { detectedEnginesBlocking(executableURL: executableURL) } + LocalASRAddonManager.activateIfInstalled() + let executablePath = executableURL?.standardizedFileURL.path if Thread.isMainThread { - return DispatchQueue.global(qos: .userInitiated).sync(execute: probe) + if let cached = cachedEnginesIfFresh(for: executablePath) { + return cached + } + if executablePath != defaultExecutableURL()?.standardizedFileURL.path { + let engines = detectedEnginesBlocking(executableURL: executableURL) + storeCachedEngines(engines, for: executablePath) + return engines + } + refreshDetectedEnginesInBackground(executableURL: executableURL) + return cachedEnginesValue(for: executablePath) ?? [] + } + let engines = detectedEnginesBlocking(executableURL: executableURL) + storeCachedEngines(engines, for: executablePath) + return engines + } + + static func refreshDetectedEngines(executableURL: URL? = defaultExecutableURL()) async + -> Set + { + LocalASRAddonManager.activateIfInstalled() + return await withCheckedContinuation { continuation in + DispatchQueue.global(qos: .userInitiated).async { + let engines = detectedEnginesBlocking(executableURL: executableURL) + storeCachedEngines(engines, for: executableURL?.standardizedFileURL.path) + continuation.resume(returning: engines) + } } - return probe() } private static func detectedEnginesBlocking(executableURL: URL?) @@ -232,6 +264,58 @@ enum LocalASRHelperLocator { return Set(response.engines.filter(\.available).map(\.engine)) } + + private static func cachedEnginesIfFresh( + for executablePath: String?, + maxAge: TimeInterval = 60 + ) -> Set? { + cacheLock.lock() + defer { cacheLock.unlock() } + guard cachedExecutablePath == executablePath, let cachedEngines, let cachedAt, + Date().timeIntervalSince(cachedAt) <= maxAge + else { + return nil + } + return cachedEngines + } + + private static func cachedEnginesValue(for executablePath: String?) -> Set< + LocalTranscriptionEngine + >? { + cacheLock.lock() + defer { cacheLock.unlock() } + guard cachedExecutablePath == executablePath else { return nil } + return cachedEngines + } + + private static func storeCachedEngines( + _ engines: Set, for executablePath: String? + ) { + cacheLock.lock() + cachedEngines = engines + cachedExecutablePath = executablePath + cachedAt = Date() + refreshInFlight = false + refreshInFlightExecutablePath = nil + cacheLock.unlock() + } + + private static func refreshDetectedEnginesInBackground(executableURL: URL?) { + let executablePath = executableURL?.standardizedFileURL.path + cacheLock.lock() + if refreshInFlight, refreshInFlightExecutablePath == executablePath { + cacheLock.unlock() + return + } + refreshInFlight = true + refreshInFlightExecutablePath = executablePath + cacheLock.unlock() + + DispatchQueue.global(qos: .utility).async { + let engines = detectedEnginesBlocking(executableURL: executableURL) + storeCachedEngines(engines, for: executablePath) + } + } } struct LocalASRBatchTranscriber { diff --git a/desktop/Desktop/Sources/LocalTranscription/TranscriptionComparisonHarness.swift b/desktop/Desktop/Sources/LocalTranscription/TranscriptionComparisonHarness.swift new file mode 100644 index 00000000000..25082ff9527 --- /dev/null +++ b/desktop/Desktop/Sources/LocalTranscription/TranscriptionComparisonHarness.swift @@ -0,0 +1,439 @@ +import Foundation + +struct TranscriptionComparisonProviderSnapshot: Equatable { + var title: String + var status: String + var transcript: String + var segmentCount: Int + var wordCount: Int + var error: String? + + static func empty(title: String, status: String = "Idle") -> Self { + Self(title: title, status: status, transcript: "", segmentCount: 0, wordCount: 0, error: nil) + } +} + +struct TranscriptionComparisonHarnessSnapshot: Equatable { + var isRunning: Bool + var startedAt: Date? + var whisper: TranscriptionComparisonProviderSnapshot + var deepgram: TranscriptionComparisonProviderSnapshot + var wordDifferenceRate: Double? + var characterDifferenceRate: Double? + + static let idle = TranscriptionComparisonHarnessSnapshot( + isRunning: false, + startedAt: nil, + whisper: .empty(title: "Local Whisper"), + deepgram: .empty(title: "Local Deepgram"), + wordDifferenceRate: nil, + characterDifferenceRate: nil + ) +} + +@MainActor +final class TranscriptionComparisonHarness { + static let enabledDefaultsKey = "dev_transcription_comparison_harness_enabled" + + static var isEnabled: Bool { + #if DEBUG + return UserDefaults.standard.bool(forKey: enabledDefaultsKey) + #else + return false + #endif + } + + static func configuredDeepgramAPIKey() -> String? { + APIKeyService.byokKey(.deepgram) + ?? getenv("DEEPGRAM_API_KEY").flatMap { String(validatingUTF8: $0) } + } + + private let language: String + private var deepgramSession: DeepgramBackgroundTranscriptionSession? + private var whisperSegments: [NormalizedTranscriptSegment] = [] + private var deepgramSegments: [NormalizedTranscriptSegment] = [] + private var whisperStatus = "Waiting for Whisper" + private var deepgramStatus = "Waiting for Deepgram" + private var whisperError: String? + private var deepgramError: String? + private(set) var snapshot: TranscriptionComparisonHarnessSnapshot + private let onSnapshot: (TranscriptionComparisonHarnessSnapshot) -> Void + + init( + language: String, + deepgramAPIKey: String?, + onSnapshot: @escaping (TranscriptionComparisonHarnessSnapshot) -> Void + ) { + self.language = language + self.onSnapshot = onSnapshot + deepgramStatus = deepgramAPIKey == nil ? "Missing Deepgram API key" : "Waiting for Deepgram" + snapshot = TranscriptionComparisonHarnessSnapshot( + isRunning: true, + startedAt: Date(), + whisper: .empty(title: "Local Whisper", status: "Waiting for Whisper"), + deepgram: .empty( + title: "Local Deepgram", + status: deepgramAPIKey == nil ? "Missing Deepgram API key" : "Connecting" + ), + wordDifferenceRate: nil, + characterDifferenceRate: nil + ) + + if let deepgramAPIKey { + deepgramSession = DeepgramBackgroundTranscriptionSession( + language: language, + apiKey: deepgramAPIKey, + onSegments: { [weak self] segments in + Task { @MainActor in + self?.appendDeepgramSegments(segments) + } + }, + onStatus: { [weak self] status in + Task { @MainActor in + self?.deepgramStatus = status + self?.publish() + } + }, + onError: { [weak self] error in + Task { @MainActor in + self?.deepgramStatus = "Failed" + self?.deepgramError = error.localizedDescription + self?.publish() + } + } + ) + } else { + deepgramError = "Add a Deepgram key in Developer API Keys or set DEEPGRAM_API_KEY." + } + } + + func start() { + deepgramSession?.start() + publish() + } + + func appendAudio(_ pcmData: Data) { + deepgramSession?.appendAudio(pcmData) + } + + func appendWhisperSegments(_ segments: [NormalizedTranscriptSegment]) { + guard !segments.isEmpty else { return } + whisperStatus = "Receiving" + merge(segments, into: &whisperSegments) + publish() + } + + func finish() { + deepgramSession?.finish() + whisperStatus = whisperSegments.isEmpty ? "No transcript" : "Finalized" + deepgramStatus = + deepgramSegments.isEmpty && deepgramError == nil ? "Finalizing" : deepgramStatus + publish(isRunning: false) + } + + func stop() { + deepgramSession?.stop() + deepgramSession = nil + publish(isRunning: false) + } + + private func appendDeepgramSegments(_ segments: [NormalizedTranscriptSegment]) { + guard !segments.isEmpty else { return } + deepgramStatus = "Receiving" + merge(segments, into: &deepgramSegments) + publish() + } + + private func merge( + _ segments: [NormalizedTranscriptSegment], + into target: inout [NormalizedTranscriptSegment] + ) { + for segment in segments where !segment.text.isEmpty { + if let id = segment.segmentId, + let existingIndex = target.firstIndex(where: { $0.segmentId == id }) + { + target[existingIndex] = segment + } else { + target.append(segment) + } + } + target.sort { lhs, rhs in + if lhs.start == rhs.start { + return lhs.end < rhs.end + } + return lhs.start < rhs.start + } + } + + private func publish(isRunning: Bool? = nil) { + let whisperText = joinedTranscript(whisperSegments) + let deepgramText = joinedTranscript(deepgramSegments) + snapshot = TranscriptionComparisonHarnessSnapshot( + isRunning: isRunning ?? snapshot.isRunning, + startedAt: snapshot.startedAt, + whisper: providerSnapshot( + title: "Local Whisper", + status: whisperStatus, + transcript: whisperText, + segments: whisperSegments, + error: whisperError + ), + deepgram: providerSnapshot( + title: "Local Deepgram", + status: deepgramStatus, + transcript: deepgramText, + segments: deepgramSegments, + error: deepgramError + ), + wordDifferenceRate: comparisonRate( + reference: whisperText, + hypothesis: deepgramText, + scorer: TranscriptComparison.wordErrorRate + ), + characterDifferenceRate: comparisonRate( + reference: whisperText, + hypothesis: deepgramText, + scorer: TranscriptComparison.characterErrorRate + ) + ) + onSnapshot(snapshot) + } + + private func providerSnapshot( + title: String, + status: String, + transcript: String, + segments: [NormalizedTranscriptSegment], + error: String? + ) -> TranscriptionComparisonProviderSnapshot { + TranscriptionComparisonProviderSnapshot( + title: title, + status: status, + transcript: transcript, + segmentCount: segments.count, + wordCount: TranscriptComparison.normalizedWords(transcript).count, + error: error + ) + } + + private func joinedTranscript(_ segments: [NormalizedTranscriptSegment]) -> String { + segments + .map(\.text) + .joined(separator: " ") + .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) + .trimmingCharacters(in: .whitespacesAndNewlines) + } + + private func comparisonRate( + reference: String, + hypothesis: String, + scorer: (String, String) -> Double + ) -> Double? { + guard !TranscriptComparison.normalizedText(reference).isEmpty, + !TranscriptComparison.normalizedText(hypothesis).isEmpty + else { + return nil + } + return scorer(reference, hypothesis) + } +} + +final class DeepgramBackgroundTranscriptionSession { + private struct Response: Decodable { + struct Channel: Decodable { + struct Alternative: Decodable { + var transcript: String? + } + + var alternatives: [Alternative] + } + + var type: String? + var channel: Channel? + var isFinal: Bool? + var speechFinal: Bool? + var start: Double? + var duration: Double? + + enum CodingKeys: String, CodingKey { + case type + case channel + case isFinal = "is_final" + case speechFinal = "speech_final" + case start + case duration + } + } + + private let language: String + private let apiKey: String + private let onSegments: ([NormalizedTranscriptSegment]) -> Void + private let onStatus: (String) -> Void + private let onError: (Error) -> Void + private var webSocketTask: URLSessionWebSocketTask? + private var urlSession: URLSession? + private var isConnected = false + private var pendingAudio = Data() + private let pendingAudioLimit = 16_000 * 2 * 5 + + init( + language: String, + apiKey: String, + onSegments: @escaping ([NormalizedTranscriptSegment]) -> Void, + onStatus: @escaping (String) -> Void, + onError: @escaping (Error) -> Void + ) { + self.language = language + self.apiKey = apiKey + self.onSegments = onSegments + self.onStatus = onStatus + self.onError = onError + } + + func start() { + guard webSocketTask == nil else { return } + guard var components = URLComponents(string: "wss://api.deepgram.com/v1/listen") else { + return + } + var queryItems = [ + URLQueryItem(name: "model", value: "nova-3"), + URLQueryItem(name: "encoding", value: "linear16"), + URLQueryItem(name: "sample_rate", value: "16000"), + URLQueryItem(name: "channels", value: "1"), + URLQueryItem(name: "interim_results", value: "false"), + URLQueryItem(name: "smart_format", value: "true"), + URLQueryItem(name: "punctuate", value: "true"), + ] + if language != "multi" { + queryItems.append(URLQueryItem(name: "language", value: language)) + } + components.queryItems = queryItems + + guard let url = components.url else { return } + var request = URLRequest(url: url) + request.setValue("Token \(apiKey)", forHTTPHeaderField: "Authorization") + + let configuration = URLSessionConfiguration.default + configuration.timeoutIntervalForRequest = 30 + configuration.timeoutIntervalForResource = 0 + let session = URLSession(configuration: configuration) + urlSession = session + webSocketTask = session.webSocketTask(with: request) + webSocketTask?.resume() + receiveMessage() + + DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in + guard let self else { return } + guard self.webSocketTask?.state == .running else { + self.onStatus("Failed to connect") + return + } + self.isConnected = true + self.onStatus("Connected") + self.flushPendingAudio() + } + } + + func appendAudio(_ data: Data) { + guard isConnected else { + pendingAudio.append(data) + if pendingAudio.count > pendingAudioLimit { + pendingAudio.removeFirst(pendingAudio.count - pendingAudioLimit) + } + return + } + sendAudio(data) + } + + func finish() { + guard isConnected else { return } + sendString("{\"type\":\"CloseStream\"}") + onStatus("Finalizing") + } + + func stop() { + isConnected = false + webSocketTask?.cancel(with: .normalClosure, reason: nil) + webSocketTask = nil + urlSession?.invalidateAndCancel() + urlSession = nil + pendingAudio.removeAll() + onStatus("Stopped") + } + + private func flushPendingAudio() { + guard !pendingAudio.isEmpty else { return } + let audio = pendingAudio + pendingAudio.removeAll() + sendAudio(audio) + } + + private func sendAudio(_ data: Data) { + webSocketTask?.send(.data(data)) { [weak self] error in + if let error { + self?.onError(error) + } + } + } + + private func sendString(_ value: String) { + webSocketTask?.send(.string(value)) { [weak self] error in + if let error { + self?.onError(error) + } + } + } + + private func receiveMessage() { + webSocketTask?.receive { [weak self] result in + guard let self else { return } + switch result { + case .success(let message): + self.handleMessage(message) + self.receiveMessage() + case .failure(let error): + guard self.isConnected else { return } + self.isConnected = false + self.onError(error) + } + } + } + + private func handleMessage(_ message: URLSessionWebSocketTask.Message) { + let data: Data? + switch message { + case .string(let text): + data = text.data(using: .utf8) + case .data(let messageData): + data = messageData + @unknown default: + data = nil + } + + guard let data, + let response = try? JSONDecoder().decode(Response.self, from: data), + response.isFinal == true, + let transcript = response.channel?.alternatives.first?.transcript? + .trimmingCharacters(in: .whitespacesAndNewlines), + !transcript.isEmpty + else { + return + } + + let start = response.start ?? 0 + let end = start + (response.duration ?? 0) + onSegments([ + NormalizedTranscriptSegment( + segmentId: "deepgram-bg-\(UUID().uuidString)", + speaker: 0, + speakerLabel: nil, + text: transcript, + start: start, + end: end, + isUser: true, + personId: nil, + translations: [] + ) + ]) + } +} diff --git a/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift b/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift index e633c49cee3..5358ca085b7 100644 --- a/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift +++ b/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift @@ -128,7 +128,8 @@ struct DesktopHomeView: View { ) } } - .onReceive(NotificationCenter.default.publisher(for: .showUsageLimitPopup)) { notification in + .onReceive(NotificationCenter.default.publisher(for: .showUsageLimitPopup)) { + notification in let reason = notification.userInfo?["reason"] as? String ?? "" appState.triggerUsageLimitPopup(reason: reason) } @@ -151,7 +152,9 @@ struct DesktopHomeView: View { // Auto-start transcription if enabled in settings. // If API keys aren't loaded yet, onChange below retries. if settings.transcriptionEnabled && !appState.isTranscribing { - if APIKeyService.keysAvailable || !backgroundTranscriptionNeedsAPIKeys(settings: settings) { + if APIKeyService.keysAvailable + || !backgroundTranscriptionNeedsAPIKeys(settings: settings) + { log("DesktopHomeView: Auto-starting transcription") appState.startTranscription() } else { @@ -232,7 +235,9 @@ struct DesktopHomeView: View { ) { _ in // Cooldown: only refresh conversations if last activation was 60+ seconds ago let now = Date() - if PollingConfig.shouldAllowActivationRefresh(now: now, lastRefresh: lastActivationRefresh) { + if PollingConfig.shouldAllowActivationRefresh( + now: now, lastRefresh: lastActivationRefresh) + { lastActivationRefresh = now Task { await appState.refreshConversations() } } @@ -740,6 +745,16 @@ struct DesktopHomeView: View { selectedIndex = SidebarNavItem.settings.rawValue } } + .onReceive(NotificationCenter.default.publisher(for: .navigateToTranscriptionSettings)) { + notification in + selectedSettingsSection = .transcription + highlightedSettingId = + notification.userInfo?["highlightedSettingId"] as? String + ?? "transcription.localWhisperAddon" + withAnimation(.easeInOut(duration: 0.2)) { + selectedIndex = SidebarNavItem.settings.rawValue + } + } .onReceive(NotificationCenter.default.publisher(for: .navigateToRewind)) { _ in // Navigate to Rewind page (index 6) - triggered by global hotkey Cmd+Option+R log( diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 421e7678b66..9370be839c5 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -189,6 +189,8 @@ struct SettingsContentView: View { @State private var rawTranscriptionHistory: [TranscriptionSessionWithSegments] = [] @State private var isLoadingRawTranscriptionHistory = false @State private var rawTranscriptionHistoryError: String? + @AppStorage(TranscriptionComparisonHarness.enabledDefaultsKey) + private var transcriptionComparisonHarnessEnabled = false @State private var chatMessageCount: Int? @State private var isLoadingChatMessages = false @State private var showProfileAndStats = false @@ -283,6 +285,9 @@ struct SettingsContentView: View { @State private var vadGateEnabled: Bool = false @State private var transcriptionProviderSelection: TranscriptionProviderSelection @State private var localTranscriptionCapabilities: LocalTranscriptionCapabilities + @State private var localASRAddonStatus: LocalASRAddonStatus + @State private var isInstallingLocalASRAddon = false + @State private var localASRAddonMessage: String? // Multi-chat mode setting @AppStorage("multiChatEnabled") private var multiChatEnabled = false @@ -455,6 +460,7 @@ struct SettingsContentView: View { _transcriptionProviderSelection = State(initialValue: settings.transcriptionProviderSelection) _localTranscriptionCapabilities = State( initialValue: SettingsContentView.detectLocalTranscriptionCapabilities()) + _localASRAddonStatus = State(initialValue: LocalASRAddonManager.status()) } /// Computed status text for notifications @@ -523,7 +529,8 @@ struct SettingsContentView: View { chatProvider?.checkClaudeConnectionStatus() // Refresh notification permission state appState.checkNotificationPermission() - localTranscriptionCapabilities = SettingsContentView.detectLocalTranscriptionCapabilities() + refreshLocalTranscriptionCapabilities() + refreshLocalASRAddonStatus() } .onReceive(NotificationCenter.default.publisher(for: .assistantMonitoringStateDidChange)) { notification in @@ -556,6 +563,13 @@ struct SettingsContentView: View { .onReceive(NotificationCenter.default.publisher(for: .navigateToFloatingBarSettings)) { _ in selectedSection = .floatingBar } + .onReceive(NotificationCenter.default.publisher(for: .navigateToTranscriptionSettings)) { + notification in + selectedSection = .transcription + highlightedSettingId = + notification.userInfo?["highlightedSettingId"] as? String + ?? "transcription.localWhisperAddon" + } .onReceive(NotificationCenter.default.publisher(for: NSApplication.didBecomeActiveNotification)) { _ in // Refresh notification permission when app becomes active (user may have changed it in System Settings) @@ -1133,9 +1147,16 @@ struct SettingsContentView: View { .fill(OmiColors.warning.opacity(0.1)) ) } + + localASRAddonControls } } + localTranscriptionDebugPanel + #if DEBUG + transcriptionComparisonHarnessPanel + #endif + // Language Mode settingsCard(settingId: "transcription.languagemode") { VStack(alignment: .leading, spacing: 16) { @@ -1452,6 +1473,38 @@ struct SettingsContentView: View { ).detect() } + private static func refreshLocalTranscriptionCapabilities() async + -> LocalTranscriptionCapabilities + { + let engines = await LocalASRHelperLocator.refreshDetectedEngines() + return LocalTranscriptionCapabilityDetector(availableEngines: { engines }).detect() + } + + private func refreshLocalTranscriptionCapabilities() { + Task { + let engines = await LocalASRHelperLocator.refreshDetectedEngines() + let capabilities = LocalTranscriptionCapabilityDetector(availableEngines: { engines }) + .detect() + let addonStatus = LocalASRAddonManager.status(afterCapabilityProbe: engines) + await MainActor.run { + localTranscriptionCapabilities = capabilities + localASRAddonStatus = addonStatus + } + } + } + + private func refreshLocalASRAddonStatus() { + Task { + let status = await LocalASRAddonManager.refreshStatusAgainstRemote() + await MainActor.run { + if case .repairRequired = localASRAddonStatus.state { + return + } + localASRAddonStatus = status + } + } + } + private var resolvedTranscriptionProvider: TranscriptionProviderPolicyResult { TranscriptionProviderPolicy().resolve( selection: transcriptionProviderSelection, @@ -1481,6 +1534,292 @@ struct SettingsContentView: View { resolvedTranscriptionProvider.usesLocal ? OmiColors.success : OmiColors.textTertiary } + private var localASRAddonControls: some View { + VStack(alignment: .leading, spacing: 10) { + HStack(alignment: .top, spacing: 10) { + Image( + systemName: localASRAddonStatus.isInstalled + ? "checkmark.seal.fill" : "arrow.down.circle" + ) + .scaledFont(size: 14) + .foregroundColor( + localASRAddonStatus.isInstalled ? OmiColors.success : OmiColors.textSecondary + ) + .frame(width: 18) + + VStack(alignment: .leading, spacing: 4) { + Text("Local Whisper Add-on") + .scaledFont(size: 13, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + + Text(localASRAddonStatus.pythonPath ?? localASRAddonStatus.detail) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + .lineLimit(2) + .truncationMode(.middle) + } + + Spacer(minLength: 12) + + if isInstallingLocalASRAddon { + ProgressView() + .controlSize(.small) + } else { + Button(localASRAddonPrimaryActionTitle) { + installLocalASRAddon() + } + .buttonStyle(.borderedProminent) + .tint(OmiColors.purplePrimary) + .disabled(!localASRAddonStatus.isActionableInstall) + + if localASRAddonStatus.isInstalled { + Button("Remove") { + removeLocalASRAddon() + } + .buttonStyle(.bordered) + } + } + } + + if let localASRAddonMessage { + Text(localASRAddonMessage) + .scaledFont(size: 11) + .foregroundColor( + localASRAddonMessage.hasPrefix("Installed") ? OmiColors.success : OmiColors.warning + ) + .fixedSize(horizontal: false, vertical: true) + } + } + .modifier( + SettingHighlightModifier( + settingId: "transcription.localWhisperAddon", + highlightedSettingId: $highlightedSettingId + ) + ) + .padding(10) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(OmiColors.backgroundSecondary.opacity(0.7)) + .overlay( + RoundedRectangle(cornerRadius: 8) + .stroke(OmiColors.backgroundQuaternary, lineWidth: 1) + ) + ) + } + + private var localTranscriptionDebugPanel: some View { + settingsCard(settingId: "transcription.rawdebug") { + VStack(alignment: .leading, spacing: 14) { + HStack(alignment: .center, spacing: 12) { + Image(systemName: "text.magnifyingglass") + .scaledFont(size: 16) + .foregroundColor(OmiColors.purplePrimary) + + VStack(alignment: .leading, spacing: 4) { + Text("Raw Local Transcription") + .scaledFont(size: 15, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + Text("Recent locally persisted sessions and raw segment text") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + } + + Spacer() + + Button { + Task { await loadRawTranscriptionHistory() } + } label: { + HStack(spacing: 6) { + if isLoadingRawTranscriptionHistory { + ProgressView().controlSize(.mini) + } else { + Image(systemName: "arrow.clockwise") + .scaledFont(size: 12, weight: .semibold) + } + Text("Refresh") + .scaledFont(size: 13, weight: .medium) + } + } + .buttonStyle(.plain) + .padding(.horizontal, 12) + .padding(.vertical, 6) + .background(OmiColors.backgroundSecondary) + .clipShape(RoundedRectangle(cornerRadius: 8)) + .disabled(isLoadingRawTranscriptionHistory) + } + + if let rawTranscriptionHistoryError { + Text(rawTranscriptionHistoryError) + .scaledFont(size: 12) + .foregroundColor(OmiColors.warning) + } else if rawTranscriptionHistory.isEmpty { + Text( + "No local transcription sessions found yet. Start local background transcription, then refresh this panel." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } else { + VStack(alignment: .leading, spacing: 10) { + ForEach(rawTranscriptionHistory.prefix(3).indices, id: \.self) { index in + rawTranscriptionHistoryRow(rawTranscriptionHistory[index]) + } + } + } + } + .task { + if rawTranscriptionHistory.isEmpty && rawTranscriptionHistoryError == nil { + await loadRawTranscriptionHistory() + } + } + } + } + + #if DEBUG + private var transcriptionComparisonHarnessPanel: some View { + settingsCard(settingId: "transcription.comparisonharness") { + VStack(alignment: .leading, spacing: 14) { + HStack(alignment: .center, spacing: 12) { + Image(systemName: "rectangle.split.2x1") + .scaledFont(size: 16) + .foregroundColor(OmiColors.purplePrimary) + + VStack(alignment: .leading, spacing: 4) { + Text("Whisper vs Deepgram Harness") + .scaledFont(size: 15, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + Text("Development-only side-by-side output for local background capture") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } + + Spacer() + + Toggle("", isOn: $transcriptionComparisonHarnessEnabled) + .toggleStyle(.switch) + .disabled(appState.isTranscribing) + } + + if appState.isTranscribing && transcriptionComparisonHarnessEnabled { + Text("Stop and restart local background transcription after changing this toggle.") + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + } + + let snapshot = appState.transcriptionComparisonSnapshot + if transcriptionComparisonHarnessEnabled { + HStack(spacing: 8) { + comparisonMetricChip( + title: "Word diff", + value: formattedComparisonRate(snapshot.wordDifferenceRate) + ) + comparisonMetricChip( + title: "Char diff", + value: formattedComparisonRate(snapshot.characterDifferenceRate) + ) + Spacer() + Text(snapshot.isRunning ? "Running" : "Idle") + .scaledFont(size: 11, weight: .medium) + .foregroundColor(snapshot.isRunning ? OmiColors.success : OmiColors.textTertiary) + } + + HStack(alignment: .top, spacing: 12) { + comparisonProviderColumn(snapshot.whisper) + comparisonProviderColumn(snapshot.deepgram) + } + } else { + Text( + "Enable this, select a local background transcription provider, then start recording." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } + } + } + } + + private func comparisonMetricChip(title: String, value: String) -> some View { + HStack(spacing: 6) { + Text(title) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + Text(value) + .scaledMonospacedFont(size: 11, weight: .semibold) + .foregroundColor(OmiColors.textSecondary) + } + .padding(.horizontal, 8) + .padding(.vertical, 5) + .background(OmiColors.backgroundSecondary.opacity(0.8)) + .clipShape(RoundedRectangle(cornerRadius: 6)) + } + + private func comparisonProviderColumn( + _ provider: TranscriptionComparisonProviderSnapshot + ) -> some View { + VStack(alignment: .leading, spacing: 8) { + HStack(spacing: 8) { + Text(provider.title) + .scaledFont(size: 12, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + Spacer() + Text(provider.status) + .scaledFont(size: 10, weight: .medium) + .foregroundColor(provider.error == nil ? OmiColors.textTertiary : OmiColors.warning) + } + + HStack(spacing: 8) { + Text("\(provider.segmentCount) segments") + Text("\(provider.wordCount) words") + } + .scaledFont(size: 10) + .foregroundColor(OmiColors.textTertiary) + + if let error = provider.error { + Text(error) + .scaledFont(size: 11) + .foregroundColor(OmiColors.warning) + .fixedSize(horizontal: false, vertical: true) + } + + Text(provider.transcript.isEmpty ? "(waiting for transcript)" : provider.transcript) + .scaledMonospacedFont(size: 11) + .foregroundColor(OmiColors.textSecondary) + .lineLimit(10) + .textSelection(.enabled) + .frame(maxWidth: .infinity, minHeight: 96, alignment: .topLeading) + .padding(10) + .background(OmiColors.backgroundSecondary.opacity(0.8)) + .clipShape(RoundedRectangle(cornerRadius: 8)) + } + .frame(maxWidth: .infinity, alignment: .topLeading) + } + + private func formattedComparisonRate(_ value: Double?) -> String { + guard let value else { return "n/a" } + return "\(Int((value * 100).rounded()))%" + } + #endif + + private var localASRAddonPrimaryActionTitle: String { + switch localASRAddonStatus.state { + case .updateAvailable: + return "Update" + case .repairRequired: + return "Repair" + case .installed(_, let models): + let required = LocalASRAddonManager.initialModel(for: transcriptionProviderSelection.quality) + return models.contains(required) ? "Repair" : "Install Model" + case .notInstalled: + return "Install" + case .installing: + return "Installing" + case .unsupported: + return "Unavailable" + } + } + private func transcriptionProviderOption( mode: TranscriptionProviderKind, title: String, @@ -1539,6 +1878,59 @@ struct SettingsContentView: View { restartTranscriptionIfNeeded() } + private func installLocalASRAddon() { + guard !isInstallingLocalASRAddon else { return } + isInstallingLocalASRAddon = true + localASRAddonMessage = "Installing local Whisper runtime..." + let startingStatus = localASRAddonStatus + let requiredModel = LocalASRAddonManager.initialModel( + for: transcriptionProviderSelection.quality) + + Task { + do { + let reportProgress: LocalASRAddonManager.ProgressHandler = { progress in + localASRAddonMessage = + progress.fraction.map { "\(progress.label) \(Int($0 * 100))%" } ?? progress.label + } + let status: LocalASRAddonStatus + if case .installed(_, let models) = startingStatus.state, !models.contains(requiredModel) { + status = try await LocalASRAddonManager.installModel( + for: transcriptionProviderSelection.quality, + progress: reportProgress + ) + } else { + status = try await LocalASRAddonManager.install( + quality: transcriptionProviderSelection.quality, + progress: reportProgress + ) + } + let capabilities = await SettingsContentView.refreshLocalTranscriptionCapabilities() + await MainActor.run { + localASRAddonStatus = status + localTranscriptionCapabilities = capabilities + localASRAddonMessage = "Installed local Whisper add-on." + isInstallingLocalASRAddon = false + } + } catch { + await MainActor.run { + localASRAddonStatus = LocalASRAddonManager.status() + localASRAddonMessage = error.localizedDescription + isInstallingLocalASRAddon = false + } + } + } + } + + private func removeLocalASRAddon() { + do { + localASRAddonStatus = try LocalASRAddonManager.remove() + refreshLocalTranscriptionCapabilities() + localASRAddonMessage = "Removed local Whisper add-on." + } catch { + localASRAddonMessage = error.localizedDescription + } + } + // MARK: - Notifications Section private var notificationsSection: some View { @@ -2101,7 +2493,8 @@ struct SettingsContentView: View { hybridSlotEditorBlock( title: "Proactive assistants", - subtitle: "Local assistant jobs; defaults to \(HybridProviderReadiness.defaultSmallModel())", + subtitle: + "Local assistant jobs; defaults to \(HybridProviderReadiness.defaultSmallModel())", model: $hybridEmbedModel, slot: HybridProviderPolicy.proactiveSlot ) @@ -2163,7 +2556,11 @@ struct SettingsContentView: View { testHybridProviderSlot(slot) } .buttonStyle(.bordered) - .disabled(isTestingHybridProvider || (optional && model.wrappedValue.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)) + .disabled( + isTestingHybridProvider + || (optional + && model.wrappedValue.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty) + ) } } } @@ -3519,8 +3916,6 @@ struct SettingsContentView: View { preferencesSubsection advancedCategoryHeader(title: "Troubleshooting", icon: "wrench.and.screwdriver") troubleshootingSubsection - advancedCategoryHeader(title: "ChatGPT plan", icon: "bubble.left.and.bubble.right") - chatGPTPlanSubsection advancedCategoryHeader(title: "Developer API Keys", icon: "key") developerKeysSubsection @@ -3561,65 +3956,6 @@ struct SettingsContentView: View { } } - settingsCard(settingId: "advanced.devtools.rawtranscription") { - VStack(alignment: .leading, spacing: 14) { - HStack(spacing: 12) { - Image(systemName: "waveform.badge.magnifyingglass") - .scaledFont(size: 16) - .foregroundColor(OmiColors.purplePrimary) - VStack(alignment: .leading, spacing: 4) { - Text("Raw Transcription History") - .scaledFont(size: 15, weight: .semibold) - .foregroundColor(OmiColors.textPrimary) - Text("Inspect locally persisted sessions and segment text from background capture") - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) - } - Spacer() - Button { - Task { await loadRawTranscriptionHistory() } - } label: { - HStack(spacing: 6) { - if isLoadingRawTranscriptionHistory { - ProgressView().controlSize(.mini) - } else { - Image(systemName: "arrow.clockwise") - .scaledFont(size: 12, weight: .semibold) - } - Text("Refresh") - .scaledFont(size: 13, weight: .medium) - } - } - .buttonStyle(.plain) - .padding(.horizontal, 12) - .padding(.vertical, 6) - .background(OmiColors.backgroundSecondary) - .clipShape(RoundedRectangle(cornerRadius: 8)) - .disabled(isLoadingRawTranscriptionHistory) - } - - if let rawTranscriptionHistoryError { - Text(rawTranscriptionHistoryError) - .scaledFont(size: 12) - .foregroundColor(OmiColors.warning) - } else if rawTranscriptionHistory.isEmpty { - Text("No local transcription sessions found yet.") - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) - } else { - VStack(alignment: .leading, spacing: 10) { - ForEach(rawTranscriptionHistory.indices, id: \.self) { index in - rawTranscriptionHistoryRow(rawTranscriptionHistory[index]) - } - } - } - } - .task { - if rawTranscriptionHistory.isEmpty && rawTranscriptionHistoryError == nil { - await loadRawTranscriptionHistory() - } - } - } } } @@ -3728,6 +4064,8 @@ struct SettingsContentView: View { } } + chatGPTPlanSubsection + settingsCard(settingId: "aichat.workspace") { VStack(alignment: .leading, spacing: 12) { HStack { @@ -5699,10 +6037,13 @@ struct SettingsContentView: View { private var chatGPTPlanSubsection: some View { VStack(spacing: 20) { - settingsCard(settingId: "advanced.chatgpt.info") { + settingsCard(settingId: "aichat.chatgpt.info") { VStack(alignment: .leading, spacing: 10) { HStack(spacing: 10) { - Image(systemName: codexAuthStore.isActive ? "checkmark.seal.fill" : "person.crop.circle.badge.checkmark") + Image( + systemName: codexAuthStore.isActive + ? "checkmark.seal.fill" : "person.crop.circle.badge.checkmark" + ) .foregroundColor(codexAuthStore.isActive ? OmiColors.success : OmiColors.textTertiary) Text(codexAuthStore.isActive ? "ChatGPT plan active" : "Use your ChatGPT subscription") .scaledFont(size: 14, weight: .semibold) @@ -5716,7 +6057,9 @@ struct SettingsContentView: View { .scaledFont(size: 12) .foregroundColor(OmiColors.textTertiary) if codexProxyService.isRunning { - Text("Proxy: \(CodexProxyService.defaultBaseURL)") + Text( + "Active provider: \(HybridChatClient.currentRoute().displayName) via \(CodexProxyService.defaultBaseURL)" + ) .scaledFont(size: 11) .foregroundColor(OmiColors.success) } else if codexAuthStore.isEnrolled, let proxyError = codexProxyService.lastError { @@ -5732,7 +6075,7 @@ struct SettingsContentView: View { } if let codexEnrollmentError { - settingsCard(settingId: "advanced.chatgpt.error") { + settingsCard(settingId: "aichat.chatgpt.error") { HStack(spacing: 10) { Image(systemName: "exclamationmark.triangle.fill") .foregroundColor(OmiColors.warning) @@ -6811,7 +7154,8 @@ struct SettingsContentView: View { { return "\(account.displayName ?? account.id) / \(resolved.modelID)" } - return "Deterministic fallback; post-transcript defaults to \(HybridProviderReadiness.defaultSmallModel()) until a provider account is configured" + return + "Deterministic fallback; post-transcript defaults to \(HybridProviderReadiness.defaultSmallModel()) until a provider account is configured" } private func backendStatusRow(title: String, value: String) -> some View { diff --git a/desktop/Desktop/Sources/MainWindow/Pages/ShortcutsSettingsSection.swift b/desktop/Desktop/Sources/MainWindow/Pages/ShortcutsSettingsSection.swift index 2c4c40af0d6..a316734a63c 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/ShortcutsSettingsSection.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/ShortcutsSettingsSection.swift @@ -387,6 +387,7 @@ struct ShortcutsSettingsSection: View { stopShortcutCapture() recordingTarget = target captureError = nil + PushToTalkManager.shared.setShortcutCaptureSuspended(true) localShortcutCaptureMonitor = NSEvent.addLocalMonitorForEvents(matching: [ .flagsChanged, .keyDown, @@ -400,6 +401,7 @@ struct ShortcutsSettingsSection: View { NSEvent.removeMonitor(monitor) localShortcutCaptureMonitor = nil } + PushToTalkManager.shared.setShortcutCaptureSuspended(false) recordingTarget = nil captureError = nil } @@ -419,8 +421,11 @@ struct ShortcutsSettingsSection: View { else { return false } - settings.askOmiEnabled = true - settings.askOmiShortcut = shortcut + DispatchQueue.main.async { + settings.askOmiEnabled = true + settings.askOmiShortcut = shortcut + stopShortcutCapture() + } case .pushToTalk: guard let shortcut = ShortcutSettings.KeyboardShortcut.fromRecordingEvent( @@ -428,11 +433,13 @@ struct ShortcutsSettingsSection: View { else { return false } - settings.pttEnabled = true - settings.pttShortcut = shortcut + DispatchQueue.main.async { + settings.pttEnabled = true + settings.pttShortcut = shortcut + stopShortcutCapture() + } } - stopShortcutCapture() return true } } diff --git a/desktop/Desktop/Sources/MainWindow/SettingsSidebar.swift b/desktop/Desktop/Sources/MainWindow/SettingsSidebar.swift index 55743324863..9f1457bfe00 100644 --- a/desktop/Desktop/Sources/MainWindow/SettingsSidebar.swift +++ b/desktop/Desktop/Sources/MainWindow/SettingsSidebar.swift @@ -165,7 +165,8 @@ struct SettingsSearchItem: Identifiable { settingId: "planusage.current"), SettingsSearchItem( name: "Upgrade Plan", subtitle: "Buy Operator or Architect", - keywords: ["upgrade", "buy", "pricing", "checkout", "architect", "operator", "unlimited"], section: .planUsage, + keywords: ["upgrade", "buy", "pricing", "checkout", "architect", "operator", "unlimited"], + section: .planUsage, icon: "creditcard", settingId: "planusage.purchase"), // About @@ -210,8 +211,14 @@ struct SettingsSearchItem: Identifiable { settingId: "advanced.stats"), SettingsSearchItem( name: "AI Provider", subtitle: "Choose between your omi account and Claude for desktop chat", - keywords: ["provider", "agent sdk", "claude code", "acp", "bridge mode"], section: .advanced, + keywords: ["provider", "agent sdk", "claude code", "acp", "bridge mode", "chatgpt", "codex"], + section: .advanced, icon: "cpu", settingId: "aichat.provider"), + SettingsSearchItem( + name: "ChatGPT Plan", subtitle: "Connect your ChatGPT subscription for desktop AI", + keywords: ["chatgpt", "codex", "subscription", "provider", "proxy", "openai"], + section: .advanced, + icon: "cpu", settingId: "aichat.chatgpt.info"), SettingsSearchItem( name: "Workspace", subtitle: "Set a project directory for desktop chat context", keywords: ["workspace", "project", "directory", "folder", "working directory"], diff --git a/desktop/Desktop/Sources/ProactiveAssistants/UI/InsightTestRunnerWindow.swift b/desktop/Desktop/Sources/ProactiveAssistants/UI/InsightTestRunnerWindow.swift index 16077ebcc05..3643f1ae3c1 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/UI/InsightTestRunnerWindow.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/UI/InsightTestRunnerWindow.swift @@ -719,9 +719,12 @@ enum InsightTestRunner { } // Pick the latest non-excluded screenshot + let builtInExcludedApps = await MainActor.run { + TaskAssistantSettings.builtInExcludedApps + } guard let anchor = screenshots.first(where: { ss in !ss.appName.isEmpty - && !TaskAssistantSettings.builtInExcludedApps.contains(ss.appName) + && !builtInExcludedApps.contains(ss.appName) && !excludedApps.contains(ss.appName) }) else { log("InsightTestCLI: \(label) \(timeFormatter.string(from: window.windowEnd)) — skipped (no non-excluded screenshots)") diff --git a/desktop/Desktop/Sources/Providers/ChatProvider.swift b/desktop/Desktop/Sources/Providers/ChatProvider.swift index b19a1ca56d9..2158ad1df2e 100644 --- a/desktop/Desktop/Sources/Providers/ChatProvider.swift +++ b/desktop/Desktop/Sources/Providers/ChatProvider.swift @@ -196,6 +196,26 @@ enum ToolCallStatus { case completed } +private actor SQLToolUsageStats { + private var rowsReturned = 0 + private var queryCount = 0 + + func record(toolName: String, result: String) { + guard toolName == "execute_sql" else { return } + + queryCount += 1 + // Parse row count from result (format: "\nN row(s)" at end) + if let match = result.range(of: #"(\d+) row\(s\)"#, options: .regularExpression) { + let numStr = result[match].components(separatedBy: " ").first ?? "0" + rowsReturned += Int(numStr) ?? 0 + } + } + + func snapshot() -> (rowsReturned: Int, queryCount: Int) { + (rowsReturned, queryCount) + } +} + // MARK: - Chat Message Model /// Metadata about the context and resources used to generate an AI response @@ -2443,10 +2463,13 @@ A screenshot may be attached — use it silently only if relevant. Never mention return } + let providerRoute = HybridChatClient.currentRoute() + let usesOmiMeteredBridge = isUsingOmiAccountProvider && !providerRoute.usesDirectProvider + // Monthly free-tier limit shared with the floating bar (30 messages/month). - // Block the send, surface the popup, and let the user upgrade. + // Block metered Omi bridge sends, surface the popup, and let the user upgrade. let usageLimiter = FloatingBarUsageLimiter.shared - if isUsingOmiAccountProvider { + if usesOmiMeteredBridge { if usageLimiter.isLimitReached { log("ChatProvider: sendMessage blocked — free-tier monthly chat limit reached") errorMessage = "You've reached \(usageLimiter.limitDescription). Upgrade to keep chatting." @@ -2460,11 +2483,9 @@ A screenshot may be attached — use it silently only if relevant. Never mention usageLimiter.recordQuery() } - let mayUseHybridDirectChat = HybridChatClient.isEnabled() - - // Ensure Claude / ACP bridge when not using hybrid direct chat. Hybrid path may + // Ensure Claude / ACP bridge when not using a direct provider. Direct paths may // skip the bridge until multimodal attachments require ACP. - if !mayUseHybridDirectChat { + if !providerRoute.usesDirectProvider { guard await ensureBridgeStarted() else { errorMessage = "AI not available" return @@ -2472,7 +2493,7 @@ A screenshot may be attached — use it silently only if relevant. Never mention } // Show upgrade prompt if over threshold but don't block the message - if bridgeMode != BridgeMode.userClaude.rawValue && omiAICumulativeCostUsd >= 50.0 { + if usesOmiMeteredBridge && omiAICumulativeCostUsd >= 50.0 { showOmiThresholdAlert = true } @@ -2605,8 +2626,7 @@ A screenshot may be attached — use it silently only if relevant. Never mention let queryStartTime = Date() var toolNames: [String] = [] var toolStartTimes: [String: Date] = [:] - var sqlRowsReturned = 0 - var sqlQueryCount = 0 + let sqlToolUsageStats = SQLToolUsageStats() var hybridResolvedModel: String? var hybridProviderAccountId: String? var hybridProviderKind: String? @@ -2614,7 +2634,7 @@ A screenshot may be attached — use it silently only if relevant. Never mention var hybridSlotReason: String? do { - if mayUseHybridDirectChat { + if providerRoute.usesDirectProvider { await preparePromptContextIfNeeded() if !isOnboarding { cachedMainSystemPrompt = buildSystemPrompt(contextString: formatMemoriesSection()) @@ -2662,13 +2682,14 @@ A screenshot may be attached — use it silently only if relevant. Never mention } } - let useHybridNow = - mayUseHybridDirectChat && effectiveImageData == nil + let useDirectProviderNow = + providerRoute.usesDirectProvider + && (effectiveImageData == nil || providerRoute.supportsInlineImages) && !attachmentsForMessage.contains(where: { $0.isImage }) let queryResult: AgentBridge.QueryResult - if useHybridNow { + if useDirectProviderNow { let historyPairs: [(role: String, text: String)] = messages .dropLast(2) @@ -2679,10 +2700,11 @@ A screenshot may be attached — use it silently only if relevant. Never mention text: msg.text ) } - let hybrid = try await HybridChatClient.completeFromDaemonSettings( + let hybrid = try await HybridChatClient.completeWithActiveDirectProvider( systemPrompt: systemPrompt, conversationMessages: historyPairs, - userMessage: trimmedText + userMessage: trimmedText, + imageData: providerRoute.supportsInlineImages ? effectiveImageData : nil ) hybridResolvedModel = hybrid.model hybridProviderAccountId = hybrid.providerAccountID @@ -2700,7 +2722,7 @@ A screenshot may be attached — use it silently only if relevant. Never mention cacheWriteTokens: 0 ) } else { - if mayUseHybridDirectChat && !useHybridNow { + if providerRoute.usesDirectProvider && !useDirectProviderNow { guard await ensureBridgeStarted() else { throw BridgeError.notRunning } @@ -2717,15 +2739,7 @@ A screenshot may be attached — use it silently only if relevant. Never mention let toolCall = ToolCall(name: name, arguments: input, thoughtSignature: nil) let result = await ChatToolExecutor.execute(toolCall) log("OMI tool \(name) executed for callId=\(callId)") - // Track SQL query stats for metadata - if name == "execute_sql" { - sqlQueryCount += 1 - // Parse row count from result (format: "\nN row(s)" at end) - if let match = result.range(of: #"(\d+) row\(s\)"#, options: .regularExpression) { - let numStr = result[match].components(separatedBy: " ").first ?? "0" - sqlRowsReturned += Int(numStr) ?? 0 - } - } + await sqlToolUsageStats.record(toolName: name, result: result) return result } let toolActivityHandler: AgentBridge.ToolActivityHandler = { @@ -2829,6 +2843,7 @@ A screenshot may be attached — use it silently only if relevant. Never mention // Determine the final text to display and save let messageText: String + let sqlToolUsageSnapshot = await sqlToolUsageStats.snapshot() if let index = messages.firstIndex(where: { $0.id == aiMessageId }) { // Message still in memory — update it in-place messageText = messages[index].text.isEmpty ? queryResult.text : messages[index].text @@ -2845,8 +2860,8 @@ A screenshot may be attached — use it silently only if relevant. Never mention hasScreenshot: imageData != nil, screenshotSizeBytes: imageData?.count, toolNames: toolNames, - sqlRowsReturned: sqlRowsReturned, - sqlQueryCount: sqlQueryCount + sqlRowsReturned: sqlToolUsageSnapshot.rowsReturned, + sqlQueryCount: sqlToolUsageSnapshot.queryCount ) completeRemainingToolCalls(messageId: aiMessageId) } else { diff --git a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift index d31f950bbe4..0fa8edf1713 100644 --- a/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift +++ b/desktop/Desktop/Sources/Rewind/Core/TranscriptionStorage.swift @@ -149,7 +149,7 @@ actor TranscriptionStorage { /// Mark session as failed with error. /// No-op if the session is already completed (prevents race with concurrent completion). - func markSessionFailed(id: Int64, error: String) async throws { + func markSessionFailed(id: Int64, error: String, retryCount: Int? = nil) async throws { let db = try await ensureInitialized() try await db.write { database in @@ -165,6 +165,9 @@ actor TranscriptionStorage { record.status = .failed record.lastError = error + if let retryCount { + record.retryCount = retryCount + } record.updatedAt = Date() try record.update(database) } diff --git a/desktop/Desktop/Sources/TranscriptionRetryService.swift b/desktop/Desktop/Sources/TranscriptionRetryService.swift index a96923c73e2..51ca235cf17 100644 --- a/desktop/Desktop/Sources/TranscriptionRetryService.swift +++ b/desktop/Desktop/Sources/TranscriptionRetryService.swift @@ -23,6 +23,9 @@ class TranscriptionRetryService { if session.status == .completed { return session.backendId } + if try await markEmptyLocalDaemonSessionExhaustedIfNeeded(sessionId: sessionId) { + return nil + } let conversation = try await uploadSessionToLocalDaemon(session, sessionId: sessionId) return conversation.id } @@ -230,6 +233,9 @@ class TranscriptionRetryService { do { if await APIClient.shared.isUsingLocalDaemon { + if try await markEmptyLocalDaemonSessionExhaustedIfNeeded(sessionId: sessionId) { + return + } try await uploadSessionToLocalDaemon(session, sessionId: sessionId) return } @@ -279,7 +285,9 @@ class TranscriptionRetryService { let segments = try await TranscriptionStorage.shared.getSegments(sessionId: sessionId) guard !segments.isEmpty else { try await TranscriptionStorage.shared.markSessionFailed( - id: sessionId, error: "No transcript segments to upload to local daemon") + id: sessionId, + error: "No transcript segments to upload to local daemon", + retryCount: maxRetries) throw APIError.httpError(statusCode: 400) } @@ -303,4 +311,16 @@ class TranscriptionRetryService { return conversation } + private func markEmptyLocalDaemonSessionExhaustedIfNeeded(sessionId: Int64) async throws -> Bool { + let segmentCount = try await TranscriptionStorage.shared.getSegmentCount(sessionId: sessionId) + guard segmentCount == 0 else { return false } + + try await TranscriptionStorage.shared.markSessionFailed( + id: sessionId, + error: "No transcript segments to upload to local daemon", + retryCount: maxRetries) + log("TranscriptionRetryService: Exhausted empty local daemon session \(sessionId) without upload") + return true + } + } diff --git a/desktop/Desktop/Tests/HybridChatClientTests.swift b/desktop/Desktop/Tests/HybridChatClientTests.swift index 021cb4d700c..de5caf4396a 100644 --- a/desktop/Desktop/Tests/HybridChatClientTests.swift +++ b/desktop/Desktop/Tests/HybridChatClientTests.swift @@ -85,6 +85,14 @@ final class HybridChatClientTests: XCTestCase { override func setUp() { super.setUp() ChatProviderCapture.reset() + UserDefaults.standard.removeObject(forKey: "codex_auth_enrolled") + UserDefaults.standard.removeObject(forKey: "codex_preferred_model") + } + + override func tearDown() { + UserDefaults.standard.removeObject(forKey: "codex_auth_enrolled") + UserDefaults.standard.removeObject(forKey: "codex_preferred_model") + super.tearDown() } func testChatSlotResolutionBuildsProviderConfig() { @@ -125,6 +133,68 @@ final class HybridChatClientTests: XCTestCase { XCTAssertEqual(result.model, "stub-model") } + func testImagePayloadUsesChatCompletionsContentParts() async throws { + let session = capturedSession() + let response = slotResolution( + accountID: "chatgpt-plan", + baseURL: "http://127.0.0.1:10531/v1", + model: "gpt-5.4" + ) + + _ = try await HybridChatClient.complete( + systemPrompt: "system", + conversationMessages: [], + userMessage: "what is on screen?", + slotResolution: response, + imageData: Data([0x89, 0x50, 0x4E, 0x47]), + session: session + ) + + let body = try XCTUnwrap(ChatProviderCapture.bodies.first) + let json = try XCTUnwrap(JSONSerialization.jsonObject(with: body) as? [String: Any]) + let messages = try XCTUnwrap(json["messages"] as? [[String: Any]]) + let user = try XCTUnwrap(messages.last) + let content = try XCTUnwrap(user["content"] as? [[String: Any]]) + + XCTAssertEqual(content.first?["type"] as? String, "text") + XCTAssertEqual(content.first?["text"] as? String, "what is on screen?") + XCTAssertEqual(content.last?["type"] as? String, "image_url") + let imageURL = try XCTUnwrap(content.last?["image_url"] as? [String: Any]) + XCTAssertTrue((imageURL["url"] as? String)?.hasPrefix("data:image/png;base64,") == true) + } + + func testCodexActiveBypassesDaemonSlotResolution() async throws { + let tempAuth = try makeTempCodexHomeWithAuth() + defer { tempAuth.cleanup() } + UserDefaults.standard.set(true, forKey: "codex_auth_enrolled") + let session = capturedSession() + + let result = try await HybridChatClient.completeWithActiveDirectProvider( + systemPrompt: "system", + conversationMessages: [], + userMessage: "hello", + session: session, + ensureCodexProxy: false + ) + + let request = try XCTUnwrap(ChatProviderCapture.requests.first) + XCTAssertEqual(request.url?.absoluteString, "http://127.0.0.1:10531/v1/chat/completions") + XCTAssertEqual(result.providerAccountID, "chatgpt-plan") + XCTAssertEqual(result.slotSource, "chatgpt_plan") + } + + func testCurrentRouteUsesCodexWhenActive() throws { + let tempAuth = try makeTempCodexHomeWithAuth() + defer { tempAuth.cleanup() } + UserDefaults.standard.set(true, forKey: "codex_auth_enrolled") + + let route = HybridChatClient.currentRoute() + + XCTAssertEqual(route.displayName, "ChatGPT plan") + XCTAssertTrue(route.usesDirectProvider) + XCTAssertTrue(route.supportsInlineImages) + } + func testProviderAccountSwitchChangesRequestTargetAndModel() async throws { let session = capturedSession() @@ -152,7 +222,9 @@ final class HybridChatClientTests: XCTestCase { ) let requests = ChatProviderCapture.requests - XCTAssertEqual(requests.map { $0.url?.absoluteString }, [ + XCTAssertEqual( + requests.map { $0.url?.absoluteString }, + [ "http://127.0.0.1:11434/v1/chat/completions", "http://localhost:43210/v1/chat/completions", ]) @@ -243,3 +315,38 @@ final class HybridChatClientTests: XCTestCase { return HybridProviderPolicy.SlotResolutionResponse(resolved: resolved, resolution: resolution) } } + +private struct TempCodexHomeForHybridChat { + let path: String + let previous: String? + + func cleanup() { + if let previous { + setenv("CODEX_HOME", previous, 1) + } else { + unsetenv("CODEX_HOME") + } + try? FileManager.default.removeItem(atPath: path) + } +} + +private func makeTempCodexHomeWithAuth() throws -> TempCodexHomeForHybridChat { + let dir = FileManager.default.temporaryDirectory + .appendingPathComponent("hybrid-chat-codex-auth-\(UUID().uuidString)") + try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) + let authURL = dir.appendingPathComponent("auth.json") + let payload = """ + { + "auth_mode": "chatgpt", + "tokens": { + "access_token": "test-access", + "refresh_token": "test-refresh", + "account_id": "acct-test" + } + } + """ + try payload.write(to: authURL, atomically: true, encoding: .utf8) + let previous = ProcessInfo.processInfo.environment["CODEX_HOME"] + setenv("CODEX_HOME", dir.path, 1) + return TempCodexHomeForHybridChat(path: dir.path, previous: previous) +} diff --git a/desktop/Desktop/Tests/TranscriptComparisonTests.swift b/desktop/Desktop/Tests/TranscriptComparisonTests.swift index 960f9a2dbc8..56ab85b61af 100644 --- a/desktop/Desktop/Tests/TranscriptComparisonTests.swift +++ b/desktop/Desktop/Tests/TranscriptComparisonTests.swift @@ -33,4 +33,37 @@ final class TranscriptComparisonTests: XCTestCase { XCTAssertEqual(TranscriptComparison.wordErrorRate(reference: "", hypothesis: ""), 0) XCTAssertEqual(TranscriptComparison.wordErrorRate(reference: "", hypothesis: "extra"), 1) } + + @MainActor + func testComparisonHarnessPublishesWhisperSnapshotWithoutDeepgramKey() { + var snapshots: [TranscriptionComparisonHarnessSnapshot] = [] + let harness = TranscriptionComparisonHarness( + language: "en", + deepgramAPIKey: nil, + onSnapshot: { snapshots.append($0) } + ) + + harness.start() + harness.appendWhisperSegments([ + NormalizedTranscriptSegment( + segmentId: "whisper-1", + speaker: 0, + speakerLabel: nil, + text: "hello local whisper", + start: 0, + end: 1, + isUser: true, + personId: nil, + translations: [] + ) + ]) + + let snapshot = try! XCTUnwrap(snapshots.last) + XCTAssertTrue(snapshot.isRunning) + XCTAssertEqual(snapshot.whisper.transcript, "hello local whisper") + XCTAssertEqual(snapshot.whisper.wordCount, 3) + XCTAssertEqual(snapshot.deepgram.status, "Missing Deepgram API key") + XCTAssertNotNil(snapshot.deepgram.error) + XCTAssertNil(snapshot.wordDifferenceRate) + } } diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index 72245b1c030..e6415f1a569 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -157,6 +157,104 @@ final class TranscriptionProviderPolicyTests: XCTestCase { } } +final class LocalASRAddonManifestTests: XCTestCase { + func testRemoteManifestParsesRuntimeAndModels() throws { + let json = """ + { + "version": 1, + "runtime": { + "version": "2026.05.20", + "platform": "macos", + "arch": "arm64", + "url": "https://example.com/runtime.zip", + "sha256": "abc123", + "size_bytes": 123456, + "minimum_app_version": "0.2.0" + }, + "models": [ + { + "model": "small", + "version": "mlx-2026.05.20", + "url": "https://example.com/model-small.zip", + "sha256": "def456", + "size_bytes": 654321 + }, + { + "model": "large_v3_turbo", + "version": "mlx-2026.05.20", + "url": "https://example.com/model-large.zip", + "sha256": "fed654", + "size_bytes": 987654 + } + ] + } + """.data(using: .utf8)! + + let manifest = try JSONDecoder.localASRAddon.decode( + LocalASRAddonRemoteManifest.self, from: json) + + XCTAssertEqual(manifest.version, 1) + XCTAssertEqual(manifest.runtime.platform, "macos") + XCTAssertEqual(manifest.runtime.arch, "arm64") + XCTAssertEqual(manifest.runtime.sizeBytes, 123456) + XCTAssertEqual(manifest.runtime.minimumAppVersion, "0.2.0") + XCTAssertEqual(manifest.models.map(\.model), [.small, .largeV3Turbo]) + } + + func testInstalledManifestRoundTripsManagedState() throws { + let installedAt = Date(timeIntervalSince1970: 1_800_000_000) + let manifest = LocalASRAddonInstalledManifest( + schemaVersion: 1, + runtimeVersion: "2026.05.20", + runtimeSha256: "runtime-sha", + pythonPath: "/tmp/Omi/LocalASR/runtime/bin/python3", + installedAt: installedAt, + models: [ + LocalASRAddonInstalledManifest.InstalledModel( + model: .small, + version: "mlx-2026.05.20", + sha256: "model-sha", + path: "/tmp/Omi/LocalASR/models/small", + installedAt: installedAt + ) + ] + ) + + let encoded = try JSONEncoder.localASRAddon.encode(manifest) + let decoded = try JSONDecoder.localASRAddon.decode( + LocalASRAddonInstalledManifest.self, from: encoded) + + XCTAssertEqual(decoded, manifest) + } + + func testInitialModelUsesBaseForFastPreset() { + XCTAssertEqual(LocalASRAddonManager.initialModel(for: .fast), .base) + } + + func testLocalDevManifestMessageExplainsSentinelRustBackend() { + let message = LocalASRAddonManager.localDevManifestConfigurationMessage( + modeValue: "local", + rustBackendURL: "http://omi-rust-invalid:9002/", + manifestOverride: nil + ) + + XCTAssertEqual( + message, + "Local Whisper add-on manifest is not configured for local dev. Run `make local-asr-fixture`, or set OMI_LOCAL_ASR_MANIFEST_URL." + ) + } + + func testLocalDevManifestMessageAllowsExplicitOverride() { + let message = LocalASRAddonManager.localDevManifestConfigurationMessage( + modeValue: "local", + rustBackendURL: "http://omi-rust-invalid:9002/", + manifestOverride: "file:///tmp/omi-local-asr-fixture/manifest.json" + ) + + XCTAssertNil(message) + } +} + final class SpeakerSegmentReducerTests: XCTestCase { func testReducerAddsUpdatesAndPreservesTranslations() { var reducer = SpeakerSegmentReducer(maxInMemorySegments: 10) @@ -772,7 +870,9 @@ final class LocalBackgroundLifecycleTests: XCTestCase { } func testLocalConversationTitleUsesDeterministicTranscriptPrefix() { - XCTAssertEqual(AppState.localConversationTitle(from: " hello local background "), "hello local background") + XCTAssertEqual( + AppState.localConversationTitle(from: " hello local background "), + "hello local background") XCTAssertEqual(AppState.localConversationTitle(from: " "), "Local transcription") } diff --git a/desktop/local-asr-addon/build_dev_fixture.sh b/desktop/local-asr-addon/build_dev_fixture.sh new file mode 100755 index 00000000000..5474302871d --- /dev/null +++ b/desktop/local-asr-addon/build_dev_fixture.sh @@ -0,0 +1,133 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +OUT_DIR="${OMI_LOCAL_ASR_FIXTURE_DIR:-/tmp/omi-local-asr-fixture}" +PYTHON_BIN="${OMI_LOCAL_ASR_PYTHON:-python3}" +MODEL="${OMI_LOCAL_ASR_FIXTURE_MODEL:-small}" +VERSION="${OMI_LOCAL_ASR_FIXTURE_VERSION:-dev-$(date +%Y%m%d%H%M%S)}" + +case "$MODEL" in + tiny|base|small|medium|large_v3_turbo) ;; + *) echo "Unsupported OMI_LOCAL_ASR_FIXTURE_MODEL: $MODEL" >&2; exit 2 ;; +esac + +repo_for_model() { + case "$1" in + tiny) echo "mlx-community/whisper-tiny-mlx" ;; + base) echo "mlx-community/whisper-base-mlx" ;; + small) echo "mlx-community/whisper-small-mlx" ;; + medium) echo "mlx-community/whisper-medium-mlx" ;; + large_v3_turbo) echo "mlx-community/whisper-large-v3-turbo" ;; + esac +} + +mkdir -p "$OUT_DIR" +rm -rf "$OUT_DIR/runtime" "$OUT_DIR/model-$MODEL" "$OUT_DIR"/*.zip "$OUT_DIR/manifest.json" + +PYTHON_ABS="$("$PYTHON_BIN" - <<'PY' +import os +import sys + +print(os.path.realpath(sys.executable)) +PY +)" + +"$PYTHON_ABS" - <<'PY' +import importlib.util +import platform +import sys + +missing = [name for name in ("mlx", "mlx_whisper", "huggingface_hub") if importlib.util.find_spec(name) is None] +if missing: + raise SystemExit( + "Python is missing required Local Whisper packages: " + + ", ".join(missing) + + "\nInstall them in the selected Python, or set OMI_LOCAL_ASR_PYTHON to a ready environment." + ) +if platform.machine() != "arm64": + raise SystemExit(f"MLX Whisper fixture requires arm64 Python, got {platform.machine()}") +PY + +RUNTIME_DIR="$OUT_DIR/runtime/runtime" +mkdir -p "$RUNTIME_DIR/bin" +cat > "$RUNTIME_DIR/bin/python3" <&2 + exit 1 +fi + +MODEL_STAGE="$OUT_DIR/model-$MODEL/model-$MODEL" +mkdir -p "$MODEL_STAGE" +# Hugging Face snapshots are symlink trees into the cache's blobs directory. +# The artifact must be self-contained after unzip, so copy dereferenced files. +rsync -aL --delete "$MODEL_SOURCE"/ "$MODEL_STAGE"/ + +SAFE_MODEL="${MODEL//_/-}" +MODEL_ZIP="$OUT_DIR/model-$SAFE_MODEL-$VERSION.zip" +(cd "$OUT_DIR/model-$MODEL" && /usr/bin/zip -qry "$MODEL_ZIP" "model-$MODEL") + +RUNTIME_SHA="$(shasum -a 256 "$RUNTIME_ZIP" | awk '{print $1}')" +RUNTIME_SIZE="$(stat -f%z "$RUNTIME_ZIP")" +MODEL_SHA="$(shasum -a 256 "$MODEL_ZIP" | awk '{print $1}')" +MODEL_SIZE="$(stat -f%z "$MODEL_ZIP")" + +"$PYTHON_ABS" - "$OUT_DIR/manifest.json" "$RUNTIME_ZIP" "$MODEL_ZIP" "$VERSION" "$RUNTIME_SHA" "$RUNTIME_SIZE" "$MODEL" "$MODEL_SHA" "$MODEL_SIZE" <<'PY' +import json +import pathlib +import sys + +manifest_path, runtime_zip, model_zip, version, runtime_sha, runtime_size, model, model_sha, model_size = sys.argv[1:] +manifest = { + "version": 1, + "runtime": { + "version": version, + "platform": "macos", + "arch": "arm64", + "url": pathlib.Path(runtime_zip).resolve().as_uri(), + "sha256": runtime_sha, + "size_bytes": int(runtime_size), + "minimum_app_version": None, + }, + "models": [ + { + "model": model, + "version": version, + "url": pathlib.Path(model_zip).resolve().as_uri(), + "sha256": model_sha, + "size_bytes": int(model_size), + } + ], +} +pathlib.Path(manifest_path).write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n") +PY + +cat <&2; exit 2 ;; +esac + +rm -rf "$STAGE" "$ZIP_PATH" +mkdir -p "$STAGE" "$OUT_DIR" +rsync -a --delete "$MODEL_DIR"/ "$STAGE"/ + +(cd "$OUT_DIR" && /usr/bin/zip -qry "$(basename "$ZIP_PATH")" "$(basename "$STAGE")") + +SHA256="$(shasum -a 256 "$ZIP_PATH" | awk '{print $1}')" +SIZE_BYTES="$(stat -f%z "$ZIP_PATH")" +ENV_MODEL="$(echo "$MODEL" | tr '[:lower:]' '[:upper:]')" + +cat < { return Err(ApiError::conflict( - "transcript segment already exists with different content at this index", + "transcript segment already exists with different content or id", )); } }; diff --git a/desktop/local-backend/src/storage.rs b/desktop/local-backend/src/storage.rs index a87e5bcfc7a..2fdc4e78ab6 100644 --- a/desktop/local-backend/src/storage.rs +++ b/desktop/local-backend/src/storage.rs @@ -1234,6 +1234,14 @@ impl TranscriptRepository { }; } + if let Some(existing) = self.get_by_id(&new.id)? { + return if transcript_matches_new(&existing, &new)? { + Ok(AppendTranscriptResult::Existing(existing)) + } else { + Ok(AppendTranscriptResult::Conflict(existing)) + }; + } + let now = Utc::now(); let segment = TranscriptSegment { id: new.id, @@ -1333,6 +1341,23 @@ impl TranscriptRepository { .context("failed to fetch transcript segment by index") } + pub fn get_by_id(&self, id: &str) -> Result> { + let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); + conn.query_row( + r#" + SELECT id, conversation_id, session_id, speaker_id, speaker_label, text, start_ms, + end_ms, segment_index, source, created_at, updated_at, deleted_at, cloud_id, + sync_version, sync_state, metadata_json + FROM transcript_segments + WHERE id = ?1 AND deleted_at IS NULL + "#, + params![id], + map_transcript_segment, + ) + .optional() + .context("failed to fetch transcript segment by id") + } + pub fn next_segment_index(&self, conversation_id: &str) -> Result { let conn = self.conn.lock().expect("SQLite connection mutex poisoned"); conn.query_row( @@ -3192,6 +3217,62 @@ mod tests { Ok(()) } + #[test] + fn duplicate_transcript_id_across_conversations_is_conflict() -> Result<()> { + let store = Store::open_in_memory()?; + let first_conversation_id = deterministic_id("conv", &["session-duplicate-id-a"]); + let second_conversation_id = deterministic_id("conv", &["session-duplicate-id-b"]); + + for (conversation_id, session_id) in [ + (&first_conversation_id, "session-duplicate-id-a"), + (&second_conversation_id, "session-duplicate-id-b"), + ] { + store.conversations().create(NewConversation { + id: conversation_id.clone(), + session_id: session_id.to_string(), + title: String::new(), + overview: String::new(), + started_at: None, + metadata: None, + })?; + } + + let shared_segment_id = "apple-hybrid-live".to_string(); + assert!(matches!( + store.transcripts().append(NewTranscriptSegment { + id: shared_segment_id.clone(), + conversation_id: first_conversation_id, + session_id: "session-duplicate-id-a".to_string(), + speaker_id: Some("speaker-1".to_string()), + speaker_label: Some("Alice".to_string()), + text: "First conversation text.".to_string(), + start_ms: 0, + end_ms: 1200, + segment_index: 0, + source: None, + metadata: None, + })?, + AppendTranscriptResult::Inserted(_) + )); + + let duplicate = store.transcripts().append(NewTranscriptSegment { + id: shared_segment_id, + conversation_id: second_conversation_id, + session_id: "session-duplicate-id-b".to_string(), + speaker_id: Some("speaker-1".to_string()), + speaker_label: Some("Alice".to_string()), + text: "Second conversation text.".to_string(), + start_ms: 0, + end_ms: 1200, + segment_index: 0, + source: None, + metadata: None, + })?; + assert!(matches!(duplicate, AppendTranscriptResult::Conflict(_))); + + Ok(()) + } + #[test] fn conversation_starred_updates_persist() -> Result<()> { let store = Store::open_in_memory()?; diff --git a/desktop/run.sh b/desktop/run.sh index 7b66f07c35b..605f87d0ffd 100755 --- a/desktop/run.sh +++ b/desktop/run.sh @@ -27,6 +27,7 @@ Options (via environment variables): OMI_HYBRID_DIRECT_STT_ENABLED Hybrid Apple Speech live transcription in local daemon (default 1 in configure_local_daemon_mode when unset) OMI_HYBRID_DIRECT_CHAT_ENABLED Hybrid OpenAI-compatible chat + daemon-backed sessions/messages (default 1 in configure_local_daemon_mode when unset) OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED Optional hybrid direct embeddings for vector search (default 0 in local bundle; local wiki search does not require embeddings) + OMI_LOCAL_ASR_MANIFEST_URL="..." Production-shaped Local Whisper add-on manifest URL OMI_SKIP_STALE_BUNDLE_SCAN=1 Skip scanning $HOME for stale dev app bundles Required files for cloud backend mode: @@ -672,6 +673,10 @@ if is_local_daemon_mode; then set_bundle_env "OMI_HYBRID_DIRECT_STT_ENABLED" "${OMI_HYBRID_DIRECT_STT_ENABLED:-1}" set_bundle_env "OMI_HYBRID_DIRECT_CHAT_ENABLED" "${OMI_HYBRID_DIRECT_CHAT_ENABLED:-1}" set_bundle_env "OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED" "${OMI_HYBRID_DIRECT_EMBEDDINGS_ENABLED:-0}" + if [ -n "${OMI_LOCAL_ASR_MANIFEST_URL:-}" ]; then + set_bundle_env "OMI_LOCAL_ASR_MANIFEST_URL" "$OMI_LOCAL_ASR_MANIFEST_URL" + substep "OMI_LOCAL_ASR_MANIFEST_URL=$OMI_LOCAL_ASR_MANIFEST_URL" + fi substep "OMI_DESKTOP_BACKEND_MODE=local" substep "OMI_LOCAL_DAEMON_URL=$OMI_LOCAL_DAEMON_URL" substep "OMI_HYBRID_DIRECT_STT_ENABLED=${OMI_HYBRID_DIRECT_STT_ENABLED:-1}" diff --git a/scripts/hybrid-local.sh b/scripts/hybrid-local.sh index 4998565527c..99fcf8d06d1 100755 --- a/scripts/hybrid-local.sh +++ b/scripts/hybrid-local.sh @@ -12,6 +12,28 @@ export OMI_LOCAL_BACKEND_HOST="${OMI_LOCAL_BACKEND_HOST:-127.0.0.1}" export OMI_LOCAL_BACKEND_PORT="${OMI_LOCAL_BACKEND_PORT:-8765}" export OMI_LOCAL_BACKEND_DATA_DIR="${OMI_LOCAL_BACKEND_DATA_DIR:-/tmp/omi-local-mvp}" export OMI_LOCAL_DAEMON_LOG="${OMI_LOCAL_DAEMON_LOG:-/tmp/omi-local-backend-dev.log}" +export OMI_LOCAL_ASR_FIXTURE_DIR="${OMI_LOCAL_ASR_FIXTURE_DIR:-/tmp/omi-local-asr-fixture}" + +file_url_for_path() { + python3 - "$1" <<'PY' +from pathlib import Path +import sys + +print(Path(sys.argv[1]).resolve().as_uri()) +PY +} + +configure_local_asr_manifest() { + if [ -n "${OMI_LOCAL_ASR_MANIFEST_URL:-}" ]; then + return + fi + + local fixture_manifest="${OMI_LOCAL_ASR_FIXTURE_DIR}/manifest.json" + if [ -f "$fixture_manifest" ]; then + OMI_LOCAL_ASR_MANIFEST_URL="$(file_url_for_path "$fixture_manifest")" + export OMI_LOCAL_ASR_MANIFEST_URL + fi +} # Shared hybrid desktop env (matches local-mvp-runbook.md). hybrid_desktop_env() { @@ -21,6 +43,7 @@ hybrid_desktop_env() { export OMI_LOCAL_BACKEND_PORT export OMI_PYTHON_API_URL="${OMI_PYTHON_API_URL:-http://omi-cloud-invalid:9001}" export OMI_DESKTOP_API_URL="${OMI_DESKTOP_API_URL:-http://omi-rust-invalid:9002}" + configure_local_asr_manifest } local_daemon_health_ok() { @@ -102,16 +125,23 @@ EOF } desktop_pane_command() { + hybrid_desktop_env cat </dev/null 2>&1; do @@ -208,6 +238,8 @@ Environment (optional): OMI_LOCAL_BACKEND_DATA_DIR SQLite data dir (default: /tmp/omi-local-mvp) OMI_DAEMON_HEALTH_WAIT_SECS wait for /health (default: 180) OMI_APP_NAME desktop bundle name (default: Omi Dev) + OMI_LOCAL_ASR_MANIFEST_URL production-shaped local ASR add-on manifest URL + OMI_LOCAL_ASR_FIXTURE_DIR auto-detected fixture dir (default: /tmp/omi-local-asr-fixture) See desktop/local-backend/docs/local-mvp-runbook.md EOF From 00756cd4099fe110d4e090a8c649a941dd37035d Mon Sep 17 00:00:00 2001 From: David Zhang Date: Wed, 20 May 2026 17:08:31 -0400 Subject: [PATCH 58/58] Polish local desktop transcription setup --- desktop/Desktop/Sources/AppState.swift | 79 +++- .../Desktop/Sources/CodexProxyService.swift | 61 ++- .../PushToTalkManager.swift | 94 ++-- .../LocalTranscription/LocalASRRuntime.swift | 90 ++-- .../TranscriptionComparisonHarness.swift | 267 ++++++++++-- .../Sources/MainWindow/DesktopHomeView.swift | 127 +++++- .../Pages/HybridProviderSetupSheet.swift | 216 ++++++++++ .../MainWindow/Pages/SettingsPage.swift | 404 ++++++++++-------- .../Sources/MainWindow/SettingsSidebar.swift | 26 +- desktop/Desktop/Sources/SignInView.swift | 122 +++--- .../Tests/TranscriptComparisonTests.swift | 68 +++ .../TranscriptionProviderPolicyTests.swift | 96 +++++ desktop/codex-proxy/src/main.rs | 97 ++++- 13 files changed, 1388 insertions(+), 359 deletions(-) create mode 100644 desktop/Desktop/Sources/MainWindow/Pages/HybridProviderSetupSheet.swift diff --git a/desktop/Desktop/Sources/AppState.swift b/desktop/Desktop/Sources/AppState.swift index 39d00287125..db0f61f55e9 100644 --- a/desktop/Desktop/Sources/AppState.swift +++ b/desktop/Desktop/Sources/AppState.swift @@ -140,7 +140,11 @@ class AppState: ObservableObject { /// Trigger the monthly-limit popup. Safe to call repeatedly — SwiftUI's /// `@Published` dedupes identical-value writes automatically. + /// No-op in local daemon mode: there is no cloud subscription to upsell. func triggerUsageLimitPopup(reason: String) { + if DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon { + return + } usageLimitReason = reason showUsageLimitPopup = true } @@ -210,6 +214,7 @@ class AppState: ObservableObject { private var transcriptionService: TranscriptionService? private var localBackgroundSession: LocalBackgroundTranscriptionSession? private var localBackgroundASRTask: Task? + private var localTranscriptionCapabilityProbeTask: Task? private(set) var localBackgroundState: LocalBackgroundSessionState? private var localBackgroundSampleCursor: Int64 = 0 @Published private(set) var transcriptionComparisonSnapshot = @@ -1288,13 +1293,22 @@ class AppState: ObservableObject { // Use provided source or fall back to current setting let effectiveSource = source ?? audioSource + let providerSelection = AssistantSettings.shared.transcriptionProviderSelection let backgroundRouting = BackgroundTranscriptionRoutingGuard().decide( - selection: AssistantSettings.shared.transcriptionProviderSelection, + selection: providerSelection, capabilities: LocalTranscriptionCapabilityDetector( availableEngines: { LocalASRHelperLocator.detectedEngines() } ).detect() ) + if shouldRefreshLocalTranscriptionCapabilities( + selection: providerSelection, + routing: backgroundRouting + ) { + refreshLocalTranscriptionCapabilitiesThenStart(source: effectiveSource) + return + } + // Paywall hard-stop applies only to the cloud listen path. Local background // MLX/faster-whisper capture never opens `/v4/listen` and should keep // working for users who selected a local provider. @@ -1462,6 +1476,46 @@ class AppState: ObservableObject { } } + private func shouldRefreshLocalTranscriptionCapabilities( + selection: TranscriptionProviderSelection, + routing: BackgroundTranscriptionRoutingDecision + ) -> Bool { + guard selection.mode != .cloud else { return false } + guard routing.localPlan == nil else { return false } + return LocalASRAddonManager.status().isInstalled + } + + private func refreshLocalTranscriptionCapabilitiesThenStart(source: AudioSource) { + guard localTranscriptionCapabilityProbeTask == nil else { + log("Transcription: Waiting for local transcription capability probe") + return + } + + log("Transcription: Refreshing local transcription capabilities before start") + localTranscriptionCapabilityProbeTask = Task { [weak self] in + let engines = await LocalASRHelperLocator.refreshDetectedEngines() + guard !Task.isCancelled else { return } + + await MainActor.run { + guard let self else { return } + self.localTranscriptionCapabilityProbeTask = nil + guard !Task.isCancelled else { return } + let selection = AssistantSettings.shared.transcriptionProviderSelection + let refreshedRouting = BackgroundTranscriptionRoutingGuard().decide( + selection: selection, + capabilities: LocalTranscriptionCapabilityDetector(availableEngines: { engines }).detect() + ) + + if case .unavailable(let reason) = refreshedRouting.route { + self.openLocalTranscriptionRepair(message: reason) + return + } + + self.startTranscription(source: source) + } + } + } + /// Start local background transcription without creating a backend `/v4/listen` session. private func startLocalBackgroundTranscription(source: AudioSource, plan: LocalTranscriptionPlan) { @@ -1691,21 +1745,21 @@ class AppState: ObservableObject { private func drainLocalBackgroundASRQueue() { guard localBackgroundASRTask == nil, let session = localBackgroundSession else { return } - localBackgroundASRTask = Task { [weak self, session] in + localBackgroundASRTask = Task { @MainActor [weak self, session] in + var didFail = false do { while let result = try await session.transcribeNext() { - await MainActor.run { - self?.applyLocalBackgroundSegments(result.remappedSegments) - } + self?.applyLocalBackgroundSegments(result.remappedSegments) } } catch { - await MainActor.run { - self?.localBackgroundState = .failed - logError("Transcription: Local background ASR failed", error: error) - } + didFail = true + self?.localBackgroundState = .failed + logError("Transcription: Local background ASR failed", error: error) } - await MainActor.run { - self?.localBackgroundASRTask = nil + guard let self else { return } + self.localBackgroundASRTask = nil + if !didFail && self.localBackgroundSession === session && session.pendingChunkCount > 0 { + self.drainLocalBackgroundASRQueue() } } } @@ -1888,6 +1942,9 @@ class AppState: ObservableObject { /// triggers conversation processing on the backend side. We also call force-process to ensure /// the conversation is finalized, preventing the retry service from creating duplicates. func stopTranscription() { + localTranscriptionCapabilityProbeTask?.cancel() + localTranscriptionCapabilityProbeTask = nil + if localBackgroundSession != nil { stopLocalBackgroundTranscription() return diff --git a/desktop/Desktop/Sources/CodexProxyService.swift b/desktop/Desktop/Sources/CodexProxyService.swift index 220ad204d1e..c119fc343d1 100644 --- a/desktop/Desktop/Sources/CodexProxyService.swift +++ b/desktop/Desktop/Sources/CodexProxyService.swift @@ -21,6 +21,34 @@ enum CodexProxyEndpoints { } } +private final class LockedCodexProxyStderrBuffer: @unchecked Sendable { + private let lock = NSLock() + private let limit: Int + private var data = Data() + + init(limit: Int = 16 * 1024) { + self.limit = limit + } + + func append(_ chunk: Data) { + guard !chunk.isEmpty else { return } + lock.lock() + data.append(chunk) + if data.count > limit { + data.removeFirst(data.count - limit) + } + lock.unlock() + } + + func snapshot() -> String? { + lock.lock() + let snapshot = data + lock.unlock() + return String(data: snapshot, encoding: .utf8)? + .trimmingCharacters(in: .whitespacesAndNewlines) + } +} + /// Manages the loopback Codex OpenAI-compatible proxy (`desktop/codex-proxy`). @MainActor final class CodexProxyService: ObservableObject { @@ -42,11 +70,28 @@ final class CodexProxyService: ObservableObject { private var process: Process? private var healthTask: Task? + private var ensureTask: Task? + private var stderrPipe: Pipe? private init() {} /// Start proxy when ChatGPT tier is active. Idempotent. func ensureRunning() async { + if let ensureTask { + await ensureTask.value + return + } + + let task = Task { @MainActor [weak self] in + guard let self else { return } + await self.ensureRunningOnce() + } + ensureTask = task + await task.value + ensureTask = nil + } + + private func ensureRunningOnce() async { guard CodexAuthService.isActive else { await stop() return @@ -80,14 +125,20 @@ final class CodexProxyService: ObservableObject { env["OMI_CODEX_PROXY_PORT"] = String(Self.port) proc.environment = env let stderrPipe = Pipe() + let stderrBuffer = LockedCodexProxyStderrBuffer() + stderrPipe.fileHandleForReading.readabilityHandler = { handle in + stderrBuffer.append(handle.availableData) + } proc.standardError = stderrPipe proc.standardOutput = FileHandle.nullDevice do { try proc.run() process = proc + self.stderrPipe = stderrPipe for _ in 0..<50 { try? await Task.sleep(nanoseconds: 100_000_000) + guard !Task.isCancelled else { return } if await healthCheck() { isRunning = true lastError = nil @@ -96,9 +147,7 @@ final class CodexProxyService: ObservableObject { return } } - let stderrData = stderrPipe.fileHandleForReading.readDataToEndOfFile() - let stderrHint = String(data: stderrData, encoding: .utf8)? - .trimmingCharacters(in: .whitespacesAndNewlines) + let stderrHint = stderrBuffer.snapshot() let detail = (stderrHint?.isEmpty == false) ? stderrHint! @@ -107,6 +156,7 @@ final class CodexProxyService: ObservableObject { logError("CodexProxyService: failed to start — \(detail)") await stop() } catch { + stderrPipe.fileHandleForReading.readabilityHandler = nil lastError = error.localizedDescription await stop() } @@ -115,6 +165,10 @@ final class CodexProxyService: ObservableObject { func stop() async { healthTask?.cancel() healthTask = nil + ensureTask?.cancel() + ensureTask = nil + stderrPipe?.fileHandleForReading.readabilityHandler = nil + stderrPipe = nil if let process, process.isRunning { process.terminate() } @@ -127,6 +181,7 @@ final class CodexProxyService: ObservableObject { healthTask = Task { while !Task.isCancelled { try? await Task.sleep(nanoseconds: 15_000_000_000) + guard !Task.isCancelled else { return } guard CodexAuthService.isActive else { await stop() return diff --git a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift index 3c985780b22..14d33f55ae9 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/PushToTalkManager.swift @@ -3,6 +3,31 @@ import Cocoa import Combine import CoreAudio +final class LockedPTTAudioBuffer: @unchecked Sendable { + private let lock = NSLock() + private var data = Data() + + func append(_ chunk: Data) { + lock.lock() + data.append(chunk) + lock.unlock() + } + + func takeAll() -> Data { + lock.lock() + let snapshot = data + data = Data() + lock.unlock() + return snapshot + } + + func clear() { + lock.lock() + data = Data() + lock.unlock() + } +} + /// Push-to-talk manager for voice input via the Option (⌥) key. /// /// State machine: @@ -51,8 +76,7 @@ class PushToTalkManager: ObservableObject { private var shortcutCaptureSuspended = false // Batch mode: accumulate raw audio for post-recording transcription - private var batchAudioBuffer = Data() - private let batchAudioLock = NSLock() + private let batchAudioBuffer = LockedPTTAudioBuffer() // Live mode: timeout for waiting on final transcript after CloseStream private var liveFinalizationTimeout: DispatchWorkItem? @@ -312,9 +336,7 @@ class PushToTalkManager: ObservableObject { transcriptSegments = [] lastInterimText = "" currentContextSnapshot = nil - batchAudioLock.lock() - batchAudioBuffer = Data() - batchAudioLock.unlock() + batchAudioBuffer.clear() currentTranscriptionProvider = .cloud isCurrentSessionFollowUp = false updateBarState() @@ -356,10 +378,7 @@ class PushToTalkManager: ObservableObject { if isBatchMode { // Batch mode: send accumulated audio to pre-recorded API log("PushToTalkManager: finalizing (batch) — mic stopped, transcribing recorded audio") - batchAudioLock.lock() - let audioData = batchAudioBuffer - batchAudioBuffer = Data() - batchAudioLock.unlock() + let audioData = batchAudioBuffer.takeAll() // Stop streaming service (was not used in batch mode, but clean up) stopAudioTranscription() @@ -547,15 +566,15 @@ class PushToTalkManager: ObservableObject { if isBatchMode { // Batch mode: just capture audio into buffer, no streaming connection - batchAudioLock.lock() - batchAudioBuffer = Data() - batchAudioLock.unlock() - startMicCapture(batchMode: true) + batchAudioBuffer.clear() + startMicCapture( + batchMode: true, + audioHandler: { [batchAudioBuffer] audioData in + batchAudioBuffer.append(audioData) + } + ) log("PushToTalkManager: started audio capture (batch mode)") } else { - // Live mode: start mic capture and stream to Deepgram - startMicCapture() - do { let language = AssistantSettings.shared.effectiveTranscriptionLanguage let service = try TranscriptionService( @@ -584,6 +603,12 @@ class PushToTalkManager: ObservableObject { } } ) + startMicCapture( + batchMode: false, + audioHandler: { [weak service] audioData in + service?.sendAudio(audioData) + } + ) } catch { logError("PushToTalkManager: failed to create TranscriptionService", error: error) stopListening() @@ -591,7 +616,11 @@ class PushToTalkManager: ObservableObject { } } - private func startMicCapture(batchMode: Bool = false, overrideDeviceID: AudioDeviceID? = nil) { + private func startMicCapture( + batchMode: Bool = false, + overrideDeviceID: AudioDeviceID? = nil, + audioHandler: @escaping @Sendable (Data) -> Void + ) { if audioCaptureService == nil { if let override = overrideDeviceID { audioCaptureService = AudioCaptureService(overrideDeviceID: override) @@ -613,17 +642,8 @@ class PushToTalkManager: ObservableObject { guard let self else { return } do { try await capture.startCapture( - onAudioChunk: { [weak self] audioData in - guard let self else { return } - if batchMode { - // Batch mode: accumulate audio in buffer - self.batchAudioLock.lock() - self.batchAudioBuffer.append(audioData) - self.batchAudioLock.unlock() - } else { - // Live mode: stream to Deepgram - self.transcriptionService?.sendAudio(audioData) - } + onAudioChunk: { audioData in + audioHandler(audioData) }, onAudioLevel: { _ in } ) @@ -650,7 +670,23 @@ class PushToTalkManager: ObservableObject { "PushToTalkManager: silent-mic fallback — switching to built-in mic (deviceID=\(builtInID))") audioCaptureService?.stopCapture() audioCaptureService = nil - startMicCapture(batchMode: batchMode, overrideDeviceID: builtInID) + if batchMode { + startMicCapture( + batchMode: true, + overrideDeviceID: builtInID, + audioHandler: { [batchAudioBuffer] audioData in + batchAudioBuffer.append(audioData) + } + ) + } else if let service = transcriptionService { + startMicCapture( + batchMode: false, + overrideDeviceID: builtInID, + audioHandler: { [weak service] audioData in + service?.sendAudio(audioData) + } + ) + } } private func stopAudioTranscription() { diff --git a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift index efc347c3123..f050fbf6576 100644 --- a/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift +++ b/desktop/Desktop/Sources/LocalTranscription/LocalASRRuntime.swift @@ -109,38 +109,34 @@ struct LocalASRHelperClient { input.fileHandleForWriting.write(requestData) try? input.fileHandleForWriting.close() - return try await withTimeout(seconds: timeoutSeconds) { - process.waitUntilExit() - let outputData = output.fileHandleForReading.readDataToEndOfFile() - if process.terminationStatus != 0 { - let errorText = - String(data: errors.fileHandleForReading.readDataToEndOfFile(), encoding: .utf8) ?? "" - throw TranscriptionService.TranscriptionError.webSocketError( - "Local ASR helper exited with status \(process.terminationStatus): \(errorText)" - ) + let didExit = await waitForExit(process, timeoutSeconds: timeoutSeconds) + guard didExit else { + if process.isRunning { + process.terminate() + _ = await waitForExit(process, timeoutSeconds: 2) } - return try JSONDecoder.localASR.decode(LocalASRTranscriptionResponse.self, from: outputData) + throw TranscriptionService.TranscriptionError.webSocketError( + "Local ASR helper timed out after \(Int(timeoutSeconds))s" + ) + } + + let outputData = output.fileHandleForReading.readDataToEndOfFile() + if process.terminationStatus != 0 { + let errorText = + String(data: errors.fileHandleForReading.readDataToEndOfFile(), encoding: .utf8) ?? "" + throw TranscriptionService.TranscriptionError.webSocketError( + "Local ASR helper exited with status \(process.terminationStatus): \(errorText)" + ) } + return try JSONDecoder.localASR.decode(LocalASRTranscriptionResponse.self, from: outputData) } - private func withTimeout( - seconds: TimeInterval, - operation: @escaping @Sendable () throws -> T - ) async throws -> T { - try await withThrowingTaskGroup(of: T.self) { group in - group.addTask { - try operation() - } - group.addTask { - try await Task.sleep(nanoseconds: UInt64(seconds * 1_000_000_000)) - throw CancellationError() - } - guard let result = try await group.next() else { - throw CancellationError() - } - group.cancelAll() - return result + private func waitForExit(_ process: Process, timeoutSeconds: TimeInterval) async -> Bool { + let deadline = Date().addingTimeInterval(timeoutSeconds) + while process.isRunning && Date() < deadline { + try? await Task.sleep(nanoseconds: 100_000_000) } + return !process.isRunning } } @@ -393,6 +389,8 @@ struct LocalBackgroundChunkerConfiguration: Equatable { var overlapDuration: TimeInterval = 1 var silenceWindowDuration: TimeInterval = 0.35 var silenceAmplitudeThreshold: Int16 = 256 + var speechPeakAmplitudeThreshold: Int16 = 512 + var speechRMSAmplitudeThreshold: Double = 64 var maxPendingChunks: Int = 4 var maxChunkSamples: Int { max(1, Int(maxChunkDuration * Double(sampleRate))) } @@ -652,6 +650,24 @@ final class LocalBackgroundTranscriptionSession { func transcribeNext() async throws -> LocalBackgroundASRRawChunkResult? { guard !pendingChunks.isEmpty else { return nil } let chunk = pendingChunks.removeFirst() + guard hasSpeechEnergy(chunk.audioData) else { + let response = LocalASRTranscriptionResponse( + requestId: makeRequestId(), + engine: plan.engine, + model: plan.model, + language: language, + segments: [], + fixture: false + ) + let rawResult = LocalBackgroundASRRawChunkResult( + chunk: chunk, + response: response, + remappedSegments: [], + latencySeconds: 0 + ) + rawChunkResults.append(rawResult) + return rawResult + } let requestId = makeRequestId() let audioURL = temporaryDirectory.appendingPathComponent("\(requestId).pcm") try chunk.audioData.write(to: audioURL, options: .atomic) @@ -736,6 +752,26 @@ final class LocalBackgroundTranscriptionSession { return normalized } } + + private func hasSpeechEnergy(_ audioData: Data) -> Bool { + var peak = 0 + var sumSquares = 0.0 + var sampleCount = 0 + + audioData.withUnsafeBytes { rawBuffer in + for sample in rawBuffer.bindMemory(to: Int16.self) { + let magnitude = sample == Int16.min ? Int(Int16.max) : Int(abs(sample)) + peak = max(peak, magnitude) + sumSquares += Double(magnitude * magnitude) + sampleCount += 1 + } + } + + guard sampleCount > 0 else { return false } + let rms = sqrt(sumSquares / Double(sampleCount)) + return peak >= Int(configuration.speechPeakAmplitudeThreshold) + && rms >= configuration.speechRMSAmplitudeThreshold + } } struct PTTBatchTranscriptionResult: Equatable { diff --git a/desktop/Desktop/Sources/LocalTranscription/TranscriptionComparisonHarness.swift b/desktop/Desktop/Sources/LocalTranscription/TranscriptionComparisonHarness.swift index 25082ff9527..e4561ee9b47 100644 --- a/desktop/Desktop/Sources/LocalTranscription/TranscriptionComparisonHarness.swift +++ b/desktop/Desktop/Sources/LocalTranscription/TranscriptionComparisonHarness.swift @@ -13,11 +13,26 @@ struct TranscriptionComparisonProviderSnapshot: Equatable { } } +struct TranscriptionComparisonTimeBucketSnapshot: Equatable, Identifiable { + var id: String { "\(Int(startTime))-\(Int(endTime))" } + var startTime: Double + var endTime: Double + var whisperText: String + var deepgramText: String + var whisperSegmentCount: Int + var deepgramSegmentCount: Int + + var hasContent: Bool { + !whisperText.isEmpty || !deepgramText.isEmpty + } +} + struct TranscriptionComparisonHarnessSnapshot: Equatable { var isRunning: Bool var startedAt: Date? var whisper: TranscriptionComparisonProviderSnapshot var deepgram: TranscriptionComparisonProviderSnapshot + var timeBuckets: [TranscriptionComparisonTimeBucketSnapshot] var wordDifferenceRate: Double? var characterDifferenceRate: Double? @@ -26,6 +41,7 @@ struct TranscriptionComparisonHarnessSnapshot: Equatable { startedAt: nil, whisper: .empty(title: "Local Whisper"), deepgram: .empty(title: "Local Deepgram"), + timeBuckets: [], wordDifferenceRate: nil, characterDifferenceRate: nil ) @@ -34,6 +50,10 @@ struct TranscriptionComparisonHarnessSnapshot: Equatable { @MainActor final class TranscriptionComparisonHarness { static let enabledDefaultsKey = "dev_transcription_comparison_harness_enabled" + private static let transcriptPreviewLimit = 12_000 + private static let bucketDuration: Double = 30 + private static let maxBuckets = 8 + private static let bucketTextLimit = 900 static var isEnabled: Bool { #if DEBUG @@ -75,6 +95,7 @@ final class TranscriptionComparisonHarness { title: "Local Deepgram", status: deepgramAPIKey == nil ? "Missing Deepgram API key" : "Connecting" ), + timeBuckets: [], wordDifferenceRate: nil, characterDifferenceRate: nil ) @@ -90,12 +111,14 @@ final class TranscriptionComparisonHarness { }, onStatus: { [weak self] status in Task { @MainActor in + log("TranscriptionComparison: Deepgram \(status)") self?.deepgramStatus = status self?.publish() } }, onError: { [weak self] error in Task { @MainActor in + logError("TranscriptionComparison: Deepgram error", error: error) self?.deepgramStatus = "Failed" self?.deepgramError = error.localizedDescription self?.publish() @@ -123,6 +146,12 @@ final class TranscriptionComparisonHarness { publish() } + #if DEBUG + func appendDeepgramSegmentsForTesting(_ segments: [NormalizedTranscriptSegment]) { + appendDeepgramSegments(segments) + } + #endif + func finish() { deepgramSession?.finish() whisperStatus = whisperSegments.isEmpty ? "No transcript" : "Finalized" @@ -168,23 +197,28 @@ final class TranscriptionComparisonHarness { private func publish(isRunning: Bool? = nil) { let whisperText = joinedTranscript(whisperSegments) let deepgramText = joinedTranscript(deepgramSegments) + let whisperPreviewText = transcriptPreview(whisperText) + let deepgramPreviewText = transcriptPreview(deepgramText) snapshot = TranscriptionComparisonHarnessSnapshot( isRunning: isRunning ?? snapshot.isRunning, startedAt: snapshot.startedAt, whisper: providerSnapshot( title: "Local Whisper", status: whisperStatus, - transcript: whisperText, + transcript: whisperPreviewText, segments: whisperSegments, + wordCount: TranscriptComparison.normalizedWords(whisperText).count, error: whisperError ), deepgram: providerSnapshot( title: "Local Deepgram", status: deepgramStatus, - transcript: deepgramText, + transcript: deepgramPreviewText, segments: deepgramSegments, + wordCount: TranscriptComparison.normalizedWords(deepgramText).count, error: deepgramError ), + timeBuckets: timeBuckets(whisper: whisperSegments, deepgram: deepgramSegments), wordDifferenceRate: comparisonRate( reference: whisperText, hypothesis: deepgramText, @@ -204,6 +238,7 @@ final class TranscriptionComparisonHarness { status: String, transcript: String, segments: [NormalizedTranscriptSegment], + wordCount: Int, error: String? ) -> TranscriptionComparisonProviderSnapshot { TranscriptionComparisonProviderSnapshot( @@ -211,7 +246,7 @@ final class TranscriptionComparisonHarness { status: status, transcript: transcript, segmentCount: segments.count, - wordCount: TranscriptComparison.normalizedWords(transcript).count, + wordCount: wordCount, error: error ) } @@ -224,6 +259,64 @@ final class TranscriptionComparisonHarness { .trimmingCharacters(in: .whitespacesAndNewlines) } + private func transcriptPreview(_ joined: String) -> String { + guard joined.count > Self.transcriptPreviewLimit else { return joined } + return String(joined.suffix(Self.transcriptPreviewLimit)) + } + + private func timeBuckets( + whisper: [NormalizedTranscriptSegment], + deepgram: [NormalizedTranscriptSegment] + ) -> [TranscriptionComparisonTimeBucketSnapshot] { + let allSegments = whisper + deepgram + guard let maxEnd = allSegments.map(\.end).max(), maxEnd > 0 else { return [] } + + let lastBucketIndex = max(0, Int(maxEnd / Self.bucketDuration)) + let firstBucketIndex = max(0, lastBucketIndex - Self.maxBuckets + 1) + + return (firstBucketIndex...lastBucketIndex).compactMap { index in + let start = Double(index) * Self.bucketDuration + let end = start + Self.bucketDuration + let whisperInBucket = segments(whisper, overlappingStart: start, end: end) + let deepgramInBucket = segments(deepgram, overlappingStart: start, end: end) + let bucket = TranscriptionComparisonTimeBucketSnapshot( + startTime: start, + endTime: end, + whisperText: bucketText(whisperInBucket), + deepgramText: bucketText(deepgramInBucket), + whisperSegmentCount: whisperInBucket.count, + deepgramSegmentCount: deepgramInBucket.count + ) + return bucket.hasContent ? bucket : nil + } + } + + private func segments( + _ segments: [NormalizedTranscriptSegment], + overlappingStart start: Double, + end: Double + ) -> [NormalizedTranscriptSegment] { + segments + .filter { $0.end > start && $0.start < end } + .sorted { lhs, rhs in + if lhs.start == rhs.start { + return lhs.end < rhs.end + } + return lhs.start < rhs.start + } + } + + private func bucketText(_ segments: [NormalizedTranscriptSegment]) -> String { + let joined = + segments + .map(\.text) + .joined(separator: " ") + .replacingOccurrences(of: #"\s+"#, with: " ", options: .regularExpression) + .trimmingCharacters(in: .whitespacesAndNewlines) + guard joined.count > Self.bucketTextLimit else { return joined } + return String(joined.prefix(Self.bucketTextLimit)) + "..." + } + private func comparisonRate( reference: String, hypothesis: String, @@ -272,9 +365,16 @@ final class DeepgramBackgroundTranscriptionSession { private let onError: (Error) -> Void private var webSocketTask: URLSessionWebSocketTask? private var urlSession: URLSession? + private let queue = DispatchQueue(label: "com.omi.transcription.deepgram-comparison") private var isConnected = false - private var pendingAudio = Data() - private let pendingAudioLimit = 16_000 * 2 * 5 + private var pendingAudioChunks: [Data] = [] + private var pendingAudioBytes = 0 + private var isSendingAudio = false + private let pendingAudioLimit = 16_000 * 2 * 8 + private var keepAliveTimer: DispatchSourceTimer? + private var shouldReconnect = false + private var isFinishing = false + private var reconnectAttempt = 0 init( language: String, @@ -291,7 +391,16 @@ final class DeepgramBackgroundTranscriptionSession { } func start() { - guard webSocketTask == nil else { return } + queue.async { + guard self.webSocketTask == nil else { return } + self.shouldReconnect = true + self.isFinishing = false + self.reconnectAttempt = 0 + self.openSocket() + } + } + + private func openSocket() { guard var components = URLComponents(string: "wss://api.deepgram.com/v1/listen") else { return } @@ -324,54 +433,86 @@ final class DeepgramBackgroundTranscriptionSession { DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) { [weak self] in guard let self else { return } - guard self.webSocketTask?.state == .running else { - self.onStatus("Failed to connect") - return + self.queue.async { + guard self.webSocketTask?.state == .running else { + self.handleConnectionFailure( + NSError( + domain: "DeepgramBackgroundTranscriptionSession", + code: 1, + userInfo: [NSLocalizedDescriptionKey: "Failed to connect to Deepgram"] + ) + ) + return + } + self.isConnected = true + self.reconnectAttempt = 0 + self.onStatus("Connected") + self.startKeepAliveTimer() + self.drainAudioQueue() } - self.isConnected = true - self.onStatus("Connected") - self.flushPendingAudio() } } func appendAudio(_ data: Data) { - guard isConnected else { - pendingAudio.append(data) - if pendingAudio.count > pendingAudioLimit { - pendingAudio.removeFirst(pendingAudio.count - pendingAudioLimit) - } - return + queue.async { + self.enqueueAudio(data) + self.drainAudioQueue() } - sendAudio(data) } func finish() { - guard isConnected else { return } - sendString("{\"type\":\"CloseStream\"}") - onStatus("Finalizing") + queue.async { + self.isFinishing = true + self.shouldReconnect = false + guard self.isConnected else { + self.stopKeepAliveTimer() + self.onStatus("Finalized") + return + } + self.sendString("{\"type\":\"CloseStream\"}") + self.onStatus("Finalizing") + } } func stop() { - isConnected = false - webSocketTask?.cancel(with: .normalClosure, reason: nil) - webSocketTask = nil - urlSession?.invalidateAndCancel() - urlSession = nil - pendingAudio.removeAll() - onStatus("Stopped") + queue.async { + self.shouldReconnect = false + self.isFinishing = false + self.isConnected = false + self.isSendingAudio = false + self.stopKeepAliveTimer() + self.webSocketTask?.cancel(with: .normalClosure, reason: nil) + self.webSocketTask = nil + self.urlSession?.invalidateAndCancel() + self.urlSession = nil + self.pendingAudioChunks.removeAll() + self.pendingAudioBytes = 0 + self.onStatus("Stopped") + } } - private func flushPendingAudio() { - guard !pendingAudio.isEmpty else { return } - let audio = pendingAudio - pendingAudio.removeAll() - sendAudio(audio) + private func enqueueAudio(_ data: Data) { + pendingAudioChunks.append(data) + pendingAudioBytes += data.count + while pendingAudioBytes > pendingAudioLimit, !pendingAudioChunks.isEmpty { + pendingAudioBytes -= pendingAudioChunks.removeFirst().count + } } - private func sendAudio(_ data: Data) { + private func drainAudioQueue() { + guard isConnected, !isSendingAudio, !pendingAudioChunks.isEmpty else { return } + let data = pendingAudioChunks.removeFirst() + pendingAudioBytes -= data.count + isSendingAudio = true webSocketTask?.send(.data(data)) { [weak self] error in - if let error { - self?.onError(error) + guard let self else { return } + self.queue.async { + self.isSendingAudio = false + if let error { + self.handleConnectionFailure(error) + return + } + self.drainAudioQueue() } } } @@ -379,7 +520,9 @@ final class DeepgramBackgroundTranscriptionSession { private func sendString(_ value: String) { webSocketTask?.send(.string(value)) { [weak self] error in if let error { - self?.onError(error) + self?.queue.async { + self?.handleConnectionFailure(error) + } } } } @@ -392,13 +535,55 @@ final class DeepgramBackgroundTranscriptionSession { self.handleMessage(message) self.receiveMessage() case .failure(let error): - guard self.isConnected else { return } - self.isConnected = false - self.onError(error) + self.queue.async { + self.handleConnectionFailure(error) + } } } } + private func startKeepAliveTimer() { + stopKeepAliveTimer() + let timer = DispatchSource.makeTimerSource(queue: queue) + timer.schedule(deadline: .now() + 5, repeating: 5) + timer.setEventHandler { [weak self] in + guard let self, self.isConnected else { return } + self.sendString("{\"type\":\"KeepAlive\"}") + } + keepAliveTimer = timer + timer.resume() + } + + private func stopKeepAliveTimer() { + keepAliveTimer?.cancel() + keepAliveTimer = nil + } + + private func handleConnectionFailure(_ error: Error) { + guard shouldReconnect || isConnected else { return } + isConnected = false + isSendingAudio = false + stopKeepAliveTimer() + webSocketTask?.cancel(with: .goingAway, reason: nil) + webSocketTask = nil + urlSession?.invalidateAndCancel() + urlSession = nil + + guard shouldReconnect, !isFinishing else { + onError(error) + return + } + + reconnectAttempt += 1 + let delay = min(10.0, pow(2.0, Double(min(reconnectAttempt, 3)))) + onStatus("Reconnecting in \(Int(delay))s") + queue.asyncAfter(deadline: .now() + delay) { [weak self] in + guard let self, self.shouldReconnect, self.webSocketTask == nil else { return } + self.onStatus("Reconnecting") + self.openSocket() + } + } + private func handleMessage(_ message: URLSessionWebSocketTask.Message) { let data: Data? switch message { diff --git a/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift b/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift index 5358ca085b7..782fddecf6f 100644 --- a/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift +++ b/desktop/Desktop/Sources/MainWindow/DesktopHomeView.swift @@ -25,6 +25,7 @@ struct DesktopHomeView: View { @AppStorage("currentTierLevel") private var currentTierLevel = 0 @AppStorage("onboardingStep") private var onboardingStep = 0 @AppStorage("onboardingJustCompleted") private var onboardingJustCompleted = false + @AppStorage("hasSeenLocalWelcome") private var hasSeenLocalWelcome = false // Settings sidebar state @State private var selectedSettingsSection: SettingsContentView.SettingsSection = .general @@ -74,13 +75,38 @@ struct DesktopHomeView: View { } } else if !appState.hasCompletedOnboarding { // State 2: Signed in but onboarding not complete - if shouldSkipOnboarding() - || DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon - { + let isLocalDaemon = + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + + if shouldSkipOnboarding() { Color.clear.onAppear { log("DesktopHomeView: --skip-onboarding flag detected, skipping onboarding") appState.hasCompletedOnboarding = true } + } else if isLocalDaemon { + if hasSeenLocalWelcome { + Color.clear.onAppear { + log("DesktopHomeView: Local daemon mode, welcome seen — skipping onboarding") + appState.hasCompletedOnboarding = true + } + } else { + LocalWelcomeView( + onGetStarted: { + log("DesktopHomeView: Local welcome completed — navigating to Local Setup") + hasSeenLocalWelcome = true + appState.hasCompletedOnboarding = true + selectedSettingsSection = .planUsage + DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) { + withAnimation(.easeInOut(duration: 0.2)) { + selectedIndex = SidebarNavItem.settings.rawValue + } + } + } + ) + .onAppear { + log("DesktopHomeView: Showing LocalWelcomeView (first launch in local mode)") + } + } } else { OnboardingView( appState: appState, chatProvider: viewModelContainer.chatProvider, onComplete: nil @@ -891,6 +917,101 @@ private struct ConversationsPageHost: View { } } +// MARK: - Local Welcome + +/// One-screen orientation shown on the first launch in local-daemon mode. +struct LocalWelcomeView: View { + let onGetStarted: () -> Void + + var body: some View { + VStack(spacing: 0) { + Spacer() + + VStack(spacing: 18) { + Image(systemName: "desktopcomputer") + .scaledFont(size: 44) + .foregroundColor(OmiColors.purplePrimary) + + Text("Omi runs locally on this Mac") + .scaledFont(size: 26, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + .multilineTextAlignment(.center) + + Text("A quick orientation before you get started.") + .scaledFont(size: 14) + .foregroundColor(OmiColors.textTertiary) + .multilineTextAlignment(.center) + } + .padding(.bottom, 36) + + VStack(alignment: .leading, spacing: 18) { + LocalWelcomeBullet( + icon: "lock.shield.fill", + title: "Your data stays on this Mac", + detail: "Transcripts, memories, and AI calls run locally by default." + ) + LocalWelcomeBullet( + icon: "key.fill", + title: "Bring your own AI keys", + detail: "Point Omi at any OpenAI-compatible endpoint — Ollama, LM Studio, or a hosted API." + ) + LocalWelcomeBullet( + icon: "waveform", + title: "Background transcription is on-device", + detail: "Whisper runs in the background; no audio is sent to the cloud." + ) + } + .padding(.horizontal, 32) + .frame(maxWidth: 520) + + Spacer() + + Button(action: onGetStarted) { + Text("Get started") + .scaledFont(size: 16, weight: .semibold) + .foregroundColor(.white) + .frame(width: 240, height: 48) + .background( + RoundedRectangle(cornerRadius: 10) + .fill(OmiColors.purplePrimary) + ) + } + .buttonStyle(.plain) + .padding(.bottom, 60) + } + .frame(maxWidth: .infinity, maxHeight: .infinity) + .background(OmiColors.backgroundPrimary) + } +} + +private struct LocalWelcomeBullet: View { + let icon: String + let title: String + let detail: String + + var body: some View { + HStack(alignment: .top, spacing: 14) { + Image(systemName: icon) + .scaledFont(size: 18) + .foregroundColor(OmiColors.purplePrimary) + .frame(width: 24) + + VStack(alignment: .leading, spacing: 4) { + Text(title) + .scaledFont(size: 15, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + + Text(detail) + .scaledFont(size: 13) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } + + Spacer() + } + } +} + #Preview { DesktopHomeView() } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/HybridProviderSetupSheet.swift b/desktop/Desktop/Sources/MainWindow/Pages/HybridProviderSetupSheet.swift new file mode 100644 index 00000000000..37f7f2ac8e7 --- /dev/null +++ b/desktop/Desktop/Sources/MainWindow/Pages/HybridProviderSetupSheet.swift @@ -0,0 +1,216 @@ +import SwiftUI + +/// Sheet that hosts the full BYOK provider editor (base URL, API key, per-slot models). +/// Lifted out of the Plan & Usage page to keep that surface as a calm status view. +struct HybridProviderSetupSheet: View { + @Binding var baseURL: String + @Binding var apiKey: String + @Binding var chatModel: String + @Binding var postTranscriptModel: String + @Binding var proactiveModel: String + @Binding var visionModel: String + + let status: String? + let isSaving: Bool + let isTesting: Bool + let applyDefaults: () -> Void + let save: () -> Void + let test: (String) -> Void + let dismiss: () -> Void + + var body: some View { + VStack(spacing: 0) { + header + + Divider().overlay(OmiColors.backgroundQuaternary) + + ScrollView { + VStack(alignment: .leading, spacing: 18) { + providerAccountCard + + slotCard( + title: "Chat", + subtitle: "Powers Ask Omi chat replies.", + model: $chatModel, + slot: HybridProviderPolicy.chatSlot + ) + + slotCard( + title: "Post-transcript processing", + subtitle: "Titles, summaries, memories, and action items.", + model: $postTranscriptModel, + slot: HybridProviderPolicy.postTranscriptSlot + ) + + slotCard( + title: "Proactive assistants", + subtitle: "Local assistant jobs. Defaults to \(HybridProviderReadiness.defaultSmallModel()).", + model: $proactiveModel, + slot: HybridProviderPolicy.proactiveSlot + ) + + slotCard( + title: "Vision", + subtitle: "Optional. Leave blank to use local OCR text.", + model: $visionModel, + slot: HybridProviderPolicy.visionSlot, + optional: true + ) + + memorySearchNote + + if let status { + Text(status) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + .textSelection(.enabled) + .frame(maxWidth: .infinity, alignment: .leading) + .padding(12) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(OmiColors.backgroundTertiary.opacity(0.5)) + ) + } + } + .padding(24) + } + + Divider().overlay(OmiColors.backgroundQuaternary) + + footer + } + .frame(width: 560, height: 640) + .background(OmiColors.backgroundPrimary) + } + + // MARK: - Sections + + private var header: some View { + HStack(alignment: .center) { + VStack(alignment: .leading, spacing: 4) { + Text("Configure providers") + .scaledFont(size: 18, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + Text( + "Bring your own AI endpoint, then assign models per task. Keys stay on this Mac." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } + + Spacer() + + Button(action: dismiss) { + Image(systemName: "xmark") + .scaledFont(size: 13, weight: .semibold) + .foregroundColor(OmiColors.textSecondary) + .padding(8) + } + .buttonStyle(.plain) + } + .padding(.horizontal, 24) + .padding(.vertical, 18) + } + + private var providerAccountCard: some View { + VStack(alignment: .leading, spacing: 10) { + Text("Provider account") + .scaledFont(size: 12, weight: .medium) + .foregroundColor(OmiColors.textTertiary) + + TextField("Base URL", text: $baseURL) + .textFieldStyle(.roundedBorder) + + SecureField("API key (optional on loopback)", text: $apiKey) + .textFieldStyle(.roundedBorder) + } + .padding(16) + .background( + RoundedRectangle(cornerRadius: 10) + .fill(OmiColors.backgroundTertiary.opacity(0.5)) + ) + } + + private func slotCard( + title: String, + subtitle: String, + model: Binding, + slot: String, + optional: Bool = false + ) -> some View { + VStack(alignment: .leading, spacing: 8) { + Text(title) + .scaledFont(size: 13, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + Text(subtitle) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + + TextField("Model", text: model) + .textFieldStyle(.roundedBorder) + + HStack(spacing: 8) { + Spacer() + Button("Test") { + test(slot) + } + .buttonStyle(.bordered) + .disabled( + isTesting + || (optional + && model.wrappedValue.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty) + ) + } + } + .padding(16) + .background( + RoundedRectangle(cornerRadius: 10) + .fill(OmiColors.backgroundTertiary.opacity(0.5)) + ) + } + + private var memorySearchNote: some View { + HStack(alignment: .top, spacing: 8) { + Image(systemName: "info.circle.fill") + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .padding(.top, 1) + Text("Memory search uses the on-device local wiki / FTS index — no embeddings required.") + .scaledFont(size: 11) + .foregroundColor(OmiColors.textSecondary) + .fixedSize(horizontal: false, vertical: true) + Spacer() + } + .padding(.horizontal, 4) + } + + private var footer: some View { + HStack(spacing: 12) { + Button("Apply local defaults", action: applyDefaults) + .buttonStyle(.bordered) + .disabled(isSaving) + + Spacer() + + Button("Done", action: dismiss) + .buttonStyle(.bordered) + + Button { + save() + } label: { + if isSaving { + ProgressView().controlSize(.small) + } else { + Text("Save") + .scaledFont(size: 13, weight: .semibold) + } + } + .buttonStyle(.borderedProminent) + .disabled(isSaving) + } + .padding(.horizontal, 24) + .padding(.vertical, 14) + } +} diff --git a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift index 9370be839c5..1ad4face180 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/SettingsPage.swift @@ -10,13 +10,22 @@ struct SettingsPage: View { @Binding var highlightedSettingId: String? var chatProvider: ChatProvider? = nil + private var sectionDisplayTitle: String { + if selectedSection == .planUsage + && DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + { + return "Local Setup" + } + return selectedSection.rawValue + } + var body: some View { ScrollViewReader { proxy in ScrollView { VStack(spacing: 0) { // Section header HStack { - Text(selectedSection.rawValue) + Text(sectionDisplayTitle) .scaledFont(size: 28, weight: .bold) .foregroundColor(OmiColors.textPrimary) .id(selectedSection) @@ -378,6 +387,7 @@ struct SettingsContentView: View { @State private var showResetOnboardingAlert: Bool = false @State private var showRescanFilesAlert: Bool = false @State private var showDeleteAccountAlert: Bool = false + @State private var showHybridProviderSheet: Bool = false // Gmail Reader states @State private var gmailEmails: [GmailEmail] = [] @@ -1106,25 +1116,22 @@ struct SettingsContentView: View { VStack(spacing: 10) { transcriptionProviderOption( mode: .auto, - title: "Local Background First", - detail: - "Use local Whisper for continuous background transcription when available. If local is unavailable, this mode may use cloud transcription.", + title: "Automatic", + detail: "Use on-device transcription when possible, fall back to cloud.", icon: "sparkle.magnifyingglass" ) transcriptionProviderOption( mode: .local, - title: "Local Background Only", - detail: - "Use on-device Whisper for continuous background transcription. If local ASR is unavailable, background transcription will fail instead of using cloud.", + title: "On-device only", + detail: "Never send audio to the cloud.", icon: "desktopcomputer" ) transcriptionProviderOption( mode: .cloud, - title: "Cloud Transcription", - detail: - "Use the existing Omi cloud transcription path for live meetings and continuous background capture.", + title: "Cloud", + detail: "Use the Omi cloud for live meetings and background capture.", icon: "cloud.fill" ) } @@ -1137,9 +1144,9 @@ struct SettingsContentView: View { .padding(.top, 1) Text(unavailableReason) - .scaledFont(size: 12) - .foregroundColor(OmiColors.warning) - .fixedSize(horizontal: false, vertical: true) + .scaledFont(size: 12) + .foregroundColor(OmiColors.warning) + .fixedSize(horizontal: false, vertical: true) } .padding(10) .background( @@ -1619,9 +1626,12 @@ struct SettingsContentView: View { Text("Raw Local Transcription") .scaledFont(size: 15, weight: .semibold) .foregroundColor(OmiColors.textPrimary) - Text("Recent locally persisted sessions and raw segment text") - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) + Text( + "Developer tool. Recent locally persisted sessions and raw segment text — not synced to cloud." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) } Spacer() @@ -1724,9 +1734,27 @@ struct SettingsContentView: View { .foregroundColor(snapshot.isRunning ? OmiColors.success : OmiColors.textTertiary) } - HStack(alignment: .top, spacing: 12) { - comparisonProviderColumn(snapshot.whisper) - comparisonProviderColumn(snapshot.deepgram) + if snapshot.timeBuckets.isEmpty { + HStack(alignment: .top, spacing: 12) { + comparisonProviderColumn(snapshot.whisper) + comparisonProviderColumn(snapshot.deepgram) + } + } else { + VStack(alignment: .leading, spacing: 10) { + HStack { + Text("Time-aligned output") + .scaledFont(size: 12, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + Spacer() + Text("Latest \(snapshot.timeBuckets.count) windows") + .scaledFont(size: 10) + .foregroundColor(OmiColors.textTertiary) + } + + ForEach(snapshot.timeBuckets) { bucket in + comparisonTimeBucketRow(bucket) + } + } } } else { Text( @@ -1796,10 +1824,64 @@ struct SettingsContentView: View { .frame(maxWidth: .infinity, alignment: .topLeading) } + private func comparisonTimeBucketRow( + _ bucket: TranscriptionComparisonTimeBucketSnapshot + ) -> some View { + VStack(alignment: .leading, spacing: 8) { + HStack(spacing: 8) { + Text("\(formatComparisonTime(bucket.startTime))-\(formatComparisonTime(bucket.endTime))") + .scaledMonospacedFont(size: 11, weight: .semibold) + .foregroundColor(OmiColors.textSecondary) + Spacer() + Text("\(bucket.whisperSegmentCount) W / \(bucket.deepgramSegmentCount) D segments") + .scaledFont(size: 10) + .foregroundColor(OmiColors.textTertiary) + } + + HStack(alignment: .top, spacing: 10) { + comparisonBucketTextBox( + title: "Whisper", + text: bucket.whisperText + ) + comparisonBucketTextBox( + title: "Deepgram", + text: bucket.deepgramText + ) + } + } + .padding(10) + .background(OmiColors.backgroundSecondary.opacity(0.45)) + .clipShape(RoundedRectangle(cornerRadius: 8)) + } + + private func comparisonBucketTextBox(title: String, text: String) -> some View { + VStack(alignment: .leading, spacing: 6) { + Text(title) + .scaledFont(size: 10, weight: .semibold) + .foregroundColor(OmiColors.textTertiary) + + Text(text.isEmpty ? "(no transcript)" : text) + .scaledMonospacedFont(size: 11) + .foregroundColor(text.isEmpty ? OmiColors.textTertiary : OmiColors.textSecondary) + .textSelection(.enabled) + .fixedSize(horizontal: false, vertical: true) + .frame(maxWidth: .infinity, alignment: .topLeading) + } + .padding(9) + .background(OmiColors.backgroundSecondary.opacity(0.8)) + .clipShape(RoundedRectangle(cornerRadius: 7)) + .frame(maxWidth: .infinity, alignment: .topLeading) + } + private func formattedComparisonRate(_ value: Double?) -> String { guard let value else { return "n/a" } return "\(Int((value * 100).rounded()))%" } + + private func formatComparisonTime(_ seconds: Double) -> String { + let totalSeconds = max(0, Int(seconds.rounded(.down))) + return String(format: "%d:%02d", totalSeconds / 60, totalSeconds % 60) + } #endif private var localASRAddonPrimaryActionTitle: String { @@ -2410,159 +2492,89 @@ struct SettingsContentView: View { } settingsCard(settingId: "planusage.local.checklist") { - VStack(alignment: .leading, spacing: 12) { - Text("Hybrid provider setup") - .scaledFont(size: 15, weight: .semibold) - .foregroundColor(OmiColors.textPrimary) - - Text( - "Bring your own AI endpoint, then assign models to local task slots. Keys are stored in the local daemon SQLite database on this Mac and sent only to URLs you configure—not Omi cloud proxies." - ) - .scaledFont(size: 12) - .foregroundColor(OmiColors.textSecondary) - .fixedSize(horizontal: false, vertical: true) + VStack(alignment: .leading, spacing: 14) { + HStack(alignment: .top) { + VStack(alignment: .leading, spacing: 4) { + Text("Providers") + .scaledFont(size: 15, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) - ForEach(HybridProviderReadiness.rows(from: backendSettings)) { row in - HStack(alignment: .top, spacing: 10) { - Image( - systemName: row.status == .configured || row.status == .optionalFallback - ? "checkmark.circle.fill" : "circle" + Text( + "Bring your own AI endpoint. Keys are stored on this Mac and sent only to URLs you configure." ) - .foregroundColor( - row.status == .configured - ? OmiColors.success - : (row.status == .optionalFallback - ? OmiColors.textTertiary : OmiColors.warning)) - VStack(alignment: .leading, spacing: 2) { - Text(row.label) - .scaledFont(size: 13, weight: .medium) - .foregroundColor(OmiColors.textPrimary) - Text(row.detail) - .scaledFont(size: 11) - .foregroundColor(OmiColors.textTertiary) - } - Spacer() + .scaledFont(size: 12) + .foregroundColor(OmiColors.textSecondary) + .fixedSize(horizontal: false, vertical: true) } + + Spacer() + + Button { + showHybridProviderSheet = true + } label: { + Text("Configure…") + .scaledFont(size: 13, weight: .semibold) + } + .buttonStyle(.borderedProminent) } - Button(action: applyLocalHybridProviderDefaults) { - Text("Apply local defaults") - .scaledFont(size: 13, weight: .semibold) + VStack(spacing: 8) { + ForEach(HybridProviderReadiness.rows(from: backendSettings)) { row in + hybridProviderSummaryRow(row) + } } - .buttonStyle(.borderedProminent) - .disabled(isSavingHybridProvider) } } - - localHybridProvidersEditorCard } - } - - private var localHybridProvidersEditorCard: some View { - settingsCard(settingId: "planusage.local.providers") { - VStack(alignment: .leading, spacing: 18) { - VStack(alignment: .leading, spacing: 8) { - Text("Provider account") - .scaledFont(size: 12, weight: .medium) - .foregroundColor(OmiColors.textTertiary) - TextField("Base URL", text: $hybridAiBaseURL) - .textFieldStyle(.roundedBorder) - SecureField("API key (optional on loopback)", text: $hybridAiApiKey) - .textFieldStyle(.roundedBorder) - } - - Divider().background(OmiColors.backgroundQuaternary) - - hybridSlotEditorBlock( - title: "Chat", - subtitle: "Direct Ask Omi chat", - model: $hybridChatModel, - slot: HybridProviderPolicy.chatSlot - ) - - Divider().background(OmiColors.backgroundQuaternary) - - hybridSlotEditorBlock( - title: "Post-transcript processing", - subtitle: "Titles, summaries, memories, and action items", - model: $hybridAiModel, - slot: HybridProviderPolicy.postTranscriptSlot - ) - - Divider().background(OmiColors.backgroundQuaternary) - - hybridSlotEditorBlock( - title: "Proactive assistants", - subtitle: - "Local assistant jobs; defaults to \(HybridProviderReadiness.defaultSmallModel())", - model: $hybridEmbedModel, - slot: HybridProviderPolicy.proactiveSlot - ) - - Divider().background(OmiColors.backgroundQuaternary) - - hybridSlotEditorBlock( - title: "Vision, optional", - subtitle: "Leave blank to use local OCR text only", - model: $hybridVisionModel, - slot: HybridProviderPolicy.visionSlot, - optional: true - ) - - Divider().background(OmiColors.backgroundQuaternary) - - VStack(alignment: .leading, spacing: 4) { - Text("Memory search") - .scaledFont(size: 12, weight: .medium) - .foregroundColor(OmiColors.textTertiary) - Text("Local wiki/FTS search uses local_wiki and does not require embeddings.") - .scaledFont(size: 11) - .foregroundColor(OmiColors.textSecondary) - } - - if let hybridProviderStatus { - Text(hybridProviderStatus) - .scaledFont(size: 12) - .foregroundColor(OmiColors.textSecondary) - .textSelection(.enabled) - } - } + .sheet(isPresented: $showHybridProviderSheet) { + HybridProviderSetupSheet( + baseURL: $hybridAiBaseURL, + apiKey: $hybridAiApiKey, + chatModel: $hybridChatModel, + postTranscriptModel: $hybridAiModel, + proactiveModel: $hybridEmbedModel, + visionModel: $hybridVisionModel, + status: hybridProviderStatus, + isSaving: isSavingHybridProvider, + isTesting: isTestingHybridProvider, + applyDefaults: applyLocalHybridProviderDefaults, + save: saveHybridProviderPolicy, + test: testHybridProviderSlot, + dismiss: { showHybridProviderSheet = false } + ) } } - private func hybridSlotEditorBlock( - title: String, - subtitle: String, - model: Binding, - slot: String, - optional: Bool = false - ) -> some View { - VStack(alignment: .leading, spacing: 8) { - Text(title) - .scaledFont(size: 12, weight: .medium) - .foregroundColor(OmiColors.textTertiary) - Text(subtitle) - .scaledFont(size: 11) - .foregroundColor(OmiColors.textSecondary) - TextField("Model", text: model) - .textFieldStyle(.roundedBorder) - HStack(spacing: 10) { - Button("Save") { - saveHybridProviderPolicy() - } - .buttonStyle(.bordered) - .disabled(isSavingHybridProvider) - Button("Test") { - testHybridProviderSlot(slot) - } - .buttonStyle(.bordered) - .disabled( - isTestingHybridProvider - || (optional - && model.wrappedValue.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty) - ) + @ViewBuilder + private func hybridProviderSummaryRow(_ row: HybridProviderReadiness.Row) -> some View { + let isConfigured = row.status == .configured + let isOptional = row.status == .optionalFallback + let chipColor: Color = + isConfigured + ? OmiColors.success + : (isOptional ? OmiColors.textTertiary : OmiColors.warning) + let chipIcon: String = + isConfigured + ? "checkmark.circle.fill" + : (isOptional ? "minus.circle.fill" : "exclamationmark.circle.fill") + + HStack(alignment: .top, spacing: 10) { + Image(systemName: chipIcon) + .scaledFont(size: 14) + .foregroundColor(chipColor) + .padding(.top, 1) + VStack(alignment: .leading, spacing: 2) { + Text(row.label) + .scaledFont(size: 13, weight: .medium) + .foregroundColor(OmiColors.textPrimary) + Text(row.detail) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) } + Spacer() } + .padding(.vertical, 4) } private var cloudPlanUsageSection: some View { @@ -6044,7 +6056,7 @@ struct SettingsContentView: View { systemName: codexAuthStore.isActive ? "checkmark.seal.fill" : "person.crop.circle.badge.checkmark" ) - .foregroundColor(codexAuthStore.isActive ? OmiColors.success : OmiColors.textTertiary) + .foregroundColor(codexAuthStore.isActive ? OmiColors.success : OmiColors.textTertiary) Text(codexAuthStore.isActive ? "ChatGPT plan active" : "Use your ChatGPT subscription") .scaledFont(size: 14, weight: .semibold) .foregroundColor(OmiColors.textPrimary) @@ -6060,8 +6072,8 @@ struct SettingsContentView: View { Text( "Active provider: \(HybridChatClient.currentRoute().displayName) via \(CodexProxyService.defaultBaseURL)" ) - .scaledFont(size: 11) - .foregroundColor(OmiColors.success) + .scaledFont(size: 11) + .foregroundColor(OmiColors.success) } else if codexAuthStore.isEnrolled, let proxyError = codexProxyService.lastError { Text("Proxy not running: \(proxyError)") .scaledFont(size: 11) @@ -6265,23 +6277,65 @@ struct SettingsContentView: View { @ViewBuilder private var byokStatusBanner: some View { + let isLocalMode = DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + settingsCard(settingId: "advanced.devkeys.info") { - HStack(alignment: .top, spacing: 12) { - Image(systemName: hasAllBYOKKeys ? "checkmark.seal.fill" : "key.fill") - .foregroundColor(hasAllBYOKKeys ? OmiColors.success : OmiColors.textTertiary) - VStack(alignment: .leading, spacing: 4) { - Text(hasAllBYOKKeys ? "Free plan active" : "Use Omi free forever") - .scaledFont(size: 14, weight: .semibold) - .foregroundColor(OmiColors.textPrimary) - Text( - hasAllBYOKKeys - ? "You're paying your own providers. Omi skips the subscription charge. Keys stay on this Mac." - : "Provide all four keys (OpenAI, Anthropic, Gemini, Deepgram) to switch to the free plan. Keys stay on this Mac — we never store them on our servers." + if isLocalMode { + VStack(alignment: .leading, spacing: 10) { + HStack(alignment: .top, spacing: 12) { + Image(systemName: "key.fill") + .foregroundColor(OmiColors.textTertiary) + VStack(alignment: .leading, spacing: 4) { + Text("Direct client keys") + .scaledFont(size: 14, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + Text( + "Used by on-device features that bypass the local daemon — Gemini for embeddings, Deepgram for the transcription comparison harness, and OpenAI/Anthropic as fallbacks. Keys stay on this Mac." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + .fixedSize(horizontal: false, vertical: true) + } + Spacer() + } + + HStack(alignment: .top, spacing: 8) { + Image(systemName: "arrow.turn.down.right") + .scaledFont(size: 11) + .foregroundColor(OmiColors.purplePrimary) + .padding(.top, 2) + Text( + "For chat, post-transcript processing, and proactive assistant endpoints, configure providers in **Local Setup** instead." + ) + .scaledFont(size: 11) + .foregroundColor(OmiColors.textSecondary) + .fixedSize(horizontal: false, vertical: true) + Spacer() + } + .padding(10) + .background( + RoundedRectangle(cornerRadius: 8) + .fill(OmiColors.purplePrimary.opacity(0.08)) ) - .scaledFont(size: 12) - .foregroundColor(OmiColors.textTertiary) } - Spacer() + } else { + HStack(alignment: .top, spacing: 12) { + Image(systemName: hasAllBYOKKeys ? "checkmark.seal.fill" : "key.fill") + .foregroundColor(hasAllBYOKKeys ? OmiColors.success : OmiColors.textTertiary) + VStack(alignment: .leading, spacing: 4) { + Text(hasAllBYOKKeys ? "Free plan active" : "Use Omi free forever") + .scaledFont(size: 14, weight: .semibold) + .foregroundColor(OmiColors.textPrimary) + Text( + hasAllBYOKKeys + ? "You're paying your own providers. Omi skips the subscription charge. Keys stay on this Mac." + : "Provide all four keys (OpenAI, Anthropic, Gemini, Deepgram) to switch to the free plan. Keys stay on this Mac — we never store them on our servers." + ) + .scaledFont(size: 12) + .foregroundColor(OmiColors.textTertiary) + } + Spacer() + } } } } diff --git a/desktop/Desktop/Sources/MainWindow/SettingsSidebar.swift b/desktop/Desktop/Sources/MainWindow/SettingsSidebar.swift index 9f1457bfe00..66721ecd1d5 100644 --- a/desktop/Desktop/Sources/MainWindow/SettingsSidebar.swift +++ b/desktop/Desktop/Sources/MainWindow/SettingsSidebar.swift @@ -159,6 +159,10 @@ struct SettingsSearchItem: Identifiable { name: "Plan and Usage", subtitle: "Subscription status and usage limits", keywords: ["subscription", "billing", "plan", "usage", "stripe", "architect", "unlimited"], section: .planUsage, icon: "creditcard", settingId: "planusage.overview"), + SettingsSearchItem( + name: "Local Setup", subtitle: "Local daemon and bring-your-own provider keys", + keywords: ["local", "hybrid", "byok", "provider", "daemon", "on-device", "offline"], + section: .planUsage, icon: "desktopcomputer", settingId: "planusage.overview"), SettingsSearchItem( name: "Current Plan", subtitle: "See your current subscription and renewal status", keywords: ["current plan", "renewal", "billing"], section: .planUsage, icon: "creditcard", @@ -341,7 +345,14 @@ struct SettingsSidebar: View { guard !searchQuery.isEmpty else { return [] } let words = searchQuery.lowercased().split(separator: " ").map(String.init) guard !words.isEmpty else { return [] } + let isLocalMode = DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon return SettingsSearchItem.allSearchableItems.filter { item in + // Hide the cloud "Plan and Usage" entry in local mode, and hide the + // "Local Setup" entry in cloud mode — they share the same section slot + // but represent different surfaces. + if isLocalMode && item.name == "Plan and Usage" { return false } + if !isLocalMode && item.name == "Local Setup" { return false } + let nameLower = item.name.lowercased() let subtitleLower = item.subtitle.lowercased() let keywordsLower = item.keywords.map { $0.lowercased() } @@ -506,6 +517,10 @@ struct SettingsSidebarItem: View { @State private var isHovered = false + private var isLocalHybridMode: Bool { + DesktopBackendEnvironment.selectedBackendTarget.mode == .localDaemon + } + private var icon: String { switch section { case .general: return "gearshape" @@ -514,7 +529,7 @@ struct SettingsSidebarItem: View { case .notifications: return "bell" case .privacy: return "lock.shield" case .account: return "person.circle" - case .planUsage: return "creditcard" + case .planUsage: return isLocalHybridMode ? "desktopcomputer" : "creditcard" case .aiChat: return "cpu" case .floatingBar: return "sparkles" case .shortcuts: return "keyboard" @@ -523,6 +538,13 @@ struct SettingsSidebarItem: View { } } + private var displayName: String { + if section == .planUsage && isLocalHybridMode { + return "Local Setup" + } + return section.rawValue + } + var body: some View { Group { if section == .aiChat { @@ -535,7 +557,7 @@ struct SettingsSidebarItem: View { .foregroundColor(isSelected ? OmiColors.textPrimary : OmiColors.textTertiary) .frame(width: iconWidth) - Text(section.rawValue) + Text(displayName) .scaledFont(size: 14, weight: isSelected ? .medium : .regular) .foregroundColor(isSelected ? OmiColors.textPrimary : OmiColors.textSecondary) diff --git a/desktop/Desktop/Sources/SignInView.swift b/desktop/Desktop/Sources/SignInView.swift index 139b1c9be80..af0011479e7 100644 --- a/desktop/Desktop/Sources/SignInView.swift +++ b/desktop/Desktop/Sources/SignInView.swift @@ -63,71 +63,73 @@ struct SignInView: View { } .buttonStyle(.plain) } - // Sign in with Apple (cloud mode only — hybrid local uses guest session) - Button(action: { - Task { - do { - try await AuthService.shared.signInWithApple() - } catch is CancellationError { - // swallow — user initiated - } catch AuthError.cancelled { - // swallow — user initiated - } catch { - let errorMsg = "Error: \(error.localizedDescription)" - authState.error = errorMsg - NSLog("OMI Sign in error: %@", errorMsg) + if !isLocalHybridMode { + // Sign in with Apple + Button(action: { + Task { + do { + try await AuthService.shared.signInWithApple() + } catch is CancellationError { + // swallow — user initiated + } catch AuthError.cancelled { + // swallow — user initiated + } catch { + let errorMsg = "Error: \(error.localizedDescription)" + authState.error = errorMsg + NSLog("OMI Sign in error: %@", errorMsg) + } } - } - }) { - HStack(spacing: 8) { - Image(systemName: "applelogo") - .scaledFont(size: 18) - Text("Sign in with Apple") - .scaledFont(size: 17, weight: .medium) - } - .foregroundColor(.black) - .frame(maxWidth: .infinity) - .frame(height: 50) - .background(Color.white) - .cornerRadius(10) - } - .buttonStyle(.plain) - .disabled(authState.isLoading || isLocalHybridMode) - - // Sign in with Google - Button(action: { - Task { - do { - try await AuthService.shared.signInWithGoogle() - } catch is CancellationError { - // swallow — user initiated - } catch AuthError.cancelled { - // swallow — user initiated - } catch { - let errorMsg = "Error: \(error.localizedDescription)" - authState.error = errorMsg - NSLog("OMI Sign in error: %@", errorMsg) + }) { + HStack(spacing: 8) { + Image(systemName: "applelogo") + .scaledFont(size: 18) + Text("Sign in with Apple") + .scaledFont(size: 17, weight: .medium) } + .foregroundColor(.black) + .frame(maxWidth: .infinity) + .frame(height: 50) + .background(Color.white) + .cornerRadius(10) } - }) { - HStack(spacing: 8) { - GoogleLogo() - .frame(width: 18, height: 18) - Text("Sign in with Google") - .scaledFont(size: 17, weight: .medium) + .buttonStyle(.plain) + .disabled(authState.isLoading) + + // Sign in with Google + Button(action: { + Task { + do { + try await AuthService.shared.signInWithGoogle() + } catch is CancellationError { + // swallow — user initiated + } catch AuthError.cancelled { + // swallow — user initiated + } catch { + let errorMsg = "Error: \(error.localizedDescription)" + authState.error = errorMsg + NSLog("OMI Sign in error: %@", errorMsg) + } + } + }) { + HStack(spacing: 8) { + GoogleLogo() + .frame(width: 18, height: 18) + Text("Sign in with Google") + .scaledFont(size: 17, weight: .medium) + } + .foregroundColor(.black) + .frame(maxWidth: .infinity) + .frame(height: 50) + .background(Color.white) + .cornerRadius(10) + .overlay( + RoundedRectangle(cornerRadius: 10) + .stroke(Color.gray.opacity(0.3), lineWidth: 1) + ) } - .foregroundColor(.black) - .frame(maxWidth: .infinity) - .frame(height: 50) - .background(Color.white) - .cornerRadius(10) - .overlay( - RoundedRectangle(cornerRadius: 10) - .stroke(Color.gray.opacity(0.3), lineWidth: 1) - ) + .buttonStyle(.plain) + .disabled(authState.isLoading) } - .buttonStyle(.plain) - .disabled(authState.isLoading || isLocalHybridMode) // Loading overlay for both buttons if authState.isLoading { diff --git a/desktop/Desktop/Tests/TranscriptComparisonTests.swift b/desktop/Desktop/Tests/TranscriptComparisonTests.swift index 56ab85b61af..271910bbd36 100644 --- a/desktop/Desktop/Tests/TranscriptComparisonTests.swift +++ b/desktop/Desktop/Tests/TranscriptComparisonTests.swift @@ -66,4 +66,72 @@ final class TranscriptComparisonTests: XCTestCase { XCTAssertNotNil(snapshot.deepgram.error) XCTAssertNil(snapshot.wordDifferenceRate) } + + @MainActor + func testComparisonHarnessGroupsProviderOutputByTimeBucket() throws { + var snapshots: [TranscriptionComparisonHarnessSnapshot] = [] + let harness = TranscriptionComparisonHarness( + language: "en", + deepgramAPIKey: nil, + onSnapshot: { snapshots.append($0) } + ) + + harness.appendWhisperSegments([ + segment(id: "w1", text: "first whisper window", start: 2, end: 4), + segment(id: "w2", text: "second whisper window", start: 35, end: 38), + ]) + harness.appendDeepgramSegmentsForTesting([ + segment(id: "d1", text: "first deepgram window", start: 3, end: 5), + segment(id: "d2", text: "second deepgram window", start: 36, end: 39), + ]) + + let buckets = try XCTUnwrap(snapshots.last?.timeBuckets) + XCTAssertEqual(buckets.count, 2) + XCTAssertEqual(buckets[0].startTime, 0, accuracy: 0.001) + XCTAssertEqual(buckets[0].endTime, 30, accuracy: 0.001) + XCTAssertEqual(buckets[0].whisperText, "first whisper window") + XCTAssertEqual(buckets[0].deepgramText, "first deepgram window") + XCTAssertEqual(buckets[1].startTime, 30, accuracy: 0.001) + XCTAssertEqual(buckets[1].whisperText, "second whisper window") + XCTAssertEqual(buckets[1].deepgramText, "second deepgram window") + } + + @MainActor + func testComparisonHarnessScoresFullTranscriptWhenPreviewIsTruncated() throws { + var snapshots: [TranscriptionComparisonHarnessSnapshot] = [] + let harness = TranscriptionComparisonHarness( + language: "en", + deepgramAPIKey: nil, + onSnapshot: { snapshots.append($0) } + ) + let previewPadding = String(repeating: ".", count: 13_000) + let whisperText = "wrong \(previewPadding) same" + let deepgramText = "right \(previewPadding) same" + + harness.appendWhisperSegments([segment(id: "w-long", text: whisperText, start: 0, end: 1)]) + harness.appendDeepgramSegmentsForTesting([ + segment(id: "d-long", text: deepgramText, start: 0, end: 1) + ]) + + let snapshot = try XCTUnwrap(snapshots.last) + XCTAssertLessThan(snapshot.whisper.transcript.count, whisperText.count) + XCTAssertEqual(snapshot.whisper.wordCount, 2) + XCTAssertEqual(try XCTUnwrap(snapshot.wordDifferenceRate), 0.5, accuracy: 0.0001) + } + + private func segment(id: String, text: String, start: Double, end: Double) + -> NormalizedTranscriptSegment + { + NormalizedTranscriptSegment( + segmentId: id, + speaker: 0, + speakerLabel: nil, + text: text, + start: start, + end: end, + isUser: true, + personId: nil, + translations: [] + ) + } } diff --git a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift index e6415f1a569..a8abcaf0f81 100644 --- a/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift +++ b/desktop/Desktop/Tests/TranscriptionProviderPolicyTests.swift @@ -157,6 +157,54 @@ final class TranscriptionProviderPolicyTests: XCTestCase { } } +final class LocalASRHelperClientTests: XCTestCase { + func testTimeoutTerminatesHungHelperWithoutLeavingBlockingWaitTask() async throws { + let temporaryDirectory = FileManager.default.temporaryDirectory + .appendingPathComponent("LocalASRHelperClientTests-\(UUID().uuidString)", isDirectory: true) + try FileManager.default.createDirectory( + at: temporaryDirectory, withIntermediateDirectories: true) + defer { + try? FileManager.default.removeItem(at: temporaryDirectory) + } + + let helperURL = temporaryDirectory.appendingPathComponent("hung-helper.sh") + try """ + #!/bin/sh + sleep 5 + printf '{"request_id":"timeout-test","engine":"mlx_whisper","model":"small","language":"en","segments":[],"fixture":false}' + """.write(to: helperURL, atomically: true, encoding: .utf8) + try FileManager.default.setAttributes( + [.posixPermissions: NSNumber(value: Int16(0o755))], + ofItemAtPath: helperURL.path + ) + + let audioURL = temporaryDirectory.appendingPathComponent("audio.pcm") + try Data(repeating: 0, count: 320).write(to: audioURL) + + let client = LocalASRHelperClient(executableURL: helperURL, timeoutSeconds: 0.2) + let startedAt = Date() + + do { + _ = try await client.transcribe( + LocalASRTranscriptionRequest( + requestId: "timeout-test", + audioPath: audioURL.path, + language: "en", + sampleRate: 16_000, + channels: 1, + engine: .mlxWhisper, + model: .small, + fixtureSegments: nil + ) + ) + XCTFail("Expected hung helper to time out") + } catch { + XCTAssertLessThan(Date().timeIntervalSince(startedAt), 2) + XCTAssertTrue(error.localizedDescription.contains("timed out")) + } + } +} + final class LocalASRAddonManifestTests: XCTestCase { func testRemoteManifestParsesRuntimeAndModels() throws { let json = """ @@ -535,6 +583,8 @@ final class LocalASRRuntimeTests: XCTestCase { overlapDuration: 0.2, silenceWindowDuration: 0.2, silenceAmplitudeThreshold: 1, + speechPeakAmplitudeThreshold: 1, + speechRMSAmplitudeThreshold: 1, maxPendingChunks: 4 ), requestHandler: { request in @@ -602,6 +652,52 @@ final class LocalASRRuntimeTests: XCTestCase { ) } + func testBackgroundSessionSkipsLowEnergyChunksBeforeWhisper() async throws { + var requestCount = 0 + let session = LocalBackgroundTranscriptionSession( + language: "en", + plan: LocalTranscriptionPlan(engine: .mlxWhisper, model: .small, quality: .balanced), + configuration: LocalBackgroundChunkerConfiguration( + sampleRate: 10, + bytesPerSample: 2, + maxChunkDuration: 1, + minChunkDuration: 0.5, + overlapDuration: 0, + silenceWindowDuration: 0.2, + silenceAmplitudeThreshold: 1, + speechPeakAmplitudeThreshold: 100, + speechRMSAmplitudeThreshold: 50, + maxPendingChunks: 4 + ), + requestHandler: { request in + requestCount += 1 + return LocalASRTranscriptionResponse( + requestId: request.requestId, + engine: request.engine, + model: request.model, + language: request.language, + segments: [ + LocalASRTranscriptSegment( + id: nil, + speaker: 0, + text: "hallucinated silence", + start: 0, + end: 1 + ) + ], + fixture: false + ) + } + ) + + _ = session.append(pcmData: pcm(Array(repeating: 3, count: 10)), startTime: 0) + let result = try await tryUnwrap(session.transcribeNext()) + + XCTAssertEqual(requestCount, 0) + XCTAssertTrue(result.remappedSegments.isEmpty) + XCTAssertTrue(session.snapshot().joinedTranscript.isEmpty) + } + func testBackgroundPipelineCanExerciseRealHelperWhenRuntimeAvailable() async throws { guard let helperURL = LocalASRHelperLocator.defaultExecutableURL() else { throw XCTSkip("Local ASR helper executable is not available") diff --git a/desktop/codex-proxy/src/main.rs b/desktop/codex-proxy/src/main.rs index 5cfc047d2be..23cee5c271e 100644 --- a/desktop/codex-proxy/src/main.rs +++ b/desktop/codex-proxy/src/main.rs @@ -147,7 +147,10 @@ async fn invoke_codex( .post(CODEX_RESPONSES_URL) .headers(hdrs.clone()) .header(CONTENT_TYPE, HeaderValue::from_static("application/json")) - .header(header::ACCEPT, HeaderValue::from_static("text/event-stream")) + .header( + header::ACCEPT, + HeaderValue::from_static("text/event-stream"), + ) .body(bytes.clone()) .send() .await @@ -313,7 +316,12 @@ fn apply_refresh_to_doc(disk: &mut AuthDisk, mut env: RefreshEnvelope) -> Result .filter(|t| !t.is_empty()) .ok_or_else(|| "oauth refresh succeeded but omitted access_token".to_string())?; - if disk.doc.get("tokens").map(|t| t.is_object()).unwrap_or(false) { + if disk + .doc + .get("tokens") + .map(|t| t.is_object()) + .unwrap_or(false) + { if let Some(tokens) = disk.doc.get_mut("tokens").and_then(Value::as_object_mut) { tokens.insert("access_token".to_string(), Value::String(new_access)); if let Some(new_refresh) = env.refresh_token.take().filter(|t| !t.is_empty()) { @@ -406,7 +414,8 @@ fn codex_payload_from_openai_chat(openai_body: &Value) -> Result .ok_or_else(|| format!("messages[{idx}].role missing"))?; if role == "system" { - if let Some(text) = message_content_as_string(msg.get("content").unwrap_or(&Value::Null))? + if let Some(text) = + message_content_as_string(msg.get("content").unwrap_or(&Value::Null))? { if !text.is_empty() { instructions_parts.push(text); @@ -534,7 +543,7 @@ fn normalize_message_content(raw: &Value, role: &str) -> Result { parts .iter() .map(|part| normalize_content_part(part, text_type)) - .collect(), + .collect::, _>>()?, ) } } @@ -546,7 +555,7 @@ fn normalize_message_content(raw: &Value, role: &str) -> Result { }) } -fn normalize_content_part(part: &Value, default_type: &str) -> Value { +fn normalize_content_part(part: &Value, default_type: &str) -> Result { match part { Value::Object(map) => { let mut out = map.clone(); @@ -555,12 +564,45 @@ fn normalize_content_part(part: &Value, default_type: &str) -> Value { } else if let Some(Value::String(kind)) = out.get("type") { if kind == "text" { out.insert("type".to_string(), Value::String(default_type.to_string())); + } else if kind == "image_url" { + out.insert("type".to_string(), Value::String("input_image".to_string())); + let image = + out.get("image_url") + .and_then(Value::as_object) + .ok_or_else(|| { + "image_url content part must include an object".to_string() + })?; + let url = image + .get("url") + .and_then(Value::as_str) + .filter(|url| !url.is_empty()) + .ok_or_else(|| { + "image_url content part must include image_url.url".to_string() + })? + .to_string(); + let detail = image + .get("detail") + .and_then(Value::as_str) + .unwrap_or("auto") + .to_string(); + if !matches!(detail.as_str(), "auto" | "low" | "high") { + return Err( + "image_url content part detail must be auto, low, or high".to_string() + ); + } + out.insert("image_url".to_string(), Value::String(url)); + out.insert("detail".to_string(), Value::String(detail)); + if let Some(text) = out.get("text").and_then(Value::as_str) { + if text.is_empty() { + out.remove("text"); + } + } } } - Value::Object(out) + Ok(Value::Object(out)) } - Value::String(s) => json!({ "type": default_type, "text": s }), - other => other.clone(), + Value::String(s) => Ok(json!({ "type": default_type, "text": s })), + other => Ok(other.clone()), } } @@ -796,4 +838,43 @@ mod tests { ); assert_eq!(out["model"], Value::from("gpt-output")); } + + #[test] + fn maps_openai_image_url_parts_to_codex_input_image() { + let openai = json!({ + "model": "gpt-test", + "messages": [{ + "role":"user", + "content":[ + {"type":"text","text":"describe this"}, + {"type":"image_url","image_url":{"url":"data:image/png;base64,abc","detail":"high"}} + ] + }] + }); + + let out = codex_payload_from_openai_chat(&openai).expect("mapping"); + assert_eq!( + out["input"][0]["content"], + json!([ + {"type":"input_text","text":"describe this"}, + {"type":"input_image","image_url":"data:image/png;base64,abc","detail":"high"} + ]) + ); + } + + #[test] + fn rejects_openai_image_url_parts_without_url() { + let openai = json!({ + "model": "gpt-test", + "messages": [{ + "role":"user", + "content":[ + {"type":"image_url","image_url":{"detail":"low"}} + ] + }] + }); + + let err = codex_payload_from_openai_chat(&openai).expect_err("missing URL should fail"); + assert!(err.contains("image_url.url")); + } }