diff options
Diffstat (limited to 'tests/hazmat/primitives/test_rsa.py')
-rw-r--r-- | tests/hazmat/primitives/test_rsa.py | 183 |
1 files changed, 180 insertions, 3 deletions
diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 6d8e6874..e6d0ac28 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -15,8 +15,10 @@ from cryptography import utils from cryptography.exceptions import ( AlreadyFinalized, InvalidSignature, _Reasons ) -from cryptography.hazmat.backends.interfaces import RSABackend -from cryptography.hazmat.primitives import hashes, interfaces +from cryptography.hazmat.backends.interfaces import ( + PEMSerializationBackend, RSABackend +) +from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding, rsa from cryptography.hazmat.primitives.asymmetric.rsa import ( RSAPrivateNumbers, RSAPublicNumbers @@ -46,6 +48,11 @@ class DummyMGF(object): _salt_length = 0 +@utils.register_interface(serialization.KeySerializationEncryption) +class DummyKeyEncryption(object): + pass + + def _flatten_pkcs1_examples(vectors): flattened_vectors = [] for vector in vectors: @@ -78,6 +85,21 @@ def test_modular_inverse(): ) +def _skip_if_no_serialization(key, backend): + if not isinstance( + key, + (rsa.RSAPrivateKeyWithSerialization, rsa.RSAPublicKeyWithSerialization) + ): + pytest.skip( + "{0} does not support RSA key serialization".format(backend) + ) + + +def test_skip_if_no_serialization(): + with pytest.raises(pytest.skip.Exception): + _skip_if_no_serialization("notakeywithserialization", "backend") + + @pytest.mark.requires_backend_interface(interface=RSABackend) class TestRSA(object): @pytest.mark.parametrize( @@ -91,7 +113,7 @@ class TestRSA(object): skey = rsa.generate_private_key(public_exponent, key_size, backend) assert skey.key_size == key_size - if isinstance(skey, interfaces.RSAPrivateKeyWithNumbers): + if isinstance(skey, rsa.RSAPrivateKeyWithNumbers): _check_rsa_private_numbers(skey.private_numbers()) pkey = skey.public_key() assert isinstance(pkey.public_numbers(), rsa.RSAPublicNumbers) @@ -1725,3 +1747,158 @@ class TestRSAPrimeFactorRecovery(object): def test_invalid_recover_prime_factors(self): with pytest.raises(ValueError): rsa.rsa_recover_prime_factors(34, 3, 7) + + +@pytest.mark.requires_backend_interface(interface=RSABackend) +@pytest.mark.requires_backend_interface(interface=PEMSerializationBackend) +class TestRSAPEMPrivateKeySerialization(object): + @pytest.mark.parametrize( + ("fmt", "password"), + itertools.product( + [ + serialization.PrivateFormat.TraditionalOpenSSL, + serialization.PrivateFormat.PKCS8 + ], + [ + b"s", + b"longerpassword", + b"!*$&(@#$*&($T@%_somesymbols", + b"\x01" * 1000, + ] + ) + ) + def test_private_bytes_encrypted_pem(self, backend, fmt, password): + key = RSA_KEY_2048.private_key(backend) + _skip_if_no_serialization(key, backend) + serialized = key.private_bytes( + serialization.Encoding.PEM, + fmt, + serialization.BestAvailableEncryption(password) + ) + loaded_key = serialization.load_pem_private_key( + serialized, password, backend + ) + loaded_priv_num = loaded_key.private_numbers() + priv_num = key.private_numbers() + assert loaded_priv_num == priv_num + + @pytest.mark.parametrize( + "fmt", + [ + serialization.PrivateFormat.TraditionalOpenSSL, + serialization.PrivateFormat.PKCS8 + ], + ) + def test_private_bytes_unencrypted_pem(self, backend, fmt): + key = RSA_KEY_2048.private_key(backend) + _skip_if_no_serialization(key, backend) + serialized = key.private_bytes( + serialization.Encoding.PEM, + fmt, + serialization.NoEncryption() + ) + loaded_key = serialization.load_pem_private_key( + serialized, None, backend + ) + loaded_priv_num = loaded_key.private_numbers() + priv_num = key.private_numbers() + assert loaded_priv_num == priv_num + + def test_private_bytes_traditional_openssl_unencrypted_pem(self, backend): + key_bytes = load_vectors_from_file( + os.path.join( + "asymmetric", + "Traditional_OpenSSL_Serialization", + "testrsa.pem" + ), + lambda pemfile: pemfile.read().encode() + ) + key = serialization.load_pem_private_key(key_bytes, None, backend) + serialized = key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.TraditionalOpenSSL, + serialization.NoEncryption() + ) + assert serialized == key_bytes + + def test_private_bytes_invalid_encoding(self, backend): + key = RSA_KEY_2048.private_key(backend) + _skip_if_no_serialization(key, backend) + with pytest.raises(TypeError): + key.private_bytes( + "notencoding", + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption() + ) + + def test_private_bytes_invalid_format(self, backend): + key = RSA_KEY_2048.private_key(backend) + _skip_if_no_serialization(key, backend) + with pytest.raises(TypeError): + key.private_bytes( + serialization.Encoding.PEM, + "invalidformat", + serialization.NoEncryption() + ) + + def test_private_bytes_invalid_encryption_algorithm(self, backend): + key = RSA_KEY_2048.private_key(backend) + _skip_if_no_serialization(key, backend) + with pytest.raises(TypeError): + key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.TraditionalOpenSSL, + "notanencalg" + ) + + def test_private_bytes_unsupported_encryption_type(self, backend): + key = RSA_KEY_2048.private_key(backend) + _skip_if_no_serialization(key, backend) + with pytest.raises(ValueError): + key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.TraditionalOpenSSL, + DummyKeyEncryption() + ) + + +@pytest.mark.requires_backend_interface(interface=RSABackend) +@pytest.mark.requires_backend_interface(interface=PEMSerializationBackend) +class TestRSAPEMPublicKeySerialization(object): + def test_public_bytes_unencrypted_pem(self, backend): + key_bytes = load_vectors_from_file( + os.path.join("asymmetric", "PKCS8", "unenc-rsa-pkcs8.pub.pem"), + lambda pemfile: pemfile.read().encode() + ) + key = serialization.load_pem_public_key(key_bytes, backend) + _skip_if_no_serialization(key, backend) + serialized = key.public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.SubjectPublicKeyInfo, + ) + assert serialized == key_bytes + + def test_public_bytes_pkcs1_unencrypted_pem(self, backend): + key_bytes = load_vectors_from_file( + os.path.join("asymmetric", "public", "PKCS1", "rsa.pub.pem"), + lambda pemfile: pemfile.read().encode() + ) + key = serialization.load_pem_public_key(key_bytes, backend) + _skip_if_no_serialization(key, backend) + serialized = key.public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.PKCS1, + ) + assert serialized == key_bytes + + def test_public_bytes_invalid_encoding(self, backend): + key = RSA_KEY_2048.private_key(backend).public_key() + _skip_if_no_serialization(key, backend) + with pytest.raises(TypeError): + key.public_bytes("notencoding", serialization.PublicFormat.PKCS1) + + def test_public_bytes_invalid_format(self, backend): + key = RSA_KEY_2048.private_key(backend).public_key() + _skip_if_no_serialization(key, backend) + with pytest.raises(TypeError): + key.public_bytes(serialization.Encoding.PEM, "invalidformat") |