From 5efd77ed998d12ce4f4ae3e10f1965cb97cb5797 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Tue, 21 Feb 2023 17:37:39 +0100 Subject: [PATCH 001/130] Allow setting to return found rows when writing to the database --- src/conn/mod.rs | 52 +++++++++++++++++++++++++++++++++++++++++++++++++ src/opts/mod.rs | 20 +++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 04c1e5f7..7aaeedf2 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1144,6 +1144,58 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> { + let opts = get_opts().writes_return_found_rows(true); + let mut conn = Conn::new(opts).await.unwrap(); + + "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)" + .ignore(&mut conn) + .await?; + + "INSERT INTO mysql.found_rows (val) VALUES (1)" + .ignore(&mut conn) + .await?; + + // Inserted one row, affected should be one. + assert_eq!(conn.affected_rows(), 1); + + "UPDATE mysql.found_rows SET val = 1 WHERE val = 1" + .ignore(&mut conn) + .await?; + + // The query doesn't affect any rows, but due to us wanting FOUND rows, + // this has to return one. + assert_eq!(conn.affected_rows(), 1); + + Ok(()) + } + + #[tokio::test] + async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> { + let mut conn = Conn::new(get_opts()).await.unwrap(); + + "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)" + .ignore(&mut conn) + .await?; + + "INSERT INTO mysql.found_rows (val) VALUES (1)" + .ignore(&mut conn) + .await?; + + // Inserted one row, affected should be one. + assert_eq!(conn.affected_rows(), 1); + + "UPDATE mysql.found_rows SET val = 1 WHERE val = 1" + .ignore(&mut conn) + .await?; + + // The query doesn't affect any rows. + assert_eq!(conn.affected_rows(), 0); + + Ok(()) + } + async fn read_binlog_streams_and_close_their_connections( pool: Option<&Pool>, binlog_server_ids: (u32, u32, u32), diff --git a/src/opts/mod.rs b/src/opts/mod.rs index a74a6fc0..10be818d 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -404,6 +404,10 @@ pub(crate) struct MysqlOpts { /// /// Available via `secure_auth` connection url parameter. secure_auth: bool, + + /// Changes the behavior of the affected count returned for writes (UPDATE/INSERT etc). + /// It makes MySQL return the FOUND rows instead of the AFFECTED rows. + client_found_rows: bool, } /// Mysql connection options. @@ -721,6 +725,11 @@ impl Opts { self.inner.mysql_opts.secure_auth } + /// If true, write queries return found rows and if false, affected rows. + pub fn writes_return_found_rows(&self) -> bool { + self.inner.mysql_opts.client_found_rows + } + pub(crate) fn get_capabilities(&self) -> CapabilityFlags { let mut out = CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_SECURE_CONNECTION @@ -742,6 +751,9 @@ impl Opts { if self.inner.mysql_opts.compression.is_some() { out |= CapabilityFlags::CLIENT_COMPRESS; } + if self.writes_return_found_rows() { + out |= CapabilityFlags::CLIENT_FOUND_ROWS; + } out } @@ -767,6 +779,7 @@ impl Default for MysqlOpts { max_allowed_packet: None, wait_timeout: None, secure_auth: true, + client_found_rows: false, } } } @@ -1017,6 +1030,13 @@ impl OptsBuilder { self.opts.secure_auth = secure_auth; self } + + /// Changes the behavior of the affected count returned for writes. + /// See [`Opts::writes_return_found_rows`]. + pub fn writes_return_found_rows(mut self, client_found_rows: bool) -> Self { + self.opts.client_found_rows = client_found_rows; + self + } } impl From for Opts { From 75ade0d0780f83683527e5fbc117122cb4626a1a Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 22 Feb 2023 11:30:36 +0300 Subject: [PATCH 002/130] Rename writes_return_found_rows > client_found_rows. Add connection URL support --- src/conn/mod.rs | 2 +- src/opts/mod.rs | 40 ++++++++++++++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 7aaeedf2..b29cdf41 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1146,7 +1146,7 @@ mod test { #[tokio::test] async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> { - let opts = get_opts().writes_return_found_rows(true); + let opts = get_opts().client_found_rows(true); let mut conn = Conn::new(opts).await.unwrap(); "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)" diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 10be818d..560b1cb9 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -405,6 +405,8 @@ pub(crate) struct MysqlOpts { /// Available via `secure_auth` connection url parameter. secure_auth: bool, + /// Enables `CLIENT_FOUND_ROWS` capability (defaults to `false`). + /// /// Changes the behavior of the affected count returned for writes (UPDATE/INSERT etc). /// It makes MySQL return the FOUND rows instead of the AFFECTED rows. client_found_rows: bool, @@ -725,8 +727,23 @@ impl Opts { self.inner.mysql_opts.secure_auth } - /// If true, write queries return found rows and if false, affected rows. - pub fn writes_return_found_rows(&self) -> bool { + /// Returns `true` if `CLIENT_FOUND_ROWS` capability is enabled (defaults to `false`). + /// + /// `CLIENT_FOUND_ROWS` changes the behavior of the affected count returned for writes + /// (UPDATE/INSERT etc). It makes MySQL return the FOUND rows instead of the AFFECTED rows. + /// + /// # Connection URL + /// + /// Use `client_found_rows` URL parameter to set this value. E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?client_found_rows=true")?; + /// assert!(opts.client_found_rows()); + /// # Ok(()) } + /// ``` + pub fn client_found_rows(&self) -> bool { self.inner.mysql_opts.client_found_rows } @@ -751,7 +768,7 @@ impl Opts { if self.inner.mysql_opts.compression.is_some() { out |= CapabilityFlags::CLIENT_COMPRESS; } - if self.writes_return_found_rows() { + if self.client_found_rows() { out |= CapabilityFlags::CLIENT_FOUND_ROWS; } @@ -1031,9 +1048,8 @@ impl OptsBuilder { self } - /// Changes the behavior of the affected count returned for writes. - /// See [`Opts::writes_return_found_rows`]. - pub fn writes_return_found_rows(mut self, client_found_rows: bool) -> Self { + /// Enables or disables `CLIENT_FOUND_ROWS` capability. See [`Opts::client_found_rows`]. + pub fn client_found_rows(mut self, client_found_rows: bool) -> Self { self.opts.client_found_rows = client_found_rows; self } @@ -1265,6 +1281,18 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "client_found_rows" { + match bool::from_str(&*value) { + Ok(client_found_rows) => { + opts.client_found_rows = client_found_rows; + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "client_found_rows".into(), + value, + }); + } + } } else if key == "socket" { opts.socket = Some(value) } else if key == "compression" { From 6ca8becdf678b49287feba8820afd399e679d2f7 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Sat, 11 Mar 2023 18:27:27 -0500 Subject: [PATCH 003/130] Update to lru v0.10.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ce7e3958..e18c40c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ futures-core = "0.3" futures-util = "0.3" futures-sink = "0.3" lazy_static = "1" -lru = "0.8.1" +lru = "0.10.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } mysql_common = { version = "0.29.2", default-features = false } once_cell = "1.7.2" From 2842af3595c2e30fe97bf2208ce31bb9cc5ed114 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Sun, 9 Apr 2023 18:54:44 +0300 Subject: [PATCH 004/130] Bump mysql_common to 0.30, add cleartext plugin support --- Cargo.toml | 23 ++++++++------ src/conn/mod.rs | 51 +++++++++++++++++++++++-------- src/conn/pool/mod.rs | 2 +- src/error/mod.rs | 3 ++ src/lib.rs | 9 ++++-- src/opts/mod.rs | 73 ++++++++++++++++++++++++++++++++++++++++++++ tests/exports.rs | 14 +++++---- 7 files changed, 143 insertions(+), 32 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ce7e3958..5c339b7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ edition = "2018" categories = ["asynchronous", "database"] [dependencies] -bytes = "1.0" +bytes = "1.4" crossbeam = "0.8.1" flate2 = { version = "1.0", default-features = false } futures-core = "0.3" @@ -22,7 +22,9 @@ futures-sink = "0.3" lazy_static = "1" lru = "0.8.1" mio = { version = "0.8.0", features = ["os-poll", "net"] } -mysql_common = { version = "0.29.2", default-features = false } +mysql_common = { version = "0.30", default-features = false, features = [ + "derive", +] } once_cell = "1.7.2" pem = "1.0.1" percent-encoding = "2.1.0" @@ -34,7 +36,9 @@ socket2 = "0.4.2" thiserror = "1.0.4" tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] } tokio-util = { version = "0.7.2", features = ["codec", "io"] } -tracing = { version = "0.1.37", default-features = false, features = ["attributes"], optional = true } +tracing = { version = "0.1.37", default-features = false, features = [ + "attributes", +], optional = true } twox-hash = "1" url = "2.1" @@ -76,20 +80,20 @@ rand = "0.8.0" [features] default = [ "flate2/zlib", - "mysql_common/bigdecimal03", + "mysql_common/bigdecimal", "mysql_common/rust_decimal", - "mysql_common/time03", - "mysql_common/uuid", + "mysql_common/time", "mysql_common/frunk", + "derive", "native-tls-tls", ] default-rustls = [ "flate2/zlib", - "mysql_common/bigdecimal03", + "mysql_common/bigdecimal", "mysql_common/rust_decimal", - "mysql_common/time03", - "mysql_common/uuid", + "mysql_common/time", "mysql_common/frunk", + "derive", "rustls-tls", ] minimal = ["flate2/zlib"] @@ -102,6 +106,7 @@ rustls-tls = [ "rustls-pemfile", ] tracing = ["dep:tracing"] +derive = ["mysql_common/derive"] nightly = [] [lib] diff --git a/src/conn/mod.rs b/src/conn/mod.rs index b29cdf41..0df5a28c 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -473,17 +473,15 @@ impl Conn { .unwrap_or((0, 0, 0)); self.inner.id = handshake.connection_id(); self.inner.status = handshake.status_flags(); + + // Allow only CachingSha2Password and MysqlNativePassword here + // because sha256_password is deprecated and other plugins won't + // appear here. self.inner.auth_plugin = match handshake.auth_plugin() { - Some(AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword) => { - AuthPlugin::MysqlNativePassword - } Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password, - Some(AuthPlugin::Other(ref name)) => { - let name = String::from_utf8_lossy(name).into(); - return Err(DriverError::UnknownAuthPlugin { name }.into()); - } - None => AuthPlugin::MysqlNativePassword, + _ => AuthPlugin::MysqlNativePassword, }; + Ok(()) } @@ -567,13 +565,32 @@ impl Conn { self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned(); - let plugin_data = self - .inner - .auth_plugin - .gen_data(self.inner.opts.pass(), &*self.inner.nonce); + let plugin_data = match &self.inner.auth_plugin { + x @ AuthPlugin::CachingSha2Password => { + x.gen_data(self.inner.opts.pass(), &self.inner.nonce) + } + x @ AuthPlugin::MysqlNativePassword => { + x.gen_data(self.inner.opts.pass(), &self.inner.nonce) + } + x @ AuthPlugin::MysqlOldPassword => { + if self.inner.opts.secure_auth() { + return Err(DriverError::MysqlOldPasswordDisabled.into()); + } else { + x.gen_data(self.inner.opts.pass(), &self.inner.nonce) + } + } + x @ AuthPlugin::MysqlClearPassword => { + if self.inner.opts.enable_cleartext_plugin() { + x.gen_data(self.inner.opts.pass(), &self.inner.nonce) + } else { + return Err(DriverError::CleartextPluginDisabled.into()); + } + } + x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce), + }; if let Some(plugin_data) = plugin_data { - self.write_struct(&plugin_data).await?; + self.write_struct(&plugin_data.into_owned()).await?; } else { self.write_packet(crate::BUFFER_POOL.get()).await?; } @@ -599,6 +616,14 @@ impl Conn { self.continue_caching_sha2_password_auth().await?; Ok(()) } + AuthPlugin::MysqlClearPassword => { + if self.inner.opts.enable_cleartext_plugin() { + self.continue_mysql_native_password_auth().await?; + Ok(()) + } else { + Err(DriverError::CleartextPluginDisabled.into()) + } + } AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin { name: String::from_utf8_lossy(name.as_ref()).to_string(), } diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 9fa107be..fd38affe 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -461,7 +461,7 @@ mod test { #[tokio::test] async fn should_connect() -> super::Result<()> { - let pool = Pool::new(get_opts()); + let pool = Pool::new(crate::Opts::from(get_opts())); pool.get_conn().await?.ping().await?; pool.disconnect().await?; Ok(()) diff --git a/src/error/mod.rs b/src/error/mod.rs index 81087260..3b1235c2 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -163,6 +163,9 @@ pub enum DriverError { #[error("Client asked for SSL but server does not have this capability")] NoClientSslFlagFromServer, + + #[error("mysql_clear_password must be enabled on the client side")] + CleartextPluginDisabled, } #[derive(Debug, Error)] diff --git a/src/lib.rs b/src/lib.rs index f28b8ce9..2f638e78 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -418,6 +418,9 @@ #[cfg(feature = "nightly")] extern crate test; +#[cfg(feature = "derive")] +extern crate mysql_common; + pub use mysql_common::{constants as consts, params}; use std::sync::Arc; @@ -481,7 +484,7 @@ pub use mysql_common::packets::{ Gtids, Schema, SessionStateChange, SystemVariable, TransactionCharacteristics, TransactionState, Unsupported, }, - BinlogDumpFlags, Column, Interval, OkPacket, SessionStateInfo, Sid, + BinlogDumpFlags, Column, GnoInterval, OkPacket, SessionStateInfo, Sid, }; pub mod binlog { @@ -541,9 +544,9 @@ pub mod prelude { #[doc(inline)] pub use crate::queryable::Queryable; #[doc(inline)] - pub use mysql_common::row::convert::FromRow; + pub use mysql_common::prelude::FromRow; #[doc(inline)] - pub use mysql_common::value::convert::{ConvIr, FromValue, ToValue}; + pub use mysql_common::prelude::{FromValue, ToValue}; /// Everything that is a statement. /// diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 560b1cb9..18953ae8 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -410,6 +410,17 @@ pub(crate) struct MysqlOpts { /// Changes the behavior of the affected count returned for writes (UPDATE/INSERT etc). /// It makes MySQL return the FOUND rows instead of the AFFECTED rows. client_found_rows: bool, + + /// Enables Client-Side Cleartext Pluggable Authentication (defaults to `false`). + /// + /// Enables client to send passwords to the server as cleartext, without hashing or encryption + /// (consult MySql documentation for more info). + /// + /// # Security Notes + /// + /// Sending passwords as cleartext may be a security problem in some configurations. Please + /// consider using TLS or encrypted tunnels for server connection. + enable_cleartext_plugin: bool, } /// Mysql connection options. @@ -747,6 +758,31 @@ impl Opts { self.inner.mysql_opts.client_found_rows } + /// Returns `true` if `mysql_clear_password` plugin support is enabled (defaults to `false`). + /// + /// `mysql_clear_password` enables client to send passwords to the server as cleartext, without + /// hashing or encryption (consult MySql documentation for more info). + /// + /// # Security Notes + /// + /// Sending passwords as cleartext may be a security problem in some configurations. Please + /// consider using TLS or encrypted tunnels for server connection. + /// + /// # Connection URL + /// + /// Use `enable_cleartext_plugin` URL parameter to set this value. E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?enable_cleartext_plugin=true")?; + /// assert!(opts.enable_cleartext_plugin()); + /// # Ok(()) } + /// ``` + pub fn enable_cleartext_plugin(&self) -> bool { + self.inner.mysql_opts.enable_cleartext_plugin + } + pub(crate) fn get_capabilities(&self) -> CapabilityFlags { let mut out = CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_SECURE_CONNECTION @@ -797,6 +833,7 @@ impl Default for MysqlOpts { wait_timeout: None, secure_auth: true, client_found_rows: false, + enable_cleartext_plugin: false, } } } @@ -1053,6 +1090,32 @@ impl OptsBuilder { self.opts.client_found_rows = client_found_rows; self } + + /// Enables Client-Side Cleartext Pluggable Authentication (defaults to `false`). + /// + /// Enables client to send passwords to the server as cleartext, without hashing or encryption + /// (consult MySql documentation for more info). + /// + /// # Security Notes + /// + /// Sending passwords as cleartext may be a security problem in some configurations. Please + /// consider using TLS or encrypted tunnels for server connection. + /// + /// # Connection URL + /// + /// Use `enable_cleartext_plugin` URL parameter to set this value. E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?enable_cleartext_plugin=true")?; + /// assert!(opts.enable_cleartext_plugin()); + /// # Ok(()) } + /// ``` + pub fn enable_cleartext_plugin(mut self, enable_cleartext_plugin: bool) -> Self { + self.opts.enable_cleartext_plugin = enable_cleartext_plugin; + self + } } impl From for Opts { @@ -1235,6 +1298,16 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "enable_cleartext_plugin" { + match bool::from_str(&*value) { + Ok(parsed) => opts.enable_cleartext_plugin = parsed, + Err(_) => { + return Err(UrlError::InvalidParamValue { + param: key.to_string(), + value, + }); + } + } } else if key == "tcp_nodelay" { match bool::from_str(&*value) { Ok(value) => opts.tcp_nodelay = value, diff --git a/tests/exports.rs b/tests/exports.rs index c8b13137..92255dbb 100644 --- a/tests/exports.rs +++ b/tests/exports.rs @@ -4,12 +4,14 @@ use mysql_async::{ futures::{DisconnectPool, GetConn}, params, prelude::{ - BatchQuery, ConvIr, FromRow, FromValue, GlobalHandler, Protocol, Query, Queryable, - StatementLike, ToValue, + BatchQuery, FromRow, FromValue, GlobalHandler, Protocol, Query, Queryable, StatementLike, + ToValue, }, - BinaryProtocol, Column, Conn, Deserialized, DriverError, Error, FromRowError, FromValueError, - IoError, IsolationLevel, Opts, OptsBuilder, Params, ParseError, Pool, PoolConstraints, - PoolOpts, QueryResult, Result, Row, Serialized, ServerError, SslOpts, Statement, TextProtocol, - Transaction, TxOpts, UrlError, Value, WhiteListFsHandler, DEFAULT_INACTIVE_CONNECTION_TTL, + BinaryProtocol, BinlogDumpFlags, BinlogRequest, Column, Conn, Deserialized, DriverError, Error, + FromRowError, FromValueError, GnoInterval, Gtids, IoError, IsolationLevel, OkPacket, Opts, + OptsBuilder, Params, ParseError, Pool, PoolConstraints, PoolOpts, QueryResult, Result, Row, + Schema, Serialized, ServerError, SessionStateChange, SessionStateInfo, Sid, SslOpts, Statement, + SystemVariable, TextProtocol, Transaction, TransactionCharacteristics, TransactionState, + TxOpts, Unsupported, UrlError, Value, WhiteListFsHandler, DEFAULT_INACTIVE_CONNECTION_TTL, DEFAULT_TTL_CHECK_INTERVAL, }; From ad90c5254a6f421fb88e3210f14420600aa93e40 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Sun, 9 Apr 2023 18:56:11 +0300 Subject: [PATCH 005/130] Bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 5c339b7b..8a068cf2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.31.3" +version = "0.32.0" exclude = ["test/*"] edition = "2018" categories = ["asynchronous", "database"] From 134cbf8b5f31c9c91e1559f9c596b98da56234f3 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 12 Apr 2023 15:09:36 +0300 Subject: [PATCH 006/130] Implement Conn::change_user --- src/conn/mod.rs | 233 ++++++++++++++++++++++++++---- src/conn/pool/futures/get_conn.rs | 14 +- src/conn/pool/mod.rs | 25 +++- src/conn/pool/recycler.rs | 45 ++++-- src/conn/routines/change_user.rs | 58 ++++++++ src/conn/routines/mod.rs | 3 +- src/io/mod.rs | 7 +- src/lib.rs | 7 +- src/opts/mod.rs | 113 +++++++++++++++ 9 files changed, 451 insertions(+), 54 deletions(-) create mode 100644 src/conn/routines/change_user.rs diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 0df5a28c..3715a468 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -44,7 +44,7 @@ use crate::{ transaction::TxStatus, BinaryProtocol, Queryable, TextProtocol, }, - BinlogStream, InfileData, OptsBuilder, + BinlogStream, ChangeUserOpts, InfileData, OptsBuilder, }; use self::routines::Routine; @@ -102,6 +102,7 @@ struct ConnInner { pool: Option, pending_result: std::result::Result, ServerError>, tx_status: TxStatus, + reset_upon_returning_to_a_pool: bool, opts: Opts, last_io: Instant, wait_timeout: Duration, @@ -109,6 +110,7 @@ struct ConnInner { nonce: Vec, auth_plugin: AuthPlugin<'static>, auth_switched: bool, + server_key: Option>, /// Connection is already disconnected. pub(crate) disconnected: bool, /// One-time connection-level infile handler. @@ -126,6 +128,8 @@ impl fmt::Debug for ConnInner { .field("tx_status", &self.tx_status) .field("stream", &self.stream) .field("options", &self.opts) + .field("server_key", &self.server_key) + .field("auth_plugin", &self.auth_plugin) .finish() } } @@ -154,7 +158,9 @@ impl ConnInner { auth_plugin: AuthPlugin::MysqlNativePassword, auth_switched: false, disconnected: false, + server_key: None, infile_handler: None, + reset_upon_returning_to_a_pool: false, } } @@ -416,16 +422,33 @@ impl Conn { /// Returns true if io stream is encrypted. fn is_secure(&self) -> bool { #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] - if let Some(ref stream) = self.inner.stream { - stream.is_secure() - } else { - false + { + self.inner + .stream + .as_ref() + .map(|x| x.is_secure()) + .unwrap_or_default() } #[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))] false } + /// Returns true if io stream is socket. + fn is_socket(&self) -> bool { + #[cfg(unix)] + { + self.inner + .stream + .as_ref() + .map(|x| x.is_socket()) + .unwrap_or_default() + } + + #[cfg(not(unix))] + false + } + /// Hacky way to move connection through &mut. `self` becomes unusable. fn take(&mut self) -> Conn { mem::replace(self, Conn::empty(Default::default())) @@ -663,16 +686,21 @@ impl Conn { let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes()); pass.as_mut().push(0); - if self.is_secure() { + if self.is_secure() || self.is_socket() { self.write_packet(pass).await?; } else { - self.write_bytes(&[0x02][..]).await?; - let packet = self.read_packet().await?; - let key = &packet[1..]; + if self.inner.server_key.is_none() { + self.write_bytes(&[0x02][..]).await?; + let packet = self.read_packet().await?; + self.inner.server_key = Some(packet[1..].to_vec()); + } for (i, byte) in pass.as_mut().iter_mut().enumerate() { *byte ^= self.inner.nonce[i % self.inner.nonce.len()]; } - let encrypted_pass = crypto::encrypt(&*pass, key); + let encrypted_pass = crypto::encrypt( + &*pass, + self.inner.server_key.as_deref().expect("unreachable"), + ); self.write_bytes(&*encrypted_pass).await?; }; self.drop_packet().await?; @@ -958,12 +986,13 @@ impl Conn { self.inner.last_io.elapsed() } - /// Executes `COM_RESET_CONNECTION` on `self`. + /// Executes [`COM_RESET_CONNECTION`][1]. /// - /// If server version is older than 5.7.2, then it'll reconnect. - pub async fn reset(&mut self) -> Result<()> { - let pool = self.inner.pool.clone(); - + /// Returns `false` if command is not supported (requires MySql >5.7.2, MariaDb >10.2.3). + /// For older versions consider using [`Conn::change_user`]. + /// + /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-reset-connection.html + pub async fn reset(&mut self) -> Result { let supports_com_reset_connection = if self.inner.is_mariadb { self.inner.version >= (10, 2, 4) } else { @@ -973,19 +1002,62 @@ impl Conn { if supports_com_reset_connection { self.routine(routines::ResetRoutine).await?; - } else { - let opts = self.inner.opts.clone(); - let old_conn = std::mem::replace(self, Conn::new(opts).await?); - // tidy up the old connection - old_conn.close_conn().await?; - }; + self.inner.stmt_cache.clear(); + self.inner.infile_handler = None; + } + + Ok(supports_com_reset_connection) + } + /// Executes [`COM_CHANGE_USER`][1]. + /// + /// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that + /// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4). + /// + /// ## Note + /// + /// * Using non-default `opts` for a pooled connection is discouraging. + /// * Connection options will be permanently updated. + /// + /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html + pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> { + // We'll kick this connection from a pool if opts are changed. + if opts != ChangeUserOpts::default() { + let mut opts_changed = false; + if let Some(user) = opts.user() { + opts_changed |= user != self.opts().user() + }; + if let Some(pass) = opts.pass() { + opts_changed |= pass != self.opts().pass() + }; + if let Some(db_name) = opts.db_name() { + opts_changed |= db_name != self.opts().db_name() + }; + if opts_changed { + if let Some(pool) = self.inner.pool.take() { + pool.cancel_connection(); + } + } + } + + let conn_opts = &mut self.inner.opts; + opts.update_opts(conn_opts); + self.routine(routines::ChangeUser).await?; self.inner.stmt_cache.clear(); self.inner.infile_handler = None; - self.inner.pool = pool; Ok(()) } + /// Resets the connection upon returning it to a pool. + /// + /// Will invoke `COM_CHANGE_USER` if `COM_RESET_CONNECTION` is not supported. + async fn reset_for_pool(mut self) -> Result { + if !self.reset().await? { + self.change_user(Default::default()).await?; + } + Ok(self) + } + /// Requires that `self.inner.tx_status != TxStatus::None` async fn rollback_transaction(&mut self) -> Result<()> { debug_assert_ne!(self.inner.tx_status, TxStatus::None); @@ -1094,13 +1166,14 @@ mod test { use bytes::Bytes; use futures_util::stream::{self, StreamExt}; use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN}; + use rand::Fill; use tokio::time::timeout; use std::time::Duration; use crate::{ - from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest, Conn, - Error, OptsBuilder, Pool, WhiteListFsHandler, + from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest, + ChangeUserOpts, Conn, Error, OptsBuilder, Pool, Value, WhiteListFsHandler, }; async fn gen_dummy_data() -> super::Result<()> { @@ -1471,9 +1544,115 @@ mod test { #[tokio::test] async fn should_reset_the_connection() -> super::Result<()> { let mut conn = Conn::new(get_opts()).await?; - conn.exec_drop("SELECT ?", (1_u8,)).await?; - conn.reset().await?; - conn.exec_drop("SELECT ?", (1_u8,)).await?; + let max_execution_time = conn + .query_first::("SELECT @@max_execution_time") + .await? + .unwrap(); + + conn.exec_drop( + "SET SESSION max_execution_time = ?", + (max_execution_time + 1,), + ) + .await?; + + assert_eq!( + conn.query_first::("SELECT @@max_execution_time") + .await?, + Some(max_execution_time + 1) + ); + + if conn.reset().await? { + assert_eq!( + conn.query_first::("SELECT @@max_execution_time") + .await?, + Some(max_execution_time) + ); + } else { + assert_eq!( + conn.query_first::("SELECT @@max_execution_time") + .await?, + Some(max_execution_time + 1) + ); + } + + conn.disconnect().await?; + Ok(()) + } + + #[tokio::test] + async fn should_change_user() -> super::Result<()> { + let mut conn = Conn::new(get_opts()).await?; + let max_execution_time = conn + .query_first::("SELECT @@max_execution_time") + .await? + .unwrap(); + + conn.exec_drop( + "SET SESSION max_execution_time = ?", + (max_execution_time + 1,), + ) + .await?; + + assert_eq!( + conn.query_first::("SELECT @@max_execution_time") + .await?, + Some(max_execution_time + 1) + ); + + conn.change_user(Default::default()).await?; + assert_eq!( + conn.query_first::("SELECT @@max_execution_time") + .await?, + Some(max_execution_time) + ); + + let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) { + &["mysql_native_password", "caching_sha2_password"] + } else { + &["mysql_native_password"] + }; + + for plugin in plugins { + let mut conn2 = Conn::new(get_opts()).await.unwrap(); + + let mut rng = rand::thread_rng(); + let mut pass = [0u8; 10]; + pass.try_fill(&mut rng).unwrap(); + let pass: String = IntoIterator::into_iter(pass) + .map(|x| ((x % (123 - 97)) + 97) as char) + .collect(); + conn.query_drop("DROP USER IF EXISTS __mysql_async_test_user") + .await + .unwrap(); + conn.query_drop(format!( + "CREATE USER '__mysql_async_test_user'@'%' IDENTIFIED WITH {} BY {}", + plugin, + Value::from(pass.clone()).as_sql(false) + )) + .await + .unwrap(); + conn.query_drop("FLUSH PRIVILEGES").await.unwrap(); + + conn2 + .change_user( + ChangeUserOpts::default() + .with_db_name(None) + .with_user(Some("__mysql_async_test_user".into())) + .with_pass(Some(pass)), + ) + .await + .unwrap(); + assert_eq!( + conn2 + .query_first::<(Option, String), _>("SELECT DATABASE(), USER();") + .await + .unwrap(), + Some((None, String::from("__mysql_async_test_user@localhost"))), + ); + + conn2.disconnect().await.unwrap(); + } + conn.disconnect().await?; Ok(()) } diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index 73e8a999..8b21e685 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -69,16 +69,18 @@ pub struct GetConn { pub(crate) queue_id: Option, pub(crate) pool: Option, pub(crate) inner: GetConnInner, + reset_upon_returning_to_a_pool: bool, #[cfg(feature = "tracing")] span: Arc, } impl GetConn { - pub(crate) fn new(pool: &Pool) -> GetConn { + pub(crate) fn new(pool: &Pool, reset_upon_returning_to_a_pool: bool) -> GetConn { GetConn { queue_id: None, pool: Some(pool.clone()), inner: GetConnInner::New, + reset_upon_returning_to_a_pool, #[cfg(feature = "tracing")] span: Arc::new(debug_span!("mysql_async::get_conn")), } @@ -141,6 +143,8 @@ impl Future for GetConn { return match result { Ok(mut c) => { c.inner.pool = Some(pool); + c.inner.reset_upon_returning_to_a_pool = + self.reset_upon_returning_to_a_pool; Poll::Ready(Ok(c)) } Err(e) => { @@ -152,12 +156,14 @@ impl Future for GetConn { GetConnInner::Checking(ref mut f) => { let result = ready!(Pin::new(f).poll(cx)); match result { - Ok(mut checked_conn) => { + Ok(mut c) => { self.inner = GetConnInner::Done; let pool = self.pool_take(); - checked_conn.inner.pool = Some(pool); - return Poll::Ready(Ok(checked_conn)); + c.inner.pool = Some(pool); + c.inner.reset_upon_returning_to_a_pool = + self.reset_upon_returning_to_a_pool; + return Poll::Ready(Ok(c)); } Err(_) => { // Idling connection is broken. We'll drop it and try again. diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index fd38affe..d00c8157 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -232,7 +232,7 @@ impl Pool { /// Async function that resolves to `Conn`. pub fn get_conn(&self) -> GetConn { - GetConn::new(self) + GetConn::new(self, true) } /// Starts a new transaction. @@ -296,7 +296,7 @@ impl Pool { /// /// Decreases the exist counter since a broken or dropped connection should not count towards /// the total. - fn cancel_connection(&self) { + pub(super) fn cancel_connection(&self) { let mut exchange = self.inner.exchange.lock().unwrap(); exchange.exist -= 1; // we just enabled the creation of a new connection! @@ -573,20 +573,29 @@ mod test { let pool = Pool::new(opts); - "CREATE TEMPORARY TABLE tmp(id int)".ignore(&pool).await?; + "CREATE TABLE IF NOT EXISTS mysql.tmp(id int)" + .ignore(&pool) + .await?; + "DELETE FROM mysql.tmp".ignore(&pool).await?; let mut tx = pool.start_transaction(TxOpts::default()).await?; - tx.exec_batch("INSERT INTO tmp (id) VALUES (?)", vec![(1_u8,), (2_u8,)]) - .await?; - tx.exec_drop("SELECT * FROM tmp", ()).await?; + tx.exec_batch( + "INSERT INTO mysql.tmp (id) VALUES (?)", + vec![(1_u8,), (2_u8,)], + ) + .await?; + tx.exec_drop("SELECT * FROM mysql.tmp", ()).await?; drop(tx); let row_opt = pool .get_conn() .await? - .query_first("SELECT COUNT(*) FROM tmp") + .query_first("SELECT COUNT(*) FROM mysql.tmp") .await?; assert_eq!(row_opt, Some((0u8,))); - pool.get_conn().await?.query_drop("DROP TABLE tmp").await?; + pool.get_conn() + .await? + .query_drop("DROP TABLE mysql.tmp") + .await?; pool.disconnect().await?; Ok(()) } diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index 2a704dbc..5a705868 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -28,6 +28,7 @@ pub(crate) struct Recycler { discard: FuturesUnordered>, discarded: usize, cleaning: FuturesUnordered>, + reset: FuturesUnordered>, // Option so that we have a way to send a "I didn't make a Conn after all" signal dropped: mpsc::UnboundedReceiver>, @@ -47,6 +48,7 @@ impl Recycler { discard: FuturesUnordered::new(), discarded: 0, cleaning: FuturesUnordered::new(), + reset: FuturesUnordered::new(), dropped, pool_opts, eof: false, @@ -60,6 +62,21 @@ impl Future for Recycler { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut close = self.inner.close.load(Ordering::Acquire); + macro_rules! conn_return { + ($self:ident, $conn:ident) => {{ + let mut exchange = $self.inner.exchange.lock().unwrap(); + if exchange.available.len() >= $self.pool_opts.active_bound() { + drop(exchange); + $self.discard.push($conn.close_conn().boxed()); + } else { + exchange.available.push_back($conn.into()); + if let Some(w) = exchange.waiting.pop() { + w.wake(); + } + } + }}; + } + macro_rules! conn_decision { ($self:ident, $conn:ident) => { if $conn.inner.stream.is_none() || $conn.inner.disconnected { @@ -69,17 +86,10 @@ impl Future for Recycler { $self.cleaning.push($conn.cleanup_for_pool().boxed()); } else if $conn.expired() || close { $self.discard.push($conn.close_conn().boxed()); + } else if $conn.inner.reset_upon_returning_to_a_pool { + $self.reset.push($conn.reset_for_pool().boxed()); } else { - let mut exchange = $self.inner.exchange.lock().unwrap(); - if exchange.available.len() >= $self.pool_opts.active_bound() { - drop(exchange); - $self.discard.push($conn.close_conn().boxed()); - } else { - exchange.available.push_back($conn.into()); - if let Some(w) = exchange.waiting.pop() { - w.wake(); - } - } + conn_return!($self, $conn); } }; } @@ -138,6 +148,21 @@ impl Future for Recycler { } } + // let's iterate through connections being successfully reset + loop { + match Pin::new(&mut self.reset).poll_next(cx) { + Poll::Pending | Poll::Ready(None) => break, + Poll::Ready(Some(Ok(conn))) => conn_return!(self, conn), + Poll::Ready(Some(Err(e))) => { + // an error during reset. + // replace with a new connection + self.discarded += 1; + // NOTE: we're discarding the error here + let _ = e; + } + } + } + // are there any torn-down connections for us to deal with? loop { match Pin::new(&mut self.discard).poll_next(cx) { diff --git a/src/conn/routines/change_user.rs b/src/conn/routines/change_user.rs new file mode 100644 index 00000000..2a110fd8 --- /dev/null +++ b/src/conn/routines/change_user.rs @@ -0,0 +1,58 @@ +use futures_core::future::BoxFuture; +use futures_util::FutureExt; +use mysql_common::{ + constants::{UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI}, + packets::{ComChangeUser, ComChangeUserMoreData}, +}; +#[cfg(feature = "tracing")] +use tracing::debug_span; + +use crate::Conn; + +use super::Routine; + +/// A routine that performs `COM_RESET_CONNECTION`. +#[derive(Debug, Copy, Clone)] +pub struct ChangeUser; + +impl Routine<()> for ChangeUser { + fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<()>> { + #[cfg(feature = "tracing")] + let span = debug_span!( + "mysql_async::change_user", + mysql_async.connection.id = conn.id() + ); + + let com_change_user = ComChangeUser::new() + .with_user(conn.opts().user().map(|x| x.as_bytes())) + .with_database(conn.opts().db_name().map(|x| x.as_bytes())) + .with_auth_plugin_data( + conn.inner + .auth_plugin + .gen_data(conn.opts().pass(), &conn.inner.nonce) + .as_deref(), + ) + .with_more_data(Some( + ComChangeUserMoreData::new(if conn.inner.version >= (5, 5, 3) { + UTF8MB4_GENERAL_CI + } else { + UTF8_GENERAL_CI + }) + .with_auth_plugin(Some(conn.inner.auth_plugin.clone())) + .with_connect_attributes(None), + )) + .into_owned(); + + let fut = async move { + conn.write_command(&com_change_user).await?; + conn.inner.auth_switched = false; + conn.continue_auth().await?; + Ok(()) + }; + + #[cfg(feature = "tracing")] + let fut = instrument_result!(fut, span); + + fut.boxed() + } +} diff --git a/src/conn/routines/mod.rs b/src/conn/routines/mod.rs index 928bd771..80ecf1a5 100644 --- a/src/conn/routines/mod.rs +++ b/src/conn/routines/mod.rs @@ -2,8 +2,9 @@ use futures_core::future::BoxFuture; use crate::Conn; -pub use self::{exec::*, next_set::*, ping::*, prepare::*, query::*, reset::*}; +pub use self::{change_user::*, exec::*, next_set::*, ping::*, prepare::*, query::*, reset::*}; +mod change_user; mod exec; mod next_set; mod ping; diff --git a/src/io/mod.rs b/src/io/mod.rs index 6498b33e..d46b2dc3 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -153,7 +153,7 @@ impl Future for CheckTcpStream<'_> { } impl Endpoint { - #[cfg(all(any(feature = "native-tls-tls", feature = "rustls-tls"), unix))] + #[cfg(unix)] fn is_socket(&self) -> bool { match self { Self::Socket(_) => true, @@ -419,6 +419,11 @@ impl Stream { self.codec.as_ref().unwrap().get_ref().is_secure() } + #[cfg(unix)] + pub(crate) fn is_socket(&self) -> bool { + self.codec.as_ref().unwrap().get_ref().is_socket() + } + pub(crate) fn reset_seq_id(&mut self) { if let Some(codec) = self.codec.as_mut() { codec.codec_mut().reset_seq_id(); diff --git a/src/lib.rs b/src/lib.rs index 2f638e78..5d6d78b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -191,7 +191,7 @@ //! * [`Pool`] is a smart pointer – each clone will point to the same pool instance. //! * [`Pool`] is `Send + Sync + 'static` – feel free to pass it around. //! * use [`Pool::disconnect`] to gracefuly close the pool. -//! * [`Pool::new`] is lazy and won't assert server availability. +//! * ⚠️ [`Pool::new`] is lazy and won't assert server availability. //! //! # Transaction //! @@ -470,8 +470,9 @@ pub use self::opts::ClientIdentity; #[doc(inline)] pub use self::opts::{ - Opts, OptsBuilder, PoolConstraints, PoolOpts, SslOpts, DEFAULT_INACTIVE_CONNECTION_TTL, - DEFAULT_POOL_CONSTRAINTS, DEFAULT_STMT_CACHE_SIZE, DEFAULT_TTL_CHECK_INTERVAL, + ChangeUserOpts, Opts, OptsBuilder, PoolConstraints, PoolOpts, SslOpts, + DEFAULT_INACTIVE_CONNECTION_TTL, DEFAULT_POOL_CONSTRAINTS, DEFAULT_STMT_CACHE_SIZE, + DEFAULT_TTL_CHECK_INTERVAL, }; #[doc(inline)] diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 18953ae8..3d7a2800 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -21,6 +21,7 @@ use url::{Host, Url}; use std::{ borrow::Cow, convert::TryFrom, + fmt, net::{Ipv4Addr, Ipv6Addr}, path::Path, str::FromStr, @@ -1132,6 +1133,118 @@ impl From for Opts { } } +/// [`COM_CHANGE_USER`][1] options. +/// +/// Connection [`Opts`] are going to be updated accordingly upon `COM_CHANGE_USER`. +/// +/// [`Opts`] won't be updated by default, because default `ChangeUserOpts` will reuse +/// connection's `user`, `pass` and `db_name`. +/// +/// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html +#[derive(Clone, Eq, PartialEq)] +pub struct ChangeUserOpts { + user: Option>, + pass: Option>, + db_name: Option>, +} + +impl ChangeUserOpts { + pub(crate) fn update_opts(self, opts: &mut Opts) { + if self.user.is_none() && self.pass.is_none() && self.db_name.is_none() { + return; + } + + let mut builder = OptsBuilder::from_opts(opts.clone()); + + if let Some(user) = self.user { + builder = builder.user(user); + } + + if let Some(pass) = self.pass { + builder = builder.pass(pass); + } + + if let Some(db_name) = self.db_name { + builder = builder.db_name(db_name); + } + + *opts = Opts::from(builder); + } + + /// Creates change user options that'll reuse connection options. + pub fn new() -> Self { + Self { + user: None, + pass: None, + db_name: None, + } + } + + /// Set [`Opts::user`] to the given value. + pub fn with_user(mut self, user: Option) -> Self { + self.user = Some(user); + self + } + + /// Set [`Opts::pass`] to the given value. + pub fn with_pass(mut self, pass: Option) -> Self { + self.pass = Some(pass); + self + } + + /// Set [`Opts::db_name`] to the given value. + pub fn with_db_name(mut self, db_name: Option) -> Self { + self.db_name = Some(db_name); + self + } + + /// Returns user. + /// + /// * if `None` then `self` does not meant to change user + /// * if `Some(None)` then `self` will clear user + /// * if `Some(Some(_))` then `self` will change user + pub fn user(&self) -> Option> { + self.user.as_ref().map(|x| x.as_deref()) + } + + /// Returns password. + /// + /// * if `None` then `self` does not meant to change password + /// * if `Some(None)` then `self` will clear password + /// * if `Some(Some(_))` then `self` will change password + pub fn pass(&self) -> Option> { + self.pass.as_ref().map(|x| x.as_deref()) + } + + /// Returns database name. + /// + /// * if `None` then `self` does not meant to change database name + /// * if `Some(None)` then `self` will clear database name + /// * if `Some(Some(_))` then `self` will change database name + pub fn db_name(&self) -> Option> { + self.db_name.as_ref().map(|x| x.as_deref()) + } +} + +impl Default for ChangeUserOpts { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Debug for ChangeUserOpts { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ChangeUserOpts") + .field("user", &self.user) + .field( + "pass", + &self.pass.as_ref().map(|x| x.as_ref().map(|_| "...")), + ) + .field("db_name", &self.db_name) + .finish() + } +} + fn get_opts_user_from_url(url: &Url) -> Option { let user = url.username(); if !user.is_empty() { From ec9a15b655698fb4a404585b180579fc79b46a10 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 12 Apr 2023 15:32:18 +0300 Subject: [PATCH 007/130] Fix tests --- src/conn/mod.rs | 115 ++++++++++++++++++++++++------------------- src/conn/pool/mod.rs | 2 +- 2 files changed, 65 insertions(+), 52 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 3715a468..4f0500f1 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1544,34 +1544,28 @@ mod test { #[tokio::test] async fn should_reset_the_connection() -> super::Result<()> { let mut conn = Conn::new(get_opts()).await?; - let max_execution_time = conn - .query_first::("SELECT @@max_execution_time") - .await? - .unwrap(); - conn.exec_drop( - "SET SESSION max_execution_time = ?", - (max_execution_time + 1,), - ) - .await?; + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); + + conn.query_drop("SET @foo = 'foo'").await?; assert_eq!( - conn.query_first::("SELECT @@max_execution_time") - .await?, - Some(max_execution_time + 1) + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", ); if conn.reset().await? { assert_eq!( - conn.query_first::("SELECT @@max_execution_time") - .await?, - Some(max_execution_time) + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL ); } else { assert_eq!( - conn.query_first::("SELECT @@max_execution_time") - .await?, - Some(max_execution_time + 1) + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", ); } @@ -1582,28 +1576,22 @@ mod test { #[tokio::test] async fn should_change_user() -> super::Result<()> { let mut conn = Conn::new(get_opts()).await?; - let max_execution_time = conn - .query_first::("SELECT @@max_execution_time") - .await? - .unwrap(); + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); - conn.exec_drop( - "SET SESSION max_execution_time = ?", - (max_execution_time + 1,), - ) - .await?; + conn.query_drop("SET @foo = 'foo'").await?; assert_eq!( - conn.query_first::("SELECT @@max_execution_time") - .await?, - Some(max_execution_time + 1) + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", ); conn.change_user(Default::default()).await?; assert_eq!( - conn.query_first::("SELECT @@max_execution_time") - .await?, - Some(max_execution_time) + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL ); let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) { @@ -1613,42 +1601,67 @@ mod test { }; for plugin in plugins { - let mut conn2 = Conn::new(get_opts()).await.unwrap(); - let mut rng = rand::thread_rng(); let mut pass = [0u8; 10]; pass.try_fill(&mut rng).unwrap(); let pass: String = IntoIterator::into_iter(pass) .map(|x| ((x % (123 - 97)) + 97) as char) .collect(); - conn.query_drop("DROP USER IF EXISTS __mysql_async_test_user") + + conn.query_drop("DELETE FROM mysql.user WHERE user = '__mats'") .await .unwrap(); - conn.query_drop(format!( - "CREATE USER '__mysql_async_test_user'@'%' IDENTIFIED WITH {} BY {}", - plugin, - Value::from(pass.clone()).as_sql(false) - )) - .await - .unwrap(); conn.query_drop("FLUSH PRIVILEGES").await.unwrap(); + if conn.inner.is_mariadb || conn.server_version() < (5, 7, 0) { + if matches!(conn.server_version(), (5, 6, _)) { + conn.query_drop("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password") + .await + .unwrap(); + conn.query_drop(format!( + "SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD({})", + Value::from(pass.clone()).as_sql(false) + )) + .await + .unwrap(); + } else { + conn.query_drop("CREATE USER '__mats'@'%'").await.unwrap(); + conn.query_drop(format!( + "SET PASSWORD FOR '__mats'@'%' = PASSWORD({})", + Value::from(pass.clone()).as_sql(false) + )) + .await + .unwrap(); + } + } else { + conn.query_drop(format!( + "CREATE USER '__mats'@'%' IDENTIFIED WITH {} BY {}", + plugin, + Value::from(pass.clone()).as_sql(false) + )) + .await + .unwrap(); + }; + + conn.query_drop("FLUSH PRIVILEGES").await.unwrap(); + + let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap(); conn2 .change_user( ChangeUserOpts::default() .with_db_name(None) - .with_user(Some("__mysql_async_test_user".into())) + .with_user(Some("__mats".into())) .with_pass(Some(pass)), ) .await .unwrap(); - assert_eq!( - conn2 - .query_first::<(Option, String), _>("SELECT DATABASE(), USER();") - .await - .unwrap(), - Some((None, String::from("__mysql_async_test_user@localhost"))), - ); + let (db, user) = conn2 + .query_first::<(Option, String), _>("SELECT DATABASE(), USER();") + .await + .unwrap() + .unwrap(); + assert_eq!(db, None); + assert!(user.starts_with("__mats")); conn2.disconnect().await.unwrap(); } diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index d00c8157..84182e8e 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -666,7 +666,7 @@ mod test { let _ = conns.pop(); // then, wait for a bit to let the connection be reclaimed - sleep(Duration::from_millis(50)).await; + sleep(Duration::from_millis(500)).await; // now check that we have the expected # of connections // this may look a little funky, but think of it this way: From 3763e929f873d4efd22ee24088ba605736119be4 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 12 Apr 2023 17:32:40 +0300 Subject: [PATCH 008/130] Bump mysql_common to 0.30.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c2d7f5d5..bd55f046 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ futures-sink = "0.3" lazy_static = "1" lru = "0.10.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } -mysql_common = { version = "0.30", default-features = false, features = [ +mysql_common = { version = "0.30.1", default-features = false, features = [ "derive", ] } once_cell = "1.7.2" From bbf24b0006bae6364714cfdad5f20ffe13b6af51 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 12 Apr 2023 18:34:13 +0300 Subject: [PATCH 009/130] Bump dependencies --- Cargo.toml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bd55f046..1f99ad9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,13 +26,13 @@ mysql_common = { version = "0.30.1", default-features = false, features = [ "derive", ] } once_cell = "1.7.2" -pem = "1.0.1" +pem = "2.0.1" percent-encoding = "2.1.0" pin-project = "1.0.2" priority-queue = "1" serde = "1" serde_json = "1" -socket2 = "0.4.2" +socket2 = "0.5.2" thiserror = "1.0.4" tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] } tokio-util = { version = "0.7.2", features = ["codec", "io"] } @@ -43,7 +43,7 @@ twox-hash = "1" url = "2.1" [dependencies.tokio-rustls] -version = "0.23.4" +version = "0.24.0" optional = true [dependencies.tokio-native-tls] @@ -55,7 +55,7 @@ version = "0.2" optional = true [dependencies.rustls] -version = "0.20.0" +version = "0.21.0" features = ["dangerous_configuration"] optional = true @@ -65,15 +65,16 @@ optional = true [dependencies.webpki] version = "0.22.0" +features = ["std"] optional = true [dependencies.webpki-roots] -version = "0.22.1" +version = "0.23.0" optional = true [dev-dependencies] tempfile = "3.1.0" -socket2 = { version = "0.4.0", features = ["all"] } +socket2 = { version = "0.5.2", features = ["all"] } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } rand = "0.8.0" From 411a84a61176ea56f85011e29f8583bb77df279b Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 13 Apr 2023 09:59:07 +0300 Subject: [PATCH 010/130] pool: Fix some connections not being properly reset --- src/conn/pool/mod.rs | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 84182e8e..dd2b8991 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -253,25 +253,6 @@ impl Pool { fn return_conn(&mut self, conn: Conn) { // NOTE: we're not in async context here, so we can't block or return NotReady // any and all cleanup work _has_ to be done in the spawned recycler - - // fast-path for when the connection is immediately ready to be reused - if conn.inner.stream.is_some() - && !conn.inner.disconnected - && !conn.expired() - && conn.inner.tx_status == TxStatus::None - && !conn.has_pending_result() - && !self.inner.close.load(atomic::Ordering::Acquire) - { - let mut exchange = self.inner.exchange.lock().unwrap(); - if exchange.available.len() < self.opts.pool_opts().active_bound() { - exchange.available.push_back(conn.into()); - if let Some(w) = exchange.waiting.pop() { - w.wake(); - } - return; - } - } - self.send_to_recycler(conn); } @@ -492,6 +473,9 @@ mod test { .map(|conn| conn.id()) .collect::>(); + // give some time to reset connections + sleep(Duration::from_millis(500)).await; + // get_conn should work if connection is available and alive pool.get_conn().await?; From 8c4b72a64b0710bb5f04af47cd28315c16d20e01 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 13 Apr 2023 09:59:42 +0300 Subject: [PATCH 011/130] recycler: assert that reset queue is exhausted on eof --- src/conn/pool/recycler.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index 5a705868..fca5be91 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -201,7 +201,11 @@ impl Future for Recycler { // races on .exist let effectively_eof = close && self.inner.exchange.lock().unwrap().exist == 0; - if (self.eof || effectively_eof) && self.cleaning.is_empty() && self.discard.is_empty() { + if (self.eof || effectively_eof) + && self.cleaning.is_empty() + && self.discard.is_empty() + && self.reset.is_empty() + { // we know that all Pool handles have been dropped (self.dropped.poll returned None). // if this assertion fails, where are the remaining connections? From f23dca01289715e26d75722b481123c56fdd9ed8 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 13 Apr 2023 13:40:41 +0300 Subject: [PATCH 012/130] recycler: check for closed pool in conn_return! macro --- src/conn/pool/mod.rs | 2 +- src/conn/pool/recycler.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index dd2b8991..b974c491 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -26,7 +26,7 @@ use crate::{ conn::{pool::futures::*, Conn}, error::*, opts::{Opts, PoolOpts}, - queryable::transaction::{Transaction, TxOpts, TxStatus}, + queryable::transaction::{Transaction, TxOpts}, }; mod recycler; diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index fca5be91..1ea855c0 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -63,9 +63,9 @@ impl Future for Recycler { let mut close = self.inner.close.load(Ordering::Acquire); macro_rules! conn_return { - ($self:ident, $conn:ident) => {{ + ($self:ident, $conn:ident, $pool_is_closed: expr) => {{ let mut exchange = $self.inner.exchange.lock().unwrap(); - if exchange.available.len() >= $self.pool_opts.active_bound() { + if $pool_is_closed || exchange.available.len() >= $self.pool_opts.active_bound() { drop(exchange); $self.discard.push($conn.close_conn().boxed()); } else { @@ -89,7 +89,7 @@ impl Future for Recycler { } else if $conn.inner.reset_upon_returning_to_a_pool { $self.reset.push($conn.reset_for_pool().boxed()); } else { - conn_return!($self, $conn); + conn_return!($self, $conn, false); } }; } @@ -152,7 +152,7 @@ impl Future for Recycler { loop { match Pin::new(&mut self.reset).poll_next(cx) { Poll::Pending | Poll::Ready(None) => break, - Poll::Ready(Some(Ok(conn))) => conn_return!(self, conn), + Poll::Ready(Some(Ok(conn))) => conn_return!(self, conn, close), Poll::Ready(Some(Err(e))) => { // an error during reset. // replace with a new connection From 32c6f2a986789f97108502c2d0c755a089411b66 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 13 Apr 2023 14:18:07 +0300 Subject: [PATCH 013/130] Bump micro version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 1f99ad9d..efa1f5f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.32.0" +version = "0.32.1" exclude = ["test/*"] edition = "2018" categories = ["asynchronous", "database"] From 7c6572d0d63fb7b19524f94258ea9e67c6b55674 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 14 Apr 2023 09:16:37 +0300 Subject: [PATCH 014/130] Add a way to opt-out of pooled connection reset --- src/conn/mod.rs | 7 ++++ src/conn/pool/mod.rs | 55 +++++++++++++++++++++++++-- src/conn/routines/change_user.rs | 2 +- src/opts/mod.rs | 65 ++++++++++++++++++++++++++++++-- 4 files changed, 122 insertions(+), 7 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 4f0500f1..bcd2d3ef 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -229,6 +229,13 @@ impl Conn { self.inner.last_ok_packet.as_ref() } + /// Turns on/off automatic connection reset (see [`crate::PoolOpts::with_reset_connection`]). + /// + /// Only makes sense for pooled connections. + pub fn reset_connection(&mut self, reset_connection: bool) { + self.inner.reset_upon_returning_to_a_pool = reset_connection; + } + pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> { self.inner.stream_mut() } diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index b974c491..2dd0ca00 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -232,7 +232,8 @@ impl Pool { /// Async function that resolves to `Conn`. pub fn get_conn(&self) -> GetConn { - GetConn::new(self, true) + let reset_connection = self.opts.pool_opts().reset_connection(); + GetConn::new(self, reset_connection) } /// Starts a new transaction. @@ -382,7 +383,6 @@ mod test { future::{join_all, select, select_all, try_join_all}, try_join, FutureExt, }; - use mysql_common::row::Row; use tokio::time::{sleep, timeout}; use std::{ @@ -396,7 +396,7 @@ mod test { opts::PoolOpts, prelude::*, test_misc::get_opts, - PoolConstraints, TxOpts, + PoolConstraints, Row, TxOpts, Value, }; macro_rules! conn_ex_field { @@ -411,6 +411,55 @@ mod test { }; } + #[tokio::test] + async fn should_opt_out_of_connection_reset() -> super::Result<()> { + let pool_opts = PoolOpts::new().with_constraints(PoolConstraints::new(1, 1).unwrap()); + let opts = get_opts().pool_opts(pool_opts.clone()); + + let pool = Pool::new(opts.clone()); + + let mut conn = pool.get_conn().await.unwrap(); + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); + conn.query_drop("SET @foo = 'foo'").await?; + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", + ); + drop(conn); + + conn = pool.get_conn().await.unwrap(); + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); + conn.query_drop("SET @foo = 'foo'").await?; + conn.reset_connection(false); + drop(conn); + + conn = pool.get_conn().await.unwrap(); + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", + ); + drop(conn); + pool.disconnect().await.unwrap(); + + let pool = Pool::new(opts.pool_opts(pool_opts.with_reset_connection(false))); + conn = pool.get_conn().await.unwrap(); + conn.query_drop("SET @foo = 'foo'").await?; + drop(conn); + conn = pool.get_conn().await.unwrap(); + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", + ); + drop(conn); + pool.disconnect().await + } + #[test] fn should_not_hang() -> super::Result<()> { pub struct Database { diff --git a/src/conn/routines/change_user.rs b/src/conn/routines/change_user.rs index 2a110fd8..28b51d4e 100644 --- a/src/conn/routines/change_user.rs +++ b/src/conn/routines/change_user.rs @@ -11,7 +11,7 @@ use crate::Conn; use super::Routine; -/// A routine that performs `COM_RESET_CONNECTION`. +/// A routine that performs `COM_CHANGE_USER`. #[derive(Debug, Copy, Clone)] pub struct ChangeUser; diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 3d7a2800..af7183fd 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -209,9 +209,15 @@ pub struct PoolOpts { constraints: PoolConstraints, inactive_connection_ttl: Duration, ttl_check_interval: Duration, + reset_connection: bool, } impl PoolOpts { + /// Calls `Self::default`. + pub fn new() -> Self { + Self::default() + } + /// Creates the default [`PoolOpts`] with the given constraints. pub fn with_constraints(mut self, constraints: PoolConstraints) -> Self { self.constraints = constraints; @@ -223,6 +229,50 @@ impl PoolOpts { self.constraints } + /// Sets whether to reset connection upon returning it to a pool (defaults to `true`). + /// + /// Default behavior increases reliability but comes with cons: + /// + /// * reset procedure removes all prepared statements, i.e. kills prepared statements cache + /// * connection reset is quite fast but requires additional client-server roundtrip + /// (might also requires requthentication for older servers) + /// + /// The purpose of the reset procedure is to: + /// + /// * rollback any opened transactions (`mysql_async` is able to do this without explicit reset) + /// * reset transaction isolation level + /// * reset session variables + /// * delete user variables + /// * remove temporary tables + /// * remove all PREPARE statement (this action kills prepared statements cache) + /// + /// So to encrease overall performance you can safely opt-out of the default behavior + /// if you are not willing to change the session state in an unpleasant way. + /// + /// It is also possible to selectively opt-in/out using [`Conn::reset_connection`]. + /// + /// # Connection URL + /// + /// You can use `reset_connection` URL parameter to set this value. E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?reset_connection=false")?; + /// assert_eq!(opts.pool_opts().reset_connection(), false); + /// # Ok(()) } + /// ``` + pub fn with_reset_connection(mut self, reset_connection: bool) -> Self { + self.reset_connection = reset_connection; + self + } + + /// Returns the `reset_connection` value (see [`PoolOpts::with_reset_connection`]). + pub fn reset_connection(&self) -> bool { + self.reset_connection + } + /// Pool will recycle inactive connection if it is outside of the lower bound of the pool /// and if it is idling longer than this value (defaults to /// [`DEFAULT_INACTIVE_CONNECTION_TTL`]). @@ -309,6 +359,7 @@ impl Default for PoolOpts { constraints: DEFAULT_POOL_CONSTRAINTS, inactive_connection_ttl: DEFAULT_INACTIVE_CONNECTION_TTL, ttl_check_interval: DEFAULT_TTL_CHECK_INTERVAL, + reset_connection: true, } } } @@ -1340,7 +1391,6 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { Ok(value) => { opts.pool_opts = opts .pool_opts - .clone() .with_inactive_connection_ttl(Duration::from_secs(value)) } _ => { @@ -1355,7 +1405,6 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { Ok(value) => { opts.pool_opts = opts .pool_opts - .clone() .with_ttl_check_interval(Duration::from_secs(value)) } _ => { @@ -1421,6 +1470,16 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "reset_connection" { + match bool::from_str(&*value) { + Ok(parsed) => opts.pool_opts = opts.pool_opts.with_reset_connection(parsed), + Err(_) => { + return Err(UrlError::InvalidParamValue { + param: key.to_string(), + value, + }); + } + } } else if key == "tcp_nodelay" { match bool::from_str(&*value) { Ok(value) => opts.tcp_nodelay = value, @@ -1538,7 +1597,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } if let Some(pool_constraints) = PoolConstraints::new(pool_min, pool_max) { - opts.pool_opts = opts.pool_opts.clone().with_constraints(pool_constraints); + opts.pool_opts = opts.pool_opts.with_constraints(pool_constraints); } else { return Err(UrlError::InvalidPoolConstraints { min: pool_min, From 73dbb96185539154a7098b3d10d82f556e349799 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 14 Apr 2023 09:34:59 +0300 Subject: [PATCH 015/130] Add `Opts::setup` and `OptsBuilder::setup` --- src/conn/mod.rs | 37 +++++++++++++++++++++++++++++++++++++ src/opts/mod.rs | 22 ++++++++++++++++++++-- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index bcd2d3ef..3fbc6331 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -881,6 +881,16 @@ impl Conn { Ok(()) } + async fn run_setup_commands(&mut self) -> Result<()> { + let mut setup = self.inner.opts.setup().to_vec(); + + while let Some(query) = setup.pop() { + self.query_drop(query).await?; + } + + Ok(()) + } + /// Returns a future that resolves to [`Conn`]. pub fn new>(opts: T) -> crate::BoxFuture<'static, Conn> { let opts = opts.into(); @@ -913,6 +923,7 @@ impl Conn { conn.read_max_allowed_packet().await?; conn.read_wait_timeout().await?; conn.run_init_commands().await?; + conn.run_setup_commands().await?; Ok(conn) } @@ -1011,6 +1022,7 @@ impl Conn { self.routine(routines::ResetRoutine).await?; self.inner.stmt_cache.clear(); self.inner.infile_handler = None; + self.run_setup_commands().await?; } Ok(supports_com_reset_connection) @@ -1052,6 +1064,7 @@ impl Conn { self.routine(routines::ChangeUser).await?; self.inner.stmt_cache.clear(); self.inner.infile_handler = None; + self.run_setup_commands().await?; Ok(()) } @@ -1548,6 +1561,30 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_execute_setup_queries_on_reset() -> super::Result<()> { + let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]); + let mut conn = Conn::new(opts).await?; + + // initial run + let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?; + assert_eq!(result, vec![(42, "foo".into())]); + + // after reset + if conn.reset().await? { + result = conn.query("SELECT @a, @b").await?; + assert_eq!(result, vec![(42, "foo".into())]); + } + + // after change user + conn.change_user(Default::default()).await?; + result = conn.query("SELECT @a, @b").await?; + assert_eq!(result, vec![(42, "foo".into())]); + + conn.disconnect().await?; + Ok(()) + } + #[tokio::test] async fn should_reset_the_connection() -> super::Result<()> { let mut conn = Conn::new(get_opts()).await?; diff --git a/src/opts/mod.rs b/src/opts/mod.rs index af7183fd..9d1e849b 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -403,9 +403,13 @@ pub(crate) struct MysqlOpts { /// (defaults to `wait_timeout`). conn_ttl: Option, - /// Commands to execute on each new database connection. + /// Commands to execute once new connection is established. init: Vec, + /// Commands to execute on new connection and every time + /// [`Conn::reset`] or [`Conn::change_user`] is invoked. + setup: Vec, + /// Number of prepared statements cached on the client side (per connection). Defaults to `10`. stmt_cache_size: usize, @@ -577,11 +581,17 @@ impl Opts { self.inner.mysql_opts.db_name.as_ref().map(AsRef::as_ref) } - /// Commands to execute on each new database connection. + /// Commands to execute once new connection is established. pub fn init(&self) -> &[String] { self.inner.mysql_opts.init.as_ref() } + /// Commands to execute on new connection and every time + /// [`Conn::reset`] or [`Conn::change_user`] is invoked. + pub fn setup(&self) -> &[String] { + self.inner.mysql_opts.setup.as_ref() + } + /// TCP keep alive timeout in milliseconds (defaults to `None`). /// /// # Connection URL @@ -871,6 +881,7 @@ impl Default for MysqlOpts { pass: None, db_name: None, init: vec![], + setup: vec![], tcp_keepalive: None, tcp_nodelay: true, local_infile_handler: None, @@ -1037,6 +1048,12 @@ impl OptsBuilder { self } + /// Defines setup queries. See [`Opts::setup`]. + pub fn setup>(mut self, setup: Vec) -> Self { + self.opts.setup = setup.into_iter().map(Into::into).collect(); + self + } + /// Defines `tcp_keepalive` option. See [`Opts::tcp_keepalive`]. pub fn tcp_keepalive>(mut self, tcp_keepalive: Option) -> Self { self.opts.tcp_keepalive = tcp_keepalive.map(Into::into); @@ -1654,6 +1671,7 @@ mod test { assert_eq!(url_opts.pass(), builder_opts.pass()); assert_eq!(url_opts.db_name(), builder_opts.db_name()); assert_eq!(url_opts.init(), builder_opts.init()); + assert_eq!(url_opts.setup(), builder_opts.setup()); assert_eq!(url_opts.tcp_keepalive(), builder_opts.tcp_keepalive()); assert_eq!(url_opts.tcp_nodelay(), builder_opts.tcp_nodelay()); assert_eq!(url_opts.pool_opts(), builder_opts.pool_opts()); From cd1ae0421eb80591f156273ecab651dfea7315ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Fern=C3=A1ndez?= Date: Thu, 20 Apr 2023 09:34:50 +0200 Subject: [PATCH 016/130] Inline ops that read settings (#2) Measures taken on 100 coldstart connections to a remote RDS mysql 5.7 in us-east-1, accessed from Spain. Units are ms. Before patch: Average: 222.01605959 Standard deviation: 20.28249939282523 After patch: Average: 152.8317912 Standard deviation: 22.382571070107467 Inlining queries make connections that require reading server settings run ~30% faster. --- src/conn/mod.rs | 60 ++++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 3fbc6331..7289d21e 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -19,6 +19,7 @@ use mysql_common::{ OldEofPacket, ResultSetTerminator, SslRequest, }, proto::MySerialize, + row::Row, }; use std::{ @@ -918,10 +919,8 @@ impl Conn { conn.do_handshake_response().await?; conn.continue_auth().await?; conn.switch_to_compression()?; - conn.read_socket().await?; + conn.read_settings().await?; conn.reconnect_via_socket_if_needed().await?; - conn.read_max_allowed_packet().await?; - conn.read_wait_timeout().await?; conn.run_init_commands().await?; conn.run_setup_commands().await?; @@ -953,38 +952,53 @@ impl Conn { Ok(()) } - /// Reads and stores socket address inside the connection. + /// Configures the connection based on server settings. In particular: /// - /// Do nothing if socket address is already in [`Opts`] or if `prefer_socket` is `false`. - async fn read_socket(&mut self) -> Result<()> { - if self.inner.opts.prefer_socket() && self.inner.socket.is_none() { - let row_opt = self.query_internal("SELECT @@socket").await?; - self.inner.socket = row_opt.unwrap_or(None); + /// * It reads and stores socket address inside the connection unless if socket address is + /// already in [`Opts`] or if `prefer_socket` is `false`. + /// + /// * It reads and stores `max_allowed_packet` in the connection unless it's already in [`Opts`] + /// + /// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`] + /// + async fn read_settings(&mut self) -> Result<()> { + let read_socket = self.inner.opts.prefer_socket() && self.inner.socket.is_none(); + let read_max_allowed_packet = self.opts().max_allowed_packet().is_none(); + let read_wait_timeout = self.opts().wait_timeout().is_none(); + + let settings: Option = if read_socket || read_max_allowed_packet || read_wait_timeout { + self.query_internal("SELECT @@socket, @@max_allowed_packet, @@wait_timeout") + .await? + } else { + None + }; + + // set socket inside the connection + if read_socket { + self.inner.socket = settings.as_ref().map(|s| s.get("@@socket")).unwrap_or(None); } - Ok(()) - } - /// Reads and stores `max_allowed_packet` in the connection. - async fn read_max_allowed_packet(&mut self) -> Result<()> { - let max_allowed_packet = if let Some(value) = self.opts().max_allowed_packet() { - Some(value) + // set max_allowed_packet + let max_allowed_packet = if read_max_allowed_packet { + settings + .as_ref() + .map(|s| s.get("@@max_allowed_packet")) + .unwrap() } else { - self.query_internal("SELECT @@max_allowed_packet").await? + self.opts().max_allowed_packet() }; if let Some(stream) = self.inner.stream.as_mut() { stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET)); } - Ok(()) - } - /// Reads and stores `wait_timeout` in the connection. - async fn read_wait_timeout(&mut self) -> Result<()> { - let wait_timeout = if let Some(value) = self.opts().wait_timeout() { - Some(value) + // set read_wait_timeout + let wait_timeout = if read_wait_timeout { + settings.as_ref().map(|s| s.get("@@wait_timeout")).unwrap() } else { - self.query_internal("SELECT @@wait_timeout").await? + self.opts().wait_timeout() }; self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64); + Ok(()) } From 48608a103379f3ee71c7b5fd7ba46e11a0241340 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 20 Apr 2023 20:42:15 +0300 Subject: [PATCH 017/130] Use rust flate2 backend on default-rustls feature (fix #244) --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index efa1f5f8..1d11b6b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,7 +89,7 @@ default = [ "native-tls-tls", ] default-rustls = [ - "flate2/zlib", + "flate2/rust_backend", "mysql_common/bigdecimal", "mysql_common/rust_decimal", "mysql_common/time", From 485c7b716a4abceacf2d0113451a1816920d3172 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 20 Apr 2023 20:43:00 +0300 Subject: [PATCH 018/130] Fix pool::test::should_reconnect --- src/conn/pool/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 2dd0ca00..ec6e301b 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -523,7 +523,7 @@ mod test { .collect::>(); // give some time to reset connections - sleep(Duration::from_millis(500)).await; + sleep(Duration::from_millis(1000)).await; // get_conn should work if connection is available and alive pool.get_conn().await?; From c1c8081256319572f18af529b9aa35bad17c19b9 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Sat, 22 Apr 2023 12:43:46 +0300 Subject: [PATCH 019/130] Remove mysql_common/derive from the set of enabled mysql_common features --- Cargo.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1d11b6b5..29cbfe7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,9 +22,7 @@ futures-sink = "0.3" lazy_static = "1" lru = "0.10.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } -mysql_common = { version = "0.30.1", default-features = false, features = [ - "derive", -] } +mysql_common = { version = "0.30", default-features = false } once_cell = "1.7.2" pem = "2.0.1" percent-encoding = "2.1.0" From a9e2278f3852b4dcf2da2a908d013e643ee114f3 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Sat, 22 Apr 2023 12:44:44 +0300 Subject: [PATCH 020/130] Update README.md --- README.md | 4 +++- src/lib.rs | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 215bb4b4..d469ddb7 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,8 @@ as well as `native-tls`-based TLS support. mysql_async = { version = "*", features = ["tracing"] } ``` +* `derive` – enables `mysql_commom/derive` feature + [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features ## TLS/SSL Support @@ -190,7 +192,7 @@ Please note: * [`Pool`] is a smart pointer – each clone will point to the same pool instance. * [`Pool`] is `Send + Sync + 'static` – feel free to pass it around. * use [`Pool::disconnect`] to gracefuly close the pool. -* [`Pool::new`] is lazy and won't assert server availability. +* ⚠️ [`Pool::new`] is lazy and won't assert server availability. ## Transaction diff --git a/src/lib.rs b/src/lib.rs index 5d6d78b7..9c5ff836 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,6 +93,8 @@ //! mysql_async = { version = "*", features = ["tracing"] } //! ``` //! +//! * `derive` – enables `mysql_commom/derive` feature +//! //! [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features //! //! # TLS/SSL Support From e6bbf7c776374d0067c0164245e9158b5c75f7e7 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Sat, 22 Apr 2023 12:44:54 +0300 Subject: [PATCH 021/130] Bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 29cbfe7f..42c9447c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.32.1" +version = "0.32.2" exclude = ["test/*"] edition = "2018" categories = ["asynchronous", "database"] From bf4fe8cd1501056d0c926cb7a7715d1e650f0c7a Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Mon, 24 Apr 2023 19:20:59 +0200 Subject: [PATCH 022/130] Replace crate priority_queue with keyed_priority_queue Fixes #247 --- Cargo.toml | 2 +- src/conn/pool/mod.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 42c9447c..251ef374 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ flate2 = { version = "1.0", default-features = false } futures-core = "0.3" futures-util = "0.3" futures-sink = "0.3" +keyed_priority_queue = "0.4" lazy_static = "1" lru = "0.10.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } @@ -27,7 +28,6 @@ once_cell = "1.7.2" pem = "2.0.1" percent-encoding = "2.1.0" pin-project = "1.0.2" -priority-queue = "1" serde = "1" serde_json = "1" socket2 = "0.5.2" diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index ec6e301b..9fc29e71 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -7,7 +7,7 @@ // modified, or distributed except according to those terms. use futures_util::FutureExt; -use priority_queue::PriorityQueue; +use keyed_priority_queue::KeyedPriorityQueue; use tokio::sync::mpsc; use std::{ @@ -92,7 +92,7 @@ impl Exchange { #[derive(Default, Debug)] struct Waitlist { - queue: PriorityQueue, + queue: KeyedPriorityQueue, } impl Waitlist { From 668a7e48d30584810dc55a37f8acf8b19afc0095 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 4 Aug 2023 15:36:53 +0300 Subject: [PATCH 023/130] Do not read unnecessary settings in Conn::read_settings --- src/conn/mod.rs | 139 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 111 insertions(+), 28 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 7289d21e..356c5658 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -55,6 +55,8 @@ pub mod pool; pub mod routines; pub mod stmt_cache; +const DEFAULT_WAIT_TIMEOUT: usize = 28800; + /// Helper that asynchronously disconnects the givent connection on the default tokio executor. fn disconnect(mut conn: Conn) { let disconnected = conn.inner.disconnected; @@ -962,42 +964,123 @@ impl Conn { /// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`] /// async fn read_settings(&mut self) -> Result<()> { - let read_socket = self.inner.opts.prefer_socket() && self.inner.socket.is_none(); - let read_max_allowed_packet = self.opts().max_allowed_packet().is_none(); - let read_wait_timeout = self.opts().wait_timeout().is_none(); + enum Action { + Load(Cfg), + Apply(CfgData), + } - let settings: Option = if read_socket || read_max_allowed_packet || read_wait_timeout { - self.query_internal("SELECT @@socket, @@max_allowed_packet, @@wait_timeout") - .await? - } else { - None - }; + enum CfgData { + MaxAllowedPacket(usize), + WaitTimeout(usize), + } - // set socket inside the connection - if read_socket { - self.inner.socket = settings.as_ref().map(|s| s.get("@@socket")).unwrap_or(None); + impl CfgData { + fn apply(&self, conn: &mut Conn) { + match self { + Self::MaxAllowedPacket(value) => { + if let Some(stream) = conn.inner.stream.as_mut() { + stream.set_max_allowed_packet(*value); + } + } + Self::WaitTimeout(value) => { + conn.inner.wait_timeout = Duration::from_secs(*value as u64); + } + } + } } - // set max_allowed_packet - let max_allowed_packet = if read_max_allowed_packet { - settings - .as_ref() - .map(|s| s.get("@@max_allowed_packet")) - .unwrap() - } else { - self.opts().max_allowed_packet() - }; - if let Some(stream) = self.inner.stream.as_mut() { - stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET)); + enum Cfg { + Socket, + MaxAllowedPacket, + WaitTimeout, } - // set read_wait_timeout - let wait_timeout = if read_wait_timeout { - settings.as_ref().map(|s| s.get("@@wait_timeout")).unwrap() + impl Cfg { + const fn name(&self) -> &'static str { + match self { + Self::Socket => "@@socket", + Self::MaxAllowedPacket => "@@max_allowed_packet", + Self::WaitTimeout => "@@wait_timeout", + } + } + + fn apply(&self, conn: &mut Conn, value: Option) { + match self { + Cfg::Socket => { + conn.inner.socket = value.map(crate::from_value).flatten(); + } + Cfg::MaxAllowedPacket => { + if let Some(stream) = conn.inner.stream.as_mut() { + stream.set_max_allowed_packet( + value + .map(crate::from_value) + .flatten() + .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET), + ); + } + } + Cfg::WaitTimeout => { + conn.inner.wait_timeout = Duration::from_secs( + value + .map(crate::from_value) + .flatten() + .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64, + ); + } + } + } + } + + let mut actions = vec![ + if let Some(x) = self.opts().max_allowed_packet() { + Action::Apply(CfgData::MaxAllowedPacket(x)) + } else { + Action::Load(Cfg::MaxAllowedPacket) + }, + if let Some(x) = self.opts().wait_timeout() { + Action::Apply(CfgData::WaitTimeout(x)) + } else { + Action::Load(Cfg::WaitTimeout) + }, + ]; + + if self.inner.opts.prefer_socket() && self.inner.socket.is_none() { + actions.push(Action::Load(Cfg::Socket)) + } + + let loads = actions + .iter() + .filter_map(|x| match x { + Action::Load(x) => Some(x), + Action::Apply(_) => None, + }) + .collect::>(); + + let loaded = if !loads.is_empty() { + let query = loads + .iter() + .zip(std::iter::once(' ').chain(std::iter::repeat(','))) + .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| { + acc.push(prefix); + acc.push_str(cfg.name()); + acc + }); + + self.query_internal::(query) + .await? + .map(|row| row.unwrap()) + .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()]) } else { - self.opts().wait_timeout() + vec![] }; - self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64); + let mut loaded = loaded.into_iter(); + + for action in actions { + match action { + Action::Load(cfg) => cfg.apply(self, loaded.next()), + Action::Apply(cfg) => cfg.apply(self), + } + } Ok(()) } From 421ad9d47b6bf5a7256c7fbc5bab9cca05681f74 Mon Sep 17 00:00:00 2001 From: David Krasnitsky Date: Sun, 20 Aug 2023 13:41:16 +0300 Subject: [PATCH 024/130] Fixed spelling mistake in doc-comments I found a spelling-mistake in the doc comments where `Executes` is misspelled as `Exectues` of the following methods: `exec_batch()` `exec()` `exec_first()` `exec_map()` `exec_fold()` `exec_drop()` --- src/queryable/mod.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index 79062ec6..40edb0f6 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -273,7 +273,7 @@ pub trait Queryable: Send { async move { self.query_iter(query).await?.drop_result().await }.boxed() } - /// Exectues the given statement for each item in the given params iterator. + /// Executes the given statement for each item in the given params iterator. /// /// It'll prepare `stmt` (once), if necessary. fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()> @@ -283,7 +283,7 @@ pub trait Queryable: Send { I::IntoIter: Send, P: Into + Send; - /// Exectues the given statement and collects the first result set. + /// Executes the given statement and collects the first result set. /// /// It'll prepare `stmt`, if necessary. /// @@ -307,7 +307,7 @@ pub trait Queryable: Send { .boxed() } - /// Exectues the given statement and returns the first row of the first result set. + /// Executes the given statement and returns the first row of the first result set. /// /// It'll prepare `stmt`, if necessary. /// @@ -335,7 +335,7 @@ pub trait Queryable: Send { .boxed() } - /// Exectues the given stmt and maps each row of the first result set. + /// Executes the given stmt and maps each row of the first result set. /// /// It'll prepare `stmt`, if necessary. /// @@ -367,7 +367,7 @@ pub trait Queryable: Send { .boxed() } - /// Exectues the given stmt and folds the first result set to a signel value. + /// Executes the given stmt and folds the first result set to a signel value. /// /// It'll prepare `stmt`, if necessary. /// @@ -399,7 +399,7 @@ pub trait Queryable: Send { .boxed() } - /// Exectues the given statement and drops the result. + /// Executes the given statement and drops the result. fn exec_drop<'a: 'b, 'b, S, P>(&'a mut self, stmt: S, params: P) -> BoxFuture<'b, ()> where S: StatementLike + 'b, From 2a716a6521428d62ab3e25dfae54a72257428f42 Mon Sep 17 00:00:00 2001 From: Marcelo Altmann Date: Fri, 8 Sep 2023 15:51:04 -0300 Subject: [PATCH 025/130] Usability improvement: leading forward slash empty database A leading forward slash in the connection URL makes get_opts_db_name_from_url set db_name to Some("") rather than None. This makes code pulling the db from config requiring extra checks as db_name.is_some() will evaluate to true. Empty database parameter is invalid for MySQL. --- src/opts/mod.rs | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 9d1e849b..976fa9e0 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -1340,11 +1340,14 @@ fn get_opts_pass_from_url(url: &Url) -> Option { fn get_opts_db_name_from_url(url: &Url) -> Option { if let Some(mut segments) = url.path_segments() { - segments.next().map(|db_name| { - percent_decode(db_name.as_ref()) - .decode_utf8_lossy() - .into_owned() - }) + segments + .next() + .map(|db_name| { + percent_decode(db_name.as_ref()) + .decode_utf8_lossy() + .into_owned() + }) + .and_then(|db| if db.is_empty() { None } else { Some(db) }) } else { None } @@ -1813,4 +1816,18 @@ mod test { let opts = Opts::from_url("mysql://localhost/foo?compression=9").unwrap(); assert_eq!(opts.compression(), Some(crate::Compression::new(9))); } + + #[test] + fn test_builder_eq_url_empty_db() { + let builder = super::OptsBuilder::default(); + let builder_opts = Opts::from(builder); + + let url: &str = "mysql://iq-controller@localhost"; + let url_opts = super::Opts::from_str(url).unwrap(); + assert_eq!(url_opts.db_name(), builder_opts.db_name()); + + let url: &str = "mysql://iq-controller@localhost/"; + let url_opts = super::Opts::from_str(url).unwrap(); + assert_eq!(url_opts.db_name(), builder_opts.db_name()); + } } From 31d040f89078aa8bf973300823b0166859ac72a5 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 14 Sep 2023 11:16:50 +0300 Subject: [PATCH 026/130] get_opts_db_name_from_url: use Option::filter instead of and_then --- src/opts/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 976fa9e0..2b506650 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -1347,7 +1347,7 @@ fn get_opts_db_name_from_url(url: &Url) -> Option { .decode_utf8_lossy() .into_owned() }) - .and_then(|db| if db.is_empty() { None } else { Some(db) }) + .filter(|db| !db.is_empty()) } else { None } From da2193b3dfcc502734e0f46f55c57822cdb4c7c3 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 14 Sep 2023 11:30:40 +0300 Subject: [PATCH 027/130] Bump webpki (fix #256) --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 251ef374..1e9f8500 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,12 +62,12 @@ version = "1.0.1" optional = true [dependencies.webpki] -version = "0.22.0" +version = ">=0.22.1" features = ["std"] optional = true [dependencies.webpki-roots] -version = "0.23.0" +version = "0.25.0" optional = true [dev-dependencies] From c5f620efa2292e3b11b98c4db595e6432575c2bc Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 14 Sep 2023 11:33:40 +0300 Subject: [PATCH 028/130] Bump `lru` and `pem` deps --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1e9f8500..69b16ac9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,11 +21,11 @@ futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" lazy_static = "1" -lru = "0.10.0" +lru = "0.11.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } mysql_common = { version = "0.30", default-features = false } once_cell = "1.7.2" -pem = "2.0.1" +pem = "3.0" percent-encoding = "2.1.0" pin-project = "1.0.2" serde = "1" From 815971fdf03102c15b944b770af35309b88333cc Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 14 Sep 2023 16:14:39 +0300 Subject: [PATCH 029/130] Fix build for updated webpki-roots --- src/io/tls/rustls_io.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index 654581d3..a976b961 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -23,7 +23,7 @@ impl Endpoint { } let mut root_store = RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { OwnedTrustAnchor::from_subject_spki_name_constraints( ta.subject, ta.spki, @@ -56,7 +56,7 @@ impl Endpoint { let mut config = if let Some(identity) = ssl_opts.client_identity() { let (cert_chain, priv_key) = identity.load()?; - config_builder.with_single_cert(cert_chain, priv_key)? + config_builder.with_client_auth_cert(cert_chain, priv_key)? } else { config_builder.with_no_client_auth() }; From ec1a698e3590715a40ef198c41ee0affd2acfd61 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Thu, 31 Aug 2023 18:02:55 +0200 Subject: [PATCH 030/130] Option to set an absolute TTL for connections * The TTL, if set, forces connections to be disconnected, even overriding the minimum pool size constraint. * The TTL can be combined with a random jitter value to prevent all connections closing at the same time. * This enables gradual connection migration and regular connection recycling when CONN_RESET on return to pool is not desired. --- Cargo.toml | 2 +- src/conn/mod.rs | 8 ++ src/conn/pool/mod.rs | 51 +++++++++++- src/conn/pool/ttl_check_inerval.rs | 48 +++++++---- src/opts/mod.rs | 126 ++++++++++++++++++++++++++++- 5 files changed, 216 insertions(+), 19 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 69b16ac9..06cc3bb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ once_cell = "1.7.2" pem = "3.0" percent-encoding = "2.1.0" pin-project = "1.0.2" +rand = "0.8.5" serde = "1" serde_json = "1" socket2 = "0.5.2" @@ -74,7 +75,6 @@ optional = true tempfile = "3.1.0" socket2 = { version = "0.5.2", features = ["all"] } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } -rand = "0.8.0" [features] default = [ diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 356c5658..9dbbd5ed 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -107,6 +107,7 @@ struct ConnInner { tx_status: TxStatus, reset_upon_returning_to_a_pool: bool, opts: Opts, + ttl_deadline: Option, last_io: Instant, wait_timeout: Duration, stmt_cache: StmtCache, @@ -140,6 +141,7 @@ impl fmt::Debug for ConnInner { impl ConnInner { /// Constructs an empty connection. fn empty(opts: Opts) -> ConnInner { + let ttl_deadline = opts.pool_opts().new_connection_ttl_deadline(); ConnInner { capabilities: opts.get_capabilities(), status: StatusFlags::empty(), @@ -157,6 +159,7 @@ impl ConnInner { stmt_cache: StmtCache::new(opts.stmt_cache_size()), socket: opts.socket().map(Into::into), opts, + ttl_deadline, nonce: Vec::default(), auth_plugin: AuthPlugin::MysqlNativePassword, auth_switched: false, @@ -1088,6 +1091,11 @@ impl Conn { /// Returns true if time since last IO exceeds `wait_timeout` /// (or `conn_ttl` if specified in opts). fn expired(&self) -> bool { + if let Some(deadline) = self.inner.ttl_deadline { + if Instant::now() > deadline { + return true; + } + } let ttl = self .inner .opts diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 9fc29e71..a984bee5 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -44,6 +44,15 @@ struct IdlingConn { } impl IdlingConn { + /// Returns true when this connection has a TTL and it elapsed. + fn expired(&self) -> bool { + self.conn + .inner + .ttl_deadline + .map(|t| Instant::now() > t) + .unwrap_or_default() + } + /// Returns duration elapsed since this connection is idling. fn elapsed(&self) -> Duration { self.since.elapsed() @@ -82,8 +91,11 @@ impl Exchange { // Spawn the Recycler. tokio::spawn(Recycler::new(pool_opts.clone(), inner.clone(), dropped)); - // Spawn the ttl check interval if `inactive_connection_ttl` isn't `0` - if pool_opts.inactive_connection_ttl() > Duration::from_secs(0) { + // Spawn the ttl check interval if `inactive_connection_ttl` isn't `0` or + // connections have an absolute TTL. + if pool_opts.inactive_connection_ttl() > Duration::ZERO + || pool_opts.abs_conn_ttl().is_some() + { tokio::spawn(TtlCheckInterval::new(pool_opts, inner.clone())); } } @@ -1012,6 +1024,41 @@ mod test { assert_eq!(0, waitlist.queue.len()); } + #[tokio::test] + async fn check_absolute_connection_ttl() -> super::Result<()> { + let constraints = PoolConstraints::new(1, 3).unwrap(); + let pool_opts = PoolOpts::default() + .with_constraints(constraints) + .with_inactive_connection_ttl(Duration::from_secs(99)) + .with_ttl_check_interval(Duration::from_secs(1)) + .with_abs_conn_ttl(Some(Duration::from_secs(2))); + + let pool = Pool::new(get_opts().pool_opts(pool_opts)); + + let conn_ttl0 = pool.get_conn().await?; + sleep(Duration::from_millis(1000)).await; + let conn_ttl1 = pool.get_conn().await?; + sleep(Duration::from_millis(1000)).await; + let conn_ttl2 = pool.get_conn().await?; + + drop(conn_ttl0); + drop(conn_ttl1); + drop(conn_ttl2); + assert_eq!(ex_field!(pool, exist), 3); + + sleep(Duration::from_millis(1500)).await; + assert_eq!(ex_field!(pool, exist), 2); + + sleep(Duration::from_millis(1000)).await; + assert_eq!(ex_field!(pool, exist), 1); + + // Go even below min pool size. + sleep(Duration::from_millis(1000)).await; + assert_eq!(ex_field!(pool, exist), 0); + + Ok(()) + } + #[cfg(feature = "nightly")] mod bench { use futures_util::future::{FutureExt, TryFutureExt}; diff --git a/src/conn/pool/ttl_check_inerval.rs b/src/conn/pool/ttl_check_inerval.rs index 0cb4f5f4..dde8e529 100644 --- a/src/conn/pool/ttl_check_inerval.rs +++ b/src/conn/pool/ttl_check_inerval.rs @@ -11,6 +11,7 @@ use pin_project::pin_project; use tokio::time::{self, Interval}; use std::{ + collections::VecDeque, future::Future, sync::{atomic::Ordering, Arc}, }; @@ -46,24 +47,41 @@ impl TtlCheckInterval { /// Perform the check. pub fn check_ttl(&self) { - let mut exchange = self.inner.exchange.lock().unwrap(); + let to_be_dropped = { + let mut exchange = self.inner.exchange.lock().unwrap(); - let num_idling = exchange.available.len(); - let num_to_drop = num_idling.saturating_sub(self.pool_opts.constraints().min()); + let num_to_drop = exchange + .available + .len() + .saturating_sub(self.pool_opts.constraints().min()); - for _ in 0..num_to_drop { - let idling_conn = exchange.available.pop_front().unwrap(); - if idling_conn.elapsed() > self.pool_opts.inactive_connection_ttl() { - assert!(idling_conn.conn.inner.pool.is_none()); - let inner = self.inner.clone(); - tokio::spawn(idling_conn.conn.disconnect().then(move |_| { - let mut exchange = inner.exchange.lock().unwrap(); - exchange.exist -= 1; - ok::<_, ()>(()) - })); - } else { - exchange.available.push_back(idling_conn); + let mut to_be_dropped = Vec::<_>::with_capacity(exchange.available.len()); + let mut kept_available = + VecDeque::<_>::with_capacity(self.pool_opts.constraints().max()); + + while let Some(conn) = exchange.available.pop_front() { + if conn.expired() { + to_be_dropped.push(conn); + } else if to_be_dropped.len() < num_to_drop + && conn.elapsed() > self.pool_opts.inactive_connection_ttl() + { + to_be_dropped.push(conn); + } else { + kept_available.push_back(conn); + } } + exchange.available = kept_available; + to_be_dropped + }; + + for idling_conn in to_be_dropped { + assert!(idling_conn.conn.inner.pool.is_none()); + let inner = self.inner.clone(); + tokio::spawn(idling_conn.conn.disconnect().then(move |_| { + let mut exchange = inner.exchange.lock().unwrap(); + exchange.exist -= 1; + ok::<_, ()>(()) + })); } } } diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 2b506650..c7bd7052 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -16,6 +16,7 @@ pub use native_tls_opts::ClientIdentity; pub use rustls_opts::ClientIdentity; use percent_encoding::percent_decode; +use rand::Rng; use url::{Host, Url}; use std::{ @@ -26,7 +27,7 @@ use std::{ path::Path, str::FromStr, sync::Arc, - time::Duration, + time::{Duration, Instant}, vec, }; @@ -209,6 +210,8 @@ pub struct PoolOpts { constraints: PoolConstraints, inactive_connection_ttl: Duration, ttl_check_interval: Duration, + abs_conn_ttl: Option, + abs_conn_ttl_jitter: Option, reset_connection: bool, } @@ -273,6 +276,49 @@ impl PoolOpts { self.reset_connection } + /// Sets an absolute TTL after which a connection is removed from the pool. + /// This may push the pool below the requested minimum pool size and is indepedent of the + /// idle TTL. + /// The absolute TTL is disabled by default. + /// Fractions of seconds are ignored. + pub fn with_abs_conn_ttl(mut self, ttl: Option) -> Self { + self.abs_conn_ttl = ttl; + self + } + + /// Optionally, the absolute TTL can be extended by a per-connection random amount + /// bounded by `jitter`. + /// Setting `abs_conn_ttl_jitter` without `abs_conn_ttl` has no effect. + /// Fractions of seconds are ignored. + pub fn with_abs_conn_ttl_jitter(mut self, jitter: Option) -> Self { + self.abs_conn_ttl_jitter = jitter; + self + } + + /// Returns the absolute TTL, if set. + pub fn abs_conn_ttl(&self) -> Option { + self.abs_conn_ttl + } + + /// Returns the absolute TTL's jitter bound, if set. + pub fn abs_conn_ttl_jitter(&self) -> Option { + self.abs_conn_ttl_jitter + } + + /// Returns a new deadline that's TTL (+ random jitter) in the future. + pub(crate) fn new_connection_ttl_deadline(&self) -> Option { + if let Some(ttl) = self.abs_conn_ttl { + let jitter = if let Some(jitter) = self.abs_conn_ttl_jitter { + Duration::from_secs(rand::thread_rng().gen_range(0..=jitter.as_secs())) + } else { + Duration::ZERO + }; + Some(Instant::now() + ttl + jitter) + } else { + None + } + } + /// Pool will recycle inactive connection if it is outside of the lower bound of the pool /// and if it is idling longer than this value (defaults to /// [`DEFAULT_INACTIVE_CONNECTION_TTL`]). @@ -359,6 +405,8 @@ impl Default for PoolOpts { constraints: DEFAULT_POOL_CONSTRAINTS, inactive_connection_ttl: DEFAULT_INACTIVE_CONNECTION_TTL, ttl_check_interval: DEFAULT_TTL_CHECK_INTERVAL, + abs_conn_ttl: None, + abs_conn_ttl_jitter: None, reset_connection: true, } } @@ -662,6 +710,49 @@ impl Opts { self.inner.mysql_opts.conn_ttl } + /// The pool will close a connection when this absolute TTL has elapsed. + /// Disabled by default. + /// + /// Enables forced recycling and migration of connections in a guaranteed timeframe. + /// This TTL bypasses pool constraints and an idle pool can go below the min size. + /// + /// # Connection URL + /// + /// You can use `abs_conn_ttl` URL parameter to set this value (in seconds). E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?abs_conn_ttl=86400")?; + /// assert_eq!(opts.abs_conn_ttl(), Some(Duration::from_secs(24 * 60 * 60))); + /// # Ok(()) } + /// ``` + pub fn abs_conn_ttl(&self) -> Option { + self.inner.mysql_opts.pool_opts.abs_conn_ttl + } + + /// Upper bound of a random value added to the absolute TTL, if enabled. + /// Disabled by default. + /// + /// Should be used to prevent connections from closing at the same time. + /// + /// # Connection URL + /// + /// You can use `abs_conn_ttl_jitter` URL parameter to set this value (in seconds). E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?abs_conn_ttl=7200&abs_conn_ttl_jitter=3600")?; + /// assert_eq!(opts.abs_conn_ttl_jitter(), Some(Duration::from_secs(60 * 60))); + /// # Ok(()) } + /// ``` + pub fn abs_conn_ttl_jitter(&self) -> Option { + self.inner.mysql_opts.pool_opts.abs_conn_ttl_jitter + } + /// Number of prepared statements cached on the client side (per connection). Defaults to /// [`DEFAULT_STMT_CACHE_SIZE`]. /// @@ -1444,6 +1535,34 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "abs_conn_ttl" { + match u64::from_str(&*value) { + Ok(value) => { + opts.pool_opts = opts + .pool_opts + .with_abs_conn_ttl(Some(Duration::from_secs(value))) + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "abs_conn_ttl".into(), + value, + }); + } + } + } else if key == "abs_conn_ttl_jitter" { + match u64::from_str(&*value) { + Ok(value) => { + opts.pool_opts = opts + .pool_opts + .with_abs_conn_ttl_jitter(Some(Duration::from_secs(value))) + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "abs_conn_ttl_jitter".into(), + value, + }); + } + } } else if key == "tcp_keepalive" { match u32::from_str(&*value) { Ok(value) => opts.tcp_keepalive = Some(value), @@ -1679,6 +1798,11 @@ mod test { assert_eq!(url_opts.tcp_nodelay(), builder_opts.tcp_nodelay()); assert_eq!(url_opts.pool_opts(), builder_opts.pool_opts()); assert_eq!(url_opts.conn_ttl(), builder_opts.conn_ttl()); + assert_eq!(url_opts.abs_conn_ttl(), builder_opts.abs_conn_ttl()); + assert_eq!( + url_opts.abs_conn_ttl_jitter(), + builder_opts.abs_conn_ttl_jitter() + ); assert_eq!(url_opts.stmt_cache_size(), builder_opts.stmt_cache_size()); assert_eq!(url_opts.ssl_opts(), builder_opts.ssl_opts()); assert_eq!(url_opts.prefer_socket(), builder_opts.prefer_socket()); From 439fec4cae870ec768d47a0b2564c32ea84e0ffe Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Mon, 23 Oct 2023 18:32:02 +0300 Subject: [PATCH 031/130] Introduce `BinlogStreamRequest` --- src/conn/mod.rs | 130 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 109 insertions(+), 21 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 9dbbd5ed..d2a11646 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -14,9 +14,10 @@ use mysql_common::{ crypto, io::ParseBuf, packets::{ - binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, - HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, - OldEofPacket, ResultSetTerminator, SslRequest, + binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, BinlogDumpFlags, + ComRegisterSlave, CommonOkPacket, ErrPacket, HandshakePacket, HandshakeResponse, OkPacket, + OkPacketDeserializer, OldAuthSwitchRequest, OldEofPacket, ResultSetTerminator, Sid, + SslRequest, }, proto::MySerialize, row::Row, @@ -1260,12 +1261,9 @@ impl Conn { Ok(self) } - async fn register_as_slave(&mut self, server_id: u32) -> Result<()> { - use mysql_common::packets::ComRegisterSlave; - + async fn register_as_slave(&mut self, com_register_slave: ComRegisterSlave<'_>) -> Result<()> { self.query_drop("SET @master_binlog_checksum='ALL'").await?; - self.write_command(&ComRegisterSlave::new(server_id)) - .await?; + self.write_command(&com_register_slave).await?; // Server will respond with OK. self.read_packet().await?; @@ -1273,19 +1271,109 @@ impl Conn { Ok(()) } - async fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> { - self.register_as_slave(request.server_id()).await?; - self.write_command(&request.as_cmd()).await?; + async fn request_binlog(&mut self, request: BinlogStreamRequest<'_>) -> Result<()> { + self.register_as_slave(request.register_slave).await?; + self.write_command(&request.binlog_request.as_cmd()).await?; Ok(()) } - pub async fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result { + /// Turns this connection into a binlog stream. + /// + /// You can use SHOW BINARY LOGS to get the current logfile and position from the master. + /// If the request’s filename is empty, the server will send the binlog-stream of the first known binlog. + pub async fn get_binlog_stream( + mut self, + request: BinlogStreamRequest<'_>, + ) -> Result { self.request_binlog(request).await?; Ok(BinlogStream::new(self)) } } +/// Binlog stream request builder. +pub struct BinlogStreamRequest<'a> { + binlog_request: BinlogRequest<'a>, + register_slave: ComRegisterSlave<'a>, +} + +impl<'a> BinlogStreamRequest<'a> { + /// Creates a new request with the given slave server id. + pub fn new(server_id: u32) -> Self { + Self { + binlog_request: BinlogRequest::new(server_id), + register_slave: ComRegisterSlave::new(server_id), + } + } + + /// Enables GTID-based replication (disabled by default). + pub fn with_gtid(mut self) -> Self { + self.binlog_request = self.binlog_request.with_use_gtid(true); + self + } + + /// Enables `NON_BLOCK` flag. Stream will be terminated as soon as there are no events. + pub fn with_non_blocking(mut self) -> Self { + self.binlog_request = self + .binlog_request + .with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK); + self + } + + /// Sets the filename of the binlog on the master (try `SHOW BINARY LOGS`). + pub fn with_filename(mut self, filename: &'a [u8]) -> Self { + self.binlog_request = self.binlog_request.with_filename(filename); + self + } + + /// Sets the start position (defaults to `4`). + pub fn with_pos(mut self, position: u64) -> Self { + self.binlog_request = self.binlog_request.with_pos(position); + self + } + + /// Adds the given set of GTIDs to the request (ignored if not GTID-based). + pub fn with_gtid_set(mut self, set: T) -> Self + where + T: IntoIterator>, + { + self.binlog_request = self.binlog_request.with_sids(set); + self + } + + /// This hostname will be reported to the server (max len 255, default to an empty string). + /// + /// Usually left default. + pub fn with_hostname(mut self, hostname: &'a [u8]) -> Self { + self.register_slave = self.register_slave.with_hostname(hostname); + self + } + + /// This username will be reported to the server (max len 255, default to an empty string). + /// + /// Usually left default. + pub fn with_user(mut self, user: &'a [u8]) -> Self { + self.register_slave = self.register_slave.with_user(user); + self + } + + /// This password will be reported to the server (max len 255, default to an empty string). + /// + /// Usually left default. + pub fn with_password(mut self, password: &'a [u8]) -> Self { + self.register_slave = self.register_slave.with_password(password); + self + } + + /// This port number will be reported to the server (defaults to `0`). + /// + /// Usually left default. + pub fn with_port(mut self, port: u16) -> Self { + self.register_slave = self.register_slave.with_port(port); + self + } +} + #[cfg(test)] mod test { use bytes::Bytes; @@ -1297,7 +1385,7 @@ mod test { use std::time::Duration; use crate::{ - from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest, + conn::BinlogStreamRequest, from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error, OptsBuilder, Pool, Value, WhiteListFsHandler, }; @@ -1429,8 +1517,8 @@ mod test { let mut binlog_stream = conn .get_binlog_stream( - BinlogRequest::new(binlog_server_ids.0) - .with_filename(filename) + BinlogStreamRequest::new(binlog_server_ids.0) + .with_filename(&filename) .with_pos(pos), ) .await @@ -1467,9 +1555,9 @@ mod test { let mut binlog_stream = conn .get_binlog_stream( - BinlogRequest::new(binlog_server_ids.1) - .with_use_gtid(true) - .with_filename(filename) + BinlogStreamRequest::new(binlog_server_ids.1) + .with_gtid() + .with_filename(&filename) .with_pos(pos), ) .await @@ -1507,10 +1595,10 @@ mod test { let mut binlog_stream = conn .get_binlog_stream( - BinlogRequest::new(binlog_server_ids.2) - .with_filename(filename) + BinlogStreamRequest::new(binlog_server_ids.2) + .with_filename(&filename) .with_pos(pos) - .with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK), + .with_non_blocking(), ) .await .unwrap(); From 829774f20739621a9481c3775b92eae069239fed Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 24 Oct 2023 17:18:51 +0300 Subject: [PATCH 032/130] Hide binlog behind a feature --- Cargo.toml | 3 + src/conn/binlog_stream.rs | 120 ----------- src/conn/binlog_stream/mod.rs | 348 ++++++++++++++++++++++++++++++ src/conn/binlog_stream/request.rs | 86 ++++++++ src/conn/mod.rs | 309 +------------------------- src/lib.rs | 10 +- tests/exports.rs | 17 +- 7 files changed, 462 insertions(+), 431 deletions(-) delete mode 100644 src/conn/binlog_stream.rs create mode 100644 src/conn/binlog_stream/mod.rs create mode 100644 src/conn/binlog_stream/request.rs diff --git a/Cargo.toml b/Cargo.toml index 06cc3bb3..ad40fae1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,6 +85,7 @@ default = [ "mysql_common/frunk", "derive", "native-tls-tls", + "binlog", ] default-rustls = [ "flate2/rust_backend", @@ -94,6 +95,7 @@ default-rustls = [ "mysql_common/frunk", "derive", "rustls-tls", + "binlog", ] minimal = ["flate2/zlib"] native-tls-tls = ["native-tls", "tokio-native-tls"] @@ -107,6 +109,7 @@ rustls-tls = [ tracing = ["dep:tracing"] derive = ["mysql_common/derive"] nightly = [] +binlog = [] [lib] name = "mysql_async" diff --git a/src/conn/binlog_stream.rs b/src/conn/binlog_stream.rs deleted file mode 100644 index 6aca3305..00000000 --- a/src/conn/binlog_stream.rs +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2020 Anatoly Ikorsky -// -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , at your -// option. All files in the project carrying such notice may not be copied, -// modified, or distributed except according to those terms. - -use futures_core::ready; -use mysql_common::{ - binlog::{ - consts::BinlogVersion::Version4, - events::{Event, TableMapEvent}, - EventStreamReader, - }, - io::ParseBuf, - packets::{ErrPacket, NetworkStreamTerminator, OkPacketDeserializer}, -}; - -use std::{ - future::Future, - io::ErrorKind, - pin::Pin, - task::{Context, Poll}, -}; - -use crate::connection_like::Connection; -use crate::{error::DriverError, io::ReadPacket, Conn, Error, IoError, Result}; - -/// Binlog event stream. -/// -/// Stream initialization is lazy, i.e. binlog won't be requested until this stream is polled. -pub struct BinlogStream { - read_packet: ReadPacket<'static, 'static>, - esr: EventStreamReader, -} - -impl BinlogStream { - /// `conn` is a `Conn` with `request_binlog` executed on it. - pub(super) fn new(conn: Conn) -> Self { - BinlogStream { - read_packet: ReadPacket::new(conn), - esr: EventStreamReader::new(Version4), - } - } - - /// Returns a table map event for the given table id. - pub fn get_tme(&self, table_id: u64) -> Option<&TableMapEvent<'static>> { - self.esr.get_tme(table_id) - } - - /// Closes the stream's `Conn`. Additionally, the connection is dropped, so its associated - /// pool (if any) will regain a connection slot. - pub async fn close(self) -> Result<()> { - match self.read_packet.0 { - // `close_conn` requires ownership of `Conn`. That's okay, because - // `BinLogStream`'s connection is always owned. - Connection::Conn(conn) => { - if let Err(Error::Io(IoError::Io(ref error))) = conn.close_conn().await { - // If the binlog was requested with the flag BINLOG_DUMP_NON_BLOCK, - // the connection's file handler will already have been closed (EOF). - if error.kind() == ErrorKind::BrokenPipe { - return Ok(()); - } - } - } - Connection::ConnMut(_) => {} - Connection::Tx(_) => {} - } - - Ok(()) - } -} - -impl futures_core::stream::Stream for BinlogStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let packet = match ready!(Pin::new(&mut self.read_packet).poll(cx)) { - Ok(packet) => packet, - Err(err) => return Poll::Ready(Some(Err(err.into()))), - }; - - let first_byte = packet.get(0).copied(); - - if first_byte == Some(255) { - if let Ok(ErrPacket::Error(err)) = - ParseBuf(&*packet).parse(self.read_packet.conn_ref().capabilities()) - { - return Poll::Ready(Some(Err(From::from(err)))); - } - } - - if first_byte == Some(254) && packet.len() < 8 { - if ParseBuf(&*packet) - .parse::>( - self.read_packet.conn_ref().capabilities(), - ) - .is_ok() - { - return Poll::Ready(None); - } - } - - if first_byte == Some(0) { - let event_data = &packet[1..]; - match self.esr.read(event_data) { - Ok(event) => { - return Poll::Ready(Some(Ok(event))); - } - Err(err) => return Poll::Ready(Some(Err(err.into()))), - } - } else { - return Poll::Ready(Some(Err(DriverError::UnexpectedPacket { - payload: packet.to_vec(), - } - .into()))); - } - } -} diff --git a/src/conn/binlog_stream/mod.rs b/src/conn/binlog_stream/mod.rs new file mode 100644 index 00000000..b0e6a02a --- /dev/null +++ b/src/conn/binlog_stream/mod.rs @@ -0,0 +1,348 @@ +// Copyright (c) 2020 Anatoly Ikorsky +// +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , at your +// option. All files in the project carrying such notice may not be copied, +// modified, or distributed except according to those terms. + +use futures_core::ready; +use mysql_common::{ + binlog::{ + consts::BinlogVersion::Version4, + events::{Event, TableMapEvent}, + EventStreamReader, + }, + io::ParseBuf, + packets::{ComRegisterSlave, ErrPacket, NetworkStreamTerminator, OkPacketDeserializer}, +}; + +use std::{ + future::Future, + io::ErrorKind, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{connection_like::Connection, queryable::Queryable}; +use crate::{error::DriverError, io::ReadPacket, Conn, Error, IoError, Result}; + +use self::request::BinlogStreamRequest; + +pub mod request; + +impl super::Conn { + /// Turns this connection into a binlog stream. + /// + /// You can use SHOW BINARY LOGS to get the current logfile and position from the master. + /// If the request’s filename is empty, the server will send the binlog-stream of the first known binlog. + pub async fn get_binlog_stream( + mut self, + request: BinlogStreamRequest<'_>, + ) -> Result { + self.request_binlog(request).await?; + + Ok(BinlogStream::new(self)) + } + + async fn register_as_slave( + &mut self, + com_register_slave: ComRegisterSlave<'_>, + ) -> crate::Result<()> { + self.query_drop("SET @master_binlog_checksum='ALL'").await?; + self.write_command(&com_register_slave).await?; + + // Server will respond with OK. + self.read_packet().await?; + + Ok(()) + } + + async fn request_binlog(&mut self, request: BinlogStreamRequest<'_>) -> crate::Result<()> { + self.register_as_slave(request.register_slave).await?; + self.write_command(&request.binlog_request.as_cmd()).await?; + Ok(()) + } +} + +/// Binlog event stream. +/// +/// Stream initialization is lazy, i.e. binlog won't be requested until this stream is polled. +pub struct BinlogStream { + read_packet: ReadPacket<'static, 'static>, + esr: EventStreamReader, +} + +impl BinlogStream { + /// `conn` is a `Conn` with `request_binlog` executed on it. + pub(super) fn new(conn: Conn) -> Self { + BinlogStream { + read_packet: ReadPacket::new(conn), + esr: EventStreamReader::new(Version4), + } + } + + /// Returns a table map event for the given table id. + pub fn get_tme(&self, table_id: u64) -> Option<&TableMapEvent<'static>> { + self.esr.get_tme(table_id) + } + + /// Closes the stream's `Conn`. Additionally, the connection is dropped, so its associated + /// pool (if any) will regain a connection slot. + pub async fn close(self) -> Result<()> { + match self.read_packet.0 { + // `close_conn` requires ownership of `Conn`. That's okay, because + // `BinLogStream`'s connection is always owned. + Connection::Conn(conn) => { + if let Err(Error::Io(IoError::Io(ref error))) = conn.close_conn().await { + // If the binlog was requested with the flag BINLOG_DUMP_NON_BLOCK, + // the connection's file handler will already have been closed (EOF). + if error.kind() == ErrorKind::BrokenPipe { + return Ok(()); + } + } + } + Connection::ConnMut(_) => {} + Connection::Tx(_) => {} + } + + Ok(()) + } +} + +impl futures_core::stream::Stream for BinlogStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let packet = match ready!(Pin::new(&mut self.read_packet).poll(cx)) { + Ok(packet) => packet, + Err(err) => return Poll::Ready(Some(Err(err.into()))), + }; + + let first_byte = packet.get(0).copied(); + + if first_byte == Some(255) { + if let Ok(ErrPacket::Error(err)) = + ParseBuf(&*packet).parse(self.read_packet.conn_ref().capabilities()) + { + return Poll::Ready(Some(Err(From::from(err)))); + } + } + + if first_byte == Some(254) && packet.len() < 8 { + if ParseBuf(&*packet) + .parse::>( + self.read_packet.conn_ref().capabilities(), + ) + .is_ok() + { + return Poll::Ready(None); + } + } + + if first_byte == Some(0) { + let event_data = &packet[1..]; + match self.esr.read(event_data) { + Ok(event) => { + return Poll::Ready(Some(Ok(event))); + } + Err(err) => return Poll::Ready(Some(Err(err.into()))), + } + } else { + return Poll::Ready(Some(Err(DriverError::UnexpectedPacket { + payload: packet.to_vec(), + } + .into()))); + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures_util::StreamExt; + use mysql_common::binlog::events::EventData; + use tokio::time::timeout; + + use crate::prelude::*; + use crate::{test_misc::get_opts, *}; + + async fn gen_dummy_data() -> super::Result<()> { + let mut conn = Conn::new(get_opts()).await?; + + "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)" + .ignore(&mut conn) + .await?; + + for i in 0_u8..100 { + "INSERT INTO customers(customer_id) VALUES (?)" + .with((i,)) + .ignore(&mut conn) + .await?; + } + + "DROP TABLE customers".ignore(&mut conn).await?; + + Ok(()) + } + + async fn create_binlog_stream_conn(pool: Option<&Pool>) -> super::Result<(Conn, Vec, u64)> { + let mut conn = match pool { + None => Conn::new(get_opts()).await.unwrap(), + Some(pool) => pool.get_conn().await.unwrap(), + }; + + if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE" + .first::(&mut conn) + .await + { + if !gtid_mode.starts_with("ON") { + panic!( + "GTID_MODE is disabled \ + (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)" + ); + } + } + + let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn).await.unwrap().unwrap(); + let filename = row.get(0).unwrap(); + let position = row.get(1).unwrap(); + + gen_dummy_data().await.unwrap(); + Ok((conn, filename, position)) + } + + #[tokio::test] + async fn should_read_binlog() -> super::Result<()> { + read_binlog_streams_and_close_their_connections(None, (12, 13, 14)) + .await + .unwrap(); + + let pool = Pool::new(get_opts()); + read_binlog_streams_and_close_their_connections(Some(&pool), (15, 16, 17)) + .await + .unwrap(); + + // Disconnecting the pool verifies that closing the binlog connections + // left the pool in a sane state. + timeout(Duration::from_secs(10), pool.disconnect()) + .await + .unwrap() + .unwrap(); + + Ok(()) + } + + async fn read_binlog_streams_and_close_their_connections( + pool: Option<&Pool>, + binlog_server_ids: (u32, u32, u32), + ) -> super::Result<()> { + // iterate using COM_BINLOG_DUMP + let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); + let is_mariadb = conn.inner.is_mariadb; + + let mut binlog_stream = conn + .get_binlog_stream( + BinlogStreamRequest::new(binlog_server_ids.0) + .with_filename(&filename) + .with_pos(pos), + ) + .await + .unwrap(); + + let mut events_num = 0; + while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await { + let event = event.unwrap(); + events_num += 1; + + // assert that event type is known + event.header().event_type().unwrap(); + + // iterate over rows of an event + match event.read_data()?.unwrap() { + EventData::RowsEvent(re) => { + let tme = binlog_stream.get_tme(re.table_id()); + for row in re.rows(tme.unwrap()) { + row.unwrap(); + } + } + _ => (), + } + } + assert!(events_num > 0); + timeout(Duration::from_secs(10), binlog_stream.close()) + .await + .unwrap() + .unwrap(); + + if !is_mariadb { + // iterate using COM_BINLOG_DUMP_GTID + let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); + + let mut binlog_stream = conn + .get_binlog_stream( + BinlogStreamRequest::new(binlog_server_ids.1) + .with_gtid() + .with_filename(&filename) + .with_pos(pos), + ) + .await + .unwrap(); + + events_num = 0; + while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await + { + let event = event.unwrap(); + events_num += 1; + + // assert that event type is known + event.header().event_type().unwrap(); + + // iterate over rows of an event + match event.read_data()?.unwrap() { + EventData::RowsEvent(re) => { + let tme = binlog_stream.get_tme(re.table_id()); + for row in re.rows(tme.unwrap()) { + row.unwrap(); + } + } + _ => (), + } + } + assert!(events_num > 0); + timeout(Duration::from_secs(10), binlog_stream.close()) + .await + .unwrap() + .unwrap(); + } + + // iterate using COM_BINLOG_DUMP with BINLOG_DUMP_NON_BLOCK flag + let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); + + let mut binlog_stream = conn + .get_binlog_stream( + BinlogStreamRequest::new(binlog_server_ids.2) + .with_filename(&filename) + .with_pos(pos) + .with_non_blocking(), + ) + .await + .unwrap(); + + events_num = 0; + while let Some(event) = binlog_stream.next().await { + let event = event.unwrap(); + events_num += 1; + event.header().event_type().unwrap(); + event.read_data().unwrap(); + } + assert!(events_num > 0); + timeout(Duration::from_secs(10), binlog_stream.close()) + .await + .unwrap() + .unwrap(); + + Ok(()) + } +} diff --git a/src/conn/binlog_stream/request.rs b/src/conn/binlog_stream/request.rs new file mode 100644 index 00000000..23265b7c --- /dev/null +++ b/src/conn/binlog_stream/request.rs @@ -0,0 +1,86 @@ +use mysql_common::packets::{ + binlog_request::BinlogRequest, BinlogDumpFlags, ComRegisterSlave, Sid, +}; + +/// Binlog stream request builder. +pub struct BinlogStreamRequest<'a> { + pub(crate) binlog_request: BinlogRequest<'a>, + pub(crate) register_slave: ComRegisterSlave<'a>, +} + +impl<'a> BinlogStreamRequest<'a> { + /// Creates a new request with the given slave server id. + pub fn new(server_id: u32) -> Self { + Self { + binlog_request: BinlogRequest::new(server_id), + register_slave: ComRegisterSlave::new(server_id), + } + } + + /// Enables GTID-based replication (disabled by default). + pub fn with_gtid(mut self) -> Self { + self.binlog_request = self.binlog_request.with_use_gtid(true); + self + } + + /// Enables `NON_BLOCK` flag. Stream will be terminated as soon as there are no events. + pub fn with_non_blocking(mut self) -> Self { + self.binlog_request = self + .binlog_request + .with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK); + self + } + + /// Sets the filename of the binlog on the master (try `SHOW BINARY LOGS`). + pub fn with_filename(mut self, filename: &'a [u8]) -> Self { + self.binlog_request = self.binlog_request.with_filename(filename); + self + } + + /// Sets the start position (defaults to `4`). + pub fn with_pos(mut self, position: u64) -> Self { + self.binlog_request = self.binlog_request.with_pos(position); + self + } + + /// Adds the given set of GTIDs to the request (ignored if not GTID-based). + pub fn with_gtid_set(mut self, set: T) -> Self + where + T: IntoIterator>, + { + self.binlog_request = self.binlog_request.with_sids(set); + self + } + + /// This hostname will be reported to the server (max len 255, default to an empty string). + /// + /// Usually left default. + pub fn with_hostname(mut self, hostname: &'a [u8]) -> Self { + self.register_slave = self.register_slave.with_hostname(hostname); + self + } + + /// This username will be reported to the server (max len 255, default to an empty string). + /// + /// Usually left default. + pub fn with_user(mut self, user: &'a [u8]) -> Self { + self.register_slave = self.register_slave.with_user(user); + self + } + + /// This password will be reported to the server (max len 255, default to an empty string). + /// + /// Usually left default. + pub fn with_password(mut self, password: &'a [u8]) -> Self { + self.register_slave = self.register_slave.with_password(password); + self + } + + /// This port number will be reported to the server (defaults to `0`). + /// + /// Usually left default. + pub fn with_port(mut self, port: u16) -> Self { + self.register_slave = self.register_slave.with_port(port); + self + } +} diff --git a/src/conn/mod.rs b/src/conn/mod.rs index d2a11646..04b28cad 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -14,10 +14,9 @@ use mysql_common::{ crypto, io::ParseBuf, packets::{ - binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, BinlogDumpFlags, - ComRegisterSlave, CommonOkPacket, ErrPacket, HandshakePacket, HandshakeResponse, OkPacket, - OkPacketDeserializer, OldAuthSwitchRequest, OldEofPacket, ResultSetTerminator, Sid, - SslRequest, + AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, HandshakePacket, + HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, OldEofPacket, + ResultSetTerminator, SslRequest, }, proto::MySerialize, row::Row, @@ -46,11 +45,12 @@ use crate::{ transaction::TxStatus, BinaryProtocol, Queryable, TextProtocol, }, - BinlogStream, ChangeUserOpts, InfileData, OptsBuilder, + ChangeUserOpts, InfileData, OptsBuilder, }; use self::routines::Routine; +#[cfg(feature = "binlog")] pub mod binlog_stream; pub mod pool; pub mod routines; @@ -1260,201 +1260,20 @@ impl Conn { } Ok(self) } - - async fn register_as_slave(&mut self, com_register_slave: ComRegisterSlave<'_>) -> Result<()> { - self.query_drop("SET @master_binlog_checksum='ALL'").await?; - self.write_command(&com_register_slave).await?; - - // Server will respond with OK. - self.read_packet().await?; - - Ok(()) - } - - async fn request_binlog(&mut self, request: BinlogStreamRequest<'_>) -> Result<()> { - self.register_as_slave(request.register_slave).await?; - self.write_command(&request.binlog_request.as_cmd()).await?; - Ok(()) - } - - /// Turns this connection into a binlog stream. - /// - /// You can use SHOW BINARY LOGS to get the current logfile and position from the master. - /// If the request’s filename is empty, the server will send the binlog-stream of the first known binlog. - pub async fn get_binlog_stream( - mut self, - request: BinlogStreamRequest<'_>, - ) -> Result { - self.request_binlog(request).await?; - - Ok(BinlogStream::new(self)) - } -} - -/// Binlog stream request builder. -pub struct BinlogStreamRequest<'a> { - binlog_request: BinlogRequest<'a>, - register_slave: ComRegisterSlave<'a>, -} - -impl<'a> BinlogStreamRequest<'a> { - /// Creates a new request with the given slave server id. - pub fn new(server_id: u32) -> Self { - Self { - binlog_request: BinlogRequest::new(server_id), - register_slave: ComRegisterSlave::new(server_id), - } - } - - /// Enables GTID-based replication (disabled by default). - pub fn with_gtid(mut self) -> Self { - self.binlog_request = self.binlog_request.with_use_gtid(true); - self - } - - /// Enables `NON_BLOCK` flag. Stream will be terminated as soon as there are no events. - pub fn with_non_blocking(mut self) -> Self { - self.binlog_request = self - .binlog_request - .with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK); - self - } - - /// Sets the filename of the binlog on the master (try `SHOW BINARY LOGS`). - pub fn with_filename(mut self, filename: &'a [u8]) -> Self { - self.binlog_request = self.binlog_request.with_filename(filename); - self - } - - /// Sets the start position (defaults to `4`). - pub fn with_pos(mut self, position: u64) -> Self { - self.binlog_request = self.binlog_request.with_pos(position); - self - } - - /// Adds the given set of GTIDs to the request (ignored if not GTID-based). - pub fn with_gtid_set(mut self, set: T) -> Self - where - T: IntoIterator>, - { - self.binlog_request = self.binlog_request.with_sids(set); - self - } - - /// This hostname will be reported to the server (max len 255, default to an empty string). - /// - /// Usually left default. - pub fn with_hostname(mut self, hostname: &'a [u8]) -> Self { - self.register_slave = self.register_slave.with_hostname(hostname); - self - } - - /// This username will be reported to the server (max len 255, default to an empty string). - /// - /// Usually left default. - pub fn with_user(mut self, user: &'a [u8]) -> Self { - self.register_slave = self.register_slave.with_user(user); - self - } - - /// This password will be reported to the server (max len 255, default to an empty string). - /// - /// Usually left default. - pub fn with_password(mut self, password: &'a [u8]) -> Self { - self.register_slave = self.register_slave.with_password(password); - self - } - - /// This port number will be reported to the server (defaults to `0`). - /// - /// Usually left default. - pub fn with_port(mut self, port: u16) -> Self { - self.register_slave = self.register_slave.with_port(port); - self - } } #[cfg(test)] mod test { use bytes::Bytes; use futures_util::stream::{self, StreamExt}; - use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN}; + use mysql_common::constants::MAX_PAYLOAD_LEN; use rand::Fill; - use tokio::time::timeout; - - use std::time::Duration; use crate::{ - conn::BinlogStreamRequest, from_row, params, prelude::*, test_misc::get_opts, - ChangeUserOpts, Conn, Error, OptsBuilder, Pool, Value, WhiteListFsHandler, + from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error, + OptsBuilder, Pool, Value, WhiteListFsHandler, }; - async fn gen_dummy_data() -> super::Result<()> { - let mut conn = Conn::new(get_opts()).await?; - - "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)" - .ignore(&mut conn) - .await?; - - for i in 0_u8..100 { - "INSERT INTO customers(customer_id) VALUES (?)" - .with((i,)) - .ignore(&mut conn) - .await?; - } - - "DROP TABLE customers".ignore(&mut conn).await?; - - Ok(()) - } - - async fn create_binlog_stream_conn(pool: Option<&Pool>) -> super::Result<(Conn, Vec, u64)> { - let mut conn = match pool { - None => Conn::new(get_opts()).await.unwrap(), - Some(pool) => pool.get_conn().await.unwrap(), - }; - - if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE" - .first::(&mut conn) - .await - { - if !gtid_mode.starts_with("ON") { - panic!( - "GTID_MODE is disabled \ - (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)" - ); - } - } - - let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn).await.unwrap().unwrap(); - let filename = row.get(0).unwrap(); - let position = row.get(1).unwrap(); - - gen_dummy_data().await.unwrap(); - Ok((conn, filename, position)) - } - - #[tokio::test] - async fn should_read_binlog() -> super::Result<()> { - read_binlog_streams_and_close_their_connections(None, (12, 13, 14)) - .await - .unwrap(); - - let pool = Pool::new(get_opts()); - read_binlog_streams_and_close_their_connections(Some(&pool), (15, 16, 17)) - .await - .unwrap(); - - // Disconnecting the pool verifies that closing the binlog connections - // left the pool in a sane state. - timeout(Duration::from_secs(10), pool.disconnect()) - .await - .unwrap() - .unwrap(); - - Ok(()) - } - #[tokio::test] async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> { let opts = get_opts().client_found_rows(true); @@ -1507,118 +1326,6 @@ mod test { Ok(()) } - async fn read_binlog_streams_and_close_their_connections( - pool: Option<&Pool>, - binlog_server_ids: (u32, u32, u32), - ) -> super::Result<()> { - // iterate using COM_BINLOG_DUMP - let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); - let is_mariadb = conn.inner.is_mariadb; - - let mut binlog_stream = conn - .get_binlog_stream( - BinlogStreamRequest::new(binlog_server_ids.0) - .with_filename(&filename) - .with_pos(pos), - ) - .await - .unwrap(); - - let mut events_num = 0; - while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await { - let event = event.unwrap(); - events_num += 1; - - // assert that event type is known - event.header().event_type().unwrap(); - - // iterate over rows of an event - match event.read_data()?.unwrap() { - EventData::RowsEvent(re) => { - let tme = binlog_stream.get_tme(re.table_id()); - for row in re.rows(tme.unwrap()) { - row.unwrap(); - } - } - _ => (), - } - } - assert!(events_num > 0); - timeout(Duration::from_secs(10), binlog_stream.close()) - .await - .unwrap() - .unwrap(); - - if !is_mariadb { - // iterate using COM_BINLOG_DUMP_GTID - let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); - - let mut binlog_stream = conn - .get_binlog_stream( - BinlogStreamRequest::new(binlog_server_ids.1) - .with_gtid() - .with_filename(&filename) - .with_pos(pos), - ) - .await - .unwrap(); - - events_num = 0; - while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await - { - let event = event.unwrap(); - events_num += 1; - - // assert that event type is known - event.header().event_type().unwrap(); - - // iterate over rows of an event - match event.read_data()?.unwrap() { - EventData::RowsEvent(re) => { - let tme = binlog_stream.get_tme(re.table_id()); - for row in re.rows(tme.unwrap()) { - row.unwrap(); - } - } - _ => (), - } - } - assert!(events_num > 0); - timeout(Duration::from_secs(10), binlog_stream.close()) - .await - .unwrap() - .unwrap(); - } - - // iterate using COM_BINLOG_DUMP with BINLOG_DUMP_NON_BLOCK flag - let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); - - let mut binlog_stream = conn - .get_binlog_stream( - BinlogStreamRequest::new(binlog_server_ids.2) - .with_filename(&filename) - .with_pos(pos) - .with_non_blocking(), - ) - .await - .unwrap(); - - events_num = 0; - while let Some(event) = binlog_stream.next().await { - let event = event.unwrap(); - events_num += 1; - event.header().event_type().unwrap(); - event.read_data().unwrap(); - } - assert!(events_num > 0); - timeout(Duration::from_secs(10), binlog_stream.close()) - .await - .unwrap() - .unwrap(); - - Ok(()) - } - #[test] fn opts_should_satisfy_send_and_sync() { struct A(T); diff --git a/src/lib.rs b/src/lib.rs index 9c5ff836..4f5fc568 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -449,8 +449,12 @@ type BoxFuture<'a, T> = futures_core::future::BoxFuture<'a, Result>; static BUFFER_POOL: once_cell::sync::Lazy> = once_cell::sync::Lazy::new(|| Default::default()); +#[cfg(feature = "binlog")] #[doc(inline)] -pub use self::conn::{binlog_stream::BinlogStream, Conn}; +pub use self::conn::binlog_stream::{request::BinlogStreamRequest, BinlogStream}; + +#[doc(inline)] +pub use self::conn::Conn; #[doc(inline)] pub use self::conn::pool::Pool; @@ -482,14 +486,14 @@ pub use self::local_infile_handler::{builtin::WhiteListFsHandler, InfileData}; #[doc(inline)] pub use mysql_common::packets::{ - binlog_request::BinlogRequest, session_state_change::{ Gtids, Schema, SessionStateChange, SystemVariable, TransactionCharacteristics, TransactionState, Unsupported, }, - BinlogDumpFlags, Column, GnoInterval, OkPacket, SessionStateInfo, Sid, + Column, GnoInterval, OkPacket, SessionStateInfo, Sid, }; +#[cfg(feature = "binlog")] pub mod binlog { #[doc(inline)] pub use mysql_common::binlog::consts::*; diff --git a/tests/exports.rs b/tests/exports.rs index 92255dbb..6f9feef8 100644 --- a/tests/exports.rs +++ b/tests/exports.rs @@ -7,11 +7,14 @@ use mysql_async::{ BatchQuery, FromRow, FromValue, GlobalHandler, Protocol, Query, Queryable, StatementLike, ToValue, }, - BinaryProtocol, BinlogDumpFlags, BinlogRequest, Column, Conn, Deserialized, DriverError, Error, - FromRowError, FromValueError, GnoInterval, Gtids, IoError, IsolationLevel, OkPacket, Opts, - OptsBuilder, Params, ParseError, Pool, PoolConstraints, PoolOpts, QueryResult, Result, Row, - Schema, Serialized, ServerError, SessionStateChange, SessionStateInfo, Sid, SslOpts, Statement, - SystemVariable, TextProtocol, Transaction, TransactionCharacteristics, TransactionState, - TxOpts, Unsupported, UrlError, Value, WhiteListFsHandler, DEFAULT_INACTIVE_CONNECTION_TTL, - DEFAULT_TTL_CHECK_INTERVAL, + BinaryProtocol, Column, Conn, Deserialized, DriverError, Error, FromRowError, FromValueError, + GnoInterval, Gtids, IoError, IsolationLevel, OkPacket, Opts, OptsBuilder, Params, ParseError, + Pool, PoolConstraints, PoolOpts, QueryResult, Result, Row, Schema, Serialized, ServerError, + SessionStateChange, SessionStateInfo, Sid, SslOpts, Statement, SystemVariable, TextProtocol, + Transaction, TransactionCharacteristics, TransactionState, TxOpts, Unsupported, UrlError, + Value, WhiteListFsHandler, DEFAULT_INACTIVE_CONNECTION_TTL, DEFAULT_TTL_CHECK_INTERVAL, }; + +#[cfg(feature = "binlog")] +#[allow(unused_imports)] +use mysql_async::{binlog, BinlogStream, BinlogStreamRequest}; From 0e84bd169088d9ba467b7c11ea076f00d099325e Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 2 Nov 2023 17:15:13 +0300 Subject: [PATCH 033/130] Bump mysql_common --- Cargo.toml | 4 +-- src/conn/binlog_stream/mod.rs | 56 ++++++++++++++++++++++++++++------- src/conn/routines/exec.rs | 5 ++-- src/io/read_packet.rs | 5 ++-- src/queryable/stmt.rs | 25 +++++++++++----- 5 files changed, 70 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ad40fae1..8a0d8302 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ keyed_priority_queue = "0.4" lazy_static = "1" lru = "0.11.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } -mysql_common = { version = "0.30", default-features = false } +mysql_common = { version = "0.31", default-features = false } once_cell = "1.7.2" pem = "3.0" percent-encoding = "2.1.0" @@ -109,7 +109,7 @@ rustls-tls = [ tracing = ["dep:tracing"] derive = ["mysql_common/derive"] nightly = [] -binlog = [] +binlog = ["mysql_common/binlog"] [lib] name = "mysql_async" diff --git a/src/conn/binlog_stream/mod.rs b/src/conn/binlog_stream/mod.rs index b0e6a02a..c4749f01 100644 --- a/src/conn/binlog_stream/mod.rs +++ b/src/conn/binlog_stream/mod.rs @@ -9,8 +9,8 @@ use futures_core::ready; use mysql_common::{ binlog::{ - consts::BinlogVersion::Version4, - events::{Event, TableMapEvent}, + consts::{BinlogVersion::Version4, EventType}, + events::{Event, TableMapEvent, TransactionPayloadEvent}, EventStreamReader, }, io::ParseBuf, @@ -19,7 +19,7 @@ use mysql_common::{ use std::{ future::Future, - io::ErrorKind, + io::{Cursor, ErrorKind}, pin::Pin, task::{Context, Poll}, }; @@ -71,6 +71,9 @@ impl super::Conn { pub struct BinlogStream { read_packet: ReadPacket<'static, 'static>, esr: EventStreamReader, + // TODO: Use 'static reader here (requires impl on the mysql_common side). + /// Uncompressed Transaction_payload_event we are iterating over (if any). + tpe: Option>>, } impl BinlogStream { @@ -79,6 +82,7 @@ impl BinlogStream { BinlogStream { read_packet: ReadPacket::new(conn), esr: EventStreamReader::new(Version4), + tpe: None, } } @@ -114,6 +118,22 @@ impl futures_core::stream::Stream for BinlogStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + { + let Self { + ref mut tpe, + ref mut esr, + .. + } = *self; + + if let Some(tpe) = tpe.as_mut() { + match esr.read_decompressed(tpe) { + Ok(Some(event)) => return Poll::Ready(Some(Ok(event))), + Ok(None) => self.tpe = None, + Err(err) => return Poll::Ready(Some(Err(err.into()))), + } + } + } + let packet = match ready!(Pin::new(&mut self.read_packet).poll(cx)) { Ok(packet) => packet, Err(err) => return Poll::Ready(Some(Err(err.into()))), @@ -143,9 +163,17 @@ impl futures_core::stream::Stream for BinlogStream { if first_byte == Some(0) { let event_data = &packet[1..]; match self.esr.read(event_data) { - Ok(event) => { + Ok(Some(event)) => { + if event.header().event_type_raw() == EventType::TRANSACTION_PAYLOAD_EVENT as u8 + { + match event.read_event::>() { + Ok(e) => self.tpe = Some(Cursor::new(e.danger_decompress())), + Err(_) => (/* TODO: Log the error */), + } + } return Poll::Ready(Some(Ok(event))); } + Ok(None) => return Poll::Ready(None), Err(err) => return Poll::Ready(Some(Err(err.into()))), } } else { @@ -168,21 +196,21 @@ mod tests { use crate::prelude::*; use crate::{test_misc::get_opts, *}; - async fn gen_dummy_data() -> super::Result<()> { - let mut conn = Conn::new(get_opts()).await?; - + async fn gen_dummy_data(conn: &mut Conn) -> super::Result<()> { "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)" - .ignore(&mut conn) + .ignore(&mut *conn) .await?; + let mut tx = conn.start_transaction(Default::default()).await?; for i in 0_u8..100 { "INSERT INTO customers(customer_id) VALUES (?)" .with((i,)) - .ignore(&mut conn) + .ignore(&mut tx) .await?; } + tx.commit().await?; - "DROP TABLE customers".ignore(&mut conn).await?; + "DROP TABLE customers".ignore(conn).await?; Ok(()) } @@ -193,6 +221,12 @@ mod tests { Some(pool) => pool.get_conn().await.unwrap(), }; + if conn.server_version() >= (8, 0, 31) && conn.server_version() < (9, 0, 0) { + let _ = "SET binlog_transaction_compression=ON" + .ignore(&mut conn) + .await; + } + if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE" .first::(&mut conn) .await @@ -209,7 +243,7 @@ mod tests { let filename = row.get(0).unwrap(); let position = row.get(1).unwrap(); - gen_dummy_data().await.unwrap(); + gen_dummy_data(&mut conn).await.unwrap(); Ok((conn, filename, position)) } diff --git a/src/conn/routines/exec.rs b/src/conn/routines/exec.rs index 262a90c9..feabb8fd 100644 --- a/src/conn/routines/exec.rs +++ b/src/conn/routines/exec.rs @@ -71,14 +71,13 @@ impl Routine<()> for ExecRoutine<'_> { break; } Params::Named(_) => { - if self.stmt.named_params.is_none() { + if self.stmt.named_params.is_empty() { let error = DriverError::NamedParamsForPositionalQuery.into(); return Err(error); } let named = mem::replace(&mut self.params, Params::Empty); - self.params = - named.into_positional(self.stmt.named_params.as_ref().unwrap())?; + self.params = named.into_positional(&self.stmt.named_params)?; continue; } diff --git a/src/io/read_packet.rs b/src/io/read_packet.rs index 7e14fca0..226e9ab6 100644 --- a/src/io/read_packet.rs +++ b/src/io/read_packet.rs @@ -15,7 +15,7 @@ use std::{ task::{Context, Poll}, }; -use crate::{buffer_pool::PooledBuf, connection_like::Connection, error::IoError, Conn}; +use crate::{buffer_pool::PooledBuf, connection_like::Connection, error::IoError}; /// Reads a packet. #[derive(Debug)] @@ -27,7 +27,8 @@ impl<'a, 't> ReadPacket<'a, 't> { Self(conn.into()) } - pub(crate) fn conn_ref(&self) -> &Conn { + #[cfg(feature = "binlog")] + pub(crate) fn conn_ref(&self) -> &crate::Conn { &*self.0 } } diff --git a/src/queryable/stmt.rs b/src/queryable/stmt.rs index 6b4a2eb8..0624fb50 100644 --- a/src/queryable/stmt.rs +++ b/src/queryable/stmt.rs @@ -9,7 +9,7 @@ use futures_util::FutureExt; use mysql_common::{ io::ParseBuf, - named_params::parse_named_params, + named_params::ParsedNamedParams, packets::{ComStmtClose, StmtPacket}, }; @@ -45,12 +45,22 @@ fn to_statement_move<'a, T: AsQuery + 'a>( ) -> ToStatementResult<'a> { let fut = async move { let query = stmt.as_query(); - let (named_params, raw_query) = parse_named_params(query.as_ref())?; - let inner_stmt = match conn.get_cached_stmt(&*raw_query) { + let parsed = ParsedNamedParams::parse(query.as_ref())?; + let inner_stmt = match conn.get_cached_stmt(parsed.query()) { Some(inner_stmt) => inner_stmt, - None => conn.prepare_statement(raw_query).await?, + None => { + conn.prepare_statement(Cow::Borrowed(parsed.query())) + .await? + } }; - Ok(Statement::new(inner_stmt, named_params)) + Ok(Statement::new( + inner_stmt, + parsed + .params() + .iter() + .map(|x| x.as_ref().to_vec()) + .collect::>(), + )) } .boxed(); ToStatementResult::Mediate(fut) @@ -240,11 +250,12 @@ impl StmtInner { #[derive(Debug, Clone, Eq, PartialEq)] pub struct Statement { pub(crate) inner: Arc, - pub(crate) named_params: Option>>, + /// An empty vector in case of no named params. + pub(crate) named_params: Vec>, } impl Statement { - pub(crate) fn new(inner: Arc, named_params: Option>>) -> Self { + pub(crate) fn new(inner: Arc, named_params: Vec>) -> Self { Self { inner, named_params, From 5e94f56648eb987160b69998b3011762611ead16 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 2 Nov 2023 17:21:16 +0300 Subject: [PATCH 034/130] Bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 8a0d8302..e1939be7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.32.2" +version = "0.33.0" exclude = ["test/*"] edition = "2018" categories = ["asynchronous", "database"] From fcb25d9b84cb554f4a0486fccaa043b996831265 Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 7 Nov 2023 01:11:42 +0800 Subject: [PATCH 035/130] add ResultSetStream types export --- src/queryable/query_result/result_set_stream.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/queryable/query_result/result_set_stream.rs b/src/queryable/query_result/result_set_stream.rs index 9210d9b8..6c42890d 100644 --- a/src/queryable/query_result/result_set_stream.rs +++ b/src/queryable/query_result/result_set_stream.rs @@ -102,6 +102,14 @@ impl<'r, 'a: 'r, 't: 'a, T, P> ResultSetStream<'r, 'a, 't, T, P> { .unwrap_or_default() } + /// Returns type result set. + /// + /// In order to know the type of the returned result in advance, it is helpful to process the type conversion of the data. + pub fn get_columns(&self) -> Arc<[Column]> { + self.columns.clone() + } + + /// See [`Conn::info`][1]. /// /// [1]: crate::Conn::info From 4023e5e3860f427a89571d0aa7bdbf56f27f9d75 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 7 Nov 2023 15:03:30 +0300 Subject: [PATCH 036/130] Style ResultSetStream::{columns, columns_ref} in accordance with the rest of the project --- src/queryable/query_result/result_set_stream.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/queryable/query_result/result_set_stream.rs b/src/queryable/query_result/result_set_stream.rs index 6c42890d..d9eecc7d 100644 --- a/src/queryable/query_result/result_set_stream.rs +++ b/src/queryable/query_result/result_set_stream.rs @@ -102,13 +102,15 @@ impl<'r, 'a: 'r, 't: 'a, T, P> ResultSetStream<'r, 'a, 't, T, P> { .unwrap_or_default() } - /// Returns type result set. - /// - /// In order to know the type of the returned result in advance, it is helpful to process the type conversion of the data. - pub fn get_columns(&self) -> Arc<[Column]> { - self.columns.clone() + /// See [`QueryResult::columns_ref`]. + pub fn columns_ref(&self) -> &[Column] { + &self.columns[..] } + /// See [`QueryResult::columns`]. + pub fn columns(&self) -> Arc<[Column]> { + self.columns.clone() + } /// See [`Conn::info`][1]. /// From 9d0883703c45b91e485c76b558bc742125539615 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 8 Nov 2023 10:44:11 +0300 Subject: [PATCH 037/130] Fix docstrings --- src/local_infile_handler/mod.rs | 2 +- src/opts/mod.rs | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/local_infile_handler/mod.rs b/src/local_infile_handler/mod.rs index 7d222a2b..39515124 100644 --- a/src/local_infile_handler/mod.rs +++ b/src/local_infile_handler/mod.rs @@ -34,7 +34,7 @@ pub type InfileData = BoxStream<'static, std::io::Result>; /// **Warning:** You should be aware of [Security Considerations for LOAD DATA LOCAL][1]. /// /// The purpose of the handler is to emit infile data in response to a file name. -/// This handler will be called if there is no [`LocalHandler`] installed for the connection. +/// This handler will be called if there is no local handler installed for the connection. /// /// The library will call this handler in response to a LOCAL INFILE request from the server. /// The server, in its turn, will emit LOCAL INFILE requests in response to a `LOAD DATA LOCAL` diff --git a/src/opts/mod.rs b/src/opts/mod.rs index c7bd7052..346805ff 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -252,7 +252,7 @@ impl PoolOpts { /// So to encrease overall performance you can safely opt-out of the default behavior /// if you are not willing to change the session state in an unpleasant way. /// - /// It is also possible to selectively opt-in/out using [`Conn::reset_connection`]. + /// It is also possible to selectively opt-in/out using [`Conn::reset_connection`][1]. /// /// # Connection URL /// @@ -266,6 +266,8 @@ impl PoolOpts { /// assert_eq!(opts.pool_opts().reset_connection(), false); /// # Ok(()) } /// ``` + /// + /// [1]: crate::Conn::reset_connection pub fn with_reset_connection(mut self, reset_connection: bool) -> Self { self.reset_connection = reset_connection; self @@ -635,7 +637,10 @@ impl Opts { } /// Commands to execute on new connection and every time - /// [`Conn::reset`] or [`Conn::change_user`] is invoked. + /// [`Conn::reset`][1] or [`Conn::change_user`][2] is invoked. + /// + /// [1]: crate::Conn::reset + /// [2]: crate::Conn::change_user pub fn setup(&self) -> &[String] { self.inner.mysql_opts.setup.as_ref() } From f8e0bd598e634c91d5f9ac0719ac828d08ce5c02 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 8 Nov 2023 11:51:09 +0300 Subject: [PATCH 038/130] Clippy --- src/buffer_pool.rs | 4 +- src/conn/binlog_stream/mod.rs | 46 ++++++++++----------- src/conn/mod.rs | 61 +++++++++++++--------------- src/conn/pool/mod.rs | 2 +- src/conn/pool/ttl_check_inerval.rs | 7 ++-- src/conn/routines/exec.rs | 2 +- src/conn/routines/helpers.rs | 8 ++-- src/conn/routines/prepare.rs | 2 +- src/conn/stmt_cache.rs | 6 +-- src/error/mod.rs | 2 +- src/io/mod.rs | 7 +--- src/io/read_packet.rs | 2 +- src/io/tls/native_tls_io.rs | 6 +-- src/lib.rs | 10 +++-- src/opts/mod.rs | 62 ++++++++++++++--------------- src/queryable/query_result/mod.rs | 6 +-- src/queryable/query_result/tests.rs | 4 +- src/queryable/stmt.rs | 2 +- src/queryable/transaction.rs | 2 +- tests/generic.rs | 4 +- 20 files changed, 113 insertions(+), 132 deletions(-) diff --git a/src/buffer_pool.rs b/src/buffer_pool.rs index 03b1a4cc..7e89cfc4 100644 --- a/src/buffer_pool.rs +++ b/src/buffer_pool.rs @@ -7,7 +7,7 @@ // modified, or distributed except according to those terms. use crossbeam::queue::ArrayQueue; -use std::{mem::replace, ops::Deref, sync::Arc}; +use std::{mem::take, ops::Deref, sync::Arc}; #[derive(Debug)] pub struct BufferPool { @@ -93,6 +93,6 @@ impl Deref for PooledBuf { impl Drop for PooledBuf { fn drop(&mut self) { - self.1.put(replace(&mut self.0, vec![])) + self.1.put(take(&mut self.0)) } } diff --git a/src/conn/binlog_stream/mod.rs b/src/conn/binlog_stream/mod.rs index c4749f01..24af9202 100644 --- a/src/conn/binlog_stream/mod.rs +++ b/src/conn/binlog_stream/mod.rs @@ -139,25 +139,25 @@ impl futures_core::stream::Stream for BinlogStream { Err(err) => return Poll::Ready(Some(Err(err.into()))), }; - let first_byte = packet.get(0).copied(); + let first_byte = packet.first().copied(); if first_byte == Some(255) { if let Ok(ErrPacket::Error(err)) = - ParseBuf(&*packet).parse(self.read_packet.conn_ref().capabilities()) + ParseBuf(&packet).parse(self.read_packet.conn_ref().capabilities()) { return Poll::Ready(Some(Err(From::from(err)))); } } - if first_byte == Some(254) && packet.len() < 8 { - if ParseBuf(&*packet) + if first_byte == Some(254) + && packet.len() < 8 + && ParseBuf(&packet) .parse::>( self.read_packet.conn_ref().capabilities(), ) .is_ok() - { - return Poll::Ready(None); - } + { + return Poll::Ready(None); } if first_byte == Some(0) { @@ -171,16 +171,16 @@ impl futures_core::stream::Stream for BinlogStream { Err(_) => (/* TODO: Log the error */), } } - return Poll::Ready(Some(Ok(event))); + Poll::Ready(Some(Ok(event))) } - Ok(None) => return Poll::Ready(None), - Err(err) => return Poll::Ready(Some(Err(err.into()))), + Ok(None) => Poll::Ready(None), + Err(err) => Poll::Ready(Some(Err(err.into()))), } } else { - return Poll::Ready(Some(Err(DriverError::UnexpectedPacket { + Poll::Ready(Some(Err(DriverError::UnexpectedPacket { payload: packet.to_vec(), } - .into()))); + .into()))) } } } @@ -294,14 +294,11 @@ mod tests { event.header().event_type().unwrap(); // iterate over rows of an event - match event.read_data()?.unwrap() { - EventData::RowsEvent(re) => { - let tme = binlog_stream.get_tme(re.table_id()); - for row in re.rows(tme.unwrap()) { - row.unwrap(); - } + if let EventData::RowsEvent(re) = event.read_data()?.unwrap() { + let tme = binlog_stream.get_tme(re.table_id()); + for row in re.rows(tme.unwrap()) { + row.unwrap(); } - _ => (), } } assert!(events_num > 0); @@ -334,14 +331,11 @@ mod tests { event.header().event_type().unwrap(); // iterate over rows of an event - match event.read_data()?.unwrap() { - EventData::RowsEvent(re) => { - let tme = binlog_stream.get_tme(re.table_id()); - for row in re.rows(tme.unwrap()) { - row.unwrap(); - } + if let EventData::RowsEvent(re) = event.read_data()?.unwrap() { + let tme = binlog_stream.get_tme(re.table_id()); + for row in re.rows(tme.unwrap()) { + row.unwrap(); } - _ => (), } } assert!(events_num > 0); diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 04b28cad..252fbce7 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -309,7 +309,7 @@ impl Conn { if let Err(ref e) = self.inner.pending_result { let e = e.clone(); self.inner.pending_result = Ok(None); - return Err(e); + Err(e) } else { Ok(self.inner.pending_result.as_ref().unwrap().as_ref()) } @@ -322,8 +322,7 @@ impl Conn { } pub(crate) fn has_pending_result(&self) -> bool { - matches!(self.inner.pending_result, Err(_)) - || matches!(self.inner.pending_result, Ok(Some(_))) + self.inner.pending_result.is_err() || matches!(self.inner.pending_result, Ok(Some(_))) } /// Sets the given pening result metadata for this connection. Returns the previous value. @@ -487,7 +486,7 @@ impl Conn { async fn handle_handshake(&mut self) -> Result<()> { let packet = self.read_packet().await?; - let handshake = ParseBuf(&*packet).parse::(())?; + let handshake = ParseBuf(&packet).parse::(())?; // Handshake scramble is always 21 bytes length (20 + zero terminator) self.inner.nonce = { @@ -563,7 +562,7 @@ impl Conn { let auth_data = self .inner .auth_plugin - .gen_data(self.inner.opts.pass(), &*self.inner.nonce); + .gen_data(self.inner.opts.pass(), &self.inner.nonce); let handshake_response = HandshakeResponse::new( auth_data.as_deref(), @@ -594,10 +593,9 @@ impl Conn { if matches!( auth_switch_request.auth_plugin(), AuthPlugin::MysqlOldPassword - ) { - if self.inner.opts.secure_auth() { - return Err(DriverError::MysqlOldPasswordDisabled.into()); - } + ) && self.inner.opts.secure_auth() + { + return Err(DriverError::MysqlOldPasswordDisabled.into()); } self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned(); @@ -685,7 +683,7 @@ impl Conn { async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> { let packet = self.read_packet().await?; - match packet.get(0) { + match packet.first() { Some(0x00) => { // ok packet for empty password Ok(()) @@ -712,10 +710,10 @@ impl Conn { *byte ^= self.inner.nonce[i % self.inner.nonce.len()]; } let encrypted_pass = crypto::encrypt( - &*pass, + &pass, self.inner.server_key.as_deref().expect("unreachable"), ); - self.write_bytes(&*encrypted_pass).await?; + self.write_bytes(&encrypted_pass).await?; }; self.drop_packet().await?; Ok(()) @@ -726,7 +724,7 @@ impl Conn { .into()), }, Some(0xfe) if !self.inner.auth_switched => { - let auth_switch_request = ParseBuf(&*packet).parse::(())?; + let auth_switch_request = ParseBuf(&packet).parse::(())?; self.perform_auth_switch(auth_switch_request).await?; Ok(()) } @@ -739,13 +737,13 @@ impl Conn { async fn continue_mysql_native_password_auth(&mut self) -> Result<()> { let packet = self.read_packet().await?; - match packet.get(0) { + match packet.first() { Some(0x00) => Ok(()), Some(0xfe) if !self.inner.auth_switched => { let auth_switch = if packet.len() > 1 { - ParseBuf(&*packet).parse(())? + ParseBuf(&packet).parse(())? } else { - let _ = ParseBuf(&*packet).parse::(())?; + let _ = ParseBuf(&packet).parse::(())?; // map OldAuthSwitch to AuthSwitch with mysql_old_password plugin AuthSwitchRequest::new( "mysql_old_password".as_bytes(), @@ -768,16 +766,16 @@ impl Conn { .capabilities() .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) { - ParseBuf(&*packet) + ParseBuf(packet) .parse::>(self.capabilities()) .map(|x| x.into_inner()) } else { - ParseBuf(&*packet) + ParseBuf(packet) .parse::>(self.capabilities()) .map(|x| x.into_inner()) } } else { - ParseBuf(&*packet) + ParseBuf(packet) .parse::>(self.capabilities()) .map(|x| x.into_inner()) }; @@ -785,7 +783,7 @@ impl Conn { if let Ok(ok_packet) = ok_packet { self.handle_ok(ok_packet.into_owned()); } else { - let err_packet = ParseBuf(&*packet).parse::(self.capabilities()); + let err_packet = ParseBuf(packet).parse::(self.capabilities()); if let Ok(err_packet) = err_packet { self.handle_err(err_packet)?; return Ok(true); @@ -1011,14 +1009,13 @@ impl Conn { fn apply(&self, conn: &mut Conn, value: Option) { match self { Cfg::Socket => { - conn.inner.socket = value.map(crate::from_value).flatten(); + conn.inner.socket = value.and_then(crate::from_value); } Cfg::MaxAllowedPacket => { if let Some(stream) = conn.inner.stream.as_mut() { stream.set_max_allowed_packet( value - .map(crate::from_value) - .flatten() + .and_then(crate::from_value) .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET), ); } @@ -1026,8 +1023,7 @@ impl Conn { Cfg::WaitTimeout => { conn.inner.wait_timeout = Duration::from_secs( value - .map(crate::from_value) - .flatten() + .and_then(crate::from_value) .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64, ); } @@ -1329,6 +1325,7 @@ mod test { #[test] fn opts_should_satisfy_send_and_sync() { struct A(T); + #[allow(clippy::unnecessary_operation)] A(get_opts()); } @@ -1672,7 +1669,7 @@ mod test { async fn should_perform_queries() -> super::Result<()> { let mut conn = Conn::new(get_opts()).await?; for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) { - let long_string = ::std::iter::repeat('A').take(x).collect::(); + let long_string = "A".repeat(x); let result: Vec<(String, u8)> = conn .query(format!(r"SELECT '{}', 231", long_string)) .await?; @@ -1724,15 +1721,11 @@ mod test { #[tokio::test] async fn should_execute_statement() -> super::Result<()> { - let long_string = ::std::iter::repeat('A') - .take(18 * 1024 * 1024) - .collect::(); + let long_string = "A".repeat(18 * 1024 * 1024); let mut conn = Conn::new(get_opts()).await?; let stmt = conn.prep(r"SELECT ?").await?; let result = conn.exec_iter(&stmt, (&long_string,)).await?; - let mut mapped = result - .map_and_drop(|row| from_row::<(String,)>(row)) - .await?; + let mut mapped = result.map_and_drop(from_row::<(String,)>).await?; assert_eq!(mapped.len(), 1); assert_eq!(mapped.pop(), Some((long_string,))); let result = conn.exec_iter(&stmt, (42_u8,)).await?; @@ -1755,7 +1748,7 @@ mod test { .exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" }) .await?; let mut mapped = result - .map_and_drop(|row| from_row::<(String, String, String, u8)>(row)) + .map_and_drop(from_row::<(String, String, String, u8)>) .await?; assert_eq!(mapped.len(), 1); assert_eq!( @@ -1847,7 +1840,7 @@ mod test { let result = conn.query_iter(q).await?; let loaded_structs = result - .map_and_drop(|row| crate::from_row::<(Vec, Vec, u64, Vec)>(row)) + .map_and_drop(crate::from_row::<(Vec, Vec, u64, Vec)>) .await?; conn.disconnect().await?; diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index a984bee5..df3ac251 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -606,7 +606,7 @@ mod test { } drop(tx); // see that all the tx's eventually complete - while let Some(_) = rx.recv().await {} + while (rx.recv().await).is_some() {} } drop(pool); } diff --git a/src/conn/pool/ttl_check_inerval.rs b/src/conn/pool/ttl_check_inerval.rs index dde8e529..a0caa1a9 100644 --- a/src/conn/pool/ttl_check_inerval.rs +++ b/src/conn/pool/ttl_check_inerval.rs @@ -60,10 +60,9 @@ impl TtlCheckInterval { VecDeque::<_>::with_capacity(self.pool_opts.constraints().max()); while let Some(conn) = exchange.available.pop_front() { - if conn.expired() { - to_be_dropped.push(conn); - } else if to_be_dropped.len() < num_to_drop - && conn.elapsed() > self.pool_opts.inactive_connection_ttl() + if conn.expired() + || (to_be_dropped.len() < num_to_drop + && conn.elapsed() > self.pool_opts.inactive_connection_ttl()) { to_be_dropped.push(conn); } else { diff --git a/src/conn/routines/exec.rs b/src/conn/routines/exec.rs index feabb8fd..5a53e00e 100644 --- a/src/conn/routines/exec.rs +++ b/src/conn/routines/exec.rs @@ -60,7 +60,7 @@ impl Routine<()> for ExecRoutine<'_> { } let (body, as_long_data) = - ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(&*params); + ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(params); if as_long_data { conn.send_long_data(self.stmt.id(), params.iter()).await?; diff --git a/src/conn/routines/helpers.rs b/src/conn/routines/helpers.rs index eed0f38e..fdaf5fef 100644 --- a/src/conn/routines/helpers.rs +++ b/src/conn/routines/helpers.rs @@ -65,14 +65,14 @@ impl Conn { } }; - match packet.get(0) { + match packet.first() { Some(0x00) => { self.set_pending_result(Some(P::result_set_meta(Arc::from( Vec::new().into_boxed_slice(), ))))?; } - Some(0xFB) => self.handle_local_infile::

(&*packet).await?, - _ => self.handle_result_set::

(&*packet).await?, + Some(0xFB) => self.handle_local_infile::

(&packet).await?, + _ => self.handle_result_set::

(&packet).await?, } Ok(()) @@ -98,7 +98,7 @@ impl Conn { match bytes { Ok(bytes) => { // We'll skip empty chunks to stay compliant with the protocol. - if bytes.len() > 0 { + if !bytes.is_empty() { self.write_bytes(&bytes).await?; } } diff --git a/src/conn/routines/prepare.rs b/src/conn/routines/prepare.rs index 33970e58..1fa1d300 100644 --- a/src/conn/routines/prepare.rs +++ b/src/conn/routines/prepare.rs @@ -47,7 +47,7 @@ impl Routine> for PrepareRoutine { .await?; let packet = conn.read_packet().await?; - let mut inner_stmt = StmtInner::from_payload(&*packet, conn.id(), self.query.clone())?; + let mut inner_stmt = StmtInner::from_payload(&packet, conn.id(), self.query.clone())?; #[cfg(feature = "tracing")] Span::current().record("mysql_async.statement.id", inner_stmt.id()); diff --git a/src/conn/stmt_cache.rs b/src/conn/stmt_cache.rs index 595836ac..f72240f9 100644 --- a/src/conn/stmt_cache.rs +++ b/src/conn/stmt_cache.rs @@ -23,13 +23,13 @@ pub struct QueryString(pub Arc<[u8]>); impl Borrow<[u8]> for QueryString { fn borrow(&self) -> &[u8] { - &*self.0.as_ref() + self.0.as_ref() } } impl PartialEq<[u8]> for QueryString { fn eq(&self, other: &[u8]) -> bool { - &*self.0.as_ref() == other + self.0.as_ref() == other } } @@ -80,7 +80,7 @@ impl StmtCache { if self.cache.len() > self.cap { if let Some((_, entry)) = self.cache.pop_lru() { - self.query_map.remove(&*entry.query.0.as_ref()); + self.query_map.remove(entry.query.0.as_ref()); return Some(entry.stmt); } } diff --git a/src/error/mod.rs b/src/error/mod.rs index 3b1235c2..e14a2578 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -109,7 +109,7 @@ pub enum DriverError { #[error("Error converting from mysql row.")] FromRow { row: Row }, - #[error("Missing named parameter `{}'.", String::from_utf8_lossy(&name))] + #[error("Missing named parameter `{}'.", String::from_utf8_lossy(name))] MissingNamedParam { name: Vec }, #[error("Named and positional parameters mixed in one statement.")] diff --git a/src/io/mod.rs b/src/io/mod.rs index d46b2dc3..f5ffc5c6 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -155,10 +155,7 @@ impl Future for CheckTcpStream<'_> { impl Endpoint { #[cfg(unix)] fn is_socket(&self) -> bool { - match self { - Self::Socket(_) => true, - _ => false, - } + matches!(self, Self::Socket(_)) } /// Checks, that connection is alive. @@ -182,7 +179,7 @@ impl Endpoint { } #[cfg(unix)] Endpoint::Socket(socket) => { - socket.write(&[]).await?; + let _ = socket.write(&[]).await?; Ok(()) } Endpoint::Plain(None) => unreachable!(), diff --git a/src/io/read_packet.rs b/src/io/read_packet.rs index 226e9ab6..6c4646ee 100644 --- a/src/io/read_packet.rs +++ b/src/io/read_packet.rs @@ -29,7 +29,7 @@ impl<'a, 't> ReadPacket<'a, 't> { #[cfg(feature = "binlog")] pub(crate) fn conn_ref(&self) -> &crate::Conn { - &*self.0 + &self.0 } } diff --git a/src/io/tls/native_tls_io.rs b/src/io/tls/native_tls_io.rs index 910387d7..04121d37 100644 --- a/src/io/tls/native_tls_io.rs +++ b/src/io/tls/native_tls_io.rs @@ -21,7 +21,7 @@ impl Endpoint { let mut root_cert_file = File::open(root_cert_path)?; root_cert_file.read_to_end(&mut root_cert_data)?; - let root_certs = Certificate::from_der(&*root_cert_data) + let root_certs = Certificate::from_der(&root_cert_data) .map(|x| vec![x]) .or_else(|_| { pem::parse_many(&*root_cert_data) @@ -42,7 +42,7 @@ impl Endpoint { let password = client_identity.password().unwrap_or(""); let der = std::fs::read(pkcs12_path)?; - let identity = Identity::from_pkcs12(&*der, password)?; + let identity = Identity::from_pkcs12(&der, password)?; builder.identity(identity); } builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation()); @@ -52,7 +52,7 @@ impl Endpoint { *self = match self { Endpoint::Plain(ref mut stream) => { let stream = stream.take().unwrap(); - let tls_stream = tls_connector.connect(&*domain, stream).await?; + let tls_stream = tls_connector.connect(&domain, stream).await?; Endpoint::Secure(tls_stream) } Endpoint::Secure(_) => unreachable!(), diff --git a/src/lib.rs b/src/lib.rs index 4f5fc568..9d94ceaf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -447,7 +447,7 @@ mod queryable; type BoxFuture<'a, T> = futures_core::future::BoxFuture<'a, Result>; static BUFFER_POOL: once_cell::sync::Lazy> = - once_cell::sync::Lazy::new(|| Default::default()); + once_cell::sync::Lazy::new(Default::default); #[cfg(feature = "binlog")] #[doc(inline)] @@ -582,7 +582,7 @@ pub mod prelude { pub trait ToConnection<'a, 't: 'a>: crate::connection_like::ToConnection<'a, 't> {} // explicitly implemented because of rusdoc impl<'a> ToConnection<'a, 'static> for &'a crate::Pool {} - impl<'a> ToConnection<'static, 'static> for crate::Pool {} + impl ToConnection<'static, 'static> for crate::Pool {} impl ToConnection<'static, 'static> for crate::Conn {} impl<'a> ToConnection<'a, 'static> for &'a mut crate::Conn {} impl<'a, 't> ToConnection<'a, 't> for &'a mut crate::Transaction<'t> {} @@ -607,7 +607,9 @@ pub mod test_misc { #[allow(unreachable_code)] fn error_should_implement_send_and_sync() { fn _dummy(_: T) {} - _dummy(panic!()); + #[allow(unused_variables)] + let err: crate::Error = panic!(); + _dummy(err); } lazy_static! { @@ -629,7 +631,7 @@ pub mod test_misc { } pub fn get_opts() -> OptsBuilder { - let mut builder = OptsBuilder::from_opts(Opts::from_url(&**DATABASE_URL).unwrap()); + let mut builder = OptsBuilder::from_opts(Opts::from_url(&DATABASE_URL).unwrap()); if test_ssl() { let ssl_opts = SslOpts::default() .with_danger_skip_domain_validation(true) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 346805ff..2ecb0e3a 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -1104,7 +1104,7 @@ impl OptsBuilder { OptsBuilder { tcp_port: opts.inner.address.get_tcp_port(), ip_or_hostname: opts.inner.address.get_ip_or_hostname().to_string(), - opts: (*opts.inner).mysql_opts.clone(), + opts: opts.inner.mysql_opts.clone(), } } @@ -1423,15 +1423,11 @@ fn get_opts_user_from_url(url: &Url) -> Option { } fn get_opts_pass_from_url(url: &Url) -> Option { - if let Some(pass) = url.password() { - Some( - percent_decode(pass.as_ref()) - .decode_utf8_lossy() - .into_owned(), - ) - } else { - None - } + url.password().map(|pass| { + percent_decode(pass.as_ref()) + .decode_utf8_lossy() + .into_owned() + }) } fn get_opts_db_name_from_url(url: &Url) -> Option { @@ -1458,9 +1454,9 @@ fn from_url_basic(url: &Url) -> std::result::Result<(MysqlOpts, Vec<(String, Str if url.cannot_be_a_base() || !url.has_host() { return Err(UrlError::Invalid); } - let user = get_opts_user_from_url(&url); - let pass = get_opts_pass_from_url(&url); - let db_name = get_opts_db_name_from_url(&url); + let user = get_opts_user_from_url(url); + let pass = get_opts_pass_from_url(url); + let db_name = get_opts_db_name_from_url(url); let query_pairs = url.query_pairs().into_owned().collect(); let opts = MysqlOpts { @@ -1483,7 +1479,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { for (key, value) in query_pairs { if key == "pool_min" { - match usize::from_str(&*value) { + match usize::from_str(&value) { Ok(value) => pool_min = value, _ => { return Err(UrlError::InvalidParamValue { @@ -1493,7 +1489,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "pool_max" { - match usize::from_str(&*value) { + match usize::from_str(&value) { Ok(value) => pool_max = value, _ => { return Err(UrlError::InvalidParamValue { @@ -1503,7 +1499,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "inactive_connection_ttl" { - match u64::from_str(&*value) { + match u64::from_str(&value) { Ok(value) => { opts.pool_opts = opts .pool_opts @@ -1517,7 +1513,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "ttl_check_interval" { - match u64::from_str(&*value) { + match u64::from_str(&value) { Ok(value) => { opts.pool_opts = opts .pool_opts @@ -1531,7 +1527,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "conn_ttl" { - match u64::from_str(&*value) { + match u64::from_str(&value) { Ok(value) => opts.conn_ttl = Some(Duration::from_secs(value)), _ => { return Err(UrlError::InvalidParamValue { @@ -1541,7 +1537,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "abs_conn_ttl" { - match u64::from_str(&*value) { + match u64::from_str(&value) { Ok(value) => { opts.pool_opts = opts .pool_opts @@ -1555,7 +1551,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "abs_conn_ttl_jitter" { - match u64::from_str(&*value) { + match u64::from_str(&value) { Ok(value) => { opts.pool_opts = opts .pool_opts @@ -1569,7 +1565,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "tcp_keepalive" { - match u32::from_str(&*value) { + match u32::from_str(&value) { Ok(value) => opts.tcp_keepalive = Some(value), _ => { return Err(UrlError::InvalidParamValue { @@ -1579,7 +1575,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "max_allowed_packet" { - match usize::from_str(&*value) { + match usize::from_str(&value) { Ok(value) => { opts.max_allowed_packet = Some(std::cmp::max(1024, std::cmp::min(1073741824, value))) @@ -1592,7 +1588,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "wait_timeout" { - match usize::from_str(&*value) { + match usize::from_str(&value) { #[cfg(windows)] Ok(value) => opts.wait_timeout = Some(std::cmp::min(2147483, value)), #[cfg(not(windows))] @@ -1605,7 +1601,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "enable_cleartext_plugin" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(parsed) => opts.enable_cleartext_plugin = parsed, Err(_) => { return Err(UrlError::InvalidParamValue { @@ -1615,7 +1611,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "reset_connection" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(parsed) => opts.pool_opts = opts.pool_opts.with_reset_connection(parsed), Err(_) => { return Err(UrlError::InvalidParamValue { @@ -1625,7 +1621,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "tcp_nodelay" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(value) => opts.tcp_nodelay = value, _ => { return Err(UrlError::InvalidParamValue { @@ -1635,7 +1631,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "stmt_cache_size" { - match usize::from_str(&*value) { + match usize::from_str(&value) { Ok(stmt_cache_size) => { opts.stmt_cache_size = stmt_cache_size; } @@ -1647,7 +1643,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "prefer_socket" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(prefer_socket) => { opts.prefer_socket = prefer_socket; } @@ -1659,7 +1655,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "secure_auth" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(secure_auth) => { opts.secure_auth = secure_auth; } @@ -1671,7 +1667,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "client_found_rows" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(client_found_rows) => { opts.client_found_rows = client_found_rows; } @@ -1702,7 +1698,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } else if key == "require_ssl" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(x) => opts.ssl_opts = x.then(SslOpts::default), _ => { return Err(UrlError::InvalidParamValue { @@ -1712,7 +1708,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "verify_ca" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(x) => { accept_invalid_certs = !x; } @@ -1724,7 +1720,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "verify_identity" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(x) => { skip_domain_validation = !x; } diff --git a/src/queryable/query_result/mod.rs b/src/queryable/query_result/mod.rs index 21f45f96..94a3de96 100644 --- a/src/queryable/query_result/mod.rs +++ b/src/queryable/query_result/mod.rs @@ -174,16 +174,16 @@ where columns: Arc<[Column]>, ) -> crate::Result> { if let Some(row) = self.next_row(columns).await? { - return Ok(Some(row)); + Ok(Some(row)) } else { self.next_set().await?; - return Ok(None); + Ok(None) } } /// Skips the taken result set. async fn skip_taken(&mut self, meta: Arc) -> crate::Result<()> { - while let Some(_) = self.next_row_or_next_set((*meta).clone()).await? {} + while (self.next_row_or_next_set((*meta).clone()).await?).is_some() {} Ok(()) } diff --git a/src/queryable/query_result/tests.rs b/src/queryable/query_result/tests.rs index 24c870cf..3aa80081 100644 --- a/src/queryable/query_result/tests.rs +++ b/src/queryable/query_result/tests.rs @@ -162,7 +162,7 @@ async fn should_map_resultset() -> super::Result<()> { ) .await?; - let rows_1 = result.map(|row| from_row::<(String, u8)>(row)).await?; + let rows_1 = result.map(from_row::<(String, u8)>).await?; let rows_2 = result.map_and_drop(from_row).await?; conn.disconnect().await?; @@ -219,7 +219,7 @@ async fn should_handle_multi_result_sets_where_some_results_have_no_output() -> r.for_each_and_drop(|x| assert_eq!(from_row::(x), 1)) .await?; let r = t.query_iter(QUERY).await?; - let out = r.map_and_drop(|row| from_row::(row)).await?; + let out = r.map_and_drop(from_row::).await?; assert_eq!(vec![1], out); let r = t.query_iter(QUERY).await?; let out = r diff --git a/src/queryable/stmt.rs b/src/queryable/stmt.rs index 0624fb50..0ce7294b 100644 --- a/src/queryable/stmt.rs +++ b/src/queryable/stmt.rs @@ -306,7 +306,7 @@ impl crate::Conn { let packets = self.read_packets(num).await?; let defs = packets .into_iter() - .map(|x| ParseBuf(&*x).parse(())) + .map(|x| ParseBuf(&x).parse(())) .collect::, _>>() .map_err(Error::from)?; diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index 53b38ed8..eb13a000 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -195,7 +195,7 @@ impl Deref for Transaction<'_> { type Target = Conn; fn deref(&self) -> &Self::Target { - &*self.0 + &self.0 } } diff --git a/tests/generic.rs b/tests/generic.rs index 3f14891d..b0451f20 100644 --- a/tests/generic.rs +++ b/tests/generic.rs @@ -33,7 +33,7 @@ where TupleType: FromRow + Send + 'static, P: Protocol + Send + 'static, { - Ok(result.collect().await?) + result.collect().await } pub async fn get_single_result(result: QueryResult<'_, '_, P>) -> Result @@ -51,7 +51,7 @@ where #[tokio::test] async fn use_generic_code() { - let pool = Pool::new(Opts::from_url(&*get_url()).unwrap()); + let pool = Pool::new(Opts::from_url(&get_url()).unwrap()); let mut conn = pool.get_conn().await.unwrap(); let result = conn.query_iter("SELECT 1, 2, 3").await.unwrap(); let result = get_single_result::<(u8, u8, u8), _>(result).await.unwrap(); From e5b07df965e0c3b29781b1138ebb8165f7ebbb5e Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 8 Nov 2023 12:20:59 +0300 Subject: [PATCH 039/130] Bump lru version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e1939be7..9ce0912a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" lazy_static = "1" -lru = "0.11.0" +lru = "0.12.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } mysql_common = { version = "0.31", default-features = false } once_cell = "1.7.2" From 514d6dba101f6ef80cfc7f711303586060d52f71 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 8 Nov 2023 18:39:50 +0300 Subject: [PATCH 040/130] Document the binlog feature --- README.md | 5 +++++ src/lib.rs | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/README.md b/README.md index d469ddb7..e4bc948e 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ as well as `native-tls`-based TLS support. - `mysql_common/time03` - `mysql_common/uuid` - `mysql_common/frunk` + - `binlog` * `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. @@ -96,6 +97,10 @@ as well as `native-tls`-based TLS support. * `derive` – enables `mysql_commom/derive` feature +* `binlog` - enables binlog-related functionality. Enables: + + - `mysql_common/binlog" + [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features ## TLS/SSL Support diff --git a/src/lib.rs b/src/lib.rs index 9d94ceaf..af44f67f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,6 +53,7 @@ //! - `mysql_common/time03` //! - `mysql_common/uuid` //! - `mysql_common/frunk` +//! - `binlog` //! //! * `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. //! @@ -95,6 +96,10 @@ //! //! * `derive` – enables `mysql_commom/derive` feature //! +//! * `binlog` - enables binlog-related functionality. Enables: +//! +//! - `mysql_common/binlog" +//! //! [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features //! //! # TLS/SSL Support From eb735aac8d2af647831d20b8fa42e13a33c0878e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Thu, 16 Nov 2023 01:58:29 -0100 Subject: [PATCH 041/130] Implement Borrow for QueuedWaker This allows using self.queue.remove(&id) directly, which then allows other simplifications. --- src/conn/pool/mod.rs | 41 +++++++++++++---------------------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index df3ac251..926ae7db 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -11,7 +11,8 @@ use keyed_priority_queue::KeyedPriorityQueue; use tokio::sync::mpsc; use std::{ - cmp::{Ordering, Reverse}, + borrow::Borrow, + cmp::Reverse, collections::VecDeque, convert::TryFrom, hash::{Hash, Hasher}, @@ -108,29 +109,19 @@ struct Waitlist { } impl Waitlist { - fn push(&mut self, w: Waker, queue_id: QueueId) { - self.queue.push( - QueuedWaker { - queue_id, - waker: Some(w), - }, - queue_id, - ); + fn push(&mut self, waker: Waker, queue_id: QueueId) { + self.queue.push(QueuedWaker { queue_id, waker }, queue_id); } fn pop(&mut self) -> Option { match self.queue.pop() { - Some((qw, _)) => Some(qw.waker.unwrap()), + Some((qw, _)) => Some(qw.waker), None => None, } } fn remove(&mut self, id: QueueId) { - let tmp = QueuedWaker { - queue_id: id, - waker: None, - }; - self.queue.remove(&tmp); + self.queue.remove(&id); } fn is_empty(&self) -> bool { @@ -154,26 +145,20 @@ impl QueueId { #[derive(Debug)] struct QueuedWaker { queue_id: QueueId, - waker: Option, + waker: Waker, } impl Eq for QueuedWaker {} -impl PartialEq for QueuedWaker { - fn eq(&self, other: &Self) -> bool { - self.queue_id == other.queue_id +impl Borrow for QueuedWaker { + fn borrow(&self) -> &QueueId { + &self.queue_id } } -impl Ord for QueuedWaker { - fn cmp(&self, other: &Self) -> Ordering { - self.queue_id.cmp(&other.queue_id) - } -} - -impl PartialOrd for QueuedWaker { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) +impl PartialEq for QueuedWaker { + fn eq(&self, other: &Self) -> bool { + self.queue_id == other.queue_id } } From e0f7517b6ac38fb516caa46c0d00f4dfb01c0d4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Thu, 16 Nov 2023 12:50:11 -0100 Subject: [PATCH 042/130] Refactor pool creation in tests --- src/conn/pool/mod.rs | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index df3ac251..3949ef03 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -423,6 +423,12 @@ mod test { }; } + fn pool_with_one_connection() -> Pool { + let pool_opts = PoolOpts::new().with_constraints(PoolConstraints::new(1, 1).unwrap()); + let opts = get_opts().pool_opts(pool_opts.clone()); + Pool::new(opts) + } + #[tokio::test] async fn should_opt_out_of_connection_reset() -> super::Result<()> { let pool_opts = PoolOpts::new().with_constraints(PoolConstraints::new(1, 1).unwrap()); @@ -571,10 +577,7 @@ mod test { #[tokio::test] async fn should_reuse_connections() -> super::Result<()> { - let constraints = PoolConstraints::new(1, 1).unwrap(); - let opts = get_opts().pool_opts(PoolOpts::default().with_constraints(constraints)); - - let pool = Pool::new(opts); + let pool = pool_with_one_connection(); let mut conn = pool.get_conn().await?; let server_version = conn.server_version(); @@ -613,10 +616,7 @@ mod test { #[tokio::test] async fn should_start_transaction() -> super::Result<()> { - let constraints = PoolConstraints::new(1, 1).unwrap(); - let opts = get_opts().pool_opts(PoolOpts::default().with_constraints(constraints)); - - let pool = Pool::new(opts); + let pool = pool_with_one_connection(); "CREATE TABLE IF NOT EXISTS mysql.tmp(id int)" .ignore(&pool) @@ -909,10 +909,7 @@ mod test { #[tokio::test] async fn should_ignore_non_fatal_errors_while_returning_to_a_pool() -> super::Result<()> { - let pool_constraints = PoolConstraints::new(1, 1).unwrap(); - let pool_opts = PoolOpts::default().with_constraints(pool_constraints); - - let pool = Pool::new(get_opts().pool_opts(pool_opts)); + let pool = pool_with_one_connection(); let id = pool.get_conn().await?.id(); // non-fatal errors are ignored @@ -927,10 +924,7 @@ mod test { #[tokio::test] async fn should_remove_waker_of_cancelled_task() { - let pool_constraints = PoolConstraints::new(1, 1).unwrap(); - let pool_opts = PoolOpts::default().with_constraints(pool_constraints); - - let pool = Pool::new(get_opts().pool_opts(pool_opts)); + let pool = pool_with_one_connection(); let only_conn = pool.get_conn().await.unwrap(); let join_handle = tokio::spawn(timeout(Duration::from_secs(1), pool.get_conn())); From 906c453393c5b20220cf24a2274b2d4468ad8a33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Thu, 14 Dec 2023 15:04:37 -0100 Subject: [PATCH 043/130] Add a failing test The next patch fixes it, but by adding a failing test we get a documentation of what the bug was. --- src/conn/pool/mod.rs | 68 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 3 deletions(-) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 3949ef03..59a67ea3 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -392,14 +392,14 @@ impl Drop for Conn { #[cfg(test)] mod test { use futures_util::{ - future::{join_all, select, select_all, try_join_all}, - try_join, FutureExt, + future::{join_all, select, select_all, try_join_all, Either}, + poll, try_join, FutureExt, }; use tokio::time::{sleep, timeout}; use std::{ cmp::Reverse, - task::{RawWaker, RawWakerVTable, Waker}, + task::{Poll, RawWaker, RawWakerVTable, Waker}, time::Duration, }; @@ -1053,6 +1053,68 @@ mod test { Ok(()) } + #[tokio::test] + async fn check_priorities() -> super::Result<()> { + let pool = pool_with_one_connection(); + + let queue_len = || { + let exchange = pool.inner.exchange.lock().unwrap(); + exchange.waiting.queue.len() + }; + + // Get a connection, so we know the next futures will be + // queued. + let conn = pool.get_conn().await.unwrap(); + + let get_pending = || async { + let fut = async { + pool.get_conn().await.unwrap(); + } + .shared(); + let p = poll!(fut.clone()); + assert!(matches!(p, Poll::Pending)); + fut + }; + + let fut1 = get_pending().await; + let fut2 = get_pending().await; + + // Both futures are queued + assert_eq!(queue_len(), 2); + + drop(conn); // This will pop fut1 from the queue, making it [2] + while queue_len() != 1 { + tokio::time::sleep(Duration::from_millis(100)).await; + } + + // We called wake on fut1, but with the select fut2 will + // resolve first + let Either::Left((_, fut1)) = select(fut2, fut1).await else { + panic!("wrong future"); + }; + + // We dropped the connection of fut2, but very likely hasn't + // made it through the recycler yet. + assert_eq!(queue_len(), 1); + + let p = poll!(fut1.clone()); + assert!(matches!(p, Poll::Pending)); + assert_eq!(queue_len(), 2); // Now fut1 is queued again + + // The connection will pass by the recycler and unblock fut1 + // and pop it from the queue. + fut1.await; + assert_eq!(queue_len(), 1); + + // Since the queue is not empty, a new future will be pending + let fut3 = get_pending().await; + assert_eq!(queue_len(), 2); + + println!("we get here"); + fut3.await; + panic!("we never get here"); + } + #[cfg(feature = "nightly")] mod bench { use futures_util::future::{FutureExt, TryFutureExt}; From be693e0798f9b72d1189fd69d494922fcce5bcf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Thu, 16 Nov 2023 12:19:41 -0100 Subject: [PATCH 044/130] Use an explicit priority check This fixes the case where wake is called for one future, but another future gets the connection. --- src/conn/pool/futures/get_conn.rs | 4 +-- src/conn/pool/mod.rs | 45 ++++++++++++++++++------------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index 8b21e685..22543fad 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -112,10 +112,8 @@ impl Future for GetConn { loop { match self.inner { GetConnInner::New => { - let queued = self.queue_id.is_some(); let queue_id = *self.queue_id.get_or_insert_with(QueueId::next); - let next = - ready!(Pin::new(self.pool_mut()).poll_new_conn(cx, queued, queue_id))?; + let next = ready!(Pin::new(self.pool_mut()).poll_new_conn(cx, queue_id))?; match next { GetConnInner::Connecting(conn_fut) => { self.inner = GetConnInner::Connecting(conn_fut); diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 59a67ea3..f13a5a2c 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -133,8 +133,8 @@ impl Waitlist { self.queue.remove(&tmp); } - fn is_empty(&self) -> bool { - self.queue.is_empty() + fn peek_id(&mut self) -> Option { + self.queue.peek().map(|(qw, _)| qw.queue_id) } } @@ -303,16 +303,14 @@ impl Pool { fn poll_new_conn( self: Pin<&mut Self>, cx: &mut Context<'_>, - queued: bool, queue_id: QueueId, ) -> Poll> { - self.poll_new_conn_inner(cx, queued, queue_id) + self.poll_new_conn_inner(cx, queue_id) } fn poll_new_conn_inner( self: Pin<&mut Self>, cx: &mut Context<'_>, - queued: bool, queue_id: QueueId, ) -> Poll> { let mut exchange = self.inner.exchange.lock().unwrap(); @@ -326,8 +324,15 @@ impl Pool { exchange.spawn_futures_if_needed(&self.inner); - // Check if others are waiting and we're not queued. - if !exchange.waiting.is_empty() && !queued { + // Check if we are higher priority than anything current + let highest = if let Some(cur) = exchange.waiting.peek_id() { + queue_id > cur + } else { + true + }; + + // If we are not, just queue + if !highest { exchange.waiting.push(cx.waker().clone(), queue_id); return Poll::Pending; } @@ -1087,32 +1092,34 @@ mod test { tokio::time::sleep(Duration::from_millis(100)).await; } - // We called wake on fut1, but with the select fut2 will + // We called wake on fut1, and even with the select fut1 will // resolve first - let Either::Left((_, fut1)) = select(fut2, fut1).await else { + let Either::Right((_, fut2)) = select(fut2, fut1).await else { panic!("wrong future"); }; - // We dropped the connection of fut2, but very likely hasn't + // We dropped the connection of fut1, but very likely hasn't // made it through the recycler yet. assert_eq!(queue_len(), 1); - let p = poll!(fut1.clone()); + let p = poll!(fut2.clone()); assert!(matches!(p, Poll::Pending)); - assert_eq!(queue_len(), 2); // Now fut1 is queued again + assert_eq!(queue_len(), 1); // The queue still has fut2 - // The connection will pass by the recycler and unblock fut1 + // The connection will pass by the recycler and unblock fut2 // and pop it from the queue. - fut1.await; - assert_eq!(queue_len(), 1); + fut2.await; + assert_eq!(queue_len(), 0); - // Since the queue is not empty, a new future will be pending + // The recycler is probably not done, so a new future will be + // pending. let fut3 = get_pending().await; - assert_eq!(queue_len(), 2); + assert_eq!(queue_len(), 1); - println!("we get here"); + // It is OK to await it. fut3.await; - panic!("we never get here"); + + Ok(()) } #[cfg(feature = "nightly")] From 5b525ee11e66a0888a84a2c94f27c9626a5a3b89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Thu, 16 Nov 2023 11:36:51 -0100 Subject: [PATCH 045/130] Inline poll_new_conn_inner into only caller --- src/conn/pool/mod.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index f13a5a2c..ae89a414 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -304,14 +304,6 @@ impl Pool { self: Pin<&mut Self>, cx: &mut Context<'_>, queue_id: QueueId, - ) -> Poll> { - self.poll_new_conn_inner(cx, queue_id) - } - - fn poll_new_conn_inner( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - queue_id: QueueId, ) -> Poll> { let mut exchange = self.inner.exchange.lock().unwrap(); From 9ef8aea07f354d113ded3ea32905cec3fcac230f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Fri, 17 Nov 2023 01:28:13 -0100 Subject: [PATCH 046/130] Make queue_id non-optional This is just a small simplification that avoids having to think about when it is None. --- src/conn/pool/futures/get_conn.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index 22543fad..166bf7a7 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -66,7 +66,7 @@ impl GetConnInner { #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct GetConn { - pub(crate) queue_id: Option, + pub(crate) queue_id: QueueId, pub(crate) pool: Option, pub(crate) inner: GetConnInner, reset_upon_returning_to_a_pool: bool, @@ -77,7 +77,7 @@ pub struct GetConn { impl GetConn { pub(crate) fn new(pool: &Pool, reset_upon_returning_to_a_pool: bool) -> GetConn { GetConn { - queue_id: None, + queue_id: QueueId::next(), pool: Some(pool.clone()), inner: GetConnInner::New, reset_upon_returning_to_a_pool, @@ -112,7 +112,7 @@ impl Future for GetConn { loop { match self.inner { GetConnInner::New => { - let queue_id = *self.queue_id.get_or_insert_with(QueueId::next); + let queue_id = self.queue_id; let next = ready!(Pin::new(self.pool_mut()).poll_new_conn(cx, queue_id))?; match next { GetConnInner::Connecting(conn_fut) => { @@ -185,9 +185,7 @@ impl Drop for GetConn { if let Some(pool) = self.pool.take() { // Remove the waker from the pool's waitlist in case this task was // woken by another waker, like from tokio::time::timeout. - if let Some(queue_id) = self.queue_id { - pool.unqueue(queue_id); - } + pool.unqueue(self.queue_id); if let GetConnInner::Connecting(..) = self.inner.take() { pool.cancel_connection(); } From 31affa38a6f7177c64dd2c633b88c0069b3fdfb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Tue, 26 Dec 2023 11:16:29 -0100 Subject: [PATCH 047/130] Silence clippy warnings In these cases I think the code is better as is, so I just silenced clippy. --- src/conn/binlog_stream/mod.rs | 1 + src/conn/pool/mod.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/src/conn/binlog_stream/mod.rs b/src/conn/binlog_stream/mod.rs index 24af9202..60aa6659 100644 --- a/src/conn/binlog_stream/mod.rs +++ b/src/conn/binlog_stream/mod.rs @@ -166,6 +166,7 @@ impl futures_core::stream::Stream for BinlogStream { Ok(Some(event)) => { if event.header().event_type_raw() == EventType::TRANSACTION_PAYLOAD_EVENT as u8 { + #[allow(clippy::single_match)] match event.read_event::>() { Ok(e) => self.tpe = Some(Cursor::new(e.danger_decompress())), Err(_) => (/* TODO: Log the error */), diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 92412b22..42ee589f 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -1048,6 +1048,7 @@ mod test { // queued. let conn = pool.get_conn().await.unwrap(); + #[allow(clippy::async_yields_async)] let get_pending = || async { let fut = async { pool.get_conn().await.unwrap(); From 3de12e3d61fc6511c4117e2193c6b8bc35053b9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Tue, 26 Dec 2023 13:01:25 -0100 Subject: [PATCH 048/130] Drop unnecessary use of Pin --- src/conn/pool/futures/get_conn.rs | 2 +- src/conn/pool/mod.rs | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index 166bf7a7..76dedb20 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -113,7 +113,7 @@ impl Future for GetConn { match self.inner { GetConnInner::New => { let queue_id = self.queue_id; - let next = ready!(Pin::new(self.pool_mut()).poll_new_conn(cx, queue_id))?; + let next = ready!(self.pool_mut().poll_new_conn(cx, queue_id))?; match next { GetConnInner::Connecting(conn_fut) => { self.inner = GetConnInner::Connecting(conn_fut); diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 92412b22..630dcbf4 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -16,7 +16,6 @@ use std::{ collections::VecDeque, convert::TryFrom, hash::{Hash, Hasher}, - pin::Pin, str::FromStr, sync::{atomic, Arc, Mutex}, task::{Context, Poll, Waker}, @@ -286,7 +285,7 @@ impl Pool { /// Poll the pool for an available connection. fn poll_new_conn( - self: Pin<&mut Self>, + &mut self, cx: &mut Context<'_>, queue_id: QueueId, ) -> Poll> { From ef20c7e888c1a7ea36c0542ae096130975560be8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Thu, 16 Nov 2023 14:09:35 -0100 Subject: [PATCH 049/130] Always save the most recent waker --- src/conn/pool/mod.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 92412b22..9d00f364 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -110,6 +110,18 @@ struct Waitlist { impl Waitlist { fn push(&mut self, waker: Waker, queue_id: QueueId) { + // The documentation of Future::poll says: + // Note that on multiple calls to poll, only the Waker from + // the Context passed to the most recent call should be + // scheduled to receive a wakeup. + // + // But the the documentation of KeyedPriorityQueue::push says: + // Adds new element to queue if missing key or replace its + // priority if key exists. In second case doesn’t replace key. + // + // This means we have to remove first to have the most recent + // waker in the queue. + self.remove(queue_id); self.queue.push(QueuedWaker { queue_id, waker }, queue_id); } From 553357c3150eb9531a980f0d010ef0c95d264b0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Tue, 26 Dec 2023 16:57:29 -0100 Subject: [PATCH 050/130] Delete GetConnInner::take It is only used in GetConn::drop and not actually needed in there. --- src/conn/pool/futures/get_conn.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index 166bf7a7..ce7faa78 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -55,13 +55,6 @@ impl fmt::Debug for GetConnInner { } } -impl GetConnInner { - /// Take the value of the inner connection, resetting it to `New`. - pub fn take(&mut self) -> GetConnInner { - std::mem::replace(self, GetConnInner::New) - } -} - /// This future will take connection from a pool and resolve to [`Conn`]. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] @@ -186,7 +179,7 @@ impl Drop for GetConn { // Remove the waker from the pool's waitlist in case this task was // woken by another waker, like from tokio::time::timeout. pool.unqueue(self.queue_id); - if let GetConnInner::Connecting(..) = self.inner.take() { + if let GetConnInner::Connecting(..) = self.inner { pool.cancel_connection(); } } From fce95539518bbfc0fada4a10b4ba147e45762c44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafael=20=C3=81vila=20de=20Esp=C3=ADndola?= Date: Mon, 25 Dec 2023 16:51:08 -0100 Subject: [PATCH 051/130] Add a test for the previous patch --- Cargo.toml | 1 + src/conn/pool/mod.rs | 46 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 9ce0912a..e409e126 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -72,6 +72,7 @@ version = "0.25.0" optional = true [dev-dependencies] +waker-fn = "1" tempfile = "3.1.0" socket2 = { version = "0.5.2", features = ["all"] } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 9d00f364..ffdb26bc 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -390,10 +390,14 @@ mod test { poll, try_join, FutureExt, }; use tokio::time::{sleep, timeout}; + use waker_fn::waker_fn; use std::{ cmp::Reverse, - task::{Poll, RawWaker, RawWakerVTable, Waker}, + future::Future, + pin::pin, + sync::{Arc, OnceLock}, + task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, time::Duration, }; @@ -1047,6 +1051,46 @@ mod test { Ok(()) } + #[tokio::test] + async fn save_last_waker() { + // Test that if passed multiple wakers, we call the last one. + + let pool = pool_with_one_connection(); + + // Get a connection, so we know the next future will be + // queued. + let conn = pool.get_conn().await.unwrap(); + let mut pending_fut = pin!(pool.get_conn()); + + let build_waker = || { + let called = Arc::new(OnceLock::new()); + let called2 = called.clone(); + let waker = waker_fn(move || called2.set(()).unwrap()); + (called, waker) + }; + + let mut assert_pending = |waker| { + let mut context = Context::from_waker(&waker); + let p = pending_fut.as_mut().poll(&mut context); + assert!(matches!(p, Poll::Pending)); + }; + + let (first_called, waker) = build_waker(); + assert_pending(waker); + + let (second_called, waker) = build_waker(); + assert_pending(waker); + + drop(conn); + + while second_called.get().is_none() { + assert!(first_called.get().is_none()); + tokio::time::sleep(Duration::from_millis(100)).await; + } + + assert!(first_called.get().is_none()); + } + #[tokio::test] async fn check_priorities() -> super::Result<()> { let pool = pool_with_one_connection(); From d74888d06edaf2a8c7697c47cac2517810eeb486 Mon Sep 17 00:00:00 2001 From: Alessandro Chitolina Date: Wed, 10 Jan 2024 10:45:56 +0100 Subject: [PATCH 052/130] fixed wrong string check (fix #283) --- src/io/tls/rustls_io.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index a976b961..ceccc539 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -136,7 +136,7 @@ impl ServerCertVerifier for DangerousVerifier { ) { Ok(assertion) => Ok(assertion), Err(ref e) - if e.to_string().contains("CertNotValidForName") + if e.to_string().contains("NotValidForName") && self.skip_domain_validation => { Ok(rustls::client::ServerCertVerified::assertion()) From ecc7f3fc57952bed7a26b17f601c2c91f1710337 Mon Sep 17 00:00:00 2001 From: Roshan Jobanputra Date: Mon, 5 Feb 2024 15:47:01 -0500 Subject: [PATCH 053/130] Allow providing the root CA cert as bytes in addition to from a file --- src/io/tls/native_tls_io.rs | 28 +++++++++++++++++----------- src/io/tls/rustls_io.rs | 18 ++++++++++++++++-- src/opts/mod.rs | 18 ++++++++++++++++++ 3 files changed, 51 insertions(+), 13 deletions(-) diff --git a/src/io/tls/native_tls_io.rs b/src/io/tls/native_tls_io.rs index 04121d37..bc421be9 100644 --- a/src/io/tls/native_tls_io.rs +++ b/src/io/tls/native_tls_io.rs @@ -21,18 +21,13 @@ impl Endpoint { let mut root_cert_file = File::open(root_cert_path)?; root_cert_file.read_to_end(&mut root_cert_data)?; - let root_certs = Certificate::from_der(&root_cert_data) - .map(|x| vec![x]) - .or_else(|_| { - pem::parse_many(&*root_cert_data) - .unwrap_or_default() - .iter() - .map(pem::encode) - .map(|s| Certificate::from_pem(s.as_bytes())) - .collect() - })?; + for root_cert in parse_certs(&root_cert_data)? { + builder.add_root_certificate(root_cert); + } + } - for root_cert in root_certs { + if let Some(root_cert_data) = ssl_opts.root_cert() { + for root_cert in parse_certs(root_cert_data)? { builder.add_root_certificate(root_cert); } } @@ -63,3 +58,14 @@ impl Endpoint { Ok(()) } } + +fn parse_certs(buf: &[u8]) -> Result> { + Ok(Certificate::from_der(buf).map(|x| vec![x]).or_else(|_| { + pem::parse_many(buf) + .unwrap_or_default() + .iter() + .map(pem::encode) + .map(|s| Certificate::from_pem(s.as_bytes())) + .collect() + })?) +} diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index ceccc539..feb857dc 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -50,6 +50,21 @@ impl Endpoint { } } + if let Some(root_cert_data) = ssl_opts.root_cert() { + let mut root_certs = Vec::new(); + for cert in certs(&mut &*root_cert_data)? { + root_certs.push(Certificate(cert)); + } + + if root_certs.is_empty() && !root_cert_data.is_empty() { + root_certs.push(Certificate(root_cert_data.to_vec())); + } + + for cert in &root_certs { + root_store.add(cert)?; + } + } + let config_builder = ClientConfig::builder() .with_safe_defaults() .with_root_certificates(root_store.clone()); @@ -136,8 +151,7 @@ impl ServerCertVerifier for DangerousVerifier { ) { Ok(assertion) => Ok(assertion), Err(ref e) - if e.to_string().contains("NotValidForName") - && self.skip_domain_validation => + if e.to_string().contains("NotValidForName") && self.skip_domain_validation => { Ok(rustls::client::ServerCertVerified::assertion()) } diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 2ecb0e3a..06effb2b 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -142,6 +142,7 @@ pub struct SslOpts { #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] client_identity: Option, root_cert_path: Option>, + root_cert: Option>, skip_domain_validation: bool, accept_invalid_certs: bool, } @@ -156,6 +157,8 @@ impl SslOpts { /// Sets path to a `pem` or `der` certificate of the root that connector will trust. /// /// Multiple certs are allowed in .pem files. + /// + /// Will be merged with any certs provided by `with_root_cert`. pub fn with_root_cert_path>>( mut self, root_cert_path: Option, @@ -164,6 +167,17 @@ impl SslOpts { self } + /// Sets the bytes for a `pem` or `der` encoded certificate of the root that the connector + /// will trust. + /// + /// Multiple certs are allowed in .pem format. + /// + /// Will be merged with any certs provided by `with_root_cert_path`. + pub fn with_root_cert(mut self, cert: Option>) -> Self { + self.root_cert = cert; + self + } + /// The way to not validate the server's domain /// name against its certificate (defaults to `false`). pub fn with_danger_skip_domain_validation(mut self, value: bool) -> Self { @@ -187,6 +201,10 @@ impl SslOpts { self.root_cert_path.as_ref().map(AsRef::as_ref) } + pub fn root_cert(&self) -> Option<&[u8]> { + self.root_cert.as_ref().map(AsRef::as_ref) + } + pub fn skip_domain_validation(&self) -> bool { self.skip_domain_validation } From ecc4908e7c98ce1050fc536739c41ec444d03349 Mon Sep 17 00:00:00 2001 From: Roshan Jobanputra Date: Mon, 5 Feb 2024 16:21:54 -0500 Subject: [PATCH 054/130] Allow specifying the client identity pkcs12 archive as bytes for native-tls --- src/io/tls/native_tls_io.rs | 12 +++++++----- src/opts/native_tls_opts.rs | 31 ++++++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/io/tls/native_tls_io.rs b/src/io/tls/native_tls_io.rs index bc421be9..83848ca5 100644 --- a/src/io/tls/native_tls_io.rs +++ b/src/io/tls/native_tls_io.rs @@ -33,11 +33,13 @@ impl Endpoint { } if let Some(client_identity) = ssl_opts.client_identity() { - let pkcs12_path = client_identity.pkcs12_path(); - let password = client_identity.password().unwrap_or(""); - - let der = std::fs::read(pkcs12_path)?; - let identity = Identity::from_pkcs12(&der, password)?; + let identity = if let Some(data) = client_identity.pkcs12_data() { + Identity::from_pkcs12(data, client_identity.password().unwrap_or(""))? + } else { + let path = client_identity.pkcs12_path(); + let der = std::fs::read(path)?; + Identity::from_pkcs12(&der, client_identity.password().unwrap_or(""))? + }; builder.identity(identity); } builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation()); diff --git a/src/opts/native_tls_opts.rs b/src/opts/native_tls_opts.rs index 49eb4c46..91c51624 100644 --- a/src/opts/native_tls_opts.rs +++ b/src/opts/native_tls_opts.rs @@ -2,9 +2,15 @@ use std::{borrow::Cow, path::Path}; +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) enum Pkcs12Archive { + Path(Cow<'static, Path>), + Data(Vec), +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ClientIdentity { - pkcs12_path: Cow<'static, Path>, + pkcs12_archive: Pkcs12Archive, password: Option>, } @@ -15,7 +21,15 @@ impl ClientIdentity { T: Into>, { Self { - pkcs12_path: pkcs12_path.into(), + pkcs12_archive: Pkcs12Archive::Path(pkcs12_path.into()), + password: None, + } + } + + /// Creates new identity with the given bytes for a pkcs12 archive. + pub fn new_from_bytes(pkcs12_data: Vec) -> Self { + Self { + pkcs12_archive: Pkcs12Archive::Data(pkcs12_data), password: None, } } @@ -31,7 +45,18 @@ impl ClientIdentity { /// Returns the pkcs12 archive path. pub fn pkcs12_path(&self) -> &Path { - self.pkcs12_path.as_ref() + match &self.pkcs12_archive { + Pkcs12Archive::Path(path) => path.as_ref(), + Pkcs12Archive::Data(_) => panic!("pkcs12 archive path is not set"), + } + } + + /// Returns the pkcs12 archive data, if set. + pub fn pkcs12_data(&self) -> Option<&[u8]> { + match &self.pkcs12_archive { + Pkcs12Archive::Data(data) => Some(data.as_ref()), + Pkcs12Archive::Path(_) => None, + } } /// Returns the archive password. From d78dfde1610358e53c0ef9bfbb6d4139adbe7826 Mon Sep 17 00:00:00 2001 From: Roshan Jobanputra Date: Wed, 7 Feb 2024 16:13:52 -0500 Subject: [PATCH 055/130] Allow overriding domain used for TLS hostname verification --- src/conn/mod.rs | 5 ++++- src/io/tls/rustls_io.rs | 3 +-- src/opts/mod.rs | 17 +++++++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 252fbce7..577dd836 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -550,7 +550,10 @@ impl Conn { self.write_struct(&ssl_request).await?; let conn = self; let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable"); - let domain = conn.opts().ip_or_hostname().into(); + let domain = ssl_opts + .tls_hostname_override() + .unwrap_or_else(|| conn.opts().ip_or_hostname()) + .into(); conn.stream_mut()?.make_secure(domain, ssl_opts).await?; Ok(()) } else { diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index ceccc539..5edce5c9 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -136,8 +136,7 @@ impl ServerCertVerifier for DangerousVerifier { ) { Ok(assertion) => Ok(assertion), Err(ref e) - if e.to_string().contains("NotValidForName") - && self.skip_domain_validation => + if e.to_string().contains("NotValidForName") && self.skip_domain_validation => { Ok(rustls::client::ServerCertVerified::assertion()) } diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 2ecb0e3a..b50aa4b0 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -144,6 +144,7 @@ pub struct SslOpts { root_cert_path: Option>, skip_domain_validation: bool, accept_invalid_certs: bool, + tls_hostname_override: Option>, } impl SslOpts { @@ -178,6 +179,18 @@ impl SslOpts { self } + /// If set, will override the hostname used to verify the server's certificate. + /// + /// This is useful when connecting to a server via a tunnel, where the server hostname + /// name is different from the hostname used to connect to the tunnel. + pub fn with_tls_hostname_override>>( + mut self, + domain: Option, + ) -> Self { + self.tls_hostname_override = domain.map(Into::into); + self + } + #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] pub fn client_identity(&self) -> Option<&ClientIdentity> { self.client_identity.as_ref() @@ -194,6 +207,10 @@ impl SslOpts { pub fn accept_invalid_certs(&self) -> bool { self.accept_invalid_certs } + + pub fn tls_hostname_override(&self) -> Option<&str> { + self.tls_hostname_override.as_ref().map(AsRef::as_ref) + } } /// Connection pool options. From 3dc7cfc415c9ae2bdfaa5b5fb889102421f9acdf Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Fri, 16 Feb 2024 13:03:28 +0200 Subject: [PATCH 056/130] annotate `Transaction` with `must_use` It turns out it's easy to get confused and thing that you're running queries inside a transaction when in fact the transaction object gets dropped and will be implicitly rolled back when the connection terminates. This PR adds a `must_use` annotation as an aid to developers that they need to do something with the return value of `start_transaction`. Signed-off-by: Petros Angelatos --- src/queryable/transaction.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index eb13a000..6ec64b87 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -126,6 +126,7 @@ impl fmt::Display for IsolationLevel { /// /// You should always call either `commit` or `rollback`, otherwise transaction will be rolled /// back implicitly when corresponding connection is dropped or queried. +#[must_use = "transaction object must be committed or rolled back explicitly"] #[derive(Debug)] pub struct Transaction<'a>(pub(crate) Connection<'a, 'static>); From 31e2d180d6496c1212b01f1ada8ef514a84e25b9 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Sat, 24 Feb 2024 14:23:27 +0300 Subject: [PATCH 057/130] Introduce PathOrBuf for certificate and key data --- src/io/tls/native_tls_io.rs | 42 +++++++---------- src/io/tls/rustls_io.rs | 59 ++++++++++-------------- src/opts/mod.rs | 92 +++++++++++++++++++++++++------------ src/opts/native_tls_opts.rs | 54 +++++++++------------- src/opts/rustls_opts.rs | 62 +++++++++++-------------- 5 files changed, 152 insertions(+), 157 deletions(-) diff --git a/src/io/tls/native_tls_io.rs b/src/io/tls/native_tls_io.rs index 83848ca5..b4d5f751 100644 --- a/src/io/tls/native_tls_io.rs +++ b/src/io/tls/native_tls_io.rs @@ -1,12 +1,23 @@ #![cfg(feature = "native-tls")] -use std::{fs::File, io::Read}; - -use native_tls::{Certificate, Identity, TlsConnector}; +use native_tls::{Certificate, TlsConnector}; use crate::io::Endpoint; use crate::{Result, SslOpts}; +impl SslOpts { + async fn load_root_certs(&self) -> crate::Result> { + let mut output = Vec::new(); + + for root_cert in self.root_certs() { + let root_cert_data = root_cert.read().await?; + output.extend(parse_certs(root_cert_data.as_ref())?); + } + + Ok(output) + } +} + impl Endpoint { pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { #[cfg(unix)] @@ -16,31 +27,12 @@ impl Endpoint { } let mut builder = TlsConnector::builder(); - if let Some(root_cert_path) = ssl_opts.root_cert_path() { - let mut root_cert_data = vec![]; - let mut root_cert_file = File::open(root_cert_path)?; - root_cert_file.read_to_end(&mut root_cert_data)?; - - for root_cert in parse_certs(&root_cert_data)? { - builder.add_root_certificate(root_cert); - } - } - - if let Some(root_cert_data) = ssl_opts.root_cert() { - for root_cert in parse_certs(root_cert_data)? { - builder.add_root_certificate(root_cert); - } + for root_cert in ssl_opts.load_root_certs().await? { + builder.add_root_certificate(root_cert); } if let Some(client_identity) = ssl_opts.client_identity() { - let identity = if let Some(data) = client_identity.pkcs12_data() { - Identity::from_pkcs12(data, client_identity.password().unwrap_or(""))? - } else { - let path = client_identity.pkcs12_path(); - let der = std::fs::read(path)?; - Identity::from_pkcs12(&der, client_identity.password().unwrap_or(""))? - }; - builder.identity(identity); + builder.identity(client_identity.load().await?); } builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation()); builder.danger_accept_invalid_certs(ssl_opts.accept_invalid_certs()); diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index feb857dc..e8757d0b 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -7,13 +7,32 @@ use rustls::{ Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore, }; -use tokio::{fs::File, io::AsyncReadExt}; - use rustls_pemfile::certs; use tokio_rustls::TlsConnector; use crate::{io::Endpoint, Result, SslOpts}; +impl SslOpts { + async fn load_root_certs(&self) -> crate::Result> { + let mut output = Vec::new(); + + for root_cert in self.root_certs() { + let root_cert_data = root_cert.read().await?; + let mut seen = false; + for cert in certs(&mut &*root_cert_data)? { + seen = true; + output.push(Certificate(cert)); + } + + if !seen && !root_cert_data.is_empty() { + output.push(Certificate(root_cert_data.into_owned())); + } + } + + Ok(output) + } +} + impl Endpoint { pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { #[cfg(unix)] @@ -31,38 +50,8 @@ impl Endpoint { ) })); - if let Some(root_cert_path) = ssl_opts.root_cert_path() { - let mut root_cert_data = vec![]; - let mut root_cert_file = File::open(root_cert_path).await?; - root_cert_file.read_to_end(&mut root_cert_data).await?; - - let mut root_certs = Vec::new(); - for cert in certs(&mut &*root_cert_data)? { - root_certs.push(Certificate(cert)); - } - - if root_certs.is_empty() && !root_cert_data.is_empty() { - root_certs.push(Certificate(root_cert_data)); - } - - for cert in &root_certs { - root_store.add(cert)?; - } - } - - if let Some(root_cert_data) = ssl_opts.root_cert() { - let mut root_certs = Vec::new(); - for cert in certs(&mut &*root_cert_data)? { - root_certs.push(Certificate(cert)); - } - - if root_certs.is_empty() && !root_cert_data.is_empty() { - root_certs.push(Certificate(root_cert_data.to_vec())); - } - - for cert in &root_certs { - root_store.add(cert)?; - } + for cert in ssl_opts.load_root_certs().await? { + root_store.add(&cert)?; } let config_builder = ClientConfig::builder() @@ -70,7 +59,7 @@ impl Endpoint { .with_root_certificates(root_store.clone()); let mut config = if let Some(identity) = ssl_opts.client_identity() { - let (cert_chain, priv_key) = identity.load()?; + let (cert_chain, priv_key) = identity.load().await?; config_builder.with_client_auth_cert(cert_chain, priv_key)? } else { config_builder.with_no_client_auth() diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 06effb2b..3d0ff7e2 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -22,9 +22,9 @@ use url::{Host, Url}; use std::{ borrow::Cow, convert::TryFrom, - fmt, + fmt, io, net::{Ipv4Addr, Ipv6Addr}, - path::Path, + path::{Path, PathBuf}, str::FromStr, sync::Arc, time::{Duration, Instant}, @@ -115,6 +115,57 @@ impl HostPortOrUrl { } } +/// Represents data that is either on-disk or in the buffer. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] +pub enum PathOrBuf<'a> { + Path(Cow<'a, Path>), + Buf(Cow<'a, [u8]>), +} + +#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] +impl<'a> PathOrBuf<'a> { + /// Will either read data from disk or return the buffered data. + pub async fn read(&self) -> io::Result> { + match self { + PathOrBuf::Path(x) => tokio::fs::read(x.as_ref()).await.map(Cow::Owned), + PathOrBuf::Buf(x) => Ok(Cow::Borrowed(x.as_ref())), + } + } + + /// Borrows `self`. + pub fn borrow(&self) -> PathOrBuf<'_> { + match self { + PathOrBuf::Path(path) => PathOrBuf::Path(Cow::Borrowed(path.as_ref())), + PathOrBuf::Buf(data) => PathOrBuf::Buf(Cow::Borrowed(data.as_ref())), + } + } +} + +impl From for PathOrBuf<'static> { + fn from(value: PathBuf) -> Self { + Self::Path(Cow::Owned(value)) + } +} + +impl<'a> From<&'a Path> for PathOrBuf<'a> { + fn from(value: &'a Path) -> Self { + Self::Path(Cow::Borrowed(value)) + } +} + +impl From> for PathOrBuf<'static> { + fn from(value: Vec) -> Self { + Self::Buf(Cow::Owned(value)) + } +} + +impl<'a> From<&'a [u8]> for PathOrBuf<'a> { + fn from(value: &'a [u8]) -> Self { + Self::Buf(Cow::Borrowed(value)) + } +} + /// Ssl Options. /// /// ``` @@ -125,7 +176,7 @@ impl HostPortOrUrl { /// // With native-tls /// # #[cfg(feature = "native-tls-tls")] /// let ssl_opts = SslOpts::default() -/// .with_client_identity(Some(ClientIdentity::new(Path::new("/path")) +/// .with_client_identity(Some(ClientIdentity::new(Path::new("/path").into()) /// .with_password("******") /// )); /// @@ -133,16 +184,15 @@ impl HostPortOrUrl { /// # #[cfg(feature = "rustls-tls")] /// let ssl_opts = SslOpts::default() /// .with_client_identity(Some(ClientIdentity::new( -/// Path::new("/path/to/chain"), -/// Path::new("/path/to/priv_key") +/// Path::new("/path/to/chain").into(), +/// Path::new("/path/to/priv_key").into(), /// ))); /// ``` #[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] pub struct SslOpts { #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] client_identity: Option, - root_cert_path: Option>, - root_cert: Option>, + root_certs: Vec>, skip_domain_validation: bool, accept_invalid_certs: bool, } @@ -158,23 +208,9 @@ impl SslOpts { /// /// Multiple certs are allowed in .pem files. /// - /// Will be merged with any certs provided by `with_root_cert`. - pub fn with_root_cert_path>>( - mut self, - root_cert_path: Option, - ) -> Self { - self.root_cert_path = root_cert_path.map(Into::into); - self - } - - /// Sets the bytes for a `pem` or `der` encoded certificate of the root that the connector - /// will trust. - /// - /// Multiple certs are allowed in .pem format. - /// - /// Will be merged with any certs provided by `with_root_cert_path`. - pub fn with_root_cert(mut self, cert: Option>) -> Self { - self.root_cert = cert; + /// All the elements in `root_certs` will be merged. + pub fn with_root_certs(mut self, root_certs: Vec>) -> Self { + self.root_certs = root_certs; self } @@ -197,12 +233,8 @@ impl SslOpts { self.client_identity.as_ref() } - pub fn root_cert_path(&self) -> Option<&Path> { - self.root_cert_path.as_ref().map(AsRef::as_ref) - } - - pub fn root_cert(&self) -> Option<&[u8]> { - self.root_cert.as_ref().map(AsRef::as_ref) + pub fn root_certs(&self) -> &[PathOrBuf<'static>] { + &self.root_certs } pub fn skip_domain_validation(&self) -> bool { diff --git a/src/opts/native_tls_opts.rs b/src/opts/native_tls_opts.rs index 91c51624..7ab5be19 100644 --- a/src/opts/native_tls_opts.rs +++ b/src/opts/native_tls_opts.rs @@ -1,37 +1,30 @@ #![cfg(feature = "native-tls")] -use std::{borrow::Cow, path::Path}; +use std::borrow::Cow; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub(crate) enum Pkcs12Archive { - Path(Cow<'static, Path>), - Data(Vec), -} +use native_tls::Identity; + +use super::PathOrBuf; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ClientIdentity { - pkcs12_archive: Pkcs12Archive, + pkcs12_archive: PathOrBuf<'static>, password: Option>, } impl ClientIdentity { - /// Creates new identity with the given path to the pkcs12 archive. - pub fn new(pkcs12_path: T) -> Self - where - T: Into>, - { + /// Creates new identity with the given pkcs12 archive. + pub fn new(pkcs12_archive: PathOrBuf<'static>) -> Self { Self { - pkcs12_archive: Pkcs12Archive::Path(pkcs12_path.into()), + pkcs12_archive, password: None, } } - /// Creates new identity with the given bytes for a pkcs12 archive. - pub fn new_from_bytes(pkcs12_data: Vec) -> Self { - Self { - pkcs12_archive: Pkcs12Archive::Data(pkcs12_data), - password: None, - } + /// Sets the pkcs12 archive. + pub fn with_pkcs12_archive(mut self, pkcs12_archive: PathOrBuf<'static>) -> Self { + self.pkcs12_archive = pkcs12_archive; + self } /// Sets the archive password. @@ -43,24 +36,19 @@ impl ClientIdentity { self } - /// Returns the pkcs12 archive path. - pub fn pkcs12_path(&self) -> &Path { - match &self.pkcs12_archive { - Pkcs12Archive::Path(path) => path.as_ref(), - Pkcs12Archive::Data(_) => panic!("pkcs12 archive path is not set"), - } - } - - /// Returns the pkcs12 archive data, if set. - pub fn pkcs12_data(&self) -> Option<&[u8]> { - match &self.pkcs12_archive { - Pkcs12Archive::Data(data) => Some(data.as_ref()), - Pkcs12Archive::Path(_) => None, - } + /// Returns the pkcs12 archive. + pub fn pkcs12_archive(&self) -> PathOrBuf<'_> { + self.pkcs12_archive.borrow() } /// Returns the archive password. pub fn password(&self) -> Option<&str> { self.password.as_ref().map(AsRef::as_ref) } + + pub(crate) async fn load(&self) -> crate::Result { + let der = self.pkcs12_archive.read().await?; + let password = self.password().unwrap_or_default(); + Ok(Identity::from_pkcs12(der.as_ref(), password)?) + } } diff --git a/src/opts/rustls_opts.rs b/src/opts/rustls_opts.rs index 143dc62d..562846d8 100644 --- a/src/opts/rustls_opts.rs +++ b/src/opts/rustls_opts.rs @@ -5,79 +5,73 @@ use rustls_pemfile::{certs, rsa_private_keys}; use std::{borrow::Cow, path::Path}; +use super::PathOrBuf; + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ClientIdentity { - cert_chain_path: Cow<'static, Path>, - priv_key_path: Cow<'static, Path>, + cert_chain: PathOrBuf<'static>, + priv_key: PathOrBuf<'static>, } impl ClientIdentity { /// Creates new identity. /// - /// `cert_chain_path` - path to a certificate chain (in PEM or DER) - /// `priv_key_path` - path to a private key (in DER or PEM) (it'll take the first one) - pub fn new(cert_chain_path: T, priv_key_path: U) -> Self - where - T: Into>, - U: Into>, - { + /// `cert_chain` - certificate chain (in PEM or DER) + /// `priv_key` - private key (in DER or PEM) (it'll take the first one) + pub fn new(cert_chain: PathOrBuf<'static>, priv_key: PathOrBuf<'static>) -> Self { Self { - cert_chain_path: cert_chain_path.into(), - priv_key_path: priv_key_path.into(), + cert_chain, + priv_key, } } /// Sets the certificate chain path (in DER or PEM). - pub fn with_cert_chain_path(mut self, cert_chain_path: T) -> Self - where - T: Into>, - { - self.cert_chain_path = cert_chain_path.into(); + pub fn with_cert_chain(mut self, cert_chain: PathOrBuf<'static>) -> Self { + self.cert_chain = cert_chain; self } /// Sets the private key path (in DER or PEM) (it'll take the first one). - pub fn with_priv_key_path(mut self, priv_key_path: T) -> Self + pub fn with_priv_key(mut self, priv_key: PathOrBuf<'static>) -> Self where T: Into>, { - self.priv_key_path = priv_key_path.into(); + self.priv_key = priv_key; self } - /// Returns the certificate chain path. - pub fn cert_chain_path(&self) -> &Path { - self.cert_chain_path.as_ref() + /// Returns the certificate chain. + pub fn cert_chain(&self) -> PathOrBuf<'_> { + self.cert_chain.borrow() } - /// Returns the private key path. - pub fn priv_key_path(&self) -> &Path { - self.priv_key_path.as_ref() + /// Returns the private key. + pub fn priv_key(&self) -> PathOrBuf<'_> { + self.priv_key.borrow() } - pub(crate) fn load(&self) -> crate::Result<(Vec, PrivateKey)> { - let cert_data = std::fs::read(self.cert_chain_path.as_ref())?; - let key_data = std::fs::read(self.priv_key_path.as_ref())?; + pub(crate) async fn load(&self) -> crate::Result<(Vec, PrivateKey)> { + let cert_data = self.cert_chain.read().await?; + let key_data = self.priv_key.read().await?; let mut cert_chain = Vec::new(); if std::str::from_utf8(&cert_data).is_err() { - cert_chain.push(Certificate(cert_data)); + cert_chain.push(Certificate(cert_data.into_owned())); } else { for cert in certs(&mut &*cert_data)? { cert_chain.push(Certificate(cert)); } } - let priv_key; - if std::str::from_utf8(&key_data).is_err() { - priv_key = Some(PrivateKey(key_data)); + let priv_key = if std::str::from_utf8(&key_data).is_err() { + Some(PrivateKey(key_data.into_owned())) } else { - priv_key = rsa_private_keys(&mut &*key_data)? + rsa_private_keys(&mut &*key_data)? .into_iter() .take(1) .map(PrivateKey) - .next(); - } + .next() + }; Ok(( cert_chain, From 7390b733d88d2494f35a7c59a62840ad5f1e62af Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Mon, 26 Feb 2024 11:53:32 +0300 Subject: [PATCH 058/130] Rename with_tls_hostname_override -> with_danger_tls_hostname_override --- src/opts/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index b50aa4b0..3e023a91 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -181,9 +181,9 @@ impl SslOpts { /// If set, will override the hostname used to verify the server's certificate. /// - /// This is useful when connecting to a server via a tunnel, where the server hostname - /// name is different from the hostname used to connect to the tunnel. - pub fn with_tls_hostname_override>>( + /// This is useful when connecting to a server via a tunnel, where the server hostname is + /// different from the hostname used to connect to the tunnel. + pub fn with_danger_tls_hostname_override>>( mut self, domain: Option, ) -> Self { @@ -209,7 +209,7 @@ impl SslOpts { } pub fn tls_hostname_override(&self) -> Option<&str> { - self.tls_hostname_override.as_ref().map(AsRef::as_ref) + self.tls_hostname_override.as_deref() } } From 35e4c24dd1b905042d796d1b15a660ee64c54e17 Mon Sep 17 00:00:00 2001 From: Daniel Black Date: Wed, 13 Mar 2024 16:43:51 +1100 Subject: [PATCH 059/130] Use DROP USER in test DROP USER existed in 5.6 http://www.asktheway.org/official-documents/mysql/refman-5.6-en.html-chapter/sql-statements.html#drop-user 5.7 and MariaDB versions support an IF EXISTS option. https://jira.mariadb.org/browse/MDEV-7288 Using executable comment syntax we can make it compatible. When using DROP USER, like all user modifications, FLUSH PRIVILEGES isn't required. --- src/conn/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 577dd836..00de30f6 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1552,10 +1552,9 @@ mod test { .map(|x| ((x % (123 - 97)) + 97) as char) .collect(); - conn.query_drop("DELETE FROM mysql.user WHERE user = '__mats'") + conn.query_drop("DROP USER /*!50700 IF EXISTS */ /*M!100103 IF EXISTS */ __mats") .await .unwrap(); - conn.query_drop("FLUSH PRIVILEGES").await.unwrap(); if conn.inner.is_mariadb || conn.server_version() < (5, 7, 0) { if matches!(conn.server_version(), (5, 6, _)) { @@ -1587,8 +1586,6 @@ mod test { .unwrap(); }; - conn.query_drop("FLUSH PRIVILEGES").await.unwrap(); - let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap(); conn2 .change_user( From b975a726b474949939d4ad7cc4a3c1eca4119029 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Sat, 24 Feb 2024 12:07:02 +0300 Subject: [PATCH 060/130] Bump dependencies --- Cargo.toml | 12 ++++++------ src/conn/mod.rs | 5 ++++- src/error/mod.rs | 5 ++++- src/lib.rs | 4 +--- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e409e126..23b55fc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ keyed_priority_queue = "0.4" lazy_static = "1" lru = "0.12.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } -mysql_common = { version = "0.31", default-features = false } +mysql_common = { version = "0.32", default-features = false } once_cell = "1.7.2" pem = "3.0" percent-encoding = "2.1.0" @@ -42,7 +42,7 @@ twox-hash = "1" url = "2.1" [dependencies.tokio-rustls] -version = "0.24.0" +version = "0.25" optional = true [dependencies.tokio-native-tls] @@ -54,12 +54,12 @@ version = "0.2" optional = true [dependencies.rustls] -version = "0.21.0" -features = ["dangerous_configuration"] +version = "0.22.2" +features = [] optional = true [dependencies.rustls-pemfile] -version = "1.0.1" +version = "2.1.0" optional = true [dependencies.webpki] @@ -68,7 +68,7 @@ features = ["std"] optional = true [dependencies.webpki-roots] -version = "0.25.0" +version = "0.26.1" optional = true [dev-dependencies] diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 00de30f6..8d58f72b 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -7,7 +7,6 @@ // modified, or distributed except according to those terms. use futures_util::FutureExt; -pub use mysql_common::named_params; use mysql_common::{ constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI}, @@ -575,6 +574,10 @@ impl Conn { Some(self.inner.auth_plugin.borrow()), self.capabilities(), Default::default(), // TODO: Add support + self.inner + .opts + .max_allowed_packet() + .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET) as u32, ); // Serialize here to satisfy borrow checker. diff --git a/src/error/mod.rs b/src/error/mod.rs index e14a2578..983f274e 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -243,7 +243,10 @@ impl From> for ServerError { ServerError { code: packet.error_code(), message: packet.message_str().into(), - state: packet.sql_state_str().into(), + state: packet + .sql_state_ref() + .map(|s| s.as_str().into_owned()) + .unwrap_or_else(|| "HY000".to_owned()), } } } diff --git a/src/lib.rs b/src/lib.rs index af44f67f..f0605910 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -610,10 +610,8 @@ pub mod test_misc { #[allow(dead_code)] #[allow(unreachable_code)] - fn error_should_implement_send_and_sync() { + fn error_should_implement_send_and_sync(err: crate::Error) { fn _dummy(_: T) {} - #[allow(unused_variables)] - let err: crate::Error = panic!(); _dummy(err); } From 0540dc118a69f567c69133023117d7946589cc5a Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Mon, 18 Mar 2024 23:35:28 +0300 Subject: [PATCH 061/130] Fix ci for MySql 5.6 --- azure-pipelines.yml | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 359eb6cb..c3954334 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -94,9 +94,9 @@ jobs: maxParallel: 10 matrix: v80: - DB_VERSION: "8-debian" + DB_VERSION: "8.0-debian" v57: - DB_VERSION: "5-debian" + DB_VERSION: "5.7-debian" v56: DB_VERSION: "5.6" steps: @@ -114,9 +114,13 @@ jobs: displayName: Run MySql in Docker - bash: | docker exec container bash -l -c "mysql -uroot -ppassword -e \"SET old_passwords = 1; GRANT ALL PRIVILEGES ON *.* TO 'root2'@'%' IDENTIFIED WITH mysql_old_password AS 'password'; SET PASSWORD FOR 'root2'@'%' = OLD_PASSWORD('password')\""; + docker exec container bash -l -c "echo 'deb [trusted=yes] http://archive.debian.org/debian/ stretch main non-free contrib' > /etc/apt/sources.list" + docker exec container bash -l -c "echo 'deb-src [trusted=yes] http://archive.debian.org/debian/ stretch main non-free contrib ' >> /etc/apt/sources.list" + docker exec container bash -l -c "echo 'deb [trusted=yes] http://archive.debian.org/debian-security/ stretch/updates main non-free contrib' >> /etc/apt/sources.list" + docker exec container bash -l -c "echo 'deb [trusted=yes] http://repo.mysql.com/apt/debian/ stretch mysql-5.6' > /etc/apt/sources.list.d/mysql.list" condition: eq(variables['DB_VERSION'], '5.6') - bash: | - docker exec container bash -l -c "apt-get update" + docker exec container bash -l -c "apt-get --allow-unauthenticated -y update" docker exec container bash -l -c "apt-get install -y curl clang libssl-dev pkg-config build-essential" docker exec container bash -l -c "curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable" displayName: Install Rust in docker From d71a75a0b16473db2a600c299481413a802d2d60 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Mon, 18 Mar 2024 23:36:25 +0300 Subject: [PATCH 062/130] ci: reduce matrix for TestMariaDb --- azure-pipelines.yml | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index c3954334..387b5c4b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -142,20 +142,16 @@ jobs: strategy: maxParallel: 10 matrix: - v107: - DB_VERSION: "10.7" + v113: + DB_VERSION: "11.3" + v1011: + DB_VERSION: "10.11" v106: DB_VERSION: "10.6" v105: DB_VERSION: "10.5" v104: DB_VERSION: "10.4" - v103: - DB_VERSION: "10.3" - v102: - DB_VERSION: "10.2" - v101: - DB_VERSION: "10.1" steps: - bash: | sudo apt-get update From 414266537ba329d7853879e152a3e4a65a9b15ab Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Mon, 18 Mar 2024 23:39:16 +0300 Subject: [PATCH 063/130] Proxy mysql_common features --- Cargo.toml | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 23b55fc4..695bb843 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,26 +80,35 @@ tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } [features] default = [ "flate2/zlib", - "mysql_common/bigdecimal", - "mysql_common/rust_decimal", - "mysql_common/time", - "mysql_common/frunk", + "bigdecimal", + "rust_decimal", + "time", + "frunk", "derive", "native-tls-tls", "binlog", ] + default-rustls = [ "flate2/rust_backend", - "mysql_common/bigdecimal", - "mysql_common/rust_decimal", - "mysql_common/time", - "mysql_common/frunk", + "bigdecimal", + "rust_decimal", + "time", + "frunk", "derive", "rustls-tls", "binlog", ] + +# minimal feature set with system flate2 impl minimal = ["flate2/zlib"] +# minimal feature set with rust flate2 impl +minimal-rust = ["flate2/rust_backend"] + +# native-tls based TLS support native-tls-tls = ["native-tls", "tokio-native-tls"] + +# rustls based TLS support rustls-tls = [ "rustls", "tokio-rustls", @@ -107,11 +116,20 @@ rustls-tls = [ "webpki-roots", "rustls-pemfile", ] -tracing = ["dep:tracing"] + +# mysql_common features derive = ["mysql_common/derive"] -nightly = [] +chrono = ["mysql_common/chrono"] +time = ["mysql_common/time"] +bigdecimal = ["mysql_common/bigdecimal"] +rust_decimal = ["mysql_common/rust_decimal"] +frunk = ["mysql_common/frunk"] binlog = ["mysql_common/binlog"] +# other features +tracing = ["dep:tracing"] +nightly = [] + [lib] name = "mysql_async" path = "src/lib.rs" From 174f40f672ad1dfc54fcb25e439e46ea2bbeed88 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Mon, 18 Mar 2024 23:39:59 +0300 Subject: [PATCH 064/130] Bump version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 695bb843..2e0f56de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.33.0" +version = "0.34.0" exclude = ["test/*"] edition = "2018" categories = ["asynchronous", "database"] From c5aa660e5174157be89b397a340a4677159ddc03 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 19 Mar 2024 00:13:15 +0300 Subject: [PATCH 065/130] Fix rustls feature --- Cargo.toml | 2 +- src/error/mod.rs | 2 +- src/error/tls/rustls_error.rs | 11 ++++ src/io/tls/rustls_io.rs | 120 ++++++++++++++++++++++++---------- src/lib.rs | 3 +- src/opts/rustls_opts.rs | 26 ++++---- 6 files changed, 115 insertions(+), 49 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2e0f56de..da62f97d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" version = "0.34.0" exclude = ["test/*"] -edition = "2018" +edition = "2021" categories = ["asynchronous", "database"] [dependencies] diff --git a/src/error/mod.rs b/src/error/mod.rs index 983f274e..ffb788f8 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -8,7 +8,7 @@ pub use url::ParseError; -mod tls; +pub mod tls; use mysql_common::{ named_params::MixedParamsError, params::MissingNamedParameterError, diff --git a/src/error/tls/rustls_error.rs b/src/error/tls/rustls_error.rs index 2ee67d39..faae88db 100644 --- a/src/error/tls/rustls_error.rs +++ b/src/error/tls/rustls_error.rs @@ -2,11 +2,14 @@ use std::fmt::Display; +use rustls::server::VerifierBuilderError; + #[derive(Debug)] pub enum TlsError { Tls(rustls::Error), Pki(webpki::Error), InvalidDnsName(webpki::InvalidDnsNameError), + VerifierBuilderError(VerifierBuilderError), } impl From for crate::Error { @@ -15,6 +18,12 @@ impl From for crate::Error { } } +impl From for TlsError { + fn from(e: VerifierBuilderError) -> Self { + TlsError::VerifierBuilderError(e) + } +} + impl From for TlsError { fn from(e: rustls::Error) -> Self { TlsError::Tls(e) @@ -57,6 +66,7 @@ impl std::error::Error for TlsError { TlsError::Tls(e) => Some(e), TlsError::Pki(e) => Some(e), TlsError::InvalidDnsName(e) => Some(e), + TlsError::VerifierBuilderError(e) => Some(e), } } } @@ -67,6 +77,7 @@ impl Display for TlsError { TlsError::Tls(e) => e.fmt(f), TlsError::Pki(e) => e.fmt(f), TlsError::InvalidDnsName(e) => e.fmt(f), + TlsError::VerifierBuilderError(e) => e.fmt(f), } } } diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index e8757d0b..76080ff2 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -1,31 +1,35 @@ #![cfg(feature = "rustls-tls")] -use std::{convert::TryInto, sync::Arc}; +use std::sync::Arc; use rustls::{ - client::{ServerCertVerifier, WebPkiVerifier}, - Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore, + client::{ + danger::{ServerCertVerified, ServerCertVerifier}, + WebPkiServerVerifier, + }, + pki_types::{CertificateDer, ServerName}, + ClientConfig, RootCertStore, }; use rustls_pemfile::certs; use tokio_rustls::TlsConnector; -use crate::{io::Endpoint, Result, SslOpts}; +use crate::{io::Endpoint, Result, SslOpts, TlsError}; impl SslOpts { - async fn load_root_certs(&self) -> crate::Result> { + async fn load_root_certs(&self) -> crate::Result>> { let mut output = Vec::new(); for root_cert in self.root_certs() { let root_cert_data = root_cert.read().await?; let mut seen = false; - for cert in certs(&mut &*root_cert_data)? { + for cert in certs(&mut &*root_cert_data) { seen = true; - output.push(Certificate(cert)); + output.push(cert?); } if !seen && !root_cert_data.is_empty() { - output.push(Certificate(root_cert_data.into_owned())); + output.push(CertificateDer::from(root_cert_data.into_owned())); } } @@ -42,21 +46,13 @@ impl Endpoint { } let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|x| x.to_owned())); for cert in ssl_opts.load_root_certs().await? { - root_store.add(&cert)?; + root_store.add(cert)?; } - let config_builder = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store.clone()); + let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone()); let mut config = if let Some(identity) = ssl_opts.client_identity() { let (cert_chain, priv_key) = identity.load().await?; @@ -65,12 +61,13 @@ impl Endpoint { config_builder.with_no_client_auth() }; - let server_name = domain - .as_str() - .try_into() - .map_err(|_| webpki::InvalidDnsNameError)?; + let server_name = ServerName::try_from(domain.as_str()) + .map_err(|_| webpki::InvalidDnsNameError)? + .to_owned(); let mut dangerous = config.dangerous(); - let web_pki_verifier = WebPkiVerifier::new(root_store, None); + let web_pki_verifier = WebPkiServerVerifier::builder(Arc::new(root_store)) + .build() + .map_err(TlsError::from)?; let dangerous_verifier = DangerousVerifier::new( ssl_opts.accept_invalid_certs(), ssl_opts.skip_domain_validation(), @@ -97,17 +94,18 @@ impl Endpoint { } } +#[derive(Debug)] struct DangerousVerifier { accept_invalid_certs: bool, skip_domain_validation: bool, - verifier: WebPkiVerifier, + verifier: Arc, } impl DangerousVerifier { fn new( accept_invalid_certs: bool, skip_domain_validation: bool, - verifier: WebPkiVerifier, + verifier: Arc, ) -> Self { Self { accept_invalid_certs, @@ -118,23 +116,51 @@ impl DangerousVerifier { } impl ServerCertVerifier for DangerousVerifier { + // fn verify_server_cert( + // &self, + // end_entity: &Certificate, + // intermediates: &[Certificate], + // server_name: &rustls::ServerName, + // scts: &mut dyn Iterator, + // ocsp_response: &[u8], + // now: std::time::SystemTime, + // ) -> std::result::Result { + // if self.accept_invalid_certs { + // Ok(rustls::client::ServerCertVerified::assertion()) + // } else { + // match self.verifier.verify_server_cert( + // end_entity, + // intermediates, + // server_name, + // scts, + // ocsp_response, + // now, + // ) { + // Ok(assertion) => Ok(assertion), + // Err(ref e) + // if e.to_string().contains("NotValidForName") && self.skip_domain_validation => + // { + // Ok(rustls::client::ServerCertVerified::assertion()) + // } + // Err(e) => Err(e), + // } + // } + // } fn verify_server_cert( &self, - end_entity: &Certificate, - intermediates: &[Certificate], - server_name: &rustls::ServerName, - scts: &mut dyn Iterator, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &rustls::pki_types::ServerName<'_>, ocsp_response: &[u8], - now: std::time::SystemTime, - ) -> std::result::Result { + now: rustls::pki_types::UnixTime, + ) -> std::prelude::v1::Result { if self.accept_invalid_certs { - Ok(rustls::client::ServerCertVerified::assertion()) + Ok(ServerCertVerified::assertion()) } else { match self.verifier.verify_server_cert( end_entity, intermediates, server_name, - scts, ocsp_response, now, ) { @@ -142,10 +168,34 @@ impl ServerCertVerifier for DangerousVerifier { Err(ref e) if e.to_string().contains("NotValidForName") && self.skip_domain_validation => { - Ok(rustls::client::ServerCertVerified::assertion()) + Ok(ServerCertVerified::assertion()) } Err(e) => Err(e), } } } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::prelude::v1::Result + { + self.verifier.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::prelude::v1::Result + { + self.verifier.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.verifier.supported_verify_schemes() + } } diff --git a/src/lib.rs b/src/lib.rs index f0605910..3e2c7983 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -466,7 +466,8 @@ pub use self::conn::pool::Pool; #[doc(inline)] pub use self::error::{ - DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, UrlError, + tls::TlsError, DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, + UrlError, }; #[doc(inline)] diff --git a/src/opts/rustls_opts.rs b/src/opts/rustls_opts.rs index 562846d8..ef954ea9 100644 --- a/src/opts/rustls_opts.rs +++ b/src/opts/rustls_opts.rs @@ -1,6 +1,6 @@ #![cfg(feature = "rustls-tls")] -use rustls::{Certificate, PrivateKey}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer}; use rustls_pemfile::{certs, rsa_private_keys}; use std::{borrow::Cow, path::Path}; @@ -50,27 +50,31 @@ impl ClientIdentity { self.priv_key.borrow() } - pub(crate) async fn load(&self) -> crate::Result<(Vec, PrivateKey)> { + pub(crate) async fn load( + &self, + ) -> crate::Result<(Vec>, PrivateKeyDer<'static>)> { let cert_data = self.cert_chain.read().await?; let key_data = self.priv_key.read().await?; let mut cert_chain = Vec::new(); if std::str::from_utf8(&cert_data).is_err() { - cert_chain.push(Certificate(cert_data.into_owned())); + cert_chain.push(CertificateDer::from(cert_data.into_owned())); } else { - for cert in certs(&mut &*cert_data)? { - cert_chain.push(Certificate(cert)); + for cert in certs(&mut &*cert_data) { + cert_chain.push(cert?); } } let priv_key = if std::str::from_utf8(&key_data).is_err() { - Some(PrivateKey(key_data.into_owned())) + Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from( + key_data.into_owned(), + ))) } else { - rsa_private_keys(&mut &*key_data)? - .into_iter() - .take(1) - .map(PrivateKey) - .next() + let mut priv_key = None; + for key in rsa_private_keys(&mut &*key_data).take(1) { + priv_key = Some(PrivateKeyDer::Pkcs1(key?.clone_key())); + } + priv_key }; Ok(( From 9d5ced305afaa5c2e6decb5d190db4107f6bceea Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 19 Mar 2024 10:03:58 +0300 Subject: [PATCH 066/130] Fix should_change_user for mysql 5.6 --- src/conn/mod.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 8d58f72b..7457f3fd 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1547,7 +1547,7 @@ mod test { &["mysql_native_password"] }; - for plugin in plugins { + for (i, plugin) in plugins.iter().enumerate() { let mut rng = rand::thread_rng(); let mut pass = [0u8; 10]; pass.try_fill(&mut rng).unwrap(); @@ -1555,9 +1555,15 @@ mod test { .map(|x| ((x % (123 - 97)) + 97) as char) .collect(); - conn.query_drop("DROP USER /*!50700 IF EXISTS */ /*M!100103 IF EXISTS */ __mats") - .await - .unwrap(); + let result = conn + .query_drop("DROP USER /*!50700 IF EXISTS */ /*M!100103 IF EXISTS */ __mats") + .await; + if matches!(conn.server_version(), (5, 6, _)) && i == 0 { + // IF EXISTS is not supported on 5.6 so the query will fail on the first iteration + drop(result); + } else { + result.unwrap(); + } if conn.inner.is_mariadb || conn.server_version() < (5, 7, 0) { if matches!(conn.server_version(), (5, 6, _)) { From 59b1035fb17671492efac0421bb3272c3030502d Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 19 Mar 2024 22:30:54 +0300 Subject: [PATCH 067/130] docs: update the Crate Features section --- Cargo.toml | 3 ++- README.md | 32 +++++++++++++----------- src/error/tls/mod.rs | 4 +-- src/error/tls/native_tls_error.rs | 2 +- src/io/mod.rs | 4 +-- src/io/tls/mod.rs | 2 +- src/io/tls/native_tls_io.rs | 2 +- src/lib.rs | 41 +++++++++++++++++-------------- src/opts/mod.rs | 10 +++----- src/opts/native_tls_opts.rs | 2 +- 10 files changed, 54 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index da62f97d..d0fb5571 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,6 +117,8 @@ rustls-tls = [ "rustls-pemfile", ] +binlog = ["mysql_common/binlog"] + # mysql_common features derive = ["mysql_common/derive"] chrono = ["mysql_common/chrono"] @@ -124,7 +126,6 @@ time = ["mysql_common/time"] bigdecimal = ["mysql_common/bigdecimal"] rust_decimal = ["mysql_common/rust_decimal"] frunk = ["mysql_common/frunk"] -binlog = ["mysql_common/binlog"] # other features tracing = ["dep:tracing"] diff --git a/README.md b/README.md index e4bc948e..04ad34c6 100644 --- a/README.md +++ b/README.md @@ -37,23 +37,18 @@ as well as `native-tls`-based TLS support. mysql_async = { version = "*", default-features = false, features = ["minimal"]} ``` - **Note:* it is possible to use another `flate2` backend by directly choosing it: +* `minimal-rust` - same as `minimal` but rust-based flate2 backend is chosen. Enables: - ```toml - [dependencies] - mysql_async = { version = "*", default-features = false } - flate2 = { version = "*", default-features = false, features = ["rust_backend"] } - ``` + - `flate2/rust_backend` -* `default` – enables the following set of crate's and dependencies' features: +* `default` – enables the following set of features: + - `minimal` - `native-tls-tls` - - `flate2/zlib" - - `mysql_common/bigdecimal03` - - `mysql_common/rust_decimal` - - `mysql_common/time03` - - `mysql_common/uuid` - - `mysql_common/frunk` + - `bigdecimal` + - `rust_decimal` + - `time` + - `frunk` - `binlog` * `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. @@ -95,12 +90,19 @@ as well as `native-tls`-based TLS support. mysql_async = { version = "*", features = ["tracing"] } ``` -* `derive` – enables `mysql_commom/derive` feature - * `binlog` - enables binlog-related functionality. Enables: - `mysql_common/binlog" +#### Proxied features + +* `derive` – enables `mysql_common/derive` feature +* `chrono` = enables `mysql_common/chrono` feature +* `time` = enables `mysql_common/time` feature +* `bigdecimal` = enables `mysql_common/bigdecimal` feature +* `rust_decimal` = enables `mysql_common/rust_decimal` feature +* `frunk` = enables `mysql_common/frunk` feature + [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features ## TLS/SSL Support diff --git a/src/error/tls/mod.rs b/src/error/tls/mod.rs index 220ed850..4048fbfe 100644 --- a/src/error/tls/mod.rs +++ b/src/error/tls/mod.rs @@ -1,9 +1,9 @@ -#![cfg(any(feature = "native-tls", feature = "rustls-tls"))] +#![cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] pub mod native_tls_error; pub mod rustls_error; -#[cfg(feature = "native-tls")] +#[cfg(feature = "native-tls-tls")] pub use native_tls_error::TlsError; #[cfg(feature = "rustls")] diff --git a/src/error/tls/native_tls_error.rs b/src/error/tls/native_tls_error.rs index 8ca8b6cb..6b324011 100644 --- a/src/error/tls/native_tls_error.rs +++ b/src/error/tls/native_tls_error.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "native-tls")] +#![cfg(feature = "native-tls-tls")] use std::fmt::Display; diff --git a/src/io/mod.rs b/src/io/mod.rs index f5ffc5c6..273acf1e 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -191,7 +191,7 @@ impl Endpoint { matches!(self, Endpoint::Secure(_)) } - #[cfg(all(not(feature = "native-tls"), not(feature = "rustls")))] + #[cfg(all(not(feature = "native-tls-tls"), not(feature = "rustls")))] pub async fn make_secure( &mut self, _domain: String, @@ -499,7 +499,7 @@ mod test { super::Endpoint::Plain(Some(stream)) => stream, #[cfg(feature = "rustls-tls")] super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0, - #[cfg(feature = "native-tls")] + #[cfg(feature = "native-tls-tls")] super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().get_ref().get_ref(), _ => unreachable!(), }; diff --git a/src/io/tls/mod.rs b/src/io/tls/mod.rs index 92f5e7c2..623e28ae 100644 --- a/src/io/tls/mod.rs +++ b/src/io/tls/mod.rs @@ -1,4 +1,4 @@ -#![cfg(any(feature = "native-tls", feature = "rustls"))] +#![cfg(any(feature = "native-tls-tls", feature = "rustls"))] mod native_tls_io; mod rustls_io; diff --git a/src/io/tls/native_tls_io.rs b/src/io/tls/native_tls_io.rs index b4d5f751..adc8d980 100644 --- a/src/io/tls/native_tls_io.rs +++ b/src/io/tls/native_tls_io.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "native-tls")] +#![cfg(feature = "native-tls-tls")] use native_tls::{Certificate, TlsConnector}; diff --git a/src/lib.rs b/src/lib.rs index 3e2c7983..01e01457 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,23 +36,18 @@ //! mysql_async = { version = "*", default-features = false, features = ["minimal"]} //! ``` //! -//! **Note:* it is possible to use another `flate2` backend by directly choosing it: +//! * `minimal-rust` - same as `minimal` but rust-based flate2 backend is chosen. Enables: //! -//! ```toml -//! [dependencies] -//! mysql_async = { version = "*", default-features = false } -//! flate2 = { version = "*", default-features = false, features = ["rust_backend"] } -//! ``` +//! - `flate2/rust_backend` //! -//! * `default` – enables the following set of crate's and dependencies' features: +//! * `default` – enables the following set of features: //! +//! - `minimal` //! - `native-tls-tls` -//! - `flate2/zlib" -//! - `mysql_common/bigdecimal03` -//! - `mysql_common/rust_decimal` -//! - `mysql_common/time03` -//! - `mysql_common/uuid` -//! - `mysql_common/frunk` +//! - `bigdecimal` +//! - `rust_decimal` +//! - `time` +//! - `frunk` //! - `binlog` //! //! * `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. @@ -94,12 +89,19 @@ //! mysql_async = { version = "*", features = ["tracing"] } //! ``` //! -//! * `derive` – enables `mysql_commom/derive` feature -//! //! * `binlog` - enables binlog-related functionality. Enables: //! //! - `mysql_common/binlog" //! +//! ### Proxied features +//! +//! * `derive` – enables `mysql_common/derive` feature +//! * `chrono` = enables `mysql_common/chrono` feature +//! * `time` = enables `mysql_common/time` feature +//! * `bigdecimal` = enables `mysql_common/bigdecimal` feature +//! * `rust_decimal` = enables `mysql_common/rust_decimal` feature +//! * `frunk` = enables `mysql_common/frunk` feature +//! //! [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features //! //! # TLS/SSL Support @@ -464,10 +466,13 @@ pub use self::conn::Conn; #[doc(inline)] pub use self::conn::pool::Pool; +#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] +#[doc(inline)] +pub use self::error::tls::TlsError; + #[doc(inline)] pub use self::error::{ - tls::TlsError, DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, - UrlError, + DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, UrlError, }; #[doc(inline)] @@ -477,7 +482,7 @@ pub use self::query::QueryWithParams; pub use self::queryable::transaction::IsolationLevel; #[doc(inline)] -#[cfg(any(feature = "rustls", feature = "native-tls"))] +#[cfg(any(feature = "rustls", feature = "native-tls-tls"))] pub use self::opts::ClientIdentity; #[doc(inline)] diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 96cea077..84e1c81b 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -9,7 +9,7 @@ mod native_tls_opts; mod rustls_opts; -#[cfg(feature = "native-tls")] +#[cfg(feature = "native-tls-tls")] pub use native_tls_opts::ClientIdentity; #[cfg(feature = "rustls-tls")] @@ -117,13 +117,11 @@ impl HostPortOrUrl { /// Represents data that is either on-disk or in the buffer. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] pub enum PathOrBuf<'a> { Path(Cow<'a, Path>), Buf(Cow<'a, [u8]>), } -#[cfg(any(feature = "native-tls", feature = "rustls-tls"))] impl<'a> PathOrBuf<'a> { /// Will either read data from disk or return the buffered data. pub async fn read(&self) -> io::Result> { @@ -190,7 +188,7 @@ impl<'a> From<&'a [u8]> for PathOrBuf<'a> { /// ``` #[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] pub struct SslOpts { - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] client_identity: Option, root_certs: Vec>, skip_domain_validation: bool, @@ -199,7 +197,7 @@ pub struct SslOpts { } impl SslOpts { - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] pub fn with_client_identity(mut self, identity: Option) -> Self { self.client_identity = identity; self @@ -241,7 +239,7 @@ impl SslOpts { self } - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] pub fn client_identity(&self) -> Option<&ClientIdentity> { self.client_identity.as_ref() } diff --git a/src/opts/native_tls_opts.rs b/src/opts/native_tls_opts.rs index 7ab5be19..a02c6041 100644 --- a/src/opts/native_tls_opts.rs +++ b/src/opts/native_tls_opts.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "native-tls")] +#![cfg(feature = "native-tls-tls")] use std::borrow::Cow; From d8fe6600ed586889036f7938d3e0cf2a12d06f9f Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 19 Mar 2024 22:36:49 +0300 Subject: [PATCH 068/130] ci: fix warnings, check minimal featureset --- azure-pipelines.yml | 3 +++ src/conn/pool/mod.rs | 1 - src/opts/mod.rs | 1 - src/queryable/mod.rs | 1 - 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 387b5c4b..04e90187 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -45,6 +45,9 @@ jobs: SSL=false COMPRESS=true cargo test SSL=true COMPRESS=true cargo test SSL=true COMPRESS=false cargo test --no-default-features --features default-rustls + + SSL=true COMPRESS=false cargo check --no-default-features --features minimal + SSL=true COMPRESS=false cargo check --no-default-features --features minimal-rust env: RUST_BACKTRACE: 1 DATABASE_URL: mysql://root:root@127.0.0.1:3306/mysql diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 465ea8e4..433395eb 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -14,7 +14,6 @@ use std::{ borrow::Borrow, cmp::Reverse, collections::VecDeque, - convert::TryFrom, hash::{Hash, Hasher}, str::FromStr, sync::{atomic, Arc, Mutex}, diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 84e1c81b..1f62e136 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -21,7 +21,6 @@ use url::{Host, Url}; use std::{ borrow::Cow, - convert::TryFrom, fmt, io, net::{Ipv4Addr, Ipv6Addr}, path::{Path, PathBuf}, diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index 40edb0f6..ebef3668 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -587,7 +587,6 @@ impl Queryable for Transaction<'_> { #[cfg(test)] mod tests { - use super::Queryable; use crate::{error::Result, prelude::*, test_misc::get_opts, Conn}; #[tokio::test] From 46fdb25bfd1dbd68400b0c523d2b0a76b7d4f837 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 20 Mar 2024 09:47:33 +0300 Subject: [PATCH 069/130] Bump micro version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index d0fb5571..ee0cc932 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.34.0" +version = "0.34.1" exclude = ["test/*"] edition = "2021" categories = ["asynchronous", "database"] From 56b319ca85f7cd6c7edc1fe0afaf47573c0344aa Mon Sep 17 00:00:00 2001 From: Roshan Jobanputra Date: Mon, 25 Mar 2024 13:48:01 -0400 Subject: [PATCH 070/130] Add support to specify pre-resolved IP addresses and avoid additional DNS lookup --- src/io/mod.rs | 18 +++++++-- src/opts/mod.rs | 102 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 108 insertions(+), 12 deletions(-) diff --git a/src/io/mod.rs b/src/io/mod.rs index 273acf1e..9a879ec6 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -31,6 +31,7 @@ use std::{ ErrorKind::{BrokenPipe, NotConnected, Other}, }, mem::replace, + net::SocketAddr, ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, @@ -357,9 +358,20 @@ impl Stream { keepalive: Option, ) -> io::Result { let tcp_stream = match addr { - HostPortOrUrl::HostPort(host, port) => { - TcpStream::connect((host.as_str(), *port)).await? - } + HostPortOrUrl::HostPort { + host, + port, + resolved_ips, + } => match resolved_ips { + Some(ips) => { + let addrs = ips + .iter() + .map(|ip| SocketAddr::new(*ip, *port)) + .collect::>(); + TcpStream::connect(&*addrs).await? + } + None => TcpStream::connect((host.as_str(), *port)).await?, + }, HostPortOrUrl::Url(url) => { let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?; TcpStream::connect(&*addrs).await? diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 1f62e136..e9044450 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -22,7 +22,7 @@ use url::{Host, Url}; use std::{ borrow::Cow, fmt, io, - net::{Ipv4Addr, Ipv6Addr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::{Path, PathBuf}, str::FromStr, sync::Arc, @@ -66,37 +66,61 @@ pub const DEFAULT_TTL_CHECK_INTERVAL: Duration = Duration::from_secs(30); /// into socket addresses using to_socket_addrs. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) enum HostPortOrUrl { - HostPort(String, u16), + HostPort { + host: String, + port: u16, + /// The resolved IP addresses to use for the TCP connection. If empty, + /// DNS resolution of `host` will be performed. + resolved_ips: Option>, + }, Url(Url), } impl Default for HostPortOrUrl { fn default() -> Self { - HostPortOrUrl::HostPort("127.0.0.1".to_string(), DEFAULT_PORT) + HostPortOrUrl::HostPort { + host: "127.0.0.1".to_string(), + port: DEFAULT_PORT, + resolved_ips: None, + } } } impl HostPortOrUrl { pub fn get_ip_or_hostname(&self) -> &str { match self { - Self::HostPort(host, _) => host, + Self::HostPort { host, .. } => host, Self::Url(url) => url.host_str().unwrap_or("127.0.0.1"), } } pub fn get_tcp_port(&self) -> u16 { match self { - Self::HostPort(_, port) => *port, + Self::HostPort { port, .. } => *port, Self::Url(url) => url.port().unwrap_or(DEFAULT_PORT), } } + pub fn get_resolved_ips(&self) -> &Option> { + match self { + Self::HostPort { resolved_ips, .. } => resolved_ips, + Self::Url(_) => &None, + } + } + pub fn is_loopback(&self) -> bool { match self { - Self::HostPort(host, _) => { + Self::HostPort { + host, resolved_ips, .. + } => { let v4addr: Option = FromStr::from_str(host).ok(); let v6addr: Option = FromStr::from_str(host).ok(); - if let Some(addr) = v4addr { + if resolved_ips + .as_ref() + .is_some_and(|s| s.iter().any(|ip| ip.is_loopback())) + { + true + } else if let Some(addr) = v4addr { addr.is_loopback() } else if let Some(addr) = v6addr { addr.is_loopback() @@ -644,6 +668,11 @@ impl Opts { self.inner.address.get_tcp_port() } + /// The resolved IPs for the mysql server, if provided. + pub fn resolved_ips(&self) -> &Option> { + self.inner.address.get_resolved_ips() + } + /// User (defaults to `None`). /// /// # Connection URL @@ -1139,6 +1168,7 @@ pub struct OptsBuilder { opts: MysqlOpts, ip_or_hostname: String, tcp_port: u16, + resolved_ips: Option>, } impl Default for OptsBuilder { @@ -1148,6 +1178,7 @@ impl Default for OptsBuilder { opts: MysqlOpts::default(), ip_or_hostname: address.get_ip_or_hostname().into(), tcp_port: address.get_tcp_port(), + resolved_ips: None, } } } @@ -1168,6 +1199,7 @@ impl OptsBuilder { OptsBuilder { tcp_port: opts.inner.address.get_tcp_port(), ip_or_hostname: opts.inner.address.get_ip_or_hostname().to_string(), + resolved_ips: opts.inner.address.get_resolved_ips().clone(), opts: opts.inner.mysql_opts.clone(), } } @@ -1184,6 +1216,14 @@ impl OptsBuilder { self } + /// Defines already-resolved IPs to use for the connection. When provided + /// the connection will not perform DNS resolution and the hostname will be + /// used only for TLS identity verification purposes. + pub fn resolved_ips>>(mut self, ips: Option) -> Self { + self.resolved_ips = ips.map(Into::into); + self + } + /// Defines user name. See [`Opts::user`]. pub fn user>(mut self, user: Option) -> Self { self.opts.user = user.map(Into::into); @@ -1349,7 +1389,11 @@ impl OptsBuilder { impl From for Opts { fn from(builder: OptsBuilder) -> Opts { - let address = HostPortOrUrl::HostPort(builder.ip_or_hostname, builder.tcp_port); + let address = HostPortOrUrl::HostPort { + host: builder.ip_or_hostname, + port: builder.tcp_port, + resolved_ips: builder.resolved_ips, + }; let inner_opts = InnerOpts { mysql_opts: builder.opts, address, @@ -1838,7 +1882,7 @@ mod test { use super::{HostPortOrUrl, MysqlOpts, Opts, Url}; use crate::{error::UrlError::InvalidParamValue, SslOpts}; - use std::str::FromStr; + use std::{net::IpAddr, net::Ipv4Addr, net::Ipv6Addr, str::FromStr}; #[test] fn test_builder_eq_url() { @@ -2019,4 +2063,44 @@ mod test { let url_opts = super::Opts::from_str(url).unwrap(); assert_eq!(url_opts.db_name(), builder_opts.db_name()); } + + #[test] + fn test_builder_update_port_host_resolved_ips() { + let builder = super::OptsBuilder::default() + .ip_or_hostname("foo") + .tcp_port(33306); + + let resolved = vec![ + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 7)), + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff)), + ]; + let builder2 = builder + .clone() + .tcp_port(55223) + .resolved_ips(Some(resolved.clone())); + + let builder_opts = Opts::from(builder); + assert_eq!(builder_opts.ip_or_hostname(), "foo"); + assert_eq!(builder_opts.tcp_port(), 33306); + assert_eq!( + builder_opts.hostport_or_url(), + &HostPortOrUrl::HostPort { + host: "foo".to_string(), + port: 33306, + resolved_ips: None + } + ); + + let builder_opts2 = Opts::from(builder2); + assert_eq!(builder_opts2.ip_or_hostname(), "foo"); + assert_eq!(builder_opts2.tcp_port(), 55223); + assert_eq!( + builder_opts2.hostport_or_url(), + &HostPortOrUrl::HostPort { + host: "foo".to_string(), + port: 55223, + resolved_ips: Some(resolved), + } + ); + } } From 55e7b70b21997c834faf5ad515d645dbeb4857d6 Mon Sep 17 00:00:00 2001 From: csh <458761603@qq.com> Date: Tue, 9 Apr 2024 18:22:01 +0800 Subject: [PATCH 071/130] Compatibility with non-Unix and non-Windows platforms. Signed-off-by: csh <458761603@qq.com> --- src/conn/mod.rs | 2 +- src/io/mod.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 7457f3fd..1e8027cf 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -913,7 +913,7 @@ impl Conn { { Stream::connect_socket(_path.to_owned()).await? } - #[cfg(target_os = "windows")] + #[cfg(not(unix))] return Err(crate::DriverError::NamedPipesDisabled.into()); } else { let keepalive = opts diff --git a/src/io/mod.rs b/src/io/mod.rs index 9a879ec6..da2be2fb 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -12,6 +12,7 @@ use bytes::BytesMut; use futures_core::{ready, stream}; use mysql_common::proto::codec::PacketCodec as PacketCodecInner; use pin_project::pin_project; +#[cfg(any(unix, windows))] use socket2::{Socket as Socket2Socket, TcpKeepalive}; #[cfg(unix)] use tokio::io::AsyncWriteExt; @@ -378,6 +379,7 @@ impl Stream { } }; + #[cfg(any(unix, windows))] if let Some(duration) = keepalive { #[cfg(unix)] let socket = { From d34b756baff0d39a3203d27a0123a64534a3cea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 17 Apr 2024 18:48:00 +0000 Subject: [PATCH 072/130] Replace lazy_static with std::cell:OnceLock --- Cargo.toml | 1 - src/lib.rs | 26 ++++++++------------------ 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ee0cc932..ac414a18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ futures-core = "0.3" futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" -lazy_static = "1" lru = "0.12.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } mysql_common = { version = "0.32", default-features = false } diff --git a/src/lib.rs b/src/lib.rs index 01e01457..744ac8af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -608,9 +608,8 @@ pub mod prelude { #[doc(hidden)] pub mod test_misc { - use lazy_static::lazy_static; - use std::env; + use std::sync::OnceLock; use crate::opts::{Opts, OptsBuilder, SslOpts}; @@ -621,26 +620,17 @@ pub mod test_misc { _dummy(err); } - lazy_static! { - pub static ref DATABASE_URL: String = { + pub fn get_opts() -> OptsBuilder { + static DATABASE_OPTS: OnceLock = OnceLock::new(); + let database_opts = DATABASE_OPTS.get_or_init(|| { if let Ok(url) = env::var("DATABASE_URL") { - let opts = Opts::from_url(&url).expect("DATABASE_URL invalid"); - if opts - .db_name() - .expect("a database name is required") - .is_empty() - { - panic!("database name is empty"); - } - url + Opts::from_url(&url).expect("DATABASE_URL invalid") } else { - "mysql://root:password@localhost:3307/mysql".into() + Opts::from_url("mysql://root:password@localhost:3307/mysql").unwrap() } - }; - } + }); - pub fn get_opts() -> OptsBuilder { - let mut builder = OptsBuilder::from_opts(Opts::from_url(&DATABASE_URL).unwrap()); + let mut builder = OptsBuilder::from_opts(database_opts.clone()); if test_ssl() { let ssl_opts = SslOpts::default() .with_danger_skip_domain_validation(true) From 11a86f17f3fd54338b0687699a7acaed537b9129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 17 Apr 2024 19:15:16 +0000 Subject: [PATCH 073/130] Replace once_cell with std::sync::OnceLock --- Cargo.toml | 1 - src/conn/mod.rs | 12 ++++++------ src/io/mod.rs | 4 ++-- src/lib.rs | 7 +++++-- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ac414a18..9d4efd26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ keyed_priority_queue = "0.4" lru = "0.12.0" mio = { version = "0.8.0", features = ["os-poll", "net"] } mysql_common = { version = "0.32", default-features = false } -once_cell = "1.7.2" pem = "3.0" percent-encoding = "2.1.0" pin-project = "1.0.2" diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 1e8027cf..df6bc020 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -581,7 +581,7 @@ impl Conn { ); // Serialize here to satisfy borrow checker. - let mut buf = crate::BUFFER_POOL.get(); + let mut buf = crate::buffer_pool().get(); handshake_response.serialize(buf.as_mut()); self.write_packet(buf).await?; @@ -633,7 +633,7 @@ impl Conn { if let Some(plugin_data) = plugin_data { self.write_struct(&plugin_data.into_owned()).await?; } else { - self.write_packet(crate::BUFFER_POOL.get()).await?; + self.write_packet(crate::buffer_pool().get()).await?; } self.continue_auth().await?; @@ -701,7 +701,7 @@ impl Conn { } Some(0x04) => { let pass = self.inner.opts.pass().unwrap_or_default(); - let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes()); + let mut pass = crate::buffer_pool().get_with(pass.as_bytes()); pass.as_mut().push(0); if self.is_secure() || self.is_socket() { @@ -838,13 +838,13 @@ impl Conn { /// Writes bytes to a server. pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> { - let buf = crate::BUFFER_POOL.get_with(bytes); + let buf = crate::buffer_pool().get_with(bytes); self.write_packet(buf).await } /// Sends a serializable structure to a server. pub(crate) async fn write_struct(&mut self, x: &T) -> Result<()> { - let mut buf = crate::BUFFER_POOL.get(); + let mut buf = crate::buffer_pool().get(); x.serialize(buf.as_mut()); self.write_packet(buf).await } @@ -870,7 +870,7 @@ impl Conn { T: AsRef<[u8]>, { let cmd_data = cmd_data.as_ref(); - let mut buf = crate::BUFFER_POOL.get(); + let mut buf = crate::buffer_pool().get(); let body = buf.as_mut(); body.push(cmd as u8); body.extend_from_slice(cmd_data); diff --git a/src/io/mod.rs b/src/io/mod.rs index da2be2fb..4bb27e9e 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -75,7 +75,7 @@ impl Default for PacketCodec { fn default() -> Self { Self { inner: Default::default(), - decode_buf: crate::BUFFER_POOL.get(), + decode_buf: crate::buffer_pool().get(), } } } @@ -100,7 +100,7 @@ impl Decoder for PacketCodec { fn decode(&mut self, src: &mut BytesMut) -> std::result::Result, IoError> { if self.inner.decode(src, self.decode_buf.as_mut())? { - let new_buf = crate::BUFFER_POOL.get(); + let new_buf = crate::buffer_pool().get(); Ok(Some(replace(&mut self.decode_buf, new_buf))) } else { Ok(None) diff --git a/src/lib.rs b/src/lib.rs index 744ac8af..679d08e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -453,8 +453,11 @@ mod queryable; type BoxFuture<'a, T> = futures_core::future::BoxFuture<'a, Result>; -static BUFFER_POOL: once_cell::sync::Lazy> = - once_cell::sync::Lazy::new(Default::default); +fn buffer_pool() -> &'static Arc { + static BUFFER_POOL: std::sync::OnceLock> = + std::sync::OnceLock::new(); + BUFFER_POOL.get_or_init(Default::default) +} #[cfg(feature = "binlog")] #[doc(inline)] From d5969785c614012cfeddaa240c08f97862ab0bae Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Thu, 23 May 2024 13:08:58 +0300 Subject: [PATCH 074/130] conn: add test for initial error packet handling Signed-off-by: Petros Angelatos --- src/conn/mod.rs | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index df6bc020..f730e63b 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1270,10 +1270,11 @@ mod test { use futures_util::stream::{self, StreamExt}; use mysql_common::constants::MAX_PAYLOAD_LEN; use rand::Fill; + use tokio::{io::AsyncWriteExt, net::TcpListener}; use crate::{ from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error, - OptsBuilder, Pool, Value, WhiteListFsHandler, + OptsBuilder, Pool, ServerError, Value, WhiteListFsHandler, }; #[tokio::test] @@ -2189,6 +2190,45 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_handle_initial_error_packet() { + let header = [ + 0x68, 0x00, 0x00, // packet_length + 0x00, // sequence + 0xff, // error_header + 0x69, 0x04, // error_code + ]; + let error_message = "Host '172.17.0.1' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'"; + + // Create a fake MySQL server that immediately replies with an error packet. + let listener = TcpListener::bind("127.0.0.1:0000").await.unwrap(); + + let listen_addr = listener.local_addr().unwrap(); + + tokio::task::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + stream.write_all(&header).await.unwrap(); + stream.write_all(error_message.as_bytes()).await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + let opts = OptsBuilder::default() + .ip_or_hostname(listen_addr.ip().to_string()) + .tcp_port(listen_addr.port()); + let server_err = match Conn::new(opts).await { + Err(Error::Server(server_err)) => server_err, + other => panic!("expected server error but got: {:?}", other), + }; + assert_eq!( + server_err, + ServerError { + code: 1129, + state: "HY000".to_owned(), + message: error_message.to_owned(), + } + ); + } + #[cfg(feature = "nightly")] mod bench { use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts}; From 4bf929a080d91d6b846d9b3abe0d44480d06d691 Mon Sep 17 00:00:00 2001 From: Petros Angelatos Date: Thu, 23 May 2024 13:09:30 +0300 Subject: [PATCH 075/130] conn: handle initial error packet correctly If we haven't completed the hashshake the server will not be aware of our capabilities and so its will packets behave as if we have none. This is necessary to correcly parse an initial error packet which never contains an SQL State field even if the client capabilities will eventually contain CLIENT_PROTOCOL_41. https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html Signed-off-by: Petros Angelatos --- src/conn/mod.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index f730e63b..05240404 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -102,6 +102,7 @@ struct ConnInner { status: StatusFlags, last_ok_packet: Option>, last_err_packet: Option>, + handshake_complete: bool, pool: Option, pending_result: std::result::Result, ServerError>, tx_status: TxStatus, @@ -147,6 +148,7 @@ impl ConnInner { status: StatusFlags::empty(), last_ok_packet: None, last_err_packet: None, + handshake_complete: false, stream: None, is_mariadb: false, version: (0, 0, 0), @@ -585,6 +587,7 @@ impl Conn { handshake_response.serialize(buf.as_mut()); self.write_packet(buf).await?; + self.inner.handshake_complete = true; Ok(()) } @@ -789,7 +792,19 @@ impl Conn { if let Ok(ok_packet) = ok_packet { self.handle_ok(ok_packet.into_owned()); } else { - let err_packet = ParseBuf(packet).parse::(self.capabilities()); + // If we haven't completed the handshake the server will not be aware of our + // capabilities and so it will behave as if we have none. In particular, the error + // packet will not contain a SQL State field even if our capabilities do contain the + // `CLIENT_PROTOCOL_41` flag. Therefore it is necessary to parse an incoming packet + // with no capability assumptions if we have not completed the handshake. + // + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html + let capabilities = if self.inner.handshake_complete { + self.capabilities() + } else { + CapabilityFlags::empty() + }; + let err_packet = ParseBuf(packet).parse::(capabilities); if let Ok(err_packet) = err_packet { self.handle_err(err_packet)?; return Ok(true); From a16384f2c692643b1da8d43943d6dc605497e340 Mon Sep 17 00:00:00 2001 From: Seth Westphal Date: Tue, 23 Jul 2024 09:05:34 -0500 Subject: [PATCH 076/130] Update dependencies. --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9d4efd26..485f303d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" lru = "0.12.0" -mio = { version = "0.8.0", features = ["os-poll", "net"] } +mio = { version = "1", features = ["os-poll", "net"] } mysql_common = { version = "0.32", default-features = false } pem = "3.0" percent-encoding = "2.1.0" @@ -40,7 +40,7 @@ twox-hash = "1" url = "2.1" [dependencies.tokio-rustls] -version = "0.25" +version = "0.26" optional = true [dependencies.tokio-native-tls] @@ -52,7 +52,7 @@ version = "0.2" optional = true [dependencies.rustls] -version = "0.22.2" +version = "0.23" features = [] optional = true From 601bb9ea7150865e506f6d022c32084a72fb6664 Mon Sep 17 00:00:00 2001 From: jaumelopez Date: Sat, 3 Aug 2024 16:00:19 +0200 Subject: [PATCH 077/130] Re-export ColumnIndex --- src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 679d08e2..cbe5be55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -567,6 +567,8 @@ pub mod prelude { #[doc(inline)] pub use mysql_common::prelude::FromRow; #[doc(inline)] + pub use mysql_common::prelude::ColumnIndex; + #[doc(inline)] pub use mysql_common::prelude::{FromValue, ToValue}; /// Everything that is a statement. From 944b530a537a9f3787749a040143776e1576dfb2 Mon Sep 17 00:00:00 2001 From: jaumelopez Date: Sat, 3 Aug 2024 16:06:56 +0200 Subject: [PATCH 078/130] Ran rustfmt --- src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index cbe5be55..0d637963 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -565,10 +565,10 @@ pub mod prelude { #[doc(inline)] pub use crate::queryable::Queryable; #[doc(inline)] - pub use mysql_common::prelude::FromRow; - #[doc(inline)] pub use mysql_common::prelude::ColumnIndex; #[doc(inline)] + pub use mysql_common::prelude::FromRow; + #[doc(inline)] pub use mysql_common::prelude::{FromValue, ToValue}; /// Everything that is a statement. From 5bd5c1e4ad599d3986fa8d3c1bd775427a135448 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 13 Aug 2024 10:33:21 +0300 Subject: [PATCH 079/130] Remove explicit mio dependency --- Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 485f303d..e9090e50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,6 @@ futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" lru = "0.12.0" -mio = { version = "1", features = ["os-poll", "net"] } mysql_common = { version = "0.32", default-features = false } pem = "3.0" percent-encoding = "2.1.0" From 091286cd92b2541f1522e0eb0ba478e98027fdf8 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 13 Aug 2024 10:34:27 +0300 Subject: [PATCH 080/130] Bump micro version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index e9090e50..f169707e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.34.1" +version = "0.34.2" exclude = ["test/*"] edition = "2021" categories = ["asynchronous", "database"] From 02fdc1a78b1a6fd356b05dcc7d2a0af5c1c352df Mon Sep 17 00:00:00 2001 From: Emilio Wuerges Date: Tue, 5 Nov 2024 15:04:38 -0300 Subject: [PATCH 081/130] Lower event log level of non-fatal server errors --- src/tracing_utils.rs | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs index b32170c0..467389a2 100644 --- a/src/tracing_utils.rs +++ b/src/tracing_utils.rs @@ -44,7 +44,25 @@ macro_rules! instrument_result { ($fut:expr, $span:expr) => {{ let fut = async { $fut.await.or_else(|e| { - tracing::error!(error = %e); + match &e { + $crate::error::Error::Server(server_error) => { + match server_error.code { + 1062 => { + tracing::warn!(error = %e, "duplicated entry for key") + } + 1451 => { + tracing::warn!(error = %e, "foreign key violation") + } + 1644 => { + tracing::warn!(error = %e, "user defined exception condition"); + } + _ => tracing::error!(error = %e), + } + }, + e => { + tracing::error!(error = %e); + } + } Err(e) }) }; From 8095ad13016377bd893f5b9f67cde287fc5f0939 Mon Sep 17 00:00:00 2001 From: Emilio Wuerges Date: Tue, 5 Nov 2024 15:24:02 -0300 Subject: [PATCH 082/130] moved instrumented text into comments --- src/tracing_utils.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs index 467389a2..190e5d33 100644 --- a/src/tracing_utils.rs +++ b/src/tracing_utils.rs @@ -47,14 +47,17 @@ macro_rules! instrument_result { match &e { $crate::error::Error::Server(server_error) => { match server_error.code { + // Duplicated entry for key 1062 => { - tracing::warn!(error = %e, "duplicated entry for key") + tracing::warn!(error = %e) } + // Foreign key violation 1451 => { - tracing::warn!(error = %e, "foreign key violation") + tracing::warn!(error = %e) } + // User defined exception condition 1644 => { - tracing::warn!(error = %e, "user defined exception condition"); + tracing::warn!(error = %e); } _ => tracing::error!(error = %e), } From edd8bed7427b80f2d4c2addc69023f4beb2690a9 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 8 Nov 2024 11:34:57 +0300 Subject: [PATCH 083/130] Bump dependencies --- Cargo.toml | 6 +++--- src/conn/stmt_cache.rs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f169707e..b612531d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" lru = "0.12.0" -mysql_common = { version = "0.32", default-features = false } +mysql_common = { version = "0.33", default-features = false } pem = "3.0" percent-encoding = "2.1.0" pin-project = "1.0.2" @@ -29,13 +29,13 @@ rand = "0.8.5" serde = "1" serde_json = "1" socket2 = "0.5.2" -thiserror = "1.0.4" +thiserror = "2" tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] } tokio-util = { version = "0.7.2", features = ["codec", "io"] } tracing = { version = "0.1.37", default-features = false, features = [ "attributes", ], optional = true } -twox-hash = "1" +twox-hash = "2" url = "2.1" [dependencies.tokio-rustls] diff --git a/src/conn/stmt_cache.rs b/src/conn/stmt_cache.rs index f72240f9..84a99c20 100644 --- a/src/conn/stmt_cache.rs +++ b/src/conn/stmt_cache.rs @@ -7,7 +7,7 @@ // modified, or distributed except according to those terms. use lru::LruCache; -use twox_hash::XxHash; +use twox_hash::XxHash64; use std::{ borrow::Borrow, @@ -42,7 +42,7 @@ pub struct Entry { pub struct StmtCache { cap: usize, cache: LruCache, - query_map: HashMap>, + query_map: HashMap>, } impl StmtCache { From 9bb5dae8703f3c1ab9a7d3cdfcf6cd1a94b48c02 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 8 Nov 2024 11:35:13 +0300 Subject: [PATCH 084/130] Fix tests for mysql v9 --- src/conn/mod.rs | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 05240404..dbf7e5df 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1416,16 +1416,18 @@ mod test { .filter(|variant| plugins.iter().any(|p| p == variant.0)); for (plug, val, pass) in variants { + dbg!((plug, val, pass, conn.inner.version)); + + if plug == "mysql_native_password" && conn.inner.version >= (9, 0, 0) { + continue; + } + let _ = conn.query_drop("DROP USER 'test_user'@'%'").await; let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug); conn.query_drop(query).await.unwrap(); - if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) { - conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass)) - .await - .unwrap(); - } else { + if conn.inner.version <= (8, 0, 11) { conn.query_drop(format!("SET old_passwords = {}", val)) .await .unwrap(); @@ -1435,6 +1437,10 @@ mod test { )) .await .unwrap(); + } else { + conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass)) + .await + .unwrap(); }; let opts = get_opts() @@ -1564,6 +1570,10 @@ mod test { }; for (i, plugin) in plugins.iter().enumerate() { + if *plugin == "mysql_native_password" && conn.server_version() >= (9, 0, 0) { + continue; + } + let mut rng = rand::thread_rng(); let mut pass = [0u8; 10]; pass.try_fill(&mut rng).unwrap(); From 1c47a9fcebaba407bf1e73adae7657767752a69b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 24 Apr 2024 02:09:00 +0000 Subject: [PATCH 085/130] Update rustls rustls replaced ring with aws-lc-rs as default crypto backend, expose features to select between the two, along with a feature on whether to enable tls 1.2 --- Cargo.toml | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index b612531d..2486c6b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ url = "2.1" [dependencies.tokio-rustls] version = "0.26" +default-features = false optional = true [dependencies.tokio-native-tls] @@ -52,7 +53,7 @@ optional = true [dependencies.rustls] version = "0.23" -features = [] +default-features = false optional = true [dependencies.rustls-pemfile] @@ -87,6 +88,16 @@ default = [ ] default-rustls = [ + "default-rustls-no-provider", + "aws-lc-rs", +] + +default-rustls-ring = [ + "default-rustls-no-provider", + "ring", +] + +default-rustls-no-provider = [ "flate2/rust_backend", "bigdecimal", "rust_decimal", @@ -95,6 +106,7 @@ default-rustls = [ "derive", "rustls-tls", "binlog", + "tls12", ] # minimal feature set with system flate2 impl @@ -114,6 +126,10 @@ rustls-tls = [ "rustls-pemfile", ] +aws-lc-rs = ["rustls/aws_lc_rs", "tokio-rustls/aws_lc_rs"] +ring = ["rustls/ring", "tokio-rustls/ring"] +tls12 = ["rustls/tls12", "tokio-rustls/tls12"] + binlog = ["mysql_common/binlog"] # mysql_common features From 1c03bf6e25a258813ce210109c849ea0af723142 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Mon, 11 Nov 2024 14:27:09 +0000 Subject: [PATCH 086/130] Add metrics for pool internals --- Cargo.toml | 1 + src/conn/mod.rs | 2 + src/conn/pool/metrics.rs | 145 ++++++++++++++++++++++++++++++++++++++ src/conn/pool/mod.rs | 96 ++++++++++++++++++++++--- src/conn/pool/recycler.rs | 60 ++++++++++++++++ 5 files changed, 295 insertions(+), 9 deletions(-) create mode 100644 src/conn/pool/metrics.rs diff --git a/Cargo.toml b/Cargo.toml index b612531d..6ada4949 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ tracing = { version = "0.1.37", default-features = false, features = [ ], optional = true } twox-hash = "2" url = "2.1" +hdrhistogram = { version = "7.5", optional = true } [dependencies.tokio-rustls] version = "0.26" diff --git a/src/conn/mod.rs b/src/conn/mod.rs index dbf7e5df..be778752 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -116,6 +116,7 @@ struct ConnInner { auth_plugin: AuthPlugin<'static>, auth_switched: bool, server_key: Option>, + active_since: Instant, /// Connection is already disconnected. pub(crate) disconnected: bool, /// One-time connection-level infile handler. @@ -169,6 +170,7 @@ impl ConnInner { server_key: None, infile_handler: None, reset_upon_returning_to_a_pool: false, + active_since: Instant::now(), } } diff --git a/src/conn/pool/metrics.rs b/src/conn/pool/metrics.rs new file mode 100644 index 00000000..c5c6a08a --- /dev/null +++ b/src/conn/pool/metrics.rs @@ -0,0 +1,145 @@ +use std::sync::atomic::AtomicUsize; + +use serde::Serialize; + +#[derive(Default, Debug, Serialize)] +#[non_exhaustive] +pub struct Metrics { + /// Guage of active connections to the database server, this includes both connections that have belong + /// to the pool, and connections currently owned by the application. + pub connection_count: AtomicUsize, + /// Guage of active connections that currently belong to the pool. + pub connections_in_pool: AtomicUsize, + /// Guage of GetConn requests that are currently active. + pub active_wait_requests: AtomicUsize, + /// Counter of connections that failed to be created. + pub create_failed: AtomicUsize, + /// Counter of connections discarded due to pool constraints. + pub discarded_superfluous_connection: AtomicUsize, + /// Counter of connections discarded due to being closed upon return to the pool. + pub discarded_unestablished_connection: AtomicUsize, + /// Counter of connections that have been returned to the pool dirty that needed to be cleaned + /// (ie. open transactions, pending queries, etc). + pub dirty_connection_return: AtomicUsize, + /// Counter of connections that have been discarded as they were expired by the pool constraints. + pub discarded_expired_connection: AtomicUsize, + /// Counter of connections that have been reset. + pub resetting_connection: AtomicUsize, + /// Counter of connections that have been discarded as they returned an error during cleanup. + pub discarded_error_during_cleanup: AtomicUsize, + /// Counter of connections that have been returned to the pool. + pub connection_returned_to_pool: AtomicUsize, + /// Histogram of times connections have spent outside of the pool. + #[cfg(feature = "hdrhistogram")] + pub connection_active_duration: MetricsHistogram, + /// Histogram of times connections have spent inside of the pool. + #[cfg(feature = "hdrhistogram")] + pub connection_idle_duration: MetricsHistogram, + /// Histogram of times connections have spent being checked for health. + #[cfg(feature = "hdrhistogram")] + pub check_duration: MetricsHistogram, + /// Histogram of time spent waiting to connect to the server. + #[cfg(feature = "hdrhistogram")] + pub connect_duration: MetricsHistogram, +} + +impl Metrics { + /// Resets all histograms to allow for histograms to be bound to a period of time (ie. between metric scrapes) + #[cfg(feature = "hdrhistogram")] + pub fn clear_histograms(&self) { + self.connection_active_duration.reset(); + self.connection_idle_duration.reset(); + self.check_duration.reset(); + self.connect_duration.reset(); + } +} + +#[cfg(feature = "hdrhistogram")] +#[derive(Debug)] +pub struct MetricsHistogram(std::sync::Mutex>); + +#[cfg(feature = "hdrhistogram")] +impl MetricsHistogram { + pub fn reset(&self) { + self.lock().unwrap().reset(); + } +} + +#[cfg(feature = "hdrhistogram")] +impl Default for MetricsHistogram { + fn default() -> Self { + let hdr = hdrhistogram::Histogram::new_with_bounds(1, 30 * 1_000_000, 2).unwrap(); + Self(std::sync::Mutex::new(hdr)) + } +} + +#[cfg(feature = "hdrhistogram")] +impl Serialize for MetricsHistogram { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let hdr = self.0.lock().unwrap(); + + /// A percentile of this histogram - for supporting serializers this + /// will ignore the key (such as `90%ile`) and instead add a + /// dimension to the metrics (such as `quantile=0.9`). + macro_rules! ile { + ($e:expr) => { + &MetricAlias(concat!("!|quantile=", $e), hdr.value_at_quantile($e)) + }; + } + + /// A 'qualified' metric name - for supporting serializers such as + /// serde_prometheus, this will prepend the metric name to this key, + /// outputting `response_time_count`, for example rather than just + /// `count`. + macro_rules! qual { + ($e:expr) => { + &MetricAlias("<|", $e) + }; + } + + use serde::ser::SerializeMap; + + let mut tup = serializer.serialize_map(Some(10))?; + tup.serialize_entry("samples", qual!(hdr.len()))?; + tup.serialize_entry("min", qual!(hdr.min()))?; + tup.serialize_entry("max", qual!(hdr.max()))?; + tup.serialize_entry("mean", qual!(hdr.mean()))?; + tup.serialize_entry("stdev", qual!(hdr.stdev()))?; + tup.serialize_entry("90%ile", ile!(0.9))?; + tup.serialize_entry("95%ile", ile!(0.95))?; + tup.serialize_entry("99%ile", ile!(0.99))?; + tup.serialize_entry("99.9%ile", ile!(0.999))?; + tup.serialize_entry("99.99%ile", ile!(0.9999))?; + tup.end() + } +} + +/// This is a mocked 'newtype' (eg. `A(u64)`) that instead allows us to +/// define our own type name that doesn't have to abide by Rust's constraints +/// on type names. This allows us to do some manipulation of our metrics, +/// allowing us to add dimensionality to our metrics via key=value pairs, or +/// key manipulation on serializers that support it. +#[cfg(feature = "hdrhistogram")] +struct MetricAlias(&'static str, T); + +#[cfg(feature = "hdrhistogram")] +impl Serialize for MetricAlias { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + serializer.serialize_newtype_struct(self.0, &self.1) + } +} + +#[cfg(feature = "hdrhistogram")] +impl std::ops::Deref for MetricsHistogram { + type Target = std::sync::Mutex>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 433395eb..b1f96d03 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -28,9 +28,12 @@ use crate::{ queryable::transaction::{Transaction, TxOpts}, }; +pub use metrics::Metrics; + mod recycler; // this is a really unfortunate name for a module pub mod futures; +mod metrics; mod ttl_check_inerval; /// Connection that is idling in the pool. @@ -107,7 +110,7 @@ struct Waitlist { } impl Waitlist { - fn push(&mut self, waker: Waker, queue_id: QueueId) { + fn push(&mut self, waker: Waker, queue_id: QueueId) -> bool { // The documentation of Future::poll says: // Note that on multiple calls to poll, only the Waker from // the Context passed to the most recent call should be @@ -120,7 +123,9 @@ impl Waitlist { // This means we have to remove first to have the most recent // waker in the queue. self.remove(queue_id); - self.queue.push(QueuedWaker { queue_id, waker }, queue_id); + self.queue + .push(QueuedWaker { queue_id, waker }, queue_id) + .is_none() } fn pop(&mut self) -> Option { @@ -130,8 +135,8 @@ impl Waitlist { } } - fn remove(&mut self, id: QueueId) { - self.queue.remove(&id); + fn remove(&mut self, id: QueueId) -> bool { + self.queue.remove(&id).is_some() } fn peek_id(&mut self) -> Option { @@ -181,6 +186,7 @@ impl Hash for QueuedWaker { /// Connection pool data. #[derive(Debug)] pub struct Inner { + metrics: Arc, close: atomic::AtomicBool, closed: atomic::AtomicBool, exchange: Mutex, @@ -220,6 +226,7 @@ impl Pool { inner: Arc::new(Inner { close: false.into(), closed: false.into(), + metrics: Arc::new(Metrics::default()), exchange: Mutex::new(Exchange { available: VecDeque::with_capacity(pool_opts.constraints().max()), waiting: Waitlist::default(), @@ -231,6 +238,11 @@ impl Pool { } } + /// Returns metrics for the connection pool. + pub fn metrics(&self) -> Arc { + self.inner.metrics.clone() + } + /// Creates a new pool of connections. pub fn from_url>(url: T) -> Result { let opts = Opts::from_str(url.as_ref())?; @@ -288,6 +300,10 @@ impl Pool { pub(super) fn cancel_connection(&self) { let mut exchange = self.inner.exchange.lock().unwrap(); exchange.exist -= 1; + self.inner + .metrics + .create_failed + .fetch_add(1, atomic::Ordering::Relaxed); // we just enabled the creation of a new connection! if let Some(w) = exchange.waiting.pop() { w.wake(); @@ -320,15 +336,44 @@ impl Pool { // If we are not, just queue if !highest { - exchange.waiting.push(cx.waker().clone(), queue_id); + if exchange.waiting.push(cx.waker().clone(), queue_id) { + self.inner + .metrics + .active_wait_requests + .fetch_add(1, atomic::Ordering::Relaxed); + } return Poll::Pending; } - while let Some(IdlingConn { mut conn, .. }) = exchange.available.pop_back() { + #[allow(unused_variables)] // `since` is only used when `hdrhistogram` is enabled + while let Some(IdlingConn { mut conn, since }) = exchange.available.pop_back() { + self.inner + .metrics + .connections_in_pool + .fetch_sub(1, atomic::Ordering::Relaxed); + if !conn.expired() { + #[cfg(feature = "hdrhistogram")] + self.inner + .metrics + .connection_idle_duration + .lock() + .unwrap() + .saturating_record(since.elapsed().as_micros() as u64); + #[cfg(feature = "hdrhistogram")] + let metrics = self.metrics(); + conn.inner.active_since = Instant::now(); return Poll::Ready(Ok(GetConnInner::Checking( async move { conn.stream_mut()?.check().await?; + #[cfg(feature = "hdrhistogram")] + metrics + .check_duration + .lock() + .unwrap() + .saturating_record( + conn.inner.active_since.elapsed().as_micros() as u64 + ); Ok(conn) } .boxed(), @@ -344,19 +389,52 @@ impl Pool { // we are allowed to make a new connection, so we will! exchange.exist += 1; + self.inner + .metrics + .connection_count + .fetch_add(1, atomic::Ordering::Relaxed); + + let opts = self.opts.clone(); + #[cfg(feature = "hdrhistogram")] + let metrics = self.metrics(); + return Poll::Ready(Ok(GetConnInner::Connecting( - Conn::new(self.opts.clone()).boxed(), + async move { + let conn = Conn::new(opts).await; + #[cfg(feature = "hdrhistogram")] + if let Ok(conn) = &conn { + metrics + .connect_duration + .lock() + .unwrap() + .saturating_record( + conn.inner.active_since.elapsed().as_micros() as u64 + ); + } + conn + } + .boxed(), ))); } // Polled, but no conn available? Back into the queue. - exchange.waiting.push(cx.waker().clone(), queue_id); + if exchange.waiting.push(cx.waker().clone(), queue_id) { + self.inner + .metrics + .active_wait_requests + .fetch_add(1, atomic::Ordering::Relaxed); + } Poll::Pending } fn unqueue(&self, queue_id: QueueId) { let mut exchange = self.inner.exchange.lock().unwrap(); - exchange.waiting.remove(queue_id); + if exchange.waiting.remove(queue_id) { + self.inner + .metrics + .active_wait_requests + .fetch_sub(1, atomic::Ordering::Relaxed); + } } } diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index 1ea855c0..7a257443 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -67,8 +67,31 @@ impl Future for Recycler { let mut exchange = $self.inner.exchange.lock().unwrap(); if $pool_is_closed || exchange.available.len() >= $self.pool_opts.active_bound() { drop(exchange); + $self + .inner + .metrics + .discarded_superfluous_connection + .fetch_add(1, Ordering::Relaxed); $self.discard.push($conn.close_conn().boxed()); } else { + $self + .inner + .metrics + .connection_returned_to_pool + .fetch_add(1, Ordering::Relaxed); + $self + .inner + .metrics + .connections_in_pool + .fetch_add(1, Ordering::Relaxed); + #[cfg(feature = "hdrhistogram")] + $self + .inner + .metrics + .connection_active_duration + .lock() + .unwrap() + .saturating_record($conn.inner.active_since.elapsed().as_micros() as u64); exchange.available.push_back($conn.into()); if let Some(w) = exchange.waiting.pop() { w.wake(); @@ -81,12 +104,32 @@ impl Future for Recycler { ($self:ident, $conn:ident) => { if $conn.inner.stream.is_none() || $conn.inner.disconnected { // drop unestablished connection + $self + .inner + .metrics + .discarded_unestablished_connection + .fetch_add(1, Ordering::Relaxed); $self.discard.push(futures_util::future::ok(()).boxed()); } else if $conn.inner.tx_status != TxStatus::None || $conn.has_pending_result() { + $self + .inner + .metrics + .dirty_connection_return + .fetch_add(1, Ordering::Relaxed); $self.cleaning.push($conn.cleanup_for_pool().boxed()); } else if $conn.expired() || close { + $self + .inner + .metrics + .discarded_expired_connection + .fetch_add(1, Ordering::Relaxed); $self.discard.push($conn.close_conn().boxed()); } else if $conn.inner.reset_upon_returning_to_a_pool { + $self + .inner + .metrics + .resetting_connection + .fetch_add(1, Ordering::Relaxed); $self.reset.push($conn.reset_for_pool().boxed()); } else { conn_return!($self, $conn, false); @@ -142,6 +185,10 @@ impl Future for Recycler { // anything that comes through .dropped we know has .pool.is_none(). // therefore, dropping the conn won't decrement .exist, so we need to do that. self.discarded += 1; + self.inner + .metrics + .discarded_error_during_cleanup + .fetch_add(1, Ordering::Relaxed); // NOTE: we're discarding the error here let _ = e; } @@ -157,6 +204,10 @@ impl Future for Recycler { // an error during reset. // replace with a new connection self.discarded += 1; + self.inner + .metrics + .discarded_error_during_cleanup + .fetch_add(1, Ordering::Relaxed); // NOTE: we're discarding the error here let _ = e; } @@ -177,6 +228,10 @@ impl Future for Recycler { // an error occurred while closing a connection. // what do we do? we still replace it with a new connection.. self.discarded += 1; + self.inner + .metrics + .discarded_error_during_cleanup + .fetch_add(1, Ordering::Relaxed); // NOTE: we're discarding the error here let _ = e; } @@ -184,6 +239,11 @@ impl Future for Recycler { } if self.discarded != 0 { + self.inner + .metrics + .connection_count + .fetch_sub(self.discarded, Ordering::Relaxed); + // we need to open up slots for new connctions to be established! let mut exchange = self.inner.exchange.lock().unwrap(); exchange.exist -= self.discarded; From c097beacefcdc6e7355b238fedd47278e71ccbef Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Tue, 17 Dec 2024 22:08:59 -0800 Subject: [PATCH 087/130] Add SslOpts::disable_built_in_roots flag --- src/io/tls/native_tls_io.rs | 1 + src/io/tls/rustls_io.rs | 4 +++- src/opts/mod.rs | 12 ++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/io/tls/native_tls_io.rs b/src/io/tls/native_tls_io.rs index adc8d980..9478303b 100644 --- a/src/io/tls/native_tls_io.rs +++ b/src/io/tls/native_tls_io.rs @@ -36,6 +36,7 @@ impl Endpoint { } builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation()); builder.danger_accept_invalid_certs(ssl_opts.accept_invalid_certs()); + builder.disable_built_in_roots(ssl_opts.disable_built_in_roots()); let tls_connector: tokio_native_tls::TlsConnector = builder.build()?.into(); *self = match self { diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index 76080ff2..c4da0fcf 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -46,7 +46,9 @@ impl Endpoint { } let mut root_store = RootCertStore::empty(); - root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|x| x.to_owned())); + if !ssl_opts.disable_built_in_roots() { + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|x| x.to_owned())); + } for cert in ssl_opts.load_root_certs().await? { root_store.add(cert)?; diff --git a/src/opts/mod.rs b/src/opts/mod.rs index e9044450..c8e47596 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -214,6 +214,7 @@ pub struct SslOpts { #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] client_identity: Option, root_certs: Vec>, + disable_built_in_roots: bool, skip_domain_validation: bool, accept_invalid_certs: bool, tls_hostname_override: Option>, @@ -236,6 +237,13 @@ impl SslOpts { self } + /// If `true`, use only the root certificates configured via `with_root_certs`, + /// not any system or built-in certs. + pub fn with_disable_built_in_roots(mut self, disable_built_in_roots: bool) -> Self { + self.disable_built_in_roots = disable_built_in_roots; + self + } + /// The way to not validate the server's domain /// name against its certificate (defaults to `false`). pub fn with_danger_skip_domain_validation(mut self, value: bool) -> Self { @@ -271,6 +279,10 @@ impl SslOpts { &self.root_certs } + pub fn disable_built_in_roots(&self) -> bool { + self.disable_built_in_roots + } + pub fn skip_domain_validation(&self) -> bool { self.skip_domain_validation } From 26e48988abd93c6047aa6863906dae5e4c5fc291 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 20 Dec 2024 11:27:11 +0300 Subject: [PATCH 088/130] Add `built_in_roots` parameter to connection URL. Update docs --- src/opts/mod.rs | 73 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 62 insertions(+), 11 deletions(-) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index c8e47596..7fac74af 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -237,22 +237,61 @@ impl SslOpts { self } - /// If `true`, use only the root certificates configured via `with_root_certs`, - /// not any system or built-in certs. + /// If `true`, use only the root certificates configured via [`SslOpts::with_root_certs`], + /// not any system or built-in certs. By default system built-in certs _will be_ used. + /// + /// # Connection URL + /// + /// Use `built_in_roots` URL parameter to set this value: + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?require_ssl=true&built_in_roots=false")?; + /// assert_eq!(opts.ssl_opts().unwrap().disable_built_in_roots(), true); + /// # Ok(()) } + /// ``` pub fn with_disable_built_in_roots(mut self, disable_built_in_roots: bool) -> Self { self.disable_built_in_roots = disable_built_in_roots; self } - /// The way to not validate the server's domain - /// name against its certificate (defaults to `false`). + /// The way to not validate the server's domain name against its certificate. + /// By default domain name _will be_ validated. + /// + /// # Connection URL + /// + /// Use `built_in_roots` URL parameter to set this value: + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?require_ssl=true&verify_identity=false")?; + /// assert_eq!(opts.ssl_opts().unwrap().skip_domain_validation(), true); + /// # Ok(()) } + /// ``` pub fn with_danger_skip_domain_validation(mut self, value: bool) -> Self { self.skip_domain_validation = value; self } - /// If `true` then client will accept invalid certificate (expired, not trusted, ..) - /// (defaults to `false`). + /// If `true` then client will accept invalid certificate (expired, not trusted, ..). + /// Invalid certificates _won't get_ accepted by default. + /// + /// # Connection URL + /// + /// Use `verify_ca` URL parameter to set this value: + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?require_ssl=true&verify_ca=false")?; + /// assert_eq!(opts.ssl_opts().unwrap().accept_invalid_certs(), true); + /// # Ok(()) } + /// ``` pub fn with_danger_accept_invalid_certs(mut self, value: bool) -> Self { self.accept_invalid_certs = value; self @@ -1596,6 +1635,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { let mut skip_domain_validation = false; let mut accept_invalid_certs = false; + let mut disable_built_in_roots = false; for (key, value) in query_pairs { if key == "pool_min" { @@ -1696,10 +1736,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } else if key == "max_allowed_packet" { match usize::from_str(&value) { - Ok(value) => { - opts.max_allowed_packet = - Some(std::cmp::max(1024, std::cmp::min(1073741824, value))) - } + Ok(value) => opts.max_allowed_packet = Some(value.clamp(1024, 1073741824)), _ => { return Err(UrlError::InvalidParamValue { param: "max_allowed_packet".into(), @@ -1851,6 +1888,18 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "built_in_roots" { + match bool::from_str(&value) { + Ok(x) => { + disable_built_in_roots = !x; + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "built_in_roots".into(), + value, + }); + } + } } else { return Err(UrlError::UnknownParameter { param: key }); } @@ -1868,6 +1917,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { if let Some(ref mut ssl_opts) = opts.ssl_opts.as_mut() { ssl_opts.accept_invalid_certs = accept_invalid_certs; ssl_opts.skip_domain_validation = skip_domain_validation; + ssl_opts.disable_built_in_roots = disable_built_in_roots; } Ok(opts) @@ -1985,7 +2035,7 @@ mod test { ); const URL4: &str = - "mysql://localhost/foo?require_ssl=true&verify_ca=false&verify_identity=false"; + "mysql://localhost/foo?require_ssl=true&verify_ca=false&verify_identity=false&built_in_roots=false"; let opts = Opts::from_url(URL4).unwrap(); assert_eq!( opts.ssl_opts(), @@ -1993,6 +2043,7 @@ mod test { &SslOpts::default() .with_danger_accept_invalid_certs(true) .with_danger_skip_domain_validation(true) + .with_disable_built_in_roots(true) ) ); From add2a3970dd8747210ad715c8396f9aec3e2b76a Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Fri, 20 Dec 2024 13:00:08 -0800 Subject: [PATCH 089/130] verify_identity --- src/opts/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 7fac74af..04118846 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -262,7 +262,7 @@ impl SslOpts { /// /// # Connection URL /// - /// Use `built_in_roots` URL parameter to set this value: + /// Use `verify_identity` URL parameter to set this value: /// /// ``` /// # use mysql_async::*; From 53c39f94eee8308aeaf558e810b439b107b214b6 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 13 Dec 2024 12:06:46 +0300 Subject: [PATCH 090/130] clippy --- src/conn/mod.rs | 9 +++------ src/conn/pool/futures/disconnect_pool.rs | 3 ++- src/opts/mod.rs | 11 ++++++----- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index be778752..dc5e232b 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -504,10 +504,7 @@ impl Conn { self.inner.capabilities = handshake.capabilities() & self.inner.opts.get_capabilities(); self.inner.version = handshake .maria_db_server_version_parsed() - .map(|version| { - self.inner.is_mariadb = true; - version - }) + .inspect(|_| self.inner.is_mariadb = true) .or_else(|| handshake.server_version_parsed()) .unwrap_or((0, 0, 0)); self.inner.id = handshake.connection_id(); @@ -982,7 +979,7 @@ impl Conn { /// Configures the connection based on server settings. In particular: /// /// * It reads and stores socket address inside the connection unless if socket address is - /// already in [`Opts`] or if `prefer_socket` is `false`. + /// already in [`Opts`] or if `prefer_socket` is `false`. /// /// * It reads and stores `max_allowed_packet` in the connection unless it's already in [`Opts`] /// @@ -1429,7 +1426,7 @@ mod test { let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug); conn.query_drop(query).await.unwrap(); - if conn.inner.version <= (8, 0, 11) { + if conn.inner.version < (8, 0, 11) { conn.query_drop(format!("SET old_passwords = {}", val)) .await .unwrap(); diff --git a/src/conn/pool/futures/disconnect_pool.rs b/src/conn/pool/futures/disconnect_pool.rs index 4e5c4f4d..b8a07d4d 100644 --- a/src/conn/pool/futures/disconnect_pool.rs +++ b/src/conn/pool/futures/disconnect_pool.rs @@ -60,7 +60,8 @@ impl Future for DisconnectPool { Some(drop) => match drop.send(None) { Ok(_) => { // Recycler is alive. Waiting for it to finish. - Poll::Ready(Ok(ready!(Box::pin(drop.closed()).as_mut().poll(cx)))) + ready!(Box::pin(drop.closed()).as_mut().poll(cx)); + Poll::Ready(Ok(())) } Err(_) => { // Recycler seem dead. No one will wake us. diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 04118846..f5730fc7 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -1376,8 +1376,7 @@ impl OptsBuilder { /// Note that it'll saturate to proper minimum and maximum values /// for this parameter (see MySql documentation). pub fn max_allowed_packet(mut self, max_allowed_packet: Option) -> Self { - self.opts.max_allowed_packet = - max_allowed_packet.map(|x| std::cmp::max(1024, std::cmp::min(1073741824, x))); + self.opts.max_allowed_packet = max_allowed_packet.map(|x| x.clamp(1024, 1073741824)); self } @@ -1931,7 +1930,7 @@ impl FromStr for Opts { } } -impl<'a> TryFrom<&'a str> for Opts { +impl TryFrom<&str> for Opts { type Error = UrlError; fn try_from(s: &str) -> std::result::Result { @@ -1991,13 +1990,15 @@ mod test { #[test] fn should_convert_url_into_opts() { - let url = "mysql://usr:pw@192.168.1.1:3309/dbname"; - let parsed_url = Url::parse("mysql://usr:pw@192.168.1.1:3309/dbname").unwrap(); + let url = "mysql://usr:pw@192.168.1.1:3309/dbname?prefer_socket=true"; + let parsed_url = + Url::parse("mysql://usr:pw@192.168.1.1:3309/dbname?prefer_socket=true").unwrap(); let mysql_opts = MysqlOpts { user: Some("usr".to_string()), pass: Some("pw".to_string()), db_name: Some("dbname".to_string()), + prefer_socket: true, ..MysqlOpts::default() }; let host = HostPortOrUrl::Url(parsed_url); From d5ccdcc45f2285d71dcc429264df063fbf63115f Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 13 Dec 2024 12:16:44 +0300 Subject: [PATCH 091/130] Turn off twox-hash features --- Cargo.toml | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6808b5fd..8ebff90a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ tokio-util = { version = "0.7.2", features = ["codec", "io"] } tracing = { version = "0.1.37", default-features = false, features = [ "attributes", ], optional = true } -twox-hash = "2" +twox-hash = { version = "2", default-features = false, features = ["xxhash64"] } url = "2.1" hdrhistogram = { version = "7.5", optional = true } @@ -88,15 +88,9 @@ default = [ "binlog", ] -default-rustls = [ - "default-rustls-no-provider", - "aws-lc-rs", -] +default-rustls = ["default-rustls-no-provider", "aws-lc-rs"] -default-rustls-ring = [ - "default-rustls-no-provider", - "ring", -] +default-rustls-ring = ["default-rustls-no-provider", "ring"] default-rustls-no-provider = [ "flate2/rust_backend", From 4d7c33bc9f32e660456d84f4b1e6897e44d48135 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 24 Dec 2024 21:23:28 +0300 Subject: [PATCH 092/130] Shrink default features --- Cargo.toml | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8ebff90a..11ac9752 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,16 +77,7 @@ socket2 = { version = "0.5.2", features = ["all"] } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } [features] -default = [ - "flate2/zlib", - "bigdecimal", - "rust_decimal", - "time", - "frunk", - "derive", - "native-tls-tls", - "binlog", -] +default = ["flate2/zlib", "derive"] default-rustls = ["default-rustls-no-provider", "aws-lc-rs"] @@ -94,13 +85,8 @@ default-rustls-ring = ["default-rustls-no-provider", "ring"] default-rustls-no-provider = [ "flate2/rust_backend", - "bigdecimal", - "rust_decimal", - "time", - "frunk", "derive", "rustls-tls", - "binlog", "tls12", ] From 13a269aee055e538cb613e521d58e54cf3e2d1be Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 24 Dec 2024 21:35:53 +0300 Subject: [PATCH 093/130] Update README.md --- README.md | 34 +++++++++++++--------------------- src/lib.rs | 34 +++++++++++++--------------------- 2 files changed, 26 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 04ad34c6..4c5b9466 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,10 @@ mysql_async = "" ## Crate Features -Default feature set is wide – it includes all default [`mysql_common` features][myslqcommonfeatures] -as well as `native-tls`-based TLS support. +By default there are only two features enabled: + +* `flate2/zlib` — choosing flate2 backend is mandatory +* `derive` — see ["Derive Macros" section in `mysql_common` docs][mysqlcommonderive] ### List Of Features @@ -37,21 +39,18 @@ as well as `native-tls`-based TLS support. mysql_async = { version = "*", default-features = false, features = ["minimal"]} ``` -* `minimal-rust` - same as `minimal` but rust-based flate2 backend is chosen. Enables: +* `minimal-rust` - same as `minimal` but with rust-based flate2 backend. Enables: - `flate2/rust_backend` * `default` – enables the following set of features: - - `minimal` - - `native-tls-tls` - - `bigdecimal` - - `rust_decimal` - - `time` - - `frunk` - - `binlog` + - `flate2/zlib` + - `derive` + +* `default-rustls` – default set of features with TLS via `rustls/aws-lc-rs` -* `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. +* `default-rustls-ring` – default set of features with TLS via `rustls/ring` **Example:** @@ -60,7 +59,7 @@ as well as `native-tls`-based TLS support. mysql_async = { version = "*", default-features = false, features = ["default-rustls"] } ``` -* `native-tls-tls` – enables `native-tls`-based TLS support _(conflicts with `rustls-tls`)_ +* `native-tls-tls` – enables TLS via `native-tls` **Example:** @@ -68,14 +67,6 @@ as well as `native-tls`-based TLS support. [dependencies] mysql_async = { version = "*", default-features = false, features = ["native-tls-tls"] } -* `rustls-tls` – enables `native-tls`-based TLS support _(conflicts with `native-tls-tls`)_ - - **Example:** - - ```toml - [dependencies] - mysql_async = { version = "*", default-features = false, features = ["rustls-tls"] } - * `tracing` – enables instrumentation via `tracing` package. Primary operations (`query`, `prepare`, `exec`) are instrumented at `INFO` level. @@ -94,7 +85,7 @@ as well as `native-tls`-based TLS support. - `mysql_common/binlog" -#### Proxied features +#### Proxied features (see [`mysql_common`` fatures][myslqcommonfeatures]) * `derive` – enables `mysql_common/derive` feature * `chrono` = enables `mysql_common/chrono` feature @@ -104,6 +95,7 @@ as well as `native-tls`-based TLS support. * `frunk` = enables `mysql_common/frunk` feature [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features +[mysqlcommonderive]: https://github.com/blackbeam/rust_mysql_common?tab=readme-ov-file#derive-macros ## TLS/SSL Support diff --git a/src/lib.rs b/src/lib.rs index 0d637963..a91d3fda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,8 +19,10 @@ //! //! # Crate Features //! -//! Default feature set is wide – it includes all default [`mysql_common` features][myslqcommonfeatures] -//! as well as `native-tls`-based TLS support. +//! By default there are only two features enabled: +//! +//! * `flate2/zlib` — choosing flate2 backend is mandatory +//! * `derive` — see ["Derive Macros" section in `mysql_common` docs][mysqlcommonderive] //! //! ## List Of Features //! @@ -36,21 +38,18 @@ //! mysql_async = { version = "*", default-features = false, features = ["minimal"]} //! ``` //! -//! * `minimal-rust` - same as `minimal` but rust-based flate2 backend is chosen. Enables: +//! * `minimal-rust` - same as `minimal` but with rust-based flate2 backend. Enables: //! //! - `flate2/rust_backend` //! //! * `default` – enables the following set of features: //! -//! - `minimal` -//! - `native-tls-tls` -//! - `bigdecimal` -//! - `rust_decimal` -//! - `time` -//! - `frunk` -//! - `binlog` +//! - `flate2/zlib` +//! - `derive` +//! +//! * `default-rustls` – default set of features with TLS via `rustls/aws-lc-rs` //! -//! * `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. +//! * `default-rustls-ring` – default set of features with TLS via `rustls/ring` //! //! **Example:** //! @@ -59,7 +58,7 @@ //! mysql_async = { version = "*", default-features = false, features = ["default-rustls"] } //! ``` //! -//! * `native-tls-tls` – enables `native-tls`-based TLS support _(conflicts with `rustls-tls`)_ +//! * `native-tls-tls` – enables TLS via `native-tls` //! //! **Example:** //! @@ -67,14 +66,6 @@ //! [dependencies] //! mysql_async = { version = "*", default-features = false, features = ["native-tls-tls"] } //! -//! * `rustls-tls` – enables `native-tls`-based TLS support _(conflicts with `native-tls-tls`)_ -//! -//! **Example:** -//! -//! ```toml -//! [dependencies] -//! mysql_async = { version = "*", default-features = false, features = ["rustls-tls"] } -//! //! * `tracing` – enables instrumentation via `tracing` package. //! //! Primary operations (`query`, `prepare`, `exec`) are instrumented at `INFO` level. @@ -93,7 +84,7 @@ //! //! - `mysql_common/binlog" //! -//! ### Proxied features +//! ### Proxied features (see [`mysql_common`` fatures][myslqcommonfeatures]) //! //! * `derive` – enables `mysql_common/derive` feature //! * `chrono` = enables `mysql_common/chrono` feature @@ -103,6 +94,7 @@ //! * `frunk` = enables `mysql_common/frunk` feature //! //! [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features +//! [mysqlcommonderive]: https://github.com/blackbeam/rust_mysql_common?tab=readme-ov-file#derive-macros //! //! # TLS/SSL Support //! From 1f7941e978ffc75919ba997f30379e6cc9994b69 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 26 Dec 2024 13:25:39 +0300 Subject: [PATCH 094/130] ci: update mariadb and tidb versions --- azure-pipelines.yml | 44 ++++++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 04e90187..2e8d41de 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -84,7 +84,6 @@ jobs: SSL=false COMPRESS=false cargo test SSL=true COMPRESS=false cargo test SSL=false COMPRESS=true cargo test - SSL=true COMPRESS=true cargo test env: RUST_BACKTRACE: 1 DATABASE_URL: mysql://root:password@127.0.0.1/mysql @@ -96,6 +95,12 @@ jobs: strategy: maxParallel: 10 matrix: + v91: + DB_VERSION: "9.1" + v90: + DB_VERSION: "9.0" + v84: + DB_VERSION: "8.4" v80: DB_VERSION: "8.0-debian" v57: @@ -126,7 +131,14 @@ jobs: docker exec container bash -l -c "apt-get --allow-unauthenticated -y update" docker exec container bash -l -c "apt-get install -y curl clang libssl-dev pkg-config build-essential" docker exec container bash -l -c "curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable" - displayName: Install Rust in docker + displayName: Install Rust in docker (Debian) + condition: or(eq(variables['DB_VERSION'], '5.6'), eq(variables['DB_VERSION'], '5.7-debian'), eq(variables['DB_VERSION'], '8.0-debian')) + - bash: | + docker exec container bash -l -c "microdnf install dnf" + docker exec container bash -l -c "dnf group install \"Development Tools\"" + docker exec container bash -l -c "curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable" + displayName: Install Rust in docker (RedHat) + condition: not(or(eq(variables['DB_VERSION'], '5.6'), eq(variables['DB_VERSION'], '5.7-debian'), eq(variables['DB_VERSION'], '8.0-debian'))) - bash: | if [[ "5.6" != "$(DB_VERSION)" ]]; then SSL=true; else DATABASE_URL="mysql://root2:password@127.0.0.1/mysql?secure_auth=false"; fi docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test" @@ -145,16 +157,16 @@ jobs: strategy: maxParallel: 10 matrix: + v1162: + DB_VERSION: "11.6.2" + v1152: + DB_VERSION: "11.5.2" + v1144: + DB_VERSION: "11.4.2" v113: - DB_VERSION: "11.3" + DB_VERSION: "11.3.2" v1011: - DB_VERSION: "10.11" - v106: - DB_VERSION: "10.6" - v105: - DB_VERSION: "10.5" - v104: - DB_VERSION: "10.4" + DB_VERSION: "10.11.10" steps: - bash: | sudo apt-get update @@ -207,10 +219,14 @@ jobs: vmImage: "ubuntu-latest" strategy: matrix: - v5.3.0: - DB_VERSION: "v5.3.0" - v5.0.6: - DB_VERSION: "v5.0.6" + v8.5.0: + DB_VERSION: "v8.5.0" + v7.6.0: + DB_VERSION: "v7.6.0" + v6.6.0: + DB_VERSION: "v6.6.0" + v5.4.3: + DB_VERSION: "v5.4.3" steps: - bash: | curl --proto '=https' --tlsv1.2 -sSf https://tiup-mirrors.pingcap.com/install.sh | sh From f31373ed2a794ecf3feadd69fe2bfd03461540ce Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Wed, 1 Jan 2025 14:03:49 +0300 Subject: [PATCH 095/130] Update `mysql_common` dependency to 0.34 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 11ac9752..241791a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" lru = "0.12.0" -mysql_common = { version = "0.33", default-features = false } +mysql_common = { version = "0.34", default-features = false } pem = "3.0" percent-encoding = "2.1.0" pin-project = "1.0.2" From ff7f8960d449cbc10c120751153fc41a2f9acb82 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 2 Jan 2025 13:18:59 +0300 Subject: [PATCH 096/130] Fix `should_change_user` test --- azure-pipelines.yml | 14 ++++++++------ src/conn/mod.rs | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 2e8d41de..0e886c9c 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -41,13 +41,15 @@ jobs: - bash: | cargo +nightly build -Zfeatures=dev_dep SSL=false COMPRESS=false cargo test - SSL=true COMPRESS=false cargo test + SSL=true COMPRESS=false cargo test --features native-tls-tls SSL=false COMPRESS=true cargo test - SSL=true COMPRESS=true cargo test - SSL=true COMPRESS=false cargo test --no-default-features --features default-rustls + SSL=true COMPRESS=true cargo test --features rustls-tls + SSL=true COMPRESS=false cargo check --no-default-features --features default-rustls + SSL=true COMPRESS=false cargo check --no-default-features --features default-rustls-ring SSL=true COMPRESS=false cargo check --no-default-features --features minimal SSL=true COMPRESS=false cargo check --no-default-features --features minimal-rust + SSL=true COMPRESS=false cargo check --no-default-features --features tracing env: RUST_BACKTRACE: 1 DATABASE_URL: mysql://root:root@127.0.0.1:3306/mysql @@ -82,7 +84,7 @@ jobs: displayName: Install Rust (Windows) - bash: | SSL=false COMPRESS=false cargo test - SSL=true COMPRESS=false cargo test + SSL=true COMPRESS=false cargo test --features native-tls-tls SSL=false COMPRESS=true cargo test env: RUST_BACKTRACE: 1 @@ -143,7 +145,7 @@ jobs: if [[ "5.6" != "$(DB_VERSION)" ]]; then SSL=true; else DATABASE_URL="mysql://root2:password@127.0.0.1/mysql?secure_auth=false"; fi docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL COMPRESS=true cargo test" - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL cargo test" + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL cargo test --features native-tls-tls" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL COMPRESS=true cargo test" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL COMPRESS=true cargo test --no-default-features --features default-rustls" env: @@ -206,7 +208,7 @@ jobs: - bash: | docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL COMPRESS=true cargo test" - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test" + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test --features native-tls-tls" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true COMPRESS=true cargo test" if [[ "10.1" != "$(DB_VERSION)" ]]; then docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test --no-default-features --features default-rustls"; fi env: diff --git a/src/conn/mod.rs b/src/conn/mod.rs index dc5e232b..51072f27 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1417,7 +1417,7 @@ mod test { for (plug, val, pass) in variants { dbg!((plug, val, pass, conn.inner.version)); - if plug == "mysql_native_password" && conn.inner.version >= (9, 0, 0) { + if plug == "mysql_native_password" && conn.inner.version >= (8, 4, 0) { continue; } @@ -1569,7 +1569,7 @@ mod test { }; for (i, plugin) in plugins.iter().enumerate() { - if *plugin == "mysql_native_password" && conn.server_version() >= (9, 0, 0) { + if *plugin == "mysql_native_password" && conn.server_version() >= (8, 4, 0) { continue; } From 40a13cd06455757cde7cdb63edd3477f098b0a19 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 2 Jan 2025 13:51:29 +0300 Subject: [PATCH 097/130] README: describre `rustls-tls` feature --- Cargo.toml | 1 + README.md | 11 ++++++++++- azure-pipelines.yml | 4 ++-- src/lib.rs | 11 ++++++++++- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 241791a0..ef0538dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ optional = true [dependencies.rustls] version = "0.23" default-features = false +features = ["std"] optional = true [dependencies.rustls-pemfile] diff --git a/README.md b/README.md index 4c5b9466..2d2bec3c 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,16 @@ By default there are only two features enabled: ```toml [dependencies] - mysql_async = { version = "*", default-features = false, features = ["native-tls-tls"] } + mysql_async = { version = "*", default-features = false, features = ["minimal", "native-tls-tls"] } + +* `rustls-tls` - enables rustls TLS backend with no provider. You should enable one + of existing providers using `aws-lc-rs` or `ring` features: + + **Example:** + + ```toml + [dependencies] + mysql_async = { version = "*", default-features = false, features = ["minimal-rust", "rustls-tls", "ring"] } * `tracing` – enables instrumentation via `tracing` package. diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 0e886c9c..03e7ce99 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -43,13 +43,13 @@ jobs: SSL=false COMPRESS=false cargo test SSL=true COMPRESS=false cargo test --features native-tls-tls SSL=false COMPRESS=true cargo test - SSL=true COMPRESS=true cargo test --features rustls-tls + SSL=true COMPRESS=true cargo test --features rustls-tls,ring SSL=true COMPRESS=false cargo check --no-default-features --features default-rustls SSL=true COMPRESS=false cargo check --no-default-features --features default-rustls-ring SSL=true COMPRESS=false cargo check --no-default-features --features minimal SSL=true COMPRESS=false cargo check --no-default-features --features minimal-rust - SSL=true COMPRESS=false cargo check --no-default-features --features tracing + SSL=true COMPRESS=false cargo check --no-default-features --features minimal,tracing env: RUST_BACKTRACE: 1 DATABASE_URL: mysql://root:root@127.0.0.1:3306/mysql diff --git a/src/lib.rs b/src/lib.rs index a91d3fda..4b8600c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,7 +64,16 @@ //! //! ```toml //! [dependencies] -//! mysql_async = { version = "*", default-features = false, features = ["native-tls-tls"] } +//! mysql_async = { version = "*", default-features = false, features = ["minimal", "native-tls-tls"] } +//! +//! * `rustls-tls` - enables rustls TLS backend with no provider. You should enable one +//! of existing providers using `aws-lc-rs` or `ring` features: +//! +//! **Example:** +//! +//! ```toml +//! [dependencies] +//! mysql_async = { version = "*", default-features = false, features = ["minimal-rust", "rustls-tls", "ring"] } //! //! * `tracing` – enables instrumentation via `tracing` package. //! From 320e1512faf3e5bd43e8aa6cdee9a8756369fbee Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 2 Jan 2025 15:09:28 +0300 Subject: [PATCH 098/130] Bump minor version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ef0538dc..2d5ef837 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.34.2" +version = "0.35.0" exclude = ["test/*"] edition = "2021" categories = ["asynchronous", "database"] From 4a36cd8079f5c8848ad028821474e39756589117 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Thu, 9 Jan 2025 14:50:23 +0700 Subject: [PATCH 099/130] Export metrics from crate root --- src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 4b8600c3..0ef17567 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -550,6 +550,9 @@ pub use self::queryable::{BinaryProtocol, TextProtocol}; #[doc(inline)] pub use self::queryable::stmt::Statement; +#[doc(inline)] +pub use self::conn::pool::Metrics; + /// Futures used in this crate pub mod futures { pub use crate::conn::pool::futures::{DisconnectPool, GetConn}; From 2a91a2738aba037c7cc10ecf2e5adeba9aaff633 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 10 Jan 2025 08:18:09 +0300 Subject: [PATCH 100/130] Bump micro version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 2d5ef837..00040182 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT/Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.35.0" +version = "0.35.1" exclude = ["test/*"] edition = "2021" categories = ["asynchronous", "database"] From c5757082d6bfd3345f63ed5263dddc8e8d180242 Mon Sep 17 00:00:00 2001 From: Stephen Wong Date: Tue, 21 Jan 2025 17:04:30 -0800 Subject: [PATCH 101/130] Add `DriverError::StmtParamsNumberExceedsLimit` When comparing the number of params in a statement against the number of params supplied, if they do not match, the error is supposed to provide what these two numbers are for users to understand what might have gone wrong. This is achieved via following code path in `ExecRoutine`: ```rust if self.stmt.num_params() as usize != params.len() { Err(DriverError::StmtParamsMismatch { required: self.stmt.num_params(), supplied: params.len() as u16, })? } ``` The maximum number of bind variables in a statement supported by MySQL is `65,535` and this is guarded at the server side which means the aforementioned comparison in `ExecRoutine` is hit before that. But at this point, both of these number could go over `65,535`, which is the `u16::MAX`. In that case, when the error is constructed, the conversion from `params.len()`'s `usize` to `u16` would result in a loss of precision, producing a confusing error. Because `self.stmt.num_params()` returns a `u16` while `params.len()` returns a `usize`. For example, if both the number of the required params and the supplied params are `65,539`, which is greater than `u16::MAX`, `self.stmt.num_params()` would have returned a truncated number `3` and even with `as usize` would not have recovered the loss precision. Meanwhile, `params.len()` still returns `65,539` because it is of `usize`. The comparison then produces an incorrect result, saying that they do not match, even though they do. What makes it even more confusing is that in the error message there are two identical numbers because `params.len() as u16` converts the `usize` to `u16`, causing the same loss of precision. Because the `self.stmt.num_params()` returns a `u16`, which intrinsically prevents the detection of mismatched param numbers if they go over `u16::MAX`, and changing that type would be rather intrusive. Therefore, this commit fixes merely the symptom by providing a new error variant `StmtParamsNumberExceedsLimit` that represents the case when the number of provided params exceeds the limit. Ref: https://stackoverflow.com/questions/4922345/how-many-bind-variables-can-i-use-in-a-sql-query-in-mysql-5#comment136409462_11131824 --- src/conn/routines/exec.rs | 10 ++++++++-- src/error/mod.rs | 13 ++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/conn/routines/exec.rs b/src/conn/routines/exec.rs index 5a53e00e..30b24286 100644 --- a/src/conn/routines/exec.rs +++ b/src/conn/routines/exec.rs @@ -6,7 +6,7 @@ use mysql_common::{packets::ComStmtExecuteRequestBuilder, params::Params}; #[cfg(feature = "tracing")] use tracing::{field, info_span, Level, Span}; -use crate::{BinaryProtocol, Conn, DriverError, Statement}; +use crate::{conn::MAX_STATEMENT_PARAMS, BinaryProtocol, Conn, DriverError, Statement}; use super::Routine; @@ -52,10 +52,16 @@ impl Routine<()> for ExecRoutine<'_> { Span::current().record("mysql_async.query.params", ps); } + if params.len() > MAX_STATEMENT_PARAMS { + Err(DriverError::StmtParamsNumberExceedsLimit { + supplied: params.len(), + })? + } + if self.stmt.num_params() as usize != params.len() { Err(DriverError::StmtParamsMismatch { required: self.stmt.num_params(), - supplied: params.len() as u16, + supplied: params.len(), })? } diff --git a/src/error/mod.rs b/src/error/mod.rs index ffb788f8..2f8de211 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -21,6 +21,10 @@ use std::{io, result}; /// Result type alias for this library. pub type Result = result::Result; +/// The maximum number of bind variables supported by MySQL. +/// https://stackoverflow.com/questions/4922345/how-many-bind-variables-can-i-use-in-a-sql-query-in-mysql-5#comment136409462_11131824 +pub(crate) const MAX_STATEMENT_PARAMS: usize = u16::MAX as usize; + /// This type enumerates library errors. #[derive(Debug, Error)] pub enum Error { @@ -135,7 +139,14 @@ pub enum DriverError { required, supplied )] - StmtParamsMismatch { required: u16, supplied: u16 }, + StmtParamsMismatch { required: u16, supplied: usize }, + + #[error( + "MySQL supports up to {} parameters but {} was supplied.", + MAX_STATEMENT_PARAMS, + supplied + )] + StmtParamsNumberExceedsLimit { supplied: usize }, #[error("Unexpected packet.")] UnexpectedPacket { payload: Vec }, From 1d6dc7dddec0466909a17ce51b355537d1c9e3f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 17 Apr 2024 21:47:48 +0000 Subject: [PATCH 102/130] Simplify socket2 keep alive logic Unfortunately SockRef::from only implemented for unix/windows, so this doesn't allow making the feature more portable --- src/io/mod.rs | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/io/mod.rs b/src/io/mod.rs index 4bb27e9e..f454ef96 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -12,8 +12,6 @@ use bytes::BytesMut; use futures_core::{ready, stream}; use mysql_common::proto::codec::PacketCodec as PacketCodecInner; use pin_project::pin_project; -#[cfg(any(unix, windows))] -use socket2::{Socket as Socket2Socket, TcpKeepalive}; #[cfg(unix)] use tokio::io::AsyncWriteExt; use tokio::{ @@ -381,20 +379,8 @@ impl Stream { #[cfg(any(unix, windows))] if let Some(duration) = keepalive { - #[cfg(unix)] - let socket = { - use std::os::unix::prelude::*; - let fd = tcp_stream.as_raw_fd(); - unsafe { Socket2Socket::from_raw_fd(fd) } - }; - #[cfg(windows)] - let socket = { - use std::os::windows::prelude::*; - let sock = tcp_stream.as_raw_socket(); - unsafe { Socket2Socket::from_raw_socket(sock) } - }; - socket.set_tcp_keepalive(&TcpKeepalive::new().with_time(duration))?; - std::mem::forget(socket); + socket2::SockRef::from(&tcp_stream) + .set_tcp_keepalive(&socket2::TcpKeepalive::new().with_time(duration))?; } Ok(Stream { From 85609cac04a26436b76a79c6f84261ccf314cc90 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Thu, 19 Dec 2024 15:56:05 -0800 Subject: [PATCH 103/130] Cache the TlsConnector built from SslOpts --- Cargo.toml | 2 +- src/conn/mod.rs | 8 +++-- src/io/mod.rs | 20 +++--------- src/io/tls/mod.rs | 13 ++++++-- src/io/tls/native_tls_io.rs | 41 ++++++++++++++----------- src/io/tls/no_tls.rs | 24 +++++++++++++++ src/io/tls/rustls_io.rs | 48 ++++++++++++++++------------- src/opts/mod.rs | 61 ++++++++++++++++++++++++++++++++++--- 8 files changed, 153 insertions(+), 64 deletions(-) create mode 100644 src/io/tls/no_tls.rs diff --git a/Cargo.toml b/Cargo.toml index 00040182..bf334691 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ serde = "1" serde_json = "1" socket2 = "0.5.2" thiserror = "2" -tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] } +tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt", "sync"] } tokio-util = { version = "0.7.2", features = ["codec", "io"] } tracing = { version = "0.1.37", default-features = false, features = [ "attributes", diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 51072f27..2815802f 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -549,12 +549,16 @@ impl Conn { ); self.write_struct(&ssl_request).await?; let conn = self; - let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable"); + let ssl_opts = conn.opts().ssl_opts_and_connector().expect("unreachable"); let domain = ssl_opts + .ssl_opts() .tls_hostname_override() .unwrap_or_else(|| conn.opts().ip_or_hostname()) .into(); - conn.stream_mut()?.make_secure(domain, ssl_opts).await?; + let tls_connector = ssl_opts.build_tls_connector().await?; + conn.stream_mut()? + .make_secure(domain, &tls_connector) + .await?; Ok(()) } else { Ok(()) diff --git a/src/io/mod.rs b/src/io/mod.rs index 4bb27e9e..574dd49e 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -42,7 +42,7 @@ use std::{ use crate::{ buffer_pool::PooledBuf, error::IoError, - opts::{HostPortOrUrl, SslOpts, DEFAULT_PORT}, + opts::{HostPortOrUrl, DEFAULT_PORT}, }; #[cfg(unix)] @@ -50,6 +50,8 @@ use crate::io::socket::Socket; mod tls; +pub(crate) use self::tls::TlsConnector; + macro_rules! with_interrupted { ($e:expr) => { loop { @@ -193,18 +195,6 @@ impl Endpoint { matches!(self, Endpoint::Secure(_)) } - #[cfg(all(not(feature = "native-tls-tls"), not(feature = "rustls")))] - pub async fn make_secure( - &mut self, - _domain: String, - _ssl_opts: crate::SslOpts, - ) -> crate::error::Result<()> { - panic!( - "Client had asked for TLS connection but TLS support is disabled. \ - Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]" - ) - } - pub fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> { match *self { Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?, @@ -415,11 +405,11 @@ impl Stream { pub(crate) async fn make_secure( &mut self, domain: String, - ssl_opts: SslOpts, + tls_connector: &TlsConnector, ) -> crate::error::Result<()> { let codec = self.codec.take().unwrap(); let FramedParts { mut io, codec, .. } = codec.into_parts(); - io.make_secure(domain, ssl_opts).await?; + io.make_secure(domain, tls_connector).await?; let codec = Framed::new(io, codec); self.codec = Some(Box::new(codec)); Ok(()) diff --git a/src/io/tls/mod.rs b/src/io/tls/mod.rs index 623e28ae..ca2f71e2 100644 --- a/src/io/tls/mod.rs +++ b/src/io/tls/mod.rs @@ -1,4 +1,13 @@ -#![cfg(any(feature = "native-tls-tls", feature = "rustls"))] - +#[cfg(feature = "native-tls-tls")] mod native_tls_io; +#[cfg(not(any(feature = "rustls-tls", feature = "native-tls-tls")))] +mod no_tls; +#[cfg(feature = "rustls-tls")] mod rustls_io; + +#[cfg(feature = "native-tls-tls")] +pub(crate) use self::native_tls_io::TlsConnector; +#[cfg(not(any(feature = "rustls-tls", feature = "native-tls-tls")))] +pub(crate) use self::no_tls::TlsConnector; +#[cfg(feature = "rustls-tls")] +pub(crate) use self::rustls_io::TlsConnector; diff --git a/src/io/tls/native_tls_io.rs b/src/io/tls/native_tls_io.rs index 9478303b..e085be7c 100644 --- a/src/io/tls/native_tls_io.rs +++ b/src/io/tls/native_tls_io.rs @@ -1,10 +1,10 @@ -#![cfg(feature = "native-tls-tls")] - -use native_tls::{Certificate, TlsConnector}; +use tokio_native_tls::native_tls::{self, Certificate}; use crate::io::Endpoint; use crate::{Result, SslOpts}; +pub use tokio_native_tls::TlsConnector; + impl SslOpts { async fn load_root_certs(&self) -> crate::Result> { let mut output = Vec::new(); @@ -16,29 +16,36 @@ impl SslOpts { Ok(output) } + + pub(crate) async fn build_tls_connector(&self) -> Result { + let mut builder = native_tls::TlsConnector::builder(); + for root_cert in self.load_root_certs().await? { + builder.add_root_certificate(root_cert); + } + + if let Some(client_identity) = self.client_identity() { + builder.identity(client_identity.load().await?); + } + builder.danger_accept_invalid_hostnames(self.skip_domain_validation()); + builder.danger_accept_invalid_certs(self.accept_invalid_certs()); + builder.disable_built_in_roots(self.disable_built_in_roots()); + let tls_connector: TlsConnector = builder.build()?.into(); + Ok(tls_connector) + } } impl Endpoint { - pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { + pub async fn make_secure( + &mut self, + domain: String, + tls_connector: &TlsConnector, + ) -> Result<()> { #[cfg(unix)] if self.is_socket() { // won't secure socket connection return Ok(()); } - let mut builder = TlsConnector::builder(); - for root_cert in ssl_opts.load_root_certs().await? { - builder.add_root_certificate(root_cert); - } - - if let Some(client_identity) = ssl_opts.client_identity() { - builder.identity(client_identity.load().await?); - } - builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation()); - builder.danger_accept_invalid_certs(ssl_opts.accept_invalid_certs()); - builder.disable_built_in_roots(ssl_opts.disable_built_in_roots()); - let tls_connector: tokio_native_tls::TlsConnector = builder.build()?.into(); - *self = match self { Endpoint::Plain(ref mut stream) => { let stream = stream.take().unwrap(); diff --git a/src/io/tls/no_tls.rs b/src/io/tls/no_tls.rs new file mode 100644 index 00000000..325585c9 --- /dev/null +++ b/src/io/tls/no_tls.rs @@ -0,0 +1,24 @@ +use crate::io::Endpoint; +use crate::{Result, SslOpts}; + +#[derive(Clone, Debug)] +pub(crate) struct TlsConnector; + +impl SslOpts { + pub(crate) async fn build_tls_connector(&self) -> Result { + panic!( + "Client had asked for TLS connection but TLS support is disabled. \ + Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]" + ) + } +} + +impl Endpoint { + pub async fn make_secure( + &mut self, + _domain: String, + _tls_connector: &TlsConnector, + ) -> Result<()> { + unreachable!(); + } +} diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index c4da0fcf..ad2ec672 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -1,5 +1,3 @@ -#![cfg(feature = "rustls-tls")] - use std::sync::Arc; use rustls::{ @@ -12,7 +10,7 @@ use rustls::{ }; use rustls_pemfile::certs; -use tokio_rustls::TlsConnector; +pub(crate) use tokio_rustls::TlsConnector; use crate::{io::Endpoint, Result, SslOpts, TlsError}; @@ -35,54 +33,60 @@ impl SslOpts { Ok(output) } -} - -impl Endpoint { - pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { - #[cfg(unix)] - if self.is_socket() { - // won't secure socket connection - return Ok(()); - } + pub(crate) async fn build_tls_connector(&self) -> Result { let mut root_store = RootCertStore::empty(); - if !ssl_opts.disable_built_in_roots() { + if !self.disable_built_in_roots() { root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|x| x.to_owned())); } - for cert in ssl_opts.load_root_certs().await? { + for cert in self.load_root_certs().await? { root_store.add(cert)?; } let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone()); - let mut config = if let Some(identity) = ssl_opts.client_identity() { + let mut config = if let Some(identity) = self.client_identity() { let (cert_chain, priv_key) = identity.load().await?; config_builder.with_client_auth_cert(cert_chain, priv_key)? } else { config_builder.with_no_client_auth() }; - let server_name = ServerName::try_from(domain.as_str()) - .map_err(|_| webpki::InvalidDnsNameError)? - .to_owned(); let mut dangerous = config.dangerous(); let web_pki_verifier = WebPkiServerVerifier::builder(Arc::new(root_store)) .build() .map_err(TlsError::from)?; let dangerous_verifier = DangerousVerifier::new( - ssl_opts.accept_invalid_certs(), - ssl_opts.skip_domain_validation(), + self.accept_invalid_certs(), + self.skip_domain_validation(), web_pki_verifier, ); dangerous.set_certificate_verifier(Arc::new(dangerous_verifier)); + let client_config = Arc::new(config); + Ok(TlsConnector::from(client_config)) + } +} + +impl Endpoint { + pub async fn make_secure( + &mut self, + domain: String, + tls_connector: &TlsConnector, + ) -> Result<()> { + #[cfg(unix)] + if self.is_socket() { + // won't secure socket connection + return Ok(()); + } *self = match self { Endpoint::Plain(ref mut stream) => { let stream = stream.take().unwrap(); - let client_config = Arc::new(config); - let tls_connector = TlsConnector::from(client_config); + let server_name = ServerName::try_from(domain.as_str()) + .map_err(|_| webpki::InvalidDnsNameError)? + .to_owned(); let connection = tls_connector.connect(server_name, stream).await?; Endpoint::Secure(connection) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index f5730fc7..ff860b35 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -17,6 +17,7 @@ pub use rustls_opts::ClientIdentity; use percent_encoding::percent_decode; use rand::Rng; +use tokio::sync::OnceCell; use url::{Host, Url}; use std::{ @@ -603,7 +604,7 @@ pub(crate) struct MysqlOpts { stmt_cache_size: usize, /// Driver will require SSL connection if this option isn't `None` (default to `None`). - ssl_opts: Option, + ssl_opts: Option, /// Prefer socket connection (defaults to `true`). /// @@ -949,7 +950,7 @@ impl Opts { /// /// pub fn ssl_opts(&self) -> Option<&SslOpts> { - self.inner.mysql_opts.ssl_opts.as_ref() + self.inner.mysql_opts.ssl_opts.as_ref().map(|o| &o.ssl_opts) } /// Prefer socket connection (defaults to `true` **temporary `false` on Windows platform**). @@ -1112,6 +1113,10 @@ impl Opts { out } + + pub(crate) fn ssl_opts_and_connector(&self) -> Option<&SslOptsAndCachedConnector> { + self.inner.mysql_opts.ssl_opts.as_ref() + } } impl Default for MysqlOpts { @@ -1141,6 +1146,47 @@ impl Default for MysqlOpts { } } +#[derive(Clone)] +pub(crate) struct SslOptsAndCachedConnector { + ssl_opts: SslOpts, + tls_connector: Arc>, +} + +impl fmt::Debug for SslOptsAndCachedConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SslOptsAndCachedConnector") + .field("ssl_opts", &self.ssl_opts) + .finish() + } +} + +impl SslOptsAndCachedConnector { + fn new(ssl_opts: SslOpts) -> Self { + Self { + ssl_opts, + tls_connector: Arc::new(OnceCell::new()), + } + } + + pub(crate) fn ssl_opts(&self) -> &SslOpts { + &self.ssl_opts + } + + pub(crate) async fn build_tls_connector(&self) -> Result { + self.tls_connector + .get_or_try_init(move || self.ssl_opts.build_tls_connector()) + .await + .cloned() + } +} + +impl PartialEq for SslOptsAndCachedConnector { + fn eq(&self, other: &Self) -> bool { + self.ssl_opts == other.ssl_opts + } +} +impl Eq for SslOptsAndCachedConnector {} + /// Connection pool constraints. /// /// This type stores `min` and `max` constraints for [`crate::Pool`] and ensures that `min <= max`. @@ -1349,7 +1395,7 @@ impl OptsBuilder { /// Defines SSL options. See [`Opts::ssl_opts`]. pub fn ssl_opts>>(mut self, ssl_opts: T) -> Self { - self.opts.ssl_opts = ssl_opts.into(); + self.opts.ssl_opts = ssl_opts.into().map(SslOptsAndCachedConnector::new); self } @@ -1632,6 +1678,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { let mut pool_min = DEFAULT_POOL_CONSTRAINTS.min; let mut pool_max = DEFAULT_POOL_CONSTRAINTS.max; + let mut ssl_opts = None; let mut skip_domain_validation = false; let mut accept_invalid_certs = false; let mut disable_built_in_roots = false; @@ -1855,7 +1902,9 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } else if key == "require_ssl" { match bool::from_str(&value) { - Ok(x) => opts.ssl_opts = x.then(SslOpts::default), + Ok(x) => { + ssl_opts = x.then(SslOpts::default); + } _ => { return Err(UrlError::InvalidParamValue { param: "require_ssl".into(), @@ -1913,12 +1962,14 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } - if let Some(ref mut ssl_opts) = opts.ssl_opts.as_mut() { + if let Some(ref mut ssl_opts) = ssl_opts { ssl_opts.accept_invalid_certs = accept_invalid_certs; ssl_opts.skip_domain_validation = skip_domain_validation; ssl_opts.disable_built_in_roots = disable_built_in_roots; } + opts.ssl_opts = ssl_opts.map(SslOptsAndCachedConnector::new); + Ok(opts) } From 75861eeabd7f3047fde1780bc0e03b00eb6b5f99 Mon Sep 17 00:00:00 2001 From: Moritz Zwerger Date: Sun, 16 Mar 2025 16:19:01 +0100 Subject: [PATCH 104/130] Explicitly rollback when transaction commit fails The mysql server might return an error ("Transactions couldn't be nested"), because the previous commit failed and is not correctly rolled back. See https://github.com/blackbeam/mysql_async/issues/332 --- src/queryable/transaction.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index 6ec64b87..bc44ab4a 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -175,14 +175,26 @@ impl<'a> Transaction<'a> { Ok(Transaction(conn)) } - /// Performs `COMMIT` query. - pub async fn commit(mut self) -> Result<()> { + /// Performs `COMMIT` query or returns an error + async fn try_commit(&mut self) -> Result<()> { let result = self.0.query_iter("COMMIT").await?; result.drop_result().await?; self.0.set_tx_status(TxStatus::None); Ok(()) } + /// Performs `COMMIT` query or rollbacks when any error occurs and returns the original error. + pub async fn commit(mut self) -> Result<()> { + match self.try_commit().await { + Ok(..) => Ok(()), + Err(e) => { + self.0.query_drop("ROLLBACK").await.unwrap_or(()); + self.0.set_tx_status(TxStatus::None); + Err(e) + }, + } + } + /// Performs `ROLLBACK` query. pub async fn rollback(mut self) -> Result<()> { let result = self.0.query_iter("ROLLBACK").await?; From d8180be0a8ba200af8033ca70bfd390bcb3d06c8 Mon Sep 17 00:00:00 2001 From: Moritz Zwerger Date: Sun, 16 Mar 2025 17:07:12 +0100 Subject: [PATCH 105/130] fix fmt linting --- src/queryable/transaction.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index bc44ab4a..f4956db6 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -187,11 +187,11 @@ impl<'a> Transaction<'a> { pub async fn commit(mut self) -> Result<()> { match self.try_commit().await { Ok(..) => Ok(()), - Err(e) => { + Err(e) => { self.0.query_drop("ROLLBACK").await.unwrap_or(()); self.0.set_tx_status(TxStatus::None); Err(e) - }, + } } } From 74c71022f35dfcf703183f3fc80ad1d77450afc6 Mon Sep 17 00:00:00 2001 From: Moritz Zwerger Date: Tue, 25 Mar 2025 15:25:16 +0100 Subject: [PATCH 106/130] new transaction: explicitly rollback when previous connection was dropped dirty Fixes https://github.com/stalwartlabs/mail-server/issues/1271 --- src/conn/mod.rs | 2 +- src/queryable/mod.rs | 3 +-- src/queryable/transaction.rs | 10 ++++------ 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 2815802f..0debcb04 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1205,7 +1205,7 @@ impl Conn { } /// Requires that `self.inner.tx_status != TxStatus::None` - async fn rollback_transaction(&mut self) -> Result<()> { + pub(crate) async fn rollback_transaction(&mut self) -> Result<()> { debug_assert_ne!(self.inner.tx_status, TxStatus::None); self.inner.tx_status = TxStatus::None; self.query_drop("ROLLBACK").await diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index ebef3668..870abe43 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -96,8 +96,7 @@ impl Conn { pub(crate) async fn clean_dirty(&mut self) -> Result<()> { self.drop_result().await?; if self.get_tx_status() == TxStatus::RequiresRollback { - self.set_tx_status(TxStatus::None); - self.exec_drop("ROLLBACK", ()).await?; + self.rollback_transaction().await?; } Ok(()) } diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index f4956db6..7346d393 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -143,6 +143,8 @@ impl<'a> Transaction<'a> { let mut conn = conn.into(); + conn.clean_dirty().await?; + if conn.get_tx_status() != TxStatus::None { return Err(DriverError::NestedTransaction.into()); } @@ -188,8 +190,7 @@ impl<'a> Transaction<'a> { match self.try_commit().await { Ok(..) => Ok(()), Err(e) => { - self.0.query_drop("ROLLBACK").await.unwrap_or(()); - self.0.set_tx_status(TxStatus::None); + self.0.rollback_transaction().await.unwrap_or(()); Err(e) } } @@ -197,10 +198,7 @@ impl<'a> Transaction<'a> { /// Performs `ROLLBACK` query. pub async fn rollback(mut self) -> Result<()> { - let result = self.0.query_iter("ROLLBACK").await?; - result.drop_result().await?; - self.0.set_tx_status(TxStatus::None); - Ok(()) + self.0.rollback_transaction().await } } From 79852269aeee694b6fc8794282be25ef753cfd41 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Wed, 9 Apr 2025 19:18:00 -0700 Subject: [PATCH 107/130] Fix connection count metrics --- src/conn/pool/mod.rs | 16 ++++++++++------ src/conn/pool/recycler.rs | 19 +++++++++---------- src/conn/pool/ttl_check_inerval.rs | 8 ++++++++ 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index b1f96d03..9106764a 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -304,6 +304,10 @@ impl Pool { .metrics .create_failed .fetch_add(1, atomic::Ordering::Relaxed); + self.inner + .metrics + .connection_count + .store(exchange.exist, atomic::Ordering::Relaxed); // we just enabled the creation of a new connection! if let Some(w) = exchange.waiting.pop() { w.wake(); @@ -347,11 +351,6 @@ impl Pool { #[allow(unused_variables)] // `since` is only used when `hdrhistogram` is enabled while let Some(IdlingConn { mut conn, since }) = exchange.available.pop_back() { - self.inner - .metrics - .connections_in_pool - .fetch_sub(1, atomic::Ordering::Relaxed); - if !conn.expired() { #[cfg(feature = "hdrhistogram")] self.inner @@ -383,6 +382,11 @@ impl Pool { } } + self.inner + .metrics + .connections_in_pool + .store(exchange.available.len(), atomic::Ordering::Relaxed); + // we didn't _immediately_ get one -- try to make one // we first try to just do a load so we don't do an unnecessary add then sub if exchange.exist < self.opts.pool_opts().constraints().max() { @@ -392,7 +396,7 @@ impl Pool { self.inner .metrics .connection_count - .fetch_add(1, atomic::Ordering::Relaxed); + .store(exchange.exist, atomic::Ordering::Relaxed); let opts = self.opts.clone(); #[cfg(feature = "hdrhistogram")] diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index 7a257443..2809dc0b 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -79,11 +79,6 @@ impl Future for Recycler { .metrics .connection_returned_to_pool .fetch_add(1, Ordering::Relaxed); - $self - .inner - .metrics - .connections_in_pool - .fetch_add(1, Ordering::Relaxed); #[cfg(feature = "hdrhistogram")] $self .inner @@ -93,6 +88,11 @@ impl Future for Recycler { .unwrap() .saturating_record($conn.inner.active_since.elapsed().as_micros() as u64); exchange.available.push_back($conn.into()); + $self + .inner + .metrics + .connections_in_pool + .store(exchange.available.len(), Ordering::Relaxed); if let Some(w) = exchange.waiting.pop() { w.wake(); } @@ -239,14 +239,13 @@ impl Future for Recycler { } if self.discarded != 0 { - self.inner - .metrics - .connection_count - .fetch_sub(self.discarded, Ordering::Relaxed); - // we need to open up slots for new connctions to be established! let mut exchange = self.inner.exchange.lock().unwrap(); exchange.exist -= self.discarded; + self.inner + .metrics + .connection_count + .store(exchange.exist, Ordering::Relaxed); for _ in 0..self.discarded { if let Some(w) = exchange.waiting.pop() { w.wake(); diff --git a/src/conn/pool/ttl_check_inerval.rs b/src/conn/pool/ttl_check_inerval.rs index a0caa1a9..95b624aa 100644 --- a/src/conn/pool/ttl_check_inerval.rs +++ b/src/conn/pool/ttl_check_inerval.rs @@ -70,6 +70,10 @@ impl TtlCheckInterval { } } exchange.available = kept_available; + self.inner + .metrics + .connections_in_pool + .store(exchange.available.len(), Ordering::Relaxed); to_be_dropped }; @@ -79,6 +83,10 @@ impl TtlCheckInterval { tokio::spawn(idling_conn.conn.disconnect().then(move |_| { let mut exchange = inner.exchange.lock().unwrap(); exchange.exist -= 1; + inner + .metrics + .connection_count + .store(exchange.exist, Ordering::Relaxed); ok::<_, ()>(()) })); } From 9f47061d57e562f5ccc391c06907b7a0405aaf19 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Fri, 11 Apr 2025 13:39:33 +0200 Subject: [PATCH 108/130] Bump `lru` --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index bf334691..8d2618e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ futures-core = "0.3" futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" -lru = "0.12.0" +lru = "0.13.0" mysql_common = { version = "0.34", default-features = false } pem = "3.0" percent-encoding = "2.1.0" From d7335cf66416e26d3dcb0bfacba8c4b368596639 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Fri, 11 Apr 2025 13:39:44 +0200 Subject: [PATCH 109/130] Make license metadata SPDX compliant --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 8d2618e4..d7b6d8c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ authors = ["blackbeam "] description = "Tokio based asynchronous MySql client library." documentation = "https://docs.rs/mysql_async" keywords = ["mysql", "database", "asynchronous", "async"] -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" From 8713a84ce7276ebe779b6a5a028c21cb483dc7db Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Fri, 11 Apr 2025 13:41:16 +0200 Subject: [PATCH 110/130] Replace `futures_util::future::poll_fn` with `std::future::poll_fn` --- src/io/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/mod.rs b/src/io/mod.rs index 4f3ae4f3..56269862 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -447,7 +447,7 @@ impl Stream { self.closed = true; if let Some(mut codec) = self.codec { use futures_sink::Sink; - futures_util::future::poll_fn(|cx| match Pin::new(&mut *codec).poll_close(cx) { + std::future::poll_fn(|cx| match Pin::new(&mut *codec).poll_close(cx) { Poll::Ready(Err(IoError::Io(err))) if err.kind() == NotConnected => { Poll::Ready(Ok(())) } From 0abcf71190ad88ba40962fac58a9b0d7f6eb2a0d Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Apr 2025 11:03:40 +0300 Subject: [PATCH 111/130] make `connection_like::{ToConnection, Connection, ToConnectionResult}` public --- src/lib.rs | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0ef17567..0c8151cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -569,6 +569,8 @@ pub mod prelude { #[doc(inline)] pub use crate::queryable::Queryable; #[doc(inline)] + pub use crate::connection_like::{Connection, ToConnection, ToConnectionResult}; + #[doc(inline)] pub use mysql_common::prelude::ColumnIndex; #[doc(inline)] pub use mysql_common::prelude::FromRow; @@ -596,17 +598,6 @@ pub mod prelude { pub trait StatementLike: crate::queryable::stmt::StatementLike {} impl StatementLike for T {} - /// Everything that is a connection. - /// - /// Note that you could obtain a `'static` connection by giving away `Conn` or `Pool`. - pub trait ToConnection<'a, 't: 'a>: crate::connection_like::ToConnection<'a, 't> {} - // explicitly implemented because of rusdoc - impl<'a> ToConnection<'a, 'static> for &'a crate::Pool {} - impl ToConnection<'static, 'static> for crate::Pool {} - impl ToConnection<'static, 'static> for crate::Conn {} - impl<'a> ToConnection<'a, 'static> for &'a mut crate::Conn {} - impl<'a, 't> ToConnection<'a, 't> for &'a mut crate::Transaction<'t> {} - /// Trait for protocol markers [`crate::TextProtocol`] and [`crate::BinaryProtocol`]. pub trait Protocol: crate::queryable::Protocol {} impl Protocol for crate::BinaryProtocol {} From de1b5392fdd93feeaeca247c016f025aa024f569 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Apr 2025 11:44:43 +0300 Subject: [PATCH 112/130] fix fmt --- src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0c8151cd..d0c15324 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -560,6 +560,8 @@ pub mod futures { /// Traits used in this crate pub mod prelude { + #[doc(inline)] + pub use crate::connection_like::{Connection, ToConnection, ToConnectionResult}; #[doc(inline)] pub use crate::local_infile_handler::GlobalHandler; #[doc(inline)] @@ -569,8 +571,6 @@ pub mod prelude { #[doc(inline)] pub use crate::queryable::Queryable; #[doc(inline)] - pub use crate::connection_like::{Connection, ToConnection, ToConnectionResult}; - #[doc(inline)] pub use mysql_common::prelude::ColumnIndex; #[doc(inline)] pub use mysql_common::prelude::FromRow; From 526ba8d9cc55e8758add7f8967d3a36ab96ebac7 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Fri, 18 Apr 2025 13:25:10 +0200 Subject: [PATCH 113/130] Upgrade `lru` to v0.14 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index d7b6d8c2..cfc052f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ futures-core = "0.3" futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" -lru = "0.13.0" +lru = "0.14.0" mysql_common = { version = "0.34", default-features = false } pem = "3.0" percent-encoding = "2.1.0" From 89bebf0bc6cdddc9566ee003b46f3313cc504073 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Sat, 19 Apr 2025 04:49:28 +0200 Subject: [PATCH 114/130] Drop unused `webpki` dependency --- Cargo.toml | 6 ------ src/error/tls/rustls_error.rs | 27 +++------------------------ src/io/tls/rustls_io.rs | 2 +- 3 files changed, 4 insertions(+), 31 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cfc052f9..713b3d50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,11 +62,6 @@ optional = true version = "2.1.0" optional = true -[dependencies.webpki] -version = ">=0.22.1" -features = ["std"] -optional = true - [dependencies.webpki-roots] version = "0.26.1" optional = true @@ -103,7 +98,6 @@ native-tls-tls = ["native-tls", "tokio-native-tls"] rustls-tls = [ "rustls", "tokio-rustls", - "webpki", "webpki-roots", "rustls-pemfile", ] diff --git a/src/error/tls/rustls_error.rs b/src/error/tls/rustls_error.rs index faae88db..d88f3b18 100644 --- a/src/error/tls/rustls_error.rs +++ b/src/error/tls/rustls_error.rs @@ -7,8 +7,7 @@ use rustls::server::VerifierBuilderError; #[derive(Debug)] pub enum TlsError { Tls(rustls::Error), - Pki(webpki::Error), - InvalidDnsName(webpki::InvalidDnsNameError), + InvalidDnsName(rustls::pki_types::InvalidDnsNameError), VerifierBuilderError(VerifierBuilderError), } @@ -30,41 +29,22 @@ impl From for TlsError { } } -impl From for TlsError { - fn from(e: webpki::InvalidDnsNameError) -> Self { +impl From for TlsError { + fn from(e: rustls::pki_types::InvalidDnsNameError) -> Self { TlsError::InvalidDnsName(e) } } -impl From for TlsError { - fn from(e: webpki::Error) -> Self { - TlsError::Pki(e) - } -} - impl From for crate::Error { fn from(e: rustls::Error) -> Self { crate::Error::Io(crate::error::IoError::Tls(e.into())) } } -impl From for crate::Error { - fn from(e: webpki::Error) -> Self { - crate::Error::Io(crate::error::IoError::Tls(e.into())) - } -} - -impl From for crate::Error { - fn from(e: webpki::InvalidDnsNameError) -> Self { - crate::Error::Io(crate::error::IoError::Tls(e.into())) - } -} - impl std::error::Error for TlsError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { TlsError::Tls(e) => Some(e), - TlsError::Pki(e) => Some(e), TlsError::InvalidDnsName(e) => Some(e), TlsError::VerifierBuilderError(e) => Some(e), } @@ -75,7 +55,6 @@ impl Display for TlsError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TlsError::Tls(e) => e.fmt(f), - TlsError::Pki(e) => e.fmt(f), TlsError::InvalidDnsName(e) => e.fmt(f), TlsError::VerifierBuilderError(e) => e.fmt(f), } diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index ad2ec672..143cf066 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -85,7 +85,7 @@ impl Endpoint { let stream = stream.take().unwrap(); let server_name = ServerName::try_from(domain.as_str()) - .map_err(|_| webpki::InvalidDnsNameError)? + .map_err(TlsError::InvalidDnsName)? .to_owned(); let connection = tls_connector.connect(server_name, stream).await?; From 7bbd74a49d527b948847f50389919e91053ffd75 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Sat, 19 Apr 2025 08:56:27 +0200 Subject: [PATCH 115/130] Drop `pin-project` dependency --- Cargo.toml | 1 - src/conn/pool/ttl_check_inerval.rs | 5 +--- src/io/mod.rs | 48 ++++++++++++++---------------- src/io/socket.rs | 23 +++++--------- 4 files changed, 32 insertions(+), 45 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cfc052f9..81c9449d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,6 @@ lru = "0.14.0" mysql_common = { version = "0.34", default-features = false } pem = "3.0" percent-encoding = "2.1.0" -pin-project = "1.0.2" rand = "0.8.5" serde = "1" serde_json = "1" diff --git a/src/conn/pool/ttl_check_inerval.rs b/src/conn/pool/ttl_check_inerval.rs index a0caa1a9..ec5ffe87 100644 --- a/src/conn/pool/ttl_check_inerval.rs +++ b/src/conn/pool/ttl_check_inerval.rs @@ -7,7 +7,6 @@ // modified, or distributed except according to those terms. use futures_util::future::{ok, FutureExt}; -use pin_project::pin_project; use tokio::time::{self, Interval}; use std::{ @@ -26,10 +25,8 @@ use std::pin::Pin; /// The purpose of this interval is to remove idling connections that both: /// * overflows min bound of the pool; /// * idles longer then `inactive_connection_ttl`. -#[pin_project] pub(crate) struct TtlCheckInterval { inner: Arc, - #[pin] interval: Interval, pool_opts: PoolOpts, } @@ -90,7 +87,7 @@ impl Future for TtlCheckInterval { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { - let _ = futures_core::ready!(self.as_mut().project().interval.poll_tick(cx)); + let _ = futures_core::ready!(Pin::new(&mut self.interval).poll_tick(cx)); let close = self.inner.close.load(Ordering::Acquire); if !close { diff --git a/src/io/mod.rs b/src/io/mod.rs index 56269862..e5705a9f 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -11,7 +11,6 @@ pub use self::{read_packet::ReadPacket, write_packet::WritePacket}; use bytes::BytesMut; use futures_core::{ready, stream}; use mysql_common::proto::codec::PacketCodec as PacketCodecInner; -use pin_project::pin_project; #[cfg(unix)] use tokio::io::AsyncWriteExt; use tokio::{ @@ -116,16 +115,15 @@ impl Encoder for PacketCodec { } } -#[pin_project(project = EndpointProj)] #[derive(Debug)] pub(crate) enum Endpoint { Plain(Option), #[cfg(feature = "native-tls-tls")] - Secure(#[pin] tokio_native_tls::TlsStream), + Secure(tokio_native_tls::TlsStream), #[cfg(feature = "rustls-tls")] - Secure(#[pin] tokio_rustls::client::TlsStream), + Secure(tokio_rustls::client::TlsStream), #[cfg(unix)] - Socket(#[pin] Socket), + Socket(Socket), } /// This future will check that TcpStream is live. @@ -243,17 +241,17 @@ impl AsyncRead for Endpoint { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let mut this = self.project(); + let this = self.get_mut(); with_interrupted!(match this { - EndpointProj::Plain(ref mut stream) => { + Self::Plain(stream) => { Pin::new(stream.as_mut().unwrap()).poll_read(cx, buf) } #[cfg(feature = "native-tls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf), + Self::Secure(stream) => Pin::new(stream).poll_read(cx, buf), #[cfg(feature = "rustls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf), + Self::Secure(stream) => Pin::new(stream).poll_read(cx, buf), #[cfg(unix)] - EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf), + Self::Socket(stream) => Pin::new(stream).poll_read(cx, buf), }) } } @@ -264,17 +262,17 @@ impl AsyncWrite for Endpoint { cx: &mut Context, buf: &[u8], ) -> Poll> { - let mut this = self.project(); + let this = self.get_mut(); with_interrupted!(match this { - EndpointProj::Plain(ref mut stream) => { + Self::Plain(stream) => { Pin::new(stream.as_mut().unwrap()).poll_write(cx, buf) } #[cfg(feature = "native-tls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf), + Self::Secure(stream) => Pin::new(stream).poll_write(cx, buf), #[cfg(feature = "rustls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf), + Self::Secure(stream) => Pin::new(stream).poll_write(cx, buf), #[cfg(unix)] - EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf), + Self::Socket(stream) => Pin::new(stream).poll_write(cx, buf), }) } @@ -282,17 +280,17 @@ impl AsyncWrite for Endpoint { self: Pin<&mut Self>, cx: &mut Context, ) -> Poll> { - let mut this = self.project(); + let this = self.get_mut(); with_interrupted!(match this { - EndpointProj::Plain(ref mut stream) => { + Self::Plain(stream) => { Pin::new(stream.as_mut().unwrap()).poll_flush(cx) } #[cfg(feature = "native-tls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx), + Self::Secure(stream) => Pin::new(stream).poll_flush(cx), #[cfg(feature = "rustls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx), + Self::Secure(stream) => Pin::new(stream).poll_flush(cx), #[cfg(unix)] - EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx), + Self::Socket(stream) => Pin::new(stream).poll_flush(cx), }) } @@ -300,17 +298,17 @@ impl AsyncWrite for Endpoint { self: Pin<&mut Self>, cx: &mut Context, ) -> Poll> { - let mut this = self.project(); + let this = self.get_mut(); with_interrupted!(match this { - EndpointProj::Plain(ref mut stream) => { + Self::Plain(stream) => { Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx) } #[cfg(feature = "native-tls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx), + Self::Secure(stream) => Pin::new(stream).poll_shutdown(cx), #[cfg(feature = "rustls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx), + Self::Secure(stream) => Pin::new(stream).poll_shutdown(cx), #[cfg(unix)] - EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx), + Self::Socket(stream) => Pin::new(stream).poll_shutdown(cx), }) } } diff --git a/src/io/socket.rs b/src/io/socket.rs index 42aae58f..c8a4aaec 100644 --- a/src/io/socket.rs +++ b/src/io/socket.rs @@ -8,7 +8,6 @@ #![cfg(unix)] -use pin_project::pin_project; use tokio::io::{Error, ErrorKind::Interrupted, ReadBuf}; use std::{ @@ -20,10 +19,8 @@ use std::{ use tokio::io::{AsyncRead, AsyncWrite}; /// Unix domain socket connection on unix, or named pipe connection on windows. -#[pin_project] #[derive(Debug)] pub(crate) struct Socket { - #[pin] #[cfg(unix)] inner: tokio::net::UnixStream, } @@ -40,32 +37,28 @@ impl Socket { impl AsyncRead for Socket { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let mut this = self.project(); - with_interrupted!(this.inner.as_mut().poll_read(cx, buf)) + with_interrupted!(Pin::new(&mut self.inner).poll_read(cx, buf)) } } impl AsyncWrite for Socket { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { - let mut this = self.project(); - with_interrupted!(this.inner.as_mut().poll_write(cx, buf)) + with_interrupted!(Pin::new(&mut self.inner).poll_write(cx, buf)) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = self.project(); - with_interrupted!(this.inner.as_mut().poll_flush(cx)) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + with_interrupted!(Pin::new(&mut self.inner).poll_flush(cx)) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = self.project(); - with_interrupted!(this.inner.as_mut().poll_shutdown(cx)) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + with_interrupted!(Pin::new(&mut self.inner).poll_shutdown(cx)) } } From 880b83050383c317a64e4c2d048c04da9d1d6572 Mon Sep 17 00:00:00 2001 From: crai0n <52631259+crai0n@users.noreply.github.com> Date: Mon, 21 Apr 2025 12:00:42 +0200 Subject: [PATCH 116/130] add glue code for ed25519 auth method in mysql_common (#329) * add glue code for ed25519 auth method in mysql_common * handle 0xfe case * Add ed25519 plugin to `should_change_user` test --------- Co-authored-by: Anatoly Ikorsky --- Cargo.toml | 3 +- azure-pipelines.yml | 16 ++-- src/conn/mod.rs | 214 +++++++++++++++++++++++++++++--------------- 3 files changed, 156 insertions(+), 77 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cfc052f9..6a016496 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" lru = "0.14.0" -mysql_common = { version = "0.34", default-features = false } +mysql_common = { version = "0.35", default-features = false } pem = "3.0" percent-encoding = "2.1.0" pin-project = "1.0.2" @@ -121,6 +121,7 @@ time = ["mysql_common/time"] bigdecimal = ["mysql_common/bigdecimal"] rust_decimal = ["mysql_common/rust_decimal"] frunk = ["mysql_common/frunk"] +client_ed25519 = ["mysql_common/client_ed25519"] # other features tracing = ["dep:tracing"] diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 03e7ce99..db33fd28 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -183,6 +183,7 @@ jobs: git checkout 901a7de displayName: Clone rust-mysql-simple (for ssl certs) - bash: | + if [[ "11.6.2" == "$(DB_VERSION)" ]]; then ARG=" --plugin-load-add=auth_ed25519"; fi docker run --rm -d \ --name container \ -v `pwd`:/root \ @@ -197,7 +198,9 @@ jobs: --ssl \ --ssl-ca=/root/rust-mysql-simple/tests/ca.crt \ --ssl-cert=/root/rust-mysql-simple/tests/server.crt \ - --ssl-key=/root/rust-mysql-simple/tests/server-key.pem & + --ssl-key=/root/rust-mysql-simple/tests/server-key.pem \ + --secure-auth=OFF \ + $ARG & while ! nc -W 1 localhost 3307 | grep -q -P '.+'; do sleep 1; done displayName: Run MariaDb in Docker - bash: | @@ -206,11 +209,12 @@ jobs: docker exec container bash -l -c "curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable" displayName: Install Rust in docker - bash: | - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test" - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL COMPRESS=true cargo test" - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test --features native-tls-tls" - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true COMPRESS=true cargo test" - if [[ "10.1" != "$(DB_VERSION)" ]]; then docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test --no-default-features --features default-rustls"; fi + if [[ "11.6.2" == "$(DB_VERSION)" ]]; then FEATURES="client_ed25519"; fi + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test --features $FEATURES" + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL COMPRESS=true cargo test --features $FEATURES" + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test --features native-tls-tls,$FEATURES" + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true COMPRESS=true cargo test --features $FEATURES" + if [[ "10.1" != "$(DB_VERSION)" ]]; then docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test --no-default-features --features default-rustls,$FEATURES"; fi env: RUST_BACKTRACE: 1 DATABASE_URL: mysql://root:password@127.0.0.1/mysql diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 0debcb04..b3a17889 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -633,6 +633,7 @@ impl Conn { return Err(DriverError::CleartextPluginDisabled.into()); } } + x @ AuthPlugin::Ed25519 => x.gen_data(self.inner.opts.pass(), &self.inner.nonce), x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce), }; @@ -671,6 +672,10 @@ impl Conn { Err(DriverError::CleartextPluginDisabled.into()) } } + AuthPlugin::Ed25519 => { + self.continue_ed25519_auth().await?; + Ok(()) + } AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin { name: String::from_utf8_lossy(name.as_ref()).to_string(), } @@ -693,6 +698,24 @@ impl Conn { Ok(()) } + async fn continue_ed25519_auth(&mut self) -> Result<()> { + let packet = self.read_packet().await?; + match packet.first() { + Some(0x00) => { + // ok packet for empty password + Ok(()) + } + Some(0xfe) if !self.inner.auth_switched => { + let auth_switch_request = ParseBuf(&packet).parse::(())?; + self.perform_auth_switch(auth_switch_request).await + } + _ => Err(DriverError::UnexpectedPacket { + payload: packet.to_vec(), + } + .into()), + } + } + async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> { let packet = self.read_packet().await?; match packet.first() { @@ -1547,7 +1570,92 @@ mod test { #[tokio::test] async fn should_change_user() -> super::Result<()> { + /// Whether particular authentication plugin should be tested on the current database. + type ShouldRunFn = fn(bool, (u16, u16, u16)) -> bool; + /// Generates `CREATE USER` and `SET PASSWORD` statements + type CreateUserFn = fn(bool, (u16, u16, u16), &str) -> Vec; + + #[allow(clippy::type_complexity)] + const TEST_MATRIX: [(&str, ShouldRunFn, CreateUserFn); 4] = [ + ( + "mysql_old_password", + |is_mariadb, version| is_mariadb || version < (5, 7, 0), + |is_mariadb, version, pass| { + if is_mariadb { + vec![ + "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password".into(), + "SET old_passwords=1".into(), + format!("ALTER USER '__mats'@'%' IDENTIFIED BY '{pass}'"), + "SET old_passwords=0".into(), + ] + } else if matches!(version, (5, 6, _)) { + vec![ + "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password".into(), + format!("SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD('{pass}')"), + ] + } else { + vec![ + "CREATE USER '__mats'@'%'".into(), + format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"), + ] + } + }, + ), + ( + "mysql_native_password", + |is_mariadb, version| is_mariadb || version < (8, 4, 0), + |is_mariadb, version, pass| { + if is_mariadb { + vec![ + format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password AS PASSWORD('{pass}')") + ] + } else if version < (8, 0, 0) { + vec![ + format!( + "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password" + ), + format!("SET old_passwords = 0"), + format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"), + ] + } else { + vec![ + format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password BY '{pass}'") + ] + } + }, + ), + ( + "caching_sha2_password", + |is_mariadb, version| !is_mariadb && version >= (5, 8, 0), + |_is_mariadb, _version, pass| { + vec![ + format!("CREATE USER '__mats'@'%' IDENTIFIED WITH caching_sha2_password BY '{pass}'") + ] + }, + ), + ( + "client_ed25519", + |is_mariadb, version| is_mariadb && version >= (11, 6, 2), + |_is_mariadb, _version, pass| { + vec![format!( + "CREATE USER '__mats'@'%' IDENTIFIED WITH ed25519 AS PASSWORD('{pass}')" + )] + }, + ), + ]; + + fn random_pass() -> String { + let mut rng = rand::thread_rng(); + let mut pass = [0u8; 10]; + pass.try_fill(&mut rng).unwrap(); + + IntoIterator::into_iter(pass) + .map(|x| ((x % (123 - 97)) + 97) as char) + .collect() + } + let mut conn = Conn::new(get_opts()).await?; + assert_eq!( conn.query_first::("SELECT @foo").await?.unwrap(), Value::NULL @@ -1566,86 +1674,52 @@ mod test { Value::NULL ); - let plugins: &[&str] = if !conn.inner.is_mariadb && conn.server_version() >= (5, 8, 0) { - &["mysql_native_password", "caching_sha2_password"] - } else { - &["mysql_native_password"] - }; + for (i, (plugin, should_run, create_statements)) in TEST_MATRIX.iter().enumerate() { + dbg!(plugin); + let is_mariadb = conn.inner.is_mariadb; + let version = conn.server_version(); - for (i, plugin) in plugins.iter().enumerate() { - if *plugin == "mysql_native_password" && conn.server_version() >= (8, 4, 0) { - continue; - } + if should_run(is_mariadb, version) { + let pass = random_pass(); - let mut rng = rand::thread_rng(); - let mut pass = [0u8; 10]; - pass.try_fill(&mut rng).unwrap(); - let pass: String = IntoIterator::into_iter(pass) - .map(|x| ((x % (123 - 97)) + 97) as char) - .collect(); - - let result = conn - .query_drop("DROP USER /*!50700 IF EXISTS */ /*M!100103 IF EXISTS */ __mats") - .await; - if matches!(conn.server_version(), (5, 6, _)) && i == 0 { - // IF EXISTS is not supported on 5.6 so the query will fail on the first iteration - drop(result); - } else { - result.unwrap(); - } + let result = conn + .query_drop("DROP USER /*!50700 IF EXISTS */ /*M!100103 IF EXISTS */ __mats") + .await; - if conn.inner.is_mariadb || conn.server_version() < (5, 7, 0) { - if matches!(conn.server_version(), (5, 6, _)) { - conn.query_drop("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password") - .await - .unwrap(); - conn.query_drop(format!( - "SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD({})", - Value::from(pass.clone()).as_sql(false) - )) - .await - .unwrap(); + if matches!(version, (5, 6, _)) && i == 0 { + // IF EXISTS is not supported on 5.6 so the query will fail on the first iteration + drop(result); } else { - conn.query_drop("CREATE USER '__mats'@'%'").await.unwrap(); - conn.query_drop(format!( - "SET PASSWORD FOR '__mats'@'%' = PASSWORD({})", - Value::from(pass.clone()).as_sql(false) - )) + result.unwrap(); + } + + for statement in create_statements(is_mariadb, version, &pass) { + conn.query_drop(dbg!(statement)).await.unwrap(); + } + + let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap(); + conn2 + .change_user( + ChangeUserOpts::default() + .with_db_name(None) + .with_user(Some("__mats".into())) + .with_pass(Some(pass)), + ) .await .unwrap(); - } - } else { - conn.query_drop(format!( - "CREATE USER '__mats'@'%' IDENTIFIED WITH {} BY {}", - plugin, - Value::from(pass.clone()).as_sql(false) - )) - .await - .unwrap(); - }; - let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap(); - conn2 - .change_user( - ChangeUserOpts::default() - .with_db_name(None) - .with_user(Some("__mats".into())) - .with_pass(Some(pass)), - ) - .await - .unwrap(); - let (db, user) = conn2 - .query_first::<(Option, String), _>("SELECT DATABASE(), USER();") - .await - .unwrap() - .unwrap(); - assert_eq!(db, None); - assert!(user.starts_with("__mats")); + let (db, user) = conn2 + .query_first::<(Option, String), _>("SELECT DATABASE(), USER();") + .await + .unwrap() + .unwrap(); + assert_eq!(db, None); + assert!(user.starts_with("__mats")); - conn2.disconnect().await.unwrap(); + conn2.disconnect().await.unwrap(); + } } - conn.disconnect().await?; Ok(()) } From b216e74a9fd0c96743c7c42e1b86dbf229f49e13 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Mon, 21 Apr 2025 12:01:00 +0200 Subject: [PATCH 117/130] Replace `crossbeam` with `crossbeam-queue` (#342) --- Cargo.toml | 2 +- src/buffer_pool.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6a016496..4e12f066 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ categories = ["asynchronous", "database"] [dependencies] bytes = "1.4" -crossbeam = "0.8.1" +crossbeam-queue = "0.3" flate2 = { version = "1.0", default-features = false } futures-core = "0.3" futures-util = "0.3" diff --git a/src/buffer_pool.rs b/src/buffer_pool.rs index 7e89cfc4..d9391c6a 100644 --- a/src/buffer_pool.rs +++ b/src/buffer_pool.rs @@ -6,7 +6,7 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use crossbeam::queue::ArrayQueue; +use crossbeam_queue::ArrayQueue; use std::{mem::take, ops::Deref, sync::Arc}; #[derive(Debug)] From ca6e8ea2055eaf0b101dc4675222da6e0881cba1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 11 Feb 2025 13:24:43 +0000 Subject: [PATCH 118/130] rand 0.9 --- Cargo.toml | 2 +- src/conn/mod.rs | 7 +++---- src/opts/mod.rs | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 695d88cf..c19bc0c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ lru = "0.14.0" mysql_common = { version = "0.35", default-features = false } pem = "3.0" percent-encoding = "2.1.0" -rand = "0.8.5" +rand = "0.9" serde = "1" serde_json = "1" socket2 = "0.5.2" diff --git a/src/conn/mod.rs b/src/conn/mod.rs index b3a17889..767bcd17 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1310,7 +1310,7 @@ mod test { use bytes::Bytes; use futures_util::stream::{self, StreamExt}; use mysql_common::constants::MAX_PAYLOAD_LEN; - use rand::Fill; + use rand::Rng; use tokio::{io::AsyncWriteExt, net::TcpListener}; use crate::{ @@ -1645,9 +1645,8 @@ mod test { ]; fn random_pass() -> String { - let mut rng = rand::thread_rng(); - let mut pass = [0u8; 10]; - pass.try_fill(&mut rng).unwrap(); + let mut rng = rand::rng(); + let pass: [u8; 10] = rng.gen(); IntoIterator::into_iter(pass) .map(|x| ((x % (123 - 97)) + 97) as char) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index ff860b35..9593df2e 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -451,7 +451,7 @@ impl PoolOpts { pub(crate) fn new_connection_ttl_deadline(&self) -> Option { if let Some(ttl) = self.abs_conn_ttl { let jitter = if let Some(jitter) = self.abs_conn_ttl_jitter { - Duration::from_secs(rand::thread_rng().gen_range(0..=jitter.as_secs())) + Duration::from_secs(rand::rng().random_range(0..=jitter.as_secs())) } else { Duration::ZERO }; From 576d17a01590fbd91bede2d88b0d46e02d7cb221 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 25 Apr 2025 08:28:03 +0300 Subject: [PATCH 119/130] Hide inner `Connection` structure. impl `Queryable` for `Connection`. --- src/connection_like/mod.rs | 114 ++++++++++++++---- src/io/read_packet.rs | 4 +- src/io/write_packet.rs | 2 +- src/lib.rs | 5 +- src/query.rs | 27 ++--- src/queryable/mod.rs | 65 ++++++++-- src/queryable/query_result/mod.rs | 15 ++- .../query_result/result_set_stream.rs | 4 +- src/queryable/transaction.rs | 29 +++-- 9 files changed, 195 insertions(+), 70 deletions(-) diff --git a/src/connection_like/mod.rs b/src/connection_like/mod.rs index b47adfaa..d86a313d 100644 --- a/src/connection_like/mod.rs +++ b/src/connection_like/mod.rs @@ -10,9 +10,9 @@ use futures_util::FutureExt; use crate::{BoxFuture, Pool}; -/// Connection. +/// Inner [`Connection`] representation. #[derive(Debug)] -pub enum Connection<'a, 't: 'a> { +pub(crate) enum ConnectionInner<'a, 't: 'a> { /// Just a connection. Conn(crate::Conn), /// Mutable reference to a connection. @@ -21,21 +21,92 @@ pub enum Connection<'a, 't: 'a> { Tx(&'a mut crate::Transaction<'t>), } +impl std::ops::Deref for ConnectionInner<'_, '_> { + type Target = crate::Conn; + + fn deref(&self) -> &Self::Target { + match self { + ConnectionInner::Conn(ref conn) => conn, + ConnectionInner::ConnMut(conn) => conn, + ConnectionInner::Tx(tx) => tx.0.deref(), + } + } +} + +impl std::ops::DerefMut for ConnectionInner<'_, '_> { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + ConnectionInner::Conn(conn) => conn, + ConnectionInner::ConnMut(conn) => conn, + ConnectionInner::Tx(tx) => tx.0.inner.deref_mut(), + } + } +} + +/// Some connection. +/// +/// This could at least be queried. +#[derive(Debug)] +pub struct Connection<'a, 't: 'a> { + pub(crate) inner: ConnectionInner<'a, 't>, +} + +impl Connection<'_, '_> { + #[inline] + pub(crate) fn as_mut(&mut self) -> &mut crate::Conn { + &mut self.inner + } +} + +impl<'a, 't: 'a> Connection<'a, 't> { + /// Borrows a [`Connection`] rather than consuming it. + /// + /// This is useful to allow calling [`Query`] methods while still retaining + /// ownership of the original connection. + /// + /// # Examples + /// + /// ```no_run + /// # use mysql_async::Connection; + /// # use mysql_async::prelude::Query; + /// async fn connection_by_ref(mut connection: Connection<'_, '_>) { + /// // Perform some query + /// "SELECT 1".ignore(connection.by_ref()).await.unwrap(); + /// // Perform another query. + /// // We can only do this because we used `by_ref` earlier. + /// "SELECT 2".ignore(connection).await.unwrap(); + /// } + /// ``` + /// + /// [`Query`]: crate::prelude::Query + pub fn by_ref(&mut self) -> Connection<'_, '_> { + Connection { + inner: ConnectionInner::ConnMut(self.as_mut()), + } + } +} + impl From for Connection<'static, 'static> { fn from(conn: crate::Conn) -> Self { - Connection::Conn(conn) + Self { + inner: ConnectionInner::Conn(conn), + } } } impl<'a> From<&'a mut crate::Conn> for Connection<'a, 'static> { fn from(conn: &'a mut crate::Conn) -> Self { - Connection::ConnMut(conn) + Self { + inner: ConnectionInner::ConnMut(conn), + } } } impl<'a, 't> From<&'a mut crate::Transaction<'t>> for Connection<'a, 't> { fn from(tx: &'a mut crate::Transaction<'t>) -> Self { - Connection::Tx(tx) + Self { + inner: ConnectionInner::Tx(tx), + } } } @@ -43,25 +114,11 @@ impl std::ops::Deref for Connection<'_, '_> { type Target = crate::Conn; fn deref(&self) -> &Self::Target { - match self { - Connection::Conn(ref conn) => conn, - Connection::ConnMut(conn) => conn, - Connection::Tx(tx) => tx.0.deref(), - } - } -} - -impl std::ops::DerefMut for Connection<'_, '_> { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - Connection::Conn(conn) => conn, - Connection::ConnMut(conn) => conn, - Connection::Tx(tx) => tx.0.deref_mut(), - } + &self.inner } } -/// Result of `ToConnection::to_connection` call. +/// Result of a [`ToConnection::to_connection`] call. pub enum ToConnectionResult<'a, 't: 'a> { /// Connection is immediately available. Immediate(Connection<'a, 't>), @@ -69,7 +126,22 @@ pub enum ToConnectionResult<'a, 't: 'a> { Mediate(BoxFuture<'a, Connection<'a, 't>>), } +impl<'a, 't: 'a> ToConnectionResult<'a, 't> { + /// Resolves `self` to a connection. + #[inline] + pub async fn resolve(self) -> crate::Result> { + match self { + ToConnectionResult::Immediate(immediate) => Ok(immediate), + ToConnectionResult::Mediate(mediate) => mediate.await, + } + } +} + +/// Everything that can be given in exchange to a connection. +/// +/// Note that you could obtain a `'static` connection by giving away `Conn` or `Pool`. pub trait ToConnection<'a, 't: 'a>: Send { + /// Converts self to a connection. fn to_connection(self) -> ToConnectionResult<'a, 't>; } diff --git a/src/io/read_packet.rs b/src/io/read_packet.rs index 6c4646ee..62cbe76e 100644 --- a/src/io/read_packet.rs +++ b/src/io/read_packet.rs @@ -37,7 +37,7 @@ impl Future for ReadPacket<'_, '_> { type Output = std::result::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let packet_opt = match self.0.stream_mut() { + let packet_opt = match self.0.as_mut().stream_mut() { Ok(stream) => ready!(Pin::new(stream).poll_next(cx)).transpose()?, // `ConnectionClosed` error. Err(_) => None, @@ -45,7 +45,7 @@ impl Future for ReadPacket<'_, '_> { match packet_opt { Some(packet) => { - self.0.touch(); + self.0.as_mut().touch(); Poll::Ready(Ok(packet)) } None => Poll::Ready(Err(Error::new( diff --git a/src/io/write_packet.rs b/src/io/write_packet.rs index 0449edb1..5d06f1d6 100644 --- a/src/io/write_packet.rs +++ b/src/io/write_packet.rs @@ -44,7 +44,7 @@ impl Future for WritePacket<'_, '_> { ref mut data, } = *self; - match conn.stream_mut() { + match conn.as_mut().stream_mut() { Ok(stream) => { if data.is_some() { let codec = Pin::new(stream.codec.as_mut().expect("must be here")); diff --git a/src/lib.rs b/src/lib.rs index d0c15324..4ad9d735 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -553,6 +553,9 @@ pub use self::queryable::stmt::Statement; #[doc(inline)] pub use self::conn::pool::Metrics; +#[doc(inline)] +pub use crate::connection_like::{Connection, ToConnectionResult}; + /// Futures used in this crate pub mod futures { pub use crate::conn::pool::futures::{DisconnectPool, GetConn}; @@ -561,7 +564,7 @@ pub mod futures { /// Traits used in this crate pub mod prelude { #[doc(inline)] - pub use crate::connection_like::{Connection, ToConnection, ToConnectionResult}; + pub use crate::connection_like::ToConnection; #[doc(inline)] pub use crate::local_infile_handler::GlobalHandler; #[doc(inline)] diff --git a/src/query.rs b/src/query.rs index 976c1724..75ea1ca9 100644 --- a/src/query.rs +++ b/src/query.rs @@ -11,7 +11,6 @@ use std::borrow::Cow; use futures_util::FutureExt; use crate::{ - connection_like::ToConnectionResult, from_row, prelude::{FromRow, StatementLike, ToConnection}, tracing_utils::LevelInfo, @@ -217,11 +216,8 @@ impl Query for Q { C: ToConnection<'a, 't> + 'a, { async move { - let mut conn = match conn.to_connection() { - ToConnectionResult::Immediate(conn) => conn, - ToConnectionResult::Mediate(fut) => fut.await?, - }; - conn.raw_query::<'_, _, LevelInfo>(self).await?; + let mut conn = conn.to_connection().resolve().await?; + conn.as_mut().raw_query::<'_, _, LevelInfo>(self).await?; Ok(QueryResult::new(conn)) } .boxed() @@ -264,14 +260,12 @@ where C: ToConnection<'a, 't> + 'a, { async move { - let mut conn = match conn.to_connection() { - ToConnectionResult::Immediate(conn) => conn, - ToConnectionResult::Mediate(fut) => fut.await?, - }; + let mut conn = conn.to_connection().resolve().await?; - let statement = conn.get_statement(self.query).await?; + let statement = conn.as_mut().get_statement(self.query).await?; - conn.execute_statement(&statement, self.params.into()) + conn.as_mut() + .execute_statement(&statement, self.params.into()) .await?; Ok(QueryResult::new(conn)) @@ -324,15 +318,12 @@ where C: ToConnection<'a, 't> + 'a, { async move { - let mut conn = match conn.to_connection() { - ToConnectionResult::Immediate(conn) => conn, - ToConnectionResult::Mediate(fut) => fut.await?, - }; + let mut conn = conn.to_connection().resolve().await?; - let statement = conn.get_statement(self.query).await?; + let statement = conn.as_mut().get_statement(self.query).await?; for params in self.params { - conn.execute_statement(&statement, params).await?; + conn.as_mut().execute_statement(&statement, params).await?; } Ok(()) diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index 870abe43..5a6d93aa 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -31,7 +31,7 @@ use crate::{ query::AsQuery, queryable::query_result::ResultSetMeta, tracing_utils::{LevelInfo, LevelTrace, TracingLevel}, - BoxFuture, Column, Conn, Params, ResultSetStream, Row, + BoxFuture, Column, Conn, Connection, Params, ResultSetStream, Row, }; pub mod query_result; @@ -537,7 +537,7 @@ impl Queryable for Conn { impl Queryable for Transaction<'_> { fn ping(&mut self) -> BoxFuture<'_, ()> { - self.0.ping() + self.0.as_mut().ping() } fn query_iter<'a, Q>( @@ -547,18 +547,18 @@ impl Queryable for Transaction<'_> { where Q: AsQuery + 'a, { - self.0.query_iter(query) + self.0.as_mut().query_iter(query) } fn prep<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Statement> where Q: AsQuery + 'a, { - self.0.prep(query) + self.0.as_mut().prep(query) } fn close(&mut self, stmt: Statement) -> BoxFuture<'_, ()> { - self.0.close(stmt) + self.0.as_mut().close(stmt) } fn exec_iter<'a: 's, 's, Q, P>( @@ -570,7 +570,7 @@ impl Queryable for Transaction<'_> { Q: StatementLike + 'a, P: Into, { - self.0.exec_iter(stmt, params) + self.0.as_mut().exec_iter(stmt, params) } fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()> @@ -580,7 +580,58 @@ impl Queryable for Transaction<'_> { I::IntoIter: Send, P: Into + Send, { - self.0.exec_batch(stmt, params_iter) + self.0.as_mut().exec_batch(stmt, params_iter) + } +} + +impl<'c, 't: 'c> Queryable for Connection<'c, 't> { + #[inline] + fn ping(&mut self) -> BoxFuture<'_, ()> { + self.as_mut().ping() + } + + #[inline] + fn query_iter<'a, Q>( + &'a mut self, + query: Q, + ) -> BoxFuture<'a, QueryResult<'a, 'static, TextProtocol>> + where + Q: AsQuery + 'a, + { + self.as_mut().query_iter(query) + } + + fn prep<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Statement> + where + Q: AsQuery + 'a, + { + self.as_mut().prep(query) + } + + fn close(&mut self, stmt: Statement) -> BoxFuture<'_, ()> { + self.as_mut().close(stmt) + } + + fn exec_iter<'a: 's, 's, Q, P>( + &'a mut self, + stmt: Q, + params: P, + ) -> BoxFuture<'s, QueryResult<'a, 'static, BinaryProtocol>> + where + Q: StatementLike + 'a, + P: Into, + { + self.as_mut().exec_iter(stmt, params) + } + + fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()> + where + S: StatementLike + 'b, + I: IntoIterator + Send + 'b, + I::IntoIter: Send, + P: Into + Send, + { + self.as_mut().exec_batch(stmt, params_iter) } } diff --git a/src/queryable/query_result/mod.rs b/src/queryable/query_result/mod.rs index 94a3de96..a83ebd6a 100644 --- a/src/queryable/query_result/mod.rs +++ b/src/queryable/query_result/mod.rs @@ -124,21 +124,21 @@ where if columns.is_empty() { // Empty, but not yet consumed result set. - self.conn.set_pending_result(None)?; + self.conn.as_mut().set_pending_result(None)?; } else { // Not yet consumed non-empty result set. - let packet = match self.conn.read_packet().await { + let packet = match self.conn.as_mut().read_packet().await { Ok(packet) => packet, Err(err) => { // Next row contained an error. No more data will follow. - self.conn.set_pending_result(None)?; + self.conn.as_mut().set_pending_result(None)?; return Err(err); } }; if P::is_last_result_set_packet(self.conn.capabilities(), &packet) { // `packet` is a result set terminator. - self.conn.set_pending_result(None)?; + self.conn.as_mut().set_pending_result(None)?; } else { // `packet` is a result set row. row = Some(P::read_result_set_row(&packet, columns)?); @@ -154,7 +154,10 @@ where async fn next_set(&mut self) -> crate::Result { if self.conn.more_results_exists() { // More data will follow. - self.conn.routine(NextSetRoutine::

::new()).await?; + self.conn + .as_mut() + .routine(NextSetRoutine::

::new()) + .await?; } Ok(self.conn.has_pending_result()) } @@ -190,7 +193,7 @@ where #[doc(hidden)] pub async fn next(&mut self) -> Result> { loop { - match self.conn.use_pending_result()?.cloned() { + match self.conn.as_mut().use_pending_result()?.cloned() { Some(PendingResult::Pending(meta)) => return self.next_row_or_next_set(meta).await, Some(PendingResult::Taken(meta)) => self.skip_taken(meta).await?, None => return Ok(None), diff --git a/src/queryable/query_result/result_set_stream.rs b/src/queryable/query_result/result_set_stream.rs index d9eecc7d..b779e93c 100644 --- a/src/queryable/query_result/result_set_stream.rs +++ b/src/queryable/query_result/result_set_stream.rs @@ -189,7 +189,7 @@ where async fn setup_stream( &mut self, ) -> crate::Result>, Arc<[Column]>)>> { - match self.conn.use_pending_result()? { + match self.conn.as_mut().use_pending_result()? { Some(PendingResult::Taken(meta)) => { let meta = (*meta).clone(); self.skip_taken(meta).await?; @@ -199,7 +199,7 @@ where } let ok_packet = self.conn.last_ok_packet().cloned(); - let columns = match self.conn.take_pending_result()? { + let columns = match self.conn.as_mut().take_pending_result()? { Some(meta) => meta.columns().clone(), None => return Ok(None), }; diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index 7346d393..b857f1a7 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -143,7 +143,7 @@ impl<'a> Transaction<'a> { let mut conn = conn.into(); - conn.clean_dirty().await?; + conn.as_mut().clean_dirty().await?; if conn.get_tx_status() != TxStatus::None { return Err(DriverError::NestedTransaction.into()); @@ -155,33 +155,38 @@ impl<'a> Transaction<'a> { if let Some(isolation_level) = isolation_level { let query = format!("SET TRANSACTION ISOLATION LEVEL {}", isolation_level); - conn.query_drop(query).await?; + conn.as_mut().query_drop(query).await?; } if let Some(readonly) = readonly { if readonly { - conn.query_drop("SET TRANSACTION READ ONLY").await?; + conn.as_mut() + .query_drop("SET TRANSACTION READ ONLY") + .await?; } else { - conn.query_drop("SET TRANSACTION READ WRITE").await?; + conn.as_mut() + .query_drop("SET TRANSACTION READ WRITE") + .await?; } } if consistent_snapshot { - conn.query_drop("START TRANSACTION WITH CONSISTENT SNAPSHOT") + conn.as_mut() + .query_drop("START TRANSACTION WITH CONSISTENT SNAPSHOT") .await? } else { - conn.query_drop("START TRANSACTION").await? + conn.as_mut().query_drop("START TRANSACTION").await? }; - conn.set_tx_status(TxStatus::InTransaction); + conn.as_mut().set_tx_status(TxStatus::InTransaction); Ok(Transaction(conn)) } /// Performs `COMMIT` query or returns an error async fn try_commit(&mut self) -> Result<()> { - let result = self.0.query_iter("COMMIT").await?; + let result = self.0.as_mut().query_iter("COMMIT").await?; result.drop_result().await?; - self.0.set_tx_status(TxStatus::None); + self.0.as_mut().set_tx_status(TxStatus::None); Ok(()) } @@ -190,7 +195,7 @@ impl<'a> Transaction<'a> { match self.try_commit().await { Ok(..) => Ok(()), Err(e) => { - self.0.rollback_transaction().await.unwrap_or(()); + self.0.as_mut().rollback_transaction().await.unwrap_or(()); Err(e) } } @@ -198,7 +203,7 @@ impl<'a> Transaction<'a> { /// Performs `ROLLBACK` query. pub async fn rollback(mut self) -> Result<()> { - self.0.rollback_transaction().await + self.0.as_mut().rollback_transaction().await } } @@ -213,7 +218,7 @@ impl Deref for Transaction<'_> { impl Drop for Transaction<'_> { fn drop(&mut self) { if self.0.get_tx_status() == TxStatus::InTransaction { - self.0.set_tx_status(TxStatus::RequiresRollback); + self.0.as_mut().set_tx_status(TxStatus::RequiresRollback); } } } From cd8458dfcab884cc5aa4f03b13e93340bbdb05d4 Mon Sep 17 00:00:00 2001 From: Paolo Barbolini Date: Wed, 23 Apr 2025 10:07:01 +0200 Subject: [PATCH 120/130] Release `mysql_async` v0.36.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c19bc0c4..816816f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT OR Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.35.1" +version = "0.36.0" exclude = ["test/*"] edition = "2021" categories = ["asynchronous", "database"] From fca540cabab1e3af457cba34a41f8749daa1a25f Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 24 Apr 2025 12:54:50 +0300 Subject: [PATCH 121/130] Require pool constraints `max` bound to be greater than `0` --- src/opts/mod.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index 9593df2e..a938a0d9 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -43,7 +43,8 @@ pub const DEFAULT_POOL_CONSTRAINTS: PoolConstraints = PoolConstraints { min: 10, // const_assert!( _DEFAULT_POOL_CONSTRAINTS_ARE_CORRECT, - DEFAULT_POOL_CONSTRAINTS.min <= DEFAULT_POOL_CONSTRAINTS.max, + DEFAULT_POOL_CONSTRAINTS.min <= DEFAULT_POOL_CONSTRAINTS.max + && 0 < DEFAULT_POOL_CONSTRAINTS.max, ); /// Each connection will cache up to this number of statements by default. @@ -1210,8 +1211,8 @@ impl PoolConstraints { /// assert_eq!(opts.pool_opts().constraints(), PoolConstraints::new(0, 151).unwrap()); /// # Ok(()) } /// ``` - pub fn new(min: usize, max: usize) -> Option { - if min <= max { + pub const fn new(min: usize, max: usize) -> Option { + if min <= max && 0 < max { Some(PoolConstraints { min, max }) } else { None From 24bab52940a51268dc5af96512904f27a145a87c Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 24 Apr 2025 12:55:06 +0300 Subject: [PATCH 122/130] Fix deprecation warning with rand 0.9 --- src/conn/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 767bcd17..91889412 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1646,7 +1646,7 @@ mod test { fn random_pass() -> String { let mut rng = rand::rng(); - let pass: [u8; 10] = rng.gen(); + let pass: [u8; 10] = rng.random(); IntoIterator::into_iter(pass) .map(|x| ((x % (123 - 97)) + 97) as char) From 576e7d55fdfc3f994622877c1ab15cc6f5591f47 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Thu, 24 Apr 2025 13:01:35 +0300 Subject: [PATCH 123/130] Bump `mysql_common` to v0.35.3 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 816816f8..c684f78b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" lru = "0.14.0" -mysql_common = { version = "0.35", default-features = false } +mysql_common = { version = "0.35.3", default-features = false } pem = "3.0" percent-encoding = "2.1.0" rand = "0.9" From 7656f438968ccb0a5abf26f2d076e8c52db09c1a Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Fri, 25 Apr 2025 11:25:54 +0300 Subject: [PATCH 124/130] Properly update `active_wait_requests` metric (fix #335) --- src/conn/pool/mod.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 9106764a..53688d9b 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -110,6 +110,7 @@ struct Waitlist { } impl Waitlist { + /// Returns `true` if pushed. fn push(&mut self, waker: Waker, queue_id: QueueId) -> bool { // The documentation of Future::poll says: // Note that on multiple calls to poll, only the Waker from @@ -122,10 +123,9 @@ impl Waitlist { // // This means we have to remove first to have the most recent // waker in the queue. - self.remove(queue_id); - self.queue - .push(QueuedWaker { queue_id, waker }, queue_id) - .is_none() + let occupied = self.remove(queue_id); + self.queue.push(QueuedWaker { queue_id, waker }, queue_id); + !occupied } fn pop(&mut self) -> Option { @@ -135,6 +135,7 @@ impl Waitlist { } } + /// Returns `true` if removed. fn remove(&mut self, id: QueueId) -> bool { self.queue.remove(&id).is_some() } @@ -1016,6 +1017,13 @@ mod test { drop(only_conn); assert_eq!(0, pool.inner.exchange.lock().unwrap().waiting.queue.len()); + // metrics should catch up with waiting queue (see #335) + assert_eq!( + 0, + pool.metrics() + .active_wait_requests + .load(std::sync::atomic::Ordering::Relaxed) + ); } #[tokio::test] From 2c347b0a0d7c860f22f0febf2eb52bb0bb6f0be4 Mon Sep 17 00:00:00 2001 From: Anatoly I Date: Fri, 25 Apr 2025 11:37:40 +0300 Subject: [PATCH 125/130] Apply suggestions from code review Co-authored-by: Paolo Barbolini --- src/opts/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/opts/mod.rs b/src/opts/mod.rs index a938a0d9..1923756a 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -1212,7 +1212,7 @@ impl PoolConstraints { /// # Ok(()) } /// ``` pub const fn new(min: usize, max: usize) -> Option { - if min <= max && 0 < max { + if min <= max && max > 0 { Some(PoolConstraints { min, max }) } else { None From 0fe65d4cc8ccfa11538b20b792bd1d0dbb813d81 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Mon, 5 May 2025 13:03:03 +0300 Subject: [PATCH 126/130] Fix `binlog` feature build --- src/conn/binlog_stream/mod.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/conn/binlog_stream/mod.rs b/src/conn/binlog_stream/mod.rs index 60aa6659..b5433f77 100644 --- a/src/conn/binlog_stream/mod.rs +++ b/src/conn/binlog_stream/mod.rs @@ -24,7 +24,7 @@ use std::{ task::{Context, Poll}, }; -use crate::{connection_like::Connection, queryable::Queryable}; +use crate::{connection_like::ConnectionInner, queryable::Queryable}; use crate::{error::DriverError, io::ReadPacket, Conn, Error, IoError, Result}; use self::request::BinlogStreamRequest; @@ -94,10 +94,10 @@ impl BinlogStream { /// Closes the stream's `Conn`. Additionally, the connection is dropped, so its associated /// pool (if any) will regain a connection slot. pub async fn close(self) -> Result<()> { - match self.read_packet.0 { + match self.read_packet.0.inner { // `close_conn` requires ownership of `Conn`. That's okay, because // `BinLogStream`'s connection is always owned. - Connection::Conn(conn) => { + ConnectionInner::Conn(conn) => { if let Err(Error::Io(IoError::Io(ref error))) = conn.close_conn().await { // If the binlog was requested with the flag BINLOG_DUMP_NON_BLOCK, // the connection's file handler will already have been closed (EOF). @@ -106,8 +106,8 @@ impl BinlogStream { } } } - Connection::ConnMut(_) => {} - Connection::Tx(_) => {} + ConnectionInner::ConnMut(_) => {} + ConnectionInner::Tx(_) => {} } Ok(()) From d2ea1b68ff46d8b85d958712e8808489bd6a9dcf Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Mon, 5 May 2025 13:32:55 +0300 Subject: [PATCH 127/130] Update azure-pipelines.yml --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index db33fd28..276530fc 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -40,7 +40,7 @@ jobs: displayName: cargo fmt - bash: | cargo +nightly build -Zfeatures=dev_dep - SSL=false COMPRESS=false cargo test + SSL=false COMPRESS=false cargo test --features binlog,derive,chrono,time,bigdecimal,rust_decimal,frunk,client_ed25519,tracing SSL=true COMPRESS=false cargo test --features native-tls-tls SSL=false COMPRESS=true cargo test SSL=true COMPRESS=true cargo test --features rustls-tls,ring From cafaf30009d6aa63bc7ba7fd42a27f1607c48197 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Mon, 5 May 2025 13:33:39 +0300 Subject: [PATCH 128/130] Bump micro version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c684f78b..ae67e4fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "MIT OR Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.36.0" +version = "0.36.1" exclude = ["test/*"] edition = "2021" categories = ["asynchronous", "database"] From 74d451a3f9ede63d8439fc65d33fd15db9d1ed10 Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 6 May 2025 10:09:39 +0300 Subject: [PATCH 129/130] Bump mysql_common version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ae67e4fb..98d09bac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ futures-util = "0.3" futures-sink = "0.3" keyed_priority_queue = "0.4" lru = "0.14.0" -mysql_common = { version = "0.35.3", default-features = false } +mysql_common = { version = "0.35.4", default-features = false } pem = "3.0" percent-encoding = "2.1.0" rand = "0.9" From 75c51d9b8015d7a2b15d820da061f70cc1421dab Mon Sep 17 00:00:00 2001 From: Anatoly Ikorsky Date: Tue, 6 May 2025 10:10:30 +0300 Subject: [PATCH 130/130] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d9642034..27d67128 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ target Cargo.lock .idea mysql_async.iml +.vscode