Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions handlers/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,13 @@ async def _final_generate_and_send():
try:
# 再次获取最新状态,以防万一
state = self.active_sessions.get(session_id, {})
final_texts, final_images = state.get("texts", []), state.get("images", [])
final_texts, final_images, image_names = state.get("texts", []), state.get("images", []), state.get("image_names", [])
final_texts = final_texts[:p.max_texts]
final_images = final_images[:p.max_images]

tasks = [self.api_client.upload_image(b) for b in final_images]
image_ids = await asyncio.gather(*tasks)
image_payload = [{"id": img_id, "name": f"img{i}"} for i, img_id in enumerate(image_ids)]
image_payload = [{"id": img_id, "name": image_names[i] if i < len(image_names) else f"img{i}"} for i, img_id in enumerate(image_ids)]
final_payload = {"texts": final_texts, "images": image_payload, "options": state.get("options", {})}

# 更新状态为“正在制作中”,实现状态锁
Expand Down Expand Up @@ -255,13 +255,17 @@ async def _final_generate_and_send():
needs_text = len(session_state["texts"]) < p.min_texts
needs_image = len(session_state["images"]) < p.min_images
provided_text = next_event.get_message_str().strip()
provided_images = await self._get_images_from_message(next_event)
provided_data = await self._get_images_from_message(next_event)
provided_images = [item[0] for item in provided_data]
provided_names = [item[1] for item in provided_data]
is_valid_and_needed_input = (needs_text and provided_text) or (needs_image and provided_images)

if is_valid_and_needed_input:
session_state["invalid_input_count"] = 0
if needs_text and provided_text: session_state["texts"].extend(provided_text.split())
if needs_image and provided_images: session_state["images"].extend(provided_images)
if needs_image and provided_images:
session_state["images"].extend(provided_images)
session_state.setdefault("image_names", []).extend(provided_names)
if len(session_state["texts"]) >= p.min_texts and len(session_state["images"]) >= p.min_images:
await self._send_and_record(next_event, "参数已集齐,开始制作...")
break
Expand Down Expand Up @@ -292,7 +296,7 @@ async def _final_generate_and_send():
self.active_sessions.pop(session_id, None)
logger.debug(f"后台工人任务结束,会话 {session_id} 已清理。")

async def handle_shortcut(self, event: AstrMessageEvent, meme: MemeInfo, shortcut: Dict, match: re.Match):
async def handle_shortcut(self, event: AstrMessageEvent, meme: MemeInfo, shortcut: Dict, match: re.Match, trailing_text: str = ""):
try:
logger.debug(f"快捷指令匹配成功: {meme.key}"); match_dict = match.groupdict()
texts = [t.format(**match_dict) for t in shortcut.get("texts", [])]
Expand All @@ -301,7 +305,7 @@ async def handle_shortcut(self, event: AstrMessageEvent, meme: MemeInfo, shortcu
event.set_extra("shortcut_names", names)

# 【核心修改】直接调用(await)新的“启动器”,而不是迭代
await self.meme_generate_handler(event, meme, "", initial_options=options, initial_texts=texts)
await self.meme_generate_handler(event, meme, trailing_text, initial_options=options, initial_texts=texts)

except Exception as e:
logger.error(f"处理快捷指令失败: {e}", exc_info=True)
Expand All @@ -325,7 +329,7 @@ async def meme_generate_handler(self, event: AstrMessageEvent, meme_info: MemeIn
# 初始化会话状态
shortcut_texts = initial_texts
shortcut_options = initial_options
parsed_texts, initial_images, parsed_options = await self.build_meme_payload(event, meme_info, text)
parsed_texts, initial_images, image_names, parsed_options = await self.build_meme_payload(event, meme_info, text)
final_texts = shortcut_texts + parsed_texts
final_options = shortcut_options
final_options.update(parsed_options)
Expand All @@ -334,7 +338,7 @@ async def meme_generate_handler(self, event: AstrMessageEvent, meme_info: MemeIn
final_texts = p.default_texts

session_state = {
"texts": final_texts, "images": initial_images, "options": final_options,
"texts": final_texts, "images": initial_images, "image_names": image_names, "options": final_options,
"params": p, "invalid_input_count": 0, "status": "waiting_for_input"
}
self.active_sessions[session_id] = session_state
Expand All @@ -350,8 +354,14 @@ async def meme_generate_handler(self, event: AstrMessageEvent, meme_info: MemeIn

# --- 以下是其他辅助函数,保持不变 ---

async def _get_images_from_message(self, event: AstrMessageEvent) -> List[bytes]:
image_bytes_list: List[bytes] = []
async def _get_images_from_message(self, event: AstrMessageEvent) -> List[tuple]:
"""从消息中提取图片数据和对应的用户名。返回 List[(bytes, str)] 的元组列表。"""
image_list: List[tuple] = []
# 获取发送者名称,用于 Comp.Image 类型的默认名称
try:
sender_name = event.get_sender_name() or str(event.get_sender_id())
except Exception:
sender_name = str(event.get_sender_id())
async def _process(seg):
if isinstance(seg, Comp.Image):
img_bytes: Optional[bytes] = None
Expand All @@ -360,31 +370,41 @@ async def _process(seg):
if isinstance(content, str) and content.startswith("base64://"): img_bytes = base64.b64decode(content[len("base64://"):])
elif isinstance(content, bytes): img_bytes = content
if not img_bytes and hasattr(seg, "url") and seg.url: img_bytes = await self.api_client._download_image(seg.url)
if img_bytes: image_bytes_list.append(img_bytes)
if img_bytes: image_list.append((img_bytes, sender_name))
elif isinstance(seg, Comp.At) and seg.qq:
if b := await self._get_avatar(str(seg.qq)): image_bytes_list.append(b)
at_name = getattr(seg, 'name', None) or str(seg.qq)
if b := await self._get_avatar(str(seg.qq)): image_list.append((b, at_name))
msgs = event.get_messages()
if reply := next((s for s in msgs if isinstance(s, Comp.Reply)), None):
if getattr(reply, 'chain', None):
for s in reply.chain: await _process(s)
for s in msgs: await _process(s)
return image_bytes_list
return image_list

async def build_meme_payload(self, event: AstrMessageEvent, meme_info: MemeInfo, text: str) -> (List[str], List[bytes], Dict):
async def build_meme_payload(self, event: AstrMessageEvent, meme_info: MemeInfo, text: str) -> (List[str], List[bytes], List[str], Dict):
image_bytes_list: List[bytes] = []
image_names_list: List[str] = []
shortcut_names = event.get_extra("shortcut_names") or []

initial_images = await self._get_images_from_message(event)
image_bytes_list.extend(initial_images)
for img_bytes, img_name in initial_images:
image_bytes_list.append(img_bytes)
image_names_list.append(img_name)

for name in shortcut_names:
if name.isdigit():
if b := await self._get_avatar(name):
image_bytes_list.append(b)
image_names_list.append(name)

if self.use_sender_when_no_image and len(image_bytes_list) < meme_info.params.min_images:
if b := await self._get_avatar(event.get_sender_id()):
image_bytes_list.insert(0, b)
try:
sender_name = event.get_sender_name() or str(event.get_sender_id())
except Exception:
sender_name = str(event.get_sender_id())
image_names_list.insert(0, sender_name)

text_to_parse = text.strip()

Expand Down Expand Up @@ -424,7 +444,7 @@ async def build_meme_payload(self, event: AstrMessageEvent, meme_info: MemeInfo,
except (ArgumentError, ValueError, ArgParseError) as e:
raise ArgParseError(f"参数解析或类型转换错误: {e}")

return texts, image_bytes_list, options_payload
return texts, image_bytes_list, image_names_list, options_payload

async def _get_avatar(self, user_id: str) -> Optional[bytes]:
if not user_id.isdigit():
Expand All @@ -439,7 +459,7 @@ async def _send_results(self, event: AstrMessageEvent, result_obj: Union[bytes,
async def handle_random_meme(self, event: AstrMessageEvent, arg_text: str):
try:
temp_meme_info = MemeInfo(key="", params=MemeParams(min_images=0, max_images=99, min_texts=0, max_texts=99), date_created=datetime.now(), keywords=[])
initial_texts, initial_images, _ = await self.build_meme_payload(event, temp_meme_info, arg_text)
initial_texts, initial_images, _, _ = await self.build_meme_payload(event, temp_meme_info, arg_text)
n_images_initial, n_texts_initial = len(initial_images), len(initial_texts)
final_arg_text = arg_text
n_images_filter, n_texts_filter = n_images_initial, n_texts_initial
Expand Down
29 changes: 27 additions & 2 deletions handlers/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from astrbot.api import logger

class HelpHandlers:
"""一个 Mixin 类,只包含表情列表指令的处理器"""
"""一个 Mixin 类,包含表情帮助和表情列表相关的指令处理器"""

async def handle_meme_list(self, event: AstrMessageEvent, _=None):
try:
Expand Down Expand Up @@ -59,4 +59,29 @@ async def handle_meme_list(self, event: AstrMessageEvent, _=None):
logger.error(f"生成动态表情列表图失败: {e}", exc_info=True)
yield event.plain_result("生成列表图失败了,呜呜...")
finally:
event.stop_event()
event.stop_event()

async def handle_meme_help(self, event: AstrMessageEvent, _=None):
p = self.prefix
help_text = (
Comment thread
monbed marked this conversation as resolved.
f"【基础指令】\n"
f"{p}表情列表: 查看所有支持表情\n"
f"{p}表情详情 <词>: 查询具体用法\n"
f"{p}表情搜索 <词>: 靠关键词找表情\n"
f"{p}<表情名> [图/文]: 制作表情\n"
f"{p}随机表情: 随机生成一张\n"
f"{p}表情调用统计: 查看使用榜单\n"
f"\n【图片处理】(发送图片附带指令)\n"
f"支持: {p}水平翻转, {p}竖直翻转, {p}旋转, {p}缩放, {p}裁剪, {p}灰度, {p}反色, {p}水平/竖直拼接\n"
f"GIF处理: {p}gif分解, {p}gif合成, {p}gif倒放, {p}gif变速\n"
f"\n【群组管理】\n"
f"{p}管理列表: 查看本群已禁用列表\n"
f"{p}禁用表情 <词>: 本群禁用该表情\n"
f"{p}启用表情 <词>: 重新启用该表情\n"
f"\n【全局管理】(超管可用)\n"
f"{p}刷新表情: 重新加载配置数据\n"
f"{p}全局禁用表情 <词>: 全局禁用\n"
f"{p}全局启用表情 <词>: 全局启用"
)
yield event.plain_result(help_text)
event.stop_event()
3 changes: 2 additions & 1 deletion handlers/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ async def handle_image_tool(self, event: AstrMessageEvent, operation: str, arg_t

async def _get_images_for_tool(self, event: AstrMessageEvent, min_images: int = 1) -> List[str]:
"""从消息中提取所需数量的图片,上传并返回 image_id 列表"""
image_bytes_list = await self._get_images_from_message(event)
image_data = await self._get_images_from_message(event)
image_bytes_list = [item[0] for item in image_data]

if len(image_bytes_list) < min_images:
# 如果不够,自动补充发送者头像
Expand Down
14 changes: 11 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from astrbot.api.event import filter, AstrMessageEvent
from astrbot.api.star import Context, Star, register, StarTools
from astrbot.api import logger, AstrBotConfig
from astrbot.api.message_components import Plain
from astrbot.core.star.filter.event_message_type import EventMessageType
from .core.permission import PermissionManager

# --- 从我们自己的模块中导入所有“零件” ---
from .api_client import APIClient
from .manager import MemeManager
from .recorder import StatsRecorder
Expand Down Expand Up @@ -100,6 +100,7 @@ def __init__(self, context: Context, config: AstrBotConfig):

# 3. 构建指令到处理器的映射
self.cmd_map = {
"表情帮助": self.handle_meme_help,
"表情列表": self.handle_meme_list,
"表情详情": self.handle_meme_info,
Comment thread
monbed marked this conversation as resolved.
"表情详细": self.handle_meme_info,
Expand Down Expand Up @@ -159,7 +160,10 @@ async def universal_handler(self, event: AstrMessageEvent):
except Exception: return

try:
message_text = event.get_message_str().strip()
message_text = " ".join(
c.text for c in event.get_messages()
if isinstance(c, Plain) and c.text
).strip()
if not message_text.startswith(self.prefix): return

cleaned_text = message_text[len(self.prefix):].strip()
Expand Down Expand Up @@ -193,7 +197,11 @@ async def universal_handler(self, event: AstrMessageEvent):
for sc_data in self.meme_manager.shortcuts:
if await self.recorder.is_meme_disabled(sc_data["meme"].key, event.get_group_id()): continue
if match := sc_data["pattern"].fullmatch(cleaned_text):
asyncio.create_task(self.handle_shortcut(event, sc_data["meme"], sc_data["shortcut"], match))
asyncio.create_task(self.handle_shortcut(event, sc_data["meme"], sc_data["shortcut"], match, ""))
return
if match := sc_data["pattern"].match(cleaned_text):
trailing_text = cleaned_text[match.end():].strip()
asyncio.create_task(self.handle_shortcut(event, sc_data["meme"], sc_data["shortcut"], match, trailing_text))
return

if keyword := self.meme_manager.find_keyword_in_text(cleaned_text, self.fuzzy_match):
Expand Down
3 changes: 3 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ class MemeOption(BaseModel):
default: Optional[Any] = None
description: Optional[str] = None
parser_flags: Dict[str, Any] = Field(default_factory=dict)
choices: Optional[List[str]] = None
minimum: Optional[float] = None
maximum: Optional[float] = None

class MemeParams(BaseModel):
min_images: int
Expand Down