Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 233 additions & 1 deletion src/aiida_workgraph/engine/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ def __init__(self, ctx_manager, logger, runner, process: Process, awaitable_mana
self.state_manager = TaskStateManager(ctx_manager, logger, process, awaitable_manager)
self.action_manager = TaskActionManager(self.state_manager, logger, process)

# Initialize window state with defaults (will be loaded from WorkGraph context later)
self.window_config = {
'enabled': False,
'window_size': float('inf'),
'task_dependencies': {},
}
self.window_state = {
'min_active_level': 0,
'max_allowed_level': float('inf'),
'dynamic_task_levels': {},
}
self._window_initialized = False

def get_task(self, name: str):
"""Get task from the context."""
task = self.process.wg.tasks[name]
Expand All @@ -60,6 +73,201 @@ def get_task(self, name: str):
socket.value = get_nested_dict(self.ctx._task_results[name], socket._name, default=None)
return task

def _init_window_state(self):
"""Initialize window state from WorkGraph context."""
# Check if WorkGraph is available yet
if not hasattr(self.process, 'wg') or self.process.wg is None:
self.logger.debug("WorkGraph not available yet for window initialization")
return # WorkGraph not loaded yet, use defaults

if self._window_initialized:
return # Already initialized

# Load window config from WorkGraph extras (persisted with the WorkGraph)
window_config = getattr(self.process.wg, 'extras', {}).get('window_config', {})
self.logger.debug(f"Initializing window state, config: {window_config}")

self.window_config = {
'enabled': window_config.get('enabled', False),
'window_size': window_config.get('window_size', float('inf')),
'max_queued_jobs': window_config.get('max_queued_jobs', None),
'task_dependencies': window_config.get('task_dependencies', {}),
}

# Initialize window state
if self.window_config['enabled']:
self.window_state = {
'min_active_level': 0,
'max_allowed_level': self.window_config['window_size'],
'dynamic_task_levels': self._compute_dynamic_levels(),
}
else:
self.window_state = {
'min_active_level': 0,
'max_allowed_level': float('inf'),
'dynamic_task_levels': {},
}

self._window_initialized = True

def _compute_dynamic_levels(self) -> dict[str, int]:
"""Compute task levels based on current unfinished tasks only.

Key idea: Exclude FINISHED/FAILED/SKIPPED tasks from dependency graph,
then run BFS to compute levels. This allows faster branches to collapse
to lower levels as their dependencies complete.

Returns:
Dict mapping task_name -> current dynamic level
"""
from collections import deque

if not self.window_config['enabled']:
return {}

task_deps = self.window_config['task_dependencies']

# Step 1: Filter to only unfinished tasks
unfinished_tasks = set()
for task_name in task_deps.keys():
state = self.state_manager.get_task_runtime_info(task_name, 'state')
if state not in ['FINISHED', 'FAILED', 'SKIPPED']:
unfinished_tasks.add(task_name)

# Step 2: Build filtered dependency graph (only unfinished tasks)
filtered_deps = {}
for task_name in unfinished_tasks:
unfinished_parents = [
p for p in task_deps[task_name]
if p in unfinished_tasks
]
filtered_deps[task_name] = unfinished_parents

# Step 3: Compute levels using BFS (same algorithm as compute_topological_levels)
levels = {}
in_degree = {task: len(parents) for task, parents in filtered_deps.items()}

# Find all tasks with no unfinished dependencies (level 0)
queue = deque([task for task, degree in in_degree.items() if degree == 0])
for task_name in queue:
levels[task_name] = 0

# Build reverse dependency graph
children = {task: [] for task in filtered_deps}
for task_name, parents in filtered_deps.items():
for parent in parents:
if parent not in children:
children[parent] = []
children[parent].append(task_name)

# Process tasks in topological order
processed = set()
while queue:
current = queue.popleft()
processed.add(current)

for child in children.get(current, []):
parents = filtered_deps[child]
if all(p in processed for p in parents):
parent_levels = [levels[p] for p in parents]
levels[child] = max(parent_levels) + 1 if parent_levels else 0
queue.append(child)

return levels

def _update_window(self):
"""Update the active window based on task completion.

Recomputes dynamic levels after each task completion to allow
faster branches to advance independently.
"""
if not self.window_config['enabled']:
return

# RECOMPUTE DYNAMIC LEVELS based on current task states
self.window_state['dynamic_task_levels'] = self._compute_dynamic_levels()

# Find minimum level of active (CREATED/RUNNING) launcher tasks
active_levels = []
for task_name, level in self.window_state['dynamic_task_levels'].items():
state = self.state_manager.get_task_runtime_info(task_name, 'state')
# PRCOMMENT: These are all the states to be considered here?
# WG uses custom states, not plumpy's `class ProcessState` or core's `class JobState` enums?!
if state in ['CREATED', 'RUNNING']:
active_levels.append(level)

if not active_levels:
# No active tasks - advance window to next pending level
old_min = self.window_state['min_active_level']
# Find next level with pending tasks
if self.window_state['dynamic_task_levels']:
max_level = max(self.window_state['dynamic_task_levels'].values())
for level in range(old_min, max_level + 1):
tasks_at_level = [
name for name, lvl in self.window_state['dynamic_task_levels'].items()
if lvl == level
]
if tasks_at_level:
# Check if any task at this level is not finished
has_pending = any(
self.state_manager.get_task_runtime_info(name, 'state')
not in ['FINISHED', 'FAILED', 'SKIPPED']
for name in tasks_at_level
)
if has_pending:
self.window_state['min_active_level'] = level
break
else:
# All tasks finished, keep current min
self.window_state['min_active_level'] = old_min
else:
# No tasks in dynamic levels (all finished), keep current min
self.window_state['min_active_level'] = old_min
else:
# Set min_active_level to minimum of active tasks
self.window_state['min_active_level'] = min(active_levels)

# Update max_allowed_level
window_size = self.window_config['window_size']
self.window_state['max_allowed_level'] = (
self.window_state['min_active_level'] + window_size
)

def _is_task_in_window(self, task_name: str) -> bool:
"""Check if task is within the active submission window."""
if not self.window_config['enabled']:
return True # No windowing, all tasks allowed

# get_job_data tasks and other non-launcher tasks are always allowed
if not task_name.startswith('launch_'):
return True

# Check dynamic topological level
task_level = self.window_state['dynamic_task_levels'].get(task_name)
if task_level is None:
# Task not in level mapping - allow it
return True

if task_level > self.window_state['max_allowed_level']:
return False # Outside window

# Check max_queued_jobs threshold if configured
if self.window_config.get('max_queued_jobs'):
active_count = self._count_active_jobs()
if active_count >= self.window_config['max_queued_jobs']:
return False # Too many jobs already

return True

def _count_active_jobs(self) -> int:
"""Count tasks in CREATED or RUNNING state."""
count = 0
for task in self.process.wg.tasks:
state = self.state_manager.get_task_runtime_info(task.name, 'state')
if state in ['CREATED', 'RUNNING']:
count += 1
return count

def set_task_results(self) -> None:
from node_graph.config import BUILTIN_TASKS

Expand Down Expand Up @@ -105,7 +313,25 @@ def continue_workgraph(self) -> None:
Resume the WorkGraph by looking for tasks that are ready to run.
"""
# self.process.report("Continue workgraph.")

# Initialize window state if not already done (lazy initialization)
self._init_window_state()

# Update window state if rolling window is enabled
if self.window_config.get('enabled'):
self._update_window()
# Report window state
if self.window_state['dynamic_task_levels']:
active_count = self._count_active_jobs()
max_level = max(self.window_state['dynamic_task_levels'].values()) if self.window_state['dynamic_task_levels'] else 0
self.process.report(
f"Window: levels {self.window_state['min_active_level']}-"
f"{self.window_state['max_allowed_level']} (max dynamic level: {max_level}), "
f"active jobs: {active_count}"
)

task_to_run = []
skipped_by_window = []
for task in self.process.wg.tasks:
# update task state
if (
Expand All @@ -123,9 +349,15 @@ def continue_workgraph(self) -> None:
continue
ready, _ = self.state_manager.is_task_ready_to_run(task.name)
if ready:
task_to_run.append(task.name)
# Check if task is within active window
if self._is_task_in_window(task.name):
task_to_run.append(task.name)
else:
skipped_by_window.append(task.name)
#
self.process.report('tasks ready to run: {}'.format(','.join(task_to_run)))
if skipped_by_window:
self.process.report('tasks skipped (outside window): {}'.format(','.join(skipped_by_window)))
self.run_tasks(task_to_run)

def should_run_task(self, task: 'Task') -> bool:
Expand Down
3 changes: 3 additions & 0 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
self.max_number_jobs = 1000000
self.max_iteration = 1000000
self._error_handlers = error_handlers or {}
self.extras = {} # Initialize extras dict for custom metadata
self.analyzer = GraphAnalysis(self)

def to_engine_inputs(self, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
Expand Down Expand Up @@ -232,6 +233,7 @@ def to_dict(self, include_sockets: bool = False, should_serialize: bool = False)
'restart_process': self.restart_process.pk if self.restart_process else None,
'max_iteration': self.max_iteration,
'max_number_jobs': self.max_number_jobs,
'extras': getattr(self, 'extras', {}), # Serialize extras dict
}
)
# save error handlers
Expand Down Expand Up @@ -331,6 +333,7 @@ def from_dict(cls, wgdata: Dict[str, Any]) -> 'WorkGraph':
'max_iteration',
'max_number_jobs',
'connectivity',
'extras', # Restore extras dict (for window_config)
]:
if key in wgdata:
setattr(wg, key, wgdata[key])
Expand Down
Loading