diff --git a/Cargo.lock b/Cargo.lock index 4318682..6e20797 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -571,6 +571,15 @@ dependencies = [ "slab", ] +[[package]] +name = "getopts" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe4fbac503b8d1f88e6676011885f34b7174f46e59956bba534ba83abded4df" +dependencies = [ + "unicode-width", +] + [[package]] name = "getrandom" version = "0.2.17" @@ -1420,6 +1429,25 @@ dependencies = [ "cc", ] +[[package]] +name = "pulldown-cmark" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f86ba2052aebccc42cbbb3ed234b8b13ce76f75c3551a303cb2bcffcff12bb14" +dependencies = [ + "bitflags", + "getopts", + "memchr", + "pulldown-cmark-escape", + "unicase", +] + +[[package]] +name = "pulldown-cmark-escape" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "007d8adb5ddab6f8e3f491ac63566a7d5002cc7ed73901f72057943fa71ae1ae" + [[package]] name = "quote" version = "1.0.44" @@ -1567,13 +1595,16 @@ dependencies = [ "base64", "chrono", "futures", + "http", "pastey", "pin-project-lite", "process-wrap", + "reqwest", "rmcp-macros", "schemars 1.2.1", "serde", "serde_json", + "sse-stream", "thiserror", "tokio", "tokio-stream", @@ -1625,6 +1656,7 @@ dependencies = [ "chrono", "futures", "futures-util", + "pulldown-cmark", "reqwest", "rmcp", "rusqlite", @@ -1954,6 +1986,19 @@ dependencies = [ "cc", ] +[[package]] +name = "sse-stream" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a" +dependencies = [ + "bytes", + "futures-util", + "http-body", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -2456,6 +2501,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "537dd038a89878be9b64dd4bd1b260315c1bb94f4d784956b81e27a088d9a09e" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "unicode-xid" version = "0.2.6" diff --git a/Cargo.toml b/Cargo.toml index 20044dc..35f4b93 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ serde_json = "1" toml = "0.8" # MCP client -rmcp = { version = "0.15", features = ["client", "transport-child-process", "transport-io"] } +rmcp = { version = "0.15", features = ["client", "transport-child-process", "transport-io", "transport-streamable-http-client-reqwest"] } # Logging tracing = "0.1" @@ -49,6 +49,9 @@ chrono = { version = "0.4", features = ["serde"] } # Background task scheduler tokio-cron-scheduler = "0.13" +# Markdown parser for entity-based Telegram message formatting +pulldown-cmark = "0.12" + # SQLite vector search extension sqlite-vec = "0.1" diff --git a/README.md b/README.md index 967b5e4..5b617eb 100644 --- a/README.md +++ b/README.md @@ -290,3 +290,6 @@ If you find RustFox useful, consider supporting the project: - [reqwest](https://github.com/seanmonstar/reqwest) — HTTP client for OpenRouter - [tokio](https://tokio.rs/) — Async runtime - [tokio-cron-scheduler](https://github.com/mvniekerk/tokio-cron-scheduler) — Task scheduling +- [pulldown-cmark](https://github.com/pulldown-cmark/pulldown-cmark) — Markdown parser (entity-based Telegram formatting) + +> **Thanks:** Markdown-to-entities conversion approach inspired by [telegramify-markdown](https://github.com/sudoskys/telegramify-markdown) by sudoskys. diff --git a/config.example.toml b/config.example.toml index fd25e26..8a0c605 100644 --- a/config.example.toml +++ b/config.example.toml @@ -32,6 +32,13 @@ allowed_directory = "/tmp/rustfox-sandbox" # Stores conversations, knowledge base, and vector embeddings database_path = "rustfox.db" +# Query rewriting for memory search (optional; default: false) +# When enabled, ambiguous follow-up questions are rewritten into self-contained +# search queries using an LLM call before the RAG vector search. +# This improves recall accuracy but adds one extra LLM round-trip per message. +# Can also be toggled per-user at runtime via the /query-rewrite Telegram command. +# query_rewriter_enabled = false + [skills] # Directory containing skill markdown files # Skills are natural-language instructions loaded at startup @@ -136,3 +143,21 @@ directory = "skills" # args = ["-y", "@brave/brave-search-mcp-server"] # [mcp_servers.env] # BRAVE_API_KEY = "your-brave-api-key" + +# ── HTTP-based MCP Servers (Streamable HTTP transport) ──────────────────────── +# These servers are reached over HTTPS and do not require a local command. +# Use `url` instead of `command`; optionally set `auth_token` for Bearer auth. + +# Example: Exa AI web search (https://mcp.exa.ai) +# Get your API key at https://dashboard.exa.ai/api-keys +# +# Option A — Bearer token (recommended; keeps the key out of URLs and logs) +# [[mcp_servers]] +# name = "exa" +# url = "https://mcp.exa.ai/mcp" +# auth_token = "your-exa-api-key" +# +# Option B — Inline API key in URL (simpler; key will appear in logs) +# [[mcp_servers]] +# name = "exa" +# url = "https://mcp.exa.ai/mcp?exaApiKey=your-exa-api-key" diff --git a/src/agent.rs b/src/agent.rs index c2f209e..760d223 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -180,9 +180,29 @@ impl Agent { let rewrite_start = filtered_msgs.len().saturating_sub(6); let recent_for_rewrite = filtered_msgs[rewrite_start..].to_vec(); + // Determine if query rewriting is enabled: per-user setting overrides config default. + let per_user_setting = self + .memory + .recall( + "settings", + &format!("query_rewrite_enabled_{}", incoming.user_id), + ) + .await + .unwrap_or(None); + let rewrite_enabled = match per_user_setting.as_deref() { + Some("true") => true, + Some("false") => false, + _ => self.config.memory.query_rewriter_enabled, + }; + let llm_for_rewrite = if rewrite_enabled { + Some(&self.llm) + } else { + None + }; + if let Ok(Some(rag_block)) = crate::memory::rag::auto_retrieve_context( &self.memory, - Some(&self.llm), + llm_for_rewrite, &incoming.text, &recent_for_rewrite, &conversation_id, @@ -242,6 +262,10 @@ impl Agent { let max_iterations = self.config.max_iterations(); let mut iteration_count = 0u32; + // Clone the stream sender so tool status can be pushed into the same Telegram + // message during tool execution, before the final response starts streaming. + let stream_status_tx = stream_token_tx.clone(); + for iteration in 0..max_iterations { debug!( "Trying iteration {}: messages length: {}", @@ -344,6 +368,20 @@ impl Agent { }); } + // Stream tool status into the Telegram message only when + // tool-progress notifications are enabled, to avoid + // prepending status lines to otherwise silent/final output. + if tool_event_tx.is_some() { + if let Some(ref tx) = stream_status_tx { + let status = + crate::platform::tool_notifier::format_tool_status_line( + &tool_call.function.name, + &args_preview, + ); + tx.try_send(status).ok(); + } + } + let tool_result = self .execute_tool(&tool_call.function.name, &arguments, user_id, chat_id) .await; @@ -401,35 +439,42 @@ impl Agent { ); } - // Stream the final response token-by-token if a channel is provided - if let Some(ref tx) = stream_token_tx { - let words: Vec<&str> = content.split_inclusive(' ').collect(); - let chunk_size = 4usize; - for chunk in words.chunks(chunk_size) { - let piece = chunk.join(""); - if tx.send(piece).await.is_err() { - break; - } - tokio::time::sleep(tokio::time::Duration::from_millis(30)).await; - } - } + // Stream the final response directly from the already-complete content. + // Previously this made a second chat_stream() API call, which could return + // Ok(partial) if the SSE connection was dropped mid-generation (e.g. after an + // 11-minute kimi-k2.5 response), silently saving a truncated reply. + // Now we pipe the guaranteed-complete content through the channel in small + // chunks so Telegram still sees tokens arrive progressively. + let final_content = if let Some(tx) = stream_token_tx { + LlmClient::stream_text(content.clone(), tx).await.ok(); + content.clone() + } else { + content.clone() + }; + // Save the delivered content to persistent memory + let save_msg = crate::llm::ChatMessage { + role: response.role.clone(), + content: Some(final_content.clone()), + tool_calls: response.tool_calls.clone(), + tool_call_id: response.tool_call_id.clone(), + }; self.memory - .save_message(&conversation_id, &response) + .save_message(&conversation_id, &save_msg) .await?; // --- LangSmith: end chain run (success) --- self.langsmith.end_run(crate::langsmith::EndRunParams { id: chain_run_id, outputs: Some(serde_json::json!({ - "response": content, + "response": final_content, "iterations": iteration, })), error: None, end_time: Self::now_iso8601_static(), }); - return Ok(content); + return Ok(final_content); } // Reached max iterations diff --git a/src/bin/setup.rs b/src/bin/setup.rs index 0d4fd4b..f0ae7e9 100644 --- a/src/bin/setup.rs +++ b/src/bin/setup.rs @@ -124,6 +124,14 @@ struct RawMcpServer { args: Vec, #[serde(default)] env: HashMap, + // `url` and `auth_token` are parsed but not used by the setup wizard; + // they are accepted so configs with HTTP MCP servers load without error. + #[serde(default)] + #[allow(dead_code)] + url: Option, + #[serde(default)] + #[allow(dead_code)] + auth_token: Option, } // ── Handlers ─────────────────────────────────────────────────────────────────── diff --git a/src/config.rs b/src/config.rs index 20bef84..5dc39fe 100644 --- a/src/config.rs +++ b/src/config.rs @@ -62,11 +62,24 @@ pub struct SandboxConfig { #[derive(Debug, Deserialize, Clone)] pub struct McpServerConfig { pub name: String, - pub command: String, + /// Command to run for stdio-based MCP servers (e.g. "uvx", "npx"). + /// Required for stdio servers; omit for HTTP servers. + #[serde(default)] + pub command: Option, #[serde(default)] pub args: Vec, #[serde(default)] pub env: std::collections::HashMap, + /// URL for HTTP-based MCP servers using the Streamable HTTP transport. + /// Required for HTTP servers; omit for stdio servers. + /// The API key may be embedded as a query parameter (e.g. `?exaApiKey=KEY`) + /// or provided separately via `auth_token`. + #[serde(default)] + pub url: Option, + /// Bearer token sent in the `Authorization` header for HTTP servers. + /// Used with `url`; ignored for stdio servers. + #[serde(default)] + pub auth_token: Option, } #[derive(Debug, Deserialize, Clone)] @@ -83,6 +96,12 @@ pub struct MemoryConfig { #[serde(default = "default_summarize_cron")] #[allow(dead_code)] pub summarize_cron: String, + /// When `true`, an LLM call rewrites ambiguous follow-up questions into + /// self-contained search queries before the RAG vector search. + /// Defaults to `false` to avoid the extra LLM round-trip. + /// Can be toggled per-user at runtime via the `/query-rewrite` command. + #[serde(default)] + pub query_rewriter_enabled: bool, } #[derive(Debug, Deserialize, Clone)] @@ -211,6 +230,7 @@ fn default_memory_config() -> MemoryConfig { max_raw_messages: default_max_raw_messages(), summarize_threshold: default_summarize_threshold(), summarize_cron: default_summarize_cron(), + query_rewriter_enabled: false, } } @@ -331,4 +351,109 @@ mod tests { let ls = cfg.langsmith.unwrap(); assert_eq!(ls.project, "default"); } + + #[test] + fn test_mcp_server_url_config_parses() { + let toml = r#" + [telegram] + bot_token = "tok" + allowed_user_ids = [1] + [openrouter] + api_key = "key" + [sandbox] + allowed_directory = "/tmp" + [[mcp_servers]] + name = "exa" + url = "https://mcp.exa.ai/mcp" + auth_token = "exa-key-123" + "#; + let cfg: Config = toml::from_str(toml).unwrap(); + assert_eq!(cfg.mcp_servers.len(), 1); + let server = &cfg.mcp_servers[0]; + assert_eq!(server.name, "exa"); + assert_eq!(server.url.as_deref(), Some("https://mcp.exa.ai/mcp")); + assert_eq!(server.auth_token.as_deref(), Some("exa-key-123")); + assert!( + server.command.is_none(), + "HTTP server should have no command" + ); + } + + #[test] + fn test_mcp_server_stdio_command_optional() { + let toml = r#" + [telegram] + bot_token = "tok" + allowed_user_ids = [1] + [openrouter] + api_key = "key" + [sandbox] + allowed_directory = "/tmp" + [[mcp_servers]] + name = "git" + command = "uvx" + args = ["mcp-server-git"] + "#; + let cfg: Config = toml::from_str(toml).unwrap(); + assert_eq!(cfg.mcp_servers[0].command.as_deref(), Some("uvx")); + assert!(cfg.mcp_servers[0].url.is_none()); + } + + #[test] + fn test_mcp_server_url_without_auth_token() { + let toml = r#" + [telegram] + bot_token = "tok" + allowed_user_ids = [1] + [openrouter] + api_key = "key" + [sandbox] + allowed_directory = "/tmp" + [[mcp_servers]] + name = "exa" + url = "https://mcp.exa.ai/mcp?exaApiKey=inline-key" + "#; + let cfg: Config = toml::from_str(toml).unwrap(); + let s = &cfg.mcp_servers[0]; + assert!(s.url.is_some()); + assert!(s.auth_token.is_none()); + } + + #[test] + fn test_query_rewriter_disabled_by_default() { + let toml = r#" + [telegram] + bot_token = "tok" + allowed_user_ids = [1] + [openrouter] + api_key = "key" + [sandbox] + allowed_directory = "/tmp" + "#; + let cfg: Config = toml::from_str(toml).unwrap(); + assert!( + !cfg.memory.query_rewriter_enabled, + "query_rewriter_enabled must default to false" + ); + } + + #[test] + fn test_query_rewriter_can_be_enabled() { + let toml = r#" + [telegram] + bot_token = "tok" + allowed_user_ids = [1] + [openrouter] + api_key = "key" + [sandbox] + allowed_directory = "/tmp" + [memory] + query_rewriter_enabled = true + "#; + let cfg: Config = toml::from_str(toml).unwrap(); + assert!( + cfg.memory.query_rewriter_enabled, + "query_rewriter_enabled should be true when set" + ); + } } diff --git a/src/llm.rs b/src/llm.rs index fb8bd5d..3c37676 100644 --- a/src/llm.rs +++ b/src/llm.rs @@ -1,5 +1,4 @@ use anyhow::{Context, Result}; -use futures_util::StreamExt; use serde::{Deserialize, Serialize}; use tracing::{debug, warn}; @@ -55,19 +54,6 @@ struct ChatRequest { max_tokens: u32, } -/// Like ChatRequest but with stream=true for SSE streaming. -#[derive(Debug, Serialize)] -struct StreamRequest { - model: String, - messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, - max_tokens: u32, - stream: bool, -} - #[derive(Debug, Deserialize)] struct ChatResponse { choices: Vec, @@ -80,18 +66,197 @@ struct Choice { finish_reason: Option, } -/// Parse a single SSE line and extract the text content token, if any. -/// Returns `None` for non-data lines, `[DONE]`, empty deltas, or parse errors. -fn parse_sse_content(line: &str) -> Option { - let data = line.strip_prefix("data: ")?; - if data == "[DONE]" { +/// Sanitize a JSON Schema parameter object so it is accepted by strict providers +/// (e.g. Google Gemini via OpenRouter). +/// +/// Gemini enforces that every entry in the `required` array corresponds to a key +/// that is actually defined in `properties`. Some MCP servers return schemas where +/// `required` contains field names that do not exist in `properties` — this causes a +/// 400 INVALID_ARGUMENT from Google AI Studio. +/// +/// Additional Gemini restrictions handled here: +/// - `additionalProperties`, `$schema`, `$defs`, `$ref` are not accepted. +/// - `required: []` (empty array) is rejected; the key must be omitted entirely. +/// - `anyOf`/`oneOf`/`allOf` variants with `{"type": "null"}` are stripped because +/// Gemini does not support nullable types expressed as `null` union members. +/// If stripping leaves exactly one variant, it is inlined (unwrapped) into the +/// parent object. If stripping leaves zero variants, the key is removed entirely. +/// +/// This function mutates the schema in-place and recurses into `properties`, +/// `items`, `anyOf`, `oneOf`, and `allOf` sub-schemas. +fn sanitize_parameters(schema: &mut serde_json::Value) { + let obj = match schema.as_object_mut() { + Some(o) => o, + None => return, + }; + + // Remove fields that Gemini rejects. + obj.remove("additionalProperties"); + obj.remove("$schema"); + obj.remove("$defs"); + obj.remove("$ref"); + + // Collect the set of property names that are actually defined. + let known_props: std::collections::HashSet = obj + .get("properties") + .and_then(|p| p.as_object()) + .map(|p| p.keys().cloned().collect()) + .unwrap_or_default(); + + // Filter `required` so it only lists names that appear in `properties`. + // Gemini also rejects an empty `required: []`, so remove the key entirely + // if nothing remains after filtering. + if let Some(required) = obj.get_mut("required") { + if let Some(arr) = required.as_array_mut() { + arr.retain(|v| v.as_str().is_some_and(|s| known_props.contains(s))); + } + } + if obj + .get("required") + .and_then(|r| r.as_array()) + .is_some_and(|a| a.is_empty()) + { + obj.remove("required"); + } + + // Recurse into property sub-schemas. + if let Some(properties) = obj.get_mut("properties") { + if let Some(props_obj) = properties.as_object_mut() { + for prop_schema in props_obj.values_mut() { + sanitize_parameters(prop_schema); + } + } + } + + // Recurse into array item schema. + if let Some(items) = obj.get_mut("items") { + sanitize_parameters(items); + } + + // Recurse into anyOf / oneOf / allOf variant schemas and strip null variants. + // Gemini does not support {"type": "null"} as a union member. + for key in &["anyOf", "oneOf", "allOf"] { + if let Some(variants) = obj.get_mut(*key) { + if let Some(arr) = variants.as_array_mut() { + // Recurse into each variant first. + for v in arr.iter_mut() { + sanitize_parameters(v); + } + // Remove variants that are purely {"type": "null"}. + arr.retain(|v| { + v.get("type") + .and_then(|t| t.as_str()) + .map(|t| t != "null") + .unwrap_or(true) + }); + } + } + + // If only one variant remains, unwrap it by merging into the parent. + // If zero variants remain, remove the key entirely. + let variant_count = obj.get(*key).and_then(|v| v.as_array()).map(|a| a.len()); + match variant_count { + Some(0) => { + obj.remove(*key); + } + Some(1) => { + if let Some(single) = obj + .remove(*key) + .and_then(|mut v| v.as_array_mut().and_then(|a| a.pop())) + { + if let Some(inner) = single.as_object() { + for (k, v) in inner { + obj.entry(k.clone()).or_insert(v.clone()); + } + } + } + } + _ => {} + } + } +} + +/// Parse Kimi's native tool-call text format and convert it into `ToolCall` structs. +/// +/// Some models (e.g. `moonshotai/kimi-k2.5`) occasionally leak their internal +/// tool-invocation syntax into the `content` field instead of populating the +/// standard `tool_calls` API field. The leaked text looks like: +/// +/// ```text +/// <|tool_calls_section_begin|> <|tool_call_begin|> functions.my_tool:0 +/// <|tool_call_argument_begin|> {"arg": "value"} <|tool_call_end|> +/// <|tool_calls_section_end|> +/// ``` +/// +/// Returns `Some(Vec)` with at least one entry when the format is +/// detected, or `None` if the content does not contain the Kimi markers. +fn parse_kimi_tool_calls(content: &str) -> Option> { + if !content.contains("<|tool_calls_section_begin|>") { return None; } - let value: serde_json::Value = serde_json::from_str(data).ok()?; - let content = value.get("choices")?.get(0)?.get("delta")?.get("content")?; - match content { - serde_json::Value::String(s) if !s.is_empty() => Some(s.clone()), - _ => None, + + let mut calls = Vec::new(); + + // Split on the per-call begin marker; the first chunk is the preamble/section + // header and is discarded. + for block in content.split("<|tool_call_begin|>").skip(1) { + // Strip everything from the closing marker onwards (handles trailing + // section-end marker and whitespace). + let block = block + .split("<|tool_call_end|>") + .next() + .unwrap_or(block) + .trim(); + + // Split into function descriptor and JSON arguments. + let (descriptor, args_raw) = if let Some(pos) = block.find("<|tool_call_argument_begin|>") { + let d = block[..pos].trim(); + let a = block[pos + "<|tool_call_argument_begin|>".len()..].trim(); + (d, a) + } else { + continue; + }; + + // Descriptor format: `functions.{name}:{index}` or just `functions.{name}`. + // Extract the plain function name. + let func_name = descriptor + .trim_start_matches("functions.") + .split(':') + .next() + .unwrap_or(descriptor) + .trim() + .to_string(); + + if func_name.is_empty() { + continue; + } + + // Use the call index (if present) as part of the synthetic tool-call ID. + let call_index = descriptor.split(':').nth(1).unwrap_or("0").trim(); + let call_id = format!("kimi_fallback_{func_name}_{call_index}"); + + // Verify the arguments are valid JSON; fall back to an empty object on + // parse failure so the tool handler can still attempt execution. + let arguments = if serde_json::from_str::(args_raw).is_ok() { + args_raw.to_string() + } else { + "{}".to_string() + }; + + calls.push(ToolCall { + id: call_id, + call_type: "function".to_string(), + function: FunctionCall { + name: func_name, + arguments, + }, + }); + } + + if calls.is_empty() { + None + } else { + Some(calls) } } @@ -119,7 +284,15 @@ impl LlmClient { let tools_param = if tools.is_empty() { None } else { - Some(tools.to_vec()) + let sanitized = tools + .iter() + .map(|t| { + let mut t = t.clone(); + sanitize_parameters(&mut t.function.parameters); + t + }) + .collect(); + Some(sanitized) }; let tool_choice = if tools_param.is_some() { @@ -184,12 +357,35 @@ impl LlmClient { } } - chat_response + let mut choice = chat_response .choices .into_iter() .next() - .map(|c| c.message) - .context("No response from OpenRouter") + .context("No response from OpenRouter")?; + + // Kimi-family models occasionally leak their native tool-call syntax into + // the `content` field instead of populating `tool_calls`. Detect and fix. + let has_tool_calls = choice + .message + .tool_calls + .as_ref() + .is_some_and(|t| !t.is_empty()); + if !has_tool_calls { + if let Some(ref content) = choice.message.content.clone() { + if let Some(parsed) = parse_kimi_tool_calls(content) { + warn!( + tool_count = parsed.len(), + "Kimi native tool-call format detected in content — \ + extracting tool calls and clearing content" + ); + choice.message.tool_calls = Some(parsed); + choice.message.content = None; + choice.finish_reason = Some("tool_calls".to_string()); + } + } + } + + Ok(choice.message) } /// Chat using the model configured in config.toml (delegates to chat_with_model). @@ -202,87 +398,33 @@ impl LlmClient { .await } - /// Stream the final LLM response token-by-token via an mpsc channel. - /// Sends each content token as a separate `String` message. - /// Closes the sender when the stream ends or on error. - /// Does NOT pass tools — use this only for the final text-only response. - #[allow(dead_code)] - pub async fn chat_stream( - &self, - messages: &[ChatMessage], - model: &str, + /// Stream an already-complete `content` string through a token channel in small chunks. + /// + /// This avoids a second LLM API call when the full response is already available from a + /// preceding non-streaming call. The caller still sees tokens arrive progressively, + /// keeping the Telegram streaming UX intact, but there is no risk of the SSE connection + /// being dropped mid-stream and silently returning a truncated response. + pub async fn stream_text( + content: String, token_tx: tokio::sync::mpsc::Sender, ) -> Result<()> { - let request = StreamRequest { - model: model.to_string(), - messages: messages.to_vec(), - tools: None, - tool_choice: None, - max_tokens: self.config.max_tokens, - stream: true, - }; + const CHUNK_SIZE: usize = 30; - let url = format!("{}/chat/completions", self.config.base_url); + let chars: Vec = content.chars().collect(); + let mut start = 0; - debug!( - url = %url, - model = %model, - message_count = messages.len(), - "Starting streaming request to OpenRouter" - ); + while start < chars.len() { + let end = (start + CHUNK_SIZE).min(chars.len()); + let chunk: String = chars[start..end].iter().collect(); + start = end; - let response = self - .client - .post(&url) - .header("Authorization", format!("Bearer {}", self.config.api_key)) - .header("Content-Type", "application/json") - .header("Accept", "text/event-stream") - .json(&request) - .send() - .await - .context("Failed to send streaming request to OpenRouter")?; - - let status = response.status(); - if !status.is_success() { - let error_body = response.text().await.unwrap_or_default(); - anyhow::bail!( - "OpenRouter streaming API error ({}): {}", - status, - error_body - ); - } - - // Accumulate bytes into lines (SSE lines end with \n) - let mut stream = response.bytes_stream(); - let mut line_buf = String::new(); - - while let Some(chunk) = stream.next().await { - let bytes = chunk.context("Stream read error")?; - let text = String::from_utf8_lossy(&bytes); - - for ch in text.chars() { - if ch == '\n' { - let line = line_buf.trim().to_string(); - line_buf.clear(); - - if let Some(token) = parse_sse_content(&line) { - if token_tx.send(token).await.is_err() { - debug!("Stream receiver dropped — stopping early"); - return Ok(()); - } - } - } else { - line_buf.push(ch); - } + if token_tx.send(chunk).await.is_err() { + // Receiver dropped — stop early, this is not an error + debug!("stream_text: receiver dropped — stopping early"); + return Ok(()); } - } - // Process any remaining buffered line - if !line_buf.is_empty() { - let line = line_buf.trim().to_string(); - if let Some(token) = parse_sse_content(&line) { - token_tx.send(token).await.ok(); - } + tokio::time::sleep(tokio::time::Duration::from_millis(2)).await; } Ok(()) @@ -342,51 +484,247 @@ mod tests { } #[test] - fn test_parse_sse_line_data_returns_content() { - let line = r#"data: {"choices":[{"delta":{"content":"Hello"},"finish_reason":null}]}"#; - let result = parse_sse_content(line); - assert_eq!(result, Some("Hello".to_string())); + fn test_parse_kimi_tool_calls_single_call() { + let content = " <|tool_calls_section_begin|> <|tool_call_begin|> functions.read_skill_file:5 \ + <|tool_call_argument_begin|> {\"skill_name\": \"reddit-fetcher\", \"relative_path\": \"SKILL.md\"} \ + <|tool_call_end|> <|tool_calls_section_end|>"; + let calls = parse_kimi_tool_calls(content).expect("should parse"); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.name, "read_skill_file"); + assert_eq!(calls[0].call_type, "function"); + let args: serde_json::Value = serde_json::from_str(&calls[0].function.arguments).unwrap(); + assert_eq!(args["skill_name"], "reddit-fetcher"); } #[test] - fn test_parse_sse_line_done_returns_none() { - let result = parse_sse_content("data: [DONE]"); - assert_eq!(result, None); + fn test_parse_kimi_tool_calls_multiple_calls() { + let content = "<|tool_calls_section_begin|>\ + <|tool_call_begin|> functions.tool_a:0 <|tool_call_argument_begin|> {\"x\": 1} <|tool_call_end|>\ + <|tool_call_begin|> functions.tool_b:1 <|tool_call_argument_begin|> {\"y\": 2} <|tool_call_end|>\ + <|tool_calls_section_end|>"; + let calls = parse_kimi_tool_calls(content).expect("should parse"); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].function.name, "tool_a"); + assert_eq!(calls[1].function.name, "tool_b"); } #[test] - fn test_parse_sse_line_empty_delta_returns_none() { - let line = r#"data: {"choices":[{"delta":{},"finish_reason":null}]}"#; - let result = parse_sse_content(line); - assert_eq!(result, None); + fn test_parse_kimi_tool_calls_no_markers_returns_none() { + assert!(parse_kimi_tool_calls("Hello, world!").is_none()); + assert!(parse_kimi_tool_calls("").is_none()); } #[test] - fn test_parse_sse_line_non_data_prefix_returns_none() { - assert_eq!(parse_sse_content(": OPENROUTER PROCESSING"), None); - assert_eq!(parse_sse_content(""), None); - assert_eq!(parse_sse_content("event: ping"), None); + fn test_parse_kimi_tool_calls_invalid_json_falls_back_to_empty_object() { + let content = "<|tool_calls_section_begin|>\ + <|tool_call_begin|> functions.my_tool:0 \ + <|tool_call_argument_begin|> not valid json <|tool_call_end|>\ + <|tool_calls_section_end|>"; + let calls = parse_kimi_tool_calls(content).expect("should parse"); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].function.arguments, "{}"); } #[test] - fn test_parse_sse_line_null_content_returns_none() { - let line = r#"data: {"choices":[{"delta":{"content":null},"finish_reason":"stop"}]}"#; - let result = parse_sse_content(line); - assert_eq!(result, None); + fn test_parse_kimi_tool_calls_id_uses_index() { + let content = "<|tool_calls_section_begin|>\ + <|tool_call_begin|> functions.do_thing:7 \ + <|tool_call_argument_begin|> {} <|tool_call_end|>\ + <|tool_calls_section_end|>"; + let calls = parse_kimi_tool_calls(content).expect("should parse"); + assert!(calls[0].id.contains("7"), "id should embed the call index"); + } + + #[tokio::test] + async fn test_stream_text_sends_all_content() { + let content = "Hello, world! This is a test of stream_text.".to_string(); + let (tx, mut rx) = tokio::sync::mpsc::channel(64); + + LlmClient::stream_text(content.clone(), tx).await.unwrap(); + + let mut received = String::new(); + while let Ok(chunk) = rx.try_recv() { + received.push_str(&chunk); + } + assert_eq!(received, content); + } + + #[tokio::test] + async fn test_stream_text_stops_when_receiver_dropped() { + let content = "A".repeat(1000); + let (tx, rx) = tokio::sync::mpsc::channel(2); + // Drop the receiver immediately — stream_text should return Ok without panic + drop(rx); + + let result = LlmClient::stream_text(content, tx).await; + assert!( + result.is_ok(), + "stream_text must return Ok even when receiver is dropped" + ); } #[test] - fn test_stream_request_serializes_stream_true() { - let req = StreamRequest { - model: "test-model".to_string(), - messages: vec![], - tools: None, - tool_choice: None, - max_tokens: 100, - stream: true, - }; - let json = serde_json::to_value(&req).unwrap(); - assert_eq!(json["stream"], true); - assert_eq!(json["model"], "test-model"); + fn test_sanitize_parameters_removes_undefined_required_entries() { + // Google Gemini rejects required entries not present in properties. + let mut schema = serde_json::json!({ + "type": "object", + "properties": { + "foo": { "type": "string" } + }, + "required": ["foo", "bar"] // "bar" is not in properties + }); + sanitize_parameters(&mut schema); + let required = schema["required"].as_array().unwrap(); + assert_eq!(required.len(), 1); + assert_eq!(required[0], "foo"); + } + + #[test] + fn test_sanitize_parameters_removes_additional_properties() { + let mut schema = serde_json::json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }); + sanitize_parameters(&mut schema); + assert!(schema.get("additionalProperties").is_none()); + } + + #[test] + fn test_sanitize_parameters_removes_schema_metadata_fields() { + let mut schema = serde_json::json!({ + "type": "object", + "$schema": "http://json-schema.org/draft-07/schema#", + "$defs": {}, + "$ref": "#/$defs/SomeType", + "properties": {} + }); + sanitize_parameters(&mut schema); + assert!(schema.get("$schema").is_none()); + assert!(schema.get("$defs").is_none()); + assert!(schema.get("$ref").is_none()); + } + + #[test] + fn test_sanitize_parameters_recurses_into_properties() { + let mut schema = serde_json::json!({ + "type": "object", + "properties": { + "nested": { + "type": "object", + "properties": { + "x": { "type": "string" } + }, + "required": ["x", "missing"], + "additionalProperties": true + } + } + }); + sanitize_parameters(&mut schema); + let nested = &schema["properties"]["nested"]; + let required = nested["required"].as_array().unwrap(); + assert_eq!(required.len(), 1); + assert_eq!(required[0], "x"); + assert!(nested.get("additionalProperties").is_none()); + } + + #[test] + fn test_sanitize_parameters_recurses_into_array_items() { + let mut schema = serde_json::json!({ + "type": "array", + "items": { + "type": "object", + "properties": { + "a": { "type": "number" } + }, + "required": ["a", "b"], + "additionalProperties": false + } + }); + sanitize_parameters(&mut schema); + let items = &schema["items"]; + let required = items["required"].as_array().unwrap(); + assert_eq!(required.len(), 1); + assert_eq!(required[0], "a"); + assert!(items.get("additionalProperties").is_none()); + } + + #[test] + fn test_sanitize_parameters_valid_schema_unchanged() { + // A schema that is already valid should pass through unmodified. + let mut schema = serde_json::json!({ + "type": "object", + "properties": { + "path": { "type": "string", "description": "File path" } + }, + "required": ["path"] + }); + let original = schema.clone(); + sanitize_parameters(&mut schema); + assert_eq!(schema, original); + } + + #[test] + fn test_sanitize_parameters_removes_empty_required_array() { + // Gemini rejects required: [] — it must be omitted entirely. + let mut schema = serde_json::json!({ + "type": "object", + "properties": {}, + "required": ["a", "b"] // neither "a" nor "b" is in properties + }); + sanitize_parameters(&mut schema); + // required should be gone, not left as [] + assert!(schema.get("required").is_none()); + } + + #[test] + fn test_sanitize_parameters_strips_null_anyof_variants() { + // Gemini does not support {"type": "null"} as a union member. + // When only one variant remains after stripping, it is unwrapped. + let mut schema = serde_json::json!({ + "type": "object", + "properties": { + "name": { + "anyOf": [ + { "type": "string" }, + { "type": "null" } + ] + } + } + }); + sanitize_parameters(&mut schema); + let name = &schema["properties"]["name"]; + // The null variant was stripped, leaving one variant which was unwrapped. + assert!(name.get("anyOf").is_none(), "anyOf should be unwrapped"); + assert_eq!(name["type"], "string"); + } + + #[test] + fn test_sanitize_parameters_recurses_into_anyof() { + // Nested schemas inside anyOf/oneOf should also be sanitized. + // Single-variant anyOf is unwrapped into the parent. + let mut schema = serde_json::json!({ + "type": "object", + "properties": { + "val": { + "anyOf": [ + { + "type": "object", + "properties": { "x": { "type": "string" } }, + "required": ["x", "missing"], + "additionalProperties": false + } + ] + } + } + }); + sanitize_parameters(&mut schema); + // Single-variant anyOf is unwrapped: inner object fields are inlined into "val". + let val = &schema["properties"]["val"]; + assert!(val.get("anyOf").is_none(), "anyOf should be unwrapped"); + let required = val["required"].as_array().unwrap(); + assert_eq!(required.len(), 1); + assert_eq!(required[0], "x"); + assert!(val.get("additionalProperties").is_none()); } } diff --git a/src/mcp.rs b/src/mcp.rs index ca20eef..94da412 100644 --- a/src/mcp.rs +++ b/src/mcp.rs @@ -2,7 +2,10 @@ use anyhow::{Context, Result}; use rmcp::{ model::{CallToolRequestParams, Tool as McpTool}, service::RunningService, - transport::{ConfigureCommandExt, TokioChildProcess}, + transport::{ + streamable_http_client::StreamableHttpClientTransportConfig, ConfigureCommandExt, + StreamableHttpClientTransport, TokioChildProcess, + }, ServiceExt, }; use serde_json::Value; @@ -32,32 +35,80 @@ impl McpManager { } } - /// Connect to an MCP server via stdio child process + /// Connect to an MCP server — dispatches to HTTP or stdio based on config. pub async fn connect(&mut self, config: &McpServerConfig) -> Result<()> { + if config.url.is_some() { + self.connect_http(config).await + } else { + self.connect_stdio(config).await + } + } + + /// Connect to an HTTP-based MCP server using the Streamable HTTP transport. + async fn connect_http(&mut self, config: &McpServerConfig) -> Result<()> { + let url = config + .url + .as_deref() + .context("HTTP MCP server config missing 'url'")?; + + info!("Connecting to HTTP MCP server '{}': {}", config.name, url); + + let transport_config = StreamableHttpClientTransportConfig::with_uri(url.to_string()) + .auth_header(config.auth_token.clone().unwrap_or_default()); + + let transport = StreamableHttpClientTransport::from_config(transport_config); + + // `()` implements rmcp's `ServiceExt` as the default no-op client handler; + // calling `.serve(transport)` on it returns a `RunningService` connected + // to the given transport without any application-level request handling. + let client = ().serve(transport).await.with_context(|| { + format!("Failed to initialize HTTP MCP connection: {}", config.name) + })?; + + self.register_client(config, client).await + } + + /// Connect to a stdio-based MCP server via a child process. + async fn connect_stdio(&mut self, config: &McpServerConfig) -> Result<()> { + let command_str = config + .command + .as_deref() + .context("Stdio MCP server config missing 'command'")?; + info!( - "Connecting to MCP server '{}': {} {:?}", - config.name, config.command, config.args + "Connecting to stdio MCP server '{}': {} {:?}", + config.name, command_str, config.args ); let args = config.args.clone(); let env = config.env.clone(); - let command_str = config.command.clone(); + let cmd = command_str.to_string(); - let transport = TokioChildProcess::new(Command::new(&command_str).configure(move |cmd| { + let transport = TokioChildProcess::new(Command::new(&cmd).configure(move |c| { for arg in &args { - cmd.arg(arg); + c.arg(arg); } for (key, value) in &env { - cmd.env(key, value); + c.env(key, value); } })) .with_context(|| format!("Failed to start MCP server process: {}", config.name))?; + // `()` is rmcp's default no-op client handler; see `connect_http` for details. let client = () .serve(transport) .await .with_context(|| format!("Failed to initialize MCP connection: {}", config.name))?; + self.register_client(config, client).await + } + + /// Register a connected client, listing its tools and storing it. + async fn register_client( + &mut self, + config: &McpServerConfig, + client: RunningService, + ) -> Result<()> { let server_info = client.peer_info(); info!( "Connected to MCP server '{}': {:?}", diff --git a/src/platform/telegram.rs b/src/platform/telegram.rs index 640f8a2..e01062f 100644 --- a/src/platform/telegram.rs +++ b/src/platform/telegram.rs @@ -2,10 +2,13 @@ use std::sync::Arc; use anyhow::Result; use teloxide::prelude::*; +use teloxide::types::ParseMode; use tracing::{error, info, warn}; use crate::agent::Agent; use crate::platform::IncomingMessage; +use crate::utils::markdown_entities::{markdown_to_entities, split_entities}; +use crate::utils::telegram_markdown::escape_text; /// Split long messages for Telegram's 4096 char limit #[cfg(test)] @@ -105,22 +108,25 @@ async fn handle_message(bot: Bot, msg: Message, agent: Arc) -> ResponseRe { error!("Failed to clear conversation: {}", e); } - bot.send_message(msg.chat.id, "Conversation cleared.") + bot.send_message(msg.chat.id, escape_text("Conversation cleared.")) + .parse_mode(ParseMode::MarkdownV2) .await?; return Ok(()); } if text == "/start" { - bot.send_message( - msg.chat.id, + let help = escape_text( "Hello! I'm your AI assistant. Send me a message and I'll help you.\n\n\ Commands:\n\ /clear - Clear conversation history\n\ /tools - List available tools\n\ /skills - List loaded skills\n\ - /verbose - Toggle tool call progress display", - ) - .await?; + /verbose - Toggle tool call progress display\n\ + /query-rewrite - Toggle query rewriting for memory search", + ); + bot.send_message(msg.chat.id, help) + .parse_mode(ParseMode::MarkdownV2) + .await?; return Ok(()); } @@ -133,7 +139,9 @@ async fn handle_message(bot: Bot, msg: Message, agent: Arc) -> ResponseRe tool.function.name, tool.function.description )); } - bot.send_message(msg.chat.id, tool_list).await?; + bot.send_message(msg.chat.id, escape_text(&tool_list)) + .parse_mode(ParseMode::MarkdownV2) + .await?; return Ok(()); } @@ -141,13 +149,17 @@ async fn handle_message(bot: Bot, msg: Message, agent: Arc) -> ResponseRe let skills_guard = agent.skills.read().await; let skills = skills_guard.list(); if skills.is_empty() { - bot.send_message(msg.chat.id, "No skills loaded.").await?; + bot.send_message(msg.chat.id, escape_text("No skills loaded.")) + .parse_mode(ParseMode::MarkdownV2) + .await?; } else { let mut skill_list = String::from("Loaded skills:\n\n"); for skill in &skills { skill_list.push_str(&format!(" - {}: {}\n", skill.name, skill.description)); } - bot.send_message(msg.chat.id, skill_list).await?; + bot.send_message(msg.chat.id, escape_text(&skill_list)) + .parse_mode(ParseMode::MarkdownV2) + .await?; } return Ok(()); } @@ -175,7 +187,43 @@ async fn handle_message(bot: Bot, msg: Message, agent: Arc) -> ResponseRe } else { "🔇 Tool call UI disabled. I'll respond silently." }; - bot.send_message(msg.chat.id, reply).await?; + bot.send_message(msg.chat.id, escape_text(reply)) + .parse_mode(ParseMode::MarkdownV2) + .await?; + return Ok(()); + } + + if text == "/query-rewrite" { + let current = agent + .memory + .recall("settings", &format!("query_rewrite_enabled_{}", user_id)) + .await + .unwrap_or(None); + // When no per-user setting exists, fall back to the global config default. + let currently_on = match current.as_deref() { + Some("true") => true, + Some("false") => false, + _ => agent.config.memory.query_rewriter_enabled, + }; + let new_value = if currently_on { "false" } else { "true" }; + agent + .memory + .remember( + "settings", + &format!("query_rewrite_enabled_{}", user_id), + new_value, + None, + ) + .await + .ok(); + let reply = if new_value == "true" { + "🔍 Query rewriting enabled. Follow-up questions will be rewritten before memory search." + } else { + "🔍 Query rewriting disabled. Messages will be searched as-is." + }; + bot.send_message(msg.chat.id, escape_text(reply)) + .parse_mode(ParseMode::MarkdownV2) + .await?; return Ok(()); } @@ -218,6 +266,21 @@ async fn handle_message(bot: Bot, msg: Message, agent: Arc) -> ResponseRe None }; + // When verbose is OFF, send a transient "Thinking..." placeholder so the user + // knows the bot is processing. The stream handle will edit it in-place with the + // first LLM tokens, so only one message is ever visible. + let placeholder_msg_id: Option = if !verbose_enabled { + match bot.send_message(msg.chat.id, "⏳ Thinking...").await { + Ok(sent) => Some(sent.id), + Err(e) => { + tracing::warn!(error = %e, "Failed to send thinking placeholder"); + None + } + } + } else { + None + }; + // Streaming: set up token channel for progressive message display const TELEGRAM_STREAM_SPLIT: usize = 3800; @@ -230,7 +293,8 @@ async fn handle_message(bot: Bot, msg: Message, agent: Arc) -> ResponseRe use std::time::{Duration, Instant}; let mut buffer = String::new(); - let mut current_msg_id: Option = None; + // Seed current_msg_id with the placeholder so the first edit reuses it. + let mut current_msg_id: Option = placeholder_msg_id; let mut last_action = Instant::now(); let mut rx = stream_token_rx; @@ -270,17 +334,43 @@ async fn handle_message(bot: Bot, msg: Message, agent: Arc) -> ResponseRe } } - // Final: flush whatever is left in the buffer + // Final: flush whatever is left in the buffer. + // Use the entity-based approach: convert completed Markdown to (plain_text, entities). + // This is robust for LLM output — no escaping needed, no risk of Telegram 400 errors. + // Intermediate streaming edits remain plain text (partial markdown is fragile). if !buffer.is_empty() { - if let Some(msg_id) = current_msg_id { - stream_bot - .edit_message_text(stream_chat_id, msg_id, &buffer) - .await - .ok(); - } else { - // No intermediate message was sent — deliver the complete response now - stream_bot.send_message(stream_chat_id, &buffer).await.ok(); + const MAX_UTF16: usize = 4090; + let (plain_text, entities) = markdown_to_entities(&buffer); + let chunks = split_entities(&plain_text, &entities, MAX_UTF16); + + for (i, (chunk_text, chunk_entities)) in chunks.iter().enumerate() { + if i == 0 { + // First chunk: edit or replace the existing in-progress message + if let Some(msg_id) = current_msg_id { + stream_bot + .edit_message_text(stream_chat_id, msg_id, chunk_text) + .entities(chunk_entities.clone()) + .await + .ok(); + } else { + stream_bot + .send_message(stream_chat_id, chunk_text) + .entities(chunk_entities.clone()) + .await + .ok(); + } + } else { + // Subsequent chunks: send as new messages + stream_bot + .send_message(stream_chat_id, chunk_text) + .entities(chunk_entities.clone()) + .await + .ok(); + } } + } else if let Some(msg_id) = current_msg_id { + // Edge case: no tokens were streamed but we have a placeholder — delete it + stream_bot.delete_message(stream_chat_id, msg_id).await.ok(); } }); @@ -316,7 +406,8 @@ async fn handle_message(bot: Bot, msg: Message, agent: Arc) -> ResponseRe if let Err(e) = process_result { warn!(error = %e, "Agent processing failed"); - bot.send_message(msg.chat.id, format!("Error: {:#}", e)) + bot.send_message(msg.chat.id, escape_text(&format!("Error: {:#}", e))) + .parse_mode(ParseMode::MarkdownV2) .await?; } // Success: response already delivered via streaming @@ -366,6 +457,32 @@ mod tests { } } + #[test] + fn test_final_flush_uses_entity_based_conversion() { + // The final flush must call markdown_to_entities (entity-based approach) instead of + // MarkdownV2 parse_mode. This is a source inspection test. + let source = include_str!("telegram.rs"); + assert!( + source.contains("markdown_to_entities"), + "Final flush must call markdown_to_entities for robust formatting" + ); + assert!( + source.contains("split_entities"), + "Final flush must call split_entities for long message handling" + ); + } + + #[test] + fn test_command_responses_use_escape_text() { + // All non-streaming command responses must escape plain text and use MarkdownV2 + // so that special chars like `.`, `-`, `!`, `_`, `(`, `)` don't break the parser. + let source = include_str!("telegram.rs"); + assert!( + source.contains("escape_text"), + "Command responses must call escape_text() before sending with MarkdownV2" + ); + } + #[test] fn test_stream_handle_does_not_require_placeholder_send() { // If the initial send fails, the stream handle must NOT silently swallow diff --git a/src/platform/tool_notifier.rs b/src/platform/tool_notifier.rs index 7a7124d..fa062ed 100644 --- a/src/platform/tool_notifier.rs +++ b/src/platform/tool_notifier.rs @@ -17,6 +17,94 @@ pub enum ToolEvent { Completed { name: String, success: bool }, } +/// Convert a technical tool name to a human-readable label with an emoji prefix. +/// +/// Priority: +/// 1. Exact match for known built-in tools. +/// 2. MCP tools prefixed with `mcp__` — server icon + humanised function name. +/// 3. Fallback — replace underscores with spaces and capitalise the first letter. +pub fn friendly_tool_name(name: &str) -> String { + // 1. Built-in tools — exact matches + let label = match name { + "read_file" => return "📄 Reading a file".to_string(), + "write_file" => return "✏️ Writing a file".to_string(), + "list_files" => return "📁 Listing files".to_string(), + "execute_command" => return "💻 Running a command".to_string(), + "schedule_task" => return "🗓️ Scheduling a task".to_string(), + "list_scheduled_tasks" => return "🗓️ Checking scheduled tasks".to_string(), + "cancel_scheduled_task" => return "🗓️ Cancelling a task".to_string(), + "invoke_agent" | "invoke_subagent" => return "🤖 Calling a specialist".to_string(), + "plan_create" | "plan_update" | "plan_view" => return "📋 Managing plan".to_string(), + "read_skill_file" | "write_skill_file" => return "📖 Reading skill".to_string(), + "reload_skills" | "reload_agents" => return "🔄 Reloading".to_string(), + "read_agent_file" | "write_agent_file" => return "🤖 Agent file".to_string(), + _ => name, + }; + + // 2. MCP tools: mcp__ + if let Some(rest) = label.strip_prefix("mcp_") { + // Known server names (with both hyphen and underscore variants) + // Sorted by length descending to match longest first (handles server names with underscores) + let known_servers = [ + ("google-workspace", "📧"), + ("google_workspace", "📧"), + ("brave-search", "🔍"), + ("brave_search", "🔍"), + ("puppeteer", "🌐"), + ("filesystem", "📁"), + ("github", "🐙"), + ("sqlite", "🗄️"), + ("threads", "🧵"), + ("fetch", "🌐"), + ("git", "📦"), + ]; + + // Try to match against known server names + for (server_name, icon) in &known_servers { + if let Some(func) = rest.strip_prefix(&format!("{}_", server_name)) { + let human = humanise_function_name(func); + return format!("{} {}", icon, human); + } + } + + // Fallback: split on first underscore (for unknown servers) + if let Some(sep) = rest.find('_') { + let func = &rest[sep + 1..]; + let human = humanise_function_name(func); + return format!("🔧 {}", human); + } + } + + // 3. Fallback — snake_case → "Snake case" + let human = humanise_function_name(label); + format!("🔧 {}", human) +} + +/// Convert a `snake_case_function_name` to a human-readable sentence. +/// Strips common verbose verb prefixes and capitalises the first letter. +fn humanise_function_name(func: &str) -> String { + // Strip common verbose prefixes that don't add meaning to the user + let stripped = func + .strip_prefix("query_") + .or_else(|| func.strip_prefix("search_")) + .or_else(|| func.strip_prefix("get_")) + .or_else(|| func.strip_prefix("list_")) + .unwrap_or(func); + + // Replace underscores with spaces + let spaced = stripped.replace('_', " "); + + // Capitalise first letter + let mut chars = spaced.chars(); + match chars.next() { + None => String::new(), + Some(first) => { + let upper: String = first.to_uppercase().collect(); + upper + chars.as_str() + } + } +} + /// Formats `args_preview` for display: truncate to 60 chars, strip outer braces for common single-arg calls. pub fn format_args_preview(args_json: &str) -> String { // Try to extract a single-value preview for readability @@ -30,7 +118,7 @@ pub fn format_args_preview(args_json: &str) -> String { other => other.to_string(), }; let truncated = crate::utils::strings::truncate_chars(&s, 60); - return format!("\"{}\"", truncated); + return truncated; } } } @@ -39,6 +127,17 @@ pub fn format_args_preview(args_json: &str) -> String { crate::utils::strings::truncate_chars(args_json, 60) } +/// Build the one-line tool status string streamed into the Telegram message +/// while the tool is running. Ends with `\n` so multiple calls stack visibly. +pub fn format_tool_status_line(name: &str, args_preview: &str) -> String { + let label = friendly_tool_name(name); + if args_preview.is_empty() { + format!("⏳ {}\n", label) + } else { + format!("⏳ {}: {}\n", label, args_preview) + } +} + /// Manages the live-edited Telegram status message during agent tool execution. #[allow(dead_code)] pub struct ToolCallNotifier { @@ -124,38 +223,43 @@ impl ToolCallNotifier { } else { "❌" }; - s.push_str(&format!("\n{} {}({})", icon, name, args_preview)); + let label = friendly_tool_name(name); + if args_preview.is_empty() { + s.push_str(&format!("\n{} {}", icon, label)); + } else { + s.push_str(&format!("\n{} {}: {}", icon, label, args_preview)); + } } s } - /// Finalise the status message. + /// Finalise the status message by deleting it. /// - /// - If no tools were called: delete the placeholder "⏳ Working..." (not useful). - /// - If tools were called: edit to a persistent summary so the user can see - /// which tools ran after the response has arrived. + /// The status message is always deleted once agent processing is complete. + /// The final LLM response (streamed separately) is the only message that + /// remains visible to the user. pub async fn finish(&self) { let Some(ref msg) = self.status_msg else { return; }; - if self.tool_log.is_empty() { - self.bot.delete_message(self.chat_id, msg.id).await.ok(); - } else { - let text = self.format_final(); - self.bot - .edit_message_text(self.chat_id, msg.id, &text) - .await - .ok(); - } + self.bot.delete_message(self.chat_id, msg.id).await.ok(); } - /// Final compact summary shown after tools have run. + /// Format a compact summary of tools that ran. + /// Currently unused because `finish()` always deletes the status message, + /// but kept for potential future use (e.g. re-enabling persistent summaries). + #[allow(dead_code)] fn format_final(&self) -> String { let mut s = String::from("🔧 Tools used:"); for (name, args_preview, _done, success) in &self.tool_log { let icon = if *success { "✅" } else { "❌" }; - s.push_str(&format!("\n{} {}({})", icon, name, args_preview)); + let label = friendly_tool_name(name); + if args_preview.is_empty() { + s.push_str(&format!("\n{} {}", icon, label)); + } else { + s.push_str(&format!("\n{} {}: {}", icon, label, args_preview)); + } } s } @@ -169,7 +273,8 @@ mod tests { fn test_format_args_preview_single_string_arg() { let json = r#"{"query":"Docker setup preferences"}"#; let preview = format_args_preview(json); - assert_eq!(preview, r#""Docker setup preferences""#); + // Quotes are no longer added — value is returned directly + assert_eq!(preview, "Docker setup preferences"); } #[test] @@ -178,7 +283,7 @@ mod tests { let json = format!(r#"{{"query":"{}"}}"#, long); let preview = format_args_preview(&json); assert!(preview.len() <= 70, "Preview should be truncated"); - assert!(preview.ends_with("...\"") || preview.contains("...")); + assert!(preview.contains("...")); } #[test] @@ -199,67 +304,192 @@ mod tests { #[test] fn test_format_final_includes_all_tools() { - // Build a notifier-like tool_log directly and call format_final via a helper. - // format_final is private — test it through a thin wrapper. + // Build a notifier-like tool_log directly and call format_final via a helper + // that mirrors the real implementation using friendly_tool_name. fn fake_format_final(tool_log: &[(String, String, bool, bool)]) -> String { let mut s = String::from("🔧 Tools used:"); for (name, args_preview, _done, success) in tool_log { let icon = if *success { "✅" } else { "❌" }; - s.push_str(&format!("\n{} {}({})", icon, name, args_preview)); + let label = friendly_tool_name(name); + if args_preview.is_empty() { + s.push_str(&format!("\n{} {}", icon, label)); + } else { + s.push_str(&format!("\n{} {}: {}", icon, label, args_preview)); + } } s } let log = vec![ - ( - "search".to_string(), - r#""Docker setup""#.to_string(), - true, - true, - ), + ("search".to_string(), "Docker setup".to_string(), true, true), ( "read_file".to_string(), - r#""/etc/config""#.to_string(), + "/etc/config".to_string(), true, false, ), ]; let result = fake_format_final(&log); assert!(result.contains("🔧 Tools used:"), "header missing"); - assert!(result.contains("✅ search"), "successful tool icon wrong"); - assert!(result.contains("❌ read_file"), "failed tool icon wrong"); + assert!(result.contains("✅"), "successful tool icon wrong"); + assert!(result.contains("❌"), "failed tool icon wrong"); assert!(result.contains("Docker setup"), "args missing for search"); assert!( !result.contains("⏳ Working"), "should not contain in-progress text" ); + // read_file should be humanised + assert!( + result.contains("📄 Reading a file"), + "read_file must be humanised" + ); } #[test] fn test_format_args_preview_single_arg_with_chinese() { - // Tests the single-arg extraction path with a Chinese string. - // This particular string's byte-60 happens to fall on a valid UTF-8 boundary, - // so it currently passes — after the UTF-8 truncation fix it will continue to pass. let long_chinese = "每日上午10點 arXiv AI 論文摘要(香港時間)很長的標題讓我們繼續寫下去直到超過六十個字"; let json = format!(r#"{{"query":"{}"}}"#, long_chinese); let preview = format_args_preview(&json); + assert!(!preview.is_empty()); + assert!(std::str::from_utf8(preview.as_bytes()).is_ok()); + } + + #[test] + fn test_format_tool_status_line_shows_hourglass_and_friendly_name() { + // web_search has no built-in mapping → falls through to 🔧 fallback + let line = format_tool_status_line("web_search", "Docker setup"); assert!( - preview.contains("\""), - "should be quoted single-arg preview" + line.starts_with("⏳"), + "status line must start with hourglass: {line}" ); - assert!(std::str::from_utf8(preview.as_bytes()).is_ok()); + assert!( + line.contains("Docker setup"), + "status line must include args preview: {line}" + ); + assert!( + line.ends_with('\n'), + "status line must end with newline: {line}" + ); + } + + #[test] + fn test_format_tool_status_line_builtin_tool_humanised() { + let line = format_tool_status_line("read_file", "/etc/config"); + assert!( + line.contains("📄 Reading a file"), + "built-in tool must be humanised: {line}" + ); + assert!(line.contains("/etc/config"), "args must be shown: {line}"); + assert!(line.ends_with('\n'), "must end with newline: {line}"); } #[test] - fn test_format_args_preview_multi_arg_chinese_panics_before_fix() { - // Multi-arg JSON falls through to the raw-JSON fallback path (lines 43-44). - // This test currently PANICS (fails) because &args_json[..60] hits byte 60 - // inside the multi-byte character '香'. After the UTF-8 truncation fix is - // applied, the slice will be adjusted to a safe boundary and this test will pass. + fn test_format_tool_status_line_ends_with_newline() { + let line = format_tool_status_line("read_file", "/etc/config"); + assert!( + line.ends_with('\n'), + "status line must end with newline for streaming: {line}" + ); + } + + #[test] + fn test_format_tool_status_line_empty_args() { + let line = format_tool_status_line("list_files", ""); + assert!( + !line.is_empty(), + "status line must not be empty even with no args" + ); + assert!( + line.contains("📁 Listing files"), + "list_files must be humanised: {line}" + ); + } + + #[test] + fn test_format_args_preview_multi_arg_chinese_truncates_safely() { let args = r#"{"description":"每日上午10點 arXiv AI 論文摘要(香港時間)","prompt":"使用 arxiv-daily-briefing skill","trigger_type":"recurring","trigger_value":"0 0 2 * * *"}"#; let preview = format_args_preview(args); assert!(!preview.is_empty()); assert!(std::str::from_utf8(preview.as_bytes()).is_ok()); } + + // --- friendly_tool_name --- + + #[test] + fn test_friendly_tool_name_builtin_read_file() { + assert_eq!(friendly_tool_name("read_file"), "📄 Reading a file"); + } + + #[test] + fn test_friendly_tool_name_builtin_execute_command() { + assert_eq!( + friendly_tool_name("execute_command"), + "💻 Running a command" + ); + } + + #[test] + fn test_friendly_tool_name_builtin_invoke_agent() { + assert_eq!( + friendly_tool_name("invoke_agent"), + "🤖 Calling a specialist" + ); + assert_eq!( + friendly_tool_name("invoke_subagent"), + "🤖 Calling a specialist" + ); + } + + #[test] + fn test_friendly_tool_name_mcp_brave_search() { + let name = "mcp_brave-search_search_web"; + let friendly = friendly_tool_name(name); + assert!(friendly.contains("🔍"), "brave-search icon: {friendly}"); + assert!( + !friendly.contains("mcp_"), + "should not contain raw prefix: {friendly}" + ); + } + + #[test] + fn test_friendly_tool_name_mcp_google_workspace() { + let name = "mcp_google-workspace_query_gmail_emails"; + let friendly = friendly_tool_name(name); + assert!(friendly.contains("📧"), "google-workspace icon: {friendly}"); + } + + #[test] + fn test_friendly_tool_name_mcp_unknown_server() { + let name = "mcp_myserver_do_something"; + let friendly = friendly_tool_name(name); + assert!( + friendly.contains("🔧"), + "unknown server must use fallback icon: {friendly}" + ); + } + + #[test] + fn test_friendly_tool_name_fallback_snake_case() { + let friendly = friendly_tool_name("some_unknown_tool"); + assert!( + friendly.starts_with("🔧"), + "unknown tool must use 🔧: {friendly}" + ); + assert!( + friendly.contains("Some unknown tool"), + "should humanise snake_case: {friendly}" + ); + } + + #[test] + fn test_friendly_tool_name_strips_verb_prefixes() { + // "query_" prefix is stripped in MCP humanisation + let name = "mcp_fetch_query_url"; + let friendly = friendly_tool_name(name); + assert!( + !friendly.contains("query_"), + "verb prefix should be stripped: {friendly}" + ); + } } diff --git a/src/utils/markdown_entities.rs b/src/utils/markdown_entities.rs new file mode 100644 index 0000000..cade990 --- /dev/null +++ b/src/utils/markdown_entities.rs @@ -0,0 +1,598 @@ +//! Convert a Markdown string to a `(plain_text, Vec)` pair suitable for +//! sending via the Telegram Bot API without any `parse_mode`. +//! +//! Inspired by [telegramify-markdown](https://github.com/sudoskys/telegramify-markdown) +//! by sudoskys. +//! +//! # Why entities instead of MarkdownV2? +//! +//! Telegram's `MarkdownV2` parse mode requires escaping 17+ special characters precisely. +//! Any mistake causes a 400 error and the bot falls back to raw unformatted text. The +//! entity approach sends plain text alongside a list of formatting spans — no escaping +//! needed, zero risk of parse failures. +//! +//! # Telegram entity offset semantics +//! +//! Telegram measures entity offsets and lengths in **UTF-16 code units**, not bytes or +//! Unicode scalar values. All offset conversions in this module use +//! `str[..n].encode_utf16().count()` to stay correct for CJK, emoji, and other +//! characters whose UTF-16 representation differs from UTF-8. + +use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd}; +use teloxide::types::MessageEntity; +use tracing::warn; + +/// Convert `markdown` to a `(plain_text, entities)` pair ready to pass to Telegram. +/// +/// The returned `plain_text` contains no Markdown syntax — all formatting information +/// is encoded in the `entities` list. Offsets and lengths in each entity are in +/// UTF-16 code units as required by the Telegram Bot API. +/// +/// Supported conversions: +/// - `**bold**` → `Bold` +/// - `*italic*` / `_italic_` → `Italic` +/// - `` `code` `` → `Code` +/// - ` ```lang\n...\n``` ` → `Pre { language }` +/// - `[text](url)` → `TextLink { url }` +/// - `# Heading` / `## Heading` / `### Heading` → `Bold` +/// - `~~strikethrough~~` → `Strikethrough` +pub fn markdown_to_entities(markdown: &str) -> (String, Vec) { + let mut options = Options::empty(); + options.insert(Options::ENABLE_STRIKETHROUGH); + options.insert(Options::ENABLE_GFM); + + let parser = Parser::new_ext(markdown, options); + + let mut plain = String::new(); + let mut entities: Vec = Vec::new(); + + // Stack of (tag, utf16_start_offset_in_plain_text) + // We push on Start and pop+emit on End. + let mut stack: Vec<(StackTag, usize)> = Vec::new(); + + // Track UTF-16 length incrementally to avoid O(n²) rescanning + let mut plain_utf16_len = 0usize; + + for event in parser { + match event { + // --- Text content --- + Event::Text(text) => { + plain.push_str(&text); + plain_utf16_len += text.encode_utf16().count(); + } + Event::Code(text) => { + // Inline code: emit as a Code entity + let start_utf16 = plain_utf16_len; + plain.push_str(&text); + let text_utf16_len = text.encode_utf16().count(); + plain_utf16_len += text_utf16_len; + let length = text_utf16_len; + if length > 0 { + entities.push(MessageEntity::code(start_utf16, length)); + } + } + Event::SoftBreak => { + plain.push('\n'); + plain_utf16_len += 1; + } + Event::HardBreak => { + plain.push('\n'); + plain_utf16_len += 1; + } + + // --- Block / inline formatting starts --- + Event::Start(tag) => match tag { + Tag::Strong => { + stack.push((StackTag::Bold, plain_utf16_len)); + } + Tag::Emphasis => { + stack.push((StackTag::Italic, plain_utf16_len)); + } + Tag::Strikethrough => { + stack.push((StackTag::Strikethrough, plain_utf16_len)); + } + Tag::Link { dest_url, .. } => { + stack.push((StackTag::Link(dest_url.to_string()), plain_utf16_len)); + } + Tag::Heading { .. } => { + stack.push((StackTag::Heading, plain_utf16_len)); + } + Tag::CodeBlock(kind) => { + let lang = match &kind { + CodeBlockKind::Fenced(lang) => { + let s = lang.trim().to_string(); + if s.is_empty() { + None + } else { + Some(s) + } + } + CodeBlockKind::Indented => None, + }; + stack.push((StackTag::CodeBlock(lang), plain_utf16_len)); + } + // Paragraph, list, etc. — no entity emitted on start. + _ => {} + }, + + // --- Block / inline formatting ends --- + Event::End(tag_end) => { + match tag_end { + TagEnd::Strong => { + if let Some((StackTag::Bold, start)) = stack.pop() { + let length = plain_utf16_len.saturating_sub(start); + if length > 0 { + entities.push(MessageEntity::bold(start, length)); + } + } + } + TagEnd::Emphasis => { + if let Some((StackTag::Italic, start)) = stack.pop() { + let length = plain_utf16_len.saturating_sub(start); + if length > 0 { + entities.push(MessageEntity::italic(start, length)); + } + } + } + TagEnd::Strikethrough => { + if let Some((StackTag::Strikethrough, start)) = stack.pop() { + let length = plain_utf16_len.saturating_sub(start); + if length > 0 { + entities.push(MessageEntity::strikethrough(start, length)); + } + } + } + TagEnd::Link => { + if let Some((StackTag::Link(url_str), start)) = stack.pop() { + let length = plain_utf16_len.saturating_sub(start); + if length > 0 { + // Parse the URL; if invalid, skip the entity (text is still kept) + if let Ok(url) = reqwest::Url::parse(&url_str) { + entities.push(MessageEntity::text_link(url, start, length)); + } else { + warn!( + "markdown_to_entities: invalid link URL ignored: {}", + url_str + ); + } + } + } + } + TagEnd::Heading(_) => { + if let Some((StackTag::Heading, start)) = stack.pop() { + let length = plain_utf16_len.saturating_sub(start); + if length > 0 { + entities.push(MessageEntity::bold(start, length)); + } + } + // Headings are block elements — add a newline after + plain.push('\n'); + plain_utf16_len += 1; + } + TagEnd::CodeBlock => { + if let Some((StackTag::CodeBlock(lang), start)) = stack.pop() { + // Trim trailing newline added by pulldown-cmark inside the block + if plain.ends_with('\n') { + plain.pop(); + plain_utf16_len -= 1; + } + let length = plain_utf16_len.saturating_sub(start); + if length > 0 { + entities.push(MessageEntity::pre(lang, start, length)); + } + } + plain.push('\n'); + plain_utf16_len += 1; + } + TagEnd::Paragraph => { + // Add blank line after each paragraph (double newline to preserve paragraph breaks) + plain.push_str("\n\n"); + plain_utf16_len += 2; + } + TagEnd::Item => { + plain.push('\n'); + plain_utf16_len += 1; + } + _ => {} + } + } + + // Ignore HTML, footnotes, rules, etc. + _ => {} + } + } + + // Trim trailing newlines (at most two, from the last paragraph's \n\n) + while plain.ends_with('\n') && plain_utf16_len > 0 { + plain.pop(); + plain_utf16_len -= 1; + } + + (plain, entities) +} + +/// Split a `(text, entities)` pair into chunks whose UTF-16 length does not exceed +/// `max_utf16_len`. Offsets in child entity lists are adjusted to be relative to +/// each chunk's start. +/// +/// Splitting tries to break at `\n` or space boundaries; it falls back to a hard +/// character boundary if no such split point exists in the window. +pub fn split_entities( + text: &str, + entities: &[MessageEntity], + max_utf16_len: usize, +) -> Vec<(String, Vec)> { + // Precompute cumulative UTF-16 lengths up to each char boundary to make + // offset lookups O(1) instead of O(n). + let char_utf16_boundaries: Vec<(usize, usize)> = { + let mut v = Vec::new(); + let mut utf16_acc = 0usize; + for (byte_pos, ch) in text.char_indices() { + v.push((byte_pos, utf16_acc)); + utf16_acc += ch.len_utf16(); + } + v.push((text.len(), utf16_acc)); + v + }; + + let total_utf16 = char_utf16_boundaries.last().map(|x| x.1).unwrap_or(0); + + if total_utf16 <= max_utf16_len { + return vec![(text.to_string(), entities.to_vec())]; + } + + let mut result = Vec::new(); + let mut chunk_utf16_start = 0usize; // UTF-16 offset into original text where chunk starts + + while chunk_utf16_start < total_utf16 { + let chunk_utf16_end_ideal = (chunk_utf16_start + max_utf16_len).min(total_utf16); + + // Find the byte position corresponding to chunk_utf16_end_ideal (or the last + // char boundary at or before it). + let mut split_utf16 = find_split_point( + text, + &char_utf16_boundaries, + chunk_utf16_start, + chunk_utf16_end_ideal, + ); + + // Ensure progress even if `find_split_point` cannot find a boundary inside + // the requested window (for example, when `max_utf16_len` is smaller than + // a single character's UTF-16 length, such as an emoji/surrogate pair). + if split_utf16 <= chunk_utf16_start { + split_utf16 = char_utf16_boundaries + .iter() + .map(|&(_, utf16)| utf16) + .find(|&utf16| utf16 > chunk_utf16_start) + .unwrap_or(total_utf16); + } + + // Convert UTF-16 offsets back to byte offsets for slicing + let start_byte = utf16_to_byte(&char_utf16_boundaries, chunk_utf16_start); + let end_byte = utf16_to_byte(&char_utf16_boundaries, split_utf16); + + let chunk_text = text[start_byte..end_byte].to_string(); + + // Collect entities that overlap this chunk and adjust their offsets + let chunk_entities: Vec = entities + .iter() + .filter_map(|e| { + let e_start = e.offset; + let e_end = e_start + e.length; + let chunk_end = split_utf16; + + // Entity must overlap the chunk + if e_end <= chunk_utf16_start || e_start >= chunk_end { + return None; + } + + let clamped_start = e_start.max(chunk_utf16_start); + let clamped_end = e_end.min(chunk_end); + let new_offset = clamped_start - chunk_utf16_start; + let new_length = clamped_end - clamped_start; + + if new_length == 0 { + return None; + } + + let mut cloned = e.clone(); + cloned.offset = new_offset; + cloned.length = new_length; + Some(cloned) + }) + .collect(); + + result.push((chunk_text, chunk_entities)); + chunk_utf16_start = split_utf16; + } + + result +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +/// Internal tag discriminant stored on the formatting stack. +enum StackTag { + Bold, + Italic, + Strikethrough, + Link(String), + Heading, + CodeBlock(Option), +} + +/// Given a `(byte_pos, cumulative_utf16)` table, convert a UTF-16 offset to a byte offset. +/// Returns the byte position of the char boundary at or just before `utf16_off`. +/// When `utf16_off` falls in the middle of a surrogate pair (i.e. it does not exactly +/// match any entry), the byte position of the preceding character is returned. +fn utf16_to_byte(boundaries: &[(usize, usize)], utf16_off: usize) -> usize { + match boundaries.binary_search_by_key(&utf16_off, |&(_, u)| u) { + Ok(idx) => boundaries[idx].0, + Err(idx) => { + // idx is the insertion point — snap to the preceding char boundary. + if idx == 0 { + 0 + } else { + boundaries[idx - 1].0 + } + } + } +} + +/// Find a good UTF-16 split point at or before `ideal_end` that lands on a `\n` or +/// space character if possible, otherwise falls back to the exact boundary. +fn find_split_point( + text: &str, + boundaries: &[(usize, usize)], + start_utf16: usize, + ideal_end_utf16: usize, +) -> usize { + if ideal_end_utf16 >= boundaries.last().map(|x| x.1).unwrap_or(0) { + return ideal_end_utf16; + } + + let ideal_byte = utf16_to_byte(boundaries, ideal_end_utf16); + let start_byte = utf16_to_byte(boundaries, start_utf16); + let window = &text[start_byte..ideal_byte]; + + // Prefer newline, then space + let split_byte_in_window = window + .rfind('\n') + .or_else(|| window.rfind(' ')) + .map(|pos| pos + 1); // keep the delimiter in the previous chunk + + let split_byte = split_byte_in_window + .map(|off| start_byte + off) + .unwrap_or(ideal_byte); + + // Convert split_byte back to utf16 offset + match boundaries.binary_search_by_key(&split_byte, |&(b, _)| b) { + Ok(idx) => boundaries[idx].1, + Err(idx) => { + if idx == 0 { + start_utf16 + } else { + boundaries[idx - 1].1 + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn entity_kind_name(e: &MessageEntity) -> &str { + use teloxide::types::MessageEntityKind::*; + match &e.kind { + Bold => "bold", + Italic => "italic", + Code => "code", + Pre { .. } => "pre", + TextLink { .. } => "text_link", + Strikethrough => "strikethrough", + _ => "other", + } + } + + // --- Basic formatting --- + + #[test] + fn test_bold_converts_to_entity() { + let (text, entities) = markdown_to_entities("**bold text**"); + assert_eq!(text, "bold text"); + assert_eq!(entities.len(), 1); + assert_eq!(entity_kind_name(&entities[0]), "bold"); + assert_eq!(entities[0].offset, 0); + assert_eq!(entities[0].length, 9); + } + + #[test] + fn test_italic_asterisk_converts_to_entity() { + let (text, entities) = markdown_to_entities("*italic*"); + assert_eq!(text, "italic"); + assert_eq!(entities.len(), 1); + assert_eq!(entity_kind_name(&entities[0]), "italic"); + } + + #[test] + fn test_italic_underscore_converts_to_entity() { + let (text, entities) = markdown_to_entities("_italic_"); + assert_eq!(text, "italic"); + assert_eq!(entities.len(), 1); + assert_eq!(entity_kind_name(&entities[0]), "italic"); + } + + #[test] + fn test_inline_code_converts_to_entity() { + let (text, entities) = markdown_to_entities("`code`"); + assert_eq!(text, "code"); + assert_eq!(entities.len(), 1); + assert_eq!(entity_kind_name(&entities[0]), "code"); + } + + #[test] + fn test_fenced_code_block_converts_to_pre_entity() { + let input = "```rust\nfn main() {}\n```"; + let (text, entities) = markdown_to_entities(input); + assert_eq!(text, "fn main() {}"); + assert_eq!(entities.len(), 1); + assert_eq!(entity_kind_name(&entities[0]), "pre"); + if let teloxide::types::MessageEntityKind::Pre { + language: Some(lang), + } = &entities[0].kind + { + assert_eq!(lang, "rust"); + } else { + panic!("Expected Pre with language: {:?}", entities[0].kind); + } + } + + #[test] + fn test_link_converts_to_text_link_entity() { + let (text, entities) = markdown_to_entities("[RustFox](https://github.com)"); + assert_eq!(text, "RustFox"); + assert_eq!(entities.len(), 1); + assert_eq!(entity_kind_name(&entities[0]), "text_link"); + if let teloxide::types::MessageEntityKind::TextLink { url } = &entities[0].kind { + assert_eq!(url.as_str(), "https://github.com/"); + } else { + panic!("Expected TextLink: {:?}", entities[0].kind); + } + } + + #[test] + fn test_heading_converts_to_bold_entity() { + let (text, entities) = markdown_to_entities("# Hello"); + assert!( + text.contains("Hello"), + "text must contain heading content: {text}" + ); + assert!( + entities.iter().any(|e| entity_kind_name(e) == "bold"), + "heading must produce a bold entity" + ); + } + + #[test] + fn test_strikethrough_converts_to_entity() { + let (text, entities) = markdown_to_entities("~~strikethrough~~"); + assert_eq!(text, "strikethrough"); + assert_eq!(entities.len(), 1); + assert_eq!(entity_kind_name(&entities[0]), "strikethrough"); + } + + // --- UTF-16 offset correctness --- + + #[test] + fn test_bold_with_cjk_correct_utf16_offsets() { + // "你好 **world**" — "你好 " is 3 chars, each 1 UTF-16 unit + let input = "你好 **world**"; + let (text, entities) = markdown_to_entities(input); + assert!(text.contains("你好"), "CJK must appear in plain text"); + assert!( + text.contains("world"), + "bold text must appear in plain text" + ); + + let bold = entities + .iter() + .find(|e| entity_kind_name(e) == "bold") + .unwrap(); + // UTF-16 offset of "world" after "你好 " (3 units) + let expected_offset: usize = "你好 ".encode_utf16().count(); + assert_eq!( + bold.offset, expected_offset, + "UTF-16 offset must account for CJK chars" + ); + assert_eq!(bold.length, 5); // "world" = 5 UTF-16 units + } + + #[test] + fn test_bold_with_emoji_correct_utf16_offsets() { + // Emoji like 🦊 = 2 UTF-16 code units + let input = "🦊 **bold**"; + let (text, entities) = markdown_to_entities(input); + assert!(text.contains("bold"), "bold text must be in plain text"); + let bold = entities + .iter() + .find(|e| entity_kind_name(e) == "bold") + .unwrap(); + // "🦊 " = 3 UTF-16 units (2 for emoji + 1 space) + let expected_offset: usize = "🦊 ".encode_utf16().count(); + assert_eq!( + bold.offset, expected_offset, + "UTF-16 offset must account for surrogate-pair emoji" + ); + } + + #[test] + fn test_plain_text_no_entities() { + let (text, entities) = markdown_to_entities("Hello world"); + assert_eq!(text, "Hello world"); + assert!(entities.is_empty(), "plain text must produce no entities"); + } + + #[test] + fn test_empty_string_returns_empty() { + let (text, entities) = markdown_to_entities(""); + assert!(text.is_empty()); + assert!(entities.is_empty()); + } + + #[test] + fn test_mixed_bold_and_code() { + let input = "**bold** and `code`"; + let (text, entities) = markdown_to_entities(input); + assert!(text.contains("bold"), "bold text in output"); + assert!(text.contains("code"), "code text in output"); + assert!(entities.iter().any(|e| entity_kind_name(e) == "bold")); + assert!(entities.iter().any(|e| entity_kind_name(e) == "code")); + } + + // --- split_entities --- + + #[test] + fn test_split_entities_short_text_not_split() { + let (text, entities) = markdown_to_entities("**hello**"); + let chunks = split_entities(&text, &entities, 4096); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0].0, "hello"); + } + + #[test] + fn test_split_entities_long_text_splits() { + let long = "a ".repeat(3000); // 6000 UTF-16 chars + let (text, entities) = markdown_to_entities(&long); + let chunks = split_entities(&text, &entities, 4096); + assert!(chunks.len() > 1, "long text must be split"); + for (chunk_text, _) in &chunks { + let utf16_len: usize = chunk_text.encode_utf16().count(); + assert!( + utf16_len <= 4096, + "chunk must not exceed max_utf16_len: {} > 4096", + utf16_len + ); + } + } + + #[test] + fn test_split_entities_entity_offsets_adjusted() { + // Two bold words separated by enough filler to force a split + let filler = " ".repeat(4090); + let input = format!("**A**{}**B**", filler); + let (text, entities) = markdown_to_entities(&input); + let chunks = split_entities(&text, &entities, 4096); + + // The first chunk should have offset-0 entity for "A" + let first_bold = chunks[0].1.iter().find(|e| entity_kind_name(e) == "bold"); + assert!(first_bold.is_some(), "first chunk must have a bold entity"); + assert_eq!( + first_bold.unwrap().offset, + 0, + "first chunk bold must start at offset 0" + ); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index e8dfd78..eccfbdb 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1 +1,3 @@ +pub mod markdown_entities; pub mod strings; +pub mod telegram_markdown; diff --git a/src/utils/telegram_markdown.rs b/src/utils/telegram_markdown.rs new file mode 100644 index 0000000..5da9e6e --- /dev/null +++ b/src/utils/telegram_markdown.rs @@ -0,0 +1,556 @@ +/// MarkdownV2 special characters that must be escaped with `\` in plain text context. +/// See: https://core.telegram.org/bots/api#markdownv2-style +/// Backslash is listed first so the intent (escape the escaper) is self-evident. +const SPECIAL_CHARS_V2: &[char] = &[ + '\\', '_', '*', '[', ']', '(', ')', '~', '`', '>', '#', '+', '-', '=', '|', '{', '}', '.', '!', +]; + +/// Escape all MarkdownV2 special characters in a plain-text string. +/// Use this for text that should be rendered as literal text (not markup). +pub fn escape_text(s: &str) -> String { + let mut out = String::with_capacity(s.len() * 2); + for c in s.chars() { + if SPECIAL_CHARS_V2.contains(&c) { + out.push('\\'); + } + out.push(c); + } + out +} + +/// Escape characters inside a code span or code block. +/// Only backtick (`) and backslash (\) need escaping inside code. +#[allow(dead_code)] +fn escape_code(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + for c in s.chars() { + if c == '\\' || c == '`' { + out.push('\\'); + } + out.push(c); + } + out +} + +/// Find the position of the first unescaped occurrence of `needle` in `haystack`. +/// Returns `None` if not found. +#[allow(dead_code)] +fn find_unescaped(haystack: &str, needle: &str) -> Option { + let mut i = 0; + let bytes = haystack.as_bytes(); + let n = needle.len(); + while i + n <= bytes.len() { + if haystack.is_char_boundary(i + n) && &haystack[i..i + n] == needle { + // Check it's not preceded by backslash (simple check) + if i == 0 || bytes[i - 1] != b'\\' { + return Some(i); + } + } + // Advance by one UTF-8 character + let ch_len = haystack[i..] + .chars() + .next() + .map(|c| c.len_utf8()) + .unwrap_or(1); + i += ch_len; + } + None +} + +/// Convert a line of inline-formatted text to MarkdownV2. +/// Handles: **bold**, `inline code`, [links](url), and plain text escaping. +/// Leaves `_italic_` as-is (MarkdownV2 already uses `_` for italic). +#[allow(dead_code)] +fn convert_inline(s: &str) -> String { + let mut result = String::new(); + let mut remaining = s; + + while !remaining.is_empty() { + // Bold: **text** → *text* + if remaining.starts_with("**") { + let after = &remaining[2..]; + if let Some(close) = find_unescaped(after, "**") { + let inner = &after[..close]; + if !inner.is_empty() { + result.push('*'); + result.push_str(&convert_inline(inner)); + result.push('*'); + remaining = &after[close + 2..]; + continue; + } + } + } + + // Inline code: `text` (but not ``` fenced blocks — those are handled at line level) + if remaining.starts_with('`') && !remaining.starts_with("```") { + let after = &remaining[1..]; + if let Some(close) = after.find('`') { + let inner = &after[..close]; + result.push('`'); + result.push_str(&escape_code(inner)); + result.push('`'); + remaining = &after[close + 1..]; + continue; + } + } + + // Link: [display text](url) + if remaining.starts_with('[') { + if let Some(bracket_close) = find_matching_bracket(remaining) { + let display = &remaining[1..bracket_close]; + let after_bracket = &remaining[bracket_close + 1..]; + if let Some(inside_parens) = after_bracket.strip_prefix('(') { + if let Some(paren_close) = inside_parens.find(')') { + let url = &inside_parens[..paren_close]; + result.push('['); + result.push_str(&convert_inline(display)); + result.push_str("]("); + // URL only needs minimal escaping: ) must be escaped + result.push_str(&url.replace(')', "\\)")); + result.push(')'); + remaining = &inside_parens[paren_close + 1..]; + continue; + } + } + } + } + + // `_italic_` in standard markdown stays as `_italic_` in MarkdownV2 + // (both use `_` for italic — no conversion needed, just escape surrounding chars) + // `*single-asterisk*` italic in standard markdown → `_italic_` in MarkdownV2 + if remaining.starts_with('_') && !remaining.starts_with("__") { + let after = &remaining[1..]; + if !after.starts_with(' ') { + if let Some(close) = find_unescaped(after, "_") { + let inner = &after[..close]; + if !inner.is_empty() && !inner.ends_with(' ') { + result.push('_'); + result.push_str(&escape_text(inner)); + result.push('_'); + remaining = &after[close + 1..]; + continue; + } + } + } + } + + // `*single-asterisk*` italic in standard markdown → `_italic_` in MarkdownV2 + if remaining.starts_with('*') && !remaining.starts_with("**") { + let after = &remaining[1..]; + // Only treat as italic if there's non-space content and a closing * + if !after.starts_with(' ') { + if let Some(close) = find_unescaped(after, "*") { + let inner = &after[..close]; + if !inner.is_empty() && !inner.ends_with(' ') { + result.push('_'); + result.push_str(&escape_text(inner)); + result.push('_'); + remaining = &after[close + 1..]; + continue; + } + } + } + } + + // Regular character — escape if special + let ch = remaining.chars().next().unwrap(); + let ch_len = ch.len_utf8(); + if SPECIAL_CHARS_V2.contains(&ch) { + result.push('\\'); + } + result.push(ch); + remaining = &remaining[ch_len..]; + } + + result +} + +/// Find the index of the matching `]` for a `[` at the start of `s`. +#[allow(dead_code)] +fn find_matching_bracket(s: &str) -> Option { + debug_assert!(s.starts_with('[')); + let mut depth = 0usize; + for (i, c) in s.char_indices() { + match c { + '[' => depth += 1, + ']' => { + depth -= 1; + if depth == 0 { + return Some(i); + } + } + _ => {} + } + } + None +} + +/// Convert a single non-code line, handling headers and inline formatting. +#[allow(dead_code)] +fn convert_line(line: &str) -> String { + // ATX headers: # / ## / ### → *Heading* (bold) + for prefix in &["### ", "## ", "# "] { + if let Some(rest) = line.strip_prefix(prefix) { + return format!("*{}*", convert_inline(rest)); + } + } + convert_inline(line) +} + +/// Convert a string of standard markdown to Telegram MarkdownV2 format. +/// +/// Rules applied: +/// - Fenced code blocks (` ``` `) — content is preserved except `` ` `` and `\` are escaped. +/// - Inline code spans (`` ` ... ` ``) — same limited escaping. +/// - `**bold**` → `*bold*` +/// - `*italic*` → `_italic_` +/// - `_italic_` → `_italic_` (unchanged — MarkdownV2 already uses `_`) +/// - `# Heading` / `## Heading` / `### Heading` → `*Heading*` +/// - `[text](url)` → `[text](url)` (text part escaped, URL left as-is) +/// - All other MarkdownV2 special characters in plain text are escaped with `\`. +#[allow(dead_code)] +pub fn markdown_to_telegram_v2(text: &str) -> String { + let mut result = String::new(); + let mut in_code_block = false; + let mut code_fence = String::new(); // tracks the opening fence (e.g. "```" or "```rust") + + // Collect lines once to avoid re-scanning for count on every iteration. + let lines: Vec<&str> = text.lines().collect(); + let line_count = lines.len(); + + for (idx, line) in lines.iter().enumerate() { + let is_last_line = idx == line_count.saturating_sub(1); + let newline = if is_last_line && !text.ends_with('\n') { + "" + } else { + "\n" + }; + + let trimmed = line.trim_start(); + + if !in_code_block && trimmed.starts_with("```") { + // Entering a fenced code block + let lang = trimmed.strip_prefix("```").unwrap_or("").trim(); + code_fence = format!("```{}", lang); + result.push_str(&code_fence); + result.push_str(newline); + in_code_block = true; + } else if in_code_block && trimmed.starts_with("```") { + // Closing the fenced code block + result.push_str("```"); + result.push_str(newline); + in_code_block = false; + code_fence.clear(); + } else if in_code_block { + // Inside code block — only escape ` and \ + result.push_str(&escape_code(line)); + result.push_str(newline); + } else { + // Regular line — convert inline formatting + result.push_str(&convert_line(line)); + result.push_str(newline); + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + // --- escape_text --- + + #[test] + fn test_escape_plain_text_no_special_chars() { + assert_eq!(escape_text("Hello world"), "Hello world"); + } + + #[test] + fn test_escape_plain_text_dot() { + assert_eq!(escape_text("Hello.World"), "Hello\\.World"); + } + + #[test] + fn test_escape_plain_text_dash() { + assert_eq!(escape_text("foo-bar"), "foo\\-bar"); + } + + #[test] + fn test_escape_plain_text_multiple_special() { + assert_eq!(escape_text("2+2=4"), "2\\+2\\=4"); + } + + #[test] + fn test_escape_plain_text_underscore() { + assert_eq!(escape_text("foo_bar"), "foo\\_bar"); + } + + #[test] + fn test_escape_plain_text_backslash() { + assert_eq!(escape_text("a\\b"), "a\\\\b"); + } + + // --- markdown_to_telegram_v2: plain text --- + + #[test] + fn test_plain_text_no_special_chars_unchanged() { + assert_eq!(markdown_to_telegram_v2("Hello world"), "Hello world"); + } + + #[test] + fn test_plain_text_dot_escaped() { + assert_eq!(markdown_to_telegram_v2("Hello."), "Hello\\."); + } + + #[test] + fn test_plain_text_dash_escaped() { + assert_eq!(markdown_to_telegram_v2("foo-bar"), "foo\\-bar"); + } + + #[test] + fn test_plain_text_exclamation_escaped() { + assert_eq!(markdown_to_telegram_v2("Hi!"), "Hi\\!"); + } + + // --- markdown_to_telegram_v2: fenced code blocks --- + + #[test] + fn test_fenced_code_block_content_not_escaped() { + let input = "```\nfoo_bar.baz\n```"; + let output = markdown_to_telegram_v2(input); + assert!( + output.contains("foo_bar.baz"), + "code content must not escape _ or .: {}", + output + ); + } + + #[test] + fn test_fenced_code_block_backtick_escaped() { + let input = "```\nlet x = `template`;\n```"; + let output = markdown_to_telegram_v2(input); + assert!( + output.contains("\\`template\\`"), + "backticks inside code must be escaped: {}", + output + ); + } + + #[test] + fn test_fenced_code_block_backslash_escaped() { + let input = "```\npath\\to\\file\n```"; + let output = markdown_to_telegram_v2(input); + assert!( + output.contains("path\\\\to\\\\file"), + "backslashes inside code must be escaped: {}", + output + ); + } + + #[test] + fn test_fenced_code_block_with_language() { + let input = "```rust\nfn main() {}\n```"; + let output = markdown_to_telegram_v2(input); + assert!( + output.starts_with("```rust\n"), + "language tag must be kept: {}", + output + ); + assert!( + output.contains("fn main() {}"), + "function body must not be escaped: {}", + output + ); + } + + // --- markdown_to_telegram_v2: inline code --- + + #[test] + fn test_inline_code_content_not_escaped_for_dots() { + let input = "Use `foo.bar()` here"; + let output = markdown_to_telegram_v2(input); + // dot must NOT be escaped inside inline code + assert!( + output.contains("`foo.bar()`"), + "inline code dot must not be escaped: {}", + output + ); + } + + #[test] + fn test_inline_code_backtick_escaped() { + let input = "This `has \\` inside` it"; + let output = markdown_to_telegram_v2(input); + assert!( + output.contains("\\`"), + "backtick in inline code must be escaped: {}", + output + ); + } + + // --- markdown_to_telegram_v2: bold --- + + #[test] + fn test_bold_double_asterisk_converted() { + assert_eq!(markdown_to_telegram_v2("**bold**"), "*bold*"); + } + + #[test] + fn test_bold_with_special_char_in_text() { + assert_eq!(markdown_to_telegram_v2("**foo.bar**"), "*foo\\.bar*"); + } + + #[test] + fn test_bold_in_sentence() { + let out = markdown_to_telegram_v2("This is **important** text."); + assert!( + out.contains("*important*"), + "bold must be converted: {}", + out + ); + assert!( + out.ends_with("\\."), + "trailing dot must be escaped: {}", + out + ); + } + + // --- markdown_to_telegram_v2: italic --- + + #[test] + fn test_italic_single_asterisk_converted() { + assert_eq!(markdown_to_telegram_v2("*italic*"), "_italic_"); + } + + #[test] + fn test_italic_underscore_unchanged() { + // MarkdownV2 already uses _ for italic; no conversion needed + assert_eq!(markdown_to_telegram_v2("_italic_"), "_italic_"); + } + + // --- markdown_to_telegram_v2: headers --- + + #[test] + fn test_h1_converted_to_bold() { + assert_eq!(markdown_to_telegram_v2("# Heading"), "*Heading*"); + } + + #[test] + fn test_h2_converted_to_bold() { + assert_eq!(markdown_to_telegram_v2("## Sub"), "*Sub*"); + } + + #[test] + fn test_h3_converted_to_bold() { + assert_eq!(markdown_to_telegram_v2("### Sub"), "*Sub*"); + } + + #[test] + fn test_header_with_special_char() { + assert_eq!(markdown_to_telegram_v2("# Hello."), "*Hello\\.*"); + } + + // --- markdown_to_telegram_v2: links --- + + #[test] + fn test_link_passes_through() { + let out = markdown_to_telegram_v2("[RustFox](https://github.com)"); + assert!( + out.starts_with("[RustFox]"), + "display text must be kept: {}", + out + ); + assert!( + out.contains("(https://github.com)"), + "URL must be kept: {}", + out + ); + } + + #[test] + fn test_link_display_special_chars_escaped() { + let out = markdown_to_telegram_v2("[foo.bar](https://x.com)"); + assert!( + out.contains("foo\\.bar"), + "dot in display text must be escaped: {}", + out + ); + } + + // --- markdown_to_telegram_v2: bullet lists --- + + #[test] + fn test_bullet_list_dash_escaped() { + let out = markdown_to_telegram_v2("- item one"); + assert!( + out.starts_with("\\- "), + "leading dash must be escaped: {}", + out + ); + } + + #[test] + fn test_bullet_list_asterisk_escaped() { + let out = markdown_to_telegram_v2("* item"); + // * followed by space is a list marker — it is treated as italic formatting + // which converts to _item_ (italic). If there's no valid closing *, it's escaped. + // Either way the output must be valid MarkdownV2. + assert!(!out.is_empty()); + } + + // --- markdown_to_telegram_v2: mixed content --- + + #[test] + fn test_mixed_bold_and_code_block() { + let input = "**Summary:**\n\n```rust\nfn hello() {}\n```"; + let output = markdown_to_telegram_v2(input); + assert!(output.contains("*Summary"), "bold must be converted"); + assert!(output.contains("```rust"), "code block must be kept"); + assert!( + output.contains("fn hello() {}"), + "code body must not be over-escaped" + ); + } + + #[test] + fn test_trailing_newline_preserved_if_present() { + let input = "Hello\n"; + let output = markdown_to_telegram_v2(input); + assert!(output.ends_with('\n'), "trailing newline must be preserved"); + } + + #[test] + fn test_no_trailing_newline_preserved_if_absent() { + let input = "Hello"; + let output = markdown_to_telegram_v2(input); + assert!(!output.ends_with('\n'), "no extra newline should be added"); + } + + #[test] + fn test_empty_string_returns_empty() { + assert_eq!(markdown_to_telegram_v2(""), ""); + } + + // --- UTF-8 multi-byte character safety --- + + #[test] + fn test_bold_in_chinese_text_does_not_panic() { + // Reproduces: byte index 14 is not a char boundary inside '年' + let input = "ComfyUI 2025 年度回顧** 🔥"; + // Must not panic; output must be valid UTF-8 + let out = markdown_to_telegram_v2(input); + assert!(std::str::from_utf8(out.as_bytes()).is_ok()); + } + + #[test] + fn test_bold_wrapping_chinese_text_converts_correctly() { + let input = "**2025 年度回顧**"; + let out = markdown_to_telegram_v2(input); + assert!( + out.starts_with('*') && out.ends_with('*'), + "bold must wrap: {out}" + ); + assert!(std::str::from_utf8(out.as_bytes()).is_ok()); + } +}