63 Commits

Author SHA1 Message Date
Jason A. Donenfeld
da19db415a version: bump snapshot 2020-11-18 14:24:17 +01:00
Jason A. Donenfeld
52c834c446 mod: bump
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-18 14:24:00 +01:00
Haichao Liu
913f68ce38 device: add write queue mutex for peer
fix panic: send on closed channel when remove peer

Signed-off-by: Haichao Liu <liuhaichao@bytedance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-18 14:22:15 +01:00
Jason A. Donenfeld
60b3766b89 wintun: load from filesystem by default
We let people loading this from resources opt in via:

    go build -tags load_wintun_from_rsrc

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-11 18:51:44 +01:00
Jason A. Donenfeld
82128c47d9 global: switch to using %w instead of %v for Errorf
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-07 21:56:32 +01:00
Jason A. Donenfeld
c192b2eeec mod: update deps
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-07 15:22:18 +01:00
Simon Rozman
a3b231b31e wintun: ring management moved to wintun.dll
Signed-off-by: Simon Rozman <simon@rozman.si>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-07 15:20:49 +01:00
Simon Rozman
65e03a9182 wintun: load wintun.dll from RCDATA resource
Signed-off-by: Simon Rozman <simon@rozman.si>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-07 15:20:49 +01:00
Simon Rozman
3e08b8aee0 wintun: migrate to wintun.dll API
Rather than having every application using Wintun driver reinvent the
wheel, the Wintun device/adapter/interface management has been moved
from wireguard-go to wintun.dll deployed with Wintun itself.

Signed-off-by: Simon Rozman <simon@rozman.si>
2020-11-07 12:46:35 +01:00
Jason A. Donenfeld
5ca1218a5c device: format a few things
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-11-06 18:01:27 +01:00
Tobias Klauser
3b490f30aa tun: use SockaddrCtl from golang.org/x/sys/unix on macOS
Direct syscalls using unix.Syscall(unix.SYS_*, ...) are discouraged on
macOS and might not be supported in future versions. Switch to use
unix.Connect with unix.SockaddrCtl instead.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Tobias Klauser
e6b7c4eef3 tun: use Ioctl{Get,Set}IfreqMTU from golang.org/x/sys/unix on macOS
Direct syscalls using unix.Syscall(unix.SYS_*, ...) are discouraged on
macOS and might not be supported in future versions. Switch to use
unix.Ioctl{Get,Set}IfreqMTU to get and set an interface's MTU.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Tobias Klauser
8ae09213a7 tun: use IoctlCtlInfo from golang.org/x/sys/unix on macOS
Direct syscalls using unix.Syscall(unix.SYS_*, ...) are discouraged on
macOS and might not be supported in future versions. Switch to use
unix.IoctlCtlInfo to get the kernel control info.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Tobias Klauser
36dc8b6994 tun: use GetsockoptString in (*NativeTun).Name on macOS
Direct syscalls using unix.Syscall(unix.SYS_*, ...) are discouraged on
macOS and might not be supported in future versions. Instead, use the
existing unix.GetsockoptString wrapper to get the interface name.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Tobias Klauser
2057f19a61 go.mod: bump golang.org/x/sys to latest version
This adds the fixes for golang/go#41868 which are needed to build
wireguard without direct syscalls on macOS.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-27 16:20:09 +01:00
Brad Fitzpatrick
58a8f05f50 tun/wintun/registry: fix Go 1.15 race/checkptr failure
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
[Jason: ran go mod tidy.]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-21 18:26:10 +02:00
Frank Werner
0b54907a73 Makefile: Add test target
Signed-off-by: Frank Werner <mail@hb9fxq.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-20 12:38:18 +02:00
Riobard Zhan
2c143dce0f replay: minor API changes to more idiomatic Go
Signed-off-by: Riobard Zhan <me@riobard.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-14 10:46:00 +02:00
Riobard Zhan
22af3890f6 replay: clean up internals and better documentation
Signed-off-by: Riobard Zhan <me@riobard.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-14 10:46:00 +02:00
Jason A. Donenfeld
c8fe925020 device: remove global for roaming escape hatch
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-10-14 10:45:31 +02:00
Jason A. Donenfeld
0cfa3314ee replay: divide by bits-per-byte
Bits / Bytes-per-Word misses the step of also dividing by Bits-per-Byte,
which we need in order for this to make sense.

Reported-by: Riobard Zhan <me@riobard.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-09-07 18:51:49 +02:00
Sina Siadat
bc3f505efa device: get free port when testing
Signed-off-by: Sina Siadat <siadat@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-07-31 16:18:53 +02:00
David Crawshaw
507f148e1c device: remove bindsocketshim.go
Both wireguard-windows and wireguard-android access Bind
directly for these methods now.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-07-14 23:18:53 -06:00
Brad Fitzpatrick
31b574ef99 device: remove some unnecessary unsafe
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2020-07-15 06:59:44 +10:00
Tobias Klauser
3c41141fb4 device: use RTMGRP_IPV4_ROUTE to specify multicast groups mask
Use the RTMGRP_IPV4_ROUTE const from x/sys/unix instead of using the
corresponding RTNLGRP_IPV4_ROUTE const to create the multicast groups
mask.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-07-13 17:58:10 -06:00
Dmytro Shynkevych
4369db522b device: wait for routines to stop before removing peers
Peers are currently removed after Device's goroutines are signaled to stop,
but without waiting for them to actually do so, which is racy.

For example, RoutineHandshake may be in Peer.SendKeepalive
when the corresponding peer is removed, which closes its nonce channel.
This causes a send on a closed channel, as observed in tailscale/tailscale#487.

This patch seems to be the correct synchronizing action:
Peer's goroutines are receivers and handle channel closure gracefully,
so Device's goroutines are the ones that should be fully stopped first.

Signed-Off-By: Dmytro Shynkevych <dmytro@tailscale.com>
2020-07-04 20:29:31 +10:00
David Crawshaw
b84f1d4db2 device: export Bind and remove socketfd shims for android
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-06-22 10:42:28 +10:00
David Crawshaw
dfb28757f7 ipc: add comment about socketDirectory linker override on android
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-06-22 10:41:19 +10:00
David Crawshaw
00bcd865e6 conn: add comments saying what uses these interfaces
Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-06-22 10:40:59 +10:00
Jason A. Donenfeld
f28a6d244b device: do not include sticky sockets on android
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-06-07 01:50:20 -06:00
Jason A. Donenfeld
c403da6a39 conn: unbreak boundif on android
Another thing never tested ever.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-06-07 01:48:28 -06:00
Jason A. Donenfeld
d6de6f3ce6 conn: remove useless comment
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-06-07 01:37:01 -06:00
Jason A. Donenfeld
59e556f24e conn: fix windows situation with boundif
This was evidently never tested before committing.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-06-07 01:26:25 -06:00
Jason A. Donenfeld
31faf4c159 replay: account for fqcodel reordering
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-19 17:46:35 -06:00
Jason A. Donenfeld
99eb7896be device: rework padding calculation and don't shadow paddedSize
Reported-by: Jayakumar S <jayakumar82.s@gmail.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-18 15:43:22 -06:00
Dmytro Shynkevych
f60b3919be tai64n: make the test deterministic
In the presence of preemption, the current test may fail transiently.
This uses static test data instead to ensure consistent behavior.

Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
2020-05-06 16:01:48 +10:00
Jason A. Donenfeld
da9d300cf8 main: now that we're upstreamed, relax Linux warning
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 02:20:47 -06:00
Jason A. Donenfeld
59c9929714 README: specify go 1.13
Due to the use of the new errors module, we now require at least 1.13
instead of 1.12.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 02:08:52 -06:00
Jason A. Donenfeld
db0aa39b76 global: update header comments and modules
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 02:08:26 -06:00
David Crawshaw
bc77de2aca ipc: deduplicate some unix-specific code
Cleans up and splits out UAPIOpen to its own file.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
[zx2c4: changed const to var for socketDirectory]
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 02:05:41 -06:00
David Crawshaw
c8596328e7 ipc: remove unnecessary error check
os.MkdirAll never returns an os.IsExist error.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-05-02 02:02:09 -06:00
Jason A. Donenfeld
28c4d04304 device: use atomic access for unlocked keypair.next
Go's GC semantics might not always guarantee the safety of this, and the
race detector gets upset too, so instead we wrap this all in atomic
accessors.

Reported-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2020-05-02 01:56:48 -06:00
Simon Rozman
fdba6c183a wintun: make remaining HWID comparisons case insensitive
c85e4a410f introduced preliminary HWID
checking to speed up Wintun adapter enumeration. However, all HWID are
case insensitive by Windows convention.

Furthermore, a device might have multiple HWIDs. When DevInfo's
DeviceRegistryProperty(SPDRP_HARDWAREID) method returns []string, all
strings returned should be checked against given hardware ID.

This issue was discovered when researching Wintun and wireguard-go on
Windows 10 ARM64. The Wintun adapter was created using devcon.exe
utility with "wintun" hardware ID, causing wireguard-go fail to
enumerate the adapter properly.

Signed-off-by: Simon Rozman <simon@rozman.si>
2020-05-02 01:50:47 -06:00
Simon Rozman
250b9795f3 setupapi: extend struct size constant definitions for arm(64)
Signed-off-by: Simon Rozman <simon@rozman.si>
2020-05-02 01:50:47 -06:00
Avery Pennarun
d60857e1a7 device: add debug logs describing handshake rejection
Useful in testing when bad network stacks repeat or
batch large numbers of packets.

Signed-off-by: Avery Pennarun <apenwarr@tailscale.com>
2020-05-02 01:50:47 -06:00
Brad Fitzpatrick
2fb0a712f0 tun: return a better error message if /dev/net/tun doesn't exist
It was just returning "no such file or directory" (the String of the
syscall.Errno returned by CreateTUN).

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2020-05-02 01:50:47 -06:00
David Anderson
f2c6faad44 device: return generic error from Ipc{Get,Set}Operation.
This makes uapi.go's public API conform to Go style in terms
of error types.

Signed-off-by: David Anderson <danderson@tailscale.com>
2020-05-02 01:49:47 -06:00
Avery Pennarun
c76b818466 tun: NetlinkListener: don't send EventDown before sending EventUp
This works around a startup race condition when competing with
HackListener, which is trying to do the same job. If HackListener
detects that the tundev is running while there is still an event in the
netlink queue that says it isn't running, then the device receives a
string of events like
	EventUp (HackListener)
	EventDown (NetlinkListener)
	EventUp (NetlinkListener)
Unfortunately, after the first EventDown, the device stops itself,
thinking incorrectly that the administrator has downed its tundev.

The device is ignoring the initial EventDown anyway, so just don't emit
it.

Signed-off-by: Avery Pennarun <apenwarr@tailscale.com>
2020-05-02 01:46:42 -06:00
David Crawshaw
de374bfb44 device: give handshake state a type
And unexport handshake constants.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-05-02 01:46:42 -06:00
David Crawshaw
1a1c3d0968 tuntest: split out testing package
This code is useful to other packages writing tests.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-05-02 01:46:42 -06:00
Brad Fitzpatrick
85a45a9651 tun: fix data race on name field
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2020-05-02 01:46:42 -06:00
Brad Fitzpatrick
abd287159e tun: remove unused isUp method
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2020-05-02 01:46:42 -06:00
David Crawshaw
203554620d conn: introduce new package that splits out the Bind and Endpoint types
The sticky socket code stays in the device package for now,
as it reaches deeply into the peer list.

This is the first step in an effort to split some code out of
the very busy device package.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-05-02 01:46:42 -06:00
Avery Pennarun
6aefb61355 wintun: split error message for create vs open namespace.
Signed-off-by: Avery Pennarun <apenwarr@tailscale.com>
2020-05-02 01:44:58 -06:00
David Anderson
3dce460c88 device: add test to ensure Peer fields are safe for atomic access on 32-bit
Adds a test that will fail consistently on 32-bit platforms if the
struct ever changes again to violate the rules. This is likely not
needed because unaligned access crashes reliably, but this will reliably
fail even if tests accidentally pass due to lucky alignment.

Signed-Off-By: David Anderson <danderson@tailscale.com>
2020-05-02 01:44:58 -06:00
David Crawshaw
224bc9e60c rwcancel: no-op builds for windows and darwin
This lets us include the package on those platforms in a
followup commit where we split out a conn package from device.
It also lets us run `go test ./...` when developing on macOS.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-03-30 18:41:39 +11:00
David Crawshaw
9cd8909df2 ratelimiter: use a fake clock in tests and style cleanups
The existing test would occasionally flake out with:

	--- FAIL: TestRatelimiter (0.12s)
	    ratelimiter_test.go:99: Test failed for 127.0.0.1 , on: 7 ( not having refilled enough ) expected: false got: true
	FAIL
	FAIL    golang.zx2c4.com/wireguard/ratelimiter  0.171s

The fake clock also means the tests run much faster, so
testing this package with -count=1000 now takes < 100ms.

While here, several style cleanups. The most significant one
is unembeding the sync.Mutex fields in the rate limiter objects.
Embedded as they were, the lock methods were accessible
outside the ratelimiter package. As they aren't needed externally,
keep them internal to make them easier to reason about.

Passes `go test -race -count=10000 ./ratelimiter`

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-03-30 18:38:36 +11:00
Jason A. Donenfeld
ae88e2a2cd version: bump snapshot 2020-03-20 12:00:53 -06:00
Jason A. Donenfeld
4739708ca4 noise: unify zero checking of ecdh 2020-03-17 23:07:14 -06:00
Tobias Klauser
b33219c2cf global: use RTMGRP_* consts from x/sys/unix
Update the golang.org/x/sys/unix dependency and use the newly introduced
RTMGRP_* consts instead of using the corresponding RTNLGRP_* const to
create a mask.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
2020-03-17 23:07:11 -06:00
Jason A. Donenfeld
9cbcff10dd send: account for zero mtu
Don't divide by zero.
2020-02-14 18:53:55 +01:00
Jason A. Donenfeld
6ed56ff2df device: fix private key removal logic 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
cb4bb63030 uapi: allow unsetting device private key with /dev/null 2020-02-04 22:02:53 +01:00
120 changed files with 3406 additions and 5203 deletions

View File

@@ -22,7 +22,10 @@ wireguard-go: $(wildcard *.go) $(wildcard */*.go)
install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
test:
go test -v ./...
clean:
rm -f wireguard-go
.PHONY: all clean install generate-version-and-build
.PHONY: all clean test install generate-version-and-build

View File

@@ -26,7 +26,7 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
### Linux
This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead use the kernel module; see the [installation page](https://www.wireguard.com/install/) for instructions.
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
### macOS
@@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
## Building
This requires an installation of [go](https://golang.org) ≥ 1.12.
This requires an installation of [go](https://golang.org) ≥ 1.13.
```
$ git clone https://git.zx2c4.com/wireguard-go
@@ -56,7 +56,7 @@ $ make
## License
Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in

34
conn/boundif_android.go Normal file
View File

@@ -0,0 +1,34 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package conn
func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := bind.ipv4.SyscallConn()
if err != nil {
return -1, err
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return -1, err
}
return
}
func (bind *nativeBind) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := bind.ipv6.SyscallConn()
if err != nil {
return -1, err
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return -1, err
}
return
}

View File

@@ -1,13 +1,12 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
package conn
import (
"encoding/binary"
"errors"
"unsafe"
"golang.org/x/sys/windows"
@@ -18,17 +17,13 @@ const (
sockoptIPV6_UNICAST_IF = 31
)
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
if device.net.bind == nil {
return errors.New("Bind is not yet initialized")
}
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
sysconn, err := bind.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo
if err != nil {
return err
}
device.net.bind.(*nativeBind).blackhole4 = blackhole
bind.blackhole4 = blackhole
return nil
}
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
sysconn, err := bind.ipv6.SyscallConn()
if err != nil {
return err
}
@@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo
if err != nil {
return err
}
device.net.bind.(*nativeBind).blackhole6 = blackhole
bind.blackhole6 = blackhole
return nil
}

111
conn/conn.go Normal file
View File

@@ -0,0 +1,111 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.
package conn
import (
"errors"
"net"
"strings"
)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
// depending on the platform-specific implementation.
type Bind interface {
// LastMark reports the last mark set for this Bind.
LastMark() uint32
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
// ReceiveIPv6 reads an IPv6 UDP packet into b.
//
// It reports the number of bytes read, n,
// the packet source address ep,
// and any error.
ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
// ReceiveIPv4 reads an IPv4 UDP packet into b.
//
// It reports the number of bytes read, n,
// the packet source address ep,
// and any error.
ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
// Send writes a packet b to address ep.
Send(b []byte, ep Endpoint) error
// Close closes the Bind connection.
Close() error
}
// CreateBind creates a Bind bound to a port.
//
// The value actualPort reports the actual port number the Bind
// object gets bound to.
func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
return createBind(port)
}
// BindSocketToInterface is implemented by Bind objects that support being
// tied to a single network interface. Used by wireguard-windows.
type BindSocketToInterface interface {
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
}
// PeekLookAtSocketFd is implemented by Bind objects that support having their
// file descriptor peeked at. Used by wireguard-android.
type PeekLookAtSocketFd interface {
PeekLookAtSocketFd4() (fd int, err error)
PeekLookAtSocketFd6() (fd int, err error)
}
// An Endpoint maintains the source/destination caching for a peer.
//
// dst : the remote address of a peer ("endpoint" in uapi terminology)
// src : the local address from which datagrams originate going to the peer
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}

View File

@@ -2,10 +2,10 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
package conn
import (
"net"
@@ -67,16 +67,12 @@ func (e *NativeEndpoint) SrcToString() string {
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
// listen
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// retrieve port
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
@@ -100,7 +96,7 @@ func extractErrno(err error) error {
return syscallErr.Err
}
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
func createBind(uport uint16) (Bind, uint16, error) {
var err error
var bind nativeBind
@@ -135,6 +131,8 @@ func (bind *nativeBind) Close() error {
return err2
}
func (bind *nativeBind) LastMark() uint32 { return 0 }
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT

View File

@@ -2,19 +2,10 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
package conn
import (
"errors"
@@ -25,7 +16,6 @@ import (
"unsafe"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
const (
@@ -33,8 +23,8 @@ const (
)
type IPv4Source struct {
src [4]byte
ifindex int32
Src [4]byte
Ifindex int32
}
type IPv6Source struct {
@@ -49,6 +39,10 @@ type NativeEndpoint struct {
isV6 bool
}
func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
func (endpoint *NativeEndpoint) src4() *IPv4Source {
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
@@ -66,11 +60,9 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
}
type nativeBind struct {
sock4 int
sock6 int
netlinkSock int
netlinkCancel *rwcancel.RWCancel
lastMark uint32
sock4 int
sock6 int
lastMark uint32
}
var _ Endpoint = (*NativeEndpoint)(nil)
@@ -111,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) {
return nil, errors.New("Invalid IP address")
}
func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}
func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
func createBind(port uint16) (Bind, uint16, error) {
var err error
var bind nativeBind
var newPort uint16
bind.netlinkSock, err = createNetlinkRouteSocket()
if err != nil {
return nil, 0, err
}
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
if err != nil {
unix.Close(bind.netlinkSock)
return nil, 0, err
}
go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if successful
// Attempt ipv6 bind, update port if successful.
bind.sock6, newPort, err = create6(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
return nil, 0, err
}
} else {
port = newPort
}
// attempt ipv4 bind, update port if successful
// Attempt ipv4 bind, update port if successful.
bind.sock4, newPort, err = create4(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
unix.Close(bind.sock6)
return nil, 0, err
}
@@ -178,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
return &bind, port, nil
}
func (bind *nativeBind) LastMark() uint32 {
return bind.lastMark
}
func (bind *nativeBind) SetMark(value uint32) error {
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
@@ -216,22 +178,18 @@ func closeUnblock(fd int) error {
}
func (bind *nativeBind) Close() error {
var err1, err2, err3 error
var err1, err2 error
if bind.sock6 != -1 {
err1 = closeUnblock(bind.sock6)
}
if bind.sock4 != -1 {
err2 = closeUnblock(bind.sock4)
}
err3 = bind.netlinkCancel.Cancel()
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
return err2
}
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
@@ -278,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
func (end *NativeEndpoint) SrcIP() net.IP {
if !end.isV6 {
return net.IPv4(
end.src4().src[0],
end.src4().src[1],
end.src4().src[2],
end.src4().src[3],
end.src4().Src[0],
end.src4().Src[1],
end.src4().Src[2],
end.src4().Src[3],
)
} else {
return end.src6().src[:]
@@ -478,8 +436,8 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
Spec_dst: end.src4().src,
Ifindex: end.src4().ifindex,
Spec_dst: end.src4().Src,
Ifindex: end.src4().Ifindex,
},
}
@@ -573,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4().src = cmsg.pktinfo.Spec_dst
end.src4().ifindex = cmsg.pktinfo.Ifindex
end.src4().Src = cmsg.pktinfo.Spec_dst
end.src4().Ifindex = cmsg.pktinfo.Ifindex
}
return size, nil
@@ -611,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
return size, nil
}
func (bind *nativeBind) routineRouteListener(device *Device) {
type peerEndpointPtr struct {
peer *Peer
endpoint *Endpoint
}
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer unix.Close(bind.netlinkSock)
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !bind.netlinkCancel.ReadyRead() {
return
}
}
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil {
reqPeerLock.Unlock()
break
}
pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok {
break
}
pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.Unlock()
break
}
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
pePtr.peer.Unlock()
break
}
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
}
}
break
}
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() {
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.RLock()
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
peer.RUnlock()
continue
}
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
peer.RUnlock()
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: i,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
peer.endpoint.(*NativeEndpoint).dst4().Addr,
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
peer.endpoint.(*NativeEndpoint).src4().src,
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
},
uint32(bind.lastMark),
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint,
}
reqPeerLock.Unlock()
peer.RUnlock()
i++
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
break
}
}
device.peers.RUnlock()
}()
}
remain = remain[hdr.Len:]
}
}
}

View File

@@ -2,10 +2,10 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
package conn
func (bind *nativeBind) SetMark(mark uint32) error {
return nil

View File

@@ -2,10 +2,10 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
package conn
import (
"runtime"

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,15 +1,19 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
import "errors"
import (
"errors"
"golang.zx2c4.com/wireguard/conn"
)
type DummyDatagram struct {
msg []byte
endpoint Endpoint
endpoint conn.Endpoint
world bool // better type
}
@@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil
}
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6
if !ok {
return 0, nil, errors.New("closed")
@@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
return len(datagram.msg), datagram.endpoint, nil
}
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4
if !ok {
return 0, nil, errors.New("closed")
@@ -50,6 +54,6 @@ func (b *DummyBind) Close() error {
return nil
}
func (b *DummyBind) Send(buff []byte, end Endpoint) error {
func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
return nil
}

View File

@@ -1,44 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import "errors"
func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
nb, ok := device.net.bind.(*nativeBind)
if !ok {
return 0, errors.New("no socket exists")
}
sysconn, err := nb.ipv4.SyscallConn()
if err != nil {
return
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return
}
return
}
func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
nb, ok := device.net.bind.(*nativeBind)
if !ok {
return 0, errors.New("no socket exists")
}
sysconn, err := nb.ipv6.SyscallConn()
if err != nil {
return
}
err = sysconn.Control(func(f uintptr) {
fd = int(f)
})
if err != nil {
return
}
return
}

View File

@@ -1,187 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"net"
"strings"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
const (
ConnRoutineNumber = 2
)
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
*/
type Bind interface {
SetMark(value uint32) error
ReceiveIPv6(buff []byte) (int, Endpoint, error)
ReceiveIPv4(buff []byte) (int, Endpoint, error)
Send(buff []byte, end Endpoint) error
Close() error
}
/* An Endpoint maintains the source/destination caching for a peer
*
* dst : the remote address of a peer ("endpoint" in uapi terminology)
* src : the local address from which datagrams originate going to the peer
*/
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}
func unsafeCloseBind(device *Device) error {
var err error
netc := &device.net
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
netc.stopping.Wait()
return err
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp.Get() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
if device.isUp.Get() {
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port, device)
if err != nil {
netc.bind = nil
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
// start receiving routines
device.net.starting.Add(ConnRoutineNumber)
device.net.stopping.Add(ConnRoutineNumber)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.net.starting.Wait()
device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := unsafeCloseBind(device)
device.net.Unlock()
return err
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -13,7 +13,7 @@ import (
const (
RekeyAfterMessages = (1 << 60)
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RejectAfterMessages = (1 << 64) - (1 << 13) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -11,15 +11,14 @@ import (
"sync/atomic"
"time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
)
const (
DeviceRoutineNumberPerCPU = 3
DeviceRoutineNumberAdditional = 2
)
type Device struct {
isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard)
@@ -39,9 +38,10 @@ type Device struct {
starting sync.WaitGroup
stopping sync.WaitGroup
sync.RWMutex
bind Bind // bind interface
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
bind conn.Bind // bind interface
netlinkCancel *rwcancel.RWCancel
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
}
staticIdentity struct {
@@ -236,23 +236,11 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do static-static DH pre-computations
rmKey := device.staticIdentity.privateKey.IsZero()
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for key, peer := range device.peers.keyMap {
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
if rmKey {
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
} else {
expiredPeers = append(expiredPeers, peer)
}
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
for _, peer := range lockedPeers {
@@ -311,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
cpus := runtime.NumCPU()
device.state.starting.Wait()
device.state.stopping.Wait()
device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
for i := 0; i < cpus; i += 1 {
device.state.starting.Add(3)
device.state.stopping.Add(3)
go device.RoutineEncryption()
go device.RoutineDecryption()
go device.RoutineHandshake()
}
device.state.starting.Add(2)
device.state.stopping.Add(2)
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
@@ -393,10 +383,10 @@ func (device *Device) Close() {
device.isUp.Set(false)
close(device.signals.stop)
device.state.stopping.Wait()
device.RemoveAllPeers()
device.state.stopping.Wait()
device.FlushPacketQueues()
device.rate.limiter.Close()
@@ -425,3 +415,133 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
}
device.peers.RUnlock()
}
func unsafeCloseBind(device *Device) error {
var err error
netc := &device.net
if netc.netlinkCancel != nil {
netc.netlinkCancel.Cancel()
}
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
netc.stopping.Wait()
return err
}
func (device *Device) Bind() conn.Bind {
device.net.Lock()
defer device.net.Unlock()
return device.net.bind
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp.Get() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
if device.isUp.Get() {
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = conn.CreateBind(netc.port)
if err != nil {
netc.bind = nil
netc.port = 0
return err
}
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
netc.bind = nil
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
// start receiving routines
device.net.starting.Add(2)
device.net.stopping.Add(2)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.net.starting.Wait()
device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := unsafeCloseBind(device)
device.net.Unlock()
return err
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,28 +8,40 @@ package device
import (
"bufio"
"bytes"
"encoding/binary"
"io"
"fmt"
"net"
"os"
"strings"
"testing"
"time"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
func getFreePort(t *testing.T) string {
l, err := net.ListenPacket("udp", "localhost:0")
if err != nil {
t.Fatal(err)
}
defer l.Close()
return fmt.Sprintf("%d", l.LocalAddr().(*net.UDPAddr).Port)
}
func TestTwoDevicePing(t *testing.T) {
// TODO(crawshaw): pick unused ports on localhost
port1 := getFreePort(t)
port2 := getFreePort(t)
cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
listen_port=53511
listen_port={{PORT1}}
replace_peers=true
public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.2/32
endpoint=127.0.0.1:53512`
tun1 := NewChannelTUN()
endpoint=127.0.0.1:{{PORT2}}`
cfg1 = strings.ReplaceAll(cfg1, "{{PORT1}}", port1)
cfg1 = strings.ReplaceAll(cfg1, "{{PORT2}}", port2)
tun1 := tuntest.NewChannelTUN()
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
dev1.Up()
defer dev1.Close()
@@ -38,14 +50,17 @@ endpoint=127.0.0.1:53512`
}
cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
listen_port=53512
listen_port={{PORT2}}
replace_peers=true
public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.1/32
endpoint=127.0.0.1:53511`
tun2 := NewChannelTUN()
endpoint=127.0.0.1:{{PORT1}}`
cfg2 = strings.ReplaceAll(cfg2, "{{PORT1}}", port1)
cfg2 = strings.ReplaceAll(cfg2, "{{PORT2}}", port2)
tun2 := tuntest.NewChannelTUN()
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
dev2.Up()
defer dev2.Close()
@@ -54,7 +69,7 @@ endpoint=127.0.0.1:53511`
}
t.Run("ping 1.0.0.1", func(t *testing.T) {
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
@@ -67,7 +82,7 @@ endpoint=127.0.0.1:53511`
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
@@ -80,139 +95,6 @@ endpoint=127.0.0.1:53511`
})
}
func ping(dst, src net.IP) []byte {
localPort := uint16(1337)
seq := uint16(0)
payload := make([]byte, 4)
binary.BigEndian.PutUint16(payload[0:], localPort)
binary.BigEndian.PutUint16(payload[2:], seq)
return genICMPv4(payload, dst, src)
}
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
func checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}
func genICMPv4(payload []byte, dst, src net.IP) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
icmpv4ChecksumOffset = 2
icmpv4Size = 8
ipv4Size = 20
ipv4TotalLenOffset = 2
ipv4ChecksumOffset = 10
ttl = 65
)
hdr := make([]byte, ipv4Size+icmpv4Size)
ip := hdr[0:ipv4Size]
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
// https://tools.ietf.org/html/rfc792
icmpv4[0] = icmpv4Echo // type
icmpv4[1] = 0 // code
chksum := ^checksum(icmpv4, checksum(payload, 0))
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
// https://tools.ietf.org/html/rfc760 section 3.1
length := uint16(len(hdr) + len(payload))
ip[0] = (4 << 4) | (ipv4Size / 4)
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4())
copy(ip[16:], dst.To4())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
var v []byte
v = append(v, hdr...)
v = append(v, payload...)
return []byte(v)
}
// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close
closed chan struct{}
events chan tun.Event
tun chTun
}
func NewChannelTUN() *ChannelTUN {
c := &ChannelTUN{
Inbound: make(chan []byte),
Outbound: make(chan []byte),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
c.tun.c = c
c.events <- tun.EventUp
return c
}
func (c *ChannelTUN) TUN() tun.Device {
return &c.tun
}
type chTun struct {
c *ChannelTUN
}
func (t *chTun) File() *os.File { return nil }
func (t *chTun) Read(data []byte, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
msg := make([]byte, len(data)-offset)
copy(msg, data[offset:])
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case t.c.Inbound <- msg:
return len(data) - offset, nil
}
}
func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
}
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,14 +1,14 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"crypto/rand"
"encoding/binary"
"sync"
"unsafe"
)
type IndexTableEntry struct {
@@ -25,7 +25,8 @@ type IndexTable struct {
func randUint32() (uint32, error) {
var integer [4]byte
_, err := rand.Read(integer[:])
return *(*uint32)(unsafe.Pointer(&integer[0])), err
// Arbitrary endianness; both are intrinsified by the Go compiler.
return binary.LittleEndian.Uint32(integer[:]), err
}
func (table *IndexTable) Init() {

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,7 +8,9 @@ package device
import (
"crypto/cipher"
"sync"
"sync/atomic"
"time"
"unsafe"
"golang.zx2c4.com/wireguard/replay"
)
@@ -24,7 +26,7 @@ type Keypair struct {
sendNonce uint64
send cipher.AEAD
receive cipher.AEAD
replayFilter replay.ReplayFilter
replayFilter replay.Filter
isInitiator bool
created time.Time
localIndex uint32
@@ -38,6 +40,14 @@ type Keypairs struct {
next *Keypair
}
func (kp *Keypairs) storeNext(next *Keypair) {
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
}
func (kp *Keypairs) loadNext() *Keypair {
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
}
func (kp *Keypairs) Current() *Keypair {
kp.RLock()
defer kp.RUnlock()

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

16
device/mobilequirks.go Normal file
View File

@@ -0,0 +1,16 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package device
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
peer.disableRoaming = peer.endpoint != nil
}
device.peers.RUnlock()
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,29 +1,51 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"fmt"
"sync"
"time"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n"
)
type handshakeState int
// TODO(crawshaw): add commentary describing each state and the transitions
const (
HandshakeZeroed = iota
HandshakeInitiationCreated
HandshakeInitiationConsumed
HandshakeResponseCreated
HandshakeResponseConsumed
handshakeZeroed = handshakeState(iota)
handshakeInitiationCreated
handshakeInitiationConsumed
handshakeResponseCreated
handshakeResponseConsumed
)
func (hs handshakeState) String() string {
switch hs {
case handshakeZeroed:
return "handshakeZeroed"
case handshakeInitiationCreated:
return "handshakeInitiationCreated"
case handshakeInitiationConsumed:
return "handshakeInitiationConsumed"
case handshakeResponseCreated:
return "handshakeResponseCreated"
case handshakeResponseConsumed:
return "handshakeResponseConsumed"
default:
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
}
}
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
@@ -95,7 +117,7 @@ type MessageCookieReply struct {
}
type Handshake struct {
state int
state handshakeState
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
@@ -135,7 +157,7 @@ func (h *Handshake) Clear() {
setZero(h.chainKey[:])
setZero(h.hash[:])
h.localIndex = 0
h.state = HandshakeZeroed
h.state = handshakeZeroed
}
func (h *Handshake) mixHash(data []byte) {
@@ -154,6 +176,7 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -162,12 +185,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("static shared secret is zero")
}
// create ephemeral key
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
@@ -176,59 +194,56 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err
}
// assign index
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
func() {
var key [chacha20poly1305.KeySize]byte
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
ss[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
}()
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) {
return nil, errZeroECDHResult
}
var key [chacha20poly1305.KeySize]byte
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
ss[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
handshake.mixHash(msg.Static[:])
// encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errZeroECDHResult
}
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
timestamp := tai64n.Now()
func() {
var key [chacha20poly1305.KeySize]byte
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
aead, _ = chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
// assign index
device.indexTable.Delete(handshake.localIndex)
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated
handshake.state = handshakeInitiationCreated
return &msg, nil
}
@@ -250,16 +265,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
var err error
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}()
var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
if err != nil {
return nil
}
@@ -273,23 +288,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
}
handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity
var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
if isZero(handshake.precomputedStaticStatic[:]) {
handshake.mutex.RUnlock()
return nil
}
KDF2(
&chainKey,
&key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil {
handshake.mutex.RUnlock()
@@ -299,11 +315,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// protect against replay & flood
var ok bool
ok = timestamp.After(handshake.lastTimestamp)
ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
replay := !timestamp.After(handshake.lastTimestamp)
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
handshake.mutex.RUnlock()
if !ok {
if replay {
device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake replay @ %v\n", peer, timestamp)
return nil
}
if flood {
device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake flood\n", peer)
return nil
}
@@ -322,7 +342,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
handshake.state = HandshakeInitiationConsumed
handshake.state = handshakeInitiationConsumed
handshake.mutex.Unlock()
@@ -337,7 +357,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationConsumed {
if handshake.state != handshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first")
}
@@ -393,7 +413,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Empty[:])
}()
handshake.state = HandshakeResponseCreated
handshake.state = handshakeResponseCreated
return &msg, nil
}
@@ -423,7 +443,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
if handshake.state != HandshakeInitiationCreated {
if handshake.state != handshakeInitiationCreated {
return false
}
@@ -484,7 +504,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed
handshake.state = handshakeResponseConsumed
handshake.mutex.Unlock()
@@ -509,7 +529,7 @@ func (peer *Peer) BeginSymmetricSession() error {
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed {
if handshake.state == handshakeResponseConsumed {
KDF2(
&sendKey,
&recvKey,
@@ -517,7 +537,7 @@ func (peer *Peer) BeginSymmetricSession() error {
nil,
)
isInitiator = true
} else if handshake.state == HandshakeResponseCreated {
} else if handshake.state == handshakeResponseCreated {
KDF2(
&recvKey,
&sendKey,
@@ -526,7 +546,7 @@ func (peer *Peer) BeginSymmetricSession() error {
)
isInitiator = false
} else {
return errors.New("invalid state for keypair derivation")
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
}
// zero handshake
@@ -534,7 +554,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:])
peer.handshake.state = HandshakeZeroed
peer.handshake.state = handshakeZeroed
// create AEAD instances
@@ -547,7 +567,7 @@ func (peer *Peer) BeginSymmetricSession() error {
keypair.created = time.Now()
keypair.sendNonce = 0
keypair.replayFilter.Init()
keypair.replayFilter.Reset()
keypair.isInitiator = isInitiator
keypair.localIndex = peer.handshake.localIndex
keypair.remoteIndex = peer.handshake.remoteIndex
@@ -564,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
next := keypairs.next
next := keypairs.loadNext()
current := keypairs.current
if isInitiator {
if next != nil {
keypairs.next = nil
keypairs.storeNext(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -578,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
keypairs.next = keypair
keypairs.storeNext(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -589,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
if keypairs.next != receivedKeypair {
if keypairs.loadNext() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
if keypairs.next != receivedKeypair {
if keypairs.loadNext() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
keypairs.current = keypairs.next
keypairs.next = nil
keypairs.current = keypairs.loadNext()
keypairs.storeNext(nil)
return true
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -52,6 +52,15 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
return
}
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
err = loadExactHex(key[:], src)
if key.IsZero() {
return
}
key.clamp()
return
}
func (key NoisePrivateKey) ToHex() string {
return hex.EncodeToString(key[:])
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
key1 := peer1.keypairs.next
key1 := peer1.keypairs.loadNext()
key2 := peer2.keypairs.current
// encrypting / decryption test

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,6 +12,8 @@ import (
"sync"
"sync/atomic"
"time"
"golang.zx2c4.com/wireguard/conn"
)
const (
@@ -24,10 +26,15 @@ type Peer struct {
keypairs Keypairs
handshake Handshake
device *Device
endpoint Endpoint
endpoint conn.Endpoint
persistentKeepaliveInterval uint16
disableRoaming bool
// This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
// These fields are accessed with atomic operations, which must be
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
// allocated struct will be 64-bit aligned. So we place
// atomically-accessed fields up front, so that they can share in
// this alignment before smaller fields throw it off.
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
@@ -51,6 +58,7 @@ type Peer struct {
}
queue struct {
sync.RWMutex
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work
@@ -108,7 +116,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk
handshake.mutex.Unlock()
@@ -116,13 +123,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil
// conditionally add
// add
if !ssIsZero {
device.peers.keyMap[pk] = peer
} else {
return nil, nil
}
device.peers.keyMap[pk] = peer
// start peer
@@ -193,10 +196,11 @@ func (peer *Peer) Start() {
peer.routines.stopping.Add(PeerRoutineNumber)
// prepare queues
peer.queue.Lock()
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
peer.queue.Unlock()
peer.timersInit()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
@@ -222,10 +226,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next)
device.DeleteKeypair(keypairs.loadNext())
keypairs.previous = nil
keypairs.current = nil
keypairs.next = nil
keypairs.storeNext(nil)
keypairs.Unlock()
// clear handshake state
@@ -253,7 +257,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs.current.sendNonce = RejectAfterMessages
}
if keypairs.next != nil {
keypairs.next.sendNonce = RejectAfterMessages
keypairs.loadNext().sendNonce = RejectAfterMessages
}
keypairs.Unlock()
}
@@ -282,17 +286,17 @@ func (peer *Peer) Stop() {
// close queues
peer.queue.Lock()
close(peer.queue.nonce)
close(peer.queue.outbound)
close(peer.queue.inbound)
peer.queue.Unlock()
peer.ZeroAndFlushAll()
}
var RoamingDisabled bool
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
if RoamingDisabled {
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
if peer.disableRoaming {
return
}
peer.Lock()

43
device/peer_test.go Normal file
View File

@@ -0,0 +1,43 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"reflect"
"testing"
"unsafe"
)
func checkAlignment(t *testing.T, name string, offset uintptr) {
t.Helper()
if offset%8 != 0 {
t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
}
}
// TestPeerAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
// Unfortunately, violating this rule on 32-bit platforms results in a
// hard segfault at runtime.
func TestPeerAlignment(t *testing.T) {
var p Peer
typ := reflect.TypeOf(p)
t.Logf("Peer type size: %d, with fields:", typ.Size())
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
field.Name,
field.Offset,
field.Type.Size(),
field.Type.Align(),
)
}
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -17,12 +17,13 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
)
type QueueHandshakeElement struct {
msgType uint32
packet []byte
endpoint Endpoint
endpoint conn.Endpoint
buffer *[MaxMessageSize]byte
}
@@ -33,7 +34,7 @@ type QueueInboundElement struct {
packet []byte
counter uint64
keypair *Keypair
endpoint Endpoint
endpoint conn.Endpoint
}
func (elem *QueueInboundElement) Drop() {
@@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
logDebug := device.log.Debug
defer func() {
@@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
var (
err error
size int
endpoint Endpoint
endpoint conn.Endpoint
)
for {
@@ -183,11 +184,13 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
// add to decryption queues
peer.queue.RLock()
if peer.isRunning.Get() {
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
buffer = device.GetMessageBuffer()
}
}
peer.queue.RUnlock()
continue

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -107,6 +107,8 @@ func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement,
/* Queues a keepalive if no packets are queued for peer
*/
func (peer *Peer) SendKeepalive() bool {
peer.queue.RLock()
defer peer.queue.RUnlock()
if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
return false
}
@@ -310,6 +312,7 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue
peer.queue.RLock()
if peer.isRunning.Get() {
if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
peer.SendHandshakeInitiation(false)
@@ -317,6 +320,7 @@ func (device *Device) RoutineReadFromTUN() {
addToNonceQueue(peer.queue.nonce, elem, device)
elem = nil
}
peer.queue.RUnlock()
}
}
@@ -448,6 +452,21 @@ func (peer *Peer) RoutineNonce() {
}
}
func calculatePaddingSize(packetSize, mtu int) int {
lastUnit := packetSize
if mtu == 0 {
return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
}
if lastUnit > mtu {
lastUnit %= mtu
}
paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
if paddedSize > mtu {
paddedSize = mtu
}
return paddedSize - lastUnit
}
/* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex)
*
@@ -514,13 +533,8 @@ func (device *Device) RoutineEncryption() {
// pad content to multiple of 16
mtu := int(atomic.LoadInt32(&device.tun.mtu))
lastUnit := len(elem.packet) % mtu
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
if paddedSize > mtu {
paddedSize = mtu
}
for i := len(elem.packet); i < paddedSize; i++ {
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
for i := 0; i < paddingSize; i++ {
elem.packet = append(elem.packet, 0)
}

12
device/sticky_default.go Normal file
View File

@@ -0,0 +1,12 @@
// +build !linux android
package device
import (
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
return nil, nil
}

217
device/sticky_linux.go Normal file
View File

@@ -0,0 +1,217 @@
// +build !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
package device
import (
"sync"
"unsafe"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
netlinkSock, err := createNetlinkRouteSocket()
if err != nil {
return nil, err
}
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
if err != nil {
unix.Close(netlinkSock)
return nil, err
}
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
return netlinkCancel, nil
}
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
type peerEndpointPtr struct {
peer *Peer
endpoint *conn.Endpoint
}
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer unix.Close(netlinkSock)
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !netlinkCancel.ReadyRead() {
return
}
}
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil {
reqPeerLock.Unlock()
break
}
pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok {
break
}
pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.Unlock()
break
}
if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
pePtr.peer.Unlock()
break
}
pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
}
}
break
}
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() {
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.RLock()
if peer.endpoint == nil {
peer.RUnlock()
continue
}
nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
if nativeEP == nil {
peer.RUnlock()
continue
}
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
peer.RUnlock()
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: i,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
nativeEP.Dst4().Addr,
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
nativeEP.Src4().Src,
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
},
uint32(bind.LastMark()),
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint,
}
reqPeerLock.Unlock()
peer.RUnlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
break
}
}
device.peers.RUnlock()
}()
}
remain = remain[hdr.Len:]
}
}
}
func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: unix.RTMGRP_IPV4_ROUTE,
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device

View File

@@ -1,12 +1,13 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bufio"
"errors"
"fmt"
"io"
"net"
@@ -15,6 +16,7 @@ import (
"sync/atomic"
"time"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ipc"
)
@@ -30,7 +32,7 @@ func (s IPCError) ErrorCode() int64 {
return s.int64
}
func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
@@ -105,7 +107,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
return nil
}
func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
scanner := bufio.NewScanner(socket)
logError := device.log.Error
logDebug := device.log.Debug
@@ -138,7 +140,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
switch key {
case "private_key":
var sk NoisePrivateKey
err := sk.FromHex(value)
err := sk.FromMaybeZeroHex(value)
if err != nil {
logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid}
@@ -306,7 +308,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
err := func() error {
peer.Lock()
defer peer.Unlock()
endpoint, err := CreateEndpoint(value)
endpoint, err := conn.CreateEndpoint(value)
if err != nil {
return err
}
@@ -420,10 +422,20 @@ func (device *Device) IpcHandle(socket net.Conn) {
switch op {
case "set=1\n":
status = device.IpcSetOperation(buffered.Reader)
err = device.IpcSetOperation(buffered.Reader)
if err != nil && !errors.As(err, &status) {
// should never happen
device.log.Error.Println("Invalid UAPI error:", err)
status = &IPCError{1}
}
case "get=1\n":
status = device.IpcGetOperation(buffered.Writer)
err = device.IpcGetOperation(buffered.Writer)
if err != nil && !errors.As(err, &status) {
// should never happen
device.log.Error.Println("Invalid UAPI error:", err)
status = &IPCError{1}
}
default:
device.log.Error.Println("Invalid UAPI operation:", op)

View File

@@ -1,3 +1,3 @@
package device
const WireGuardGoVersion = "0.0.20200121"
const WireGuardGoVersion = "0.0.20201118"

9
go.mod
View File

@@ -1,10 +1,9 @@
module golang.zx2c4.com/wireguard
go 1.12
go 1.13
require (
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
golang.org/x/net v0.0.0-20191003171128-d98b1b443823
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c
golang.org/x/text v0.3.2
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7
)

19
go.sum
View File

@@ -1,14 +1,17 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o=
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME=
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c h1:6Zx7DRlKXf79yfxuQ/7GqV3w2y7aDsk6bGg0MzF5RVU=
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7 h1:s330+6z/Ko3J0o6rvOcwXe5nzs7UT9tLKHoOXYn6uE0=
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -2,32 +2,20 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
"errors"
"fmt"
"net"
"os"
"path"
"unsafe"
"golang.org/x/sys/unix"
)
var socketDirectory = "/var/run/wireguard"
const (
IpcErrorIO = -int64(unix.EIO)
IpcErrorProtocol = -int64(unix.EPROTO)
IpcErrorInvalid = -int64(unix.EINVAL)
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
socketName = "%s.sock"
)
type UAPIListener struct {
listener net.Listener // unix socket listener
connNew chan net.Conn
@@ -84,10 +72,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
unixListener.SetUnlinkOnClose(true)
}
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
socketPath := sockPath(name)
// watch for deletion of socket
@@ -146,58 +131,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil
}
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0755)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
oldUmask := unix.Umask(0077)
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
unix.Umask(oldUmask)
if err != nil {
return nil, err
}
return listener.File()
}

View File

@@ -1,31 +1,18 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
"errors"
"fmt"
"net"
"os"
"path"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
)
var socketDirectory = "/var/run/wireguard"
const (
IpcErrorIO = -int64(unix.EIO)
IpcErrorProtocol = -int64(unix.EPROTO)
IpcErrorInvalid = -int64(unix.EINVAL)
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
socketName = "%s.sock"
)
type UAPIListener struct {
listener net.Listener // unix socket listener
connNew chan net.Conn
@@ -84,10 +71,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
// watch for deletion of socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
socketPath := sockPath(name)
uapi.inotifyFd, err = unix.InotifyInit()
if err != nil {
@@ -143,58 +127,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
return uapi, nil
}
func UAPIOpen(name string) (*os.File, error) {
// check if path exist
err := os.MkdirAll(socketDirectory, 0755)
if err != nil && !os.IsExist(err) {
return nil, err
}
// open UNIX socket
socketPath := path.Join(
socketDirectory,
fmt.Sprintf(socketName, name),
)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
oldUmask := unix.Umask(0077)
listener, err := func() (*net.UnixListener, error) {
// initial connection attempt
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener, nil
}
// check if socket already active
_, err = net.Dial("unix", socketPath)
if err == nil {
return nil, errors.New("unix socket in use")
}
// cleanup & attempt again
err = os.Remove(socketPath)
if err != nil {
return nil, err
}
return net.ListenUnix("unix", addr)
}()
unix.Umask(oldUmask)
if err != nil {
return nil, err
}
return listener.File()
}

65
ipc/uapi_unix.go Normal file
View File

@@ -0,0 +1,65 @@
// +build linux darwin freebsd openbsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ipc
import (
"errors"
"fmt"
"net"
"os"
"golang.org/x/sys/unix"
)
const (
IpcErrorIO = -int64(unix.EIO)
IpcErrorProtocol = -int64(unix.EPROTO)
IpcErrorInvalid = -int64(unix.EINVAL)
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
)
// socketDirectory is variable because it is modified by a linker
// flag in wireguard-android.
var socketDirectory = "/var/run/wireguard"
func sockPath(iface string) string {
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
}
func UAPIOpen(name string) (*os.File, error) {
if err := os.MkdirAll(socketDirectory, 0755); err != nil {
return nil, err
}
socketPath := sockPath(name)
addr, err := net.ResolveUnixAddr("unix", socketPath)
if err != nil {
return nil, err
}
oldUmask := unix.Umask(0077)
defer unix.Umask(oldUmask)
listener, err := net.ListenUnix("unix", addr)
if err == nil {
return listener.File()
}
// Test socket, if not in use cleanup and try again.
if _, err := net.Dial("unix", socketPath); err == nil {
return nil, errors.New("unix socket in use")
}
if err := os.Remove(socketPath); err != nil {
return nil, err
}
listener, err = net.ListenUnix("unix", addr)
if err != nil {
return nil, err
}
return listener.File()
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ipc

View File

@@ -3,7 +3,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package winpipe

View File

@@ -1,7 +1,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package winpipe

View File

@@ -3,7 +3,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package winpipe

24
main.go
View File

@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package main
@@ -41,18 +41,16 @@ func warning() {
return
}
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G")
fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
fmt.Fprintln(os.Stderr, "┌───────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, " ")
fmt.Fprintln(os.Stderr, "│ Running this software on Linux is unnecessary, ")
fmt.Fprintln(os.Stderr, " because the Linux kernel has built-in first │")
fmt.Fprintln(os.Stderr, " class support for WireGuard, which will be ")
fmt.Fprintln(os.Stderr, " faster, slicker, and better integrated. For ")
fmt.Fprintln(os.Stderr, " information on installing the kernel module, ")
fmt.Fprintln(os.Stderr, " please visit: <https://wireguard.com/install>. ")
fmt.Fprintln(os.Stderr, " ")
fmt.Fprintln(os.Stderr, "└───────────────────────────────────────────────────┘")
}
func main() {

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package main

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
@@ -20,21 +20,23 @@ const (
)
type RatelimiterEntry struct {
sync.Mutex
mu sync.Mutex
lastTime time.Time
tokens int64
}
type Ratelimiter struct {
sync.RWMutex
stopReset chan struct{}
mu sync.RWMutex
timeNow func() time.Time
stopReset chan struct{} // send to reset, close to stop
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
}
func (rate *Ratelimiter) Close() {
rate.Lock()
defer rate.Unlock()
rate.mu.Lock()
defer rate.mu.Unlock()
if rate.stopReset != nil {
close(rate.stopReset)
@@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() {
}
func (rate *Ratelimiter) Init() {
rate.Lock()
defer rate.Unlock()
rate.mu.Lock()
defer rate.mu.Unlock()
if rate.timeNow == nil {
rate.timeNow = time.Now
}
// stop any ongoing garbage collection routine
if rate.stopReset != nil {
close(rate.stopReset)
}
@@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() {
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
// start garbage collection routine
stopReset := rate.stopReset // store in case Init is called again.
// Start garbage collection routine.
go func() {
ticker := time.NewTicker(time.Second)
ticker.Stop()
for {
select {
case _, ok := <-rate.stopReset:
case _, ok := <-stopReset:
ticker.Stop()
if ok {
ticker = time.NewTicker(time.Second)
} else {
if !ok {
return
}
ticker = time.NewTicker(time.Second)
case <-ticker.C:
func() {
rate.Lock()
defer rate.Unlock()
for key, entry := range rate.tableIPv4 {
entry.Lock()
if time.Since(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key)
}
entry.Unlock()
}
for key, entry := range rate.tableIPv6 {
entry.Lock()
if time.Since(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.Unlock()
}
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
ticker.Stop()
}
}()
if rate.cleanup() {
ticker.Stop()
}
}
}
}()
}
func (rate *Ratelimiter) cleanup() (empty bool) {
rate.mu.Lock()
defer rate.mu.Unlock()
for key, entry := range rate.tableIPv4 {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key)
}
entry.mu.Unlock()
}
for key, entry := range rate.tableIPv6 {
entry.mu.Lock()
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key)
}
entry.mu.Unlock()
}
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
}
func (rate *Ratelimiter) Allow(ip net.IP) bool {
var entry *RatelimiterEntry
var keyIPv4 [net.IPv4len]byte
@@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
IPv4 := ip.To4()
IPv6 := ip.To16()
rate.RLock()
rate.mu.RLock()
if IPv4 != nil {
copy(keyIPv4[:], IPv4)
@@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
entry = rate.tableIPv6[keyIPv6]
}
rate.RUnlock()
rate.mu.RUnlock()
// make new entry if not found
if entry == nil {
entry = new(RatelimiterEntry)
entry.tokens = maxTokens - packetCost
entry.lastTime = time.Now()
rate.Lock()
entry.lastTime = rate.timeNow()
rate.mu.Lock()
if IPv4 != nil {
rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
@@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
rate.stopReset <- struct{}{}
}
}
rate.Unlock()
rate.mu.Unlock()
return true
}
// add tokens to entry
entry.Lock()
now := time.Now()
entry.mu.Lock()
now := rate.timeNow()
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
entry.lastTime = now
if entry.tokens > maxTokens {
@@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
if entry.tokens > packetCost {
entry.tokens -= packetCost
entry.Unlock()
entry.mu.Unlock()
return true
}
entry.Unlock()
entry.mu.Unlock()
return false
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package ratelimiter
@@ -11,22 +11,21 @@ import (
"time"
)
type RatelimiterResult struct {
type result struct {
allowed bool
text string
wait time.Duration
}
func TestRatelimiter(t *testing.T) {
var rate Ratelimiter
var expectedResults []result
var ratelimiter Ratelimiter
var expectedResults []RatelimiterResult
Nano := func(nano int64) time.Duration {
nano := func(nano int64) time.Duration {
return time.Nanosecond * time.Duration(nano)
}
Add := func(res RatelimiterResult) {
add := func(res result) {
expectedResults = append(
expectedResults,
res,
@@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
}
for i := 0; i < packetsBurstable; i++ {
Add(RatelimiterResult{
add(result{
allowed: true,
text: "initial burst",
})
}
Add(RatelimiterResult{
add(result{
allowed: false,
text: "after burst",
})
Add(RatelimiterResult{
add(result{
allowed: true,
wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
text: "filling tokens for single packet",
})
Add(RatelimiterResult{
add(result{
allowed: false,
text: "not having refilled enough",
})
Add(RatelimiterResult{
add(result{
allowed: true,
wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
text: "filling tokens for two packet burst",
})
Add(RatelimiterResult{
add(result{
allowed: true,
text: "second packet in 2 packet burst",
})
Add(RatelimiterResult{
add(result{
allowed: false,
text: "packet following 2 packet burst",
})
@@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
}
ratelimiter.Init()
now := time.Now()
rate.timeNow = func() time.Time {
return now
}
defer func() {
// Lock to avoid data race with cleanup goroutine from Init.
rate.mu.Lock()
defer rate.mu.Unlock()
rate.timeNow = time.Now
}()
timeSleep := func(d time.Duration) {
now = now.Add(d + 1)
rate.cleanup()
}
rate.Init()
defer rate.Close()
for i, res := range expectedResults {
time.Sleep(res.wait)
timeSleep(res.wait)
for _, ip := range ips {
allowed := ratelimiter.Allow(ip)
allowed := rate.Allow(ip)
if allowed != res.allowed {
t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
}
}
}

View File

@@ -1,83 +1,62 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
package replay
/* Implementation of RFC6479
* https://tools.ietf.org/html/rfc6479
*
* The implementation is not safe for concurrent use!
*/
type block uint64
const (
// See: https://golang.org/src/math/big/arith.go
_Wordm = ^uintptr(0)
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
_WordSize = 1 << _WordLogSize
blockBitLog = 6 // 1<<6 == 64 bits
blockBits = 1 << blockBitLog // must be power of 2
ringBlocks = 1 << 7 // must be power of 2
windowSize = (ringBlocks - 1) * blockBits
blockMask = ringBlocks - 1
bitMask = blockBits - 1
)
const (
CounterRedundantBitsLog = _WordLogSize + 3
CounterRedundantBits = _WordSize * 8
CounterBitsTotal = 2048
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
)
const (
BacktrackWords = CounterBitsTotal / _WordSize
)
func minUint64(a uint64, b uint64) uint64 {
if a > b {
return b
}
return a
// A Filter rejects replayed messages by checking if message counter value is
// within a sliding window of previously received messages.
// The zero value for Filter is an empty filter ready to use.
// Filters are unsafe for concurrent use.
type Filter struct {
last uint64
ring [ringBlocks]block
}
type ReplayFilter struct {
counter uint64
backtrack [BacktrackWords]uintptr
// Reset resets the filter to empty state.
func (f *Filter) Reset() {
f.last = 0
f.ring[0] = 0
}
func (filter *ReplayFilter) Init() {
filter.counter = 0
filter.backtrack[0] = 0
}
func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
// ValidateCounter checks if the counter should be accepted.
// Overlimit counters (>= limit) are always rejected.
func (f *Filter) ValidateCounter(counter uint64, limit uint64) bool {
if counter >= limit {
return false
}
indexWord := counter >> CounterRedundantBitsLog
if counter > filter.counter {
// move window forward
current := filter.counter >> CounterRedundantBitsLog
diff := minUint64(indexWord-current, BacktrackWords)
for i := uint64(1); i <= diff; i++ {
filter.backtrack[(current+i)%BacktrackWords] = 0
indexBlock := counter >> blockBitLog
if counter > f.last { // move window forward
current := f.last >> blockBitLog
diff := indexBlock - current
if diff > ringBlocks {
diff = ringBlocks // cap diff to clear the whole ring
}
filter.counter = counter
} else if filter.counter-counter > CounterWindowSize {
// behind current window
for i := current + 1; i <= current+diff; i++ {
f.ring[i&blockMask] = 0
}
f.last = counter
} else if f.last-counter > windowSize { // behind current window
return false
}
indexWord %= BacktrackWords
indexBit := counter & uint64(CounterRedundantBits-1)
// check and set bit
oldValue := filter.backtrack[indexWord]
newValue := oldValue | (1 << indexBit)
filter.backtrack[indexWord] = newValue
return oldValue != newValue
indexBlock &= blockMask
indexBit := counter & bitMask
old := f.ring[indexBlock]
new := old | 1<<indexBit
f.ring[indexBlock] = new
return old != new
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package replay
@@ -14,22 +14,22 @@ import (
*
*/
const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
const RejectAfterMessages = 1<<64 - 1<<13 - 1
func TestReplay(t *testing.T) {
var filter ReplayFilter
var filter Filter
T_LIM := CounterWindowSize + 1
const T_LIM = windowSize + 1
testNumber := 0
T := func(n uint64, v bool) {
T := func(n uint64, expected bool) {
testNumber++
if filter.ValidateCounter(n, RejectAfterMessages) != v {
t.Fatal("Test", testNumber, "failed", n, v)
if filter.ValidateCounter(n, RejectAfterMessages) != expected {
t.Fatal("Test", testNumber, "failed", n, expected)
}
}
filter.Init()
filter.Reset()
T(0, true) /* 1 */
T(1, true) /* 2 */
@@ -67,53 +67,53 @@ func TestReplay(t *testing.T) {
T(0, false) /* 34 */
t.Log("Bulk test 1")
filter.Init()
filter.Reset()
testNumber = 0
for i := uint64(1); i <= CounterWindowSize; i++ {
for i := uint64(1); i <= windowSize; i++ {
T(i, true)
}
T(0, true)
T(0, false)
t.Log("Bulk test 2")
filter.Init()
filter.Reset()
testNumber = 0
for i := uint64(2); i <= CounterWindowSize+1; i++ {
for i := uint64(2); i <= windowSize+1; i++ {
T(i, true)
}
T(1, true)
T(0, false)
t.Log("Bulk test 3")
filter.Init()
filter.Reset()
testNumber = 0
for i := CounterWindowSize + 1; i > 0; i-- {
for i := uint64(windowSize + 1); i > 0; i-- {
T(i, true)
}
t.Log("Bulk test 4")
filter.Init()
filter.Reset()
testNumber = 0
for i := CounterWindowSize + 2; i > 1; i-- {
for i := uint64(windowSize + 2); i > 1; i-- {
T(i, true)
}
T(0, false)
t.Log("Bulk test 5")
filter.Init()
filter.Reset()
testNumber = 0
for i := CounterWindowSize; i > 0; i-- {
for i := uint64(windowSize); i > 0; i-- {
T(i, true)
}
T(CounterWindowSize+1, true)
T(windowSize+1, true)
T(0, false)
t.Log("Bulk test 6")
filter.Init()
filter.Reset()
testNumber = 0
for i := CounterWindowSize; i > 0; i-- {
for i := uint64(windowSize); i > 0; i-- {
T(i, true)
}
T(0, true)
T(CounterWindowSize+1, true)
T(windowSize+1, true)
}

View File

@@ -1,6 +1,8 @@
// +build !windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package rwcancel

View File

@@ -1,8 +1,12 @@
// +build !windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
// Package rwcancel implements cancelable read/write operations on
// a file descriptor.
package rwcancel
import (

View File

@@ -0,0 +1,8 @@
// SPDX-License-Identifier: MIT
package rwcancel
type RWCancel struct {
}
func (*RWCancel) Cancel() {}

View File

@@ -1,8 +1,8 @@
// +build !linux
// +build !linux,!windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
@@ -10,5 +10,6 @@ package rwcancel
import "golang.org/x/sys/unix"
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
return unix.Select(nfd, r, w, e, timeout)
_, err := unix.Select(nfd, r, w, e, timeout)
return err
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package rwcancel

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tai64n
@@ -17,16 +17,19 @@ const whitenerMask = uint32(0x1000000 - 1)
type Timestamp [TimestampSize]byte
func Now() Timestamp {
func stamp(t time.Time) Timestamp {
var tai64n Timestamp
now := time.Now()
secs := base + uint64(now.Unix())
nano := uint32(now.Nanosecond()) &^ whitenerMask
secs := base + uint64(t.Unix())
nano := uint32(t.Nanosecond()) &^ whitenerMask
binary.BigEndian.PutUint64(tai64n[:], secs)
binary.BigEndian.PutUint32(tai64n[8:], nano)
return tai64n
}
func Now() Timestamp {
return stamp(time.Now())
}
func (t1 Timestamp) After(t2 Timestamp) bool {
return bytes.Compare(t1[:], t2[:]) > 0
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tai64n
@@ -10,21 +10,31 @@ import (
"time"
)
/* Testing the essential property of the timestamp
* as used by WireGuard.
*/
// Test that timestamps are monotonic as required by Wireguard and that
// nanosecond-level information is whitened to prevent side channel attacks.
func TestMonotonic(t *testing.T) {
old := Now()
for i := 0; i < 50; i++ {
next := Now()
if next.After(old) {
t.Error("Whitening insufficient")
}
time.Sleep(time.Duration(whitenerMask)/time.Nanosecond + 1)
next = Now()
if !next.After(old) {
t.Error("Not monotonically increasing on whitened nano-second scale")
}
old = next
startTime := time.Unix(0, 123456789) // a nontrivial bit pattern
// Whitening should reduce timestamp granularity
// to more than 10 but fewer than 20 milliseconds.
tests := []struct {
name string
t1, t2 time.Time
wantAfter bool
}{
{"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false},
{"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false},
{"after_1_ms", startTime, startTime.Add(time.Millisecond), false},
{"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false},
{"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ts1, ts2 := stamp(tt.t1), stamp(tt.t2)
got := ts2.After(ts1)
if got != tt.wantAfter {
t.Errorf("after = %v; want %v", got, tt.wantAfter)
}
})
}
}

View File

@@ -2,7 +2,7 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -20,19 +20,6 @@ import (
const utunControlName = "com.apple.net.utun_control"
// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h
const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3
// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h
type sockaddrCtl struct {
scLen uint8
scFamily uint8
ssSysaddr uint16
scID uint32
scUnit uint32
scReserved [5]uint32
}
type NativeTun struct {
name string
tunFile *os.File
@@ -41,8 +28,6 @@ type NativeTun struct {
routeSocket int
}
var sockaddrCtlSize uintptr = 32
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
@@ -130,43 +115,21 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, err
}
var ctlInfo = &struct {
ctlID uint32
ctlName [96]byte
}{}
copy(ctlInfo.ctlName[:], []byte(utunControlName))
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(_CTLIOCGINFO),
uintptr(unsafe.Pointer(ctlInfo)),
)
if errno != 0 {
return nil, fmt.Errorf("_CTLIOCGINFO: %v", errno)
ctlInfo := &unix.CtlInfo{}
copy(ctlInfo.Name[:], []byte(utunControlName))
err = unix.IoctlCtlInfo(fd, ctlInfo)
if err != nil {
return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
}
sc := sockaddrCtl{
scLen: uint8(sockaddrCtlSize),
scFamily: unix.AF_SYSTEM,
ssSysaddr: 2,
scID: ctlInfo.ctlID,
scUnit: uint32(ifIndex) + 1,
sc := &unix.SockaddrCtl{
ID: ctlInfo.Id,
Unit: uint32(ifIndex) + 1,
}
scPointer := unsafe.Pointer(&sc)
_, _, errno = unix.RawSyscall(
unix.SYS_CONNECT,
uintptr(fd),
uintptr(scPointer),
uintptr(sockaddrCtlSize),
)
if errno != 0 {
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
err = unix.Connect(fd, sc)
if err != nil {
return nil, err
}
err = syscall.SetNonblock(fd, true)
@@ -230,27 +193,19 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
}
func (tun *NativeTun) Name() (string, error) {
var ifName struct {
name [16]byte
}
ifNameSize := uintptr(16)
var errno syscall.Errno
var err error
tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
tun.name, err = unix.GetsockoptString(
int(fd),
2, /* #define SYSPROTO_CONTROL 2 */
2, /* #define UTUN_OPT_IFNAME 2 */
uintptr(unsafe.Pointer(&ifName)),
uintptr(unsafe.Pointer(&ifNameSize)), 0)
)
})
if errno != 0 {
return "", fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
if err != nil {
return "", fmt.Errorf("GetSockoptString: %w", err)
}
tun.name = string(ifName.name[:ifNameSize-1])
return tun.name, nil
}
@@ -320,11 +275,6 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) setMTU(n int) error {
// open datagram socket
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
@@ -337,29 +287,18 @@ func (tun *NativeTun) setMTU(n int) error {
defer unix.Close(fd)
// do ioctl call
var ifr [32]byte
copy(ifr[:], tun.name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return fmt.Errorf("failed to set MTU on %s", tun.name)
var ifr unix.IfreqMTU
copy(ifr.Name[:], tun.name)
ifr.MTU = int32(n)
err = unix.IoctlSetIfreqMTU(fd, &ifr)
if err != nil {
return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err)
}
return nil
}
func (tun *NativeTun) MTU() (int, error) {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
@@ -372,19 +311,10 @@ func (tun *NativeTun) MTU() (int, error) {
defer unix.Close(fd)
// do ioctl call
var ifr [64]byte
copy(ifr[:], tun.name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name)
if err != nil {
return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err)
}
return int(*(*int32)(unsafe.Pointer(&ifr[16]))), nil
return int(ifr.MTU), nil
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -287,7 +287,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %v", errno)
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %w", errno)
}
// Open control sockets
@@ -328,7 +328,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %v", assignedName, errno)
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %w", assignedName, errno)
}
ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL
ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD
@@ -341,7 +341,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %v", assignedName, errno)
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %w", assignedName, errno)
}
// Rename the interface
@@ -359,7 +359,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Failed to rename %s to %s: %v", assignedName, name, errno)
return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno)
}
return CreateTUNFromFile(tunFile, mtu)

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -12,7 +12,6 @@ import (
"bytes"
"errors"
"fmt"
"net"
"os"
"sync"
"syscall"
@@ -32,7 +31,6 @@ const (
type NativeTun struct {
tunFile *os.File
index int32 // if index
name string // name of interface
errors chan error // async error handling
events chan Event // device related events
nopi bool // the device was passed IFF_NO_PI
@@ -40,6 +38,10 @@ type NativeTun struct {
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
}
func (tun *NativeTun) File() *os.File {
@@ -64,14 +66,19 @@ func (tun *NativeTun) routineHackListener() {
}
switch err {
case unix.EINVAL:
// If the tunnel is up, it reports that write() is
// allowed but we provided invalid data.
tun.events <- EventUp
case unix.EIO:
// If the tunnel is down, it reports that no I/O
// is possible, without checking our provided data.
tun.events <- EventDown
default:
return
}
select {
case <-time.After(time.Second):
// nothing
case <-tun.statusListenersShutdown:
return
}
@@ -85,7 +92,7 @@ func createNetlinkSocket() (int, error) {
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
}
err = unix.Bind(sock, saddr)
if err != nil {
@@ -126,6 +133,7 @@ func (tun *NativeTun) routineNetlinkListener() {
default:
}
wasEverUp := false
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
@@ -149,10 +157,16 @@ func (tun *NativeTun) routineNetlinkListener() {
if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- EventUp
wasEverUp = true
}
if info.Flags&unix.IFF_RUNNING == 0 {
tun.events <- EventDown
// Don't emit EventDown before we've ever emitted EventUp.
// This avoids a startup race with HackListener, which
// might detect Up before we have finished reporting Down.
if wasEverUp {
tun.events <- EventDown
}
}
tun.events <- EventMTUUpdate
@@ -164,11 +178,6 @@ func (tun *NativeTun) routineNetlinkListener() {
}
}
func (tun *NativeTun) isUp() (bool, error) {
inter, err := net.InterfaceByName(tun.name)
return inter.Flags&net.FlagUp != 0, err
}
func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket(
unix.AF_INET,
@@ -198,6 +207,11 @@ func getIFIndex(name string) (int32, error) {
}
func (tun *NativeTun) setMTU(n int) error {
name, err := tun.Name()
if err != nil {
return err
}
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
@@ -212,9 +226,8 @@ func (tun *NativeTun) setMTU(n int) error {
defer unix.Close(fd)
// do ioctl call
var ifr [ifReqSize]byte
copy(ifr[:], tun.name)
copy(ifr[:], name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
@@ -231,6 +244,11 @@ func (tun *NativeTun) setMTU(n int) error {
}
func (tun *NativeTun) MTU() (int, error) {
name, err := tun.Name()
if err != nil {
return 0, err
}
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
@@ -247,7 +265,7 @@ func (tun *NativeTun) MTU() (int, error) {
// do ioctl call
var ifr [ifReqSize]byte
copy(ifr[:], tun.name)
copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
@@ -262,6 +280,15 @@ func (tun *NativeTun) MTU() (int, error) {
}
func (tun *NativeTun) Name() (string, error) {
tun.nameOnce.Do(tun.initNameCache)
return tun.nameCache, tun.nameErr
}
func (tun *NativeTun) initNameCache() {
tun.nameCache, tun.nameErr = tun.nameSlow()
}
func (tun *NativeTun) nameSlow() (string, error) {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return "", err
@@ -282,13 +309,11 @@ func (tun *NativeTun) Name() (string, error) {
if errno != 0 {
return "", errors.New("failed to get name of TUN device: " + errno.Error())
}
nullStr := ifr[:]
i := bytes.IndexByte(nullStr, 0)
if i != -1 {
nullStr = nullStr[:i]
name := ifr[:]
if i := bytes.IndexByte(name, 0); i != -1 {
name = name[:i]
}
tun.name = string(nullStr)
return tun.name, nil
return string(name), nil
}
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
@@ -367,6 +392,9 @@ func (tun *NativeTun) Close() error {
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
}
return nil, err
}
@@ -408,16 +436,15 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
statusListenersShutdown: make(chan struct{}),
nopi: false,
}
var err error
_, err = tun.Name()
name, err := tun.Name()
if err != nil {
return nil, err
}
// start event listener
tun.index, err = getIFIndex(tun.name)
tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tun

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2018-2020 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -9,10 +9,9 @@ import (
"errors"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
"unsafe"
_ "unsafe"
"golang.org/x/sys/windows"
@@ -33,18 +32,26 @@ type rateJuggler struct {
}
type NativeTun struct {
wt *wintun.Interface
wt *wintun.Adapter
handle windows.Handle
close bool
events chan Event
errors chan error
forcedMTU int
rate rateJuggler
rings *wintun.RingDescriptor
writeLock sync.Mutex
session wintun.Session
readWait windows.Handle
}
const WintunPool = wintun.Pool("WireGuard")
var WintunPool *wintun.Pool
func init() {
var err error
WintunPool, err = wintun.MakePool("WireGuard")
if err != nil {
panic(fmt.Errorf("Failed to make pool: %w", err))
}
}
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
@@ -66,20 +73,20 @@ func CreateTUN(ifname string, mtu int) (Device, error) {
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
var err error
var wt *wintun.Interface
var wt *wintun.Adapter
// Does an interface with this name already exist?
wt, err = WintunPool.GetInterface(ifname)
wt, err = WintunPool.OpenAdapter(ifname)
if err == nil {
// If so, we delete it, in case it has weird residual configuration.
_, err = wt.DeleteInterface()
_, err = wt.Delete(true)
if err != nil {
return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
}
}
wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
wt, _, err = WintunPool.CreateAdapter(ifname, requestedGUID)
if err != nil {
return nil, fmt.Errorf("Error creating interface: %v", err)
return nil, fmt.Errorf("Error creating interface: %w", err)
}
forcedMTU := 1420
@@ -95,17 +102,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
forcedMTU: forcedMTU,
}
tun.rings, err = wintun.NewRingDescriptor()
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error creating events: %v", err)
}
tun.handle, err = tun.wt.Register(tun.rings)
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error registering rings: %v", err)
_, err = tun.wt.Delete(false)
close(tun.events)
return nil, fmt.Errorf("Error starting session: %w", err)
}
tun.readWait = tun.session.ReadWaitEvent()
return tun, nil
}
@@ -123,16 +126,10 @@ func (tun *NativeTun) Events() chan Event {
func (tun *NativeTun) Close() error {
tun.close = true
if tun.rings.Send.TailMoved != 0 {
windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
}
if tun.handle != windows.InvalidHandle {
windows.CloseHandle(tun.handle)
}
tun.rings.Close()
tun.session.End()
var err error
if tun.wt != nil {
_, err = tun.wt.DeleteInterface()
_, err = tun.wt.Delete(false)
}
close(tun.events)
return err
@@ -156,56 +153,34 @@ retry:
return 0, err
default:
}
if tun.close {
return 0, os.ErrClosed
}
buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed
}
start := nanotime()
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
var buffTail uint32
for {
buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
if buffHead != buffTail {
break
}
if tun.close {
return 0, os.ErrClosed
}
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
goto retry
packet, err := tun.session.ReceivePacket()
switch err {
case nil:
packetSize := len(packet)
copy(buff[offset:], packet)
tun.session.ReleaseReceivePacket(packet)
tun.rate.update(uint64(packetSize))
return packetSize, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
goto retry
}
procyield(1)
continue
case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed
case windows.ERROR_INVALID_DATA:
return 0, errors.New("Send ring corrupt")
}
procyield(1)
return 0, fmt.Errorf("Read failed: %w", err)
}
if buffTail >= wintun.PacketCapacity {
return 0, os.ErrClosed
}
buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
return 0, errors.New("incomplete packet header in send ring")
}
packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
if packet.Size > wintun.PacketSizeMax {
return 0, errors.New("packet too big in send ring")
}
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
if alignedPacketSize > buffContent {
return 0, errors.New("incomplete packet in send ring")
}
copy(buff[offset:], packet.Data[:packet.Size])
buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
tun.rate.update(uint64(packet.Size))
return int(packet.Size), nil
}
func (tun *NativeTun) Flush() error {
@@ -217,36 +192,22 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return 0, os.ErrClosed
}
packetSize := uint32(len(buff) - offset)
packetSize := len(buff) - offset
tun.rate.update(uint64(packetSize))
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
tun.writeLock.Lock()
defer tun.writeLock.Unlock()
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed
packet, err := tun.session.AllocateSendPacket(packetSize)
if err == nil {
copy(packet, buff[offset:])
tun.session.SendPacket(packet)
return packetSize, nil
}
buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
if buffTail >= wintun.PacketCapacity {
switch err {
case windows.ERROR_HANDLE_EOF:
return 0, os.ErrClosed
}
buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
if alignedPacketSize > buffSpace {
case windows.ERROR_BUFFER_OVERFLOW:
return 0, nil // Dropping when ring is full.
}
packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
packet.Size = packetSize
copy(packet.Data[:packetSize], buff[offset:])
atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
windows.SetEvent(tun.rings.Receive.TailMoved)
}
return int(packetSize), nil
return 0, fmt.Errorf("Write failed: %w", err)
}
// LUID returns Windows interface instance ID.
@@ -254,9 +215,9 @@ func (tun *NativeTun) LUID() uint64 {
return tun.wt.LUID()
}
// Version returns the version of the Wintun driver and NDIS system currently loaded.
func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
return tun.wt.Version()
// RunningVersion returns the running version of the Wintun driver.
func (tun *NativeTun) RunningVersion() (version uint32, err error) {
return wintun.RunningVersion()
}
func (rate *rateJuggler) update(packetLen uint64) {

150
tun/tuntest/tuntest.go Normal file
View File

@@ -0,0 +1,150 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
*/
package tuntest
import (
"encoding/binary"
"io"
"net"
"os"
"golang.zx2c4.com/wireguard/tun"
)
func Ping(dst, src net.IP) []byte {
localPort := uint16(1337)
seq := uint16(0)
payload := make([]byte, 4)
binary.BigEndian.PutUint16(payload[0:], localPort)
binary.BigEndian.PutUint16(payload[2:], seq)
return genICMPv4(payload, dst, src)
}
// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
func checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}
func genICMPv4(payload []byte, dst, src net.IP) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
icmpv4ChecksumOffset = 2
icmpv4Size = 8
ipv4Size = 20
ipv4TotalLenOffset = 2
ipv4ChecksumOffset = 10
ttl = 65
)
hdr := make([]byte, ipv4Size+icmpv4Size)
ip := hdr[0:ipv4Size]
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
// https://tools.ietf.org/html/rfc792
icmpv4[0] = icmpv4Echo // type
icmpv4[1] = 0 // code
chksum := ^checksum(icmpv4, checksum(payload, 0))
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
// https://tools.ietf.org/html/rfc760 section 3.1
length := uint16(len(hdr) + len(payload))
ip[0] = (4 << 4) | (ipv4Size / 4)
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4())
copy(ip[16:], dst.To4())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
var v []byte
v = append(v, hdr...)
v = append(v, payload...)
return []byte(v)
}
// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close
closed chan struct{}
events chan tun.Event
tun chTun
}
func NewChannelTUN() *ChannelTUN {
c := &ChannelTUN{
Inbound: make(chan []byte),
Outbound: make(chan []byte),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
c.tun.c = c
c.events <- tun.EventUp
return c
}
func (c *ChannelTUN) TUN() tun.Device {
return &c.tun
}
type chTun struct {
c *ChannelTUN
}
func (t *chTun) File() *os.File { return nil }
func (t *chTun) Read(data []byte, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
msg := make([]byte, len(data)-offset)
copy(msg, data[offset:])
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case t.c.Inbound <- msg:
return len(data) - offset, nil
}
}
const DefaultMTU = 1420
func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
}

View File

@@ -0,0 +1,50 @@
// +build !load_wintun_from_rsrc
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module windows.Handle
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != 0 {
return nil
}
const (
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
)
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return windows.GetProcAddress(p.dll.module, p.Name)
}

View File

@@ -0,0 +1,58 @@
// +build load_wintun_from_rsrc
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun/wintun/memmod"
"golang.zx2c4.com/wireguard/tun/wintun/resource"
)
type lazyDLL struct {
Name string
mu sync.Mutex
module *memmod.Module
}
func (d *lazyDLL) Load() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
return nil
}
d.mu.Lock()
defer d.mu.Unlock()
if d.module != nil {
return nil
}
const ourModule windows.Handle = 0
resInfo, err := resource.FindByName(ourModule, d.Name, resource.RT_RCDATA)
if err != nil {
return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
}
data, err := resource.Load(ourModule, resInfo)
if err != nil {
return fmt.Errorf("Unable to load resource: %w", err)
}
module, err := memmod.LoadLibrary(data)
if err != nil {
return fmt.Errorf("Unable to load library: %w", err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
return nil
}
func (p *lazyProc) nameToAddr() (uintptr, error) {
return p.dll.module.ProcAddressByName(p.Name)
}

59
tun/wintun/dll_windows.go Normal file
View File

@@ -0,0 +1,59 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
)
func newLazyDLL(name string) *lazyDLL {
return &lazyDLL{Name: name}
}
func (d *lazyDLL) NewProc(name string) *lazyProc {
return &lazyProc{dll: d, Name: name}
}
type lazyProc struct {
Name string
mu sync.Mutex
dll *lazyDLL
addr uintptr
}
func (p *lazyProc) Find() error {
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
if p.addr != 0 {
return nil
}
err := p.dll.Load()
if err != nil {
return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
}
addr, err := p.nameToAddr()
if err != nil {
return fmt.Errorf("Error getting %v address: %w", p.Name, err)
}
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
return nil
}
func (p *lazyProc) Addr() uintptr {
err := p.Find()
if err != nil {
panic(err)
}
return p.addr
}

View File

@@ -1,25 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package iphlpapi
import "golang.org/x/sys/windows"
//sys convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
//sys convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) = iphlpapi.ConvertInterfaceAliasToLuid
func InterfaceGUIDFromAlias(alias string) (*windows.GUID, error) {
var luid uint64
var guid windows.GUID
err := convertInterfaceAliasToLUID(windows.StringToUTF16Ptr(alias), &luid)
if err != nil {
return nil, err
}
err = convertInterfaceLUIDToGUID(&luid, &guid)
if err != nil {
return nil, err
}
return &guid, nil
}

View File

@@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package iphlpapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go

View File

@@ -1,60 +0,0 @@
// Code generated by 'go generate'; DO NOT EDIT.
package iphlpapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
procConvertInterfaceAliasToLuid = modiphlpapi.NewProc("ConvertInterfaceAliasToLuid")
)
func convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceAliasToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceAlias)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}

View File

@@ -0,0 +1,620 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
import (
"errors"
"fmt"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
type addressList struct {
next *addressList
address uintptr
}
func (head *addressList) free() {
for node := head; node != nil; node = node.next {
windows.VirtualFree(node.address, 0, windows.MEM_RELEASE)
}
}
type Module struct {
headers *IMAGE_NT_HEADERS
codeBase uintptr
modules []windows.Handle
initialized bool
isDLL bool
isRelocated bool
nameExports map[string]uint16
entry uintptr
blockedMemory *addressList
}
func (module *Module) headerDirectory(idx int) *IMAGE_DATA_DIRECTORY {
return &module.headers.OptionalHeader.DataDirectory[idx]
}
func (module *Module) copySections(address uintptr, size uintptr, old_headers *IMAGE_NT_HEADERS) error {
sections := module.headers.Sections()
for i := range sections {
if sections[i].SizeOfRawData == 0 {
// Section doesn't contain data in the dll itself, but may define uninitialized data.
sectionSize := old_headers.OptionalHeader.SectionAlignment
if sectionSize == 0 {
continue
}
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
uintptr(sectionSize),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating section: %w", err)
}
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
dest = module.codeBase + uintptr(sections[i].VirtualAddress)
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
var dst []byte
unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
for j := range dst {
dst[j] = 0
}
continue
}
if size < uintptr(sections[i].PointerToRawData+sections[i].SizeOfRawData) {
return errors.New("Incomplete section")
}
// Commit memory block and copy data from dll.
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
uintptr(sections[i].SizeOfRawData),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating memory block: %w", err)
}
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
memcpy(
module.codeBase+uintptr(sections[i].VirtualAddress),
address+uintptr(sections[i].PointerToRawData),
uintptr(sections[i].SizeOfRawData))
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
}
return nil
}
func (module *Module) realSectionSize(section *IMAGE_SECTION_HEADER) uintptr {
size := section.SizeOfRawData
if size != 0 {
return uintptr(size)
}
if (section.Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) != 0 {
return uintptr(module.headers.OptionalHeader.SizeOfInitializedData)
}
if (section.Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) != 0 {
return uintptr(module.headers.OptionalHeader.SizeOfUninitializedData)
}
return 0
}
type sectionFinalizeData struct {
address uintptr
alignedAddress uintptr
size uintptr
characteristics uint32
last bool
}
func (module *Module) finalizeSection(sectionData *sectionFinalizeData) error {
if sectionData.size == 0 {
return nil
}
if (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) != 0 {
// Section is not needed any more and can safely be freed.
if sectionData.address == sectionData.alignedAddress &&
(sectionData.last ||
(sectionData.size%uintptr(module.headers.OptionalHeader.SectionAlignment)) == 0) {
// Only allowed to decommit whole pages.
windows.VirtualFree(sectionData.address, sectionData.size, windows.MEM_DECOMMIT)
}
return nil
}
// determine protection flags based on characteristics
var ProtectionFlags = [8]uint32{
windows.PAGE_NOACCESS, // not writeable, not readable, not executable
windows.PAGE_EXECUTE, // not writeable, not readable, executable
windows.PAGE_READONLY, // not writeable, readable, not executable
windows.PAGE_EXECUTE_READ, // not writeable, readable, executable
windows.PAGE_WRITECOPY, // writeable, not readable, not executable
windows.PAGE_EXECUTE_WRITECOPY, // writeable, not readable, executable
windows.PAGE_READWRITE, // writeable, readable, not executable
windows.PAGE_EXECUTE_READWRITE, // writeable, readable, executable
}
protect := ProtectionFlags[sectionData.characteristics>>29]
if (sectionData.characteristics & IMAGE_SCN_MEM_NOT_CACHED) != 0 {
protect |= windows.PAGE_NOCACHE
}
// Change memory access flags.
var oldProtect uint32
err := windows.VirtualProtect(sectionData.address, sectionData.size, protect, &oldProtect)
if err != nil {
return fmt.Errorf("Error protecting memory page: %w", err)
}
return nil
}
func (module *Module) finalizeSections() error {
sections := module.headers.Sections()
imageOffset := module.headers.OptionalHeader.imageOffset()
sectionData := sectionFinalizeData{}
sectionData.address = uintptr(sections[0].PhysicalAddress()) | imageOffset
sectionData.alignedAddress = alignDown(sectionData.address, uintptr(module.headers.OptionalHeader.SectionAlignment))
sectionData.size = module.realSectionSize(&sections[0])
sectionData.characteristics = sections[0].Characteristics
// Loop through all sections and change access flags.
for i := uint16(1); i < module.headers.FileHeader.NumberOfSections; i++ {
sectionAddress := uintptr(sections[i].PhysicalAddress()) | imageOffset
alignedAddress := alignDown(sectionAddress, uintptr(module.headers.OptionalHeader.SectionAlignment))
sectionSize := module.realSectionSize(&sections[i])
// Combine access flags of all sections that share a page.
// TODO: We currently share flags of a trailing large section with the page of a first small section. This should be optimized.
if sectionData.alignedAddress == alignedAddress || sectionData.address+sectionData.size > alignedAddress {
// Section shares page with previous.
if (sections[i].Characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 {
sectionData.characteristics = (sectionData.characteristics | sections[i].Characteristics) &^ IMAGE_SCN_MEM_DISCARDABLE
} else {
sectionData.characteristics |= sections[i].Characteristics
}
sectionData.size = sectionAddress + sectionSize - sectionData.address
continue
}
err := module.finalizeSection(&sectionData)
if err != nil {
return fmt.Errorf("Error finalizing section: %w", err)
}
sectionData.address = sectionAddress
sectionData.alignedAddress = alignedAddress
sectionData.size = sectionSize
sectionData.characteristics = sections[i].Characteristics
}
sectionData.last = true
err := module.finalizeSection(&sectionData)
if err != nil {
return fmt.Errorf("Error finalizing section: %w", err)
}
return nil
}
func (module *Module) executeTLS() {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_TLS)
if directory.VirtualAddress == 0 {
return
}
tls := (*IMAGE_TLS_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
callback := tls.AddressOfCallbacks
if callback != 0 {
for {
f := *(*uintptr)(a2p(callback))
if f == 0 {
break
}
syscall.Syscall(f, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), uintptr(0))
callback += unsafe.Sizeof(f)
}
}
}
func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_BASERELOC)
if directory.Size == 0 {
return delta == 0, nil
}
relocationHdr := (*IMAGE_BASE_RELOCATION)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
for relocationHdr.VirtualAddress > 0 {
dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
var relInfos []uint16
unsafeSlice(
unsafe.Pointer(&relInfos),
a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
for _, relInfo := range relInfos {
// The upper 4 bits define the type of relocation.
relType := relInfo >> 12
// The lower 12 bits define the offset.
relOffset := uintptr(relInfo & 0xfff)
switch relType {
case IMAGE_REL_BASED_ABSOLUTE:
// Skip relocation.
case IMAGE_REL_BASED_LOW:
*(*uint16)(a2p(dest + relOffset)) += uint16(delta & 0xffff)
break
case IMAGE_REL_BASED_HIGH:
*(*uint16)(a2p(dest + relOffset)) += uint16(uint32(delta) >> 16)
break
case IMAGE_REL_BASED_HIGHLOW:
*(*uint32)(a2p(dest + relOffset)) += uint32(delta)
case IMAGE_REL_BASED_DIR64:
*(*uint64)(a2p(dest + relOffset)) += uint64(delta)
case IMAGE_REL_BASED_THUMB_MOV32:
inst := *(*uint32)(a2p(dest + relOffset))
imm16 := ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
if (inst & 0x8000fbf0) != 0x0000f240 {
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVW", inst)
}
imm16 += uint32(delta) & 0xffff
hiDelta := (uint32(delta&0xffff0000) >> 16) + ((imm16 & 0xffff0000) >> 16)
*(*uint32)(a2p(dest + relOffset)) = (inst & 0x8f00fbf0) + ((imm16 >> 1) & 0x0400) +
((imm16 >> 12) & 0x000f) +
((imm16 << 20) & 0x70000000) +
((imm16 << 16) & 0xff0000)
if hiDelta != 0 {
inst = *(*uint32)(a2p(dest + relOffset + 4))
imm16 = ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
if (inst & 0x8000fbf0) != 0x0000f2c0 {
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVT", inst)
}
imm16 += hiDelta
if imm16 > 0xffff {
return false, fmt.Errorf("Resulting immediate value won't fit: %08x", imm16)
}
*(*uint32)(a2p(dest + relOffset + 4)) = (inst & 0x8f00fbf0) +
((imm16 >> 1) & 0x0400) +
((imm16 >> 12) & 0x000f) +
((imm16 << 20) & 0x70000000) +
((imm16 << 16) & 0xff0000)
}
default:
return false, fmt.Errorf("Unsupported relocation: %w", relType)
}
}
// Advance to next relocation block.
relocationHdr = (*IMAGE_BASE_RELOCATION)(a2p(uintptr(unsafe.Pointer(relocationHdr)) + uintptr(relocationHdr.SizeOfBlock)))
}
return true, nil
}
func (module *Module) buildImportTable() error {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_IMPORT)
if directory.Size == 0 {
return nil
}
module.modules = make([]windows.Handle, 0, 16)
importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
for !isBadReadPtr(uintptr(unsafe.Pointer(importDesc)), unsafe.Sizeof(*importDesc)) && importDesc.Name != 0 {
handle, err := loadLibraryA((*byte)(a2p(module.codeBase + uintptr(importDesc.Name))))
if err != nil {
return fmt.Errorf("Error loading module: %w", err)
}
var thunkRef, funcRef *uintptr
if importDesc.OriginalFirstThunk() != 0 {
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.OriginalFirstThunk())))
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
} else {
// No hint table.
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
}
for *thunkRef != 0 {
if IMAGE_SNAP_BY_ORDINAL(*thunkRef) {
*funcRef, err = getProcAddress(handle, (*byte)(a2p(IMAGE_ORDINAL(*thunkRef))))
} else {
thunkData := (*IMAGE_IMPORT_BY_NAME)(a2p(module.codeBase + *thunkRef))
*funcRef, err = getProcAddress(handle, &thunkData.Name[0])
}
if err != nil {
windows.FreeLibrary(handle)
return fmt.Errorf("Error getting function address: %w", err)
}
thunkRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(thunkRef)) + unsafe.Sizeof(*thunkRef)))
funcRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(funcRef)) + unsafe.Sizeof(*funcRef)))
}
module.modules = append(module.modules, handle)
importDesc = (*IMAGE_IMPORT_DESCRIPTOR)(a2p(uintptr(unsafe.Pointer(importDesc)) + unsafe.Sizeof(*importDesc)))
}
return nil
}
func (module *Module) buildNameExports() error {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if exports.NumberOfNames == 0 || exports.NumberOfFunctions == 0 {
return errors.New("No functions exported")
}
if exports.NumberOfNames == 0 {
return errors.New("No functions exported by name")
}
var nameRefs []uint32
unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
var ordinals []uint16
unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
module.nameExports = make(map[string]uint16)
for i := range nameRefs {
nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
module.nameExports[nameArray] = ordinals[i]
}
return nil
}
// LoadLibrary loads module image to memory.
func LoadLibrary(data []byte) (module *Module, err error) {
addr := uintptr(unsafe.Pointer(&data[0]))
size := uintptr(len(data))
if size < unsafe.Sizeof(IMAGE_DOS_HEADER{}) {
return nil, errors.New("Incomplete IMAGE_DOS_HEADER")
}
dosHeader := (*IMAGE_DOS_HEADER)(a2p(addr))
if dosHeader.E_magic != IMAGE_DOS_SIGNATURE {
return nil, fmt.Errorf("Not an MS-DOS binary (provided: %x, expected: %x)", dosHeader.E_magic, IMAGE_DOS_SIGNATURE)
}
if (size < uintptr(dosHeader.E_lfanew)+unsafe.Sizeof(IMAGE_NT_HEADERS{})) {
return nil, errors.New("Incomplete IMAGE_NT_HEADERS")
}
oldHeader := (*IMAGE_NT_HEADERS)(a2p(addr + uintptr(dosHeader.E_lfanew)))
if oldHeader.Signature != IMAGE_NT_SIGNATURE {
return nil, fmt.Errorf("Not an NT binary (provided: %x, expected: %x)", oldHeader.Signature, IMAGE_NT_SIGNATURE)
}
if oldHeader.FileHeader.Machine != imageFileProcess {
return nil, fmt.Errorf("Foreign platform (provided: %x, expected: %x)", oldHeader.FileHeader.Machine, imageFileProcess)
}
if (oldHeader.OptionalHeader.SectionAlignment & 1) != 0 {
return nil, errors.New("Unaligned section")
}
lastSectionEnd := uintptr(0)
sections := oldHeader.Sections()
optionalSectionSize := oldHeader.OptionalHeader.SectionAlignment
for i := range sections {
var endOfSection uintptr
if sections[i].SizeOfRawData == 0 {
// Section without data in the DLL
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(optionalSectionSize)
} else {
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(sections[i].SizeOfRawData)
}
if endOfSection > lastSectionEnd {
lastSectionEnd = endOfSection
}
}
alignedImageSize := alignUp(uintptr(oldHeader.OptionalHeader.SizeOfImage), uintptr(oldHeader.OptionalHeader.SectionAlignment))
if alignedImageSize != alignUp(lastSectionEnd, uintptr(oldHeader.OptionalHeader.SectionAlignment)) {
return nil, errors.New("Section is not page-aligned")
}
module = &Module{isDLL: (oldHeader.FileHeader.Characteristics & IMAGE_FILE_DLL) != 0}
defer func() {
if err != nil {
module.Free()
module = nil
}
}()
// Reserve memory for image of library.
// TODO: Is it correct to commit the complete memory region at once? Calling DllEntry raises an exception if we don't.
module.codeBase, err = windows.VirtualAlloc(oldHeader.OptionalHeader.ImageBase,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
// Try to allocate memory at arbitrary position.
module.codeBase, err = windows.VirtualAlloc(0,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
err = fmt.Errorf("Error allocating code: %w", err)
return
}
}
err = module.check4GBBoundaries(alignedImageSize)
if err != nil {
err = fmt.Errorf("Error reallocating code: %w", err)
return
}
if size < uintptr(oldHeader.OptionalHeader.SizeOfHeaders) {
err = errors.New("Incomplete headers")
return
}
// Commit memory for headers.
headers, err := windows.VirtualAlloc(module.codeBase,
uintptr(oldHeader.OptionalHeader.SizeOfHeaders),
windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
err = fmt.Errorf("Error allocating headers: %w", err)
return
}
// Copy PE header to code.
memcpy(headers, addr, uintptr(oldHeader.OptionalHeader.SizeOfHeaders))
module.headers = (*IMAGE_NT_HEADERS)(a2p(headers + uintptr(dosHeader.E_lfanew)))
// Update position.
module.headers.OptionalHeader.ImageBase = module.codeBase
// Copy sections from DLL file block to new memory location.
err = module.copySections(addr, size, oldHeader)
if err != nil {
err = fmt.Errorf("Error copying sections: %w", err)
return
}
// Adjust base address of imported data.
locationDelta := module.headers.OptionalHeader.ImageBase - oldHeader.OptionalHeader.ImageBase
if locationDelta != 0 {
module.isRelocated, err = module.performBaseRelocation(locationDelta)
if err != nil {
err = fmt.Errorf("Error relocating module: %w", err)
return
}
} else {
module.isRelocated = true
}
// Load required dlls and adjust function table of imports.
err = module.buildImportTable()
if err != nil {
err = fmt.Errorf("Error building import table: %w", err)
return
}
// Mark memory pages depending on section headers and release sections that are marked as "discardable".
err = module.finalizeSections()
if err != nil {
err = fmt.Errorf("Error finalizing sections: %w", err)
return
}
// TLS callbacks are executed BEFORE the main loading.
module.executeTLS()
// Get entry point of loaded module.
if module.headers.OptionalHeader.AddressOfEntryPoint != 0 {
module.entry = module.codeBase + uintptr(module.headers.OptionalHeader.AddressOfEntryPoint)
if module.isDLL {
// Notify library about attaching to process.
r0, _, _ := syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), 0)
successful := r0 != 0
if !successful {
err = windows.ERROR_DLL_INIT_FAILED
return
}
module.initialized = true
}
}
module.buildNameExports()
return
}
// Free releases module resources and unloads it.
func (module *Module) Free() {
if module.initialized {
// Notify library about detaching from process.
syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_DETACH), 0)
module.initialized = false
}
if module.modules != nil {
// Free previously opened libraries.
for _, handle := range module.modules {
windows.FreeLibrary(handle)
}
module.modules = nil
}
if module.codeBase != 0 {
windows.VirtualFree(module.codeBase, 0, windows.MEM_RELEASE)
module.codeBase = 0
}
if module.blockedMemory != nil {
module.blockedMemory.free()
module.blockedMemory = nil
}
}
// ProcAddressByName returns function address by exported name.
func (module *Module) ProcAddressByName(name string) (uintptr, error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return 0, errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if module.nameExports == nil {
return 0, errors.New("No functions exported by name")
}
if idx, ok := module.nameExports[name]; ok {
if uint32(idx) > exports.NumberOfFunctions {
return 0, errors.New("Ordinal number too high")
}
// AddressOfFunctions contains the RVAs to the "real" functions.
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
}
return 0, errors.New("Function not found by name")
}
// ProcAddressByOrdinal returns function address by exported ordinal.
func (module *Module) ProcAddressByOrdinal(ordinal uint16) (uintptr, error) {
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
if directory.Size == 0 {
return 0, errors.New("No export table found")
}
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
if uint32(ordinal) < exports.Base {
return 0, errors.New("Ordinal number too low")
}
idx := ordinal - uint16(exports.Base)
if uint32(idx) > exports.NumberOfFunctions {
return 0, errors.New("Ordinal number too high")
}
// AddressOfFunctions contains the RVAs to the "real" functions.
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
}
func alignDown(value, alignment uintptr) uintptr {
return value & ^(alignment - 1)
}
func alignUp(value, alignment uintptr) uintptr {
return (value + alignment - 1) & ^(alignment - 1)
}
func a2p(addr uintptr) unsafe.Pointer {
return unsafe.Pointer(addr)
}
func memcpy(dst, src, size uintptr) {
var d, s []byte
unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
copy(d, s)
}
// unsafeSlice updates the slice slicePtr to be a slice
// referencing the provided data with its length & capacity set to
// lenCap.
//
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
// update callers to use unsafe.Slice instead of this.
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
h := (*sliceHeader)(slicePtr)
h.Data = data
h.Len = lenCap
h.Cap = lenCap
}

View File

@@ -0,0 +1,16 @@
// +build 386 arm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
return 0
}
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
return
}

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_I386

View File

@@ -0,0 +1,36 @@
// +build amd64 arm64
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
import (
"fmt"
"golang.org/x/sys/windows"
)
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
return uintptr(opthdr.ImageBase & 0xffffffff00000000)
}
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) {
node := &addressList{
next: module.blockedMemory,
address: module.codeBase,
}
module.blockedMemory = node
module.codeBase, err = windows.VirtualAlloc(0,
alignedImageSize,
windows.MEM_RESERVE|windows.MEM_COMMIT,
windows.PAGE_READWRITE)
if err != nil {
return fmt.Errorf("Error allocating memory block: %w", err)
}
}
return
}

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_AMD64

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
const imageFileProcess = IMAGE_FILE_MACHINE_ARM64

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go

View File

@@ -0,0 +1,343 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
import "unsafe"
const (
IMAGE_DOS_SIGNATURE = 0x5A4D // MZ
IMAGE_OS2_SIGNATURE = 0x454E // NE
IMAGE_OS2_SIGNATURE_LE = 0x454C // LE
IMAGE_VXD_SIGNATURE = 0x454C // LE
IMAGE_NT_SIGNATURE = 0x00004550 // PE00
)
// DOS .EXE header
type IMAGE_DOS_HEADER struct {
E_magic uint16 // Magic number
E_cblp uint16 // Bytes on last page of file
E_cp uint16 // Pages in file
E_crlc uint16 // Relocations
E_cparhdr uint16 // Size of header in paragraphs
E_minalloc uint16 // Minimum extra paragraphs needed
E_maxalloc uint16 // Maximum extra paragraphs needed
E_ss uint16 // Initial (relative) SS value
E_sp uint16 // Initial SP value
E_csum uint16 // Checksum
E_ip uint16 // Initial IP value
E_cs uint16 // Initial (relative) CS value
E_lfarlc uint16 // File address of relocation table
E_ovno uint16 // Overlay number
E_res [4]uint16 // Reserved words
E_oemid uint16 // OEM identifier (for e_oeminfo)
E_oeminfo uint16 // OEM information; e_oemid specific
E_res2 [10]uint16 // Reserved words
E_lfanew int32 // File address of new exe header
}
// File header format
type IMAGE_FILE_HEADER struct {
Machine uint16
NumberOfSections uint16
TimeDateStamp uint32
PointerToSymbolTable uint32
NumberOfSymbols uint32
SizeOfOptionalHeader uint16
Characteristics uint16
}
const (
IMAGE_SIZEOF_FILE_HEADER = 20
IMAGE_FILE_RELOCS_STRIPPED = 0x0001 // Relocation info stripped from file.
IMAGE_FILE_EXECUTABLE_IMAGE = 0x0002 // File is executable (i.e. no unresolved external references).
IMAGE_FILE_LINE_NUMS_STRIPPED = 0x0004 // Line nunbers stripped from file.
IMAGE_FILE_LOCAL_SYMS_STRIPPED = 0x0008 // Local symbols stripped from file.
IMAGE_FILE_AGGRESIVE_WS_TRIM = 0x0010 // Aggressively trim working set
IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020 // App can handle >2gb addresses
IMAGE_FILE_BYTES_REVERSED_LO = 0x0080 // Bytes of machine word are reversed.
IMAGE_FILE_32BIT_MACHINE = 0x0100 // 32 bit word machine.
IMAGE_FILE_DEBUG_STRIPPED = 0x0200 // Debugging info stripped from file in .DBG file
IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400 // If Image is on removable media, copy and run from the swap file.
IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800 // If Image is on Net, copy and run from the swap file.
IMAGE_FILE_SYSTEM = 0x1000 // System File.
IMAGE_FILE_DLL = 0x2000 // File is a DLL.
IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 // File should only be run on a UP machine
IMAGE_FILE_BYTES_REVERSED_HI = 0x8000 // Bytes of machine word are reversed.
IMAGE_FILE_MACHINE_UNKNOWN = 0
IMAGE_FILE_MACHINE_TARGET_HOST = 0x0001 // Useful for indicating we want to interact with the host and not a WoW guest.
IMAGE_FILE_MACHINE_I386 = 0x014c // Intel 386.
IMAGE_FILE_MACHINE_R3000 = 0x0162 // MIPS little-endian, 0x160 big-endian
IMAGE_FILE_MACHINE_R4000 = 0x0166 // MIPS little-endian
IMAGE_FILE_MACHINE_R10000 = 0x0168 // MIPS little-endian
IMAGE_FILE_MACHINE_WCEMIPSV2 = 0x0169 // MIPS little-endian WCE v2
IMAGE_FILE_MACHINE_ALPHA = 0x0184 // Alpha_AXP
IMAGE_FILE_MACHINE_SH3 = 0x01a2 // SH3 little-endian
IMAGE_FILE_MACHINE_SH3DSP = 0x01a3
IMAGE_FILE_MACHINE_SH3E = 0x01a4 // SH3E little-endian
IMAGE_FILE_MACHINE_SH4 = 0x01a6 // SH4 little-endian
IMAGE_FILE_MACHINE_SH5 = 0x01a8 // SH5
IMAGE_FILE_MACHINE_ARM = 0x01c0 // ARM Little-Endian
IMAGE_FILE_MACHINE_THUMB = 0x01c2 // ARM Thumb/Thumb-2 Little-Endian
IMAGE_FILE_MACHINE_ARMNT = 0x01c4 // ARM Thumb-2 Little-Endian
IMAGE_FILE_MACHINE_AM33 = 0x01d3
IMAGE_FILE_MACHINE_POWERPC = 0x01F0 // IBM PowerPC Little-Endian
IMAGE_FILE_MACHINE_POWERPCFP = 0x01f1
IMAGE_FILE_MACHINE_IA64 = 0x0200 // Intel 64
IMAGE_FILE_MACHINE_MIPS16 = 0x0266 // MIPS
IMAGE_FILE_MACHINE_ALPHA64 = 0x0284 // ALPHA64
IMAGE_FILE_MACHINE_MIPSFPU = 0x0366 // MIPS
IMAGE_FILE_MACHINE_MIPSFPU16 = 0x0466 // MIPS
IMAGE_FILE_MACHINE_AXP64 = IMAGE_FILE_MACHINE_ALPHA64
IMAGE_FILE_MACHINE_TRICORE = 0x0520 // Infineon
IMAGE_FILE_MACHINE_CEF = 0x0CEF
IMAGE_FILE_MACHINE_EBC = 0x0EBC // EFI Byte Code
IMAGE_FILE_MACHINE_AMD64 = 0x8664 // AMD64 (K8)
IMAGE_FILE_MACHINE_M32R = 0x9041 // M32R little-endian
IMAGE_FILE_MACHINE_ARM64 = 0xAA64 // ARM64 Little-Endian
IMAGE_FILE_MACHINE_CEE = 0xC0EE
)
// Directory format
type IMAGE_DATA_DIRECTORY struct {
VirtualAddress uint32
Size uint32
}
const IMAGE_NUMBEROF_DIRECTORY_ENTRIES = 16
type IMAGE_NT_HEADERS struct {
Signature uint32
FileHeader IMAGE_FILE_HEADER
OptionalHeader IMAGE_OPTIONAL_HEADER
}
func (ntheader *IMAGE_NT_HEADERS) Sections() []IMAGE_SECTION_HEADER {
return (*[0xffff]IMAGE_SECTION_HEADER)(unsafe.Pointer(
(uintptr)(unsafe.Pointer(ntheader)) +
unsafe.Offsetof(ntheader.OptionalHeader) +
uintptr(ntheader.FileHeader.SizeOfOptionalHeader)))[:ntheader.FileHeader.NumberOfSections]
}
const (
IMAGE_DIRECTORY_ENTRY_EXPORT = 0 // Export Directory
IMAGE_DIRECTORY_ENTRY_IMPORT = 1 // Import Directory
IMAGE_DIRECTORY_ENTRY_RESOURCE = 2 // Resource Directory
IMAGE_DIRECTORY_ENTRY_EXCEPTION = 3 // Exception Directory
IMAGE_DIRECTORY_ENTRY_SECURITY = 4 // Security Directory
IMAGE_DIRECTORY_ENTRY_BASERELOC = 5 // Base Relocation Table
IMAGE_DIRECTORY_ENTRY_DEBUG = 6 // Debug Directory
IMAGE_DIRECTORY_ENTRY_COPYRIGHT = 7 // (X86 usage)
IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = 7 // Architecture Specific Data
IMAGE_DIRECTORY_ENTRY_GLOBALPTR = 8 // RVA of GP
IMAGE_DIRECTORY_ENTRY_TLS = 9 // TLS Directory
IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = 10 // Load Configuration Directory
IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = 11 // Bound Import Directory in headers
IMAGE_DIRECTORY_ENTRY_IAT = 12 // Import Address Table
IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = 13 // Delay Load Import Descriptors
IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = 14 // COM Runtime descriptor
)
const IMAGE_SIZEOF_SHORT_NAME = 8
// Section header format
type IMAGE_SECTION_HEADER struct {
Name [IMAGE_SIZEOF_SHORT_NAME]byte
physicalAddressOrVirtualSize uint32
VirtualAddress uint32
SizeOfRawData uint32
PointerToRawData uint32
PointerToRelocations uint32
PointerToLinenumbers uint32
NumberOfRelocations uint16
NumberOfLinenumbers uint16
Characteristics uint32
}
func (ishdr *IMAGE_SECTION_HEADER) PhysicalAddress() uint32 {
return ishdr.physicalAddressOrVirtualSize
}
func (ishdr *IMAGE_SECTION_HEADER) SetPhysicalAddress(addr uint32) {
ishdr.physicalAddressOrVirtualSize = addr
}
func (ishdr *IMAGE_SECTION_HEADER) VirtualSize() uint32 {
return ishdr.physicalAddressOrVirtualSize
}
func (ishdr *IMAGE_SECTION_HEADER) SetVirtualSize(addr uint32) {
ishdr.physicalAddressOrVirtualSize = addr
}
const (
// Section characteristics.
IMAGE_SCN_TYPE_REG = 0x00000000 // Reserved.
IMAGE_SCN_TYPE_DSECT = 0x00000001 // Reserved.
IMAGE_SCN_TYPE_NOLOAD = 0x00000002 // Reserved.
IMAGE_SCN_TYPE_GROUP = 0x00000004 // Reserved.
IMAGE_SCN_TYPE_NO_PAD = 0x00000008 // Reserved.
IMAGE_SCN_TYPE_COPY = 0x00000010 // Reserved.
IMAGE_SCN_CNT_CODE = 0x00000020 // Section contains code.
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040 // Section contains initialized data.
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080 // Section contains uninitialized data.
IMAGE_SCN_LNK_OTHER = 0x00000100 // Reserved.
IMAGE_SCN_LNK_INFO = 0x00000200 // Section contains comments or some other type of information.
IMAGE_SCN_TYPE_OVER = 0x00000400 // Reserved.
IMAGE_SCN_LNK_REMOVE = 0x00000800 // Section contents will not become part of image.
IMAGE_SCN_LNK_COMDAT = 0x00001000 // Section contents comdat.
IMAGE_SCN_MEM_PROTECTED = 0x00004000 // Obsolete.
IMAGE_SCN_NO_DEFER_SPEC_EXC = 0x00004000 // Reset speculative exceptions handling bits in the TLB entries for this section.
IMAGE_SCN_GPREL = 0x00008000 // Section content can be accessed relative to GP
IMAGE_SCN_MEM_FARDATA = 0x00008000
IMAGE_SCN_MEM_SYSHEAP = 0x00010000 // Obsolete.
IMAGE_SCN_MEM_PURGEABLE = 0x00020000
IMAGE_SCN_MEM_16BIT = 0x00020000
IMAGE_SCN_MEM_LOCKED = 0x00040000
IMAGE_SCN_MEM_PRELOAD = 0x00080000
IMAGE_SCN_ALIGN_1BYTES = 0x00100000 //
IMAGE_SCN_ALIGN_2BYTES = 0x00200000 //
IMAGE_SCN_ALIGN_4BYTES = 0x00300000 //
IMAGE_SCN_ALIGN_8BYTES = 0x00400000 //
IMAGE_SCN_ALIGN_16BYTES = 0x00500000 // Default alignment if no others are specified.
IMAGE_SCN_ALIGN_32BYTES = 0x00600000 //
IMAGE_SCN_ALIGN_64BYTES = 0x00700000 //
IMAGE_SCN_ALIGN_128BYTES = 0x00800000 //
IMAGE_SCN_ALIGN_256BYTES = 0x00900000 //
IMAGE_SCN_ALIGN_512BYTES = 0x00A00000 //
IMAGE_SCN_ALIGN_1024BYTES = 0x00B00000 //
IMAGE_SCN_ALIGN_2048BYTES = 0x00C00000 //
IMAGE_SCN_ALIGN_4096BYTES = 0x00D00000 //
IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000 //
IMAGE_SCN_ALIGN_MASK = 0x00F00000
IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000 // Section contains extended relocations.
IMAGE_SCN_MEM_DISCARDABLE = 0x02000000 // Section can be discarded.
IMAGE_SCN_MEM_NOT_CACHED = 0x04000000 // Section is not cachable.
IMAGE_SCN_MEM_NOT_PAGED = 0x08000000 // Section is not pageable.
IMAGE_SCN_MEM_SHARED = 0x10000000 // Section is shareable.
IMAGE_SCN_MEM_EXECUTE = 0x20000000 // Section is executable.
IMAGE_SCN_MEM_READ = 0x40000000 // Section is readable.
IMAGE_SCN_MEM_WRITE = 0x80000000 // Section is writeable.
// TLS Characteristic Flags
IMAGE_SCN_SCALE_INDEX = 0x00000001 // Tls index is scaled.
)
// Based relocation format
type IMAGE_BASE_RELOCATION struct {
VirtualAddress uint32
SizeOfBlock uint32
}
const (
IMAGE_REL_BASED_ABSOLUTE = 0
IMAGE_REL_BASED_HIGH = 1
IMAGE_REL_BASED_LOW = 2
IMAGE_REL_BASED_HIGHLOW = 3
IMAGE_REL_BASED_HIGHADJ = 4
IMAGE_REL_BASED_MACHINE_SPECIFIC_5 = 5
IMAGE_REL_BASED_RESERVED = 6
IMAGE_REL_BASED_MACHINE_SPECIFIC_7 = 7
IMAGE_REL_BASED_MACHINE_SPECIFIC_8 = 8
IMAGE_REL_BASED_MACHINE_SPECIFIC_9 = 9
IMAGE_REL_BASED_DIR64 = 10
IMAGE_REL_BASED_IA64_IMM64 = 9
IMAGE_REL_BASED_MIPS_JMPADDR = 5
IMAGE_REL_BASED_MIPS_JMPADDR16 = 9
IMAGE_REL_BASED_ARM_MOV32 = 5
IMAGE_REL_BASED_THUMB_MOV32 = 7
)
// Export Format
type IMAGE_EXPORT_DIRECTORY struct {
Characteristics uint32
TimeDateStamp uint32
MajorVersion uint16
MinorVersion uint16
Name uint32
Base uint32
NumberOfFunctions uint32
NumberOfNames uint32
AddressOfFunctions uint32 // RVA from base of image
AddressOfNames uint32 // RVA from base of image
AddressOfNameOrdinals uint32 // RVA from base of image
}
type IMAGE_IMPORT_BY_NAME struct {
Hint uint16
Name [1]byte
}
func IMAGE_ORDINAL(ordinal uintptr) uintptr {
return ordinal & 0xffff
}
func IMAGE_SNAP_BY_ORDINAL(ordinal uintptr) bool {
return (ordinal & IMAGE_ORDINAL_FLAG) != 0
}
// Thread Local Storage
type IMAGE_TLS_DIRECTORY struct {
StartAddressOfRawData uintptr
EndAddressOfRawData uintptr
AddressOfIndex uintptr // PDWORD
AddressOfCallbacks uintptr // PIMAGE_TLS_CALLBACK *;
SizeOfZeroFill uint32
Characteristics uint32
}
type IMAGE_IMPORT_DESCRIPTOR struct {
characteristicsOrOriginalFirstThunk uint32 // 0 for terminating null import descriptor
// RVA to original unbound IAT (PIMAGE_THUNK_DATA)
TimeDateStamp uint32 // 0 if not bound,
// -1 if bound, and real date\time stamp
// in IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT (new BIND)
// O.W. date/time stamp of DLL bound to (Old BIND)
ForwarderChain uint32 // -1 if no forwarders
Name uint32
FirstThunk uint32 // RVA to IAT (if bound this IAT has actual addresses)
}
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) Characteristics() uint32 {
return imgimpdesc.characteristicsOrOriginalFirstThunk
}
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) OriginalFirstThunk() uint32 {
return imgimpdesc.characteristicsOrOriginalFirstThunk
}
const (
DLL_PROCESS_ATTACH = 1
DLL_THREAD_ATTACH = 2
DLL_THREAD_DETACH = 3
DLL_PROCESS_DETACH = 0
)
//sys loadLibraryA(libFileName *byte) (module windows.Handle, err error) = kernel32.LoadLibraryA
//sys getProcAddress(module windows.Handle, procName *byte) (addr uintptr, err error) = kernel32.GetProcAddress
//sys isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) = kernel32.IsBadReadPtr
type SYSTEM_INFO struct {
ProcessorArchitecture uint16
Reserved uint16
PageSize uint32
MinimumApplicationAddress uintptr
MaximumApplicationAddress uintptr
ActiveProcessorMask uintptr
NumberOfProcessors uint32
ProcessorType uint32
AllocationGranularity uint32
ProcessorLevel uint16
ProcessorRevision uint16
}

View File

@@ -0,0 +1,45 @@
// +build 386 arm
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
// Optional header format
type IMAGE_OPTIONAL_HEADER struct {
Magic uint16
MajorLinkerVersion uint8
MinorLinkerVersion uint8
SizeOfCode uint32
SizeOfInitializedData uint32
SizeOfUninitializedData uint32
AddressOfEntryPoint uint32
BaseOfCode uint32
BaseOfData uint32
ImageBase uintptr
SectionAlignment uint32
FileAlignment uint32
MajorOperatingSystemVersion uint16
MinorOperatingSystemVersion uint16
MajorImageVersion uint16
MinorImageVersion uint16
MajorSubsystemVersion uint16
MinorSubsystemVersion uint16
Win32VersionValue uint32
SizeOfImage uint32
SizeOfHeaders uint32
CheckSum uint32
Subsystem uint16
DllCharacteristics uint16
SizeOfStackReserve uintptr
SizeOfStackCommit uintptr
SizeOfHeapReserve uintptr
SizeOfHeapCommit uintptr
LoaderFlags uint32
NumberOfRvaAndSizes uint32
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
}
const IMAGE_ORDINAL_FLAG uintptr = 0x80000000

View File

@@ -0,0 +1,44 @@
// +build amd64 arm64
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
*/
package memmod
// Optional header format
type IMAGE_OPTIONAL_HEADER struct {
Magic uint16
MajorLinkerVersion uint8
MinorLinkerVersion uint8
SizeOfCode uint32
SizeOfInitializedData uint32
SizeOfUninitializedData uint32
AddressOfEntryPoint uint32
BaseOfCode uint32
ImageBase uintptr
SectionAlignment uint32
FileAlignment uint32
MajorOperatingSystemVersion uint16
MinorOperatingSystemVersion uint16
MajorImageVersion uint16
MinorImageVersion uint16
MajorSubsystemVersion uint16
MinorSubsystemVersion uint16
Win32VersionValue uint32
SizeOfImage uint32
SizeOfHeaders uint32
CheckSum uint32
Subsystem uint16
DllCharacteristics uint16
SizeOfStackReserve uintptr
SizeOfStackCommit uintptr
SizeOfHeapReserve uintptr
SizeOfHeapCommit uintptr
LoaderFlags uint32
NumberOfRvaAndSizes uint32
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
}
const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000

View File

@@ -0,0 +1,70 @@
// Code generated by 'go generate'; DO NOT EDIT.
package memmod
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procGetProcAddress = modkernel32.NewProc("GetProcAddress")
procIsBadReadPtr = modkernel32.NewProc("IsBadReadPtr")
procLoadLibraryA = modkernel32.NewProc("LoadLibraryA")
)
func getProcAddress(module windows.Handle, procName *byte) (addr uintptr, err error) {
r0, _, e1 := syscall.Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(procName)), 0)
addr = uintptr(r0)
if addr == 0 {
err = errnoErr(e1)
}
return
}
func isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) {
r0, _, _ := syscall.Syscall(procIsBadReadPtr.Addr(), 2, uintptr(addr), uintptr(ucb), 0)
ret = r0 != 0
return
}
func loadLibraryA(libFileName *byte) (module windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procLoadLibraryA.Addr(), 1, uintptr(unsafe.Pointer(libFileName)), 0, 0)
module = windows.Handle(r0)
if module == 0 {
err = errnoErr(e1)
}
return
}

View File

@@ -1,98 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"encoding/hex"
"errors"
"fmt"
"sync"
"unsafe"
"golang.org/x/crypto/blake2s"
"golang.org/x/sys/windows"
"golang.org/x/text/unicode/norm"
"golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
)
var (
wintunObjectSecurityAttributes *windows.SecurityAttributes
hasInitializedNamespace bool
initializingNamespace sync.Mutex
)
func initializeNamespace() error {
initializingNamespace.Lock()
defer initializingNamespace.Unlock()
if hasInitializedNamespace {
return nil
}
sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil {
return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
}
wintunObjectSecurityAttributes = &windows.SecurityAttributes{
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
SecurityDescriptor: sd,
}
sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
if err != nil {
return fmt.Errorf("CreateWellKnownSid(LOCAL_SYSTEM) failed: %v", err)
}
boundary, err := namespaceapi.CreateBoundaryDescriptor("Wintun")
if err != nil {
return fmt.Errorf("CreateBoundaryDescriptor failed: %v", err)
}
err = boundary.AddSid(sid)
if err != nil {
return fmt.Errorf("AddSIDToBoundaryDescriptor failed: %v", err)
}
for {
_, err = namespaceapi.CreatePrivateNamespace(wintunObjectSecurityAttributes, boundary, "Wintun")
if err == windows.ERROR_ALREADY_EXISTS {
_, err = namespaceapi.OpenPrivateNamespace(boundary, "Wintun")
if err == windows.ERROR_PATH_NOT_FOUND {
continue
}
}
if err != nil {
return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err)
}
break
}
hasInitializedNamespace = true
return nil
}
func (pool Pool) takeNameMutex() (windows.Handle, error) {
err := initializeNamespace()
if err != nil {
return 0, err
}
const mutexLabel = "WireGuard Adapter Name Mutex Stable Suffix v1 jason@zx2c4.com"
b2, _ := blake2s.New256(nil)
b2.Write([]byte(mutexLabel))
b2.Write(norm.NFC.Bytes([]byte(string(pool))))
mutexName := `Wintun\Wintun-Name-Mutex-` + hex.EncodeToString(b2.Sum(nil))
mutex, err := windows.CreateMutex(wintunObjectSecurityAttributes, false, windows.StringToUTF16Ptr(mutexName))
if err != nil {
err = fmt.Errorf("Error creating name mutex: %v", err)
return 0, err
}
event, err := windows.WaitForSingleObject(mutex, windows.INFINITE)
if err != nil {
windows.CloseHandle(mutex)
return 0, fmt.Errorf("Error waiting on name mutex: %v", err)
}
if event != windows.WAIT_OBJECT_0 && event != windows.WAIT_ABANDONED {
windows.CloseHandle(mutex)
return 0, errors.New("Error with event trigger of name mutex")
}
return mutex, nil
}

View File

@@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package namespaceapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go

View File

@@ -1,83 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package namespaceapi
import "golang.org/x/sys/windows"
//sys createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) = kernel32.CreateBoundaryDescriptorW
//sys deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) = kernel32.DeleteBoundaryDescriptor
//sys addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) = kernel32.AddSIDToBoundaryDescriptor
//sys createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.CreatePrivateNamespaceW
//sys openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.OpenPrivateNamespaceW
//sys closePrivateNamespace(handle windows.Handle, flags uint32) (err error) = kernel32.ClosePrivateNamespace
// BoundaryDescriptor represents a boundary that defines how the objects in the namespace are to be isolated.
type BoundaryDescriptor windows.Handle
// CreateBoundaryDescriptor creates a boundary descriptor.
func CreateBoundaryDescriptor(name string) (BoundaryDescriptor, error) {
name16, err := windows.UTF16PtrFromString(name)
if err != nil {
return 0, err
}
handle, err := createBoundaryDescriptor(name16, 0)
if err != nil {
return 0, err
}
return BoundaryDescriptor(handle), nil
}
// Delete deletes the specified boundary descriptor.
func (bd BoundaryDescriptor) Delete() {
deleteBoundaryDescriptor(windows.Handle(bd))
}
// AddSid adds a security identifier (SID) to the specified boundary descriptor.
func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
}
// PrivateNamespace represents a private namespace.
type PrivateNamespace windows.Handle
// CreatePrivateNamespace creates a private namespace.
func CreatePrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
if err != nil {
return 0, err
}
handle, err := createPrivateNamespace(privateNamespaceAttributes, windows.Handle(boundaryDescriptor), aliasPrefix16)
if err != nil {
return 0, err
}
return PrivateNamespace(handle), nil
}
// OpenPrivateNamespace opens a private namespace.
func OpenPrivateNamespace(boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
if err != nil {
return 0, err
}
handle, err := openPrivateNamespace(windows.Handle(boundaryDescriptor), aliasPrefix16)
if err != nil {
return 0, err
}
return PrivateNamespace(handle), nil
}
// ClosePrivateNamespaceFlags describes flags that are used by PrivateNamespace's Close() method.
type ClosePrivateNamespaceFlags uint32
const (
// PrivateNamespaceFlagDestroy makes the close to destroy the namespace.
PrivateNamespaceFlagDestroy = ClosePrivateNamespaceFlags(0x1)
)
// Close closes an open namespace handle.
func (pns PrivateNamespace) Close(flags ClosePrivateNamespaceFlags) error {
return closePrivateNamespace(windows.Handle(pns), uint32(flags))
}

View File

@@ -1,116 +0,0 @@
// Code generated by 'go generate'; DO NOT EDIT.
package namespaceapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procCreateBoundaryDescriptorW = modkernel32.NewProc("CreateBoundaryDescriptorW")
procDeleteBoundaryDescriptor = modkernel32.NewProc("DeleteBoundaryDescriptor")
procAddSIDToBoundaryDescriptor = modkernel32.NewProc("AddSIDToBoundaryDescriptor")
procCreatePrivateNamespaceW = modkernel32.NewProc("CreatePrivateNamespaceW")
procOpenPrivateNamespaceW = modkernel32.NewProc("OpenPrivateNamespaceW")
procClosePrivateNamespace = modkernel32.NewProc("ClosePrivateNamespace")
)
func createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procCreateBoundaryDescriptorW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(flags), 0)
handle = windows.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) {
syscall.Syscall(procDeleteBoundaryDescriptor.Addr(), 1, uintptr(boundaryDescriptor), 0, 0)
return
}
func addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) {
r1, _, e1 := syscall.Syscall(procAddSIDToBoundaryDescriptor.Addr(), 2, uintptr(unsafe.Pointer(boundaryDescriptor)), uintptr(unsafe.Pointer(requiredSid)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procCreatePrivateNamespaceW.Addr(), 3, uintptr(unsafe.Pointer(privateNamespaceAttributes)), uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)))
handle = windows.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procOpenPrivateNamespaceW.Addr(), 2, uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)), 0)
handle = windows.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func closePrivateNamespace(handle windows.Handle, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procClosePrivateNamespace.Addr(), 2, uintptr(handle), uintptr(flags), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}

View File

@@ -1,8 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package nci
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go

Some files were not shown because too many files have changed in this diff Show More