diff options
-rw-r--r-- | cryptography/hazmat/backends/openssl/backend.py | 7 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/base.py | 14 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/modes.py | 5 | ||||
-rw-r--r-- | tests/hazmat/primitives/utils.py | 9 |
4 files changed, 20 insertions, 15 deletions
diff --git a/cryptography/hazmat/backends/openssl/backend.py b/cryptography/hazmat/backends/openssl/backend.py index 470aa399..49066466 100644 --- a/cryptography/hazmat/backends/openssl/backend.py +++ b/cryptography/hazmat/backends/openssl/backend.py @@ -235,18 +235,11 @@ class _CipherContext(object): ) assert res != 0 if operation == self._DECRYPT: - if not mode.tag or len(mode.tag) < 4: - raise ValueError("Authentication tag must be provided and " - "be 4 bytes or longer when decrypting") res = self._backend._lib.EVP_CIPHER_CTX_ctrl( ctx, self._backend._lib.EVP_CTRL_GCM_SET_TAG, len(mode.tag), mode.tag ) assert res != 0 - else: - if mode.tag: - raise ValueError("Authentication tag must be None when " - "encrypting") # pass key/iv res = self._backend._lib.EVP_CipherInit_ex( diff --git a/cryptography/hazmat/primitives/ciphers/base.py b/cryptography/hazmat/primitives/ciphers/base.py index b8615cb9..1da0802c 100644 --- a/cryptography/hazmat/primitives/ciphers/base.py +++ b/cryptography/hazmat/primitives/ciphers/base.py @@ -30,16 +30,26 @@ class Cipher(object): self._backend = backend def encryptor(self): + if isinstance(self.mode, interfaces.ModeWithAuthenticationTag): + if self.mode.tag is not None: + raise ValueError( + "Authentication tag must be None when encrypting" + ) ctx = self._backend.create_symmetric_encryption_ctx( self.algorithm, self.mode ) - return self._wrap_ctx(ctx, True) + return self._wrap_ctx(ctx, encrypt=True) def decryptor(self): + if isinstance(self.mode, interfaces.ModeWithAuthenticationTag): + if self.mode.tag is None: + raise ValueError( + "Authentication tag must be provided when decrypting" + ) ctx = self._backend.create_symmetric_decryption_ctx( self.algorithm, self.mode ) - return self._wrap_ctx(ctx, False) + return self._wrap_ctx(ctx, encrypt=False) def _wrap_ctx(self, ctx, encrypt): if isinstance(self.mode, interfaces.ModeWithAuthenticationTag): diff --git a/cryptography/hazmat/primitives/ciphers/modes.py b/cryptography/hazmat/primitives/ciphers/modes.py index e1c70185..ab8501c6 100644 --- a/cryptography/hazmat/primitives/ciphers/modes.py +++ b/cryptography/hazmat/primitives/ciphers/modes.py @@ -65,5 +65,10 @@ class GCM(object): name = "GCM" def __init__(self, initialization_vector, tag=None): + if tag is not None and len(tag) < 4: + raise ValueError( + "Authentication tag must be 4 bytes or longer" + ) + self.initialization_vector = initialization_vector self.tag = tag diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index cdcf84cb..6ecc70ff 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -264,13 +264,10 @@ def aead_tag_exception_test(backend, cipher_factory, mode_factory): ) with pytest.raises(ValueError): cipher.decryptor() - cipher = Cipher( - cipher_factory(binascii.unhexlify(b"0" * 32)), - mode_factory(binascii.unhexlify(b"0" * 24), b"000"), - backend - ) + with pytest.raises(ValueError): - cipher.decryptor() + mode_factory(binascii.unhexlify(b"0" * 24), b"000") + cipher = Cipher( cipher_factory(binascii.unhexlify(b"0" * 32)), mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16), |