diff options
-rw-r--r-- | cryptography/hazmat/backends/interfaces.py | 6 | ||||
-rw-r--r-- | cryptography/hazmat/backends/openssl/backend.py | 75 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/asymmetric/padding.py | 8 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_rsa.py | 63 |
4 files changed, 151 insertions, 1 deletions
diff --git a/cryptography/hazmat/backends/interfaces.py b/cryptography/hazmat/backends/interfaces.py index 92413d8c..c5c5a16e 100644 --- a/cryptography/hazmat/backends/interfaces.py +++ b/cryptography/hazmat/backends/interfaces.py @@ -117,6 +117,12 @@ class RSABackend(object): Return True if the hash algorithm is supported for MGF1 in PSS. """ + @abc.abstractmethod + def rsa_decrypt(self, private_key, ciphertext, padding): + """ + Returns decrypted bytes. + """ + @six.add_metaclass(abc.ABCMeta) class DSABackend(object): diff --git a/cryptography/hazmat/backends/openssl/backend.py b/cryptography/hazmat/backends/openssl/backend.py index 86fa704b..2965c781 100644 --- a/cryptography/hazmat/backends/openssl/backend.py +++ b/cryptography/hazmat/backends/openssl/backend.py @@ -32,7 +32,7 @@ from cryptography.hazmat.bindings.openssl.binding import Binding from cryptography.hazmat.primitives import hashes, interfaces from cryptography.hazmat.primitives.asymmetric import dsa, rsa from cryptography.hazmat.primitives.asymmetric.padding import ( - MGF1, PKCS1v15, PSS + MGF1, OAEP, PKCS1v15, PSS ) from cryptography.hazmat.primitives.ciphers.algorithms import ( AES, ARC4, Blowfish, CAST5, Camellia, IDEA, SEED, TripleDES @@ -473,6 +473,79 @@ class Backend(object): y=self._bn_to_int(ctx.pub_key) ) + def rsa_decrypt(self, private_key, ciphertext, padding): + if isinstance(padding, PKCS1v15): + padding_enum = self._lib.RSA_PKCS1_PADDING + elif isinstance(padding, OAEP): + padding_enum = self._lib.RSA_PKCS1_OAEP_PADDING + if not isinstance(padding._mgf, MGF1): + raise UnsupportedAlgorithm( + "Only MGF1 is supported by this backend" + ) + + if not isinstance(padding._mgf._algorithm, hashes.SHA1): + raise UnsupportedAlgorithm( + "This backend only supports SHA1 inside MGF1 when " + "using OAEP", + _Reasons.UNSUPPORTED_HASH + ) + else: + raise UnsupportedAlgorithm( + "{0} is not supported by this backend".format( + padding.name + ), + _Reasons.UNSUPPORTED_PADDING + ) + + if self._lib.Cryptography_HAS_PKEY_CTX: + evp_pkey = self._rsa_private_key_to_evp_pkey(private_key) + pkey_ctx = self._lib.EVP_PKEY_CTX_new( + evp_pkey, self._ffi.NULL + ) + assert pkey_ctx != self._ffi.NULL + res = self._lib.EVP_PKEY_decrypt_init(pkey_ctx) + assert res == 1 + res = self._lib.EVP_PKEY_CTX_set_rsa_padding( + pkey_ctx, padding_enum) + assert res > 0 + buf_size = self._lib.EVP_PKEY_size(evp_pkey) + assert buf_size > 0 + outlen = self._ffi.new("size_t *", buf_size) + buf = self._ffi.new("char[]", buf_size) + res = self._lib.Cryptography_EVP_PKEY_decrypt( + pkey_ctx, + buf, + outlen, + ciphertext, + len(ciphertext) + ) + assert res >= 0 + if res == 0: + errors = self._consume_errors() + assert errors + raise SystemError # TODO + + return self._ffi.buffer(buf)[:outlen[0]] + else: + rsa_cdata = self._rsa_cdata_from_private_key(private_key) + rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free) + key_size = self._lib.RSA_size(rsa_cdata) + assert key_size > 0 + buf = self._ffi.new("unsigned char[]", key_size) + res = self._lib.RSA_private_decrypt( + len(ciphertext), + ciphertext, + buf, + rsa_cdata, + padding_enum + ) + if res < 0: + errors = self._consume_errors() + assert errors + raise SystemError # TODO + + return self._ffi.buffer(buf)[:res] + class GetCipherByName(object): def __init__(self, fmt): diff --git a/cryptography/hazmat/primitives/asymmetric/padding.py b/cryptography/hazmat/primitives/asymmetric/padding.py index 72806a61..899fed17 100644 --- a/cryptography/hazmat/primitives/asymmetric/padding.py +++ b/cryptography/hazmat/primitives/asymmetric/padding.py @@ -54,6 +54,14 @@ class PSS(object): self._salt_length = salt_length +@utils.register_interface(interfaces.AsymmetricPadding) +class OAEP(object): + name = "EME-OAEP" + + def __init__(self, mgf): + self._mgf = mgf + + class MGF1(object): MAX_LENGTH = object() diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 84d0f805..70ae20dc 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -1225,3 +1225,66 @@ class TestMGF1(object): mgf = padding.MGF1(algorithm, padding.MGF1.MAX_LENGTH) assert mgf._algorithm == algorithm assert mgf._salt_length == padding.MGF1.MAX_LENGTH + + +@pytest.mark.rsa +class TestRSADecryption(object): + @pytest.mark.parametrize( + "vector", + _flatten_pkcs1_examples(load_vectors_from_file( + os.path.join( + "asymmetric", "RSA", "pkcs-1v2-1d2-vec", "oaep-vect.txt"), + load_pkcs1_vectors + )) + ) + def test_decrypt_oaep_vectors(self, vector, backend): + private, public, example = vector + skey = rsa.RSAPrivateKey( + p=private["p"], + q=private["q"], + private_exponent=private["private_exponent"], + dmp1=private["dmp1"], + dmq1=private["dmq1"], + iqmp=private["iqmp"], + public_exponent=private["public_exponent"], + modulus=private["modulus"] + ) + message = backend.rsa_decrypt( + skey, + binascii.unhexlify(example["encryption"]), + # TODO: handle MGF1 here + padding.OAEP( + padding.MGF1( + algorithm=hashes.SHA1(), + salt_length=padding.MGF1.MAX_LENGTH + ) + ) + ) + assert message == binascii.unhexlify(example["message"]) + + @pytest.mark.parametrize( + "vector", + _flatten_pkcs1_examples(load_vectors_from_file( + os.path.join( + "asymmetric", "RSA", "pkcs1v15crypt-vectors.txt"), + load_pkcs1_vectors + )) + ) + def test_decrypt_pkcs1v15_vectors(self, vector, backend): + private, public, example = vector + skey = rsa.RSAPrivateKey( + p=private["p"], + q=private["q"], + private_exponent=private["private_exponent"], + dmp1=private["dmp1"], + dmq1=private["dmq1"], + iqmp=private["iqmp"], + public_exponent=private["public_exponent"], + modulus=private["modulus"] + ) + message = backend.rsa_decrypt( + skey, + binascii.unhexlify(example["encryption"]), + padding.PKCS1v15() + ) + assert message == binascii.unhexlify(example["message"]) |