Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions services/database/facades/learning_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ async def save_persona_update_record(self, record: Dict[str, Any]) -> int:
return await self.add_persona_learning_review(record)

async def update_persona_update_record_status(
self, record_id: int, new_status: str, reviewer_comment: str = ''
self, record_id: int, new_status: str, reviewer_comment: str = '',
group_id: str = None,
) -> bool:
"""更新人格更新记录的状态

Expand All @@ -237,6 +238,8 @@ async def update_persona_update_record_status(
stmt = select(PersonaLearningReview).where(
PersonaLearningReview.id == record_id
)
if group_id:
stmt = stmt.where(PersonaLearningReview.group_id == group_id)
result = await session.execute(stmt)
record = result.scalar_one_or_none()
if not record:
Expand Down Expand Up @@ -531,7 +534,7 @@ async def get_persona_learning_review_by_id(

async def update_persona_learning_review_status(
self, review_id, new_status, reviewer_comment='',
modified_content=None,
modified_content=None, group_id: str = None,
) -> bool:
"""更新人格学习审核记录状态

Expand All @@ -550,6 +553,8 @@ async def update_persona_learning_review_status(
stmt = select(PersonaLearningReview).where(
PersonaLearningReview.id == review_id
)
if group_id:
stmt = stmt.where(PersonaLearningReview.group_id == group_id)
result = await session.execute(stmt)
record = result.scalar_one_or_none()
if not record:
Expand Down Expand Up @@ -772,7 +777,7 @@ async def get_reviewed_style_learning_updates(
return []

async def update_style_review_status(
self, review_id, new_status, reviewer_comment=''
self, review_id, new_status, reviewer_comment='', group_id: str = None,
) -> bool:
"""更新风格学习审核记录状态

Expand All @@ -790,6 +795,8 @@ async def update_style_review_status(
stmt = select(StyleLearningReview).where(
StyleLearningReview.id == review_id
)
if group_id:
stmt = stmt.where(StyleLearningReview.group_id == group_id)
result = await session.execute(stmt)
record = result.scalar_one_or_none()
if not record:
Expand Down
10 changes: 6 additions & 4 deletions services/database/sqlalchemy_database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,11 +783,12 @@ async def get_reviewed_persona_update_records(

async def update_persona_update_record_status(
self, record_id: int, status: str, comment: str = None,
group_id: str = None,
) -> bool:
return await self._call_learning(
"update_persona_update_record_status",
False,
record_id, status, comment,
record_id, status, comment, group_id,
)

async def create_style_learning_review(
Expand All @@ -811,11 +812,12 @@ async def get_reviewed_style_learning_updates(

async def update_style_review_status(
self, review_id: int, status: str, reviewer_comment: str = '',
group_id: str = None,
) -> bool:
return await self._call_learning(
"update_style_review_status",
False,
review_id, status, reviewer_comment,
review_id, status, reviewer_comment, group_id,
)

async def update_style_review_metadata(
Expand Down Expand Up @@ -870,12 +872,12 @@ async def get_persona_learning_review_by_id(

async def update_persona_learning_review_status(
self, review_id: int, status: str, comment: str = None,
modified_content: str = None,
modified_content: str = None, group_id: str = None,
) -> bool:
return await self._call_learning(
"update_persona_learning_review_status",
False,
review_id, status, comment, modified_content,
review_id, status, comment, modified_content, group_id,
)

async def update_persona_learning_review_metadata(
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/test_persona_backup_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Unit tests for WebUI persona backup management service."""
from unittest.mock import AsyncMock
from unittest.mock import Mock

import pytest

Expand Down Expand Up @@ -102,6 +103,50 @@ async def test_restore_backup_falls_back_to_persona_manager(mock_container):
mock_container.persona_manager.update_persona.assert_awaited_once()


@pytest.mark.asyncio
async def test_restore_default_backup_targets_current_real_persona(mock_container):
"""Backups created from the legacy default placeholder restore into the real current persona."""
mock_container.persona_backup_manager = None
mock_container.persona_web_manager = Mock()
mock_container.persona_web_manager.get_persona_for_group = AsyncMock(return_value={
'persona_id': 'default',
'system_prompt': '',
'begin_dialogs': [],
})
mock_container.persona_web_manager.get_all_personas_for_web = AsyncMock(return_value=[
{
'persona_id': 'suleng',
'name': 'SuLeng',
'system_prompt': 'current prompt',
'begin_dialogs': [],
}
])
mock_container.persona_web_manager.update_persona_via_web = AsyncMock(return_value={
'success': True,
})
mock_container.database_manager.get_persona_backup = AsyncMock(return_value={
'id': 10,
'group_id': 'default',
'backup_name': 'legacy-default',
'original_persona': {
'persona_id': 'default',
'name': 'default',
'prompt': 'restored prompt',
},
'imitation_dialogues': [],
})
service = PersonaBackupService(mock_container)

success, message = await service.restore_backup(10, group_id='default')

assert success is True
assert '恢复成功' in message
persona_id, payload = mock_container.persona_web_manager.update_persona_via_web.await_args.args
assert persona_id == 'suleng'
assert payload['persona_id'] == 'suleng'
assert payload['system_prompt'] == 'restored prompt'


@pytest.mark.asyncio
async def test_delete_backup_uses_database_manager(mock_container):
mock_container.database_manager.delete_persona_backup = AsyncMock(return_value=True)
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_persona_review_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,29 @@ async def test_review_persona_update_approve_style_learning(self, mock_container
assert '批准' in message or 'approved' in message.lower()
mock_container.database_manager.update_style_review_status.assert_called()

@pytest.mark.asyncio
async def test_style_learning_approval_locks_review_by_group_without_polluting_comment(
self, mock_container, sample_style_review_data
):
"""Approving one style review should lock by row/group and keep comments clean."""
service = PersonaReviewService(mock_container)
sample_style_review_data["group_id"] = "group-a"
mock_container.database_manager.get_pending_style_reviews.return_value = [
sample_style_review_data,
{**sample_style_review_data, "id": 2, "group_id": "group-b"},
]
mock_container.database_manager.update_style_review_status.return_value = True

success, message = await service.review_persona_update("style_1", "approve")

assert success is True
assert "group-a" not in message
mock_container.database_manager.update_style_review_status.assert_awaited_once_with(
1,
"approved",
group_id="group-a",
)

@pytest.mark.asyncio
async def test_style_learning_preview_targets_begin_dialogs(self, mock_container, sample_style_review_data):
"""Style review previews should show begin_dialogs changes, not system prompt edits."""
Expand Down
84 changes: 79 additions & 5 deletions webui/services/persona_backup_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@

from .persona_service import _optional_container_attr

try:
from ...utils.persona_selection import (
resolve_target_persona,
resolve_target_persona_from_web,
)
except ImportError:
from utils.persona_selection import (
resolve_target_persona,
resolve_target_persona_from_web,
)


class PersonaBackupService:
"""人格备份管理服务。"""
Expand All @@ -17,6 +28,13 @@ def __init__(self, container):
self.persona_backup_manager = _optional_container_attr(container, 'persona_backup_manager')
self.persona_manager = _optional_container_attr(container, 'persona_manager')
self.persona_web_mgr = _optional_container_attr(container, 'persona_web_manager')
self.astrbot_persona_manager = (
_optional_container_attr(container, 'astrbot_persona_manager')
or self.persona_manager
)
self.plugin_config = _optional_container_attr(container, 'plugin_config')
group_mapping = _optional_container_attr(container, 'group_id_to_unified_origin', {})
self.group_id_to_unified_origin = group_mapping if isinstance(group_mapping, dict) else {}

@staticmethod
def _normalize_limit(limit: Any, default: int = 20, maximum: int = 100) -> int:
Expand Down Expand Up @@ -75,6 +93,14 @@ def _ensure_database(self):
if not self.database_manager:
raise ValueError('数据库服务未初始化,无法管理人格备份')

def _resolve_umo(self, group_id: str) -> str:
return self.group_id_to_unified_origin.get(group_id, group_id)

@staticmethod
def _is_placeholder_persona_id(persona_id: Any) -> bool:
value = str(persona_id or '').strip().lower()
return not value or value in {'default', '[%none]'}

@staticmethod
def _normalize_group_id(group_id: Any) -> Optional[str]:
if group_id is None:
Expand Down Expand Up @@ -142,19 +168,31 @@ async def restore_backup(self, backup_id: int, group_id: Optional[str] = None) -
normalized_group_id = self._normalize_group_id(group_id)
backup = await self.get_backup(backup_id, group_id=normalized_group_id)
effective_group_id = normalized_group_id or backup.get('group_id') or 'default'
persona = self._persona_from_backup(backup)
if not persona:
return False, '备份中没有可恢复的人格内容'

if self.persona_backup_manager and hasattr(self.persona_backup_manager, 'restore_backup'):
target_persona_id = await self._resolve_restore_persona_id(persona, effective_group_id)
if target_persona_id:
persona['persona_id'] = target_persona_id
if self._is_placeholder_persona_id(persona.get('name')):
persona['name'] = target_persona_id

if (
self.persona_backup_manager
and hasattr(self.persona_backup_manager, 'restore_backup')
and not self._is_placeholder_persona_id(
(backup.get('original_persona') or {}).get('persona_id')
or (backup.get('original_persona') or {}).get('name')
)
):
try:
success = await self.persona_backup_manager.restore_backup(effective_group_id, backup_id)
if success:
return True, '人格备份恢复成功'
except Exception as e:
logger.warning(f"正式备份管理器恢复失败,尝试 WebUI fallback: {e}")

persona = self._persona_from_backup(backup)
if not persona:
return False, '备份中没有可恢复的人格内容'

persona_id = persona['persona_id']
if self.persona_web_mgr and hasattr(self.persona_web_mgr, 'update_persona_via_web'):
result = await self.persona_web_mgr.update_persona_via_web(persona_id, persona)
Expand Down Expand Up @@ -186,6 +224,42 @@ async def restore_backup(self, backup_id: int, group_id: Optional[str] = None) -

return False, 'PersonaManager 未初始化,无法恢复备份'

async def _resolve_restore_persona_id(
self,
persona: Dict[str, Any],
group_id: str,
) -> str:
"""Resolve placeholder backup IDs like default to the current real AstrBot persona."""
persona_id = persona.get('persona_id') or persona.get('name')
if not self._is_placeholder_persona_id(persona_id):
return str(persona_id)

try:
current = None
if self.persona_web_mgr:
current = await resolve_target_persona_from_web(
self.persona_web_mgr,
self.plugin_config,
group_id,
log=logger,
)
elif self.astrbot_persona_manager:
current = await resolve_target_persona(
self.astrbot_persona_manager,
self.plugin_config,
self._resolve_umo(group_id),
require_existing=True,
log=logger,
)
if isinstance(current, dict):
resolved = current.get('persona_id') or current.get('name')
if resolved and not self._is_placeholder_persona_id(resolved):
return str(resolved)
except Exception as e:
logger.warning(f"解析备份恢复目标人格失败,将使用备份内 ID: {e}", exc_info=True)

return str(persona.get('persona_id') or persona.get('name') or 'default')

async def delete_backup(self, backup_id: int, group_id: Optional[str] = None) -> Tuple[bool, str]:
"""删除人格备份。"""
self._ensure_database()
Expand Down
2 changes: 1 addition & 1 deletion webui/services/persona_review_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ async def _approve_style_learning_review(self, review_id: int) -> Tuple[bool, st

# 更新状态
success = await self.database_manager.update_style_review_status(
review_id, 'approved', target_review['group_id']
review_id, 'approved', group_id=group_id
)

if success and target_review['few_shots_content']:
Expand Down
Loading