Skip to content

Commit 80d7956

Browse files
committed
Refactor certificate handling and add key password support
Updated file extension conventions for PEM files and improved attribute handling with stricter checks. Added support for encrypted private keys and clarified SSL operations, ensuring better consistency and maintainability.
1 parent 82e79c9 commit 80d7956

2 files changed

Lines changed: 49 additions & 32 deletions

File tree

chaski/utils/certificate_authority.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import ssl
33
import datetime
4-
from ipaddress import ip_address
4+
from ipaddress import ip_address as _ip_address
55
from typing import Literal, Optional
66

77
# Importing cryptography modules for X509 certificates, private key generation,
@@ -24,9 +24,10 @@ def __init__(
2424
id: str,
2525
ip_address: str,
2626
ssl_certificates_location: str = None,
27-
ssl_certificate_attributes: dict = {},
27+
ssl_certificate_attributes: dict = None,
2828
key_password: bytes | None = None,
2929
end_entity_key_size: int = 4096,
30+
ca_key_size: int = 4096,
3031
):
3132
"""
3233
Initialize the Certificate Authority (CA).
@@ -46,10 +47,11 @@ def __init__(
4647
self.ssl_certificates_location = ssl_certificates_location
4748

4849
os.makedirs(self.ssl_certificates_location, exist_ok=True)
49-
self.ssl_certificate_attributes = ssl_certificate_attributes
50-
self.ip_address = ip_address
50+
self.ssl_certificate_attributes = ssl_certificate_attributes or {}
51+
self._ip = _ip_address(ip_address)
5152
self.key_password = key_password
5253
self.end_entity_key_size = end_entity_key_size
54+
self.ca_key_size = ca_key_size
5355

5456
def setup_certificate_authority(self) -> None:
5557
"""
@@ -69,12 +71,14 @@ def setup_certificate_authority(self) -> None:
6971
IOError
7072
If there is an error reading or writing the key and certificate files.
7173
"""
72-
self.ca_key_path_ = os.path.join(self.ssl_certificates_location, "ca.key")
73-
self.ca_cert_path_ = os.path.join(self.ssl_certificates_location, "ca.cert")
74+
self._require_attrs()
75+
76+
self.ca_key_path_ = os.path.join(self.ssl_certificates_location, "ca.key.pem")
77+
self.ca_cert_path_ = os.path.join(self.ssl_certificates_location, "ca.crt.pem")
7478

7579
# Generate CA key
7680
ca_key = rsa.generate_private_key(
77-
public_exponent=65537, key_size=self.end_entity_key_size
81+
public_exponent=65537, key_size=self.ca_key_size
7882
)
7983

8084
# Write CA key to file
@@ -274,7 +278,8 @@ def sign_csr(
274278

275279
# Load CA key
276280
ca_key = serialization.load_pem_private_key(
277-
self.load_certificate(self.ca_private_key_path), password=None
281+
self.load_certificate(self.ca_private_key_path),
282+
password=self.key_password,
278283
)
279284

280285
# Load CA certificate
@@ -352,9 +357,7 @@ def sign_csr(
352357

353358
if role == "server":
354359
builder = builder.add_extension(
355-
x509.SubjectAlternativeName(
356-
[x509.IPAddress(ip_address(self.ip_address))]
357-
),
360+
x509.SubjectAlternativeName([x509.IPAddress(self._ip)]),
358361
critical=False,
359362
)
360363

@@ -387,12 +390,14 @@ def _key_and_csr(self, name="client") -> tuple:
387390
IOError
388391
If there is an error writing the private key or CSR files to the filesystem.
389392
"""
393+
self._require_attrs()
394+
390395
private_key_path_ = os.path.join(
391-
self.ssl_certificates_location, f"{name}_{self.id}.key"
396+
self.ssl_certificates_location, f"{name}_{self.id}.key.pem"
392397
)
393398

394399
certificate_path_ = os.path.join(
395-
self.ssl_certificates_location, f"{name}_{self.id}.csr"
400+
self.ssl_certificates_location, f"{name}_{self.id}.csr.pem"
396401
)
397402

398403
key = rsa.generate_private_key(
@@ -572,25 +577,21 @@ def certificate_paths(self) -> dict:
572577
@property
573578
def certificate_signed_paths(self) -> dict:
574579
"""
575-
Provide the path to the signed certificate.
580+
Get the paths to the signed certificates for both client and server.
576581
577-
This property retrieves the file path to the signed certificate
578-
by replacing the '.csr.pem' suffix of the CSR path with
579-
'.sign.csr.pem'.
582+
This property returns a dictionary containing the file paths to the signed
583+
certificates for both the client and server. The paths are derived from
584+
the CSR paths by replacing the '.csr' extension with '.cert'.
580585
581586
Returns
582587
-------
583-
str
584-
The file path to the signed certificate.
585-
586-
Raises
587-
------
588-
Exception
589-
If the CSR path is not set.
588+
dict
589+
A dictionary with 'client' and 'server' keys mapping to their respective
590+
signed certificate file paths.
590591
"""
591592
return {
592-
"client": self.certificate_paths["client"].replace(".csr", ".cert"),
593-
"server": self.certificate_paths["server"].replace(".csr", ".cert"),
593+
"client": self.certificate_paths["client"].replace(".csr.pem", ".crt.pem"),
594+
"server": self.certificate_paths["server"].replace(".csr.pem", ".crt.pem"),
594595
}
595596

596597
def load_certificate(self, path: str) -> bytes:
@@ -670,6 +671,7 @@ def get_context(self) -> tuple[ssl.SSLContext, ssl.SSLContext]:
670671
ssl_context_client.load_cert_chain(
671672
certfile=self.certificate_signed_paths["client"],
672673
keyfile=self.private_key_paths["client"],
674+
password=self.key_password,
673675
)
674676
# Load and set the Certificate Authority (CA) certificate to verify the client's server certificate
675677
ssl_context_client.load_verify_locations(cafile=self.ca_certificate_path)
@@ -682,6 +684,7 @@ def get_context(self) -> tuple[ssl.SSLContext, ssl.SSLContext]:
682684
ssl_context_server.load_cert_chain(
683685
certfile=self.certificate_signed_paths["server"],
684686
keyfile=self.private_key_paths["server"],
687+
password=self.key_password,
685688
)
686689
# Load and set the Certificate Authority (CA) certificate to verify the server's client certificate
687690
ssl_context_server.load_verify_locations(cafile=self.ca_certificate_path)
@@ -720,3 +723,17 @@ def sign_and_store(self, name: Literal["server", "client"]) -> str:
720723
self.write_certificate(cert_path, cert_pem)
721724
os.chmod(cert_path, 0o644)
722725
return cert_path
726+
727+
def _require_attrs(
728+
self,
729+
keys=(
730+
"Country Name",
731+
"State or Province Name",
732+
"Locality Name",
733+
"Organization Name",
734+
"Common Name",
735+
),
736+
):
737+
missing = [k for k in keys if k not in self.ssl_certificate_attributes]
738+
if missing:
739+
raise ValueError(f"Missing certificate attributes: {missing}")

test/test_certificate_authority.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,22 +251,22 @@ async def test_sign(self, certificate_authority: CertificateAuthority) -> None:
251251
ca = certificate_authority
252252

253253
ca.load_ca(
254-
ca_key_path=os.path.join(self.SSL_CERTIFICATES_LOCATION, "ca.key"),
255-
ca_cert_path=os.path.join(self.SSL_CERTIFICATES_LOCATION, "ca.cert"),
254+
ca_key_path=os.path.join(self.SSL_CERTIFICATES_LOCATION, "ca.key.pem"),
255+
ca_cert_path=os.path.join(self.SSL_CERTIFICATES_LOCATION, "ca.crt.pem"),
256256
)
257257

258258
ca.load_key_and_csr(
259259
private_key_client_path=os.path.join(
260-
self.SSL_CERTIFICATES_LOCATION, f"client_{self.TEST_NAME}.key"
260+
self.SSL_CERTIFICATES_LOCATION, f"client_{self.TEST_NAME}.key.pem"
261261
),
262262
certificate_client_path=os.path.join(
263-
self.SSL_CERTIFICATES_LOCATION, f"client_{self.TEST_NAME}.csr"
263+
self.SSL_CERTIFICATES_LOCATION, f"client_{self.TEST_NAME}.csr.pem"
264264
),
265265
private_key_server_path=os.path.join(
266-
self.SSL_CERTIFICATES_LOCATION, f"server_{self.TEST_NAME}.key"
266+
self.SSL_CERTIFICATES_LOCATION, f"server_{self.TEST_NAME}.key.pem"
267267
),
268268
certificate_server_path=os.path.join(
269-
self.SSL_CERTIFICATES_LOCATION, f"server_{self.TEST_NAME}.csr"
269+
self.SSL_CERTIFICATES_LOCATION, f"server_{self.TEST_NAME}.csr.pem"
270270
),
271271
)
272272

0 commit comments

Comments
 (0)