diff --git a/.claude/rules/architecture.md b/.claude/rules/architecture.md index 320050bd..d8694ebd 100644 --- a/.claude/rules/architecture.md +++ b/.claude/rules/architecture.md @@ -30,7 +30,7 @@ │ SessionMonitor │ │ TmuxManager (tmux_manager.py) │ │ (session_monitor.py) │ │ - list/find/create/kill windows│ │ - Poll JSONL every 2s │ │ - send_keys to pane │ -│ - Detect mtime changes │ │ - capture_pane for screenshot │ +│ - Detect size changes │ │ - capture_pane for screenshot │ │ - Parse new lines │ └──────────────┬─────────────────┘ │ - Track pending tools │ │ │ across poll cycles │ │ diff --git a/.claude/rules/message-handling.md b/.claude/rules/message-handling.md index ab108a86..33138307 100644 --- a/.claude/rules/message-handling.md +++ b/.claude/rules/message-handling.md @@ -32,7 +32,7 @@ Per-user message queues + worker pattern for all send tasks: ## Performance Optimizations -**mtime cache**: The monitoring loop maintains an in-memory file mtime cache, skipping reads for unchanged files. +**File size fast path**: The monitoring loop compares file size against the last byte offset, skipping reads for unchanged files. **Byte offset incremental reads**: Each tracked session records `last_byte_offset`, reading only new content. File truncation (offset > file_size) is detected and offset is auto-reset. diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 00000000..eba3817c --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,13 @@ +{ + "permissions": { + "allow": [ + "Bash(cmd.exe:*)", + "Bash(pip show:*)", + "Bash(python3:*)", + "Bash(\"/c/Users/krisd/AppData/Local/Programs/Python/Python314/python.exe\":*)", + "Bash(uv run:*)", + "Bash(~/.local/bin/uv run:*)", + "Bash(ls:*)" + ] + } +} diff --git a/ccbot-workshop-setup.md b/ccbot-workshop-setup.md new file mode 100644 index 00000000..bef0cec5 --- /dev/null +++ b/ccbot-workshop-setup.md @@ -0,0 +1,310 @@ +# CCBot Workshop Setup Guide + +Complete setup from a fresh Windows machine to running CCBot with Claude Code sessions accessible via Telegram. + +--- + +## Prerequisites + +Before you begin, you'll need: + +- Windows 10 (version 2004+) or Windows 11 +- A Telegram account +- A Claude Code subscription (Claude Pro/Team/Enterprise) +- Your project repositories cloned into `C:\GitHub\` + +--- + +## Part 1: Install WSL and Ubuntu + +Open **PowerShell as Administrator** and run: + +```powershell +wsl --install +``` + +This installs WSL 2 with Ubuntu. Restart your computer when prompted. + +After restart, Ubuntu will open automatically and ask you to create a username and password. Remember these — you'll need the password for `sudo` commands. + +Once you're at the Ubuntu prompt, update everything: + +```bash +sudo apt update && sudo apt upgrade -y +``` + +--- + +## Part 2: Install Core Tools + +### Node.js (required for Claude Code) + +```bash +curl -fsSL https://deb.nodesource.com/setup_22.x | sudo -E bash - +sudo apt install -y nodejs +``` + +Verify: + +```bash +node --version +npm --version +``` + +### Claude Code + +```bash +npm install -g @anthropic-ai/claude-code +``` + +Add npm global bin to your PATH if not already there: + +```bash +echo 'export PATH=~/.npm-global/bin:$PATH' >> ~/.bashrc +source ~/.bashrc +``` + +Verify Claude Code works: + +```bash +claude --version +``` + +### tmux + +```bash +sudo apt install -y tmux +``` + +### uv (Python package manager) + +```bash +curl -LsSf https://astral.sh/uv/install.sh | sh +``` + +Then: + +```bash +source ~/.bashrc +``` + +--- + +## Part 3: Create a Telegram Bot + +1. Open Telegram and search for **@BotFather** +2. Send `/newbot` +3. Follow the prompts to name your bot +4. BotFather gives you a **bot token** — save it (looks like `1234567890:ABCdefGHIjklMNOpqrsTUVwxyz`) + +### Get your Telegram user ID + +1. Search for **@userinfobot** in Telegram +2. Start a chat with it +3. It replies with your numeric user ID — save it + +### Create a Telegram group + +1. Create a new group in Telegram +2. Name it something like "Workshop Sessions" +3. Add your bot to the group +4. Go to group settings → **Topics** → Enable topics (use list format) +5. Make the bot an **admin** of the group + +--- + +## Part 4: Install CCBot Workshop + +```bash +uv tool install git+https://github.com/JanusMarko/ccbot-workshop.git +``` + +Verify it installed: + +```bash +which ccbot +``` + +If it's not found, add the path: + +```bash +export PATH="$HOME/.local/bin:$PATH" +echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc +``` + +### Configure CCBot + +Create the config directory and environment file: + +```bash +mkdir -p ~/.ccbot +nano ~/.ccbot/.env +``` + +Paste the following, replacing the placeholder values with your actual token and user ID: + +``` +TELEGRAM_BOT_TOKEN=your_bot_token_here +ALLOWED_USERS=your_telegram_user_id_here +TMUX_SESSION_NAME=ccbot +CLAUDE_COMMAND=claude --dangerously-skip-permissions +CCBOT_BROWSE_ROOT=/mnt/c/GitHub +``` + +Save with `Ctrl+O`, exit with `Ctrl+X`. + +The `CCBOT_BROWSE_ROOT` setting ensures the directory browser always starts from your `C:\GitHub\` folder when creating new sessions. + +### Install the Claude Code hook + +This lets CCBot track which Claude session runs in which tmux window: + +```bash +ccbot hook --install +``` + +Or manually add to `~/.claude/settings.json`: + +```json +{ + "hooks": { + "SessionStart": [ + { + "hooks": [{ "type": "command", "command": "ccbot hook", "timeout": 5 }] + } + ] + } +} +``` + +--- + +## Part 5: Starting CCBot + +### First time startup + +```bash +tmux new -s ccbot +``` + +Inside the tmux session: + +```bash +ccbot +``` + +You should see log output confirming the bot started, including your allowed users and Claude projects path. + +### Detach from tmux + +Press `Ctrl+b`, release, then press `d`. CCBot keeps running in the background. You can close the terminal — it stays alive. + +### Start a session from Telegram + +1. Open your Telegram group +2. Create a new topic (name it after your project, e.g. "PAIOS") +3. Send a message in the topic +4. CCBot shows a directory browser starting from `C:\GitHub\` — tap your project folder +5. Tap **Select** to confirm +6. A new tmux window is created with Claude Code running in that directory +7. Your message is forwarded to Claude Code + +### View sessions in the terminal + +```bash +tmux attach -t ccbot +``` + +Switch between windows using `Ctrl+b` then the window number (shown in the bottom bar). For example: + +- `Ctrl+b` then `1` → ccbot process (don't close this) +- `Ctrl+b` then `2` → first Claude Code session +- `Ctrl+b` then `3` → second Claude Code session + +Detach again with `Ctrl+b` then `d`. + +--- + +## Part 6: Daily Usage + +### Starting CCBot after a reboot + +```bash +tmux new -s ccbot || tmux attach -t ccbot +ccbot +``` + +Then `Ctrl+b` then `d` to detach. + +### Useful Telegram commands + +Send these in a topic: + +- `/screenshot` — see what the terminal looks like right now +- `/history` — browse conversation history +- `/esc` — send Escape key (toggles plan mode, same as Shift+Tab) +- `/cost` — check token usage +- `/kill` — kill the session and delete the topic + +### Ending a session + +Close or delete the topic in Telegram. The tmux window is automatically killed. + +### Multiple projects + +Create a new topic for each project. CCBot's design is **1 topic = 1 window = 1 session**. Each topic can run a separate Claude Code session in a different project directory. + +### Switching between phone and desktop + +From your phone, just use Telegram — all interaction goes through topics. + +To switch to your desktop terminal: + +```bash +tmux attach -t ccbot +``` + +Navigate to the right window with `Ctrl+b` then the window number. You're in the same session with full scrollback. + +--- + +## Part 7: Uninstall and Reinstall + +Use this after pushing updates to your fork. + +### Stop CCBot + +```bash +tmux attach -t ccbot +``` + +Press `Ctrl+C` to stop ccbot. Stay in the tmux session. + +### Uninstall the current version + +```bash +uv tool uninstall ccbot +``` + +### Install the updated version + +```bash +uv tool install git+https://github.com/JanusMarko/ccbot-workshop.git +``` + +If you're getting a cached version and not seeing your changes, force a fresh install: + +```bash +uv tool install --force git+https://github.com/JanusMarko/ccbot-workshop.git +``` + +### Verify and restart + +```bash +which ccbot +ccbot +``` + +Then `Ctrl+b` then `d` to detach. + +Your `~/.ccbot/.env` configuration and `~/.ccbot/state.json` session state are preserved across reinstalls — you don't need to reconfigure anything. diff --git a/pyproject.toml b/pyproject.toml index f02ba25c..81e6d87b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "Pillow>=10.0.0", "aiofiles>=24.0.0", "telegramify-markdown>=0.5.0", + "python-docx>=1.0.0", ] [project.scripts] diff --git a/src/ccbot/bot.py b/src/ccbot/bot.py index 0b746c78..5a25d20b 100644 --- a/src/ccbot/bot.py +++ b/src/ccbot/bot.py @@ -12,6 +12,8 @@ Unbound topics trigger the directory browser to create a new session. - Photo handling: photos sent by user are downloaded and forwarded to Claude Code as file paths (photo_handler). + - Document handling: Markdown and text files sent by user are saved to + {session_cwd}/docs/inbox/ and path forwarded to Claude Code (document_handler). - Automatic cleanup: closing a topic kills the associated window (topic_closed_handler). Unsupported content (stickers, voice, etc.) is rejected with a warning (unsupported_content_handler). @@ -41,7 +43,7 @@ BotCommand, InlineKeyboardButton, InlineKeyboardMarkup, - InputMediaDocument, + InputMediaPhoto, Update, ) from telegram.constants import ChatAction @@ -105,9 +107,9 @@ ) from .handlers.message_queue import ( clear_status_msg_info, + enqueue_callable, enqueue_content_message, enqueue_status_update, - get_message_queue, shutdown_workers, ) from .handlers.message_sender import ( @@ -124,7 +126,7 @@ from .session import session_manager from .session_monitor import NewMessage, SessionMonitor from .terminal_parser import extract_bash_output, is_interactive_ui -from .tmux_manager import tmux_manager +from .tmux_manager import SHELL_COMMANDS, tmux_manager from .utils import ccbot_dir logger = logging.getLogger(__name__) @@ -229,9 +231,8 @@ async def screenshot_command( png_bytes = await text_to_image(text, with_ansi=True) keyboard = _build_screenshot_keyboard(wid) - await update.message.reply_document( - document=io.BytesIO(png_bytes), - filename="screenshot.png", + await update.message.reply_photo( + photo=io.BytesIO(png_bytes), reply_markup=keyboard, ) @@ -285,6 +286,9 @@ async def esc_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non display = session_manager.get_display_name(wid) await safe_reply(update.message, f"❌ Window '{display}' no longer exists.") return + if w.pane_current_command in SHELL_COMMANDS: + await safe_reply(update.message, "❌ Claude Code has exited.") + return # Send Escape control character (no enter) await tmux_manager.send_keys(w.window_id, "\x1b", enter=False) @@ -309,6 +313,9 @@ async def usage_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N if not w: await safe_reply(update.message, f"Window '{wid}' no longer exists.") return + if w.pane_current_command in SHELL_COMMANDS: + await safe_reply(update.message, "❌ Claude Code has exited.") + return # Send /usage command to Claude Code TUI await tmux_manager.send_keys(w.window_id, "/usage") @@ -338,6 +345,64 @@ async def usage_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N await safe_reply(update.message, f"```\n{trimmed}\n```") +async def kill_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Kill the associated tmux window, clean up state, and delete the topic.""" + user = update.effective_user + if not user or not is_user_allowed(user.id): + return + if not update.message: + return + + thread_id = _get_thread_id(update) + if thread_id is None: + await safe_reply(update.message, "❌ This command only works in a topic.") + return + + wid = session_manager.get_window_for_thread(user.id, thread_id) + if not wid: + await safe_reply(update.message, "❌ No session bound to this topic.") + return + + display = session_manager.get_display_name(wid) + + # Kill the tmux window + w = await tmux_manager.find_window_by_id(wid) + if w: + await tmux_manager.kill_window(w.window_id) + logger.info( + "/kill: killed window %s (user=%d, thread=%d)", + display, + user.id, + thread_id, + ) + else: + logger.info( + "/kill: window %s already gone (user=%d, thread=%d)", + display, + user.id, + thread_id, + ) + + # Unbind and clean up all topic state + session_manager.unbind_thread(user.id, thread_id) + await clear_topic_state(user.id, thread_id, context.bot, context.user_data) + + await safe_reply(update.message, f"✅ Killed session '{display}'.") + + # Best-effort: close and delete the forum topic + chat_id = update.effective_chat.id if update.effective_chat else None + if chat_id and thread_id: + try: + await context.bot.close_forum_topic(chat_id, thread_id) + await context.bot.delete_forum_topic(chat_id, thread_id) + except Exception: + logger.debug( + "/kill: could not close/delete topic (user=%d, thread=%d)", + user.id, + thread_id, + ) + + # --- Screenshot keyboard with quick control keys --- # key_id → (tmux_key, enter, literal) @@ -620,6 +685,223 @@ async def photo_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> N await safe_reply(update.message, "📷 Image sent to Claude Code.") +# --- Allowed document MIME types for upload --- +_ALLOWED_DOC_MIME_PREFIXES = ("text/",) +_ALLOWED_DOC_MIME_TYPES = { + "application/pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/msword", +} +_ALLOWED_DOC_EXTENSIONS = { + ".md", + ".markdown", + ".txt", + ".csv", + ".json", + ".yaml", + ".yml", + ".toml", + ".xml", + ".html", + ".css", + ".js", + ".ts", + ".py", + ".sh", + ".bash", + ".rs", + ".go", + ".java", + ".c", + ".cpp", + ".h", + ".hpp", + ".rb", + ".pl", + ".lua", + ".sql", + ".r", + ".swift", + ".kt", + ".scala", + ".ex", + ".exs", + ".hs", + ".ml", + ".clj", + ".el", + ".vim", + ".conf", + ".ini", + ".cfg", + ".env", + ".log", + ".diff", + ".patch", + ".pdf", + ".docx", + ".doc", +} + + +def _convert_docx_to_markdown(docx_path: Path) -> str: + """Extract text from a .docx file and return as markdown.""" + import docx + + doc = docx.Document(str(docx_path)) + lines: list[str] = [] + for para in doc.paragraphs: + text = para.text + if not text.strip(): + lines.append("") + continue + style_name = (para.style.name or "").lower() if para.style else "" + if style_name.startswith("heading 1"): + lines.append(f"# {text}") + elif style_name.startswith("heading 2"): + lines.append(f"## {text}") + elif style_name.startswith("heading 3"): + lines.append(f"### {text}") + elif style_name.startswith("heading 4"): + lines.append(f"#### {text}") + elif style_name.startswith("list"): + lines.append(f"- {text}") + else: + lines.append(text) + return "\n\n".join(lines) + + +async def document_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle document uploads: save text/code/PDF/Word files to session cwd and forward path.""" + user = update.effective_user + if not user or not is_user_allowed(user.id): + if update.message: + await safe_reply(update.message, "You are not authorized to use this bot.") + return + + if not update.message or not update.message.document: + return + + doc = update.message.document + file_name = doc.file_name or "unnamed_document" + mime = doc.mime_type or "" + ext = Path(file_name).suffix.lower() + + # Check if file type is allowed + if ( + not any(mime.startswith(p) for p in _ALLOWED_DOC_MIME_PREFIXES) + and mime not in _ALLOWED_DOC_MIME_TYPES + and ext not in _ALLOWED_DOC_EXTENSIONS + ): + await safe_reply( + update.message, + f"⚠ Unsupported file type: {file_name}\n" + "Supported: text files, code, Markdown, PDF, and Word documents.", + ) + return + + chat = update.message.chat + thread_id = _get_thread_id(update) + if chat.type in ("group", "supergroup") and thread_id is not None: + session_manager.set_group_chat_id(user.id, thread_id, chat.id) + + # Must be in a named topic + if thread_id is None: + await safe_reply( + update.message, + "❌ Please use a named topic. Create a new topic to start a session.", + ) + return + + wid = session_manager.get_window_for_thread(user.id, thread_id) + if wid is None: + await safe_reply( + update.message, + "❌ No session bound to this topic. Send a text message first to create one.", + ) + return + + w = await tmux_manager.find_window_by_id(wid) + if not w: + display = session_manager.get_display_name(wid) + session_manager.unbind_thread(user.id, thread_id) + await safe_reply( + update.message, + f"❌ Window '{display}' no longer exists. Binding removed.\n" + "Send a message to start a new session.", + ) + return + + # Resolve session cwd for the inbox directory + ws = session_manager.get_window_state(wid) + if not ws.cwd: + await safe_reply( + update.message, + "❌ Session working directory not yet known. Try again in a moment.", + ) + return + + inbox_dir = Path(ws.cwd) / "docs" / "inbox" + inbox_dir.mkdir(parents=True, exist_ok=True) + + tg_file = await doc.get_file() + is_docx = ext in (".docx", ".doc") or mime in ( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/msword", + ) + + if is_docx: + # Convert Word documents to Markdown + import tempfile + + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: + tmp_path = Path(tmp.name) + try: + await tg_file.download_to_drive(tmp_path) + md_content = await asyncio.to_thread(_convert_docx_to_markdown, tmp_path) + finally: + tmp_path.unlink(missing_ok=True) + + save_name = Path(file_name).stem + ".md" + dest = inbox_dir / save_name + if dest.exists(): + dest = inbox_dir / f"{Path(file_name).stem}_{int(time.time())}.md" + dest.write_text(md_content, encoding="utf-8") + else: + # Save PDFs and text files directly + dest = inbox_dir / file_name + if dest.exists(): + stem = Path(file_name).stem + dest = inbox_dir / f"{stem}_{int(time.time())}{ext}" + await tg_file.download_to_drive(dest) + + # Build message for Claude Code — file context first, then user's instruction + rel_path = f"docs/inbox/{dest.name}" + caption = update.message.caption or "" + file_notice = ( + f"A file has been saved to {rel_path} (absolute path: {dest}). " + "Read it with your Read tool." + ) + if caption: + text_to_send = f"{file_notice}\n\n{caption}" + else: + text_to_send = file_notice + + await update.message.chat.send_action(ChatAction.TYPING) + clear_status_msg_info(user.id, thread_id) + + success, message = await session_manager.send_to_window(wid, text_to_send) + if not success: + await safe_reply(update.message, f"❌ {message}") + return + + suffix_note = " (converted from Word to Markdown)" if is_docx else "" + await safe_reply( + update.message, + f"📄 File saved to `{rel_path}`{suffix_note} and sent to Claude Code.", + ) + + # Active bash capture tasks: (user_id, thread_id) → asyncio.Task _bash_capture_tasks: dict[tuple[int, int], asyncio.Task[None]] = {} @@ -776,7 +1058,7 @@ async def text_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No if wid is None: # Unbound topic — check for unbound windows first all_windows = await tmux_manager.list_windows() - bound_ids = {wid for _, _, wid in session_manager.iter_thread_bindings()} + bound_ids = {wid for _, _, wid in session_manager.all_thread_bindings()} unbound = [ (w.window_id, w.window_name, w.cwd) for w in all_windows @@ -812,7 +1094,7 @@ async def text_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No user.id, thread_id, ) - start_path = str(Path.cwd()) + start_path = config.browse_root or str(Path.cwd()) msg_text, keyboard, subdirs = build_directory_browser(start_path) if context.user_data is not None: context.user_data[STATE_KEY] = STATE_BROWSING_DIRECTORY @@ -864,7 +1146,16 @@ async def text_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No success, message = await session_manager.send_to_window(wid, text) if not success: - await safe_reply(update.message, f"❌ {message}") + if "not running" in message: + # Claude Code exited and auto-resume failed — unbind + session_manager.unbind_thread(user.id, thread_id) + await safe_reply( + update.message, + f"❌ {message}. Binding removed.\n" + "Send a message to start a new session.", + ) + else: + await safe_reply(update.message, f"❌ {message}") return # Start background capture for ! bash command output @@ -970,7 +1261,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - return subdir_name = cached_dirs[idx] - default_path = str(Path.cwd()) + default_path = config.browse_root or str(Path.cwd()) current_path = ( context.user_data.get(BROWSE_PATH_KEY, default_path) if context.user_data @@ -1000,7 +1291,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - if pending_tid is not None and _get_thread_id(update) != pending_tid: await query.answer("Stale browser (topic mismatch)", show_alert=True) return - default_path = str(Path.cwd()) + default_path = config.browse_root or str(Path.cwd()) current_path = ( context.user_data.get(BROWSE_PATH_KEY, default_path) if context.user_data @@ -1033,7 +1324,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - except ValueError: await query.answer("Invalid data") return - default_path = str(Path.cwd()) + default_path = config.browse_root or str(Path.cwd()) current_path = ( context.user_data.get(BROWSE_PATH_KEY, default_path) if context.user_data @@ -1049,7 +1340,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - await query.answer() elif data == CB_DIR_CONFIRM: - default_path = str(Path.cwd()) + default_path = config.browse_root or str(Path.cwd()) selected_path = ( context.user_data.get(BROWSE_PATH_KEY, default_path) if context.user_data @@ -1084,7 +1375,9 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - user.id, pending_thread_id, ) - # Wait for Claude Code's SessionStart hook to register in session_map + # Wait for Claude Code's SessionStart hook to register in session_map. + # Return value intentionally ignored: on timeout, the monitor's poll + # cycle will pick up the session_map entry once the hook fires. await session_manager.wait_for_session_map_entry(created_wid) if pending_thread_id is not None: @@ -1123,14 +1416,16 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - created_wname, len(pending_text), ) - if context.user_data is not None: - context.user_data.pop("_pending_thread_text", None) - context.user_data.pop("_pending_thread_id", None) send_ok, send_msg = await session_manager.send_to_window( created_wid, pending_text, ) - if not send_ok: + if send_ok: + # Clear pending text only after successful send + if context.user_data is not None: + context.user_data.pop("_pending_thread_text", None) + context.user_data.pop("_pending_thread_id", None) + else: logger.warning("Failed to forward pending text: %s", send_msg) await safe_send( context.bot, @@ -1224,14 +1519,16 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - pending_text = ( context.user_data.get("_pending_thread_text") if context.user_data else None ) - if context.user_data is not None: - context.user_data.pop("_pending_thread_text", None) - context.user_data.pop("_pending_thread_id", None) if pending_text: send_ok, send_msg = await session_manager.send_to_window( selected_wid, pending_text ) - if not send_ok: + if send_ok: + # Clear pending text only after successful send + if context.user_data is not None: + context.user_data.pop("_pending_thread_text", None) + context.user_data.pop("_pending_thread_id", None) + else: logger.warning("Failed to forward pending text: %s", send_msg) await safe_send( context.bot, @@ -1239,6 +1536,9 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - f"❌ Failed to send pending message: {send_msg}", message_thread_id=thread_id, ) + elif context.user_data is not None: + # No pending text — clean up thread_id tracking + context.user_data.pop("_pending_thread_id", None) await query.answer("Bound") # Window picker: new session → transition to directory browser @@ -1251,7 +1551,7 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - return # Preserve pending thread info, clear only picker state clear_window_picker_state(context.user_data) - start_path = str(Path.cwd()) + start_path = config.browse_root or str(Path.cwd()) msg_text, keyboard, subdirs = build_directory_browser(start_path) if context.user_data is not None: context.user_data[STATE_KEY] = STATE_BROWSING_DIRECTORY @@ -1293,8 +1593,8 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - keyboard = _build_screenshot_keyboard(window_id) try: await query.edit_message_media( - media=InputMediaDocument( - media=io.BytesIO(png_bytes), filename="screenshot.png" + media=InputMediaPhoto( + media=io.BytesIO(png_bytes), ), reply_markup=keyboard, ) @@ -1432,6 +1732,9 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - if not w: await query.answer("Window not found", show_alert=True) return + if w.pane_current_command in SHELL_COMMANDS: + await query.answer("Claude Code has exited", show_alert=True) + return await tmux_manager.send_keys( w.window_id, tmux_key, enter=enter, literal=literal @@ -1446,14 +1749,13 @@ async def callback_handler(update: Update, context: ContextTypes.DEFAULT_TYPE) - keyboard = _build_screenshot_keyboard(window_id) try: await query.edit_message_media( - media=InputMediaDocument( + media=InputMediaPhoto( media=io.BytesIO(png_bytes), - filename="screenshot.png", ), reply_markup=keyboard, ) - except Exception: - pass # Screenshot unchanged or message too old + except Exception as e: + logger.debug(f"Screenshot edit after key press failed: {e}") # --- Streaming response / notifications --- @@ -1481,30 +1783,42 @@ async def handle_new_message(msg: NewMessage, bot: Bot) -> None: for user_id, wid, thread_id in active_users: # Handle interactive tools specially - capture terminal and send UI if msg.tool_name in INTERACTIVE_TOOL_NAMES and msg.content_type == "tool_use": - # Mark interactive mode BEFORE sleeping so polling skips this window - set_interactive_mode(user_id, wid, thread_id) - # Flush pending messages (e.g. plan content) before sending interactive UI - queue = get_message_queue(user_id) - if queue: - await queue.join() - # Wait briefly for Claude Code to render the question UI - await asyncio.sleep(0.3) - handled = await handle_interactive_ui(bot, user_id, wid, thread_id) - if handled: - # Update user's read offset - session = await session_manager.resolve_session_for_window(wid) - if session and session.file_path: - try: - file_size = Path(session.file_path).stat().st_size - session_manager.update_user_window_offset( - user_id, wid, file_size - ) - except OSError: - pass - continue # Don't send the normal tool_use message - else: - # UI not rendered — clear the early-set mode - clear_interactive_mode(user_id, thread_id) + # Mark interactive mode BEFORE enqueuing so polling skips this window. + # Capture the generation so the callable can detect staleness. + gen = set_interactive_mode(user_id, wid, thread_id) + + # Enqueue the interactive UI handling as a callable task so it + # executes AFTER all pending content messages already in the queue, + # without blocking the monitor loop or any other session's processing. + async def _send_interactive_ui( + _bot: Bot = bot, + _user_id: int = user_id, + _wid: str = wid, + _thread_id: int | None = thread_id, + _gen: int = gen, + ) -> None: + # Wait briefly for Claude Code to render the question UI + await asyncio.sleep(0.3) + handled = await handle_interactive_ui( + _bot, _user_id, _wid, _thread_id, expected_generation=_gen + ) + if handled: + # Update user's read offset + session = await session_manager.resolve_session_for_window(_wid) + if session and session.file_path: + try: + file_size = Path(session.file_path).stat().st_size + session_manager.update_user_window_offset( + _user_id, _wid, file_size + ) + except OSError: + pass + else: + # UI not rendered — clear the early-set mode + clear_interactive_mode(_user_id, _thread_id) + + enqueue_callable(bot, user_id, _send_interactive_ui) + continue # Don't send the normal tool_use message # Any non-interactive message means the interaction is complete — delete the UI message if get_interactive_msg_id(user_id, thread_id): @@ -1631,6 +1945,7 @@ def create_bot() -> Application: application.add_handler(CommandHandler("history", history_command)) application.add_handler(CommandHandler("screenshot", screenshot_command)) application.add_handler(CommandHandler("esc", esc_command)) + application.add_handler(CommandHandler("kill", kill_command)) application.add_handler(CommandHandler("unbind", unbind_command)) application.add_handler(CommandHandler("usage", usage_command)) application.add_handler(CallbackQueryHandler(callback_handler)) @@ -1655,6 +1970,8 @@ def create_bot() -> Application: ) # Photos: download and forward file path to Claude Code application.add_handler(MessageHandler(filters.PHOTO, photo_handler)) + # Documents: save text/markdown files to session cwd and forward path + application.add_handler(MessageHandler(filters.Document.ALL, document_handler)) # Catch-all: non-text content (stickers, voice, etc.) application.add_handler( MessageHandler( diff --git a/src/ccbot/config.py b/src/ccbot/config.py index 1dfd28ed..fb0a1271 100644 --- a/src/ccbot/config.py +++ b/src/ccbot/config.py @@ -93,6 +93,30 @@ def __init__(self) -> None: os.getenv("CCBOT_SHOW_HIDDEN_DIRS", "").lower() == "true" ) + # Starting directory for the directory browser + self.browse_root = os.getenv("CCBOT_BROWSE_ROOT", "") + + # Memory monitoring (on by default, opt-out with CCBOT_MEMORY_MONITOR=false) + self.memory_monitor_enabled = ( + os.getenv("CCBOT_MEMORY_MONITOR", "true").lower() != "false" + ) + self.memory_warning_mb = float(os.getenv("CCBOT_MEMORY_WARNING_MB", "2048")) + self.memory_check_interval = float( + os.getenv("CCBOT_MEMORY_CHECK_INTERVAL", "10") + ) + + # System-wide memory pressure thresholds (MemAvailable from /proc/meminfo) + # Escalation: warn → interrupt (send Escape) → kill (highest-RSS window) + self.mem_avail_warn_mb = float( + os.getenv("CCBOT_MEM_AVAIL_WARN_MB", "1024") + ) + self.mem_avail_interrupt_mb = float( + os.getenv("CCBOT_MEM_AVAIL_INTERRUPT_MB", "512") + ) + self.mem_avail_kill_mb = float( + os.getenv("CCBOT_MEM_AVAIL_KILL_MB", "256") + ) + # Scrub sensitive vars from os.environ so child processes never inherit them. # Values are already captured in Config attributes above. for var in SENSITIVE_ENV_VARS: @@ -100,12 +124,13 @@ def __init__(self) -> None: logger.debug( "Config initialized: dir=%s, token=%s..., allowed_users=%d, " - "tmux_session=%s, claude_projects_path=%s", + "tmux_session=%s, claude_projects_path=%s, browse_root=%s", self.config_dir, self.telegram_bot_token[:8], len(self.allowed_users), self.tmux_session_name, self.claude_projects_path, + self.browse_root, ) def is_user_allowed(self, user_id: int) -> bool: diff --git a/src/ccbot/handlers/directory_browser.py b/src/ccbot/handlers/directory_browser.py index a9e724d7..bb0ddcad 100644 --- a/src/ccbot/handlers/directory_browser.py +++ b/src/ccbot/handlers/directory_browser.py @@ -112,7 +112,7 @@ def build_directory_browser( """ path = Path(current_path).expanduser().resolve() if not path.exists() or not path.is_dir(): - path = Path.cwd() + path = Path(config.browse_root) if config.browse_root else Path.cwd() try: subdirs = sorted( diff --git a/src/ccbot/handlers/interactive_ui.py b/src/ccbot/handlers/interactive_ui.py index 174e3a9e..c90c93dc 100644 --- a/src/ccbot/handlers/interactive_ui.py +++ b/src/ccbot/handlers/interactive_ui.py @@ -15,8 +15,10 @@ """ import logging +import time from telegram import Bot, InlineKeyboardButton, InlineKeyboardMarkup +from telegram.error import BadRequest, RetryAfter from ..session import session_manager from ..terminal_parser import extract_interactive_content, is_interactive_ui @@ -45,6 +47,21 @@ # Track interactive mode: (user_id, thread_id_or_0) -> window_id _interactive_mode: dict[tuple[int, int], str] = {} +# Deduplication: monotonic timestamp of last new interactive message send +_last_interactive_send: dict[tuple[int, int], float] = {} +_INTERACTIVE_DEDUP_WINDOW = 2.0 # seconds — suppress duplicate sends within this window + +# Generation counter: incremented on every state transition (set/clear) so that +# stale callables enqueued by the JSONL monitor can detect invalidation. +_interactive_generation: dict[tuple[int, int], int] = {} + + +def _next_generation(ikey: tuple[int, int]) -> int: + """Increment and return the generation counter for this user/thread.""" + gen = _interactive_generation.get(ikey, 0) + 1 + _interactive_generation[ikey] = gen + return gen + def get_interactive_window(user_id: int, thread_id: int | None = None) -> str | None: """Get the window_id for user's interactive mode.""" @@ -55,21 +72,25 @@ def set_interactive_mode( user_id: int, window_id: str, thread_id: int | None = None, -) -> None: - """Set interactive mode for a user.""" +) -> int: + """Set interactive mode for a user. Returns the generation counter.""" + ikey = (user_id, thread_id or 0) logger.debug( "Set interactive mode: user=%d, window_id=%s, thread=%s", user_id, window_id, thread_id, ) - _interactive_mode[(user_id, thread_id or 0)] = window_id + _interactive_mode[ikey] = window_id + return _next_generation(ikey) def clear_interactive_mode(user_id: int, thread_id: int | None = None) -> None: """Clear interactive mode for a user (without deleting message).""" + ikey = (user_id, thread_id or 0) logger.debug("Clear interactive mode: user=%d, thread=%s", user_id, thread_id) - _interactive_mode.pop((user_id, thread_id or 0), None) + _interactive_mode.pop(ikey, None) + _next_generation(ikey) def get_interactive_msg_id(user_id: int, thread_id: int | None = None) -> int | None: @@ -145,14 +166,36 @@ async def handle_interactive_ui( user_id: int, window_id: str, thread_id: int | None = None, + expected_generation: int | None = None, ) -> bool: """Capture terminal and send interactive UI content to user. Handles AskUserQuestion, ExitPlanMode, Permission Prompt, and RestoreCheckpoint UIs. Returns True if UI was detected and sent, False otherwise. + + If *expected_generation* is provided (from the JSONL monitor path), + the function checks that the current generation still matches before + proceeding. This prevents stale callables from acting after the + interactive mode has been cleared or superseded. """ ikey = (user_id, thread_id or 0) + + # Generation guard: if caller provided an expected generation and it + # doesn't match the current one, this callable is stale — bail out. + if expected_generation is not None: + current_gen = _interactive_generation.get(ikey, 0) + if current_gen != expected_generation: + logger.debug( + "Stale interactive UI callable: user=%d, thread=%s, " + "expected_gen=%d, current_gen=%d — skipping", + user_id, + thread_id, + expected_generation, + current_gen, + ) + return False + chat_id = session_manager.resolve_chat_id(user_id, thread_id) w = await tmux_manager.find_window_by_id(window_id) if not w: @@ -202,14 +245,54 @@ async def handle_interactive_ui( ) _interactive_mode[ikey] = window_id return True - except Exception: - # Edit failed (message deleted, etc.) - clear stale msg_id and send new + except RetryAfter: + raise + except BadRequest as e: + if "is not modified" in str(e).lower(): + # Content identical to what's already displayed — treat as success. + _interactive_mode[ikey] = window_id + return True + # Any other BadRequest (e.g. message deleted, too old to edit): + # clear stale state and try to remove the orphan message. + logger.debug( + "Edit failed for interactive msg %s (%s), sending new", + existing_msg_id, + e, + ) + _interactive_msgs.pop(ikey, None) + try: + await bot.delete_message(chat_id=chat_id, message_id=existing_msg_id) + except Exception: + pass # Already deleted or too old — ignore. + # Fall through to send new message + except Exception as e: + # NetworkError, TimedOut, Forbidden, etc. — message state is uncertain; + # discard the stale ID and fall through to send a fresh message. logger.debug( - "Edit failed for interactive msg %s, sending new", existing_msg_id + "Edit failed (%s) for interactive msg %s, sending new", + e, + existing_msg_id, ) _interactive_msgs.pop(ikey, None) # Fall through to send new message + # Dedup guard: prevent both JSONL monitor and status poller from sending + # a new interactive message in the same short window. No await between + # check and set, so this is atomic in the asyncio event loop. + last_send = _last_interactive_send.get(ikey, 0.0) + now = time.monotonic() + if now - last_send < _INTERACTIVE_DEDUP_WINDOW: + logger.debug( + "Dedup: skipping duplicate interactive UI send " + "(user=%d, thread=%s, %.1fs since last)", + user_id, + thread_id, + now - last_send, + ) + _interactive_mode[ikey] = window_id + return True + _last_interactive_send[ikey] = now + # Send new message (plain text — terminal content is not markdown) logger.info( "Sending interactive UI to user %d for window_id %s", user_id, window_id @@ -222,7 +305,11 @@ async def handle_interactive_ui( link_preview_options=NO_LINK_PREVIEW, **thread_kwargs, # type: ignore[arg-type] ) + except RetryAfter: + _last_interactive_send.pop(ikey, None) + raise except Exception as e: + _last_interactive_send.pop(ikey, None) logger.error("Failed to send interactive UI: %s", e) return False if sent: @@ -241,6 +328,8 @@ async def clear_interactive_msg( ikey = (user_id, thread_id or 0) msg_id = _interactive_msgs.pop(ikey, None) _interactive_mode.pop(ikey, None) + _last_interactive_send.pop(ikey, None) + _next_generation(ikey) logger.debug( "Clear interactive msg: user=%d, thread=%s, msg_id=%s", user_id, diff --git a/src/ccbot/handlers/message_queue.py b/src/ccbot/handlers/message_queue.py index bdd28038..53ff3b63 100644 --- a/src/ccbot/handlers/message_queue.py +++ b/src/ccbot/handlers/message_queue.py @@ -20,8 +20,9 @@ import asyncio import logging import time +from collections.abc import Callable, Coroutine from dataclasses import dataclass, field -from typing import Literal +from typing import Any, Literal from telegram import Bot from telegram.constants import ChatAction @@ -55,7 +56,7 @@ def _ensure_formatted(text: str) -> str: class MessageTask: """Message task for queue processing.""" - task_type: Literal["content", "status_update", "status_clear"] + task_type: Literal["content", "status_update", "status_clear", "callable"] text: str | None = None window_id: str | None = None # content type fields @@ -64,6 +65,13 @@ class MessageTask: content_type: str = "text" thread_id: int | None = None # Telegram topic thread_id for targeted send image_data: list[tuple[str, bytes]] | None = None # From tool_result images + # callable task: a zero-argument coroutine factory executed in-order by the + # worker. Must be a factory (not a bare coroutine object) so the worker can + # safely retry by calling it again — a coroutine can only be awaited once. + callable_fn: Callable[[], Coroutine[Any, Any, None]] | None = None + # Number of times this task has been re-queued after a long RetryAfter. + # Prevents infinite re-queue loops under persistent rate limiting. + requeue_count: int = 0 # Per-user message queues and worker tasks @@ -84,6 +92,12 @@ class MessageTask: # Max seconds to wait for flood control before dropping tasks FLOOD_CONTROL_MAX_WAIT = 10 +# Maximum number of RetryAfter retries per task before giving up +MAX_TASK_RETRIES = 3 + +# Maximum number of times a task can be re-queued after long RetryAfter +MAX_REQUEUE_COUNT = 5 + def get_message_queue(user_id: int) -> asyncio.Queue[MessageTask] | None: """Get the message queue for a user (if exists).""" @@ -144,10 +158,11 @@ async def _merge_content_tasks( additional tasks merged (0 if no merging occurred). Note on queue counter management: - When we put items back, we call task_done() to compensate for the - internal counter increment caused by put_nowait(). This is necessary - because the items were already counted when originally enqueued. - Without this compensation, queue.join() would wait indefinitely. + put_nowait() unconditionally increments _unfinished_tasks. + When we put items back, they already hold a counter slot from when + they were first enqueued, so the compensating task_done() removes the + duplicate increment added by put_nowait(). Without this, _unfinished_tasks + would leak by len(remaining) per merge cycle. """ merged_parts = list(first.parts) current_length = sum(len(p) for p in merged_parts) @@ -212,10 +227,10 @@ async def _message_queue_worker(bot: Bot, user_id: int) -> None: if flood_end > 0: remaining = flood_end - time.monotonic() if remaining > 0: - if task.task_type != "content": + if task.task_type in ("status_update", "status_clear"): # Status is ephemeral — safe to drop continue - # Content is actual Claude output — wait then send + # Content and callable tasks must not be dropped — wait logger.debug( "Flood controlled: waiting %.0fs for content (user %d)", remaining, @@ -226,42 +241,87 @@ async def _message_queue_worker(bot: Bot, user_id: int) -> None: _flood_until.pop(user_id, None) logger.info("Flood control lifted for user %d", user_id) + # Retry loop: retry the task on RetryAfter up to MAX_TASK_RETRIES times. + # Merging is done once before the loop so that merged_task is reused on + # every retry attempt rather than re-merging from a now-empty queue. if task.task_type == "content": - # Try to merge consecutive content tasks merged_task, merge_count = await _merge_content_tasks( queue, task, lock ) if merge_count > 0: logger.debug(f"Merged {merge_count} tasks for user {user_id}") - # Mark merged tasks as done for _ in range(merge_count): queue.task_done() - await _process_content_task(bot, user_id, merged_task) - elif task.task_type == "status_update": - await _process_status_update_task(bot, user_id, task) - elif task.task_type == "status_clear": - await _do_clear_status_message(bot, user_id, task.thread_id or 0) - except RetryAfter as e: - retry_secs = ( - e.retry_after - if isinstance(e.retry_after, int) - else int(e.retry_after.total_seconds()) - ) - if retry_secs > FLOOD_CONTROL_MAX_WAIT: - _flood_until[user_id] = time.monotonic() + retry_secs - logger.warning( - "Flood control for user %d: retry_after=%ds, " - "pausing queue until ban expires", - user_id, - retry_secs, - ) else: - logger.warning( - "Flood control for user %d: waiting %ds", - user_id, - retry_secs, - ) - await asyncio.sleep(retry_secs) + merged_task = task + merge_count = 0 + + for attempt in range(MAX_TASK_RETRIES + 1): + try: + if merged_task.task_type == "content": + await _process_content_task(bot, user_id, merged_task) + elif merged_task.task_type == "status_update": + await _process_status_update_task(bot, user_id, merged_task) + elif merged_task.task_type == "status_clear": + await _do_clear_status_message( + bot, user_id, merged_task.thread_id or 0 + ) + elif merged_task.task_type == "callable": + if merged_task.callable_fn is not None: + await merged_task.callable_fn() + break # Success — exit retry loop + except RetryAfter as e: + retry_secs = ( + e.retry_after + if isinstance(e.retry_after, int) + else int(e.retry_after.total_seconds()) + ) + if retry_secs > FLOOD_CONTROL_MAX_WAIT: + _flood_until[user_id] = time.monotonic() + retry_secs + if merged_task.requeue_count >= MAX_REQUEUE_COUNT: + logger.error( + "Dropping task for user %d after %d re-queues " + "(persistent flood control, task_type=%s)", + user_id, + merged_task.requeue_count, + merged_task.task_type, + ) + break + merged_task.requeue_count += 1 + logger.warning( + "Flood control for user %d: retry_after=%ds, " + "re-queuing task (requeue %d/%d)", + user_id, + retry_secs, + merged_task.requeue_count, + MAX_REQUEUE_COUNT, + ) + # Re-queue so the task is retried once the ban + # expires. put_nowait increments _unfinished_tasks + # for the new slot; the outer finally calls + # task_done() for the slot consumed by dequeuing, + # so the net counter change is zero. + queue.put_nowait(merged_task) + break # Let the flood-control path handle re-queued task + if attempt < MAX_TASK_RETRIES: + logger.warning( + "RetryAfter for user %d: waiting %ds (attempt %d/%d)", + user_id, + retry_secs, + attempt + 1, + MAX_TASK_RETRIES, + ) + await asyncio.sleep(retry_secs) + # Loop back and retry the same task + else: + logger.error( + "Dropping task for user %d after %d retries " + "(last retry_after=%ds, task_type=%s)", + user_id, + MAX_TASK_RETRIES, + retry_secs, + merged_task.task_type, + ) except Exception as e: logger.error(f"Error processing message task for user {user_id}: {e}") finally: @@ -381,8 +441,14 @@ async def _process_content_task(bot: Bot, user_id: int, task: MessageTask) -> No # 4. Send images if present (from tool_result with base64 image blocks) await _send_task_images(bot, chat_id, task) - # 5. After content, check and send status - await _check_and_send_status(bot, user_id, wid, task.thread_id) + # 5. After content, check and send status. + # Catch RetryAfter here: the status message is cosmetic and must never + # propagate RetryAfter to the outer retry loop (which would re-send all + # content messages as duplicates). + try: + await _check_and_send_status(bot, user_id, wid, task.thread_id) + except RetryAfter: + pass async def _convert_status_to_content( @@ -475,7 +541,9 @@ async def _process_status_update_task( if "esc to interrupt" in status_text.lower(): try: await bot.send_chat_action( - chat_id=chat_id, action=ChatAction.TYPING + chat_id=chat_id, + action=ChatAction.TYPING, + message_thread_id=task.thread_id, ) except RetryAfter: raise @@ -534,7 +602,11 @@ async def _do_send_status_message( # Send typing indicator when Claude is working if "esc to interrupt" in text.lower(): try: - await bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING) + await bot.send_chat_action( + chat_id=chat_id, + action=ChatAction.TYPING, + message_thread_id=thread_id, + ) except RetryAfter: raise except Exception: @@ -661,6 +733,28 @@ async def enqueue_status_update( queue.put_nowait(task) +def enqueue_callable( + bot: Bot, + user_id: int, + coro_factory: Callable[[], Coroutine[Any, Any, None]], +) -> None: + """Enqueue a coroutine factory for in-order execution by the queue worker. + + *coro_factory* is a zero-argument callable that returns a new coroutine each + time it is called. The worker calls the factory on each attempt so that + retries after ``RetryAfter`` work correctly (a bare coroutine object can + only be awaited once). + + Typically this is just an async function reference (not its invocation):: + + enqueue_callable(bot, uid, my_async_fn) # correct — factory + enqueue_callable(bot, uid, my_async_fn()) # WRONG — bare coroutine + """ + queue = get_or_create_queue(bot, user_id) + task = MessageTask(task_type="callable", callable_fn=coro_factory) + queue.put_nowait(task) + + def clear_status_msg_info(user_id: int, thread_id: int | None = None) -> None: """Clear status message tracking for a user (and optionally a specific thread).""" skey = (user_id, thread_id or 0) diff --git a/src/ccbot/handlers/response_builder.py b/src/ccbot/handlers/response_builder.py index 41b7d0c9..7e7a1953 100644 --- a/src/ccbot/handlers/response_builder.py +++ b/src/ccbot/handlers/response_builder.py @@ -33,7 +33,7 @@ def build_response_parts( # User messages: add emoji prefix (no newline) if role == "user": - prefix = "👤 " + prefix = "💎 " separator = "" # User messages are typically short, no special processing needed if len(text) > 3000: @@ -55,11 +55,18 @@ def build_response_parts( # Format based on content type if content_type == "thinking": - # Thinking: prefix with "∴ Thinking…" and single newline - prefix = "∴ Thinking…" + # Thinking: purple prefix + prefix = "🧠 Thinking…" separator = "\n" + elif content_type in ("tool_use", "tool_result"): + # Tool calls: orange prefix + prefix = "🛠️" + separator = " " + elif content_type == "text": + # Assistant text: green prefix + prefix = "🔮" + separator = " " else: - # Plain text: no prefix prefix = "" separator = "" diff --git a/src/ccbot/handlers/status_polling.py b/src/ccbot/handlers/status_polling.py index c4de1c6e..4f2faa3e 100644 --- a/src/ccbot/handlers/status_polling.py +++ b/src/ccbot/handlers/status_polling.py @@ -5,15 +5,18 @@ - Detects interactive UIs (permission prompts) not triggered via JSONL - Updates status messages in Telegram - Polls thread_bindings (each topic = one window) - - Periodically probes topic existence via unpin_all_forum_topic_messages - (silent no-op when no pins); cleans up deleted topics (kills tmux window - + unbinds thread) + - Periodically probes topic existence via send_chat_action (TYPING); + raises BadRequest on deleted topics; cleans up deleted topics (kills + tmux window + unbinds thread) + - Proactive OOM prevention: system-wide MemAvailable monitoring with + escalating actions (warn → interrupt → kill highest-RSS window) Key components: - STATUS_POLL_INTERVAL: Polling frequency (1 second) - TOPIC_CHECK_INTERVAL: Topic existence probe frequency (60 seconds) - status_poll_loop: Background polling task - update_status_message: Poll and enqueue status updates + - _check_system_memory: Escalating system memory pressure response """ import asyncio @@ -21,21 +24,44 @@ import time from telegram import Bot +from telegram.constants import ChatAction from telegram.error import BadRequest +from ..config import config +from ..process_info import ( + get_mem_available_mb, + get_process_tree_pids, + get_tree_rss_mb, + was_pid_oom_killed, +) from ..session import session_manager from ..terminal_parser import is_interactive_ui, parse_status_line from ..tmux_manager import tmux_manager from .interactive_ui import ( clear_interactive_msg, + get_interactive_msg_id, get_interactive_window, handle_interactive_ui, ) from .cleanup import clear_topic_state from .message_queue import enqueue_status_update, get_message_queue +from .message_sender import safe_send logger = logging.getLogger(__name__) +# Track pane PIDs so we can check OOM after window death +_window_pids: dict[str, int] = {} # window_id → shell PID + +# Per-window memory monitoring state +_last_memory_check: float = 0.0 +_memory_warned: set[str] = set() # window_ids that have been warned + +# System-wide memory pressure escalation state +# Levels: 0=ok, 1=warned, 2=interrupted, 3=killed +_sys_mem_level: int = 0 +_sys_mem_cycles_at_level: int = 0 # how many check cycles at current level +_last_sys_mem_check: float = 0.0 + # Status polling interval STATUS_POLL_INTERVAL = 1.0 # seconds - faster response (rate limiting at send layer) @@ -93,6 +119,15 @@ async def update_status_message( # Check for permission prompt (interactive UI not triggered via JSONL) # ALWAYS check UI, regardless of skip_status if should_check_new_ui and is_interactive_ui(pane_text): + # Skip if another path (e.g. JSONL monitor) already sent an interactive + # message for this user/thread — avoids duplicate messages + if get_interactive_msg_id(user_id, thread_id): + logger.debug( + "Interactive UI already tracked for user=%d thread=%s, skipping", + user_id, + thread_id, + ) + return logger.debug( "Interactive UI detected in polling (user=%d, window=%s, thread=%s)", user_id, @@ -129,13 +164,12 @@ async def status_poll_loop(bot: Bot) -> None: now = time.monotonic() if now - last_topic_check >= TOPIC_CHECK_INTERVAL: last_topic_check = now - for user_id, thread_id, wid in list( - session_manager.iter_thread_bindings() - ): + for user_id, thread_id, wid in session_manager.all_thread_bindings(): try: - await bot.unpin_all_forum_topic_messages( + await bot.send_chat_action( chat_id=session_manager.resolve_chat_id(user_id, thread_id), message_thread_id=thread_id, + action=ChatAction.TYPING, ) except BadRequest as e: if "Topic_id_invalid" in str(e): @@ -165,11 +199,77 @@ async def status_poll_loop(bot: Bot) -> None: e, ) - for user_id, thread_id, wid in list(session_manager.iter_thread_bindings()): + # Fresh snapshot — reflects any unbinds from the topic probe above, + # so bindings cleaned there are naturally excluded. + for user_id, thread_id, wid in session_manager.all_thread_bindings(): try: # Clean up stale bindings (window no longer exists) w = await tmux_manager.find_window_by_id(wid) if not w: + display_name = session_manager.get_display_name(wid) + shell_pid = _window_pids.pop(wid, None) + _memory_warned.discard(wid) + + # Check if the window was killed by OOM + reason = "Session ended" + if shell_pid: + try: + descendants = await get_process_tree_pids(shell_pid) + except Exception: + descendants = [] + oom_info = await was_pid_oom_killed(shell_pid, descendants) + if oom_info: + rss_kb = oom_info.get("rss_kb") + rss_str = ( + f", RSS: {int(str(rss_kb)) // 1024}MB" + if rss_kb + else "" + ) + reason = ( + f"Session killed by OOM killer " + f"(process: {oom_info['process_name']}" + f"{rss_str})" + ) + logger.warning( + "OOM kill detected for window %s (pid=%d): %s", + wid, + shell_pid, + oom_info.get("line", ""), + ) + + # Extract last-activity context from JSONL + death_context = "" + try: + death_context = ( + await session_manager.get_session_death_context(wid) + ) + except Exception as e: + logger.debug( + "Failed to get death context for %s: %s", + wid, + e, + ) + + # Notify user in the topic + try: + chat_id = session_manager.resolve_chat_id( + user_id, thread_id + ) + msg = f"\u26a0\ufe0f {reason}: {display_name}" + if death_context: + msg += f"\n\n{death_context}" + await safe_send( + bot, + chat_id, + msg, + message_thread_id=thread_id, + ) + except Exception as e: + logger.debug( + "Failed to send session-end notification: %s", + e, + ) + session_manager.unbind_thread(user_id, thread_id) await clear_topic_state(user_id, thread_id, bot) logger.info( @@ -180,6 +280,12 @@ async def status_poll_loop(bot: Bot) -> None: ) continue + # Track pane PID for OOM detection on death + if wid not in _window_pids: + pid = await tmux_manager.get_pane_pid(wid) + if pid: + _window_pids[wid] = pid + # UI detection happens unconditionally in update_status_message. # Status enqueue is skipped inside update_status_message when # interactive UI is detected (returns early) or when queue is non-empty. @@ -198,7 +304,257 @@ async def status_poll_loop(bot: Bot) -> None: f"Status update error for user {user_id} " f"thread {thread_id}: {e}" ) + + # Periodic memory monitoring + await _check_memory_usage(bot) + await _check_system_memory(bot) except Exception as e: logger.error(f"Status poll loop error: {e}") await asyncio.sleep(STATUS_POLL_INTERVAL) + + +async def _check_memory_usage(bot: Bot) -> None: + """Check memory usage of all tracked windows and warn if above threshold.""" + global _last_memory_check + + if not config.memory_monitor_enabled: + return + + now = time.monotonic() + if now - _last_memory_check < config.memory_check_interval: + return + _last_memory_check = now + + for user_id, thread_id, wid in session_manager.all_thread_bindings(): + pid = _window_pids.get(wid) + if not pid: + continue + try: + rss_mb = await get_tree_rss_mb(pid) + if rss_mb is None: + continue + + if rss_mb > config.memory_warning_mb and wid not in _memory_warned: + _memory_warned.add(wid) + display_name = session_manager.get_display_name(wid) + chat_id = session_manager.resolve_chat_id(user_id, thread_id) + await safe_send( + bot, + chat_id, + f"\u26a0\ufe0f High memory usage: {display_name} " + f"is using {rss_mb:.0f}MB RSS", + message_thread_id=thread_id, + ) + logger.warning( + "Memory warning for window %s (pid=%d): %.0fMB", + wid, + pid, + rss_mb, + ) + elif rss_mb <= config.memory_warning_mb * 0.8 and wid in _memory_warned: + # Reset warning when memory drops to 80% of threshold + _memory_warned.discard(wid) + except Exception as e: + logger.debug("Memory check error for window %s: %s", wid, e) + + +async def _find_highest_rss_window() -> ( + tuple[str, int, int, float] | None +): + """Find the window with the highest process tree RSS. + + Returns (window_id, user_id, thread_id, rss_mb) or None. + """ + highest: tuple[str, int, int, float] | None = None + for user_id, thread_id, wid in session_manager.all_thread_bindings(): + pid = _window_pids.get(wid) + if not pid: + continue + try: + rss_mb = await get_tree_rss_mb(pid) + if rss_mb is not None and (highest is None or rss_mb > highest[3]): + highest = (wid, user_id, thread_id, rss_mb) + except Exception: + continue + return highest + + +# Cooldown constants (in check cycles, not seconds — cycle = memory_check_interval) +# With default CCBOT_MEMORY_CHECK_INTERVAL=10: 2 cycles ≈ 20s, 3 cycles ≈ 30s +_INTERRUPT_COOLDOWN_CYCLES = 2 # wait 2 cycles after interrupt before kill +_KILL_COOLDOWN_CYCLES = 3 # wait 3 cycles after kill before another kill + + +async def _check_system_memory(bot: Bot) -> None: + """Check system-wide MemAvailable and escalate if memory is critically low. + + Escalation levels (advances at most one level per check cycle): + 0 → normal + 1 → warn (notify all topics) + 2 → interrupt (send Escape to highest-RSS window) + 3 → kill (kill highest-RSS window) + """ + global _sys_mem_level, _sys_mem_cycles_at_level, _last_sys_mem_check + + if not config.memory_monitor_enabled: + return + + now = time.monotonic() + if now - _last_sys_mem_check < config.memory_check_interval: + return + _last_sys_mem_check = now + + available = await get_mem_available_mb() + if available is None: + return # /proc/meminfo not readable (non-Linux) + + # Determine target level based on available memory + if available <= config.mem_avail_kill_mb: + target_level = 3 + elif available <= config.mem_avail_interrupt_mb: + target_level = 2 + elif available <= config.mem_avail_warn_mb: + target_level = 1 + else: + target_level = 0 + + # Recovery: reset when memory is well above warn threshold (hysteresis) + if available > config.mem_avail_warn_mb * 1.5: + if _sys_mem_level > 0: + logger.info( + "System memory recovered: %.0fMB available, resetting escalation", + available, + ) + _sys_mem_level = 0 + _sys_mem_cycles_at_level = 0 + return + + # No escalation needed + if target_level == 0: + _sys_mem_level = 0 + _sys_mem_cycles_at_level = 0 + return + + # Track cycles at current level + _sys_mem_cycles_at_level += 1 + + # Advance at most one level per cycle + new_level = min(target_level, _sys_mem_level + 1) + + # Enforce cooldowns before escalating + if new_level == 3 and _sys_mem_level == 2: + if _sys_mem_cycles_at_level < _INTERRUPT_COOLDOWN_CYCLES: + logger.debug( + "Interrupt cooldown: %d/%d cycles", + _sys_mem_cycles_at_level, + _INTERRUPT_COOLDOWN_CYCLES, + ) + return + if new_level == 3 and _sys_mem_level == 3: + if _sys_mem_cycles_at_level < _KILL_COOLDOWN_CYCLES: + logger.debug( + "Kill cooldown: %d/%d cycles", + _sys_mem_cycles_at_level, + _KILL_COOLDOWN_CYCLES, + ) + return + + # Downgrade level when pressure has eased + if new_level < _sys_mem_level: + _sys_mem_level = new_level + _sys_mem_cycles_at_level = 0 + return + + # Only act when level actually changes (or kill repeats) + if new_level <= _sys_mem_level and new_level < 3: + return + + # Reset cycle counter on level change + if new_level != _sys_mem_level: + _sys_mem_cycles_at_level = 0 + _sys_mem_level = new_level + + # === Level 1: Warn all topics === + if new_level == 1: + logger.warning("System memory low: %.0fMB available", available) + for user_id, thread_id, _wid in session_manager.all_thread_bindings(): + try: + chat_id = session_manager.resolve_chat_id(user_id, thread_id) + await safe_send( + bot, + chat_id, + f"\u26a0\ufe0f System memory low: {available:.0f}MB available. " + "Consider reducing parallel workloads.", + message_thread_id=thread_id, + ) + except Exception as e: + logger.debug("Failed to send memory warning: %s", e) + + # === Level 2: Interrupt highest-RSS window === + elif new_level == 2: + target = await _find_highest_rss_window() + if not target: + logger.warning( + "Memory critical (%.0fMB available) but no windows to interrupt", + available, + ) + return + wid, user_id, thread_id, rss_mb = target + display_name = session_manager.get_display_name(wid) + logger.warning( + "Memory critical (%.0fMB available) — interrupting %s (%.0fMB RSS)", + available, + display_name, + rss_mb, + ) + await tmux_manager.send_keys(wid, "Escape", enter=False, literal=False) + try: + chat_id = session_manager.resolve_chat_id(user_id, thread_id) + await safe_send( + bot, + chat_id, + f"\u26a0\ufe0f Memory critical ({available:.0f}MB available) " + f"\u2014 interrupted {display_name} ({rss_mb:.0f}MB RSS)", + message_thread_id=thread_id, + ) + except Exception as e: + logger.debug("Failed to send interrupt notification: %s", e) + + # === Level 3: Kill highest-RSS window === + elif new_level == 3: + target = await _find_highest_rss_window() + if not target: + logger.warning( + "Memory emergency (%.0fMB available) but no windows to kill", + available, + ) + _sys_mem_level = 0 + _sys_mem_cycles_at_level = 0 + return + wid, user_id, thread_id, rss_mb = target + display_name = session_manager.get_display_name(wid) + logger.error( + "Memory emergency (%.0fMB available) — killing %s (%.0fMB RSS)", + available, + display_name, + rss_mb, + ) + await tmux_manager.kill_window(wid) + _window_pids.pop(wid, None) + _memory_warned.discard(wid) + try: + chat_id = session_manager.resolve_chat_id(user_id, thread_id) + await safe_send( + bot, + chat_id, + f"\U0001f6a8 Memory emergency ({available:.0f}MB available) " + f"\u2014 killed {display_name} ({rss_mb:.0f}MB RSS) " + "to prevent system OOM", + message_thread_id=thread_id, + ) + except Exception as e: + logger.debug("Failed to send kill notification: %s", e) + session_manager.unbind_thread(user_id, thread_id) + await clear_topic_state(user_id, thread_id, bot) + _sys_mem_cycles_at_level = 0 # reset for next kill cooldown diff --git a/src/ccbot/process_info.py b/src/ccbot/process_info.py new file mode 100644 index 00000000..6b3cc1e0 --- /dev/null +++ b/src/ccbot/process_info.py @@ -0,0 +1,181 @@ +"""Process tree inspection, memory usage, and OOM-kill detection. + +Provides Linux-specific helpers that read /proc and dmesg to: + - Walk a process's descendant tree + - Read RSS memory (VmRSS) for a process or its entire tree + - Detect recent OOM kills and correlate them with specific PIDs + +All functions are async-friendly (blocking I/O wrapped in to_thread). +""" + +from __future__ import annotations + +import asyncio +import logging +import re +from pathlib import Path + +logger = logging.getLogger(__name__) + + +async def get_child_pids(pid: int) -> list[int]: + """Get direct child PIDs of a process via /proc/[pid]/task/[tid]/children.""" + children: list[int] = [] + task_dir = Path(f"/proc/{pid}/task") + try: + if not task_dir.exists(): + return children + for tid_dir in task_dir.iterdir(): + children_file = tid_dir / "children" + if children_file.exists(): + text = children_file.read_text().strip() + if text: + children.extend(int(p) for p in text.split()) + except (OSError, ValueError): + pass + return children + + +async def get_process_tree_pids(root_pid: int) -> list[int]: + """Get all descendant PIDs of a process (breadth-first).""" + all_pids: list[int] = [] + queue = [root_pid] + while queue: + pid = queue.pop(0) + kids = await get_child_pids(pid) + all_pids.extend(kids) + queue.extend(kids) + return all_pids + + +async def get_process_rss_mb(pid: int) -> float | None: + """Read VmRSS from /proc/[pid]/status. Returns MB or None if unavailable.""" + try: + status_file = Path(f"/proc/{pid}/status") + if not status_file.exists(): + return None + text = await asyncio.to_thread(status_file.read_text) + for line in text.splitlines(): + if line.startswith("VmRSS:"): + # Format: "VmRSS: 12345 kB" + parts = line.split() + if len(parts) >= 2: + return int(parts[1]) / 1024.0 + except (OSError, ValueError): + pass + return None + + +async def get_tree_rss_mb(root_pid: int) -> float | None: + """Sum RSS of a process and all its descendants. Returns MB or None.""" + root_rss = await get_process_rss_mb(root_pid) + if root_rss is None: + return None + + total = root_rss + descendants = await get_process_tree_pids(root_pid) + for pid in descendants: + rss = await get_process_rss_mb(pid) + if rss is not None: + total += rss + return total + + +def _parse_dmesg_for_oom_kills(dmesg_output: str) -> list[dict[str, object]]: + """Parse dmesg text for OOM kill entries. + + Returns list of dicts with keys: pid, process_name, total_pages, rss_kb. + """ + results: list[dict[str, object]] = [] + # Kernel OOM killer log pattern (varies by kernel version): + # "Killed process 1234 (python3) total-vm:123456kB, ..." + # "Out of memory: Killed process 1234 (python3) ..." + pattern = re.compile( + r"Killed process (\d+) \(([^)]+)\)" + r"(?:.*?total-vm:(\d+)kB)?" + r"(?:.*?anon-rss:(\d+)kB)?", + ) + for line in dmesg_output.splitlines(): + m = pattern.search(line) + if m: + results.append( + { + "pid": int(m.group(1)), + "process_name": m.group(2), + "total_vm_kb": int(m.group(3)) if m.group(3) else None, + "rss_kb": int(m.group(4)) if m.group(4) else None, + "line": line.strip(), + } + ) + return results + + +async def check_recent_oom_kills() -> list[dict[str, object]]: + """Check dmesg for OOM kill entries. + + Returns all OOM kills found in the current dmesg buffer. + Best-effort: returns empty list if dmesg is inaccessible. + """ + try: + proc = await asyncio.create_subprocess_exec( + "dmesg", + "--level=err,warn,info", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, _ = await proc.communicate() + if proc.returncode != 0: + # Try without --level flag (older dmesg versions) + proc = await asyncio.create_subprocess_exec( + "dmesg", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, _ = await proc.communicate() + if proc.returncode != 0: + return [] + return _parse_dmesg_for_oom_kills(stdout.decode("utf-8", errors="replace")) + except Exception as e: + logger.debug("Failed to check dmesg for OOM kills: %s", e) + return [] + + +async def get_mem_available_mb() -> float | None: + """Read MemAvailable from /proc/meminfo. Returns MB or None if unavailable. + + MemAvailable is the kernel's estimate of memory available for new + allocations without swapping — more accurate than MemFree alone. + """ + meminfo = Path("/proc/meminfo") + try: + if not meminfo.exists(): + return None + text = await asyncio.to_thread(meminfo.read_text) + for line in text.splitlines(): + if line.startswith("MemAvailable:"): + # Format: "MemAvailable: 12345678 kB" + parts = line.split() + if len(parts) >= 2: + return int(parts[1]) / 1024.0 + except (OSError, ValueError): + pass + return None + + +async def was_pid_oom_killed( + shell_pid: int, + descendants: list[int] | None = None, +) -> dict[str, object] | None: + """Check if a PID or any of its descendants were OOM-killed. + + Returns the matching OOM kill info dict, or None. + """ + pids_to_check = {shell_pid} + if descendants: + pids_to_check.update(descendants) + + oom_kills = await check_recent_oom_kills() + for kill in oom_kills: + if kill["pid"] in pids_to_check: + return kill + return None diff --git a/src/ccbot/screenshot.py b/src/ccbot/screenshot.py index cbe70f69..c73ef815 100644 --- a/src/ccbot/screenshot.py +++ b/src/ccbot/screenshot.py @@ -104,9 +104,8 @@ def _font_tier(ch: str) -> int: if cp in _SYMBOLA_CODEPOINTS: return 2 # CJK Unified Ideographs + CJK compat + fullwidth forms + Hangul + known Noto-only codepoints - if ( - cp in _NOTO_CODEPOINTS - or cp >= 0x1100 + if cp in _NOTO_CODEPOINTS or ( + cp >= 0x1100 and ( cp <= 0x11FF # Hangul Jamo or 0x2E80 <= cp <= 0x9FFF # CJK radicals, kangxi, ideographs diff --git a/src/ccbot/session.py b/src/ccbot/session.py index c740545c..802820f0 100644 --- a/src/ccbot/session.py +++ b/src/ccbot/session.py @@ -17,27 +17,48 @@ Key class: SessionManager (singleton instantiated as `session_manager`). Key methods for thread binding access: - resolve_window_for_thread: Get window_id for a user's thread - - iter_thread_bindings: Generator for iterating all (user_id, thread_id, window_id) + - all_thread_bindings: Snapshot list of all (user_id, thread_id, window_id) - find_users_for_session: Find all users bound to a session_id """ import asyncio import json import logging +import re from dataclasses import dataclass, field from pathlib import Path -from collections.abc import Iterator from typing import Any import aiofiles from .config import config -from .tmux_manager import tmux_manager +from .tmux_manager import SHELL_COMMANDS, tmux_manager from .transcript_parser import TranscriptParser from .utils import atomic_write_json logger = logging.getLogger(__name__) +# Patterns for detecting Claude Code resume commands in pane output +_RESUME_CMD_RE = re.compile(r"(claude\s+(?:--resume|-r)\s+\S+)") +_STOPPED_RE = re.compile(r"Stopped\s+.*claude", re.IGNORECASE) + + +def _extract_resume_command(pane_text: str) -> str | None: + """Extract a resume command from pane content after Claude Code exit. + + Detects two patterns: + - Suspended process: ``[1]+ Stopped claude`` → returns ``"fg"`` + - Exited with resume hint: ``claude --resume `` → returns the full command + + Returns None if no resume path is detected. + """ + if _STOPPED_RE.search(pane_text): + return "fg" + match = _RESUME_CMD_RE.search(pane_text) + if match: + return match.group(1) + return None + @dataclass class WindowState: @@ -179,7 +200,6 @@ def _load_state(self) -> None: "Detected old-format state (window_name keys), " "will re-resolve on startup" ) - pass except (json.JSONDecodeError, ValueError) as e: logger.warning("Failed to load state: %s", e) @@ -188,7 +208,6 @@ def _load_state(self) -> None: self.thread_bindings = {} self.window_display_names = {} self.group_chat_ids = {} - pass async def resolve_stale_ids(self) -> None: """Re-resolve persisted window IDs against live tmux windows. @@ -472,8 +491,8 @@ async def wait_for_session_map_entry( timeout, ) key = f"{config.tmux_session_name}:{window_id}" - deadline = asyncio.get_event_loop().time() + timeout - while asyncio.get_event_loop().time() < deadline: + deadline = asyncio.get_running_loop().time() + timeout + while asyncio.get_running_loop().time() < deadline: try: if config.session_map_file.exists(): async with aiofiles.open(config.session_map_file, "r") as f: @@ -739,15 +758,18 @@ def resolve_window_for_thread( return None return self.get_window_for_thread(user_id, thread_id) - def iter_thread_bindings(self) -> Iterator[tuple[int, int, str]]: - """Iterate all thread bindings as (user_id, thread_id, window_id). + def all_thread_bindings(self) -> list[tuple[int, int, str]]: + """Return a snapshot of all thread bindings as (user_id, thread_id, window_id). - Provides encapsulated access to thread_bindings without exposing - the internal data structure directly. + Returns a new list each call so callers can safely await between + iterations without risking ``RuntimeError: dictionary changed size + during iteration`` from a concurrent ``unbind_thread`` call. """ - for user_id, bindings in self.thread_bindings.items(): - for thread_id, window_id in bindings.items(): - yield user_id, thread_id, window_id + return [ + (user_id, thread_id, window_id) + for user_id, bindings in self.thread_bindings.items() + for thread_id, window_id in bindings.items() + ] async def find_users_for_session( self, @@ -758,7 +780,7 @@ async def find_users_for_session( Returns list of (user_id, window_id, thread_id) tuples. """ result: list[tuple[int, str, int]] = [] - for user_id, thread_id, window_id in self.iter_thread_bindings(): + for user_id, thread_id, window_id in self.all_thread_bindings(): resolved = await self.resolve_session_for_window(window_id) if resolved and resolved.session_id == session_id: result.append((user_id, window_id, thread_id)) @@ -767,7 +789,11 @@ async def find_users_for_session( # --- Tmux helpers --- async def send_to_window(self, window_id: str, text: str) -> tuple[bool, str]: - """Send text to a tmux window by ID.""" + """Send text to a tmux window by ID. + + If the pane is running a bare shell (Claude Code exited), attempts + to auto-resume via ``fg`` or ``claude --resume `` before sending. + """ display = self.get_display_name(window_id) logger.debug( "send_to_window: window_id=%s (%s), text_len=%d", @@ -778,11 +804,61 @@ async def send_to_window(self, window_id: str, text: str) -> tuple[bool, str]: window = await tmux_manager.find_window_by_id(window_id) if not window: return False, "Window not found (may have been closed)" + if window.pane_current_command in SHELL_COMMANDS: + resumed = await self._try_resume_claude(window_id, display) + if not resumed: + return False, "Claude Code is not running (session exited)" success = await tmux_manager.send_keys(window.window_id, text) if success: return True, f"Sent to {display}" return False, "Failed to send keys" + async def _try_resume_claude(self, window_id: str, display: str) -> bool: + """Attempt to resume Claude Code when pane has dropped to a shell. + + Detects ``fg`` (suspended process) and ``claude --resume `` + (exited session) patterns in the pane content. Sends the appropriate + command and waits for Claude Code to take over the terminal. + + Returns True if Claude Code is running after the attempt. + """ + pane_text = await tmux_manager.capture_pane(window_id) + if not pane_text: + return False + + resume_cmd = _extract_resume_command(pane_text) + if not resume_cmd: + logger.warning( + "No resume command found in %s (%s), cannot auto-resume", + window_id, + display, + ) + return False + + logger.info( + "Auto-resuming Claude Code in %s (%s): %s", + window_id, + display, + resume_cmd, + ) + await tmux_manager.send_keys(window_id, resume_cmd) + + # Wait for Claude Code to take over the terminal + max_wait = 3.0 if resume_cmd == "fg" else 15.0 + elapsed = 0.0 + while elapsed < max_wait: + await asyncio.sleep(0.5) + elapsed += 0.5 + w = await tmux_manager.find_window_by_id(window_id) + if w and w.pane_current_command not in SHELL_COMMANDS: + # Claude Code is running again — give TUI a moment to init + await asyncio.sleep(1.0) + logger.info("Claude Code resumed in %s (%s)", window_id, display) + return True + + logger.warning("Auto-resume timed out for %s (%s)", window_id, display) + return False + # --- Message history --- async def get_recent_messages( @@ -844,5 +920,89 @@ async def get_recent_messages( return all_messages, len(all_messages) + async def get_session_death_context( + self, window_id: str, max_chars: int = 500 + ) -> str: + """Extract last-activity context from a session's JSONL for crash diagnostics. + + Reads the tail of the JSONL file and returns a formatted summary of + what Claude was doing when the session died (last tools, pending + operations, last message text). + + Returns empty string if session or file is unavailable. + """ + session = await self.resolve_session_for_window(window_id) + if not session or not session.file_path: + return "" + + file_path = Path(session.file_path) + if not file_path.exists(): + return "" + + # Read last ~8KB of the JSONL file (efficient tail) + try: + file_size = file_path.stat().st_size + tail_offset = max(0, file_size - 8192) + + entries: list[dict] = [] + async with aiofiles.open(file_path, "r", encoding="utf-8") as f: + if tail_offset > 0: + await f.seek(tail_offset) + await f.readline() # skip partial first line + + while True: + line = await f.readline() + if not line: + break + data = TranscriptParser.parse_line(line) + if data: + entries.append(data) + except OSError as e: + logger.debug("Error reading session file for death context: %s", e) + return "" + + if not entries: + return "" + + parsed_entries, remaining_pending = TranscriptParser.parse_entries(entries) + if not parsed_entries: + return "" + + lines: list[str] = [] + + # Pending tools (were mid-execution when session died) + if remaining_pending: + for tool_id, info in remaining_pending.items(): + lines.append(f"\u2022 Running: {info.summary}") + + # Last tool_use entries (most recent first) + tool_entries = [e for e in parsed_entries if e.content_type == "tool_use"] + if tool_entries and not remaining_pending: + last_tool = tool_entries[-1] + lines.append(f"\u2022 Last tool: {last_tool.text}") + + # Last assistant text message + text_entries = [ + e + for e in parsed_entries + if e.content_type == "text" and e.role == "assistant" + ] + if text_entries: + last_text = text_entries[-1].text.strip() + if last_text: + # Truncate long messages + if len(last_text) > 200: + last_text = last_text[:200] + "..." + lines.append(f'\u2022 Last message: "{last_text}"') + + if not lines: + return "" + + result = "Last activity:\n" + "\n".join(lines) + # Enforce overall length limit + if len(result) > max_chars: + result = result[:max_chars] + "..." + return result + session_manager = SessionManager() diff --git a/src/ccbot/session_monitor.py b/src/ccbot/session_monitor.py index 0a1b3186..9bc7e7e6 100644 --- a/src/ccbot/session_monitor.py +++ b/src/ccbot/session_monitor.py @@ -6,7 +6,7 @@ 3. Reads new JSONL lines from each session file using byte-offset tracking. 4. Parses entries via TranscriptParser and emits NewMessage objects to a callback. -Optimizations: mtime cache skips unchanged files; byte offset avoids re-reading. +Optimizations: file size check skips unchanged files; byte offset avoids re-reading. Key classes: SessionMonitor, NewMessage, SessionInfo. """ @@ -82,8 +82,6 @@ def __init__( # Track last known session_map for detecting changes # Keys may be window_id (@12) or window_name (old format) during transition self._last_session_map: dict[str, str] = {} # window_key -> session_id - # In-memory mtime cache for quick file change detection (not persisted) - self._file_mtimes: dict[str, float] = {} # session_id -> last_seen_mtime def set_message_callback( self, callback: Callable[[NewMessage], Awaitable[None]] @@ -292,41 +290,33 @@ async def check_for_updates(self, active_session_ids: set[str]) -> list[NewMessa # to avoid re-processing old messages try: file_size = session_info.file_path.stat().st_size - current_mtime = session_info.file_path.stat().st_mtime except OSError: file_size = 0 - current_mtime = 0.0 tracked = TrackedSession( session_id=session_info.session_id, file_path=str(session_info.file_path), last_byte_offset=file_size, ) self.state.update_session(tracked) - self._file_mtimes[session_info.session_id] = current_mtime logger.info(f"Started tracking session: {session_info.session_id}") continue - # Check mtime + file size to see if file has changed + # Quick size check: skip reading if file size hasn't changed. + # For append-only JSONL files, size == offset means no new + # content. Size < offset (truncation) and size > offset (new + # data) both need processing — handled inside _read_new_lines. try: - st = session_info.file_path.stat() - current_mtime = st.st_mtime - current_size = st.st_size + current_size = session_info.file_path.stat().st_size except OSError: continue - last_mtime = self._file_mtimes.get(session_info.session_id, 0.0) - if ( - current_mtime <= last_mtime - and current_size <= tracked.last_byte_offset - ): - # File hasn't changed, skip reading + if current_size == tracked.last_byte_offset: continue # File changed, read new content from last offset new_entries = await self._read_new_lines( tracked, session_info.file_path ) - self._file_mtimes[session_info.session_id] = current_mtime if new_entries: logger.debug( @@ -369,7 +359,10 @@ async def check_for_updates(self, active_session_ids: set[str]) -> list[NewMessa except OSError as e: logger.debug(f"Error processing session {session_info.session_id}: {e}") - self.state.save_if_dirty() + # NOTE: save_if_dirty() is intentionally NOT called here. + # Offsets must be persisted only AFTER delivery to Telegram (in + # _monitor_loop) to guarantee at-least-once delivery. Saving before + # delivery would risk silent message loss on crash. return new_messages async def _load_current_session_map(self) -> dict[str, str]: @@ -416,7 +409,7 @@ async def _cleanup_all_stale_sessions(self) -> None: ) for session_id in stale_sessions: self.state.remove_session(session_id) - self._file_mtimes.pop(session_id, None) + self._pending_tools.pop(session_id, None) self.state.save_if_dirty() async def _detect_and_cleanup_changes(self) -> dict[str, str]: @@ -458,7 +451,7 @@ async def _detect_and_cleanup_changes(self) -> dict[str, str]: if sessions_to_remove: for session_id in sessions_to_remove: self.state.remove_session(session_id) - self._file_mtimes.pop(session_id, None) + self._pending_tools.pop(session_id, None) self.state.save_if_dirty() # Update last known map @@ -503,6 +496,12 @@ async def _monitor_loop(self) -> None: except Exception as e: logger.error(f"Message callback error: {e}") + # Persist byte offsets AFTER delivering messages to Telegram. + # This guarantees at-least-once delivery: if the bot crashes + # before this save, messages will be re-read and re-delivered + # on restart (safe duplicate) rather than silently lost. + self.state.save_if_dirty() + except Exception as e: logger.error(f"Monitor loop error: {e}") diff --git a/src/ccbot/tmux_manager.py b/src/ccbot/tmux_manager.py index 84cba5aa..4c6dcb70 100644 --- a/src/ccbot/tmux_manager.py +++ b/src/ccbot/tmux_manager.py @@ -15,6 +15,7 @@ import asyncio import logging +import time from dataclasses import dataclass from pathlib import Path @@ -24,6 +25,22 @@ logger = logging.getLogger(__name__) +# Process names that indicate a bare shell (Claude Code has exited). +# Used to prevent sending user input to a shell prompt. +SHELL_COMMANDS = frozenset( + { + "bash", + "zsh", + "sh", + "fish", + "dash", + "tcsh", + "csh", + "ksh", + "ash", + } +) + @dataclass class TmuxWindow: @@ -38,6 +55,9 @@ class TmuxWindow: class TmuxManager: """Manages tmux windows for Claude Code sessions.""" + # How long cached list_windows results are valid (seconds). + _CACHE_TTL = 1.0 + def __init__(self, session_name: str | None = None): """Initialize tmux manager. @@ -46,6 +66,8 @@ def __init__(self, session_name: str | None = None): """ self.session_name = session_name or config.tmux_session_name self._server: libtmux.Server | None = None + self._windows_cache: list[TmuxWindow] | None = None + self._windows_cache_time: float = 0.0 @property def server(self) -> libtmux.Server: @@ -92,12 +114,23 @@ def _scrub_session_env(session: libtmux.Session) -> None: except Exception: pass # var not set in session env — nothing to remove + def invalidate_cache(self) -> None: + """Invalidate the cached window list (call after mutations).""" + self._windows_cache = None + async def list_windows(self) -> list[TmuxWindow]: """List all windows in the session with their working directories. - Returns: - List of TmuxWindow with window info and cwd + Results are cached for ``_CACHE_TTL`` seconds to avoid hammering + the tmux server when multiple callers need window info in the same + poll cycle. """ + now = time.monotonic() + if ( + self._windows_cache is not None + and (now - self._windows_cache_time) < self._CACHE_TTL + ): + return self._windows_cache def _sync_list_windows() -> list[TmuxWindow]: windows = [] @@ -135,7 +168,10 @@ def _sync_list_windows() -> list[TmuxWindow]: return windows - return await asyncio.to_thread(_sync_list_windows) + result = await asyncio.to_thread(_sync_list_windows) + self._windows_cache = result + self._windows_cache_time = time.monotonic() + return result async def find_window_by_name(self, window_name: str) -> TmuxWindow | None: """Find a window by its name. @@ -169,59 +205,54 @@ async def find_window_by_id(self, window_id: str) -> TmuxWindow | None: logger.debug("Window not found by id: %s", window_id) return None + async def get_pane_pid(self, window_id: str) -> int | None: + """Get the PID of the shell process in a window's active pane.""" + try: + proc = await asyncio.create_subprocess_exec( + "tmux", + "display-message", + "-p", + "-t", + window_id, + "#{pane_pid}", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, _ = await proc.communicate() + if proc.returncode == 0: + pid_str = stdout.decode("utf-8").strip() + if pid_str: + return int(pid_str) + except (OSError, ValueError) as e: + logger.debug("Failed to get pane PID for %s: %s", window_id, e) + return None + async def capture_pane(self, window_id: str, with_ansi: bool = False) -> str | None: """Capture the visible text content of a window's active pane. - Args: - window_id: The window ID to capture - with_ansi: If True, capture with ANSI color codes - - Returns: - The captured text, or None on failure. + Uses a direct ``tmux capture-pane`` subprocess for both plain text + and ANSI modes — avoids the multiple tmux round-trips that libtmux + would generate (list-windows → list-panes → capture-pane). """ + cmd = ["tmux", "capture-pane", "-p", "-t", window_id] if with_ansi: - # Use async subprocess to call tmux capture-pane -e for ANSI colors - try: - proc = await asyncio.create_subprocess_exec( - "tmux", - "capture-pane", - "-e", - "-p", - "-t", - window_id, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await proc.communicate() - if proc.returncode == 0: - return stdout.decode("utf-8") - logger.error( - f"Failed to capture pane {window_id}: {stderr.decode('utf-8')}" - ) - return None - except Exception as e: - logger.error(f"Unexpected error capturing pane {window_id}: {e}") - return None - - # Original implementation for plain text - wrap in thread - def _sync_capture() -> str | None: - session = self.get_session() - if not session: - return None - try: - window = session.windows.get(window_id=window_id) - if not window: - return None - pane = window.active_pane - if not pane: - return None - lines = pane.capture_pane() - return "\n".join(lines) if isinstance(lines, list) else str(lines) - except Exception as e: - logger.error(f"Failed to capture pane {window_id}: {e}") - return None - - return await asyncio.to_thread(_sync_capture) + cmd.insert(2, "-e") + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + if proc.returncode == 0: + return stdout.decode("utf-8") + logger.error( + "Failed to capture pane %s: %s", window_id, stderr.decode("utf-8") + ) + return None + except Exception as e: + logger.error("Unexpected error capturing pane %s: %s", window_id, e) + return None async def send_keys( self, window_id: str, text: str, enter: bool = True, literal: bool = True @@ -326,6 +357,7 @@ def _sync_send_keys() -> bool: async def rename_window(self, window_id: str, new_name: str) -> bool: """Rename a tmux window by its ID.""" + self.invalidate_cache() def _sync_rename() -> bool: session = self.get_session() @@ -346,6 +378,7 @@ def _sync_rename() -> bool: async def kill_window(self, window_id: str) -> bool: """Kill a tmux window by its ID.""" + self.invalidate_cache() def _sync_kill() -> bool: session = self.get_session() @@ -398,6 +431,8 @@ async def create_window( counter += 1 # Create window in thread + self.invalidate_cache() + def _create_and_start() -> tuple[bool, str, str, str]: session = self.get_or_create_session() try: diff --git a/src/ccbot/transcript_parser.py b/src/ccbot/transcript_parser.py index fa0bbf69..8bb4fc1f 100644 --- a/src/ccbot/transcript_parser.py +++ b/src/ccbot/transcript_parser.py @@ -486,6 +486,22 @@ def parse_entries( last_cmd_name = None if msg_type == "assistant": + # Pre-scan: check if this message contains an interactive + # tool_use (ExitPlanMode / AskUserQuestion). When present, + # suppress text entries from this same message — those text + # blocks are preamble that the terminal capture already + # includes. Emitting them as separate content messages + # causes a race: the content message clears the interactive + # UI state set by the status poller, leading to a duplicate + # interactive message being sent by the JSONL callable. + _INTERACTIVE_TOOLS = frozenset({"AskUserQuestion", "ExitPlanMode"}) + has_interactive_tool = any( + isinstance(b, dict) + and b.get("type") == "tool_use" + and b.get("name") in _INTERACTIVE_TOOLS + for b in content + ) + # Process content blocks has_text = False for block in content: @@ -494,6 +510,11 @@ def parse_entries( btype = block.get("type", "") if btype == "text": + # Skip text blocks when an interactive tool_use is + # present in the same message to avoid clearing the + # interactive UI state prematurely. + if has_interactive_tool: + continue t = block.get("text", "").strip() if t and t != cls._NO_CONTENT_PLACEHOLDER: result.append( diff --git a/tests/ccbot/handlers/test_interactive_ui.py b/tests/ccbot/handlers/test_interactive_ui.py index 8d6a98e4..336f9965 100644 --- a/tests/ccbot/handlers/test_interactive_ui.py +++ b/tests/ccbot/handlers/test_interactive_ui.py @@ -32,13 +32,19 @@ def mock_bot(): @pytest.fixture def _clear_interactive_state(): """Ensure interactive state is clean before and after each test.""" - from ccbot.handlers.interactive_ui import _interactive_mode, _interactive_msgs + from ccbot.handlers.interactive_ui import ( + _interactive_mode, + _interactive_msgs, + _last_interactive_send, + ) _interactive_mode.clear() _interactive_msgs.clear() + _last_interactive_send.clear() yield _interactive_mode.clear() _interactive_msgs.clear() + _last_interactive_send.clear() @pytest.mark.usefixtures("_clear_interactive_state") diff --git a/tests/ccbot/handlers/test_status_polling.py b/tests/ccbot/handlers/test_status_polling.py index 9c0f04f7..ad6ec312 100644 --- a/tests/ccbot/handlers/test_status_polling.py +++ b/tests/ccbot/handlers/test_status_polling.py @@ -24,13 +24,19 @@ def mock_bot(): @pytest.fixture def _clear_interactive_state(): """Ensure interactive state is clean before and after each test.""" - from ccbot.handlers.interactive_ui import _interactive_mode, _interactive_msgs + from ccbot.handlers.interactive_ui import ( + _interactive_mode, + _interactive_msgs, + _last_interactive_send, + ) _interactive_mode.clear() _interactive_msgs.clear() + _last_interactive_send.clear() yield _interactive_mode.clear() _interactive_msgs.clear() + _last_interactive_send.clear() @pytest.mark.usefixtures("_clear_interactive_state") diff --git a/tests/ccbot/handlers/test_system_memory.py b/tests/ccbot/handlers/test_system_memory.py new file mode 100644 index 00000000..4dac96af --- /dev/null +++ b/tests/ccbot/handlers/test_system_memory.py @@ -0,0 +1,323 @@ +"""Tests for system-wide memory pressure detection and escalation. + +Verifies the warn → interrupt → kill escalation in _check_system_memory: + - Warn sends notifications to all bound topics + - Interrupt sends Escape to highest-RSS window + - Kill removes highest-RSS window and cleans up bindings + - Escalation advances at most one level per check cycle + - Cooldowns prevent rapid successive kills + - Hysteresis resets escalation when memory recovers +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from ccbot.handlers import status_polling +from ccbot.handlers.status_polling import ( + _check_system_memory, + _find_highest_rss_window, +) + + +@pytest.fixture(autouse=True) +def _reset_escalation_state(): + """Reset module-level escalation state before and after each test.""" + status_polling._sys_mem_level = 0 + status_polling._sys_mem_cycles_at_level = 0 + status_polling._last_sys_mem_check = 0.0 + status_polling._window_pids.clear() + status_polling._memory_warned.clear() + yield + status_polling._sys_mem_level = 0 + status_polling._sys_mem_cycles_at_level = 0 + status_polling._last_sys_mem_check = 0.0 + status_polling._window_pids.clear() + status_polling._memory_warned.clear() + + +@pytest.fixture +def mock_bot(): + return AsyncMock() + + +@pytest.fixture +def mock_config(): + """Patch config with test thresholds.""" + with patch("ccbot.handlers.status_polling.config") as cfg: + cfg.memory_monitor_enabled = True + cfg.memory_check_interval = 0 # no throttle in tests + cfg.mem_avail_warn_mb = 1024.0 + cfg.mem_avail_interrupt_mb = 512.0 + cfg.mem_avail_kill_mb = 256.0 + yield cfg + + +@pytest.fixture +def one_binding(): + """Set up one thread binding with a tracked PID.""" + status_polling._window_pids["@0"] = 1234 + with ( + patch( + "ccbot.handlers.status_polling.session_manager" + ) as mock_sm, + patch("ccbot.handlers.status_polling.tmux_manager") as mock_tmux, + patch( + "ccbot.handlers.status_polling.safe_send", + new_callable=AsyncMock, + ) as mock_send, + patch( + "ccbot.handlers.status_polling.clear_topic_state", + new_callable=AsyncMock, + ), + ): + mock_sm.all_thread_bindings.return_value = [(100, 42, "@0")] + mock_sm.resolve_chat_id.return_value = 100 + mock_sm.get_display_name.return_value = "test-session" + mock_tmux.send_keys = AsyncMock(return_value=True) + mock_tmux.kill_window = AsyncMock(return_value=True) + yield { + "session_manager": mock_sm, + "tmux_manager": mock_tmux, + "safe_send": mock_send, + } + + +class TestCheckSystemMemory: + """Test _check_system_memory escalation logic.""" + + @pytest.mark.asyncio + async def test_no_action_when_memory_ok( + self, mock_bot: AsyncMock, mock_config: MagicMock, one_binding: dict + ) -> None: + """MemAvailable > warn threshold → no notification.""" + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + return_value=4000.0, + ): + await _check_system_memory(mock_bot) + one_binding["safe_send"].assert_not_called() + assert status_polling._sys_mem_level == 0 + + @pytest.mark.asyncio + async def test_warn_on_low_memory( + self, mock_bot: AsyncMock, mock_config: MagicMock, one_binding: dict + ) -> None: + """MemAvailable <= warn threshold → warn notification to all topics.""" + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + return_value=800.0, + ): + await _check_system_memory(mock_bot) + one_binding["safe_send"].assert_called_once() + msg = one_binding["safe_send"].call_args[0][2] + assert "System memory low" in msg + assert "800" in msg + assert status_polling._sys_mem_level == 1 + + @pytest.mark.asyncio + async def test_one_level_per_cycle( + self, mock_bot: AsyncMock, mock_config: MagicMock, one_binding: dict + ) -> None: + """Even with kill-level pressure, first cycle only warns.""" + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + return_value=100.0, # below kill threshold + ), patch( + "ccbot.handlers.status_polling.get_tree_rss_mb", + new_callable=AsyncMock, + return_value=3000.0, + ): + await _check_system_memory(mock_bot) + # Should only be at warn level, not kill + assert status_polling._sys_mem_level == 1 + msg = one_binding["safe_send"].call_args[0][2] + assert "System memory low" in msg + + @pytest.mark.asyncio + async def test_escalates_to_interrupt( + self, mock_bot: AsyncMock, mock_config: MagicMock, one_binding: dict + ) -> None: + """Warn → interrupt on continued pressure.""" + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + return_value=400.0, + ), patch( + "ccbot.handlers.status_polling.get_tree_rss_mb", + new_callable=AsyncMock, + return_value=3000.0, + ): + # Cycle 1: warn + await _check_system_memory(mock_bot) + assert status_polling._sys_mem_level == 1 + + # Cycle 2: interrupt + status_polling._last_sys_mem_check = 0.0 # bypass throttle + await _check_system_memory(mock_bot) + assert status_polling._sys_mem_level == 2 + one_binding["tmux_manager"].send_keys.assert_called_once_with( + "@0", "Escape", enter=False, literal=False + ) + + @pytest.mark.asyncio + async def test_interrupt_cooldown_before_kill( + self, mock_bot: AsyncMock, mock_config: MagicMock, one_binding: dict + ) -> None: + """After interrupt, must wait cooldown cycles before kill.""" + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + return_value=200.0, + ), patch( + "ccbot.handlers.status_polling.get_tree_rss_mb", + new_callable=AsyncMock, + return_value=3000.0, + ): + # Cycle 1: warn + await _check_system_memory(mock_bot) + assert status_polling._sys_mem_level == 1 + + # Cycle 2: interrupt + status_polling._last_sys_mem_check = 0.0 + await _check_system_memory(mock_bot) + assert status_polling._sys_mem_level == 2 + + # Cycle 3: cooldown (should NOT kill yet) + status_polling._last_sys_mem_check = 0.0 + await _check_system_memory(mock_bot) + one_binding["tmux_manager"].kill_window.assert_not_called() + + # Cycle 4: cooldown satisfied → kill + status_polling._last_sys_mem_check = 0.0 + await _check_system_memory(mock_bot) + one_binding["tmux_manager"].kill_window.assert_called_once_with("@0") + assert status_polling._sys_mem_level == 3 + + @pytest.mark.asyncio + async def test_kill_cleans_up_bindings( + self, mock_bot: AsyncMock, mock_config: MagicMock, one_binding: dict + ) -> None: + """Kill action unbinds thread and clears topic state.""" + # Fast-track to kill level + status_polling._sys_mem_level = 2 + status_polling._sys_mem_cycles_at_level = 10 # past cooldown + + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + return_value=200.0, + ), patch( + "ccbot.handlers.status_polling.get_tree_rss_mb", + new_callable=AsyncMock, + return_value=3000.0, + ): + await _check_system_memory(mock_bot) + + one_binding["tmux_manager"].kill_window.assert_called_once_with("@0") + one_binding["session_manager"].unbind_thread.assert_called_once_with(100, 42) + assert "@0" not in status_polling._window_pids + + @pytest.mark.asyncio + async def test_hysteresis_reset( + self, mock_bot: AsyncMock, mock_config: MagicMock, one_binding: dict + ) -> None: + """Memory recovery above warn*1.5 resets escalation to level 0.""" + status_polling._sys_mem_level = 2 # was at interrupt + + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + return_value=2000.0, # well above 1024 * 1.5 = 1536 + ): + await _check_system_memory(mock_bot) + + assert status_polling._sys_mem_level == 0 + assert status_polling._sys_mem_cycles_at_level == 0 + + @pytest.mark.asyncio + async def test_level_downgrade_when_pressure_partially_relieves( + self, mock_bot: AsyncMock, mock_config: MagicMock, one_binding: dict + ) -> None: + """If at kill level but memory rises to interrupt level, level downgrades.""" + status_polling._sys_mem_level = 3 + status_polling._sys_mem_cycles_at_level = 0 + + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + return_value=400.0, # interrupt level, not kill + ): + await _check_system_memory(mock_bot) + + assert status_polling._sys_mem_level == 2 + assert status_polling._sys_mem_cycles_at_level == 0 + + @pytest.mark.asyncio + async def test_disabled_monitor_skips( + self, mock_bot: AsyncMock, mock_config: MagicMock + ) -> None: + """When memory_monitor_enabled=False, no action taken.""" + mock_config.memory_monitor_enabled = False + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + ) as mock_mem: + await _check_system_memory(mock_bot) + mock_mem.assert_not_called() + + @pytest.mark.asyncio + async def test_none_mem_available_skips( + self, mock_bot: AsyncMock, mock_config: MagicMock, one_binding: dict + ) -> None: + """Non-Linux (MemAvailable=None) → no action.""" + with patch( + "ccbot.handlers.status_polling.get_mem_available_mb", + new_callable=AsyncMock, + return_value=None, + ): + await _check_system_memory(mock_bot) + one_binding["safe_send"].assert_not_called() + + +class TestFindHighestRssWindow: + """Test _find_highest_rss_window helper.""" + + @pytest.mark.asyncio + async def test_finds_highest_rss(self) -> None: + status_polling._window_pids = {"@0": 100, "@1": 200} + rss_map = {100: 1500.0, 200: 3000.0} + + with ( + patch( + "ccbot.handlers.status_polling.session_manager" + ) as mock_sm, + patch( + "ccbot.handlers.status_polling.get_tree_rss_mb", + new_callable=AsyncMock, + side_effect=lambda pid: rss_map.get(pid), + ), + ): + mock_sm.all_thread_bindings.return_value = [ + (1, 10, "@0"), + (1, 20, "@1"), + ] + result = await _find_highest_rss_window() + + assert result is not None + wid, user_id, thread_id, rss_mb = result + assert wid == "@1" + assert rss_mb == 3000.0 + + @pytest.mark.asyncio + async def test_returns_none_when_no_pids(self) -> None: + status_polling._window_pids = {} + with patch( + "ccbot.handlers.status_polling.session_manager" + ) as mock_sm: + mock_sm.all_thread_bindings.return_value = [(1, 10, "@0")] + result = await _find_highest_rss_window() + assert result is None diff --git a/tests/ccbot/test_process_info.py b/tests/ccbot/test_process_info.py new file mode 100644 index 00000000..18ddb826 --- /dev/null +++ b/tests/ccbot/test_process_info.py @@ -0,0 +1,196 @@ +"""Tests for process_info module — OOM detection and memory utilities.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from ccbot.process_info import ( + _parse_dmesg_for_oom_kills, + get_mem_available_mb, + get_process_rss_mb, + was_pid_oom_killed, +) + + +class TestParseDmesgForOomKills: + """Test dmesg OOM-kill line parsing.""" + + def test_standard_oom_kill_line(self) -> None: + line = ( + "[12345.678] Out of memory: Killed process 1234 (python3) " + "total-vm:15000000kB, anon-rss:14800000kB, file-rss:1234kB" + ) + results = _parse_dmesg_for_oom_kills(line) + assert len(results) == 1 + assert results[0]["pid"] == 1234 + assert results[0]["process_name"] == "python3" + assert results[0]["total_vm_kb"] == 15000000 + assert results[0]["rss_kb"] == 14800000 + + def test_minimal_oom_kill_line(self) -> None: + line = "[12345.678] Killed process 5678 (node)" + results = _parse_dmesg_for_oom_kills(line) + assert len(results) == 1 + assert results[0]["pid"] == 5678 + assert results[0]["process_name"] == "node" + assert results[0]["total_vm_kb"] is None + assert results[0]["rss_kb"] is None + + def test_multiple_oom_kills(self) -> None: + text = ( + "[100.0] Killed process 111 (a)\n" + "[200.0] some other log line\n" + "[300.0] Killed process 222 (b)\n" + ) + results = _parse_dmesg_for_oom_kills(text) + assert len(results) == 2 + assert results[0]["pid"] == 111 + assert results[1]["pid"] == 222 + + def test_no_oom_kills(self) -> None: + text = "some normal log output\nanother line\n" + results = _parse_dmesg_for_oom_kills(text) + assert len(results) == 0 + + def test_empty_input(self) -> None: + assert _parse_dmesg_for_oom_kills("") == [] + + +class TestGetProcessRssMb: + """Test reading process RSS from /proc.""" + + @pytest.mark.asyncio + async def test_reads_vmrss(self, tmp_path: object) -> None: + status_content = ( + "Name:\tpython3\n" + "VmPeak:\t1000000 kB\n" + "VmRSS:\t512000 kB\n" + "VmSize:\t800000 kB\n" + ) + with patch("ccbot.process_info.Path") as mock_path: + mock_file = mock_path.return_value + mock_file.exists.return_value = True + mock_file.read_text.return_value = status_content + result = await get_process_rss_mb(1234) + assert result is not None + assert abs(result - 500.0) < 0.1 # 512000 kB = 500 MB + + @pytest.mark.asyncio + async def test_returns_none_for_missing_process(self) -> None: + with patch("ccbot.process_info.Path") as mock_path: + mock_file = mock_path.return_value + mock_file.exists.return_value = False + result = await get_process_rss_mb(99999) + assert result is None + + +class TestWasPidOomKilled: + """Test OOM-kill correlation with specific PIDs.""" + + @pytest.mark.asyncio + async def test_detects_oom_for_matching_pid(self) -> None: + oom_kills = [ + { + "pid": 1234, + "process_name": "python3", + "total_vm_kb": 15000000, + "rss_kb": 14800000, + "line": "Killed process 1234 (python3)", + } + ] + with patch( + "ccbot.process_info.check_recent_oom_kills", + new_callable=AsyncMock, + return_value=oom_kills, + ): + result = await was_pid_oom_killed(1234) + assert result is not None + assert result["pid"] == 1234 + + @pytest.mark.asyncio + async def test_detects_oom_for_descendant_pid(self) -> None: + oom_kills = [ + { + "pid": 5678, + "process_name": "node", + "total_vm_kb": None, + "rss_kb": None, + "line": "Killed process 5678 (node)", + } + ] + with patch( + "ccbot.process_info.check_recent_oom_kills", + new_callable=AsyncMock, + return_value=oom_kills, + ): + result = await was_pid_oom_killed(1234, descendants=[5678, 9999]) + assert result is not None + assert result["pid"] == 5678 + + @pytest.mark.asyncio + async def test_returns_none_when_no_match(self) -> None: + oom_kills = [ + { + "pid": 9999, + "process_name": "other", + "total_vm_kb": None, + "rss_kb": None, + "line": "Killed process 9999 (other)", + } + ] + with patch( + "ccbot.process_info.check_recent_oom_kills", + new_callable=AsyncMock, + return_value=oom_kills, + ): + result = await was_pid_oom_killed(1234, descendants=[5678]) + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_when_no_oom_kills(self) -> None: + with patch( + "ccbot.process_info.check_recent_oom_kills", + new_callable=AsyncMock, + return_value=[], + ): + result = await was_pid_oom_killed(1234) + assert result is None + + +class TestGetMemAvailableMb: + """Test reading MemAvailable from /proc/meminfo.""" + + @pytest.mark.asyncio + async def test_parses_standard_meminfo(self) -> None: + meminfo_content = ( + "MemTotal: 16384000 kB\n" + "MemFree: 2048000 kB\n" + "MemAvailable: 8192000 kB\n" + "Buffers: 512000 kB\n" + ) + with patch("ccbot.process_info.Path") as mock_path: + mock_file = mock_path.return_value + mock_file.exists.return_value = True + mock_file.read_text.return_value = meminfo_content + result = await get_mem_available_mb() + assert result is not None + assert abs(result - 8000.0) < 0.1 # 8192000 kB = 8000 MB + + @pytest.mark.asyncio + async def test_returns_none_when_file_missing(self) -> None: + with patch("ccbot.process_info.Path") as mock_path: + mock_file = mock_path.return_value + mock_file.exists.return_value = False + result = await get_mem_available_mb() + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_for_malformed_input(self) -> None: + meminfo_content = "MemTotal: 16384000 kB\nMemFree: garbage\n" + with patch("ccbot.process_info.Path") as mock_path: + mock_file = mock_path.return_value + mock_file.exists.return_value = True + mock_file.read_text.return_value = meminfo_content + result = await get_mem_available_mb() + # MemAvailable line is absent entirely → None + assert result is None diff --git a/tests/ccbot/test_session.py b/tests/ccbot/test_session.py index 022fb55a..0b1dad74 100644 --- a/tests/ccbot/test_session.py +++ b/tests/ccbot/test_session.py @@ -1,8 +1,10 @@ """Tests for SessionManager pure dict operations.""" +import json + import pytest -from ccbot.session import SessionManager +from ccbot.session import ClaudeSession, SessionManager @pytest.fixture @@ -25,13 +27,44 @@ def test_bind_unbind_get_returns_none(self, mgr: SessionManager) -> None: def test_unbind_nonexistent_returns_none(self, mgr: SessionManager) -> None: assert mgr.unbind_thread(100, 999) is None - def test_iter_thread_bindings(self, mgr: SessionManager) -> None: + def test_all_thread_bindings(self, mgr: SessionManager) -> None: mgr.bind_thread(100, 1, "@1") mgr.bind_thread(100, 2, "@2") mgr.bind_thread(200, 3, "@3") - result = set(mgr.iter_thread_bindings()) + result = set(mgr.all_thread_bindings()) assert result == {(100, 1, "@1"), (100, 2, "@2"), (200, 3, "@3")} + def test_all_thread_bindings_returns_list(self, mgr: SessionManager) -> None: + """all_thread_bindings must return a list (snapshot), not a generator. + + A generator would hold a live reference into the internal dict and could + raise RuntimeError if an async coroutine calls unbind_thread between two + consumed values. A list snapshot is safe across await points. + """ + mgr.bind_thread(100, 1, "@1") + result = mgr.all_thread_bindings() + assert isinstance(result, list) + + def test_all_thread_bindings_snapshot_is_independent( + self, mgr: SessionManager + ) -> None: + """Mutating thread_bindings after calling all_thread_bindings must not + affect the already-returned snapshot.""" + mgr.bind_thread(100, 1, "@1") + mgr.bind_thread(100, 2, "@2") + snapshot = mgr.all_thread_bindings() + # Mutate the live dict after snapshot was taken — the snapshot must + # be unaffected (this is the property that prevents RuntimeError + # when unbind_thread runs between await points in async callers) + mgr.unbind_thread(100, 1) + assert (100, 1, "@1") in snapshot + assert len(snapshot) == 2 + + def test_all_thread_bindings_empty(self, mgr: SessionManager) -> None: + """all_thread_bindings returns an empty list when nothing is bound.""" + result = mgr.all_thread_bindings() + assert result == [] + class TestGroupChatId: """Tests for group chat_id routing (supergroup forum topic support). @@ -155,3 +188,197 @@ def test_invalid_ids(self, mgr: SessionManager) -> None: assert mgr._is_window_id("@") is False assert mgr._is_window_id("") is False assert mgr._is_window_id("@abc") is False + + +def _make_jsonl_entry( + msg_type: str, + content: list, + session_id: str = "test-session", +) -> str: + """Build a JSONL line for testing.""" + return json.dumps( + { + "type": msg_type, + "timestamp": "2026-03-07T20:00:00.000Z", + "sessionId": session_id, + "cwd": "/tmp/test", + "message": {"content": content}, + } + ) + + +class TestGetSessionDeathContext: + """Tests for get_session_death_context — crash diagnostics from JSONL.""" + + @pytest.mark.asyncio + async def test_returns_last_tool_use(self, mgr: SessionManager, tmp_path) -> None: + """Shows the last tool that was running when session died.""" + jsonl = tmp_path / "session.jsonl" + lines = [ + _make_jsonl_entry( + "assistant", + [{"type": "text", "text": "Let me read the file."}], + ), + _make_jsonl_entry( + "assistant", + [ + { + "type": "tool_use", + "id": "tool_1", + "name": "Bash", + "input": {"command": "npm run test:e2e"}, + } + ], + ), + # No tool_result — tool was mid-execution when session died + ] + jsonl.write_text("\n".join(lines) + "\n") + + # Mock resolve_session_for_window to return our test file + async def mock_resolve(wid): + return ClaudeSession( + session_id="test", + summary="", + message_count=2, + file_path=str(jsonl), + ) + + mgr.resolve_session_for_window = mock_resolve # type: ignore[assignment] + result = await mgr.get_session_death_context("@1") + + assert "Last activity:" in result + assert "Running:" in result + assert "Bash" in result + assert "npm run test:e2e" in result + + @pytest.mark.asyncio + async def test_returns_last_text_message( + self, mgr: SessionManager, tmp_path + ) -> None: + """Shows the last assistant text message.""" + jsonl = tmp_path / "session.jsonl" + lines = [ + _make_jsonl_entry( + "assistant", + [{"type": "text", "text": "All 4 test agents running."}], + ), + ] + jsonl.write_text("\n".join(lines) + "\n") + + async def mock_resolve(wid): + return ClaudeSession( + session_id="test", + summary="", + message_count=1, + file_path=str(jsonl), + ) + + mgr.resolve_session_for_window = mock_resolve # type: ignore[assignment] + result = await mgr.get_session_death_context("@1") + + assert "Last message:" in result + assert "All 4 test agents running." in result + + @pytest.mark.asyncio + async def test_returns_empty_for_missing_session(self, mgr: SessionManager) -> None: + """Returns empty string when session can't be resolved.""" + + async def mock_resolve(wid): + return None + + mgr.resolve_session_for_window = mock_resolve # type: ignore[assignment] + result = await mgr.get_session_death_context("@1") + assert result == "" + + @pytest.mark.asyncio + async def test_returns_empty_for_missing_file(self, mgr: SessionManager) -> None: + """Returns empty string when JSONL file doesn't exist.""" + + async def mock_resolve(wid): + return ClaudeSession( + session_id="test", + summary="", + message_count=0, + file_path="/nonexistent/path.jsonl", + ) + + mgr.resolve_session_for_window = mock_resolve # type: ignore[assignment] + result = await mgr.get_session_death_context("@1") + assert result == "" + + @pytest.mark.asyncio + async def test_truncates_long_messages(self, mgr: SessionManager, tmp_path) -> None: + """Long assistant text is truncated to ~200 chars.""" + jsonl = tmp_path / "session.jsonl" + long_text = "x" * 500 + lines = [ + _make_jsonl_entry( + "assistant", + [{"type": "text", "text": long_text}], + ), + ] + jsonl.write_text("\n".join(lines) + "\n") + + async def mock_resolve(wid): + return ClaudeSession( + session_id="test", + summary="", + message_count=1, + file_path=str(jsonl), + ) + + mgr.resolve_session_for_window = mock_resolve # type: ignore[assignment] + result = await mgr.get_session_death_context("@1") + + assert "..." in result + # The truncated text should be at most 200 chars + "..." + for line in result.split("\n"): + if "Last message:" in line: + # Extract the quoted message content + msg_part = line.split('"')[1] if '"' in line else "" + assert len(msg_part) <= 204 # 200 + "..." + + @pytest.mark.asyncio + async def test_completed_tool_shows_last_tool( + self, mgr: SessionManager, tmp_path + ) -> None: + """When tool completed (has result), shows as 'Last tool' not 'Running'.""" + jsonl = tmp_path / "session.jsonl" + lines = [ + _make_jsonl_entry( + "assistant", + [ + { + "type": "tool_use", + "id": "tool_1", + "name": "Read", + "input": {"file_path": "src/main.py"}, + } + ], + ), + _make_jsonl_entry( + "user", + [ + { + "type": "tool_result", + "tool_use_id": "tool_1", + "content": "file contents here", + } + ], + ), + ] + jsonl.write_text("\n".join(lines) + "\n") + + async def mock_resolve(wid): + return ClaudeSession( + session_id="test", + summary="", + message_count=2, + file_path=str(jsonl), + ) + + mgr.resolve_session_for_window = mock_resolve # type: ignore[assignment] + result = await mgr.get_session_death_context("@1") + + assert "Last tool:" in result + assert "Running:" not in result