Skip to content

Commit a322da1

Browse files
committed
post-rebase fixes
1 parent 997402a commit a322da1

4 files changed

Lines changed: 35 additions & 29 deletions

File tree

airflow-core/src/airflow/jobs/scheduler_job_runner.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -359,32 +359,35 @@ def _get_team_names_for_dag_ids(
359359
# Return dict with all None values to ensure graceful degradation
360360
return {}
361361

362-
def _get_task_team_name(self, task_instance: TaskInstance, session: Session) -> str | None:
362+
def _get_workload_team_name(self, workload: SchedulerWorkload, session: Session) -> str | None:
363363
"""
364-
Resolve team name for a task instance using the DAG > Bundle > Team relationship chain.
364+
Resolve team name for a workload using the DAG > Bundle > Team relationship chain.
365365
366-
TaskInstance > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team
366+
Workload > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team
367367
368-
:param task_instance: The TaskInstance to resolve team name for
368+
:param workload: The Workload to resolve team name for
369369
:param session: Database session for queries
370370
:return: Team name if found or None
371371
"""
372372
# Use the batch query function with a single DAG ID
373-
dag_id_to_team_name = self._get_team_names_for_dag_ids([task_instance.dag_id], session)
374-
team_name = dag_id_to_team_name.get(task_instance.dag_id)
373+
if dag_id := workload.get_dag_id():
374+
dag_id_to_team_name = self._get_team_names_for_dag_ids([dag_id], session)
375+
team_name = dag_id_to_team_name.get(dag_id)
376+
else:
377+
team_name = None # mypy didn't like the implicit defaulting to None
375378

376379
if team_name:
377380
self.log.debug(
378-
"Resolved team name '%s' for task %s (dag_id=%s)",
381+
"Resolved team name '%s' for workload %s (dag_id=%s)",
379382
team_name,
380-
task_instance.task_id,
381-
task_instance.dag_id,
383+
workload,
384+
dag_id,
382385
)
383386
else:
384387
self.log.debug(
385-
"No team found for task %s (dag_id=%s) - DAG may not have bundle or team association",
386-
task_instance.task_id,
387-
task_instance.dag_id,
388+
"No team found for workload %s (dag_id=%s) - DAG may not have bundle or team association",
389+
workload,
390+
dag_id,
388391
)
389392

390393
return team_name
@@ -1002,7 +1005,7 @@ def _enqueue_executor_callbacks(self, session: Session) -> None:
10021005
10031006
:param session: The database session
10041007
"""
1005-
num_occupied_slots = sum(executor.slots_occupied for executor in self.job.executors)
1008+
num_occupied_slots = sum(executor.slots_occupied for executor in self.executors)
10061009
max_callbacks = conf.getint("core", "parallelism") - num_occupied_slots
10071010

10081011
if max_callbacks <= 0:
@@ -1132,11 +1135,11 @@ def process_executor_events(
11321135
ti_primary_key_to_try_number_map[key.primary] = key.try_number
11331136
cls.logger().info("Received executor event with state %s for task instance %s", state, key)
11341137
if state in (
1135-
TaskInstanceState.FAILED,
1136-
TaskInstanceState.SUCCESS,
1137-
TaskInstanceState.QUEUED,
1138-
TaskInstanceState.RUNNING,
1139-
TaskInstanceState.RESTARTING,
1138+
TaskInstanceState.FAILED,
1139+
TaskInstanceState.SUCCESS,
1140+
TaskInstanceState.QUEUED,
1141+
TaskInstanceState.RUNNING,
1142+
TaskInstanceState.RESTARTING,
11401143
):
11411144
tis_with_right_state.append(key)
11421145
else:
@@ -3247,8 +3250,11 @@ def _executor_to_workloads(
32473250
workloads_list = list(workloads)
32483251
if workloads_list:
32493252
dag_id_to_team_name = self._get_team_names_for_dag_ids(
3250-
{dag_id for workload in workloads_list if
3251-
(dag_id := workload.get_dag_id()) is not None},
3253+
{
3254+
dag_id
3255+
for workload in workloads_list
3256+
if (dag_id := workload.get_dag_id()) is not None
3257+
},
32523258
session,
32533259
)
32543260
else:
@@ -3262,9 +3268,9 @@ def _executor_to_workloads(
32623268

32633269
_executor_to_workloads: defaultdict[BaseExecutor, list[SchedulerWorkload]] = defaultdict(list)
32643270
for workload in workloads_iter:
3265-
if executor_obj := self._try_to_load_executor(
3266-
workload, session, team_name=dag_id_to_team_name.get(workload.get_dag_id(), NOTSET)
3267-
):
3271+
_dag_id = workload.get_dag_id()
3272+
_team = dag_id_to_team_name.get(_dag_id, NOTSET) if _dag_id else NOTSET
3273+
if executor_obj := self._try_to_load_executor(workload, session, team_name=_team):
32683274
_executor_to_workloads[executor_obj].append(workload)
32693275

32703276
return _executor_to_workloads
@@ -3287,7 +3293,7 @@ def _try_to_load_executor(
32873293
if conf.getboolean("core", "multi_team"):
32883294
# Use provided team_name if available, otherwise query the database
32893295
if team_name is NOTSET:
3290-
team_name = self._get_task_team_name(workload, session)
3296+
team_name = self._get_workload_team_name(workload, session)
32913297
else:
32923298
team_name = None
32933299
# Firstly, check if there is no executor set on the workload, if not, we need to fetch the default
@@ -3308,8 +3314,8 @@ def _try_to_load_executor(
33083314
executor = self.executor
33093315
else:
33103316
# An executor is specified on the workload (as a str), so we need to find it in the list of executors
3311-
for _executor in self.job.executors:
3312-
if workload.get_executor_name() in (_executor.name.alias, _executor.name.module_path):
3317+
for _executor in self.executors:
3318+
if _executor.name and workload.get_executor_name() in (_executor.name.alias, _executor.name.module_path):
33133319
# The executor must either match the team or be global (i.e. team_name is None)
33143320
if team_name and _executor.team_name == team_name or _executor.team_name is None:
33153321
executor = _executor

providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def execute_workload(input: str) -> None:
204204
ti=workload.ti, # type: ignore[arg-type]
205205
dag_rel_path=workload.dag_rel_path,
206206
bundle_info=workload.bundle_info,
207-
token=workload.token,
207+
token=workload.identity_token,
208208
server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
209209
log_path=workload.log_path,
210210
)

providers/edge3/src/airflow/providers/edge3/cli/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) -
217217
ti=ti, # type: ignore[arg-type]
218218
dag_rel_path=workload.dag_rel_path,
219219
bundle_info=workload.bundle_info,
220-
token=workload.token,
220+
token=workload.identity_token,
221221
server=_execution_api_server_url(),
222222
log_path=workload.log_path,
223223
)

task-sdk/src/airflow/sdk/execution_time/execute_workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def execute_workload(workload: ExecuteTask) -> None:
6868
ti=workload.ti, # type: ignore[arg-type]
6969
dag_rel_path=workload.dag_rel_path,
7070
bundle_info=workload.bundle_info,
71-
token=workload.token,
71+
token=workload.identity_token,
7272
server=server,
7373
log_path=workload.log_path,
7474
sentry_integration=workload.sentry_integration,

0 commit comments

Comments
 (0)