-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtask_registry.py
More file actions
164 lines (150 loc) · 5.41 KB
/
task_registry.py
File metadata and controls
164 lines (150 loc) · 5.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3
import json
import sqlite3
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
SCHEMA = """
CREATE TABLE IF NOT EXISTS tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_chat_id TEXT NOT NULL,
source_message_id TEXT,
source_user_id TEXT,
source_text TEXT NOT NULL,
category TEXT NOT NULL,
route_reason TEXT NOT NULL,
allowed_agents_json TEXT NOT NULL,
status TEXT NOT NULL,
claimed_by TEXT,
claim_started_at REAL,
finished_at REAL,
result_summary TEXT,
error_text TEXT,
created_at REAL NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_tasks_status_created ON tasks(status, created_at);
"""
class TaskRegistry:
def __init__(self, db_path: str) -> None:
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
def _init_db(self) -> None:
with self._connect() as conn:
conn.executescript(SCHEMA)
def create_task(
self,
*,
source_chat_id: str,
source_message_id: str,
source_user_id: str,
source_text: str,
category: str,
route_reason: str,
allowed_agents: List[str],
) -> int:
created_at = time.time()
with self._connect() as conn:
cursor = conn.execute(
"""
INSERT INTO tasks (
source_chat_id, source_message_id, source_user_id, source_text,
category, route_reason, allowed_agents_json, status, created_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)
""",
(
source_chat_id,
source_message_id,
source_user_id,
source_text,
category,
route_reason,
json.dumps(allowed_agents, ensure_ascii=True),
created_at,
),
)
return int(cursor.lastrowid)
def get_task(self, task_id: int) -> Optional[Dict[str, Any]]:
with self._connect() as conn:
row = conn.execute("SELECT * FROM tasks WHERE id = ?", (task_id,)).fetchone()
return self._row_to_task(row) if row else None
def list_claimable_tasks(self, agent_name: str, limit: int = 10) -> List[Dict[str, Any]]:
with self._connect() as conn:
rows = conn.execute(
"SELECT * FROM tasks WHERE status = 'pending' ORDER BY created_at ASC LIMIT ?",
(limit,),
).fetchall()
tasks = []
for row in rows:
task = self._row_to_task(row)
if agent_name in task["allowed_agents"]:
tasks.append(task)
return tasks
def claim_task(self, task_id: int, agent_name: str) -> bool:
claim_started_at = time.time()
with self._connect() as conn:
cursor = conn.execute(
"""
UPDATE tasks
SET status = 'claimed', claimed_by = ?, claim_started_at = ?
WHERE id = ? AND status = 'pending'
""",
(agent_name, claim_started_at, task_id),
)
return cursor.rowcount == 1
def requeue_stale_claims(self, agent_name: str, stale_secs: int) -> List[int]:
stale_before = time.time() - max(0, stale_secs)
with self._connect() as conn:
rows = conn.execute(
"""
SELECT id
FROM tasks
WHERE status = 'claimed'
AND claimed_by = ?
AND claim_started_at IS NOT NULL
AND claim_started_at < ?
ORDER BY claim_started_at ASC
""",
(agent_name, stale_before),
).fetchall()
task_ids = [int(row["id"]) for row in rows]
if task_ids:
conn.executemany(
"""
UPDATE tasks
SET status = 'pending', claimed_by = NULL, claim_started_at = NULL
WHERE id = ?
""",
[(task_id,) for task_id in task_ids],
)
return task_ids
def finish_task(self, task_id: int, agent_name: str, result_summary: str) -> None:
with self._connect() as conn:
conn.execute(
"""
UPDATE tasks
SET status = 'completed', finished_at = ?, result_summary = ?
WHERE id = ? AND claimed_by = ?
""",
(time.time(), result_summary, task_id, agent_name),
)
def fail_task(self, task_id: int, agent_name: str, error_text: str) -> None:
with self._connect() as conn:
conn.execute(
"""
UPDATE tasks
SET status = 'failed', finished_at = ?, error_text = ?
WHERE id = ? AND claimed_by = ?
""",
(time.time(), error_text, task_id, agent_name),
)
@staticmethod
def _row_to_task(row: sqlite3.Row) -> Dict[str, Any]:
task = dict(row)
task["allowed_agents"] = json.loads(task.pop("allowed_agents_json"))
return task