diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cryptography/hazmat/primitives/serialization.py | 49 | ||||
-rw-r--r-- | src/cryptography/utils.py | 7 |
2 files changed, 34 insertions, 22 deletions
diff --git a/src/cryptography/hazmat/primitives/serialization.py b/src/cryptography/hazmat/primitives/serialization.py index 8a4c8bd8..858ec043 100644 --- a/src/cryptography/hazmat/primitives/serialization.py +++ b/src/cryptography/hazmat/primitives/serialization.py @@ -49,7 +49,7 @@ def load_pem_public_key(data, backend): def load_ssh_public_key(data, backend): key_parts = data.split(b' ') - if len(key_parts) < 2 or len(key_parts) > 3: + if len(key_parts) != 2 and len(key_parts) != 3: raise ValueError( 'Key is not in the proper format or contains extra data.') @@ -62,21 +62,21 @@ def load_ssh_public_key(data, backend): if not key_type.startswith(b'ssh-rsa'): raise UnsupportedAlgorithm('Only RSA keys are currently supported.') - return _load_ssh_rsa_public_key(key_type, key_body, backend) + return _load_ssh_rsa_public_key(key_body, backend) -def _load_ssh_rsa_public_key(key_type, key_body, backend): +def _load_ssh_rsa_public_key(key_body, backend): data = base64.b64decode(key_body) - key_body_type, rest = _read_next_string(data) + key_type, rest = _read_next_string(data) e, rest = _read_next_mpint(rest) n, rest = _read_next_mpint(rest) - if key_type != key_body_type: + if key_type != b'ssh-rsa': raise ValueError( 'Key header and key body contain different key type values.') - if len(rest) != 0: + if rest: raise ValueError('Key body contains extra bytes.') return backend.load_rsa_public_numbers(RSAPublicNumbers(e, n)) @@ -84,26 +84,37 @@ def _load_ssh_rsa_public_key(key_type, key_body, backend): def _read_next_string(data): """Retrieves the next RFC 4251 string value from the data.""" - str_len, = struct.unpack('>I', data[0:4]) + str_len, = struct.unpack('>I', data[:4]) return data[4:4 + str_len], data[4 + str_len:] def _read_next_mpint(data): - """Reads the next mpint from the data. Currently, all mpints are - interpreted as unsigned.""" + """ + Reads the next mpint from the data. + + Currently, all mpints are interpreted as unsigned. + """ mpint_data, rest = _read_next_string(data) - if sys.version_info >= (3, 2): - # If we're using >= 3.2, use int.from_bytes for identical results. - return int.from_bytes(mpint_data, byteorder='big', signed=False), rest + return _int_from_bytes(mpint_data, byteorder='big', signed=False), rest + + + +if hasattr(int, "from_bytes"): + _int_from_bytes = int.from_bytes +else: + def _int_from_bytes(data, byteorder, signed=False): + assert byteorder == 'big' + assert not signed - if len(mpint_data) % 4 != 0: - mpint_data = (b'\x00' * (4 - (len(mpint_data) % 4))) + mpint_data + if len(data) % 4 != 0: + data = (b'\x00' * (4 - (len(data) % 4))) + data - result = 0 + result = 0 - while len(mpint_data) > 0: - result = (result << 32) + struct.unpack('>I', mpint_data[0:4])[0] - mpint_data = mpint_data[4:] + while len(data) > 0: + digit, = struct.unpack('>I', data[:4]) + result = (result << 32) + digit + data = data[4:] - return result, rest + return result diff --git a/src/cryptography/utils.py b/src/cryptography/utils.py index 63464dfa..78f73464 100644 --- a/src/cryptography/utils.py +++ b/src/cryptography/utils.py @@ -48,8 +48,9 @@ def verify_interface(iface, klass): ) -def bit_length(x): - if sys.version_info >= (2, 7): +if sys.version_info >= (2, 7): + def bit_length(x): return x.bit_length() - else: +else: + def bit_length(x): return len(bin(x)) - (2 + (x <= 0)) |