From c81b4a7bf84538aaaa4c146851534698188f5ee4 Mon Sep 17 00:00:00 2001 From: PoAn Yang Date: Wed, 24 Jun 2026 17:17:02 +0900 Subject: [PATCH] Return 410 for stale id for set rtif API When a task whose operator sets overwrite_rtif_after_execution=True leaves RUNNING via a retry or a clear, the server regenerates the task instance id and archives the old one. finalize() still overwrites RTIF with the stale id, so the API server returns 404 and the worker logs a spurious error traceback on top of the task's real outcome. The RTIF will be wrote in next id. Signed-off-by: PoAn Yang --- .../execution_api/routes/task_instances.py | 59 ++++++++++--------- .../versions/head/test_task_instances.py | 31 +++++++++- .../airflow/sdk/execution_time/supervisor.py | 8 ++- .../execution_time/test_supervisor.py | 39 ++++++++++++ 4 files changed, 107 insertions(+), 30 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 0f862e1625832..e94ca7b05848d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -22,7 +22,7 @@ import json from collections import defaultdict from collections.abc import Iterator -from typing import TYPE_CHECKING, Annotated, Any, cast +from typing import TYPE_CHECKING, Annotated, Any, NoReturn, cast from uuid import UUID import attrs @@ -852,6 +852,29 @@ def ti_skip_downstream( log.info("Downstream tasks skipped", tasks_skipped=getattr(result, "rowcount", 0)) +def _raise_ti_not_in_live_table(task_instance_id: UUID, session: SessionDep) -> NoReturn: + """Raise 410 Gone if the missing TI id was archived to history, else 404 Not Found.""" + if session.scalar( + select(func.count(TIH.task_instance_id)).where(TIH.task_instance_id == task_instance_id) + ): + log.error("TaskInstance not in live table but archived in history", ti_id=str(task_instance_id)) + raise HTTPException( + status_code=status.HTTP_410_GONE, + detail={ + "reason": "not_found", + "message": "Task Instance not found, it may have been moved to the Task Instance History table", + }, + ) + log.error("Task Instance not found", ti_id=str(task_instance_id)) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": "Task Instance not found", + }, + ) + + @ti_id_router.put( "/{task_instance_id}/heartbeat", status_code=status.HTTP_204_NO_CONTENT, @@ -912,29 +935,7 @@ def ti_heartbeat( # Check if the TI exists in the Task Instance History table. # If it does, it was likely cleared while running, so return 410 Gone # instead of 404 Not Found to give the client a more specific signal. - tih_exists = session.scalar( - select(func.count(TIH.task_instance_id)).where(TIH.task_instance_id == task_instance_id) - ) - if tih_exists: - log.error( - "TaskInstance was previously cleared and archived in history, heartbeat skipped", - ti_id=str(task_instance_id), - ) - raise HTTPException( - status_code=status.HTTP_410_GONE, - detail={ - "reason": "not_found", - "message": "Task Instance not found, it may have been moved to the Task Instance History table", - }, - ) - log.error("Task Instance not found") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={ - "reason": "not_found", - "message": "Task Instance not found", - }, - ) + _raise_ti_not_in_live_table(task_instance_id, session) if hostname != ti_payload.hostname or pid != ti_payload.pid: log.warning( @@ -981,6 +982,10 @@ def ti_heartbeat( responses=create_openapi_http_exception_doc( [ (status.HTTP_404_NOT_FOUND, "Task Instance not found"), + ( + status.HTTP_410_GONE, + "Task Instance not found in the TI table but exists in the Task Instance History table", + ), ( HTTP_422_UNPROCESSABLE_CONTENT, "Invalid payload for the setting rendered task instance fields", @@ -999,10 +1004,8 @@ def ti_put_rtif( task_instance = session.scalar(select(TI).where(TI.id == task_instance_id)) if not task_instance: - log.error("Task Instance not found") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - ) + # On retry/clear, the server regenerates the TI id. Return 410 for the stale id. + _raise_ti_not_in_live_table(task_instance_id, session) task_instance.update_rtif(put_rtif_payload, session=session) log.debug("RenderedTaskInstanceFields updated successfully") diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 90325bb9c8542..c5ebdfc55654d 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -2794,7 +2794,36 @@ def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance): random_id = uuid6.uuid7() response = client.put(f"/execution/task-instances/{random_id}/rtif", json=payload) assert response.status_code == 404 - assert response.json()["detail"] == "Not Found" + assert response.json()["detail"] == { + "reason": "not_found", + "message": "Task Instance not found", + } + + def test_ti_put_rtif_archived_ti_returns_410(self, client, session, create_task_instance): + ti = create_task_instance( + task_id="test_ti_put_rtif_archived", + state=State.RUNNING, + session=session, + ) + session.commit() + old_ti_id = ti.id + + # Archive the current try to TIH and assign a new UUID, mirroring prepare_db_for_next_try(). + ti.prepare_db_for_next_try(session) + session.commit() + + assert session.get(TaskInstance, old_ti_id) is None + + response = client.put( + f"/execution/task-instances/{old_ti_id}/rtif", + json={"field1": "rendered_value1"}, + ) + + assert response.status_code == 410 + assert response.json()["detail"] == { + "reason": "not_found", + "message": "Task Instance not found, it may have been moved to the Task Instance History table", + } class TestPreviousDagRun: diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 76dfd2009ac18..3c3a18bbbef44 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1729,7 +1729,13 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: elif isinstance(msg, PutVariable): resp, dump_opts = handle_put_variable(self.client, msg) elif isinstance(msg, SetRenderedFields): - self.client.task_instances.set_rtif(self.id, msg.rendered_fields) + try: + self.client.task_instances.set_rtif(self.id, msg.rendered_fields) + except ServerResponseError as e: + # The TI id was archived when the server regenerated it on retry/clear, so skip 410. + if e.response.status_code != HTTPStatus.GONE: + raise + log.debug("Skipping RTIF overwrite; task instance archived on retry/clear", ti_id=self.id) elif isinstance(msg, SetRenderedMapIndex): self.client.task_instances.set_rendered_map_index(self.id, msg.rendered_map_index) elif isinstance(msg, GetAssetByName): diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 8dea3d0793f49..81b60c6a7d8a3 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -3139,6 +3139,45 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker): "detail": error.response.json(), } + @pytest.mark.parametrize( + ("status_code", "expects_error"), + [ + pytest.param(410, False, id="410_gone_is_swallowed"), + pytest.param(404, True, id="404_not_found_propagates"), + ], + ) + def test_set_rendered_fields_swallows_410_but_propagates_404( + self, watched_subprocess, mocker, status_code, expects_error + ): + """A stale-id RTIF overwrite (410) is skipped silently; a bogus-id (404) still propagates as an error.""" + watched_subprocess, read_socket = watched_subprocess + + error = ServerResponseError( + message="boom", + request=httpx.Request("PUT", "http://test"), + response=httpx.Response(status_code, json={"detail": "boom"}), + ) + watched_subprocess.client.task_instances.set_rtif = mocker.Mock(side_effect=error) + + generator = watched_subprocess.handle_requests(log=mocker.Mock()) + next(generator) + + msg = SetRenderedFields(rendered_fields={"field1": "v1"}) + req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=msg.model_dump()) + generator.send(req_frame) + + read_socket.settimeout(0.1) + frame_len = int.from_bytes(read_socket.recv(4), "big") + frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(read_socket.recv(frame_len)) + + assert frame.id == req_frame.id + if expects_error: + assert frame.error is not None + assert frame.error["error"] == "API_SERVER_ERROR" + assert frame.error["detail"]["status_code"] == status_code + else: + assert frame.error is None + def test_handle_requests_network_exception_does_not_crash_loop(self, watched_subprocess, mocker): """A transient network error must not crash the IPC generator.