@@ -359,32 +359,35 @@ def _get_team_names_for_dag_ids(
359359 # Return dict with all None values to ensure graceful degradation
360360 return {}
361361
362- def _get_task_team_name (self , task_instance : TaskInstance , session : Session ) -> str | None :
362+ def _get_workload_team_name (self , workload : SchedulerWorkload , session : Session ) -> str | None :
363363 """
364- Resolve team name for a task instance using the DAG > Bundle > Team relationship chain.
364+ Resolve team name for a workload using the DAG > Bundle > Team relationship chain.
365365
366- TaskInstance > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team
366+ Workload > DagModel (via dag_id) > DagBundleModel (via bundle_name) > Team
367367
368- :param task_instance : The TaskInstance to resolve team name for
368+ :param workload : The Workload to resolve team name for
369369 :param session: Database session for queries
370370 :return: Team name if found or None
371371 """
372372 # Use the batch query function with a single DAG ID
373- dag_id_to_team_name = self ._get_team_names_for_dag_ids ([task_instance .dag_id ], session )
374- team_name = dag_id_to_team_name .get (task_instance .dag_id )
373+ if dag_id := workload .get_dag_id ():
374+ dag_id_to_team_name = self ._get_team_names_for_dag_ids ([dag_id ], session )
375+ team_name = dag_id_to_team_name .get (dag_id )
376+ else :
377+ team_name = None # mypy didn't like the implicit defaulting to None
375378
376379 if team_name :
377380 self .log .debug (
378- "Resolved team name '%s' for task %s (dag_id=%s)" ,
381+ "Resolved team name '%s' for workload %s (dag_id=%s)" ,
379382 team_name ,
380- task_instance . task_id ,
381- task_instance . dag_id ,
383+ workload ,
384+ dag_id ,
382385 )
383386 else :
384387 self .log .debug (
385- "No team found for task %s (dag_id=%s) - DAG may not have bundle or team association" ,
386- task_instance . task_id ,
387- task_instance . dag_id ,
388+ "No team found for workload %s (dag_id=%s) - DAG may not have bundle or team association" ,
389+ workload ,
390+ dag_id ,
388391 )
389392
390393 return team_name
@@ -1002,7 +1005,7 @@ def _enqueue_executor_callbacks(self, session: Session) -> None:
10021005
10031006 :param session: The database session
10041007 """
1005- num_occupied_slots = sum (executor .slots_occupied for executor in self .job . executors )
1008+ num_occupied_slots = sum (executor .slots_occupied for executor in self .executors )
10061009 max_callbacks = conf .getint ("core" , "parallelism" ) - num_occupied_slots
10071010
10081011 if max_callbacks <= 0 :
@@ -1132,11 +1135,11 @@ def process_executor_events(
11321135 ti_primary_key_to_try_number_map [key .primary ] = key .try_number
11331136 cls .logger ().info ("Received executor event with state %s for task instance %s" , state , key )
11341137 if state in (
1135- TaskInstanceState .FAILED ,
1136- TaskInstanceState .SUCCESS ,
1137- TaskInstanceState .QUEUED ,
1138- TaskInstanceState .RUNNING ,
1139- TaskInstanceState .RESTARTING ,
1138+ TaskInstanceState .FAILED ,
1139+ TaskInstanceState .SUCCESS ,
1140+ TaskInstanceState .QUEUED ,
1141+ TaskInstanceState .RUNNING ,
1142+ TaskInstanceState .RESTARTING ,
11401143 ):
11411144 tis_with_right_state .append (key )
11421145 else :
@@ -3247,8 +3250,11 @@ def _executor_to_workloads(
32473250 workloads_list = list (workloads )
32483251 if workloads_list :
32493252 dag_id_to_team_name = self ._get_team_names_for_dag_ids (
3250- {dag_id for workload in workloads_list if
3251- (dag_id := workload .get_dag_id ()) is not None },
3253+ {
3254+ dag_id
3255+ for workload in workloads_list
3256+ if (dag_id := workload .get_dag_id ()) is not None
3257+ },
32523258 session ,
32533259 )
32543260 else :
@@ -3262,9 +3268,9 @@ def _executor_to_workloads(
32623268
32633269 _executor_to_workloads : defaultdict [BaseExecutor , list [SchedulerWorkload ]] = defaultdict (list )
32643270 for workload in workloads_iter :
3265- if executor_obj := self . _try_to_load_executor (
3266- workload , session , team_name = dag_id_to_team_name .get (workload . get_dag_id () , NOTSET )
3267- ):
3271+ _dag_id = workload . get_dag_id ()
3272+ _team = dag_id_to_team_name .get (_dag_id , NOTSET ) if _dag_id else NOTSET
3273+ if executor_obj := self . _try_to_load_executor ( workload , session , team_name = _team ):
32683274 _executor_to_workloads [executor_obj ].append (workload )
32693275
32703276 return _executor_to_workloads
@@ -3287,7 +3293,7 @@ def _try_to_load_executor(
32873293 if conf .getboolean ("core" , "multi_team" ):
32883294 # Use provided team_name if available, otherwise query the database
32893295 if team_name is NOTSET :
3290- team_name = self ._get_task_team_name (workload , session )
3296+ team_name = self ._get_workload_team_name (workload , session )
32913297 else :
32923298 team_name = None
32933299 # Firstly, check if there is no executor set on the workload, if not, we need to fetch the default
@@ -3308,8 +3314,8 @@ def _try_to_load_executor(
33083314 executor = self .executor
33093315 else :
33103316 # An executor is specified on the workload (as a str), so we need to find it in the list of executors
3311- for _executor in self .job . executors :
3312- if workload .get_executor_name () in (_executor .name .alias , _executor .name .module_path ):
3317+ for _executor in self .executors :
3318+ if _executor . name and workload .get_executor_name () in (_executor .name .alias , _executor .name .module_path ):
33133319 # The executor must either match the team or be global (i.e. team_name is None)
33143320 if team_name and _executor .team_name == team_name or _executor .team_name is None :
33153321 executor = _executor
0 commit comments