Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from airflow.sdk.definitions.deadline import DeadlineAlert
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.operator_resources import Resources
from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.sdk.definitions.param import DagParam, Param, ParamsDict
from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
from airflow.sdk.definitions.xcom_arg import serialize_xcom_arg
from airflow.sdk.execution_time.context import OutletEventAccessor, OutletEventAccessors
Expand Down Expand Up @@ -579,6 +579,10 @@ def serialize(
return TaskGroupSerialization.serialize_task_group(var)
elif isinstance(var, Param):
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
elif isinstance(var, DagParam):
return cls._encode(
cls._serialize_param(Param(default=var._default, source="dag")), type_=DAT.PARAM
)
Comment on lines +582 to +585

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested this? Looking at the code, it seems only _default is stored and _name is dropped, and since DagParam resolves at runtime via dag_run.conf[self._name], I think this could be a problem. Or am I misreading the flow?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I don't think that's the correct thing to do. In op_args the serialized representation is used after DagParam went through serialize_template_field. For the partial_kwargs this path is not taken and what I did here will not work properly on runtime, i got confused while debugging in the heat.
I don't really know how a fix for this should look like - I think either serialize_mapped_operator needs to be adapted and possibly should use serialize_template_field for the partial_kwargs ops or an own Serialization class is needed for DagParam possibly similar to XComArg?
If somebody is willing to point me in a direction, I am also open to contribute a fix.
I'm closing this PR and will open a bug report instead.
We found a workaround using Jinja Templates for DagParams in partial for now.


add.partial(value="{{ params.p }}").expand(value=[1, 2, 3])

Sorry for the confusion.

elif isinstance(var, XComArg):
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
elif isinstance(var, LazySelectSequence):
Expand Down
41 changes: 41 additions & 0 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4768,6 +4768,47 @@ def test_partial_kwargs_end_to_end_deserialization(self):
assert deserialized_task.partial_kwargs["retry_delay"] == timedelta(seconds=600)
assert deserialized_task.partial_kwargs["owner"] == "custom_owner"

def test_partial_kwargs_dag_param_serialization_is_stable(self):
"""DagParam passed to a mapped task's partial() must serialize to a stable representation.

Without dedicated handling the serializer falls back to ``str(var)`` which embeds the
object's memory address, producing a different serialized Dag on every parse.
"""
from airflow.sdk import task

with DAG(dag_id="test_dag_param_partial") as dag:

@task
def add(value):
return value

add.partial(value=dag.param("p", "default_value")).expand(value=[1, 2, 3])

serialized_dag = DagSerialization.to_dict(dag)
mapped_task = serialized_dag["dag"]["tasks"][0]["__var"]
serialized_value = mapped_task["partial_kwargs"]["op_kwargs"]["__var"]["value"]

# The DagParam must encode to its stable structure, not a repr with a memory address.
assert serialized_value["__type"] == "param"
assert serialized_value["__var"] == {
"__class": "airflow.sdk.definitions.param.Param",
"description": None,
"source": "dag",
"default": "default_value",
"schema": {"__var": {}, "__type": "dict"},
}

# Serializing the same Dag twice must produce identical output.
assert DagSerialization.to_dict(dag) == serialized_dag

# And the round-trip restores a real DagParam bound to the Dag.
deserialized_dag = DagSerialization.from_dict(serialized_dag)
deserialized_task = deserialized_dag.get_task("add")
restored_param = deserialized_task.partial_kwargs["op_kwargs"]["value"]
assert isinstance(restored_param, SerializedParam)
assert restored_param.value == "default_value"
assert restored_param.source == "dag"


@pytest.mark.parametrize(
("callbacks", "expected_has_flags", "absent_keys"),
Expand Down
Loading