From 2ca18fbfb76d69046e48a870a0977f3f34ca23f4 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 6 Nov 2025 13:18:07 +0100 Subject: [PATCH 1/4] wip: add rolling task window --- src/aiida_workgraph/engine/task_manager.py | 152 ++++++++++++++++++++- 1 file changed, 151 insertions(+), 1 deletion(-) diff --git a/src/aiida_workgraph/engine/task_manager.py b/src/aiida_workgraph/engine/task_manager.py index 1d3621df..7748b307 100644 --- a/src/aiida_workgraph/engine/task_manager.py +++ b/src/aiida_workgraph/engine/task_manager.py @@ -44,6 +44,19 @@ def __init__(self, ctx_manager, logger, runner, process: Process, awaitable_mana self.state_manager = TaskStateManager(ctx_manager, logger, process, awaitable_manager) self.action_manager = TaskActionManager(self.state_manager, logger, process) + # Initialize window state with defaults (will be loaded from WorkGraph context later) + self.window_config = { + 'enabled': False, + 'window_size': float('inf'), + 'max_queued_jobs': None, + 'task_levels': {}, + } + self.window_state = { + 'min_active_level': 0, + 'max_allowed_level': float('inf'), + } + self._window_initialized = False + def get_task(self, name: str): """Get task from the context.""" task = self.process.wg.tasks[name] @@ -58,6 +71,120 @@ def get_task(self, name: str): socket.value = get_nested_dict(self.ctx._task_results[name], socket._name, default=None) return task + def _init_window_state(self): + """Initialize window state from WorkGraph context.""" + # Check if WorkGraph is available yet + if not hasattr(self.process, 'wg') or self.process.wg is None: + print("DEBUG: WorkGraph not available yet for window initialization") + return # WorkGraph not loaded yet, use defaults + + if self._window_initialized: + return # Already initialized + + # Load window config from WorkGraph extras (persisted with the WorkGraph) + window_config = getattr(self.process.wg, 'extras', {}).get('window_config', {}) + print(f"DEBUG: Initializing window state, config: {window_config}") + + self.window_config = { + 'enabled': window_config.get('enabled', False), + 'window_size': window_config.get('window_size', float('inf')), + 'max_queued_jobs': window_config.get('max_queued_jobs', None), + 'task_levels': window_config.get('task_levels', {}), + } + + # Initialize window state + if self.window_config['enabled']: + self.window_state = { + 'min_active_level': 0, + 'max_allowed_level': self.window_config['window_size'], + } + else: + self.window_state = { + 'min_active_level': 0, + 'max_allowed_level': float('inf'), + } + + self._window_initialized = True + + def _update_window(self): + """Update the active window based on task completion.""" + if not self.window_config['enabled']: + return + + # Find minimum level of active (CREATED/RUNNING) launcher tasks + active_levels = [] + for task_name, level in self.window_config['task_levels'].items(): + state = self.state_manager.get_task_runtime_info(task_name, 'state') + if state in ['CREATED', 'RUNNING']: + active_levels.append(level) + + if not active_levels: + # No active tasks - advance window to next pending level + old_min = self.window_state['min_active_level'] + # Find next level with pending tasks + for level in range(old_min, max(self.window_config['task_levels'].values()) + 1): + tasks_at_level = [ + name for name, l in self.window_config['task_levels'].items() + if l == level + ] + if tasks_at_level: + # Check if any task at this level is not finished + has_pending = any( + self.state_manager.get_task_runtime_info(name, 'state') + not in ['FINISHED', 'FAILED', 'SKIPPED'] + for name in tasks_at_level + ) + if has_pending: + self.window_state['min_active_level'] = level + break + else: + # All tasks finished, keep current min + self.window_state['min_active_level'] = old_min + else: + # Set min_active_level to minimum of active tasks + self.window_state['min_active_level'] = min(active_levels) + + # Update max_allowed_level + window_size = self.window_config['window_size'] + self.window_state['max_allowed_level'] = ( + self.window_state['min_active_level'] + window_size + ) + + def _is_task_in_window(self, task_name: str) -> bool: + """Check if task is within the active submission window.""" + if not self.window_config['enabled']: + return True # No windowing, all tasks allowed + + # get_job_data tasks and other non-launcher tasks are always allowed + if not task_name.startswith('launch_'): + return True + + # Check topological level + task_level = self.window_config['task_levels'].get(task_name) + if task_level is None: + # Task not in level mapping - allow it + return True + + if task_level > self.window_state['max_allowed_level']: + return False # Outside window + + # Check max_queued_jobs threshold if configured + if self.window_config.get('max_queued_jobs'): + active_count = self._count_active_jobs() + if active_count >= self.window_config['max_queued_jobs']: + return False # Too many jobs already + + return True + + def _count_active_jobs(self) -> int: + """Count tasks in CREATED or RUNNING state.""" + count = 0 + for task in self.process.wg.tasks: + state = self.state_manager.get_task_runtime_info(task.name, 'state') + if state in ['CREATED', 'RUNNING']: + count += 1 + return count + def set_task_results(self) -> None: from node_graph.config import BUILTIN_NODES @@ -103,7 +230,24 @@ def continue_workgraph(self) -> None: Resume the WorkGraph by looking for tasks that are ready to run. """ # self.process.report("Continue workgraph.") + + # Initialize window state if not already done (lazy initialization) + self._init_window_state() + + # Update window state if rolling window is enabled + if self.window_config.get('enabled'): + self._update_window() + # Report window state + if self.window_config['task_levels']: + active_count = self._count_active_jobs() + self.process.report( + f"Window: levels {self.window_state['min_active_level']}-" + f"{self.window_state['max_allowed_level']}, " + f"active jobs: {active_count}" + ) + task_to_run = [] + skipped_by_window = [] for task in self.process.wg.tasks: # update task state if ( @@ -121,9 +265,15 @@ def continue_workgraph(self) -> None: continue ready, _ = self.state_manager.is_task_ready_to_run(task.name) if ready: - task_to_run.append(task.name) + # Check if task is within active window + if self._is_task_in_window(task.name): + task_to_run.append(task.name) + else: + skipped_by_window.append(task.name) # self.process.report('tasks ready to run: {}'.format(','.join(task_to_run))) + if skipped_by_window: + self.process.report('tasks skipped (outside window): {}'.format(','.join(skipped_by_window))) self.run_tasks(task_to_run) def should_run_task(self, task: 'Task') -> bool: From f73d0036883cad42462e5c234f01e8752946eb11 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Tue, 11 Nov 2025 12:47:00 +0100 Subject: [PATCH 2/4] wip --- src/aiida_workgraph/workgraph.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index 771a4d24..63d6e93d 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -54,6 +54,7 @@ def __init__( self.max_number_jobs = 1000000 self.max_iteration = 1000000 self._error_handlers = error_handlers or {} + self.extras = {} # Initialize extras dict for custom metadata self.analyzer = NodeGraphAnalysis(self) @property @@ -238,6 +239,7 @@ def to_dict(self, include_sockets: bool = False, should_serialize: bool = False) 'restart_process': self.restart_process.pk if self.restart_process else None, 'max_iteration': self.max_iteration, 'max_number_jobs': self.max_number_jobs, + 'extras': getattr(self, 'extras', {}), # Serialize extras dict } ) # save error handlers @@ -340,6 +342,7 @@ def from_dict(cls, wgdata: Dict[str, Any]) -> 'WorkGraph': 'max_iteration', 'max_number_jobs', 'connectivity', + 'extras', # Restore extras dict (for window_config) ]: if key in wgdata: setattr(wg, key, wgdata[key]) From 93754049272a011617c6b9669df397060995ceb3 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Wed, 26 Nov 2025 10:59:18 +0100 Subject: [PATCH 3/4] implement dynamic task-level computation --- src/aiida_workgraph/engine/task_manager.py | 138 +++++++++++++++++---- 1 file changed, 111 insertions(+), 27 deletions(-) diff --git a/src/aiida_workgraph/engine/task_manager.py b/src/aiida_workgraph/engine/task_manager.py index 7748b307..3e7df443 100644 --- a/src/aiida_workgraph/engine/task_manager.py +++ b/src/aiida_workgraph/engine/task_manager.py @@ -48,12 +48,12 @@ def __init__(self, ctx_manager, logger, runner, process: Process, awaitable_mana self.window_config = { 'enabled': False, 'window_size': float('inf'), - 'max_queued_jobs': None, - 'task_levels': {}, + 'task_dependencies': {}, } self.window_state = { 'min_active_level': 0, 'max_allowed_level': float('inf'), + 'dynamic_task_levels': {}, } self._window_initialized = False @@ -75,7 +75,7 @@ def _init_window_state(self): """Initialize window state from WorkGraph context.""" # Check if WorkGraph is available yet if not hasattr(self.process, 'wg') or self.process.wg is None: - print("DEBUG: WorkGraph not available yet for window initialization") + self.logger.debug("WorkGraph not available yet for window initialization") return # WorkGraph not loaded yet, use defaults if self._window_initialized: @@ -83,13 +83,13 @@ def _init_window_state(self): # Load window config from WorkGraph extras (persisted with the WorkGraph) window_config = getattr(self.process.wg, 'extras', {}).get('window_config', {}) - print(f"DEBUG: Initializing window state, config: {window_config}") + self.logger.debug(f"Initializing window state, config: {window_config}") self.window_config = { 'enabled': window_config.get('enabled', False), 'window_size': window_config.get('window_size', float('inf')), 'max_queued_jobs': window_config.get('max_queued_jobs', None), - 'task_levels': window_config.get('task_levels', {}), + 'task_dependencies': window_config.get('task_dependencies', {}), } # Initialize window state @@ -97,24 +97,100 @@ def _init_window_state(self): self.window_state = { 'min_active_level': 0, 'max_allowed_level': self.window_config['window_size'], + 'dynamic_task_levels': self._compute_dynamic_levels(), } else: self.window_state = { 'min_active_level': 0, 'max_allowed_level': float('inf'), + 'dynamic_task_levels': {}, } self._window_initialized = True + def _compute_dynamic_levels(self) -> dict[str, int]: + """Compute task levels based on current unfinished tasks only. + + Key idea: Exclude FINISHED/FAILED/SKIPPED tasks from dependency graph, + then run BFS to compute levels. This allows faster branches to collapse + to lower levels as their dependencies complete. + + Returns: + Dict mapping task_name -> current dynamic level + """ + from collections import deque + + if not self.window_config['enabled']: + return {} + + task_deps = self.window_config['task_dependencies'] + + # Step 1: Filter to only unfinished tasks + unfinished_tasks = set() + for task_name in task_deps.keys(): + state = self.state_manager.get_task_runtime_info(task_name, 'state') + if state not in ['FINISHED', 'FAILED', 'SKIPPED']: + unfinished_tasks.add(task_name) + + # Step 2: Build filtered dependency graph (only unfinished tasks) + filtered_deps = {} + for task_name in unfinished_tasks: + unfinished_parents = [ + p for p in task_deps[task_name] + if p in unfinished_tasks + ] + filtered_deps[task_name] = unfinished_parents + + # Step 3: Compute levels using BFS (same algorithm as compute_topological_levels) + levels = {} + in_degree = {task: len(parents) for task, parents in filtered_deps.items()} + + # Find all tasks with no unfinished dependencies (level 0) + queue = deque([task for task, degree in in_degree.items() if degree == 0]) + for task_name in queue: + levels[task_name] = 0 + + # Build reverse dependency graph + children = {task: [] for task in filtered_deps} + for task_name, parents in filtered_deps.items(): + for parent in parents: + if parent not in children: + children[parent] = [] + children[parent].append(task_name) + + # Process tasks in topological order + processed = set() + while queue: + current = queue.popleft() + processed.add(current) + + for child in children.get(current, []): + parents = filtered_deps[child] + if all(p in processed for p in parents): + parent_levels = [levels[p] for p in parents] + levels[child] = max(parent_levels) + 1 if parent_levels else 0 + queue.append(child) + + return levels + def _update_window(self): - """Update the active window based on task completion.""" + """Update the active window based on task completion. + + Recomputes dynamic levels after each task completion to allow + faster branches to advance independently. + """ if not self.window_config['enabled']: return + # RECOMPUTE DYNAMIC LEVELS based on current task states + self.window_state['dynamic_task_levels'] = self._compute_dynamic_levels() + # Find minimum level of active (CREATED/RUNNING) launcher tasks active_levels = [] - for task_name, level in self.window_config['task_levels'].items(): + for task_name, level in self.window_state['dynamic_task_levels'].items(): state = self.state_manager.get_task_runtime_info(task_name, 'state') + # PRCOMMENT: These are all the states to be considered here? + # WG uses custom states, not plumpy's `class ProcessState` or core's `class JobState` enums?! if state in ['CREATED', 'RUNNING']: active_levels.append(level) @@ -122,23 +198,28 @@ def _update_window(self): # No active tasks - advance window to next pending level old_min = self.window_state['min_active_level'] # Find next level with pending tasks - for level in range(old_min, max(self.window_config['task_levels'].values()) + 1): - tasks_at_level = [ - name for name, l in self.window_config['task_levels'].items() - if l == level - ] - if tasks_at_level: - # Check if any task at this level is not finished - has_pending = any( - self.state_manager.get_task_runtime_info(name, 'state') - not in ['FINISHED', 'FAILED', 'SKIPPED'] - for name in tasks_at_level - ) - if has_pending: - self.window_state['min_active_level'] = level - break + if self.window_state['dynamic_task_levels']: + max_level = max(self.window_state['dynamic_task_levels'].values()) + for level in range(old_min, max_level + 1): + tasks_at_level = [ + name for name, lvl in self.window_state['dynamic_task_levels'].items() + if lvl == level + ] + if tasks_at_level: + # Check if any task at this level is not finished + has_pending = any( + self.state_manager.get_task_runtime_info(name, 'state') + not in ['FINISHED', 'FAILED', 'SKIPPED'] + for name in tasks_at_level + ) + if has_pending: + self.window_state['min_active_level'] = level + break + else: + # All tasks finished, keep current min + self.window_state['min_active_level'] = old_min else: - # All tasks finished, keep current min + # No tasks in dynamic levels (all finished), keep current min self.window_state['min_active_level'] = old_min else: # Set min_active_level to minimum of active tasks @@ -156,11 +237,13 @@ def _is_task_in_window(self, task_name: str) -> bool: return True # No windowing, all tasks allowed # get_job_data tasks and other non-launcher tasks are always allowed + # FIXME if not task_name.startswith('launch_'): + breakpoint() return True - # Check topological level - task_level = self.window_config['task_levels'].get(task_name) + # Check dynamic topological level + task_level = self.window_state['dynamic_task_levels'].get(task_name) if task_level is None: # Task not in level mapping - allow it return True @@ -238,11 +321,12 @@ def continue_workgraph(self) -> None: if self.window_config.get('enabled'): self._update_window() # Report window state - if self.window_config['task_levels']: + if self.window_state['dynamic_task_levels']: active_count = self._count_active_jobs() + max_level = max(self.window_state['dynamic_task_levels'].values()) if self.window_state['dynamic_task_levels'] else 0 self.process.report( f"Window: levels {self.window_state['min_active_level']}-" - f"{self.window_state['max_allowed_level']}, " + f"{self.window_state['max_allowed_level']} (max dynamic level: {max_level}), " f"active jobs: {active_count}" ) From 202d5019358b22dd729c95de3c5fe6f53e9d1dc6 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Thu, 27 Nov 2025 15:55:46 +0100 Subject: [PATCH 4/4] remove breakpoint --- src/aiida_workgraph/engine/task_manager.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/aiida_workgraph/engine/task_manager.py b/src/aiida_workgraph/engine/task_manager.py index 5010ae9a..cb716b24 100644 --- a/src/aiida_workgraph/engine/task_manager.py +++ b/src/aiida_workgraph/engine/task_manager.py @@ -237,9 +237,7 @@ def _is_task_in_window(self, task_name: str) -> bool: return True # No windowing, all tasks allowed # get_job_data tasks and other non-launcher tasks are always allowed - # FIXME if not task_name.startswith('launch_'): - breakpoint() return True # Check dynamic topological level