aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Reid <dreid@dreid.org>2014-01-27 17:05:49 -0800
committerDavid Reid <dreid@dreid.org>2014-02-03 10:05:27 -0800
commit0d492db1be3e287b5f49a5ce408196401bdd0a2b (patch)
tree8c9cbc30464819cfd173d751600d92e21fb7a836
parent14367303f16bc271f4a8f11f09b02342f44c3a7e (diff)
downloadcryptography-0d492db1be3e287b5f49a5ce408196401bdd0a2b.tar.gz
cryptography-0d492db1be3e287b5f49a5ce408196401bdd0a2b.tar.bz2
cryptography-0d492db1be3e287b5f49a5ce408196401bdd0a2b.zip
Closer to proposed interface in #513.
-rw-r--r--cryptography/hazmat/primitives/kdf/hkdf.py56
-rw-r--r--tests/hazmat/primitives/utils.py32
2 files changed, 47 insertions, 41 deletions
diff --git a/cryptography/hazmat/primitives/kdf/hkdf.py b/cryptography/hazmat/primitives/kdf/hkdf.py
index 3f3897c1..f2ea114b 100644
--- a/cryptography/hazmat/primitives/kdf/hkdf.py
+++ b/cryptography/hazmat/primitives/kdf/hkdf.py
@@ -16,38 +16,40 @@ import six
from cryptography.hazmat.primitives import hmac
-def hkdf_extract(algorithm, ikm, salt, backend):
- h = hmac.HMAC(salt, algorithm, backend=backend)
- h.update(ikm)
- return h.finalize()
+class HKDF(object):
+ def __init__(self, algorithm, length, salt, info, backend):
+ self._algorithm = algorithm
+ self._length = length
+ if salt is None:
+ salt = b"\x00" * (self._algorithm.digest_size // 8)
-def hkdf_expand(algorithm, prk, info, length, backend):
- output = [b'']
- counter = 1
+ self._salt = salt
- while (algorithm.digest_size // 8) * len(output) < length:
- h = hmac.HMAC(prk, algorithm, backend=backend)
- h.update(output[-1])
- h.update(info)
- h.update(six.int2byte(counter))
- output.append(h.finalize())
- counter += 1
+ if info is None:
+ info = b""
- return b"".join(output)[:length]
+ self._info = info
+ self._backend = backend
+ def extract(self, key_material):
+ h = hmac.HMAC(self._salt, self._algorithm, backend=self._backend)
+ h.update(key_material)
+ return h.finalize()
-def hkdf_derive(key, length, salt, info, algorithm, backend):
- if info is None:
- info = b""
+ def expand(self, key_material):
+ output = [b'']
+ counter = 1
- if salt is None:
- salt = b"\x00" * (algorithm.digest_size // 8)
+ while (self._algorithm.digest_size // 8) * len(output) < self._length:
+ h = hmac.HMAC(key_material, self._algorithm, backend=self._backend)
+ h.update(output[-1])
+ h.update(self._info)
+ h.update(six.int2byte(counter))
+ output.append(h.finalize())
+ counter += 1
- return hkdf_expand(
- algorithm,
- hkdf_extract(algorithm, key, salt, backend=backend),
- info,
- length,
- backend=backend
- )
+ return b"".join(output)[:self._length]
+
+ def derive(self, key_material):
+ return self.expand(self.extract(key_material))
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py
index 9e9088a3..2584272a 100644
--- a/tests/hazmat/primitives/utils.py
+++ b/tests/hazmat/primitives/utils.py
@@ -8,9 +8,7 @@ import pytest
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives.ciphers import Cipher
-from cryptography.hazmat.primitives.kdf.hkdf import (
- hkdf_derive, hkdf_extract, hkdf_expand
-)
+from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.exceptions import (
AlreadyFinalized, NotYetFinalized, AlreadyUpdated, InvalidTag,
@@ -306,38 +304,44 @@ def aead_tag_exception_test(backend, cipher_factory, mode_factory):
def hkdf_derive_test(backend, algorithm, params):
- okm = hkdf_derive(
- binascii.unhexlify(params["ikm"]),
- int(params["l"]),
- binascii.unhexlify(params["salt"]),
- binascii.unhexlify(params["info"]),
+ hkdf = HKDF(
algorithm,
+ int(params["l"]),
+ salt=binascii.unhexlify(params["salt"]) or None,
+ info=binascii.unhexlify(params["info"]) or None,
backend=backend
)
+ okm = hkdf.derive(binascii.unhexlify(params["ikm"]))
+
assert okm == binascii.unhexlify(params["okm"])
def hkdf_extract_test(backend, algorithm, params):
- prk = hkdf_extract(
+ hkdf = HKDF(
algorithm,
- binascii.unhexlify(params["ikm"]),
- binascii.unhexlify(params["salt"]),
+ int(params["l"]),
+ salt=binascii.unhexlify(params["salt"]) or None,
+ info=binascii.unhexlify(params["info"]) or None,
backend=backend
)
+ prk = hkdf.extract(binascii.unhexlify(params["ikm"]))
+
assert prk == binascii.unhexlify(params["prk"])
def hkdf_expand_test(backend, algorithm, params):
- okm = hkdf_expand(
+ hkdf = HKDF(
algorithm,
- binascii.unhexlify(params["prk"]),
- binascii.unhexlify(params["info"]),
int(params["l"]),
+ salt=binascii.unhexlify(params["salt"]) or None,
+ info=binascii.unhexlify(params["info"]) or None,
backend=backend
)
+ okm = hkdf.expand(binascii.unhexlify(params["prk"]))
+
assert okm == binascii.unhexlify(params["okm"])