Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 210 additions & 2 deletions pycose/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def get_pad_func(cls, hash_cls):


class _Ecdsa(CoseAlgorithm, ABC):
""" Polymorphic ECDSA family. """
@classmethod
@abstractmethod
def get_curve(cls):
Expand Down Expand Up @@ -196,6 +197,46 @@ def verify(cls, key: 'EC2', data: bytes, signature: bytes) -> bool:
except BadSignatureError:
return False

class _Espdsa(CoseAlgorithm, ABC):
""" Fully-specified ECDSA family. """
@classmethod
@abstractmethod
def get_cose_curve_fullname(cls) -> str:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be moved into _Ecdsa and changed to return Optional[str] so that the polymorphic derived classes return default None.

That would avoid mostly duplicate definitions of sign() and verify().

raise NotImplementedError()

@classmethod
@abstractmethod
def get_curve(cls):
raise NotImplementedError()

@classmethod
@abstractmethod
def get_hash_func(cls):
raise NotImplementedError()

@classmethod
def sign(cls, key: 'EC2', data: bytes) -> bytes:
if key.crv.fullname != cls.get_cose_curve_fullname():
raise CoseException(f"Illegal curve for signing: {key.crv}")

sk = SigningKey.from_secret_exponent(int(hexlify(key.d), 16), curve=cls.get_curve())

return sk.sign_deterministic(data, hashfunc=cls.get_hash_func())

@classmethod
def verify(cls, key: 'EC2', data: bytes, signature: bytes) -> bool:
if key.crv.fullname != cls.get_cose_curve_fullname():
raise CoseException(f"Illegal curve for signing: {key.crv}")

p = Point(curve=cls.get_curve().curve, x=int(hexlify(key.x), 16), y=int(hexlify(key.y), 16))

vk = VerifyingKey.from_public_point(p, cls.get_curve(), cls.get_hash_func(), validate_point=True)

try:
return vk.verify(signature=signature, data=data, hashfunc=cls.get_hash_func())
except BadSignatureError:
return False


class _AesMac(CoseAlgorithm, ABC):
@classmethod
Expand Down Expand Up @@ -420,6 +461,105 @@ def get_hash_func(cls):
return SHA256


@CoseAlgorithm.register_attribute()
class Ed448(CoseAlgorithm):
"""
Ed448

Attributes:
identifier -53

fullname ED448
"""

identifier = -53
fullname = "ED448"

@classmethod
def sign(cls, key: 'OKP', data: bytes) -> bytes:
if key.crv.fullname == 'ED448':
sk = Ed448PrivateKey.from_private_bytes(key.d)
else:
raise CoseException(f"Illegal curve for signing: {key.crv}")

return sk.sign(data)

@classmethod
def verify(cls, key: 'OKP', data: bytes, signature: bytes) -> bool:
if key.crv.fullname == 'ED448':
vk = Ed448PublicKey.from_public_bytes(key.x)
else:
raise CoseException(f"Illegal curve for signing: {key.crv}")

try:
vk.verify(signature, data)
return True
except InvalidSignature:
return False


@CoseAlgorithm.register_attribute()
class Esp512(_Espdsa):
"""
ECDSA using P-521 curve and SHA-512

Attributes:
identifier -52

fullname ESP512

"""

identifier = -52
fullname = "ESP512"

@classmethod
def get_cose_curve_fullname(cls) -> str:
""" Returns a curve object used with this algorithm """
return 'P_521'

@classmethod
def get_hash_func(cls):
""" Returns a hash function used with this algorithm """
return sha512

@classmethod
def get_curve(cls) -> Curve:
""" Returns a curve object used with this algorithm """
return NIST521p


@CoseAlgorithm.register_attribute()
class Esp384(_Espdsa):
"""
ECDSA using P-384 curve and SHA-384

Attributes:
identifier -51

fullname ESP384

"""

identifier = -51
fullname = "ESP384"

@classmethod
def get_cose_curve_fullname(cls) -> str:
""" Returns a curve object used with this algorithm """
return 'P_384'

@classmethod
def get_hash_func(cls):
""" Returns a hash function used with this algorithm """
return sha384

@classmethod
def get_curve(cls) -> Curve:
""" Returns a curve object used with this algorithm """
return NIST384p


@CoseAlgorithm.register_attribute()
class Shake256(_HashAlg):
"""
Expand Down Expand Up @@ -860,6 +1000,43 @@ def get_key_wrap_func(cls):
return Direct()


@CoseAlgorithm.register_attribute()
class Ed25519(CoseAlgorithm):
"""
Ed25519

Attributes:
identifier -9

fullname ED25519
"""

identifier = -9
fullname = "ED25519"

@classmethod
def sign(cls, key: 'OKP', data: bytes) -> bytes:
if key.crv.fullname == 'ED25519':
sk = Ed25519PrivateKey.from_private_bytes(key.d)
else:
raise CoseException(f"Illegal curve for signing: {key.crv}")

return sk.sign(data)

@classmethod
def verify(cls, key: 'OKP', data: bytes, signature: bytes) -> bool:
if key.crv.fullname == 'ED25519':
vk = Ed25519PublicKey.from_public_bytes(key.x)
else:
raise CoseException(f"Illegal curve for signing: {key.crv}")

try:
vk.verify(signature, data)
return True
except InvalidSignature:
return False


@CoseAlgorithm.register_attribute()
class Shake128(_HashAlg):
"""
Expand Down Expand Up @@ -1001,6 +1178,37 @@ class DirectHKDFSHA256(CoseAlgorithm):
fullname = "DIRECT_HKDF_SHA_256"


@CoseAlgorithm.register_attribute()
class Esp256(_Espdsa):
"""
ECDSA using P-256 curve and SHA-256

Attributes:
identifier -9

fullname ESP256

"""

identifier = -9
fullname = "ESP256"

@classmethod
def get_cose_curve_fullname(cls) -> str:
""" Returns a curve object used with this algorithm """
return 'P_256'

@classmethod
def get_hash_func(cls):
""" Returns a hash function used with this algorithm """
return sha256

@classmethod
def get_curve(cls) -> Curve:
""" Returns a curve object used with this algorithm """
return NIST256p


@CoseAlgorithm.register_attribute()
class EdDSA(CoseAlgorithm):
"""
Expand All @@ -1022,7 +1230,7 @@ def sign(cls, key: 'OKP', data: bytes) -> bytes:
elif key.crv.fullname == 'ED448':
sk = Ed448PrivateKey.from_private_bytes(key.d)
else:
raise CoseException(f"Illegal curve for OKP singing: {key.crv}")
raise CoseException(f"Illegal curve for OKP signing: {key.crv}")

return sk.sign(data)

Expand All @@ -1033,7 +1241,7 @@ def verify(cls, key: 'OKP', data: bytes, signature: bytes) -> bool:
elif key.crv.fullname == 'ED448':
vk = Ed448PublicKey.from_public_bytes(key.x)
else:
raise CoseException(f"Illegal curve for OKP singing: {key.crv}")
raise CoseException(f"Illegal curve for OKP signing: {key.crv}")

try:
vk.verify(signature, data)
Expand Down
52 changes: 51 additions & 1 deletion tests/test_signmessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import cbor2
import pytest

from pycose.keys import EC2Key
from pycose.keys import EC2Key, OKPKey
from pycose.messages.signmessage import SignMessage
from pycose.messages.signer import CoseSignature
from pycose.exceptions import CoseException
Expand Down Expand Up @@ -114,3 +114,53 @@ def test_fail_on_missing_payload_verification():

with pytest.raises(CoseException, match="Missing payload"):
signer.verify_signature()


def test_ecdsa_allow_key_curve_mismatch():
key = EC2Key.generate_key(crv='P_256')

signer = CoseSignature(phdr={'ALG': 'ES384'})
signer.key = key

msg = SignMessage(phdr={}, signers=[signer])

payload = "signed message".encode('utf-8')
msg.encode(detached_payload=payload)


def test_espdsa_success():
key = EC2Key.generate_key(crv='P_384')

signer = CoseSignature(phdr={'ALG': 'ESP384'})
signer.key = key

msg = SignMessage(phdr={}, signers=[signer])

payload = "signed message".encode('utf-8')
msg.encode(detached_payload=payload)


def test_espdsa_fail_on_key_curve_mismatch():
key = EC2Key.generate_key(crv='P_256')

signer = CoseSignature(phdr={'ALG': 'ESP384'})
signer.key = key

msg = SignMessage(phdr={}, signers=[signer])

payload = "signed message".encode('utf-8')
with pytest.raises(CoseException, match="Illegal curve for signing: .*"):
msg.encode(detached_payload=payload)


def test_edpdsa_fail_on_key_curve_mismatch():
key = OKPKey.generate_key(crv='ED25519')

signer = CoseSignature(phdr={'ALG': 'ED448'})
signer.key = key

msg = SignMessage(phdr={}, signers=[signer])

payload = "signed message".encode('utf-8')
with pytest.raises(CoseException, match="Illegal curve for signing: .*"):
msg.encode(detached_payload=payload)
Loading