diff --git a/Cargo.lock b/Cargo.lock index cc95e2e8e3..135fa5a762 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1042,6 +1042,7 @@ dependencies = [ "hex", "openssl", "openssl_kdf", + "rsa", "symcrypt", "thiserror 2.0.16", "wchar", @@ -6167,9 +6168,9 @@ dependencies = [ [[package]] name = "pkcs8" -version = "0.11.0-rc.11" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12922b6296c06eb741b02d7b5161e3aaa22864af38dfa025a1a3ba3f68c84577" +checksum = "451913da69c775a56034ea8d9003d27ee8948e12443eae7c038ba100a4f21cb7" dependencies = [ "der", "spki", @@ -6539,9 +6540,9 @@ checksum = "6c20b6793b5c2fa6553b250154b78d6d0db37e72700ae35fad9387a46f487c97" [[package]] name = "rsa" -version = "0.10.0-rc.17" +version = "0.10.0-rc.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ed3e93fc7e473e464b9726f4759659e72bc8665e4b8ea227547024f416d905" +checksum = "30b2aa4ba0d89f73d1e332df05be0eeab8840351c36ca5654341dfdb57bb3caf" dependencies = [ "const-oid", "crypto-bigint", diff --git a/openhcl/underhill_attestation/src/jwt.rs b/openhcl/underhill_attestation/src/jwt.rs index a1ac937cb5..640f4251af 100644 --- a/openhcl/underhill_attestation/src/jwt.rs +++ b/openhcl/underhill_attestation/src/jwt.rs @@ -39,7 +39,7 @@ pub(crate) enum JwtError { #[error("failed to validate certificate chain")] CertificateChainValidation(#[from] CertificateChainValidationError), #[error("failed to verify JWT signature")] - JwtSignatureVerification(#[from] JwtSignatureVerificationError), + JwtSignatureVerification(#[source] JwtSignatureVerificationError), } #[derive(Debug, Error)] @@ -167,11 +167,8 @@ impl JwtHelper { pub fn verify_signature(&self) -> Result { let alg = &self.jwt.header.alg; let pkey = validate_cert_chain(&self.cert_chain()?)?; - - let result = - verify_jwt_signature(alg, &pkey, self.payload.as_bytes(), &self.jwt.signature)?; - - Ok(result) + verify_jwt_signature(alg, &pkey, self.payload.as_bytes(), &self.jwt.signature) + .map_err(JwtError::JwtSignatureVerification) } } @@ -222,13 +219,11 @@ fn verify_jwt_signature( payload: &[u8], signature: &[u8], ) -> Result { - let result = match alg { + match alg { JwtAlgorithm::RS256 => pkey - .verify_pkcs1_sha256(payload, signature) - .map_err(JwtSignatureVerificationError::VerifySignature)?, - }; - - Ok(result) + .pkcs1_verify(payload, signature, crypto::rsa::HashAlgorithm::Sha256) + .map_err(JwtSignatureVerificationError::VerifySignature), + } } /// Helper function for x509 certificate chain validation. @@ -396,7 +391,9 @@ mod tests { let rsa_key = RsaKeyPair::generate(2048).unwrap(); let payload = "test"; - let signature = rsa_key.sign_pkcs1_sha256(payload.as_bytes()).unwrap(); + let signature = rsa_key + .pkcs1_sign(payload.as_bytes(), crypto::rsa::HashAlgorithm::Sha256) + .unwrap(); let cert = crate::test_helpers::generate_x509(&rsa_key); let public = cert.public_key().unwrap(); diff --git a/openhcl/underhill_attestation/src/key_protector.rs b/openhcl/underhill_attestation/src/key_protector.rs index 4984f4ec86..7111a685f4 100644 --- a/openhcl/underhill_attestation/src/key_protector.rs +++ b/openhcl/underhill_attestation/src/key_protector.rs @@ -4,7 +4,7 @@ //! Implementation of the key retrieval logic for the [`KeyProtector`]. use crate::Keys; -use crypto::rsa::OaepHashAlgorithm; +use crypto::rsa::HashAlgorithm; use crypto::rsa::RsaKeyPair; use cvm_tracing::CVM_ALLOWED; use cvm_tracing::CVM_CONFIDENTIAL; @@ -132,7 +132,7 @@ impl KeyProtectorExt for KeyProtector { } ingress_kek - .oaep_decrypt(&wrapped_des_key[..modulus_size], OaepHashAlgorithm::Sha256) + .oaep_decrypt(&wrapped_des_key[..modulus_size], HashAlgorithm::Sha256) .map_err(GetKeysFromKeyProtectorError::DesKeyRsaUnwrap)? } else { // The DEK buffer should contain an RSA-wrapped key. @@ -141,7 +141,7 @@ impl KeyProtectorExt for KeyProtector { ingress_kek .oaep_decrypt( &self.dek[ingress_idx].dek_buffer[..modulus_size], - OaepHashAlgorithm::Sha256, + HashAlgorithm::Sha256, ) .map_err(GetKeysFromKeyProtectorError::IngressDekRsaUnwrap)? }; @@ -217,7 +217,7 @@ impl KeyProtectorExt for KeyProtector { } else { // The DEK buffer should contain an RSA-wrapped key. ingress_kek - .oaep_decrypt(&dek_buffer[..modulus_size], OaepHashAlgorithm::Sha256) + .oaep_decrypt(&dek_buffer[..modulus_size], HashAlgorithm::Sha256) .map_err(GetKeysFromKeyProtectorError::EgressDekRsaUnwrap)? }; let mut key = [0u8; AES_GCM_KEY_LENGTH]; @@ -241,7 +241,7 @@ impl KeyProtectorExt for KeyProtector { } else { // Create an RSA wrapped key ingress_kek - .oaep_encrypt(&encrypt_egress_key, OaepHashAlgorithm::Sha256) + .oaep_encrypt(&encrypt_egress_key, HashAlgorithm::Sha256) .map_err(GetKeysFromKeyProtectorError::EgressKeyRsaWrap)? }; @@ -295,7 +295,7 @@ mod tests { let dek = generate_aes_256(); // Test DEK wrapped by the test RSA KEK - let result = kek.oaep_encrypt(&dek, OaepHashAlgorithm::Sha256); + let result = kek.oaep_encrypt(&dek, HashAlgorithm::Sha256); assert!(result.is_ok()); let rsa_wrapped_dek = result.unwrap(); @@ -346,7 +346,7 @@ mod tests { let result = kek.oaep_decrypt( &key_protector.dek[egress_index].dek_buffer[..kek.modulus_size()], - OaepHashAlgorithm::Sha256, + HashAlgorithm::Sha256, ); assert!(result.is_ok()); let plaintext = result.unwrap(); @@ -379,7 +379,7 @@ mod tests { let result = kek.oaep_decrypt( &key_protector.dek[egress_index].dek_buffer[..kek.modulus_size()], - OaepHashAlgorithm::Sha256, + HashAlgorithm::Sha256, ); assert!(result.is_ok()); let plaintext = result.unwrap(); @@ -402,7 +402,7 @@ mod tests { let aes_wrapped_dek = result.unwrap(); // Test DES key wrapped by the test RSA KEK - let result = kek.oaep_encrypt(&des, OaepHashAlgorithm::Sha256); + let result = kek.oaep_encrypt(&des, HashAlgorithm::Sha256); assert!(result.is_ok()); let rsa_wrapped_des = result.unwrap(); @@ -457,7 +457,7 @@ mod tests { false ); - let result = kek.oaep_decrypt(&rsa_wrapped_des, OaepHashAlgorithm::Sha256); + let result = kek.oaep_decrypt(&rsa_wrapped_des, HashAlgorithm::Sha256); assert!(result.is_ok()); let des_key = result.unwrap(); @@ -499,7 +499,7 @@ mod tests { false ); - let result = kek.oaep_decrypt(&rsa_wrapped_des, OaepHashAlgorithm::Sha256); + let result = kek.oaep_decrypt(&rsa_wrapped_des, HashAlgorithm::Sha256); assert!(result.is_ok()); let des_key = result.unwrap(); @@ -528,7 +528,7 @@ mod tests { let mut aes_wrapped_dek = result.unwrap(); // Test DES key wrapped by the test RSA KEK - let result = kek.oaep_encrypt(&des, OaepHashAlgorithm::Sha256); + let result = kek.oaep_encrypt(&des, HashAlgorithm::Sha256); assert!(result.is_ok()); let rsa_wrapped_des = result.unwrap(); diff --git a/openhcl/underhill_attestation/src/secure_key_release.rs b/openhcl/underhill_attestation/src/secure_key_release.rs index a04cb31eaa..f639499b8b 100644 --- a/openhcl/underhill_attestation/src/secure_key_release.rs +++ b/openhcl/underhill_attestation/src/secure_key_release.rs @@ -76,7 +76,7 @@ fn pkcs11_rsa_aes_key_unwrap( unwrapping_rsa_key: &RsaKeyPair, wrapped_key_blob: &[u8], ) -> Result { - use crypto::rsa::OaepHashAlgorithm; + use crypto::rsa::HashAlgorithm; let modulus_size = unwrapping_rsa_key.modulus_size(); @@ -92,7 +92,7 @@ fn pkcs11_rsa_aes_key_unwrap( } let unwrapped_aes_key = unwrapping_rsa_key - .oaep_decrypt(wrapped_aes_key, OaepHashAlgorithm::Sha1) + .oaep_decrypt(wrapped_aes_key, HashAlgorithm::Sha1) .map_err(Pkcs11RsaAesKeyUnwrapError::RsaUnwrap)?; let unwrapped_rsa_key = crypto::aes_key_wrap::AesKeyWrap::new(&unwrapped_aes_key) .and_then(|kw| kw.unwrapper()?.unwrap(wrapped_rsa_key)) @@ -385,7 +385,6 @@ async fn make_igvm_attest_requests( #[cfg(test)] mod tests { use super::*; - use crypto::rsa::OaepHashAlgorithm; #[test] fn fail_to_unwrap_pkcs11_rsa_aes_with_undersized_wrapped_key_blob() { @@ -420,7 +419,7 @@ mod tests { let wrapping_rsa_key = RsaKeyPair::generate(2048).unwrap(); let wrapped_aes_key = wrapping_rsa_key - .oaep_encrypt(&wrapping_aes_key, OaepHashAlgorithm::Sha1) + .oaep_encrypt(&wrapping_aes_key, crypto::rsa::HashAlgorithm::Sha1) .unwrap(); let wrapped_target_key = crypto::aes_key_wrap::AesKeyWrap::new(&wrapping_aes_key) .unwrap() @@ -432,8 +431,8 @@ mod tests { let unwrapped_target_key = pkcs11_rsa_aes_key_unwrap(&wrapping_rsa_key, wrapped_key_blob.as_slice()).unwrap(); assert_eq!( - unwrapped_target_key.to_private_key_der().unwrap(), - target_key.to_private_key_der().unwrap() + unwrapped_target_key.to_pkcs8_der().unwrap(), + target_key.to_pkcs8_der().unwrap() ); } } diff --git a/openhcl/underhill_attestation/src/test_helpers.rs b/openhcl/underhill_attestation/src/test_helpers.rs index 19c94658e0..3bd2fc1506 100644 --- a/openhcl/underhill_attestation/src/test_helpers.rs +++ b/openhcl/underhill_attestation/src/test_helpers.rs @@ -70,7 +70,9 @@ pub fn generate_base64_encoded_jwt_components(key_pair: &RsaKeyPair) -> (String, // The signature is generated by signing the concatenation of base64_header and base64_body let message = format!("{}.{}", base64_header, base64_body); - let signature = key_pair.sign_pkcs1_sha256(message.as_bytes()).unwrap(); + let signature = key_pair + .pkcs1_sign(message.as_bytes(), crypto::rsa::HashAlgorithm::Sha256) + .unwrap(); let base64_signature = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&signature); (base64_header, base64_body, base64_signature) diff --git a/support/crypto/Cargo.toml b/support/crypto/Cargo.toml index 8af7f3a37d..37d2bc85e2 100644 --- a/support/crypto/Cargo.toml +++ b/support/crypto/Cargo.toml @@ -18,19 +18,23 @@ openssl = ["dep:openssl", "dep:openssl_kdf"] # Use Symcrypt instead of any native backend. # Note that the symcrypt backend does not support vendoring, enabling both # `symcrypt` and `vendored` will trigger a compile error -symcrypt = ["dep:symcrypt", "symcrypt/sha1"] +symcrypt = ["dep:symcrypt", "symcrypt/sha1", "dep:rsa"] [dependencies] openssl_kdf = { workspace = true, optional = true } openssl = { workspace = true, optional = true } symcrypt = { workspace = true, optional = true } +rsa = { workspace = true, optional = true, features = ["std", "encoding"] } thiserror.workspace = true [target.'cfg(windows)'.dependencies] wchar.workspace = true -windows = { workspace = true, features = ["Win32_Foundation","Win32_Security_Cryptography"] } +windows = { workspace = true, features = [ + "Win32_Foundation", + "Win32_Security_Cryptography", +] } windows-result.workspace = true zerocopy.workspace = true diff --git a/support/crypto/src/aes_256_cbc/symcrypt.rs b/support/crypto/src/aes_256_cbc/symcrypt.rs index d621b6fff4..68c2cc5ef7 100644 --- a/support/crypto/src/aes_256_cbc/symcrypt.rs +++ b/support/crypto/src/aes_256_cbc/symcrypt.rs @@ -17,7 +17,7 @@ pub struct Aes256CbcDecCtxInner<'a> { } fn err(e: ::symcrypt::errors::SymCryptError, op: &'static str) -> Aes256CbcError { - Aes256CbcError(crate::BackendError(e, op)) + Aes256CbcError(crate::BackendError::SymCryptError(e, op)) } impl Aes256CbcInner { diff --git a/support/crypto/src/aes_256_gcm/symcrypt.rs b/support/crypto/src/aes_256_gcm/symcrypt.rs index 537c59585f..466044f323 100644 --- a/support/crypto/src/aes_256_gcm/symcrypt.rs +++ b/support/crypto/src/aes_256_gcm/symcrypt.rs @@ -17,7 +17,7 @@ pub struct Aes256GcmDecCtxInner<'a> { } fn err(e: ::symcrypt::errors::SymCryptError, op: &'static str) -> Aes256GcmError { - Aes256GcmError(crate::BackendError(e, op)) + Aes256GcmError(crate::BackendError::SymCryptError(e, op)) } impl Aes256GcmInner { diff --git a/support/crypto/src/hmac_sha_256/symcrypt.rs b/support/crypto/src/hmac_sha_256/symcrypt.rs index 0a8ac54cfb..1479e28ddc 100644 --- a/support/crypto/src/hmac_sha_256/symcrypt.rs +++ b/support/crypto/src/hmac_sha_256/symcrypt.rs @@ -6,7 +6,7 @@ use ::symcrypt::errors::SymCryptError; use ::symcrypt::hmac::hmac_sha256; fn err(e: SymCryptError, op: &'static str) -> HmacSha256Error { - HmacSha256Error(crate::BackendError(e, op)) + HmacSha256Error(crate::BackendError::SymCryptError(e, op)) } pub fn hmac_sha_256(key: &[u8], data: &[u8]) -> Result<[u8; 32], HmacSha256Error> { diff --git a/support/crypto/src/lib.rs b/support/crypto/src/lib.rs index f202fcfe58..67974e7353 100644 --- a/support/crypto/src/lib.rs +++ b/support/crypto/src/lib.rs @@ -45,8 +45,13 @@ pub struct BackendError(#[source] windows_result::Error, &'static str); /// operation being performed when the error occurred. #[cfg(symcrypt)] #[derive(Clone, Debug, thiserror::Error)] -#[error("symcrypt error during {1}")] -pub struct BackendError(#[source] symcrypt::errors::SymCryptError, &'static str); +#[error("symcrypt backend error during {1}")] +pub enum BackendError { + /// An error from the SymCrypt library, with the operation being performed when the error occurred. + SymCryptError(#[source] symcrypt::errors::SymCryptError, &'static str), + /// An error from encoding or decoding PKCS#8, with the operation being performed when the error occurred. + Pkcs8EncodingError(#[source] ::rsa::pkcs8::Error, &'static str), +} #[cfg(all(native, target_os = "macos"))] pub use mac::BackendError; diff --git a/support/crypto/src/pkcs7/ossl.rs b/support/crypto/src/pkcs7/ossl.rs index e6d161794b..7fd9d2f963 100644 --- a/support/crypto/src/pkcs7/ossl.rs +++ b/support/crypto/src/pkcs7/ossl.rs @@ -45,13 +45,11 @@ impl Pkcs7SignedDataInner { key_pair: &crate::rsa::RsaKeyPair, data: &[u8], ) -> Result { - let pkey = openssl::pkey::PKey::from_rsa(key_pair.0.rsa.clone()) - .map_err(|e| err(e, "converting RSA key for pkcs7 signing"))?; let certs = openssl::stack::Stack::new().map_err(|e| err(e, "creating empty certificate stack"))?; let pkcs7 = openssl::pkcs7::Pkcs7::sign( &cert.0.cert, - &pkey, + &key_pair.0.0, &certs, data, openssl::pkcs7::Pkcs7Flags::empty(), diff --git a/support/crypto/src/rsa/mod.rs b/support/crypto/src/rsa/mod.rs index 76bd91cb39..4abc5d52da 100644 --- a/support/crypto/src/rsa/mod.rs +++ b/support/crypto/src/rsa/mod.rs @@ -3,13 +3,18 @@ //! RSA cryptographic operations. -#![cfg(openssl)] +#![cfg(any(openssl, symcrypt))] #[cfg(openssl)] pub(crate) mod ossl; #[cfg(openssl)] use ossl as sys; +#[cfg(symcrypt)] +pub(crate) mod symcrypt; +#[cfg(symcrypt)] +use symcrypt as sys; + use thiserror::Error; /// An error for RSA operations. @@ -17,9 +22,9 @@ use thiserror::Error; #[error("RSA error")] pub struct RsaError(#[source] super::BackendError); -/// Hash algorithm for RSA-OAEP encryption/decryption. +/// Hash algorithm for RSA operations. #[derive(Debug, Clone, Copy)] -pub enum OaepHashAlgorithm { +pub enum HashAlgorithm { /// SHA-1 Sha1, /// SHA-256 @@ -27,6 +32,7 @@ pub enum OaepHashAlgorithm { } /// An RSA private key (key pair). +#[repr(transparent)] // Needed for the transmute in deref. pub struct RsaKeyPair(pub(crate) sys::RsaKeyPairInner); impl RsaKeyPair { @@ -40,61 +46,76 @@ impl RsaKeyPair { sys::RsaKeyPairInner::from_pkcs8_der(der).map(Self) } - /// Returns the size of the RSA modulus in bytes. - pub fn modulus_size(&self) -> usize { - self.0.modulus_size() + /// Convert the RSA private key to PKCS#8 DER-encoded bytes. + pub fn to_pkcs8_der(&self) -> Result, RsaError> { + self.0.to_pkcs8_der() } - /// Returns the RSA modulus as a big-endian byte vector. - pub fn modulus(&self) -> Vec { - self.0.modulus() + /// Decrypt `input` using RSA-OAEP with the specified hash algorithm. + pub fn oaep_decrypt( + &self, + input: &[u8], + hash_algorithm: HashAlgorithm, + ) -> Result, RsaError> { + self.0.oaep_decrypt(input, hash_algorithm) } - /// Returns the RSA public exponent as a big-endian byte vector. - pub fn public_exponent(&self) -> Vec { - self.0.public_exponent() + /// Sign `data` using RSA PKCS#1 v1.5 with the specified hash algorithm. + pub fn pkcs1_sign( + &self, + data: &[u8], + hash_algorithm: HashAlgorithm, + ) -> Result, RsaError> { + self.0.pkcs1_sign(data, hash_algorithm) } +} + +/// An RSA public key. +#[repr(transparent)] // Needed for the transmute in deref. +pub struct RsaPublicKey(pub(crate) sys::RsaPublicKeyInner); +impl RsaPublicKey { /// Encrypt `input` using RSA-OAEP with the specified hash algorithm. pub fn oaep_encrypt( &self, input: &[u8], - hash_algorithm: OaepHashAlgorithm, + hash_algorithm: HashAlgorithm, ) -> Result, RsaError> { self.0.oaep_encrypt(input, hash_algorithm) } - /// Decrypt `input` using RSA-OAEP with the specified hash algorithm. - pub fn oaep_decrypt( + /// Verify an RSA PKCS#1 v1.5 signature with the specified hash algorithm. Returns `Ok(true)` if the signature is valid. + /// Different backends may return Ok(false) or an error if the signature is invalid, but all return an error for other failures. + pub fn pkcs1_verify( &self, - input: &[u8], - hash_algorithm: OaepHashAlgorithm, - ) -> Result, RsaError> { - self.0.oaep_decrypt(input, hash_algorithm) + message: &[u8], + signature: &[u8], + hash_algorithm: HashAlgorithm, + ) -> Result { + self.0.pkcs1_verify(message, signature, hash_algorithm) } - /// Export the private key in PKCS#8 DER format. - pub fn to_pkcs8_der(&self) -> Result, RsaError> { - self.0.to_pkcs8_der() + /// Returns the size of the RSA modulus in bytes. + pub fn modulus_size(&self) -> usize { + self.0.modulus_size() } - /// Export the private key in traditional RSA DER format. - pub fn to_private_key_der(&self) -> Result, RsaError> { - self.0.to_private_key_der() + /// Returns the RSA modulus as a big-endian byte vector. + pub fn modulus(&self) -> Vec { + self.0.modulus() } - /// Sign `data` using RSA PKCS#1 v1.5 with SHA-256. - pub fn sign_pkcs1_sha256(&self, data: &[u8]) -> Result, RsaError> { - self.0.sign_pkcs1_sha256(data) + /// Returns the RSA public exponent as a big-endian byte vector. + pub fn public_exponent(&self) -> Vec { + self.0.public_exponent() } } -/// An RSA public key. -pub struct RsaPublicKey(pub(crate) sys::RsaPublicKeyInner); +impl std::ops::Deref for RsaKeyPair { + type Target = RsaPublicKey; -impl RsaPublicKey { - /// Verify an RSA PKCS#1 v1.5 signature with SHA-256. - pub fn verify_pkcs1_sha256(&self, message: &[u8], signature: &[u8]) -> Result { - self.0.verify_pkcs1_sha256(message, signature) + fn deref(&self) -> &Self::Target { + // SAFETY: RsaPublicKey is just a wrapper around RsaPublicKeyInner. + unsafe { std::mem::transmute::<&sys::RsaPublicKeyInner, &RsaPublicKey>(self.0.as_pub()) } } } diff --git a/support/crypto/src/rsa/ossl.rs b/support/crypto/src/rsa/ossl.rs index 5fdd81c4da..4598f80cd0 100644 --- a/support/crypto/src/rsa/ossl.rs +++ b/support/crypto/src/rsa/ossl.rs @@ -1,115 +1,62 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use super::OaepHashAlgorithm; +use super::HashAlgorithm; use super::RsaError; fn err(err: openssl::error::ErrorStack, op: &'static str) -> RsaError { RsaError(crate::BackendError(err, op)) } -pub struct RsaKeyPairInner { - pub(crate) rsa: openssl::rsa::Rsa, -} +#[repr(transparent)] // Needed for the transmute in as_pub. +pub struct RsaKeyPairInner(pub(crate) openssl::pkey::PKey); impl RsaKeyPairInner { pub fn generate(bits: u32) -> Result { let rsa = openssl::rsa::Rsa::generate(bits).map_err(|e| err(e, "generating RSA key"))?; - Ok(Self { rsa }) + let pkey = + openssl::pkey::PKey::from_rsa(rsa).map_err(|e| err(e, "converting RSA to PKey"))?; + Ok(Self(pkey)) } pub fn from_pkcs8_der(der: &[u8]) -> Result { let pkey = openssl::pkey::PKey::private_key_from_pkcs8(der) .map_err(|e| err(e, "parsing PKCS#8 DER"))?; - let rsa = pkey - .rsa() - .map_err(|e| err(e, "extracting RSA key from PKey"))?; - Ok(Self { rsa }) - } - - pub fn modulus_size(&self) -> usize { - self.rsa.size() as usize - } - - pub fn modulus(&self) -> Vec { - self.rsa.n().to_vec() - } - - pub fn public_exponent(&self) -> Vec { - self.rsa.e().to_vec() + // Ensure the key is actually an RSA key. + pkey.rsa().map_err(|e| err(e, "checking key is RSA"))?; + Ok(Self(pkey)) } - pub fn oaep_encrypt( - &self, - input: &[u8], - hash_algorithm: OaepHashAlgorithm, - ) -> Result, RsaError> { - let pkey = openssl::pkey::PKey::from_rsa(self.rsa.clone()) - .map_err(|e| err(e, "converting RSA to PKey"))?; - let mut ctx = - openssl::pkey_ctx::PkeyCtx::new(&pkey).map_err(|e| err(e, "creating PkeyCtx"))?; - - ctx.encrypt_init().map_err(|e| err(e, "encrypt init"))?; - ctx.set_rsa_padding(openssl::rsa::Padding::PKCS1_OAEP) - .map_err(|e| err(e, "setting RSA padding"))?; - - match hash_algorithm { - OaepHashAlgorithm::Sha1 => ctx.set_rsa_oaep_md(openssl::md::Md::sha1()), - OaepHashAlgorithm::Sha256 => ctx.set_rsa_oaep_md(openssl::md::Md::sha256()), - } - .map_err(|e| err(e, "setting OAEP hash"))?; - - let mut output = vec![]; - ctx.encrypt_to_vec(input, &mut output) - .map_err(|e| err(e, "RSA-OAEP encrypt"))?; - - Ok(output) + pub fn to_pkcs8_der(&self) -> Result, RsaError> { + self.0 + .private_key_to_pkcs8() + .map_err(|e| err(e, "exporting PKCS#8 DER")) } pub fn oaep_decrypt( &self, input: &[u8], - hash_algorithm: OaepHashAlgorithm, + hash_algorithm: HashAlgorithm, ) -> Result, RsaError> { - let pkey = openssl::pkey::PKey::from_rsa(self.rsa.clone()) - .map_err(|e| err(e, "converting RSA to PKey"))?; let mut ctx = - openssl::pkey_ctx::PkeyCtx::new(&pkey).map_err(|e| err(e, "creating PkeyCtx"))?; - + openssl::pkey_ctx::PkeyCtx::new(&self.0).map_err(|e| err(e, "creating PkeyCtx"))?; ctx.decrypt_init().map_err(|e| err(e, "decrypt init"))?; ctx.set_rsa_padding(openssl::rsa::Padding::PKCS1_OAEP) .map_err(|e| err(e, "setting RSA padding"))?; - - match hash_algorithm { - OaepHashAlgorithm::Sha1 => ctx.set_rsa_oaep_md(openssl::md::Md::sha1()), - OaepHashAlgorithm::Sha256 => ctx.set_rsa_oaep_md(openssl::md::Md::sha256()), - } - .map_err(|e| err(e, "setting OAEP hash"))?; - + ctx.set_rsa_oaep_md(conv_mdref(hash_algorithm)) + .map_err(|e| err(e, "setting OAEP hash"))?; let mut output = vec![]; ctx.decrypt_to_vec(input, &mut output) .map_err(|e| err(e, "RSA-OAEP decrypt"))?; - Ok(output) } - pub fn to_pkcs8_der(&self) -> Result, RsaError> { - let pkey = openssl::pkey::PKey::from_rsa(self.rsa.clone()) - .map_err(|e| err(e, "converting RSA to PKey"))?; - pkey.private_key_to_pkcs8() - .map_err(|e| err(e, "exporting PKCS#8 DER")) - } - - pub fn to_private_key_der(&self) -> Result, RsaError> { - self.rsa - .private_key_to_der() - .map_err(|e| err(e, "exporting private key DER")) - } - - pub fn sign_pkcs1_sha256(&self, data: &[u8]) -> Result, RsaError> { - let pkey = openssl::pkey::PKey::from_rsa(self.rsa.clone()) - .map_err(|e| err(e, "converting RSA key to PKey for signing"))?; - let mut signer = openssl::sign::Signer::new(openssl::hash::MessageDigest::sha256(), &pkey) + pub fn pkcs1_sign( + &self, + data: &[u8], + hash_algorithm: HashAlgorithm, + ) -> Result, RsaError> { + let mut signer = openssl::sign::Signer::new(conv_md(hash_algorithm), &self.0) .map_err(|e| err(e, "creating signer"))?; signer .set_rsa_padding(openssl::rsa::Padding::PKCS1) @@ -117,21 +64,43 @@ impl RsaKeyPairInner { signer.update(data).map_err(|e| err(e, "signer update"))?; signer.sign_to_vec().map_err(|e| err(e, "signer sign")) } -} -pub struct RsaPublicKeyInner { - pub(crate) pkey: openssl::pkey::PKey, + pub(crate) fn as_pub(&self) -> &RsaPublicKeyInner { + // SAFETY: PKey can be safely treated as PKey for read-only operations. + unsafe { std::mem::transmute::<&RsaKeyPairInner, &RsaPublicKeyInner>(self) } + } } +#[repr(transparent)] // Needed for the transmute in as_pub. +pub struct RsaPublicKeyInner(pub(crate) openssl::pkey::PKey); + impl RsaPublicKeyInner { - pub fn from_pkey(pkey: openssl::pkey::PKey) -> Self { - Self { pkey } + pub fn oaep_encrypt( + &self, + input: &[u8], + hash_algorithm: HashAlgorithm, + ) -> Result, RsaError> { + let mut ctx = + openssl::pkey_ctx::PkeyCtx::new(&self.0).map_err(|e| err(e, "creating PkeyCtx"))?; + ctx.encrypt_init().map_err(|e| err(e, "encrypt init"))?; + ctx.set_rsa_padding(openssl::rsa::Padding::PKCS1_OAEP) + .map_err(|e| err(e, "setting RSA padding"))?; + ctx.set_rsa_oaep_md(conv_mdref(hash_algorithm)) + .map_err(|e| err(e, "setting OAEP hash"))?; + let mut output = vec![]; + ctx.encrypt_to_vec(input, &mut output) + .map_err(|e| err(e, "RSA-OAEP encrypt"))?; + Ok(output) } - pub fn verify_pkcs1_sha256(&self, message: &[u8], signature: &[u8]) -> Result { - let mut verifier = - openssl::sign::Verifier::new(openssl::hash::MessageDigest::sha256(), &self.pkey) - .map_err(|e| err(e, "creating verifier"))?; + pub fn pkcs1_verify( + &self, + message: &[u8], + signature: &[u8], + hash_algorithm: HashAlgorithm, + ) -> Result { + let mut verifier = openssl::sign::Verifier::new(conv_md(hash_algorithm), &self.0) + .map_err(|e| err(e, "creating verifier"))?; verifier .set_rsa_padding(openssl::rsa::Padding::PKCS1) .map_err(|e| err(e, "setting RSA padding"))?; @@ -142,4 +111,33 @@ impl RsaPublicKeyInner { .verify(signature) .map_err(|e| err(e, "verifier verify")) } + + pub fn modulus_size(&self) -> usize { + // TODO: This should use EVP_PKEY_get_params but the openssl crate doesn't expose it + self.0.rsa().unwrap().size() as usize + } + + pub fn modulus(&self) -> Vec { + // TODO: This should use EVP_PKEY_get_params but the openssl crate doesn't expose it + self.0.rsa().unwrap().n().to_vec() + } + + pub fn public_exponent(&self) -> Vec { + // TODO: This should use EVP_PKEY_get_params but the openssl crate doesn't expose it + self.0.rsa().unwrap().e().to_vec() + } +} + +fn conv_md(hash_algorithm: HashAlgorithm) -> openssl::hash::MessageDigest { + match hash_algorithm { + HashAlgorithm::Sha1 => openssl::hash::MessageDigest::sha1(), + HashAlgorithm::Sha256 => openssl::hash::MessageDigest::sha256(), + } +} + +fn conv_mdref(hash_algorithm: HashAlgorithm) -> &'static openssl::md::MdRef { + match hash_algorithm { + HashAlgorithm::Sha1 => openssl::md::Md::sha1(), + HashAlgorithm::Sha256 => openssl::md::Md::sha256(), + } } diff --git a/support/crypto/src/rsa/symcrypt.rs b/support/crypto/src/rsa/symcrypt.rs new file mode 100644 index 0000000000..92073e1d46 --- /dev/null +++ b/support/crypto/src/rsa/symcrypt.rs @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use super::RsaError; +use rsa::pkcs8::DecodePrivateKey; +use rsa::pkcs8::EncodePrivateKey; +use rsa::traits::PrivateKeyParts; +use rsa::traits::PublicKeyParts; +use symcrypt::rsa::RsaKey; +use symcrypt::rsa::RsaKeyUsage; + +fn err(err: symcrypt::errors::SymCryptError, op: &'static str) -> RsaError { + RsaError(crate::BackendError::SymCryptError(err, op)) +} + +fn pkcs8_err(err: rsa::pkcs8::Error, op: &'static str) -> RsaError { + RsaError(crate::BackendError::Pkcs8EncodingError(err, op)) +} + +#[repr(transparent)] // Needed for the transmute in as_pub. +pub struct RsaKeyPairInner(RsaKey); + +impl RsaKeyPairInner { + pub fn generate(bits: u32) -> Result { + let rsa = RsaKey::generate_key_pair(bits, None, RsaKeyUsage::SignAndEncrypt) + .map_err(|e| err(e, "generating RSA key"))?; + Ok(Self(rsa)) + } + + pub fn from_pkcs8_der(der: &[u8]) -> Result { + let parsed = rsa::RsaPrivateKey::from_pkcs8_der(der) + .map_err(|e| pkcs8_err(e, "parsing PKCS#8 DER"))?; + let primes = parsed.primes(); + if primes.len() != 2 { + return Err(RsaError(crate::BackendError::Pkcs8EncodingError( + rsa::pkcs8::Error::KeyMalformed(rsa::pkcs8::KeyError::Invalid), + "multiprime RSA keys not supported", + ))); + } + let rsa = RsaKey::set_key_pair( + &parsed.n().to_be_bytes_trimmed_vartime(), + &parsed.e().to_be_bytes_trimmed_vartime(), + &primes[0].to_be_bytes_trimmed_vartime(), + &primes[1].to_be_bytes_trimmed_vartime(), + RsaKeyUsage::SignAndEncrypt, + ) + .map_err(|e| err(e, "setting RSA key pair"))?; + Ok(Self(rsa)) + } + + pub fn to_pkcs8_der(&self) -> Result, RsaError> { + let blob = self + .0 + .export_key_pair_blob() + .map_err(|e| err(e, "exporting RSA key blob"))?; + let rsa = rsa::RsaPrivateKey::from_components( + rsa::BoxedUint::from_be_slice_vartime(&blob.modulus), + rsa::BoxedUint::from_be_slice_vartime(&blob.pub_exp), + rsa::BoxedUint::from_be_slice_vartime(&blob.private_exp), + vec![ + rsa::BoxedUint::from_be_slice_vartime(&blob.p), + rsa::BoxedUint::from_be_slice_vartime(&blob.q), + ], + ) + .unwrap(); + Ok(rsa + .to_pkcs8_der() + .map_err(|e| pkcs8_err(e, "converting to DER"))? + .as_bytes() + .to_vec()) + } + + pub fn oaep_decrypt( + &self, + input: &[u8], + hash_algorithm: super::HashAlgorithm, + ) -> Result, RsaError> { + self.0 + .oaep_decrypt(input, conv_hash(hash_algorithm), &[]) + .map_err(|e| err(e, "OAEP decryption")) + } + + pub fn pkcs1_sign( + &self, + data: &[u8], + hash_algorithm: super::HashAlgorithm, + ) -> Result, RsaError> { + self.0 + .pkcs1_sign(data, conv_hash(hash_algorithm)) + .map_err(|e| err(e, "PKCS#1 signing")) + } + + pub(crate) fn as_pub(&self) -> &RsaPublicKeyInner { + // SAFETY: RsaPublicKeyInner is just a wrapper around the same RsaKey. + unsafe { std::mem::transmute::<&RsaKeyPairInner, &RsaPublicKeyInner>(self) } + } +} + +#[repr(transparent)] // Needed for the transmute in as_pub. +pub struct RsaPublicKeyInner(RsaKey); + +impl RsaPublicKeyInner { + pub fn oaep_encrypt( + &self, + input: &[u8], + hash_algorithm: super::HashAlgorithm, + ) -> Result, RsaError> { + self.0 + .oaep_encrypt(input, conv_hash(hash_algorithm), &[]) + .map_err(|e| err(e, "OAEP encryption")) + } + + pub fn pkcs1_verify( + &self, + data: &[u8], + signature: &[u8], + hash_algorithm: super::HashAlgorithm, + ) -> Result { + self.0 + .pkcs1_verify(data, signature, conv_hash(hash_algorithm)) + .map_err(|e| err(e, "PKCS#1 signature verification"))?; + Ok(true) + } + + pub fn modulus_size(&self) -> usize { + self.0.get_size_of_modulus() as usize + } + + pub fn modulus(&self) -> Vec { + // TODO: Maybe cache the pub blob? + self.0.export_public_key_blob().unwrap().modulus + } + + pub fn public_exponent(&self) -> Vec { + // TODO: Maybe cache the pub blob? + self.0.export_public_key_blob().unwrap().pub_exp + } +} + +fn conv_hash(hash_algorithm: super::HashAlgorithm) -> symcrypt::hash::HashAlgorithm { + match hash_algorithm { + super::HashAlgorithm::Sha1 => symcrypt::hash::HashAlgorithm::Sha1, + super::HashAlgorithm::Sha256 => symcrypt::hash::HashAlgorithm::Sha256, + } +} diff --git a/support/crypto/src/x509/ossl.rs b/support/crypto/src/x509/ossl.rs index cbae1ede13..cd3ed08479 100644 --- a/support/crypto/src/x509/ossl.rs +++ b/support/crypto/src/x509/ossl.rs @@ -25,13 +25,13 @@ impl X509CertificateInner { .map_err(|e| err(e, "extracting public key"))?; assert_eq!(pkey.id(), openssl::pkey::Id::RSA); Ok(crate::rsa::RsaPublicKey( - crate::rsa::ossl::RsaPublicKeyInner::from_pkey(pkey), + crate::rsa::ossl::RsaPublicKeyInner(pkey), )) } pub fn verify(&self, issuer_public_key: &crate::rsa::RsaPublicKey) -> Result { self.cert - .verify(&issuer_public_key.0.pkey) + .verify(&issuer_public_key.0.0) .map_err(|e| err(e, "verifying certificate signature")) } @@ -72,10 +72,8 @@ impl X509BuilderInner { &mut self, key_pair: &crate::rsa::RsaKeyPair, ) -> Result<(), X509Error> { - let pkey = openssl::pkey::PKey::from_rsa(key_pair.0.rsa.clone()) - .map_err(|e| err(e, "converting RSA key to PKey"))?; self.builder - .set_pubkey(&pkey) + .set_pubkey(&key_pair.0.0) .map_err(|e| err(e, "setting public key")) } @@ -125,10 +123,8 @@ impl X509BuilderInner { mut self, key_pair: &crate::rsa::RsaKeyPair, ) -> Result { - let pkey = openssl::pkey::PKey::from_rsa(key_pair.0.rsa.clone()) - .map_err(|e| err(e, "converting RSA key for signing"))?; self.builder - .sign(&pkey, openssl::hash::MessageDigest::sha256()) + .sign(&key_pair.0.0, openssl::hash::MessageDigest::sha256()) .map_err(|e| err(e, "signing certificate"))?; Ok(X509CertificateInner { cert: self.builder.build(),