Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions src/replay_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::io::BufReader;
use std::num::NonZeroI32;
use std::sync::Arc;

use etl::config::{ETL_REPLICATION_OPTIONS, IntoConnectOptions, PgConnectionConfig};
use etl::config::{
ETL_REPLICATION_OPTIONS, ETL_STATE_MANAGEMENT_OPTIONS, IntoConnectOptions, PgConnectionConfig,
};
use etl::error::EtlResult;
use etl_postgres::replication::extract_server_version;
use tokio_postgres::tls::MakeTlsConnect;
Expand Down Expand Up @@ -38,7 +40,10 @@ where
#[derive(Debug, Clone)]
pub struct ReplayClient {
stream_id: StreamId,
client: Arc<Client>,
/// Connection used for COPY streaming of replay events.
copy_client: Arc<Client>,
/// Separate connection used for checkpoint updates during replay.
checkpoint_client: Arc<Client>,
/// Server version extracted from connection - reserved for future version-specific logic
#[allow(dead_code)]
server_version: Option<NonZeroI32>,
Expand Down Expand Up @@ -67,22 +72,28 @@ impl ReplayClient {
stream_id: StreamId,
pg_connection_config: PgConnectionConfig,
) -> EtlResult<Self> {
let config: Config = pg_connection_config
let copy_config: Config = pg_connection_config
.clone()
.with_db(Some(&ETL_REPLICATION_OPTIONS));

let (client, connection) = config.connect(NoTls).await?;
let (copy_client, copy_connection) = copy_config.connect(NoTls).await?;

let server_version = connection
let server_version = copy_connection
.parameter("server_version")
.and_then(extract_server_version);

spawn_postgres_connection::<NoTls>(connection);
spawn_postgres_connection::<NoTls>(copy_connection);

let checkpoint_config: Config =
pg_connection_config.with_db(Some(&ETL_STATE_MANAGEMENT_OPTIONS));
let (checkpoint_client, checkpoint_connection) = checkpoint_config.connect(NoTls).await?;
spawn_postgres_connection::<NoTls>(checkpoint_connection);

info!("successfully connected to postgres without tls");

Ok(ReplayClient {
client: Arc::new(client),
copy_client: Arc::new(copy_client),
checkpoint_client: Arc::new(checkpoint_client),
server_version,
stream_id,
})
Expand All @@ -95,7 +106,7 @@ impl ReplayClient {
stream_id: StreamId,
pg_connection_config: PgConnectionConfig,
) -> EtlResult<Self> {
let config: Config = pg_connection_config
let copy_config: Config = pg_connection_config
.clone()
.with_db(Some(&ETL_REPLICATION_OPTIONS));

Expand All @@ -113,18 +124,28 @@ impl ReplayClient {
.with_root_certificates(root_store)
.with_no_client_auth();

let (client, connection) = config.connect(MakeRustlsConnect::new(tls_config)).await?;
let (copy_client, copy_connection) = copy_config
.connect(MakeRustlsConnect::new(tls_config.clone()))
.await?;

let server_version = connection
let server_version = copy_connection
.parameter("server_version")
.and_then(extract_server_version);

spawn_postgres_connection::<MakeRustlsConnect>(connection);
spawn_postgres_connection::<MakeRustlsConnect>(copy_connection);

let checkpoint_config: Config =
pg_connection_config.with_db(Some(&ETL_STATE_MANAGEMENT_OPTIONS));
let (checkpoint_client, checkpoint_connection) = checkpoint_config
.connect(MakeRustlsConnect::new(tls_config))
.await?;
spawn_postgres_connection::<MakeRustlsConnect>(checkpoint_connection);

info!("successfully connected to postgres with tls");

Ok(ReplayClient {
client: Arc::new(client),
copy_client: Arc::new(copy_client),
checkpoint_client: Arc::new(checkpoint_client),
server_version,
stream_id,
})
Expand All @@ -133,7 +154,7 @@ impl ReplayClient {
/// Checks if the underlying connection is closed.
#[must_use]
pub fn is_closed(&self) -> bool {
self.client.is_closed()
self.copy_client.is_closed() || self.checkpoint_client.is_closed()
}

/// Gets events between two checkpoints (exclusive on both ends).
Expand All @@ -150,12 +171,13 @@ impl ReplayClient {
where (created_at, id) > ('{}'::timestamptz, '{}'::uuid)
and (created_at, id) < ('{}'::timestamptz, '{}'::uuid)
and stream_id = {}
order by created_at, id
) to stdout with (format text);
"#,
from.created_at, from.id, to.created_at, to.id, self.stream_id as i64
);

let stream = self.client.copy_out_simple(&copy_query).await?;
let stream = self.copy_client.copy_out_simple(&copy_query).await?;

Ok(stream)
}
Expand All @@ -164,7 +186,7 @@ impl ReplayClient {
///
/// This is duplicated from the [`StreamStore`] because we want to use a persistent connection during failover.
pub async fn update_checkpoint(&self, checkpoint: &EventIdentifier) -> EtlResult<()> {
self.client
self.checkpoint_client
.execute(
r#"
update pgstream.streams
Expand Down
110 changes: 104 additions & 6 deletions src/slot_recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ use tracing::{info, warn};

use crate::types::SlotName;

type Checkpoint = (String, DateTime<Utc>);

#[must_use]
fn checkpoint_is_earlier(a: &Checkpoint, b: &Checkpoint) -> bool {
a.1 < b.1 || (a.1 == b.1 && a.0 < b.0)
}

#[must_use]
fn select_recovery_checkpoint(
existing_checkpoint: Option<Checkpoint>,
lsn_checkpoint: Option<Checkpoint>,
) -> Option<Checkpoint> {
match (existing_checkpoint, lsn_checkpoint) {
(Some(existing), Some(from_lsn)) => {
if checkpoint_is_earlier(&existing, &from_lsn) {
Some(existing)
} else {
Some(from_lsn)
}
}
(Some(existing), None) => Some(existing),
(None, Some(from_lsn)) => Some(from_lsn),
(None, None) => None,
}
}

/// Checks if an error indicates a replication slot has been invalidated.
///
/// Postgres returns error code 55000 (OBJECT_NOT_IN_PREREQUISITE_STATE) with the message
Expand Down Expand Up @@ -66,6 +92,24 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<()
// Start a transaction for the checkpoint update
let mut tx = pool.begin().await?;

// Preserve an existing failover checkpoint if we are already in failover mode.
let existing_checkpoint_row: Option<(Option<String>, Option<DateTime<Utc>>)> = sqlx::query_as(
"SELECT failover_checkpoint_id, failover_checkpoint_ts FROM pgstream.streams WHERE id = $1",
)
.bind(stream_id as i64)
.fetch_optional(&mut *tx)
.await?;

let existing_checkpoint = existing_checkpoint_row.and_then(|(id, ts)| id.zip(ts));

if let Some((id, created_at)) = &existing_checkpoint {
info!(
event_id = %id,
event_created_at = %created_at,
"existing failover checkpoint found"
);
}

// 1. Get confirmed_flush_lsn BEFORE dropping the slot
let confirmed_lsn: Option<String> = sqlx::query_scalar(
"SELECT confirmed_flush_lsn::text FROM pg_replication_slots WHERE slot_name = $1",
Expand All @@ -91,7 +135,7 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<()
);

// 2. Find the first event after the confirmed LSN
let checkpoint: Option<(String, DateTime<Utc>)> = sqlx::query_as(
let lsn_checkpoint: Option<Checkpoint> = sqlx::query_as(
"SELECT id::text, created_at FROM pgstream.events
WHERE lsn > $1::pg_lsn AND stream_id = $2
ORDER BY created_at, id LIMIT 1",
Expand All @@ -101,8 +145,23 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<()
.fetch_optional(&mut *tx)
.await?;

// 3. Set failover checkpoint BEFORE dropping slot (crash-safe ordering)
if let Some((id, created_at)) = checkpoint {
// 3. Choose the earliest safe checkpoint.
// If we are already in failover mode, keep the existing checkpoint unless
// the LSN-derived checkpoint is earlier.
let checkpoint = select_recovery_checkpoint(existing_checkpoint.clone(), lsn_checkpoint);

// 4. Set failover checkpoint BEFORE dropping slot (crash-safe ordering)
if checkpoint == existing_checkpoint {
if let Some((id, created_at)) = checkpoint {
info!(
event_id = %id,
event_created_at = %created_at,
"preserving existing failover checkpoint during slot recovery"
);
} else {
info!("no events found after confirmed_flush_lsn, pipeline will start fresh");
}
} else if let Some((id, created_at)) = checkpoint {
info!(
event_id = %id,
event_created_at = %created_at,
Expand All @@ -124,7 +183,7 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<()
info!("no events found after confirmed_flush_lsn, pipeline will start fresh");
}

// 4. Delete ETL replication state so ETL will create a fresh slot on restart
// 5. Delete ETL replication state so ETL will create a fresh slot on restart
// This triggers DataSync, but we skip it by returning Ok(()) from write_table_rows.
// The failover checkpoint ensures we COPY missed events when replication starts.
let deleted = sqlx::query("DELETE FROM etl.replication_state WHERE pipeline_id = $1")
Expand All @@ -137,10 +196,10 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<()
"deleted ETL replication state to trigger fresh slot creation"
);

// 5. Commit the transaction - checkpoint is now durable
// 6. Commit the transaction - checkpoint is now durable
tx.commit().await?;

// 6. Drop the invalidated slot AFTER commit (non-transactional operation)
// 7. Drop the invalidated slot AFTER commit (non-transactional operation)
// This ordering ensures crash safety: if we crash here, the checkpoint is
// already saved, and the next recovery attempt will simply drop the slot.
let drop_result = sqlx::query("SELECT pg_drop_replication_slot($1)")
Expand All @@ -164,6 +223,7 @@ pub async fn handle_slot_recovery(pool: &PgPool, stream_id: u64) -> EtlResult<()
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;

#[test]
fn test_is_slot_invalidation_error_matches() {
Expand Down Expand Up @@ -206,4 +266,42 @@ mod tests {
let error = etl::etl_error!(etl::error::ErrorKind::InvalidState, "connection refused");
assert!(!is_slot_invalidation_error(&error));
}

#[test]
fn test_select_recovery_checkpoint_prefers_earlier_existing_checkpoint() {
let ts = Utc.with_ymd_and_hms(2024, 3, 15, 12, 0, 0).unwrap();
let existing = Some(("00000000-0000-0000-0000-000000000001".to_string(), ts));
let from_lsn = Some((
"00000000-0000-0000-0000-000000000002".to_string(),
ts + chrono::Duration::seconds(1),
));

assert_eq!(
select_recovery_checkpoint(existing.clone(), from_lsn),
existing
);
}

#[test]
fn test_select_recovery_checkpoint_prefers_earlier_lsn_checkpoint() {
let ts = Utc.with_ymd_and_hms(2024, 3, 15, 12, 0, 0).unwrap();
let existing = Some((
"00000000-0000-0000-0000-000000000002".to_string(),
ts + chrono::Duration::seconds(1),
));
let from_lsn = Some(("00000000-0000-0000-0000-000000000001".to_string(), ts));

assert_eq!(
select_recovery_checkpoint(existing, from_lsn.clone()),
from_lsn
);
}

#[test]
fn test_select_recovery_checkpoint_uses_existing_when_lsn_checkpoint_missing() {
let ts = Utc.with_ymd_and_hms(2024, 3, 15, 12, 0, 0).unwrap();
let existing = Some(("00000000-0000-0000-0000-000000000001".to_string(), ts));

assert_eq!(select_recovery_checkpoint(existing.clone(), None), existing);
}
}
35 changes: 25 additions & 10 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,31 @@ where

let result = self.sink.publish_events(events).await;
if result.is_err() {
info!(
"Publishing events failed, entering failover at checkpoint event id: {:?}",
checkpoint_id
);
metrics::record_failover_entered(self.config.id);
self.store
.store_stream_status(StreamStatus::Failover {
checkpoint_event_id: checkpoint_id,
})
.await?;
let (current_status, _) = self.store.get_stream_state().await?;

match current_status {
StreamStatus::Healthy => {
info!(
"Publishing events failed, entering failover at checkpoint event id: {:?}",
checkpoint_id
);
metrics::record_failover_entered(self.config.id);
self.store
.store_stream_status(StreamStatus::Failover {
checkpoint_event_id: checkpoint_id,
})
.await?;
}
StreamStatus::Failover {
checkpoint_event_id,
} => {
info!(
"Publishing events failed while already in failover, preserving checkpoint event id: {:?}",
checkpoint_event_id
);
}
}

return Ok(());
}

Expand Down
Loading