From 9e931bde2fd1c11f4f7008b19ca555d4a3330811 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Sat, 28 Feb 2026 14:50:44 +0800 Subject: [PATCH] Refactor handling part code --- src/session.rs | 231 +++++++++++++++++++++++++------------------------ 1 file changed, 116 insertions(+), 115 deletions(-) diff --git a/src/session.rs b/src/session.rs index 8cfefac..2c18876 100644 --- a/src/session.rs +++ b/src/session.rs @@ -2,14 +2,14 @@ use arrow_array::RecordBatch; use arrow_cast::pretty::pretty_format_batches; use arrow_csv::WriterBuilder; use arrow_flight::{ - flight_service_client::FlightServiceClient, sql::client::FlightSqlServiceClient, + FlightInfo, flight_service_client::FlightServiceClient, sql::client::FlightSqlServiceClient, }; use arrow_schema::ArrowError; use futures::TryStreamExt; use rustyline::Editor; use rustyline::error::ReadlineError; use rustyline::history::DefaultHistory; -use std::io::BufRead; +use std::{io::BufRead, time::Duration}; use tokio::time::Instant; use tonic::transport::{Channel, Endpoint}; @@ -64,34 +64,102 @@ impl Session { } } - pub async fn handle_command(&mut self, command: &str) { - if let Err(e) = self.handle_query_command(command).await { - eprintln!("handle_query err: {e}"); - } - } + pub async fn handle_repl(&mut self) { + let mut query = "".to_owned(); + let mut rl = Editor::::new().unwrap(); + rl.set_helper(Some(CliHelper::new())); + rl.load_history(&get_history_path()).ok(); + + loop { + match rl.readline(&self.prompt) { + Ok(line) if line.starts_with("--") => { + continue; + } + Ok(line) => { + let line = line.trim_end(); + query.push_str(&line.replace("\\\n", "")); + } + Err(e) => match e { + ReadlineError::Io(err) => { + eprintln!("io err: {err}"); + } + ReadlineError::Interrupted => { + println!("^C"); + } + ReadlineError::Eof => { + break; + } + _ => {} + }, + } + if !query.is_empty() { + let _ = rl.add_history_entry(query.trim_end()); - pub async fn handle_query_command(&mut self, query: &str) -> Result { - let (batches, ticket_recv_duration, rows_recv_duration, flight_info) = - self.execute_query(query).await?; + if query == "exit" || query == "quit" || query == r#"\q"# { + break; + } - // Use pretty format for command mode (like psql -c) - let res = pretty_format_batches(batches.as_slice())?; - println!("{res}"); + println!("\n{}\n", query); - if self.args.print_schema { - let schema = flight_info.try_decode_schema()?; - println!("{schema:#?}"); + if let Err(e) = async { + let (batches, ticket_recv_duration, rows_recv_duration, flight_info) = + self.execute_query(&query).await?; + print_batches( + &batches, + ticket_recv_duration, + rows_recv_duration, + flight_info, + &self.args, + )?; + Ok::<_, ArrowError>(()) + } + .await + { + eprintln!("handle query err: {e}"); + } + } + query.clear(); } - let rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - println!( - "{} rows in set (tickets received in {:.3} sec, rows received in {:.3} sec)", - rows, - ticket_recv_duration.as_secs_f64(), - rows_recv_duration.as_secs_f64(), - ); + println!("Bye"); + let _ = rl.save_history(&get_history_path()); + } + + pub async fn handle_command(&mut self, command: &str) { + if let Err(e) = async { + let (batches, ticket_recv_duration, rows_recv_duration, flight_info) = + self.execute_query(command).await?; + + print_batches( + &batches, + ticket_recv_duration, + rows_recv_duration, + flight_info, + &self.args, + )?; + Ok::<_, ArrowError>(()) + } + .await + { + eprintln!("handle command {command} err: {e}"); + } + } - Ok(false) + pub async fn handle_stdin(&mut self) { + let mut lines = std::io::stdin().lock().lines(); + // TODO support multi line + while let Some(Ok(line)) = lines.next() { + let line = line.trim_end(); + if let Err(e) = async { + let (batches, _, _, _) = self.execute_query(line).await?; + print_batches_with_sep(batches.as_slice(), b'\t')?; + Ok::<_, ArrowError>(()) + } + .await + { + eprintln!("handle query {line} err: {e}"); + } + } } async fn execute_query( @@ -147,103 +215,35 @@ impl Session { flight_info, )) } +} - pub async fn handle_repl(&mut self) { - let mut query = "".to_owned(); - let mut rl = Editor::::new().unwrap(); - rl.set_helper(Some(CliHelper::new())); - rl.load_history(&get_history_path()).ok(); - - loop { - match rl.readline(&self.prompt) { - Ok(line) if line.starts_with("--") => { - continue; - } - Ok(line) => { - let line = line.trim_end(); - query.push_str(&line.replace("\\\n", "")); - } - Err(e) => match e { - ReadlineError::Io(err) => { - eprintln!("io err: {err}"); - } - ReadlineError::Interrupted => { - println!("^C"); - } - ReadlineError::Eof => { - break; - } - _ => {} - }, - } - if !query.is_empty() { - let _ = rl.add_history_entry(query.trim_end()); - match self.handle_query(true, &query).await { - Ok(true) => { - break; - } - Ok(false) => {} - Err(e) => { - eprintln!("handle_query err: {e}"); - } - } - } - query.clear(); - } +fn print_batches( + batches: &[RecordBatch], + ticket_recv_duration: Duration, + rows_recv_duration: Duration, + flight_info: FlightInfo, + args: &Args, +) -> Result<(), ArrowError> { + let res = pretty_format_batches(batches)?; - println!("Bye"); - let _ = rl.save_history(&get_history_path()); - } + println!("{res}\n"); - pub async fn handle_stdin(&mut self) { - let mut lines = std::io::stdin().lock().lines(); - // TODO support multi line - while let Some(Ok(line)) = lines.next() { - let line = line.trim_end(); - if let Err(e) = self.handle_query(false, line).await { - eprintln!("handle_query err: {e}"); - } - } + if args.print_schema { + let schema = flight_info.try_decode_schema()?; + println!("{schema:#?}\n"); } - pub async fn handle_query(&mut self, is_repl: bool, query: &str) -> Result { - if is_repl { - if query == "exit" || query == "quit" || query == r#"\q"# { - return Ok(true); - } - println!("\n{}\n", query); - } - - let (batches, ticket_recv_duration, rows_recv_duration, flight_info) = - self.execute_query(query).await?; - - if is_repl { - let res = pretty_format_batches(batches.as_slice())?; - - println!("{res}\n"); - - if self.args.print_schema { - let schema = flight_info.try_decode_schema()?; - println!("{schema:#?}\n"); - } - - let rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - println!( - "{} rows in set (tickets received in {:.3} sec, rows received in {:.3} sec)\n", - rows, - ticket_recv_duration.as_secs_f64(), - rows_recv_duration.as_secs_f64(), - ); - } else { - let res = print_batches_with_sep(batches.as_slice(), b'\t')?; - print!("{res}"); - } - - Ok(false) - } + let rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + println!( + "{} rows in set (tickets received in {:.3} sec, rows received in {:.3} sec)\n", + rows, + ticket_recv_duration.as_secs_f64(), + rows_recv_duration.as_secs_f64(), + ); + Ok(()) } -fn print_batches_with_sep(batches: &[RecordBatch], delimiter: u8) -> Result { +fn print_batches_with_sep(batches: &[RecordBatch], delimiter: u8) -> Result<(), ArrowError> { let mut bytes = vec![]; { let builder = WriterBuilder::new() @@ -255,7 +255,8 @@ fn print_batches_with_sep(batches: &[RecordBatch], delimiter: u8) -> Result String {