From ec206b2aa19648db3c78e3c718eb8f373f3a9b91 Mon Sep 17 00:00:00 2001 From: BHznJNs <441768875@qq.com> Date: Sun, 15 Mar 2026 22:47:50 +0800 Subject: [PATCH 1/3] backend tool call audition --- src-server/pyproject.toml | 2 +- src-server/src/agent/context.py | 6 +- src-server/src/agent/prompts/__init__.py | 2 +- .../src/agent/prompts/one_turns/__init__.py | 23 +++ .../{ => one_turns}/title_summarization.py | 10 +- .../one_turns/tool_call_safety_audit.py | 187 ++++++++++++++++++ src-server/src/agent/task/__init__.py | 8 +- .../src/agent/task/llm_request_manager.py | 10 +- .../src/agent/task/tool_call_dispatcher.py | 16 +- .../src/agent/task/tool_call_reviewer.py | 71 ++++++- src-server/src/agent/temp_generation.py | 36 ---- src-server/src/agent/types/metadata.py | 1 + src-server/src/agent/types/stream.py | 6 +- src-server/src/api/routes/llm_api.py | 5 +- .../src/api/routes/task/background_task.py | 13 +- .../src/api/routes/task/context_file.py | 1 + src-server/src/db/models/task.py | 16 +- src-server/src/schemas/task.py | 7 +- src-server/src/settings.py | 3 + src-server/uv.lock | 8 +- 20 files changed, 344 insertions(+), 87 deletions(-) create mode 100644 src-server/src/agent/prompts/one_turns/__init__.py rename src-server/src/agent/prompts/{ => one_turns}/title_summarization.py (66%) create mode 100644 src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py delete mode 100644 src-server/src/agent/temp_generation.py diff --git a/src-server/pyproject.toml b/src-server/pyproject.toml index 2b086ade..5d417bf3 100644 --- a/src-server/pyproject.toml +++ b/src-server/pyproject.toml @@ -3,7 +3,7 @@ name = "dais-server" version = "0.1.0" requires-python = ">=3.14" dependencies = [ - "dais-sdk==0.8.10", + "dais-sdk==0.8.15", "dais-shell==0.1.2", "alembic==1.18.4", diff --git a/src-server/src/agent/context.py b/src-server/src/agent/context.py index 1ac0ba55..e3235ee8 100644 --- a/src-server/src/agent/context.py +++ b/src-server/src/agent/context.py @@ -4,7 +4,7 @@ from dataclasses import asdict from typing import Self, cast from dais_sdk.tool import Toolset -from dais_sdk.types import ToolDef +from dais_sdk.types import Message, ToolDef from .tool import use_mcp_toolset_manager, BuiltinToolsetManager, McpToolsetManager, BuiltInToolset from .prompts import BASE_INSTRUCTION, NO_WORKSPACE_INSTRUCTION, NO_AGENT_INSTRUCTION from .types import ContextUsage @@ -56,7 +56,7 @@ def __init__(self, task_id: int, *, usage: task_models.TaskUsage, - messages: list[task_models.TaskMessage], + messages: list[Message], workspace: workspace_schemas.WorkspaceRead, agent: agent_schemas.AgentRead, provider: provider_schemas.ProviderRead, @@ -142,7 +142,7 @@ def model(self) -> provider_schemas.LlmModelRead: return self._model @property - def messages(self) -> list[task_models.TaskMessage]: + def messages(self) -> list[Message]: return self._messages @property diff --git a/src-server/src/agent/prompts/__init__.py b/src-server/src/agent/prompts/__init__.py index a965bd94..a1b4a17c 100644 --- a/src-server/src/agent/prompts/__init__.py +++ b/src-server/src/agent/prompts/__init__.py @@ -1,5 +1,5 @@ from .instruction import BASE_INSTRUCTION -from .title_summarization import TITLE_SUMMARIZATION_INSTRUCTION +from .one_turns import * from .built_in_agents import * USER_IGNORED_TOOL_CALL_RESULT = "[System Message] User ignored this tool call." diff --git a/src-server/src/agent/prompts/one_turns/__init__.py b/src-server/src/agent/prompts/one_turns/__init__.py new file mode 100644 index 00000000..09ba7a4c --- /dev/null +++ b/src-server/src/agent/prompts/one_turns/__init__.py @@ -0,0 +1,23 @@ +from dais_sdk import LLM +from ....db import db_context +from ....services.llm_model import LlmModelService + +from .title_summarization import TitleSummarization +from .tool_call_safety_audit import ToolCallSafetyAudit, ToolCallSafetyAuditInput, ToolCallSafetyAuditOutput + +async def create_one_turn_llm(model_id: int) -> LLM: + async with db_context() as db_session: + model = await LlmModelService(db_session).get_model_by_id(model_id) + provider = model.provider + provider = LLM.create_provider(provider.type, + provider.base_url, + api_key=provider.api_key) + return LLM(model.name, provider=provider) + +__all__ = [ + "create_one_turn_llm", + "TitleSummarization", + "ToolCallSafetyAudit", + "ToolCallSafetyAuditInput", + "ToolCallSafetyAuditOutput", +] diff --git a/src-server/src/agent/prompts/title_summarization.py b/src-server/src/agent/prompts/one_turns/title_summarization.py similarity index 66% rename from src-server/src/agent/prompts/title_summarization.py rename to src-server/src/agent/prompts/one_turns/title_summarization.py index ad4f04c2..c1691c22 100644 --- a/src-server/src/agent/prompts/title_summarization.py +++ b/src-server/src/agent/prompts/one_turns/title_summarization.py @@ -1,4 +1,4 @@ -TITLE_SUMMARIZATION_INSTRUCTION = """\ +INSTRUCTION = """\ Generate a concise title for the following task/conversation in {LANGUAGE}. CRITICAL: Your response must contain ONLY the title itself - no explanations, no "Here is the title:", no quotes, no punctuation at the end. @@ -10,3 +10,11 @@ - Use {LANGUAGE} language - Output format: plain text title only """ + +# --- --- --- --- --- --- + +from dais_sdk import LLM, OneTurn + +class TitleSummarization(OneTurn): + def __init__(self, llm: LLM): + super().__init__(llm, INSTRUCTION, output="text") diff --git a/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py b/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py new file mode 100644 index 00000000..2a126b12 --- /dev/null +++ b/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py @@ -0,0 +1,187 @@ +INSTRUCTION = """\ +You are a security review module for an AI agent system. Your sole responsibility is to assess the danger level of pending tool calls before they are executed. + +## Inputs + +You will receive three sections wrapped in XML tags. + +### Tool definitions + +``` + + + {name of the tool} + {what the tool does} + {JSON Schema describing the tool's arguments} + + ... + +``` + +- Contains definitions for all tools referenced in `pending_tool_calls` part. +- Use the `description` and `input_schema` to understand what each argument means and what the tool is capable of. + +### Recent conversation context + +``` + + + {message content} + + ... + +``` + +- Messages are listed in chronological order (earlier messages first). +- `role` is one of: `user`, `assistant`, or `tool` (tool result). +- This section may be empty or sparse; if so, apply the uncertainty rule (see Rule 5 of Scoring Rules). + +### Pending tool calls to review + +``` + + + {unique id} + {name of the tool} + {tool arguments, typically a JSON object} + + ... + +``` + +- There may be one or more `` entries (parallel tool calls). +- Review each `` independently. + +## Your Task + +For each pending tool call, assign a **danger level** from 0 to 100 (multiples of 10 only: 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100). +Output a single JSON object array, with each `tool_call_id` and its danger level string. No explanation, no additional text - only the JSON. + +**Output format:** +[ + {"call_id": "tool_call_id1", "danger_level": 10}, + {"call_id": "tool_call_id2", "danger_level": 50}, + ... +] + +## Danger Level Reference + +| Level | Meaning | Typical Examples | +|-------|---------|-----------------| +| 0 | Completely safe, fully reversible, read-only | Read a public file, query a database (SELECT only), get current date/time | +| 10 | Negligible risk, effectively harmless | List directory contents, get system time, echo a string | +| 20 | Very low risk, easily reversible | Write to a temp file, create a local log entry | +| 30 | Low risk, reversible with minor effort | Modify a non-critical config, send an internal notification | +| 40 | Moderate risk, reversible but with some effort | Overwrite an existing file, restart a non-critical service, fetch a public URL (read-only GET, no sensitive data in request) | +| 50 | Medium risk, partial or uncertain reversibility | Batch-modify records, send an external email, install a package | +| 60 | Elevated risk, difficult to reverse | Delete non-critical data, modify production config | +| 70 | High risk, hard to reverse, potential for significant impact | Delete important files, stop a critical service, bulk data changes | +| 80 | Very high risk, likely irreversible or severe impact | Drop a database table, revoke access credentials, mass-delete records | +| 90 | Extremely high risk OR danger level cannot be determined (see Rule 5) | Destructive system commands, unknown/opaque commands, missing critical context | +| 100 | Catastrophic, certainly irreversible, existential impact | Format a disk, wipe all data, terminate all system processes | + +## Scoring Rules + +1. **Assess each tool call independently** based on: the tool name, its arguments, and the intent inferred from the recent context. + +2. **Consider the reversibility** of the action as the primary factor. An action that is hard or impossible to undo should score higher. + +3. **Consider the blast radius**: how many systems, users, or data records could be affected if this call goes wrong. + +4. **Consider the context fit**: if the tool call appears inconsistent with the agent's stated goal or the conversation context, increase the score. + +5. **Default to 90 when uncertain.** Apply this rule in any of the following situations: + - The tool uses commands, APIs, or parameters that you are not familiar with and whose danger cannot be inferred from naming or semantics alone. + - The arguments are opaque, heavily encoded, or reference external resources you cannot evaluate (e.g., a dynamically constructed shell command, a raw binary payload, an unfamiliar DSL). + - The recent context is missing, too sparse, or ambiguous to determine why this tool call is being made. + - You cannot confidently distinguish between a safe and a dangerous interpretation of the call. + When in doubt, always err on the side of caution and assign 90. + +6. **Never output values other than multiples of 10** (0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100). + +7. **Output only the JSON object.** No markdown, no code fences, no explanation. +""" + +# --- --- --- --- --- --- + +import json +import xml.etree.ElementTree as ET +from dataclasses import dataclass +from typing import override +from pydantic import BaseModel +from dais_sdk import LLM, OneTurn +from dais_sdk.types import ( + ToolSchema, Message, ToolMessage +) + +@dataclass +class ToolCallSafetyAuditInput: + tool_definitions: list[ToolSchema] + context: list[Message] + pending_tool_calls: list[ToolMessage] + +class ToolCallSafetyAuditOutput(BaseModel): + class OutputItem(BaseModel): + call_id: str + risk_level: int + + results: list[ToolCallSafetyAuditOutput.OutputItem] + +class ToolCallSafetyAudit(OneTurn[ToolCallSafetyAuditInput, ToolCallSafetyAuditOutput]): + def __init__(self, llm: LLM): + super().__init__(llm, INSTRUCTION, output=ToolCallSafetyAuditOutput, validate=True) + + @override + def format_input(self, input: ToolCallSafetyAuditInput) -> str: + def tool_definitions_xml(tool_definitions: list[ToolSchema]) -> ET.Element: + root = ET.Element("tool_definitions") + for t in tool_definitions: + tool_elem = ET.SubElement(root, "tool") + ET.SubElement(tool_elem, "name").text = t["name"] + ET.SubElement(tool_elem, "description").text = t["description"] + ET.SubElement(tool_elem, "input_schema").text = json.dumps(t["parameters"], ensure_ascii=False) + return root + + def context_xml(context: list[Message]) -> ET.Element: + root = ET.Element("context") + for msg in context: + match msg.role: + case "user": + msg_elem = ET.SubElement(root, "message", role="user") + ET.SubElement(msg_elem, "content").text = msg.content + case "assistant": + msg_elem = ET.SubElement(root, "message", role="assistant") + ET.SubElement(msg_elem, "content").text = msg.content + if msg.tool_calls is not None: + tool_calls_elem = ET.SubElement(msg_elem, "tool_calls") + for tool_call in msg.tool_calls: + tool_call_elem = ET.SubElement(tool_calls_elem, "tool_call") + ET.SubElement(tool_call_elem, "id").text = tool_call.id + ET.SubElement(tool_call_elem, "name").text = tool_call.name + ET.SubElement(tool_call_elem, "arguments").text = json.dumps(tool_call.arguments, ensure_ascii=False) + case "tool": + msg_elem = ET.SubElement(root, "message", role="tool") + ET.SubElement(msg_elem, "name").text = msg.name + ET.SubElement(msg_elem, "arguments").text = json.dumps(msg.arguments, ensure_ascii=False) + if msg.result is not None: + ET.SubElement(msg_elem, "result").text = msg.result + if msg.error is not None: + ET.SubElement(msg_elem, "error").text = msg.error + case _: ... # do nothing for other message types + return root + + def pending_tool_calls_xml(pending_tool_calls: list[ToolMessage]) -> ET.Element: + root = ET.Element("pending_tool_calls") + + for tc in pending_tool_calls: + tool_call_elem = ET.SubElement(root, "tool_call") + ET.SubElement(tool_call_elem, "tool_call_id").text = tc.id + ET.SubElement(tool_call_elem, "name").text = tc.name + ET.SubElement(tool_call_elem, "arguments").text = json.dumps(tc.arguments, ensure_ascii=False) + return root + + return "".join([ET.tostring(el, encoding="unicode") for el in ( + tool_definitions_xml(input.tool_definitions), + context_xml(input.context), + pending_tool_calls_xml(input.pending_tool_calls), + )]) diff --git a/src-server/src/agent/task/__init__.py b/src-server/src/agent/task/__init__.py index c192c82a..76c75769 100644 --- a/src-server/src/agent/task/__init__.py +++ b/src-server/src/agent/task/__init__.py @@ -4,7 +4,7 @@ from loguru import logger from dais_sdk.tool import ToolCallExecutor from dais_sdk.types import ( - ToolMessage, UserMessage, AssistantMessage, + Message, ToolMessage, UserMessage, AssistantMessage, ToolDef, ToolDoesNotExistError, ToolArgumentDecodeError, ToolExecutionError, ) from ..context import AgentContext @@ -40,10 +40,10 @@ def __init__(self, ctx: AgentContext): self._tool_call_executor.exception_handler.set_handler(ToolArgumentDecodeError, handle_tool_argument_decode_error) self._tool_call_executor.exception_handler.set_handler(ToolExecutionError, handle_tool_execution_error) - self._tool_call_reviewer = ToolCallReviewer() + self._tool_call_reviewer = ToolCallReviewer(ctx) self._tool_call_dispatcher = ToolCallDispatcher(self._ctx, self._tool_call_executor, self._tool_call_reviewer) - def _find_message(self, predicate: Callable[[task_models.TaskMessage], bool]) -> task_models.TaskMessage: + def _find_message(self, predicate: Callable[[Message], bool]) -> Message: for message in reversed(self._ctx.messages): if predicate(message): return message @@ -113,7 +113,7 @@ async def approve_tool_call(self, call_id: str, approved: bool) -> AsyncGenerato # so we can safely assert the type of tool_def to ToolDef here. assert isinstance(tool, ToolDef) - permission_check_result = self._tool_call_reviewer.check_permission(tool, target_message) + permission_check_result = await self._tool_call_reviewer.check_permission(tool, target_message) match permission_check_result: case ToolCallBlocked(event): yield event diff --git a/src-server/src/agent/task/llm_request_manager.py b/src-server/src/agent/task/llm_request_manager.py index 12921fe4..88baf9c7 100644 --- a/src-server/src/agent/task/llm_request_manager.py +++ b/src-server/src/agent/task/llm_request_manager.py @@ -1,7 +1,6 @@ import asyncio import uuid from collections.abc import AsyncGenerator -from dais_sdk.providers import OpenAIProvider from loguru import logger from dais_sdk import LLM from dais_sdk.types import ( @@ -33,14 +32,13 @@ def llm(self) -> LLM: return self._llm def _llm_factory(self) -> LLM: - provider = OpenAIProvider( - self._ctx.provider.base_url, - api_key=self._ctx.provider.api_key) - return LLM(provider=provider) + provider = LLM.create_provider(self._ctx.provider.type, + self._ctx.provider.base_url, + api_key=self._ctx.provider.api_key) + return LLM(self._ctx.model.name, provider=provider) def _create_request_param(self) -> LlmRequestParams: params = LlmRequestParams( - model=self._ctx.model.name, instructions=self._ctx.system_instruction, messages=self._ctx.messages) usable_tool_ids = self._ctx.usable_tool_ids diff --git a/src-server/src/agent/task/tool_call_dispatcher.py b/src-server/src/agent/task/tool_call_dispatcher.py index 1a262629..e4f545af 100644 --- a/src-server/src/agent/task/tool_call_dispatcher.py +++ b/src-server/src/agent/task/tool_call_dispatcher.py @@ -25,7 +25,7 @@ def __init__(self, ctx: AgentContext, tool_call_executor: ToolCallExecutor, tool self._tool_call_executor = tool_call_executor self._tool_call_reviewer = tool_call_reviewer - async def execute(self, tool: ToolLike, message: ToolMessage) -> ToolEvent: + async def execute(self, tool: ToolDef, message: ToolMessage) -> ToolEvent: """ Execute tool call and attach the result to the corresponding message. This method should not throw any exceptions. @@ -46,7 +46,7 @@ async def _dispatch_stream(self, | ErrorEvent, None]: executables = list[tuple[ToolDef, ToolMessage]]() for message in tool_call_messages: - tool: ToolLike | None = self._ctx.find_tool(message.name) + tool = self._ctx.find_tool(message.name) if tool is None: message.error = handle_tool_does_not_exist_error(ToolDoesNotExistError(message.name)) yield MessageReplaceEvent(message=message) @@ -59,7 +59,7 @@ async def _dispatch_stream(self, if tool.executes(ExecutionControlToolset.finish_task): result.has_finished_task = True - permission_check_result = self._tool_call_reviewer.check_permission(tool, message) + permission_check_result = await self._tool_call_reviewer.check_permission(tool, message) match permission_check_result: case ToolCallBlocked(event): yield event @@ -68,6 +68,16 @@ async def _dispatch_stream(self, case ToolCallApproved(): executables.append((tool, message)) + if len(result.pendings) > 0: + audit_result = await self._tool_call_reviewer.audit_tool_calls(result.pendings) + if audit_result is not None: + high_risk, low_risk = audit_result + result.pendings = high_risk + for message in low_risk: + tool = self._ctx.find_tool(message.name) + if tool is None: continue + executables.append((tool, message)) + if len(executables) > 0: execute_tasks = [self.execute(tool, message) for tool, message in executables] events = await asyncio.gather(*execute_tasks, return_exceptions=True) diff --git a/src-server/src/agent/task/tool_call_reviewer.py b/src-server/src/agent/task/tool_call_reviewer.py index b94bbf9f..6c39dff6 100644 --- a/src-server/src/agent/task/tool_call_reviewer.py +++ b/src-server/src/agent/task/tool_call_reviewer.py @@ -1,13 +1,20 @@ from dataclasses import dataclass from loguru import logger from dais_sdk.types import ToolDef, ToolMessage +from dais_sdk.tool.prepare import prepare_tools from ..tool.types import is_tool_metadata -from ..prompts import USER_DENIED_TOOL_CALL_RESULT +from ..prompts import ( + create_one_turn_llm, + USER_DENIED_TOOL_CALL_RESULT, + ToolCallSafetyAudit, ToolCallSafetyAuditInput, ToolCallSafetyAuditOutput, +) +from ..context import AgentContext from ..types import ( ToolDeniedEvent, ToolEvent, ToolRequirePermissionEvent, ToolRequireUserResponseEvent ) from ..types.metadata import ToolMessageMetadata, UserApprovalStatus, is_agent_tool_metadata +from ...settings import use_app_setting_manager @dataclass class ToolCallBlocked: @@ -19,6 +26,65 @@ class ToolCallApproved: ... class ToolCallReviewer: _logger = logger.bind(name="ToolCallReviewer") + def __init__(self, ctx: AgentContext): + self._ctx = ctx + + async def audit_tool_calls(self, messages: list[ToolMessage]) -> tuple[list[ToolMessage], list[ToolMessage]] | None: + """ + Side effect: The risk level will be attached to the metadata of each message. + + Returns: + - Tuple of (high_risk, low_risk) + - None if smart approve is disabled or no flash model is configured. + """ + settings = use_app_setting_manager().settings + if not settings.smart_approve: + self._logger.info("Smart approve is disabled, skipping smart approve") + return None + if settings.flash_model is None: + self._logger.warning("No flash model configured, skipping smart approve") + return None + + llm = await create_one_turn_llm(settings.flash_model) + audit_context_size = 5 + context = self._ctx.messages[-audit_context_size:] + safety_audit = ToolCallSafetyAudit(llm) + + tooldefs: list[ToolDef] = [] + for message in messages: + tool = self._ctx.find_tool(message.name) + if tool is not None: tooldefs.append(tool) + + input = ToolCallSafetyAuditInput( + tool_definitions=prepare_tools(tooldefs), + context=context, + pending_tool_calls=messages + ) + output = await safety_audit(input) + + # attach risk level to each message + for item in output.results: + for message in messages: + if message.call_id == item.call_id: + assert is_agent_tool_metadata(message.metadata) + message.metadata["risk_level"] = item.risk_level + break + else: + self._logger.warning(f"Tool call {item.call_id} not found") + continue + + # split messages into two groups: high risk and low risk + high_risk = [] + low_risk = [] + for message in messages: + assert is_agent_tool_metadata(message.metadata) + assert "risk_level" in message.metadata + if message.metadata["risk_level"] > settings.smart_approve_threshold: + high_risk.append(message) + else: + low_risk.append(message) + return high_risk, low_risk + def apply_user_approval(self, call_id: str, metadata: ToolMessageMetadata, @@ -41,7 +107,7 @@ def apply_user_approval(self, ) return True - def check_permission(self, tool: ToolDef, message: ToolMessage) -> ToolCallBlocked | ToolCallApproved: + async def check_permission(self, tool: ToolDef, message: ToolMessage) -> ToolCallBlocked | ToolCallApproved: # use TypeGuards to assert the type of metadata assert is_tool_metadata(tool.metadata) assert is_agent_tool_metadata(message.metadata) @@ -57,7 +123,6 @@ def check_permission(self, tool: ToolDef, message: ToolMessage) -> ToolCallBlock message.metadata["user_approval"] = UserApprovalStatus.PENDING match message.metadata["user_approval"]: case UserApprovalStatus.PENDING: - # TODO: implement smart approve here return ToolCallBlocked( event=ToolRequirePermissionEvent(call_id=message.call_id, tool_name=tool.name)) case UserApprovalStatus.DENIED: diff --git a/src-server/src/agent/temp_generation.py b/src-server/src/agent/temp_generation.py deleted file mode 100644 index 5a212f04..00000000 --- a/src-server/src/agent/temp_generation.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Self -from dais_sdk import LLM -from dais_sdk.providers import OpenAIProvider -from dais_sdk.types import LlmRequestParams, SystemMessage, UserMessage, AssistantMessage -from ..db import db_context -from ..db.models import provider as provider_models -from ..services.llm_model import LlmModelService - -class TempGeneration: - def __init__(self, instruction: str, model_name: str, provider: provider_models.Provider) -> None: - self._instruction = instruction - self._model_name = model_name - self._provider = provider - - @classmethod - async def create(cls, instruction: str, model_id: int) -> Self: - async with db_context() as db_session: - model = await LlmModelService(db_session).get_model_by_id(model_id) - provider = model.provider - return cls(instruction, model.name, provider) - - def _request_param_factory(self, input: str) -> LlmRequestParams: - return LlmRequestParams( - model=self._model_name, - messages=[ - SystemMessage(content=self._instruction), - UserMessage(content=input), - ], - tool_choice="none") - - async def generate(self, input: str) -> str | None: - provider = OpenAIProvider(self._provider.base_url, api_key=self._provider.api_key) - llm = LLM(provider=provider) - request_params = self._request_param_factory(input) - message = await llm.generate_text(request_params) - return message.content diff --git a/src-server/src/agent/types/metadata.py b/src-server/src/agent/types/metadata.py index 08564e45..53f171be 100644 --- a/src-server/src/agent/types/metadata.py +++ b/src-server/src/agent/types/metadata.py @@ -8,6 +8,7 @@ class UserApprovalStatus(str, Enum): class ToolMessageMetadata(TypedDict, total=False): user_approval: UserApprovalStatus + risk_level: int @staticmethod def is_agent_tool_metadata(_: dict) -> TypeGuard[ToolMessageMetadata]: diff --git a/src-server/src/agent/types/stream.py b/src-server/src/agent/types/stream.py index 46d463db..d7332c38 100644 --- a/src-server/src/agent/types/stream.py +++ b/src-server/src/agent/types/stream.py @@ -1,13 +1,13 @@ from collections.abc import AsyncGenerator -from dataclasses import dataclass from typing import Annotated, Literal, Self from dais_sdk.types import ( + Message, ToolMessage, AssistantMessage, TextChunkEvent as SdkTextChunkEvent, ToolCallChunkEvent as SdkToolCallChunkEvent ) from pydantic import BaseModel, Discriminator -from ...db.models.task import TaskMessage, TaskUsage +from ...db.models.task import TaskUsage class MessageStartEvent(BaseModel): @@ -72,7 +72,7 @@ def from_sdk(cls, message: AssistantMessage) -> Self: return cls(message=message) class MessageReplaceEvent(BaseModel): - message: TaskMessage + message: Message event_id: Literal["MESSAGE_REPLACE"] = "MESSAGE_REPLACE" class ToolCallEndEvent(BaseModel): diff --git a/src-server/src/api/routes/llm_api.py b/src-server/src/api/routes/llm_api.py index 22aa8c98..f4122755 100644 --- a/src-server/src/api/routes/llm_api.py +++ b/src-server/src/api/routes/llm_api.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel -from dais_sdk.providers import OpenAIProvider, LlmProviders +from dais_sdk import LLM +from dais_sdk.providers import LlmProviders llm_api_router = APIRouter(tags=["llm_api"]) @@ -16,6 +17,6 @@ class FetchModelsResponse(BaseModel): @llm_api_router.get("/models", response_model=FetchModelsResponse) async def fetch_models(params: FetchModelsParams = Depends(FetchModelsParams)): - provider = OpenAIProvider(params.base_url, api_key=params.api_key) + provider = LLM.create_provider(params.type, params.base_url, api_key=params.api_key) models = await provider.list_models() return FetchModelsResponse(models=models) diff --git a/src-server/src/api/routes/task/background_task.py b/src-server/src/api/routes/task/background_task.py index 4433e6a0..0b4e0236 100644 --- a/src-server/src/api/routes/task/background_task.py +++ b/src-server/src/api/routes/task/background_task.py @@ -1,13 +1,11 @@ import time from typing import TYPE_CHECKING from loguru import logger -from dais_sdk.types import UserMessage -from ....db.models.task import TaskMessage +from dais_sdk.types import Message, UserMessage from ....db import db_context from ....services.task import TaskService from ....schemas import task as task_schemas -from ....agent.prompts import TITLE_SUMMARIZATION_INSTRUCTION -from ....agent.temp_generation import TempGeneration +from ....agent.prompts import create_one_turn_llm, TitleSummarization from ....settings import use_app_setting_manager from ....api.sse_dispatcher.types import TaskTitleUpdatedEvent @@ -18,15 +16,16 @@ async def summarize_title_in_background( task_id: int, - context: list[TaskMessage], + context: list[Message], sse_dispatcher: SseDispatcher, ): settings = use_app_setting_manager().settings if settings.flash_model is None: return try: - temp_generation = await TempGeneration.create(TITLE_SUMMARIZATION_INSTRUCTION, settings.flash_model) + llm = await create_one_turn_llm(settings.flash_model) + summarizer = TitleSummarization(llm) assert isinstance(context[0], UserMessage) - title = await temp_generation.generate(context[0].content) + title = await summarizer(context[0].content) except Exception: _logger.exception("Failed to request title summarization for task {}", task_id) return diff --git a/src-server/src/api/routes/task/context_file.py b/src-server/src/api/routes/task/context_file.py index 8c2aa4c8..c19ad4b5 100644 --- a/src-server/src/api/routes/task/context_file.py +++ b/src-server/src/api/routes/task/context_file.py @@ -86,6 +86,7 @@ class SearchFileResult(BaseModel): items: list[task_schemas.ContextFileItem] total: int +# TODO: use to_thread @context_file_router.get("/files/list", response_model=ListDirectoryResult) async def list_directory( db_session: DbSessionDep, diff --git a/src-server/src/db/models/task.py b/src-server/src/db/models/task.py index 45be29a2..ef0ae024 100644 --- a/src-server/src/db/models/task.py +++ b/src-server/src/db/models/task.py @@ -1,8 +1,8 @@ import time from dataclasses import dataclass -from typing import Annotated, TYPE_CHECKING, Self -from dais_sdk.types import AssistantMessage, SystemMessage, ToolMessage, UserMessage -from pydantic import Discriminator, TypeAdapter +from typing import TYPE_CHECKING, Self +from dais_sdk.types import Message +from pydantic import TypeAdapter from sqlalchemy import ForeignKey from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -13,12 +13,8 @@ from .agent import Agent from .workspace import Workspace -TaskMessage = Annotated[ - UserMessage | AssistantMessage | SystemMessage | ToolMessage, - Discriminator("role") -] -message_adapter = TypeAdapter(TaskMessage) -messages_adapter = TypeAdapter(list[TaskMessage]) +message_adapter = TypeAdapter(Message) +messages_adapter = TypeAdapter(list[Message]) @dataclass class TaskUsage: @@ -45,7 +41,7 @@ class Task(Base): id: Mapped[int] = mapped_column(primary_key=True) title: Mapped[str] usage: Mapped[TaskUsage] = mapped_column(DataClassJSON(TaskUsage), default=TaskUsage.default) - messages: Mapped[list[TaskMessage]] = mapped_column(PydanticJSON(messages_adapter), default=list) + messages: Mapped[list[Message]] = mapped_column(PydanticJSON(messages_adapter), default=list) last_run_at: Mapped[int] = mapped_column(default=lambda: int(time.time())) agent_id: Mapped[int | None] = mapped_column(ForeignKey("agents.id", ondelete="SET NULL")) diff --git a/src-server/src/schemas/task.py b/src-server/src/schemas/task.py index d3c34655..82de955e 100644 --- a/src-server/src/schemas/task.py +++ b/src-server/src/schemas/task.py @@ -1,10 +1,11 @@ from typing import Literal from . import DTOBase -from ..db.models.task import TaskMessage, TaskUsage +from dais_sdk.types import Message +from ..db.models.task import TaskUsage class TaskBase(DTOBase): title: str - messages: list[TaskMessage] + messages: list[Message] class TaskBrief(TaskBase): id: int @@ -36,7 +37,7 @@ class TaskUpdate(DTOBase): usage: TaskUsage | None last_run_at: int agent_id: int | None - messages: list[TaskMessage] | None + messages: list[Message] | None # --- --- --- --- --- --- diff --git a/src-server/src/settings.py b/src-server/src/settings.py index 5ec12dc6..dc23b1a6 100644 --- a/src-server/src/settings.py +++ b/src-server/src/settings.py @@ -71,6 +71,9 @@ class AppSettings(JsonSettings): reply_language: str = "zh_CN" flash_model: int | None = None + smart_approve: bool = True + smart_approve_threshold: int = 50 # 0 ~ 100 + async def validate_self(self): if self.flash_model is not None: async with db_context() as db_session: diff --git a/src-server/uv.lock b/src-server/uv.lock index 6555ccff..3a1a9676 100644 --- a/src-server/uv.lock +++ b/src-server/uv.lock @@ -298,7 +298,7 @@ wheels = [ [[package]] name = "dais-sdk" -version = "0.8.10" +version = "0.8.15" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -308,9 +308,9 @@ dependencies = [ { name = "starlette" }, { name = "uvicorn" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/aa/10/178658fccb333cbab743adb00789a9555f74d571ad5ac260ef07ce1172c1/dais_sdk-0.8.10.tar.gz", hash = "sha256:1e5ef7c4cbf46d2e956a655ef47884960d7fa70e05d1651ba591c86491755d35", size = 18317, upload-time = "2026-03-14T05:48:43.444Z" } +sdist = { url = "https://files.pythonhosted.org/packages/26/b3/2d5a5e19eb2d562dc0abf0dd66716b5fe59eff403e12e2ed97e13e280c9b/dais_sdk-0.8.15.tar.gz", hash = "sha256:3bd3ad24444c2bd86590aeaa5d0eec99a2b5ac53e4bab0743d0793c76f23286d", size = 19352, upload-time = "2026-03-15T13:02:43.862Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6d/29/8cd77864809ee7d52d6d21f6009923018b2aa0b7046a3c31bd8633fcc2e4/dais_sdk-0.8.10-py3-none-any.whl", hash = "sha256:4bca30e074937b049e340b07b20783d02114b1b1a4996b543441b7f5381b91eb", size = 27381, upload-time = "2026-03-14T05:48:41.889Z" }, + { url = "https://files.pythonhosted.org/packages/78/6d/005d37f07da3847939da1597fdd97831c230f600a583f95987986729e38d/dais_sdk-0.8.15-py3-none-any.whl", hash = "sha256:d35486ea9982b5e5ff6c3bc9190bb2e77216d2f63426562849bb409c601a3a8e", size = 29193, upload-time = "2026-03-15T13:02:42.81Z" }, ] [[package]] @@ -358,7 +358,7 @@ requires-dist = [ { name = "aiosqlite", specifier = "==0.22.1" }, { name = "alembic", specifier = "==1.18.4" }, { name = "binaryornot", specifier = "==0.4.4" }, - { name = "dais-sdk", specifier = "==0.8.10" }, + { name = "dais-sdk", specifier = "==0.8.15" }, { name = "dais-shell", specifier = "==0.1.2" }, { name = "fastapi", specifier = "==0.135.1" }, { name = "fastapi-pagination", specifier = "==0.15.10" }, From 6cc848ea4166ea231405a917db2c6038c826e799 Mon Sep 17 00:00:00 2001 From: BHznJNs <441768875@qq.com> Date: Mon, 16 Mar 2026 16:38:54 +0800 Subject: [PATCH 2/3] smart approve frontend --- .../src/components/ai-elements/tool.tsx | 52 +++++++++++++- .../SettingsView/AgentSettings/index.tsx | 72 +++++++++++++++++++ .../SideBar/views/SettingsView/index.tsx | 12 +++- .../messages/GeneralToolMessage.tsx | 2 + src-frontend/src/i18n/locales/en/sidebar.json | 9 +++ .../src/i18n/locales/zh_CN/sidebar.json | 9 +++ .../src/stores/server-settings-store.ts | 3 +- src-frontend/src/styles/base.css | 4 ++ src-server/src/api/routes/settings.py | 1 - 9 files changed, 158 insertions(+), 6 deletions(-) create mode 100644 src-frontend/src/features/SideBar/views/SettingsView/AgentSettings/index.tsx diff --git a/src-frontend/src/components/ai-elements/tool.tsx b/src-frontend/src/components/ai-elements/tool.tsx index 1aa85fb8..a9cfb45d 100644 --- a/src-frontend/src/components/ai-elements/tool.tsx +++ b/src-frontend/src/components/ai-elements/tool.tsx @@ -42,6 +42,7 @@ export type ToolHeaderProps = { toolsetName?: string; className?: string; state: ToolState; + riskLevel?: number; }; export const getStatusBadge = (status: ToolState) => { @@ -73,6 +74,51 @@ export const getStatusBadge = (status: ToolState) => { ); }; +const getRiskBadge = (riskLevel: number) => { + const normalizedRisk = Math.min( + 100, + Math.max(0, Math.ceil(riskLevel / 10) * 10) + ); + + const configs = [ + { + label: "安全", + range: [0, 20], + className: "bg-emerald-50 text-emerald-700 dark:bg-emerald-950/60 dark:text-emerald-400", + }, + { + label: "低风险", + range: [30, 50], + className: "bg-amber-50 text-amber-700 dark:bg-amber-950/60 dark:text-amber-400", + }, + { + label: "中风险", + range: [60, 70], + className: "bg-orange-50 text-orange-700 dark:bg-orange-950/60 dark:text-orange-400", + }, + { + label: "高风险", + range: [80, 100], + className: "bg-red-50 text-red-700 dark:bg-red-950/60 dark:text-red-400", + }, + ] as const; + + const match = configs.find( + (config) => + normalizedRisk >= config.range[0] && normalizedRisk <= config.range[1] + ); + + if (!match) { + return null; + } + + return ( + + {match.label} {normalizedRisk} + + ); +}; + export type ToolBreadcrumbProps = { toolsetName?: string; toolName: string; @@ -100,6 +146,7 @@ export const ToolHeader = ({ toolsetName, state, toolName, + riskLevel, ...props }: ToolHeaderProps) => ( {getStatusBadge(state)} - +
+ {riskLevel && getRiskBadge(riskLevel)} + +
); diff --git a/src-frontend/src/features/SideBar/views/SettingsView/AgentSettings/index.tsx b/src-frontend/src/features/SideBar/views/SettingsView/AgentSettings/index.tsx new file mode 100644 index 00000000..5036e875 --- /dev/null +++ b/src-frontend/src/features/SideBar/views/SettingsView/AgentSettings/index.tsx @@ -0,0 +1,72 @@ +import { useEffect, useState } from "react"; +import { useTranslation } from "react-i18next"; +import { useDebounceFn } from "ahooks"; +import { SIDEBAR_NAMESPACE } from "@/i18n/resources"; +import { SettingItem } from "@/components/custom/item/SettingItem"; +import { useServerSettingsStore } from "@/stores/server-settings-store"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Switch } from "@/components/ui/switch"; +import { Input } from "@/components/ui/input"; +import { AppSettings } from "@/api/generated/schemas"; + +export function AgentSettings() { + const { t } = useTranslation(SIDEBAR_NAMESPACE); + const { current: serverSettings, setPartial: setPartialServerSettings } = useServerSettingsStore(); + const [localSettings, setLocalSettings] = useState(serverSettings); + const [disabled, setDisabled] = useState(false); + + useEffect(() => { + setLocalSettings(serverSettings); + }, [serverSettings]); + + const { run: handleUpdateSettings } = useDebounceFn((update: Partial) => { + const updatePromise = setPartialServerSettings(update); + if (updatePromise === null) { + return; + } + setDisabled(true); + updatePromise.finally(() => setDisabled(false)); + }, { wait: 300 }); + + const handleValueChange = (update: Partial) => { + setLocalSettings((prev) => { + if (prev === null) { + return null; + } + return { ...prev, ...update }; + }); + handleUpdateSettings(update); + }; + + return ( +
+ + {localSettings === null ? ( + + ) : ( + handleValueChange({ smart_approve: checked })} + disabled={disabled} + /> + )} + + + + {localSettings === null ? ( + + ) : ( + handleValueChange({ smart_approve_threshold: Number(e.target.value) })} + disabled={disabled} + /> + )} + +
+ ); +} diff --git a/src-frontend/src/features/SideBar/views/SettingsView/index.tsx b/src-frontend/src/features/SideBar/views/SettingsView/index.tsx index ab8412d1..14c05324 100644 --- a/src-frontend/src/features/SideBar/views/SettingsView/index.tsx +++ b/src-frontend/src/features/SideBar/views/SettingsView/index.tsx @@ -3,11 +3,12 @@ import { ScrollArea } from "@/components/ui/scroll-area"; import { SIDEBAR_NAMESPACE } from "@/i18n/resources"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { cn } from "@/lib/utils"; -import { SideBarHeader } from "../../components/SideBarHeader"; -import { DevSettings } from "./DevSettings"; import { GeneralSettings } from "./GeneralSettings"; -import { HelperModelSettings } from "./HelperModelSettings"; import { ProviderSettings } from "./ProviderSettings"; +import { HelperModelSettings } from "./HelperModelSettings"; +import { AgentSettings } from "./AgentSettings"; +import { DevSettings } from "./DevSettings"; +import { SideBarHeader } from "../../components/SideBarHeader"; export function SettingsView() { const { t } = useTranslation(SIDEBAR_NAMESPACE); @@ -28,6 +29,11 @@ export function SettingsView() { title: t("settings.tabs.helper_model"), content: , }, + { + id: "agents", + title: t("settings.tabs.agents"), + content: , + }, { id: "dev", title: t("settings.tabs.dev"), diff --git a/src-frontend/src/features/Tabs/TaskPanel/components/messages/GeneralToolMessage.tsx b/src-frontend/src/features/Tabs/TaskPanel/components/messages/GeneralToolMessage.tsx index 2e3533bd..e45245df 100644 --- a/src-frontend/src/features/Tabs/TaskPanel/components/messages/GeneralToolMessage.tsx +++ b/src-frontend/src/features/Tabs/TaskPanel/components/messages/GeneralToolMessage.tsx @@ -7,6 +7,7 @@ import { useAgentTaskAction } from "../../hooks/use-agent-task"; import { useToolName } from "../../hooks/use-tool-name"; import { useToolState } from "../../hooks/use-tool-state"; import { shouldShowConfirmation, ToolConfirmation } from "./BuiltInToolMessage/components/ToolConfirmation"; +import { ToolMessageMetadata } from "@/api/generated/schemas"; export function GeneralToolMessage({ message }: ToolMessageProps) { const { reviewTool } = useAgentTaskAction(); @@ -21,6 +22,7 @@ export function GeneralToolMessage({ message }: ToolMessageProps) { toolName={toolName} toolsetName={toolsetName} state={toolState} + riskLevel={(message.metadata as ToolMessageMetadata).risk_level} /> diff --git a/src-frontend/src/i18n/locales/en/sidebar.json b/src-frontend/src/i18n/locales/en/sidebar.json index cb49e227..589de27c 100644 --- a/src-frontend/src/i18n/locales/en/sidebar.json +++ b/src-frontend/src/i18n/locales/en/sidebar.json @@ -27,6 +27,14 @@ "placeholder": "Plugins view" }, "settings": { + "agents": { + "smart_approve": { + "title": "Smart approve" + }, + "smart_approve_threshold": { + "title": "Smart approve threshold" + } + }, "dev": { "devtools": { "open_button": "Open", @@ -85,6 +93,7 @@ } }, "tabs": { + "agents": "Agents", "dev": "Development", "general": "General", "helper_model": "Helper model", diff --git a/src-frontend/src/i18n/locales/zh_CN/sidebar.json b/src-frontend/src/i18n/locales/zh_CN/sidebar.json index e65640a3..6de481b7 100644 --- a/src-frontend/src/i18n/locales/zh_CN/sidebar.json +++ b/src-frontend/src/i18n/locales/zh_CN/sidebar.json @@ -27,6 +27,14 @@ "placeholder": "插件视图" }, "settings": { + "agents": { + "smart_approve": { + "title": "智能批准" + }, + "smart_approve_threshold": { + "title": "智能批准阈值" + } + }, "dev": { "devtools": { "open_button": "打开", @@ -85,6 +93,7 @@ } }, "tabs": { + "agents": "Agents", "dev": "开发", "general": "通用", "helper_model": "助手模型", diff --git a/src-frontend/src/stores/server-settings-store.ts b/src-frontend/src/stores/server-settings-store.ts index 24e2cccc..5eac2e6e 100644 --- a/src-frontend/src/stores/server-settings-store.ts +++ b/src-frontend/src/stores/server-settings-store.ts @@ -6,7 +6,7 @@ type ServerSettingsStore = { current: AppSettings | null; currentPromise: Promise; isLoading: boolean; - setPartial: (settings: Partial) => void; + setPartial: (settings: Partial) => Promise | null; }; export const useServerSettingsStore = create()((set, get) => ({ @@ -28,5 +28,6 @@ export const useServerSettingsStore = create()((set, get) = return settings; }); set({ currentPromise: updatePromise }); + return updatePromise; }, })); diff --git a/src-frontend/src/styles/base.css b/src-frontend/src/styles/base.css index 5df4de68..ef101d9a 100644 --- a/src-frontend/src/styles/base.css +++ b/src-frontend/src/styles/base.css @@ -11,3 +11,7 @@ body { input[type="password"]::-ms-reveal { display: none !important; } + +.dark input[type="number"] { + color-scheme: dark; +} diff --git a/src-server/src/api/routes/settings.py b/src-server/src/api/routes/settings.py index 0182d900..99ac22bf 100644 --- a/src-server/src/api/routes/settings.py +++ b/src-server/src/api/routes/settings.py @@ -1,4 +1,3 @@ -import asyncio from fastapi import APIRouter from ...settings import AppSettings, use_app_setting_manager From 99ce8c913918b6f7e564b9141154e5e67ffe8400 Mon Sep 17 00:00:00 2001 From: BHznJNs <441768875@qq.com> Date: Mon, 16 Mar 2026 20:07:33 +0800 Subject: [PATCH 3/3] backend tool call safety audit --- src-server/pyproject.toml | 2 +- .../one_turns/tool_call_safety_audit.py | 43 +++--- src-server/src/agent/task/__init__.py | 6 +- .../src/agent/task/tool_call_dispatcher.py | 122 ++++++++++++------ .../src/agent/task/tool_call_reviewer.py | 59 ++++++--- .../agent/tool/builtin_tools/file_system.py | 2 + .../src/api/routes/task/context_file.py | 4 +- src-server/uv.lock | 8 +- 8 files changed, 158 insertions(+), 88 deletions(-) diff --git a/src-server/pyproject.toml b/src-server/pyproject.toml index 5d417bf3..a047a695 100644 --- a/src-server/pyproject.toml +++ b/src-server/pyproject.toml @@ -3,7 +3,7 @@ name = "dais-server" version = "0.1.0" requires-python = ">=3.14" dependencies = [ - "dais-sdk==0.8.15", + "dais-sdk==0.8.16", "dais-shell==0.1.2", "alembic==1.18.4", diff --git a/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py b/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py index 2a126b12..2077c4bb 100644 --- a/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py +++ b/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py @@ -1,5 +1,5 @@ INSTRUCTION = """\ -You are a security review module for an AI agent system. Your sole responsibility is to assess the danger level of pending tool calls before they are executed. +You are a security review module for an AI agent system. Your sole responsibility is to assess the risk level of pending tool calls before they are executed. ## Inputs @@ -41,8 +41,8 @@ ``` - {unique id} - {name of the tool} + {unique id} + {name of the tool} {tool arguments, typically a JSON object} ... @@ -54,17 +54,22 @@ ## Your Task -For each pending tool call, assign a **danger level** from 0 to 100 (multiples of 10 only: 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100). -Output a single JSON object array, with each `tool_call_id` and its danger level string. No explanation, no additional text - only the JSON. +For each pending tool call, assign a **risk level** from 0 to 100 (multiples of 10 only: 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100). +Output a single JSON object array, with each `tool_call_id` and its risk level string. No explanation, no additional text - only the JSON. -**Output format:** -[ - {"call_id": "tool_call_id1", "danger_level": 10}, - {"call_id": "tool_call_id2", "danger_level": 50}, - ... -] +### Output format + +``` +{ + "results": [ + {"call_id": "tool_call_id1", "risk_level": 10}, + {"call_id": "tool_call_id2", "risk_level": 50}, + ... + ] +} +``` -## Danger Level Reference +## Risk Level Reference | Level | Meaning | Typical Examples | |-------|---------|-----------------| @@ -74,10 +79,10 @@ | 30 | Low risk, reversible with minor effort | Modify a non-critical config, send an internal notification | | 40 | Moderate risk, reversible but with some effort | Overwrite an existing file, restart a non-critical service, fetch a public URL (read-only GET, no sensitive data in request) | | 50 | Medium risk, partial or uncertain reversibility | Batch-modify records, send an external email, install a package | -| 60 | Elevated risk, difficult to reverse | Delete non-critical data, modify production config | +| 60 | Elevated risk, difficult to reverse | Delete non-critical data, modify production config, write to or edit files outside the working directory | | 70 | High risk, hard to reverse, potential for significant impact | Delete important files, stop a critical service, bulk data changes | | 80 | Very high risk, likely irreversible or severe impact | Drop a database table, revoke access credentials, mass-delete records | -| 90 | Extremely high risk OR danger level cannot be determined (see Rule 5) | Destructive system commands, unknown/opaque commands, missing critical context | +| 90 | Extremely high risk OR risk level cannot be determined (see Rule 5) | Destructive system commands, unknown/opaque commands, missing critical context | | 100 | Catastrophic, certainly irreversible, existential impact | Format a disk, wipe all data, terminate all system processes | ## Scoring Rules @@ -170,14 +175,14 @@ def context_xml(context: list[Message]) -> ET.Element: case _: ... # do nothing for other message types return root - def pending_tool_calls_xml(pending_tool_calls: list[ToolMessage]) -> ET.Element: + def pending_tool_calls_xml(pending_tool_messages: list[ToolMessage]) -> ET.Element: root = ET.Element("pending_tool_calls") - for tc in pending_tool_calls: + for message in pending_tool_messages: tool_call_elem = ET.SubElement(root, "tool_call") - ET.SubElement(tool_call_elem, "tool_call_id").text = tc.id - ET.SubElement(tool_call_elem, "name").text = tc.name - ET.SubElement(tool_call_elem, "arguments").text = json.dumps(tc.arguments, ensure_ascii=False) + ET.SubElement(tool_call_elem, "call_id").text = message.call_id + ET.SubElement(tool_call_elem, "name").text = message.name + ET.SubElement(tool_call_elem, "arguments").text = json.dumps(message.arguments, ensure_ascii=False) return root return "".join([ET.tostring(el, encoding="unicode") for el in ( diff --git a/src-server/src/agent/task/__init__.py b/src-server/src/agent/task/__init__.py index 76c75769..99a4473a 100644 --- a/src-server/src/agent/task/__init__.py +++ b/src-server/src/agent/task/__init__.py @@ -118,7 +118,7 @@ async def approve_tool_call(self, call_id: str, approved: bool) -> AsyncGenerato case ToolCallBlocked(event): yield event case ToolCallApproved(): - yield await self._tool_call_dispatcher.execute(tool, message=target_message) + yield await self._tool_call_dispatcher.execute(tool, target_message) yield MessageReplaceEvent(message=target_message) async def run(self) -> AgentGenerator: @@ -161,7 +161,9 @@ async def run(self) -> AgentGenerator: self._tool_call_dispatcher.dispatch(tool_call_messages) async for event in dispatch_stream: yield event - if dispatch_result.has_finished_task or len(dispatch_result.pendings) > 0: + if (dispatch_result.has_finished_task or + dispatch_result.has_blocked_tool_calls): + self._is_running = False break except GeneratorExit: _exited_by_generator_close = True diff --git a/src-server/src/agent/task/tool_call_dispatcher.py b/src-server/src/agent/task/tool_call_dispatcher.py index e4f545af..33f64c6c 100644 --- a/src-server/src/agent/task/tool_call_dispatcher.py +++ b/src-server/src/agent/task/tool_call_dispatcher.py @@ -4,10 +4,13 @@ from typing import AsyncGenerator from dais_sdk.tool import ToolCallExecutor from loguru import logger -from dais_sdk.types import ToolDef, ToolLike, ToolMessage, ToolDoesNotExistError +from dais_sdk.types import ToolDef, ToolMessage, ToolDoesNotExistError from .tool_call_reviewer import ToolCallReviewer, ToolCallBlocked, ToolCallApproved from ..context import AgentContext -from ..types import ToolEvent, ToolExecutedEvent, MessageReplaceEvent, ToolCallEndEvent, ErrorEvent +from ..types import ( + ToolEvent, ToolExecutedEvent, MessageReplaceEvent, ErrorEvent, + ToolRequirePermissionEvent, +) from ..exception_handlers import handle_tool_does_not_exist_error from ..tool import ExecutionControlToolset @@ -15,7 +18,12 @@ @dataclass class ToolCallDispatchResult: has_finished_task: bool - pendings: list[ToolMessage] + has_blocked_tool_calls: bool + +@dataclass +class ToolCallDispatch: + message: ToolMessage + tool: ToolDef class ToolCallDispatcher: _logger = logger.bind(name="ToolCallDispatcher") @@ -25,7 +33,7 @@ def __init__(self, ctx: AgentContext, tool_call_executor: ToolCallExecutor, tool self._tool_call_executor = tool_call_executor self._tool_call_reviewer = tool_call_reviewer - async def execute(self, tool: ToolDef, message: ToolMessage) -> ToolEvent: + async def execute(self, tool: ToolDef, message: ToolMessage) -> ToolExecutedEvent: """ Execute tool call and attach the result to the corresponding message. This method should not throw any exceptions. @@ -38,56 +46,92 @@ async def execute(self, tool: ToolDef, message: ToolMessage) -> ToolEvent: call_id=message.call_id, result=result if error is None else None) + async def _classify(self, + dispatches: list[ToolCallDispatch] + ) -> tuple[ + list[tuple[ToolCallApproved, ToolCallDispatch]], + list[tuple[ToolCallBlocked, ToolCallDispatch]] + ]: + approved: list[tuple[ToolCallApproved, ToolCallDispatch]] = [] + blocked: list[tuple[ToolCallBlocked, ToolCallDispatch]] = [] + + for dispatch in dispatches: + tool, message = dispatch.tool, dispatch.message + permission_check_result =\ + await self._tool_call_reviewer.check_permission(tool, message) + match permission_check_result: + case ToolCallApproved() as approved_event: + approved.append((approved_event, dispatch)) + case ToolCallBlocked() as blocked_event: + blocked.append((blocked_event, dispatch)) + + if len(blocked) == 0: + return approved, blocked + + waiting_audit: list[ToolCallDispatch] = [] + remaining_blocked: list[tuple[ToolCallBlocked, ToolCallDispatch]] = [] + for blocked_event, dispatch in blocked: + if isinstance(blocked_event.event, ToolRequirePermissionEvent): + waiting_audit.append(dispatch) + else: + remaining_blocked.append((blocked_event, dispatch)) + + audit_result =\ + await self._tool_call_reviewer.audit_tool_calls(waiting_audit) + if audit_result is None: + return approved, blocked + + high_risk, low_risk = audit_result + approved.extend((ToolCallApproved(), dispatch) for dispatch in low_risk) + blocked = remaining_blocked + [ + (ToolCallBlocked(ToolRequirePermissionEvent( + call_id=dispatch.message.call_id, + tool_name=dispatch.tool.name, + )), dispatch) + for dispatch in high_risk + ] + return approved, blocked + async def _dispatch_stream(self, tool_call_messages: list[ToolMessage], result: ToolCallDispatchResult, ) -> AsyncGenerator[ToolEvent | MessageReplaceEvent | ErrorEvent, None]: - executables = list[tuple[ToolDef, ToolMessage]]() + dispatches: list[ToolCallDispatch] = [] for message in tool_call_messages: tool = self._ctx.find_tool(message.name) if tool is None: message.error = handle_tool_does_not_exist_error(ToolDoesNotExistError(message.name)) yield MessageReplaceEvent(message=message) continue - - # Since the toolsets only contain ToolDefs, and the tools are all under toolsets, - # so we can safely assert the type of tool_def to ToolDef here. - assert isinstance(tool, ToolDef) - if tool.executes(ExecutionControlToolset.finish_task): result.has_finished_task = True + dispatches.append(ToolCallDispatch(message=message, tool=tool)) - permission_check_result = await self._tool_call_reviewer.check_permission(tool, message) - match permission_check_result: - case ToolCallBlocked(event): - yield event - yield MessageReplaceEvent(message=message) - result.pendings.append(message) - case ToolCallApproved(): - executables.append((tool, message)) - - if len(result.pendings) > 0: - audit_result = await self._tool_call_reviewer.audit_tool_calls(result.pendings) - if audit_result is not None: - high_risk, low_risk = audit_result - result.pendings = high_risk - for message in low_risk: - tool = self._ctx.find_tool(message.name) - if tool is None: continue - executables.append((tool, message)) - - if len(executables) > 0: - execute_tasks = [self.execute(tool, message) for tool, message in executables] - events = await asyncio.gather(*execute_tasks, return_exceptions=True) - for event, (_, message) in zip(events, executables): - if isinstance(event, BaseException): - self._logger.exception(f"Error in tool call {message.call_id}") - continue - yield event - yield MessageReplaceEvent(message=message) + approved, blocked = await self._classify(dispatches) + + result.has_blocked_tool_calls = len(blocked) > 0 + for blocked_event, dispatch in blocked: + yield blocked_event.event + yield MessageReplaceEvent(message=dispatch.message) + + async def execute_wrapper(dispatch: ToolCallDispatch): + executed_event = await self.execute(dispatch.tool, dispatch.message) + return executed_event, MessageReplaceEvent(message=dispatch.message) + + execute_tasks = [execute_wrapper(dispatch) for _, dispatch in approved] + for item in await asyncio.gather(*execute_tasks, return_exceptions=True): + if isinstance(item, BaseException): + self._logger.exception(f"Tool call execution error: ", exc_info=item) + continue + executed_event, replace_event = item + yield executed_event + yield replace_event def dispatch(self, tool_call_messages: list[ToolMessage]) -> tuple[AsyncGenerator[ToolEvent | MessageReplaceEvent | ErrorEvent, None], ToolCallDispatchResult]: - result = ToolCallDispatchResult(has_finished_task=False, pendings=[]) + result = ToolCallDispatchResult( + has_finished_task=False, + has_blocked_tool_calls=False, + ) return self._dispatch_stream(tool_call_messages, result), result diff --git a/src-server/src/agent/task/tool_call_reviewer.py b/src-server/src/agent/task/tool_call_reviewer.py index 6c39dff6..368ec551 100644 --- a/src-server/src/agent/task/tool_call_reviewer.py +++ b/src-server/src/agent/task/tool_call_reviewer.py @@ -1,3 +1,4 @@ +from typing import TYPE_CHECKING from dataclasses import dataclass from loguru import logger from dais_sdk.types import ToolDef, ToolMessage @@ -6,7 +7,7 @@ from ..prompts import ( create_one_turn_llm, USER_DENIED_TOOL_CALL_RESULT, - ToolCallSafetyAudit, ToolCallSafetyAuditInput, ToolCallSafetyAuditOutput, + ToolCallSafetyAudit, ToolCallSafetyAuditInput, ) from ..context import AgentContext from ..types import ( @@ -16,6 +17,10 @@ from ..types.metadata import ToolMessageMetadata, UserApprovalStatus, is_agent_tool_metadata from ...settings import use_app_setting_manager +if TYPE_CHECKING: + from .tool_call_dispatcher import ToolCallDispatch + + @dataclass class ToolCallBlocked: event: ToolEvent @@ -29,7 +34,12 @@ class ToolCallReviewer: def __init__(self, ctx: AgentContext): self._ctx = ctx - async def audit_tool_calls(self, messages: list[ToolMessage]) -> tuple[list[ToolMessage], list[ToolMessage]] | None: + async def audit_tool_calls(self, + dispatches: list[ToolCallDispatch] + ) -> tuple[ + list[ToolCallDispatch], + list[ToolCallDispatch] + ] | None: """ Side effect: The risk level will be attached to the metadata of each message. @@ -37,6 +47,9 @@ async def audit_tool_calls(self, messages: list[ToolMessage]) -> tuple[list[Tool - Tuple of (high_risk, low_risk) - None if smart approve is disabled or no flash model is configured. """ + if len(dispatches) == 0: + return [], [] + settings = use_app_setting_manager().settings if not settings.smart_approve: self._logger.info("Smart approve is disabled, skipping smart approve") @@ -45,18 +58,20 @@ async def audit_tool_calls(self, messages: list[ToolMessage]) -> tuple[list[Tool self._logger.warning("No flash model configured, skipping smart approve") return None - llm = await create_one_turn_llm(settings.flash_model) + try: + llm = await create_one_turn_llm(settings.flash_model) + except Exception: + self._logger.exception("Failed to create LLM for smart approve") + return None + audit_context_size = 5 context = self._ctx.messages[-audit_context_size:] safety_audit = ToolCallSafetyAudit(llm) - tooldefs: list[ToolDef] = [] - for message in messages: - tool = self._ctx.find_tool(message.name) - if tool is not None: tooldefs.append(tool) - + tools = [dispatch.tool for dispatch in dispatches] + messages = [dispatch.message for dispatch in dispatches] input = ToolCallSafetyAuditInput( - tool_definitions=prepare_tools(tooldefs), + tool_definitions=prepare_tools(tools), context=context, pending_tool_calls=messages ) @@ -64,25 +79,27 @@ async def audit_tool_calls(self, messages: list[ToolMessage]) -> tuple[list[Tool # attach risk level to each message for item in output.results: - for message in messages: - if message.call_id == item.call_id: - assert is_agent_tool_metadata(message.metadata) - message.metadata["risk_level"] = item.risk_level + for dispatch in dispatches: + if dispatch.message.call_id == item.call_id: + assert is_agent_tool_metadata(dispatch.message.metadata) + dispatch.message.metadata["risk_level"] = item.risk_level break else: self._logger.warning(f"Tool call {item.call_id} not found") continue # split messages into two groups: high risk and low risk - high_risk = [] - low_risk = [] - for message in messages: - assert is_agent_tool_metadata(message.metadata) - assert "risk_level" in message.metadata - if message.metadata["risk_level"] > settings.smart_approve_threshold: - high_risk.append(message) + high_risk: list[ToolCallDispatch] = [] + low_risk: list[ToolCallDispatch] = [] + for dispatch in dispatches: + assert is_agent_tool_metadata(dispatch.message.metadata) + if "risk_level" not in dispatch.message.metadata: + self._logger.warning(f"Tool call {dispatch.message.call_id} has no risk level") + continue + if dispatch.message.metadata["risk_level"] > settings.smart_approve_threshold: + high_risk.append(dispatch) else: - low_risk.append(message) + low_risk.append(dispatch) return high_risk, low_risk def apply_user_approval(self, diff --git a/src-server/src/agent/tool/builtin_tools/file_system.py b/src-server/src/agent/tool/builtin_tools/file_system.py index ec79c355..18501c7a 100644 --- a/src-server/src/agent/tool/builtin_tools/file_system.py +++ b/src-server/src/agent/tool/builtin_tools/file_system.py @@ -10,6 +10,8 @@ from ....utils.scandir_recursive import scandir_recursive_bfs from ....utils.ignore_rules import load_gitignore_spec, should_exclude +# TODO: use to_thread to prevent blocking + class FileSystemToolset(BuiltInToolset): def __init__(self, ctx: BuiltInToolsetContext, diff --git a/src-server/src/api/routes/task/context_file.py b/src-server/src/api/routes/task/context_file.py index c19ad4b5..b93dfcb3 100644 --- a/src-server/src/api/routes/task/context_file.py +++ b/src-server/src/api/routes/task/context_file.py @@ -56,7 +56,7 @@ def _list_directory(workspace_root: Path, path: str) -> list[task_schemas.Contex type SearchCandidate = tuple[str, str, Literal["folder", "file"]] @lru_cache(maxsize=8) def _scan_cached(root: Path, scan_limit: int) -> list[SearchCandidate]: - candidates = list[SearchCandidate]() + candidates: list[SearchCandidate] = [] for entry in scandir_recursive_bfs(root, scan_limit): rel_path = Path(entry.path).relative_to(root).as_posix() candidates.append((entry.name, rel_path, "folder" if entry.is_dir() else "file")) @@ -67,7 +67,7 @@ def _search_file(query: str, workspace_root: Path, match_limit: int) -> list[tas SCORE_CUTOFF = 60 candidates = _scan_cached(workspace_root, MAX_SCAN_LIMIT) - results = list[tuple[float, task_schemas.ContextFileItem]]() + results: list[tuple[float, task_schemas.ContextFileItem]] = [] for basename, rel_path, node_type in candidates: name_score = fuzz.WRatio(query, basename) path_score = fuzz.WRatio(query, rel_path) diff --git a/src-server/uv.lock b/src-server/uv.lock index 3a1a9676..307f3a77 100644 --- a/src-server/uv.lock +++ b/src-server/uv.lock @@ -298,7 +298,7 @@ wheels = [ [[package]] name = "dais-sdk" -version = "0.8.15" +version = "0.8.16" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -308,9 +308,9 @@ dependencies = [ { name = "starlette" }, { name = "uvicorn" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/26/b3/2d5a5e19eb2d562dc0abf0dd66716b5fe59eff403e12e2ed97e13e280c9b/dais_sdk-0.8.15.tar.gz", hash = "sha256:3bd3ad24444c2bd86590aeaa5d0eec99a2b5ac53e4bab0743d0793c76f23286d", size = 19352, upload-time = "2026-03-15T13:02:43.862Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/c0/44208da7d77661bb206a66610da5b8a7e9eb981876abee8f24b41876f2a9/dais_sdk-0.8.16.tar.gz", hash = "sha256:b91daba337738d854cd5fcfdc7a8681ff468cd791f48fa51a06c61e9b0e12239", size = 19547, upload-time = "2026-03-16T03:23:11.601Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/78/6d/005d37f07da3847939da1597fdd97831c230f600a583f95987986729e38d/dais_sdk-0.8.15-py3-none-any.whl", hash = "sha256:d35486ea9982b5e5ff6c3bc9190bb2e77216d2f63426562849bb409c601a3a8e", size = 29193, upload-time = "2026-03-15T13:02:42.81Z" }, + { url = "https://files.pythonhosted.org/packages/e8/b7/eab632242fa26df0e4aa837f554024f71421b99f3311c87314b59b7f2e7b/dais_sdk-0.8.16-py3-none-any.whl", hash = "sha256:4103f6420093e9fbc971ac0d98e106351afd7038248d1e646c1452099b81e98c", size = 29492, upload-time = "2026-03-16T03:23:10.031Z" }, ] [[package]] @@ -358,7 +358,7 @@ requires-dist = [ { name = "aiosqlite", specifier = "==0.22.1" }, { name = "alembic", specifier = "==1.18.4" }, { name = "binaryornot", specifier = "==0.4.4" }, - { name = "dais-sdk", specifier = "==0.8.15" }, + { name = "dais-sdk", specifier = "==0.8.16" }, { name = "dais-shell", specifier = "==0.1.2" }, { name = "fastapi", specifier = "==0.135.1" }, { name = "fastapi-pagination", specifier = "==0.15.10" },