diff --git a/.claude/rules/architecture.md b/.claude/rules/architecture.md index 320050bd..d8694ebd 100644 --- a/.claude/rules/architecture.md +++ b/.claude/rules/architecture.md @@ -30,7 +30,7 @@ │ SessionMonitor │ │ TmuxManager (tmux_manager.py) │ │ (session_monitor.py) │ │ - list/find/create/kill windows│ │ - Poll JSONL every 2s │ │ - send_keys to pane │ -│ - Detect mtime changes │ │ - capture_pane for screenshot │ +│ - Detect size changes │ │ - capture_pane for screenshot │ │ - Parse new lines │ └──────────────┬─────────────────┘ │ - Track pending tools │ │ │ across poll cycles │ │ diff --git a/.claude/rules/message-handling.md b/.claude/rules/message-handling.md index ab108a86..33138307 100644 --- a/.claude/rules/message-handling.md +++ b/.claude/rules/message-handling.md @@ -32,7 +32,7 @@ Per-user message queues + worker pattern for all send tasks: ## Performance Optimizations -**mtime cache**: The monitoring loop maintains an in-memory file mtime cache, skipping reads for unchanged files. +**File size fast path**: The monitoring loop compares file size against the last byte offset, skipping reads for unchanged files. **Byte offset incremental reads**: Each tracked session records `last_byte_offset`, reading only new content. File truncation (offset > file_size) is detected and offset is auto-reset. diff --git a/src/ccbot/bot.py b/src/ccbot/bot.py index 0b746c78..ff782aa6 100644 --- a/src/ccbot/bot.py +++ b/src/ccbot/bot.py @@ -41,7 +41,7 @@ BotCommand, InlineKeyboardButton, InlineKeyboardMarkup, - InputMediaDocument, + InputMediaPhoto, Update, ) from telegram.constants import ChatAction @@ -105,9 +105,9 @@ ) from .handlers.message_queue import ( clear_status_msg_info, + enqueue_callable, enqueue_content_message, enqueue_status_update, - get_message_queue, shutdown_workers, ) from .handlers.message_sender import ( @@ -124,7 +124,7 @@ from .session import session_manager from .session_monitor import NewMessage, SessionMonitor from .terminal_parser import extract_bash_output, is_interactive_ui -from .tmux_manager import tmux_manager +from .tmux_manager import SHELL_COMMANDS, tmux_manager from .utils import ccbot_dir logger = logging.getLogger(__name__) @@ -229,9 +229,8 @@ async def screenshot_command( png_bytes = await text_to_image(text, with_ansi=True) keyboard = _build_screenshot_keyboard(wid) - await update.message.reply_document( - document=io.BytesIO(png_bytes), - filename="screenshot.png", + await update.message.reply_photo( + photo=io.BytesIO(png_bytes), reply_markup=keyboard, ) @@ -285,6 +284,9 @@ async def esc_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non display = session_manager.get_display_name(wid) await safe_reply(update.message, f"❌ Window '{display}' no longer exists.") return + if w.pane_current_command in SHELL_COMMANDS: + await safe_reply(update.message, "❌ Claude Code has exited.") + return # Send Escape control character (no enter) await tmux_manager.send_keys(w.window_id, "\x1b", enter=False) @@ -309,6 +311,9 @@ async def usage_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N if not w: await safe_reply(update.message, f"Window '{wid}' no longer exists.") return + if w.pane_current_command in SHELL_COMMANDS: + await safe_reply(update.message, "❌ Claude Code has exited.") + return # Send /usage command to Claude Code TUI await tmux_manager.send_keys(w.window_id, "/usage") @@ -338,6 +343,64 @@ async def usage_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N await safe_reply(update.message, f"```\n{trimmed}\n```") +async def kill_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Kill the associated tmux window, clean up state, and delete the topic.""" + user = update.effective_user + if not user or not is_user_allowed(user.id): + return + if not update.message: + return + + thread_id = _get_thread_id(update) + if thread_id is None: + await safe_reply(update.message, "❌ This command only works in a topic.") + return + + wid = session_manager.get_window_for_thread(user.id, thread_id) + if not wid: + await safe_reply(update.message, "❌ No session bound to this topic.") + return + + display = session_manager.get_display_name(wid) + + # Kill the tmux window + w = await tmux_manager.find_window_by_id(wid) + if w: + await tmux_manager.kill_window(w.window_id) + logger.info( + "/kill: killed window %s (user=%d, thread=%d)", + display, + user.id, + thread_id, + ) + else: + logger.info( + "/kill: window %s already gone (user=%d, thread=%d)", + display, + user.id, + thread_id, + ) + + # Unbind and clean up all topic state + session_manager.unbind_thread(user.id, thread_id) + await clear_topic_state(user.id, thread_id, context.bot, context.user_data) + + await safe_reply(update.message, f"✅ Killed session '{display}'.") + + # Best-effort: close and delete the forum topic + chat_id = update.effective_chat.id if update.effective_chat else None + if chat_id and thread_id: + try: + await context.bot.close_forum_topic(chat_id, thread_id) + await context.bot.delete_forum_topic(chat_id, thread_id) + except Exception: + logger.debug( + "/kill: could not close/delete topic (user=%d, thread=%d)", + user.id, + thread_id, + ) + + # --- Screenshot keyboard with quick control keys --- # key_id → (tmux_key, enter, literal) @@ -776,7 +839,7 @@ async def text_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No if wid is None: # Unbound topic — check for unbound windows first all_windows = await tmux_manager.list_windows() - bound_ids = {wid for _, _, wid in session_manager.iter_thread_bindings()} + bound_ids = {wid for _, _, wid in session_manager.all_thread_bindings()} unbound = [ (w.window_id, w.window_name, w.cwd) for w in all_windows @@ -812,7 +875,7 @@ async def text_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No user.id, thread_id, ) - start_path = str(Path.cwd()) + start_path = config.browse_root or str(Path.cwd()) msg_text, keyboard, subdirs = build_directory_browser(start_path) if context.user_data is not None: context.user_data[STATE_KEY] = STATE_BROWSING_DIRECTORY @@ -864,7 +927,16 @@ async def text_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No success, message = await session_manager.send_to_window(wid, text) if not success: - await safe_reply(update.message, f"❌ {message}") + if "not running" in message: + # Claude Code exited and auto-resume failed — unbind + session_manager.unbind_thread(user.id, thread_id) + await safe_reply( + update.message, + f"❌ {message}. Binding removed.\n" + "Send a message to start a new session.", + ) + else: + await safe_reply(update.message, f"❌ {message}") return # Start background capture for ! bash command output @@ -970,7 +1042,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - return subdir_name = cached_dirs[idx] - default_path = str(Path.cwd()) + default_path = config.browse_root or str(Path.cwd()) current_path = ( context.user_data.get(BROWSE_PATH_KEY, default_path) if context.user_data @@ -1000,7 +1072,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - if pending_tid is not None and _get_thread_id(update) != pending_tid: await query.answer("Stale browser (topic mismatch)", show_alert=True) return - default_path = str(Path.cwd()) + default_path = config.browse_root or str(Path.cwd()) current_path = ( context.user_data.get(BROWSE_PATH_KEY, default_path) if context.user_data @@ -1033,7 +1105,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - except ValueError: await query.answer("Invalid data") return - default_path = str(Path.cwd()) + default_path = config.browse_root or str(Path.cwd()) current_path = ( context.user_data.get(BROWSE_PATH_KEY, default_path) if context.user_data @@ -1049,7 +1121,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - await query.answer() elif data == CB_DIR_CONFIRM: - default_path = str(Path.cwd()) + default_path = config.browse_root or str(Path.cwd()) selected_path = ( context.user_data.get(BROWSE_PATH_KEY, default_path) if context.user_data @@ -1084,7 +1156,9 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - user.id, pending_thread_id, ) - # Wait for Claude Code's SessionStart hook to register in session_map + # Wait for Claude Code's SessionStart hook to register in session_map. + # Return value intentionally ignored: on timeout, the monitor's poll + # cycle will pick up the session_map entry once the hook fires. await session_manager.wait_for_session_map_entry(created_wid) if pending_thread_id is not None: @@ -1123,14 +1197,16 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - created_wname, len(pending_text), ) - if context.user_data is not None: - context.user_data.pop("_pending_thread_text", None) - context.user_data.pop("_pending_thread_id", None) send_ok, send_msg = await session_manager.send_to_window( created_wid, pending_text, ) - if not send_ok: + if send_ok: + # Clear pending text only after successful send + if context.user_data is not None: + context.user_data.pop("_pending_thread_text", None) + context.user_data.pop("_pending_thread_id", None) + else: logger.warning("Failed to forward pending text: %s", send_msg) await safe_send( context.bot, @@ -1224,14 +1300,16 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - pending_text = ( context.user_data.get("_pending_thread_text") if context.user_data else None ) - if context.user_data is not None: - context.user_data.pop("_pending_thread_text", None) - context.user_data.pop("_pending_thread_id", None) if pending_text: send_ok, send_msg = await session_manager.send_to_window( selected_wid, pending_text ) - if not send_ok: + if send_ok: + # Clear pending text only after successful send + if context.user_data is not None: + context.user_data.pop("_pending_thread_text", None) + context.user_data.pop("_pending_thread_id", None) + else: logger.warning("Failed to forward pending text: %s", send_msg) await safe_send( context.bot, @@ -1239,6 +1317,9 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - f"❌ Failed to send pending message: {send_msg}", message_thread_id=thread_id, ) + elif context.user_data is not None: + # No pending text — clean up thread_id tracking + context.user_data.pop("_pending_thread_id", None) await query.answer("Bound") # Window picker: new session → transition to directory browser @@ -1251,7 +1332,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - return # Preserve pending thread info, clear only picker state clear_window_picker_state(context.user_data) - start_path = str(Path.cwd()) + start_path = config.browse_root or str(Path.cwd()) msg_text, keyboard, subdirs = build_directory_browser(start_path) if context.user_data is not None: context.user_data[STATE_KEY] = STATE_BROWSING_DIRECTORY @@ -1293,8 +1374,8 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - keyboard = _build_screenshot_keyboard(window_id) try: await query.edit_message_media( - media=InputMediaDocument( - media=io.BytesIO(png_bytes), filename="screenshot.png" + media=InputMediaPhoto( + media=io.BytesIO(png_bytes), ), reply_markup=keyboard, ) @@ -1432,6 +1513,9 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - if not w: await query.answer("Window not found", show_alert=True) return + if w.pane_current_command in SHELL_COMMANDS: + await query.answer("Claude Code has exited", show_alert=True) + return await tmux_manager.send_keys( w.window_id, tmux_key, enter=enter, literal=literal @@ -1446,14 +1530,13 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - keyboard = _build_screenshot_keyboard(window_id) try: await query.edit_message_media( - media=InputMediaDocument( + media=InputMediaPhoto( media=io.BytesIO(png_bytes), - filename="screenshot.png", ), reply_markup=keyboard, ) - except Exception: - pass # Screenshot unchanged or message too old + except Exception as e: + logger.debug(f"Screenshot edit after key press failed: {e}") # --- Streaming response / notifications --- @@ -1481,30 +1564,42 @@ async def handle_new_message(msg: NewMessage, bot: Bot) -> None: for user_id, wid, thread_id in active_users: # Handle interactive tools specially - capture terminal and send UI if msg.tool_name in INTERACTIVE_TOOL_NAMES and msg.content_type == "tool_use": - # Mark interactive mode BEFORE sleeping so polling skips this window - set_interactive_mode(user_id, wid, thread_id) - # Flush pending messages (e.g. plan content) before sending interactive UI - queue = get_message_queue(user_id) - if queue: - await queue.join() - # Wait briefly for Claude Code to render the question UI - await asyncio.sleep(0.3) - handled = await handle_interactive_ui(bot, user_id, wid, thread_id) - if handled: - # Update user's read offset - session = await session_manager.resolve_session_for_window(wid) - if session and session.file_path: - try: - file_size = Path(session.file_path).stat().st_size - session_manager.update_user_window_offset( - user_id, wid, file_size - ) - except OSError: - pass - continue # Don't send the normal tool_use message - else: - # UI not rendered — clear the early-set mode - clear_interactive_mode(user_id, thread_id) + # Mark interactive mode BEFORE enqueuing so polling skips this window. + # Capture the generation so the callable can detect staleness. + gen = set_interactive_mode(user_id, wid, thread_id) + + # Enqueue the interactive UI handling as a callable task so it + # executes AFTER all pending content messages already in the queue, + # without blocking the monitor loop or any other session's processing. + async def _send_interactive_ui( + _bot: Bot = bot, + _user_id: int = user_id, + _wid: str = wid, + _thread_id: int | None = thread_id, + _gen: int = gen, + ) -> None: + # Wait briefly for Claude Code to render the question UI + await asyncio.sleep(0.3) + handled = await handle_interactive_ui( + _bot, _user_id, _wid, _thread_id, expected_generation=_gen + ) + if handled: + # Update user's read offset + session = await session_manager.resolve_session_for_window(_wid) + if session and session.file_path: + try: + file_size = Path(session.file_path).stat().st_size + session_manager.update_user_window_offset( + _user_id, _wid, file_size + ) + except OSError: + pass + else: + # UI not rendered — clear the early-set mode + clear_interactive_mode(_user_id, _thread_id) + + enqueue_callable(bot, user_id, _send_interactive_ui) + continue # Don't send the normal tool_use message # Any non-interactive message means the interaction is complete — delete the UI message if get_interactive_msg_id(user_id, thread_id): @@ -1631,6 +1726,7 @@ def create_bot() -> Application: application.add_handler(CommandHandler("history", history_command)) application.add_handler(CommandHandler("screenshot", screenshot_command)) application.add_handler(CommandHandler("esc", esc_command)) + application.add_handler(CommandHandler("kill", kill_command)) application.add_handler(CommandHandler("unbind", unbind_command)) application.add_handler(CommandHandler("usage", usage_command)) application.add_handler(CallbackQueryHandler(callback_handler)) diff --git a/src/ccbot/config.py b/src/ccbot/config.py index 1dfd28ed..6735abd3 100644 --- a/src/ccbot/config.py +++ b/src/ccbot/config.py @@ -93,6 +93,9 @@ def __init__(self) -> None: os.getenv("CCBOT_SHOW_HIDDEN_DIRS", "").lower() == "true" ) + # Starting directory for the directory browser + self.browse_root = os.getenv("CCBOT_BROWSE_ROOT", "") + # Scrub sensitive vars from os.environ so child processes never inherit them. # Values are already captured in Config attributes above. for var in SENSITIVE_ENV_VARS: @@ -100,12 +103,13 @@ def __init__(self) -> None: logger.debug( "Config initialized: dir=%s, token=%s..., allowed_users=%d, " - "tmux_session=%s, claude_projects_path=%s", + "tmux_session=%s, claude_projects_path=%s, browse_root=%s", self.config_dir, self.telegram_bot_token[:8], len(self.allowed_users), self.tmux_session_name, self.claude_projects_path, + self.browse_root, ) def is_user_allowed(self, user_id: int) -> bool: diff --git a/src/ccbot/handlers/directory_browser.py b/src/ccbot/handlers/directory_browser.py index a9e724d7..bb0ddcad 100644 --- a/src/ccbot/handlers/directory_browser.py +++ b/src/ccbot/handlers/directory_browser.py @@ -112,7 +112,7 @@ def build_directory_browser( """ path = Path(current_path).expanduser().resolve() if not path.exists() or not path.is_dir(): - path = Path.cwd() + path = Path(config.browse_root) if config.browse_root else Path.cwd() try: subdirs = sorted( diff --git a/src/ccbot/handlers/interactive_ui.py b/src/ccbot/handlers/interactive_ui.py index 174e3a9e..c90c93dc 100644 --- a/src/ccbot/handlers/interactive_ui.py +++ b/src/ccbot/handlers/interactive_ui.py @@ -15,8 +15,10 @@ """ import logging +import time from telegram import Bot, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.error import BadRequest, RetryAfter from ..session import session_manager from ..terminal_parser import extract_interactive_content, is_interactive_ui @@ -45,6 +47,21 @@ # Track interactive mode: (user_id, thread_id_or_0) -> window_id _interactive_mode: dict[tuple[int, int], str] = {} +# Deduplication: monotonic timestamp of last new interactive message send +_last_interactive_send: dict[tuple[int, int], float] = {} +_INTERACTIVE_DEDUP_WINDOW = 2.0 # seconds — suppress duplicate sends within this window + +# Generation counter: incremented on every state transition (set/clear) so that +# stale callables enqueued by the JSONL monitor can detect invalidation. +_interactive_generation: dict[tuple[int, int], int] = {} + + +def _next_generation(ikey: tuple[int, int]) -> int: + """Increment and return the generation counter for this user/thread.""" + gen = _interactive_generation.get(ikey, 0) + 1 + _interactive_generation[ikey] = gen + return gen + def get_interactive_window(user_id: int, thread_id: int | None = None) -> str | None: """Get the window_id for user's interactive mode.""" @@ -55,21 +72,25 @@ def set_interactive_mode( user_id: int, window_id: str, thread_id: int | None = None, -) -> None: - """Set interactive mode for a user.""" +) -> int: + """Set interactive mode for a user. Returns the generation counter.""" + ikey = (user_id, thread_id or 0) logger.debug( "Set interactive mode: user=%d, window_id=%s, thread=%s", user_id, window_id, thread_id, ) - _interactive_mode[(user_id, thread_id or 0)] = window_id + _interactive_mode[ikey] = window_id + return _next_generation(ikey) def clear_interactive_mode(user_id: int, thread_id: int | None = None) -> None: """Clear interactive mode for a user (without deleting message).""" + ikey = (user_id, thread_id or 0) logger.debug("Clear interactive mode: user=%d, thread=%s", user_id, thread_id) - _interactive_mode.pop((user_id, thread_id or 0), None) + _interactive_mode.pop(ikey, None) + _next_generation(ikey) def get_interactive_msg_id(user_id: int, thread_id: int | None = None) -> int | None: @@ -145,14 +166,36 @@ async def handle_interactive_ui( user_id: int, window_id: str, thread_id: int | None = None, + expected_generation: int | None = None, ) -> bool: """Capture terminal and send interactive UI content to user. Handles AskUserQuestion, ExitPlanMode, Permission Prompt, and RestoreCheckpoint UIs. Returns True if UI was detected and sent, False otherwise. + + If *expected_generation* is provided (from the JSONL monitor path), + the function checks that the current generation still matches before + proceeding. This prevents stale callables from acting after the + interactive mode has been cleared or superseded. """ ikey = (user_id, thread_id or 0) + + # Generation guard: if caller provided an expected generation and it + # doesn't match the current one, this callable is stale — bail out. + if expected_generation is not None: + current_gen = _interactive_generation.get(ikey, 0) + if current_gen != expected_generation: + logger.debug( + "Stale interactive UI callable: user=%d, thread=%s, " + "expected_gen=%d, current_gen=%d — skipping", + user_id, + thread_id, + expected_generation, + current_gen, + ) + return False + chat_id = session_manager.resolve_chat_id(user_id, thread_id) w = await tmux_manager.find_window_by_id(window_id) if not w: @@ -202,14 +245,54 @@ async def handle_interactive_ui( ) _interactive_mode[ikey] = window_id return True - except Exception: - # Edit failed (message deleted, etc.) - clear stale msg_id and send new + except RetryAfter: + raise + except BadRequest as e: + if "is not modified" in str(e).lower(): + # Content identical to what's already displayed — treat as success. + _interactive_mode[ikey] = window_id + return True + # Any other BadRequest (e.g. message deleted, too old to edit): + # clear stale state and try to remove the orphan message. + logger.debug( + "Edit failed for interactive msg %s (%s), sending new", + existing_msg_id, + e, + ) + _interactive_msgs.pop(ikey, None) + try: + await bot.delete_message(chat_id=chat_id, message_id=existing_msg_id) + except Exception: + pass # Already deleted or too old — ignore. + # Fall through to send new message + except Exception as e: + # NetworkError, TimedOut, Forbidden, etc. — message state is uncertain; + # discard the stale ID and fall through to send a fresh message. logger.debug( - "Edit failed for interactive msg %s, sending new", existing_msg_id + "Edit failed (%s) for interactive msg %s, sending new", + e, + existing_msg_id, ) _interactive_msgs.pop(ikey, None) # Fall through to send new message + # Dedup guard: prevent both JSONL monitor and status poller from sending + # a new interactive message in the same short window. No await between + # check and set, so this is atomic in the asyncio event loop. + last_send = _last_interactive_send.get(ikey, 0.0) + now = time.monotonic() + if now - last_send < _INTERACTIVE_DEDUP_WINDOW: + logger.debug( + "Dedup: skipping duplicate interactive UI send " + "(user=%d, thread=%s, %.1fs since last)", + user_id, + thread_id, + now - last_send, + ) + _interactive_mode[ikey] = window_id + return True + _last_interactive_send[ikey] = now + # Send new message (plain text — terminal content is not markdown) logger.info( "Sending interactive UI to user %d for window_id %s", user_id, window_id @@ -222,7 +305,11 @@ async def handle_interactive_ui( link_preview_options=NO_LINK_PREVIEW, **thread_kwargs, # type: ignore[arg-type] ) + except RetryAfter: + _last_interactive_send.pop(ikey, None) + raise except Exception as e: + _last_interactive_send.pop(ikey, None) logger.error("Failed to send interactive UI: %s", e) return False if sent: @@ -241,6 +328,8 @@ async def clear_interactive_msg( ikey = (user_id, thread_id or 0) msg_id = _interactive_msgs.pop(ikey, None) _interactive_mode.pop(ikey, None) + _last_interactive_send.pop(ikey, None) + _next_generation(ikey) logger.debug( "Clear interactive msg: user=%d, thread=%s, msg_id=%s", user_id, diff --git a/src/ccbot/handlers/message_queue.py b/src/ccbot/handlers/message_queue.py index bdd28038..53ff3b63 100644 --- a/src/ccbot/handlers/message_queue.py +++ b/src/ccbot/handlers/message_queue.py @@ -20,8 +20,9 @@ import asyncio import logging import time +from collections.abc import Callable, Coroutine from dataclasses import dataclass, field -from typing import Literal +from typing import Any, Literal from telegram import Bot from telegram.constants import ChatAction @@ -55,7 +56,7 @@ def _ensure_formatted(text: str) -> str: class MessageTask: """Message task for queue processing.""" - task_type: Literal["content", "status_update", "status_clear"] + task_type: Literal["content", "status_update", "status_clear", "callable"] text: str | None = None window_id: str | None = None # content type fields @@ -64,6 +65,13 @@ class MessageTask: content_type: str = "text" thread_id: int | None = None # Telegram topic thread_id for targeted send image_data: list[tuple[str, bytes]] | None = None # From tool_result images + # callable task: a zero-argument coroutine factory executed in-order by the + # worker. Must be a factory (not a bare coroutine object) so the worker can + # safely retry by calling it again — a coroutine can only be awaited once. + callable_fn: Callable[[], Coroutine[Any, Any, None]] | None = None + # Number of times this task has been re-queued after a long RetryAfter. + # Prevents infinite re-queue loops under persistent rate limiting. + requeue_count: int = 0 # Per-user message queues and worker tasks @@ -84,6 +92,12 @@ class MessageTask: # Max seconds to wait for flood control before dropping tasks FLOOD_CONTROL_MAX_WAIT = 10 +# Maximum number of RetryAfter retries per task before giving up +MAX_TASK_RETRIES = 3 + +# Maximum number of times a task can be re-queued after long RetryAfter +MAX_REQUEUE_COUNT = 5 + def get_message_queue(user_id: int) -> asyncio.Queue[MessageTask] | None: """Get the message queue for a user (if exists).""" @@ -144,10 +158,11 @@ async def _merge_content_tasks( additional tasks merged (0 if no merging occurred). Note on queue counter management: - When we put items back, we call task_done() to compensate for the - internal counter increment caused by put_nowait(). This is necessary - because the items were already counted when originally enqueued. - Without this compensation, queue.join() would wait indefinitely. + put_nowait() unconditionally increments _unfinished_tasks. + When we put items back, they already hold a counter slot from when + they were first enqueued, so the compensating task_done() removes the + duplicate increment added by put_nowait(). Without this, _unfinished_tasks + would leak by len(remaining) per merge cycle. """ merged_parts = list(first.parts) current_length = sum(len(p) for p in merged_parts) @@ -212,10 +227,10 @@ async def _message_queue_worker(bot: Bot, user_id: int) -> None: if flood_end > 0: remaining = flood_end - time.monotonic() if remaining > 0: - if task.task_type != "content": + if task.task_type in ("status_update", "status_clear"): # Status is ephemeral — safe to drop continue - # Content is actual Claude output — wait then send + # Content and callable tasks must not be dropped — wait logger.debug( "Flood controlled: waiting %.0fs for content (user %d)", remaining, @@ -226,42 +241,87 @@ async def _message_queue_worker(bot: Bot, user_id: int) -> None: _flood_until.pop(user_id, None) logger.info("Flood control lifted for user %d", user_id) + # Retry loop: retry the task on RetryAfter up to MAX_TASK_RETRIES times. + # Merging is done once before the loop so that merged_task is reused on + # every retry attempt rather than re-merging from a now-empty queue. if task.task_type == "content": - # Try to merge consecutive content tasks merged_task, merge_count = await _merge_content_tasks( queue, task, lock ) if merge_count > 0: logger.debug(f"Merged {merge_count} tasks for user {user_id}") - # Mark merged tasks as done for _ in range(merge_count): queue.task_done() - await _process_content_task(bot, user_id, merged_task) - elif task.task_type == "status_update": - await _process_status_update_task(bot, user_id, task) - elif task.task_type == "status_clear": - await _do_clear_status_message(bot, user_id, task.thread_id or 0) - except RetryAfter as e: - retry_secs = ( - e.retry_after - if isinstance(e.retry_after, int) - else int(e.retry_after.total_seconds()) - ) - if retry_secs > FLOOD_CONTROL_MAX_WAIT: - _flood_until[user_id] = time.monotonic() + retry_secs - logger.warning( - "Flood control for user %d: retry_after=%ds, " - "pausing queue until ban expires", - user_id, - retry_secs, - ) else: - logger.warning( - "Flood control for user %d: waiting %ds", - user_id, - retry_secs, - ) - await asyncio.sleep(retry_secs) + merged_task = task + merge_count = 0 + + for attempt in range(MAX_TASK_RETRIES + 1): + try: + if merged_task.task_type == "content": + await _process_content_task(bot, user_id, merged_task) + elif merged_task.task_type == "status_update": + await _process_status_update_task(bot, user_id, merged_task) + elif merged_task.task_type == "status_clear": + await _do_clear_status_message( + bot, user_id, merged_task.thread_id or 0 + ) + elif merged_task.task_type == "callable": + if merged_task.callable_fn is not None: + await merged_task.callable_fn() + break # Success — exit retry loop + except RetryAfter as e: + retry_secs = ( + e.retry_after + if isinstance(e.retry_after, int) + else int(e.retry_after.total_seconds()) + ) + if retry_secs > FLOOD_CONTROL_MAX_WAIT: + _flood_until[user_id] = time.monotonic() + retry_secs + if merged_task.requeue_count >= MAX_REQUEUE_COUNT: + logger.error( + "Dropping task for user %d after %d re-queues " + "(persistent flood control, task_type=%s)", + user_id, + merged_task.requeue_count, + merged_task.task_type, + ) + break + merged_task.requeue_count += 1 + logger.warning( + "Flood control for user %d: retry_after=%ds, " + "re-queuing task (requeue %d/%d)", + user_id, + retry_secs, + merged_task.requeue_count, + MAX_REQUEUE_COUNT, + ) + # Re-queue so the task is retried once the ban + # expires. put_nowait increments _unfinished_tasks + # for the new slot; the outer finally calls + # task_done() for the slot consumed by dequeuing, + # so the net counter change is zero. + queue.put_nowait(merged_task) + break # Let the flood-control path handle re-queued task + if attempt < MAX_TASK_RETRIES: + logger.warning( + "RetryAfter for user %d: waiting %ds (attempt %d/%d)", + user_id, + retry_secs, + attempt + 1, + MAX_TASK_RETRIES, + ) + await asyncio.sleep(retry_secs) + # Loop back and retry the same task + else: + logger.error( + "Dropping task for user %d after %d retries " + "(last retry_after=%ds, task_type=%s)", + user_id, + MAX_TASK_RETRIES, + retry_secs, + merged_task.task_type, + ) except Exception as e: logger.error(f"Error processing message task for user {user_id}: {e}") finally: @@ -381,8 +441,14 @@ async def _process_content_task(bot: Bot, user_id: int, task: MessageTask) -> No # 4. Send images if present (from tool_result with base64 image blocks) await _send_task_images(bot, chat_id, task) - # 5. After content, check and send status - await _check_and_send_status(bot, user_id, wid, task.thread_id) + # 5. After content, check and send status. + # Catch RetryAfter here: the status message is cosmetic and must never + # propagate RetryAfter to the outer retry loop (which would re-send all + # content messages as duplicates). + try: + await _check_and_send_status(bot, user_id, wid, task.thread_id) + except RetryAfter: + pass async def _convert_status_to_content( @@ -475,7 +541,9 @@ async def _process_status_update_task( if "esc to interrupt" in status_text.lower(): try: await bot.send_chat_action( - chat_id=chat_id, action=ChatAction.TYPING + chat_id=chat_id, + action=ChatAction.TYPING, + message_thread_id=task.thread_id, ) except RetryAfter: raise @@ -534,7 +602,11 @@ async def _do_send_status_message( # Send typing indicator when Claude is working if "esc to interrupt" in text.lower(): try: - await bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING) + await bot.send_chat_action( + chat_id=chat_id, + action=ChatAction.TYPING, + message_thread_id=thread_id, + ) except RetryAfter: raise except Exception: @@ -661,6 +733,28 @@ async def enqueue_status_update( queue.put_nowait(task) +def enqueue_callable( + bot: Bot, + user_id: int, + coro_factory: Callable[[], Coroutine[Any, Any, None]], +) -> None: + """Enqueue a coroutine factory for in-order execution by the queue worker. + + *coro_factory* is a zero-argument callable that returns a new coroutine each + time it is called. The worker calls the factory on each attempt so that + retries after ``RetryAfter`` work correctly (a bare coroutine object can + only be awaited once). + + Typically this is just an async function reference (not its invocation):: + + enqueue_callable(bot, uid, my_async_fn) # correct — factory + enqueue_callable(bot, uid, my_async_fn()) # WRONG — bare coroutine + """ + queue = get_or_create_queue(bot, user_id) + task = MessageTask(task_type="callable", callable_fn=coro_factory) + queue.put_nowait(task) + + def clear_status_msg_info(user_id: int, thread_id: int | None = None) -> None: """Clear status message tracking for a user (and optionally a specific thread).""" skey = (user_id, thread_id or 0) diff --git a/src/ccbot/handlers/status_polling.py b/src/ccbot/handlers/status_polling.py index c4de1c6e..dde0e9ea 100644 --- a/src/ccbot/handlers/status_polling.py +++ b/src/ccbot/handlers/status_polling.py @@ -5,9 +5,9 @@ - Detects interactive UIs (permission prompts) not triggered via JSONL - Updates status messages in Telegram - Polls thread_bindings (each topic = one window) - - Periodically probes topic existence via unpin_all_forum_topic_messages - (silent no-op when no pins); cleans up deleted topics (kills tmux window - + unbinds thread) + - Periodically probes topic existence via send_chat_action (TYPING); + raises BadRequest on deleted topics; cleans up deleted topics (kills + tmux window + unbinds thread) Key components: - STATUS_POLL_INTERVAL: Polling frequency (1 second) @@ -21,6 +21,7 @@ import time from telegram import Bot +from telegram.constants import ChatAction from telegram.error import BadRequest from ..session import session_manager @@ -28,6 +29,7 @@ from ..tmux_manager import tmux_manager from .interactive_ui import ( clear_interactive_msg, + get_interactive_msg_id, get_interactive_window, handle_interactive_ui, ) @@ -93,6 +95,15 @@ async def update_status_message( # Check for permission prompt (interactive UI not triggered via JSONL) # ALWAYS check UI, regardless of skip_status if should_check_new_ui and is_interactive_ui(pane_text): + # Skip if another path (e.g. JSONL monitor) already sent an interactive + # message for this user/thread — avoids duplicate messages + if get_interactive_msg_id(user_id, thread_id): + logger.debug( + "Interactive UI already tracked for user=%d thread=%s, skipping", + user_id, + thread_id, + ) + return logger.debug( "Interactive UI detected in polling (user=%d, window=%s, thread=%s)", user_id, @@ -129,13 +140,12 @@ async def status_poll_loop(bot: Bot) -> None: now = time.monotonic() if now - last_topic_check >= TOPIC_CHECK_INTERVAL: last_topic_check = now - for user_id, thread_id, wid in list( - session_manager.iter_thread_bindings() - ): + for user_id, thread_id, wid in session_manager.all_thread_bindings(): try: - await bot.unpin_all_forum_topic_messages( + await bot.send_chat_action( chat_id=session_manager.resolve_chat_id(user_id, thread_id), message_thread_id=thread_id, + action=ChatAction.TYPING, ) except BadRequest as e: if "Topic_id_invalid" in str(e): @@ -165,7 +175,9 @@ async def status_poll_loop(bot: Bot) -> None: e, ) - for user_id, thread_id, wid in list(session_manager.iter_thread_bindings()): + # Fresh snapshot — reflects any unbinds from the topic probe above, + # so bindings cleaned there are naturally excluded. + for user_id, thread_id, wid in session_manager.all_thread_bindings(): try: # Clean up stale bindings (window no longer exists) w = await tmux_manager.find_window_by_id(wid) diff --git a/src/ccbot/screenshot.py b/src/ccbot/screenshot.py index cbe70f69..c73ef815 100644 --- a/src/ccbot/screenshot.py +++ b/src/ccbot/screenshot.py @@ -104,9 +104,8 @@ def _font_tier(ch: str) -> int: if cp in _SYMBOLA_CODEPOINTS: return 2 # CJK Unified Ideographs + CJK compat + fullwidth forms + Hangul + known Noto-only codepoints - if ( - cp in _NOTO_CODEPOINTS - or cp >= 0x1100 + if cp in _NOTO_CODEPOINTS or ( + cp >= 0x1100 and ( cp <= 0x11FF # Hangul Jamo or 0x2E80 <= cp <= 0x9FFF # CJK radicals, kangxi, ideographs diff --git a/src/ccbot/session.py b/src/ccbot/session.py index c740545c..22ca2464 100644 --- a/src/ccbot/session.py +++ b/src/ccbot/session.py @@ -17,27 +17,48 @@ Key class: SessionManager (singleton instantiated as `session_manager`). Key methods for thread binding access: - resolve_window_for_thread: Get window_id for a user's thread - - iter_thread_bindings: Generator for iterating all (user_id, thread_id, window_id) + - all_thread_bindings: Snapshot list of all (user_id, thread_id, window_id) - find_users_for_session: Find all users bound to a session_id """ import asyncio import json import logging +import re from dataclasses import dataclass, field from pathlib import Path -from collections.abc import Iterator from typing import Any import aiofiles from .config import config -from .tmux_manager import tmux_manager +from .tmux_manager import SHELL_COMMANDS, tmux_manager from .transcript_parser import TranscriptParser from .utils import atomic_write_json logger = logging.getLogger(__name__) +# Patterns for detecting Claude Code resume commands in pane output +_RESUME_CMD_RE = re.compile(r"(claude\s+(?:--resume|-r)\s+\S+)") +_STOPPED_RE = re.compile(r"Stopped\s+.*claude", re.IGNORECASE) + + +def _extract_resume_command(pane_text: str) -> str | None: + """Extract a resume command from pane content after Claude Code exit. + + Detects two patterns: + - Suspended process: ``[1]+ Stopped claude`` → returns ``"fg"`` + - Exited with resume hint: ``claude --resume `` → returns the full command + + Returns None if no resume path is detected. + """ + if _STOPPED_RE.search(pane_text): + return "fg" + match = _RESUME_CMD_RE.search(pane_text) + if match: + return match.group(1) + return None + @dataclass class WindowState: @@ -179,7 +200,6 @@ def _load_state(self) -> None: "Detected old-format state (window_name keys), " "will re-resolve on startup" ) - pass except (json.JSONDecodeError, ValueError) as e: logger.warning("Failed to load state: %s", e) @@ -188,7 +208,6 @@ def _load_state(self) -> None: self.thread_bindings = {} self.window_display_names = {} self.group_chat_ids = {} - pass async def resolve_stale_ids(self) -> None: """Re-resolve persisted window IDs against live tmux windows. @@ -472,8 +491,8 @@ async def wait_for_session_map_entry( timeout, ) key = f"{config.tmux_session_name}:{window_id}" - deadline = asyncio.get_event_loop().time() + timeout - while asyncio.get_event_loop().time() < deadline: + deadline = asyncio.get_running_loop().time() + timeout + while asyncio.get_running_loop().time() < deadline: try: if config.session_map_file.exists(): async with aiofiles.open(config.session_map_file, "r") as f: @@ -739,15 +758,18 @@ def resolve_window_for_thread( return None return self.get_window_for_thread(user_id, thread_id) - def iter_thread_bindings(self) -> Iterator[tuple[int, int, str]]: - """Iterate all thread bindings as (user_id, thread_id, window_id). + def all_thread_bindings(self) -> list[tuple[int, int, str]]: + """Return a snapshot of all thread bindings as (user_id, thread_id, window_id). - Provides encapsulated access to thread_bindings without exposing - the internal data structure directly. + Returns a new list each call so callers can safely await between + iterations without risking ``RuntimeError: dictionary changed size + during iteration`` from a concurrent ``unbind_thread`` call. """ - for user_id, bindings in self.thread_bindings.items(): - for thread_id, window_id in bindings.items(): - yield user_id, thread_id, window_id + return [ + (user_id, thread_id, window_id) + for user_id, bindings in self.thread_bindings.items() + for thread_id, window_id in bindings.items() + ] async def find_users_for_session( self, @@ -758,7 +780,7 @@ async def find_users_for_session( Returns list of (user_id, window_id, thread_id) tuples. """ result: list[tuple[int, str, int]] = [] - for user_id, thread_id, window_id in self.iter_thread_bindings(): + for user_id, thread_id, window_id in self.all_thread_bindings(): resolved = await self.resolve_session_for_window(window_id) if resolved and resolved.session_id == session_id: result.append((user_id, window_id, thread_id)) @@ -767,7 +789,11 @@ async def find_users_for_session( # --- Tmux helpers --- async def send_to_window(self, window_id: str, text: str) -> tuple[bool, str]: - """Send text to a tmux window by ID.""" + """Send text to a tmux window by ID. + + If the pane is running a bare shell (Claude Code exited), attempts + to auto-resume via ``fg`` or ``claude --resume `` before sending. + """ display = self.get_display_name(window_id) logger.debug( "send_to_window: window_id=%s (%s), text_len=%d", @@ -778,11 +804,61 @@ async def send_to_window(self, window_id: str, text: str) -> tuple[bool, str]: window = await tmux_manager.find_window_by_id(window_id) if not window: return False, "Window not found (may have been closed)" + if window.pane_current_command in SHELL_COMMANDS: + resumed = await self._try_resume_claude(window_id, display) + if not resumed: + return False, "Claude Code is not running (session exited)" success = await tmux_manager.send_keys(window.window_id, text) if success: return True, f"Sent to {display}" return False, "Failed to send keys" + async def _try_resume_claude(self, window_id: str, display: str) -> bool: + """Attempt to resume Claude Code when pane has dropped to a shell. + + Detects ``fg`` (suspended process) and ``claude --resume `` + (exited session) patterns in the pane content. Sends the appropriate + command and waits for Claude Code to take over the terminal. + + Returns True if Claude Code is running after the attempt. + """ + pane_text = await tmux_manager.capture_pane(window_id) + if not pane_text: + return False + + resume_cmd = _extract_resume_command(pane_text) + if not resume_cmd: + logger.warning( + "No resume command found in %s (%s), cannot auto-resume", + window_id, + display, + ) + return False + + logger.info( + "Auto-resuming Claude Code in %s (%s): %s", + window_id, + display, + resume_cmd, + ) + await tmux_manager.send_keys(window_id, resume_cmd) + + # Wait for Claude Code to take over the terminal + max_wait = 3.0 if resume_cmd == "fg" else 15.0 + elapsed = 0.0 + while elapsed < max_wait: + await asyncio.sleep(0.5) + elapsed += 0.5 + w = await tmux_manager.find_window_by_id(window_id) + if w and w.pane_current_command not in SHELL_COMMANDS: + # Claude Code is running again — give TUI a moment to init + await asyncio.sleep(1.0) + logger.info("Claude Code resumed in %s (%s)", window_id, display) + return True + + logger.warning("Auto-resume timed out for %s (%s)", window_id, display) + return False + # --- Message history --- async def get_recent_messages( diff --git a/src/ccbot/session_monitor.py b/src/ccbot/session_monitor.py index 0a1b3186..9bc7e7e6 100644 --- a/src/ccbot/session_monitor.py +++ b/src/ccbot/session_monitor.py @@ -6,7 +6,7 @@ 3. Reads new JSONL lines from each session file using byte-offset tracking. 4. Parses entries via TranscriptParser and emits NewMessage objects to a callback. -Optimizations: mtime cache skips unchanged files; byte offset avoids re-reading. +Optimizations: file size check skips unchanged files; byte offset avoids re-reading. Key classes: SessionMonitor, NewMessage, SessionInfo. """ @@ -82,8 +82,6 @@ def __init__( # Track last known session_map for detecting changes # Keys may be window_id (@12) or window_name (old format) during transition self._last_session_map: dict[str, str] = {} # window_key -> session_id - # In-memory mtime cache for quick file change detection (not persisted) - self._file_mtimes: dict[str, float] = {} # session_id -> last_seen_mtime def set_message_callback( self, callback: Callable[[NewMessage], Awaitable[None]] @@ -292,41 +290,33 @@ async def check_for_updates(self, active_session_ids: set[str]) -> list[NewMessa # to avoid re-processing old messages try: file_size = session_info.file_path.stat().st_size - current_mtime = session_info.file_path.stat().st_mtime except OSError: file_size = 0 - current_mtime = 0.0 tracked = TrackedSession( session_id=session_info.session_id, file_path=str(session_info.file_path), last_byte_offset=file_size, ) self.state.update_session(tracked) - self._file_mtimes[session_info.session_id] = current_mtime logger.info(f"Started tracking session: {session_info.session_id}") continue - # Check mtime + file size to see if file has changed + # Quick size check: skip reading if file size hasn't changed. + # For append-only JSONL files, size == offset means no new + # content. Size < offset (truncation) and size > offset (new + # data) both need processing — handled inside _read_new_lines. try: - st = session_info.file_path.stat() - current_mtime = st.st_mtime - current_size = st.st_size + current_size = session_info.file_path.stat().st_size except OSError: continue - last_mtime = self._file_mtimes.get(session_info.session_id, 0.0) - if ( - current_mtime <= last_mtime - and current_size <= tracked.last_byte_offset - ): - # File hasn't changed, skip reading + if current_size == tracked.last_byte_offset: continue # File changed, read new content from last offset new_entries = await self._read_new_lines( tracked, session_info.file_path ) - self._file_mtimes[session_info.session_id] = current_mtime if new_entries: logger.debug( @@ -369,7 +359,10 @@ async def check_for_updates(self, active_session_ids: set[str]) -> list[NewMessa except OSError as e: logger.debug(f"Error processing session {session_info.session_id}: {e}") - self.state.save_if_dirty() + # NOTE: save_if_dirty() is intentionally NOT called here. + # Offsets must be persisted only AFTER delivery to Telegram (in + # _monitor_loop) to guarantee at-least-once delivery. Saving before + # delivery would risk silent message loss on crash. return new_messages async def _load_current_session_map(self) -> dict[str, str]: @@ -416,7 +409,7 @@ async def _cleanup_all_stale_sessions(self) -> None: ) for session_id in stale_sessions: self.state.remove_session(session_id) - self._file_mtimes.pop(session_id, None) + self._pending_tools.pop(session_id, None) self.state.save_if_dirty() async def _detect_and_cleanup_changes(self) -> dict[str, str]: @@ -458,7 +451,7 @@ async def _detect_and_cleanup_changes(self) -> dict[str, str]: if sessions_to_remove: for session_id in sessions_to_remove: self.state.remove_session(session_id) - self._file_mtimes.pop(session_id, None) + self._pending_tools.pop(session_id, None) self.state.save_if_dirty() # Update last known map @@ -503,6 +496,12 @@ async def _monitor_loop(self) -> None: except Exception as e: logger.error(f"Message callback error: {e}") + # Persist byte offsets AFTER delivering messages to Telegram. + # This guarantees at-least-once delivery: if the bot crashes + # before this save, messages will be re-read and re-delivered + # on restart (safe duplicate) rather than silently lost. + self.state.save_if_dirty() + except Exception as e: logger.error(f"Monitor loop error: {e}") diff --git a/src/ccbot/tmux_manager.py b/src/ccbot/tmux_manager.py index 84cba5aa..31a5e775 100644 --- a/src/ccbot/tmux_manager.py +++ b/src/ccbot/tmux_manager.py @@ -15,6 +15,7 @@ import asyncio import logging +import time from dataclasses import dataclass from pathlib import Path @@ -24,6 +25,22 @@ logger = logging.getLogger(__name__) +# Process names that indicate a bare shell (Claude Code has exited). +# Used to prevent sending user input to a shell prompt. +SHELL_COMMANDS = frozenset( + { + "bash", + "zsh", + "sh", + "fish", + "dash", + "tcsh", + "csh", + "ksh", + "ash", + } +) + @dataclass class TmuxWindow: @@ -38,6 +55,9 @@ class TmuxWindow: class TmuxManager: """Manages tmux windows for Claude Code sessions.""" + # How long cached list_windows results are valid (seconds). + _CACHE_TTL = 1.0 + def __init__(self, session_name: str | None = None): """Initialize tmux manager. @@ -46,6 +66,8 @@ def __init__(self, session_name: str | None = None): """ self.session_name = session_name or config.tmux_session_name self._server: libtmux.Server | None = None + self._windows_cache: list[TmuxWindow] | None = None + self._windows_cache_time: float = 0.0 @property def server(self) -> libtmux.Server: @@ -92,12 +114,23 @@ def _scrub_session_env(session: libtmux.Session) -> None: except Exception: pass # var not set in session env — nothing to remove + def invalidate_cache(self) -> None: + """Invalidate the cached window list (call after mutations).""" + self._windows_cache = None + async def list_windows(self) -> list[TmuxWindow]: """List all windows in the session with their working directories. - Returns: - List of TmuxWindow with window info and cwd + Results are cached for ``_CACHE_TTL`` seconds to avoid hammering + the tmux server when multiple callers need window info in the same + poll cycle. """ + now = time.monotonic() + if ( + self._windows_cache is not None + and (now - self._windows_cache_time) < self._CACHE_TTL + ): + return self._windows_cache def _sync_list_windows() -> list[TmuxWindow]: windows = [] @@ -135,7 +168,10 @@ def _sync_list_windows() -> list[TmuxWindow]: return windows - return await asyncio.to_thread(_sync_list_windows) + result = await asyncio.to_thread(_sync_list_windows) + self._windows_cache = result + self._windows_cache_time = time.monotonic() + return result async def find_window_by_name(self, window_name: str) -> TmuxWindow | None: """Find a window by its name. @@ -172,56 +208,29 @@ async def find_window_by_id(self, window_id: str) -> TmuxWindow | None: async def capture_pane(self, window_id: str, with_ansi: bool = False) -> str | None: """Capture the visible text content of a window's active pane. - Args: - window_id: The window ID to capture - with_ansi: If True, capture with ANSI color codes - - Returns: - The captured text, or None on failure. + Uses a direct ``tmux capture-pane`` subprocess for both plain text + and ANSI modes — avoids the multiple tmux round-trips that libtmux + would generate (list-windows → list-panes → capture-pane). """ + cmd = ["tmux", "capture-pane", "-p", "-t", window_id] if with_ansi: - # Use async subprocess to call tmux capture-pane -e for ANSI colors - try: - proc = await asyncio.create_subprocess_exec( - "tmux", - "capture-pane", - "-e", - "-p", - "-t", - window_id, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await proc.communicate() - if proc.returncode == 0: - return stdout.decode("utf-8") - logger.error( - f"Failed to capture pane {window_id}: {stderr.decode('utf-8')}" - ) - return None - except Exception as e: - logger.error(f"Unexpected error capturing pane {window_id}: {e}") - return None - - # Original implementation for plain text - wrap in thread - def _sync_capture() -> str | None: - session = self.get_session() - if not session: - return None - try: - window = session.windows.get(window_id=window_id) - if not window: - return None - pane = window.active_pane - if not pane: - return None - lines = pane.capture_pane() - return "\n".join(lines) if isinstance(lines, list) else str(lines) - except Exception as e: - logger.error(f"Failed to capture pane {window_id}: {e}") - return None - - return await asyncio.to_thread(_sync_capture) + cmd.insert(2, "-e") + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode == 0: + return stdout.decode("utf-8") + logger.error( + "Failed to capture pane %s: %s", window_id, stderr.decode("utf-8") + ) + return None + except Exception as e: + logger.error("Unexpected error capturing pane %s: %s", window_id, e) + return None async def send_keys( self, window_id: str, text: str, enter: bool = True, literal: bool = True @@ -326,6 +335,7 @@ def _sync_send_keys() -> bool: async def rename_window(self, window_id: str, new_name: str) -> bool: """Rename a tmux window by its ID.""" + self.invalidate_cache() def _sync_rename() -> bool: session = self.get_session() @@ -346,6 +356,7 @@ def _sync_rename() -> bool: async def kill_window(self, window_id: str) -> bool: """Kill a tmux window by its ID.""" + self.invalidate_cache() def _sync_kill() -> bool: session = self.get_session() @@ -398,6 +409,8 @@ async def create_window( counter += 1 # Create window in thread + self.invalidate_cache() + def _create_and_start() -> tuple[bool, str, str, str]: session = self.get_or_create_session() try: diff --git a/tests/ccbot/handlers/test_interactive_ui.py b/tests/ccbot/handlers/test_interactive_ui.py index 8d6a98e4..336f9965 100644 --- a/tests/ccbot/handlers/test_interactive_ui.py +++ b/tests/ccbot/handlers/test_interactive_ui.py @@ -32,13 +32,19 @@ def mock_bot(): @pytest.fixture def _clear_interactive_state(): """Ensure interactive state is clean before and after each test.""" - from ccbot.handlers.interactive_ui import _interactive_mode, _interactive_msgs + from ccbot.handlers.interactive_ui import ( + _interactive_mode, + _interactive_msgs, + _last_interactive_send, + ) _interactive_mode.clear() _interactive_msgs.clear() + _last_interactive_send.clear() yield _interactive_mode.clear() _interactive_msgs.clear() + _last_interactive_send.clear() @pytest.mark.usefixtures("_clear_interactive_state") diff --git a/tests/ccbot/handlers/test_status_polling.py b/tests/ccbot/handlers/test_status_polling.py index 9c0f04f7..ad6ec312 100644 --- a/tests/ccbot/handlers/test_status_polling.py +++ b/tests/ccbot/handlers/test_status_polling.py @@ -24,13 +24,19 @@ def mock_bot(): @pytest.fixture def _clear_interactive_state(): """Ensure interactive state is clean before and after each test.""" - from ccbot.handlers.interactive_ui import _interactive_mode, _interactive_msgs + from ccbot.handlers.interactive_ui import ( + _interactive_mode, + _interactive_msgs, + _last_interactive_send, + ) _interactive_mode.clear() _interactive_msgs.clear() + _last_interactive_send.clear() yield _interactive_mode.clear() _interactive_msgs.clear() + _last_interactive_send.clear() @pytest.mark.usefixtures("_clear_interactive_state") diff --git a/tests/ccbot/test_session.py b/tests/ccbot/test_session.py index 022fb55a..96cfa4a7 100644 --- a/tests/ccbot/test_session.py +++ b/tests/ccbot/test_session.py @@ -25,13 +25,44 @@ def test_bind_unbind_get_returns_none(self, mgr: SessionManager) -> None: def test_unbind_nonexistent_returns_none(self, mgr: SessionManager) -> None: assert mgr.unbind_thread(100, 999) is None - def test_iter_thread_bindings(self, mgr: SessionManager) -> None: + def test_all_thread_bindings(self, mgr: SessionManager) -> None: mgr.bind_thread(100, 1, "@1") mgr.bind_thread(100, 2, "@2") mgr.bind_thread(200, 3, "@3") - result = set(mgr.iter_thread_bindings()) + result = set(mgr.all_thread_bindings()) assert result == {(100, 1, "@1"), (100, 2, "@2"), (200, 3, "@3")} + def test_all_thread_bindings_returns_list(self, mgr: SessionManager) -> None: + """all_thread_bindings must return a list (snapshot), not a generator. + + A generator would hold a live reference into the internal dict and could + raise RuntimeError if an async coroutine calls unbind_thread between two + consumed values. A list snapshot is safe across await points. + """ + mgr.bind_thread(100, 1, "@1") + result = mgr.all_thread_bindings() + assert isinstance(result, list) + + def test_all_thread_bindings_snapshot_is_independent( + self, mgr: SessionManager + ) -> None: + """Mutating thread_bindings after calling all_thread_bindings must not + affect the already-returned snapshot.""" + mgr.bind_thread(100, 1, "@1") + mgr.bind_thread(100, 2, "@2") + snapshot = mgr.all_thread_bindings() + # Mutate the live dict after snapshot was taken — the snapshot must + # be unaffected (this is the property that prevents RuntimeError + # when unbind_thread runs between await points in async callers) + mgr.unbind_thread(100, 1) + assert (100, 1, "@1") in snapshot + assert len(snapshot) == 2 + + def test_all_thread_bindings_empty(self, mgr: SessionManager) -> None: + """all_thread_bindings returns an empty list when nothing is bound.""" + result = mgr.all_thread_bindings() + assert result == [] + class TestGroupChatId: """Tests for group chat_id routing (supergroup forum topic support).