diff --git a/src/main.rs b/src/main.rs index 5358b2a..efd3823 100644 --- a/src/main.rs +++ b/src/main.rs @@ -42,6 +42,9 @@ struct Args { #[clap(long, default_value = "false", help = "Print resultset schema")] print_schema: bool, + + #[clap(short = 'c', long, help = "Execute SQL command and exit")] + command: Option, } #[tokio::main] @@ -52,7 +55,7 @@ pub async fn main() -> Result<(), ArrowError> { // Authenticate let url = format!("{protocol}://{}:{}", args.host, args.port); let endpoint = endpoint(&args, url)?; - let is_repl = atty::is(Stream::Stdin); + let is_repl = atty::is(Stream::Stdin) && args.command.is_none(); let mut session = session::Session::try_new(endpoint, is_repl, args).await?; session.handle().await; diff --git a/src/session.rs b/src/session.rs index 623de03..8cfefac 100644 --- a/src/session.rs +++ b/src/session.rs @@ -56,11 +56,98 @@ impl Session { pub async fn handle(&mut self) { if self.is_repl { self.handle_repl().await; + } else if self.args.command.is_some() { + let command = self.args.command.clone().unwrap(); + self.handle_command(&command).await; } else { self.handle_stdin().await; } } + pub async fn handle_command(&mut self, command: &str) { + if let Err(e) = self.handle_query_command(command).await { + eprintln!("handle_query err: {e}"); + } + } + + pub async fn handle_query_command(&mut self, query: &str) -> Result { + let (batches, ticket_recv_duration, rows_recv_duration, flight_info) = + self.execute_query(query).await?; + + // Use pretty format for command mode (like psql -c) + let res = pretty_format_batches(batches.as_slice())?; + println!("{res}"); + + if self.args.print_schema { + let schema = flight_info.try_decode_schema()?; + println!("{schema:#?}"); + } + + let rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + println!( + "{} rows in set (tickets received in {:.3} sec, rows received in {:.3} sec)", + rows, + ticket_recv_duration.as_secs_f64(), + rows_recv_duration.as_secs_f64(), + ); + + Ok(false) + } + + async fn execute_query( + &mut self, + query: &str, + ) -> Result< + ( + Vec, + std::time::Duration, + std::time::Duration, + arrow_flight::FlightInfo, + ), + ArrowError, + > { + let start = Instant::now(); + let flight_info = if self.args.prepared { + let mut stmt = self.client.prepare(query.to_string(), None).await?; + let info = stmt.execute().await?; + stmt.close().await?; + info + } else { + self.client.execute(query.to_string(), None).await? + }; + let ticket_recv_duration = start.elapsed(); + let mut batches: Vec = Vec::new(); + + let mut handles = Vec::with_capacity(flight_info.endpoint.len()); + for endpoint in flight_info.endpoint.iter() { + let ticket = endpoint + .ticket + .as_ref() + .ok_or_else(|| ArrowError::IpcError("Ticket is emtpy".to_string()))? + .clone(); + let mut client = self.client.clone(); + handles.push(tokio::spawn(async move { + let flight_data = client.do_get(ticket).await?; + let result: Vec = flight_data.try_collect().await.map_err(|e| { + ArrowError::IpcError(format!("Failed to collect record batches: {e}")) + })?; + Ok::, ArrowError>(result) + })); + } + + for handle in handles { + batches.extend(handle.await.unwrap()?); + } + let rows_recv_duration = start.elapsed(); + + Ok(( + batches, + ticket_recv_duration, + rows_recv_duration, + flight_info, + )) + } + pub async fn handle_repl(&mut self) { let mut query = "".to_owned(); let mut rl = Editor::::new().unwrap(); @@ -127,39 +214,8 @@ impl Session { println!("\n{}\n", query); } - let start = Instant::now(); - let flight_info = if self.args.prepared { - let mut stmt = self.client.prepare(query.to_string(), None).await?; - let info = stmt.execute().await?; - stmt.close().await?; - info - } else { - self.client.execute(query.to_string(), None).await? - }; - let ticket_recv_duration = start.elapsed(); - let mut batches: Vec = Vec::new(); - - let mut handles = Vec::with_capacity(flight_info.endpoint.len()); - for endpoint in flight_info.endpoint.iter() { - let ticket = endpoint - .ticket - .as_ref() - .ok_or_else(|| ArrowError::IpcError("Ticket is emtpy".to_string()))? - .clone(); - let mut client = self.client.clone(); - handles.push(tokio::spawn(async move { - let flight_data = client.do_get(ticket).await?; - let result: Vec = flight_data.try_collect().await.map_err(|e| { - ArrowError::IpcError(format!("Failed to collect record batches: {e}")) - })?; - Ok::, ArrowError>(result) - })); - } - - for handle in handles { - batches.extend(handle.await.unwrap()?); - } - let rows_recv_duration = start.elapsed(); + let (batches, ticket_recv_duration, rows_recv_duration, flight_info) = + self.execute_query(query).await?; if is_repl { let res = pretty_format_batches(batches.as_slice())?;