diff --git a/CHANGELOG.md b/CHANGELOG.md index 29f6c3c..c75260c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +[Unreleased] + +- Add multi-language support for Markdown files: fenced code blocks are now spell-checked using the appropriate language grammar (Python, Rust, Bash, etc.) +- Add language injection system via `@injection.*` capture tags in `.scm` query files. Adding multi-language support to a new file type requires only a `.scm` change, no Rust code +- Add HTML block injection in Markdown. Block-level HTML is spell-checked using the HTML grammar +- Add language alias resolution for Markdown code blocks (e.g., `py`, `js`, `sh`, `rs`, `yml`, `c++`) +- Pre-compile all tree-sitter queries at startup for faster spell-checking and earlier error detection +- Reduce per-word memory allocations in the spell-check pipeline +- Fix Erlang query producing duplicate captures for function name atoms +- Refactor: split spell-checking into separate extraction (`parser.rs`) and checking (`checker.rs`) modules + [0.3.35] - Add tag-based filtering (`include_tags`/`exclude_tags`) to control which parts of code are spell-checked (comments, strings, identifiers, etc.) diff --git a/crates/codebook/src/checker.rs b/crates/codebook/src/checker.rs new file mode 100644 index 0000000..95a2dff --- /dev/null +++ b/crates/codebook/src/checker.rs @@ -0,0 +1,123 @@ +use std::collections::{HashMap, HashSet}; + +use crate::dictionaries::dictionary::Dictionary; +use crate::parser::{TextRange, WordLocation}; +use codebook_config::CodebookConfig; + +/// A candidate word extracted from a text node, with its position +/// in original-document byte offsets. Borrows the word text from the +/// source document to avoid per-word String allocations. +#[derive(Debug, Clone, PartialEq)] +pub struct WordCandidate<'a> { + pub word: &'a str, + pub start_byte: usize, + pub end_byte: usize, +} + +/// Check candidate words against dictionaries and config rules. +/// Returns WordLocations for misspelled words, grouping all locations +/// of the same word together. +pub fn check_words( + candidates: &[WordCandidate<'_>], + dictionaries: &[std::sync::Arc], + config: &dyn CodebookConfig, +) -> Vec { + // Group candidates by word text, deduplicating identical spans. + let mut word_positions: HashMap<&str, HashSet> = HashMap::new(); + for candidate in candidates { + let location = TextRange { + start_byte: candidate.start_byte, + end_byte: candidate.end_byte, + }; + let added = word_positions + .entry(candidate.word) + .or_default() + .insert(location); + + debug_assert!( + added, + "Two of the same locations found. Make a better query. Word: {}, Location: {:?}", + candidate.word, location + ); + } + + // Check each unique word once + let mut results = Vec::new(); + for (word, positions) in word_positions { + let positions: Vec = positions.into_iter().collect(); + if config.should_flag_word(word) { + results.push(WordLocation::new(word.to_string(), positions)); + continue; + } + if word.len() < config.get_min_word_length() { + continue; + } + if config.is_allowed_word(word) { + continue; + } + let is_correct = dictionaries.iter().any(|dict| dict.check(word)); + if !is_correct { + results.push(WordLocation::new(word.to_string(), positions)); + } + } + results +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dictionaries::dictionary::TextDictionary; + use std::sync::Arc; + + fn make_candidates<'a>(words: &[(&'a str, usize, usize)]) -> Vec> { + words + .iter() + .map(|(word, start, end)| WordCandidate { + word, + start_byte: *start, + end_byte: *end, + }) + .collect() + } + + #[test] + fn test_check_words_flags_unknown() { + let dict = Arc::new(TextDictionary::new("hello\nworld\n")); + let config = Arc::new(codebook_config::CodebookConfigMemory::default()); + let candidates = make_candidates(&[("hello", 0, 5), ("wrld", 6, 10)]); + let results = check_words(&candidates, &[dict], config.as_ref()); + assert_eq!(results.len(), 1); + assert_eq!(results[0].word, "wrld"); + } + + #[test] + fn test_check_words_groups_locations() { + let dict = Arc::new(TextDictionary::new("hello\n")); + let config = Arc::new(codebook_config::CodebookConfigMemory::default()); + let candidates = make_candidates(&[("wrld", 0, 4), ("wrld", 10, 14)]); + let results = check_words(&candidates, &[dict], config.as_ref()); + assert_eq!(results.len(), 1); + assert_eq!(results[0].word, "wrld"); + assert_eq!(results[0].locations.len(), 2); + } + + #[test] + fn test_check_words_respects_min_length() { + let dict = Arc::new(TextDictionary::new("")); + let config = Arc::new(codebook_config::CodebookConfigMemory::default()); + // Default min word length is 3 + let candidates = make_candidates(&[("ab", 0, 2)]); + let results = check_words(&candidates, &[dict], config.as_ref()); + assert!(results.is_empty(), "Short words should be skipped"); + } + + #[test] + fn test_check_words_respects_allowed_words() { + let dict = Arc::new(TextDictionary::new("")); + let config = Arc::new(codebook_config::CodebookConfigMemory::default()); + config.add_word("codebook").unwrap(); + let candidates = make_candidates(&[("codebook", 0, 8)]); + let results = check_words(&candidates, &[dict], config.as_ref()); + assert!(results.is_empty(), "Allowed words should not be flagged"); + } +} diff --git a/crates/codebook/src/dictionaries/dictionary.rs b/crates/codebook/src/dictionaries/dictionary.rs index 82e75d7..3ba0d80 100644 --- a/crates/codebook/src/dictionaries/dictionary.rs +++ b/crates/codebook/src/dictionaries/dictionary.rs @@ -7,10 +7,6 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::parser::{WordLocation, find_locations}; -use crate::queries::LanguageType; -use regex::Regex; - pub trait Dictionary: Send + Sync { fn check(&self, word: &str) -> bool; fn suggest(&self, word: &str) -> Vec; @@ -170,17 +166,6 @@ impl TextDictionary { } } -/// Integration helper to use any Dictionary trait with optimized batch processing -pub fn find_locations_with_dictionary_batch( - text: &str, - language: LanguageType, - dictionary: &dyn Dictionary, - skip_patterns: &[Regex], -) -> Vec { - // For non-HashSet dictionaries, we still get deduplication benefits - find_locations(text, language, |word| dictionary.check(word), |_| true, skip_patterns) -} - #[cfg(test)] mod dictionary_tests { use super::*; diff --git a/crates/codebook/src/lib.rs b/crates/codebook/src/lib.rs index 5cc6840..19cfe95 100644 --- a/crates/codebook/src/lib.rs +++ b/crates/codebook/src/lib.rs @@ -1,3 +1,4 @@ +pub mod checker; pub mod dictionaries; mod logging; pub mod parser; @@ -6,6 +7,7 @@ pub mod regexes; mod splitter; use crate::regexes::get_default_skip_patterns; +use std::collections::HashSet; use std::path::Path; use std::sync::Arc; @@ -38,48 +40,35 @@ impl Codebook { file_path: Option<&str>, ) -> Vec { if let Some(file_path) = file_path { - // ignore_paths is a blocklist and has higher precedence than include_paths if self.config.should_ignore_path(Path::new(file_path)) { return Vec::new(); } - // include_paths is an allowlist; empty list means "include everything" if !self.config.should_include_path(Path::new(file_path)) { return Vec::new(); } } - // get needed dictionary names - // get needed dictionaries - // call spell check on each dictionary + let language = self.resolve_language(language, file_path); - let dictionaries = self.get_dictionaries(Some(language)); - // Combine default and user patterns + + // Combine default and user skip patterns let mut all_patterns = get_default_skip_patterns().clone(); if let Some(user_patterns) = self.config.get_ignore_patterns() { all_patterns.extend(user_patterns); } - parser::find_locations( + + // Extract all words, recursively following injections + let (candidates, languages_found) = parser::extract_all_words( text, language, - |word| { - if self.config.should_flag_word(word) { - return false; - } - if word.len() < self.config.get_min_word_length() { - return true; - } - if self.config.is_allowed_word(word) { - return true; - } - for dictionary in &dictionaries { - if dictionary.check(word) { - return true; - } - } - false - }, - |tag| self.config.should_check_tag(tag), + &|tag| self.config.should_check_tag(tag), &all_patterns, - ) + ); + + // Load dictionaries for all languages encountered + let dictionaries = self.get_dictionaries_for_languages(&languages_found); + + // Check words against dictionaries + checker::check_words(&candidates, &dictionaries, self.config.as_ref()) } fn resolve_language( @@ -87,7 +76,6 @@ impl Codebook { language_type: Option, path: Option<&str>, ) -> queries::LanguageType { - // Check if we have a language_id first, fallback to path, fall back to text match language_type { Some(lang) => lang, None => match path { @@ -97,21 +85,26 @@ impl Codebook { } } - fn get_dictionaries( + /// Gather dictionaries for all languages encountered in a file. + fn get_dictionaries_for_languages( &self, - language: Option, + languages: &HashSet, ) -> Vec> { let mut dictionary_ids = self.config.get_dictionary_ids(); - if let Some(lang) = language { - let language_dictionary_ids = lang.dictionary_ids(); - dictionary_ids.extend(language_dictionary_ids); - }; + + for lang in languages { + dictionary_ids.extend(lang.dictionary_ids()); + } + dictionary_ids.extend(DEFAULT_DICTIONARIES.iter().map(|f| f.to_string())); + + dictionary_ids.sort(); + dictionary_ids.dedup(); + let mut dictionaries = Vec::with_capacity(dictionary_ids.len()); debug!("Checking text with dictionaries: {dictionary_ids:?}"); for dictionary_id in dictionary_ids { - let dictionary = self.manager.get_dictionary(&dictionary_id); - if let Some(d) = dictionary { + if let Some(d) = self.manager.get_dictionary(&dictionary_id) { dictionaries.push(d); } } @@ -125,9 +118,8 @@ impl Codebook { } pub fn get_suggestions(&self, word: &str) -> Option> { - // Get top suggestions and return the first 5 suggestions in round robin order let max_results = 5; - let dictionaries = self.get_dictionaries(None); + let dictionaries = self.get_dictionaries_for_languages(&HashSet::new()); let mut is_misspelled = false; let suggestions: Vec> = dictionaries .iter() @@ -176,9 +168,7 @@ mod tests { vec!["date", "elderberry", "fig"], vec!["grape", "honeydew", "kiwi"], ]; - let result = collect_round_robin(&sources, 5); - // Round-robin order: first from each source, then second from each source assert_eq!( result, vec!["apple", "date", "grape", "banana", "elderberry"] @@ -192,13 +182,6 @@ mod tests { vec!["banana", "cherry", "date"], vec!["cherry", "date", "elderberry"], ]; - - // In round-robin, we get: - // 1. apple (1st from 1st source) - // 2. banana (1st from 2nd source) - cherry already taken - // 3. cherry (1st from 3rd source) - // 4. banana (2nd from 1st source) - // 5. date (3rd from 2nd source) - cherry already taken let result = collect_round_robin(&sources, 5); assert_eq!( result, @@ -213,8 +196,6 @@ mod tests { vec!["elderberry"], vec!["fig", "grape"], ]; - - // Round-robin order with uneven sources let result = collect_round_robin(&sources, 7); assert_eq!( result, @@ -240,8 +221,6 @@ mod tests { #[test] fn test_collect_round_robin_some_empty_sources() { let sources = vec![vec!["apple", "banana"], vec![], vec!["cherry", "date"]]; - - // Round-robin order, skipping empty source let result = collect_round_robin(&sources, 4); assert_eq!(result, vec!["apple", "cherry", "banana", "date"]); } @@ -249,8 +228,6 @@ mod tests { #[test] fn test_collect_round_robin_with_numbers() { let sources = vec![vec![1, 3, 5], vec![2, 4, 6]]; - - // Round-robin order with numbers let result = collect_round_robin(&sources, 6); assert_eq!(result, vec![1, 2, 3, 4, 5, 6]); } @@ -262,8 +239,6 @@ mod tests { vec!["date", "elderberry", "fig"], vec!["grape", "honeydew", "kiwi"], ]; - - // First round of round-robin (first from each source) let result = collect_round_robin(&sources, 3); assert_eq!(result, vec!["apple", "date", "grape"]); } @@ -271,8 +246,6 @@ mod tests { #[test] fn test_collect_round_robin_max_count_higher_than_available() { let sources = vec![vec!["apple", "banana"], vec!["cherry", "date"]]; - - // Round-robin order for all available elements let result = collect_round_robin(&sources, 10); assert_eq!(result, vec!["apple", "banana", "cherry", "date"]); } diff --git a/crates/codebook/src/parser.rs b/crates/codebook/src/parser.rs index 894c3b1..0ba3759 100644 --- a/crates/codebook/src/parser.rs +++ b/crates/codebook/src/parser.rs @@ -1,8 +1,9 @@ -use crate::splitter::{self}; - -use crate::queries::{LanguageType, get_language_setting}; +use crate::checker::WordCandidate; +use crate::queries::{LANGUAGE_SETTINGS, LanguageType, get_language_setting}; +use crate::splitter; use regex::Regex; use std::collections::{HashMap, HashSet}; +use std::str::FromStr; use std::sync::{LazyLock, Mutex}; use streaming_iterator::StreamingIterator; use tree_sitter::{Parser, Query, QueryCursor}; @@ -14,6 +15,43 @@ use unicode_segmentation::UnicodeSegmentation; static PARSER_CACHE: LazyLock>> = LazyLock::new(|| Mutex::new(HashMap::new())); +/// Pre-compiled query for a language, with its capture names. +struct CompiledQuery { + query: Query, + capture_names: Vec, +} + +/// All tree-sitter queries compiled eagerly at startup. Since queries come +/// from static `include_str!` data, they never change at runtime. Compiling +/// them once here means bad queries panic immediately rather than hiding +/// until a user opens that file type. +static COMPILED_QUERIES: LazyLock> = LazyLock::new(|| { + let mut map = HashMap::new(); + for setting in LANGUAGE_SETTINGS { + let Some(lang) = setting.language() else { + continue; + }; + if setting.query.is_empty() { + continue; + } + let query = Query::new(&lang, setting.query) + .unwrap_or_else(|e| panic!("Failed to compile query for {:?}: {e}", setting.type_)); + let capture_names = query + .capture_names() + .iter() + .map(|s| s.to_string()) + .collect(); + map.insert( + setting.type_, + CompiledQuery { + query, + capture_names, + }, + ); + } + map +}); + #[derive(Debug, Clone, Copy, PartialEq, Ord, Eq, PartialOrd, Hash)] pub struct TextRange { /// Start position in utf-8 byte offset @@ -24,27 +62,21 @@ pub struct TextRange { #[derive(Debug, Clone, Copy, PartialEq)] struct SkipRange { - /// Start position in utf-8 byte offset start_byte: usize, - /// End position in utf-8 byte offset end_byte: usize, } -/// Check if a word at [start, end) is entirely within any skip range fn is_within_skip_range(start: usize, end: usize, skip_ranges: &[SkipRange]) -> bool { skip_ranges .iter() .any(|r| start >= r.start_byte && end <= r.end_byte) } -/// Find skip ranges from pattern matches in text. fn find_skip_ranges(text: &str, patterns: &[Regex]) -> Vec { if patterns.is_empty() { return Vec::new(); } - let mut ranges = Vec::new(); - for pattern in patterns { for regex_match in pattern.find_iter(text) { ranges.push(SkipRange { @@ -53,20 +85,16 @@ fn find_skip_ranges(text: &str, patterns: &[Regex]) -> Vec { }); } } - ranges.sort_by_key(|r| r.start_byte); merge_overlapping_ranges(ranges) } -/// Merge overlapping or adjacent ranges fn merge_overlapping_ranges(ranges: Vec) -> Vec { if ranges.is_empty() { return ranges; } - let mut merged = Vec::new(); let mut current = ranges[0]; - for range in ranges.into_iter().skip(1) { if range.start_byte <= current.end_byte { current.end_byte = current.end_byte.max(range.end_byte); @@ -79,88 +107,6 @@ fn merge_overlapping_ranges(ranges: Vec) -> Vec { merged } -/// Helper struct to handle text position tracking and word extraction -struct TextProcessor { - text: String, - skip_ranges: Vec, -} - -impl TextProcessor { - fn new(text: &str, skip_patterns: &[Regex]) -> Self { - let skip_ranges = find_skip_ranges(text, skip_patterns); - Self { - text: text.to_string(), - skip_ranges, - } - } - - fn should_skip(&self, start_byte: usize, word_len: usize) -> bool { - is_within_skip_range(start_byte, start_byte + word_len, &self.skip_ranges) - } - - fn process_words_with_check(&self, mut check_function: F) -> Vec - where - F: FnMut(&str) -> bool, - { - // First pass: collect all unique words with their positions - let estimated_words = (self.text.len() as f64 / 6.0).ceil() as usize; - let mut word_positions: HashMap<&str, Vec> = - HashMap::with_capacity(estimated_words); - - for (offset, word) in self.text.split_word_bound_indices() { - if is_alphabetic(word) && !self.should_skip(offset, word.len()) { - self.collect_split_words(word, offset, &mut word_positions); - } - } - - // Second pass: batch check unique words and filter - let mut result_locations: HashMap> = HashMap::new(); - for (word_text, positions) in word_positions { - if !check_function(word_text) { - result_locations.insert(word_text.to_string(), positions); - } - } - - result_locations - .into_iter() - .map(|(word, locations)| WordLocation::new(word, locations)) - .collect() - } - - fn extract_words(&self) -> Vec { - // Reuse the word collection logic by collecting all words (check always returns false) - self.process_words_with_check(|_| false) - } - - fn collect_split_words<'a>( - &self, - word: &'a str, - offset: usize, - word_positions: &mut HashMap<&'a str, Vec>, - ) { - if !word.is_empty() { - let split = splitter::split(word); - for split_word in split { - if !is_numeric(split_word.word) { - let word_start_byte = offset + split_word.start_byte; - let location = TextRange { - start_byte: word_start_byte, - end_byte: word_start_byte + split_word.word.len(), - }; - let word_text = split_word.word; - word_positions.entry(word_text).or_default().push(location); - } - } - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct WordRef<'a> { - pub word: &'a str, - pub position: (u32, u32), // (start_char, line) -} - #[derive(Debug, Clone, PartialEq)] pub struct WordLocation { pub word: String, @@ -173,40 +119,77 @@ impl WordLocation { } } -pub fn find_locations( - text: &str, +// ============================================================================= +// Main entry point: recursive word extraction with injection support +// ============================================================================= + +/// Extract all candidate words from a document, recursively following +/// `@injection.*` captures in .scm query files to handle multi-language files. +/// +/// Returns the candidates and the set of all languages encountered (for +/// dictionary loading). +pub fn extract_all_words<'a>( + document_text: &'a str, language: LanguageType, - check_function: impl Fn(&str) -> bool, - tag_filter: impl Fn(&str) -> bool, + tag_filter: &dyn Fn(&str) -> bool, skip_patterns: &[Regex], -) -> Vec { - match language { - LanguageType::Text => { - let processor = TextProcessor::new(text, skip_patterns); - processor.process_words_with_check(|word| check_function(word)) - } - _ => find_locations_code( - text, - language, - |word| check_function(word), - &tag_filter, - skip_patterns, - ), - } +) -> (Vec>, HashSet) { + let skip_ranges = find_skip_ranges(document_text, skip_patterns); + let mut result = ExtractionResult { + candidates: Vec::new(), + languages: HashSet::from([language]), + }; + + extract_recursive( + document_text, + 0, + document_text.len(), + language, + tag_filter, + &skip_ranges, + &mut result, + ); + + (result.candidates, result.languages) +} + +/// Accumulated output from recursive word extraction. +struct ExtractionResult<'a> { + candidates: Vec>, + languages: HashSet, } -fn find_locations_code( - text: &str, +/// Recursively extract words from a byte range of the document. +/// +/// For languages with a tree-sitter grammar and .scm query: +/// - Text captures (`@string`, `@comment`, `@identifier.*`) → word-split +/// - Static injections (`@injection.{lang}`) → recurse with that language +/// - Dynamic injections (`@injection.content` + `@injection.language`) → read +/// the language name from the sibling capture, then recurse +/// +/// For LanguageType::Text (no grammar): word-split the entire range. +fn extract_recursive<'a>( + document_text: &'a str, + start_byte: usize, + end_byte: usize, language: LanguageType, - check_function: impl Fn(&str) -> bool, tag_filter: &dyn Fn(&str) -> bool, - skip_patterns: &[Regex], -) -> Vec { - let language_setting = - get_language_setting(language).expect("This _should_ never happen. Famous last words."); + skip_ranges: &[SkipRange], + result: &mut ExtractionResult<'a>, +) { + let language_setting = match get_language_setting(language) { + Some(s) => s, + None => { + // No grammar (e.g. Text): word-split the whole range + let text = &document_text[start_byte..end_byte]; + extract_words_from_text(text, start_byte, skip_ranges, &mut result.candidates); + return; + } + }; - // Parse under global lock to protect external scanners with global C state. - // The lock covers create + parse; Tree is fully owned after parse returns. + let region_text = &document_text[start_byte..end_byte]; + + // Parse under global lock let tree = { let mut cache = PARSER_CACHE.lock().unwrap(); let parser = cache.entry(language).or_insert_with(|| { @@ -215,82 +198,139 @@ fn find_locations_code( parser.set_language(&lang).unwrap(); parser }); - parser.parse(text, None).unwrap() + parser.parse(region_text, None).unwrap() }; let root_node = tree.root_node(); - let lang = language_setting.language().unwrap(); - let query = Query::new(&lang, language_setting.query).unwrap(); - let capture_names = query.capture_names(); + let compiled = COMPILED_QUERIES + .get(&language) + .expect("Language has a LanguageSetting but no compiled query; this should not happen"); let mut cursor = QueryCursor::new(); - let mut word_locations: HashMap> = HashMap::new(); - let provider = text.as_bytes(); - let mut matches_query = cursor.matches(&query, root_node, provider); - - // Find all skip ranges from patterns matched against the full source text - let all_skip_ranges = find_skip_ranges(text, skip_patterns); + let provider = region_text.as_bytes(); + let mut matches_query = cursor.matches(&compiled.query, root_node, provider); while let Some(match_) = matches_query.next() { + // First pass: look for dynamic injection pairs in this match + let mut injection_content: Option = None; + let mut injection_language_text: Option<&str> = None; + for capture in match_.captures { - // Filter by tag - let tag = &capture_names[capture.index as usize]; - if !tag_filter(tag) { - continue; + let tag = &compiled.capture_names[capture.index as usize]; + if tag == "injection.content" { + injection_content = Some(capture.node); + } else if tag == "injection.language" { + injection_language_text = Some(capture.node.utf8_text(provider).unwrap_or("")); + } + } + + // Handle dynamic injection pair + if let Some(content_node) = injection_content { + if let Some(lang_text) = injection_language_text { + let lowered = lang_text.trim().to_lowercase(); + let child_lang = LanguageType::from_str(&lowered); + if let Ok(child_lang) = child_lang + && child_lang != LanguageType::Text + { + let child_start = content_node.start_byte() + start_byte; + let child_end = content_node.end_byte() + start_byte; + if child_start < child_end { + result.languages.insert(child_lang); + extract_recursive( + document_text, + child_start, + child_end, + child_lang, + tag_filter, + skip_ranges, + result, + ); + } + } } + continue; + } + // Second pass: handle text captures and static injections + for capture in match_.captures { + let tag = &compiled.capture_names[capture.index as usize]; let node = capture.node; - let node_start_byte = node.start_byte(); + let node_start = node.start_byte() + start_byte; + let node_end = node.end_byte() + start_byte; - let node_text = node.utf8_text(provider).unwrap(); - let processor = TextProcessor::new(node_text, &[]); - let words = processor.extract_words(); - - // Check words against global skip ranges and dictionary - for word_pos in words { - if !check_function(&word_pos.word) { - for range in word_pos.locations { - let global_start = range.start_byte + node_start_byte; - let global_end = range.end_byte + node_start_byte; - - // Skip if word is entirely within a skip range - if is_within_skip_range(global_start, global_end, &all_skip_ranges) { - continue; - } - - let location = TextRange { - start_byte: global_start, - end_byte: global_end, - }; - if let Some(existing_result) = word_locations.get_mut(&word_pos.word) { - let added = existing_result.insert(location); - debug_assert!( - added, - "Two of the same locations found. Make a better query. Word: {}, Location: {:?}", - word_pos.word, location - ); - } else { - let mut set = HashSet::new(); - set.insert(location); - word_locations.insert(word_pos.word.clone(), set); - } - } + if node_start >= node_end { + continue; + } + + if tag == "language" || tag == "injection.language" { + continue; + } + + if let Some(lang_name) = tag.strip_prefix("injection.") { + // Static injection: @injection.html, @injection.javascript, etc. + if let Ok(child_lang) = LanguageType::from_str(lang_name) + && child_lang != LanguageType::Text + { + result.languages.insert(child_lang); + extract_recursive( + document_text, + node_start, + node_end, + child_lang, + tag_filter, + skip_ranges, + result, + ); } + continue; } + + // Normal text capture: extract words if tag passes filter + if !tag_filter(tag) { + continue; + } + + let node_text = node.utf8_text(provider).unwrap(); + extract_words_from_text(node_text, node_start, skip_ranges, &mut result.candidates); } } +} - word_locations - .keys() - .map(|word| WordLocation { - word: word.clone(), - locations: word_locations - .get(word) - .cloned() - .unwrap_or_default() - .into_iter() - .collect(), - }) - .collect() +// ============================================================================= +// Word extraction from plain text +// ============================================================================= + +fn extract_words_from_text<'a>( + text: &'a str, + base_offset: usize, + skip_ranges: &[SkipRange], + candidates: &mut Vec>, +) { + let mut split_buf = Vec::new(); + for (offset, word) in text.split_word_bound_indices() { + if !is_alphabetic(word) { + continue; + } + let global_offset = base_offset + offset; + if is_within_skip_range(global_offset, global_offset + word.len(), skip_ranges) { + continue; + } + splitter::split_into(word, &mut split_buf); + for split_word in &split_buf { + if is_numeric(split_word.word) { + continue; + } + let word_start = global_offset + split_word.start_byte; + let word_end = word_start + split_word.word.len(); + if is_within_skip_range(word_start, word_end, skip_ranges) { + continue; + } + candidates.push(WordCandidate { + word: split_word.word, + start_byte: word_start, + end_byte: word_end, + }); + } + } } fn is_numeric(s: &str) -> bool { @@ -312,155 +352,141 @@ pub fn get_word_from_string(start_utf16: usize, end_utf16: usize, text: &str) -> } #[cfg(test)] -mod parser_tests { +mod tests { use super::*; #[test] - fn test_spell_checking() { + fn test_extract_words_plain_text() { let text = "HelloWorld calc_wrld"; - let results = find_locations(text, LanguageType::Text, |_| false, |_| true, &[]); - println!("{results:?}"); - assert_eq!(results.len(), 4); + let (words, langs) = extract_all_words(text, LanguageType::Text, &|_| true, &[]); + let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect(); + assert!(word_strings.contains(&"Hello")); + assert!(word_strings.contains(&"World")); + assert!(word_strings.contains(&"calc")); + assert!(word_strings.contains(&"wrld")); + assert_eq!(words.len(), 4); + assert!(langs.contains(&LanguageType::Text)); } #[test] - fn test_get_words_from_text() { - let text = r#" - HelloWorld calc_wrld - I'm a contraction, don't ignore me - this is a 3rd line. - "#; - let expected = vec![ - ("Hello", (13, 18)), - ("World", (18, 23)), - ("calc", (24, 28)), - ("wrld", (29, 33)), - ("I'm", (46, 49)), - ("a", (50, 51)), - ("contraction", (52, 63)), - ("don't", (65, 70)), - ("ignore", (71, 77)), - ("me", (78, 80)), - ("this", (93, 97)), - ("is", (98, 100)), - ("a", (101, 102)), - ("rd", (104, 106)), - ("line", (107, 111)), - ]; - let processor = TextProcessor::new(text, &[]); - let words = processor.extract_words(); - println!("{words:?}"); - for word in words { - let loc = word.locations.first().unwrap(); - let pos = (loc.start_byte, loc.end_byte); - assert!( - expected.contains(&(word.word.as_str(), pos)), - "Expected word '{}' to be at position {:?}", - word.word, - pos - ); + fn test_extract_words_contraction() { + let text = "I'm a contraction, wouldn't you agree'?"; + let (words, _) = extract_all_words(text, LanguageType::Text, &|_| true, &[]); + let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect(); + let expected = ["I'm", "a", "contraction", "wouldn't", "you", "agree"]; + for e in &expected { + assert!(word_strings.contains(e), "Expected word '{e}' not found"); } } #[test] - fn test_contraction() { - let text = "I'm a contraction, wouldn't you agree'?"; - let processor = TextProcessor::new(text, &[]); - let words = processor.extract_words(); - println!("{words:?}"); - let expected = ["I'm", "a", "contraction", "wouldn't", "you", "agree"]; - for word in words { - assert!(expected.contains(&word.word.as_str())); - } + fn test_extract_words_code() { + let text = "// a comment\nfn main() {}"; + let (words, langs) = extract_all_words(text, LanguageType::Rust, &|_| true, &[]); + assert!(!words.is_empty()); + let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect(); + assert!( + word_strings.contains(&"comment"), + "Should find 'comment' in Rust comment" + ); + assert!(langs.contains(&LanguageType::Rust)); + } + + #[test] + fn test_extract_words_tag_filter() { + let text = "// comment\nlet x = \"string value\";"; + let (words, _) = extract_all_words( + text, + LanguageType::Rust, + &|tag| tag.starts_with("comment"), + &[], + ); + let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect(); + assert!(word_strings.contains(&"comment")); + assert!(!word_strings.contains(&"string")); + assert!(!word_strings.contains(&"value")); + } + + #[test] + fn test_extract_words_with_skip_patterns() { + let text = "check https://example.com this"; + let url_pattern = Regex::new(r"https?://[^\s]+").unwrap(); + let (words, _) = extract_all_words(text, LanguageType::Text, &|_| true, &[url_pattern]); + let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect(); + assert!(word_strings.contains(&"check")); + assert!(word_strings.contains(&"this")); + assert!(!word_strings.contains(&"https")); + assert!(!word_strings.contains(&"example")); + } + + #[test] + fn test_extract_words_code_duplicates() { + let text = "// wrld foo wrld"; + let (words, _) = extract_all_words(text, LanguageType::Rust, &|_| true, &[]); + let wrld_words: Vec<_> = words.iter().filter(|w| w.word == "wrld").collect(); + assert_eq!(wrld_words.len(), 2, "Expected two occurrences of 'wrld'"); + } + + #[test] + fn test_markdown_injection_discovers_languages() { + let text = + "# Hello\n\nSome text.\n\n```python\ndef foo(): pass\n```\n\n```bash\necho hi\n```\n"; + let (_, langs) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]); + assert!(langs.contains(&LanguageType::Markdown)); + assert!(langs.contains(&LanguageType::Python)); + assert!(langs.contains(&LanguageType::Bash)); + } + + #[test] + fn test_markdown_injection_extracts_code_words() { + let text = "# Hello\n\n```python\ndef some_functin(): pass\n```\n"; + let (words, _) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]); + let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect(); + assert!(word_strings.contains(&"functin")); + assert!(word_strings.contains(&"Hello")); + } + + #[test] + fn test_markdown_unknown_language_skipped() { + let text = "# Hello\n\n```unknownlang\nbadwwword\n```\n"; + let (words, _) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]); + let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect(); + assert!(!word_strings.contains(&"badwwword")); + } + + #[test] + fn test_markdown_html_block_injection() { + let text = "# Hello\n\n
\n

A misspeled word

\n
\n\nMore text.\n"; + let (words, langs) = extract_all_words(text, LanguageType::Markdown, &|_| true, &[]); + let word_strings: Vec<&str> = words.iter().map(|w| w.word).collect(); + assert!(langs.contains(&LanguageType::HTML)); + assert!(word_strings.contains(&"misspeled")); + assert!(!word_strings.contains(&"div")); } #[test] fn test_get_word_from_string() { - // Test with ASCII characters let text = "Hello World"; assert_eq!(get_word_from_string(0, 5, text), "Hello"); assert_eq!(get_word_from_string(6, 11, text), "World"); - // Test with partial words - assert_eq!(get_word_from_string(2, 5, text), "llo"); - - // Test with Unicode characters let unicode_text = "こんにちは世界"; assert_eq!(get_word_from_string(0, 5, unicode_text), "こんにちは"); assert_eq!(get_word_from_string(5, 7, unicode_text), "世界"); - // Test with emoji (which can be multi-codepoint) let emoji_text = "Hello 👨‍👩‍👧‍👦 World"; assert_eq!(get_word_from_string(6, 17, emoji_text), "👨‍👩‍👧‍👦"); } + #[test] fn test_unicode_character_handling() { crate::logging::init_test_logging(); let text = "©
badword
"; - let processor = TextProcessor::new(text, &[]); - let words = processor.extract_words(); - println!("{words:?}"); - - // Make sure "badword" is included and correctly positioned - assert!(words.iter().any(|word| word.word == "badword")); - - // If "badword" is found, verify its position - if let Some(pos) = words.iter().find(|word| word.word == "badword") { - // The correct position should be 6 (after ©
) - let start_byte = pos.locations.first().unwrap().start_byte; - let end_byte = pos.locations.first().unwrap().end_byte; - assert_eq!( - start_byte, 7, - "Expected 'badword' to start at character position 7" - ); - assert_eq!(end_byte, 14, "Expected 'badword' to be on end_byte 14"); - } else { - panic!("Word 'badword' not found in the text"); - } - } - - #[test] - fn test_duplicate_word_locations() { - // Use a code language to exercise find_locations_code path - let text = "// wrld foo wrld"; - let results = find_locations(text, LanguageType::Rust, |_| false, |_| true, &[]); - let wrld = results.iter().find(|loc| loc.word == "wrld").unwrap(); - assert_eq!( - wrld.locations.len(), - 2, - "Expected two locations for repeated word 'wrld'" - ); + let (words, _) = extract_all_words(text, LanguageType::Text, &|_| true, &[]); + let bad_word = words.iter().find(|w| w.word == "badword"); + assert!(bad_word.is_some(), "Expected 'badword' to be found"); + let bw = bad_word.unwrap(); + assert_eq!(bw.start_byte, 7); + assert_eq!(bw.end_byte, 14); } - - // Something is up with the HTML tree-sitter package - // #[test] - // fn test_spell_checking_with_unicode() { - // crate::log::init_test_logging(); - // let text = "©
badword
"; - - // // Mock spell check function that flags "badword" - // let results = find_locations(text, LanguageType::Html, |word| word != "badword"); - - // println!("{:?}", results); - - // // Ensure "badword" is flagged - // let badword_result = results.iter().find(|loc| loc.word == "badword"); - // assert!(badword_result.is_some(), "Expected 'badword' to be flagged"); - - // // Check if the location is correct - // if let Some(location) = badword_result { - // assert_eq!( - // location.locations.len(), - // 1, - // "Expected exactly one location for 'badword'" - // ); - // let range = &location.locations[0]; - - // // The word should start after "©
" which is 6 characters - // assert_eq!(range.start_char, 6, "Wrong start position for 'badword'"); - - // // The word should end after "badword" which is 13 characters from the start - // assert_eq!(range.end_char, 13, "Wrong end position for 'badword'"); - // } - // } } diff --git a/crates/codebook/src/queries.rs b/crates/codebook/src/queries.rs index ccf83df..df2ce00 100644 --- a/crates/codebook/src/queries.rs +++ b/crates/codebook/src/queries.rs @@ -45,6 +45,11 @@ impl FromStr for LanguageType { return Ok(language.type_); } } + for ext in language.extensions.iter() { + if s == *ext { + return Ok(language.type_); + } + } } Ok(LanguageType::Text) } @@ -209,7 +214,15 @@ pub static LANGUAGE_SETTINGS: &[LanguageSetting] = &[ }, LanguageSetting { type_: LanguageType::Bash, - ids: &["bash", "shellscript", "sh", "shell script"], + ids: &[ + "bash", + "shellscript", + "sh", + "shell script", + "shell", + "zsh", + "fish", + ], dictionary_ids: &["bash"], query: include_str!("queries/bash.scm"), extensions: &["sh", "bash"], @@ -237,7 +250,7 @@ pub static LANGUAGE_SETTINGS: &[LanguageSetting] = &[ }, LanguageSetting { type_: LanguageType::YAML, - ids: &["yaml"], + ids: &["yaml", "yml"], dictionary_ids: &["yaml"], query: include_str!("queries/yaml.scm"), extensions: &["yaml", "yml"], @@ -399,9 +412,9 @@ mod tests { continue; } - let language = language_setting - .language() - .unwrap_or_else(|| panic!("Failed to get language for {:?}", language_setting.type_)); + let language = language_setting.language().unwrap_or_else(|| { + panic!("Failed to get language for {:?}", language_setting.type_) + }); let query = Query::new(&language, language_setting.query).unwrap_or_else(|e| { panic!( @@ -411,13 +424,46 @@ mod tests { }); for name in query.capture_names() { + let is_allowed = ALLOWED_TAGS.contains(name) || name.starts_with("injection."); assert!( - ALLOWED_TAGS.contains(&name.as_ref()), + is_allowed, "Language {:?} uses unknown capture tag @{name}. \ - Allowed tags: {ALLOWED_TAGS:?}", + Allowed tags: {ALLOWED_TAGS:?} (plus injection.* tags)", language_setting.type_, ); } } } + + #[test] + fn test_no_overlap_in_ids_and_extensions() { + use std::collections::HashMap; + + // Map every id and extension to the language that owns it + let mut seen: HashMap<&str, LanguageType> = HashMap::new(); + + for setting in LANGUAGE_SETTINGS { + for &id in setting.ids { + if let Some(&prev) = seen.get(id) { + panic!( + "Duplicate id/extension {id:?}: used by both {:?} and {:?}", + prev, setting.type_ + ); + } + seen.insert(id, setting.type_); + } + for &ext in setting.extensions { + if let Some(&prev) = seen.get(ext) { + // Allow overlap within the same language (e.g. "hs" in both ids and extensions) + if prev != setting.type_ { + panic!( + "Duplicate id/extension {ext:?}: used by both {:?} and {:?}", + prev, setting.type_ + ); + } + } + seen.insert(ext, setting.type_); + } + } + } } diff --git a/crates/codebook/src/queries/README.md b/crates/codebook/src/queries/README.md index dcebba7..9b413a4 100644 --- a/crates/codebook/src/queries/README.md +++ b/crates/codebook/src/queries/README.md @@ -27,6 +27,34 @@ Every capture name is a **tag** that categorizes the matched text. Tags use a do Not every language needs every tag. HTML, for example, only uses `@comment` and `@string`. You can get a feel for which tags are available for a specific language by looking at the `scm` file for that language in this directory. +### Injection Tags (Multi-Language Support) + +Injection tags tell codebook to re-parse a region of the file using a different language's grammar. This is how Markdown code blocks, HTML `