diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py index 779c2fc9f1528..317bd444911c9 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_workflow.py @@ -409,5 +409,14 @@ def __exit__( for root_task in roots: root_task.set_upstream(create_databricks_workflow_task) + + # When ``>>`` is called before the ``with`` block, the dependency is + # recorded on the task group but the launch/leaf tasks don't exist yet. + # Transfer those task-group-level dependencies now that the tasks are created. + for upstream_id in self.upstream_task_ids: + create_databricks_workflow_task.set_upstream(self.dag.get_task(upstream_id)) + for downstream_id in self.downstream_task_ids: + for leaf_task in self.get_leaves(): + leaf_task.set_downstream(self.dag.get_task(downstream_id)) finally: super().__exit__(_type, _value, _tb) diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py index 84069ee0ff7f8..0f42298beb772 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_workflow.py @@ -571,3 +571,79 @@ def test_reset_job_payload_carries_parent_depends_on(self, mock_databricks_hook) job_id, job_spec = launch_task._hook.reset_job.call_args.args assert job_id == 42 self._assert_parent_depends_on(job_spec) + + +class TestTaskGroupExternalDependencyTransfer: + """Verify that task-group-level ``>>`` deps transfer to the launch / leaf tasks. + + When ``>>`` is called *before* the ``with`` block, the launch task doesn't + exist yet; ``__exit__`` must transfer ``self.upstream_task_ids`` to the + launch task and ``self.downstream_task_ids`` to the leaf tasks. + """ + + JOB_CLUSTERS = [ + { + "job_cluster_key": "c", + "new_cluster": { + "spark_version": "15.4.x-scala2.12", + "num_workers": 0, + "node_type_id": "i3.xlarge", + }, + } + ] + + def test_upstream_set_before_with_block(self): + with DAG(dag_id="test_up_before", start_date=DEFAULT_DATE, schedule=None) as dag: + start = EmptyOperator(task_id="start") + tg = DatabricksWorkflowTaskGroup( + group_id="tg", databricks_conn_id="databricks_conn", job_clusters=self.JOB_CLUSTERS + ) + start >> tg + with tg: + DatabricksNotebookOperator( + task_id="nb", notebook_path="/t", source="WORKSPACE", job_cluster_key="c" + ) + launch = dag.get_task("tg.launch") + assert "start" in launch.upstream_task_ids + + def test_upstream_set_after_with_block(self): + with DAG(dag_id="test_up_after", start_date=DEFAULT_DATE, schedule=None) as dag: + start = EmptyOperator(task_id="start") + tg = DatabricksWorkflowTaskGroup( + group_id="tg", databricks_conn_id="databricks_conn", job_clusters=self.JOB_CLUSTERS + ) + with tg: + DatabricksNotebookOperator( + task_id="nb", notebook_path="/t", source="WORKSPACE", job_cluster_key="c" + ) + start >> tg + launch = dag.get_task("tg.launch") + assert "start" in launch.upstream_task_ids + + def test_downstream_set_before_with_block(self): + with DAG(dag_id="test_down_before", start_date=DEFAULT_DATE, schedule=None) as dag: + tg = DatabricksWorkflowTaskGroup( + group_id="tg", databricks_conn_id="databricks_conn", job_clusters=self.JOB_CLUSTERS + ) + end = EmptyOperator(task_id="end") + tg >> end + with tg: + DatabricksNotebookOperator( + task_id="nb", notebook_path="/t", source="WORKSPACE", job_cluster_key="c" + ) + nb = dag.get_task("tg.nb") + assert "end" in nb.downstream_task_ids + + def test_downstream_set_after_with_block(self): + with DAG(dag_id="test_down_after", start_date=DEFAULT_DATE, schedule=None) as dag: + tg = DatabricksWorkflowTaskGroup( + group_id="tg", databricks_conn_id="databricks_conn", job_clusters=self.JOB_CLUSTERS + ) + end = EmptyOperator(task_id="end") + with tg: + DatabricksNotebookOperator( + task_id="nb", notebook_path="/t", source="WORKSPACE", job_cluster_key="c" + ) + tg >> end + nb = dag.get_task("tg.nb") + assert "end" in nb.downstream_task_ids