From 6f7ac820291fe9ce9f2df3e26f4c03acd0a7f314 Mon Sep 17 00:00:00 2001 From: EterUltimate <139631158+EterUltimate@users.noreply.github.com> Date: Sun, 7 Jun 2026 14:52:05 +0800 Subject: [PATCH 1/3] fix: defer provider binding during cold start --- README.md | 2 +- README_EN.md | 2 +- __init__.py | 2 +- _conf_schema.json | 6 + config.py | 7 + core/plugin_lifecycle.py | 7 + metadata.yaml | 2 +- .../core_learning/v2_learning_integration.py | 189 ++++++++++++++++-- services/embedding/factory.py | 73 ++++++- services/provider_registry.py | 179 +++++++++++++++++ services/reranker/factory.py | 73 ++++++- tests/unit/test_config.py | 3 + tests/unit/test_provider_registry_rebind.py | 162 +++++++++++++++ 13 files changed, 683 insertions(+), 24 deletions(-) create mode 100644 services/provider_registry.py create mode 100644 tests/unit/test_provider_registry_rebind.py diff --git a/README.md b/README.md index 0b19d4e1..919a1e71 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ 让 AstrBot 在群聊中持续采集、学习、审查并注入上下文,使 Bot 逐步具备表达风格、群组黑话、社交关系、长期记忆和人格演化能力。 -[![Version](https://img.shields.io/badge/version-3.1.5-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) +[![Version](https://img.shields.io/badge/version-3.1.6-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-AGPL--3.0-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) diff --git a/README_EN.md b/README_EN.md index 6d656400..3b53222e 100644 --- a/README_EN.md +++ b/README_EN.md @@ -14,7 +14,7 @@
-[![Version](https://img.shields.io/badge/version-3.1.5-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-AGPL--3.0-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) +[![Version](https://img.shields.io/badge/version-3.1.6-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-AGPL--3.0-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) [Features](#what-we-can-do) · [Quick Start](#quick-start) · [Web UI](#visual-management-interface) · [Community](#community) · [Contributing](CONTRIBUTING.md) diff --git a/__init__.py b/__init__.py index f37e7720..e387b0fc 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,5 @@ # AstrBot 自学习插件 -__version__ = "3.1.5" +__version__ = "3.1.6" # Ensure parent namespace packages ("data", "data.plugins") are # durably registered in sys.modules. AstrBot loads plugins via diff --git a/_conf_schema.json b/_conf_schema.json index fd63c062..01003142 100644 --- a/_conf_schema.json +++ b/_conf_schema.json @@ -721,6 +721,12 @@ "hint": "当知识+记忆候选文档数量低于此阈值时,跳过Reranker以节省延迟。设为0则始终执行重排序", "default": 3 }, + "provider_retry_interval_seconds": { + "description": "Provider 重试间隔(秒)", + "type": "float", + "hint": "AstrBot 冷启动时 Provider 注册表可能尚未就绪。插件会按此间隔重试绑定 Embedding/Reranker Provider。", + "default": 10.0 + }, "knowledge_engine": { "description": "知识引擎", "type": "string", diff --git a/config.py b/config.py index 4e6f4dbb..0ba4a815 100644 --- a/config.py +++ b/config.py @@ -94,6 +94,7 @@ class PluginConfig(BaseModel): rerank_provider_id: Optional[str] = None rerank_top_k: int = 5 rerank_min_candidates: int = 3 # 候选文档数低于此阈值时跳过 rerank 以节省延迟 + provider_retry_interval_seconds: float = 10.0 # Provider 注册表未就绪时的重试间隔 # v2 Architecture: Knowledge engine knowledge_engine: str = "legacy" # "lightrag" | "legacy" @@ -363,6 +364,9 @@ def create_from_config(cls, config: dict, data_dir: Optional[str] = None) -> 'Pl rerank_provider_id=v2_settings.get('rerank_provider_id', None), rerank_top_k=v2_settings.get('rerank_top_k', 5), rerank_min_candidates=v2_settings.get('rerank_min_candidates', 3), + provider_retry_interval_seconds=v2_settings.get( + 'provider_retry_interval_seconds', 10.0 + ), knowledge_engine=v2_settings.get('knowledge_engine', 'legacy'), lightrag_query_mode=v2_settings.get('lightrag_query_mode', 'local'), memory_engine=v2_settings.get('memory_engine', 'legacy'), @@ -616,6 +620,9 @@ def validate_config(self) -> List[str]: if self.topic_detection_interval_messages <= 0: errors.append("话题检测触发消息数必须大于0") + if self.provider_retry_interval_seconds <= 0: + errors.append("Provider重试间隔必须大于0秒") + if self.message_min_length >= self.message_max_length: errors.append("消息最小长度必须小于最大长度") diff --git a/core/plugin_lifecycle.py b/core/plugin_lifecycle.py index 02cf67af..bf6a8c25 100644 --- a/core/plugin_lifecycle.py +++ b/core/plugin_lifecycle.py @@ -590,6 +590,13 @@ async def _delayed_provider_reinitialization(self) -> None: logger.info( f"成功配置了 {p.llm_adapter.providers_configured} 个提供商" ) + + if getattr(p, "v2_integration", None): + refreshed = await p.v2_integration.refresh_provider_bindings( + force=True + ) + if refreshed: + logger.info("V2 Provider 延迟绑定刷新完成") except Exception as e: logger.error(f"延迟重新初始化提供商配置失败: {e}") diff --git a/metadata.yaml b/metadata.yaml index 6ea4b5da..9ac4e272 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -2,7 +2,7 @@ name: "astrbot_plugin_self_learning" author: "NickMo, EterUltimate" display_name: "self-learning" description: "SELF LEARNING 自主学习插件 — 让 AI 聊天机器人自主学习对话风格、理解群组黑话、管理社交关系与好感度、自适应人格演化,像真人一样自然对话。(使用前必须手动备份人格数据)" -version: "3.1.5" +version: "3.1.6" repo: "https://github.com/NickCharlie/astrbot_plugin_self_learning" tags: - "自学习" diff --git a/services/core_learning/v2_learning_integration.py b/services/core_learning/v2_learning_integration.py index 53b3c5c7..5f0a1c85 100644 --- a/services/core_learning/v2_learning_integration.py +++ b/services/core_learning/v2_learning_integration.py @@ -30,6 +30,7 @@ import asyncio import hashlib +import time from collections import defaultdict from typing import Any, Dict, List, Optional, Tuple @@ -77,6 +78,15 @@ def __init__( self._db = db_manager self._context = context self._feature_delegation = feature_delegation + self._started = False + self._provider_retry_lock = asyncio.Lock() + self._last_provider_retry: float = 0.0 + self._provider_retry_interval: float = max( + 0.1, + float(getattr(config, "provider_retry_interval_seconds", 10.0) or 10.0), + ) + self._knowledge_manager_retryable = True + self._memory_manager_retryable = True # --- Resolve framework providers via factories --------------- self._embedding_provider = self._create_embedding_provider() @@ -115,7 +125,106 @@ def __init__( async def start(self) -> None: """Start all active v2 modules that expose a ``start`` method.""" - modules: List[Tuple[str, Any]] = [ + await self.refresh_provider_bindings(force=True) + + await asyncio.gather(*( + self._start_one(name, module) + for name, module in self._active_modules() + if module and hasattr(module, "start") + )) + self._started = True + logger.info("[V2Integration] All modules started") + + async def refresh_provider_bindings(self, *, force: bool = False) -> bool: + """Retry framework provider binding and create dependent modules. + + AstrBot can load plugins before provider registries are populated. This + lets startup, warmup, and first-use paths bind providers later without a + manual plugin reload. + """ + if not self._needs_provider_or_module_retry(): + return False + + if not force and not self._provider_retry_due(): + return False + + async with self._provider_retry_lock: + if not self._needs_provider_or_module_retry(): + return False + + if not force and not self._provider_retry_due(): + return False + self._last_provider_retry = time.monotonic() + + changed = False + modules_to_start: List[Tuple[str, Any]] = [] + + if not self._embedding_provider and self._embedding_provider_configured(): + provider = self._create_embedding_provider() + if provider: + self._embedding_provider = provider + changed = True + + if not self._rerank_provider and self._rerank_provider_configured(): + provider = self._create_rerank_provider() + if provider: + self._rerank_provider = provider + changed = True + + if self._embedding_provider: + if self._knowledge_manager is None: + self._knowledge_manager = self._create_knowledge_manager() + if self._knowledge_manager: + changed = True + modules_to_start.append(( + "knowledge_manager", + self._knowledge_manager, + )) + + if self._memory_manager is None: + self._memory_manager = self._create_memory_manager() + if self._memory_manager: + changed = True + modules_to_start.append(( + "memory_manager", + self._memory_manager, + )) + + if self._exemplar_library_needs_embedding_refresh(): + self._exemplar_library = self._create_exemplar_library() + changed = True + + if changed: + self._register_trigger_operations() + if self._started and modules_to_start: + await asyncio.gather(*( + self._start_one(name, module) + for name, module in modules_to_start + if module and hasattr(module, "start") + )) + logger.info( + "[V2Integration] Provider bindings refreshed — " + f"embedding={'yes' if self._embedding_provider else 'no'}, " + f"reranker={'yes' if self._rerank_provider else 'no'}" + ) + return changed + + def _provider_retry_due(self) -> bool: + return ( + time.monotonic() - self._last_provider_retry + >= self._provider_retry_interval + ) + + async def _start_one(self, name: str, module: Any) -> None: + try: + await module.start() + except Exception as exc: + logger.warning( + f"[V2Integration] {name} start failed: {exc}" + ) + + def _active_modules(self) -> List[Tuple[str, Any]]: + return [ ("knowledge_manager", self._knowledge_manager), ("memory_manager", self._memory_manager), ("exemplar_library", self._exemplar_library), @@ -123,20 +232,38 @@ async def start(self) -> None: ("jargon_filter", self._jargon_filter), ] - async def _start_one(name: str, module: Any) -> None: - try: - await module.start() - except Exception as exc: - logger.warning( - f"[V2Integration] {name} start failed: {exc}" - ) + def _needs_provider_or_module_retry(self) -> bool: + if self._embedding_provider_configured() and not self._embedding_provider: + return True + if self._rerank_provider_configured() and not self._rerank_provider: + return True + if self._embedding_provider and self._knowledge_manager is None: + return ( + self._config.knowledge_engine == "lightrag" + and self._knowledge_manager_retryable + ) + if self._embedding_provider and self._memory_manager is None: + return ( + self._config.memory_engine == "mem0" + and not self._memory_delegated() + and self._memory_manager_retryable + ) + return self._exemplar_library_needs_embedding_refresh() - await asyncio.gather(*( - _start_one(name, module) - for name, module in modules - if module and hasattr(module, "start") - )) - logger.info("[V2Integration] All modules started") + def _embedding_provider_configured(self) -> bool: + return bool( + str(getattr(self._config, "embedding_provider_id", "") or "").strip() + ) + + def _rerank_provider_configured(self) -> bool: + return bool( + str(getattr(self._config, "rerank_provider_id", "") or "").strip() + ) + + def _exemplar_library_needs_embedding_refresh(self) -> bool: + if not (self._db and self._embedding_provider and self._exemplar_library): + return False + return getattr(self._exemplar_library, "_embedding", None) is None async def warmup(self, group_ids: List[str]) -> None: """Pre-warm heavyweight module instances for *group_ids*. @@ -146,6 +273,7 @@ async def warmup(self, group_ids: List[str]) -> None: (each cold-start avoids a 12-15s initialisation penalty on the first user query). """ + await self.refresh_provider_bindings() if ( self._knowledge_manager and hasattr(self._knowledge_manager, "warmup_instances") @@ -228,6 +356,7 @@ async def process_message( Tier 1 operations run concurrently on every message. Tier 2 operations fire when their policies are satisfied. """ + await self.refresh_provider_bindings() return await self._trigger.process_message(message, group_id) @monitored @@ -256,6 +385,8 @@ async def get_enhanced_context( All retrieval tasks run concurrently via ``asyncio.gather`` to minimise total latency. """ + await self.refresh_provider_bindings() + # --- Check query result cache --- cache_key = self._make_cache_key(query, group_id) cached_result = self._cache.get("context", cache_key) @@ -403,10 +534,17 @@ def _create_knowledge_manager(self) -> Optional[Any]: """Create knowledge manager based on configured engine.""" if self._config.knowledge_engine == "lightrag": if not self._embedding_provider: - logger.warning( - "[V2Integration] LightRAG requires an embedding provider " - "but none is available; knowledge engine disabled" - ) + if self._embedding_provider_configured(): + logger.info( + "[V2Integration] LightRAG is waiting for the " + "embedding provider registry to become ready" + ) + else: + logger.warning( + "[V2Integration] LightRAG requires an embedding " + "provider; configure embedding_provider_id or use " + "the legacy knowledge engine" + ) return None try: from ..integration import LightRAGKnowledgeManager @@ -414,6 +552,7 @@ def _create_knowledge_manager(self) -> Optional[Any]: self._config, self._llm, self._embedding_provider ) except ImportError: + self._knowledge_manager_retryable = False logger.warning( "[V2Integration] lightrag-hku not installed, " "falling back to legacy knowledge engine" @@ -433,12 +572,26 @@ def _create_memory_manager(self) -> Optional[Any]: logger.info("[V2Integration] Memory engine skipped: delegated to LivingMemory") return None if self._config.memory_engine == "mem0": + if not self._embedding_provider: + if self._embedding_provider_configured(): + logger.info( + "[V2Integration] Mem0 is waiting for the embedding " + "provider registry to become ready" + ) + else: + logger.warning( + "[V2Integration] Mem0 requires an embedding provider; " + "configure embedding_provider_id or use the legacy " + "memory engine" + ) + return None try: from ..integration import Mem0MemoryManager return Mem0MemoryManager( self._config, self._llm, self._embedding_provider ) except ImportError: + self._memory_manager_retryable = False logger.warning( "[V2Integration] mem0ai not installed, " "falling back to legacy memory engine" diff --git a/services/embedding/factory.py b/services/embedding/factory.py index c0a3f6ec..25e4a675 100644 --- a/services/embedding/factory.py +++ b/services/embedding/factory.py @@ -14,6 +14,12 @@ from astrbot.api import logger from astrbot.core.provider.provider import EmbeddingProvider +from ..provider_registry import ( + collect_framework_providers, + find_provider_by_id, + framework_registry_has_any_provider, + normalize_provider_id, +) from .base import IEmbeddingProvider from .framework_adapter import FrameworkEmbeddingAdapter @@ -41,7 +47,9 @@ def create(config, context) -> Optional[IEmbeddingProvider]: An ``IEmbeddingProvider`` instance, or ``None`` if embedding is not configured. """ - provider_id = getattr(config, "embedding_provider_id", None) + provider_id = normalize_provider_id( + getattr(config, "embedding_provider_id", None) + ) if not provider_id: logger.debug( @@ -66,6 +74,50 @@ def _resolve_framework_provider( provider_id: str, context ) -> Optional[IEmbeddingProvider]: """Resolve the framework provider by ID and wrap in adapter.""" + providers, inspected, errors = collect_framework_providers( + context, + EmbeddingProvider, + context_getter_name="get_all_embedding_providers", + manager_list_name="embedding_provider_insts", + ) + for error in errors: + logger.debug(f"[EmbeddingFactory] Registry inspection failed: {error}") + + if inspected: + provider = find_provider_by_id(providers, provider_id) + if provider is not None: + return EmbeddingProviderFactory._wrap_provider( + provider_id, provider + ) + + if not providers: + registry_has_any_provider, _, any_errors = ( + framework_registry_has_any_provider(context) + ) + for error in any_errors: + logger.debug( + f"[EmbeddingFactory] Provider readiness inspection failed: {error}" + ) + if not registry_has_any_provider: + logger.info( + "[EmbeddingFactory] Framework provider registry is " + "not ready; embedding provider resolution will retry later" + ) + return None + + logger.warning( + f"[EmbeddingFactory] No embedding providers are visible " + f"in the framework registry; configured id='{provider_id}'" + ) + return None + + available_ids = EmbeddingProviderFactory._provider_ids(providers) + logger.warning( + f"[EmbeddingFactory] Provider '{provider_id}' not found " + f"in embedding registry; available={available_ids}" + ) + return None + try: provider = context.get_provider_by_id(provider_id) except Exception as exc: @@ -82,6 +134,14 @@ def _resolve_framework_provider( ) return None + return EmbeddingProviderFactory._wrap_provider(provider_id, provider) + + @staticmethod + def _wrap_provider( + provider_id: str, + provider: EmbeddingProvider, + ) -> Optional[IEmbeddingProvider]: + """Validate and wrap an already-resolved framework provider.""" if not isinstance(provider, EmbeddingProvider): logger.warning( f"[EmbeddingFactory] Provider '{provider_id}' is " @@ -96,3 +156,14 @@ def _resolve_framework_provider( f"dim={adapter.get_dim()}" ) return adapter + + @staticmethod + def _provider_ids(providers) -> list[str]: + """Return visible provider IDs for diagnostics.""" + ids: list[str] = [] + for provider in providers: + try: + ids.append(provider.meta().id) + except Exception: + continue + return ids diff --git a/services/provider_registry.py b/services/provider_registry.py new file mode 100644 index 00000000..8d13ff2f --- /dev/null +++ b/services/provider_registry.py @@ -0,0 +1,179 @@ +"""AstrBot provider registry inspection helpers. + +These helpers intentionally inspect provider lists before falling back to +``context.get_provider_by_id``. During AstrBot cold start the registry can be +temporarily empty; probing by ID in that window makes the framework emit +"provider not found" warnings for an otherwise valid configuration. +""" + +from __future__ import annotations + +from typing import Any, List, Optional, Tuple, Type, TypeVar + + +T = TypeVar("T") + + +def normalize_provider_id(value: Any) -> Optional[str]: + """Return a stripped provider ID, or ``None`` for empty values.""" + if value is None: + return None + provider_id = str(value).strip() + return provider_id or None + + +def collect_framework_providers( + context: Any, + provider_cls: Type[T], + *, + context_getter_name: Optional[str] = None, + manager_list_name: Optional[str] = None, +) -> Tuple[List[T], bool, List[str]]: + """Collect known framework providers for a type. + + Returns ``(providers, inspected, errors)``. ``inspected`` is true when the + context exposed at least one list-like registry source. If it is true and + ``providers`` is empty, callers should treat the framework registry as not + ready and avoid noisy ID lookups. + """ + providers: List[T] = [] + inspected = False + errors: List[str] = [] + + if context_getter_name: + getter = _safe_getattr(context, context_getter_name) + if callable(getter): + inspected = True + try: + _extend_provider_list(providers, getter(), provider_cls) + except Exception as exc: + errors.append(f"{context_getter_name}: {exc}") + + provider_manager = _safe_getattr(context, "provider_manager") + if provider_manager is not None: + if manager_list_name: + manager_list = _safe_getattr(provider_manager, manager_list_name) + if manager_list is not None: + inspected = True + try: + _extend_provider_list(providers, manager_list, provider_cls) + except Exception as exc: + errors.append(f"provider_manager.{manager_list_name}: {exc}") + + inst_map = _safe_getattr(provider_manager, "inst_map") + if inst_map is not None: + inspected = True + try: + values = inst_map.values() if isinstance(inst_map, dict) else inst_map + _extend_provider_list(providers, values, provider_cls) + except Exception as exc: + errors.append("provider_manager.inst_map: " + str(exc)) + + return _dedupe_providers(providers), inspected, errors + + +def framework_registry_has_any_provider(context: Any) -> Tuple[bool, bool, List[str]]: + """Return whether any framework provider is visible in known registries.""" + inspected = False + errors: List[str] = [] + + for getter_name in ( + "get_all_providers", + "get_all_embedding_providers", + "get_all_rerank_providers", + ): + getter = _safe_getattr(context, getter_name) + if not callable(getter): + continue + inspected = True + try: + if _as_list(getter()): + return True, inspected, errors + except Exception as exc: + errors.append(f"{getter_name}: {exc}") + + provider_manager = _safe_getattr(context, "provider_manager") + if provider_manager is not None: + for attr_name in ( + "provider_insts", + "embedding_provider_insts", + "rerank_provider_insts", + "inst_map", + ): + value = _safe_getattr(provider_manager, attr_name) + if value is None: + continue + inspected = True + try: + if _as_list(value): + return True, inspected, errors + except Exception as exc: + errors.append(f"provider_manager.{attr_name}: {exc}") + + return False, inspected, errors + + +def find_provider_by_id(providers: List[T], provider_id: str) -> Optional[T]: + """Find a provider in a pre-inspected registry list by metadata ID.""" + for provider in providers: + try: + meta = provider.meta() # type: ignore[attr-defined] + except Exception: + continue + if getattr(meta, "id", None) == provider_id: + return provider + return None + + +def _safe_getattr(obj: Any, name: str, default: Any = None) -> Any: + """Read attributes without letting unconfigured Mock children look real.""" + if obj is None: + return default + + try: + attrs = vars(obj) + except TypeError: + attrs = {} + + if name in attrs: + return attrs[name] + + if type(obj).__module__.startswith("unittest.mock"): + return default + + try: + return getattr(obj, name) + except Exception: + return default + + +def _extend_provider_list( + providers: List[T], + raw_value: Any, + provider_cls: Type[T], +) -> None: + for candidate in _as_list(raw_value): + if isinstance(candidate, provider_cls): + providers.append(candidate) + + +def _as_list(value: Any) -> List[Any]: + if value is None: + return [] + if isinstance(value, dict): + return list(value.values()) + if isinstance(value, (list, tuple, set)): + return list(value) + return [value] + + +def _dedupe_providers(providers: List[T]) -> List[T]: + deduped: List[T] = [] + seen = set() + for provider in providers: + identity = id(provider) + if identity in seen: + continue + seen.add(identity) + deduped.append(provider) + return deduped diff --git a/services/reranker/factory.py b/services/reranker/factory.py index 956ff801..c747ef79 100644 --- a/services/reranker/factory.py +++ b/services/reranker/factory.py @@ -10,6 +10,12 @@ from astrbot.api import logger from astrbot.core.provider.provider import RerankProvider as FrameworkRerankProvider +from ..provider_registry import ( + collect_framework_providers, + find_provider_by_id, + framework_registry_has_any_provider, + normalize_provider_id, +) from .base import IRerankProvider from .framework_adapter import FrameworkRerankAdapter @@ -35,7 +41,9 @@ def create(config, context) -> Optional[IRerankProvider]: Returns: An ``IRerankProvider`` instance, or ``None`` if not configured. """ - provider_id = getattr(config, "rerank_provider_id", None) + provider_id = normalize_provider_id( + getattr(config, "rerank_provider_id", None) + ) if not provider_id: logger.debug( @@ -51,6 +59,50 @@ def create(config, context) -> Optional[IRerankProvider]: ) return None + providers, inspected, errors = collect_framework_providers( + context, + FrameworkRerankProvider, + context_getter_name="get_all_rerank_providers", + manager_list_name="rerank_provider_insts", + ) + for error in errors: + logger.debug(f"[RerankFactory] Registry inspection failed: {error}") + + if inspected: + provider = find_provider_by_id(providers, provider_id) + if provider is not None: + return RerankProviderFactory._wrap_provider( + provider_id, provider + ) + + if not providers: + registry_has_any_provider, _, any_errors = ( + framework_registry_has_any_provider(context) + ) + for error in any_errors: + logger.debug( + f"[RerankFactory] Provider readiness inspection failed: {error}" + ) + if not registry_has_any_provider: + logger.info( + "[RerankFactory] Framework provider registry is " + "not ready; reranker provider resolution will retry later" + ) + return None + + logger.warning( + f"[RerankFactory] No rerank providers are visible " + f"in the framework registry; configured id='{provider_id}'" + ) + return None + + available_ids = RerankProviderFactory._provider_ids(providers) + logger.warning( + f"[RerankFactory] Provider '{provider_id}' not found " + f"in rerank registry; available={available_ids}" + ) + return None + try: provider = context.get_provider_by_id(provider_id) except Exception as exc: @@ -67,6 +119,14 @@ def create(config, context) -> Optional[IRerankProvider]: ) return None + return RerankProviderFactory._wrap_provider(provider_id, provider) + + @staticmethod + def _wrap_provider( + provider_id: str, + provider: FrameworkRerankProvider, + ) -> Optional[IRerankProvider]: + """Validate and wrap an already-resolved framework provider.""" if not isinstance(provider, FrameworkRerankProvider): logger.warning( f"[RerankFactory] Provider '{provider_id}' is " @@ -80,3 +140,14 @@ def create(config, context) -> Optional[IRerankProvider]: f"id={provider_id}, model={adapter.get_model_name()}" ) return adapter + + @staticmethod + def _provider_ids(providers) -> list[str]: + """Return visible provider IDs for diagnostics.""" + ids: list[str] = [] + for provider in providers: + try: + ids.append(provider.meta().id) + except Exception: + continue + return ids diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 44650228..b7ad4608 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -85,6 +85,7 @@ def test_default_provider_ids_none(self): assert config.reinforce_provider_id is None assert config.embedding_provider_id is None assert config.rerank_provider_id is None + assert config.provider_retry_interval_seconds == 10.0 def test_sqlalchemy_always_true(self): """Test that use_sqlalchemy is always True (hardcoded).""" @@ -226,6 +227,7 @@ def test_create_from_config_with_v2_settings(self): 'V2_Architecture_Settings': { 'embedding_provider_id': 'embed_provider', 'rerank_provider_id': 'rerank_provider', + 'provider_retry_interval_seconds': 2.5, 'knowledge_engine': 'lightrag', 'memory_engine': 'mem0', } @@ -235,6 +237,7 @@ def test_create_from_config_with_v2_settings(self): assert config.embedding_provider_id == 'embed_provider' assert config.rerank_provider_id == 'rerank_provider' + assert config.provider_retry_interval_seconds == 2.5 assert config.knowledge_engine == 'lightrag' assert config.memory_engine == 'mem0' diff --git a/tests/unit/test_provider_registry_rebind.py b/tests/unit/test_provider_registry_rebind.py new file mode 100644 index 00000000..314d3dbf --- /dev/null +++ b/tests/unit/test_provider_registry_rebind.py @@ -0,0 +1,162 @@ +"""Regression coverage for AstrBot provider registry cold-start handling.""" + +import importlib +import importlib.util +import sys +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from astrbot.core.provider.provider import ( + EmbeddingProvider, + RerankProvider as FrameworkRerankProvider, +) + +PLUGIN_ROOT = Path(__file__).resolve().parents[2] + + +def _load_plugin_package(alias: str): + spec = importlib.util.spec_from_file_location( + alias, + PLUGIN_ROOT / "__init__.py", + submodule_search_locations=[str(PLUGIN_ROOT)], + ) + module = importlib.util.module_from_spec(spec) + sys.modules[alias] = module + spec.loader.exec_module(module) + return module + + +def _cleanup_alias(alias: str) -> None: + for name in list(sys.modules): + if name == alias or name.startswith(f"{alias}."): + sys.modules.pop(name, None) + + +@pytest.fixture +def plugin_modules(): + alias = "data.plugins.astrbot_plugin_self_learning_provider_rebind_test" + _cleanup_alias(alias) + _load_plugin_package(alias) + try: + yield SimpleNamespace( + PluginConfig=importlib.import_module(f"{alias}.config").PluginConfig, + V2LearningIntegration=importlib.import_module( + f"{alias}.services.core_learning.v2_learning_integration" + ).V2LearningIntegration, + EmbeddingProviderFactory=importlib.import_module( + f"{alias}.services.embedding.factory" + ).EmbeddingProviderFactory, + RerankProviderFactory=importlib.import_module( + f"{alias}.services.reranker.factory" + ).RerankProviderFactory, + ) + finally: + _cleanup_alias(alias) + + +class DummyEmbeddingProvider(EmbeddingProvider): + def __init__(self, provider_id: str = "embed-a") -> None: + super().__init__({"id": provider_id, "type": "dummy_embedding"}, {}) + self._id = provider_id + self.set_model("dummy-embedding") + + def meta(self): + return SimpleNamespace(id=self._id) + + async def get_embedding(self, text: str) -> list[float]: + return [1.0, 2.0, 3.0] + + async def get_embeddings(self, text: list[str]) -> list[list[float]]: + return [[1.0, 2.0, 3.0] for _ in text] + + def get_dim(self) -> int: + return 3 + + +class DummyRerankProvider(FrameworkRerankProvider): + def __init__(self, provider_id: str = "rerank-a") -> None: + super().__init__({"id": provider_id, "type": "dummy_rerank"}, {}) + self._id = provider_id + self.set_model("dummy-rerank") + + def meta(self): + return SimpleNamespace(id=self._id) + + async def rerank(self, query: str, documents: list[str], top_n=None): + return [] + + +def test_embedding_factory_waits_when_framework_registry_is_empty(plugin_modules): + config = SimpleNamespace(embedding_provider_id="embedding") + context = SimpleNamespace( + get_all_providers=Mock(return_value=[]), + get_all_embedding_providers=Mock(return_value=[]), + get_provider_by_id=Mock(side_effect=AssertionError("should not query by id")), + ) + + provider = plugin_modules.EmbeddingProviderFactory.create(config, context) + + assert provider is None + context.get_provider_by_id.assert_not_called() + + +def test_rerank_factory_waits_when_framework_registry_is_empty(plugin_modules): + config = SimpleNamespace(rerank_provider_id="rerank") + context = SimpleNamespace( + get_all_providers=Mock(return_value=[]), + get_all_rerank_providers=Mock(return_value=[]), + get_provider_by_id=Mock(side_effect=AssertionError("should not query by id")), + ) + + provider = plugin_modules.RerankProviderFactory.create(config, context) + + assert provider is None + context.get_provider_by_id.assert_not_called() + + +@pytest.mark.asyncio +async def test_v2_integration_rebinds_providers_after_registry_becomes_ready(plugin_modules): + class MinimalV2LearningIntegration(plugin_modules.V2LearningIntegration): + """Keep provider rebinding tests focused on provider wiring only.""" + + def _create_social_analyzer(self): + return None + + def _create_jargon_filter(self): + return None + + embedding_providers = [] + rerank_providers = [] + context = SimpleNamespace( + get_all_providers=Mock(return_value=[]), + get_all_embedding_providers=Mock(side_effect=lambda: list(embedding_providers)), + get_all_rerank_providers=Mock(side_effect=lambda: list(rerank_providers)), + get_provider_by_id=Mock(side_effect=AssertionError("should not query by id")), + ) + config = plugin_modules.PluginConfig( + embedding_provider_id="embed-a", + rerank_provider_id="rerank-a", + knowledge_engine="legacy", + memory_engine="legacy", + ) + integration = MinimalV2LearningIntegration( + config=config, + llm_adapter=None, + db_manager=None, + context=context, + ) + + assert integration._embedding_provider is None + assert integration._rerank_provider is None + + embedding_providers.append(DummyEmbeddingProvider("embed-a")) + rerank_providers.append(DummyRerankProvider("rerank-a")) + + refreshed = await integration.refresh_provider_bindings(force=True) + + assert refreshed is True + assert integration._embedding_provider.provider_id == "embed-a" + assert integration._rerank_provider.provider_id == "rerank-a" + context.get_provider_by_id.assert_not_called() From 28de144105e6ab0bd1ed6430400da99b6cb73263 Mon Sep 17 00:00:00 2001 From: EterUltimate <1831303476@qq.com> Date: Sun, 7 Jun 2026 18:08:19 +0800 Subject: [PATCH 2/3] feat: modularize learning dashboard for 3.2.0 --- CHANGELOG.md | 16 + README.md | 2 +- README_EN.md | 2 +- __init__.py | 2 +- docs/README.md | 2 +- metadata.yaml | 2 +- .../core_learning/progressive_learning.py | 733 ++------- services/learning/expression_learning.py | 305 ++++ services/learning/jargon_learning.py | 160 ++ services/learning/message_pipeline.py | 121 +- services/learning/persona_learning.py | 515 ++++++ web_res/static/html/dashboard.html | 1406 ++++++++++++++++- 12 files changed, 2526 insertions(+), 740 deletions(-) create mode 100644 services/learning/expression_learning.py create mode 100644 services/learning/jargon_learning.py create mode 100644 services/learning/persona_learning.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c7c2f14b..cf2ba829 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,22 @@ 所有重要更改都将记录在此文件中。 +## [3.2.0] - 2026-06-07 + +### 学习模块 + +- 将黑话学习、表达方式学习和人格学习拆分为独立服务模块,学习流水线与渐进学习流程改为委托给专门模块处理。 + +### WebUI + +- Dashboard 新增学习模块控制台,汇总全局学习效率、待办、内容样本和最近批次,并提供黑话、表达方式、人格、审查、内容和监控快捷入口。 +- 新增黑话学习、表达方式学习、人格学习三个独立 WebUI 子模块页面,分别展示 KPI、列表、快捷操作和 ECharts 可视化图表。 +- 增加 iOS-like 指针物理反馈、按压回弹和移动端响应式收敛,并尊重 `prefers-reduced-motion`。 + +### 版本 + +- 将插件发布版本号提升至 `3.2.0`。 + ## [3.0.15] - 2026-05-30 ### WebUI diff --git a/README.md b/README.md index 919a1e71..ee6e1989 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ 让 AstrBot 在群聊中持续采集、学习、审查并注入上下文,使 Bot 逐步具备表达风格、群组黑话、社交关系、长期记忆和人格演化能力。 -[![Version](https://img.shields.io/badge/version-3.1.6-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) +[![Version](https://img.shields.io/badge/version-3.2.0-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-AGPL--3.0-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) diff --git a/README_EN.md b/README_EN.md index 3b53222e..4b372887 100644 --- a/README_EN.md +++ b/README_EN.md @@ -14,7 +14,7 @@
-[![Version](https://img.shields.io/badge/version-3.1.6-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-AGPL--3.0-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) +[![Version](https://img.shields.io/badge/version-3.2.0-blue.svg)](https://github.com/NickCharlie/astrbot_plugin_self_learning) [![License](https://img.shields.io/badge/license-AGPL--3.0-green.svg)](LICENSE) [![AstrBot](https://img.shields.io/badge/AstrBot-%3E%3D4.11.4-orange.svg)](https://github.com/Soulter/AstrBot) [![Python](https://img.shields.io/badge/python-3.11%2B-blue.svg)](https://www.python.org/) [Features](#what-we-can-do) · [Quick Start](#quick-start) · [Web UI](#visual-management-interface) · [Community](#community) · [Contributing](CONTRIBUTING.md) diff --git a/__init__.py b/__init__.py index e387b0fc..572ef848 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,5 @@ # AstrBot 自学习插件 -__version__ = "3.1.6" +__version__ = "3.2.0" # Ensure parent namespace packages ("data", "data.plugins") are # durably registered in sys.modules. AstrBot loads plugins via diff --git a/docs/README.md b/docs/README.md index 315be1d9..3842dadb 100644 --- a/docs/README.md +++ b/docs/README.md @@ -6,7 +6,7 @@ AstrBot 自主学习插件的实现文档和使用文档。 - 插件名: `astrbot_plugin_self_learning` - 展示名: `self-learning` -- 当前元数据版本: `3.0.5` +- 当前元数据版本: `3.2.0` - 最低 AstrBot 版本: `4.11.4` - 主要入口: `main.py` - 配置入口: `_conf_schema.json`, `config.py` diff --git a/metadata.yaml b/metadata.yaml index 9ac4e272..e02ef57c 100644 --- a/metadata.yaml +++ b/metadata.yaml @@ -2,7 +2,7 @@ name: "astrbot_plugin_self_learning" author: "NickMo, EterUltimate" display_name: "self-learning" description: "SELF LEARNING 自主学习插件 — 让 AI 聊天机器人自主学习对话风格、理解群组黑话、管理社交关系与好感度、自适应人格演化,像真人一样自然对话。(使用前必须手动备份人格数据)" -version: "3.1.6" +version: "3.2.0" repo: "https://github.com/NickCharlie/astrbot_plugin_self_learning" tags: - "自学习" diff --git a/services/core_learning/progressive_learning.py b/services/core_learning/progressive_learning.py index e88696d3..c21fa3ff 100644 --- a/services/core_learning/progressive_learning.py +++ b/services/core_learning/progressive_learning.py @@ -5,24 +5,19 @@ import json import time from typing import Dict, List, Optional, Any -from datetime import datetime, timedelta +from datetime import datetime from dataclasses import dataclass from astrbot.api import logger from astrbot.api.star import Context from ...config import PluginConfig -from ...constants import ( - UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING, - UPDATE_TYPE_STYLE_LEARNING, -) from ...exceptions import LearningError -from ...utils.json_utils import safe_parse_llm_json, clean_llm_json_response -from ...utils.persona_selection import get_persona_identifier, resolve_target_persona - from ..database import DatabaseManager -from ..learning.sample_filter import filter_learning_messages, should_ignore_learning_sample +from ..learning.expression_learning import ExpressionLearningModule +from ..learning.persona_learning import PersonaLearningModule +from ..learning.sample_filter import filter_learning_messages @dataclass @@ -62,6 +57,21 @@ def __init__(self, config: PluginConfig, context: Context, self.persona_manager = persona_manager # 注入 persona_manager self.ml_analyzer = ml_analyzer # 注入 ml_analyzer self.prompts = prompts # 保存 prompts 实例 + + # MaiBot-style learning domains: expression learning and persona + # learning are independent modules; this service only orchestrates + # the batch lifecycle. + self.expression_learning = ExpressionLearningModule(db_manager) + self.persona_learning = PersonaLearningModule( + config=config, + context=context, + db_manager=db_manager, + persona_manager=persona_manager, + multidimensional_analyzer=multidimensional_analyzer, + prompts=prompts, + resolve_umo=self._resolve_umo, + json_serializer=self._json_serializer, + ) # 学习状态 - 使用字典管理每个群组的学习状态 self.learning_active = {} # 改为字典,按群组ID管理 @@ -86,6 +96,33 @@ def _resolve_umo(self, group_id: str) -> str: return self.group_id_to_unified_origin.get(group_id, group_id) return group_id + def _get_expression_learning_module(self) -> ExpressionLearningModule: + """Return the expression-learning domain module, creating it lazily for tests.""" + module = getattr(self, "expression_learning", None) + if module is None: + module = ExpressionLearningModule(self.db_manager) + self.expression_learning = module + return module + + def _get_persona_learning_module(self) -> PersonaLearningModule: + """Return the persona-learning domain module, creating it lazily for tests.""" + module = getattr(self, "persona_learning", None) + if module is None: + module = PersonaLearningModule( + config=getattr(self, "config", None), + context=getattr(self, "context", None), + db_manager=getattr(self, "db_manager", None), + persona_manager=getattr(self, "persona_manager", None), + multidimensional_analyzer=getattr( + self, "multidimensional_analyzer", None + ), + prompts=getattr(self, "prompts", None), + resolve_umo=self._resolve_umo, + json_serializer=self._json_serializer, + ) + self.persona_learning = module + return module + @staticmethod def _quality_value(value) -> Optional[float]: if value is None: @@ -766,65 +803,12 @@ async def _execute_strategy_optimization_background(self, group_id: str): logger.error(f"后台策略优化失败: {e}") async def _generate_updated_persona_with_refinement(self, group_id: str, current_persona: Dict[str, Any], style_analysis: Any) -> Dict[str, Any]: - """使用提炼模型生成更新后的人格""" - try: - # 正确处理AnalysisResult对象和字典类型 - from ...core.interfaces import AnalysisResult - - if isinstance(style_analysis, AnalysisResult): - # 如果是AnalysisResult对象,提取data属性 - analysis_data = style_analysis.data if style_analysis.data else {} - logger.debug(f"从AnalysisResult提取data: success={style_analysis.success}, confidence={style_analysis.confidence}") - elif isinstance(style_analysis, dict): - analysis_data = style_analysis - logger.debug("使用字典形式的style_analysis") - elif hasattr(style_analysis, 'data'): - # 兼容其他具有data属性的对象 - analysis_data = style_analysis.data if style_analysis.data else {} - logger.debug(f"从对象提取data属性: {type(style_analysis)}") - else: - analysis_data = {} - logger.warning(f"style_analysis类型不正确: {type(style_analysis)}, 使用空字典") - - # 使用多维度分析器的框架适配器生成人格更新 - if hasattr(self.multidimensional_analyzer, 'llm_adapter') and self.multidimensional_analyzer.llm_adapter: - llm_adapter = self.multidimensional_analyzer.llm_adapter - - if llm_adapter.has_refine_provider() and llm_adapter.providers_configured >= 2: - # 准备输入数据 - current_persona_json = json.dumps(current_persona, ensure_ascii=False, indent=2, default=self._json_serializer) - style_analysis_json = json.dumps(analysis_data, ensure_ascii=False, indent=2, default=self._json_serializer) - - # 调用框架适配器 - response = await llm_adapter.refine_chat_completion( - prompt=self.prompts.PROGRESSIVE_LEARNING_GENERATE_UPDATED_PERSONA_PROMPT.format( - current_persona_json=current_persona_json, - style_analysis_json=style_analysis_json - ), - temperature=0.6 - ) - - if response: - # 清理响应文本,移除markdown标识符(使用统一的json_utils工具) - clean_response = clean_llm_json_response(response) - - try: - updated_persona = safe_parse_llm_json(clean_response) - logger.info("使用提炼模型成功生成更新后的人格") - return updated_persona - except json.JSONDecodeError as e: - logger.error(f"提炼模型返回的JSON格式不正确: {e}, 响应: {clean_response}") - return await self._generate_updated_persona(group_id, current_persona, style_analysis) - else: - logger.warning("提炼模型Provider未配置,使用传统方法生成人格") - return await self._generate_updated_persona(group_id, current_persona, style_analysis) - else: - logger.warning("框架适配器未找到,使用传统方法生成人格") - return await self._generate_updated_persona(group_id, current_persona, style_analysis) - - except Exception as e: - logger.error(f"使用提炼模型生成人格失败: {e}") - return await self._generate_updated_persona(group_id, current_persona, style_analysis) + """使用提炼模型生成更新后的人格(兼容转发)""" + return await self._get_persona_learning_module().generate_updated_persona_with_refinement( + group_id, + current_persona, + style_analysis, + ) def _json_serializer(self, obj): """自定义JSON序列化器,处理不能直接序列化的对象""" @@ -866,327 +850,35 @@ async def _filter_messages_with_context(self, messages: List[Dict[str, Any]]) -> return messages async def _get_current_persona(self, group_id: str) -> Dict[str, Any]: - """获取当前人格设置 (针对特定群组)""" - try: - # 通过 PersonaManagerService 获取当前人格 - persona = await self.persona_manager.get_current_persona(group_id) - if persona: - return persona - - # 如果没有特定群组的人格,尝试从框架获取默认人格 - if hasattr(self.context, 'persona_manager') and self.context.persona_manager: - try: - default_persona = await resolve_target_persona( - self.context.persona_manager, - self.config, - self._resolve_umo(group_id), - require_existing=True, - log=logger, - ) - if default_persona: - return { - 'prompt': default_persona.get('prompt', '默认人格'), - 'name': get_persona_identifier(default_persona), - 'style_parameters': {}, - 'last_updated': datetime.now().isoformat() - } - except Exception as e: - logger.warning(f"从框架获取默认人格失败: {e}") - - # 如果都失败,返回默认结构 - return { - 'prompt': "默认人格", - 'name': 'default', - 'style_parameters': {}, - 'last_updated': datetime.now().isoformat() - } - except Exception as e: - logger.error(f"获取当前人格失败 for group {group_id}: {e}") - return {'prompt': '默认人格', 'name': 'default', 'style_parameters': {}} + """获取当前人格设置 (针对特定群组,兼容转发)""" + return await self._get_persona_learning_module().get_current_persona(group_id) async def _generate_updated_persona(self, group_id: str, current_persona: Dict[str, Any], style_analysis: Dict[str, Any]) -> Dict[str, Any]: - """生成更新后的人格 - 直接在原有文本后面追加增量学习内容""" - try: - # 使用新版框架API获取当前人格 - if not hasattr(self.context, 'persona_manager') or not self.context.persona_manager: - logger.warning(f"无法获取PersonaManager for group {group_id}") - return current_persona - - default_persona = await resolve_target_persona( - self.context.persona_manager, - self.config, - self._resolve_umo(group_id), - require_existing=True, - log=logger, - ) - if not default_persona: - logger.warning(f"无法获取当前人格 for group {group_id}") - return current_persona - - # 获取原有人格文本 - original_prompt = default_persona.get('prompt', '') - - # 构建增量学习内容 - learning_content = [] - - # 正确处理AnalysisResult对象和字典类型 - from ...core.interfaces import AnalysisResult - - if isinstance(style_analysis, AnalysisResult): - # 如果是AnalysisResult对象,提取data属性 - analysis_data = style_analysis.data if style_analysis.data else {} - logger.debug(f"从AnalysisResult提取data: success={style_analysis.success}, confidence={style_analysis.confidence}") - elif isinstance(style_analysis, dict): - analysis_data = style_analysis - logger.debug("使用字典形式的style_analysis") - elif hasattr(style_analysis, 'data'): - # 兼容其他具有data属性的对象 - analysis_data = style_analysis.data if style_analysis.data else {} - logger.debug(f"从对象提取data属性: {type(style_analysis)}") - else: - analysis_data = {} - logger.warning(f"style_analysis类型不正确: {type(style_analysis)}, 使用空字典") - - # 修复:从实际的 style_analysis 结构中提取内容 - # 优先提取 enhanced_prompt 和 learning_insights(如果有) - if 'enhanced_prompt' in analysis_data: - learning_content.append(analysis_data['enhanced_prompt']) - logger.debug("找到 enhanced_prompt 字段") - - if 'learning_insights' in analysis_data: - insights = analysis_data['learning_insights'] - if insights: - learning_content.append(insights) - logger.debug("找到 learning_insights 字段") - - # 新增:从 style_analysis 字段提取内容(StyleAnalyzer返回的结构) - if not learning_content and 'style_analysis' in analysis_data: - style_report = analysis_data['style_analysis'] - if isinstance(style_report, dict): - # 提取关键的风格分析内容 - extracted_parts = [] - - # 提取文本风格描述 - if 'text_style' in style_report: - extracted_parts.append(f"文本风格: {style_report['text_style']}") - - # 提取表达特点 - if 'expression_features' in style_report: - features = style_report['expression_features'] - if isinstance(features, list): - extracted_parts.append(f"表达特点: {', '.join(features)}") - elif isinstance(features, str): - extracted_parts.append(f"表达特点: {features}") - - # 提取语气倾向 - if 'tone' in style_report: - extracted_parts.append(f"语气倾向: {style_report['tone']}") - - # 提取话题偏好 - if 'topics' in style_report: - topics = style_report['topics'] - if isinstance(topics, list): - extracted_parts.append(f"话题偏好: {', '.join(topics)}") - elif isinstance(topics, str): - extracted_parts.append(f"话题偏好: {topics}") - - if extracted_parts: - learning_content.append("【对话风格学习结果】\n" + "\n".join(extracted_parts)) - logger.debug(f"从 style_analysis 提取了 {len(extracted_parts)} 个风格特征") - - # 新增:如果还是没有内容,从 style_profile 提取 - if not learning_content and 'style_profile' in analysis_data: - style_profile = analysis_data['style_profile'] - if isinstance(style_profile, dict): - profile_parts = [] - - # 提取语气强度 - if 'tone_intensity' in style_profile: - profile_parts.append(f"语气强度: {style_profile['tone_intensity']:.2f}") - - # 提取情感倾向 - if 'sentiment' in style_profile: - profile_parts.append(f"情感倾向: {style_profile['sentiment']:.2f}") - - # 提取词汇丰富度 - if 'vocabulary_richness' in style_profile: - profile_parts.append(f"词汇丰富度: {style_profile['vocabulary_richness']:.2f}") - - if profile_parts: - learning_content.append("【风格量化指标】\n" + "\n".join(profile_parts)) - logger.debug(f"从 style_profile 提取了 {len(profile_parts)} 个量化指标") - - # 新增:如果还是没有内容,尝试提取任何有用的信息 - if not learning_content: - # 尝试从顶层提取任何看起来有用的字段 - useful_fields = ['summary', 'description', 'analysis', 'insights', 'findings'] - for field in useful_fields: - if field in analysis_data and analysis_data[field]: - learning_content.append(f"【{field}】\n{analysis_data[field]}") - logger.debug(f"从顶层字段 {field} 提取了内容") - break - - # 直接在原有文本后面追加新内容 - if learning_content: - timestamp = datetime.now().strftime('%Y-%m-%d %H:%M') - new_content = f"\n\n【学习更新 - {timestamp}】\n" + "\n".join(learning_content) - - # 创建更新后的人格 (Personality是TypedDict) - updated_persona = dict(default_persona) - updated_persona['prompt'] = original_prompt + new_content - updated_persona['last_updated'] = timestamp - - logger.info(f" 成功追加 {len(learning_content)} 项学习内容到人格 for group {group_id}") - return updated_persona - else: - logger.warning(f" style_analysis中没有可提取的学习内容 for group {group_id}, 数据结构: {list(analysis_data.keys())}") - # 即使没有学习内容,也返回一个副本以确保有updated_persona用于对比 - return dict(default_persona) - - except Exception as e: - logger.error(f"生成更新人格失败 for group {group_id}: {e}", exc_info=True) - return current_persona + """生成更新后的人格 - 直接在原有文本后面追加增量学习内容(兼容转发)""" + return await self._get_persona_learning_module().generate_updated_persona( + group_id, + current_persona, + style_analysis, + ) async def _apply_learning_updates(self, group_id: str, style_analysis: Dict[str, Any], messages: List[Dict[str, Any]], current_persona: Dict[str, Any] = None, updated_persona: Dict[str, Any] = None, quality_metrics = None, relearn_mode: bool = False, ml_tuning_info: Dict[str, Any] = None): - """应用学习更新,并创建人格学习审查记录和风格学习记录 - - Args: - group_id: 群组ID - style_analysis: 风格分析结果 - messages: 处理的消息列表 - current_persona: 当前人格 - updated_persona: 更新后的人格 - quality_metrics: 质量指标 - relearn_mode: 重新学习模式,为True时即使内容相同也创建审查记录 - ml_tuning_info: 强化学习调优信息(包含是否使用保守融合策略等) - """ + """应用学习更新,并创建人格学习审查记录和风格学习记录。""" try: - # 处理可能的list类型参数 - if isinstance(current_persona, list): - logger.warning(f"current_persona为list类型(长度{len(current_persona)}),转换为空字典") - current_persona = {} - - if isinstance(updated_persona, list): - logger.warning(f"updated_persona为list类型(长度{len(updated_persona)}),转换为空字典") - updated_persona = {} - - # 1. 保存对话风格学习记录(不需要审查,直接保存) await self._save_style_learning_record(group_id, style_analysis, messages, quality_metrics) - # 2. 更新人格prompt(通过 PersonaManagerService) - logger.info(f"应用人格更新 for group {group_id}") - - # 正确处理 AnalysisResult 对象 - if hasattr(style_analysis, 'success'): - # 这是一个 AnalysisResult 对象 - if not style_analysis.success: - logger.error(f"风格分析失败,跳过人格更新: {style_analysis.error}") - return - - # 使用 AnalysisResult 的 data 属性 - style_analysis_dict = style_analysis.data - confidence = style_analysis.confidence - logger.debug(f"使用 AnalysisResult 对象,置信度: {confidence:.3f}") - elif isinstance(style_analysis, dict): - # 向后兼容:如果传入的是字典 - style_analysis_dict = style_analysis - confidence = style_analysis.get('confidence', 0.5) - logger.debug("使用字典形式的 style_analysis(向后兼容)") - else: - logger.error(f"style_analysis 类型不正确: {type(style_analysis)}") - return - - update_success = await self.persona_manager.update_persona(group_id, style_analysis_dict, messages) - if not update_success: - logger.error(f"通过 PersonaManagerService 更新人格失败 for group {group_id}") - - # 2. 创建人格学习审查记录(新增) - # 重新学习模式:即使内容相同也创建审查记录(作为重新确认) - # 正常模式:只在内容不同时创建审查记录 - should_create_review = False - if relearn_mode: - # 重新学习模式:总是创建审查记录 - should_create_review = bool(updated_persona and current_persona) - if should_create_review: - # 检查是否有实质性变化 - has_changes = updated_persona.get('prompt', '') != current_persona.get('prompt', '') - if has_changes: - logger.info(f" 重新学习模式:检测到人格变化,创建审查记录(group: {group_id})") - else: - logger.info(f" 重新学习模式:未检测到人格变化,但仍创建审查记录供审核(group: {group_id})") - else: - logger.warning(f" 重新学习模式:无法创建审查记录 - updated_persona={bool(updated_persona)}, current_persona={bool(current_persona)}") - elif updated_persona and current_persona and updated_persona.get('prompt') != current_persona.get('prompt'): - # 正常模式:只在内容不同时创建 - should_create_review = True - logger.info(f" 正常模式:检测到人格变化,创建审查记录(group: {group_id})") - else: - logger.debug(f" 正常模式:人格未变化,跳过审查记录 - updated={bool(updated_persona)}, current={bool(current_persona)}, same_prompt={updated_persona.get('prompt') == current_persona.get('prompt') if updated_persona and current_persona else 'N/A'}") - - if should_create_review: - try: - # 提取原人格和新人格的完整文本 - original_prompt = current_persona.get('prompt', '') - new_prompt = updated_persona.get('prompt', '') - - # 计算新增内容(用于单独标记) - if len(new_prompt) > len(original_prompt): - incremental_content = new_prompt[len(original_prompt):].strip() - else: - incremental_content = new_prompt - - # 准备元数据(包含高亮信息) - metadata = { - "progressive_learning": True, - "message_count": len(messages), - "style_analysis_fields": list(style_analysis.data.keys()) if (hasattr(style_analysis, "data") and isinstance(style_analysis.data, dict)) else (list(style_analysis.keys()) if isinstance(style_analysis, dict) else []), - "original_prompt_length": len(original_prompt), - "new_prompt_length": len(new_prompt), - "incremental_content": incremental_content, # 单独记录增量内容,用于高亮 - "incremental_start_pos": len(original_prompt), # 标记新增内容的起始位置 - "relearn_mode": relearn_mode # 标记是否���重新学习模式 - } - - # 添加强化学习调优信息到元数据 - if ml_tuning_info: - metadata['ml_tuning'] = ml_tuning_info - - # 获取质量得分 - confidence_score = quality_metrics.consistency_score if quality_metrics and hasattr(quality_metrics, 'consistency_score') else 0.5 - - # 构建 raw_analysis 说明(包含强化学习信息) - raw_analysis_parts = [f"基于{len(messages)}条消息的风格分析"] - if relearn_mode: - raw_analysis_parts.append("(重新学习)") - if ml_tuning_info and ml_tuning_info.get('applied'): - if ml_tuning_info.get('used_conservative_fusion'): - raw_analysis_parts.append(f"强化学习生成的prompt过短({ml_tuning_info['tuned_length']} vs {ml_tuning_info['original_length']}),采用保守融合策略") - else: - raw_analysis_parts.append(f"已应用强化学习优化,预期改进: {ml_tuning_info['expected_improvement']:.2%}") - raw_analysis = ";".join(raw_analysis_parts) - - # 创建审查记录 - proposed_content 仅包含新增内容,审批时拼接 original_content - review_id = await self.db_manager.add_persona_learning_review( - group_id=group_id, - proposed_content=incremental_content, # 仅新增内容,不重复原始人格 - learning_source=UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING, - confidence_score=confidence_score, - raw_analysis=raw_analysis, - metadata=metadata, - original_content=original_prompt, # 原人格完整文本 - new_content=new_prompt # 完整新人格文本(original + incremental),用于审批应用 - ) - - logger.info(f" 已创建人格学习审查记录 (ID: {review_id}),置信度: {confidence_score:.3f}") - - except Exception as review_error: - logger.error(f"创建人格学习审查记录失败: {review_error}", exc_info=True) - else: - logger.debug(f"人格未变化或缺少必要参数,跳过审查记录创建") + await self._get_persona_learning_module().apply_persona_learning( + group_id, + style_analysis, + messages, + current_persona=current_persona, + updated_persona=updated_persona, + quality_metrics=quality_metrics, + relearn_mode=relearn_mode, + ml_tuning_info=ml_tuning_info, + ) - # 3. 记录学习更新 if group_id in self._group_sessions: self._group_sessions[group_id].style_updates += 1 @@ -1270,276 +962,47 @@ async def stop(self): async def _save_style_learning_record(self, group_id: str, style_analysis: Dict[str, Any], messages: List[Dict[str, Any]], quality_metrics=None): - """ - 保存对话风格学习记录(直接保存,不需要审查) - - Args: - group_id: 群组ID - style_analysis: 风格分析结果(可以为空,会基于消息创建简单记录) - messages: 处理的消息列表 - quality_metrics: 质量指标 - """ - try: - messages = filter_learning_messages(messages or []) - - # 处理 AnalysisResult 对象,提取其 data 属性 - if style_analysis and hasattr(style_analysis, 'data'): - style_analysis_dict = style_analysis.data - elif isinstance(style_analysis, dict): - style_analysis_dict = style_analysis - else: - style_analysis_dict = {} - - # 即使没有 style_analysis,也应该基于消息创建学习记录 - if not style_analysis_dict and not messages: - logger.debug(f"群组 {group_id} 没有风格分析结果且没有消息,跳过风格学习记录保存") - return - - # 1. 保存表达模式到 expression_patterns 表 - expression_patterns = style_analysis_dict.get('expression_patterns', []) - expression_patterns = self._filter_expression_patterns(expression_patterns) - - # 在 fewshot 模式下,style_analysis 可能不包含 expression_patterns。 - # 此时从数据库获取 bot 消息与用户消息合并,提取 user->bot 对话对。 - if not expression_patterns and messages: - try: - merged = await self._merge_bot_messages_for_pairs(group_id, messages) - if merged: - expression_patterns = self._extract_fewshot_pairs_from_merged(merged, group_id) - except Exception as pair_err: - logger.debug(f"提取 fewshot 对话对失败: {pair_err}") - - if expression_patterns: - await self._save_expression_patterns(group_id, expression_patterns) - - # 2. 构建 few_shots 内容(仅使用原始 A/B 对话对,不做 LLM 总结) - few_shots_content = '' - if expression_patterns: - few_shots_content = self._build_few_shots_from_patterns(expression_patterns) - - # 如果没有 few_shots_content,从消息中构建简单的学习内容 - if not few_shots_content and messages: - few_shots_content = f"基于 {len(messages)} 条对话消息的风格学习" - - # 3. 构建学习模式列表 - learned_patterns = [] - for pattern in expression_patterns[:10]: # 取前10个模式 - learned_patterns.append({ - 'situation': pattern.get('situation', ''), - 'expression': pattern.get('expression', ''), - 'weight': pattern.get('weight', 1.0), - 'confidence': pattern.get('confidence', 0.8) - }) - - # 4. 获取质量得分 - confidence_score = quality_metrics.consistency_score if quality_metrics and hasattr(quality_metrics, 'consistency_score') else 0.75 - - # 5. 构建描述 - pattern_count = len(learned_patterns) if learned_patterns else 0 - message_count = len(messages) if messages else 0 - description = f"群组 {group_id} 的对话风格学习结果(处理 {message_count} 条消息,提取 {pattern_count} 个表达模式)" - - # 6. 保存风格学习记录(使用 ORM) - # 提取到有效对话对时设为 pending 等待审查,否则自动批准 - try: - async with self.db_manager.get_session() as session: - from ...models.orm.learning import StyleLearningReview - from datetime import datetime - - current_timestamp = time.time() - has_patterns = bool(learned_patterns) - - review = StyleLearningReview( - type=UPDATE_TYPE_STYLE_LEARNING, - group_id=group_id, - timestamp=current_timestamp, - learned_patterns=json.dumps(learned_patterns, ensure_ascii=False), - few_shots_content=few_shots_content, - status='pending' if has_patterns else 'approved', - description=description, - reviewer_comment=None if has_patterns else '自动批准(无有效对话对)', - review_time=None if has_patterns else current_timestamp, - created_at=datetime.fromtimestamp(current_timestamp), - updated_at=datetime.fromtimestamp(current_timestamp), - ) - - session.add(review) - await session.commit() - await session.refresh(review) - - logger.info(f" 对话风格学习记录已保存 (ID: {review.id}),处理 {message_count} 条消息,提取 {pattern_count} 个模式") - - except Exception as e: - logger.error(f"保存对话风格学习记录失败: {e}", exc_info=True) - - except Exception as e: - logger.error(f"保存风格学习记录失败: {e}", exc_info=True) + """保存对话风格学习记录(兼容转发)。""" + await self._get_expression_learning_module().save_style_learning_record( + group_id, + style_analysis, + messages, + quality_metrics, + ) @staticmethod def _filter_expression_patterns(patterns: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Remove command/system-derived pairs before saving style samples.""" - filtered = [] - for pattern in patterns or []: - if not isinstance(pattern, dict): - continue - situation = pattern.get('situation', '') - expression = pattern.get('expression', '') - if should_ignore_learning_sample(situation): - continue - if should_ignore_learning_sample(expression, sender_id='bot', is_bot=True): - continue - filtered.append(pattern) - return filtered + return ExpressionLearningModule.filter_expression_patterns(patterns) def _build_few_shots_from_patterns(self, patterns: List[Dict[str, Any]]) -> str: - """从表达模式构建 few-shots 内容""" - few_shots = "*Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n" - - for i, pattern in enumerate(self._filter_expression_patterns(patterns)[:5], 1): # 只取前5个 - situation = pattern.get('situation', '') - expression = pattern.get('expression', '') - if situation and expression: - few_shots += f"A: {situation}\nB: {expression}\n\n" - - return few_shots.strip() + """从表达模式构建 few-shots 内容。""" + return self._get_expression_learning_module().build_few_shots_from_patterns( + patterns + ) async def _merge_bot_messages_for_pairs( self, group_id: str, user_messages: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: - """Merge user messages with bot messages from DB to form a timeline. - - Fetches recent bot messages for the group and interleaves them with - user messages sorted by timestamp, producing a unified stream that - allows user->bot pair extraction. - """ - user_messages = filter_learning_messages(user_messages) - if not user_messages: - return [] - - bot_texts = await self.db_manager.get_recent_bot_responses(group_id, limit=50) - if not bot_texts: - return [] - - # bot_texts is List[str]; build dicts with bot sender_id - bot_msgs = [] - # Retrieve full BotMessage records to get timestamps - async with self.db_manager.get_session() as session: - from ...models.orm.message import BotMessage - from sqlalchemy import select, desc - stmt = ( - select(BotMessage) - .where(BotMessage.group_id == group_id) - .order_by(desc(BotMessage.timestamp)) - .limit(50) - ) - result = await session.execute(stmt) - for row in result.scalars().all(): - if should_ignore_learning_sample( - row.message, - sender_id='bot', - is_bot=True, - ): - continue - bot_msgs.append({ - 'sender_id': 'bot', - 'message': row.message, - 'timestamp': row.timestamp, - }) - - if not bot_msgs: - return [] - - merged = list(user_messages) + bot_msgs - merged.sort(key=lambda m: m.get('timestamp', 0)) - return merged + """Merge user messages with bot messages from DB to form a timeline.""" + return await self._get_expression_learning_module().merge_bot_messages_for_pairs( + group_id, + user_messages, + ) @staticmethod def _extract_fewshot_pairs_from_merged( merged: List[Dict[str, Any]], group_id: str ) -> List[Dict[str, Any]]: - """Extract user->bot conversation pairs from a merged message timeline. - - Mirrors the logic of ExpressionPatternLearner._extract_few_shot_pairs - but operates on plain dicts and returns expression pattern dicts. - """ - pairs = [] - current_time = time.time() - - for i in range(len(merged) - 1): - msg = merged[i] - nxt = merged[i + 1] - - msg_is_bot = msg.get('sender_id') == 'bot' - nxt_is_bot = nxt.get('sender_id') == 'bot' - msg_text = msg.get('message', '').strip() - nxt_text = nxt.get('message', '').strip() - - if not msg_is_bot and nxt_is_bot and msg_text and nxt_text: - if should_ignore_learning_sample(msg_text): - continue - if should_ignore_learning_sample(nxt_text, sender_id='bot', is_bot=True): - continue - if len(msg_text) < 3 or len(nxt_text) < 3: - continue - if msg_text.startswith(('[', 'http', '@')): - continue - if nxt_text.startswith(('[', 'http', '@')): - continue - if '@' in msg_text or '@' in nxt_text: - continue - - pairs.append({ - 'situation': msg_text[:50], - 'expression': nxt_text[:100], - 'weight': 1.0, - 'confidence': 0.8, - 'group_id': group_id, - 'last_active_time': current_time, - 'create_time': current_time, - }) - - return pairs + """Extract user->bot conversation pairs from a merged message timeline.""" + return ExpressionLearningModule.extract_fewshot_pairs_from_merged( + merged, + group_id, + ) async def _save_expression_patterns(self, group_id: str, patterns: List[Dict[str, Any]]): - """ - 保存表达模式到 expression_patterns 表 - - Args: - group_id: 群组ID - patterns: 表达模式列表 - """ - try: - if not patterns: - return - - # 使用 ORM 批量保存表达模式 - async with self.db_manager.get_session() as session: - from ...models.orm.expression import ExpressionPattern - import time - - current_time = time.time() - objects = [] - - for pattern in patterns: - situation = pattern.get('situation', '').strip() - expression = pattern.get('expression', '').strip() - - if not situation or not expression: - continue - - objects.append(ExpressionPattern( - group_id=group_id, - situation=situation, - expression=expression, - weight=float(pattern.get('weight', 1.0)), - last_active_time=current_time, - create_time=current_time - )) - - if objects: - session.add_all(objects) - await session.commit() - logger.info(f"已保存 {len(objects)} 个表达模式到数据库 (群组: {group_id})") - - except Exception as e: - logger.error(f"保存表达模式失败: {e}", exc_info=True) + """保存表达模式到 expression_patterns 表。""" + await self._get_expression_learning_module().save_expression_patterns( + group_id, + patterns, + ) diff --git a/services/learning/expression_learning.py b/services/learning/expression_learning.py new file mode 100644 index 00000000..1bb8ba1c --- /dev/null +++ b/services/learning/expression_learning.py @@ -0,0 +1,305 @@ +"""Expression learning module. + +Owns expression/few-shot learning persistence so progressive learning can +orchestrate batches without also carrying expression-specific storage rules. +""" + +import json +import time +from datetime import datetime +from typing import Any, Dict, List + +from astrbot.api import logger + +from ...constants import UPDATE_TYPE_STYLE_LEARNING +from .sample_filter import filter_learning_messages, should_ignore_learning_sample + + +class ExpressionLearningModule: + """Persist learned expression patterns and style review records.""" + + def __init__(self, db_manager: Any) -> None: + self.db_manager = db_manager + + async def save_style_learning_record( + self, + group_id: str, + style_analysis: Any, + messages: List[Dict[str, Any]], + quality_metrics: Any = None, + ) -> None: + """Save expression learning output and create a style review record.""" + try: + messages = filter_learning_messages(messages or []) + + if style_analysis and hasattr(style_analysis, "data"): + style_analysis_dict = style_analysis.data + elif isinstance(style_analysis, dict): + style_analysis_dict = style_analysis + else: + style_analysis_dict = {} + + if not style_analysis_dict and not messages: + logger.debug( + f"群组 {group_id} 没有风格分析结果且没有消息,跳过风格学习记录保存" + ) + return + + expression_patterns = style_analysis_dict.get("expression_patterns", []) + expression_patterns = self.filter_expression_patterns(expression_patterns) + + if not expression_patterns and messages: + try: + merged = await self.merge_bot_messages_for_pairs(group_id, messages) + if merged: + expression_patterns = self.extract_fewshot_pairs_from_merged( + merged, group_id + ) + except Exception as pair_err: + logger.debug(f"提取 fewshot 对话对失败: {pair_err}") + + if expression_patterns: + await self.save_expression_patterns(group_id, expression_patterns) + + few_shots_content = "" + if expression_patterns: + few_shots_content = self.build_few_shots_from_patterns( + expression_patterns + ) + + if not few_shots_content and messages: + few_shots_content = f"基于 {len(messages)} 条对话消息的风格学习" + + learned_patterns = [] + for pattern in expression_patterns[:10]: + learned_patterns.append( + { + "situation": pattern.get("situation", ""), + "expression": pattern.get("expression", ""), + "weight": pattern.get("weight", 1.0), + "confidence": pattern.get("confidence", 0.8), + } + ) + + confidence_score = ( + quality_metrics.consistency_score + if quality_metrics and hasattr(quality_metrics, "consistency_score") + else 0.75 + ) + del confidence_score # Stored review currently derives status from patterns. + + pattern_count = len(learned_patterns) if learned_patterns else 0 + message_count = len(messages) if messages else 0 + description = ( + f"群组 {group_id} 的对话风格学习结果" + f"(处理 {message_count} 条消息,提取 {pattern_count} 个表达模式)" + ) + + try: + async with self.db_manager.get_session() as session: + from ...models.orm.learning import StyleLearningReview + + current_timestamp = time.time() + has_patterns = bool(learned_patterns) + + review = StyleLearningReview( + type=UPDATE_TYPE_STYLE_LEARNING, + group_id=group_id, + timestamp=current_timestamp, + learned_patterns=json.dumps( + learned_patterns, ensure_ascii=False + ), + few_shots_content=few_shots_content, + status="pending" if has_patterns else "approved", + description=description, + reviewer_comment=None if has_patterns else "自动批准(无有效对话对)", + review_time=None if has_patterns else current_timestamp, + created_at=datetime.fromtimestamp(current_timestamp), + updated_at=datetime.fromtimestamp(current_timestamp), + ) + + session.add(review) + await session.commit() + await session.refresh(review) + + logger.info( + f" 对话风格学习记录已保存 (ID: {review.id})," + f"处理 {message_count} 条消息,提取 {pattern_count} 个模式" + ) + + except Exception as exc: + logger.error(f"保存对话风格学习记录失败: {exc}", exc_info=True) + + except Exception as exc: + logger.error(f"保存风格学习记录失败: {exc}", exc_info=True) + + @staticmethod + def filter_expression_patterns( + patterns: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """Remove command/system-derived pairs before saving style samples.""" + filtered = [] + for pattern in patterns or []: + if not isinstance(pattern, dict): + continue + situation = pattern.get("situation", "") + expression = pattern.get("expression", "") + if should_ignore_learning_sample(situation): + continue + if should_ignore_learning_sample(expression, sender_id="bot", is_bot=True): + continue + filtered.append(pattern) + return filtered + + def build_few_shots_from_patterns(self, patterns: List[Dict[str, Any]]) -> str: + """Build few-shot dialog text from expression patterns.""" + few_shots = ( + "*Here are few shots of dialogs, you need to imitate the tone of 'B' " + "in the following dialogs to respond:\n" + ) + + for pattern in self.filter_expression_patterns(patterns)[:5]: + situation = pattern.get("situation", "") + expression = pattern.get("expression", "") + if situation and expression: + few_shots += f"A: {situation}\nB: {expression}\n\n" + + return few_shots.strip() + + async def merge_bot_messages_for_pairs( + self, group_id: str, user_messages: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Merge user messages with bot messages from DB to form a timeline.""" + user_messages = filter_learning_messages(user_messages) + if not user_messages: + return [] + + bot_texts = await self.db_manager.get_recent_bot_responses( + group_id, limit=50 + ) + if not bot_texts: + return [] + + bot_msgs = [] + async with self.db_manager.get_session() as session: + from sqlalchemy import desc, select + + from ...models.orm.message import BotMessage + + stmt = ( + select(BotMessage) + .where(BotMessage.group_id == group_id) + .order_by(desc(BotMessage.timestamp)) + .limit(50) + ) + result = await session.execute(stmt) + for row in result.scalars().all(): + if should_ignore_learning_sample( + row.message, + sender_id="bot", + is_bot=True, + ): + continue + bot_msgs.append( + { + "sender_id": "bot", + "message": row.message, + "timestamp": row.timestamp, + } + ) + + if not bot_msgs: + return [] + + merged = list(user_messages) + bot_msgs + merged.sort(key=lambda message: message.get("timestamp", 0)) + return merged + + @staticmethod + def extract_fewshot_pairs_from_merged( + merged: List[Dict[str, Any]], group_id: str + ) -> List[Dict[str, Any]]: + """Extract user->bot conversation pairs from a merged timeline.""" + pairs = [] + current_time = time.time() + + for idx in range(len(merged) - 1): + msg = merged[idx] + nxt = merged[idx + 1] + + msg_is_bot = msg.get("sender_id") == "bot" + nxt_is_bot = nxt.get("sender_id") == "bot" + msg_text = msg.get("message", "").strip() + nxt_text = nxt.get("message", "").strip() + + if not msg_is_bot and nxt_is_bot and msg_text and nxt_text: + if should_ignore_learning_sample(msg_text): + continue + if should_ignore_learning_sample( + nxt_text, sender_id="bot", is_bot=True + ): + continue + if len(msg_text) < 3 or len(nxt_text) < 3: + continue + if msg_text.startswith(("[", "http", "@")): + continue + if nxt_text.startswith(("[", "http", "@")): + continue + if "@" in msg_text or "@" in nxt_text: + continue + + pairs.append( + { + "situation": msg_text[:50], + "expression": nxt_text[:100], + "weight": 1.0, + "confidence": 0.8, + "group_id": group_id, + "last_active_time": current_time, + "create_time": current_time, + } + ) + + return pairs + + async def save_expression_patterns( + self, group_id: str, patterns: List[Dict[str, Any]] + ) -> None: + """Save expression patterns to the expression_patterns table.""" + try: + if not patterns: + return + + async with self.db_manager.get_session() as session: + from ...models.orm.expression import ExpressionPattern + + current_time = time.time() + objects = [] + + for pattern in patterns: + situation = pattern.get("situation", "").strip() + expression = pattern.get("expression", "").strip() + + if not situation or not expression: + continue + + objects.append( + ExpressionPattern( + group_id=group_id, + situation=situation, + expression=expression, + weight=float(pattern.get("weight", 1.0)), + last_active_time=current_time, + create_time=current_time, + ) + ) + + if objects: + session.add_all(objects) + await session.commit() + logger.info( + f"已保存 {len(objects)} 个表达模式到数据库 (群组: {group_id})" + ) + + except Exception as exc: + logger.error(f"保存表达模式失败: {exc}", exc_info=True) diff --git a/services/learning/jargon_learning.py b/services/learning/jargon_learning.py new file mode 100644 index 00000000..47fe2699 --- /dev/null +++ b/services/learning/jargon_learning.py @@ -0,0 +1,160 @@ +"""Jargon learning module. + +Keeps jargon trigger accounting and mining work outside the main message +pipeline. The pipeline still owns task tracking and event sequencing. +""" + +from typing import Any, Dict, Optional, Set + +from astrbot.api import logger + +from .sample_filter import filter_learning_messages + + +class JargonLearningModule: + """Coordinate jargon statistical updates, triggers, and mining.""" + + def __init__( + self, + *, + config: Any, + message_collector: Any, + jargon_miner_manager: Optional[Any], + jargon_statistical_filter: Optional[Any], + db_manager: Any, + ) -> None: + self._config = config + self._message_collector = message_collector + self._jargon_miner_manager = jargon_miner_manager + self._jargon_statistical_filter = jargon_statistical_filter + self._db_manager = db_manager + + self.active_groups: Set[str] = set() + self.last_trigger_counts: Dict[str, int] = {} + self.group_raw_message_counts: Dict[str, int] = {} + self.groups_seeded: set[str] = set() + + def update_statistical_filter( + self, + message_text: str, + group_id: str, + sender_id: str, + ) -> None: + """Update the cheap statistical pre-filter for one message.""" + if not self._config.enable_jargon_learning: + return + if not self._jargon_statistical_filter: + return + try: + self._jargon_statistical_filter.update_from_message( + message_text, group_id, sender_id + ) + except Exception: + pass # best-effort + + def note_collected_message(self, group_id: str) -> None: + """Track in-memory raw message count after collection succeeds.""" + self.group_raw_message_counts[group_id] = ( + self.group_raw_message_counts.get(group_id, 0) + 1 + ) + + async def get_raw_message_count(self, group_id: str) -> int: + """Get raw message count for a group, seeded from DB once.""" + if group_id not in self.groups_seeded: + try: + stats = await self._message_collector.get_statistics(group_id) + db_count = stats.get("raw_messages", 0) + memory_count = self.group_raw_message_counts.get(group_id, 0) + self.group_raw_message_counts[group_id] = max(db_count, memory_count) + except Exception: + pass + self.groups_seeded.add(group_id) + return self.group_raw_message_counts.get(group_id, 0) + + def should_schedule_mining(self, group_id: str, raw_message_count: int) -> bool: + """Trigger jargon mining once per additional 10 messages per group.""" + if raw_message_count < 10: + return False + if group_id in self.active_groups: + return False + last_trigger = self.last_trigger_counts.get(group_id, 0) + return raw_message_count - last_trigger >= 10 + + def mark_mining_started(self, group_id: str, raw_message_count: int) -> None: + """Record trigger state before spawning a mining task.""" + self.last_trigger_counts[group_id] = raw_message_count + self.active_groups.add(group_id) + + def mark_mining_finished(self, group_id: str) -> None: + """Clear active mining state for a group.""" + self.active_groups.discard(group_id) + + async def mine_jargon(self, group_id: str) -> None: + """Run one jargon mining iteration for a group.""" + try: + if not self._config.enable_jargon_learning: + logger.debug("[JargonMining] Jargon learning disabled, skip") + return + + if not self._jargon_miner_manager: + logger.debug("[JargonMining] JargonMinerManager not initialised, skip") + return + + jargon_miner = self._jargon_miner_manager.get_or_create_miner(group_id) + + stats = await self._message_collector.get_statistics(group_id) + recent_message_count = stats.get("raw_messages", 0) + + if not jargon_miner.should_trigger(recent_message_count): + logger.debug( + f"[JargonMining] Group {group_id} trigger conditions not met" + ) + return + + recent_messages = await self._db_manager.get_recent_raw_messages( + group_id, limit=30 + ) + recent_messages = filter_learning_messages(recent_messages) + + if len(recent_messages) < 10: + logger.debug( + f"[JargonMining] Group {group_id} insufficient messages " + f"({len(recent_messages)}<10)" + ) + return + + logger.debug( + f"[JargonMining] Analysing {len(recent_messages)} messages " + f"from group {group_id}" + ) + + chat_messages = "\n".join( + [ + f"{msg.get('sender_id', 'unknown')}: {msg.get('message', '')}" + for msg in recent_messages + ] + ) + + statistical_candidates = None + if self._jargon_statistical_filter: + statistical_candidates = ( + self._jargon_statistical_filter.get_jargon_candidates( + group_id, top_k=20 + ) + ) + if not statistical_candidates: + statistical_candidates = None + + await jargon_miner.run_once( + chat_messages, + len(recent_messages), + statistical_candidates=statistical_candidates, + ) + + logger.debug(f"[JargonMining] Group {group_id} learning complete") + + except Exception as exc: + logger.error( + f"[JargonMining] Background task failed (group={group_id}): {exc}", + exc_info=True, + ) diff --git a/services/learning/message_pipeline.py b/services/learning/message_pipeline.py index 6e8be3a2..da695f03 100644 --- a/services/learning/message_pipeline.py +++ b/services/learning/message_pipeline.py @@ -7,6 +7,7 @@ from ...core.interfaces import MessageData from ...statics.messages import LogMessages +from .jargon_learning import JargonLearningModule from .sample_filter import ( extract_learning_event_metadata, filter_learning_messages, @@ -43,10 +44,20 @@ def __init__( self._affection_manager = affection_manager self._db_manager = db_manager self._subtasks: Set[asyncio.Task] = set() - self._active_jargon_groups: Set[str] = set() - self._last_jargon_trigger_counts: dict[str, int] = {} - self._group_raw_message_counts: dict[str, int] = {} - self._groups_seeded: set[str] = set() + self._jargon_learning = JargonLearningModule( + config=plugin_config, + message_collector=message_collector, + jargon_miner_manager=jargon_miner_manager, + jargon_statistical_filter=jargon_statistical_filter, + db_manager=db_manager, + ) + # Compatibility attributes for existing tests and integrations. + self._active_jargon_groups = self._jargon_learning.active_groups + self._last_jargon_trigger_counts = self._jargon_learning.last_trigger_counts + self._group_raw_message_counts = ( + self._jargon_learning.group_raw_message_counts + ) + self._groups_seeded = self._jargon_learning.groups_seeded # 后台学习流水线(6 步) @@ -102,9 +113,7 @@ async def process_learning( # Track raw message count in memory for jargon trigger if message_collected: - self._group_raw_message_counts[group_id] = ( - self._group_raw_message_counts.get(group_id, 0) + 1 - ) + self._jargon_learning.note_collected_message(group_id) # 2. 增强交互(多轮对话管理) try: @@ -115,13 +124,9 @@ async def process_learning( logger.error(LogMessages.ENHANCED_INTERACTION_FAILED.format(error=e)) # 2.5 黑话统计预筛(<1ms, 零 LLM 成本) - if self._config.enable_jargon_learning and self._jargon_statistical_filter: - try: - self._jargon_statistical_filter.update_from_message( - message_text, group_id, sender_id - ) - except Exception: - pass # best-effort + self._jargon_learning.update_statistical_filter( + message_text, group_id, sender_id + ) # 3. 黑话挖掘 — 每收集 10 条消息触发一次 if self._config.enable_jargon_learning: @@ -203,66 +208,7 @@ async def mine_jargon(self, group_id: str) -> None: 4. 保存/更新到数据库并在阈值处触发推理 """ try: - if not self._config.enable_jargon_learning: - logger.debug("[JargonMining] Jargon learning disabled, skip") - return - - if not self._jargon_miner_manager: - logger.debug("[JargonMining] JargonMinerManager not initialised, skip") - return - - jargon_miner = self._jargon_miner_manager.get_or_create_miner(group_id) - - stats = await self._message_collector.get_statistics(group_id) - recent_message_count = stats.get("raw_messages", 0) - - if not jargon_miner.should_trigger(recent_message_count): - logger.debug( - f"[JargonMining] Group {group_id} trigger conditions not met" - ) - return - - recent_messages = await self._db_manager.get_recent_raw_messages( - group_id, limit=30 - ) - recent_messages = filter_learning_messages(recent_messages) - - if len(recent_messages) < 10: - logger.debug( - f"[JargonMining] Group {group_id} insufficient messages " - f"({len(recent_messages)}<10)" - ) - return - - logger.debug( - f"[JargonMining] Analysing {len(recent_messages)} messages " - f"from group {group_id}" - ) - - chat_messages = "\n".join( - [ - f"{msg.get('sender_id', 'unknown')}: {msg.get('message', '')}" - for msg in recent_messages - ] - ) - - statistical_candidates = None - if self._jargon_statistical_filter: - statistical_candidates = ( - self._jargon_statistical_filter.get_jargon_candidates( - group_id, top_k=20 - ) - ) - if not statistical_candidates: - statistical_candidates = None - - await jargon_miner.run_once( - chat_messages, - len(recent_messages), - statistical_candidates=statistical_candidates, - ) - - logger.debug(f"[JargonMining] Group {group_id} learning complete") + await self._jargon_learning.mine_jargon(group_id) except Exception as e: logger.error( @@ -295,37 +241,24 @@ async def process_affection( async def _get_raw_message_count(self, group_id: str) -> int: """Get raw message count for a group, seeded from DB once.""" - if group_id not in self._groups_seeded: - try: - stats = await self._message_collector.get_statistics(group_id) - db_count = stats.get("raw_messages", 0) - # Merge DB count with any messages collected in memory since startup - memory_count = self._group_raw_message_counts.get(group_id, 0) - self._group_raw_message_counts[group_id] = max(db_count, memory_count) - except Exception: - pass # fall back to memory-only count - self._groups_seeded.add(group_id) - return self._group_raw_message_counts.get(group_id, 0) + return await self._jargon_learning.get_raw_message_count(group_id) def _should_schedule_jargon_mining( self, group_id: str, raw_message_count: int ) -> bool: """Trigger jargon mining once per additional 10 messages per group.""" - if raw_message_count < 10: - return False - if group_id in self._active_jargon_groups: - return False - last_trigger = self._last_jargon_trigger_counts.get(group_id, 0) - return raw_message_count - last_trigger >= 10 + return self._jargon_learning.should_schedule_mining( + group_id, + raw_message_count, + ) def _spawn_jargon_task(self, group_id: str, raw_message_count: int) -> None: """Spawn a jargon-mining task and track group-level trigger state.""" - self._last_jargon_trigger_counts[group_id] = raw_message_count - self._active_jargon_groups.add(group_id) + self._jargon_learning.mark_mining_started(group_id, raw_message_count) task = self._spawn(self.mine_jargon(group_id)) def _on_complete(_: asyncio.Task) -> None: - self._active_jargon_groups.discard(group_id) + self._jargon_learning.mark_mining_finished(group_id) task.add_done_callback(_on_complete) diff --git a/services/learning/persona_learning.py b/services/learning/persona_learning.py new file mode 100644 index 00000000..11c33632 --- /dev/null +++ b/services/learning/persona_learning.py @@ -0,0 +1,515 @@ +"""Persona learning module. + +Keeps persona-specific learning logic separate from expression and jargon +learning. The progressive learning service remains the batch orchestrator. +""" + +import json +from datetime import datetime +from typing import Any, Dict, List, Optional + +from astrbot.api import logger + +from ...constants import UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING +from ...utils.json_utils import clean_llm_json_response, safe_parse_llm_json +from ...utils.persona_selection import get_persona_identifier, resolve_target_persona + + +class PersonaLearningModule: + """Generate persona candidates and create review records.""" + + def __init__( + self, + *, + config: Any, + context: Any, + db_manager: Any, + persona_manager: Any, + multidimensional_analyzer: Any, + prompts: Any, + resolve_umo, + json_serializer, + ) -> None: + self.config = config + self.context = context + self.db_manager = db_manager + self.persona_manager = persona_manager + self.multidimensional_analyzer = multidimensional_analyzer + self.prompts = prompts + self._resolve_umo = resolve_umo + self._json_serializer = json_serializer + + async def get_current_persona(self, group_id: str) -> Dict[str, Any]: + """Get current persona settings for a group.""" + try: + persona = await self.persona_manager.get_current_persona(group_id) + if persona: + return persona + + if hasattr(self.context, "persona_manager") and self.context.persona_manager: + try: + default_persona = await resolve_target_persona( + self.context.persona_manager, + self.config, + self._resolve_umo(group_id), + require_existing=True, + log=logger, + ) + if default_persona: + return { + "prompt": default_persona.get("prompt", "默认人格"), + "name": get_persona_identifier(default_persona), + "style_parameters": {}, + "last_updated": datetime.now().isoformat(), + } + except Exception as exc: + logger.warning(f"从框架获取默认人格失败: {exc}") + + return { + "prompt": "默认人格", + "name": "default", + "style_parameters": {}, + "last_updated": datetime.now().isoformat(), + } + except Exception as exc: + logger.error(f"获取当前人格失败 for group {group_id}: {exc}") + return {"prompt": "默认人格", "name": "default", "style_parameters": {}} + + async def generate_updated_persona_with_refinement( + self, + group_id: str, + current_persona: Dict[str, Any], + style_analysis: Any, + ) -> Dict[str, Any]: + """Generate a persona candidate, using refine provider when available.""" + try: + analysis_data = self.extract_analysis_data(style_analysis) + + if ( + hasattr(self.multidimensional_analyzer, "llm_adapter") + and self.multidimensional_analyzer.llm_adapter + ): + llm_adapter = self.multidimensional_analyzer.llm_adapter + + if ( + llm_adapter.has_refine_provider() + and llm_adapter.providers_configured >= 2 + ): + current_persona_json = json.dumps( + current_persona, + ensure_ascii=False, + indent=2, + default=self._json_serializer, + ) + style_analysis_json = json.dumps( + analysis_data, + ensure_ascii=False, + indent=2, + default=self._json_serializer, + ) + + response = await llm_adapter.refine_chat_completion( + prompt=self.prompts.PROGRESSIVE_LEARNING_GENERATE_UPDATED_PERSONA_PROMPT.format( + current_persona_json=current_persona_json, + style_analysis_json=style_analysis_json, + ), + temperature=0.6, + ) + + if response: + clean_response = clean_llm_json_response(response) + try: + updated_persona = safe_parse_llm_json(clean_response) + logger.info("使用提炼模型成功生成更新后的人格") + return updated_persona + except json.JSONDecodeError as exc: + logger.error( + f"提炼模型返回的JSON格式不正确: {exc}, 响应: {clean_response}" + ) + return await self.generate_updated_persona( + group_id, current_persona, style_analysis + ) + + logger.warning("提炼模型Provider未配置,使用传统方法生成人格") + return await self.generate_updated_persona( + group_id, current_persona, style_analysis + ) + + logger.warning("框架适配器未找到,使用传统方法生成人格") + return await self.generate_updated_persona( + group_id, current_persona, style_analysis + ) + + except Exception as exc: + logger.error(f"使用提炼模型生成人格失败: {exc}") + return await self.generate_updated_persona( + group_id, current_persona, style_analysis + ) + + async def generate_updated_persona( + self, + group_id: str, + current_persona: Dict[str, Any], + style_analysis: Any, + ) -> Dict[str, Any]: + """Generate a persona candidate by appending incremental learning text.""" + try: + if not hasattr(self.context, "persona_manager") or not self.context.persona_manager: + logger.warning(f"无法获取PersonaManager for group {group_id}") + return current_persona + + default_persona = await resolve_target_persona( + self.context.persona_manager, + self.config, + self._resolve_umo(group_id), + require_existing=True, + log=logger, + ) + if not default_persona: + logger.warning(f"无法获取当前人格 for group {group_id}") + return current_persona + + original_prompt = default_persona.get("prompt", "") + learning_content = self.extract_learning_content(style_analysis) + + if learning_content: + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") + new_content = ( + f"\n\n【学习更新 - {timestamp}】\n" + + "\n".join(learning_content) + ) + + updated_persona = dict(default_persona) + updated_persona["prompt"] = original_prompt + new_content + updated_persona["last_updated"] = timestamp + + logger.info( + f" 成功追加 {len(learning_content)} 项学习内容到人格 for group {group_id}" + ) + return updated_persona + + analysis_data = self.extract_analysis_data(style_analysis) + logger.warning( + f" style_analysis中没有可提取的学习内容 for group {group_id}, " + f"数据结构: {list(analysis_data.keys())}" + ) + return dict(default_persona) + + except Exception as exc: + logger.error(f"生成更新人格失败 for group {group_id}: {exc}", exc_info=True) + return current_persona + + async def apply_persona_learning( + self, + group_id: str, + style_analysis: Any, + messages: List[Dict[str, Any]], + *, + current_persona: Optional[Dict[str, Any]] = None, + updated_persona: Optional[Dict[str, Any]] = None, + quality_metrics: Any = None, + relearn_mode: bool = False, + ml_tuning_info: Optional[Dict[str, Any]] = None, + ) -> bool: + """Apply persona-side learning and create a review when needed.""" + try: + current_persona = self._coerce_persona(current_persona, "current_persona") + updated_persona = self._coerce_persona(updated_persona, "updated_persona") + + logger.info(f"应用人格更新 for group {group_id}") + + style_analysis_dict, _confidence = self.style_analysis_to_dict( + style_analysis + ) + if style_analysis_dict is None: + return False + + update_success = await self.persona_manager.update_persona( + group_id, style_analysis_dict, messages + ) + if not update_success: + logger.error(f"通过 PersonaManagerService 更新人格失败 for group {group_id}") + + should_create_review = self.should_create_review( + current_persona=current_persona, + updated_persona=updated_persona, + relearn_mode=relearn_mode, + group_id=group_id, + ) + if should_create_review: + await self.create_persona_review( + group_id, + style_analysis, + messages, + current_persona=current_persona, + updated_persona=updated_persona, + quality_metrics=quality_metrics, + relearn_mode=relearn_mode, + ml_tuning_info=ml_tuning_info, + ) + else: + logger.debug("人格未变化或缺少必要参数,跳过审查记录创建") + + return bool(update_success) + + except Exception as exc: + logger.error(f"应用人格学习失败 for group {group_id}: {exc}") + return False + + @staticmethod + def extract_analysis_data(style_analysis: Any) -> Dict[str, Any]: + """Normalize AnalysisResult/dict-like style analysis to a dict.""" + try: + from ...core.interfaces import AnalysisResult + + if isinstance(style_analysis, AnalysisResult): + analysis_data = style_analysis.data if style_analysis.data else {} + logger.debug( + "从AnalysisResult提取data: " + f"success={style_analysis.success}, confidence={style_analysis.confidence}" + ) + return analysis_data + except Exception: + pass + + if isinstance(style_analysis, dict): + logger.debug("使用字典形式的style_analysis") + return style_analysis + if hasattr(style_analysis, "data"): + analysis_data = style_analysis.data if style_analysis.data else {} + logger.debug(f"从对象提取data属性: {type(style_analysis)}") + return analysis_data + + logger.warning(f"style_analysis类型不正确: {type(style_analysis)}, 使用空字典") + return {} + + def extract_learning_content(self, style_analysis: Any) -> List[str]: + """Extract persona increment text from style analysis data.""" + analysis_data = self.extract_analysis_data(style_analysis) + learning_content: List[str] = [] + + if "enhanced_prompt" in analysis_data: + learning_content.append(analysis_data["enhanced_prompt"]) + logger.debug("找到 enhanced_prompt 字段") + + if "learning_insights" in analysis_data: + insights = analysis_data["learning_insights"] + if insights: + learning_content.append(insights) + logger.debug("找到 learning_insights 字段") + + if not learning_content and "style_analysis" in analysis_data: + style_report = analysis_data["style_analysis"] + if isinstance(style_report, dict): + extracted_parts = [] + + if "text_style" in style_report: + extracted_parts.append(f"文本风格: {style_report['text_style']}") + + if "expression_features" in style_report: + features = style_report["expression_features"] + if isinstance(features, list): + extracted_parts.append(f"表达特点: {', '.join(features)}") + elif isinstance(features, str): + extracted_parts.append(f"表达特点: {features}") + + if "tone" in style_report: + extracted_parts.append(f"语气倾向: {style_report['tone']}") + + if "topics" in style_report: + topics = style_report["topics"] + if isinstance(topics, list): + extracted_parts.append(f"话题偏好: {', '.join(topics)}") + elif isinstance(topics, str): + extracted_parts.append(f"话题偏好: {topics}") + + if extracted_parts: + learning_content.append("【对话风格学习结果】\n" + "\n".join(extracted_parts)) + logger.debug( + f"从 style_analysis 提取了 {len(extracted_parts)} 个风格特征" + ) + + if not learning_content and "style_profile" in analysis_data: + style_profile = analysis_data["style_profile"] + if isinstance(style_profile, dict): + profile_parts = [] + + if "tone_intensity" in style_profile: + profile_parts.append(f"语气强度: {style_profile['tone_intensity']:.2f}") + if "sentiment" in style_profile: + profile_parts.append(f"情感倾向: {style_profile['sentiment']:.2f}") + if "vocabulary_richness" in style_profile: + profile_parts.append( + f"词汇丰富度: {style_profile['vocabulary_richness']:.2f}" + ) + + if profile_parts: + learning_content.append("【风格量化指标】\n" + "\n".join(profile_parts)) + logger.debug(f"从 style_profile 提取了 {len(profile_parts)} 个量化指标") + + if not learning_content: + for field in ("summary", "description", "analysis", "insights", "findings"): + if field in analysis_data and analysis_data[field]: + learning_content.append(f"【{field}】\n{analysis_data[field]}") + logger.debug(f"从顶层字段 {field} 提取了内容") + break + + return learning_content + + def style_analysis_to_dict(self, style_analysis: Any) -> tuple[Optional[Dict[str, Any]], float]: + """Return (data, confidence) for persona manager updates.""" + if hasattr(style_analysis, "success"): + if not style_analysis.success: + logger.error(f"风格分析失败,跳过人格更新: {style_analysis.error}") + return None, 0.0 + logger.debug(f"使用 AnalysisResult 对象,置信度: {style_analysis.confidence:.3f}") + return style_analysis.data, style_analysis.confidence + if isinstance(style_analysis, dict): + logger.debug("使用字典形式的 style_analysis(向后兼容)") + return style_analysis, style_analysis.get("confidence", 0.5) + + logger.error(f"style_analysis 类型不正确: {type(style_analysis)}") + return None, 0.0 + + @staticmethod + def should_create_review( + *, + current_persona: Optional[Dict[str, Any]], + updated_persona: Optional[Dict[str, Any]], + relearn_mode: bool, + group_id: str, + ) -> bool: + """Determine whether a persona review record should be created.""" + if relearn_mode: + should_create = bool(updated_persona and current_persona) + if should_create: + has_changes = updated_persona.get("prompt", "") != current_persona.get( + "prompt", "" + ) + if has_changes: + logger.info( + f" 重新学习模式:检测到人格变化,创建审查记录(group: {group_id})" + ) + else: + logger.info( + f" 重新学习模式:未检测到人格变化,但仍创建审查记录供审核" + f"(group: {group_id})" + ) + else: + logger.warning( + " 重新学习模式:无法创建审查记录 - " + f"updated_persona={bool(updated_persona)}, " + f"current_persona={bool(current_persona)}" + ) + return should_create + + if ( + updated_persona + and current_persona + and updated_persona.get("prompt") != current_persona.get("prompt") + ): + logger.info(f" 正常模式:检测到人格变化,创建审查记录(group: {group_id})") + return True + + logger.debug( + f" 正常模式:人格未变化,跳过审查记录 - " + f"updated={bool(updated_persona)}, current={bool(current_persona)}, " + f"same_prompt={updated_persona.get('prompt') == current_persona.get('prompt') if updated_persona and current_persona else 'N/A'}" + ) + return False + + async def create_persona_review( + self, + group_id: str, + style_analysis: Any, + messages: List[Dict[str, Any]], + *, + current_persona: Dict[str, Any], + updated_persona: Dict[str, Any], + quality_metrics: Any = None, + relearn_mode: bool = False, + ml_tuning_info: Optional[Dict[str, Any]] = None, + ) -> Optional[int]: + """Create a persona learning review record.""" + try: + original_prompt = current_persona.get("prompt", "") + new_prompt = updated_persona.get("prompt", "") + + if len(new_prompt) > len(original_prompt): + incremental_content = new_prompt[len(original_prompt):].strip() + else: + incremental_content = new_prompt + + metadata = { + "progressive_learning": True, + "message_count": len(messages), + "style_analysis_fields": list( + style_analysis.data.keys() + if hasattr(style_analysis, "data") + and isinstance(style_analysis.data, dict) + else style_analysis.keys() + if isinstance(style_analysis, dict) + else [] + ), + "original_prompt_length": len(original_prompt), + "new_prompt_length": len(new_prompt), + "incremental_content": incremental_content, + "incremental_start_pos": len(original_prompt), + "relearn_mode": relearn_mode, + } + + if ml_tuning_info: + metadata["ml_tuning"] = ml_tuning_info + + confidence_score = ( + quality_metrics.consistency_score + if quality_metrics and hasattr(quality_metrics, "consistency_score") + else 0.5 + ) + + raw_analysis_parts = [f"基于{len(messages)}条消息的风格分析"] + if relearn_mode: + raw_analysis_parts.append("(重新学习)") + if ml_tuning_info and ml_tuning_info.get("applied"): + if ml_tuning_info.get("used_conservative_fusion"): + raw_analysis_parts.append( + "强化学习生成的prompt过短" + f"({ml_tuning_info['tuned_length']} vs " + f"{ml_tuning_info['original_length']}),采用保守融合策略" + ) + else: + raw_analysis_parts.append( + "已应用强化学习优化,预期改进: " + f"{ml_tuning_info['expected_improvement']:.2%}" + ) + raw_analysis = ";".join(raw_analysis_parts) + + review_id = await self.db_manager.add_persona_learning_review( + group_id=group_id, + proposed_content=incremental_content, + learning_source=UPDATE_TYPE_PROGRESSIVE_PERSONA_LEARNING, + confidence_score=confidence_score, + raw_analysis=raw_analysis, + metadata=metadata, + original_content=original_prompt, + new_content=new_prompt, + ) + + logger.info( + f" 已创建人格学习审查记录 (ID: {review_id})," + f"置信度: {confidence_score:.3f}" + ) + return review_id + + except Exception as exc: + logger.error(f"创建人格学习审查记录失败: {exc}", exc_info=True) + return None + + @staticmethod + def _coerce_persona( + persona: Optional[Dict[str, Any]], label: str + ) -> Optional[Dict[str, Any]]: + if isinstance(persona, list): + logger.warning(f"{label}为list类型(长度{len(persona)}),转换为空字典") + return {} + return persona diff --git a/web_res/static/html/dashboard.html b/web_res/static/html/dashboard.html index a602100c..b01537aa 100644 --- a/web_res/static/html/dashboard.html +++ b/web_res/static/html/dashboard.html @@ -186,6 +186,7 @@ max-width: 1440px; margin: 0 auto; padding: 24px 28px; + overflow-x: clip; } .topbar { @@ -418,6 +419,358 @@ background-clip: text; } + .learning-dashboard { + display: grid; + gap: 16px; + margin-bottom: 16px; + } + + .learning-command { + display: grid; + grid-template-columns: minmax(0, 1.1fr) minmax(300px, 0.9fr); + gap: 16px; + padding: 18px; + border: 1px solid var(--border); + border-radius: var(--radius-lg); + background: var(--panel); + box-shadow: var(--shadow-sm); + overflow: hidden; + position: relative; + } + + .learning-command::before { + content: ''; + position: absolute; + inset: 0; + background: + radial-gradient(circle at 14% 18%, rgba(16, 185, 129, 0.11), transparent 28%), + radial-gradient(circle at 84% 20%, rgba(245, 158, 11, 0.10), transparent 26%), + linear-gradient(135deg, rgba(99, 102, 241, 0.08), transparent 48%); + pointer-events: none; + } + + .learning-command > * { + position: relative; + z-index: 1; + } + + .learning-command-title { + display: grid; + gap: 10px; + align-content: start; + } + + .learning-command-title h2 { + margin: 0; + font-size: var(--text-2xl); + line-height: 1.15; + } + + .learning-command-title p { + margin: 0; + max-width: 680px; + color: var(--muted); + font-size: 13px; + line-height: 1.6; + } + + .learning-pulse-grid { + display: grid; + grid-template-columns: repeat(4, minmax(0, 1fr)); + gap: 8px; + margin-top: 4px; + } + + .learning-pulse { + min-width: 0; + padding: 10px 12px; + border: 1px solid var(--border); + border-radius: var(--radius-sm); + background: color-mix(in srgb, var(--panel) 76%, transparent); + backdrop-filter: blur(12px); + } + + .learning-pulse .k { + color: var(--muted); + font-size: 11px; + } + + .learning-pulse .v { + margin-top: 5px; + font-size: 18px; + font-weight: 800; + line-height: 1; + overflow-wrap: anywhere; + } + + .quick-dock { + display: flex; + flex-wrap: wrap; + gap: 8px; + align-content: flex-start; + justify-content: flex-end; + } + + .quick-action { + min-height: 38px; + display: inline-flex; + align-items: center; + justify-content: center; + gap: 7px; + padding: 0 12px; + border: 1px solid var(--border); + border-radius: var(--radius-sm); + background: color-mix(in srgb, var(--panel) 82%, transparent); + color: var(--text); + cursor: pointer; + font: inherit; + text-decoration: none; + font-size: 12px; + font-weight: 650; + line-height: 1; + box-shadow: var(--shadow-sm); + backdrop-filter: blur(14px); + transition: border-color 0.2s ease, color 0.2s ease, background 0.2s ease, box-shadow 0.2s ease; + } + + .quick-action:hover, + .quick-action:focus-visible { + border-color: var(--accent); + color: var(--accent); + background: var(--panel); + box-shadow: var(--shadow); + outline: none; + } + + .quick-action .material-icons { + font-size: 17px; + } + + .learning-module-grid { + display: grid; + grid-template-columns: repeat(3, minmax(0, 1fr)); + gap: 14px; + } + + .learning-module-card { + --module-accent: var(--accent); + min-height: 238px; + display: grid; + gap: 12px; + align-content: space-between; + padding: 16px; + border: 1px solid var(--border); + border-radius: var(--radius-lg); + background: var(--panel); + color: var(--text); + text-decoration: none; + box-shadow: var(--shadow-sm); + position: relative; + overflow: hidden; + transition: border-color 0.25s ease, box-shadow 0.25s ease, background 0.25s ease; + } + + .learning-module-card::before { + content: ''; + position: absolute; + inset: 0 0 auto; + height: 4px; + background: linear-gradient(90deg, var(--module-accent), color-mix(in srgb, var(--module-accent) 35%, transparent)); + } + + .learning-module-card:hover, + .learning-module-card:focus-visible, + .learning-module-card.active { + border-color: color-mix(in srgb, var(--module-accent) 58%, var(--border)); + box-shadow: var(--shadow-lg); + outline: none; + } + + .learning-module-card.jargon { + --module-accent: var(--success); + } + + .learning-module-card.expression { + --module-accent: var(--accent-3); + } + + .learning-module-card.persona { + --module-accent: var(--accent-2); + } + + .learning-module-card .module-icon { + background: color-mix(in srgb, var(--module-accent) 13%, transparent); + color: var(--module-accent); + } + + .learning-module-head { + display: flex; + align-items: flex-start; + justify-content: space-between; + gap: 12px; + min-width: 0; + } + + .learning-module-title { + display: grid; + gap: 5px; + min-width: 0; + } + + .learning-module-title strong { + font-size: 17px; + line-height: 1.25; + } + + .learning-module-title span { + color: var(--muted); + font-size: 12px; + line-height: 1.45; + } + + .learning-module-metrics { + display: grid; + grid-template-columns: repeat(3, minmax(0, 1fr)); + gap: 8px; + } + + .learning-module-metric { + min-width: 0; + padding: 9px 10px; + border: 1px solid var(--border); + border-radius: var(--radius-sm); + background: var(--panel-soft); + } + + .learning-module-metric .k { + color: var(--muted); + font-size: 11px; + white-space: nowrap; + } + + .learning-module-metric .v { + margin-top: 5px; + font-size: 16px; + font-weight: 800; + overflow-wrap: anywhere; + } + + .mini-chart { + width: 100%; + height: 72px; + } + + .home-section-label { + margin: 2px 0 10px; + color: var(--muted); + font-size: 12px; + font-weight: 750; + text-transform: uppercase; + letter-spacing: 0; + } + + .learning-module-page-grid { + display: grid; + grid-template-columns: minmax(0, 0.96fr) minmax(0, 1.36fr); + gap: 16px; + } + + .learning-module-hero { + display: grid; + gap: 14px; + padding: 18px; + border: 1px solid var(--border); + border-radius: var(--radius-lg); + background: var(--panel); + box-shadow: var(--shadow-sm); + } + + .learning-module-hero .icon-row { + display: flex; + justify-content: space-between; + align-items: flex-start; + gap: 12px; + } + + .learning-module-hero h3 { + margin: 0; + font-size: var(--text-xl); + line-height: 1.2; + } + + .learning-module-hero p { + margin: 6px 0 0; + color: var(--muted); + font-size: 13px; + line-height: 1.55; + } + + .module-kpi-strip { + display: grid; + grid-template-columns: repeat(2, minmax(0, 1fr)); + gap: 10px; + } + + .module-kpi { + min-height: 84px; + padding: 12px; + border: 1px solid var(--border); + border-radius: var(--radius-sm); + background: var(--panel-soft); + display: grid; + align-content: space-between; + } + + .module-kpi .k { + color: var(--muted); + font-size: 12px; + } + + .module-kpi .v { + margin-top: 8px; + font-size: 24px; + font-weight: 850; + line-height: 1; + } + + .learning-module-actions { + display: flex; + flex-wrap: wrap; + gap: 8px; + } + + .module-list-compact { + display: grid; + gap: 10px; + max-height: 520px; + overflow: auto; + padding-right: 2px; + } + + .module-chart-large { + height: 300px; + } + + .physics-surface { + --spring-x: 0px; + --spring-y: 0px; + --spring-rx: 0deg; + --spring-ry: 0deg; + --spring-scale: 1; + transform: + perspective(900px) + translate3d(var(--spring-x), var(--spring-y), 0) + rotateX(var(--spring-rx)) + rotateY(var(--spring-ry)) + scale(var(--spring-scale)); + transform-style: preserve-3d; + will-change: transform; + } + + .physics-surface.is-pressing { + cursor: grabbing; + } + .page-titlebar { display: flex; justify-content: space-between; @@ -2002,6 +2355,19 @@ grid-template-columns: repeat(2, minmax(0, 1fr)); } + .learning-command, + .learning-module-page-grid { + grid-template-columns: 1fr; + } + + .learning-module-grid { + grid-template-columns: 1fr; + } + + .quick-dock { + justify-content: flex-start; + } + .grid2, .queue-grid { grid-template-columns: 1fr; @@ -2026,12 +2392,36 @@ } @media (max-width: 720px) { + body { + overflow-x: hidden; + } + .shell { padding: 14px; } .topbar { flex-direction: column; + overflow: hidden; + } + + .toolbar { + justify-content: flex-start; + max-width: 100%; + } + + .pill { + max-width: 100%; + } + + .learning-command { + padding: 14px; + } + + .learning-pulse-grid, + .learning-module-metrics, + .module-kpi-strip { + grid-template-columns: repeat(2, minmax(0, 1fr)); } .page-titlebar { @@ -2071,6 +2461,52 @@ grid-template-columns: 1fr; } } + + @media (max-width: 460px) { + .learning-pulse-grid, + .learning-module-metrics, + .module-kpi-strip { + grid-template-columns: 1fr; + } + + .quick-action { + flex: 1 0 calc(33.333% - 8px); + max-width: calc(33.333% - 5px); + min-width: 0; + padding: 0 8px; + } + + .quick-dock, + .learning-module-actions, + .panel-tools { + width: 100%; + justify-content: flex-start; + } + + .content-search { + width: 100% !important; + } + } + + @media (prefers-reduced-motion: reduce) { + .page.active, + .toast, + .spinner { + animation: none !important; + } + + .physics-surface, + .module-card, + .learning-module-card, + .metric, + .panel, + .item, + .small-stat, + .quick-action { + transition: none !important; + transform: none !important; + } + } .confirm-overlay{display:none;position:fixed;inset:0;background:rgba(0,0,0,.5);z-index:9999;align-items:center;justify-content:center;backdrop-filter:blur(4px);-webkit-backdrop-filter:blur(4px);} .confirm-overlay.show{display:flex;} .confirm-box{background:var(--panel);border:1px solid var(--border);border-radius:var(--radius-lg);padding:var(--space-lg);max-width:360px;width:90%;box-shadow:var(--shadow-lg);} @@ -2080,8 +2516,8 @@ .confirm-box .cbtns button{padding:var(--space-sm) 18px;border-radius:var(--radius-md);border:1px solid var(--border);background:var(--panel);color:var(--text);cursor:pointer;font-size:var(--text-sm);transition:opacity .16s,background .16s;} .confirm-box .cbtns button.danger{background:var(--danger);border-color:var(--danger);color:#fff;} .confirm-box .cbtns button:hover{opacity:.85;} - .toast-wrap{position:fixed;top:var(--space-md);right:var(--space-md);z-index:10000;display:flex;flex-direction:column;gap:var(--space-sm);pointer-events:none;} - .toast{padding:12px 20px;border-radius:var(--radius-md);font-size:var(--text-sm);font-weight:500;color:#fff;pointer-events:auto;animation:toast-in .35s cubic-bezier(0.4,0,0.2,1);box-shadow:var(--shadow-lg);cursor:pointer;backdrop-filter:blur(8px);} + .toast-wrap{position:fixed;top:var(--space-md);right:var(--space-md);left:var(--space-md);z-index:10000;display:flex;flex-direction:column;align-items:flex-end;gap:var(--space-sm);pointer-events:none;max-width:calc(100vw - 32px);} + .toast{padding:12px 20px;border-radius:var(--radius-md);font-size:var(--text-sm);font-weight:500;color:#fff;pointer-events:auto;animation:toast-in .35s cubic-bezier(0.4,0,0.2,1);box-shadow:var(--shadow-lg);cursor:pointer;backdrop-filter:blur(8px);max-width:100%;overflow-wrap:anywhere;} .toast.ok{background:var(--toast-success);} .toast.err{background:var(--toast-error);} .toast.warn{background:var(--toast-warning);} .toast.info{background:var(--toast-info);} .toast.out{animation:toast-out .25s ease forwards;} @keyframes toast-in{from{opacity:0;transform:translateX(60px) scale(0.95)}to{opacity:1;transform:translateX(0) scale(1)}} @@ -2130,6 +2566,138 @@

监控板

+
+
+
+

学习模块控制台

+

黑话、表达方式和人格学习分开运行,首页汇总全部队列、内容、批次与系统状态。

+
+
+
学习效率
+
--
+
+
+
待办总量
+
0
+
+
+
内容样本
+
--
+
+
+
最近批次
+
--
+
+
+
+ +
+ +
+ + +
+
+ + - -
+ + + + +
+ + +
+
+
+
+

黑话词库

+

等待黑话统计。

+
+ translate +
+
+
+
候选词条
+
0
+
+
+
待确认
+
0
+
+
+
已确认
+
0
+
+
+
出现次数
+
0
+
+
+ +
+ +
+
+

黑话分布

+ -- +
+
+
+
+
+
+ +
+
+
+

黑话词条

+ 同步审查队列筛选 +
+
+ + +
+
+
+
+
+
+
+ +
+
+
+

表达方式学习

+

独立查看对话样本、分析结果、表达模式和学习批次。

+
+ apps模块 +
+ +
+
+
+
+

表达样本库

+

等待表达学习内容。

+
+ record_voice_over +
+
+
+
原始对话
+
0
+
+
+
分析结果
+
0
+
+
+
表达模式
+
0
+
+
+
待审风格
+
0
+
+
+
+ article内容浏览 + rule风格审查 + +
+
+ +
+
+

表达学习构成

+ -- +
+
+
+
+
+
+ +
+
+
+

表达内容流

+ 最近学习内容 +
+
+ + + + +
+
+
+
+
+
+
+ +
+
+
+

人格学习

+

独立管理当前人格、人格更新、审查历史、备份和恢复。

+
+ apps模块 +
+ +
+
+
+
+

人格状态

+

等待人格状态。

+
+ psychology +
+
+
+
待审人格
+
0
+
+
+
已审人格
+
0
+
+
+
人格备份
+
0
+
+
+
Prompt
+
0
+
+
+
+ rate_review审查队列 + tune人格配置 + +
+
+ +
+
+

人格学习状态

+ -- +
+
+
+
+
+
+ +
+
+
+

当前人格

+ -- +
+
+
+
+
+
+
+

待审人格更新

+ -- +
+
+
+
+
+
@@ -2937,9 +3749,12 @@

手动安装依赖

items: [], loading: false, }, + learningModules: { + expressionType: 'dialogues', + }, activePage: 'home', }; - const DASHBOARD_PAGES = ['home', 'overview', 'insights', 'monitoring', 'reviews', 'content', 'reply-strategy', 'graphs', 'integrations', 'settings']; + const DASHBOARD_PAGES = ['home', 'overview', 'insights', 'monitoring', 'reviews', 'jargon-learning', 'expression-learning', 'persona-learning', 'content', 'reply-strategy', 'graphs', 'integrations', 'settings']; const numberFmt = new Intl.NumberFormat('zh-CN', { maximumFractionDigits: 1, @@ -3031,6 +3846,27 @@

手动安装依赖

return Array.isArray(value) ? value : []; } + function setText(id, value) { + const element = $(id); + if (element) { + element.textContent = value; + } + } + + function setHTML(id, value) { + const element = $(id); + if (element) { + element.innerHTML = value; + } + } + + function syncFieldValue(id, value) { + const element = $(id); + if (element && document.activeElement !== element) { + element.value = value; + } + } + function patternCount(value) { if (Array.isArray(value)) { return value.length; @@ -3412,9 +4248,178 @@

手动安装依赖

if (!state.batch.items.length || force) { loadBatchList(); } + } else if (page === 'jargon-learning') { + if (!state.jargon.items.length || force) { + loadJargonList(); + } else { + renderLearningModules(state.data || {}); + } + } else if (page === 'expression-learning') { + if (force || !state.content.loaded) { + loadLearningContent({ quiet: true }); + } else { + renderLearningModules(state.data || {}); + } + } else if (page === 'persona-learning') { + renderLearningModules(state.data || {}); } } + function learningModuleSnapshot(data) { + const payload = data || {}; + const metrics = payload.metrics || {}; + const trends = payload.trends || {}; + const persona = payload.persona || {}; + const personaReviewed = payload.personaReviewed || {}; + const personaState = payload.personaState || {}; + const personaBackups = payload.personaBackups || {}; + const style = payload.style || {}; + const contentPayload = getContentPayload(); + const jargonStats = extractJargonStats(payload.jargonStats); + const personaUpdates = safeArray(persona.updates); + const styleBacklog = safeNumber(style.total); + const personaTotal = safeNumber(persona.total); + const personaPayloadIncludesStyle = personaUpdates.some((item) => item && item.review_source === 'style_learning'); + const personaBacklog = personaPayloadIncludesStyle ? Math.max(0, personaTotal - styleBacklog) : personaTotal; + const totalCandidates = safeNumber(jargonStats.total_candidates); + const confirmedJargon = safeNumber(jargonStats.confirmed_jargon); + const pendingJargon = Math.max(0, totalCandidates - confirmedJargon); + const contentCounts = { + dialogues: safeArray(contentPayload.dialogues).length, + analysis: safeArray(contentPayload.analysis).length, + features: safeArray(contentPayload.features).length, + history: safeArray(contentPayload.history).length, + }; + const expressionTotal = Object.values(contentCounts).reduce((sum, value) => sum + value, 0); + const backups = safeArray(personaBackups.backups); + + return { + metrics, + trends, + jargon: { + stats: jargonStats, + totalCandidates, + confirmed: confirmedJargon, + pending: pendingJargon, + occurrences: safeNumber(jargonStats.total_occurrences), + completed: safeNumber(jargonStats.completed_inference), + items: state.jargon.items.length ? state.jargon.items : extractJargonItems(payload.jargonList), + }, + expression: { + content: contentPayload, + counts: contentCounts, + total: expressionTotal, + pending: styleBacklog, + reviews: safeArray(style.reviews), + batches: safeArray(trends.recent_batches), + }, + persona: { + state: personaState, + backups, + backupTotal: safeNumber(personaBackups.total || backups.length), + pending: personaBacklog, + reviewed: safeNumber(personaReviewed.total || safeArray(personaReviewed.updates).length), + updates: personaUpdates.filter((item) => !(item && item.review_source === 'style_learning')), + promptLength: safeNumber(personaState.prompt_length), + beginDialogs: safeNumber(personaState.begin_dialog_count), + }, + backlog: personaBacklog + styleBacklog + pendingJargon, + contentTotal: expressionTotal, + batchCount: safeArray(trends.recent_batches).length, + efficiency: safeNumber(metrics.learning_efficiency), + }; + } + + function moduleStatMarkup(rows) { + return rows.map(([label, value]) => ` +
+
${escapeHtml(label)}
+
${escapeHtml(value)}
+
+ `).join(''); + } + + function renderCompactJargonItems(items, limit = 10) { + const rows = safeArray(items).slice(0, limit).map((item) => { + const id = item.id; + const title = item.term || item.content || item.word || '未命名'; + const definition = item.definition || item.meaning || item.review_detail || item.description || ''; + const meta = [ + item.group_id ? `群组 ${item.group_id}` : null, + item.is_confirmed !== undefined ? (item.is_confirmed ? '已确认' : '待确认') : null, + item.is_global ? '全局' : '本地', + item.occurrences !== undefined ? `出现 ${formatCount(item.occurrences)}` : null, + ].filter(Boolean).join(' · '); + return ` +
+
+
${escapeHtml(title)}
+ ${item.is_confirmed ? '已确认' : '待确认'} +
+
${escapeHtml(meta || '未分组')}
+
${escapeHtml(definition || '暂无释义')}
+
+ ${item.is_confirmed ? actionButton('jargon-edit', id, 'edit', '编辑') : actionButton('jargon-approve', id, 'done', '确认', 'success')} + ${item.is_confirmed ? '' : actionButton('jargon-reject', id, 'block', '驳回', 'warning')} + ${actionButton('jargon-toggle-global', id, item.is_global ? 'public_off' : 'public', item.is_global ? '取消全局' : '设为全局')} +
+
+ `; + }); + return rows.length ? rows.join('') : '
translate

暂无黑话记录

'; + } + + function renderCompactContentItems(type, items, limit = 12) { + const rows = safeArray(items).slice(0, limit).map((item) => { + const title = item.title || item.type || contentTypeLabel(type); + const meta = [item.timestamp ? formatTime(item.timestamp) : null, item.metadata, item.status ? `状态 ${item.status}` : null].filter(Boolean).join(' · '); + const text = item.text || item.detail || renderObjectSummary(item.raw) || '暂无内容'; + return ` +
+
+
${escapeHtml(title)}
+ ${escapeHtml(contentTypeLabel(type))} +
+
${escapeHtml(meta || '学习内容')}
+
${escapeHtml(text)}
+
+ `; + }); + return rows.length ? rows.join('') : `
article

暂无${escapeHtml(contentTypeLabel(type))}

`; + } + + function renderPersonaStateModule(snapshot) { + const statePayload = snapshot.persona.state || {}; + const persona = statePayload.persona || {}; + const services = statePayload.available_services || {}; + const activeServices = Object.entries(services).filter(([, available]) => available).map(([name]) => name); + const promptPreview = statePayload.prompt_preview || persona.prompt || persona.system_prompt || ''; + setText('personaModuleStateHint', statePayload.degraded ? '降级预览' : (persona.name || persona.persona_id || 'default')); + setHTML('personaModuleState', ` +
+
+
${escapeHtml(persona.name || persona.persona_id || '默认人格')}
+ ${escapeHtml(statePayload.group_id || 'default')} +
+
Begin dialogs ${formatCount(snapshot.persona.beginDialogs)} · Tools ${formatCount(statePayload.tool_count || 0)} · 服务 ${escapeHtml(activeServices.length ? activeServices.join(', ') : 'fallback')}
+
${escapeHtml(promptPreview || '暂无人格 Prompt')}
+
+ ${snapshot.persona.backups.slice(0, 3).map((item) => ` +
+
+
${escapeHtml(item.backup_name || `备份 ${item.id}`)}
+ ${escapeHtml(item.persona_name || 'persona')} +
+
${escapeHtml(item.timestamp ? formatTime(item.timestamp) : (item.created_at ? formatTime(item.created_at) : '人格备份'))}
+
+ ${actionButton('persona-backup-view', item.id, 'visibility', '查看')} + ${actionButton('persona-backup-restore', item.id, 'settings_backup_restore', '恢复', 'warning')} +
+
+ `).join('')} + `); + } + function renderHomeModules(data) { const metrics = data.metrics || {}; const health = data.health || {}; @@ -3485,6 +4490,97 @@

手动安装依赖

: '配置中心'; } + function renderLearningModuleCharts(snapshot) { + renderMiniBarChart('homeJargonMiniChart', ['候选', '确认', '待确认'], [ + snapshot.jargon.totalCandidates, + snapshot.jargon.confirmed, + snapshot.jargon.pending, + ], 3); + renderMiniBarChart('homeExpressionMiniChart', ['对话', '分析', '模式', '批次'], [ + snapshot.expression.counts.dialogues, + snapshot.expression.counts.analysis, + snapshot.expression.counts.features, + snapshot.expression.counts.history, + ], 2); + renderMiniBarChart('homePersonaMiniChart', ['待审', '已审', '备份'], [ + snapshot.persona.pending, + snapshot.persona.reviewed, + snapshot.persona.backupTotal, + ], 1); + + renderModuleDonutChart('jargonModuleChart', [ + ['待确认', snapshot.jargon.pending], + ['已确认', snapshot.jargon.confirmed], + ['完成推理', snapshot.jargon.completed], + ], `出现 ${formatCount(snapshot.jargon.occurrences)}`); + renderModuleDonutChart('expressionModuleChart', [ + ['原始对话', snapshot.expression.counts.dialogues], + ['分析结果', snapshot.expression.counts.analysis], + ['表达模式', snapshot.expression.counts.features], + ['学习批次', snapshot.expression.counts.history], + ], `总计 ${formatCount(snapshot.expression.total)}`); + renderModuleGaugeChart('personaModuleChart', [ + ['待审人格', snapshot.persona.pending], + ['已审人格', snapshot.persona.reviewed], + ['人格备份', snapshot.persona.backupTotal], + ['Begin dialogs', snapshot.persona.beginDialogs], + ], `Prompt ${formatCount(snapshot.persona.promptLength)}`); + } + + function renderLearningModules(data) { + const snapshot = learningModuleSnapshot(data); + setText('learningPulseEfficiency', formatPercent(snapshot.efficiency)); + setText('learningPulseBacklog', formatCount(snapshot.backlog)); + setText('learningPulseContent', snapshot.contentTotal ? formatCount(snapshot.contentTotal) : '--'); + setText('learningPulseBatches', snapshot.batchCount ? formatCount(snapshot.batchCount) : '--'); + + setText('homeJargonCandidates', formatCount(snapshot.jargon.totalCandidates)); + setText('homeJargonConfirmed', formatCount(snapshot.jargon.confirmed)); + setText('homeJargonPending', formatCount(snapshot.jargon.pending)); + setText('homeExpressionSamples', snapshot.expression.total ? formatCount(snapshot.expression.total) : '--'); + setText('homeExpressionPatterns', snapshot.expression.counts.features ? formatCount(snapshot.expression.counts.features) : '--'); + setText('homeExpressionPending', formatCount(snapshot.expression.pending)); + setText('homePersonaPending', formatCount(snapshot.persona.pending)); + setText('homePersonaReviewed', formatCount(snapshot.persona.reviewed)); + setText('homePersonaBackups', formatCount(snapshot.persona.backupTotal)); + + setText('jargonModuleCandidates', formatCount(snapshot.jargon.totalCandidates)); + setText('jargonModulePending', formatCount(snapshot.jargon.pending)); + setText('jargonModuleConfirmed', formatCount(snapshot.jargon.confirmed)); + setText('jargonModuleOccurrences', formatCount(snapshot.jargon.occurrences)); + setText('jargonLearningSummary', `候选 ${formatCount(snapshot.jargon.totalCandidates)} · 确认 ${formatCount(snapshot.jargon.confirmed)} · 出现 ${formatCount(snapshot.jargon.occurrences)}`); + setText('jargonModuleChartHint', `待确认 ${formatCount(snapshot.jargon.pending)} 条`); + setText('jargonModuleListHint', `当前筛选 ${state.jargon.filter} · ${formatCount(state.jargon.total || snapshot.jargon.items.length)} 条`); + setHTML('jargonModuleList', renderCompactJargonItems(snapshot.jargon.items)); + syncFieldValue('jargonModuleFilterSelect', state.jargon.filter); + syncFieldValue('jargonModuleSearchInput', state.jargon.search || ''); + + setText('expressionModuleDialogues', formatCount(snapshot.expression.counts.dialogues)); + setText('expressionModuleAnalysis', formatCount(snapshot.expression.counts.analysis)); + setText('expressionModuleFeatures', formatCount(snapshot.expression.counts.features)); + setText('expressionModulePending', formatCount(snapshot.expression.pending)); + setText('expressionLearningSummary', `样本 ${formatCount(snapshot.expression.total)} · 待审风格 ${formatCount(snapshot.expression.pending)} · 最近批次 ${formatCount(snapshot.batchCount)}`); + setText('expressionModuleChartHint', `${formatCount(snapshot.expression.total)} 条学习内容`); + const expressionType = state.learningModules.expressionType || 'dialogues'; + document.querySelectorAll('[data-expression-module-type]').forEach((tab) => { + tab.classList.toggle('active', tab.dataset.expressionModuleType === expressionType); + }); + setText('expressionModuleListHint', `${contentTypeLabel(expressionType)} · ${formatCount(safeArray(snapshot.expression.content[expressionType]).length)} 条`); + setHTML('expressionModuleList', renderCompactContentItems(expressionType, snapshot.expression.content[expressionType])); + + setText('personaModulePending', formatCount(snapshot.persona.pending)); + setText('personaModuleReviewed', formatCount(snapshot.persona.reviewed)); + setText('personaModuleBackups', formatCount(snapshot.persona.backupTotal)); + setText('personaModulePrompt', formatCount(snapshot.persona.promptLength)); + setText('personaLearningSummary', `待审 ${formatCount(snapshot.persona.pending)} · 已审 ${formatCount(snapshot.persona.reviewed)} · 备份 ${formatCount(snapshot.persona.backupTotal)}`); + setText('personaModuleChartHint', `Prompt ${formatCount(snapshot.persona.promptLength)} 字 · Begin dialogs ${formatCount(snapshot.persona.beginDialogs)}`); + setText('personaModuleListHint', `${formatCount(snapshot.persona.pending)} 条待审`); + setHTML('personaModuleList', renderPersonaQueue(snapshot.persona.updates)); + renderPersonaStateModule(snapshot); + + renderLearningModuleCharts(snapshot); + } + function collectLLMCalls(llmCalls, llmCallBreakdown, llmCallSummary, filterModelSummary) { const sourceRows = safeArray(llmCallBreakdown).length ? safeArray(llmCallBreakdown) @@ -3864,7 +4960,7 @@

手动安装依赖

llm: 'monitoring', efficiency: 'monitoring', reviews: 'reviews', - jargon: 'reviews', + jargon: 'jargon-learning', batches: 'reviews', messages: 'overview', stable: 'overview', @@ -4074,6 +5170,182 @@

手动安装依赖

return state.charts[id]; } + function renderMiniBarChart(id, labels, values, accentIndex = 0) { + const chart = ensureChart(id); + if (!chart) { + return; + } + const palette = getThemePalette(); + chart.setOption({ + backgroundColor: 'transparent', + animation: true, + grid: { left: 4, right: 4, top: 6, bottom: 4 }, + xAxis: { + type: 'category', + data: labels, + show: false, + }, + yAxis: { + type: 'value', + show: false, + }, + tooltip: { + trigger: 'axis', + backgroundColor: getTooltipBackground(), + borderColor: getTooltipBorderColor(), + borderWidth: 1, + textStyle: { color: getTextColor() }, + }, + series: [{ + type: 'bar', + data: values, + barWidth: '48%', + itemStyle: { + color: palette[accentIndex] || palette[0], + borderRadius: [6, 6, 2, 2], + }, + }], + }, true); + } + + function renderModuleDonutChart(id, rows, centerLabel) { + const chart = ensureChart(id); + if (!chart) { + return; + } + const data = rows + .map(([name, value]) => ({ name, value: Math.max(0, safeNumber(value)) })) + .filter((item) => item.value > 0); + + if (!data.length) { + chart.setOption({ + graphic: [{ + type: 'text', + left: 'center', + top: 'middle', + style: { + text: '暂无数据', + fill: getTextColor(), + fontSize: 14, + }, + }], + }, true); + return; + } + + chart.setOption({ + backgroundColor: 'transparent', + color: getThemePalette(), + textStyle: { color: getTextColor() }, + tooltip: { + trigger: 'item', + backgroundColor: getTooltipBackground(), + borderColor: getTooltipBorderColor(), + borderWidth: 1, + textStyle: { color: getTextColor() }, + }, + legend: { + bottom: 0, + itemWidth: 10, + itemHeight: 10, + textStyle: { color: getMutedColor(), fontSize: 11 }, + }, + series: [{ + type: 'pie', + radius: ['48%', '70%'], + center: ['50%', '43%'], + avoidLabelOverlap: true, + label: { + formatter: '{b}\n{c}', + color: getTextColor(), + fontSize: 12, + }, + labelLine: { + length: 10, + length2: 8, + lineStyle: { color: getBorderColor() }, + }, + emphasis: { + scale: true, + scaleSize: 8, + }, + data, + }], + graphic: [{ + type: 'text', + left: 'center', + top: '38%', + style: { + text: centerLabel || '', + fill: getMutedColor(), + fontSize: 12, + fontWeight: 700, + textAlign: 'center', + }, + }], + }, true); + } + + function renderModuleGaugeChart(id, rows, centerValue) { + const chart = ensureChart(id); + if (!chart) { + return; + } + const maxValue = Math.max(1, ...rows.map(([, value]) => safeNumber(value))); + chart.setOption({ + backgroundColor: 'transparent', + color: getThemePalette(), + textStyle: { color: getTextColor() }, + grid: { left: 100, right: 28, top: 16, bottom: 24 }, + tooltip: { + trigger: 'axis', + axisPointer: { type: 'shadow' }, + backgroundColor: getTooltipBackground(), + borderColor: getTooltipBorderColor(), + borderWidth: 1, + textStyle: { color: getTextColor() }, + }, + xAxis: { + type: 'value', + max: Math.ceil(maxValue * 1.2), + axisLabel: { color: getMutedColor() }, + splitLine: { lineStyle: { color: getGridColor() } }, + }, + yAxis: { + type: 'category', + data: rows.map(([label]) => label), + axisLabel: { color: getTextColor() }, + axisLine: { lineStyle: { color: getBorderColor() } }, + }, + series: [{ + type: 'bar', + data: rows.map(([, value]) => safeNumber(value)), + barWidth: 18, + itemStyle: { + color: getTrendBarGradient(), + borderRadius: [0, 8, 8, 0], + }, + label: { + show: true, + position: 'right', + color: getMutedColor(), + formatter: '{c}', + }, + }], + graphic: centerValue ? [{ + type: 'text', + right: 18, + top: 8, + style: { + text: centerValue, + fill: getMutedColor(), + fontSize: 12, + fontWeight: 700, + }, + }] : [], + }, true); + } + function confidenceLabel(value) { if (value === null || value === undefined || value === '') { return null; @@ -5934,6 +7206,20 @@

${escapeHtml(item.title)}

} async function runDashboardAction(action, id) { + if (action === 'module-refresh') { + if (id === 'jargon') { + await loadJargonList(); + } else if (id === 'expression') { + await loadLearningContent(); + } else if (id === 'persona') { + await loadDashboard(); + } else { + await loadDashboard(); + } + showToast('模块已刷新', 'ok'); + return; + } + if (id === undefined || id === null || id === '') { setQueueStatus('缺少可操作的 ID。', 'error'); return; @@ -6625,6 +7911,7 @@

${escapeHtml(item.title)}

renderLearningContent(); if (state.data) { renderHomeModules(state.data); + renderLearningModules(state.data); } } } @@ -6634,11 +7921,13 @@

${escapeHtml(item.title)}

renderSummary(data); renderInsightPanel(data); renderHomeModules(data); + renderLearningModules(data); renderTrendChart(data.trends || {}); renderEfficiencyChart(data.metrics || {}); if (state.content.data) { renderLearningContent(); renderHomeModules(data); + renderLearningModules(data); } if (Object.values(state.charts).length) { Object.values(state.charts).forEach((chart) => chart && chart.resize()); @@ -6726,6 +8015,61 @@

${escapeHtml(item.title)}

} } + function initPhysicsMotion() { + const reduceMotion = window.matchMedia && window.matchMedia('(prefers-reduced-motion: reduce)').matches; + if (reduceMotion) { + return; + } + + const surfaces = Array.from(document.querySelectorAll('[data-physics]')); + surfaces.forEach((surface) => { + const settle = () => { + surface.classList.remove('is-pressing'); + surface.animate([ + { + transform: getComputedStyle(surface).transform === 'none' + ? 'perspective(900px) translate3d(0, 0, 0) rotateX(0deg) rotateY(0deg) scale(1)' + : getComputedStyle(surface).transform, + }, + { + transform: 'perspective(900px) translate3d(0, 0, 0) rotateX(0deg) rotateY(0deg) scale(1)', + }, + ], { + duration: 520, + easing: 'cubic-bezier(0.22, 1.35, 0.36, 1)', + }); + surface.style.setProperty('--spring-x', '0px'); + surface.style.setProperty('--spring-y', '0px'); + surface.style.setProperty('--spring-rx', '0deg'); + surface.style.setProperty('--spring-ry', '0deg'); + surface.style.setProperty('--spring-scale', '1'); + }; + + surface.addEventListener('pointermove', (event) => { + const rect = surface.getBoundingClientRect(); + if (!rect.width || !rect.height) { + return; + } + const px = (event.clientX - rect.left) / rect.width - 0.5; + const py = (event.clientY - rect.top) / rect.height - 0.5; + const strength = surface.classList.contains('is-pressing') ? 1.2 : 0.72; + surface.style.setProperty('--spring-x', `${(px * 7 * strength).toFixed(2)}px`); + surface.style.setProperty('--spring-y', `${(py * 7 * strength).toFixed(2)}px`); + surface.style.setProperty('--spring-rx', `${(-py * 4 * strength).toFixed(2)}deg`); + surface.style.setProperty('--spring-ry', `${(px * 4 * strength).toFixed(2)}deg`); + }); + + surface.addEventListener('pointerdown', () => { + surface.classList.add('is-pressing'); + surface.style.setProperty('--spring-scale', '0.985'); + }); + surface.addEventListener('pointerup', settle); + surface.addEventListener('pointercancel', settle); + surface.addEventListener('pointerleave', settle); + surface.addEventListener('blur', settle); + }); + } + function updateLastUpdated() { $('lastUpdated').textContent = new Intl.DateTimeFormat('zh-CN', { hour: '2-digit', @@ -6774,6 +8118,7 @@

${escapeHtml(item.title)}

style, jargonStats, jargonList, + contentText, ] = await Promise.all([ safeFetch('/api/metrics'), safeFetch('/api/metrics/trends'), @@ -6786,6 +8131,7 @@

${escapeHtml(item.title)}

safeFetch('/api/style_learning/reviews?limit=5'), safeFetch('/api/jargon/stats'), safeFetch('/api/jargon/list?page_size=5&confirmed=false&pending=true'), + safeFetch('/api/style_learning/content_text'), ]); const fetchErrors = []; @@ -6811,6 +8157,11 @@

${escapeHtml(item.title)}

_errors: fetchErrors, }; + if (contentText) { + state.content.data = contentText; + state.content.loaded = true; + } + renderAll(merged); if (fetchErrors.length) { @@ -6865,6 +8216,9 @@

${escapeHtml(item.title)}

state.jargon.totalPages = Math.max(1, Math.ceil(total / state.jargon.pageSize)); renderJargonPanel(); + if (state.data) { + renderLearningModules(state.data); + } } catch (e) { console.error('Failed to load jargon list:', e); } finally { @@ -6897,6 +8251,9 @@

${escapeHtml(item.title)}

} renderBatchPanel(); + if (state.data) { + renderLearningModules(state.data); + } } catch (e) { console.error('Failed to load batch list:', e); } finally { @@ -6986,6 +8343,9 @@

${escapeHtml(item.title)}

$('jargonNextBtn').disabled = page >= totalPages; $('jargonHint').textContent = `候选 ${formatCount((jargonStats.total_candidates ?? jargonStats.totalCandidates) || 0)} · 第 ${page}/${totalPages} 页`; + if (state.data) { + renderLearningModules(state.data); + } } function renderBatchPanel() { @@ -7105,6 +8465,7 @@

${escapeHtml(item.title)}

$('jargonFilterSelect').addEventListener('change', (event) => { state.jargon.filter = event.target.value; state.jargon.page = 1; + syncFieldValue('jargonModuleFilterSelect', state.jargon.filter); loadJargonList(); }); $('jargonSortSelect').addEventListener('change', (event) => { @@ -7118,9 +8479,32 @@

${escapeHtml(item.title)}

jargonSearchTimer = setTimeout(() => { state.jargon.search = event.target.value || ''; state.jargon.page = 1; + syncFieldValue('jargonModuleSearchInput', state.jargon.search); loadJargonList(); }, 300); }); + const jargonModuleFilterSelect = $('jargonModuleFilterSelect'); + if (jargonModuleFilterSelect) { + jargonModuleFilterSelect.addEventListener('change', (event) => { + state.jargon.filter = event.target.value; + state.jargon.page = 1; + syncFieldValue('jargonFilterSelect', state.jargon.filter); + loadJargonList(); + }); + } + const jargonModuleSearchInput = $('jargonModuleSearchInput'); + if (jargonModuleSearchInput) { + let jargonModuleSearchTimer = null; + jargonModuleSearchInput.addEventListener('input', (event) => { + clearTimeout(jargonModuleSearchTimer); + jargonModuleSearchTimer = setTimeout(() => { + state.jargon.search = event.target.value || ''; + state.jargon.page = 1; + syncFieldValue('jargonSearchInput', state.jargon.search); + loadJargonList(); + }, 300); + }); + } $('jargonPrevBtn').addEventListener('click', () => { if (state.jargon.page > 1) { state.jargon.page--; @@ -7197,6 +8581,15 @@

${escapeHtml(item.title)}

} return; } + + const expressionModuleTab = event.target.closest('[data-expression-module-type]'); + if (expressionModuleTab) { + state.learningModules.expressionType = ['dialogues', 'analysis', 'features', 'history'].includes(expressionModuleTab.dataset.expressionModuleType) + ? expressionModuleTab.dataset.expressionModuleType + : 'dialogues'; + renderLearningModules(state.data || {}); + return; + } }); const graphReloadBtn = $('graphReloadBtn'); @@ -7243,6 +8636,7 @@

${escapeHtml(item.title)}

function init() { applyTheme(state.theme); bindEvents(); + initPhysicsMotion(); navigateToPage(resolvePageFromHash(), { skipHash: true, instant: true }); loadDashboard(); loadConfigPanel(); From 4a43702565c90b0da640cd0d87c6fa7b68de0332 Mon Sep 17 00:00:00 2001 From: EterUltimate <1831303476@qq.com> Date: Sun, 7 Jun 2026 20:27:16 +0800 Subject: [PATCH 3/3] fix: resolve persona backups and command filters --- services/learning/sample_filter.py | 39 ++++++++++++++++- tests/integration/test_webui_static_assets.py | 4 +- tests/unit/test_learning_chain_regressions.py | 7 +++- tests/unit/test_persona_backup_service.py | 25 +++++++++++ web_res/static/html/dashboard.html | 35 ++++++++++------ webui/blueprints/personas.py | 8 ++-- webui/services/persona_backup_service.py | 42 +++++++++++++------ 7 files changed, 128 insertions(+), 32 deletions(-) diff --git a/services/learning/sample_filter.py b/services/learning/sample_filter.py index 109008e5..d0784dde 100644 --- a/services/learning/sample_filter.py +++ b/services/learning/sample_filter.py @@ -21,6 +21,21 @@ "model", "tools", } +COMMAND_LIKE_PATTERN = re.compile(r"^[/!#.][^\W\d_][\w-]*(?:\s+.*)?$") +COMMAND_REFERENCE_PATTERN = re.compile(r"[/!#.][^\W\d_][\w-]*") +COMMAND_GUIDANCE_PATTERNS = ( + re.compile( + r"(?:使用|输入|发送|执行|运行|调用|通过).{0,12}" + r"(?:命令|指令|菜单|帮助|功能|插件).{0,24}" + r"[/!#.][^\W\d_][\w-]*", + re.IGNORECASE, + ), + re.compile( + r"[/!#.][^\W\d_][\w-]*.{0,24}" + r"(?:命令|指令|菜单|帮助|功能|插件)", + re.IGNORECASE, + ), +) SYSTEM_RESPONSE_PATTERNS = ( re.compile(r"^AstrBot\s+v?\d", re.IGNORECASE), @@ -290,10 +305,30 @@ def is_command_message(message_text: Any) -> bool: normalized = first_token[1:].lower() if has_prefix else first_token.lower() if has_prefix: - return normalized in BARE_COMMANDS + if normalized in BARE_COMMANDS: + return True + return bool(COMMAND_LIKE_PATTERN.match(text)) return text.lower() in BARE_COMMANDS +def is_command_guidance(message_text: Any) -> bool: + """Return true for help/menu text that teaches or lists commands.""" + text = _normalize_text(message_text) + if not text: + return False + + command_refs = COMMAND_REFERENCE_PATTERN.findall(text) + if command_refs and any(pattern.search(text) for pattern in COMMAND_GUIDANCE_PATTERNS): + return True + + if len(command_refs) >= 2 and any( + keyword in text + for keyword in ("命令", "指令", "菜单", "帮助", "功能", "使用", "输入", "发送") + ): + return True + return False + + def is_system_response(message_text: Any) -> bool: """Return true for framework help/error output that should not be learned.""" text = _normalize_text(message_text) @@ -302,6 +337,8 @@ def is_system_response(message_text: Any) -> bool: if any(pattern.search(text) for pattern in SYSTEM_RESPONSE_PATTERNS): return True + if is_command_guidance(text): + return True help_lines = sum( 1 diff --git a/tests/integration/test_webui_static_assets.py b/tests/integration/test_webui_static_assets.py index 94b3ac9d..d65a2587 100644 --- a/tests/integration/test_webui_static_assets.py +++ b/tests/integration/test_webui_static_assets.py @@ -162,7 +162,9 @@ def test_dashboard_exposes_persona_state_and_backup_management(): assert "personaStateStats" in text assert "personaBackupList" in text assert "/api/persona_management/current?group_id=default" in text - assert "/api/persona_backups/list?group_id=default&limit=8" in text + assert "/api/persona_backups/list?limit=8" in text + assert "data-${key}" in text + assert "group-id" in text assert "persona-backup-view" in text assert "persona-backup-restore" in text assert "persona-backup-delete" in text diff --git a/tests/unit/test_learning_chain_regressions.py b/tests/unit/test_learning_chain_regressions.py index 948ffc24..15d3f8a0 100644 --- a/tests/unit/test_learning_chain_regressions.py +++ b/tests/unit/test_learning_chain_regressions.py @@ -310,7 +310,12 @@ def test_learning_sample_filter_blocks_commands_and_system_outputs(): assert should_ignore_learning_sample("/help me") is True assert should_ignore_learning_sample("help") is True assert should_ignore_learning_sample("help me") is False - assert should_ignore_learning_sample("/a hello") is False + assert should_ignore_learning_sample("/a hello") is True + assert should_ignore_learning_sample("/帮助 查看菜单") is True + assert should_ignore_learning_sample("使用 /help 命令查看菜单", sender_id="bot", is_bot=True) is True + assert should_ignore_learning_sample("发送 /帮助 查看菜单", sender_id="bot", is_bot=True) is True + assert should_ignore_learning_sample("可用命令:/help /provider", sender_id="bot", is_bot=True) is True + assert should_ignore_learning_sample("我不会使用命令式的语气聊天") is False assert should_ignore_learning_sample(bot_help, sender_id="bot", is_bot=True) is True assert should_ignore_learning_sample(livingmemory_shutdown_log) is True assert should_ignore_learning_sample("MemoryEngine 已关闭") is True diff --git a/tests/unit/test_persona_backup_service.py b/tests/unit/test_persona_backup_service.py index d1fe55d9..088ad999 100644 --- a/tests/unit/test_persona_backup_service.py +++ b/tests/unit/test_persona_backup_service.py @@ -52,6 +52,31 @@ async def test_get_backup_uses_database_detail(mock_container): assert result['summary']['prompt_length'] == len('content') +@pytest.mark.asyncio +async def test_list_backups_without_group_queries_all_groups(mock_container): + mock_container.database_manager.get_persona_backups = AsyncMock(return_value=[ + { + 'id': 4, + 'group_id': 'real-group', + 'backup_name': 'real-group-backup', + 'timestamp': 1710000001, + 'backup_reason': 'Before style update', + 'original_persona': {'name': 'Group Persona', 'prompt': 'hello'}, + } + ]) + service = PersonaBackupService(mock_container) + + result = await service.list_backups(limit=8) + + assert result['group_id'] is None + assert result['backups'][0]['group_id'] == 'real-group' + mock_container.database_manager.get_persona_backups.assert_awaited_once_with( + group_id=None, + limit=8, + include_content=True, + ) + + @pytest.mark.asyncio async def test_restore_backup_falls_back_to_persona_manager(mock_container): mock_container.persona_backup_manager = None diff --git a/web_res/static/html/dashboard.html b/web_res/static/html/dashboard.html index b01537aa..aa835521 100644 --- a/web_res/static/html/dashboard.html +++ b/web_res/static/html/dashboard.html @@ -4412,8 +4412,8 @@

手动安装依赖

${escapeHtml(item.timestamp ? formatTime(item.timestamp) : (item.created_at ? formatTime(item.created_at) : '人格备份'))}
- ${actionButton('persona-backup-view', item.id, 'visibility', '查看')} - ${actionButton('persona-backup-restore', item.id, 'settings_backup_restore', '恢复', 'warning')} + ${actionButton('persona-backup-view', item.id, 'visibility', '查看', '', { 'group-id': item.group_id })} + ${actionButton('persona-backup-restore', item.id, 'settings_backup_restore', '恢复', 'warning', { 'group-id': item.group_id })}
`).join('')} @@ -5355,7 +5355,7 @@

手动安装依赖

return `置信 ${formatPercent(score <= 1 ? score * 100 : score)}`; } - function actionButton(action, id, icon, label, tone = '') { + function actionButton(action, id, icon, label, tone = '', attrs = {}) { if (id === undefined || id === null || id === '') { return ''; } @@ -5364,12 +5364,17 @@

手动安装依赖

const displayIcon = icon; const displayLabel = label; const displayTone = tone; + const attrHtml = Object.entries(attrs || {}) + .filter(([, value]) => value !== undefined && value !== null && value !== '') + .map(([key, value]) => `data-${key}="${escapeHtml(value)}"`) + .join(' '); return `