MultiTaskGNN is a research framework that predicts tool-use workflow graphs directly from natural-language queries. Unlike Large Language Models (LLMs) that generate plans token-by-token (autoregressively), this project treats planning as a one-shot graph link prediction task.
-
Efficiency: Reduces planning complexity from
$O(N)$ (text generation) to$O(1)$ (graph inference). - Dataset: Includes a pipeline to generate robust synthetic workflows across 4 domains.
- Unified Architecture: A single GNN that jointly predicts workflow architecture, tool selection, and edge connectivity.
# 1. Clone the repository
git clone <repo_url>
cd MultiTaskGNN
# 2. Install dependencies
pip install -r requirements.txtRequirements: torch, torch-geometric, sentence-transformers, scikit-learn.
This project uses a two-step generation pipeline to create high-quality synthetic data.
Generates 10,000 JSON tasks with domain-specific queries (Finance, DevOps, etc.) and architectural patterns.
python src/agent_graphs/data/generate_synthetic_tasks_robust.py- Output:
data/synthetic_tasks_robust_10k.json
Converts the raw task descriptions into node/edge lists suitable for GNN training.
python src/agent_graphs/data/generate_graphs_from_synthetic.py- Output:
data/graph_tasks_robust_10k.json(Used for training)
To replicate the main results (AUC > 0.99), run the multitask training script.
python run_training_multitask.py- Output: Saves the best model to
models/multitask/model_multitask.pt.
To prove the importance of the Graph Neural Network layers, run the ablation suite. This compares the Full Model against an Embeddings-Only Baseline.
python run_ablations.pyExpected Output (Low-Data Regime):
| Variant | Arch Acc | Edge AUC |
|---|---|---|
| Full Model (GNN) | 0.966 | 0.989 |
| Embeddings Only | 0.933 | 0.904 |
| Blind Baseline | 0.906 | 0.928 |
You can run the trained model on new, unseen natural language queries. The system outputs a JSON plan and a Mermaid.js visualization.
python run_inference_multitask.py --query "Analyze audit logs, detect anomalies, and report results."Output (plan_mermaid.md):
flowchart TD
user["User Query"] --> agent["Main Agent"]
agent --> fetch_logs
agent --> fetch_metrics
fetch_logs --> detect_anomalies
fetch_metrics --> detect_anomalies
detect_anomalies --> generate_report
src/agent_graphs/models: GNN Definitions (SAGEConv layers).src/agent_graphs/data: Data generators and PyG dataset loaders.src/agent_graphs/training: Training loops and loss functions.run_ablations.py: Main reproduction script for the paper's claims.