Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import json
from collections import defaultdict
from collections.abc import Iterator
from typing import TYPE_CHECKING, Annotated, Any, cast
from typing import TYPE_CHECKING, Annotated, Any, NoReturn, cast
from uuid import UUID

import attrs
Expand Down Expand Up @@ -852,6 +852,29 @@ def ti_skip_downstream(
log.info("Downstream tasks skipped", tasks_skipped=getattr(result, "rowcount", 0))


def _raise_ti_not_in_live_table(task_instance_id: UUID, session: SessionDep) -> NoReturn:
"""Raise 410 Gone if the missing TI id was archived to history, else 404 Not Found."""
if session.scalar(
select(func.count(TIH.task_instance_id)).where(TIH.task_instance_id == task_instance_id)
):
log.error("TaskInstance not in live table but archived in history", ti_id=str(task_instance_id))
raise HTTPException(
status_code=status.HTTP_410_GONE,
detail={
"reason": "not_found",
"message": "Task Instance not found, it may have been moved to the Task Instance History table",
},
)
log.error("Task Instance not found", ti_id=str(task_instance_id))
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": "Task Instance not found",
},
)


@ti_id_router.put(
"/{task_instance_id}/heartbeat",
status_code=status.HTTP_204_NO_CONTENT,
Expand Down Expand Up @@ -912,29 +935,7 @@ def ti_heartbeat(
# Check if the TI exists in the Task Instance History table.
# If it does, it was likely cleared while running, so return 410 Gone
# instead of 404 Not Found to give the client a more specific signal.
tih_exists = session.scalar(
select(func.count(TIH.task_instance_id)).where(TIH.task_instance_id == task_instance_id)
)
if tih_exists:
log.error(
"TaskInstance was previously cleared and archived in history, heartbeat skipped",
ti_id=str(task_instance_id),
)
raise HTTPException(
status_code=status.HTTP_410_GONE,
detail={
"reason": "not_found",
"message": "Task Instance not found, it may have been moved to the Task Instance History table",
},
)
log.error("Task Instance not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": "Task Instance not found",
},
)
_raise_ti_not_in_live_table(task_instance_id, session)

if hostname != ti_payload.hostname or pid != ti_payload.pid:
log.warning(
Expand Down Expand Up @@ -981,6 +982,10 @@ def ti_heartbeat(
responses=create_openapi_http_exception_doc(
[
(status.HTTP_404_NOT_FOUND, "Task Instance not found"),
(
status.HTTP_410_GONE,
"Task Instance not found in the TI table but exists in the Task Instance History table",
),
(
HTTP_422_UNPROCESSABLE_CONTENT,
"Invalid payload for the setting rendered task instance fields",
Expand All @@ -999,10 +1004,8 @@ def ti_put_rtif(

task_instance = session.scalar(select(TI).where(TI.id == task_instance_id))
if not task_instance:
log.error("Task Instance not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
)
# On retry/clear, the server regenerates the TI id. Return 410 for the stale id.
_raise_ti_not_in_live_table(task_instance_id, session)
task_instance.update_rtif(put_rtif_payload, session=session)
log.debug("RenderedTaskInstanceFields updated successfully")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2794,7 +2794,36 @@ def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance):
random_id = uuid6.uuid7()
response = client.put(f"/execution/task-instances/{random_id}/rtif", json=payload)
assert response.status_code == 404
assert response.json()["detail"] == "Not Found"
assert response.json()["detail"] == {
"reason": "not_found",
"message": "Task Instance not found",
}

def test_ti_put_rtif_archived_ti_returns_410(self, client, session, create_task_instance):
ti = create_task_instance(
task_id="test_ti_put_rtif_archived",
state=State.RUNNING,
session=session,
)
session.commit()
old_ti_id = ti.id

# Archive the current try to TIH and assign a new UUID, mirroring prepare_db_for_next_try().
ti.prepare_db_for_next_try(session)
session.commit()

assert session.get(TaskInstance, old_ti_id) is None

response = client.put(
f"/execution/task-instances/{old_ti_id}/rtif",
json={"field1": "rendered_value1"},
)

assert response.status_code == 410
assert response.json()["detail"] == {
"reason": "not_found",
"message": "Task Instance not found, it may have been moved to the Task Instance History table",
}


class TestPreviousDagRun:
Expand Down
8 changes: 7 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,13 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
elif isinstance(msg, PutVariable):
resp, dump_opts = handle_put_variable(self.client, msg)
elif isinstance(msg, SetRenderedFields):
self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
try:
self.client.task_instances.set_rtif(self.id, msg.rendered_fields)
except ServerResponseError as e:
# The TI id was archived when the server regenerated it on retry/clear, so skip 410.
if e.response.status_code != HTTPStatus.GONE:
raise
log.debug("Skipping RTIF overwrite; task instance archived on retry/clear", ti_id=self.id)
elif isinstance(msg, SetRenderedMapIndex):
self.client.task_instances.set_rendered_map_index(self.id, msg.rendered_map_index)
elif isinstance(msg, GetAssetByName):
Expand Down
39 changes: 39 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3139,6 +3139,45 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker):
"detail": error.response.json(),
}

@pytest.mark.parametrize(
("status_code", "expects_error"),
[
pytest.param(410, False, id="410_gone_is_swallowed"),
pytest.param(404, True, id="404_not_found_propagates"),
],
)
def test_set_rendered_fields_swallows_410_but_propagates_404(
self, watched_subprocess, mocker, status_code, expects_error
):
"""A stale-id RTIF overwrite (410) is skipped silently; a bogus-id (404) still propagates as an error."""
watched_subprocess, read_socket = watched_subprocess

error = ServerResponseError(
message="boom",
request=httpx.Request("PUT", "http://test"),
response=httpx.Response(status_code, json={"detail": "boom"}),
)
watched_subprocess.client.task_instances.set_rtif = mocker.Mock(side_effect=error)

generator = watched_subprocess.handle_requests(log=mocker.Mock())
next(generator)

msg = SetRenderedFields(rendered_fields={"field1": "v1"})
req_frame = _RequestFrame(id=randint(1, 2**32 - 1), body=msg.model_dump())
generator.send(req_frame)

read_socket.settimeout(0.1)
frame_len = int.from_bytes(read_socket.recv(4), "big")
frame = msgspec.msgpack.Decoder(_ResponseFrame).decode(read_socket.recv(frame_len))

assert frame.id == req_frame.id
if expects_error:
assert frame.error is not None
assert frame.error["error"] == "API_SERVER_ERROR"
assert frame.error["detail"]["status_code"] == status_code
else:
assert frame.error is None

def test_handle_requests_network_exception_does_not_crash_loop(self, watched_subprocess, mocker):
"""A transient network error must not crash the IPC generator.

Expand Down
Loading