diff --git a/vapid.go b/vapid.go index ae26cfd..e9c010f 100644 --- a/vapid.go +++ b/vapid.go @@ -85,8 +85,8 @@ func generateVAPIDHeaders( return nil, err } - // Unless subscriber is an HTTPS URL, assume an e-mail address - if !strings.HasPrefix(subscriber, "https:") { + // Unless subscriber is already a URI (https: or mailto:), assume an e-mail address. + if !strings.HasPrefix(subscriber, "https:") && !strings.HasPrefix(subscriber, "mailto:") { subscriber = "mailto:" + subscriber } diff --git a/vapid_test.go b/vapid_test.go index c665e39..4190b64 100644 --- a/vapid_test.go +++ b/vapid_test.go @@ -108,6 +108,58 @@ func TestVAPID(t *testing.T) { } } +func TestVAPIDSubscriberFormats(t *testing.T) { + s := getStandardEncodedTestSubscription() + vapidPrivateKey, vapidPublicKey, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + input string + expected string + }{ + {"bare email", "test@test.com", "mailto:test@test.com"}, + {"mailto URI", "mailto:test@test.com", "mailto:test@test.com"}, + {"https URI", "https://example.com/contact", "https://example.com/contact"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vapidHeaders, err := generateVAPIDHeaders( + s.Endpoint, + tt.input, + vapidPublicKey, + vapidPrivateKey, + time.Now().Add(time.Hour), + Vapid, + ) + if err != nil { + t.Fatal(err) + } + + tokenString := getTokenFromAuthorizationHeader(vapidHeaders["Authorization"], Vapid, t) + token, _ := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + b64 := base64.RawURLEncoding + decodedVapidPrivateKey, err := b64.DecodeString(vapidPrivateKey) + if err != nil { + return nil, err + } + return generateVAPIDHeaderKeys(decodedVapidPrivateKey).Public(), nil + }) + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + t.Fatal("invalid token") + } + if claims["sub"] != tt.expected { + t.Fatalf("sub: expected=%q got=%q", tt.expected, claims["sub"]) + } + }) + } +} + func TestVAPIDKeys(t *testing.T) { privateKey, publicKey, err := GenerateVAPIDKeys() if err != nil { diff --git a/webpush.go b/webpush.go index e092270..2b97f97 100644 --- a/webpush.go +++ b/webpush.go @@ -45,7 +45,7 @@ type Options struct { AuthScheme AuthScheme // VAPID authentication scheme, defaults to "vapid" HTTPClient HTTPClient // Will replace with *http.Client by default if not included RecordSize uint32 // Limit the record size - Subscriber string // Sub in VAPID JWT token + Subscriber string // Sub in VAPID JWT token. Accepts a "mailto:" or "https:" URI; a bare value is treated as an e-mail and prefixed with "mailto:". Topic string // Set the Topic header to collapse a pending messages (Optional) TTL int // Set the TTL on the endpoint POST request Urgency Urgency // Set the Urgency header to change a message priority (Optional)