Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ struct Args {

#[clap(long, default_value = "false", help = "Print resultset schema")]
print_schema: bool,

#[clap(short = 'c', long, help = "Execute SQL command and exit")]
command: Option<String>,
}

#[tokio::main]
Expand All @@ -52,7 +55,7 @@ pub async fn main() -> Result<(), ArrowError> {
// Authenticate
let url = format!("{protocol}://{}:{}", args.host, args.port);
let endpoint = endpoint(&args, url)?;
let is_repl = atty::is(Stream::Stdin);
let is_repl = atty::is(Stream::Stdin) && args.command.is_none();
let mut session = session::Session::try_new(endpoint, is_repl, args).await?;

session.handle().await;
Expand Down
122 changes: 89 additions & 33 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,98 @@ impl Session {
pub async fn handle(&mut self) {
if self.is_repl {
self.handle_repl().await;
} else if self.args.command.is_some() {
let command = self.args.command.clone().unwrap();
self.handle_command(&command).await;
} else {
self.handle_stdin().await;
}
}

pub async fn handle_command(&mut self, command: &str) {
if let Err(e) = self.handle_query_command(command).await {
eprintln!("handle_query err: {e}");
}
}

pub async fn handle_query_command(&mut self, query: &str) -> Result<bool, ArrowError> {
let (batches, ticket_recv_duration, rows_recv_duration, flight_info) =
self.execute_query(query).await?;

// Use pretty format for command mode (like psql -c)
let res = pretty_format_batches(batches.as_slice())?;
println!("{res}");

if self.args.print_schema {
let schema = flight_info.try_decode_schema()?;
println!("{schema:#?}");
}

let rows: usize = batches.iter().map(|b| b.num_rows()).sum();
println!(
"{} rows in set (tickets received in {:.3} sec, rows received in {:.3} sec)",
rows,
ticket_recv_duration.as_secs_f64(),
rows_recv_duration.as_secs_f64(),
);

Ok(false)
}

async fn execute_query(
&mut self,
query: &str,
) -> Result<
(
Vec<RecordBatch>,
std::time::Duration,
std::time::Duration,
arrow_flight::FlightInfo,
),
ArrowError,
> {
let start = Instant::now();
let flight_info = if self.args.prepared {
let mut stmt = self.client.prepare(query.to_string(), None).await?;
let info = stmt.execute().await?;
stmt.close().await?;
info
} else {
self.client.execute(query.to_string(), None).await?
};
let ticket_recv_duration = start.elapsed();
let mut batches: Vec<RecordBatch> = Vec::new();

let mut handles = Vec::with_capacity(flight_info.endpoint.len());
for endpoint in flight_info.endpoint.iter() {
let ticket = endpoint
.ticket
.as_ref()
.ok_or_else(|| ArrowError::IpcError("Ticket is emtpy".to_string()))?
.clone();
let mut client = self.client.clone();
handles.push(tokio::spawn(async move {
let flight_data = client.do_get(ticket).await?;
let result: Vec<RecordBatch> = flight_data.try_collect().await.map_err(|e| {
ArrowError::IpcError(format!("Failed to collect record batches: {e}"))
})?;
Ok::<Vec<RecordBatch>, ArrowError>(result)
}));
}

for handle in handles {
batches.extend(handle.await.unwrap()?);
}
let rows_recv_duration = start.elapsed();

Ok((
batches,
ticket_recv_duration,
rows_recv_duration,
flight_info,
))
}

pub async fn handle_repl(&mut self) {
let mut query = "".to_owned();
let mut rl = Editor::<CliHelper, DefaultHistory>::new().unwrap();
Expand Down Expand Up @@ -127,39 +214,8 @@ impl Session {
println!("\n{}\n", query);
}

let start = Instant::now();
let flight_info = if self.args.prepared {
let mut stmt = self.client.prepare(query.to_string(), None).await?;
let info = stmt.execute().await?;
stmt.close().await?;
info
} else {
self.client.execute(query.to_string(), None).await?
};
let ticket_recv_duration = start.elapsed();
let mut batches: Vec<RecordBatch> = Vec::new();

let mut handles = Vec::with_capacity(flight_info.endpoint.len());
for endpoint in flight_info.endpoint.iter() {
let ticket = endpoint
.ticket
.as_ref()
.ok_or_else(|| ArrowError::IpcError("Ticket is emtpy".to_string()))?
.clone();
let mut client = self.client.clone();
handles.push(tokio::spawn(async move {
let flight_data = client.do_get(ticket).await?;
let result: Vec<RecordBatch> = flight_data.try_collect().await.map_err(|e| {
ArrowError::IpcError(format!("Failed to collect record batches: {e}"))
})?;
Ok::<Vec<RecordBatch>, ArrowError>(result)
}));
}

for handle in handles {
batches.extend(handle.await.unwrap()?);
}
let rows_recv_duration = start.elapsed();
let (batches, ticket_recv_duration, rows_recv_duration, flight_info) =
self.execute_query(query).await?;

if is_repl {
let res = pretty_format_batches(batches.as_slice())?;
Expand Down