Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -409,5 +409,14 @@ def __exit__(

for root_task in roots:
root_task.set_upstream(create_databricks_workflow_task)

# When ``>>`` is called before the ``with`` block, the dependency is
# recorded on the task group but the launch/leaf tasks don't exist yet.
# Transfer those task-group-level dependencies now that the tasks are created.
for upstream_id in self.upstream_task_ids:
create_databricks_workflow_task.set_upstream(self.dag.get_task(upstream_id))
for downstream_id in self.downstream_task_ids:
for leaf_task in self.get_leaves():
leaf_task.set_downstream(self.dag.get_task(downstream_id))
finally:
super().__exit__(_type, _value, _tb)
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,79 @@ def test_reset_job_payload_carries_parent_depends_on(self, mock_databricks_hook)
job_id, job_spec = launch_task._hook.reset_job.call_args.args
assert job_id == 42
self._assert_parent_depends_on(job_spec)


class TestTaskGroupExternalDependencyTransfer:
"""Verify that task-group-level ``>>`` deps transfer to the launch / leaf tasks.

When ``>>`` is called *before* the ``with`` block, the launch task doesn't
exist yet; ``__exit__`` must transfer ``self.upstream_task_ids`` to the
launch task and ``self.downstream_task_ids`` to the leaf tasks.
"""

JOB_CLUSTERS = [
{
"job_cluster_key": "c",
"new_cluster": {
"spark_version": "15.4.x-scala2.12",
"num_workers": 0,
"node_type_id": "i3.xlarge",
},
}
]

def test_upstream_set_before_with_block(self):
with DAG(dag_id="test_up_before", start_date=DEFAULT_DATE, schedule=None) as dag:
start = EmptyOperator(task_id="start")
tg = DatabricksWorkflowTaskGroup(
group_id="tg", databricks_conn_id="databricks_conn", job_clusters=self.JOB_CLUSTERS
)
start >> tg
with tg:
DatabricksNotebookOperator(
task_id="nb", notebook_path="/t", source="WORKSPACE", job_cluster_key="c"
)
launch = dag.get_task("tg.launch")
assert "start" in launch.upstream_task_ids

def test_upstream_set_after_with_block(self):
with DAG(dag_id="test_up_after", start_date=DEFAULT_DATE, schedule=None) as dag:
start = EmptyOperator(task_id="start")
tg = DatabricksWorkflowTaskGroup(
group_id="tg", databricks_conn_id="databricks_conn", job_clusters=self.JOB_CLUSTERS
)
with tg:
DatabricksNotebookOperator(
task_id="nb", notebook_path="/t", source="WORKSPACE", job_cluster_key="c"
)
start >> tg
launch = dag.get_task("tg.launch")
assert "start" in launch.upstream_task_ids

def test_downstream_set_before_with_block(self):
with DAG(dag_id="test_down_before", start_date=DEFAULT_DATE, schedule=None) as dag:
tg = DatabricksWorkflowTaskGroup(
group_id="tg", databricks_conn_id="databricks_conn", job_clusters=self.JOB_CLUSTERS
)
end = EmptyOperator(task_id="end")
tg >> end
with tg:
DatabricksNotebookOperator(
task_id="nb", notebook_path="/t", source="WORKSPACE", job_cluster_key="c"
)
nb = dag.get_task("tg.nb")
assert "end" in nb.downstream_task_ids

def test_downstream_set_after_with_block(self):
with DAG(dag_id="test_down_after", start_date=DEFAULT_DATE, schedule=None) as dag:
tg = DatabricksWorkflowTaskGroup(
group_id="tg", databricks_conn_id="databricks_conn", job_clusters=self.JOB_CLUSTERS
)
end = EmptyOperator(task_id="end")
with tg:
DatabricksNotebookOperator(
task_id="nb", notebook_path="/t", source="WORKSPACE", job_cluster_key="c"
)
tg >> end
nb = dag.get_task("tg.nb")
assert "end" in nb.downstream_task_ids
Loading