1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
from __future__ import absolute_import, division, print_function
import struct
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import ECB
from cryptography.hazmat.primitives.constant_time import bytes_eq
def _wrap_core(wrapping_key, a, r, backend):
# RFC 3394 Key Wrap - 2.2.1 (index method)
encryptor = Cipher(AES(wrapping_key), ECB(), backend).encryptor()
n = len(r)
for j in range(6):
for i in range(n):
# every encryption operation is a discrete 16 byte chunk (because
# AES has a 128-bit block size) and since we're using ECB it is
# safe to reuse the encryptor for the entire operation
b = encryptor.update(a + r[i])
# pack/unpack are safe as these are always 64-bit chunks
a = struct.pack(
">Q", struct.unpack(">Q", b[:8])[0] ^ ((n * j) + i + 1)
)
r[i] = b[-8:]
assert encryptor.finalize() == b""
return a + b"".join(r)
def aes_key_wrap(wrapping_key, key_to_wrap, backend):
if len(wrapping_key) not in [16, 24, 32]:
raise ValueError("The wrapping key must be a valid AES key length")
if len(key_to_wrap) < 16:
raise ValueError("The key to wrap must be at least 16 bytes")
if len(key_to_wrap) % 8 != 0:
raise ValueError("The key to wrap must be a multiple of 8 bytes")
a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)]
return _wrap_core(wrapping_key, a, r, backend)
def _unwrap_core(wrapping_key, a, r, backend):
# Implement RFC 3394 Key Unwrap - 2.2.2 (index method)
decryptor = Cipher(AES(wrapping_key), ECB(), backend).decryptor()
n = len(r)
for j in reversed(range(6)):
for i in reversed(range(n)):
# pack/unpack are safe as these are always 64-bit chunks
atr = struct.pack(
">Q", struct.unpack(">Q", a)[0] ^ ((n * j) + i + 1)
) + r[i]
# every decryption operation is a discrete 16 byte chunk so
# it is safe to reuse the decryptor for the entire operation
b = decryptor.update(atr)
a = b[:8]
r[i] = b[-8:]
assert decryptor.finalize() == b""
return a, r
def aes_key_wrap_with_padding(wrapping_key, key_to_wrap, backend):
if len(wrapping_key) not in [16, 24, 32]:
raise ValueError("The wrapping key must be a valid AES key length")
aiv = b"\xA6\x59\x59\xA6" + struct.pack(">i", len(key_to_wrap))
# pad the key to wrap if necessary
pad = (8 - (len(key_to_wrap) % 8)) % 8
key_to_wrap = key_to_wrap + b"\x00" * pad
if len(key_to_wrap) == 8:
# RFC 5649 - 4.1 - exactly 8 octets after padding
encryptor = Cipher(AES(wrapping_key), ECB(), backend).encryptor()
b = encryptor.update(aiv + key_to_wrap)
assert encryptor.finalize() == b""
return b
else:
r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)]
return _wrap_core(wrapping_key, aiv, r, backend)
def aes_key_unwrap_with_padding(wrapping_key, wrapped_key, backend):
if len(wrapped_key) < 16:
raise ValueError("Must be at least 16 bytes")
if len(wrapping_key) not in [16, 24, 32]:
raise ValueError("The wrapping key must be a valid AES key length")
if len(wrapped_key) == 16:
# RFC 5649 - 4.2 - exactly two 64-bit blocks
decryptor = Cipher(AES(wrapping_key), ECB(), backend).decryptor()
b = decryptor.update(wrapped_key)
assert decryptor.finalize() == b""
a = b[:8]
data = b[8:]
n = 1
else:
r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)]
encrypted_aiv = r.pop(0)
n = len(r)
a, r = _unwrap_core(wrapping_key, encrypted_aiv, r, backend)
data = b"".join(r)
# 1) Check that MSB(32,A) = A65959A6.
# 2) Check that 8*(n-1) < LSB(32,A) <= 8*n. If so, let
# MLI = LSB(32,A).
# 3) Let b = (8*n)-MLI, and then check that the rightmost b octets of
# the output data are zero.
(mli,) = struct.unpack(">I", a[4:])
b = (8 * n) - mli
if (
not bytes_eq(a[:4], b"\xa6\x59\x59\xa6") or not
8 * (n - 1) < mli <= 8 * n or (
b != 0 and not bytes_eq(data[-b:], b"\x00" * b)
)
):
raise InvalidUnwrap()
if b == 0:
return data
else:
return data[:-b]
def aes_key_unwrap(wrapping_key, wrapped_key, backend):
if len(wrapped_key) < 24:
raise ValueError("Must be at least 24 bytes")
if len(wrapped_key) % 8 != 0:
raise ValueError("The wrapped key must be a multiple of 8 bytes")
if len(wrapping_key) not in [16, 24, 32]:
raise ValueError("The wrapping key must be a valid AES key length")
aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)]
a = r.pop(0)
a, r = _unwrap_core(wrapping_key, a, r, backend)
if not bytes_eq(a, aiv):
raise InvalidUnwrap()
return b"".join(r)
class InvalidUnwrap(Exception):
pass
|