diff options
-rw-r--r-- | tests/hazmat/primitives/test_rsa.py | 45 | ||||
-rw-r--r-- | tests/hazmat/primitives/utils.py | 18 |
2 files changed, 40 insertions, 23 deletions
diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index cf201212..97ca3935 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -763,7 +763,12 @@ class TestRSAPSSMGF1VerificationSHA1(object): "SigVerPSS_186-3.rsp", ], hashes.SHA1(), - padding.PSS + lambda params, hash_alg: padding.PSS( + mgf=padding.MGF1( + algorithm=hash_alg, + salt_length=params["salt_length"] + ) + ) ) @@ -782,7 +787,12 @@ class TestRSAPSSMGF1VerificationSHA224(object): "SigVerPSS_186-3.rsp", ], hashes.SHA224(), - padding.PSS + lambda params, hash_alg: padding.PSS( + mgf=padding.MGF1( + algorithm=hash_alg, + salt_length=params["salt_length"] + ) + ) ) @@ -801,7 +811,12 @@ class TestRSAPSSMGF1VerificationSHA256(object): "SigVerPSS_186-3.rsp", ], hashes.SHA256(), - padding.PSS + lambda params, hash_alg: padding.PSS( + mgf=padding.MGF1( + algorithm=hash_alg, + salt_length=params["salt_length"] + ) + ) ) @@ -820,7 +835,12 @@ class TestRSAPSSMGF1VerificationSHA384(object): "SigVerPSS_186-3.rsp", ], hashes.SHA384(), - padding.PSS + lambda params, hash_alg: padding.PSS( + mgf=padding.MGF1( + algorithm=hash_alg, + salt_length=params["salt_length"] + ) + ) ) @@ -839,7 +859,12 @@ class TestRSAPSSMGF1VerificationSHA512(object): "SigVerPSS_186-3.rsp", ], hashes.SHA512(), - padding.PSS + lambda params, hash_alg: padding.PSS( + mgf=padding.MGF1( + algorithm=hash_alg, + salt_length=params["salt_length"] + ) + ) ) @@ -856,7 +881,7 @@ class TestRSAPKCS1SHA1Verification(object): "SigVer15_186-3.rsp", ], hashes.SHA1(), - padding.PKCS1v15 + lambda params, hash_alg: padding.PKCS1v15() ) @@ -873,7 +898,7 @@ class TestRSAPKCS1SHA224Verification(object): "SigVer15_186-3.rsp", ], hashes.SHA224(), - padding.PKCS1v15 + lambda params, hash_alg: padding.PKCS1v15() ) @@ -890,7 +915,7 @@ class TestRSAPKCS1SHA256Verification(object): "SigVer15_186-3.rsp", ], hashes.SHA256(), - padding.PKCS1v15 + lambda params, hash_alg: padding.PKCS1v15() ) @@ -907,7 +932,7 @@ class TestRSAPKCS1SHA384Verification(object): "SigVer15_186-3.rsp", ], hashes.SHA384(), - padding.PKCS1v15 + lambda params, hash_alg: padding.PKCS1v15() ) @@ -924,7 +949,7 @@ class TestRSAPKCS1SHA512Verification(object): "SigVer15_186-3.rsp", ], hashes.SHA512(), - padding.PKCS1v15 + lambda params, hash_alg: padding.PKCS1v15() ) diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index 89d0f5f1..2e838474 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -24,7 +24,7 @@ from cryptography.exceptions import ( NotYetFinalized ) from cryptography.hazmat.primitives import hashes, hmac -from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC @@ -376,32 +376,24 @@ def generate_hkdf_test(param_loader, path, file_names, algorithm): def generate_rsa_verification_test(param_loader, path, file_names, hash_alg, - pad_cls): + pad_factory): all_params = _load_all_params(path, file_names, param_loader) all_params = [i for i in all_params if i["algorithm"] == hash_alg.name.upper()] @pytest.mark.parametrize("params", all_params) def test_rsa_verification(self, backend, params): - rsa_verification_test(backend, params, hash_alg, pad_cls) + rsa_verification_test(backend, params, hash_alg, pad_factory) return test_rsa_verification -def rsa_verification_test(backend, params, hash_alg, pad_cls): +def rsa_verification_test(backend, params, hash_alg, pad_factory): public_key = rsa.RSAPublicKey( public_exponent=params["public_exponent"], modulus=params["modulus"] ) - if pad_cls is padding.PKCS1v15: - pad = padding.PKCS1v15() - else: - pad = padding.PSS( - mgf=padding.MGF1( - algorithm=hash_alg, - salt_length=params["salt_length"] - ) - ) + pad = pad_factory(params, hash_alg) verifier = public_key.verifier( binascii.unhexlify(params["s"]), pad, |