From afaa4c65b58e31f36a8162f88ad79829f68d0a86 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Sun, 2 Dec 2018 19:20:33 +0800 Subject: centralize our bytes check (#4622) this will make life a bit easier when we support bytearrays --- src/cryptography/fernet.py | 8 +++--- .../hazmat/backends/openssl/backend.py | 8 +++--- src/cryptography/hazmat/backends/openssl/dsa.py | 3 +-- src/cryptography/hazmat/backends/openssl/ec.py | 3 +-- src/cryptography/hazmat/backends/openssl/rsa.py | 3 +-- .../hazmat/primitives/ciphers/algorithms.py | 3 +-- .../hazmat/primitives/ciphers/modes.py | 29 ++++++---------------- src/cryptography/hazmat/primitives/cmac.py | 7 +++--- src/cryptography/hazmat/primitives/hashes.py | 3 +-- src/cryptography/hazmat/primitives/hmac.py | 6 ++--- .../hazmat/primitives/kdf/concatkdf.py | 13 +++++----- src/cryptography/hazmat/primitives/kdf/hkdf.py | 18 +++++--------- src/cryptography/hazmat/primitives/kdf/kbkdf.py | 9 +++---- src/cryptography/hazmat/primitives/kdf/pbkdf2.py | 6 ++--- src/cryptography/hazmat/primitives/kdf/scrypt.py | 7 ++---- src/cryptography/hazmat/primitives/kdf/x963kdf.py | 10 +++----- src/cryptography/hazmat/primitives/padding.py | 6 ++--- 17 files changed, 50 insertions(+), 92 deletions(-) diff --git a/src/cryptography/fernet.py b/src/cryptography/fernet.py index ac2dd0b6..b990defa 100644 --- a/src/cryptography/fernet.py +++ b/src/cryptography/fernet.py @@ -12,6 +12,7 @@ import time import six +from cryptography import utils from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, padding @@ -51,8 +52,7 @@ class Fernet(object): return self._encrypt_from_parts(data, current_time, iv) def _encrypt_from_parts(self, data, current_time, iv): - if not isinstance(data, bytes): - raise TypeError("data must be bytes.") + utils._check_bytes("data", data) padder = padding.PKCS7(algorithms.AES.block_size).padder() padded_data = padder.update(data) + padder.finalize() @@ -82,9 +82,7 @@ class Fernet(object): @staticmethod def _get_unverified_token_data(token): - if not isinstance(token, bytes): - raise TypeError("token must be bytes.") - + utils._check_bytes("token", token) try: data = base64.urlsafe_b64decode(token) except (TypeError, binascii.Error): diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index 5a22a555..ae966cd0 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -1203,8 +1203,8 @@ class Backend(object): def _load_key(self, openssl_read_func, convert_func, data, password): mem_bio = self._bytes_to_bio(data) - if password is not None and not isinstance(password, bytes): - raise TypeError("Password must be bytes") + if password is not None: + utils._check_bytes("password", password) userdata = self._ffi.new("CRYPTOGRAPHY_PASSWORD_DATA *") if password is not None: @@ -2132,8 +2132,8 @@ class Backend(object): def load_key_and_certificates_from_pkcs12(self, data, password): if password is None: password = self._ffi.NULL - elif not isinstance(password, bytes): - raise TypeError("Password must be a byte string or None") + else: + utils._check_bytes("password", password) bio = self._bytes_to_bio(data) p12 = self._lib.d2i_PKCS12_bio(bio.bio, self._ffi.NULL) diff --git a/src/cryptography/hazmat/backends/openssl/dsa.py b/src/cryptography/hazmat/backends/openssl/dsa.py index 48886e45..de61f089 100644 --- a/src/cryptography/hazmat/backends/openssl/dsa.py +++ b/src/cryptography/hazmat/backends/openssl/dsa.py @@ -211,8 +211,7 @@ class _DSAPublicKey(object): def verifier(self, signature, signature_algorithm): _warn_sign_verify_deprecated() - if not isinstance(signature, bytes): - raise TypeError("signature must be bytes.") + utils._check_bytes("signature", signature) _check_not_prehashed(signature_algorithm) return _DSAVerificationContext( diff --git a/src/cryptography/hazmat/backends/openssl/ec.py b/src/cryptography/hazmat/backends/openssl/ec.py index 69da2344..852b4918 100644 --- a/src/cryptography/hazmat/backends/openssl/ec.py +++ b/src/cryptography/hazmat/backends/openssl/ec.py @@ -244,8 +244,7 @@ class _EllipticCurvePublicKey(object): def verifier(self, signature, signature_algorithm): _warn_sign_verify_deprecated() - if not isinstance(signature, bytes): - raise TypeError("signature must be bytes.") + utils._check_bytes("signature", signature) _check_signature_algorithm(signature_algorithm) _check_not_prehashed(signature_algorithm.algorithm) diff --git a/src/cryptography/hazmat/backends/openssl/rsa.py b/src/cryptography/hazmat/backends/openssl/rsa.py index 00f5e377..b7d2173f 100644 --- a/src/cryptography/hazmat/backends/openssl/rsa.py +++ b/src/cryptography/hazmat/backends/openssl/rsa.py @@ -434,8 +434,7 @@ class _RSAPublicKey(object): def verifier(self, signature, padding, algorithm): _warn_sign_verify_deprecated() - if not isinstance(signature, bytes): - raise TypeError("signature must be bytes.") + utils._check_bytes("signature", signature) _check_not_prehashed(algorithm) return _RSAVerificationContext( diff --git a/src/cryptography/hazmat/primitives/ciphers/algorithms.py b/src/cryptography/hazmat/primitives/ciphers/algorithms.py index 68a9e330..21d9ecf0 100644 --- a/src/cryptography/hazmat/primitives/ciphers/algorithms.py +++ b/src/cryptography/hazmat/primitives/ciphers/algorithms.py @@ -153,8 +153,7 @@ class ChaCha20(object): def __init__(self, key, nonce): self.key = _verify_key_size(self, key) - if not isinstance(nonce, bytes): - raise TypeError("nonce must be bytes") + utils._check_bytes("nonce", nonce) if len(nonce) != 16: raise ValueError("nonce must be 128-bits (16 bytes)") diff --git a/src/cryptography/hazmat/primitives/ciphers/modes.py b/src/cryptography/hazmat/primitives/ciphers/modes.py index e82c1a8d..d2444580 100644 --- a/src/cryptography/hazmat/primitives/ciphers/modes.py +++ b/src/cryptography/hazmat/primitives/ciphers/modes.py @@ -88,9 +88,7 @@ class CBC(object): name = "CBC" def __init__(self, initialization_vector): - if not isinstance(initialization_vector, bytes): - raise TypeError("initialization_vector must be bytes") - + utils._check_bytes("initialization_vector", initialization_vector) self._initialization_vector = initialization_vector initialization_vector = utils.read_only_property("_initialization_vector") @@ -103,8 +101,7 @@ class XTS(object): name = "XTS" def __init__(self, tweak): - if not isinstance(tweak, bytes): - raise TypeError("tweak must be bytes") + utils._check_bytes("tweak", tweak) if len(tweak) != 16: raise ValueError("tweak must be 128-bits (16 bytes)") @@ -134,9 +131,7 @@ class OFB(object): name = "OFB" def __init__(self, initialization_vector): - if not isinstance(initialization_vector, bytes): - raise TypeError("initialization_vector must be bytes") - + utils._check_bytes("initialization_vector", initialization_vector) self._initialization_vector = initialization_vector initialization_vector = utils.read_only_property("_initialization_vector") @@ -149,9 +144,7 @@ class CFB(object): name = "CFB" def __init__(self, initialization_vector): - if not isinstance(initialization_vector, bytes): - raise TypeError("initialization_vector must be bytes") - + utils._check_bytes("initialization_vector", initialization_vector) self._initialization_vector = initialization_vector initialization_vector = utils.read_only_property("_initialization_vector") @@ -164,9 +157,7 @@ class CFB8(object): name = "CFB8" def __init__(self, initialization_vector): - if not isinstance(initialization_vector, bytes): - raise TypeError("initialization_vector must be bytes") - + utils._check_bytes("initialization_vector", initialization_vector) self._initialization_vector = initialization_vector initialization_vector = utils.read_only_property("_initialization_vector") @@ -179,9 +170,7 @@ class CTR(object): name = "CTR" def __init__(self, nonce): - if not isinstance(nonce, bytes): - raise TypeError("nonce must be bytes") - + utils._check_bytes("nonce", nonce) self._nonce = nonce nonce = utils.read_only_property("_nonce") @@ -206,14 +195,12 @@ class GCM(object): # len(initialization_vector) must in [1, 2 ** 64), but it's impossible # to actually construct a bytes object that large, so we don't check # for it - if not isinstance(initialization_vector, bytes): - raise TypeError("initialization_vector must be bytes") + utils._check_bytes("initialization_vector", initialization_vector) if len(initialization_vector) == 0: raise ValueError("initialization_vector must be at least 1 byte") self._initialization_vector = initialization_vector if tag is not None: - if not isinstance(tag, bytes): - raise TypeError("tag must be bytes or None") + utils._check_bytes("tag", tag) if min_tag_length < 4: raise ValueError("min_tag_length must be >= 4") if len(tag) < min_tag_length: diff --git a/src/cryptography/hazmat/primitives/cmac.py b/src/cryptography/hazmat/primitives/cmac.py index 77537f04..1404eac3 100644 --- a/src/cryptography/hazmat/primitives/cmac.py +++ b/src/cryptography/hazmat/primitives/cmac.py @@ -36,8 +36,8 @@ class CMAC(object): def update(self, data): if self._ctx is None: raise AlreadyFinalized("Context was already finalized.") - if not isinstance(data, bytes): - raise TypeError("data must be bytes.") + + utils._check_bytes("data", data) self._ctx.update(data) def finalize(self): @@ -48,8 +48,7 @@ class CMAC(object): return digest def verify(self, signature): - if not isinstance(signature, bytes): - raise TypeError("signature must be bytes.") + utils._check_bytes("signature", signature) if self._ctx is None: raise AlreadyFinalized("Context was already finalized.") diff --git a/src/cryptography/hazmat/primitives/hashes.py b/src/cryptography/hazmat/primitives/hashes.py index 259a2c01..89e66ab7 100644 --- a/src/cryptography/hazmat/primitives/hashes.py +++ b/src/cryptography/hazmat/primitives/hashes.py @@ -82,8 +82,7 @@ class Hash(object): def update(self, data): if self._ctx is None: raise AlreadyFinalized("Context was already finalized.") - if not isinstance(data, bytes): - raise TypeError("data must be bytes.") + utils._check_bytes("data", data) self._ctx.update(data) def copy(self): diff --git a/src/cryptography/hazmat/primitives/hmac.py b/src/cryptography/hazmat/primitives/hmac.py index 2e9a4e2f..590555d9 100644 --- a/src/cryptography/hazmat/primitives/hmac.py +++ b/src/cryptography/hazmat/primitives/hmac.py @@ -38,8 +38,7 @@ class HMAC(object): def update(self, data): if self._ctx is None: raise AlreadyFinalized("Context was already finalized.") - if not isinstance(data, bytes): - raise TypeError("data must be bytes.") + utils._check_bytes("data", data) self._ctx.update(data) def copy(self): @@ -60,8 +59,7 @@ class HMAC(object): return digest def verify(self, signature): - if not isinstance(signature, bytes): - raise TypeError("signature must be bytes.") + utils._check_bytes("signature", signature) if self._ctx is None: raise AlreadyFinalized("Context was already finalized.") diff --git a/src/cryptography/hazmat/primitives/kdf/concatkdf.py b/src/cryptography/hazmat/primitives/kdf/concatkdf.py index c6399e4f..89c3b282 100644 --- a/src/cryptography/hazmat/primitives/kdf/concatkdf.py +++ b/src/cryptography/hazmat/primitives/kdf/concatkdf.py @@ -27,14 +27,12 @@ def _common_args_checks(algorithm, length, otherinfo): "Can not derive keys larger than {0} bits.".format( max_length )) - if not (otherinfo is None or isinstance(otherinfo, bytes)): - raise TypeError("otherinfo must be bytes.") + if otherinfo is not None: + utils._check_bytes("otherinfo", otherinfo) def _concatkdf_derive(key_material, length, auxfn, otherinfo): - if not isinstance(key_material, bytes): - raise TypeError("key_material must be bytes.") - + utils._check_bytes("key_material", key_material) output = [b""] outlen = 0 counter = 1 @@ -96,10 +94,11 @@ class ConcatKDFHMAC(object): if self._otherinfo is None: self._otherinfo = b"" - if not (salt is None or isinstance(salt, bytes)): - raise TypeError("salt must be bytes.") if salt is None: salt = b"\x00" * algorithm.block_size + else: + utils._check_bytes("salt", salt) + self._salt = salt if not isinstance(backend, HMACBackend): diff --git a/src/cryptography/hazmat/primitives/kdf/hkdf.py b/src/cryptography/hazmat/primitives/kdf/hkdf.py index 917b4e9c..27dc9c93 100644 --- a/src/cryptography/hazmat/primitives/kdf/hkdf.py +++ b/src/cryptography/hazmat/primitives/kdf/hkdf.py @@ -26,11 +26,10 @@ class HKDF(object): self._algorithm = algorithm - if not (salt is None or isinstance(salt, bytes)): - raise TypeError("salt must be bytes.") - if salt is None: salt = b"\x00" * self._algorithm.digest_size + else: + utils._check_bytes("salt", salt) self._salt = salt @@ -44,9 +43,7 @@ class HKDF(object): return h.finalize() def derive(self, key_material): - if not isinstance(key_material, bytes): - raise TypeError("key_material must be bytes.") - + utils._check_bytes("key_material", key_material) return self._hkdf_expand.derive(self._extract(key_material)) def verify(self, key_material, expected_key): @@ -77,11 +74,10 @@ class HKDFExpand(object): self._length = length - if not (info is None or isinstance(info, bytes)): - raise TypeError("info must be bytes.") - if info is None: info = b"" + else: + utils._check_bytes("info", info) self._info = info @@ -102,9 +98,7 @@ class HKDFExpand(object): return b"".join(output)[:self._length] def derive(self, key_material): - if not isinstance(key_material, bytes): - raise TypeError("key_material must be bytes.") - + utils._check_bytes("key_material", key_material) if self._used: raise AlreadyFinalized diff --git a/src/cryptography/hazmat/primitives/kdf/kbkdf.py b/src/cryptography/hazmat/primitives/kdf/kbkdf.py index 14de56eb..74fe9e29 100644 --- a/src/cryptography/hazmat/primitives/kdf/kbkdf.py +++ b/src/cryptography/hazmat/primitives/kdf/kbkdf.py @@ -73,10 +73,8 @@ class KBKDFHMAC(object): if context is None: context = b'' - if (not isinstance(label, bytes) or - not isinstance(context, bytes)): - raise TypeError('label and context must be of type bytes') - + utils._check_bytes("label", label) + utils._check_bytes("context", context) self._algorithm = algorithm self._mode = mode self._length = length @@ -102,8 +100,7 @@ class KBKDFHMAC(object): if self._used: raise AlreadyFinalized - if not isinstance(key_material, bytes): - raise TypeError('key_material must be bytes') + utils._check_bytes("key_material", key_material) self._used = True # inverse floor division (equivalent to ceiling) diff --git a/src/cryptography/hazmat/primitives/kdf/pbkdf2.py b/src/cryptography/hazmat/primitives/kdf/pbkdf2.py index f8ce7a3b..fbe8964d 100644 --- a/src/cryptography/hazmat/primitives/kdf/pbkdf2.py +++ b/src/cryptography/hazmat/primitives/kdf/pbkdf2.py @@ -31,8 +31,7 @@ class PBKDF2HMAC(object): self._used = False self._algorithm = algorithm self._length = length - if not isinstance(salt, bytes): - raise TypeError("salt must be bytes.") + utils._check_bytes("salt", salt) self._salt = salt self._iterations = iterations self._backend = backend @@ -42,8 +41,7 @@ class PBKDF2HMAC(object): raise AlreadyFinalized("PBKDF2 instances can only be used once.") self._used = True - if not isinstance(key_material, bytes): - raise TypeError("key_material must be bytes.") + utils._check_bytes("key_material", key_material) return self._backend.derive_pbkdf2_hmac( self._algorithm, self._length, diff --git a/src/cryptography/hazmat/primitives/kdf/scrypt.py b/src/cryptography/hazmat/primitives/kdf/scrypt.py index 77dcf9ab..44e369fb 100644 --- a/src/cryptography/hazmat/primitives/kdf/scrypt.py +++ b/src/cryptography/hazmat/primitives/kdf/scrypt.py @@ -30,9 +30,7 @@ class Scrypt(object): ) self._length = length - if not isinstance(salt, bytes): - raise TypeError("salt must be bytes.") - + utils._check_bytes("salt", salt) if n < 2 or (n & (n - 1)) != 0: raise ValueError("n must be greater than 1 and be a power of 2.") @@ -54,8 +52,7 @@ class Scrypt(object): raise AlreadyFinalized("Scrypt instances can only be used once.") self._used = True - if not isinstance(key_material, bytes): - raise TypeError("key_material must be bytes.") + utils._check_bytes("key_material", key_material) return self._backend.derive_scrypt( key_material, self._salt, self._length, self._n, self._r, self._p ) diff --git a/src/cryptography/hazmat/primitives/kdf/x963kdf.py b/src/cryptography/hazmat/primitives/kdf/x963kdf.py index 83789b31..a8c07751 100644 --- a/src/cryptography/hazmat/primitives/kdf/x963kdf.py +++ b/src/cryptography/hazmat/primitives/kdf/x963kdf.py @@ -27,8 +27,9 @@ class X963KDF(object): if length > max_len: raise ValueError( "Can not derive keys larger than {0} bits.".format(max_len)) - if not (sharedinfo is None or isinstance(sharedinfo, bytes)): - raise TypeError("sharedinfo must be bytes.") + if sharedinfo is not None: + utils._check_bytes("sharedinfo", sharedinfo) + self._algorithm = algorithm self._length = length self._sharedinfo = sharedinfo @@ -45,10 +46,7 @@ class X963KDF(object): if self._used: raise AlreadyFinalized self._used = True - - if not isinstance(key_material, bytes): - raise TypeError("key_material must be bytes.") - + utils._check_bytes("key_material", key_material) output = [b""] outlen = 0 counter = 1 diff --git a/src/cryptography/hazmat/primitives/padding.py b/src/cryptography/hazmat/primitives/padding.py index a081976e..170c8021 100644 --- a/src/cryptography/hazmat/primitives/padding.py +++ b/src/cryptography/hazmat/primitives/padding.py @@ -40,8 +40,7 @@ def _byte_padding_update(buffer_, data, block_size): if buffer_ is None: raise AlreadyFinalized("Context was already finalized.") - if not isinstance(data, bytes): - raise TypeError("data must be bytes.") + utils._check_bytes("data", data) buffer_ += data @@ -65,8 +64,7 @@ def _byte_unpadding_update(buffer_, data, block_size): if buffer_ is None: raise AlreadyFinalized("Context was already finalized.") - if not isinstance(data, bytes): - raise TypeError("data must be bytes.") + utils._check_bytes("data", data) buffer_ += data -- cgit v1.2.3