From 18a9b14ce3717df8f4525b72252c2412677f99eb Mon Sep 17 00:00:00 2001 From: yash-scaleai Date: Wed, 28 Jan 2026 10:45:29 +0000 Subject: [PATCH 1/3] feat: add ask_user tool with host-side LLM interception Adds an ask_user tool that allows agents to request clarification on underspecified tasks. The tool simulates user responses using an LLM based on the complete task definition. Key changes: - Add ask_user tool in tools/ask_user/ - Add TaskDefinitionInjectionHook to inject task definitions into containers - Auto-detect task_definitions.json in run_batch.py - **Critical fix**: Intercept ask_user commands on HOST side in agents.py The host-side interception is necessary because Modal containers cannot reach internal API endpoints (e.g., litellm.ml-serving-internal.scale.com). By handling the LLM call on the host (where the agent's own LLM calls are made), we avoid timeout issues while maintaining the same functionality. Based on Bryan's PR #17144 with architectural modification for internal API compatibility. Co-Authored-By: Claude Opus 4.5 --- config/tool_use_with_ask_user.yaml | 49 +++++ sweagent/agent/agents.py | 184 ++++++++++++++++++ .../run/hooks/task_definition_injection.py | 123 ++++++++++++ sweagent/run/run_batch.py | 12 ++ tools/ask_user/bin/ask_user | 172 ++++++++++++++++ tools/ask_user/bin/llm.py | 70 +++++++ tools/ask_user/config.yaml | 13 ++ tools/ask_user/install.sh | 5 + tools/ask_user/requirements.txt | 2 + 9 files changed, 630 insertions(+) create mode 100644 config/tool_use_with_ask_user.yaml create mode 100644 sweagent/run/hooks/task_definition_injection.py create mode 100755 tools/ask_user/bin/ask_user create mode 100644 tools/ask_user/bin/llm.py create mode 100644 tools/ask_user/config.yaml create mode 100644 tools/ask_user/install.sh create mode 100644 tools/ask_user/requirements.txt diff --git a/config/tool_use_with_ask_user.yaml b/config/tool_use_with_ask_user.yaml new file mode 100644 index 0000000000..ebfa8b624b --- /dev/null +++ b/config/tool_use_with_ask_user.yaml @@ -0,0 +1,49 @@ +# Config with ask_user tool for handling underspecified tasks +# Based on tool_use.yaml + +agent: + templates: + system_template: |- + You are a helpful assistant that can interact with a computer to solve tasks. + instance_template: |- + + {{working_dir}} + + I've uploaded a python code repository in the directory {{working_dir}}. Consider the following PR description: + + + {{problem_statement}} + + + Can you help me implement the necessary changes to the repository so that the requirements specified in the are met? + I've already taken care of all changes to any of the test files described in the . This means you DON'T have to modify the testing logic or any of the tests in any way! + Your task is to make the minimal changes to non-tests files in the {{working_dir}} directory to ensure the is satisfied. + Follow these steps to resolve the issue: + 1. As a first step, it might be a good idea to find and read code relevant to the + 2. Create a script to reproduce the error and execute it with `python ` using the bash tool, to confirm the error + 3. Edit the source code of the repo to resolve the issue + 4. Rerun your reproduce script and confirm that the error is fixed! + 5. Think about edgecases and make sure your fix handles them as well + + IMPORTANT: Your output will be checked by an auto-grader looking for exact answers. + This task may be missing critical information. + Use the ask_user tool to ask me for any missing details. + + Your thinking should be thorough and so it's fine if it's very long. + next_step_template: |- + OBSERVATION: + {{observation}} + next_step_no_output_template: |- + Your command ran successfully and did not produce any output. + tools: + execution_timeout: 450 + bundles: + - path: tools/registry + - path: tools/edit_anthropic + - path: tools/submit + - path: tools/ask_user + env_variables: + USE_FILEMAP: 'true' + enable_bash_tool: true + parse_function: + type: function_calling diff --git a/sweagent/agent/agents.py b/sweagent/agent/agents.py index d720b07dc8..2eda9e7a3c 100644 --- a/sweagent/agent/agents.py +++ b/sweagent/agent/agents.py @@ -4,10 +4,14 @@ import copy import json import logging +import os +import re +import shlex import time from pathlib import Path, PurePosixPath from typing import Annotated, Any, Literal +import litellm import yaml from jinja2 import Template from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -57,6 +61,164 @@ from sweagent.utils.patch_formatter import PatchFormatter +# Global task definitions cache for ask_user interception +_task_definitions_cache: dict | None = None +_task_definitions_path: str | None = None + + +def _load_task_definitions(output_dir: Path | None) -> dict | None: + """Load task definitions from output directory if available.""" + global _task_definitions_cache, _task_definitions_path + + if output_dir is None: + return None + + task_def_file = output_dir / "task_definitions.json" + task_def_str = str(task_def_file) + + # Return cached if same file + if _task_definitions_path == task_def_str and _task_definitions_cache is not None: + return _task_definitions_cache + + if task_def_file.exists(): + try: + with open(task_def_file, "r") as f: + _task_definitions_cache = json.load(f) + _task_definitions_path = task_def_str + return _task_definitions_cache + except Exception: + pass + return None + + +def _handle_ask_user_on_host(question: str, context: str, instance_id: str, output_dir: Path | None, logger) -> str: + """Handle ask_user command on the host side using litellm. + + This function intercepts ask_user calls to run the LLM call on the host, + which can reach internal API endpoints that the container cannot access. + """ + task_defs = _load_task_definitions(output_dir) + + if task_defs is None or instance_id not in task_defs: + return f"Error: No task definition found for instance {instance_id}" + + task_def = task_defs[instance_id] + primary_task = task_def.get("primary_task", "") + underspecified_prompt = task_def.get("underspecified_task", "") + + # Extract removed values + removed_values = [] + if "removed_segments" in task_def: + for seg in task_def["removed_segments"]: + if isinstance(seg, dict) and seg.get("value"): + removed_values.append(seg["value"]) + + removed_values_str = ", ".join(removed_values) if removed_values else "None specified" + underspec_str = underspecified_prompt or "Not provided" + + system_prompt = f"""You are simulating a user who has a task in mind but didn't fully specify it. + +The user originally intended to give this COMPLETE prompt: +{primary_task} + +But they actually gave this UNDERSPECIFIED version: +{underspec_str} + +The parts that were removed/made vague: +{removed_values_str} + +An AI assistant (who only sees the underspecified version) is now asking you a clarifying question. + +Your job: Compare the two prompts, find what's MISSING from the underspecified version, and provide the EXACT information from the complete prompt. + +Guidelines: +- Find the EXACT values that are in the complete prompt but missing from the underspecified one +- Provide those specific values (times, names, dates, numbers, phrases, etc.) +- Be concise - just answer what's asked +- Don't reveal you're a simulation + +ENVIRONMENT CONTEXT: +- The agent is working in a repository at /workspace (or the working directory specified in the prompt) +- The agent has full access to the repository files +- The agent can read, write, and execute files in the repository +- When providing file paths, use paths relative to the repository root +""" + + user_prompt = f"The assistant asks: {question}" + if context: + user_prompt += f"\n\nContext: {context}" + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + # Get API credentials from environment + api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY") + api_base = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL") + model = os.environ.get("USER_SIMULATOR_MODEL", "openai/gpt-5.2") + + if not api_key: + return "Error: No API key available for user simulation" + + try: + logger.info(f"ask_user intercepted on host: question='{question[:100]}...'") + # Drop unsupported params for models like GPT-5 that don't support temperature + litellm.drop_params = True + response = litellm.completion( + model=model, + messages=messages, + temperature=1, # GPT-5 only supports temperature=1 + api_key=api_key, + api_base=api_base, + timeout=30, + ) + result = response.choices[0].message.content + logger.info(f"ask_user response generated: '{result[:100]}...'") + return result + except Exception as e: + logger.error(f"ask_user LLM call failed: {e}") + return f"Error generating user response: {str(e)}" + + +def _parse_ask_user_command(command: str) -> tuple[str, str] | None: + """Parse ask_user command to extract question and optional context. + + Handles formats like: + - ask_user "question" + - ask_user "question" "context" + - ask_user 'question' + - ask_user question_without_quotes + + Returns (question, context) tuple or None if not an ask_user command. + """ + command = command.strip() + if not command.startswith("ask_user"): + return None + + # Remove the "ask_user" prefix + args_str = command[8:].strip() + if not args_str: + return None + + try: + # Use shlex to properly parse quoted arguments + args = shlex.split(args_str) + if len(args) >= 1: + question = args[0] + context = args[1] if len(args) > 1 else "" + return (question, context) + except ValueError: + # Fallback: try simple quote extraction + match = re.match(r'["\'](.+?)["\'](?:\s+["\'](.+?)["\'])?', args_str) + if match: + return (match.group(1), match.group(2) or "") + # Last resort: treat entire string as question + return (args_str, "") + + return None + + class TemplateConfig(BaseModel): """This configuration is used to define almost all message templates that are formatted by the agent and sent to the LM. @@ -943,6 +1105,28 @@ def handle_action(self, step: StepOutput) -> StepOutput: self._chook.on_action_started(step=step) execution_t0 = time.perf_counter() run_action: str = self.tools.guard_multiline_input(step.action).strip() + + # Intercept ask_user commands and handle on HOST side + # This is needed because Modal containers cannot reach internal API endpoints + ask_user_args = _parse_ask_user_command(run_action) + if ask_user_args is not None: + question, context = ask_user_args + instance_id = self._problem_statement.id if self._problem_statement else "unknown" + # output_dir is parent of traj_path's parent (traj_path = output_dir/instance_id/instance_id.traj) + output_dir = self.traj_path.parent.parent if self.traj_path else None + step.observation = _handle_ask_user_on_host( + question=question, + context=context, + instance_id=instance_id, + output_dir=output_dir, + logger=self.logger, + ) + step.execution_time = time.perf_counter() - execution_t0 + self._total_execution_time += step.execution_time + self._chook.on_action_executed(step=step) + step.state = self.tools.get_state(env=self._env) + return self.handle_submission(step) + try: step.observation = self._env.communicate( input=run_action, diff --git a/sweagent/run/hooks/task_definition_injection.py b/sweagent/run/hooks/task_definition_injection.py new file mode 100644 index 0000000000..988300baa8 --- /dev/null +++ b/sweagent/run/hooks/task_definition_injection.py @@ -0,0 +1,123 @@ +""" +RunHook for injecting complete task definitions into containers for ask_user tool. + +This hook writes the full task definition (including underspecified version and removed +segments) to a file in the container so the ask_user tool can access it. +""" + +import json +import logging +import os +from pathlib import Path +from typing import Any, Dict, Optional + +from sweagent.agent.problem_statement import ProblemStatement, ProblemStatementConfig +from sweagent.environment.swe_env import SWEEnv +from sweagent.run.hooks.abstract import RunHook + +logger = logging.getLogger(__name__) + + +class TaskDefinitionInjectionHook(RunHook): + """ + Inject complete task definitions into container for ask_user tool. + + Writes task definition to /tmp/task_definition.json in the container, which + the ask_user tool reads to provide accurate clarifications. + """ + + def __init__( + self, + task_definitions: Optional[Dict[str, Dict[str, Any]]] = None, + task_definitions_file: Optional[Path] = None, + ): + """ + Initialize hook with task definitions. + + Args: + task_definitions: Dict mapping instance_id to task definition, OR + task_definitions_file: Path to JSON file with task definitions + + Task definition should contain: + - primary_task: Complete task description + - underspecified_task: Partial task given to agent + - removed_segments: List of removed segments + - expected_questions: Expected clarification questions + """ + super().__init__() + self.task_definitions = task_definitions + self.task_definitions_file = task_definitions_file + + def _load_task_definitions(self) -> Dict[str, Dict[str, Any]]: + """Load task definitions from file or return cached dict.""" + if self.task_definitions is not None: + return self.task_definitions + + if self.task_definitions_file and self.task_definitions_file.exists(): + with open(self.task_definitions_file) as f: + return json.load(f) + + return {} + + def on_instance_start( + self, + *, + index: int, + env: SWEEnv, + problem_statement: ProblemStatement | ProblemStatementConfig, + ) -> None: + """ + Inject task definition into container before agent starts. + + Called after environment is ready but before agent.setup(). + """ + instance_id = problem_statement.id + + # Load task definitions + task_definitions = self._load_task_definitions() + + # Check if we have a task definition for this instance + if instance_id not in task_definitions: + logger.debug(f"No task definition found for instance {instance_id}, skipping injection") + return + + task_def = task_definitions[instance_id] + + # Write task definition to container + task_def_path = "/tmp/task_definition.json" + task_def_json = json.dumps(task_def, indent=2) + + try: + # Write file to container using swerex + logger.info(f"Injecting task definition for {instance_id} to {task_def_path}") + + # Create a temporary file write command + command = f"cat > {task_def_path} << 'TASK_DEFINITION_EOF'\n{task_def_json}\nTASK_DEFINITION_EOF" + env.communicate(command, check="raise") + + # Set environment variables for the ask_user tool + env_vars = { + "TASK_DEFINITION_PATH": task_def_path, + "HAS_TASK_DEFINITION": "true", + } + + # Pass through API credentials for the user simulator LLM + api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY") + base_url = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL") + simulator_model = os.environ.get("USER_SIMULATOR_MODEL") + + if api_key: + env_vars["OPENAI_API_KEY"] = api_key + env_vars["LLM_API_KEY"] = api_key + if base_url: + env_vars["OPENAI_BASE_URL"] = base_url + env_vars["LLM_BASE_URL"] = base_url + if simulator_model: + env_vars["USER_SIMULATOR_MODEL"] = simulator_model + + env.set_env_variables(env_vars) + + logger.info(f"Successfully injected task definition for {instance_id}") + except Exception as e: + logger.error(f"Failed to inject task definition for {instance_id}: {e}") + # Don't raise - let the run continue even if injection fails diff --git a/sweagent/run/run_batch.py b/sweagent/run/run_batch.py index 99bae59d9f..9ebdcf5fec 100644 --- a/sweagent/run/run_batch.py +++ b/sweagent/run/run_batch.py @@ -231,6 +231,18 @@ def from_config(cls, config: RunBatchConfig) -> Self: continuous_submission_every=30, ) ) + + # Auto-add TaskDefinitionInjectionHook if task definitions file exists + # This enables the ask_user tool to access complete task specifications + task_def_file = config.output_dir / "task_definitions.json" + if task_def_file.exists(): + from sweagent.run.hooks.task_definition_injection import TaskDefinitionInjectionHook + + logger.info( + f"Found task definitions file, enabling ask_user tool support: {task_def_file}" + ) + rb.add_hook(TaskDefinitionInjectionHook(task_definitions_file=task_def_file)) + return rb def add_hook(self, hook: RunHook) -> None: diff --git a/tools/ask_user/bin/ask_user b/tools/ask_user/bin/ask_user new file mode 100755 index 0000000000..9670b979ac --- /dev/null +++ b/tools/ask_user/bin/ask_user @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +""" +Ask user tool for SWE-bench Pro. + +Allows the agent to request clarification from the task assigner. +Responses are generated by an LLM to simulate the human who assigned the task +and has access to the complete task requirements. +""" + +import json +import os +import sys +from pathlib import Path + +# Add the parent directory to the path to import llm utilities +sys.path.insert(0, str(Path(__file__).parent)) + +from llm import llm_completion, logger # noqa: E402 + +# Path to the primary task description (injected by TaskDefinitionInjectionHook) +PRIMARY_TASK_PATH = os.environ.get("TASK_DEFINITION_PATH", "/tmp/task_definition.json") +# Use gpt-5.2 as default (same model as agent, known to work) +USER_SIMULATOR_MODEL = os.environ.get("USER_SIMULATOR_MODEL", "openai/gpt-5.2") + +# Default task context note for SWE-bench Pro environment +DEFAULT_TASK_CONTEXT = """ENVIRONMENT CONTEXT: +- The agent is working in a repository at /workspace (or the working directory specified in the prompt) +- The agent has full access to the repository files +- The agent can read, write, and execute files in the repository +- When providing file paths, use paths relative to the repository root +""".strip() + + +def load_primary_task(): + """Load the primary task description from the mounted file. + + Returns: + Tuple of (primary_task, underspecified_prompt, removed_values) + """ + try: + if os.path.exists(PRIMARY_TASK_PATH): + with open(PRIMARY_TASK_PATH, "r", encoding="utf-8") as f: + task_def = json.load(f) + + primary_task = task_def.get("primary_task", "") + underspecified_prompt = task_def.get("underspecified_task", "") + + # Extract removed values from removed_segments + removed_values = [] + if "removed_segments" in task_def: + for seg in task_def["removed_segments"]: + if isinstance(seg, dict) and seg.get("value"): + removed_values.append(seg["value"]) + elif "expected_questions" in task_def: + for eq in task_def["expected_questions"]: + if isinstance(eq, dict) and "segment" in eq: + segment_val = eq.get("segment", {}).get("value", "") + if segment_val: + removed_values.append(segment_val) + + return primary_task, underspecified_prompt, removed_values + else: + return "No primary task definition available.", "", [] + except Exception as e: + return f"Error reading primary task: {str(e)}", "", [] + + +def generate_user_response(question, context, primary_task, underspecified_prompt, removed_values): + """Generate a human-like response using an LLM. + + Args: + question: The agent's question + context: Context about what the agent was doing + primary_task: The complete task requirements + underspecified_prompt: The partial prompt the agent actually sees + removed_values: List of values that were removed from the complete prompt + + Returns: + A natural human-like response that answers the question + """ + # Check if API credentials are available + api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY") + base_url = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL") + + if not api_key or not base_url: + raise ValueError("LLM_API_KEY or LLM_BASE_URL not available. Cannot generate response.") + + # Format removed values for the prompt + removed_values_str = ", ".join(removed_values) if removed_values else "None specified" + underspec_str = underspecified_prompt or "Not provided" + + system_prompt = f"""You are simulating a user who has a task in mind but didn't fully specify it. + +The user originally intended to give this COMPLETE prompt: +{primary_task} + +But they actually gave this UNDERSPECIFIED version: +{underspec_str} + +The parts that were removed/made vague: +{removed_values_str} + +An AI assistant (who only sees the underspecified version) is now asking you a clarifying question. + +Your job: Compare the two prompts, find what's MISSING from the underspecified version, and provide the EXACT information from the complete prompt. + +Guidelines: +- Find the EXACT values that are in the complete prompt but missing from the underspecified one +- Provide those specific values (times, names, dates, numbers, phrases, etc.) +- Be concise - just answer what's asked +- Don't reveal you're a simulation + +{DEFAULT_TASK_CONTEXT} +""" + + user_prompt = f"The assistant asks: {question}" + if context: + user_prompt += f"\n\nContext: {context}" + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + try: + response = llm_completion( + model=USER_SIMULATOR_MODEL, + messages=messages, + temperature=0.7, + timeout=30, + ) + return response + except Exception as e: + # Print full error details for debugging + import traceback + error_details = traceback.format_exc() + print(f"DEBUG: User simulator error: {e}", file=sys.stderr) + print(f"DEBUG: Full traceback:\n{error_details}", file=sys.stderr) + print(f"DEBUG: API_KEY set: {bool(api_key)}", file=sys.stderr) + print(f"DEBUG: BASE_URL: {base_url}", file=sys.stderr) + print(f"DEBUG: MODEL: {USER_SIMULATOR_MODEL}", file=sys.stderr) + return f"Error calling user simulator: {e}" + + +def main(): + if len(sys.argv) < 2: + print("Usage: ask_user []") + sys.exit(1) + + question = sys.argv[1] + context = sys.argv[2] if len(sys.argv) > 2 else "" + + # Load the primary task description + primary_task, underspecified_prompt, removed_values = load_primary_task() + + # Generate and print the user's response + try: + response = generate_user_response( + question=question, + context=context, + primary_task=primary_task, + underspecified_prompt=underspecified_prompt, + removed_values=removed_values, + ) + print(response) + except Exception as e: + print(f"Error: {str(e)}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tools/ask_user/bin/llm.py b/tools/ask_user/bin/llm.py new file mode 100644 index 0000000000..70d25aa7ad --- /dev/null +++ b/tools/ask_user/bin/llm.py @@ -0,0 +1,70 @@ +""" +LLM utilities for the ask_user tool. + +Uses OpenAI client with optional LiteLLM proxy for multi-provider routing. +""" + +import logging +import os + +from openai import OpenAI +from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential + +logger = logging.getLogger(__name__) +logging.getLogger("httpx").setLevel(logging.WARNING) + +# Module-level client (initialized lazily) +_client = None + + +def get_openai_client() -> OpenAI: + """ + Get an OpenAI client, optionally configured for LiteLLM proxy. + + Environment variables: + OPENAI_API_KEY or LLM_API_KEY: Your API key + OPENAI_BASE_URL or LLM_BASE_URL: Proxy URL for multi-provider routing + """ + global _client + if _client is not None: + return _client + + api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY") + if not api_key: + raise ValueError( + "OPENAI_API_KEY or LLM_API_KEY environment variable is required. " + "Set with: export OPENAI_API_KEY='your-key' or create a .env file" + ) + + base_url = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL") + # Note: base_url can be None for direct OpenAI API usage + + _client = OpenAI( + api_key=api_key, + base_url=base_url, # None = use default OpenAI API + max_retries=0, # Retries handled by tenacity for better logging + ) + return _client + + +@retry( + stop=stop_after_attempt(1), # No retries - fail fast for debugging + wait=wait_exponential(multiplier=1, exp_base=2, max=5), + before_sleep=before_sleep_log(logger, logging.WARNING), +) +def llm_completion(model: str, messages: list, **kwargs) -> str: + """ + LLM completion with automatic retries. + + Args: + model: Model identifier (e.g., "openai/gpt-4o" for LiteLLM proxy) + messages: List of message dicts with "role" and "content" + **kwargs: Additional arguments passed to chat.completions.create + + Returns: + Response text from the model + """ + client = get_openai_client() + + response = client.chat.completions.create(model=model, messages=messages, **kwargs) + return response.choices[0].message.content diff --git a/tools/ask_user/config.yaml b/tools/ask_user/config.yaml new file mode 100644 index 0000000000..32cf87abd8 --- /dev/null +++ b/tools/ask_user/config.yaml @@ -0,0 +1,13 @@ +tools: + ask_user: + signature: "ask_user []" + docstring: "Ask the user a clarifying question to get more information about the task. Use this when the task is ambiguous or you need specific details to proceed." + arguments: + - name: question + type: string + description: "The clarifying question to ask the user" + required: true + - name: context + type: string + description: "Optional additional context (e.g., conversation history summary)" + required: false diff --git a/tools/ask_user/install.sh b/tools/ask_user/install.sh new file mode 100644 index 0000000000..fa350d6c4e --- /dev/null +++ b/tools/ask_user/install.sh @@ -0,0 +1,5 @@ +#!/bin/bash +# Install dependencies for ask_user tool + +pip install -q -r "$(dirname "$0")/requirements.txt" +echo "ask_user tool dependencies installed" diff --git a/tools/ask_user/requirements.txt b/tools/ask_user/requirements.txt new file mode 100644 index 0000000000..82befcb561 --- /dev/null +++ b/tools/ask_user/requirements.txt @@ -0,0 +1,2 @@ +openai>=1.0.0 +tenacity>=8.0.0 From 4805771f5e232349f94969f8f78a7f2ee435d2f4 Mon Sep 17 00:00:00 2001 From: yash-scaleai Date: Fri, 20 Feb 2026 00:50:08 +0000 Subject: [PATCH 2/3] fix: thread-safe cache, tighten ask_user parsing, guard None result Includes post-experiment working tree fixes (traj_path refactor, _find_task_definitions_file search) that ran during Phase B. Safety fixes (no behavior change): - Add threading.Lock around task definitions cache - Tighten _parse_ask_user_command to exact word boundary - Guard result against None before slicing for log - Remove no-op @retry(stop_after_attempt(1)) in container llm.py Co-Authored-By: Claude Opus 4.6 --- sweagent/agent/agents.py | 54 ++++++++++++++++++++++----------- tools/ask_user/bin/llm.py | 8 +---- tools/ask_user/requirements.txt | 1 - 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/sweagent/agent/agents.py b/sweagent/agent/agents.py index 2eda9e7a3c..51122f1560 100644 --- a/sweagent/agent/agents.py +++ b/sweagent/agent/agents.py @@ -7,6 +7,7 @@ import os import re import shlex +import threading import time from pathlib import Path, PurePosixPath from typing import Annotated, Any, Literal @@ -62,25 +63,46 @@ # Global task definitions cache for ask_user interception +_task_definitions_lock = threading.Lock() _task_definitions_cache: dict | None = None _task_definitions_path: str | None = None -def _load_task_definitions(output_dir: Path | None) -> dict | None: - """Load task definitions from output directory if available.""" +def _find_task_definitions_file(traj_path: Path | None) -> Path | None: + """Search for task_definitions.json in parent directories of traj_path. + + Handles different directory structures: + - Single test: output_dir/instance_id/instance_id.traj (2 levels up) + - Benchmark: output_dir/exp_N/instance_id/instance_id.traj (3 levels up) + """ + if traj_path is None: + return None + + # Search up to 4 parent levels for task_definitions.json + current = traj_path.parent # Start from directory containing .traj + for _ in range(4): + task_def_file = current / "task_definitions.json" + if task_def_file.exists(): + return task_def_file + current = current.parent + return None + + +def _load_task_definitions(traj_path: Path | None) -> dict | None: + """Load task definitions by searching parent directories.""" global _task_definitions_cache, _task_definitions_path - if output_dir is None: + task_def_file = _find_task_definitions_file(traj_path) + if task_def_file is None: return None - task_def_file = output_dir / "task_definitions.json" task_def_str = str(task_def_file) - # Return cached if same file - if _task_definitions_path == task_def_str and _task_definitions_cache is not None: - return _task_definitions_cache + with _task_definitions_lock: + # Return cached if same file + if _task_definitions_path == task_def_str and _task_definitions_cache is not None: + return _task_definitions_cache - if task_def_file.exists(): try: with open(task_def_file, "r") as f: _task_definitions_cache = json.load(f) @@ -91,13 +113,13 @@ def _load_task_definitions(output_dir: Path | None) -> dict | None: return None -def _handle_ask_user_on_host(question: str, context: str, instance_id: str, output_dir: Path | None, logger) -> str: +def _handle_ask_user_on_host(question: str, context: str, instance_id: str, traj_path: Path | None, logger) -> str: """Handle ask_user command on the host side using litellm. This function intercepts ask_user calls to run the LLM call on the host, which can reach internal API endpoints that the container cannot access. """ - task_defs = _load_task_definitions(output_dir) + task_defs = _load_task_definitions(traj_path) if task_defs is None or instance_id not in task_defs: return f"Error: No task definition found for instance {instance_id}" @@ -156,7 +178,7 @@ def _handle_ask_user_on_host(question: str, context: str, instance_id: str, outp # Get API credentials from environment api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY") api_base = os.environ.get("OPENAI_BASE_URL") or os.environ.get("LLM_BASE_URL") - model = os.environ.get("USER_SIMULATOR_MODEL", "openai/gpt-5.2") + model = os.environ.get("USER_SIMULATOR_MODEL", "openai/gpt-4.1-2025-04-14") if not api_key: return "Error: No API key available for user simulation" @@ -173,7 +195,7 @@ def _handle_ask_user_on_host(question: str, context: str, instance_id: str, outp api_base=api_base, timeout=30, ) - result = response.choices[0].message.content + result = response.choices[0].message.content or "" logger.info(f"ask_user response generated: '{result[:100]}...'") return result except Exception as e: @@ -193,7 +215,7 @@ def _parse_ask_user_command(command: str) -> tuple[str, str] | None: Returns (question, context) tuple or None if not an ask_user command. """ command = command.strip() - if not command.startswith("ask_user"): + if not (command == "ask_user" or command.startswith("ask_user ") or command.startswith("ask_user\t")): return None # Remove the "ask_user" prefix @@ -216,8 +238,6 @@ def _parse_ask_user_command(command: str) -> tuple[str, str] | None: # Last resort: treat entire string as question return (args_str, "") - return None - class TemplateConfig(BaseModel): """This configuration is used to define almost all message templates that are @@ -1112,13 +1132,11 @@ def handle_action(self, step: StepOutput) -> StepOutput: if ask_user_args is not None: question, context = ask_user_args instance_id = self._problem_statement.id if self._problem_statement else "unknown" - # output_dir is parent of traj_path's parent (traj_path = output_dir/instance_id/instance_id.traj) - output_dir = self.traj_path.parent.parent if self.traj_path else None step.observation = _handle_ask_user_on_host( question=question, context=context, instance_id=instance_id, - output_dir=output_dir, + traj_path=self.traj_path, logger=self.logger, ) step.execution_time = time.perf_counter() - execution_t0 diff --git a/tools/ask_user/bin/llm.py b/tools/ask_user/bin/llm.py index 70d25aa7ad..d1cf3c0f5e 100644 --- a/tools/ask_user/bin/llm.py +++ b/tools/ask_user/bin/llm.py @@ -8,7 +8,6 @@ import os from openai import OpenAI -from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.WARNING) @@ -47,14 +46,9 @@ def get_openai_client() -> OpenAI: return _client -@retry( - stop=stop_after_attempt(1), # No retries - fail fast for debugging - wait=wait_exponential(multiplier=1, exp_base=2, max=5), - before_sleep=before_sleep_log(logger, logging.WARNING), -) def llm_completion(model: str, messages: list, **kwargs) -> str: """ - LLM completion with automatic retries. + LLM completion. Args: model: Model identifier (e.g., "openai/gpt-4o" for LiteLLM proxy) diff --git a/tools/ask_user/requirements.txt b/tools/ask_user/requirements.txt index 82befcb561..aa2b704464 100644 --- a/tools/ask_user/requirements.txt +++ b/tools/ask_user/requirements.txt @@ -1,2 +1 @@ openai>=1.0.0 -tenacity>=8.0.0 From 36d9a406c62735b36c3093913307af217863a036 Mon Sep 17 00:00:00 2001 From: yash-scaleai Date: Fri, 20 Feb 2026 00:50:53 +0000 Subject: [PATCH 3/3] fix: SWE-ReX Modal fork, top_p=None crash, data volume mount Infrastructure fixes applied during experiment runs: - pyproject.toml: use jeff-da/SWE-ReX sweap-support branch for Modal - models.py: handle top_p=None in model ID string (was crashing) - justfile: mount data/ volume for container access Co-Authored-By: Claude Opus 4.6 --- justfile | 1 + pyproject.toml | 2 +- sweagent/agent/models.py | 5 +++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/justfile b/justfile index bbd9f9b63a..204eac81b7 100644 --- a/justfile +++ b/justfile @@ -17,6 +17,7 @@ run: --env-file $(pwd)/.env \ -v "$HOME/.modal.toml:/root/.modal.toml" \ -v "$(pwd)/config:/app/config" \ + -v "$(pwd)/data:/app/data" \ -v "$(pwd)/sweagent_wrapper_configs:/app/sweagent_wrapper_configs" \ -v "$(pwd)/sweagent_results:/app/sweagent_results" \ --add-host=host.docker.internal:host-gateway \ diff --git a/pyproject.toml b/pyproject.toml index aff188f514..15c83e0f5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "litellm", "GitPython", "ghapi", - "swe-rex[modal]==1.2.0", + "swe-rex[modal] @ git+https://github.com/jeff-da/SWE-ReX.git@sweap-support", "tabulate", "textual>=1.0.0" ] diff --git a/sweagent/agent/models.py b/sweagent/agent/models.py index 307d8d579a..10ba870db4 100644 --- a/sweagent/agent/models.py +++ b/sweagent/agent/models.py @@ -78,7 +78,7 @@ class GenericAPIModelConfig(PydanticBaseModel): per_instance_call_limit: int = Field(default=0, description="Per instance call limit.") temperature: float = 0.0 """Sampling temperature""" - top_p: float | None = 1.0 + top_p: float | None = None """Sampling top-p""" api_base: str | None = None api_version: str | None = None @@ -180,7 +180,8 @@ def choose_api_key(self) -> str | None: @property def id(self) -> str: - return f"{self.name}__t-{self.temperature:.2f}__p-{self.top_p:.2f}__c-{self.per_instance_cost_limit:.2f}" + top_p_str = f"{self.top_p:.2f}" if self.top_p is not None else "none" + return f"{self.name}__t-{self.temperature:.2f}__p-{top_p_str}__c-{self.per_instance_cost_limit:.2f}" class ReplayModelConfig(GenericAPIModelConfig):