diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/hazmat/primitives/test_hash_vectors.py | 77 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_hashes.py | 21 | ||||
-rw-r--r-- | tests/utils.py | 2 |
3 files changed, 97 insertions, 3 deletions
diff --git a/tests/hazmat/primitives/test_hash_vectors.py b/tests/hazmat/primitives/test_hash_vectors.py index f8561fcd..5225a00b 100644 --- a/tests/hazmat/primitives/test_hash_vectors.py +++ b/tests/hazmat/primitives/test_hash_vectors.py @@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function +import binascii import os import pytest @@ -11,8 +12,8 @@ import pytest from cryptography.hazmat.backends.interfaces import HashBackend from cryptography.hazmat.primitives import hashes -from .utils import generate_hash_test -from ...utils import load_hash_vectors +from .utils import _load_all_params, generate_hash_test +from ...utils import load_hash_vectors, load_nist_vectors @pytest.mark.supported( @@ -250,3 +251,75 @@ class TestSHA3512(object): ], hashes.SHA3_512(), ) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported( + hashes.SHAKE128(digest_size=16)), + skip_message="Does not support SHAKE128", +) +@pytest.mark.requires_backend_interface(interface=HashBackend) +class TestSHAKE128(object): + test_shake128 = generate_hash_test( + load_hash_vectors, + os.path.join("hashes", "SHAKE"), + [ + "SHAKE128LongMsg.rsp", + "SHAKE128ShortMsg.rsp", + ], + hashes.SHAKE128(digest_size=16), + ) + + @pytest.mark.parametrize( + "vector", + _load_all_params( + os.path.join("hashes", "SHAKE"), + [ + "SHAKE128VariableOut.rsp", + ], + load_nist_vectors, + ) + ) + def test_shake128_variable(self, vector, backend): + output_length = int(vector['outputlen']) // 8 + msg = binascii.unhexlify(vector['msg']) + shake = hashes.SHAKE128(digest_size=output_length) + m = hashes.Hash(shake, backend=backend) + m.update(msg) + assert m.finalize() == binascii.unhexlify(vector['output']) + + +@pytest.mark.supported( + only_if=lambda backend: backend.hash_supported( + hashes.SHAKE256(digest_size=32)), + skip_message="Does not support SHAKE256", +) +@pytest.mark.requires_backend_interface(interface=HashBackend) +class TestSHAKE256(object): + test_shake256 = generate_hash_test( + load_hash_vectors, + os.path.join("hashes", "SHAKE"), + [ + "SHAKE256LongMsg.rsp", + "SHAKE256ShortMsg.rsp", + ], + hashes.SHAKE256(digest_size=32), + ) + + @pytest.mark.parametrize( + "vector", + _load_all_params( + os.path.join("hashes", "SHAKE"), + [ + "SHAKE256VariableOut.rsp", + ], + load_nist_vectors, + ) + ) + def test_shake256_variable(self, vector, backend): + output_length = int(vector['outputlen']) // 8 + msg = binascii.unhexlify(vector['msg']) + shake = hashes.SHAKE256(digest_size=output_length) + m = hashes.Hash(shake, backend=backend) + m.update(msg) + assert m.finalize() == binascii.unhexlify(vector['output']) diff --git a/tests/hazmat/primitives/test_hashes.py b/tests/hazmat/primitives/test_hashes.py index 6cba84b5..b10fadcd 100644 --- a/tests/hazmat/primitives/test_hashes.py +++ b/tests/hazmat/primitives/test_hashes.py @@ -179,3 +179,24 @@ def test_buffer_protocol_hash(backend): assert h.finalize() == binascii.unhexlify( b"dff2e73091f6c05e528896c4c831b9448653dc2ff043528f6769437bc7b975c2" ) + + +class TestSHAKE(object): + @pytest.mark.parametrize( + "xof", + [hashes.SHAKE128, hashes.SHAKE256] + ) + def test_invalid_digest_type(self, xof): + with pytest.raises(TypeError): + xof(digest_size=object()) + + @pytest.mark.parametrize( + "xof", + [hashes.SHAKE128, hashes.SHAKE256] + ) + def test_invalid_digest_size(self, xof): + with pytest.raises(ValueError): + xof(digest_size=-5) + + with pytest.raises(ValueError): + xof(digest_size=0) diff --git a/tests/utils.py b/tests/utils.py index 364a349b..b4812808 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -134,7 +134,7 @@ def load_hash_vectors(vector_data): # string as hex 00, which is of course not actually an empty # string. So we parse the provided length and catch this edge case. msg = line.split(" = ")[1].encode("ascii") if length > 0 else b"" - elif line.startswith("MD"): + elif line.startswith("MD") or line.startswith("Output"): md = line.split(" = ")[1] # after MD is found the Msg+MD (+ potential key) tuple is complete if key is not None: |