diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index defeb5c2ac2d..bc4cc27aea1d 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -299,7 +299,13 @@ where } TransactionManagerStatus::InError => panic!("Transaction manager in error"), }; - Self::TransactionManager::begin_transaction(self) + Self::TransactionManager::begin_transaction(self)?; + // set the test transaction flag + // to pervent that this connection gets droped in connection pools + // Tests commonly set the poolsize to 1 and use `begin_test_transaction` + // to prevent modifications to the schema + Self::TransactionManager::transaction_manager_status_mut(self).set_test_transaction_flag(); + Ok(()) } /// Executes the given function inside a transaction, but does not commit diff --git a/diesel/src/connection/transaction_manager.rs b/diesel/src/connection/transaction_manager.rs index 1e5eb773fe89..c5277b893b87 100644 --- a/diesel/src/connection/transaction_manager.rs +++ b/diesel/src/connection/transaction_manager.rs @@ -38,6 +38,9 @@ pub trait TransactionManager { /// Used to ensure that `begin_test_transaction` is not called when already /// inside of a transaction, and that operations are not run in a `InError` /// transaction manager. + #[diesel_derives::__diesel_public_if( + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" + )] fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus; /// Executes the given function inside of a database transaction @@ -66,6 +69,35 @@ pub trait TransactionManager { }, } } + + /// This methods checks if the connection manager is considered to be broken + /// by connection pool implementations + /// + /// A connection manager is considered to be broken by default if it either + /// contains an open transaction (because you don't want to have connections + /// with open transactions in your pool) or when the transaction manager is + /// in an error state. + #[diesel_derives::__diesel_public_if( + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" + )] + fn is_broken_transaction_manager(conn: &mut Conn) -> bool { + match Self::transaction_manager_status_mut(conn).transaction_state() { + // all transactions are closed + // so we don't consider this connection broken + Ok(ValidTransactionManagerStatus { + in_transaction: None, + }) => false, + // The transaction manager is in an error state + // Therefore we consider this connection broken + Err(_) => true, + // The transaction manager contains a open transaction + // we do consider this connection broken + // if that transaction was not opened by `begin_test_transaction` + Ok(ValidTransactionManagerStatus { + in_transaction: Some(s), + }) => !s.test_transaction, + } + } } /// An implementation of `TransactionManager` which can be used for backends @@ -77,6 +109,9 @@ pub struct AnsiTransactionManager { } /// Status of the transaction manager +#[diesel_derives::__diesel_public_if( + feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes" +)] #[derive(Debug)] pub enum TransactionManagerStatus { /// Valid status, the manager can run operations @@ -157,6 +192,15 @@ impl TransactionManagerStatus { TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager), } } + + pub(crate) fn set_test_transaction_flag(&mut self) { + if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus { + in_transaction: Some(s), + }) = self + { + s.test_transaction = true; + } + } } /// Valid transaction status for the manager. Can return the current transaction depth @@ -171,6 +215,7 @@ pub struct ValidTransactionManagerStatus { struct InTransactionStatus { transaction_depth: NonZeroU32, top_level_transaction_requires_rollback: bool, + test_transaction: bool, } impl ValidTransactionManagerStatus { @@ -209,6 +254,7 @@ impl ValidTransactionManagerStatus { self.in_transaction = Some(InTransactionStatus { transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"), top_level_transaction_requires_rollback: false, + test_transaction: false, }); Ok(()) } @@ -331,6 +377,7 @@ where Some(InTransactionStatus { transaction_depth, top_level_transaction_requires_rollback, + .. }), }) if transaction_depth.get() > 1 && !*top_level_transaction_requires_rollback => @@ -376,6 +423,7 @@ where Some(InTransactionStatus { ref mut transaction_depth, top_level_transaction_requires_rollback: true, + .. }), }) = conn.transaction_state().status { diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 93f510d95632..ac983beaec46 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -205,15 +205,7 @@ impl crate::r2d2::R2D2Connection for MysqlConnection { } fn is_broken(&mut self) -> bool { - match self.transaction_state.status.transaction_depth() { - // all transactions are closed - // so we don't consider this connection broken - Ok(None) => false, - // The transaction manager is in an error state - // or contains an open transaction - // Therefore we consider this connection broken - Err(_) | Ok(Some(_)) => true, - } + AnsiTransactionManager::is_broken_transaction_manager(self) } } diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 90a4fc57ba89..fd2130da9aff 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -280,20 +280,7 @@ impl crate::r2d2::R2D2Connection for PgConnection { } fn is_broken(&mut self) -> bool { - match self - .connection_and_transaction_manager - .transaction_state - .status - .transaction_depth() - { - // all transactions are closed - // so we don't consider this connection broken - Ok(None) => false, - // The transaction manager is in an error state - // or contains an open transaction - // Therefore we consider this connection broken - Err(_) | Ok(Some(_)) => true, - } + AnsiTransactionManager::is_broken_transaction_manager(self) } } diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index 1256832d0a83..2a6f21de2a0a 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -528,4 +528,57 @@ mod tests { assert_eq!(checkin_count.load(Ordering::Relaxed), 3); assert_eq!(checkout_count.load(Ordering::Relaxed), 3); } + + #[cfg(feature = "postgres")] + #[test] + fn verify_that_begin_test_transaction_works_with_pools() { + use crate::prelude::*; + use crate::r2d2::*; + + table! { + users { + id -> Integer, + name -> Text, + } + } + + #[derive(Debug)] + struct TestConnectionCustomizer; + + impl CustomizeConnection for TestConnectionCustomizer { + fn on_acquire(&self, conn: &mut PgConnection) -> Result<(), E> { + conn.begin_test_transaction() + .expect("Failed to start test transaction"); + + Ok(()) + } + } + + let manager = ConnectionManager::::new(database_url()); + let pool = Pool::builder() + .max_size(1) + .connection_customizer(Box::new(TestConnectionCustomizer)) + .build(manager) + .unwrap(); + + let mut conn = pool.get().unwrap(); + + crate::sql_query( + "CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name TEXT NOT NULL)", + ) + .execute(&mut conn) + .unwrap(); + + crate::insert_into(users::table) + .values(users::name.eq("John")) + .execute(&mut conn) + .unwrap(); + + std::mem::drop(conn); + + let mut conn2 = pool.get().unwrap(); + + let user_count = users::table.count().get_result::(&mut conn2).unwrap(); + assert_eq!(user_count, 1); + } } diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 509e463fb6f6..4665ea032420 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -204,15 +204,7 @@ impl crate::r2d2::R2D2Connection for crate::sqlite::SqliteConnection { } fn is_broken(&mut self) -> bool { - match self.transaction_state.status.transaction_depth() { - // all transactions are closed - // so we don't consider this connection broken - Ok(None) => false, - // The transaction manager is in an error state - // or contains an open transaction - // Therefore we consider this connection broken - Err(_) | Ok(Some(_)) => true, - } + AnsiTransactionManager::is_broken_transaction_manager(self) } }