aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cryptography/hazmat/primitives/ciphers/base.py3
-rw-r--r--cryptography/hazmat/primitives/ciphers/modes.py27
-rw-r--r--cryptography/hazmat/primitives/interfaces.py7
-rw-r--r--docs/hazmat/primitives/interfaces.rst12
-rw-r--r--tests/hazmat/bindings/test_openssl.py3
-rw-r--r--tests/hazmat/primitives/test_block.py41
6 files changed, 91 insertions, 2 deletions
diff --git a/cryptography/hazmat/primitives/ciphers/base.py b/cryptography/hazmat/primitives/ciphers/base.py
index 48e6da6f..9f84d667 100644
--- a/cryptography/hazmat/primitives/ciphers/base.py
+++ b/cryptography/hazmat/primitives/ciphers/base.py
@@ -23,6 +23,9 @@ class Cipher(object):
if not isinstance(algorithm, interfaces.CipherAlgorithm):
raise TypeError("Expected interface of interfaces.CipherAlgorithm")
+ if mode is not None:
+ mode.validate_for_algorithm(algorithm)
+
self.algorithm = algorithm
self.mode = mode
self._backend = backend
diff --git a/cryptography/hazmat/primitives/ciphers/modes.py b/cryptography/hazmat/primitives/ciphers/modes.py
index 1d0de689..63fa163c 100644
--- a/cryptography/hazmat/primitives/ciphers/modes.py
+++ b/cryptography/hazmat/primitives/ciphers/modes.py
@@ -25,11 +25,20 @@ class CBC(object):
def __init__(self, initialization_vector):
self.initialization_vector = initialization_vector
+ def validate_for_algorithm(self, algorithm):
+ if len(self.initialization_vector) * 8 != algorithm.block_size:
+ raise ValueError("Invalid iv size ({0}) for {1}".format(
+ len(self.initialization_vector), self.name
+ ))
+
@utils.register_interface(interfaces.Mode)
class ECB(object):
name = "ECB"
+ def validate_for_algorithm(self, algorithm):
+ pass
+
@utils.register_interface(interfaces.Mode)
@utils.register_interface(interfaces.ModeWithInitializationVector)
@@ -39,6 +48,12 @@ class OFB(object):
def __init__(self, initialization_vector):
self.initialization_vector = initialization_vector
+ def validate_for_algorithm(self, algorithm):
+ if len(self.initialization_vector) * 8 != algorithm.block_size:
+ raise ValueError("Invalid iv size ({0}) for {1}".format(
+ len(self.initialization_vector), self.name
+ ))
+
@utils.register_interface(interfaces.Mode)
@utils.register_interface(interfaces.ModeWithInitializationVector)
@@ -48,6 +63,12 @@ class CFB(object):
def __init__(self, initialization_vector):
self.initialization_vector = initialization_vector
+ def validate_for_algorithm(self, algorithm):
+ if len(self.initialization_vector) * 8 != algorithm.block_size:
+ raise ValueError("Invalid iv size ({0}) for {1}".format(
+ len(self.initialization_vector), self.name
+ ))
+
@utils.register_interface(interfaces.Mode)
@utils.register_interface(interfaces.ModeWithNonce)
@@ -56,3 +77,9 @@ class CTR(object):
def __init__(self, nonce):
self.nonce = nonce
+
+ def validate_for_algorithm(self, algorithm):
+ if len(self.nonce) * 8 != algorithm.block_size:
+ raise ValueError("Invalid nonce size ({0}) for {1}".format(
+ len(self.nonce), self.name
+ ))
diff --git a/cryptography/hazmat/primitives/interfaces.py b/cryptography/hazmat/primitives/interfaces.py
index 8cc9d42c..672ac96a 100644
--- a/cryptography/hazmat/primitives/interfaces.py
+++ b/cryptography/hazmat/primitives/interfaces.py
@@ -39,6 +39,13 @@ class Mode(six.with_metaclass(abc.ABCMeta)):
A string naming this mode. (e.g. ECB, CBC)
"""
+ @abc.abstractmethod
+ def validate_for_algorithm(self, algorithm):
+ """
+ Checks that all the necessary invariants of this (mode, algorithm)
+ combination are met.
+ """
+
class ModeWithInitializationVector(six.with_metaclass(abc.ABCMeta)):
@abc.abstractproperty
diff --git a/docs/hazmat/primitives/interfaces.rst b/docs/hazmat/primitives/interfaces.rst
index 11cff51a..e798c0e6 100644
--- a/docs/hazmat/primitives/interfaces.rst
+++ b/docs/hazmat/primitives/interfaces.rst
@@ -56,6 +56,18 @@ Interfaces used by the symmetric cipher modes described in
The name may be used by a backend to influence the operation of a
cipher in conjunction with the algorithm's name.
+ .. method:: validate_for_algorithm(algorithm)
+
+ :param CipherAlgorithm algorithm:
+
+ Checks that the combination of this mode with the provided algorithm
+ meets any necessary invariants. This should raise an exception if they
+ are not met.
+
+ For example, the :class:`~cryptography.hazmat.primitives.modes.CBC`
+ mode uses this method to check that the provided initialization
+ vector's length matches the block size of the algorithm.
+
.. class:: ModeWithInitializationVector
diff --git a/tests/hazmat/bindings/test_openssl.py b/tests/hazmat/bindings/test_openssl.py
index 1eb6f200..4923d698 100644
--- a/tests/hazmat/bindings/test_openssl.py
+++ b/tests/hazmat/bindings/test_openssl.py
@@ -24,7 +24,8 @@ from cryptography.hazmat.primitives.ciphers.modes import CBC
class DummyMode(object):
- pass
+ def validate_for_algorithm(self, algorithm):
+ pass
@utils.register_interface(interfaces.CipherAlgorithm)
diff --git a/tests/hazmat/primitives/test_block.py b/tests/hazmat/primitives/test_block.py
index f6c44b47..b0bbba09 100644
--- a/tests/hazmat/primitives/test_block.py
+++ b/tests/hazmat/primitives/test_block.py
@@ -30,6 +30,11 @@ class DummyCipher(object):
pass
+class DummyMode(object):
+ def validate_for_algorithm(self, algorithm):
+ pass
+
+
class TestCipher(object):
def test_creates_encryptor(self, backend):
cipher = Cipher(
@@ -97,7 +102,7 @@ class TestCipherContext(object):
def test_nonexistent_cipher(self, backend):
cipher = Cipher(
- DummyCipher(), object(), backend
+ DummyCipher(), DummyMode(), backend
)
with pytest.raises(UnsupportedAlgorithm):
cipher.encryptor()
@@ -120,3 +125,37 @@ class TestCipherContext(object):
decryptor.update(b"1")
with pytest.raises(ValueError):
decryptor.finalize()
+
+
+class TestModeValidation(object):
+ def test_cbc(self, backend):
+ with pytest.raises(ValueError):
+ Cipher(
+ algorithms.AES(b"\x00" * 16),
+ modes.CBC(b"abc"),
+ backend,
+ )
+
+ def test_ofb(self, backend):
+ with pytest.raises(ValueError):
+ Cipher(
+ algorithms.AES(b"\x00" * 16),
+ modes.OFB(b"abc"),
+ backend,
+ )
+
+ def test_cfb(self, backend):
+ with pytest.raises(ValueError):
+ Cipher(
+ algorithms.AES(b"\x00" * 16),
+ modes.CFB(b"abc"),
+ backend,
+ )
+
+ def test_ctr(self, backend):
+ with pytest.raises(ValueError):
+ Cipher(
+ algorithms.AES(b"\x00" * 16),
+ modes.CTR(b"abc"),
+ backend,
+ )