diff --git a/src/replay_client.rs b/src/replay_client.rs index d57d129..a3b1b30 100644 --- a/src/replay_client.rs +++ b/src/replay_client.rs @@ -2,7 +2,9 @@ use std::io::BufReader; use std::num::NonZeroI32; use std::sync::Arc; -use etl::config::{ETL_REPLICATION_OPTIONS, IntoConnectOptions, PgConnectionConfig}; +use etl::config::{ + ETL_REPLICATION_OPTIONS, ETL_STATE_MANAGEMENT_OPTIONS, IntoConnectOptions, PgConnectionConfig, +}; use etl::error::EtlResult; use etl_postgres::replication::extract_server_version; use tokio_postgres::tls::MakeTlsConnect; @@ -38,7 +40,10 @@ where #[derive(Debug, Clone)] pub struct ReplayClient { stream_id: StreamId, - client: Arc, + /// Connection used for COPY streaming of replay events. + copy_client: Arc, + /// Separate connection used for checkpoint updates during replay. + checkpoint_client: Arc, /// Server version extracted from connection - reserved for future version-specific logic #[allow(dead_code)] server_version: Option, @@ -67,22 +72,28 @@ impl ReplayClient { stream_id: StreamId, pg_connection_config: PgConnectionConfig, ) -> EtlResult { - let config: Config = pg_connection_config + let copy_config: Config = pg_connection_config .clone() .with_db(Some(&ETL_REPLICATION_OPTIONS)); - let (client, connection) = config.connect(NoTls).await?; + let (copy_client, copy_connection) = copy_config.connect(NoTls).await?; - let server_version = connection + let server_version = copy_connection .parameter("server_version") .and_then(extract_server_version); - spawn_postgres_connection::(connection); + spawn_postgres_connection::(copy_connection); + + let checkpoint_config: Config = + pg_connection_config.with_db(Some(&ETL_STATE_MANAGEMENT_OPTIONS)); + let (checkpoint_client, checkpoint_connection) = checkpoint_config.connect(NoTls).await?; + spawn_postgres_connection::(checkpoint_connection); info!("successfully connected to postgres without tls"); Ok(ReplayClient { - client: Arc::new(client), + copy_client: Arc::new(copy_client), + checkpoint_client: Arc::new(checkpoint_client), server_version, stream_id, }) @@ -95,7 +106,7 @@ impl ReplayClient { stream_id: StreamId, pg_connection_config: PgConnectionConfig, ) -> EtlResult { - let config: Config = pg_connection_config + let copy_config: Config = pg_connection_config .clone() .with_db(Some(&ETL_REPLICATION_OPTIONS)); @@ -113,18 +124,28 @@ impl ReplayClient { .with_root_certificates(root_store) .with_no_client_auth(); - let (client, connection) = config.connect(MakeRustlsConnect::new(tls_config)).await?; + let (copy_client, copy_connection) = copy_config + .connect(MakeRustlsConnect::new(tls_config.clone())) + .await?; - let server_version = connection + let server_version = copy_connection .parameter("server_version") .and_then(extract_server_version); - spawn_postgres_connection::(connection); + spawn_postgres_connection::(copy_connection); + + let checkpoint_config: Config = + pg_connection_config.with_db(Some(&ETL_STATE_MANAGEMENT_OPTIONS)); + let (checkpoint_client, checkpoint_connection) = checkpoint_config + .connect(MakeRustlsConnect::new(tls_config)) + .await?; + spawn_postgres_connection::(checkpoint_connection); info!("successfully connected to postgres with tls"); Ok(ReplayClient { - client: Arc::new(client), + copy_client: Arc::new(copy_client), + checkpoint_client: Arc::new(checkpoint_client), server_version, stream_id, }) @@ -133,7 +154,7 @@ impl ReplayClient { /// Checks if the underlying connection is closed. #[must_use] pub fn is_closed(&self) -> bool { - self.client.is_closed() + self.copy_client.is_closed() || self.checkpoint_client.is_closed() } /// Gets events between two checkpoints (exclusive on both ends). @@ -150,12 +171,13 @@ impl ReplayClient { where (created_at, id) > ('{}'::timestamptz, '{}'::uuid) and (created_at, id) < ('{}'::timestamptz, '{}'::uuid) and stream_id = {} + order by created_at, id ) to stdout with (format text); "#, from.created_at, from.id, to.created_at, to.id, self.stream_id as i64 ); - let stream = self.client.copy_out_simple(©_query).await?; + let stream = self.copy_client.copy_out_simple(©_query).await?; Ok(stream) } @@ -164,7 +186,7 @@ impl ReplayClient { /// /// This is duplicated from the [`StreamStore`] because we want to use a persistent connection during failover. pub async fn update_checkpoint(&self, checkpoint: &EventIdentifier) -> EtlResult<()> { - self.client + self.checkpoint_client .execute( r#" update pgstream.streams diff --git a/src/slot_recovery.rs b/src/slot_recovery.rs index cac9d61..8a34334 100644 --- a/src/slot_recovery.rs +++ b/src/slot_recovery.rs @@ -24,6 +24,32 @@ use tracing::{info, warn}; use crate::types::SlotName; +type Checkpoint = (String, DateTime); + +#[must_use] +fn checkpoint_is_earlier(a: &Checkpoint, b: &Checkpoint) -> bool { + a.1 < b.1 || (a.1 == b.1 && a.0 < b.0) +} + +#[must_use] +fn select_recovery_checkpoint( + existing_checkpoint: Option, + lsn_checkpoint: Option, +) -> Option { + match (existing_checkpoint, lsn_checkpoint) { + (Some(existing), Some(from_lsn)) => { + if checkpoint_is_earlier(&existing, &from_lsn) { + Some(existing) + } else { + Some(from_lsn) + } + } + (Some(existing), None) => Some(existing), + (None, Some(from_lsn)) => Some(from_lsn), + (None, None) => None, + } +} + /// Checks if an error indicates a replication slot has been invalidated. /// /// Postgres returns error code 55000 (OBJECT_NOT_IN_PREREQUISITE_STATE) with the message @@ -66,6 +92,24 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<() // Start a transaction for the checkpoint update let mut tx = pool.begin().await?; + // Preserve an existing failover checkpoint if we are already in failover mode. + let existing_checkpoint_row: Option<(Option, Option>)> = sqlx::query_as( + "SELECT failover_checkpoint_id, failover_checkpoint_ts FROM pgstream.streams WHERE id = $1", + ) + .bind(stream_id as i64) + .fetch_optional(&mut *tx) + .await?; + + let existing_checkpoint = existing_checkpoint_row.and_then(|(id, ts)| id.zip(ts)); + + if let Some((id, created_at)) = &existing_checkpoint { + info!( + event_id = %id, + event_created_at = %created_at, + "existing failover checkpoint found" + ); + } + // 1. Get confirmed_flush_lsn BEFORE dropping the slot let confirmed_lsn: Option = sqlx::query_scalar( "SELECT confirmed_flush_lsn::text FROM pg_replication_slots WHERE slot_name = $1", @@ -91,7 +135,7 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<() ); // 2. Find the first event after the confirmed LSN - let checkpoint: Option<(String, DateTime)> = sqlx::query_as( + let lsn_checkpoint: Option = sqlx::query_as( "SELECT id::text, created_at FROM pgstream.events WHERE lsn > $1::pg_lsn AND stream_id = $2 ORDER BY created_at, id LIMIT 1", @@ -101,8 +145,23 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<() .fetch_optional(&mut *tx) .await?; - // 3. Set failover checkpoint BEFORE dropping slot (crash-safe ordering) - if let Some((id, created_at)) = checkpoint { + // 3. Choose the earliest safe checkpoint. + // If we are already in failover mode, keep the existing checkpoint unless + // the LSN-derived checkpoint is earlier. + let checkpoint = select_recovery_checkpoint(existing_checkpoint.clone(), lsn_checkpoint); + + // 4. Set failover checkpoint BEFORE dropping slot (crash-safe ordering) + if checkpoint == existing_checkpoint { + if let Some((id, created_at)) = checkpoint { + info!( + event_id = %id, + event_created_at = %created_at, + "preserving existing failover checkpoint during slot recovery" + ); + } else { + info!("no events found after confirmed_flush_lsn, pipeline will start fresh"); + } + } else if let Some((id, created_at)) = checkpoint { info!( event_id = %id, event_created_at = %created_at, @@ -124,7 +183,7 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<() info!("no events found after confirmed_flush_lsn, pipeline will start fresh"); } - // 4. Delete ETL replication state so ETL will create a fresh slot on restart + // 5. Delete ETL replication state so ETL will create a fresh slot on restart // This triggers DataSync, but we skip it by returning Ok(()) from write_table_rows. // The failover checkpoint ensures we COPY missed events when replication starts. let deleted = sqlx::query("DELETE FROM etl.replication_state WHERE pipeline_id = $1") @@ -137,10 +196,10 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<() "deleted ETL replication state to trigger fresh slot creation" ); - // 5. Commit the transaction - checkpoint is now durable + // 6. Commit the transaction - checkpoint is now durable tx.commit().await?; - // 6. Drop the invalidated slot AFTER commit (non-transactional operation) + // 7. Drop the invalidated slot AFTER commit (non-transactional operation) // This ordering ensures crash safety: if we crash here, the checkpoint is // already saved, and the next recovery attempt will simply drop the slot. let drop_result = sqlx::query("SELECT pg_drop_replication_slot($1)") @@ -164,6 +223,7 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<() #[cfg(test)] mod tests { use super::*; + use chrono::TimeZone; #[test] fn test_is_slot_invalidation_error_matches() { @@ -206,4 +266,42 @@ mod tests { let error = etl::etl_error!(etl::error::ErrorKind::InvalidState, "connection refused"); assert!(!is_slot_invalidation_error(&error)); } + + #[test] + fn test_select_recovery_checkpoint_prefers_earlier_existing_checkpoint() { + let ts = Utc.with_ymd_and_hms(2024, 3, 15, 12, 0, 0).unwrap(); + let existing = Some(("00000000-0000-0000-0000-000000000001".to_string(), ts)); + let from_lsn = Some(( + "00000000-0000-0000-0000-000000000002".to_string(), + ts + chrono::Duration::seconds(1), + )); + + assert_eq!( + select_recovery_checkpoint(existing.clone(), from_lsn), + existing + ); + } + + #[test] + fn test_select_recovery_checkpoint_prefers_earlier_lsn_checkpoint() { + let ts = Utc.with_ymd_and_hms(2024, 3, 15, 12, 0, 0).unwrap(); + let existing = Some(( + "00000000-0000-0000-0000-000000000002".to_string(), + ts + chrono::Duration::seconds(1), + )); + let from_lsn = Some(("00000000-0000-0000-0000-000000000001".to_string(), ts)); + + assert_eq!( + select_recovery_checkpoint(existing, from_lsn.clone()), + from_lsn + ); + } + + #[test] + fn test_select_recovery_checkpoint_uses_existing_when_lsn_checkpoint_missing() { + let ts = Utc.with_ymd_and_hms(2024, 3, 15, 12, 0, 0).unwrap(); + let existing = Some(("00000000-0000-0000-0000-000000000001".to_string(), ts)); + + assert_eq!(select_recovery_checkpoint(existing.clone(), None), existing); + } } diff --git a/src/stream.rs b/src/stream.rs index dea85ac..e91122d 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -103,16 +103,31 @@ where let result = self.sink.publish_events(events).await; if result.is_err() { - info!( - "Publishing events failed, entering failover at checkpoint event id: {:?}", - checkpoint_id - ); - metrics::record_failover_entered(self.config.id); - self.store - .store_stream_status(StreamStatus::Failover { - checkpoint_event_id: checkpoint_id, - }) - .await?; + let (current_status, _) = self.store.get_stream_state().await?; + + match current_status { + StreamStatus::Healthy => { + info!( + "Publishing events failed, entering failover at checkpoint event id: {:?}", + checkpoint_id + ); + metrics::record_failover_entered(self.config.id); + self.store + .store_stream_status(StreamStatus::Failover { + checkpoint_event_id: checkpoint_id, + }) + .await?; + } + StreamStatus::Failover { + checkpoint_event_id, + } => { + info!( + "Publishing events failed while already in failover, preserving checkpoint event id: {:?}", + checkpoint_event_id + ); + } + } + return Ok(()); } diff --git a/tests/failover_checkpoint_tests.rs b/tests/failover_checkpoint_tests.rs new file mode 100644 index 0000000..4fc2c87 --- /dev/null +++ b/tests/failover_checkpoint_tests.rs @@ -0,0 +1,109 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use etl::destination::Destination; +use etl::error::{ErrorKind, EtlResult}; +use etl::store::both::postgres::PostgresStore; +use postgres_stream::sink::Sink; +use postgres_stream::store::StreamStore; +use postgres_stream::stream::PgStream; +use postgres_stream::test_utils::{ + TestDatabase, create_postgres_store_with_table_id, insert_events_to_db, make_event_with_id, + test_stream_config, +}; +use postgres_stream::types::{StreamStatus, TriggeredEvent}; + +#[derive(Clone)] +struct FailFirstNSink { + fail_first_n: usize, + call_count: Arc, +} + +impl FailFirstNSink { + fn new(fail_first_n: usize) -> Self { + Self { + fail_first_n, + call_count: Arc::new(AtomicUsize::new(0)), + } + } +} + +impl Sink for FailFirstNSink { + fn name() -> &'static str { + "fail_first_n" + } + + fn publish_events( + &self, + _events: Vec, + ) -> impl std::future::Future> + Send { + let call_num = self.call_count.fetch_add(1, Ordering::SeqCst); + let fail_first_n = self.fail_first_n; + + async move { + if call_num < fail_first_n { + return Err(etl::etl_error!( + ErrorKind::InvalidData, + "Simulated sink failure", + "Test failure" + )); + } + + Ok(()) + } + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_failover_does_not_overwrite_checkpoint_on_repeated_failures() { + let db = TestDatabase::spawn().await; + let config = test_stream_config(&db); + let sink = FailFirstNSink::new(3); + let (store, table_id) = + create_postgres_store_with_table_id(config.id, &db.config, &db.pool).await; + + let stream: PgStream = + PgStream::create(config.clone(), sink, store.clone()) + .await + .expect("Failed to create PgStream"); + + let event_ids = insert_events_to_db(&db, 3).await; + + // First write fails and enters failover at event 0. + stream + .write_events(vec![make_event_with_id( + table_id, + event_ids.first().expect("Should have event 0"), + serde_json::json!({"seq": 1}), + )]) + .await + .expect("write_events should succeed even if sink fails"); + + // During failover, both checkpoint re-publish and current event publish fail. + // The checkpoint should remain at event 0 instead of being overwritten. + stream + .write_events(vec![make_event_with_id( + table_id, + event_ids.get(2).expect("Should have event 2"), + serde_json::json!({"seq": 3}), + )]) + .await + .expect("write_events should succeed even if sink fails"); + + let stream_store = StreamStore::create(config, store).await.unwrap(); + let (status, _) = stream_store.get_stream_state().await.unwrap(); + + match status { + StreamStatus::Failover { + checkpoint_event_id, + } => { + let expected = event_ids.first().expect("Should have event 0"); + assert_eq!(checkpoint_event_id.id, expected.id); + assert_eq!( + checkpoint_event_id.created_at.timestamp_micros(), + expected.created_at.timestamp_micros() + ); + } + StreamStatus::Healthy => panic!("Expected Failover status, got Healthy"), + } +} diff --git a/tests/replay_client_tests.rs b/tests/replay_client_tests.rs index c1e810b..2b26188 100644 --- a/tests/replay_client_tests.rs +++ b/tests/replay_client_tests.rs @@ -285,6 +285,68 @@ async fn test_get_events_copy_stream_boundary_conditions() { assert_eq!(count, 4, "Should have 4 events between indices 4 and 9"); } +#[tokio::test(flavor = "multi_thread")] +async fn test_update_checkpoint_while_copy_stream_active() { + let db = TestDatabase::spawn().await; + db.ensure_today_partition().await; + + // Ensure stream row exists for checkpoint updates. + sqlx::query("insert into pgstream.streams (id, next_maintenance_at) values (1, now())") + .execute(&db.pool) + .await + .unwrap(); + + let client = ReplayClient::connect(1, db.config.clone()) + .await + .expect("Failed to connect"); + + // Insert enough events so COPY remains active after reading first batch. + let events = insert_events_to_db(&db, 250).await; + + let config = test_stream_config(&db); + let store_backend = create_postgres_store(config.id, &db.config, &db.pool).await; + let store = StreamStore::create(config, store_backend).await.unwrap(); + let table_schema = store.get_events_table_schema().await.unwrap(); + + // Keep COPY active by only partially consuming it. + use etl::replication::stream::TableCopyStream; + use futures::StreamExt; + use tokio::pin; + + let stream = client + .get_events_copy_stream( + events.first().expect("Should have event 0"), + events.get(249).expect("Should have event 249"), + ) + .await + .expect("Failed to get copy stream"); + + let stream = TableCopyStream::wrap(stream, &table_schema.column_schemas, 1); + pin!(stream); + + for _ in 0..100 { + let row = stream + .next() + .await + .expect("Expected row from copy stream") + .expect("Should parse row successfully"); + drop(row); + } + + // Checkpoint update should not block while COPY stream is still active. + let checkpoint_update = tokio::time::timeout( + std::time::Duration::from_secs(3), + client.update_checkpoint(events.get(100).expect("Should have event 100")), + ) + .await; + + assert!( + checkpoint_update.is_ok(), + "Checkpoint update should not time out while copy stream is active" + ); + checkpoint_update.unwrap().unwrap(); +} + #[tokio::test(flavor = "multi_thread")] async fn test_get_events_copy_stream_filters_by_stream_id() { let db = TestDatabase::spawn().await; diff --git a/tests/slot_recovery_tests.rs b/tests/slot_recovery_tests.rs index 960e00b..60099ca 100644 --- a/tests/slot_recovery_tests.rs +++ b/tests/slot_recovery_tests.rs @@ -659,3 +659,209 @@ async fn test_pipeline_recovers_from_invalidated_slot() { "Failover checkpoint timestamp should be cleared after successful recovery" ); } + +#[tokio::test(flavor = "multi_thread")] +async fn test_slot_recovery_preserves_existing_failover_checkpoint() { + let _lock = acquire_exclusive_test_lock().await; + + let db = TestDatabase::spawn().await; + let stream_config = test_stream_config_with_id(&db, unique_pipeline_id()); + let pipeline_id = stream_config.id; + let slot_name = format!("supabase_etl_apply_{pipeline_id}"); + + migrate_etl(&db.config) + .await + .expect("Failed to run ETL migrations"); + + db.ensure_today_partition().await; + + // Create an existing failover checkpoint before the slot is created. + // Recovery must preserve this checkpoint. + let existing_checkpoint: (String, chrono::DateTime) = sqlx::query_as( + "INSERT INTO pgstream.events (id, payload, stream_id, created_at, lsn) + VALUES (gen_random_uuid(), $1, $2, now(), pg_current_wal_lsn()) + RETURNING id::text, created_at", + ) + .bind(serde_json::json!({"existing_failover_checkpoint": true})) + .bind(pipeline_id as i64) + .fetch_one(&db.pool) + .await + .expect("Should insert existing checkpoint event"); + + // Start and stop pipeline once to create the replication slot. + { + let state_store = PostgresStore::new(pipeline_id, db.config.clone()); + let sink = MemorySink::new(); + let pgstream = PgStream::create(stream_config.clone(), sink, state_store.clone()) + .await + .expect("Failed to create PgStream"); + + let pipeline_config: etl::config::PipelineConfig = stream_config.clone().into(); + let mut pipeline = etl::pipeline::Pipeline::new(pipeline_config, state_store, pgstream); + + pipeline.start().await.expect("Failed to start pipeline"); + + let mut slot_created = false; + for _ in 0..30 { + let slot_exists: bool = sqlx::query_scalar(&format!( + "SELECT EXISTS(SELECT 1 FROM pg_replication_slots WHERE slot_name = '{slot_name}')" + )) + .fetch_one(&db.pool) + .await + .unwrap(); + + if slot_exists { + slot_created = true; + break; + } + + tokio::time::sleep(Duration::from_millis(500)).await; + } + + assert!(slot_created, "Replication slot should be created"); + + let shutdown_tx = pipeline.shutdown_tx(); + shutdown_tx + .shutdown() + .expect("Failed to send shutdown signal"); + pipeline.wait().await.expect("Failed to wait for pipeline"); + } + + // Generate events after slot creation so LSN-derived recovery checkpoint would move forward. + for i in 0..5 { + sqlx::query( + "INSERT INTO pgstream.events (id, payload, stream_id, created_at, lsn) + VALUES (gen_random_uuid(), $1, $2, now(), pg_current_wal_lsn())", + ) + .bind(serde_json::json!({"after_slot_created": i})) + .bind(pipeline_id as i64) + .execute(&db.pool) + .await + .unwrap(); + } + + sqlx::query( + "UPDATE pgstream.streams + SET failover_checkpoint_id = $1, failover_checkpoint_ts = $2 + WHERE id = $3", + ) + .bind(&existing_checkpoint.0) + .bind(existing_checkpoint.1) + .bind(pipeline_id as i64) + .execute(&db.pool) + .await + .expect("Should set existing failover checkpoint"); + + // Invalidate the slot. + sqlx::query("alter system set max_slot_wal_keep_size = '1MB'") + .execute(&db.pool) + .await + .unwrap(); + sqlx::query("select pg_reload_conf()") + .execute(&db.pool) + .await + .unwrap(); + + sqlx::query("create table wal_bloat (id serial, data bytea)") + .execute(&db.pool) + .await + .unwrap(); + + for batch in 0..50 { + for _ in 0..10 { + sqlx::query("insert into wal_bloat (data) select decode(repeat('ab', 50000), 'hex') from generate_series(1, 10)") + .execute(&db.pool) + .await + .unwrap(); + } + + if batch % 10 == 9 { + let _: Option = sqlx::query_scalar("select pg_switch_wal()::text") + .fetch_one(&db.pool) + .await + .unwrap(); + sqlx::query("checkpoint").execute(&db.pool).await.unwrap(); + } + } + + let _: Option = sqlx::query_scalar("select pg_switch_wal()::text") + .fetch_one(&db.pool) + .await + .unwrap(); + sqlx::query("checkpoint").execute(&db.pool).await.unwrap(); + + let wal_status: String = sqlx::query_scalar(&format!( + "SELECT wal_status FROM pg_replication_slots WHERE slot_name = '{slot_name}'" + )) + .fetch_one(&db.pool) + .await + .unwrap(); + assert_eq!(wal_status, "lost", "Slot should be invalidated"); + + let confirmed_lsn: Option = sqlx::query_scalar(&format!( + "SELECT confirmed_flush_lsn::text FROM pg_replication_slots WHERE slot_name = '{slot_name}'" + )) + .fetch_one(&db.pool) + .await + .unwrap(); + + let confirmed_lsn = confirmed_lsn.expect("confirmed_flush_lsn should be present"); + + let lsn_checkpoint: Option<(String, chrono::DateTime)> = sqlx::query_as( + "SELECT id::text, created_at FROM pgstream.events + WHERE lsn > $1::pg_lsn AND stream_id = $2 + ORDER BY created_at, id LIMIT 1", + ) + .bind(&confirmed_lsn) + .bind(pipeline_id as i64) + .fetch_optional(&db.pool) + .await + .expect("Should resolve LSN-based checkpoint"); + + assert!( + lsn_checkpoint.is_some(), + "Recovery should have an LSN-derived checkpoint candidate" + ); + assert_ne!( + lsn_checkpoint.unwrap().0, + existing_checkpoint.0, + "LSN-derived checkpoint must differ from existing failover checkpoint" + ); + + // Reset system setting before recovery so subsequent tests remain isolated. + sqlx::query("alter system reset max_slot_wal_keep_size") + .execute(&db.pool) + .await + .unwrap(); + sqlx::query("select pg_reload_conf()") + .execute(&db.pool) + .await + .unwrap(); + + handle_slot_recovery(&db.pool, pipeline_id) + .await + .expect("Slot recovery should succeed"); + + let recovered_checkpoint: (Option, Option>) = + sqlx::query_as( + "SELECT failover_checkpoint_id, failover_checkpoint_ts FROM pgstream.streams WHERE id = $1", + ) + .bind(pipeline_id as i64) + .fetch_one(&db.pool) + .await + .expect("Should find stream row"); + + assert_eq!( + recovered_checkpoint.0, + Some(existing_checkpoint.0.clone()), + "Recovery must preserve the existing failover checkpoint id" + ); + assert_eq!( + recovered_checkpoint + .1 + .expect("Checkpoint timestamp should be present") + .timestamp_micros(), + existing_checkpoint.1.timestamp_micros(), + "Recovery must preserve the existing failover checkpoint timestamp" + ); +} diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs index d9d4913..d370ecd 100644 --- a/tests/stream_tests.rs +++ b/tests/stream_tests.rs @@ -9,6 +9,7 @@ use postgres_stream::test_utils::{ insert_events_to_db, make_event_with_id, make_test_event, test_stream_config, }; use postgres_stream::types::{StreamStatus, TriggeredEvent}; +use std::time::Duration as StdDuration; // Basic stream tests @@ -422,6 +423,70 @@ async fn test_failover_large_gap_recovery() { ); } +#[tokio::test(flavor = "multi_thread")] +async fn test_failover_recovery_does_not_hang_when_replay_exceeds_batch_size() { + let db = TestDatabase::spawn().await; + let config = test_stream_config(&db); + let sink = FailableSink::new(); + let (store, table_id) = + create_postgres_store_with_table_id(config.id, &db.config, &db.pool).await; + + let stream: PgStream = + PgStream::create(config.clone(), sink.clone(), store.clone()) + .await + .expect("Failed to create PgStream"); + + // 250 events ensure replay window is larger than batch.max_size (100). + let event_ids = insert_events_to_db(&db, 250).await; + + // Event 0 succeeds. + sink.succeed_always(); + stream + .write_events(vec![make_event_with_id( + table_id, + event_ids.first().expect("Should have event 0"), + serde_json::json!({"seq": 1}), + )]) + .await + .unwrap(); + + // Event 1 fails, entering failover mode. + sink.fail_on_call(1); + stream + .write_events(vec![make_event_with_id( + table_id, + event_ids.get(1).expect("Should have event 1"), + serde_json::json!({"seq": 2}), + )]) + .await + .unwrap(); + + // Event 249 should trigger failover replay with a gap > batch size. + sink.succeed_always(); + let recovery_result = tokio::time::timeout( + StdDuration::from_secs(20), + stream.write_events(vec![make_event_with_id( + table_id, + event_ids.get(249).expect("Should have event 249"), + serde_json::json!({"seq": 250}), + )]), + ) + .await; + + assert!( + recovery_result.is_ok(), + "Failover recovery should not hang when replay exceeds batch size" + ); + recovery_result.unwrap().unwrap(); + + let stream_store = StreamStore::create(config, store).await.unwrap(); + let (status, _) = stream_store.get_stream_state().await.unwrap(); + assert!( + matches!(status, StreamStatus::Healthy), + "Stream should return to Healthy after large replay recovery" + ); +} + #[tokio::test(flavor = "multi_thread")] async fn test_failover_checkpoint_persists_across_stream_instances() { let db = TestDatabase::spawn().await;