diff --git a/README.md b/README.md index c313fc6..3a0ab13 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ if err != nil { ## Development -1. Install [Go 1.11+](https://golang.org/) +1. Install [Go 1.20+](https://golang.org/) 2. `go mod vendor` 3. `go test` diff --git a/end2end_test.go b/end2end_test.go new file mode 100644 index 0000000..40cbc72 --- /dev/null +++ b/end2end_test.go @@ -0,0 +1,396 @@ +package webpush + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/golang-jwt/jwt" + "golang.org/x/crypto/hkdf" +) + +func TestEnd2End(t *testing.T) { + var ( + // the data known to the application server (backend, which uses webpush-go) + applicationServer struct { + publicVAPIDKey string + privateVAPIDKey string + subscription Subscription + } + // the data known to the user agent (browser) + userAgent struct { + publicVAPIDKey *ecdsa.PublicKey + subscriptionKey *ecdsa.PrivateKey + authSecret [16]byte + subscription Subscription + receivedNotifications [][]byte + } + // the data known to the push server (which receives push messages on behalf of the user agent, e.g. Firestore) + pushService struct { + applicationServerKey *ecdsa.PublicKey + receivedNotifications [][]byte + } + + err error + ) + + // a VAPID key pair for the application server, usually only generated once and reused + applicationServer.privateVAPIDKey, applicationServer.publicVAPIDKey, err = GenerateVAPIDKeys() + if err != nil { + t.Fatalf("generating VAPID keys: %s", err) + } + + // The application server needs to inform the user agent of the public VAPID key. + // (We decode it first for ease of use.) + userAgent.publicVAPIDKey, err = decodeVAPIDPublicKey(applicationServer.publicVAPIDKey) + if err != nil { + t.Fatal(err) + } + + // We need a mock push service for webpush-go to send notifications to. + var mockPushService *httptest.Server + mockPushService = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // check that there's a valid vapid JWT + token, err := parseVapidAuthHeader( + r.Header.Get("Authorization"), + // by the time this function is called, this value will be set (see PushManager.subscribe() below) + pushService.applicationServerKey) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + _, _ = fmt.Fprintf(w, "invalid auth: %s", err) + return + } + // verify that the audience matches our URL + aud := token.Claims.(jwt.MapClaims)["aud"] + if aud != mockPushService.URL { + w.WriteHeader(http.StatusUnauthorized) + _, _ = fmt.Fprintf(w, "JWT has bad audience, want %q, got %q", mockPushService.URL, aud) + return + } + // RFC8188 only allows for exactly one content encoding + if contentEncoding := r.Header.Get("Content-Encoding"); contentEncoding != "aes128gcm" { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, "unsupported Content-Encoding, want %q, got %q", "aes128gcm", contentEncoding) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + // this suggests a broken connection, so log the error instead of sending it back + t.Errorf("failed to read request body: %s", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + // store body for later decoding by user agent + // (the push service doesn't have the key required for decryption) + pushService.receivedNotifications = append(pushService.receivedNotifications, body) + + w.WriteHeader(http.StatusAccepted) + })) + defer mockPushService.Close() + + // what follows is the equivalent of PushManager.subscribe() in JS + { + // the user agent generates its own key pair so it can be sent encrypted messages + userAgent.subscriptionKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generating user agent keys: %s", err) + } + // we need the ECDH representation + ecdhPublicKey, err := userAgent.subscriptionKey.PublicKey.ECDH() + if err != nil { + t.Fatalf("converting user agent public key to ECDH: %s", err) + } + // generate the shared auth secret + _, err = rand.Read(userAgent.authSecret[:]) + if err != nil { + t.Fatalf("generating user agent auth secret: %s", err) + } + // the user agent then performs a registration with the push service using that key, + // while also letting the push service know the application server key to expect. + pushService.applicationServerKey = userAgent.publicVAPIDKey + userAgent.subscription = Subscription{ + Keys: Keys{ + Auth: base64.StdEncoding.EncodeToString(userAgent.authSecret[:]), + P256dh: base64.StdEncoding.EncodeToString(ecdhPublicKey.Bytes()), + }, + Endpoint: mockPushService.URL, + } + } + + // the user agent sends its subscription to the application server... + applicationServer.subscription = userAgent.subscription + + // ...and the application server uses the subscription to send a push notification + sentMessage := "this is our test push notification" + resp, err := SendNotification([]byte(sentMessage), &applicationServer.subscription, &Options{ + HTTPClient: mockPushService.Client(), + VAPIDPublicKey: applicationServer.publicVAPIDKey, + VAPIDPrivateKey: applicationServer.privateVAPIDKey, + Subscriber: "test@example.com", + }) + if err != nil { + t.Fatalf("failed to send notification: %s", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Errorf("error closing mock push service response body: %s", err) + } + }() + // check for success + respBody, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading mock push service response body: %s", err) + } + if resp.StatusCode/100 != 2 { + t.Errorf("unexpected push service status code %d, body: %s", resp.StatusCode, respBody) + } + + // the push server should now have received the notification + if l := len(pushService.receivedNotifications); l != 1 { + t.Fatalf("Want 1 notification received by push service, got %d", l) + } + // the push service then forwards the notification to the user agent + userAgent.receivedNotifications = pushService.receivedNotifications + // and the user agent can decrypt them + receivedMessage, err := decodeNotification(userAgent.receivedNotifications[0], userAgent.authSecret, userAgent.subscriptionKey) + if err != nil { + t.Fatalf("error decrypting notification in user agent: %s", err) + } + if receivedMessage != sentMessage { + t.Errorf("Sent notification %q, but got %q", sentMessage, receivedMessage) + } +} + +func decodeVAPIDPublicKey(publicVAPIDKey string) (*ecdsa.PublicKey, error) { + publicVAPIDKeyBytes, err := base64.RawURLEncoding.DecodeString(publicVAPIDKey) + if err != nil { + return nil, fmt.Errorf("base64-decoding public VAPID key: %w", err) + } + return decodeECDSAPublicKey(publicVAPIDKeyBytes) +} + +func decodeECDSAPublicKey(bytes []byte) (*ecdsa.PublicKey, error) { + ecdhKey, err := ecdh.P256().NewPublicKey(bytes) + if err != nil { + return nil, fmt.Errorf("parsing public VAPID key: %w", err) + } + res, err := ecdhPublicKeyToECDSA(ecdhKey) + if err != nil { + return nil, fmt.Errorf("converting public VAPID key from *ecdh.PublicKey to *ecdsa.PublicKey: %w", err) + } + return res, nil +} + +func parseVapidAuthHeader(authHeader string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { + if authHeader == "" { + return nil, fmt.Errorf("missing auth header") + } + // the Authorization header should be of the form "vapid t=JWT, k=key" (RFC8292) + // we need to extract the JWT (JSON Web Token) from t to check the signature using k + authBody, found := strings.CutPrefix(authHeader, "vapid ") + if !found { + return nil, fmt.Errorf("Authorization header is not vapid: %s", authHeader) + } + authFields := strings.Split(authBody, ",") + rawJWT := "" + rawKey := "" + for _, field := range authFields { + kv := strings.SplitN(field, "=", 2) + if len(kv) < 2 { + return nil, fmt.Errorf("push service vapid Authorization header field %q malformed", field) + } + key := strings.TrimSpace(kv[0]) + val := strings.TrimSpace(kv[1]) + switch key { + case "t": + rawJWT = val + case "k": + rawKey = val + default: + // other fields irrelevant to us + } + } + if rawJWT == "" { + return nil, fmt.Errorf("vapid Authorization header lacks \"t\" field (JWT)") + } + if rawKey == "" { + return nil, fmt.Errorf("vapid Authorization header lacks \"k\" field") + } + key, err := decodeVAPIDPublicKey(rawKey) + if err != nil { + return nil, fmt.Errorf("parsing vapid Authorization key: %w", err) + } + // check that the key matches the known applicationServerKey + // (RFC8292 4.2) + if !key.Equal(applicationServerKey) { + // in real code, this would mean the user agent needs to resubscribe with the new applicationServerKey + return nil, fmt.Errorf("vapid Authorization key does not match applicationServerKey from subscription") + } + + // verify the JWT signature + token, err := parseJWT(rawJWT, key) + if err != nil { + return nil, fmt.Errorf("parsing vapid Authorization JWT: %w", err) + } + return token, nil +} + +func parseJWT(rawJWT string, applicationServerKey *ecdsa.PublicKey) (*jwt.Token, error) { + token, err := jwt.Parse(rawJWT, func(t *jwt.Token) (interface{}, error) { + switch t.Method.Alg() { + case "ES256": + return applicationServerKey, nil + default: + return nil, fmt.Errorf("unsupported JWT signing alg %q", t.Method.Alg()) + } + }) + if err != nil { + return nil, fmt.Errorf("decoding JWT %s: %w", rawJWT, err) + } + return token, nil +} + +func decodeNotification(body []byte, authSecret [16]byte, userAgentKey *ecdsa.PrivateKey) (string, error) { + // remember initial body length, before we start consuming it + bodyLen := len(body) + // the body is aes128gcm-encoded as described in RFC8188, + // starting with this header: + // +-----------+--------+-----------+---------------+ + // | salt (16) | rs (4) | idlen (1) | keyid (idlen) | + // +-----------+--------+-----------+---------------+ + salt, body := body[:16], body[16:] + recordSize, body := int(binary.BigEndian.Uint32(body[:4])), body[4:] + idLen, body := int(uint8(body[0])), body[1:] + rawPubKey, body := body[:idLen], body[idLen:] + if bodyLen != recordSize { + // this could mean a multi-record message was sent, this simplified parser does not support those. + return "", fmt.Errorf("expected body length %d, got %d", recordSize, bodyLen) + } + + // parse keys and derive shared secret + pubKey, err := decodeECDSAPublicKey(rawPubKey) + if err != nil { + return "", fmt.Errorf("decoding public key from header: %w", err) + } + pubKeyECDH, err := pubKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting public key to ECDH: %w", err) + } + userAgentECDHKey, err := userAgentKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting user agent private key to ECDH: %w", err) + } + userAgentECDHPublicKey, err := userAgentKey.PublicKey.ECDH() + if err != nil { + return "", fmt.Errorf("converting user agent public key to ECDH: %w", err) + } + + sharedECDHSecret, err := userAgentECDHKey.ECDH(pubKeyECDH) + if err != nil { + return "", fmt.Errorf("deriving shared secret from notification public key and user agent private key: %w", err) + } + + hash := sha256.New + + // ikm + prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) + prkInfoBuf.Write(userAgentECDHPublicKey.Bytes()) // aka "dh" + prkInfoBuf.Write(pubKeyECDH.Bytes()) + + prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret[:], prkInfoBuf.Bytes()) + ikm, err := getHKDFKey(prkHKDF, 32) + if err != nil { + return "", fmt.Errorf("deriving ikm: %w", err) + } + + // Derive Content Encryption Key + contentEncryptionKeyInfo := []byte("Content-Encoding: aes128gcm\x00") + contentHKDF := hkdf.New(hash, ikm, salt, contentEncryptionKeyInfo) + contentEncryptionKey, err := getHKDFKey(contentHKDF, 16) + if err != nil { + return "", fmt.Errorf("deriving content encryption key: %w", err) + } + + // Derive the Nonce + nonceInfo := []byte("Content-Encoding: nonce\x00") + nonceHKDF := hkdf.New(hash, ikm, salt, nonceInfo) + nonce, err := getHKDFKey(nonceHKDF, 12) + if err != nil { + return "", fmt.Errorf("deriving nonce: %w", err) + } + + // Cipher + c, err := aes.NewCipher(contentEncryptionKey) + if err != nil { + return "", fmt.Errorf("creating cipher block: %w", err) + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return "", fmt.Errorf("creating GCM: %w", err) + } + + // Decrypt + res, err := gcm.Open(nil, nonce, body, nil) + if err != nil { + return "", fmt.Errorf("decrypting: %w", err) + } + + // the message is padded with 0x02 0x00 0x00 0x00 [...] 0x00, we need to remove that + lastNull := len(res) + for ; lastNull > 0 && res[lastNull-1] == 0x00; lastNull-- { + } + if lastNull == 0 { + // we expect at least one 0x02 (or 0x01) before the nulls, not finding one is wrong + return "", fmt.Errorf("decryption yielded only %d null bytes", len(res)) + } + if beforeNull := res[lastNull-1]; beforeNull != 0x02 { + // if we get an 0x01, it means we have a multi-record message, this mock does not implement those + return "", fmt.Errorf("padding nulls in decrypted message should be preceded by 0x02 delimiter, got %02X", beforeNull) + } + // strip trailing nulls and separating 0x02 + res = res[:lastNull-1] + + return string(res), nil +} + +// test for the decoding helper function +func Test_decodeVAPIDPublicKey(t *testing.T) { + privKeyB64, pubKeyB64, err := GenerateVAPIDKeys() + if err != nil { + t.Fatalf("generating VAPID keys: %s", err) + } + + // as a baseline, decode using the library functions + privKeyBytes, err := decodeVapidKey(privKeyB64) + if err != nil { + t.Fatalf("decoding private key: %s", err) + } + privKey, err := generateVAPIDHeaderKeys(privKeyBytes) + if err != nil { + t.Fatalf("converting private key: %s", err) + } + wantPubKey := &privKey.PublicKey + + // now decode using our test helper and compare the results + gotPubKey, err := decodeVAPIDPublicKey(pubKeyB64) + if err != nil { + t.Fatalf("decoding public key") + } + if !gotPubKey.Equal(wantPubKey) { + t.Errorf("result differs:\ngot: %v\nwant: %v", gotPubKey, wantPubKey) + } +} diff --git a/go.mod b/go.mod index 6b0604f..642b7dd 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,4 @@ require ( golang.org/x/crypto v0.9.0 ) -go 1.13 +go 1.20 diff --git a/go.sum b/go.sum index d9575c4..a2e2828 100644 --- a/go.sum +++ b/go.sum @@ -1,38 +1,4 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/vapid.go b/vapid.go index fe2c580..9e82e96 100644 --- a/vapid.go +++ b/vapid.go @@ -1,6 +1,7 @@ package webpush import ( + "crypto/ecdh" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -16,46 +17,69 @@ import ( // GenerateVAPIDKeys will create a private and public VAPID key pair func GenerateVAPIDKeys() (privateKey, publicKey string, err error) { // Get the private key from the P256 curve - curve := elliptic.P256() + curve := ecdh.P256() - private, x, y, err := elliptic.GenerateKey(curve, rand.Reader) + private, err := curve.GenerateKey(rand.Reader) if err != nil { return } - public := elliptic.Marshal(curve, x, y) - // Convert to base64 - publicKey = base64.RawURLEncoding.EncodeToString(public) - privateKey = base64.RawURLEncoding.EncodeToString(private) - + publicKey = base64.RawURLEncoding.EncodeToString(private.PublicKey().Bytes()) + privateKey = base64.RawURLEncoding.EncodeToString(private.Bytes()) return } // Generates the ECDSA public and private keys for the JWT encryption -func generateVAPIDHeaderKeys(privateKey []byte) *ecdsa.PrivateKey { - // Public key - curve := elliptic.P256() - px, py := curve.ScalarMult( - curve.Params().Gx, - curve.Params().Gy, - privateKey, - ) - - pubKey := ecdsa.PublicKey{ - Curve: curve, - X: px, - Y: py, +func generateVAPIDHeaderKeys(privateKey []byte) (*ecdsa.PrivateKey, error) { + key, err := ecdh.P256().NewPrivateKey(privateKey) + if err != nil { + return nil, fmt.Errorf("validating private key: %w", err) } + converted, err := ecdhPrivateKeyToECDSA(key) + if err != nil { + return nil, fmt.Errorf("converting private key to crypto/ecdsa: %w", err) + } + return converted, nil +} - // Private key - d := &big.Int{} - d.SetBytes(privateKey) +func ecdhPublicKeyToECDSA(key *ecdh.PublicKey) (*ecdsa.PublicKey, error) { + // see https://github.com/golang/go/issues/63963 + rawKey := key.Bytes() + switch key.Curve() { + case ecdh.P256(): + return &ecdsa.PublicKey{ + Curve: elliptic.P256(), + X: big.NewInt(0).SetBytes(rawKey[1:33]), + Y: big.NewInt(0).SetBytes(rawKey[33:]), + }, nil + case ecdh.P384(): + return &ecdsa.PublicKey{ + Curve: elliptic.P384(), + X: big.NewInt(0).SetBytes(rawKey[1:49]), + Y: big.NewInt(0).SetBytes(rawKey[49:]), + }, nil + case ecdh.P521(): + return &ecdsa.PublicKey{ + Curve: elliptic.P521(), + X: big.NewInt(0).SetBytes(rawKey[1:67]), + Y: big.NewInt(0).SetBytes(rawKey[67:]), + }, nil + default: + return nil, fmt.Errorf("cannot convert non-NIST *ecdh.PublicKey to *ecdsa.PublicKey") + } +} - return &ecdsa.PrivateKey{ - PublicKey: pubKey, - D: d, +func ecdhPrivateKeyToECDSA(key *ecdh.PrivateKey) (*ecdsa.PrivateKey, error) { + // see https://github.com/golang/go/issues/63963 + pubKey, err := ecdhPublicKeyToECDSA(key.PublicKey()) + if err != nil { + return nil, fmt.Errorf("converting PublicKey part of *ecdh.PrivateKey: %w", err) } + return &ecdsa.PrivateKey{ + PublicKey: *pubKey, + D: big.NewInt(0).SetBytes(key.Bytes()), + }, nil } // getVAPIDAuthorizationHeader @@ -84,7 +108,10 @@ func getVAPIDAuthorizationHeader( return "", err } - privKey := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + privKey, err := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + if err != nil { + return "", fmt.Errorf("generating VAPID header keys: %w", err) + } // Sign token with private key jwtString, err := token.SignedString(privKey) diff --git a/vapid_test.go b/vapid_test.go index be4cca8..d77d4ec 100644 --- a/vapid_test.go +++ b/vapid_test.go @@ -1,6 +1,9 @@ package webpush import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "encoding/base64" "fmt" "strings" @@ -44,10 +47,13 @@ func TestVAPID(t *testing.T) { b64 := base64.RawURLEncoding decodedVapidPrivateKey, err := b64.DecodeString(vapidPrivateKey) if err != nil { - t.Fatal("Could not decode VAPID private key") + t.Fatalf("Could not decode VAPID private key: %s", err) } - privKey := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + privKey, err := generateVAPIDHeaderKeys(decodedVapidPrivateKey) + if err != nil { + t.Fatalf("Could not parse VAPID private key: %s", err) + } return privKey.Public(), nil }) @@ -100,3 +106,84 @@ func getTokenFromAuthorizationHeader(tokenHeader string, t *testing.T) string { return tsplit[1][:len(tsplit[1])-1] } + +func Test_ecdhPublicKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pk, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + original := &pk.PublicKey + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PublicKey to ecdh.PublicKey: %s", err) + } + roundtrip, err := ecdhPublicKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PublicKey back to ecdsa.PublicKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} + +func Test_ecdhPrivateKeyToECDSA(t *testing.T) { + tests := [...]struct { + name string + curve elliptic.Curve + }{ + // P224 not supported by ecdh + { + name: "P256", + curve: elliptic.P256(), + }, + { + name: "P256", + curve: elliptic.P384(), + }, + { + name: "P521", + curve: elliptic.P521(), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original, err := ecdsa.GenerateKey(tt.curve, rand.Reader) + if err != nil { + t.Fatalf("generating ecdsa.PrivateKey: %s", err) + } + converted, err := original.ECDH() + if err != nil { + t.Fatalf("converting ecdsa.PrivateKey to ecdh.PrivateKey: %s", err) + } + roundtrip, err := ecdhPrivateKeyToECDSA(converted) + if err != nil { + t.Fatalf("converting ecdh.PrivateKey back to ecdsa.PrivateKey: %s", err) + } + if !roundtrip.Equal(original) { + t.Errorf("Roundtrip changed key from %v to %v", original, roundtrip) + } + }) + } +} diff --git a/webpush.go b/webpush.go index 4c85ad6..a2d9768 100644 --- a/webpush.go +++ b/webpush.go @@ -5,12 +5,13 @@ import ( "context" "crypto/aes" "crypto/cipher" - "crypto/elliptic" + "crypto/ecdh" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/binary" "errors" + "fmt" "io" "net/http" "strconv" @@ -77,53 +78,47 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri // Authentication secret (auth_secret) authSecret, err := decodeSubscriptionKey(s.Keys.Auth) if err != nil { - return nil, err + return nil, fmt.Errorf("decoding keys.auth: %w", err) } // dh (Diffie Hellman) dh, err := decodeSubscriptionKey(s.Keys.P256dh) if err != nil { - return nil, err + return nil, fmt.Errorf("decoding keys.p256dh: %w", err) + } + userAgentPublicKey, err := ecdh.P256().NewPublicKey(dh) + if err != nil { + return nil, fmt.Errorf("validating keys.p256dh: %w", err) } // Generate 16 byte salt salt, err := saltFunc() if err != nil { - return nil, err + return nil, fmt.Errorf("generating salt: %w", err) } // Create the ecdh_secret shared key pair - curve := elliptic.P256() // Application server key pairs (single use) - localPrivateKey, x, y, err := elliptic.GenerateKey(curve, rand.Reader) + localPrivateKey, err := ecdh.P256().GenerateKey(rand.Reader) if err != nil { return nil, err } - localPublicKey := elliptic.Marshal(curve, x, y) - - // Combine application keys with receiver's EC public key - sharedX, sharedY := elliptic.Unmarshal(curve, dh) - if sharedX == nil { - return nil, errors.New("Unmarshal Error: Public key is not a valid point on the curve") - } + localPublicKey := localPrivateKey.PublicKey() - // Derive ECDH shared secret - sx, sy := curve.ScalarMult(sharedX, sharedY, localPrivateKey) - if !curve.IsOnCurve(sx, sy) { - return nil, errors.New("Encryption error: ECDH shared secret isn't on curve") + // Combine application keys with receiver's EC public key to derive ECDH shared secret + sharedECDHSecret, err := localPrivateKey.ECDH(userAgentPublicKey) + if err != nil { + return nil, fmt.Errorf("deriving shared secret: %w", err) } - mlen := curve.Params().BitSize / 8 - sharedECDHSecret := make([]byte, mlen) - sx.FillBytes(sharedECDHSecret) hash := sha256.New // ikm prkInfoBuf := bytes.NewBuffer([]byte("WebPush: info\x00")) prkInfoBuf.Write(dh) - prkInfoBuf.Write(localPublicKey) + prkInfoBuf.Write(localPublicKey.Bytes()) prkHKDF := hkdf.New(hash, sharedECDHSecret, authSecret, prkInfoBuf.Bytes()) ikm, err := getHKDFKey(prkHKDF, 32) @@ -173,8 +168,8 @@ func SendNotificationWithContext(ctx context.Context, message []byte, s *Subscri binary.BigEndian.PutUint32(rs, recordSize) recordBuf.Write(rs) - recordBuf.Write([]byte{byte(len(localPublicKey))}) - recordBuf.Write(localPublicKey) + recordBuf.Write([]byte{byte(len(localPublicKey.Bytes()))}) + recordBuf.Write(localPublicKey.Bytes()) // Data dataBuf := bytes.NewBuffer(message) diff --git a/webpush_test.go b/webpush_test.go index 807a1f7..d1f74c5 100644 --- a/webpush_test.go +++ b/webpush_test.go @@ -33,6 +33,10 @@ func getStandardEncodedTestSubscription() *Subscription { } func TestSendNotificationToURLEncodedSubscription(t *testing.T) { + priv, pub, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } resp, err := SendNotification([]byte("Test"), getURLEncodedTestSubscription(), &Options{ HTTPClient: &testHTTPClient{}, RecordSize: 3070, @@ -40,8 +44,8 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { Topic: "test_topic", TTL: 0, Urgency: "low", - VAPIDPublicKey: "test-public", - VAPIDPrivateKey: "test-private", + VAPIDPublicKey: pub, + VAPIDPrivateKey: priv, }) if err != nil { t.Fatal(err) @@ -49,7 +53,7 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { if resp.StatusCode != 201 { t.Fatalf( - "Incorreect status code, expected=%d, got=%d", + "Incorrect status code, expected=%d, got=%d", resp.StatusCode, 201, ) @@ -57,13 +61,17 @@ func TestSendNotificationToURLEncodedSubscription(t *testing.T) { } func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { + priv, _, err := GenerateVAPIDKeys() + if err != nil { + t.Fatal(err) + } resp, err := SendNotification([]byte("Test"), getStandardEncodedTestSubscription(), &Options{ HTTPClient: &testHTTPClient{}, Subscriber: "", Topic: "test_topic", TTL: 0, Urgency: "low", - VAPIDPrivateKey: "testKey", + VAPIDPrivateKey: priv, }) if err != nil { t.Fatal(err)