From dcd018a32b2ec0ea22a5dbfbb1b08517e70601a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Thu, 29 Jan 2026 09:03:36 +0100 Subject: [PATCH] feat(orm)!: transactions support --- cot-macros/src/model.rs | 6 +- cot/Cargo.toml | 2 +- cot/src/auth/db.rs | 30 +- cot/src/db.rs | 723 ++++++++++++++++++++++++++++-------- cot/src/db/impl_mysql.rs | 3 +- cot/src/db/impl_postgres.rs | 3 +- cot/src/db/impl_sqlite.rs | 3 +- cot/src/db/query.rs | 21 +- cot/src/db/relations.rs | 4 +- cot/src/db/sea_query_db.rs | 106 +++++- cot/src/session/store/db.rs | 37 +- cot/tests/db.rs | 42 +++ examples/admin/src/main.rs | 8 +- 13 files changed, 790 insertions(+), 198 deletions(-) diff --git a/cot-macros/src/model.rs b/cot-macros/src/model.rs index 91006c40..e20ab604 100644 --- a/cot-macros/src/model.rs +++ b/cot-macros/src/model.rs @@ -176,8 +176,8 @@ impl ModelBuilder { let fields_as_get_values = &self.fields_as_get_values; quote! { - #[#crate_ident::__private::async_trait] #[automatically_derived] + #[#orm_ident::async_trait] impl #orm_ident::Model for #name { type Fields = #fields_struct_name; type PrimaryKey = #pk_type; @@ -225,11 +225,11 @@ impl ModelBuilder { } async fn get_by_primary_key( - db: &DB, + mut db: DB, pk: Self::PrimaryKey, ) -> #orm_ident::Result> { #orm_ident::query!(Self, $#pk_field_name == pk) - .get(db) + .get(&mut db) .await } } diff --git a/cot/Cargo.toml b/cot/Cargo.toml index f12f86d3..cdfc8c51 100644 --- a/cot/Cargo.toml +++ b/cot/Cargo.toml @@ -61,7 +61,7 @@ subtle = { workspace = true, features = ["std"] } swagger-ui-redist = { workspace = true, optional = true } thiserror.workspace = true time.workspace = true -tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "fs", "io-util"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "fs", "io-util", "sync"] } toml = { workspace = true, features = ["parse", "serde"] } tower = { workspace = true, features = ["util"] } tower-livereload = { workspace = true, optional = true } diff --git a/cot/src/auth/db.rs b/cot/src/auth/db.rs index b7cf90b2..331d5fd6 100644 --- a/cot/src/auth/db.rs +++ b/cot/src/auth/db.rs @@ -99,7 +99,7 @@ impl DatabaseUser { /// # } /// ``` pub async fn create_user, U: Into>( - db: &DB, + mut db: DB, username: T, password: U, ) -> Result { @@ -110,7 +110,9 @@ impl DatabaseUser { })?; let mut user = Self::new(Auto::auto(), username, &password.into()); - user.insert(db).await.map_err(AuthError::backend_error)?; + user.insert(&mut db) + .await + .map_err(AuthError::backend_error)?; Ok(user) } @@ -155,9 +157,9 @@ impl DatabaseUser { /// # Ok(()) /// # } /// ``` - pub async fn get_by_id(db: &DB, id: i64) -> Result> { + pub async fn get_by_id(mut db: DB, id: i64) -> Result> { let db_user = query!(DatabaseUser, $id == id) - .get(db) + .get(&mut db) .await .map_err(AuthError::backend_error)?; @@ -201,14 +203,14 @@ impl DatabaseUser { /// # } /// ``` pub async fn get_by_username( - db: &DB, + mut db: DB, username: &str, ) -> Result> { let username = LimitedString::::new(username).map_err(|_| { AuthError::backend_error(CreateUserError::UsernameTooLong(username.len())) })?; let db_user = query!(DatabaseUser, $username == username) - .get(db) + .get(&mut db) .await .map_err(AuthError::backend_error)?; @@ -221,7 +223,7 @@ impl DatabaseUser { /// /// Returns an error if there was an error querying the database. pub async fn authenticate( - db: &DB, + mut db: DB, credentials: &DatabaseUserCredentials, ) -> Result> { let username = credentials.username(); @@ -230,7 +232,7 @@ impl DatabaseUser { AuthError::backend_error(CreateUserError::UsernameTooLong(username.len())) })?; let user = query!(DatabaseUser, $username == username_limited) - .get(db) + .get(&mut db) .await .map_err(AuthError::backend_error)?; @@ -240,7 +242,7 @@ impl DatabaseUser { PasswordVerificationResult::Ok => Ok(Some(user)), PasswordVerificationResult::OkObsolete(new_hash) => { user.password = new_hash; - user.save(db).await.map_err(AuthError::backend_error)?; + user.save(&mut db).await.map_err(AuthError::backend_error)?; Ok(Some(user)) } PasswordVerificationResult::Invalid => Ok(None), @@ -624,7 +626,7 @@ mod tests { let username = "testuser".to_string(); let password = Password::new("password123"); - let user = DatabaseUser::create_user(&mock_db, username.clone(), &password) + let user = DatabaseUser::create_user(&mut mock_db, username.clone(), &password) .await .unwrap(); assert_eq!(user.username(), username); @@ -644,7 +646,7 @@ mod tests { .expect_get::() .returning(move |_| Ok(Some(user.clone()))); - let result = DatabaseUser::get_by_id(&mock_db, 1).await.unwrap(); + let result = DatabaseUser::get_by_id(&mut mock_db, 1).await.unwrap(); assert!(result.is_some()); assert_eq!(result.unwrap().username(), "testuser"); } @@ -665,7 +667,7 @@ mod tests { let credentials = DatabaseUserCredentials::new("testuser".to_string(), Password::new("password123")); - let result = DatabaseUser::authenticate(&mock_db, &credentials) + let result = DatabaseUser::authenticate(&mut mock_db, &credentials) .await .unwrap(); assert!(result.is_some()); @@ -683,7 +685,7 @@ mod tests { let credentials = DatabaseUserCredentials::new("testuser".to_string(), Password::new("password123")); - let result = DatabaseUser::authenticate(&mock_db, &credentials) + let result = DatabaseUser::authenticate(&mut mock_db, &credentials) .await .unwrap(); assert!(result.is_none()); @@ -705,7 +707,7 @@ mod tests { let credentials = DatabaseUserCredentials::new("testuser".to_string(), Password::new("invalid")); - let result = DatabaseUser::authenticate(&mock_db, &credentials) + let result = DatabaseUser::authenticate(&mut mock_db, &credentials) .await .unwrap(); assert!(result.is_none()); diff --git a/cot/src/db.rs b/cot/src/db.rs index 445dd972..b035dade 100644 --- a/cot/src/db.rs +++ b/cot/src/db.rs @@ -20,7 +20,7 @@ use std::hash::Hash; use std::str::FromStr; use std::sync::Arc; -use async_trait::async_trait; +pub use async_trait::async_trait; use cot_core::error::impl_into_cot_error; pub use cot_macros::{model, query}; use derive_more::{Debug, Deref, Display}; @@ -32,16 +32,18 @@ use sea_query::{ ColumnRef, Iden, IntoColumnRef, OnConflict, ReturningClause, SchemaStatementBuilder, SimpleExpr, }; use sea_query_binder::{SqlxBinder, SqlxValues}; -use sqlx::{Type, TypeInfo}; +use sqlx::{Acquire, Type, TypeInfo}; use thiserror::Error; use tracing::{Instrument, Level, span, trace}; #[cfg(feature = "mysql")] -use crate::db::impl_mysql::{DatabaseMySql, MySqlRow, MySqlValueRef}; +use crate::db::impl_mysql::{DatabaseMySql, MySqlRow, MySqlValueRef, TransactionMySql}; #[cfg(feature = "postgres")] -use crate::db::impl_postgres::{DatabasePostgres, PostgresRow, PostgresValueRef}; +use crate::db::impl_postgres::{ + DatabasePostgres, PostgresRow, PostgresValueRef, TransactionPostgres, +}; #[cfg(feature = "sqlite")] -use crate::db::impl_sqlite::{DatabaseSqlite, SqliteRow, SqliteValueRef}; +use crate::db::impl_sqlite::{DatabaseSqlite, SqliteRow, SqliteValueRef, TransactionSqlite}; use crate::db::migrations::ColumnTypeMapper; const ERROR_PREFIX: &str = "database error:"; @@ -115,6 +117,9 @@ pub enum DatabaseError { /// The actual number of rows returned. actual: usize, }, + /// Nested transactions are not supported yet. + #[error("{ERROR_PREFIX} nested transactions are not supported yet")] + NestedTransactionsNotSupported, } impl_into_cot_error!(DatabaseError, INTERNAL_SERVER_ERROR); @@ -151,12 +156,12 @@ pub type Result = std::result::Result; /// name: String, /// } /// ``` -#[async_trait] #[diagnostic::on_unimplemented( message = "`{Self}` is not marked as a database model", label = "`{Self}` is not annotated with `#[cot::db::model]`", note = "annotate `{Self}` with the `#[cot::db::model]` attribute" )] +#[async_trait] pub trait Model: Sized + Send + 'static { #[allow( clippy::allow_attributes, @@ -229,7 +234,7 @@ pub trait Model: Sized + Send + 'static { /// found in the database, or there was a problem with the database /// connection. async fn get_by_primary_key( - db: &DB, + db: DB, pk: Self::PrimaryKey, ) -> Result>; @@ -245,7 +250,7 @@ pub trait Model: Sized + Send + 'static { /// inserted into the database, for instance because the migrations /// haven't been applied, or there was a problem with the database /// connection. - async fn save(&mut self, db: &DB) -> Result<()> { + async fn save(&mut self, mut db: DB) -> Result<()> { db.insert_or_update(self).await?; Ok(()) } @@ -258,7 +263,7 @@ pub trait Model: Sized + Send + 'static { /// inserted into the database, for instance because the migrations /// haven't been applied, or there was a problem with the database /// connection. - async fn insert(&mut self, db: &DB) -> Result<()> { + async fn insert(&mut self, mut db: DB) -> Result<()> { db.insert(self).await?; Ok(()) } @@ -274,7 +279,7 @@ pub trait Model: Sized + Send + 'static { /// /// This method can return an error if the model with the given primary key /// could not be found in the database. - async fn update(&mut self, db: &DB) -> Result<()> { + async fn update(&mut self, mut db: DB) -> Result<()> { db.update(self).await?; Ok(()) } @@ -318,7 +323,7 @@ pub trait Model: Sized + Send + 'static { /// // After insertion, all todos have populated IDs /// assert!(todos[0].id.is_fixed()); /// ``` - async fn bulk_insert(db: &DB, instances: &mut [Self]) -> Result<()> { + async fn bulk_insert(mut db: DB, instances: &mut [Self]) -> Result<()> { db.bulk_insert(instances).await?; Ok(()) } @@ -358,7 +363,7 @@ pub trait Model: Sized + Send + 'static { /// assert!(todos[0].id.is_fixed()); /// ``` async fn bulk_insert_or_update( - db: &DB, + mut db: DB, instances: &mut [Self], ) -> Result<()> { db.bulk_insert_or_update(instances).await?; @@ -785,6 +790,386 @@ pub trait SqlxValueRef<'r>: Sized { } } +/// A database transaction structure that holds a transaction to the database. +#[derive(Debug)] +pub struct Transaction<'a> { + inner: TransactionImpl<'a>, +} + +#[derive(Debug)] +enum TransactionImpl<'a> { + #[cfg(feature = "sqlite")] + Sqlite(TransactionSqlite<'a>), + #[cfg(feature = "postgres")] + Postgres(TransactionPostgres<'a>), + #[cfg(feature = "mysql")] + MySql(TransactionMySql<'a>), +} + +impl Transaction<'_> { + /// Commits the transaction. + /// + /// # Errors + /// + /// Returns an error if the transaction could not be committed. + pub async fn commit(self) -> Result<()> { + match self.inner { + #[cfg(feature = "sqlite")] + TransactionImpl::Sqlite(inner) => inner.commit().await, + #[cfg(feature = "postgres")] + TransactionImpl::Postgres(inner) => inner.commit().await, + #[cfg(feature = "mysql")] + TransactionImpl::MySql(inner) => inner.commit().await, + } + } + + /// Rolls back the transaction. + /// + /// # Errors + /// + /// Returns an error if the transaction could not be rolled back. + pub async fn rollback(self) -> Result<()> { + match self.inner { + #[cfg(feature = "sqlite")] + TransactionImpl::Sqlite(inner) => inner.rollback().await, + #[cfg(feature = "postgres")] + TransactionImpl::Postgres(inner) => inner.rollback().await, + #[cfg(feature = "mysql")] + TransactionImpl::MySql(inner) => inner.rollback().await, + } + } + + /// Starts a new database transaction (savepoint). + /// + /// # Errors + /// + /// Returns an error if the transaction could not be started. + pub async fn begin(&mut self) -> Result> { + let inner = match &mut self.inner { + #[cfg(feature = "sqlite")] + TransactionImpl::Sqlite(inner) => { + TransactionImpl::Sqlite(TransactionSqlite::new(inner.inner.begin().await?)) + } + #[cfg(feature = "postgres")] + TransactionImpl::Postgres(inner) => { + TransactionImpl::Postgres(TransactionPostgres::new(inner.inner.begin().await?)) + } + #[cfg(feature = "mysql")] + TransactionImpl::MySql(inner) => { + TransactionImpl::MySql(TransactionMySql::new(inner.inner.begin().await?)) + } + }; + + Ok(Transaction { inner }) + } + + async fn fetch_option(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync, + { + let result = match &mut self.inner { + #[cfg(feature = "sqlite")] + TransactionImpl::Sqlite(inner) => { + inner.fetch_option::(statement).await?.map(Row::Sqlite) + } + #[cfg(feature = "postgres")] + TransactionImpl::Postgres(inner) => { + inner.fetch_option::(statement).await?.map(Row::Postgres) + } + #[cfg(feature = "mysql")] + TransactionImpl::MySql(inner) => { + inner.fetch_option::(statement).await?.map(Row::MySql) + } + }; + + Ok(result) + } + + async fn fetch_all(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync, + { + let result = match &mut self.inner { + #[cfg(feature = "sqlite")] + TransactionImpl::Sqlite(inner) => inner + .fetch_all::(statement) + .await? + .into_iter() + .map(Row::Sqlite) + .collect(), + #[cfg(feature = "postgres")] + TransactionImpl::Postgres(inner) => inner + .fetch_all::(statement) + .await? + .into_iter() + .map(Row::Postgres) + .collect(), + #[cfg(feature = "mysql")] + TransactionImpl::MySql(inner) => inner + .fetch_all::(statement) + .await? + .into_iter() + .map(Row::MySql) + .collect(), + }; + + Ok(result) + } + + async fn execute_statement(&mut self, statement: &T) -> Result + where + T: SqlxBinder + Send + Sync, + { + let result = match &mut self.inner { + #[cfg(feature = "sqlite")] + TransactionImpl::Sqlite(inner) => inner.execute_statement::(statement).await?, + #[cfg(feature = "postgres")] + TransactionImpl::Postgres(inner) => inner.execute_statement::(statement).await?, + #[cfg(feature = "mysql")] + TransactionImpl::MySql(inner) => inner.execute_statement::(statement).await?, + }; + + Ok(result) + } +} + +#[async_trait] +impl DatabaseBackend for Transaction<'_> { + async fn insert_or_update(&mut self, data: &mut T) -> Result<()> { + Database::insert_or_update_generic(self, data, true).await + } + + async fn insert(&mut self, data: &mut T) -> Result<()> { + Database::insert_or_update_generic(self, data, false).await + } + + async fn update(&mut self, data: &mut T) -> Result<()> { + Database::update_generic(self, data).await + } + + async fn bulk_insert(&mut self, data: &mut [T]) -> Result<()> { + Database::bulk_insert_generic(self, data, false).await + } + + async fn bulk_insert_or_update(&mut self, data: &mut [T]) -> Result<()> { + Database::bulk_insert_generic(self, data, true).await + } + + async fn query(&mut self, query: &Query) -> Result> { + Database::query_generic(self, query).await + } + + async fn get(&mut self, query: &Query) -> Result> { + Database::get_generic(self, query).await + } + + async fn exists(&mut self, query: &Query) -> Result { + Database::exists_generic(self, query).await + } + + async fn delete(&mut self, query: &Query) -> Result { + let mut delete = sea_query::Query::delete(); + delete.from_table(T::TABLE_NAME); + query.add_filter_to_statement(&mut delete); + + self.execute_statement(&delete).await + } +} + +#[async_trait] +trait RawExecutor { + async fn fetch_option(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync; + + async fn fetch_all(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync; + + async fn execute_statement(&mut self, statement: &T) -> Result + where + T: SqlxBinder + Send + Sync; + + fn supports_returning(&self) -> bool; + + fn max_params(&self) -> usize; + + async fn begin_transaction<'a>(&'a mut self) -> Result>; +} + +#[async_trait] +impl RawExecutor for &Database { + async fn fetch_option(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync, + { + let result = match &*self.inner { + #[cfg(feature = "sqlite")] + DatabaseImpl::Sqlite(inner) => inner.fetch_option(statement).await?.map(Row::Sqlite), + #[cfg(feature = "postgres")] + DatabaseImpl::Postgres(inner) => { + inner.fetch_option(statement).await?.map(Row::Postgres) + } + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner.fetch_option(statement).await?.map(Row::MySql), + }; + + Ok(result) + } + + async fn fetch_all(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync, + { + let result = match &*self.inner { + #[cfg(feature = "sqlite")] + DatabaseImpl::Sqlite(inner) => inner + .fetch_all(statement) + .await? + .into_iter() + .map(Row::Sqlite) + .collect(), + #[cfg(feature = "postgres")] + DatabaseImpl::Postgres(inner) => inner + .fetch_all(statement) + .await? + .into_iter() + .map(Row::Postgres) + .collect(), + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner + .fetch_all(statement) + .await? + .into_iter() + .map(Row::MySql) + .collect(), + }; + + Ok(result) + } + + async fn execute_statement(&mut self, statement: &T) -> Result + where + T: SqlxBinder + Send + Sync, + { + let result = match &*self.inner { + #[cfg(feature = "sqlite")] + DatabaseImpl::Sqlite(inner) => inner.execute_statement(statement).await?, + #[cfg(feature = "postgres")] + DatabaseImpl::Postgres(inner) => inner.execute_statement(statement).await?, + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner.execute_statement(statement).await?, + }; + + Ok(result) + } + + fn supports_returning(&self) -> bool { + (*self).supports_returning() + } + + fn max_params(&self) -> usize { + match &*self.inner { + #[cfg(feature = "sqlite")] + DatabaseImpl::Sqlite(_) => 32766, + #[cfg(feature = "postgres")] + DatabaseImpl::Postgres(_) => 65535, + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(_) => 65535, + } + } + + async fn begin_transaction<'a>(&'a mut self) -> Result> { + self.begin().await + } +} + +#[async_trait] +impl RawExecutor for Transaction<'_> { + async fn fetch_option(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync, + { + self.fetch_option::(statement).await + } + + async fn fetch_all(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync, + { + self.fetch_all::(statement).await + } + + async fn execute_statement(&mut self, statement: &T) -> Result + where + T: SqlxBinder + Send + Sync, + { + self.execute_statement::(statement).await + } + + fn supports_returning(&self) -> bool { + match &self.inner { + #[cfg(feature = "sqlite")] + TransactionImpl::Sqlite(_) => true, + #[cfg(feature = "postgres")] + TransactionImpl::Postgres(_) => true, + #[cfg(feature = "mysql")] + TransactionImpl::MySql(_) => false, + } + } + + fn max_params(&self) -> usize { + match &self.inner { + #[cfg(feature = "sqlite")] + TransactionImpl::Sqlite(_) => 32766, + #[cfg(feature = "postgres")] + TransactionImpl::Postgres(_) => 65535, + #[cfg(feature = "mysql")] + TransactionImpl::MySql(_) => 65535, + } + } + + async fn begin_transaction<'b>(&'b mut self) -> Result> { + self.begin().await + } +} + +#[async_trait] +impl RawExecutor for &mut E { + async fn fetch_option(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync, + { + (**self).fetch_option::(statement).await + } + + async fn fetch_all(&mut self, statement: &T) -> Result> + where + T: SqlxBinder + Send + Sync, + { + (**self).fetch_all::(statement).await + } + + async fn execute_statement(&mut self, statement: &T) -> Result + where + T: SqlxBinder + Send + Sync, + { + (**self).execute_statement::(statement).await + } + + fn supports_returning(&self) -> bool { + (**self).supports_returning() + } + + fn max_params(&self) -> usize { + (**self).max_params() + } + + async fn begin_transaction<'a>(&'a mut self) -> Result> { + (**self).begin_transaction().await + } +} + /// A database connection structure that holds the connection to the database. /// /// It is used to execute queries and interact with the database. The connection @@ -893,6 +1278,30 @@ impl Database { } } + /// Starts a new database transaction. + /// + /// # Errors + /// + /// Returns an error if the transaction could not be started. + pub async fn begin(&self) -> Result> { + let inner = match &*self.inner { + #[cfg(feature = "sqlite")] + DatabaseImpl::Sqlite(inner) => { + TransactionImpl::Sqlite(TransactionSqlite::new(inner.begin().await?)) + } + #[cfg(feature = "postgres")] + DatabaseImpl::Postgres(inner) => { + TransactionImpl::Postgres(TransactionPostgres::new(inner.begin().await?)) + } + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => { + TransactionImpl::MySql(TransactionMySql::new(inner.begin().await?)) + } + }; + + Ok(Transaction { inner }) + } + /// Inserts a new row into the database. /// /// # Errors @@ -903,7 +1312,7 @@ impl Database { pub async fn insert(&self, data: &mut T) -> Result<()> { let span = span!(Level::TRACE, "insert", table = %T::TABLE_NAME); - Self::insert_or_update_impl(self, data, false) + Self::insert_or_update_generic(self, data, false) .instrument(span) .await } @@ -923,12 +1332,16 @@ impl Database { table = %T::TABLE_NAME ); - Self::insert_or_update_impl(self, data, true) + Self::insert_or_update_generic(self, data, true) .instrument(span) .await } - async fn insert_or_update_impl(&self, data: &mut T, update: bool) -> Result<()> { + async fn insert_or_update_generic( + mut executor: E, + data: &mut T, + update: bool, + ) -> Result<()> { let column_identifiers = T::COLUMNS .iter() .map(|column| Identifier::from(column.name.as_str())); @@ -979,16 +1392,17 @@ impl Database { } if auto_col_ids.is_empty() { - self.execute_statement(&insert_statement).await?; + executor.execute_statement(&insert_statement).await?; } else { - let row = if self.supports_returning() { + let row = if executor.supports_returning() { insert_statement.returning(ReturningClause::Columns(auto_col_identifiers)); - self.fetch_option(&insert_statement) + executor + .fetch_option(&insert_statement) .await? .expect("query should return the primary key") } else { - let result = self.execute_statement(&insert_statement).await?; + let result = executor.execute_statement(&insert_statement).await?; let row_id = result .last_inserted_row_id .expect("expected last inserted row ID if RETURNING clause is not supported"); @@ -997,7 +1411,7 @@ impl Database { .columns(auto_col_identifiers) .and_where(sea_query::Expr::col(T::PRIMARY_KEY_NAME).eq(row_id)) .to_owned(); - self.fetch_option(&query).await?.expect( + executor.fetch_option(&query).await?.expect( "expected a row returned from a SELECT if RETURNING clause is not supported", ) }; @@ -1031,10 +1445,10 @@ impl Database { primary_key = ?data.primary_key().to_db_field_value(), ); - Self::update_impl(self, data).instrument(span).await + Self::update_generic(self, data).instrument(span).await } - async fn update_impl(&self, data: &mut T) -> Result<()> { + async fn update_generic(mut executor: E, data: &mut T) -> Result<()> { let column_identifiers = T::COLUMNS .iter() .map(|column| Identifier::from(column.name.as_str())); @@ -1068,7 +1482,7 @@ impl Database { .and_where(sea_query::Expr::col(T::PRIMARY_KEY_NAME).eq(primary_key.clone())) .to_owned(); - let result = self.execute_statement(&update_statement).await?; + let result = executor.execute_statement(&update_statement).await?; if result.rows_affected == RowsNum(0) { return Err(DatabaseError::RecordNotFound { primary_key }); } @@ -1088,7 +1502,7 @@ impl Database { pub async fn bulk_insert(&self, data: &mut [T]) -> Result<()> { let span = span!(Level::TRACE, "bulk_insert", table = %T::TABLE_NAME, count = data.len()); - Self::bulk_insert_impl(self, data, false) + Self::bulk_insert_generic(self, data, false) .instrument(span) .await } @@ -1104,37 +1518,26 @@ impl Database { pub async fn bulk_insert_or_update(&self, data: &mut [T]) -> Result<()> { let span = span!( Level::TRACE, - "bulk_insert_or_update", + "insert_or_update", table = %T::TABLE_NAME, count = data.len() ); - Self::bulk_insert_impl(self, data, true) + Self::bulk_insert_generic(self, data, true) .instrument(span) .await } - async fn bulk_insert_impl(&self, data: &mut [T], update: bool) -> Result<()> { - // TODO: add transactions when implemented - + async fn bulk_insert_generic( + mut executor: E, + data: &mut [T], + update: bool, + ) -> Result<()> { if data.is_empty() { return Ok(()); } - let max_params = match &*self.inner { - // https://sqlite.org/limits.html#max_variable_number - // Assuming SQLite > 3.32.0 (2020-05-22) - #[cfg(feature = "sqlite")] - DatabaseImpl::Sqlite(_) => 32766, - // https://www.postgresql.org/docs/18/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BIND - // The number of parameter format codes is Int16 - #[cfg(feature = "postgres")] - DatabaseImpl::Postgres(_) => 65535, - // https://dev.mysql.com/doc/dev/mysql-server/9.5.0/page_protocol_com_stmt_prepare.html#sect_protocol_com_stmt_prepare_response - // The number of parameter returned in the COM_STMT_PREPARE_OK packet is int<2> - #[cfg(feature = "mysql")] - DatabaseImpl::MySql(_) => 65535, - }; + let max_params = executor.max_params(); let column_identifiers: Vec<_> = T::COLUMNS .iter() @@ -1186,9 +1589,10 @@ impl Database { return Err(DatabaseError::BulkInsertNoValueColumns); }; - for chunk in data.chunks_mut(batch_size) { - self.bulk_insert_chunk( - chunk, + if data.len() <= batch_size { + Self::bulk_insert_chunk_generic( + executor, + data, update, &value_identifiers, &value_column_indices, @@ -1196,13 +1600,28 @@ impl Database { &auto_col_identifiers, ) .await?; + } else { + let mut transaction = executor.begin_transaction().await?; + for chunk in data.chunks_mut(batch_size) { + Self::bulk_insert_chunk_generic( + &mut transaction, + chunk, + update, + &value_identifiers, + &value_column_indices, + &auto_col_ids, + &auto_col_identifiers, + ) + .await?; + } + transaction.commit().await?; } Ok(()) } - async fn bulk_insert_chunk( - &self, + async fn bulk_insert_chunk_generic( + mut executor: E, chunk: &mut [T], update: bool, value_identifiers: &[Identifier], @@ -1250,12 +1669,12 @@ impl Database { } if auto_col_ids.is_empty() { - self.execute_statement(&insert_statement).await?; - } else if self.supports_returning() { + executor.execute_statement(&insert_statement).await?; + } else if executor.supports_returning() { // PostgreSQL/SQLite: Use RETURNING clause insert_statement.returning(ReturningClause::Columns(auto_col_identifiers.to_vec())); - let rows = self.fetch_all(&insert_statement).await?; + let rows = executor.fetch_all(&insert_statement).await?; if rows.len() != chunk.len() { return Err(DatabaseError::BulkInsertReturnDataInvalid { expected: chunk.len(), @@ -1268,7 +1687,7 @@ impl Database { } } else { // MySQL: Use LAST_INSERT_ID() and fetch rows - let result = self.execute_statement(&insert_statement).await?; + let result = executor.execute_statement(&insert_statement).await?; let first_id = result.last_inserted_row_id.ok_or_else(|| { DatabaseError::BulkInsertReturnDataInvalid { expected: chunk.len(), @@ -1292,7 +1711,7 @@ impl Database { .order_by(T::PRIMARY_KEY_NAME, sea_query::Order::Asc) .to_owned(); - let rows = self.fetch_all(&query).await?; + let rows = executor.fetch_all(&query).await?; if rows.len() != chunk.len() { return Err(DatabaseError::BulkInsertReturnDataInvalid { expected: chunk.len(), @@ -1327,6 +1746,13 @@ impl Database { /// /// Can return an error if the database connection is lost. pub async fn query(&self, query: &Query) -> Result> { + Self::query_generic(self, query).await + } + + async fn query_generic( + mut executor: E, + query: &Query, + ) -> Result> { let columns_to_get: Vec<_> = T::COLUMNS.iter().map(|column| column.name).collect(); let mut select = sea_query::Query::select(); select.columns(columns_to_get).from(T::TABLE_NAME); @@ -1334,7 +1760,7 @@ impl Database { query.add_limit_to_statement(&mut select); query.add_offset_to_statement(&mut select); - let rows = self.fetch_all(&select).await?; + let rows = executor.fetch_all(&select).await?; let result = rows.into_iter().map(T::from_db).collect::>()?; Ok(result) @@ -1353,13 +1779,20 @@ impl Database { /// /// Can return an error if the database connection is lost. pub async fn get(&self, query: &Query) -> Result> { + Self::get_generic(self, query).await + } + + async fn get_generic( + mut executor: E, + query: &Query, + ) -> Result> { let columns_to_get: Vec<_> = T::COLUMNS.iter().map(|column| column.name).collect(); let mut select = sea_query::Query::select(); select.columns(columns_to_get).from(T::TABLE_NAME); query.add_filter_to_statement(&mut select); select.limit(1); - let row = self.fetch_option(&select).await?; + let row = executor.fetch_option(&select).await?; let result = match row { Some(row) => Some(T::from_db(row)?), @@ -1380,12 +1813,19 @@ impl Database { /// /// Can return an error if the database connection is lost. pub async fn exists(&self, query: &Query) -> Result { + Self::exists_generic(self, query).await + } + + async fn exists_generic( + mut executor: E, + query: &Query, + ) -> Result { let mut select = sea_query::Query::select(); select.expr(sea_query::Expr::value(1)).from(T::TABLE_NAME); query.add_filter_to_statement(&mut select); select.limit(1); - let rows = self.fetch_option(&select).await?; + let rows = executor.fetch_option(&select).await?; Ok(rows.is_some()) } @@ -1402,11 +1842,18 @@ impl Database { /// /// Can return an error if the database connection is lost. pub async fn delete(&self, query: &Query) -> Result { + Self::delete_generic(self, query).await + } + + async fn delete_generic( + mut executor: E, + query: &Query, + ) -> Result { let mut delete = sea_query::Query::delete(); delete.from_table(T::TABLE_NAME); query.add_filter_to_statement(&mut delete); - self.execute_statement(&delete).await + executor.execute_statement(&delete).await } /// Executes a raw SQL query. @@ -1480,24 +1927,6 @@ impl Database { Ok(result) } - async fn fetch_option(&self, statement: &T) -> Result> - where - T: SqlxBinder + Send + Sync, - { - let result = match &*self.inner { - #[cfg(feature = "sqlite")] - DatabaseImpl::Sqlite(inner) => inner.fetch_option(statement).await?.map(Row::Sqlite), - #[cfg(feature = "postgres")] - DatabaseImpl::Postgres(inner) => { - inner.fetch_option(statement).await?.map(Row::Postgres) - } - #[cfg(feature = "mysql")] - DatabaseImpl::MySql(inner) => inner.fetch_option(statement).await?.map(Row::MySql), - }; - - Ok(result) - } - fn supports_returning(&self) -> bool { match &*self.inner { #[cfg(feature = "sqlite")] @@ -1509,53 +1938,6 @@ impl Database { } } - async fn fetch_all(&self, statement: &T) -> Result> - where - T: SqlxBinder + Send + Sync, - { - let result = match &*self.inner { - #[cfg(feature = "sqlite")] - DatabaseImpl::Sqlite(inner) => inner - .fetch_all(statement) - .await? - .into_iter() - .map(Row::Sqlite) - .collect(), - #[cfg(feature = "postgres")] - DatabaseImpl::Postgres(inner) => inner - .fetch_all(statement) - .await? - .into_iter() - .map(Row::Postgres) - .collect(), - #[cfg(feature = "mysql")] - DatabaseImpl::MySql(inner) => inner - .fetch_all(statement) - .await? - .into_iter() - .map(Row::MySql) - .collect(), - }; - - Ok(result) - } - - async fn execute_statement(&self, statement: &T) -> Result - where - T: SqlxBinder + Send + Sync, - { - let result = match &*self.inner { - #[cfg(feature = "sqlite")] - DatabaseImpl::Sqlite(inner) => inner.execute_statement(statement).await?, - #[cfg(feature = "postgres")] - DatabaseImpl::Postgres(inner) => inner.execute_statement(statement).await?, - #[cfg(feature = "mysql")] - DatabaseImpl::MySql(inner) => inner.execute_statement(statement).await?, - }; - - Ok(result) - } - async fn execute_schema( &self, statement: T, @@ -1591,7 +1973,7 @@ impl ColumnTypeMapper for Database { /// This trait is used to provide a backend for the database. #[cfg_attr(test, automock)] #[async_trait] -pub trait DatabaseBackend: Send + Sync { +pub trait DatabaseBackend: Send { /// Inserts a new row into the database, or updates an existing row if it /// already exists. /// @@ -1600,7 +1982,7 @@ pub trait DatabaseBackend: Send + Sync { /// This method can return an error if the row could not be inserted into /// the database, for instance because the migrations haven't been /// applied, or there was a problem with the database connection. - async fn insert_or_update(&self, data: &mut T) -> Result<()>; + async fn insert_or_update(&mut self, data: &mut T) -> Result<()>; /// Inserts a new row into the database. /// @@ -1609,7 +1991,7 @@ pub trait DatabaseBackend: Send + Sync { /// This method can return an error if the row could not be inserted into /// the database, for instance because the migrations haven't been /// applied, or there was a problem with the database connection. - async fn insert(&self, data: &mut T) -> Result<()>; + async fn insert(&mut self, data: &mut T) -> Result<()>; /// Updates an existing row in the database. /// @@ -1618,7 +2000,7 @@ pub trait DatabaseBackend: Send + Sync { /// This method can return an error if the row could not be updated in the /// database, for instance because the migrations haven't been applied, or /// there was a problem with the database connection. - async fn update(&self, data: &mut T) -> Result<()>; + async fn update(&mut self, data: &mut T) -> Result<()>; /// Bulk inserts multiple rows into the database. /// @@ -1627,7 +2009,7 @@ pub trait DatabaseBackend: Send + Sync { /// This method can return an error if the rows could not be inserted into /// the database, for instance because the migrations haven't been /// applied, or there was a problem with the database connection. - async fn bulk_insert(&self, data: &mut [T]) -> Result<()>; + async fn bulk_insert(&mut self, data: &mut [T]) -> Result<()>; /// Bulk inserts multiple rows into the database, or updates existing rows /// if they already exist. @@ -1637,7 +2019,7 @@ pub trait DatabaseBackend: Send + Sync { /// This method can return an error if the rows could not be inserted into /// the database, for instance because the migrations haven't been /// applied, or there was a problem with the database connection. - async fn bulk_insert_or_update(&self, data: &mut [T]) -> Result<()>; + async fn bulk_insert_or_update(&mut self, data: &mut [T]) -> Result<()>; /// Executes a query and returns the results converted to the model type. /// @@ -1650,7 +2032,7 @@ pub trait DatabaseBackend: Send + Sync { /// generated or applied). /// /// Can return an error if the database connection is lost. - async fn query(&self, query: &Query) -> Result>; + async fn query(&mut self, query: &Query) -> Result>; /// Returns the first row that matches the given query. If no rows match the /// query, returns `None`. @@ -1664,7 +2046,7 @@ pub trait DatabaseBackend: Send + Sync { /// applied). /// /// Can return an error if the database connection is lost. - async fn get(&self, query: &Query) -> Result>; + async fn get(&mut self, query: &Query) -> Result>; /// Returns whether a row exists that matches the given query. /// @@ -1677,7 +2059,7 @@ pub trait DatabaseBackend: Send + Sync { /// applied). /// /// Can return an error if the database connection is lost. - async fn exists(&self, query: &Query) -> Result; + async fn exists(&mut self, query: &Query) -> Result; /// Deletes all rows that match the given query. /// @@ -1690,45 +2072,84 @@ pub trait DatabaseBackend: Send + Sync { /// applied). /// /// Can return an error if the database connection is lost. - async fn delete(&self, query: &Query) -> Result; + async fn delete(&mut self, query: &Query) -> Result; +} + +#[async_trait] +impl DatabaseBackend for &mut DB { + async fn insert_or_update(&mut self, data: &mut T) -> Result<()> { + (**self).insert_or_update(data).await + } + + async fn insert(&mut self, data: &mut T) -> Result<()> { + (**self).insert(data).await + } + + async fn update(&mut self, data: &mut T) -> Result<()> { + (**self).update(data).await + } + + async fn bulk_insert(&mut self, data: &mut [T]) -> Result<()> { + (**self).bulk_insert(data).await + } + + async fn bulk_insert_or_update(&mut self, data: &mut [T]) -> Result<()> { + (**self).bulk_insert_or_update(data).await + } + + async fn query(&mut self, query: &Query) -> Result> { + (**self).query(query).await + } + + async fn get(&mut self, query: &Query) -> Result> { + (**self).get(query).await + } + + async fn exists(&mut self, query: &Query) -> Result { + (**self).exists(query).await + } + + async fn delete(&mut self, query: &Query) -> Result { + (**self).delete(query).await + } } #[async_trait] -impl DatabaseBackend for Database { - async fn insert_or_update(&self, data: &mut T) -> Result<()> { - Database::insert_or_update(self, data).await +impl DatabaseBackend for &Database { + async fn insert_or_update(&mut self, data: &mut T) -> Result<()> { + Database::insert_or_update_generic(*self, data, true).await } - async fn insert(&self, data: &mut T) -> Result<()> { - Database::insert(self, data).await + async fn insert(&mut self, data: &mut T) -> Result<()> { + Database::insert_or_update_generic(*self, data, false).await } - async fn update(&self, data: &mut T) -> Result<()> { - Database::update(self, data).await + async fn update(&mut self, data: &mut T) -> Result<()> { + Database::update_generic(*self, data).await } - async fn bulk_insert(&self, data: &mut [T]) -> Result<()> { - Database::bulk_insert(self, data).await + async fn bulk_insert(&mut self, data: &mut [T]) -> Result<()> { + Database::bulk_insert_generic(*self, data, false).await } - async fn bulk_insert_or_update(&self, data: &mut [T]) -> Result<()> { - Database::bulk_insert_or_update(self, data).await + async fn bulk_insert_or_update(&mut self, data: &mut [T]) -> Result<()> { + Database::bulk_insert_generic(*self, data, true).await } - async fn query(&self, query: &Query) -> Result> { - Database::query(self, query).await + async fn query(&mut self, query: &Query) -> Result> { + Database::query_generic(*self, query).await } - async fn get(&self, query: &Query) -> Result> { - Database::get(self, query).await + async fn get(&mut self, query: &Query) -> Result> { + Database::get_generic(*self, query).await } - async fn exists(&self, query: &Query) -> Result { - Database::exists(self, query).await + async fn exists(&mut self, query: &Query) -> Result { + Database::exists_generic(*self, query).await } - async fn delete(&self, query: &Query) -> Result { - Database::delete(self, query).await + async fn delete(&mut self, query: &Query) -> Result { + Database::delete_generic(*self, query).await } } diff --git a/cot/src/db/impl_mysql.rs b/cot/src/db/impl_mysql.rs index ccd5bfc7..f6324f42 100644 --- a/cot/src/db/impl_mysql.rs +++ b/cot/src/db/impl_mysql.rs @@ -1,9 +1,10 @@ //! Database interface implementation – MySQL backend. use crate::db::ColumnType; -use crate::db::sea_query_db::impl_sea_query_db_backend; +use crate::db::sea_query_db::{impl_sea_query_db_backend, impl_sea_query_transaction_backend}; impl_sea_query_db_backend!(DatabaseMySql: sqlx::mysql::MySql, sqlx::mysql::MySqlPool, MySqlRow, MySqlValueRef, sea_query::MysqlQueryBuilder); +impl_sea_query_transaction_backend!(DatabaseMySql, TransactionMySql: sqlx::mysql::MySql, MySqlRow, sea_query::MysqlQueryBuilder); impl DatabaseMySql { #[expect(clippy::unused_async)] diff --git a/cot/src/db/impl_postgres.rs b/cot/src/db/impl_postgres.rs index f7e813b9..43e73575 100644 --- a/cot/src/db/impl_postgres.rs +++ b/cot/src/db/impl_postgres.rs @@ -1,8 +1,9 @@ //! Database interface implementation – PostgreSQL backend. -use crate::db::sea_query_db::impl_sea_query_db_backend; +use crate::db::sea_query_db::{impl_sea_query_db_backend, impl_sea_query_transaction_backend}; impl_sea_query_db_backend!(DatabasePostgres: sqlx::postgres::Postgres, sqlx::postgres::PgPool, PostgresRow, PostgresValueRef, sea_query::PostgresQueryBuilder); +impl_sea_query_transaction_backend!(DatabasePostgres, TransactionPostgres: sqlx::postgres::Postgres, PostgresRow, sea_query::PostgresQueryBuilder); impl DatabasePostgres { #[expect(clippy::unused_async)] diff --git a/cot/src/db/impl_sqlite.rs b/cot/src/db/impl_sqlite.rs index 974735e3..94d52042 100644 --- a/cot/src/db/impl_sqlite.rs +++ b/cot/src/db/impl_sqlite.rs @@ -2,9 +2,10 @@ use sea_query_binder::SqlxValues; -use crate::db::sea_query_db::impl_sea_query_db_backend; +use crate::db::sea_query_db::{impl_sea_query_db_backend, impl_sea_query_transaction_backend}; impl_sea_query_db_backend!(DatabaseSqlite: sqlx::sqlite::Sqlite, sqlx::sqlite::SqlitePool, SqliteRow, SqliteValueRef, sea_query::SqliteQueryBuilder); +impl_sea_query_transaction_backend!(DatabaseSqlite, TransactionSqlite: sqlx::sqlite::Sqlite, SqliteRow, sea_query::SqliteQueryBuilder); impl DatabaseSqlite { async fn init(&self) -> crate::db::Result<()> { diff --git a/cot/src/db/query.rs b/cot/src/db/query.rs index 139a7ca7..45c87517 100644 --- a/cot/src/db/query.rs +++ b/cot/src/db/query.rs @@ -8,7 +8,7 @@ use sea_query::{ExprTrait, IntoColumnRef}; use crate::db; use crate::db::{ Auto, Database, DatabaseBackend, DbFieldValue, DbValue, ForeignKey, FromDbValue, Identifier, - Model, StatementResult, ToDbFieldValue, + Model, RawExecutor, StatementResult, ToDbFieldValue, }; /// A query that can be executed on a database. Can be used to filter, update, @@ -177,7 +177,7 @@ impl Query { /// # Errors /// /// Returns an error if the query fails. - pub async fn all(&self, db: &DB) -> db::Result> { + pub async fn all(&self, mut db: DB) -> db::Result> { db.query(self).await } @@ -186,7 +186,7 @@ impl Query { /// # Errors /// /// Returns an error if the query fails. - pub async fn get(&self, db: &DB) -> db::Result> { + pub async fn get(&self, mut db: DB) -> db::Result> { // TODO panic/error if more than one result db.get(self).await } @@ -202,7 +202,8 @@ impl Query { .from(T::TABLE_NAME) .expr(sea_query::Expr::col(sea_query::Asterisk).count()); self.add_filter_to_statement(&mut select); - let row = db.fetch_option(&select).await?; + let mut db_ref = db; + let row = db_ref.fetch_option(&select).await?; let count = match row { #[expect(clippy::cast_sign_loss)] Some(row) => row.get::(0)? as u64, @@ -216,7 +217,7 @@ impl Query { /// # Errors /// /// Returns an error if the query fails. - pub async fn exists(&self, db: &DB) -> db::Result { + pub async fn exists(&self, mut db: DB) -> db::Result { db.exists(self).await } @@ -225,7 +226,7 @@ impl Query { /// # Errors /// /// Returns an error if the query fails. - pub async fn delete(&self, db: &DB) -> db::Result { + pub async fn delete(&self, mut db: DB) -> db::Result { db.delete(self).await } @@ -1484,7 +1485,7 @@ mod tests { db.expect_query().returning(|_| Ok(Vec::::new())); let query: Query = Query::new(); - let result = query.all(&db).await; + let result = query.all(&mut db).await; assert_eq!(result.unwrap(), Vec::::new()); } @@ -1495,7 +1496,7 @@ mod tests { db.expect_get().returning(|_| Ok(Option::::None)); let query: Query = Query::new(); - let result = query.get(&db).await; + let result = query.get(&mut db).await; assert_eq!(result.unwrap(), Option::::None); } @@ -1508,7 +1509,7 @@ mod tests { let query: Query = Query::new(); - let result = query.exists(&db).await; + let result = query.exists(&mut db).await; assert!(result.is_ok()); } @@ -1519,7 +1520,7 @@ mod tests { .returning(|_: &Query| Ok(StatementResult::new(RowsNum(0)))); let query: Query = Query::new(); - let result = query.delete(&db).await; + let result = query.delete(&mut db).await; assert!(result.is_ok()); } diff --git a/cot/src/db/relations.rs b/cot/src/db/relations.rs index 5ab6685f..e0fa4860 100644 --- a/cot/src/db/relations.rs +++ b/cot/src/db/relations.rs @@ -92,11 +92,11 @@ impl ForeignKey { /// could not be found in the database. /// /// Returns an error if there was a problem communicating with the database. - pub async fn get(&mut self, db: &DB) -> Result<&T> { + pub async fn get(&mut self, mut db: DB) -> Result<&T> { match self { Self::Model(model) => Ok(model), Self::PrimaryKey(pk) => { - let model = T::get_by_primary_key(db, pk.clone()) + let model = T::get_by_primary_key(&mut db, pk.clone()) .await? .ok_or(DatabaseError::ForeignKeyNotFound)?; *self = Self::Model(Box::new(model)); diff --git a/cot/src/db/sea_query_db.rs b/cot/src/db/sea_query_db.rs index 8b0c4df0..f334df95 100644 --- a/cot/src/db/sea_query_db.rs +++ b/cot/src/db/sea_query_db.rs @@ -27,6 +27,15 @@ macro_rules! impl_sea_query_db_backend { Ok(()) } + pub(super) async fn begin( + &self, + ) -> crate::db::Result> { + self.db_connection + .begin() + .await + .map_err(crate::db::sea_query_db::map_sqlx_error) + } + pub(super) async fn fetch_option( &self, statement: &T, @@ -36,7 +45,7 @@ macro_rules! impl_sea_query_db_backend { let row = Self::sqlx_query_with(&sql, values) .fetch_optional(&self.db_connection) .await - .map_err(|err| crate::db::sea_query_db::map_sqlx_error(err))?; + .map_err(crate::db::sea_query_db::map_sqlx_error)?; Ok(row.map($row_name::new)) } @@ -185,4 +194,97 @@ pub(crate) fn map_sqlx_error(err: sqlx::Error) -> crate::db::DatabaseError { crate::db::DatabaseError::from(err) } -pub(super) use impl_sea_query_db_backend; +pub(crate) use impl_sea_query_db_backend; + +/// Implements the transaction backend for a specific engine using `SeaQuery`. +macro_rules! impl_sea_query_transaction_backend { + ($db_name:ident, $transaction_name:ident : $sqlx_db_ty:ty, $row_name:ident, $query_builder:expr) => { + #[derive(derive_more::Debug)] + pub(super) struct $transaction_name<'a> { + #[debug("...")] + pub(crate) inner: sqlx::Transaction<'a, $sqlx_db_ty>, + } + + impl<'a> $transaction_name<'a> { + pub(super) fn new(transaction: sqlx::Transaction<'a, $sqlx_db_ty>) -> Self { + Self { inner: transaction } + } + + pub(super) async fn commit(self) -> crate::db::Result<()> { + self.inner + .commit() + .await + .map_err(crate::db::sea_query_db::map_sqlx_error) + } + + pub(super) async fn rollback(self) -> crate::db::Result<()> { + self.inner + .rollback() + .await + .map_err(crate::db::sea_query_db::map_sqlx_error) + } + + pub(super) async fn fetch_option( + &mut self, + statement: &T, + ) -> crate::db::Result> { + let (sql, values) = $db_name::build_sql(statement); + + let row = $db_name::sqlx_query_with(&sql, values) + .fetch_optional(&mut *self.inner) + .await + .map_err(crate::db::sea_query_db::map_sqlx_error)?; + Ok(row.map($row_name::new)) + } + + pub(super) async fn fetch_all( + &mut self, + statement: &T, + ) -> crate::db::Result> { + let (sql, values) = $db_name::build_sql(statement); + + let result = $db_name::sqlx_query_with(&sql, values) + .fetch_all(&mut *self.inner) + .await + .map_err(crate::db::sea_query_db::map_sqlx_error)? + .into_iter() + .map($row_name::new) + .collect(); + Ok(result) + } + + pub(super) async fn execute_statement( + &mut self, + statement: &T, + ) -> crate::db::Result { + let (sql, mut values) = $db_name::build_sql(statement); + $db_name::prepare_values(&mut values); + + self.execute_sqlx($db_name::sqlx_query_with(&sql, values)) + .await + } + + async fn execute_sqlx<'b, A>( + &mut self, + sqlx_statement: sqlx::query::Query<'b, $sqlx_db_ty, A>, + ) -> crate::db::Result + where + A: 'b + sqlx::IntoArguments<'b, $sqlx_db_ty>, + { + let result = sqlx_statement + .execute(&mut *self.inner) + .await + .map_err(crate::db::sea_query_db::map_sqlx_error)?; + let result = crate::db::StatementResult { + rows_affected: crate::db::RowsNum(result.rows_affected()), + last_inserted_row_id: $db_name::last_inserted_row_id_for(&result), + }; + + tracing::debug!("Rows affected: {}", result.rows_affected.0); + Ok(result) + } + } + }; +} + +pub(crate) use impl_sea_query_transaction_backend; diff --git a/cot/src/session/store/db.rs b/cot/src/session/store/db.rs index 79b5d4e3..b46e8b93 100644 --- a/cot/src/session/store/db.rs +++ b/cot/src/session/store/db.rs @@ -112,11 +112,12 @@ impl DbStore { pub fn new(connection: Database) -> DbStore { DbStore { connection } } -} -#[async_trait] -impl SessionStore for DbStore { - async fn create(&self, record: &mut Record) -> session_store::Result<()> { + async fn create_in_executor( + &self, + mut db: DB, + record: &mut Record, + ) -> session_store::Result<()> { for _ in 0..=MAX_COLLISION_RETRIES { let key = record.id.to_string(); @@ -132,7 +133,7 @@ impl SessionStore for DbStore { expiry, }; - let res = self.connection.insert(&mut model).await; + let res = db.insert(&mut model).await; match res { Ok(()) => { return Ok(()); @@ -146,27 +147,45 @@ impl SessionStore for DbStore { } Err(DbStoreError::TooManyIdCollisions(MAX_COLLISION_RETRIES))? } +} + +#[async_trait] +impl SessionStore for DbStore { + async fn create(&self, record: &mut Record) -> session_store::Result<()> { + self.create_in_executor(&self.connection, record).await + } async fn save(&self, record: &Record) -> session_store::Result<()> { - // TODO: use transactions when implemented + let mut transaction = self + .connection + .begin() + .await + .map_err(DbStoreError::DatabaseError)?; + let key = record.id.to_string(); let data = serde_json::to_string(&record.data) .map_err(|err| DbStoreError::Serialize(Box::new(err)))?; let query = query!(Session, $key == key) - .get(&self.connection) + .get(&mut transaction) .await .map_err(DbStoreError::DatabaseError)?; if let Some(mut model) = query { model.data = data; model - .update(&self.connection) + .update(&mut transaction) .await .map_err(DbStoreError::DatabaseError)?; } else { let mut record = record.clone(); - self.create(&mut record).await?; + self.create_in_executor(&mut transaction, &mut record) + .await?; } + + transaction + .commit() + .await + .map_err(DbStoreError::DatabaseError)?; Ok(()) } diff --git a/cot/tests/db.rs b/cot/tests/db.rs index 3a6c1a6c..0a336c6b 100644 --- a/cot/tests/db.rs +++ b/cot/tests/db.rs @@ -932,3 +932,45 @@ async fn bulk_insert_with_fixed_pk(test_db: &mut TestDatabase) { .unwrap(); assert_eq!(model300.name, "test300"); } + +#[cot_macros::dbtest] +async fn transaction_commit(test_db: &mut TestDatabase) { + migrate_test_model(&*test_db).await; + let db = &**test_db; + + let mut transaction = db.begin().await.unwrap(); + let mut model = TestModel { + id: Auto::auto(), + name: "test".to_string(), + }; + model.insert(&mut transaction).await.unwrap(); + transaction.commit().await.unwrap(); + + let exists = TestModel::objects() + .filter(::Fields::name.eq("test")) + .exists(db) + .await + .unwrap(); + assert!(exists); +} + +#[cot_macros::dbtest] +async fn transaction_rollback(test_db: &mut TestDatabase) { + migrate_test_model(&*test_db).await; + let db = &**test_db; + + let mut transaction = db.begin().await.unwrap(); + let mut model = TestModel { + id: Auto::auto(), + name: "test_rollback".to_string(), + }; + model.insert(&mut transaction).await.unwrap(); + transaction.rollback().await.unwrap(); + + let exists = TestModel::objects() + .filter(::Fields::name.eq("test_rollback")) + .exists(db) + .await + .unwrap(); + assert!(!exists); +} diff --git a/examples/admin/src/main.rs b/examples/admin/src/main.rs index f85edef8..77af4fb9 100644 --- a/examples/admin/src/main.rs +++ b/examples/admin/src/main.rs @@ -62,12 +62,14 @@ impl App for HelloApp { } async fn init(&self, context: &mut ProjectContext) -> cot::Result<()> { - // TODO use transaction - let user = DatabaseUser::get_by_username(context.database(), "admin").await?; + let mut transaction = context.database().begin().await?; + + let user = DatabaseUser::get_by_username(&mut transaction, "admin").await?; if user.is_none() { - DatabaseUser::create_user(context.database(), "admin", "admin").await?; + DatabaseUser::create_user(&mut transaction, "admin", "admin").await?; } + transaction.commit().await?; Ok(()) }