-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgraph.py
More file actions
103 lines (77 loc) · 3.55 KB
/
graph.py
File metadata and controls
103 lines (77 loc) · 3.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys
import os
# Make root importable when running graph.py directly
sys.path.insert(0, os.path.dirname(__file__))
from langgraph.graph import StateGraph, END
from state import FAQAgentState
from nodes.generate_node import generate_node
from nodes.validate_node import validate_node
from nodes.load_question_node import load_question_node
from nodes.save_result_node import save_result_node
MAX_RETRIES = 3 # Maximum generation + validation attempts per question
# ──────────────────────────────────────────────
# Conditional routing functions
# ──────────────────────────────────────────────
def route_after_validation(state: FAQAgentState) -> str:
from state import ValidationResult
validation: ValidationResult | None = state.get("validation")
retry_count = state["retry_count"]
if validation and validation.is_valid:
print(f"[Router] Answer VALID after {retry_count} attempt(s) → saving result")
return "save_result"
if retry_count >= MAX_RETRIES:
print(f"[Router] Max retries ({MAX_RETRIES}) reached → saving with no_answer")
return "save_result"
print(f"[Router] Answer invalid (attempt {retry_count}/{MAX_RETRIES}) → retrying")
return "generate_node"
def route_after_save(state: FAQAgentState) -> str:
"""
After saving a result, check whether more questions remain.
- "load_question" → more questions to process
- END → all questions done
"""
if state["current_index"] < len(state["questions"]):
return "load_question"
print("\n[Router] All questions processed → ending graph.")
return END
def build_graph() -> StateGraph:
"""
Constructs and compiles the FAQ Reflection Agent LangGraph.
Returns a compiled graph ready for invocation.
"""
graph = StateGraph(FAQAgentState)
# ── Register nodes ─────────────────────────
graph.add_node("load_question", load_question_node)
graph.add_node("generate_node", generate_node)
graph.add_node("validate_node", validate_node)
graph.add_node("save_result", save_result_node)
# ── Entry point ────────────────────────────
graph.set_entry_point("load_question")
# ── Static edges ───────────────────────────
graph.add_edge("load_question", "generate_node")
graph.add_edge("generate_node", "validate_node")
# ── Conditional edge: after validation ─────
graph.add_conditional_edges(
"validate_node",
route_after_validation,
{
"generate_node": "generate_node", # retry loop
"save_result": "save_result", # valid or exhausted
},
)
# ── Conditional edge: after saving result ──
graph.add_conditional_edges(
"save_result",
route_after_save,
{
"load_question": "load_question", # next question
END: END, # finished all questions
},
)
compiled = graph.compile()
graph_image = compiled.get_graph().draw_mermaid_png()
with open("graph_image.png","wb") as f:
f.write(graph_image)
return compiled
# ── Singleton compiled graph (imported by main.py) ──
faq_graph = build_graph()