From 0d492db1be3e287b5f49a5ce408196401bdd0a2b Mon Sep 17 00:00:00 2001 From: David Reid Date: Mon, 27 Jan 2014 17:05:49 -0800 Subject: Closer to proposed interface in #513. --- cryptography/hazmat/primitives/kdf/hkdf.py | 56 ++++++++++++++++-------------- tests/hazmat/primitives/utils.py | 32 +++++++++-------- 2 files changed, 47 insertions(+), 41 deletions(-) diff --git a/cryptography/hazmat/primitives/kdf/hkdf.py b/cryptography/hazmat/primitives/kdf/hkdf.py index 3f3897c1..f2ea114b 100644 --- a/cryptography/hazmat/primitives/kdf/hkdf.py +++ b/cryptography/hazmat/primitives/kdf/hkdf.py @@ -16,38 +16,40 @@ import six from cryptography.hazmat.primitives import hmac -def hkdf_extract(algorithm, ikm, salt, backend): - h = hmac.HMAC(salt, algorithm, backend=backend) - h.update(ikm) - return h.finalize() +class HKDF(object): + def __init__(self, algorithm, length, salt, info, backend): + self._algorithm = algorithm + self._length = length + if salt is None: + salt = b"\x00" * (self._algorithm.digest_size // 8) -def hkdf_expand(algorithm, prk, info, length, backend): - output = [b''] - counter = 1 + self._salt = salt - while (algorithm.digest_size // 8) * len(output) < length: - h = hmac.HMAC(prk, algorithm, backend=backend) - h.update(output[-1]) - h.update(info) - h.update(six.int2byte(counter)) - output.append(h.finalize()) - counter += 1 + if info is None: + info = b"" - return b"".join(output)[:length] + self._info = info + self._backend = backend + def extract(self, key_material): + h = hmac.HMAC(self._salt, self._algorithm, backend=self._backend) + h.update(key_material) + return h.finalize() -def hkdf_derive(key, length, salt, info, algorithm, backend): - if info is None: - info = b"" + def expand(self, key_material): + output = [b''] + counter = 1 - if salt is None: - salt = b"\x00" * (algorithm.digest_size // 8) + while (self._algorithm.digest_size // 8) * len(output) < self._length: + h = hmac.HMAC(key_material, self._algorithm, backend=self._backend) + h.update(output[-1]) + h.update(self._info) + h.update(six.int2byte(counter)) + output.append(h.finalize()) + counter += 1 - return hkdf_expand( - algorithm, - hkdf_extract(algorithm, key, salt, backend=backend), - info, - length, - backend=backend - ) + return b"".join(output)[:self._length] + + def derive(self, key_material): + return self.expand(self.extract(key_material)) diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index 9e9088a3..2584272a 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -8,9 +8,7 @@ import pytest from cryptography.hazmat.primitives import hashes, hmac from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from cryptography.hazmat.primitives.ciphers import Cipher -from cryptography.hazmat.primitives.kdf.hkdf import ( - hkdf_derive, hkdf_extract, hkdf_expand -) +from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.exceptions import ( AlreadyFinalized, NotYetFinalized, AlreadyUpdated, InvalidTag, @@ -306,38 +304,44 @@ def aead_tag_exception_test(backend, cipher_factory, mode_factory): def hkdf_derive_test(backend, algorithm, params): - okm = hkdf_derive( - binascii.unhexlify(params["ikm"]), - int(params["l"]), - binascii.unhexlify(params["salt"]), - binascii.unhexlify(params["info"]), + hkdf = HKDF( algorithm, + int(params["l"]), + salt=binascii.unhexlify(params["salt"]) or None, + info=binascii.unhexlify(params["info"]) or None, backend=backend ) + okm = hkdf.derive(binascii.unhexlify(params["ikm"])) + assert okm == binascii.unhexlify(params["okm"]) def hkdf_extract_test(backend, algorithm, params): - prk = hkdf_extract( + hkdf = HKDF( algorithm, - binascii.unhexlify(params["ikm"]), - binascii.unhexlify(params["salt"]), + int(params["l"]), + salt=binascii.unhexlify(params["salt"]) or None, + info=binascii.unhexlify(params["info"]) or None, backend=backend ) + prk = hkdf.extract(binascii.unhexlify(params["ikm"])) + assert prk == binascii.unhexlify(params["prk"]) def hkdf_expand_test(backend, algorithm, params): - okm = hkdf_expand( + hkdf = HKDF( algorithm, - binascii.unhexlify(params["prk"]), - binascii.unhexlify(params["info"]), int(params["l"]), + salt=binascii.unhexlify(params["salt"]) or None, + info=binascii.unhexlify(params["info"]) or None, backend=backend ) + okm = hkdf.expand(binascii.unhexlify(params["prk"])) + assert okm == binascii.unhexlify(params["okm"]) -- cgit v1.2.3