diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 4b860c07..ee4d46ce 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -61,6 +61,7 @@ lazy_static::lazy_static! { static ref FIXED_MARIADB_VERSION_RE: Regex = Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap(); } +const DEFAULT_WAIT_TIMEOUT: usize = 28800; /// Helper that asynchronously disconnects the givent connection on the default tokio executor. fn disconnect(mut conn: Conn) { @@ -917,42 +918,123 @@ impl Conn { /// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`] /// async fn read_settings(&mut self) -> Result<()> { - let read_socket = self.inner.opts.prefer_socket() && self.inner.socket.is_none(); - let read_max_allowed_packet = self.opts().max_allowed_packet().is_none(); - let read_wait_timeout = self.opts().wait_timeout().is_none(); + enum Action { + Load(Cfg), + Apply(CfgData), + } - let settings: Option = if read_socket || read_max_allowed_packet || read_wait_timeout { - self.query_internal("SELECT @@socket, @@max_allowed_packet, @@wait_timeout") - .await? - } else { - None - }; + enum CfgData { + MaxAllowedPacket(usize), + WaitTimeout(usize), + } - // set socket inside the connection - if read_socket { - self.inner.socket = settings.as_ref().map(|s| s.get("@@socket")).unwrap_or(None); + impl CfgData { + fn apply(&self, conn: &mut Conn) { + match self { + Self::MaxAllowedPacket(value) => { + if let Some(stream) = conn.inner.stream.as_mut() { + stream.set_max_allowed_packet(*value); + } + } + Self::WaitTimeout(value) => { + conn.inner.wait_timeout = Duration::from_secs(*value as u64); + } + } + } } - // set max_allowed_packet - let max_allowed_packet = if read_max_allowed_packet { - settings - .as_ref() - .map(|s| s.get("@@max_allowed_packet")) - .unwrap() - } else { - self.opts().max_allowed_packet() - }; - if let Some(stream) = self.inner.stream.as_mut() { - stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET)); + enum Cfg { + Socket, + MaxAllowedPacket, + WaitTimeout, + } + + impl Cfg { + const fn name(&self) -> &'static str { + match self { + Self::Socket => "@@socket", + Self::MaxAllowedPacket => "@@max_allowed_packet", + Self::WaitTimeout => "@@wait_timeout", + } + } + + fn apply(&self, conn: &mut Conn, value: Option) { + match self { + Cfg::Socket => { + conn.inner.socket = value.map(crate::from_value).flatten(); + } + Cfg::MaxAllowedPacket => { + if let Some(stream) = conn.inner.stream.as_mut() { + stream.set_max_allowed_packet( + value + .map(crate::from_value) + .flatten() + .unwrap_or(DEFAULT_MAX_ALLOWED_PACKET), + ); + } + } + Cfg::WaitTimeout => { + conn.inner.wait_timeout = Duration::from_secs( + value + .map(crate::from_value) + .flatten() + .unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64, + ); + } + } + } } - // set read_wait_timeout - let wait_timeout = if read_wait_timeout { - settings.as_ref().map(|s| s.get("@@wait_timeout")).unwrap() + let mut actions = vec![ + if let Some(x) = self.opts().max_allowed_packet() { + Action::Apply(CfgData::MaxAllowedPacket(x)) + } else { + Action::Load(Cfg::MaxAllowedPacket) + }, + if let Some(x) = self.opts().wait_timeout() { + Action::Apply(CfgData::WaitTimeout(x)) + } else { + Action::Load(Cfg::WaitTimeout) + }, + ]; + + if self.inner.opts.prefer_socket() && self.inner.socket.is_none() { + actions.push(Action::Load(Cfg::Socket)) + } + + let loads = actions + .iter() + .filter_map(|x| match x { + Action::Load(x) => Some(x), + Action::Apply(_) => None, + }) + .collect::>(); + + let loaded = if !loads.is_empty() { + let query = loads + .iter() + .zip(std::iter::once(' ').chain(std::iter::repeat(','))) + .fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| { + acc.push(prefix); + acc.push_str(cfg.name()); + acc + }); + + self.query_internal::(query) + .await? + .map(|row| row.unwrap()) + .unwrap_or_else(|| vec![crate::Value::NULL; loads.len()]) } else { - self.opts().wait_timeout() + vec![] }; - self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64); + let mut loaded = loaded.into_iter(); + + for action in actions { + match action { + Action::Load(cfg) => cfg.apply(self, loaded.next()), + Action::Apply(cfg) => cfg.apply(self), + } + } Ok(()) }