diff options
-rw-r--r-- | cryptography/hazmat/primitives/asymmetric/rsa.py | 100 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_rsa.py | 8 |
2 files changed, 61 insertions, 47 deletions
diff --git a/cryptography/hazmat/primitives/asymmetric/rsa.py b/cryptography/hazmat/primitives/asymmetric/rsa.py index b256ddcc..4f4e257b 100644 --- a/cryptography/hazmat/primitives/asymmetric/rsa.py +++ b/cryptography/hazmat/primitives/asymmetric/rsa.py @@ -43,6 +43,56 @@ def _verify_rsa_parameters(public_exponent, key_size): raise ValueError("key_size must be at least 512-bits.") +def _check_private_key_components(p, q, private_exponent, dmp1, dmq1, iqmp, + public_exponent, modulus): + if modulus < 3: + raise ValueError("modulus must be >= 3.") + + if p >= modulus: + raise ValueError("p must be < modulus.") + + if q >= modulus: + raise ValueError("q must be < modulus.") + + if dmp1 >= modulus: + raise ValueError("dmp1 must be < modulus.") + + if dmq1 >= modulus: + raise ValueError("dmq1 must be < modulus.") + + if iqmp >= modulus: + raise ValueError("iqmp must be < modulus.") + + if private_exponent >= modulus: + raise ValueError("private_exponent must be < modulus.") + + if public_exponent < 3 or public_exponent >= modulus: + raise ValueError("public_exponent must be >= 3 and < modulus.") + + if public_exponent & 1 == 0: + raise ValueError("public_exponent must be odd.") + + if dmp1 & 1 == 0: + raise ValueError("dmp1 must be odd.") + + if dmq1 & 1 == 0: + raise ValueError("dmq1 must be odd.") + + if p * q != modulus: + raise ValueError("p*q must equal modulus.") + + +def _check_public_key_components(e, n): + if n < 3: + raise ValueError("n must be >= 3.") + + if e < 3 or e >= n: + raise ValueError("e must be >= 3 and < n.") + + if e & 1 == 0: + raise ValueError("e must be odd.") + + @utils.register_interface(interfaces.RSAPublicKey) class RSAPublicKey(object): def __init__(self, public_exponent, modulus): @@ -52,14 +102,7 @@ class RSAPublicKey(object): ): raise TypeError("RSAPublicKey arguments must be integers.") - if modulus < 3: - raise ValueError("modulus must be >= 3.") - - if public_exponent < 3 or public_exponent >= modulus: - raise ValueError("public_exponent must be >= 3 and < modulus.") - - if public_exponent & 1 == 0: - raise ValueError("public_exponent must be odd.") + _check_public_key_components(public_exponent, modulus) self._public_exponent = public_exponent self._modulus = modulus @@ -156,41 +199,8 @@ class RSAPrivateKey(object): ): raise TypeError("RSAPrivateKey arguments must be integers.") - if modulus < 3: - raise ValueError("modulus must be >= 3.") - - if p >= modulus: - raise ValueError("p must be < modulus.") - - if q >= modulus: - raise ValueError("q must be < modulus.") - - if dmp1 >= modulus: - raise ValueError("dmp1 must be < modulus.") - - if dmq1 >= modulus: - raise ValueError("dmq1 must be < modulus.") - - if iqmp >= modulus: - raise ValueError("iqmp must be < modulus.") - - if private_exponent >= modulus: - raise ValueError("private_exponent must be < modulus.") - - if public_exponent < 3 or public_exponent >= modulus: - raise ValueError("public_exponent must be >= 3 and < modulus.") - - if public_exponent & 1 == 0: - raise ValueError("public_exponent must be odd.") - - if dmp1 & 1 == 0: - raise ValueError("dmp1 must be odd.") - - if dmq1 & 1 == 0: - raise ValueError("dmq1 must be odd.") - - if p * q != modulus: - raise ValueError("p*q must equal modulus.") + _check_private_key_components(p, q, private_exponent, dmp1, dmq1, iqmp, + public_exponent, modulus) self._p = p self._q = q @@ -304,6 +314,8 @@ class RSAPrivateNumbers(object): " instance." ) + _check_private_key_components(p, q, d, dmp1, dmq1, iqmp, + public_numbers.e, public_numbers.n) self._p = p self._q = q self._d = d @@ -349,6 +361,8 @@ class RSAPublicNumbers(object): ): raise TypeError("RSAPublicNumbers arguments must be integers.") + _check_public_key_components(e, n) + self._e = e self._n = n diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 8f10fb10..cfb51b0b 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -1681,12 +1681,12 @@ class TestRSAEncryption(object): @pytest.mark.rsa class TestRSANumbers(object): def test_rsa_public_numbers(self): - public_numbers = rsa.RSAPublicNumbers(e=1, n=15) - assert public_numbers.e == 1 + public_numbers = rsa.RSAPublicNumbers(e=3, n=15) + assert public_numbers.e == 3 assert public_numbers.n == 15 def test_rsa_private_numbers(self): - public_numbers = rsa.RSAPublicNumbers(e=1, n=15) + public_numbers = rsa.RSAPublicNumbers(e=3, n=15) private_numbers = rsa.RSAPrivateNumbers( p=3, q=5, @@ -1713,7 +1713,7 @@ class TestRSANumbers(object): rsa.RSAPublicNumbers(e=1, n=None) def test_private_numbers_invalid_types(self): - public_numbers = rsa.RSAPublicNumbers(e=1, n=15) + public_numbers = rsa.RSAPublicNumbers(e=3, n=15) with pytest.raises(TypeError): rsa.RSAPrivateNumbers( |