getStateClass() {
@@ -436,7 +498,8 @@ public Executing getState() {
userCodeClassLoader,
failureCollection,
stateTransitionManagerFactory,
- rescaleOnFailedCheckpointCount);
+ rescaleOnFailedCheckpointCount,
+ activeCheckpointTriggerEnabled);
}
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/StateTransitionManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/StateTransitionManager.java
index 98229a9afd3e3..e057addd40f4a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/StateTransitionManager.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptive/StateTransitionManager.java
@@ -89,5 +89,24 @@ interface Context extends RescaleContext {
* @return the {@link JobID} of the job
*/
JobID getJobId();
+
+ /**
+ * Requests the context to actively trigger a checkpoint to expedite rescaling. Called by
+ * the {@link DefaultStateTransitionManager} from within phase lifecycle methods:
+ *
+ *
+ * - On entering {@link DefaultStateTransitionManager.Stabilizing} (to overlap
+ * checkpoint with the stabilization wait)
+ *
- On each {@link DefaultStateTransitionManager.Stabilizing#onChange} event (retry if
+ * a previous trigger was skipped)
+ *
- On entering {@link DefaultStateTransitionManager.Stabilized} (fallback if no
+ * checkpoint completed during stabilization)
+ *
+ *
+ * The implementation decides whether to actually trigger based on its own guard
+ * conditions (e.g., checkpointing enabled, no checkpoint in progress, config flag).
+ * Multiple calls are safe; guards prevent redundant triggers.
+ */
+ default void requestActiveCheckpointTrigger() {}
}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
index f593082f6c75b..b139474e0a4c7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/ExecutingTest.java
@@ -160,7 +160,8 @@ void testNoDeploymentCallOnEnterWhenVertexRunning() throws Exception {
ClassLoader.getSystemClassLoader(),
new ArrayList<>(),
(context) -> TestingStateTransitionManager.withNoOp(),
- 1);
+ 1,
+ false);
assertThat(mockExecutionVertex.isDeployCalled()).isFalse();
}
}
@@ -186,7 +187,8 @@ void testIllegalStateExceptionOnNotRunningExecutionGraph() {
ClassLoader.getSystemClassLoader(),
new ArrayList<>(),
context -> TestingStateTransitionManager.withNoOp(),
- 1);
+ 1,
+ false);
}
})
.isInstanceOf(IllegalStateException.class);
@@ -691,6 +693,7 @@ private final class ExecutingStateBuilder {
private Function
stateTransitionManagerFactory = context -> TestingStateTransitionManager.withNoOp();
private int rescaleOnFailedCheckpointCount = 1;
+ private boolean activeCheckpointTriggerEnabled = false;
private ExecutingStateBuilder() throws JobException, JobExecutionException {
operatorCoordinatorHandler = new TestingOperatorCoordinatorHandler();
@@ -733,7 +736,8 @@ private Executing build(MockExecutingContext ctx) {
ClassLoader.getSystemClassLoader(),
new ArrayList<>(),
stateTransitionManagerFactory::apply,
- rescaleOnFailedCheckpointCount);
+ rescaleOnFailedCheckpointCount,
+ activeCheckpointTriggerEnabled);
} finally {
Preconditions.checkState(
!ctx.hadStateTransition,
@@ -1029,6 +1033,12 @@ public boolean updateState(TaskExecutionStateTransition state) {
public Iterable getVerticesTopologically() {
return getVerticesTopologicallySupplier.get();
}
+
+ @Nullable
+ @Override
+ public CheckpointCoordinator getCheckpointCoordinator() {
+ return null;
+ }
}
private static class FinishingMockExecutionGraph extends StateTrackingMockExecutionGraph {
diff --git a/flink-tests/src/test/java/org/apache/flink/test/scheduling/RescaleOnCheckpointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/scheduling/RescaleOnCheckpointITCase.java
index 6a24600f1ace1..e568e0aa826a8 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/scheduling/RescaleOnCheckpointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/scheduling/RescaleOnCheckpointITCase.java
@@ -175,4 +175,63 @@ void testRescaleOnCheckpoint(
restClusterClient.cancel(jobGraph.getJobID()).join();
}
}
+
+ @Test
+ void testRescaleWithActiveCheckpointTrigger(
+ @InjectMiniCluster MiniCluster miniCluster,
+ @InjectClusterClient RestClusterClient> restClusterClient)
+ throws Exception {
+ final Configuration config = new Configuration();
+ config.set(
+ JobManagerOptions.SCHEDULER_RESCALE_TRIGGER_ACTIVE_CHECKPOINT_ENABLED, true);
+
+ final StreamExecutionEnvironment env =
+ StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(BEFORE_RESCALE_PARALLELISM);
+ // Use a checkpoint interval far exceeding CI timeout (4h) so periodic checkpoints
+ // cannot cause a false-pass. Only the active checkpoint trigger can rescue this test.
+ env.enableCheckpointing(Duration.ofHours(24).toMillis());
+ env.fromSequence(0, Integer.MAX_VALUE).sinkTo(new DiscardingSink<>());
+
+ final JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+ final Iterator jobVertexIterator = jobGraph.getVertices().iterator();
+ assertThat(jobVertexIterator.hasNext()).isTrue();
+ final JobVertexID jobVertexId = jobVertexIterator.next().getID();
+
+ final JobResourceRequirements jobResourceRequirements =
+ JobResourceRequirements.newBuilder()
+ .setParallelismForJobVertex(jobVertexId, 1, AFTER_RESCALE_PARALLELISM)
+ .build();
+
+ restClusterClient.submitJob(jobGraph).join();
+
+ final JobID jobId = jobGraph.getJobID();
+ try {
+ LOG.info(
+ "Waiting for job {} to reach parallelism of {} for vertex {}.",
+ jobId,
+ BEFORE_RESCALE_PARALLELISM,
+ jobVertexId);
+ waitForRunningTasks(restClusterClient, jobId, BEFORE_RESCALE_PARALLELISM);
+
+ LOG.info(
+ "Updating job {} resource requirements: parallelism {} -> {}.",
+ jobId,
+ BEFORE_RESCALE_PARALLELISM,
+ AFTER_RESCALE_PARALLELISM);
+ restClusterClient.updateJobResourceRequirements(jobId, jobResourceRequirements).join();
+ LOG.info(
+ "Waiting for job {} to rescale to parallelism {} via active checkpoint trigger.",
+ jobId,
+ AFTER_RESCALE_PARALLELISM);
+ waitForRunningTasks(restClusterClient, jobId, AFTER_RESCALE_PARALLELISM);
+ final int expectedFreeSlotCount = NUMBER_OF_SLOTS - AFTER_RESCALE_PARALLELISM;
+ LOG.info(
+ "Waiting for {} slot(s) to become available after scale down.",
+ expectedFreeSlotCount);
+ waitForAvailableSlots(restClusterClient, expectedFreeSlotCount);
+ } finally {
+ restClusterClient.cancel(jobGraph.getJobID()).join();
+ }
+ }
}