diff --git a/sqlx-sqlite/src/options/parse.rs b/sqlx-sqlite/src/options/parse.rs index 0530f4204c..96f6a148a3 100644 --- a/sqlx-sqlite/src/options/parse.rs +++ b/sqlx-sqlite/src/options/parse.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; use percent_encoding::{percent_decode_str, percent_encode, AsciiSet}; use url::Url; @@ -9,6 +10,8 @@ use url::Url; use crate::error::Error; use crate::SqliteConnectOptions; +use super::SqliteJournalMode; + // https://www.sqlite.org/uri.html static IN_MEMORY_DB_SEQ: AtomicUsize = AtomicUsize::new(0); @@ -102,6 +105,31 @@ impl SqliteConnectOptions { "vfs" => options.vfs = Some(Cow::Owned(value.into_owned())), + // The journal_mode query parameter specifies the journal mode to use for the database. + // The default is DELETE, but WAL is recommended for most use cases. + // See https://www.sqlite.org/pragma.html#pragma_journal_mode + // as journal_mode is not a standard query parameter, we prefix it with `_` + "_journal_mode" => { + if SqliteJournalMode::from_str(&value).is_err() { + return Err(Error::Configuration( + format!("unknown value {value:?} for `journal_mode`").into(), + )); + } + options + .pragmas + .insert("journal_mode".into(), Some(value.into_owned().into())); + } + + // The busy_timeout query parameter specifies the timeout to use when the database is busy. + // the default is 5 seconds, but this can be changed to a shorter or longer duration. + // See https://www.sqlite.org/pragma.html#pragma_busy_timeout + // as busy_timeout is not a standard query parameter, we prefix it with `_` + "_busy_timeout" => { + if let Some(timeout) = parse_duration(&value) { + options.busy_timeout = timeout; + } + } + _ => { return Err(Error::Configuration( format!("unknown query parameter `{key}` while parsing connection URL") @@ -149,7 +177,6 @@ impl SqliteConnectOptions { false => "private", }; url.query_pairs_mut().append_pair("cache", cache); - if self.immutable { url.query_pairs_mut().append_pair("immutable", "true"); } @@ -158,6 +185,16 @@ impl SqliteConnectOptions { url.query_pairs_mut().append_pair("vfs", vfs); } + if let Some(Some(journal_mode)) = self.pragmas.get("journal_mode") { + url.query_pairs_mut() + .append_pair("_journal_mode", journal_mode); + } + + if !self.busy_timeout.is_zero() { + url.query_pairs_mut() + .append_pair("_busy_timeout", &format_duration(self.busy_timeout)); + } + url } } @@ -180,6 +217,30 @@ impl FromStr for SqliteConnectOptions { } } +type TimeUnitParser = (&'static str, fn(u64) -> Duration); + +// This function is used to parse the busy timeout from the URL query parameters. +// as busy timeout should be short, we only support milliseconds and seconds +fn parse_duration(s: &str) -> Option { + static UNITS: [TimeUnitParser; 2] = [("ms", Duration::from_millis), ("s", Duration::from_secs)]; + for (suffix, func) in UNITS.iter() { + let Some(suffix) = s.strip_suffix(suffix) else { + continue; + }; + let value = suffix.parse::().ok()?; + return Some(func(value)); + } + None +} + +fn format_duration(duration: Duration) -> String { + if duration.subsec_millis() == 0 { + format!("{}s", duration.as_secs()) + } else { + format!("{}ms", duration.as_millis()) + } +} + #[test] fn test_parse_in_memory() -> Result<(), Error> { let options: SqliteConnectOptions = "sqlite::memory:".parse()?; @@ -221,7 +282,7 @@ fn test_parse_shared_in_memory() -> Result<(), Error> { #[test] fn it_returns_the_parsed_url() -> Result<(), Error> { - let url = "sqlite://test.db?mode=rw&cache=shared"; + let url = "sqlite://test.db?mode=rw&cache=shared&_busy_timeout=5s"; let options: SqliteConnectOptions = url.parse()?; let expected_url = Url::parse(url).unwrap(); @@ -229,3 +290,47 @@ fn it_returns_the_parsed_url() -> Result<(), Error> { Ok(()) } + +#[test] +fn it_parse_journal_mode() -> Result<(), Error> { + let url = "sqlite://test.db?_journal_mode=WAL"; + let options: SqliteConnectOptions = url.parse()?; + + let val = options.pragmas.get("journal_mode").cloned().flatten(); + assert_eq!(val, Some(Cow::Owned("WAL".to_string()))); + + let format_url = options.build_url(); + assert_eq!( + format_url.as_str(), + "sqlite://test.db?mode=rw&cache=private&_journal_mode=WAL&_busy_timeout=5s" + ); + Ok(()) +} + +#[test] +fn it_should_return_error_for_invalid_journal_mode() -> Result<(), Error> { + let url = "sqlite://test.db?_journal_mode=invalid"; + let options: Result = url.parse(); + + assert!(options.is_err()); + assert_eq!( + options.unwrap_err().to_string(), + "error with configuration: unknown value \"invalid\" for `journal_mode`" + ); + Ok(()) +} + +#[test] +fn it_should_parse_busy_timeout() -> Result<(), Error> { + let url = "sqlite://test.db?_busy_timeout=1000ms"; + let options: SqliteConnectOptions = url.parse()?; + + assert_eq!(options.busy_timeout, Duration::from_millis(1000)); + + let format_url = options.build_url(); + assert_eq!( + format_url.as_str(), + "sqlite://test.db?mode=rw&cache=private&_busy_timeout=1s" + ); + Ok(()) +}