aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/hazmat/primitives/key-derivation-functions.rst3
-rw-r--r--src/cryptography/hazmat/primitives/kdf/scrypt.py10
-rw-r--r--tests/hazmat/primitives/test_scrypt.py17
3 files changed, 30 insertions, 0 deletions
diff --git a/docs/hazmat/primitives/key-derivation-functions.rst b/docs/hazmat/primitives/key-derivation-functions.rst
index 03260c06..511708d6 100644
--- a/docs/hazmat/primitives/key-derivation-functions.rst
+++ b/docs/hazmat/primitives/key-derivation-functions.rst
@@ -805,6 +805,9 @@ Different KDFs are suitable for different tasks such as:
:class:`~cryptography.hazmat.backends.interfaces.ScryptBackend`
:raises TypeError: This exception is raised if ``salt`` is not ``bytes``.
+ :raises ValueError: This exception is raised if ``n`` is less than 2, if
+ ``n`` is not a power of 2, if ``r`` is less than 1 or if ``p`` is less
+ than 1.
.. method:: derive(key_material)
diff --git a/src/cryptography/hazmat/primitives/kdf/scrypt.py b/src/cryptography/hazmat/primitives/kdf/scrypt.py
index 09181d97..20935409 100644
--- a/src/cryptography/hazmat/primitives/kdf/scrypt.py
+++ b/src/cryptography/hazmat/primitives/kdf/scrypt.py
@@ -25,6 +25,16 @@ class Scrypt(object):
self._length = length
if not isinstance(salt, bytes):
raise TypeError("salt must be bytes.")
+
+ if n < 2 or (n & (n - 1)) != 0:
+ raise ValueError("n must be greater than 1 and be a power of 2.")
+
+ if r < 1:
+ raise ValueError("r must be greater than or equal to 1.")
+
+ if p < 1:
+ raise ValueError("p must be greater than or equal to 1.")
+
self._used = False
self._salt = salt
self._n = n
diff --git a/tests/hazmat/primitives/test_scrypt.py b/tests/hazmat/primitives/test_scrypt.py
index de4100e3..49b304e0 100644
--- a/tests/hazmat/primitives/test_scrypt.py
+++ b/tests/hazmat/primitives/test_scrypt.py
@@ -117,3 +117,20 @@ class TestScrypt(object):
scrypt.derive(password)
with pytest.raises(AlreadyFinalized):
scrypt.derive(password)
+
+ def test_invalid_n(self, backend):
+ # n is less than 2
+ with pytest.raises(ValueError):
+ Scrypt(b"NaCl", 64, 1, 8, 16, backend)
+
+ # n is not a power of 2
+ with pytest.raises(ValueError):
+ Scrypt(b"NaCl", 64, 3, 8, 16, backend)
+
+ def test_invalid_r(self, backend):
+ with pytest.raises(ValueError):
+ Scrypt(b"NaCl", 64, 2, 0, 16, backend)
+
+ def test_invalid_p(self, backend):
+ with pytest.raises(ValueError):
+ Scrypt(b"NaCl", 64, 2, 8, 0, backend)