aboutsummaryrefslogtreecommitdiffstats
path: root/target/linux/generic/backport-5.4/080-wireguard-0105-wireguard-noise-separate-receive-counter-from-send-c.patch
blob: 87d38d36fe78e1797d60889004a95e5235a77db1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: "Jason A. Donenfeld" <Jason@zx2c4.com>
Date: Tue, 19 May 2020 22:49:30 -0600
Subject: [PATCH] wireguard: noise: separate receive counter from send counter

commit a9e90d9931f3a474f04bab782ccd9d77904941e9 upstream.

In "wireguard: queueing: preserve flow hash across packet scrubbing", we
were required to slightly increase the size of the receive replay
counter to something still fairly small, but an increase nonetheless.
It turns out that we can recoup some of the additional memory overhead
by splitting up the prior union type into two distinct types. Before, we
used the same "noise_counter" union for both sending and receiving, with
sending just using a simple atomic64_t, while receiving used the full
replay counter checker. This meant that most of the memory being
allocated for the sending counter was being wasted. Since the old
"noise_counter" type increased in size in the prior commit, now is a
good time to split up that union type into a distinct "noise_replay_
counter" for receiving and a boring atomic64_t for sending, each using
neither more nor less memory than required.

Also, since sometimes the replay counter is accessed without
necessitating additional accesses to the bitmap, we can reduce cache
misses by hoisting the always-necessary lock above the bitmap in the
struct layout. We also change a "noise_replay_counter" stack allocation
to kmalloc in a -DDEBUG selftest so that KASAN doesn't trigger a stack
frame warning.

All and all, removing a bit of abstraction in this commit makes the code
simpler and smaller, in addition to the motivating memory usage
recuperation. For example, passing around raw "noise_symmetric_key"
structs is something that really only makes sense within noise.c, in the
one place where the sending and receiving keys can safely be thought of
as the same type of object; subsequent to that, it's important that we
uniformly access these through keypair->{sending,receiving}, where their
distinct roles are always made explicit. So this patch allows us to draw
that distinction clearly as well.

Fixes: e7096c131e51 ("net: WireGuard secure network tunnel")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
---
 drivers/net/wireguard/noise.c            | 16 +++------
 drivers/net/wireguard/noise.h            | 14 ++++----
 drivers/net/wireguard/receive.c          | 42 ++++++++++++------------
 drivers/net/wireguard/selftest/counter.c | 17 +++++++---
 drivers/net/wireguard/send.c             | 12 +++----
 5 files changed, 48 insertions(+), 53 deletions(-)

--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -104,6 +104,7 @@ static struct noise_keypair *keypair_cre
 
 	if (unlikely(!keypair))
 		return NULL;
+	spin_lock_init(&keypair->receiving_counter.lock);
 	keypair->internal_id = atomic64_inc_return(&keypair_counter);
 	keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
 	keypair->entry.peer = peer;
@@ -358,25 +359,16 @@ out:
 	memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
 }
 
-static void symmetric_key_init(struct noise_symmetric_key *key)
-{
-	spin_lock_init(&key->counter.receive.lock);
-	atomic64_set(&key->counter.counter, 0);
-	memset(key->counter.receive.backtrack, 0,
-	       sizeof(key->counter.receive.backtrack));
-	key->birthdate = ktime_get_coarse_boottime_ns();
-	key->is_valid = true;
-}
-
 static void derive_keys(struct noise_symmetric_key *first_dst,
 			struct noise_symmetric_key *second_dst,
 			const u8 chaining_key[NOISE_HASH_LEN])
 {
+	u64 birthdate = ktime_get_coarse_boottime_ns();
 	kdf(first_dst->key, second_dst->key, NULL, NULL,
 	    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
 	    chaining_key);
-	symmetric_key_init(first_dst);
-	symmetric_key_init(second_dst);
+	first_dst->birthdate = second_dst->birthdate = birthdate;
+	first_dst->is_valid = second_dst->is_valid = true;
 }
 
 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
--- a/drivers/net/wireguard/noise.h
+++ b/drivers/net/wireguard/noise.h
@@ -15,18 +15,14 @@
 #include <linux/mutex.h>
 #include <linux/kref.h>
 
-union noise_counter {
-	struct {
-		u64 counter;
-		unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
-		spinlock_t lock;
-	} receive;
-	atomic64_t counter;
+struct noise_replay_counter {
+	u64 counter;
+	spinlock_t lock;
+	unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
 };
 
 struct noise_symmetric_key {
 	u8 key[NOISE_SYMMETRIC_KEY_LEN];
-	union noise_counter counter;
 	u64 birthdate;
 	bool is_valid;
 };
@@ -34,7 +30,9 @@ struct noise_symmetric_key {
 struct noise_keypair {
 	struct index_hashtable_entry entry;
 	struct noise_symmetric_key sending;
+	atomic64_t sending_counter;
 	struct noise_symmetric_key receiving;
+	struct noise_replay_counter receiving_counter;
 	__le32 remote_index;
 	bool i_am_the_initiator;
 	struct kref refcount;
--- a/drivers/net/wireguard/receive.c
+++ b/drivers/net/wireguard/receive.c
@@ -245,20 +245,20 @@ static void keep_key_fresh(struct wg_pee
 	}
 }
 
-static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
+static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
 {
 	struct scatterlist sg[MAX_SKB_FRAGS + 8];
 	struct sk_buff *trailer;
 	unsigned int offset;
 	int num_frags;
 
-	if (unlikely(!key))
+	if (unlikely(!keypair))
 		return false;
 
-	if (unlikely(!READ_ONCE(key->is_valid) ||
-		  wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
-		  key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
-		WRITE_ONCE(key->is_valid, false);
+	if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
+		  wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
+		  keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
+		WRITE_ONCE(keypair->receiving.is_valid, false);
 		return false;
 	}
 
@@ -283,7 +283,7 @@ static bool decrypt_packet(struct sk_buf
 
 	if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
 					         PACKET_CB(skb)->nonce,
-						 key->key))
+						 keypair->receiving.key))
 		return false;
 
 	/* Another ugly situation of pushing and pulling the header so as to
@@ -298,41 +298,41 @@ static bool decrypt_packet(struct sk_buf
 }
 
 /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
-static bool counter_validate(union noise_counter *counter, u64 their_counter)
+static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
 {
 	unsigned long index, index_current, top, i;
 	bool ret = false;
 
-	spin_lock_bh(&counter->receive.lock);
+	spin_lock_bh(&counter->lock);
 
-	if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
+	if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
 		     their_counter >= REJECT_AFTER_MESSAGES))
 		goto out;
 
 	++their_counter;
 
 	if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
-		     counter->receive.counter))
+		     counter->counter))
 		goto out;
 
 	index = their_counter >> ilog2(BITS_PER_LONG);
 
-	if (likely(their_counter > counter->receive.counter)) {
-		index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
+	if (likely(their_counter > counter->counter)) {
+		index_current = counter->counter >> ilog2(BITS_PER_LONG);
 		top = min_t(unsigned long, index - index_current,
 			    COUNTER_BITS_TOTAL / BITS_PER_LONG);
 		for (i = 1; i <= top; ++i)
-			counter->receive.backtrack[(i + index_current) &
+			counter->backtrack[(i + index_current) &
 				((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
-		counter->receive.counter = their_counter;
+		counter->counter = their_counter;
 	}
 
 	index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
 	ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
-				&counter->receive.backtrack[index]);
+				&counter->backtrack[index]);
 
 out:
-	spin_unlock_bh(&counter->receive.lock);
+	spin_unlock_bh(&counter->lock);
 	return ret;
 }
 
@@ -472,12 +472,12 @@ int wg_packet_rx_poll(struct napi_struct
 		if (unlikely(state != PACKET_STATE_CRYPTED))
 			goto next;
 
-		if (unlikely(!counter_validate(&keypair->receiving.counter,
+		if (unlikely(!counter_validate(&keypair->receiving_counter,
 					       PACKET_CB(skb)->nonce))) {
 			net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
 					    peer->device->dev->name,
 					    PACKET_CB(skb)->nonce,
-					    keypair->receiving.counter.receive.counter);
+					    keypair->receiving_counter.counter);
 			goto next;
 		}
 
@@ -511,8 +511,8 @@ void wg_packet_decrypt_worker(struct wor
 	struct sk_buff *skb;
 
 	while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
-		enum packet_state state = likely(decrypt_packet(skb,
-				&PACKET_CB(skb)->keypair->receiving)) ?
+		enum packet_state state =
+			likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
 				PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
 		wg_queue_enqueue_per_peer_napi(skb, state);
 		if (need_resched())
--- a/drivers/net/wireguard/selftest/counter.c
+++ b/drivers/net/wireguard/selftest/counter.c
@@ -6,18 +6,24 @@
 #ifdef DEBUG
 bool __init wg_packet_counter_selftest(void)
 {
+	struct noise_replay_counter *counter;
 	unsigned int test_num = 0, i;
-	union noise_counter counter;
 	bool success = true;
 
-#define T_INIT do {                                               \
-		memset(&counter, 0, sizeof(union noise_counter)); \
-		spin_lock_init(&counter.receive.lock);            \
+	counter = kmalloc(sizeof(*counter), GFP_KERNEL);
+	if (unlikely(!counter)) {
+		pr_err("nonce counter self-test malloc: FAIL\n");
+		return false;
+	}
+
+#define T_INIT do {                                    \
+		memset(counter, 0, sizeof(*counter));  \
+		spin_lock_init(&counter->lock);        \
 	} while (0)
 #define T_LIM (COUNTER_WINDOW_SIZE + 1)
 #define T(n, v) do {                                                  \
 		++test_num;                                           \
-		if (counter_validate(&counter, n) != (v)) {           \
+		if (counter_validate(counter, n) != (v)) {            \
 			pr_err("nonce counter self-test %u: FAIL\n",  \
 			       test_num);                             \
 			success = false;                              \
@@ -99,6 +105,7 @@ bool __init wg_packet_counter_selftest(v
 
 	if (success)
 		pr_info("nonce counter self-tests: pass\n");
+	kfree(counter);
 	return success;
 }
 #endif
--- a/drivers/net/wireguard/send.c
+++ b/drivers/net/wireguard/send.c
@@ -129,7 +129,7 @@ static void keep_key_fresh(struct wg_pee
 	rcu_read_lock_bh();
 	keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
 	send = keypair && READ_ONCE(keypair->sending.is_valid) &&
-	       (atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES ||
+	       (atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES ||
 		(keypair->i_am_the_initiator &&
 		 wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME)));
 	rcu_read_unlock_bh();
@@ -349,7 +349,6 @@ void wg_packet_purge_staged_packets(stru
 
 void wg_packet_send_staged_packets(struct wg_peer *peer)
 {
-	struct noise_symmetric_key *key;
 	struct noise_keypair *keypair;
 	struct sk_buff_head packets;
 	struct sk_buff *skb;
@@ -369,10 +368,9 @@ void wg_packet_send_staged_packets(struc
 	rcu_read_unlock_bh();
 	if (unlikely(!keypair))
 		goto out_nokey;
-	key = &keypair->sending;
-	if (unlikely(!READ_ONCE(key->is_valid)))
+	if (unlikely(!READ_ONCE(keypair->sending.is_valid)))
 		goto out_nokey;
-	if (unlikely(wg_birthdate_has_expired(key->birthdate,
+	if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
 					      REJECT_AFTER_TIME)))
 		goto out_invalid;
 
@@ -387,7 +385,7 @@ void wg_packet_send_staged_packets(struc
 		 */
 		PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
 		PACKET_CB(skb)->nonce =
-				atomic64_inc_return(&key->counter.counter) - 1;
+				atomic64_inc_return(&keypair->sending_counter) - 1;
 		if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
 			goto out_invalid;
 	}
@@ -399,7 +397,7 @@ void wg_packet_send_staged_packets(struc
 	return;
 
 out_invalid:
-	WRITE_ONCE(key->is_valid, false);
+	WRITE_ONCE(keypair->sending.is_valid, false);
 out_nokey:
 	wg_noise_keypair_put(keypair, false);