diff --git a/README.md b/README.md index ef9f6e3..492d426 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # Spellchecker +[![Go Reference](https://pkg.go.dev/badge/github.com/f1monkey/spellchecker.svg)](https://pkg.go.dev/github.com/f1monkey/spellchecker) +[![CI](https://github.com/f1monkey/spellchecker/actions/workflows/test.yml/badge.svg)](https://github.com/f1monkey/spellchecker/actions/workflows/test.yml) + Yet another spellchecker written in go. - [Spellchecker](#spellchecker) @@ -11,9 +14,9 @@ Yet another spellchecker written in go. - [Test set 2:](#test-set-2) ## Features: -- very small database: approximately 1mb for 30,000 unique words -- average time to fix one word ~35μs -- about 70-74% accuracy in Peter Norvig's test sets (see [benchmarks](#benchmarks)) +- very compact database: ~1 MB for 30,000 unique words +- average time to fix a single word: ~35 µs +- achieves about 70–74% accuracy on Peter Norvig’s test sets (see [benchmarks](#benchmarks)) ## Installation @@ -29,37 +32,37 @@ go get -v github.com/f1monkey/spellchecker ```go func main() { - // Create new instance + // Create a new instance sc, err := spellchecker.New( "abcdefghijklmnopqrstuvwxyz1234567890", // allowed symbols, other symbols will be ignored - spellchecker.WithMaxErrors(2) + spellchecker.WithMaxErrors(2) // see options.go ) if err != nil { panic(err) } - // Read data from any io.Reader + // Load data from any io.Reader in, err := os.Open("data/sample.txt") if err != nil { panic(err) } sc.AddFrom(in) - // Add some words + // Add words manually sc.Add("lock", "stock", "and", "two", "smoking", "barrels") - // Check if a word is correct + // Check if a word is valid result := sc.IsCorrect("coffee") fmt.Println(result) // true - // Fix one word + // Correct a single word fixed, err := sc.Fix("awepon") if err != nil && !errors.Is(err, spellchecker.ErrUnknownWord) { panic(err) } fmt.Println(fixed) // weapon - // Find max=10 suggestions for a word + // Find up to 10 suggestions for a word matches, err := sc.Suggest("rang", 10) if err != nil && !errors.Is(err, spellchecker.ErrUnknownWord) { panic(err) @@ -67,6 +70,10 @@ func main() { fmt.Println(matches) // [range, orange] ``` +### Options + +See [options.go](./options.go) for the list of available options. + ### Save/load ```go @@ -79,7 +86,7 @@ func main() { } sc.Save(out) - // Load saved data from io.Reader + // Load data back from io.Reader in, err = os.Open("data/out.bin") if err != nil { panic(err) @@ -92,26 +99,28 @@ func main() { ### Custom score function -You can provide a custom score function if you need to. +You can provide a custom scoring function if needed: ```go - var scoreFunc spellchecker.ScoreFunc = func(src, candidate []rune, distance, cnt int) float64 { - return 1.0 // return constant score + var fn spellchecker.FilterFunc = func(src, candidate []rune, cnt int) (float64, bool) { + // you can calculate Levenshtein distance here (see defaultFilterFunc in options.go for example) + + return 1.0, true // constant score } - sc, err := spellchecker.New("abc", spellchecker.WithScoreFunc(scoreFunc)) + sc, err := spellchecker.New("abc", spellchecker.WithFilterFunc(fn)) if err != nil { // handle err } - // after you load spellchecker from file - // you will need to provide the function again: + // After loading a spellchecker from a file, + // you need to set the function again: sc, err = spellchecker.Load(inFile) if err != nil { // handle err } - err = sc.WithOpts(spellchecker.WithScoreFunc(scoreFunc)) + err = sc.WithOpts(spellchecker.WithFilterFunc(fn)) if err != nil { // handle err } @@ -125,27 +134,28 @@ Tests are based on data from [Peter Norvig's article about spelling correction]( #### [Test set 1](http://norvig.com/spell-testset1.txt): ``` -Running tool: /usr/local/go/bin/go test -benchmem -run=^$ -bench ^Benchmark_Norvig1$ github.com/f1monkey/spellchecker +Running tool: /usr/bin/go test -benchmem -run=^$ -bench ^Benchmark_Norvig1$ github.com/f1monkey/spellchecker -count=1 goos: linux goarch: amd64 pkg: github.com/f1monkey/spellchecker cpu: 13th Gen Intel(R) Core(TM) i9-13980HX -Benchmark_Norvig1-32 294 3876229 ns/op 74.07 success_percent 200.0 success_words 270.0 total_words 918275 B/op 2150 allocs/op +Benchmark_Norvig1-32 348 3385868 ns/op 74.44 success_percent 201.0 success_words 270.0 total_words 830803 B/op 15504 allocs/op PASS -ok github.com/f1monkey/spellchecker 3.378s +ok github.com/f1monkey/spellchecker 3.723s ``` #### [Test set 2](http://norvig.com/spell-testset2.txt): ``` -Running tool: /usr/local/go/bin/go test -benchmem -run=^$ -bench ^Benchmark_Norvig2$ github.com/f1monkey/spellchecker +Running tool: /usr/bin/go test -benchmem -run=^$ -bench ^Benchmark_Norvig2$ github.com/f1monkey/spellchecker -count=1 goos: linux goarch: amd64 pkg: github.com/f1monkey/spellchecker cpu: 13th Gen Intel(R) Core(TM) i9-13980HX -Benchmark_Norvig2-32 198 6102429 ns/op 70.00 success_percent 280.0 success_words 400.0 total_words 1327385 B/op 3121 allocs/op +Benchmark_Norvig2-32 231 4935406 ns/op 71.25 success_percent 285.0 success_words 400.0 total_words 1270755 B/op 21801 allocs/op PASS -ok github.com/f1monkey/spellchecker 3.895s +ok github.com/f1monkey/spellchecker 4.057s + ``` diff --git a/dictionary.go b/dictionary.go index f0f4992..222d4ed 100644 --- a/dictionary.go +++ b/dictionary.go @@ -4,44 +4,41 @@ import ( "bytes" "encoding" "encoding/gob" - "sort" + "sync" "sync/atomic" - "github.com/agnivade/levenshtein" "github.com/f1monkey/bitmap" ) -type scoreFunc func(src []rune, candidate []rune, distance int, cnt uint) float64 - type dictionary struct { maxErrors int alphabet alphabet nextID func() uint32 - words map[uint32]string + words map[uint32][]rune ids map[string]uint32 counts map[uint32]uint index map[uint64][]uint32 - scoreFunc scoreFunc + filterFunc FilterFunc } -func newDictionary(ab string, scoreFunc scoreFunc, maxErrors int) (*dictionary, error) { +func newDictionary(ab string, filterFunc FilterFunc, maxErrors int) (*dictionary, error) { alphabet, err := newAlphabet(ab) if err != nil { return nil, err } return &dictionary{ - maxErrors: maxErrors, - alphabet: alphabet, - nextID: idSeq(0), - ids: make(map[string]uint32), - words: make(map[uint32]string), - counts: make(map[uint32]uint), - index: make(map[uint64][]uint32), - scoreFunc: scoreFunc, + maxErrors: maxErrors, + alphabet: alphabet, + nextID: idSeq(0), + ids: make(map[string]uint32), + words: make(map[uint32][]rune), + counts: make(map[uint32]uint), + index: make(map[uint64][]uint32), + filterFunc: filterFunc, }, nil } @@ -60,10 +57,11 @@ func (d *dictionary) add(word string, n uint) (uint32, error) { id := d.nextID() d.ids[word] = id - runes := []rune(word) + wordRunes := []rune(word) + d.counts[id] = n - d.words[id] = word - key := sum(d.alphabet.encode(runes)) + d.words[id] = wordRunes + key := sum(d.alphabet.encode(wordRunes)) d.index[key] = append(d.index[key], id) return id, nil @@ -88,110 +86,80 @@ func (d *dictionary) find(word string, n int) []Match { return nil } - candidates := d.getCandidates(word, n) - sort.Slice(candidates, func(i, j int) bool { return candidates[i].Score > candidates[j].Score }) + result := newPriorityQueue(n) - return candidates -} + wordRunes := []rune(word) + bmSrc := d.alphabet.encode(wordRunes) -func (d *dictionary) getCandidates(word string, max int) []Match { - result := newPriorityQueue(max) + // check for transposition or exact match and do early termination if found + // (the most common mistake is a transposition of letters) + d.fillWithCandidates(result, wordRunes, sum(bmSrc)) + if result.Len() != 0 { + return result.DrainSorted() + } - wordRunes := []rune(word) - bmSrc := d.alphabet.encode([]rune(wordRunes)) + bitmaps := bitmapsPool.Get().(map[uint64]struct{}) + d.computeCandidateBitmaps(bitmaps, bmSrc, d.maxErrors) + for bm := range bitmaps { + d.fillWithCandidates(result, wordRunes, bm) + } - // "exact match" OR "candidate has all the same letters as the word but in different order" - key := sum(bmSrc) - ids := d.index[key] - for _, id := range ids { - docWord, ok := d.words[id] - if !ok { - continue + releaseBitmaps(bitmaps) + + return result.DrainSorted() +} + +func (d *dictionary) computeCandidateBitmaps(bitmaps map[uint64]struct{}, src bitmap.Bitmap32, maxFlips int) { + var dfs func(bm bitmap.Bitmap32, level, start int) + dfs = func(bm bitmap.Bitmap32, level, start int) { + key := sum(bm) + if len(d.index[key]) > 0 { + bitmaps[key] = struct{}{} } - distance := levenshtein.ComputeDistance(word, docWord) - if distance > d.maxErrors { - continue + if level == maxFlips { + return } - result.Push(Match{ - Value: docWord, - Score: d.scoreFunc(wordRunes, []rune(docWord), distance, d.counts[id]), - }) - } - // the most common mistake is a transposition of letters. - // so if we found one here, we do early termination - if result.Len() != 0 { - return result.items - } - // @todo perform phonetic analysis with early termination here - for bm := range d.computeCandidateBitmaps(bmSrc) { - ids := d.index[bm] - for _, id := range ids { - docWord, ok := d.words[id] - if !ok { - continue - } - - distance := levenshtein.ComputeDistance(word, docWord) - if distance > d.maxErrors { - continue - } - result.Push(Match{ - Value: docWord, - Score: d.scoreFunc(wordRunes, []rune(docWord), distance, d.counts[id]), - }) + for i := start; i < d.alphabet.len(); i++ { + bm.Xor(uint32(i)) // change one bit + dfs(bm, level+1, i+1) + bm.Xor(uint32(i)) // revert back } } - return result.items + dfs(src.Clone(), 0, 0) } -func (d *dictionary) computeCandidateBitmaps(bmSrc bitmap.Bitmap32) map[uint64]struct{} { - bitmaps := make(map[uint64]struct{}, d.alphabet.len()*5) - bmSrc = bmSrc.Clone() - - var i, j uint32 - // swap one bit - for i = 0; i < uint32(d.alphabet.len()); i++ { - bmSrc.Xor(i) - - // swap one more bit to be able to fix: - // - two deletions ("rang" => "orange") - // - replacements ("problam" => "problem") - for j = 0; j < uint32(d.alphabet.len()); j++ { - if i == j { - continue - } - - bmSrc.Xor(j) - key := sum(bmSrc) - bmSrc.Xor(j) // return back the changed bit - if len(d.index[key]) == 0 { - continue - } - bitmaps[key] = struct{}{} +func (d *dictionary) fillWithCandidates(result *priorityQueue, wordRunes []rune, bm uint64) { + ids := d.index[bm] + for _, id := range ids { + docWord, ok := d.words[id] + if !ok { + continue } - key := sum(bmSrc) - bmSrc.Xor(i) // return back the changed bit - if len(d.index[key]) == 0 { + score, ok := d.filterFunc(wordRunes, docWord, d.counts[id]) + if !ok { continue } - bitmaps[key] = struct{}{} - } - return bitmaps + result.Push(Match{ + Value: string(docWord), + Score: score, + }) + } } var _ encoding.BinaryMarshaler = (*dictionary)(nil) var _ encoding.BinaryUnmarshaler = (*dictionary)(nil) type dictData struct { - Alphabet alphabet - IDs map[string]uint32 - Words map[uint32]string - Counts map[uint32]uint + Alphabet alphabet + IDs map[string]uint32 + Words map[uint32]string + WordRunes map[uint32][]rune + Counts map[uint32]uint Index map[uint64][]uint32 @@ -202,7 +170,7 @@ func (d *dictionary) MarshalBinary() ([]byte, error) { data := &dictData{ Alphabet: d.alphabet, IDs: d.ids, - Words: d.words, + WordRunes: d.words, Counts: d.counts, Index: d.index, MaxErrors: d.maxErrors, @@ -227,10 +195,21 @@ func (d *dictionary) UnmarshalBinary(data []byte) error { d.alphabet = dictData.Alphabet d.ids = dictData.IDs d.counts = dictData.Counts - d.words = dictData.Words + + // compatibility with previous versions + if len(dictData.Words) > 0 { + wordRunes := make(map[uint32][]rune, len(dictData.Words)) + for k, v := range dictData.Words { + wordRunes[k] = []rune(v) + } + d.words = wordRunes + } else { + d.words = dictData.WordRunes + } + d.index = dictData.Index d.maxErrors = dictData.MaxErrors - d.scoreFunc = defaultScorefunc + d.filterFunc = defaultFilterFunc(dictData.MaxErrors) var max uint32 for _, id := range d.ids { @@ -259,3 +238,17 @@ func sum(b bitmap.Bitmap32) uint64 { return result } + +func releaseBitmaps(m map[uint64]struct{}) { + for k := range m { + delete(m, k) + } + + bitmapsPool.Put(m) +} + +var bitmapsPool = sync.Pool{ + New: func() interface{} { + return make(map[uint64]struct{}, 256) + }, +} diff --git a/dictionary_test.go b/dictionary_test.go index a779557..938ae82 100644 --- a/dictionary_test.go +++ b/dictionary_test.go @@ -7,7 +7,7 @@ import ( ) func Test_dictionary_id(t *testing.T) { - dict, err := newDictionary(DefaultAlphabet, defaultScorefunc, DefaultMaxErrors) + dict, err := newDictionary(DefaultAlphabet, nil, DefaultMaxErrors) require.NoError(t, err) t.Run("must return 0 for unexisting word", func(t *testing.T) { @@ -24,14 +24,14 @@ func Test_dictionary_id(t *testing.T) { func Test_dictionary_add(t *testing.T) { t.Run("must add word to dictionary index", func(t *testing.T) { - dict, err := newDictionary(DefaultAlphabet, defaultScorefunc, DefaultMaxErrors) + dict, err := newDictionary(DefaultAlphabet, nil, DefaultMaxErrors) require.NoError(t, err) id, err := dict.add("qwe", 1) require.NoError(t, err) require.Equal(t, uint32(1), id) require.Equal(t, uint(1), dict.counts[id]) - require.Equal(t, "qwe", dict.words[id]) + require.Equal(t, []rune("qwe"), dict.words[id]) require.Equal(t, 1, len(dict.ids)) require.Len(t, dict.index, 1) @@ -39,7 +39,7 @@ func Test_dictionary_add(t *testing.T) { require.NoError(t, err) require.Equal(t, uint32(2), id) require.Equal(t, uint(2), dict.counts[id]) - require.Equal(t, "asd", dict.words[id]) + require.Equal(t, []rune("asd"), dict.words[id]) require.Equal(t, 2, len(dict.ids)) require.Len(t, dict.index, 2) @@ -49,7 +49,7 @@ func Test_dictionary_add(t *testing.T) { func Test_Dictionary_Inc(t *testing.T) { t.Run("must increase counter value", func(t *testing.T) { - dict, err := newDictionary(DefaultAlphabet, defaultScorefunc, DefaultMaxErrors) + dict, err := newDictionary(DefaultAlphabet, nil, DefaultMaxErrors) dict.counts[1] = 0 require.NoError(t, err) diff --git a/go.mod b/go.mod index a5eb1a0..09f0604 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/f1monkey/spellchecker go 1.19 require ( - github.com/agnivade/levenshtein v1.1.1 + github.com/agext/levenshtein v1.2.3 github.com/f1monkey/bitmap v1.4.0 github.com/stretchr/testify v1.8.4 ) diff --git a/go.sum b/go.sum index 0905a72..45fbd50 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,7 @@ -github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= -github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= -github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= -github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= +github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= +github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= -github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= -github.com/f1monkey/bitmap v1.2.0 h1:INxVA52Ckcxb7RXrdWU63K5iJDsToTWTtcXWdLe5aNs= -github.com/f1monkey/bitmap v1.2.0/go.mod h1:qOc9q5FQxdvMyjVDnmvfJxUtz8JIryqOGxpg4Vtg4nY= github.com/f1monkey/bitmap v1.4.0 h1:Is1PqZWrTawUowD/qE7Vnlh9fzXrEs/qxJHDQ47jZ3g= github.com/f1monkey/bitmap v1.4.0/go.mod h1:qOc9q5FQxdvMyjVDnmvfJxUtz8JIryqOGxpg4Vtg4nY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/options.go b/options.go new file mode 100644 index 0000000..1aef963 --- /dev/null +++ b/options.go @@ -0,0 +1,119 @@ +package spellchecker + +import ( + "bufio" + "math" + + "github.com/agext/levenshtein" +) + +// WithOpt set spellchecker options +func (s *Spellchecker) WithOpts(opts ...OptionFunc) error { + s.mtx.Lock() + defer s.mtx.Unlock() + + for _, o := range opts { + if err := o(s); err != nil { + return err + } + } + + if s.scoreFunc != nil { + s.dict.filterFunc = wrapScoreFunc(s.scoreFunc, s.maxErrors) + } else { + s.dict.filterFunc = s.filterFunc + } + + return nil +} + +// WithSplitter set splitter func for AddFrom() reader +func WithSplitter(f bufio.SplitFunc) OptionFunc { + return func(s *Spellchecker) error { + s.splitter = f + return nil + } +} + +// WithMaxErrors sets maxErrors — the maximum allowed difference in bits +// between the "search word" and a "dictionary word". +// - deletion is a 1-bit change (proble → problem) +// - insertion is a 1-bit change (problemm → problem) +// - substitution is a 2-bit change (problam → problem) +// - transposition is a 0-bit change (problme → problem) +// +// It is not recommended to set this value greater than 2, +// as it can significantly affect performance. +func WithMaxErrors(maxErrors int) OptionFunc { + return func(s *Spellchecker) error { + s.maxErrors = maxErrors + + return nil + } +} + +// FilterFunc compares the source word with a candidate word. +// It returns the candidate's score and a boolean flag. +// If the flag is false, the candidate will be completely filtered out. +type FilterFunc func(src, candidate []rune, count uint) (float64, bool) + +// WithFilterFunc set custom scoring function +func WithFilterFunc(f FilterFunc) OptionFunc { + return func(s *Spellchecker) error { + s.filterFunc = f + return nil + } +} + +// ScoreFunc custom scoring function type +// +// Deprecated: use FilterFunc instead +type ScoreFunc func(src []rune, candidate []rune, distance int, cnt uint) float64 + +// WithScoreFunc specify a function that will be used for scoring +// +// Deprecated: use WithFilterFunc instead +func WithScoreFunc(f ScoreFunc) OptionFunc { + return func(s *Spellchecker) error { + s.scoreFunc = f + return nil + } +} + +func defaultFilterFunc(maxErrors int) FilterFunc { + return func(src, candidate []rune, count uint) (float64, bool) { + distance, prefixLen, suffixLen := levenshtein.Calculate(src, candidate, 0, 1, 1, 1) + if distance > maxErrors { + return 0, false + } + + mult := math.Log1p(float64(count)) * math.Pow(1.5, float64(prefixLen+suffixLen)) + + return 1 / (1 + float64(distance*distance)) * mult, true + } +} + +func wrapScoreFunc(f ScoreFunc, maxErrors int) FilterFunc { + return func(src, candidate []rune, count uint) (float64, bool) { + distance, _, _ := levenshtein.Calculate(src, candidate, 0, 1, 1, 1) + if distance > maxErrors { + return 0, false + } + + return f(src, candidate, distance, count), true + } +} + +var defaultScoreFunc ScoreFunc = func(src, candidate []rune, distance int, cnt uint) float64 { + mult := math.Log1p(float64(cnt)) + // if first letters are the same, increase score + if src[0] == candidate[0] { + mult *= 1.5 + // if second letters are the same too, increase score even more + if len(src) > 1 && len(candidate) > 1 && src[1] == candidate[1] { + mult *= 1.5 + } + } + + return 1 / (1 + float64(distance*distance)) * mult +} diff --git a/priority_queue.go b/priority_queue.go index 2a93a1d..0c4b011 100644 --- a/priority_queue.go +++ b/priority_queue.go @@ -2,13 +2,11 @@ package spellchecker import "container/heap" -// priorityQueue implements heap.Interface and holds matches. type priorityQueue struct { items []Match capacity int } -// newPriorityQueue initializes a new priorityQueue with a given capacity func newPriorityQueue(capacity int) *priorityQueue { return &priorityQueue{ items: make([]Match, 0, capacity), @@ -19,7 +17,7 @@ func newPriorityQueue(capacity int) *priorityQueue { func (pq priorityQueue) Len() int { return len(pq.items) } func (pq priorityQueue) Less(i, j int) bool { - return pq.items[j].Score > pq.items[i].Score + return pq.items[i].Score < pq.items[j].Score } func (pq priorityQueue) Swap(i, j int) { @@ -28,19 +26,37 @@ func (pq priorityQueue) Swap(i, j int) { func (pq *priorityQueue) Push(x interface{}) { item := x.(Match) + if len(pq.items) < pq.capacity { pq.items = append(pq.items, item) heap.Fix(pq, len(pq.items)-1) - } else if len(pq.items) > 0 && item.Score >= pq.items[0].Score { - pq.items[0] = item - heap.Fix(pq, 0) + return + } + + if item.Score < pq.items[0].Score { + return + } + + pq.items[0] = item + heap.Fix(pq, 0) } func (pq *priorityQueue) Pop() interface{} { old := pq.items n := len(old) item := old[n-1] - pq.items = old[0 : n-1] + pq.items = old[:n-1] + return item } + +func (pq *priorityQueue) DrainSorted() []Match { + out := make([]Match, pq.Len()) + + for i := len(out) - 1; i >= 0; i-- { + out[i] = heap.Pop(pq).(Match) + } + + return out +} diff --git a/spellchecker.go b/spellchecker.go index e23219f..9b6feba 100644 --- a/spellchecker.go +++ b/spellchecker.go @@ -4,7 +4,6 @@ import ( "bufio" "fmt" "io" - "math" "sync" ) @@ -16,22 +15,18 @@ type OptionFunc func(s *Spellchecker) error type Spellchecker struct { mtx sync.RWMutex - dict *dictionary - splitter bufio.SplitFunc - scoreFunc scoreFunc - maxErrors int + dict *dictionary + splitter bufio.SplitFunc + filterFunc FilterFunc + scoreFunc ScoreFunc + maxErrors int } func New(alphabet string, opts ...OptionFunc) (*Spellchecker, error) { result := &Spellchecker{ - maxErrors: DefaultMaxErrors, - scoreFunc: defaultScorefunc, + maxErrors: DefaultMaxErrors, + filterFunc: defaultFilterFunc(DefaultMaxErrors), } - dict, err := newDictionary(alphabet, result.scoreFunc, result.maxErrors) - if err != nil { - return nil, err - } - result.dict = dict for _, o := range opts { if err := o(result); err != nil { @@ -39,10 +34,21 @@ func New(alphabet string, opts ...OptionFunc) (*Spellchecker, error) { } } + if result.scoreFunc != nil { + result.filterFunc = wrapScoreFunc(result.scoreFunc, result.maxErrors) + } + + dict, err := newDictionary(alphabet, result.filterFunc, result.maxErrors) + if err != nil { + return nil, err + } + + result.dict = dict + return result, nil } -// AddFrom reads input, splits it with spellchecker splitter func and adds words to dictionary +// AddFrom reads input, splits it with spellchecker splitter func and adds words to the dictionary func (m *Spellchecker) AddFrom(input io.Reader) error { words := make([]string, 1000) i := 0 @@ -66,7 +72,7 @@ func (m *Spellchecker) AddFrom(input io.Reader) error { return nil } -// Add adds provided words to dictionary +// Add adds provided words to the dictionary func (m *Spellchecker) Add(words ...string) { m.mtx.Lock() defer m.mtx.Unlock() @@ -81,7 +87,7 @@ func (m *Spellchecker) Add(words ...string) { } } -// AddWeight adds provided words to dictionary with a custom weight +// AddWeight adds provided words to the dictionary with a custom weight func (m *Spellchecker) AddWeight(weight uint, words ...string) { m.mtx.Lock() defer m.mtx.Unlock() @@ -163,58 +169,3 @@ func (s *Spellchecker) SuggestScore(word string, n int) SuggestionResult { Suggestions: s.dict.find(word, n), } } - -// WithOpt set spellchecker options -func (s *Spellchecker) WithOpts(opts ...OptionFunc) error { - s.mtx.Lock() - defer s.mtx.Unlock() - - for _, o := range opts { - if err := o(s); err != nil { - return err - } - } - - return nil -} - -// WithSplitter set splitter func for AddFrom() reader -func WithSplitter(f bufio.SplitFunc) OptionFunc { - return func(s *Spellchecker) error { - s.splitter = f - return nil - } -} - -// WithMaxErrors set maxErrors, which is a max diff in bits betweeen the "search word" and a "dictionary word". -// i.e. one simple symbol replacement (problam => problem ) is a two-bit difference -func WithMaxErrors(maxErrors int) OptionFunc { - return func(s *Spellchecker) error { - s.maxErrors = maxErrors - return nil - } -} - -type ScoreFunc = scoreFunc - -// WithScoreFunc specify a function that will be used for scoring -func WithScoreFunc(f ScoreFunc) OptionFunc { - return func(s *Spellchecker) error { - s.dict.scoreFunc = f - return nil - } -} - -var defaultScorefunc scoreFunc = func(src, candidate []rune, distance int, cnt uint) float64 { - mult := math.Log1p(float64(cnt)) - // if first letters are the same, increase score - if src[0] == candidate[0] { - mult *= 1.5 - // if second letters are the same too, increase score even more - if len(src) > 1 && len(candidate) > 1 && src[1] == candidate[1] { - mult *= 1.5 - } - } - - return 1 / (1 + float64(distance*distance)) * mult -} diff --git a/spellchecker_test.go b/spellchecker_test.go index 93b149f..9361058 100644 --- a/spellchecker_test.go +++ b/spellchecker_test.go @@ -15,24 +15,24 @@ import ( func loadFullSpellchecker() *Spellchecker { var s *Spellchecker ff, err := os.Open("data/spellchecker.bin") - if errors.Is(err, os.ErrNotExist) { - s = newFullSpellchecker() - dst, err := os.Create("data/spellchecker.bin") - if err != nil { - panic(err) - } - - err = s.Save(dst) - if err != nil { - panic(err) - } - } else { + if !errors.Is(err, os.ErrNotExist) { s, err = Load(ff) - if err != nil { - panic(err) + if err == nil { + return s } } + s = newFullSpellchecker() + dst, err := os.Create("data/spellchecker.bin") + if err != nil { + panic(err) + } + + err = s.Save(dst) + if err != nil { + panic(err) + } + return s } @@ -164,7 +164,6 @@ func benchmarkNorvig(b *testing.B, dataPath string) { total := 0 ok := 0 for i := 0; i < b.N; i++ { - for _, item := range data { for _, word := range item.words { if word == "" { @@ -236,6 +235,15 @@ func Test_Spellchecker_Fix(t *testing.T) { require.Equal(t, "problem", result) } +func Test_Spellchecker_Fix_ScoreFunc(t *testing.T) { + s := newSampleSpellchecker() + s.WithOpts(WithScoreFunc(defaultScoreFunc)) + + result, err := s.Fix("problam") + require.NoError(t, err) + require.Equal(t, "problem", result) +} + func Test_Spellchecker_Suggest(t *testing.T) { s := newSampleSpellchecker() result, err := s.Suggest("arang", 5)