diff --git a/utils/image_cache.py b/utils/image_cache.py index 8f90c82..21a9558 100644 --- a/utils/image_cache.py +++ b/utils/image_cache.py @@ -27,6 +27,8 @@ class ImageCacheManager: base_storage_path: Optional[str] = None # 内存缓存(用于快速查询) memory_cache: Dict[str, tuple[str, float]] = {} + # 失败记录缓存(hash -> failure_timestamp) + failure_cache: Dict[str, float] = {} # 记录写入次数,用于周期性保存 write_count: int = 0 @@ -41,6 +43,7 @@ def init(config: AstrBotConfig): ImageCacheManager.config = config ImageCacheManager.write_count = 0 # 重置写入计数 ImageCacheManager.memory_cache.clear() # 清空内存缓存,确保从磁盘重新加载 + ImageCacheManager.failure_cache.clear() # 清空失败缓存,确保从磁盘重新加载 # 初始化基础存储路径 from astrbot.core.utils.astrbot_path import get_astrbot_data_path astrbot_data_path = get_astrbot_data_path() @@ -96,9 +99,18 @@ def _load_cache_from_disk() -> None: with open(cache_file, "r", encoding="utf-8") as f: cache_data = json.load(f) - # 加载缓存到内存,统一转换为元组格式,使用严格验证 - if isinstance(cache_data, dict): - for key, value in cache_data.items(): + # 兼容两种格式: + # 1) 旧格式: {hash: [caption, timestamp]} + # 2) 新格式: {"captions": {...}, "failures": {...}} + caption_data = cache_data + failure_data = {} + if isinstance(cache_data, dict) and ("captions" in cache_data or "failures" in cache_data): + caption_data = cache_data.get("captions", {}) + failure_data = cache_data.get("failures", {}) + + # 加载成功缓存到内存 + if isinstance(caption_data, dict): + for key, value in caption_data.items(): try: # 要求恰好2个元素 if isinstance(value, (list, tuple)) and len(value) == 2: @@ -115,8 +127,18 @@ def _load_cache_from_disk() -> None: logger.info(f"成功从磁盘加载 {len(ImageCacheManager.memory_cache)} 条图片缓存") else: - logger.warning(f"缓存文件格式不正确,跳过加载") - + logger.warning(f"图片缓存数据格式不正确,期望 dict,实际为 {type(caption_data).__name__},跳过加载") + + # 加载失败缓存到内存 + if isinstance(failure_data, dict): + for key, value in failure_data.items(): + if isinstance(value, (int, float)): + ImageCacheManager.failure_cache[key] = float(value) + else: + logger.warning(f"失败缓存条目格式不正确,跳过: {key}") + elif failure_data: + logger.warning("失败缓存数据格式不正确,跳过加载") + except Exception as e: logger.error(f"从磁盘加载缓存失败: {e}") logger.debug(traceback.format_exc()) @@ -146,8 +168,23 @@ def _save_cache_to_disk() -> None: skipped_count += 1 logger.debug(f"跳过格式不正确的缓存条目: {key}") + serializable_failures = {} + for key, value in ImageCacheManager.failure_cache.items(): + if isinstance(value, (int, float)): + serializable_failures[key] = float(value) + else: + logger.debug(f"跳过格式不正确的失败缓存条目: {key}") + with open(cache_file, "w", encoding="utf-8") as f: - json.dump(serializable_cache, f, ensure_ascii=False, indent=2) + json.dump( + { + "captions": serializable_cache, + "failures": serializable_failures + }, + f, + ensure_ascii=False, + indent=2 + ) if skipped_count > 0: logger.debug(f"成功保存 {len(serializable_cache)} 条有效缓存到磁盘,跳过 {skipped_count} 条格式不正确的条目") @@ -238,6 +275,7 @@ def clear() -> bool: """ try: ImageCacheManager.memory_cache.clear() + ImageCacheManager.failure_cache.clear() ImageCacheManager.write_count = 0 cache_file = ImageCacheManager._get_cache_file_path() @@ -294,6 +332,18 @@ def cleanup_old_entries() -> None: for key in keys_to_remove: del ImageCacheManager.memory_cache[key] + + failure_keys_to_remove = [] + for key, timestamp in ImageCacheManager.failure_cache.items(): + if not isinstance(timestamp, (int, float)): + failure_keys_to_remove.append(key) + removed_count += 1 + elif current_time - timestamp > cleanup_threshold: + failure_keys_to_remove.append(key) + removed_count += 1 + + for key in failure_keys_to_remove: + del ImageCacheManager.failure_cache[key] if removed_count > 0: logger.info(f"清理过期缓存完成,清理了 {removed_count} 条超过 {retention_days} 天的缓存条目") @@ -310,3 +360,89 @@ def force_save() -> None: ImageCacheManager._save_cache_to_disk() except Exception as e: logger.error(f"强制保存缓存失败: {e}") + + @staticmethod + def get_failed_timestamp(image: str) -> Optional[float]: + """ + 获取图片最近一次转述失败时间戳 + """ + try: + image_hash = ImageCacheManager._generate_image_hash(image) + timestamp = ImageCacheManager.failure_cache.get(image_hash) + if isinstance(timestamp, (int, float)): + return float(timestamp) + return None + except Exception as e: + logger.error(f"获取失败记录失败: {e}") + return None + + @staticmethod + def is_failed(image: str) -> bool: + """ + 判断图片是否有失败记录 + """ + return ImageCacheManager.get_failed_timestamp(image) is not None + + @staticmethod + def set_failed(image: str) -> bool: + """ + 记录图片转述失败 + """ + try: + image_hash = ImageCacheManager._generate_image_hash(image) + ImageCacheManager.failure_cache[image_hash] = time.time() + + ImageCacheManager.write_count += 1 + if ImageCacheManager.write_count >= ImageCacheManager.WRITE_THRESHOLD: + ImageCacheManager._save_cache_to_disk() + ImageCacheManager.write_count = 0 + + return True + except Exception as e: + logger.error(f"记录失败缓存失败: {e}") + return False + + @staticmethod + def clear_failed(image: str) -> bool: + """ + 清理图片失败记录 + """ + try: + image_hash = ImageCacheManager._generate_image_hash(image) + if image_hash in ImageCacheManager.failure_cache: + del ImageCacheManager.failure_cache[image_hash] + + ImageCacheManager.write_count += 1 + if ImageCacheManager.write_count >= ImageCacheManager.WRITE_THRESHOLD: + ImageCacheManager._save_cache_to_disk() + ImageCacheManager.write_count = 0 + + return True + except Exception as e: + logger.error(f"清理失败缓存失败: {e}") + return False + + @staticmethod + def should_skip_failed_image(image: str, latest_success_timestamp: Optional[float], window_seconds: int) -> bool: + """ + 判断失败图片是否应跳过转述: + - 存在失败记录 + - 失败时间早于最近成功时间(表示这张图是在该次成功之前失败的) + - 且二者间隔在窗口时间内(避免无限期跳过) + + Args: + image: 图片的base64编码或URL + latest_success_timestamp: 最近一次成功转述的时间戳 + window_seconds: 失败记录与最近成功记录可判定为“相近”的时间窗口(秒) + + Returns: + 是否应跳过该图片转述 + """ + if latest_success_timestamp is None or window_seconds <= 0: + return False + + failed_timestamp = ImageCacheManager.get_failed_timestamp(image) + if failed_timestamp is None: + return False + + return failed_timestamp < latest_success_timestamp and (latest_success_timestamp - failed_timestamp) <= window_seconds diff --git a/utils/image_caption.py b/utils/image_caption.py index 66fae94..c70f4ca 100644 --- a/utils/image_caption.py +++ b/utils/image_caption.py @@ -13,6 +13,7 @@ class ImageCaptionUtils: # 保存context和config对象的静态变量 context: Optional[Context] = None config: Optional[AstrBotConfig] = None + DEFAULT_FAILED_IMAGE_SKIP_WINDOW_SECONDS = 300 @staticmethod def init(context: Context, config: AstrBotConfig): @@ -22,11 +23,30 @@ def init(context: Context, config: AstrBotConfig): # 初始化图片缓存管理器 ImageCacheManager.init(config) + @staticmethod + def get_failed_image_skip_window_seconds() -> int: + """ + 获取失败图片跳过策略的时间窗口(秒) + """ + config = ImageCaptionUtils.config + if not config: + return ImageCaptionUtils.DEFAULT_FAILED_IMAGE_SKIP_WINDOW_SECONDS + + image_processing_config = config.get("image_processing", {}) + skip_window_seconds = image_processing_config.get( + "failed_image_skip_window_seconds", + ImageCaptionUtils.DEFAULT_FAILED_IMAGE_SKIP_WINDOW_SECONDS + ) + if not isinstance(skip_window_seconds, int) or skip_window_seconds < 0: + return ImageCaptionUtils.DEFAULT_FAILED_IMAGE_SKIP_WINDOW_SECONDS + return skip_window_seconds + @staticmethod async def generate_image_caption( image: str, # 图片的base64编码或URL umo: Optional[str] = None, # unified_msg_origin,用于 UMO 路由 - timeout: int = 30 + timeout: int = 30, + latest_success_timestamp: Optional[float] = None ) -> Optional[str]: """ 为单张图片生成文字描述 @@ -35,6 +55,7 @@ async def generate_image_caption( image: 图片的base64编码或URL umo: unified_msg_origin,用于获取对应 UMO 的 provider timeout: 超时时间(秒) + latest_success_timestamp: 最近一次成功转述时间戳(用于失败图片跳过策略) Returns: 生成的图片描述文本,如果失败则返回None @@ -42,6 +63,7 @@ async def generate_image_caption( # 检查持久化缓存 cached_caption = ImageCacheManager.get(image) if cached_caption is not None: + ImageCacheManager.clear_failed(image) logger.debug(f"命中图片描述缓存: {image[:50]}...") return cached_caption @@ -58,6 +80,12 @@ async def generate_image_caption( if not image_processing_config.get("use_image_caption", False): return None + skip_window_seconds = ImageCaptionUtils.get_failed_image_skip_window_seconds() + + if ImageCacheManager.should_skip_failed_image(image, latest_success_timestamp, skip_window_seconds): + logger.debug(f"跳过失败图片转述(该图片失败记录早于本轮最近一次成功,且时间间隔在窗口内): {image[:50]}...") + return None + provider_id = image_processing_config.get("image_caption_provider_id", "") # 获取提供商,支持 UMO 路由 if provider_id == "": @@ -88,12 +116,17 @@ async def call_llm(): # 缓存结果到持久化缓存 if caption: ImageCacheManager.set(image, caption) + ImageCacheManager.clear_failed(image) logger.debug(f"缓存到持久化存储: {image[:50]}...") + else: + ImageCacheManager.set_failed(image) return caption except asyncio.TimeoutError: logger.warning(f"图片转述超时,超过了{timeout}秒") + ImageCacheManager.set_failed(image) return None except Exception as e: logger.error(f"图片转述失败: {e}") + ImageCacheManager.set_failed(image) return None diff --git a/utils/message_utils.py b/utils/message_utils.py index d66dc74..b7342b5 100644 --- a/utils/message_utils.py +++ b/utils/message_utils.py @@ -82,6 +82,7 @@ async def outline_message_list(message_list: List[BaseMessageComponent], umo: Op umo: unified_msg_origin,用于 UMO 路由 """ outline = "" + latest_success_timestamp: Optional[float] = None for i in message_list: try: # 获取组件类型 @@ -110,9 +111,14 @@ async def outline_message_list(message_list: List[BaseMessageComponent], umo: Op continue image = image_path - caption = await ImageCaptionUtils.generate_image_caption(image, umo=umo) + caption = await ImageCaptionUtils.generate_image_caption( + image, + umo=umo, + latest_success_timestamp=latest_success_timestamp + ) if caption: outline += f"[图片: {caption}]" + latest_success_timestamp = time.time() else: outline += f"[图片]" else: