-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtools.py
More file actions
333 lines (292 loc) · 12.1 KB
/
tools.py
File metadata and controls
333 lines (292 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""the four tools. bash, read, write, edit. nothing else."""
import asyncio
import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from logger import log
# strips CSI/OSC escape sequences. covers the 99% case for normal CLI output.
_ANSI = re.compile(r"\x1b\[[0-9;?]*[a-zA-Z]|\x1b\][^\x07]*\x07")
MAX_OUTPUT_BYTES = 100_000 # ~100KB, anything beyond is tail-truncated
DEFAULT_READ_LINES = 2000 # default line cap for read(). match pi/claude code.
# pre-flight guardrail: patterns we refuse to run. each entry is (regex, reason).
# order matters — first match wins. we match command-name positions only
# (start of command or after |/;/&) to avoid false positives from quoted args.
_CMD_HEAD = r"(?:^|[|&;])\s*" # command begins here
_BANNED: tuple[tuple[re.Pattern[str], str], ...] = (
(re.compile(_CMD_HEAD + r"find\s+"),
"use `rg --files` (respects .gitignore) instead of `find`"),
(re.compile(_CMD_HEAD + r"grep\s+-[rR]\b"),
"use `rg 'pattern'` instead of `grep -r`"),
(re.compile(_CMD_HEAD + r"ls\s+-R\b"),
"use `rg --files` instead of `ls -R`"),
(re.compile(_CMD_HEAD + r"cat\s+\S+\.(?:py|js|ts|tsx|jsx|go|rs|rb|java|c|cc|cpp|h|hpp|md|json|ya?ml|toml)\b"),
"use the `read` tool (not `cat`) for code/config files"),
(re.compile(r"\bcurl\s+[^|]*\|\s*(?:sh|bash)\b"),
"refusing: `curl | sh` is unsafe"),
(re.compile(_CMD_HEAD + r"rm\s+-rf?\s+(?:/|~|\*\s*$)"),
"refusing: `rm -rf` on a dangerous target (/, ~, *)"),
)
def vet(cmd: str) -> str | None:
"""pre-flight check. returns a hint string if cmd is blocked, else None."""
for pat, msg in _BANNED:
if pat.search(cmd):
return msg
return None
def _clean(raw: bytes) -> str:
"""decode, strip ANSI, normalize newlines."""
text = raw.decode("utf-8", errors="replace")
text = _ANSI.sub("", text)
return text.replace("\r\n", "\n").replace("\r", "\n")
def _truncate(text: str, limit: int = MAX_OUTPUT_BYTES) -> tuple[str, bool]:
"""keep the tail. return (text, truncated)."""
data = text.encode("utf-8")
if len(data) <= limit:
return text, False
kept = data[-limit:].decode("utf-8", errors="replace")
header = f"[... truncated {len(data) - limit} bytes; showing last {limit} ...]\n"
return header + kept, True
@dataclass
class ToolResult:
output: str
is_error: bool = False
truncated: bool = False
timed_out: bool = False
aborted: bool = False
# Diff information for file operations
diff_info: dict | None = None # {path: str, old_content: str, new_content: str, operation: str}
async def _kill(proc: asyncio.subprocess.Process) -> None:
"""try terminate, then kill. never raise."""
if proc.returncode is not None:
return
try:
proc.terminate()
try:
await asyncio.wait_for(proc.wait(), timeout=0.5)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
except ProcessLookupError:
pass
async def bash(cmd: str, cwd: str = ".", timeout: float | None = None) -> ToolResult:
"""run a shell command. combined stdout+stderr, ANSI stripped, tail-truncated.
honors asyncio cancellation. kills the process on timeout or abort.
pre-flight: vets against banned patterns (find, grep -r, etc.)."""
hint = vet(cmd)
if hint is not None:
log("cmd_blocked", {"cmd": cmd[:200], "hint": hint})
return ToolResult(output=f"[blocked] {hint}", is_error=True)
proc = await asyncio.create_subprocess_shell(
cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
cwd=cwd,
)
timed_out = False
aborted = False
try:
stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=timeout)
except asyncio.TimeoutError:
timed_out = True
await _kill(proc)
stdout = b""
except asyncio.CancelledError:
aborted = True
await _kill(proc)
raise # propagate cancellation after cleanup
output, truncated = _truncate(_clean(stdout))
rc = proc.returncode
is_error = timed_out or aborted or (rc not in (0, None))
# anthropic rejects tool_result with is_error=true and empty content.
# synthesize a reason so the model has something to work with.
if is_error and not output:
if timed_out:
output = f"[timeout after {timeout}s, no output]"
elif aborted:
output = "[aborted, no output]"
else:
output = f"[exit {rc}, no output]"
return ToolResult(
output=output,
is_error=is_error,
truncated=truncated,
timed_out=timed_out,
aborted=aborted,
)
def _looks_binary(sample: bytes) -> bool:
"""heuristic: NUL byte in the first 8KB => binary."""
return b"\x00" in sample
async def read(
path: str,
cwd: str = ".",
offset: int = 1,
limit: int | None = None,
) -> ToolResult:
"""read a text file. 1-indexed line offset + optional limit.
refuses binary files. defaults to first 2000 lines.
returns output prefixed with line numbers, truncation marker if capped."""
try:
offset = int(offset)
if limit is not None:
limit = int(limit)
except (TypeError, ValueError):
return ToolResult(output="[error] offset and limit must be integers", is_error=True)
abs_path = (Path(cwd) / path).resolve() if not Path(path).is_absolute() else Path(path).resolve()
if not abs_path.exists():
return ToolResult(output=f"[error] not found: {path}", is_error=True)
if abs_path.is_dir():
return ToolResult(output=f"[error] is a directory: {path}", is_error=True)
with abs_path.open("rb") as f:
head = f.read(8192)
if _looks_binary(head):
return ToolResult(output=f"[error] binary file: {path}", is_error=True)
text = abs_path.read_text(encoding="utf-8", errors="replace")
lines = text.splitlines()
total = len(lines)
start = max(0, offset - 1)
cap = limit if limit is not None else DEFAULT_READ_LINES
end = min(total, start + cap)
selected = lines[start:end]
numbered = "\n".join(f"{start + i + 1:6d}|{line}" for i, line in enumerate(selected))
truncated = end < total
if truncated:
numbered += f"\n[... showing lines {start + 1}-{end} of {total}; use offset/limit for more]"
return ToolResult(output=numbered or "[empty file]", truncated=truncated)
# per-path mutation queue: serialize concurrent writes/edits to the same file.
_file_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
def _resolve(path: str, cwd: str) -> Path:
p = Path(path)
return p.resolve() if p.is_absolute() else (Path(cwd) / p).resolve()
async def write(path: str, content: str, cwd: str = ".") -> ToolResult:
"""create or overwrite a file. creates parent dirs. serialized per absolute path."""
abs_path = _resolve(path, cwd)
lock = _file_locks[str(abs_path)]
# Check if file exists and read old content for diff
is_new_file = not abs_path.exists()
old_content = ""
async with lock:
if not is_new_file:
try:
old_content = abs_path.read_text(encoding="utf-8")
except (OSError, UnicodeDecodeError):
# If we can't read it, treat as new file
old_content = ""
is_new_file = True
try:
abs_path.parent.mkdir(parents=True, exist_ok=True)
abs_path.write_text(content, encoding="utf-8")
except OSError as e:
return ToolResult(output=f"[error] {e}", is_error=True)
return ToolResult(
output=f"wrote {len(content)} bytes to {path}",
diff_info={
"path": path,
"old_content": old_content,
"new_content": content,
"operation": "write",
"is_new_file": is_new_file
}
)
async def edit(
path: str,
old_text: str | None = None,
new_text: str | None = None,
cwd: str = ".",
edits: list[dict] | None = None,
) -> ToolResult:
"""find/replace on a single file. two modes:
single edit: pass old_text + new_text. old_text must match exactly once.
batch edits: pass edits=[{old_text, new_text}, ...] applied sequentially
against an in-memory buffer. atomic: if any edit fails, no
changes are written.
serialized per absolute path via _file_locks.
"""
# normalize input shape: collapse to a single internal `ops` list.
single_given = old_text is not None or new_text is not None
if edits is not None and single_given:
return ToolResult(
output="[error] provide either (old_text, new_text) or edits[], not both",
is_error=True,
)
if edits is None:
if not single_given:
return ToolResult(
output="[error] no edits given: pass old_text+new_text or edits[]",
is_error=True,
)
if old_text is None or new_text is None:
return ToolResult(
output="[error] single-edit mode requires both old_text and new_text",
is_error=True,
)
ops: list[dict] = [{"old_text": old_text, "new_text": new_text}]
else:
if not isinstance(edits, list) or not edits:
return ToolResult(output="[error] edits[] is empty", is_error=True)
ops = edits
# validate each op shape up front so we fail before opening the file.
for i, op in enumerate(ops):
if not isinstance(op, dict) or "old_text" not in op or "new_text" not in op:
return ToolResult(
output=f"[error] edits[{i}]: must be an object with old_text and new_text",
is_error=True,
)
ot, nt = op["old_text"], op["new_text"]
if not isinstance(ot, str) or not isinstance(nt, str):
return ToolResult(
output=f"[error] edits[{i}]: old_text and new_text must be strings",
is_error=True,
)
if ot == nt:
return ToolResult(
output=f"[error] edits[{i}]: old_text and new_text are identical",
is_error=True,
)
abs_path = _resolve(path, cwd)
lock = _file_locks[str(abs_path)]
async with lock:
if not abs_path.exists():
return ToolResult(output=f"[error] not found: {path}", is_error=True)
try:
original_content = abs_path.read_text(encoding="utf-8")
buffer = original_content
except UnicodeDecodeError:
return ToolResult(output=f"[error] binary file: {path}", is_error=True)
applied_lines: list[int] = []
for i, op in enumerate(ops):
ot, nt = op["old_text"], op["new_text"]
count = buffer.count(ot)
if count == 0:
prior = len(applied_lines)
ctx = f" ({prior} prior edit{'s' if prior != 1 else ''} applied to buffer)" if prior else ""
return ToolResult(
output=f"[error] edits[{i}]: old_text not found in {path}{ctx}",
is_error=True,
)
if count > 1:
return ToolResult(
output=(
f"[error] edits[{i}]: old_text matches {count} times in {path}; "
"make it unique with more surrounding context"
),
is_error=True,
)
line = buffer[: buffer.index(ot)].count("\n") + 1
applied_lines.append(line)
buffer = buffer.replace(ot, nt, 1)
# atomic write only after all edits succeed against the buffer.
abs_path.write_text(buffer, encoding="utf-8")
if len(applied_lines) == 1:
result_msg = f"edited {path} at line {applied_lines[0]}"
else:
lines_str = ", ".join(str(n) for n in applied_lines)
result_msg = f"edited {path}: {len(applied_lines)} changes at lines {lines_str}"
return ToolResult(
output=result_msg,
diff_info={
"path": path,
"old_content": original_content,
"new_content": buffer,
"operation": "edit"
}
)