From 2b6eca56cd3653a3bf9fb85454b91dad72b8f34c Mon Sep 17 00:00:00 2001 From: chenmch Date: Sat, 28 Feb 2026 12:20:12 +0800 Subject: [PATCH 1/3] feat: arrow_cli -c support --- src/main.rs | 5 +++- src/session.rs | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 5358b2a..efd3823 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,6 +42,9 @@ struct Args { #[clap(long, default_value = "false", help = "Print resultset schema")] print_schema: bool, + + #[clap(short = 'c', long, help = "Execute SQL command and exit")] + command: Option, } #[tokio::main] @@ -52,7 +55,7 @@ pub async fn main() -> Result<(), ArrowError> { // Authenticate let url = format!("{protocol}://{}:{}", args.host, args.port); let endpoint = endpoint(&args, url)?; - let is_repl = atty::is(Stream::Stdin); + let is_repl = atty::is(Stream::Stdin) && args.command.is_none(); let mut session = session::Session::try_new(endpoint, is_repl, args).await?; session.handle().await; diff --git a/src/session.rs b/src/session.rs index 623de03..e524532 100644 --- a/src/session.rs +++ b/src/session.rs @@ -56,11 +56,75 @@ impl Session { pub async fn handle(&mut self) { if self.is_repl { self.handle_repl().await; + } else if self.args.command.is_some() { + let command = self.args.command.clone().unwrap(); + self.handle_command(&command).await; } else { self.handle_stdin().await; } } + pub async fn handle_command(&mut self, command: &str) { + if let Err(e) = self.handle_query_command(command).await { + eprintln!("handle_query err: {e}"); + } + } + + pub async fn handle_query_command(&mut self, query: &str) -> Result { + let start = Instant::now(); + let flight_info = if self.args.prepared { + let mut stmt = self.client.prepare(query.to_string(), None).await?; + let info = stmt.execute().await?; + stmt.close().await?; + info + } else { + self.client.execute(query.to_string(), None).await? + }; + let ticket_recv_duration = start.elapsed(); + let mut batches: Vec = Vec::new(); + + let mut handles = Vec::with_capacity(flight_info.endpoint.len()); + for endpoint in flight_info.endpoint.iter() { + let ticket = endpoint + .ticket + .as_ref() + .ok_or_else(|| ArrowError::IpcError("Ticket is emtpy".to_string()))? + .clone(); + let mut client = self.client.clone(); + handles.push(tokio::spawn(async move { + let flight_data = client.do_get(ticket).await?; + let result: Vec = flight_data.try_collect().await.map_err(|e| { + ArrowError::IpcError(format!("Failed to collect record batches: {e}")) + })?; + Ok::, ArrowError>(result) + })); + } + + for handle in handles { + batches.extend(handle.await.unwrap()?); + } + let rows_recv_duration = start.elapsed(); + + // Use pretty format for command mode (like psql -c) + let res = pretty_format_batches(batches.as_slice())?; + println!("{res}"); + + if self.args.print_schema { + let schema = flight_info.try_decode_schema()?; + println!("{schema:#?}"); + } + + let rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + println!( + "{} rows in set (tickets received in {:.3} sec, rows received in {:.3} sec)", + rows, + ticket_recv_duration.as_secs_f64(), + rows_recv_duration.as_secs_f64(), + ); + + Ok(false) + } + pub async fn handle_repl(&mut self) { let mut query = "".to_owned(); let mut rl = Editor::::new().unwrap(); From 6eef99dc20a0835429ae78aef401d3aefa48fb47 Mon Sep 17 00:00:00 2001 From: chenmch Date: Sat, 28 Feb 2026 13:40:30 +0800 Subject: [PATCH 2/3] optimize: Use execute_query in handle_repl to eliminate code repetition --- src/session.rs | 89 +++++++++++++++++++++----------------------------- 1 file changed, 38 insertions(+), 51 deletions(-) diff --git a/src/session.rs b/src/session.rs index e524532..a181e0b 100644 --- a/src/session.rs +++ b/src/session.rs @@ -71,6 +71,41 @@ impl Session { } pub async fn handle_query_command(&mut self, query: &str) -> Result { + let (batches, ticket_recv_duration, rows_recv_duration, flight_info) = + self.execute_query(query).await?; + + // Use pretty format for command mode (like psql -c) + let res = pretty_format_batches(batches.as_slice())?; + println!("{res}"); + + if self.args.print_schema { + let schema = flight_info.try_decode_schema()?; + println!("{schema:#?}"); + } + + let rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + println!( + "{} rows in set (tickets received in {:.3} sec, rows received in {:.3} sec)", + rows, + ticket_recv_duration.as_secs_f64(), + rows_recv_duration.as_secs_f64(), + ); + + Ok(false) + } + + async fn execute_query( + &mut self, + query: &str, + ) -> Result< + ( + Vec, + std::time::Duration, + std::time::Duration, + arrow_flight::FlightInfo, + ), + ArrowError, + > { let start = Instant::now(); let flight_info = if self.args.prepared { let mut stmt = self.client.prepare(query.to_string(), None).await?; @@ -105,24 +140,7 @@ impl Session { } let rows_recv_duration = start.elapsed(); - // Use pretty format for command mode (like psql -c) - let res = pretty_format_batches(batches.as_slice())?; - println!("{res}"); - - if self.args.print_schema { - let schema = flight_info.try_decode_schema()?; - println!("{schema:#?}"); - } - - let rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - println!( - "{} rows in set (tickets received in {:.3} sec, rows received in {:.3} sec)", - rows, - ticket_recv_duration.as_secs_f64(), - rows_recv_duration.as_secs_f64(), - ); - - Ok(false) + Ok((batches, ticket_recv_duration, rows_recv_duration, flight_info)) } pub async fn handle_repl(&mut self) { @@ -191,39 +209,8 @@ impl Session { println!("\n{}\n", query); } - let start = Instant::now(); - let flight_info = if self.args.prepared { - let mut stmt = self.client.prepare(query.to_string(), None).await?; - let info = stmt.execute().await?; - stmt.close().await?; - info - } else { - self.client.execute(query.to_string(), None).await? - }; - let ticket_recv_duration = start.elapsed(); - let mut batches: Vec = Vec::new(); - - let mut handles = Vec::with_capacity(flight_info.endpoint.len()); - for endpoint in flight_info.endpoint.iter() { - let ticket = endpoint - .ticket - .as_ref() - .ok_or_else(|| ArrowError::IpcError("Ticket is emtpy".to_string()))? - .clone(); - let mut client = self.client.clone(); - handles.push(tokio::spawn(async move { - let flight_data = client.do_get(ticket).await?; - let result: Vec = flight_data.try_collect().await.map_err(|e| { - ArrowError::IpcError(format!("Failed to collect record batches: {e}")) - })?; - Ok::, ArrowError>(result) - })); - } - - for handle in handles { - batches.extend(handle.await.unwrap()?); - } - let rows_recv_duration = start.elapsed(); + let (batches, ticket_recv_duration, rows_recv_duration, flight_info) = + self.execute_query(query).await?; if is_repl { let res = pretty_format_batches(batches.as_slice())?; From 20ba119643e732047ca27a79a04ba99179159cd3 Mon Sep 17 00:00:00 2001 From: chenmch Date: Sat, 28 Feb 2026 13:45:45 +0800 Subject: [PATCH 3/3] fmt code --- src/session.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/session.rs b/src/session.rs index a181e0b..8cfefac 100644 --- a/src/session.rs +++ b/src/session.rs @@ -140,7 +140,12 @@ impl Session { } let rows_recv_duration = start.elapsed(); - Ok((batches, ticket_recv_duration, rows_recv_duration, flight_info)) + Ok(( + batches, + ticket_recv_duration, + rows_recv_duration, + flight_info, + )) } pub async fn handle_repl(&mut self) {