|
1 | 1 | """ |
2 | | -Yutori n1 Sampling Loop |
| 2 | +Yutori n1.5 Sampling Loop |
3 | 3 |
|
4 | | -Implements the agent loop for Yutori's n1-latest computer use model. |
5 | | -n1-latest uses an OpenAI-compatible API with tool_calls: |
| 4 | +Implements the agent loop for Yutori's n1.5-latest computer use model. |
| 5 | +n1.5-latest uses an OpenAI-compatible API with tool_calls: |
6 | 6 | - Actions are returned via tool_calls in the assistant message |
7 | 7 | - Tool results use role: "tool" with matching tool_call_id |
8 | 8 | - The model stops by returning content without tool_calls |
9 | 9 | - Coordinates are returned in 1000x1000 space and need scaling |
10 | 10 |
|
11 | | -@see https://docs.yutori.com/reference/n1 |
| 11 | +@see https://docs.yutori.com/reference/n1-5 |
12 | 12 | """ |
13 | 13 |
|
| 14 | +import copy |
14 | 15 | import json |
15 | 16 | from typing import Any, Optional |
16 | 17 |
|
17 | 18 | from kernel import Kernel |
18 | 19 | from openai import OpenAI |
19 | 20 |
|
20 | | -from tools import ComputerTool, N1Action, ToolResult |
| 21 | +from tools import ComputerTool, N15Action, ToolResult |
| 22 | + |
| 23 | +# Tools that require a Playwright page / DOM access. The default core tool set |
| 24 | +# already excludes them, but we also list them in `disable_tools` so the |
| 25 | +# exclusion is explicit and survives if the default ever changes. |
| 26 | +DISABLED_TOOLS = ["extract_elements", "find", "set_element_value", "execute_js"] |
| 27 | +TOOL_SET = "browser_tools_core-20260403" |
| 28 | + |
| 29 | +# Screenshot-trimming defaults mirror Yutori's reference loop: |
| 30 | +# https://github.com/yutori-ai/yutori-sdk-python/blob/main/yutori/navigator/payload.py |
| 31 | +# Trimming is size-triggered — we only drop old screenshots when the payload |
| 32 | +# exceeds MAX_REQUEST_BYTES, and we always keep at least KEEP_RECENT_SCREENSHOTS. |
| 33 | +MAX_REQUEST_BYTES = 9_500_000 |
| 34 | +KEEP_RECENT_SCREENSHOTS = 6 |
21 | 35 |
|
22 | 36 |
|
23 | 37 | async def sampling_loop( |
24 | 38 | *, |
25 | | - model: str = "n1-latest", |
| 39 | + model: str = "n1.5-latest", |
26 | 40 | task: str, |
27 | 41 | api_key: str, |
28 | 42 | kernel: Kernel, |
@@ -63,12 +77,23 @@ async def sampling_loop( |
63 | 77 | iteration += 1 |
64 | 78 | print(f"\n=== Iteration {iteration} ===") |
65 | 79 |
|
| 80 | + request_messages, dropped = _trimmed_for_request(conversation_messages) |
| 81 | + if dropped: |
| 82 | + print(f"Trimmed {dropped} old screenshot(s) to fit request size limit") |
| 83 | + |
66 | 84 | try: |
67 | 85 | response = client.chat.completions.create( |
68 | 86 | model=model, |
69 | | - messages=conversation_messages, |
| 87 | + messages=request_messages, |
70 | 88 | max_completion_tokens=max_completion_tokens, |
71 | 89 | temperature=0.3, |
| 90 | + # n1.5-specific knobs go in extra_body. |
| 91 | + # tool_set selects the core (coordinate-based) tools. |
| 92 | + # disable_tools is a defense-in-depth exclusion of DOM/Playwright tools. |
| 93 | + extra_body={ |
| 94 | + "tool_set": TOOL_SET, |
| 95 | + "disable_tools": DISABLED_TOOLS, |
| 96 | + }, |
72 | 97 | ) |
73 | 98 | except Exception as api_error: |
74 | 99 | print(f"API call failed: {api_error}") |
@@ -108,7 +133,7 @@ async def sampling_loop( |
108 | 133 | }) |
109 | 134 | continue |
110 | 135 |
|
111 | | - action: N1Action = {"action_type": action_name, **args} |
| 136 | + action: N15Action = {"action_type": action_name, **args} |
112 | 137 | print(f"Executing action: {action_name}", args) |
113 | 138 |
|
114 | 139 | scaled_action = _scale_coordinates(action, viewport_width, viewport_height) |
@@ -155,7 +180,86 @@ async def sampling_loop( |
155 | 180 | } |
156 | 181 |
|
157 | 182 |
|
158 | | -def _scale_coordinates(action: N1Action, viewport_width: int, viewport_height: int) -> N1Action: |
| 183 | +def _trimmed_for_request( |
| 184 | + messages: list[dict[str, Any]], |
| 185 | +) -> tuple[list[dict[str, Any]], int]: |
| 186 | + """Return a deep-copied messages list with old screenshots stripped to fit MAX_REQUEST_BYTES. |
| 187 | +
|
| 188 | + The most recent KEEP_RECENT_SCREENSHOTS screenshots are protected. The full |
| 189 | + `messages` list is preserved unchanged for the caller's return value. |
| 190 | + """ |
| 191 | + trimmed = copy.deepcopy(messages) |
| 192 | + size = _estimate_size(trimmed) |
| 193 | + if size <= MAX_REQUEST_BYTES: |
| 194 | + return trimmed, 0 |
| 195 | + |
| 196 | + image_indices = [i for i, m in enumerate(trimmed) if _message_has_image(m)] |
| 197 | + if not image_indices: |
| 198 | + return trimmed, 0 |
| 199 | + |
| 200 | + protected = set(image_indices[-max(1, KEEP_RECENT_SCREENSHOTS):]) |
| 201 | + removed = 0 |
| 202 | + |
| 203 | + for idx in image_indices: |
| 204 | + if size <= MAX_REQUEST_BYTES: |
| 205 | + break |
| 206 | + if idx in protected: |
| 207 | + continue |
| 208 | + if _strip_one_image(trimmed[idx]): |
| 209 | + removed += 1 |
| 210 | + size = _estimate_size(trimmed) |
| 211 | + |
| 212 | + # If still over, strip from the protected window too — but always keep the latest. |
| 213 | + if size > MAX_REQUEST_BYTES: |
| 214 | + last_idx = image_indices[-1] |
| 215 | + for idx in image_indices: |
| 216 | + if size <= MAX_REQUEST_BYTES: |
| 217 | + break |
| 218 | + if idx == last_idx: |
| 219 | + continue |
| 220 | + if _strip_one_image(trimmed[idx]): |
| 221 | + removed += 1 |
| 222 | + size = _estimate_size(trimmed) |
| 223 | + |
| 224 | + return trimmed, removed |
| 225 | + |
| 226 | + |
| 227 | +def _estimate_size(messages: list[dict[str, Any]]) -> int: |
| 228 | + return len(json.dumps(messages, separators=(",", ":"), ensure_ascii=False).encode("utf-8")) |
| 229 | + |
| 230 | + |
| 231 | +def _message_has_image(msg: dict[str, Any]) -> bool: |
| 232 | + content = msg.get("content") |
| 233 | + if not isinstance(content, list): |
| 234 | + return False |
| 235 | + return any(isinstance(p, dict) and p.get("type") == "image_url" for p in content) |
| 236 | + |
| 237 | + |
| 238 | +def _strip_one_image(msg: dict[str, Any]) -> bool: |
| 239 | + content = msg.get("content") |
| 240 | + if not isinstance(content, list): |
| 241 | + return False |
| 242 | + |
| 243 | + removed = False |
| 244 | + new_content: list[dict[str, Any]] = [] |
| 245 | + for part in content: |
| 246 | + if not removed and isinstance(part, dict) and part.get("type") == "image_url": |
| 247 | + removed = True |
| 248 | + continue |
| 249 | + new_content.append(part) |
| 250 | + |
| 251 | + if not removed: |
| 252 | + return False |
| 253 | + |
| 254 | + has_text = any(isinstance(p, dict) and p.get("type") == "text" for p in new_content) |
| 255 | + if not has_text: |
| 256 | + new_content.append({"type": "text", "text": "Screenshot omitted to stay under request size limit."}) |
| 257 | + |
| 258 | + msg["content"] = new_content |
| 259 | + return True |
| 260 | + |
| 261 | + |
| 262 | +def _scale_coordinates(action: N15Action, viewport_width: int, viewport_height: int) -> N15Action: |
159 | 263 | scaled = dict(action) |
160 | 264 |
|
161 | 265 | if "coordinates" in scaled and scaled["coordinates"]: |
|
0 commit comments