Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 110 additions & 28 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ lazy_static::lazy_static! {
static ref FIXED_MARIADB_VERSION_RE: Regex =
Regex::new(r"^(?:5.5.5-)?(\d{1,2})\.(\d{1,2})\.(\d{1,3})-MariaDB").unwrap();
}
const DEFAULT_WAIT_TIMEOUT: usize = 28800;

/// Helper that asynchronously disconnects the givent connection on the default tokio executor.
fn disconnect(mut conn: Conn) {
Expand Down Expand Up @@ -917,42 +918,123 @@ impl Conn {
/// * It reads and stores `wait_timeout` in the connection unless it's already in [`Opts`]
///
async fn read_settings(&mut self) -> Result<()> {
let read_socket = self.inner.opts.prefer_socket() && self.inner.socket.is_none();
let read_max_allowed_packet = self.opts().max_allowed_packet().is_none();
let read_wait_timeout = self.opts().wait_timeout().is_none();
enum Action {
Load(Cfg),
Apply(CfgData),
}

let settings: Option<Row> = if read_socket || read_max_allowed_packet || read_wait_timeout {
self.query_internal("SELECT @@socket, @@max_allowed_packet, @@wait_timeout")
.await?
} else {
None
};
enum CfgData {
MaxAllowedPacket(usize),
WaitTimeout(usize),
}

// set socket inside the connection
if read_socket {
self.inner.socket = settings.as_ref().map(|s| s.get("@@socket")).unwrap_or(None);
impl CfgData {
fn apply(&self, conn: &mut Conn) {
match self {
Self::MaxAllowedPacket(value) => {
if let Some(stream) = conn.inner.stream.as_mut() {
stream.set_max_allowed_packet(*value);
}
}
Self::WaitTimeout(value) => {
conn.inner.wait_timeout = Duration::from_secs(*value as u64);
}
}
}
}

// set max_allowed_packet
let max_allowed_packet = if read_max_allowed_packet {
settings
.as_ref()
.map(|s| s.get("@@max_allowed_packet"))
.unwrap()
} else {
self.opts().max_allowed_packet()
};
if let Some(stream) = self.inner.stream.as_mut() {
stream.set_max_allowed_packet(max_allowed_packet.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET));
enum Cfg {
Socket,
MaxAllowedPacket,
WaitTimeout,
}

impl Cfg {
const fn name(&self) -> &'static str {
match self {
Self::Socket => "@@socket",
Self::MaxAllowedPacket => "@@max_allowed_packet",
Self::WaitTimeout => "@@wait_timeout",
}
}

fn apply(&self, conn: &mut Conn, value: Option<crate::Value>) {
match self {
Cfg::Socket => {
conn.inner.socket = value.map(crate::from_value).flatten();
}
Cfg::MaxAllowedPacket => {
if let Some(stream) = conn.inner.stream.as_mut() {
stream.set_max_allowed_packet(
value
.map(crate::from_value)
.flatten()
.unwrap_or(DEFAULT_MAX_ALLOWED_PACKET),
);
}
}
Cfg::WaitTimeout => {
conn.inner.wait_timeout = Duration::from_secs(
value
.map(crate::from_value)
.flatten()
.unwrap_or(DEFAULT_WAIT_TIMEOUT) as u64,
);
}
}
}
}

// set read_wait_timeout
let wait_timeout = if read_wait_timeout {
settings.as_ref().map(|s| s.get("@@wait_timeout")).unwrap()
let mut actions = vec![
if let Some(x) = self.opts().max_allowed_packet() {
Action::Apply(CfgData::MaxAllowedPacket(x))
} else {
Action::Load(Cfg::MaxAllowedPacket)
},
if let Some(x) = self.opts().wait_timeout() {
Action::Apply(CfgData::WaitTimeout(x))
} else {
Action::Load(Cfg::WaitTimeout)
},
];

if self.inner.opts.prefer_socket() && self.inner.socket.is_none() {
actions.push(Action::Load(Cfg::Socket))
}

let loads = actions
.iter()
.filter_map(|x| match x {
Action::Load(x) => Some(x),
Action::Apply(_) => None,
})
.collect::<Vec<_>>();

let loaded = if !loads.is_empty() {
let query = loads
.iter()
.zip(std::iter::once(' ').chain(std::iter::repeat(',')))
.fold("SELECT".to_owned(), |mut acc, (cfg, prefix)| {
acc.push(prefix);
acc.push_str(cfg.name());
acc
});

self.query_internal::<Row, String>(query)
.await?
.map(|row| row.unwrap())
.unwrap_or_else(|| vec![crate::Value::NULL; loads.len()])
} else {
self.opts().wait_timeout()
vec![]
};
self.inner.wait_timeout = Duration::from_secs(wait_timeout.unwrap_or(28800) as u64);
let mut loaded = loaded.into_iter();

for action in actions {
match action {
Action::Load(cfg) => cfg.apply(self, loaded.next()),
Action::Apply(cfg) => cfg.apply(self),
}
}

Ok(())
}
Expand Down