diff options
-rw-r--r-- | cryptography/primitives/block/base.py | 17 | ||||
-rw-r--r-- | tests/primitives/test_block.py | 20 |
2 files changed, 35 insertions, 2 deletions
diff --git a/cryptography/primitives/block/base.py b/cryptography/primitives/block/base.py index 417b1ad8..2a6a5c37 100644 --- a/cryptography/primitives/block/base.py +++ b/cryptography/primitives/block/base.py @@ -21,16 +21,29 @@ class BlockCipher(object): self.cipher = cipher self.mode = mode self._ctx = api.create_block_cipher_context(cipher, mode) + self._operation = None def encrypt(self, plaintext): if self._ctx is None: raise ValueError("BlockCipher was already finalized") + + if self._operation is None: + self._operation = "encrypt" + elif self._operation != "encrypt": + raise ValueError("BlockCipher cannot encrypt when the operation is" + " set to %s" % self._operation) + return api.update_encrypt_context(self._ctx, plaintext) def finalize(self): if self._ctx is None: raise ValueError("BlockCipher was already finalized") - # TODO: this might be a decrypt context - result = api.finalize_encrypt_context(self._ctx) + + if self._operation == "encrypt": + result = api.finalize_encrypt_context(self._ctx) + else: + raise ValueError("BlockCipher cannot finalize the unknown " + "operation %s" % self._operation) + self._ctx = None return result diff --git a/tests/primitives/test_block.py b/tests/primitives/test_block.py index f5693431..7dccda4b 100644 --- a/tests/primitives/test_block.py +++ b/tests/primitives/test_block.py @@ -30,3 +30,23 @@ class TestBlockCipher(object): cipher.encrypt(b"b" * 16) with pytest.raises(ValueError): cipher.finalize() + + def test_encrypt_with_invalid_operation(self): + cipher = BlockCipher( + ciphers.AES(binascii.unhexlify(b"0" * 32)), + modes.CBC(binascii.unhexlify(b"0" * 32), padding.NoPadding()) + ) + cipher._operation = "decrypt" + + with pytest.raises(ValueError): + cipher.encrypt(b"b" * 16) + + def test_finalize_with_invalid_operation(self): + cipher = BlockCipher( + ciphers.AES(binascii.unhexlify(b"0" * 32)), + modes.CBC(binascii.unhexlify(b"0" * 32), padding.NoPadding()) + ) + cipher._operation = "wat" + + with pytest.raises(ValueError): + cipher.encrypt(b"b" * 16) |