Skip to content

nbashyal/AGENT-GRAPH-PLANNER

Repository files navigation

MultiTaskGNN: Learning to Plan with Tools

MultiTaskGNN is a research framework that predicts tool-use workflow graphs directly from natural-language queries. Unlike Large Language Models (LLMs) that generate plans token-by-token (autoregressively), this project treats planning as a one-shot graph link prediction task.

Key Contributions

  1. Efficiency: Reduces planning complexity from $O(N)$ (text generation) to $O(1)$ (graph inference).
  2. Dataset: Includes a pipeline to generate robust synthetic workflows across 4 domains.
  3. Unified Architecture: A single GNN that jointly predicts workflow architecture, tool selection, and edge connectivity.

1. Setup & Installation

# 1. Clone the repository
git clone <repo_url>
cd MultiTaskGNN

# 2. Install dependencies
pip install -r requirements.txt

Requirements: torch, torch-geometric, sentence-transformers, scikit-learn.


2. Data Generation (Reproduce the Dataset)

This project uses a two-step generation pipeline to create high-quality synthetic data.

Step A: Generate Raw Tasks

Generates 10,000 JSON tasks with domain-specific queries (Finance, DevOps, etc.) and architectural patterns.

python src/agent_graphs/data/generate_synthetic_tasks_robust.py
  • Output: data/synthetic_tasks_robust_10k.json

Step B: Convert to Graph Format

Converts the raw task descriptions into node/edge lists suitable for GNN training.

python src/agent_graphs/data/generate_graphs_from_synthetic.py
  • Output: data/graph_tasks_robust_10k.json (Used for training)

3. Training & Experiments

A. Train the Full Model

To replicate the main results (AUC > 0.99), run the multitask training script.

python run_training_multitask.py
  • Output: Saves the best model to models/multitask/model_multitask.pt.

B. Run Ablation Studies (Table 3)

To prove the importance of the Graph Neural Network layers, run the ablation suite. This compares the Full Model against an Embeddings-Only Baseline.

python run_ablations.py

Expected Output (Low-Data Regime):

Variant Arch Acc Edge AUC
Full Model (GNN) 0.966 0.989
Embeddings Only 0.933 0.904
Blind Baseline 0.906 0.928

4. Inference & Visualization

You can run the trained model on new, unseen natural language queries. The system outputs a JSON plan and a Mermaid.js visualization.

python run_inference_multitask.py --query "Analyze audit logs, detect anomalies, and report results."

Output (plan_mermaid.md):

flowchart TD
    user["User Query"] --> agent["Main Agent"]
    agent --> fetch_logs
    agent --> fetch_metrics
    fetch_logs --> detect_anomalies
    fetch_metrics --> detect_anomalies
    detect_anomalies --> generate_report
Loading

5. Project Structure

  • src/agent_graphs/models: GNN Definitions (SAGEConv layers).
  • src/agent_graphs/data: Data generators and PyG dataset loaders.
  • src/agent_graphs/training: Training loops and loss functions.
  • run_ablations.py: Main reproduction script for the paper's claims.

About

GNN as a multiagent Orchestrator

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages