From b0db6124d5db2214d70f9f3035952a2605d0b082 Mon Sep 17 00:00:00 2001 From: max ulidtko Date: Tue, 6 Jul 2021 12:13:54 +0300 Subject: [PATCH 1/8] refactor: use a b64pad function --- jwt_tool.py | 155 +++++++++------------------------------------------- 1 file changed, 26 insertions(+), 129 deletions(-) diff --git a/jwt_tool.py b/jwt_tool.py index 99ca1ab..86b0014 100644 --- a/jwt_tool.py +++ b/jwt_tool.py @@ -56,6 +56,10 @@ def cprintc(textval, colval): if not args.bare: cprint(textval, colval) +def b64pad(buf): + """ Restore stripped B64 padding """ + return buf + '=' * (4 - len(buf) % 4 if len(buf) % 4 in (2, 3) else 0) + def createConfig(): privKeyName = path+"/jwttool_custom_private_RSA.pem" pubkeyName = path+"/jwttool_custom_public_RSA.pem" @@ -867,34 +871,13 @@ def verifyTokenRSA(headDict, paylDict, sig, pubKey): key = RSA.importKey(open(pubKey).read()) newContents = genContents(headDict, paylDict) newContents = newContents.encode('UTF-8') - if "-" in sig: - try: - sig = base64.urlsafe_b64decode(sig) - except: - pass - try: - sig = base64.urlsafe_b64decode(sig+"=") - except: - pass - try: - sig = base64.urlsafe_b64decode(sig+"==") - except: - pass - elif "+" in sig: - try: - sig = base64.b64decode(sig) - except: - pass - try: - sig = base64.b64decode(sig+"=") - except: - pass + try: + sig = base64.urlsafe_b64decode(b64pad(sig)) + except ValueError: try: - sig = base64.b64decode(sig+"==") - except: - pass - else: - cprintc("Signature not Base64 encoded HEX", "red") + sig = base64.b64decode(b64pad(sig)) + except ValueError: + cprintc("Signature not Base64 encoded HEX", "red") if headDict['alg'] == "RS256": h = SHA256.new(newContents) elif headDict['alg'] == "RS384": @@ -919,34 +902,13 @@ def verifyTokenRSA(headDict, paylDict, sig, pubKey): def verifyTokenEC(headDict, paylDict, sig, pubKey): newContents = genContents(headDict, paylDict) message = newContents.encode('UTF-8') - if "-" in str(sig): - try: - signature = base64.urlsafe_b64decode(sig) - except: - pass - try: - signature = base64.urlsafe_b64decode(sig+"=") - except: - pass - try: - signature = base64.urlsafe_b64decode(sig+"==") - except: - pass - elif "+" in str(sig): - try: - signature = base64.b64decode(sig) - except: - pass - try: - signature = base64.b64decode(sig+"=") - except: - pass + try: + sig = base64.urlsafe_b64decode(b64pad(sig)) + except ValueError: try: - signature = base64.b64decode(sig+"==") - except: - pass - else: - cprintc("Signature not Base64 encoded HEX", "red") + sig = base64.b64decode(b64pad(sig)) + except ValueError: + cprintc("Signature not Base64 encoded HEX", "red") if headDict['alg'] == "ES256": h = SHA256.new(message) elif headDict['alg'] == "ES384": @@ -971,34 +933,13 @@ def verifyTokenPSS(headDict, paylDict, sig, pubKey): key = RSA.importKey(open(pubKey).read()) newContents = genContents(headDict, paylDict) newContents = newContents.encode('UTF-8') - if "-" in sig: - try: - sig = base64.urlsafe_b64decode(sig) - except: - pass - try: - sig = base64.urlsafe_b64decode(sig+"=") - except: - pass - try: - sig = base64.urlsafe_b64decode(sig+"==") - except: - pass - elif "+" in sig: - try: - sig = base64.b64decode(sig) - except: - pass - try: - sig = base64.b64decode(sig+"=") - except: - pass + try: + sig = base64.urlsafe_b64decode(b64pad(sig)) + except ValueError: try: - sig = base64.b64decode(sig+"==") - except: - pass - else: - cprintc("Signature not Base64 encoded HEX", "red") + sig = base64.b64decode(b64pad(sig)) + except ValueError: + cprintc("Signature not Base64 encoded HEX", "red") if headDict['alg'] == "PS256": h = SHA256.new(newContents) elif headDict['alg'] == "PS384": @@ -1096,30 +1037,8 @@ def parseJWKS(jwksfile): pass def genECPubFromJWKS(x, y, kid, nowtime): - try: - x = int.from_bytes(base64.urlsafe_b64decode(x), byteorder='big') - except: - pass - try: - x = int.from_bytes(base64.urlsafe_b64decode(x+"="), byteorder='big') - except: - pass - try: - x = int.from_bytes(base64.urlsafe_b64decode(x+"=="), byteorder='big') - except: - pass - try: - y = int.from_bytes(base64.urlsafe_b64decode(y), byteorder='big') - except: - pass - try: - y = int.from_bytes(base64.urlsafe_b64decode(y+"="), byteorder='big') - except: - pass - try: - y = int.from_bytes(base64.urlsafe_b64decode(y+"=="), byteorder='big') - except: - pass + x = int.from_bytes(base64.urlsafe_b64decode(b64pad(x)), byteorder='big') + y = int.from_bytes(base64.urlsafe_b64decode(b64pad(y)), byteorder='big') new_key = ECC.construct(curve='P-256', point_x=x, point_y=y) pubKey = new_key.public_key().export_key(format="PEM")+"\n" pubkeyName = "kid_"+str(kid)+"_"+str(nowtime)+".pem" @@ -1128,30 +1047,8 @@ def genECPubFromJWKS(x, y, kid, nowtime): return pubkeyName def genRSAPubFromJWKS(n, e, kid, nowtime): - try: - n = int.from_bytes(base64.urlsafe_b64decode(n), byteorder='big') - except: - pass - try: - n = int.from_bytes(base64.urlsafe_b64decode(n+"="), byteorder='big') - except: - pass - try: - n = int.from_bytes(base64.urlsafe_b64decode(n+"=="), byteorder='big') - except: - pass - try: - e = int.from_bytes(base64.urlsafe_b64decode(e), byteorder='big') - except: - pass - try: - e = int.from_bytes(base64.urlsafe_b64decode(e+"="), byteorder='big') - except: - pass - try: - e = int.from_bytes(base64.urlsafe_b64decode(e+"=="), byteorder='big') - except: - pass + n = int.from_bytes(base64.urlsafe_b64decode(b64pad(n)), byteorder='big') + e = int.from_bytes(base64.urlsafe_b64decode(b64pad(e)), byteorder='big') new_key = RSA.construct((n, e)) pubKey = new_key.publickey().exportKey(format="PEM") pubkeyName = "kid_"+str(kid)+"_"+str(nowtime)+".pem" From 57a1cd0d7b6619fa53504771a2227de79cd96317 Mon Sep 17 00:00:00 2001 From: max ulidtko Date: Thu, 3 Feb 2022 16:15:27 +0200 Subject: [PATCH 2/8] fix: NameError in verifyTokenEC shadowed behind catch-all except The correct variable name is `sig`, but under try: it's referred to as `signature`. Normally that'd crash with NameError exception -- but here we have a catch-all except block misinterpreting that as wrong signature. --- jwt_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jwt_tool.py b/jwt_tool.py index 86b0014..49486d6 100644 --- a/jwt_tool.py +++ b/jwt_tool.py @@ -921,10 +921,10 @@ def verifyTokenEC(headDict, paylDict, sig, pubKey): pub_key = ECC.import_key(pubkey.read()) verifier = DSS.new(pub_key, 'fips-186-3') try: - verifier.verify(h, signature) + verifier.verify(h, sig) cprintc("ECC Signature is VALID", "green") valid = True - except: + except ValueError: cprintc("ECC Signature is INVALID", "red") valid = False return valid From 8d8b02defba50f690d8ee62e8be87e774079fbe0 Mon Sep 17 00:00:00 2001 From: max ulidtko Date: Thu, 3 Feb 2022 16:20:17 +0200 Subject: [PATCH 3/8] fix: git chmod +x shebanged script --- jwt_tool.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 jwt_tool.py diff --git a/jwt_tool.py b/jwt_tool.py old mode 100644 new mode 100755 From b4023b492f0d725d5181fe8f19bd5dce4e337396 Mon Sep 17 00:00:00 2001 From: max ulidtko Date: Thu, 3 Feb 2022 16:23:28 +0200 Subject: [PATCH 4/8] fix: assert loaded external key curve matches the standard Keys can be/are of wildly different types, including different elliptic curves. IETF RFC7518 (JWA) section 3.4 table mandates this 3-row map: | JWT.alg | Hash, curve | ES256 | SHA256, P-256 | ES384 | SHA384, P-384 | ES512 | SHA512, P-521 The assert (unless disabled with -O) will clearly fail when JWT mismatches the pubkey, (as far as by curve choice). --- jwt_tool.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/jwt_tool.py b/jwt_tool.py index 49486d6..7009127 100755 --- a/jwt_tool.py +++ b/jwt_tool.py @@ -909,16 +909,20 @@ def verifyTokenEC(headDict, paylDict, sig, pubKey): sig = base64.b64decode(b64pad(sig)) except ValueError: cprintc("Signature not Base64 encoded HEX", "red") + if headDict['alg'] == "ES256": - h = SHA256.new(message) + h, curvename = SHA256.new(message), 'P-256' elif headDict['alg'] == "ES384": - h = SHA384.new(message) + h, curvename = SHA384.new(message), 'P-384' elif headDict['alg'] == "ES512": - h = SHA512.new(message) + h, curvename = SHA512.new(message), 'P-521' else: cprintc("Invalid ECDSA algorithm", "red") pubkey = open(pubKey, "r") pub_key = ECC.import_key(pubkey.read()) + cprintc("[ ] loaded ECC pubkey on the curve {}".format(pub_key.curve), "cyan") + assert pub_key.curve == 'NIST ' + curvename, "Key on unexpected curve loaded" + verifier = DSS.new(pub_key, 'fips-186-3') try: verifier.verify(h, sig) From 1897c5a54c0ce196082a67e226f8e24d28857f18 Mon Sep 17 00:00:00 2001 From: max ulidtko Date: Thu, 3 Feb 2022 17:53:16 +0200 Subject: [PATCH 5/8] fix: see https://www.iana.org/assignments/jose/jose.xhtml --- common-headers.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common-headers.txt b/common-headers.txt index 4be1f15..009d1f9 100644 --- a/common-headers.txt +++ b/common-headers.txt @@ -2,4 +2,5 @@ typ jku kid x5u -x5t \ No newline at end of file +x5t +url From c183f265b226788fc5fc5c2b2a7ce69c24565fc2 Mon Sep 17 00:00:00 2001 From: max ulidtko Date: Thu, 3 Feb 2022 18:39:48 +0200 Subject: [PATCH 6/8] chore: refactor parseSingleJWK --- jwt_tool.py | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/jwt_tool.py b/jwt_tool.py index 7009127..095909f 100755 --- a/jwt_tool.py +++ b/jwt_tool.py @@ -972,14 +972,14 @@ def exportJWKS(jku): return newContents, newSig def parseJWKS(jwksfile): - jwks = open(jwksfile, "r").read() - jwksDict = json.loads(jwks, object_pairs_hook=OrderedDict) + jwksDict = json.load(open(jwksfile, 'r'), object_pairs_hook=OrderedDict) nowtime = int(datetime.now().timestamp()) cprintc("JWKS Contents:", "cyan") try: keyLen = len(jwksDict["keys"]) cprintc("Number of keys: "+str(keyLen), "cyan") i = -1 + valid = False for jkey in range(0,keyLen): i += 1 cprintc("\n--------", "white") @@ -993,32 +993,15 @@ def parseJWKS(jwksfile): for keyVal in jwksDict["keys"][i].items(): keyVal = keyVal[0] cprintc("[+] "+keyVal+" = "+str(jwksDict["keys"][i][keyVal]), "green") - try: - x = str(jwksDict["keys"][i]["x"]) - y = str(jwksDict["keys"][i]["y"]) - cprintc("\nFound ECC key factors, generating a public key", "cyan") - pubkeyName = genECPubFromJWKS(x, y, kid, nowtime) - cprintc("[+] "+pubkeyName, "green") - cprintc("\nAttempting to verify token using "+pubkeyName, "cyan") - valid = verifyTokenEC(headDict, paylDict, sig, pubkeyName) - except: - pass - try: - n = str(jwksDict["keys"][i]["n"]) - e = str(jwksDict["keys"][i]["e"]) - cprintc("\nFound RSA key factors, generating a public key", "cyan") - pubkeyName = genRSAPubFromJWKS(n, e, kid, nowtime) - cprintc("[+] "+pubkeyName, "green") - cprintc("\nAttempting to verify token using "+pubkeyName, "cyan") - valid = verifyTokenRSA(headDict, paylDict, sig, pubkeyName) - except: - pass + parseSingleJWK(jwksDict["keys"][i], kid=i) except: cprintc("Single key file", "white") for jkey in jwksDict: cprintc("[+] "+jkey+" = "+str(jwksDict[jkey]), "green") + parseSingleJWK(jwksDict) + +def parseSingleJWK(jwksDict, kid=1): try: - kid = 1 x = str(jwksDict["x"]) y = str(jwksDict["y"]) cprintc("\nFound ECC key factors, generating a public key", "cyan") @@ -1029,7 +1012,6 @@ def parseJWKS(jwksfile): except: pass try: - kid = 1 n = str(jwksDict["n"]) e = str(jwksDict["e"]) cprintc("\nFound RSA key factors, generating a public key", "cyan") From a44528fce05a899f47ccaca889227c5f453e007b Mon Sep 17 00:00:00 2001 From: max ulidtko Date: Thu, 3 Feb 2022 19:07:14 +0200 Subject: [PATCH 7/8] feat: verify JWT against a JWKS URL --- jwt_tool.py | 64 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/jwt_tool.py b/jwt_tool.py index 095909f..db8cad8 100755 --- a/jwt_tool.py +++ b/jwt_tool.py @@ -17,12 +17,13 @@ import base64 import json import random -from urllib.parse import urljoin, urlparse +import tempfile import argparse from datetime import datetime import configparser from http.cookies import SimpleCookie from collections import OrderedDict +from urllib.parse import urljoin, urlparse try: from Cryptodome.Signature import PKCS1_v1_5, DSS, pss from Cryptodome.Hash import SHA256, SHA384, SHA512 @@ -978,38 +979,37 @@ def parseJWKS(jwksfile): try: keyLen = len(jwksDict["keys"]) cprintc("Number of keys: "+str(keyLen), "cyan") - i = -1 - valid = False - for jkey in range(0,keyLen): - i += 1 + kid_bak = 1 + any1valid = False + for d in jwksDict["keys"]: cprintc("\n--------", "white") - try: - cprintc("Key "+str(i+1), "cyan") - kid = str(jwksDict["keys"][i]["kid"]) - cprintc("kid: "+kid, "cyan") - except: - kid = i - cprintc("Key "+str(i+1), "cyan") - for keyVal in jwksDict["keys"][i].items(): - keyVal = keyVal[0] - cprintc("[+] "+keyVal+" = "+str(jwksDict["keys"][i][keyVal]), "green") - parseSingleJWK(jwksDict["keys"][i], kid=i) - except: + if 'kid' in d: + kid = str(d["kid"]) + else: + kid = kid_bak + kid_bak += 1 + cprintc(f"Key kid {kid}", "cyan") + for k, v in d.items(): + cprintc(f"[+] {k} = {v}", "green") + if parseSingleJWK(d, nowtime, kid=kid): + any1valid = True + return any1valid + except ValueError: cprintc("Single key file", "white") for jkey in jwksDict: cprintc("[+] "+jkey+" = "+str(jwksDict[jkey]), "green") - parseSingleJWK(jwksDict) + return parseSingleJWK(jwksDict, nowtime) -def parseSingleJWK(jwksDict, kid=1): +def parseSingleJWK(jwksDict, nowtime, kid=1): try: x = str(jwksDict["x"]) y = str(jwksDict["y"]) cprintc("\nFound ECC key factors, generating a public key", "cyan") - pubkeyName = genECPubFromJWKS(x, y, kid, nowtime) + pubkeyName = genECPubFromJWKS(x, y, kid, nowtime, curve=jwksDict.get('crv')) cprintc("[+] "+pubkeyName, "green") cprintc("\nAttempting to verify token using "+pubkeyName, "cyan") - valid = verifyTokenEC(headDict, paylDict, sig, pubkeyName) - except: + return verifyTokenEC(headDict, paylDict, sig, pubkeyName) + except KeyError: pass try: n = str(jwksDict["n"]) @@ -1018,14 +1018,14 @@ def parseSingleJWK(jwksDict, kid=1): pubkeyName = genRSAPubFromJWKS(n, e, kid, nowtime) cprintc("[+] "+pubkeyName, "green") cprintc("\nAttempting to verify token using "+pubkeyName, "cyan") - valid = verifyTokenRSA(headDict, paylDict, sig, pubkeyName) + return verifyTokenRSA(headDict, paylDict, sig, pubkeyName) except: pass -def genECPubFromJWKS(x, y, kid, nowtime): +def genECPubFromJWKS(x, y, kid, nowtime, curve=None): x = int.from_bytes(base64.urlsafe_b64decode(b64pad(x)), byteorder='big') y = int.from_bytes(base64.urlsafe_b64decode(b64pad(y)), byteorder='big') - new_key = ECC.construct(curve='P-256', point_x=x, point_y=y) + new_key = ECC.construct(curve=curve or 'P-256', point_x=x, point_y=y) pubKey = new_key.public_key().export_key(format="PEM")+"\n" pubkeyName = "kid_"+str(kid)+"_"+str(nowtime)+".pem" with open(pubkeyName, 'w') as test_pub_out: @@ -1669,8 +1669,20 @@ def runActions(): else: cprintc("Algorithm not supported for verification", "red") exit(1) + elif args.jwksfile: parseJWKS(config['crypto']['jwks']) + + elif args.jwksurl: + resp = requests.get(args.jwksurl) + assert resp.ok + + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(resp.content) + tmp.flush() + tmp.seek(0) + valid = parseJWKS(tmp.name) + exit(0 if valid else 1) else: cprintc("No Public Key or JWKS file provided (-pk/-jw)\n", "red") parser.print_usage() @@ -1791,8 +1803,6 @@ def printLogo(): os.rename(configFileName, path+"/old_("+config['services']['jwt_tool_version']+")_jwtconf.ini") createConfig() exit(1) - with open(path+"/null.txt", 'w') as nullfile: - pass findJWT = "" if args.request: From 14d16765a3a99cee1f7126e7f87c028cc5a869d5 Mon Sep 17 00:00:00 2001 From: max ulidtko Date: Thu, 17 Mar 2022 18:42:46 +0200 Subject: [PATCH 8/8] fix: avoid kid clashing potential For those JWK's which lack the kid attribute, the logic assigns one. When parsing pubkey bundle (JWKS, a set of JWK), the previous logic enables a clash, consider this JWK sequence: * {"kid": "2", "kty":"EC", "use":"sig", ... } * {"kty":"RS", "use":"sig", ... } -- this saves with kid=1 * {"kty":"RS", "use":"enc", ... } -- this *overwrites* kid=2 --- jwt_tool.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/jwt_tool.py b/jwt_tool.py index db8cad8..d15cbb7 100755 --- a/jwt_tool.py +++ b/jwt_tool.py @@ -979,15 +979,13 @@ def parseJWKS(jwksfile): try: keyLen = len(jwksDict["keys"]) cprintc("Number of keys: "+str(keyLen), "cyan") - kid_bak = 1 + kids_seen = set() + new_kid = lambda: 1 + max([x for x in kids_seen if isinstance(x, int)], default=0) any1valid = False for d in jwksDict["keys"]: cprintc("\n--------", "white") - if 'kid' in d: - kid = str(d["kid"]) - else: - kid = kid_bak - kid_bak += 1 + kid = d['kid'] if 'kid' in d else new_kid() + kids_seen.add(kid) cprintc(f"Key kid {kid}", "cyan") for k, v in d.items(): cprintc(f"[+] {k} = {v}", "green")