-
Notifications
You must be signed in to change notification settings - Fork 4
feat: add ask_user tool with host-side LLM interception #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: scale-customizations
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Config with ask_user tool for handling underspecified tasks | ||
| # Based on tool_use.yaml | ||
|
|
||
| agent: | ||
| templates: | ||
| system_template: |- | ||
| You are a helpful assistant that can interact with a computer to solve tasks. | ||
| instance_template: |- | ||
| <uploaded_files> | ||
| {{working_dir}} | ||
| </uploaded_files> | ||
| I've uploaded a python code repository in the directory {{working_dir}}. Consider the following PR description: | ||
|
|
||
| <pr_description> | ||
| {{problem_statement}} | ||
| </pr_description> | ||
|
|
||
| Can you help me implement the necessary changes to the repository so that the requirements specified in the <pr_description> are met? | ||
| I've already taken care of all changes to any of the test files described in the <pr_description>. This means you DON'T have to modify the testing logic or any of the tests in any way! | ||
| Your task is to make the minimal changes to non-tests files in the {{working_dir}} directory to ensure the <pr_description> is satisfied. | ||
| Follow these steps to resolve the issue: | ||
| 1. As a first step, it might be a good idea to find and read code relevant to the <pr_description> | ||
| 2. Create a script to reproduce the error and execute it with `python <filename.py>` using the bash tool, to confirm the error | ||
| 3. Edit the source code of the repo to resolve the issue | ||
| 4. Rerun your reproduce script and confirm that the error is fixed! | ||
| 5. Think about edgecases and make sure your fix handles them as well | ||
|
|
||
| IMPORTANT: Your output will be checked by an auto-grader looking for exact answers. | ||
| This task may be missing critical information. | ||
| Use the ask_user tool to ask me for any missing details. | ||
|
|
||
| Your thinking should be thorough and so it's fine if it's very long. | ||
| next_step_template: |- | ||
| OBSERVATION: | ||
| {{observation}} | ||
| next_step_no_output_template: |- | ||
| Your command ran successfully and did not produce any output. | ||
| tools: | ||
| execution_timeout: 450 | ||
| bundles: | ||
| - path: tools/registry | ||
| - path: tools/edit_anthropic | ||
| - path: tools/submit | ||
| - path: tools/ask_user | ||
| env_variables: | ||
| USE_FILEMAP: 'true' | ||
| enable_bash_tool: true | ||
| parse_function: | ||
| type: function_calling |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,10 +4,15 @@ | |||||
| import copy | ||||||
| import json | ||||||
| import logging | ||||||
| import os | ||||||
| import re | ||||||
| import shlex | ||||||
| import threading | ||||||
| import time | ||||||
| from pathlib import Path, PurePosixPath | ||||||
| from typing import Annotated, Any, Literal | ||||||
|
|
||||||
| import litellm | ||||||
| import yaml | ||||||
| from jinja2 import Template | ||||||
| from pydantic import BaseModel, ConfigDict, Field, model_validator | ||||||
|
|
@@ -57,6 +62,183 @@ | |||||
| from sweagent.utils.patch_formatter import PatchFormatter | ||||||
|
|
||||||
|
|
||||||
| # Global task definitions cache for ask_user interception | ||||||
| _task_definitions_lock = threading.Lock() | ||||||
| _task_definitions_cache: dict | None = None | ||||||
| _task_definitions_path: str | None = None | ||||||
|
|
||||||
|
|
||||||
| def _find_task_definitions_file(traj_path: Path | None) -> Path | None: | ||||||
| """Search for task_definitions.json in parent directories of traj_path. | ||||||
|
|
||||||
| Handles different directory structures: | ||||||
| - Single test: output_dir/instance_id/instance_id.traj (2 levels up) | ||||||
| - Benchmark: output_dir/exp_N/instance_id/instance_id.traj (3 levels up) | ||||||
| """ | ||||||
| if traj_path is None: | ||||||
| return None | ||||||
|
|
||||||
| # Search up to 4 parent levels for task_definitions.json | ||||||
| current = traj_path.parent # Start from directory containing .traj | ||||||
| for _ in range(4): | ||||||
| task_def_file = current / "task_definitions.json" | ||||||
| if task_def_file.exists(): | ||||||
| return task_def_file | ||||||
| current = current.parent | ||||||
| return None | ||||||
|
|
||||||
|
|
||||||
| def _load_task_definitions(traj_path: Path | None) -> dict | None: | ||||||
| """Load task definitions by searching parent directories.""" | ||||||
| global _task_definitions_cache, _task_definitions_path | ||||||
|
|
||||||
| task_def_file = _find_task_definitions_file(traj_path) | ||||||
| if task_def_file is None: | ||||||
| return None | ||||||
|
|
||||||
| task_def_str = str(task_def_file) | ||||||
|
|
||||||
| with _task_definitions_lock: | ||||||
| # Return cached if same file | ||||||
| if _task_definitions_path == task_def_str and _task_definitions_cache is not None: | ||||||
| return _task_definitions_cache | ||||||
|
|
||||||
| try: | ||||||
| with open(task_def_file, "r") as f: | ||||||
| _task_definitions_cache = json.load(f) | ||||||
| _task_definitions_path = task_def_str | ||||||
| return _task_definitions_cache | ||||||
| except Exception: | ||||||
| pass | ||||||
| return None | ||||||
|
|
||||||
|
|
||||||
| def _handle_ask_user_on_host(question: str, context: str, instance_id: str, traj_path: Path | None, logger) -> str: | ||||||
| """Handle ask_user command on the host side using litellm. | ||||||
|
|
||||||
| This function intercepts ask_user calls to run the LLM call on the host, | ||||||
| which can reach internal API endpoints that the container cannot access. | ||||||
| """ | ||||||
| task_defs = _load_task_definitions(traj_path) | ||||||
|
|
||||||
| if task_defs is None or instance_id not in task_defs: | ||||||
| return f"Error: No task definition found for instance {instance_id}" | ||||||
|
|
||||||
| task_def = task_defs[instance_id] | ||||||
| primary_task = task_def.get("primary_task", "") | ||||||
| underspecified_prompt = task_def.get("underspecified_task", "") | ||||||
|
|
||||||
| # Extract removed values | ||||||
| removed_values = [] | ||||||
| if "removed_segments" in task_def: | ||||||
| for seg in task_def["removed_segments"]: | ||||||
| if isinstance(seg, dict) and seg.get("value"): | ||||||
| removed_values.append(seg["value"]) | ||||||
|
|
||||||
| removed_values_str = ", ".join(removed_values) if removed_values else "None specified" | ||||||
| underspec_str = underspecified_prompt or "Not provided" | ||||||
|
|
||||||
| system_prompt = f"""You are simulating a user who has a task in mind but didn't fully specify it. | ||||||
|
|
||||||
| The user originally intended to give this COMPLETE prompt: | ||||||
| {primary_task} | ||||||
|
|
||||||
| But they actually gave this UNDERSPECIFIED version: | ||||||
| {underspec_str} | ||||||
|
|
||||||
| The parts that were removed/made vague: | ||||||
| {removed_values_str} | ||||||
|
|
||||||
| An AI assistant (who only sees the underspecified version) is now asking you a clarifying question. | ||||||
|
|
||||||
| Your job: Compare the two prompts, find what's MISSING from the underspecified version, and provide the EXACT information from the complete prompt. | ||||||
|
|
||||||
| Guidelines: | ||||||
| - Find the EXACT values that are in the complete prompt but missing from the underspecified one | ||||||
| - Provide those specific values (times, names, dates, numbers, phrases, etc.) | ||||||
| - Be concise - just answer what's asked | ||||||
| - Don't reveal you're a simulation | ||||||
|
|
||||||
| ENVIRONMENT CONTEXT: | ||||||
| - The agent is working in a repository at /workspace (or the working directory specified in the prompt) | ||||||
| - The agent has full access to the repository files | ||||||
| - The agent can read, write, and execute files in the repository | ||||||
| - When providing file paths, use paths relative to the repository root | ||||||
| """ | ||||||
|
|
||||||
| user_prompt = f"The assistant asks: {question}" | ||||||
| if context: | ||||||
| user_prompt += f"\n\nContext: {context}" | ||||||
|
|
||||||
| messages = [ | ||||||
| {"role": "system", "content": system_prompt}, | ||||||
| {"role": "user", "content": user_prompt}, | ||||||
| ] | ||||||
|
|
||||||
| # Get API credentials from environment | ||||||
| api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY") | ||||||
| api_base = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL") | ||||||
| model = os.environ.get("USER_SIMULATOR_MODEL", "openai/gpt-4.1-2025-04-14") | ||||||
|
|
||||||
| if not api_key: | ||||||
| return "Error: No API key available for user simulation" | ||||||
|
|
||||||
| try: | ||||||
| logger.info(f"ask_user intercepted on host: question='{question[:100]}...'") | ||||||
| # Drop unsupported params for models like GPT-5 that don't support temperature | ||||||
| litellm.drop_params = True | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Global side-effect on all litellm calls
This is especially problematic in multi-worker mode ( Consider passing
Suggested change
Or scope the change more carefully by saving and restoring the original value. Prompt To Fix With AIThis is a comment left during a code review.
Path: sweagent/agent/agents.py
Line: 189
Comment:
**Global side-effect on all litellm calls**
`litellm.drop_params` is a module-level global. Setting it to `True` here permanently affects all subsequent `litellm.completion` calls in the process, including the main agent model calls in `models.py:693`. This means unsupported parameters will be silently dropped for the agent's own LLM calls too, which could mask configuration errors or change behavior unexpectedly.
This is especially problematic in multi-worker mode (`ThreadPoolExecutor` in `run_batch.py:288`) where a single `ask_user` call from one worker mutates global state for all workers.
Consider passing `drop_params` as a per-call parameter instead:
```suggestion
litellm.drop_params = True # TODO: this mutates global state; move to per-call param if litellm supports it
```
Or scope the change more carefully by saving and restoring the original value.
How can I resolve this? If you propose a fix, please make it concise. |
||||||
| response = litellm.completion( | ||||||
| model=model, | ||||||
| messages=messages, | ||||||
| temperature=1, # GPT-5 only supports temperature=1 | ||||||
| api_key=api_key, | ||||||
| api_base=api_base, | ||||||
| timeout=30, | ||||||
| ) | ||||||
| result = response.choices[0].message.content or "" | ||||||
| logger.info(f"ask_user response generated: '{result[:100]}...'") | ||||||
| return result | ||||||
| except Exception as e: | ||||||
| logger.error(f"ask_user LLM call failed: {e}") | ||||||
| return f"Error generating user response: {str(e)}" | ||||||
|
|
||||||
|
|
||||||
| def _parse_ask_user_command(command: str) -> tuple[str, str] | None: | ||||||
| """Parse ask_user command to extract question and optional context. | ||||||
|
|
||||||
| Handles formats like: | ||||||
| - ask_user "question" | ||||||
| - ask_user "question" "context" | ||||||
| - ask_user 'question' | ||||||
| - ask_user question_without_quotes | ||||||
|
|
||||||
| Returns (question, context) tuple or None if not an ask_user command. | ||||||
| """ | ||||||
| command = command.strip() | ||||||
| if not (command == "ask_user" or command.startswith("ask_user ") or command.startswith("ask_user\t")): | ||||||
| return None | ||||||
|
|
||||||
| # Remove the "ask_user" prefix | ||||||
| args_str = command[8:].strip() | ||||||
| if not args_str: | ||||||
| return None | ||||||
|
|
||||||
| try: | ||||||
| # Use shlex to properly parse quoted arguments | ||||||
| args = shlex.split(args_str) | ||||||
| if len(args) >= 1: | ||||||
| question = args[0] | ||||||
| context = args[1] if len(args) > 1 else "" | ||||||
| return (question, context) | ||||||
| except ValueError: | ||||||
| # Fallback: try simple quote extraction | ||||||
| match = re.match(r'["\'](.+?)["\'](?:\s+["\'](.+?)["\'])?', args_str) | ||||||
| if match: | ||||||
| return (match.group(1), match.group(2) or "") | ||||||
| # Last resort: treat entire string as question | ||||||
| return (args_str, "") | ||||||
|
|
||||||
|
|
||||||
| class TemplateConfig(BaseModel): | ||||||
| """This configuration is used to define almost all message templates that are | ||||||
| formatted by the agent and sent to the LM. | ||||||
|
|
@@ -943,6 +1125,26 @@ def handle_action(self, step: StepOutput) -> StepOutput: | |||||
| self._chook.on_action_started(step=step) | ||||||
| execution_t0 = time.perf_counter() | ||||||
| run_action: str = self.tools.guard_multiline_input(step.action).strip() | ||||||
|
|
||||||
| # Intercept ask_user commands and handle on HOST side | ||||||
| # This is needed because Modal containers cannot reach internal API endpoints | ||||||
| ask_user_args = _parse_ask_user_command(run_action) | ||||||
| if ask_user_args is not None: | ||||||
| question, context = ask_user_args | ||||||
| instance_id = self._problem_statement.id if self._problem_statement else "unknown" | ||||||
| step.observation = _handle_ask_user_on_host( | ||||||
| question=question, | ||||||
| context=context, | ||||||
| instance_id=instance_id, | ||||||
| traj_path=self.traj_path, | ||||||
| logger=self.logger, | ||||||
| ) | ||||||
| step.execution_time = time.perf_counter() - execution_t0 | ||||||
| self._total_execution_time += step.execution_time | ||||||
| self._chook.on_action_executed(step=step) | ||||||
| step.state = self.tools.get_state(env=self._env) | ||||||
| return self.handle_submission(step) | ||||||
|
|
||||||
| try: | ||||||
| step.observation = self._env.communicate( | ||||||
| input=run_action, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silent exception swallowing hides load failures
The bare
except Exception: passhere silently swallows all errors when loading task definitions (malformed JSON, permission errors, encoding issues, etc.). Whenask_userlater fails with "No task definition found for instance X", there will be no indication that the file existed but couldn't be parsed.At minimum, log the exception:
Prompt To Fix With AI