aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cryptography/hazmat/bindings/openssl/backend.py2
-rw-r--r--cryptography/hazmat/primitives/ciphers/base.py19
-rw-r--r--cryptography/hazmat/primitives/interfaces.py16
-rw-r--r--tests/hazmat/primitives/utils.py2
4 files changed, 27 insertions, 12 deletions
diff --git a/cryptography/hazmat/bindings/openssl/backend.py b/cryptography/hazmat/bindings/openssl/backend.py
index e9ecc800..7d3eb3d7 100644
--- a/cryptography/hazmat/bindings/openssl/backend.py
+++ b/cryptography/hazmat/bindings/openssl/backend.py
@@ -239,6 +239,8 @@ class GetCipherByName(object):
@utils.register_interface(interfaces.CipherContext)
@utils.register_interface(interfaces.AEADCipherContext)
+@utils.register_interface(interfaces.AEADEncryptionContext)
+@utils.register_interface(interfaces.AEADDecryptionContext)
class _CipherContext(object):
_ENCRYPT = 1
_DECRYPT = 0
diff --git a/cryptography/hazmat/primitives/ciphers/base.py b/cryptography/hazmat/primitives/ciphers/base.py
index a6f06b82..f24fd000 100644
--- a/cryptography/hazmat/primitives/ciphers/base.py
+++ b/cryptography/hazmat/primitives/ciphers/base.py
@@ -43,7 +43,10 @@ class Cipher(object):
def _wrap_ctx(self, ctx, encrypt):
if isinstance(self.mode, interfaces.ModeWithAAD):
- return _AEADCipherContext(ctx, encrypt)
+ if encrypt:
+ return _AEADEncryptionContext(ctx)
+ else:
+ return _AEADDecryptionContext(ctx)
else:
return _CipherContext(ctx)
@@ -69,11 +72,10 @@ class _CipherContext(object):
@utils.register_interface(interfaces.AEADCipherContext)
@utils.register_interface(interfaces.CipherContext)
class _AEADCipherContext(object):
- def __init__(self, ctx, encrypt):
+ def __init__(self, ctx):
self._ctx = ctx
self._tag = None
self._updated = False
- self._encrypt = encrypt
def update(self, data):
if self._ctx is None:
@@ -96,11 +98,16 @@ class _AEADCipherContext(object):
raise AlreadyUpdated("Update has been called on this context")
self._ctx.authenticate_additional_data(data)
+
+@utils.register_interface(interfaces.AEADDecryptionContext)
+class _AEADDecryptionContext(_AEADCipherContext):
+ pass
+
+
+@utils.register_interface(interfaces.AEADEncryptionContext)
+class _AEADEncryptionContext(_AEADCipherContext):
@property
def tag(self):
- if not self._encrypt:
- raise TypeError("The tag attribute is unavailable on a "
- "decryption context")
if self._ctx is not None:
raise NotYetFinalized("You must finalize encryption before "
"getting the tag")
diff --git a/cryptography/hazmat/primitives/interfaces.py b/cryptography/hazmat/primitives/interfaces.py
index c0548dfd..1884e560 100644
--- a/cryptography/hazmat/primitives/interfaces.py
+++ b/cryptography/hazmat/primitives/interfaces.py
@@ -79,17 +79,23 @@ class CipherContext(six.with_metaclass(abc.ABCMeta)):
class AEADCipherContext(six.with_metaclass(abc.ABCMeta)):
+ @abc.abstractmethod
+ def authenticate_additional_data(self, data):
+ """
+ authenticate_additional_data takes bytes and returns nothing.
+ """
+
+
+class AEADEncryptionContext(six.with_metaclass(abc.ABCMeta)):
@abc.abstractproperty
def tag(self):
"""
Returns tag bytes after finalizing encryption.
"""
- @abc.abstractmethod
- def authenticate_additional_data(self, data):
- """
- authenticate_additional_data takes bytes and returns nothing.
- """
+
+class AEADDecryptionContext(six.with_metaclass(abc.ABCMeta)):
+ pass
class PaddingContext(six.with_metaclass(abc.ABCMeta)):
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py
index 58b9a917..9aa3a89a 100644
--- a/tests/hazmat/primitives/utils.py
+++ b/tests/hazmat/primitives/utils.py
@@ -351,5 +351,5 @@ def aead_exception_test(backend, cipher_factory, mode_factory,
)
decryptor = cipher.decryptor()
decryptor.update(b"a" * 16)
- with pytest.raises(TypeError):
+ with pytest.raises(AttributeError):
decryptor.tag