aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/cryptography/hazmat/primitives/serialization.py93
1 files changed, 65 insertions, 28 deletions
diff --git a/src/cryptography/hazmat/primitives/serialization.py b/src/cryptography/hazmat/primitives/serialization.py
index 083f17e5..b95ac1cd 100644
--- a/src/cryptography/hazmat/primitives/serialization.py
+++ b/src/cryptography/hazmat/primitives/serialization.py
@@ -7,11 +7,10 @@ from __future__ import absolute_import, division, print_function
import base64
import struct
+import six
+
from cryptography.exceptions import UnsupportedAlgorithm
-from cryptography.hazmat.primitives.asymmetric.dsa import (
- DSAParameterNumbers, DSAPublicNumbers
-)
-from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
+from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa
def load_pem_private_key(data, password, backend):
@@ -30,6 +29,18 @@ def load_ssh_public_key(data, backend):
'Key is not in the proper format or contains extra data.')
key_type = key_parts[0]
+
+ if key_type == b'ssh-rsa':
+ loader = _load_ssh_rsa_public_key
+ elif key_type == b'ssh-dss':
+ loader = _load_ssh_dss_public_key
+ elif key_type in [
+ b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521',
+ ]:
+ loader = _load_ssh_ecdsa_public_key
+ else:
+ raise UnsupportedAlgorithm('Key type is not supported.')
+
key_body = key_parts[1]
try:
@@ -37,53 +48,79 @@ def load_ssh_public_key(data, backend):
except TypeError:
raise ValueError('Key is not in the proper format.')
- if key_type == b'ssh-rsa':
- return _load_ssh_rsa_public_key(decoded_data, backend)
- elif key_type == b'ssh-dss':
- return _load_ssh_dss_public_key(decoded_data, backend)
- else:
- raise UnsupportedAlgorithm(
- 'Only RSA and DSA keys are currently supported.'
+ inner_key_type, rest = _read_next_string(decoded_data)
+
+ if inner_key_type != key_type:
+ raise ValueError(
+ 'Key header and key body contain different key type values.'
)
+ return loader(key_type, rest, backend)
-def _load_ssh_rsa_public_key(decoded_data, backend):
- key_type, rest = _read_next_string(decoded_data)
- e, rest = _read_next_mpint(rest)
- n, rest = _read_next_mpint(rest)
- if key_type != b'ssh-rsa':
- raise ValueError(
- 'Key header and key body contain different key type values.')
+def _load_ssh_rsa_public_key(key_type, decoded_data, backend):
+ e, rest = _read_next_mpint(decoded_data)
+ n, rest = _read_next_mpint(rest)
if rest:
raise ValueError('Key body contains extra bytes.')
- return RSAPublicNumbers(e, n).public_key(backend)
+ return rsa.RSAPublicNumbers(e, n).public_key(backend)
-def _load_ssh_dss_public_key(decoded_data, backend):
- key_type, rest = _read_next_string(decoded_data)
- p, rest = _read_next_mpint(rest)
+def _load_ssh_dss_public_key(key_type, decoded_data, backend):
+ p, rest = _read_next_mpint(decoded_data)
q, rest = _read_next_mpint(rest)
g, rest = _read_next_mpint(rest)
y, rest = _read_next_mpint(rest)
- if key_type != b'ssh-dss':
+ if rest:
+ raise ValueError('Key body contains extra bytes.')
+
+ parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
+ public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
+
+ return public_numbers.public_key(backend)
+
+
+def _load_ssh_ecdsa_public_key(expected_key_type, decoded_data, backend):
+ curve_name, rest = _read_next_string(decoded_data)
+ data, rest = _read_next_string(rest)
+
+ if expected_key_type != b"ecdsa-sha2-" + curve_name:
raise ValueError(
- 'Key header and key body contain different key type values.')
+ 'Key header and key body contain different key type values.'
+ )
if rest:
raise ValueError('Key body contains extra bytes.')
- parameter_numbers = DSAParameterNumbers(p, q, g)
- public_numbers = DSAPublicNumbers(y, parameter_numbers)
+ if curve_name == b"nistp256":
+ curve = ec.SECP256R1()
+ elif curve_name == b"nistp384":
+ curve = ec.SECP384R1()
+ elif curve_name == b"nistp521":
+ curve = ec.SECP521R1()
- return public_numbers.public_key(backend)
+ if len(data) != 1 + 2 * (curve.key_size // 8):
+ raise ValueError("Malformed key bytes")
+
+ if six.indexbytes(data, 0) != 4:
+ raise NotImplementedError(
+ "Compressed elliptic curve points are not supported"
+ )
+
+ x = _int_from_bytes(data[1:1 + curve.key_size // 8], byteorder='big')
+ y = _int_from_bytes(data[1 + curve.key_size // 8:], byteorder='big')
+ return ec.EllipticCurvePublicNumbers(x, y, curve).public_key(backend)
def _read_next_string(data):
- """Retrieves the next RFC 4251 string value from the data."""
+ """
+ Retrieves the next RFC 4251 string value from the data.
+
+ While the RFC calls these strings, in Python they are bytes objects.
+ """
str_len, = struct.unpack('>I', data[:4])
return data[4:4 + str_len], data[4 + str_len:]