Compare commits
416 Commits
0.0.201909
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d3b6017b8 | ||
|
|
e28a5cc364 | ||
|
|
ddd62a90dd | ||
|
|
12269c2761 | ||
|
|
542e565baa | ||
|
|
7c20311b3d | ||
|
|
4ffa9c2032 | ||
|
|
d0bc03c707 | ||
|
|
1cf89f5339 | ||
|
|
2e0774f246 | ||
|
|
b3df23dcd4 | ||
|
|
f502ec3fad | ||
|
|
5d37bd24e1 | ||
|
|
24ea13351e | ||
|
|
177caa7e44 | ||
|
|
42ec952ead | ||
|
|
ec8f6f82c2 | ||
|
|
1ec454f253 | ||
|
|
8a015f7c76 | ||
|
|
895d6c23cd | ||
|
|
4201e08f1d | ||
|
|
6a84778f2c | ||
|
|
469159ecf7 | ||
|
|
6e755e132a | ||
|
|
1f25eac395 | ||
|
|
25eb973e00 | ||
|
|
b7cd547315 | ||
|
|
052af4a807 | ||
|
|
aad7fca9c5 | ||
|
|
6f895be10d | ||
|
|
6a07b2a355 | ||
|
|
334b605e72 | ||
|
|
3a9e75374f | ||
|
|
cc20c08c96 | ||
|
|
1417a47c8f | ||
|
|
7f511c3bb1 | ||
|
|
07a1e55270 | ||
|
|
fff53afca7 | ||
|
|
0ad14a89f5 | ||
|
|
7d327ed35a | ||
|
|
f41f474466 | ||
|
|
5819c6af28 | ||
|
|
6901984f6a | ||
|
|
2fcdaf9799 | ||
|
|
dbd949307e | ||
|
|
f26efb65f2 | ||
|
|
f67c862a2a | ||
|
|
9e2f386022 | ||
|
|
3bb8fec7e4 | ||
|
|
21636207a6 | ||
|
|
c7b76d3d9e | ||
|
|
1e2c3e5a3c | ||
|
|
ebbd4a4330 | ||
|
|
0ae4b3177c | ||
|
|
077ce8ecab | ||
|
|
bb719d3a6e | ||
|
|
fde0a9525a | ||
|
|
b51010ba13 | ||
|
|
d1d08426b2 | ||
|
|
3381e21b18 | ||
|
|
c31a7b1ab4 | ||
|
|
6a08d81f6b | ||
|
|
ef5c587f78 | ||
|
|
193cf8d6a5 | ||
|
|
ee1c8e0e87 | ||
|
|
95b48cdb39 | ||
|
|
5aff28b14c | ||
|
|
46826fc4e5 | ||
|
|
42c9af45e1 | ||
|
|
ae6bc4dd64 | ||
|
|
2cec4d1a62 | ||
|
|
3b95c81cc1 | ||
|
|
b9669b734e | ||
|
|
e0b8f11489 | ||
|
|
114a3db918 | ||
|
|
9c9e7e2724 | ||
|
|
2dd424e2d8 | ||
|
|
387f7c461a | ||
|
|
4d87c9e824 | ||
|
|
ef8d6804d7 | ||
|
|
de7c702ace | ||
|
|
fc4f975a4d | ||
|
|
9d699ba730 | ||
|
|
425f7c726b | ||
|
|
3cae233d69 | ||
|
|
111e0566dc | ||
|
|
e3134bf665 | ||
|
|
63abb5537b | ||
|
|
851efb1bb6 | ||
|
|
c07dd60cdb | ||
|
|
eb6302c7eb | ||
|
|
60683d7361 | ||
|
|
e42c6c4bc2 | ||
|
|
828a885a71 | ||
|
|
f1f626090e | ||
|
|
82e0b734e5 | ||
|
|
fdf57a1fa4 | ||
|
|
f87e87af0d | ||
|
|
ba9e364dab | ||
|
|
dfd688b6aa | ||
|
|
c01d52b66a | ||
|
|
82d2aa87aa | ||
|
|
982d5d2e84 | ||
|
|
642a56e165 | ||
|
|
bb745b2ea3 | ||
|
|
fcc601dbf0 | ||
|
|
217ac1016b | ||
|
|
eae5e0f3a3 | ||
|
|
2ef39d4754 | ||
|
|
3957e9b9dd | ||
|
|
bad6caeb82 | ||
|
|
c89f5ca665 | ||
|
|
15b24b6179 | ||
|
|
f9b48a961c | ||
|
|
d0cf96114f | ||
|
|
841756e328 | ||
|
|
c382222eab | ||
|
|
b41f4cc768 | ||
|
|
4a57024b94 | ||
|
|
64cb82f2b3 | ||
|
|
c27ff9b9f6 | ||
|
|
99e8b4ba60 | ||
|
|
bd83f0ac99 | ||
|
|
50d779833e | ||
|
|
a9b377e9e1 | ||
|
|
9087e444e6 | ||
|
|
25ad08a591 | ||
|
|
5846b62283 | ||
|
|
9844c74f67 | ||
|
|
4e9e5dad09 | ||
|
|
39e0b6dade | ||
|
|
7121927b87 | ||
|
|
326aec10af | ||
|
|
efb8818550 | ||
|
|
69b39db0b4 | ||
|
|
db733ccd65 | ||
|
|
a7aec4449f | ||
|
|
60a26371f4 | ||
|
|
a544776d70 | ||
|
|
69a42a4eef | ||
|
|
097af6e135 | ||
|
|
8246d251ea | ||
|
|
c9db4b7aaa | ||
|
|
3625f8d284 | ||
|
|
0687dc06c8 | ||
|
|
71aefa374d | ||
|
|
3d3e30beb8 | ||
|
|
b0e5b19969 | ||
|
|
3988821442 | ||
|
|
c7cd2c9eab | ||
|
|
54dbe2471f | ||
|
|
d2fd0c0cc0 | ||
|
|
5f6bbe4ae8 | ||
|
|
75526d6071 | ||
|
|
fbf97502cf | ||
|
|
10533c3e73 | ||
|
|
8ed83e0427 | ||
|
|
6228659a91 | ||
|
|
517f0703f5 | ||
|
|
204140016a | ||
|
|
822f5a6d70 | ||
|
|
02e419ed8a | ||
|
|
bc69a3fa60 | ||
|
|
12ce53271b | ||
|
|
5f0c8b942d | ||
|
|
c5f382624e | ||
|
|
6005c573e2 | ||
|
|
82f3e9e2af | ||
|
|
4885e7c954 | ||
|
|
497ba95de7 | ||
|
|
0eb7206295 | ||
|
|
20714ca472 | ||
|
|
c1e09f1927 | ||
|
|
79611c64e8 | ||
|
|
593658d975 | ||
|
|
3c11c0308e | ||
|
|
f9dac7099e | ||
|
|
9a29ae267c | ||
|
|
6603c05a4a | ||
|
|
a4f8e83d5d | ||
|
|
c69481f1b3 | ||
|
|
0f4809f366 | ||
|
|
fecb8f482a | ||
|
|
8bf4204d2e | ||
|
|
4e439ea10e | ||
|
|
7a0fb5bbb1 | ||
|
|
c7b7998619 | ||
|
|
ef8115f63b | ||
|
|
75e6d810ed | ||
|
|
747f5440bc | ||
|
|
aabc3770ba | ||
|
|
484a9fd324 | ||
|
|
5bf8d73127 | ||
|
|
587a2b2a20 | ||
|
|
6f08a10041 | ||
|
|
a97ef39cd4 | ||
|
|
c040dea798 | ||
|
|
5cdb862f15 | ||
|
|
da32fe328b | ||
|
|
4eab21a7b7 | ||
|
|
30b96ba083 | ||
|
|
78ebce6932 | ||
|
|
cae090d116 | ||
|
|
465261310b | ||
|
|
d117d42ae7 | ||
|
|
ecceaadd16 | ||
|
|
9e728c2eb0 | ||
|
|
eaf664e4e9 | ||
|
|
a816e8511e | ||
|
|
02138f1f81 | ||
|
|
d7bc7508e5 | ||
|
|
d6e76fdbd6 | ||
|
|
6ac1240821 | ||
|
|
4b5d15ec2b | ||
|
|
6548a682a9 | ||
|
|
a60e6dab76 | ||
|
|
d8dd1f254f | ||
|
|
57aadfcb14 | ||
|
|
af408eb940 | ||
|
|
15810daa22 | ||
|
|
d840445e9b | ||
|
|
675ff32e6c | ||
|
|
3516ccc1e2 | ||
|
|
0bcb822e5b | ||
|
|
da95677203 | ||
|
|
9c75f58f3d | ||
|
|
84a42aed63 | ||
|
|
4192036acd | ||
|
|
9c7bd73be2 | ||
|
|
01e176af3c | ||
|
|
91617b4c52 | ||
|
|
7258a8973d | ||
|
|
d9d547a3f3 | ||
|
|
c3bde5f590 | ||
|
|
fd63a233c9 | ||
|
|
8a374a35a0 | ||
|
|
4846070322 | ||
|
|
a9f80d8c58 | ||
|
|
de51129e33 | ||
|
|
beb25cc4fd | ||
|
|
9263014ed3 | ||
|
|
f0f27d7fd2 | ||
|
|
d4112d9096 | ||
|
|
bf3bb88851 | ||
|
|
6a128dde71 | ||
|
|
34c047c762 | ||
|
|
d4725bc456 | ||
|
|
1b092ce584 | ||
|
|
a11dec5dc1 | ||
|
|
ace50a0529 | ||
|
|
8cc99631d0 | ||
|
|
d669c78c43 | ||
|
|
7139279cd0 | ||
|
|
37efdcaccf | ||
|
|
d3a2b74df2 | ||
|
|
8114c9db5f | ||
|
|
e6ec3852a9 | ||
|
|
23b2790aa0 | ||
|
|
18e47795e5 | ||
|
|
a29767dda6 | ||
|
|
cecb41515d | ||
|
|
a9ce4b762c | ||
|
|
d8f2cc87ee | ||
|
|
2b8665f5f9 | ||
|
|
674a4675a1 | ||
|
|
87bdcb2ae4 | ||
|
|
37a239e736 | ||
|
|
6252de0db9 | ||
|
|
a029b942ae | ||
|
|
db3fa1409c | ||
|
|
675aae2423 | ||
|
|
fcc8ad05df | ||
|
|
1d4eb2727a | ||
|
|
294d3bedf9 | ||
|
|
86a58b51c0 | ||
|
|
6a2ecb581b | ||
|
|
f07177c762 | ||
|
|
b00b2c2951 | ||
|
|
7c5d1e355e | ||
|
|
a86492a567 | ||
|
|
7ee95e053c | ||
|
|
291dbcf1f0 | ||
|
|
abc88c82b1 | ||
|
|
23642a13be | ||
|
|
2fe19ce54d | ||
|
|
0cc15e7c7c | ||
|
|
48c3b87eb8 | ||
|
|
675955de5d | ||
|
|
ea6c1cd7e6 | ||
|
|
3b3de758ec | ||
|
|
29b0477585 | ||
|
|
85b4950579 | ||
|
|
8a30415555 | ||
|
|
cdaf4e9a76 | ||
|
|
3d83df9bf3 | ||
|
|
d664444928 | ||
|
|
1481e72107 | ||
|
|
d0f8e9477c | ||
|
|
b42e32047d | ||
|
|
b5f966ac24 | ||
|
|
a1c265b0c5 | ||
|
|
25b01723dd | ||
|
|
40dfc85def | ||
|
|
890cc06ed5 | ||
|
|
ad73ee78e9 | ||
|
|
e9edc16349 | ||
|
|
f7bbdc31a0 | ||
|
|
70861686d3 | ||
|
|
c8faa34cde | ||
|
|
2832e96339 | ||
|
|
63066ce406 | ||
|
|
e1fa1cc556 | ||
|
|
41cd68416c | ||
|
|
94b33ba705 | ||
|
|
ea8fbb5927 | ||
|
|
93a4313c3a | ||
|
|
db1edc7e91 | ||
|
|
fc0aabbae9 | ||
|
|
c9e4a859ae | ||
|
|
3591acba76 | ||
|
|
ca9edf1c63 | ||
|
|
347ce76bbc | ||
|
|
c4895658e6 | ||
|
|
d3ff2d6b62 | ||
|
|
01d3aaa7f4 | ||
|
|
b6303091fc | ||
|
|
c9fabbd5bf | ||
|
|
4cc7a7a455 | ||
|
|
da19db415a | ||
|
|
52c834c446 | ||
|
|
913f68ce38 | ||
|
|
60b3766b89 | ||
|
|
82128c47d9 | ||
|
|
c192b2eeec | ||
|
|
a3b231b31e | ||
|
|
65e03a9182 | ||
|
|
3e08b8aee0 | ||
|
|
5ca1218a5c | ||
|
|
3b490f30aa | ||
|
|
e6b7c4eef3 | ||
|
|
8ae09213a7 | ||
|
|
36dc8b6994 | ||
|
|
2057f19a61 | ||
|
|
58a8f05f50 | ||
|
|
0b54907a73 | ||
|
|
2c143dce0f | ||
|
|
22af3890f6 | ||
|
|
c8fe925020 | ||
|
|
0cfa3314ee | ||
|
|
bc3f505efa | ||
|
|
507f148e1c | ||
|
|
31b574ef99 | ||
|
|
3c41141fb4 | ||
|
|
4369db522b | ||
|
|
b84f1d4db2 | ||
|
|
dfb28757f7 | ||
|
|
00bcd865e6 | ||
|
|
f28a6d244b | ||
|
|
c403da6a39 | ||
|
|
d6de6f3ce6 | ||
|
|
59e556f24e | ||
|
|
31faf4c159 | ||
|
|
99eb7896be | ||
|
|
f60b3919be | ||
|
|
da9d300cf8 | ||
|
|
59c9929714 | ||
|
|
db0aa39b76 | ||
|
|
bc77de2aca | ||
|
|
c8596328e7 | ||
|
|
28c4d04304 | ||
|
|
fdba6c183a | ||
|
|
250b9795f3 | ||
|
|
d60857e1a7 | ||
|
|
2fb0a712f0 | ||
|
|
f2c6faad44 | ||
|
|
c76b818466 | ||
|
|
de374bfb44 | ||
|
|
1a1c3d0968 | ||
|
|
85a45a9651 | ||
|
|
abd287159e | ||
|
|
203554620d | ||
|
|
6aefb61355 | ||
|
|
3dce460c88 | ||
|
|
224bc9e60c | ||
|
|
9cd8909df2 | ||
|
|
ae88e2a2cd | ||
|
|
4739708ca4 | ||
|
|
b33219c2cf | ||
|
|
9cbcff10dd | ||
|
|
6ed56ff2df | ||
|
|
cb4bb63030 | ||
|
|
05b03c6750 | ||
|
|
caebdfe9d0 | ||
|
|
4fa2ea6a2d | ||
|
|
89dd065e53 | ||
|
|
ddfad453cf | ||
|
|
2b242f9393 | ||
|
|
4cdf805b29 | ||
|
|
f7d0edd2ec | ||
|
|
ffffbbcc8a | ||
|
|
47b02c618b | ||
|
|
fd23c66fcd | ||
|
|
ae492d1b35 | ||
|
|
95fbfccf60 | ||
|
|
c85e4a410f | ||
|
|
1b6c8ddbe8 | ||
|
|
0abb6b668c | ||
|
|
540d01e54a | ||
|
|
f2ea85e9f9 | ||
|
|
222f0f8000 | ||
|
|
1f146a5e7a | ||
|
|
f2501aa6c8 | ||
|
|
cb8d01f58a | ||
|
|
01f8ef4e84 | ||
|
|
70f6c42556 | ||
|
|
bb0b2514c0 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1 @@
|
|||||||
wireguard-go
|
wireguard-go
|
||||||
vendor
|
|
||||||
.gopath
|
|
||||||
ireallywantobuildon_linux.go
|
|
||||||
|
|||||||
13
Makefile
13
Makefile
@@ -10,10 +10,10 @@ MAKEFLAGS += --no-print-directory
|
|||||||
generate-version-and-build:
|
generate-version-and-build:
|
||||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
||||||
tag="$$(git describe --dirty 2>/dev/null)" && \
|
tag="$$(git describe --dirty 2>/dev/null)" && \
|
||||||
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \
|
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
|
||||||
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
|
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
|
||||||
echo "$$ver" > device/version.go && \
|
echo "$$ver" > version.go && \
|
||||||
git update-index --assume-unchanged device/version.go || true
|
git update-index --assume-unchanged version.go || true
|
||||||
@$(MAKE) wireguard-go
|
@$(MAKE) wireguard-go
|
||||||
|
|
||||||
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
||||||
@@ -22,7 +22,10 @@ wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
|||||||
install: wireguard-go
|
install: wireguard-go
|
||||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
|
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test ./...
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -f wireguard-go
|
rm -f wireguard-go
|
||||||
|
|
||||||
.PHONY: all clean install generate-version-and-build
|
.PHONY: all clean test install generate-version-and-build
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
|
|||||||
$ wireguard-go -f wg0
|
$ wireguard-go -f wg0
|
||||||
```
|
```
|
||||||
|
|
||||||
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||||
|
|
||||||
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
|||||||
|
|
||||||
### Linux
|
### Linux
|
||||||
|
|
||||||
This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead use the kernel module; see the [installation page](https://www.wireguard.com/install/) for instructions.
|
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
|
|||||||
|
|
||||||
## Building
|
## Building
|
||||||
|
|
||||||
This requires an installation of [go](https://golang.org) ≥ 1.12.
|
This requires an installation of the latest version of [Go](https://go.dev/).
|
||||||
|
|
||||||
```
|
```
|
||||||
$ git clone https://git.zx2c4.com/wireguard-go
|
$ git clone https://git.zx2c4.com/wireguard-go
|
||||||
@@ -56,7 +56,7 @@ $ make
|
|||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||||
this software and associated documentation files (the "Software"), to deal in
|
this software and associated documentation files (the "Software"), to deal in
|
||||||
|
|||||||
544
conn/bind_std.go
Normal file
544
conn/bind_std.go
Normal file
@@ -0,0 +1,544 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Bind = (*StdNetBind)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
||||||
|
// (see bind_windows.go), it may fall back to StdNetBind.
|
||||||
|
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
||||||
|
// methods for sending and receiving multiple datagrams per-syscall. See the
|
||||||
|
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
||||||
|
type StdNetBind struct {
|
||||||
|
mu sync.Mutex // protects all fields except as specified
|
||||||
|
ipv4 *net.UDPConn
|
||||||
|
ipv6 *net.UDPConn
|
||||||
|
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
||||||
|
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
||||||
|
ipv4TxOffload bool
|
||||||
|
ipv4RxOffload bool
|
||||||
|
ipv6TxOffload bool
|
||||||
|
ipv6RxOffload bool
|
||||||
|
|
||||||
|
// these two fields are not guarded by mu
|
||||||
|
udpAddrPool sync.Pool
|
||||||
|
msgsPool sync.Pool
|
||||||
|
|
||||||
|
blackhole4 bool
|
||||||
|
blackhole6 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStdNetBind() Bind {
|
||||||
|
return &StdNetBind{
|
||||||
|
udpAddrPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: make([]byte, 16),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
msgsPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
// ipv6.Message and ipv4.Message are interchangeable as they are
|
||||||
|
// both aliases for x/net/internal/socket.Message.
|
||||||
|
msgs := make([]ipv6.Message, IdealBatchSize)
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make(net.Buffers, 1)
|
||||||
|
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
|
||||||
|
}
|
||||||
|
return &msgs
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type StdNetEndpoint struct {
|
||||||
|
// AddrPort is the endpoint destination.
|
||||||
|
netip.AddrPort
|
||||||
|
// src is the current sticky source address and interface index, if
|
||||||
|
// supported. Typically this is a PKTINFO structure from/for control
|
||||||
|
// messages, see unix.PKTINFO for an example.
|
||||||
|
src []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Bind = (*StdNetBind)(nil)
|
||||||
|
_ Endpoint = &StdNetEndpoint{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||||
|
e, err := netip.ParseAddrPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &StdNetEndpoint{
|
||||||
|
AddrPort: e,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) ClearSrc() {
|
||||||
|
if e.src != nil {
|
||||||
|
// Truncate src, no need to reallocate.
|
||||||
|
e.src = e.src[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
||||||
|
return e.AddrPort.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) DstToBytes() []byte {
|
||||||
|
b, _ := e.AddrPort.MarshalBinary()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) DstToString() string {
|
||||||
|
return e.AddrPort.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||||
|
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve port.
|
||||||
|
laddr := conn.LocalAddr()
|
||||||
|
uaddr, err := net.ResolveUDPAddr(
|
||||||
|
laddr.Network(),
|
||||||
|
laddr.String(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return conn.(*net.UDPConn), uaddr.Port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var tries int
|
||||||
|
|
||||||
|
if s.ipv4 != nil || s.ipv6 != nil {
|
||||||
|
return nil, 0, ErrBindAlreadyOpen
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
||||||
|
// If uport is 0, we can retry on failure.
|
||||||
|
again:
|
||||||
|
port := int(uport)
|
||||||
|
var v4conn, v6conn *net.UDPConn
|
||||||
|
var v4pc *ipv4.PacketConn
|
||||||
|
var v6pc *ipv6.PacketConn
|
||||||
|
|
||||||
|
v4conn, port, err = listenNet("udp4", port)
|
||||||
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen on the same port as we're using for ipv4.
|
||||||
|
v6conn, port, err = listenNet("udp6", port)
|
||||||
|
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
||||||
|
v4conn.Close()
|
||||||
|
tries++
|
||||||
|
goto again
|
||||||
|
}
|
||||||
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||||
|
v4conn.Close()
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
var fns []ReceiveFunc
|
||||||
|
if v4conn != nil {
|
||||||
|
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
v4pc = ipv4.NewPacketConn(v4conn)
|
||||||
|
s.ipv4PC = v4pc
|
||||||
|
}
|
||||||
|
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
|
||||||
|
s.ipv4 = v4conn
|
||||||
|
}
|
||||||
|
if v6conn != nil {
|
||||||
|
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
v6pc = ipv6.NewPacketConn(v6conn)
|
||||||
|
s.ipv6PC = v6pc
|
||||||
|
}
|
||||||
|
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
|
||||||
|
s.ipv6 = v6conn
|
||||||
|
}
|
||||||
|
if len(fns) == 0 {
|
||||||
|
return nil, 0, syscall.EAFNOSUPPORT
|
||||||
|
}
|
||||||
|
|
||||||
|
return fns, uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
|
||||||
|
for i := range *msgs {
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||||
|
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||||
|
}
|
||||||
|
s.msgsPool.Put(msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) getMessages() *[]ipv6.Message {
|
||||||
|
return s.msgsPool.Get().(*[]ipv6.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// If compilation fails here these are no longer the same underlying type.
|
||||||
|
_ ipv6.Message = ipv4.Message{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type batchReader interface {
|
||||||
|
ReadBatch([]ipv6.Message, int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type batchWriter interface {
|
||||||
|
WriteBatch([]ipv6.Message, int) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) receiveIP(
|
||||||
|
br batchReader,
|
||||||
|
conn *net.UDPConn,
|
||||||
|
rxOffload bool,
|
||||||
|
bufs [][]byte,
|
||||||
|
sizes []int,
|
||||||
|
eps []Endpoint,
|
||||||
|
) (n int, err error) {
|
||||||
|
msgs := s.getMessages()
|
||||||
|
for i := range bufs {
|
||||||
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||||
|
}
|
||||||
|
defer s.putMessages(msgs)
|
||||||
|
var numMsgs int
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
if rxOffload {
|
||||||
|
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
|
||||||
|
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
numMsgs, err = br.ReadBatch(*msgs, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
msg := &(*msgs)[0]
|
||||||
|
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
numMsgs = 1
|
||||||
|
}
|
||||||
|
for i := 0; i < numMsgs; i++ {
|
||||||
|
msg := &(*msgs)[i]
|
||||||
|
sizes[i] = msg.N
|
||||||
|
if sizes[i] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||||
|
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||||
|
eps[i] = ep
|
||||||
|
}
|
||||||
|
return numMsgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
|
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
||||||
|
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||||
|
// rename the IdealBatchSize constant to BatchSize.
|
||||||
|
func (s *StdNetBind) BatchSize() int {
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
return IdealBatchSize
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
var err1, err2 error
|
||||||
|
if s.ipv4 != nil {
|
||||||
|
err1 = s.ipv4.Close()
|
||||||
|
s.ipv4 = nil
|
||||||
|
s.ipv4PC = nil
|
||||||
|
}
|
||||||
|
if s.ipv6 != nil {
|
||||||
|
err2 = s.ipv6.Close()
|
||||||
|
s.ipv6 = nil
|
||||||
|
s.ipv6PC = nil
|
||||||
|
}
|
||||||
|
s.blackhole4 = false
|
||||||
|
s.blackhole6 = false
|
||||||
|
s.ipv4TxOffload = false
|
||||||
|
s.ipv4RxOffload = false
|
||||||
|
s.ipv6TxOffload = false
|
||||||
|
s.ipv6RxOffload = false
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrUDPGSODisabled struct {
|
||||||
|
onLaddr string
|
||||||
|
RetryErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ErrUDPGSODisabled) Error() string {
|
||||||
|
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ErrUDPGSODisabled) Unwrap() error {
|
||||||
|
return e.RetryErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
blackhole := s.blackhole4
|
||||||
|
conn := s.ipv4
|
||||||
|
offload := s.ipv4TxOffload
|
||||||
|
br := batchWriter(s.ipv4PC)
|
||||||
|
is6 := false
|
||||||
|
if endpoint.DstIP().Is6() {
|
||||||
|
blackhole = s.blackhole6
|
||||||
|
conn = s.ipv6
|
||||||
|
br = s.ipv6PC
|
||||||
|
is6 = true
|
||||||
|
offload = s.ipv6TxOffload
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if blackhole {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return syscall.EAFNOSUPPORT
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := s.getMessages()
|
||||||
|
defer s.putMessages(msgs)
|
||||||
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
||||||
|
defer s.udpAddrPool.Put(ua)
|
||||||
|
if is6 {
|
||||||
|
as16 := endpoint.DstIP().As16()
|
||||||
|
copy(ua.IP, as16[:])
|
||||||
|
ua.IP = ua.IP[:16]
|
||||||
|
} else {
|
||||||
|
as4 := endpoint.DstIP().As4()
|
||||||
|
copy(ua.IP, as4[:])
|
||||||
|
ua.IP = ua.IP[:4]
|
||||||
|
}
|
||||||
|
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
|
||||||
|
var (
|
||||||
|
retried bool
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
retry:
|
||||||
|
if offload {
|
||||||
|
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
|
||||||
|
err = s.send(conn, br, (*msgs)[:n])
|
||||||
|
if err != nil && offload && errShouldDisableUDPGSO(err) {
|
||||||
|
offload = false
|
||||||
|
s.mu.Lock()
|
||||||
|
if is6 {
|
||||||
|
s.ipv6TxOffload = false
|
||||||
|
} else {
|
||||||
|
s.ipv4TxOffload = false
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := range bufs {
|
||||||
|
(*msgs)[i].Addr = ua
|
||||||
|
(*msgs)[i].Buffers[0] = bufs[i]
|
||||||
|
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
|
||||||
|
}
|
||||||
|
err = s.send(conn, br, (*msgs)[:len(bufs)])
|
||||||
|
}
|
||||||
|
if retried {
|
||||||
|
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
|
||||||
|
var (
|
||||||
|
n int
|
||||||
|
err error
|
||||||
|
start int
|
||||||
|
)
|
||||||
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
|
for {
|
||||||
|
n, err = pc.WriteBatch(msgs[start:], 0)
|
||||||
|
if err != nil || n == len(msgs[start:]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
start += n
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, msg := range msgs {
|
||||||
|
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Exceeding these values results in EMSGSIZE. They account for layer3 and
|
||||||
|
// layer4 headers. IPv6 does not need to account for itself as the payload
|
||||||
|
// length field is self excluding.
|
||||||
|
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
||||||
|
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
||||||
|
|
||||||
|
// This is a hard limit imposed by the kernel.
|
||||||
|
udpSegmentMaxDatagrams = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
type setGSOFunc func(control *[]byte, gsoSize uint16)
|
||||||
|
|
||||||
|
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
|
||||||
|
var (
|
||||||
|
base = -1 // index of msg we are currently coalescing into
|
||||||
|
gsoSize int // segmentation size of msgs[base]
|
||||||
|
dgramCnt int // number of dgrams coalesced into msgs[base]
|
||||||
|
endBatch bool // tracking flag to start a new batch on next iteration of bufs
|
||||||
|
)
|
||||||
|
maxPayloadLen := maxIPv4PayloadLen
|
||||||
|
if ep.DstIP().Is6() {
|
||||||
|
maxPayloadLen = maxIPv6PayloadLen
|
||||||
|
}
|
||||||
|
for i, buf := range bufs {
|
||||||
|
if i > 0 {
|
||||||
|
msgLen := len(buf)
|
||||||
|
baseLenBefore := len(msgs[base].Buffers[0])
|
||||||
|
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
||||||
|
if msgLen+baseLenBefore <= maxPayloadLen &&
|
||||||
|
msgLen <= gsoSize &&
|
||||||
|
msgLen <= freeBaseCap &&
|
||||||
|
dgramCnt < udpSegmentMaxDatagrams &&
|
||||||
|
!endBatch {
|
||||||
|
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
|
||||||
|
if i == len(bufs)-1 {
|
||||||
|
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||||
|
}
|
||||||
|
dgramCnt++
|
||||||
|
if msgLen < gsoSize {
|
||||||
|
// A smaller than gsoSize packet on the tail is legal, but
|
||||||
|
// it must end the batch.
|
||||||
|
endBatch = true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dgramCnt > 1 {
|
||||||
|
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
||||||
|
}
|
||||||
|
// Reset prior to incrementing base since we are preparing to start a
|
||||||
|
// new potential batch.
|
||||||
|
endBatch = false
|
||||||
|
base++
|
||||||
|
gsoSize = len(buf)
|
||||||
|
setSrcControl(&msgs[base].OOB, ep)
|
||||||
|
msgs[base].Buffers[0] = buf
|
||||||
|
msgs[base].Addr = addr
|
||||||
|
dgramCnt = 1
|
||||||
|
}
|
||||||
|
return base + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
type getGSOFunc func(control []byte) (int, error)
|
||||||
|
|
||||||
|
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
||||||
|
for i := firstMsgAt; i < len(msgs); i++ {
|
||||||
|
msg := &msgs[i]
|
||||||
|
if msg.N == 0 {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
gsoSize int
|
||||||
|
start int
|
||||||
|
end = msg.N
|
||||||
|
numToSplit = 1
|
||||||
|
)
|
||||||
|
gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
if gsoSize > 0 {
|
||||||
|
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||||
|
end = gsoSize
|
||||||
|
}
|
||||||
|
for j := 0; j < numToSplit; j++ {
|
||||||
|
if n > i {
|
||||||
|
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||||
|
}
|
||||||
|
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||||
|
msgs[n].N = copied
|
||||||
|
msgs[n].Addr = msg.Addr
|
||||||
|
start = end
|
||||||
|
end += gsoSize
|
||||||
|
if end > msg.N {
|
||||||
|
end = msg.N
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
if i != n-1 {
|
||||||
|
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||||
|
// of splitting, so we only zero the source msg len when it is not
|
||||||
|
// the destination of the last split operation above.
|
||||||
|
msg.N = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
250
conn/bind_std_test.go
Normal file
250
conn/bind_std_test.go
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
|
||||||
|
bind := NewStdNetBind().(*StdNetBind)
|
||||||
|
fns, _, err := bind.Open(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
bind.Close()
|
||||||
|
bufs := make([][]byte, 1)
|
||||||
|
bufs[0] = make([]byte, 1)
|
||||||
|
sizes := make([]int, 1)
|
||||||
|
eps := make([]Endpoint, 1)
|
||||||
|
for _, fn := range fns {
|
||||||
|
// The ReceiveFuncs must not access conn-related fields on StdNetBind
|
||||||
|
// unguarded. Close() nils the conn-related fields resulting in a panic
|
||||||
|
// if they violate the mutex.
|
||||||
|
fn(bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
*control = (*control)[:cap(*control)]
|
||||||
|
binary.LittleEndian.PutUint16(*control, gsoSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_coalesceMessages(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
buffs [][]byte
|
||||||
|
wantLens []int
|
||||||
|
wantGSO []int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "one message no coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
},
|
||||||
|
wantLens: []int{1},
|
||||||
|
wantGSO: []int{0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two messages equal len coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 1, 2),
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
},
|
||||||
|
wantLens: []int{2},
|
||||||
|
wantGSO: []int{1},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two messages unequal len coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 2, 3),
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
},
|
||||||
|
wantLens: []int{3},
|
||||||
|
wantGSO: []int{2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three messages second unequal len coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 2, 3),
|
||||||
|
make([]byte, 1, 1),
|
||||||
|
make([]byte, 2, 2),
|
||||||
|
},
|
||||||
|
wantLens: []int{3, 2},
|
||||||
|
wantGSO: []int{2, 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "three messages limited cap coalesce",
|
||||||
|
buffs: [][]byte{
|
||||||
|
make([]byte, 2, 4),
|
||||||
|
make([]byte, 2, 2),
|
||||||
|
make([]byte, 2, 2),
|
||||||
|
},
|
||||||
|
wantLens: []int{4, 2},
|
||||||
|
wantGSO: []int{2, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addr := &net.UDPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1").To4(),
|
||||||
|
Port: 1,
|
||||||
|
}
|
||||||
|
msgs := make([]ipv6.Message, len(tt.buffs))
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make([][]byte, 1)
|
||||||
|
msgs[i].OOB = make([]byte, 0, 2)
|
||||||
|
}
|
||||||
|
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
|
||||||
|
if got != len(tt.wantLens) {
|
||||||
|
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
|
||||||
|
}
|
||||||
|
for i := 0; i < got; i++ {
|
||||||
|
if msgs[i].Addr != addr {
|
||||||
|
t.Errorf("msgs[%d].Addr != passed addr", i)
|
||||||
|
}
|
||||||
|
gotLen := len(msgs[i].Buffers[0])
|
||||||
|
if gotLen != tt.wantLens[i] {
|
||||||
|
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
|
||||||
|
}
|
||||||
|
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
|
||||||
|
}
|
||||||
|
if gotGSO != tt.wantGSO[i] {
|
||||||
|
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mockGetGSOSize(control []byte) (int, error) {
|
||||||
|
if len(control) < 2 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return int(binary.LittleEndian.Uint16(control)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_splitCoalescedMessages(t *testing.T) {
|
||||||
|
newMsg := func(n, gso int) ipv6.Message {
|
||||||
|
msg := ipv6.Message{
|
||||||
|
Buffers: [][]byte{make([]byte, 1<<16-1)},
|
||||||
|
N: n,
|
||||||
|
OOB: make([]byte, 2),
|
||||||
|
}
|
||||||
|
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
|
||||||
|
if gso > 0 {
|
||||||
|
msg.NN = 2
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
msgs []ipv6.Message
|
||||||
|
firstMsgAt int
|
||||||
|
wantNumEval int
|
||||||
|
wantMsgLens []int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "second last split last empty",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(3, 1),
|
||||||
|
newMsg(0, 0),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 3,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 0},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last empty",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 1,
|
||||||
|
wantMsgLens: []int{1, 0, 0, 0},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last no split",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 2,
|
||||||
|
wantMsgLens: []int{1, 1, 0, 0},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last split",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(3, 1),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 4,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 1},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last split last split",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(2, 1),
|
||||||
|
newMsg(2, 1),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 4,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 1},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "second last no split last split overflow",
|
||||||
|
msgs: []ipv6.Message{
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(0, 0),
|
||||||
|
newMsg(1, 0),
|
||||||
|
newMsg(4, 1),
|
||||||
|
},
|
||||||
|
firstMsgAt: 2,
|
||||||
|
wantNumEval: 4,
|
||||||
|
wantMsgLens: []int{1, 1, 1, 1},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
|
||||||
|
if err != nil && !tt.wantErr {
|
||||||
|
t.Fatalf("err: %v", err)
|
||||||
|
}
|
||||||
|
if got != tt.wantNumEval {
|
||||||
|
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
|
||||||
|
}
|
||||||
|
for i, msg := range tt.msgs {
|
||||||
|
if msg.N != tt.wantMsgLens[i] {
|
||||||
|
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
601
conn/bind_windows.go
Normal file
601
conn/bind_windows.go
Normal file
@@ -0,0 +1,601 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/conn/winrio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
packetsPerRing = 1024
|
||||||
|
bytesPerPacket = 2048 - 32
|
||||||
|
receiveSpins = 15
|
||||||
|
)
|
||||||
|
|
||||||
|
type ringPacket struct {
|
||||||
|
addr WinRingEndpoint
|
||||||
|
data [bytesPerPacket]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ringBuffer struct {
|
||||||
|
packets uintptr
|
||||||
|
head, tail uint32
|
||||||
|
id winrio.BufferId
|
||||||
|
iocp windows.Handle
|
||||||
|
isFull bool
|
||||||
|
cq winrio.Cq
|
||||||
|
mu sync.Mutex
|
||||||
|
overlapped windows.Overlapped
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rb *ringBuffer) Push() *ringPacket {
|
||||||
|
for rb.isFull {
|
||||||
|
panic("ring is full")
|
||||||
|
}
|
||||||
|
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
|
||||||
|
rb.tail += 1
|
||||||
|
if rb.tail%packetsPerRing == rb.head%packetsPerRing {
|
||||||
|
rb.isFull = true
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rb *ringBuffer) Return(count uint32) {
|
||||||
|
if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rb.head += count
|
||||||
|
rb.isFull = false
|
||||||
|
}
|
||||||
|
|
||||||
|
type afWinRingBind struct {
|
||||||
|
sock windows.Handle
|
||||||
|
rx, tx ringBuffer
|
||||||
|
rq winrio.Rq
|
||||||
|
mu sync.Mutex
|
||||||
|
blackhole bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// WinRingBind uses Windows registered I/O for fast ring buffered networking.
|
||||||
|
type WinRingBind struct {
|
||||||
|
v4, v6 afWinRingBind
|
||||||
|
mu sync.RWMutex
|
||||||
|
isOpen atomic.Uint32 // 0, 1, or 2
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDefaultBind() Bind { return NewWinRingBind() }
|
||||||
|
|
||||||
|
func NewWinRingBind() Bind {
|
||||||
|
if !winrio.Initialize() {
|
||||||
|
return NewStdNetBind()
|
||||||
|
}
|
||||||
|
return new(WinRingBind)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WinRingEndpoint struct {
|
||||||
|
family uint16
|
||||||
|
data [30]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Bind = (*WinRingBind)(nil)
|
||||||
|
_ Endpoint = (*WinRingEndpoint)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||||
|
host, port, err := net.SplitHostPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
host16, err := windows.UTF16PtrFromString(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
port16, err := windows.UTF16PtrFromString(port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
hints := windows.AddrinfoW{
|
||||||
|
Flags: windows.AI_NUMERICHOST,
|
||||||
|
Family: windows.AF_UNSPEC,
|
||||||
|
Socktype: windows.SOCK_DGRAM,
|
||||||
|
Protocol: windows.IPPROTO_UDP,
|
||||||
|
}
|
||||||
|
var addrinfo *windows.AddrinfoW
|
||||||
|
err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer windows.FreeAddrInfoW(addrinfo)
|
||||||
|
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
|
||||||
|
return nil, windows.ERROR_INVALID_ADDRESS
|
||||||
|
}
|
||||||
|
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
|
||||||
|
copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
|
||||||
|
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*WinRingEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
|
func (e *WinRingEndpoint) DstIP() netip.Addr {
|
||||||
|
switch e.family {
|
||||||
|
case windows.AF_INET:
|
||||||
|
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
|
||||||
|
case windows.AF_INET6:
|
||||||
|
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
|
||||||
|
}
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *WinRingEndpoint) SrcIP() netip.Addr {
|
||||||
|
return netip.Addr{} // not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *WinRingEndpoint) DstToBytes() []byte {
|
||||||
|
switch e.family {
|
||||||
|
case windows.AF_INET:
|
||||||
|
b := make([]byte, 0, 6)
|
||||||
|
b = append(b, e.data[2:6]...)
|
||||||
|
b = append(b, e.data[1], e.data[0])
|
||||||
|
return b
|
||||||
|
case windows.AF_INET6:
|
||||||
|
b := make([]byte, 0, 18)
|
||||||
|
b = append(b, e.data[6:22]...)
|
||||||
|
b = append(b, e.data[1], e.data[0])
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *WinRingEndpoint) DstToString() string {
|
||||||
|
switch e.family {
|
||||||
|
case windows.AF_INET:
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||||
|
case windows.AF_INET6:
|
||||||
|
var zone string
|
||||||
|
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
||||||
|
zone = strconv.FormatUint(uint64(scope), 10)
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *WinRingEndpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ring *ringBuffer) CloseAndZero() {
|
||||||
|
if ring.cq != 0 {
|
||||||
|
winrio.CloseCompletionQueue(ring.cq)
|
||||||
|
ring.cq = 0
|
||||||
|
}
|
||||||
|
if ring.iocp != 0 {
|
||||||
|
windows.CloseHandle(ring.iocp)
|
||||||
|
ring.iocp = 0
|
||||||
|
}
|
||||||
|
if ring.id != 0 {
|
||||||
|
winrio.DeregisterBuffer(ring.id)
|
||||||
|
ring.id = 0
|
||||||
|
}
|
||||||
|
if ring.packets != 0 {
|
||||||
|
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
|
||||||
|
ring.packets = 0
|
||||||
|
}
|
||||||
|
ring.head = 0
|
||||||
|
ring.tail = 0
|
||||||
|
ring.isFull = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *afWinRingBind) CloseAndZero() {
|
||||||
|
bind.rx.CloseAndZero()
|
||||||
|
bind.tx.CloseAndZero()
|
||||||
|
if bind.sock != 0 {
|
||||||
|
windows.CloseHandle(bind.sock)
|
||||||
|
bind.sock = 0
|
||||||
|
}
|
||||||
|
bind.blackhole = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *WinRingBind) closeAndZero() {
|
||||||
|
bind.isOpen.Store(0)
|
||||||
|
bind.v4.CloseAndZero()
|
||||||
|
bind.v6.CloseAndZero()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ring *ringBuffer) Open() error {
|
||||||
|
var err error
|
||||||
|
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
|
||||||
|
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
|
||||||
|
var err error
|
||||||
|
bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = bind.rx.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = bind.tx.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = windows.Bind(bind.sock, sa)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sa, err = windows.Getsockname(bind.sock)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return sa, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
bind.closeAndZero()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if bind.isOpen.Load() != 0 {
|
||||||
|
return nil, 0, ErrBindAlreadyOpen
|
||||||
|
}
|
||||||
|
var sa windows.Sockaddr
|
||||||
|
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
|
||||||
|
for i := 0; i < packetsPerRing; i++ {
|
||||||
|
err = bind.v4.InsertReceiveRequest()
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
err = bind.v6.InsertReceiveRequest()
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bind.isOpen.Store(1)
|
||||||
|
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *WinRingBind) Close() error {
|
||||||
|
bind.mu.RLock()
|
||||||
|
if bind.isOpen.Load() != 1 {
|
||||||
|
bind.mu.RUnlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bind.isOpen.Store(2)
|
||||||
|
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
|
||||||
|
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
|
||||||
|
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
|
||||||
|
windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
|
||||||
|
bind.mu.RUnlock()
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
|
bind.closeAndZero()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
||||||
|
// rename the IdealBatchSize constant to BatchSize.
|
||||||
|
func (bind *WinRingBind) BatchSize() int {
|
||||||
|
// TODO: implement batching in and out of the ring
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *WinRingBind) SetMark(mark uint32) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *afWinRingBind) InsertReceiveRequest() error {
|
||||||
|
packet := bind.rx.Push()
|
||||||
|
dataBuffer := &winrio.Buffer{
|
||||||
|
Id: bind.rx.id,
|
||||||
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
|
||||||
|
Length: uint32(len(packet.data)),
|
||||||
|
}
|
||||||
|
addressBuffer := &winrio.Buffer{
|
||||||
|
Id: bind.rx.id,
|
||||||
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
|
||||||
|
Length: uint32(unsafe.Sizeof(packet.addr)),
|
||||||
|
}
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
|
return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:linkname procyield runtime.procyield
|
||||||
|
func procyield(cycles uint32)
|
||||||
|
|
||||||
|
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
|
||||||
|
if isOpen.Load() != 1 {
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
bind.rx.mu.Lock()
|
||||||
|
defer bind.rx.mu.Unlock()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var count uint32
|
||||||
|
var results [1]winrio.Result
|
||||||
|
retry:
|
||||||
|
count = 0
|
||||||
|
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
||||||
|
if tries > 0 {
|
||||||
|
if isOpen.Load() != 1 {
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
procyield(1)
|
||||||
|
}
|
||||||
|
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
||||||
|
}
|
||||||
|
if count == 0 {
|
||||||
|
err = winrio.Notify(bind.rx.cq)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
var bytes uint32
|
||||||
|
var key uintptr
|
||||||
|
var overlapped *windows.Overlapped
|
||||||
|
err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
if isOpen.Load() != 1 {
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
||||||
|
if count == 0 {
|
||||||
|
return 0, nil, io.ErrNoProgress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bind.rx.Return(1)
|
||||||
|
err = bind.InsertReceiveRequest()
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil, err
|
||||||
|
}
|
||||||
|
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
|
||||||
|
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
||||||
|
// attacker bandwidth, just like the rest of the receive path.
|
||||||
|
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
||||||
|
if isOpen.Load() != 1 {
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
if results[0].Status != 0 {
|
||||||
|
return 0, nil, windows.Errno(results[0].Status)
|
||||||
|
}
|
||||||
|
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
|
||||||
|
ep := packet.addr
|
||||||
|
n := copy(buf, packet.data[:results[0].BytesTransferred])
|
||||||
|
return n, &ep, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||||
|
bind.mu.RLock()
|
||||||
|
defer bind.mu.RUnlock()
|
||||||
|
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
|
||||||
|
sizes[0] = n
|
||||||
|
eps[0] = ep
|
||||||
|
return 1, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
||||||
|
bind.mu.RLock()
|
||||||
|
defer bind.mu.RUnlock()
|
||||||
|
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
|
||||||
|
sizes[0] = n
|
||||||
|
eps[0] = ep
|
||||||
|
return 1, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
|
||||||
|
if isOpen.Load() != 1 {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
if len(buf) > bytesPerPacket {
|
||||||
|
return io.ErrShortBuffer
|
||||||
|
}
|
||||||
|
bind.tx.mu.Lock()
|
||||||
|
defer bind.tx.mu.Unlock()
|
||||||
|
var results [packetsPerRing]winrio.Result
|
||||||
|
count := winrio.DequeueCompletion(bind.tx.cq, results[:])
|
||||||
|
if count == 0 && bind.tx.isFull {
|
||||||
|
err := winrio.Notify(bind.tx.cq)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var bytes uint32
|
||||||
|
var key uintptr
|
||||||
|
var overlapped *windows.Overlapped
|
||||||
|
err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if isOpen.Load() != 1 {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
|
||||||
|
if count == 0 {
|
||||||
|
return io.ErrNoProgress
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if count > 0 {
|
||||||
|
bind.tx.Return(count)
|
||||||
|
}
|
||||||
|
packet := bind.tx.Push()
|
||||||
|
packet.addr = *nend
|
||||||
|
copy(packet.data[:], buf)
|
||||||
|
dataBuffer := &winrio.Buffer{
|
||||||
|
Id: bind.tx.id,
|
||||||
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
|
||||||
|
Length: uint32(len(buf)),
|
||||||
|
}
|
||||||
|
addressBuffer := &winrio.Buffer{
|
||||||
|
Id: bind.tx.id,
|
||||||
|
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
|
||||||
|
Length: uint32(unsafe.Sizeof(packet.addr)),
|
||||||
|
}
|
||||||
|
bind.mu.Lock()
|
||||||
|
defer bind.mu.Unlock()
|
||||||
|
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
||||||
|
nend, ok := endpoint.(*WinRingEndpoint)
|
||||||
|
if !ok {
|
||||||
|
return ErrWrongEndpointType
|
||||||
|
}
|
||||||
|
bind.mu.RLock()
|
||||||
|
defer bind.mu.RUnlock()
|
||||||
|
for _, buf := range bufs {
|
||||||
|
switch nend.family {
|
||||||
|
case windows.AF_INET:
|
||||||
|
if bind.v4.blackhole {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case windows.AF_INET6:
|
||||||
|
if bind.v6.blackhole {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
sysconn, err := s.ipv4.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err2 := sysconn.Control(func(fd uintptr) {
|
||||||
|
err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
|
||||||
|
})
|
||||||
|
if err2 != nil {
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.blackhole4 = blackhole
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
sysconn, err := s.ipv6.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err2 := sysconn.Control(func(fd uintptr) {
|
||||||
|
err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
|
||||||
|
})
|
||||||
|
if err2 != nil {
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.blackhole6 = blackhole
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
|
bind.mu.RLock()
|
||||||
|
defer bind.mu.RUnlock()
|
||||||
|
if bind.isOpen.Load() != 1 {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bind.v4.blackhole = blackhole
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
|
bind.mu.RLock()
|
||||||
|
defer bind.mu.RUnlock()
|
||||||
|
if bind.isOpen.Load() != 1 {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bind.v6.blackhole = blackhole
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
|
||||||
|
const IP_UNICAST_IF = 31
|
||||||
|
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
||||||
|
var bytes [4]byte
|
||||||
|
binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
|
||||||
|
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
||||||
|
err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
|
||||||
|
const IPV6_UNICAST_IF = 31
|
||||||
|
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
|
||||||
|
}
|
||||||
136
conn/bindtest/bindtest.go
Normal file
136
conn/bindtest/bindtest.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package bindtest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChannelBind struct {
|
||||||
|
rx4, tx4 *chan []byte
|
||||||
|
rx6, tx6 *chan []byte
|
||||||
|
closeSignal chan bool
|
||||||
|
source4, source6 ChannelEndpoint
|
||||||
|
target4, target6 ChannelEndpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChannelEndpoint uint16
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ conn.Bind = (*ChannelBind)(nil)
|
||||||
|
_ conn.Endpoint = (*ChannelEndpoint)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewChannelBinds() [2]conn.Bind {
|
||||||
|
arx4 := make(chan []byte, 8192)
|
||||||
|
brx4 := make(chan []byte, 8192)
|
||||||
|
arx6 := make(chan []byte, 8192)
|
||||||
|
brx6 := make(chan []byte, 8192)
|
||||||
|
var binds [2]ChannelBind
|
||||||
|
binds[0].rx4 = &arx4
|
||||||
|
binds[0].tx4 = &brx4
|
||||||
|
binds[1].rx4 = &brx4
|
||||||
|
binds[1].tx4 = &arx4
|
||||||
|
binds[0].rx6 = &arx6
|
||||||
|
binds[0].tx6 = &brx6
|
||||||
|
binds[1].rx6 = &brx6
|
||||||
|
binds[1].tx6 = &arx6
|
||||||
|
binds[0].target4 = ChannelEndpoint(1)
|
||||||
|
binds[1].target4 = ChannelEndpoint(2)
|
||||||
|
binds[0].target6 = ChannelEndpoint(3)
|
||||||
|
binds[1].target6 = ChannelEndpoint(4)
|
||||||
|
binds[0].source4 = binds[1].target4
|
||||||
|
binds[0].source6 = binds[1].target6
|
||||||
|
binds[1].source4 = binds[0].target4
|
||||||
|
binds[1].source6 = binds[0].target6
|
||||||
|
return [2]conn.Bind{&binds[0], &binds[1]}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) SrcToString() string { return "" }
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
|
||||||
|
|
||||||
|
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
|
||||||
|
|
||||||
|
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||||
|
c.closeSignal = make(chan bool)
|
||||||
|
fns = append(fns, c.makeReceiveFunc(*c.rx4))
|
||||||
|
fns = append(fns, c.makeReceiveFunc(*c.rx6))
|
||||||
|
if rand.Uint32()&1 == 0 {
|
||||||
|
return fns, uint16(c.source4), nil
|
||||||
|
} else {
|
||||||
|
return fns, uint16(c.source6), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) Close() error {
|
||||||
|
if c.closeSignal != nil {
|
||||||
|
select {
|
||||||
|
case <-c.closeSignal:
|
||||||
|
default:
|
||||||
|
close(c.closeSignal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) BatchSize() int { return 1 }
|
||||||
|
|
||||||
|
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
||||||
|
|
||||||
|
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case <-c.closeSignal:
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
case rx := <-ch:
|
||||||
|
copied := copy(bufs[0], rx)
|
||||||
|
sizes[0] = copied
|
||||||
|
eps[0] = c.target6
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||||
|
for _, b := range bufs {
|
||||||
|
select {
|
||||||
|
case <-c.closeSignal:
|
||||||
|
return net.ErrClosed
|
||||||
|
default:
|
||||||
|
bc := make([]byte, len(b))
|
||||||
|
copy(bc, b)
|
||||||
|
if ep.(ChannelEndpoint) == c.target4 {
|
||||||
|
*c.tx4 <- bc
|
||||||
|
} else if ep.(ChannelEndpoint) == c.target6 {
|
||||||
|
*c.tx6 <- bc
|
||||||
|
} else {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||||
|
addr, err := netip.ParseAddrPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ChannelEndpoint(addr.Port()), nil
|
||||||
|
}
|
||||||
34
conn/boundif_android.go
Normal file
34
conn/boundif_android.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||||
|
sysconn, err := s.ipv4.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
err = sysconn.Control(func(f uintptr) {
|
||||||
|
fd = int(f)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
||||||
|
sysconn, err := s.ipv6.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
err = sysconn.Control(func(f uintptr) {
|
||||||
|
fd = int(f)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
133
conn/conn.go
Normal file
133
conn/conn.go
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package conn implements WireGuard's network connections.
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
||||||
|
)
|
||||||
|
|
||||||
|
// A ReceiveFunc receives at least one packet from the network and writes them
|
||||||
|
// into packets. On a successful read it returns the number of elements of
|
||||||
|
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
||||||
|
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
||||||
|
// and eps slice with a length greater than or equal to the length of packets.
|
||||||
|
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
||||||
|
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
||||||
|
|
||||||
|
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||||
|
//
|
||||||
|
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
||||||
|
// depending on the platform-specific implementation.
|
||||||
|
type Bind interface {
|
||||||
|
// Open puts the Bind into a listening state on a given port and reports the actual
|
||||||
|
// port that it bound to. Passing zero results in a random selection.
|
||||||
|
// fns is the set of functions that will be called to receive packets.
|
||||||
|
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
|
||||||
|
|
||||||
|
// Close closes the Bind listener.
|
||||||
|
// All fns returned by Open must return net.ErrClosed after a call to Close.
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// SetMark sets the mark for each packet sent through this Bind.
|
||||||
|
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||||
|
SetMark(mark uint32) error
|
||||||
|
|
||||||
|
// Send writes one or more packets in bufs to address ep. The length of
|
||||||
|
// bufs must not exceed BatchSize().
|
||||||
|
Send(bufs [][]byte, ep Endpoint) error
|
||||||
|
|
||||||
|
// ParseEndpoint creates a new endpoint from a string.
|
||||||
|
ParseEndpoint(s string) (Endpoint, error)
|
||||||
|
|
||||||
|
// BatchSize is the number of buffers expected to be passed to
|
||||||
|
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
||||||
|
BatchSize() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// BindSocketToInterface is implemented by Bind objects that support being
|
||||||
|
// tied to a single network interface. Used by wireguard-windows.
|
||||||
|
type BindSocketToInterface interface {
|
||||||
|
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
|
||||||
|
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeekLookAtSocketFd is implemented by Bind objects that support having their
|
||||||
|
// file descriptor peeked at. Used by wireguard-android.
|
||||||
|
type PeekLookAtSocketFd interface {
|
||||||
|
PeekLookAtSocketFd4() (fd int, err error)
|
||||||
|
PeekLookAtSocketFd6() (fd int, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Endpoint maintains the source/destination caching for a peer.
|
||||||
|
//
|
||||||
|
// dst: the remote address of a peer ("endpoint" in uapi terminology)
|
||||||
|
// src: the local address from which datagrams originate going to the peer
|
||||||
|
type Endpoint interface {
|
||||||
|
ClearSrc() // clears the source address
|
||||||
|
SrcToString() string // returns the local source address (ip:port)
|
||||||
|
DstToString() string // returns the destination address (ip:port)
|
||||||
|
DstToBytes() []byte // used for mac2 cookie calculations
|
||||||
|
DstIP() netip.Addr
|
||||||
|
SrcIP() netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrBindAlreadyOpen = errors.New("bind is already open")
|
||||||
|
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (fn ReceiveFunc) PrettyName() string {
|
||||||
|
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
||||||
|
// 0. cheese/taco.beansIPv6.func12.func21218-fm
|
||||||
|
name = strings.TrimSuffix(name, "-fm")
|
||||||
|
// 1. cheese/taco.beansIPv6.func12.func21218
|
||||||
|
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
|
||||||
|
name = name[idx+1:]
|
||||||
|
// 2. taco.beansIPv6.func12.func21218
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
var idx int
|
||||||
|
for idx = len(name) - 1; idx >= 0; idx-- {
|
||||||
|
if name[idx] < '0' || name[idx] > '9' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx == len(name)-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
const dotFunc = ".func"
|
||||||
|
if !strings.HasSuffix(name[:idx+1], dotFunc) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
name = name[:idx+1-len(dotFunc)]
|
||||||
|
// 3. taco.beansIPv6.func12
|
||||||
|
// 4. taco.beansIPv6
|
||||||
|
}
|
||||||
|
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
|
||||||
|
name = name[idx+1:]
|
||||||
|
// 5. beansIPv6
|
||||||
|
}
|
||||||
|
if name == "" {
|
||||||
|
return fmt.Sprintf("%p", fn)
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(name, "IPv4") {
|
||||||
|
return "v4"
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(name, "IPv6") {
|
||||||
|
return "v6"
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
24
conn/conn_test.go
Normal file
24
conn/conn_test.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrettyName(t *testing.T) {
|
||||||
|
var (
|
||||||
|
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
|
||||||
|
)
|
||||||
|
|
||||||
|
const want = "TestPrettyName"
|
||||||
|
|
||||||
|
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
|
||||||
|
if got := recvFunc.PrettyName(); got != want {
|
||||||
|
t.Errorf("PrettyName() = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
43
conn/controlfns.go
Normal file
43
conn/controlfns.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
||||||
|
// the max supported by a default configuration of macOS. Some platforms will
|
||||||
|
// silently clamp the value to other maximums, such as linux clamping to
|
||||||
|
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
||||||
|
// around this limitation)
|
||||||
|
const socketBufferSize = 7 << 20
|
||||||
|
|
||||||
|
// controlFn is the callback function signature from net.ListenConfig.Control.
|
||||||
|
// It is used to apply platform specific configuration to the socket prior to
|
||||||
|
// bind.
|
||||||
|
type controlFn func(network, address string, c syscall.RawConn) error
|
||||||
|
|
||||||
|
// controlFns is a list of functions that are called from the listen config
|
||||||
|
// that can apply socket options.
|
||||||
|
var controlFns = []controlFn{}
|
||||||
|
|
||||||
|
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
||||||
|
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
||||||
|
// information OOB configuration for sticky sockets.
|
||||||
|
func listenConfig() *net.ListenConfig {
|
||||||
|
return &net.ListenConfig{
|
||||||
|
Control: func(network, address string, c syscall.RawConn) error {
|
||||||
|
for _, fn := range controlFns {
|
||||||
|
if err := fn(network, address, c); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
69
conn/controlfns_linux.go
Normal file
69
conn/controlfns_linux.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
controlFns = append(controlFns,
|
||||||
|
|
||||||
|
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
||||||
|
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
||||||
|
// fail silently - the result of failure is lower performance on very fast
|
||||||
|
// links or high latency links.
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
return c.Control(func(fd uintptr) {
|
||||||
|
// Set up to *mem_max
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||||
|
// Set beyond *mem_max if CAP_NET_ADMIN
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
||||||
|
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
var err error
|
||||||
|
switch network {
|
||||||
|
case "udp4":
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "udp6":
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
if runtime.GOOS != "android" {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
|
||||||
|
// Attempt to enable UDP_GRO
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
35
conn/controlfns_unix.go
Normal file
35
conn/controlfns_unix.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
//go:build !windows && !linux && !wasm
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
controlFns = append(controlFns,
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
return c.Control(func(fd uintptr) {
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
||||||
|
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
var err error
|
||||||
|
if network == "udp6" {
|
||||||
|
c.Control(func(fd uintptr) {
|
||||||
|
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
23
conn/controlfns_windows.go
Normal file
23
conn/controlfns_windows.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
controlFns = append(controlFns,
|
||||||
|
func(network, address string, c syscall.RawConn) error {
|
||||||
|
return c.Control(func(fd uintptr) {
|
||||||
|
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
|
||||||
|
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
10
conn/default.go
Normal file
10
conn/default.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func NewDefaultBind() Bind { return NewStdNetBind() }
|
||||||
12
conn/errors_default.go
Normal file
12
conn/errors_default.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
26
conn/errors_linux.go
Normal file
26
conn/errors_linux.go
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func errShouldDisableUDPGSO(err error) bool {
|
||||||
|
var serr *os.SyscallError
|
||||||
|
if errors.As(err, &serr) {
|
||||||
|
// EIO is returned by udp_send_skb() if the device driver does not have
|
||||||
|
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
||||||
|
// See:
|
||||||
|
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
||||||
|
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
||||||
|
return serr.Err == unix.EIO
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
15
conn/features_default.go
Normal file
15
conn/features_default.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
return
|
||||||
|
}
|
||||||
29
conn/features_linux.go
Normal file
29
conn/features_linux.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
||||||
|
rc, err := conn.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = rc.Control(func(fd uintptr) {
|
||||||
|
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
||||||
|
txOffload = errSyscall == nil
|
||||||
|
opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO)
|
||||||
|
rxOffload = errSyscall == nil && opt == 1
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return txOffload, rxOffload
|
||||||
|
}
|
||||||
21
conn/gso_default.go
Normal file
21
conn/gso_default.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
||||||
|
// offloading control data.
|
||||||
|
const gsoControlSize = 0
|
||||||
65
conn/gso_linux.go
Normal file
65
conn/gso_linux.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
//go:build linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sizeOfGSOData = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
||||||
|
func getGSOSize(control []byte) (int, error) {
|
||||||
|
var (
|
||||||
|
hdr unix.Cmsghdr
|
||||||
|
data []byte
|
||||||
|
rem = control
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(rem) > unix.SizeofCmsghdr {
|
||||||
|
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
||||||
|
}
|
||||||
|
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
||||||
|
var gso uint16
|
||||||
|
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
||||||
|
return int(gso), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
||||||
|
// data in control untouched.
|
||||||
|
func setGSOSize(control *[]byte, gsoSize uint16) {
|
||||||
|
existingLen := len(*control)
|
||||||
|
avail := cap(*control) - existingLen
|
||||||
|
space := unix.CmsgSpace(sizeOfGSOData)
|
||||||
|
if avail < space {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*control = (*control)[:cap(*control)]
|
||||||
|
gsoControl := (*control)[existingLen:]
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
||||||
|
hdr.Level = unix.SOL_UDP
|
||||||
|
hdr.Type = unix.UDP_SEGMENT
|
||||||
|
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
||||||
|
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
||||||
|
*control = (*control)[:existingLen+space]
|
||||||
|
}
|
||||||
|
|
||||||
|
// gsoControlSize returns the recommended buffer size for pooling UDP
|
||||||
|
// offloading control data.
|
||||||
|
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
||||||
12
conn/mark_default.go
Normal file
12
conn/mark_default.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !linux && !openbsd && !freebsd
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
// +build android openbsd freebsd
|
//go:build linux || openbsd || freebsd
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -26,13 +26,13 @@ func init() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *nativeBind) SetMark(mark uint32) error {
|
func (s *StdNetBind) SetMark(mark uint32) error {
|
||||||
var operr error
|
var operr error
|
||||||
if fwmarkIoctl == 0 {
|
if fwmarkIoctl == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if bind.ipv4 != nil {
|
if s.ipv4 != nil {
|
||||||
fd, err := bind.ipv4.SyscallConn()
|
fd, err := s.ipv4.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -46,8 +46,8 @@ func (bind *nativeBind) SetMark(mark uint32) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if bind.ipv6 != nil {
|
if s.ipv6 != nil {
|
||||||
fd, err := bind.ipv6.SyscallConn()
|
fd, err := s.ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
42
conn/sticky_default.go
Normal file
42
conn/sticky_default.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import "net/netip"
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
||||||
|
// {get,set}srcControl feature set, but use alternatively named flags and need
|
||||||
|
// ports and require testing.
|
||||||
|
|
||||||
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
|
}
|
||||||
|
|
||||||
|
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||||
|
// offloading control data.
|
||||||
|
const stickyControlSize = 0
|
||||||
|
|
||||||
|
const StdNetSupportsStickySockets = false
|
||||||
112
conn/sticky_linux.go
Normal file
112
conn/sticky_linux.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
||||||
|
switch len(e.src) {
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
return netip.AddrFrom4(info.Spec_dst)
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
// TODO: set zone. in order to do so we need to check if the address is
|
||||||
|
// link local, and if it is perform a syscall to turn the ifindex into a
|
||||||
|
// zone string because netip uses string zones.
|
||||||
|
return netip.AddrFrom16(info.Addr)
|
||||||
|
}
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
||||||
|
switch len(e.src) {
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
return info.Ifindex
|
||||||
|
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
||||||
|
return int32(info.Ifindex)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StdNetEndpoint) SrcToString() string {
|
||||||
|
return e.SrcIP().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
||||||
|
// the source information found.
|
||||||
|
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
||||||
|
ep.ClearSrc()
|
||||||
|
|
||||||
|
var (
|
||||||
|
hdr unix.Cmsghdr
|
||||||
|
data []byte
|
||||||
|
rem []byte = control
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(rem) > unix.SizeofCmsghdr {
|
||||||
|
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Level == unix.IPPROTO_IP &&
|
||||||
|
hdr.Type == unix.IP_PKTINFO {
|
||||||
|
|
||||||
|
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
|
||||||
|
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||||
|
}
|
||||||
|
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
||||||
|
|
||||||
|
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||||
|
copy(ep.src, hdrBuf)
|
||||||
|
copy(ep.src[unix.CmsgLen(0):], data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Level == unix.IPPROTO_IPV6 &&
|
||||||
|
hdr.Type == unix.IPV6_PKTINFO {
|
||||||
|
|
||||||
|
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
|
||||||
|
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
||||||
|
|
||||||
|
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
||||||
|
copy(ep.src, hdrBuf)
|
||||||
|
copy(ep.src[unix.CmsgLen(0):], data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
||||||
|
// and source ifindex found in ep. control's len will be set to 0 in the event
|
||||||
|
// that ep is a default value.
|
||||||
|
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
||||||
|
if cap(*control) < len(ep.src) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*control = (*control)[:0]
|
||||||
|
*control = append(*control, ep.src...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stickyControlSize returns the recommended buffer size for pooling sticky
|
||||||
|
// offloading control data.
|
||||||
|
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
||||||
|
|
||||||
|
const StdNetSupportsStickySockets = true
|
||||||
266
conn/sticky_linux_test.go
Normal file
266
conn/sticky_linux_test.go
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
|
||||||
|
var buf []byte
|
||||||
|
if addr.Is4() {
|
||||||
|
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||||
|
hdr := unix.Cmsghdr{
|
||||||
|
Level: unix.IPPROTO_IP,
|
||||||
|
Type: unix.IP_PKTINFO,
|
||||||
|
}
|
||||||
|
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
||||||
|
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
||||||
|
|
||||||
|
info := unix.Inet4Pktinfo{
|
||||||
|
Ifindex: ifidx,
|
||||||
|
Spec_dst: addr.As4(),
|
||||||
|
}
|
||||||
|
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
|
||||||
|
} else {
|
||||||
|
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
||||||
|
hdr := unix.Cmsghdr{
|
||||||
|
Level: unix.IPPROTO_IPV6,
|
||||||
|
Type: unix.IPV6_PKTINFO,
|
||||||
|
}
|
||||||
|
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
|
||||||
|
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
||||||
|
|
||||||
|
info := unix.Inet6Pktinfo{
|
||||||
|
Ifindex: uint32(ifidx),
|
||||||
|
Addr: addr.As16(),
|
||||||
|
}
|
||||||
|
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.src = buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_setSrcControl(t *testing.T) {
|
||||||
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
|
ep := &StdNetEndpoint{
|
||||||
|
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
|
||||||
|
}
|
||||||
|
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
|
||||||
|
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
|
||||||
|
setSrcControl(&control, ep)
|
||||||
|
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
if hdr.Level != unix.IPPROTO_IP {
|
||||||
|
t.Errorf("unexpected level: %d", hdr.Level)
|
||||||
|
}
|
||||||
|
if hdr.Type != unix.IP_PKTINFO {
|
||||||
|
t.Errorf("unexpected type: %d", hdr.Type)
|
||||||
|
}
|
||||||
|
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
|
||||||
|
t.Errorf("unexpected length: %d", hdr.Len)
|
||||||
|
}
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
|
||||||
|
t.Errorf("unexpected address: %v", info.Spec_dst)
|
||||||
|
}
|
||||||
|
if info.Ifindex != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
|
ep := &StdNetEndpoint{
|
||||||
|
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
|
||||||
|
}
|
||||||
|
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
||||||
|
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
|
||||||
|
setSrcControl(&control, ep)
|
||||||
|
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
if hdr.Level != unix.IPPROTO_IPV6 {
|
||||||
|
t.Errorf("unexpected level: %d", hdr.Level)
|
||||||
|
}
|
||||||
|
if hdr.Type != unix.IPV6_PKTINFO {
|
||||||
|
t.Errorf("unexpected type: %d", hdr.Type)
|
||||||
|
}
|
||||||
|
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
|
||||||
|
t.Errorf("unexpected length: %d", hdr.Len)
|
||||||
|
}
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
if info.Addr != ep.SrcIP().As16() {
|
||||||
|
t.Errorf("unexpected address: %v", info.Addr)
|
||||||
|
}
|
||||||
|
if info.Ifindex != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ClearOnNoSrc", func(t *testing.T) {
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = 1
|
||||||
|
hdr.Type = 2
|
||||||
|
hdr.Len = 3
|
||||||
|
|
||||||
|
setSrcControl(&control, &StdNetEndpoint{})
|
||||||
|
|
||||||
|
if len(control) != 0 {
|
||||||
|
t.Errorf("unexpected control: %v", control)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_getSrcFromControl(t *testing.T) {
|
||||||
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = unix.IPPROTO_IP
|
||||||
|
hdr.Type = unix.IP_PKTINFO
|
||||||
|
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||||
|
info.Ifindex = 5
|
||||||
|
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
getSrcFromControl(control, ep)
|
||||||
|
|
||||||
|
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
||||||
|
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||||
|
}
|
||||||
|
if ep.SrcIfidx() != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
|
control := make([]byte, stickyControlSize)
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = unix.IPPROTO_IPV6
|
||||||
|
hdr.Type = unix.IPV6_PKTINFO
|
||||||
|
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
|
||||||
|
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||||||
|
info.Ifindex = 5
|
||||||
|
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
getSrcFromControl(control, ep)
|
||||||
|
|
||||||
|
if ep.SrcIP() != netip.MustParseAddr("::1") {
|
||||||
|
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||||
|
}
|
||||||
|
if ep.SrcIfidx() != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("ClearOnEmpty", func(t *testing.T) {
|
||||||
|
var control []byte
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
||||||
|
|
||||||
|
getSrcFromControl(control, ep)
|
||||||
|
if ep.SrcIP().IsValid() {
|
||||||
|
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||||
|
}
|
||||||
|
if ep.SrcIfidx() != 0 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("Multiple", func(t *testing.T) {
|
||||||
|
zeroControl := make([]byte, unix.CmsgSpace(0))
|
||||||
|
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
|
||||||
|
zeroHdr.SetLen(unix.CmsgLen(0))
|
||||||
|
|
||||||
|
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
||||||
|
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
||||||
|
hdr.Level = unix.IPPROTO_IP
|
||||||
|
hdr.Type = unix.IP_PKTINFO
|
||||||
|
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
||||||
|
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
||||||
|
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
||||||
|
info.Ifindex = 5
|
||||||
|
|
||||||
|
combined := make([]byte, 0)
|
||||||
|
combined = append(combined, zeroControl...)
|
||||||
|
combined = append(combined, control...)
|
||||||
|
|
||||||
|
ep := &StdNetEndpoint{}
|
||||||
|
getSrcFromControl(combined, ep)
|
||||||
|
|
||||||
|
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
||||||
|
t.Errorf("unexpected address: %v", ep.SrcIP())
|
||||||
|
}
|
||||||
|
if ep.SrcIfidx() != 5 {
|
||||||
|
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_listenConfig(t *testing.T) {
|
||||||
|
t.Run("IPv4", func(t *testing.T) {
|
||||||
|
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
var i int
|
||||||
|
sc.Control(func(fd uintptr) {
|
||||||
|
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if i != 1 {
|
||||||
|
t.Error("IP_PKTINFO not set!")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("IPv6", func(t *testing.T) {
|
||||||
|
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
sc, err := conn.(*net.UDPConn).SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
var i int
|
||||||
|
sc.Control(func(fd uintptr) {
|
||||||
|
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if i != 1 {
|
||||||
|
t.Error("IPV6_PKTINFO not set!")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
254
conn/winrio/rio_windows.go
Normal file
254
conn/winrio/rio_windows.go
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package winrio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MsgDontNotify = 1
|
||||||
|
MsgDefer = 2
|
||||||
|
MsgWaitAll = 4
|
||||||
|
MsgCommitOnly = 8
|
||||||
|
|
||||||
|
MaxCqSize = 0x8000000
|
||||||
|
|
||||||
|
invalidBufferId = 0xFFFFFFFF
|
||||||
|
invalidCq = 0
|
||||||
|
invalidRq = 0
|
||||||
|
corruptCq = 0xFFFFFFFF
|
||||||
|
)
|
||||||
|
|
||||||
|
var extensionFunctionTable struct {
|
||||||
|
cbSize uint32
|
||||||
|
rioReceive uintptr
|
||||||
|
rioReceiveEx uintptr
|
||||||
|
rioSend uintptr
|
||||||
|
rioSendEx uintptr
|
||||||
|
rioCloseCompletionQueue uintptr
|
||||||
|
rioCreateCompletionQueue uintptr
|
||||||
|
rioCreateRequestQueue uintptr
|
||||||
|
rioDequeueCompletion uintptr
|
||||||
|
rioDeregisterBuffer uintptr
|
||||||
|
rioNotify uintptr
|
||||||
|
rioRegisterBuffer uintptr
|
||||||
|
rioResizeCompletionQueue uintptr
|
||||||
|
rioResizeRequestQueue uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cq uintptr
|
||||||
|
|
||||||
|
type Rq uintptr
|
||||||
|
|
||||||
|
type BufferId uintptr
|
||||||
|
|
||||||
|
type Buffer struct {
|
||||||
|
Id BufferId
|
||||||
|
Offset uint32
|
||||||
|
Length uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Status int32
|
||||||
|
BytesTransferred uint32
|
||||||
|
SocketContext uint64
|
||||||
|
RequestContext uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type notificationCompletionType uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
eventCompletion notificationCompletionType = 1
|
||||||
|
iocpCompletion notificationCompletionType = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type eventNotificationCompletion struct {
|
||||||
|
completionType notificationCompletionType
|
||||||
|
event windows.Handle
|
||||||
|
notifyReset uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type iocpNotificationCompletion struct {
|
||||||
|
completionType notificationCompletionType
|
||||||
|
iocp windows.Handle
|
||||||
|
key uintptr
|
||||||
|
overlapped *windows.Overlapped
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
initialized sync.Once
|
||||||
|
available bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func Initialize() bool {
|
||||||
|
initialized.Do(func() {
|
||||||
|
var (
|
||||||
|
err error
|
||||||
|
socket windows.Handle
|
||||||
|
cq Cq
|
||||||
|
)
|
||||||
|
defer func() {
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Registered I/O is unavailable: %v", err)
|
||||||
|
}()
|
||||||
|
socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer windows.CloseHandle(socket)
|
||||||
|
WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
|
||||||
|
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
|
||||||
|
ob := uint32(0)
|
||||||
|
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
|
||||||
|
(*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
|
||||||
|
(*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
|
||||||
|
&ob, nil, 0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
|
||||||
|
// failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
|
||||||
|
var iocp windows.Handle
|
||||||
|
iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer windows.CloseHandle(iocp)
|
||||||
|
var overlapped windows.Overlapped
|
||||||
|
cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer CloseCompletionQueue(cq)
|
||||||
|
_, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
available = true
|
||||||
|
})
|
||||||
|
return available
|
||||||
|
}
|
||||||
|
|
||||||
|
func Socket(af, typ, proto int32) (windows.Handle, error) {
|
||||||
|
return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CloseCompletionQueue(cq Cq) {
|
||||||
|
_, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
|
||||||
|
notificationCompletion := &eventNotificationCompletion{
|
||||||
|
completionType: eventCompletion,
|
||||||
|
event: event,
|
||||||
|
}
|
||||||
|
if notifyReset {
|
||||||
|
notificationCompletion.notifyReset = 1
|
||||||
|
}
|
||||||
|
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
|
||||||
|
if ret == invalidCq {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return Cq(ret), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
|
||||||
|
notificationCompletion := &iocpNotificationCompletion{
|
||||||
|
completionType: iocpCompletion,
|
||||||
|
iocp: iocp,
|
||||||
|
key: key,
|
||||||
|
overlapped: overlapped,
|
||||||
|
}
|
||||||
|
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
|
||||||
|
if ret == invalidCq {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return Cq(ret), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
|
||||||
|
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
|
||||||
|
if ret == invalidCq {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return Cq(ret), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
|
||||||
|
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
|
||||||
|
if ret == invalidRq {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return Rq(ret), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DequeueCompletion(cq Cq, results []Result) uint32 {
|
||||||
|
var array uintptr
|
||||||
|
if len(results) > 0 {
|
||||||
|
array = uintptr(unsafe.Pointer(&results[0]))
|
||||||
|
}
|
||||||
|
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
|
||||||
|
if ret == corruptCq {
|
||||||
|
panic("cq is corrupt")
|
||||||
|
}
|
||||||
|
return uint32(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeregisterBuffer(id BufferId) {
|
||||||
|
_, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterBuffer(buffer []byte) (BufferId, error) {
|
||||||
|
var buf unsafe.Pointer
|
||||||
|
if len(buffer) > 0 {
|
||||||
|
buf = unsafe.Pointer(&buffer[0])
|
||||||
|
}
|
||||||
|
return RegisterPointer(buf, uint32(len(buffer)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
|
||||||
|
ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
|
||||||
|
if ret == invalidBufferId {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return BufferId(ret), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
|
||||||
|
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
|
||||||
|
if ret == 0 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
|
||||||
|
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
|
||||||
|
if ret == 0 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Notify(cq Cq) error {
|
||||||
|
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
|
||||||
|
if ret != 0 {
|
||||||
|
return windows.Errno(ret)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,173 +1,201 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"container/list"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"math/bits"
|
"math/bits"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type parentIndirection struct {
|
||||||
|
parentBit **trieEntry
|
||||||
|
parentBitType uint8
|
||||||
|
}
|
||||||
|
|
||||||
type trieEntry struct {
|
type trieEntry struct {
|
||||||
cidr uint
|
peer *Peer
|
||||||
child [2]*trieEntry
|
child [2]*trieEntry
|
||||||
bits net.IP
|
parent parentIndirection
|
||||||
peer *Peer
|
cidr uint8
|
||||||
|
bitAtByte uint8
|
||||||
// index of "branching" bit
|
bitAtShift uint8
|
||||||
|
bits []byte
|
||||||
bit_at_byte uint
|
perPeerElem *list.Element
|
||||||
bit_at_shift uint
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLittleEndian() bool {
|
func commonBits(ip1, ip2 []byte) uint8 {
|
||||||
one := uint32(1)
|
|
||||||
return *(*byte)(unsafe.Pointer(&one)) != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func swapU32(i uint32) uint32 {
|
|
||||||
if !isLittleEndian() {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
return bits.ReverseBytes32(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
func swapU64(i uint64) uint64 {
|
|
||||||
if !isLittleEndian() {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
return bits.ReverseBytes64(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
|
||||||
size := len(ip1)
|
size := len(ip1)
|
||||||
if size == net.IPv4len {
|
if size == net.IPv4len {
|
||||||
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
a := binary.BigEndian.Uint32(ip1)
|
||||||
b := (*uint32)(unsafe.Pointer(&ip2[0]))
|
b := binary.BigEndian.Uint32(ip2)
|
||||||
x := *a ^ *b
|
x := a ^ b
|
||||||
return uint(bits.LeadingZeros32(swapU32(x)))
|
return uint8(bits.LeadingZeros32(x))
|
||||||
} else if size == net.IPv6len {
|
} else if size == net.IPv6len {
|
||||||
a := (*uint64)(unsafe.Pointer(&ip1[0]))
|
a := binary.BigEndian.Uint64(ip1)
|
||||||
b := (*uint64)(unsafe.Pointer(&ip2[0]))
|
b := binary.BigEndian.Uint64(ip2)
|
||||||
x := *a ^ *b
|
x := a ^ b
|
||||||
if x != 0 {
|
if x != 0 {
|
||||||
return uint(bits.LeadingZeros64(swapU64(x)))
|
return uint8(bits.LeadingZeros64(x))
|
||||||
}
|
}
|
||||||
a = (*uint64)(unsafe.Pointer(&ip1[8]))
|
a = binary.BigEndian.Uint64(ip1[8:])
|
||||||
b = (*uint64)(unsafe.Pointer(&ip2[8]))
|
b = binary.BigEndian.Uint64(ip2[8:])
|
||||||
x = *a ^ *b
|
x = a ^ b
|
||||||
return 64 + uint(bits.LeadingZeros64(swapU64(x)))
|
return 64 + uint8(bits.LeadingZeros64(x))
|
||||||
} else {
|
} else {
|
||||||
panic("Wrong size bit string")
|
panic("Wrong size bit string")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
func (node *trieEntry) addToPeerEntries() {
|
||||||
if node == nil {
|
node.perPeerElem = node.peer.trieEntries.PushBack(node)
|
||||||
return node
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) removeFromPeerEntries() {
|
||||||
|
if node.perPeerElem != nil {
|
||||||
|
node.peer.trieEntries.Remove(node.perPeerElem)
|
||||||
|
node.perPeerElem = nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// walk recursively
|
func (node *trieEntry) choose(ip []byte) byte {
|
||||||
|
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
||||||
|
}
|
||||||
|
|
||||||
node.child[0] = node.child[0].removeByPeer(p)
|
func (node *trieEntry) maskSelf() {
|
||||||
node.child[1] = node.child[1].removeByPeer(p)
|
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
||||||
|
for i := 0; i < len(mask); i++ {
|
||||||
if node.peer != p {
|
node.bits[i] &= mask[i]
|
||||||
return node
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// remove peer & merge
|
func (node *trieEntry) zeroizePointers() {
|
||||||
|
// Make the garbage collector's life slightly easier
|
||||||
node.peer = nil
|
node.peer = nil
|
||||||
if node.child[0] == nil {
|
node.child[0] = nil
|
||||||
return node.child[1]
|
node.child[1] = nil
|
||||||
}
|
node.parent.parentBit = nil
|
||||||
return node.child[0]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) choose(ip net.IP) byte {
|
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
|
||||||
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
||||||
}
|
parent = node
|
||||||
|
if parent.cidr == cidr {
|
||||||
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
exact = true
|
||||||
|
return
|
||||||
// at leaf
|
|
||||||
|
|
||||||
if node == nil {
|
|
||||||
return &trieEntry{
|
|
||||||
bits: ip,
|
|
||||||
peer: peer,
|
|
||||||
cidr: cidr,
|
|
||||||
bit_at_byte: cidr / 8,
|
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// traverse deeper
|
|
||||||
|
|
||||||
common := commonBits(node.bits, ip)
|
|
||||||
if node.cidr <= cidr && common >= node.cidr {
|
|
||||||
if node.cidr == cidr {
|
|
||||||
node.peer = peer
|
|
||||||
return node
|
|
||||||
}
|
}
|
||||||
bit := node.choose(ip)
|
bit := node.choose(ip)
|
||||||
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
|
node = node.child[bit]
|
||||||
return node
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
// split node
|
|
||||||
|
|
||||||
newNode := &trieEntry{
|
|
||||||
bits: ip,
|
|
||||||
peer: peer,
|
|
||||||
cidr: cidr,
|
|
||||||
bit_at_byte: cidr / 8,
|
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
|
||||||
}
|
|
||||||
|
|
||||||
cidr = min(cidr, common)
|
|
||||||
|
|
||||||
// check for shorter prefix
|
|
||||||
|
|
||||||
if newNode.cidr == cidr {
|
|
||||||
bit := newNode.choose(node.bits)
|
|
||||||
newNode.child[bit] = node
|
|
||||||
return newNode
|
|
||||||
}
|
|
||||||
|
|
||||||
// create new parent for node & newNode
|
|
||||||
|
|
||||||
parent := &trieEntry{
|
|
||||||
bits: ip,
|
|
||||||
peer: nil,
|
|
||||||
cidr: cidr,
|
|
||||||
bit_at_byte: cidr / 8,
|
|
||||||
bit_at_shift: 7 - (cidr % 8),
|
|
||||||
}
|
|
||||||
|
|
||||||
bit := parent.choose(ip)
|
|
||||||
parent.child[bit] = newNode
|
|
||||||
parent.child[bit^1] = node
|
|
||||||
|
|
||||||
return parent
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
|
||||||
|
if *trie.parentBit == nil {
|
||||||
|
node := &trieEntry{
|
||||||
|
peer: peer,
|
||||||
|
parent: trie,
|
||||||
|
bits: ip,
|
||||||
|
cidr: cidr,
|
||||||
|
bitAtByte: cidr / 8,
|
||||||
|
bitAtShift: 7 - (cidr % 8),
|
||||||
|
}
|
||||||
|
node.maskSelf()
|
||||||
|
node.addToPeerEntries()
|
||||||
|
*trie.parentBit = node
|
||||||
|
return
|
||||||
|
}
|
||||||
|
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
|
||||||
|
if exact {
|
||||||
|
node.removeFromPeerEntries()
|
||||||
|
node.peer = peer
|
||||||
|
node.addToPeerEntries()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newNode := &trieEntry{
|
||||||
|
peer: peer,
|
||||||
|
bits: ip,
|
||||||
|
cidr: cidr,
|
||||||
|
bitAtByte: cidr / 8,
|
||||||
|
bitAtShift: 7 - (cidr % 8),
|
||||||
|
}
|
||||||
|
newNode.maskSelf()
|
||||||
|
newNode.addToPeerEntries()
|
||||||
|
|
||||||
|
var down *trieEntry
|
||||||
|
if node == nil {
|
||||||
|
down = *trie.parentBit
|
||||||
|
} else {
|
||||||
|
bit := node.choose(ip)
|
||||||
|
down = node.child[bit]
|
||||||
|
if down == nil {
|
||||||
|
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||||
|
node.child[bit] = newNode
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
common := commonBits(down.bits, ip)
|
||||||
|
if common < cidr {
|
||||||
|
cidr = common
|
||||||
|
}
|
||||||
|
parent := node
|
||||||
|
|
||||||
|
if newNode.cidr == cidr {
|
||||||
|
bit := newNode.choose(down.bits)
|
||||||
|
down.parent = parentIndirection{&newNode.child[bit], bit}
|
||||||
|
newNode.child[bit] = down
|
||||||
|
if parent == nil {
|
||||||
|
newNode.parent = trie
|
||||||
|
*trie.parentBit = newNode
|
||||||
|
} else {
|
||||||
|
bit := parent.choose(newNode.bits)
|
||||||
|
newNode.parent = parentIndirection{&parent.child[bit], bit}
|
||||||
|
parent.child[bit] = newNode
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
node = &trieEntry{
|
||||||
|
bits: append([]byte{}, newNode.bits...),
|
||||||
|
cidr: cidr,
|
||||||
|
bitAtByte: cidr / 8,
|
||||||
|
bitAtShift: 7 - (cidr % 8),
|
||||||
|
}
|
||||||
|
node.maskSelf()
|
||||||
|
|
||||||
|
bit := node.choose(down.bits)
|
||||||
|
down.parent = parentIndirection{&node.child[bit], bit}
|
||||||
|
node.child[bit] = down
|
||||||
|
bit = node.choose(newNode.bits)
|
||||||
|
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||||
|
node.child[bit] = newNode
|
||||||
|
if parent == nil {
|
||||||
|
node.parent = trie
|
||||||
|
*trie.parentBit = node
|
||||||
|
} else {
|
||||||
|
bit := parent.choose(node.bits)
|
||||||
|
node.parent = parentIndirection{&parent.child[bit], bit}
|
||||||
|
parent.child[bit] = node
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) lookup(ip []byte) *Peer {
|
||||||
var found *Peer
|
var found *Peer
|
||||||
size := uint(len(ip))
|
size := uint8(len(ip))
|
||||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||||
if node.peer != nil {
|
if node.peer != nil {
|
||||||
found = node.peer
|
found = node.peer
|
||||||
}
|
}
|
||||||
if node.bit_at_byte == size {
|
if node.bitAtByte == size {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
bit := node.choose(ip)
|
bit := node.choose(ip)
|
||||||
@@ -176,76 +204,91 @@ func (node *trieEntry) lookup(ip net.IP) *Peer {
|
|||||||
return found
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
|
|
||||||
if node == nil {
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
if node.peer == p {
|
|
||||||
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
|
||||||
results = append(results, net.IPNet{
|
|
||||||
Mask: mask,
|
|
||||||
IP: node.bits.Mask(mask),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
results = node.child[0].entriesForPeer(p, results)
|
|
||||||
results = node.child[1].entriesForPeer(p, results)
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
type AllowedIPs struct {
|
type AllowedIPs struct {
|
||||||
IPv4 *trieEntry
|
IPv4 *trieEntry
|
||||||
IPv6 *trieEntry
|
IPv6 *trieEntry
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
|
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
|
|
||||||
allowed := make([]net.IPNet, 0, 10)
|
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
||||||
allowed = table.IPv4.entriesForPeer(peer, allowed)
|
node := elem.Value.(*trieEntry)
|
||||||
allowed = table.IPv6.entriesForPeer(peer, allowed)
|
a, _ := netip.AddrFromSlice(node.bits)
|
||||||
return allowed
|
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
|
||||||
}
|
return
|
||||||
|
}
|
||||||
func (table *AllowedIPs) Reset() {
|
}
|
||||||
table.mutex.Lock()
|
|
||||||
defer table.mutex.Unlock()
|
|
||||||
|
|
||||||
table.IPv4 = nil
|
|
||||||
table.IPv6 = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
table.IPv4 = table.IPv4.removeByPeer(peer)
|
var next *list.Element
|
||||||
table.IPv6 = table.IPv6.removeByPeer(peer)
|
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||||
|
next = elem.Next()
|
||||||
|
node := elem.Value.(*trieEntry)
|
||||||
|
|
||||||
|
node.removeFromPeerEntries()
|
||||||
|
node.peer = nil
|
||||||
|
if node.child[0] != nil && node.child[1] != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
bit := 0
|
||||||
|
if node.child[0] == nil {
|
||||||
|
bit = 1
|
||||||
|
}
|
||||||
|
child := node.child[bit]
|
||||||
|
if child != nil {
|
||||||
|
child.parent = node.parent
|
||||||
|
}
|
||||||
|
*node.parent.parentBit = child
|
||||||
|
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||||
|
node.zeroizePointers()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||||
|
if parent.peer != nil {
|
||||||
|
node.zeroizePointers()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
child = parent.child[node.parent.parentBitType^1]
|
||||||
|
if child != nil {
|
||||||
|
child.parent = parent.parent
|
||||||
|
}
|
||||||
|
*parent.parent.parentBit = child
|
||||||
|
node.zeroizePointers()
|
||||||
|
parent.zeroizePointers()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
|
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||||
table.mutex.Lock()
|
table.mutex.Lock()
|
||||||
defer table.mutex.Unlock()
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
switch len(ip) {
|
if prefix.Addr().Is6() {
|
||||||
case net.IPv6len:
|
ip := prefix.Addr().As16()
|
||||||
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
|
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||||
case net.IPv4len:
|
} else if prefix.Addr().Is4() {
|
||||||
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
|
ip := prefix.Addr().As4()
|
||||||
default:
|
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||||
|
} else {
|
||||||
panic(errors.New("inserting unknown address type"))
|
panic(errors.New("inserting unknown address type"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
|
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
|
||||||
table.mutex.RLock()
|
table.mutex.RLock()
|
||||||
defer table.mutex.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
return table.IPv4.lookup(address)
|
switch len(ip) {
|
||||||
}
|
case net.IPv6len:
|
||||||
|
return table.IPv6.lookup(ip)
|
||||||
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
|
case net.IPv4len:
|
||||||
table.mutex.RLock()
|
return table.IPv4.lookup(ip)
|
||||||
defer table.mutex.RUnlock()
|
default:
|
||||||
return table.IPv6.lookup(address)
|
panic(errors.New("looking up unknown address type"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +1,28 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NumberOfPeers = 100
|
NumberOfPeers = 100
|
||||||
NumberOfAddresses = 250
|
NumberOfPeerRemovals = 4
|
||||||
NumberOfTests = 10000
|
NumberOfAddresses = 250
|
||||||
|
NumberOfTests = 10000
|
||||||
)
|
)
|
||||||
|
|
||||||
type SlowNode struct {
|
type SlowNode struct {
|
||||||
peer *Peer
|
peer *Peer
|
||||||
cidr uint
|
cidr uint8
|
||||||
bits []byte
|
bits []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) {
|
|||||||
r[i], r[j] = r[j], r[i]
|
r[i], r[j] = r[j], r[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
|
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
|
||||||
for _, t := range r {
|
for _, t := range r {
|
||||||
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
|
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
|
||||||
t.peer = peer
|
t.peer = peer
|
||||||
@@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrieRandomIPv4(t *testing.T) {
|
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
|
||||||
var trie *trieEntry
|
n := 0
|
||||||
var slow SlowRouter
|
for _, x := range r {
|
||||||
|
if x.peer != peer {
|
||||||
|
r[n] = x
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrieRandom(t *testing.T) {
|
||||||
|
var slow4, slow6 SlowRouter
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
rand.Seed(1)
|
rand.Seed(1)
|
||||||
|
|
||||||
const AddressLength = 4
|
for n := 0; n < NumberOfPeers; n++ {
|
||||||
|
|
||||||
for n := 0; n < NumberOfPeers; n += 1 {
|
|
||||||
peers = append(peers, &Peer{})
|
peers = append(peers, &Peer{})
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < NumberOfAddresses; n += 1 {
|
for n := 0; n < NumberOfAddresses; n++ {
|
||||||
var addr [AddressLength]byte
|
var addr4 [4]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr4[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Intn(32) + 1)
|
||||||
index := rand.Int() % NumberOfPeers
|
index := rand.Intn(NumberOfPeers)
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
||||||
|
|
||||||
|
var addr6 [16]byte
|
||||||
|
rand.Read(addr6[:])
|
||||||
|
cidr = uint8(rand.Intn(128) + 1)
|
||||||
|
index = rand.Intn(NumberOfPeers)
|
||||||
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
||||||
|
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < NumberOfTests; n += 1 {
|
var p int
|
||||||
var addr [AddressLength]byte
|
for p = 0; ; p++ {
|
||||||
rand.Read(addr[:])
|
for n := 0; n < NumberOfTests; n++ {
|
||||||
peer1 := slow.Lookup(addr[:])
|
var addr4 [4]byte
|
||||||
peer2 := trie.lookup(addr[:])
|
rand.Read(addr4[:])
|
||||||
if peer1 != peer2 {
|
peer1 := slow4.Lookup(addr4[:])
|
||||||
t.Error("Trie did not match naive implementation, for:", addr)
|
peer2 := allowedIPs.Lookup(addr4[:])
|
||||||
}
|
if peer1 != peer2 {
|
||||||
}
|
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTrieRandomIPv6(t *testing.T) {
|
var addr6 [16]byte
|
||||||
var trie *trieEntry
|
rand.Read(addr6[:])
|
||||||
var slow SlowRouter
|
peer1 = slow6.Lookup(addr6[:])
|
||||||
var peers []*Peer
|
peer2 = allowedIPs.Lookup(addr6[:])
|
||||||
|
if peer1 != peer2 {
|
||||||
rand.Seed(1)
|
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
|
||||||
|
}
|
||||||
const AddressLength = 16
|
|
||||||
|
|
||||||
for n := 0; n < NumberOfPeers; n += 1 {
|
|
||||||
peers = append(peers, &Peer{})
|
|
||||||
}
|
|
||||||
|
|
||||||
for n := 0; n < NumberOfAddresses; n += 1 {
|
|
||||||
var addr [AddressLength]byte
|
|
||||||
rand.Read(addr[:])
|
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
|
||||||
index := rand.Int() % NumberOfPeers
|
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
|
||||||
slow = slow.Insert(addr[:], cidr, peers[index])
|
|
||||||
}
|
|
||||||
|
|
||||||
for n := 0; n < NumberOfTests; n += 1 {
|
|
||||||
var addr [AddressLength]byte
|
|
||||||
rand.Read(addr[:])
|
|
||||||
peer1 := slow.Lookup(addr[:])
|
|
||||||
peer2 := trie.lookup(addr[:])
|
|
||||||
if peer1 != peer2 {
|
|
||||||
t.Error("Trie did not match naive implementation, for:", addr)
|
|
||||||
}
|
}
|
||||||
|
if p >= len(peers) || p >= NumberOfPeerRemovals {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
allowedIPs.RemoveByPeer(peers[p])
|
||||||
|
slow4 = slow4.RemoveByPeer(peers[p])
|
||||||
|
slow6 = slow6.RemoveByPeer(peers[p])
|
||||||
|
}
|
||||||
|
for ; p < len(peers); p++ {
|
||||||
|
allowedIPs.RemoveByPeer(peers[p])
|
||||||
|
}
|
||||||
|
|
||||||
|
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
|
||||||
|
t.Error("Failed to remove all nodes from trie by peer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -8,40 +8,17 @@ package device
|
|||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Todo: More comprehensive
|
|
||||||
*/
|
|
||||||
|
|
||||||
type testPairCommonBits struct {
|
type testPairCommonBits struct {
|
||||||
s1 []byte
|
s1 []byte
|
||||||
s2 []byte
|
s2 []byte
|
||||||
match uint
|
match uint8
|
||||||
}
|
|
||||||
|
|
||||||
type testPairTrieInsert struct {
|
|
||||||
key []byte
|
|
||||||
cidr uint
|
|
||||||
peer *Peer
|
|
||||||
}
|
|
||||||
|
|
||||||
type testPairTrieLookup struct {
|
|
||||||
key []byte
|
|
||||||
peer *Peer
|
|
||||||
}
|
|
||||||
|
|
||||||
func printTrie(t *testing.T, p *trieEntry) {
|
|
||||||
if p == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Log(p)
|
|
||||||
printTrie(t, p.child[0])
|
|
||||||
printTrie(t, p.child[1])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCommonBits(t *testing.T) {
|
func TestCommonBits(t *testing.T) {
|
||||||
|
|
||||||
tests := []testPairCommonBits{
|
tests := []testPairCommonBits{
|
||||||
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
|
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
|
||||||
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
|
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
|
||||||
@@ -62,27 +39,28 @@ func TestCommonBits(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
|
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
|
||||||
var trie *trieEntry
|
var trie *trieEntry
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
|
root := parentIndirection{&trie, 2}
|
||||||
|
|
||||||
rand.Seed(1)
|
rand.Seed(1)
|
||||||
|
|
||||||
const AddressLength = 4
|
const AddressLength = 4
|
||||||
|
|
||||||
for n := 0; n < peerNumber; n += 1 {
|
for n := 0; n < peerNumber; n++ {
|
||||||
peers = append(peers, &Peer{})
|
peers = append(peers, &Peer{})
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < addressNumber; n += 1 {
|
for n := 0; n < addressNumber; n++ {
|
||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % peerNumber
|
index := rand.Int() % peerNumber
|
||||||
trie = trie.insert(addr[:], cidr, peers[index])
|
root.insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < b.N; n += 1 {
|
for n := 0; n < b.N; n++ {
|
||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
trie.lookup(addr[:])
|
trie.lookup(addr[:])
|
||||||
@@ -117,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var trie *trieEntry
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
|
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||||
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := trie.lookup([]byte{a, b, c, d})
|
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := trie.lookup([]byte{a, b, c, d})
|
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
||||||
if p == peer {
|
if p == peer {
|
||||||
t.Error("Assert NEQ failed")
|
t.Error("Assert NEQ failed")
|
||||||
}
|
}
|
||||||
@@ -173,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
assertEQ(a, 192, 0, 0, 0)
|
assertEQ(a, 192, 0, 0, 0)
|
||||||
assertEQ(a, 255, 0, 0, 0)
|
assertEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
trie = trie.removeByPeer(a)
|
allowedIPs.RemoveByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 1, 0, 0, 0)
|
assertNEQ(a, 1, 0, 0, 0)
|
||||||
assertNEQ(a, 64, 0, 0, 0)
|
assertNEQ(a, 64, 0, 0, 0)
|
||||||
@@ -181,12 +159,21 @@ func TestTrieIPv4(t *testing.T) {
|
|||||||
assertNEQ(a, 192, 0, 0, 0)
|
assertNEQ(a, 192, 0, 0, 0)
|
||||||
assertNEQ(a, 255, 0, 0, 0)
|
assertNEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
trie = nil
|
allowedIPs.RemoveByPeer(a)
|
||||||
|
allowedIPs.RemoveByPeer(b)
|
||||||
|
allowedIPs.RemoveByPeer(c)
|
||||||
|
allowedIPs.RemoveByPeer(d)
|
||||||
|
allowedIPs.RemoveByPeer(e)
|
||||||
|
allowedIPs.RemoveByPeer(g)
|
||||||
|
allowedIPs.RemoveByPeer(h)
|
||||||
|
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
|
||||||
|
t.Error("Expected removing all the peers to empty trie, but it did not")
|
||||||
|
}
|
||||||
|
|
||||||
insert(a, 192, 168, 0, 0, 16)
|
insert(a, 192, 168, 0, 0, 16)
|
||||||
insert(a, 192, 168, 0, 0, 24)
|
insert(a, 192, 168, 0, 0, 24)
|
||||||
|
|
||||||
trie = trie.removeByPeer(a)
|
allowedIPs.RemoveByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 192, 168, 0, 1)
|
assertNEQ(a, 192, 168, 0, 1)
|
||||||
}
|
}
|
||||||
@@ -204,7 +191,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var trie *trieEntry
|
var allowedIPs AllowedIPs
|
||||||
|
|
||||||
expand := func(a uint32) []byte {
|
expand := func(a uint32) []byte {
|
||||||
var out [4]byte
|
var out [4]byte
|
||||||
@@ -215,13 +202,13 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
return out[:]
|
return out[:]
|
||||||
}
|
}
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
|
insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
|
||||||
var addr []byte
|
var addr []byte
|
||||||
addr = append(addr, expand(a)...)
|
addr = append(addr, expand(a)...)
|
||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
trie = trie.insert(addr, cidr, peer)
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
@@ -230,7 +217,7 @@ func TestTrieIPv6(t *testing.T) {
|
|||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
p := trie.lookup(addr)
|
p := allowedIPs.Lookup(addr)
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +1,24 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import "errors"
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
|
)
|
||||||
|
|
||||||
type DummyDatagram struct {
|
type DummyDatagram struct {
|
||||||
msg []byte
|
msg []byte
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
world bool // better type
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type DummyBind struct {
|
type DummyBind struct {
|
||||||
in6 chan DummyDatagram
|
in6 chan DummyDatagram
|
||||||
ou6 chan DummyDatagram
|
|
||||||
in4 chan DummyDatagram
|
in4 chan DummyDatagram
|
||||||
ou4 chan DummyDatagram
|
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
|
||||||
datagram, ok := <-b.in6
|
datagram, ok := <-b.in6
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
}
|
}
|
||||||
copy(buff, datagram.msg)
|
copy(buf, datagram.msg)
|
||||||
return len(datagram.msg), datagram.endpoint, nil
|
return len(datagram.msg), datagram.endpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
|
||||||
datagram, ok := <-b.in4
|
datagram, ok := <-b.in4
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
}
|
}
|
||||||
copy(buff, datagram.msg)
|
copy(buf, datagram.msg)
|
||||||
return len(datagram.msg), datagram.endpoint, nil
|
return len(datagram.msg), datagram.endpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,6 +51,6 @@ func (b *DummyBind) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) Send(buff []byte, end Endpoint) error {
|
func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,44 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
|
|
||||||
nb, ok := device.net.bind.(*nativeBind)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("no socket exists")
|
|
||||||
}
|
|
||||||
sysconn, err := nb.ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = sysconn.Control(func(f uintptr) {
|
|
||||||
fd = int(f)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
|
|
||||||
nb, ok := device.net.bind.(*nativeBind)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("no socket exists")
|
|
||||||
}
|
|
||||||
sysconn, err := nb.ipv6.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = sysconn.Control(func(f uintptr) {
|
|
||||||
fd = int(f)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
sockoptIP_UNICAST_IF = 31
|
|
||||||
sockoptIPV6_UNICAST_IF = 31
|
|
||||||
)
|
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
|
|
||||||
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
|
||||||
bytes := make([]byte, 4)
|
|
||||||
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
|
||||||
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
|
||||||
|
|
||||||
if device.net.bind == nil {
|
|
||||||
return errors.New("Bind is not yet initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err2 := sysconn.Control(func(fd uintptr) {
|
|
||||||
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
|
|
||||||
})
|
|
||||||
if err2 != nil {
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err2 := sysconn.Control(func(fd uintptr) {
|
|
||||||
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
|
|
||||||
})
|
|
||||||
if err2 != nil {
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
137
device/channels.go
Normal file
137
device/channels.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
|
||||||
|
// An outboundQueue is ref-counted using its wg field.
|
||||||
|
// An outboundQueue created with newOutboundQueue has one reference.
|
||||||
|
// Every additional writer must call wg.Add(1).
|
||||||
|
// Every completed writer must call wg.Done().
|
||||||
|
// When no further writers will be added,
|
||||||
|
// call wg.Done to remove the initial reference.
|
||||||
|
// When the refcount hits 0, the queue's channel is closed.
|
||||||
|
type outboundQueue struct {
|
||||||
|
c chan *QueueOutboundElementsContainer
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOutboundQueue() *outboundQueue {
|
||||||
|
q := &outboundQueue{
|
||||||
|
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||||
|
}
|
||||||
|
q.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
q.wg.Wait()
|
||||||
|
close(q.c)
|
||||||
|
}()
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// A inboundQueue is similar to an outboundQueue; see those docs.
|
||||||
|
type inboundQueue struct {
|
||||||
|
c chan *QueueInboundElementsContainer
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInboundQueue() *inboundQueue {
|
||||||
|
q := &inboundQueue{
|
||||||
|
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||||
|
}
|
||||||
|
q.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
q.wg.Wait()
|
||||||
|
close(q.c)
|
||||||
|
}()
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// A handshakeQueue is similar to an outboundQueue; see those docs.
|
||||||
|
type handshakeQueue struct {
|
||||||
|
c chan QueueHandshakeElement
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHandshakeQueue() *handshakeQueue {
|
||||||
|
q := &handshakeQueue{
|
||||||
|
c: make(chan QueueHandshakeElement, QueueHandshakeSize),
|
||||||
|
}
|
||||||
|
q.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
q.wg.Wait()
|
||||||
|
close(q.c)
|
||||||
|
}()
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
type autodrainingInboundQueue struct {
|
||||||
|
c chan *QueueInboundElementsContainer
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
||||||
|
// It is useful in cases in which is it hard to manage the lifetime of the channel.
|
||||||
|
// The returned channel must not be closed. Senders should signal shutdown using
|
||||||
|
// some other means, such as sending a sentinel nil values.
|
||||||
|
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
||||||
|
q := &autodrainingInboundQueue{
|
||||||
|
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
||||||
|
}
|
||||||
|
runtime.SetFinalizer(q, device.flushInboundQueue)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case elemsContainer := <-q.c:
|
||||||
|
elemsContainer.Lock()
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutInboundElement(elem)
|
||||||
|
}
|
||||||
|
device.PutInboundElementsContainer(elemsContainer)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type autodrainingOutboundQueue struct {
|
||||||
|
c chan *QueueOutboundElementsContainer
|
||||||
|
}
|
||||||
|
|
||||||
|
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
||||||
|
// It is useful in cases in which is it hard to manage the lifetime of the channel.
|
||||||
|
// The returned channel must not be closed. Senders should signal shutdown using
|
||||||
|
// some other means, such as sending a sentinel nil values.
|
||||||
|
// All sends to the channel must be best-effort, because there may be no receivers.
|
||||||
|
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
||||||
|
q := &autodrainingOutboundQueue{
|
||||||
|
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
||||||
|
}
|
||||||
|
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case elemsContainer := <-q.c:
|
||||||
|
elemsContainer.Lock()
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
device.PutOutboundElementsContainer(elemsContainer)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
187
device/conn.go
187
device/conn.go
@@ -1,187 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ConnRoutineNumber = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
|
|
||||||
*/
|
|
||||||
type Bind interface {
|
|
||||||
SetMark(value uint32) error
|
|
||||||
ReceiveIPv6(buff []byte) (int, Endpoint, error)
|
|
||||||
ReceiveIPv4(buff []byte) (int, Endpoint, error)
|
|
||||||
Send(buff []byte, end Endpoint) error
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
/* An Endpoint maintains the source/destination caching for a peer
|
|
||||||
*
|
|
||||||
* dst : the remote address of a peer ("endpoint" in uapi terminology)
|
|
||||||
* src : the local address from which datagrams originate going to the peer
|
|
||||||
*/
|
|
||||||
type Endpoint interface {
|
|
||||||
ClearSrc() // clears the source address
|
|
||||||
SrcToString() string // returns the local source address (ip:port)
|
|
||||||
DstToString() string // returns the destination address (ip:port)
|
|
||||||
DstToBytes() []byte // used for mac2 cookie calculations
|
|
||||||
DstIP() net.IP
|
|
||||||
SrcIP() net.IP
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
|
||||||
// ensure that the host is an IP address
|
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
|
|
||||||
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
|
|
||||||
// trying to make sure with a small sanity test that this is a real IP address and
|
|
||||||
// not something that's likely to incur DNS lookups.
|
|
||||||
host = host[:i]
|
|
||||||
}
|
|
||||||
if ip := net.ParseIP(host); ip == nil {
|
|
||||||
return nil, errors.New("Failed to parse IP address: " + host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse address and port
|
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ip4 := addr.IP.To4()
|
|
||||||
if ip4 != nil {
|
|
||||||
addr.IP = ip4
|
|
||||||
}
|
|
||||||
return addr, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func unsafeCloseBind(device *Device) error {
|
|
||||||
var err error
|
|
||||||
netc := &device.net
|
|
||||||
if netc.bind != nil {
|
|
||||||
err = netc.bind.Close()
|
|
||||||
netc.bind = nil
|
|
||||||
}
|
|
||||||
netc.stopping.Wait()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindSetMark(mark uint32) error {
|
|
||||||
|
|
||||||
device.net.Lock()
|
|
||||||
defer device.net.Unlock()
|
|
||||||
|
|
||||||
// check if modified
|
|
||||||
|
|
||||||
if device.net.fwmark == mark {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// update fwmark on existing bind
|
|
||||||
|
|
||||||
device.net.fwmark = mark
|
|
||||||
if device.isUp.Get() && device.net.bind != nil {
|
|
||||||
if err := device.net.bind.SetMark(mark); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear cached source addresses
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindUpdate() error {
|
|
||||||
|
|
||||||
device.net.Lock()
|
|
||||||
defer device.net.Unlock()
|
|
||||||
|
|
||||||
// close existing sockets
|
|
||||||
|
|
||||||
if err := unsafeCloseBind(device); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open new sockets
|
|
||||||
|
|
||||||
if device.isUp.Get() {
|
|
||||||
|
|
||||||
// bind to new port
|
|
||||||
|
|
||||||
var err error
|
|
||||||
netc := &device.net
|
|
||||||
netc.bind, netc.port, err = CreateBind(netc.port, device)
|
|
||||||
if err != nil {
|
|
||||||
netc.bind = nil
|
|
||||||
netc.port = 0
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set fwmark
|
|
||||||
|
|
||||||
if netc.fwmark != 0 {
|
|
||||||
err = netc.bind.SetMark(netc.fwmark)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear cached source addresses
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
|
|
||||||
// start receiving routines
|
|
||||||
|
|
||||||
device.net.starting.Add(ConnRoutineNumber)
|
|
||||||
device.net.stopping.Add(ConnRoutineNumber)
|
|
||||||
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
|
||||||
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
|
||||||
device.net.starting.Wait()
|
|
||||||
|
|
||||||
device.log.Debug.Println("UDP bind has been updated")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindClose() error {
|
|
||||||
device.net.Lock()
|
|
||||||
err := unsafeCloseBind(device)
|
|
||||||
device.net.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
// +build !linux android
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* This code is meant to be a temporary solution
|
|
||||||
* on platforms for which the sticky socket / source caching behavior
|
|
||||||
* has not yet been implemented.
|
|
||||||
*
|
|
||||||
* See conn_linux.go for an implementation on the linux platform.
|
|
||||||
*/
|
|
||||||
|
|
||||||
type nativeBind struct {
|
|
||||||
ipv4 *net.UDPConn
|
|
||||||
ipv6 *net.UDPConn
|
|
||||||
}
|
|
||||||
|
|
||||||
type NativeEndpoint net.UDPAddr
|
|
||||||
|
|
||||||
var _ Bind = (*nativeBind)(nil)
|
|
||||||
var _ Endpoint = (*NativeEndpoint)(nil)
|
|
||||||
|
|
||||||
func CreateEndpoint(s string) (Endpoint, error) {
|
|
||||||
addr, err := parseEndpoint(s)
|
|
||||||
return (*NativeEndpoint)(addr), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (_ *NativeEndpoint) ClearSrc() {}
|
|
||||||
|
|
||||||
func (e *NativeEndpoint) DstIP() net.IP {
|
|
||||||
return (*net.UDPAddr)(e).IP
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *NativeEndpoint) SrcIP() net.IP {
|
|
||||||
return nil // not supported
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *NativeEndpoint) DstToBytes() []byte {
|
|
||||||
addr := (*net.UDPAddr)(e)
|
|
||||||
out := addr.IP.To4()
|
|
||||||
if out == nil {
|
|
||||||
out = addr.IP
|
|
||||||
}
|
|
||||||
out = append(out, byte(addr.Port&0xff))
|
|
||||||
out = append(out, byte((addr.Port>>8)&0xff))
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *NativeEndpoint) DstToString() string {
|
|
||||||
return (*net.UDPAddr)(e).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *NativeEndpoint) SrcToString() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
|
||||||
|
|
||||||
// listen
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// retrieve port
|
|
||||||
|
|
||||||
laddr := conn.LocalAddr()
|
|
||||||
uaddr, err := net.ResolveUDPAddr(
|
|
||||||
laddr.Network(),
|
|
||||||
laddr.String(),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
return conn, uaddr.Port, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractErrno(err error) error {
|
|
||||||
opErr, ok := err.(*net.OpError)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
syscallErr, ok := opErr.Err.(*os.SyscallError)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return syscallErr.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
|
|
||||||
var err error
|
|
||||||
var bind nativeBind
|
|
||||||
|
|
||||||
port := int(uport)
|
|
||||||
|
|
||||||
bind.ipv4, port, err = listenNet("udp4", port)
|
|
||||||
if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
bind.ipv6, port, err = listenNet("udp6", port)
|
|
||||||
if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
|
|
||||||
bind.ipv4.Close()
|
|
||||||
bind.ipv4 = nil
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &bind, uint16(port), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) Close() error {
|
|
||||||
var err1, err2 error
|
|
||||||
if bind.ipv4 != nil {
|
|
||||||
err1 = bind.ipv4.Close()
|
|
||||||
}
|
|
||||||
if bind.ipv6 != nil {
|
|
||||||
err2 = bind.ipv6.Close()
|
|
||||||
}
|
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
|
||||||
if bind.ipv4 == nil {
|
|
||||||
return 0, nil, syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
|
|
||||||
if endpoint != nil {
|
|
||||||
endpoint.IP = endpoint.IP.To4()
|
|
||||||
}
|
|
||||||
return n, (*NativeEndpoint)(endpoint), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
|
||||||
if bind.ipv6 == nil {
|
|
||||||
return 0, nil, syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
|
|
||||||
return n, (*NativeEndpoint)(endpoint), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
|
|
||||||
var err error
|
|
||||||
nend := endpoint.(*NativeEndpoint)
|
|
||||||
if nend.IP.To4() != nil {
|
|
||||||
if bind.ipv4 == nil {
|
|
||||||
return syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
|
||||||
} else {
|
|
||||||
if bind.ipv6 == nil {
|
|
||||||
return syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@@ -1,757 +0,0 @@
|
|||||||
// +build !android
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* This implements userspace semantics of "sticky sockets", modeled after
|
|
||||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
|
||||||
* of the sticky-sockets.c example code:
|
|
||||||
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
|
||||||
*
|
|
||||||
* Currently there is no way to achieve this within the net package:
|
|
||||||
* See e.g. https://github.com/golang/go/issues/17930
|
|
||||||
* So this code is remains platform dependent.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
FD_ERR = -1
|
|
||||||
)
|
|
||||||
|
|
||||||
type IPv4Source struct {
|
|
||||||
src [4]byte
|
|
||||||
ifindex int32
|
|
||||||
}
|
|
||||||
|
|
||||||
type IPv6Source struct {
|
|
||||||
src [16]byte
|
|
||||||
//ifindex belongs in dst.ZoneId
|
|
||||||
}
|
|
||||||
|
|
||||||
type NativeEndpoint struct {
|
|
||||||
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
|
||||||
src [unsafe.Sizeof(IPv6Source{})]byte
|
|
||||||
isV6 bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *NativeEndpoint) src4() *IPv4Source {
|
|
||||||
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *NativeEndpoint) src6() *IPv6Source {
|
|
||||||
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
|
|
||||||
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
|
|
||||||
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
type nativeBind struct {
|
|
||||||
sock4 int
|
|
||||||
sock6 int
|
|
||||||
netlinkSock int
|
|
||||||
netlinkCancel *rwcancel.RWCancel
|
|
||||||
lastMark uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Endpoint = (*NativeEndpoint)(nil)
|
|
||||||
var _ Bind = (*nativeBind)(nil)
|
|
||||||
|
|
||||||
func CreateEndpoint(s string) (Endpoint, error) {
|
|
||||||
var end NativeEndpoint
|
|
||||||
addr, err := parseEndpoint(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv4 := addr.IP.To4()
|
|
||||||
if ipv4 != nil {
|
|
||||||
dst := end.dst4()
|
|
||||||
end.isV6 = false
|
|
||||||
dst.Port = addr.Port
|
|
||||||
copy(dst.Addr[:], ipv4)
|
|
||||||
end.ClearSrc()
|
|
||||||
return &end, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ipv6 := addr.IP.To16()
|
|
||||||
if ipv6 != nil {
|
|
||||||
zone, err := zoneToUint32(addr.Zone)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dst := end.dst6()
|
|
||||||
end.isV6 = true
|
|
||||||
dst.Port = addr.Port
|
|
||||||
dst.ZoneId = zone
|
|
||||||
copy(dst.Addr[:], ipv6[:])
|
|
||||||
end.ClearSrc()
|
|
||||||
return &end, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("Invalid IP address")
|
|
||||||
}
|
|
||||||
|
|
||||||
func createNetlinkRouteSocket() (int, error) {
|
|
||||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
saddr := &unix.SockaddrNetlink{
|
|
||||||
Family: unix.AF_NETLINK,
|
|
||||||
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
|
|
||||||
}
|
|
||||||
err = unix.Bind(sock, saddr)
|
|
||||||
if err != nil {
|
|
||||||
unix.Close(sock)
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return sock, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
|
||||||
var err error
|
|
||||||
var bind nativeBind
|
|
||||||
var newPort uint16
|
|
||||||
|
|
||||||
bind.netlinkSock, err = createNetlinkRouteSocket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
|
|
||||||
if err != nil {
|
|
||||||
unix.Close(bind.netlinkSock)
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
go bind.routineRouteListener(device)
|
|
||||||
|
|
||||||
// attempt ipv6 bind, update port if succesful
|
|
||||||
|
|
||||||
bind.sock6, newPort, err = create6(port)
|
|
||||||
if err != nil {
|
|
||||||
if err != syscall.EAFNOSUPPORT {
|
|
||||||
bind.netlinkCancel.Cancel()
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
port = newPort
|
|
||||||
}
|
|
||||||
|
|
||||||
// attempt ipv4 bind, update port if succesful
|
|
||||||
|
|
||||||
bind.sock4, newPort, err = create4(port)
|
|
||||||
if err != nil {
|
|
||||||
if err != syscall.EAFNOSUPPORT {
|
|
||||||
bind.netlinkCancel.Cancel()
|
|
||||||
unix.Close(bind.sock6)
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
port = newPort
|
|
||||||
}
|
|
||||||
|
|
||||||
if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
|
|
||||||
return nil, 0, errors.New("ipv4 and ipv6 not supported")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &bind, port, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) SetMark(value uint32) error {
|
|
||||||
if bind.sock6 != -1 {
|
|
||||||
err := unix.SetsockoptInt(
|
|
||||||
bind.sock6,
|
|
||||||
unix.SOL_SOCKET,
|
|
||||||
unix.SO_MARK,
|
|
||||||
int(value),
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if bind.sock4 != -1 {
|
|
||||||
err := unix.SetsockoptInt(
|
|
||||||
bind.sock4,
|
|
||||||
unix.SOL_SOCKET,
|
|
||||||
unix.SO_MARK,
|
|
||||||
int(value),
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bind.lastMark = value
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeUnblock(fd int) error {
|
|
||||||
// shutdown to unblock readers and writers
|
|
||||||
unix.Shutdown(fd, unix.SHUT_RDWR)
|
|
||||||
return unix.Close(fd)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) Close() error {
|
|
||||||
var err1, err2, err3 error
|
|
||||||
if bind.sock6 != -1 {
|
|
||||||
err1 = closeUnblock(bind.sock6)
|
|
||||||
}
|
|
||||||
if bind.sock4 != -1 {
|
|
||||||
err2 = closeUnblock(bind.sock4)
|
|
||||||
}
|
|
||||||
err3 = bind.netlinkCancel.Cancel()
|
|
||||||
|
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
|
||||||
if err2 != nil {
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
return err3
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
|
||||||
var end NativeEndpoint
|
|
||||||
if bind.sock6 == -1 {
|
|
||||||
return 0, nil, syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
n, err := receive6(
|
|
||||||
bind.sock6,
|
|
||||||
buff,
|
|
||||||
&end,
|
|
||||||
)
|
|
||||||
return n, &end, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
|
||||||
var end NativeEndpoint
|
|
||||||
if bind.sock4 == -1 {
|
|
||||||
return 0, nil, syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
n, err := receive4(
|
|
||||||
bind.sock4,
|
|
||||||
buff,
|
|
||||||
&end,
|
|
||||||
)
|
|
||||||
return n, &end, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
|
|
||||||
nend := end.(*NativeEndpoint)
|
|
||||||
if !nend.isV6 {
|
|
||||||
if bind.sock4 == -1 {
|
|
||||||
return syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
return send4(bind.sock4, nend, buff)
|
|
||||||
} else {
|
|
||||||
if bind.sock6 == -1 {
|
|
||||||
return syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
return send6(bind.sock6, nend, buff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *NativeEndpoint) SrcIP() net.IP {
|
|
||||||
if !end.isV6 {
|
|
||||||
return net.IPv4(
|
|
||||||
end.src4().src[0],
|
|
||||||
end.src4().src[1],
|
|
||||||
end.src4().src[2],
|
|
||||||
end.src4().src[3],
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
return end.src6().src[:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *NativeEndpoint) DstIP() net.IP {
|
|
||||||
if !end.isV6 {
|
|
||||||
return net.IPv4(
|
|
||||||
end.dst4().Addr[0],
|
|
||||||
end.dst4().Addr[1],
|
|
||||||
end.dst4().Addr[2],
|
|
||||||
end.dst4().Addr[3],
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
return end.dst6().Addr[:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *NativeEndpoint) DstToBytes() []byte {
|
|
||||||
if !end.isV6 {
|
|
||||||
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
|
|
||||||
} else {
|
|
||||||
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *NativeEndpoint) SrcToString() string {
|
|
||||||
return end.SrcIP().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *NativeEndpoint) DstToString() string {
|
|
||||||
var udpAddr net.UDPAddr
|
|
||||||
udpAddr.IP = end.DstIP()
|
|
||||||
if !end.isV6 {
|
|
||||||
udpAddr.Port = end.dst4().Port
|
|
||||||
} else {
|
|
||||||
udpAddr.Port = end.dst6().Port
|
|
||||||
}
|
|
||||||
return udpAddr.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *NativeEndpoint) ClearDst() {
|
|
||||||
for i := range end.dst {
|
|
||||||
end.dst[i] = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (end *NativeEndpoint) ClearSrc() {
|
|
||||||
for i := range end.src {
|
|
||||||
end.src[i] = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func zoneToUint32(zone string) (uint32, error) {
|
|
||||||
if zone == "" {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if intr, err := net.InterfaceByName(zone); err == nil {
|
|
||||||
return uint32(intr.Index), nil
|
|
||||||
}
|
|
||||||
n, err := strconv.ParseUint(zone, 10, 32)
|
|
||||||
return uint32(n), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func create4(port uint16) (int, uint16, error) {
|
|
||||||
|
|
||||||
// create socket
|
|
||||||
|
|
||||||
fd, err := unix.Socket(
|
|
||||||
unix.AF_INET,
|
|
||||||
unix.SOCK_DGRAM,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return FD_ERR, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := unix.SockaddrInet4{
|
|
||||||
Port: int(port),
|
|
||||||
}
|
|
||||||
|
|
||||||
// set sockopts and bind
|
|
||||||
|
|
||||||
if err := func() error {
|
|
||||||
if err := unix.SetsockoptInt(
|
|
||||||
fd,
|
|
||||||
unix.SOL_SOCKET,
|
|
||||||
unix.SO_REUSEADDR,
|
|
||||||
1,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unix.SetsockoptInt(
|
|
||||||
fd,
|
|
||||||
unix.IPPROTO_IP,
|
|
||||||
unix.IP_PKTINFO,
|
|
||||||
1,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return unix.Bind(fd, &addr)
|
|
||||||
}(); err != nil {
|
|
||||||
unix.Close(fd)
|
|
||||||
return FD_ERR, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sa, err := unix.Getsockname(fd)
|
|
||||||
if err == nil {
|
|
||||||
addr.Port = sa.(*unix.SockaddrInet4).Port
|
|
||||||
}
|
|
||||||
|
|
||||||
return fd, uint16(addr.Port), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func create6(port uint16) (int, uint16, error) {
|
|
||||||
|
|
||||||
// create socket
|
|
||||||
|
|
||||||
fd, err := unix.Socket(
|
|
||||||
unix.AF_INET6,
|
|
||||||
unix.SOCK_DGRAM,
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return FD_ERR, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set sockopts and bind
|
|
||||||
|
|
||||||
addr := unix.SockaddrInet6{
|
|
||||||
Port: int(port),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := func() error {
|
|
||||||
|
|
||||||
if err := unix.SetsockoptInt(
|
|
||||||
fd,
|
|
||||||
unix.SOL_SOCKET,
|
|
||||||
unix.SO_REUSEADDR,
|
|
||||||
1,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unix.SetsockoptInt(
|
|
||||||
fd,
|
|
||||||
unix.IPPROTO_IPV6,
|
|
||||||
unix.IPV6_RECVPKTINFO,
|
|
||||||
1,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := unix.SetsockoptInt(
|
|
||||||
fd,
|
|
||||||
unix.IPPROTO_IPV6,
|
|
||||||
unix.IPV6_V6ONLY,
|
|
||||||
1,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return unix.Bind(fd, &addr)
|
|
||||||
|
|
||||||
}(); err != nil {
|
|
||||||
unix.Close(fd)
|
|
||||||
return FD_ERR, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sa, err := unix.Getsockname(fd)
|
|
||||||
if err == nil {
|
|
||||||
addr.Port = sa.(*unix.SockaddrInet6).Port
|
|
||||||
}
|
|
||||||
|
|
||||||
return fd, uint16(addr.Port), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
|
||||||
|
|
||||||
// construct message header
|
|
||||||
|
|
||||||
cmsg := struct {
|
|
||||||
cmsghdr unix.Cmsghdr
|
|
||||||
pktinfo unix.Inet4Pktinfo
|
|
||||||
}{
|
|
||||||
unix.Cmsghdr{
|
|
||||||
Level: unix.IPPROTO_IP,
|
|
||||||
Type: unix.IP_PKTINFO,
|
|
||||||
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
|
||||||
},
|
|
||||||
unix.Inet4Pktinfo{
|
|
||||||
Spec_dst: end.src4().src,
|
|
||||||
Ifindex: end.src4().ifindex,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear src and retry
|
|
||||||
|
|
||||||
if err == unix.EINVAL {
|
|
||||||
end.ClearSrc()
|
|
||||||
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|
||||||
|
|
||||||
// construct message header
|
|
||||||
|
|
||||||
cmsg := struct {
|
|
||||||
cmsghdr unix.Cmsghdr
|
|
||||||
pktinfo unix.Inet6Pktinfo
|
|
||||||
}{
|
|
||||||
unix.Cmsghdr{
|
|
||||||
Level: unix.IPPROTO_IPV6,
|
|
||||||
Type: unix.IPV6_PKTINFO,
|
|
||||||
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
|
|
||||||
},
|
|
||||||
unix.Inet6Pktinfo{
|
|
||||||
Addr: end.src6().src,
|
|
||||||
Ifindex: end.dst6().ZoneId,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmsg.pktinfo.Addr == [16]byte{} {
|
|
||||||
cmsg.pktinfo.Ifindex = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear src and retry
|
|
||||||
|
|
||||||
if err == unix.EINVAL {
|
|
||||||
end.ClearSrc()
|
|
||||||
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|
||||||
|
|
||||||
// contruct message header
|
|
||||||
|
|
||||||
var cmsg struct {
|
|
||||||
cmsghdr unix.Cmsghdr
|
|
||||||
pktinfo unix.Inet4Pktinfo
|
|
||||||
}
|
|
||||||
|
|
||||||
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
end.isV6 = false
|
|
||||||
|
|
||||||
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
|
|
||||||
*end.dst4() = *newDst4
|
|
||||||
}
|
|
||||||
|
|
||||||
// update source cache
|
|
||||||
|
|
||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
|
||||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
|
||||||
end.src4().src = cmsg.pktinfo.Spec_dst
|
|
||||||
end.src4().ifindex = cmsg.pktinfo.Ifindex
|
|
||||||
}
|
|
||||||
|
|
||||||
return size, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|
||||||
|
|
||||||
// contruct message header
|
|
||||||
|
|
||||||
var cmsg struct {
|
|
||||||
cmsghdr unix.Cmsghdr
|
|
||||||
pktinfo unix.Inet6Pktinfo
|
|
||||||
}
|
|
||||||
|
|
||||||
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
end.isV6 = true
|
|
||||||
|
|
||||||
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
|
|
||||||
*end.dst6() = *newDst6
|
|
||||||
}
|
|
||||||
|
|
||||||
// update source cache
|
|
||||||
|
|
||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
|
||||||
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
|
||||||
end.src6().src = cmsg.pktinfo.Addr
|
|
||||||
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
|
|
||||||
}
|
|
||||||
|
|
||||||
return size, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *nativeBind) routineRouteListener(device *Device) {
|
|
||||||
type peerEndpointPtr struct {
|
|
||||||
peer *Peer
|
|
||||||
endpoint *Endpoint
|
|
||||||
}
|
|
||||||
var reqPeer map[uint32]peerEndpointPtr
|
|
||||||
var reqPeerLock sync.Mutex
|
|
||||||
|
|
||||||
defer unix.Close(bind.netlinkSock)
|
|
||||||
|
|
||||||
for msg := make([]byte, 1<<16); ; {
|
|
||||||
var err error
|
|
||||||
var msgn int
|
|
||||||
for {
|
|
||||||
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
|
|
||||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !bind.netlinkCancel.ReadyRead() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
|
||||||
|
|
||||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
|
||||||
|
|
||||||
if uint(hdr.Len) > uint(len(remain)) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
switch hdr.Type {
|
|
||||||
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
|
||||||
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
|
||||||
if uint(len(remain)) < uint(hdr.Len) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
|
||||||
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
|
||||||
for {
|
|
||||||
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
|
||||||
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
|
||||||
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
if reqPeer == nil {
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr, ok := reqPeer[hdr.Seq]
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr.peer.Lock()
|
|
||||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
|
||||||
pePtr.peer.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
|
|
||||||
pePtr.peer.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
|
|
||||||
pePtr.peer.Unlock()
|
|
||||||
}
|
|
||||||
attr = attr[attrhdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
reqPeer = make(map[uint32]peerEndpointPtr)
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
go func() {
|
|
||||||
device.peers.RLock()
|
|
||||||
i := uint32(1)
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.RLock()
|
|
||||||
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
|
|
||||||
peer.RUnlock()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
|
|
||||||
peer.RUnlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
nlmsg := struct {
|
|
||||||
hdr unix.NlMsghdr
|
|
||||||
msg unix.RtMsg
|
|
||||||
dsthdr unix.RtAttr
|
|
||||||
dst [4]byte
|
|
||||||
srchdr unix.RtAttr
|
|
||||||
src [4]byte
|
|
||||||
markhdr unix.RtAttr
|
|
||||||
mark uint32
|
|
||||||
}{
|
|
||||||
unix.NlMsghdr{
|
|
||||||
Type: uint16(unix.RTM_GETROUTE),
|
|
||||||
Flags: unix.NLM_F_REQUEST,
|
|
||||||
Seq: i,
|
|
||||||
},
|
|
||||||
unix.RtMsg{
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Dst_len: 32,
|
|
||||||
Src_len: 32,
|
|
||||||
},
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_DST,
|
|
||||||
},
|
|
||||||
peer.endpoint.(*NativeEndpoint).dst4().Addr,
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_SRC,
|
|
||||||
},
|
|
||||||
peer.endpoint.(*NativeEndpoint).src4().src,
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_MARK,
|
|
||||||
},
|
|
||||||
uint32(bind.lastMark),
|
|
||||||
}
|
|
||||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
reqPeer[i] = peerEndpointPtr{
|
|
||||||
peer: peer,
|
|
||||||
endpoint: &peer.endpoint,
|
|
||||||
}
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
peer.RUnlock()
|
|
||||||
i++
|
|
||||||
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
remain = remain[hdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
/* Specification constants */
|
/* Specification constants */
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
|
RekeyAfterMessages = (1 << 60)
|
||||||
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
RejectAfterMessages = (1 << 64) - (1 << 13) - 1
|
||||||
RekeyAfterTime = time.Second * 120
|
RekeyAfterTime = time.Second * 120
|
||||||
RekeyAttemptTime = time.Second * 90
|
RekeyAttemptTime = time.Second * 90
|
||||||
RekeyTimeout = time.Second * 5
|
RekeyTimeout = time.Second * 5
|
||||||
@@ -35,7 +35,6 @@ const (
|
|||||||
/* Implementation constants */
|
/* Implementation constants */
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UnderLoadQueueSize = QueueHandshakeSize / 8
|
|
||||||
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
|
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
|
||||||
MaxPeers = 1 << 16 // maximum number of configured peers
|
MaxPeers = 1 << 16 // maximum number of configured peers
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
|||||||
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
|
func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
|
||||||
st.RLock()
|
st.RLock()
|
||||||
defer st.RUnlock()
|
defer st.RUnlock()
|
||||||
|
|
||||||
@@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply(
|
|||||||
recv uint32,
|
recv uint32,
|
||||||
src []byte,
|
src []byte,
|
||||||
) (*MessageCookieReply, error) {
|
) (*MessageCookieReply, error) {
|
||||||
|
|
||||||
st.RLock()
|
st.RLock()
|
||||||
|
|
||||||
// refresh cookie secret
|
// refresh cookie secret
|
||||||
@@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
|||||||
|
|
||||||
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
||||||
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
|
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieGenerator) AddMacs(msg []byte) {
|
func (st *CookieGenerator) AddMacs(msg []byte) {
|
||||||
|
|
||||||
size := len(msg)
|
size := len(msg)
|
||||||
|
|
||||||
smac2 := size - blake2s.Size128
|
smac2 := size - blake2s.Size128
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCookieMAC1(t *testing.T) {
|
func TestCookieMAC1(t *testing.T) {
|
||||||
|
|
||||||
// setup generator / checker
|
// setup generator / checker
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) {
|
|||||||
|
|
||||||
msg[5] ^= 0x20
|
msg[5] ^= 0x20
|
||||||
|
|
||||||
srcBad1 := []byte{192, 168, 13, 37, 40, 01}
|
srcBad1 := []byte{192, 168, 13, 37, 40, 1}
|
||||||
if checker.CheckMAC2(msg, srcBad1) {
|
if checker.CheckMAC2(msg, srcBad1) {
|
||||||
t.Fatal("MAC2 generation/verification failed")
|
t.Fatal("MAC2 generation/verification failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
srcBad2 := []byte{192, 168, 13, 38, 40, 01}
|
srcBad2 := []byte{192, 168, 13, 38, 40, 1}
|
||||||
if checker.CheckMAC2(msg, srcBad2) {
|
if checker.CheckMAC2(msg, srcBad2) {
|
||||||
t.Fatal("MAC2 generation/verification failed")
|
t.Fatal("MAC2 generation/verification failed")
|
||||||
}
|
}
|
||||||
|
|||||||
527
device/device.go
527
device/device.go
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -11,37 +11,40 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/ratelimiter"
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"github.com/Lordy82/wireguard-go/ratelimiter"
|
||||||
)
|
"github.com/Lordy82/wireguard-go/rwcancel"
|
||||||
|
"github.com/Lordy82/wireguard-go/tun"
|
||||||
const (
|
|
||||||
DeviceRoutineNumberPerCPU = 3
|
|
||||||
DeviceRoutineNumberAdditional = 2
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
isUp AtomicBool // device is (going) up
|
|
||||||
isClosed AtomicBool // device is closed? (acting as guard)
|
|
||||||
log *Logger
|
|
||||||
|
|
||||||
// synchronized resources (locks acquired in order)
|
|
||||||
|
|
||||||
state struct {
|
state struct {
|
||||||
starting sync.WaitGroup
|
// state holds the device's state. It is accessed atomically.
|
||||||
|
// Use the device.deviceState method to read it.
|
||||||
|
// device.deviceState does not acquire the mutex, so it captures only a snapshot.
|
||||||
|
// During state transitions, the state variable is updated before the device itself.
|
||||||
|
// The state is thus either the current state of the device or
|
||||||
|
// the intended future state of the device.
|
||||||
|
// For example, while executing a call to Up, state will be deviceStateUp.
|
||||||
|
// There is no guarantee that that intended future state of the device
|
||||||
|
// will become the actual state; Up can fail.
|
||||||
|
// The device can also change state multiple times between time of check and time of use.
|
||||||
|
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
||||||
|
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
|
||||||
|
// stopping blocks until all inputs to Device have been closed.
|
||||||
stopping sync.WaitGroup
|
stopping sync.WaitGroup
|
||||||
|
// mu protects state changes.
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
changing AtomicBool
|
|
||||||
current bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
net struct {
|
net struct {
|
||||||
starting sync.WaitGroup
|
|
||||||
stopping sync.WaitGroup
|
stopping sync.WaitGroup
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
bind Bind // bind interface
|
bind conn.Bind // bind interface
|
||||||
port uint16 // listening port
|
netlinkCancel *rwcancel.RWCancel
|
||||||
fwmark uint32 // mark value (0 = disabled)
|
port uint16 // listening port
|
||||||
|
fwmark uint32 // mark value (0 = disabled)
|
||||||
|
brokenRoaming bool
|
||||||
}
|
}
|
||||||
|
|
||||||
staticIdentity struct {
|
staticIdentity struct {
|
||||||
@@ -51,153 +54,176 @@ type Device struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
peers struct {
|
peers struct {
|
||||||
sync.RWMutex
|
sync.RWMutex // protects keyMap
|
||||||
keyMap map[NoisePublicKey]*Peer
|
keyMap map[NoisePublicKey]*Peer
|
||||||
}
|
}
|
||||||
|
|
||||||
// unprotected / "self-synchronising resources"
|
rate struct {
|
||||||
|
underLoadUntil atomic.Int64
|
||||||
|
limiter ratelimiter.Ratelimiter
|
||||||
|
}
|
||||||
|
|
||||||
allowedips AllowedIPs
|
allowedips AllowedIPs
|
||||||
indexTable IndexTable
|
indexTable IndexTable
|
||||||
cookieChecker CookieChecker
|
cookieChecker CookieChecker
|
||||||
|
|
||||||
rate struct {
|
|
||||||
underLoadUntil atomic.Value
|
|
||||||
limiter ratelimiter.Ratelimiter
|
|
||||||
}
|
|
||||||
|
|
||||||
pool struct {
|
pool struct {
|
||||||
messageBufferPool *sync.Pool
|
inboundElementsContainer *WaitPool
|
||||||
messageBufferReuseChan chan *[MaxMessageSize]byte
|
outboundElementsContainer *WaitPool
|
||||||
inboundElementPool *sync.Pool
|
messageBuffers *WaitPool
|
||||||
inboundElementReuseChan chan *QueueInboundElement
|
inboundElements *WaitPool
|
||||||
outboundElementPool *sync.Pool
|
outboundElements *WaitPool
|
||||||
outboundElementReuseChan chan *QueueOutboundElement
|
|
||||||
}
|
}
|
||||||
|
|
||||||
queue struct {
|
queue struct {
|
||||||
encryption chan *QueueOutboundElement
|
encryption *outboundQueue
|
||||||
decryption chan *QueueInboundElement
|
decryption *inboundQueue
|
||||||
handshake chan QueueHandshakeElement
|
handshake *handshakeQueue
|
||||||
}
|
|
||||||
|
|
||||||
signals struct {
|
|
||||||
stop chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tun struct {
|
tun struct {
|
||||||
device tun.Device
|
device tun.Device
|
||||||
mtu int32
|
mtu atomic.Int32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ipcMutex sync.RWMutex
|
||||||
|
closed chan struct{}
|
||||||
|
log *Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Converts the peer into a "zombie", which remains in the peer map,
|
// deviceState represents the state of a Device.
|
||||||
* but processes no packets and does not exists in the routing table.
|
// There are three states: down, up, closed.
|
||||||
*
|
// Transitions:
|
||||||
* Must hold device.peers.Mutex
|
//
|
||||||
*/
|
// down -----+
|
||||||
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
|
// ↑↓ ↓
|
||||||
|
// up -> closed
|
||||||
|
type deviceState uint32
|
||||||
|
|
||||||
|
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
||||||
|
const (
|
||||||
|
deviceStateDown deviceState = iota
|
||||||
|
deviceStateUp
|
||||||
|
deviceStateClosed
|
||||||
|
)
|
||||||
|
|
||||||
|
// deviceState returns device.state.state as a deviceState
|
||||||
|
// See those docs for how to interpret this value.
|
||||||
|
func (device *Device) deviceState() deviceState {
|
||||||
|
return deviceState(device.state.state.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// isClosed reports whether the device is closed (or is closing).
|
||||||
|
// See device.state.state comments for how to interpret this value.
|
||||||
|
func (device *Device) isClosed() bool {
|
||||||
|
return device.deviceState() == deviceStateClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUp reports whether the device is up (or is attempting to come up).
|
||||||
|
// See device.state.state comments for how to interpret this value.
|
||||||
|
func (device *Device) isUp() bool {
|
||||||
|
return device.deviceState() == deviceStateUp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must hold device.peers.Lock()
|
||||||
|
func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
|
||||||
// stop routing and processing of packets
|
// stop routing and processing of packets
|
||||||
|
|
||||||
device.allowedips.RemoveByPeer(peer)
|
device.allowedips.RemoveByPeer(peer)
|
||||||
peer.Stop()
|
peer.Stop()
|
||||||
|
|
||||||
// remove from peer map
|
// remove from peer map
|
||||||
|
|
||||||
delete(device.peers.keyMap, key)
|
delete(device.peers.keyMap, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func deviceUpdateState(device *Device) {
|
// changeState attempts to change the device state to match want.
|
||||||
|
func (device *Device) changeState(want deviceState) (err error) {
|
||||||
// check if state already being updated (guard)
|
|
||||||
|
|
||||||
if device.state.changing.Swap(true) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// compare to current state of device
|
|
||||||
|
|
||||||
device.state.Lock()
|
device.state.Lock()
|
||||||
|
defer device.state.Unlock()
|
||||||
newIsUp := device.isUp.Get()
|
old := device.deviceState()
|
||||||
|
if old == deviceStateClosed {
|
||||||
if newIsUp == device.state.current {
|
// once closed, always closed
|
||||||
device.state.changing.Set(false)
|
device.log.Verbosef("Interface closed, ignored requested state %s", want)
|
||||||
device.state.Unlock()
|
return nil
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
switch want {
|
||||||
// change state of device
|
case old:
|
||||||
|
return nil
|
||||||
switch newIsUp {
|
case deviceStateUp:
|
||||||
case true:
|
device.state.state.Store(uint32(deviceStateUp))
|
||||||
if err := device.BindUpdate(); err != nil {
|
err = device.upLocked()
|
||||||
device.log.Error.Printf("Unable to update bind: %v\n", err)
|
if err == nil {
|
||||||
device.isUp.Set(false)
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
device.peers.RLock()
|
fallthrough // up failed; bring the device all the way back down
|
||||||
for _, peer := range device.peers.keyMap {
|
case deviceStateDown:
|
||||||
peer.Start()
|
device.state.state.Store(uint32(deviceStateDown))
|
||||||
if peer.persistentKeepaliveInterval > 0 {
|
errDown := device.downLocked()
|
||||||
peer.SendKeepalive()
|
if err == nil {
|
||||||
}
|
err = errDown
|
||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
|
||||||
|
|
||||||
case false:
|
|
||||||
device.BindClose()
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.Stop()
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
}
|
}
|
||||||
|
device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
||||||
// update state variables
|
return
|
||||||
|
|
||||||
device.state.current = newIsUp
|
|
||||||
device.state.changing.Set(false)
|
|
||||||
device.state.Unlock()
|
|
||||||
|
|
||||||
// check for state change in the mean time
|
|
||||||
|
|
||||||
deviceUpdateState(device)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) Up() {
|
// upLocked attempts to bring the device up and reports whether it succeeded.
|
||||||
|
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
||||||
// closed device cannot be brought up
|
func (device *Device) upLocked() error {
|
||||||
|
if err := device.BindUpdate(); err != nil {
|
||||||
if device.isClosed.Get() {
|
device.log.Errorf("Unable to update bind: %v", err)
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
device.isUp.Set(true)
|
// The IPC set operation waits for peers to be created before calling Start() on them,
|
||||||
deviceUpdateState(device)
|
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
|
||||||
|
device.ipcMutex.Lock()
|
||||||
|
defer device.ipcMutex.Unlock()
|
||||||
|
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.Start()
|
||||||
|
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||||
|
peer.SendKeepalive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) Down() {
|
// downLocked attempts to bring the device down.
|
||||||
device.isUp.Set(false)
|
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
||||||
deviceUpdateState(device)
|
func (device *Device) downLocked() error {
|
||||||
|
err := device.BindClose()
|
||||||
|
if err != nil {
|
||||||
|
device.log.Errorf("Bind close failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.Stop()
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) Up() error {
|
||||||
|
return device.changeState(deviceStateUp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) Down() error {
|
||||||
|
return device.changeState(deviceStateDown)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) IsUnderLoad() bool {
|
func (device *Device) IsUnderLoad() bool {
|
||||||
|
|
||||||
// check if currently under load
|
// check if currently under load
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
|
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
|
||||||
if underLoad {
|
if underLoad {
|
||||||
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
|
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if recently under load
|
// check if recently under load
|
||||||
|
return device.rate.underLoadUntil.Load() > now.UnixNano()
|
||||||
until := device.rate.underLoadUntil.Load().(time.Time)
|
|
||||||
return until.After(now)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
@@ -224,7 +250,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||||||
publicKey := sk.publicKey()
|
publicKey := sk.publicKey()
|
||||||
for key, peer := range device.peers.keyMap {
|
for key, peer := range device.peers.keyMap {
|
||||||
if peer.handshake.remoteStatic.Equals(publicKey) {
|
if peer.handshake.remoteStatic.Equals(publicKey) {
|
||||||
unsafeRemovePeer(device, peer, key)
|
peer.handshake.mutex.RUnlock()
|
||||||
|
removePeerLocked(device, peer, key)
|
||||||
|
peer.handshake.mutex.RLock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,23 +264,11 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||||||
|
|
||||||
// do static-static DH pre-computations
|
// do static-static DH pre-computations
|
||||||
|
|
||||||
rmKey := device.staticIdentity.privateKey.IsZero()
|
|
||||||
|
|
||||||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||||
for key, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
|
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||||
if rmKey {
|
expiredPeers = append(expiredPeers, peer)
|
||||||
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
|
|
||||||
} else {
|
|
||||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
|
||||||
}
|
|
||||||
|
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
||||||
unsafeRemovePeer(device, peer, key)
|
|
||||||
} else {
|
|
||||||
expiredPeers = append(expiredPeers, peer)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peer := range lockedPeers {
|
for _, peer := range lockedPeers {
|
||||||
@@ -265,68 +281,63 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||||
device := new(Device)
|
device := new(Device)
|
||||||
|
device.state.state.Store(uint32(deviceStateDown))
|
||||||
device.isUp.Set(false)
|
device.closed = make(chan struct{})
|
||||||
device.isClosed.Set(false)
|
|
||||||
|
|
||||||
device.log = logger
|
device.log = logger
|
||||||
|
device.net.bind = bind
|
||||||
device.tun.device = tunDevice
|
device.tun.device = tunDevice
|
||||||
mtu, err := device.tun.device.MTU()
|
mtu, err := device.tun.device.MTU()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Println("Trouble determining MTU, assuming default:", err)
|
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
||||||
mtu = DefaultMTU
|
mtu = DefaultMTU
|
||||||
}
|
}
|
||||||
device.tun.mtu = int32(mtu)
|
device.tun.mtu.Store(int32(mtu))
|
||||||
|
|
||||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||||
|
|
||||||
device.rate.limiter.Init()
|
device.rate.limiter.Init()
|
||||||
device.rate.underLoadUntil.Store(time.Time{})
|
|
||||||
|
|
||||||
device.indexTable.Init()
|
device.indexTable.Init()
|
||||||
device.allowedips.Reset()
|
|
||||||
|
|
||||||
device.PopulatePools()
|
device.PopulatePools()
|
||||||
|
|
||||||
// create queues
|
// create queues
|
||||||
|
|
||||||
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
|
device.queue.handshake = newHandshakeQueue()
|
||||||
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
|
device.queue.encryption = newOutboundQueue()
|
||||||
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
|
device.queue.decryption = newInboundQueue()
|
||||||
|
|
||||||
// prepare signals
|
|
||||||
|
|
||||||
device.signals.stop = make(chan struct{})
|
|
||||||
|
|
||||||
// prepare net
|
|
||||||
|
|
||||||
device.net.port = 0
|
|
||||||
device.net.bind = nil
|
|
||||||
|
|
||||||
// start workers
|
// start workers
|
||||||
|
|
||||||
cpus := runtime.NumCPU()
|
cpus := runtime.NumCPU()
|
||||||
device.state.starting.Wait()
|
|
||||||
device.state.stopping.Wait()
|
device.state.stopping.Wait()
|
||||||
device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
|
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
|
||||||
device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
|
for i := 0; i < cpus; i++ {
|
||||||
for i := 0; i < cpus; i += 1 {
|
go device.RoutineEncryption(i + 1)
|
||||||
go device.RoutineEncryption()
|
go device.RoutineDecryption(i + 1)
|
||||||
go device.RoutineDecryption()
|
go device.RoutineHandshake(i + 1)
|
||||||
go device.RoutineHandshake()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device.state.stopping.Add(1) // RoutineReadFromTUN
|
||||||
|
device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
|
||||||
go device.RoutineReadFromTUN()
|
go device.RoutineReadFromTUN()
|
||||||
go device.RoutineTUNEventReader()
|
go device.RoutineTUNEventReader()
|
||||||
|
|
||||||
device.state.starting.Wait()
|
|
||||||
|
|
||||||
return device
|
return device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the BatchSize for the device as a whole which is the max of
|
||||||
|
// the bind batch size and the tun batch size. The batch size reported by device
|
||||||
|
// is the size used to construct memory pools, and is the allowed batch size for
|
||||||
|
// the lifetime of the device.
|
||||||
|
func (device *Device) BatchSize() int {
|
||||||
|
size := device.net.bind.BatchSize()
|
||||||
|
dSize := device.tun.device.BatchSize()
|
||||||
|
if size < dSize {
|
||||||
|
size = dSize
|
||||||
|
}
|
||||||
|
return size
|
||||||
|
}
|
||||||
|
|
||||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||||
device.peers.RLock()
|
device.peers.RLock()
|
||||||
defer device.peers.RUnlock()
|
defer device.peers.RUnlock()
|
||||||
@@ -341,7 +352,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
|
|||||||
|
|
||||||
peer, ok := device.peers.keyMap[key]
|
peer, ok := device.peers.keyMap[key]
|
||||||
if ok {
|
if ok {
|
||||||
unsafeRemovePeer(device, peer, key)
|
removePeerLocked(device, peer, key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -350,67 +361,50 @@ func (device *Device) RemoveAllPeers() {
|
|||||||
defer device.peers.Unlock()
|
defer device.peers.Unlock()
|
||||||
|
|
||||||
for key, peer := range device.peers.keyMap {
|
for key, peer := range device.peers.keyMap {
|
||||||
unsafeRemovePeer(device, peer, key)
|
removePeerLocked(device, peer, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) FlushPacketQueues() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case elem, ok := <-device.queue.decryption:
|
|
||||||
if ok {
|
|
||||||
elem.Drop()
|
|
||||||
}
|
|
||||||
case elem, ok := <-device.queue.encryption:
|
|
||||||
if ok {
|
|
||||||
elem.Drop()
|
|
||||||
}
|
|
||||||
case <-device.queue.handshake:
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) Close() {
|
func (device *Device) Close() {
|
||||||
if device.isClosed.Swap(true) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
device.state.starting.Wait()
|
|
||||||
|
|
||||||
device.log.Info.Println("Device closing")
|
|
||||||
device.state.changing.Set(true)
|
|
||||||
device.state.Lock()
|
device.state.Lock()
|
||||||
defer device.state.Unlock()
|
defer device.state.Unlock()
|
||||||
|
device.ipcMutex.Lock()
|
||||||
|
defer device.ipcMutex.Unlock()
|
||||||
|
if device.isClosed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
device.state.state.Store(uint32(deviceStateClosed))
|
||||||
|
device.log.Verbosef("Device closing")
|
||||||
|
|
||||||
device.tun.device.Close()
|
device.tun.device.Close()
|
||||||
device.BindClose()
|
device.downLocked()
|
||||||
|
|
||||||
device.isUp.Set(false)
|
|
||||||
|
|
||||||
close(device.signals.stop)
|
|
||||||
|
|
||||||
|
// Remove peers before closing queues,
|
||||||
|
// because peers assume that queues are active.
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
|
|
||||||
|
// We kept a reference to the encryption and decryption queues,
|
||||||
|
// in case we started any new peers that might write to them.
|
||||||
|
// No new peers are coming; we are done with these queues.
|
||||||
|
device.queue.encryption.wg.Done()
|
||||||
|
device.queue.decryption.wg.Done()
|
||||||
|
device.queue.handshake.wg.Done()
|
||||||
device.state.stopping.Wait()
|
device.state.stopping.Wait()
|
||||||
device.FlushPacketQueues()
|
|
||||||
|
|
||||||
device.rate.limiter.Close()
|
device.rate.limiter.Close()
|
||||||
|
|
||||||
device.state.changing.Set(false)
|
device.log.Verbosef("Device closed")
|
||||||
device.log.Info.Println("Interface closed")
|
close(device.closed)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) Wait() chan struct{} {
|
func (device *Device) Wait() chan struct{} {
|
||||||
return device.signals.stop
|
return device.closed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
||||||
if device.isClosed.Get() {
|
if !device.isUp() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,3 +419,118 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
|||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// closeBindLocked closes the device's net.bind.
|
||||||
|
// The caller must hold the net mutex.
|
||||||
|
func closeBindLocked(device *Device) error {
|
||||||
|
var err error
|
||||||
|
netc := &device.net
|
||||||
|
if netc.netlinkCancel != nil {
|
||||||
|
netc.netlinkCancel.Cancel()
|
||||||
|
}
|
||||||
|
if netc.bind != nil {
|
||||||
|
err = netc.bind.Close()
|
||||||
|
}
|
||||||
|
netc.stopping.Wait()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) Bind() conn.Bind {
|
||||||
|
device.net.Lock()
|
||||||
|
defer device.net.Unlock()
|
||||||
|
return device.net.bind
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindSetMark(mark uint32) error {
|
||||||
|
device.net.Lock()
|
||||||
|
defer device.net.Unlock()
|
||||||
|
|
||||||
|
// check if modified
|
||||||
|
if device.net.fwmark == mark {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// update fwmark on existing bind
|
||||||
|
device.net.fwmark = mark
|
||||||
|
if device.isUp() && device.net.bind != nil {
|
||||||
|
if err := device.net.bind.SetMark(mark); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear cached source addresses
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.markEndpointSrcForClearing()
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindUpdate() error {
|
||||||
|
device.net.Lock()
|
||||||
|
defer device.net.Unlock()
|
||||||
|
|
||||||
|
// close existing sockets
|
||||||
|
if err := closeBindLocked(device); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// open new sockets
|
||||||
|
if !device.isUp() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bind to new port
|
||||||
|
var err error
|
||||||
|
var recvFns []conn.ReceiveFunc
|
||||||
|
netc := &device.net
|
||||||
|
|
||||||
|
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
||||||
|
if err != nil {
|
||||||
|
netc.port = 0
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||||
|
if err != nil {
|
||||||
|
netc.bind.Close()
|
||||||
|
netc.port = 0
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set fwmark
|
||||||
|
if netc.fwmark != 0 {
|
||||||
|
err = netc.bind.SetMark(netc.fwmark)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear cached source addresses
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.markEndpointSrcForClearing()
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
|
||||||
|
// start receiving routines
|
||||||
|
device.net.stopping.Add(len(recvFns))
|
||||||
|
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||||
|
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||||
|
batchSize := netc.bind.BatchSize()
|
||||||
|
for _, fn := range recvFns {
|
||||||
|
go device.RoutineReceiveIncoming(batchSize, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
device.log.Verbosef("UDP bind has been updated")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindClose() error {
|
||||||
|
device.net.Lock()
|
||||||
|
err := closeBindLocked(device)
|
||||||
|
device.net.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,68 +1,476 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
/* Create two device instances and simulate full WireGuard interaction
|
|
||||||
* without network dependencies
|
|
||||||
*/
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"runtime/pprof"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
|
"github.com/Lordy82/wireguard-go/conn/bindtest"
|
||||||
|
"github.com/Lordy82/wireguard-go/tun"
|
||||||
|
"github.com/Lordy82/wireguard-go/tun/tuntest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDevice(t *testing.T) {
|
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
|
||||||
|
// cfg is a series of alternating key/value strings.
|
||||||
// prepare tun devices for generating traffic
|
// uapiCfg exists because editors and humans like to insert
|
||||||
|
// whitespace into configs, which can cause failures, some of which are silent.
|
||||||
tun1 := newDummyTUN("tun1")
|
// For example, a leading blank newline causes the remainder
|
||||||
tun2 := newDummyTUN("tun2")
|
// of the config to be silently ignored.
|
||||||
|
func uapiCfg(cfg ...string) string {
|
||||||
_ = tun1
|
if len(cfg)%2 != 0 {
|
||||||
_ = tun2
|
panic("odd number of args to uapiReader")
|
||||||
|
|
||||||
// prepare endpoints
|
|
||||||
|
|
||||||
end1, err := CreateDummyEndpoint()
|
|
||||||
if err != nil {
|
|
||||||
t.Error("failed to create endpoint:", err.Error())
|
|
||||||
}
|
}
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
end2, err := CreateDummyEndpoint()
|
for i, s := range cfg {
|
||||||
if err != nil {
|
buf.WriteString(s)
|
||||||
t.Error("failed to create endpoint:", err.Error())
|
sep := byte('\n')
|
||||||
|
if i%2 == 0 {
|
||||||
|
sep = '='
|
||||||
|
}
|
||||||
|
buf.WriteByte(sep)
|
||||||
}
|
}
|
||||||
|
return buf.String()
|
||||||
_ = end1
|
|
||||||
_ = end2
|
|
||||||
|
|
||||||
// create binds
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func randDevice(t *testing.T) *Device {
|
// genConfigs generates a pair of configs that connect to each other.
|
||||||
sk, err := newPrivateKey()
|
// The configs use distinct, probably-usable ports.
|
||||||
|
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
||||||
|
var key1, key2 NoisePrivateKey
|
||||||
|
_, err := rand.Read(key1[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||||
}
|
}
|
||||||
tun := newDummyTUN("dummy")
|
_, err = rand.Read(key2[:])
|
||||||
logger := NewLogger(LogLevelError, "")
|
if err != nil {
|
||||||
device := NewDevice(tun, logger)
|
tb.Errorf("unable to generate private key random bytes: %v", err)
|
||||||
device.SetPrivateKey(sk)
|
}
|
||||||
return device
|
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
||||||
|
|
||||||
|
cfgs[0] = uapiCfg(
|
||||||
|
"private_key", hex.EncodeToString(key1[:]),
|
||||||
|
"listen_port", "0",
|
||||||
|
"replace_peers", "true",
|
||||||
|
"public_key", hex.EncodeToString(pub2[:]),
|
||||||
|
"protocol_version", "1",
|
||||||
|
"replace_allowed_ips", "true",
|
||||||
|
"allowed_ip", "1.0.0.2/32",
|
||||||
|
)
|
||||||
|
endpointCfgs[0] = uapiCfg(
|
||||||
|
"public_key", hex.EncodeToString(pub2[:]),
|
||||||
|
"endpoint", "127.0.0.1:%d",
|
||||||
|
)
|
||||||
|
cfgs[1] = uapiCfg(
|
||||||
|
"private_key", hex.EncodeToString(key2[:]),
|
||||||
|
"listen_port", "0",
|
||||||
|
"replace_peers", "true",
|
||||||
|
"public_key", hex.EncodeToString(pub1[:]),
|
||||||
|
"protocol_version", "1",
|
||||||
|
"replace_allowed_ips", "true",
|
||||||
|
"allowed_ip", "1.0.0.1/32",
|
||||||
|
)
|
||||||
|
endpointCfgs[1] = uapiCfg(
|
||||||
|
"public_key", hex.EncodeToString(pub1[:]),
|
||||||
|
"endpoint", "127.0.0.1:%d",
|
||||||
|
)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertNil(t *testing.T, err error) {
|
// A testPair is a pair of testPeers.
|
||||||
|
type testPair [2]testPeer
|
||||||
|
|
||||||
|
// A testPeer is a peer used for testing.
|
||||||
|
type testPeer struct {
|
||||||
|
tun *tuntest.ChannelTUN
|
||||||
|
dev *Device
|
||||||
|
ip netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
type SendDirection bool
|
||||||
|
|
||||||
|
const (
|
||||||
|
Ping SendDirection = true
|
||||||
|
Pong SendDirection = false
|
||||||
|
)
|
||||||
|
|
||||||
|
func (d SendDirection) String() string {
|
||||||
|
if d == Ping {
|
||||||
|
return "ping"
|
||||||
|
}
|
||||||
|
return "pong"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
|
||||||
|
tb.Helper()
|
||||||
|
p0, p1 := pair[0], pair[1]
|
||||||
|
if !ping {
|
||||||
|
// pong is the new ping
|
||||||
|
p0, p1 = p1, p0
|
||||||
|
}
|
||||||
|
msg := tuntest.Ping(p0.ip, p1.ip)
|
||||||
|
p1.tun.Outbound <- msg
|
||||||
|
timer := time.NewTimer(5 * time.Second)
|
||||||
|
defer timer.Stop()
|
||||||
|
var err error
|
||||||
|
select {
|
||||||
|
case msgRecv := <-p0.tun.Inbound:
|
||||||
|
if !bytes.Equal(msg, msgRecv) {
|
||||||
|
err = fmt.Errorf("%s did not transit correctly", ping)
|
||||||
|
}
|
||||||
|
case <-timer.C:
|
||||||
|
err = fmt.Errorf("%s did not transit", ping)
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
// The error may have occurred because the test is done.
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
// Real error.
|
||||||
|
tb.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertEqual(t *testing.T, a, b []byte) {
|
// genTestPair creates a testPair.
|
||||||
if !bytes.Equal(a, b) {
|
func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
||||||
t.Fatal(a, "!=", b)
|
cfg, endpointCfg := genConfigs(tb)
|
||||||
|
var binds [2]conn.Bind
|
||||||
|
if realSocket {
|
||||||
|
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
||||||
|
} else {
|
||||||
|
binds = bindtest.NewChannelBinds()
|
||||||
|
}
|
||||||
|
// Bring up a ChannelTun for each config.
|
||||||
|
for i := range pair {
|
||||||
|
p := &pair[i]
|
||||||
|
p.tun = tuntest.NewChannelTUN()
|
||||||
|
p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
|
||||||
|
level := LogLevelVerbose
|
||||||
|
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||||
|
level = LogLevelError
|
||||||
|
}
|
||||||
|
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
||||||
|
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
||||||
|
tb.Errorf("failed to configure device %d: %v", i, err)
|
||||||
|
p.dev.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := p.dev.Up(); err != nil {
|
||||||
|
tb.Errorf("failed to bring up device %d: %v", i, err)
|
||||||
|
p.dev.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
|
||||||
|
}
|
||||||
|
for i := range pair {
|
||||||
|
p := &pair[i]
|
||||||
|
if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
|
||||||
|
tb.Errorf("failed to configure device endpoint %d: %v", i, err)
|
||||||
|
p.dev.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// The device is ready. Close it when the test completes.
|
||||||
|
tb.Cleanup(p.dev.Close)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTwoDevicePing(t *testing.T) {
|
||||||
|
goroutineLeakCheck(t)
|
||||||
|
pair := genTestPair(t, true)
|
||||||
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
|
pair.Send(t, Ping, nil)
|
||||||
|
})
|
||||||
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
|
pair.Send(t, Pong, nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpDown(t *testing.T) {
|
||||||
|
goroutineLeakCheck(t)
|
||||||
|
const itrials = 50
|
||||||
|
const otrials = 10
|
||||||
|
|
||||||
|
for n := 0; n < otrials; n++ {
|
||||||
|
pair := genTestPair(t, false)
|
||||||
|
for i := range pair {
|
||||||
|
for k := range pair[i].dev.peers.keyMap {
|
||||||
|
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(pair))
|
||||||
|
for i := range pair {
|
||||||
|
go func(d *Device) {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < itrials; i++ {
|
||||||
|
if err := d.Up(); err != nil {
|
||||||
|
t.Errorf("failed up bring up device: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
|
||||||
|
if err := d.Down(); err != nil {
|
||||||
|
t.Errorf("failed to bring down device: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
|
||||||
|
}
|
||||||
|
}(pair[i].dev)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
for i := range pair {
|
||||||
|
pair[i].dev.Up()
|
||||||
|
pair[i].dev.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcurrencySafety does other things concurrently with tunnel use.
|
||||||
|
// It is intended to be used with the race detector to catch data races.
|
||||||
|
func TestConcurrencySafety(t *testing.T) {
|
||||||
|
pair := genTestPair(t, true)
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
const warmupIters = 10
|
||||||
|
var warmup sync.WaitGroup
|
||||||
|
warmup.Add(warmupIters)
|
||||||
|
go func() {
|
||||||
|
// Send data continuously back and forth until we're done.
|
||||||
|
// Note that we may continue to attempt to send data
|
||||||
|
// even after done is closed.
|
||||||
|
i := warmupIters
|
||||||
|
for ping := Ping; ; ping = !ping {
|
||||||
|
pair.Send(t, ping, done)
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if i > 0 {
|
||||||
|
warmup.Done()
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
warmup.Wait()
|
||||||
|
|
||||||
|
applyCfg := func(cfg string) {
|
||||||
|
err := pair[0].dev.IpcSet(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change persistent_keepalive_interval concurrently with tunnel use.
|
||||||
|
t.Run("persistentKeepaliveInterval", func(t *testing.T) {
|
||||||
|
var pub NoisePublicKey
|
||||||
|
for key := range pair[0].dev.peers.keyMap {
|
||||||
|
pub = key
|
||||||
|
break
|
||||||
|
}
|
||||||
|
cfg := uapiCfg(
|
||||||
|
"public_key", hex.EncodeToString(pub[:]),
|
||||||
|
"persistent_keepalive_interval", "1",
|
||||||
|
)
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
applyCfg(cfg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Change private keys concurrently with tunnel use.
|
||||||
|
t.Run("privateKey", func(t *testing.T) {
|
||||||
|
bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
|
||||||
|
good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
|
||||||
|
// Set iters to a large number like 1000 to flush out data races quickly.
|
||||||
|
// Don't leave it large. That can cause logical races
|
||||||
|
// in which the handshake is interleaved with key changes
|
||||||
|
// such that the private key appears to be unchanging but
|
||||||
|
// other state gets reset, which can cause handshake failures like
|
||||||
|
// "Received packet with invalid mac1".
|
||||||
|
const iters = 1
|
||||||
|
for i := 0; i < iters; i++ {
|
||||||
|
applyCfg(bad)
|
||||||
|
applyCfg(good)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Perform bind updates and keepalive sends concurrently with tunnel use.
|
||||||
|
t.Run("bindUpdate and keepalive", func(t *testing.T) {
|
||||||
|
const iters = 10
|
||||||
|
for i := 0; i < iters; i++ {
|
||||||
|
for _, peer := range pair {
|
||||||
|
peer.dev.BindUpdate()
|
||||||
|
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
close(done)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLatency(b *testing.B) {
|
||||||
|
pair := genTestPair(b, true)
|
||||||
|
|
||||||
|
// Establish a connection.
|
||||||
|
pair.Send(b, Ping, nil)
|
||||||
|
pair.Send(b, Pong, nil)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pair.Send(b, Ping, nil)
|
||||||
|
pair.Send(b, Pong, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkThroughput(b *testing.B) {
|
||||||
|
pair := genTestPair(b, true)
|
||||||
|
|
||||||
|
// Establish a connection.
|
||||||
|
pair.Send(b, Ping, nil)
|
||||||
|
pair.Send(b, Pong, nil)
|
||||||
|
|
||||||
|
// Measure how long it takes to receive b.N packets,
|
||||||
|
// starting when we receive the first packet.
|
||||||
|
var recv atomic.Uint64
|
||||||
|
var elapsed time.Duration
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
var start time.Time
|
||||||
|
for {
|
||||||
|
<-pair[0].tun.Inbound
|
||||||
|
new := recv.Add(1)
|
||||||
|
if new == 1 {
|
||||||
|
start = time.Now()
|
||||||
|
}
|
||||||
|
// Careful! Don't change this to else if; b.N can be equal to 1.
|
||||||
|
if new == uint64(b.N) {
|
||||||
|
elapsed = time.Since(start)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Send packets as fast as we can until we've received enough.
|
||||||
|
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
|
||||||
|
pingc := pair[1].tun.Outbound
|
||||||
|
var sent uint64
|
||||||
|
for recv.Load() != uint64(b.N) {
|
||||||
|
sent++
|
||||||
|
pingc <- ping
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op")
|
||||||
|
b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUAPIGet(b *testing.B) {
|
||||||
|
pair := genTestPair(b, true)
|
||||||
|
pair.Send(b, Ping, nil)
|
||||||
|
pair.Send(b, Pong, nil)
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
pair[0].dev.IpcGetOperation(io.Discard)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func goroutineLeakCheck(t *testing.T) {
|
||||||
|
goroutines := func() (int, []byte) {
|
||||||
|
p := pprof.Lookup("goroutine")
|
||||||
|
b := new(bytes.Buffer)
|
||||||
|
p.WriteTo(b, 1)
|
||||||
|
return p.Count(), b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
startGoroutines, startStacks := goroutines()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if t.Failed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Give goroutines time to exit, if they need it.
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
if runtime.NumGoroutine() <= startGoroutines {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
}
|
||||||
|
endGoroutines, endStacks := goroutines()
|
||||||
|
t.Logf("starting stacks:\n%s\n", startStacks)
|
||||||
|
t.Logf("ending stacks:\n%s\n", endStacks)
|
||||||
|
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeBindSized struct {
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
||||||
|
return nil, 0, nil
|
||||||
|
}
|
||||||
|
func (b *fakeBindSized) Close() error { return nil }
|
||||||
|
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
|
||||||
|
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
|
||||||
|
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
|
||||||
|
func (b *fakeBindSized) BatchSize() int { return b.size }
|
||||||
|
|
||||||
|
type fakeTUNDeviceSized struct {
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
||||||
|
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
|
||||||
|
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
|
||||||
|
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
|
||||||
|
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
|
||||||
|
func (t *fakeTUNDeviceSized) Close() error { return nil }
|
||||||
|
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
|
||||||
|
|
||||||
|
func TestBatchSize(t *testing.T) {
|
||||||
|
d := Device{}
|
||||||
|
|
||||||
|
d.net.bind = &fakeBindSized{1}
|
||||||
|
d.tun.device = &fakeTUNDeviceSized{1}
|
||||||
|
if want, got := 1, d.BatchSize(); got != want {
|
||||||
|
t.Errorf("expected batch size %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.net.bind = &fakeBindSized{1}
|
||||||
|
d.tun.device = &fakeTUNDeviceSized{128}
|
||||||
|
if want, got := 128, d.BatchSize(); got != want {
|
||||||
|
t.Errorf("expected batch size %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.net.bind = &fakeBindSized{128}
|
||||||
|
d.tun.device = &fakeTUNDeviceSized{1}
|
||||||
|
if want, got := 128, d.BatchSize(); got != want {
|
||||||
|
t.Errorf("expected batch size %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.net.bind = &fakeBindSized{128}
|
||||||
|
d.tun.device = &fakeTUNDeviceSized{128}
|
||||||
|
if want, got := 128, d.BatchSize(); got != want {
|
||||||
|
t.Errorf("expected batch size %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
16
device/devicestate_string.go
Normal file
16
device/devicestate_string.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
const _deviceState_name = "DownUpClosed"
|
||||||
|
|
||||||
|
var _deviceState_index = [...]uint8{0, 4, 6, 12}
|
||||||
|
|
||||||
|
func (i deviceState) String() string {
|
||||||
|
if i >= deviceState(len(_deviceState_index)-1) {
|
||||||
|
return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||||
|
}
|
||||||
|
return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
|
||||||
|
}
|
||||||
@@ -1,53 +1,49 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DummyEndpoint struct {
|
type DummyEndpoint struct {
|
||||||
src [16]byte
|
src, dst netip.Addr
|
||||||
dst [16]byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateDummyEndpoint() (*DummyEndpoint, error) {
|
func CreateDummyEndpoint() (*DummyEndpoint, error) {
|
||||||
var end DummyEndpoint
|
var src, dst [16]byte
|
||||||
if _, err := rand.Read(end.src[:]); err != nil {
|
if _, err := rand.Read(src[:]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
_, err := rand.Read(end.dst[:])
|
_, err := rand.Read(dst[:])
|
||||||
return &end, err
|
return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) ClearSrc() {}
|
func (e *DummyEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcToString() string {
|
func (e *DummyEndpoint) SrcToString() string {
|
||||||
var addr net.UDPAddr
|
return netip.AddrPortFrom(e.SrcIP(), 1000).String()
|
||||||
addr.IP = e.SrcIP()
|
|
||||||
addr.Port = 1000
|
|
||||||
return addr.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) DstToString() string {
|
func (e *DummyEndpoint) DstToString() string {
|
||||||
var addr net.UDPAddr
|
return netip.AddrPortFrom(e.DstIP(), 1000).String()
|
||||||
addr.IP = e.DstIP()
|
|
||||||
addr.Port = 1000
|
|
||||||
return addr.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcToBytes() []byte {
|
func (e *DummyEndpoint) DstToBytes() []byte {
|
||||||
return e.src[:]
|
out := e.DstIP().AsSlice()
|
||||||
|
out = append(out, byte(1000&0xff))
|
||||||
|
out = append(out, byte((1000>>8)&0xff))
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) DstIP() net.IP {
|
func (e *DummyEndpoint) DstIP() netip.Addr {
|
||||||
return e.dst[:]
|
return e.dst
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcIP() net.IP {
|
func (e *DummyEndpoint) SrcIP() netip.Addr {
|
||||||
return e.src[:]
|
return e.src
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type IndexTableEntry struct {
|
type IndexTableEntry struct {
|
||||||
@@ -25,7 +25,8 @@ type IndexTable struct {
|
|||||||
func randUint32() (uint32, error) {
|
func randUint32() (uint32, error) {
|
||||||
var integer [4]byte
|
var integer [4]byte
|
||||||
_, err := rand.Read(integer[:])
|
_, err := rand.Read(integer[:])
|
||||||
return *(*uint32)(unsafe.Pointer(&integer[0])), err
|
// Arbitrary endianness; both are intrinsified by the Go compiler.
|
||||||
|
return binary.LittleEndian.Uint32(integer[:]), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *IndexTable) Init() {
|
func (table *IndexTable) Init() {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -20,7 +20,7 @@ type KDFTest struct {
|
|||||||
t2 string
|
t2 string
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertEquals(t *testing.T, a string, b string) {
|
func assertEquals(t *testing.T, a, b string) {
|
||||||
if a != b {
|
if a != b {
|
||||||
t.Fatal("expected", a, "=", b)
|
t.Fatal("expected", a, "=", b)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -8,9 +8,10 @@ package device
|
|||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/replay"
|
"github.com/Lordy82/wireguard-go/replay"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Due to limitations in Go and /x/crypto there is currently
|
/* Due to limitations in Go and /x/crypto there is currently
|
||||||
@@ -21,10 +22,10 @@ import (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
type Keypair struct {
|
type Keypair struct {
|
||||||
sendNonce uint64
|
sendNonce atomic.Uint64
|
||||||
send cipher.AEAD
|
send cipher.AEAD
|
||||||
receive cipher.AEAD
|
receive cipher.AEAD
|
||||||
replayFilter replay.ReplayFilter
|
replayFilter replay.Filter
|
||||||
isInitiator bool
|
isInitiator bool
|
||||||
created time.Time
|
created time.Time
|
||||||
localIndex uint32
|
localIndex uint32
|
||||||
@@ -35,7 +36,7 @@ type Keypairs struct {
|
|||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
current *Keypair
|
current *Keypair
|
||||||
previous *Keypair
|
previous *Keypair
|
||||||
next *Keypair
|
next atomic.Pointer[Keypair]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kp *Keypairs) Current() *Keypair {
|
func (kp *Keypairs) Current() *Keypair {
|
||||||
|
|||||||
@@ -1,59 +1,48 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A Logger provides logging for a Device.
|
||||||
|
// The functions are Printf-style functions.
|
||||||
|
// They must be safe for concurrent use.
|
||||||
|
// They do not require a trailing newline in the format.
|
||||||
|
// If nil, that level of logging will be silent.
|
||||||
|
type Logger struct {
|
||||||
|
Verbosef func(format string, args ...any)
|
||||||
|
Errorf func(format string, args ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log levels for use with NewLogger.
|
||||||
const (
|
const (
|
||||||
LogLevelSilent = iota
|
LogLevelSilent = iota
|
||||||
LogLevelError
|
LogLevelError
|
||||||
LogLevelInfo
|
LogLevelVerbose
|
||||||
LogLevelDebug
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Logger struct {
|
// Function for use in Logger for discarding logged lines.
|
||||||
Debug *log.Logger
|
func DiscardLogf(format string, args ...any) {}
|
||||||
Info *log.Logger
|
|
||||||
Error *log.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// NewLogger constructs a Logger that writes to stdout.
|
||||||
|
// It logs at the specified log level and above.
|
||||||
|
// It decorates log lines with the log level, date, time, and prepend.
|
||||||
func NewLogger(level int, prepend string) *Logger {
|
func NewLogger(level int, prepend string) *Logger {
|
||||||
output := os.Stdout
|
logger := &Logger{DiscardLogf, DiscardLogf}
|
||||||
logger := new(Logger)
|
logf := func(prefix string) func(string, ...any) {
|
||||||
|
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
|
||||||
logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
|
}
|
||||||
if level >= LogLevelDebug {
|
if level >= LogLevelVerbose {
|
||||||
return output, output, output
|
logger.Verbosef = logf("DEBUG")
|
||||||
}
|
}
|
||||||
if level >= LogLevelInfo {
|
if level >= LogLevelError {
|
||||||
return output, output, ioutil.Discard
|
logger.Errorf = logf("ERROR")
|
||||||
}
|
}
|
||||||
if level >= LogLevelError {
|
|
||||||
return output, ioutil.Discard, ioutil.Discard
|
|
||||||
}
|
|
||||||
return ioutil.Discard, ioutil.Discard, ioutil.Discard
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.Debug = log.New(logDebug,
|
|
||||||
"DEBUG: "+prepend,
|
|
||||||
log.Ldate|log.Ltime,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.Info = log.New(logInfo,
|
|
||||||
"INFO: "+prepend,
|
|
||||||
log.Ldate|log.Ltime,
|
|
||||||
)
|
|
||||||
logger.Error = log.New(logErr,
|
|
||||||
"ERROR: "+prepend,
|
|
||||||
log.Ldate|log.Ltime,
|
|
||||||
)
|
|
||||||
return logger
|
return logger
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
// +build !linux,!openbsd,!freebsd
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
func (bind *nativeBind) SetMark(mark uint32) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* Atomic Boolean */
|
|
||||||
|
|
||||||
const (
|
|
||||||
AtomicFalse = int32(iota)
|
|
||||||
AtomicTrue
|
|
||||||
)
|
|
||||||
|
|
||||||
type AtomicBool struct {
|
|
||||||
int32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicBool) Get() bool {
|
|
||||||
return atomic.LoadInt32(&a.int32) == AtomicTrue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicBool) Swap(val bool) bool {
|
|
||||||
flag := AtomicFalse
|
|
||||||
if val {
|
|
||||||
flag = AtomicTrue
|
|
||||||
}
|
|
||||||
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AtomicBool) Set(val bool) {
|
|
||||||
flag := AtomicFalse
|
|
||||||
if val {
|
|
||||||
flag = AtomicTrue
|
|
||||||
}
|
|
||||||
atomic.StoreInt32(&a.int32, flag)
|
|
||||||
}
|
|
||||||
|
|
||||||
func min(a, b uint) uint {
|
|
||||||
if a > b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
19
device/mobilequirks.go
Normal file
19
device/mobilequirks.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
|
||||||
|
// though it will try to deal with it, and race maybe, if called after.
|
||||||
|
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
||||||
|
device.net.brokenRoaming = true
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.endpoint.Lock()
|
||||||
|
peer.endpoint.disableRoaming = peer.endpoint.val != nil
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"errors"
|
||||||
"hash"
|
"hash"
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
@@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
|
var errInvalidPublicKey = errors.New("invalid public key")
|
||||||
|
|
||||||
|
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
|
||||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||||
curve25519.ScalarMult(&ss, ask, apk)
|
curve25519.ScalarMult(&ss, ask, apk)
|
||||||
return ss
|
if isZero(ss[:]) {
|
||||||
|
return ss, errInvalidPublicKey
|
||||||
|
}
|
||||||
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +1,50 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/poly1305"
|
"golang.org/x/crypto/poly1305"
|
||||||
"golang.zx2c4.com/wireguard/tai64n"
|
|
||||||
|
"github.com/Lordy82/wireguard-go/tai64n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type handshakeState int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
HandshakeZeroed = iota
|
handshakeZeroed = handshakeState(iota)
|
||||||
HandshakeInitiationCreated
|
handshakeInitiationCreated
|
||||||
HandshakeInitiationConsumed
|
handshakeInitiationConsumed
|
||||||
HandshakeResponseCreated
|
handshakeResponseCreated
|
||||||
HandshakeResponseConsumed
|
handshakeResponseConsumed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (hs handshakeState) String() string {
|
||||||
|
switch hs {
|
||||||
|
case handshakeZeroed:
|
||||||
|
return "handshakeZeroed"
|
||||||
|
case handshakeInitiationCreated:
|
||||||
|
return "handshakeInitiationCreated"
|
||||||
|
case handshakeInitiationConsumed:
|
||||||
|
return "handshakeInitiationConsumed"
|
||||||
|
case handshakeResponseCreated:
|
||||||
|
return "handshakeResponseCreated"
|
||||||
|
case handshakeResponseConsumed:
|
||||||
|
return "handshakeResponseConsumed"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
||||||
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
||||||
@@ -39,13 +60,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MessageInitiationSize = 148 // size of handshake initation message
|
MessageInitiationSize = 148 // size of handshake initiation message
|
||||||
MessageResponseSize = 92 // size of response message
|
MessageResponseSize = 92 // size of response message
|
||||||
MessageCookieReplySize = 64 // size of cookie reply message
|
MessageCookieReplySize = 64 // size of cookie reply message
|
||||||
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
|
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
|
||||||
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
||||||
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
||||||
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
|
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -95,11 +116,11 @@ type MessageCookieReply struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Handshake struct {
|
type Handshake struct {
|
||||||
state int
|
state handshakeState
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
hash [blake2s.Size]byte // hash value
|
hash [blake2s.Size]byte // hash value
|
||||||
chainKey [blake2s.Size]byte // chain key
|
chainKey [blake2s.Size]byte // chain key
|
||||||
presharedKey NoiseSymmetricKey // psk
|
presharedKey NoisePresharedKey // psk
|
||||||
localEphemeral NoisePrivateKey // ephemeral secret key
|
localEphemeral NoisePrivateKey // ephemeral secret key
|
||||||
localIndex uint32 // used to clear hash-table
|
localIndex uint32 // used to clear hash-table
|
||||||
remoteIndex uint32 // index for sending
|
remoteIndex uint32 // index for sending
|
||||||
@@ -117,11 +138,11 @@ var (
|
|||||||
ZeroNonce [chacha20poly1305.NonceSize]byte
|
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||||
)
|
)
|
||||||
|
|
||||||
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
|
func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
|
||||||
KDF1(dst, c[:], data)
|
KDF1(dst, c[:], data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
|
func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
|
||||||
hash, _ := blake2s.New256(nil)
|
hash, _ := blake2s.New256(nil)
|
||||||
hash.Write(h[:])
|
hash.Write(h[:])
|
||||||
hash.Write(data)
|
hash.Write(data)
|
||||||
@@ -135,7 +156,7 @@ func (h *Handshake) Clear() {
|
|||||||
setZero(h.chainKey[:])
|
setZero(h.chainKey[:])
|
||||||
setZero(h.hash[:])
|
setZero(h.hash[:])
|
||||||
h.localIndex = 0
|
h.localIndex = 0
|
||||||
h.state = HandshakeZeroed
|
h.state = handshakeZeroed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) mixHash(data []byte) {
|
func (h *Handshake) mixHash(data []byte) {
|
||||||
@@ -154,7 +175,6 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
device.staticIdentity.RLock()
|
||||||
defer device.staticIdentity.RUnlock()
|
defer device.staticIdentity.RUnlock()
|
||||||
|
|
||||||
@@ -162,12 +182,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
||||||
return nil, errors.New("static shared secret is zero")
|
|
||||||
}
|
|
||||||
|
|
||||||
// create ephemeral key
|
// create ephemeral key
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
handshake.hash = InitialHash
|
handshake.hash = InitialHash
|
||||||
handshake.chainKey = InitialChainKey
|
handshake.chainKey = InitialChainKey
|
||||||
@@ -176,59 +191,56 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// assign index
|
|
||||||
|
|
||||||
device.indexTable.Delete(handshake.localIndex)
|
|
||||||
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
handshake.mixHash(handshake.remoteStatic[:])
|
handshake.mixHash(handshake.remoteStatic[:])
|
||||||
|
|
||||||
msg := MessageInitiation{
|
msg := MessageInitiation{
|
||||||
Type: MessageInitiationType,
|
Type: MessageInitiationType,
|
||||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||||
Sender: handshake.localIndex,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
handshake.mixKey(msg.Ephemeral[:])
|
handshake.mixKey(msg.Ephemeral[:])
|
||||||
handshake.mixHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
// encrypt static key
|
// encrypt static key
|
||||||
|
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
func() {
|
if err != nil {
|
||||||
var key [chacha20poly1305.KeySize]byte
|
return nil, err
|
||||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
}
|
||||||
KDF2(
|
var key [chacha20poly1305.KeySize]byte
|
||||||
&handshake.chainKey,
|
KDF2(
|
||||||
&key,
|
&handshake.chainKey,
|
||||||
handshake.chainKey[:],
|
&key,
|
||||||
ss[:],
|
handshake.chainKey[:],
|
||||||
)
|
ss[:],
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
)
|
||||||
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
}()
|
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
||||||
handshake.mixHash(msg.Static[:])
|
handshake.mixHash(msg.Static[:])
|
||||||
|
|
||||||
// encrypt timestamp
|
// encrypt timestamp
|
||||||
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
|
return nil, errInvalidPublicKey
|
||||||
|
}
|
||||||
|
KDF2(
|
||||||
|
&handshake.chainKey,
|
||||||
|
&key,
|
||||||
|
handshake.chainKey[:],
|
||||||
|
handshake.precomputedStaticStatic[:],
|
||||||
|
)
|
||||||
timestamp := tai64n.Now()
|
timestamp := tai64n.Now()
|
||||||
func() {
|
aead, _ = chacha20poly1305.New(key[:])
|
||||||
var key [chacha20poly1305.KeySize]byte
|
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
||||||
KDF2(
|
|
||||||
&handshake.chainKey,
|
// assign index
|
||||||
&key,
|
device.indexTable.Delete(handshake.localIndex)
|
||||||
handshake.chainKey[:],
|
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
||||||
handshake.precomputedStaticStatic[:],
|
if err != nil {
|
||||||
)
|
return nil, err
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
}
|
||||||
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
handshake.localIndex = msg.Sender
|
||||||
}()
|
|
||||||
|
|
||||||
handshake.mixHash(msg.Timestamp[:])
|
handshake.mixHash(msg.Timestamp[:])
|
||||||
handshake.state = HandshakeInitiationCreated
|
handshake.state = handshakeInitiationCreated
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,16 +262,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
// decrypt static key
|
// decrypt static key
|
||||||
|
|
||||||
var err error
|
|
||||||
var peerPK NoisePublicKey
|
var peerPK NoisePublicKey
|
||||||
func() {
|
var key [chacha20poly1305.KeySize]byte
|
||||||
var key [chacha20poly1305.KeySize]byte
|
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
if err != nil {
|
||||||
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
return nil
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
}
|
||||||
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
||||||
}()
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
|
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -268,28 +279,29 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
// lookup peer
|
// lookup peer
|
||||||
|
|
||||||
peer := device.LookupPeer(peerPK)
|
peer := device.LookupPeer(peerPK)
|
||||||
if peer == nil {
|
if peer == nil || !peer.isRunning.Load() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify identity
|
// verify identity
|
||||||
|
|
||||||
var timestamp tai64n.Timestamp
|
var timestamp tai64n.Timestamp
|
||||||
var key [chacha20poly1305.KeySize]byte
|
|
||||||
|
|
||||||
handshake.mutex.RLock()
|
handshake.mutex.RLock()
|
||||||
|
|
||||||
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
|
handshake.mutex.RUnlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
KDF2(
|
KDF2(
|
||||||
&chainKey,
|
&chainKey,
|
||||||
&key,
|
&key,
|
||||||
chainKey[:],
|
chainKey[:],
|
||||||
handshake.precomputedStaticStatic[:],
|
handshake.precomputedStaticStatic[:],
|
||||||
)
|
)
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ = chacha20poly1305.New(key[:])
|
||||||
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handshake.mutex.RUnlock()
|
handshake.mutex.RUnlock()
|
||||||
@@ -299,11 +311,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
|
|
||||||
// protect against replay & flood
|
// protect against replay & flood
|
||||||
|
|
||||||
var ok bool
|
replay := !timestamp.After(handshake.lastTimestamp)
|
||||||
ok = timestamp.After(handshake.lastTimestamp)
|
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
|
||||||
ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
|
|
||||||
handshake.mutex.RUnlock()
|
handshake.mutex.RUnlock()
|
||||||
if !ok {
|
if replay {
|
||||||
|
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if flood {
|
||||||
|
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,9 +331,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.remoteEphemeral = msg.Ephemeral
|
handshake.remoteEphemeral = msg.Ephemeral
|
||||||
handshake.lastTimestamp = timestamp
|
if timestamp.After(handshake.lastTimestamp) {
|
||||||
handshake.lastInitiationConsumption = time.Now()
|
handshake.lastTimestamp = timestamp
|
||||||
handshake.state = HandshakeInitiationConsumed
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if now.After(handshake.lastInitiationConsumption) {
|
||||||
|
handshake.lastInitiationConsumption = now
|
||||||
|
}
|
||||||
|
handshake.state = handshakeInitiationConsumed
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
@@ -332,7 +353,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
if handshake.state != HandshakeInitiationConsumed {
|
if handshake.state != handshakeInitiationConsumed {
|
||||||
return nil, errors.New("handshake initiation must be consumed first")
|
return nil, errors.New("handshake initiation must be consumed first")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
handshake.mixHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
handshake.mixKey(msg.Ephemeral[:])
|
handshake.mixKey(msg.Ephemeral[:])
|
||||||
|
|
||||||
func() {
|
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
if err != nil {
|
||||||
handshake.mixKey(ss[:])
|
return nil, err
|
||||||
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
}
|
||||||
handshake.mixKey(ss[:])
|
handshake.mixKey(ss[:])
|
||||||
}()
|
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
handshake.mixKey(ss[:])
|
||||||
|
|
||||||
// add preshared key
|
// add preshared key
|
||||||
|
|
||||||
@@ -382,13 +407,11 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
|
|
||||||
handshake.mixHash(tau[:])
|
handshake.mixHash(tau[:])
|
||||||
|
|
||||||
func() {
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
handshake.mixHash(msg.Empty[:])
|
||||||
handshake.mixHash(msg.Empty[:])
|
|
||||||
}()
|
|
||||||
|
|
||||||
handshake.state = HandshakeResponseCreated
|
handshake.state = handshakeResponseCreated
|
||||||
|
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
@@ -412,13 +435,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
)
|
)
|
||||||
|
|
||||||
ok := func() bool {
|
ok := func() bool {
|
||||||
|
|
||||||
// lock handshake state
|
// lock handshake state
|
||||||
|
|
||||||
handshake.mutex.RLock()
|
handshake.mutex.RLock()
|
||||||
defer handshake.mutex.RUnlock()
|
defer handshake.mutex.RUnlock()
|
||||||
|
|
||||||
if handshake.state != HandshakeInitiationCreated {
|
if handshake.state != handshakeInitiationCreated {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -432,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
||||||
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
func() {
|
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||||
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
if err != nil {
|
||||||
mixKey(&chainKey, &chainKey, ss[:])
|
return false
|
||||||
setZero(ss[:])
|
}
|
||||||
}()
|
mixKey(&chainKey, &chainKey, ss[:])
|
||||||
|
setZero(ss[:])
|
||||||
|
|
||||||
func() {
|
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
if err != nil {
|
||||||
mixKey(&chainKey, &chainKey, ss[:])
|
return false
|
||||||
setZero(ss[:])
|
}
|
||||||
}()
|
mixKey(&chainKey, &chainKey, ss[:])
|
||||||
|
setZero(ss[:])
|
||||||
|
|
||||||
// add preshared key (psk)
|
// add preshared key (psk)
|
||||||
|
|
||||||
@@ -460,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
// authenticate transcript
|
// authenticate transcript
|
||||||
|
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -479,7 +503,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
handshake.hash = hash
|
handshake.hash = hash
|
||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.state = HandshakeResponseConsumed
|
handshake.state = handshakeResponseConsumed
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
@@ -504,7 +528,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
var sendKey [chacha20poly1305.KeySize]byte
|
var sendKey [chacha20poly1305.KeySize]byte
|
||||||
var recvKey [chacha20poly1305.KeySize]byte
|
var recvKey [chacha20poly1305.KeySize]byte
|
||||||
|
|
||||||
if handshake.state == HandshakeResponseConsumed {
|
if handshake.state == handshakeResponseConsumed {
|
||||||
KDF2(
|
KDF2(
|
||||||
&sendKey,
|
&sendKey,
|
||||||
&recvKey,
|
&recvKey,
|
||||||
@@ -512,7 +536,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
isInitiator = true
|
isInitiator = true
|
||||||
} else if handshake.state == HandshakeResponseCreated {
|
} else if handshake.state == handshakeResponseCreated {
|
||||||
KDF2(
|
KDF2(
|
||||||
&recvKey,
|
&recvKey,
|
||||||
&sendKey,
|
&sendKey,
|
||||||
@@ -521,7 +545,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
)
|
)
|
||||||
isInitiator = false
|
isInitiator = false
|
||||||
} else {
|
} else {
|
||||||
return errors.New("invalid state for keypair derivation")
|
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero handshake
|
// zero handshake
|
||||||
@@ -529,7 +553,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
setZero(handshake.chainKey[:])
|
setZero(handshake.chainKey[:])
|
||||||
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
||||||
setZero(handshake.localEphemeral[:])
|
setZero(handshake.localEphemeral[:])
|
||||||
peer.handshake.state = HandshakeZeroed
|
peer.handshake.state = handshakeZeroed
|
||||||
|
|
||||||
// create AEAD instances
|
// create AEAD instances
|
||||||
|
|
||||||
@@ -541,8 +565,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
setZero(recvKey[:])
|
setZero(recvKey[:])
|
||||||
|
|
||||||
keypair.created = time.Now()
|
keypair.created = time.Now()
|
||||||
keypair.sendNonce = 0
|
keypair.replayFilter.Reset()
|
||||||
keypair.replayFilter.Init()
|
|
||||||
keypair.isInitiator = isInitiator
|
keypair.isInitiator = isInitiator
|
||||||
keypair.localIndex = peer.handshake.localIndex
|
keypair.localIndex = peer.handshake.localIndex
|
||||||
keypair.remoteIndex = peer.handshake.remoteIndex
|
keypair.remoteIndex = peer.handshake.remoteIndex
|
||||||
@@ -559,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
|
|
||||||
previous := keypairs.previous
|
previous := keypairs.previous
|
||||||
next := keypairs.next
|
next := keypairs.next.Load()
|
||||||
current := keypairs.current
|
current := keypairs.current
|
||||||
|
|
||||||
if isInitiator {
|
if isInitiator {
|
||||||
if next != nil {
|
if next != nil {
|
||||||
keypairs.next = nil
|
keypairs.next.Store(nil)
|
||||||
keypairs.previous = next
|
keypairs.previous = next
|
||||||
device.DeleteKeypair(current)
|
device.DeleteKeypair(current)
|
||||||
} else {
|
} else {
|
||||||
@@ -573,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
keypairs.current = keypair
|
keypairs.current = keypair
|
||||||
} else {
|
} else {
|
||||||
keypairs.next = keypair
|
keypairs.next.Store(keypair)
|
||||||
device.DeleteKeypair(next)
|
device.DeleteKeypair(next)
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
@@ -584,18 +607,19 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
|
|
||||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
if keypairs.next != receivedKeypair {
|
|
||||||
|
if keypairs.next.Load() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
if keypairs.next != receivedKeypair {
|
if keypairs.next.Load() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
old := keypairs.previous
|
old := keypairs.previous
|
||||||
keypairs.previous = keypairs.current
|
keypairs.previous = keypairs.current
|
||||||
peer.device.DeleteKeypair(old)
|
peer.device.DeleteKeypair(old)
|
||||||
keypairs.current = keypairs.next
|
keypairs.current = keypairs.next.Load()
|
||||||
keypairs.next = nil
|
keypairs.next.Store(nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -9,19 +9,18 @@ import (
|
|||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoisePublicKeySize = 32
|
NoisePublicKeySize = 32
|
||||||
NoisePrivateKeySize = 32
|
NoisePrivateKeySize = 32
|
||||||
|
NoisePresharedKeySize = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
NoisePublicKey [NoisePublicKeySize]byte
|
NoisePublicKey [NoisePublicKeySize]byte
|
||||||
NoisePrivateKey [NoisePrivateKeySize]byte
|
NoisePrivateKey [NoisePrivateKeySize]byte
|
||||||
NoiseSymmetricKey [chacha20poly1305.KeySize]byte
|
NoisePresharedKey [NoisePresharedKeySize]byte
|
||||||
NoiseNonce uint64 // padded to 12-bytes
|
NoiseNonce uint64 // padded to 12-bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,18 +51,19 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key NoisePrivateKey) ToHex() string {
|
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
|
||||||
return hex.EncodeToString(key[:])
|
err = loadExactHex(key[:], src)
|
||||||
|
if key.IsZero() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key.clamp()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoisePublicKey) FromHex(src string) error {
|
func (key *NoisePublicKey) FromHex(src string) error {
|
||||||
return loadExactHex(key[:], src)
|
return loadExactHex(key[:], src)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key NoisePublicKey) ToHex() string {
|
|
||||||
return hex.EncodeToString(key[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
func (key NoisePublicKey) IsZero() bool {
|
func (key NoisePublicKey) IsZero() bool {
|
||||||
var zero NoisePublicKey
|
var zero NoisePublicKey
|
||||||
return key.Equals(zero)
|
return key.Equals(zero)
|
||||||
@@ -73,10 +73,6 @@ func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
|
|||||||
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoiseSymmetricKey) FromHex(src string) error {
|
func (key *NoisePresharedKey) FromHex(src string) error {
|
||||||
return loadExactHex(key[:], src)
|
return loadExactHex(key[:], src)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key NoiseSymmetricKey) ToHex() string {
|
|
||||||
return hex.EncodeToString(key[:])
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -9,6 +9,9 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
|
"github.com/Lordy82/wireguard-go/tun/tuntest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCurveWrappers(t *testing.T) {
|
func TestCurveWrappers(t *testing.T) {
|
||||||
@@ -21,14 +24,38 @@ func TestCurveWrappers(t *testing.T) {
|
|||||||
pk1 := sk1.publicKey()
|
pk1 := sk1.publicKey()
|
||||||
pk2 := sk2.publicKey()
|
pk2 := sk2.publicKey()
|
||||||
|
|
||||||
ss1 := sk1.sharedSecret(pk2)
|
ss1, err1 := sk1.sharedSecret(pk2)
|
||||||
ss2 := sk2.sharedSecret(pk1)
|
ss2, err2 := sk2.sharedSecret(pk1)
|
||||||
|
|
||||||
if ss1 != ss2 {
|
if ss1 != ss2 || err1 != nil || err2 != nil {
|
||||||
t.Fatal("Failed to compute shared secet")
|
t.Fatal("Failed to compute shared secet")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randDevice(t *testing.T) *Device {
|
||||||
|
sk, err := newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tun := tuntest.NewChannelTUN()
|
||||||
|
logger := NewLogger(LogLevelError, "")
|
||||||
|
device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
|
||||||
|
device.SetPrivateKey(sk)
|
||||||
|
return device
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertNil(t *testing.T, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEqual(t *testing.T, a, b []byte) {
|
||||||
|
if !bytes.Equal(a, b) {
|
||||||
|
t.Fatal(a, "!=", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNoiseHandshake(t *testing.T) {
|
func TestNoiseHandshake(t *testing.T) {
|
||||||
dev1 := randDevice(t)
|
dev1 := randDevice(t)
|
||||||
dev2 := randDevice(t)
|
dev2 := randDevice(t)
|
||||||
@@ -36,8 +63,16 @@ func TestNoiseHandshake(t *testing.T) {
|
|||||||
defer dev1.Close()
|
defer dev1.Close()
|
||||||
defer dev2.Close()
|
defer dev2.Close()
|
||||||
|
|
||||||
peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
|
peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
|
||||||
peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
peer1.Start()
|
||||||
|
peer2.Start()
|
||||||
|
|
||||||
assertEqual(
|
assertEqual(
|
||||||
t,
|
t,
|
||||||
@@ -113,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
|
|||||||
t.Fatal("failed to derive keypair for peer 2", err)
|
t.Fatal("failed to derive keypair for peer 2", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key1 := peer1.keypairs.next
|
key1 := peer1.keypairs.next.Load()
|
||||||
key2 := peer2.keypairs.current
|
key2 := peer2.keypairs.current
|
||||||
|
|
||||||
// encrypting / decryption test
|
// encrypting / decryption test
|
||||||
|
|||||||
269
device/peer.go
269
device/peer.go
@@ -1,37 +1,35 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"container/list"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
PeerRoutineNumber = 3
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
isRunning AtomicBool
|
isRunning atomic.Bool
|
||||||
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
|
keypairs Keypairs
|
||||||
keypairs Keypairs
|
handshake Handshake
|
||||||
handshake Handshake
|
device *Device
|
||||||
device *Device
|
stopping sync.WaitGroup // routines pending stop
|
||||||
endpoint Endpoint
|
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
||||||
persistentKeepaliveInterval uint16
|
rxBytes atomic.Uint64 // bytes received from peer
|
||||||
|
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
||||||
|
|
||||||
// This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
|
endpoint struct {
|
||||||
stats struct {
|
sync.Mutex
|
||||||
txBytes uint64 // bytes send to peer (endpoint)
|
val conn.Endpoint
|
||||||
rxBytes uint64 // bytes received from peer
|
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
|
||||||
lastHandshakeNano int64 // nano seconds since epoch
|
disableRoaming bool
|
||||||
}
|
}
|
||||||
|
|
||||||
timers struct {
|
timers struct {
|
||||||
@@ -40,40 +38,32 @@ type Peer struct {
|
|||||||
newHandshake *Timer
|
newHandshake *Timer
|
||||||
zeroKeyMaterial *Timer
|
zeroKeyMaterial *Timer
|
||||||
persistentKeepalive *Timer
|
persistentKeepalive *Timer
|
||||||
handshakeAttempts uint32
|
handshakeAttempts atomic.Uint32
|
||||||
needAnotherKeepalive AtomicBool
|
needAnotherKeepalive atomic.Bool
|
||||||
sentLastMinuteHandshake AtomicBool
|
sentLastMinuteHandshake atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
signals struct {
|
state struct {
|
||||||
newKeypairArrived chan struct{}
|
sync.Mutex // protects against concurrent Start/Stop
|
||||||
flushNonceQueue chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
queue struct {
|
queue struct {
|
||||||
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
|
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
|
||||||
outbound chan *QueueOutboundElement // sequential ordering of work
|
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
||||||
inbound chan *QueueInboundElement // sequential ordering of work
|
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
||||||
packetInNonceQueueIsAwaitingKey AtomicBool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
routines struct {
|
cookieGenerator CookieGenerator
|
||||||
sync.Mutex // held when stopping / starting routines
|
trieEntries list.List
|
||||||
starting sync.WaitGroup // routines pending start
|
persistentKeepaliveInterval atomic.Uint32
|
||||||
stopping sync.WaitGroup // routines pending stop
|
|
||||||
stop chan struct{} // size 0, stop all go routines in peer
|
|
||||||
}
|
|
||||||
|
|
||||||
cookieGenerator CookieGenerator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
if device.isClosed.Get() {
|
if device.isClosed() {
|
||||||
return nil, errors.New("device closed")
|
return nil, errors.New("device closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock resources
|
// lock resources
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
device.staticIdentity.RLock()
|
||||||
defer device.staticIdentity.RUnlock()
|
defer device.staticIdentity.RUnlock()
|
||||||
|
|
||||||
@@ -81,136 +71,144 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||||||
defer device.peers.Unlock()
|
defer device.peers.Unlock()
|
||||||
|
|
||||||
// check if over limit
|
// check if over limit
|
||||||
|
|
||||||
if len(device.peers.keyMap) >= MaxPeers {
|
if len(device.peers.keyMap) >= MaxPeers {
|
||||||
return nil, errors.New("too many peers")
|
return nil, errors.New("too many peers")
|
||||||
}
|
}
|
||||||
|
|
||||||
// create peer
|
// create peer
|
||||||
|
|
||||||
peer := new(Peer)
|
peer := new(Peer)
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
|
|
||||||
peer.cookieGenerator.Init(pk)
|
peer.cookieGenerator.Init(pk)
|
||||||
peer.device = device
|
peer.device = device
|
||||||
peer.isRunning.Set(false)
|
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
||||||
|
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
||||||
|
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
|
||||||
|
|
||||||
// map public key
|
// map public key
|
||||||
|
|
||||||
_, ok := device.peers.keyMap[pk]
|
_, ok := device.peers.keyMap[pk]
|
||||||
if ok {
|
if ok {
|
||||||
return nil, errors.New("adding existing peer")
|
return nil, errors.New("adding existing peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
// pre-compute DH
|
// pre-compute DH
|
||||||
|
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||||
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
|
|
||||||
handshake.remoteStatic = pk
|
handshake.remoteStatic = pk
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
// reset endpoint
|
// reset endpoint
|
||||||
|
peer.endpoint.Lock()
|
||||||
|
peer.endpoint.val = nil
|
||||||
|
peer.endpoint.disableRoaming = false
|
||||||
|
peer.endpoint.clearSrcOnTx = false
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
|
||||||
peer.endpoint = nil
|
// init timers
|
||||||
|
peer.timersInit()
|
||||||
|
|
||||||
// conditionally add
|
// add
|
||||||
|
device.peers.keyMap[pk] = peer
|
||||||
if !ssIsZero {
|
|
||||||
device.peers.keyMap[pk] = peer
|
|
||||||
} else {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// start peer
|
|
||||||
|
|
||||||
if peer.device.isUp.Get() {
|
|
||||||
peer.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) SendBuffer(buffer []byte) error {
|
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||||
peer.device.net.RLock()
|
peer.device.net.RLock()
|
||||||
defer peer.device.net.RUnlock()
|
defer peer.device.net.RUnlock()
|
||||||
|
|
||||||
if peer.device.net.bind == nil {
|
if peer.device.isClosed() {
|
||||||
return errors.New("no bind")
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.RLock()
|
peer.endpoint.Lock()
|
||||||
defer peer.RUnlock()
|
endpoint := peer.endpoint.val
|
||||||
|
if endpoint == nil {
|
||||||
if peer.endpoint == nil {
|
peer.endpoint.Unlock()
|
||||||
return errors.New("no known endpoint for peer")
|
return errors.New("no known endpoint for peer")
|
||||||
}
|
}
|
||||||
|
if peer.endpoint.clearSrcOnTx {
|
||||||
|
endpoint.ClearSrc()
|
||||||
|
peer.endpoint.clearSrcOnTx = false
|
||||||
|
}
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
|
||||||
err := peer.device.net.bind.Send(buffer, peer.endpoint)
|
err := peer.device.net.bind.Send(buffers, endpoint)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
|
var totalLen uint64
|
||||||
|
for _, b := range buffers {
|
||||||
|
totalLen += uint64(len(b))
|
||||||
|
}
|
||||||
|
peer.txBytes.Add(totalLen)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) String() string {
|
func (peer *Peer) String() string {
|
||||||
base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
|
// The awful goo that follows is identical to:
|
||||||
abbreviatedKey := "invalid"
|
//
|
||||||
if len(base64Key) == 44 {
|
// base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
|
||||||
abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
|
// abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
|
||||||
|
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
|
||||||
|
//
|
||||||
|
// except that it is considerably more efficient.
|
||||||
|
src := peer.handshake.remoteStatic
|
||||||
|
b64 := func(input byte) byte {
|
||||||
|
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("peer(%s)", abbreviatedKey)
|
b := []byte("peer(____…____)")
|
||||||
|
const first = len("peer(")
|
||||||
|
const second = len("peer(____…")
|
||||||
|
b[first+0] = b64((src[0] >> 2) & 63)
|
||||||
|
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
|
||||||
|
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
|
||||||
|
b[first+3] = b64(src[2] & 63)
|
||||||
|
b[second+0] = b64(src[29] & 63)
|
||||||
|
b[second+1] = b64((src[30] >> 2) & 63)
|
||||||
|
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
|
||||||
|
b[second+3] = b64((src[31] << 2) & 63)
|
||||||
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) Start() {
|
func (peer *Peer) Start() {
|
||||||
|
|
||||||
// should never start a peer on a closed device
|
// should never start a peer on a closed device
|
||||||
|
if peer.device.isClosed() {
|
||||||
if peer.device.isClosed.Get() {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// prevent simultaneous start/stop operations
|
// prevent simultaneous start/stop operations
|
||||||
|
peer.state.Lock()
|
||||||
|
defer peer.state.Unlock()
|
||||||
|
|
||||||
peer.routines.Lock()
|
if peer.isRunning.Load() {
|
||||||
defer peer.routines.Unlock()
|
|
||||||
|
|
||||||
if peer.isRunning.Get() {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
device := peer.device
|
device := peer.device
|
||||||
device.log.Debug.Println(peer, "- Starting...")
|
device.log.Verbosef("%v - Starting", peer)
|
||||||
|
|
||||||
// reset routine state
|
// reset routine state
|
||||||
|
peer.stopping.Wait()
|
||||||
|
peer.stopping.Add(2)
|
||||||
|
|
||||||
peer.routines.starting.Wait()
|
peer.handshake.mutex.Lock()
|
||||||
peer.routines.stopping.Wait()
|
|
||||||
peer.routines.stop = make(chan struct{})
|
|
||||||
peer.routines.starting.Add(PeerRoutineNumber)
|
|
||||||
peer.routines.stopping.Add(PeerRoutineNumber)
|
|
||||||
|
|
||||||
// prepare queues
|
|
||||||
|
|
||||||
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
|
|
||||||
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
|
|
||||||
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
|
|
||||||
|
|
||||||
peer.timersInit()
|
|
||||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||||
peer.signals.newKeypairArrived = make(chan struct{}, 1)
|
peer.handshake.mutex.Unlock()
|
||||||
peer.signals.flushNonceQueue = make(chan struct{}, 1)
|
|
||||||
|
|
||||||
// wait for routines to start
|
peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
|
||||||
|
|
||||||
go peer.RoutineNonce()
|
peer.timersStart()
|
||||||
go peer.RoutineSequentialSender()
|
|
||||||
go peer.RoutineSequentialReceiver()
|
|
||||||
|
|
||||||
peer.routines.starting.Wait()
|
device.flushInboundQueue(peer.queue.inbound)
|
||||||
peer.isRunning.Set(true)
|
device.flushOutboundQueue(peer.queue.outbound)
|
||||||
|
|
||||||
|
// Use the device batch size, not the bind batch size, as the device size is
|
||||||
|
// the size of the batch pools.
|
||||||
|
batchSize := peer.device.BatchSize()
|
||||||
|
go peer.RoutineSequentialSender(batchSize)
|
||||||
|
go peer.RoutineSequentialReceiver(batchSize)
|
||||||
|
|
||||||
|
peer.isRunning.Store(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) ZeroAndFlushAll() {
|
func (peer *Peer) ZeroAndFlushAll() {
|
||||||
@@ -222,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() {
|
|||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
device.DeleteKeypair(keypairs.previous)
|
device.DeleteKeypair(keypairs.previous)
|
||||||
device.DeleteKeypair(keypairs.current)
|
device.DeleteKeypair(keypairs.current)
|
||||||
device.DeleteKeypair(keypairs.next)
|
device.DeleteKeypair(keypairs.next.Load())
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
keypairs.current = nil
|
keypairs.current = nil
|
||||||
keypairs.next = nil
|
keypairs.next.Store(nil)
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
|
|
||||||
// clear handshake state
|
// clear handshake state
|
||||||
@@ -236,7 +234,7 @@ func (peer *Peer) ZeroAndFlushAll() {
|
|||||||
handshake.Clear()
|
handshake.Clear()
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
peer.FlushNonceQueue()
|
peer.FlushStagedPackets()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) ExpireCurrentKeypairs() {
|
func (peer *Peer) ExpireCurrentKeypairs() {
|
||||||
@@ -244,58 +242,55 @@ func (peer *Peer) ExpireCurrentKeypairs() {
|
|||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
peer.device.indexTable.Delete(handshake.localIndex)
|
peer.device.indexTable.Delete(handshake.localIndex)
|
||||||
handshake.Clear()
|
handshake.Clear()
|
||||||
handshake.mutex.Unlock()
|
|
||||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||||
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
if keypairs.current != nil {
|
if keypairs.current != nil {
|
||||||
keypairs.current.sendNonce = RejectAfterMessages
|
keypairs.current.sendNonce.Store(RejectAfterMessages)
|
||||||
}
|
}
|
||||||
if keypairs.next != nil {
|
if next := keypairs.next.Load(); next != nil {
|
||||||
keypairs.next.sendNonce = RejectAfterMessages
|
next.sendNonce.Store(RejectAfterMessages)
|
||||||
}
|
}
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) Stop() {
|
func (peer *Peer) Stop() {
|
||||||
|
peer.state.Lock()
|
||||||
// prevent simultaneous start/stop operations
|
defer peer.state.Unlock()
|
||||||
|
|
||||||
if !peer.isRunning.Swap(false) {
|
if !peer.isRunning.Swap(false) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.routines.starting.Wait()
|
peer.device.log.Verbosef("%v - Stopping", peer)
|
||||||
|
|
||||||
peer.routines.Lock()
|
|
||||||
defer peer.routines.Unlock()
|
|
||||||
|
|
||||||
peer.device.log.Debug.Println(peer, "- Stopping...")
|
|
||||||
|
|
||||||
peer.timersStop()
|
peer.timersStop()
|
||||||
|
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
|
||||||
// stop & wait for ongoing peer routines
|
peer.queue.inbound.c <- nil
|
||||||
|
peer.queue.outbound.c <- nil
|
||||||
close(peer.routines.stop)
|
peer.stopping.Wait()
|
||||||
peer.routines.stopping.Wait()
|
peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
|
||||||
|
|
||||||
// close queues
|
|
||||||
|
|
||||||
close(peer.queue.nonce)
|
|
||||||
close(peer.queue.outbound)
|
|
||||||
close(peer.queue.inbound)
|
|
||||||
|
|
||||||
peer.ZeroAndFlushAll()
|
peer.ZeroAndFlushAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
var RoamingDisabled bool
|
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||||
|
peer.endpoint.Lock()
|
||||||
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
|
defer peer.endpoint.Unlock()
|
||||||
if RoamingDisabled {
|
if peer.endpoint.disableRoaming {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peer.Lock()
|
peer.endpoint.clearSrcOnTx = false
|
||||||
peer.endpoint = endpoint
|
peer.endpoint.val = endpoint
|
||||||
peer.Unlock()
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) markEndpointSrcForClearing() {
|
||||||
|
peer.endpoint.Lock()
|
||||||
|
defer peer.endpoint.Unlock()
|
||||||
|
if peer.endpoint.val == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
peer.endpoint.clearSrcOnTx = true
|
||||||
}
|
}
|
||||||
|
|||||||
157
device/pools.go
157
device/pools.go
@@ -1,89 +1,120 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WaitPool struct {
|
||||||
|
pool sync.Pool
|
||||||
|
cond sync.Cond
|
||||||
|
lock sync.Mutex
|
||||||
|
count atomic.Uint32
|
||||||
|
max uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWaitPool(max uint32, new func() any) *WaitPool {
|
||||||
|
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
|
||||||
|
p.cond = sync.Cond{L: &p.lock}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WaitPool) Get() any {
|
||||||
|
if p.max != 0 {
|
||||||
|
p.lock.Lock()
|
||||||
|
for p.count.Load() >= p.max {
|
||||||
|
p.cond.Wait()
|
||||||
|
}
|
||||||
|
p.count.Add(1)
|
||||||
|
p.lock.Unlock()
|
||||||
|
}
|
||||||
|
return p.pool.Get()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WaitPool) Put(x any) {
|
||||||
|
p.pool.Put(x)
|
||||||
|
if p.max == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.count.Add(^uint32(0))
|
||||||
|
p.cond.Signal()
|
||||||
|
}
|
||||||
|
|
||||||
func (device *Device) PopulatePools() {
|
func (device *Device) PopulatePools() {
|
||||||
if PreallocatedBuffersPerPool == 0 {
|
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
device.pool.messageBufferPool = &sync.Pool{
|
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
||||||
New: func() interface{} {
|
return &QueueInboundElementsContainer{elems: s}
|
||||||
return new([MaxMessageSize]byte)
|
})
|
||||||
},
|
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
}
|
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
||||||
device.pool.inboundElementPool = &sync.Pool{
|
return &QueueOutboundElementsContainer{elems: s}
|
||||||
New: func() interface{} {
|
})
|
||||||
return new(QueueInboundElement)
|
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
},
|
return new([MaxMessageSize]byte)
|
||||||
}
|
})
|
||||||
device.pool.outboundElementPool = &sync.Pool{
|
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
New: func() interface{} {
|
return new(QueueInboundElement)
|
||||||
return new(QueueOutboundElement)
|
})
|
||||||
},
|
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
||||||
}
|
return new(QueueOutboundElement)
|
||||||
} else {
|
})
|
||||||
device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
|
}
|
||||||
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
|
|
||||||
device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
|
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
|
||||||
}
|
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
|
||||||
device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
|
c.Mutex = sync.Mutex{}
|
||||||
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
|
return c
|
||||||
device.pool.inboundElementReuseChan <- new(QueueInboundElement)
|
}
|
||||||
}
|
|
||||||
device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
|
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
|
||||||
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
|
for i := range c.elems {
|
||||||
device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
|
c.elems[i] = nil
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
c.elems = c.elems[:0]
|
||||||
|
device.pool.inboundElementsContainer.Put(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
|
||||||
|
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
|
||||||
|
c.Mutex = sync.Mutex{}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
|
||||||
|
for i := range c.elems {
|
||||||
|
c.elems[i] = nil
|
||||||
|
}
|
||||||
|
c.elems = c.elems[:0]
|
||||||
|
device.pool.outboundElementsContainer.Put(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||||
if PreallocatedBuffersPerPool == 0 {
|
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
||||||
return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
|
|
||||||
} else {
|
|
||||||
return <-device.pool.messageBufferReuseChan
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
|
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
|
||||||
if PreallocatedBuffersPerPool == 0 {
|
device.pool.messageBuffers.Put(msg)
|
||||||
device.pool.messageBufferPool.Put(msg)
|
|
||||||
} else {
|
|
||||||
device.pool.messageBufferReuseChan <- msg
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) GetInboundElement() *QueueInboundElement {
|
func (device *Device) GetInboundElement() *QueueInboundElement {
|
||||||
if PreallocatedBuffersPerPool == 0 {
|
return device.pool.inboundElements.Get().(*QueueInboundElement)
|
||||||
return device.pool.inboundElementPool.Get().(*QueueInboundElement)
|
|
||||||
} else {
|
|
||||||
return <-device.pool.inboundElementReuseChan
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) PutInboundElement(msg *QueueInboundElement) {
|
func (device *Device) PutInboundElement(elem *QueueInboundElement) {
|
||||||
if PreallocatedBuffersPerPool == 0 {
|
elem.clearPointers()
|
||||||
device.pool.inboundElementPool.Put(msg)
|
device.pool.inboundElements.Put(elem)
|
||||||
} else {
|
|
||||||
device.pool.inboundElementReuseChan <- msg
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) GetOutboundElement() *QueueOutboundElement {
|
func (device *Device) GetOutboundElement() *QueueOutboundElement {
|
||||||
if PreallocatedBuffersPerPool == 0 {
|
return device.pool.outboundElements.Get().(*QueueOutboundElement)
|
||||||
return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
|
|
||||||
} else {
|
|
||||||
return <-device.pool.outboundElementReuseChan
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
|
func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
|
||||||
if PreallocatedBuffersPerPool == 0 {
|
elem.clearPointers()
|
||||||
device.pool.outboundElementPool.Put(msg)
|
device.pool.outboundElements.Put(elem)
|
||||||
} else {
|
|
||||||
device.pool.outboundElementReuseChan <- msg
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
139
device/pools_test.go
Normal file
139
device/pools_test.go
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWaitPool(t *testing.T) {
|
||||||
|
t.Skip("Currently disabled")
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var trials atomic.Int32
|
||||||
|
startTrials := int32(100000)
|
||||||
|
if raceEnabled {
|
||||||
|
// This test can be very slow with -race.
|
||||||
|
startTrials /= 10
|
||||||
|
}
|
||||||
|
trials.Store(startTrials)
|
||||||
|
workers := runtime.NumCPU() + 2
|
||||||
|
if workers-4 <= 0 {
|
||||||
|
t.Skip("Not enough cores")
|
||||||
|
}
|
||||||
|
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
|
||||||
|
wg.Add(workers)
|
||||||
|
var max atomic.Uint32
|
||||||
|
updateMax := func() {
|
||||||
|
count := p.count.Load()
|
||||||
|
if count > p.max {
|
||||||
|
t.Errorf("count (%d) > max (%d)", count, p.max)
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
old := max.Load()
|
||||||
|
if count <= old {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if max.CompareAndSwap(old, count) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for trials.Add(-1) > 0 {
|
||||||
|
updateMax()
|
||||||
|
x := p.Get()
|
||||||
|
updateMax()
|
||||||
|
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||||
|
updateMax()
|
||||||
|
p.Put(x)
|
||||||
|
updateMax()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
if max.Load() != p.max {
|
||||||
|
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWaitPool(b *testing.B) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var trials atomic.Int32
|
||||||
|
trials.Store(int32(b.N))
|
||||||
|
workers := runtime.NumCPU() + 2
|
||||||
|
if workers-4 <= 0 {
|
||||||
|
b.Skip("Not enough cores")
|
||||||
|
}
|
||||||
|
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
|
||||||
|
wg.Add(workers)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for trials.Add(-1) > 0 {
|
||||||
|
x := p.Get()
|
||||||
|
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||||
|
p.Put(x)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWaitPoolEmpty(b *testing.B) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var trials atomic.Int32
|
||||||
|
trials.Store(int32(b.N))
|
||||||
|
workers := runtime.NumCPU() + 2
|
||||||
|
if workers-4 <= 0 {
|
||||||
|
b.Skip("Not enough cores")
|
||||||
|
}
|
||||||
|
p := NewWaitPool(0, func() any { return make([]byte, 16) })
|
||||||
|
wg.Add(workers)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for trials.Add(-1) > 0 {
|
||||||
|
x := p.Get()
|
||||||
|
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||||
|
p.Put(x)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSyncPool(b *testing.B) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var trials atomic.Int32
|
||||||
|
trials.Store(int32(b.N))
|
||||||
|
workers := runtime.NumCPU() + 2
|
||||||
|
if workers-4 <= 0 {
|
||||||
|
b.Skip("Not enough cores")
|
||||||
|
}
|
||||||
|
p := sync.Pool{New: func() any { return make([]byte, 16) }}
|
||||||
|
wg.Add(workers)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for trials.Add(-1) > 0 {
|
||||||
|
x := p.Get()
|
||||||
|
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||||
|
p.Put(x)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
@@ -1,16 +1,19 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
|
import "github.com/Lordy82/wireguard-go/conn"
|
||||||
|
|
||||||
/* Reduce memory consumption for Android */
|
/* Reduce memory consumption for Android */
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
QueueStagedSize = conn.IdealBatchSize
|
||||||
QueueOutboundSize = 1024
|
QueueOutboundSize = 1024
|
||||||
QueueInboundSize = 1024
|
QueueInboundSize = 1024
|
||||||
QueueHandshakeSize = 1024
|
QueueHandshakeSize = 1024
|
||||||
MaxSegmentSize = 2200
|
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
||||||
PreallocatedBuffersPerPool = 4096
|
PreallocatedBuffersPerPool = 4096
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
// +build !android,!ios
|
//go:build !android && !ios && !windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
|
import "github.com/Lordy82/wireguard-go/conn"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
QueueStagedSize = conn.IdealBatchSize
|
||||||
QueueOutboundSize = 1024
|
QueueOutboundSize = 1024
|
||||||
QueueInboundSize = 1024
|
QueueInboundSize = 1024
|
||||||
QueueHandshakeSize = 1024
|
QueueHandshakeSize = 1024
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
// +build ios
|
//go:build ios
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */
|
// Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
|
||||||
|
// These are vars instead of consts, because heavier network extensions might want to reduce
|
||||||
const (
|
// them further.
|
||||||
QueueOutboundSize = 1024
|
var (
|
||||||
QueueInboundSize = 1024
|
QueueStagedSize = 128
|
||||||
QueueHandshakeSize = 1024
|
QueueOutboundSize = 1024
|
||||||
MaxSegmentSize = 1700
|
QueueInboundSize = 1024
|
||||||
PreallocatedBuffersPerPool = 1024
|
QueueHandshakeSize = 1024
|
||||||
|
PreallocatedBuffersPerPool uint32 = 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const MaxSegmentSize = 1700
|
||||||
|
|||||||
15
device/queueconstants_windows.go
Normal file
15
device/queueconstants_windows.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
const (
|
||||||
|
QueueStagedSize = 128
|
||||||
|
QueueOutboundSize = 1024
|
||||||
|
QueueInboundSize = 1024
|
||||||
|
QueueHandshakeSize = 1024
|
||||||
|
MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
|
||||||
|
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
|
||||||
|
)
|
||||||
10
device/race_disabled_test.go
Normal file
10
device/race_disabled_test.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build !race
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
const raceEnabled = false
|
||||||
10
device/race_enabled_test.go
Normal file
10
device/race_enabled_test.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build race
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
const raceEnabled = true
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -8,12 +8,12 @@ package device
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
@@ -22,52 +22,32 @@ import (
|
|||||||
type QueueHandshakeElement struct {
|
type QueueHandshakeElement struct {
|
||||||
msgType uint32
|
msgType uint32
|
||||||
packet []byte
|
packet []byte
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
buffer *[MaxMessageSize]byte
|
buffer *[MaxMessageSize]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueueInboundElement struct {
|
type QueueInboundElement struct {
|
||||||
dropped int32
|
|
||||||
sync.Mutex
|
|
||||||
buffer *[MaxMessageSize]byte
|
buffer *[MaxMessageSize]byte
|
||||||
packet []byte
|
packet []byte
|
||||||
counter uint64
|
counter uint64
|
||||||
keypair *Keypair
|
keypair *Keypair
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
func (elem *QueueInboundElement) Drop() {
|
type QueueInboundElementsContainer struct {
|
||||||
atomic.StoreInt32(&elem.dropped, AtomicTrue)
|
sync.Mutex
|
||||||
|
elems []*QueueInboundElement
|
||||||
}
|
}
|
||||||
|
|
||||||
func (elem *QueueInboundElement) IsDropped() bool {
|
// clearPointers clears elem fields that contain pointers.
|
||||||
return atomic.LoadInt32(&elem.dropped) == AtomicTrue
|
// This makes the garbage collector's life easier and
|
||||||
}
|
// avoids accidentally keeping other objects around unnecessarily.
|
||||||
|
// It also reduces the possible collateral damage from use-after-free bugs.
|
||||||
func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool {
|
func (elem *QueueInboundElement) clearPointers() {
|
||||||
select {
|
elem.buffer = nil
|
||||||
case inboundQueue <- element:
|
elem.packet = nil
|
||||||
select {
|
elem.keypair = nil
|
||||||
case decryptionQueue <- element:
|
elem.endpoint = nil
|
||||||
return true
|
|
||||||
default:
|
|
||||||
element.Drop()
|
|
||||||
element.Unlock()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
device.PutInboundElement(element)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool {
|
|
||||||
select {
|
|
||||||
case queue <- element:
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Called when a new authenticated message has been received
|
/* Called when a new authenticated message has been received
|
||||||
@@ -75,12 +55,12 @@ func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem
|
|||||||
* NOTE: Not thread safe, but called by sequential receiver!
|
* NOTE: Not thread safe, but called by sequential receiver!
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) keepKeyFreshReceiving() {
|
func (peer *Peer) keepKeyFreshReceiving() {
|
||||||
if peer.timers.sentLastMinuteHandshake.Get() {
|
if peer.timers.sentLastMinuteHandshake.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
keypair := peer.keypairs.Current()
|
keypair := peer.keypairs.Current()
|
||||||
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
||||||
peer.timers.sentLastMinuteHandshake.Set(true)
|
peer.timers.sentLastMinuteHandshake.Store(true)
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -90,188 +70,189 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
|||||||
* Every time the bind is updated a new routine is started for
|
* Every time the bind is updated a new routine is started for
|
||||||
* IPv4 and IPv6 (separately)
|
* IPv4 and IPv6 (separately)
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
|
||||||
|
recvName := recv.PrettyName()
|
||||||
logDebug := device.log.Debug
|
|
||||||
defer func() {
|
defer func() {
|
||||||
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
|
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
||||||
|
device.queue.decryption.wg.Done()
|
||||||
|
device.queue.handshake.wg.Done()
|
||||||
device.net.stopping.Done()
|
device.net.stopping.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started")
|
device.log.Verbosef("Routine: receive incoming %s - started", recvName)
|
||||||
device.net.starting.Done()
|
|
||||||
|
|
||||||
// receive datagrams until conn is closed
|
// receive datagrams until conn is closed
|
||||||
|
|
||||||
buffer := device.GetMessageBuffer()
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
err error
|
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
|
||||||
size int
|
bufs = make([][]byte, maxBatchSize)
|
||||||
endpoint Endpoint
|
err error
|
||||||
|
sizes = make([]int, maxBatchSize)
|
||||||
|
count int
|
||||||
|
endpoints = make([]conn.Endpoint, maxBatchSize)
|
||||||
|
deathSpiral int
|
||||||
|
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
|
||||||
)
|
)
|
||||||
|
|
||||||
for {
|
for i := range bufsArrs {
|
||||||
|
bufsArrs[i] = device.GetMessageBuffer()
|
||||||
|
bufs[i] = bufsArrs[i][:]
|
||||||
|
}
|
||||||
|
|
||||||
// read next datagram
|
defer func() {
|
||||||
|
for i := 0; i < maxBatchSize; i++ {
|
||||||
switch IP {
|
if bufsArrs[i] != nil {
|
||||||
case ipv4.Version:
|
device.PutMessageBuffer(bufsArrs[i])
|
||||||
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
|
}
|
||||||
case ipv6.Version:
|
|
||||||
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
|
|
||||||
default:
|
|
||||||
panic("invalid IP version")
|
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
count, err = recv(bufs, sizes, endpoints)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.PutMessageBuffer(buffer)
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
|
||||||
|
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if deathSpiral < 10 {
|
||||||
|
deathSpiral++
|
||||||
|
time.Sleep(time.Second / 3)
|
||||||
|
continue
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
deathSpiral = 0
|
||||||
|
|
||||||
if size < MinMessageSize {
|
// handle each packet in the batch
|
||||||
continue
|
for i, size := range sizes[:count] {
|
||||||
}
|
if size < MinMessageSize {
|
||||||
|
|
||||||
// check size of packet
|
|
||||||
|
|
||||||
packet := buffer[:size]
|
|
||||||
msgType := binary.LittleEndian.Uint32(packet[:4])
|
|
||||||
|
|
||||||
var okay bool
|
|
||||||
|
|
||||||
switch msgType {
|
|
||||||
|
|
||||||
// check if transport
|
|
||||||
|
|
||||||
case MessageTransportType:
|
|
||||||
|
|
||||||
// check size
|
|
||||||
|
|
||||||
if len(packet) < MessageTransportSize {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookup key pair
|
// check size of packet
|
||||||
|
|
||||||
receiver := binary.LittleEndian.Uint32(
|
packet := bufsArrs[i][:size]
|
||||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
msgType := binary.LittleEndian.Uint32(packet[:4])
|
||||||
)
|
|
||||||
value := device.indexTable.Lookup(receiver)
|
|
||||||
keypair := value.keypair
|
|
||||||
if keypair == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// check keypair expiry
|
switch msgType {
|
||||||
|
|
||||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
// check if transport
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// create work element
|
case MessageTransportType:
|
||||||
peer := value.peer
|
|
||||||
elem := device.GetInboundElement()
|
|
||||||
elem.packet = packet
|
|
||||||
elem.buffer = buffer
|
|
||||||
elem.keypair = keypair
|
|
||||||
elem.dropped = AtomicFalse
|
|
||||||
elem.endpoint = endpoint
|
|
||||||
elem.counter = 0
|
|
||||||
elem.Mutex = sync.Mutex{}
|
|
||||||
elem.Lock()
|
|
||||||
|
|
||||||
// add to decryption queues
|
// check size
|
||||||
|
|
||||||
if peer.isRunning.Get() {
|
if len(packet) < MessageTransportSize {
|
||||||
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
|
continue
|
||||||
buffer = device.GetMessageBuffer()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// lookup key pair
|
||||||
|
|
||||||
|
receiver := binary.LittleEndian.Uint32(
|
||||||
|
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
||||||
|
)
|
||||||
|
value := device.indexTable.Lookup(receiver)
|
||||||
|
keypair := value.keypair
|
||||||
|
if keypair == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// check keypair expiry
|
||||||
|
|
||||||
|
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// create work element
|
||||||
|
peer := value.peer
|
||||||
|
elem := device.GetInboundElement()
|
||||||
|
elem.packet = packet
|
||||||
|
elem.buffer = bufsArrs[i]
|
||||||
|
elem.keypair = keypair
|
||||||
|
elem.endpoint = endpoints[i]
|
||||||
|
elem.counter = 0
|
||||||
|
|
||||||
|
elemsForPeer, ok := elemsByPeer[peer]
|
||||||
|
if !ok {
|
||||||
|
elemsForPeer = device.GetInboundElementsContainer()
|
||||||
|
elemsForPeer.Lock()
|
||||||
|
elemsByPeer[peer] = elemsForPeer
|
||||||
|
}
|
||||||
|
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||||
|
bufsArrs[i] = device.GetMessageBuffer()
|
||||||
|
bufs[i] = bufsArrs[i][:]
|
||||||
|
continue
|
||||||
|
|
||||||
|
// otherwise it is a fixed size & handshake related packet
|
||||||
|
|
||||||
|
case MessageInitiationType:
|
||||||
|
if len(packet) != MessageInitiationSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
case MessageResponseType:
|
||||||
|
if len(packet) != MessageResponseSize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
case MessageCookieReplyType:
|
||||||
|
if len(packet) != MessageCookieReplySize {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
device.log.Verbosef("Received message with unknown type")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
select {
|
||||||
|
case device.queue.handshake.c <- QueueHandshakeElement{
|
||||||
// otherwise it is a fixed size & handshake related packet
|
msgType: msgType,
|
||||||
|
buffer: bufsArrs[i],
|
||||||
case MessageInitiationType:
|
packet: packet,
|
||||||
okay = len(packet) == MessageInitiationSize
|
endpoint: endpoints[i],
|
||||||
|
}:
|
||||||
case MessageResponseType:
|
bufsArrs[i] = device.GetMessageBuffer()
|
||||||
okay = len(packet) == MessageResponseSize
|
bufs[i] = bufsArrs[i][:]
|
||||||
|
default:
|
||||||
case MessageCookieReplyType:
|
}
|
||||||
okay = len(packet) == MessageCookieReplySize
|
|
||||||
|
|
||||||
default:
|
|
||||||
logDebug.Println("Received message with unknown type")
|
|
||||||
}
|
}
|
||||||
|
for peer, elemsContainer := range elemsByPeer {
|
||||||
if okay {
|
if peer.isRunning.Load() {
|
||||||
if (device.addToHandshakeQueue(
|
peer.queue.inbound.c <- elemsContainer
|
||||||
device.queue.handshake,
|
device.queue.decryption.c <- elemsContainer
|
||||||
QueueHandshakeElement{
|
} else {
|
||||||
msgType: msgType,
|
for _, elem := range elemsContainer.elems {
|
||||||
buffer: buffer,
|
device.PutMessageBuffer(elem.buffer)
|
||||||
packet: packet,
|
device.PutInboundElement(elem)
|
||||||
endpoint: endpoint,
|
}
|
||||||
},
|
device.PutInboundElementsContainer(elemsContainer)
|
||||||
)) {
|
|
||||||
buffer = device.GetMessageBuffer()
|
|
||||||
}
|
}
|
||||||
|
delete(elemsByPeer, peer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) RoutineDecryption() {
|
func (device *Device) RoutineDecryption(id int) {
|
||||||
|
|
||||||
var nonce [chacha20poly1305.NonceSize]byte
|
var nonce [chacha20poly1305.NonceSize]byte
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
||||||
defer func() {
|
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
||||||
logDebug.Println("Routine: decryption worker - stopped")
|
|
||||||
device.state.stopping.Done()
|
|
||||||
}()
|
|
||||||
logDebug.Println("Routine: decryption worker - started")
|
|
||||||
device.state.starting.Done()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-device.signals.stop:
|
|
||||||
return
|
|
||||||
|
|
||||||
case elem, ok := <-device.queue.decryption:
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if dropped
|
|
||||||
|
|
||||||
if elem.IsDropped() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for elemsContainer := range device.queue.decryption.c {
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
// split message into fields
|
// split message into fields
|
||||||
|
|
||||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||||
content := elem.packet[MessageTransportOffsetContent:]
|
content := elem.packet[MessageTransportOffsetContent:]
|
||||||
|
|
||||||
// expand nonce
|
|
||||||
|
|
||||||
nonce[0x4] = counter[0x0]
|
|
||||||
nonce[0x5] = counter[0x1]
|
|
||||||
nonce[0x6] = counter[0x2]
|
|
||||||
nonce[0x7] = counter[0x3]
|
|
||||||
|
|
||||||
nonce[0x8] = counter[0x4]
|
|
||||||
nonce[0x9] = counter[0x5]
|
|
||||||
nonce[0xa] = counter[0x6]
|
|
||||||
nonce[0xb] = counter[0x7]
|
|
||||||
|
|
||||||
// decrypt and release to consumer
|
// decrypt and release to consumer
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||||
|
// copy counter to nonce
|
||||||
|
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
||||||
elem.packet, err = elem.keypair.receive.Open(
|
elem.packet, err = elem.keypair.receive.Open(
|
||||||
content[:0],
|
content[:0],
|
||||||
nonce[:],
|
nonce[:],
|
||||||
@@ -279,51 +260,23 @@ func (device *Device) RoutineDecryption() {
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
elem.Drop()
|
elem.packet = nil
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
}
|
}
|
||||||
elem.Unlock()
|
|
||||||
}
|
}
|
||||||
|
elemsContainer.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Handles incoming packets related to handshake
|
/* Handles incoming packets related to handshake
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineHandshake() {
|
func (device *Device) RoutineHandshake(id int) {
|
||||||
|
|
||||||
logInfo := device.log.Info
|
|
||||||
logError := device.log.Error
|
|
||||||
logDebug := device.log.Debug
|
|
||||||
|
|
||||||
var elem QueueHandshakeElement
|
|
||||||
var ok bool
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
logDebug.Println("Routine: handshake worker - stopped")
|
device.log.Verbosef("Routine: handshake worker %d - stopped", id)
|
||||||
device.state.stopping.Done()
|
device.queue.encryption.wg.Done()
|
||||||
if elem.buffer != nil {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
device.log.Verbosef("Routine: handshake worker %d - started", id)
|
||||||
|
|
||||||
logDebug.Println("Routine: handshake worker - started")
|
for elem := range device.queue.handshake.c {
|
||||||
device.state.starting.Done()
|
|
||||||
|
|
||||||
for {
|
|
||||||
if elem.buffer != nil {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
elem.buffer = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case elem, ok = <-device.queue.handshake:
|
|
||||||
case <-device.signals.stop:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle cookie fields and ratelimiting
|
// handle cookie fields and ratelimiting
|
||||||
|
|
||||||
@@ -337,8 +290,8 @@ func (device *Device) RoutineHandshake() {
|
|||||||
reader := bytes.NewReader(elem.packet)
|
reader := bytes.NewReader(elem.packet)
|
||||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logDebug.Println("Failed to decode cookie reply")
|
device.log.Verbosef("Failed to decode cookie reply")
|
||||||
return
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookup peer from index
|
// lookup peer from index
|
||||||
@@ -346,27 +299,27 @@ func (device *Device) RoutineHandshake() {
|
|||||||
entry := device.indexTable.Lookup(reply.Receiver)
|
entry := device.indexTable.Lookup(reply.Receiver)
|
||||||
|
|
||||||
if entry.peer == nil {
|
if entry.peer == nil {
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// consume reply
|
// consume reply
|
||||||
|
|
||||||
if peer := entry.peer; peer.isRunning.Get() {
|
if peer := entry.peer; peer.isRunning.Load() {
|
||||||
logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString())
|
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
|
||||||
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
||||||
logDebug.Println("Could not decrypt invalid cookie response")
|
device.log.Verbosef("Could not decrypt invalid cookie response")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
goto skip
|
||||||
|
|
||||||
case MessageInitiationType, MessageResponseType:
|
case MessageInitiationType, MessageResponseType:
|
||||||
|
|
||||||
// check mac fields and maybe ratelimit
|
// check mac fields and maybe ratelimit
|
||||||
|
|
||||||
if !device.cookieChecker.CheckMAC1(elem.packet) {
|
if !device.cookieChecker.CheckMAC1(elem.packet) {
|
||||||
logDebug.Println("Received packet with invalid mac1")
|
device.log.Verbosef("Received packet with invalid mac1")
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// endpoints destination address is the source of the datagram
|
// endpoints destination address is the source of the datagram
|
||||||
@@ -377,19 +330,19 @@ func (device *Device) RoutineHandshake() {
|
|||||||
|
|
||||||
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
|
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
|
||||||
device.SendHandshakeCookie(&elem)
|
device.SendHandshakeCookie(&elem)
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// check ratelimiter
|
// check ratelimiter
|
||||||
|
|
||||||
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
|
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
logError.Println("Invalid packet ended up in the handshake queue")
|
device.log.Errorf("Invalid packet ended up in the handshake queue")
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle handshake initiation/response content
|
// handle handshake initiation/response content
|
||||||
@@ -403,19 +356,16 @@ func (device *Device) RoutineHandshake() {
|
|||||||
reader := bytes.NewReader(elem.packet)
|
reader := bytes.NewReader(elem.packet)
|
||||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to decode initiation message")
|
device.log.Errorf("Failed to decode initiation message")
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// consume initiation
|
// consume initiation
|
||||||
|
|
||||||
peer := device.ConsumeMessageInitiation(&msg)
|
peer := device.ConsumeMessageInitiation(&msg)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
logInfo.Println(
|
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
||||||
"Received invalid initiation message from",
|
goto skip
|
||||||
elem.endpoint.DstToString(),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update timers
|
// update timers
|
||||||
@@ -426,8 +376,8 @@ func (device *Device) RoutineHandshake() {
|
|||||||
// update endpoint
|
// update endpoint
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
|
||||||
logDebug.Println(peer, "- Received handshake initiation")
|
device.log.Verbosef("%v - Received handshake initiation", peer)
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||||
|
|
||||||
peer.SendHandshakeResponse()
|
peer.SendHandshakeResponse()
|
||||||
|
|
||||||
@@ -439,26 +389,23 @@ func (device *Device) RoutineHandshake() {
|
|||||||
reader := bytes.NewReader(elem.packet)
|
reader := bytes.NewReader(elem.packet)
|
||||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to decode response message")
|
device.log.Errorf("Failed to decode response message")
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
// consume response
|
// consume response
|
||||||
|
|
||||||
peer := device.ConsumeMessageResponse(&msg)
|
peer := device.ConsumeMessageResponse(&msg)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
logInfo.Println(
|
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
|
||||||
"Received invalid response message from",
|
goto skip
|
||||||
elem.endpoint.DstToString(),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update endpoint
|
// update endpoint
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
|
||||||
logDebug.Println(peer, "- Received handshake response")
|
device.log.Verbosef("%v - Received handshake response", peer)
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
peer.rxBytes.Add(uint64(len(elem.packet)))
|
||||||
|
|
||||||
// update timers
|
// update timers
|
||||||
|
|
||||||
@@ -470,178 +417,124 @@ func (device *Device) RoutineHandshake() {
|
|||||||
err = peer.BeginSymmetricSession()
|
err = peer.BeginSymmetricSession()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println(peer, "- Failed to derive keypair:", err)
|
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
||||||
continue
|
goto skip
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.timersSessionDerived()
|
peer.timersSessionDerived()
|
||||||
peer.timersHandshakeComplete()
|
peer.timersHandshakeComplete()
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
select {
|
|
||||||
case peer.signals.newKeypairArrived <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
skip:
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) RoutineSequentialReceiver() {
|
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||||
|
|
||||||
device := peer.device
|
device := peer.device
|
||||||
logInfo := device.log.Info
|
|
||||||
logError := device.log.Error
|
|
||||||
logDebug := device.log.Debug
|
|
||||||
|
|
||||||
var elem *QueueInboundElement
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
logDebug.Println(peer, "- Routine: sequential receiver - stopped")
|
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
||||||
peer.routines.stopping.Done()
|
peer.stopping.Done()
|
||||||
if elem != nil {
|
|
||||||
if !elem.IsDropped() {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
}
|
|
||||||
device.PutInboundElement(elem)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
||||||
|
|
||||||
logDebug.Println(peer, "- Routine: sequential receiver - started")
|
bufs := make([][]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
peer.routines.starting.Done()
|
for elemsContainer := range peer.queue.inbound.c {
|
||||||
|
if elemsContainer == nil {
|
||||||
for {
|
|
||||||
if elem != nil {
|
|
||||||
if !elem.IsDropped() {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
}
|
|
||||||
device.PutInboundElement(elem)
|
|
||||||
elem = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var elemOk bool
|
|
||||||
select {
|
|
||||||
case <-peer.routines.stop:
|
|
||||||
return
|
return
|
||||||
case elem, elemOk = <-peer.queue.inbound:
|
}
|
||||||
if !elemOk {
|
elemsContainer.Lock()
|
||||||
return
|
validTailPacket := -1
|
||||||
|
dataPacketReceived := false
|
||||||
|
rxBytesLen := uint64(0)
|
||||||
|
for i, elem := range elemsContainer.elems {
|
||||||
|
if elem.packet == nil {
|
||||||
|
// decryption failed
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// wait for decryption
|
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
elem.Lock()
|
validTailPacket = i
|
||||||
|
if peer.ReceivedWithKeypair(elem.keypair) {
|
||||||
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
peer.timersHandshakeComplete()
|
||||||
|
peer.SendStagedPackets()
|
||||||
|
}
|
||||||
|
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
|
||||||
|
|
||||||
if elem.IsDropped() {
|
if len(elem.packet) == 0 {
|
||||||
continue
|
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
dataPacketReceived = true
|
||||||
|
|
||||||
// check for replay
|
switch elem.packet[0] >> 4 {
|
||||||
|
case 4:
|
||||||
|
if len(elem.packet) < ipv4.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||||
|
length := binary.BigEndian.Uint16(field)
|
||||||
|
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
elem.packet = elem.packet[:length]
|
||||||
|
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||||
|
if device.allowedips.Lookup(src) != peer {
|
||||||
|
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
case 6:
|
||||||
continue
|
if len(elem.packet) < ipv6.HeaderLen {
|
||||||
}
|
continue
|
||||||
|
}
|
||||||
|
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||||
|
length := binary.BigEndian.Uint16(field)
|
||||||
|
length += ipv6.HeaderLen
|
||||||
|
if int(length) > len(elem.packet) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
elem.packet = elem.packet[:length]
|
||||||
|
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||||
|
if device.allowedips.Lookup(src) != peer {
|
||||||
|
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// update endpoint
|
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
|
||||||
|
|
||||||
// check if using new keypair
|
|
||||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
|
||||||
peer.timersHandshakeComplete()
|
|
||||||
select {
|
|
||||||
case peer.signals.newKeypairArrived <- struct{}{}:
|
|
||||||
default:
|
default:
|
||||||
|
device.log.Verbosef("Packet with invalid IP version from %v", peer)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.keepKeyFreshReceiving()
|
peer.rxBytes.Add(rxBytesLen)
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
if validTailPacket >= 0 {
|
||||||
peer.timersAnyAuthenticatedPacketReceived()
|
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
|
peer.keepKeyFreshReceiving()
|
||||||
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
// check for keepalive
|
peer.timersAnyAuthenticatedPacketReceived()
|
||||||
|
|
||||||
if len(elem.packet) == 0 {
|
|
||||||
logDebug.Println(peer, "- Receiving keepalive packet")
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
peer.timersDataReceived()
|
if dataPacketReceived {
|
||||||
|
peer.timersDataReceived()
|
||||||
// verify source and strip padding
|
|
||||||
|
|
||||||
switch elem.packet[0] >> 4 {
|
|
||||||
case ipv4.Version:
|
|
||||||
|
|
||||||
// strip padding
|
|
||||||
|
|
||||||
if len(elem.packet) < ipv4.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
|
||||||
length := binary.BigEndian.Uint16(field)
|
|
||||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
elem.packet = elem.packet[:length]
|
|
||||||
|
|
||||||
// verify IPv4 source
|
|
||||||
|
|
||||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
|
||||||
if device.allowedips.LookupIPv4(src) != peer {
|
|
||||||
logInfo.Println(
|
|
||||||
"IPv4 packet with disallowed source address from",
|
|
||||||
peer,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
case ipv6.Version:
|
|
||||||
|
|
||||||
// strip padding
|
|
||||||
|
|
||||||
if len(elem.packet) < ipv6.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
|
||||||
length := binary.BigEndian.Uint16(field)
|
|
||||||
length += ipv6.HeaderLen
|
|
||||||
if int(length) > len(elem.packet) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
elem.packet = elem.packet[:length]
|
|
||||||
|
|
||||||
// verify IPv6 source
|
|
||||||
|
|
||||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
|
||||||
if device.allowedips.LookupIPv6(src) != peer {
|
|
||||||
logInfo.Println(
|
|
||||||
"IPv6 packet with disallowed source address from",
|
|
||||||
peer,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
logInfo.Println("Packet with invalid IP version from", peer)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
if len(bufs) > 0 {
|
||||||
// write to tun device
|
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
||||||
|
if err != nil && !device.isClosed() {
|
||||||
offset := MessageTransportOffsetContent
|
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
||||||
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
|
|
||||||
if len(peer.queue.inbound) == 0 {
|
|
||||||
err = device.tun.device.Flush()
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Error.Printf("Unable to flush packets: %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil && !device.isClosed.Get() {
|
for _, elem := range elemsContainer.elems {
|
||||||
logError.Println("Failed to write packet to TUN device:", err)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutInboundElement(elem)
|
||||||
}
|
}
|
||||||
|
bufs = bufs[:0]
|
||||||
|
device.PutInboundElementsContainer(elemsContainer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
719
device/send.go
719
device/send.go
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -8,11 +8,14 @@ package device
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
|
"github.com/Lordy82/wireguard-go/tun"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
@@ -43,8 +46,6 @@ import (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
type QueueOutboundElement struct {
|
type QueueOutboundElement struct {
|
||||||
dropped int32
|
|
||||||
sync.Mutex
|
|
||||||
buffer *[MaxMessageSize]byte // slice holding the packet data
|
buffer *[MaxMessageSize]byte // slice holding the packet data
|
||||||
packet []byte // slice of "buffer" (always!)
|
packet []byte // slice of "buffer" (always!)
|
||||||
nonce uint64 // nonce for encryption
|
nonce uint64 // nonce for encryption
|
||||||
@@ -52,80 +53,52 @@ type QueueOutboundElement struct {
|
|||||||
peer *Peer // related peer
|
peer *Peer // related peer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueueOutboundElementsContainer struct {
|
||||||
|
sync.Mutex
|
||||||
|
elems []*QueueOutboundElement
|
||||||
|
}
|
||||||
|
|
||||||
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
||||||
elem := device.GetOutboundElement()
|
elem := device.GetOutboundElement()
|
||||||
elem.dropped = AtomicFalse
|
|
||||||
elem.buffer = device.GetMessageBuffer()
|
elem.buffer = device.GetMessageBuffer()
|
||||||
elem.Mutex = sync.Mutex{}
|
|
||||||
elem.nonce = 0
|
elem.nonce = 0
|
||||||
elem.keypair = nil
|
// keypair and peer were cleared (if necessary) by clearPointers.
|
||||||
elem.peer = nil
|
|
||||||
return elem
|
return elem
|
||||||
}
|
}
|
||||||
|
|
||||||
func (elem *QueueOutboundElement) Drop() {
|
// clearPointers clears elem fields that contain pointers.
|
||||||
atomic.StoreInt32(&elem.dropped, AtomicTrue)
|
// This makes the garbage collector's life easier and
|
||||||
}
|
// avoids accidentally keeping other objects around unnecessarily.
|
||||||
|
// It also reduces the possible collateral damage from use-after-free bugs.
|
||||||
func (elem *QueueOutboundElement) IsDropped() bool {
|
func (elem *QueueOutboundElement) clearPointers() {
|
||||||
return atomic.LoadInt32(&elem.dropped) == AtomicTrue
|
elem.buffer = nil
|
||||||
}
|
elem.packet = nil
|
||||||
|
elem.keypair = nil
|
||||||
func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) {
|
elem.peer = nil
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case queue <- element:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
select {
|
|
||||||
case old := <-queue:
|
|
||||||
device.PutMessageBuffer(old.buffer)
|
|
||||||
device.PutOutboundElement(old)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) {
|
|
||||||
select {
|
|
||||||
case outboundQueue <- element:
|
|
||||||
select {
|
|
||||||
case encryptionQueue <- element:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
element.Drop()
|
|
||||||
element.peer.device.PutMessageBuffer(element.buffer)
|
|
||||||
element.Unlock()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
element.peer.device.PutMessageBuffer(element.buffer)
|
|
||||||
element.peer.device.PutOutboundElement(element)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Queues a keepalive if no packets are queued for peer
|
/* Queues a keepalive if no packets are queued for peer
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) SendKeepalive() bool {
|
func (peer *Peer) SendKeepalive() {
|
||||||
if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
|
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
||||||
return false
|
elem := peer.device.NewOutboundElement()
|
||||||
}
|
elemsContainer := peer.device.GetOutboundElementsContainer()
|
||||||
elem := peer.device.NewOutboundElement()
|
elemsContainer.elems = append(elemsContainer.elems, elem)
|
||||||
elem.packet = nil
|
select {
|
||||||
select {
|
case peer.queue.staged <- elemsContainer:
|
||||||
case peer.queue.nonce <- elem:
|
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
||||||
peer.device.log.Debug.Println(peer, "- Sending keepalive packet")
|
default:
|
||||||
return true
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
default:
|
peer.device.PutOutboundElement(elem)
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
peer.device.PutOutboundElement(elem)
|
}
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
peer.SendStagedPackets()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||||
if !isRetry {
|
if !isRetry {
|
||||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
peer.timers.handshakeAttempts.Store(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.handshake.mutex.RLock()
|
peer.handshake.mutex.RLock()
|
||||||
@@ -143,16 +116,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||||||
peer.handshake.lastSentHandshake = time.Now()
|
peer.handshake.lastSentHandshake = time.Now()
|
||||||
peer.handshake.mutex.Unlock()
|
peer.handshake.mutex.Unlock()
|
||||||
|
|
||||||
peer.device.log.Debug.Println(peer, "- Sending handshake initiation")
|
peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
|
||||||
|
|
||||||
msg, err := peer.device.CreateMessageInitiation(peer)
|
msg, err := peer.device.CreateMessageInitiation(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err)
|
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var buff [MessageInitiationSize]byte
|
var buf [MessageInitiationSize]byte
|
||||||
writer := bytes.NewBuffer(buff[:0])
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, msg)
|
binary.Write(writer, binary.LittleEndian, msg)
|
||||||
packet := writer.Bytes()
|
packet := writer.Bytes()
|
||||||
peer.cookieGenerator.AddMacs(packet)
|
peer.cookieGenerator.AddMacs(packet)
|
||||||
@@ -160,9 +133,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
peer.timersAnyAuthenticatedPacketSent()
|
||||||
|
|
||||||
err = peer.SendBuffer(packet)
|
err = peer.SendBuffers([][]byte{packet})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err)
|
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||||
}
|
}
|
||||||
peer.timersHandshakeInitiated()
|
peer.timersHandshakeInitiated()
|
||||||
|
|
||||||
@@ -174,23 +147,23 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||||||
peer.handshake.lastSentHandshake = time.Now()
|
peer.handshake.lastSentHandshake = time.Now()
|
||||||
peer.handshake.mutex.Unlock()
|
peer.handshake.mutex.Unlock()
|
||||||
|
|
||||||
peer.device.log.Debug.Println(peer, "- Sending handshake response")
|
peer.device.log.Verbosef("%v - Sending handshake response", peer)
|
||||||
|
|
||||||
response, err := peer.device.CreateMessageResponse(peer)
|
response, err := peer.device.CreateMessageResponse(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Error.Println(peer, "- Failed to create response message:", err)
|
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var buff [MessageResponseSize]byte
|
var buf [MessageResponseSize]byte
|
||||||
writer := bytes.NewBuffer(buff[:0])
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, response)
|
binary.Write(writer, binary.LittleEndian, response)
|
||||||
packet := writer.Bytes()
|
packet := writer.Bytes()
|
||||||
peer.cookieGenerator.AddMacs(packet)
|
peer.cookieGenerator.AddMacs(packet)
|
||||||
|
|
||||||
err = peer.BeginSymmetricSession()
|
err = peer.BeginSymmetricSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err)
|
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,32 +171,30 @@ func (peer *Peer) SendHandshakeResponse() error {
|
|||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
peer.timersAnyAuthenticatedPacketSent()
|
||||||
|
|
||||||
err = peer.SendBuffer(packet)
|
// TODO: allocation could be avoided
|
||||||
|
err = peer.SendBuffers([][]byte{packet})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
peer.device.log.Error.Println(peer, "- Failed to send handshake response", err)
|
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
|
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
|
||||||
|
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
|
||||||
device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString())
|
|
||||||
|
|
||||||
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
|
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
|
||||||
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
|
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
device.log.Error.Println("Failed to create cookie reply:", err)
|
device.log.Errorf("Failed to create cookie reply: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var buff [MessageCookieReplySize]byte
|
var buf [MessageCookieReplySize]byte
|
||||||
writer := bytes.NewBuffer(buff[:0])
|
writer := bytes.NewBuffer(buf[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, reply)
|
binary.Write(writer, binary.LittleEndian, reply)
|
||||||
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
// TODO: allocation could be avoided
|
||||||
if err != nil {
|
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
|
||||||
device.log.Error.Println("Failed to send cookie reply:", err)
|
return nil
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) keepKeyFreshSending() {
|
func (peer *Peer) keepKeyFreshSending() {
|
||||||
@@ -231,280 +202,255 @@ func (peer *Peer) keepKeyFreshSending() {
|
|||||||
if keypair == nil {
|
if keypair == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
nonce := atomic.LoadUint64(&keypair.sendNonce)
|
nonce := keypair.sendNonce.Load()
|
||||||
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Reads packets from the TUN and inserts
|
|
||||||
* into nonce queue for peer
|
|
||||||
*
|
|
||||||
* Obs. Single instance per TUN device
|
|
||||||
*/
|
|
||||||
func (device *Device) RoutineReadFromTUN() {
|
func (device *Device) RoutineReadFromTUN() {
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
|
||||||
logError := device.log.Error
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
logDebug.Println("Routine: TUN reader - stopped")
|
device.log.Verbosef("Routine: TUN reader - stopped")
|
||||||
device.state.stopping.Done()
|
device.state.stopping.Done()
|
||||||
|
device.queue.encryption.wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logDebug.Println("Routine: TUN reader - started")
|
device.log.Verbosef("Routine: TUN reader - started")
|
||||||
device.state.starting.Done()
|
|
||||||
|
|
||||||
var elem *QueueOutboundElement
|
var (
|
||||||
|
batchSize = device.BatchSize()
|
||||||
|
readErr error
|
||||||
|
elems = make([]*QueueOutboundElement, batchSize)
|
||||||
|
bufs = make([][]byte, batchSize)
|
||||||
|
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
|
||||||
|
count = 0
|
||||||
|
sizes = make([]int, batchSize)
|
||||||
|
offset = MessageTransportHeaderSize
|
||||||
|
)
|
||||||
|
|
||||||
for {
|
for i := range elems {
|
||||||
if elem != nil {
|
elems[i] = device.NewOutboundElement()
|
||||||
device.PutMessageBuffer(elem.buffer)
|
bufs[i] = elems[i].buffer[:]
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
elem = device.NewOutboundElement()
|
|
||||||
|
|
||||||
// read packet
|
|
||||||
|
|
||||||
offset := MessageTransportHeaderSize
|
|
||||||
size, err := device.tun.device.Read(elem.buffer[:], offset)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if !device.isClosed.Get() {
|
|
||||||
logError.Println("Failed to read packet from TUN device:", err)
|
|
||||||
device.Close()
|
|
||||||
}
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if size == 0 || size > MaxContentSize {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
elem.packet = elem.buffer[offset : offset+size]
|
|
||||||
|
|
||||||
// lookup peer
|
|
||||||
|
|
||||||
var peer *Peer
|
|
||||||
switch elem.packet[0] >> 4 {
|
|
||||||
case ipv4.Version:
|
|
||||||
if len(elem.packet) < ipv4.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
|
||||||
peer = device.allowedips.LookupIPv4(dst)
|
|
||||||
|
|
||||||
case ipv6.Version:
|
|
||||||
if len(elem.packet) < ipv6.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
|
||||||
peer = device.allowedips.LookupIPv6(dst)
|
|
||||||
|
|
||||||
default:
|
|
||||||
logDebug.Println("Received packet with unknown IP version")
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert into nonce/pre-handshake queue
|
|
||||||
|
|
||||||
if peer.isRunning.Get() {
|
|
||||||
if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
|
|
||||||
peer.SendHandshakeInitiation(false)
|
|
||||||
}
|
|
||||||
addToNonceQueue(peer.queue.nonce, elem, device)
|
|
||||||
elem = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) FlushNonceQueue() {
|
|
||||||
select {
|
|
||||||
case peer.signals.flushNonceQueue <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Queues packets when there is no handshake.
|
|
||||||
* Then assigns nonces to packets sequentially
|
|
||||||
* and creates "work" structs for workers
|
|
||||||
*
|
|
||||||
* Obs. A single instance per peer
|
|
||||||
*/
|
|
||||||
func (peer *Peer) RoutineNonce() {
|
|
||||||
var keypair *Keypair
|
|
||||||
|
|
||||||
device := peer.device
|
|
||||||
logDebug := device.log.Debug
|
|
||||||
|
|
||||||
flush := func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case elem := <-peer.queue.nonce:
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
flush()
|
for _, elem := range elems {
|
||||||
logDebug.Println(peer, "- Routine: nonce worker - stopped")
|
if elem != nil {
|
||||||
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
|
|
||||||
peer.routines.stopping.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
peer.routines.starting.Done()
|
|
||||||
logDebug.Println(peer, "- Routine: nonce worker - started")
|
|
||||||
|
|
||||||
for {
|
|
||||||
NextPacket:
|
|
||||||
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-peer.routines.stop:
|
|
||||||
return
|
|
||||||
|
|
||||||
case <-peer.signals.flushNonceQueue:
|
|
||||||
flush()
|
|
||||||
goto NextPacket
|
|
||||||
|
|
||||||
case elem, ok := <-peer.queue.nonce:
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure to always pick the newest key
|
|
||||||
|
|
||||||
for {
|
|
||||||
|
|
||||||
// check validity of newest key pair
|
|
||||||
|
|
||||||
keypair = peer.keypairs.Current()
|
|
||||||
if keypair != nil && keypair.sendNonce < RejectAfterMessages {
|
|
||||||
if time.Since(keypair.created) < RejectAfterTime {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
peer.queue.packetInNonceQueueIsAwaitingKey.Set(true)
|
|
||||||
|
|
||||||
// no suitable key pair, request for new handshake
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-peer.signals.newKeypairArrived:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.SendHandshakeInitiation(false)
|
|
||||||
|
|
||||||
// wait for key to be established
|
|
||||||
|
|
||||||
logDebug.Println(peer, "- Awaiting keypair")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-peer.signals.newKeypairArrived:
|
|
||||||
logDebug.Println(peer, "- Obtained awaited keypair")
|
|
||||||
|
|
||||||
case <-peer.signals.flushNonceQueue:
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
flush()
|
|
||||||
goto NextPacket
|
|
||||||
|
|
||||||
case <-peer.routines.stop:
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
|
|
||||||
|
|
||||||
// populate work element
|
|
||||||
|
|
||||||
elem.peer = peer
|
|
||||||
elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
|
|
||||||
|
|
||||||
// double check in case of race condition added by future code
|
|
||||||
|
|
||||||
if elem.nonce >= RejectAfterMessages {
|
|
||||||
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutOutboundElement(elem)
|
device.PutOutboundElement(elem)
|
||||||
goto NextPacket
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
// read packets
|
||||||
|
count, readErr = device.tun.device.Read(bufs, sizes, offset)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
if sizes[i] < 1 {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
elem.keypair = keypair
|
elem := elems[i]
|
||||||
elem.dropped = AtomicFalse
|
elem.packet = bufs[i][offset : offset+sizes[i]]
|
||||||
elem.Lock()
|
|
||||||
|
// lookup peer
|
||||||
|
var peer *Peer
|
||||||
|
switch elem.packet[0] >> 4 {
|
||||||
|
case 4:
|
||||||
|
if len(elem.packet) < ipv4.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||||
|
peer = device.allowedips.Lookup(dst)
|
||||||
|
|
||||||
|
case 6:
|
||||||
|
if len(elem.packet) < ipv6.HeaderLen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||||
|
peer = device.allowedips.Lookup(dst)
|
||||||
|
|
||||||
|
default:
|
||||||
|
device.log.Verbosef("Received packet with unknown IP version")
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
elemsForPeer, ok := elemsByPeer[peer]
|
||||||
|
if !ok {
|
||||||
|
elemsForPeer = device.GetOutboundElementsContainer()
|
||||||
|
elemsByPeer[peer] = elemsForPeer
|
||||||
|
}
|
||||||
|
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
||||||
|
elems[i] = device.NewOutboundElement()
|
||||||
|
bufs[i] = elems[i].buffer[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
for peer, elemsForPeer := range elemsByPeer {
|
||||||
|
if peer.isRunning.Load() {
|
||||||
|
peer.StagePackets(elemsForPeer)
|
||||||
|
peer.SendStagedPackets()
|
||||||
|
} else {
|
||||||
|
for _, elem := range elemsForPeer.elems {
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
|
device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
device.PutOutboundElementsContainer(elemsForPeer)
|
||||||
|
}
|
||||||
|
delete(elemsByPeer, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if readErr != nil {
|
||||||
|
if errors.Is(readErr, tun.ErrTooManySegments) {
|
||||||
|
// TODO: record stat for this
|
||||||
|
// This will happen if MSS is surprisingly small (< 576)
|
||||||
|
// coincident with reasonably high throughput.
|
||||||
|
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !device.isClosed() {
|
||||||
|
if !errors.Is(readErr, os.ErrClosed) {
|
||||||
|
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
|
||||||
|
}
|
||||||
|
go device.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case peer.queue.staged <- elems:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case tooOld := <-peer.queue.staged:
|
||||||
|
for _, elem := range tooOld.elems {
|
||||||
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
|
peer.device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
peer.device.PutOutboundElementsContainer(tooOld)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) SendStagedPackets() {
|
||||||
|
top:
|
||||||
|
if len(peer.queue.staged) == 0 || !peer.device.isUp() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
keypair := peer.keypairs.Current()
|
||||||
|
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
||||||
|
peer.SendHandshakeInitiation(false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
var elemsContainerOOO *QueueOutboundElementsContainer
|
||||||
|
select {
|
||||||
|
case elemsContainer := <-peer.queue.staged:
|
||||||
|
i := 0
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
|
elem.peer = peer
|
||||||
|
elem.nonce = keypair.sendNonce.Add(1) - 1
|
||||||
|
if elem.nonce >= RejectAfterMessages {
|
||||||
|
keypair.sendNonce.Store(RejectAfterMessages)
|
||||||
|
if elemsContainerOOO == nil {
|
||||||
|
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
|
||||||
|
}
|
||||||
|
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
elemsContainer.elems[i] = elem
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
elem.keypair = keypair
|
||||||
|
}
|
||||||
|
elemsContainer.Lock()
|
||||||
|
elemsContainer.elems = elemsContainer.elems[:i]
|
||||||
|
|
||||||
|
if elemsContainerOOO != nil {
|
||||||
|
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(elemsContainer.elems) == 0 {
|
||||||
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
|
goto top
|
||||||
|
}
|
||||||
|
|
||||||
// add to parallel and sequential queue
|
// add to parallel and sequential queue
|
||||||
addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
|
if peer.isRunning.Load() {
|
||||||
|
peer.queue.outbound.c <- elemsContainer
|
||||||
|
peer.device.queue.encryption.c <- elemsContainer
|
||||||
|
} else {
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
|
peer.device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if elemsContainerOOO != nil {
|
||||||
|
goto top
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) FlushStagedPackets() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case elemsContainer := <-peer.queue.staged:
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
|
peer.device.PutMessageBuffer(elem.buffer)
|
||||||
|
peer.device.PutOutboundElement(elem)
|
||||||
|
}
|
||||||
|
peer.device.PutOutboundElementsContainer(elemsContainer)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculatePaddingSize(packetSize, mtu int) int {
|
||||||
|
lastUnit := packetSize
|
||||||
|
if mtu == 0 {
|
||||||
|
return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
|
||||||
|
}
|
||||||
|
if lastUnit > mtu {
|
||||||
|
lastUnit %= mtu
|
||||||
|
}
|
||||||
|
paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
|
||||||
|
if paddedSize > mtu {
|
||||||
|
paddedSize = mtu
|
||||||
|
}
|
||||||
|
return paddedSize - lastUnit
|
||||||
|
}
|
||||||
|
|
||||||
/* Encrypts the elements in the queue
|
/* Encrypts the elements in the queue
|
||||||
* and marks them for sequential consumption (by releasing the mutex)
|
* and marks them for sequential consumption (by releasing the mutex)
|
||||||
*
|
*
|
||||||
* Obs. One instance per core
|
* Obs. One instance per core
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineEncryption() {
|
func (device *Device) RoutineEncryption(id int) {
|
||||||
|
var paddingZeros [PaddingMultiple]byte
|
||||||
var nonce [chacha20poly1305.NonceSize]byte
|
var nonce [chacha20poly1305.NonceSize]byte
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
|
||||||
|
device.log.Verbosef("Routine: encryption worker %d - started", id)
|
||||||
defer func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case elem, ok := <-device.queue.encryption:
|
|
||||||
if ok && !elem.IsDropped() {
|
|
||||||
elem.Drop()
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
elem.Unlock()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
goto out
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out:
|
|
||||||
logDebug.Println("Routine: encryption worker - stopped")
|
|
||||||
device.state.stopping.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
logDebug.Println("Routine: encryption worker - started")
|
|
||||||
device.state.starting.Done()
|
|
||||||
|
|
||||||
for {
|
|
||||||
|
|
||||||
// fetch next element
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-device.signals.stop:
|
|
||||||
return
|
|
||||||
|
|
||||||
case elem, ok := <-device.queue.encryption:
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if dropped
|
|
||||||
|
|
||||||
if elem.IsDropped() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
|
for elemsContainer := range device.queue.encryption.c {
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
// populate header fields
|
// populate header fields
|
||||||
|
|
||||||
header := elem.buffer[:MessageTransportHeaderSize]
|
header := elem.buffer[:MessageTransportHeaderSize]
|
||||||
|
|
||||||
fieldType := header[0:4]
|
fieldType := header[0:4]
|
||||||
@@ -516,16 +462,8 @@ func (device *Device) RoutineEncryption() {
|
|||||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
||||||
|
|
||||||
// pad content to multiple of 16
|
// pad content to multiple of 16
|
||||||
|
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
||||||
mtu := int(atomic.LoadInt32(&device.tun.mtu))
|
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
||||||
lastUnit := len(elem.packet) % mtu
|
|
||||||
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
|
|
||||||
if paddedSize > mtu {
|
|
||||||
paddedSize = mtu
|
|
||||||
}
|
|
||||||
for i := len(elem.packet); i < paddedSize; i++ {
|
|
||||||
elem.packet = append(elem.packet, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// encrypt content and release to consumer
|
// encrypt content and release to consumer
|
||||||
|
|
||||||
@@ -536,82 +474,73 @@ func (device *Device) RoutineEncryption() {
|
|||||||
elem.packet,
|
elem.packet,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
elem.Unlock()
|
|
||||||
}
|
}
|
||||||
|
elemsContainer.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Sequentially reads packets from queue and sends to endpoint
|
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||||
*
|
|
||||||
* Obs. Single instance per peer.
|
|
||||||
* The routine terminates then the outbound queue is closed.
|
|
||||||
*/
|
|
||||||
func (peer *Peer) RoutineSequentialSender() {
|
|
||||||
|
|
||||||
device := peer.device
|
device := peer.device
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
|
||||||
logError := device.log.Error
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
for {
|
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
|
||||||
select {
|
peer.stopping.Done()
|
||||||
case elem, ok := <-peer.queue.outbound:
|
|
||||||
if ok {
|
|
||||||
if !elem.IsDropped() {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
elem.Drop()
|
|
||||||
}
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
goto out
|
|
||||||
}
|
|
||||||
}
|
|
||||||
out:
|
|
||||||
logDebug.Println(peer, "- Routine: sequential sender - stopped")
|
|
||||||
peer.routines.stopping.Done()
|
|
||||||
}()
|
}()
|
||||||
|
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
|
||||||
|
|
||||||
logDebug.Println(peer, "- Routine: sequential sender - started")
|
bufs := make([][]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
peer.routines.starting.Done()
|
for elemsContainer := range peer.queue.outbound.c {
|
||||||
|
bufs = bufs[:0]
|
||||||
for {
|
if elemsContainer == nil {
|
||||||
select {
|
|
||||||
|
|
||||||
case <-peer.routines.stop:
|
|
||||||
return
|
return
|
||||||
|
}
|
||||||
case elem, ok := <-peer.queue.outbound:
|
if !peer.isRunning.Load() {
|
||||||
|
// peer has been stopped; return re-usable elems to the shared pool.
|
||||||
if !ok {
|
// This is an optimization only. It is possible for the peer to be stopped
|
||||||
return
|
// immediately after this check, in which case, elem will get processed.
|
||||||
}
|
// The timers and SendBuffers code are resilient to a few stragglers.
|
||||||
|
// TODO: rework peer shutdown order to ensure
|
||||||
elem.Lock()
|
// that we never accidentally keep timers alive longer than necessary.
|
||||||
if elem.IsDropped() {
|
elemsContainer.Lock()
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutOutboundElement(elem)
|
device.PutOutboundElement(elem)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
continue
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
}
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
dataSent := false
|
||||||
|
elemsContainer.Lock()
|
||||||
// send message and return buffer to pool
|
for _, elem := range elemsContainer.elems {
|
||||||
|
|
||||||
err := peer.SendBuffer(elem.packet)
|
|
||||||
if len(elem.packet) != MessageKeepaliveSize {
|
if len(elem.packet) != MessageKeepaliveSize {
|
||||||
peer.timersDataSent()
|
dataSent = true
|
||||||
}
|
}
|
||||||
|
bufs = append(bufs, elem.packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
|
peer.timersAnyAuthenticatedPacketSent()
|
||||||
|
|
||||||
|
err := peer.SendBuffers(bufs)
|
||||||
|
if dataSent {
|
||||||
|
peer.timersDataSent()
|
||||||
|
}
|
||||||
|
for _, elem := range elemsContainer.elems {
|
||||||
device.PutMessageBuffer(elem.buffer)
|
device.PutMessageBuffer(elem.buffer)
|
||||||
device.PutOutboundElement(elem)
|
device.PutOutboundElement(elem)
|
||||||
if err != nil {
|
|
||||||
logError.Println(peer, "- Failed to send data packet", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.keepKeyFreshSending()
|
|
||||||
}
|
}
|
||||||
|
device.PutOutboundElementsContainer(elemsContainer)
|
||||||
|
if err != nil {
|
||||||
|
var errGSO conn.ErrUDPGSODisabled
|
||||||
|
if errors.As(err, &errGSO) {
|
||||||
|
device.log.Verbosef(err.Error())
|
||||||
|
err = errGSO.RetryErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.keepKeyFreshSending()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
12
device/sticky_default.go
Normal file
12
device/sticky_default.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !linux
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
|
"github.com/Lordy82/wireguard-go/rwcancel"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
224
device/sticky_linux.go
Normal file
224
device/sticky_linux.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* This implements userspace semantics of "sticky sockets", modeled after
|
||||||
|
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||||
|
* of the sticky-sockets.c example code:
|
||||||
|
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
||||||
|
*
|
||||||
|
* Currently there is no way to achieve this within the net package:
|
||||||
|
* See e.g. https://github.com/golang/go/issues/17930
|
||||||
|
* So this code is remains platform dependent.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
|
"github.com/Lordy82/wireguard-go/rwcancel"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
|
if !conn.StdNetSupportsStickySockets {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if _, ok := bind.(*conn.StdNetBind); !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
netlinkSock, err := createNetlinkRouteSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(netlinkSock)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
|
||||||
|
|
||||||
|
return netlinkCancel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||||
|
type peerEndpointPtr struct {
|
||||||
|
peer *Peer
|
||||||
|
endpoint *conn.Endpoint
|
||||||
|
}
|
||||||
|
var reqPeer map[uint32]peerEndpointPtr
|
||||||
|
var reqPeerLock sync.Mutex
|
||||||
|
|
||||||
|
defer netlinkCancel.Close()
|
||||||
|
defer unix.Close(netlinkSock)
|
||||||
|
|
||||||
|
for msg := make([]byte, 1<<16); ; {
|
||||||
|
var err error
|
||||||
|
var msgn int
|
||||||
|
for {
|
||||||
|
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
|
||||||
|
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !netlinkCancel.ReadyRead() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||||
|
|
||||||
|
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||||
|
|
||||||
|
if uint(hdr.Len) > uint(len(remain)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch hdr.Type {
|
||||||
|
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||||
|
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
||||||
|
if uint(len(remain)) < uint(hdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
||||||
|
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
||||||
|
for {
|
||||||
|
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
||||||
|
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||||
|
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||||
|
reqPeerLock.Lock()
|
||||||
|
if reqPeer == nil {
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr, ok := reqPeer[hdr.Seq]
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr.peer.endpoint.Lock()
|
||||||
|
if &pePtr.peer.endpoint.val != pePtr.endpoint {
|
||||||
|
pePtr.peer.endpoint.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
||||||
|
pePtr.peer.endpoint.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr.peer.endpoint.clearSrcOnTx = true
|
||||||
|
pePtr.peer.endpoint.Unlock()
|
||||||
|
}
|
||||||
|
attr = attr[attrhdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
reqPeerLock.Lock()
|
||||||
|
reqPeer = make(map[uint32]peerEndpointPtr)
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
go func() {
|
||||||
|
device.peers.RLock()
|
||||||
|
i := uint32(1)
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.endpoint.Lock()
|
||||||
|
if peer.endpoint.val == nil {
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
|
||||||
|
if nativeEP == nil {
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nlmsg := struct {
|
||||||
|
hdr unix.NlMsghdr
|
||||||
|
msg unix.RtMsg
|
||||||
|
dsthdr unix.RtAttr
|
||||||
|
dst [4]byte
|
||||||
|
srchdr unix.RtAttr
|
||||||
|
src [4]byte
|
||||||
|
markhdr unix.RtAttr
|
||||||
|
mark uint32
|
||||||
|
}{
|
||||||
|
unix.NlMsghdr{
|
||||||
|
Type: uint16(unix.RTM_GETROUTE),
|
||||||
|
Flags: unix.NLM_F_REQUEST,
|
||||||
|
Seq: i,
|
||||||
|
},
|
||||||
|
unix.RtMsg{
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Dst_len: 32,
|
||||||
|
Src_len: 32,
|
||||||
|
},
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_DST,
|
||||||
|
},
|
||||||
|
nativeEP.DstIP().As4(),
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_SRC,
|
||||||
|
},
|
||||||
|
nativeEP.SrcIP().As4(),
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_MARK,
|
||||||
|
},
|
||||||
|
device.net.fwmark,
|
||||||
|
}
|
||||||
|
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||||
|
reqPeerLock.Lock()
|
||||||
|
reqPeer[i] = peerEndpointPtr{
|
||||||
|
peer: peer,
|
||||||
|
endpoint: &peer.endpoint.val,
|
||||||
|
}
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
i++
|
||||||
|
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
remain = remain[hdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNetlinkRouteSocket() (int, error) {
|
||||||
|
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
saddr := &unix.SockaddrNetlink{
|
||||||
|
Family: unix.AF_NETLINK,
|
||||||
|
Groups: unix.RTMGRP_IPV4_ROUTE,
|
||||||
|
}
|
||||||
|
err = unix.Bind(sock, saddr)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(sock)
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return sock, nil
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*
|
*
|
||||||
* This is based heavily on timers.c from the kernel implementation.
|
* This is based heavily on timers.c from the kernel implementation.
|
||||||
*/
|
*/
|
||||||
@@ -8,16 +8,16 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
_ "unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* This Timer structure and related functions should roughly copy the interface of
|
//go:linkname fastrandn runtime.fastrandn
|
||||||
* the Linux kernel's struct timer_list.
|
func fastrandn(n uint32) uint32
|
||||||
*/
|
|
||||||
|
|
||||||
|
// A Timer manages time-based aspects of the WireGuard protocol.
|
||||||
|
// Timer roughly copies the interface of the Linux kernel's struct timer_list.
|
||||||
type Timer struct {
|
type Timer struct {
|
||||||
*time.Timer
|
*time.Timer
|
||||||
modifyingLock sync.RWMutex
|
modifyingLock sync.RWMutex
|
||||||
@@ -29,18 +29,17 @@ func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
|
|||||||
timer := &Timer{}
|
timer := &Timer{}
|
||||||
timer.Timer = time.AfterFunc(time.Hour, func() {
|
timer.Timer = time.AfterFunc(time.Hour, func() {
|
||||||
timer.runningLock.Lock()
|
timer.runningLock.Lock()
|
||||||
|
defer timer.runningLock.Unlock()
|
||||||
|
|
||||||
timer.modifyingLock.Lock()
|
timer.modifyingLock.Lock()
|
||||||
if !timer.isPending {
|
if !timer.isPending {
|
||||||
timer.modifyingLock.Unlock()
|
timer.modifyingLock.Unlock()
|
||||||
timer.runningLock.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
timer.isPending = false
|
timer.isPending = false
|
||||||
timer.modifyingLock.Unlock()
|
timer.modifyingLock.Unlock()
|
||||||
|
|
||||||
expirationFunction(peer)
|
expirationFunction(peer)
|
||||||
timer.runningLock.Unlock()
|
|
||||||
})
|
})
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
return timer
|
return timer
|
||||||
@@ -74,12 +73,12 @@ func (timer *Timer) IsPending() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) timersActive() bool {
|
func (peer *Peer) timersActive() bool {
|
||||||
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
|
return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
|
||||||
}
|
}
|
||||||
|
|
||||||
func expiredRetransmitHandshake(peer *Peer) {
|
func expiredRetransmitHandshake(peer *Peer) {
|
||||||
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
|
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
|
||||||
peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
|
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
|
||||||
|
|
||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
peer.timers.sendKeepalive.Del()
|
peer.timers.sendKeepalive.Del()
|
||||||
@@ -88,7 +87,7 @@ func expiredRetransmitHandshake(peer *Peer) {
|
|||||||
/* We drop all packets without a keypair and don't try again,
|
/* We drop all packets without a keypair and don't try again,
|
||||||
* if we try unsuccessfully for too long to make a handshake.
|
* if we try unsuccessfully for too long to make a handshake.
|
||||||
*/
|
*/
|
||||||
peer.FlushNonceQueue()
|
peer.FlushStagedPackets()
|
||||||
|
|
||||||
/* We set a timer for destroying any residue that might be left
|
/* We set a timer for destroying any residue that might be left
|
||||||
* of a partial exchange.
|
* of a partial exchange.
|
||||||
@@ -97,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) {
|
|||||||
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
|
peer.timers.handshakeAttempts.Add(1)
|
||||||
peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
|
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
|
||||||
|
|
||||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
peer.Unlock()
|
|
||||||
|
|
||||||
peer.SendHandshakeInitiation(true)
|
peer.SendHandshakeInitiation(true)
|
||||||
}
|
}
|
||||||
@@ -113,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) {
|
|||||||
|
|
||||||
func expiredSendKeepalive(peer *Peer) {
|
func expiredSendKeepalive(peer *Peer) {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
if peer.timers.needAnotherKeepalive.Get() {
|
if peer.timers.needAnotherKeepalive.Load() {
|
||||||
peer.timers.needAnotherKeepalive.Set(false)
|
peer.timers.needAnotherKeepalive.Store(false)
|
||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||||
}
|
}
|
||||||
@@ -122,24 +117,19 @@ func expiredSendKeepalive(peer *Peer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func expiredNewHandshake(peer *Peer) {
|
func expiredNewHandshake(peer *Peer) {
|
||||||
peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
|
||||||
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
/* We clear the endpoint address src address, in case this is the cause of trouble. */
|
||||||
peer.Lock()
|
peer.markEndpointSrcForClearing()
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
peer.Unlock()
|
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func expiredZeroKeyMaterial(peer *Peer) {
|
func expiredZeroKeyMaterial(peer *Peer) {
|
||||||
peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
|
peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds()))
|
||||||
peer.ZeroAndFlushAll()
|
peer.ZeroAndFlushAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
func expiredPersistentKeepalive(peer *Peer) {
|
func expiredPersistentKeepalive(peer *Peer) {
|
||||||
if peer.persistentKeepaliveInterval > 0 {
|
if peer.persistentKeepaliveInterval.Load() > 0 {
|
||||||
peer.SendKeepalive()
|
peer.SendKeepalive()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -147,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) {
|
|||||||
/* Should be called after an authenticated data packet is sent. */
|
/* Should be called after an authenticated data packet is sent. */
|
||||||
func (peer *Peer) timersDataSent() {
|
func (peer *Peer) timersDataSent() {
|
||||||
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
|
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
|
||||||
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
|
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +147,7 @@ func (peer *Peer) timersDataReceived() {
|
|||||||
if !peer.timers.sendKeepalive.IsPending() {
|
if !peer.timers.sendKeepalive.IsPending() {
|
||||||
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
|
||||||
} else {
|
} else {
|
||||||
peer.timers.needAnotherKeepalive.Set(true)
|
peer.timers.needAnotherKeepalive.Store(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -179,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
|
|||||||
/* Should be called after a handshake initiation message is sent. */
|
/* Should be called after a handshake initiation message is sent. */
|
||||||
func (peer *Peer) timersHandshakeInitiated() {
|
func (peer *Peer) timersHandshakeInitiated() {
|
||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
|
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() {
|
|||||||
if peer.timersActive() {
|
if peer.timersActive() {
|
||||||
peer.timers.retransmitHandshake.Del()
|
peer.timers.retransmitHandshake.Del()
|
||||||
}
|
}
|
||||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
peer.timers.handshakeAttempts.Store(0)
|
||||||
peer.timers.sentLastMinuteHandshake.Set(false)
|
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||||
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
|
peer.lastHandshakeNano.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
|
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
|
||||||
@@ -202,8 +192,9 @@ func (peer *Peer) timersSessionDerived() {
|
|||||||
|
|
||||||
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
|
||||||
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
|
||||||
if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
|
keepalive := peer.persistentKeepaliveInterval.Load()
|
||||||
peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
|
if keepalive > 0 && peer.timersActive() {
|
||||||
|
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,9 +204,12 @@ func (peer *Peer) timersInit() {
|
|||||||
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
|
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
|
||||||
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
|
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
|
||||||
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
|
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
|
||||||
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
|
}
|
||||||
peer.timers.sentLastMinuteHandshake.Set(false)
|
|
||||||
peer.timers.needAnotherKeepalive.Set(false)
|
func (peer *Peer) timersStart() {
|
||||||
|
peer.timers.handshakeAttempts.Store(0)
|
||||||
|
peer.timers.sentLastMinuteHandshake.Store(false)
|
||||||
|
peer.timers.needAnotherKeepalive.Store(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) timersStop() {
|
func (peer *Peer) timersStop() {
|
||||||
|
|||||||
@@ -1,56 +1,53 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync/atomic"
|
"fmt"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"github.com/Lordy82/wireguard-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DefaultMTU = 1420
|
const DefaultMTU = 1420
|
||||||
|
|
||||||
func (device *Device) RoutineTUNEventReader() {
|
func (device *Device) RoutineTUNEventReader() {
|
||||||
setUp := false
|
device.log.Verbosef("Routine: event worker - started")
|
||||||
logDebug := device.log.Debug
|
|
||||||
logInfo := device.log.Info
|
|
||||||
logError := device.log.Error
|
|
||||||
|
|
||||||
logDebug.Println("Routine: event worker - started")
|
|
||||||
device.state.starting.Done()
|
|
||||||
|
|
||||||
for event := range device.tun.device.Events() {
|
for event := range device.tun.device.Events() {
|
||||||
if event&tun.EventMTUUpdate != 0 {
|
if event&tun.EventMTUUpdate != 0 {
|
||||||
mtu, err := device.tun.device.MTU()
|
mtu, err := device.tun.device.MTU()
|
||||||
old := atomic.LoadInt32(&device.tun.mtu)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to load updated MTU of device:", err)
|
device.log.Errorf("Failed to load updated MTU of device: %v", err)
|
||||||
} else if int(old) != mtu {
|
continue
|
||||||
if mtu+MessageTransportSize > MaxMessageSize {
|
}
|
||||||
logInfo.Println("MTU updated:", mtu, "(too large)")
|
if mtu < 0 {
|
||||||
} else {
|
device.log.Errorf("MTU not updated to negative value: %v", mtu)
|
||||||
logInfo.Println("MTU updated:", mtu)
|
continue
|
||||||
}
|
}
|
||||||
atomic.StoreInt32(&device.tun.mtu, int32(mtu))
|
var tooLarge string
|
||||||
|
if mtu > MaxContentSize {
|
||||||
|
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
|
||||||
|
mtu = MaxContentSize
|
||||||
|
}
|
||||||
|
old := device.tun.mtu.Swap(int32(mtu))
|
||||||
|
if int(old) != mtu {
|
||||||
|
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if event&tun.EventUp != 0 && !setUp {
|
if event&tun.EventUp != 0 {
|
||||||
logInfo.Println("Interface set up")
|
device.log.Verbosef("Interface up requested")
|
||||||
setUp = true
|
|
||||||
device.Up()
|
device.Up()
|
||||||
}
|
}
|
||||||
|
|
||||||
if event&tun.EventDown != 0 && setUp {
|
if event&tun.EventDown != 0 {
|
||||||
logInfo.Println("Interface set down")
|
device.log.Verbosef("Interface down requested")
|
||||||
setUp = false
|
|
||||||
device.Down()
|
device.Down()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logDebug.Println("Routine: event worker - stopped")
|
device.log.Verbosef("Routine: event worker - stopped")
|
||||||
device.state.stopping.Done()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
// newDummyTUN creates a dummy TUN device with the specified name.
|
|
||||||
func newDummyTUN(name string) tun.Device {
|
|
||||||
return &dummyTUN{
|
|
||||||
name: name,
|
|
||||||
packets: make(chan []byte, 100),
|
|
||||||
events: make(chan tun.Event, 10),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A dummyTUN is a tun.Device which is used in unit tests.
|
|
||||||
type dummyTUN struct {
|
|
||||||
name string
|
|
||||||
mtu int
|
|
||||||
packets chan []byte
|
|
||||||
events chan tun.Event
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dummyTUN) Events() chan tun.Event { return d.events }
|
|
||||||
func (*dummyTUN) File() *os.File { return nil }
|
|
||||||
func (*dummyTUN) Flush() error { return nil }
|
|
||||||
func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil }
|
|
||||||
func (d *dummyTUN) Name() (string, error) { return d.name, nil }
|
|
||||||
|
|
||||||
func (d *dummyTUN) Close() error {
|
|
||||||
close(d.events)
|
|
||||||
close(d.packets)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dummyTUN) Read(b []byte, offset int) (int, error) {
|
|
||||||
buf, ok := <-d.packets
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("device closed")
|
|
||||||
}
|
|
||||||
copy(b[offset:], buf)
|
|
||||||
return len(buf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *dummyTUN) Write(b []byte, offset int) (int, error) {
|
|
||||||
d.packets <- b[offset:]
|
|
||||||
return len(b), nil
|
|
||||||
}
|
|
||||||
680
device/uapi.go
680
device/uapi.go
@@ -1,43 +1,77 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"github.com/Lordy82/wireguard-go/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IPCError struct {
|
type IPCError struct {
|
||||||
int64
|
code int64 // error code
|
||||||
|
err error // underlying/wrapped error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s IPCError) Error() string {
|
func (s IPCError) Error() string {
|
||||||
return fmt.Sprintf("IPC error: %d", s.int64)
|
return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s IPCError) Unwrap() error {
|
||||||
|
return s.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s IPCError) ErrorCode() int64 {
|
func (s IPCError) ErrorCode() int64 {
|
||||||
return s.int64
|
return s.code
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
|
func ipcErrorf(code int64, msg string, args ...any) *IPCError {
|
||||||
lines := make([]string, 0, 100)
|
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
|
||||||
send := func(line string) {
|
}
|
||||||
lines = append(lines, line)
|
|
||||||
|
var byteBufferPool = &sync.Pool{
|
||||||
|
New: func() any { return new(bytes.Buffer) },
|
||||||
|
}
|
||||||
|
|
||||||
|
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
|
||||||
|
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
||||||
|
func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||||
|
device.ipcMutex.RLock()
|
||||||
|
defer device.ipcMutex.RUnlock()
|
||||||
|
|
||||||
|
buf := byteBufferPool.Get().(*bytes.Buffer)
|
||||||
|
buf.Reset()
|
||||||
|
defer byteBufferPool.Put(buf)
|
||||||
|
sendf := func(format string, args ...any) {
|
||||||
|
fmt.Fprintf(buf, format, args...)
|
||||||
|
buf.WriteByte('\n')
|
||||||
|
}
|
||||||
|
keyf := func(prefix string, key *[32]byte) {
|
||||||
|
buf.Grow(len(key)*2 + 2 + len(prefix))
|
||||||
|
buf.WriteString(prefix)
|
||||||
|
buf.WriteByte('=')
|
||||||
|
const hex = "0123456789abcdef"
|
||||||
|
for i := 0; i < len(key); i++ {
|
||||||
|
buf.WriteByte(hex[key[i]>>4])
|
||||||
|
buf.WriteByte(hex[key[i]&0xf])
|
||||||
|
}
|
||||||
|
buf.WriteByte('\n')
|
||||||
}
|
}
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
|
|
||||||
// lock required resources
|
// lock required resources
|
||||||
|
|
||||||
device.net.RLock()
|
device.net.RLock()
|
||||||
@@ -52,337 +86,326 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
|
|||||||
// serialize device related values
|
// serialize device related values
|
||||||
|
|
||||||
if !device.staticIdentity.privateKey.IsZero() {
|
if !device.staticIdentity.privateKey.IsZero() {
|
||||||
send("private_key=" + device.staticIdentity.privateKey.ToHex())
|
keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
|
||||||
}
|
}
|
||||||
|
|
||||||
if device.net.port != 0 {
|
if device.net.port != 0 {
|
||||||
send(fmt.Sprintf("listen_port=%d", device.net.port))
|
sendf("listen_port=%d", device.net.port)
|
||||||
}
|
}
|
||||||
|
|
||||||
if device.net.fwmark != 0 {
|
if device.net.fwmark != 0 {
|
||||||
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
|
sendf("fwmark=%d", device.net.fwmark)
|
||||||
}
|
}
|
||||||
|
|
||||||
// serialize each peer state
|
|
||||||
|
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.RLock()
|
// Serialize peer state.
|
||||||
defer peer.RUnlock()
|
peer.handshake.mutex.RLock()
|
||||||
|
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
||||||
send("public_key=" + peer.handshake.remoteStatic.ToHex())
|
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
||||||
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
|
peer.handshake.mutex.RUnlock()
|
||||||
send("protocol_version=1")
|
sendf("protocol_version=1")
|
||||||
if peer.endpoint != nil {
|
peer.endpoint.Lock()
|
||||||
send("endpoint=" + peer.endpoint.DstToString())
|
if peer.endpoint.val != nil {
|
||||||
|
sendf("endpoint=%s", peer.endpoint.val.DstToString())
|
||||||
}
|
}
|
||||||
|
peer.endpoint.Unlock()
|
||||||
|
|
||||||
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
|
nano := peer.lastHandshakeNano.Load()
|
||||||
secs := nano / time.Second.Nanoseconds()
|
secs := nano / time.Second.Nanoseconds()
|
||||||
nano %= time.Second.Nanoseconds()
|
nano %= time.Second.Nanoseconds()
|
||||||
|
|
||||||
send(fmt.Sprintf("last_handshake_time_sec=%d", secs))
|
sendf("last_handshake_time_sec=%d", secs)
|
||||||
send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
|
sendf("last_handshake_time_nsec=%d", nano)
|
||||||
send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
|
sendf("tx_bytes=%d", peer.txBytes.Load())
|
||||||
send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
|
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
||||||
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
|
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
||||||
|
|
||||||
for _, ip := range device.allowedips.EntriesForPeer(peer) {
|
|
||||||
send("allowed_ip=" + ip.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
|
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||||
|
sendf("allowed_ip=%s", prefix.String())
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// send lines (does not require resource locks)
|
// send lines (does not require resource locks)
|
||||||
|
if _, err := w.Write(buf.Bytes()); err != nil {
|
||||||
for _, line := range lines {
|
return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
|
||||||
_, err := socket.WriteString(line + "\n")
|
|
||||||
if err != nil {
|
|
||||||
return &IPCError{ipc.IpcErrorIO}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
|
||||||
scanner := bufio.NewScanner(socket)
|
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
||||||
logError := device.log.Error
|
func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||||
logDebug := device.log.Debug
|
device.ipcMutex.Lock()
|
||||||
|
defer device.ipcMutex.Unlock()
|
||||||
|
|
||||||
var peer *Peer
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
device.log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
dummy := false
|
peer := new(ipcSetPeer)
|
||||||
deviceConfig := true
|
deviceConfig := true
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
|
|
||||||
// parse line
|
|
||||||
|
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if line == "" {
|
if line == "" {
|
||||||
|
// Blank line means terminate operation.
|
||||||
|
peer.handlePostConfig()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
parts := strings.Split(line, "=")
|
key, value, ok := strings.Cut(line, "=")
|
||||||
if len(parts) != 2 {
|
if !ok {
|
||||||
return &IPCError{ipc.IpcErrorProtocol}
|
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
|
||||||
}
|
}
|
||||||
key := parts[0]
|
|
||||||
value := parts[1]
|
|
||||||
|
|
||||||
/* device configuration */
|
if key == "public_key" {
|
||||||
|
if deviceConfig {
|
||||||
if deviceConfig {
|
|
||||||
|
|
||||||
switch key {
|
|
||||||
case "private_key":
|
|
||||||
var sk NoisePrivateKey
|
|
||||||
err := sk.FromHex(value)
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to set private_key:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
logDebug.Println("UAPI: Updating private key")
|
|
||||||
device.SetPrivateKey(sk)
|
|
||||||
|
|
||||||
case "listen_port":
|
|
||||||
|
|
||||||
// parse port number
|
|
||||||
|
|
||||||
port, err := strconv.ParseUint(value, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to parse listen_port:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
// update port and rebind
|
|
||||||
|
|
||||||
logDebug.Println("UAPI: Updating listen port")
|
|
||||||
|
|
||||||
device.net.Lock()
|
|
||||||
device.net.port = uint16(port)
|
|
||||||
device.net.Unlock()
|
|
||||||
|
|
||||||
if err := device.BindUpdate(); err != nil {
|
|
||||||
logError.Println("Failed to set listen_port:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorPortInUse}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "fwmark":
|
|
||||||
|
|
||||||
// parse fwmark field
|
|
||||||
|
|
||||||
fwmark, err := func() (uint32, error) {
|
|
||||||
if value == "" {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
mark, err := strconv.ParseUint(value, 10, 32)
|
|
||||||
return uint32(mark), err
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Invalid fwmark", err)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
logDebug.Println("UAPI: Updating fwmark")
|
|
||||||
|
|
||||||
if err := device.BindSetMark(uint32(fwmark)); err != nil {
|
|
||||||
logError.Println("Failed to update fwmark:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorPortInUse}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "public_key":
|
|
||||||
// switch to peer configuration
|
|
||||||
logDebug.Println("UAPI: Transition to peer configuration")
|
|
||||||
deviceConfig = false
|
deviceConfig = false
|
||||||
|
|
||||||
case "replace_peers":
|
|
||||||
if value != "true" {
|
|
||||||
logError.Println("Failed to set replace_peers, invalid value:", value)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
logDebug.Println("UAPI: Removing all peers")
|
|
||||||
device.RemoveAllPeers()
|
|
||||||
|
|
||||||
default:
|
|
||||||
logError.Println("Invalid UAPI device key:", key)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
}
|
||||||
|
peer.handlePostConfig()
|
||||||
|
// Load/create the peer we are now configuring.
|
||||||
|
err := device.handlePublicKeyLine(peer, value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
/* peer configuration */
|
var err error
|
||||||
|
if deviceConfig {
|
||||||
if !deviceConfig {
|
err = device.handleDeviceLine(key, value)
|
||||||
|
} else {
|
||||||
switch key {
|
err = device.handlePeerLine(peer, key, value)
|
||||||
|
|
||||||
case "public_key":
|
|
||||||
var publicKey NoisePublicKey
|
|
||||||
err := publicKey.FromHex(value)
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to get peer by public key:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ignore peer with public key of device
|
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
|
||||||
dummy = device.staticIdentity.publicKey.Equals(publicKey)
|
|
||||||
device.staticIdentity.RUnlock()
|
|
||||||
|
|
||||||
if dummy {
|
|
||||||
peer = &Peer{}
|
|
||||||
} else {
|
|
||||||
peer = device.LookupPeer(publicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer == nil {
|
|
||||||
peer, err = device.NewPeer(publicKey)
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to create new peer:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
if peer == nil {
|
|
||||||
dummy = true
|
|
||||||
peer = &Peer{}
|
|
||||||
} else {
|
|
||||||
logDebug.Println(peer, "- UAPI: Created")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "remove":
|
|
||||||
|
|
||||||
// remove currently selected peer from device
|
|
||||||
|
|
||||||
if value != "true" {
|
|
||||||
logError.Println("Failed to set remove, invalid value:", value)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
if !dummy {
|
|
||||||
logDebug.Println(peer, "- UAPI: Removing")
|
|
||||||
device.RemovePeer(peer.handshake.remoteStatic)
|
|
||||||
}
|
|
||||||
peer = &Peer{}
|
|
||||||
dummy = true
|
|
||||||
|
|
||||||
case "preshared_key":
|
|
||||||
|
|
||||||
// update PSK
|
|
||||||
|
|
||||||
logDebug.Println(peer, "- UAPI: Updating preshared key")
|
|
||||||
|
|
||||||
peer.handshake.mutex.Lock()
|
|
||||||
err := peer.handshake.presharedKey.FromHex(value)
|
|
||||||
peer.handshake.mutex.Unlock()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to set preshared key:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "endpoint":
|
|
||||||
|
|
||||||
// set endpoint destination
|
|
||||||
|
|
||||||
logDebug.Println(peer, "- UAPI: Updating endpoint")
|
|
||||||
|
|
||||||
err := func() error {
|
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
endpoint, err := CreateEndpoint(value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
peer.endpoint = endpoint
|
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to set endpoint:", err, ":", value)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
|
||||||
|
|
||||||
// update persistent keepalive interval
|
|
||||||
|
|
||||||
logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
|
|
||||||
|
|
||||||
secs, err := strconv.ParseUint(value, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to set persistent keepalive interval:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
old := peer.persistentKeepaliveInterval
|
|
||||||
peer.persistentKeepaliveInterval = uint16(secs)
|
|
||||||
|
|
||||||
// send immediate keepalive if we're turning it on and before it wasn't on
|
|
||||||
|
|
||||||
if old == 0 && secs != 0 {
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to get tun device status:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorIO}
|
|
||||||
}
|
|
||||||
if device.isUp.Get() && !dummy {
|
|
||||||
peer.SendKeepalive()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "replace_allowed_ips":
|
|
||||||
|
|
||||||
logDebug.Println(peer, "- UAPI: Removing all allowedips")
|
|
||||||
|
|
||||||
if value != "true" {
|
|
||||||
logError.Println("Failed to replace allowedips, invalid value:", value)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
if dummy {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
device.allowedips.RemoveByPeer(peer)
|
|
||||||
|
|
||||||
case "allowed_ip":
|
|
||||||
|
|
||||||
logDebug.Println(peer, "- UAPI: Adding allowedip")
|
|
||||||
|
|
||||||
_, network, err := net.ParseCIDR(value)
|
|
||||||
if err != nil {
|
|
||||||
logError.Println("Failed to set allowed ip:", err)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
if dummy {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ones, _ := network.Mask.Size()
|
|
||||||
device.allowedips.Insert(network.IP, uint(ones), peer)
|
|
||||||
|
|
||||||
case "protocol_version":
|
|
||||||
|
|
||||||
if value != "1" {
|
|
||||||
logError.Println("Invalid protocol version:", value)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
logError.Println("Invalid UAPI peer key:", key)
|
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
peer.handlePostConfig()
|
||||||
|
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) handleDeviceLine(key, value string) error {
|
||||||
|
switch key {
|
||||||
|
case "private_key":
|
||||||
|
var sk NoisePrivateKey
|
||||||
|
err := sk.FromMaybeZeroHex(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Updating private key")
|
||||||
|
device.SetPrivateKey(sk)
|
||||||
|
|
||||||
|
case "listen_port":
|
||||||
|
port, err := strconv.ParseUint(value, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// update port and rebind
|
||||||
|
device.log.Verbosef("UAPI: Updating listen port")
|
||||||
|
|
||||||
|
device.net.Lock()
|
||||||
|
device.net.port = uint16(port)
|
||||||
|
device.net.Unlock()
|
||||||
|
|
||||||
|
if err := device.BindUpdate(); err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "fwmark":
|
||||||
|
mark, err := strconv.ParseUint(value, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
device.log.Verbosef("UAPI: Updating fwmark")
|
||||||
|
if err := device.BindSetMark(uint32(mark)); err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "replace_peers":
|
||||||
|
if value != "true" {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("UAPI: Removing all peers")
|
||||||
|
device.RemoveAllPeers()
|
||||||
|
|
||||||
|
default:
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// An ipcSetPeer is the current state of an IPC set operation on a peer.
|
||||||
|
type ipcSetPeer struct {
|
||||||
|
*Peer // Peer is the current peer being operated on
|
||||||
|
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
|
||||||
|
created bool // new reports whether this is a newly created peer
|
||||||
|
pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *ipcSetPeer) handlePostConfig() {
|
||||||
|
if peer.Peer == nil || peer.dummy {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if peer.created {
|
||||||
|
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
|
||||||
|
}
|
||||||
|
if peer.device.isUp() {
|
||||||
|
peer.Start()
|
||||||
|
if peer.pkaOn {
|
||||||
|
peer.SendKeepalive()
|
||||||
|
}
|
||||||
|
peer.SendStagedPackets()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
|
||||||
|
// Load/create the peer we are configuring.
|
||||||
|
var publicKey NoisePublicKey
|
||||||
|
err := publicKey.FromHex(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ignore peer with the same public key as this device.
|
||||||
|
device.staticIdentity.RLock()
|
||||||
|
peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
|
||||||
|
device.staticIdentity.RUnlock()
|
||||||
|
|
||||||
|
if peer.dummy {
|
||||||
|
peer.Peer = &Peer{}
|
||||||
|
} else {
|
||||||
|
peer.Peer = device.LookupPeer(publicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.created = peer.Peer == nil
|
||||||
|
if peer.created {
|
||||||
|
peer.Peer, err = device.NewPeer(publicKey)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
|
||||||
|
}
|
||||||
|
device.log.Verbosef("%v - UAPI: Created", peer.Peer)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
|
||||||
|
switch key {
|
||||||
|
case "update_only":
|
||||||
|
// allow disabling of creation
|
||||||
|
if value != "true" {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
||||||
|
}
|
||||||
|
if peer.created && !peer.dummy {
|
||||||
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
|
peer.Peer = &Peer{}
|
||||||
|
peer.dummy = true
|
||||||
|
}
|
||||||
|
|
||||||
|
case "remove":
|
||||||
|
// remove currently selected peer from device
|
||||||
|
if value != "true" {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
|
||||||
|
}
|
||||||
|
if !peer.dummy {
|
||||||
|
device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
|
||||||
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
|
}
|
||||||
|
peer.Peer = &Peer{}
|
||||||
|
peer.dummy = true
|
||||||
|
|
||||||
|
case "preshared_key":
|
||||||
|
device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
|
||||||
|
|
||||||
|
peer.handshake.mutex.Lock()
|
||||||
|
err := peer.handshake.presharedKey.FromHex(value)
|
||||||
|
peer.handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "endpoint":
|
||||||
|
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
|
||||||
|
endpoint, err := device.net.bind.ParseEndpoint(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||||
|
}
|
||||||
|
peer.endpoint.Lock()
|
||||||
|
defer peer.endpoint.Unlock()
|
||||||
|
peer.endpoint.val = endpoint
|
||||||
|
|
||||||
|
case "persistent_keepalive_interval":
|
||||||
|
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
||||||
|
|
||||||
|
secs, err := strconv.ParseUint(value, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
||||||
|
|
||||||
|
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
||||||
|
peer.pkaOn = old == 0 && secs != 0
|
||||||
|
|
||||||
|
case "replace_allowed_ips":
|
||||||
|
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
||||||
|
if value != "true" {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
||||||
|
}
|
||||||
|
if peer.dummy {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
device.allowedips.RemoveByPeer(peer.Peer)
|
||||||
|
|
||||||
|
case "allowed_ip":
|
||||||
|
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||||
|
prefix, err := netip.ParsePrefix(value)
|
||||||
|
if err != nil {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||||
|
}
|
||||||
|
if peer.dummy {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
device.allowedips.Insert(prefix, peer.Peer)
|
||||||
|
|
||||||
|
case "protocol_version":
|
||||||
|
if value != "1" {
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) IpcGet() (string, error) {
|
||||||
|
buf := new(strings.Builder)
|
||||||
|
if err := device.IpcGetOperation(buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) IpcSet(uapiConf string) error {
|
||||||
|
return device.IpcSetOperation(strings.NewReader(uapiConf))
|
||||||
|
}
|
||||||
|
|
||||||
func (device *Device) IpcHandle(socket net.Conn) {
|
func (device *Device) IpcHandle(socket net.Conn) {
|
||||||
|
|
||||||
// create buffered read/writer
|
|
||||||
|
|
||||||
defer socket.Close()
|
defer socket.Close()
|
||||||
|
|
||||||
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
||||||
@@ -391,35 +414,44 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
|||||||
return bufio.NewReadWriter(reader, writer)
|
return bufio.NewReadWriter(reader, writer)
|
||||||
}(socket)
|
}(socket)
|
||||||
|
|
||||||
defer buffered.Flush()
|
for {
|
||||||
|
op, err := buffered.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
op, err := buffered.ReadString('\n')
|
// handle operation
|
||||||
if err != nil {
|
switch op {
|
||||||
return
|
case "set=1\n":
|
||||||
}
|
err = device.IpcSetOperation(buffered.Reader)
|
||||||
|
case "get=1\n":
|
||||||
|
var nextByte byte
|
||||||
|
nextByte, err = buffered.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if nextByte != '\n' {
|
||||||
|
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = device.IpcGetOperation(buffered.Writer)
|
||||||
|
default:
|
||||||
|
device.log.Errorf("invalid UAPI operation: %v", op)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// handle operation
|
// write status
|
||||||
|
var status *IPCError
|
||||||
var status *IPCError
|
if err != nil && !errors.As(err, &status) {
|
||||||
|
// shouldn't happen
|
||||||
switch op {
|
status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
|
||||||
case "set=1\n":
|
}
|
||||||
status = device.IpcSetOperation(buffered.Reader)
|
if status != nil {
|
||||||
|
device.log.Errorf("%v", status)
|
||||||
case "get=1\n":
|
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
|
||||||
status = device.IpcGetOperation(buffered.Writer)
|
} else {
|
||||||
|
fmt.Fprintf(buffered, "errno=0\n\n")
|
||||||
default:
|
}
|
||||||
device.log.Error.Println("Invalid UAPI operation:", op)
|
buffered.Flush()
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// write status
|
|
||||||
|
|
||||||
if status != nil {
|
|
||||||
device.log.Error.Println(status)
|
|
||||||
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
package device
|
|
||||||
|
|
||||||
const WireGuardGoVersion = "0.0.20190908"
|
|
||||||
51
format_test.go
Normal file
51
format_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"go/format"
|
||||||
|
"io/fs"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormatting(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to walk %s: %v", path, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if d.IsDir() || filepath.Ext(path) != ".go" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(path string) {
|
||||||
|
defer wg.Done()
|
||||||
|
src, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to read %s: %v", path, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
|
||||||
|
}
|
||||||
|
formatted, err := format.Source(src)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unable to format %s: %v", path, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(src, formatted) {
|
||||||
|
t.Errorf("unformatted code: %s", path)
|
||||||
|
}
|
||||||
|
}(path)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
19
go.mod
19
go.mod
@@ -1,10 +1,17 @@
|
|||||||
module golang.zx2c4.com/wireguard
|
module github.com/Lordy82/wireguard-go
|
||||||
|
|
||||||
go 1.12
|
go 1.23.1
|
||||||
|
|
||||||
|
toolchain go1.24.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472
|
golang.org/x/crypto v0.34.0
|
||||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297
|
golang.org/x/net v0.35.0
|
||||||
golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad
|
golang.org/x/sys v0.30.0
|
||||||
golang.org/x/text v0.3.2
|
gvisor.dev/gvisor v0.0.0-20250218181608-84670a4fc612
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/google/btree v1.1.3 // indirect
|
||||||
|
golang.org/x/time v0.10.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
34
go.sum
34
go.sum
@@ -1,14 +1,20 @@
|
|||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||||
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM=
|
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||||
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM=
|
golang.org/x/crypto v0.34.0 h1:+/C6tk6rf/+t5DhUketUbD1aNGqiSX3j15Z6xuIDlBA=
|
||||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/crypto v0.34.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||||
golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad h1:cCejgArrk10gX6kFqjWeLwXD7aVMqWoRpyUCaaJSggc=
|
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||||
golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
||||||
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4=
|
||||||
|
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
|
gvisor.dev/gvisor v0.0.0-20250218181608-84670a4fc612 h1:Ah91Og1rlLUY0Sm/fP3vQsAtFoWrOoPTJ+4320GDL2k=
|
||||||
|
gvisor.dev/gvisor v0.0.0-20250218181608-84670a4fc612/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=
|
||||||
|
|||||||
287
ipc/namedpipe/file.go
Normal file
287
ipc/namedpipe/file.go
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
// Copyright 2021 The Go Authors. All rights reserved.
|
||||||
|
// Copyright 2015 Microsoft
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package namedpipe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
type timeoutChan chan struct{}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ioInitOnce sync.Once
|
||||||
|
ioCompletionPort windows.Handle
|
||||||
|
)
|
||||||
|
|
||||||
|
// ioResult contains the result of an asynchronous IO operation
|
||||||
|
type ioResult struct {
|
||||||
|
bytes uint32
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ioOperation represents an outstanding asynchronous Win32 IO
|
||||||
|
type ioOperation struct {
|
||||||
|
o windows.Overlapped
|
||||||
|
ch chan ioResult
|
||||||
|
}
|
||||||
|
|
||||||
|
func initIo() {
|
||||||
|
h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ioCompletionPort = h
|
||||||
|
go ioCompletionProcessor(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
||||||
|
// It takes ownership of this handle and will close it if it is garbage collected.
|
||||||
|
type file struct {
|
||||||
|
handle windows.Handle
|
||||||
|
wg sync.WaitGroup
|
||||||
|
wgLock sync.RWMutex
|
||||||
|
closing atomic.Bool
|
||||||
|
socket bool
|
||||||
|
readDeadline deadlineHandler
|
||||||
|
writeDeadline deadlineHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
type deadlineHandler struct {
|
||||||
|
setLock sync.Mutex
|
||||||
|
channel timeoutChan
|
||||||
|
channelLock sync.RWMutex
|
||||||
|
timer *time.Timer
|
||||||
|
timedout atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeFile makes a new file from an existing file handle
|
||||||
|
func makeFile(h windows.Handle) (*file, error) {
|
||||||
|
f := &file{handle: h}
|
||||||
|
ioInitOnce.Do(initIo)
|
||||||
|
_, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f.readDeadline.channel = make(timeoutChan)
|
||||||
|
f.writeDeadline.channel = make(timeoutChan)
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeHandle closes the resources associated with a Win32 handle
|
||||||
|
func (f *file) closeHandle() {
|
||||||
|
f.wgLock.Lock()
|
||||||
|
// Atomically set that we are closing, releasing the resources only once.
|
||||||
|
if f.closing.Swap(true) == false {
|
||||||
|
f.wgLock.Unlock()
|
||||||
|
// cancel all IO and wait for it to complete
|
||||||
|
windows.CancelIoEx(f.handle, nil)
|
||||||
|
f.wg.Wait()
|
||||||
|
// at this point, no new IO can start
|
||||||
|
windows.Close(f.handle)
|
||||||
|
f.handle = 0
|
||||||
|
} else {
|
||||||
|
f.wgLock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes a file.
|
||||||
|
func (f *file) Close() error {
|
||||||
|
f.closeHandle()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareIo prepares for a new IO operation.
|
||||||
|
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||||
|
func (f *file) prepareIo() (*ioOperation, error) {
|
||||||
|
f.wgLock.RLock()
|
||||||
|
if f.closing.Load() {
|
||||||
|
f.wgLock.RUnlock()
|
||||||
|
return nil, os.ErrClosed
|
||||||
|
}
|
||||||
|
f.wg.Add(1)
|
||||||
|
f.wgLock.RUnlock()
|
||||||
|
c := &ioOperation{}
|
||||||
|
c.ch = make(chan ioResult)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ioCompletionProcessor processes completed async IOs forever
|
||||||
|
func ioCompletionProcessor(h windows.Handle) {
|
||||||
|
for {
|
||||||
|
var bytes uint32
|
||||||
|
var key uintptr
|
||||||
|
var op *ioOperation
|
||||||
|
err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE)
|
||||||
|
if op == nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
op.ch <- ioResult{bytes, err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
|
||||||
|
// the operation has actually completed.
|
||||||
|
func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
||||||
|
if err != windows.ERROR_IO_PENDING {
|
||||||
|
return int(bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.closing.Load() {
|
||||||
|
windows.CancelIoEx(f.handle, &c.o)
|
||||||
|
}
|
||||||
|
|
||||||
|
var timeout timeoutChan
|
||||||
|
if d != nil {
|
||||||
|
d.channelLock.Lock()
|
||||||
|
timeout = d.channel
|
||||||
|
d.channelLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
var r ioResult
|
||||||
|
select {
|
||||||
|
case r = <-c.ch:
|
||||||
|
err = r.err
|
||||||
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
|
if f.closing.Load() {
|
||||||
|
err = os.ErrClosed
|
||||||
|
}
|
||||||
|
} else if err != nil && f.socket {
|
||||||
|
// err is from Win32. Query the overlapped structure to get the winsock error.
|
||||||
|
var bytes, flags uint32
|
||||||
|
err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
windows.CancelIoEx(f.handle, &c.o)
|
||||||
|
r = <-c.ch
|
||||||
|
err = r.err
|
||||||
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
|
err = os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runtime.KeepAlive is needed, as c is passed via native
|
||||||
|
// code to ioCompletionProcessor, c must remain alive
|
||||||
|
// until the channel read is complete.
|
||||||
|
runtime.KeepAlive(c)
|
||||||
|
return int(r.bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads from a file handle.
|
||||||
|
func (f *file) Read(b []byte) (int, error) {
|
||||||
|
c, err := f.prepareIo()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer f.wg.Done()
|
||||||
|
|
||||||
|
if f.readDeadline.timedout.Load() {
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
var bytes uint32
|
||||||
|
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
|
||||||
|
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
|
||||||
|
runtime.KeepAlive(b)
|
||||||
|
|
||||||
|
// Handle EOF conditions.
|
||||||
|
if err == nil && n == 0 && len(b) != 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
} else if err == windows.ERROR_BROKEN_PIPE {
|
||||||
|
return 0, io.EOF
|
||||||
|
} else {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes to a file handle.
|
||||||
|
func (f *file) Write(b []byte) (int, error) {
|
||||||
|
c, err := f.prepareIo()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer f.wg.Done()
|
||||||
|
|
||||||
|
if f.writeDeadline.timedout.Load() {
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
var bytes uint32
|
||||||
|
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
|
||||||
|
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
|
||||||
|
runtime.KeepAlive(b)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *file) SetReadDeadline(deadline time.Time) error {
|
||||||
|
return f.readDeadline.set(deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *file) SetWriteDeadline(deadline time.Time) error {
|
||||||
|
return f.writeDeadline.set(deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *file) Flush() error {
|
||||||
|
return windows.FlushFileBuffers(f.handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *file) Fd() uintptr {
|
||||||
|
return uintptr(f.handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *deadlineHandler) set(deadline time.Time) error {
|
||||||
|
d.setLock.Lock()
|
||||||
|
defer d.setLock.Unlock()
|
||||||
|
|
||||||
|
if d.timer != nil {
|
||||||
|
if !d.timer.Stop() {
|
||||||
|
<-d.channel
|
||||||
|
}
|
||||||
|
d.timer = nil
|
||||||
|
}
|
||||||
|
d.timedout.Store(false)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-d.channel:
|
||||||
|
d.channelLock.Lock()
|
||||||
|
d.channel = make(chan struct{})
|
||||||
|
d.channelLock.Unlock()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if deadline.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutIO := func() {
|
||||||
|
d.timedout.Store(true)
|
||||||
|
close(d.channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
duration := deadline.Sub(now)
|
||||||
|
if deadline.After(now) {
|
||||||
|
// Deadline is in the future, set a timer to wait
|
||||||
|
d.timer = time.AfterFunc(duration, timeoutIO)
|
||||||
|
} else {
|
||||||
|
// Deadline is in the past. Cancel all pending IO now.
|
||||||
|
timeoutIO()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
485
ipc/namedpipe/namedpipe.go
Normal file
485
ipc/namedpipe/namedpipe.go
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
// Copyright 2021 The Go Authors. All rights reserved.
|
||||||
|
// Copyright 2015 Microsoft
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
//go:build windows
|
||||||
|
|
||||||
|
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
|
||||||
|
package namedpipe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pipe struct {
|
||||||
|
*file
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
type messageBytePipe struct {
|
||||||
|
pipe
|
||||||
|
writeClosed atomic.Bool
|
||||||
|
readEOF bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type pipeAddress string
|
||||||
|
|
||||||
|
func (f *pipe) LocalAddr() net.Addr {
|
||||||
|
return pipeAddress(f.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *pipe) RemoteAddr() net.Addr {
|
||||||
|
return pipeAddress(f.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *pipe) SetDeadline(t time.Time) error {
|
||||||
|
f.SetReadDeadline(t)
|
||||||
|
f.SetWriteDeadline(t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||||
|
func (f *messageBytePipe) CloseWrite() error {
|
||||||
|
if !f.writeClosed.CompareAndSwap(false, true) {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
err := f.file.Flush()
|
||||||
|
if err != nil {
|
||||||
|
f.writeClosed.Store(false)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = f.file.Write(nil)
|
||||||
|
if err != nil {
|
||||||
|
f.writeClosed.Store(false)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||||
|
// they are used to implement CloseWrite.
|
||||||
|
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
||||||
|
if f.writeClosed.Load() {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if len(b) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return f.file.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
|
||||||
|
// mode pipe will return io.EOF, as will all subsequent reads.
|
||||||
|
func (f *messageBytePipe) Read(b []byte) (int, error) {
|
||||||
|
if f.readEOF {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n, err := f.file.Read(b)
|
||||||
|
if err == io.EOF {
|
||||||
|
// If this was the result of a zero-byte read, then
|
||||||
|
// it is possible that the read was due to a zero-size
|
||||||
|
// message. Since we are simulating CloseWrite with a
|
||||||
|
// zero-byte message, ensure that all future Read calls
|
||||||
|
// also return EOF.
|
||||||
|
f.readEOF = true
|
||||||
|
} else if err == windows.ERROR_MORE_DATA {
|
||||||
|
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||||
|
// and the message still has more bytes. Treat this as a success, since
|
||||||
|
// this package presents all named pipes as byte streams.
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *pipe) Handle() windows.Handle {
|
||||||
|
return f.handle
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s pipeAddress) Network() string {
|
||||||
|
return "pipe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s pipeAddress) String() string {
|
||||||
|
return string(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDialPipe attempts to dial the specified pipe until cancellation or timeout.
|
||||||
|
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return 0, ctx.Err()
|
||||||
|
default:
|
||||||
|
path16, err := windows.UTF16PtrFromString(*path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
||||||
|
if err == nil {
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
if err != windows.ERROR_PIPE_BUSY {
|
||||||
|
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||||
|
}
|
||||||
|
// Wait 10 msec and try again. This is a rather simplistic
|
||||||
|
// view, as we always try each 10 milliseconds.
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialConfig exposes various options for use in Dial and DialContext.
|
||||||
|
type DialConfig struct {
|
||||||
|
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialTimeout connects to the specified named pipe by path, timing out if the
|
||||||
|
// connection takes longer than the specified duration. If timeout is zero, then
|
||||||
|
// we use a default timeout of 2 seconds.
|
||||||
|
func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = time.Second * 2
|
||||||
|
}
|
||||||
|
absTimeout := time.Now().Add(timeout)
|
||||||
|
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
||||||
|
conn, err := config.DialContext(ctx, path)
|
||||||
|
if err == context.DeadlineExceeded {
|
||||||
|
return nil, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext attempts to connect to the specified named pipe by path.
|
||||||
|
func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
|
||||||
|
var err error
|
||||||
|
var h windows.Handle
|
||||||
|
h, err = tryDialPipe(ctx, &path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ExpectedOwner != nil {
|
||||||
|
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
realOwner, _, err := sd.Owner()
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !realOwner.Equals(config.ExpectedOwner) {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, windows.ERROR_ACCESS_DENIED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var flags uint32
|
||||||
|
err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := makeFile(h)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the pipe is in message mode, return a message byte pipe, which
|
||||||
|
// supports CloseWrite.
|
||||||
|
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
|
||||||
|
return &messageBytePipe{
|
||||||
|
pipe: pipe{file: f, path: path},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &pipe{file: f, path: path}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultDialer DialConfig
|
||||||
|
|
||||||
|
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
|
||||||
|
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return defaultDialer.DialTimeout(path, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext calls DialConfig.DialContext using an empty configuration.
|
||||||
|
func DialContext(ctx context.Context, path string) (net.Conn, error) {
|
||||||
|
return defaultDialer.DialContext(ctx, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
type acceptResponse struct {
|
||||||
|
f *file
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type pipeListener struct {
|
||||||
|
firstHandle windows.Handle
|
||||||
|
path string
|
||||||
|
config ListenConfig
|
||||||
|
acceptCh chan chan acceptResponse
|
||||||
|
closeCh chan int
|
||||||
|
doneCh chan int
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
|
||||||
|
path16, err := windows.UTF16PtrFromString(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
var oa windows.OBJECT_ATTRIBUTES
|
||||||
|
oa.Length = uint32(unsafe.Sizeof(oa))
|
||||||
|
|
||||||
|
var ntPath windows.NTUnicodeString
|
||||||
|
if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil {
|
||||||
|
if ntstatus, ok := err.(windows.NTStatus); ok {
|
||||||
|
err = ntstatus.Errno()
|
||||||
|
}
|
||||||
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
}
|
||||||
|
defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer)))
|
||||||
|
oa.ObjectName = &ntPath
|
||||||
|
|
||||||
|
// The security descriptor is only needed for the first pipe.
|
||||||
|
if isFirstPipe {
|
||||||
|
if sd != nil {
|
||||||
|
oa.SecurityDescriptor = sd
|
||||||
|
} else {
|
||||||
|
// Construct the default named pipe security descriptor.
|
||||||
|
var acl *windows.ACL
|
||||||
|
if err := windows.RtlDefaultNpAcl(&acl); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
|
||||||
|
sd, err = windows.NewSecurityDescriptor()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if err = sd.SetDACL(acl, true, false); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
oa.SecurityDescriptor = sd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
|
||||||
|
if c.MessageMode {
|
||||||
|
typ |= windows.FILE_PIPE_MESSAGE_TYPE
|
||||||
|
}
|
||||||
|
|
||||||
|
disposition := uint32(windows.FILE_OPEN)
|
||||||
|
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
||||||
|
if isFirstPipe {
|
||||||
|
disposition = windows.FILE_CREATE
|
||||||
|
// By not asking for read or write access, the named pipe file system
|
||||||
|
// will put this pipe into an initially disconnected state, blocking
|
||||||
|
// client connections until the next call with isFirstPipe == false.
|
||||||
|
access = windows.SYNCHRONIZE
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := int64(-50 * 10000) // 50ms
|
||||||
|
|
||||||
|
var (
|
||||||
|
h windows.Handle
|
||||||
|
iosb windows.IO_STATUS_BLOCK
|
||||||
|
)
|
||||||
|
err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout)
|
||||||
|
if err != nil {
|
||||||
|
if ntstatus, ok := err.(windows.NTStatus); ok {
|
||||||
|
err = ntstatus.Errno()
|
||||||
|
}
|
||||||
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime.KeepAlive(ntPath)
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *pipeListener) makeServerPipe() (*file, error) {
|
||||||
|
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f, err := makeFile(h)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *pipeListener) makeConnectedServerPipe() (*file, error) {
|
||||||
|
p, err := l.makeServerPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the client to connect.
|
||||||
|
ch := make(chan error)
|
||||||
|
go func(p *file) {
|
||||||
|
ch <- connectPipe(p)
|
||||||
|
}(p)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-ch:
|
||||||
|
if err != nil {
|
||||||
|
p.Close()
|
||||||
|
p = nil
|
||||||
|
}
|
||||||
|
case <-l.closeCh:
|
||||||
|
// Abort the connect request by closing the handle.
|
||||||
|
p.Close()
|
||||||
|
p = nil
|
||||||
|
err = <-ch
|
||||||
|
if err == nil || err == os.ErrClosed {
|
||||||
|
err = net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *pipeListener) listenerRoutine() {
|
||||||
|
closed := false
|
||||||
|
for !closed {
|
||||||
|
select {
|
||||||
|
case <-l.closeCh:
|
||||||
|
closed = true
|
||||||
|
case responseCh := <-l.acceptCh:
|
||||||
|
var (
|
||||||
|
p *file
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
p, err = l.makeConnectedServerPipe()
|
||||||
|
// If the connection was immediately closed by the client, try
|
||||||
|
// again.
|
||||||
|
if err != windows.ERROR_NO_DATA {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responseCh <- acceptResponse{p, err}
|
||||||
|
closed = err == net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
windows.Close(l.firstHandle)
|
||||||
|
l.firstHandle = 0
|
||||||
|
// Notify Close and Accept callers that the handle has been closed.
|
||||||
|
close(l.doneCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenConfig contains configuration for the pipe listener.
|
||||||
|
type ListenConfig struct {
|
||||||
|
// SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used.
|
||||||
|
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
|
|
||||||
|
// MessageMode determines whether the pipe is in byte or message mode. In either
|
||||||
|
// case the pipe is read in byte mode by default. The only practical difference in
|
||||||
|
// this implementation is that CloseWrite is only supported for message mode pipes;
|
||||||
|
// CloseWrite is implemented as a zero-byte write, but zero-byte writes are only
|
||||||
|
// transferred to the reader (and returned as io.EOF in this implementation)
|
||||||
|
// when the pipe is in message mode.
|
||||||
|
MessageMode bool
|
||||||
|
|
||||||
|
// InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed.
|
||||||
|
InputBufferSize int32
|
||||||
|
|
||||||
|
// OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed.
|
||||||
|
OutputBufferSize int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
|
||||||
|
// The pipe must not already exist.
|
||||||
|
func (c *ListenConfig) Listen(path string) (net.Listener, error) {
|
||||||
|
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l := &pipeListener{
|
||||||
|
firstHandle: h,
|
||||||
|
path: path,
|
||||||
|
config: *c,
|
||||||
|
acceptCh: make(chan chan acceptResponse),
|
||||||
|
closeCh: make(chan int),
|
||||||
|
doneCh: make(chan int),
|
||||||
|
}
|
||||||
|
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
|
||||||
|
if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
|
||||||
|
path16, err := windows.UTF16PtrFromString(path)
|
||||||
|
if err == nil {
|
||||||
|
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
||||||
|
if err == nil {
|
||||||
|
windows.CloseHandle(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go l.listenerRoutine()
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultListener ListenConfig
|
||||||
|
|
||||||
|
// Listen calls ListenConfig.Listen using an empty configuration.
|
||||||
|
func Listen(path string) (net.Listener, error) {
|
||||||
|
return defaultListener.Listen(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectPipe(p *file) error {
|
||||||
|
c, err := p.prepareIo()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer p.wg.Done()
|
||||||
|
|
||||||
|
err = windows.ConnectNamedPipe(p.handle, &c.o)
|
||||||
|
_, err = p.asyncIo(c, nil, 0, err)
|
||||||
|
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *pipeListener) Accept() (net.Conn, error) {
|
||||||
|
ch := make(chan acceptResponse)
|
||||||
|
select {
|
||||||
|
case l.acceptCh <- ch:
|
||||||
|
response := <-ch
|
||||||
|
err := response.err
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if l.config.MessageMode {
|
||||||
|
return &messageBytePipe{
|
||||||
|
pipe: pipe{file: response.f, path: l.path},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &pipe{file: response.f, path: l.path}, nil
|
||||||
|
case <-l.doneCh:
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *pipeListener) Close() error {
|
||||||
|
select {
|
||||||
|
case l.closeCh <- 1:
|
||||||
|
<-l.doneCh
|
||||||
|
case <-l.doneCh:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *pipeListener) Addr() net.Addr {
|
||||||
|
return pipeAddress(l.path)
|
||||||
|
}
|
||||||
674
ipc/namedpipe/namedpipe_test.go
Normal file
674
ipc/namedpipe/namedpipe_test.go
Normal file
@@ -0,0 +1,674 @@
|
|||||||
|
// Copyright 2021 The Go Authors. All rights reserved.
|
||||||
|
// Copyright 2015 Microsoft
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package namedpipe_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/ipc/namedpipe"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func randomPipePath() string {
|
||||||
|
guid, err := windows.GenerateGUID()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return `\\.\PIPE\go-namedpipe-test-` + guid.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPingPong(t *testing.T) {
|
||||||
|
const (
|
||||||
|
ping = 42
|
||||||
|
pong = 24
|
||||||
|
)
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
listener, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to listen on pipe: %v", err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
go func() {
|
||||||
|
incoming, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to accept pipe connection: %v", err)
|
||||||
|
}
|
||||||
|
defer incoming.Close()
|
||||||
|
var data [1]byte
|
||||||
|
_, err = incoming.Read(data[:])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to read ping from pipe: %v", err)
|
||||||
|
}
|
||||||
|
if data[0] != ping {
|
||||||
|
t.Fatalf("expected ping, got %d", data[0])
|
||||||
|
}
|
||||||
|
data[0] = pong
|
||||||
|
_, err = incoming.Write(data[:])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to write pong to pipe: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to dial pipe: %v", err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
client.SetDeadline(time.Now().Add(time.Second * 5))
|
||||||
|
var data [1]byte
|
||||||
|
data[0] = ping
|
||||||
|
_, err = client.Write(data[:])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to write ping to pipe: %v", err)
|
||||||
|
}
|
||||||
|
_, err = client.Read(data[:])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to read pong from pipe: %v", err)
|
||||||
|
}
|
||||||
|
if data[0] != pong {
|
||||||
|
t.Fatalf("expected pong, got %d", data[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialUnknownFailsImmediately(t *testing.T) {
|
||||||
|
_, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
|
||||||
|
if !errors.Is(err, syscall.ENOENT) {
|
||||||
|
t.Fatalf("expected ENOENT got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialListenerTimesOut(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
|
if err != os.ErrDeadlineExceeded {
|
||||||
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialContextListenerTimesOut(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
d := 10 * time.Millisecond
|
||||||
|
ctx, _ := context.WithTimeout(context.Background(), d)
|
||||||
|
pipe, err := namedpipe.DialContext(ctx, pipePath)
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
|
if err != context.DeadlineExceeded {
|
||||||
|
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialListenerGetsCancelled(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
ch := make(chan error)
|
||||||
|
go func(ctx context.Context, ch chan error) {
|
||||||
|
_, err := namedpipe.DialContext(ctx, pipePath)
|
||||||
|
ch <- err
|
||||||
|
}(ctx, ch)
|
||||||
|
time.Sleep(time.Millisecond * 30)
|
||||||
|
cancel()
|
||||||
|
err = <-ch
|
||||||
|
if err != context.Canceled {
|
||||||
|
t.Fatalf("expected context.Canceled, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
|
||||||
|
if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
|
||||||
|
t.Skip("dacls on named pipes are broken on wine")
|
||||||
|
}
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
sd, _ := windows.SecurityDescriptorFromString("D:")
|
||||||
|
l, err := (&namedpipe.ListenConfig{
|
||||||
|
SecurityDescriptor: sd,
|
||||||
|
}).Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
|
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
|
||||||
|
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &namedpipe.ListenConfig{}
|
||||||
|
}
|
||||||
|
l, err := cfg.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
type response struct {
|
||||||
|
c net.Conn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
ch := make(chan response)
|
||||||
|
go func() {
|
||||||
|
c, err := l.Accept()
|
||||||
|
ch <- response{c, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r := <-ch
|
||||||
|
if err = r.err; err != nil {
|
||||||
|
c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client = c
|
||||||
|
server = r.c
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadTimeout(t *testing.T) {
|
||||||
|
c, s, err := getConnection(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||||
|
|
||||||
|
buf := make([]byte, 10)
|
||||||
|
_, err = c.Read(buf)
|
||||||
|
if err != os.ErrDeadlineExceeded {
|
||||||
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func server(l net.Listener, ch chan int) {
|
||||||
|
c, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
|
||||||
|
s, err := rw.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
_, err = rw.WriteString("got " + s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
err = rw.Flush()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
ch <- 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFullListenDialReadWrite(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
ch := make(chan int)
|
||||||
|
go server(l, ch)
|
||||||
|
|
||||||
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
|
||||||
|
_, err = rw.WriteString("hello world\n")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
err = rw.Flush()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := rw.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ms := "got hello world\n"
|
||||||
|
if s != ms {
|
||||||
|
t.Errorf("expected '%s', got '%s'", ms, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseAbortsListen(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan error)
|
||||||
|
go func() {
|
||||||
|
_, err := l.Accept()
|
||||||
|
ch <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(30 * time.Millisecond)
|
||||||
|
l.Close()
|
||||||
|
|
||||||
|
err = <-ch
|
||||||
|
if err != net.ErrClosed {
|
||||||
|
t.Fatalf("expected net.ErrClosed, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
|
||||||
|
b := make([]byte, 10)
|
||||||
|
w.Close()
|
||||||
|
n, err := r.Read(b)
|
||||||
|
if n > 0 {
|
||||||
|
t.Errorf("unexpected byte count %d", n)
|
||||||
|
}
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Errorf("expected EOF: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseClientEOFServer(t *testing.T) {
|
||||||
|
c, s, err := getConnection(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
defer s.Close()
|
||||||
|
ensureEOFOnClose(t, c, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseServerEOFClient(t *testing.T) {
|
||||||
|
c, s, err := getConnection(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
defer s.Close()
|
||||||
|
ensureEOFOnClose(t, s, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseWriteEOF(t *testing.T) {
|
||||||
|
cfg := &namedpipe.ListenConfig{
|
||||||
|
MessageMode: true,
|
||||||
|
}
|
||||||
|
c, s, err := getConnection(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
type closeWriter interface {
|
||||||
|
CloseWrite() error
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.(closeWriter).CloseWrite()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, 10)
|
||||||
|
_, err = s.Read(b)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAcceptAfterCloseFails(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
l.Close()
|
||||||
|
_, err = l.Accept()
|
||||||
|
if err != net.ErrClosed {
|
||||||
|
t.Fatalf("expected net.ErrClosed, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDialTimesOutByDefault(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
|
||||||
|
if err == nil {
|
||||||
|
pipe.Close()
|
||||||
|
}
|
||||||
|
if err != os.ErrDeadlineExceeded {
|
||||||
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimeoutPendingRead(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
serverDone := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
s.Close()
|
||||||
|
close(serverDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
clientErr := make(chan error)
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, 10)
|
||||||
|
_, err = client.Read(buf)
|
||||||
|
clientErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
|
||||||
|
client.SetReadDeadline(time.Unix(1, 0))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-clientErr:
|
||||||
|
if err != os.ErrDeadlineExceeded {
|
||||||
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatalf("timed out while waiting for read to cancel")
|
||||||
|
<-clientErr
|
||||||
|
}
|
||||||
|
<-serverDone
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTimeoutPendingWrite(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
serverDone := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
s.Close()
|
||||||
|
close(serverDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
clientErr := make(chan error)
|
||||||
|
go func() {
|
||||||
|
_, err = client.Write([]byte("this should timeout"))
|
||||||
|
clientErr <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
|
||||||
|
client.SetWriteDeadline(time.Unix(1, 0))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-clientErr:
|
||||||
|
if err != os.ErrDeadlineExceeded {
|
||||||
|
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatalf("timed out while waiting for write to cancel")
|
||||||
|
<-clientErr
|
||||||
|
}
|
||||||
|
<-serverDone
|
||||||
|
}
|
||||||
|
|
||||||
|
type CloseWriter interface {
|
||||||
|
CloseWrite() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEchoWithMessaging(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := (&namedpipe.ListenConfig{
|
||||||
|
MessageMode: true, // Use message mode so that CloseWrite() is supported
|
||||||
|
InputBufferSize: 65536, // Use 64KB buffers to improve performance
|
||||||
|
OutputBufferSize: 65536,
|
||||||
|
}).Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
listenerDone := make(chan bool)
|
||||||
|
clientDone := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
// server echo
|
||||||
|
conn, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
|
||||||
|
_, err = io.Copy(conn, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
conn.(CloseWriter).CloseWrite()
|
||||||
|
close(listenerDone)
|
||||||
|
}()
|
||||||
|
client, err := namedpipe.DialTimeout(pipePath, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// client read back
|
||||||
|
bytes := make([]byte, 2)
|
||||||
|
n, e := client.Read(bytes)
|
||||||
|
if e != nil {
|
||||||
|
t.Fatal(e)
|
||||||
|
}
|
||||||
|
if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
|
||||||
|
t.Fatalf("expected 2 bytes, got %v", n)
|
||||||
|
}
|
||||||
|
close(clientDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
payload := make([]byte, 2)
|
||||||
|
payload[0] = 0
|
||||||
|
payload[1] = 1
|
||||||
|
|
||||||
|
n, err := client.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if n != 2 {
|
||||||
|
t.Fatalf("expected 2 bytes, got %v", n)
|
||||||
|
}
|
||||||
|
client.(CloseWriter).CloseWrite()
|
||||||
|
<-listenerDone
|
||||||
|
<-clientDone
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectRace(t *testing.T) {
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
s, err := l.Accept()
|
||||||
|
if err == net.ErrClosed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
s.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessageReadMode(t *testing.T) {
|
||||||
|
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
|
||||||
|
t.Skipf("Skipping on Windows %d", maj)
|
||||||
|
}
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
defer wg.Wait()
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
msg := ([]byte)("hello world")
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
s, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = s.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
s.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
mode := uint32(windows.PIPE_READMODE_MESSAGE)
|
||||||
|
err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make([]byte, 1)
|
||||||
|
var vmsg []byte
|
||||||
|
for {
|
||||||
|
n, err := c.Read(ch)
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Fatalf("expected 1, got %d", n)
|
||||||
|
}
|
||||||
|
vmsg = append(vmsg, ch[0])
|
||||||
|
}
|
||||||
|
if !bytes.Equal(msg, vmsg) {
|
||||||
|
t.Fatalf("expected %s, got %s", msg, vmsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListenConnectRace(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping long race test")
|
||||||
|
}
|
||||||
|
pipePath := randomPipePath()
|
||||||
|
for i := 0; i < 50 && !t.Failed(); i++ {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
||||||
|
if err == nil {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
s, err := namedpipe.Listen(pipePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(i, err)
|
||||||
|
} else {
|
||||||
|
s.Close()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,33 +1,21 @@
|
|||||||
// +build darwin freebsd openbsd
|
//go:build darwin || freebsd || openbsd
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
var socketDirectory = "/var/run/wireguard"
|
|
||||||
|
|
||||||
const (
|
|
||||||
IpcErrorIO = -int64(unix.EIO)
|
|
||||||
IpcErrorProtocol = -int64(unix.EPROTO)
|
|
||||||
IpcErrorInvalid = -int64(unix.EINVAL)
|
|
||||||
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
|
||||||
socketName = "%s.sock"
|
|
||||||
)
|
|
||||||
|
|
||||||
type UAPIListener struct {
|
type UAPIListener struct {
|
||||||
listener net.Listener // unix socket listener
|
listener net.Listener // unix socket listener
|
||||||
connNew chan net.Conn
|
connNew chan net.Conn
|
||||||
@@ -66,7 +54,6 @@ func (l *UAPIListener) Addr() net.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||||
|
|
||||||
// wrap file in listener
|
// wrap file in listener
|
||||||
|
|
||||||
listener, err := net.FileListener(file)
|
listener, err := net.FileListener(file)
|
||||||
@@ -84,10 +71,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
unixListener.SetUnlinkOnClose(true)
|
unixListener.SetUnlinkOnClose(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
socketPath := path.Join(
|
socketPath := sockPath(name)
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
// watch for deletion of socket
|
// watch for deletion of socket
|
||||||
|
|
||||||
@@ -119,7 +103,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
l.connErr <- err
|
l.connErr <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if kerr != nil || n != 1 {
|
if (kerr != nil || n != 1) && kerr != unix.EINTR {
|
||||||
if kerr != nil {
|
if kerr != nil {
|
||||||
l.connErr <- kerr
|
l.connErr <- kerr
|
||||||
} else {
|
} else {
|
||||||
@@ -146,58 +130,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
|
|
||||||
return uapi, nil
|
return uapi, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UAPIOpen(name string) (*os.File, error) {
|
|
||||||
|
|
||||||
// check if path exist
|
|
||||||
|
|
||||||
err := os.MkdirAll(socketDirectory, 0755)
|
|
||||||
if err != nil && !os.IsExist(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open UNIX socket
|
|
||||||
|
|
||||||
socketPath := path.Join(
|
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldUmask := unix.Umask(0077)
|
|
||||||
listener, err := func() (*net.UnixListener, error) {
|
|
||||||
|
|
||||||
// initial connection attempt
|
|
||||||
|
|
||||||
listener, err := net.ListenUnix("unix", addr)
|
|
||||||
if err == nil {
|
|
||||||
return listener, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if socket already active
|
|
||||||
|
|
||||||
_, err = net.Dial("unix", socketPath)
|
|
||||||
if err == nil {
|
|
||||||
return nil, errors.New("unix socket in use")
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup & attempt again
|
|
||||||
|
|
||||||
err = os.Remove(socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return net.ListenUnix("unix", addr)
|
|
||||||
}()
|
|
||||||
unix.Umask(oldUmask)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return listener.File()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,29 +1,16 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
|
|
||||||
|
"github.com/Lordy82/wireguard-go/rwcancel"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
|
||||||
)
|
|
||||||
|
|
||||||
var socketDirectory = "/var/run/wireguard"
|
|
||||||
|
|
||||||
const (
|
|
||||||
IpcErrorIO = -int64(unix.EIO)
|
|
||||||
IpcErrorProtocol = -int64(unix.EPROTO)
|
|
||||||
IpcErrorInvalid = -int64(unix.EINVAL)
|
|
||||||
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
|
||||||
socketName = "%s.sock"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type UAPIListener struct {
|
type UAPIListener struct {
|
||||||
@@ -64,7 +51,6 @@ func (l *UAPIListener) Addr() net.Addr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||||
|
|
||||||
// wrap file in listener
|
// wrap file in listener
|
||||||
|
|
||||||
listener, err := net.FileListener(file)
|
listener, err := net.FileListener(file)
|
||||||
@@ -84,10 +70,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
|
|
||||||
// watch for deletion of socket
|
// watch for deletion of socket
|
||||||
|
|
||||||
socketPath := path.Join(
|
socketPath := sockPath(name)
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
uapi.inotifyFd, err = unix.InotifyInit()
|
uapi.inotifyFd, err = unix.InotifyInit()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -113,14 +96,15 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func(l *UAPIListener) {
|
go func(l *UAPIListener) {
|
||||||
var buff [0]byte
|
var buf [0]byte
|
||||||
for {
|
for {
|
||||||
|
defer uapi.inotifyRWCancel.Close()
|
||||||
// start with lstat to avoid race condition
|
// start with lstat to avoid race condition
|
||||||
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
|
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
|
||||||
l.connErr <- err
|
l.connErr <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err := uapi.inotifyRWCancel.Read(buff[:])
|
_, err := uapi.inotifyRWCancel.Read(buf[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.connErr <- err
|
l.connErr <- err
|
||||||
return
|
return
|
||||||
@@ -143,58 +127,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
|
|
||||||
return uapi, nil
|
return uapi, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UAPIOpen(name string) (*os.File, error) {
|
|
||||||
|
|
||||||
// check if path exist
|
|
||||||
|
|
||||||
err := os.MkdirAll(socketDirectory, 0755)
|
|
||||||
if err != nil && !os.IsExist(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open UNIX socket
|
|
||||||
|
|
||||||
socketPath := path.Join(
|
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldUmask := unix.Umask(0077)
|
|
||||||
listener, err := func() (*net.UnixListener, error) {
|
|
||||||
|
|
||||||
// initial connection attempt
|
|
||||||
|
|
||||||
listener, err := net.ListenUnix("unix", addr)
|
|
||||||
if err == nil {
|
|
||||||
return listener, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if socket already active
|
|
||||||
|
|
||||||
_, err = net.Dial("unix", socketPath)
|
|
||||||
if err == nil {
|
|
||||||
return nil, errors.New("unix socket in use")
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup & attempt again
|
|
||||||
|
|
||||||
err = os.Remove(socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return net.ListenUnix("unix", addr)
|
|
||||||
}()
|
|
||||||
unix.Umask(oldUmask)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return listener.File()
|
|
||||||
}
|
|
||||||
|
|||||||
66
ipc/uapi_unix.go
Normal file
66
ipc/uapi_unix.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//go:build linux || darwin || freebsd || openbsd
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ipc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
IpcErrorIO = -int64(unix.EIO)
|
||||||
|
IpcErrorProtocol = -int64(unix.EPROTO)
|
||||||
|
IpcErrorInvalid = -int64(unix.EINVAL)
|
||||||
|
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||||
|
IpcErrorUnknown = -55 // ENOANO
|
||||||
|
)
|
||||||
|
|
||||||
|
// socketDirectory is variable because it is modified by a linker
|
||||||
|
// flag in wireguard-android.
|
||||||
|
var socketDirectory = "/var/run/wireguard"
|
||||||
|
|
||||||
|
func sockPath(iface string) string {
|
||||||
|
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UAPIOpen(name string) (*os.File, error) {
|
||||||
|
if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
socketPath := sockPath(name)
|
||||||
|
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldUmask := unix.Umask(0o077)
|
||||||
|
defer unix.Umask(oldUmask)
|
||||||
|
|
||||||
|
listener, err := net.ListenUnix("unix", addr)
|
||||||
|
if err == nil {
|
||||||
|
return listener.File()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test socket, if not in use cleanup and try again.
|
||||||
|
if _, err := net.Dial("unix", socketPath); err == nil {
|
||||||
|
return nil, errors.New("unix socket in use")
|
||||||
|
}
|
||||||
|
if err := os.Remove(socketPath); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
listener, err = net.ListenUnix("unix", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return listener.File()
|
||||||
|
}
|
||||||
15
ipc/uapi_wasm.go
Normal file
15
ipc/uapi_wasm.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ipc
|
||||||
|
|
||||||
|
// Made up sentinel error codes for {js,wasip1}/wasm.
|
||||||
|
const (
|
||||||
|
IpcErrorIO = 1
|
||||||
|
IpcErrorInvalid = 2
|
||||||
|
IpcErrorPortInUse = 3
|
||||||
|
IpcErrorUnknown = 4
|
||||||
|
IpcErrorProtocol = 5
|
||||||
|
)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
@@ -8,7 +8,8 @@ package ipc
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
"github.com/Lordy82/wireguard-go/ipc/namedpipe"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: replace these with actual standard windows error numbers from the win package
|
// TODO: replace these with actual standard windows error numbers from the win package
|
||||||
@@ -17,6 +18,7 @@ const (
|
|||||||
IpcErrorProtocol = -int64(71)
|
IpcErrorProtocol = -int64(71)
|
||||||
IpcErrorInvalid = -int64(22)
|
IpcErrorInvalid = -int64(22)
|
||||||
IpcErrorPortInUse = -int64(98)
|
IpcErrorPortInUse = -int64(98)
|
||||||
|
IpcErrorUnknown = -int64(55)
|
||||||
)
|
)
|
||||||
|
|
||||||
type UAPIListener struct {
|
type UAPIListener struct {
|
||||||
@@ -47,14 +49,20 @@ func (l *UAPIListener) Addr() net.Addr {
|
|||||||
return l.listener.Addr()
|
return l.listener.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
|
var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
var UAPISecurityDescriptor = "O:SYD:P(A;;GA;;;SY)"
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func UAPIListen(name string) (net.Listener, error) {
|
func UAPIListen(name string) (net.Listener, error) {
|
||||||
config := winpipe.PipeConfig{
|
listener, err := (&namedpipe.ListenConfig{
|
||||||
SecurityDescriptor: UAPISecurityDescriptor,
|
SecurityDescriptor: UAPISecurityDescriptor,
|
||||||
}
|
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
|
||||||
listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,322 +0,0 @@
|
|||||||
// +build windows
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
package winpipe
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
|
|
||||||
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
|
|
||||||
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
|
||||||
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
|
||||||
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
|
||||||
|
|
||||||
type atomicBool int32
|
|
||||||
|
|
||||||
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
|
||||||
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
|
||||||
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
|
||||||
func (b *atomicBool) swap(new bool) bool {
|
|
||||||
var newInt int32
|
|
||||||
if new {
|
|
||||||
newInt = 1
|
|
||||||
}
|
|
||||||
return atomic.SwapInt32((*int32)(b), newInt) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1
|
|
||||||
cFILE_SKIP_SET_EVENT_ON_HANDLE = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrFileClosed = errors.New("file has already been closed")
|
|
||||||
ErrTimeout = &timeoutError{}
|
|
||||||
)
|
|
||||||
|
|
||||||
type timeoutError struct{}
|
|
||||||
|
|
||||||
func (e *timeoutError) Error() string { return "i/o timeout" }
|
|
||||||
func (e *timeoutError) Timeout() bool { return true }
|
|
||||||
func (e *timeoutError) Temporary() bool { return true }
|
|
||||||
|
|
||||||
type timeoutChan chan struct{}
|
|
||||||
|
|
||||||
var ioInitOnce sync.Once
|
|
||||||
var ioCompletionPort syscall.Handle
|
|
||||||
|
|
||||||
// ioResult contains the result of an asynchronous IO operation
|
|
||||||
type ioResult struct {
|
|
||||||
bytes uint32
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ioOperation represents an outstanding asynchronous Win32 IO
|
|
||||||
type ioOperation struct {
|
|
||||||
o syscall.Overlapped
|
|
||||||
ch chan ioResult
|
|
||||||
}
|
|
||||||
|
|
||||||
func initIo() {
|
|
||||||
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ioCompletionPort = h
|
|
||||||
go ioCompletionProcessor(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
|
||||||
// It takes ownership of this handle and will close it if it is garbage collected.
|
|
||||||
type win32File struct {
|
|
||||||
handle syscall.Handle
|
|
||||||
wg sync.WaitGroup
|
|
||||||
wgLock sync.RWMutex
|
|
||||||
closing atomicBool
|
|
||||||
socket bool
|
|
||||||
readDeadline deadlineHandler
|
|
||||||
writeDeadline deadlineHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
type deadlineHandler struct {
|
|
||||||
setLock sync.Mutex
|
|
||||||
channel timeoutChan
|
|
||||||
channelLock sync.RWMutex
|
|
||||||
timer *time.Timer
|
|
||||||
timedout atomicBool
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeWin32File makes a new win32File from an existing file handle
|
|
||||||
func makeWin32File(h syscall.Handle) (*win32File, error) {
|
|
||||||
f := &win32File{handle: h}
|
|
||||||
ioInitOnce.Do(initIo)
|
|
||||||
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
f.readDeadline.channel = make(timeoutChan)
|
|
||||||
f.writeDeadline.channel = make(timeoutChan)
|
|
||||||
return f, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
|
|
||||||
return makeWin32File(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeHandle closes the resources associated with a Win32 handle
|
|
||||||
func (f *win32File) closeHandle() {
|
|
||||||
f.wgLock.Lock()
|
|
||||||
// Atomically set that we are closing, releasing the resources only once.
|
|
||||||
if !f.closing.swap(true) {
|
|
||||||
f.wgLock.Unlock()
|
|
||||||
// cancel all IO and wait for it to complete
|
|
||||||
cancelIoEx(f.handle, nil)
|
|
||||||
f.wg.Wait()
|
|
||||||
// at this point, no new IO can start
|
|
||||||
syscall.Close(f.handle)
|
|
||||||
f.handle = 0
|
|
||||||
} else {
|
|
||||||
f.wgLock.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes a win32File.
|
|
||||||
func (f *win32File) Close() error {
|
|
||||||
f.closeHandle()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareIo prepares for a new IO operation.
|
|
||||||
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
|
||||||
func (f *win32File) prepareIo() (*ioOperation, error) {
|
|
||||||
f.wgLock.RLock()
|
|
||||||
if f.closing.isSet() {
|
|
||||||
f.wgLock.RUnlock()
|
|
||||||
return nil, ErrFileClosed
|
|
||||||
}
|
|
||||||
f.wg.Add(1)
|
|
||||||
f.wgLock.RUnlock()
|
|
||||||
c := &ioOperation{}
|
|
||||||
c.ch = make(chan ioResult)
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ioCompletionProcessor processes completed async IOs forever
|
|
||||||
func ioCompletionProcessor(h syscall.Handle) {
|
|
||||||
for {
|
|
||||||
var bytes uint32
|
|
||||||
var key uintptr
|
|
||||||
var op *ioOperation
|
|
||||||
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE)
|
|
||||||
if op == nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
op.ch <- ioResult{bytes, err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
|
|
||||||
// the operation has actually completed.
|
|
||||||
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
|
||||||
if err != syscall.ERROR_IO_PENDING {
|
|
||||||
return int(bytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.closing.isSet() {
|
|
||||||
cancelIoEx(f.handle, &c.o)
|
|
||||||
}
|
|
||||||
|
|
||||||
var timeout timeoutChan
|
|
||||||
if d != nil {
|
|
||||||
d.channelLock.Lock()
|
|
||||||
timeout = d.channel
|
|
||||||
d.channelLock.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
var r ioResult
|
|
||||||
select {
|
|
||||||
case r = <-c.ch:
|
|
||||||
err = r.err
|
|
||||||
if err == syscall.ERROR_OPERATION_ABORTED {
|
|
||||||
if f.closing.isSet() {
|
|
||||||
err = ErrFileClosed
|
|
||||||
}
|
|
||||||
} else if err != nil && f.socket {
|
|
||||||
// err is from Win32. Query the overlapped structure to get the winsock error.
|
|
||||||
var bytes, flags uint32
|
|
||||||
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
|
||||||
}
|
|
||||||
case <-timeout:
|
|
||||||
cancelIoEx(f.handle, &c.o)
|
|
||||||
r = <-c.ch
|
|
||||||
err = r.err
|
|
||||||
if err == syscall.ERROR_OPERATION_ABORTED {
|
|
||||||
err = ErrTimeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// runtime.KeepAlive is needed, as c is passed via native
|
|
||||||
// code to ioCompletionProcessor, c must remain alive
|
|
||||||
// until the channel read is complete.
|
|
||||||
runtime.KeepAlive(c)
|
|
||||||
return int(r.bytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads from a file handle.
|
|
||||||
func (f *win32File) Read(b []byte) (int, error) {
|
|
||||||
c, err := f.prepareIo()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer f.wg.Done()
|
|
||||||
|
|
||||||
if f.readDeadline.timedout.isSet() {
|
|
||||||
return 0, ErrTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
var bytes uint32
|
|
||||||
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
|
|
||||||
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
|
|
||||||
runtime.KeepAlive(b)
|
|
||||||
|
|
||||||
// Handle EOF conditions.
|
|
||||||
if err == nil && n == 0 && len(b) != 0 {
|
|
||||||
return 0, io.EOF
|
|
||||||
} else if err == syscall.ERROR_BROKEN_PIPE {
|
|
||||||
return 0, io.EOF
|
|
||||||
} else {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes to a file handle.
|
|
||||||
func (f *win32File) Write(b []byte) (int, error) {
|
|
||||||
c, err := f.prepareIo()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer f.wg.Done()
|
|
||||||
|
|
||||||
if f.writeDeadline.timedout.isSet() {
|
|
||||||
return 0, ErrTimeout
|
|
||||||
}
|
|
||||||
|
|
||||||
var bytes uint32
|
|
||||||
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
|
|
||||||
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
|
|
||||||
runtime.KeepAlive(b)
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *win32File) SetReadDeadline(deadline time.Time) error {
|
|
||||||
return f.readDeadline.set(deadline)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *win32File) SetWriteDeadline(deadline time.Time) error {
|
|
||||||
return f.writeDeadline.set(deadline)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *win32File) Flush() error {
|
|
||||||
return syscall.FlushFileBuffers(f.handle)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *win32File) Fd() uintptr {
|
|
||||||
return uintptr(f.handle)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deadlineHandler) set(deadline time.Time) error {
|
|
||||||
d.setLock.Lock()
|
|
||||||
defer d.setLock.Unlock()
|
|
||||||
|
|
||||||
if d.timer != nil {
|
|
||||||
if !d.timer.Stop() {
|
|
||||||
<-d.channel
|
|
||||||
}
|
|
||||||
d.timer = nil
|
|
||||||
}
|
|
||||||
d.timedout.setFalse()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-d.channel:
|
|
||||||
d.channelLock.Lock()
|
|
||||||
d.channel = make(chan struct{})
|
|
||||||
d.channelLock.Unlock()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if deadline.IsZero() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
timeoutIO := func() {
|
|
||||||
d.timedout.setTrue()
|
|
||||||
close(d.channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
duration := deadline.Sub(now)
|
|
||||||
if deadline.After(now) {
|
|
||||||
// Deadline is in the future, set a timer to wait
|
|
||||||
d.timer = time.AfterFunc(duration, timeoutIO)
|
|
||||||
} else {
|
|
||||||
// Deadline is in the past. Cancel all pending IO now.
|
|
||||||
timeoutIO()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winpipe
|
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go pipe.go sd.go file.go
|
|
||||||
@@ -1,532 +0,0 @@
|
|||||||
// +build windows
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winpipe
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe
|
|
||||||
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW
|
|
||||||
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW
|
|
||||||
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
|
||||||
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
|
||||||
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
|
|
||||||
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
|
|
||||||
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
|
||||||
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
|
|
||||||
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
|
|
||||||
|
|
||||||
type ioStatusBlock struct {
|
|
||||||
Status, Information uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
type objectAttributes struct {
|
|
||||||
Length uintptr
|
|
||||||
RootDirectory uintptr
|
|
||||||
ObjectName *unicodeString
|
|
||||||
Attributes uintptr
|
|
||||||
SecurityDescriptor *securityDescriptor
|
|
||||||
SecurityQoS uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
type unicodeString struct {
|
|
||||||
Length uint16
|
|
||||||
MaximumLength uint16
|
|
||||||
Buffer uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
type securityDescriptor struct {
|
|
||||||
Revision byte
|
|
||||||
Sbz1 byte
|
|
||||||
Control uint16
|
|
||||||
Owner uintptr
|
|
||||||
Group uintptr
|
|
||||||
Sacl uintptr
|
|
||||||
Dacl uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
type ntstatus int32
|
|
||||||
|
|
||||||
func (status ntstatus) Err() error {
|
|
||||||
if status >= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return rtlNtStatusToDosError(status)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
cERROR_PIPE_BUSY = syscall.Errno(231)
|
|
||||||
cERROR_NO_DATA = syscall.Errno(232)
|
|
||||||
cERROR_PIPE_CONNECTED = syscall.Errno(535)
|
|
||||||
cERROR_SEM_TIMEOUT = syscall.Errno(121)
|
|
||||||
|
|
||||||
cSECURITY_SQOS_PRESENT = 0x100000
|
|
||||||
cSECURITY_ANONYMOUS = 0
|
|
||||||
|
|
||||||
cPIPE_TYPE_MESSAGE = 4
|
|
||||||
|
|
||||||
cPIPE_READMODE_MESSAGE = 2
|
|
||||||
|
|
||||||
cFILE_OPEN = 1
|
|
||||||
cFILE_CREATE = 2
|
|
||||||
|
|
||||||
cFILE_PIPE_MESSAGE_TYPE = 1
|
|
||||||
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
|
|
||||||
|
|
||||||
cSE_DACL_PRESENT = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
|
|
||||||
// This error should match net.errClosing since docker takes a dependency on its text.
|
|
||||||
ErrPipeListenerClosed = errors.New("use of closed network connection")
|
|
||||||
|
|
||||||
errPipeWriteClosed = errors.New("pipe has been closed for write")
|
|
||||||
)
|
|
||||||
|
|
||||||
type win32Pipe struct {
|
|
||||||
*win32File
|
|
||||||
path string
|
|
||||||
}
|
|
||||||
|
|
||||||
type win32MessageBytePipe struct {
|
|
||||||
win32Pipe
|
|
||||||
writeClosed bool
|
|
||||||
readEOF bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type pipeAddress string
|
|
||||||
|
|
||||||
func (f *win32Pipe) LocalAddr() net.Addr {
|
|
||||||
return pipeAddress(f.path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *win32Pipe) RemoteAddr() net.Addr {
|
|
||||||
return pipeAddress(f.path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *win32Pipe) SetDeadline(t time.Time) error {
|
|
||||||
f.SetReadDeadline(t)
|
|
||||||
f.SetWriteDeadline(t)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
|
||||||
func (f *win32MessageBytePipe) CloseWrite() error {
|
|
||||||
if f.writeClosed {
|
|
||||||
return errPipeWriteClosed
|
|
||||||
}
|
|
||||||
err := f.win32File.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = f.win32File.Write(nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
f.writeClosed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
|
||||||
// they are used to implement CloseWrite().
|
|
||||||
func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
|
|
||||||
if f.writeClosed {
|
|
||||||
return 0, errPipeWriteClosed
|
|
||||||
}
|
|
||||||
if len(b) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return f.win32File.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
|
|
||||||
// mode pipe will return io.EOF, as will all subsequent reads.
|
|
||||||
func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
|
|
||||||
if f.readEOF {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
n, err := f.win32File.Read(b)
|
|
||||||
if err == io.EOF {
|
|
||||||
// If this was the result of a zero-byte read, then
|
|
||||||
// it is possible that the read was due to a zero-size
|
|
||||||
// message. Since we are simulating CloseWrite with a
|
|
||||||
// zero-byte message, ensure that all future Read() calls
|
|
||||||
// also return EOF.
|
|
||||||
f.readEOF = true
|
|
||||||
} else if err == syscall.ERROR_MORE_DATA {
|
|
||||||
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
|
||||||
// and the message still has more bytes. Treat this as a success, since
|
|
||||||
// this package presents all named pipes as byte streams.
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s pipeAddress) Network() string {
|
|
||||||
return "pipe"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s pipeAddress) String() string {
|
|
||||||
return string(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
|
||||||
func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return syscall.Handle(0), ctx.Err()
|
|
||||||
default:
|
|
||||||
h, err := createFile(*path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
|
|
||||||
if err == nil {
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
if err != cERROR_PIPE_BUSY {
|
|
||||||
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
|
||||||
}
|
|
||||||
// Wait 10 msec and try again. This is a rather simplistic
|
|
||||||
// view, as we always try each 10 milliseconds.
|
|
||||||
time.Sleep(time.Millisecond * 10)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialPipe connects to a named pipe by path, timing out if the connection
|
|
||||||
// takes longer than the specified duration. If timeout is nil, then we use
|
|
||||||
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
|
||||||
func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (net.Conn, error) {
|
|
||||||
var absTimeout time.Time
|
|
||||||
if timeout != nil {
|
|
||||||
absTimeout = time.Now().Add(*timeout)
|
|
||||||
} else {
|
|
||||||
absTimeout = time.Now().Add(time.Second * 2)
|
|
||||||
}
|
|
||||||
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
|
||||||
conn, err := DialPipeContext(ctx, path, expectedOwner)
|
|
||||||
if err == context.DeadlineExceeded {
|
|
||||||
return nil, ErrTimeout
|
|
||||||
}
|
|
||||||
return conn, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
|
||||||
// cancellation or timeout.
|
|
||||||
func DialPipeContext(ctx context.Context, path string, expectedOwner *syscall.SID) (net.Conn, error) {
|
|
||||||
var err error
|
|
||||||
var h syscall.Handle
|
|
||||||
h, err = tryDialPipe(ctx, &path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if expectedOwner != nil {
|
|
||||||
var realOwner *syscall.SID
|
|
||||||
var realSd uintptr
|
|
||||||
err = getSecurityInfo(h, SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, &realOwner, nil, nil, nil, &realSd)
|
|
||||||
if err != nil {
|
|
||||||
syscall.Close(h)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer localFree(realSd)
|
|
||||||
if !equalSid(realOwner, expectedOwner) {
|
|
||||||
syscall.Close(h)
|
|
||||||
return nil, syscall.ERROR_ACCESS_DENIED
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var flags uint32
|
|
||||||
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
syscall.Close(h)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := makeWin32File(h)
|
|
||||||
if err != nil {
|
|
||||||
syscall.Close(h)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the pipe is in message mode, return a message byte pipe, which
|
|
||||||
// supports CloseWrite().
|
|
||||||
if flags&cPIPE_TYPE_MESSAGE != 0 {
|
|
||||||
return &win32MessageBytePipe{
|
|
||||||
win32Pipe: win32Pipe{win32File: f, path: path},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return &win32Pipe{win32File: f, path: path}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type acceptResponse struct {
|
|
||||||
f *win32File
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type win32PipeListener struct {
|
|
||||||
firstHandle syscall.Handle
|
|
||||||
path string
|
|
||||||
config PipeConfig
|
|
||||||
acceptCh chan (chan acceptResponse)
|
|
||||||
closeCh chan int
|
|
||||||
doneCh chan int
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
|
|
||||||
path16, err := syscall.UTF16FromString(path)
|
|
||||||
if err != nil {
|
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
var oa objectAttributes
|
|
||||||
oa.Length = unsafe.Sizeof(oa)
|
|
||||||
|
|
||||||
var ntPath unicodeString
|
|
||||||
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
|
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
|
||||||
}
|
|
||||||
defer localFree(ntPath.Buffer)
|
|
||||||
oa.ObjectName = &ntPath
|
|
||||||
|
|
||||||
// The security descriptor is only needed for the first pipe.
|
|
||||||
if first {
|
|
||||||
if sd != nil {
|
|
||||||
len := uint32(len(sd))
|
|
||||||
sdb := localAlloc(0, len)
|
|
||||||
defer localFree(sdb)
|
|
||||||
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
|
|
||||||
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
|
|
||||||
} else {
|
|
||||||
// Construct the default named pipe security descriptor.
|
|
||||||
var dacl uintptr
|
|
||||||
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
|
||||||
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
|
|
||||||
}
|
|
||||||
defer localFree(dacl)
|
|
||||||
|
|
||||||
sdb := &securityDescriptor{
|
|
||||||
Revision: 1,
|
|
||||||
Control: cSE_DACL_PRESENT,
|
|
||||||
Dacl: dacl,
|
|
||||||
}
|
|
||||||
oa.SecurityDescriptor = sdb
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS)
|
|
||||||
if c.MessageMode {
|
|
||||||
typ |= cFILE_PIPE_MESSAGE_TYPE
|
|
||||||
}
|
|
||||||
|
|
||||||
disposition := uint32(cFILE_OPEN)
|
|
||||||
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
|
|
||||||
if first {
|
|
||||||
disposition = cFILE_CREATE
|
|
||||||
// By not asking for read or write access, the named pipe file system
|
|
||||||
// will put this pipe into an initially disconnected state, blocking
|
|
||||||
// client connections until the next call with first == false.
|
|
||||||
access = syscall.SYNCHRONIZE
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout := int64(-50 * 10000) // 50ms
|
|
||||||
|
|
||||||
var (
|
|
||||||
h syscall.Handle
|
|
||||||
iosb ioStatusBlock
|
|
||||||
)
|
|
||||||
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
|
|
||||||
if err != nil {
|
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
runtime.KeepAlive(ntPath)
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
|
|
||||||
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
f, err := makeWin32File(h)
|
|
||||||
if err != nil {
|
|
||||||
syscall.Close(h)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return f, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
|
|
||||||
p, err := l.makeServerPipe()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the client to connect.
|
|
||||||
ch := make(chan error)
|
|
||||||
go func(p *win32File) {
|
|
||||||
ch <- connectPipe(p)
|
|
||||||
}(p)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-ch:
|
|
||||||
if err != nil {
|
|
||||||
p.Close()
|
|
||||||
p = nil
|
|
||||||
}
|
|
||||||
case <-l.closeCh:
|
|
||||||
// Abort the connect request by closing the handle.
|
|
||||||
p.Close()
|
|
||||||
p = nil
|
|
||||||
err = <-ch
|
|
||||||
if err == nil || err == ErrFileClosed {
|
|
||||||
err = ErrPipeListenerClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *win32PipeListener) listenerRoutine() {
|
|
||||||
closed := false
|
|
||||||
for !closed {
|
|
||||||
select {
|
|
||||||
case <-l.closeCh:
|
|
||||||
closed = true
|
|
||||||
case responseCh := <-l.acceptCh:
|
|
||||||
var (
|
|
||||||
p *win32File
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
for {
|
|
||||||
p, err = l.makeConnectedServerPipe()
|
|
||||||
// If the connection was immediately closed by the client, try
|
|
||||||
// again.
|
|
||||||
if err != cERROR_NO_DATA {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
responseCh <- acceptResponse{p, err}
|
|
||||||
closed = err == ErrPipeListenerClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
syscall.Close(l.firstHandle)
|
|
||||||
l.firstHandle = 0
|
|
||||||
// Notify Close() and Accept() callers that the handle has been closed.
|
|
||||||
close(l.doneCh)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PipeConfig contain configuration for the pipe listener.
|
|
||||||
type PipeConfig struct {
|
|
||||||
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
|
|
||||||
SecurityDescriptor string
|
|
||||||
|
|
||||||
// MessageMode determines whether the pipe is in byte or message mode. In either
|
|
||||||
// case the pipe is read in byte mode by default. The only practical difference in
|
|
||||||
// this implementation is that CloseWrite() is only supported for message mode pipes;
|
|
||||||
// CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
|
|
||||||
// transferred to the reader (and returned as io.EOF in this implementation)
|
|
||||||
// when the pipe is in message mode.
|
|
||||||
MessageMode bool
|
|
||||||
|
|
||||||
// InputBufferSize specifies the size the input buffer, in bytes.
|
|
||||||
InputBufferSize int32
|
|
||||||
|
|
||||||
// OutputBufferSize specifies the size the input buffer, in bytes.
|
|
||||||
OutputBufferSize int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
|
||||||
// The pipe must not already exist.
|
|
||||||
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
|
||||||
var (
|
|
||||||
sd []byte
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if c == nil {
|
|
||||||
c = &PipeConfig{}
|
|
||||||
}
|
|
||||||
if c.SecurityDescriptor != "" {
|
|
||||||
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h, err := makeServerPipeHandle(path, sd, c, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
l := &win32PipeListener{
|
|
||||||
firstHandle: h,
|
|
||||||
path: path,
|
|
||||||
config: *c,
|
|
||||||
acceptCh: make(chan (chan acceptResponse)),
|
|
||||||
closeCh: make(chan int),
|
|
||||||
doneCh: make(chan int),
|
|
||||||
}
|
|
||||||
go l.listenerRoutine()
|
|
||||||
return l, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func connectPipe(p *win32File) error {
|
|
||||||
c, err := p.prepareIo()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer p.wg.Done()
|
|
||||||
|
|
||||||
err = connectNamedPipe(p.handle, &c.o)
|
|
||||||
_, err = p.asyncIo(c, nil, 0, err)
|
|
||||||
if err != nil && err != cERROR_PIPE_CONNECTED {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *win32PipeListener) Accept() (net.Conn, error) {
|
|
||||||
ch := make(chan acceptResponse)
|
|
||||||
select {
|
|
||||||
case l.acceptCh <- ch:
|
|
||||||
response := <-ch
|
|
||||||
err := response.err
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if l.config.MessageMode {
|
|
||||||
return &win32MessageBytePipe{
|
|
||||||
win32Pipe: win32Pipe{win32File: response.f, path: l.path},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return &win32Pipe{win32File: response.f, path: l.path}, nil
|
|
||||||
case <-l.doneCh:
|
|
||||||
return nil, ErrPipeListenerClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *win32PipeListener) Close() error {
|
|
||||||
select {
|
|
||||||
case l.closeCh <- 1:
|
|
||||||
<-l.doneCh
|
|
||||||
case <-l.doneCh:
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *win32PipeListener) Addr() net.Addr {
|
|
||||||
return pipeAddress(l.path)
|
|
||||||
}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
// +build windows
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winpipe
|
|
||||||
|
|
||||||
import (
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
|
|
||||||
//sys localFree(mem uintptr) = LocalFree
|
|
||||||
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
|
|
||||||
//sys getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) = advapi32.GetSecurityInfo
|
|
||||||
//sys equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) = advapi32.EqualSid
|
|
||||||
|
|
||||||
const (
|
|
||||||
SE_FILE_OBJECT = 1
|
|
||||||
OWNER_SECURITY_INFORMATION = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
|
||||||
var sdBuffer uintptr
|
|
||||||
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer localFree(sdBuffer)
|
|
||||||
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
|
|
||||||
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
|
|
||||||
return sd, nil
|
|
||||||
}
|
|
||||||
@@ -1,290 +0,0 @@
|
|||||||
// Code generated by 'go generate'; DO NOT EDIT.
|
|
||||||
|
|
||||||
package winpipe
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ unsafe.Pointer
|
|
||||||
|
|
||||||
// Do the interface allocations only once for common
|
|
||||||
// Errno values.
|
|
||||||
const (
|
|
||||||
errnoERROR_IO_PENDING = 997
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
|
||||||
)
|
|
||||||
|
|
||||||
// errnoErr returns common boxed Errno values, to prevent
|
|
||||||
// allocations at runtime.
|
|
||||||
func errnoErr(e syscall.Errno) error {
|
|
||||||
switch e {
|
|
||||||
case 0:
|
|
||||||
return nil
|
|
||||||
case errnoERROR_IO_PENDING:
|
|
||||||
return errERROR_IO_PENDING
|
|
||||||
}
|
|
||||||
// TODO: add more here, after collecting data on the common
|
|
||||||
// error values see on Windows. (perhaps when running
|
|
||||||
// all.bat?)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
|
||||||
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
|
||||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
|
||||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
|
||||||
|
|
||||||
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
|
||||||
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
|
|
||||||
procCreateFileW = modkernel32.NewProc("CreateFileW")
|
|
||||||
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
|
|
||||||
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
|
|
||||||
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
|
|
||||||
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
|
|
||||||
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
|
||||||
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
|
||||||
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
|
||||||
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
|
||||||
procLocalFree = modkernel32.NewProc("LocalFree")
|
|
||||||
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
|
|
||||||
procGetSecurityInfo = modadvapi32.NewProc("GetSecurityInfo")
|
|
||||||
procEqualSid = modadvapi32.NewProc("EqualSid")
|
|
||||||
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
|
||||||
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
|
||||||
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
|
||||||
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
|
|
||||||
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
|
||||||
)
|
|
||||||
|
|
||||||
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
|
|
||||||
var _p0 *uint16
|
|
||||||
_p0, err = syscall.UTF16PtrFromString(name)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
|
|
||||||
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
|
|
||||||
handle = syscall.Handle(r0)
|
|
||||||
if handle == syscall.InvalidHandle {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
|
|
||||||
var _p0 *uint16
|
|
||||||
_p0, err = syscall.UTF16PtrFromString(name)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
|
|
||||||
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
|
|
||||||
handle = syscall.Handle(r0)
|
|
||||||
if handle == syscall.InvalidHandle {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
|
|
||||||
r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0)
|
|
||||||
ptr = uintptr(r0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
|
|
||||||
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
|
|
||||||
status = ntstatus(r0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func rtlNtStatusToDosError(status ntstatus) (winerr error) {
|
|
||||||
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
|
|
||||||
if r0 != 0 {
|
|
||||||
winerr = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) {
|
|
||||||
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
|
|
||||||
status = ntstatus(r0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
|
|
||||||
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
|
|
||||||
status = ntstatus(r0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) {
|
|
||||||
var _p0 *uint16
|
|
||||||
_p0, err = syscall.UTF16PtrFromString(str)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func localFree(mem uintptr) {
|
|
||||||
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
|
|
||||||
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
|
|
||||||
len = uint32(r0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) {
|
|
||||||
r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(securityInformation), uintptr(unsafe.Pointer(owner)), uintptr(unsafe.Pointer(group)), uintptr(unsafe.Pointer(dacl)), uintptr(unsafe.Pointer(sacl)), uintptr(unsafe.Pointer(sd)), 0)
|
|
||||||
if r0 != 0 {
|
|
||||||
ret = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) {
|
|
||||||
r0, _, _ := syscall.Syscall(procEqualSid.Addr(), 2, uintptr(unsafe.Pointer(sid1)), uintptr(unsafe.Pointer(sid2)), 0)
|
|
||||||
isEqual = r0 != 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) {
|
|
||||||
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
|
|
||||||
newport = syscall.Handle(r0)
|
|
||||||
if newport == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
|
||||||
var _p0 uint32
|
|
||||||
if wait {
|
|
||||||
_p0 = 1
|
|
||||||
} else {
|
|
||||||
_p0 = 0
|
|
||||||
}
|
|
||||||
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
87
main.go
87
main.go
@@ -1,8 +1,8 @@
|
|||||||
// +build !windows
|
//go:build !windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package main
|
package main
|
||||||
@@ -13,11 +13,12 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"github.com/Lordy82/wireguard-go/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"github.com/Lordy82/wireguard-go/ipc"
|
||||||
|
"github.com/Lordy82/wireguard-go/tun"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -32,32 +33,33 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func printUsage() {
|
func printUsage() {
|
||||||
fmt.Printf("usage:\n")
|
fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
|
||||||
fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func warning() {
|
func warning() {
|
||||||
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
|
switch runtime.GOOS {
|
||||||
|
case "linux", "freebsd", "openbsd":
|
||||||
|
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
|
||||||
fmt.Fprintln(os.Stderr, "W G")
|
fmt.Fprintln(os.Stderr, "│ │")
|
||||||
fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
|
fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
|
||||||
fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G")
|
fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
|
||||||
fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
|
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
||||||
fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
|
fmt.Fprintln(os.Stderr, "│ please visit: │")
|
||||||
fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
|
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
|
||||||
fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
|
fmt.Fprintln(os.Stderr, "│ │")
|
||||||
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
|
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
|
||||||
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
|
|
||||||
fmt.Fprintln(os.Stderr, "W G")
|
|
||||||
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) == 2 && os.Args[1] == "--version" {
|
if len(os.Args) == 2 && os.Args[1] == "--version" {
|
||||||
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", device.WireGuardGoVersion, runtime.GOOS, runtime.GOARCH)
|
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,21 +99,19 @@ func main() {
|
|||||||
|
|
||||||
logLevel := func() int {
|
logLevel := func() int {
|
||||||
switch os.Getenv("LOG_LEVEL") {
|
switch os.Getenv("LOG_LEVEL") {
|
||||||
case "debug":
|
case "verbose", "debug":
|
||||||
return device.LogLevelDebug
|
return device.LogLevelVerbose
|
||||||
case "info":
|
|
||||||
return device.LogLevelInfo
|
|
||||||
case "error":
|
case "error":
|
||||||
return device.LogLevelError
|
return device.LogLevelError
|
||||||
case "silent":
|
case "silent":
|
||||||
return device.LogLevelSilent
|
return device.LogLevelSilent
|
||||||
}
|
}
|
||||||
return device.LogLevelInfo
|
return device.LogLevelError
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// open TUN device (or use supplied fd)
|
// open TUN device (or use supplied fd)
|
||||||
|
|
||||||
tun, err := func() (tun.Device, error) {
|
tdev, err := func() (tun.Device, error) {
|
||||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||||
if tunFdStr == "" {
|
if tunFdStr == "" {
|
||||||
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
||||||
@@ -124,7 +124,7 @@ func main() {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = syscall.SetNonblock(int(fd), true)
|
err = unix.SetNonblock(int(fd), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,7 @@ func main() {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
realInterfaceName, err2 := tun.Name()
|
realInterfaceName, err2 := tdev.Name()
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
interfaceName = realInterfaceName
|
interfaceName = realInterfaceName
|
||||||
}
|
}
|
||||||
@@ -145,12 +145,10 @@ func main() {
|
|||||||
fmt.Sprintf("(%s) ", interfaceName),
|
fmt.Sprintf("(%s) ", interfaceName),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
|
logger.Verbosef("Starting wireguard-go version %s", Version)
|
||||||
|
|
||||||
logger.Debug.Println("Debug log enabled")
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Println("Failed to create TUN device:", err)
|
logger.Errorf("Failed to create TUN device: %v", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,9 +169,8 @@ func main() {
|
|||||||
|
|
||||||
return os.NewFile(uintptr(fd), ""), nil
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Println("UAPI listen error:", err)
|
logger.Errorf("UAPI listen error: %v", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -199,7 +196,7 @@ func main() {
|
|||||||
files[0], // stdin
|
files[0], // stdin
|
||||||
files[1], // stdout
|
files[1], // stdout
|
||||||
files[2], // stderr
|
files[2], // stderr
|
||||||
tun.File(),
|
tdev.File(),
|
||||||
fileUAPI,
|
fileUAPI,
|
||||||
},
|
},
|
||||||
Dir: ".",
|
Dir: ".",
|
||||||
@@ -208,7 +205,7 @@ func main() {
|
|||||||
|
|
||||||
path, err := os.Executable()
|
path, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Println("Failed to determine executable:", err)
|
logger.Errorf("Failed to determine executable: %v", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,23 +215,23 @@ func main() {
|
|||||||
attr,
|
attr,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Println("Failed to daemonize:", err)
|
logger.Errorf("Failed to daemonize: %v", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
process.Release()
|
process.Release()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
device := device.NewDevice(tun, logger)
|
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
|
||||||
|
|
||||||
logger.Info.Println("Device started")
|
logger.Verbosef("Device started")
|
||||||
|
|
||||||
errs := make(chan error)
|
errs := make(chan error)
|
||||||
term := make(chan os.Signal, 1)
|
term := make(chan os.Signal, 1)
|
||||||
|
|
||||||
uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
|
uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Println("Failed to listen on uapi socket:", err)
|
logger.Errorf("Failed to listen on uapi socket: %v", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -249,11 +246,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logger.Info.Println("UAPI listener started")
|
logger.Verbosef("UAPI listener started")
|
||||||
|
|
||||||
// wait for program to terminate
|
// wait for program to terminate
|
||||||
|
|
||||||
signal.Notify(term, syscall.SIGTERM)
|
signal.Notify(term, unix.SIGTERM)
|
||||||
signal.Notify(term, os.Interrupt)
|
signal.Notify(term, os.Interrupt)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -267,5 +264,5 @@ func main() {
|
|||||||
uapi.Close()
|
uapi.Close()
|
||||||
device.Close()
|
device.Close()
|
||||||
|
|
||||||
logger.Info.Println("Shutting down")
|
logger.Verbosef("Shutting down")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package main
|
package main
|
||||||
@@ -9,12 +9,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"github.com/Lordy82/wireguard-go/conn"
|
||||||
|
"github.com/Lordy82/wireguard-go/device"
|
||||||
|
"github.com/Lordy82/wireguard-go/ipc"
|
||||||
|
"github.com/Lordy82/wireguard-go/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -31,30 +32,33 @@ func main() {
|
|||||||
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
|
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
|
||||||
|
|
||||||
logger := device.NewLogger(
|
logger := device.NewLogger(
|
||||||
device.LogLevelDebug,
|
device.LogLevelVerbose,
|
||||||
fmt.Sprintf("(%s) ", interfaceName),
|
fmt.Sprintf("(%s) ", interfaceName),
|
||||||
)
|
)
|
||||||
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
|
logger.Verbosef("Starting wireguard-go version %s", Version)
|
||||||
logger.Debug.Println("Debug log enabled")
|
|
||||||
|
|
||||||
tun, err := tun.CreateTUN(interfaceName)
|
tun, err := tun.CreateTUN(interfaceName, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
realInterfaceName, err2 := tun.Name()
|
realInterfaceName, err2 := tun.Name()
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
interfaceName = realInterfaceName
|
interfaceName = realInterfaceName
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
logger.Error.Println("Failed to create TUN device:", err)
|
logger.Errorf("Failed to create TUN device: %v", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
device := device.NewDevice(tun, logger)
|
device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
|
||||||
device.Up()
|
err = device.Up()
|
||||||
logger.Info.Println("Device started")
|
if err != nil {
|
||||||
|
logger.Errorf("Failed to bring up device: %v", err)
|
||||||
|
os.Exit(ExitSetupFailed)
|
||||||
|
}
|
||||||
|
logger.Verbosef("Device started")
|
||||||
|
|
||||||
uapi, err := ipc.UAPIListen(interfaceName)
|
uapi, err := ipc.UAPIListen(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error.Println("Failed to listen on uapi socket:", err)
|
logger.Errorf("Failed to listen on uapi socket: %v", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,13 +75,13 @@ func main() {
|
|||||||
go device.IpcHandle(conn)
|
go device.IpcHandle(conn)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
logger.Info.Println("UAPI listener started")
|
logger.Verbosef("UAPI listener started")
|
||||||
|
|
||||||
// wait for program to terminate
|
// wait for program to terminate
|
||||||
|
|
||||||
signal.Notify(term, os.Interrupt)
|
signal.Notify(term, os.Interrupt)
|
||||||
signal.Notify(term, os.Kill)
|
signal.Notify(term, os.Kill)
|
||||||
signal.Notify(term, syscall.SIGTERM)
|
signal.Notify(term, windows.SIGTERM)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-term:
|
case <-term:
|
||||||
@@ -90,5 +94,5 @@ func main() {
|
|||||||
uapi.Close()
|
uapi.Close()
|
||||||
device.Close()
|
device.Close()
|
||||||
|
|
||||||
logger.Info.Println("Shutting down")
|
logger.Verbosef("Shutting down")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ratelimiter
|
package ratelimiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -20,21 +20,22 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RatelimiterEntry struct {
|
type RatelimiterEntry struct {
|
||||||
sync.Mutex
|
mu sync.Mutex
|
||||||
lastTime time.Time
|
lastTime time.Time
|
||||||
tokens int64
|
tokens int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ratelimiter struct {
|
type Ratelimiter struct {
|
||||||
sync.RWMutex
|
mu sync.RWMutex
|
||||||
stopReset chan struct{}
|
timeNow func() time.Time
|
||||||
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
|
|
||||||
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
|
stopReset chan struct{} // send to reset, close to stop
|
||||||
|
table map[netip.Addr]*RatelimiterEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Close() {
|
func (rate *Ratelimiter) Close() {
|
||||||
rate.Lock()
|
rate.mu.Lock()
|
||||||
defer rate.Unlock()
|
defer rate.mu.Unlock()
|
||||||
|
|
||||||
if rate.stopReset != nil {
|
if rate.stopReset != nil {
|
||||||
close(rate.stopReset)
|
close(rate.stopReset)
|
||||||
@@ -42,111 +43,83 @@ func (rate *Ratelimiter) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Init() {
|
func (rate *Ratelimiter) Init() {
|
||||||
rate.Lock()
|
rate.mu.Lock()
|
||||||
defer rate.Unlock()
|
defer rate.mu.Unlock()
|
||||||
|
|
||||||
|
if rate.timeNow == nil {
|
||||||
|
rate.timeNow = time.Now
|
||||||
|
}
|
||||||
|
|
||||||
// stop any ongoing garbage collection routine
|
// stop any ongoing garbage collection routine
|
||||||
|
|
||||||
if rate.stopReset != nil {
|
if rate.stopReset != nil {
|
||||||
close(rate.stopReset)
|
close(rate.stopReset)
|
||||||
}
|
}
|
||||||
|
|
||||||
rate.stopReset = make(chan struct{})
|
rate.stopReset = make(chan struct{})
|
||||||
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
|
rate.table = make(map[netip.Addr]*RatelimiterEntry)
|
||||||
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
|
|
||||||
|
|
||||||
// start garbage collection routine
|
stopReset := rate.stopReset // store in case Init is called again.
|
||||||
|
|
||||||
|
// Start garbage collection routine.
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(time.Second)
|
ticker := time.NewTicker(time.Second)
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case _, ok := <-rate.stopReset:
|
case _, ok := <-stopReset:
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
if ok {
|
if !ok {
|
||||||
ticker = time.NewTicker(time.Second)
|
|
||||||
} else {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
ticker = time.NewTicker(time.Second)
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
func() {
|
if rate.cleanup() {
|
||||||
rate.Lock()
|
ticker.Stop()
|
||||||
defer rate.Unlock()
|
}
|
||||||
|
|
||||||
for key, entry := range rate.tableIPv4 {
|
|
||||||
entry.Lock()
|
|
||||||
if time.Since(entry.lastTime) > garbageCollectTime {
|
|
||||||
delete(rate.tableIPv4, key)
|
|
||||||
}
|
|
||||||
entry.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, entry := range rate.tableIPv6 {
|
|
||||||
entry.Lock()
|
|
||||||
if time.Since(entry.lastTime) > garbageCollectTime {
|
|
||||||
delete(rate.tableIPv6, key)
|
|
||||||
}
|
|
||||||
entry.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
|
|
||||||
ticker.Stop()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
func (rate *Ratelimiter) cleanup() (empty bool) {
|
||||||
var entry *RatelimiterEntry
|
rate.mu.Lock()
|
||||||
var keyIPv4 [net.IPv4len]byte
|
defer rate.mu.Unlock()
|
||||||
var keyIPv6 [net.IPv6len]byte
|
|
||||||
|
|
||||||
// lookup entry
|
for key, entry := range rate.table {
|
||||||
|
entry.mu.Lock()
|
||||||
IPv4 := ip.To4()
|
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
||||||
IPv6 := ip.To16()
|
delete(rate.table, key)
|
||||||
|
}
|
||||||
rate.RLock()
|
entry.mu.Unlock()
|
||||||
|
|
||||||
if IPv4 != nil {
|
|
||||||
copy(keyIPv4[:], IPv4)
|
|
||||||
entry = rate.tableIPv4[keyIPv4]
|
|
||||||
} else {
|
|
||||||
copy(keyIPv6[:], IPv6)
|
|
||||||
entry = rate.tableIPv6[keyIPv6]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rate.RUnlock()
|
return len(rate.table) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
|
||||||
|
var entry *RatelimiterEntry
|
||||||
|
// lookup entry
|
||||||
|
rate.mu.RLock()
|
||||||
|
entry = rate.table[ip]
|
||||||
|
rate.mu.RUnlock()
|
||||||
|
|
||||||
// make new entry if not found
|
// make new entry if not found
|
||||||
|
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
entry = new(RatelimiterEntry)
|
entry = new(RatelimiterEntry)
|
||||||
entry.tokens = maxTokens - packetCost
|
entry.tokens = maxTokens - packetCost
|
||||||
entry.lastTime = time.Now()
|
entry.lastTime = rate.timeNow()
|
||||||
rate.Lock()
|
rate.mu.Lock()
|
||||||
if IPv4 != nil {
|
rate.table[ip] = entry
|
||||||
rate.tableIPv4[keyIPv4] = entry
|
if len(rate.table) == 1 {
|
||||||
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
|
rate.stopReset <- struct{}{}
|
||||||
rate.stopReset <- struct{}{}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rate.tableIPv6[keyIPv6] = entry
|
|
||||||
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
|
|
||||||
rate.stopReset <- struct{}{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
rate.Unlock()
|
rate.mu.Unlock()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// add tokens to entry
|
// add tokens to entry
|
||||||
|
entry.mu.Lock()
|
||||||
entry.Lock()
|
now := rate.timeNow()
|
||||||
now := time.Now()
|
|
||||||
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
||||||
entry.lastTime = now
|
entry.lastTime = now
|
||||||
if entry.tokens > maxTokens {
|
if entry.tokens > maxTokens {
|
||||||
@@ -154,12 +127,11 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// subtract cost of packet
|
// subtract cost of packet
|
||||||
|
|
||||||
if entry.tokens > packetCost {
|
if entry.tokens > packetCost {
|
||||||
entry.tokens -= packetCost
|
entry.tokens -= packetCost
|
||||||
entry.Unlock()
|
entry.mu.Unlock()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
entry.Unlock()
|
entry.mu.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,32 +1,31 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ratelimiter
|
package ratelimiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RatelimiterResult struct {
|
type result struct {
|
||||||
allowed bool
|
allowed bool
|
||||||
text string
|
text string
|
||||||
wait time.Duration
|
wait time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRatelimiter(t *testing.T) {
|
func TestRatelimiter(t *testing.T) {
|
||||||
|
var rate Ratelimiter
|
||||||
|
var expectedResults []result
|
||||||
|
|
||||||
var ratelimiter Ratelimiter
|
nano := func(nano int64) time.Duration {
|
||||||
var expectedResults []RatelimiterResult
|
|
||||||
|
|
||||||
Nano := func(nano int64) time.Duration {
|
|
||||||
return time.Nanosecond * time.Duration(nano)
|
return time.Nanosecond * time.Duration(nano)
|
||||||
}
|
}
|
||||||
|
|
||||||
Add := func(res RatelimiterResult) {
|
add := func(res result) {
|
||||||
expectedResults = append(
|
expectedResults = append(
|
||||||
expectedResults,
|
expectedResults,
|
||||||
res,
|
res,
|
||||||
@@ -34,69 +33,86 @@ func TestRatelimiter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < packetsBurstable; i++ {
|
for i := 0; i < packetsBurstable; i++ {
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
text: "inital burst",
|
text: "initial burst",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: false,
|
allowed: false,
|
||||||
text: "after burst",
|
text: "after burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
|
wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
|
||||||
text: "filling tokens for single packet",
|
text: "filling tokens for single packet",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: false,
|
allowed: false,
|
||||||
text: "not having refilled enough",
|
text: "not having refilled enough",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
|
wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
|
||||||
text: "filling tokens for two packet burst",
|
text: "filling tokens for two packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
text: "second packet in 2 packet burst",
|
text: "second packet in 2 packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: false,
|
allowed: false,
|
||||||
text: "packet following 2 packet burst",
|
text: "packet following 2 packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
ips := []net.IP{
|
ips := []netip.Addr{
|
||||||
net.ParseIP("127.0.0.1"),
|
netip.MustParseAddr("127.0.0.1"),
|
||||||
net.ParseIP("192.168.1.1"),
|
netip.MustParseAddr("192.168.1.1"),
|
||||||
net.ParseIP("172.167.2.3"),
|
netip.MustParseAddr("172.167.2.3"),
|
||||||
net.ParseIP("97.231.252.215"),
|
netip.MustParseAddr("97.231.252.215"),
|
||||||
net.ParseIP("248.97.91.167"),
|
netip.MustParseAddr("248.97.91.167"),
|
||||||
net.ParseIP("188.208.233.47"),
|
netip.MustParseAddr("188.208.233.47"),
|
||||||
net.ParseIP("104.2.183.179"),
|
netip.MustParseAddr("104.2.183.179"),
|
||||||
net.ParseIP("72.129.46.120"),
|
netip.MustParseAddr("72.129.46.120"),
|
||||||
net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
|
netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
|
||||||
net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
|
netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
|
||||||
net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
|
netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
|
||||||
net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
|
netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
|
||||||
net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
|
netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
|
||||||
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
||||||
}
|
}
|
||||||
|
|
||||||
ratelimiter.Init()
|
now := time.Now()
|
||||||
|
rate.timeNow = func() time.Time {
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// Lock to avoid data race with cleanup goroutine from Init.
|
||||||
|
rate.mu.Lock()
|
||||||
|
defer rate.mu.Unlock()
|
||||||
|
|
||||||
|
rate.timeNow = time.Now
|
||||||
|
}()
|
||||||
|
timeSleep := func(d time.Duration) {
|
||||||
|
now = now.Add(d + 1)
|
||||||
|
rate.cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
rate.Init()
|
||||||
|
defer rate.Close()
|
||||||
|
|
||||||
for i, res := range expectedResults {
|
for i, res := range expectedResults {
|
||||||
time.Sleep(res.wait)
|
timeSleep(res.wait)
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
allowed := ratelimiter.Allow(ip)
|
allowed := rate.Allow(ip)
|
||||||
if allowed != res.allowed {
|
if allowed != res.allowed {
|
||||||
t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
|
t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
101
replay/replay.go
101
replay/replay.go
@@ -1,83 +1,62 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
||||||
package replay
|
package replay
|
||||||
|
|
||||||
/* Implementation of RFC6479
|
type block uint64
|
||||||
* https://tools.ietf.org/html/rfc6479
|
|
||||||
*
|
|
||||||
* The implementation is not safe for concurrent use!
|
|
||||||
*/
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// See: https://golang.org/src/math/big/arith.go
|
blockBitLog = 6 // 1<<6 == 64 bits
|
||||||
_Wordm = ^uintptr(0)
|
blockBits = 1 << blockBitLog // must be power of 2
|
||||||
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
|
ringBlocks = 1 << 7 // must be power of 2
|
||||||
_WordSize = 1 << _WordLogSize
|
windowSize = (ringBlocks - 1) * blockBits
|
||||||
|
blockMask = ringBlocks - 1
|
||||||
|
bitMask = blockBits - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// A Filter rejects replayed messages by checking if message counter value is
|
||||||
CounterRedundantBitsLog = _WordLogSize + 3
|
// within a sliding window of previously received messages.
|
||||||
CounterRedundantBits = _WordSize * 8
|
// The zero value for Filter is an empty filter ready to use.
|
||||||
CounterBitsTotal = 2048
|
// Filters are unsafe for concurrent use.
|
||||||
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
|
type Filter struct {
|
||||||
)
|
last uint64
|
||||||
|
ring [ringBlocks]block
|
||||||
const (
|
|
||||||
BacktrackWords = CounterBitsTotal / _WordSize
|
|
||||||
)
|
|
||||||
|
|
||||||
func minUint64(a uint64, b uint64) uint64 {
|
|
||||||
if a > b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReplayFilter struct {
|
// Reset resets the filter to empty state.
|
||||||
counter uint64
|
func (f *Filter) Reset() {
|
||||||
backtrack [BacktrackWords]uintptr
|
f.last = 0
|
||||||
|
f.ring[0] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (filter *ReplayFilter) Init() {
|
// ValidateCounter checks if the counter should be accepted.
|
||||||
filter.counter = 0
|
// Overlimit counters (>= limit) are always rejected.
|
||||||
filter.backtrack[0] = 0
|
func (f *Filter) ValidateCounter(counter, limit uint64) bool {
|
||||||
}
|
|
||||||
|
|
||||||
func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
|
|
||||||
if counter >= limit {
|
if counter >= limit {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
indexBlock := counter >> blockBitLog
|
||||||
indexWord := counter >> CounterRedundantBitsLog
|
if counter > f.last { // move window forward
|
||||||
|
current := f.last >> blockBitLog
|
||||||
if counter > filter.counter {
|
diff := indexBlock - current
|
||||||
|
if diff > ringBlocks {
|
||||||
// move window forward
|
diff = ringBlocks // cap diff to clear the whole ring
|
||||||
|
|
||||||
current := filter.counter >> CounterRedundantBitsLog
|
|
||||||
diff := minUint64(indexWord-current, BacktrackWords)
|
|
||||||
for i := uint64(1); i <= diff; i++ {
|
|
||||||
filter.backtrack[(current+i)%BacktrackWords] = 0
|
|
||||||
}
|
}
|
||||||
filter.counter = counter
|
for i := current + 1; i <= current+diff; i++ {
|
||||||
|
f.ring[i&blockMask] = 0
|
||||||
} else if filter.counter-counter > CounterWindowSize {
|
}
|
||||||
|
f.last = counter
|
||||||
// behind current window
|
} else if f.last-counter > windowSize { // behind current window
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
indexWord %= BacktrackWords
|
|
||||||
indexBit := counter & uint64(CounterRedundantBits-1)
|
|
||||||
|
|
||||||
// check and set bit
|
// check and set bit
|
||||||
|
indexBlock &= blockMask
|
||||||
oldValue := filter.backtrack[indexWord]
|
indexBit := counter & bitMask
|
||||||
newValue := oldValue | (1 << indexBit)
|
old := f.ring[indexBlock]
|
||||||
filter.backtrack[indexWord] = newValue
|
new := old | 1<<indexBit
|
||||||
return oldValue != newValue
|
f.ring[indexBlock] = new
|
||||||
|
return old != new
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package replay
|
package replay
|
||||||
@@ -14,22 +14,22 @@ import (
|
|||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
const RejectAfterMessages = 1<<64 - 1<<13 - 1
|
||||||
|
|
||||||
func TestReplay(t *testing.T) {
|
func TestReplay(t *testing.T) {
|
||||||
var filter ReplayFilter
|
var filter Filter
|
||||||
|
|
||||||
T_LIM := CounterWindowSize + 1
|
const T_LIM = windowSize + 1
|
||||||
|
|
||||||
testNumber := 0
|
testNumber := 0
|
||||||
T := func(n uint64, v bool) {
|
T := func(n uint64, expected bool) {
|
||||||
testNumber++
|
testNumber++
|
||||||
if filter.ValidateCounter(n, RejectAfterMessages) != v {
|
if filter.ValidateCounter(n, RejectAfterMessages) != expected {
|
||||||
t.Fatal("Test", testNumber, "failed", n, v)
|
t.Fatal("Test", testNumber, "failed", n, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
|
|
||||||
T(0, true) /* 1 */
|
T(0, true) /* 1 */
|
||||||
T(1, true) /* 2 */
|
T(1, true) /* 2 */
|
||||||
@@ -67,53 +67,53 @@ func TestReplay(t *testing.T) {
|
|||||||
T(0, false) /* 34 */
|
T(0, false) /* 34 */
|
||||||
|
|
||||||
t.Log("Bulk test 1")
|
t.Log("Bulk test 1")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := uint64(1); i <= CounterWindowSize; i++ {
|
for i := uint64(1); i <= windowSize; i++ {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(0, true)
|
T(0, true)
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
|
||||||
t.Log("Bulk test 2")
|
t.Log("Bulk test 2")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := uint64(2); i <= CounterWindowSize+1; i++ {
|
for i := uint64(2); i <= windowSize+1; i++ {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(1, true)
|
T(1, true)
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
|
||||||
t.Log("Bulk test 3")
|
t.Log("Bulk test 3")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize + 1; i > 0; i-- {
|
for i := uint64(windowSize + 1); i > 0; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("Bulk test 4")
|
t.Log("Bulk test 4")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize + 2; i > 1; i-- {
|
for i := uint64(windowSize + 2); i > 1; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
|
||||||
t.Log("Bulk test 5")
|
t.Log("Bulk test 5")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize; i > 0; i-- {
|
for i := uint64(windowSize); i > 0; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(CounterWindowSize+1, true)
|
T(windowSize+1, true)
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
|
||||||
t.Log("Bulk test 6")
|
t.Log("Bulk test 6")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize; i > 0; i-- {
|
for i := uint64(windowSize); i > 0; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(0, true)
|
T(0, true)
|
||||||
T(CounterWindowSize+1, true)
|
T(windowSize+1, true)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package rwcancel
|
|
||||||
|
|
||||||
import "golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
type fdSet struct {
|
|
||||||
unix.FdSet
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fdset *fdSet) set(i int) {
|
|
||||||
bits := 32 << (^uint(0) >> 63)
|
|
||||||
fdset.Bits[i/bits] |= 1 << uint(i%bits)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fdset *fdSet) check(i int) bool {
|
|
||||||
bits := 32 << (^uint(0) >> 63)
|
|
||||||
return (fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,12 @@
|
|||||||
|
//go:build !windows && !wasm
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// Package rwcancel implements cancelable read/write operations on
|
||||||
|
// a file descriptor.
|
||||||
package rwcancel
|
package rwcancel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -13,13 +17,6 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func max(a, b int) int {
|
|
||||||
if a > b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
type RWCancel struct {
|
type RWCancel struct {
|
||||||
fd int
|
fd int
|
||||||
closingReader *os.File
|
closingReader *os.File
|
||||||
@@ -42,47 +39,47 @@ func NewRWCancel(fd int) (*RWCancel, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RetryAfterError(err error) bool {
|
func RetryAfterError(err error) bool {
|
||||||
if pe, ok := err.(*os.PathError); ok {
|
return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR)
|
||||||
err = pe.Err
|
|
||||||
}
|
|
||||||
if errno, ok := err.(syscall.Errno); ok {
|
|
||||||
switch errno {
|
|
||||||
case syscall.EAGAIN, syscall.EINTR:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *RWCancel) ReadyRead() bool {
|
func (rw *RWCancel) ReadyRead() bool {
|
||||||
closeFd := int(rw.closingReader.Fd())
|
closeFd := int32(rw.closingReader.Fd())
|
||||||
fdset := fdSet{}
|
|
||||||
fdset.set(rw.fd)
|
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}}
|
||||||
fdset.set(closeFd)
|
var err error
|
||||||
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
|
for {
|
||||||
|
_, err = unix.Poll(pollFds, -1)
|
||||||
|
if err == nil || !RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if fdset.check(closeFd) {
|
if pollFds[1].Revents != 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return fdset.check(rw.fd)
|
return pollFds[0].Revents != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *RWCancel) ReadyWrite() bool {
|
func (rw *RWCancel) ReadyWrite() bool {
|
||||||
closeFd := int(rw.closingReader.Fd())
|
closeFd := int32(rw.closingReader.Fd())
|
||||||
fdset := fdSet{}
|
pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
|
||||||
fdset.set(rw.fd)
|
var err error
|
||||||
fdset.set(closeFd)
|
for {
|
||||||
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
|
_, err = unix.Poll(pollFds, -1)
|
||||||
|
if err == nil || !RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if fdset.check(closeFd) {
|
|
||||||
|
if pollFds[1].Revents != 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return fdset.check(rw.fd)
|
return pollFds[0].Revents != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *RWCancel) Read(p []byte) (n int, err error) {
|
func (rw *RWCancel) Read(p []byte) (n int, err error) {
|
||||||
@@ -92,7 +89,7 @@ func (rw *RWCancel) Read(p []byte) (n int, err error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
if !rw.ReadyRead() {
|
if !rw.ReadyRead() {
|
||||||
return 0, errors.New("fd closed")
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -104,7 +101,7 @@ func (rw *RWCancel) Write(p []byte) (n int, err error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
if !rw.ReadyWrite() {
|
if !rw.ReadyWrite() {
|
||||||
return 0, errors.New("fd closed")
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -113,3 +110,8 @@ func (rw *RWCancel) Cancel() (err error) {
|
|||||||
_, err = rw.closingWriter.Write([]byte{0})
|
_, err = rw.closingWriter.Write([]byte{0})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rw *RWCancel) Close() {
|
||||||
|
rw.closingReader.Close()
|
||||||
|
rw.closingWriter.Close()
|
||||||
|
}
|
||||||
|
|||||||
9
rwcancel/rwcancel_stub.go
Normal file
9
rwcancel/rwcancel_stub.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build windows || wasm
|
||||||
|
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
package rwcancel
|
||||||
|
|
||||||
|
type RWCancel struct{}
|
||||||
|
|
||||||
|
func (*RWCancel) Cancel() {}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user