Skip to content

Commit a707cf9

Browse files
committed
Deprecate ScoreFunc, replace it with FilterFunc to provide more control over the result composition
1 parent 2ee49c5 commit a707cf9

6 files changed

Lines changed: 134 additions & 59 deletions

File tree

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,13 @@ See [options.go](./options.go) for the list of available options.
102102
You can provide a custom scoring function if needed:
103103
104104
```go
105-
var scoreFunc spellchecker.ScoreFunc = func(src, candidate []rune, distance, cnt int) float64 {
106-
return 1.0 // constant score
105+
var fn spellchecker.FilterFunc = func(src, candidate []rune, cnt int) (float64, bool) {
106+
// you can calculate Levenshtein distance here (see defaultFilterFunc in options.go for example)
107+
108+
return 1.0, true // constant score
107109
}
108110

109-
sc, err := spellchecker.New("abc", spellchecker.WithScoreFunc(scoreFunc))
111+
sc, err := spellchecker.New("abc", spellchecker.WithFilterFunc(fn))
110112
if err != nil {
111113
// handle err
112114
}
@@ -118,7 +120,7 @@ You can provide a custom scoring function if needed:
118120
// handle err
119121
}
120122

121-
err = sc.WithOpts(spellchecker.WithScoreFunc(scoreFunc))
123+
err = sc.WithOpts(spellchecker.WithFilterFunc(fn))
122124
if err != nil {
123125
// handle err
124126
}

dictionary.go

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,9 @@ import (
77
"sort"
88
"sync/atomic"
99

10-
"github.com/agext/levenshtein"
1110
"github.com/f1monkey/bitmap"
1211
)
1312

14-
type scoreFunc func(src []rune, candidate []rune, distance int, cnt uint) float64
15-
1613
type dictionary struct {
1714
maxErrors int
1815
alphabet alphabet
@@ -24,24 +21,24 @@ type dictionary struct {
2421

2522
index map[uint64][]uint32
2623

27-
scoreFunc scoreFunc
24+
filterFunc FilterFunc
2825
}
2926

30-
func newDictionary(ab string, scoreFunc scoreFunc, maxErrors int) (*dictionary, error) {
27+
func newDictionary(ab string, filterFunc FilterFunc, maxErrors int) (*dictionary, error) {
3128
alphabet, err := newAlphabet(ab)
3229
if err != nil {
3330
return nil, err
3431
}
3532

3633
return &dictionary{
37-
maxErrors: maxErrors,
38-
alphabet: alphabet,
39-
nextID: idSeq(0),
40-
ids: make(map[string]uint32),
41-
words: make(map[uint32]string),
42-
counts: make(map[uint32]uint),
43-
index: make(map[uint64][]uint32),
44-
scoreFunc: scoreFunc,
34+
maxErrors: maxErrors,
35+
alphabet: alphabet,
36+
nextID: idSeq(0),
37+
ids: make(map[string]uint32),
38+
words: make(map[uint32]string),
39+
counts: make(map[uint32]uint),
40+
index: make(map[uint64][]uint32),
41+
filterFunc: filterFunc,
4542
}, nil
4643
}
4744

@@ -60,10 +57,9 @@ func (d *dictionary) add(word string, n uint) (uint32, error) {
6057
id := d.nextID()
6158
d.ids[word] = id
6259

63-
runes := []rune(word)
6460
d.counts[id] = n
6561
d.words[id] = word
66-
key := sum(d.alphabet.encode(runes))
62+
key := sum(d.alphabet.encode([]rune(word)))
6763
d.index[key] = append(d.index[key], id)
6864

6965
return id, nil
@@ -148,14 +144,14 @@ func (d *dictionary) fillWithCandidates(result *priorityQueue, wordRunes []rune,
148144
continue
149145
}
150146

151-
distance, _, _ := levenshtein.Calculate(wordRunes, []rune(docWord), 0, 1, 1, 1)
152-
if distance > d.maxErrors {
147+
score, ok := d.filterFunc(wordRunes, []rune(docWord), d.counts[id])
148+
if !ok {
153149
continue
154150
}
155151

156152
result.Push(Match{
157153
Value: docWord,
158-
Score: d.scoreFunc(wordRunes, []rune(docWord), distance, d.counts[id]),
154+
Score: score,
159155
})
160156
}
161157
}
@@ -206,7 +202,7 @@ func (d *dictionary) UnmarshalBinary(data []byte) error {
206202
d.words = dictData.Words
207203
d.index = dictData.Index
208204
d.maxErrors = dictData.MaxErrors
209-
d.scoreFunc = defaultScorefunc
205+
d.filterFunc = defaultFilterFunc(dictData.MaxErrors)
210206

211207
var max uint32
212208
for _, id := range d.ids {

dictionary_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
)
88

99
func Test_dictionary_id(t *testing.T) {
10-
dict, err := newDictionary(DefaultAlphabet, defaultScorefunc, DefaultMaxErrors)
10+
dict, err := newDictionary(DefaultAlphabet, nil, DefaultMaxErrors)
1111
require.NoError(t, err)
1212

1313
t.Run("must return 0 for unexisting word", func(t *testing.T) {
@@ -24,7 +24,7 @@ func Test_dictionary_id(t *testing.T) {
2424

2525
func Test_dictionary_add(t *testing.T) {
2626
t.Run("must add word to dictionary index", func(t *testing.T) {
27-
dict, err := newDictionary(DefaultAlphabet, defaultScorefunc, DefaultMaxErrors)
27+
dict, err := newDictionary(DefaultAlphabet, nil, DefaultMaxErrors)
2828
require.NoError(t, err)
2929

3030
id, err := dict.add("qwe", 1)
@@ -49,7 +49,7 @@ func Test_dictionary_add(t *testing.T) {
4949

5050
func Test_Dictionary_Inc(t *testing.T) {
5151
t.Run("must increase counter value", func(t *testing.T) {
52-
dict, err := newDictionary(DefaultAlphabet, defaultScorefunc, DefaultMaxErrors)
52+
dict, err := newDictionary(DefaultAlphabet, nil, DefaultMaxErrors)
5353
dict.counts[1] = 0
5454
require.NoError(t, err)
5555

options.go

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,30 @@ package spellchecker
33
import (
44
"bufio"
55
"math"
6+
7+
"github.com/agext/levenshtein"
68
)
79

10+
// WithOpt set spellchecker options
11+
func (s *Spellchecker) WithOpts(opts ...OptionFunc) error {
12+
s.mtx.Lock()
13+
defer s.mtx.Unlock()
14+
15+
for _, o := range opts {
16+
if err := o(s); err != nil {
17+
return err
18+
}
19+
}
20+
21+
if s.scoreFunc != nil {
22+
s.dict.filterFunc = wrapScoreFunc(s.scoreFunc, s.maxErrors)
23+
} else {
24+
s.dict.filterFunc = s.filterFunc
25+
}
26+
27+
return nil
28+
}
29+
830
// WithSplitter set splitter func for AddFrom() reader
931
func WithSplitter(f bufio.SplitFunc) OptionFunc {
1032
return func(s *Spellchecker) error {
@@ -15,28 +37,82 @@ func WithSplitter(f bufio.SplitFunc) OptionFunc {
1537

1638
// WithMaxErrors sets maxErrors — the maximum allowed difference in bits
1739
// between the "search word" and a "dictionary word".
18-
// For example, replacing a single character (problam => problem)
19-
// is treated as a two-bit difference.
20-
// It is not recommended to set a value greater than 2,
21-
// as it can significantly impact performance.
40+
// - deletion is a 1-bit change (proble → problem)
41+
// - insertion is a 1-bit change (problemm → problem)
42+
// - substitution is a 2-bit change (problam → problem)
43+
// - transposition is a 0-bit change (problme → problem)
44+
//
45+
// It is not recommended to set this value greater than 2,
46+
// as it can significantly affect performance.
2247
func WithMaxErrors(maxErrors int) OptionFunc {
2348
return func(s *Spellchecker) error {
2449
s.maxErrors = maxErrors
50+
2551
return nil
2652
}
2753
}
2854

29-
type ScoreFunc = scoreFunc
55+
// FilterFunc compares the source word with a candidate word.
56+
// It returns the candidate's score and a boolean flag.
57+
// If the flag is false, the candidate will be completely filtered out.
58+
type FilterFunc func(src, candidate []rune, count uint) (float64, bool)
59+
60+
// WithFilterFunc set custom scoring function
61+
func WithFilterFunc(f FilterFunc) OptionFunc {
62+
return func(s *Spellchecker) error {
63+
s.filterFunc = f
64+
return nil
65+
}
66+
}
67+
68+
// ScoreFunc custom scoring function type
69+
//
70+
// Deprecated: use FilterFunc instead
71+
type ScoreFunc func(src []rune, candidate []rune, distance int, cnt uint) float64
3072

3173
// WithScoreFunc specify a function that will be used for scoring
74+
//
75+
// Deprecated: use WithFilterFunc instead
3276
func WithScoreFunc(f ScoreFunc) OptionFunc {
3377
return func(s *Spellchecker) error {
34-
s.dict.scoreFunc = f
78+
s.scoreFunc = f
3579
return nil
3680
}
3781
}
3882

39-
var defaultScorefunc scoreFunc = func(src, candidate []rune, distance int, cnt uint) float64 {
83+
func defaultFilterFunc(maxErrors int) FilterFunc {
84+
return func(src, candidate []rune, count uint) (float64, bool) {
85+
distance, _, _ := levenshtein.Calculate(src, candidate, 0, 1, 1, 1)
86+
if distance > maxErrors {
87+
return 0, false
88+
}
89+
90+
mult := math.Log1p(float64(count))
91+
// if first letters are the same, increase score
92+
if src[0] == candidate[0] {
93+
mult *= 1.5
94+
// if second letters are the same too, increase score even more
95+
if len(src) > 1 && len(candidate) > 1 && src[1] == candidate[1] {
96+
mult *= 1.5
97+
}
98+
}
99+
100+
return 1 / (1 + float64(distance*distance)) * mult, true
101+
}
102+
}
103+
104+
func wrapScoreFunc(f ScoreFunc, maxErrors int) FilterFunc {
105+
return func(src, candidate []rune, count uint) (float64, bool) {
106+
distance, _, _ := levenshtein.Calculate(src, candidate, 0, 1, 1, 1)
107+
if distance > maxErrors {
108+
return 0, false
109+
}
110+
111+
return f(src, candidate, distance, count), true
112+
}
113+
}
114+
115+
var defaultScoreFunc ScoreFunc = func(src, candidate []rune, distance int, cnt uint) float64 {
40116
mult := math.Log1p(float64(cnt))
41117
// if first letters are the same, increase score
42118
if src[0] == candidate[0] {

spellchecker.go

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,29 +15,36 @@ type OptionFunc func(s *Spellchecker) error
1515
type Spellchecker struct {
1616
mtx sync.RWMutex
1717

18-
dict *dictionary
19-
splitter bufio.SplitFunc
20-
scoreFunc scoreFunc
21-
maxErrors int
18+
dict *dictionary
19+
splitter bufio.SplitFunc
20+
filterFunc FilterFunc
21+
scoreFunc ScoreFunc
22+
maxErrors int
2223
}
2324

2425
func New(alphabet string, opts ...OptionFunc) (*Spellchecker, error) {
2526
result := &Spellchecker{
26-
maxErrors: DefaultMaxErrors,
27-
scoreFunc: defaultScorefunc,
27+
maxErrors: DefaultMaxErrors,
28+
filterFunc: defaultFilterFunc(DefaultMaxErrors),
2829
}
29-
dict, err := newDictionary(alphabet, result.scoreFunc, result.maxErrors)
30-
if err != nil {
31-
return nil, err
32-
}
33-
result.dict = dict
3430

3531
for _, o := range opts {
3632
if err := o(result); err != nil {
3733
return nil, err
3834
}
3935
}
4036

37+
if result.scoreFunc != nil {
38+
result.filterFunc = wrapScoreFunc(result.scoreFunc, result.maxErrors)
39+
}
40+
41+
dict, err := newDictionary(alphabet, result.filterFunc, result.maxErrors)
42+
if err != nil {
43+
return nil, err
44+
}
45+
46+
result.dict = dict
47+
4148
return result, nil
4249
}
4350

@@ -162,17 +169,3 @@ func (s *Spellchecker) SuggestScore(word string, n int) SuggestionResult {
162169
Suggestions: s.dict.find(word, n),
163170
}
164171
}
165-
166-
// WithOpt set spellchecker options
167-
func (s *Spellchecker) WithOpts(opts ...OptionFunc) error {
168-
s.mtx.Lock()
169-
defer s.mtx.Unlock()
170-
171-
for _, o := range opts {
172-
if err := o(s); err != nil {
173-
return err
174-
}
175-
}
176-
177-
return nil
178-
}

spellchecker_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ func benchmarkNorvig(b *testing.B, dataPath string) {
164164
total := 0
165165
ok := 0
166166
for i := 0; i < b.N; i++ {
167-
168167
for _, item := range data {
169168
for _, word := range item.words {
170169
if word == "" {
@@ -236,6 +235,15 @@ func Test_Spellchecker_Fix(t *testing.T) {
236235
require.Equal(t, "problem", result)
237236
}
238237

238+
func Test_Spellchecker_Fix_ScoreFunc(t *testing.T) {
239+
s := newSampleSpellchecker()
240+
s.WithOpts(WithScoreFunc(defaultScoreFunc))
241+
242+
result, err := s.Fix("problam")
243+
require.NoError(t, err)
244+
require.Equal(t, "problem", result)
245+
}
246+
239247
func Test_Spellchecker_Suggest(t *testing.T) {
240248
s := newSampleSpellchecker()
241249
result, err := s.Suggest("arang", 5)

0 commit comments

Comments
 (0)