aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cryptography/primitives/block/base.py17
-rw-r--r--tests/primitives/test_block.py20
2 files changed, 35 insertions, 2 deletions
diff --git a/cryptography/primitives/block/base.py b/cryptography/primitives/block/base.py
index 417b1ad8..2a6a5c37 100644
--- a/cryptography/primitives/block/base.py
+++ b/cryptography/primitives/block/base.py
@@ -21,16 +21,29 @@ class BlockCipher(object):
self.cipher = cipher
self.mode = mode
self._ctx = api.create_block_cipher_context(cipher, mode)
+ self._operation = None
def encrypt(self, plaintext):
if self._ctx is None:
raise ValueError("BlockCipher was already finalized")
+
+ if self._operation is None:
+ self._operation = "encrypt"
+ elif self._operation != "encrypt":
+ raise ValueError("BlockCipher cannot encrypt when the operation is"
+ " set to %s" % self._operation)
+
return api.update_encrypt_context(self._ctx, plaintext)
def finalize(self):
if self._ctx is None:
raise ValueError("BlockCipher was already finalized")
- # TODO: this might be a decrypt context
- result = api.finalize_encrypt_context(self._ctx)
+
+ if self._operation == "encrypt":
+ result = api.finalize_encrypt_context(self._ctx)
+ else:
+ raise ValueError("BlockCipher cannot finalize the unknown "
+ "operation %s" % self._operation)
+
self._ctx = None
return result
diff --git a/tests/primitives/test_block.py b/tests/primitives/test_block.py
index f5693431..7dccda4b 100644
--- a/tests/primitives/test_block.py
+++ b/tests/primitives/test_block.py
@@ -30,3 +30,23 @@ class TestBlockCipher(object):
cipher.encrypt(b"b" * 16)
with pytest.raises(ValueError):
cipher.finalize()
+
+ def test_encrypt_with_invalid_operation(self):
+ cipher = BlockCipher(
+ ciphers.AES(binascii.unhexlify(b"0" * 32)),
+ modes.CBC(binascii.unhexlify(b"0" * 32), padding.NoPadding())
+ )
+ cipher._operation = "decrypt"
+
+ with pytest.raises(ValueError):
+ cipher.encrypt(b"b" * 16)
+
+ def test_finalize_with_invalid_operation(self):
+ cipher = BlockCipher(
+ ciphers.AES(binascii.unhexlify(b"0" * 32)),
+ modes.CBC(binascii.unhexlify(b"0" * 32), padding.NoPadding())
+ )
+ cipher._operation = "wat"
+
+ with pytest.raises(ValueError):
+ cipher.encrypt(b"b" * 16)