diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index c93f0ed8e1e13..8ba1d58b9bdf5 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -658,9 +658,27 @@ def get_queued_dag_runs_to_set_running(cls, session: Session) -> ScalarResult[Da .subquery() ) + available_dagruns_rn = ( + select( + DagRun.dag_id, + DagRun.id, + func.row_number() + .over(partition_by=[DagRun.dag_id, DagRun.backfill_id], order_by=DagRun.logical_date) + .label("rn"), + ) + .where(DagRun.state == DagRunState.QUEUED) + .subquery() + ) + query = ( select(cls) - .where(cls.state == DagRunState.QUEUED) + .join( + available_dagruns_rn, + and_( + available_dagruns_rn.c.id == DagRun.id, + available_dagruns_rn.c.dag_id == DagRun.dag_id, + ), + ) .join( DagModel, and_( @@ -692,8 +710,12 @@ def get_queued_dag_runs_to_set_running(cls, session: Session) -> ScalarResult[Da # the one done in this query verifies that the dag is not maxed out # it could return many more dag runs than runnable if there is even # capacity for 1. this could be improved. - coalesce(running_drs.c.num_running, text("0")) - < coalesce(Backfill.max_active_runs, DagModel.max_active_runs), + available_dagruns_rn.c.rn + <= coalesce( + Backfill.max_active_runs, + DagModel.max_active_runs, + ) + - coalesce(running_drs.c.num_running, 0), # don't set paused dag runs as running not_(coalesce(cast("ColumnElement[bool]", Backfill.is_paused), False)), ) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index fd77644bc44b5..b1b3e21398107 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -3289,6 +3289,94 @@ def test_runs_are_created_after_max_active_runs_was_reached(self, dag_maker, ses dag_runs = DagRun.find(dag_id=dag.dag_id, session=session) assert len(dag_runs) == 2 + def test_runs_are_not_starved_by_max_active_runs_limit(self, dag_maker, session): + """ + Test that dagruns are not starved by max_active_runs + """ + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, executors=[self.null_exec]) + + dag_ids = ["dag1", "dag2", "dag3"] + + max_active_runs = 3 + + for dag_id in dag_ids: + with dag_maker( + dag_id=dag_id, + max_active_runs=max_active_runs, + session=session, + catchup=True, + schedule=timedelta(seconds=60), + start_date=DEFAULT_DATE, + ): + # Need to use something that doesn't immediately get marked as success by the scheduler + BashOperator(task_id="task", bash_command="true") + + dag_run = dag_maker.create_dagrun( + state=State.QUEUED, session=session, run_type=DagRunType.SCHEDULED + ) + + for _ in range(50): + # create a bunch of dagruns in queued state, to make sure they are filtered by max_active_runs + dag_run = dag_maker.create_dagrun_after( + dag_run, run_type=DagRunType.SCHEDULED, state=State.QUEUED + ) + + self.job_runner._start_queued_dagruns(session) + session.flush() + + running_dagrun_count = session.scalar( + select(func.count()).select_from(DagRun).where(DagRun.state == DagRunState.RUNNING) + ) + + assert running_dagrun_count == max_active_runs * len(dag_ids) + + def test_no_more_dagruns_are_set_to_running_when_max_active_runs_exceeded(self, dag_maker, session): + """ + Test that dagruns are not moved to running if there are more than the max_active_runs running dagruns + """ + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, executors=[self.null_exec]) + + max_active_runs = 1 + with dag_maker( + dag_id="test_dag", + max_active_runs=max_active_runs, + session=session, + catchup=True, + schedule=timedelta(seconds=60), + start_date=DEFAULT_DATE, + ): + # Need to use something that doesn't immediately get marked as success by the scheduler + BashOperator(task_id="task", bash_command="true") + + dag_run = dag_maker.create_dagrun(state=State.RUNNING, session=session, run_type=DagRunType.SCHEDULED) + + for _ in range(5): + # create a bunch of dagruns in queued state, to make sure they are filtered by max_active_runs + dag_run = dag_maker.create_dagrun_after( + dag_run, run_type=DagRunType.SCHEDULED, state=State.RUNNING + ) + + running_dagruns_pre = session.scalar( + select(func.count()).select_from(DagRun).where(DagRun.state == DagRunState.RUNNING) + ) + + for _ in range(5): + # create a bunch of dagruns in queued state, to make sure they are filtered by max_active_runs + dag_run = dag_maker.create_dagrun_after( + dag_run, run_type=DagRunType.SCHEDULED, state=State.QUEUED + ) + + self.job_runner._start_queued_dagruns(session) + session.flush() + + running_dagruns_post = session.scalar( + select(func.count()).select_from(DagRun).where(DagRun.state == DagRunState.RUNNING) + ) + + assert running_dagruns_pre == running_dagruns_post + def test_dagrun_timeout_verify_max_active_runs(self, dag_maker, session): """ Test if a dagrun will not be scheduled if max_dag_runs @@ -5965,14 +6053,14 @@ def _running_counts(): EmptyOperator(task_id="mytask") dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.QUEUED) - for _ in range(9): + for _ in range(29): dr = dag_maker.create_dagrun_after(dr, run_type=DagRunType.SCHEDULED, state=State.QUEUED) # initial state -- nothing is running assert dag1_non_b_running == 0 assert dag1_b_running == 0 assert total_running == 0 - assert session.scalar(select(func.count(DagRun.id))) == 46 + assert session.scalar(select(func.count(DagRun.id))) == 66 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 # now let's run it once @@ -5980,26 +6068,40 @@ def _running_counts(): session.flush() # after running the scheduler one time, observe that only one dag run is started - # this is because there are 30 runs for dag 1 so neither the backfills nor + # and 3 backfill dagruns are started + # this is because there are 30 dags, most of which get filtered due to max_active_runs + # and so due to the default dagruns to examine, we look at the first 20 dags which CAN be run + # according to the max_active_runs parameter, meaning 3 backfill runs will start, 1 non backfill and + # all dagruns of dag2 # any runs for dag2 get started assert DagRun.DEFAULT_DAGRUNS_TO_EXAMINE == 20 dag1_non_b_running, dag1_b_running, total_running = _running_counts() assert dag1_non_b_running == 1 - assert dag1_b_running == 0 - assert total_running == 1 - assert session.scalar(select(func.count()).select_from(DagRun)) == 46 + assert dag1_b_running == 3 + assert total_running == 20 + assert session.scalar(select(func.count()).select_from(DagRun)) == 66 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 + # now we finish all lower priority backfill tasks, and observe new higher priority tasks are started + session.execute( + update(DagRun) + .where(DagRun.dag_id == "test_dag2", DagRun.state == DagRunState.RUNNING) + .values(state=DagRunState.SUCCESS) + ) + session.commit() + session.flush() # we run scheduler again and observe that now all the runs are created + # other than the finished runs of the backfill # this must be because sorting is working + # new tasks from test dag 2 should run, and so they are scheduled self.job_runner._start_queued_dagruns(session) session.flush() dag1_non_b_running, dag1_b_running, total_running = _running_counts() assert dag1_non_b_running == 1 assert dag1_b_running == 3 - assert total_running == 14 - assert session.scalar(select(func.count()).select_from(DagRun)) == 46 + assert total_running == 18 + assert session.scalar(select(func.count()).select_from(DagRun)) == 66 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 # run it a 3rd time and nothing changes @@ -6009,8 +6111,8 @@ def _running_counts(): dag1_non_b_running, dag1_b_running, total_running = _running_counts() assert dag1_non_b_running == 1 assert dag1_b_running == 3 - assert total_running == 14 - assert session.scalar(select(func.count()).select_from(DagRun)) == 46 + assert total_running == 18 + assert session.scalar(select(func.count()).select_from(DagRun)) == 66 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 def test_backfill_runs_are_started_with_lower_priority_catchup_false(self, dag_maker, session): @@ -6230,25 +6332,11 @@ def _running_counts(): assert dag1_non_b_running == 1 assert dag1_b_running == 3 - # this should be 14 but it is not. why? - # answer: because dag2 got starved out by dag1 - # if we run the scheduler again, dag2 should get queued - assert total_running == 4 + assert total_running == 14 assert session.scalar(select(func.count()).select_from(DagRun)) == 46 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36 - # run scheduler a second time - self.job_runner._start_queued_dagruns(session) - session.flush() - - dag1_non_b_running, dag1_b_running, total_running = _running_counts() - assert dag1_non_b_running == 1 - assert dag1_b_running == 3 - - # on the second try, dag 2's 10 runs now start running - assert total_running == 14 - assert session.scalar(select(func.count()).select_from(DagRun)) == 46 assert session.scalar(select(func.count()).where(DagRun.dag_id == dag1_dag_id)) == 36