diff options
-rw-r--r-- | cryptography/hazmat/backends/openssl/backend.py | 55 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/asymmetric/dsa.py | 8 |
2 files changed, 32 insertions, 31 deletions
diff --git a/cryptography/hazmat/backends/openssl/backend.py b/cryptography/hazmat/backends/openssl/backend.py index ca3992fe..3e6b1b5b 100644 --- a/cryptography/hazmat/backends/openssl/backend.py +++ b/cryptography/hazmat/backends/openssl/backend.py @@ -29,7 +29,7 @@ from cryptography.hazmat.backends.interfaces import ( ) from cryptography.hazmat.bindings.openssl.binding import Binding from cryptography.hazmat.primitives import hashes, interfaces -from cryptography.hazmat.primitives.asymmetric import rsa, dsa +from cryptography.hazmat.primitives.asymmetric import dsa, rsa from cryptography.hazmat.primitives.asymmetric.padding import ( MGF1, PKCS1v15, PSS ) @@ -409,22 +409,30 @@ class Backend(object): return _RSAVerificationContext(self, public_key, signature, padding, algorithm) - def generate_dsa_parameters(self, key_size, ctx=None): + def mgf1_hash_supported(self, algorithm): + if self._lib.Cryptography_HAS_MGF1_MD: + return self.hash_supported(algorithm) + else: + return isinstance(algorithm, hashes.SHA1) + + def generate_dsa_parameters(self, key_size): if key_size not in (1024, 2048, 3072): - raise ValueError("Key size must be 1024 or 2048 or" - "3072 bits") + raise ValueError( + "Key size must be 1024 or 2048 or 3072 bits") - if ctx is None: - ctx = self._lib.DSA_new() - assert ctx != self._ffi.NULL - ctx = self._ffi.gc(ctx, self._lib.DSA_free) + if backend._lib.OPENSSL_VERSION_NUMBER < 0x1000000f \ + and key_size > 1024: + raise ValueError( + "Key size must be 1024 because OpenSSL < 1.0.0 doesn't " + "support larger key sizes") - bn = self._int_to_bn(key_size) - bn = self._ffi.gc(bn, self._lib.BN_free) + ctx = self._lib.DSA_new() + assert ctx != self._ffi.NULL + ctx = self._ffi.gc(ctx, self._lib.DSA_free) res = self._lib.DSA_generate_parameters_ex( - ctx, bn, self._ffi.NULL, self._ffi.NULL, - self._ffi.NULL, self._ffi.NULL + ctx, key_size, self._ffi.NULL, self._ffi.NULL, + self._ffi.NULL, self._ffi.NULL, self._ffi.NULL ) assert res == 1 @@ -435,22 +443,13 @@ class Backend(object): generator=self._bn_to_int(ctx.g) ) - def generate_dsa_private_key(self, parameters, key_size): + def generate_dsa_private_key(self, parameters): ctx = self._lib.DSA_new() assert ctx != self._ffi.NULL ctx = self._ffi.gc(ctx, self._lib.DSA_free) - if all([parameters.p, parameters.q, parameters.g]): - ctx.p = self._int_to_bn(parameters.p) - ctx.q = self._int_to_bn(parameters.q) - ctx.g = self._int_to_bn(parameters.g) - - else: - if key_size not in (1024, 2048, 3072): - raise ValueError("Key size must be 1024 or 2048 or" - "3072 bits") - bn = self._int_to_bn(key_size) - bn = self._ffi.gc(bn, self._lib.BN_free) - self.generate_dsa_parameters(bn, ctx) + ctx.p = self._int_to_bn(parameters.p) + ctx.q = self._int_to_bn(parameters.q) + ctx.g = self._int_to_bn(parameters.g) self._lib.DSA_generate_key(ctx) @@ -462,12 +461,6 @@ class Backend(object): y=self._bn_to_int(ctx.pub_key) ) - def mgf1_hash_supported(self, algorithm): - if self._lib.Cryptography_HAS_MGF1_MD: - return self.hash_supported(algorithm) - else: - return isinstance(algorithm, hashes.SHA1) - class GetCipherByName(object): def __init__(self, fmt): diff --git a/cryptography/hazmat/primitives/asymmetric/dsa.py b/cryptography/hazmat/primitives/asymmetric/dsa.py index 974db0a6..eb4a162c 100644 --- a/cryptography/hazmat/primitives/asymmetric/dsa.py +++ b/cryptography/hazmat/primitives/asymmetric/dsa.py @@ -49,6 +49,10 @@ class DSAParameters(object): self._subgroup_order = subgroup_order self._generator = generator + @classmethod + def generate(cls, backend, key_size): + return backend.generate_dsa_parameters(key_size) + @property def modulus(self): return self._modulus @@ -96,6 +100,10 @@ class DSAPrivateKey(object): self._x = x self._y = y + @classmethod + def generate(cls, backend, parameters): + return backend.generate_dsa_private_key(parameters) + @property def key_size(self): return utils.bit_length(self._modulus) |