diff options
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/base.py | 3 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/modes.py | 27 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/interfaces.py | 7 | ||||
-rw-r--r-- | docs/hazmat/primitives/interfaces.rst | 12 | ||||
-rw-r--r-- | tests/hazmat/bindings/test_openssl.py | 3 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_block.py | 41 |
6 files changed, 91 insertions, 2 deletions
diff --git a/cryptography/hazmat/primitives/ciphers/base.py b/cryptography/hazmat/primitives/ciphers/base.py index 48e6da6f..9f84d667 100644 --- a/cryptography/hazmat/primitives/ciphers/base.py +++ b/cryptography/hazmat/primitives/ciphers/base.py @@ -23,6 +23,9 @@ class Cipher(object): if not isinstance(algorithm, interfaces.CipherAlgorithm): raise TypeError("Expected interface of interfaces.CipherAlgorithm") + if mode is not None: + mode.validate_for_algorithm(algorithm) + self.algorithm = algorithm self.mode = mode self._backend = backend diff --git a/cryptography/hazmat/primitives/ciphers/modes.py b/cryptography/hazmat/primitives/ciphers/modes.py index 1d0de689..63fa163c 100644 --- a/cryptography/hazmat/primitives/ciphers/modes.py +++ b/cryptography/hazmat/primitives/ciphers/modes.py @@ -25,11 +25,20 @@ class CBC(object): def __init__(self, initialization_vector): self.initialization_vector = initialization_vector + def validate_for_algorithm(self, algorithm): + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError("Invalid iv size ({0}) for {1}".format( + len(self.initialization_vector), self.name + )) + @utils.register_interface(interfaces.Mode) class ECB(object): name = "ECB" + def validate_for_algorithm(self, algorithm): + pass + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithInitializationVector) @@ -39,6 +48,12 @@ class OFB(object): def __init__(self, initialization_vector): self.initialization_vector = initialization_vector + def validate_for_algorithm(self, algorithm): + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError("Invalid iv size ({0}) for {1}".format( + len(self.initialization_vector), self.name + )) + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithInitializationVector) @@ -48,6 +63,12 @@ class CFB(object): def __init__(self, initialization_vector): self.initialization_vector = initialization_vector + def validate_for_algorithm(self, algorithm): + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError("Invalid iv size ({0}) for {1}".format( + len(self.initialization_vector), self.name + )) + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithNonce) @@ -56,3 +77,9 @@ class CTR(object): def __init__(self, nonce): self.nonce = nonce + + def validate_for_algorithm(self, algorithm): + if len(self.nonce) * 8 != algorithm.block_size: + raise ValueError("Invalid nonce size ({0}) for {1}".format( + len(self.nonce), self.name + )) diff --git a/cryptography/hazmat/primitives/interfaces.py b/cryptography/hazmat/primitives/interfaces.py index 8cc9d42c..672ac96a 100644 --- a/cryptography/hazmat/primitives/interfaces.py +++ b/cryptography/hazmat/primitives/interfaces.py @@ -39,6 +39,13 @@ class Mode(six.with_metaclass(abc.ABCMeta)): A string naming this mode. (e.g. ECB, CBC) """ + @abc.abstractmethod + def validate_for_algorithm(self, algorithm): + """ + Checks that all the necessary invariants of this (mode, algorithm) + combination are met. + """ + class ModeWithInitializationVector(six.with_metaclass(abc.ABCMeta)): @abc.abstractproperty diff --git a/docs/hazmat/primitives/interfaces.rst b/docs/hazmat/primitives/interfaces.rst index 11cff51a..e798c0e6 100644 --- a/docs/hazmat/primitives/interfaces.rst +++ b/docs/hazmat/primitives/interfaces.rst @@ -56,6 +56,18 @@ Interfaces used by the symmetric cipher modes described in The name may be used by a backend to influence the operation of a cipher in conjunction with the algorithm's name. + .. method:: validate_for_algorithm(algorithm) + + :param CipherAlgorithm algorithm: + + Checks that the combination of this mode with the provided algorithm + meets any necessary invariants. This should raise an exception if they + are not met. + + For example, the :class:`~cryptography.hazmat.primitives.modes.CBC` + mode uses this method to check that the provided initialization + vector's length matches the block size of the algorithm. + .. class:: ModeWithInitializationVector diff --git a/tests/hazmat/bindings/test_openssl.py b/tests/hazmat/bindings/test_openssl.py index 1eb6f200..4923d698 100644 --- a/tests/hazmat/bindings/test_openssl.py +++ b/tests/hazmat/bindings/test_openssl.py @@ -24,7 +24,8 @@ from cryptography.hazmat.primitives.ciphers.modes import CBC class DummyMode(object): - pass + def validate_for_algorithm(self, algorithm): + pass @utils.register_interface(interfaces.CipherAlgorithm) diff --git a/tests/hazmat/primitives/test_block.py b/tests/hazmat/primitives/test_block.py index 4a8e88b4..2a3e82d4 100644 --- a/tests/hazmat/primitives/test_block.py +++ b/tests/hazmat/primitives/test_block.py @@ -30,6 +30,11 @@ class DummyCipher(object): pass +class DummyMode(object): + def validate_for_algorithm(self, algorithm): + pass + + class TestCipher(object): def test_creates_encryptor(self, backend): cipher = Cipher( @@ -97,7 +102,7 @@ class TestCipherContext(object): def test_nonexistent_cipher(self, backend): cipher = Cipher( - DummyCipher(), object(), backend + DummyCipher(), DummyMode(), backend ) with pytest.raises(UnsupportedAlgorithm): cipher.encryptor() @@ -120,3 +125,37 @@ class TestCipherContext(object): decryptor.update(b"1") with pytest.raises(ValueError): decryptor.finalize() + + +class TestModeValidation(object): + def test_cbc(self, backend): + with pytest.raises(ValueError): + Cipher( + algorithms.AES(b"\x00" * 16), + modes.CBC(b"abc"), + backend, + ) + + def test_ofb(self, backend): + with pytest.raises(ValueError): + Cipher( + algorithms.AES(b"\x00" * 16), + modes.OFB(b"abc"), + backend, + ) + + def test_cfb(self, backend): + with pytest.raises(ValueError): + Cipher( + algorithms.AES(b"\x00" * 16), + modes.CFB(b"abc"), + backend, + ) + + def test_ctr(self, backend): + with pytest.raises(ValueError): + Cipher( + algorithms.AES(b"\x00" * 16), + modes.CTR(b"abc"), + backend, + ) |