diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index fe7f5a5ce050d..2383d275f6b52 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -400,6 +400,9 @@ class TIRunContext(BaseModel): task_reschedule_count: int = 0 """How many times the task has been rescheduled.""" + first_task_reschedule_start_date: UtcDateTime | None = None + """The first reschedule start date for the task instance, if it has been rescheduled.""" + max_tries: int """Maximum number of tries for the task instance (from DB).""" diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index ae71e41cf8e5f..ec98197c276d5 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -296,6 +296,14 @@ def ti_run( ) or 0 ) + first_task_reschedule_start_date = None + if task_reschedule_count > 0: + first_task_reschedule_start_date = session.scalar( + select(TaskReschedule.start_date) + .where(TaskReschedule.ti_id == task_instance_id) + .order_by(TaskReschedule.id.asc()) + .limit(1) + ) dr.team_name = get_team_name_for_ti(task_instance_id, session) @@ -309,6 +317,8 @@ def ti_run( xcom_keys_to_clear=xcom_keys, should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries), ) + if first_task_reschedule_start_date is not None: + context.first_task_reschedule_start_date = first_task_reschedule_start_date # Only set if they are non-null if ti.next_method: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 332ddb28704ed..1b6409d1f37d4 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -44,6 +44,7 @@ AddAssetsByAliasEndpoint, AddAwaitingInputStatePayload, AddConnectionTestEndpoint, + AddFirstTaskRescheduleStartDateField, AddRetryPolicyFields, AddTaskAndAssetStateStoreEndpoints, AddTaskInstanceQueueField, @@ -55,6 +56,7 @@ HeadVersion(), Version( "2026-06-30", + AddFirstTaskRescheduleStartDateField, AddVariableKeysEndpoint, AddConnectionTestEndpoint, AddAwaitingInputStatePayload, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py index cbd801c0a9b0b..3c2dc26fbeb01 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py @@ -53,6 +53,16 @@ class AddConnectionTestEndpoint(VersionChange): ) +class AddFirstTaskRescheduleStartDateField(VersionChange): + """Add first_task_reschedule_start_date field to TIRunContext.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + schema(TIRunContext).field("first_task_reschedule_start_date").didnt_exist, + ) + + class AddTaskInstanceQueueField(VersionChange): """Add the `queue` field to the TaskInstance model.""" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 395837c0e61ed..b77f890f79ee1 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -289,6 +289,60 @@ def test_ti_run_state_to_running( ) assert response.status_code == 409 + def test_ti_run_state_includes_first_task_reschedule_start_date( + self, + client, + session, + create_task_instance, + ): + """Test that running a rescheduled Task Instance includes its first reschedule start date.""" + instant_str = "2024-09-30T12:00:00Z" + instant = timezone.parse(instant_str) + first_reschedule_start_date = timezone.datetime(2024, 9, 30, 10) + second_reschedule_start_date = timezone.datetime(2024, 9, 30, 11) + + ti = create_task_instance( + task_id="test_ti_run_state_includes_first_task_reschedule_start_date", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), + ) + session.add_all( + [ + TaskReschedule( + ti_id=ti.id, + start_date=first_reschedule_start_date, + end_date=timezone.datetime(2024, 9, 30, 10, 1), + reschedule_date=timezone.datetime(2024, 9, 30, 10, 2), + ), + TaskReschedule( + ti_id=ti.id, + start_date=second_reschedule_start_date, + end_date=timezone.datetime(2024, 9, 30, 11, 1), + reschedule_date=timezone.datetime(2024, 9, 30, 11, 2), + ), + ] + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": instant_str, + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["task_reschedule_count"] == 2 + assert result["first_task_reschedule_start_date"] == "2024-09-30T10:00:00Z" + def test_ti_run_returns_execution_token( self, client, exec_app, session, create_task_instance, time_machine ): diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py new file mode 100644 index 0000000000000..7304d0e879ed2 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_task_instances.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from airflow._shared.timezones import timezone +from airflow.models import TaskReschedule +from airflow.utils.state import DagRunState, State + +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def old_ver_client(client): + """Last released execution API before first_task_reschedule_start_date was added.""" + client.headers["Airflow-API-Version"] = "2026-06-16" + return client + + +@pytest.fixture(autouse=True) +def setup_teardown(): + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() + yield + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() + + +def test_first_task_reschedule_start_date_removed_from_previous_version( + old_ver_client, + session, + create_task_instance, +): + ti = create_task_instance( + task_id="test_first_task_reschedule_start_date_removed_from_previous_version", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=timezone.datetime(2024, 9, 30, 12), + dag_id=str(uuid4()), + ) + session.add( + TaskReschedule( + ti_id=ti.id, + start_date=timezone.datetime(2024, 9, 30, 10), + end_date=timezone.datetime(2024, 9, 30, 10, 1), + reschedule_date=timezone.datetime(2024, 9, 30, 10, 2), + ) + ) + session.commit() + + response = old_ver_client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": "2024-09-30T12:00:00Z", + }, + ) + + assert response.status_code == 200 + result = response.json() + assert result["task_reschedule_count"] == 1 + assert "first_task_reschedule_start_date" not in result diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 812ca4d041833..1d512ebf49f14 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -786,6 +786,9 @@ class TIRunContext(BaseModel): dag_run: DagRun task_reschedule_count: Annotated[int | None, Field(title="Task Reschedule Count")] = 0 + first_task_reschedule_start_date: Annotated[ + AwareDatetime | None, Field(title="First Task Reschedule Start Date") + ] = None max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json index 4807d9f53b353..0f903b970fbf6 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json +++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json @@ -4693,6 +4693,19 @@ "title": "Task Reschedule Count", "type": "integer" }, + "first_task_reschedule_start_date": { + "anyOf": [ + { + "format": "date-time", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "First Task Reschedule Start Date" + }, "max_tries": { "title": "Max Tries", "type": "integer" diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/versions/__init__.py b/task-sdk/src/airflow/sdk/execution_time/schema/versions/__init__.py index d8ae4e4580ed5..1ee7ac09925b1 100644 --- a/task-sdk/src/airflow/sdk/execution_time/schema/versions/__init__.py +++ b/task-sdk/src/airflow/sdk/execution_time/schema/versions/__init__.py @@ -21,5 +21,6 @@ bundle = VersionBundle( HeadVersion(), + # First supervisor schema version; there is no previous version to migrate from. Version("2026-06-16"), ) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 7eb52038ce9e3..c633b342b21ef 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -585,6 +585,12 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: # If the task has not been rescheduled, there is no need to ask the supervisor return None + first_task_reschedule_start_date = getattr( + self._ti_context_from_server, "first_task_reschedule_start_date", None + ) + if first_task_reschedule_start_date is not None: + return first_task_reschedule_start_date + max_tries: int = self.max_tries retries: int = self.task.retries or 0 first_try_number = max_tries - retries + 1 diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index c1ef3b72c92f4..321dbc7c349de 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -229,6 +229,7 @@ def __call__( run_after: str | datetime = ..., run_type: str = ..., task_reschedule_count: int = ..., + first_task_reschedule_start_date: str | datetime | None = ..., conf: dict[str, Any] | None = ..., should_retry: bool = ..., max_tries: int = ..., @@ -249,6 +250,7 @@ def __call__( run_after: str | datetime = ..., run_type: str = ..., task_reschedule_count: int = ..., + first_task_reschedule_start_date: str | datetime | None = ..., conf=None, consumed_asset_events: Sequence[AssetEventDagRunReference] = ..., ) -> dict[str, Any]: ... @@ -271,6 +273,7 @@ def _make_context( run_after: str | datetime = "2024-12-01T01:00:00Z", run_type: str = "manual", task_reschedule_count: int = 0, + first_task_reschedule_start_date: str | datetime | None = None, conf: dict[str, Any] | None = None, should_retry: bool = False, max_tries: int = 0, @@ -292,6 +295,7 @@ def _make_context( consumed_asset_events=list(consumed_asset_events), ), task_reschedule_count=task_reschedule_count, + first_task_reschedule_start_date=first_task_reschedule_start_date, max_tries=max_tries, should_retry=should_retry, ) @@ -314,6 +318,7 @@ def _make_context_dict( run_after: str | datetime = "2024-12-01T00:00:00Z", run_type: str = "manual", task_reschedule_count: int = 0, + first_task_reschedule_start_date: str | datetime | None = None, conf=None, consumed_asset_events: Sequence[AssetEventDagRunReference] = (), ) -> dict[str, Any]: @@ -329,6 +334,7 @@ def _make_context_dict( run_type=run_type, conf=conf, task_reschedule_count=task_reschedule_count, + first_task_reschedule_start_date=first_task_reschedule_start_date, consumed_asset_events=consumed_asset_events, ) return context.model_dump(exclude_unset=True, mode="json") diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 8c1312833e80a..fdb91bf675cef 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2695,7 +2695,7 @@ def __init__(self, command, *args, **kwargs): def test_get_first_reschedule_date( self, create_runtime_ti, mock_supervisor_comms, task_reschedule_count, expected_date ): - """Test that the first reschedule date is fetched from the Supervisor.""" + """Test that the first reschedule date falls back to the Supervisor.""" task = BaseOperator(task_id="hello") runtime_ti = create_runtime_ti(task=task, task_reschedule_count=task_reschedule_count) @@ -2706,6 +2706,22 @@ def test_get_first_reschedule_date( context = runtime_ti.get_template_context() assert runtime_ti.get_first_reschedule_date(context=context) == expected_date + def test_get_first_reschedule_date_uses_context_from_server( + self, create_runtime_ti, make_ti_context, mock_supervisor_comms + ): + """Test that first reschedule date from server context avoids a Supervisor request.""" + first_reschedule_date = timezone.datetime(2025, 1, 1) + task = BaseOperator(task_id="hello") + runtime_ti = create_runtime_ti(task=task, task_reschedule_count=1) + runtime_ti._ti_context_from_server = make_ti_context( + task_reschedule_count=1, + first_task_reschedule_start_date=first_reschedule_date, + ) + + context = runtime_ti.get_template_context() + assert runtime_ti.get_first_reschedule_date(context=context) == first_reschedule_date + mock_supervisor_comms.send.assert_not_called() + def test_get_ti_count(self, mock_supervisor_comms): """Test that get_ti_count sends the correct request and returns the count.""" mock_supervisor_comms.send.return_value = TICount(count=2)