136 Commits

Author SHA1 Message Date
Jason A. Donenfeld
21636207a6 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-23 19:12:33 +01:00
Jason A. Donenfeld
c7b76d3d9e device: uniformly check ECDH output for zeros
For some reason, this was omitted for response messages.

Reported-by: z <dzm@unexpl0.red>
Fixes: 8c34c4c ("First set of code review patches")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-16 16:33:14 +01:00
Jordan Whited
1e2c3e5a3c tun: guard Device.Events() against chan writes
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-09 12:35:58 -03:00
Jason A. Donenfeld
ebbd4a4330 global: bump copyright year
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-07 20:39:29 -03:00
Soren L. Hansen
0ae4b3177c tun/netstack: make http examples communicate with each other
This seems like a much better demonstration as it removes the need for
external components.

Signed-off-by: Søren L. Hansen <sorenisanerd@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-07 20:38:19 -03:00
Colin Adler
077ce8ecab tun/netstack: bump gvisor
Bump gVisor to a recent known-good version.

Signed-off-by: Colin Adler <colin1adler@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2023-02-07 20:10:52 -03:00
Jason A. Donenfeld
bb719d3a6e global: bump copyright year
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-09-20 17:21:32 +02:00
Colin Adler
fde0a9525a tun/netstack: ensure (*netTun).incomingPacket chan is closed
Without this, `device.Close()` will deadlock.

Signed-off-by: Colin Adler <colin1adler@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-09-20 17:17:29 +02:00
Brad Fitzpatrick
b51010ba13 all: use Go 1.19 and its atomic types
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-09-04 12:57:30 +02:00
Jason A. Donenfeld
d1d08426b2 tun/netstack: remove separate module
Now that the gvisor deps aren't insane, we can just do this in the main
module.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-08-29 12:14:05 -04:00
Shengjing Zhu
3381e21b18 tun/netstack: bump to latest gvisor
To build with go1.19, gvisor needs
99325baf ("Bump gVisor build tags to go1.19").

However gvisor.dev/gvisor/pkg/tcpip/buffer is no longer available,
so refactor to use gvisor.dev/gvisor/pkg/tcpip/link/channel directly.

Signed-off-by: Shengjing Zhu <i@zhsj.me>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-08-29 12:01:05 -04:00
Brad Fitzpatrick
c31a7b1ab4 conn, device, tun: set CLOEXEC on fds
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-07-04 01:42:12 +02:00
Tobias Klauser
6a08d81f6b tun: use ByteSliceToString from golang.org/x/sys/unix
Use unix.ByteSliceToString in (*NativeTun).nameSlice to convert the
TUNGETIFF ioctl result []byte to a string.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-06-01 15:00:07 +02:00
Josh Bleecher Snyder
ef5c587f78 conn: remove the final alloc per packet receive
This does bind_std only; other platforms remain.

The remaining alloc per iteration in the Throughput benchmark
comes from the tuntest package, and should not appear in regular use.

name           old time/op      new time/op      delta
Latency-10         25.2µs ± 1%      25.0µs ± 0%   -0.58%  (p=0.006 n=10+10)
Throughput-10      2.44µs ± 3%      2.41µs ± 2%     ~     (p=0.140 n=10+8)

name           old alloc/op     new alloc/op     delta
Latency-10           854B ± 5%        741B ± 3%  -13.22%  (p=0.000 n=10+10)
Throughput-10        265B ±34%        267B ±39%     ~     (p=0.670 n=10+10)

name           old allocs/op    new allocs/op    delta
Latency-10           16.0 ± 0%        14.0 ± 0%  -12.50%  (p=0.000 n=10+10)
Throughput-10        2.00 ± 0%        1.00 ± 0%  -50.00%  (p=0.000 n=10+10)

name           old packet-loss  new packet-loss  delta
Throughput-10        0.01 ±82%       0.01 ±282%     ~     (p=0.321 n=9+8)

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-04-07 03:31:10 +02:00
Jason A. Donenfeld
193cf8d6a5 conn: use netip for std bind
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-17 22:23:02 -06:00
Jason A. Donenfeld
ee1c8e0e87 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 21:32:14 -06:00
Jason A. Donenfeld
95b48cdb39 tun/netstack: bump mod
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 18:01:34 -06:00
Jason A. Donenfeld
5aff28b14c mod: bump packages and remove compat netip
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-16 17:51:47 -06:00
Josh Bleecher Snyder
46826fc4e5 all: use any in place of interface{}
Enabled by using Go 1.18. A bit less verbose.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2022-03-16 16:40:24 -07:00
Josh Bleecher Snyder
42c9af45e1 all: update to Go 1.18
Bump go.mod and README.

Switch to upstream net/netip.

Use strings.Cut.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2022-03-16 16:09:48 -07:00
Alexander Neumann
ae6bc4dd64 tun/netstack: check error returned by SetDeadline()
Signed-off-by: Alexander Neumann <alexander.neumann@redteam-pentesting.de>
[Jason: don't wrap deadline error.]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-09 18:27:36 -07:00
Alexander Neumann
2cec4d1a62 tun/netstack: update to latest wireguard-go
This commit fixes all callsites of netip.AddrFromSlice(), which has
changed its signature and now returns two values.

Signed-off-by: Alexander Neumann <alexander.neumann@redteam-pentesting.de>
[Jason: remove error handling from AddrFromSlice.]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-03-09 18:27:36 -07:00
Jason A. Donenfeld
3b95c81cc1 tun/netstack: simplify read timeout on ping socket
I'm not 100% sure this is correct, but it certainly is a lot simpler.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-02-02 23:30:31 +01:00
Thomas H. Ptacek
b9669b734e tun/netstack: implement ICMP ping
Provide a PacketConn interface for netstack's ICMP endpoint; netstack
currently only provides EchoRequest/EchoResponse ICMP support, so this
code exposes only an interface for doing ping.

Signed-off-by: Thomas Ptacek <thomas@sockpuppet.org>
[Jason: rework structure, match std go interfaces, add example code]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-02-02 23:09:37 +01:00
Jason A. Donenfeld
e0b8f11489 version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-01-17 17:37:42 +01:00
Jason A. Donenfeld
114a3db918 ipc: bsd: try again if kqueue returns EINTR
Reported-by: J. Michael McAtee <mmcatee@jumptrading.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2022-01-14 16:10:43 +01:00
Jason A. Donenfeld
9c9e7e2724 global: apply gofumpt
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-12-09 23:15:55 +01:00
Jason A. Donenfeld
2dd424e2d8 device: handle peer post config on blank line
We missed a function exit point. This was exacerbated by e3134bf
("device: defer state machine transitions until configuration is
complete"), but the bug existed prior. Minus provided the following
useful reproducer script:

    #!/usr/bin/env bash

    set -eux

    make wireguard-go || exit 125

    ip netns del test-ns || true
    ip netns add test-ns
    ip link add test-kernel type wireguard
    wg set test-kernel listen-port 0 private-key <(echo "QMCfZcp1KU27kEkpcMCgASEjDnDZDYsfMLHPed7+538=") peer "eDPZJMdfnb8ZcA/VSUnLZvLB2k8HVH12ufCGa7Z7rHI=" allowed-ips 10.51.234.10/32
    ip link set test-kernel netns test-ns up
    ip -n test-ns addr add 10.51.234.1/24 dev test-kernel
    port=$(ip netns exec test-ns wg show test-kernel listen-port)

    ip link del test-go || true
    ./wireguard-go test-go
    wg set test-go private-key <(echo "WBM7qimR3vFk1QtWNfH+F4ggy/hmO+5hfIHKxxI4nF4=") peer "+nj9Dkqpl4phsHo2dQliGm5aEiWJJgBtYKbh7XjeNjg=" allowed-ips 0.0.0.0/0 endpoint 127.0.0.1:$port
    ip addr add 10.51.234.10/24 dev test-go
    ip link set test-go up

    ping -c2 -W1 10.51.234.1

Reported-by: minus <minus@mnus.de>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-29 12:31:54 -05:00
Josh Bleecher Snyder
387f7c461a device: reduce peer lock critical section in UAPI
The deferred RUnlock calls weren't executing until all peers
had been processed. Add an anonymous function so that each
peer may be unlocked as soon as it is completed.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-23 22:03:15 +01:00
Josh Bleecher Snyder
4d87c9e824 device: remove code using unsafe
There is no performance impact.

name                             old time/op  new time/op  delta
TrieIPv4Peers100Addresses1000-8  78.6ns ± 1%  79.4ns ± 3%    ~     (p=0.604 n=10+9)
TrieIPv4Peers10Addresses10-8     29.1ns ± 2%  28.8ns ± 1%  -1.12%  (p=0.014 n=10+9)
TrieIPv6Peers100Addresses1000-8  78.9ns ± 1%  78.6ns ± 1%    ~     (p=0.492 n=10+10)
TrieIPv6Peers10Addresses10-8     29.3ns ± 2%  28.6ns ± 2%  -2.16%  (p=0.000 n=10+10)

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-23 22:03:15 +01:00
Jason A. Donenfeld
ef8d6804d7 global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-23 22:03:15 +01:00
Jason A. Donenfeld
de7c702ace device: only propagate roaming value before peer is referenced elsewhere
A peer.endpoint never becomes nil after being not-nil, so creation is
the only time we actually need to set this. This prevents a race from
when the variable is actually used elsewhere, and allows us to avoid an
expensive atomic.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:16:04 +01:00
Jason A. Donenfeld
fc4f975a4d device: align 64-bit atomic member in Device
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:07:31 +01:00
Jason A. Donenfeld
9d699ba730 device: start peers before running handshake test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:07:31 +01:00
Jason A. Donenfeld
425f7c726b Makefile: don't use test -v because it hides failures in scrollback
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 21:07:31 +01:00
David Anderson
3cae233d69 device: fix nil pointer dereference in uapi read
Signed-off-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-16 20:43:26 +01:00
Jason A. Donenfeld
111e0566dc device: make new peers inherit broken mobile semantics
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-15 23:40:47 +01:00
Jason A. Donenfeld
e3134bf665 device: defer state machine transitions until configuration is complete
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-15 23:40:47 +01:00
Jason A. Donenfeld
63abb5537b device: do not consume handshake messages if not running
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-15 23:40:47 +01:00
Jason A. Donenfeld
851efb1bb6 tun: move wintun to its own repo
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-04 12:53:55 +01:00
Jason A. Donenfeld
c07dd60cdb namedpipe: rename from winpipe to keep in sync with CL299009
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-11-04 12:53:52 +01:00
Jason A. Donenfeld
eb6302c7eb device: timers: use pre-seeded per-thread unlocked fastrandn for jitter
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-28 13:47:50 +02:00
Jason A. Donenfeld
60683d7361 device: timers: seed unsafe rng before use for jitter
Forgetting to seed the unsafe rng, the jitter before followed a fixed
pattern, which didn't help when a fleet of computers all boot at once.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-28 13:34:21 +02:00
Jason A. Donenfeld
e42c6c4bc2 wintun: align 64-bit argument on ARM32
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-26 14:53:40 +02:00
Jason A. Donenfeld
828a885a71 README: raise minimum Go to 1.17
Suggested-by: Adam Bliss <abliss@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-25 17:53:11 +02:00
Mikael Magnusson
f1f626090e tun/netstack: update gvisor
Update gvisor to v0.0.0-20211020211948-f76a604701b6, which requires some
changes to tun.go:

WriteRawPacket: Add function with not implemented error.

CreateNetTUN: Replace stack.AddAddress with stack.AddProtocolAddress, and
fix IPv6 address in error message.

Signed-off-by: Mikael Magnusson <mikma@users.sourceforge.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-22 13:22:29 -06:00
Brad Fitzpatrick
82e0b734e5 ipc, rwcancel: compile on js/wasm
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-10-20 14:50:05 -06:00
Jason A. Donenfeld
fdf57a1fa4 wintun: allow retrieving DLL version
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-20 12:13:44 -06:00
Jason A. Donenfeld
f87e87af0d version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-16 23:27:13 -06:00
Jason A. Donenfeld
ba9e364dab wintun: remove memmod option for dll loading
Only wireguard-windows used this, and it's moving to wgnt exclusively.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-16 22:49:38 -06:00
Jason A. Donenfeld
dfd688b6aa global: remove old-style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 12:02:10 -06:00
Jason A. Donenfeld
c01d52b66a global: add newer-style build tags
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 11:46:53 -06:00
Jason A. Donenfeld
82d2aa87aa wintun: use new swdevice-based API for upcoming Wintun 0.14
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-12 00:26:46 -06:00
Jason A. Donenfeld
982d5d2e84 conn,wintun: use unsafe.Slice instead of unsafeSlice
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-11 14:57:53 -06:00
Jason A. Donenfeld
642a56e165 memmod: import from wireguard-windows
We'll eventually be getting rid of it here, but keep it sync'd up for
now.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-10-11 14:53:36 -06:00
Jason A. Donenfeld
bb745b2ea3 rwcancel: use unix.Poll again but bump x/sys so it uses ppoll under the hood
This reverts commit fcc601dbf0 but then
bumps go.mod.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-27 14:19:15 -06:00
Jason A. Donenfeld
fcc601dbf0 rwcancel: use ppoll on Linux for Android
This is a temporary measure while we wait for
https://go-review.googlesource.com/c/sys/+/352310 to land.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-26 17:16:38 -06:00
Tobias Klauser
217ac1016b tun: make operateonfd.go build tags more specific
(*NativeTun).operateOnFd is only used on darwin and freebsd. Adjust the
build tags accordingly.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-23 09:54:01 -06:00
Tobias Klauser
eae5e0f3a3 tun: avoid leaking sock fd in CreateTUN error cases
At these points, the socket file descriptor is not yet wrapped in an
*os.File, so it needs to be closed explicitly on error.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-23 09:53:49 -06:00
Jason A. Donenfeld
2ef39d4754 global: add new go 1.17 build comments
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-09-05 16:00:43 +02:00
Jason A. Donenfeld
3957e9b9dd memmod: register exception handler tables
Otherwise recent WDK binaries fail on ARM64, where an exception handler
is used for trapping an illegal instruction when ARMv8.1 atomics are
being tested for functionality.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-08-05 14:56:48 +02:00
Jason A. Donenfeld
bad6caeb82 memmod: fix protected delayed load the right way
The reason this was failing before is that dloadsup.h's
DloadObtainSection was doing a linear search of sections to find which
header corresponds with the IMAGE_DELAYLOAD_DESCRIPTOR section, and we
were stupidly overwriting the VirtualSize field, so the linear search
wound up matching the .text section, which then it found to not be
marked writable and failed with FAST_FAIL_DLOAD_PROTECTION_FAILURE.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-07-29 01:27:40 +02:00
Jason A. Donenfeld
c89f5ca665 memmod: disable protected delayed load for now
Probably a bad idea, but we don't currently support it, and those huge
windows.NewCallback trampolines make juicer targets anyway.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-07-29 01:13:03 +02:00
Jason A. Donenfeld
15b24b6179 ipc: allow admins but require high integrity label
Might be more reasonable.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-24 17:01:02 +02:00
Jason A. Donenfeld
f9b48a961c device: zero out allowedip node pointers when removing
This should make it a bit easier for the garbage collector.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-04 16:33:28 +02:00
Jason A. Donenfeld
d0cf96114f device: limit allowedip fuzzer a to 4 times through
Trying this for every peer winds up being very slow and precludes it
from acceptable runtime in the CI, so reduce this to 4.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 18:22:50 +02:00
Jason A. Donenfeld
841756e328 device: simplify allowedips lookup signature
The inliner should handle this for us.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 16:29:43 +02:00
Jason A. Donenfeld
c382222eab device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to
the node and remove it in place, rather than having to recursively walk
the entire trie.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 16:29:43 +02:00
Jason A. Donenfeld
b41f4cc768 device: remove recursion from insertion and connect parent pointers
This makes the insertion algorithm a bit more efficient, while also now
taking on the additional task of connecting up parent pointers. This
will be handy in the following commit.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 15:08:42 +02:00
Jason A. Donenfeld
4a57024b94 device: reduce size of trie struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-06-03 13:51:03 +02:00
Josh Bleecher Snyder
64cb82f2b3 go.mod: bump golang.org/x/sys again
To pick up https://go-review.googlesource.com/c/sys/+/307129.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-05-25 16:34:54 +02:00
Jason A. Donenfeld
c27ff9b9f6 device: allow reducing queue constants on iOS
Heavier network extensions might require the wireguard-go component to
use less ram, so let users of this reduce these as needed.

At some point we'll put this behind a configuration method of sorts, but
for now, just expose the consts as vars.

Requested-by: Josh Bleecher Snyder <josh@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-22 01:00:51 +02:00
Jason A. Donenfeld
99e8b4ba60 tun: linux: account for interface removal from outside
On Linux we can run `ip link del wg0`, in which case the fd becomes
stale, and we should exit. Since this is an intentional action, don't
treat it as an error.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 18:26:01 +02:00
Jason A. Donenfeld
bd83f0ac99 conn: linux: protect read fds
The -1 protection was removed and the wrong error was returned, causing
us to read from a bogus fd. As well, remove the useless closures that
aren't doing anything, since this is all synchronized anyway.

Fixes: 10533c3 ("all: make conn.Bind.Open return a slice of receive functions")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 18:09:55 +02:00
Jason A. Donenfeld
50d779833e rwcancel: use ordinary os.ErrClosed instead of custom error
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 17:56:36 +02:00
Jason A. Donenfeld
a9b377e9e1 rwcancel: use poll instead of select
Suggested-by: Lennart Poettering <lennart@poettering.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-20 17:42:34 +02:00
Jason A. Donenfeld
9087e444e6 device: optimize Peer.String even more
This reduces the allocation, branches, and amount of base64 encoding.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-18 17:43:53 +02:00
Josh Bleecher Snyder
25ad08a591 device: optimize Peer.String
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
2021-05-14 00:37:30 +02:00
Jason A. Donenfeld
5846b62283 conn: windows: set count=0 on retry
When retrying, if count is not 0, we forget to dequeue another request,
and so the ring fills up and errors out.

Reported-by: Sascha Dierberg <dierberg@dresearch-fe.de>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-11 16:47:17 +02:00
Jason A. Donenfeld
9844c74f67 main: replace crlf on windows in fmt test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-10 22:23:32 +02:00
Jason A. Donenfeld
4e9e5dad09 main: check that code is formatted in unit test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-10 17:48:26 +02:00
Jason A. Donenfeld
39e0b6dade tun: format
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:21:27 +02:00
Jason A. Donenfeld
7121927b87 device: add ID to repeated routines
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:21:21 +02:00
Jason A. Donenfeld
326aec10af device: remove unusual ... in messages
We dont use ... in any other present progressive messages except these.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 12:17:41 +02:00
Jason A. Donenfeld
efb8818550 device: avoid verbose log line during ordinary shutdown sequence
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:39:06 +02:00
Jason A. Donenfeld
69b39db0b4 tun: windows: set event before waiting
In 097af6e ("tun: windows: protect reads from closing") we made sure no
functions are running when End() is called, to avoid a UaF. But we still
need to kick that event somehow, so that Read() is allowed to exit, in
order to release the lock. So this commit calls SetEvent, while moving
the closing boolean to be atomic so it can be modified without locks,
and then moves to a WaitGroup for the RCU-like pattern.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:26:24 +02:00
Jason A. Donenfeld
db733ccd65 tun: windows: rearrange struct to avoid alignment trap on 32bit
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:19:00 +02:00
Jason A. Donenfeld
a7aec4449f tun: windows: check alignment in unit test
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-07 09:15:50 +02:00
Josh Bleecher Snyder
60a26371f4 device: log all errors received by RoutineReceiveIncoming
When debugging, it's useful to know why a receive func exited.

We were already logging that, but only in the "death spiral" case.
Move the logging up, to capture it always.
Reduce the verbosity, since it is not an error case any more.
Put the receive func name in the log line.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-05-06 11:22:13 +02:00
Jason A. Donenfeld
a544776d70 tun/netstack: update go mod and remove GSO argument
Reported-by: John Xiong <xiaoyang1258@yeah.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-06 11:07:26 +02:00
Jason A. Donenfeld
69a42a4eef tun: windows: send MTU update when forced MTU changes
Otherwise the padding doesn't get updated.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-05-05 11:42:45 +02:00
Jason A. Donenfeld
097af6e135 tun: windows: protect reads from closing
The code previously used the old errors channel for checking, rather
than the simpler boolean, which caused issues on shutdown, since the
errors channel was meaningless. However, looking at this exposed a more
basic problem: Close() and all the other functions that check the closed
boolean can race. So protect with a basic RW lock, to ensure that
Close() waits for all pending operations to complete.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-26 22:22:45 -04:00
Jason A. Donenfeld
8246d251ea conn: windows: do not error out when receiving UDP jumbogram
If we receive a large UDP packet, don't return an error to receive.go,
which then terminates the receive loop. Instead, simply retry.

Considering Winsock's general finickiness, we might consider other
places where an attacker on the wire can generate error conditions like
this.

Reported-by: Sascha Dierberg <sascha.dierberg@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-26 22:07:03 -04:00
Jason A. Donenfeld
c9db4b7aaa version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-24 13:07:27 -04:00
Jason A. Donenfeld
3625f8d284 tun: freebsd: avoid OOB writes
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 15:10:23 -06:00
Jason A. Donenfeld
0687dc06c8 tun: freebsd: become controlling process when reopening tun FD
When we pass the TUN FD to the child, we have to call TUNSIFPID;
otherwise when we close the device, we get a splat in dmesg.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 15:02:44 -06:00
Jason A. Donenfeld
71aefa374d tun: freebsd: restructure and cleanup
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 14:54:59 -06:00
Jason A. Donenfeld
3d3e30beb8 tun: freebsd: remove horrific hack for getting tunnel name
As of FreeBSD 12.1, there's TUNGIFNAME.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 12:03:16 -06:00
Jason A. Donenfeld
b0e5b19969 tun: freebsd: set IFF_MULTICAST for routing daemons
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-18 20:09:04 -06:00
Jason A. Donenfeld
3988821442 main: print kernel warning on OpenBSD and FreeBSD too
More kernels!

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-15 23:35:45 -06:00
Jason A. Donenfeld
c7cd2c9eab device: don't defer unlocking from loop
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 16:19:35 -06:00
Jason A. Donenfeld
54dbe2471f conn: reconstruct v4 vs v6 receive function based on symtab
This is kind of gross but it's better than the alternatives.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 15:35:32 -06:00
Kristupas Antanavičius
d2fd0c0cc0 device: allocate new buffer in receive death spiral
Note: this bug is "hidden" by avoiding "death spiral" code path by
6228659 ("device: handle broader range of errors in RoutineReceiveIncoming").

If the code reached "death spiral" mechanism, there would be multiple
double frees happening. This results in a deadlock on iOS, because the
pools are fixed size and goroutine might stop until somebody makes
space in the pool.

This was almost 100% repro on the new ARM Macbooks:

- Build with 'ios' tag for Mac. This will enable bounded pools.
- Somehow call device.IpcSet at least couple of times (update config)
- device.BindUpdate() would be triggered
- RoutineReceiveIncoming would enter "death spiral".
- RoutineReceiveIncoming would stall on double free (pool is already
  full)
- The stuck routine would deadlock 'device.closeBindLocked()' function
  on line 'netc.stopping.Wait()'

Signed-off-by: Kristupas Antanavičius <kristupas.antanavicius@nordsec.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 11:14:53 -06:00
Jason A. Donenfeld
5f6bbe4ae8 conn: windows: reset ring to starting position after free
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 18:09:41 -06:00
Jason A. Donenfeld
75526d6071 conn: windows: compare head and tail properly
By not comparing these with the modulo, the ring became nearly never
full, resulting in completion queue buffers filling up prematurely.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 14:26:08 -06:00
Jason A. Donenfeld
fbf97502cf winrio: test that IOCP-based RIO is supported
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 14:26:08 -06:00
Josh Bleecher Snyder
10533c3e73 all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which
to receive packets (an IPv4 source and an IPv6 source),
allow the conn.Bind to specify a set of sources.

Beneficial consequences:

* If there's no IPv6 support on a system,
  conn.Bind.Open can choose not to return a receive function for it,
  which is simpler than tracking that state in the bind.
  This simplification removes existing data races from both
  conn.StdNetBind and bindtest.ChannelBind.
* If there are more than two sources on a system,
  the conn.Bind no longer needs to add a separate muxing layer.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-04-02 11:07:08 -06:00
Jason A. Donenfeld
8ed83e0427 conn: winrio: pass key parameter into struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-02 10:36:41 -06:00
Josh Bleecher Snyder
6228659a91 device: handle broader range of errors in RoutineReceiveIncoming
RoutineReceiveIncoming exits immediately on net.ErrClosed,
but not on other errors. However, for errors that are known
to be permanent, such as syscall.EAFNOSUPPORT,
we may as well exit immediately instead of retrying.

This considerably speeds up the package device tests right now,
because the Bind sometimes (incorrectly) returns syscall.EAFNOSUPPORT
instead of net.ErrClosed.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:41:43 -07:00
Josh Bleecher Snyder
517f0703f5 conn: document retry loop in StdNetBind.Open
It's not obvious on a first read what the loop is doing.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:38 -07:00
Josh Bleecher Snyder
204140016a conn: use local ipvN vars in StdNetBind.Open
This makes it clearer that they are fresh on each attempt,
and avoids the bookkeeping required to clearing them on failure.

Also, remove an unnecessary err != nil.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:38 -07:00
Josh Bleecher Snyder
822f5a6d70 conn: unify code in StdNetBind.Send
The sending code is identical for ipv4 and ipv6;
select the conn, then use it.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:32 -07:00
Josh Bleecher Snyder
02e419ed8a device: rename unsafeCloseBind to closeBindLocked
And document a bit.
This name is more idiomatic.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:07:12 -07:00
Jason A. Donenfeld
bc69a3fa60 version: bump snapshot 2021-03-23 13:07:19 -06:00
Jason A. Donenfeld
12ce53271b tun: freebsd: use broadcast mode instead of PPP mode
It makes the routing configuration simpler.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-23 12:41:34 -06:00
Jason A. Donenfeld
5f0c8b942d device: signal to close device in separate routine
Otherwise we wind up deadlocking.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-11 09:29:10 -07:00
Jason A. Donenfeld
c5f382624e tun: linux: do not spam events every second from hack listener
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-11 09:23:11 -07:00
Kay Diam
6005c573e2 tun: freebsd: allow empty names
This change allows omitting the tun interface name setting. When the
name is not set, the kernel automatically picks up the tun name and
index.

Signed-off-by: Kay Diam <kay.diam@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:32:27 -07:00
Jason A. Donenfeld
82f3e9e2af winpipe: move syscalls into x/sys
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:32:27 -07:00
Jason A. Donenfeld
4885e7c954 memmod: use resource functions from x/sys
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:04:09 -07:00
Jason A. Donenfeld
497ba95de7 memmod: do not use IsBadReadPtr
It should be enough to check for the trailing zero name.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:04:09 -07:00
Jason A. Donenfeld
0eb7206295 conn: linux: unexport mutex
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:04:09 -07:00
Jason A. Donenfeld
20714ca472 mod: bump x/sys
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-08 21:04:09 -07:00
Jason A. Donenfeld
c1e09f1927 mod: rename COPYING to LICENSE
Otherwise the netstack module doesn't show up on the package site.

https://github.com/golang/go/issues/43817#issuecomment-764987580

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-06 09:09:21 -07:00
Jason A. Donenfeld
79611c64e8 tun/netstack: bump deps and api
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-06 08:48:14 -07:00
Jason A. Donenfeld
593658d975 device: get rid of peers.empty boolean in timersActive
There's no way for len(peers)==0 when a current peer has
isRunning==false.

This requires some struct reshuffling so that the uint64 pointer is
aligned.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-03-06 08:44:38 -07:00
Jason A. Donenfeld
3c11c0308e conn: implement RIO for fast Windows UDP sockets
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-25 15:08:08 +01:00
Jason A. Donenfeld
f9dac7099e global: remove TODO name graffiti
Googlers have a habit of graffiting their name in TODO items that then
are never addressed, and other people won't go near those because
they're marked territory of another animal. I've been gradually cleaning
these up as I see them, but this commit just goes all the way and
removes the remaining stragglers.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-23 20:00:57 +01:00
Jason A. Donenfeld
9a29ae267c device: test up/down using virtual conn
This prevents port clashing bugs.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-23 20:00:57 +01:00
Jason A. Donenfeld
6603c05a4a device: cleanup unused test components
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-23 20:00:57 +01:00
Jason A. Donenfeld
a4f8e83d5d conn: make binds replacable
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-23 20:00:57 +01:00
Jason A. Donenfeld
c69481f1b3 device: disable waitpool tests
This code is stable, and the test is finicky, especially on high core
count systems, so just disable it.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-22 15:26:47 +01:00
Brad Fitzpatrick
0f4809f366 tun: make NativeTun.Close well behaved, not crash on double close
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-02-22 15:26:29 +01:00
Brad Fitzpatrick
fecb8f482a README: bump document Go requirement to 1.16
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-02-22 15:26:29 +01:00
Jason A. Donenfeld
8bf4204d2e global: stop using ioutil
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-17 22:19:27 +01:00
Jason A. Donenfeld
4e439ea10e conn: bump to 1.16 and get rid of NetErrClosed hack
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-02-16 21:05:25 +01:00
118 changed files with 4335 additions and 5239 deletions

View File

View File

@@ -10,7 +10,7 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build: generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \ tag="$$(git describe --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\nconst Version = "%s"\n' "$$tag")" && \ ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \ echo "$$ver" > version.go && \
git update-index --assume-unchanged version.go || true git update-index --assume-unchanged version.go || true
@@ -23,7 +23,7 @@ install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go" @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
test: test:
go test -v ./... go test ./...
clean: clean:
rm -f wireguard-go rm -f wireguard-go

View File

@@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
## Building ## Building
This requires an installation of [go](https://golang.org) ≥ 1.13. This requires an installation of [go](https://golang.org) ≥ 1.18.
``` ```
$ git clone https://git.zx2c4.com/wireguard-go $ git clone https://git.zx2c4.com/wireguard-go
@@ -56,7 +56,7 @@ $ make
## License ## License
Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in this software and associated documentation files (the "Software"), to deal in

View File

@@ -1,8 +1,6 @@
// +build !android
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn
@@ -10,6 +8,7 @@ package conn
import ( import (
"errors" "errors"
"net" "net"
"net/netip"
"strconv" "strconv"
"sync" "sync"
"syscall" "syscall"
@@ -18,101 +17,114 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type IPv4Source struct { type ipv4Source struct {
Src [4]byte Src [4]byte
Ifindex int32 Ifindex int32
} }
type IPv6Source struct { type ipv6Source struct {
src [16]byte src [16]byte
//ifindex belongs in dst.ZoneId // ifindex belongs in dst.ZoneId
} }
type NativeEndpoint struct { type LinuxSocketEndpoint struct {
sync.Mutex mu sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte src [unsafe.Sizeof(ipv6Source{})]byte
isV6 bool isV6 bool
} }
func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() } func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 } func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
func (endpoint *NativeEndpoint) src4() *IPv4Source { func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
} }
func (endpoint *NativeEndpoint) src6() *IPv6Source { func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
} }
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
} }
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
} }
type nativeBind struct { // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
sock4 int type LinuxSocketBind struct {
sock6 int // mu guards sock4 and sock6 and the associated fds.
lastMark uint32 // As long as someone holds mu (read or write), the associated fds are valid.
closing sync.RWMutex mu sync.RWMutex
sock4 int
sock6 int
} }
var _ Endpoint = (*NativeEndpoint)(nil) func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
var _ Bind = (*nativeBind)(nil) func NewDefaultBind() Bind { return NewLinuxSocketBind() }
func CreateEndpoint(s string) (Endpoint, error) { var (
var end NativeEndpoint _ Endpoint = (*LinuxSocketEndpoint)(nil)
addr, err := parseEndpoint(s) _ Bind = (*LinuxSocketBind)(nil)
)
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
var end LinuxSocketEndpoint
e, err := netip.ParseAddrPort(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ipv4 := addr.IP.To4() if e.Addr().Is4() {
if ipv4 != nil {
dst := end.dst4() dst := end.dst4()
end.isV6 = false end.isV6 = false
dst.Port = addr.Port dst.Port = int(e.Port())
copy(dst.Addr[:], ipv4) dst.Addr = e.Addr().As4()
end.ClearSrc() end.ClearSrc()
return &end, nil return &end, nil
} }
ipv6 := addr.IP.To16() if e.Addr().Is6() {
if ipv6 != nil { zone, err := zoneToUint32(e.Addr().Zone())
zone, err := zoneToUint32(addr.Zone)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dst := end.dst6() dst := end.dst6()
end.isV6 = true end.isV6 = true
dst.Port = addr.Port dst.Port = int(e.Port())
dst.ZoneId = zone dst.ZoneId = zone
copy(dst.Addr[:], ipv6[:]) dst.Addr = e.Addr().As16()
end.ClearSrc() end.ClearSrc()
return &end, nil return &end, nil
} }
return nil, errors.New("Invalid IP address") return nil, errors.New("invalid IP address")
} }
func createBind(port uint16) (Bind, uint16, error) { func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error var err error
var bind nativeBind
var newPort uint16 var newPort uint16
var tries int var tries int
if bind.sock4 != -1 || bind.sock6 != -1 {
return nil, 0, ErrBindAlreadyOpen
}
originalPort := port originalPort := port
again: again:
port = originalPort port = originalPort
var sock4, sock6 int
// Attempt ipv6 bind, update port if successful. // Attempt ipv6 bind, update port if successful.
bind.sock6, newPort, err = create6(port) sock6, newPort, err = create6(port)
if err != nil { if err != nil {
if err != syscall.EAFNOSUPPORT { if !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err return nil, 0, err
} }
} else { } else {
@@ -120,35 +132,39 @@ again:
} }
// Attempt ipv4 bind, update port if successful. // Attempt ipv4 bind, update port if successful.
bind.sock4, newPort, err = create4(port) sock4, newPort, err = create4(port)
if err != nil { if err != nil {
if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 { if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
unix.Close(bind.sock6) unix.Close(sock6)
tries++ tries++
goto again goto again
} }
if err != syscall.EAFNOSUPPORT { if !errors.Is(err, syscall.EAFNOSUPPORT) {
unix.Close(bind.sock6) unix.Close(sock6)
return nil, 0, err return nil, 0, err
} }
} else { } else {
port = newPort port = newPort
} }
if bind.sock4 == -1 && bind.sock6 == -1 { var fns []ReceiveFunc
return nil, 0, errors.New("ipv4 and ipv6 not supported") if sock4 != -1 {
bind.sock4 = sock4
fns = append(fns, bind.receiveIPv4)
} }
if sock6 != -1 {
return &bind, port, nil bind.sock6 = sock6
fns = append(fns, bind.receiveIPv6)
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, port, nil
} }
func (bind *nativeBind) LastMark() uint32 { func (bind *LinuxSocketBind) SetMark(value uint32) error {
return bind.lastMark bind.mu.RLock()
} defer bind.mu.RUnlock()
func (bind *nativeBind) SetMark(value uint32) error {
bind.closing.RLock()
defer bind.closing.RUnlock()
if bind.sock6 != -1 { if bind.sock6 != -1 {
err := unix.SetsockoptInt( err := unix.SetsockoptInt(
@@ -157,7 +173,6 @@ func (bind *nativeBind) SetMark(value uint32) error {
unix.SO_MARK, unix.SO_MARK,
int(value), int(value),
) )
if err != nil { if err != nil {
return err return err
} }
@@ -170,27 +185,29 @@ func (bind *nativeBind) SetMark(value uint32) error {
unix.SO_MARK, unix.SO_MARK,
int(value), int(value),
) )
if err != nil { if err != nil {
return err return err
} }
} }
bind.lastMark = value
return nil return nil
} }
func (bind *nativeBind) Close() error { func (bind *LinuxSocketBind) Close() error {
var err1, err2 error // Take a readlock to shut down the sockets...
bind.closing.RLock() bind.mu.RLock()
if bind.sock6 != -1 { if bind.sock6 != -1 {
unix.Shutdown(bind.sock6, unix.SHUT_RDWR) unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
} }
if bind.sock4 != -1 { if bind.sock4 != -1 {
unix.Shutdown(bind.sock4, unix.SHUT_RDWR) unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
} }
bind.closing.RUnlock() bind.mu.RUnlock()
bind.closing.Lock() // ...and a write lock to close the fd.
// This ensures that no one else is using the fd.
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error
if bind.sock6 != -1 { if bind.sock6 != -1 {
err1 = unix.Close(bind.sock6) err1 = unix.Close(bind.sock6)
bind.sock6 = -1 bind.sock6 = -1
@@ -199,7 +216,6 @@ func (bind *nativeBind) Close() error {
err2 = unix.Close(bind.sock4) err2 = unix.Close(bind.sock4)
bind.sock4 = -1 bind.sock4 = -1
} }
bind.closing.Unlock()
if err1 != nil { if err1 != nil {
return err1 return err1
@@ -207,83 +223,65 @@ func (bind *nativeBind) Close() error {
return err2 return err2
} }
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
bind.closing.RLock() bind.mu.RLock()
defer bind.closing.RUnlock() defer bind.mu.RUnlock()
var end NativeEndpoint
if bind.sock6 == -1 {
return 0, nil, NetErrClosed
}
n, err := receive6(
bind.sock6,
buff,
&end,
)
return n, &end, err
}
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
bind.closing.RLock()
defer bind.closing.RUnlock()
var end NativeEndpoint
if bind.sock4 == -1 { if bind.sock4 == -1 {
return 0, nil, NetErrClosed return 0, nil, net.ErrClosed
} }
n, err := receive4( var end LinuxSocketEndpoint
bind.sock4, n, err := receive4(bind.sock4, buf, &end)
buff,
&end,
)
return n, &end, err return n, &end, err
} }
func (bind *nativeBind) Send(buff []byte, end Endpoint) error { func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
bind.closing.RLock() bind.mu.RLock()
defer bind.closing.RUnlock() defer bind.mu.RUnlock()
if bind.sock6 == -1 {
return 0, nil, net.ErrClosed
}
var end LinuxSocketEndpoint
n, err := receive6(bind.sock6, buf, &end)
return n, &end, err
}
nend := end.(*NativeEndpoint) func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
nend, ok := end.(*LinuxSocketEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
if !nend.isV6 { if !nend.isV6 {
if bind.sock4 == -1 { if bind.sock4 == -1 {
return NetErrClosed return net.ErrClosed
} }
return send4(bind.sock4, nend, buff) return send4(bind.sock4, nend, buff)
} else { } else {
if bind.sock6 == -1 { if bind.sock6 == -1 {
return NetErrClosed return net.ErrClosed
} }
return send6(bind.sock6, nend, buff) return send6(bind.sock6, nend, buff)
} }
} }
func (end *NativeEndpoint) SrcIP() net.IP { func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
if !end.isV6 { if !end.isV6 {
return net.IPv4( return netip.AddrFrom4(end.src4().Src)
end.src4().Src[0],
end.src4().Src[1],
end.src4().Src[2],
end.src4().Src[3],
)
} else { } else {
return end.src6().src[:] return netip.AddrFrom16(end.src6().src)
} }
} }
func (end *NativeEndpoint) DstIP() net.IP { func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
if !end.isV6 { if !end.isV6 {
return net.IPv4( return netip.AddrFrom4(end.dst4().Addr)
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
} else { } else {
return end.dst6().Addr[:] return netip.AddrFrom16(end.dst6().Addr)
} }
} }
func (end *NativeEndpoint) DstToBytes() []byte { func (end *LinuxSocketEndpoint) DstToBytes() []byte {
if !end.isV6 { if !end.isV6 {
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
} else { } else {
@@ -291,28 +289,27 @@ func (end *NativeEndpoint) DstToBytes() []byte {
} }
} }
func (end *NativeEndpoint) SrcToString() string { func (end *LinuxSocketEndpoint) SrcToString() string {
return end.SrcIP().String() return end.SrcIP().String()
} }
func (end *NativeEndpoint) DstToString() string { func (end *LinuxSocketEndpoint) DstToString() string {
var udpAddr net.UDPAddr var port int
udpAddr.IP = end.DstIP()
if !end.isV6 { if !end.isV6 {
udpAddr.Port = end.dst4().Port port = end.dst4().Port
} else { } else {
udpAddr.Port = end.dst6().Port port = end.dst6().Port
} }
return udpAddr.String() return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
} }
func (end *NativeEndpoint) ClearDst() { func (end *LinuxSocketEndpoint) ClearDst() {
for i := range end.dst { for i := range end.dst {
end.dst[i] = 0 end.dst[i] = 0
} }
} }
func (end *NativeEndpoint) ClearSrc() { func (end *LinuxSocketEndpoint) ClearSrc() {
for i := range end.src { for i := range end.src {
end.src[i] = 0 end.src[i] = 0
} }
@@ -330,15 +327,13 @@ func zoneToUint32(zone string) (uint32, error) {
} }
func create4(port uint16) (int, uint16, error) { func create4(port uint16) (int, uint16, error) {
// create socket // create socket
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0, 0,
) )
if err != nil { if err != nil {
return -1, 0, err return -1, 0, err
} }
@@ -374,15 +369,13 @@ func create4(port uint16) (int, uint16, error) {
} }
func create6(port uint16) (int, uint16, error) { func create6(port uint16) (int, uint16, error) {
// create socket // create socket
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_INET6, unix.AF_INET6,
unix.SOCK_DGRAM, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0, 0,
) )
if err != nil { if err != nil {
return -1, 0, err return -1, 0, err
} }
@@ -413,7 +406,6 @@ func create6(port uint16) (int, uint16, error) {
} }
return unix.Bind(fd, &addr) return unix.Bind(fd, &addr)
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
return -1, 0, err return -1, 0, err
@@ -427,8 +419,7 @@ func create6(port uint16) (int, uint16, error) {
return fd, uint16(addr.Port), err return fd, uint16(addr.Port), err
} }
func send4(sock int, end *NativeEndpoint, buff []byte) error { func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
// construct message header // construct message header
cmsg := struct { cmsg := struct {
@@ -446,9 +437,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
}, },
} }
end.Lock() end.mu.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock() end.mu.Unlock()
if err == nil { if err == nil {
return nil return nil
@@ -459,16 +450,15 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL { if err == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{} cmsg.pktinfo = unix.Inet4Pktinfo{}
end.Lock() end.mu.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock() end.mu.Unlock()
} }
return err return err
} }
func send6(sock int, end *NativeEndpoint, buff []byte) error { func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
// construct message header // construct message header
cmsg := struct { cmsg := struct {
@@ -490,9 +480,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
cmsg.pktinfo.Ifindex = 0 cmsg.pktinfo.Ifindex = 0
} }
end.Lock() end.mu.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock() end.mu.Unlock()
if err == nil { if err == nil {
return nil return nil
@@ -503,16 +493,15 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL { if err == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{} cmsg.pktinfo = unix.Inet6Pktinfo{}
end.Lock() end.mu.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock() end.mu.Unlock()
} }
return err return err
} }
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
// construct message header // construct message header
var cmsg struct { var cmsg struct {
@@ -521,7 +510,6 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
} }
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -543,8 +531,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
return size, nil return size, nil
} }
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
// construct message header // construct message header
var cmsg struct { var cmsg struct {
@@ -553,7 +540,6 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
} }
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil { if err != nil {
return 0, err return 0, err
} }

212
conn/bind_std.go Normal file
View File

@@ -0,0 +1,212 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"net"
"net/netip"
"sync"
"syscall"
)
// StdNetBind is meant to be a temporary solution on platforms for which
// the sticky socket / source caching behavior has not yet been implemented.
// It uses the Go's net package to implement networking.
// See LinuxSocketBind for a proper implementation on the Linux platform.
type StdNetBind struct {
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
}
func NewStdNetBind() Bind { return &StdNetBind{} }
type StdNetEndpoint netip.AddrPort
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
e, err := netip.ParseAddrPort(s)
return asEndpoint(e), err
}
func (StdNetEndpoint) ClearSrc() {}
func (e StdNetEndpoint) DstIP() netip.Addr {
return (netip.AddrPort)(e).Addr()
}
func (e StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}
func (e StdNetEndpoint) DstToBytes() []byte {
b, _ := (netip.AddrPort)(e).MarshalBinary()
return b
}
func (e StdNetEndpoint) DstToString() string {
return (netip.AddrPort)(e).String()
}
func (e StdNetEndpoint) SrcToString() string {
return ""
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error
var tries int
if bind.ipv4 != nil || bind.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var ipv4, ipv6 *net.UDPConn
ipv4, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
ipv6, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
ipv4.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
ipv4.Close()
return nil, 0, err
}
var fns []ReceiveFunc
if ipv4 != nil {
fns = append(fns, bind.makeReceiveIPv4(ipv4))
bind.ipv4 = ipv4
}
if ipv6 != nil {
fns = append(fns, bind.makeReceiveIPv6(ipv6))
bind.ipv6 = ipv6
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
}
func (bind *StdNetBind) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
bind.ipv4 = nil
}
if bind.ipv6 != nil {
err2 = bind.ipv6.Close()
bind.ipv6 = nil
}
bind.blackhole4 = false
bind.blackhole6 = false
if err1 != nil {
return err1
}
return err2
}
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
return n, asEndpoint(endpoint), err
}
}
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
return n, asEndpoint(endpoint), err
}
}
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend, ok := endpoint.(StdNetEndpoint)
if !ok {
return ErrWrongEndpointType
}
addrPort := netip.AddrPort(nend)
bind.mu.Lock()
blackhole := bind.blackhole4
conn := bind.ipv4
if addrPort.Addr().Is6() {
blackhole = bind.blackhole6
conn = bind.ipv6
}
bind.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
_, err = conn.WriteToUDPAddrPort(buff, addrPort)
return err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]Endpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) Endpoint {
m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = Endpoint(StdNetEndpoint(ap))
m[ap] = e
}
return e
}

582
conn/bind_windows.go Normal file
View File

@@ -0,0 +1,582 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"encoding/binary"
"io"
"net"
"net/netip"
"strconv"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/conn/winrio"
)
const (
packetsPerRing = 1024
bytesPerPacket = 2048 - 32
receiveSpins = 15
)
type ringPacket struct {
addr WinRingEndpoint
data [bytesPerPacket]byte
}
type ringBuffer struct {
packets uintptr
head, tail uint32
id winrio.BufferId
iocp windows.Handle
isFull bool
cq winrio.Cq
mu sync.Mutex
overlapped windows.Overlapped
}
func (rb *ringBuffer) Push() *ringPacket {
for rb.isFull {
panic("ring is full")
}
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
rb.tail += 1
if rb.tail%packetsPerRing == rb.head%packetsPerRing {
rb.isFull = true
}
return ret
}
func (rb *ringBuffer) Return(count uint32) {
if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
return
}
rb.head += count
rb.isFull = false
}
type afWinRingBind struct {
sock windows.Handle
rx, tx ringBuffer
rq winrio.Rq
mu sync.Mutex
blackhole bool
}
// WinRingBind uses Windows registered I/O for fast ring buffered networking.
type WinRingBind struct {
v4, v6 afWinRingBind
mu sync.RWMutex
isOpen atomic.Uint32 // 0, 1, or 2
}
func NewDefaultBind() Bind { return NewWinRingBind() }
func NewWinRingBind() Bind {
if !winrio.Initialize() {
return NewStdNetBind()
}
return new(WinRingBind)
}
type WinRingEndpoint struct {
family uint16
data [30]byte
}
var (
_ Bind = (*WinRingBind)(nil)
_ Endpoint = (*WinRingEndpoint)(nil)
)
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
host, port, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
host16, err := windows.UTF16PtrFromString(host)
if err != nil {
return nil, err
}
port16, err := windows.UTF16PtrFromString(port)
if err != nil {
return nil, err
}
hints := windows.AddrinfoW{
Flags: windows.AI_NUMERICHOST,
Family: windows.AF_UNSPEC,
Socktype: windows.SOCK_DGRAM,
Protocol: windows.IPPROTO_UDP,
}
var addrinfo *windows.AddrinfoW
err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
if err != nil {
return nil, err
}
defer windows.FreeAddrInfoW(addrinfo)
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
return nil, windows.ERROR_INVALID_ADDRESS
}
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
}
func (*WinRingEndpoint) ClearSrc() {}
func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
case windows.AF_INET6:
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
}
return netip.Addr{}
}
func (e *WinRingEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
}
func (e *WinRingEndpoint) DstToBytes() []byte {
switch e.family {
case windows.AF_INET:
b := make([]byte, 0, 6)
b = append(b, e.data[2:6]...)
b = append(b, e.data[1], e.data[0])
return b
case windows.AF_INET6:
b := make([]byte, 0, 18)
b = append(b, e.data[6:22]...)
b = append(b, e.data[1], e.data[0])
return b
}
return nil
}
func (e *WinRingEndpoint) DstToString() string {
switch e.family {
case windows.AF_INET:
netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
case windows.AF_INET6:
var zone string
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
zone = strconv.FormatUint(uint64(scope), 10)
}
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
}
return ""
}
func (e *WinRingEndpoint) SrcToString() string {
return ""
}
func (ring *ringBuffer) CloseAndZero() {
if ring.cq != 0 {
winrio.CloseCompletionQueue(ring.cq)
ring.cq = 0
}
if ring.iocp != 0 {
windows.CloseHandle(ring.iocp)
ring.iocp = 0
}
if ring.id != 0 {
winrio.DeregisterBuffer(ring.id)
ring.id = 0
}
if ring.packets != 0 {
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
ring.packets = 0
}
ring.head = 0
ring.tail = 0
ring.isFull = false
}
func (bind *afWinRingBind) CloseAndZero() {
bind.rx.CloseAndZero()
bind.tx.CloseAndZero()
if bind.sock != 0 {
windows.CloseHandle(bind.sock)
bind.sock = 0
}
bind.blackhole = false
}
func (bind *WinRingBind) closeAndZero() {
bind.isOpen.Store(0)
bind.v4.CloseAndZero()
bind.v6.CloseAndZero()
}
func (ring *ringBuffer) Open() error {
var err error
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return err
}
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
if err != nil {
return err
}
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return err
}
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
if err != nil {
return err
}
return nil
}
func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
var err error
bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
return nil, err
}
err = bind.rx.Open()
if err != nil {
return nil, err
}
err = bind.tx.Open()
if err != nil {
return nil, err
}
bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
if err != nil {
return nil, err
}
err = windows.Bind(bind.sock, sa)
if err != nil {
return nil, err
}
sa, err = windows.Getsockname(bind.sock)
if err != nil {
return nil, err
}
return sa, nil
}
func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
bind.mu.Lock()
defer bind.mu.Unlock()
defer func() {
if err != nil {
bind.closeAndZero()
}
}()
if bind.isOpen.Load() != 0 {
return nil, 0, ErrBindAlreadyOpen
}
var sa windows.Sockaddr
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
if err != nil {
return nil, 0, err
}
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
if err != nil {
return nil, 0, err
}
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
for i := 0; i < packetsPerRing; i++ {
err = bind.v4.InsertReceiveRequest()
if err != nil {
return nil, 0, err
}
err = bind.v6.InsertReceiveRequest()
if err != nil {
return nil, 0, err
}
}
bind.isOpen.Store(1)
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
}
func (bind *WinRingBind) Close() error {
bind.mu.RLock()
if bind.isOpen.Load() != 1 {
bind.mu.RUnlock()
return nil
}
bind.isOpen.Store(2)
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
bind.mu.RUnlock()
bind.mu.Lock()
defer bind.mu.Unlock()
bind.closeAndZero()
return nil
}
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}
func (bind *afWinRingBind) InsertReceiveRequest() error {
packet := bind.rx.Push()
dataBuffer := &winrio.Buffer{
Id: bind.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
Length: uint32(len(packet.data)),
}
addressBuffer := &winrio.Buffer{
Id: bind.rx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
bind.mu.Lock()
defer bind.mu.Unlock()
return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
}
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
bind.rx.mu.Lock()
defer bind.rx.mu.Unlock()
var err error
var count uint32
var results [1]winrio.Result
retry:
count = 0
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
if tries > 0 {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
procyield(1)
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
}
if count == 0 {
err = winrio.Notify(bind.rx.cq)
if err != nil {
return 0, nil, err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return 0, nil, err
}
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
if count == 0 {
return 0, nil, io.ErrNoProgress
}
}
bind.rx.Return(1)
err = bind.InsertReceiveRequest()
if err != nil {
return 0, nil, err
}
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
// attacker bandwidth, just like the rest of the receive path.
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
if isOpen.Load() != 1 {
return 0, nil, net.ErrClosed
}
goto retry
}
if results[0].Status != 0 {
return 0, nil, windows.Errno(results[0].Status)
}
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
ep := packet.addr
n := copy(buf, packet.data[:results[0].BytesTransferred])
return n, &ep, nil
}
func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
return bind.v4.Receive(buf, &bind.isOpen)
}
func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
return bind.v6.Receive(buf, &bind.isOpen)
}
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
if isOpen.Load() != 1 {
return net.ErrClosed
}
if len(buf) > bytesPerPacket {
return io.ErrShortBuffer
}
bind.tx.mu.Lock()
defer bind.tx.mu.Unlock()
var results [packetsPerRing]winrio.Result
count := winrio.DequeueCompletion(bind.tx.cq, results[:])
if count == 0 && bind.tx.isFull {
err := winrio.Notify(bind.tx.cq)
if err != nil {
return err
}
var bytes uint32
var key uintptr
var overlapped *windows.Overlapped
err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
if err != nil {
return err
}
if isOpen.Load() != 1 {
return net.ErrClosed
}
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
if count == 0 {
return io.ErrNoProgress
}
}
if count > 0 {
bind.tx.Return(count)
}
packet := bind.tx.Push()
packet.addr = *nend
copy(packet.data[:], buf)
dataBuffer := &winrio.Buffer{
Id: bind.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
Length: uint32(len(buf)),
}
addressBuffer := &winrio.Buffer{
Id: bind.tx.id,
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
Length: uint32(unsafe.Sizeof(packet.addr)),
}
bind.mu.Lock()
defer bind.mu.Unlock()
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
switch nend.family {
case windows.AF_INET:
if bind.v4.blackhole {
return nil
}
return bind.v4.Send(buf, nend, &bind.isOpen)
case windows.AF_INET6:
if bind.v6.blackhole {
return nil
}
return bind.v6.Send(buf, nend, &bind.isOpen)
}
return nil
}
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock()
defer bind.mu.Unlock()
sysconn, err := bind.ipv4.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
bind.blackhole4 = blackhole
return nil
}
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock()
defer bind.mu.Unlock()
sysconn, err := bind.ipv6.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
bind.blackhole6 = blackhole
return nil
}
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
if err != nil {
return err
}
bind.v4.blackhole = blackhole
return nil
}
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.isOpen.Load() != 1 {
return net.ErrClosed
}
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
if err != nil {
return err
}
bind.v6.blackhole = blackhole
return nil
}
func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
const IP_UNICAST_IF = 31
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
var bytes [4]byte
binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
if err != nil {
return err
}
return nil
}
func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
const IPV6_UNICAST_IF = 31
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
}

129
conn/bindtest/bindtest.go Normal file
View File

@@ -0,0 +1,129 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package bindtest
import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
"golang.zx2c4.com/wireguard/conn"
)
type ChannelBind struct {
rx4, tx4 *chan []byte
rx6, tx6 *chan []byte
closeSignal chan bool
source4, source6 ChannelEndpoint
target4, target6 ChannelEndpoint
}
type ChannelEndpoint uint16
var (
_ conn.Bind = (*ChannelBind)(nil)
_ conn.Endpoint = (*ChannelEndpoint)(nil)
)
func NewChannelBinds() [2]conn.Bind {
arx4 := make(chan []byte, 8192)
brx4 := make(chan []byte, 8192)
arx6 := make(chan []byte, 8192)
brx6 := make(chan []byte, 8192)
var binds [2]ChannelBind
binds[0].rx4 = &arx4
binds[0].tx4 = &brx4
binds[1].rx4 = &brx4
binds[1].tx4 = &arx4
binds[0].rx6 = &arx6
binds[0].tx6 = &brx6
binds[1].rx6 = &brx6
binds[1].tx6 = &arx6
binds[0].target4 = ChannelEndpoint(1)
binds[1].target4 = ChannelEndpoint(2)
binds[0].target6 = ChannelEndpoint(3)
binds[1].target6 = ChannelEndpoint(4)
binds[0].source4 = binds[1].target4
binds[0].source6 = binds[1].target6
binds[1].source4 = binds[0].target4
binds[1].source6 = binds[0].target6
return [2]conn.Bind{&binds[0], &binds[1]}
}
func (c ChannelEndpoint) ClearSrc() {}
func (c ChannelEndpoint) SrcToString() string { return "" }
func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
fns = append(fns, c.makeReceiveFunc(*c.rx4))
fns = append(fns, c.makeReceiveFunc(*c.rx6))
if rand.Uint32()&1 == 0 {
return fns, uint16(c.source4), nil
} else {
return fns, uint16(c.source6), nil
}
}
func (c *ChannelBind) Close() error {
if c.closeSignal != nil {
select {
case <-c.closeSignal:
default:
close(c.closeSignal)
}
}
return nil
}
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
return func(b []byte) (n int, ep conn.Endpoint, err error) {
select {
case <-c.closeSignal:
return 0, nil, net.ErrClosed
case rx := <-ch:
return copy(b, rx), c.target6, nil
}
}
}
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
select {
case <-c.closeSignal:
return net.ErrClosed
default:
bc := make([]byte, len(b))
copy(bc, b)
if ep.(ChannelEndpoint) == c.target4 {
*c.tx4 <- bc
} else if ep.(ChannelEndpoint) == c.target6 {
*c.tx6 <- bc
} else {
return os.ErrInvalid
}
}
return nil
}
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
addr, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
return ChannelEndpoint(addr.Port()), nil
}

View File

@@ -1,11 +1,11 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn
func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) { func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := bind.ipv4.SyscallConn() sysconn, err := bind.ipv4.SyscallConn()
if err != nil { if err != nil {
return -1, err return -1, err
@@ -19,7 +19,7 @@ func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) {
return return
} }
func (bind *nativeBind) PeekLookAtSocketFd6() (fd int, err error) { func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := bind.ipv6.SyscallConn() sysconn, err := bind.ipv6.SyscallConn()
if err != nil { if err != nil {
return -1, err return -1, err

View File

@@ -1,59 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"encoding/binary"
"unsafe"
"golang.org/x/sys/windows"
)
const (
sockoptIP_UNICAST_IF = 31
sockoptIPV6_UNICAST_IF = 31
)
func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
sysconn, err := bind.ipv4.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
bind.blackhole4 = blackhole
return nil
}
func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
sysconn, err := bind.ipv6.SyscallConn()
if err != nil {
return err
}
err2 := sysconn.Control(func(fd uintptr) {
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
bind.blackhole6 = blackhole
return nil
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
// Package conn implements WireGuard's network connections. // Package conn implements WireGuard's network connections.
@@ -8,49 +8,41 @@ package conn
import ( import (
"errors" "errors"
"net" "fmt"
"net/netip"
"reflect"
"runtime"
"strings" "strings"
) )
// A ReceiveFunc receives a single inbound packet from the network.
// It writes the data into b. n is the length of the packet.
// ep is the remote endpoint.
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
// //
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
// depending on the platform-specific implementation. // depending on the platform-specific implementation.
type Bind interface { type Bind interface {
// LastMark reports the last mark set for this Bind. // Open puts the Bind into a listening state on a given port and reports the actual
LastMark() uint32 // port that it bound to. Passing zero results in a random selection.
// fns is the set of functions that will be called to receive packets.
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
// Close closes the Bind listener.
// All fns returned by Open must return net.ErrClosed after a call to Close.
Close() error
// SetMark sets the mark for each packet sent through this Bind. // SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK. // This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error SetMark(mark uint32) error
// ReceiveIPv6 reads an IPv6 UDP packet into b.
//
// It reports the number of bytes read, n,
// the packet source address ep,
// and any error.
ReceiveIPv6(b []byte) (n int, ep Endpoint, err error)
// ReceiveIPv4 reads an IPv4 UDP packet into b.
//
// It reports the number of bytes read, n,
// the packet source address ep,
// and any error.
ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
// Send writes a packet b to address ep. // Send writes a packet b to address ep.
Send(b []byte, ep Endpoint) error Send(b []byte, ep Endpoint) error
// Close closes the Bind connection. // ParseEndpoint creates a new endpoint from a string.
Close() error ParseEndpoint(s string) (Endpoint, error)
}
// CreateBind creates a Bind bound to a port.
//
// The value actualPort reports the actual port number the Bind
// object gets bound to.
func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
return createBind(port)
} }
// BindSocketToInterface is implemented by Bind objects that support being // BindSocketToInterface is implemented by Bind objects that support being
@@ -69,43 +61,61 @@ type PeekLookAtSocketFd interface {
// An Endpoint maintains the source/destination caching for a peer. // An Endpoint maintains the source/destination caching for a peer.
// //
// dst : the remote address of a peer ("endpoint" in uapi terminology) // dst: the remote address of a peer ("endpoint" in uapi terminology)
// src : the local address from which datagrams originate going to the peer // src: the local address from which datagrams originate going to the peer
type Endpoint interface { type Endpoint interface {
ClearSrc() // clears the source address ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port) SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port) DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP DstIP() netip.Addr
SrcIP() net.IP SrcIP() netip.Addr
} }
func parseEndpoint(s string) (*net.UDPAddr, error) { var (
// ensure that the host is an IP address ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)
host, _, err := net.SplitHostPort(s) func (fn ReceiveFunc) PrettyName() string {
if err != nil { name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
return nil, err // 0. cheese/taco.beansIPv6.func12.func21218-fm
name = strings.TrimSuffix(name, "-fm")
// 1. cheese/taco.beansIPv6.func12.func21218
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
name = name[idx+1:]
// 2. taco.beansIPv6.func12.func21218
} }
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { for {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just var idx int
// trying to make sure with a small sanity test that this is a real IP address and for idx = len(name) - 1; idx >= 0; idx-- {
// not something that's likely to incur DNS lookups. if name[idx] < '0' || name[idx] > '9' {
host = host[:i] break
}
}
if idx == len(name)-1 {
break
}
const dotFunc = ".func"
if !strings.HasSuffix(name[:idx+1], dotFunc) {
break
}
name = name[:idx+1-len(dotFunc)]
// 3. taco.beansIPv6.func12
// 4. taco.beansIPv6
} }
if ip := net.ParseIP(host); ip == nil { if idx := strings.LastIndexByte(name, '.'); idx != -1 {
return nil, errors.New("Failed to parse IP address: " + host) name = name[idx+1:]
// 5. beansIPv6
} }
if name == "" {
// parse address and port return fmt.Sprintf("%p", fn)
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
} }
ip4 := addr.IP.To4() if strings.HasSuffix(name, "IPv4") {
if ip4 != nil { return "v4"
addr.IP = ip4
} }
return addr, err if strings.HasSuffix(name, "IPv6") {
return "v6"
}
return name
} }

View File

@@ -1,171 +0,0 @@
// +build !linux android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
import (
"errors"
"net"
"syscall"
)
/* This code is meant to be a temporary solution
* on platforms for which the sticky socket / source caching behavior
* has not yet been implemented.
*
* See conn_linux.go for an implementation on the linux platform.
*/
type nativeBind struct {
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
}
type NativeEndpoint net.UDPAddr
var _ Bind = (*nativeBind)(nil)
var _ Endpoint = (*NativeEndpoint)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s)
return (*NativeEndpoint)(addr), err
}
func (*NativeEndpoint) ClearSrc() {}
func (e *NativeEndpoint) DstIP() net.IP {
return (*net.UDPAddr)(e).IP
}
func (e *NativeEndpoint) SrcIP() net.IP {
return nil // not supported
}
func (e *NativeEndpoint) DstToBytes() []byte {
addr := (*net.UDPAddr)(e)
out := addr.IP.To4()
if out == nil {
out = addr.IP
}
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
}
func (e *NativeEndpoint) DstToString() string {
return (*net.UDPAddr)(e).String()
}
func (e *NativeEndpoint) SrcToString() string {
return ""
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func createBind(uport uint16) (Bind, uint16, error) {
var err error
var bind nativeBind
var tries int
again:
port := int(uport)
bind.ipv4, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
bind.ipv6, port, err = listenNet("udp6", port)
if uport == 0 && err != nil && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
bind.ipv4.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
bind.ipv4.Close()
bind.ipv4 = nil
return nil, 0, err
}
return &bind, uint16(port), nil
}
func (bind *nativeBind) Close() error {
var err1, err2 error
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
}
if bind.ipv6 != nil {
err2 = bind.ipv6.Close()
}
if err1 != nil {
return err1
}
return err2
}
func (bind *nativeBind) LastMark() uint32 { return 0 }
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
if bind.ipv6 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend := endpoint.(*NativeEndpoint)
if nend.IP.To4() != nil {
if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT
}
if bind.blackhole4 {
return nil
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else {
if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT
}
if bind.blackhole6 {
return nil
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
}
return err
}

10
conn/default.go Normal file
View File

@@ -0,0 +1,10 @@
//go:build !linux && !windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package conn
func NewDefaultBind() Bind { return NewStdNetBind() }

View File

@@ -1,12 +1,12 @@
// +build !linux,!openbsd,!freebsd //go:build !linux && !openbsd && !freebsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn
func (bind *nativeBind) SetMark(mark uint32) error { func (bind *StdNetBind) SetMark(mark uint32) error {
return nil return nil
} }

View File

@@ -1,8 +1,8 @@
// +build android openbsd freebsd //go:build linux || openbsd || freebsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package conn package conn
@@ -26,7 +26,7 @@ func init() {
} }
} }
func (bind *nativeBind) SetMark(mark uint32) error { func (bind *StdNetBind) SetMark(mark uint32) error {
var operr error var operr error
if fwmarkIoctl == 0 { if fwmarkIoctl == 0 {
return nil return nil

View File

@@ -1,13 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package conn
import _ "unsafe"
//TODO: replace this with net.ErrClosed for Go 1.16
//go:linkname NetErrClosed internal/poll.ErrNetClosing
var NetErrClosed error

254
conn/winrio/rio_windows.go Normal file
View File

@@ -0,0 +1,254 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package winrio
import (
"log"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
const (
MsgDontNotify = 1
MsgDefer = 2
MsgWaitAll = 4
MsgCommitOnly = 8
MaxCqSize = 0x8000000
invalidBufferId = 0xFFFFFFFF
invalidCq = 0
invalidRq = 0
corruptCq = 0xFFFFFFFF
)
var extensionFunctionTable struct {
cbSize uint32
rioReceive uintptr
rioReceiveEx uintptr
rioSend uintptr
rioSendEx uintptr
rioCloseCompletionQueue uintptr
rioCreateCompletionQueue uintptr
rioCreateRequestQueue uintptr
rioDequeueCompletion uintptr
rioDeregisterBuffer uintptr
rioNotify uintptr
rioRegisterBuffer uintptr
rioResizeCompletionQueue uintptr
rioResizeRequestQueue uintptr
}
type Cq uintptr
type Rq uintptr
type BufferId uintptr
type Buffer struct {
Id BufferId
Offset uint32
Length uint32
}
type Result struct {
Status int32
BytesTransferred uint32
SocketContext uint64
RequestContext uint64
}
type notificationCompletionType uint32
const (
eventCompletion notificationCompletionType = 1
iocpCompletion notificationCompletionType = 2
)
type eventNotificationCompletion struct {
completionType notificationCompletionType
event windows.Handle
notifyReset uint32
}
type iocpNotificationCompletion struct {
completionType notificationCompletionType
iocp windows.Handle
key uintptr
overlapped *windows.Overlapped
}
var (
initialized sync.Once
available bool
)
func Initialize() bool {
initialized.Do(func() {
var (
err error
socket windows.Handle
cq Cq
)
defer func() {
if err == nil {
return
}
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
return
}
log.Printf("Registered I/O is unavailable: %v", err)
}()
socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
if err != nil {
return
}
defer windows.CloseHandle(socket)
WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
ob := uint32(0)
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
(*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
(*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
&ob, nil, 0)
if err != nil {
return
}
// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
// failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
var iocp windows.Handle
iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return
}
defer windows.CloseHandle(iocp)
var overlapped windows.Overlapped
cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
if err != nil {
return
}
defer CloseCompletionQueue(cq)
_, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
if err != nil {
return
}
available = true
})
return available
}
func Socket(af, typ, proto int32) (windows.Handle, error) {
return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
}
func CloseCompletionQueue(cq Cq) {
_, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
}
func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
notificationCompletion := &eventNotificationCompletion{
completionType: eventCompletion,
event: event,
}
if notifyReset {
notificationCompletion.notifyReset = 1
}
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
if ret == invalidCq {
return 0, err
}
return Cq(ret), nil
}
func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
notificationCompletion := &iocpNotificationCompletion{
completionType: iocpCompletion,
iocp: iocp,
key: key,
overlapped: overlapped,
}
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
if ret == invalidCq {
return 0, err
}
return Cq(ret), nil
}
func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
if ret == invalidCq {
return 0, err
}
return Cq(ret), nil
}
func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
if ret == invalidRq {
return 0, err
}
return Rq(ret), nil
}
func DequeueCompletion(cq Cq, results []Result) uint32 {
var array uintptr
if len(results) > 0 {
array = uintptr(unsafe.Pointer(&results[0]))
}
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
if ret == corruptCq {
panic("cq is corrupt")
}
return uint32(ret)
}
func DeregisterBuffer(id BufferId) {
_, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
}
func RegisterBuffer(buffer []byte) (BufferId, error) {
var buf unsafe.Pointer
if len(buffer) > 0 {
buf = unsafe.Pointer(&buffer[0])
}
return RegisterPointer(buf, uint32(len(buffer)))
}
func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
if ret == invalidBufferId {
return 0, err
}
return BufferId(ret), nil
}
func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
if ret == 0 {
return err
}
return nil
}
func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
if ret == 0 {
return err
}
return nil
}
func Notify(cq Cq) error {
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
if ret != 0 {
return windows.Errno(ret)
}
return nil
}

View File

@@ -1,68 +1,55 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"container/list" "container/list"
"encoding/binary"
"errors" "errors"
"math/bits" "math/bits"
"net" "net"
"net/netip"
"sync" "sync"
"unsafe" "unsafe"
) )
type parentIndirection struct {
parentBit **trieEntry
parentBitType uint8
}
type trieEntry struct { type trieEntry struct {
child [2]*trieEntry peer *Peer
peer *Peer child [2]*trieEntry
bits net.IP parent parentIndirection
cidr uint cidr uint8
bit_at_byte uint bitAtByte uint8
bit_at_shift uint bitAtShift uint8
perPeerElem *list.Element bits []byte
perPeerElem *list.Element
} }
func isLittleEndian() bool { func commonBits(ip1, ip2 []byte) uint8 {
one := uint32(1)
return *(*byte)(unsafe.Pointer(&one)) != 0
}
func swapU32(i uint32) uint32 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes32(i)
}
func swapU64(i uint64) uint64 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes64(i)
}
func commonBits(ip1 net.IP, ip2 net.IP) uint {
size := len(ip1) size := len(ip1)
if size == net.IPv4len { if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0])) a := binary.BigEndian.Uint32(ip1)
b := (*uint32)(unsafe.Pointer(&ip2[0])) b := binary.BigEndian.Uint32(ip2)
x := *a ^ *b x := a ^ b
return uint(bits.LeadingZeros32(swapU32(x))) return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len { } else if size == net.IPv6len {
a := (*uint64)(unsafe.Pointer(&ip1[0])) a := binary.BigEndian.Uint64(ip1)
b := (*uint64)(unsafe.Pointer(&ip2[0])) b := binary.BigEndian.Uint64(ip2)
x := *a ^ *b x := a ^ b
if x != 0 { if x != 0 {
return uint(bits.LeadingZeros64(swapU64(x))) return uint8(bits.LeadingZeros64(x))
} }
a = (*uint64)(unsafe.Pointer(&ip1[8])) a = binary.BigEndian.Uint64(ip1[8:])
b = (*uint64)(unsafe.Pointer(&ip2[8])) b = binary.BigEndian.Uint64(ip2[8:])
x = *a ^ *b x = a ^ b
return 64 + uint(bits.LeadingZeros64(swapU64(x))) return 64 + uint8(bits.LeadingZeros64(x))
} else { } else {
panic("Wrong size bit string") panic("Wrong size bit string")
} }
@@ -79,32 +66,8 @@ func (node *trieEntry) removeFromPeerEntries() {
} }
} }
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { func (node *trieEntry) choose(ip []byte) byte {
if node == nil { return (ip[node.bitAtByte] >> node.bitAtShift) & 1
return node
}
// walk recursively
node.child[0] = node.child[0].removeByPeer(p)
node.child[1] = node.child[1].removeByPeer(p)
if node.peer != p {
return node
}
// remove peer & merge
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] == nil {
return node.child[1]
}
return node.child[0]
}
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
} }
func (node *trieEntry) maskSelf() { func (node *trieEntry) maskSelf() {
@@ -114,86 +77,125 @@ func (node *trieEntry) maskSelf() {
} }
} }
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { func (node *trieEntry) zeroizePointers() {
// Make the garbage collector's life slightly easier
node.peer = nil
node.child[0] = nil
node.child[1] = nil
node.parent.parentBit = nil
}
// at leaf func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
exact = true
return
}
bit := node.choose(ip)
node = node.child[bit]
}
return
}
if node == nil { func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{ node := &trieEntry{
bits: ip, peer: peer,
peer: peer, parent: trie,
cidr: cidr, bits: ip,
bit_at_byte: cidr / 8, cidr: cidr,
bit_at_shift: 7 - (cidr % 8), bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
} }
node.maskSelf() node.maskSelf()
node.addToPeerEntries() node.addToPeerEntries()
return node *trie.parentBit = node
return
} }
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
// traverse deeper if exact {
node.removeFromPeerEntries()
common := commonBits(node.bits, ip) node.peer = peer
if node.cidr <= cidr && common >= node.cidr { node.addToPeerEntries()
if node.cidr == cidr { return
node.removeFromPeerEntries()
node.peer = peer
node.addToPeerEntries()
return node
}
bit := node.choose(ip)
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
} }
// split node
newNode := &trieEntry{ newNode := &trieEntry{
bits: ip, peer: peer,
peer: peer, bits: ip,
cidr: cidr, cidr: cidr,
bit_at_byte: cidr / 8, bitAtByte: cidr / 8,
bit_at_shift: 7 - (cidr % 8), bitAtShift: 7 - (cidr % 8),
} }
newNode.maskSelf() newNode.maskSelf()
newNode.addToPeerEntries() newNode.addToPeerEntries()
cidr = min(cidr, common) var down *trieEntry
if node == nil {
// check for shorter prefix down = *trie.parentBit
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
return
}
}
common := commonBits(down.bits, ip)
if common < cidr {
cidr = common
}
parent := node
if newNode.cidr == cidr { if newNode.cidr == cidr {
bit := newNode.choose(node.bits) bit := newNode.choose(down.bits)
newNode.child[bit] = node down.parent = parentIndirection{&newNode.child[bit], bit}
return newNode newNode.child[bit] = down
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
} }
// create new parent for node & newNode node = &trieEntry{
bits: append([]byte{}, newNode.bits...),
parent := &trieEntry{ cidr: cidr,
bits: append([]byte{}, ip...), bitAtByte: cidr / 8,
peer: nil, bitAtShift: 7 - (cidr % 8),
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
} }
parent.maskSelf() node.maskSelf()
bit := parent.choose(ip) bit := node.choose(down.bits)
parent.child[bit] = newNode down.parent = parentIndirection{&node.child[bit], bit}
parent.child[bit^1] = node node.child[bit] = down
bit = node.choose(newNode.bits)
return parent newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
node.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = node
}
} }
func (node *trieEntry) lookup(ip net.IP) *Peer { func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer var found *Peer
size := uint(len(ip)) size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr { for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil { if node.peer != nil {
found = node.peer found = node.peer
} }
if node.bit_at_byte == size { if node.bitAtByte == size {
break break
} }
bit := node.choose(ip) bit := node.choose(ip)
@@ -208,13 +210,14 @@ type AllowedIPs struct {
mutex sync.RWMutex mutex sync.RWMutex
} }
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) { func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry) node := elem.Value.(*trieEntry)
if !cb(node.bits, node.cidr) { a, _ := netip.AddrFromSlice(node.bits)
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return return
} }
} }
@@ -224,32 +227,68 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
table.IPv4 = table.IPv4.removeByPeer(peer) var next *list.Element
table.IPv6 = table.IPv6.removeByPeer(peer) for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
node := elem.Value.(*trieEntry)
node.removeFromPeerEntries()
node.peer = nil
if node.child[0] != nil && node.child[1] != nil {
continue
}
bit := 0
if node.child[0] == nil {
bit = 1
}
child := node.child[bit]
if child != nil {
child.parent = node.parent
}
*node.parent.parentBit = child
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
node.zeroizePointers()
continue
}
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
if parent.peer != nil {
node.zeroizePointers()
continue
}
child = parent.child[node.parent.parentBitType^1]
if child != nil {
child.parent = parent.parent
}
*parent.parent.parentBit = child
node.zeroizePointers()
parent.zeroizePointers()
}
} }
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock() table.mutex.Lock()
defer table.mutex.Unlock() defer table.mutex.Unlock()
switch len(ip) { if prefix.Addr().Is6() {
case net.IPv6len: ip := prefix.Addr().As16()
table.IPv6 = table.IPv6.insert(ip, cidr, peer) parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
case net.IPv4len: } else if prefix.Addr().Is4() {
table.IPv4 = table.IPv4.insert(ip, cidr, peer) ip := prefix.Addr().As4()
default: parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type")) panic(errors.New("inserting unknown address type"))
} }
} }
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock() table.mutex.RLock()
defer table.mutex.RUnlock() defer table.mutex.RUnlock()
return table.IPv4.lookup(address) switch len(ip) {
} case net.IPv6len:
return table.IPv6.lookup(ip)
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { case net.IPv4len:
table.mutex.RLock() return table.IPv4.lookup(ip)
defer table.mutex.RUnlock() default:
return table.IPv6.lookup(address) panic(errors.New("looking up unknown address type"))
}
} }

View File

@@ -1,25 +1,28 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"math/rand" "math/rand"
"net"
"net/netip"
"sort" "sort"
"testing" "testing"
) )
const ( const (
NumberOfPeers = 100 NumberOfPeers = 100
NumberOfAddresses = 250 NumberOfPeerRemovals = 4
NumberOfTests = 10000 NumberOfAddresses = 250
NumberOfTests = 10000
) )
type SlowNode struct { type SlowNode struct {
peer *Peer peer *Peer
cidr uint cidr uint8
bits []byte bits []byte
} }
@@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i] r[i], r[j] = r[j], r[i]
} }
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r { for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer t.peer = peer
@@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil return nil
} }
func TestTrieRandomIPv4(t *testing.T) { func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
var trie *trieEntry n := 0
var slow SlowRouter for _, x := range r {
if x.peer != peer {
r[n] = x
n++
}
}
return r[:n]
}
func TestTrieRandom(t *testing.T) {
var slow4, slow6 SlowRouter
var peers []*Peer var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1) rand.Seed(1)
const AddressLength = 4
for n := 0; n < NumberOfPeers; n++ { for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{}) peers = append(peers, &Peer{})
} }
for n := 0; n < NumberOfAddresses; n++ { for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte var addr4 [4]byte
rand.Read(addr[:]) rand.Read(addr4[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Intn(32) + 1)
index := rand.Int() % NumberOfPeers index := rand.Intn(NumberOfPeers)
trie = trie.insert(addr[:], cidr, peers[index]) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow = slow.Insert(addr[:], cidr, peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
} }
for n := 0; n < NumberOfTests; n++ { var p int
var addr [AddressLength]byte for p = 0; ; p++ {
rand.Read(addr[:]) for n := 0; n < NumberOfTests; n++ {
peer1 := slow.Lookup(addr[:]) var addr4 [4]byte
peer2 := trie.lookup(addr[:]) rand.Read(addr4[:])
if peer1 != peer2 { peer1 := slow4.Lookup(addr4[:])
t.Error("Trie did not match naive implementation, for:", addr) peer2 := allowedIPs.Lookup(addr4[:])
} if peer1 != peer2 {
} t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
} }
func TestTrieRandomIPv6(t *testing.T) { var addr6 [16]byte
var trie *trieEntry rand.Read(addr6[:])
var slow SlowRouter peer1 = slow6.Lookup(addr6[:])
var peers []*Peer peer2 = allowedIPs.Lookup(addr6[:])
if peer1 != peer2 {
rand.Seed(1) t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
}
const AddressLength = 16
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < NumberOfTests; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
} }
if p >= len(peers) || p >= NumberOfPeerRemovals {
break
}
allowedIPs.RemoveByPeer(peers[p])
slow4 = slow4.RemoveByPeer(peers[p])
slow6 = slow6.RemoveByPeer(peers[p])
}
for ; p < len(peers); p++ {
allowedIPs.RemoveByPeer(peers[p])
}
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Failed to remove all nodes from trie by peer")
} }
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -8,20 +8,17 @@ package device
import ( import (
"math/rand" "math/rand"
"net" "net"
"net/netip"
"testing" "testing"
) )
/* Todo: More comprehensive
*/
type testPairCommonBits struct { type testPairCommonBits struct {
s1 []byte s1 []byte
s2 []byte s2 []byte
match uint match uint8
} }
func TestCommonBits(t *testing.T) { func TestCommonBits(t *testing.T) {
tests := []testPairCommonBits{ tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@@ -42,9 +39,10 @@ func TestCommonBits(t *testing.T) {
} }
} }
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry var trie *trieEntry
var peers []*Peer var peers []*Peer
root := parentIndirection{&trie, 2}
rand.Seed(1) rand.Seed(1)
@@ -57,9 +55,9 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
for n := 0; n < addressNumber; n++ { for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte var addr [AddressLength]byte
rand.Read(addr[:]) rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8)) cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber index := rand.Int() % peerNumber
trie = trie.insert(addr[:], cidr, peers[index]) root.insert(addr[:], cidr, peers[index])
} }
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
@@ -97,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint) { insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
trie = trie.insert([]byte{a, b, c, d}, cidr, peer) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
} }
assertEQ := func(peer *Peer, a, b, c, d byte) { assertEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }
} }
assertNEQ := func(peer *Peer, a, b, c, d byte) { assertNEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d}) p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer { if p == peer {
t.Error("Assert NEQ failed") t.Error("Assert NEQ failed")
} }
@@ -153,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0) assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0) assertEQ(a, 255, 0, 0, 0)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0)
@@ -161,12 +159,21 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0)
trie = nil allowedIPs.RemoveByPeer(a)
allowedIPs.RemoveByPeer(b)
allowedIPs.RemoveByPeer(c)
allowedIPs.RemoveByPeer(d)
allowedIPs.RemoveByPeer(e)
allowedIPs.RemoveByPeer(g)
allowedIPs.RemoveByPeer(h)
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
t.Error("Expected removing all the peers to empty trie, but it did not")
}
insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24) insert(a, 192, 168, 0, 0, 24)
trie = trie.removeByPeer(a) allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1) assertNEQ(a, 192, 168, 0, 1)
} }
@@ -184,7 +191,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{} g := &Peer{}
h := &Peer{} h := &Peer{}
var trie *trieEntry var allowedIPs AllowedIPs
expand := func(a uint32) []byte { expand := func(a uint32) []byte {
var out [4]byte var out [4]byte
@@ -195,13 +202,13 @@ func TestTrieIPv6(t *testing.T) {
return out[:] return out[:]
} }
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte var addr []byte
addr = append(addr, expand(a)...) addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
trie = trie.insert(addr, cidr, peer) allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
} }
assertEQ := func(peer *Peer, a, b, c, d uint32) { assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -210,7 +217,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...) addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...) addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...) addr = append(addr, expand(d)...)
p := trie.lookup(addr) p := allowedIPs.Lookup(addr)
if p != peer { if p != peer {
t.Error("Assert EQ failed") t.Error("Assert EQ failed")
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -35,7 +35,6 @@ const (
/* Implementation constants */ /* Implementation constants */
const ( const (
UnderLoadQueueSize = QueueHandshakeSize / 8
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
MaxPeers = 1 << 16 // maximum number of configured peers MaxPeers = 1 << 16 // maximum number of configured peers
) )

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2]) return hmac.Equal(mac1[:], msg[smac1:smac2])
} }
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
st.RLock() st.RLock()
defer st.RUnlock() defer st.RUnlock()
@@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply(
recv uint32, recv uint32,
src []byte, src []byte,
) (*MessageCookieReply, error) { ) (*MessageCookieReply, error) {
st.RLock() st.RLock()
// refresh cookie secret // refresh cookie secret
@@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
if err != nil { if err != nil {
return false return false
} }
@@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
} }
func (st *CookieGenerator) AddMacs(msg []byte) { func (st *CookieGenerator) AddMacs(msg []byte) {
size := len(msg) size := len(msg)
smac2 := size - blake2s.Size128 smac2 := size - blake2s.Size128

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -10,7 +10,6 @@ import (
) )
func TestCookieMAC1(t *testing.T) { func TestCookieMAC1(t *testing.T) {
// setup generator / checker // setup generator / checker
var ( var (
@@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20 msg[5] ^= 0x20
srcBad1 := []byte{192, 168, 13, 37, 40, 01} srcBad1 := []byte{192, 168, 13, 37, 40, 1}
if checker.CheckMAC2(msg, srcBad1) { if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }
srcBad2 := []byte{192, 168, 13, 38, 40, 01} srcBad2 := []byte{192, 168, 13, 38, 40, 1}
if checker.CheckMAC2(msg, srcBad2) { if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed") t.Fatal("MAC2 generation/verification failed")
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -11,9 +11,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/rwcancel"
@@ -33,7 +30,7 @@ type Device struct {
// will become the actual state; Up can fail. // will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use. // The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only. // Unsynchronized uses of state must therefore be advisory/best-effort only.
state uint32 // actually a deviceState, but typed uint32 for convenience state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed. // stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup stopping sync.WaitGroup
// mu protects state changes. // mu protects state changes.
@@ -47,6 +44,7 @@ type Device struct {
netlinkCancel *rwcancel.RWCancel netlinkCancel *rwcancel.RWCancel
port uint16 // listening port port uint16 // listening port
fwmark uint32 // mark value (0 = disabled) fwmark uint32 // mark value (0 = disabled)
brokenRoaming bool
} }
staticIdentity struct { staticIdentity struct {
@@ -56,20 +54,19 @@ type Device struct {
} }
peers struct { peers struct {
empty AtomicBool // empty reports whether len(keyMap) == 0 sync.RWMutex // protects keyMap
sync.RWMutex // protects keyMap
keyMap map[NoisePublicKey]*Peer keyMap map[NoisePublicKey]*Peer
} }
rate struct {
underLoadUntil atomic.Int64
limiter ratelimiter.Ratelimiter
}
allowedips AllowedIPs allowedips AllowedIPs
indexTable IndexTable indexTable IndexTable
cookieChecker CookieChecker cookieChecker CookieChecker
rate struct {
underLoadUntil int64
limiter ratelimiter.Ratelimiter
}
pool struct { pool struct {
messageBuffers *WaitPool messageBuffers *WaitPool
inboundElements *WaitPool inboundElements *WaitPool
@@ -84,7 +81,7 @@ type Device struct {
tun struct { tun struct {
device tun.Device device tun.Device
mtu int32 mtu atomic.Int32
} }
ipcMutex sync.RWMutex ipcMutex sync.RWMutex
@@ -96,10 +93,9 @@ type Device struct {
// There are three states: down, up, closed. // There are three states: down, up, closed.
// Transitions: // Transitions:
// //
// down -----+ // down -----+
// ↑↓ ↓ // ↑↓ ↓
// up -> closed // up -> closed
//
type deviceState uint32 type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
@@ -112,7 +108,7 @@ const (
// deviceState returns device.state.state as a deviceState // deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value. // See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState { func (device *Device) deviceState() deviceState {
return deviceState(atomic.LoadUint32(&device.state.state)) return deviceState(device.state.state.Load())
} }
// isClosed reports whether the device is closed (or is closing). // isClosed reports whether the device is closed (or is closing).
@@ -135,7 +131,6 @@ func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
// remove from peer map // remove from peer map
delete(device.peers.keyMap, key) delete(device.peers.keyMap, key)
device.peers.empty.Set(len(device.peers.keyMap) == 0)
} }
// changeState attempts to change the device state to match want. // changeState attempts to change the device state to match want.
@@ -152,14 +147,14 @@ func (device *Device) changeState(want deviceState) (err error) {
case old: case old:
return nil return nil
case deviceStateUp: case deviceStateUp:
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) device.state.state.Store(uint32(deviceStateUp))
err = device.upLocked() err = device.upLocked()
if err == nil { if err == nil {
break break
} }
fallthrough // up failed; bring the device all the way back down fallthrough // up failed; bring the device all the way back down
case deviceStateDown: case deviceStateDown:
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) device.state.state.Store(uint32(deviceStateDown))
errDown := device.downLocked() errDown := device.downLocked()
if err == nil { if err == nil {
err = errDown err = errDown
@@ -177,10 +172,15 @@ func (device *Device) upLocked() error {
return err return err
} }
// The IPC set operation waits for peers to be created before calling Start() on them,
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
device.ipcMutex.Lock()
defer device.ipcMutex.Unlock()
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Start() peer.Start()
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive() peer.SendKeepalive()
} }
} }
@@ -215,13 +215,13 @@ func (device *Device) Down() error {
func (device *Device) IsUnderLoad() bool { func (device *Device) IsUnderLoad() bool {
// check if currently under load // check if currently under load
now := time.Now() now := time.Now()
underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad { if underLoad {
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano()) device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true return true
} }
// check if recently under load // check if recently under load
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano() return device.rate.underLoadUntil.Load() > now.UnixNano()
} }
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@@ -265,7 +265,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer) expiredPeers = append(expiredPeers, peer)
} }
@@ -279,18 +279,19 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
return nil return nil
} }
func NewDevice(tunDevice tun.Device, logger *Logger) *Device { func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.state.state = uint32(deviceStateDown) device.state.state.Store(uint32(deviceStateDown))
device.closed = make(chan struct{}) device.closed = make(chan struct{})
device.log = logger device.log = logger
device.net.bind = bind
device.tun.device = tunDevice device.tun.device = tunDevice
mtu, err := device.tun.device.MTU() mtu, err := device.tun.device.MTU()
if err != nil { if err != nil {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err) device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU mtu = DefaultMTU
} }
device.tun.mtu = int32(mtu) device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init() device.rate.limiter.Init()
device.indexTable.Init() device.indexTable.Init()
@@ -302,20 +303,15 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device.queue.encryption = newOutboundQueue() device.queue.encryption = newOutboundQueue()
device.queue.decryption = newInboundQueue() device.queue.decryption = newInboundQueue()
// prepare net
device.net.port = 0
device.net.bind = nil
// start workers // start workers
cpus := runtime.NumCPU() cpus := runtime.NumCPU()
device.state.stopping.Wait() device.state.stopping.Wait()
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
for i := 0; i < cpus; i++ { for i := 0; i < cpus; i++ {
go device.RoutineEncryption() go device.RoutineEncryption(i + 1)
go device.RoutineDecryption() go device.RoutineDecryption(i + 1)
go device.RoutineHandshake() go device.RoutineHandshake(i + 1)
} }
device.state.stopping.Add(1) // RoutineReadFromTUN device.state.stopping.Add(1) // RoutineReadFromTUN
@@ -361,7 +357,7 @@ func (device *Device) Close() {
if device.isClosed() { if device.isClosed() {
return return
} }
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) device.state.state.Store(uint32(deviceStateClosed))
device.log.Verbosef("Device closing") device.log.Verbosef("Device closing")
device.tun.device.Close() device.tun.device.Close()
@@ -406,7 +402,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
device.peers.RUnlock() device.peers.RUnlock()
} }
func unsafeCloseBind(device *Device) error { // closeBindLocked closes the device's net.bind.
// The caller must hold the net mutex.
func closeBindLocked(device *Device) error {
var err error var err error
netc := &device.net netc := &device.net
if netc.netlinkCancel != nil { if netc.netlinkCancel != nil {
@@ -414,7 +412,6 @@ func unsafeCloseBind(device *Device) error {
} }
if netc.bind != nil { if netc.bind != nil {
err = netc.bind.Close() err = netc.bind.Close()
netc.bind = nil
} }
netc.stopping.Wait() netc.stopping.Wait()
return err return err
@@ -462,7 +459,7 @@ func (device *Device) BindUpdate() error {
defer device.net.Unlock() defer device.net.Unlock()
// close existing sockets // close existing sockets
if err := unsafeCloseBind(device); err != nil { if err := closeBindLocked(device); err != nil {
return err return err
} }
@@ -473,17 +470,16 @@ func (device *Device) BindUpdate() error {
// bind to new port // bind to new port
var err error var err error
var recvFns []conn.ReceiveFunc
netc := &device.net netc := &device.net
netc.bind, netc.port, err = conn.CreateBind(netc.port) recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil { if err != nil {
netc.bind = nil
netc.port = 0 netc.port = 0
return err return err
} }
netc.netlinkCancel, err = device.startRouteListener(netc.bind) netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil { if err != nil {
netc.bind.Close() netc.bind.Close()
netc.bind = nil
netc.port = 0 netc.port = 0
return err return err
} }
@@ -508,11 +504,12 @@ func (device *Device) BindUpdate() error {
device.peers.RUnlock() device.peers.RUnlock()
// start receiving routines // start receiving routines
device.net.stopping.Add(2) device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) for _, fn := range recvFns {
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) go device.RoutineReceiveIncoming(fn)
}
device.log.Verbosef("UDP bind has been updated") device.log.Verbosef("UDP bind has been updated")
return nil return nil
@@ -520,7 +517,7 @@ func (device *Device) BindUpdate() error {
func (device *Device) BindClose() error { func (device *Device) BindClose() error {
device.net.Lock() device.net.Lock()
err := unsafeCloseBind(device) err := closeBindLocked(device)
device.net.Unlock() device.net.Unlock()
return err return err
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -8,19 +8,19 @@ package device
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io/ioutil" "io"
"math/rand" "math/rand"
"net" "net/netip"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall"
"testing" "testing"
"time" "time"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun/tuntest" "golang.zx2c4.com/wireguard/tun/tuntest"
) )
@@ -48,7 +48,7 @@ func uapiCfg(cfg ...string) string {
// genConfigs generates a pair of configs that connect to each other. // genConfigs generates a pair of configs that connect to each other.
// The configs use distinct, probably-usable ports. // The configs use distinct, probably-usable ports.
func genConfigs(tb testing.TB) (cfgs [2]string, endpointCfgs [2]string) { func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:]) _, err := rand.Read(key1[:])
if err != nil { if err != nil {
@@ -96,7 +96,7 @@ type testPair [2]testPeer
type testPeer struct { type testPeer struct {
tun *tuntest.ChannelTUN tun *tuntest.ChannelTUN
dev *Device dev *Device
ip net.IP ip netip.Addr
} }
type SendDirection bool type SendDirection bool
@@ -147,18 +147,24 @@ func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}
} }
// genTestPair creates a testPair. // genTestPair creates a testPair.
func genTestPair(tb testing.TB) (pair testPair) { func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
cfg, endpointCfg := genConfigs(tb) cfg, endpointCfg := genConfigs(tb)
var binds [2]conn.Bind
if realSocket {
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
} else {
binds = bindtest.NewChannelBinds()
}
// Bring up a ChannelTun for each config. // Bring up a ChannelTun for each config.
for i := range pair { for i := range pair {
p := &pair[i] p := &pair[i]
p.tun = tuntest.NewChannelTUN() p.tun = tuntest.NewChannelTUN()
p.ip = net.IPv4(1, 0, 0, byte(i+1)) p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() { if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError level = LogLevelError
} }
p.dev = NewDevice(p.tun.TUN(), NewLogger(level, fmt.Sprintf("dev%d: ", i))) p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
if err := p.dev.IpcSet(cfg[i]); err != nil { if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err) tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close() p.dev.Close()
@@ -186,7 +192,7 @@ func genTestPair(tb testing.TB) (pair testPair) {
func TestTwoDevicePing(t *testing.T) { func TestTwoDevicePing(t *testing.T) {
goroutineLeakCheck(t) goroutineLeakCheck(t)
pair := genTestPair(t) pair := genTestPair(t, true)
t.Run("ping 1.0.0.1", func(t *testing.T) { t.Run("ping 1.0.0.1", func(t *testing.T) {
pair.Send(t, Ping, nil) pair.Send(t, Ping, nil)
}) })
@@ -197,11 +203,11 @@ func TestTwoDevicePing(t *testing.T) {
func TestUpDown(t *testing.T) { func TestUpDown(t *testing.T) {
goroutineLeakCheck(t) goroutineLeakCheck(t)
const itrials = 20 const itrials = 50
const otrials = 1 const otrials = 10
for n := 0; n < otrials; n++ { for n := 0; n < otrials; n++ {
pair := genTestPair(t) pair := genTestPair(t, false)
for i := range pair { for i := range pair {
for k := range pair[i].dev.peers.keyMap { for k := range pair[i].dev.peers.keyMap {
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
@@ -213,17 +219,8 @@ func TestUpDown(t *testing.T) {
go func(d *Device) { go func(d *Device) {
defer wg.Done() defer wg.Done()
for i := 0; i < itrials; i++ { for i := 0; i < itrials; i++ {
start := time.Now() if err := d.Up(); err != nil {
for { t.Errorf("failed up bring up device: %v", err)
if err := d.Up(); err != nil {
if errors.Is(err, syscall.EADDRINUSE) && time.Now().Sub(start) < time.Second*4 {
// Some other test process is racing with us, so try again.
time.Sleep(time.Millisecond * 10)
continue
}
t.Errorf("failed up bring up device: %v", err)
}
break
} }
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
if err := d.Down(); err != nil { if err := d.Down(); err != nil {
@@ -244,7 +241,7 @@ func TestUpDown(t *testing.T) {
// TestConcurrencySafety does other things concurrently with tunnel use. // TestConcurrencySafety does other things concurrently with tunnel use.
// It is intended to be used with the race detector to catch data races. // It is intended to be used with the race detector to catch data races.
func TestConcurrencySafety(t *testing.T) { func TestConcurrencySafety(t *testing.T) {
pair := genTestPair(t) pair := genTestPair(t, true)
done := make(chan struct{}) done := make(chan struct{})
const warmupIters = 10 const warmupIters = 10
@@ -313,32 +310,8 @@ func TestConcurrencySafety(t *testing.T) {
close(done) close(done)
} }
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a, b []byte) {
if !bytes.Equal(a, b) {
t.Fatal(a, "!=", b)
}
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}
func BenchmarkLatency(b *testing.B) { func BenchmarkLatency(b *testing.B) {
pair := genTestPair(b) pair := genTestPair(b, true)
// Establish a connection. // Establish a connection.
pair.Send(b, Ping, nil) pair.Send(b, Ping, nil)
@@ -352,7 +325,7 @@ func BenchmarkLatency(b *testing.B) {
} }
func BenchmarkThroughput(b *testing.B) { func BenchmarkThroughput(b *testing.B) {
pair := genTestPair(b) pair := genTestPair(b, true)
// Establish a connection. // Establish a connection.
pair.Send(b, Ping, nil) pair.Send(b, Ping, nil)
@@ -360,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) {
// Measure how long it takes to receive b.N packets, // Measure how long it takes to receive b.N packets,
// starting when we receive the first packet. // starting when we receive the first packet.
var recv uint64 var recv atomic.Uint64
var elapsed time.Duration var elapsed time.Duration
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
@@ -369,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) {
var start time.Time var start time.Time
for { for {
<-pair[0].tun.Inbound <-pair[0].tun.Inbound
new := atomic.AddUint64(&recv, 1) new := recv.Add(1)
if new == 1 { if new == 1 {
start = time.Now() start = time.Now()
} }
@@ -385,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) {
ping := tuntest.Ping(pair[0].ip, pair[1].ip) ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound pingc := pair[1].tun.Outbound
var sent uint64 var sent uint64
for atomic.LoadUint64(&recv) != uint64(b.N) { for recv.Load() != uint64(b.N) {
sent++ sent++
pingc <- ping pingc <- ping
} }
@@ -396,13 +369,13 @@ func BenchmarkThroughput(b *testing.B) {
} }
func BenchmarkUAPIGet(b *testing.B) { func BenchmarkUAPIGet(b *testing.B) {
pair := genTestPair(b) pair := genTestPair(b, true)
pair.Send(b, Ping, nil) pair.Send(b, Ping, nil)
pair.Send(b, Pong, nil) pair.Send(b, Pong, nil)
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
pair[0].dev.IpcGetOperation(ioutil.Discard) pair[0].dev.IpcGetOperation(io.Discard)
} }
} }

View File

@@ -1,53 +1,49 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"math/rand" "math/rand"
"net" "net/netip"
) )
type DummyEndpoint struct { type DummyEndpoint struct {
src [16]byte src, dst netip.Addr
dst [16]byte
} }
func CreateDummyEndpoint() (*DummyEndpoint, error) { func CreateDummyEndpoint() (*DummyEndpoint, error) {
var end DummyEndpoint var src, dst [16]byte
if _, err := rand.Read(end.src[:]); err != nil { if _, err := rand.Read(src[:]); err != nil {
return nil, err return nil, err
} }
_, err := rand.Read(end.dst[:]) _, err := rand.Read(dst[:])
return &end, err return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
} }
func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string { func (e *DummyEndpoint) SrcToString() string {
var addr net.UDPAddr return netip.AddrPortFrom(e.SrcIP(), 1000).String()
addr.IP = e.SrcIP()
addr.Port = 1000
return addr.String()
} }
func (e *DummyEndpoint) DstToString() string { func (e *DummyEndpoint) DstToString() string {
var addr net.UDPAddr return netip.AddrPortFrom(e.DstIP(), 1000).String()
addr.IP = e.DstIP()
addr.Port = 1000
return addr.String()
} }
func (e *DummyEndpoint) SrcToBytes() []byte { func (e *DummyEndpoint) DstToBytes() []byte {
return e.src[:] out := e.DstIP().AsSlice()
out = append(out, byte(1000&0xff))
out = append(out, byte((1000>>8)&0xff))
return out
} }
func (e *DummyEndpoint) DstIP() net.IP { func (e *DummyEndpoint) DstIP() netip.Addr {
return e.dst[:] return e.dst
} }
func (e *DummyEndpoint) SrcIP() net.IP { func (e *DummyEndpoint) SrcIP() netip.Addr {
return e.src[:] return e.src
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -20,7 +20,7 @@ type KDFTest struct {
t2 string t2 string
} }
func assertEquals(t *testing.T, a string, b string) { func assertEquals(t *testing.T, a, b string) {
if a != b { if a != b {
t.Fatal("expected", a, "=", b) t.Fatal("expected", a, "=", b)
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -10,7 +10,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"golang.zx2c4.com/wireguard/replay" "golang.zx2c4.com/wireguard/replay"
) )
@@ -23,7 +22,7 @@ import (
*/ */
type Keypair struct { type Keypair struct {
sendNonce uint64 // accessed atomically sendNonce atomic.Uint64
send cipher.AEAD send cipher.AEAD
receive cipher.AEAD receive cipher.AEAD
replayFilter replay.Filter replayFilter replay.Filter
@@ -37,15 +36,7 @@ type Keypairs struct {
sync.RWMutex sync.RWMutex
current *Keypair current *Keypair
previous *Keypair previous *Keypair
next *Keypair next atomic.Pointer[Keypair]
}
func (kp *Keypairs) storeNext(next *Keypair) {
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
}
func (kp *Keypairs) loadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
} }
func (kp *Keypairs) Current() *Keypair { func (kp *Keypairs) Current() *Keypair {

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -16,8 +16,8 @@ import (
// They do not require a trailing newline in the format. // They do not require a trailing newline in the format.
// If nil, that level of logging will be silent. // If nil, that level of logging will be silent.
type Logger struct { type Logger struct {
Verbosef func(format string, args ...interface{}) Verbosef func(format string, args ...any)
Errorf func(format string, args ...interface{}) Errorf func(format string, args ...any)
} }
// Log levels for use with NewLogger. // Log levels for use with NewLogger.
@@ -28,14 +28,14 @@ const (
) )
// Function for use in Logger for discarding logged lines. // Function for use in Logger for discarding logged lines.
func DiscardLogf(format string, args ...interface{}) {} func DiscardLogf(format string, args ...any) {}
// NewLogger constructs a Logger that writes to stdout. // NewLogger constructs a Logger that writes to stdout.
// It logs at the specified log level and above. // It logs at the specified log level and above.
// It decorates log lines with the log level, date, time, and prepend. // It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger { func NewLogger(level int, prepend string) *Logger {
logger := &Logger{DiscardLogf, DiscardLogf} logger := &Logger{DiscardLogf, DiscardLogf}
logf := func(prefix string) func(string, ...interface{}) { logf := func(prefix string) func(string, ...any) {
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
} }
if level >= LogLevelVerbose { if level >= LogLevelVerbose {

View File

@@ -1,48 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"sync/atomic"
)
/* Atomic Boolean */
const (
AtomicFalse = int32(iota)
AtomicTrue
)
type AtomicBool struct {
int32
}
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.int32) == AtomicTrue
}
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
atomic.StoreInt32(&a.int32, flag)
}
func min(a, b uint) uint {
if a > b {
return b
}
return a
}

View File

@@ -1,16 +1,19 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
// though it will try to deal with it, and race maybe, if called after.
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.net.brokenRoaming = true
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Lock() peer.Lock()
defer peer.Unlock()
peer.disableRoaming = peer.endpoint != nil peer.disableRoaming = peer.endpoint != nil
peer.Unlock()
} }
device.peers.RUnlock() device.peers.RUnlock()
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -9,6 +9,7 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/subtle" "crypto/subtle"
"errors"
"hash" "hash"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
@@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return return
} }
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { var errInvalidPublicKey = errors.New("invalid public key")
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk) apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk) ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk) curve25519.ScalarMult(&ss, ask, apk)
return ss if isZero(ss[:]) {
return ss, errInvalidPublicKey
}
return ss, nil
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -20,7 +20,6 @@ import (
type handshakeState int type handshakeState int
// TODO(crawshaw): add commentary describing each state and the transitions
const ( const (
handshakeZeroed = handshakeState(iota) handshakeZeroed = handshakeState(iota)
handshakeInitiationCreated handshakeInitiationCreated
@@ -139,11 +138,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte ZeroNonce [chacha20poly1305.NonceSize]byte
) )
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data) KDF1(dst, c[:], data)
} }
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
hash.Write(h[:]) hash.Write(h[:])
hash.Write(data) hash.Write(data)
@@ -176,8 +175,6 @@ func init() {
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@@ -205,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt static key // encrypt static key
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) { if err != nil {
return nil, errZeroECDHResult return nil, err
} }
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
KDF2( KDF2(
@@ -222,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp // encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) { if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errZeroECDHResult return nil, errInvalidPublicKey
} }
KDF2( KDF2(
&handshake.chainKey, &handshake.chainKey,
@@ -265,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
var err error
var peerPK NoisePublicKey var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) { if err != nil {
return nil return nil
} }
KDF2(&chainKey, &key, chainKey[:], ss[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
@@ -283,7 +279,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer // lookup peer
peer := device.LookupPeer(peerPK) peer := device.LookupPeer(peerPK)
if peer == nil { if peer == nil || !peer.isRunning.Load() {
return nil return nil
} }
@@ -385,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
func() { ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) if err != nil {
handshake.mixKey(ss[:]) return nil, err
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) }
handshake.mixKey(ss[:]) handshake.mixKey(ss[:])
}() ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if err != nil {
return nil, err
}
handshake.mixKey(ss[:])
// add preshared key // add preshared key
@@ -407,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:]) handshake.mixHash(tau[:])
func() { aead, _ := chacha20poly1305.New(key[:])
aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) handshake.mixHash(msg.Empty[:])
handshake.mixHash(msg.Empty[:])
}()
handshake.state = handshakeResponseCreated handshake.state = handshakeResponseCreated
@@ -437,7 +435,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
) )
ok := func() bool { ok := func() bool {
// lock handshake state // lock handshake state
handshake.mutex.RLock() handshake.mutex.RLock()
@@ -457,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() { ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) if err != nil {
mixKey(&chainKey, &chainKey, ss[:]) return false
setZero(ss[:]) }
}() mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
func() { ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if err != nil {
mixKey(&chainKey, &chainKey, ss[:]) return false
setZero(ss[:]) }
}() mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
// add preshared key (psk) // add preshared key (psk)
@@ -485,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript // authenticate transcript
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil { if err != nil {
return false return false
} }
@@ -583,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock() defer keypairs.Unlock()
previous := keypairs.previous previous := keypairs.previous
next := keypairs.loadNext() next := keypairs.next.Load()
current := keypairs.current current := keypairs.current
if isInitiator { if isInitiator {
if next != nil { if next != nil {
keypairs.storeNext(nil) keypairs.next.Store(nil)
keypairs.previous = next keypairs.previous = next
device.DeleteKeypair(current) device.DeleteKeypair(current)
} else { } else {
@@ -597,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
keypairs.current = keypair keypairs.current = keypair
} else { } else {
keypairs.storeNext(keypair) keypairs.next.Store(keypair)
device.DeleteKeypair(next) device.DeleteKeypair(next)
keypairs.previous = nil keypairs.previous = nil
device.DeleteKeypair(previous) device.DeleteKeypair(previous)
@@ -609,18 +608,18 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs keypairs := &peer.keypairs
if keypairs.loadNext() != receivedKeypair { if keypairs.next.Load() != receivedKeypair {
return false return false
} }
keypairs.Lock() keypairs.Lock()
defer keypairs.Unlock() defer keypairs.Unlock()
if keypairs.loadNext() != receivedKeypair { if keypairs.next.Load() != receivedKeypair {
return false return false
} }
old := keypairs.previous old := keypairs.previous
keypairs.previous = keypairs.current keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old) peer.device.DeleteKeypair(old)
keypairs.current = keypairs.loadNext() keypairs.current = keypairs.next.Load()
keypairs.storeNext(nil) keypairs.next.Store(nil)
return true return true
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -9,6 +9,9 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"testing" "testing"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/tun/tuntest"
) )
func TestCurveWrappers(t *testing.T) { func TestCurveWrappers(t *testing.T) {
@@ -21,14 +24,38 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey() pk1 := sk1.publicKey()
pk2 := sk2.publicKey() pk2 := sk2.publicKey()
ss1 := sk1.sharedSecret(pk2) ss1, err1 := sk1.sharedSecret(pk2)
ss2 := sk2.sharedSecret(pk1) ss2, err2 := sk2.sharedSecret(pk1)
if ss1 != ss2 { if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet") t.Fatal("Failed to compute shared secet")
} }
} }
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := tuntest.NewChannelTUN()
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
device.SetPrivateKey(sk)
return device
}
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a, b []byte) {
if !bytes.Equal(a, b) {
t.Fatal(a, "!=", b)
}
}
func TestNoiseHandshake(t *testing.T) { func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t) dev1 := randDevice(t)
dev2 := randDevice(t) dev2 := randDevice(t)
@@ -44,6 +71,8 @@ func TestNoiseHandshake(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
peer1.Start()
peer2.Start()
assertEqual( assertEqual(
t, t,
@@ -119,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err) t.Fatal("failed to derive keypair for peer 2", err)
} }
key1 := peer1.keypairs.loadNext() key1 := peer1.keypairs.next.Load()
key2 := peer2.keypairs.current key2 := peer2.keypairs.current
// encrypting / decryption test // encrypting / decryption test

View File

@@ -1,15 +1,13 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"container/list" "container/list"
"encoding/base64"
"errors" "errors"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -18,24 +16,16 @@ import (
) )
type Peer struct { type Peer struct {
isRunning AtomicBool isRunning atomic.Bool
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs keypairs Keypairs
handshake Handshake handshake Handshake
device *Device device *Device
endpoint conn.Endpoint endpoint conn.Endpoint
stopping sync.WaitGroup // routines pending stop stopping sync.WaitGroup // routines pending stop
txBytes atomic.Uint64 // bytes send to peer (endpoint)
// These fields are accessed with atomic operations, which must be rxBytes atomic.Uint64 // bytes received from peer
// 64-bit aligned even on 32-bit platforms. Go guarantees that an lastHandshakeNano atomic.Int64 // nano seconds since epoch
// allocated struct will be 64-bit aligned. So we place
// atomically-accessed fields up front, so that they can share in
// this alignment before smaller fields throw it off.
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
}
disableRoaming bool disableRoaming bool
@@ -45,9 +35,9 @@ type Peer struct {
newHandshake *Timer newHandshake *Timer
zeroKeyMaterial *Timer zeroKeyMaterial *Timer
persistentKeepalive *Timer persistentKeepalive *Timer
handshakeAttempts uint32 handshakeAttempts atomic.Uint32
needAnotherKeepalive AtomicBool needAnotherKeepalive atomic.Bool
sentLastMinuteHandshake AtomicBool sentLastMinuteHandshake atomic.Bool
} }
state struct { state struct {
@@ -62,7 +52,7 @@ type Peer struct {
cookieGenerator CookieGenerator cookieGenerator CookieGenerator
trieEntries list.List trieEntries list.List
persistentKeepaliveInterval uint32 // accessed atomically persistentKeepaliveInterval atomic.Uint32
} }
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@@ -102,22 +92,18 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// pre-compute DH // pre-compute DH
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint // reset endpoint
peer.endpoint = nil peer.endpoint = nil
// init timers
peer.timersInit()
// add // add
device.peers.keyMap[pk] = peer device.peers.keyMap[pk] = peer
device.peers.empty.Set(false)
// start peer
peer.timersInit()
if peer.device.isUp() {
peer.Start()
}
return peer, nil return peer, nil
} }
@@ -126,13 +112,8 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.RLock() peer.device.net.RLock()
defer peer.device.net.RUnlock() defer peer.device.net.RUnlock()
if peer.device.net.bind == nil { if peer.device.isClosed() {
// Packets can leak through to SendBuffer while the device is closing. return nil
// When that happens, drop them silently to avoid spurious errors.
if peer.device.isClosed() {
return nil
}
return errors.New("no bind")
} }
peer.RLock() peer.RLock()
@@ -144,18 +125,35 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
err := peer.device.net.bind.Send(buffer, peer.endpoint) err := peer.device.net.bind.Send(buffer, peer.endpoint)
if err == nil { if err == nil {
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) peer.txBytes.Add(uint64(len(buffer)))
} }
return err return err
} }
func (peer *Peer) String() string { func (peer *Peer) String() string {
base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) // The awful goo that follows is identical to:
abbreviatedKey := "invalid" //
if len(base64Key) == 44 { // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
//
// except that it is considerably more efficient.
src := peer.handshake.remoteStatic
b64 := func(input byte) byte {
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
} }
return fmt.Sprintf("peer(%s)", abbreviatedKey) b := []byte("peer(____…____)")
const first = len("peer(")
const second = len("peer(____…")
b[first+0] = b64((src[0] >> 2) & 63)
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
b[first+3] = b64(src[2] & 63)
b[second+0] = b64(src[29] & 63)
b[second+1] = b64((src[30] >> 2) & 63)
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
b[second+3] = b64((src[31] << 2) & 63)
return string(b)
} }
func (peer *Peer) Start() { func (peer *Peer) Start() {
@@ -168,12 +166,12 @@ func (peer *Peer) Start() {
peer.state.Lock() peer.state.Lock()
defer peer.state.Unlock() defer peer.state.Unlock()
if peer.isRunning.Get() { if peer.isRunning.Load() {
return return
} }
device := peer.device device := peer.device
device.log.Verbosef("%v - Starting...", peer) device.log.Verbosef("%v - Starting", peer)
// reset routine state // reset routine state
peer.stopping.Wait() peer.stopping.Wait()
@@ -192,7 +190,7 @@ func (peer *Peer) Start() {
go peer.RoutineSequentialSender() go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver() go peer.RoutineSequentialReceiver()
peer.isRunning.Set(true) peer.isRunning.Store(true)
} }
func (peer *Peer) ZeroAndFlushAll() { func (peer *Peer) ZeroAndFlushAll() {
@@ -204,10 +202,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock() keypairs.Lock()
device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current) device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.loadNext()) device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil keypairs.previous = nil
keypairs.current = nil keypairs.current = nil
keypairs.storeNext(nil) keypairs.next.Store(nil)
keypairs.Unlock() keypairs.Unlock()
// clear handshake state // clear handshake state
@@ -232,11 +230,10 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs := &peer.keypairs keypairs := &peer.keypairs
keypairs.Lock() keypairs.Lock()
if keypairs.current != nil { if keypairs.current != nil {
atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages) keypairs.current.sendNonce.Store(RejectAfterMessages)
} }
if keypairs.next != nil { if next := keypairs.next.Load(); next != nil {
next := keypairs.loadNext() next.sendNonce.Store(RejectAfterMessages)
atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
} }
keypairs.Unlock() keypairs.Unlock()
} }
@@ -249,7 +246,7 @@ func (peer *Peer) Stop() {
return return
} }
peer.device.log.Verbosef("%v - Stopping...", peer) peer.device.log.Verbosef("%v - Stopping", peer)
peer.timersStop() peer.timersStop()
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -14,45 +14,45 @@ type WaitPool struct {
pool sync.Pool pool sync.Pool
cond sync.Cond cond sync.Cond
lock sync.Mutex lock sync.Mutex
count uint32 count atomic.Uint32
max uint32 max uint32
} }
func NewWaitPool(max uint32, new func() interface{}) *WaitPool { func NewWaitPool(max uint32, new func() any) *WaitPool {
p := &WaitPool{pool: sync.Pool{New: new}, max: max} p := &WaitPool{pool: sync.Pool{New: new}, max: max}
p.cond = sync.Cond{L: &p.lock} p.cond = sync.Cond{L: &p.lock}
return p return p
} }
func (p *WaitPool) Get() interface{} { func (p *WaitPool) Get() any {
if p.max != 0 { if p.max != 0 {
p.lock.Lock() p.lock.Lock()
for atomic.LoadUint32(&p.count) >= p.max { for p.count.Load() >= p.max {
p.cond.Wait() p.cond.Wait()
} }
atomic.AddUint32(&p.count, 1) p.count.Add(1)
p.lock.Unlock() p.lock.Unlock()
} }
return p.pool.Get() return p.pool.Get()
} }
func (p *WaitPool) Put(x interface{}) { func (p *WaitPool) Put(x any) {
p.pool.Put(x) p.pool.Put(x)
if p.max == 0 { if p.max == 0 {
return return
} }
atomic.AddUint32(&p.count, ^uint32(0)) p.count.Add(^uint32(0))
p.cond.Signal() p.cond.Signal()
} }
func (device *Device) PopulatePools() { func (device *Device) PopulatePools() {
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte) return new([MaxMessageSize]byte)
}) })
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueInboundElement) return new(QueueInboundElement)
}) })
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueOutboundElement) return new(QueueOutboundElement)
}) })
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -15,30 +15,33 @@ import (
) )
func TestWaitPool(t *testing.T) { func TestWaitPool(t *testing.T) {
t.Skip("Currently disabled")
var wg sync.WaitGroup var wg sync.WaitGroup
trials := int32(100000) var trials atomic.Int32
startTrials := int32(100000)
if raceEnabled { if raceEnabled {
// This test can be very slow with -race. // This test can be very slow with -race.
trials /= 10 startTrials /= 10
} }
trials.Store(startTrials)
workers := runtime.NumCPU() + 2 workers := runtime.NumCPU() + 2
if workers-4 <= 0 { if workers-4 <= 0 {
t.Skip("Not enough cores") t.Skip("Not enough cores")
} }
p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) }) p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers) wg.Add(workers)
max := uint32(0) var max atomic.Uint32
updateMax := func() { updateMax := func() {
count := atomic.LoadUint32(&p.count) count := p.count.Load()
if count > p.max { if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max) t.Errorf("count (%d) > max (%d)", count, p.max)
} }
for { for {
old := atomic.LoadUint32(&max) old := max.Load()
if count <= old { if count <= old {
break break
} }
if atomic.CompareAndSwapUint32(&max, old, count) { if max.CompareAndSwap(old, count) {
break break
} }
} }
@@ -46,7 +49,7 @@ func TestWaitPool(t *testing.T) {
for i := 0; i < workers; i++ { for i := 0; i < workers; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
for atomic.AddInt32(&trials, -1) > 0 { for trials.Add(-1) > 0 {
updateMax() updateMax()
x := p.Get() x := p.Get()
updateMax() updateMax()
@@ -58,25 +61,26 @@ func TestWaitPool(t *testing.T) {
}() }()
} }
wg.Wait() wg.Wait()
if max != p.max { if max.Load() != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
} }
} }
func BenchmarkWaitPool(b *testing.B) { func BenchmarkWaitPool(b *testing.B) {
var wg sync.WaitGroup var wg sync.WaitGroup
trials := int32(b.N) var trials atomic.Int32
trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2 workers := runtime.NumCPU() + 2
if workers-4 <= 0 { if workers-4 <= 0 {
b.Skip("Not enough cores") b.Skip("Not enough cores")
} }
p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) }) p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers) wg.Add(workers)
b.ResetTimer() b.ResetTimer()
for i := 0; i < workers; i++ { for i := 0; i < workers; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
for atomic.AddInt32(&trials, -1) > 0 { for trials.Add(-1) > 0 {
x := p.Get() x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x) p.Put(x)

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View File

@@ -1,8 +1,8 @@
// +build !android,!ios //go:build !android && !ios && !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View File

@@ -1,19 +1,21 @@
// +build ios //go:build ios
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */ // Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
// These are vars instead of consts, because heavier network extensions might want to reduce
const ( // them further.
QueueStagedSize = 128 var (
QueueOutboundSize = 1024 QueueStagedSize = 128
QueueInboundSize = 1024 QueueOutboundSize = 1024
QueueHandshakeSize = 1024 QueueInboundSize = 1024
MaxSegmentSize = 1700 QueueHandshakeSize = 1024
PreallocatedBuffersPerPool = 1024 PreallocatedBuffersPerPool uint32 = 1024
) )
const MaxSegmentSize = 1700

View File

@@ -0,0 +1,15 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
const (
QueueStagedSize = 128
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
)

View File

@@ -1,8 +1,8 @@
//+build !race //go:build !race
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View File

@@ -1,8 +1,8 @@
//+build race //go:build race
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -11,13 +11,11 @@ import (
"errors" "errors"
"net" "net"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
) )
@@ -53,12 +51,12 @@ func (elem *QueueInboundElement) clearPointers() {
* NOTE: Not thread safe, but called by sequential receiver! * NOTE: Not thread safe, but called by sequential receiver!
*/ */
func (peer *Peer) keepKeyFreshReceiving() { func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake.Get() { if peer.timers.sentLastMinuteHandshake.Load() {
return return
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake.Set(true) peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
@@ -68,15 +66,16 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for * Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately) * IPv4 and IPv6 (separately)
*/ */
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() { defer func() {
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP) device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
device.queue.decryption.wg.Done() device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done() device.queue.handshake.wg.Done()
device.net.stopping.Done() device.net.stopping.Done()
}() }()
device.log.Verbosef("Routine: receive incoming IPv%d - started", IP) device.log.Verbosef("Routine: receive incoming %s - started", recvName)
// receive datagrams until conn is closed // receive datagrams until conn is closed
@@ -90,24 +89,21 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
) )
for { for {
switch IP { size, endpoint, err = recv(buffer[:])
case ipv4.Version:
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
panic("invalid IP version")
}
if err != nil { if err != nil {
device.PutMessageBuffer(buffer) device.PutMessageBuffer(buffer)
if errors.Is(err, conn.NetErrClosed) { if errors.Is(err, net.ErrClosed) {
return
}
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
return return
} }
device.log.Errorf("Failed to receive packet: %v", err)
if deathSpiral < 10 { if deathSpiral < 10 {
deathSpiral++ deathSpiral++
time.Sleep(time.Second / 3) time.Sleep(time.Second / 3)
buffer = device.GetMessageBuffer()
continue continue
} }
return return
@@ -166,7 +162,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
elem.Lock() elem.Lock()
// add to decryption queues // add to decryption queues
if peer.isRunning.Get() { if peer.isRunning.Load() {
peer.queue.inbound.c <- elem peer.queue.inbound.c <- elem
device.queue.decryption.c <- elem device.queue.decryption.c <- elem
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
@@ -205,11 +201,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
} }
} }
func (device *Device) RoutineDecryption() { func (device *Device) RoutineDecryption(id int) {
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: decryption worker - stopped") defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker - started") device.log.Verbosef("Routine: decryption worker %d - started", id)
for elem := range device.queue.decryption.c { for elem := range device.queue.decryption.c {
// split message into fields // split message into fields
@@ -236,12 +232,12 @@ func (device *Device) RoutineDecryption() {
/* Handles incoming packets related to handshake /* Handles incoming packets related to handshake
*/ */
func (device *Device) RoutineHandshake() { func (device *Device) RoutineHandshake(id int) {
defer func() { defer func() {
device.log.Verbosef("Routine: handshake worker - stopped") device.log.Verbosef("Routine: handshake worker %d - stopped", id)
device.queue.encryption.wg.Done() device.queue.encryption.wg.Done()
}() }()
device.log.Verbosef("Routine: handshake worker - started") device.log.Verbosef("Routine: handshake worker %d - started", id)
for elem := range device.queue.handshake.c { for elem := range device.queue.handshake.c {
@@ -271,7 +267,7 @@ func (device *Device) RoutineHandshake() {
// consume reply // consume reply
if peer := entry.peer; peer.isRunning.Get() { if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) { if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef("Could not decrypt invalid cookie response") device.log.Verbosef("Could not decrypt invalid cookie response")
@@ -344,7 +340,7 @@ func (device *Device) RoutineHandshake() {
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer) device.log.Verbosef("%v - Received handshake initiation", peer)
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse() peer.SendHandshakeResponse()
@@ -372,7 +368,7 @@ func (device *Device) RoutineHandshake() {
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer) device.log.Verbosef("%v - Received handshake response", peer)
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers // update timers
@@ -429,7 +425,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.keepKeyFreshReceiving() peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived() peer.timersAnyAuthenticatedPacketReceived()
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
if len(elem.packet) == 0 { if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer) device.log.Verbosef("%v - Receiving keepalive packet", peer)
@@ -449,7 +445,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
} }
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.LookupIPv4(src) != peer { if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
goto skip goto skip
} }
@@ -466,7 +462,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
} }
elem.packet = elem.packet[:length] elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer { if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
goto skip goto skip
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -8,9 +8,10 @@ package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"net" "net"
"os"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
@@ -74,7 +75,7 @@ func (elem *QueueOutboundElement) clearPointers() {
/* Queues a keepalive if no packets are queued for peer /* Queues a keepalive if no packets are queued for peer
*/ */
func (peer *Peer) SendKeepalive() { func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Get() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement() elem := peer.device.NewOutboundElement()
select { select {
case peer.queue.staged <- elem: case peer.queue.staged <- elem:
@@ -89,7 +90,7 @@ func (peer *Peer) SendKeepalive() {
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry { if !isRetry {
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
} }
peer.handshake.mutex.RLock() peer.handshake.mutex.RLock()
@@ -191,7 +192,7 @@ func (peer *Peer) keepKeyFreshSending() {
if keypair == nil { if keypair == nil {
return return
} }
nonce := atomic.LoadUint64(&keypair.sendNonce) nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
@@ -224,11 +225,12 @@ func (device *Device) RoutineReadFromTUN() {
offset := MessageTransportHeaderSize offset := MessageTransportHeaderSize
size, err := device.tun.device.Read(elem.buffer[:], offset) size, err := device.tun.device.Read(elem.buffer[:], offset)
if err != nil { if err != nil {
if !device.isClosed() { if !device.isClosed() {
device.log.Errorf("Failed to read packet from TUN device: %v", err) if !errors.Is(err, os.ErrClosed) {
device.Close() device.log.Errorf("Failed to read packet from TUN device: %v", err)
}
go device.Close()
} }
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
@@ -250,14 +252,14 @@ func (device *Device) RoutineReadFromTUN() {
continue continue
} }
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.LookupIPv4(dst) peer = device.allowedips.Lookup(dst)
case ipv6.Version: case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen { if len(elem.packet) < ipv6.HeaderLen {
continue continue
} }
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.LookupIPv6(dst) peer = device.allowedips.Lookup(dst)
default: default:
device.log.Verbosef("Received packet with unknown IP version") device.log.Verbosef("Received packet with unknown IP version")
@@ -266,7 +268,7 @@ func (device *Device) RoutineReadFromTUN() {
if peer == nil { if peer == nil {
continue continue
} }
if peer.isRunning.Get() { if peer.isRunning.Load() {
peer.StagePacket(elem) peer.StagePacket(elem)
elem = nil elem = nil
peer.SendStagedPackets() peer.SendStagedPackets()
@@ -297,7 +299,7 @@ top:
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
return return
} }
@@ -306,9 +308,9 @@ top:
select { select {
case elem := <-peer.queue.staged: case elem := <-peer.queue.staged:
elem.peer = peer elem.peer = peer
elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages { if elem.nonce >= RejectAfterMessages {
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) keypair.sendNonce.Store(RejectAfterMessages)
peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
goto top goto top
} }
@@ -317,7 +319,7 @@ top:
elem.Lock() elem.Lock()
// add to parallel and sequential queue // add to parallel and sequential queue
if peer.isRunning.Get() { if peer.isRunning.Load() {
peer.queue.outbound.c <- elem peer.queue.outbound.c <- elem
peer.device.queue.encryption.c <- elem peer.device.queue.encryption.c <- elem
} else { } else {
@@ -362,12 +364,12 @@ func calculatePaddingSize(packetSize, mtu int) int {
* *
* Obs. One instance per core * Obs. One instance per core
*/ */
func (device *Device) RoutineEncryption() { func (device *Device) RoutineEncryption(id int) {
var paddingZeros [PaddingMultiple]byte var paddingZeros [PaddingMultiple]byte
var nonce [chacha20poly1305.NonceSize]byte var nonce [chacha20poly1305.NonceSize]byte
defer device.log.Verbosef("Routine: encryption worker - stopped") defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker - started") device.log.Verbosef("Routine: encryption worker %d - started", id)
for elem := range device.queue.encryption.c { for elem := range device.queue.encryption.c {
// populate header fields // populate header fields
@@ -382,7 +384,7 @@ func (device *Device) RoutineEncryption() {
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16 // pad content to multiple of 16
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu))) paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer // encrypt content and release to consumer
@@ -416,12 +418,12 @@ func (peer *Peer) RoutineSequentialSender() {
return return
} }
elem.Lock() elem.Lock()
if !peer.isRunning.Get() { if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool. // peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped // This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed. // immediately after this check, in which case, elem will get processed.
// The timers and SendBuffer code are resilient to a few stragglers. // The timers and SendBuffer code are resilient to a few stragglers.
// TODO(josharian): rework peer shutdown order to ensure // TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary. // that we never accidentally keep timers alive longer than necessary.
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)

View File

@@ -1,4 +1,4 @@
// +build !linux android //go:build !linux
package device package device

View File

@@ -1,8 +1,6 @@
// +build !android
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* *
* This implements userspace semantics of "sticky sockets", modeled after * This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port * WireGuard's kernelspace implementation. This is more or less a straight port
@@ -21,11 +19,16 @@ import (
"unsafe" "unsafe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/rwcancel"
) )
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
if _, ok := bind.(*conn.LinuxSocketBind); !ok {
return nil, nil
}
netlinkSock, err := createNetlinkRouteSocket() netlinkSock, err := createNetlinkRouteSocket()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -109,11 +112,11 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
pePtr.peer.Unlock() pePtr.peer.Unlock()
break break
} }
if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx { if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
pePtr.peer.Unlock() pePtr.peer.Unlock()
break break
} }
pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc() pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
pePtr.peer.Unlock() pePtr.peer.Unlock()
} }
attr = attr[attrhdr.Len:] attr = attr[attrhdr.Len:]
@@ -133,7 +136,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
peer.RUnlock() peer.RUnlock()
continue continue
} }
nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint) nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
if nativeEP == nil { if nativeEP == nil {
peer.RUnlock() peer.RUnlock()
continue continue
@@ -176,7 +179,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8, Len: 8,
Type: unix.RTA_MARK, Type: unix.RTA_MARK,
}, },
uint32(bind.LastMark()), device.net.fwmark,
} }
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock() reqPeerLock.Lock()
@@ -201,7 +204,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
} }
func createNetlinkRouteSocket() (int, error) { func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil { if err != nil {
return -1, err return -1, err
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
* *
* This is based heavily on timers.c from the kernel implementation. * This is based heavily on timers.c from the kernel implementation.
*/ */
@@ -8,12 +8,14 @@
package device package device
import ( import (
"math/rand"
"sync" "sync"
"sync/atomic"
"time" "time"
_ "unsafe"
) )
//go:linkname fastrandn runtime.fastrandn
func fastrandn(n uint32) uint32
// A Timer manages time-based aspects of the WireGuard protocol. // A Timer manages time-based aspects of the WireGuard protocol.
// Timer roughly copies the interface of the Linux kernel's struct timer_list. // Timer roughly copies the interface of the Linux kernel's struct timer_list.
type Timer struct { type Timer struct {
@@ -71,11 +73,11 @@ func (timer *Timer) IsPending() bool {
} }
func (peer *Peer) timersActive() bool { func (peer *Peer) timersActive() bool {
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp() && !peer.device.peers.empty.Get() return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
} }
func expiredRetransmitHandshake(peer *Peer) { func expiredRetransmitHandshake(peer *Peer) {
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() { if peer.timersActive() {
@@ -94,8 +96,8 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
} }
} else { } else {
atomic.AddUint32(&peer.timers.handshakeAttempts, 1) peer.timers.handshakeAttempts.Add(1)
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock() peer.Lock()
@@ -110,8 +112,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) { func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive() peer.SendKeepalive()
if peer.timers.needAnotherKeepalive.Get() { if peer.timers.needAnotherKeepalive.Load() {
peer.timers.needAnotherKeepalive.Set(false) peer.timers.needAnotherKeepalive.Store(false)
if peer.timersActive() { if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} }
@@ -127,7 +129,6 @@ func expiredNewHandshake(peer *Peer) {
} }
peer.Unlock() peer.Unlock()
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
func expiredZeroKeyMaterial(peer *Peer) { func expiredZeroKeyMaterial(peer *Peer) {
@@ -136,7 +137,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
} }
func expiredPersistentKeepalive(peer *Peer) { func expiredPersistentKeepalive(peer *Peer) {
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive() peer.SendKeepalive()
} }
} }
@@ -144,7 +145,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */ /* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() { func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
} }
} }
@@ -154,7 +155,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() { if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else { } else {
peer.timers.needAnotherKeepalive.Set(true) peer.timers.needAnotherKeepalive.Store(true)
} }
} }
} }
@@ -176,7 +177,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
/* Should be called after a handshake initiation message is sent. */ /* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() { func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
} }
} }
@@ -185,9 +186,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Del() peer.timers.retransmitHandshake.Del()
} }
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Set(false) peer.timers.sentLastMinuteHandshake.Store(false)
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) peer.lastHandshakeNano.Store(time.Now().UnixNano())
} }
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@@ -199,7 +200,7 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval) keepalive := peer.persistentKeepaliveInterval.Load()
if keepalive > 0 && peer.timersActive() { if keepalive > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
} }
@@ -214,9 +215,9 @@ func (peer *Peer) timersInit() {
} }
func (peer *Peer) timersStart() { func (peer *Peer) timersStart() {
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) peer.timers.handshakeAttempts.Store(0)
peer.timers.sentLastMinuteHandshake.Set(false) peer.timers.sentLastMinuteHandshake.Store(false)
peer.timers.needAnotherKeepalive.Set(false) peer.timers.needAnotherKeepalive.Store(false)
} }
func (peer *Peer) timersStop() { func (peer *Peer) timersStop() {

View File

@@ -1,13 +1,12 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
import ( import (
"fmt" "fmt"
"sync/atomic"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@@ -33,7 +32,7 @@ func (device *Device) RoutineTUNEventReader() {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
mtu = MaxContentSize mtu = MaxContentSize
} }
old := atomic.SwapInt32(&device.tun.mtu, int32(mtu)) old := device.tun.mtu.Swap(int32(mtu))
if int(old) != mtu { if int(old) != mtu {
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
} }

View File

@@ -1,56 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"os"
"golang.zx2c4.com/wireguard/tun"
)
// newDummyTUN creates a dummy TUN device with the specified name.
func newDummyTUN(name string) tun.Device {
return &dummyTUN{
name: name,
packets: make(chan []byte, 100),
events: make(chan tun.Event, 10),
}
}
// A dummyTUN is a tun.Device which is used in unit tests.
type dummyTUN struct {
name string
mtu int
packets chan []byte
events chan tun.Event
}
func (d *dummyTUN) Events() chan tun.Event { return d.events }
func (*dummyTUN) File() *os.File { return nil }
func (*dummyTUN) Flush() error { return nil }
func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil }
func (d *dummyTUN) Name() (string, error) { return d.name, nil }
func (d *dummyTUN) Close() error {
close(d.events)
close(d.packets)
return nil
}
func (d *dummyTUN) Read(b []byte, offset int) (int, error) {
buf, ok := <-d.packets
if !ok {
return 0, errors.New("device closed")
}
copy(b[offset:], buf)
return len(buf), nil
}
func (d *dummyTUN) Write(b []byte, offset int) (int, error) {
d.packets <- b[offset:]
return len(b), nil
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package device
@@ -12,13 +12,12 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
) )
@@ -39,12 +38,12 @@ func (s IPCError) ErrorCode() int64 {
return s.code return s.code
} }
func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError { func ipcErrorf(code int64, msg string, args ...any) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)} return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
} }
var byteBufferPool = &sync.Pool{ var byteBufferPool = &sync.Pool{
New: func() interface{} { return new(bytes.Buffer) }, New: func() any { return new(bytes.Buffer) },
} }
// IpcGetOperation implements the WireGuard configuration protocol "get" operation. // IpcGetOperation implements the WireGuard configuration protocol "get" operation.
@@ -56,7 +55,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
buf := byteBufferPool.Get().(*bytes.Buffer) buf := byteBufferPool.Get().(*bytes.Buffer)
buf.Reset() buf.Reset()
defer byteBufferPool.Put(buf) defer byteBufferPool.Put(buf)
sendf := func(format string, args ...interface{}) { sendf := func(format string, args ...any) {
fmt.Fprintf(buf, format, args...) fmt.Fprintf(buf, format, args...)
buf.WriteByte('\n') buf.WriteByte('\n')
} }
@@ -73,7 +72,6 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
} }
func() { func() {
// lock required resources // lock required resources
device.net.RLock() device.net.RLock()
@@ -99,33 +97,35 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark) sendf("fwmark=%d", device.net.fwmark)
} }
// serialize each peer state
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.RLock() // Serialize peer state.
defer peer.RUnlock() // Do the work in an anonymous function so that we can use defer.
func() {
peer.RLock()
defer peer.RUnlock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
sendf("protocol_version=1") sendf("protocol_version=1")
if peer.endpoint != nil { if peer.endpoint != nil {
sendf("endpoint=%s", peer.endpoint.DstToString()) sendf("endpoint=%s", peer.endpoint.DstToString())
} }
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds() secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds() nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs) sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano) sendf("last_handshake_time_nsec=%d", nano)
sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)) sendf("tx_bytes=%d", peer.txBytes.Load())
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("rx_bytes=%d", peer.rxBytes.Load())
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool { device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s/%d", ip.String(), cidr) sendf("allowed_ip=%s", prefix.String())
return true return true
}) })
}()
} }
}() }()
@@ -157,14 +157,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
line := scanner.Text() line := scanner.Text()
if line == "" { if line == "" {
// Blank line means terminate operation. // Blank line means terminate operation.
peer.handlePostConfig()
return nil return nil
} }
parts := strings.Split(line, "=") key, value, ok := strings.Cut(line, "=")
if len(parts) != 2 { if !ok {
return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts)) return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
} }
key := parts[0]
value := parts[1]
if key == "public_key" { if key == "public_key" {
if deviceConfig { if deviceConfig {
@@ -255,10 +254,21 @@ type ipcSetPeer struct {
*Peer // Peer is the current peer being operated on *Peer // Peer is the current peer being operated on
dummy bool // dummy reports whether this peer is a temporary, placeholder peer dummy bool // dummy reports whether this peer is a temporary, placeholder peer
created bool // new reports whether this is a newly created peer created bool // new reports whether this is a newly created peer
pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
} }
func (peer *ipcSetPeer) handlePostConfig() { func (peer *ipcSetPeer) handlePostConfig() {
if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp() { if peer.Peer == nil || peer.dummy {
return
}
if peer.created {
peer.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint != nil
}
if peer.device.isUp() {
peer.Start()
if peer.pkaOn {
peer.SendKeepalive()
}
peer.SendStagedPackets() peer.SendStagedPackets()
} }
} }
@@ -331,7 +341,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "endpoint": case "endpoint":
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
endpoint, err := conn.CreateEndpoint(value) endpoint, err := device.net.bind.ParseEndpoint(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
} }
@@ -347,17 +357,10 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
} }
old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on. // Send immediate keepalive if we're turning it on and before it wasn't on.
if old == 0 && secs != 0 { peer.pkaOn = old == 0 && secs != 0
if err != nil {
return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
}
if device.isUp() && !peer.dummy {
peer.SendKeepalive()
}
}
case "replace_allowed_ips": case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
@@ -371,16 +374,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip": case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
prefix, err := netip.ParsePrefix(value)
_, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
} }
if peer.dummy { if peer.dummy {
return nil return nil
} }
ones, _ := network.Mask.Size() device.allowedips.Insert(prefix, peer.Peer)
device.allowedips.Insert(network.IP, uint(ones), peer.Peer)
case "protocol_version": case "protocol_version":
if value != "1" { if value != "1" {

51
format_test.go Normal file
View File

@@ -0,0 +1,51 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"bytes"
"go/format"
"io/fs"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
)
func TestFormatting(t *testing.T) {
var wg sync.WaitGroup
filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
t.Errorf("unable to walk %s: %v", path, err)
return nil
}
if d.IsDir() || filepath.Ext(path) != ".go" {
return nil
}
wg.Add(1)
go func(path string) {
defer wg.Done()
src, err := os.ReadFile(path)
if err != nil {
t.Errorf("unable to read %s: %v", path, err)
return
}
if runtime.GOOS == "windows" {
src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
}
formatted, err := format.Source(src)
if err != nil {
t.Errorf("unable to format %s: %v", path, err)
return
}
if !bytes.Equal(src, formatted) {
t.Errorf("unformatted code: %s", path)
}
}(path)
return nil
})
wg.Wait()
}

15
go.mod
View File

@@ -1,9 +1,16 @@
module golang.zx2c4.com/wireguard module golang.zx2c4.com/wireguard
go 1.15 go 1.19
require ( require (
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd
golang.org/x/net v0.0.0-20201224014010-6772e930b67b golang.org/x/net v0.0.0-20220225172249-27dd8689420f
golang.org/x/sys v0.0.0-20210105210732-16f7687f5001 golang.org/x/sys v0.2.0
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0
)
require (
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
) )

31
go.sum
View File

@@ -1,17 +1,14 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd h1:XcWmESyNjXJMLahc3mqVQJcgSTDxFxhETVlfk9uGc38=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/crypto v0.0.0-20220315160706-3147a52a75dd/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b h1:iFwSg7t5GZmB/Q5TjiEAsdoLDrdJRC1RiF2WhuV29Qw= golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc=
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/sys v0.0.0-20210105210732-16f7687f5001 h1:/dSxr6gT0FNI1MO5WLJo8mTmItROeOKTkDn+7OwWBos= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/sys v0.0.0-20210105210732-16f7687f5001/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 h1:Ug9qvr1myri/zFN6xL17LSCBGFDnphBBhzmILHsM5TY=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e h1:FDhOuMEY4JVRztM/gsbk+IKUQ8kj74bxZrgw87eMMVc=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -1,62 +1,31 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
// +build windows // +build windows
/* SPDX-License-Identifier: MIT package namedpipe
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import ( import (
"errors"
"io" "io"
"os"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
type atomicBool int32
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
func (b *atomicBool) swap(new bool) bool {
var newInt int32
if new {
newInt = 1
}
return atomic.SwapInt32((*int32)(b), newInt) == 1
}
const (
cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1
cFILE_SKIP_SET_EVENT_ON_HANDLE = 2
)
var (
ErrFileClosed = errors.New("file has already been closed")
ErrTimeout = &timeoutError{}
)
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{} type timeoutChan chan struct{}
var ioInitOnce sync.Once var (
var ioCompletionPort windows.Handle ioInitOnce sync.Once
ioCompletionPort windows.Handle
)
// ioResult contains the result of an asynchronous IO operation // ioResult contains the result of an asynchronous IO operation
type ioResult struct { type ioResult struct {
@@ -71,7 +40,7 @@ type ioOperation struct {
} }
func initIo() { func initIo() {
h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff) h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -79,13 +48,13 @@ func initIo() {
go ioCompletionProcessor(h) go ioCompletionProcessor(h)
} }
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. // file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected. // It takes ownership of this handle and will close it if it is garbage collected.
type win32File struct { type file struct {
handle windows.Handle handle windows.Handle
wg sync.WaitGroup wg sync.WaitGroup
wgLock sync.RWMutex wgLock sync.RWMutex
closing atomicBool closing atomic.Bool
socket bool socket bool
readDeadline deadlineHandler readDeadline deadlineHandler
writeDeadline deadlineHandler writeDeadline deadlineHandler
@@ -96,18 +65,18 @@ type deadlineHandler struct {
channel timeoutChan channel timeoutChan
channelLock sync.RWMutex channelLock sync.RWMutex
timer *time.Timer timer *time.Timer
timedout atomicBool timedout atomic.Bool
} }
// makeWin32File makes a new win32File from an existing file handle // makeFile makes a new file from an existing file handle
func makeWin32File(h windows.Handle) (*win32File, error) { func makeFile(h windows.Handle) (*file, error) {
f := &win32File{handle: h} f := &file{handle: h}
ioInitOnce.Do(initIo) ioInitOnce.Do(initIo)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE) err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -116,18 +85,14 @@ func makeWin32File(h windows.Handle) (*win32File, error) {
return f, nil return f, nil
} }
func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
return makeWin32File(h)
}
// closeHandle closes the resources associated with a Win32 handle // closeHandle closes the resources associated with a Win32 handle
func (f *win32File) closeHandle() { func (f *file) closeHandle() {
f.wgLock.Lock() f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once. // Atomically set that we are closing, releasing the resources only once.
if !f.closing.swap(true) { if f.closing.Swap(true) == false {
f.wgLock.Unlock() f.wgLock.Unlock()
// cancel all IO and wait for it to complete // cancel all IO and wait for it to complete
cancelIoEx(f.handle, nil) windows.CancelIoEx(f.handle, nil)
f.wg.Wait() f.wg.Wait()
// at this point, no new IO can start // at this point, no new IO can start
windows.Close(f.handle) windows.Close(f.handle)
@@ -137,19 +102,19 @@ func (f *win32File) closeHandle() {
} }
} }
// Close closes a win32File. // Close closes a file.
func (f *win32File) Close() error { func (f *file) Close() error {
f.closeHandle() f.closeHandle()
return nil return nil
} }
// prepareIo prepares for a new IO operation. // prepareIo prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *win32File) prepareIo() (*ioOperation, error) { func (f *file) prepareIo() (*ioOperation, error) {
f.wgLock.RLock() f.wgLock.RLock()
if f.closing.isSet() { if f.closing.Load() {
f.wgLock.RUnlock() f.wgLock.RUnlock()
return nil, ErrFileClosed return nil, os.ErrClosed
} }
f.wg.Add(1) f.wg.Add(1)
f.wgLock.RUnlock() f.wgLock.RUnlock()
@@ -164,7 +129,7 @@ func ioCompletionProcessor(h windows.Handle) {
var bytes uint32 var bytes uint32
var key uintptr var key uintptr
var op *ioOperation var op *ioOperation
err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE) err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE)
if op == nil { if op == nil {
panic(err) panic(err)
} }
@@ -174,13 +139,13 @@ func ioCompletionProcessor(h windows.Handle) {
// asyncIo processes the return value from ReadFile or WriteFile, blocking until // asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed. // the operation has actually completed.
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != windows.ERROR_IO_PENDING { if err != windows.ERROR_IO_PENDING {
return int(bytes), err return int(bytes), err
} }
if f.closing.isSet() { if f.closing.Load() {
cancelIoEx(f.handle, &c.o) windows.CancelIoEx(f.handle, &c.o)
} }
var timeout timeoutChan var timeout timeoutChan
@@ -195,20 +160,20 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
case r = <-c.ch: case r = <-c.ch:
err = r.err err = r.err
if err == windows.ERROR_OPERATION_ABORTED { if err == windows.ERROR_OPERATION_ABORTED {
if f.closing.isSet() { if f.closing.Load() {
err = ErrFileClosed err = os.ErrClosed
} }
} else if err != nil && f.socket { } else if err != nil && f.socket {
// err is from Win32. Query the overlapped structure to get the winsock error. // err is from Win32. Query the overlapped structure to get the winsock error.
var bytes, flags uint32 var bytes, flags uint32
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
} }
case <-timeout: case <-timeout:
cancelIoEx(f.handle, &c.o) windows.CancelIoEx(f.handle, &c.o)
r = <-c.ch r = <-c.ch
err = r.err err = r.err
if err == windows.ERROR_OPERATION_ABORTED { if err == windows.ERROR_OPERATION_ABORTED {
err = ErrTimeout err = os.ErrDeadlineExceeded
} }
} }
@@ -220,15 +185,15 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
} }
// Read reads from a file handle. // Read reads from a file handle.
func (f *win32File) Read(b []byte) (int, error) { func (f *file) Read(b []byte) (int, error) {
c, err := f.prepareIo() c, err := f.prepareIo()
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer f.wg.Done() defer f.wg.Done()
if f.readDeadline.timedout.isSet() { if f.readDeadline.timedout.Load() {
return 0, ErrTimeout return 0, os.ErrDeadlineExceeded
} }
var bytes uint32 var bytes uint32
@@ -247,15 +212,15 @@ func (f *win32File) Read(b []byte) (int, error) {
} }
// Write writes to a file handle. // Write writes to a file handle.
func (f *win32File) Write(b []byte) (int, error) { func (f *file) Write(b []byte) (int, error) {
c, err := f.prepareIo() c, err := f.prepareIo()
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer f.wg.Done() defer f.wg.Done()
if f.writeDeadline.timedout.isSet() { if f.writeDeadline.timedout.Load() {
return 0, ErrTimeout return 0, os.ErrDeadlineExceeded
} }
var bytes uint32 var bytes uint32
@@ -265,19 +230,19 @@ func (f *win32File) Write(b []byte) (int, error) {
return n, err return n, err
} }
func (f *win32File) SetReadDeadline(deadline time.Time) error { func (f *file) SetReadDeadline(deadline time.Time) error {
return f.readDeadline.set(deadline) return f.readDeadline.set(deadline)
} }
func (f *win32File) SetWriteDeadline(deadline time.Time) error { func (f *file) SetWriteDeadline(deadline time.Time) error {
return f.writeDeadline.set(deadline) return f.writeDeadline.set(deadline)
} }
func (f *win32File) Flush() error { func (f *file) Flush() error {
return windows.FlushFileBuffers(f.handle) return windows.FlushFileBuffers(f.handle)
} }
func (f *win32File) Fd() uintptr { func (f *file) Fd() uintptr {
return uintptr(f.handle) return uintptr(f.handle)
} }
@@ -291,7 +256,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
} }
d.timer = nil d.timer = nil
} }
d.timedout.setFalse() d.timedout.Store(false)
select { select {
case <-d.channel: case <-d.channel:
@@ -306,7 +271,7 @@ func (d *deadlineHandler) set(deadline time.Time) error {
} }
timeoutIO := func() { timeoutIO := func() {
d.timedout.setTrue() d.timedout.Store(true)
close(d.channel) close(d.channel)
} }

486
ipc/namedpipe/namedpipe.go Normal file
View File

@@ -0,0 +1,486 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
// +build windows
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
package namedpipe
import (
"context"
"io"
"net"
"os"
"runtime"
"sync/atomic"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
type pipe struct {
*file
path string
}
type messageBytePipe struct {
pipe
writeClosed atomic.Bool
readEOF bool
}
type pipeAddress string
func (f *pipe) LocalAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *pipe) RemoteAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *pipe) SetDeadline(t time.Time) error {
f.SetReadDeadline(t)
f.SetWriteDeadline(t)
return nil
}
// CloseWrite closes the write side of a message pipe in byte mode.
func (f *messageBytePipe) CloseWrite() error {
if !f.writeClosed.CompareAndSwap(false, true) {
return io.ErrClosedPipe
}
err := f.file.Flush()
if err != nil {
f.writeClosed.Store(false)
return err
}
_, err = f.file.Write(nil)
if err != nil {
f.writeClosed.Store(false)
return err
}
return nil
}
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite.
func (f *messageBytePipe) Write(b []byte) (int, error) {
if f.writeClosed.Load() {
return 0, io.ErrClosedPipe
}
if len(b) == 0 {
return 0, nil
}
return f.file.Write(b)
}
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
// mode pipe will return io.EOF, as will all subsequent reads.
func (f *messageBytePipe) Read(b []byte) (int, error) {
if f.readEOF {
return 0, io.EOF
}
n, err := f.file.Read(b)
if err == io.EOF {
// If this was the result of a zero-byte read, then
// it is possible that the read was due to a zero-size
// message. Since we are simulating CloseWrite with a
// zero-byte message, ensure that all future Read calls
// also return EOF.
f.readEOF = true
} else if err == windows.ERROR_MORE_DATA {
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams.
err = nil
}
return n, err
}
func (f *pipe) Handle() windows.Handle {
return f.handle
}
func (s pipeAddress) Network() string {
return "pipe"
}
func (s pipeAddress) String() string {
return string(s)
}
// tryDialPipe attempts to dial the specified pipe until cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
for {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
path16, err := windows.UTF16PtrFromString(*path)
if err != nil {
return 0, err
}
h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
if err == nil {
return h, nil
}
if err != windows.ERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path}
}
// Wait 10 msec and try again. This is a rather simplistic
// view, as we always try each 10 milliseconds.
time.Sleep(10 * time.Millisecond)
}
}
}
// DialConfig exposes various options for use in Dial and DialContext.
type DialConfig struct {
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
}
// DialTimeout connects to the specified named pipe by path, timing out if the
// connection takes longer than the specified duration. If timeout is zero, then
// we use a default timeout of 2 seconds.
func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
if timeout == 0 {
timeout = time.Second * 2
}
absTimeout := time.Now().Add(timeout)
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := config.DialContext(ctx, path)
if err == context.DeadlineExceeded {
return nil, os.ErrDeadlineExceeded
}
return conn, err
}
// DialContext attempts to connect to the specified named pipe by path.
func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
var err error
var h windows.Handle
h, err = tryDialPipe(ctx, &path)
if err != nil {
return nil, err
}
if config.ExpectedOwner != nil {
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
if err != nil {
windows.Close(h)
return nil, err
}
realOwner, _, err := sd.Owner()
if err != nil {
windows.Close(h)
return nil, err
}
if !realOwner.Equals(config.ExpectedOwner) {
windows.Close(h)
return nil, windows.ERROR_ACCESS_DENIED
}
}
var flags uint32
err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil {
windows.Close(h)
return nil, err
}
f, err := makeFile(h)
if err != nil {
windows.Close(h)
return nil, err
}
// If the pipe is in message mode, return a message byte pipe, which
// supports CloseWrite.
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
return &messageBytePipe{
pipe: pipe{file: f, path: path},
}, nil
}
return &pipe{file: f, path: path}, nil
}
var defaultDialer DialConfig
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialTimeout(path, timeout)
}
// DialContext calls DialConfig.DialContext using an empty configuration.
func DialContext(ctx context.Context, path string) (net.Conn, error) {
return defaultDialer.DialContext(ctx, path)
}
type acceptResponse struct {
f *file
err error
}
type pipeListener struct {
firstHandle windows.Handle
path string
config ListenConfig
acceptCh chan chan acceptResponse
closeCh chan int
doneCh chan int
}
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
path16, err := windows.UTF16PtrFromString(path)
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
var oa windows.OBJECT_ATTRIBUTES
oa.Length = uint32(unsafe.Sizeof(oa))
var ntPath windows.NTUnicodeString
if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil {
if ntstatus, ok := err.(windows.NTStatus); ok {
err = ntstatus.Errno()
}
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer)))
oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe.
if isFirstPipe {
if sd != nil {
oa.SecurityDescriptor = sd
} else {
// Construct the default named pipe security descriptor.
var acl *windows.ACL
if err := windows.RtlDefaultNpAcl(&acl); err != nil {
return 0, err
}
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
sd, err = windows.NewSecurityDescriptor()
if err != nil {
return 0, err
}
if err = sd.SetDACL(acl, true, false); err != nil {
return 0, err
}
oa.SecurityDescriptor = sd
}
}
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
if c.MessageMode {
typ |= windows.FILE_PIPE_MESSAGE_TYPE
}
disposition := uint32(windows.FILE_OPEN)
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
if isFirstPipe {
disposition = windows.FILE_CREATE
// By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking
// client connections until the next call with isFirstPipe == false.
access = windows.SYNCHRONIZE
}
timeout := int64(-50 * 10000) // 50ms
var (
h windows.Handle
iosb windows.IO_STATUS_BLOCK
)
err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout)
if err != nil {
if ntstatus, ok := err.(windows.NTStatus); ok {
err = ntstatus.Errno()
}
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
runtime.KeepAlive(ntPath)
return h, nil
}
func (l *pipeListener) makeServerPipe() (*file, error) {
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
if err != nil {
return nil, err
}
f, err := makeFile(h)
if err != nil {
windows.Close(h)
return nil, err
}
return f, nil
}
func (l *pipeListener) makeConnectedServerPipe() (*file, error) {
p, err := l.makeServerPipe()
if err != nil {
return nil, err
}
// Wait for the client to connect.
ch := make(chan error)
go func(p *file) {
ch <- connectPipe(p)
}(p)
select {
case err = <-ch:
if err != nil {
p.Close()
p = nil
}
case <-l.closeCh:
// Abort the connect request by closing the handle.
p.Close()
p = nil
err = <-ch
if err == nil || err == os.ErrClosed {
err = net.ErrClosed
}
}
return p, err
}
func (l *pipeListener) listenerRoutine() {
closed := false
for !closed {
select {
case <-l.closeCh:
closed = true
case responseCh := <-l.acceptCh:
var (
p *file
err error
)
for {
p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try
// again.
if err != windows.ERROR_NO_DATA {
break
}
}
responseCh <- acceptResponse{p, err}
closed = err == net.ErrClosed
}
}
windows.Close(l.firstHandle)
l.firstHandle = 0
// Notify Close and Accept callers that the handle has been closed.
close(l.doneCh)
}
// ListenConfig contains configuration for the pipe listener.
type ListenConfig struct {
// SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used.
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
// MessageMode determines whether the pipe is in byte or message mode. In either
// case the pipe is read in byte mode by default. The only practical difference in
// this implementation is that CloseWrite is only supported for message mode pipes;
// CloseWrite is implemented as a zero-byte write, but zero-byte writes are only
// transferred to the reader (and returned as io.EOF in this implementation)
// when the pipe is in message mode.
MessageMode bool
// InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed.
InputBufferSize int32
// OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed.
OutputBufferSize int32
}
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
// The pipe must not already exist.
func (c *ListenConfig) Listen(path string) (net.Listener, error) {
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
if err != nil {
return nil, err
}
l := &pipeListener{
firstHandle: h,
path: path,
config: *c,
acceptCh: make(chan chan acceptResponse),
closeCh: make(chan int),
doneCh: make(chan int),
}
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
path16, err := windows.UTF16PtrFromString(path)
if err == nil {
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
if err == nil {
windows.CloseHandle(h)
}
}
}
go l.listenerRoutine()
return l, nil
}
var defaultListener ListenConfig
// Listen calls ListenConfig.Listen using an empty configuration.
func Listen(path string) (net.Listener, error) {
return defaultListener.Listen(path)
}
func connectPipe(p *file) error {
c, err := p.prepareIo()
if err != nil {
return err
}
defer p.wg.Done()
err = windows.ConnectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err)
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
return err
}
return nil
}
func (l *pipeListener) Accept() (net.Conn, error) {
ch := make(chan acceptResponse)
select {
case l.acceptCh <- ch:
response := <-ch
err := response.err
if err != nil {
return nil, err
}
if l.config.MessageMode {
return &messageBytePipe{
pipe: pipe{file: response.f, path: l.path},
}, nil
}
return &pipe{file: response.f, path: l.path}, nil
case <-l.doneCh:
return nil, net.ErrClosed
}
}
func (l *pipeListener) Close() error {
select {
case l.closeCh <- 1:
<-l.doneCh
case <-l.doneCh:
}
return nil
}
func (l *pipeListener) Addr() net.Addr {
return pipeAddress(l.path)
}

View File

@@ -0,0 +1,675 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Copyright 2015 Microsoft
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows
// +build windows
package namedpipe_test
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"net"
"os"
"sync"
"syscall"
"testing"
"time"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
)
func randomPipePath() string {
guid, err := windows.GenerateGUID()
if err != nil {
panic(err)
}
return `\\.\PIPE\go-namedpipe-test-` + guid.String()
}
func TestPingPong(t *testing.T) {
const (
ping = 42
pong = 24
)
pipePath := randomPipePath()
listener, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatalf("unable to listen on pipe: %v", err)
}
defer listener.Close()
go func() {
incoming, err := listener.Accept()
if err != nil {
t.Fatalf("unable to accept pipe connection: %v", err)
}
defer incoming.Close()
var data [1]byte
_, err = incoming.Read(data[:])
if err != nil {
t.Fatalf("unable to read ping from pipe: %v", err)
}
if data[0] != ping {
t.Fatalf("expected ping, got %d", data[0])
}
data[0] = pong
_, err = incoming.Write(data[:])
if err != nil {
t.Fatalf("unable to write pong to pipe: %v", err)
}
}()
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatalf("unable to dial pipe: %v", err)
}
defer client.Close()
client.SetDeadline(time.Now().Add(time.Second * 5))
var data [1]byte
data[0] = ping
_, err = client.Write(data[:])
if err != nil {
t.Fatalf("unable to write ping to pipe: %v", err)
}
_, err = client.Read(data[:])
if err != nil {
t.Fatalf("unable to read pong from pipe: %v", err)
}
if data[0] != pong {
t.Fatalf("expected pong, got %d", data[0])
}
}
func TestDialUnknownFailsImmediately(t *testing.T) {
_, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
if !errors.Is(err, syscall.ENOENT) {
t.Fatalf("expected ENOENT got %v", err)
}
}
func TestDialListenerTimesOut(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
}
func TestDialContextListenerTimesOut(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
d := 10 * time.Millisecond
ctx, _ := context.WithTimeout(context.Background(), d)
pipe, err := namedpipe.DialContext(ctx, pipePath)
if err == nil {
pipe.Close()
}
if err != context.DeadlineExceeded {
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
}
}
func TestDialListenerGetsCancelled(t *testing.T) {
pipePath := randomPipePath()
ctx, cancel := context.WithCancel(context.Background())
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
ch := make(chan error)
go func(ctx context.Context, ch chan error) {
_, err := namedpipe.DialContext(ctx, pipePath)
ch <- err
}(ctx, ch)
time.Sleep(time.Millisecond * 30)
cancel()
err = <-ch
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got %v", err)
}
}
func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
t.Skip("dacls on named pipes are broken on wine")
}
pipePath := randomPipePath()
sd, _ := windows.SecurityDescriptorFromString("D:")
l, err := (&namedpipe.ListenConfig{
SecurityDescriptor: sd,
}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
pipe.Close()
}
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
}
}
func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
pipePath := randomPipePath()
if cfg == nil {
cfg = &namedpipe.ListenConfig{}
}
l, err := cfg.Listen(pipePath)
if err != nil {
return
}
defer l.Close()
type response struct {
c net.Conn
err error
}
ch := make(chan response)
go func() {
c, err := l.Accept()
ch <- response{c, err}
}()
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
return
}
r := <-ch
if err = r.err; err != nil {
c.Close()
return
}
client = c
server = r.c
return
}
func TestReadTimeout(t *testing.T) {
c, s, err := getConnection(nil)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
buf := make([]byte, 10)
_, err = c.Read(buf)
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
}
func server(l net.Listener, ch chan int) {
c, err := l.Accept()
if err != nil {
panic(err)
}
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
s, err := rw.ReadString('\n')
if err != nil {
panic(err)
}
_, err = rw.WriteString("got " + s)
if err != nil {
panic(err)
}
err = rw.Flush()
if err != nil {
panic(err)
}
c.Close()
ch <- 1
}
func TestFullListenDialReadWrite(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
ch := make(chan int)
go server(l, ch)
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer c.Close()
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
_, err = rw.WriteString("hello world\n")
if err != nil {
t.Fatal(err)
}
err = rw.Flush()
if err != nil {
t.Fatal(err)
}
s, err := rw.ReadString('\n')
if err != nil {
t.Fatal(err)
}
ms := "got hello world\n"
if s != ms {
t.Errorf("expected '%s', got '%s'", ms, s)
}
<-ch
}
func TestCloseAbortsListen(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
ch := make(chan error)
go func() {
_, err := l.Accept()
ch <- err
}()
time.Sleep(30 * time.Millisecond)
l.Close()
err = <-ch
if err != net.ErrClosed {
t.Fatalf("expected net.ErrClosed, got %v", err)
}
}
func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
b := make([]byte, 10)
w.Close()
n, err := r.Read(b)
if n > 0 {
t.Errorf("unexpected byte count %d", n)
}
if err != io.EOF {
t.Errorf("expected EOF: %v", err)
}
}
func TestCloseClientEOFServer(t *testing.T) {
c, s, err := getConnection(nil)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
ensureEOFOnClose(t, c, s)
}
func TestCloseServerEOFClient(t *testing.T) {
c, s, err := getConnection(nil)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
ensureEOFOnClose(t, s, c)
}
func TestCloseWriteEOF(t *testing.T) {
cfg := &namedpipe.ListenConfig{
MessageMode: true,
}
c, s, err := getConnection(cfg)
if err != nil {
t.Fatal(err)
}
defer c.Close()
defer s.Close()
type closeWriter interface {
CloseWrite() error
}
err = c.(closeWriter).CloseWrite()
if err != nil {
t.Fatal(err)
}
b := make([]byte, 10)
_, err = s.Read(b)
if err != io.EOF {
t.Fatal(err)
}
}
func TestAcceptAfterCloseFails(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
l.Close()
_, err = l.Accept()
if err != net.ErrClosed {
t.Fatalf("expected net.ErrClosed, got %v", err)
}
}
func TestDialTimesOutByDefault(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
if err == nil {
pipe.Close()
}
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
}
func TestTimeoutPendingRead(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
serverDone := make(chan struct{})
go func() {
s, err := l.Accept()
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
s.Close()
close(serverDone)
}()
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer client.Close()
clientErr := make(chan error)
go func() {
buf := make([]byte, 10)
_, err = client.Read(buf)
clientErr <- err
}()
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
client.SetReadDeadline(time.Unix(1, 0))
select {
case err = <-clientErr:
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timed out while waiting for read to cancel")
<-clientErr
}
<-serverDone
}
func TestTimeoutPendingWrite(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
serverDone := make(chan struct{})
go func() {
s, err := l.Accept()
if err != nil {
t.Fatal(err)
}
time.Sleep(1 * time.Second)
s.Close()
close(serverDone)
}()
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer client.Close()
clientErr := make(chan error)
go func() {
_, err = client.Write([]byte("this should timeout"))
clientErr <- err
}()
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
client.SetWriteDeadline(time.Unix(1, 0))
select {
case err = <-clientErr:
if err != os.ErrDeadlineExceeded {
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("timed out while waiting for write to cancel")
<-clientErr
}
<-serverDone
}
type CloseWriter interface {
CloseWrite() error
}
func TestEchoWithMessaging(t *testing.T) {
pipePath := randomPipePath()
l, err := (&namedpipe.ListenConfig{
MessageMode: true, // Use message mode so that CloseWrite() is supported
InputBufferSize: 65536, // Use 64KB buffers to improve performance
OutputBufferSize: 65536,
}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
listenerDone := make(chan bool)
clientDone := make(chan bool)
go func() {
// server echo
conn, err := l.Accept()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
_, err = io.Copy(conn, conn)
if err != nil {
t.Fatal(err)
}
conn.(CloseWriter).CloseWrite()
close(listenerDone)
}()
client, err := namedpipe.DialTimeout(pipePath, time.Second)
if err != nil {
t.Fatal(err)
}
defer client.Close()
go func() {
// client read back
bytes := make([]byte, 2)
n, e := client.Read(bytes)
if e != nil {
t.Fatal(e)
}
if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
t.Fatalf("expected 2 bytes, got %v", n)
}
close(clientDone)
}()
payload := make([]byte, 2)
payload[0] = 0
payload[1] = 1
n, err := client.Write(payload)
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("expected 2 bytes, got %v", n)
}
client.(CloseWriter).CloseWrite()
<-listenerDone
<-clientDone
}
func TestConnectRace(t *testing.T) {
pipePath := randomPipePath()
l, err := namedpipe.Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
go func() {
for {
s, err := l.Accept()
if err == net.ErrClosed {
return
}
if err != nil {
t.Fatal(err)
}
s.Close()
}
}()
for i := 0; i < 1000; i++ {
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
c.Close()
}
}
func TestMessageReadMode(t *testing.T) {
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
t.Skipf("Skipping on Windows %d", maj)
}
var wg sync.WaitGroup
defer wg.Wait()
pipePath := randomPipePath()
l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
if err != nil {
t.Fatal(err)
}
defer l.Close()
msg := ([]byte)("hello world")
wg.Add(1)
go func() {
defer wg.Done()
s, err := l.Accept()
if err != nil {
t.Fatal(err)
}
_, err = s.Write(msg)
if err != nil {
t.Fatal(err)
}
s.Close()
}()
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err != nil {
t.Fatal(err)
}
defer c.Close()
mode := uint32(windows.PIPE_READMODE_MESSAGE)
err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
if err != nil {
t.Fatal(err)
}
ch := make([]byte, 1)
var vmsg []byte
for {
n, err := c.Read(ch)
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
if n != 1 {
t.Fatalf("expected 1, got %d", n)
}
vmsg = append(vmsg, ch[0])
}
if !bytes.Equal(msg, vmsg) {
t.Fatalf("expected %s, got %s", msg, vmsg)
}
}
func TestListenConnectRace(t *testing.T) {
if testing.Short() {
t.Skip("Skipping long race test")
}
pipePath := randomPipePath()
for i := 0; i < 50 && !t.Failed(); i++ {
var wg sync.WaitGroup
wg.Add(1)
go func() {
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
if err == nil {
c.Close()
}
wg.Done()
}()
s, err := namedpipe.Listen(pipePath)
if err != nil {
t.Error(i, err)
} else {
s.Close()
}
wg.Wait()
}
}

View File

@@ -1,8 +1,8 @@
// +build darwin freebsd openbsd //go:build darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
@@ -54,7 +54,6 @@ func (l *UAPIListener) Addr() net.Addr {
} }
func UAPIListen(name string, file *os.File) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// wrap file in listener // wrap file in listener
listener, err := net.FileListener(file) listener, err := net.FileListener(file)
@@ -104,7 +103,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
l.connErr <- err l.connErr <- err
return return
} }
if kerr != nil || n != 1 { if (kerr != nil || n != 1) && kerr != unix.EINTR {
if kerr != nil { if kerr != nil {
l.connErr <- kerr l.connErr <- kerr
} else { } else {

15
ipc/uapi_js.go Normal file
View File

@@ -0,0 +1,15 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package ipc
// Made up sentinel error codes for the js/wasm platform.
const (
IpcErrorIO = 1
IpcErrorInvalid = 2
IpcErrorPortInUse = 3
IpcErrorUnknown = 4
IpcErrorProtocol = 5
)

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
@@ -51,7 +51,6 @@ func (l *UAPIListener) Addr() net.Addr {
} }
func UAPIListen(name string, file *os.File) (net.Listener, error) { func UAPIListen(name string, file *os.File) (net.Listener, error) {
// wrap file in listener // wrap file in listener
listener, err := net.FileListener(file) listener, err := net.FileListener(file)

View File

@@ -1,8 +1,8 @@
// +build linux darwin freebsd openbsd //go:build linux || darwin || freebsd || openbsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
@@ -33,7 +33,7 @@ func sockPath(iface string) string {
} }
func UAPIOpen(name string) (*os.File, error) { func UAPIOpen(name string) (*os.File, error) {
if err := os.MkdirAll(socketDirectory, 0755); err != nil { if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
return nil, err return nil, err
} }
@@ -43,7 +43,7 @@ func UAPIOpen(name string) (*os.File, error) {
return nil, err return nil, err
} }
oldUmask := unix.Umask(0077) oldUmask := unix.Umask(0o077)
defer unix.Umask(oldUmask) defer unix.Umask(oldUmask)
listener, err := net.ListenUnix("unix", addr) listener, err := net.ListenUnix("unix", addr)

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ipc package ipc
@@ -9,8 +9,7 @@ import (
"net" "net"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/namedpipe"
"golang.zx2c4.com/wireguard/ipc/winpipe"
) )
// TODO: replace these with actual standard windows error numbers from the win package // TODO: replace these with actual standard windows error numbers from the win package
@@ -54,18 +53,16 @@ var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
func init() { func init() {
var err error var err error
/* SDDL_DEVOBJ_SYS_ALL from the WDK */ UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
func UAPIListen(name string) (net.Listener, error) { func UAPIListen(name string) (net.Listener, error) {
config := winpipe.PipeConfig{ listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor, SecurityDescriptor: UAPISecurityDescriptor,
} }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name)
listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,9 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package winpipe
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go

View File

@@ -1,509 +0,0 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
type ioStatusBlock struct {
Status, Information uintptr
}
type objectAttributes struct {
Length uintptr
RootDirectory uintptr
ObjectName *unicodeString
Attributes uintptr
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
SecurityQoS uintptr
}
type unicodeString struct {
Length uint16
MaximumLength uint16
Buffer uintptr
}
type ntstatus int32
func (status ntstatus) Err() error {
if status >= 0 {
return nil
}
return rtlNtStatusToDosError(status)
}
const (
cSECURITY_SQOS_PRESENT = 0x100000
cSECURITY_ANONYMOUS = 0
cPIPE_TYPE_MESSAGE = 4
cPIPE_READMODE_MESSAGE = 2
cFILE_OPEN = 1
cFILE_CREATE = 2
cFILE_PIPE_MESSAGE_TYPE = 1
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
)
var (
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
// This error should match net.errClosing since docker takes a dependency on its text.
ErrPipeListenerClosed = errors.New("use of closed network connection")
errPipeWriteClosed = errors.New("pipe has been closed for write")
)
type win32Pipe struct {
*win32File
path string
}
type win32MessageBytePipe struct {
win32Pipe
writeClosed bool
readEOF bool
}
type pipeAddress string
func (f *win32Pipe) LocalAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *win32Pipe) RemoteAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *win32Pipe) SetDeadline(t time.Time) error {
f.SetReadDeadline(t)
f.SetWriteDeadline(t)
return nil
}
// CloseWrite closes the write side of a message pipe in byte mode.
func (f *win32MessageBytePipe) CloseWrite() error {
if f.writeClosed {
return errPipeWriteClosed
}
err := f.win32File.Flush()
if err != nil {
return err
}
_, err = f.win32File.Write(nil)
if err != nil {
return err
}
f.writeClosed = true
return nil
}
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite().
func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
if f.writeClosed {
return 0, errPipeWriteClosed
}
if len(b) == 0 {
return 0, nil
}
return f.win32File.Write(b)
}
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
// mode pipe will return io.EOF, as will all subsequent reads.
func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
if f.readEOF {
return 0, io.EOF
}
n, err := f.win32File.Read(b)
if err == io.EOF {
// If this was the result of a zero-byte read, then
// it is possible that the read was due to a zero-size
// message. Since we are simulating CloseWrite with a
// zero-byte message, ensure that all future Read() calls
// also return EOF.
f.readEOF = true
} else if err == windows.ERROR_MORE_DATA {
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams.
err = nil
}
return n, err
}
func (s pipeAddress) Network() string {
return "pipe"
}
func (s pipeAddress) String() string {
return string(s)
}
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
for {
select {
case <-ctx.Done():
return windows.Handle(0), ctx.Err()
default:
h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
if err == nil {
return h, nil
}
if err != windows.ERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path}
}
// Wait 10 msec and try again. This is a rather simplistic
// view, as we always try each 10 milliseconds.
time.Sleep(time.Millisecond * 10)
}
}
}
// DialPipe connects to a named pipe by path, timing out if the connection
// takes longer than the specified duration. If timeout is nil, then we use
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) {
var absTimeout time.Time
if timeout != nil {
absTimeout = time.Now().Add(*timeout)
} else {
absTimeout = time.Now().Add(time.Second * 2)
}
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := DialPipeContext(ctx, path, expectedOwner)
if err == context.DeadlineExceeded {
return nil, ErrTimeout
}
return conn, err
}
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout.
func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) {
var err error
var h windows.Handle
h, err = tryDialPipe(ctx, &path)
if err != nil {
return nil, err
}
if expectedOwner != nil {
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
if err != nil {
windows.Close(h)
return nil, err
}
realOwner, _, err := sd.Owner()
if err != nil {
windows.Close(h)
return nil, err
}
if !realOwner.Equals(expectedOwner) {
windows.Close(h)
return nil, windows.ERROR_ACCESS_DENIED
}
}
var flags uint32
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil {
windows.Close(h)
return nil, err
}
f, err := makeWin32File(h)
if err != nil {
windows.Close(h)
return nil, err
}
// If the pipe is in message mode, return a message byte pipe, which
// supports CloseWrite().
if flags&cPIPE_TYPE_MESSAGE != 0 {
return &win32MessageBytePipe{
win32Pipe: win32Pipe{win32File: f, path: path},
}, nil
}
return &win32Pipe{win32File: f, path: path}, nil
}
type acceptResponse struct {
f *win32File
err error
}
type win32PipeListener struct {
firstHandle windows.Handle
path string
config PipeConfig
acceptCh chan (chan acceptResponse)
closeCh chan int
doneCh chan int
}
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) {
path16, err := windows.UTF16FromString(path)
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
var oa objectAttributes
oa.Length = unsafe.Sizeof(oa)
var ntPath unicodeString
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
defer windows.LocalFree(windows.Handle(ntPath.Buffer))
oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe.
if first {
if sd != nil {
oa.SecurityDescriptor = sd
} else {
// Construct the default named pipe security descriptor.
var dacl uintptr
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
}
defer windows.LocalFree(windows.Handle(dacl))
sd, err := windows.NewSecurityDescriptor()
if err != nil {
return 0, fmt.Errorf("creating new security descriptor: %s", err)
}
if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil {
return 0, fmt.Errorf("assigning dacl: %s", err)
}
sd, err = sd.ToSelfRelative()
if err != nil {
return 0, fmt.Errorf("converting to self-relative: %s", err)
}
oa.SecurityDescriptor = sd
}
}
typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS)
if c.MessageMode {
typ |= cFILE_PIPE_MESSAGE_TYPE
}
disposition := uint32(cFILE_OPEN)
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
if first {
disposition = cFILE_CREATE
// By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false.
access = windows.SYNCHRONIZE
}
timeout := int64(-50 * 10000) // 50ms
var (
h windows.Handle
iosb ioStatusBlock
)
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
runtime.KeepAlive(ntPath)
return h, nil
}
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
if err != nil {
return nil, err
}
f, err := makeWin32File(h)
if err != nil {
windows.Close(h)
return nil, err
}
return f, nil
}
func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
p, err := l.makeServerPipe()
if err != nil {
return nil, err
}
// Wait for the client to connect.
ch := make(chan error)
go func(p *win32File) {
ch <- connectPipe(p)
}(p)
select {
case err = <-ch:
if err != nil {
p.Close()
p = nil
}
case <-l.closeCh:
// Abort the connect request by closing the handle.
p.Close()
p = nil
err = <-ch
if err == nil || err == ErrFileClosed {
err = ErrPipeListenerClosed
}
}
return p, err
}
func (l *win32PipeListener) listenerRoutine() {
closed := false
for !closed {
select {
case <-l.closeCh:
closed = true
case responseCh := <-l.acceptCh:
var (
p *win32File
err error
)
for {
p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try
// again.
if err != windows.ERROR_NO_DATA {
break
}
}
responseCh <- acceptResponse{p, err}
closed = err == ErrPipeListenerClosed
}
}
windows.Close(l.firstHandle)
l.firstHandle = 0
// Notify Close() and Accept() callers that the handle has been closed.
close(l.doneCh)
}
// PipeConfig contain configuration for the pipe listener.
type PipeConfig struct {
// SecurityDescriptor contains a Windows security descriptor.
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
// MessageMode determines whether the pipe is in byte or message mode. In either
// case the pipe is read in byte mode by default. The only practical difference in
// this implementation is that CloseWrite() is only supported for message mode pipes;
// CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
// transferred to the reader (and returned as io.EOF in this implementation)
// when the pipe is in message mode.
MessageMode bool
// InputBufferSize specifies the size the input buffer, in bytes.
InputBufferSize int32
// OutputBufferSize specifies the size the input buffer, in bytes.
OutputBufferSize int32
}
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
// The pipe must not already exist.
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
if c == nil {
c = &PipeConfig{}
}
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
if err != nil {
return nil, err
}
l := &win32PipeListener{
firstHandle: h,
path: path,
config: *c,
acceptCh: make(chan (chan acceptResponse)),
closeCh: make(chan int),
doneCh: make(chan int),
}
go l.listenerRoutine()
return l, nil
}
func connectPipe(p *win32File) error {
c, err := p.prepareIo()
if err != nil {
return err
}
defer p.wg.Done()
err = connectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err)
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
return err
}
return nil
}
func (l *win32PipeListener) Accept() (net.Conn, error) {
ch := make(chan acceptResponse)
select {
case l.acceptCh <- ch:
response := <-ch
err := response.err
if err != nil {
return nil, err
}
if l.config.MessageMode {
return &win32MessageBytePipe{
win32Pipe: win32Pipe{win32File: response.f, path: l.path},
}, nil
}
return &win32Pipe{win32File: response.f, path: l.path}, nil
case <-l.doneCh:
return nil, ErrPipeListenerClosed
}
}
func (l *win32PipeListener) Close() error {
select {
case l.closeCh <- 1:
<-l.doneCh
case <-l.doneCh:
}
return nil
}
func (l *win32PipeListener) Addr() net.Addr {
return pipeAddress(l.path)
}

View File

@@ -1,238 +0,0 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winpipe
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modntdll = windows.NewLazySystemDLL("ntdll.dll")
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
procCreateFileW = modkernel32.NewProc("CreateFileW")
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
)
func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
}
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
handle = windows.Handle(r0)
if handle == windows.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
}
func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
handle = windows.Handle(r0)
if handle == windows.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0)
ptr = uintptr(r0)
return
}
func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
status = ntstatus(r0)
return
}
func rtlNtStatusToDosError(status ntstatus) (winerr error) {
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
if r0 != 0 {
winerr = syscall.Errno(r0)
}
return
}
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) {
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
status = ntstatus(r0)
return
}
func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
status = ntstatus(r0)
return
}
func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
newport = windows.Handle(r0)
if newport == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
var _p0 uint32
if wait {
_p0 = 1
} else {
_p0 = 0
}
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}

37
main.go
View File

@@ -1,8 +1,8 @@
// +build !windows //go:build !windows
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
@@ -15,6 +15,7 @@ import (
"strconv" "strconv"
"syscall" "syscall"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@@ -32,25 +33,28 @@ const (
) )
func printUsage() { func printUsage() {
fmt.Printf("usage:\n") fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
} }
func warning() { func warning() {
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { switch runtime.GOOS {
case "linux", "freebsd", "openbsd":
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
return
}
default:
return return
} }
fmt.Fprintln(os.Stderr, "┌───────────────────────────────────────────────────┐") fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ Running this software on Linux is unnecessary, │") fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ because the Linux kernel has built-in first │") fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
fmt.Fprintln(os.Stderr, "│ class support for WireGuard, which will be │") fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "│ faster, slicker, and better integrated. For │") fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
fmt.Fprintln(os.Stderr, "│ please visit: <https://wireguard.com/install>. │") fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
fmt.Fprintln(os.Stderr, "└───────────────────────────────────────────────────┘")
} }
func main() { func main() {
@@ -165,7 +169,6 @@ func main() {
return os.NewFile(uintptr(fd), ""), nil return os.NewFile(uintptr(fd), ""), nil
}() }()
if err != nil { if err != nil {
logger.Errorf("UAPI listen error: %v", err) logger.Errorf("UAPI listen error: %v", err)
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
@@ -219,7 +222,7 @@ func main() {
return return
} }
device := device.NewDevice(tun, logger) device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
logger.Verbosef("Device started") logger.Verbosef("Device started")

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
@@ -11,6 +11,7 @@ import (
"os/signal" "os/signal"
"syscall" "syscall"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
@@ -47,7 +48,7 @@ func main() {
os.Exit(ExitSetupFailed) os.Exit(ExitSetupFailed)
} }
device := device.NewDevice(tun, logger) device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
err = device.Up() err = device.Up()
if err != nil { if err != nil {
logger.Errorf("Failed to bring up device: %v", err) logger.Errorf("Failed to bring up device: %v", err)

View File

@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter
import ( import (
"net" "net/netip"
"sync" "sync"
"time" "time"
) )
@@ -30,8 +30,7 @@ type Ratelimiter struct {
timeNow func() time.Time timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop stopReset chan struct{} // send to reset, close to stop
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry table map[netip.Addr]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
} }
func (rate *Ratelimiter) Close() { func (rate *Ratelimiter) Close() {
@@ -57,8 +56,7 @@ func (rate *Ratelimiter) Init() {
} }
rate.stopReset = make(chan struct{}) rate.stopReset = make(chan struct{})
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.table = make(map[netip.Addr]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
stopReset := rate.stopReset // store in case Init is called again. stopReset := rate.stopReset // store in case Init is called again.
@@ -87,71 +85,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock() rate.mu.Lock()
defer rate.mu.Unlock() defer rate.mu.Unlock()
for key, entry := range rate.tableIPv4 { for key, entry := range rate.table {
entry.mu.Lock() entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key) delete(rate.table, key)
} }
entry.mu.Unlock() entry.mu.Unlock()
} }
for key, entry := range rate.tableIPv6 { return len(rate.table) == 0
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.mu.Unlock()
}
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
} }
func (rate *Ratelimiter) Allow(ip net.IP) bool { func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
var entry *RatelimiterEntry var entry *RatelimiterEntry
var keyIPv4 [net.IPv4len]byte
var keyIPv6 [net.IPv6len]byte
// lookup entry // lookup entry
IPv4 := ip.To4()
IPv6 := ip.To16()
rate.mu.RLock() rate.mu.RLock()
entry = rate.table[ip]
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
entry = rate.tableIPv4[keyIPv4]
} else {
copy(keyIPv6[:], IPv6)
entry = rate.tableIPv6[keyIPv6]
}
rate.mu.RUnlock() rate.mu.RUnlock()
// make new entry if not found // make new entry if not found
if entry == nil { if entry == nil {
entry = new(RatelimiterEntry) entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost entry.tokens = maxTokens - packetCost
entry.lastTime = rate.timeNow() entry.lastTime = rate.timeNow()
rate.mu.Lock() rate.mu.Lock()
if IPv4 != nil { rate.table[ip] = entry
rate.tableIPv4[keyIPv4] = entry if len(rate.table) == 1 {
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { rate.stopReset <- struct{}{}
rate.stopReset <- struct{}{}
}
} else {
rate.tableIPv6[keyIPv6] = entry
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
rate.stopReset <- struct{}{}
}
} }
rate.mu.Unlock() rate.mu.Unlock()
return true return true
} }
// add tokens to entry // add tokens to entry
entry.mu.Lock() entry.mu.Lock()
now := rate.timeNow() now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds() entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
@@ -161,7 +127,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
} }
// subtract cost of packet // subtract cost of packet
if entry.tokens > packetCost { if entry.tokens > packetCost {
entry.tokens -= packetCost entry.tokens -= packetCost
entry.mu.Unlock() entry.mu.Unlock()

View File

@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
) )
@@ -71,21 +71,21 @@ func TestRatelimiter(t *testing.T) {
text: "packet following 2 packet burst", text: "packet following 2 packet burst",
}) })
ips := []net.IP{ ips := []netip.Addr{
net.ParseIP("127.0.0.1"), netip.MustParseAddr("127.0.0.1"),
net.ParseIP("192.168.1.1"), netip.MustParseAddr("192.168.1.1"),
net.ParseIP("172.167.2.3"), netip.MustParseAddr("172.167.2.3"),
net.ParseIP("97.231.252.215"), netip.MustParseAddr("97.231.252.215"),
net.ParseIP("248.97.91.167"), netip.MustParseAddr("248.97.91.167"),
net.ParseIP("188.208.233.47"), netip.MustParseAddr("188.208.233.47"),
net.ParseIP("104.2.183.179"), netip.MustParseAddr("104.2.183.179"),
net.ParseIP("72.129.46.120"), netip.MustParseAddr("72.129.46.120"),
net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
} }
now := time.Now() now := time.Now()

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. // Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
@@ -34,7 +34,7 @@ func (f *Filter) Reset() {
// ValidateCounter checks if the counter should be accepted. // ValidateCounter checks if the counter should be accepted.
// Overlimit counters (>= limit) are always rejected. // Overlimit counters (>= limit) are always rejected.
func (f *Filter) ValidateCounter(counter uint64, limit uint64) bool { func (f *Filter) ValidateCounter(counter, limit uint64) bool {
if counter >= limit { if counter >= limit {
return false return false
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package replay package replay

View File

@@ -1,24 +0,0 @@
// +build !windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
type fdSet struct {
unix.FdSet
}
func (fdset *fdSet) set(i int) {
bits := 32 << (^uint(0) >> 63)
fdset.Bits[i/bits] |= 1 << uint(i%bits)
}
func (fdset *fdSet) check(i int) bool {
bits := 32 << (^uint(0) >> 63)
return (fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
}

View File

@@ -1,8 +1,8 @@
// +build !windows //go:build !windows && !js
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
// Package rwcancel implements cancelable read/write operations on // Package rwcancel implements cancelable read/write operations on
@@ -17,13 +17,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func max(a, b int) int {
if a > b {
return a
}
return b
}
type RWCancel struct { type RWCancel struct {
fd int fd int
closingReader *os.File closingReader *os.File
@@ -50,13 +43,12 @@ func RetryAfterError(err error) bool {
} }
func (rw *RWCancel) ReadyRead() bool { func (rw *RWCancel) ReadyRead() bool {
closeFd := int(rw.closingReader.Fd()) closeFd := int32(rw.closingReader.Fd())
fdset := fdSet{}
fdset.set(rw.fd) pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}}
fdset.set(closeFd)
var err error var err error
for { for {
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil) _, err = unix.Poll(pollFds, -1)
if err == nil || !RetryAfterError(err) { if err == nil || !RetryAfterError(err) {
break break
} }
@@ -64,20 +56,18 @@ func (rw *RWCancel) ReadyRead() bool {
if err != nil { if err != nil {
return false return false
} }
if fdset.check(closeFd) { if pollFds[1].Revents != 0 {
return false return false
} }
return fdset.check(rw.fd) return pollFds[0].Revents != 0
} }
func (rw *RWCancel) ReadyWrite() bool { func (rw *RWCancel) ReadyWrite() bool {
closeFd := int(rw.closingReader.Fd()) closeFd := int32(rw.closingReader.Fd())
fdset := fdSet{} pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}}
fdset.set(rw.fd)
fdset.set(closeFd)
var err error var err error
for { for {
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil) _, err = unix.Poll(pollFds, -1)
if err == nil || !RetryAfterError(err) { if err == nil || !RetryAfterError(err) {
break break
} }
@@ -85,10 +75,11 @@ func (rw *RWCancel) ReadyWrite() bool {
if err != nil { if err != nil {
return false return false
} }
if fdset.check(closeFd) {
if pollFds[1].Revents != 0 {
return false return false
} }
return fdset.check(rw.fd) return pollFds[0].Revents != 0
} }
func (rw *RWCancel) Read(p []byte) (n int, err error) { func (rw *RWCancel) Read(p []byte) (n int, err error) {
@@ -98,7 +89,7 @@ func (rw *RWCancel) Read(p []byte) (n int, err error) {
return n, err return n, err
} }
if !rw.ReadyRead() { if !rw.ReadyRead() {
return 0, errors.New("fd closed") return 0, os.ErrClosed
} }
} }
} }
@@ -110,7 +101,7 @@ func (rw *RWCancel) Write(p []byte) (n int, err error) {
return n, err return n, err
} }
if !rw.ReadyWrite() { if !rw.ReadyWrite() {
return 0, errors.New("fd closed") return 0, os.ErrClosed
} }
} }
} }

View File

@@ -1,8 +1,9 @@
//go:build windows || js
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
package rwcancel package rwcancel
type RWCancel struct { type RWCancel struct{}
}
func (*RWCancel) Cancel() {} func (*RWCancel) Cancel() {}

View File

@@ -1,15 +0,0 @@
// +build !linux,!windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
_, err := unix.Select(nfd, r, w, e, timeout)
return err
}

View File

@@ -1,13 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) (err error) {
_, err = unix.Select(nfd, r, w, e, timeout)
return
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tai64n package tai64n
@@ -11,9 +11,11 @@ import (
"time" "time"
) )
const TimestampSize = 12 const (
const base = uint64(0x400000000000000a) TimestampSize = 12
const whitenerMask = uint32(0x1000000 - 1) base = uint64(0x400000000000000a)
whitenerMask = uint32(0x1000000 - 1)
)
type Timestamp [TimestampSize]byte type Timestamp [TimestampSize]byte

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tai64n package tai64n

View File

@@ -1,9 +1,9 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package device package tun
import ( import (
"reflect" "reflect"
@@ -18,15 +18,15 @@ func checkAlignment(t *testing.T, name string, offset uintptr) {
} }
} }
// TestPeerAlignment checks that atomically-accessed fields are // TestRateJugglerAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package. // aligned to 64-bit boundaries, as required by the atomic package.
// //
// Unfortunately, violating this rule on 32-bit platforms results in a // Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime. // hard segfault at runtime.
func TestPeerAlignment(t *testing.T) { func TestRateJugglerAlignment(t *testing.T) {
var p Peer var r rateJuggler
typ := reflect.TypeOf(&p).Elem() typ := reflect.TypeOf(&r).Elem()
t.Logf("Peer type size: %d, with fields:", typ.Size()) t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i) field := typ.Field(i)
@@ -38,20 +38,21 @@ func TestPeerAlignment(t *testing.T) {
) )
} }
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats)) checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current))
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount))
checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime))
} }
// TestDeviceAlignment checks that atomically-accessed fields are // TestNativeTunAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package. // aligned to 64-bit boundaries, as required by the atomic package.
// //
// Unfortunately, violating this rule on 32-bit platforms results in a // Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime. // hard segfault at runtime.
func TestDeviceAlignment(t *testing.T) { func TestNativeTunAlignment(t *testing.T) {
var d Device var tun NativeTun
typ := reflect.TypeOf(&d).Elem() typ := reflect.TypeOf(&tun).Elem()
t.Logf("Device type size: %d, with fields:", typ.Size()) t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i) field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
@@ -61,5 +62,6 @@ func TestDeviceAlignment(t *testing.T) {
field.Type.Align(), field.Type.Align(),
) )
} }
checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate))
} }

View File

@@ -1,8 +1,9 @@
//go:build ignore
// +build ignore // +build ignore
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
@@ -10,26 +11,27 @@ package main
import ( import (
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"net/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
) )
func main() { func main() {
tun, tnet, err := netstack.CreateNetTUN( tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")}, []netip.Addr{netip.MustParseAddr("192.168.4.28")},
[]net.IP{net.ParseIP("8.8.8.8")}, []netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420) 1420)
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
dev := device.NewDevice(tun, &device.Logger{log.Default(), log.Default(), log.Default()}) dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28
endpoint=163.172.161.0:12912
allowed_ip=0.0.0.0/0 allowed_ip=0.0.0.0/0
endpoint=127.0.0.1:58120
`) `)
err = dev.Up() err = dev.Up()
if err != nil { if err != nil {
@@ -41,7 +43,7 @@ allowed_ip=0.0.0.0/0
DialContext: tnet.DialContext, DialContext: tnet.DialContext,
}, },
} }
resp, err := client.Get("https://www.zx2c4.com/ip") resp, err := client.Get("http://192.168.4.29/")
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }

View File

@@ -1,37 +1,41 @@
//go:build ignore
// +build ignore // +build ignore
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package main package main
import ( import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"io" "io"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
) )
func main() { func main() {
tun, tnet, err := netstack.CreateNetTUN( tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")}, []netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")}, []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
1420, 1420,
) )
if err != nil { if err != nil {
log.Panic(err) log.Panic(err)
} }
dev := device.NewDevice(tun, &device.Logger{log.Default(), log.Default(), log.Default()}) dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b listen_port=58120
endpoint=163.172.161.0:12912 public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c
allowed_ip=0.0.0.0/0 allowed_ip=192.168.4.28/32
persistent_keepalive_interval=25 persistent_keepalive_interval=25
`) `)
dev.Up() dev.Up()
listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80}) listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80})
if err != nil { if err != nil {

View File

@@ -0,0 +1,76 @@
//go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package main
import (
"bytes"
"log"
"math/rand"
"net/netip"
"time"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
)
func main() {
tun, tnet, err := netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)
}
dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
endpoint=163.172.161.0:12912
allowed_ip=0.0.0.0/0
`)
err = dev.Up()
if err != nil {
log.Panic(err)
}
socket, err := tnet.Dial("ping4", "zx2c4.com")
if err != nil {
log.Panic(err)
}
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"),
}
icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
socket.SetReadDeadline(time.Now().Add(time.Second * 10))
start := time.Now()
_, err = socket.Write(icmpBytes)
if err != nil {
log.Panic(err)
}
n, err := socket.Read(icmpBytes[:])
if err != nil {
log.Panic(err)
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil {
log.Panic(err)
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
log.Panicf("invalid reply type: %v", replyPacket)
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
log.Panicf("invalid ping reply: %v", replyPing)
}
log.Printf("Ping latency: %v", time.Since(start))
}

View File

@@ -1,9 +0,0 @@
module golang.zx2c4.com/wireguard/tun/netstack
go 1.15
require (
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b
golang.zx2c4.com/wireguard v0.0.20201118
gvisor.dev/gvisor v0.0.0-20210109011639-2fb7a49fea98
)

View File

@@ -1,391 +0,0 @@
bazil.org/fuse v0.0.0-20160811212531-371fbbdaa898/go.mod h1:Xbm+BRKSBEpa4q4hTSxohYNQpsxXPbPry4JJWOB3LB8=
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU=
cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY=
cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc=
cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0=
cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4=
cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI=
github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0=
github.com/Azure/go-autorest/autorest/date v0.1.0/go.mod h1:plvfp3oPSKwf2DNjlBjWF/7vwR+cUD/ELuzDCXwHUVA=
github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0=
github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0=
github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc=
github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/Microsoft/go-winio v0.4.15-0.20200908182639-5b44b70ab3ab/go.mod h1:tTuCMEN+UleMWgg9dVx4Hu52b1bJo+59jBh3ajtinzw=
github.com/Microsoft/hcsshim v0.8.6/go.mod h1:Op3hHsoHPAvb6lceZHDtd9OkTew38wNoXnJs8iY7rUg=
github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ=
github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0=
github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE=
github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3/go.mod h1:MA5e5Lr8slmEg9bt0VpxxWqJlO4iwu3FBdHUzV7wQVg=
github.com/cilium/ebpf v0.2.0/go.mod h1:To2CFviqOWL/M0gIMsvSMlqe7em/l1ALkX1PyjrX2Qs=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE=
github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE=
github.com/containerd/containerd v1.3.9/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a/go.mod h1:W0qIOTD7mp2He++YVq+kgfXezRYqzP1uDuMVH1bITDY=
github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0=
github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328/go.mod h1:PpyHrqVs8FTi9vpyHwPwiNEGaACDxT/N/pLcvMSRA9g=
github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15/go.mod h1:UAxOpgT9ziI0gJrmKvgcZivgxOp8iFPSk8httJEt98Y=
github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737/go.mod h1:TB1hUtrpaiO88KEK56ijojHS1+NeF0izUACaJW2mdXg=
github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
github.com/coreos/go-systemd/v22 v22.1.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/davecgh/go-spew v0.0.0-20151105211317-5215b55f46b2/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.3.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec=
github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM=
github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas=
github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0=
github.com/go-openapi/jsonreference v0.0.0-20160704190145-13c6e3589ad9/go.mod h1:W3Z9FmVs9qj+KR4zFKmDPGiLdk1D9Rlm7cyMvf57TTg=
github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc=
github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I=
github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c=
github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
github.com/golang/protobuf v0.0.0-20161109072736-4bd1920723d7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8/go.mod h1:g82e6OHbJ0WYrYeOrid1MMfHAtqjxBz+N74tfAt9KrQ=
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY=
github.com/gophercloud/gophercloud v0.1.0/go.mod h1:vxM41WHh5uqHVBMZHzuwNOHh8XEoIEcSTewFxm1c5g8=
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a/go.mod h1:M1qoD/MqPgTZIk0EWKB38wE28ACRfVcn+cU08jyArI0=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180320133207-05fbef0ca5da/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/onsi/ginkgo v0.0.0-20170829012221-11459a886d9c/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.0.1/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0=
github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U=
github.com/opencontainers/runtime-spec v1.0.1/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/prometheus/procfs v0.0.0-20190522114515-bc1a522cf7b1/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk=
github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs=
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190209173611-3b5209105503/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200124204421-9fbb57f87de9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7 h1:s330+6z/Ko3J0o6rvOcwXe5nzs7UT9tLKHoOXYn6uE0=
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.20201118 h1:QL8y2C7uO8T6z1GY+UX/hSeWiYEBurQkXjOTRFtCvXU=
golang.zx2c4.com/wireguard v0.0.20201118/go.mod h1:Dz+cq5bnrai9EpgYj4GDof/+qaGzbRWbeaAOs1bUYa0=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8=
google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b/go.mod h1:hFxJC2f0epmp1elRCiEGJTKAWbwxZ2nvqZdHl3FQXCY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
gvisor.dev/gvisor v0.0.0-20210109011639-2fb7a49fea98 h1:qDiV0V69AVoFfU6AiE1UgpLUorGJrIxSM/P4tgkF8oc=
gvisor.dev/gvisor v0.0.0-20210109011639-2fb7a49fea98/go.mod h1:5DEMKRjYDiM24fvDUWPjBpABm9ROMcv/kEcox3fHtm0=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
k8s.io/api v0.16.13/go.mod h1:QWu8UWSTiuQZMMeYjwLs6ILu5O74qKSJ0c+4vrchDxs=
k8s.io/apimachinery v0.16.13/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ=
k8s.io/apimachinery v0.16.14-rc.0/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ=
k8s.io/client-go v0.16.13/go.mod h1:UKvVT4cajC2iN7DCjLgT0KVY/cbY6DGdUCyRiIfws5M=
k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0=
k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
k8s.io/klog v0.3.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I=
k8s.io/kube-openapi v0.0.0-20200410163147-594e756bea31/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E=
k8s.io/utils v0.0.0-20190801114015-581e00157fb1/go.mod h1:sZAwmy6armz5eXlNoLmJcl4F1QuKu7sr+mFQ0byX7Ew=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
sigs.k8s.io/structured-merge-diff v0.0.0-20190525122527-15d366b2352e/go.mod h1:wWxsB5ozmmv/SG7nM11ayaAW51xMvak/t1r0CSlcokI=
sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o=

View File

@@ -1,11 +1,12 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package netstack package netstack
import ( import (
"bytes"
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
@@ -13,7 +14,9 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"os" "os"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -21,104 +24,69 @@ import (
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
) )
type netTun struct { type netTun struct {
ep *channel.Endpoint
stack *stack.Stack stack *stack.Stack
dispatcher stack.NetworkDispatcher
events chan tun.Event events chan tun.Event
incomingPacket chan buffer.VectorisedView incomingPacket chan *bufferv2.View
mtu int mtu int
dnsServers []net.IP dnsServers []netip.Addr
hasV4, hasV6 bool hasV4, hasV6 bool
} }
type endpoint netTun
type Net netTun type Net netTun
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *endpoint) MTU() uint32 {
mtu, err := (*netTun)(e).MTU()
if err != nil {
panic(err)
}
return uint32(mtu)
}
func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (*endpoint) MaxHeaderLength() uint16 {
return 0
}
func (*endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (*endpoint) Wait() {}
func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
e.incomingPacket <- buffer.NewVectorisedView(pkt.Size(), pkt.Views())
return nil
}
func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
func (*endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
}
func CreateNetTUN(localAddresses []net.IP, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{ opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
HandleLocal: true, HandleLocal: true,
} }
dev := &netTun{ dev := &netTun{
ep: channel.New(1024, uint32(mtu), ""),
stack: stack.New(opts), stack: stack.New(opts),
events: make(chan tun.Event, 10), events: make(chan tun.Event, 10),
incomingPacket: make(chan buffer.VectorisedView), incomingPacket: make(chan *bufferv2.View),
dnsServers: dnsServers, dnsServers: dnsServers,
mtu: mtu, mtu: mtu,
} }
tcpipErr := dev.stack.CreateNIC(1, (*endpoint)(dev)) dev.ep.AddNotify(dev)
tcpipErr := dev.stack.CreateNIC(1, dev.ep)
if tcpipErr != nil { if tcpipErr != nil {
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
} }
for _, ip := range localAddresses { for _, ip := range localAddresses {
if ip4 := ip.To4(); ip4 != nil { var protoNumber tcpip.NetworkProtocolNumber
tcpipErr = dev.stack.AddAddress(1, ipv4.ProtocolNumber, tcpip.Address(ip4)) if ip.Is4() {
if tcpipErr != nil { protoNumber = ipv4.ProtocolNumber
return nil, nil, fmt.Errorf("AddAddress(%v): %v", ip4, tcpipErr) } else if ip.Is6() {
} protoNumber = ipv6.ProtocolNumber
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
}
if ip.Is4() {
dev.hasV4 = true dev.hasV4 = true
} else { } else if ip.Is6() {
tcpipErr = dev.stack.AddAddress(1, ipv6.ProtocolNumber, tcpip.Address(ip))
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddAddress(%v): %v", ip4, tcpipErr)
}
dev.hasV6 = true dev.hasV6 = true
} }
} }
@@ -141,7 +109,7 @@ func (tun *netTun) File() *os.File {
return nil return nil
} }
func (tun *netTun) Events() chan tun.Event { func (tun *netTun) Events() <-chan tun.Event {
return tun.events return tun.events
} }
@@ -150,6 +118,7 @@ func (tun *netTun) Read(buf []byte, offset int) (int, error) {
if !ok { if !ok {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
return view.Read(buf[offset:]) return view.Read(buf[offset:])
} }
@@ -159,17 +128,29 @@ func (tun *netTun) Write(buf []byte, offset int) (int, error) {
return 0, nil return 0, nil
} }
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(len(packet), []buffer.View{buffer.NewViewFromBytes(packet)})}) pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)})
switch packet[0] >> 4 { switch packet[0] >> 4 {
case 4: case 4:
tun.dispatcher.DeliverNetworkPacket("", "", ipv4.ProtocolNumber, pkb) tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
case 6: case 6:
tun.dispatcher.DeliverNetworkPacket("", "", ipv6.ProtocolNumber, pkb) tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
} }
return len(buf), nil return len(buf), nil
} }
func (tun *netTun) WriteNotify() {
pkt := tun.ep.Read()
if pkt.IsNil() {
return
}
view := pkt.ToView()
pkt.DecRef()
tun.incomingPacket <- view
}
func (tun *netTun) Flush() error { func (tun *netTun) Flush() error {
return nil return nil
} }
@@ -180,9 +161,13 @@ func (tun *netTun) Close() error {
if tun.events != nil { if tun.events != nil {
close(tun.events) close(tun.events)
} }
tun.ep.Close()
if tun.incomingPacket != nil { if tun.incomingPacket != nil {
close(tun.incomingPacket) close(tun.incomingPacket)
} }
return nil return nil
} }
@@ -190,62 +175,290 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil return tun.mtu, nil
} }
func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
if ip4 := ip.To4(); ip4 != nil { var protoNumber tcpip.NetworkProtocolNumber
return tcpip.FullAddress{ if endpoint.Addr().Is4() {
NIC: 1, protoNumber = ipv4.ProtocolNumber
Addr: tcpip.Address(ip4),
Port: uint16(port),
}, ipv4.ProtocolNumber
} else { } else {
return tcpip.FullAddress{ protoNumber = ipv6.ProtocolNumber
NIC: 1,
Addr: tcpip.Address(ip),
Port: uint16(port),
}, ipv6.ProtocolNumber
} }
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
} }
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil { if addr == nil {
panic("todo: deal with auto addr semantics for nil addr") return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
} }
fa, pn := convertToFullAddr(addr.IP, addr.Port) ip, _ := netip.AddrFromSlice(addr.IP)
return gonet.DialContextTCP(ctx, net.stack, fa, pn) return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
}
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialTCP(net.stack, fa, pn)
} }
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil { if addr == nil {
panic("todo: deal with auto addr semantics for nil addr") return net.DialTCPAddrPort(netip.AddrPort{})
} }
fa, pn := convertToFullAddr(addr.IP, addr.Port) ip, _ := netip.AddrFromSlice(addr.IP)
return gonet.DialTCP(net.stack, fa, pn) return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
}
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
fa, pn := convertToFullAddr(addr)
return gonet.ListenTCP(net.stack, fa, pn)
} }
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil { if addr == nil {
panic("todo: deal with auto addr semantics for nil addr") return net.ListenTCPAddrPort(netip.AddrPort{})
} }
fa, pn := convertToFullAddr(addr.IP, addr.Port) ip, _ := netip.AddrFromSlice(addr.IP)
return gonet.ListenTCP(net.stack, fa, pn) return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
} }
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber var pn tcpip.NetworkProtocolNumber
if laddr != nil { if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress var addr tcpip.FullAddress
addr, pn = convertToFullAddr(laddr.IP, laddr.Port) addr, pn = convertToFullAddr(laddr)
lfa = &addr lfa = &addr
} }
if raddr != nil { if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress var addr tcpip.FullAddress
addr, pn = convertToFullAddr(raddr.IP, raddr.Port) addr, pn = convertToFullAddr(raddr)
rfa = &addr rfa = &addr
} }
return gonet.DialUDP(net.stack, lfa, rfa, pn) return gonet.DialUDP(net.stack, lfa, rfa, pn)
} }
func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
return net.DialUDPAddrPort(laddr, netip.AddrPort{})
}
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
var la, ra netip.AddrPort
if laddr != nil {
ip, _ := netip.AddrFromSlice(laddr.IP)
la = netip.AddrPortFrom(ip, uint16(laddr.Port))
}
if raddr != nil {
ip, _ := netip.AddrFromSlice(raddr.IP)
ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
}
return net.DialUDPAddrPort(la, ra)
}
func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
return net.DialUDP(laddr, nil)
}
type PingConn struct {
laddr PingAddr
raddr PingAddr
wq waiter.Queue
ep tcpip.Endpoint
deadline *time.Timer
}
type PingAddr struct{ addr netip.Addr }
func (ia PingAddr) String() string {
return ia.addr.String()
}
func (ia PingAddr) Network() string {
if ia.addr.Is4() {
return "ping4"
} else if ia.addr.Is6() {
return "ping6"
}
return "ping"
}
func (ia PingAddr) Addr() netip.Addr {
return ia.addr
}
func PingAddrFromAddr(addr netip.Addr) *PingAddr {
return &PingAddr{addr}
}
func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
if !laddr.IsValid() && !raddr.IsValid() {
return nil, errors.New("ping dial: invalid address")
}
v6 := laddr.Is6() || raddr.Is6()
bind := laddr.IsValid()
if !bind {
if v6 {
laddr = netip.IPv6Unspecified()
} else {
laddr = netip.IPv4Unspecified()
}
}
tn := icmp.ProtocolNumber4
pn := ipv4.ProtocolNumber
if v6 {
tn = icmp.ProtocolNumber6
pn = ipv6.ProtocolNumber
}
pc := &PingConn{
laddr: PingAddr{laddr},
deadline: time.NewTimer(time.Hour << 10),
}
pc.deadline.Stop()
ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
if tcpipErr != nil {
return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
}
pc.ep = ep
if bind {
fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
return nil, fmt.Errorf("ping bind: %s", tcpipErr)
}
}
if raddr.IsValid() {
pc.raddr = PingAddr{raddr}
fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
return nil, fmt.Errorf("ping connect: %s", tcpipErr)
}
}
return pc, nil
}
func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
return net.DialPingAddr(laddr, netip.Addr{})
}
func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
var la, ra netip.Addr
if laddr != nil {
la = laddr.addr
}
if raddr != nil {
ra = raddr.addr
}
return net.DialPingAddr(la, ra)
}
func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
var la netip.Addr
if laddr != nil {
la = laddr.addr
}
return net.ListenPingAddr(la)
}
func (pc *PingConn) LocalAddr() net.Addr {
return pc.laddr
}
func (pc *PingConn) RemoteAddr() net.Addr {
return pc.raddr
}
func (pc *PingConn) Close() error {
pc.deadline.Reset(0)
pc.ep.Close()
return nil
}
func (pc *PingConn) SetWriteDeadline(t time.Time) error {
return errors.New("not implemented")
}
func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
var na netip.Addr
switch v := addr.(type) {
case *PingAddr:
na = v.addr
case *net.IPAddr:
na, _ = netip.AddrFromSlice(v.IP)
default:
return 0, fmt.Errorf("ping write: wrong net.Addr type")
}
if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
return 0, fmt.Errorf("ping write: mismatched protocols")
}
buf := bytes.NewReader(p)
rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
// won't block, no deadlines
n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
To: &rfa,
})
if tcpipErr != nil {
return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
}
return int(n64), nil
}
func (pc *PingConn) Write(p []byte) (n int, err error) {
return pc.WriteTo(p, &pc.raddr)
}
func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
pc.wq.EventRegister(&e)
defer pc.wq.EventUnregister(&e)
select {
case <-pc.deadline.C:
return 0, nil, os.ErrDeadlineExceeded
case <-notifyCh:
}
w := tcpip.SliceWriter(p)
res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
NeedRemoteAddr: true,
})
if tcpipErr != nil {
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
}
remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr))
return res.Count, &PingAddr{remoteAddr}, nil
}
func (pc *PingConn) Read(p []byte) (n int, err error) {
n, _, err = pc.ReadFrom(p)
return
}
func (pc *PingConn) SetDeadline(t time.Time) error {
// pc.SetWriteDeadline is unimplemented
return pc.SetReadDeadline(t)
}
func (pc *PingConn) SetReadDeadline(t time.Time) error {
pc.deadline.Reset(time.Until(t))
return nil
}
var ( var (
errNoSuchHost = errors.New("no such host") errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral") errLameReferral = errors.New("lame referral")
@@ -421,7 +634,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p, h, nil return p, h, nil
} }
func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q) id, udpReq, tcpReq, err := newRequest(q)
if err != nil { if err != nil {
@@ -435,16 +648,19 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net.Conn var c net.Conn
var err error var err error
if useUDP { if useUDP {
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53}) c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else { } else {
c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53}) c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
} }
if err != nil { if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err return dnsmessage.Parser{}, dnsmessage.Header{}, err
} }
if d, ok := ctx.Deadline(); ok && !d.IsZero() { if d, ok := ctx.Deadline(); ok && !d.IsZero() {
c.SetDeadline(d) err := c.SetDeadline(d)
if err != nil {
return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
} }
var p dnsmessage.Parser var p dnsmessage.Parser
var h dnsmessage.Header var h dnsmessage.Header
@@ -588,8 +804,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx zlen = zidx
} }
} }
if ip := net.ParseIP(host[:zlen]); ip != nil { if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
return []string{host[:zlen]}, nil return []string{ip.String()}, nil
} }
if !isDomainName(host) { if !isDomainName(host) {
@@ -600,7 +816,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string server string
error error
} }
var addrsV4, addrsV6 []net.IP var addrsV4, addrsV6 []netip.Addr
lanes := 0 lanes := 0
if tnet.hasV4 { if tnet.hasV4 {
lanes++ lanes++
@@ -655,7 +871,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
} }
break loop break loop
} }
addrsV4 = append(addrsV4, net.IP(a.A[:])) addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA: case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource() aaaa, err := result.p.AAAAResource()
@@ -667,7 +883,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
} }
break loop break loop
} }
addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:])) addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default: default:
if err := result.p.SkipAnswer(); err != nil { if err := result.p.SkipAnswer(); err != nil {
@@ -683,7 +899,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
} }
} }
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
var addrs []net.IP var addrs []netip.Addr
if tnet.hasV6 { if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...) addrs = append(addrsV6, addrsV4...)
} else { } else {
@@ -720,44 +936,48 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er
return now.Add(timeout), nil return now.Add(timeout), nil
} }
var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if ctx == nil { if ctx == nil {
panic("nil context") panic("nil context")
} }
var acceptV4, acceptV6, useUDP bool var acceptV4, acceptV6 bool
if len(network) == 3 { matches := protoSplitter.FindStringSubmatch(network)
if matches == nil {
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
} else if len(matches[2]) == 0 {
acceptV4 = true acceptV4 = true
acceptV6 = true acceptV6 = true
} else if len(network) == 4 { } else {
acceptV4 = network[3] == '4' acceptV4 = matches[2][0] == '4'
acceptV6 = network[3] == '6' acceptV6 = !acceptV4
} }
if !acceptV4 && !acceptV6 { var host string
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} var port int
} if matches[1] == "ping" {
if network[:3] == "udp" { host = address
useUDP = true } else {
} else if network[:3] != "tcp" { var sport string
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} var err error
} host, sport, err = net.SplitHostPort(address)
host, sport, err := net.SplitHostPort(address) if err != nil {
if err != nil { return nil, &net.OpError{Op: "dial", Err: err}
return nil, &net.OpError{Op: "dial", Err: err} }
} port, err = strconv.Atoi(sport)
port, err := strconv.Atoi(sport) if err != nil || port < 0 || port > 65535 {
if err != nil || port < 0 || port > 65535 { return nil, &net.OpError{Op: "dial", Err: errNumericPort}
return nil, &net.OpError{Op: "dial", Err: errNumericPort} }
} }
allAddr, err := tnet.LookupContextHost(ctx, host) allAddr, err := tnet.LookupContextHost(ctx, host)
if err != nil { if err != nil {
return nil, &net.OpError{Op: "dial", Err: err} return nil, &net.OpError{Op: "dial", Err: err}
} }
var addrs []net.IP var addrs []netip.AddrPort
for _, addr := range allAddr { for _, addr := range allAddr {
if strings.IndexByte(addr, ':') != -1 && acceptV6 { ip, err := netip.ParseAddr(addr)
addrs = append(addrs, net.ParseIP(addr)) if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
} else if strings.IndexByte(addr, '.') != -1 && acceptV4 { addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
addrs = append(addrs, net.ParseIP(addr))
} }
} }
if len(addrs) == 0 && len(allAddr) != 0 { if len(addrs) == 0 && len(allAddr) != 0 {
@@ -795,10 +1015,13 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
} }
var c net.Conn var c net.Conn
if useUDP { switch matches[1] {
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port}) case "tcp":
} else { c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port}) case "udp":
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
case "ping":
c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
} }
if err == nil { if err == nil {
return c, nil return c, nil

View File

@@ -1,8 +1,8 @@
// +build !windows //go:build darwin || freebsd
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tun package tun

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tun package tun
@@ -24,6 +24,6 @@ type Device interface {
Flush() error // flush all previous writes to the device Flush() error // flush all previous writes to the device
MTU() (int, error) // returns the MTU of the device MTU() (int, error) // returns the MTU of the device
Name() (string, error) // fetches and returns the current name Name() (string, error) // fetches and returns the current name
Events() chan Event // returns a constant channel of events related to the device Events() <-chan Event // returns a constant channel of events related to the device
Close() error // stops the device and closes the event channel Close() error // stops the device and closes the event channel
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tun package tun
@@ -8,9 +8,9 @@ package tun
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
"sync"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
@@ -27,6 +27,7 @@ type NativeTun struct {
events chan Event events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
closeOnce sync.Once
} }
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
@@ -106,8 +107,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
} }
} }
fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -116,6 +116,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
copy(ctlInfo.Name[:], []byte(utunControlName)) copy(ctlInfo.Name[:], []byte(utunControlName))
err = unix.IoctlCtlInfo(fd, ctlInfo) err = unix.IoctlCtlInfo(fd, ctlInfo)
if err != nil { if err != nil {
unix.Close(fd)
return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err) return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
} }
@@ -126,11 +127,13 @@ func CreateTUN(name string, mtu int) (Device, error) {
err = unix.Connect(fd, sc) err = unix.Connect(fd, sc)
if err != nil { if err != nil {
unix.Close(fd)
return nil, err return nil, err
} }
err = syscall.SetNonblock(fd, true) err = unix.SetNonblock(fd, true)
if err != nil { if err != nil {
unix.Close(fd)
return nil, err return nil, err
} }
tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu)
@@ -138,7 +141,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if err == nil && name == "utun" { if err == nil && name == "utun" {
fname := os.Getenv("WG_TUN_NAME_FILE") fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" { if fname != "" {
ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400) os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
} }
} }
@@ -170,7 +173,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err return nil, err
} }
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil { if err != nil {
tun.tunFile.Close() tun.tunFile.Close()
return nil, err return nil, err
@@ -210,7 +213,7 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile return tun.tunFile
} }
func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Events() <-chan Event {
return tun.events return tun.events
} }
@@ -229,7 +232,6 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
// reserve space for header // reserve space for header
buff = buff[offset-4:] buff = buff[offset-4:]
@@ -257,14 +259,16 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err2 error var err1, err2 error
err1 := tun.tunFile.Close() tun.closeOnce.Do(func() {
if tun.routeSocket != -1 { err1 = tun.tunFile.Close()
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) if tun.routeSocket != -1 {
err2 = unix.Close(tun.routeSocket) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
} else if tun.events != nil { err2 = unix.Close(tun.routeSocket)
close(tun.events) } else if tun.events != nil {
} close(tun.events)
}
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }
@@ -272,12 +276,11 @@ func (tun *NativeTun) Close() error {
} }
func (tun *NativeTun) setMTU(n int) error { func (tun *NativeTun) setMTU(n int) error {
fd, err := unix.Socket( fd, err := socketCloexec(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
0, 0,
) )
if err != nil { if err != nil {
return err return err
} }
@@ -296,12 +299,11 @@ func (tun *NativeTun) setMTU(n int) error {
} }
func (tun *NativeTun) MTU() (int, error) { func (tun *NativeTun) MTU() (int, error) {
fd, err := unix.Socket( fd, err := socketCloexec(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
0, 0,
) )
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -315,3 +317,15 @@ func (tun *NativeTun) MTU() (int, error) {
return int(ifr.MTU), nil return int(ifr.MTU), nil
} }
func socketCloexec(family, sotype, proto int) (fd int, err error) {
// See go/src/net/sys_cloexec.go for background.
syscall.ForkLock.RLock()
defer syscall.ForkLock.RUnlock()
fd, err = unix.Socket(family, sotype, proto)
if err == nil {
unix.CloseOnExec(fd)
}
return
}

View File

@@ -1,66 +1,57 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tun package tun
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h
// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96))
const ( const (
_TUNSIFHEAD = 0x80047460 _TUNSIFHEAD = 0x80047460
_TUNSIFMODE = 0x8004745e _TUNSIFMODE = 0x8004745e
_TUNGIFNAME = 0x4020745d
_TUNSIFPID = 0x2000745f _TUNSIFPID = 0x2000745f
_SIOCGIFINFO_IN6 = 0xc048696c
_SIOCSIFINFO_IN6 = 0xc048696d
_ND6_IFF_AUTO_LINKLOCAL = 0x20
_ND6_IFF_NO_DAD = 0x100
) )
// TODO: move into x/sys/unix // Iface requests with just the name
const ( type ifreqName struct {
SIOCGIFINFO_IN6 = 0xc048696c Name [unix.IFNAMSIZ]byte
SIOCSIFINFO_IN6 = 0xc048696d _ [16]byte
ND6_IFF_AUTO_LINKLOCAL = 0x20 }
ND6_IFF_NO_DAD = 0x100
)
// Iface status string max len // Iface requests with a pointer
const _IFSTATMAX = 800 type ifreqPtr struct {
const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1)
// structure for iface requests with a pointer
type ifreq_ptr struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Data uintptr Data uintptr
Pad0 [16 - SIZEOF_UINTPTR]byte _ [16 - unsafe.Sizeof(uintptr(0))]byte
} }
// Structure for iface mtu get/set ioctls // Iface requests with MTU
type ifreq_mtu struct { type ifreqMtu struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
MTU uint32 MTU uint32
Pad0 [12]byte _ [12]byte
} }
// Structure for interface status request ioctl // ND6 flag manipulation
type ifstat struct { type nd6Req struct {
IfsName [unix.IFNAMSIZ]byte
Ascii [_IFSTATMAX]byte
}
// Structures for nd6 flag manipulation
type in6_ndireq struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Linkmtu uint32 Linkmtu uint32
Maxmtu uint32 Maxmtu uint32
@@ -82,6 +73,7 @@ type NativeTun struct {
events chan Event events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
closeOnce sync.Once
} }
func (tun *NativeTun) routineRouteListener(tunIfindex int) { func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@@ -97,7 +89,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry: retry:
n, err := unix.Read(tun.routeSocket, data) n, err := unix.Read(tun.routeSocket, data)
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { if errors.Is(err, syscall.EINTR) {
goto retry goto retry
} }
tun.errors <- err tun.errors <- err
@@ -141,91 +133,17 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
} }
func tunName(fd uintptr) (string, error) { func tunName(fd uintptr) (string, error) {
//Terrible hack to make up for freebsd not having a TUNGIFNAME var ifreq ifreqName
_, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq)))
//First, make sure the tun pid matches this proc's pid if err != 0 {
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(_TUNSIFPID),
uintptr(0),
)
if errno != 0 {
return "", fmt.Errorf("failed to set tun device PID: %s", errno.Error())
}
// Open iface control socket
confd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return "", err return "", err
} }
return unix.ByteSliceToString(ifreq.Name[:]), nil
defer unix.Close(confd)
procPid := os.Getpid()
//Try to find interface with matching PID
for i := 1; ; i++ {
iface, _ := net.InterfaceByIndex(i)
if err != nil || iface == nil {
break
}
// Structs for getting data in and out of SIOCGIFSTATUS ioctl
var ifstatus ifstat
copy(ifstatus.IfsName[:], iface.Name)
// Make the syscall to get the status string
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(confd),
uintptr(unix.SIOCGIFSTATUS),
uintptr(unsafe.Pointer(&ifstatus)),
)
if errno != 0 {
continue
}
nullStr := ifstatus.Ascii[:]
i := bytes.IndexByte(nullStr, 0)
if i < 1 {
continue
}
statStr := string(nullStr[:i])
var pidNum int = 0
// Finally get the owning PID
// Format string taken from sys/net/if_tun.c
_, err := fmt.Sscanf(statStr, "\tOpened by PID %d\n", &pidNum)
if err != nil {
continue
}
if pidNum == procPid {
return iface.Name, nil
}
}
return "", nil
} }
// Destroy a named system interface // Destroy a named system interface
func tunDestroy(name string) error { func tunDestroy(name string) error {
// Open control socket. fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
return err return err
} }
@@ -233,14 +151,9 @@ func tunDestroy(name string) error {
var ifr [32]byte var ifr [32]byte
copy(ifr[:], name) copy(ifr[:], name)
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0])))
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCIFDESTROY),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 { if errno != 0 {
return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error()) return fmt.Errorf("failed to destroy interface %s: %w", name, errno)
} }
return nil return nil
@@ -257,7 +170,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, fmt.Errorf("interface %s already exists", name) return nil, fmt.Errorf("interface %s already exists", name)
} }
tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR, 0) tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -276,103 +189,94 @@ func CreateTUN(name string, mtu int) (Device, error) {
ifheadmode := 1 ifheadmode := 1
var errno syscall.Errno var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) { tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode)))
unix.SYS_IOCTL,
fd,
uintptr(_TUNSIFHEAD),
uintptr(unsafe.Pointer(&ifheadmode)),
)
}) })
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %w", errno) return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno)
} }
// Open control sockets // Get out of PTP mode.
confd, err := unix.Socket( ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST
unix.AF_INET, tun.operateOnFd(func(fd uintptr) {
unix.SOCK_DGRAM, _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags)))
0, })
)
if err != nil { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, err return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno)
} }
defer unix.Close(confd)
confd6, err := unix.Socket( // Disable link-local v6, not just because WireGuard doesn't do that anyway, but
unix.AF_INET6, // also because there are serious races with attaching and detaching LLv6 addresses
unix.SOCK_DGRAM, // in relation to interface lifetime within the FreeBSD kernel.
0, confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
)
if err != nil { if err != nil {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, err return nil, err
} }
defer unix.Close(confd6) defer unix.Close(confd6)
var ndireq nd6Req
// Disable link-local v6, not just because WireGuard doesn't do that anyway, but
// also because there are serious races with attaching and detaching LLv6 addresses
// in relation to interface lifetime within the FreeBSD kernel.
var ndireq in6_ndireq
copy(ndireq.Name[:], assignedName) copy(ndireq.Name[:], assignedName)
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCGIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %w", assignedName, errno) return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno)
} }
ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL
ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCSIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %w", assignedName, errno) return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno)
} }
// Rename the interface if name != "" {
var newnp [unix.IFNAMSIZ]byte confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
copy(newnp[:], name) if err != nil {
var ifr ifreq_ptr tunFile.Close()
copy(ifr.Name[:], assignedName) tunDestroy(assignedName)
ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) return nil, err
_, _, errno = unix.Syscall( }
unix.SYS_IOCTL, defer unix.Close(confd)
uintptr(confd), var newnp [unix.IFNAMSIZ]byte
uintptr(unix.SIOCSIFNAME), copy(newnp[:], name)
uintptr(unsafe.Pointer(&ifr)), var ifr ifreqPtr
) copy(ifr.Name[:], assignedName)
if errno != 0 { ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
tunFile.Close() _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr)))
tunDestroy(assignedName) if errno != 0 {
return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno) tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno)
}
} }
return CreateTUNFromFile(tunFile, mtu) return CreateTUNFromFile(tunFile, mtu)
} }
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan Event, 10), events: make(chan Event, 10),
errors: make(chan error, 1), errors: make(chan error, 1),
} }
var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0))
})
if errno != 0 {
tun.tunFile.Close()
return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno)
}
name, err := tun.Name() name, err := tun.Name()
if err != nil { if err != nil {
tun.tunFile.Close() tun.tunFile.Close()
@@ -391,7 +295,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err return nil, err
} }
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil { if err != nil {
tun.tunFile.Close() tun.tunFile.Close()
return nil, err return nil, err
@@ -425,7 +329,7 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile return tun.tunFile
} }
func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Events() <-chan Event {
return tun.events return tun.events
} }
@@ -443,27 +347,26 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
} }
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
if offset < 4 {
// reserve space for header return 0, io.ErrShortBuffer
buff = buff[offset-4:]
// add packet information header
buff[0] = 0x00
buff[1] = 0x00
buff[2] = 0x00
if buff[4]>>4 == ipv6.Version {
buff[3] = unix.AF_INET6
} else {
buff[3] = unix.AF_INET
} }
buf = buf[offset-4:]
// write if len(buf) < 5 {
return 0, io.ErrShortBuffer
return tun.tunFile.Write(buff) }
buf[0] = 0x00
buf[1] = 0x00
buf[2] = 0x00
switch buf[4] >> 4 {
case 4:
buf[3] = unix.AF_INET
case 6:
buf[3] = unix.AF_INET6
default:
return 0, unix.EAFNOSUPPORT
}
return tun.tunFile.Write(buf)
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
@@ -472,16 +375,18 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err3 error var err1, err2, err3 error
err1 := tun.tunFile.Close() tun.closeOnce.Do(func() {
err2 := tunDestroy(tun.name) err1 = tun.tunFile.Close()
if tun.routeSocket != -1 { err2 = tunDestroy(tun.name)
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) if tun.routeSocket != -1 {
err3 = unix.Close(tun.routeSocket) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
tun.routeSocket = -1 err3 = unix.Close(tun.routeSocket)
} else if tun.events != nil { tun.routeSocket = -1
close(tun.events) } else if tun.events != nil {
} close(tun.events)
}
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }
@@ -492,70 +397,34 @@ func (tun *NativeTun) Close() error {
} }
func (tun *NativeTun) setMTU(n int) error { func (tun *NativeTun) setMTU(n int) error {
// open datagram socket fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
return err return err
} }
defer unix.Close(fd) defer unix.Close(fd)
// do ioctl call var ifr ifreqMtu
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name) copy(ifr.Name[:], tun.name)
ifr.MTU = uint32(n) ifr.MTU = uint32(n)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr)))
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 { if errno != 0 {
return fmt.Errorf("failed to set MTU on %s", tun.name) return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno)
} }
return nil return nil
} }
func (tun *NativeTun) MTU() (int, error) { func (tun *NativeTun) MTU() (int, error) {
// open datagram socket fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer unix.Close(fd) defer unix.Close(fd)
// do ioctl call var ifr ifreqMtu
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name) copy(ifr.Name[:], tun.name)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr)))
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 { if errno != 0 {
return 0, fmt.Errorf("failed to get MTU on %s", tun.name) return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno)
} }
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tun package tun
@@ -9,7 +9,7 @@ package tun
*/ */
import ( import (
"bytes" "errors"
"fmt" "fmt"
"os" "os"
"sync" "sync"
@@ -39,6 +39,8 @@ type NativeTun struct {
hackListenerClosed sync.Mutex hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{} statusListenersShutdown chan struct{}
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface nameCache string // name of interface
nameErr error nameErr error
@@ -53,6 +55,11 @@ func (tun *NativeTun) routineHackListener() {
/* This is needed for the detection to work across network namespaces /* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch. * If you are reading this and know a better method, please get in touch.
*/ */
last := 0
const (
up = 1
down = 2
)
for { for {
sysconn, err := tun.tunFile.SyscallConn() sysconn, err := tun.tunFile.SyscallConn()
if err != nil { if err != nil {
@@ -66,13 +73,19 @@ func (tun *NativeTun) routineHackListener() {
} }
switch err { switch err {
case unix.EINVAL: case unix.EINVAL:
// If the tunnel is up, it reports that write() is if last != up {
// allowed but we provided invalid data. // If the tunnel is up, it reports that write() is
tun.events <- EventUp // allowed but we provided invalid data.
tun.events <- EventUp
last = up
}
case unix.EIO: case unix.EIO:
// If the tunnel is down, it reports that no I/O if last != down {
// is possible, without checking our provided data. // If the tunnel is down, it reports that no I/O
tun.events <- EventDown // is possible, without checking our provided data.
tun.events <- EventDown
last = down
}
default: default:
return return
} }
@@ -86,7 +99,7 @@ func (tun *NativeTun) routineHackListener() {
} }
func createNetlinkSocket() (int, error) { func createNetlinkSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil { if err != nil {
return -1, err return -1, err
} }
@@ -181,7 +194,7 @@ func (tun *NativeTun) routineNetlinkListener() {
func getIFIndex(name string) (int32, error) { func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0, 0,
) )
if err != nil { if err != nil {
@@ -215,10 +228,9 @@ func (tun *NativeTun) setMTU(n int) error {
// open datagram socket // open datagram socket
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0, 0,
) )
if err != nil { if err != nil {
return err return err
} }
@@ -252,10 +264,9 @@ func (tun *NativeTun) MTU() (int, error) {
// open datagram socket // open datagram socket
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0, 0,
) )
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -309,39 +320,33 @@ func (tun *NativeTun) nameSlow() (string, error) {
if errno != 0 { if errno != 0 {
return "", fmt.Errorf("failed to get name of TUN device: %w", errno) return "", fmt.Errorf("failed to get name of TUN device: %w", errno)
} }
name := ifr[:] return unix.ByteSliceToString(ifr[:]), nil
if i := bytes.IndexByte(name, 0); i != -1 {
name = name[:i]
}
return string(name), nil
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
if tun.nopi { if tun.nopi {
buff = buff[offset:] buf = buf[offset:]
} else { } else {
// reserve space for header // reserve space for header
buf = buf[offset-4:]
buff = buff[offset-4:]
// add packet information header // add packet information header
buf[0] = 0x00
buff[0] = 0x00 buf[1] = 0x00
buff[1] = 0x00 if buf[4]>>4 == ipv6.Version {
buf[2] = 0x86
if buff[4]>>4 == ipv6.Version { buf[3] = 0xdd
buff[2] = 0x86
buff[3] = 0xdd
} else { } else {
buff[2] = 0x08 buf[2] = 0x08
buff[3] = 0x00 buf[3] = 0x00
} }
} }
// write n, err := tun.tunFile.Write(buf)
if errors.Is(err, syscall.EBADFD) {
return tun.tunFile.Write(buff) err = os.ErrClosed
}
return n, err
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
@@ -349,40 +354,45 @@ func (tun *NativeTun) Flush() error {
return nil return nil
} }
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buf []byte, offset int) (n int, err error) {
select { select {
case err := <-tun.errors: case err = <-tun.errors:
return 0, err
default: default:
if tun.nopi { if tun.nopi {
return tun.tunFile.Read(buff[offset:]) n, err = tun.tunFile.Read(buf[offset:])
} else { } else {
buff := buff[offset-4:] buff := buf[offset-4:]
n, err := tun.tunFile.Read(buff[:]) n, err = tun.tunFile.Read(buff[:])
if n < 4 { if errors.Is(err, syscall.EBADFD) {
return 0, err err = os.ErrClosed
}
if n < 4 {
n = 0
} else {
n -= 4
} }
return n - 4, err
} }
} }
return
} }
func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Events() <-chan Event {
return tun.events return tun.events
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err1 error var err1, err2 error
if tun.statusListenersShutdown != nil { tun.closeOnce.Do(func() {
close(tun.statusListenersShutdown) if tun.statusListenersShutdown != nil {
if tun.netlinkCancel != nil { close(tun.statusListenersShutdown)
err1 = tun.netlinkCancel.Cancel() if tun.netlinkCancel != nil {
err1 = tun.netlinkCancel.Cancel()
}
} else if tun.events != nil {
close(tun.events)
} }
} else if tun.events != nil { err2 = tun.tunFile.Close()
close(tun.events) })
}
err2 := tun.tunFile.Close()
if err1 != nil { if err1 != nil {
return err1 return err1
} }
@@ -390,7 +400,7 @@ func (tun *NativeTun) Close() error {
} }
func CreateTUN(name string, mtu int) (Device, error) { func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0) nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath) return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
@@ -402,6 +412,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
nameBytes := []byte(name) nameBytes := []byte(name)
if len(nameBytes) >= unix.IFNAMSIZ { if len(nameBytes) >= unix.IFNAMSIZ {
unix.Close(nfd)
return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG) return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
} }
copy(ifr[:], nameBytes) copy(ifr[:], nameBytes)
@@ -414,17 +425,19 @@ func CreateTUN(name string, mtu int) (Device, error) {
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
if errno != 0 { if errno != 0 {
unix.Close(nfd)
return nil, errno return nil, errno
} }
err = unix.SetNonblock(nfd, true) err = unix.SetNonblock(nfd, true)
if err != nil {
unix.Close(nfd)
return nil, err
}
// Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
fd := os.NewFile(uintptr(nfd), cloneDevicePath) fd := os.NewFile(uintptr(nfd), cloneDevicePath)
if err != nil {
return nil, err
}
return CreateTUNFromFile(fd, mtu) return CreateTUNFromFile(fd, mtu)
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tun package tun
@@ -8,9 +8,9 @@ package tun
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
"sync"
"syscall" "syscall"
"unsafe" "unsafe"
@@ -33,6 +33,7 @@ type NativeTun struct {
events chan Event events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
closeOnce sync.Once
} }
func (tun *NativeTun) routineRouteListener(tunIfindex int) { func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@@ -113,10 +114,10 @@ func CreateTUN(name string, mtu int) (Device, error) {
var err error var err error
if ifIndex != -1 { if ifIndex != -1 {
tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0) tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
} else { } else {
for ifIndex = 0; ifIndex < 256; ifIndex++ { for ifIndex = 0; ifIndex < 256; ifIndex++ {
tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0) tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
if err == nil || !errors.Is(err, syscall.EBUSY) { if err == nil || !errors.Is(err, syscall.EBUSY) {
break break
} }
@@ -132,7 +133,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if err == nil && name == "tun" { if err == nil && name == "tun" {
fname := os.Getenv("WG_TUN_NAME_FILE") fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" { if fname != "" {
ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400) os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
} }
} }
@@ -164,7 +165,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err return nil, err
} }
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil { if err != nil {
tun.tunFile.Close() tun.tunFile.Close()
return nil, err return nil, err
@@ -199,7 +200,7 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile return tun.tunFile
} }
func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Events() <-chan Event {
return tun.events return tun.events
} }
@@ -218,7 +219,6 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
// reserve space for header // reserve space for header
buff = buff[offset-4:] buff = buff[offset-4:]
@@ -246,15 +246,17 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
var err2 error var err1, err2 error
err1 := tun.tunFile.Close() tun.closeOnce.Do(func() {
if tun.routeSocket != -1 { err1 = tun.tunFile.Close()
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) if tun.routeSocket != -1 {
err2 = unix.Close(tun.routeSocket) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
tun.routeSocket = -1 err2 = unix.Close(tun.routeSocket)
} else if tun.events != nil { tun.routeSocket = -1
close(tun.events) } else if tun.events != nil {
} close(tun.events)
}
})
if err1 != nil { if err1 != nil {
return err1 return err1
} }
@@ -268,10 +270,9 @@ func (tun *NativeTun) setMTU(n int) error {
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0, 0,
) )
if err != nil { if err != nil {
return err return err
} }
@@ -303,10 +304,9 @@ func (tun *NativeTun) MTU() (int, error) {
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0, 0,
) )
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tun package tun
@@ -8,15 +8,15 @@ package tun
import ( import (
"errors" "errors"
"fmt" "fmt"
"log"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
_ "unsafe" _ "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun/wintun" "golang.zx2c4.com/wintun"
) )
const ( const (
@@ -26,26 +26,30 @@ const (
) )
type rateJuggler struct { type rateJuggler struct {
current uint64 current atomic.Uint64
nextByteCount uint64 nextByteCount atomic.Uint64
nextStartTime int64 nextStartTime atomic.Int64
changing int32 changing atomic.Bool
} }
type NativeTun struct { type NativeTun struct {
wt *wintun.Adapter wt *wintun.Adapter
name string
handle windows.Handle handle windows.Handle
close bool
events chan Event
errors chan error
forcedMTU int
rate rateJuggler rate rateJuggler
session wintun.Session session wintun.Session
readWait windows.Handle readWait windows.Handle
events chan Event
running sync.WaitGroup
closeOnce sync.Once
close atomic.Bool
forcedMTU int
} }
var WintunPool, _ = wintun.MakePool("WireGuard") var (
var WintunStaticRequestedGUID *windows.GUID WintunTunnelType = "WireGuard"
WintunStaticRequestedGUID *windows.GUID
)
//go:linkname procyield runtime.procyield //go:linkname procyield runtime.procyield
func procyield(cycles uint32) func procyield(cycles uint32)
@@ -53,38 +57,19 @@ func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime //go:linkname nanotime runtime.nanotime
func nanotime() int64 func nanotime() int64
//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun // CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused. // interface with the same name exist, it is reused.
//
func CreateTUN(ifname string, mtu int) (Device, error) { func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
} }
//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused. // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
var err error wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
var wt *wintun.Adapter
// Does an interface with this name already exist?
wt, err = WintunPool.OpenAdapter(ifname)
if err == nil {
// If so, we delete it, in case it has weird residual configuration.
_, err = wt.Delete(true)
if err != nil {
return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
}
}
wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error creating interface: %w", err) return nil, fmt.Errorf("Error creating interface: %w", err)
} }
if rebootRequired {
log.Println("Windows indicated a reboot is required.")
}
forcedMTU := 1420 forcedMTU := 1420
if mtu > 0 { if mtu > 0 {
@@ -93,15 +78,15 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun := &NativeTun{ tun := &NativeTun{
wt: wt, wt: wt,
name: ifname,
handle: windows.InvalidHandle, handle: windows.InvalidHandle,
events: make(chan Event, 10), events: make(chan Event, 10),
errors: make(chan error, 1),
forcedMTU: forcedMTU, forcedMTU: forcedMTU,
} }
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil { if err != nil {
tun.wt.Delete(false) tun.wt.Close()
close(tun.events) close(tun.events)
return nil, fmt.Errorf("Error starting session: %w", err) return nil, fmt.Errorf("Error starting session: %w", err)
} }
@@ -110,25 +95,29 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
} }
func (tun *NativeTun) Name() (string, error) { func (tun *NativeTun) Name() (string, error) {
return tun.wt.Name() return tun.name, nil
} }
func (tun *NativeTun) File() *os.File { func (tun *NativeTun) File() *os.File {
return nil return nil
} }
func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Events() <-chan Event {
return tun.events return tun.events
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
tun.close = true
tun.session.End()
var err error var err error
if tun.wt != nil { tun.closeOnce.Do(func() {
_, err = tun.wt.Delete(false) tun.close.Store(true)
} windows.SetEvent(tun.readWait)
close(tun.events) tun.running.Wait()
tun.session.End()
if tun.wt != nil {
tun.wt.Close()
}
close(tun.events)
})
return err return err
} }
@@ -138,22 +127,26 @@ func (tun *NativeTun) MTU() (int, error) {
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) { func (tun *NativeTun) ForceMTU(mtu int) {
update := tun.forcedMTU != mtu
tun.forcedMTU = mtu tun.forcedMTU = mtu
if update {
tun.events <- EventMTUUpdate
}
} }
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
retry: retry:
select { if tun.close.Load() {
case err := <-tun.errors: return 0, os.ErrClosed
return 0, err
default:
} }
start := nanotime() start := nanotime()
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for { for {
if tun.close { if tun.close.Load() {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
packet, err := tun.session.ReceivePacket() packet, err := tun.session.ReceivePacket()
@@ -185,7 +178,9 @@ func (tun *NativeTun) Flush() error {
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
if tun.close { tun.running.Add(1)
defer tun.running.Done()
if tun.close.Load() {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
@@ -209,6 +204,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
// LUID returns Windows interface instance ID. // LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 { func (tun *NativeTun) LUID() uint64 {
tun.running.Add(1)
defer tun.running.Done()
if tun.close.Load() {
return 0
}
return tun.wt.LUID() return tun.wt.LUID()
} }
@@ -219,15 +219,15 @@ func (tun *NativeTun) RunningVersion() (version uint32, err error) {
func (rate *rateJuggler) update(packetLen uint64) { func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime() now := nanotime()
total := atomic.AddUint64(&rate.nextByteCount, packetLen) total := rate.nextByteCount.Add(packetLen)
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity { if period >= rateMeasurementGranularity {
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { if !rate.changing.CompareAndSwap(false, true) {
return return
} }
atomic.StoreInt64(&rate.nextStartTime, now) rate.nextStartTime.Store(now)
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
atomic.StoreUint64(&rate.nextByteCount, 0) rate.nextByteCount.Store(0)
atomic.StoreInt32(&rate.changing, 0) rate.changing.Store(false)
} }
} }

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/ */
package tuntest package tuntest
@@ -8,13 +8,13 @@ package tuntest
import ( import (
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net/netip"
"os" "os"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
func Ping(dst, src net.IP) []byte { func Ping(dst, src netip.Addr) []byte {
localPort := uint16(1337) localPort := uint16(1337)
seq := uint16(0) seq := uint16(0)
@@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 {
return ^uint16(v) return ^uint16(v)
} }
func genICMPv4(payload []byte, dst, src net.IP) []byte { func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
const ( const (
icmpv4ProtocolNumber = 1 icmpv4ProtocolNumber = 1
icmpv4Echo = 8 icmpv4Echo = 8
@@ -70,8 +70,8 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl ip[8] = ttl
ip[9] = icmpv4ProtocolNumber ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4()) copy(ip[12:], src.AsSlice())
copy(ip[16:], dst.To4()) copy(ip[16:], dst.AsSlice())
chksum = ^checksum(ip[:], 0) chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
@@ -79,7 +79,6 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte {
return pkt return pkt
} }
// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct { type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close Outbound chan []byte // outbound packets, blocks forever on TUN close
@@ -114,7 +113,7 @@ func (t *chTun) File() *os.File { return nil }
func (t *chTun) Read(data []byte, offset int) (int, error) { func (t *chTun) Read(data []byte, offset int) (int, error) {
select { select {
case <-t.c.closed: case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value? return 0, os.ErrClosed
case msg := <-t.c.Outbound: case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil return copy(data[offset:], msg), nil
} }
@@ -131,7 +130,7 @@ func (t *chTun) Write(data []byte, offset int) (int, error) {
copy(msg, data[offset:]) copy(msg, data[offset:])
select { select {
case <-t.c.closed: case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value? return 0, os.ErrClosed
case t.c.Inbound <- msg: case t.c.Inbound <- msg:
return len(data) - offset, nil return len(data) - offset, nil
} }
@@ -139,10 +138,10 @@ func (t *chTun) Write(data []byte, offset int) (int, error) {
const DefaultMTU = 1420 const DefaultMTU = 1420
func (t *chTun) Flush() error { return nil } func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events } func (t *chTun) Events() <-chan tun.Event { return t.c.events }
func (t *chTun) Close() error { func (t *chTun) Close() error {
t.Write(nil, -1) t.Write(nil, -1)
return nil return nil

View File

@@ -1,54 +0,0 @@
// +build !load_wintun_from_rsrc
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module windows.Handle
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != 0 {
return nil
}
const (
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
)
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return windows.GetProcAddress(p.dll.module, p.Name)
}

View File

@@ -1,62 +0,0 @@
// +build load_wintun_from_rsrc
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun/wintun/memmod"
"golang.zx2c4.com/wireguard/tun/wintun/resource"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module *memmod.Module
onLoad func(d *lazyDLL)
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != nil {
return nil
}
const ourModule windows.Handle = 0
resInfo, err := resource.FindByName(ourModule, d.Name, resource.RT_RCDATA)
if err != nil {
return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
}
data, err := resource.Load(ourModule, resInfo)
if err != nil {
return fmt.Errorf("Unable to load resource: %w", err)
}
module, err := memmod.LoadLibrary(data)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
if d.onLoad != nil {
d.onLoad(d)
}
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return p.dll.module.ProcAddressByName(p.Name)
}

View File

@@ -1,59 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
)
func newLazyDLL(name string, onLoad func(d *lazyDLL)) *lazyDLL {
return &lazyDLL{Name: name, onLoad: onLoad}
}
func (d *lazyDLL) NewProc(name string) *lazyProc {
return &lazyProc{dll: d, Name: name}
}
type lazyProc struct {
Name string
mu sync.Mutex
dll *lazyDLL
addr uintptr
}
func (p *lazyProc) Find() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
if p.addr != 0 {
return nil
}
err := p.dll.Load()
if err != nil {
return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
}
addr, err := p.nameToAddr()
if err != nil {
return fmt.Errorf("Error getting %v address: %w", p.Name, err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
return nil
}
func (p *lazyProc) Addr() uintptr {
err := p.Find()
if err != nil {
panic(err)
}
return p.addr
}

Some files were not shown because too many files have changed in this diff Show More