diff options
-rw-r--r-- | cryptography/hazmat/bindings/openssl/backend.py | 8 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_block.py | 12 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_utils.py | 15 | ||||
-rw-r--r-- | tests/hazmat/primitives/utils.py | 35 |
4 files changed, 66 insertions, 4 deletions
diff --git a/cryptography/hazmat/bindings/openssl/backend.py b/cryptography/hazmat/bindings/openssl/backend.py index 1b19ddaa..6ab4dc26 100644 --- a/cryptography/hazmat/bindings/openssl/backend.py +++ b/cryptography/hazmat/bindings/openssl/backend.py @@ -289,12 +289,18 @@ class _CipherContext(object): ) assert res != 0 if operation == self._DECRYPT: - assert mode.tag is not None + if not mode.tag: + raise ValueError("Authentication tag must be supplied " + "when decrypting") res = self._backend.lib.EVP_CIPHER_CTX_ctrl( ctx, self._backend.lib.Cryptography_EVP_CTRL_GCM_SET_TAG, len(mode.tag), mode.tag ) assert res != 0 + else: + if mode.tag: + raise ValueError("Authentication tag must be None when " + "encrypting") # pass key/iv res = self._backend.lib.EVP_CipherInit_ex(ctx, self._backend.ffi.NULL, diff --git a/tests/hazmat/primitives/test_block.py b/tests/hazmat/primitives/test_block.py index 2806efd5..02de3861 100644 --- a/tests/hazmat/primitives/test_block.py +++ b/tests/hazmat/primitives/test_block.py @@ -26,7 +26,9 @@ from cryptography.hazmat.primitives.ciphers import ( Cipher, algorithms, modes ) -from .utils import generate_aead_exception_test +from .utils import ( + generate_aead_exception_test, generate_aead_tag_exception_test +) @utils.register_interface(interfaces.CipherAlgorithm) @@ -135,3 +137,11 @@ class TestAEADCipherContext(object): ), skip_message="Does not support AES GCM", ) + test_aead_tag_exceptions = generate_aead_tag_exception_test( + algorithms.AES, + modes.GCM, + only_if=lambda backend: backend.cipher_supported( + algorithms.AES("\x00" * 16), modes.GCM("\x00" * 12) + ), + skip_message="Does not support AES GCM", + ) diff --git a/tests/hazmat/primitives/test_utils.py b/tests/hazmat/primitives/test_utils.py index ebb8b5c4..c39364c7 100644 --- a/tests/hazmat/primitives/test_utils.py +++ b/tests/hazmat/primitives/test_utils.py @@ -3,7 +3,7 @@ import pytest from .utils import ( base_hash_test, encrypt_test, hash_test, long_string_hash_test, base_hmac_test, hmac_test, stream_encryption_test, aead_test, - aead_exception_test, + aead_exception_test, aead_tag_exception_test, ) @@ -29,7 +29,7 @@ class TestAEADTest(object): assert exc_info.value.args[0] == "message!" -class TestAEADFinalizeTest(object): +class TestAEADExceptionTest(object): def test_skips_if_only_if_returns_false(self): with pytest.raises(pytest.skip.Exception) as exc_info: aead_exception_test( @@ -40,6 +40,17 @@ class TestAEADFinalizeTest(object): assert exc_info.value.args[0] == "message!" +class TestAEADTagExceptionTest(object): + def test_skips_if_only_if_returns_false(self): + with pytest.raises(pytest.skip.Exception) as exc_info: + aead_tag_exception_test( + None, None, None, + only_if=lambda backend: False, + skip_message="message!" + ) + assert exc_info.value.args[0] == "message!" + + class TestHashTest(object): def test_skips_if_only_if_returns_false(self): with pytest.raises(pytest.skip.Exception) as exc_info: diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index 9aa3a89a..705983a0 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -353,3 +353,38 @@ def aead_exception_test(backend, cipher_factory, mode_factory, decryptor.update(b"a" * 16) with pytest.raises(AttributeError): decryptor.tag + + +def generate_aead_tag_exception_test(cipher_factory, mode_factory, + only_if, skip_message): + def test_aead_tag_exception(self): + for backend in _ALL_BACKENDS: + yield ( + aead_tag_exception_test, + backend, + cipher_factory, + mode_factory, + only_if, + skip_message + ) + return test_aead_tag_exception + + +def aead_tag_exception_test(backend, cipher_factory, mode_factory, + only_if, skip_message): + if not only_if(backend): + pytest.skip(skip_message) + cipher = Cipher( + cipher_factory(binascii.unhexlify(b"0" * 32)), + mode_factory(binascii.unhexlify(b"0" * 24)), + backend + ) + with pytest.raises(ValueError): + cipher.decryptor() + cipher = Cipher( + cipher_factory(binascii.unhexlify(b"0" * 32)), + mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16), + backend + ) + with pytest.raises(ValueError): + cipher.encryptor() |