diff --git a/src-tauri/src/drivers/mysql/mod.rs b/src-tauri/src/drivers/mysql/mod.rs index bd46e6d0..bd2ce962 100644 --- a/src-tauri/src/drivers/mysql/mod.rs +++ b/src-tauri/src/drivers/mysql/mod.rs @@ -1313,8 +1313,16 @@ impl DatabaseDriver for MysqlDriver { let connect_timeout = mysql_numeric_setting("connectTimeout", DEFAULT_MYSQL_CONNECT_TIMEOUT_MS); let timezone = mysql_string_setting("timezone", DEFAULT_MYSQL_TIMEZONE); + let ssl_mode = match params.ssl_mode.as_deref() { + Some("disabled") | Some("disable") => "disabled", + Some("preferred") | Some("prefer") => "preferred", + Some("required") | Some("require") => "required", + Some("verify_ca") => "verify_ca", + Some("verify_identity") => "verify_identity", + _ => "required", + }; Ok(format!( - "mysql://{}@{}:{}/{}?maxAllowedPacket={}&socketTimeout={}&connectTimeout={}&timezone={}", + "mysql://{}@{}:{}/{}?maxAllowedPacket={}&socketTimeout={}&connectTimeout={}&timezone={}&ssl-mode={}", credentials, params.host.as_deref().unwrap_or("localhost"), params.port.unwrap_or(3306), @@ -1323,6 +1331,7 @@ impl DatabaseDriver for MysqlDriver { socket_timeout, connect_timeout, encode(&timezone), + ssl_mode, )) } diff --git a/src-tauri/src/drivers/mysql/tests.rs b/src-tauri/src/drivers/mysql/tests.rs index 7207fbec..985560df 100644 --- a/src-tauri/src/drivers/mysql/tests.rs +++ b/src-tauri/src/drivers/mysql/tests.rs @@ -1,5 +1,40 @@ use super::explain::parse_mysql_query_block; +use super::MysqlDriver; +use crate::drivers::driver_trait::DatabaseDriver; use crate::models::ExplainNode; +use crate::models::{ConnectionParams, DatabaseSelection}; + +#[test] +fn build_connection_url_includes_disabled_ssl_mode() { + let driver = MysqlDriver::new(); + let params = ConnectionParams { + driver: "mysql".to_string(), + host: Some("127.0.0.1".to_string()), + port: Some(3306), + username: Some("root".to_string()), + password: Some("secret".to_string()), + database: DatabaseSelection::Single("dec".to_string()), + ssl_mode: Some("disabled".to_string()), + ssl_ca: None, + ssl_cert: None, + ssl_key: None, + ssh_enabled: None, + ssh_connection_id: None, + ssh_host: None, + ssh_port: None, + ssh_user: None, + ssh_password: None, + ssh_key_file: None, + ssh_key_passphrase: None, + save_in_keychain: None, + connection_id: None, + ..Default::default() + }; + + let url = driver.build_connection_url(¶ms).unwrap(); + + assert!(url.contains("ssl-mode=disabled"), "url was: {url}"); +} /// Helper: parse a MariaDB ANALYZE FORMAT=JSON string and return the root node. fn parse_json(json: &str) -> ExplainNode { diff --git a/src-tauri/src/pool_manager.rs b/src-tauri/src/pool_manager.rs index c5130048..37d084cf 100644 --- a/src-tauri/src/pool_manager.rs +++ b/src-tauri/src/pool_manager.rs @@ -79,8 +79,21 @@ fn mysql_numeric_setting(key: &str, default: u64) -> u64 { /// Build a stable connection key that works with SSH tunnels. /// If connection_id is provided (from saved connections), use it for stable pooling. /// Otherwise fall back to host:port:database (for ad-hoc connections). -fn build_connection_key(params: &ConnectionParams, connection_id: Option<&str>) -> String { - if let Some(conn_id) = connection_id { +pub(crate) fn build_connection_key( + params: &ConnectionParams, + connection_id: Option<&str>, +) -> String { + let tls_key = (params.driver == "mysql").then(|| { + format!( + "ssl:{}:{}:{}:{}", + params.ssl_mode.as_deref().unwrap_or("default"), + params.ssl_ca.as_deref().unwrap_or(""), + params.ssl_cert.as_deref().unwrap_or(""), + params.ssl_key.as_deref().unwrap_or("") + ) + }); + + let base_key = if let Some(conn_id) = connection_id { // Include database in key so different databases on the same connection use separate pools format!("{}:conn:{}:{}", params.driver, conn_id, params.database) } else { @@ -92,10 +105,16 @@ fn build_connection_key(params: &ConnectionParams, connection_id: Option<&str>) params.port.unwrap_or(0), params.database ) + }; + + if let Some(tls_key) = tls_key { + format!("{base_key}:{tls_key}") + } else { + base_key } } -fn build_mysql_options( +pub(crate) fn build_mysql_options( params: &ConnectionParams, override_db: Option<&str>, ) -> Result { diff --git a/src-tauri/src/pool_manager_tests.rs b/src-tauri/src/pool_manager_tests.rs index 8713fa75..bfb23c74 100644 --- a/src-tauri/src/pool_manager_tests.rs +++ b/src-tauri/src/pool_manager_tests.rs @@ -1,6 +1,42 @@ #[cfg(test)] mod tests { - use crate::pool_manager::format_error_chain; + use crate::models::{ConnectionParams, DatabaseSelection}; + use crate::pool_manager::{build_connection_key, build_mysql_options, format_error_chain}; + use sqlx::mysql::MySqlSslMode; + + fn connection_params(driver: &str, ssl_mode: Option<&str>) -> ConnectionParams { + ConnectionParams { + driver: driver.to_string(), + host: Some("127.0.0.1".to_string()), + port: Some(match driver { + "postgres" => 5432, + "mysql" => 3306, + _ => 0, + }), + username: Some("dec".to_string()), + password: Some("secret".to_string()), + database: DatabaseSelection::Single("dec".to_string()), + ssl_mode: ssl_mode.map(ToOwned::to_owned), + ssl_ca: None, + ssl_cert: None, + ssl_key: None, + ssh_enabled: Some(true), + ssh_connection_id: Some("ssh-1".to_string()), + ssh_host: Some("149.202.85.42".to_string()), + ssh_port: Some(2222), + ssh_user: Some("julien".to_string()), + ssh_password: None, + ssh_key_file: Some("/Users/julienbarbe/.ssh/id_rsa".to_string()), + ssh_key_passphrase: None, + save_in_keychain: None, + connection_id: Some("conn-1".to_string()), + ..Default::default() + } + } + + fn mysql_params(ssl_mode: &str) -> ConnectionParams { + connection_params("mysql", Some(ssl_mode)) + } #[test] fn format_error_chain_walks_sources() { @@ -34,6 +70,58 @@ mod tests { "outer message -> inner cause" ); } + + #[test] + fn mysql_pool_key_changes_when_ssl_mode_changes() { + let required = mysql_params("required"); + let disabled = mysql_params("disabled"); + + assert_ne!( + build_connection_key(&required, Some("conn-1")), + build_connection_key(&disabled, Some("conn-1")) + ); + } + + #[test] + fn postgres_pool_key_ignores_mysql_ssl_key_fields() { + let required = connection_params("postgres", Some("required")); + let disabled = connection_params("postgres", Some("disabled")); + + assert_eq!( + build_connection_key(&required, Some("conn-1")), + build_connection_key(&disabled, Some("conn-1")) + ); + } + + #[test] + fn sqlite_pool_key_ignores_mysql_ssl_key_fields() { + let required = connection_params("sqlite", Some("required")); + let disabled = connection_params("sqlite", Some("disabled")); + + assert_eq!( + build_connection_key(&required, Some("conn-1")), + build_connection_key(&disabled, Some("conn-1")) + ); + } + + #[test] + fn mysql_options_accept_snake_case_verify_ssl_modes() { + let verify_ca = mysql_params("verify_ca"); + let verify_identity = mysql_params("verify_identity"); + + assert!(matches!( + build_mysql_options(&verify_ca, None) + .unwrap() + .get_ssl_mode(), + MySqlSslMode::VerifyCa + )); + assert!(matches!( + build_mysql_options(&verify_identity, None) + .unwrap() + .get_ssl_mode(), + MySqlSslMode::VerifyIdentity + )); + } } #[cfg(test)]