diff --git a/src/webserver/database/error_highlighting.rs b/src/webserver/database/error_highlighting.rs index 7c9a01eb..9c38d6f3 100644 --- a/src/webserver/database/error_highlighting.rs +++ b/src/webserver/database/error_highlighting.rs @@ -17,6 +17,24 @@ struct NiceDatabaseError { query_position: Option, } +fn write_source_position_info( + f: &mut std::fmt::Formatter<'_>, + source_file: &Path, + query_position: Option, +) -> Result<(), std::fmt::Error> { + write!(f, "\n{}", source_file.display())?; + if let Some(query_position) = query_position { + let start_line = query_position.start.line; + let end_line = query_position.end.line; + if start_line == end_line { + write!(f, ": line {start_line}")?; + } else { + write!(f, ": lines {start_line} to {end_line}")?; + } + } + Ok(()) +} + impl std::fmt::Display for NiceDatabaseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -51,22 +69,32 @@ impl std::fmt::Display for NiceDatabaseError { impl NiceDatabaseError { fn show_position_info(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { - write!(f, "\n{}", self.source_file.display())?; - let _: () = if let Some(query_position) = self.query_position { - let start_line = query_position.start.line; - let end_line = query_position.end.line; - if start_line == end_line { - write!(f, ": line {start_line}")?; - } else { - write!(f, ": lines {start_line} to {end_line}")?; - } - }; - Ok(()) + write_source_position_info(f, &self.source_file, self.query_position) } } impl std::error::Error for NiceDatabaseError {} +#[derive(Debug)] +struct NicePositionedError { + source_file: PathBuf, + query_position: SourceSpan, + error: anyhow::Error, +} + +impl std::fmt::Display for NicePositionedError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "In \"{}\": {}", self.source_file.display(), self.error)?; + write_source_position_info(f, &self.source_file, Some(self.query_position)) + } +} + +impl std::error::Error for NicePositionedError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(self.error.as_ref()) + } +} + /// Display a database error without any position information #[must_use] pub fn display_db_error( @@ -97,6 +125,19 @@ pub fn display_stmt_db_error( }) } +#[must_use] +pub fn display_stmt_error( + source_file: &Path, + query_position: SourceSpan, + error: anyhow::Error, +) -> anyhow::Error { + anyhow::Error::new(NicePositionedError { + source_file: source_file.to_path_buf(), + query_position, + error, + }) +} + /// Highlight a line with a character offset. pub fn highlight_line_offset(msg: &mut W, line: &str, offset: usize) { writeln!(msg, "{line}").unwrap(); @@ -124,6 +165,27 @@ pub fn quote_source_with_highlight(source: &str, line_num: u64, col_num: u64) -> msg } +#[test] +fn test_display_stmt_error_includes_file_and_line() { + let err = display_stmt_error( + Path::new("example.sql"), + SourceSpan { + start: super::sql::SourceLocation { + line: 12, + column: 3, + }, + end: super::sql::SourceLocation { + line: 12, + column: 17, + }, + }, + anyhow::anyhow!("boom"), + ); + let message = err.to_string(); + assert!(message.contains("In \"example.sql\": boom")); + assert!(message.contains("example.sql: line 12")); +} + #[test] fn test_quote_source_with_highlight() { let source = "SELECT *\nFROM table\nWHERE "; diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index 253bd17e..3edeecfc 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -7,7 +7,7 @@ use std::path::Path; use std::pin::Pin; use super::csv_import::run_csv_import; -use super::error_highlighting::display_stmt_db_error; +use super::error_highlighting::{display_stmt_db_error, display_stmt_error}; use super::sql::{ DelayedFunctionCall, ParsedSqlFile, ParsedStatement, SimpleSelectValue, StmtWithParams, }; @@ -17,6 +17,7 @@ use crate::webserver::database::sql_to_json::row_to_string; use crate::webserver::http_request_info::ExecutionContext; use crate::webserver::request_variables::SetVariablesMap; use crate::webserver::single_or_vec::SingleOrVec; +use crate::webserver::ErrorWithStatus; use super::syntax_tree::{extract_req_param, StmtParam}; use super::{error_highlighting::display_db_error, Database, DbItem}; @@ -57,7 +58,9 @@ pub fn stream_query_results_with_conn<'a>( run_csv_import(connection, csv_import, request).await.with_context(|| format!("Failed to import the CSV file {:?} into the table {:?}", csv_import.uploaded_file, csv_import.table_name))?; }, ParsedStatement::StmtWithParams(stmt) => { - let query = bind_parameters(stmt, request, db_connection).await?; + let query = bind_parameters(stmt, request, db_connection) + .await + .map_err(|e| with_stmt_position(source_file, stmt.query_position, e))?; request.server_timing.record("bind_params"); let connection = take_connection(&request.app_state.db, db_connection, request).await?; log::trace!("Executing query {:?}", query.sql); @@ -93,8 +96,11 @@ pub fn stream_query_results_with_conn<'a>( format!("Failed to set the {variable} variable to {value:?}") )?; }, - ParsedStatement::StaticSimpleSelect(value) => { - for i in parse_dynamic_rows(DbItem::Row(exec_static_simple_select(value, request, db_connection).await?)) { + ParsedStatement::StaticSimpleSelect { values, query_position } => { + let row = exec_static_simple_select(values, request, db_connection) + .await + .map_err(|e| with_stmt_position(source_file, *query_position, e))?; + for i in parse_dynamic_rows(DbItem::Row(row)) { yield i; } } @@ -105,6 +111,18 @@ pub fn stream_query_results_with_conn<'a>( .map(|res| res.unwrap_or_else(DbItem::Error)) } +fn with_stmt_position( + source_file: &Path, + query_position: super::sql::SourceSpan, + error: anyhow::Error, +) -> anyhow::Error { + if error.downcast_ref::().is_some() { + error + } else { + display_stmt_error(source_file, query_position, error) + } +} + /// Transforms a stream of database items to stop processing after encountering the first error. /// The error item itself is still emitted before stopping. pub fn stop_at_first_error( diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 7acac1a0..dad66f54 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -118,7 +118,10 @@ pub(super) struct SourceLocation { #[derive(Debug)] pub(super) enum ParsedStatement { StmtWithParams(StmtWithParams), - StaticSimpleSelect(Vec<(String, SimpleSelectValue)>), + StaticSimpleSelect { + values: Vec<(String, SimpleSelectValue)>, + query_position: SourceSpan, + }, SetVariable { variable: StmtParam, value: StmtWithParams, @@ -217,7 +220,10 @@ fn parse_single_statement( } if let Some(static_statement) = extract_static_simple_select(&stmt, ¶ms) { log::debug!("Optimised a static simple select to avoid a trivial database query: {stmt} optimized to {static_statement:?}"); - return Some(ParsedStatement::StaticSimpleSelect(static_statement)); + return Some(ParsedStatement::StaticSimpleSelect { + values: static_statement, + query_position: extract_query_start(&stmt), + }); } let delayed_functions = extract_toplevel_functions(&mut stmt); @@ -1042,7 +1048,7 @@ mod test { }; let parsed: Vec = parse_sql(&db_info, dialect, sql).unwrap().collect(); match &parsed[..] { - [ParsedStatement::StaticSimpleSelect(q)] => assert_eq!( + [ParsedStatement::StaticSimpleSelect { values: q, .. }] => assert_eq!( q, &[ ("component".into(), Static("text".into())),