diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cryptography/hazmat/backends/openssl/backend.py | 22 | ||||
-rw-r--r-- | src/cryptography/hazmat/backends/openssl/x509.py | 175 | ||||
-rw-r--r-- | src/cryptography/hazmat/primitives/serialization.py | 43 | ||||
-rw-r--r-- | src/cryptography/utils.py | 21 | ||||
-rw-r--r-- | src/cryptography/x509.py | 2 |
5 files changed, 136 insertions, 127 deletions
diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index 85f65972..91bc304f 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -119,6 +119,9 @@ def _encode_basic_constraints(backend, basic_constraints, critical): obj = _txt2obj(backend, x509.OID_BASIC_CONSTRAINTS.dotted_string) assert obj is not None constraints = backend._lib.BASIC_CONSTRAINTS_new() + constraints = backend._ffi.gc( + constraints, backend._lib.BASIC_CONSTRAINTS_free + ) constraints.ca = 255 if basic_constraints.ca else 0 if basic_constraints.ca: constraints.pathlen = _encode_asn1_int( @@ -685,8 +688,7 @@ class Backend(object): def generate_dsa_parameters(self, key_size): if key_size not in (1024, 2048, 3072): - raise ValueError( - "Key size must be 1024 or 2048 or 3072 bits.") + raise ValueError("Key size must be 1024 or 2048 or 3072 bits.") if (self._lib.OPENSSL_VERSION_NUMBER < 0x1000000f and key_size > 1024): @@ -833,7 +835,7 @@ class Backend(object): # Set subject name. res = self._lib.X509_REQ_set_subject_name( - x509_req, _encode_name(self, list(builder._subject_name)) + x509_req, _encode_name(self, builder._subject_name) ) assert res == 1 @@ -1351,9 +1353,6 @@ class Backend(object): def _private_key_bytes(self, encoding, format, encryption_algorithm, evp_pkey, cdata): - if not isinstance(encoding, serialization.Encoding): - raise TypeError("encoding must be an item from the Encoding enum") - if not isinstance(format, serialization.PrivateFormat): raise TypeError( "format must be an item from the PrivateFormat enum" @@ -1416,6 +1415,8 @@ class Backend(object): elif format is serialization.PrivateFormat.PKCS8: write_bio = self._lib.i2d_PKCS8PrivateKey_bio key = evp_pkey + else: + raise TypeError("encoding must be an item from the Encoding enum") bio = self._create_mem_bio() res = write_bio( @@ -1448,11 +1449,6 @@ class Backend(object): if not isinstance(encoding, serialization.Encoding): raise TypeError("encoding must be an item from the Encoding enum") - if not isinstance(format, serialization.PublicFormat): - raise TypeError( - "format must be an item from the PublicFormat enum" - ) - if format is serialization.PublicFormat.SubjectPublicKeyInfo: if encoding is serialization.Encoding.PEM: write_bio = self._lib.PEM_write_bio_PUBKEY @@ -1469,6 +1465,10 @@ class Backend(object): write_bio = self._lib.i2d_RSAPublicKey_bio key = cdata + else: + raise TypeError( + "format must be an item from the PublicFormat enum" + ) bio = self._create_mem_bio() res = write_bio(bio, key) diff --git a/src/cryptography/hazmat/backends/openssl/x509.py b/src/cryptography/hazmat/backends/openssl/x509.py index a03414c8..cc805755 100644 --- a/src/cryptography/hazmat/backends/openssl/x509.py +++ b/src/cryptography/hazmat/backends/openssl/x509.py @@ -82,7 +82,17 @@ def _decode_general_names(backend, gns): def _decode_general_name(backend, gn): if gn.type == backend._lib.GEN_DNS: data = backend._ffi.buffer(gn.d.dNSName.data, gn.d.dNSName.length)[:] - return x509.DNSName(idna.decode(data)) + if data.startswith(b"*."): + # This is a wildcard name. We need to remove the leading wildcard, + # IDNA decode, then re-add the wildcard. Wildcard characters should + # always be left-most (RFC 2595 section 2.4). + data = u"*." + idna.decode(data[2:]) + else: + # Not a wildcard, decode away. If the string has a * in it anywhere + # invalid this will raise an InvalidCodePoint + data = idna.decode(data) + + return x509.DNSName(data) elif gn.type == backend._lib.GEN_URI: data = backend._ffi.buffer( gn.d.uniformResourceIdentifier.data, @@ -153,6 +163,45 @@ def _decode_general_name(backend, gn): ) +def _decode_ocsp_no_check(backend, ext): + return x509.OCSPNoCheck() + + +class _X509ExtensionParser(object): + def __init__(self, ext_count, get_ext, handlers): + self.ext_count = ext_count + self.get_ext = get_ext + self.handlers = handlers + + def parse(self, backend, x509_obj): + extensions = [] + seen_oids = set() + for i in range(self.ext_count(backend, x509_obj)): + ext = self.get_ext(backend, x509_obj, i) + assert ext != backend._ffi.NULL + crit = backend._lib.X509_EXTENSION_get_critical(ext) + critical = crit == 1 + oid = x509.ObjectIdentifier(_obj2txt(backend, ext.object)) + if oid in seen_oids: + raise x509.DuplicateExtension( + "Duplicate {0} extension found".format(oid), oid + ) + try: + handler = self.handlers[oid] + except KeyError: + if critical: + raise x509.UnsupportedExtension( + "{0} is not currently supported".format(oid), oid + ) + else: + value = handler(backend, ext) + extensions.append(x509.Extension(oid, critical, value)) + + seen_oids.add(oid) + + return x509.Extensions(extensions) + + @utils.register_interface(x509.Certificate) class _Certificate(object): def __init__(self, backend, x509): @@ -258,68 +307,17 @@ class _Certificate(object): @property def extensions(self): - extensions = [] - seen_oids = set() - extcount = self._backend._lib.X509_get_ext_count(self._x509) - for i in range(0, extcount): - ext = self._backend._lib.X509_get_ext(self._x509, i) - assert ext != self._backend._ffi.NULL - crit = self._backend._lib.X509_EXTENSION_get_critical(ext) - critical = crit == 1 - oid = x509.ObjectIdentifier(_obj2txt(self._backend, ext.object)) - if oid in seen_oids: - raise x509.DuplicateExtension( - "Duplicate {0} extension found".format(oid), oid - ) - elif oid == x509.OID_BASIC_CONSTRAINTS: - value = _decode_basic_constraints(self._backend, ext) - elif oid == x509.OID_SUBJECT_KEY_IDENTIFIER: - value = _decode_subject_key_identifier(self._backend, ext) - elif oid == x509.OID_KEY_USAGE: - value = _decode_key_usage(self._backend, ext) - elif oid == x509.OID_SUBJECT_ALTERNATIVE_NAME: - value = _decode_subject_alt_name(self._backend, ext) - elif oid == x509.OID_EXTENDED_KEY_USAGE: - value = _decode_extended_key_usage(self._backend, ext) - elif oid == x509.OID_AUTHORITY_KEY_IDENTIFIER: - value = _decode_authority_key_identifier(self._backend, ext) - elif oid == x509.OID_AUTHORITY_INFORMATION_ACCESS: - value = _decode_authority_information_access( - self._backend, ext - ) - elif oid == x509.OID_CERTIFICATE_POLICIES: - value = _decode_certificate_policies(self._backend, ext) - elif oid == x509.OID_CRL_DISTRIBUTION_POINTS: - value = _decode_crl_distribution_points(self._backend, ext) - elif oid == x509.OID_OCSP_NO_CHECK: - value = x509.OCSPNoCheck() - elif oid == x509.OID_INHIBIT_ANY_POLICY: - value = _decode_inhibit_any_policy(self._backend, ext) - elif oid == x509.OID_ISSUER_ALTERNATIVE_NAME: - value = _decode_issuer_alt_name(self._backend, ext) - elif critical: - raise x509.UnsupportedExtension( - "{0} is not currently supported".format(oid), oid - ) - else: - # Unsupported non-critical extension, silently skipping for now - seen_oids.add(oid) - continue - - seen_oids.add(oid) - extensions.append(x509.Extension(oid, critical, value)) - - return x509.Extensions(extensions) + return _CERTIFICATE_EXTENSION_PARSER.parse(self._backend, self._x509) def public_bytes(self, encoding): - if not isinstance(encoding, serialization.Encoding): - raise TypeError("encoding must be an item from the Encoding enum") - bio = self._backend._create_mem_bio() if encoding is serialization.Encoding.PEM: res = self._backend._lib.PEM_write_bio_X509(bio, self._x509) elif encoding is serialization.Encoding.DER: res = self._backend._lib.i2d_X509_bio(bio, self._x509) + else: + raise TypeError("encoding must be an item from the Encoding enum") + assert res == 1 return self._backend._read_mem_bio(bio) @@ -694,40 +692,10 @@ class _CertificateSigningRequest(object): @property def extensions(self): - extensions = [] - seen_oids = set() x509_exts = self._backend._lib.X509_REQ_get_extensions(self._x509_req) - extcount = self._backend._lib.sk_X509_EXTENSION_num(x509_exts) - for i in range(0, extcount): - ext = self._backend._lib.sk_X509_EXTENSION_value(x509_exts, i) - assert ext != self._backend._ffi.NULL - crit = self._backend._lib.X509_EXTENSION_get_critical(ext) - critical = crit == 1 - oid = x509.ObjectIdentifier(_obj2txt(self._backend, ext.object)) - if oid in seen_oids: - raise x509.DuplicateExtension( - "Duplicate {0} extension found".format(oid), oid - ) - elif oid == x509.OID_BASIC_CONSTRAINTS: - value = _decode_basic_constraints(self._backend, ext) - elif critical: - raise x509.UnsupportedExtension( - "{0} is not currently supported".format(oid), oid - ) - else: - # Unsupported non-critical extension, silently skipping for now - seen_oids.add(oid) - continue - - seen_oids.add(oid) - extensions.append(x509.Extension(oid, critical, value)) - - return x509.Extensions(extensions) + return _CSR_EXTENSION_PARSER.parse(self._backend, x509_exts) def public_bytes(self, encoding): - if not isinstance(encoding, serialization.Encoding): - raise TypeError("encoding must be an item from the Encoding enum") - bio = self._backend._create_mem_bio() if encoding is serialization.Encoding.PEM: res = self._backend._lib.PEM_write_bio_X509_REQ( @@ -735,5 +703,38 @@ class _CertificateSigningRequest(object): ) elif encoding is serialization.Encoding.DER: res = self._backend._lib.i2d_X509_REQ_bio(bio, self._x509_req) + else: + raise TypeError("encoding must be an item from the Encoding enum") + assert res == 1 return self._backend._read_mem_bio(bio) + + +_CERTIFICATE_EXTENSION_PARSER = _X509ExtensionParser( + ext_count=lambda backend, x: backend._lib.X509_get_ext_count(x), + get_ext=lambda backend, x, i: backend._lib.X509_get_ext(x, i), + handlers={ + x509.OID_BASIC_CONSTRAINTS: _decode_basic_constraints, + x509.OID_SUBJECT_KEY_IDENTIFIER: _decode_subject_key_identifier, + x509.OID_KEY_USAGE: _decode_key_usage, + x509.OID_SUBJECT_ALTERNATIVE_NAME: _decode_subject_alt_name, + x509.OID_EXTENDED_KEY_USAGE: _decode_extended_key_usage, + x509.OID_AUTHORITY_KEY_IDENTIFIER: _decode_authority_key_identifier, + x509.OID_AUTHORITY_INFORMATION_ACCESS: ( + _decode_authority_information_access + ), + x509.OID_CERTIFICATE_POLICIES: _decode_certificate_policies, + x509.OID_CRL_DISTRIBUTION_POINTS: _decode_crl_distribution_points, + x509.OID_OCSP_NO_CHECK: _decode_ocsp_no_check, + x509.OID_INHIBIT_ANY_POLICY: _decode_inhibit_any_policy, + x509.OID_ISSUER_ALTERNATIVE_NAME: _decode_issuer_alt_name, + } +) + +_CSR_EXTENSION_PARSER = _X509ExtensionParser( + ext_count=lambda backend, x: backend._lib.sk_X509_EXTENSION_num(x), + get_ext=lambda backend, x, i: backend._lib.sk_X509_EXTENSION_value(x, i), + handlers={ + x509.OID_BASIC_CONSTRAINTS: _decode_basic_constraints, + } +) diff --git a/src/cryptography/hazmat/primitives/serialization.py b/src/cryptography/hazmat/primitives/serialization.py index 8699fa91..098b31dc 100644 --- a/src/cryptography/hazmat/primitives/serialization.py +++ b/src/cryptography/hazmat/primitives/serialization.py @@ -106,12 +106,11 @@ def _load_ssh_ecdsa_public_key(expected_key_type, decoded_data, backend): if rest: raise ValueError('Key body contains extra bytes.') - if curve_name == b"nistp256": - curve = ec.SECP256R1() - elif curve_name == b"nistp384": - curve = ec.SECP384R1() - elif curve_name == b"nistp521": - curve = ec.SECP521R1() + curve = { + b"nistp256": ec.SECP256R1, + b"nistp384": ec.SECP384R1, + b"nistp521": ec.SECP521R1, + }[curve_name]() if six.indexbytes(data, 0) != 4: raise NotImplementedError( @@ -123,8 +122,12 @@ def _load_ssh_ecdsa_public_key(expected_key_type, decoded_data, backend): if len(data) != 1 + 2 * ((curve.key_size + 7) // 8): raise ValueError("Malformed key bytes") - x = _int_from_bytes(data[1:1 + (curve.key_size + 7) // 8], byteorder='big') - y = _int_from_bytes(data[1 + (curve.key_size + 7) // 8:], byteorder='big') + x = utils.int_from_bytes( + data[1:1 + (curve.key_size + 7) // 8], byteorder='big' + ) + y = utils.int_from_bytes( + data[1 + (curve.key_size + 7) // 8:], byteorder='big' + ) return ec.EllipticCurvePublicNumbers(x, y, curve).public_key(backend) @@ -146,27 +149,9 @@ def _read_next_mpint(data): """ mpint_data, rest = _read_next_string(data) - return _int_from_bytes(mpint_data, byteorder='big', signed=False), rest - - -if hasattr(int, "from_bytes"): - _int_from_bytes = int.from_bytes -else: - def _int_from_bytes(data, byteorder, signed=False): - assert byteorder == 'big' - assert not signed - - if len(data) % 4 != 0: - data = (b'\x00' * (4 - (len(data) % 4))) + data - - result = 0 - - while len(data) > 0: - digit, = struct.unpack('>I', data[:4]) - result = (result << 32) + digit - data = data[4:] - - return result + return ( + utils.int_from_bytes(mpint_data, byteorder='big', signed=False), rest + ) class Encoding(Enum): diff --git a/src/cryptography/utils.py b/src/cryptography/utils.py index 0bf8c0ea..24afe612 100644 --- a/src/cryptography/utils.py +++ b/src/cryptography/utils.py @@ -6,6 +6,7 @@ from __future__ import absolute_import, division, print_function import abc import inspect +import struct import sys import warnings @@ -25,6 +26,26 @@ def register_interface(iface): return register_decorator +if hasattr(int, "from_bytes"): + int_from_bytes = int.from_bytes +else: + def int_from_bytes(data, byteorder, signed=False): + assert byteorder == 'big' + assert not signed + + if len(data) % 4 != 0: + data = (b'\x00' * (4 - (len(data) % 4))) + data + + result = 0 + + while len(data) > 0: + digit, = struct.unpack('>I', data[:4]) + result = (result << 32) + digit + data = data[4:] + + return result + + class InterfaceNotImplemented(Exception): pass diff --git a/src/cryptography/x509.py b/src/cryptography/x509.py index 0f72abb3..668bc2ef 100644 --- a/src/cryptography/x509.py +++ b/src/cryptography/x509.py @@ -1486,4 +1486,6 @@ class CertificateSigningRequestBuilder(object): """ Signs the request using the requestor's private key. """ + if self._subject_name is None: + raise ValueError("A CertificateSigningRequest must have a subject") return backend.create_x509_csr(self, private_key, algorithm) |