-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_test.go
More file actions
100 lines (93 loc) · 2.31 KB
/
main_test.go
File metadata and controls
100 lines (93 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package main
import (
"encoding/base64"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
)
func withEnv(t *testing.T, key, val string, fn func()) {
t.Helper()
old, had := os.LookupEnv(key)
if val == "" {
_ = os.Unsetenv(key)
} else {
_ = os.Setenv(key, val)
}
defer func() {
if had {
_ = os.Setenv(key, old)
} else {
_ = os.Unsetenv(key)
}
}()
fn()
}
func TestReadBasicAuth_Empty(t *testing.T) {
withEnv(t, "BASIC_AUTH_USER", "", func() {
withEnv(t, "BASIC_AUTH_PASSWORD", "", func() {
u, p, err := readBasicAuth()
if err != nil || u != "" || p != "" {
t.Fatalf("got %q %q %v", u, p, err)
}
})
})
}
func TestReadBasicAuth_PairEnv(t *testing.T) {
withEnv(t, "BASIC_AUTH_USER", "u", func() {
withEnv(t, "BASIC_AUTH_PASSWORD", "p", func() {
u, p, err := readBasicAuth()
if err != nil || u != "u" || p != "p" {
t.Fatalf("got %q %q %v", u, p, err)
}
})
})
}
func TestReadBasicAuth_File(t *testing.T) {
dir := t.TempDir()
uf := filepath.Join(dir, "u.txt")
pf := filepath.Join(dir, "p.txt")
os.WriteFile(uf, []byte("u\n"), 0o600)
os.WriteFile(pf, []byte("p\n"), 0o600)
withEnv(t, "BASIC_AUTH_USER", "", func() {
withEnv(t, "BASIC_AUTH_PASSWORD", "", func() {
withEnv(t, "BASIC_AUTH_USER_FILE", uf, func() {
withEnv(t, "BASIC_AUTH_PASSWORD_FILE", pf, func() {
u, p, err := readBasicAuth()
if err != nil || u != "u" || p != "p" {
t.Fatalf("got %q %q %v", u, p, err)
}
})
})
})
})
}
func TestReadBasicAuth_Mismatch(t *testing.T) {
withEnv(t, "BASIC_AUTH_USER", "u", func() {
withEnv(t, "BASIC_AUTH_PASSWORD", "", func() {
_, _, err := readBasicAuth()
if err == nil {
t.Fatalf("expected error")
}
})
})
}
func TestBasicAuthMiddleware(t *testing.T) {
h := BasicAuthMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(204)
}), "u", "p")
req := httptest.NewRequest("GET", "http://x/", nil)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != 401 {
t.Fatalf("want 401 got %d", rec.Code)
}
req2 := httptest.NewRequest("GET", "http://x/", nil)
req2.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("u:p")))
rec2 := httptest.NewRecorder()
h.ServeHTTP(rec2, req2)
if rec2.Code != 204 {
t.Fatalf("want 204 got %d", rec2.Code)
}
}