diff --git a/.gitignore b/.gitignore index 50cfd76f..e41e27fb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ Cargo.lock .idea mysql_async.iml .direnv +.vscode diff --git a/Cargo.toml b/Cargo.toml index 1c4c8de2..9db4d144 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,45 +3,54 @@ authors = ["blackbeam "] description = "Tokio based asynchronous MySql client library." documentation = "https://docs.rs/mysql_async" keywords = ["mysql", "database", "asynchronous", "async"] -license = "MIT/Apache-2.0" +license = "MIT OR Apache-2.0" name = "mysql_async" readme = "README.md" repository = "https://github.com/blackbeam/mysql_async" -version = "0.31.3" +version = "0.36.1" exclude = ["test/*"] -edition = "2018" +edition = "2021" categories = ["asynchronous", "database"] [dependencies] -bytes = "1.0" -crossbeam = "0.8.1" +bytes = "1.4" +crossbeam-queue = "0.3" flate2 = { version = "1.0", default-features = false } futures-core = "0.3" futures-util = "0.3" futures-sink = "0.3" -lazy_static = "1" -lru = "0.8.1" -mio = { version = "0.8.0", features = ["os-poll", "net"] } -mysql_common = { version = "0.29.2", default-features = false } -once_cell = "1.7.2" -pem = "1.0.1" +keyed_priority_queue = "0.4" +lazy_static = "1.5" +lru = "0.14.0" +mysql_common = { version = "0.35.4", default-features = false } +pem = "3.0" percent-encoding = "2.1.0" -pin-project = "1.0.2" -priority-queue = "1" +rand = "0.9" serde = "1" serde_json = "1" -socket2 = "0.4.2" -thiserror = "1.0.4" -tokio = { version = "1.0", features = ["io-util", "fs", "net", "time", "rt"] } +socket2 = "0.5.2" +thiserror = "2" +tokio = { version = "1.0", features = [ + "io-util", + "fs", + "net", + "time", + "rt", + "sync", +] } tokio-util = { version = "0.7.2", features = ["codec", "io"] } -tracing = { version = "0.1.37", default-features = false, features = ["attributes"], optional = true } -twox-hash = "1" +tracing = { version = "0.1.37", default-features = false, features = [ + "attributes", +], optional = true } +twox-hash = { version = "2", default-features = false, features = ["xxhash64"] } url = "2.1" regex = "1.10.3" +hdrhistogram = { version = "7.5", optional = true } lexical = "6.1.0" [dependencies.tokio-rustls] -version = "0.23.4" +version = "0.26" +default-features = false optional = true [dependencies.tokio-native-tls] @@ -53,56 +62,66 @@ version = "0.2" optional = true [dependencies.rustls] -version = "0.20.0" -features = ["dangerous_configuration"] +version = "0.23" +default-features = false +features = ["std"] optional = true [dependencies.rustls-pemfile] -version = "1.0.1" -optional = true - -[dependencies.webpki] -version = "0.22.0" +version = "2.1.0" optional = true [dependencies.webpki-roots] -version = "0.22.1" +version = "0.26.1" optional = true [dev-dependencies] +waker-fn = "1" tempfile = "3.1.0" -socket2 = { version = "0.4.0", features = ["all"] } +socket2 = { version = "0.5.2", features = ["all"] } tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread"] } -rand = "0.8.0" [features] -default = [ - "flate2/zlib", - "mysql_common/bigdecimal03", - "mysql_common/rust_decimal", - "mysql_common/time03", - "mysql_common/uuid", - "mysql_common/frunk", - "native-tls-tls", -] -default-rustls = [ - "flate2/zlib", - "mysql_common/bigdecimal03", - "mysql_common/rust_decimal", - "mysql_common/time03", - "mysql_common/uuid", - "mysql_common/frunk", +default = ["flate2/zlib", "derive"] + +default-rustls = ["default-rustls-no-provider", "aws-lc-rs"] + +default-rustls-ring = ["default-rustls-no-provider", "ring"] + +default-rustls-no-provider = [ + "flate2/rust_backend", + "derive", "rustls-tls", + "tls12", ] + +# minimal feature set with system flate2 impl minimal = ["flate2/zlib"] +# minimal feature set with rust flate2 impl +minimal-rust = ["flate2/rust_backend"] + +# native-tls based TLS support native-tls-tls = ["native-tls", "tokio-native-tls"] -rustls-tls = [ - "rustls", - "tokio-rustls", - "webpki", - "webpki-roots", - "rustls-pemfile", -] + +# rustls based TLS support +rustls-tls = ["rustls", "tokio-rustls", "webpki-roots", "rustls-pemfile"] + +aws-lc-rs = ["rustls/aws_lc_rs", "tokio-rustls/aws_lc_rs"] +ring = ["rustls/ring", "tokio-rustls/ring"] +tls12 = ["rustls/tls12", "tokio-rustls/tls12"] + +binlog = ["mysql_common/binlog"] + +# mysql_common features +derive = ["mysql_common/derive"] +chrono = ["mysql_common/chrono"] +time = ["mysql_common/time"] +bigdecimal = ["mysql_common/bigdecimal"] +rust_decimal = ["mysql_common/rust_decimal"] +frunk = ["mysql_common/frunk"] +client_ed25519 = ["mysql_common/client_ed25519"] + +# other features tracing = ["dep:tracing"] nightly = [] vendored-openssl = ["tokio-native-tls/vendored", "native-tls/vendored"] diff --git a/README.md b/README.md index 215bb4b4..2d2bec3c 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,10 @@ mysql_async = "" ## Crate Features -Default feature set is wide – it includes all default [`mysql_common` features][myslqcommonfeatures] -as well as `native-tls`-based TLS support. +By default there are only two features enabled: + +* `flate2/zlib` — choosing flate2 backend is mandatory +* `derive` — see ["Derive Macros" section in `mysql_common` docs][mysqlcommonderive] ### List Of Features @@ -37,25 +39,18 @@ as well as `native-tls`-based TLS support. mysql_async = { version = "*", default-features = false, features = ["minimal"]} ``` - **Note:* it is possible to use another `flate2` backend by directly choosing it: +* `minimal-rust` - same as `minimal` but with rust-based flate2 backend. Enables: - ```toml - [dependencies] - mysql_async = { version = "*", default-features = false } - flate2 = { version = "*", default-features = false, features = ["rust_backend"] } - ``` + - `flate2/rust_backend` -* `default` – enables the following set of crate's and dependencies' features: +* `default` – enables the following set of features: - - `native-tls-tls` - - `flate2/zlib" - - `mysql_common/bigdecimal03` - - `mysql_common/rust_decimal` - - `mysql_common/time03` - - `mysql_common/uuid` - - `mysql_common/frunk` + - `flate2/zlib` + - `derive` -* `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. +* `default-rustls` – default set of features with TLS via `rustls/aws-lc-rs` + +* `default-rustls-ring` – default set of features with TLS via `rustls/ring` **Example:** @@ -64,21 +59,22 @@ as well as `native-tls`-based TLS support. mysql_async = { version = "*", default-features = false, features = ["default-rustls"] } ``` -* `native-tls-tls` – enables `native-tls`-based TLS support _(conflicts with `rustls-tls`)_ +* `native-tls-tls` – enables TLS via `native-tls` **Example:** ```toml [dependencies] - mysql_async = { version = "*", default-features = false, features = ["native-tls-tls"] } + mysql_async = { version = "*", default-features = false, features = ["minimal", "native-tls-tls"] } -* `rustls-tls` – enables `native-tls`-based TLS support _(conflicts with `native-tls-tls`)_ +* `rustls-tls` - enables rustls TLS backend with no provider. You should enable one + of existing providers using `aws-lc-rs` or `ring` features: **Example:** ```toml [dependencies] - mysql_async = { version = "*", default-features = false, features = ["rustls-tls"] } + mysql_async = { version = "*", default-features = false, features = ["minimal-rust", "rustls-tls", "ring"] } * `tracing` – enables instrumentation via `tracing` package. @@ -94,7 +90,21 @@ as well as `native-tls`-based TLS support. mysql_async = { version = "*", features = ["tracing"] } ``` +* `binlog` - enables binlog-related functionality. Enables: + + - `mysql_common/binlog" + +#### Proxied features (see [`mysql_common`` fatures][myslqcommonfeatures]) + +* `derive` – enables `mysql_common/derive` feature +* `chrono` = enables `mysql_common/chrono` feature +* `time` = enables `mysql_common/time` feature +* `bigdecimal` = enables `mysql_common/bigdecimal` feature +* `rust_decimal` = enables `mysql_common/rust_decimal` feature +* `frunk` = enables `mysql_common/frunk` feature + [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features +[mysqlcommonderive]: https://github.com/blackbeam/rust_mysql_common?tab=readme-ov-file#derive-macros ## TLS/SSL Support @@ -190,7 +200,7 @@ Please note: * [`Pool`] is a smart pointer – each clone will point to the same pool instance. * [`Pool`] is `Send + Sync + 'static` – feel free to pass it around. * use [`Pool::disconnect`] to gracefuly close the pool. -* [`Pool::new`] is lazy and won't assert server availability. +* ⚠️ [`Pool::new`] is lazy and won't assert server availability. ## Transaction diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 359eb6cb..276530fc 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -40,11 +40,16 @@ jobs: displayName: cargo fmt - bash: | cargo +nightly build -Zfeatures=dev_dep - SSL=false COMPRESS=false cargo test - SSL=true COMPRESS=false cargo test + SSL=false COMPRESS=false cargo test --features binlog,derive,chrono,time,bigdecimal,rust_decimal,frunk,client_ed25519,tracing + SSL=true COMPRESS=false cargo test --features native-tls-tls SSL=false COMPRESS=true cargo test - SSL=true COMPRESS=true cargo test - SSL=true COMPRESS=false cargo test --no-default-features --features default-rustls + SSL=true COMPRESS=true cargo test --features rustls-tls,ring + + SSL=true COMPRESS=false cargo check --no-default-features --features default-rustls + SSL=true COMPRESS=false cargo check --no-default-features --features default-rustls-ring + SSL=true COMPRESS=false cargo check --no-default-features --features minimal + SSL=true COMPRESS=false cargo check --no-default-features --features minimal-rust + SSL=true COMPRESS=false cargo check --no-default-features --features minimal,tracing env: RUST_BACKTRACE: 1 DATABASE_URL: mysql://root:root@127.0.0.1:3306/mysql @@ -79,9 +84,8 @@ jobs: displayName: Install Rust (Windows) - bash: | SSL=false COMPRESS=false cargo test - SSL=true COMPRESS=false cargo test + SSL=true COMPRESS=false cargo test --features native-tls-tls SSL=false COMPRESS=true cargo test - SSL=true COMPRESS=true cargo test env: RUST_BACKTRACE: 1 DATABASE_URL: mysql://root:password@127.0.0.1/mysql @@ -93,10 +97,16 @@ jobs: strategy: maxParallel: 10 matrix: + v91: + DB_VERSION: "9.1" + v90: + DB_VERSION: "9.0" + v84: + DB_VERSION: "8.4" v80: - DB_VERSION: "8-debian" + DB_VERSION: "8.0-debian" v57: - DB_VERSION: "5-debian" + DB_VERSION: "5.7-debian" v56: DB_VERSION: "5.6" steps: @@ -114,17 +124,28 @@ jobs: displayName: Run MySql in Docker - bash: | docker exec container bash -l -c "mysql -uroot -ppassword -e \"SET old_passwords = 1; GRANT ALL PRIVILEGES ON *.* TO 'root2'@'%' IDENTIFIED WITH mysql_old_password AS 'password'; SET PASSWORD FOR 'root2'@'%' = OLD_PASSWORD('password')\""; + docker exec container bash -l -c "echo 'deb [trusted=yes] http://archive.debian.org/debian/ stretch main non-free contrib' > /etc/apt/sources.list" + docker exec container bash -l -c "echo 'deb-src [trusted=yes] http://archive.debian.org/debian/ stretch main non-free contrib ' >> /etc/apt/sources.list" + docker exec container bash -l -c "echo 'deb [trusted=yes] http://archive.debian.org/debian-security/ stretch/updates main non-free contrib' >> /etc/apt/sources.list" + docker exec container bash -l -c "echo 'deb [trusted=yes] http://repo.mysql.com/apt/debian/ stretch mysql-5.6' > /etc/apt/sources.list.d/mysql.list" condition: eq(variables['DB_VERSION'], '5.6') - bash: | - docker exec container bash -l -c "apt-get update" + docker exec container bash -l -c "apt-get --allow-unauthenticated -y update" docker exec container bash -l -c "apt-get install -y curl clang libssl-dev pkg-config build-essential" docker exec container bash -l -c "curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable" - displayName: Install Rust in docker + displayName: Install Rust in docker (Debian) + condition: or(eq(variables['DB_VERSION'], '5.6'), eq(variables['DB_VERSION'], '5.7-debian'), eq(variables['DB_VERSION'], '8.0-debian')) + - bash: | + docker exec container bash -l -c "microdnf install dnf" + docker exec container bash -l -c "dnf group install \"Development Tools\"" + docker exec container bash -l -c "curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable" + displayName: Install Rust in docker (RedHat) + condition: not(or(eq(variables['DB_VERSION'], '5.6'), eq(variables['DB_VERSION'], '5.7-debian'), eq(variables['DB_VERSION'], '8.0-debian'))) - bash: | if [[ "5.6" != "$(DB_VERSION)" ]]; then SSL=true; else DATABASE_URL="mysql://root2:password@127.0.0.1/mysql?secure_auth=false"; fi docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL COMPRESS=true cargo test" - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL cargo test" + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL cargo test --features native-tls-tls" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL COMPRESS=true cargo test" docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=$SSL COMPRESS=true cargo test --no-default-features --features default-rustls" env: @@ -138,20 +159,16 @@ jobs: strategy: maxParallel: 10 matrix: - v107: - DB_VERSION: "10.7" - v106: - DB_VERSION: "10.6" - v105: - DB_VERSION: "10.5" - v104: - DB_VERSION: "10.4" - v103: - DB_VERSION: "10.3" - v102: - DB_VERSION: "10.2" - v101: - DB_VERSION: "10.1" + v1162: + DB_VERSION: "11.6.2" + v1152: + DB_VERSION: "11.5.2" + v1144: + DB_VERSION: "11.4.2" + v113: + DB_VERSION: "11.3.2" + v1011: + DB_VERSION: "10.11.10" steps: - bash: | sudo apt-get update @@ -166,6 +183,7 @@ jobs: git checkout 901a7de displayName: Clone rust-mysql-simple (for ssl certs) - bash: | + if [[ "11.6.2" == "$(DB_VERSION)" ]]; then ARG=" --plugin-load-add=auth_ed25519"; fi docker run --rm -d \ --name container \ -v `pwd`:/root \ @@ -180,7 +198,9 @@ jobs: --ssl \ --ssl-ca=/root/rust-mysql-simple/tests/ca.crt \ --ssl-cert=/root/rust-mysql-simple/tests/server.crt \ - --ssl-key=/root/rust-mysql-simple/tests/server-key.pem & + --ssl-key=/root/rust-mysql-simple/tests/server-key.pem \ + --secure-auth=OFF \ + $ARG & while ! nc -W 1 localhost 3307 | grep -q -P '.+'; do sleep 1; done displayName: Run MariaDb in Docker - bash: | @@ -189,11 +209,12 @@ jobs: docker exec container bash -l -c "curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable" displayName: Install Rust in docker - bash: | - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test" - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL COMPRESS=true cargo test" - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test" - docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true COMPRESS=true cargo test" - if [[ "10.1" != "$(DB_VERSION)" ]]; then docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test --no-default-features --features default-rustls"; fi + if [[ "11.6.2" == "$(DB_VERSION)" ]]; then FEATURES="client_ed25519"; fi + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL cargo test --features $FEATURES" + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL COMPRESS=true cargo test --features $FEATURES" + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test --features native-tls-tls,$FEATURES" + docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true COMPRESS=true cargo test --features $FEATURES" + if [[ "10.1" != "$(DB_VERSION)" ]]; then docker exec container bash -l -c "cd \$HOME && DATABASE_URL=$DATABASE_URL SSL=true cargo test --no-default-features --features default-rustls,$FEATURES"; fi env: RUST_BACKTRACE: 1 DATABASE_URL: mysql://root:password@127.0.0.1/mysql @@ -204,10 +225,14 @@ jobs: vmImage: "ubuntu-latest" strategy: matrix: - v5.3.0: - DB_VERSION: "v5.3.0" - v5.0.6: - DB_VERSION: "v5.0.6" + v8.5.0: + DB_VERSION: "v8.5.0" + v7.6.0: + DB_VERSION: "v7.6.0" + v6.6.0: + DB_VERSION: "v6.6.0" + v5.4.3: + DB_VERSION: "v5.4.3" steps: - bash: | curl --proto '=https' --tlsv1.2 -sSf https://tiup-mirrors.pingcap.com/install.sh | sh diff --git a/src/buffer_pool.rs b/src/buffer_pool.rs index 03b1a4cc..d9391c6a 100644 --- a/src/buffer_pool.rs +++ b/src/buffer_pool.rs @@ -6,8 +6,8 @@ // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. -use crossbeam::queue::ArrayQueue; -use std::{mem::replace, ops::Deref, sync::Arc}; +use crossbeam_queue::ArrayQueue; +use std::{mem::take, ops::Deref, sync::Arc}; #[derive(Debug)] pub struct BufferPool { @@ -93,6 +93,6 @@ impl Deref for PooledBuf { impl Drop for PooledBuf { fn drop(&mut self) { - self.1.put(replace(&mut self.0, vec![])) + self.1.put(take(&mut self.0)) } } diff --git a/src/conn/binlog_stream.rs b/src/conn/binlog_stream.rs deleted file mode 100644 index 6aca3305..00000000 --- a/src/conn/binlog_stream.rs +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2020 Anatoly Ikorsky -// -// Licensed under the Apache License, Version 2.0 -// or the MIT -// license , at your -// option. All files in the project carrying such notice may not be copied, -// modified, or distributed except according to those terms. - -use futures_core::ready; -use mysql_common::{ - binlog::{ - consts::BinlogVersion::Version4, - events::{Event, TableMapEvent}, - EventStreamReader, - }, - io::ParseBuf, - packets::{ErrPacket, NetworkStreamTerminator, OkPacketDeserializer}, -}; - -use std::{ - future::Future, - io::ErrorKind, - pin::Pin, - task::{Context, Poll}, -}; - -use crate::connection_like::Connection; -use crate::{error::DriverError, io::ReadPacket, Conn, Error, IoError, Result}; - -/// Binlog event stream. -/// -/// Stream initialization is lazy, i.e. binlog won't be requested until this stream is polled. -pub struct BinlogStream { - read_packet: ReadPacket<'static, 'static>, - esr: EventStreamReader, -} - -impl BinlogStream { - /// `conn` is a `Conn` with `request_binlog` executed on it. - pub(super) fn new(conn: Conn) -> Self { - BinlogStream { - read_packet: ReadPacket::new(conn), - esr: EventStreamReader::new(Version4), - } - } - - /// Returns a table map event for the given table id. - pub fn get_tme(&self, table_id: u64) -> Option<&TableMapEvent<'static>> { - self.esr.get_tme(table_id) - } - - /// Closes the stream's `Conn`. Additionally, the connection is dropped, so its associated - /// pool (if any) will regain a connection slot. - pub async fn close(self) -> Result<()> { - match self.read_packet.0 { - // `close_conn` requires ownership of `Conn`. That's okay, because - // `BinLogStream`'s connection is always owned. - Connection::Conn(conn) => { - if let Err(Error::Io(IoError::Io(ref error))) = conn.close_conn().await { - // If the binlog was requested with the flag BINLOG_DUMP_NON_BLOCK, - // the connection's file handler will already have been closed (EOF). - if error.kind() == ErrorKind::BrokenPipe { - return Ok(()); - } - } - } - Connection::ConnMut(_) => {} - Connection::Tx(_) => {} - } - - Ok(()) - } -} - -impl futures_core::stream::Stream for BinlogStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let packet = match ready!(Pin::new(&mut self.read_packet).poll(cx)) { - Ok(packet) => packet, - Err(err) => return Poll::Ready(Some(Err(err.into()))), - }; - - let first_byte = packet.get(0).copied(); - - if first_byte == Some(255) { - if let Ok(ErrPacket::Error(err)) = - ParseBuf(&*packet).parse(self.read_packet.conn_ref().capabilities()) - { - return Poll::Ready(Some(Err(From::from(err)))); - } - } - - if first_byte == Some(254) && packet.len() < 8 { - if ParseBuf(&*packet) - .parse::>( - self.read_packet.conn_ref().capabilities(), - ) - .is_ok() - { - return Poll::Ready(None); - } - } - - if first_byte == Some(0) { - let event_data = &packet[1..]; - match self.esr.read(event_data) { - Ok(event) => { - return Poll::Ready(Some(Ok(event))); - } - Err(err) => return Poll::Ready(Some(Err(err.into()))), - } - } else { - return Poll::Ready(Some(Err(DriverError::UnexpectedPacket { - payload: packet.to_vec(), - } - .into()))); - } - } -} diff --git a/src/conn/binlog_stream/mod.rs b/src/conn/binlog_stream/mod.rs new file mode 100644 index 00000000..b5433f77 --- /dev/null +++ b/src/conn/binlog_stream/mod.rs @@ -0,0 +1,377 @@ +// Copyright (c) 2020 Anatoly Ikorsky +// +// Licensed under the Apache License, Version 2.0 +// or the MIT +// license , at your +// option. All files in the project carrying such notice may not be copied, +// modified, or distributed except according to those terms. + +use futures_core::ready; +use mysql_common::{ + binlog::{ + consts::{BinlogVersion::Version4, EventType}, + events::{Event, TableMapEvent, TransactionPayloadEvent}, + EventStreamReader, + }, + io::ParseBuf, + packets::{ComRegisterSlave, ErrPacket, NetworkStreamTerminator, OkPacketDeserializer}, +}; + +use std::{ + future::Future, + io::{Cursor, ErrorKind}, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{connection_like::ConnectionInner, queryable::Queryable}; +use crate::{error::DriverError, io::ReadPacket, Conn, Error, IoError, Result}; + +use self::request::BinlogStreamRequest; + +pub mod request; + +impl super::Conn { + /// Turns this connection into a binlog stream. + /// + /// You can use SHOW BINARY LOGS to get the current logfile and position from the master. + /// If the request’s filename is empty, the server will send the binlog-stream of the first known binlog. + pub async fn get_binlog_stream( + mut self, + request: BinlogStreamRequest<'_>, + ) -> Result { + self.request_binlog(request).await?; + + Ok(BinlogStream::new(self)) + } + + async fn register_as_slave( + &mut self, + com_register_slave: ComRegisterSlave<'_>, + ) -> crate::Result<()> { + self.query_drop("SET @master_binlog_checksum='ALL'").await?; + self.write_command(&com_register_slave).await?; + + // Server will respond with OK. + self.read_packet().await?; + + Ok(()) + } + + async fn request_binlog(&mut self, request: BinlogStreamRequest<'_>) -> crate::Result<()> { + self.register_as_slave(request.register_slave).await?; + self.write_command(&request.binlog_request.as_cmd()).await?; + Ok(()) + } +} + +/// Binlog event stream. +/// +/// Stream initialization is lazy, i.e. binlog won't be requested until this stream is polled. +pub struct BinlogStream { + read_packet: ReadPacket<'static, 'static>, + esr: EventStreamReader, + // TODO: Use 'static reader here (requires impl on the mysql_common side). + /// Uncompressed Transaction_payload_event we are iterating over (if any). + tpe: Option>>, +} + +impl BinlogStream { + /// `conn` is a `Conn` with `request_binlog` executed on it. + pub(super) fn new(conn: Conn) -> Self { + BinlogStream { + read_packet: ReadPacket::new(conn), + esr: EventStreamReader::new(Version4), + tpe: None, + } + } + + /// Returns a table map event for the given table id. + pub fn get_tme(&self, table_id: u64) -> Option<&TableMapEvent<'static>> { + self.esr.get_tme(table_id) + } + + /// Closes the stream's `Conn`. Additionally, the connection is dropped, so its associated + /// pool (if any) will regain a connection slot. + pub async fn close(self) -> Result<()> { + match self.read_packet.0.inner { + // `close_conn` requires ownership of `Conn`. That's okay, because + // `BinLogStream`'s connection is always owned. + ConnectionInner::Conn(conn) => { + if let Err(Error::Io(IoError::Io(ref error))) = conn.close_conn().await { + // If the binlog was requested with the flag BINLOG_DUMP_NON_BLOCK, + // the connection's file handler will already have been closed (EOF). + if error.kind() == ErrorKind::BrokenPipe { + return Ok(()); + } + } + } + ConnectionInner::ConnMut(_) => {} + ConnectionInner::Tx(_) => {} + } + + Ok(()) + } +} + +impl futures_core::stream::Stream for BinlogStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + { + let Self { + ref mut tpe, + ref mut esr, + .. + } = *self; + + if let Some(tpe) = tpe.as_mut() { + match esr.read_decompressed(tpe) { + Ok(Some(event)) => return Poll::Ready(Some(Ok(event))), + Ok(None) => self.tpe = None, + Err(err) => return Poll::Ready(Some(Err(err.into()))), + } + } + } + + let packet = match ready!(Pin::new(&mut self.read_packet).poll(cx)) { + Ok(packet) => packet, + Err(err) => return Poll::Ready(Some(Err(err.into()))), + }; + + let first_byte = packet.first().copied(); + + if first_byte == Some(255) { + if let Ok(ErrPacket::Error(err)) = + ParseBuf(&packet).parse(self.read_packet.conn_ref().capabilities()) + { + return Poll::Ready(Some(Err(From::from(err)))); + } + } + + if first_byte == Some(254) + && packet.len() < 8 + && ParseBuf(&packet) + .parse::>( + self.read_packet.conn_ref().capabilities(), + ) + .is_ok() + { + return Poll::Ready(None); + } + + if first_byte == Some(0) { + let event_data = &packet[1..]; + match self.esr.read(event_data) { + Ok(Some(event)) => { + if event.header().event_type_raw() == EventType::TRANSACTION_PAYLOAD_EVENT as u8 + { + #[allow(clippy::single_match)] + match event.read_event::>() { + Ok(e) => self.tpe = Some(Cursor::new(e.danger_decompress())), + Err(_) => (/* TODO: Log the error */), + } + } + Poll::Ready(Some(Ok(event))) + } + Ok(None) => Poll::Ready(None), + Err(err) => Poll::Ready(Some(Err(err.into()))), + } + } else { + Poll::Ready(Some(Err(DriverError::UnexpectedPacket { + payload: packet.to_vec(), + } + .into()))) + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures_util::StreamExt; + use mysql_common::binlog::events::EventData; + use tokio::time::timeout; + + use crate::prelude::*; + use crate::{test_misc::get_opts, *}; + + async fn gen_dummy_data(conn: &mut Conn) -> super::Result<()> { + "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)" + .ignore(&mut *conn) + .await?; + + let mut tx = conn.start_transaction(Default::default()).await?; + for i in 0_u8..100 { + "INSERT INTO customers(customer_id) VALUES (?)" + .with((i,)) + .ignore(&mut tx) + .await?; + } + tx.commit().await?; + + "DROP TABLE customers".ignore(conn).await?; + + Ok(()) + } + + async fn create_binlog_stream_conn(pool: Option<&Pool>) -> super::Result<(Conn, Vec, u64)> { + let mut conn = match pool { + None => Conn::new(get_opts()).await.unwrap(), + Some(pool) => pool.get_conn().await.unwrap(), + }; + + if conn.server_version() >= (8, 0, 31) && conn.server_version() < (9, 0, 0) { + let _ = "SET binlog_transaction_compression=ON" + .ignore(&mut conn) + .await; + } + + if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE" + .first::(&mut conn) + .await + { + if !gtid_mode.starts_with("ON") { + panic!( + "GTID_MODE is disabled \ + (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)" + ); + } + } + + let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn).await.unwrap().unwrap(); + let filename = row.get(0).unwrap(); + let position = row.get(1).unwrap(); + + gen_dummy_data(&mut conn).await.unwrap(); + Ok((conn, filename, position)) + } + + #[tokio::test] + async fn should_read_binlog() -> super::Result<()> { + read_binlog_streams_and_close_their_connections(None, (12, 13, 14)) + .await + .unwrap(); + + let pool = Pool::new(get_opts()); + read_binlog_streams_and_close_their_connections(Some(&pool), (15, 16, 17)) + .await + .unwrap(); + + // Disconnecting the pool verifies that closing the binlog connections + // left the pool in a sane state. + timeout(Duration::from_secs(10), pool.disconnect()) + .await + .unwrap() + .unwrap(); + + Ok(()) + } + + async fn read_binlog_streams_and_close_their_connections( + pool: Option<&Pool>, + binlog_server_ids: (u32, u32, u32), + ) -> super::Result<()> { + // iterate using COM_BINLOG_DUMP + let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); + let is_mariadb = conn.inner.is_mariadb; + + let mut binlog_stream = conn + .get_binlog_stream( + BinlogStreamRequest::new(binlog_server_ids.0) + .with_filename(&filename) + .with_pos(pos), + ) + .await + .unwrap(); + + let mut events_num = 0; + while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await { + let event = event.unwrap(); + events_num += 1; + + // assert that event type is known + event.header().event_type().unwrap(); + + // iterate over rows of an event + if let EventData::RowsEvent(re) = event.read_data()?.unwrap() { + let tme = binlog_stream.get_tme(re.table_id()); + for row in re.rows(tme.unwrap()) { + row.unwrap(); + } + } + } + assert!(events_num > 0); + timeout(Duration::from_secs(10), binlog_stream.close()) + .await + .unwrap() + .unwrap(); + + if !is_mariadb { + // iterate using COM_BINLOG_DUMP_GTID + let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); + + let mut binlog_stream = conn + .get_binlog_stream( + BinlogStreamRequest::new(binlog_server_ids.1) + .with_gtid() + .with_filename(&filename) + .with_pos(pos), + ) + .await + .unwrap(); + + events_num = 0; + while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await + { + let event = event.unwrap(); + events_num += 1; + + // assert that event type is known + event.header().event_type().unwrap(); + + // iterate over rows of an event + if let EventData::RowsEvent(re) = event.read_data()?.unwrap() { + let tme = binlog_stream.get_tme(re.table_id()); + for row in re.rows(tme.unwrap()) { + row.unwrap(); + } + } + } + assert!(events_num > 0); + timeout(Duration::from_secs(10), binlog_stream.close()) + .await + .unwrap() + .unwrap(); + } + + // iterate using COM_BINLOG_DUMP with BINLOG_DUMP_NON_BLOCK flag + let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); + + let mut binlog_stream = conn + .get_binlog_stream( + BinlogStreamRequest::new(binlog_server_ids.2) + .with_filename(&filename) + .with_pos(pos) + .with_non_blocking(), + ) + .await + .unwrap(); + + events_num = 0; + while let Some(event) = binlog_stream.next().await { + let event = event.unwrap(); + events_num += 1; + event.header().event_type().unwrap(); + event.read_data().unwrap(); + } + assert!(events_num > 0); + timeout(Duration::from_secs(10), binlog_stream.close()) + .await + .unwrap() + .unwrap(); + + Ok(()) + } +} diff --git a/src/conn/binlog_stream/request.rs b/src/conn/binlog_stream/request.rs new file mode 100644 index 00000000..23265b7c --- /dev/null +++ b/src/conn/binlog_stream/request.rs @@ -0,0 +1,86 @@ +use mysql_common::packets::{ + binlog_request::BinlogRequest, BinlogDumpFlags, ComRegisterSlave, Sid, +}; + +/// Binlog stream request builder. +pub struct BinlogStreamRequest<'a> { + pub(crate) binlog_request: BinlogRequest<'a>, + pub(crate) register_slave: ComRegisterSlave<'a>, +} + +impl<'a> BinlogStreamRequest<'a> { + /// Creates a new request with the given slave server id. + pub fn new(server_id: u32) -> Self { + Self { + binlog_request: BinlogRequest::new(server_id), + register_slave: ComRegisterSlave::new(server_id), + } + } + + /// Enables GTID-based replication (disabled by default). + pub fn with_gtid(mut self) -> Self { + self.binlog_request = self.binlog_request.with_use_gtid(true); + self + } + + /// Enables `NON_BLOCK` flag. Stream will be terminated as soon as there are no events. + pub fn with_non_blocking(mut self) -> Self { + self.binlog_request = self + .binlog_request + .with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK); + self + } + + /// Sets the filename of the binlog on the master (try `SHOW BINARY LOGS`). + pub fn with_filename(mut self, filename: &'a [u8]) -> Self { + self.binlog_request = self.binlog_request.with_filename(filename); + self + } + + /// Sets the start position (defaults to `4`). + pub fn with_pos(mut self, position: u64) -> Self { + self.binlog_request = self.binlog_request.with_pos(position); + self + } + + /// Adds the given set of GTIDs to the request (ignored if not GTID-based). + pub fn with_gtid_set(mut self, set: T) -> Self + where + T: IntoIterator>, + { + self.binlog_request = self.binlog_request.with_sids(set); + self + } + + /// This hostname will be reported to the server (max len 255, default to an empty string). + /// + /// Usually left default. + pub fn with_hostname(mut self, hostname: &'a [u8]) -> Self { + self.register_slave = self.register_slave.with_hostname(hostname); + self + } + + /// This username will be reported to the server (max len 255, default to an empty string). + /// + /// Usually left default. + pub fn with_user(mut self, user: &'a [u8]) -> Self { + self.register_slave = self.register_slave.with_user(user); + self + } + + /// This password will be reported to the server (max len 255, default to an empty string). + /// + /// Usually left default. + pub fn with_password(mut self, password: &'a [u8]) -> Self { + self.register_slave = self.register_slave.with_password(password); + self + } + + /// This port number will be reported to the server (defaults to `0`). + /// + /// Usually left default. + pub fn with_port(mut self, port: u16) -> Self { + self.register_slave = self.register_slave.with_port(port); + self + } +} diff --git a/src/conn/mod.rs b/src/conn/mod.rs index e1376f87..7eee74d4 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -7,16 +7,15 @@ // modified, or distributed except according to those terms. use futures_util::FutureExt; -pub use mysql_common::named_params; use mysql_common::{ constants::{DEFAULT_MAX_ALLOWED_PACKET, UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI}, crypto, io::ParseBuf, packets::{ - binlog_request::BinlogRequest, AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, - HandshakePacket, HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, - OldEofPacket, ResultSetTerminator, SslRequest, + AuthPlugin, AuthSwitchRequest, CommonOkPacket, ErrPacket, HandshakePacket, + HandshakeResponse, OkPacket, OkPacketDeserializer, OldAuthSwitchRequest, OldEofPacket, + ResultSetTerminator, SslRequest, }, proto::MySerialize, row::Row, @@ -45,18 +44,21 @@ use crate::{ transaction::TxStatus, BinaryProtocol, Queryable, TextProtocol, }, - BinlogStream, InfileData, OptsBuilder, + ChangeUserOpts, InfileData, OptsBuilder, }; use self::routines::Routine; use regex::bytes::Regex; +#[cfg(feature = "binlog")] pub mod binlog_stream; pub mod pool; pub mod routines; pub mod stmt_cache; +const DEFAULT_WAIT_TIMEOUT: usize = 28800; + lazy_static::lazy_static! { static ref FIXED_MARIADB_VERSION_RE: Regex = Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap(); @@ -111,16 +113,21 @@ struct ConnInner { status: StatusFlags, last_ok_packet: Option>, last_err_packet: Option>, + handshake_complete: bool, pool: Option, pending_result: std::result::Result, ServerError>, tx_status: TxStatus, + reset_upon_returning_to_a_pool: bool, opts: Opts, + ttl_deadline: Option, last_io: Instant, wait_timeout: Duration, stmt_cache: StmtCache, nonce: Vec, auth_plugin: AuthPlugin<'static>, auth_switched: bool, + server_key: Option>, + active_since: Instant, /// Connection is already disconnected. pub(crate) disconnected: bool, /// One-time connection-level infile handler. @@ -138,6 +145,8 @@ impl fmt::Debug for ConnInner { .field("tx_status", &self.tx_status) .field("stream", &self.stream) .field("options", &self.opts) + .field("server_key", &self.server_key) + .field("auth_plugin", &self.auth_plugin) .finish() } } @@ -145,11 +154,13 @@ impl fmt::Debug for ConnInner { impl ConnInner { /// Constructs an empty connection. fn empty(opts: Opts) -> ConnInner { + let ttl_deadline = opts.pool_opts().new_connection_ttl_deadline(); ConnInner { capabilities: opts.get_capabilities(), status: StatusFlags::empty(), last_ok_packet: None, last_err_packet: None, + handshake_complete: false, stream: None, is_mariadb: false, is_vitess: false, @@ -163,11 +174,15 @@ impl ConnInner { stmt_cache: StmtCache::new(opts.stmt_cache_size()), socket: opts.socket().map(Into::into), opts, + ttl_deadline, nonce: Vec::default(), auth_plugin: AuthPlugin::MysqlNativePassword, auth_switched: false, disconnected: false, + server_key: None, infile_handler: None, + reset_upon_returning_to_a_pool: false, + active_since: Instant::now(), } } @@ -236,6 +251,13 @@ impl Conn { self.inner.last_ok_packet.as_ref() } + /// Turns on/off automatic connection reset (see [`crate::PoolOpts::with_reset_connection`]). + /// + /// Only makes sense for pooled connections. + pub fn reset_connection(&mut self, reset_connection: bool) { + self.inner.reset_upon_returning_to_a_pool = reset_connection; + } + pub(crate) fn stream_mut(&mut self) -> Result<&mut Stream> { self.inner.stream_mut() } @@ -302,7 +324,7 @@ impl Conn { if let Err(ref e) = self.inner.pending_result { let e = e.clone(); self.inner.pending_result = Ok(None); - return Err(e); + Err(e) } else { Ok(self.inner.pending_result.as_ref().unwrap().as_ref()) } @@ -315,8 +337,7 @@ impl Conn { } pub(crate) fn has_pending_result(&self) -> bool { - matches!(self.inner.pending_result, Err(_)) - || matches!(self.inner.pending_result, Ok(Some(_))) + self.inner.pending_result.is_err() || matches!(self.inner.pending_result, Ok(Some(_))) } /// Sets the given pening result metadata for this connection. Returns the previous value. @@ -439,16 +460,33 @@ impl Conn { /// Returns true if io stream is encrypted. fn is_secure(&self) -> bool { #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] - if let Some(ref stream) = self.inner.stream { - stream.is_secure() - } else { - false + { + self.inner + .stream + .as_ref() + .map(|x| x.is_secure()) + .unwrap_or_default() } #[cfg(not(any(feature = "native-tls-tls", feature = "rustls-tls")))] false } + /// Returns true if io stream is socket. + fn is_socket(&self) -> bool { + #[cfg(unix)] + { + self.inner + .stream + .as_ref() + .map(|x| x.is_socket()) + .unwrap_or_default() + } + + #[cfg(not(unix))] + false + } + /// Hacky way to move connection through &mut. `self` becomes unusable. fn take(&mut self) -> Conn { mem::replace(self, Conn::empty(Default::default())) @@ -473,7 +511,7 @@ impl Conn { async fn handle_handshake(&mut self) -> Result<()> { let packet = self.read_packet().await?; - let handshake = ParseBuf(&*packet).parse::(())?; + let handshake = ParseBuf(&packet).parse::(())?; // Handshake scramble is always 21 bytes length (20 + zero terminator) self.inner.nonce = { @@ -497,17 +535,15 @@ impl Conn { .unwrap_or((0, 0, 0)); self.inner.id = handshake.connection_id(); self.inner.status = handshake.status_flags(); + + // Allow only CachingSha2Password and MysqlNativePassword here + // because sha256_password is deprecated and other plugins won't + // appear here. self.inner.auth_plugin = match handshake.auth_plugin() { - Some(AuthPlugin::MysqlNativePassword | AuthPlugin::MysqlOldPassword) => { - AuthPlugin::MysqlNativePassword - } Some(AuthPlugin::CachingSha2Password) => AuthPlugin::CachingSha2Password, - Some(AuthPlugin::Other(ref name)) => { - let name = String::from_utf8_lossy(name).into(); - return Err(DriverError::UnknownAuthPlugin { name }.into()); - } - None => AuthPlugin::MysqlNativePassword, + _ => AuthPlugin::MysqlNativePassword, }; + Ok(()) } @@ -563,9 +599,16 @@ impl Conn { ); self.write_struct(&ssl_request).await?; let conn = self; - let ssl_opts = conn.opts().ssl_opts().cloned().expect("unreachable"); - let domain = conn.opts().ip_or_hostname().into(); - conn.stream_mut()?.make_secure(domain, ssl_opts).await?; + let ssl_opts = conn.opts().ssl_opts_and_connector().expect("unreachable"); + let domain = ssl_opts + .ssl_opts() + .tls_hostname_override() + .unwrap_or_else(|| conn.opts().ip_or_hostname()) + .into(); + let tls_connector = ssl_opts.build_tls_connector().await?; + conn.stream_mut()? + .make_secure(domain, &tls_connector) + .await?; Ok(()) } else { Ok(()) @@ -576,7 +619,7 @@ impl Conn { let auth_data = self .inner .auth_plugin - .gen_data(self.inner.opts.pass(), &*self.inner.nonce); + .gen_data(self.inner.opts.pass(), &self.inner.nonce); let handshake_response = HandshakeResponse::new( auth_data.as_deref(), @@ -586,13 +629,18 @@ impl Conn { Some(self.inner.auth_plugin.borrow()), self.capabilities(), Default::default(), // TODO: Add support + self.inner + .opts + .max_allowed_packet() + .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET) as u32, ); // Serialize here to satisfy borrow checker. - let mut buf = crate::BUFFER_POOL.get(); + let mut buf = crate::buffer_pool().get(); handshake_response.serialize(buf.as_mut()); self.write_packet(buf).await?; + self.inner.handshake_complete = true; Ok(()) } @@ -607,23 +655,42 @@ impl Conn { if matches!( auth_switch_request.auth_plugin(), AuthPlugin::MysqlOldPassword - ) { - if self.inner.opts.secure_auth() { - return Err(DriverError::MysqlOldPasswordDisabled.into()); - } + ) && self.inner.opts.secure_auth() + { + return Err(DriverError::MysqlOldPasswordDisabled.into()); } self.inner.auth_plugin = auth_switch_request.auth_plugin().clone().into_owned(); - let plugin_data = self - .inner - .auth_plugin - .gen_data(self.inner.opts.pass(), &*self.inner.nonce); + let plugin_data = match &self.inner.auth_plugin { + x @ AuthPlugin::CachingSha2Password => { + x.gen_data(self.inner.opts.pass(), &self.inner.nonce) + } + x @ AuthPlugin::MysqlNativePassword => { + x.gen_data(self.inner.opts.pass(), &self.inner.nonce) + } + x @ AuthPlugin::MysqlOldPassword => { + if self.inner.opts.secure_auth() { + return Err(DriverError::MysqlOldPasswordDisabled.into()); + } else { + x.gen_data(self.inner.opts.pass(), &self.inner.nonce) + } + } + x @ AuthPlugin::MysqlClearPassword => { + if self.inner.opts.enable_cleartext_plugin() { + x.gen_data(self.inner.opts.pass(), &self.inner.nonce) + } else { + return Err(DriverError::CleartextPluginDisabled.into()); + } + } + x @ AuthPlugin::Ed25519 => x.gen_data(self.inner.opts.pass(), &self.inner.nonce), + x @ AuthPlugin::Other(_) => x.gen_data(self.inner.opts.pass(), &self.inner.nonce), + }; if let Some(plugin_data) = plugin_data { - self.write_struct(&plugin_data).await?; + self.write_struct(&plugin_data.into_owned()).await?; } else { - self.write_packet(crate::BUFFER_POOL.get()).await?; + self.write_packet(crate::buffer_pool().get()).await?; } self.continue_auth().await?; @@ -647,6 +714,18 @@ impl Conn { self.continue_caching_sha2_password_auth().await?; Ok(()) } + AuthPlugin::MysqlClearPassword => { + if self.inner.opts.enable_cleartext_plugin() { + self.continue_mysql_native_password_auth().await?; + Ok(()) + } else { + Err(DriverError::CleartextPluginDisabled.into()) + } + } + AuthPlugin::Ed25519 => { + self.continue_ed25519_auth().await?; + Ok(()) + } AuthPlugin::Other(ref name) => Err(DriverError::UnknownAuthPlugin { name: String::from_utf8_lossy(name.as_ref()).to_string(), } @@ -669,9 +748,27 @@ impl Conn { Ok(()) } + async fn continue_ed25519_auth(&mut self) -> Result<()> { + let packet = self.read_packet().await?; + match packet.first() { + Some(0x00) => { + // ok packet for empty password + Ok(()) + } + Some(0xfe) if !self.inner.auth_switched => { + let auth_switch_request = ParseBuf(&packet).parse::(())?; + self.perform_auth_switch(auth_switch_request).await + } + _ => Err(DriverError::UnexpectedPacket { + payload: packet.to_vec(), + } + .into()), + } + } + async fn continue_caching_sha2_password_auth(&mut self) -> Result<()> { let packet = self.read_packet().await?; - match packet.get(0) { + match packet.first() { Some(0x00) => { // ok packet for empty password Ok(()) @@ -683,20 +780,25 @@ impl Conn { } Some(0x04) => { let pass = self.inner.opts.pass().unwrap_or_default(); - let mut pass = crate::BUFFER_POOL.get_with(pass.as_bytes()); + let mut pass = crate::buffer_pool().get_with(pass.as_bytes()); pass.as_mut().push(0); - if self.is_secure() { + if self.is_secure() || self.is_socket() { self.write_packet(pass).await?; } else { - self.write_bytes(&[0x02][..]).await?; - let packet = self.read_packet().await?; - let key = &packet[1..]; + if self.inner.server_key.is_none() { + self.write_bytes(&[0x02][..]).await?; + let packet = self.read_packet().await?; + self.inner.server_key = Some(packet[1..].to_vec()); + } for (i, byte) in pass.as_mut().iter_mut().enumerate() { *byte ^= self.inner.nonce[i % self.inner.nonce.len()]; } - let encrypted_pass = crypto::encrypt(&*pass, key); - self.write_bytes(&*encrypted_pass).await?; + let encrypted_pass = crypto::encrypt( + &pass, + self.inner.server_key.as_deref().expect("unreachable"), + ); + self.write_bytes(&encrypted_pass).await?; }; self.drop_packet().await?; Ok(()) @@ -707,7 +809,7 @@ impl Conn { .into()), }, Some(0xfe) if !self.inner.auth_switched => { - let auth_switch_request = ParseBuf(&*packet).parse::(())?; + let auth_switch_request = ParseBuf(&packet).parse::(())?; self.perform_auth_switch(auth_switch_request).await?; Ok(()) } @@ -720,13 +822,13 @@ impl Conn { async fn continue_mysql_native_password_auth(&mut self) -> Result<()> { let packet = self.read_packet().await?; - match packet.get(0) { + match packet.first() { Some(0x00) => Ok(()), Some(0xfe) if !self.inner.auth_switched => { let auth_switch = if packet.len() > 1 { - ParseBuf(&*packet).parse(())? + ParseBuf(&packet).parse(())? } else { - let _ = ParseBuf(&*packet).parse::(())?; + let _ = ParseBuf(&packet).parse::(())?; // map OldAuthSwitch to AuthSwitch with mysql_old_password plugin AuthSwitchRequest::new( "mysql_old_password".as_bytes(), @@ -749,16 +851,16 @@ impl Conn { .capabilities() .contains(CapabilityFlags::CLIENT_DEPRECATE_EOF) { - ParseBuf(&*packet) + ParseBuf(packet) .parse::>(self.capabilities()) .map(|x| x.into_inner()) } else { - ParseBuf(&*packet) + ParseBuf(packet) .parse::>(self.capabilities()) .map(|x| x.into_inner()) } } else { - ParseBuf(&*packet) + ParseBuf(packet) .parse::>(self.capabilities()) .map(|x| x.into_inner()) }; @@ -766,7 +868,19 @@ impl Conn { if let Ok(ok_packet) = ok_packet { self.handle_ok(ok_packet.into_owned()); } else { - let err_packet = ParseBuf(&*packet).parse::(self.capabilities()); + // If we haven't completed the handshake the server will not be aware of our + // capabilities and so it will behave as if we have none. In particular, the error + // packet will not contain a SQL State field even if our capabilities do contain the + // `CLIENT_PROTOCOL_41` flag. Therefore it is necessary to parse an incoming packet + // with no capability assumptions if we have not completed the handshake. + // + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase.html + let capabilities = if self.inner.handshake_complete { + self.capabilities() + } else { + CapabilityFlags::empty() + }; + let err_packet = ParseBuf(packet).parse::(capabilities); if let Ok(err_packet) = err_packet { self.handle_err(err_packet)?; return Ok(true); @@ -815,13 +929,13 @@ impl Conn { /// Writes bytes to a server. pub(crate) async fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> { - let buf = crate::BUFFER_POOL.get_with(bytes); + let buf = crate::buffer_pool().get_with(bytes); self.write_packet(buf).await } /// Sends a serializable structure to a server. pub(crate) async fn write_struct(&mut self, x: &T) -> Result<()> { - let mut buf = crate::BUFFER_POOL.get(); + let mut buf = crate::buffer_pool().get(); x.serialize(buf.as_mut()); self.write_packet(buf).await } @@ -847,7 +961,7 @@ impl Conn { T: AsRef<[u8]>, { let cmd_data = cmd_data.as_ref(); - let mut buf = crate::BUFFER_POOL.get(); + let mut buf = crate::buffer_pool().get(); let body = buf.as_mut(); body.push(cmd as u8); body.extend_from_slice(cmd_data); @@ -869,6 +983,16 @@ impl Conn { Ok(()) } + async fn run_setup_commands(&mut self) -> Result<()> { + let mut setup = self.inner.opts.setup().to_vec(); + + while let Some(query) = setup.pop() { + self.query_drop(query).await?; + } + + Ok(()) + } + /// Returns a future that resolves to [`Conn`]. pub fn new>(opts: T) -> crate::BoxFuture<'static, Conn> { let opts = opts.into(); @@ -880,7 +1004,7 @@ impl Conn { { Stream::connect_socket(_path.to_owned()).await? } - #[cfg(target_os = "windows")] + #[cfg(not(unix))] return Err(crate::DriverError::NamedPipesDisabled.into()); } else { let keepalive = opts @@ -899,6 +1023,7 @@ impl Conn { conn.read_settings().await?; conn.reconnect_via_socket_if_needed().await?; conn.run_init_commands().await?; + conn.run_setup_commands().await?; Ok(conn) } @@ -931,49 +1056,128 @@ impl Conn { /// Configures the connection based on server settings. In particular: /// /// * It reads and stores socket address inside the connection unless if socket address is - /// already in [`Opts`] or if `prefer_socket` is `false`. + /// already in [`Opts`] or if `prefer_socket` is `false`. /// /// * It reads and stores `max_allowed_packet` in the connection unless it's already in [`Opts`] /// /// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`] /// async fn read_settings(&mut self) -> Result<()> { - let read_socket = self.inner.opts.prefer_socket() && self.inner.socket.is_none(); - let read_max_allowed_packet = self.opts().max_allowed_packet().is_none(); - let read_wait_timeout = self.opts().wait_timeout().is_none(); + enum Action { + Load(Cfg), + Apply(CfgData), + } - let settings: Option = if read_socket || read_max_allowed_packet || read_wait_timeout { - self.query_internal("SELECT @@socket, @@max_allowed_packet, @@wait_timeout") - .await? - } else { - None - }; + enum CfgData { + MaxAllowedPacket(usize), + WaitTimeout(usize), + } - // set socket inside the connection - if read_socket { - self.inner.socket = settings.as_ref().map(|s| s.get("@@socket")).unwrap_or(None); + impl CfgData { + fn apply(&self, conn: &mut Conn) { + match self { + Self::MaxAllowedPacket(value) => { + if let Some(stream) = conn.inner.stream.as_mut() { + stream.set_max_allowed_packet(*value); + } + } + Self::WaitTimeout(value) => { + conn.inner.wait_timeout = Duration::from_secs(*value as u64); + } + } + } } - // set max_allowed_packet - let max_allowed_packet = if read_max_allowed_packet { - settings - .as_ref() - .map(|s| s.get("@@max_allowed_packet")) - .unwrap() - } else { - self.opts().max_allowed_packet() - }; - if let Some(stream) = self.inner.stream.as_mut() { - stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET)); + enum Cfg { + Socket, + MaxAllowedPacket, + WaitTimeout, + } + + impl Cfg { + const fn name(&self) -> &'static str { + match self { + Self::Socket => "@@socket", + Self::MaxAllowedPacket => "@@max_allowed_packet", + Self::WaitTimeout => "@@wait_timeout", + } + } + + fn apply(&self, conn: &mut Conn, value: Option) { + match self { + Cfg::Socket => { + conn.inner.socket = value.and_then(crate::from_value); + } + Cfg::MaxAllowedPacket => { + if let Some(stream) = conn.inner.stream.as_mut() { + stream.set_max_allowed_packet( + value + .and_then(crate::from_value) + .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET), + ); + } + } + Cfg::WaitTimeout => { + conn.inner.wait_timeout = Duration::from_secs( + value + .and_then(crate::from_value) + .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64, + ); + } + } + } } - // set read_wait_timeout - let wait_timeout = if read_wait_timeout { - settings.as_ref().map(|s| s.get("@@wait_timeout")).unwrap() + let mut actions = vec![ + if let Some(x) = self.opts().max_allowed_packet() { + Action::Apply(CfgData::MaxAllowedPacket(x)) + } else { + Action::Load(Cfg::MaxAllowedPacket) + }, + if let Some(x) = self.opts().wait_timeout() { + Action::Apply(CfgData::WaitTimeout(x)) + } else { + Action::Load(Cfg::WaitTimeout) + }, + ]; + + if self.inner.opts.prefer_socket() && self.inner.socket.is_none() { + actions.push(Action::Load(Cfg::Socket)) + } + + let loads = actions + .iter() + .filter_map(|x| match x { + Action::Load(x) => Some(x), + Action::Apply(_) => None, + }) + .collect::>(); + + let loaded = if !loads.is_empty() { + let query = loads + .iter() + .zip(std::iter::once(' ').chain(std::iter::repeat(','))) + .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| { + acc.push(prefix); + acc.push_str(cfg.name()); + acc + }); + + self.query_internal::(query) + .await? + .map(|row| row.unwrap()) + .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()]) } else { - self.opts().wait_timeout() + vec![] }; - self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64); + let mut loaded = loaded.into_iter(); + + for action in actions { + match action { + Action::Load(cfg) => cfg.apply(self, loaded.next()), + Action::Apply(cfg) => cfg.apply(self), + } + } Ok(()) } @@ -981,6 +1185,11 @@ impl Conn { /// Returns true if time since last IO exceeds `wait_timeout` /// (or `conn_ttl` if specified in opts). fn expired(&self) -> bool { + if let Some(deadline) = self.inner.ttl_deadline { + if Instant::now() > deadline { + return true; + } + } let ttl = self .inner .opts @@ -994,12 +1203,13 @@ impl Conn { self.inner.last_io.elapsed() } - /// Executes `COM_RESET_CONNECTION` on `self`. + /// Executes [`COM_RESET_CONNECTION`][1]. /// - /// If server version is older than 5.7.2, then it'll reconnect. - pub async fn reset(&mut self) -> Result<()> { - let pool = self.inner.pool.clone(); - + /// Returns `false` if command is not supported (requires MySql >5.7.2, MariaDb >10.2.3). + /// For older versions consider using [`Conn::change_user`]. + /// + /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-reset-connection.html + pub async fn reset(&mut self) -> Result { let supports_com_reset_connection = if self.inner.is_mariadb { self.inner.version >= (10, 2, 4) } else { @@ -1009,21 +1219,66 @@ impl Conn { if supports_com_reset_connection { self.routine(routines::ResetRoutine).await?; - } else { - let opts = self.inner.opts.clone(); - let old_conn = std::mem::replace(self, Conn::new(opts).await?); - // tidy up the old connection - old_conn.close_conn().await?; - }; + self.inner.stmt_cache.clear(); + self.inner.infile_handler = None; + self.run_setup_commands().await?; + } + Ok(supports_com_reset_connection) + } + + /// Executes [`COM_CHANGE_USER`][1]. + /// + /// This might be used as an older and slower alternative to `COM_RESET_CONNECTION` that + /// works on MySql prior to 5.7.3 (MariaDb prior ot 10.2.4). + /// + /// ## Note + /// + /// * Using non-default `opts` for a pooled connection is discouraging. + /// * Connection options will be permanently updated. + /// + /// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html + pub async fn change_user(&mut self, opts: ChangeUserOpts) -> Result<()> { + // We'll kick this connection from a pool if opts are changed. + if opts != ChangeUserOpts::default() { + let mut opts_changed = false; + if let Some(user) = opts.user() { + opts_changed |= user != self.opts().user() + }; + if let Some(pass) = opts.pass() { + opts_changed |= pass != self.opts().pass() + }; + if let Some(db_name) = opts.db_name() { + opts_changed |= db_name != self.opts().db_name() + }; + if opts_changed { + if let Some(pool) = self.inner.pool.take() { + pool.cancel_connection(); + } + } + } + + let conn_opts = &mut self.inner.opts; + opts.update_opts(conn_opts); + self.routine(routines::ChangeUser).await?; self.inner.stmt_cache.clear(); self.inner.infile_handler = None; - self.inner.pool = pool; + self.run_setup_commands().await?; Ok(()) } + /// Resets the connection upon returning it to a pool. + /// + /// Will invoke `COM_CHANGE_USER` if `COM_RESET_CONNECTION` is not supported. + async fn reset_for_pool(mut self) -> Result { + if !self.reset().await? { + self.change_user(Default::default()).await?; + } + Ok(self) + } + /// Requires that `self.inner.tx_status != TxStatus::None` - async fn rollback_transaction(&mut self) -> Result<()> { + pub(crate) async fn rollback_transaction(&mut self) -> Result<()> { debug_assert_ne!(self.inner.tx_status, TxStatus::None); self.inner.tx_status = TxStatus::None; self.query_drop("ROLLBACK").await @@ -1098,221 +1353,69 @@ impl Conn { } Ok(self) } - - async fn register_as_slave(&mut self, server_id: u32) -> Result<()> { - use mysql_common::packets::ComRegisterSlave; - - self.query_drop("SET @master_binlog_checksum='ALL'").await?; - self.write_command(&ComRegisterSlave::new(server_id)) - .await?; - - // Server will respond with OK. - self.read_packet().await?; - - Ok(()) - } - - async fn request_binlog(&mut self, request: BinlogRequest<'_>) -> Result<()> { - self.register_as_slave(request.server_id()).await?; - self.write_command(&request.as_cmd()).await?; - Ok(()) - } - - pub async fn get_binlog_stream(mut self, request: BinlogRequest<'_>) -> Result { - self.request_binlog(request).await?; - - Ok(BinlogStream::new(self)) - } } #[cfg(test)] mod test { use bytes::Bytes; use futures_util::stream::{self, StreamExt}; - use mysql_common::{binlog::events::EventData, constants::MAX_PAYLOAD_LEN}; - use tokio::time::timeout; - - use std::time::Duration; + use mysql_common::constants::MAX_PAYLOAD_LEN; + use rand::Rng; + use tokio::{io::AsyncWriteExt, net::TcpListener}; use crate::{ - from_row, params, prelude::*, test_misc::get_opts, BinlogDumpFlags, BinlogRequest, Conn, - Error, OptsBuilder, Pool, WhiteListFsHandler, + from_row, params, prelude::*, test_misc::get_opts, ChangeUserOpts, Conn, Error, + OptsBuilder, Pool, ServerError, Value, WhiteListFsHandler, }; - async fn gen_dummy_data() -> super::Result<()> { - let mut conn = Conn::new(get_opts()).await?; + #[tokio::test] + async fn should_return_found_rows_if_flag_is_set() -> super::Result<()> { + let opts = get_opts().client_found_rows(true); + let mut conn = Conn::new(opts).await.unwrap(); - "CREATE TABLE IF NOT EXISTS customers (customer_id int not null)" + "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)" .ignore(&mut conn) .await?; - for i in 0_u8..100 { - "INSERT INTO customers(customer_id) VALUES (?)" - .with((i,)) - .ignore(&mut conn) - .await?; - } - - "DROP TABLE customers".ignore(&mut conn).await?; - - Ok(()) - } - - async fn create_binlog_stream_conn(pool: Option<&Pool>) -> super::Result<(Conn, Vec, u64)> { - let mut conn = match pool { - None => Conn::new(get_opts()).await.unwrap(), - Some(pool) => pool.get_conn().await.unwrap(), - }; - - if let Ok(Some(gtid_mode)) = "SELECT @@GLOBAL.GTID_MODE" - .first::(&mut conn) - .await - { - if !gtid_mode.starts_with("ON") { - panic!( - "GTID_MODE is disabled \ - (enable using --gtid_mode=ON --enforce_gtid_consistency=ON)" - ); - } - } - - let row: crate::Row = "SHOW BINARY LOGS".first(&mut conn).await.unwrap().unwrap(); - let filename = row.get(0).unwrap(); - let position = row.get(1).unwrap(); + "INSERT INTO mysql.found_rows (val) VALUES (1)" + .ignore(&mut conn) + .await?; - gen_dummy_data().await.unwrap(); - Ok((conn, filename, position)) - } + // Inserted one row, affected should be one. + assert_eq!(conn.affected_rows(), 1); - #[tokio::test] - async fn should_read_binlog() -> super::Result<()> { - read_binlog_streams_and_close_their_connections(None, (12, 13, 14)) - .await - .unwrap(); - - let pool = Pool::new(get_opts()); - read_binlog_streams_and_close_their_connections(Some(&pool), (15, 16, 17)) - .await - .unwrap(); + "UPDATE mysql.found_rows SET val = 1 WHERE val = 1" + .ignore(&mut conn) + .await?; - // Disconnecting the pool verifies that closing the binlog connections - // left the pool in a sane state. - timeout(Duration::from_secs(10), pool.disconnect()) - .await - .unwrap() - .unwrap(); + // The query doesn't affect any rows, but due to us wanting FOUND rows, + // this has to return one. + assert_eq!(conn.affected_rows(), 1); Ok(()) } - async fn read_binlog_streams_and_close_their_connections( - pool: Option<&Pool>, - binlog_server_ids: (u32, u32, u32), - ) -> super::Result<()> { - // iterate using COM_BINLOG_DUMP - let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); - let is_mariadb = conn.inner.is_mariadb; - - let mut binlog_stream = conn - .get_binlog_stream( - BinlogRequest::new(binlog_server_ids.0) - .with_filename(filename) - .with_pos(pos), - ) - .await - .unwrap(); - - let mut events_num = 0; - while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await { - let event = event.unwrap(); - events_num += 1; - - // assert that event type is known - event.header().event_type().unwrap(); - - // iterate over rows of an event - match event.read_data()?.unwrap() { - EventData::RowsEvent(re) => { - let tme = binlog_stream.get_tme(re.table_id()); - for row in re.rows(tme.unwrap()) { - row.unwrap(); - } - } - _ => (), - } - } - assert!(events_num > 0); - timeout(Duration::from_secs(10), binlog_stream.close()) - .await - .unwrap() - .unwrap(); + #[tokio::test] + async fn should_not_return_found_rows_if_flag_is_not_set() -> super::Result<()> { + let mut conn = Conn::new(get_opts()).await.unwrap(); - if !is_mariadb { - // iterate using COM_BINLOG_DUMP_GTID - let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); - - let mut binlog_stream = conn - .get_binlog_stream( - BinlogRequest::new(binlog_server_ids.1) - .with_use_gtid(true) - .with_filename(filename) - .with_pos(pos), - ) - .await - .unwrap(); + "CREATE TEMPORARY TABLE mysql.found_rows (id INT PRIMARY KEY AUTO_INCREMENT, val INT)" + .ignore(&mut conn) + .await?; - events_num = 0; - while let Ok(Some(event)) = timeout(Duration::from_secs(10), binlog_stream.next()).await - { - let event = event.unwrap(); - events_num += 1; - - // assert that event type is known - event.header().event_type().unwrap(); - - // iterate over rows of an event - match event.read_data()?.unwrap() { - EventData::RowsEvent(re) => { - let tme = binlog_stream.get_tme(re.table_id()); - for row in re.rows(tme.unwrap()) { - row.unwrap(); - } - } - _ => (), - } - } - assert!(events_num > 0); - timeout(Duration::from_secs(10), binlog_stream.close()) - .await - .unwrap() - .unwrap(); - } + "INSERT INTO mysql.found_rows (val) VALUES (1)" + .ignore(&mut conn) + .await?; - // iterate using COM_BINLOG_DUMP with BINLOG_DUMP_NON_BLOCK flag - let (conn, filename, pos) = create_binlog_stream_conn(pool).await.unwrap(); + // Inserted one row, affected should be one. + assert_eq!(conn.affected_rows(), 1); - let mut binlog_stream = conn - .get_binlog_stream( - BinlogRequest::new(binlog_server_ids.2) - .with_filename(filename) - .with_pos(pos) - .with_flags(BinlogDumpFlags::BINLOG_DUMP_NON_BLOCK), - ) - .await - .unwrap(); + "UPDATE mysql.found_rows SET val = 1 WHERE val = 1" + .ignore(&mut conn) + .await?; - events_num = 0; - while let Some(event) = binlog_stream.next().await { - let event = event.unwrap(); - events_num += 1; - event.header().event_type().unwrap(); - event.read_data().unwrap(); - } - assert!(events_num > 0); - timeout(Duration::from_secs(10), binlog_stream.close()) - .await - .unwrap() - .unwrap(); + // The query doesn't affect any rows. + assert_eq!(conn.affected_rows(), 0); Ok(()) } @@ -1320,6 +1423,7 @@ mod test { #[test] fn opts_should_satisfy_send_and_sync() { struct A(T); + #[allow(clippy::unnecessary_operation)] A(get_opts()); } @@ -1388,16 +1492,18 @@ mod test { .filter(|variant| plugins.iter().any(|p| p == variant.0)); for (plug, val, pass) in variants { + dbg!((plug, val, pass, conn.inner.version)); + + if plug == "mysql_native_password" && conn.inner.version >= (8, 4, 0) { + continue; + } + let _ = conn.query_drop("DROP USER 'test_user'@'%'").await; let query = format!("CREATE USER 'test_user'@'%' IDENTIFIED WITH {}", plug); conn.query_drop(query).await.unwrap(); - if (8, 0, 11) <= conn.inner.version && conn.inner.version <= (9, 0, 0) { - conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass)) - .await - .unwrap(); - } else { + if conn.inner.version < (8, 0, 11) { conn.query_drop(format!("SET old_passwords = {}", val)) .await .unwrap(); @@ -1407,6 +1513,10 @@ mod test { )) .await .unwrap(); + } else { + conn.query_drop(format!("SET PASSWORD FOR 'test_user'@'%' = '{}'", pass)) + .await + .unwrap(); }; let opts = get_opts() @@ -1452,16 +1562,216 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_execute_setup_queries_on_reset() -> super::Result<()> { + let opts = OptsBuilder::from_opts(get_opts()).setup(vec!["SET @a = 42", "SET @b = 'foo'"]); + let mut conn = Conn::new(opts).await?; + + // initial run + let mut result: Vec<(u8, String)> = conn.query("SELECT @a, @b").await?; + assert_eq!(result, vec![(42, "foo".into())]); + + // after reset + if conn.reset().await? { + result = conn.query("SELECT @a, @b").await?; + assert_eq!(result, vec![(42, "foo".into())]); + } + + // after change user + conn.change_user(Default::default()).await?; + result = conn.query("SELECT @a, @b").await?; + assert_eq!(result, vec![(42, "foo".into())]); + + conn.disconnect().await?; + Ok(()) + } + #[tokio::test] async fn should_reset_the_connection() -> super::Result<()> { let mut conn = Conn::new(get_opts()).await?; - conn.exec_drop("SELECT ?", (1_u8,)).await?; - conn.reset().await?; - conn.exec_drop("SELECT ?", (1_u8,)).await?; + + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); + + conn.query_drop("SET @foo = 'foo'").await?; + + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", + ); + + if conn.reset().await? { + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); + } else { + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", + ); + } + conn.disconnect().await?; Ok(()) } + #[tokio::test] + async fn should_change_user() -> super::Result<()> { + /// Whether particular authentication plugin should be tested on the current database. + type ShouldRunFn = fn(bool, (u16, u16, u16)) -> bool; + /// Generates `CREATE USER` and `SET PASSWORD` statements + type CreateUserFn = fn(bool, (u16, u16, u16), &str) -> Vec; + + #[allow(clippy::type_complexity)] + const TEST_MATRIX: [(&str, ShouldRunFn, CreateUserFn); 4] = [ + ( + "mysql_old_password", + |is_mariadb, version| is_mariadb || version < (5, 7, 0), + |is_mariadb, version, pass| { + if is_mariadb { + vec![ + "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password".into(), + "SET old_passwords=1".into(), + format!("ALTER USER '__mats'@'%' IDENTIFIED BY '{pass}'"), + "SET old_passwords=0".into(), + ] + } else if matches!(version, (5, 6, _)) { + vec![ + "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_old_password".into(), + format!("SET PASSWORD FOR '__mats'@'%' = OLD_PASSWORD('{pass}')"), + ] + } else { + vec![ + "CREATE USER '__mats'@'%'".into(), + format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"), + ] + } + }, + ), + ( + "mysql_native_password", + |is_mariadb, version| is_mariadb || version < (8, 4, 0), + |is_mariadb, version, pass| { + if is_mariadb { + vec![ + format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password AS PASSWORD('{pass}')") + ] + } else if version < (8, 0, 0) { + vec![ + format!( + "CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password" + ), + format!("SET old_passwords = 0"), + format!("SET PASSWORD FOR '__mats'@'%' = PASSWORD('{pass}')"), + ] + } else { + vec![ + format!("CREATE USER '__mats'@'%' IDENTIFIED WITH mysql_native_password BY '{pass}'") + ] + } + }, + ), + ( + "caching_sha2_password", + |is_mariadb, version| !is_mariadb && version >= (5, 8, 0), + |_is_mariadb, _version, pass| { + vec![ + format!("CREATE USER '__mats'@'%' IDENTIFIED WITH caching_sha2_password BY '{pass}'") + ] + }, + ), + ( + "client_ed25519", + |is_mariadb, version| is_mariadb && version >= (11, 6, 2), + |_is_mariadb, _version, pass| { + vec![format!( + "CREATE USER '__mats'@'%' IDENTIFIED WITH ed25519 AS PASSWORD('{pass}')" + )] + }, + ), + ]; + + fn random_pass() -> String { + let mut rng = rand::rng(); + let pass: [u8; 10] = rng.random(); + + IntoIterator::into_iter(pass) + .map(|x| ((x % (123 - 97)) + 97) as char) + .collect() + } + + let mut conn = Conn::new(get_opts()).await?; + + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); + + conn.query_drop("SET @foo = 'foo'").await?; + + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", + ); + + conn.change_user(Default::default()).await?; + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); + + for (i, (plugin, should_run, create_statements)) in TEST_MATRIX.iter().enumerate() { + dbg!(plugin); + let is_mariadb = conn.inner.is_mariadb; + let version = conn.server_version(); + + if should_run(is_mariadb, version) { + let pass = random_pass(); + + let result = conn + .query_drop("DROP USER /*!50700 IF EXISTS */ /*M!100103 IF EXISTS */ __mats") + .await; + + if matches!(version, (5, 6, _)) && i == 0 { + // IF EXISTS is not supported on 5.6 so the query will fail on the first iteration + drop(result); + } else { + result.unwrap(); + } + + for statement in create_statements(is_mariadb, version, &pass) { + conn.query_drop(dbg!(statement)).await.unwrap(); + } + + let mut conn2 = Conn::new(get_opts().secure_auth(false)).await.unwrap(); + conn2 + .change_user( + ChangeUserOpts::default() + .with_db_name(None) + .with_user(Some("__mats".into())) + .with_pass(Some(pass)), + ) + .await + .unwrap(); + + let (db, user) = conn2 + .query_first::<(Option, String), _>("SELECT DATABASE(), USER();") + .await + .unwrap() + .unwrap(); + assert_eq!(db, None); + assert!(user.starts_with("__mats")); + + conn2.disconnect().await.unwrap(); + } + } + + Ok(()) + } + #[tokio::test] async fn should_not_cache_statements_if_stmt_cache_size_is_zero() -> super::Result<()> { let opts = OptsBuilder::from_opts(get_opts()).stmt_cache_size(0); @@ -1520,7 +1830,7 @@ mod test { async fn should_perform_queries() -> super::Result<()> { let mut conn = Conn::new(get_opts()).await?; for x in (MAX_PAYLOAD_LEN - 2)..=(MAX_PAYLOAD_LEN + 2) { - let long_string = ::std::iter::repeat('A').take(x).collect::(); + let long_string = "A".repeat(x); let result: Vec<(String, u8)> = conn .query(format!(r"SELECT '{}', 231", long_string)) .await?; @@ -1572,15 +1882,11 @@ mod test { #[tokio::test] async fn should_execute_statement() -> super::Result<()> { - let long_string = ::std::iter::repeat('A') - .take(18 * 1024 * 1024) - .collect::(); + let long_string = "A".repeat(18 * 1024 * 1024); let mut conn = Conn::new(get_opts()).await?; let stmt = conn.prep(r"SELECT ?").await?; let result = conn.exec_iter(&stmt, (&long_string,)).await?; - let mut mapped = result - .map_and_drop(|row| from_row::<(String,)>(row)) - .await?; + let mut mapped = result.map_and_drop(from_row::<(String,)>).await?; assert_eq!(mapped.len(), 1); assert_eq!(mapped.pop(), Some((long_string,))); let result = conn.exec_iter(&stmt, (42_u8,)).await?; @@ -1603,7 +1909,7 @@ mod test { .exec_iter(&stmt, params! { "foo" => "quux", "bar" => "baz" }) .await?; let mut mapped = result - .map_and_drop(|row| from_row::<(String, String, String, u8)>(row)) + .map_and_drop(from_row::<(String, String, String, u8)>) .await?; assert_eq!(mapped.len(), 1); assert_eq!( @@ -1695,7 +2001,7 @@ mod test { let result = conn.query_iter(q).await?; let loaded_structs = result - .map_and_drop(|row| crate::from_row::<(Vec, Vec, u64, Vec)>(row)) + .map_and_drop(crate::from_row::<(Vec, Vec, u64, Vec)>) .await?; conn.disconnect().await?; @@ -2035,6 +2341,45 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_handle_initial_error_packet() { + let header = [ + 0x68, 0x00, 0x00, // packet_length + 0x00, // sequence + 0xff, // error_header + 0x69, 0x04, // error_code + ]; + let error_message = "Host '172.17.0.1' is blocked because of many connection errors; unblock with 'mysqladmin flush-hosts'"; + + // Create a fake MySQL server that immediately replies with an error packet. + let listener = TcpListener::bind("127.0.0.1:0000").await.unwrap(); + + let listen_addr = listener.local_addr().unwrap(); + + tokio::task::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + stream.write_all(&header).await.unwrap(); + stream.write_all(error_message.as_bytes()).await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + let opts = OptsBuilder::default() + .ip_or_hostname(listen_addr.ip().to_string()) + .tcp_port(listen_addr.port()); + let server_err = match Conn::new(opts).await { + Err(Error::Server(server_err)) => server_err, + other => panic!("expected server error but got: {:?}", other), + }; + assert_eq!( + server_err, + ServerError { + code: 1129, + state: "HY000".to_owned(), + message: error_message.to_owned(), + } + ); + } + #[cfg(feature = "nightly")] mod bench { use crate::{conn::Conn, queryable::Queryable, test_misc::get_opts}; diff --git a/src/conn/pool/futures/disconnect_pool.rs b/src/conn/pool/futures/disconnect_pool.rs index 4e5c4f4d..b8a07d4d 100644 --- a/src/conn/pool/futures/disconnect_pool.rs +++ b/src/conn/pool/futures/disconnect_pool.rs @@ -60,7 +60,8 @@ impl Future for DisconnectPool { Some(drop) => match drop.send(None) { Ok(_) => { // Recycler is alive. Waiting for it to finish. - Poll::Ready(Ok(ready!(Box::pin(drop.closed()).as_mut().poll(cx)))) + ready!(Box::pin(drop.closed()).as_mut().poll(cx)); + Poll::Ready(Ok(())) } Err(_) => { // Recycler seem dead. No one will wake us. diff --git a/src/conn/pool/futures/get_conn.rs b/src/conn/pool/futures/get_conn.rs index 73e8a999..b89f9bc6 100644 --- a/src/conn/pool/futures/get_conn.rs +++ b/src/conn/pool/futures/get_conn.rs @@ -55,30 +55,25 @@ impl fmt::Debug for GetConnInner { } } -impl GetConnInner { - /// Take the value of the inner connection, resetting it to `New`. - pub fn take(&mut self) -> GetConnInner { - std::mem::replace(self, GetConnInner::New) - } -} - /// This future will take connection from a pool and resolve to [`Conn`]. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct GetConn { - pub(crate) queue_id: Option, + pub(crate) queue_id: QueueId, pub(crate) pool: Option, pub(crate) inner: GetConnInner, + reset_upon_returning_to_a_pool: bool, #[cfg(feature = "tracing")] span: Arc, } impl GetConn { - pub(crate) fn new(pool: &Pool) -> GetConn { + pub(crate) fn new(pool: &Pool, reset_upon_returning_to_a_pool: bool) -> GetConn { GetConn { - queue_id: None, + queue_id: QueueId::next(), pool: Some(pool.clone()), inner: GetConnInner::New, + reset_upon_returning_to_a_pool, #[cfg(feature = "tracing")] span: Arc::new(debug_span!("mysql_async::get_conn")), } @@ -110,10 +105,8 @@ impl Future for GetConn { loop { match self.inner { GetConnInner::New => { - let queued = self.queue_id.is_some(); - let queue_id = *self.queue_id.get_or_insert_with(QueueId::next); - let next = - ready!(Pin::new(self.pool_mut()).poll_new_conn(cx, queued, queue_id))?; + let queue_id = self.queue_id; + let next = ready!(self.pool_mut().poll_new_conn(cx, queue_id))?; match next { GetConnInner::Connecting(conn_fut) => { self.inner = GetConnInner::Connecting(conn_fut); @@ -141,6 +134,8 @@ impl Future for GetConn { return match result { Ok(mut c) => { c.inner.pool = Some(pool); + c.inner.reset_upon_returning_to_a_pool = + self.reset_upon_returning_to_a_pool; Poll::Ready(Ok(c)) } Err(e) => { @@ -152,12 +147,14 @@ impl Future for GetConn { GetConnInner::Checking(ref mut f) => { let result = ready!(Pin::new(f).poll(cx)); match result { - Ok(mut checked_conn) => { + Ok(mut c) => { self.inner = GetConnInner::Done; let pool = self.pool_take(); - checked_conn.inner.pool = Some(pool); - return Poll::Ready(Ok(checked_conn)); + c.inner.pool = Some(pool); + c.inner.reset_upon_returning_to_a_pool = + self.reset_upon_returning_to_a_pool; + return Poll::Ready(Ok(c)); } Err(_) => { // Idling connection is broken. We'll drop it and try again. @@ -181,10 +178,8 @@ impl Drop for GetConn { if let Some(pool) = self.pool.take() { // Remove the waker from the pool's waitlist in case this task was // woken by another waker, like from tokio::time::timeout. - if let Some(queue_id) = self.queue_id { - pool.unqueue(queue_id); - } - if let GetConnInner::Connecting(..) = self.inner.take() { + pool.unqueue(self.queue_id); + if let GetConnInner::Connecting(..) = self.inner { pool.cancel_connection(); } } diff --git a/src/conn/pool/metrics.rs b/src/conn/pool/metrics.rs new file mode 100644 index 00000000..c5c6a08a --- /dev/null +++ b/src/conn/pool/metrics.rs @@ -0,0 +1,145 @@ +use std::sync::atomic::AtomicUsize; + +use serde::Serialize; + +#[derive(Default, Debug, Serialize)] +#[non_exhaustive] +pub struct Metrics { + /// Guage of active connections to the database server, this includes both connections that have belong + /// to the pool, and connections currently owned by the application. + pub connection_count: AtomicUsize, + /// Guage of active connections that currently belong to the pool. + pub connections_in_pool: AtomicUsize, + /// Guage of GetConn requests that are currently active. + pub active_wait_requests: AtomicUsize, + /// Counter of connections that failed to be created. + pub create_failed: AtomicUsize, + /// Counter of connections discarded due to pool constraints. + pub discarded_superfluous_connection: AtomicUsize, + /// Counter of connections discarded due to being closed upon return to the pool. + pub discarded_unestablished_connection: AtomicUsize, + /// Counter of connections that have been returned to the pool dirty that needed to be cleaned + /// (ie. open transactions, pending queries, etc). + pub dirty_connection_return: AtomicUsize, + /// Counter of connections that have been discarded as they were expired by the pool constraints. + pub discarded_expired_connection: AtomicUsize, + /// Counter of connections that have been reset. + pub resetting_connection: AtomicUsize, + /// Counter of connections that have been discarded as they returned an error during cleanup. + pub discarded_error_during_cleanup: AtomicUsize, + /// Counter of connections that have been returned to the pool. + pub connection_returned_to_pool: AtomicUsize, + /// Histogram of times connections have spent outside of the pool. + #[cfg(feature = "hdrhistogram")] + pub connection_active_duration: MetricsHistogram, + /// Histogram of times connections have spent inside of the pool. + #[cfg(feature = "hdrhistogram")] + pub connection_idle_duration: MetricsHistogram, + /// Histogram of times connections have spent being checked for health. + #[cfg(feature = "hdrhistogram")] + pub check_duration: MetricsHistogram, + /// Histogram of time spent waiting to connect to the server. + #[cfg(feature = "hdrhistogram")] + pub connect_duration: MetricsHistogram, +} + +impl Metrics { + /// Resets all histograms to allow for histograms to be bound to a period of time (ie. between metric scrapes) + #[cfg(feature = "hdrhistogram")] + pub fn clear_histograms(&self) { + self.connection_active_duration.reset(); + self.connection_idle_duration.reset(); + self.check_duration.reset(); + self.connect_duration.reset(); + } +} + +#[cfg(feature = "hdrhistogram")] +#[derive(Debug)] +pub struct MetricsHistogram(std::sync::Mutex>); + +#[cfg(feature = "hdrhistogram")] +impl MetricsHistogram { + pub fn reset(&self) { + self.lock().unwrap().reset(); + } +} + +#[cfg(feature = "hdrhistogram")] +impl Default for MetricsHistogram { + fn default() -> Self { + let hdr = hdrhistogram::Histogram::new_with_bounds(1, 30 * 1_000_000, 2).unwrap(); + Self(std::sync::Mutex::new(hdr)) + } +} + +#[cfg(feature = "hdrhistogram")] +impl Serialize for MetricsHistogram { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let hdr = self.0.lock().unwrap(); + + /// A percentile of this histogram - for supporting serializers this + /// will ignore the key (such as `90%ile`) and instead add a + /// dimension to the metrics (such as `quantile=0.9`). + macro_rules! ile { + ($e:expr) => { + &MetricAlias(concat!("!|quantile=", $e), hdr.value_at_quantile($e)) + }; + } + + /// A 'qualified' metric name - for supporting serializers such as + /// serde_prometheus, this will prepend the metric name to this key, + /// outputting `response_time_count`, for example rather than just + /// `count`. + macro_rules! qual { + ($e:expr) => { + &MetricAlias("<|", $e) + }; + } + + use serde::ser::SerializeMap; + + let mut tup = serializer.serialize_map(Some(10))?; + tup.serialize_entry("samples", qual!(hdr.len()))?; + tup.serialize_entry("min", qual!(hdr.min()))?; + tup.serialize_entry("max", qual!(hdr.max()))?; + tup.serialize_entry("mean", qual!(hdr.mean()))?; + tup.serialize_entry("stdev", qual!(hdr.stdev()))?; + tup.serialize_entry("90%ile", ile!(0.9))?; + tup.serialize_entry("95%ile", ile!(0.95))?; + tup.serialize_entry("99%ile", ile!(0.99))?; + tup.serialize_entry("99.9%ile", ile!(0.999))?; + tup.serialize_entry("99.99%ile", ile!(0.9999))?; + tup.end() + } +} + +/// This is a mocked 'newtype' (eg. `A(u64)`) that instead allows us to +/// define our own type name that doesn't have to abide by Rust's constraints +/// on type names. This allows us to do some manipulation of our metrics, +/// allowing us to add dimensionality to our metrics via key=value pairs, or +/// key manipulation on serializers that support it. +#[cfg(feature = "hdrhistogram")] +struct MetricAlias(&'static str, T); + +#[cfg(feature = "hdrhistogram")] +impl Serialize for MetricAlias { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + serializer.serialize_newtype_struct(self.0, &self.1) + } +} + +#[cfg(feature = "hdrhistogram")] +impl std::ops::Deref for MetricsHistogram { + type Target = std::sync::Mutex>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/conn/pool/mod.rs b/src/conn/pool/mod.rs index 9fa107be..53688d9b 100644 --- a/src/conn/pool/mod.rs +++ b/src/conn/pool/mod.rs @@ -7,15 +7,14 @@ // modified, or distributed except according to those terms. use futures_util::FutureExt; -use priority_queue::PriorityQueue; +use keyed_priority_queue::KeyedPriorityQueue; use tokio::sync::mpsc; use std::{ - cmp::{Ordering, Reverse}, + borrow::Borrow, + cmp::Reverse, collections::VecDeque, - convert::TryFrom, hash::{Hash, Hasher}, - pin::Pin, str::FromStr, sync::{atomic, Arc, Mutex}, task::{Context, Poll, Waker}, @@ -26,12 +25,15 @@ use crate::{ conn::{pool::futures::*, Conn}, error::*, opts::{Opts, PoolOpts}, - queryable::transaction::{Transaction, TxOpts, TxStatus}, + queryable::transaction::{Transaction, TxOpts}, }; +pub use metrics::Metrics; + mod recycler; // this is a really unfortunate name for a module pub mod futures; +mod metrics; mod ttl_check_inerval; /// Connection that is idling in the pool. @@ -44,6 +46,15 @@ struct IdlingConn { } impl IdlingConn { + /// Returns true when this connection has a TTL and it elapsed. + fn expired(&self) -> bool { + self.conn + .inner + .ttl_deadline + .map(|t| Instant::now() > t) + .unwrap_or_default() + } + /// Returns duration elapsed since this connection is idling. fn elapsed(&self) -> Duration { self.since.elapsed() @@ -82,8 +93,11 @@ impl Exchange { // Spawn the Recycler. tokio::spawn(Recycler::new(pool_opts.clone(), inner.clone(), dropped)); - // Spawn the ttl check interval if `inactive_connection_ttl` isn't `0` - if pool_opts.inactive_connection_ttl() > Duration::from_secs(0) { + // Spawn the ttl check interval if `inactive_connection_ttl` isn't `0` or + // connections have an absolute TTL. + if pool_opts.inactive_connection_ttl() > Duration::ZERO + || pool_opts.abs_conn_ttl().is_some() + { tokio::spawn(TtlCheckInterval::new(pool_opts, inner.clone())); } } @@ -92,37 +106,42 @@ impl Exchange { #[derive(Default, Debug)] struct Waitlist { - queue: PriorityQueue, + queue: KeyedPriorityQueue, } impl Waitlist { - fn push(&mut self, w: Waker, queue_id: QueueId) { - self.queue.push( - QueuedWaker { - queue_id, - waker: Some(w), - }, - queue_id, - ); + /// Returns `true` if pushed. + fn push(&mut self, waker: Waker, queue_id: QueueId) -> bool { + // The documentation of Future::poll says: + // Note that on multiple calls to poll, only the Waker from + // the Context passed to the most recent call should be + // scheduled to receive a wakeup. + // + // But the the documentation of KeyedPriorityQueue::push says: + // Adds new element to queue if missing key or replace its + // priority if key exists. In second case doesn’t replace key. + // + // This means we have to remove first to have the most recent + // waker in the queue. + let occupied = self.remove(queue_id); + self.queue.push(QueuedWaker { queue_id, waker }, queue_id); + !occupied } fn pop(&mut self) -> Option { match self.queue.pop() { - Some((qw, _)) => Some(qw.waker.unwrap()), + Some((qw, _)) => Some(qw.waker), None => None, } } - fn remove(&mut self, id: QueueId) { - let tmp = QueuedWaker { - queue_id: id, - waker: None, - }; - self.queue.remove(&tmp); + /// Returns `true` if removed. + fn remove(&mut self, id: QueueId) -> bool { + self.queue.remove(&id).is_some() } - fn is_empty(&self) -> bool { - self.queue.is_empty() + fn peek_id(&mut self) -> Option { + self.queue.peek().map(|(qw, _)| qw.queue_id) } } @@ -142,26 +161,20 @@ impl QueueId { #[derive(Debug)] struct QueuedWaker { queue_id: QueueId, - waker: Option, + waker: Waker, } impl Eq for QueuedWaker {} -impl PartialEq for QueuedWaker { - fn eq(&self, other: &Self) -> bool { - self.queue_id == other.queue_id +impl Borrow for QueuedWaker { + fn borrow(&self) -> &QueueId { + &self.queue_id } } -impl Ord for QueuedWaker { - fn cmp(&self, other: &Self) -> Ordering { - self.queue_id.cmp(&other.queue_id) - } -} - -impl PartialOrd for QueuedWaker { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) +impl PartialEq for QueuedWaker { + fn eq(&self, other: &Self) -> bool { + self.queue_id == other.queue_id } } @@ -174,6 +187,7 @@ impl Hash for QueuedWaker { /// Connection pool data. #[derive(Debug)] pub struct Inner { + metrics: Arc, close: atomic::AtomicBool, closed: atomic::AtomicBool, exchange: Mutex, @@ -213,6 +227,7 @@ impl Pool { inner: Arc::new(Inner { close: false.into(), closed: false.into(), + metrics: Arc::new(Metrics::default()), exchange: Mutex::new(Exchange { available: VecDeque::with_capacity(pool_opts.constraints().max()), waiting: Waitlist::default(), @@ -224,6 +239,11 @@ impl Pool { } } + /// Returns metrics for the connection pool. + pub fn metrics(&self) -> Arc { + self.inner.metrics.clone() + } + /// Creates a new pool of connections. pub fn from_url>(url: T) -> Result { let opts = Opts::from_str(url.as_ref())?; @@ -232,7 +252,8 @@ impl Pool { /// Async function that resolves to `Conn`. pub fn get_conn(&self) -> GetConn { - GetConn::new(self) + let reset_connection = self.opts.pool_opts().reset_connection(); + GetConn::new(self, reset_connection) } /// Starts a new transaction. @@ -253,25 +274,6 @@ impl Pool { fn return_conn(&mut self, conn: Conn) { // NOTE: we're not in async context here, so we can't block or return NotReady // any and all cleanup work _has_ to be done in the spawned recycler - - // fast-path for when the connection is immediately ready to be reused - if conn.inner.stream.is_some() - && !conn.inner.disconnected - && !conn.expired() - && conn.inner.tx_status == TxStatus::None - && !conn.has_pending_result() - && !self.inner.close.load(atomic::Ordering::Acquire) - { - let mut exchange = self.inner.exchange.lock().unwrap(); - if exchange.available.len() < self.opts.pool_opts().active_bound() { - exchange.available.push_back(conn.into()); - if let Some(w) = exchange.waiting.pop() { - w.wake(); - } - return; - } - } - self.send_to_recycler(conn); } @@ -296,9 +298,17 @@ impl Pool { /// /// Decreases the exist counter since a broken or dropped connection should not count towards /// the total. - fn cancel_connection(&self) { + pub(super) fn cancel_connection(&self) { let mut exchange = self.inner.exchange.lock().unwrap(); exchange.exist -= 1; + self.inner + .metrics + .create_failed + .fetch_add(1, atomic::Ordering::Relaxed); + self.inner + .metrics + .connection_count + .store(exchange.exist, atomic::Ordering::Relaxed); // we just enabled the creation of a new connection! if let Some(w) = exchange.waiting.pop() { w.wake(); @@ -307,18 +317,8 @@ impl Pool { /// Poll the pool for an available connection. fn poll_new_conn( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - queued: bool, - queue_id: QueueId, - ) -> Poll> { - self.poll_new_conn_inner(cx, queued, queue_id) - } - - fn poll_new_conn_inner( - self: Pin<&mut Self>, + &mut self, cx: &mut Context<'_>, - queued: bool, queue_id: QueueId, ) -> Poll> { let mut exchange = self.inner.exchange.lock().unwrap(); @@ -332,17 +332,48 @@ impl Pool { exchange.spawn_futures_if_needed(&self.inner); - // Check if others are waiting and we're not queued. - if !exchange.waiting.is_empty() && !queued { - exchange.waiting.push(cx.waker().clone(), queue_id); + // Check if we are higher priority than anything current + let highest = if let Some(cur) = exchange.waiting.peek_id() { + queue_id > cur + } else { + true + }; + + // If we are not, just queue + if !highest { + if exchange.waiting.push(cx.waker().clone(), queue_id) { + self.inner + .metrics + .active_wait_requests + .fetch_add(1, atomic::Ordering::Relaxed); + } return Poll::Pending; } - while let Some(IdlingConn { mut conn, .. }) = exchange.available.pop_back() { + #[allow(unused_variables)] // `since` is only used when `hdrhistogram` is enabled + while let Some(IdlingConn { mut conn, since }) = exchange.available.pop_back() { if !conn.expired() { + #[cfg(feature = "hdrhistogram")] + self.inner + .metrics + .connection_idle_duration + .lock() + .unwrap() + .saturating_record(since.elapsed().as_micros() as u64); + #[cfg(feature = "hdrhistogram")] + let metrics = self.metrics(); + conn.inner.active_since = Instant::now(); return Poll::Ready(Ok(GetConnInner::Checking( async move { conn.stream_mut()?.check().await?; + #[cfg(feature = "hdrhistogram")] + metrics + .check_duration + .lock() + .unwrap() + .saturating_record( + conn.inner.active_since.elapsed().as_micros() as u64 + ); Ok(conn) } .boxed(), @@ -352,25 +383,63 @@ impl Pool { } } + self.inner + .metrics + .connections_in_pool + .store(exchange.available.len(), atomic::Ordering::Relaxed); + // we didn't _immediately_ get one -- try to make one // we first try to just do a load so we don't do an unnecessary add then sub if exchange.exist < self.opts.pool_opts().constraints().max() { // we are allowed to make a new connection, so we will! exchange.exist += 1; + self.inner + .metrics + .connection_count + .store(exchange.exist, atomic::Ordering::Relaxed); + + let opts = self.opts.clone(); + #[cfg(feature = "hdrhistogram")] + let metrics = self.metrics(); + return Poll::Ready(Ok(GetConnInner::Connecting( - Conn::new(self.opts.clone()).boxed(), + async move { + let conn = Conn::new(opts).await; + #[cfg(feature = "hdrhistogram")] + if let Ok(conn) = &conn { + metrics + .connect_duration + .lock() + .unwrap() + .saturating_record( + conn.inner.active_since.elapsed().as_micros() as u64 + ); + } + conn + } + .boxed(), ))); } // Polled, but no conn available? Back into the queue. - exchange.waiting.push(cx.waker().clone(), queue_id); + if exchange.waiting.push(cx.waker().clone(), queue_id) { + self.inner + .metrics + .active_wait_requests + .fetch_add(1, atomic::Ordering::Relaxed); + } Poll::Pending } fn unqueue(&self, queue_id: QueueId) { let mut exchange = self.inner.exchange.lock().unwrap(); - exchange.waiting.remove(queue_id); + if exchange.waiting.remove(queue_id) { + self.inner + .metrics + .active_wait_requests + .fetch_sub(1, atomic::Ordering::Relaxed); + } } } @@ -398,15 +467,18 @@ impl Drop for Conn { #[cfg(test)] mod test { use futures_util::{ - future::{join_all, select, select_all, try_join_all}, - try_join, FutureExt, + future::{join_all, select, select_all, try_join_all, Either}, + poll, try_join, FutureExt, }; - use mysql_common::row::Row; use tokio::time::{sleep, timeout}; + use waker_fn::waker_fn; use std::{ cmp::Reverse, - task::{RawWaker, RawWakerVTable, Waker}, + future::Future, + pin::pin, + sync::{Arc, OnceLock}, + task::{Context, Poll, RawWaker, RawWakerVTable, Waker}, time::Duration, }; @@ -415,7 +487,7 @@ mod test { opts::PoolOpts, prelude::*, test_misc::get_opts, - PoolConstraints, TxOpts, + PoolConstraints, Row, TxOpts, Value, }; macro_rules! conn_ex_field { @@ -430,6 +502,61 @@ mod test { }; } + fn pool_with_one_connection() -> Pool { + let pool_opts = PoolOpts::new().with_constraints(PoolConstraints::new(1, 1).unwrap()); + let opts = get_opts().pool_opts(pool_opts.clone()); + Pool::new(opts) + } + + #[tokio::test] + async fn should_opt_out_of_connection_reset() -> super::Result<()> { + let pool_opts = PoolOpts::new().with_constraints(PoolConstraints::new(1, 1).unwrap()); + let opts = get_opts().pool_opts(pool_opts.clone()); + + let pool = Pool::new(opts.clone()); + + let mut conn = pool.get_conn().await.unwrap(); + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); + conn.query_drop("SET @foo = 'foo'").await?; + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", + ); + drop(conn); + + conn = pool.get_conn().await.unwrap(); + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + Value::NULL + ); + conn.query_drop("SET @foo = 'foo'").await?; + conn.reset_connection(false); + drop(conn); + + conn = pool.get_conn().await.unwrap(); + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", + ); + drop(conn); + pool.disconnect().await.unwrap(); + + let pool = Pool::new(opts.pool_opts(pool_opts.with_reset_connection(false))); + conn = pool.get_conn().await.unwrap(); + conn.query_drop("SET @foo = 'foo'").await?; + drop(conn); + conn = pool.get_conn().await.unwrap(); + assert_eq!( + conn.query_first::("SELECT @foo").await?.unwrap(), + "foo", + ); + drop(conn); + pool.disconnect().await + } + #[test] fn should_not_hang() -> super::Result<()> { pub struct Database { @@ -461,7 +588,7 @@ mod test { #[tokio::test] async fn should_connect() -> super::Result<()> { - let pool = Pool::new(get_opts()); + let pool = Pool::new(crate::Opts::from(get_opts())); pool.get_conn().await?.ping().await?; pool.disconnect().await?; Ok(()) @@ -492,6 +619,9 @@ mod test { .map(|conn| conn.id()) .collect::>(); + // give some time to reset connections + sleep(Duration::from_millis(1000)).await; + // get_conn should work if connection is available and alive pool.get_conn().await?; @@ -526,10 +656,7 @@ mod test { #[tokio::test] async fn should_reuse_connections() -> super::Result<()> { - let constraints = PoolConstraints::new(1, 1).unwrap(); - let opts = get_opts().pool_opts(PoolOpts::default().with_constraints(constraints)); - - let pool = Pool::new(opts); + let pool = pool_with_one_connection(); let mut conn = pool.get_conn().await?; let server_version = conn.server_version(); @@ -561,32 +688,38 @@ mod test { } drop(tx); // see that all the tx's eventually complete - while let Some(_) = rx.recv().await {} + while (rx.recv().await).is_some() {} } drop(pool); } #[tokio::test] async fn should_start_transaction() -> super::Result<()> { - let constraints = PoolConstraints::new(1, 1).unwrap(); - let opts = get_opts().pool_opts(PoolOpts::default().with_constraints(constraints)); - - let pool = Pool::new(opts); + let pool = pool_with_one_connection(); - "CREATE TEMPORARY TABLE tmp(id int)".ignore(&pool).await?; + "CREATE TABLE IF NOT EXISTS mysql.tmp(id int)" + .ignore(&pool) + .await?; + "DELETE FROM mysql.tmp".ignore(&pool).await?; let mut tx = pool.start_transaction(TxOpts::default()).await?; - tx.exec_batch("INSERT INTO tmp (id) VALUES (?)", vec![(1_u8,), (2_u8,)]) - .await?; - tx.exec_drop("SELECT * FROM tmp", ()).await?; + tx.exec_batch( + "INSERT INTO mysql.tmp (id) VALUES (?)", + vec![(1_u8,), (2_u8,)], + ) + .await?; + tx.exec_drop("SELECT * FROM mysql.tmp", ()).await?; drop(tx); let row_opt = pool .get_conn() .await? - .query_first("SELECT COUNT(*) FROM tmp") + .query_first("SELECT COUNT(*) FROM mysql.tmp") .await?; assert_eq!(row_opt, Some((0u8,))); - pool.get_conn().await?.query_drop("DROP TABLE tmp").await?; + pool.get_conn() + .await? + .query_drop("DROP TABLE mysql.tmp") + .await?; pool.disconnect().await?; Ok(()) } @@ -657,7 +790,7 @@ mod test { let _ = conns.pop(); // then, wait for a bit to let the connection be reclaimed - sleep(Duration::from_millis(50)).await; + sleep(Duration::from_millis(500)).await; // now check that we have the expected # of connections // this may look a little funky, but think of it this way: @@ -855,10 +988,7 @@ mod test { #[tokio::test] async fn should_ignore_non_fatal_errors_while_returning_to_a_pool() -> super::Result<()> { - let pool_constraints = PoolConstraints::new(1, 1).unwrap(); - let pool_opts = PoolOpts::default().with_constraints(pool_constraints); - - let pool = Pool::new(get_opts().pool_opts(pool_opts)); + let pool = pool_with_one_connection(); let id = pool.get_conn().await?.id(); // non-fatal errors are ignored @@ -873,10 +1003,7 @@ mod test { #[tokio::test] async fn should_remove_waker_of_cancelled_task() { - let pool_constraints = PoolConstraints::new(1, 1).unwrap(); - let pool_opts = PoolOpts::default().with_constraints(pool_constraints); - - let pool = Pool::new(get_opts().pool_opts(pool_opts)); + let pool = pool_with_one_connection(); let only_conn = pool.get_conn().await.unwrap(); let join_handle = tokio::spawn(timeout(Duration::from_secs(1), pool.get_conn())); @@ -890,6 +1017,13 @@ mod test { drop(only_conn); assert_eq!(0, pool.inner.exchange.lock().unwrap().waiting.queue.len()); + // metrics should catch up with waiting queue (see #335) + assert_eq!( + 0, + pool.metrics() + .active_wait_requests + .load(std::sync::atomic::Ordering::Relaxed) + ); } #[tokio::test] @@ -970,6 +1104,146 @@ mod test { assert_eq!(0, waitlist.queue.len()); } + #[tokio::test] + async fn check_absolute_connection_ttl() -> super::Result<()> { + let constraints = PoolConstraints::new(1, 3).unwrap(); + let pool_opts = PoolOpts::default() + .with_constraints(constraints) + .with_inactive_connection_ttl(Duration::from_secs(99)) + .with_ttl_check_interval(Duration::from_secs(1)) + .with_abs_conn_ttl(Some(Duration::from_secs(2))); + + let pool = Pool::new(get_opts().pool_opts(pool_opts)); + + let conn_ttl0 = pool.get_conn().await?; + sleep(Duration::from_millis(1000)).await; + let conn_ttl1 = pool.get_conn().await?; + sleep(Duration::from_millis(1000)).await; + let conn_ttl2 = pool.get_conn().await?; + + drop(conn_ttl0); + drop(conn_ttl1); + drop(conn_ttl2); + assert_eq!(ex_field!(pool, exist), 3); + + sleep(Duration::from_millis(1500)).await; + assert_eq!(ex_field!(pool, exist), 2); + + sleep(Duration::from_millis(1000)).await; + assert_eq!(ex_field!(pool, exist), 1); + + // Go even below min pool size. + sleep(Duration::from_millis(1000)).await; + assert_eq!(ex_field!(pool, exist), 0); + + Ok(()) + } + + #[tokio::test] + async fn save_last_waker() { + // Test that if passed multiple wakers, we call the last one. + + let pool = pool_with_one_connection(); + + // Get a connection, so we know the next future will be + // queued. + let conn = pool.get_conn().await.unwrap(); + let mut pending_fut = pin!(pool.get_conn()); + + let build_waker = || { + let called = Arc::new(OnceLock::new()); + let called2 = called.clone(); + let waker = waker_fn(move || called2.set(()).unwrap()); + (called, waker) + }; + + let mut assert_pending = |waker| { + let mut context = Context::from_waker(&waker); + let p = pending_fut.as_mut().poll(&mut context); + assert!(matches!(p, Poll::Pending)); + }; + + let (first_called, waker) = build_waker(); + assert_pending(waker); + + let (second_called, waker) = build_waker(); + assert_pending(waker); + + drop(conn); + + while second_called.get().is_none() { + assert!(first_called.get().is_none()); + tokio::time::sleep(Duration::from_millis(100)).await; + } + + assert!(first_called.get().is_none()); + } + + #[tokio::test] + async fn check_priorities() -> super::Result<()> { + let pool = pool_with_one_connection(); + + let queue_len = || { + let exchange = pool.inner.exchange.lock().unwrap(); + exchange.waiting.queue.len() + }; + + // Get a connection, so we know the next futures will be + // queued. + let conn = pool.get_conn().await.unwrap(); + + #[allow(clippy::async_yields_async)] + let get_pending = || async { + let fut = async { + pool.get_conn().await.unwrap(); + } + .shared(); + let p = poll!(fut.clone()); + assert!(matches!(p, Poll::Pending)); + fut + }; + + let fut1 = get_pending().await; + let fut2 = get_pending().await; + + // Both futures are queued + assert_eq!(queue_len(), 2); + + drop(conn); // This will pop fut1 from the queue, making it [2] + while queue_len() != 1 { + tokio::time::sleep(Duration::from_millis(100)).await; + } + + // We called wake on fut1, and even with the select fut1 will + // resolve first + let Either::Right((_, fut2)) = select(fut2, fut1).await else { + panic!("wrong future"); + }; + + // We dropped the connection of fut1, but very likely hasn't + // made it through the recycler yet. + assert_eq!(queue_len(), 1); + + let p = poll!(fut2.clone()); + assert!(matches!(p, Poll::Pending)); + assert_eq!(queue_len(), 1); // The queue still has fut2 + + // The connection will pass by the recycler and unblock fut2 + // and pop it from the queue. + fut2.await; + assert_eq!(queue_len(), 0); + + // The recycler is probably not done, so a new future will be + // pending. + let fut3 = get_pending().await; + assert_eq!(queue_len(), 1); + + // It is OK to await it. + fut3.await; + + Ok(()) + } + #[cfg(feature = "nightly")] mod bench { use futures_util::future::{FutureExt, TryFutureExt}; diff --git a/src/conn/pool/recycler.rs b/src/conn/pool/recycler.rs index 2a704dbc..2809dc0b 100644 --- a/src/conn/pool/recycler.rs +++ b/src/conn/pool/recycler.rs @@ -28,6 +28,7 @@ pub(crate) struct Recycler { discard: FuturesUnordered>, discarded: usize, cleaning: FuturesUnordered>, + reset: FuturesUnordered>, // Option so that we have a way to send a "I didn't make a Conn after all" signal dropped: mpsc::UnboundedReceiver>, @@ -47,6 +48,7 @@ impl Recycler { discard: FuturesUnordered::new(), discarded: 0, cleaning: FuturesUnordered::new(), + reset: FuturesUnordered::new(), dropped, pool_opts, eof: false, @@ -60,26 +62,77 @@ impl Future for Recycler { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut close = self.inner.close.load(Ordering::Acquire); + macro_rules! conn_return { + ($self:ident, $conn:ident, $pool_is_closed: expr) => {{ + let mut exchange = $self.inner.exchange.lock().unwrap(); + if $pool_is_closed || exchange.available.len() >= $self.pool_opts.active_bound() { + drop(exchange); + $self + .inner + .metrics + .discarded_superfluous_connection + .fetch_add(1, Ordering::Relaxed); + $self.discard.push($conn.close_conn().boxed()); + } else { + $self + .inner + .metrics + .connection_returned_to_pool + .fetch_add(1, Ordering::Relaxed); + #[cfg(feature = "hdrhistogram")] + $self + .inner + .metrics + .connection_active_duration + .lock() + .unwrap() + .saturating_record($conn.inner.active_since.elapsed().as_micros() as u64); + exchange.available.push_back($conn.into()); + $self + .inner + .metrics + .connections_in_pool + .store(exchange.available.len(), Ordering::Relaxed); + if let Some(w) = exchange.waiting.pop() { + w.wake(); + } + } + }}; + } + macro_rules! conn_decision { ($self:ident, $conn:ident) => { if $conn.inner.stream.is_none() || $conn.inner.disconnected { // drop unestablished connection + $self + .inner + .metrics + .discarded_unestablished_connection + .fetch_add(1, Ordering::Relaxed); $self.discard.push(futures_util::future::ok(()).boxed()); } else if $conn.inner.tx_status != TxStatus::None || $conn.has_pending_result() { + $self + .inner + .metrics + .dirty_connection_return + .fetch_add(1, Ordering::Relaxed); $self.cleaning.push($conn.cleanup_for_pool().boxed()); } else if $conn.expired() || close { + $self + .inner + .metrics + .discarded_expired_connection + .fetch_add(1, Ordering::Relaxed); $self.discard.push($conn.close_conn().boxed()); + } else if $conn.inner.reset_upon_returning_to_a_pool { + $self + .inner + .metrics + .resetting_connection + .fetch_add(1, Ordering::Relaxed); + $self.reset.push($conn.reset_for_pool().boxed()); } else { - let mut exchange = $self.inner.exchange.lock().unwrap(); - if exchange.available.len() >= $self.pool_opts.active_bound() { - drop(exchange); - $self.discard.push($conn.close_conn().boxed()); - } else { - exchange.available.push_back($conn.into()); - if let Some(w) = exchange.waiting.pop() { - w.wake(); - } - } + conn_return!($self, $conn, false); } }; } @@ -132,6 +185,29 @@ impl Future for Recycler { // anything that comes through .dropped we know has .pool.is_none(). // therefore, dropping the conn won't decrement .exist, so we need to do that. self.discarded += 1; + self.inner + .metrics + .discarded_error_during_cleanup + .fetch_add(1, Ordering::Relaxed); + // NOTE: we're discarding the error here + let _ = e; + } + } + } + + // let's iterate through connections being successfully reset + loop { + match Pin::new(&mut self.reset).poll_next(cx) { + Poll::Pending | Poll::Ready(None) => break, + Poll::Ready(Some(Ok(conn))) => conn_return!(self, conn, close), + Poll::Ready(Some(Err(e))) => { + // an error during reset. + // replace with a new connection + self.discarded += 1; + self.inner + .metrics + .discarded_error_during_cleanup + .fetch_add(1, Ordering::Relaxed); // NOTE: we're discarding the error here let _ = e; } @@ -152,6 +228,10 @@ impl Future for Recycler { // an error occurred while closing a connection. // what do we do? we still replace it with a new connection.. self.discarded += 1; + self.inner + .metrics + .discarded_error_during_cleanup + .fetch_add(1, Ordering::Relaxed); // NOTE: we're discarding the error here let _ = e; } @@ -162,6 +242,10 @@ impl Future for Recycler { // we need to open up slots for new connctions to be established! let mut exchange = self.inner.exchange.lock().unwrap(); exchange.exist -= self.discarded; + self.inner + .metrics + .connection_count + .store(exchange.exist, Ordering::Relaxed); for _ in 0..self.discarded { if let Some(w) = exchange.waiting.pop() { w.wake(); @@ -176,7 +260,11 @@ impl Future for Recycler { // races on .exist let effectively_eof = close && self.inner.exchange.lock().unwrap().exist == 0; - if (self.eof || effectively_eof) && self.cleaning.is_empty() && self.discard.is_empty() { + if (self.eof || effectively_eof) + && self.cleaning.is_empty() + && self.discard.is_empty() + && self.reset.is_empty() + { // we know that all Pool handles have been dropped (self.dropped.poll returned None). // if this assertion fails, where are the remaining connections? diff --git a/src/conn/pool/ttl_check_inerval.rs b/src/conn/pool/ttl_check_inerval.rs index 0cb4f5f4..2686e612 100644 --- a/src/conn/pool/ttl_check_inerval.rs +++ b/src/conn/pool/ttl_check_inerval.rs @@ -7,10 +7,10 @@ // modified, or distributed except according to those terms. use futures_util::future::{ok, FutureExt}; -use pin_project::pin_project; use tokio::time::{self, Interval}; use std::{ + collections::VecDeque, future::Future, sync::{atomic::Ordering, Arc}, }; @@ -25,10 +25,8 @@ use std::pin::Pin; /// The purpose of this interval is to remove idling connections that both: /// * overflows min bound of the pool; /// * idles longer then `inactive_connection_ttl`. -#[pin_project] pub(crate) struct TtlCheckInterval { inner: Arc, - #[pin] interval: Interval, pool_opts: PoolOpts, } @@ -46,24 +44,48 @@ impl TtlCheckInterval { /// Perform the check. pub fn check_ttl(&self) { - let mut exchange = self.inner.exchange.lock().unwrap(); + let to_be_dropped = { + let mut exchange = self.inner.exchange.lock().unwrap(); - let num_idling = exchange.available.len(); - let num_to_drop = num_idling.saturating_sub(self.pool_opts.constraints().min()); + let num_to_drop = exchange + .available + .len() + .saturating_sub(self.pool_opts.constraints().min()); - for _ in 0..num_to_drop { - let idling_conn = exchange.available.pop_front().unwrap(); - if idling_conn.elapsed() > self.pool_opts.inactive_connection_ttl() { - assert!(idling_conn.conn.inner.pool.is_none()); - let inner = self.inner.clone(); - tokio::spawn(idling_conn.conn.disconnect().then(move |_| { - let mut exchange = inner.exchange.lock().unwrap(); - exchange.exist -= 1; - ok::<_, ()>(()) - })); - } else { - exchange.available.push_back(idling_conn); + let mut to_be_dropped = Vec::<_>::with_capacity(exchange.available.len()); + let mut kept_available = + VecDeque::<_>::with_capacity(self.pool_opts.constraints().max()); + + while let Some(conn) = exchange.available.pop_front() { + if conn.expired() + || (to_be_dropped.len() < num_to_drop + && conn.elapsed() > self.pool_opts.inactive_connection_ttl()) + { + to_be_dropped.push(conn); + } else { + kept_available.push_back(conn); + } } + exchange.available = kept_available; + self.inner + .metrics + .connections_in_pool + .store(exchange.available.len(), Ordering::Relaxed); + to_be_dropped + }; + + for idling_conn in to_be_dropped { + assert!(idling_conn.conn.inner.pool.is_none()); + let inner = self.inner.clone(); + tokio::spawn(idling_conn.conn.disconnect().then(move |_| { + let mut exchange = inner.exchange.lock().unwrap(); + exchange.exist -= 1; + inner + .metrics + .connection_count + .store(exchange.exist, Ordering::Relaxed); + ok::<_, ()>(()) + })); } } } @@ -73,7 +95,7 @@ impl Future for TtlCheckInterval { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { - let _ = futures_core::ready!(self.as_mut().project().interval.poll_tick(cx)); + let _ = futures_core::ready!(Pin::new(&mut self.interval).poll_tick(cx)); let close = self.inner.close.load(Ordering::Acquire); if !close { diff --git a/src/conn/routines/change_user.rs b/src/conn/routines/change_user.rs new file mode 100644 index 00000000..28b51d4e --- /dev/null +++ b/src/conn/routines/change_user.rs @@ -0,0 +1,58 @@ +use futures_core::future::BoxFuture; +use futures_util::FutureExt; +use mysql_common::{ + constants::{UTF8MB4_GENERAL_CI, UTF8_GENERAL_CI}, + packets::{ComChangeUser, ComChangeUserMoreData}, +}; +#[cfg(feature = "tracing")] +use tracing::debug_span; + +use crate::Conn; + +use super::Routine; + +/// A routine that performs `COM_CHANGE_USER`. +#[derive(Debug, Copy, Clone)] +pub struct ChangeUser; + +impl Routine<()> for ChangeUser { + fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<()>> { + #[cfg(feature = "tracing")] + let span = debug_span!( + "mysql_async::change_user", + mysql_async.connection.id = conn.id() + ); + + let com_change_user = ComChangeUser::new() + .with_user(conn.opts().user().map(|x| x.as_bytes())) + .with_database(conn.opts().db_name().map(|x| x.as_bytes())) + .with_auth_plugin_data( + conn.inner + .auth_plugin + .gen_data(conn.opts().pass(), &conn.inner.nonce) + .as_deref(), + ) + .with_more_data(Some( + ComChangeUserMoreData::new(if conn.inner.version >= (5, 5, 3) { + UTF8MB4_GENERAL_CI + } else { + UTF8_GENERAL_CI + }) + .with_auth_plugin(Some(conn.inner.auth_plugin.clone())) + .with_connect_attributes(None), + )) + .into_owned(); + + let fut = async move { + conn.write_command(&com_change_user).await?; + conn.inner.auth_switched = false; + conn.continue_auth().await?; + Ok(()) + }; + + #[cfg(feature = "tracing")] + let fut = instrument_result!(fut, span); + + fut.boxed() + } +} diff --git a/src/conn/routines/exec.rs b/src/conn/routines/exec.rs index 262a90c9..30b24286 100644 --- a/src/conn/routines/exec.rs +++ b/src/conn/routines/exec.rs @@ -6,7 +6,7 @@ use mysql_common::{packets::ComStmtExecuteRequestBuilder, params::Params}; #[cfg(feature = "tracing")] use tracing::{field, info_span, Level, Span}; -use crate::{BinaryProtocol, Conn, DriverError, Statement}; +use crate::{conn::MAX_STATEMENT_PARAMS, BinaryProtocol, Conn, DriverError, Statement}; use super::Routine; @@ -52,15 +52,21 @@ impl Routine<()> for ExecRoutine<'_> { Span::current().record("mysql_async.query.params", ps); } + if params.len() > MAX_STATEMENT_PARAMS { + Err(DriverError::StmtParamsNumberExceedsLimit { + supplied: params.len(), + })? + } + if self.stmt.num_params() as usize != params.len() { Err(DriverError::StmtParamsMismatch { required: self.stmt.num_params(), - supplied: params.len() as u16, + supplied: params.len(), })? } let (body, as_long_data) = - ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(&*params); + ComStmtExecuteRequestBuilder::new(self.stmt.id()).build(params); if as_long_data { conn.send_long_data(self.stmt.id(), params.iter()).await?; @@ -71,14 +77,13 @@ impl Routine<()> for ExecRoutine<'_> { break; } Params::Named(_) => { - if self.stmt.named_params.is_none() { + if self.stmt.named_params.is_empty() { let error = DriverError::NamedParamsForPositionalQuery.into(); return Err(error); } let named = mem::replace(&mut self.params, Params::Empty); - self.params = - named.into_positional(self.stmt.named_params.as_ref().unwrap())?; + self.params = named.into_positional(&self.stmt.named_params)?; continue; } diff --git a/src/conn/routines/helpers.rs b/src/conn/routines/helpers.rs index eed0f38e..fdaf5fef 100644 --- a/src/conn/routines/helpers.rs +++ b/src/conn/routines/helpers.rs @@ -65,14 +65,14 @@ impl Conn { } }; - match packet.get(0) { + match packet.first() { Some(0x00) => { self.set_pending_result(Some(P::result_set_meta(Arc::from( Vec::new().into_boxed_slice(), ))))?; } - Some(0xFB) => self.handle_local_infile::

(&*packet).await?, - _ => self.handle_result_set::

(&*packet).await?, + Some(0xFB) => self.handle_local_infile::

(&packet).await?, + _ => self.handle_result_set::

(&packet).await?, } Ok(()) @@ -98,7 +98,7 @@ impl Conn { match bytes { Ok(bytes) => { // We'll skip empty chunks to stay compliant with the protocol. - if bytes.len() > 0 { + if !bytes.is_empty() { self.write_bytes(&bytes).await?; } } diff --git a/src/conn/routines/mod.rs b/src/conn/routines/mod.rs index 928bd771..80ecf1a5 100644 --- a/src/conn/routines/mod.rs +++ b/src/conn/routines/mod.rs @@ -2,8 +2,9 @@ use futures_core::future::BoxFuture; use crate::Conn; -pub use self::{exec::*, next_set::*, ping::*, prepare::*, query::*, reset::*}; +pub use self::{change_user::*, exec::*, next_set::*, ping::*, prepare::*, query::*, reset::*}; +mod change_user; mod exec; mod next_set; mod ping; diff --git a/src/conn/routines/prepare.rs b/src/conn/routines/prepare.rs index 33970e58..1fa1d300 100644 --- a/src/conn/routines/prepare.rs +++ b/src/conn/routines/prepare.rs @@ -47,7 +47,7 @@ impl Routine> for PrepareRoutine { .await?; let packet = conn.read_packet().await?; - let mut inner_stmt = StmtInner::from_payload(&*packet, conn.id(), self.query.clone())?; + let mut inner_stmt = StmtInner::from_payload(&packet, conn.id(), self.query.clone())?; #[cfg(feature = "tracing")] Span::current().record("mysql_async.statement.id", inner_stmt.id()); diff --git a/src/conn/stmt_cache.rs b/src/conn/stmt_cache.rs index 595836ac..84a99c20 100644 --- a/src/conn/stmt_cache.rs +++ b/src/conn/stmt_cache.rs @@ -7,7 +7,7 @@ // modified, or distributed except according to those terms. use lru::LruCache; -use twox_hash::XxHash; +use twox_hash::XxHash64; use std::{ borrow::Borrow, @@ -23,13 +23,13 @@ pub struct QueryString(pub Arc<[u8]>); impl Borrow<[u8]> for QueryString { fn borrow(&self) -> &[u8] { - &*self.0.as_ref() + self.0.as_ref() } } impl PartialEq<[u8]> for QueryString { fn eq(&self, other: &[u8]) -> bool { - &*self.0.as_ref() == other + self.0.as_ref() == other } } @@ -42,7 +42,7 @@ pub struct Entry { pub struct StmtCache { cap: usize, cache: LruCache, - query_map: HashMap>, + query_map: HashMap>, } impl StmtCache { @@ -80,7 +80,7 @@ impl StmtCache { if self.cache.len() > self.cap { if let Some((_, entry)) = self.cache.pop_lru() { - self.query_map.remove(&*entry.query.0.as_ref()); + self.query_map.remove(entry.query.0.as_ref()); return Some(entry.stmt); } } diff --git a/src/connection_like/mod.rs b/src/connection_like/mod.rs index b47adfaa..d86a313d 100644 --- a/src/connection_like/mod.rs +++ b/src/connection_like/mod.rs @@ -10,9 +10,9 @@ use futures_util::FutureExt; use crate::{BoxFuture, Pool}; -/// Connection. +/// Inner [`Connection`] representation. #[derive(Debug)] -pub enum Connection<'a, 't: 'a> { +pub(crate) enum ConnectionInner<'a, 't: 'a> { /// Just a connection. Conn(crate::Conn), /// Mutable reference to a connection. @@ -21,21 +21,92 @@ pub enum Connection<'a, 't: 'a> { Tx(&'a mut crate::Transaction<'t>), } +impl std::ops::Deref for ConnectionInner<'_, '_> { + type Target = crate::Conn; + + fn deref(&self) -> &Self::Target { + match self { + ConnectionInner::Conn(ref conn) => conn, + ConnectionInner::ConnMut(conn) => conn, + ConnectionInner::Tx(tx) => tx.0.deref(), + } + } +} + +impl std::ops::DerefMut for ConnectionInner<'_, '_> { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + ConnectionInner::Conn(conn) => conn, + ConnectionInner::ConnMut(conn) => conn, + ConnectionInner::Tx(tx) => tx.0.inner.deref_mut(), + } + } +} + +/// Some connection. +/// +/// This could at least be queried. +#[derive(Debug)] +pub struct Connection<'a, 't: 'a> { + pub(crate) inner: ConnectionInner<'a, 't>, +} + +impl Connection<'_, '_> { + #[inline] + pub(crate) fn as_mut(&mut self) -> &mut crate::Conn { + &mut self.inner + } +} + +impl<'a, 't: 'a> Connection<'a, 't> { + /// Borrows a [`Connection`] rather than consuming it. + /// + /// This is useful to allow calling [`Query`] methods while still retaining + /// ownership of the original connection. + /// + /// # Examples + /// + /// ```no_run + /// # use mysql_async::Connection; + /// # use mysql_async::prelude::Query; + /// async fn connection_by_ref(mut connection: Connection<'_, '_>) { + /// // Perform some query + /// "SELECT 1".ignore(connection.by_ref()).await.unwrap(); + /// // Perform another query. + /// // We can only do this because we used `by_ref` earlier. + /// "SELECT 2".ignore(connection).await.unwrap(); + /// } + /// ``` + /// + /// [`Query`]: crate::prelude::Query + pub fn by_ref(&mut self) -> Connection<'_, '_> { + Connection { + inner: ConnectionInner::ConnMut(self.as_mut()), + } + } +} + impl From for Connection<'static, 'static> { fn from(conn: crate::Conn) -> Self { - Connection::Conn(conn) + Self { + inner: ConnectionInner::Conn(conn), + } } } impl<'a> From<&'a mut crate::Conn> for Connection<'a, 'static> { fn from(conn: &'a mut crate::Conn) -> Self { - Connection::ConnMut(conn) + Self { + inner: ConnectionInner::ConnMut(conn), + } } } impl<'a, 't> From<&'a mut crate::Transaction<'t>> for Connection<'a, 't> { fn from(tx: &'a mut crate::Transaction<'t>) -> Self { - Connection::Tx(tx) + Self { + inner: ConnectionInner::Tx(tx), + } } } @@ -43,25 +114,11 @@ impl std::ops::Deref for Connection<'_, '_> { type Target = crate::Conn; fn deref(&self) -> &Self::Target { - match self { - Connection::Conn(ref conn) => conn, - Connection::ConnMut(conn) => conn, - Connection::Tx(tx) => tx.0.deref(), - } - } -} - -impl std::ops::DerefMut for Connection<'_, '_> { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - Connection::Conn(conn) => conn, - Connection::ConnMut(conn) => conn, - Connection::Tx(tx) => tx.0.deref_mut(), - } + &self.inner } } -/// Result of `ToConnection::to_connection` call. +/// Result of a [`ToConnection::to_connection`] call. pub enum ToConnectionResult<'a, 't: 'a> { /// Connection is immediately available. Immediate(Connection<'a, 't>), @@ -69,7 +126,22 @@ pub enum ToConnectionResult<'a, 't: 'a> { Mediate(BoxFuture<'a, Connection<'a, 't>>), } +impl<'a, 't: 'a> ToConnectionResult<'a, 't> { + /// Resolves `self` to a connection. + #[inline] + pub async fn resolve(self) -> crate::Result> { + match self { + ToConnectionResult::Immediate(immediate) => Ok(immediate), + ToConnectionResult::Mediate(mediate) => mediate.await, + } + } +} + +/// Everything that can be given in exchange to a connection. +/// +/// Note that you could obtain a `'static` connection by giving away `Conn` or `Pool`. pub trait ToConnection<'a, 't: 'a>: Send { + /// Converts self to a connection. fn to_connection(self) -> ToConnectionResult<'a, 't>; } diff --git a/src/error/mod.rs b/src/error/mod.rs index 81087260..2f8de211 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -8,7 +8,7 @@ pub use url::ParseError; -mod tls; +pub mod tls; use mysql_common::{ named_params::MixedParamsError, params::MissingNamedParameterError, @@ -21,6 +21,10 @@ use std::{io, result}; /// Result type alias for this library. pub type Result = result::Result; +/// The maximum number of bind variables supported by MySQL. +/// https://stackoverflow.com/questions/4922345/how-many-bind-variables-can-i-use-in-a-sql-query-in-mysql-5#comment136409462_11131824 +pub(crate) const MAX_STATEMENT_PARAMS: usize = u16::MAX as usize; + /// This type enumerates library errors. #[derive(Debug, Error)] pub enum Error { @@ -109,7 +113,7 @@ pub enum DriverError { #[error("Error converting from mysql row.")] FromRow { row: Row }, - #[error("Missing named parameter `{}'.", String::from_utf8_lossy(&name))] + #[error("Missing named parameter `{}'.", String::from_utf8_lossy(name))] MissingNamedParam { name: Vec }, #[error("Named and positional parameters mixed in one statement.")] @@ -135,7 +139,14 @@ pub enum DriverError { required, supplied )] - StmtParamsMismatch { required: u16, supplied: u16 }, + StmtParamsMismatch { required: u16, supplied: usize }, + + #[error( + "MySQL supports up to {} parameters but {} was supplied.", + MAX_STATEMENT_PARAMS, + supplied + )] + StmtParamsNumberExceedsLimit { supplied: usize }, #[error("Unexpected packet.")] UnexpectedPacket { payload: Vec }, @@ -163,6 +174,9 @@ pub enum DriverError { #[error("Client asked for SSL but server does not have this capability")] NoClientSslFlagFromServer, + + #[error("mysql_clear_password must be enabled on the client side")] + CleartextPluginDisabled, } #[derive(Debug, Error)] @@ -240,7 +254,10 @@ impl From> for ServerError { ServerError { code: packet.error_code(), message: packet.message_str().into(), - state: packet.sql_state_str().into(), + state: packet + .sql_state_ref() + .map(|s| s.as_str().into_owned()) + .unwrap_or_else(|| "HY000".to_owned()), } } } diff --git a/src/error/tls/mod.rs b/src/error/tls/mod.rs index 220ed850..4048fbfe 100644 --- a/src/error/tls/mod.rs +++ b/src/error/tls/mod.rs @@ -1,9 +1,9 @@ -#![cfg(any(feature = "native-tls", feature = "rustls-tls"))] +#![cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] pub mod native_tls_error; pub mod rustls_error; -#[cfg(feature = "native-tls")] +#[cfg(feature = "native-tls-tls")] pub use native_tls_error::TlsError; #[cfg(feature = "rustls")] diff --git a/src/error/tls/native_tls_error.rs b/src/error/tls/native_tls_error.rs index 8ca8b6cb..6b324011 100644 --- a/src/error/tls/native_tls_error.rs +++ b/src/error/tls/native_tls_error.rs @@ -1,4 +1,4 @@ -#![cfg(feature = "native-tls")] +#![cfg(feature = "native-tls-tls")] use std::fmt::Display; diff --git a/src/error/tls/rustls_error.rs b/src/error/tls/rustls_error.rs index 2ee67d39..d88f3b18 100644 --- a/src/error/tls/rustls_error.rs +++ b/src/error/tls/rustls_error.rs @@ -2,11 +2,13 @@ use std::fmt::Display; +use rustls::server::VerifierBuilderError; + #[derive(Debug)] pub enum TlsError { Tls(rustls::Error), - Pki(webpki::Error), - InvalidDnsName(webpki::InvalidDnsNameError), + InvalidDnsName(rustls::pki_types::InvalidDnsNameError), + VerifierBuilderError(VerifierBuilderError), } impl From for crate::Error { @@ -15,48 +17,36 @@ impl From for crate::Error { } } +impl From for TlsError { + fn from(e: VerifierBuilderError) -> Self { + TlsError::VerifierBuilderError(e) + } +} + impl From for TlsError { fn from(e: rustls::Error) -> Self { TlsError::Tls(e) } } -impl From for TlsError { - fn from(e: webpki::InvalidDnsNameError) -> Self { +impl From for TlsError { + fn from(e: rustls::pki_types::InvalidDnsNameError) -> Self { TlsError::InvalidDnsName(e) } } -impl From for TlsError { - fn from(e: webpki::Error) -> Self { - TlsError::Pki(e) - } -} - impl From for crate::Error { fn from(e: rustls::Error) -> Self { crate::Error::Io(crate::error::IoError::Tls(e.into())) } } -impl From for crate::Error { - fn from(e: webpki::Error) -> Self { - crate::Error::Io(crate::error::IoError::Tls(e.into())) - } -} - -impl From for crate::Error { - fn from(e: webpki::InvalidDnsNameError) -> Self { - crate::Error::Io(crate::error::IoError::Tls(e.into())) - } -} - impl std::error::Error for TlsError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { TlsError::Tls(e) => Some(e), - TlsError::Pki(e) => Some(e), TlsError::InvalidDnsName(e) => Some(e), + TlsError::VerifierBuilderError(e) => Some(e), } } } @@ -65,8 +55,8 @@ impl Display for TlsError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { TlsError::Tls(e) => e.fmt(f), - TlsError::Pki(e) => e.fmt(f), TlsError::InvalidDnsName(e) => e.fmt(f), + TlsError::VerifierBuilderError(e) => e.fmt(f), } } } diff --git a/src/io/mod.rs b/src/io/mod.rs index 6498b33e..e5705a9f 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -11,8 +11,6 @@ pub use self::{read_packet::ReadPacket, write_packet::WritePacket}; use bytes::BytesMut; use futures_core::{ready, stream}; use mysql_common::proto::codec::PacketCodec as PacketCodecInner; -use pin_project::pin_project; -use socket2::{Socket as Socket2Socket, TcpKeepalive}; #[cfg(unix)] use tokio::io::AsyncWriteExt; use tokio::{ @@ -31,6 +29,7 @@ use std::{ ErrorKind::{BrokenPipe, NotConnected, Other}, }, mem::replace, + net::SocketAddr, ops::{Deref, DerefMut}, pin::Pin, task::{Context, Poll}, @@ -40,7 +39,7 @@ use std::{ use crate::{ buffer_pool::PooledBuf, error::IoError, - opts::{HostPortOrUrl, SslOpts, DEFAULT_PORT}, + opts::{HostPortOrUrl, DEFAULT_PORT}, }; #[cfg(unix)] @@ -48,6 +47,8 @@ use crate::io::socket::Socket; mod tls; +pub(crate) use self::tls::TlsConnector; + macro_rules! with_interrupted { ($e:expr) => { loop { @@ -73,7 +74,7 @@ impl Default for PacketCodec { fn default() -> Self { Self { inner: Default::default(), - decode_buf: crate::BUFFER_POOL.get(), + decode_buf: crate::buffer_pool().get(), } } } @@ -98,7 +99,7 @@ impl Decoder for PacketCodec { fn decode(&mut self, src: &mut BytesMut) -> std::result::Result, IoError> { if self.inner.decode(src, self.decode_buf.as_mut())? { - let new_buf = crate::BUFFER_POOL.get(); + let new_buf = crate::buffer_pool().get(); Ok(Some(replace(&mut self.decode_buf, new_buf))) } else { Ok(None) @@ -114,16 +115,15 @@ impl Encoder for PacketCodec { } } -#[pin_project(project = EndpointProj)] #[derive(Debug)] pub(crate) enum Endpoint { Plain(Option), #[cfg(feature = "native-tls-tls")] - Secure(#[pin] tokio_native_tls::TlsStream), + Secure(tokio_native_tls::TlsStream), #[cfg(feature = "rustls-tls")] - Secure(#[pin] tokio_rustls::client::TlsStream), + Secure(tokio_rustls::client::TlsStream), #[cfg(unix)] - Socket(#[pin] Socket), + Socket(Socket), } /// This future will check that TcpStream is live. @@ -153,12 +153,9 @@ impl Future for CheckTcpStream<'_> { } impl Endpoint { - #[cfg(all(any(feature = "native-tls-tls", feature = "rustls-tls"), unix))] + #[cfg(unix)] fn is_socket(&self) -> bool { - match self { - Self::Socket(_) => true, - _ => false, - } + matches!(self, Self::Socket(_)) } /// Checks, that connection is alive. @@ -182,7 +179,7 @@ impl Endpoint { } #[cfg(unix)] Endpoint::Socket(socket) => { - socket.write(&[]).await?; + let _ = socket.write(&[]).await?; Ok(()) } Endpoint::Plain(None) => unreachable!(), @@ -194,18 +191,6 @@ impl Endpoint { matches!(self, Endpoint::Secure(_)) } - #[cfg(all(not(feature = "native-tls"), not(feature = "rustls")))] - pub async fn make_secure( - &mut self, - _domain: String, - _ssl_opts: crate::SslOpts, - ) -> crate::error::Result<()> { - panic!( - "Client had asked for TLS connection but TLS support is disabled. \ - Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]" - ) - } - pub fn set_tcp_nodelay(&self, val: bool) -> io::Result<()> { match *self { Endpoint::Plain(Some(ref stream)) => stream.set_nodelay(val)?, @@ -256,17 +241,17 @@ impl AsyncRead for Endpoint { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let mut this = self.project(); + let this = self.get_mut(); with_interrupted!(match this { - EndpointProj::Plain(ref mut stream) => { + Self::Plain(stream) => { Pin::new(stream.as_mut().unwrap()).poll_read(cx, buf) } #[cfg(feature = "native-tls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf), + Self::Secure(stream) => Pin::new(stream).poll_read(cx, buf), #[cfg(feature = "rustls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf), + Self::Secure(stream) => Pin::new(stream).poll_read(cx, buf), #[cfg(unix)] - EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf), + Self::Socket(stream) => Pin::new(stream).poll_read(cx, buf), }) } } @@ -277,17 +262,17 @@ impl AsyncWrite for Endpoint { cx: &mut Context, buf: &[u8], ) -> Poll> { - let mut this = self.project(); + let this = self.get_mut(); with_interrupted!(match this { - EndpointProj::Plain(ref mut stream) => { + Self::Plain(stream) => { Pin::new(stream.as_mut().unwrap()).poll_write(cx, buf) } #[cfg(feature = "native-tls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf), + Self::Secure(stream) => Pin::new(stream).poll_write(cx, buf), #[cfg(feature = "rustls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf), + Self::Secure(stream) => Pin::new(stream).poll_write(cx, buf), #[cfg(unix)] - EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf), + Self::Socket(stream) => Pin::new(stream).poll_write(cx, buf), }) } @@ -295,17 +280,17 @@ impl AsyncWrite for Endpoint { self: Pin<&mut Self>, cx: &mut Context, ) -> Poll> { - let mut this = self.project(); + let this = self.get_mut(); with_interrupted!(match this { - EndpointProj::Plain(ref mut stream) => { + Self::Plain(stream) => { Pin::new(stream.as_mut().unwrap()).poll_flush(cx) } #[cfg(feature = "native-tls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx), + Self::Secure(stream) => Pin::new(stream).poll_flush(cx), #[cfg(feature = "rustls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx), + Self::Secure(stream) => Pin::new(stream).poll_flush(cx), #[cfg(unix)] - EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx), + Self::Socket(stream) => Pin::new(stream).poll_flush(cx), }) } @@ -313,17 +298,17 @@ impl AsyncWrite for Endpoint { self: Pin<&mut Self>, cx: &mut Context, ) -> Poll> { - let mut this = self.project(); + let this = self.get_mut(); with_interrupted!(match this { - EndpointProj::Plain(ref mut stream) => { + Self::Plain(stream) => { Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx) } #[cfg(feature = "native-tls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx), + Self::Secure(stream) => Pin::new(stream).poll_shutdown(cx), #[cfg(feature = "rustls-tls")] - EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx), + Self::Secure(stream) => Pin::new(stream).poll_shutdown(cx), #[cfg(unix)] - EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx), + Self::Socket(stream) => Pin::new(stream).poll_shutdown(cx), }) } } @@ -360,30 +345,30 @@ impl Stream { keepalive: Option, ) -> io::Result { let tcp_stream = match addr { - HostPortOrUrl::HostPort(host, port) => { - TcpStream::connect((host.as_str(), *port)).await? - } + HostPortOrUrl::HostPort { + host, + port, + resolved_ips, + } => match resolved_ips { + Some(ips) => { + let addrs = ips + .iter() + .map(|ip| SocketAddr::new(*ip, *port)) + .collect::>(); + TcpStream::connect(&*addrs).await? + } + None => TcpStream::connect((host.as_str(), *port)).await?, + }, HostPortOrUrl::Url(url) => { let addrs = url.socket_addrs(|| Some(DEFAULT_PORT))?; TcpStream::connect(&*addrs).await? } }; + #[cfg(any(unix, windows))] if let Some(duration) = keepalive { - #[cfg(unix)] - let socket = { - use std::os::unix::prelude::*; - let fd = tcp_stream.as_raw_fd(); - unsafe { Socket2Socket::from_raw_fd(fd) } - }; - #[cfg(windows)] - let socket = { - use std::os::windows::prelude::*; - let sock = tcp_stream.as_raw_socket(); - unsafe { Socket2Socket::from_raw_socket(sock) } - }; - socket.set_tcp_keepalive(&TcpKeepalive::new().with_time(duration))?; - std::mem::forget(socket); + socket2::SockRef::from(&tcp_stream) + .set_tcp_keepalive(&socket2::TcpKeepalive::new().with_time(duration))?; } Ok(Stream { @@ -404,11 +389,11 @@ impl Stream { pub(crate) async fn make_secure( &mut self, domain: String, - ssl_opts: SslOpts, + tls_connector: &TlsConnector, ) -> crate::error::Result<()> { let codec = self.codec.take().unwrap(); let FramedParts { mut io, codec, .. } = codec.into_parts(); - io.make_secure(domain, ssl_opts).await?; + io.make_secure(domain, tls_connector).await?; let codec = Framed::new(io, codec); self.codec = Some(Box::new(codec)); Ok(()) @@ -419,6 +404,11 @@ impl Stream { self.codec.as_ref().unwrap().get_ref().is_secure() } + #[cfg(unix)] + pub(crate) fn is_socket(&self) -> bool { + self.codec.as_ref().unwrap().get_ref().is_socket() + } + pub(crate) fn reset_seq_id(&mut self) { if let Some(codec) = self.codec.as_mut() { codec.codec_mut().reset_seq_id(); @@ -455,7 +445,7 @@ impl Stream { self.closed = true; if let Some(mut codec) = self.codec { use futures_sink::Sink; - futures_util::future::poll_fn(|cx| match Pin::new(&mut *codec).poll_close(cx) { + std::future::poll_fn(|cx| match Pin::new(&mut *codec).poll_close(cx) { Poll::Ready(Err(IoError::Io(err))) if err.kind() == NotConnected => { Poll::Ready(Ok(())) } @@ -497,7 +487,7 @@ mod test { super::Endpoint::Plain(Some(stream)) => stream, #[cfg(feature = "rustls-tls")] super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().0, - #[cfg(feature = "native-tls")] + #[cfg(feature = "native-tls-tls")] super::Endpoint::Secure(tls_stream) => tls_stream.get_ref().get_ref().get_ref(), _ => unreachable!(), }; diff --git a/src/io/read_packet.rs b/src/io/read_packet.rs index 7e14fca0..62cbe76e 100644 --- a/src/io/read_packet.rs +++ b/src/io/read_packet.rs @@ -15,7 +15,7 @@ use std::{ task::{Context, Poll}, }; -use crate::{buffer_pool::PooledBuf, connection_like::Connection, error::IoError, Conn}; +use crate::{buffer_pool::PooledBuf, connection_like::Connection, error::IoError}; /// Reads a packet. #[derive(Debug)] @@ -27,8 +27,9 @@ impl<'a, 't> ReadPacket<'a, 't> { Self(conn.into()) } - pub(crate) fn conn_ref(&self) -> &Conn { - &*self.0 + #[cfg(feature = "binlog")] + pub(crate) fn conn_ref(&self) -> &crate::Conn { + &self.0 } } @@ -36,7 +37,7 @@ impl Future for ReadPacket<'_, '_> { type Output = std::result::Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let packet_opt = match self.0.stream_mut() { + let packet_opt = match self.0.as_mut().stream_mut() { Ok(stream) => ready!(Pin::new(stream).poll_next(cx)).transpose()?, // `ConnectionClosed` error. Err(_) => None, @@ -44,7 +45,7 @@ impl Future for ReadPacket<'_, '_> { match packet_opt { Some(packet) => { - self.0.touch(); + self.0.as_mut().touch(); Poll::Ready(Ok(packet)) } None => Poll::Ready(Err(Error::new( diff --git a/src/io/socket.rs b/src/io/socket.rs index 42aae58f..c8a4aaec 100644 --- a/src/io/socket.rs +++ b/src/io/socket.rs @@ -8,7 +8,6 @@ #![cfg(unix)] -use pin_project::pin_project; use tokio::io::{Error, ErrorKind::Interrupted, ReadBuf}; use std::{ @@ -20,10 +19,8 @@ use std::{ use tokio::io::{AsyncRead, AsyncWrite}; /// Unix domain socket connection on unix, or named pipe connection on windows. -#[pin_project] #[derive(Debug)] pub(crate) struct Socket { - #[pin] #[cfg(unix)] inner: tokio::net::UnixStream, } @@ -40,32 +37,28 @@ impl Socket { impl AsyncRead for Socket { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let mut this = self.project(); - with_interrupted!(this.inner.as_mut().poll_read(cx, buf)) + with_interrupted!(Pin::new(&mut self.inner).poll_read(cx, buf)) } } impl AsyncWrite for Socket { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { - let mut this = self.project(); - with_interrupted!(this.inner.as_mut().poll_write(cx, buf)) + with_interrupted!(Pin::new(&mut self.inner).poll_write(cx, buf)) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = self.project(); - with_interrupted!(this.inner.as_mut().poll_flush(cx)) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + with_interrupted!(Pin::new(&mut self.inner).poll_flush(cx)) } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = self.project(); - with_interrupted!(this.inner.as_mut().poll_shutdown(cx)) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + with_interrupted!(Pin::new(&mut self.inner).poll_shutdown(cx)) } } diff --git a/src/io/tls/mod.rs b/src/io/tls/mod.rs index 92f5e7c2..ca2f71e2 100644 --- a/src/io/tls/mod.rs +++ b/src/io/tls/mod.rs @@ -1,4 +1,13 @@ -#![cfg(any(feature = "native-tls", feature = "rustls"))] - +#[cfg(feature = "native-tls-tls")] mod native_tls_io; +#[cfg(not(any(feature = "rustls-tls", feature = "native-tls-tls")))] +mod no_tls; +#[cfg(feature = "rustls-tls")] mod rustls_io; + +#[cfg(feature = "native-tls-tls")] +pub(crate) use self::native_tls_io::TlsConnector; +#[cfg(not(any(feature = "rustls-tls", feature = "native-tls-tls")))] +pub(crate) use self::no_tls::TlsConnector; +#[cfg(feature = "rustls-tls")] +pub(crate) use self::rustls_io::TlsConnector; diff --git a/src/io/tls/native_tls_io.rs b/src/io/tls/native_tls_io.rs index 910387d7..e085be7c 100644 --- a/src/io/tls/native_tls_io.rs +++ b/src/io/tls/native_tls_io.rs @@ -1,58 +1,55 @@ -#![cfg(feature = "native-tls")] - -use std::{fs::File, io::Read}; - -use native_tls::{Certificate, Identity, TlsConnector}; +use tokio_native_tls::native_tls::{self, Certificate}; use crate::io::Endpoint; use crate::{Result, SslOpts}; -impl Endpoint { - pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { - #[cfg(unix)] - if self.is_socket() { - // won't secure socket connection - return Ok(()); - } +pub use tokio_native_tls::TlsConnector; - let mut builder = TlsConnector::builder(); - if let Some(root_cert_path) = ssl_opts.root_cert_path() { - let mut root_cert_data = vec![]; - let mut root_cert_file = File::open(root_cert_path)?; - root_cert_file.read_to_end(&mut root_cert_data)?; +impl SslOpts { + async fn load_root_certs(&self) -> crate::Result> { + let mut output = Vec::new(); - let root_certs = Certificate::from_der(&*root_cert_data) - .map(|x| vec![x]) - .or_else(|_| { - pem::parse_many(&*root_cert_data) - .unwrap_or_default() - .iter() - .map(pem::encode) - .map(|s| Certificate::from_pem(s.as_bytes())) - .collect() - })?; + for root_cert in self.root_certs() { + let root_cert_data = root_cert.read().await?; + output.extend(parse_certs(root_cert_data.as_ref())?); + } - for root_cert in root_certs { - builder.add_root_certificate(root_cert); - } + Ok(output) + } + + pub(crate) async fn build_tls_connector(&self) -> Result { + let mut builder = native_tls::TlsConnector::builder(); + for root_cert in self.load_root_certs().await? { + builder.add_root_certificate(root_cert); } - if let Some(client_identity) = ssl_opts.client_identity() { - let pkcs12_path = client_identity.pkcs12_path(); - let password = client_identity.password().unwrap_or(""); + if let Some(client_identity) = self.client_identity() { + builder.identity(client_identity.load().await?); + } + builder.danger_accept_invalid_hostnames(self.skip_domain_validation()); + builder.danger_accept_invalid_certs(self.accept_invalid_certs()); + builder.disable_built_in_roots(self.disable_built_in_roots()); + let tls_connector: TlsConnector = builder.build()?.into(); + Ok(tls_connector) + } +} - let der = std::fs::read(pkcs12_path)?; - let identity = Identity::from_pkcs12(&*der, password)?; - builder.identity(identity); +impl Endpoint { + pub async fn make_secure( + &mut self, + domain: String, + tls_connector: &TlsConnector, + ) -> Result<()> { + #[cfg(unix)] + if self.is_socket() { + // won't secure socket connection + return Ok(()); } - builder.danger_accept_invalid_hostnames(ssl_opts.skip_domain_validation()); - builder.danger_accept_invalid_certs(ssl_opts.accept_invalid_certs()); - let tls_connector: tokio_native_tls::TlsConnector = builder.build()?.into(); *self = match self { Endpoint::Plain(ref mut stream) => { let stream = stream.take().unwrap(); - let tls_stream = tls_connector.connect(&*domain, stream).await?; + let tls_stream = tls_connector.connect(&domain, stream).await?; Endpoint::Secure(tls_stream) } Endpoint::Secure(_) => unreachable!(), @@ -63,3 +60,14 @@ impl Endpoint { Ok(()) } } + +fn parse_certs(buf: &[u8]) -> Result> { + Ok(Certificate::from_der(buf).map(|x| vec![x]).or_else(|_| { + pem::parse_many(buf) + .unwrap_or_default() + .iter() + .map(pem::encode) + .map(|s| Certificate::from_pem(s.as_bytes())) + .collect() + })?) +} diff --git a/src/io/tls/no_tls.rs b/src/io/tls/no_tls.rs new file mode 100644 index 00000000..325585c9 --- /dev/null +++ b/src/io/tls/no_tls.rs @@ -0,0 +1,24 @@ +use crate::io::Endpoint; +use crate::{Result, SslOpts}; + +#[derive(Clone, Debug)] +pub(crate) struct TlsConnector; + +impl SslOpts { + pub(crate) async fn build_tls_connector(&self) -> Result { + panic!( + "Client had asked for TLS connection but TLS support is disabled. \ + Please enable one of the following features: [\"native-tls-tls\", \"rustls-tls\"]" + ) + } +} + +impl Endpoint { + pub async fn make_secure( + &mut self, + _domain: String, + _tls_connector: &TlsConnector, + ) -> Result<()> { + unreachable!(); + } +} diff --git a/src/io/tls/rustls_io.rs b/src/io/tls/rustls_io.rs index 654581d3..143cf066 100644 --- a/src/io/tls/rustls_io.rs +++ b/src/io/tls/rustls_io.rs @@ -1,85 +1,92 @@ -#![cfg(feature = "rustls-tls")] - -use std::{convert::TryInto, sync::Arc}; +use std::sync::Arc; use rustls::{ - client::{ServerCertVerifier, WebPkiVerifier}, - Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore, + client::{ + danger::{ServerCertVerified, ServerCertVerifier}, + WebPkiServerVerifier, + }, + pki_types::{CertificateDer, ServerName}, + ClientConfig, RootCertStore, }; -use tokio::{fs::File, io::AsyncReadExt}; - use rustls_pemfile::certs; -use tokio_rustls::TlsConnector; +pub(crate) use tokio_rustls::TlsConnector; -use crate::{io::Endpoint, Result, SslOpts}; +use crate::{io::Endpoint, Result, SslOpts, TlsError}; -impl Endpoint { - pub async fn make_secure(&mut self, domain: String, ssl_opts: SslOpts) -> Result<()> { - #[cfg(unix)] - if self.is_socket() { - // won't secure socket connection - return Ok(()); - } +impl SslOpts { + async fn load_root_certs(&self) -> crate::Result>> { + let mut output = Vec::new(); - let mut root_store = RootCertStore::empty(); - root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - })); - - if let Some(root_cert_path) = ssl_opts.root_cert_path() { - let mut root_cert_data = vec![]; - let mut root_cert_file = File::open(root_cert_path).await?; - root_cert_file.read_to_end(&mut root_cert_data).await?; - - let mut root_certs = Vec::new(); - for cert in certs(&mut &*root_cert_data)? { - root_certs.push(Certificate(cert)); + for root_cert in self.root_certs() { + let root_cert_data = root_cert.read().await?; + let mut seen = false; + for cert in certs(&mut &*root_cert_data) { + seen = true; + output.push(cert?); } - if root_certs.is_empty() && !root_cert_data.is_empty() { - root_certs.push(Certificate(root_cert_data)); + if !seen && !root_cert_data.is_empty() { + output.push(CertificateDer::from(root_cert_data.into_owned())); } + } - for cert in &root_certs { - root_store.add(cert)?; - } + Ok(output) + } + + pub(crate) async fn build_tls_connector(&self) -> Result { + let mut root_store = RootCertStore::empty(); + if !self.disable_built_in_roots() { + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|x| x.to_owned())); + } + + for cert in self.load_root_certs().await? { + root_store.add(cert)?; } - let config_builder = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_store.clone()); + let config_builder = ClientConfig::builder().with_root_certificates(root_store.clone()); - let mut config = if let Some(identity) = ssl_opts.client_identity() { - let (cert_chain, priv_key) = identity.load()?; - config_builder.with_single_cert(cert_chain, priv_key)? + let mut config = if let Some(identity) = self.client_identity() { + let (cert_chain, priv_key) = identity.load().await?; + config_builder.with_client_auth_cert(cert_chain, priv_key)? } else { config_builder.with_no_client_auth() }; - let server_name = domain - .as_str() - .try_into() - .map_err(|_| webpki::InvalidDnsNameError)?; let mut dangerous = config.dangerous(); - let web_pki_verifier = WebPkiVerifier::new(root_store, None); + let web_pki_verifier = WebPkiServerVerifier::builder(Arc::new(root_store)) + .build() + .map_err(TlsError::from)?; let dangerous_verifier = DangerousVerifier::new( - ssl_opts.accept_invalid_certs(), - ssl_opts.skip_domain_validation(), + self.accept_invalid_certs(), + self.skip_domain_validation(), web_pki_verifier, ); dangerous.set_certificate_verifier(Arc::new(dangerous_verifier)); + let client_config = Arc::new(config); + Ok(TlsConnector::from(client_config)) + } +} + +impl Endpoint { + pub async fn make_secure( + &mut self, + domain: String, + tls_connector: &TlsConnector, + ) -> Result<()> { + #[cfg(unix)] + if self.is_socket() { + // won't secure socket connection + return Ok(()); + } *self = match self { Endpoint::Plain(ref mut stream) => { let stream = stream.take().unwrap(); - let client_config = Arc::new(config); - let tls_connector = TlsConnector::from(client_config); + let server_name = ServerName::try_from(domain.as_str()) + .map_err(TlsError::InvalidDnsName)? + .to_owned(); let connection = tls_connector.connect(server_name, stream).await?; Endpoint::Secure(connection) @@ -93,17 +100,18 @@ impl Endpoint { } } +#[derive(Debug)] struct DangerousVerifier { accept_invalid_certs: bool, skip_domain_validation: bool, - verifier: WebPkiVerifier, + verifier: Arc, } impl DangerousVerifier { fn new( accept_invalid_certs: bool, skip_domain_validation: bool, - verifier: WebPkiVerifier, + verifier: Arc, ) -> Self { Self { accept_invalid_certs, @@ -114,35 +122,86 @@ impl DangerousVerifier { } impl ServerCertVerifier for DangerousVerifier { + // fn verify_server_cert( + // &self, + // end_entity: &Certificate, + // intermediates: &[Certificate], + // server_name: &rustls::ServerName, + // scts: &mut dyn Iterator, + // ocsp_response: &[u8], + // now: std::time::SystemTime, + // ) -> std::result::Result { + // if self.accept_invalid_certs { + // Ok(rustls::client::ServerCertVerified::assertion()) + // } else { + // match self.verifier.verify_server_cert( + // end_entity, + // intermediates, + // server_name, + // scts, + // ocsp_response, + // now, + // ) { + // Ok(assertion) => Ok(assertion), + // Err(ref e) + // if e.to_string().contains("NotValidForName") && self.skip_domain_validation => + // { + // Ok(rustls::client::ServerCertVerified::assertion()) + // } + // Err(e) => Err(e), + // } + // } + // } fn verify_server_cert( &self, - end_entity: &Certificate, - intermediates: &[Certificate], - server_name: &rustls::ServerName, - scts: &mut dyn Iterator, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &rustls::pki_types::ServerName<'_>, ocsp_response: &[u8], - now: std::time::SystemTime, - ) -> std::result::Result { + now: rustls::pki_types::UnixTime, + ) -> std::prelude::v1::Result { if self.accept_invalid_certs { - Ok(rustls::client::ServerCertVerified::assertion()) + Ok(ServerCertVerified::assertion()) } else { match self.verifier.verify_server_cert( end_entity, intermediates, server_name, - scts, ocsp_response, now, ) { Ok(assertion) => Ok(assertion), Err(ref e) - if e.to_string().contains("CertNotValidForName") - && self.skip_domain_validation => + if e.to_string().contains("NotValidForName") && self.skip_domain_validation => { - Ok(rustls::client::ServerCertVerified::assertion()) + Ok(ServerCertVerified::assertion()) } Err(e) => Err(e), } } } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::prelude::v1::Result + { + self.verifier.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> std::prelude::v1::Result + { + self.verifier.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.verifier.supported_verify_schemes() + } } diff --git a/src/io/write_packet.rs b/src/io/write_packet.rs index 0449edb1..5d06f1d6 100644 --- a/src/io/write_packet.rs +++ b/src/io/write_packet.rs @@ -44,7 +44,7 @@ impl Future for WritePacket<'_, '_> { ref mut data, } = *self; - match conn.stream_mut() { + match conn.as_mut().stream_mut() { Ok(stream) => { if data.is_some() { let codec = Pin::new(stream.codec.as_mut().expect("must be here")); diff --git a/src/lib.rs b/src/lib.rs index f28b8ce9..4ad9d735 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,8 +19,10 @@ //! //! # Crate Features //! -//! Default feature set is wide – it includes all default [`mysql_common` features][myslqcommonfeatures] -//! as well as `native-tls`-based TLS support. +//! By default there are only two features enabled: +//! +//! * `flate2/zlib` — choosing flate2 backend is mandatory +//! * `derive` — see ["Derive Macros" section in `mysql_common` docs][mysqlcommonderive] //! //! ## List Of Features //! @@ -36,25 +38,18 @@ //! mysql_async = { version = "*", default-features = false, features = ["minimal"]} //! ``` //! -//! **Note:* it is possible to use another `flate2` backend by directly choosing it: +//! * `minimal-rust` - same as `minimal` but with rust-based flate2 backend. Enables: //! -//! ```toml -//! [dependencies] -//! mysql_async = { version = "*", default-features = false } -//! flate2 = { version = "*", default-features = false, features = ["rust_backend"] } -//! ``` +//! - `flate2/rust_backend` //! -//! * `default` – enables the following set of crate's and dependencies' features: +//! * `default` – enables the following set of features: //! -//! - `native-tls-tls` -//! - `flate2/zlib" -//! - `mysql_common/bigdecimal03` -//! - `mysql_common/rust_decimal` -//! - `mysql_common/time03` -//! - `mysql_common/uuid` -//! - `mysql_common/frunk` +//! - `flate2/zlib` +//! - `derive` +//! +//! * `default-rustls` – default set of features with TLS via `rustls/aws-lc-rs` //! -//! * `default-rustls` – same as default but with `rustls-tls` instead of `native-tls-tls`. +//! * `default-rustls-ring` – default set of features with TLS via `rustls/ring` //! //! **Example:** //! @@ -63,21 +58,22 @@ //! mysql_async = { version = "*", default-features = false, features = ["default-rustls"] } //! ``` //! -//! * `native-tls-tls` – enables `native-tls`-based TLS support _(conflicts with `rustls-tls`)_ +//! * `native-tls-tls` – enables TLS via `native-tls` //! //! **Example:** //! //! ```toml //! [dependencies] -//! mysql_async = { version = "*", default-features = false, features = ["native-tls-tls"] } +//! mysql_async = { version = "*", default-features = false, features = ["minimal", "native-tls-tls"] } //! -//! * `rustls-tls` – enables `native-tls`-based TLS support _(conflicts with `native-tls-tls`)_ +//! * `rustls-tls` - enables rustls TLS backend with no provider. You should enable one +//! of existing providers using `aws-lc-rs` or `ring` features: //! //! **Example:** //! //! ```toml //! [dependencies] -//! mysql_async = { version = "*", default-features = false, features = ["rustls-tls"] } +//! mysql_async = { version = "*", default-features = false, features = ["minimal-rust", "rustls-tls", "ring"] } //! //! * `tracing` – enables instrumentation via `tracing` package. //! @@ -93,7 +89,21 @@ //! mysql_async = { version = "*", features = ["tracing"] } //! ``` //! +//! * `binlog` - enables binlog-related functionality. Enables: +//! +//! - `mysql_common/binlog" +//! +//! ### Proxied features (see [`mysql_common`` fatures][myslqcommonfeatures]) +//! +//! * `derive` – enables `mysql_common/derive` feature +//! * `chrono` = enables `mysql_common/chrono` feature +//! * `time` = enables `mysql_common/time` feature +//! * `bigdecimal` = enables `mysql_common/bigdecimal` feature +//! * `rust_decimal` = enables `mysql_common/rust_decimal` feature +//! * `frunk` = enables `mysql_common/frunk` feature +//! //! [myslqcommonfeatures]: https://github.com/blackbeam/rust_mysql_common#crate-features +//! [mysqlcommonderive]: https://github.com/blackbeam/rust_mysql_common?tab=readme-ov-file#derive-macros //! //! # TLS/SSL Support //! @@ -191,7 +201,7 @@ //! * [`Pool`] is a smart pointer – each clone will point to the same pool instance. //! * [`Pool`] is `Send + Sync + 'static` – feel free to pass it around. //! * use [`Pool::disconnect`] to gracefuly close the pool. -//! * [`Pool::new`] is lazy and won't assert server availability. +//! * ⚠️ [`Pool::new`] is lazy and won't assert server availability. //! //! # Transaction //! @@ -418,6 +428,9 @@ #[cfg(feature = "nightly")] extern crate test; +#[cfg(feature = "derive")] +extern crate mysql_common; + pub use mysql_common::{constants as consts, params}; use std::sync::Arc; @@ -441,15 +454,26 @@ mod queryable; type BoxFuture<'a, T> = futures_core::future::BoxFuture<'a, Result>; -static BUFFER_POOL: once_cell::sync::Lazy> = - once_cell::sync::Lazy::new(|| Default::default()); +fn buffer_pool() -> &'static Arc { + static BUFFER_POOL: std::sync::OnceLock> = + std::sync::OnceLock::new(); + BUFFER_POOL.get_or_init(Default::default) +} + +#[cfg(feature = "binlog")] +#[doc(inline)] +pub use self::conn::binlog_stream::{request::BinlogStreamRequest, BinlogStream}; #[doc(inline)] -pub use self::conn::{binlog_stream::BinlogStream, Conn}; +pub use self::conn::Conn; #[doc(inline)] pub use self::conn::pool::Pool; +#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] +#[doc(inline)] +pub use self::error::tls::TlsError; + #[doc(inline)] pub use self::error::{ DriverError, Error, IoError, LocalInfileError, ParseError, Result, ServerError, UrlError, @@ -462,13 +486,14 @@ pub use self::query::QueryWithParams; pub use self::queryable::transaction::IsolationLevel; #[doc(inline)] -#[cfg(any(feature = "rustls", feature = "native-tls"))] +#[cfg(any(feature = "rustls", feature = "native-tls-tls"))] pub use self::opts::ClientIdentity; #[doc(inline)] pub use self::opts::{ - Opts, OptsBuilder, PoolConstraints, PoolOpts, SslOpts, DEFAULT_INACTIVE_CONNECTION_TTL, - DEFAULT_POOL_CONSTRAINTS, DEFAULT_STMT_CACHE_SIZE, DEFAULT_TTL_CHECK_INTERVAL, + ChangeUserOpts, Opts, OptsBuilder, PoolConstraints, PoolOpts, SslOpts, + DEFAULT_INACTIVE_CONNECTION_TTL, DEFAULT_POOL_CONSTRAINTS, DEFAULT_STMT_CACHE_SIZE, + DEFAULT_TTL_CHECK_INTERVAL, }; #[doc(inline)] @@ -476,14 +501,14 @@ pub use self::local_infile_handler::{builtin::WhiteListFsHandler, InfileData}; #[doc(inline)] pub use mysql_common::packets::{ - binlog_request::BinlogRequest, session_state_change::{ Gtids, Schema, SessionStateChange, SystemVariable, TransactionCharacteristics, TransactionState, Unsupported, }, - BinlogDumpFlags, Column, Interval, OkPacket, SessionStateInfo, Sid, + Column, GnoInterval, OkPacket, SessionStateInfo, Sid, }; +#[cfg(feature = "binlog")] pub mod binlog { #[doc(inline)] pub use mysql_common::binlog::consts::*; @@ -525,6 +550,12 @@ pub use self::queryable::{BinaryProtocol, TextProtocol}; #[doc(inline)] pub use self::queryable::stmt::Statement; +#[doc(inline)] +pub use self::conn::pool::Metrics; + +#[doc(inline)] +pub use crate::connection_like::{Connection, ToConnectionResult}; + /// Futures used in this crate pub mod futures { pub use crate::conn::pool::futures::{DisconnectPool, GetConn}; @@ -532,6 +563,8 @@ pub mod futures { /// Traits used in this crate pub mod prelude { + #[doc(inline)] + pub use crate::connection_like::ToConnection; #[doc(inline)] pub use crate::local_infile_handler::GlobalHandler; #[doc(inline)] @@ -541,9 +574,11 @@ pub mod prelude { #[doc(inline)] pub use crate::queryable::Queryable; #[doc(inline)] - pub use mysql_common::row::convert::FromRow; + pub use mysql_common::prelude::ColumnIndex; #[doc(inline)] - pub use mysql_common::value::convert::{ConvIr, FromValue, ToValue}; + pub use mysql_common::prelude::FromRow; + #[doc(inline)] + pub use mysql_common::prelude::{FromValue, ToValue}; /// Everything that is a statement. /// @@ -566,17 +601,6 @@ pub mod prelude { pub trait StatementLike: crate::queryable::stmt::StatementLike {} impl StatementLike for T {} - /// Everything that is a connection. - /// - /// Note that you could obtain a `'static` connection by giving away `Conn` or `Pool`. - pub trait ToConnection<'a, 't: 'a>: crate::connection_like::ToConnection<'a, 't> {} - // explicitly implemented because of rusdoc - impl<'a> ToConnection<'a, 'static> for &'a crate::Pool {} - impl<'a> ToConnection<'static, 'static> for crate::Pool {} - impl ToConnection<'static, 'static> for crate::Conn {} - impl<'a> ToConnection<'a, 'static> for &'a mut crate::Conn {} - impl<'a, 't> ToConnection<'a, 't> for &'a mut crate::Transaction<'t> {} - /// Trait for protocol markers [`crate::TextProtocol`] and [`crate::BinaryProtocol`]. pub trait Protocol: crate::queryable::Protocol {} impl Protocol for crate::BinaryProtocol {} @@ -587,39 +611,29 @@ pub mod prelude { #[doc(hidden)] pub mod test_misc { - use lazy_static::lazy_static; - use std::env; + use std::sync::OnceLock; use crate::opts::{Opts, OptsBuilder, SslOpts}; #[allow(dead_code)] #[allow(unreachable_code)] - fn error_should_implement_send_and_sync() { + fn error_should_implement_send_and_sync(err: crate::Error) { fn _dummy(_: T) {} - _dummy(panic!()); + _dummy(err); } - lazy_static! { - pub static ref DATABASE_URL: String = { + pub fn get_opts() -> OptsBuilder { + static DATABASE_OPTS: OnceLock = OnceLock::new(); + let database_opts = DATABASE_OPTS.get_or_init(|| { if let Ok(url) = env::var("DATABASE_URL") { - let opts = Opts::from_url(&url).expect("DATABASE_URL invalid"); - if opts - .db_name() - .expect("a database name is required") - .is_empty() - { - panic!("database name is empty"); - } - url + Opts::from_url(&url).expect("DATABASE_URL invalid") } else { - "mysql://root:password@localhost:3307/mysql".into() + Opts::from_url("mysql://root:password@localhost:3307/mysql").unwrap() } - }; - } + }); - pub fn get_opts() -> OptsBuilder { - let mut builder = OptsBuilder::from_opts(Opts::from_url(&**DATABASE_URL).unwrap()); + let mut builder = OptsBuilder::from_opts(database_opts.clone()); if test_ssl() { let ssl_opts = SslOpts::default() .with_danger_skip_domain_validation(true) diff --git a/src/local_infile_handler/mod.rs b/src/local_infile_handler/mod.rs index 7d222a2b..39515124 100644 --- a/src/local_infile_handler/mod.rs +++ b/src/local_infile_handler/mod.rs @@ -34,7 +34,7 @@ pub type InfileData = BoxStream<'static, std::io::Result>; /// **Warning:** You should be aware of [Security Considerations for LOAD DATA LOCAL][1]. /// /// The purpose of the handler is to emit infile data in response to a file name. -/// This handler will be called if there is no [`LocalHandler`] installed for the connection. +/// This handler will be called if there is no local handler installed for the connection. /// /// The library will call this handler in response to a LOCAL INFILE request from the server. /// The server, in its turn, will emit LOCAL INFILE requests in response to a `LOAD DATA LOCAL` diff --git a/src/opts/mod.rs b/src/opts/mod.rs index ebb18fc3..11e04402 100644 --- a/src/opts/mod.rs +++ b/src/opts/mod.rs @@ -9,23 +9,25 @@ mod native_tls_opts; mod rustls_opts; -#[cfg(feature = "native-tls")] +#[cfg(feature = "native-tls-tls")] pub use native_tls_opts::ClientIdentity; #[cfg(feature = "rustls-tls")] pub use rustls_opts::ClientIdentity; use percent_encoding::percent_decode; +use rand::Rng; +use tokio::sync::OnceCell; use url::{Host, Url}; use std::{ borrow::Cow, - convert::TryFrom, - net::{Ipv4Addr, Ipv6Addr}, - path::Path, + fmt, io, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + path::{Path, PathBuf}, str::FromStr, sync::Arc, - time::Duration, + time::{Duration, Instant}, vec, }; @@ -41,7 +43,8 @@ pub const DEFAULT_POOL_CONSTRAINTS: PoolConstraints = PoolConstraints { min: 10, // const_assert!( _DEFAULT_POOL_CONSTRAINTS_ARE_CORRECT, - DEFAULT_POOL_CONSTRAINTS.min <= DEFAULT_POOL_CONSTRAINTS.max, + DEFAULT_POOL_CONSTRAINTS.min <= DEFAULT_POOL_CONSTRAINTS.max + && 0 < DEFAULT_POOL_CONSTRAINTS.max, ); /// Each connection will cache up to this number of statements by default. @@ -65,37 +68,61 @@ pub const DEFAULT_TTL_CHECK_INTERVAL: Duration = Duration::from_secs(30); /// into socket addresses using to_socket_addrs. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) enum HostPortOrUrl { - HostPort(String, u16), + HostPort { + host: String, + port: u16, + /// The resolved IP addresses to use for the TCP connection. If empty, + /// DNS resolution of `host` will be performed. + resolved_ips: Option>, + }, Url(Url), } impl Default for HostPortOrUrl { fn default() -> Self { - HostPortOrUrl::HostPort("127.0.0.1".to_string(), DEFAULT_PORT) + HostPortOrUrl::HostPort { + host: "127.0.0.1".to_string(), + port: DEFAULT_PORT, + resolved_ips: None, + } } } impl HostPortOrUrl { pub fn get_ip_or_hostname(&self) -> &str { match self { - Self::HostPort(host, _) => host, + Self::HostPort { host, .. } => host, Self::Url(url) => url.host_str().unwrap_or("127.0.0.1"), } } pub fn get_tcp_port(&self) -> u16 { match self { - Self::HostPort(_, port) => *port, + Self::HostPort { port, .. } => *port, Self::Url(url) => url.port().unwrap_or(DEFAULT_PORT), } } + pub fn get_resolved_ips(&self) -> &Option> { + match self { + Self::HostPort { resolved_ips, .. } => resolved_ips, + Self::Url(_) => &None, + } + } + pub fn is_loopback(&self) -> bool { match self { - Self::HostPort(host, _) => { + Self::HostPort { + host, resolved_ips, .. + } => { let v4addr: Option = FromStr::from_str(host).ok(); let v6addr: Option = FromStr::from_str(host).ok(); - if let Some(addr) = v4addr { + if resolved_ips + .as_ref() + .is_some_and(|s| s.iter().any(|ip| ip.is_loopback())) + { + true + } else if let Some(addr) = v4addr { addr.is_loopback() } else if let Some(addr) = v6addr { addr.is_loopback() @@ -113,6 +140,55 @@ impl HostPortOrUrl { } } +/// Represents data that is either on-disk or in the buffer. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum PathOrBuf<'a> { + Path(Cow<'a, Path>), + Buf(Cow<'a, [u8]>), +} + +impl<'a> PathOrBuf<'a> { + /// Will either read data from disk or return the buffered data. + pub async fn read(&self) -> io::Result> { + match self { + PathOrBuf::Path(x) => tokio::fs::read(x.as_ref()).await.map(Cow::Owned), + PathOrBuf::Buf(x) => Ok(Cow::Borrowed(x.as_ref())), + } + } + + /// Borrows `self`. + pub fn borrow(&self) -> PathOrBuf<'_> { + match self { + PathOrBuf::Path(path) => PathOrBuf::Path(Cow::Borrowed(path.as_ref())), + PathOrBuf::Buf(data) => PathOrBuf::Buf(Cow::Borrowed(data.as_ref())), + } + } +} + +impl From for PathOrBuf<'static> { + fn from(value: PathBuf) -> Self { + Self::Path(Cow::Owned(value)) + } +} + +impl<'a> From<&'a Path> for PathOrBuf<'a> { + fn from(value: &'a Path) -> Self { + Self::Path(Cow::Borrowed(value)) + } +} + +impl From> for PathOrBuf<'static> { + fn from(value: Vec) -> Self { + Self::Buf(Cow::Owned(value)) + } +} + +impl<'a> From<&'a [u8]> for PathOrBuf<'a> { + fn from(value: &'a [u8]) -> Self { + Self::Buf(Cow::Borrowed(value)) + } +} + /// Ssl Options. /// /// ``` @@ -123,7 +199,7 @@ impl HostPortOrUrl { /// // With native-tls /// # #[cfg(feature = "native-tls-tls")] /// let ssl_opts = SslOpts::default() -/// .with_client_identity(Some(ClientIdentity::new(Path::new("/path")) +/// .with_client_identity(Some(ClientIdentity::new(Path::new("/path").into()) /// .with_password("******") /// )); /// @@ -131,21 +207,23 @@ impl HostPortOrUrl { /// # #[cfg(feature = "rustls-tls")] /// let ssl_opts = SslOpts::default() /// .with_client_identity(Some(ClientIdentity::new( -/// Path::new("/path/to/chain"), -/// Path::new("/path/to/priv_key") +/// Path::new("/path/to/chain").into(), +/// Path::new("/path/to/priv_key").into(), /// ))); /// ``` #[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] pub struct SslOpts { - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] client_identity: Option, - root_cert_path: Option>, + root_certs: Vec>, + disable_built_in_roots: bool, skip_domain_validation: bool, accept_invalid_certs: bool, + tls_hostname_override: Option>, } impl SslOpts { - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] pub fn with_client_identity(mut self, identity: Option) -> Self { self.client_identity = identity; self @@ -154,35 +232,96 @@ impl SslOpts { /// Sets path to a `pem` or `der` certificate of the root that connector will trust. /// /// Multiple certs are allowed in .pem files. - pub fn with_root_cert_path>>( - mut self, - root_cert_path: Option, - ) -> Self { - self.root_cert_path = root_cert_path.map(Into::into); + /// + /// All the elements in `root_certs` will be merged. + pub fn with_root_certs(mut self, root_certs: Vec>) -> Self { + self.root_certs = root_certs; + self + } + + /// If `true`, use only the root certificates configured via [`SslOpts::with_root_certs`], + /// not any system or built-in certs. By default system built-in certs _will be_ used. + /// + /// # Connection URL + /// + /// Use `built_in_roots` URL parameter to set this value: + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?require_ssl=true&built_in_roots=false")?; + /// assert_eq!(opts.ssl_opts().unwrap().disable_built_in_roots(), true); + /// # Ok(()) } + /// ``` + pub fn with_disable_built_in_roots(mut self, disable_built_in_roots: bool) -> Self { + self.disable_built_in_roots = disable_built_in_roots; self } - /// The way to not validate the server's domain - /// name against its certificate (defaults to `false`). + /// The way to not validate the server's domain name against its certificate. + /// By default domain name _will be_ validated. + /// + /// # Connection URL + /// + /// Use `verify_identity` URL parameter to set this value: + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?require_ssl=true&verify_identity=false")?; + /// assert_eq!(opts.ssl_opts().unwrap().skip_domain_validation(), true); + /// # Ok(()) } + /// ``` pub fn with_danger_skip_domain_validation(mut self, value: bool) -> Self { self.skip_domain_validation = value; self } - /// If `true` then client will accept invalid certificate (expired, not trusted, ..) - /// (defaults to `false`). + /// If `true` then client will accept invalid certificate (expired, not trusted, ..). + /// Invalid certificates _won't get_ accepted by default. + /// + /// # Connection URL + /// + /// Use `verify_ca` URL parameter to set this value: + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?require_ssl=true&verify_ca=false")?; + /// assert_eq!(opts.ssl_opts().unwrap().accept_invalid_certs(), true); + /// # Ok(()) } + /// ``` pub fn with_danger_accept_invalid_certs(mut self, value: bool) -> Self { self.accept_invalid_certs = value; self } - #[cfg(any(feature = "native-tls", feature = "rustls-tls"))] + /// If set, will override the hostname used to verify the server's certificate. + /// + /// This is useful when connecting to a server via a tunnel, where the server hostname is + /// different from the hostname used to connect to the tunnel. + pub fn with_danger_tls_hostname_override>>( + mut self, + domain: Option, + ) -> Self { + self.tls_hostname_override = domain.map(Into::into); + self + } + + #[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))] pub fn client_identity(&self) -> Option<&ClientIdentity> { self.client_identity.as_ref() } - pub fn root_cert_path(&self) -> Option<&Path> { - self.root_cert_path.as_ref().map(AsRef::as_ref) + pub fn root_certs(&self) -> &[PathOrBuf<'static>] { + &self.root_certs + } + + pub fn disable_built_in_roots(&self) -> bool { + self.disable_built_in_roots } pub fn skip_domain_validation(&self) -> bool { @@ -192,6 +331,10 @@ impl SslOpts { pub fn accept_invalid_certs(&self) -> bool { self.accept_invalid_certs } + + pub fn tls_hostname_override(&self) -> Option<&str> { + self.tls_hostname_override.as_deref() + } } /// Connection pool options. @@ -208,9 +351,17 @@ pub struct PoolOpts { constraints: PoolConstraints, inactive_connection_ttl: Duration, ttl_check_interval: Duration, + abs_conn_ttl: Option, + abs_conn_ttl_jitter: Option, + reset_connection: bool, } impl PoolOpts { + /// Calls `Self::default`. + pub fn new() -> Self { + Self::default() + } + /// Creates the default [`PoolOpts`] with the given constraints. pub fn with_constraints(mut self, constraints: PoolConstraints) -> Self { self.constraints = constraints; @@ -222,6 +373,95 @@ impl PoolOpts { self.constraints } + /// Sets whether to reset connection upon returning it to a pool (defaults to `true`). + /// + /// Default behavior increases reliability but comes with cons: + /// + /// * reset procedure removes all prepared statements, i.e. kills prepared statements cache + /// * connection reset is quite fast but requires additional client-server roundtrip + /// (might also requires requthentication for older servers) + /// + /// The purpose of the reset procedure is to: + /// + /// * rollback any opened transactions (`mysql_async` is able to do this without explicit reset) + /// * reset transaction isolation level + /// * reset session variables + /// * delete user variables + /// * remove temporary tables + /// * remove all PREPARE statement (this action kills prepared statements cache) + /// + /// So to encrease overall performance you can safely opt-out of the default behavior + /// if you are not willing to change the session state in an unpleasant way. + /// + /// It is also possible to selectively opt-in/out using [`Conn::reset_connection`][1]. + /// + /// # Connection URL + /// + /// You can use `reset_connection` URL parameter to set this value. E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?reset_connection=false")?; + /// assert_eq!(opts.pool_opts().reset_connection(), false); + /// # Ok(()) } + /// ``` + /// + /// [1]: crate::Conn::reset_connection + pub fn with_reset_connection(mut self, reset_connection: bool) -> Self { + self.reset_connection = reset_connection; + self + } + + /// Returns the `reset_connection` value (see [`PoolOpts::with_reset_connection`]). + pub fn reset_connection(&self) -> bool { + self.reset_connection + } + + /// Sets an absolute TTL after which a connection is removed from the pool. + /// This may push the pool below the requested minimum pool size and is indepedent of the + /// idle TTL. + /// The absolute TTL is disabled by default. + /// Fractions of seconds are ignored. + pub fn with_abs_conn_ttl(mut self, ttl: Option) -> Self { + self.abs_conn_ttl = ttl; + self + } + + /// Optionally, the absolute TTL can be extended by a per-connection random amount + /// bounded by `jitter`. + /// Setting `abs_conn_ttl_jitter` without `abs_conn_ttl` has no effect. + /// Fractions of seconds are ignored. + pub fn with_abs_conn_ttl_jitter(mut self, jitter: Option) -> Self { + self.abs_conn_ttl_jitter = jitter; + self + } + + /// Returns the absolute TTL, if set. + pub fn abs_conn_ttl(&self) -> Option { + self.abs_conn_ttl + } + + /// Returns the absolute TTL's jitter bound, if set. + pub fn abs_conn_ttl_jitter(&self) -> Option { + self.abs_conn_ttl_jitter + } + + /// Returns a new deadline that's TTL (+ random jitter) in the future. + pub(crate) fn new_connection_ttl_deadline(&self) -> Option { + if let Some(ttl) = self.abs_conn_ttl { + let jitter = if let Some(jitter) = self.abs_conn_ttl_jitter { + Duration::from_secs(rand::rng().random_range(0..=jitter.as_secs())) + } else { + Duration::ZERO + }; + Some(Instant::now() + ttl + jitter) + } else { + None + } + } + /// Pool will recycle inactive connection if it is outside of the lower bound of the pool /// and if it is idling longer than this value (defaults to /// [`DEFAULT_INACTIVE_CONNECTION_TTL`]). @@ -308,6 +548,9 @@ impl Default for PoolOpts { constraints: DEFAULT_POOL_CONSTRAINTS, inactive_connection_ttl: DEFAULT_INACTIVE_CONNECTION_TTL, ttl_check_interval: DEFAULT_TTL_CHECK_INTERVAL, + abs_conn_ttl: None, + abs_conn_ttl_jitter: None, + reset_connection: true, } } } @@ -351,14 +594,18 @@ pub(crate) struct MysqlOpts { /// (defaults to `wait_timeout`). conn_ttl: Option, - /// Commands to execute on each new database connection. + /// Commands to execute once new connection is established. init: Vec, + /// Commands to execute on new connection and every time + /// [`Conn::reset`] or [`Conn::change_user`] is invoked. + setup: Vec, + /// Number of prepared statements cached on the client side (per connection). Defaults to `10`. stmt_cache_size: usize, /// Driver will require SSL connection if this option isn't `None` (default to `None`). - ssl_opts: Option, + ssl_opts: Option, /// Prefer socket connection (defaults to `true`). /// @@ -404,6 +651,23 @@ pub(crate) struct MysqlOpts { /// /// Available via `secure_auth` connection url parameter. secure_auth: bool, + + /// Enables `CLIENT_FOUND_ROWS` capability (defaults to `false`). + /// + /// Changes the behavior of the affected count returned for writes (UPDATE/INSERT etc). + /// It makes MySQL return the FOUND rows instead of the AFFECTED rows. + client_found_rows: bool, + + /// Enables Client-Side Cleartext Pluggable Authentication (defaults to `false`). + /// + /// Enables client to send passwords to the server as cleartext, without hashing or encryption + /// (consult MySql documentation for more info). + /// + /// # Security Notes + /// + /// Sending passwords as cleartext may be a security problem in some configurations. Please + /// consider using TLS or encrypted tunnels for server connection. + enable_cleartext_plugin: bool, } /// Mysql connection options. @@ -457,6 +721,11 @@ impl Opts { self.inner.address.get_tcp_port() } + /// The resolved IPs for the mysql server, if provided. + pub fn resolved_ips(&self) -> &Option> { + self.inner.address.get_resolved_ips() + } + /// User (defaults to `None`). /// /// # Connection URL @@ -508,11 +777,20 @@ impl Opts { self.inner.mysql_opts.db_name.as_ref().map(AsRef::as_ref) } - /// Commands to execute on each new database connection. + /// Commands to execute once new connection is established. pub fn init(&self) -> &[String] { self.inner.mysql_opts.init.as_ref() } + /// Commands to execute on new connection and every time + /// [`Conn::reset`][1] or [`Conn::change_user`][2] is invoked. + /// + /// [1]: crate::Conn::reset + /// [2]: crate::Conn::change_user + pub fn setup(&self) -> &[String] { + self.inner.mysql_opts.setup.as_ref() + } + /// TCP keep alive timeout in milliseconds (defaults to `None`). /// /// # Connection URL @@ -583,6 +861,49 @@ impl Opts { self.inner.mysql_opts.conn_ttl } + /// The pool will close a connection when this absolute TTL has elapsed. + /// Disabled by default. + /// + /// Enables forced recycling and migration of connections in a guaranteed timeframe. + /// This TTL bypasses pool constraints and an idle pool can go below the min size. + /// + /// # Connection URL + /// + /// You can use `abs_conn_ttl` URL parameter to set this value (in seconds). E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?abs_conn_ttl=86400")?; + /// assert_eq!(opts.abs_conn_ttl(), Some(Duration::from_secs(24 * 60 * 60))); + /// # Ok(()) } + /// ``` + pub fn abs_conn_ttl(&self) -> Option { + self.inner.mysql_opts.pool_opts.abs_conn_ttl + } + + /// Upper bound of a random value added to the absolute TTL, if enabled. + /// Disabled by default. + /// + /// Should be used to prevent connections from closing at the same time. + /// + /// # Connection URL + /// + /// You can use `abs_conn_ttl_jitter` URL parameter to set this value (in seconds). E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # use std::time::Duration; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?abs_conn_ttl=7200&abs_conn_ttl_jitter=3600")?; + /// assert_eq!(opts.abs_conn_ttl_jitter(), Some(Duration::from_secs(60 * 60))); + /// # Ok(()) } + /// ``` + pub fn abs_conn_ttl_jitter(&self) -> Option { + self.inner.mysql_opts.pool_opts.abs_conn_ttl_jitter + } + /// Number of prepared statements cached on the client side (per connection). Defaults to /// [`DEFAULT_STMT_CACHE_SIZE`]. /// @@ -630,7 +951,7 @@ impl Opts { /// /// pub fn ssl_opts(&self) -> Option<&SslOpts> { - self.inner.mysql_opts.ssl_opts.as_ref() + self.inner.mysql_opts.ssl_opts.as_ref().map(|o| &o.ssl_opts) } /// Prefer socket connection (defaults to `true` **temporary `false` on Windows platform**). @@ -721,6 +1042,51 @@ impl Opts { self.inner.mysql_opts.secure_auth } + /// Returns `true` if `CLIENT_FOUND_ROWS` capability is enabled (defaults to `false`). + /// + /// `CLIENT_FOUND_ROWS` changes the behavior of the affected count returned for writes + /// (UPDATE/INSERT etc). It makes MySQL return the FOUND rows instead of the AFFECTED rows. + /// + /// # Connection URL + /// + /// Use `client_found_rows` URL parameter to set this value. E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?client_found_rows=true")?; + /// assert!(opts.client_found_rows()); + /// # Ok(()) } + /// ``` + pub fn client_found_rows(&self) -> bool { + self.inner.mysql_opts.client_found_rows + } + + /// Returns `true` if `mysql_clear_password` plugin support is enabled (defaults to `false`). + /// + /// `mysql_clear_password` enables client to send passwords to the server as cleartext, without + /// hashing or encryption (consult MySql documentation for more info). + /// + /// # Security Notes + /// + /// Sending passwords as cleartext may be a security problem in some configurations. Please + /// consider using TLS or encrypted tunnels for server connection. + /// + /// # Connection URL + /// + /// Use `enable_cleartext_plugin` URL parameter to set this value. E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?enable_cleartext_plugin=true")?; + /// assert!(opts.enable_cleartext_plugin()); + /// # Ok(()) } + /// ``` + pub fn enable_cleartext_plugin(&self) -> bool { + self.inner.mysql_opts.enable_cleartext_plugin + } + pub(crate) fn get_capabilities(&self) -> CapabilityFlags { let mut out = CapabilityFlags::CLIENT_PROTOCOL_41 | CapabilityFlags::CLIENT_SECURE_CONNECTION @@ -743,9 +1109,16 @@ impl Opts { if self.inner.mysql_opts.compression.is_some() { out |= CapabilityFlags::CLIENT_COMPRESS; } + if self.client_found_rows() { + out |= CapabilityFlags::CLIENT_FOUND_ROWS; + } out } + + pub(crate) fn ssl_opts_and_connector(&self) -> Option<&SslOptsAndCachedConnector> { + self.inner.mysql_opts.ssl_opts.as_ref() + } } impl Default for MysqlOpts { @@ -755,6 +1128,7 @@ impl Default for MysqlOpts { pass: None, db_name: None, init: vec![], + setup: vec![], tcp_keepalive: None, tcp_nodelay: true, local_infile_handler: None, @@ -768,10 +1142,53 @@ impl Default for MysqlOpts { max_allowed_packet: None, wait_timeout: None, secure_auth: true, + client_found_rows: false, + enable_cleartext_plugin: false, } } } +#[derive(Clone)] +pub(crate) struct SslOptsAndCachedConnector { + ssl_opts: SslOpts, + tls_connector: Arc>, +} + +impl fmt::Debug for SslOptsAndCachedConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SslOptsAndCachedConnector") + .field("ssl_opts", &self.ssl_opts) + .finish() + } +} + +impl SslOptsAndCachedConnector { + fn new(ssl_opts: SslOpts) -> Self { + Self { + ssl_opts, + tls_connector: Arc::new(OnceCell::new()), + } + } + + pub(crate) fn ssl_opts(&self) -> &SslOpts { + &self.ssl_opts + } + + pub(crate) async fn build_tls_connector(&self) -> Result { + self.tls_connector + .get_or_try_init(move || self.ssl_opts.build_tls_connector()) + .await + .cloned() + } +} + +impl PartialEq for SslOptsAndCachedConnector { + fn eq(&self, other: &Self) -> bool { + self.ssl_opts == other.ssl_opts + } +} +impl Eq for SslOptsAndCachedConnector {} + /// Connection pool constraints. /// /// This type stores `min` and `max` constraints for [`crate::Pool`] and ensures that `min <= max`. @@ -795,8 +1212,8 @@ impl PoolConstraints { /// assert_eq!(opts.pool_opts().constraints(), PoolConstraints::new(0, 151).unwrap()); /// # Ok(()) } /// ``` - pub fn new(min: usize, max: usize) -> Option { - if min <= max { + pub const fn new(min: usize, max: usize) -> Option { + if min <= max && max > 0 { Some(PoolConstraints { min, max }) } else { None @@ -850,6 +1267,7 @@ pub struct OptsBuilder { opts: MysqlOpts, ip_or_hostname: String, tcp_port: u16, + resolved_ips: Option>, } impl Default for OptsBuilder { @@ -859,6 +1277,7 @@ impl Default for OptsBuilder { opts: MysqlOpts::default(), ip_or_hostname: address.get_ip_or_hostname().into(), tcp_port: address.get_tcp_port(), + resolved_ips: None, } } } @@ -879,7 +1298,8 @@ impl OptsBuilder { OptsBuilder { tcp_port: opts.inner.address.get_tcp_port(), ip_or_hostname: opts.inner.address.get_ip_or_hostname().to_string(), - opts: (*opts.inner).mysql_opts.clone(), + resolved_ips: opts.inner.address.get_resolved_ips().clone(), + opts: opts.inner.mysql_opts.clone(), } } @@ -895,6 +1315,14 @@ impl OptsBuilder { self } + /// Defines already-resolved IPs to use for the connection. When provided + /// the connection will not perform DNS resolution and the hostname will be + /// used only for TLS identity verification purposes. + pub fn resolved_ips>>(mut self, ips: Option) -> Self { + self.resolved_ips = ips.map(Into::into); + self + } + /// Defines user name. See [`Opts::user`]. pub fn user>(mut self, user: Option) -> Self { self.opts.user = user.map(Into::into); @@ -919,6 +1347,12 @@ impl OptsBuilder { self } + /// Defines setup queries. See [`Opts::setup`]. + pub fn setup>(mut self, setup: Vec) -> Self { + self.opts.setup = setup.into_iter().map(Into::into).collect(); + self + } + /// Defines `tcp_keepalive` option. See [`Opts::tcp_keepalive`]. pub fn tcp_keepalive>(mut self, tcp_keepalive: Option) -> Self { self.opts.tcp_keepalive = tcp_keepalive.map(Into::into); @@ -963,7 +1397,7 @@ impl OptsBuilder { /// Defines SSL options. See [`Opts::ssl_opts`]. pub fn ssl_opts>>(mut self, ssl_opts: T) -> Self { - self.opts.ssl_opts = ssl_opts.into(); + self.opts.ssl_opts = ssl_opts.into().map(SslOptsAndCachedConnector::new); self } @@ -990,8 +1424,7 @@ impl OptsBuilder { /// Note that it'll saturate to proper minimum and maximum values /// for this parameter (see MySql documentation). pub fn max_allowed_packet(mut self, max_allowed_packet: Option) -> Self { - self.opts.max_allowed_packet = - max_allowed_packet.map(|x| std::cmp::max(1024, std::cmp::min(1073741824, x))); + self.opts.max_allowed_packet = max_allowed_packet.map(|x| x.clamp(1024, 1073741824)); self } @@ -1018,11 +1451,47 @@ impl OptsBuilder { self.opts.secure_auth = secure_auth; self } + + /// Enables or disables `CLIENT_FOUND_ROWS` capability. See [`Opts::client_found_rows`]. + pub fn client_found_rows(mut self, client_found_rows: bool) -> Self { + self.opts.client_found_rows = client_found_rows; + self + } + + /// Enables Client-Side Cleartext Pluggable Authentication (defaults to `false`). + /// + /// Enables client to send passwords to the server as cleartext, without hashing or encryption + /// (consult MySql documentation for more info). + /// + /// # Security Notes + /// + /// Sending passwords as cleartext may be a security problem in some configurations. Please + /// consider using TLS or encrypted tunnels for server connection. + /// + /// # Connection URL + /// + /// Use `enable_cleartext_plugin` URL parameter to set this value. E.g. + /// + /// ``` + /// # use mysql_async::*; + /// # fn main() -> Result<()> { + /// let opts = Opts::from_url("mysql://localhost/db?enable_cleartext_plugin=true")?; + /// assert!(opts.enable_cleartext_plugin()); + /// # Ok(()) } + /// ``` + pub fn enable_cleartext_plugin(mut self, enable_cleartext_plugin: bool) -> Self { + self.opts.enable_cleartext_plugin = enable_cleartext_plugin; + self + } } impl From for Opts { fn from(builder: OptsBuilder) -> Opts { - let address = HostPortOrUrl::HostPort(builder.ip_or_hostname, builder.tcp_port); + let address = HostPortOrUrl::HostPort { + host: builder.ip_or_hostname, + port: builder.tcp_port, + resolved_ips: builder.resolved_ips, + }; let inner_opts = InnerOpts { mysql_opts: builder.opts, address, @@ -1034,6 +1503,118 @@ impl From for Opts { } } +/// [`COM_CHANGE_USER`][1] options. +/// +/// Connection [`Opts`] are going to be updated accordingly upon `COM_CHANGE_USER`. +/// +/// [`Opts`] won't be updated by default, because default `ChangeUserOpts` will reuse +/// connection's `user`, `pass` and `db_name`. +/// +/// [1]: https://dev.mysql.com/doc/c-api/5.7/en/mysql-change-user.html +#[derive(Clone, Eq, PartialEq)] +pub struct ChangeUserOpts { + user: Option>, + pass: Option>, + db_name: Option>, +} + +impl ChangeUserOpts { + pub(crate) fn update_opts(self, opts: &mut Opts) { + if self.user.is_none() && self.pass.is_none() && self.db_name.is_none() { + return; + } + + let mut builder = OptsBuilder::from_opts(opts.clone()); + + if let Some(user) = self.user { + builder = builder.user(user); + } + + if let Some(pass) = self.pass { + builder = builder.pass(pass); + } + + if let Some(db_name) = self.db_name { + builder = builder.db_name(db_name); + } + + *opts = Opts::from(builder); + } + + /// Creates change user options that'll reuse connection options. + pub fn new() -> Self { + Self { + user: None, + pass: None, + db_name: None, + } + } + + /// Set [`Opts::user`] to the given value. + pub fn with_user(mut self, user: Option) -> Self { + self.user = Some(user); + self + } + + /// Set [`Opts::pass`] to the given value. + pub fn with_pass(mut self, pass: Option) -> Self { + self.pass = Some(pass); + self + } + + /// Set [`Opts::db_name`] to the given value. + pub fn with_db_name(mut self, db_name: Option) -> Self { + self.db_name = Some(db_name); + self + } + + /// Returns user. + /// + /// * if `None` then `self` does not meant to change user + /// * if `Some(None)` then `self` will clear user + /// * if `Some(Some(_))` then `self` will change user + pub fn user(&self) -> Option> { + self.user.as_ref().map(|x| x.as_deref()) + } + + /// Returns password. + /// + /// * if `None` then `self` does not meant to change password + /// * if `Some(None)` then `self` will clear password + /// * if `Some(Some(_))` then `self` will change password + pub fn pass(&self) -> Option> { + self.pass.as_ref().map(|x| x.as_deref()) + } + + /// Returns database name. + /// + /// * if `None` then `self` does not meant to change database name + /// * if `Some(None)` then `self` will clear database name + /// * if `Some(Some(_))` then `self` will change database name + pub fn db_name(&self) -> Option> { + self.db_name.as_ref().map(|x| x.as_deref()) + } +} + +impl Default for ChangeUserOpts { + fn default() -> Self { + Self::new() + } +} + +impl fmt::Debug for ChangeUserOpts { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ChangeUserOpts") + .field("user", &self.user) + .field( + "pass", + &self.pass.as_ref().map(|x| x.as_ref().map(|_| "...")), + ) + .field("db_name", &self.db_name) + .finish() + } +} + fn get_opts_user_from_url(url: &Url) -> Option { let user = url.username(); if !user.is_empty() { @@ -1048,24 +1629,23 @@ fn get_opts_user_from_url(url: &Url) -> Option { } fn get_opts_pass_from_url(url: &Url) -> Option { - if let Some(pass) = url.password() { - Some( - percent_decode(pass.as_ref()) - .decode_utf8_lossy() - .into_owned(), - ) - } else { - None - } + url.password().map(|pass| { + percent_decode(pass.as_ref()) + .decode_utf8_lossy() + .into_owned() + }) } fn get_opts_db_name_from_url(url: &Url) -> Option { if let Some(mut segments) = url.path_segments() { - segments.next().map(|db_name| { - percent_decode(db_name.as_ref()) - .decode_utf8_lossy() - .into_owned() - }) + segments + .next() + .map(|db_name| { + percent_decode(db_name.as_ref()) + .decode_utf8_lossy() + .into_owned() + }) + .filter(|db| !db.is_empty()) } else { None } @@ -1080,9 +1660,9 @@ fn from_url_basic(url: &Url) -> std::result::Result<(MysqlOpts, Vec<(String, Str if url.cannot_be_a_base() || !url.has_host() { return Err(UrlError::Invalid); } - let user = get_opts_user_from_url(&url); - let pass = get_opts_pass_from_url(&url); - let db_name = get_opts_db_name_from_url(&url); + let user = get_opts_user_from_url(url); + let pass = get_opts_pass_from_url(url); + let db_name = get_opts_db_name_from_url(url); let query_pairs = url.query_pairs().into_owned().collect(); let opts = MysqlOpts { @@ -1100,12 +1680,14 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { let mut pool_min = DEFAULT_POOL_CONSTRAINTS.min; let mut pool_max = DEFAULT_POOL_CONSTRAINTS.max; + let mut ssl_opts = None; let mut skip_domain_validation = false; let mut accept_invalid_certs = false; + let mut disable_built_in_roots = false; for (key, value) in query_pairs { if key == "pool_min" { - match usize::from_str(&*value) { + match usize::from_str(&value) { Ok(value) => pool_min = value, _ => { return Err(UrlError::InvalidParamValue { @@ -1115,7 +1697,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "pool_max" { - match usize::from_str(&*value) { + match usize::from_str(&value) { Ok(value) => pool_max = value, _ => { return Err(UrlError::InvalidParamValue { @@ -1125,11 +1707,10 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "inactive_connection_ttl" { - match u64::from_str(&*value) { + match u64::from_str(&value) { Ok(value) => { opts.pool_opts = opts .pool_opts - .clone() .with_inactive_connection_ttl(Duration::from_secs(value)) } _ => { @@ -1140,11 +1721,10 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "ttl_check_interval" { - match u64::from_str(&*value) { + match u64::from_str(&value) { Ok(value) => { opts.pool_opts = opts .pool_opts - .clone() .with_ttl_check_interval(Duration::from_secs(value)) } _ => { @@ -1155,7 +1735,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "conn_ttl" { - match u64::from_str(&*value) { + match u64::from_str(&value) { Ok(value) => opts.conn_ttl = Some(Duration::from_secs(value)), _ => { return Err(UrlError::InvalidParamValue { @@ -1164,8 +1744,36 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "abs_conn_ttl" { + match u64::from_str(&value) { + Ok(value) => { + opts.pool_opts = opts + .pool_opts + .with_abs_conn_ttl(Some(Duration::from_secs(value))) + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "abs_conn_ttl".into(), + value, + }); + } + } + } else if key == "abs_conn_ttl_jitter" { + match u64::from_str(&value) { + Ok(value) => { + opts.pool_opts = opts + .pool_opts + .with_abs_conn_ttl_jitter(Some(Duration::from_secs(value))) + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "abs_conn_ttl_jitter".into(), + value, + }); + } + } } else if key == "tcp_keepalive" { - match u32::from_str(&*value) { + match u32::from_str(&value) { Ok(value) => opts.tcp_keepalive = Some(value), _ => { return Err(UrlError::InvalidParamValue { @@ -1175,11 +1783,8 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "max_allowed_packet" { - match usize::from_str(&*value) { - Ok(value) => { - opts.max_allowed_packet = - Some(std::cmp::max(1024, std::cmp::min(1073741824, value))) - } + match usize::from_str(&value) { + Ok(value) => opts.max_allowed_packet = Some(value.clamp(1024, 1073741824)), _ => { return Err(UrlError::InvalidParamValue { param: "max_allowed_packet".into(), @@ -1188,7 +1793,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "wait_timeout" { - match usize::from_str(&*value) { + match usize::from_str(&value) { #[cfg(windows)] Ok(value) => opts.wait_timeout = Some(std::cmp::min(2147483, value)), #[cfg(not(windows))] @@ -1200,8 +1805,28 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "enable_cleartext_plugin" { + match bool::from_str(&value) { + Ok(parsed) => opts.enable_cleartext_plugin = parsed, + Err(_) => { + return Err(UrlError::InvalidParamValue { + param: key.to_string(), + value, + }); + } + } + } else if key == "reset_connection" { + match bool::from_str(&value) { + Ok(parsed) => opts.pool_opts = opts.pool_opts.with_reset_connection(parsed), + Err(_) => { + return Err(UrlError::InvalidParamValue { + param: key.to_string(), + value, + }); + } + } } else if key == "tcp_nodelay" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(value) => opts.tcp_nodelay = value, _ => { return Err(UrlError::InvalidParamValue { @@ -1211,7 +1836,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "stmt_cache_size" { - match usize::from_str(&*value) { + match usize::from_str(&value) { Ok(stmt_cache_size) => { opts.stmt_cache_size = stmt_cache_size; } @@ -1223,7 +1848,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "prefer_socket" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(prefer_socket) => { opts.prefer_socket = prefer_socket; } @@ -1235,7 +1860,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "secure_auth" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(secure_auth) => { opts.secure_auth = secure_auth; } @@ -1246,6 +1871,18 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "client_found_rows" { + match bool::from_str(&value) { + Ok(client_found_rows) => { + opts.client_found_rows = client_found_rows; + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "client_found_rows".into(), + value, + }); + } + } } else if key == "socket" { opts.socket = Some(value) } else if key == "compression" { @@ -1266,8 +1903,10 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } else if key == "require_ssl" { - match bool::from_str(&*value) { - Ok(x) => opts.ssl_opts = x.then(SslOpts::default), + match bool::from_str(&value) { + Ok(x) => { + ssl_opts = x.then(SslOpts::default); + } _ => { return Err(UrlError::InvalidParamValue { param: "require_ssl".into(), @@ -1276,7 +1915,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "verify_ca" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(x) => { accept_invalid_certs = !x; } @@ -1288,7 +1927,7 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { } } } else if key == "verify_identity" { - match bool::from_str(&*value) { + match bool::from_str(&value) { Ok(x) => { skip_domain_validation = !x; } @@ -1299,13 +1938,25 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } } + } else if key == "built_in_roots" { + match bool::from_str(&value) { + Ok(x) => { + disable_built_in_roots = !x; + } + _ => { + return Err(UrlError::InvalidParamValue { + param: "built_in_roots".into(), + value, + }); + } + } } else { return Err(UrlError::UnknownParameter { param: key }); } } if let Some(pool_constraints) = PoolConstraints::new(pool_min, pool_max) { - opts.pool_opts = opts.pool_opts.clone().with_constraints(pool_constraints); + opts.pool_opts = opts.pool_opts.with_constraints(pool_constraints); } else { return Err(UrlError::InvalidPoolConstraints { min: pool_min, @@ -1313,11 +1964,14 @@ fn mysqlopts_from_url(url: &Url) -> std::result::Result { }); } - if let Some(ref mut ssl_opts) = opts.ssl_opts.as_mut() { + if let Some(ref mut ssl_opts) = ssl_opts { ssl_opts.accept_invalid_certs = accept_invalid_certs; ssl_opts.skip_domain_validation = skip_domain_validation; + ssl_opts.disable_built_in_roots = disable_built_in_roots; } + opts.ssl_opts = ssl_opts.map(SslOptsAndCachedConnector::new); + Ok(opts) } @@ -1329,7 +1983,7 @@ impl FromStr for Opts { } } -impl<'a> TryFrom<&'a str> for Opts { +impl TryFrom<&str> for Opts { type Error = UrlError; fn try_from(s: &str) -> std::result::Result { @@ -1342,7 +1996,7 @@ mod test { use super::{HostPortOrUrl, MysqlOpts, Opts, Url}; use crate::{error::UrlError::InvalidParamValue, SslOpts}; - use std::str::FromStr; + use std::{net::IpAddr, net::Ipv4Addr, net::Ipv6Addr, str::FromStr}; #[test] fn test_builder_eq_url() { @@ -1362,10 +2016,16 @@ mod test { assert_eq!(url_opts.pass(), builder_opts.pass()); assert_eq!(url_opts.db_name(), builder_opts.db_name()); assert_eq!(url_opts.init(), builder_opts.init()); + assert_eq!(url_opts.setup(), builder_opts.setup()); assert_eq!(url_opts.tcp_keepalive(), builder_opts.tcp_keepalive()); assert_eq!(url_opts.tcp_nodelay(), builder_opts.tcp_nodelay()); assert_eq!(url_opts.pool_opts(), builder_opts.pool_opts()); assert_eq!(url_opts.conn_ttl(), builder_opts.conn_ttl()); + assert_eq!(url_opts.abs_conn_ttl(), builder_opts.abs_conn_ttl()); + assert_eq!( + url_opts.abs_conn_ttl_jitter(), + builder_opts.abs_conn_ttl_jitter() + ); assert_eq!(url_opts.stmt_cache_size(), builder_opts.stmt_cache_size()); assert_eq!(url_opts.ssl_opts(), builder_opts.ssl_opts()); assert_eq!(url_opts.prefer_socket(), builder_opts.prefer_socket()); @@ -1383,13 +2043,15 @@ mod test { #[test] fn should_convert_url_into_opts() { - let url = "mysql://usr:pw@192.168.1.1:3309/dbname"; - let parsed_url = Url::parse("mysql://usr:pw@192.168.1.1:3309/dbname").unwrap(); + let url = "mysql://usr:pw@192.168.1.1:3309/dbname?prefer_socket=true"; + let parsed_url = + Url::parse("mysql://usr:pw@192.168.1.1:3309/dbname?prefer_socket=true").unwrap(); let mysql_opts = MysqlOpts { user: Some("usr".to_string()), pass: Some("pw".to_string()), db_name: Some("dbname".to_string()), + prefer_socket: true, ..MysqlOpts::default() }; let host = HostPortOrUrl::Url(parsed_url); @@ -1427,7 +2089,7 @@ mod test { ); const URL4: &str = - "mysql://localhost/foo?require_ssl=true&verify_ca=false&verify_identity=false"; + "mysql://localhost/foo?require_ssl=true&verify_ca=false&verify_identity=false&built_in_roots=false"; let opts = Opts::from_url(URL4).unwrap(); assert_eq!( opts.ssl_opts(), @@ -1435,6 +2097,7 @@ mod test { &SslOpts::default() .with_danger_accept_invalid_certs(true) .with_danger_skip_domain_validation(true) + .with_disable_built_in_roots(true) ) ); @@ -1503,4 +2166,58 @@ mod test { let opts = Opts::from_url("mysql://localhost/foo?compression=9").unwrap(); assert_eq!(opts.compression(), Some(crate::Compression::new(9))); } + + #[test] + fn test_builder_eq_url_empty_db() { + let builder = super::OptsBuilder::default(); + let builder_opts = Opts::from(builder); + + let url: &str = "mysql://iq-controller@localhost"; + let url_opts = super::Opts::from_str(url).unwrap(); + assert_eq!(url_opts.db_name(), builder_opts.db_name()); + + let url: &str = "mysql://iq-controller@localhost/"; + let url_opts = super::Opts::from_str(url).unwrap(); + assert_eq!(url_opts.db_name(), builder_opts.db_name()); + } + + #[test] + fn test_builder_update_port_host_resolved_ips() { + let builder = super::OptsBuilder::default() + .ip_or_hostname("foo") + .tcp_port(33306); + + let resolved = vec![ + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 7)), + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff)), + ]; + let builder2 = builder + .clone() + .tcp_port(55223) + .resolved_ips(Some(resolved.clone())); + + let builder_opts = Opts::from(builder); + assert_eq!(builder_opts.ip_or_hostname(), "foo"); + assert_eq!(builder_opts.tcp_port(), 33306); + assert_eq!( + builder_opts.hostport_or_url(), + &HostPortOrUrl::HostPort { + host: "foo".to_string(), + port: 33306, + resolved_ips: None + } + ); + + let builder_opts2 = Opts::from(builder2); + assert_eq!(builder_opts2.ip_or_hostname(), "foo"); + assert_eq!(builder_opts2.tcp_port(), 55223); + assert_eq!( + builder_opts2.hostport_or_url(), + &HostPortOrUrl::HostPort { + host: "foo".to_string(), + port: 55223, + resolved_ips: Some(resolved), + } + ); + } } diff --git a/src/opts/native_tls_opts.rs b/src/opts/native_tls_opts.rs index 49eb4c46..a02c6041 100644 --- a/src/opts/native_tls_opts.rs +++ b/src/opts/native_tls_opts.rs @@ -1,25 +1,32 @@ -#![cfg(feature = "native-tls")] +#![cfg(feature = "native-tls-tls")] -use std::{borrow::Cow, path::Path}; +use std::borrow::Cow; + +use native_tls::Identity; + +use super::PathOrBuf; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ClientIdentity { - pkcs12_path: Cow<'static, Path>, + pkcs12_archive: PathOrBuf<'static>, password: Option>, } impl ClientIdentity { - /// Creates new identity with the given path to the pkcs12 archive. - pub fn new(pkcs12_path: T) -> Self - where - T: Into>, - { + /// Creates new identity with the given pkcs12 archive. + pub fn new(pkcs12_archive: PathOrBuf<'static>) -> Self { Self { - pkcs12_path: pkcs12_path.into(), + pkcs12_archive, password: None, } } + /// Sets the pkcs12 archive. + pub fn with_pkcs12_archive(mut self, pkcs12_archive: PathOrBuf<'static>) -> Self { + self.pkcs12_archive = pkcs12_archive; + self + } + /// Sets the archive password. pub fn with_password(mut self, pass: T) -> Self where @@ -29,13 +36,19 @@ impl ClientIdentity { self } - /// Returns the pkcs12 archive path. - pub fn pkcs12_path(&self) -> &Path { - self.pkcs12_path.as_ref() + /// Returns the pkcs12 archive. + pub fn pkcs12_archive(&self) -> PathOrBuf<'_> { + self.pkcs12_archive.borrow() } /// Returns the archive password. pub fn password(&self) -> Option<&str> { self.password.as_ref().map(AsRef::as_ref) } + + pub(crate) async fn load(&self) -> crate::Result { + let der = self.pkcs12_archive.read().await?; + let password = self.password().unwrap_or_default(); + Ok(Identity::from_pkcs12(der.as_ref(), password)?) + } } diff --git a/src/opts/rustls_opts.rs b/src/opts/rustls_opts.rs index 143dc62d..ef954ea9 100644 --- a/src/opts/rustls_opts.rs +++ b/src/opts/rustls_opts.rs @@ -1,83 +1,81 @@ #![cfg(feature = "rustls-tls")] -use rustls::{Certificate, PrivateKey}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer}; use rustls_pemfile::{certs, rsa_private_keys}; use std::{borrow::Cow, path::Path}; +use super::PathOrBuf; + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ClientIdentity { - cert_chain_path: Cow<'static, Path>, - priv_key_path: Cow<'static, Path>, + cert_chain: PathOrBuf<'static>, + priv_key: PathOrBuf<'static>, } impl ClientIdentity { /// Creates new identity. /// - /// `cert_chain_path` - path to a certificate chain (in PEM or DER) - /// `priv_key_path` - path to a private key (in DER or PEM) (it'll take the first one) - pub fn new(cert_chain_path: T, priv_key_path: U) -> Self - where - T: Into>, - U: Into>, - { + /// `cert_chain` - certificate chain (in PEM or DER) + /// `priv_key` - private key (in DER or PEM) (it'll take the first one) + pub fn new(cert_chain: PathOrBuf<'static>, priv_key: PathOrBuf<'static>) -> Self { Self { - cert_chain_path: cert_chain_path.into(), - priv_key_path: priv_key_path.into(), + cert_chain, + priv_key, } } /// Sets the certificate chain path (in DER or PEM). - pub fn with_cert_chain_path(mut self, cert_chain_path: T) -> Self - where - T: Into>, - { - self.cert_chain_path = cert_chain_path.into(); + pub fn with_cert_chain(mut self, cert_chain: PathOrBuf<'static>) -> Self { + self.cert_chain = cert_chain; self } /// Sets the private key path (in DER or PEM) (it'll take the first one). - pub fn with_priv_key_path(mut self, priv_key_path: T) -> Self + pub fn with_priv_key(mut self, priv_key: PathOrBuf<'static>) -> Self where T: Into>, { - self.priv_key_path = priv_key_path.into(); + self.priv_key = priv_key; self } - /// Returns the certificate chain path. - pub fn cert_chain_path(&self) -> &Path { - self.cert_chain_path.as_ref() + /// Returns the certificate chain. + pub fn cert_chain(&self) -> PathOrBuf<'_> { + self.cert_chain.borrow() } - /// Returns the private key path. - pub fn priv_key_path(&self) -> &Path { - self.priv_key_path.as_ref() + /// Returns the private key. + pub fn priv_key(&self) -> PathOrBuf<'_> { + self.priv_key.borrow() } - pub(crate) fn load(&self) -> crate::Result<(Vec, PrivateKey)> { - let cert_data = std::fs::read(self.cert_chain_path.as_ref())?; - let key_data = std::fs::read(self.priv_key_path.as_ref())?; + pub(crate) async fn load( + &self, + ) -> crate::Result<(Vec>, PrivateKeyDer<'static>)> { + let cert_data = self.cert_chain.read().await?; + let key_data = self.priv_key.read().await?; let mut cert_chain = Vec::new(); if std::str::from_utf8(&cert_data).is_err() { - cert_chain.push(Certificate(cert_data)); + cert_chain.push(CertificateDer::from(cert_data.into_owned())); } else { - for cert in certs(&mut &*cert_data)? { - cert_chain.push(Certificate(cert)); + for cert in certs(&mut &*cert_data) { + cert_chain.push(cert?); } } - let priv_key; - if std::str::from_utf8(&key_data).is_err() { - priv_key = Some(PrivateKey(key_data)); + let priv_key = if std::str::from_utf8(&key_data).is_err() { + Some(PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from( + key_data.into_owned(), + ))) } else { - priv_key = rsa_private_keys(&mut &*key_data)? - .into_iter() - .take(1) - .map(PrivateKey) - .next(); - } + let mut priv_key = None; + for key in rsa_private_keys(&mut &*key_data).take(1) { + priv_key = Some(PrivateKeyDer::Pkcs1(key?.clone_key())); + } + priv_key + }; Ok(( cert_chain, diff --git a/src/query.rs b/src/query.rs index 976c1724..75ea1ca9 100644 --- a/src/query.rs +++ b/src/query.rs @@ -11,7 +11,6 @@ use std::borrow::Cow; use futures_util::FutureExt; use crate::{ - connection_like::ToConnectionResult, from_row, prelude::{FromRow, StatementLike, ToConnection}, tracing_utils::LevelInfo, @@ -217,11 +216,8 @@ impl Query for Q { C: ToConnection<'a, 't> + 'a, { async move { - let mut conn = match conn.to_connection() { - ToConnectionResult::Immediate(conn) => conn, - ToConnectionResult::Mediate(fut) => fut.await?, - }; - conn.raw_query::<'_, _, LevelInfo>(self).await?; + let mut conn = conn.to_connection().resolve().await?; + conn.as_mut().raw_query::<'_, _, LevelInfo>(self).await?; Ok(QueryResult::new(conn)) } .boxed() @@ -264,14 +260,12 @@ where C: ToConnection<'a, 't> + 'a, { async move { - let mut conn = match conn.to_connection() { - ToConnectionResult::Immediate(conn) => conn, - ToConnectionResult::Mediate(fut) => fut.await?, - }; + let mut conn = conn.to_connection().resolve().await?; - let statement = conn.get_statement(self.query).await?; + let statement = conn.as_mut().get_statement(self.query).await?; - conn.execute_statement(&statement, self.params.into()) + conn.as_mut() + .execute_statement(&statement, self.params.into()) .await?; Ok(QueryResult::new(conn)) @@ -324,15 +318,12 @@ where C: ToConnection<'a, 't> + 'a, { async move { - let mut conn = match conn.to_connection() { - ToConnectionResult::Immediate(conn) => conn, - ToConnectionResult::Mediate(fut) => fut.await?, - }; + let mut conn = conn.to_connection().resolve().await?; - let statement = conn.get_statement(self.query).await?; + let statement = conn.as_mut().get_statement(self.query).await?; for params in self.params { - conn.execute_statement(&statement, params).await?; + conn.as_mut().execute_statement(&statement, params).await?; } Ok(()) diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index 79062ec6..5a6d93aa 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -31,7 +31,7 @@ use crate::{ query::AsQuery, queryable::query_result::ResultSetMeta, tracing_utils::{LevelInfo, LevelTrace, TracingLevel}, - BoxFuture, Column, Conn, Params, ResultSetStream, Row, + BoxFuture, Column, Conn, Connection, Params, ResultSetStream, Row, }; pub mod query_result; @@ -96,8 +96,7 @@ impl Conn { pub(crate) async fn clean_dirty(&mut self) -> Result<()> { self.drop_result().await?; if self.get_tx_status() == TxStatus::RequiresRollback { - self.set_tx_status(TxStatus::None); - self.exec_drop("ROLLBACK", ()).await?; + self.rollback_transaction().await?; } Ok(()) } @@ -273,7 +272,7 @@ pub trait Queryable: Send { async move { self.query_iter(query).await?.drop_result().await }.boxed() } - /// Exectues the given statement for each item in the given params iterator. + /// Executes the given statement for each item in the given params iterator. /// /// It'll prepare `stmt` (once), if necessary. fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()> @@ -283,7 +282,7 @@ pub trait Queryable: Send { I::IntoIter: Send, P: Into + Send; - /// Exectues the given statement and collects the first result set. + /// Executes the given statement and collects the first result set. /// /// It'll prepare `stmt`, if necessary. /// @@ -307,7 +306,7 @@ pub trait Queryable: Send { .boxed() } - /// Exectues the given statement and returns the first row of the first result set. + /// Executes the given statement and returns the first row of the first result set. /// /// It'll prepare `stmt`, if necessary. /// @@ -335,7 +334,7 @@ pub trait Queryable: Send { .boxed() } - /// Exectues the given stmt and maps each row of the first result set. + /// Executes the given stmt and maps each row of the first result set. /// /// It'll prepare `stmt`, if necessary. /// @@ -367,7 +366,7 @@ pub trait Queryable: Send { .boxed() } - /// Exectues the given stmt and folds the first result set to a signel value. + /// Executes the given stmt and folds the first result set to a signel value. /// /// It'll prepare `stmt`, if necessary. /// @@ -399,7 +398,7 @@ pub trait Queryable: Send { .boxed() } - /// Exectues the given statement and drops the result. + /// Executes the given statement and drops the result. fn exec_drop<'a: 'b, 'b, S, P>(&'a mut self, stmt: S, params: P) -> BoxFuture<'b, ()> where S: StatementLike + 'b, @@ -538,7 +537,7 @@ impl Queryable for Conn { impl Queryable for Transaction<'_> { fn ping(&mut self) -> BoxFuture<'_, ()> { - self.0.ping() + self.0.as_mut().ping() } fn query_iter<'a, Q>( @@ -548,18 +547,18 @@ impl Queryable for Transaction<'_> { where Q: AsQuery + 'a, { - self.0.query_iter(query) + self.0.as_mut().query_iter(query) } fn prep<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Statement> where Q: AsQuery + 'a, { - self.0.prep(query) + self.0.as_mut().prep(query) } fn close(&mut self, stmt: Statement) -> BoxFuture<'_, ()> { - self.0.close(stmt) + self.0.as_mut().close(stmt) } fn exec_iter<'a: 's, 's, Q, P>( @@ -571,7 +570,7 @@ impl Queryable for Transaction<'_> { Q: StatementLike + 'a, P: Into, { - self.0.exec_iter(stmt, params) + self.0.as_mut().exec_iter(stmt, params) } fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()> @@ -581,13 +580,63 @@ impl Queryable for Transaction<'_> { I::IntoIter: Send, P: Into + Send, { - self.0.exec_batch(stmt, params_iter) + self.0.as_mut().exec_batch(stmt, params_iter) + } +} + +impl<'c, 't: 'c> Queryable for Connection<'c, 't> { + #[inline] + fn ping(&mut self) -> BoxFuture<'_, ()> { + self.as_mut().ping() + } + + #[inline] + fn query_iter<'a, Q>( + &'a mut self, + query: Q, + ) -> BoxFuture<'a, QueryResult<'a, 'static, TextProtocol>> + where + Q: AsQuery + 'a, + { + self.as_mut().query_iter(query) + } + + fn prep<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Statement> + where + Q: AsQuery + 'a, + { + self.as_mut().prep(query) + } + + fn close(&mut self, stmt: Statement) -> BoxFuture<'_, ()> { + self.as_mut().close(stmt) + } + + fn exec_iter<'a: 's, 's, Q, P>( + &'a mut self, + stmt: Q, + params: P, + ) -> BoxFuture<'s, QueryResult<'a, 'static, BinaryProtocol>> + where + Q: StatementLike + 'a, + P: Into, + { + self.as_mut().exec_iter(stmt, params) + } + + fn exec_batch<'a: 'b, 'b, S, P, I>(&'a mut self, stmt: S, params_iter: I) -> BoxFuture<'b, ()> + where + S: StatementLike + 'b, + I: IntoIterator + Send + 'b, + I::IntoIter: Send, + P: Into + Send, + { + self.as_mut().exec_batch(stmt, params_iter) } } #[cfg(test)] mod tests { - use super::Queryable; use crate::{error::Result, prelude::*, test_misc::get_opts, Conn}; #[tokio::test] diff --git a/src/queryable/query_result/mod.rs b/src/queryable/query_result/mod.rs index 21f45f96..a83ebd6a 100644 --- a/src/queryable/query_result/mod.rs +++ b/src/queryable/query_result/mod.rs @@ -124,21 +124,21 @@ where if columns.is_empty() { // Empty, but not yet consumed result set. - self.conn.set_pending_result(None)?; + self.conn.as_mut().set_pending_result(None)?; } else { // Not yet consumed non-empty result set. - let packet = match self.conn.read_packet().await { + let packet = match self.conn.as_mut().read_packet().await { Ok(packet) => packet, Err(err) => { // Next row contained an error. No more data will follow. - self.conn.set_pending_result(None)?; + self.conn.as_mut().set_pending_result(None)?; return Err(err); } }; if P::is_last_result_set_packet(self.conn.capabilities(), &packet) { // `packet` is a result set terminator. - self.conn.set_pending_result(None)?; + self.conn.as_mut().set_pending_result(None)?; } else { // `packet` is a result set row. row = Some(P::read_result_set_row(&packet, columns)?); @@ -154,7 +154,10 @@ where async fn next_set(&mut self) -> crate::Result { if self.conn.more_results_exists() { // More data will follow. - self.conn.routine(NextSetRoutine::

::new()).await?; + self.conn + .as_mut() + .routine(NextSetRoutine::

::new()) + .await?; } Ok(self.conn.has_pending_result()) } @@ -174,23 +177,23 @@ where columns: Arc<[Column]>, ) -> crate::Result> { if let Some(row) = self.next_row(columns).await? { - return Ok(Some(row)); + Ok(Some(row)) } else { self.next_set().await?; - return Ok(None); + Ok(None) } } /// Skips the taken result set. async fn skip_taken(&mut self, meta: Arc) -> crate::Result<()> { - while let Some(_) = self.next_row_or_next_set((*meta).clone()).await? {} + while (self.next_row_or_next_set((*meta).clone()).await?).is_some() {} Ok(()) } #[doc(hidden)] pub async fn next(&mut self) -> Result> { loop { - match self.conn.use_pending_result()?.cloned() { + match self.conn.as_mut().use_pending_result()?.cloned() { Some(PendingResult::Pending(meta)) => return self.next_row_or_next_set(meta).await, Some(PendingResult::Taken(meta)) => self.skip_taken(meta).await?, None => return Ok(None), diff --git a/src/queryable/query_result/result_set_stream.rs b/src/queryable/query_result/result_set_stream.rs index 9210d9b8..b779e93c 100644 --- a/src/queryable/query_result/result_set_stream.rs +++ b/src/queryable/query_result/result_set_stream.rs @@ -102,6 +102,16 @@ impl<'r, 'a: 'r, 't: 'a, T, P> ResultSetStream<'r, 'a, 't, T, P> { .unwrap_or_default() } + /// See [`QueryResult::columns_ref`]. + pub fn columns_ref(&self) -> &[Column] { + &self.columns[..] + } + + /// See [`QueryResult::columns`]. + pub fn columns(&self) -> Arc<[Column]> { + self.columns.clone() + } + /// See [`Conn::info`][1]. /// /// [1]: crate::Conn::info @@ -179,7 +189,7 @@ where async fn setup_stream( &mut self, ) -> crate::Result>, Arc<[Column]>)>> { - match self.conn.use_pending_result()? { + match self.conn.as_mut().use_pending_result()? { Some(PendingResult::Taken(meta)) => { let meta = (*meta).clone(); self.skip_taken(meta).await?; @@ -189,7 +199,7 @@ where } let ok_packet = self.conn.last_ok_packet().cloned(); - let columns = match self.conn.take_pending_result()? { + let columns = match self.conn.as_mut().take_pending_result()? { Some(meta) => meta.columns().clone(), None => return Ok(None), }; diff --git a/src/queryable/query_result/tests.rs b/src/queryable/query_result/tests.rs index 24c870cf..3aa80081 100644 --- a/src/queryable/query_result/tests.rs +++ b/src/queryable/query_result/tests.rs @@ -162,7 +162,7 @@ async fn should_map_resultset() -> super::Result<()> { ) .await?; - let rows_1 = result.map(|row| from_row::<(String, u8)>(row)).await?; + let rows_1 = result.map(from_row::<(String, u8)>).await?; let rows_2 = result.map_and_drop(from_row).await?; conn.disconnect().await?; @@ -219,7 +219,7 @@ async fn should_handle_multi_result_sets_where_some_results_have_no_output() -> r.for_each_and_drop(|x| assert_eq!(from_row::(x), 1)) .await?; let r = t.query_iter(QUERY).await?; - let out = r.map_and_drop(|row| from_row::(row)).await?; + let out = r.map_and_drop(from_row::).await?; assert_eq!(vec![1], out); let r = t.query_iter(QUERY).await?; let out = r diff --git a/src/queryable/stmt.rs b/src/queryable/stmt.rs index 6b4a2eb8..0ce7294b 100644 --- a/src/queryable/stmt.rs +++ b/src/queryable/stmt.rs @@ -9,7 +9,7 @@ use futures_util::FutureExt; use mysql_common::{ io::ParseBuf, - named_params::parse_named_params, + named_params::ParsedNamedParams, packets::{ComStmtClose, StmtPacket}, }; @@ -45,12 +45,22 @@ fn to_statement_move<'a, T: AsQuery + 'a>( ) -> ToStatementResult<'a> { let fut = async move { let query = stmt.as_query(); - let (named_params, raw_query) = parse_named_params(query.as_ref())?; - let inner_stmt = match conn.get_cached_stmt(&*raw_query) { + let parsed = ParsedNamedParams::parse(query.as_ref())?; + let inner_stmt = match conn.get_cached_stmt(parsed.query()) { Some(inner_stmt) => inner_stmt, - None => conn.prepare_statement(raw_query).await?, + None => { + conn.prepare_statement(Cow::Borrowed(parsed.query())) + .await? + } }; - Ok(Statement::new(inner_stmt, named_params)) + Ok(Statement::new( + inner_stmt, + parsed + .params() + .iter() + .map(|x| x.as_ref().to_vec()) + .collect::>(), + )) } .boxed(); ToStatementResult::Mediate(fut) @@ -240,11 +250,12 @@ impl StmtInner { #[derive(Debug, Clone, Eq, PartialEq)] pub struct Statement { pub(crate) inner: Arc, - pub(crate) named_params: Option>>, + /// An empty vector in case of no named params. + pub(crate) named_params: Vec>, } impl Statement { - pub(crate) fn new(inner: Arc, named_params: Option>>) -> Self { + pub(crate) fn new(inner: Arc, named_params: Vec>) -> Self { Self { inner, named_params, @@ -295,7 +306,7 @@ impl crate::Conn { let packets = self.read_packets(num).await?; let defs = packets .into_iter() - .map(|x| ParseBuf(&*x).parse(())) + .map(|x| ParseBuf(&x).parse(())) .collect::, _>>() .map_err(Error::from)?; diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index 53b38ed8..b857f1a7 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -126,6 +126,7 @@ impl fmt::Display for IsolationLevel { /// /// You should always call either `commit` or `rollback`, otherwise transaction will be rolled /// back implicitly when corresponding connection is dropped or queried. +#[must_use = "transaction object must be committed or rolled back explicitly"] #[derive(Debug)] pub struct Transaction<'a>(pub(crate) Connection<'a, 'static>); @@ -142,6 +143,8 @@ impl<'a> Transaction<'a> { let mut conn = conn.into(); + conn.as_mut().clean_dirty().await?; + if conn.get_tx_status() != TxStatus::None { return Err(DriverError::NestedTransaction.into()); } @@ -152,42 +155,55 @@ impl<'a> Transaction<'a> { if let Some(isolation_level) = isolation_level { let query = format!("SET TRANSACTION ISOLATION LEVEL {}", isolation_level); - conn.query_drop(query).await?; + conn.as_mut().query_drop(query).await?; } if let Some(readonly) = readonly { if readonly { - conn.query_drop("SET TRANSACTION READ ONLY").await?; + conn.as_mut() + .query_drop("SET TRANSACTION READ ONLY") + .await?; } else { - conn.query_drop("SET TRANSACTION READ WRITE").await?; + conn.as_mut() + .query_drop("SET TRANSACTION READ WRITE") + .await?; } } if consistent_snapshot { - conn.query_drop("START TRANSACTION WITH CONSISTENT SNAPSHOT") + conn.as_mut() + .query_drop("START TRANSACTION WITH CONSISTENT SNAPSHOT") .await? } else { - conn.query_drop("START TRANSACTION").await? + conn.as_mut().query_drop("START TRANSACTION").await? }; - conn.set_tx_status(TxStatus::InTransaction); + conn.as_mut().set_tx_status(TxStatus::InTransaction); Ok(Transaction(conn)) } - /// Performs `COMMIT` query. - pub async fn commit(mut self) -> Result<()> { - let result = self.0.query_iter("COMMIT").await?; + /// Performs `COMMIT` query or returns an error + async fn try_commit(&mut self) -> Result<()> { + let result = self.0.as_mut().query_iter("COMMIT").await?; result.drop_result().await?; - self.0.set_tx_status(TxStatus::None); + self.0.as_mut().set_tx_status(TxStatus::None); Ok(()) } + /// Performs `COMMIT` query or rollbacks when any error occurs and returns the original error. + pub async fn commit(mut self) -> Result<()> { + match self.try_commit().await { + Ok(..) => Ok(()), + Err(e) => { + self.0.as_mut().rollback_transaction().await.unwrap_or(()); + Err(e) + } + } + } + /// Performs `ROLLBACK` query. pub async fn rollback(mut self) -> Result<()> { - let result = self.0.query_iter("ROLLBACK").await?; - result.drop_result().await?; - self.0.set_tx_status(TxStatus::None); - Ok(()) + self.0.as_mut().rollback_transaction().await } } @@ -195,14 +211,14 @@ impl Deref for Transaction<'_> { type Target = Conn; fn deref(&self) -> &Self::Target { - &*self.0 + &self.0 } } impl Drop for Transaction<'_> { fn drop(&mut self) { if self.0.get_tx_status() == TxStatus::InTransaction { - self.0.set_tx_status(TxStatus::RequiresRollback); + self.0.as_mut().set_tx_status(TxStatus::RequiresRollback); } } } diff --git a/src/tracing_utils.rs b/src/tracing_utils.rs index b32170c0..190e5d33 100644 --- a/src/tracing_utils.rs +++ b/src/tracing_utils.rs @@ -44,7 +44,28 @@ macro_rules! instrument_result { ($fut:expr, $span:expr) => {{ let fut = async { $fut.await.or_else(|e| { - tracing::error!(error = %e); + match &e { + $crate::error::Error::Server(server_error) => { + match server_error.code { + // Duplicated entry for key + 1062 => { + tracing::warn!(error = %e) + } + // Foreign key violation + 1451 => { + tracing::warn!(error = %e) + } + // User defined exception condition + 1644 => { + tracing::warn!(error = %e); + } + _ => tracing::error!(error = %e), + } + }, + e => { + tracing::error!(error = %e); + } + } Err(e) }) }; diff --git a/tests/exports.rs b/tests/exports.rs index c8b13137..6f9feef8 100644 --- a/tests/exports.rs +++ b/tests/exports.rs @@ -4,12 +4,17 @@ use mysql_async::{ futures::{DisconnectPool, GetConn}, params, prelude::{ - BatchQuery, ConvIr, FromRow, FromValue, GlobalHandler, Protocol, Query, Queryable, - StatementLike, ToValue, + BatchQuery, FromRow, FromValue, GlobalHandler, Protocol, Query, Queryable, StatementLike, + ToValue, }, BinaryProtocol, Column, Conn, Deserialized, DriverError, Error, FromRowError, FromValueError, - IoError, IsolationLevel, Opts, OptsBuilder, Params, ParseError, Pool, PoolConstraints, - PoolOpts, QueryResult, Result, Row, Serialized, ServerError, SslOpts, Statement, TextProtocol, - Transaction, TxOpts, UrlError, Value, WhiteListFsHandler, DEFAULT_INACTIVE_CONNECTION_TTL, - DEFAULT_TTL_CHECK_INTERVAL, + GnoInterval, Gtids, IoError, IsolationLevel, OkPacket, Opts, OptsBuilder, Params, ParseError, + Pool, PoolConstraints, PoolOpts, QueryResult, Result, Row, Schema, Serialized, ServerError, + SessionStateChange, SessionStateInfo, Sid, SslOpts, Statement, SystemVariable, TextProtocol, + Transaction, TransactionCharacteristics, TransactionState, TxOpts, Unsupported, UrlError, + Value, WhiteListFsHandler, DEFAULT_INACTIVE_CONNECTION_TTL, DEFAULT_TTL_CHECK_INTERVAL, }; + +#[cfg(feature = "binlog")] +#[allow(unused_imports)] +use mysql_async::{binlog, BinlogStream, BinlogStreamRequest}; diff --git a/tests/generic.rs b/tests/generic.rs index 3f14891d..b0451f20 100644 --- a/tests/generic.rs +++ b/tests/generic.rs @@ -33,7 +33,7 @@ where TupleType: FromRow + Send + 'static, P: Protocol + Send + 'static, { - Ok(result.collect().await?) + result.collect().await } pub async fn get_single_result(result: QueryResult<'_, '_, P>) -> Result @@ -51,7 +51,7 @@ where #[tokio::test] async fn use_generic_code() { - let pool = Pool::new(Opts::from_url(&*get_url()).unwrap()); + let pool = Pool::new(Opts::from_url(&get_url()).unwrap()); let mut conn = pool.get_conn().await.unwrap(); let result = conn.query_iter("SELECT 1, 2, 3").await.unwrap(); let result = get_single_result::<(u8, u8, u8), _>(result).await.unwrap();