From dcf6893e6518e4ce7960d702862305745eff17e5 Mon Sep 17 00:00:00 2001 From: Willi Schinmeyer Date: Mon, 6 Nov 2023 16:01:53 +0100 Subject: [PATCH 1/2] Add end to end test This new test mocks both the user agent (e.g. browser) and the push service (e.g. Firestore) to verify that encryption and decryption works properly. I used the RFCs as reference (RFC8291, RFC8292 & RFC 8188), but didn't follow them to the letter. The result can successfully check all the signatures and decrypt the content, so it seems to be working. Instead of the deprecated crypto/elliptic functions, this makes heavy use of crypto/ecdh, which require Go 1.20. But this is only a test dependency, library users should not be impacted. --- end2end_test.go | 462 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 462 insertions(+) create mode 100644 end2end_test.go diff --git a/end2end_test.go b/end2end_test.go new file mode 100644 index 0000000..f7858f4 --- /dev/null +++ b/end2end_test.go @@ -0,0 +1,462 @@ +package webpush + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/golang-jwt/jwt" + "golang.org/x/crypto/hkdf" +) + +func TestEnd2End(t *testing.T) { + var ( + // the data known to the application server (backend, which uses webpush-go) + applicationServer struct { + publicVAPIDKey string + privateVAPIDKey string + subscription Subscription + } + // the data known to the user agent (browser) + userAgent struct { + publicVAPIDKey *ecdsa.PublicKey + subscriptionKey *ecdsa.PrivateKey + authSecret [16]byte + subscription Subscription + receivedNotifications [][]byte + } + // the data known to the push server (which receives push messages on behalf of the user agent, e.g. Firestore) + pushService struct { + applicationServerKey *ecdsa.PublicKey + receivedNotifications [][]byte + } + + err error + ) + + // a VAPID key pair for the application server, usually only generated once and reused + applicationServer.privateVAPIDKey, applicationServer.publicVAPIDKey, err = GenerateVAPIDKeys() + if err != nil { + t.Fatalf("generating VAPID keys: %s", err) + } + + // The application server needs to inform the user agent of the public VAPID key. + // (We decode it first for ease of use.) + userAgent.publicVAPIDKey, err = decodeVAPIDPublicKey(applicationServer.publicVAPIDKey) + if err != nil { + t.Fatal(err) + } + + // We need a mock push service for webpush-go to send notifications to. + var mockPushService *httptest.Server + mockPushService = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // check that there's a valid vapid JWT + token, err := parseVapidAuthHeader( + r.Header.Get("Authorization"), + // by the time this function is called, this value will be set (see PushManager.subscribe() below) + pushService.applicationServerKey) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + _, _ = fmt.Fprintf(w, "invalid auth: %s", err) + return + } + // verify that the audience matches our URL + aud := token.Claims.(jwt.MapClaims)["aud"] + if aud != mockPushService.URL { + w.WriteHeader(http.StatusUnauthorized) + _, _ = fmt.Fprintf(w, "JWT has bad audience, want %q, got %q", mockPushService.URL, aud) + return + } + // RFC8188 only allows for exactly one content encoding + if contentEncoding := r.Header.Get("Content-Encoding"); contentEncoding != "aes128gcm" { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, "unsupported Content-Encoding, want %q, got %q", "aes128gcm", contentEncoding) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + // this suggests a broken connection, so log the error instead of sending it back + t.Errorf("failed to read request body: %s", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + // store body for later decoding by user agent + // (the push service doesn't have the key required for decryption) + pushService.receivedNotifications = append(pushService.receivedNotifications, body) + + w.WriteHeader(http.StatusAccepted) + })) + defer mockPushService.Close() + + // what follows is the equivalent of PushManager.subscribe() in JS + { + // the user agent generates its own key pair so it can be sent encrypted messages + userAgent.subscriptionKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generating user agent keys: %s", err) + } + // we need the ECDH representation + ecdhPublicKey, err := userAgent.subscriptionKey.PublicKey.ECDH() + if err != nil { + t.Fatalf("converting user agent public key to ECDH: %s", err) + } + // generate the shared auth secret + _, err = rand.Read(userAgent.authSecret[:]) + if err != nil { + t.Fatalf("generating user agent auth secret: %s", err) + } + // the user agent then performs a registration with the push service using that key, + // while also letting the push service know the application server key to expect. + pushService.applicationServerKey = userAgent.publicVAPIDKey + userAgent.subscription = Subscription{ + Keys: Keys{ + Auth: base64.StdEncoding.EncodeToString(userAgent.authSecret[:]), + P256dh: base64.StdEncoding.EncodeToString(ecdhPublicKey.Bytes()), + }, + Endpoint: mockPushService.URL, + } + } + + // the user agent sends its subscription to the application server... + applicationServer.subscription = userAgent.subscription + + // ...and the application server uses the subscription to send a push notification + sentMessage := "this is our test push notification" + resp, err := SendNotification([]byte(sentMessage), &applicationServer.subscription, &Options{ + HTTPClient: mockPushService.Client(), + VAPIDPublicKey: applicationServer.publicVAPIDKey, + VAPIDPrivateKey: applicationServer.privateVAPIDKey, + Subscriber: "test@example.com", + }) + if err != nil { + t.Fatalf("failed to send notification: %s", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Errorf("error closing mock push service response body: %s", err) + } + }() + // check for success + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading mock push service response body: %s", err) + } + if resp.StatusCode/100 != 2 { + t.Errorf("unexpected push service status code %d, body: %s", resp.StatusCode, respBody) + } + + // the push server should now have received the notification + if l := len(pushService.receivedNotifications); l != 1 { + t.Fatalf("Want 1 notification received by push service, got %d", l) + } + // the push service then forwards the notification to the user agent + userAgent.receivedNotifications = pushService.receivedNotifications + // and the user agent can decrypt them + receivedMessage, err := decodeNotification(userAgent.receivedNotifications[0], userAgent.authSecret, userAgent.subscriptionKey) + if err != nil { + t.Fatalf("error decrypting notification in user agent: %s", err) + } + if receivedMessage != sentMessage { + t.Errorf("Sent notification %q, but got %q", sentMessage, receivedMessage) + } +} + +func decodeVAPIDPublicKey(publicVAPIDKey string) (*ecdsa.PublicKey, error) { + publicVAPIDKeyBytes, err := base64.RawURLEncoding.DecodeString(publicVAPIDKey) + if err != nil { + return nil, fmt.Errorf("base64-decoding public VAPID key: %w", err) + } + return decodeECDSAPublicKey(publicVAPIDKeyBytes) +} + +func decodeECDSAPublicKey(bytes []byte) (*ecdsa.PublicKey, error) { + ecdhKey, err := ecdh.P256().NewPublicKey(bytes) + if err != nil { + return nil, fmt.Errorf("parsing public VAPID key: %w", err) + } + res, err := ecdhPublicKeyToECDSA(ecdhKey) + if err != nil { + return nil, fmt.Errorf("converting public VAPID key from *ecdh.PublicKey to *ecdsa.PublicKey: %w", err) + } + return res, nil +} + +func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { + // see https://github.com/golang/go/issues/63963 + rawKey := key.Bytes() + switch key.Curve() { + case ecdh.P256(): + return &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: big.NewInt(0).SetBytes(rawKey[1:33]), + Y: big.NewInt(0).SetBytes(rawKey[33:]), + }, nil + case ecdh.P384(): + return &ecdsa.PublicKey{ + Curve: elliptic.P384(), + X: big.NewInt(0).SetBytes(rawKey[1:49]), + Y: big.NewInt(0).SetBytes(rawKey[49:]), + }, nil + case ecdh.P521(): + return &ecdsa.PublicKey{ + Curve: elliptic.P521(), + X: big.NewInt(0).SetBytes(rawKey[1:67]), + Y: big.NewInt(0).SetBytes(rawKey[67:]), + }, nil + default: + return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") + } +} + +func Test_ecdhPublicKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pk, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + original := &pk.PublicKey + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PublicKey to ecdh.PublicKey: %s", err) + } + roundtrip, err := ecdhPublicKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PublicKey back to ecdsa.PublicKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} + +func parseVapidAuthHeader(authHeader string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { + if authHeader == "" { + return nil, fmt.Errorf("missing auth header") + } + // the Authorization header should be of the form "vapid t=JWT, k=key" (RFC8292) + // we need to extract the JWT (JSON Web Token) from t to check the signature using k + authBody, found := strings.CutPrefix(authHeader, "vapid ") + if !found { + return nil, fmt.Errorf("Authorization header is not vapid: %s", authHeader) + } + authFields := strings.Split(authBody, ",") + rawJWT := "" + rawKey := "" + for _, field := range authFields { + kv := strings.SplitN(field, "=", 2) + if len(kv) < 2 { + return nil, fmt.Errorf("push service vapid Authorization header field %q malformed", field) + } + key := strings.TrimSpace(kv[0]) + val := strings.TrimSpace(kv[1]) + switch key { + case "t": + rawJWT = val + case "k": + rawKey = val + default: + // other fields irrelevant to us + } + } + if rawJWT == "" { + return nil, fmt.Errorf("vapid Authorization header lacks \"t\" field (JWT)") + } + if rawKey == "" { + return nil, fmt.Errorf("vapid Authorization header lacks \"k\" field") + } + key, err := decodeVAPIDPublicKey(rawKey) + if err != nil { + return nil, fmt.Errorf("parsing vapid Authorization key: %w", err) + } + // check that the key matches the known applicationServerKey + // (RFC8292 4.2) + if !key.Equal(applicationServerKey) { + // in real code, this would mean the user agent needs to resubscribe with the new applicationServerKey + return nil, fmt.Errorf("vapid Authorization key does not match applicationServerKey from subscription") + } + + // verify the JWT signature + token, err := parseJWT(rawJWT, key) + if err != nil { + return nil, fmt.Errorf("parsing vapid Authorization JWT: %w", err) + } + return token, nil +} + +func parseJWT(rawJWT string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { + token, err := jwt.Parse(rawJWT, func(t *jwt.Token) (interface{}, error) { + switch t.Method.Alg() { + case "ES256": + return applicationServerKey, nil + default: + return nil, fmt.Errorf("unsupported JWT signing alg %q", t.Method.Alg()) + } + }) + if err != nil { + return nil, fmt.Errorf("decoding JWT %s: %w", rawJWT, err) + } + return token, nil +} + +func decodeNotification(body []byte, authSecret [16]byte, userAgentKey *ecdsa.PrivateKey) (string, error) { + // remember initial body length, before we start consuming it + bodyLen := len(body) + // the body is aes128gcm-encoded as described in RFC8188, + // starting with this header: + // +-----------+--------+-----------+---------------+ + // | salt (16) | rs (4) | idlen (1) | keyid (idlen) | + // +-----------+--------+-----------+---------------+ + salt, body := body[:16], body[16:] + recordSize, body := int(binary.BigEndian.Uint32(body[:4])), body[4:] + idLen, body := int(uint8(body[0])), body[1:] + rawPubKey, body := body[:idLen], body[idLen:] + if bodyLen != recordSize { + // this could mean a multi-record message was sent, this simplified parser does not support those. + return "", fmt.Errorf("expected body length %d, got %d", recordSize, bodyLen) + } + + // parse keys and derive shared secret + pubKey, err := decodeECDSAPublicKey(rawPubKey) + if err != nil { + return "", fmt.Errorf("decoding public key from header: %w", err) + } + pubKeyECDH, err := pubKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting public key to ECDH: %w", err) + } + userAgentECDHKey, err := userAgentKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting user agent private key to ECDH: %w", err) + } + userAgentECDHPublicKey, err := userAgentKey.PublicKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting user agent public key to ECDH: %w", err) + } + + sharedECDHSecret, err := userAgentECDHKey.ECDH(pubKeyECDH) + if err != nil { + return "", fmt.Errorf("deriving shared secret from notification public key and user agent private key: %w", err) + } + + hash := sha256.New + + // ikm + prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) + prkInfoBuf.Write(userAgentECDHPublicKey.Bytes()) // aka "dh" + prkInfoBuf.Write(pubKeyECDH.Bytes()) + + prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret[:], prkInfoBuf.Bytes()) + ikm, err := getHKDFKey(prkHKDF, 32) + if err != nil { + return "", fmt.Errorf("deriving ikm: %w", err) + } + + // Derive Content Encryption Key + contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00") + contentHKDF := hkdf.New(hash, ikm, salt, contentEncryptionKeyInfo) + contentEncryptionKey, err := getHKDFKey(contentHKDF, 16) + if err != nil { + return "", fmt.Errorf("deriving content encryption key: %w", err) + } + + // Derive the Nonce + nonceInfo := []byte("Content-Encoding: nonce\x00") + nonceHKDF := hkdf.New(hash, ikm, salt, nonceInfo) + nonce, err := getHKDFKey(nonceHKDF, 12) + if err != nil { + return "", fmt.Errorf("deriving nonce: %w", err) + } + + // Cipher + c, err := aes.NewCipher(contentEncryptionKey) + if err != nil { + return "", fmt.Errorf("creating cipher block: %w", err) + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return "", fmt.Errorf("creating GCM: %w", err) + } + + // Decrypt + res, err := gcm.Open(nil, nonce, body, nil) + if err != nil { + return "", fmt.Errorf("decrypting: %w", err) + } + + // the message is padded with 0x02 0x00 0x00 0x00 [...] 0x00, we need to remove that + lastNull := len(res) + for ; lastNull > 0 && res[lastNull-1] == 0x00; lastNull-- { + } + if lastNull == 0 { + // we expect at least one 0x02 (or 0x01) before the nulls, not finding one is wrong + return "", fmt.Errorf("decryption yielded only %d null bytes", len(res)) + } + if beforeNull := res[lastNull-1]; beforeNull != 0x02 { + // if we get an 0x01, it means we have a multi-record message, this mock does not implement those + return "", fmt.Errorf("padding nulls in decrypted message should be preceded by 0x02 delimiter, got %02X", beforeNull) + } + // strip trailing nulls and separating 0x02 + res = res[:lastNull-1] + + return string(res), nil +} + +// test for the decoding helper function +func Test_decodeVAPIDPublicKey(t *testing.T) { + privKeyB64, pubKeyB64, err := GenerateVAPIDKeys() + if err != nil { + t.Fatalf("generating VAPID keys: %s", err) + } + + // as a baseline, decode using the library functions + privKeyBytes, err := decodeVapidKey(privKeyB64) + if err != nil { + t.Fatalf("decoding private key: %s", err) + } + privKey := generateVAPIDHeaderKeys(privKeyBytes) + wantPubKey := &privKey.PublicKey + + // now decode using our test helper and compare the results + gotPubKey, err := decodeVAPIDPublicKey(pubKeyB64) + if err != nil { + t.Fatalf("decoding public key") + } + if !gotPubKey.Equal(wantPubKey) { + t.Errorf("result differs:\ngot: %v\nwant: %v", gotPubKey, wantPubKey) + } +} From 3e0c7552ae08118117f79a6c97af5a50e90c2bd7 Mon Sep 17 00:00:00 2001 From: Willi Schinmeyer Date: Tue, 7 Nov 2023 11:52:40 +0100 Subject: [PATCH 2/2] Replace deprecated crypto/elliptic with crypto/ecdh crypto/elliptic is subject for removal soon, use of crypto/ecdh is advised instead. This has a number of side-effects: - the required Go version increases to Go 1.20 - the configured VAPID keys get verified now, and invalid keys are rejected - this necessitated changes to some tests - VAPID is effectively mandatory now (but all push services I know require it anyway) go.mod has been updated to reflect the new requirement, and I ran `go mod tidy` to clean up go.sum. I also added additional error context by wrapping errors with fmt.Errorf's %w verb. This was introduced in Go 1.13. --- README.md | 2 +- end2end_test.go | 74 +++------------------------------------- go.mod | 2 +- go.sum | 34 ------------------ vapid.go | 81 ++++++++++++++++++++++++++++--------------- vapid_test.go | 91 +++++++++++++++++++++++++++++++++++++++++++++++-- webpush.go | 41 ++++++++++------------ webpush_test.go | 16 ++++++--- 8 files changed, 179 insertions(+), 162 deletions(-) diff --git a/README.md b/README.md index c313fc6..3a0ab13 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ if err != nil { ## Development -1. Install [Go 1.11+](https://golang.org/) +1. Install [Go 1.20+](https://golang.org/) 2. `go mod vendor` 3. `go test` diff --git a/end2end_test.go b/end2end_test.go index f7858f4..40cbc72 100644 --- a/end2end_test.go +++ b/end2end_test.go @@ -13,7 +13,6 @@ import ( "encoding/binary" "fmt" "io" - "math/big" "net/http" "net/http/httptest" "strings" @@ -195,74 +194,6 @@ func decodeECDSAPublicKey(bytes []byte) (*ecdsa.PublicKey, error) { return res, nil } -func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { - // see https://github.com/golang/go/issues/63963 - rawKey := key.Bytes() - switch key.Curve() { - case ecdh.P256(): - return &ecdsa.PublicKey{ - Curve: elliptic.P256(), - X: big.NewInt(0).SetBytes(rawKey[1:33]), - Y: big.NewInt(0).SetBytes(rawKey[33:]), - }, nil - case ecdh.P384(): - return &ecdsa.PublicKey{ - Curve: elliptic.P384(), - X: big.NewInt(0).SetBytes(rawKey[1:49]), - Y: big.NewInt(0).SetBytes(rawKey[49:]), - }, nil - case ecdh.P521(): - return &ecdsa.PublicKey{ - Curve: elliptic.P521(), - X: big.NewInt(0).SetBytes(rawKey[1:67]), - Y: big.NewInt(0).SetBytes(rawKey[67:]), - }, nil - default: - return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") - } -} - -func Test_ecdhPublicKeyToECDSA(t *testing.T) { - tests := [...]struct { - name string - curve elliptic.Curve - }{ - // P224 not supported by ecdh - { - name: "P256", - curve: elliptic.P256(), - }, - { - name: "P256", - curve: elliptic.P384(), - }, - { - name: "P521", - curve: elliptic.P521(), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pk, err := ecdsa.GenerateKey(tt.curve, rand.Reader) - if err != nil { - t.Fatalf("generating ecdsa.PrivateKey: %s", err) - } - original := &pk.PublicKey - converted, err := original.ECDH() - if err != nil { - t.Fatalf("converting ecdsa.PublicKey to ecdh.PublicKey: %s", err) - } - roundtrip, err := ecdhPublicKeyToECDSA(converted) - if err != nil { - t.Fatalf("converting ecdh.PublicKey back to ecdsa.PublicKey: %s", err) - } - if !roundtrip.Equal(original) { - t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) - } - }) - } -} - func parseVapidAuthHeader(authHeader string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { if authHeader == "" { return nil, fmt.Errorf("missing auth header") @@ -448,7 +379,10 @@ func Test_decodeVAPIDPublicKey(t *testing.T) { if err != nil { t.Fatalf("decoding private key: %s", err) } - privKey := generateVAPIDHeaderKeys(privKeyBytes) + privKey, err := generateVAPIDHeaderKeys(privKeyBytes) + if err != nil { + t.Fatalf("converting private key: %s", err) + } wantPubKey := &privKey.PublicKey // now decode using our test helper and compare the results diff --git a/go.mod b/go.mod index 6b0604f..642b7dd 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,4 @@ require ( golang.org/x/crypto v0.9.0 ) -go 1.13 +go 1.20 diff --git a/go.sum b/go.sum index d9575c4..a2e2828 100644 --- a/go.sum +++ b/go.sum @@ -1,38 +1,4 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/vapid.go b/vapid.go index fe2c580..9e82e96 100644 --- a/vapid.go +++ b/vapid.go @@ -1,6 +1,7 @@ package webpush import ( + "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -16,46 +17,69 @@ import ( // GenerateVAPIDKeys will create a private and public VAPID key pair func GenerateVAPIDKeys() (privateKey, publicKey string, err error) { // Get the private key from the P256 curve - curve := elliptic.P256() + curve := ecdh.P256() - private, x, y, err := elliptic.GenerateKey(curve, rand.Reader) + private, err := curve.GenerateKey(rand.Reader) if err != nil { return } - public := elliptic.Marshal(curve, x, y) - // Convert to base64 - publicKey = base64.RawURLEncoding.EncodeToString(public) - privateKey = base64.RawURLEncoding.EncodeToString(private) - + publicKey = base64.RawURLEncoding.EncodeToString(private.PublicKey().Bytes()) + privateKey = base64.RawURLEncoding.EncodeToString(private.Bytes()) return } // Generates the ECDSA public and private keys for the JWT encryption -func generateVAPIDHeaderKeys(privateKey []byte) *ecdsa.PrivateKey { - // Public key - curve := elliptic.P256() - px, py := curve.ScalarMult( - curve.Params().Gx, - curve.Params().Gy, - privateKey, - ) - - pubKey := ecdsa.PublicKey{ - Curve: curve, - X: px, - Y: py, +func generateVAPIDHeaderKeys(privateKey []byte) (*ecdsa.PrivateKey, error) { + key, err := ecdh.P256().NewPrivateKey(privateKey) + if err != nil { + return nil, fmt.Errorf("validating private key: %w", err) } + converted, err := ecdhPrivateKeyToECDSA(key) + if err != nil { + return nil, fmt.Errorf("converting private key to crypto/ecdsa: %w", err) + } + return converted, nil +} - // Private key - d := &big.Int{} - d.SetBytes(privateKey) +func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { + // see https://github.com/golang/go/issues/63963 + rawKey := key.Bytes() + switch key.Curve() { + case ecdh.P256(): + return &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: big.NewInt(0).SetBytes(rawKey[1:33]), + Y: big.NewInt(0).SetBytes(rawKey[33:]), + }, nil + case ecdh.P384(): + return &ecdsa.PublicKey{ + Curve: elliptic.P384(), + X: big.NewInt(0).SetBytes(rawKey[1:49]), + Y: big.NewInt(0).SetBytes(rawKey[49:]), + }, nil + case ecdh.P521(): + return &ecdsa.PublicKey{ + Curve: elliptic.P521(), + X: big.NewInt(0).SetBytes(rawKey[1:67]), + Y: big.NewInt(0).SetBytes(rawKey[67:]), + }, nil + default: + return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") + } +} - return &ecdsa.PrivateKey{ - PublicKey: pubKey, - D: d, +func ecdhPrivateKeyToECDSA(key *ecdh.PrivateKey) (*ecdsa.PrivateKey, error) { + // see https://github.com/golang/go/issues/63963 + pubKey, err := ecdhPublicKeyToECDSA(key.PublicKey()) + if err != nil { + return nil, fmt.Errorf("converting PublicKey part of *ecdh.PrivateKey: %w", err) } + return &ecdsa.PrivateKey{ + PublicKey: *pubKey, + D: big.NewInt(0).SetBytes(key.Bytes()), + }, nil } // getVAPIDAuthorizationHeader @@ -84,7 +108,10 @@ func getVAPIDAuthorizationHeader( return "", err } - privKey := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + privKey, err := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + if err != nil { + return "", fmt.Errorf("generating VAPID header keys: %w", err) + } // Sign token with private key jwtString, err := token.SignedString(privKey) diff --git a/vapid_test.go b/vapid_test.go index be4cca8..d77d4ec 100644 --- a/vapid_test.go +++ b/vapid_test.go @@ -1,6 +1,9 @@ package webpush import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "encoding/base64" "fmt" "strings" @@ -44,10 +47,13 @@ func TestVAPID(t *testing.T) { b64 := base64.RawURLEncoding decodedVapidPrivateKey, err := b64.DecodeString(vapidPrivateKey) if err != nil { - t.Fatal("Could not decode VAPID private key") + t.Fatalf("Could not decode VAPID private key: %s", err) } - privKey := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + privKey, err := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + if err != nil { + t.Fatalf("Could not parse VAPID private key: %s", err) + } return privKey.Public(), nil }) @@ -100,3 +106,84 @@ func getTokenFromAuthorizationHeader(tokenHeader string, t *testing.T) string { return tsplit[1][:len(tsplit[1])-1] } + +func Test_ecdhPublicKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pk, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + original := &pk.PublicKey + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PublicKey to ecdh.PublicKey: %s", err) + } + roundtrip, err := ecdhPublicKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PublicKey back to ecdsa.PublicKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} + +func Test_ecdhPrivateKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PrivateKey to ecdh.PrivateKey: %s", err) + } + roundtrip, err := ecdhPrivateKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PrivateKey back to ecdsa.PrivateKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} diff --git a/webpush.go b/webpush.go index 4c85ad6..a2d9768 100644 --- a/webpush.go +++ b/webpush.go @@ -5,12 +5,13 @@ import ( "context" "crypto/aes" "crypto/cipher" - "crypto/elliptic" + "crypto/ecdh" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/binary" "errors" + "fmt" "io" "net/http" "strconv" @@ -77,53 +78,47 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri // Authentication secret (auth_secret) authSecret, err := decodeSubscriptionKey(s.Keys.Auth) if err != nil { - return nil, err + return nil, fmt.Errorf("decoding keys.auth: %w", err) } // dh (Diffie Hellman) dh, err := decodeSubscriptionKey(s.Keys.P256dh) if err != nil { - return nil, err + return nil, fmt.Errorf("decoding keys.p256dh: %w", err) + } + userAgentPublicKey, err := ecdh.P256().NewPublicKey(dh) + if err != nil { + return nil, fmt.Errorf("validating keys.p256dh: %w", err) } // Generate 16 byte salt salt, err := saltFunc() if err != nil { - return nil, err + return nil, fmt.Errorf("generating salt: %w", err) } // Create the ecdh_secret shared key pair - curve := elliptic.P256() // Application server key pairs (single use) - localPrivateKey, x, y, err := elliptic.GenerateKey(curve, rand.Reader) + localPrivateKey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { return nil, err } - localPublicKey := elliptic.Marshal(curve, x, y) - - // Combine application keys with receiver's EC public key - sharedX, sharedY := elliptic.Unmarshal(curve, dh) - if sharedX == nil { - return nil, errors.New("Unmarshal Error: Public key is not a valid point on the curve") - } + localPublicKey := localPrivateKey.PublicKey() - // Derive ECDH shared secret - sx, sy := curve.ScalarMult(sharedX, sharedY, localPrivateKey) - if !curve.IsOnCurve(sx, sy) { - return nil, errors.New("Encryption error: ECDH shared secret isn't on curve") + // Combine application keys with receiver's EC public key to derive ECDH shared secret + sharedECDHSecret, err := localPrivateKey.ECDH(userAgentPublicKey) + if err != nil { + return nil, fmt.Errorf("deriving shared secret: %w", err) } - mlen := curve.Params().BitSize / 8 - sharedECDHSecret := make([]byte, mlen) - sx.FillBytes(sharedECDHSecret) hash := sha256.New // ikm prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) prkInfoBuf.Write(dh) - prkInfoBuf.Write(localPublicKey) + prkInfoBuf.Write(localPublicKey.Bytes()) prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret, prkInfoBuf.Bytes()) ikm, err := getHKDFKey(prkHKDF, 32) @@ -173,8 +168,8 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri binary.BigEndian.PutUint32(rs, recordSize) recordBuf.Write(rs) - recordBuf.Write([]byte{byte(len(localPublicKey))}) - recordBuf.Write(localPublicKey) + recordBuf.Write([]byte{byte(len(localPublicKey.Bytes()))}) + recordBuf.Write(localPublicKey.Bytes()) // Data dataBuf := bytes.NewBuffer(message) diff --git a/webpush_test.go b/webpush_test.go index 807a1f7..d1f74c5 100644 --- a/webpush_test.go +++ b/webpush_test.go @@ -33,6 +33,10 @@ func getStandardEncodedTestSubscription() *Subscription { } func TestSendNotificationToURLEncodedSubscription(t *testing.T) { + priv, pub, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } resp, err := SendNotification([]byte("Test"), getURLEncodedTestSubscription(), &Options{ HTTPClient: &testHTTPClient{}, RecordSize: 3070, @@ -40,8 +44,8 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { Topic: "test_topic", TTL: 0, Urgency: "low", - VAPIDPublicKey: "test-public", - VAPIDPrivateKey: "test-private", + VAPIDPublicKey: pub, + VAPIDPrivateKey: priv, }) if err != nil { t.Fatal(err) @@ -49,7 +53,7 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { if resp.StatusCode != 201 { t.Fatalf( - "Incorreect status code, expected=%d, got=%d", + "Incorrect status code, expected=%d, got=%d", resp.StatusCode, 201, ) @@ -57,13 +61,17 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { } func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { + priv, _, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } resp, err := SendNotification([]byte("Test"), getStandardEncodedTestSubscription(), &Options{ HTTPClient: &testHTTPClient{}, Subscriber: "", Topic: "test_topic", TTL: 0, Urgency: "low", - VAPIDPrivateKey: "testKey", + VAPIDPrivateKey: priv, }) if err != nil { t.Fatal(err)