diff options
-rw-r--r-- | cryptography/hazmat/backends/openssl/backend.py | 7 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/base.py | 8 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/modes.py | 5 |
3 files changed, 13 insertions, 7 deletions
diff --git a/cryptography/hazmat/backends/openssl/backend.py b/cryptography/hazmat/backends/openssl/backend.py index 6231aadb..0e824165 100644 --- a/cryptography/hazmat/backends/openssl/backend.py +++ b/cryptography/hazmat/backends/openssl/backend.py @@ -354,18 +354,11 @@ class _CipherContext(object): ) assert res != 0 if operation == self._DECRYPT: - if not mode.tag or len(mode.tag) < 4: - raise ValueError("Authentication tag must be provided and " - "be 4 bytes or longer when decrypting") res = self._backend.lib.EVP_CIPHER_CTX_ctrl( ctx, self._backend.lib.EVP_CTRL_GCM_SET_TAG, len(mode.tag), mode.tag ) assert res != 0 - else: - if mode.tag: - raise ValueError("Authentication tag must be None when " - "encrypting") # pass key/iv res = self._backend.lib.EVP_CipherInit_ex(ctx, self._backend.ffi.NULL, diff --git a/cryptography/hazmat/primitives/ciphers/base.py b/cryptography/hazmat/primitives/ciphers/base.py index b8615cb9..d1ca6d2a 100644 --- a/cryptography/hazmat/primitives/ciphers/base.py +++ b/cryptography/hazmat/primitives/ciphers/base.py @@ -44,8 +44,16 @@ class Cipher(object): def _wrap_ctx(self, ctx, encrypt): if isinstance(self.mode, interfaces.ModeWithAuthenticationTag): if encrypt: + if self.mode.tag is not None: + raise ValueError( + "Authentication tag must be None when encrypting" + ) return _AEADEncryptionContext(ctx) else: + if self.mode.tag is None: + raise ValueError( + "Authentication tag must be provided when decrypting" + ) return _AEADCipherContext(ctx) else: return _CipherContext(ctx) diff --git a/cryptography/hazmat/primitives/ciphers/modes.py b/cryptography/hazmat/primitives/ciphers/modes.py index e1c70185..ab8501c6 100644 --- a/cryptography/hazmat/primitives/ciphers/modes.py +++ b/cryptography/hazmat/primitives/ciphers/modes.py @@ -65,5 +65,10 @@ class GCM(object): name = "GCM" def __init__(self, initialization_vector, tag=None): + if tag is not None and len(tag) < 4: + raise ValueError( + "Authentication tag must be 4 bytes or longer" + ) + self.initialization_vector = initialization_vector self.tag = tag |