Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import copy
import json
from typing import Annotated

from fastapi import Depends, HTTPException, Query, status
Expand Down Expand Up @@ -270,27 +269,24 @@ def create_xcom_entry(
)

try:
value = json.dumps(request_body.value)
except (ValueError, TypeError):
XComModel.set(
key=request_body.key,
value=request_body.value,
dag_id=dag_id,
task_id=task_id,
run_id=dag_run_id,
map_index=request_body.map_index,
serialize=False,
session=session,
)
except (ValueError, TypeError) as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, f"Couldn't serialise the XCom with key: `{request_body.key}`"
)

new = XComModel(
dag_run_id=dag_run.id,
key=request_body.key,
value=value,
run_id=dag_run_id,
task_id=task_id,
dag_id=dag_id,
map_index=request_body.map_index,
)
session.add(new)
session.flush()
) from e

xcom = session.scalar(
select(XComModel)
.filter(
.where(
XComModel.dag_id == dag_id,
XComModel.task_id == task_id,
XComModel.run_id == dag_run_id,
Expand Down Expand Up @@ -324,11 +320,12 @@ def update_xcom_entry(
dag_run_id: str,
xcom_key: str,
patch_body: XComUpdateBody,
*,
session: SessionDep,
) -> XComResponseNative:
"""Update an existing XCom entry."""
# Check if XCom entry exists
xcom_entry = session.scalar(
xcom_query = (
select(XComModel)
.where(
XComModel.dag_id == dag_id,
Expand All @@ -340,16 +337,32 @@ def update_xcom_entry(
.limit(1)
.options(joinedload(XComModel.task), joinedload(XComModel.dag_run).joinedload(DR.dag_model))
)
xcom_entry = session.scalar(xcom_query)

if not xcom_entry:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f"The XCom with key: `{xcom_key}` with mentioned task instance doesn't exist.",
)

# Update XCom entry
xcom_entry.value = json.dumps(patch_body.value)
try:
XComModel.set(
key=xcom_key,
value=patch_body.value,
dag_id=dag_id,
task_id=task_id,
run_id=dag_run_id,
map_index=patch_body.map_index,
serialize=False,
session=session,
)
except (ValueError, TypeError) as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, f"Couldn't serialise the XCom with key: `{xcom_key}`"
) from e

# Re-fetch after set (delete + insert) to get fresh object for response
xcom_entry = session.scalar(xcom_query)
return XComResponseNative.model_validate(xcom_entry)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def test_create_xcom_entry(
# Validate the created XCom response
current_data = response.json()
assert current_data["key"] == request_body.key
assert current_data["value"] == XComModel.serialize_value(request_body.value)
assert current_data["value"] == request_body.value
assert current_data["dag_id"] == dag_id
assert current_data["task_id"] == task_id
assert current_data["run_id"] == dag_run_id
Expand Down Expand Up @@ -716,7 +716,7 @@ def test_create_xcom_entry_with_slash_key(self, test_client):
)
assert get_resp.status_code == 200
assert get_resp.json()["key"] == slash_key
assert get_resp.json()["value"] == json.dumps(TEST_XCOM_VALUE)
assert get_resp.json()["value"] == TEST_XCOM_VALUE


class TestDeleteXComEntry(TestXComEndpoint):
Expand Down Expand Up @@ -814,7 +814,7 @@ def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detai
assert response.status_code == expected_status

if expected_status == 200:
assert response.json()["value"] == json.dumps(patch_body["value"])
assert response.json()["value"] == patch_body["value"]
else:
assert response.json()["detail"] == expected_detail
check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None)
Expand Down Expand Up @@ -843,5 +843,5 @@ def test_patch_xcom_entry_with_slash_key(self, test_client, session):
)
assert response.status_code == 200
assert response.json()["key"] == slash_key
assert response.json()["value"] == json.dumps(new_value)
assert response.json()["value"] == new_value
check_last_log(session, dag_id=TEST_DAG_ID, event="update_xcom_entry", logical_date=None)
Loading