diff options
Diffstat (limited to 'test/modules/md/md_cert_util.py')
-rwxr-xr-x | test/modules/md/md_cert_util.py | 70 |
1 files changed, 20 insertions, 50 deletions
diff --git a/test/modules/md/md_cert_util.py b/test/modules/md/md_cert_util.py index abcd36b938..6cd034a02b 100755 --- a/test/modules/md/md_cert_util.py +++ b/test/modules/md/md_cert_util.py @@ -1,6 +1,5 @@ import logging import re -import os import socket import OpenSSL import time @@ -12,6 +11,7 @@ from datetime import timedelta from http.client import HTTPConnection from urllib.parse import urlparse +from cryptography import x509 SEC_PER_DAY = 24 * 60 * 60 @@ -24,45 +24,6 @@ class MDCertUtil(object): # Uses PyOpenSSL: https://pyopenssl.org/en/stable/index.html @classmethod - def create_self_signed_cert(cls, path, name_list, valid_days, serial=1000): - domain = name_list[0] - if not os.path.exists(path): - os.makedirs(path) - - cert_file = os.path.join(path, 'pubcert.pem') - pkey_file = os.path.join(path, 'privkey.pem') - # create a key pair - if os.path.exists(pkey_file): - key_buffer = open(pkey_file, 'rt').read() - k = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, key_buffer) - else: - k = OpenSSL.crypto.PKey() - k.generate_key(OpenSSL.crypto.TYPE_RSA, 2048) - - # create a self-signed cert - cert = OpenSSL.crypto.X509() - cert.get_subject().C = "DE" - cert.get_subject().ST = "NRW" - cert.get_subject().L = "Muenster" - cert.get_subject().O = "greenbytes GmbH" - cert.get_subject().CN = domain - cert.set_serial_number(serial) - cert.gmtime_adj_notBefore(valid_days["notBefore"] * SEC_PER_DAY) - cert.gmtime_adj_notAfter(valid_days["notAfter"] * SEC_PER_DAY) - cert.set_issuer(cert.get_subject()) - - cert.add_extensions([OpenSSL.crypto.X509Extension( - b"subjectAltName", False, b", ".join(map(lambda n: b"DNS:" + n.encode(), name_list)) - )]) - cert.set_pubkey(k) - cert.sign(k, 'sha1') - - open(cert_file, "wt").write( - OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert).decode('utf-8')) - open(pkey_file, "wt").write( - OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, k).decode('utf-8')) - - @classmethod def load_server_cert(cls, host_ip, host_port, host_name, tls=None, ciphers=None): ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) if tls is not None and tls != 1.0: @@ -138,17 +99,26 @@ class MDCertUtil(object): # add leading 0s to align with word boundaries. return ("%lx" % (self.cert.get_serial_number())).upper() - def same_serial_as(self, other): - if isinstance(other, MDCertUtil): - return self.cert.get_serial_number() == other.cert.get_serial_number() - elif isinstance(other, OpenSSL.crypto.X509): - return self.cert.get_serial_number() == other.get_serial_number() - elif isinstance(other, str): + @staticmethod + def _get_serial(cert) -> int: + if isinstance(cert, x509.Certificate): + return cert.serial_number + if isinstance(cert, MDCertUtil): + return cert.get_serial_number() + elif isinstance(cert, OpenSSL.crypto.X509): + return cert.get_serial_number() + elif isinstance(cert, str): # assume a hex number - return self.cert.get_serial_number() == int(other, 16) - elif isinstance(other, int): - return self.cert.get_serial_number() == other - return False + return int(cert, 16) + elif isinstance(cert, int): + return cert + return 0 + + def get_serial_number(self): + return self._get_serial(self.cert) + + def same_serial_as(self, other): + return self._get_serial(self.cert) == self._get_serial(other) def get_not_before(self): tsp = self.cert.get_notBefore() |