diff --git a/src-frontend/src/components/ai-elements/tool.tsx b/src-frontend/src/components/ai-elements/tool.tsx
index 1aa85fb8..a9cfb45d 100644
--- a/src-frontend/src/components/ai-elements/tool.tsx
+++ b/src-frontend/src/components/ai-elements/tool.tsx
@@ -42,6 +42,7 @@ export type ToolHeaderProps = {
toolsetName?: string;
className?: string;
state: ToolState;
+ riskLevel?: number;
};
export const getStatusBadge = (status: ToolState) => {
@@ -73,6 +74,51 @@ export const getStatusBadge = (status: ToolState) => {
);
};
+const getRiskBadge = (riskLevel: number) => {
+ const normalizedRisk = Math.min(
+ 100,
+ Math.max(0, Math.ceil(riskLevel / 10) * 10)
+ );
+
+ const configs = [
+ {
+ label: "安全",
+ range: [0, 20],
+ className: "bg-emerald-50 text-emerald-700 dark:bg-emerald-950/60 dark:text-emerald-400",
+ },
+ {
+ label: "低风险",
+ range: [30, 50],
+ className: "bg-amber-50 text-amber-700 dark:bg-amber-950/60 dark:text-amber-400",
+ },
+ {
+ label: "中风险",
+ range: [60, 70],
+ className: "bg-orange-50 text-orange-700 dark:bg-orange-950/60 dark:text-orange-400",
+ },
+ {
+ label: "高风险",
+ range: [80, 100],
+ className: "bg-red-50 text-red-700 dark:bg-red-950/60 dark:text-red-400",
+ },
+ ] as const;
+
+ const match = configs.find(
+ (config) =>
+ normalizedRisk >= config.range[0] && normalizedRisk <= config.range[1]
+ );
+
+ if (!match) {
+ return null;
+ }
+
+ return (
+
+ {match.label} {normalizedRisk}
+
+ );
+};
+
export type ToolBreadcrumbProps = {
toolsetName?: string;
toolName: string;
@@ -100,6 +146,7 @@ export const ToolHeader = ({
toolsetName,
state,
toolName,
+ riskLevel,
...props
}: ToolHeaderProps) => (
{getStatusBadge(state)}
-
+
+ {riskLevel && getRiskBadge(riskLevel)}
+
+
);
diff --git a/src-frontend/src/features/SideBar/views/SettingsView/AgentSettings/index.tsx b/src-frontend/src/features/SideBar/views/SettingsView/AgentSettings/index.tsx
new file mode 100644
index 00000000..5036e875
--- /dev/null
+++ b/src-frontend/src/features/SideBar/views/SettingsView/AgentSettings/index.tsx
@@ -0,0 +1,72 @@
+import { useEffect, useState } from "react";
+import { useTranslation } from "react-i18next";
+import { useDebounceFn } from "ahooks";
+import { SIDEBAR_NAMESPACE } from "@/i18n/resources";
+import { SettingItem } from "@/components/custom/item/SettingItem";
+import { useServerSettingsStore } from "@/stores/server-settings-store";
+import { Skeleton } from "@/components/ui/skeleton";
+import { Switch } from "@/components/ui/switch";
+import { Input } from "@/components/ui/input";
+import { AppSettings } from "@/api/generated/schemas";
+
+export function AgentSettings() {
+ const { t } = useTranslation(SIDEBAR_NAMESPACE);
+ const { current: serverSettings, setPartial: setPartialServerSettings } = useServerSettingsStore();
+ const [localSettings, setLocalSettings] = useState(serverSettings);
+ const [disabled, setDisabled] = useState(false);
+
+ useEffect(() => {
+ setLocalSettings(serverSettings);
+ }, [serverSettings]);
+
+ const { run: handleUpdateSettings } = useDebounceFn((update: Partial) => {
+ const updatePromise = setPartialServerSettings(update);
+ if (updatePromise === null) {
+ return;
+ }
+ setDisabled(true);
+ updatePromise.finally(() => setDisabled(false));
+ }, { wait: 300 });
+
+ const handleValueChange = (update: Partial) => {
+ setLocalSettings((prev) => {
+ if (prev === null) {
+ return null;
+ }
+ return { ...prev, ...update };
+ });
+ handleUpdateSettings(update);
+ };
+
+ return (
+
+
+ {localSettings === null ? (
+
+ ) : (
+ handleValueChange({ smart_approve: checked })}
+ disabled={disabled}
+ />
+ )}
+
+
+
+ {localSettings === null ? (
+
+ ) : (
+ handleValueChange({ smart_approve_threshold: Number(e.target.value) })}
+ disabled={disabled}
+ />
+ )}
+
+
+ );
+}
diff --git a/src-frontend/src/features/SideBar/views/SettingsView/index.tsx b/src-frontend/src/features/SideBar/views/SettingsView/index.tsx
index ab8412d1..14c05324 100644
--- a/src-frontend/src/features/SideBar/views/SettingsView/index.tsx
+++ b/src-frontend/src/features/SideBar/views/SettingsView/index.tsx
@@ -3,11 +3,12 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { SIDEBAR_NAMESPACE } from "@/i18n/resources";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { cn } from "@/lib/utils";
-import { SideBarHeader } from "../../components/SideBarHeader";
-import { DevSettings } from "./DevSettings";
import { GeneralSettings } from "./GeneralSettings";
-import { HelperModelSettings } from "./HelperModelSettings";
import { ProviderSettings } from "./ProviderSettings";
+import { HelperModelSettings } from "./HelperModelSettings";
+import { AgentSettings } from "./AgentSettings";
+import { DevSettings } from "./DevSettings";
+import { SideBarHeader } from "../../components/SideBarHeader";
export function SettingsView() {
const { t } = useTranslation(SIDEBAR_NAMESPACE);
@@ -28,6 +29,11 @@ export function SettingsView() {
title: t("settings.tabs.helper_model"),
content: ,
},
+ {
+ id: "agents",
+ title: t("settings.tabs.agents"),
+ content: ,
+ },
{
id: "dev",
title: t("settings.tabs.dev"),
diff --git a/src-frontend/src/features/Tabs/TaskPanel/components/messages/GeneralToolMessage.tsx b/src-frontend/src/features/Tabs/TaskPanel/components/messages/GeneralToolMessage.tsx
index 2e3533bd..e45245df 100644
--- a/src-frontend/src/features/Tabs/TaskPanel/components/messages/GeneralToolMessage.tsx
+++ b/src-frontend/src/features/Tabs/TaskPanel/components/messages/GeneralToolMessage.tsx
@@ -7,6 +7,7 @@ import { useAgentTaskAction } from "../../hooks/use-agent-task";
import { useToolName } from "../../hooks/use-tool-name";
import { useToolState } from "../../hooks/use-tool-state";
import { shouldShowConfirmation, ToolConfirmation } from "./BuiltInToolMessage/components/ToolConfirmation";
+import { ToolMessageMetadata } from "@/api/generated/schemas";
export function GeneralToolMessage({ message }: ToolMessageProps) {
const { reviewTool } = useAgentTaskAction();
@@ -21,6 +22,7 @@ export function GeneralToolMessage({ message }: ToolMessageProps) {
toolName={toolName}
toolsetName={toolsetName}
state={toolState}
+ riskLevel={(message.metadata as ToolMessageMetadata).risk_level}
/>
diff --git a/src-frontend/src/i18n/locales/en/sidebar.json b/src-frontend/src/i18n/locales/en/sidebar.json
index cb49e227..589de27c 100644
--- a/src-frontend/src/i18n/locales/en/sidebar.json
+++ b/src-frontend/src/i18n/locales/en/sidebar.json
@@ -27,6 +27,14 @@
"placeholder": "Plugins view"
},
"settings": {
+ "agents": {
+ "smart_approve": {
+ "title": "Smart approve"
+ },
+ "smart_approve_threshold": {
+ "title": "Smart approve threshold"
+ }
+ },
"dev": {
"devtools": {
"open_button": "Open",
@@ -85,6 +93,7 @@
}
},
"tabs": {
+ "agents": "Agents",
"dev": "Development",
"general": "General",
"helper_model": "Helper model",
diff --git a/src-frontend/src/i18n/locales/zh_CN/sidebar.json b/src-frontend/src/i18n/locales/zh_CN/sidebar.json
index e65640a3..6de481b7 100644
--- a/src-frontend/src/i18n/locales/zh_CN/sidebar.json
+++ b/src-frontend/src/i18n/locales/zh_CN/sidebar.json
@@ -27,6 +27,14 @@
"placeholder": "插件视图"
},
"settings": {
+ "agents": {
+ "smart_approve": {
+ "title": "智能批准"
+ },
+ "smart_approve_threshold": {
+ "title": "智能批准阈值"
+ }
+ },
"dev": {
"devtools": {
"open_button": "打开",
@@ -85,6 +93,7 @@
}
},
"tabs": {
+ "agents": "Agents",
"dev": "开发",
"general": "通用",
"helper_model": "助手模型",
diff --git a/src-frontend/src/stores/server-settings-store.ts b/src-frontend/src/stores/server-settings-store.ts
index 24e2cccc..5eac2e6e 100644
--- a/src-frontend/src/stores/server-settings-store.ts
+++ b/src-frontend/src/stores/server-settings-store.ts
@@ -6,7 +6,7 @@ type ServerSettingsStore = {
current: AppSettings | null;
currentPromise: Promise;
isLoading: boolean;
- setPartial: (settings: Partial) => void;
+ setPartial: (settings: Partial) => Promise | null;
};
export const useServerSettingsStore = create()((set, get) => ({
@@ -28,5 +28,6 @@ export const useServerSettingsStore = create()((set, get) =
return settings;
});
set({ currentPromise: updatePromise });
+ return updatePromise;
},
}));
diff --git a/src-frontend/src/styles/base.css b/src-frontend/src/styles/base.css
index 5df4de68..ef101d9a 100644
--- a/src-frontend/src/styles/base.css
+++ b/src-frontend/src/styles/base.css
@@ -11,3 +11,7 @@ body {
input[type="password"]::-ms-reveal {
display: none !important;
}
+
+.dark input[type="number"] {
+ color-scheme: dark;
+}
diff --git a/src-server/pyproject.toml b/src-server/pyproject.toml
index 2b086ade..a047a695 100644
--- a/src-server/pyproject.toml
+++ b/src-server/pyproject.toml
@@ -3,7 +3,7 @@ name = "dais-server"
version = "0.1.0"
requires-python = ">=3.14"
dependencies = [
- "dais-sdk==0.8.10",
+ "dais-sdk==0.8.16",
"dais-shell==0.1.2",
"alembic==1.18.4",
diff --git a/src-server/src/agent/context.py b/src-server/src/agent/context.py
index 1ac0ba55..e3235ee8 100644
--- a/src-server/src/agent/context.py
+++ b/src-server/src/agent/context.py
@@ -4,7 +4,7 @@
from dataclasses import asdict
from typing import Self, cast
from dais_sdk.tool import Toolset
-from dais_sdk.types import ToolDef
+from dais_sdk.types import Message, ToolDef
from .tool import use_mcp_toolset_manager, BuiltinToolsetManager, McpToolsetManager, BuiltInToolset
from .prompts import BASE_INSTRUCTION, NO_WORKSPACE_INSTRUCTION, NO_AGENT_INSTRUCTION
from .types import ContextUsage
@@ -56,7 +56,7 @@ def __init__(self,
task_id: int,
*,
usage: task_models.TaskUsage,
- messages: list[task_models.TaskMessage],
+ messages: list[Message],
workspace: workspace_schemas.WorkspaceRead,
agent: agent_schemas.AgentRead,
provider: provider_schemas.ProviderRead,
@@ -142,7 +142,7 @@ def model(self) -> provider_schemas.LlmModelRead:
return self._model
@property
- def messages(self) -> list[task_models.TaskMessage]:
+ def messages(self) -> list[Message]:
return self._messages
@property
diff --git a/src-server/src/agent/prompts/__init__.py b/src-server/src/agent/prompts/__init__.py
index a965bd94..a1b4a17c 100644
--- a/src-server/src/agent/prompts/__init__.py
+++ b/src-server/src/agent/prompts/__init__.py
@@ -1,5 +1,5 @@
from .instruction import BASE_INSTRUCTION
-from .title_summarization import TITLE_SUMMARIZATION_INSTRUCTION
+from .one_turns import *
from .built_in_agents import *
USER_IGNORED_TOOL_CALL_RESULT = "[System Message] User ignored this tool call."
diff --git a/src-server/src/agent/prompts/one_turns/__init__.py b/src-server/src/agent/prompts/one_turns/__init__.py
new file mode 100644
index 00000000..09ba7a4c
--- /dev/null
+++ b/src-server/src/agent/prompts/one_turns/__init__.py
@@ -0,0 +1,23 @@
+from dais_sdk import LLM
+from ....db import db_context
+from ....services.llm_model import LlmModelService
+
+from .title_summarization import TitleSummarization
+from .tool_call_safety_audit import ToolCallSafetyAudit, ToolCallSafetyAuditInput, ToolCallSafetyAuditOutput
+
+async def create_one_turn_llm(model_id: int) -> LLM:
+ async with db_context() as db_session:
+ model = await LlmModelService(db_session).get_model_by_id(model_id)
+ provider = model.provider
+ provider = LLM.create_provider(provider.type,
+ provider.base_url,
+ api_key=provider.api_key)
+ return LLM(model.name, provider=provider)
+
+__all__ = [
+ "create_one_turn_llm",
+ "TitleSummarization",
+ "ToolCallSafetyAudit",
+ "ToolCallSafetyAuditInput",
+ "ToolCallSafetyAuditOutput",
+]
diff --git a/src-server/src/agent/prompts/title_summarization.py b/src-server/src/agent/prompts/one_turns/title_summarization.py
similarity index 66%
rename from src-server/src/agent/prompts/title_summarization.py
rename to src-server/src/agent/prompts/one_turns/title_summarization.py
index ad4f04c2..c1691c22 100644
--- a/src-server/src/agent/prompts/title_summarization.py
+++ b/src-server/src/agent/prompts/one_turns/title_summarization.py
@@ -1,4 +1,4 @@
-TITLE_SUMMARIZATION_INSTRUCTION = """\
+INSTRUCTION = """\
Generate a concise title for the following task/conversation in {LANGUAGE}.
CRITICAL: Your response must contain ONLY the title itself - no explanations, no "Here is the title:", no quotes, no punctuation at the end.
@@ -10,3 +10,11 @@
- Use {LANGUAGE} language
- Output format: plain text title only
"""
+
+# --- --- --- --- --- ---
+
+from dais_sdk import LLM, OneTurn
+
+class TitleSummarization(OneTurn):
+ def __init__(self, llm: LLM):
+ super().__init__(llm, INSTRUCTION, output="text")
diff --git a/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py b/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py
new file mode 100644
index 00000000..2077c4bb
--- /dev/null
+++ b/src-server/src/agent/prompts/one_turns/tool_call_safety_audit.py
@@ -0,0 +1,192 @@
+INSTRUCTION = """\
+You are a security review module for an AI agent system. Your sole responsibility is to assess the risk level of pending tool calls before they are executed.
+
+## Inputs
+
+You will receive three sections wrapped in XML tags.
+
+### Tool definitions
+
+```
+
+
+ {name of the tool}
+ {what the tool does}
+ {JSON Schema describing the tool's arguments}
+
+ ...
+
+```
+
+- Contains definitions for all tools referenced in `pending_tool_calls` part.
+- Use the `description` and `input_schema` to understand what each argument means and what the tool is capable of.
+
+### Recent conversation context
+
+```
+
+
+ {message content}
+
+ ...
+
+```
+
+- Messages are listed in chronological order (earlier messages first).
+- `role` is one of: `user`, `assistant`, or `tool` (tool result).
+- This section may be empty or sparse; if so, apply the uncertainty rule (see Rule 5 of Scoring Rules).
+
+### Pending tool calls to review
+
+```
+
+
+ {unique id}
+ {name of the tool}
+ {tool arguments, typically a JSON object}
+
+ ...
+
+```
+
+- There may be one or more `` entries (parallel tool calls).
+- Review each `` independently.
+
+## Your Task
+
+For each pending tool call, assign a **risk level** from 0 to 100 (multiples of 10 only: 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100).
+Output a single JSON object array, with each `tool_call_id` and its risk level string. No explanation, no additional text - only the JSON.
+
+### Output format
+
+```
+{
+ "results": [
+ {"call_id": "tool_call_id1", "risk_level": 10},
+ {"call_id": "tool_call_id2", "risk_level": 50},
+ ...
+ ]
+}
+```
+
+## Risk Level Reference
+
+| Level | Meaning | Typical Examples |
+|-------|---------|-----------------|
+| 0 | Completely safe, fully reversible, read-only | Read a public file, query a database (SELECT only), get current date/time |
+| 10 | Negligible risk, effectively harmless | List directory contents, get system time, echo a string |
+| 20 | Very low risk, easily reversible | Write to a temp file, create a local log entry |
+| 30 | Low risk, reversible with minor effort | Modify a non-critical config, send an internal notification |
+| 40 | Moderate risk, reversible but with some effort | Overwrite an existing file, restart a non-critical service, fetch a public URL (read-only GET, no sensitive data in request) |
+| 50 | Medium risk, partial or uncertain reversibility | Batch-modify records, send an external email, install a package |
+| 60 | Elevated risk, difficult to reverse | Delete non-critical data, modify production config, write to or edit files outside the working directory |
+| 70 | High risk, hard to reverse, potential for significant impact | Delete important files, stop a critical service, bulk data changes |
+| 80 | Very high risk, likely irreversible or severe impact | Drop a database table, revoke access credentials, mass-delete records |
+| 90 | Extremely high risk OR risk level cannot be determined (see Rule 5) | Destructive system commands, unknown/opaque commands, missing critical context |
+| 100 | Catastrophic, certainly irreversible, existential impact | Format a disk, wipe all data, terminate all system processes |
+
+## Scoring Rules
+
+1. **Assess each tool call independently** based on: the tool name, its arguments, and the intent inferred from the recent context.
+
+2. **Consider the reversibility** of the action as the primary factor. An action that is hard or impossible to undo should score higher.
+
+3. **Consider the blast radius**: how many systems, users, or data records could be affected if this call goes wrong.
+
+4. **Consider the context fit**: if the tool call appears inconsistent with the agent's stated goal or the conversation context, increase the score.
+
+5. **Default to 90 when uncertain.** Apply this rule in any of the following situations:
+ - The tool uses commands, APIs, or parameters that you are not familiar with and whose danger cannot be inferred from naming or semantics alone.
+ - The arguments are opaque, heavily encoded, or reference external resources you cannot evaluate (e.g., a dynamically constructed shell command, a raw binary payload, an unfamiliar DSL).
+ - The recent context is missing, too sparse, or ambiguous to determine why this tool call is being made.
+ - You cannot confidently distinguish between a safe and a dangerous interpretation of the call.
+ When in doubt, always err on the side of caution and assign 90.
+
+6. **Never output values other than multiples of 10** (0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100).
+
+7. **Output only the JSON object.** No markdown, no code fences, no explanation.
+"""
+
+# --- --- --- --- --- ---
+
+import json
+import xml.etree.ElementTree as ET
+from dataclasses import dataclass
+from typing import override
+from pydantic import BaseModel
+from dais_sdk import LLM, OneTurn
+from dais_sdk.types import (
+ ToolSchema, Message, ToolMessage
+)
+
+@dataclass
+class ToolCallSafetyAuditInput:
+ tool_definitions: list[ToolSchema]
+ context: list[Message]
+ pending_tool_calls: list[ToolMessage]
+
+class ToolCallSafetyAuditOutput(BaseModel):
+ class OutputItem(BaseModel):
+ call_id: str
+ risk_level: int
+
+ results: list[ToolCallSafetyAuditOutput.OutputItem]
+
+class ToolCallSafetyAudit(OneTurn[ToolCallSafetyAuditInput, ToolCallSafetyAuditOutput]):
+ def __init__(self, llm: LLM):
+ super().__init__(llm, INSTRUCTION, output=ToolCallSafetyAuditOutput, validate=True)
+
+ @override
+ def format_input(self, input: ToolCallSafetyAuditInput) -> str:
+ def tool_definitions_xml(tool_definitions: list[ToolSchema]) -> ET.Element:
+ root = ET.Element("tool_definitions")
+ for t in tool_definitions:
+ tool_elem = ET.SubElement(root, "tool")
+ ET.SubElement(tool_elem, "name").text = t["name"]
+ ET.SubElement(tool_elem, "description").text = t["description"]
+ ET.SubElement(tool_elem, "input_schema").text = json.dumps(t["parameters"], ensure_ascii=False)
+ return root
+
+ def context_xml(context: list[Message]) -> ET.Element:
+ root = ET.Element("context")
+ for msg in context:
+ match msg.role:
+ case "user":
+ msg_elem = ET.SubElement(root, "message", role="user")
+ ET.SubElement(msg_elem, "content").text = msg.content
+ case "assistant":
+ msg_elem = ET.SubElement(root, "message", role="assistant")
+ ET.SubElement(msg_elem, "content").text = msg.content
+ if msg.tool_calls is not None:
+ tool_calls_elem = ET.SubElement(msg_elem, "tool_calls")
+ for tool_call in msg.tool_calls:
+ tool_call_elem = ET.SubElement(tool_calls_elem, "tool_call")
+ ET.SubElement(tool_call_elem, "id").text = tool_call.id
+ ET.SubElement(tool_call_elem, "name").text = tool_call.name
+ ET.SubElement(tool_call_elem, "arguments").text = json.dumps(tool_call.arguments, ensure_ascii=False)
+ case "tool":
+ msg_elem = ET.SubElement(root, "message", role="tool")
+ ET.SubElement(msg_elem, "name").text = msg.name
+ ET.SubElement(msg_elem, "arguments").text = json.dumps(msg.arguments, ensure_ascii=False)
+ if msg.result is not None:
+ ET.SubElement(msg_elem, "result").text = msg.result
+ if msg.error is not None:
+ ET.SubElement(msg_elem, "error").text = msg.error
+ case _: ... # do nothing for other message types
+ return root
+
+ def pending_tool_calls_xml(pending_tool_messages: list[ToolMessage]) -> ET.Element:
+ root = ET.Element("pending_tool_calls")
+
+ for message in pending_tool_messages:
+ tool_call_elem = ET.SubElement(root, "tool_call")
+ ET.SubElement(tool_call_elem, "call_id").text = message.call_id
+ ET.SubElement(tool_call_elem, "name").text = message.name
+ ET.SubElement(tool_call_elem, "arguments").text = json.dumps(message.arguments, ensure_ascii=False)
+ return root
+
+ return "".join([ET.tostring(el, encoding="unicode") for el in (
+ tool_definitions_xml(input.tool_definitions),
+ context_xml(input.context),
+ pending_tool_calls_xml(input.pending_tool_calls),
+ )])
diff --git a/src-server/src/agent/task/__init__.py b/src-server/src/agent/task/__init__.py
index c192c82a..99a4473a 100644
--- a/src-server/src/agent/task/__init__.py
+++ b/src-server/src/agent/task/__init__.py
@@ -4,7 +4,7 @@
from loguru import logger
from dais_sdk.tool import ToolCallExecutor
from dais_sdk.types import (
- ToolMessage, UserMessage, AssistantMessage,
+ Message, ToolMessage, UserMessage, AssistantMessage,
ToolDef, ToolDoesNotExistError, ToolArgumentDecodeError, ToolExecutionError,
)
from ..context import AgentContext
@@ -40,10 +40,10 @@ def __init__(self, ctx: AgentContext):
self._tool_call_executor.exception_handler.set_handler(ToolArgumentDecodeError, handle_tool_argument_decode_error)
self._tool_call_executor.exception_handler.set_handler(ToolExecutionError, handle_tool_execution_error)
- self._tool_call_reviewer = ToolCallReviewer()
+ self._tool_call_reviewer = ToolCallReviewer(ctx)
self._tool_call_dispatcher = ToolCallDispatcher(self._ctx, self._tool_call_executor, self._tool_call_reviewer)
- def _find_message(self, predicate: Callable[[task_models.TaskMessage], bool]) -> task_models.TaskMessage:
+ def _find_message(self, predicate: Callable[[Message], bool]) -> Message:
for message in reversed(self._ctx.messages):
if predicate(message):
return message
@@ -113,12 +113,12 @@ async def approve_tool_call(self, call_id: str, approved: bool) -> AsyncGenerato
# so we can safely assert the type of tool_def to ToolDef here.
assert isinstance(tool, ToolDef)
- permission_check_result = self._tool_call_reviewer.check_permission(tool, target_message)
+ permission_check_result = await self._tool_call_reviewer.check_permission(tool, target_message)
match permission_check_result:
case ToolCallBlocked(event):
yield event
case ToolCallApproved():
- yield await self._tool_call_dispatcher.execute(tool, message=target_message)
+ yield await self._tool_call_dispatcher.execute(tool, target_message)
yield MessageReplaceEvent(message=target_message)
async def run(self) -> AgentGenerator:
@@ -161,7 +161,9 @@ async def run(self) -> AgentGenerator:
self._tool_call_dispatcher.dispatch(tool_call_messages)
async for event in dispatch_stream:
yield event
- if dispatch_result.has_finished_task or len(dispatch_result.pendings) > 0:
+ if (dispatch_result.has_finished_task or
+ dispatch_result.has_blocked_tool_calls):
+ self._is_running = False
break
except GeneratorExit:
_exited_by_generator_close = True
diff --git a/src-server/src/agent/task/llm_request_manager.py b/src-server/src/agent/task/llm_request_manager.py
index 12921fe4..88baf9c7 100644
--- a/src-server/src/agent/task/llm_request_manager.py
+++ b/src-server/src/agent/task/llm_request_manager.py
@@ -1,7 +1,6 @@
import asyncio
import uuid
from collections.abc import AsyncGenerator
-from dais_sdk.providers import OpenAIProvider
from loguru import logger
from dais_sdk import LLM
from dais_sdk.types import (
@@ -33,14 +32,13 @@ def llm(self) -> LLM:
return self._llm
def _llm_factory(self) -> LLM:
- provider = OpenAIProvider(
- self._ctx.provider.base_url,
- api_key=self._ctx.provider.api_key)
- return LLM(provider=provider)
+ provider = LLM.create_provider(self._ctx.provider.type,
+ self._ctx.provider.base_url,
+ api_key=self._ctx.provider.api_key)
+ return LLM(self._ctx.model.name, provider=provider)
def _create_request_param(self) -> LlmRequestParams:
params = LlmRequestParams(
- model=self._ctx.model.name,
instructions=self._ctx.system_instruction,
messages=self._ctx.messages)
usable_tool_ids = self._ctx.usable_tool_ids
diff --git a/src-server/src/agent/task/tool_call_dispatcher.py b/src-server/src/agent/task/tool_call_dispatcher.py
index 1a262629..33f64c6c 100644
--- a/src-server/src/agent/task/tool_call_dispatcher.py
+++ b/src-server/src/agent/task/tool_call_dispatcher.py
@@ -4,10 +4,13 @@
from typing import AsyncGenerator
from dais_sdk.tool import ToolCallExecutor
from loguru import logger
-from dais_sdk.types import ToolDef, ToolLike, ToolMessage, ToolDoesNotExistError
+from dais_sdk.types import ToolDef, ToolMessage, ToolDoesNotExistError
from .tool_call_reviewer import ToolCallReviewer, ToolCallBlocked, ToolCallApproved
from ..context import AgentContext
-from ..types import ToolEvent, ToolExecutedEvent, MessageReplaceEvent, ToolCallEndEvent, ErrorEvent
+from ..types import (
+ ToolEvent, ToolExecutedEvent, MessageReplaceEvent, ErrorEvent,
+ ToolRequirePermissionEvent,
+)
from ..exception_handlers import handle_tool_does_not_exist_error
from ..tool import ExecutionControlToolset
@@ -15,7 +18,12 @@
@dataclass
class ToolCallDispatchResult:
has_finished_task: bool
- pendings: list[ToolMessage]
+ has_blocked_tool_calls: bool
+
+@dataclass
+class ToolCallDispatch:
+ message: ToolMessage
+ tool: ToolDef
class ToolCallDispatcher:
_logger = logger.bind(name="ToolCallDispatcher")
@@ -25,7 +33,7 @@ def __init__(self, ctx: AgentContext, tool_call_executor: ToolCallExecutor, tool
self._tool_call_executor = tool_call_executor
self._tool_call_reviewer = tool_call_reviewer
- async def execute(self, tool: ToolLike, message: ToolMessage) -> ToolEvent:
+ async def execute(self, tool: ToolDef, message: ToolMessage) -> ToolExecutedEvent:
"""
Execute tool call and attach the result to the corresponding message.
This method should not throw any exceptions.
@@ -38,46 +46,92 @@ async def execute(self, tool: ToolLike, message: ToolMessage) -> ToolEvent:
call_id=message.call_id,
result=result if error is None else None)
+ async def _classify(self,
+ dispatches: list[ToolCallDispatch]
+ ) -> tuple[
+ list[tuple[ToolCallApproved, ToolCallDispatch]],
+ list[tuple[ToolCallBlocked, ToolCallDispatch]]
+ ]:
+ approved: list[tuple[ToolCallApproved, ToolCallDispatch]] = []
+ blocked: list[tuple[ToolCallBlocked, ToolCallDispatch]] = []
+
+ for dispatch in dispatches:
+ tool, message = dispatch.tool, dispatch.message
+ permission_check_result =\
+ await self._tool_call_reviewer.check_permission(tool, message)
+ match permission_check_result:
+ case ToolCallApproved() as approved_event:
+ approved.append((approved_event, dispatch))
+ case ToolCallBlocked() as blocked_event:
+ blocked.append((blocked_event, dispatch))
+
+ if len(blocked) == 0:
+ return approved, blocked
+
+ waiting_audit: list[ToolCallDispatch] = []
+ remaining_blocked: list[tuple[ToolCallBlocked, ToolCallDispatch]] = []
+ for blocked_event, dispatch in blocked:
+ if isinstance(blocked_event.event, ToolRequirePermissionEvent):
+ waiting_audit.append(dispatch)
+ else:
+ remaining_blocked.append((blocked_event, dispatch))
+
+ audit_result =\
+ await self._tool_call_reviewer.audit_tool_calls(waiting_audit)
+ if audit_result is None:
+ return approved, blocked
+
+ high_risk, low_risk = audit_result
+ approved.extend((ToolCallApproved(), dispatch) for dispatch in low_risk)
+ blocked = remaining_blocked + [
+ (ToolCallBlocked(ToolRequirePermissionEvent(
+ call_id=dispatch.message.call_id,
+ tool_name=dispatch.tool.name,
+ )), dispatch)
+ for dispatch in high_risk
+ ]
+ return approved, blocked
+
async def _dispatch_stream(self,
tool_call_messages: list[ToolMessage],
result: ToolCallDispatchResult,
) -> AsyncGenerator[ToolEvent
| MessageReplaceEvent
| ErrorEvent, None]:
- executables = list[tuple[ToolDef, ToolMessage]]()
+ dispatches: list[ToolCallDispatch] = []
for message in tool_call_messages:
- tool: ToolLike | None = self._ctx.find_tool(message.name)
+ tool = self._ctx.find_tool(message.name)
if tool is None:
message.error = handle_tool_does_not_exist_error(ToolDoesNotExistError(message.name))
yield MessageReplaceEvent(message=message)
continue
-
- # Since the toolsets only contain ToolDefs, and the tools are all under toolsets,
- # so we can safely assert the type of tool_def to ToolDef here.
- assert isinstance(tool, ToolDef)
-
if tool.executes(ExecutionControlToolset.finish_task):
result.has_finished_task = True
+ dispatches.append(ToolCallDispatch(message=message, tool=tool))
- permission_check_result = self._tool_call_reviewer.check_permission(tool, message)
- match permission_check_result:
- case ToolCallBlocked(event):
- yield event
- yield MessageReplaceEvent(message=message)
- result.pendings.append(message)
- case ToolCallApproved():
- executables.append((tool, message))
-
- if len(executables) > 0:
- execute_tasks = [self.execute(tool, message) for tool, message in executables]
- events = await asyncio.gather(*execute_tasks, return_exceptions=True)
- for event, (_, message) in zip(events, executables):
- if isinstance(event, BaseException):
- self._logger.exception(f"Error in tool call {message.call_id}")
- continue
- yield event
- yield MessageReplaceEvent(message=message)
+ approved, blocked = await self._classify(dispatches)
+
+ result.has_blocked_tool_calls = len(blocked) > 0
+ for blocked_event, dispatch in blocked:
+ yield blocked_event.event
+ yield MessageReplaceEvent(message=dispatch.message)
+
+ async def execute_wrapper(dispatch: ToolCallDispatch):
+ executed_event = await self.execute(dispatch.tool, dispatch.message)
+ return executed_event, MessageReplaceEvent(message=dispatch.message)
+
+ execute_tasks = [execute_wrapper(dispatch) for _, dispatch in approved]
+ for item in await asyncio.gather(*execute_tasks, return_exceptions=True):
+ if isinstance(item, BaseException):
+ self._logger.exception(f"Tool call execution error: ", exc_info=item)
+ continue
+ executed_event, replace_event = item
+ yield executed_event
+ yield replace_event
def dispatch(self, tool_call_messages: list[ToolMessage]) -> tuple[AsyncGenerator[ToolEvent | MessageReplaceEvent | ErrorEvent, None], ToolCallDispatchResult]:
- result = ToolCallDispatchResult(has_finished_task=False, pendings=[])
+ result = ToolCallDispatchResult(
+ has_finished_task=False,
+ has_blocked_tool_calls=False,
+ )
return self._dispatch_stream(tool_call_messages, result), result
diff --git a/src-server/src/agent/task/tool_call_reviewer.py b/src-server/src/agent/task/tool_call_reviewer.py
index b94bbf9f..368ec551 100644
--- a/src-server/src/agent/task/tool_call_reviewer.py
+++ b/src-server/src/agent/task/tool_call_reviewer.py
@@ -1,13 +1,25 @@
+from typing import TYPE_CHECKING
from dataclasses import dataclass
from loguru import logger
from dais_sdk.types import ToolDef, ToolMessage
+from dais_sdk.tool.prepare import prepare_tools
from ..tool.types import is_tool_metadata
-from ..prompts import USER_DENIED_TOOL_CALL_RESULT
+from ..prompts import (
+ create_one_turn_llm,
+ USER_DENIED_TOOL_CALL_RESULT,
+ ToolCallSafetyAudit, ToolCallSafetyAuditInput,
+)
+from ..context import AgentContext
from ..types import (
ToolDeniedEvent, ToolEvent,
ToolRequirePermissionEvent, ToolRequireUserResponseEvent
)
from ..types.metadata import ToolMessageMetadata, UserApprovalStatus, is_agent_tool_metadata
+from ...settings import use_app_setting_manager
+
+if TYPE_CHECKING:
+ from .tool_call_dispatcher import ToolCallDispatch
+
@dataclass
class ToolCallBlocked:
@@ -19,6 +31,77 @@ class ToolCallApproved: ...
class ToolCallReviewer:
_logger = logger.bind(name="ToolCallReviewer")
+ def __init__(self, ctx: AgentContext):
+ self._ctx = ctx
+
+ async def audit_tool_calls(self,
+ dispatches: list[ToolCallDispatch]
+ ) -> tuple[
+ list[ToolCallDispatch],
+ list[ToolCallDispatch]
+ ] | None:
+ """
+ Side effect: The risk level will be attached to the metadata of each message.
+
+ Returns:
+ - Tuple of (high_risk, low_risk)
+ - None if smart approve is disabled or no flash model is configured.
+ """
+ if len(dispatches) == 0:
+ return [], []
+
+ settings = use_app_setting_manager().settings
+ if not settings.smart_approve:
+ self._logger.info("Smart approve is disabled, skipping smart approve")
+ return None
+ if settings.flash_model is None:
+ self._logger.warning("No flash model configured, skipping smart approve")
+ return None
+
+ try:
+ llm = await create_one_turn_llm(settings.flash_model)
+ except Exception:
+ self._logger.exception("Failed to create LLM for smart approve")
+ return None
+
+ audit_context_size = 5
+ context = self._ctx.messages[-audit_context_size:]
+ safety_audit = ToolCallSafetyAudit(llm)
+
+ tools = [dispatch.tool for dispatch in dispatches]
+ messages = [dispatch.message for dispatch in dispatches]
+ input = ToolCallSafetyAuditInput(
+ tool_definitions=prepare_tools(tools),
+ context=context,
+ pending_tool_calls=messages
+ )
+ output = await safety_audit(input)
+
+ # attach risk level to each message
+ for item in output.results:
+ for dispatch in dispatches:
+ if dispatch.message.call_id == item.call_id:
+ assert is_agent_tool_metadata(dispatch.message.metadata)
+ dispatch.message.metadata["risk_level"] = item.risk_level
+ break
+ else:
+ self._logger.warning(f"Tool call {item.call_id} not found")
+ continue
+
+ # split messages into two groups: high risk and low risk
+ high_risk: list[ToolCallDispatch] = []
+ low_risk: list[ToolCallDispatch] = []
+ for dispatch in dispatches:
+ assert is_agent_tool_metadata(dispatch.message.metadata)
+ if "risk_level" not in dispatch.message.metadata:
+ self._logger.warning(f"Tool call {dispatch.message.call_id} has no risk level")
+ continue
+ if dispatch.message.metadata["risk_level"] > settings.smart_approve_threshold:
+ high_risk.append(dispatch)
+ else:
+ low_risk.append(dispatch)
+ return high_risk, low_risk
+
def apply_user_approval(self,
call_id: str,
metadata: ToolMessageMetadata,
@@ -41,7 +124,7 @@ def apply_user_approval(self,
)
return True
- def check_permission(self, tool: ToolDef, message: ToolMessage) -> ToolCallBlocked | ToolCallApproved:
+ async def check_permission(self, tool: ToolDef, message: ToolMessage) -> ToolCallBlocked | ToolCallApproved:
# use TypeGuards to assert the type of metadata
assert is_tool_metadata(tool.metadata)
assert is_agent_tool_metadata(message.metadata)
@@ -57,7 +140,6 @@ def check_permission(self, tool: ToolDef, message: ToolMessage) -> ToolCallBlock
message.metadata["user_approval"] = UserApprovalStatus.PENDING
match message.metadata["user_approval"]:
case UserApprovalStatus.PENDING:
- # TODO: implement smart approve here
return ToolCallBlocked(
event=ToolRequirePermissionEvent(call_id=message.call_id, tool_name=tool.name))
case UserApprovalStatus.DENIED:
diff --git a/src-server/src/agent/temp_generation.py b/src-server/src/agent/temp_generation.py
deleted file mode 100644
index 5a212f04..00000000
--- a/src-server/src/agent/temp_generation.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from typing import Self
-from dais_sdk import LLM
-from dais_sdk.providers import OpenAIProvider
-from dais_sdk.types import LlmRequestParams, SystemMessage, UserMessage, AssistantMessage
-from ..db import db_context
-from ..db.models import provider as provider_models
-from ..services.llm_model import LlmModelService
-
-class TempGeneration:
- def __init__(self, instruction: str, model_name: str, provider: provider_models.Provider) -> None:
- self._instruction = instruction
- self._model_name = model_name
- self._provider = provider
-
- @classmethod
- async def create(cls, instruction: str, model_id: int) -> Self:
- async with db_context() as db_session:
- model = await LlmModelService(db_session).get_model_by_id(model_id)
- provider = model.provider
- return cls(instruction, model.name, provider)
-
- def _request_param_factory(self, input: str) -> LlmRequestParams:
- return LlmRequestParams(
- model=self._model_name,
- messages=[
- SystemMessage(content=self._instruction),
- UserMessage(content=input),
- ],
- tool_choice="none")
-
- async def generate(self, input: str) -> str | None:
- provider = OpenAIProvider(self._provider.base_url, api_key=self._provider.api_key)
- llm = LLM(provider=provider)
- request_params = self._request_param_factory(input)
- message = await llm.generate_text(request_params)
- return message.content
diff --git a/src-server/src/agent/tool/builtin_tools/file_system.py b/src-server/src/agent/tool/builtin_tools/file_system.py
index ec79c355..18501c7a 100644
--- a/src-server/src/agent/tool/builtin_tools/file_system.py
+++ b/src-server/src/agent/tool/builtin_tools/file_system.py
@@ -10,6 +10,8 @@
from ....utils.scandir_recursive import scandir_recursive_bfs
from ....utils.ignore_rules import load_gitignore_spec, should_exclude
+# TODO: use to_thread to prevent blocking
+
class FileSystemToolset(BuiltInToolset):
def __init__(self,
ctx: BuiltInToolsetContext,
diff --git a/src-server/src/agent/types/metadata.py b/src-server/src/agent/types/metadata.py
index 08564e45..53f171be 100644
--- a/src-server/src/agent/types/metadata.py
+++ b/src-server/src/agent/types/metadata.py
@@ -8,6 +8,7 @@ class UserApprovalStatus(str, Enum):
class ToolMessageMetadata(TypedDict, total=False):
user_approval: UserApprovalStatus
+ risk_level: int
@staticmethod
def is_agent_tool_metadata(_: dict) -> TypeGuard[ToolMessageMetadata]:
diff --git a/src-server/src/agent/types/stream.py b/src-server/src/agent/types/stream.py
index 46d463db..d7332c38 100644
--- a/src-server/src/agent/types/stream.py
+++ b/src-server/src/agent/types/stream.py
@@ -1,13 +1,13 @@
from collections.abc import AsyncGenerator
-from dataclasses import dataclass
from typing import Annotated, Literal, Self
from dais_sdk.types import (
+ Message,
ToolMessage, AssistantMessage,
TextChunkEvent as SdkTextChunkEvent,
ToolCallChunkEvent as SdkToolCallChunkEvent
)
from pydantic import BaseModel, Discriminator
-from ...db.models.task import TaskMessage, TaskUsage
+from ...db.models.task import TaskUsage
class MessageStartEvent(BaseModel):
@@ -72,7 +72,7 @@ def from_sdk(cls, message: AssistantMessage) -> Self:
return cls(message=message)
class MessageReplaceEvent(BaseModel):
- message: TaskMessage
+ message: Message
event_id: Literal["MESSAGE_REPLACE"] = "MESSAGE_REPLACE"
class ToolCallEndEvent(BaseModel):
diff --git a/src-server/src/api/routes/llm_api.py b/src-server/src/api/routes/llm_api.py
index 22aa8c98..f4122755 100644
--- a/src-server/src/api/routes/llm_api.py
+++ b/src-server/src/api/routes/llm_api.py
@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends
from pydantic import BaseModel
-from dais_sdk.providers import OpenAIProvider, LlmProviders
+from dais_sdk import LLM
+from dais_sdk.providers import LlmProviders
llm_api_router = APIRouter(tags=["llm_api"])
@@ -16,6 +17,6 @@ class FetchModelsResponse(BaseModel):
@llm_api_router.get("/models", response_model=FetchModelsResponse)
async def fetch_models(params: FetchModelsParams = Depends(FetchModelsParams)):
- provider = OpenAIProvider(params.base_url, api_key=params.api_key)
+ provider = LLM.create_provider(params.type, params.base_url, api_key=params.api_key)
models = await provider.list_models()
return FetchModelsResponse(models=models)
diff --git a/src-server/src/api/routes/settings.py b/src-server/src/api/routes/settings.py
index 0182d900..99ac22bf 100644
--- a/src-server/src/api/routes/settings.py
+++ b/src-server/src/api/routes/settings.py
@@ -1,4 +1,3 @@
-import asyncio
from fastapi import APIRouter
from ...settings import AppSettings, use_app_setting_manager
diff --git a/src-server/src/api/routes/task/background_task.py b/src-server/src/api/routes/task/background_task.py
index 4433e6a0..0b4e0236 100644
--- a/src-server/src/api/routes/task/background_task.py
+++ b/src-server/src/api/routes/task/background_task.py
@@ -1,13 +1,11 @@
import time
from typing import TYPE_CHECKING
from loguru import logger
-from dais_sdk.types import UserMessage
-from ....db.models.task import TaskMessage
+from dais_sdk.types import Message, UserMessage
from ....db import db_context
from ....services.task import TaskService
from ....schemas import task as task_schemas
-from ....agent.prompts import TITLE_SUMMARIZATION_INSTRUCTION
-from ....agent.temp_generation import TempGeneration
+from ....agent.prompts import create_one_turn_llm, TitleSummarization
from ....settings import use_app_setting_manager
from ....api.sse_dispatcher.types import TaskTitleUpdatedEvent
@@ -18,15 +16,16 @@
async def summarize_title_in_background(
task_id: int,
- context: list[TaskMessage],
+ context: list[Message],
sse_dispatcher: SseDispatcher,
):
settings = use_app_setting_manager().settings
if settings.flash_model is None: return
try:
- temp_generation = await TempGeneration.create(TITLE_SUMMARIZATION_INSTRUCTION, settings.flash_model)
+ llm = await create_one_turn_llm(settings.flash_model)
+ summarizer = TitleSummarization(llm)
assert isinstance(context[0], UserMessage)
- title = await temp_generation.generate(context[0].content)
+ title = await summarizer(context[0].content)
except Exception:
_logger.exception("Failed to request title summarization for task {}", task_id)
return
diff --git a/src-server/src/api/routes/task/context_file.py b/src-server/src/api/routes/task/context_file.py
index 8c2aa4c8..b93dfcb3 100644
--- a/src-server/src/api/routes/task/context_file.py
+++ b/src-server/src/api/routes/task/context_file.py
@@ -56,7 +56,7 @@ def _list_directory(workspace_root: Path, path: str) -> list[task_schemas.Contex
type SearchCandidate = tuple[str, str, Literal["folder", "file"]]
@lru_cache(maxsize=8)
def _scan_cached(root: Path, scan_limit: int) -> list[SearchCandidate]:
- candidates = list[SearchCandidate]()
+ candidates: list[SearchCandidate] = []
for entry in scandir_recursive_bfs(root, scan_limit):
rel_path = Path(entry.path).relative_to(root).as_posix()
candidates.append((entry.name, rel_path, "folder" if entry.is_dir() else "file"))
@@ -67,7 +67,7 @@ def _search_file(query: str, workspace_root: Path, match_limit: int) -> list[tas
SCORE_CUTOFF = 60
candidates = _scan_cached(workspace_root, MAX_SCAN_LIMIT)
- results = list[tuple[float, task_schemas.ContextFileItem]]()
+ results: list[tuple[float, task_schemas.ContextFileItem]] = []
for basename, rel_path, node_type in candidates:
name_score = fuzz.WRatio(query, basename)
path_score = fuzz.WRatio(query, rel_path)
@@ -86,6 +86,7 @@ class SearchFileResult(BaseModel):
items: list[task_schemas.ContextFileItem]
total: int
+# TODO: use to_thread
@context_file_router.get("/files/list", response_model=ListDirectoryResult)
async def list_directory(
db_session: DbSessionDep,
diff --git a/src-server/src/db/models/task.py b/src-server/src/db/models/task.py
index 45be29a2..ef0ae024 100644
--- a/src-server/src/db/models/task.py
+++ b/src-server/src/db/models/task.py
@@ -1,8 +1,8 @@
import time
from dataclasses import dataclass
-from typing import Annotated, TYPE_CHECKING, Self
-from dais_sdk.types import AssistantMessage, SystemMessage, ToolMessage, UserMessage
-from pydantic import Discriminator, TypeAdapter
+from typing import TYPE_CHECKING, Self
+from dais_sdk.types import Message
+from pydantic import TypeAdapter
from sqlalchemy import ForeignKey
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -13,12 +13,8 @@
from .agent import Agent
from .workspace import Workspace
-TaskMessage = Annotated[
- UserMessage | AssistantMessage | SystemMessage | ToolMessage,
- Discriminator("role")
-]
-message_adapter = TypeAdapter(TaskMessage)
-messages_adapter = TypeAdapter(list[TaskMessage])
+message_adapter = TypeAdapter(Message)
+messages_adapter = TypeAdapter(list[Message])
@dataclass
class TaskUsage:
@@ -45,7 +41,7 @@ class Task(Base):
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str]
usage: Mapped[TaskUsage] = mapped_column(DataClassJSON(TaskUsage), default=TaskUsage.default)
- messages: Mapped[list[TaskMessage]] = mapped_column(PydanticJSON(messages_adapter), default=list)
+ messages: Mapped[list[Message]] = mapped_column(PydanticJSON(messages_adapter), default=list)
last_run_at: Mapped[int] = mapped_column(default=lambda: int(time.time()))
agent_id: Mapped[int | None] = mapped_column(ForeignKey("agents.id", ondelete="SET NULL"))
diff --git a/src-server/src/schemas/task.py b/src-server/src/schemas/task.py
index d3c34655..82de955e 100644
--- a/src-server/src/schemas/task.py
+++ b/src-server/src/schemas/task.py
@@ -1,10 +1,11 @@
from typing import Literal
from . import DTOBase
-from ..db.models.task import TaskMessage, TaskUsage
+from dais_sdk.types import Message
+from ..db.models.task import TaskUsage
class TaskBase(DTOBase):
title: str
- messages: list[TaskMessage]
+ messages: list[Message]
class TaskBrief(TaskBase):
id: int
@@ -36,7 +37,7 @@ class TaskUpdate(DTOBase):
usage: TaskUsage | None
last_run_at: int
agent_id: int | None
- messages: list[TaskMessage] | None
+ messages: list[Message] | None
# --- --- --- --- --- ---
diff --git a/src-server/src/settings.py b/src-server/src/settings.py
index 5ec12dc6..dc23b1a6 100644
--- a/src-server/src/settings.py
+++ b/src-server/src/settings.py
@@ -71,6 +71,9 @@ class AppSettings(JsonSettings):
reply_language: str = "zh_CN"
flash_model: int | None = None
+ smart_approve: bool = True
+ smart_approve_threshold: int = 50 # 0 ~ 100
+
async def validate_self(self):
if self.flash_model is not None:
async with db_context() as db_session:
diff --git a/src-server/uv.lock b/src-server/uv.lock
index 6555ccff..307f3a77 100644
--- a/src-server/uv.lock
+++ b/src-server/uv.lock
@@ -298,7 +298,7 @@ wheels = [
[[package]]
name = "dais-sdk"
-version = "0.8.10"
+version = "0.8.16"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
@@ -308,9 +308,9 @@ dependencies = [
{ name = "starlette" },
{ name = "uvicorn" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/aa/10/178658fccb333cbab743adb00789a9555f74d571ad5ac260ef07ce1172c1/dais_sdk-0.8.10.tar.gz", hash = "sha256:1e5ef7c4cbf46d2e956a655ef47884960d7fa70e05d1651ba591c86491755d35", size = 18317, upload-time = "2026-03-14T05:48:43.444Z" }
+sdist = { url = "https://files.pythonhosted.org/packages/b2/c0/44208da7d77661bb206a66610da5b8a7e9eb981876abee8f24b41876f2a9/dais_sdk-0.8.16.tar.gz", hash = "sha256:b91daba337738d854cd5fcfdc7a8681ff468cd791f48fa51a06c61e9b0e12239", size = 19547, upload-time = "2026-03-16T03:23:11.601Z" }
wheels = [
- { url = "https://files.pythonhosted.org/packages/6d/29/8cd77864809ee7d52d6d21f6009923018b2aa0b7046a3c31bd8633fcc2e4/dais_sdk-0.8.10-py3-none-any.whl", hash = "sha256:4bca30e074937b049e340b07b20783d02114b1b1a4996b543441b7f5381b91eb", size = 27381, upload-time = "2026-03-14T05:48:41.889Z" },
+ { url = "https://files.pythonhosted.org/packages/e8/b7/eab632242fa26df0e4aa837f554024f71421b99f3311c87314b59b7f2e7b/dais_sdk-0.8.16-py3-none-any.whl", hash = "sha256:4103f6420093e9fbc971ac0d98e106351afd7038248d1e646c1452099b81e98c", size = 29492, upload-time = "2026-03-16T03:23:10.031Z" },
]
[[package]]
@@ -358,7 +358,7 @@ requires-dist = [
{ name = "aiosqlite", specifier = "==0.22.1" },
{ name = "alembic", specifier = "==1.18.4" },
{ name = "binaryornot", specifier = "==0.4.4" },
- { name = "dais-sdk", specifier = "==0.8.10" },
+ { name = "dais-sdk", specifier = "==0.8.16" },
{ name = "dais-shell", specifier = "==0.1.2" },
{ name = "fastapi", specifier = "==0.135.1" },
{ name = "fastapi-pagination", specifier = "==0.15.10" },