diff options
-rw-r--r-- | src/cryptography/hazmat/primitives/keywrap.py | 9 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_keywrap.py | 24 |
2 files changed, 31 insertions, 2 deletions
diff --git a/src/cryptography/hazmat/primitives/keywrap.py b/src/cryptography/hazmat/primitives/keywrap.py index 3b531318..2b7955f8 100644 --- a/src/cryptography/hazmat/primitives/keywrap.py +++ b/src/cryptography/hazmat/primitives/keywrap.py @@ -118,11 +118,16 @@ def aes_key_unwrap_with_padding(wrapping_key, wrapped_key, backend): b = (8 * n) - mli if ( not bytes_eq(a[:4], b"\xa6\x59\x59\xa6") or not - 8 * (n - 1) < mli <= 8 * n or not bytes_eq(data[-b:], b"\x00" * b) + 8 * (n - 1) < mli <= 8 * n or ( + b != 0 and not bytes_eq(data[-b:], b"\x00" * b) + ) ): raise InvalidUnwrap() - return data[:-b] + if b == 0: + return data + else: + return data[:-b] def aes_key_unwrap(wrapping_key, wrapped_key, backend): diff --git a/tests/hazmat/primitives/test_keywrap.py b/tests/hazmat/primitives/test_keywrap.py index 8311c2a4..9b1e43e4 100644 --- a/tests/hazmat/primitives/test_keywrap.py +++ b/tests/hazmat/primitives/test_keywrap.py @@ -143,6 +143,18 @@ class TestAESKeyWrapWithPadding(object): @pytest.mark.parametrize( "params", + _load_all_params("keywrap", ["kwp_botan.txt"], load_nist_vectors) + ) + def test_wrap_additional_vectors(self, backend, params): + wrapping_key = binascii.unhexlify(params["key"]) + key_to_wrap = binascii.unhexlify(params["input"]) + wrapped_key = keywrap.aes_key_wrap_with_padding( + wrapping_key, key_to_wrap, backend + ) + assert wrapped_key == binascii.unhexlify(params["output"]) + + @pytest.mark.parametrize( + "params", _load_all_params( os.path.join("keywrap", "kwtestvectors"), ["KWP_AD_128.txt", "KWP_AD_192.txt", "KWP_AD_256.txt"], @@ -163,6 +175,18 @@ class TestAESKeyWrapWithPadding(object): ) assert params["p"] == binascii.hexlify(unwrapped_key) + @pytest.mark.parametrize( + "params", + _load_all_params("keywrap", ["kwp_botan.txt"], load_nist_vectors) + ) + def test_unwrap_additional_vectors(self, backend, params): + wrapping_key = binascii.unhexlify(params["key"]) + wrapped_key = binascii.unhexlify(params["output"]) + unwrapped_key = keywrap.aes_key_unwrap_with_padding( + wrapping_key, wrapped_key, backend + ) + assert unwrapped_key == binascii.unhexlify(params["input"]) + def test_unwrap_invalid_wrapped_key_length(self, backend): # Keys to unwrap must be at least 16 bytes with pytest.raises(ValueError, match='Must be at least 16 bytes'): |