aboutsummaryrefslogtreecommitdiffstats
path: root/src/cryptography/hazmat/backends/openssl/ec.py
blob: 13b0ddbb343cfabd47ce327da5c72468783ca3fe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import, division, print_function

from cryptography import utils
from cryptography.exceptions import (
    InvalidSignature, UnsupportedAlgorithm, _Reasons
)
from cryptography.hazmat.backends.openssl.utils import _truncate_digest
from cryptography.hazmat.primitives import hashes, interfaces
from cryptography.hazmat.primitives.asymmetric import ec


def _truncate_digest_for_ecdsa(ec_key_cdata, digest, backend):
    """
    This function truncates digests that are longer than a given elliptic
    curve key's length so they can be signed. Since elliptic curve keys are
    much shorter than RSA keys many digests (e.g. SHA-512) may require
    truncation.
    """

    _lib = backend._lib
    _ffi = backend._ffi

    group = _lib.EC_KEY_get0_group(ec_key_cdata)

    with backend._tmp_bn_ctx() as bn_ctx:
        order = _lib.BN_CTX_get(bn_ctx)
        assert order != _ffi.NULL

        res = _lib.EC_GROUP_get_order(group, order, bn_ctx)
        assert res == 1

        order_bits = _lib.BN_num_bits(order)

    return _truncate_digest(digest, order_bits)


def _ec_key_curve_sn(backend, ec_key):
    group = backend._lib.EC_KEY_get0_group(ec_key)
    assert group != backend._ffi.NULL

    nid = backend._lib.EC_GROUP_get_curve_name(group)
    assert nid != backend._lib.NID_undef

    curve_name = backend._lib.OBJ_nid2sn(nid)
    assert curve_name != backend._ffi.NULL

    sn = backend._ffi.string(curve_name).decode('ascii')
    return sn


def _sn_to_elliptic_curve(backend, sn):
    try:
        return ec._CURVE_TYPES[sn]()
    except KeyError:
        raise UnsupportedAlgorithm(
            "{0} is not a supported elliptic curve".format(sn),
            _Reasons.UNSUPPORTED_ELLIPTIC_CURVE
        )


@utils.register_interface(interfaces.AsymmetricSignatureContext)
class _ECDSASignatureContext(object):
    def __init__(self, backend, private_key, algorithm):
        self._backend = backend
        self._private_key = private_key
        self._digest = hashes.Hash(algorithm, backend)

    def update(self, data):
        self._digest.update(data)

    def finalize(self):
        ec_key = self._private_key._ec_key

        digest = self._digest.finalize()

        digest = _truncate_digest_for_ecdsa(ec_key, digest, self._backend)

        max_size = self._backend._lib.ECDSA_size(ec_key)
        assert max_size > 0

        sigbuf = self._backend._ffi.new("char[]", max_size)
        siglen_ptr = self._backend._ffi.new("unsigned int[]", 1)
        res = self._backend._lib.ECDSA_sign(
            0,
            digest,
            len(digest),
            sigbuf,
            siglen_ptr,
            ec_key
        )
        assert res == 1
        return self._backend._ffi.buffer(sigbuf)[:siglen_ptr[0]]


@utils.register_interface(interfaces.AsymmetricVerificationContext)
class _ECDSAVerificationContext(object):
    def __init__(self, backend, public_key, signature, algorithm):
        self._backend = backend
        self._public_key = public_key
        self._signature = signature
        self._digest = hashes.Hash(algorithm, backend)

    def update(self, data):
        self._digest.update(data)

    def verify(self):
        ec_key = self._public_key._ec_key

        digest = self._digest.finalize()

        digest = _truncate_digest_for_ecdsa(ec_key, digest, self._backend)

        res = self._backend._lib.ECDSA_verify(
            0,
            digest,
            len(digest),
            self._signature,
            len(self._signature),
            ec_key
        )
        if res != 1:
            self._backend._consume_errors()
            raise InvalidSignature
        return True


@utils.register_interface(interfaces.EllipticCurvePrivateKeyWithNumbers)
class _EllipticCurvePrivateKey(object):
    def __init__(self, backend, ec_key_cdata):
        self._backend = backend
        self._ec_key = ec_key_cdata

        sn = _ec_key_curve_sn(backend, ec_key_cdata)
        self._curve = _sn_to_elliptic_curve(backend, sn)

    curve = utils.read_only_property("_curve")

    def signer(self, signature_algorithm):
        if isinstance(signature_algorithm, ec.ECDSA):
            return _ECDSASignatureContext(
                self._backend, self, signature_algorithm.algorithm
            )
        else:
            raise UnsupportedAlgorithm(
                "Unsupported elliptic curve signature algorithm.",
                _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM)

    def public_key(self):
        group = self._backend._lib.EC_KEY_get0_group(self._ec_key)
        assert group != self._backend._ffi.NULL

        curve_nid = self._backend._lib.EC_GROUP_get_curve_name(group)

        public_ec_key = self._backend._lib.EC_KEY_new_by_curve_name(curve_nid)
        assert public_ec_key != self._backend._ffi.NULL
        public_ec_key = self._backend._ffi.gc(
            public_ec_key, self._backend._lib.EC_KEY_free
        )

        point = self._backend._lib.EC_KEY_get0_public_key(self._ec_key)
        assert point != self._backend._ffi.NULL

        res = self._backend._lib.EC_KEY_set_public_key(public_ec_key, point)
        assert res == 1

        return _EllipticCurvePublicKey(
            self._backend, public_ec_key
        )

    def private_numbers(self):
        bn = self._backend._lib.EC_KEY_get0_private_key(self._ec_key)
        private_value = self._backend._bn_to_int(bn)
        return ec.EllipticCurvePrivateNumbers(
            private_value=private_value,
            public_numbers=self.public_key().public_numbers()
        )


@utils.register_interface(interfaces.EllipticCurvePublicKeyWithNumbers)
class _EllipticCurvePublicKey(object):
    def __init__(self, backend, ec_key_cdata):
        self._backend = backend
        self._ec_key = ec_key_cdata

        sn = _ec_key_curve_sn(backend, ec_key_cdata)
        self._curve = _sn_to_elliptic_curve(backend, sn)

    curve = utils.read_only_property("_curve")

    def verifier(self, signature, signature_algorithm):
        if isinstance(signature_algorithm, ec.ECDSA):
            return _ECDSAVerificationContext(
                self._backend, self, signature, signature_algorithm.algorithm
            )
        else:
            raise UnsupportedAlgorithm(
                "Unsupported elliptic curve signature algorithm.",
                _Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM)

    def public_numbers(self):
        set_func, get_func, group = (
            self._backend._ec_key_determine_group_get_set_funcs(self._ec_key)
        )
        point = self._backend._lib.EC_KEY_get0_public_key(self._ec_key)
        assert point != self._backend._ffi.NULL

        with self._backend._tmp_bn_ctx() as bn_ctx:
            bn_x = self._backend._lib.BN_CTX_get(bn_ctx)
            bn_y = self._backend._lib.BN_CTX_get(bn_ctx)

            res = get_func(group, point, bn_x, bn_y, bn_ctx)
            assert res == 1

            x = self._backend._bn_to_int(bn_x)
            y = self._backend._bn_to_int(bn_y)

        return ec.EllipticCurvePublicNumbers(
            x=x,
            y=y,
            curve=self._curve
        )