-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils_test.go
More file actions
124 lines (109 loc) · 3.03 KB
/
utils_test.go
File metadata and controls
124 lines (109 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package tokenizers
import (
"os"
"path/filepath"
"testing"
"unsafe"
"github.com/stretchr/testify/require"
)
func TestMasksFromBuf(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
buf := Buffer{}
s, a := MasksFromBuf(buf)
require.Empty(t, s)
require.Empty(t, a)
})
t.Run("SpecialTokenMask", func(t *testing.T) {
specialTokens := []uint32{10, 20, 30}
buf := Buffer{
Len: 3,
SpecialTokensMask: &specialTokens[0],
}
s, a := MasksFromBuf(buf)
require.Len(t, s, len(specialTokens))
require.Equal(t, specialTokens, s)
require.Empty(t, a)
})
t.Run("AttentionMask", func(t *testing.T) {
attentionMask := []uint32{10, 20, 30}
buf := Buffer{
Len: 3,
AttentionMask: &attentionMask[0],
}
s, a := MasksFromBuf(buf)
require.Len(t, a, len(attentionMask))
require.Equal(t, attentionMask, a)
require.Empty(t, s)
})
}
type CStringArray struct {
ptrs []*byte // slice of string pointers
bufs [][]byte // backing storage (NUL-terminated)
count int
}
func StringsToPtrArray(strs []string) *CStringArray {
a := &CStringArray{
ptrs: make([]*byte, len(strs)),
bufs: make([][]byte, len(strs)),
count: len(strs),
}
for i, s := range strs {
b := make([]byte, len(s)+1) // +1 for NUL
copy(b, s)
a.bufs[i] = b
a.ptrs[i] = &b[0]
}
return a
}
func (a *CStringArray) Ptr() **byte {
if len(a.ptrs) == 0 {
return nil
}
return (**byte)(unsafe.Pointer(&a.ptrs[0]))
}
func (a *CStringArray) Len() int { return a.count }
func TestTokensFromBuf(t *testing.T) {
t.Run("Empty", func(t *testing.T) {
buf := Buffer{}
tokens := TokensFromBuf(buf)
require.Empty(t, tokens)
})
t.Run("with tokens", func(t *testing.T) {
buffTokens := []string{"hello", ",", "world"}
ptrs := StringsToPtrArray(buffTokens)
buf := Buffer{
Len: 3,
Tokens: ptrs.Ptr(),
}
tokens := TokensFromBuf(buf)
require.NotEmpty(t, tokens)
require.Len(t, tokens, len(buffTokens))
for i, token := range tokens {
require.Equal(t, buffTokens[i], token)
}
})
}
func TestLoadTokenizerLibrary(t *testing.T) {
t.Run("Invalid path", func(t *testing.T) {
_, err := LoadTokenizerLibrary("invalid/path/to/libtokenizers.so")
require.Error(t, err)
require.Contains(t, err.Error(), "library file not found at user-provided path")
})
t.Run("Invalid library", func(t *testing.T) {
fakeLibPath := filepath.Join(t.TempDir(), "fake_library.so")
err := os.WriteFile(fakeLibPath, []byte("not a valid library"), 0644)
require.NoError(t, err)
_, err = LoadTokenizerLibrary(fakeLibPath)
require.Error(t, err)
require.Contains(t, err.Error(), "failed to load library from user-provided path")
})
t.Run("Invalid env var library path", func(t *testing.T) {
//fakeLibPath := filepath.Join(t.TempDir(), "fake_library.so")
//err := os.WriteFile(fakeLibPath, []byte("not a valid library"), 0644)
//require.NoError(t, err)
t.Setenv("TOKENIZERS_LIB_PATH", "invalid")
_, err := LoadTokenizerLibrary("")
require.Error(t, err)
require.Contains(t, err.Error(), "library file not found at TOKENIZERS_LIB_PATH")
})
}