diff options
author | David Reid <dreid@dreid.org> | 2014-01-03 16:02:51 -0800 |
---|---|---|
committer | David Reid <dreid@dreid.org> | 2014-01-03 16:02:51 -0800 |
commit | 24c9a8d153ed7b1520a87d2ad22d9e9b26f272b8 (patch) | |
tree | dc1ea9ab698f092f1310d6adf8077f1ef9ad95d5 | |
parent | f96db83a64bb0ac40d04d27383d7c2defbcec491 (diff) | |
parent | 267dbc946b4584b7b4ed10a439b2820d3b048356 (diff) | |
download | cryptography-24c9a8d153ed7b1520a87d2ad22d9e9b26f272b8.tar.gz cryptography-24c9a8d153ed7b1520a87d2ad22d9e9b26f272b8.tar.bz2 cryptography-24c9a8d153ed7b1520a87d2ad22d9e9b26f272b8.zip |
Merge pull request #272 from alex/validate-iv
Validate the IV/nonce length for a given algorithm.
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/base.py | 3 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/modes.py | 33 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/interfaces.py | 7 | ||||
-rw-r--r-- | docs/hazmat/primitives/interfaces.rst | 12 | ||||
-rw-r--r-- | tests/hazmat/backends/test_openssl.py | 3 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_block.py | 37 |
6 files changed, 95 insertions, 0 deletions
diff --git a/cryptography/hazmat/primitives/ciphers/base.py b/cryptography/hazmat/primitives/ciphers/base.py index 1da0802c..d366e4cf 100644 --- a/cryptography/hazmat/primitives/ciphers/base.py +++ b/cryptography/hazmat/primitives/ciphers/base.py @@ -25,6 +25,9 @@ class Cipher(object): if not isinstance(algorithm, interfaces.CipherAlgorithm): raise TypeError("Expected interface of interfaces.CipherAlgorithm") + if mode is not None: + mode.validate_for_algorithm(algorithm) + self.algorithm = algorithm self.mode = mode self._backend = backend diff --git a/cryptography/hazmat/primitives/ciphers/modes.py b/cryptography/hazmat/primitives/ciphers/modes.py index ab8501c6..739f23dd 100644 --- a/cryptography/hazmat/primitives/ciphers/modes.py +++ b/cryptography/hazmat/primitives/ciphers/modes.py @@ -25,11 +25,20 @@ class CBC(object): def __init__(self, initialization_vector): self.initialization_vector = initialization_vector + def validate_for_algorithm(self, algorithm): + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError("Invalid iv size ({0}) for {1}".format( + len(self.initialization_vector), self.name + )) + @utils.register_interface(interfaces.Mode) class ECB(object): name = "ECB" + def validate_for_algorithm(self, algorithm): + pass + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithInitializationVector) @@ -39,6 +48,12 @@ class OFB(object): def __init__(self, initialization_vector): self.initialization_vector = initialization_vector + def validate_for_algorithm(self, algorithm): + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError("Invalid iv size ({0}) for {1}".format( + len(self.initialization_vector), self.name + )) + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithInitializationVector) @@ -48,6 +63,12 @@ class CFB(object): def __init__(self, initialization_vector): self.initialization_vector = initialization_vector + def validate_for_algorithm(self, algorithm): + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError("Invalid iv size ({0}) for {1}".format( + len(self.initialization_vector), self.name + )) + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithNonce) @@ -57,6 +78,12 @@ class CTR(object): def __init__(self, nonce): self.nonce = nonce + def validate_for_algorithm(self, algorithm): + if len(self.nonce) * 8 != algorithm.block_size: + raise ValueError("Invalid nonce size ({0}) for {1}".format( + len(self.nonce), self.name + )) + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithInitializationVector) @@ -65,6 +92,9 @@ class GCM(object): name = "GCM" def __init__(self, initialization_vector, tag=None): + # len(initialization_vector) must in [1, 2 ** 64), but it's impossible + # to actually construct a bytes object that large, so we don't check + # for it if tag is not None and len(tag) < 4: raise ValueError( "Authentication tag must be 4 bytes or longer" @@ -72,3 +102,6 @@ class GCM(object): self.initialization_vector = initialization_vector self.tag = tag + + def validate_for_algorithm(self, algorithm): + pass diff --git a/cryptography/hazmat/primitives/interfaces.py b/cryptography/hazmat/primitives/interfaces.py index e87c9ca9..7a6bf3e2 100644 --- a/cryptography/hazmat/primitives/interfaces.py +++ b/cryptography/hazmat/primitives/interfaces.py @@ -47,6 +47,13 @@ class Mode(six.with_metaclass(abc.ABCMeta)): A string naming this mode (e.g. "ECB", "CBC"). """ + @abc.abstractmethod + def validate_for_algorithm(self, algorithm): + """ + Checks that all the necessary invariants of this (mode, algorithm) + combination are met. + """ + class ModeWithInitializationVector(six.with_metaclass(abc.ABCMeta)): @abc.abstractproperty diff --git a/docs/hazmat/primitives/interfaces.rst b/docs/hazmat/primitives/interfaces.rst index 361b723e..edb24cd9 100644 --- a/docs/hazmat/primitives/interfaces.rst +++ b/docs/hazmat/primitives/interfaces.rst @@ -67,6 +67,18 @@ Interfaces used by the symmetric cipher modes described in The name may be used by a backend to influence the operation of a cipher in conjunction with the algorithm's name. + .. method:: validate_for_algorithm(algorithm) + + :param CipherAlgorithm algorithm: + + Checks that the combination of this mode with the provided algorithm + meets any necessary invariants. This should raise an exception if they + are not met. + + For example, the :class:`~cryptography.hazmat.primitives.modes.CBC` + mode uses this method to check that the provided initialization + vector's length matches the block size of the algorithm. + .. class:: ModeWithInitializationVector diff --git a/tests/hazmat/backends/test_openssl.py b/tests/hazmat/backends/test_openssl.py index 22cfbe71..ad399594 100644 --- a/tests/hazmat/backends/test_openssl.py +++ b/tests/hazmat/backends/test_openssl.py @@ -27,6 +27,9 @@ from cryptography.hazmat.primitives.ciphers.modes import CBC class DummyMode(object): name = "dummy-mode" + def validate_for_algorithm(self, algorithm): + pass + @utils.register_interface(interfaces.CipherAlgorithm) class DummyCipher(object): diff --git a/tests/hazmat/primitives/test_block.py b/tests/hazmat/primitives/test_block.py index 30cf1d60..f758ffaa 100644 --- a/tests/hazmat/primitives/test_block.py +++ b/tests/hazmat/primitives/test_block.py @@ -35,6 +35,9 @@ from .utils import ( class DummyMode(object): name = "dummy-mode" + def validate_for_algorithm(self, algorithm): + pass + @utils.register_interface(interfaces.CipherAlgorithm) class DummyCipher(object): @@ -152,3 +155,37 @@ class TestAEADCipherContext(object): algorithms.AES, modes.GCM, ) + + +class TestModeValidation(object): + def test_cbc(self, backend): + with pytest.raises(ValueError): + Cipher( + algorithms.AES(b"\x00" * 16), + modes.CBC(b"abc"), + backend, + ) + + def test_ofb(self, backend): + with pytest.raises(ValueError): + Cipher( + algorithms.AES(b"\x00" * 16), + modes.OFB(b"abc"), + backend, + ) + + def test_cfb(self, backend): + with pytest.raises(ValueError): + Cipher( + algorithms.AES(b"\x00" * 16), + modes.CFB(b"abc"), + backend, + ) + + def test_ctr(self, backend): + with pytest.raises(ValueError): + Cipher( + algorithms.AES(b"\x00" * 16), + modes.CTR(b"abc"), + backend, + ) |