From f3bd4aae0da6ad23a2b06500b6c9354dec140a77 Mon Sep 17 00:00:00 2001 From: Hemkumar Chheda Date: Thu, 4 Jun 2026 15:23:21 +0530 Subject: [PATCH] Fix rescheduled sensors hanging before poke closes: #68010 --- .../execution_api/datamodels/taskinstance.py | 3 + .../execution_api/routes/task_instances.py | 10 +++ .../execution_api/versions/__init__.py | 2 + .../execution_api/versions/v2026_06_30.py | 11 +++ .../versions/head/test_task_instances.py | 54 ++++++++++++ .../v2026_06_30/test_task_instances.py | 88 +++++++++++++++++++ .../airflow/sdk/api/datamodels/_generated.py | 3 + .../sdk/execution_time/schema/schema.json | 13 +++ .../schema/versions/__init__.py | 1 + .../airflow/sdk/execution_time/task_runner.py | 6 ++ task-sdk/tests/conftest.py | 6 ++ .../execution_time/test_task_runner.py | 18 +++- 12 files changed, 214 insertions(+), 1 deletion(-) create mode 100644 airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index fe7f5a5ce050d..2383d275f6b52 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -400,6 +400,9 @@ class TIRunContext(BaseModel): task_reschedule_count: int = 0 """How many times the task has been rescheduled.""" + first_task_reschedule_start_date: UtcDateTime | None = None + """The first reschedule start date for the task instance, if it has been rescheduled.""" + max_tries: int """Maximum number of tries for the task instance (from DB).""" diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 8a91614e09aeb..dd28ebe467254 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -295,6 +295,14 @@ def ti_run( ) or 0 ) + first_task_reschedule_start_date = None + if task_reschedule_count > 0: + first_task_reschedule_start_date = session.scalar( + select(TaskReschedule.start_date) + .where(TaskReschedule.ti_id == task_instance_id) + .order_by(TaskReschedule.id.asc()) + .limit(1) + ) dr.team_name = get_team_name_for_ti(task_instance_id, session) @@ -308,6 +316,8 @@ def ti_run( xcom_keys_to_clear=xcom_keys, should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries), ) + if first_task_reschedule_start_date is not None: + context.first_task_reschedule_start_date = first_task_reschedule_start_date # Only set if they are non-null if ti.next_method: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 656bb8dce107f..178c5d7636188 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -49,6 +49,7 @@ from airflow.api_fastapi.execution_api.versions.v2026_06_30 import ( AddAwaitingInputStatePayload, AddConnectionTestEndpoint, + AddFirstTaskRescheduleStartDateField, AddTaskInstanceQueueField, AddVariableKeysEndpoint, ) @@ -57,6 +58,7 @@ HeadVersion(), Version( "2026-06-30", + AddFirstTaskRescheduleStartDateField, AddVariableKeysEndpoint, AddConnectionTestEndpoint, AddAwaitingInputStatePayload, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py index cfa5b616396bf..01595a30bab15 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py @@ -22,6 +22,7 @@ from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( TaskInstance, TIAwaitingInputStatePayload, + TIRunContext, ) @@ -44,6 +45,16 @@ class AddConnectionTestEndpoint(VersionChange): ) +class AddFirstTaskRescheduleStartDateField(VersionChange): + """Add first_task_reschedule_start_date field to TIRunContext.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(TIRunContext).field("first_task_reschedule_start_date").didnt_exist, + ) + + class AddTaskInstanceQueueField(VersionChange): """Add the `queue` field to the TaskInstance model.""" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 4d21bb406f9d3..9d18b859d6507 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -289,6 +289,60 @@ def test_ti_run_state_to_running( ) assert response.status_code == 409 + def test_ti_run_state_includes_first_task_reschedule_start_date( + self, + client, + session, + create_task_instance, + ): + """Test that running a rescheduled Task Instance includes its first reschedule start date.""" + instant_str = "2024-09-30T12:00:00Z" + instant = timezone.parse(instant_str) + first_reschedule_start_date = timezone.datetime(2024, 9, 30, 10) + second_reschedule_start_date = timezone.datetime(2024, 9, 30, 11) + + ti = create_task_instance( + task_id="test_ti_run_state_includes_first_task_reschedule_start_date", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), + ) + session.add_all( + [ + TaskReschedule( + ti_id=ti.id, + start_date=first_reschedule_start_date, + end_date=timezone.datetime(2024, 9, 30, 10, 1), + reschedule_date=timezone.datetime(2024, 9, 30, 10, 2), + ), + TaskReschedule( + ti_id=ti.id, + start_date=second_reschedule_start_date, + end_date=timezone.datetime(2024, 9, 30, 11, 1), + reschedule_date=timezone.datetime(2024, 9, 30, 11, 2), + ), + ] + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": instant_str, + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["task_reschedule_count"] == 2 + assert result["first_task_reschedule_start_date"] == "2024-09-30T10:00:00Z" + def test_ti_run_returns_execution_token( self, client, exec_app, session, create_task_instance, time_machine ): diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py new file mode 100644 index 0000000000000..7304d0e879ed2 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from airflow._shared.timezones import timezone +from airflow.models import TaskReschedule +from airflow.utils.state import DagRunState, State + +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def old_ver_client(client): + """Last released execution API before first_task_reschedule_start_date was added.""" + client.headers["Airflow-API-Version"] = "2026-06-16" + return client + + +@pytest.fixture(autouse=True) +def setup_teardown(): + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() + yield + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() + + +def test_first_task_reschedule_start_date_removed_from_previous_version( + old_ver_client, + session, + create_task_instance, +): + ti = create_task_instance( + task_id="test_first_task_reschedule_start_date_removed_from_previous_version", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=timezone.datetime(2024, 9, 30, 12), + dag_id=str(uuid4()), + ) + session.add( + TaskReschedule( + ti_id=ti.id, + start_date=timezone.datetime(2024, 9, 30, 10), + end_date=timezone.datetime(2024, 9, 30, 10, 1), + reschedule_date=timezone.datetime(2024, 9, 30, 10, 2), + ) + ) + session.commit() + + response = old_ver_client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": "2024-09-30T12:00:00Z", + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["task_reschedule_count"] == 1 + assert "first_task_reschedule_start_date" not in result diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 99aedae25b03b..d6d4f259f8131 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -786,6 +786,9 @@ class TIRunContext(BaseModel): dag_run: DagRun task_reschedule_count: Annotated[int | None, Field(title="Task Reschedule Count")] = 0 + first_task_reschedule_start_date: Annotated[ + AwareDatetime | None, Field(title="First Task Reschedule Start Date") + ] = None max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json index 0d1a3dc61c703..d7f117d43f648 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json +++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json @@ -4693,6 +4693,19 @@ "title": "Task Reschedule Count", "type": "integer" }, + "first_task_reschedule_start_date": { + "anyOf": [ + { + "format": "date-time", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "First Task Reschedule Start Date" + }, "max_tries": { "title": "Max Tries", "type": "integer" diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/versions/__init__.py b/task-sdk/src/airflow/sdk/execution_time/schema/versions/__init__.py index d8ae4e4580ed5..1ee7ac09925b1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/versions/__init__.py +++ b/task-sdk/src/airflow/sdk/execution_time/schema/versions/__init__.py @@ -21,5 +21,6 @@ bundle = VersionBundle( HeadVersion(), + # First supervisor schema version; there is no previous version to migrate from. Version("2026-06-16"), ) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 26d1268df020a..f652f1546b72e 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -585,6 +585,12 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: # If the task has not been rescheduled, there is no need to ask the supervisor return None + first_task_reschedule_start_date = getattr( + self._ti_context_from_server, "first_task_reschedule_start_date", None + ) + if first_task_reschedule_start_date is not None: + return first_task_reschedule_start_date + max_tries: int = self.max_tries retries: int = self.task.retries or 0 first_try_number = max_tries - retries + 1 diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index c1ef3b72c92f4..321dbc7c349de 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -229,6 +229,7 @@ def __call__( run_after: str | datetime = ..., run_type: str = ..., task_reschedule_count: int = ..., + first_task_reschedule_start_date: str | datetime | None = ..., conf: dict[str, Any] | None = ..., should_retry: bool = ..., max_tries: int = ..., @@ -249,6 +250,7 @@ def __call__( run_after: str | datetime = ..., run_type: str = ..., task_reschedule_count: int = ..., + first_task_reschedule_start_date: str | datetime | None = ..., conf=None, consumed_asset_events: Sequence[AssetEventDagRunReference] = ..., ) -> dict[str, Any]: ... @@ -271,6 +273,7 @@ def _make_context( run_after: str | datetime = "2024-12-01T01:00:00Z", run_type: str = "manual", task_reschedule_count: int = 0, + first_task_reschedule_start_date: str | datetime | None = None, conf: dict[str, Any] | None = None, should_retry: bool = False, max_tries: int = 0, @@ -292,6 +295,7 @@ def _make_context( consumed_asset_events=list(consumed_asset_events), ), task_reschedule_count=task_reschedule_count, + first_task_reschedule_start_date=first_task_reschedule_start_date, max_tries=max_tries, should_retry=should_retry, ) @@ -314,6 +318,7 @@ def _make_context_dict( run_after: str | datetime = "2024-12-01T00:00:00Z", run_type: str = "manual", task_reschedule_count: int = 0, + first_task_reschedule_start_date: str | datetime | None = None, conf=None, consumed_asset_events: Sequence[AssetEventDagRunReference] = (), ) -> dict[str, Any]: @@ -329,6 +334,7 @@ def _make_context_dict( run_type=run_type, conf=conf, task_reschedule_count=task_reschedule_count, + first_task_reschedule_start_date=first_task_reschedule_start_date, consumed_asset_events=consumed_asset_events, ) return context.model_dump(exclude_unset=True, mode="json") diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 48b5139046963..43aaf309c5da3 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2695,7 +2695,7 @@ def __init__(self, command, *args, **kwargs): def test_get_first_reschedule_date( self, create_runtime_ti, mock_supervisor_comms, task_reschedule_count, expected_date ): - """Test that the first reschedule date is fetched from the Supervisor.""" + """Test that the first reschedule date falls back to the Supervisor.""" task = BaseOperator(task_id="hello") runtime_ti = create_runtime_ti(task=task, task_reschedule_count=task_reschedule_count) @@ -2706,6 +2706,22 @@ def test_get_first_reschedule_date( context = runtime_ti.get_template_context() assert runtime_ti.get_first_reschedule_date(context=context) == expected_date + def test_get_first_reschedule_date_uses_context_from_server( + self, create_runtime_ti, make_ti_context, mock_supervisor_comms + ): + """Test that first reschedule date from server context avoids a Supervisor request.""" + first_reschedule_date = timezone.datetime(2025, 1, 1) + task = BaseOperator(task_id="hello") + runtime_ti = create_runtime_ti(task=task, task_reschedule_count=1) + runtime_ti._ti_context_from_server = make_ti_context( + task_reschedule_count=1, + first_task_reschedule_start_date=first_reschedule_date, + ) + + context = runtime_ti.get_template_context() + assert runtime_ti.get_first_reschedule_date(context=context) == first_reschedule_date + mock_supervisor_comms.send.assert_not_called() + def test_get_ti_count(self, mock_supervisor_comms): """Test that get_ti_count sends the correct request and returns the count.""" mock_supervisor_comms.send.return_value = TICount(count=2)