diff --git a/services/database/facades/learning_facade.py b/services/database/facades/learning_facade.py index fd7d2d53..03dd8fb4 100644 --- a/services/database/facades/learning_facade.py +++ b/services/database/facades/learning_facade.py @@ -219,7 +219,8 @@ async def save_persona_update_record(self, record: Dict[str, Any]) -> int: return await self.add_persona_learning_review(record) async def update_persona_update_record_status( - self, record_id: int, new_status: str, reviewer_comment: str = '' + self, record_id: int, new_status: str, reviewer_comment: str = '', + group_id: str = None, ) -> bool: """更新人格更新记录的状态 @@ -237,6 +238,8 @@ async def update_persona_update_record_status( stmt = select(PersonaLearningReview).where( PersonaLearningReview.id == record_id ) + if group_id: + stmt = stmt.where(PersonaLearningReview.group_id == group_id) result = await session.execute(stmt) record = result.scalar_one_or_none() if not record: @@ -531,7 +534,7 @@ async def get_persona_learning_review_by_id( async def update_persona_learning_review_status( self, review_id, new_status, reviewer_comment='', - modified_content=None, + modified_content=None, group_id: str = None, ) -> bool: """更新人格学习审核记录状态 @@ -550,6 +553,8 @@ async def update_persona_learning_review_status( stmt = select(PersonaLearningReview).where( PersonaLearningReview.id == review_id ) + if group_id: + stmt = stmt.where(PersonaLearningReview.group_id == group_id) result = await session.execute(stmt) record = result.scalar_one_or_none() if not record: @@ -772,7 +777,7 @@ async def get_reviewed_style_learning_updates( return [] async def update_style_review_status( - self, review_id, new_status, reviewer_comment='' + self, review_id, new_status, reviewer_comment='', group_id: str = None, ) -> bool: """更新风格学习审核记录状态 @@ -790,6 +795,8 @@ async def update_style_review_status( stmt = select(StyleLearningReview).where( StyleLearningReview.id == review_id ) + if group_id: + stmt = stmt.where(StyleLearningReview.group_id == group_id) result = await session.execute(stmt) record = result.scalar_one_or_none() if not record: diff --git a/services/database/sqlalchemy_database_manager.py b/services/database/sqlalchemy_database_manager.py index 82fc5b4a..7926a1e9 100644 --- a/services/database/sqlalchemy_database_manager.py +++ b/services/database/sqlalchemy_database_manager.py @@ -783,11 +783,12 @@ async def get_reviewed_persona_update_records( async def update_persona_update_record_status( self, record_id: int, status: str, comment: str = None, + group_id: str = None, ) -> bool: return await self._call_learning( "update_persona_update_record_status", False, - record_id, status, comment, + record_id, status, comment, group_id, ) async def create_style_learning_review( @@ -811,11 +812,12 @@ async def get_reviewed_style_learning_updates( async def update_style_review_status( self, review_id: int, status: str, reviewer_comment: str = '', + group_id: str = None, ) -> bool: return await self._call_learning( "update_style_review_status", False, - review_id, status, reviewer_comment, + review_id, status, reviewer_comment, group_id, ) async def update_style_review_metadata( @@ -870,12 +872,12 @@ async def get_persona_learning_review_by_id( async def update_persona_learning_review_status( self, review_id: int, status: str, comment: str = None, - modified_content: str = None, + modified_content: str = None, group_id: str = None, ) -> bool: return await self._call_learning( "update_persona_learning_review_status", False, - review_id, status, comment, modified_content, + review_id, status, comment, modified_content, group_id, ) async def update_persona_learning_review_metadata( diff --git a/tests/unit/test_persona_backup_service.py b/tests/unit/test_persona_backup_service.py index 088ad999..20994277 100644 --- a/tests/unit/test_persona_backup_service.py +++ b/tests/unit/test_persona_backup_service.py @@ -1,5 +1,6 @@ """Unit tests for WebUI persona backup management service.""" from unittest.mock import AsyncMock +from unittest.mock import Mock import pytest @@ -102,6 +103,50 @@ async def test_restore_backup_falls_back_to_persona_manager(mock_container): mock_container.persona_manager.update_persona.assert_awaited_once() +@pytest.mark.asyncio +async def test_restore_default_backup_targets_current_real_persona(mock_container): + """Backups created from the legacy default placeholder restore into the real current persona.""" + mock_container.persona_backup_manager = None + mock_container.persona_web_manager = Mock() + mock_container.persona_web_manager.get_persona_for_group = AsyncMock(return_value={ + 'persona_id': 'default', + 'system_prompt': '', + 'begin_dialogs': [], + }) + mock_container.persona_web_manager.get_all_personas_for_web = AsyncMock(return_value=[ + { + 'persona_id': 'suleng', + 'name': 'SuLeng', + 'system_prompt': 'current prompt', + 'begin_dialogs': [], + } + ]) + mock_container.persona_web_manager.update_persona_via_web = AsyncMock(return_value={ + 'success': True, + }) + mock_container.database_manager.get_persona_backup = AsyncMock(return_value={ + 'id': 10, + 'group_id': 'default', + 'backup_name': 'legacy-default', + 'original_persona': { + 'persona_id': 'default', + 'name': 'default', + 'prompt': 'restored prompt', + }, + 'imitation_dialogues': [], + }) + service = PersonaBackupService(mock_container) + + success, message = await service.restore_backup(10, group_id='default') + + assert success is True + assert '恢复成功' in message + persona_id, payload = mock_container.persona_web_manager.update_persona_via_web.await_args.args + assert persona_id == 'suleng' + assert payload['persona_id'] == 'suleng' + assert payload['system_prompt'] == 'restored prompt' + + @pytest.mark.asyncio async def test_delete_backup_uses_database_manager(mock_container): mock_container.database_manager.delete_persona_backup = AsyncMock(return_value=True) diff --git a/tests/unit/test_persona_review_service.py b/tests/unit/test_persona_review_service.py index c90e17b2..221d8802 100644 --- a/tests/unit/test_persona_review_service.py +++ b/tests/unit/test_persona_review_service.py @@ -174,6 +174,29 @@ async def test_review_persona_update_approve_style_learning(self, mock_container assert '批准' in message or 'approved' in message.lower() mock_container.database_manager.update_style_review_status.assert_called() + @pytest.mark.asyncio + async def test_style_learning_approval_locks_review_by_group_without_polluting_comment( + self, mock_container, sample_style_review_data + ): + """Approving one style review should lock by row/group and keep comments clean.""" + service = PersonaReviewService(mock_container) + sample_style_review_data["group_id"] = "group-a" + mock_container.database_manager.get_pending_style_reviews.return_value = [ + sample_style_review_data, + {**sample_style_review_data, "id": 2, "group_id": "group-b"}, + ] + mock_container.database_manager.update_style_review_status.return_value = True + + success, message = await service.review_persona_update("style_1", "approve") + + assert success is True + assert "group-a" not in message + mock_container.database_manager.update_style_review_status.assert_awaited_once_with( + 1, + "approved", + group_id="group-a", + ) + @pytest.mark.asyncio async def test_style_learning_preview_targets_begin_dialogs(self, mock_container, sample_style_review_data): """Style review previews should show begin_dialogs changes, not system prompt edits.""" diff --git a/webui/services/persona_backup_service.py b/webui/services/persona_backup_service.py index 2e70a3eb..030ba0a8 100644 --- a/webui/services/persona_backup_service.py +++ b/webui/services/persona_backup_service.py @@ -7,6 +7,17 @@ from .persona_service import _optional_container_attr +try: + from ...utils.persona_selection import ( + resolve_target_persona, + resolve_target_persona_from_web, + ) +except ImportError: + from utils.persona_selection import ( + resolve_target_persona, + resolve_target_persona_from_web, + ) + class PersonaBackupService: """人格备份管理服务。""" @@ -17,6 +28,13 @@ def __init__(self, container): self.persona_backup_manager = _optional_container_attr(container, 'persona_backup_manager') self.persona_manager = _optional_container_attr(container, 'persona_manager') self.persona_web_mgr = _optional_container_attr(container, 'persona_web_manager') + self.astrbot_persona_manager = ( + _optional_container_attr(container, 'astrbot_persona_manager') + or self.persona_manager + ) + self.plugin_config = _optional_container_attr(container, 'plugin_config') + group_mapping = _optional_container_attr(container, 'group_id_to_unified_origin', {}) + self.group_id_to_unified_origin = group_mapping if isinstance(group_mapping, dict) else {} @staticmethod def _normalize_limit(limit: Any, default: int = 20, maximum: int = 100) -> int: @@ -75,6 +93,14 @@ def _ensure_database(self): if not self.database_manager: raise ValueError('数据库服务未初始化,无法管理人格备份') + def _resolve_umo(self, group_id: str) -> str: + return self.group_id_to_unified_origin.get(group_id, group_id) + + @staticmethod + def _is_placeholder_persona_id(persona_id: Any) -> bool: + value = str(persona_id or '').strip().lower() + return not value or value in {'default', '[%none]'} + @staticmethod def _normalize_group_id(group_id: Any) -> Optional[str]: if group_id is None: @@ -142,8 +168,24 @@ async def restore_backup(self, backup_id: int, group_id: Optional[str] = None) - normalized_group_id = self._normalize_group_id(group_id) backup = await self.get_backup(backup_id, group_id=normalized_group_id) effective_group_id = normalized_group_id or backup.get('group_id') or 'default' + persona = self._persona_from_backup(backup) + if not persona: + return False, '备份中没有可恢复的人格内容' - if self.persona_backup_manager and hasattr(self.persona_backup_manager, 'restore_backup'): + target_persona_id = await self._resolve_restore_persona_id(persona, effective_group_id) + if target_persona_id: + persona['persona_id'] = target_persona_id + if self._is_placeholder_persona_id(persona.get('name')): + persona['name'] = target_persona_id + + if ( + self.persona_backup_manager + and hasattr(self.persona_backup_manager, 'restore_backup') + and not self._is_placeholder_persona_id( + (backup.get('original_persona') or {}).get('persona_id') + or (backup.get('original_persona') or {}).get('name') + ) + ): try: success = await self.persona_backup_manager.restore_backup(effective_group_id, backup_id) if success: @@ -151,10 +193,6 @@ async def restore_backup(self, backup_id: int, group_id: Optional[str] = None) - except Exception as e: logger.warning(f"正式备份管理器恢复失败,尝试 WebUI fallback: {e}") - persona = self._persona_from_backup(backup) - if not persona: - return False, '备份中没有可恢复的人格内容' - persona_id = persona['persona_id'] if self.persona_web_mgr and hasattr(self.persona_web_mgr, 'update_persona_via_web'): result = await self.persona_web_mgr.update_persona_via_web(persona_id, persona) @@ -186,6 +224,42 @@ async def restore_backup(self, backup_id: int, group_id: Optional[str] = None) - return False, 'PersonaManager 未初始化,无法恢复备份' + async def _resolve_restore_persona_id( + self, + persona: Dict[str, Any], + group_id: str, + ) -> str: + """Resolve placeholder backup IDs like default to the current real AstrBot persona.""" + persona_id = persona.get('persona_id') or persona.get('name') + if not self._is_placeholder_persona_id(persona_id): + return str(persona_id) + + try: + current = None + if self.persona_web_mgr: + current = await resolve_target_persona_from_web( + self.persona_web_mgr, + self.plugin_config, + group_id, + log=logger, + ) + elif self.astrbot_persona_manager: + current = await resolve_target_persona( + self.astrbot_persona_manager, + self.plugin_config, + self._resolve_umo(group_id), + require_existing=True, + log=logger, + ) + if isinstance(current, dict): + resolved = current.get('persona_id') or current.get('name') + if resolved and not self._is_placeholder_persona_id(resolved): + return str(resolved) + except Exception as e: + logger.warning(f"解析备份恢复目标人格失败,将使用备份内 ID: {e}", exc_info=True) + + return str(persona.get('persona_id') or persona.get('name') or 'default') + async def delete_backup(self, backup_id: int, group_id: Optional[str] = None) -> Tuple[bool, str]: """删除人格备份。""" self._ensure_database() diff --git a/webui/services/persona_review_service.py b/webui/services/persona_review_service.py index f80db739..638e6cb0 100644 --- a/webui/services/persona_review_service.py +++ b/webui/services/persona_review_service.py @@ -992,7 +992,7 @@ async def _approve_style_learning_review(self, review_id: int) -> Tuple[bool, st # 更新状态 success = await self.database_manager.update_style_review_status( - review_id, 'approved', target_review['group_id'] + review_id, 'approved', group_id=group_id ) if success and target_review['few_shots_content']: