Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.25

- name: Test
run: go test ./...
49 changes: 21 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Yet another spellchecker written in go.
## Installation

```
go get -v github.com/f1monkey/spellchecker
go get -v github.com/f1monkey/spellchecker/v2
```

## Usage
Expand All @@ -35,39 +35,43 @@ func main() {
// Create a new instance
sc, err := spellchecker.New(
"abcdefghijklmnopqrstuvwxyz1234567890", // allowed symbols, other symbols will be ignored
spellchecker.WithMaxErrors(2) // see options.go
)
if err != nil {
panic(err)
}

// The weight increases the likelihood that the word will be chosen as a correction.
weight := uint(1)

// Load data from any io.Reader
in, err := os.Open("data/sample.txt")
if err != nil {
panic(err)
}
sc.AddFrom(in)

sc.AddFrom(&spellchecker.AddOptions{Weight: weight}, in)
// OR
sc.AddFrom(nil, in)

// Add words manually
sc.Add("lock", "stock", "and", "two", "smoking", "barrels")
sc.Add(nil, "lock", "stock", "and", "two", "smoking", "barrels")

// Check if a word is valid
result := sc.IsCorrect("coffee")
fmt.Println(result) // true

// Correct a single word
fixed, err := sc.Fix("awepon")
if err != nil && !errors.Is(err, spellchecker.ErrUnknownWord) {
panic(err)
}
fixed, isCorrect := sc.Fix(nil, "awepon")
fmt.Println(isCorrect) // false
fmt.Println(fixed) // weapon

// Find up to 10 suggestions for a word
matches, err := sc.Suggest("rang", 10)
if err != nil && !errors.Is(err, spellchecker.ErrUnknownWord) {
panic(err)
}
matches := sc.Suggest(nil, "rang", 10)
fmt.Println(matches) // [range, orange]

if len(os.Args) < 2 {
log.Fatal("dict path must be provided")
}
```

### Options
Expand Down Expand Up @@ -113,17 +117,7 @@ You can provide a custom scoring function if needed:
// handle err
}

// After loading a spellchecker from a file,
// you need to set the function again:
sc, err = spellchecker.Load(inFile)
if err != nil {
// handle err
}

err = sc.WithOpts(spellchecker.WithFilterFunc(fn))
if err != nil {
// handle err
}
sc.Fix(fn, "word")
```


Expand All @@ -140,9 +134,9 @@ goos: linux
goarch: amd64
pkg: github.com/f1monkey/spellchecker
cpu: 13th Gen Intel(R) Core(TM) i9-13980HX
Benchmark_Norvig1-32 348 3385868 ns/op 74.44 success_percent 201.0 success_words 270.0 total_words 830803 B/op 15504 allocs/op
Benchmark_Norvig1-32 357 3305052 ns/op 74.44 success_percent 201.0 success_words 270.0 total_words 768899 B/op 13302 allocs/op
PASS
ok github.com/f1monkey/spellchecker 3.723s
ok github.com/f1monkey/spellchecker 3.801s
```

#### [Test set 2](http://norvig.com/spell-testset2.txt):
Expand All @@ -154,8 +148,7 @@ goos: linux
goarch: amd64
pkg: github.com/f1monkey/spellchecker
cpu: 13th Gen Intel(R) Core(TM) i9-13980HX
Benchmark_Norvig2-32 231 4935406 ns/op 71.25 success_percent 285.0 success_words 400.0 total_words 1270755 B/op 21801 allocs/op
Benchmark_Norvig2-32 236 5257185 ns/op 71.25 success_percent 285.0 success_words 400.0 total_words 1201260 B/op 19346 allocs/op
PASS
ok github.com/f1monkey/spellchecker 4.057s

ok github.com/f1monkey/spellchecker 4.350s
```
74 changes: 26 additions & 48 deletions dictionary.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,29 @@ import (
)

type dictionary struct {
maxErrors int
alphabet alphabet
nextID func() uint32
alphabet alphabet
nextID func() uint32

words map[uint32][]rune
ids map[string]uint32
counts map[uint32]uint

index map[uint64][]uint32

filterFunc FilterFunc
}

func newDictionary(ab string, filterFunc FilterFunc, maxErrors int) (*dictionary, error) {
func newDictionary(ab string) (*dictionary, error) {
alphabet, err := newAlphabet(ab)
if err != nil {
return nil, err
}

return &dictionary{
maxErrors: maxErrors,
alphabet: alphabet,
nextID: idSeq(0),
ids: make(map[string]uint32),
words: make(map[uint32][]rune),
counts: make(map[uint32]uint),
index: make(map[uint64][]uint32),
filterFunc: filterFunc,
alphabet: alphabet,
nextID: idSeq(0),
ids: make(map[string]uint32),
words: make(map[uint32][]rune),
counts: make(map[uint32]uint),
index: make(map[uint64][]uint32),
}, nil
}

Expand Down Expand Up @@ -81,8 +76,8 @@ type Match struct {
Score float64
}

func (d *dictionary) find(word string, n int) []Match {
if d.maxErrors <= 0 {
func (d *dictionary) find(word string, n int, maxErrors int, fn FilterFunc) []Match {
if maxErrors <= 0 {
return nil
}

Expand All @@ -93,15 +88,15 @@ func (d *dictionary) find(word string, n int) []Match {

// check for transposition or exact match and do early termination if found
// (the most common mistake is a transposition of letters)
d.fillWithCandidates(result, wordRunes, sum(bmSrc))
d.fillWithCandidates(result, wordRunes, sum(bmSrc), fn)
if result.Len() != 0 {
return result.DrainSorted()
}

bitmaps := bitmapsPool.Get().(map[uint64]struct{})
d.computeCandidateBitmaps(bitmaps, bmSrc, d.maxErrors)
d.computeCandidateBitmaps(bitmaps, bmSrc, maxErrors)
for bm := range bitmaps {
d.fillWithCandidates(result, wordRunes, bm)
d.fillWithCandidates(result, wordRunes, bm, fn)
}

releaseBitmaps(bitmaps)
Expand Down Expand Up @@ -131,15 +126,15 @@ func (d *dictionary) computeCandidateBitmaps(bitmaps map[uint64]struct{}, src bi
dfs(src.Clone(), 0, 0)
}

func (d *dictionary) fillWithCandidates(result *priorityQueue, wordRunes []rune, bm uint64) {
func (d *dictionary) fillWithCandidates(result *priorityQueue, wordRunes []rune, bm uint64, filter FilterFunc) {
ids := d.index[bm]
for _, id := range ids {
docWord, ok := d.words[id]
if !ok {
continue
}

score, ok := d.filterFunc(wordRunes, docWord, d.counts[id])
score, ok := filter(wordRunes, docWord, d.counts[id])
if !ok {
continue
}
Expand All @@ -155,25 +150,21 @@ var _ encoding.BinaryMarshaler = (*dictionary)(nil)
var _ encoding.BinaryUnmarshaler = (*dictionary)(nil)

type dictData struct {
Alphabet alphabet
IDs map[string]uint32
Words map[uint32]string
WordRunes map[uint32][]rune
Counts map[uint32]uint
Alphabet alphabet
IDs map[string]uint32
Words map[uint32][]rune
Counts map[uint32]uint

Index map[uint64][]uint32

MaxErrors int
}

func (d *dictionary) MarshalBinary() ([]byte, error) {
data := &dictData{
Alphabet: d.alphabet,
IDs: d.ids,
WordRunes: d.words,
Counts: d.counts,
Index: d.index,
MaxErrors: d.maxErrors,
Alphabet: d.alphabet,
IDs: d.ids,
Words: d.words,
Counts: d.counts,
Index: d.index,
}

buf := &bytes.Buffer{}
Expand All @@ -195,21 +186,8 @@ func (d *dictionary) UnmarshalBinary(data []byte) error {
d.alphabet = dictData.Alphabet
d.ids = dictData.IDs
d.counts = dictData.Counts

// compatibility with previous versions
if len(dictData.Words) > 0 {
wordRunes := make(map[uint32][]rune, len(dictData.Words))
for k, v := range dictData.Words {
wordRunes[k] = []rune(v)
}
d.words = wordRunes
} else {
d.words = dictData.WordRunes
}

d.index = dictData.Index
d.maxErrors = dictData.MaxErrors
d.filterFunc = defaultFilterFunc(dictData.MaxErrors)
d.words = dictData.Words

var max uint32
for _, id := range d.ids {
Expand Down
6 changes: 3 additions & 3 deletions dictionary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func Test_dictionary_id(t *testing.T) {
dict, err := newDictionary(DefaultAlphabet, nil, DefaultMaxErrors)
dict, err := newDictionary(DefaultAlphabet)
require.NoError(t, err)

t.Run("must return 0 for unexisting word", func(t *testing.T) {
Expand All @@ -24,7 +24,7 @@ func Test_dictionary_id(t *testing.T) {

func Test_dictionary_add(t *testing.T) {
t.Run("must add word to dictionary index", func(t *testing.T) {
dict, err := newDictionary(DefaultAlphabet, nil, DefaultMaxErrors)
dict, err := newDictionary(DefaultAlphabet)
require.NoError(t, err)

id, err := dict.add("qwe", 1)
Expand All @@ -49,7 +49,7 @@ func Test_dictionary_add(t *testing.T) {

func Test_Dictionary_Inc(t *testing.T) {
t.Run("must increase counter value", func(t *testing.T) {
dict, err := newDictionary(DefaultAlphabet, nil, DefaultMaxErrors)
dict, err := newDictionary(DefaultAlphabet)
dict.counts[1] = 0
require.NoError(t, err)

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/f1monkey/spellchecker

go 1.19
go 1.24

require (
github.com/agext/levenshtein v1.2.3
Expand Down
Loading