diff options
-rw-r--r-- | docs/hazmat/primitives/key-derivation-functions.rst | 3 | ||||
-rw-r--r-- | src/cryptography/hazmat/primitives/kdf/scrypt.py | 10 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_scrypt.py | 17 |
3 files changed, 30 insertions, 0 deletions
diff --git a/docs/hazmat/primitives/key-derivation-functions.rst b/docs/hazmat/primitives/key-derivation-functions.rst index 03260c06..511708d6 100644 --- a/docs/hazmat/primitives/key-derivation-functions.rst +++ b/docs/hazmat/primitives/key-derivation-functions.rst @@ -805,6 +805,9 @@ Different KDFs are suitable for different tasks such as: :class:`~cryptography.hazmat.backends.interfaces.ScryptBackend` :raises TypeError: This exception is raised if ``salt`` is not ``bytes``. + :raises ValueError: This exception is raised if ``n`` is less than 2, if + ``n`` is not a power of 2, if ``r`` is less than 1 or if ``p`` is less + than 1. .. method:: derive(key_material) diff --git a/src/cryptography/hazmat/primitives/kdf/scrypt.py b/src/cryptography/hazmat/primitives/kdf/scrypt.py index 09181d97..20935409 100644 --- a/src/cryptography/hazmat/primitives/kdf/scrypt.py +++ b/src/cryptography/hazmat/primitives/kdf/scrypt.py @@ -25,6 +25,16 @@ class Scrypt(object): self._length = length if not isinstance(salt, bytes): raise TypeError("salt must be bytes.") + + if n < 2 or (n & (n - 1)) != 0: + raise ValueError("n must be greater than 1 and be a power of 2.") + + if r < 1: + raise ValueError("r must be greater than or equal to 1.") + + if p < 1: + raise ValueError("p must be greater than or equal to 1.") + self._used = False self._salt = salt self._n = n diff --git a/tests/hazmat/primitives/test_scrypt.py b/tests/hazmat/primitives/test_scrypt.py index de4100e3..49b304e0 100644 --- a/tests/hazmat/primitives/test_scrypt.py +++ b/tests/hazmat/primitives/test_scrypt.py @@ -117,3 +117,20 @@ class TestScrypt(object): scrypt.derive(password) with pytest.raises(AlreadyFinalized): scrypt.derive(password) + + def test_invalid_n(self, backend): + # n is less than 2 + with pytest.raises(ValueError): + Scrypt(b"NaCl", 64, 1, 8, 16, backend) + + # n is not a power of 2 + with pytest.raises(ValueError): + Scrypt(b"NaCl", 64, 3, 8, 16, backend) + + def test_invalid_r(self, backend): + with pytest.raises(ValueError): + Scrypt(b"NaCl", 64, 2, 0, 16, backend) + + def test_invalid_p(self, backend): + with pytest.raises(ValueError): + Scrypt(b"NaCl", 64, 2, 8, 0, backend) |