-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathagent_cli.py
More file actions
144 lines (119 loc) · 4.99 KB
/
agent_cli.py
File metadata and controls
144 lines (119 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# python agent_cli.py [test_name]
# (c) Terence Lim
import asyncio
import argparse
import json
from datetime import datetime
from pathlib import Path
from typing import Any
from pydantic_ai.exceptions import UnexpectedModelBehavior
import logfire
from agent_delegation import attach_research_delegation_tools
from shared_agents import create_agents
from utils import OUTPUT
import logging
logging.basicConfig(level=logging.ERROR)
# Configure logging
#logfire.configure()
#logfire.instrument_pydantic_ai()
TESTS = Path('tests')
K = 5
RETRIES = 2
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run the Quant Research Agent CLI")
parser.add_argument(
"test",
nargs="?",
help="Optional name of test prompt and outfile",
)
return parser.parse_args()
def append_evaluation_log(response: str, evaluation_stem: str, mode: str, **kwargs) -> None:
payload = {
"date": str(datetime.now())[:19],
"response": response,
} | kwargs
evaluation_path = OUTPUT / (evaluation_stem + ".responses")
evaluation_path.parent.mkdir(parents=True, exist_ok=True)
with evaluation_path.open(mode, encoding="utf-8") as log_file:
log_file.write(json.dumps(payload) + "\n")
def build_conversation_context(messages: list[dict[str, str]], max_messages: int = 50) -> str:
tail = messages[-max_messages:] if len(messages) > max_messages else messages
lines = ["<conversation_history>"]
for i, m in enumerate(tail, start=1):
lines.append(f" <message index=\"{i}\" role=\"{m['role']}\">")
lines.append("```text")
lines.append(m["content"])
lines.append("```")
lines.append(" </message>")
lines.append("</conversation_history>")
return "\n".join(lines)
async def run_agent_safely(agent: Any, query: str, retries: int = RETRIES) -> str:
attempts = 0
while True:
try:
response = await agent.run(query)
return response.output if hasattr(response, "output") else str(response)
except UnexpectedModelBehavior:
attempts += 1
if attempts > retries:
return "The agent could not complete the request due to model behavior errors."
await asyncio.sleep(0.4)
except Exception as exc: # pragma: no cover - defensive fallback
print(query)
return f"The agent failed with an unexpected error: {exc}"
async def main() -> None:
args = parse_args()
bundle = create_agents()
research_agent = bundle["research_agent"]
report_agent = bundle["report_agent"]
coding_agent = bundle["coding_agent"]
messages: list[dict[str, str]] = []
def on_instruction(role: str, message: str) -> None:
messages.append({"role": role, "content": message})
def on_result(role: str, message: str) -> None:
messages.append({"role": role, "content": message})
attach_research_delegation_tools(
research_agent=research_agent,
report_agent=report_agent,
coding_agent=coding_agent,
build_context=lambda: build_conversation_context(messages),
store_conversation=lambda _: None,
on_instruction=on_instruction,
on_result=on_result,
run_agent_safely=lambda a, q, _role, retries=RETRIES: run_agent_safely(a, q, retries=retries),
)
if args.test is None:
print("Quant Research Agent CLI (type 'exit' to quit)")
while True:
user_input = input("\nYou: ").strip()
if user_input.lower() in {"exit", "quit"}:
break
messages.append({"role": "user", "content": user_input})
prompt = build_conversation_context(messages)
reply = await run_agent_safely(research_agent, prompt)
messages.append({"role": "assistant", "content": reply})
print(f"\nAgent: {reply}")
return
print(f"Quant Research Agent CLI (single-query mode: running {K} times)")
with open(TESTS / (args.test + ".query"), "r", encoding="utf-8") as prompt_file:
query = prompt_file.read().strip()
label = f"Running query from {args.test}.query:"
print(f"\n{'-'*len(label)}\n{label}\n{'-'*len(label)}\n{query}\n")
mode = 'a' ### 'w'
for run_index in range(1, K + 1):
messages = [{"role": "user", "content": query}]
prompt = build_conversation_context(messages)
kwargs = {"start": str(datetime.now())[:19]}
for tries in range(RETRIES + 1):
reply = await run_agent_safely(research_agent, prompt)
if any(word not in reply for word in ["MODEL", "ERROR"]):
break
print(f'Retrying {tries}/{RETRIES}')
# TODO: handle when [high demand Spike] all in reply
kwargs['retries'] = tries
messages.append({"role": "assistant", "content": reply})
append_evaluation_log(reply, args.test, mode=mode, **kwargs)
### mode = 'a'
print(f"\nAgent [{run_index}/{K}]: {reply}")
if __name__ == "__main__":
asyncio.run(main())