Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 142 additions & 6 deletions utils/image_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class ImageCacheManager:
base_storage_path: Optional[str] = None
# 内存缓存(用于快速查询)
memory_cache: Dict[str, tuple[str, float]] = {}
# 失败记录缓存(hash -> failure_timestamp)
failure_cache: Dict[str, float] = {}
# 记录写入次数,用于周期性保存
write_count: int = 0

Expand All @@ -41,6 +43,7 @@ def init(config: AstrBotConfig):
ImageCacheManager.config = config
ImageCacheManager.write_count = 0 # 重置写入计数
ImageCacheManager.memory_cache.clear() # 清空内存缓存,确保从磁盘重新加载
ImageCacheManager.failure_cache.clear() # 清空失败缓存,确保从磁盘重新加载
# 初始化基础存储路径
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
astrbot_data_path = get_astrbot_data_path()
Expand Down Expand Up @@ -96,9 +99,18 @@ def _load_cache_from_disk() -> None:
with open(cache_file, "r", encoding="utf-8") as f:
cache_data = json.load(f)

# 加载缓存到内存,统一转换为元组格式,使用严格验证
if isinstance(cache_data, dict):
for key, value in cache_data.items():
# 兼容两种格式:
# 1) 旧格式: {hash: [caption, timestamp]}
# 2) 新格式: {"captions": {...}, "failures": {...}}
caption_data = cache_data
failure_data = {}
if isinstance(cache_data, dict) and ("captions" in cache_data or "failures" in cache_data):
caption_data = cache_data.get("captions", {})
failure_data = cache_data.get("failures", {})

# 加载成功缓存到内存
if isinstance(caption_data, dict):
for key, value in caption_data.items():
try:
# 要求恰好2个元素
if isinstance(value, (list, tuple)) and len(value) == 2:
Expand All @@ -115,8 +127,18 @@ def _load_cache_from_disk() -> None:

logger.info(f"成功从磁盘加载 {len(ImageCacheManager.memory_cache)} 条图片缓存")
else:
logger.warning(f"缓存文件格式不正确,跳过加载")

logger.warning(f"图片缓存数据格式不正确,期望 dict,实际为 {type(caption_data).__name__},跳过加载")

# 加载失败缓存到内存
if isinstance(failure_data, dict):
for key, value in failure_data.items():
if isinstance(value, (int, float)):
ImageCacheManager.failure_cache[key] = float(value)
else:
logger.warning(f"失败缓存条目格式不正确,跳过: {key}")
elif failure_data:
logger.warning("失败缓存数据格式不正确,跳过加载")

except Exception as e:
logger.error(f"从磁盘加载缓存失败: {e}")
logger.debug(traceback.format_exc())
Expand Down Expand Up @@ -146,8 +168,23 @@ def _save_cache_to_disk() -> None:
skipped_count += 1
logger.debug(f"跳过格式不正确的缓存条目: {key}")

serializable_failures = {}
for key, value in ImageCacheManager.failure_cache.items():
if isinstance(value, (int, float)):
serializable_failures[key] = float(value)
else:
logger.debug(f"跳过格式不正确的失败缓存条目: {key}")

with open(cache_file, "w", encoding="utf-8") as f:
json.dump(serializable_cache, f, ensure_ascii=False, indent=2)
json.dump(
{
"captions": serializable_cache,
"failures": serializable_failures
},
f,
ensure_ascii=False,
indent=2
)

if skipped_count > 0:
logger.debug(f"成功保存 {len(serializable_cache)} 条有效缓存到磁盘,跳过 {skipped_count} 条格式不正确的条目")
Expand Down Expand Up @@ -238,6 +275,7 @@ def clear() -> bool:
"""
try:
ImageCacheManager.memory_cache.clear()
ImageCacheManager.failure_cache.clear()
ImageCacheManager.write_count = 0

cache_file = ImageCacheManager._get_cache_file_path()
Expand Down Expand Up @@ -294,6 +332,18 @@ def cleanup_old_entries() -> None:

for key in keys_to_remove:
del ImageCacheManager.memory_cache[key]

failure_keys_to_remove = []
for key, timestamp in ImageCacheManager.failure_cache.items():
if not isinstance(timestamp, (int, float)):
failure_keys_to_remove.append(key)
removed_count += 1
elif current_time - timestamp > cleanup_threshold:
failure_keys_to_remove.append(key)
removed_count += 1

for key in failure_keys_to_remove:
del ImageCacheManager.failure_cache[key]

if removed_count > 0:
logger.info(f"清理过期缓存完成,清理了 {removed_count} 条超过 {retention_days} 天的缓存条目")
Expand All @@ -310,3 +360,89 @@ def force_save() -> None:
ImageCacheManager._save_cache_to_disk()
except Exception as e:
logger.error(f"强制保存缓存失败: {e}")

@staticmethod
def get_failed_timestamp(image: str) -> Optional[float]:
"""
获取图片最近一次转述失败时间戳
"""
try:
image_hash = ImageCacheManager._generate_image_hash(image)
timestamp = ImageCacheManager.failure_cache.get(image_hash)
if isinstance(timestamp, (int, float)):
return float(timestamp)
return None
except Exception as e:
logger.error(f"获取失败记录失败: {e}")
return None

@staticmethod
def is_failed(image: str) -> bool:
"""
判断图片是否有失败记录
"""
return ImageCacheManager.get_failed_timestamp(image) is not None

@staticmethod
def set_failed(image: str) -> bool:
"""
记录图片转述失败
"""
try:
image_hash = ImageCacheManager._generate_image_hash(image)
ImageCacheManager.failure_cache[image_hash] = time.time()

ImageCacheManager.write_count += 1
if ImageCacheManager.write_count >= ImageCacheManager.WRITE_THRESHOLD:
ImageCacheManager._save_cache_to_disk()
ImageCacheManager.write_count = 0

return True
except Exception as e:
logger.error(f"记录失败缓存失败: {e}")
return False

@staticmethod
def clear_failed(image: str) -> bool:
"""
清理图片失败记录
"""
try:
image_hash = ImageCacheManager._generate_image_hash(image)
if image_hash in ImageCacheManager.failure_cache:
del ImageCacheManager.failure_cache[image_hash]

ImageCacheManager.write_count += 1
if ImageCacheManager.write_count >= ImageCacheManager.WRITE_THRESHOLD:
ImageCacheManager._save_cache_to_disk()
ImageCacheManager.write_count = 0

return True
except Exception as e:
logger.error(f"清理失败缓存失败: {e}")
return False

@staticmethod
def should_skip_failed_image(image: str, latest_success_timestamp: Optional[float], window_seconds: int) -> bool:
"""
判断失败图片是否应跳过转述:
- 存在失败记录
- 失败时间早于最近成功时间(表示这张图是在该次成功之前失败的)
- 且二者间隔在窗口时间内(避免无限期跳过)

Args:
image: 图片的base64编码或URL
latest_success_timestamp: 最近一次成功转述的时间戳
window_seconds: 失败记录与最近成功记录可判定为“相近”的时间窗口(秒)

Returns:
是否应跳过该图片转述
"""
if latest_success_timestamp is None or window_seconds <= 0:
return False

failed_timestamp = ImageCacheManager.get_failed_timestamp(image)
if failed_timestamp is None:
return False

return failed_timestamp < latest_success_timestamp and (latest_success_timestamp - failed_timestamp) <= window_seconds
35 changes: 34 additions & 1 deletion utils/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class ImageCaptionUtils:
# 保存context和config对象的静态变量
context: Optional[Context] = None
config: Optional[AstrBotConfig] = None
DEFAULT_FAILED_IMAGE_SKIP_WINDOW_SECONDS = 300

@staticmethod
def init(context: Context, config: AstrBotConfig):
Expand All @@ -22,11 +23,30 @@ def init(context: Context, config: AstrBotConfig):
# 初始化图片缓存管理器
ImageCacheManager.init(config)

@staticmethod
def get_failed_image_skip_window_seconds() -> int:
"""
获取失败图片跳过策略的时间窗口(秒)
"""
config = ImageCaptionUtils.config
if not config:
return ImageCaptionUtils.DEFAULT_FAILED_IMAGE_SKIP_WINDOW_SECONDS

image_processing_config = config.get("image_processing", {})
skip_window_seconds = image_processing_config.get(
"failed_image_skip_window_seconds",
ImageCaptionUtils.DEFAULT_FAILED_IMAGE_SKIP_WINDOW_SECONDS
)
if not isinstance(skip_window_seconds, int) or skip_window_seconds < 0:
return ImageCaptionUtils.DEFAULT_FAILED_IMAGE_SKIP_WINDOW_SECONDS
return skip_window_seconds

@staticmethod
async def generate_image_caption(
image: str, # 图片的base64编码或URL
umo: Optional[str] = None, # unified_msg_origin,用于 UMO 路由
timeout: int = 30
timeout: int = 30,
latest_success_timestamp: Optional[float] = None
) -> Optional[str]:
"""
为单张图片生成文字描述
Expand All @@ -35,13 +55,15 @@ async def generate_image_caption(
image: 图片的base64编码或URL
umo: unified_msg_origin,用于获取对应 UMO 的 provider
timeout: 超时时间(秒)
latest_success_timestamp: 最近一次成功转述时间戳(用于失败图片跳过策略)

Returns:
生成的图片描述文本,如果失败则返回None
"""
# 检查持久化缓存
cached_caption = ImageCacheManager.get(image)
if cached_caption is not None:
ImageCacheManager.clear_failed(image)
logger.debug(f"命中图片描述缓存: {image[:50]}...")
return cached_caption

Expand All @@ -58,6 +80,12 @@ async def generate_image_caption(
if not image_processing_config.get("use_image_caption", False):
return None

skip_window_seconds = ImageCaptionUtils.get_failed_image_skip_window_seconds()

if ImageCacheManager.should_skip_failed_image(image, latest_success_timestamp, skip_window_seconds):
logger.debug(f"跳过失败图片转述(该图片失败记录早于本轮最近一次成功,且时间间隔在窗口内): {image[:50]}...")
return None

provider_id = image_processing_config.get("image_caption_provider_id", "")
# 获取提供商,支持 UMO 路由
if provider_id == "":
Expand Down Expand Up @@ -88,12 +116,17 @@ async def call_llm():
# 缓存结果到持久化缓存
if caption:
ImageCacheManager.set(image, caption)
ImageCacheManager.clear_failed(image)
logger.debug(f"缓存到持久化存储: {image[:50]}...")
else:
ImageCacheManager.set_failed(image)

return caption
except asyncio.TimeoutError:
logger.warning(f"图片转述超时,超过了{timeout}秒")
ImageCacheManager.set_failed(image)
return None
except Exception as e:
logger.error(f"图片转述失败: {e}")
ImageCacheManager.set_failed(image)
return None
8 changes: 7 additions & 1 deletion utils/message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ async def outline_message_list(message_list: List[BaseMessageComponent], umo: Op
umo: unified_msg_origin,用于 UMO 路由
"""
outline = ""
latest_success_timestamp: Optional[float] = None
for i in message_list:
try:
# 获取组件类型
Expand Down Expand Up @@ -110,9 +111,14 @@ async def outline_message_list(message_list: List[BaseMessageComponent], umo: Op
continue
image = image_path

caption = await ImageCaptionUtils.generate_image_caption(image, umo=umo)
caption = await ImageCaptionUtils.generate_image_caption(
image,
umo=umo,
latest_success_timestamp=latest_success_timestamp
)
if caption:
outline += f"[图片: {caption}]"
latest_success_timestamp = time.time()
else:
outline += f"[图片]"
else:
Expand Down