Compare commits
121 Commits
0.0.201908
...
0.0.202011
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da19db415a | ||
|
|
52c834c446 | ||
|
|
913f68ce38 | ||
|
|
60b3766b89 | ||
|
|
82128c47d9 | ||
|
|
c192b2eeec | ||
|
|
a3b231b31e | ||
|
|
65e03a9182 | ||
|
|
3e08b8aee0 | ||
|
|
5ca1218a5c | ||
|
|
3b490f30aa | ||
|
|
e6b7c4eef3 | ||
|
|
8ae09213a7 | ||
|
|
36dc8b6994 | ||
|
|
2057f19a61 | ||
|
|
58a8f05f50 | ||
|
|
0b54907a73 | ||
|
|
2c143dce0f | ||
|
|
22af3890f6 | ||
|
|
c8fe925020 | ||
|
|
0cfa3314ee | ||
|
|
bc3f505efa | ||
|
|
507f148e1c | ||
|
|
31b574ef99 | ||
|
|
3c41141fb4 | ||
|
|
4369db522b | ||
|
|
b84f1d4db2 | ||
|
|
dfb28757f7 | ||
|
|
00bcd865e6 | ||
|
|
f28a6d244b | ||
|
|
c403da6a39 | ||
|
|
d6de6f3ce6 | ||
|
|
59e556f24e | ||
|
|
31faf4c159 | ||
|
|
99eb7896be | ||
|
|
f60b3919be | ||
|
|
da9d300cf8 | ||
|
|
59c9929714 | ||
|
|
db0aa39b76 | ||
|
|
bc77de2aca | ||
|
|
c8596328e7 | ||
|
|
28c4d04304 | ||
|
|
fdba6c183a | ||
|
|
250b9795f3 | ||
|
|
d60857e1a7 | ||
|
|
2fb0a712f0 | ||
|
|
f2c6faad44 | ||
|
|
c76b818466 | ||
|
|
de374bfb44 | ||
|
|
1a1c3d0968 | ||
|
|
85a45a9651 | ||
|
|
abd287159e | ||
|
|
203554620d | ||
|
|
6aefb61355 | ||
|
|
3dce460c88 | ||
|
|
224bc9e60c | ||
|
|
9cd8909df2 | ||
|
|
ae88e2a2cd | ||
|
|
4739708ca4 | ||
|
|
b33219c2cf | ||
|
|
9cbcff10dd | ||
|
|
6ed56ff2df | ||
|
|
cb4bb63030 | ||
|
|
05b03c6750 | ||
|
|
caebdfe9d0 | ||
|
|
4fa2ea6a2d | ||
|
|
89dd065e53 | ||
|
|
ddfad453cf | ||
|
|
2b242f9393 | ||
|
|
4cdf805b29 | ||
|
|
f7d0edd2ec | ||
|
|
ffffbbcc8a | ||
|
|
47b02c618b | ||
|
|
fd23c66fcd | ||
|
|
ae492d1b35 | ||
|
|
95fbfccf60 | ||
|
|
c85e4a410f | ||
|
|
1b6c8ddbe8 | ||
|
|
0abb6b668c | ||
|
|
540d01e54a | ||
|
|
f2ea85e9f9 | ||
|
|
222f0f8000 | ||
|
|
1f146a5e7a | ||
|
|
f2501aa6c8 | ||
|
|
cb8d01f58a | ||
|
|
01f8ef4e84 | ||
|
|
70f6c42556 | ||
|
|
bb0b2514c0 | ||
|
|
7c97fdb1e3 | ||
|
|
84b5a4d83d | ||
|
|
4cd06c0925 | ||
|
|
d12eb91f9a | ||
|
|
73d3bd9cd5 | ||
|
|
f3dba4c194 | ||
|
|
7937840f96 | ||
|
|
e4b957183c | ||
|
|
950ca2ba8c | ||
|
|
df2bf34373 | ||
|
|
a12b765784 | ||
|
|
14df9c3e75 | ||
|
|
353f0956bc | ||
|
|
fa7763c268 | ||
|
|
d94bae8348 | ||
|
|
7689d09336 | ||
|
|
69c26dc258 | ||
|
|
e862131d3c | ||
|
|
da28a3e9f3 | ||
|
|
3bf3322b2c | ||
|
|
7305b4ce93 | ||
|
|
26fb615b11 | ||
|
|
7fbb24afaa | ||
|
|
d9008ac35c | ||
|
|
f8198c0428 | ||
|
|
0c540ad60e | ||
|
|
3cedc22d7b | ||
|
|
68fea631d8 | ||
|
|
ef23100a4f | ||
|
|
eb786cd7c1 | ||
|
|
333de75370 | ||
|
|
d20459dc69 | ||
|
|
01786286c1 |
21
Makefile
21
Makefile
@@ -1,30 +1,16 @@
|
|||||||
PREFIX ?= /usr
|
PREFIX ?= /usr
|
||||||
DESTDIR ?=
|
DESTDIR ?=
|
||||||
BINDIR ?= $(PREFIX)/bin
|
BINDIR ?= $(PREFIX)/bin
|
||||||
export GOPATH ?= $(CURDIR)/.gopath
|
|
||||||
export GO111MODULE := on
|
export GO111MODULE := on
|
||||||
|
|
||||||
all: generate-version-and-build
|
all: generate-version-and-build
|
||||||
|
|
||||||
ifeq ($(shell go env GOOS)|$(wildcard .git),linux|)
|
|
||||||
$(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.)
|
|
||||||
else ifeq ($(shell go env GOOS),linux)
|
|
||||||
ireallywantobuildon_linux.go:
|
|
||||||
@printf "WARNING: This software is meant for use on non-Linux\nsystems. For Linux, please use the kernel module\ninstead. See wireguard.com/install/ for more info.\n\n" >&2
|
|
||||||
@printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > "$@"
|
|
||||||
clean-ireallywantobuildon_linux.go:
|
|
||||||
@rm -f ireallywantobuildon_linux.go
|
|
||||||
.PHONY: clean-ireallywantobuildon_linux.go
|
|
||||||
clean: clean-ireallywantobuildon_linux.go
|
|
||||||
wireguard-go: ireallywantobuildon_linux.go
|
|
||||||
endif
|
|
||||||
|
|
||||||
MAKEFLAGS += --no-print-directory
|
MAKEFLAGS += --no-print-directory
|
||||||
|
|
||||||
generate-version-and-build:
|
generate-version-and-build:
|
||||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
||||||
tag="$$(git describe --dirty 2>/dev/null)" && \
|
tag="$$(git describe --dirty 2>/dev/null)" && \
|
||||||
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \
|
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$${tag#v}")" && \
|
||||||
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
|
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
|
||||||
echo "$$ver" > device/version.go && \
|
echo "$$ver" > device/version.go && \
|
||||||
git update-index --assume-unchanged device/version.go || true
|
git update-index --assume-unchanged device/version.go || true
|
||||||
@@ -36,7 +22,10 @@ wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
|||||||
install: wireguard-go
|
install: wireguard-go
|
||||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
|
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test -v ./...
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -f wireguard-go
|
rm -f wireguard-go
|
||||||
|
|
||||||
.PHONY: all clean install generate-version-and-build
|
.PHONY: all clean test install generate-version-and-build
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
|
|||||||
$ wireguard-go -f wg0
|
$ wireguard-go -f wg0
|
||||||
```
|
```
|
||||||
|
|
||||||
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||||
|
|
||||||
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
|||||||
|
|
||||||
### Linux
|
### Linux
|
||||||
|
|
||||||
This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead use the kernel module; see the [installation page](https://www.wireguard.com/install/) for instructions.
|
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
|
|||||||
|
|
||||||
## Building
|
## Building
|
||||||
|
|
||||||
This requires an installation of [go](https://golang.org) ≥ 1.12.
|
This requires an installation of [go](https://golang.org) ≥ 1.13.
|
||||||
|
|
||||||
```
|
```
|
||||||
$ git clone https://git.zx2c4.com/wireguard-go
|
$ git clone https://git.zx2c4.com/wireguard-go
|
||||||
@@ -56,7 +56,7 @@ $ make
|
|||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||||
this software and associated documentation files (the "Software"), to deal in
|
this software and associated documentation files (the "Software"), to deal in
|
||||||
|
|||||||
34
conn/boundif_android.go
Normal file
34
conn/boundif_android.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package conn
|
||||||
|
|
||||||
|
func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||||
|
sysconn, err := bind.ipv4.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
err = sysconn.Control(func(f uintptr) {
|
||||||
|
fd = int(f)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *nativeBind) PeekLookAtSocketFd6() (fd int, err error) {
|
||||||
|
sysconn, err := bind.ipv6.SyscallConn()
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
err = sysconn.Control(func(f uintptr) {
|
||||||
|
fd = int(f)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -1,13 +1,12 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
@@ -18,17 +17,13 @@ const (
|
|||||||
sockoptIPV6_UNICAST_IF = 31
|
sockoptIPV6_UNICAST_IF = 31
|
||||||
)
|
)
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
|
func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
||||||
bytes := make([]byte, 4)
|
bytes := make([]byte, 4)
|
||||||
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
||||||
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
||||||
|
|
||||||
if device.net.bind == nil {
|
sysconn, err := bind.ipv4.SyscallConn()
|
||||||
return errors.New("Bind is not yet initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -41,11 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
bind.blackhole4 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
|
sysconn, err := bind.ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -58,5 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
bind.blackhole6 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
111
conn/conn.go
Normal file
111
conn/conn.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package conn implements WireGuard's network connections.
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||||
|
//
|
||||||
|
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
||||||
|
// depending on the platform-specific implementation.
|
||||||
|
type Bind interface {
|
||||||
|
// LastMark reports the last mark set for this Bind.
|
||||||
|
LastMark() uint32
|
||||||
|
|
||||||
|
// SetMark sets the mark for each packet sent through this Bind.
|
||||||
|
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||||
|
SetMark(mark uint32) error
|
||||||
|
|
||||||
|
// ReceiveIPv6 reads an IPv6 UDP packet into b.
|
||||||
|
//
|
||||||
|
// It reports the number of bytes read, n,
|
||||||
|
// the packet source address ep,
|
||||||
|
// and any error.
|
||||||
|
ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
|
||||||
|
|
||||||
|
// ReceiveIPv4 reads an IPv4 UDP packet into b.
|
||||||
|
//
|
||||||
|
// It reports the number of bytes read, n,
|
||||||
|
// the packet source address ep,
|
||||||
|
// and any error.
|
||||||
|
ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
|
||||||
|
|
||||||
|
// Send writes a packet b to address ep.
|
||||||
|
Send(b []byte, ep Endpoint) error
|
||||||
|
|
||||||
|
// Close closes the Bind connection.
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateBind creates a Bind bound to a port.
|
||||||
|
//
|
||||||
|
// The value actualPort reports the actual port number the Bind
|
||||||
|
// object gets bound to.
|
||||||
|
func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
|
||||||
|
return createBind(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BindSocketToInterface is implemented by Bind objects that support being
|
||||||
|
// tied to a single network interface. Used by wireguard-windows.
|
||||||
|
type BindSocketToInterface interface {
|
||||||
|
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
|
||||||
|
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeekLookAtSocketFd is implemented by Bind objects that support having their
|
||||||
|
// file descriptor peeked at. Used by wireguard-android.
|
||||||
|
type PeekLookAtSocketFd interface {
|
||||||
|
PeekLookAtSocketFd4() (fd int, err error)
|
||||||
|
PeekLookAtSocketFd6() (fd int, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Endpoint maintains the source/destination caching for a peer.
|
||||||
|
//
|
||||||
|
// dst : the remote address of a peer ("endpoint" in uapi terminology)
|
||||||
|
// src : the local address from which datagrams originate going to the peer
|
||||||
|
type Endpoint interface {
|
||||||
|
ClearSrc() // clears the source address
|
||||||
|
SrcToString() string // returns the local source address (ip:port)
|
||||||
|
DstToString() string // returns the destination address (ip:port)
|
||||||
|
DstToBytes() []byte // used for mac2 cookie calculations
|
||||||
|
DstIP() net.IP
|
||||||
|
SrcIP() net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||||
|
// ensure that the host is an IP address
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
|
||||||
|
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
|
||||||
|
// trying to make sure with a small sanity test that this is a real IP address and
|
||||||
|
// not something that's likely to incur DNS lookups.
|
||||||
|
host = host[:i]
|
||||||
|
}
|
||||||
|
if ip := net.ParseIP(host); ip == nil {
|
||||||
|
return nil, errors.New("Failed to parse IP address: " + host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse address and port
|
||||||
|
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ip4 := addr.IP.To4()
|
||||||
|
if ip4 != nil {
|
||||||
|
addr.IP = ip4
|
||||||
|
}
|
||||||
|
return addr, err
|
||||||
|
}
|
||||||
@@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
@@ -21,8 +21,10 @@ import (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
type nativeBind struct {
|
type nativeBind struct {
|
||||||
ipv4 *net.UDPConn
|
ipv4 *net.UDPConn
|
||||||
ipv6 *net.UDPConn
|
ipv6 *net.UDPConn
|
||||||
|
blackhole4 bool
|
||||||
|
blackhole6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type NativeEndpoint net.UDPAddr
|
type NativeEndpoint net.UDPAddr
|
||||||
@@ -65,16 +67,12 @@ func (e *NativeEndpoint) SrcToString() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||||
|
|
||||||
// listen
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieve port
|
// Retrieve port.
|
||||||
|
|
||||||
laddr := conn.LocalAddr()
|
laddr := conn.LocalAddr()
|
||||||
uaddr, err := net.ResolveUDPAddr(
|
uaddr, err := net.ResolveUDPAddr(
|
||||||
laddr.Network(),
|
laddr.Network(),
|
||||||
@@ -98,7 +96,7 @@ func extractErrno(err error) error {
|
|||||||
return syscallErr.Err
|
return syscallErr.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
|
func createBind(uport uint16) (Bind, uint16, error) {
|
||||||
var err error
|
var err error
|
||||||
var bind nativeBind
|
var bind nativeBind
|
||||||
|
|
||||||
@@ -133,6 +131,8 @@ func (bind *nativeBind) Close() error {
|
|||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bind *nativeBind) LastMark() uint32 { return 0 }
|
||||||
|
|
||||||
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||||
if bind.ipv4 == nil {
|
if bind.ipv4 == nil {
|
||||||
return 0, nil, syscall.EAFNOSUPPORT
|
return 0, nil, syscall.EAFNOSUPPORT
|
||||||
@@ -159,11 +159,17 @@ func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
|
|||||||
if bind.ipv4 == nil {
|
if bind.ipv4 == nil {
|
||||||
return syscall.EAFNOSUPPORT
|
return syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
if bind.blackhole4 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
} else {
|
} else {
|
||||||
if bind.ipv6 == nil {
|
if bind.ipv6 == nil {
|
||||||
return syscall.EAFNOSUPPORT
|
return syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
if bind.blackhole6 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
@@ -2,19 +2,10 @@
|
|||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*
|
|
||||||
* This implements userspace semantics of "sticky sockets", modeled after
|
|
||||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
|
||||||
* of the sticky-sockets.c example code:
|
|
||||||
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
|
||||||
*
|
|
||||||
* Currently there is no way to achieve this within the net package:
|
|
||||||
* See e.g. https://github.com/golang/go/issues/17930
|
|
||||||
* So this code is remains platform dependent.
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
@@ -25,7 +16,6 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -33,8 +23,8 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type IPv4Source struct {
|
type IPv4Source struct {
|
||||||
src [4]byte
|
Src [4]byte
|
||||||
ifindex int32
|
Ifindex int32
|
||||||
}
|
}
|
||||||
|
|
||||||
type IPv6Source struct {
|
type IPv6Source struct {
|
||||||
@@ -43,11 +33,16 @@ type IPv6Source struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type NativeEndpoint struct {
|
type NativeEndpoint struct {
|
||||||
|
sync.Mutex
|
||||||
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
||||||
src [unsafe.Sizeof(IPv6Source{})]byte
|
src [unsafe.Sizeof(IPv6Source{})]byte
|
||||||
isV6 bool
|
isV6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
|
||||||
|
func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
|
||||||
|
func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
|
||||||
|
|
||||||
func (endpoint *NativeEndpoint) src4() *IPv4Source {
|
func (endpoint *NativeEndpoint) src4() *IPv4Source {
|
||||||
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
||||||
}
|
}
|
||||||
@@ -65,11 +60,9 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type nativeBind struct {
|
type nativeBind struct {
|
||||||
sock4 int
|
sock4 int
|
||||||
sock6 int
|
sock6 int
|
||||||
netlinkSock int
|
lastMark uint32
|
||||||
netlinkCancel *rwcancel.RWCancel
|
|
||||||
lastMark uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Endpoint = (*NativeEndpoint)(nil)
|
var _ Endpoint = (*NativeEndpoint)(nil)
|
||||||
@@ -110,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) {
|
|||||||
return nil, errors.New("Invalid IP address")
|
return nil, errors.New("Invalid IP address")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNetlinkRouteSocket() (int, error) {
|
func createBind(port uint16) (Bind, uint16, error) {
|
||||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
saddr := &unix.SockaddrNetlink{
|
|
||||||
Family: unix.AF_NETLINK,
|
|
||||||
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
|
|
||||||
}
|
|
||||||
err = unix.Bind(sock, saddr)
|
|
||||||
if err != nil {
|
|
||||||
unix.Close(sock)
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return sock, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
|
||||||
var err error
|
var err error
|
||||||
var bind nativeBind
|
var bind nativeBind
|
||||||
var newPort uint16
|
var newPort uint16
|
||||||
|
|
||||||
bind.netlinkSock, err = createNetlinkRouteSocket()
|
// Attempt ipv6 bind, update port if successful.
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
|
|
||||||
if err != nil {
|
|
||||||
unix.Close(bind.netlinkSock)
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
go bind.routineRouteListener(device)
|
|
||||||
|
|
||||||
// attempt ipv6 bind, update port if succesful
|
|
||||||
|
|
||||||
bind.sock6, newPort, err = create6(port)
|
bind.sock6, newPort, err = create6(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != syscall.EAFNOSUPPORT {
|
if err != syscall.EAFNOSUPPORT {
|
||||||
bind.netlinkCancel.Cancel()
|
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
port = newPort
|
port = newPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// attempt ipv4 bind, update port if succesful
|
// Attempt ipv4 bind, update port if successful.
|
||||||
|
|
||||||
bind.sock4, newPort, err = create4(port)
|
bind.sock4, newPort, err = create4(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != syscall.EAFNOSUPPORT {
|
if err != syscall.EAFNOSUPPORT {
|
||||||
bind.netlinkCancel.Cancel()
|
|
||||||
unix.Close(bind.sock6)
|
unix.Close(bind.sock6)
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
@@ -177,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
|||||||
return &bind, port, nil
|
return &bind, port, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (bind *nativeBind) LastMark() uint32 {
|
||||||
|
return bind.lastMark
|
||||||
|
}
|
||||||
|
|
||||||
func (bind *nativeBind) SetMark(value uint32) error {
|
func (bind *nativeBind) SetMark(value uint32) error {
|
||||||
if bind.sock6 != -1 {
|
if bind.sock6 != -1 {
|
||||||
err := unix.SetsockoptInt(
|
err := unix.SetsockoptInt(
|
||||||
@@ -215,22 +178,18 @@ func closeUnblock(fd int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (bind *nativeBind) Close() error {
|
func (bind *nativeBind) Close() error {
|
||||||
var err1, err2, err3 error
|
var err1, err2 error
|
||||||
if bind.sock6 != -1 {
|
if bind.sock6 != -1 {
|
||||||
err1 = closeUnblock(bind.sock6)
|
err1 = closeUnblock(bind.sock6)
|
||||||
}
|
}
|
||||||
if bind.sock4 != -1 {
|
if bind.sock4 != -1 {
|
||||||
err2 = closeUnblock(bind.sock4)
|
err2 = closeUnblock(bind.sock4)
|
||||||
}
|
}
|
||||||
err3 = bind.netlinkCancel.Cancel()
|
|
||||||
|
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
return err1
|
return err1
|
||||||
}
|
}
|
||||||
if err2 != nil {
|
return err2
|
||||||
return err2
|
|
||||||
}
|
|
||||||
return err3
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
@@ -277,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
|
|||||||
func (end *NativeEndpoint) SrcIP() net.IP {
|
func (end *NativeEndpoint) SrcIP() net.IP {
|
||||||
if !end.isV6 {
|
if !end.isV6 {
|
||||||
return net.IPv4(
|
return net.IPv4(
|
||||||
end.src4().src[0],
|
end.src4().Src[0],
|
||||||
end.src4().src[1],
|
end.src4().Src[1],
|
||||||
end.src4().src[2],
|
end.src4().Src[2],
|
||||||
end.src4().src[3],
|
end.src4().Src[3],
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
return end.src6().src[:]
|
return end.src6().src[:]
|
||||||
@@ -391,6 +350,11 @@ func create4(port uint16) (int, uint16, error) {
|
|||||||
return FD_ERR, 0, err
|
return FD_ERR, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sa, err := unix.Getsockname(fd)
|
||||||
|
if err == nil {
|
||||||
|
addr.Port = sa.(*unix.SockaddrInet4).Port
|
||||||
|
}
|
||||||
|
|
||||||
return fd, uint16(addr.Port), err
|
return fd, uint16(addr.Port), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -450,6 +414,11 @@ func create6(port uint16) (int, uint16, error) {
|
|||||||
return FD_ERR, 0, err
|
return FD_ERR, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sa, err := unix.Getsockname(fd)
|
||||||
|
if err == nil {
|
||||||
|
addr.Port = sa.(*unix.SockaddrInet6).Port
|
||||||
|
}
|
||||||
|
|
||||||
return fd, uint16(addr.Port), err
|
return fd, uint16(addr.Port), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,12 +436,14 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
||||||
},
|
},
|
||||||
unix.Inet4Pktinfo{
|
unix.Inet4Pktinfo{
|
||||||
Spec_dst: end.src4().src,
|
Spec_dst: end.src4().Src,
|
||||||
Ifindex: end.src4().ifindex,
|
Ifindex: end.src4().Ifindex,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
end.Lock()
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||||
|
end.Unlock()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -483,7 +454,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
if err == unix.EINVAL {
|
if err == unix.EINVAL {
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
||||||
|
end.Lock()
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||||
|
end.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -512,7 +485,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
cmsg.pktinfo.Ifindex = 0
|
cmsg.pktinfo.Ifindex = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
end.Lock()
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||||
|
end.Unlock()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -523,7 +498,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
if err == unix.EINVAL {
|
if err == unix.EINVAL {
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
||||||
|
end.Lock()
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||||
|
end.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -531,7 +508,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
|
|
||||||
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
// contruct message header
|
// construct message header
|
||||||
|
|
||||||
var cmsg struct {
|
var cmsg struct {
|
||||||
cmsghdr unix.Cmsghdr
|
cmsghdr unix.Cmsghdr
|
||||||
@@ -554,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|||||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||||
end.src4().src = cmsg.pktinfo.Spec_dst
|
end.src4().Src = cmsg.pktinfo.Spec_dst
|
||||||
end.src4().ifindex = cmsg.pktinfo.Ifindex
|
end.src4().Ifindex = cmsg.pktinfo.Ifindex
|
||||||
}
|
}
|
||||||
|
|
||||||
return size, nil
|
return size, nil
|
||||||
@@ -563,7 +540,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|||||||
|
|
||||||
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
// contruct message header
|
// construct message header
|
||||||
|
|
||||||
var cmsg struct {
|
var cmsg struct {
|
||||||
cmsghdr unix.Cmsghdr
|
cmsghdr unix.Cmsghdr
|
||||||
@@ -592,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|||||||
|
|
||||||
return size, nil
|
return size, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bind *nativeBind) routineRouteListener(device *Device) {
|
|
||||||
type peerEndpointPtr struct {
|
|
||||||
peer *Peer
|
|
||||||
endpoint *Endpoint
|
|
||||||
}
|
|
||||||
var reqPeer map[uint32]peerEndpointPtr
|
|
||||||
var reqPeerLock sync.Mutex
|
|
||||||
|
|
||||||
defer unix.Close(bind.netlinkSock)
|
|
||||||
|
|
||||||
for msg := make([]byte, 1<<16); ; {
|
|
||||||
var err error
|
|
||||||
var msgn int
|
|
||||||
for {
|
|
||||||
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
|
|
||||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !bind.netlinkCancel.ReadyRead() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
|
||||||
|
|
||||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
|
||||||
|
|
||||||
if uint(hdr.Len) > uint(len(remain)) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
switch hdr.Type {
|
|
||||||
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
|
||||||
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
|
||||||
if uint(len(remain)) < uint(hdr.Len) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
|
||||||
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
|
||||||
for {
|
|
||||||
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
|
||||||
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
|
||||||
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
if reqPeer == nil {
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr, ok := reqPeer[hdr.Seq]
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr.peer.Lock()
|
|
||||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
|
||||||
pePtr.peer.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
|
|
||||||
pePtr.peer.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
|
|
||||||
pePtr.peer.Unlock()
|
|
||||||
}
|
|
||||||
attr = attr[attrhdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
reqPeer = make(map[uint32]peerEndpointPtr)
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
go func() {
|
|
||||||
device.peers.RLock()
|
|
||||||
i := uint32(1)
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.RLock()
|
|
||||||
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
|
|
||||||
peer.RUnlock()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
|
|
||||||
peer.RUnlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
nlmsg := struct {
|
|
||||||
hdr unix.NlMsghdr
|
|
||||||
msg unix.RtMsg
|
|
||||||
dsthdr unix.RtAttr
|
|
||||||
dst [4]byte
|
|
||||||
srchdr unix.RtAttr
|
|
||||||
src [4]byte
|
|
||||||
markhdr unix.RtAttr
|
|
||||||
mark uint32
|
|
||||||
}{
|
|
||||||
unix.NlMsghdr{
|
|
||||||
Type: uint16(unix.RTM_GETROUTE),
|
|
||||||
Flags: unix.NLM_F_REQUEST,
|
|
||||||
Seq: i,
|
|
||||||
},
|
|
||||||
unix.RtMsg{
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Dst_len: 32,
|
|
||||||
Src_len: 32,
|
|
||||||
},
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_DST,
|
|
||||||
},
|
|
||||||
peer.endpoint.(*NativeEndpoint).dst4().Addr,
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_SRC,
|
|
||||||
},
|
|
||||||
peer.endpoint.(*NativeEndpoint).src4().src,
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_MARK,
|
|
||||||
},
|
|
||||||
uint32(bind.lastMark),
|
|
||||||
}
|
|
||||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
reqPeer[i] = peerEndpointPtr{
|
|
||||||
peer: peer,
|
|
||||||
endpoint: &peer.endpoint,
|
|
||||||
}
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
peer.RUnlock()
|
|
||||||
i++
|
|
||||||
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
remain = remain[hdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
func (bind *nativeBind) SetMark(mark uint32) error {
|
func (bind *nativeBind) SetMark(mark uint32) error {
|
||||||
return nil
|
return nil
|
||||||
@@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,15 +1,19 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import "errors"
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
type DummyDatagram struct {
|
type DummyDatagram struct {
|
||||||
msg []byte
|
msg []byte
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
world bool // better type
|
world bool // better type
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
|
||||||
datagram, ok := <-b.in6
|
datagram, ok := <-b.in6
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
@@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
|||||||
return len(datagram.msg), datagram.endpoint, nil
|
return len(datagram.msg), datagram.endpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
|
||||||
datagram, ok := <-b.in4
|
datagram, ok := <-b.in4
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
@@ -50,6 +54,6 @@ func (b *DummyBind) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) Send(buff []byte, end Endpoint) error {
|
func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,44 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
|
|
||||||
nb, ok := device.net.bind.(*nativeBind)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("no socket exists")
|
|
||||||
}
|
|
||||||
sysconn, err := nb.ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = sysconn.Control(func(f uintptr) {
|
|
||||||
fd = int(f)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
|
|
||||||
nb, ok := device.net.bind.(*nativeBind)
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("no socket exists")
|
|
||||||
}
|
|
||||||
sysconn, err := nb.ipv6.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = sysconn.Control(func(f uintptr) {
|
|
||||||
fd = int(f)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
187
device/conn.go
187
device/conn.go
@@ -1,187 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ConnRoutineNumber = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
|
|
||||||
*/
|
|
||||||
type Bind interface {
|
|
||||||
SetMark(value uint32) error
|
|
||||||
ReceiveIPv6(buff []byte) (int, Endpoint, error)
|
|
||||||
ReceiveIPv4(buff []byte) (int, Endpoint, error)
|
|
||||||
Send(buff []byte, end Endpoint) error
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
/* An Endpoint maintains the source/destination caching for a peer
|
|
||||||
*
|
|
||||||
* dst : the remote address of a peer ("endpoint" in uapi terminology)
|
|
||||||
* src : the local address from which datagrams originate going to the peer
|
|
||||||
*/
|
|
||||||
type Endpoint interface {
|
|
||||||
ClearSrc() // clears the source address
|
|
||||||
SrcToString() string // returns the local source address (ip:port)
|
|
||||||
DstToString() string // returns the destination address (ip:port)
|
|
||||||
DstToBytes() []byte // used for mac2 cookie calculations
|
|
||||||
DstIP() net.IP
|
|
||||||
SrcIP() net.IP
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
|
||||||
// ensure that the host is an IP address
|
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
|
|
||||||
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
|
|
||||||
// trying to make sure with a small sanity test that this is a real IP address and
|
|
||||||
// not something that's likely to incur DNS lookups.
|
|
||||||
host = host[:i]
|
|
||||||
}
|
|
||||||
if ip := net.ParseIP(host); ip == nil {
|
|
||||||
return nil, errors.New("Failed to parse IP address: " + host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse address and port
|
|
||||||
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ip4 := addr.IP.To4()
|
|
||||||
if ip4 != nil {
|
|
||||||
addr.IP = ip4
|
|
||||||
}
|
|
||||||
return addr, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func unsafeCloseBind(device *Device) error {
|
|
||||||
var err error
|
|
||||||
netc := &device.net
|
|
||||||
if netc.bind != nil {
|
|
||||||
err = netc.bind.Close()
|
|
||||||
netc.bind = nil
|
|
||||||
}
|
|
||||||
netc.stopping.Wait()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindSetMark(mark uint32) error {
|
|
||||||
|
|
||||||
device.net.Lock()
|
|
||||||
defer device.net.Unlock()
|
|
||||||
|
|
||||||
// check if modified
|
|
||||||
|
|
||||||
if device.net.fwmark == mark {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// update fwmark on existing bind
|
|
||||||
|
|
||||||
device.net.fwmark = mark
|
|
||||||
if device.isUp.Get() && device.net.bind != nil {
|
|
||||||
if err := device.net.bind.SetMark(mark); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear cached source addresses
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindUpdate() error {
|
|
||||||
|
|
||||||
device.net.Lock()
|
|
||||||
defer device.net.Unlock()
|
|
||||||
|
|
||||||
// close existing sockets
|
|
||||||
|
|
||||||
if err := unsafeCloseBind(device); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open new sockets
|
|
||||||
|
|
||||||
if device.isUp.Get() {
|
|
||||||
|
|
||||||
// bind to new port
|
|
||||||
|
|
||||||
var err error
|
|
||||||
netc := &device.net
|
|
||||||
netc.bind, netc.port, err = CreateBind(netc.port, device)
|
|
||||||
if err != nil {
|
|
||||||
netc.bind = nil
|
|
||||||
netc.port = 0
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set fwmark
|
|
||||||
|
|
||||||
if netc.fwmark != 0 {
|
|
||||||
err = netc.bind.SetMark(netc.fwmark)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear cached source addresses
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.Lock()
|
|
||||||
defer peer.Unlock()
|
|
||||||
if peer.endpoint != nil {
|
|
||||||
peer.endpoint.ClearSrc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
|
|
||||||
// start receiving routines
|
|
||||||
|
|
||||||
device.net.starting.Add(ConnRoutineNumber)
|
|
||||||
device.net.stopping.Add(ConnRoutineNumber)
|
|
||||||
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
|
||||||
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
|
||||||
device.net.starting.Wait()
|
|
||||||
|
|
||||||
device.log.Debug.Println("UDP bind has been updated")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindClose() error {
|
|
||||||
device.net.Lock()
|
|
||||||
err := unsafeCloseBind(device)
|
|
||||||
device.net.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
/* Specification constants */
|
/* Specification constants */
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
|
RekeyAfterMessages = (1 << 60)
|
||||||
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
RejectAfterMessages = (1 << 64) - (1 << 13) - 1
|
||||||
RekeyAfterTime = time.Second * 120
|
RekeyAfterTime = time.Second * 120
|
||||||
RekeyAttemptTime = time.Second * 90
|
RekeyAttemptTime = time.Second * 90
|
||||||
RekeyTimeout = time.Second * 5
|
RekeyTimeout = time.Second * 5
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
174
device/device.go
174
device/device.go
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -11,15 +11,14 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/ratelimiter"
|
"golang.zx2c4.com/wireguard/ratelimiter"
|
||||||
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
DeviceRoutineNumberPerCPU = 3
|
|
||||||
DeviceRoutineNumberAdditional = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
||||||
isUp AtomicBool // device is (going) up
|
isUp AtomicBool // device is (going) up
|
||||||
isClosed AtomicBool // device is closed? (acting as guard)
|
isClosed AtomicBool // device is closed? (acting as guard)
|
||||||
@@ -39,9 +38,10 @@ type Device struct {
|
|||||||
starting sync.WaitGroup
|
starting sync.WaitGroup
|
||||||
stopping sync.WaitGroup
|
stopping sync.WaitGroup
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
bind Bind // bind interface
|
bind conn.Bind // bind interface
|
||||||
port uint16 // listening port
|
netlinkCancel *rwcancel.RWCancel
|
||||||
fwmark uint32 // mark value (0 = disabled)
|
port uint16 // listening port
|
||||||
|
fwmark uint32 // mark value (0 = disabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
staticIdentity struct {
|
staticIdentity struct {
|
||||||
@@ -236,23 +236,11 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||||||
|
|
||||||
// do static-static DH pre-computations
|
// do static-static DH pre-computations
|
||||||
|
|
||||||
rmKey := device.staticIdentity.privateKey.IsZero()
|
|
||||||
|
|
||||||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||||
for key, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
|
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||||
if rmKey {
|
expiredPeers = append(expiredPeers, peer)
|
||||||
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
|
|
||||||
} else {
|
|
||||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
|
||||||
}
|
|
||||||
|
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
||||||
unsafeRemovePeer(device, peer, key)
|
|
||||||
} else {
|
|
||||||
expiredPeers = append(expiredPeers, peer)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peer := range lockedPeers {
|
for _, peer := range lockedPeers {
|
||||||
@@ -311,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
|||||||
cpus := runtime.NumCPU()
|
cpus := runtime.NumCPU()
|
||||||
device.state.starting.Wait()
|
device.state.starting.Wait()
|
||||||
device.state.stopping.Wait()
|
device.state.stopping.Wait()
|
||||||
device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
|
|
||||||
device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
|
|
||||||
for i := 0; i < cpus; i += 1 {
|
for i := 0; i < cpus; i += 1 {
|
||||||
|
device.state.starting.Add(3)
|
||||||
|
device.state.stopping.Add(3)
|
||||||
go device.RoutineEncryption()
|
go device.RoutineEncryption()
|
||||||
go device.RoutineDecryption()
|
go device.RoutineDecryption()
|
||||||
go device.RoutineHandshake()
|
go device.RoutineHandshake()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device.state.starting.Add(2)
|
||||||
|
device.state.stopping.Add(2)
|
||||||
go device.RoutineReadFromTUN()
|
go device.RoutineReadFromTUN()
|
||||||
go device.RoutineTUNEventReader()
|
go device.RoutineTUNEventReader()
|
||||||
|
|
||||||
@@ -393,10 +383,10 @@ func (device *Device) Close() {
|
|||||||
device.isUp.Set(false)
|
device.isUp.Set(false)
|
||||||
|
|
||||||
close(device.signals.stop)
|
close(device.signals.stop)
|
||||||
|
device.state.stopping.Wait()
|
||||||
|
|
||||||
device.RemoveAllPeers()
|
device.RemoveAllPeers()
|
||||||
|
|
||||||
device.state.stopping.Wait()
|
|
||||||
device.FlushPacketQueues()
|
device.FlushPacketQueues()
|
||||||
|
|
||||||
device.rate.limiter.Close()
|
device.rate.limiter.Close()
|
||||||
@@ -425,3 +415,133 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
|||||||
}
|
}
|
||||||
device.peers.RUnlock()
|
device.peers.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unsafeCloseBind(device *Device) error {
|
||||||
|
var err error
|
||||||
|
netc := &device.net
|
||||||
|
if netc.netlinkCancel != nil {
|
||||||
|
netc.netlinkCancel.Cancel()
|
||||||
|
}
|
||||||
|
if netc.bind != nil {
|
||||||
|
err = netc.bind.Close()
|
||||||
|
netc.bind = nil
|
||||||
|
}
|
||||||
|
netc.stopping.Wait()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) Bind() conn.Bind {
|
||||||
|
device.net.Lock()
|
||||||
|
defer device.net.Unlock()
|
||||||
|
return device.net.bind
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindSetMark(mark uint32) error {
|
||||||
|
|
||||||
|
device.net.Lock()
|
||||||
|
defer device.net.Unlock()
|
||||||
|
|
||||||
|
// check if modified
|
||||||
|
|
||||||
|
if device.net.fwmark == mark {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// update fwmark on existing bind
|
||||||
|
|
||||||
|
device.net.fwmark = mark
|
||||||
|
if device.isUp.Get() && device.net.bind != nil {
|
||||||
|
if err := device.net.bind.SetMark(mark); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear cached source addresses
|
||||||
|
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.Lock()
|
||||||
|
defer peer.Unlock()
|
||||||
|
if peer.endpoint != nil {
|
||||||
|
peer.endpoint.ClearSrc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindUpdate() error {
|
||||||
|
|
||||||
|
device.net.Lock()
|
||||||
|
defer device.net.Unlock()
|
||||||
|
|
||||||
|
// close existing sockets
|
||||||
|
|
||||||
|
if err := unsafeCloseBind(device); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// open new sockets
|
||||||
|
|
||||||
|
if device.isUp.Get() {
|
||||||
|
|
||||||
|
// bind to new port
|
||||||
|
|
||||||
|
var err error
|
||||||
|
netc := &device.net
|
||||||
|
netc.bind, netc.port, err = conn.CreateBind(netc.port)
|
||||||
|
if err != nil {
|
||||||
|
netc.bind = nil
|
||||||
|
netc.port = 0
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||||
|
if err != nil {
|
||||||
|
netc.bind.Close()
|
||||||
|
netc.bind = nil
|
||||||
|
netc.port = 0
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set fwmark
|
||||||
|
|
||||||
|
if netc.fwmark != 0 {
|
||||||
|
err = netc.bind.SetMark(netc.fwmark)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear cached source addresses
|
||||||
|
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.Lock()
|
||||||
|
defer peer.Unlock()
|
||||||
|
if peer.endpoint != nil {
|
||||||
|
peer.endpoint.ClearSrc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
|
||||||
|
// start receiving routines
|
||||||
|
|
||||||
|
device.net.starting.Add(2)
|
||||||
|
device.net.stopping.Add(2)
|
||||||
|
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
||||||
|
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
||||||
|
device.net.starting.Wait()
|
||||||
|
|
||||||
|
device.log.Debug.Println("UDP bind has been updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindClose() error {
|
||||||
|
device.net.Lock()
|
||||||
|
err := unsafeCloseBind(device)
|
||||||
|
device.net.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,58 +1,98 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
/* Create two device instances and simulate full WireGuard interaction
|
|
||||||
* without network dependencies
|
|
||||||
*/
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDevice(t *testing.T) {
|
func getFreePort(t *testing.T) string {
|
||||||
|
l, err := net.ListenPacket("udp", "localhost:0")
|
||||||
// prepare tun devices for generating traffic
|
|
||||||
|
|
||||||
tun1 := newDummyTUN("tun1")
|
|
||||||
tun2 := newDummyTUN("tun2")
|
|
||||||
|
|
||||||
_ = tun1
|
|
||||||
_ = tun2
|
|
||||||
|
|
||||||
// prepare endpoints
|
|
||||||
|
|
||||||
end1, err := CreateDummyEndpoint()
|
|
||||||
if err != nil {
|
|
||||||
t.Error("failed to create endpoint:", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
end2, err := CreateDummyEndpoint()
|
|
||||||
if err != nil {
|
|
||||||
t.Error("failed to create endpoint:", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = end1
|
|
||||||
_ = end2
|
|
||||||
|
|
||||||
// create binds
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func randDevice(t *testing.T) *Device {
|
|
||||||
sk, err := newPrivateKey()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tun := newDummyTUN("dummy")
|
defer l.Close()
|
||||||
logger := NewLogger(LogLevelError, "")
|
return fmt.Sprintf("%d", l.LocalAddr().(*net.UDPAddr).Port)
|
||||||
device := NewDevice(tun, logger)
|
}
|
||||||
device.SetPrivateKey(sk)
|
|
||||||
return device
|
func TestTwoDevicePing(t *testing.T) {
|
||||||
|
port1 := getFreePort(t)
|
||||||
|
port2 := getFreePort(t)
|
||||||
|
|
||||||
|
cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
|
||||||
|
listen_port={{PORT1}}
|
||||||
|
replace_peers=true
|
||||||
|
public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
|
||||||
|
protocol_version=1
|
||||||
|
replace_allowed_ips=true
|
||||||
|
allowed_ip=1.0.0.2/32
|
||||||
|
endpoint=127.0.0.1:{{PORT2}}`
|
||||||
|
cfg1 = strings.ReplaceAll(cfg1, "{{PORT1}}", port1)
|
||||||
|
cfg1 = strings.ReplaceAll(cfg1, "{{PORT2}}", port2)
|
||||||
|
|
||||||
|
tun1 := tuntest.NewChannelTUN()
|
||||||
|
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
|
||||||
|
dev1.Up()
|
||||||
|
defer dev1.Close()
|
||||||
|
if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
|
||||||
|
listen_port={{PORT2}}
|
||||||
|
replace_peers=true
|
||||||
|
public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
|
||||||
|
protocol_version=1
|
||||||
|
replace_allowed_ips=true
|
||||||
|
allowed_ip=1.0.0.1/32
|
||||||
|
endpoint=127.0.0.1:{{PORT1}}`
|
||||||
|
cfg2 = strings.ReplaceAll(cfg2, "{{PORT1}}", port1)
|
||||||
|
cfg2 = strings.ReplaceAll(cfg2, "{{PORT2}}", port2)
|
||||||
|
|
||||||
|
tun2 := tuntest.NewChannelTUN()
|
||||||
|
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
|
||||||
|
dev2.Up()
|
||||||
|
defer dev2.Close()
|
||||||
|
if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
|
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
|
||||||
|
tun2.Outbound <- msg2to1
|
||||||
|
select {
|
||||||
|
case msgRecv := <-tun1.Inbound:
|
||||||
|
if !bytes.Equal(msg2to1, msgRecv) {
|
||||||
|
t.Error("ping did not transit correctly")
|
||||||
|
}
|
||||||
|
case <-time.After(300 * time.Millisecond):
|
||||||
|
t.Error("ping did not transit")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
|
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
|
||||||
|
tun1.Outbound <- msg1to2
|
||||||
|
select {
|
||||||
|
case msgRecv := <-tun2.Inbound:
|
||||||
|
if !bytes.Equal(msg1to2, msgRecv) {
|
||||||
|
t.Error("return ping did not transit correctly")
|
||||||
|
}
|
||||||
|
case <-time.After(300 * time.Millisecond):
|
||||||
|
t.Error("return ping did not transit")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertNil(t *testing.T, err error) {
|
func assertNil(t *testing.T, err error) {
|
||||||
@@ -66,3 +106,15 @@ func assertEqual(t *testing.T, a, b []byte) {
|
|||||||
t.Fatal(a, "!=", b)
|
t.Fatal(a, "!=", b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randDevice(t *testing.T) *Device {
|
||||||
|
sk, err := newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tun := newDummyTUN("dummy")
|
||||||
|
logger := NewLogger(LogLevelError, "")
|
||||||
|
device := NewDevice(tun, logger)
|
||||||
|
device.SetPrivateKey(sk)
|
||||||
|
return device
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type IndexTableEntry struct {
|
type IndexTableEntry struct {
|
||||||
@@ -25,7 +25,8 @@ type IndexTable struct {
|
|||||||
func randUint32() (uint32, error) {
|
func randUint32() (uint32, error) {
|
||||||
var integer [4]byte
|
var integer [4]byte
|
||||||
_, err := rand.Read(integer[:])
|
_, err := rand.Read(integer[:])
|
||||||
return *(*uint32)(unsafe.Pointer(&integer[0])), err
|
// Arbitrary endianness; both are intrinsified by the Go compiler.
|
||||||
|
return binary.LittleEndian.Uint32(integer[:]), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *IndexTable) Init() {
|
func (table *IndexTable) Init() {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -8,7 +8,9 @@ package device
|
|||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/replay"
|
"golang.zx2c4.com/wireguard/replay"
|
||||||
)
|
)
|
||||||
@@ -24,7 +26,7 @@ type Keypair struct {
|
|||||||
sendNonce uint64
|
sendNonce uint64
|
||||||
send cipher.AEAD
|
send cipher.AEAD
|
||||||
receive cipher.AEAD
|
receive cipher.AEAD
|
||||||
replayFilter replay.ReplayFilter
|
replayFilter replay.Filter
|
||||||
isInitiator bool
|
isInitiator bool
|
||||||
created time.Time
|
created time.Time
|
||||||
localIndex uint32
|
localIndex uint32
|
||||||
@@ -38,6 +40,14 @@ type Keypairs struct {
|
|||||||
next *Keypair
|
next *Keypair
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (kp *Keypairs) storeNext(next *Keypair) {
|
||||||
|
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kp *Keypairs) loadNext() *Keypair {
|
||||||
|
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
|
||||||
|
}
|
||||||
|
|
||||||
func (kp *Keypairs) Current() *Keypair {
|
func (kp *Keypairs) Current() *Keypair {
|
||||||
kp.RLock()
|
kp.RLock()
|
||||||
defer kp.RUnlock()
|
defer kp.RUnlock()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
16
device/mobilequirks.go
Normal file
16
device/mobilequirks.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.Lock()
|
||||||
|
defer peer.Unlock()
|
||||||
|
peer.disableRoaming = peer.endpoint != nil
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,29 +1,51 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/poly1305"
|
"golang.org/x/crypto/poly1305"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tai64n"
|
"golang.zx2c4.com/wireguard/tai64n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type handshakeState int
|
||||||
|
|
||||||
|
// TODO(crawshaw): add commentary describing each state and the transitions
|
||||||
const (
|
const (
|
||||||
HandshakeZeroed = iota
|
handshakeZeroed = handshakeState(iota)
|
||||||
HandshakeInitiationCreated
|
handshakeInitiationCreated
|
||||||
HandshakeInitiationConsumed
|
handshakeInitiationConsumed
|
||||||
HandshakeResponseCreated
|
handshakeResponseCreated
|
||||||
HandshakeResponseConsumed
|
handshakeResponseConsumed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (hs handshakeState) String() string {
|
||||||
|
switch hs {
|
||||||
|
case handshakeZeroed:
|
||||||
|
return "handshakeZeroed"
|
||||||
|
case handshakeInitiationCreated:
|
||||||
|
return "handshakeInitiationCreated"
|
||||||
|
case handshakeInitiationConsumed:
|
||||||
|
return "handshakeInitiationConsumed"
|
||||||
|
case handshakeResponseCreated:
|
||||||
|
return "handshakeResponseCreated"
|
||||||
|
case handshakeResponseConsumed:
|
||||||
|
return "handshakeResponseConsumed"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
||||||
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
||||||
@@ -39,13 +61,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MessageInitiationSize = 148 // size of handshake initation message
|
MessageInitiationSize = 148 // size of handshake initiation message
|
||||||
MessageResponseSize = 92 // size of response message
|
MessageResponseSize = 92 // size of response message
|
||||||
MessageCookieReplySize = 64 // size of cookie reply message
|
MessageCookieReplySize = 64 // size of cookie reply message
|
||||||
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
|
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
|
||||||
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
||||||
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
||||||
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
|
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -95,7 +117,7 @@ type MessageCookieReply struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Handshake struct {
|
type Handshake struct {
|
||||||
state int
|
state handshakeState
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
hash [blake2s.Size]byte // hash value
|
hash [blake2s.Size]byte // hash value
|
||||||
chainKey [blake2s.Size]byte // chain key
|
chainKey [blake2s.Size]byte // chain key
|
||||||
@@ -135,7 +157,7 @@ func (h *Handshake) Clear() {
|
|||||||
setZero(h.chainKey[:])
|
setZero(h.chainKey[:])
|
||||||
setZero(h.hash[:])
|
setZero(h.hash[:])
|
||||||
h.localIndex = 0
|
h.localIndex = 0
|
||||||
h.state = HandshakeZeroed
|
h.state = handshakeZeroed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) mixHash(data []byte) {
|
func (h *Handshake) mixHash(data []byte) {
|
||||||
@@ -154,6 +176,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||||
|
var errZeroECDHResult = errors.New("ECDH returned all zeros")
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
device.staticIdentity.RLock()
|
||||||
defer device.staticIdentity.RUnlock()
|
defer device.staticIdentity.RUnlock()
|
||||||
@@ -162,12 +185,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
||||||
return nil, errors.New("static shared secret is zero")
|
|
||||||
}
|
|
||||||
|
|
||||||
// create ephemeral key
|
// create ephemeral key
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
handshake.hash = InitialHash
|
handshake.hash = InitialHash
|
||||||
handshake.chainKey = InitialChainKey
|
handshake.chainKey = InitialChainKey
|
||||||
@@ -176,59 +194,56 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// assign index
|
|
||||||
|
|
||||||
device.indexTable.Delete(handshake.localIndex)
|
|
||||||
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
handshake.mixHash(handshake.remoteStatic[:])
|
handshake.mixHash(handshake.remoteStatic[:])
|
||||||
|
|
||||||
msg := MessageInitiation{
|
msg := MessageInitiation{
|
||||||
Type: MessageInitiationType,
|
Type: MessageInitiationType,
|
||||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||||
Sender: handshake.localIndex,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
handshake.mixKey(msg.Ephemeral[:])
|
handshake.mixKey(msg.Ephemeral[:])
|
||||||
handshake.mixHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
// encrypt static key
|
// encrypt static key
|
||||||
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
func() {
|
if isZero(ss[:]) {
|
||||||
var key [chacha20poly1305.KeySize]byte
|
return nil, errZeroECDHResult
|
||||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
}
|
||||||
KDF2(
|
var key [chacha20poly1305.KeySize]byte
|
||||||
&handshake.chainKey,
|
KDF2(
|
||||||
&key,
|
&handshake.chainKey,
|
||||||
handshake.chainKey[:],
|
&key,
|
||||||
ss[:],
|
handshake.chainKey[:],
|
||||||
)
|
ss[:],
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
)
|
||||||
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
}()
|
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
||||||
handshake.mixHash(msg.Static[:])
|
handshake.mixHash(msg.Static[:])
|
||||||
|
|
||||||
// encrypt timestamp
|
// encrypt timestamp
|
||||||
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
|
return nil, errZeroECDHResult
|
||||||
|
}
|
||||||
|
KDF2(
|
||||||
|
&handshake.chainKey,
|
||||||
|
&key,
|
||||||
|
handshake.chainKey[:],
|
||||||
|
handshake.precomputedStaticStatic[:],
|
||||||
|
)
|
||||||
timestamp := tai64n.Now()
|
timestamp := tai64n.Now()
|
||||||
func() {
|
aead, _ = chacha20poly1305.New(key[:])
|
||||||
var key [chacha20poly1305.KeySize]byte
|
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
||||||
KDF2(
|
|
||||||
&handshake.chainKey,
|
// assign index
|
||||||
&key,
|
device.indexTable.Delete(handshake.localIndex)
|
||||||
handshake.chainKey[:],
|
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
||||||
handshake.precomputedStaticStatic[:],
|
if err != nil {
|
||||||
)
|
return nil, err
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
}
|
||||||
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
handshake.localIndex = msg.Sender
|
||||||
}()
|
|
||||||
|
|
||||||
handshake.mixHash(msg.Timestamp[:])
|
handshake.mixHash(msg.Timestamp[:])
|
||||||
handshake.state = HandshakeInitiationCreated
|
handshake.state = handshakeInitiationCreated
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,16 +265,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
// decrypt static key
|
// decrypt static key
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var peerPK NoisePublicKey
|
var peerPK NoisePublicKey
|
||||||
func() {
|
var key [chacha20poly1305.KeySize]byte
|
||||||
var key [chacha20poly1305.KeySize]byte
|
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
if isZero(ss[:]) {
|
||||||
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
return nil
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
}
|
||||||
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
||||||
}()
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
|
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -273,23 +288,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify identity
|
// verify identity
|
||||||
|
|
||||||
var timestamp tai64n.Timestamp
|
var timestamp tai64n.Timestamp
|
||||||
var key [chacha20poly1305.KeySize]byte
|
|
||||||
|
|
||||||
handshake.mutex.RLock()
|
handshake.mutex.RLock()
|
||||||
|
|
||||||
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
|
handshake.mutex.RUnlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
KDF2(
|
KDF2(
|
||||||
&chainKey,
|
&chainKey,
|
||||||
&key,
|
&key,
|
||||||
chainKey[:],
|
chainKey[:],
|
||||||
handshake.precomputedStaticStatic[:],
|
handshake.precomputedStaticStatic[:],
|
||||||
)
|
)
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ = chacha20poly1305.New(key[:])
|
||||||
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handshake.mutex.RUnlock()
|
handshake.mutex.RUnlock()
|
||||||
@@ -299,11 +315,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
|
|
||||||
// protect against replay & flood
|
// protect against replay & flood
|
||||||
|
|
||||||
var ok bool
|
replay := !timestamp.After(handshake.lastTimestamp)
|
||||||
ok = timestamp.After(handshake.lastTimestamp)
|
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
|
||||||
ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
|
|
||||||
handshake.mutex.RUnlock()
|
handshake.mutex.RUnlock()
|
||||||
if !ok {
|
if replay {
|
||||||
|
device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake replay @ %v\n", peer, timestamp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if flood {
|
||||||
|
device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake flood\n", peer)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,9 +335,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.remoteEphemeral = msg.Ephemeral
|
handshake.remoteEphemeral = msg.Ephemeral
|
||||||
handshake.lastTimestamp = timestamp
|
if timestamp.After(handshake.lastTimestamp) {
|
||||||
handshake.lastInitiationConsumption = time.Now()
|
handshake.lastTimestamp = timestamp
|
||||||
handshake.state = HandshakeInitiationConsumed
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if now.After(handshake.lastInitiationConsumption) {
|
||||||
|
handshake.lastInitiationConsumption = now
|
||||||
|
}
|
||||||
|
handshake.state = handshakeInitiationConsumed
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
@@ -332,7 +357,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
if handshake.state != HandshakeInitiationConsumed {
|
if handshake.state != handshakeInitiationConsumed {
|
||||||
return nil, errors.New("handshake initiation must be consumed first")
|
return nil, errors.New("handshake initiation must be consumed first")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -388,7 +413,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
|||||||
handshake.mixHash(msg.Empty[:])
|
handshake.mixHash(msg.Empty[:])
|
||||||
}()
|
}()
|
||||||
|
|
||||||
handshake.state = HandshakeResponseCreated
|
handshake.state = handshakeResponseCreated
|
||||||
|
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
@@ -418,7 +443,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
handshake.mutex.RLock()
|
handshake.mutex.RLock()
|
||||||
defer handshake.mutex.RUnlock()
|
defer handshake.mutex.RUnlock()
|
||||||
|
|
||||||
if handshake.state != HandshakeInitiationCreated {
|
if handshake.state != handshakeInitiationCreated {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,7 +504,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
|||||||
handshake.hash = hash
|
handshake.hash = hash
|
||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.state = HandshakeResponseConsumed
|
handshake.state = handshakeResponseConsumed
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
@@ -504,7 +529,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
var sendKey [chacha20poly1305.KeySize]byte
|
var sendKey [chacha20poly1305.KeySize]byte
|
||||||
var recvKey [chacha20poly1305.KeySize]byte
|
var recvKey [chacha20poly1305.KeySize]byte
|
||||||
|
|
||||||
if handshake.state == HandshakeResponseConsumed {
|
if handshake.state == handshakeResponseConsumed {
|
||||||
KDF2(
|
KDF2(
|
||||||
&sendKey,
|
&sendKey,
|
||||||
&recvKey,
|
&recvKey,
|
||||||
@@ -512,7 +537,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
isInitiator = true
|
isInitiator = true
|
||||||
} else if handshake.state == HandshakeResponseCreated {
|
} else if handshake.state == handshakeResponseCreated {
|
||||||
KDF2(
|
KDF2(
|
||||||
&recvKey,
|
&recvKey,
|
||||||
&sendKey,
|
&sendKey,
|
||||||
@@ -521,7 +546,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
)
|
)
|
||||||
isInitiator = false
|
isInitiator = false
|
||||||
} else {
|
} else {
|
||||||
return errors.New("invalid state for keypair derivation")
|
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero handshake
|
// zero handshake
|
||||||
@@ -529,7 +554,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
setZero(handshake.chainKey[:])
|
setZero(handshake.chainKey[:])
|
||||||
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
||||||
setZero(handshake.localEphemeral[:])
|
setZero(handshake.localEphemeral[:])
|
||||||
peer.handshake.state = HandshakeZeroed
|
peer.handshake.state = handshakeZeroed
|
||||||
|
|
||||||
// create AEAD instances
|
// create AEAD instances
|
||||||
|
|
||||||
@@ -542,7 +567,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
|
|
||||||
keypair.created = time.Now()
|
keypair.created = time.Now()
|
||||||
keypair.sendNonce = 0
|
keypair.sendNonce = 0
|
||||||
keypair.replayFilter.Init()
|
keypair.replayFilter.Reset()
|
||||||
keypair.isInitiator = isInitiator
|
keypair.isInitiator = isInitiator
|
||||||
keypair.localIndex = peer.handshake.localIndex
|
keypair.localIndex = peer.handshake.localIndex
|
||||||
keypair.remoteIndex = peer.handshake.remoteIndex
|
keypair.remoteIndex = peer.handshake.remoteIndex
|
||||||
@@ -559,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
|
|
||||||
previous := keypairs.previous
|
previous := keypairs.previous
|
||||||
next := keypairs.next
|
next := keypairs.loadNext()
|
||||||
current := keypairs.current
|
current := keypairs.current
|
||||||
|
|
||||||
if isInitiator {
|
if isInitiator {
|
||||||
if next != nil {
|
if next != nil {
|
||||||
keypairs.next = nil
|
keypairs.storeNext(nil)
|
||||||
keypairs.previous = next
|
keypairs.previous = next
|
||||||
device.DeleteKeypair(current)
|
device.DeleteKeypair(current)
|
||||||
} else {
|
} else {
|
||||||
@@ -573,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
keypairs.current = keypair
|
keypairs.current = keypair
|
||||||
} else {
|
} else {
|
||||||
keypairs.next = keypair
|
keypairs.storeNext(keypair)
|
||||||
device.DeleteKeypair(next)
|
device.DeleteKeypair(next)
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
@@ -584,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
|
|||||||
|
|
||||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
if keypairs.next != receivedKeypair {
|
|
||||||
|
if keypairs.loadNext() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
defer keypairs.Unlock()
|
defer keypairs.Unlock()
|
||||||
if keypairs.next != receivedKeypair {
|
if keypairs.loadNext() != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
old := keypairs.previous
|
old := keypairs.previous
|
||||||
keypairs.previous = keypairs.current
|
keypairs.previous = keypairs.current
|
||||||
peer.device.DeleteKeypair(old)
|
peer.device.DeleteKeypair(old)
|
||||||
keypairs.current = keypairs.next
|
keypairs.current = keypairs.loadNext()
|
||||||
keypairs.next = nil
|
keypairs.storeNext(nil)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -52,6 +52,15 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
|
||||||
|
err = loadExactHex(key[:], src)
|
||||||
|
if key.IsZero() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key.clamp()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (key NoisePrivateKey) ToHex() string {
|
func (key NoisePrivateKey) ToHex() string {
|
||||||
return hex.EncodeToString(key[:])
|
return hex.EncodeToString(key[:])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
|
|||||||
t.Fatal("failed to derive keypair for peer 2", err)
|
t.Fatal("failed to derive keypair for peer 2", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key1 := peer1.keypairs.next
|
key1 := peer1.keypairs.loadNext()
|
||||||
key2 := peer2.keypairs.current
|
key2 := peer2.keypairs.current
|
||||||
|
|
||||||
// encrypting / decryption test
|
// encrypting / decryption test
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -24,10 +26,15 @@ type Peer struct {
|
|||||||
keypairs Keypairs
|
keypairs Keypairs
|
||||||
handshake Handshake
|
handshake Handshake
|
||||||
device *Device
|
device *Device
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
persistentKeepaliveInterval uint16
|
persistentKeepaliveInterval uint16
|
||||||
|
disableRoaming bool
|
||||||
|
|
||||||
// This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
|
// These fields are accessed with atomic operations, which must be
|
||||||
|
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
|
||||||
|
// allocated struct will be 64-bit aligned. So we place
|
||||||
|
// atomically-accessed fields up front, so that they can share in
|
||||||
|
// this alignment before smaller fields throw it off.
|
||||||
stats struct {
|
stats struct {
|
||||||
txBytes uint64 // bytes send to peer (endpoint)
|
txBytes uint64 // bytes send to peer (endpoint)
|
||||||
rxBytes uint64 // bytes received from peer
|
rxBytes uint64 // bytes received from peer
|
||||||
@@ -51,6 +58,7 @@ type Peer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
queue struct {
|
queue struct {
|
||||||
|
sync.RWMutex
|
||||||
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
|
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
|
||||||
outbound chan *QueueOutboundElement // sequential ordering of work
|
outbound chan *QueueOutboundElement // sequential ordering of work
|
||||||
inbound chan *QueueInboundElement // sequential ordering of work
|
inbound chan *QueueInboundElement // sequential ordering of work
|
||||||
@@ -108,7 +116,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||||
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
|
|
||||||
handshake.remoteStatic = pk
|
handshake.remoteStatic = pk
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
@@ -116,13 +123,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||||||
|
|
||||||
peer.endpoint = nil
|
peer.endpoint = nil
|
||||||
|
|
||||||
// conditionally add
|
// add
|
||||||
|
|
||||||
if !ssIsZero {
|
device.peers.keyMap[pk] = peer
|
||||||
device.peers.keyMap[pk] = peer
|
|
||||||
} else {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// start peer
|
// start peer
|
||||||
|
|
||||||
@@ -193,10 +196,11 @@ func (peer *Peer) Start() {
|
|||||||
peer.routines.stopping.Add(PeerRoutineNumber)
|
peer.routines.stopping.Add(PeerRoutineNumber)
|
||||||
|
|
||||||
// prepare queues
|
// prepare queues
|
||||||
|
peer.queue.Lock()
|
||||||
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
|
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||||
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
|
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||||
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
|
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
|
||||||
|
peer.queue.Unlock()
|
||||||
|
|
||||||
peer.timersInit()
|
peer.timersInit()
|
||||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||||
@@ -222,10 +226,10 @@ func (peer *Peer) ZeroAndFlushAll() {
|
|||||||
keypairs.Lock()
|
keypairs.Lock()
|
||||||
device.DeleteKeypair(keypairs.previous)
|
device.DeleteKeypair(keypairs.previous)
|
||||||
device.DeleteKeypair(keypairs.current)
|
device.DeleteKeypair(keypairs.current)
|
||||||
device.DeleteKeypair(keypairs.next)
|
device.DeleteKeypair(keypairs.loadNext())
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
keypairs.current = nil
|
keypairs.current = nil
|
||||||
keypairs.next = nil
|
keypairs.storeNext(nil)
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
|
|
||||||
// clear handshake state
|
// clear handshake state
|
||||||
@@ -253,7 +257,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
|
|||||||
keypairs.current.sendNonce = RejectAfterMessages
|
keypairs.current.sendNonce = RejectAfterMessages
|
||||||
}
|
}
|
||||||
if keypairs.next != nil {
|
if keypairs.next != nil {
|
||||||
keypairs.next.sendNonce = RejectAfterMessages
|
keypairs.loadNext().sendNonce = RejectAfterMessages
|
||||||
}
|
}
|
||||||
keypairs.Unlock()
|
keypairs.Unlock()
|
||||||
}
|
}
|
||||||
@@ -282,17 +286,17 @@ func (peer *Peer) Stop() {
|
|||||||
|
|
||||||
// close queues
|
// close queues
|
||||||
|
|
||||||
|
peer.queue.Lock()
|
||||||
close(peer.queue.nonce)
|
close(peer.queue.nonce)
|
||||||
close(peer.queue.outbound)
|
close(peer.queue.outbound)
|
||||||
close(peer.queue.inbound)
|
close(peer.queue.inbound)
|
||||||
|
peer.queue.Unlock()
|
||||||
|
|
||||||
peer.ZeroAndFlushAll()
|
peer.ZeroAndFlushAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
var RoamingDisabled bool
|
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||||
|
if peer.disableRoaming {
|
||||||
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
|
|
||||||
if RoamingDisabled {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peer.Lock()
|
peer.Lock()
|
||||||
|
|||||||
43
device/peer_test.go
Normal file
43
device/peer_test.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkAlignment(t *testing.T, name string, offset uintptr) {
|
||||||
|
t.Helper()
|
||||||
|
if offset%8 != 0 {
|
||||||
|
t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerAlignment checks that atomically-accessed fields are
|
||||||
|
// aligned to 64-bit boundaries, as required by the atomic package.
|
||||||
|
//
|
||||||
|
// Unfortunately, violating this rule on 32-bit platforms results in a
|
||||||
|
// hard segfault at runtime.
|
||||||
|
func TestPeerAlignment(t *testing.T) {
|
||||||
|
var p Peer
|
||||||
|
|
||||||
|
typ := reflect.TypeOf(p)
|
||||||
|
t.Logf("Peer type size: %d, with fields:", typ.Size())
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
|
||||||
|
field.Name,
|
||||||
|
field.Offset,
|
||||||
|
field.Type.Size(),
|
||||||
|
field.Type.Align(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
|
||||||
|
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -17,12 +17,13 @@ import (
|
|||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
)
|
)
|
||||||
|
|
||||||
type QueueHandshakeElement struct {
|
type QueueHandshakeElement struct {
|
||||||
msgType uint32
|
msgType uint32
|
||||||
packet []byte
|
packet []byte
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
buffer *[MaxMessageSize]byte
|
buffer *[MaxMessageSize]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ type QueueInboundElement struct {
|
|||||||
packet []byte
|
packet []byte
|
||||||
counter uint64
|
counter uint64
|
||||||
keypair *Keypair
|
keypair *Keypair
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
func (elem *QueueInboundElement) Drop() {
|
func (elem *QueueInboundElement) Drop() {
|
||||||
@@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
|||||||
* Every time the bind is updated a new routine is started for
|
* Every time the bind is updated a new routine is started for
|
||||||
* IPv4 and IPv6 (separately)
|
* IPv4 and IPv6 (separately)
|
||||||
*/
|
*/
|
||||||
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||||
|
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
|||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
size int
|
size int
|
||||||
endpoint Endpoint
|
endpoint conn.Endpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -183,11 +184,13 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
|||||||
|
|
||||||
// add to decryption queues
|
// add to decryption queues
|
||||||
|
|
||||||
|
peer.queue.RLock()
|
||||||
if peer.isRunning.Get() {
|
if peer.isRunning.Get() {
|
||||||
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
|
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
|
||||||
buffer = device.GetMessageBuffer()
|
buffer = device.GetMessageBuffer()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
peer.queue.RUnlock()
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
@@ -107,6 +107,8 @@ func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement,
|
|||||||
/* Queues a keepalive if no packets are queued for peer
|
/* Queues a keepalive if no packets are queued for peer
|
||||||
*/
|
*/
|
||||||
func (peer *Peer) SendKeepalive() bool {
|
func (peer *Peer) SendKeepalive() bool {
|
||||||
|
peer.queue.RLock()
|
||||||
|
defer peer.queue.RUnlock()
|
||||||
if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
|
if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -220,10 +222,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
|
|||||||
writer := bytes.NewBuffer(buff[:0])
|
writer := bytes.NewBuffer(buff[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, reply)
|
binary.Write(writer, binary.LittleEndian, reply)
|
||||||
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
||||||
if err != nil {
|
return nil
|
||||||
device.log.Error.Println("Failed to send cookie reply:", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) keepKeyFreshSending() {
|
func (peer *Peer) keepKeyFreshSending() {
|
||||||
@@ -313,6 +312,7 @@ func (device *Device) RoutineReadFromTUN() {
|
|||||||
|
|
||||||
// insert into nonce/pre-handshake queue
|
// insert into nonce/pre-handshake queue
|
||||||
|
|
||||||
|
peer.queue.RLock()
|
||||||
if peer.isRunning.Get() {
|
if peer.isRunning.Get() {
|
||||||
if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
|
if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
@@ -320,6 +320,7 @@ func (device *Device) RoutineReadFromTUN() {
|
|||||||
addToNonceQueue(peer.queue.nonce, elem, device)
|
addToNonceQueue(peer.queue.nonce, elem, device)
|
||||||
elem = nil
|
elem = nil
|
||||||
}
|
}
|
||||||
|
peer.queue.RUnlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -451,6 +452,21 @@ func (peer *Peer) RoutineNonce() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func calculatePaddingSize(packetSize, mtu int) int {
|
||||||
|
lastUnit := packetSize
|
||||||
|
if mtu == 0 {
|
||||||
|
return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
|
||||||
|
}
|
||||||
|
if lastUnit > mtu {
|
||||||
|
lastUnit %= mtu
|
||||||
|
}
|
||||||
|
paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
|
||||||
|
if paddedSize > mtu {
|
||||||
|
paddedSize = mtu
|
||||||
|
}
|
||||||
|
return paddedSize - lastUnit
|
||||||
|
}
|
||||||
|
|
||||||
/* Encrypts the elements in the queue
|
/* Encrypts the elements in the queue
|
||||||
* and marks them for sequential consumption (by releasing the mutex)
|
* and marks them for sequential consumption (by releasing the mutex)
|
||||||
*
|
*
|
||||||
@@ -517,13 +533,8 @@ func (device *Device) RoutineEncryption() {
|
|||||||
|
|
||||||
// pad content to multiple of 16
|
// pad content to multiple of 16
|
||||||
|
|
||||||
mtu := int(atomic.LoadInt32(&device.tun.mtu))
|
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
|
||||||
lastUnit := len(elem.packet) % mtu
|
for i := 0; i < paddingSize; i++ {
|
||||||
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
|
|
||||||
if paddedSize > mtu {
|
|
||||||
paddedSize = mtu
|
|
||||||
}
|
|
||||||
for i := len(elem.packet); i < paddedSize; i++ {
|
|
||||||
elem.packet = append(elem.packet, 0)
|
elem.packet = append(elem.packet, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
12
device/sticky_default.go
Normal file
12
device/sticky_default.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// +build !linux android
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
217
device/sticky_linux.go
Normal file
217
device/sticky_linux.go
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
// +build !android
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* This implements userspace semantics of "sticky sockets", modeled after
|
||||||
|
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||||
|
* of the sticky-sockets.c example code:
|
||||||
|
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
||||||
|
*
|
||||||
|
* Currently there is no way to achieve this within the net package:
|
||||||
|
* See e.g. https://github.com/golang/go/issues/17930
|
||||||
|
* So this code is remains platform dependent.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||||
|
netlinkSock, err := createNetlinkRouteSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(netlinkSock)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
|
||||||
|
|
||||||
|
return netlinkCancel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||||
|
type peerEndpointPtr struct {
|
||||||
|
peer *Peer
|
||||||
|
endpoint *conn.Endpoint
|
||||||
|
}
|
||||||
|
var reqPeer map[uint32]peerEndpointPtr
|
||||||
|
var reqPeerLock sync.Mutex
|
||||||
|
|
||||||
|
defer unix.Close(netlinkSock)
|
||||||
|
|
||||||
|
for msg := make([]byte, 1<<16); ; {
|
||||||
|
var err error
|
||||||
|
var msgn int
|
||||||
|
for {
|
||||||
|
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
|
||||||
|
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !netlinkCancel.ReadyRead() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||||
|
|
||||||
|
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||||
|
|
||||||
|
if uint(hdr.Len) > uint(len(remain)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch hdr.Type {
|
||||||
|
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||||
|
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
||||||
|
if uint(len(remain)) < uint(hdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
||||||
|
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
||||||
|
for {
|
||||||
|
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
||||||
|
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||||
|
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||||
|
reqPeerLock.Lock()
|
||||||
|
if reqPeer == nil {
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr, ok := reqPeer[hdr.Seq]
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr.peer.Lock()
|
||||||
|
if &pePtr.peer.endpoint != pePtr.endpoint {
|
||||||
|
pePtr.peer.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
|
||||||
|
pePtr.peer.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
|
||||||
|
pePtr.peer.Unlock()
|
||||||
|
}
|
||||||
|
attr = attr[attrhdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
reqPeerLock.Lock()
|
||||||
|
reqPeer = make(map[uint32]peerEndpointPtr)
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
go func() {
|
||||||
|
device.peers.RLock()
|
||||||
|
i := uint32(1)
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.RLock()
|
||||||
|
if peer.endpoint == nil {
|
||||||
|
peer.RUnlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
|
||||||
|
if nativeEP == nil {
|
||||||
|
peer.RUnlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
|
||||||
|
peer.RUnlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nlmsg := struct {
|
||||||
|
hdr unix.NlMsghdr
|
||||||
|
msg unix.RtMsg
|
||||||
|
dsthdr unix.RtAttr
|
||||||
|
dst [4]byte
|
||||||
|
srchdr unix.RtAttr
|
||||||
|
src [4]byte
|
||||||
|
markhdr unix.RtAttr
|
||||||
|
mark uint32
|
||||||
|
}{
|
||||||
|
unix.NlMsghdr{
|
||||||
|
Type: uint16(unix.RTM_GETROUTE),
|
||||||
|
Flags: unix.NLM_F_REQUEST,
|
||||||
|
Seq: i,
|
||||||
|
},
|
||||||
|
unix.RtMsg{
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Dst_len: 32,
|
||||||
|
Src_len: 32,
|
||||||
|
},
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_DST,
|
||||||
|
},
|
||||||
|
nativeEP.Dst4().Addr,
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_SRC,
|
||||||
|
},
|
||||||
|
nativeEP.Src4().Src,
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_MARK,
|
||||||
|
},
|
||||||
|
uint32(bind.LastMark()),
|
||||||
|
}
|
||||||
|
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||||
|
reqPeerLock.Lock()
|
||||||
|
reqPeer[i] = peerEndpointPtr{
|
||||||
|
peer: peer,
|
||||||
|
endpoint: &peer.endpoint,
|
||||||
|
}
|
||||||
|
reqPeerLock.Unlock()
|
||||||
|
peer.RUnlock()
|
||||||
|
i++
|
||||||
|
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
remain = remain[hdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNetlinkRouteSocket() (int, error) {
|
||||||
|
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
saddr := &unix.SockaddrNetlink{
|
||||||
|
Family: unix.AF_NETLINK,
|
||||||
|
Groups: unix.RTMGRP_IPV4_ROUTE,
|
||||||
|
}
|
||||||
|
err = unix.Bind(sock, saddr)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(sock)
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return sock, nil
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*
|
*
|
||||||
* This is based heavily on timers.c from the kernel implementation.
|
* This is based heavily on timers.c from the kernel implementation.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,7 +32,7 @@ func (s IPCError) ErrorCode() int64 {
|
|||||||
return s.int64
|
return s.int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
|
func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
|
||||||
lines := make([]string, 0, 100)
|
lines := make([]string, 0, 100)
|
||||||
send := func(line string) {
|
send := func(line string) {
|
||||||
lines = append(lines, line)
|
lines = append(lines, line)
|
||||||
@@ -105,7 +107,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
|
||||||
scanner := bufio.NewScanner(socket)
|
scanner := bufio.NewScanner(socket)
|
||||||
logError := device.log.Error
|
logError := device.log.Error
|
||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
@@ -113,6 +115,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
var peer *Peer
|
var peer *Peer
|
||||||
|
|
||||||
dummy := false
|
dummy := false
|
||||||
|
createdNewPeer := false
|
||||||
deviceConfig := true
|
deviceConfig := true
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
@@ -137,7 +140,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
switch key {
|
switch key {
|
||||||
case "private_key":
|
case "private_key":
|
||||||
var sk NoisePrivateKey
|
var sk NoisePrivateKey
|
||||||
err := sk.FromHex(value)
|
err := sk.FromMaybeZeroHex(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to set private_key:", err)
|
logError.Println("Failed to set private_key:", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
return &IPCError{ipc.IpcErrorInvalid}
|
||||||
@@ -237,7 +240,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
peer = device.LookupPeer(publicKey)
|
peer = device.LookupPeer(publicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer == nil {
|
createdNewPeer = peer == nil
|
||||||
|
if createdNewPeer {
|
||||||
peer, err = device.NewPeer(publicKey)
|
peer, err = device.NewPeer(publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to create new peer:", err)
|
logError.Println("Failed to create new peer:", err)
|
||||||
@@ -251,6 +255,20 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "update_only":
|
||||||
|
|
||||||
|
// allow disabling of creation
|
||||||
|
|
||||||
|
if value != "true" {
|
||||||
|
logError.Println("Failed to set update only, invalid value:", value)
|
||||||
|
return &IPCError{ipc.IpcErrorInvalid}
|
||||||
|
}
|
||||||
|
if createdNewPeer && !dummy {
|
||||||
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
|
peer = &Peer{}
|
||||||
|
dummy = true
|
||||||
|
}
|
||||||
|
|
||||||
case "remove":
|
case "remove":
|
||||||
|
|
||||||
// remove currently selected peer from device
|
// remove currently selected peer from device
|
||||||
@@ -290,7 +308,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
err := func() error {
|
err := func() error {
|
||||||
peer.Lock()
|
peer.Lock()
|
||||||
defer peer.Unlock()
|
defer peer.Unlock()
|
||||||
endpoint, err := CreateEndpoint(value)
|
endpoint, err := conn.CreateEndpoint(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -404,10 +422,20 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
|||||||
|
|
||||||
switch op {
|
switch op {
|
||||||
case "set=1\n":
|
case "set=1\n":
|
||||||
status = device.IpcSetOperation(buffered.Reader)
|
err = device.IpcSetOperation(buffered.Reader)
|
||||||
|
if err != nil && !errors.As(err, &status) {
|
||||||
|
// should never happen
|
||||||
|
device.log.Error.Println("Invalid UAPI error:", err)
|
||||||
|
status = &IPCError{1}
|
||||||
|
}
|
||||||
|
|
||||||
case "get=1\n":
|
case "get=1\n":
|
||||||
status = device.IpcGetOperation(buffered.Writer)
|
err = device.IpcGetOperation(buffered.Writer)
|
||||||
|
if err != nil && !errors.As(err, &status) {
|
||||||
|
// should never happen
|
||||||
|
device.log.Error.Println("Invalid UAPI error:", err)
|
||||||
|
status = &IPCError{1}
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
device.log.Error.Println("Invalid UAPI operation:", op)
|
device.log.Error.Println("Invalid UAPI operation:", op)
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
const WireGuardGoVersion = "0.0.20190805"
|
const WireGuardGoVersion = "0.0.20201118"
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
// +build !android
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
const DoNotUseThisProgramOnLinux = UseTheKernelModuleInstead
|
|
||||||
|
|
||||||
// --------------------------------------------------------
|
|
||||||
// Do not use this on Linux. Instead use the kernel module.
|
|
||||||
// See wireguard.com/install for more information.
|
|
||||||
// --------------------------------------------------------
|
|
||||||
8
go.mod
8
go.mod
@@ -1,9 +1,9 @@
|
|||||||
module golang.zx2c4.com/wireguard
|
module golang.zx2c4.com/wireguard
|
||||||
|
|
||||||
go 1.12
|
go 1.13
|
||||||
|
|
||||||
require (
|
require (
|
||||||
golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56
|
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9
|
||||||
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980
|
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b
|
||||||
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f
|
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7
|
||||||
)
|
)
|
||||||
|
|||||||
18
go.sum
18
go.sum
@@ -1,11 +1,17 @@
|
|||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56 h1:ZpKuNIejY8P0ExLOVyKhb0WsgG8UdvHXe6TWjY7eL6k=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o=
|
||||||
|
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980 h1:dfGZHvZk057jK2MCeWus/TowKpJ8y4AmooUzdBSR9GU=
|
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME=
|
||||||
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f h1:dHNZYIYdq2QuU6w73vZ/DzesPbVlZVYZTtTZmrnsbQ8=
|
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7 h1:s330+6z/Ko3J0o6rvOcwXe5nzs7UT9tLKHoOXYn6uE0=
|
||||||
|
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
|||||||
@@ -2,32 +2,20 @@
|
|||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
var socketDirectory = "/var/run/wireguard"
|
|
||||||
|
|
||||||
const (
|
|
||||||
IpcErrorIO = -int64(unix.EIO)
|
|
||||||
IpcErrorProtocol = -int64(unix.EPROTO)
|
|
||||||
IpcErrorInvalid = -int64(unix.EINVAL)
|
|
||||||
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
|
||||||
socketName = "%s.sock"
|
|
||||||
)
|
|
||||||
|
|
||||||
type UAPIListener struct {
|
type UAPIListener struct {
|
||||||
listener net.Listener // unix socket listener
|
listener net.Listener // unix socket listener
|
||||||
connNew chan net.Conn
|
connNew chan net.Conn
|
||||||
@@ -84,10 +72,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
unixListener.SetUnlinkOnClose(true)
|
unixListener.SetUnlinkOnClose(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
socketPath := path.Join(
|
socketPath := sockPath(name)
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
// watch for deletion of socket
|
// watch for deletion of socket
|
||||||
|
|
||||||
@@ -146,58 +131,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
|
|
||||||
return uapi, nil
|
return uapi, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UAPIOpen(name string) (*os.File, error) {
|
|
||||||
|
|
||||||
// check if path exist
|
|
||||||
|
|
||||||
err := os.MkdirAll(socketDirectory, 0755)
|
|
||||||
if err != nil && !os.IsExist(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open UNIX socket
|
|
||||||
|
|
||||||
socketPath := path.Join(
|
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldUmask := unix.Umask(0077)
|
|
||||||
listener, err := func() (*net.UnixListener, error) {
|
|
||||||
|
|
||||||
// initial connection attempt
|
|
||||||
|
|
||||||
listener, err := net.ListenUnix("unix", addr)
|
|
||||||
if err == nil {
|
|
||||||
return listener, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if socket already active
|
|
||||||
|
|
||||||
_, err = net.Dial("unix", socketPath)
|
|
||||||
if err == nil {
|
|
||||||
return nil, errors.New("unix socket in use")
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup & attempt again
|
|
||||||
|
|
||||||
err = os.Remove(socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return net.ListenUnix("unix", addr)
|
|
||||||
}()
|
|
||||||
unix.Umask(oldUmask)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return listener.File()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,31 +1,18 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/rwcancel"
|
"golang.zx2c4.com/wireguard/rwcancel"
|
||||||
)
|
)
|
||||||
|
|
||||||
var socketDirectory = "/var/run/wireguard"
|
|
||||||
|
|
||||||
const (
|
|
||||||
IpcErrorIO = -int64(unix.EIO)
|
|
||||||
IpcErrorProtocol = -int64(unix.EPROTO)
|
|
||||||
IpcErrorInvalid = -int64(unix.EINVAL)
|
|
||||||
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
|
||||||
socketName = "%s.sock"
|
|
||||||
)
|
|
||||||
|
|
||||||
type UAPIListener struct {
|
type UAPIListener struct {
|
||||||
listener net.Listener // unix socket listener
|
listener net.Listener // unix socket listener
|
||||||
connNew chan net.Conn
|
connNew chan net.Conn
|
||||||
@@ -84,10 +71,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
|
|
||||||
// watch for deletion of socket
|
// watch for deletion of socket
|
||||||
|
|
||||||
socketPath := path.Join(
|
socketPath := sockPath(name)
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
uapi.inotifyFd, err = unix.InotifyInit()
|
uapi.inotifyFd, err = unix.InotifyInit()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -143,58 +127,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
|||||||
|
|
||||||
return uapi, nil
|
return uapi, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UAPIOpen(name string) (*os.File, error) {
|
|
||||||
|
|
||||||
// check if path exist
|
|
||||||
|
|
||||||
err := os.MkdirAll(socketDirectory, 0755)
|
|
||||||
if err != nil && !os.IsExist(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open UNIX socket
|
|
||||||
|
|
||||||
socketPath := path.Join(
|
|
||||||
socketDirectory,
|
|
||||||
fmt.Sprintf(socketName, name),
|
|
||||||
)
|
|
||||||
|
|
||||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldUmask := unix.Umask(0077)
|
|
||||||
listener, err := func() (*net.UnixListener, error) {
|
|
||||||
|
|
||||||
// initial connection attempt
|
|
||||||
|
|
||||||
listener, err := net.ListenUnix("unix", addr)
|
|
||||||
if err == nil {
|
|
||||||
return listener, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check if socket already active
|
|
||||||
|
|
||||||
_, err = net.Dial("unix", socketPath)
|
|
||||||
if err == nil {
|
|
||||||
return nil, errors.New("unix socket in use")
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup & attempt again
|
|
||||||
|
|
||||||
err = os.Remove(socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return net.ListenUnix("unix", addr)
|
|
||||||
}()
|
|
||||||
unix.Umask(oldUmask)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return listener.File()
|
|
||||||
}
|
|
||||||
|
|||||||
65
ipc/uapi_unix.go
Normal file
65
ipc/uapi_unix.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
// +build linux darwin freebsd openbsd
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ipc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
IpcErrorIO = -int64(unix.EIO)
|
||||||
|
IpcErrorProtocol = -int64(unix.EPROTO)
|
||||||
|
IpcErrorInvalid = -int64(unix.EINVAL)
|
||||||
|
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||||
|
)
|
||||||
|
|
||||||
|
// socketDirectory is variable because it is modified by a linker
|
||||||
|
// flag in wireguard-android.
|
||||||
|
var socketDirectory = "/var/run/wireguard"
|
||||||
|
|
||||||
|
func sockPath(iface string) string {
|
||||||
|
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UAPIOpen(name string) (*os.File, error) {
|
||||||
|
if err := os.MkdirAll(socketDirectory, 0755); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
socketPath := sockPath(name)
|
||||||
|
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldUmask := unix.Umask(0077)
|
||||||
|
defer unix.Umask(oldUmask)
|
||||||
|
|
||||||
|
listener, err := net.ListenUnix("unix", addr)
|
||||||
|
if err == nil {
|
||||||
|
return listener.File()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test socket, if not in use cleanup and try again.
|
||||||
|
if _, err := net.Dial("unix", socketPath); err == nil {
|
||||||
|
return nil, errors.New("unix socket in use")
|
||||||
|
}
|
||||||
|
if err := os.Remove(socketPath); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
listener, err = net.ListenUnix("unix", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return listener.File()
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ipc
|
package ipc
|
||||||
@@ -8,6 +8,8 @@ package ipc
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,14 +49,22 @@ func (l *UAPIListener) Addr() net.Addr {
|
|||||||
return l.listener.Addr()
|
return l.listener.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
|
var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
var UAPISecurityDescriptor = "O:SYD:P(A;;GA;;;SY)"
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
|
||||||
|
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func UAPIListen(name string) (net.Listener, error) {
|
func UAPIListen(name string) (net.Listener, error) {
|
||||||
config := winpipe.PipeConfig{
|
config := winpipe.PipeConfig{
|
||||||
SecurityDescriptor: UAPISecurityDescriptor,
|
SecurityDescriptor: UAPISecurityDescriptor,
|
||||||
}
|
}
|
||||||
listener, err := winpipe.ListenPipe("\\\\.\\pipe\\WireGuard\\"+name, &config)
|
listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2005 Microsoft
|
* Copyright (C) 2005 Microsoft
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
package winpipe
|
package winpipe
|
||||||
|
|
||||||
@@ -13,15 +13,16 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
|
//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
|
||||||
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
|
//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
|
||||||
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
||||||
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
||||||
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
||||||
|
|
||||||
type atomicBool int32
|
type atomicBool int32
|
||||||
|
|
||||||
@@ -55,7 +56,7 @@ func (e *timeoutError) Temporary() bool { return true }
|
|||||||
type timeoutChan chan struct{}
|
type timeoutChan chan struct{}
|
||||||
|
|
||||||
var ioInitOnce sync.Once
|
var ioInitOnce sync.Once
|
||||||
var ioCompletionPort syscall.Handle
|
var ioCompletionPort windows.Handle
|
||||||
|
|
||||||
// ioResult contains the result of an asynchronous IO operation
|
// ioResult contains the result of an asynchronous IO operation
|
||||||
type ioResult struct {
|
type ioResult struct {
|
||||||
@@ -65,12 +66,12 @@ type ioResult struct {
|
|||||||
|
|
||||||
// ioOperation represents an outstanding asynchronous Win32 IO
|
// ioOperation represents an outstanding asynchronous Win32 IO
|
||||||
type ioOperation struct {
|
type ioOperation struct {
|
||||||
o syscall.Overlapped
|
o windows.Overlapped
|
||||||
ch chan ioResult
|
ch chan ioResult
|
||||||
}
|
}
|
||||||
|
|
||||||
func initIo() {
|
func initIo() {
|
||||||
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
|
h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -81,7 +82,7 @@ func initIo() {
|
|||||||
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
||||||
// It takes ownership of this handle and will close it if it is garbage collected.
|
// It takes ownership of this handle and will close it if it is garbage collected.
|
||||||
type win32File struct {
|
type win32File struct {
|
||||||
handle syscall.Handle
|
handle windows.Handle
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
wgLock sync.RWMutex
|
wgLock sync.RWMutex
|
||||||
closing atomicBool
|
closing atomicBool
|
||||||
@@ -99,7 +100,7 @@ type deadlineHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// makeWin32File makes a new win32File from an existing file handle
|
// makeWin32File makes a new win32File from an existing file handle
|
||||||
func makeWin32File(h syscall.Handle) (*win32File, error) {
|
func makeWin32File(h windows.Handle) (*win32File, error) {
|
||||||
f := &win32File{handle: h}
|
f := &win32File{handle: h}
|
||||||
ioInitOnce.Do(initIo)
|
ioInitOnce.Do(initIo)
|
||||||
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
||||||
@@ -115,7 +116,7 @@ func makeWin32File(h syscall.Handle) (*win32File, error) {
|
|||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
|
func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
|
||||||
return makeWin32File(h)
|
return makeWin32File(h)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,7 +130,7 @@ func (f *win32File) closeHandle() {
|
|||||||
cancelIoEx(f.handle, nil)
|
cancelIoEx(f.handle, nil)
|
||||||
f.wg.Wait()
|
f.wg.Wait()
|
||||||
// at this point, no new IO can start
|
// at this point, no new IO can start
|
||||||
syscall.Close(f.handle)
|
windows.Close(f.handle)
|
||||||
f.handle = 0
|
f.handle = 0
|
||||||
} else {
|
} else {
|
||||||
f.wgLock.Unlock()
|
f.wgLock.Unlock()
|
||||||
@@ -158,12 +159,12 @@ func (f *win32File) prepareIo() (*ioOperation, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ioCompletionProcessor processes completed async IOs forever
|
// ioCompletionProcessor processes completed async IOs forever
|
||||||
func ioCompletionProcessor(h syscall.Handle) {
|
func ioCompletionProcessor(h windows.Handle) {
|
||||||
for {
|
for {
|
||||||
var bytes uint32
|
var bytes uint32
|
||||||
var key uintptr
|
var key uintptr
|
||||||
var op *ioOperation
|
var op *ioOperation
|
||||||
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE)
|
err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
|
||||||
if op == nil {
|
if op == nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -174,7 +175,7 @@ func ioCompletionProcessor(h syscall.Handle) {
|
|||||||
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
|
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
|
||||||
// the operation has actually completed.
|
// the operation has actually completed.
|
||||||
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
||||||
if err != syscall.ERROR_IO_PENDING {
|
if err != windows.ERROR_IO_PENDING {
|
||||||
return int(bytes), err
|
return int(bytes), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,7 +194,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
|
|||||||
select {
|
select {
|
||||||
case r = <-c.ch:
|
case r = <-c.ch:
|
||||||
err = r.err
|
err = r.err
|
||||||
if err == syscall.ERROR_OPERATION_ABORTED {
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
if f.closing.isSet() {
|
if f.closing.isSet() {
|
||||||
err = ErrFileClosed
|
err = ErrFileClosed
|
||||||
}
|
}
|
||||||
@@ -206,7 +207,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
|
|||||||
cancelIoEx(f.handle, &c.o)
|
cancelIoEx(f.handle, &c.o)
|
||||||
r = <-c.ch
|
r = <-c.ch
|
||||||
err = r.err
|
err = r.err
|
||||||
if err == syscall.ERROR_OPERATION_ABORTED {
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
err = ErrTimeout
|
err = ErrTimeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -231,14 +232,14 @@ func (f *win32File) Read(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var bytes uint32
|
var bytes uint32
|
||||||
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
|
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
|
||||||
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
|
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
|
||||||
runtime.KeepAlive(b)
|
runtime.KeepAlive(b)
|
||||||
|
|
||||||
// Handle EOF conditions.
|
// Handle EOF conditions.
|
||||||
if err == nil && n == 0 && len(b) != 0 {
|
if err == nil && n == 0 && len(b) != 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
} else if err == syscall.ERROR_BROKEN_PIPE {
|
} else if err == windows.ERROR_BROKEN_PIPE {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
} else {
|
} else {
|
||||||
return n, err
|
return n, err
|
||||||
@@ -258,7 +259,7 @@ func (f *win32File) Write(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var bytes uint32
|
var bytes uint32
|
||||||
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
|
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
|
||||||
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
|
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
|
||||||
runtime.KeepAlive(b)
|
runtime.KeepAlive(b)
|
||||||
return n, err
|
return n, err
|
||||||
@@ -273,7 +274,7 @@ func (f *win32File) SetWriteDeadline(deadline time.Time) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *win32File) Flush() error {
|
func (f *win32File) Flush() error {
|
||||||
return syscall.FlushFileBuffers(f.handle)
|
return windows.FlushFileBuffers(f.handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *win32File) Fd() uintptr {
|
func (f *win32File) Fd() uintptr {
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2005 Microsoft
|
* Copyright (C) 2005 Microsoft
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package winpipe
|
package winpipe
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go pipe.go sd.go file.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2005 Microsoft
|
* Copyright (C) 2005 Microsoft
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package winpipe
|
package winpipe
|
||||||
@@ -16,18 +16,19 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe
|
//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
|
||||||
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW
|
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
|
||||||
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW
|
//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
|
||||||
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
||||||
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
||||||
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
|
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
|
||||||
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
|
//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
|
||||||
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
||||||
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
|
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
|
||||||
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
|
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
|
||||||
@@ -41,7 +42,7 @@ type objectAttributes struct {
|
|||||||
RootDirectory uintptr
|
RootDirectory uintptr
|
||||||
ObjectName *unicodeString
|
ObjectName *unicodeString
|
||||||
Attributes uintptr
|
Attributes uintptr
|
||||||
SecurityDescriptor *securityDescriptor
|
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
SecurityQoS uintptr
|
SecurityQoS uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,16 +52,6 @@ type unicodeString struct {
|
|||||||
Buffer uintptr
|
Buffer uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
type securityDescriptor struct {
|
|
||||||
Revision byte
|
|
||||||
Sbz1 byte
|
|
||||||
Control uint16
|
|
||||||
Owner uintptr
|
|
||||||
Group uintptr
|
|
||||||
Sacl uintptr
|
|
||||||
Dacl uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
type ntstatus int32
|
type ntstatus int32
|
||||||
|
|
||||||
func (status ntstatus) Err() error {
|
func (status ntstatus) Err() error {
|
||||||
@@ -71,11 +62,6 @@ func (status ntstatus) Err() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
cERROR_PIPE_BUSY = syscall.Errno(231)
|
|
||||||
cERROR_NO_DATA = syscall.Errno(232)
|
|
||||||
cERROR_PIPE_CONNECTED = syscall.Errno(535)
|
|
||||||
cERROR_SEM_TIMEOUT = syscall.Errno(121)
|
|
||||||
|
|
||||||
cSECURITY_SQOS_PRESENT = 0x100000
|
cSECURITY_SQOS_PRESENT = 0x100000
|
||||||
cSECURITY_ANONYMOUS = 0
|
cSECURITY_ANONYMOUS = 0
|
||||||
|
|
||||||
@@ -88,8 +74,6 @@ const (
|
|||||||
|
|
||||||
cFILE_PIPE_MESSAGE_TYPE = 1
|
cFILE_PIPE_MESSAGE_TYPE = 1
|
||||||
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
|
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
|
||||||
|
|
||||||
cSE_DACL_PRESENT = 4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -170,7 +154,7 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
|
|||||||
// zero-byte message, ensure that all future Read() calls
|
// zero-byte message, ensure that all future Read() calls
|
||||||
// also return EOF.
|
// also return EOF.
|
||||||
f.readEOF = true
|
f.readEOF = true
|
||||||
} else if err == syscall.ERROR_MORE_DATA {
|
} else if err == windows.ERROR_MORE_DATA {
|
||||||
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||||
// and the message still has more bytes. Treat this as a success, since
|
// and the message still has more bytes. Treat this as a success, since
|
||||||
// this package presents all named pipes as byte streams.
|
// this package presents all named pipes as byte streams.
|
||||||
@@ -188,17 +172,17 @@ func (s pipeAddress) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
||||||
func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
|
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return syscall.Handle(0), ctx.Err()
|
return windows.Handle(0), ctx.Err()
|
||||||
default:
|
default:
|
||||||
h, err := createFile(*path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
|
h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
if err != cERROR_PIPE_BUSY {
|
if err != windows.ERROR_PIPE_BUSY {
|
||||||
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||||
}
|
}
|
||||||
// Wait 10 msec and try again. This is a rather simplistic
|
// Wait 10 msec and try again. This is a rather simplistic
|
||||||
@@ -211,7 +195,7 @@ func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
|
|||||||
// DialPipe connects to a named pipe by path, timing out if the connection
|
// DialPipe connects to a named pipe by path, timing out if the connection
|
||||||
// takes longer than the specified duration. If timeout is nil, then we use
|
// takes longer than the specified duration. If timeout is nil, then we use
|
||||||
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
||||||
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) {
|
||||||
var absTimeout time.Time
|
var absTimeout time.Time
|
||||||
if timeout != nil {
|
if timeout != nil {
|
||||||
absTimeout = time.Now().Add(*timeout)
|
absTimeout = time.Now().Add(*timeout)
|
||||||
@@ -219,7 +203,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
|||||||
absTimeout = time.Now().Add(time.Second * 2)
|
absTimeout = time.Now().Add(time.Second * 2)
|
||||||
}
|
}
|
||||||
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
||||||
conn, err := DialPipeContext(ctx, path)
|
conn, err := DialPipeContext(ctx, path, expectedOwner)
|
||||||
if err == context.DeadlineExceeded {
|
if err == context.DeadlineExceeded {
|
||||||
return nil, ErrTimeout
|
return nil, ErrTimeout
|
||||||
}
|
}
|
||||||
@@ -228,23 +212,41 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
|||||||
|
|
||||||
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
||||||
// cancellation or timeout.
|
// cancellation or timeout.
|
||||||
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
|
func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) {
|
||||||
var err error
|
var err error
|
||||||
var h syscall.Handle
|
var h windows.Handle
|
||||||
h, err = tryDialPipe(ctx, &path)
|
h, err = tryDialPipe(ctx, &path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if expectedOwner != nil {
|
||||||
|
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
realOwner, _, err := sd.Owner()
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !realOwner.Equals(expectedOwner) {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, windows.ERROR_ACCESS_DENIED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var flags uint32
|
var flags uint32
|
||||||
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := makeWin32File(h)
|
f, err := makeWin32File(h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
syscall.Close(h)
|
windows.Close(h)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,7 +266,7 @@ type acceptResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type win32PipeListener struct {
|
type win32PipeListener struct {
|
||||||
firstHandle syscall.Handle
|
firstHandle windows.Handle
|
||||||
path string
|
path string
|
||||||
config PipeConfig
|
config PipeConfig
|
||||||
acceptCh chan (chan acceptResponse)
|
acceptCh chan (chan acceptResponse)
|
||||||
@@ -272,8 +274,8 @@ type win32PipeListener struct {
|
|||||||
doneCh chan int
|
doneCh chan int
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
|
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) {
|
||||||
path16, err := syscall.UTF16FromString(path)
|
path16, err := windows.UTF16FromString(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
}
|
}
|
||||||
@@ -285,31 +287,32 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
|
|||||||
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
|
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
}
|
}
|
||||||
defer localFree(ntPath.Buffer)
|
defer windows.LocalFree(windows.Handle(ntPath.Buffer))
|
||||||
oa.ObjectName = &ntPath
|
oa.ObjectName = &ntPath
|
||||||
|
|
||||||
// The security descriptor is only needed for the first pipe.
|
// The security descriptor is only needed for the first pipe.
|
||||||
if first {
|
if first {
|
||||||
if sd != nil {
|
if sd != nil {
|
||||||
len := uint32(len(sd))
|
oa.SecurityDescriptor = sd
|
||||||
sdb := localAlloc(0, len)
|
|
||||||
defer localFree(sdb)
|
|
||||||
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
|
|
||||||
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
|
|
||||||
} else {
|
} else {
|
||||||
// Construct the default named pipe security descriptor.
|
// Construct the default named pipe security descriptor.
|
||||||
var dacl uintptr
|
var dacl uintptr
|
||||||
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
||||||
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
|
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
|
||||||
}
|
}
|
||||||
defer localFree(dacl)
|
defer windows.LocalFree(windows.Handle(dacl))
|
||||||
|
sd, err := windows.NewSecurityDescriptor()
|
||||||
sdb := &securityDescriptor{
|
if err != nil {
|
||||||
Revision: 1,
|
return 0, fmt.Errorf("creating new security descriptor: %s", err)
|
||||||
Control: cSE_DACL_PRESENT,
|
|
||||||
Dacl: dacl,
|
|
||||||
}
|
}
|
||||||
oa.SecurityDescriptor = sdb
|
if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil {
|
||||||
|
return 0, fmt.Errorf("assigning dacl: %s", err)
|
||||||
|
}
|
||||||
|
sd, err = sd.ToSelfRelative()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("converting to self-relative: %s", err)
|
||||||
|
}
|
||||||
|
oa.SecurityDescriptor = sd
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,22 +322,22 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
|
|||||||
}
|
}
|
||||||
|
|
||||||
disposition := uint32(cFILE_OPEN)
|
disposition := uint32(cFILE_OPEN)
|
||||||
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
|
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
||||||
if first {
|
if first {
|
||||||
disposition = cFILE_CREATE
|
disposition = cFILE_CREATE
|
||||||
// By not asking for read or write access, the named pipe file system
|
// By not asking for read or write access, the named pipe file system
|
||||||
// will put this pipe into an initially disconnected state, blocking
|
// will put this pipe into an initially disconnected state, blocking
|
||||||
// client connections until the next call with first == false.
|
// client connections until the next call with first == false.
|
||||||
access = syscall.SYNCHRONIZE
|
access = windows.SYNCHRONIZE
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := int64(-50 * 10000) // 50ms
|
timeout := int64(-50 * 10000) // 50ms
|
||||||
|
|
||||||
var (
|
var (
|
||||||
h syscall.Handle
|
h windows.Handle
|
||||||
iosb ioStatusBlock
|
iosb ioStatusBlock
|
||||||
)
|
)
|
||||||
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
|
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
}
|
}
|
||||||
@@ -350,7 +353,7 @@ func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
|
|||||||
}
|
}
|
||||||
f, err := makeWin32File(h)
|
f, err := makeWin32File(h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
syscall.Close(h)
|
windows.Close(h)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return f, nil
|
return f, nil
|
||||||
@@ -401,7 +404,7 @@ func (l *win32PipeListener) listenerRoutine() {
|
|||||||
p, err = l.makeConnectedServerPipe()
|
p, err = l.makeConnectedServerPipe()
|
||||||
// If the connection was immediately closed by the client, try
|
// If the connection was immediately closed by the client, try
|
||||||
// again.
|
// again.
|
||||||
if err != cERROR_NO_DATA {
|
if err != windows.ERROR_NO_DATA {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -409,7 +412,7 @@ func (l *win32PipeListener) listenerRoutine() {
|
|||||||
closed = err == ErrPipeListenerClosed
|
closed = err == ErrPipeListenerClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
syscall.Close(l.firstHandle)
|
windows.Close(l.firstHandle)
|
||||||
l.firstHandle = 0
|
l.firstHandle = 0
|
||||||
// Notify Close() and Accept() callers that the handle has been closed.
|
// Notify Close() and Accept() callers that the handle has been closed.
|
||||||
close(l.doneCh)
|
close(l.doneCh)
|
||||||
@@ -417,8 +420,8 @@ func (l *win32PipeListener) listenerRoutine() {
|
|||||||
|
|
||||||
// PipeConfig contain configuration for the pipe listener.
|
// PipeConfig contain configuration for the pipe listener.
|
||||||
type PipeConfig struct {
|
type PipeConfig struct {
|
||||||
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
|
// SecurityDescriptor contains a Windows security descriptor.
|
||||||
SecurityDescriptor string
|
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
|
|
||||||
// MessageMode determines whether the pipe is in byte or message mode. In either
|
// MessageMode determines whether the pipe is in byte or message mode. In either
|
||||||
// case the pipe is read in byte mode by default. The only practical difference in
|
// case the pipe is read in byte mode by default. The only practical difference in
|
||||||
@@ -438,20 +441,10 @@ type PipeConfig struct {
|
|||||||
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
||||||
// The pipe must not already exist.
|
// The pipe must not already exist.
|
||||||
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
||||||
var (
|
|
||||||
sd []byte
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if c == nil {
|
if c == nil {
|
||||||
c = &PipeConfig{}
|
c = &PipeConfig{}
|
||||||
}
|
}
|
||||||
if c.SecurityDescriptor != "" {
|
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
||||||
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h, err := makeServerPipeHandle(path, sd, c, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -476,7 +469,7 @@ func connectPipe(p *win32File) error {
|
|||||||
|
|
||||||
err = connectNamedPipe(p.handle, &c.o)
|
err = connectNamedPipe(p.handle, &c.o)
|
||||||
_, err = p.asyncIo(c, nil, 0, err)
|
_, err = p.asyncIo(c, nil, 0, err)
|
||||||
if err != nil && err != cERROR_PIPE_CONNECTED {
|
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
// +build windows
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winpipe
|
|
||||||
|
|
||||||
import (
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
|
|
||||||
//sys localFree(mem uintptr) = LocalFree
|
|
||||||
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
|
|
||||||
|
|
||||||
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
|
||||||
var sdBuffer uintptr
|
|
||||||
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer localFree(sdBuffer)
|
|
||||||
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
|
|
||||||
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
|
|
||||||
return sd, nil
|
|
||||||
}
|
|
||||||
@@ -39,30 +39,26 @@ func errnoErr(e syscall.Errno) error {
|
|||||||
var (
|
var (
|
||||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
||||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
|
||||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||||
|
|
||||||
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
||||||
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
|
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
|
||||||
procCreateFileW = modkernel32.NewProc("CreateFileW")
|
procCreateFileW = modkernel32.NewProc("CreateFileW")
|
||||||
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
|
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
|
||||||
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
|
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
|
||||||
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
|
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
|
||||||
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
|
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
|
||||||
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
||||||
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
||||||
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
||||||
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
||||||
procLocalFree = modkernel32.NewProc("LocalFree")
|
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
||||||
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
|
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
||||||
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
|
||||||
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
||||||
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
|
||||||
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
|
|
||||||
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
|
func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
|
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -74,7 +70,7 @@ func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
|
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||||
var _p0 *uint16
|
var _p0 *uint16
|
||||||
_p0, err = syscall.UTF16PtrFromString(name)
|
_p0, err = syscall.UTF16PtrFromString(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -83,10 +79,10 @@ func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances ui
|
|||||||
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
|
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||||
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
|
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
|
||||||
handle = syscall.Handle(r0)
|
handle = windows.Handle(r0)
|
||||||
if handle == syscall.InvalidHandle {
|
if handle == windows.InvalidHandle {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
err = errnoErr(e1)
|
err = errnoErr(e1)
|
||||||
} else {
|
} else {
|
||||||
@@ -96,7 +92,7 @@ func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
|
func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||||
var _p0 *uint16
|
var _p0 *uint16
|
||||||
_p0, err = syscall.UTF16PtrFromString(name)
|
_p0, err = syscall.UTF16PtrFromString(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -105,10 +101,10 @@ func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAtt
|
|||||||
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
|
func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||||
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
|
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
|
||||||
handle = syscall.Handle(r0)
|
handle = windows.Handle(r0)
|
||||||
if handle == syscall.InvalidHandle {
|
if handle == windows.InvalidHandle {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
err = errnoErr(e1)
|
err = errnoErr(e1)
|
||||||
} else {
|
} else {
|
||||||
@@ -118,7 +114,7 @@ func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityA
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
|
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -130,7 +126,7 @@ func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSiz
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
|
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -148,7 +144,7 @@ func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
|
func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
|
||||||
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
|
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
|
||||||
status = ntstatus(r0)
|
status = ntstatus(r0)
|
||||||
return
|
return
|
||||||
@@ -174,39 +170,7 @@ func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) {
|
func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
|
||||||
var _p0 *uint16
|
|
||||||
_p0, err = syscall.UTF16PtrFromString(str)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func localFree(mem uintptr) {
|
|
||||||
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
|
|
||||||
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
|
|
||||||
len = uint32(r0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -218,9 +182,9 @@ func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) {
|
func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
|
||||||
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
|
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
|
||||||
newport = syscall.Handle(r0)
|
newport = windows.Handle(r0)
|
||||||
if newport == 0 {
|
if newport == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
err = errnoErr(e1)
|
err = errnoErr(e1)
|
||||||
@@ -231,7 +195,7 @@ func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintpt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
|
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -243,7 +207,7 @@ func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
|
func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
|
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -255,7 +219,7 @@ func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err erro
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
||||||
var _p0 uint32
|
var _p0 uint32
|
||||||
if wait {
|
if wait {
|
||||||
_p0 = 1
|
_p0 = 1
|
||||||
|
|||||||
38
main.go
38
main.go
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package main
|
package main
|
||||||
@@ -40,31 +40,17 @@ func warning() {
|
|||||||
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
|
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
shouldQuit := os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
|
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
fmt.Fprintln(os.Stderr, "┌───────────────────────────────────────────────────┐")
|
||||||
fmt.Fprintln(os.Stderr, "W G")
|
fmt.Fprintln(os.Stderr, "│ │")
|
||||||
fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
|
fmt.Fprintln(os.Stderr, "│ Running this software on Linux is unnecessary, │")
|
||||||
fmt.Fprintln(os.Stderr, "W which is probably unnecessary and foolish. This G")
|
fmt.Fprintln(os.Stderr, "│ because the Linux kernel has built-in first │")
|
||||||
fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
|
fmt.Fprintln(os.Stderr, "│ class support for WireGuard, which will be │")
|
||||||
fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
|
fmt.Fprintln(os.Stderr, "│ faster, slicker, and better integrated. For │")
|
||||||
fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
|
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
||||||
fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
|
fmt.Fprintln(os.Stderr, "│ please visit: <https://wireguard.com/install>. │")
|
||||||
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
|
fmt.Fprintln(os.Stderr, "│ │")
|
||||||
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
|
fmt.Fprintln(os.Stderr, "└───────────────────────────────────────────────────┘")
|
||||||
if shouldQuit {
|
|
||||||
fmt.Fprintln(os.Stderr, "W G")
|
|
||||||
fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
|
|
||||||
fmt.Fprintln(os.Stderr, "W the advice here, please first export this G")
|
|
||||||
fmt.Fprintln(os.Stderr, "W environment variable: G")
|
|
||||||
fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
|
|
||||||
}
|
|
||||||
fmt.Fprintln(os.Stderr, "W G")
|
|
||||||
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
|
||||||
|
|
||||||
if shouldQuit {
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -75,8 +61,6 @@ func main() {
|
|||||||
|
|
||||||
warning()
|
warning()
|
||||||
|
|
||||||
// parse arguments
|
|
||||||
|
|
||||||
var foreground bool
|
var foreground bool
|
||||||
var interfaceName string
|
var interfaceName string
|
||||||
if len(os.Args) < 2 || len(os.Args) > 3 {
|
if len(os.Args) < 2 || len(os.Args) > 3 {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package main
|
package main
|
||||||
@@ -37,7 +37,7 @@ func main() {
|
|||||||
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
|
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
|
||||||
logger.Debug.Println("Debug log enabled")
|
logger.Debug.Println("Debug log enabled")
|
||||||
|
|
||||||
tun, err := tun.CreateTUN(interfaceName)
|
tun, err := tun.CreateTUN(interfaceName, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
realInterfaceName, err2 := tun.Name()
|
realInterfaceName, err2 := tun.Name()
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ratelimiter
|
package ratelimiter
|
||||||
@@ -20,21 +20,23 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RatelimiterEntry struct {
|
type RatelimiterEntry struct {
|
||||||
sync.Mutex
|
mu sync.Mutex
|
||||||
lastTime time.Time
|
lastTime time.Time
|
||||||
tokens int64
|
tokens int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ratelimiter struct {
|
type Ratelimiter struct {
|
||||||
sync.RWMutex
|
mu sync.RWMutex
|
||||||
stopReset chan struct{}
|
timeNow func() time.Time
|
||||||
|
|
||||||
|
stopReset chan struct{} // send to reset, close to stop
|
||||||
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
|
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
|
||||||
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
|
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Close() {
|
func (rate *Ratelimiter) Close() {
|
||||||
rate.Lock()
|
rate.mu.Lock()
|
||||||
defer rate.Unlock()
|
defer rate.mu.Unlock()
|
||||||
|
|
||||||
if rate.stopReset != nil {
|
if rate.stopReset != nil {
|
||||||
close(rate.stopReset)
|
close(rate.stopReset)
|
||||||
@@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Init() {
|
func (rate *Ratelimiter) Init() {
|
||||||
rate.Lock()
|
rate.mu.Lock()
|
||||||
defer rate.Unlock()
|
defer rate.mu.Unlock()
|
||||||
|
|
||||||
|
if rate.timeNow == nil {
|
||||||
|
rate.timeNow = time.Now
|
||||||
|
}
|
||||||
|
|
||||||
// stop any ongoing garbage collection routine
|
// stop any ongoing garbage collection routine
|
||||||
|
|
||||||
if rate.stopReset != nil {
|
if rate.stopReset != nil {
|
||||||
close(rate.stopReset)
|
close(rate.stopReset)
|
||||||
}
|
}
|
||||||
@@ -55,50 +60,52 @@ func (rate *Ratelimiter) Init() {
|
|||||||
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
|
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
|
||||||
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
|
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
|
||||||
|
|
||||||
// start garbage collection routine
|
stopReset := rate.stopReset // store in case Init is called again.
|
||||||
|
|
||||||
|
// Start garbage collection routine.
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(time.Second)
|
ticker := time.NewTicker(time.Second)
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case _, ok := <-rate.stopReset:
|
case _, ok := <-stopReset:
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
if ok {
|
if !ok {
|
||||||
ticker = time.NewTicker(time.Second)
|
|
||||||
} else {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
ticker = time.NewTicker(time.Second)
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
func() {
|
if rate.cleanup() {
|
||||||
rate.Lock()
|
ticker.Stop()
|
||||||
defer rate.Unlock()
|
}
|
||||||
|
|
||||||
for key, entry := range rate.tableIPv4 {
|
|
||||||
entry.Lock()
|
|
||||||
if time.Since(entry.lastTime) > garbageCollectTime {
|
|
||||||
delete(rate.tableIPv4, key)
|
|
||||||
}
|
|
||||||
entry.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, entry := range rate.tableIPv6 {
|
|
||||||
entry.Lock()
|
|
||||||
if time.Since(entry.lastTime) > garbageCollectTime {
|
|
||||||
delete(rate.tableIPv6, key)
|
|
||||||
}
|
|
||||||
entry.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
|
|
||||||
ticker.Stop()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rate *Ratelimiter) cleanup() (empty bool) {
|
||||||
|
rate.mu.Lock()
|
||||||
|
defer rate.mu.Unlock()
|
||||||
|
|
||||||
|
for key, entry := range rate.tableIPv4 {
|
||||||
|
entry.mu.Lock()
|
||||||
|
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
||||||
|
delete(rate.tableIPv4, key)
|
||||||
|
}
|
||||||
|
entry.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, entry := range rate.tableIPv6 {
|
||||||
|
entry.mu.Lock()
|
||||||
|
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
||||||
|
delete(rate.tableIPv6, key)
|
||||||
|
}
|
||||||
|
entry.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
|
||||||
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
||||||
var entry *RatelimiterEntry
|
var entry *RatelimiterEntry
|
||||||
var keyIPv4 [net.IPv4len]byte
|
var keyIPv4 [net.IPv4len]byte
|
||||||
@@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
|||||||
IPv4 := ip.To4()
|
IPv4 := ip.To4()
|
||||||
IPv6 := ip.To16()
|
IPv6 := ip.To16()
|
||||||
|
|
||||||
rate.RLock()
|
rate.mu.RLock()
|
||||||
|
|
||||||
if IPv4 != nil {
|
if IPv4 != nil {
|
||||||
copy(keyIPv4[:], IPv4)
|
copy(keyIPv4[:], IPv4)
|
||||||
@@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
|||||||
entry = rate.tableIPv6[keyIPv6]
|
entry = rate.tableIPv6[keyIPv6]
|
||||||
}
|
}
|
||||||
|
|
||||||
rate.RUnlock()
|
rate.mu.RUnlock()
|
||||||
|
|
||||||
// make new entry if not found
|
// make new entry if not found
|
||||||
|
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
entry = new(RatelimiterEntry)
|
entry = new(RatelimiterEntry)
|
||||||
entry.tokens = maxTokens - packetCost
|
entry.tokens = maxTokens - packetCost
|
||||||
entry.lastTime = time.Now()
|
entry.lastTime = rate.timeNow()
|
||||||
rate.Lock()
|
rate.mu.Lock()
|
||||||
if IPv4 != nil {
|
if IPv4 != nil {
|
||||||
rate.tableIPv4[keyIPv4] = entry
|
rate.tableIPv4[keyIPv4] = entry
|
||||||
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
|
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
|
||||||
@@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
|||||||
rate.stopReset <- struct{}{}
|
rate.stopReset <- struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rate.Unlock()
|
rate.mu.Unlock()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// add tokens to entry
|
// add tokens to entry
|
||||||
|
|
||||||
entry.Lock()
|
entry.mu.Lock()
|
||||||
now := time.Now()
|
now := rate.timeNow()
|
||||||
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
||||||
entry.lastTime = now
|
entry.lastTime = now
|
||||||
if entry.tokens > maxTokens {
|
if entry.tokens > maxTokens {
|
||||||
@@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
|||||||
|
|
||||||
if entry.tokens > packetCost {
|
if entry.tokens > packetCost {
|
||||||
entry.tokens -= packetCost
|
entry.tokens -= packetCost
|
||||||
entry.Unlock()
|
entry.mu.Unlock()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
entry.Unlock()
|
entry.mu.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ratelimiter
|
package ratelimiter
|
||||||
@@ -11,22 +11,21 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RatelimiterResult struct {
|
type result struct {
|
||||||
allowed bool
|
allowed bool
|
||||||
text string
|
text string
|
||||||
wait time.Duration
|
wait time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRatelimiter(t *testing.T) {
|
func TestRatelimiter(t *testing.T) {
|
||||||
|
var rate Ratelimiter
|
||||||
|
var expectedResults []result
|
||||||
|
|
||||||
var ratelimiter Ratelimiter
|
nano := func(nano int64) time.Duration {
|
||||||
var expectedResults []RatelimiterResult
|
|
||||||
|
|
||||||
Nano := func(nano int64) time.Duration {
|
|
||||||
return time.Nanosecond * time.Duration(nano)
|
return time.Nanosecond * time.Duration(nano)
|
||||||
}
|
}
|
||||||
|
|
||||||
Add := func(res RatelimiterResult) {
|
add := func(res result) {
|
||||||
expectedResults = append(
|
expectedResults = append(
|
||||||
expectedResults,
|
expectedResults,
|
||||||
res,
|
res,
|
||||||
@@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < packetsBurstable; i++ {
|
for i := 0; i < packetsBurstable; i++ {
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
text: "inital burst",
|
text: "initial burst",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: false,
|
allowed: false,
|
||||||
text: "after burst",
|
text: "after burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
|
wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
|
||||||
text: "filling tokens for single packet",
|
text: "filling tokens for single packet",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: false,
|
allowed: false,
|
||||||
text: "not having refilled enough",
|
text: "not having refilled enough",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
|
wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
|
||||||
text: "filling tokens for two packet burst",
|
text: "filling tokens for two packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
text: "second packet in 2 packet burst",
|
text: "second packet in 2 packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
Add(RatelimiterResult{
|
add(result{
|
||||||
allowed: false,
|
allowed: false,
|
||||||
text: "packet following 2 packet burst",
|
text: "packet following 2 packet burst",
|
||||||
})
|
})
|
||||||
@@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
|
|||||||
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
||||||
}
|
}
|
||||||
|
|
||||||
ratelimiter.Init()
|
now := time.Now()
|
||||||
|
rate.timeNow = func() time.Time {
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
// Lock to avoid data race with cleanup goroutine from Init.
|
||||||
|
rate.mu.Lock()
|
||||||
|
defer rate.mu.Unlock()
|
||||||
|
|
||||||
|
rate.timeNow = time.Now
|
||||||
|
}()
|
||||||
|
timeSleep := func(d time.Duration) {
|
||||||
|
now = now.Add(d + 1)
|
||||||
|
rate.cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
rate.Init()
|
||||||
|
defer rate.Close()
|
||||||
|
|
||||||
for i, res := range expectedResults {
|
for i, res := range expectedResults {
|
||||||
time.Sleep(res.wait)
|
timeSleep(res.wait)
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
allowed := ratelimiter.Allow(ip)
|
allowed := rate.Allow(ip)
|
||||||
if allowed != res.allowed {
|
if allowed != res.allowed {
|
||||||
t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
|
t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
101
replay/replay.go
101
replay/replay.go
@@ -1,83 +1,62 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
||||||
package replay
|
package replay
|
||||||
|
|
||||||
/* Implementation of RFC6479
|
type block uint64
|
||||||
* https://tools.ietf.org/html/rfc6479
|
|
||||||
*
|
|
||||||
* The implementation is not safe for concurrent use!
|
|
||||||
*/
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// See: https://golang.org/src/math/big/arith.go
|
blockBitLog = 6 // 1<<6 == 64 bits
|
||||||
_Wordm = ^uintptr(0)
|
blockBits = 1 << blockBitLog // must be power of 2
|
||||||
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
|
ringBlocks = 1 << 7 // must be power of 2
|
||||||
_WordSize = 1 << _WordLogSize
|
windowSize = (ringBlocks - 1) * blockBits
|
||||||
|
blockMask = ringBlocks - 1
|
||||||
|
bitMask = blockBits - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// A Filter rejects replayed messages by checking if message counter value is
|
||||||
CounterRedundantBitsLog = _WordLogSize + 3
|
// within a sliding window of previously received messages.
|
||||||
CounterRedundantBits = _WordSize * 8
|
// The zero value for Filter is an empty filter ready to use.
|
||||||
CounterBitsTotal = 2048
|
// Filters are unsafe for concurrent use.
|
||||||
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
|
type Filter struct {
|
||||||
)
|
last uint64
|
||||||
|
ring [ringBlocks]block
|
||||||
const (
|
|
||||||
BacktrackWords = CounterBitsTotal / _WordSize
|
|
||||||
)
|
|
||||||
|
|
||||||
func minUint64(a uint64, b uint64) uint64 {
|
|
||||||
if a > b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ReplayFilter struct {
|
// Reset resets the filter to empty state.
|
||||||
counter uint64
|
func (f *Filter) Reset() {
|
||||||
backtrack [BacktrackWords]uintptr
|
f.last = 0
|
||||||
|
f.ring[0] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (filter *ReplayFilter) Init() {
|
// ValidateCounter checks if the counter should be accepted.
|
||||||
filter.counter = 0
|
// Overlimit counters (>= limit) are always rejected.
|
||||||
filter.backtrack[0] = 0
|
func (f *Filter) ValidateCounter(counter uint64, limit uint64) bool {
|
||||||
}
|
|
||||||
|
|
||||||
func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
|
|
||||||
if counter >= limit {
|
if counter >= limit {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
indexBlock := counter >> blockBitLog
|
||||||
indexWord := counter >> CounterRedundantBitsLog
|
if counter > f.last { // move window forward
|
||||||
|
current := f.last >> blockBitLog
|
||||||
if counter > filter.counter {
|
diff := indexBlock - current
|
||||||
|
if diff > ringBlocks {
|
||||||
// move window forward
|
diff = ringBlocks // cap diff to clear the whole ring
|
||||||
|
|
||||||
current := filter.counter >> CounterRedundantBitsLog
|
|
||||||
diff := minUint64(indexWord-current, BacktrackWords)
|
|
||||||
for i := uint64(1); i <= diff; i++ {
|
|
||||||
filter.backtrack[(current+i)%BacktrackWords] = 0
|
|
||||||
}
|
}
|
||||||
filter.counter = counter
|
for i := current + 1; i <= current+diff; i++ {
|
||||||
|
f.ring[i&blockMask] = 0
|
||||||
} else if filter.counter-counter > CounterWindowSize {
|
}
|
||||||
|
f.last = counter
|
||||||
// behind current window
|
} else if f.last-counter > windowSize { // behind current window
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
indexWord %= BacktrackWords
|
|
||||||
indexBit := counter & uint64(CounterRedundantBits-1)
|
|
||||||
|
|
||||||
// check and set bit
|
// check and set bit
|
||||||
|
indexBlock &= blockMask
|
||||||
oldValue := filter.backtrack[indexWord]
|
indexBit := counter & bitMask
|
||||||
newValue := oldValue | (1 << indexBit)
|
old := f.ring[indexBlock]
|
||||||
filter.backtrack[indexWord] = newValue
|
new := old | 1<<indexBit
|
||||||
return oldValue != newValue
|
f.ring[indexBlock] = new
|
||||||
|
return old != new
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package replay
|
package replay
|
||||||
@@ -14,22 +14,22 @@ import (
|
|||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
const RejectAfterMessages = 1<<64 - 1<<13 - 1
|
||||||
|
|
||||||
func TestReplay(t *testing.T) {
|
func TestReplay(t *testing.T) {
|
||||||
var filter ReplayFilter
|
var filter Filter
|
||||||
|
|
||||||
T_LIM := CounterWindowSize + 1
|
const T_LIM = windowSize + 1
|
||||||
|
|
||||||
testNumber := 0
|
testNumber := 0
|
||||||
T := func(n uint64, v bool) {
|
T := func(n uint64, expected bool) {
|
||||||
testNumber++
|
testNumber++
|
||||||
if filter.ValidateCounter(n, RejectAfterMessages) != v {
|
if filter.ValidateCounter(n, RejectAfterMessages) != expected {
|
||||||
t.Fatal("Test", testNumber, "failed", n, v)
|
t.Fatal("Test", testNumber, "failed", n, expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
|
|
||||||
T(0, true) /* 1 */
|
T(0, true) /* 1 */
|
||||||
T(1, true) /* 2 */
|
T(1, true) /* 2 */
|
||||||
@@ -67,53 +67,53 @@ func TestReplay(t *testing.T) {
|
|||||||
T(0, false) /* 34 */
|
T(0, false) /* 34 */
|
||||||
|
|
||||||
t.Log("Bulk test 1")
|
t.Log("Bulk test 1")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := uint64(1); i <= CounterWindowSize; i++ {
|
for i := uint64(1); i <= windowSize; i++ {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(0, true)
|
T(0, true)
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
|
||||||
t.Log("Bulk test 2")
|
t.Log("Bulk test 2")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := uint64(2); i <= CounterWindowSize+1; i++ {
|
for i := uint64(2); i <= windowSize+1; i++ {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(1, true)
|
T(1, true)
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
|
||||||
t.Log("Bulk test 3")
|
t.Log("Bulk test 3")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize + 1; i > 0; i-- {
|
for i := uint64(windowSize + 1); i > 0; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Log("Bulk test 4")
|
t.Log("Bulk test 4")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize + 2; i > 1; i-- {
|
for i := uint64(windowSize + 2); i > 1; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
|
||||||
t.Log("Bulk test 5")
|
t.Log("Bulk test 5")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize; i > 0; i-- {
|
for i := uint64(windowSize); i > 0; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(CounterWindowSize+1, true)
|
T(windowSize+1, true)
|
||||||
T(0, false)
|
T(0, false)
|
||||||
|
|
||||||
t.Log("Bulk test 6")
|
t.Log("Bulk test 6")
|
||||||
filter.Init()
|
filter.Reset()
|
||||||
testNumber = 0
|
testNumber = 0
|
||||||
for i := CounterWindowSize; i > 0; i-- {
|
for i := uint64(windowSize); i > 0; i-- {
|
||||||
T(i, true)
|
T(i, true)
|
||||||
}
|
}
|
||||||
T(0, true)
|
T(0, true)
|
||||||
T(CounterWindowSize+1, true)
|
T(windowSize+1, true)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
// +build !windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package rwcancel
|
package rwcancel
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
|
// +build !windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// Package rwcancel implements cancelable read/write operations on
|
||||||
|
// a file descriptor.
|
||||||
package rwcancel
|
package rwcancel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -60,7 +64,13 @@ func (rw *RWCancel) ReadyRead() bool {
|
|||||||
fdset := fdSet{}
|
fdset := fdSet{}
|
||||||
fdset.set(rw.fd)
|
fdset.set(rw.fd)
|
||||||
fdset.set(closeFd)
|
fdset.set(closeFd)
|
||||||
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
|
var err error
|
||||||
|
for {
|
||||||
|
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
|
||||||
|
if err == nil || !RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -75,7 +85,13 @@ func (rw *RWCancel) ReadyWrite() bool {
|
|||||||
fdset := fdSet{}
|
fdset := fdSet{}
|
||||||
fdset.set(rw.fd)
|
fdset.set(rw.fd)
|
||||||
fdset.set(closeFd)
|
fdset.set(closeFd)
|
||||||
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
|
var err error
|
||||||
|
for {
|
||||||
|
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
|
||||||
|
if err == nil || !RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
8
rwcancel/rwcancel_windows.go
Normal file
8
rwcancel/rwcancel_windows.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
package rwcancel
|
||||||
|
|
||||||
|
type RWCancel struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*RWCancel) Cancel() {}
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
// +build !linux
|
// +build !linux,!windows
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package rwcancel
|
package rwcancel
|
||||||
@@ -10,5 +10,6 @@ package rwcancel
|
|||||||
import "golang.org/x/sys/unix"
|
import "golang.org/x/sys/unix"
|
||||||
|
|
||||||
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
|
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
|
||||||
return unix.Select(nfd, r, w, e, timeout)
|
_, err := unix.Select(nfd, r, w, e, timeout)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package rwcancel
|
package rwcancel
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tai64n
|
package tai64n
|
||||||
@@ -17,16 +17,19 @@ const whitenerMask = uint32(0x1000000 - 1)
|
|||||||
|
|
||||||
type Timestamp [TimestampSize]byte
|
type Timestamp [TimestampSize]byte
|
||||||
|
|
||||||
func Now() Timestamp {
|
func stamp(t time.Time) Timestamp {
|
||||||
var tai64n Timestamp
|
var tai64n Timestamp
|
||||||
now := time.Now()
|
secs := base + uint64(t.Unix())
|
||||||
secs := base + uint64(now.Unix())
|
nano := uint32(t.Nanosecond()) &^ whitenerMask
|
||||||
nano := uint32(now.Nanosecond()) &^ whitenerMask
|
|
||||||
binary.BigEndian.PutUint64(tai64n[:], secs)
|
binary.BigEndian.PutUint64(tai64n[:], secs)
|
||||||
binary.BigEndian.PutUint32(tai64n[8:], nano)
|
binary.BigEndian.PutUint32(tai64n[8:], nano)
|
||||||
return tai64n
|
return tai64n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Now() Timestamp {
|
||||||
|
return stamp(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
func (t1 Timestamp) After(t2 Timestamp) bool {
|
func (t1 Timestamp) After(t2 Timestamp) bool {
|
||||||
return bytes.Compare(t1[:], t2[:]) > 0
|
return bytes.Compare(t1[:], t2[:]) > 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tai64n
|
package tai64n
|
||||||
@@ -10,21 +10,31 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Testing the essential property of the timestamp
|
// Test that timestamps are monotonic as required by Wireguard and that
|
||||||
* as used by WireGuard.
|
// nanosecond-level information is whitened to prevent side channel attacks.
|
||||||
*/
|
|
||||||
func TestMonotonic(t *testing.T) {
|
func TestMonotonic(t *testing.T) {
|
||||||
old := Now()
|
startTime := time.Unix(0, 123456789) // a nontrivial bit pattern
|
||||||
for i := 0; i < 50; i++ {
|
// Whitening should reduce timestamp granularity
|
||||||
next := Now()
|
// to more than 10 but fewer than 20 milliseconds.
|
||||||
if next.After(old) {
|
tests := []struct {
|
||||||
t.Error("Whitening insufficient")
|
name string
|
||||||
}
|
t1, t2 time.Time
|
||||||
time.Sleep(time.Duration(whitenerMask)/time.Nanosecond + 1)
|
wantAfter bool
|
||||||
next = Now()
|
}{
|
||||||
if !next.After(old) {
|
{"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false},
|
||||||
t.Error("Not monotonically increasing on whitened nano-second scale")
|
{"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false},
|
||||||
}
|
{"after_1_ms", startTime, startTime.Add(time.Millisecond), false},
|
||||||
old = next
|
{"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false},
|
||||||
|
{"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ts1, ts2 := stamp(tt.t1), stamp(tt.t2)
|
||||||
|
got := ts2.After(ts1)
|
||||||
|
if got != tt.wantAfter {
|
||||||
|
t.Errorf("after = %v; want %v", got, tt.wantAfter)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tun
|
package tun
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tun
|
package tun
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tun
|
package tun
|
||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
@@ -19,19 +20,6 @@ import (
|
|||||||
|
|
||||||
const utunControlName = "com.apple.net.utun_control"
|
const utunControlName = "com.apple.net.utun_control"
|
||||||
|
|
||||||
// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h
|
|
||||||
const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3
|
|
||||||
|
|
||||||
// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h
|
|
||||||
type sockaddrCtl struct {
|
|
||||||
scLen uint8
|
|
||||||
scFamily uint8
|
|
||||||
ssSysaddr uint16
|
|
||||||
scID uint32
|
|
||||||
scUnit uint32
|
|
||||||
scReserved [5]uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type NativeTun struct {
|
type NativeTun struct {
|
||||||
name string
|
name string
|
||||||
tunFile *os.File
|
tunFile *os.File
|
||||||
@@ -40,7 +28,21 @@ type NativeTun struct {
|
|||||||
routeSocket int
|
routeSocket int
|
||||||
}
|
}
|
||||||
|
|
||||||
var sockaddrCtlSize uintptr = 32
|
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
iface, err = net.InterfaceByIndex(index)
|
||||||
|
if err != nil {
|
||||||
|
if opErr, ok := err.(*net.OpError); ok {
|
||||||
|
if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM {
|
||||||
|
time.Sleep(time.Duration(i) * time.Second / 3)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return iface, err
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||||
var (
|
var (
|
||||||
@@ -74,7 +76,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, err := net.InterfaceByIndex(ifindex)
|
iface, err := retryInterfaceByIndex(ifindex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tun.errors <- err
|
tun.errors <- err
|
||||||
return
|
return
|
||||||
@@ -113,43 +115,21 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var ctlInfo = &struct {
|
ctlInfo := &unix.CtlInfo{}
|
||||||
ctlID uint32
|
copy(ctlInfo.Name[:], []byte(utunControlName))
|
||||||
ctlName [96]byte
|
err = unix.IoctlCtlInfo(fd, ctlInfo)
|
||||||
}{}
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
|
||||||
copy(ctlInfo.ctlName[:], []byte(utunControlName))
|
|
||||||
|
|
||||||
_, _, errno := unix.Syscall(
|
|
||||||
unix.SYS_IOCTL,
|
|
||||||
uintptr(fd),
|
|
||||||
uintptr(_CTLIOCGINFO),
|
|
||||||
uintptr(unsafe.Pointer(ctlInfo)),
|
|
||||||
)
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return nil, fmt.Errorf("_CTLIOCGINFO: %v", errno)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sc := sockaddrCtl{
|
sc := &unix.SockaddrCtl{
|
||||||
scLen: uint8(sockaddrCtlSize),
|
ID: ctlInfo.Id,
|
||||||
scFamily: unix.AF_SYSTEM,
|
Unit: uint32(ifIndex) + 1,
|
||||||
ssSysaddr: 2,
|
|
||||||
scID: ctlInfo.ctlID,
|
|
||||||
scUnit: uint32(ifIndex) + 1,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
scPointer := unsafe.Pointer(&sc)
|
err = unix.Connect(fd, sc)
|
||||||
|
if err != nil {
|
||||||
_, _, errno = unix.RawSyscall(
|
return nil, err
|
||||||
unix.SYS_CONNECT,
|
|
||||||
uintptr(fd),
|
|
||||||
uintptr(scPointer),
|
|
||||||
uintptr(sockaddrCtlSize),
|
|
||||||
)
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = syscall.SetNonblock(fd, true)
|
err = syscall.SetNonblock(fd, true)
|
||||||
@@ -213,27 +193,19 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Name() (string, error) {
|
func (tun *NativeTun) Name() (string, error) {
|
||||||
var ifName struct {
|
var err error
|
||||||
name [16]byte
|
|
||||||
}
|
|
||||||
ifNameSize := uintptr(16)
|
|
||||||
|
|
||||||
var errno syscall.Errno
|
|
||||||
tun.operateOnFd(func(fd uintptr) {
|
tun.operateOnFd(func(fd uintptr) {
|
||||||
_, _, errno = unix.Syscall6(
|
tun.name, err = unix.GetsockoptString(
|
||||||
unix.SYS_GETSOCKOPT,
|
int(fd),
|
||||||
fd,
|
|
||||||
2, /* #define SYSPROTO_CONTROL 2 */
|
2, /* #define SYSPROTO_CONTROL 2 */
|
||||||
2, /* #define UTUN_OPT_IFNAME 2 */
|
2, /* #define UTUN_OPT_IFNAME 2 */
|
||||||
uintptr(unsafe.Pointer(&ifName)),
|
)
|
||||||
uintptr(unsafe.Pointer(&ifNameSize)), 0)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if errno != 0 {
|
if err != nil {
|
||||||
return "", fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
|
return "", fmt.Errorf("GetSockoptString: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tun.name = string(ifName.name[:ifNameSize-1])
|
|
||||||
return tun.name, nil
|
return tun.name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,11 +275,6 @@ func (tun *NativeTun) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) setMTU(n int) error {
|
func (tun *NativeTun) setMTU(n int) error {
|
||||||
|
|
||||||
// open datagram socket
|
|
||||||
|
|
||||||
var fd int
|
|
||||||
|
|
||||||
fd, err := unix.Socket(
|
fd, err := unix.Socket(
|
||||||
unix.AF_INET,
|
unix.AF_INET,
|
||||||
unix.SOCK_DGRAM,
|
unix.SOCK_DGRAM,
|
||||||
@@ -320,29 +287,18 @@ func (tun *NativeTun) setMTU(n int) error {
|
|||||||
|
|
||||||
defer unix.Close(fd)
|
defer unix.Close(fd)
|
||||||
|
|
||||||
// do ioctl call
|
var ifr unix.IfreqMTU
|
||||||
|
copy(ifr.Name[:], tun.name)
|
||||||
var ifr [32]byte
|
ifr.MTU = int32(n)
|
||||||
copy(ifr[:], tun.name)
|
err = unix.IoctlSetIfreqMTU(fd, &ifr)
|
||||||
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
if err != nil {
|
||||||
_, _, errno := unix.Syscall(
|
return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err)
|
||||||
unix.SYS_IOCTL,
|
|
||||||
uintptr(fd),
|
|
||||||
uintptr(unix.SIOCSIFMTU),
|
|
||||||
uintptr(unsafe.Pointer(&ifr[0])),
|
|
||||||
)
|
|
||||||
|
|
||||||
if errno != 0 {
|
|
||||||
return fmt.Errorf("failed to set MTU on %s", tun.name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) MTU() (int, error) {
|
func (tun *NativeTun) MTU() (int, error) {
|
||||||
|
|
||||||
// open datagram socket
|
|
||||||
|
|
||||||
fd, err := unix.Socket(
|
fd, err := unix.Socket(
|
||||||
unix.AF_INET,
|
unix.AF_INET,
|
||||||
unix.SOCK_DGRAM,
|
unix.SOCK_DGRAM,
|
||||||
@@ -355,19 +311,10 @@ func (tun *NativeTun) MTU() (int, error) {
|
|||||||
|
|
||||||
defer unix.Close(fd)
|
defer unix.Close(fd)
|
||||||
|
|
||||||
// do ioctl call
|
ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name)
|
||||||
|
if err != nil {
|
||||||
var ifr [64]byte
|
return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err)
|
||||||
copy(ifr[:], tun.name)
|
|
||||||
_, _, errno := unix.Syscall(
|
|
||||||
unix.SYS_IOCTL,
|
|
||||||
uintptr(fd),
|
|
||||||
uintptr(unix.SIOCGIFMTU),
|
|
||||||
uintptr(unsafe.Pointer(&ifr[0])),
|
|
||||||
)
|
|
||||||
if errno != 0 {
|
|
||||||
return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return int(*(*int32)(unsafe.Pointer(&ifr[16]))), nil
|
return int(ifr.MTU), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tun
|
package tun
|
||||||
@@ -287,7 +287,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
|||||||
if errno != 0 {
|
if errno != 0 {
|
||||||
tunFile.Close()
|
tunFile.Close()
|
||||||
tunDestroy(assignedName)
|
tunDestroy(assignedName)
|
||||||
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %v", errno)
|
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %w", errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open control sockets
|
// Open control sockets
|
||||||
@@ -328,7 +328,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
|||||||
if errno != 0 {
|
if errno != 0 {
|
||||||
tunFile.Close()
|
tunFile.Close()
|
||||||
tunDestroy(assignedName)
|
tunDestroy(assignedName)
|
||||||
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %v", assignedName, errno)
|
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %w", assignedName, errno)
|
||||||
}
|
}
|
||||||
ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL
|
ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL
|
||||||
ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD
|
ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD
|
||||||
@@ -341,7 +341,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
|||||||
if errno != 0 {
|
if errno != 0 {
|
||||||
tunFile.Close()
|
tunFile.Close()
|
||||||
tunDestroy(assignedName)
|
tunDestroy(assignedName)
|
||||||
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %v", assignedName, errno)
|
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %w", assignedName, errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rename the interface
|
// Rename the interface
|
||||||
@@ -359,7 +359,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
|||||||
if errno != 0 {
|
if errno != 0 {
|
||||||
tunFile.Close()
|
tunFile.Close()
|
||||||
tunDestroy(assignedName)
|
tunDestroy(assignedName)
|
||||||
return nil, fmt.Errorf("Failed to rename %s to %s: %v", assignedName, name, errno)
|
return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno)
|
||||||
}
|
}
|
||||||
|
|
||||||
return CreateTUNFromFile(tunFile, mtu)
|
return CreateTUNFromFile(tunFile, mtu)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tun
|
package tun
|
||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -32,14 +31,17 @@ const (
|
|||||||
type NativeTun struct {
|
type NativeTun struct {
|
||||||
tunFile *os.File
|
tunFile *os.File
|
||||||
index int32 // if index
|
index int32 // if index
|
||||||
name string // name of interface
|
|
||||||
errors chan error // async error handling
|
errors chan error // async error handling
|
||||||
events chan Event // device related events
|
events chan Event // device related events
|
||||||
nopi bool // the device was pased IFF_NO_PI
|
nopi bool // the device was passed IFF_NO_PI
|
||||||
netlinkSock int
|
netlinkSock int
|
||||||
netlinkCancel *rwcancel.RWCancel
|
netlinkCancel *rwcancel.RWCancel
|
||||||
hackListenerClosed sync.Mutex
|
hackListenerClosed sync.Mutex
|
||||||
statusListenersShutdown chan struct{}
|
statusListenersShutdown chan struct{}
|
||||||
|
|
||||||
|
nameOnce sync.Once // guards calling initNameCache, which sets following fields
|
||||||
|
nameCache string // name of interface
|
||||||
|
nameErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) File() *os.File {
|
func (tun *NativeTun) File() *os.File {
|
||||||
@@ -64,14 +66,19 @@ func (tun *NativeTun) routineHackListener() {
|
|||||||
}
|
}
|
||||||
switch err {
|
switch err {
|
||||||
case unix.EINVAL:
|
case unix.EINVAL:
|
||||||
|
// If the tunnel is up, it reports that write() is
|
||||||
|
// allowed but we provided invalid data.
|
||||||
tun.events <- EventUp
|
tun.events <- EventUp
|
||||||
case unix.EIO:
|
case unix.EIO:
|
||||||
|
// If the tunnel is down, it reports that no I/O
|
||||||
|
// is possible, without checking our provided data.
|
||||||
tun.events <- EventDown
|
tun.events <- EventDown
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-time.After(time.Second):
|
case <-time.After(time.Second):
|
||||||
|
// nothing
|
||||||
case <-tun.statusListenersShutdown:
|
case <-tun.statusListenersShutdown:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -85,7 +92,7 @@ func createNetlinkSocket() (int, error) {
|
|||||||
}
|
}
|
||||||
saddr := &unix.SockaddrNetlink{
|
saddr := &unix.SockaddrNetlink{
|
||||||
Family: unix.AF_NETLINK,
|
Family: unix.AF_NETLINK,
|
||||||
Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
|
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
|
||||||
}
|
}
|
||||||
err = unix.Bind(sock, saddr)
|
err = unix.Bind(sock, saddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -126,6 +133,7 @@ func (tun *NativeTun) routineNetlinkListener() {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wasEverUp := false
|
||||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||||
|
|
||||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||||
@@ -149,10 +157,16 @@ func (tun *NativeTun) routineNetlinkListener() {
|
|||||||
|
|
||||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||||
tun.events <- EventUp
|
tun.events <- EventUp
|
||||||
|
wasEverUp = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||||
tun.events <- EventDown
|
// Don't emit EventDown before we've ever emitted EventUp.
|
||||||
|
// This avoids a startup race with HackListener, which
|
||||||
|
// might detect Up before we have finished reporting Down.
|
||||||
|
if wasEverUp {
|
||||||
|
tun.events <- EventDown
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tun.events <- EventMTUUpdate
|
tun.events <- EventMTUUpdate
|
||||||
@@ -164,11 +178,6 @@ func (tun *NativeTun) routineNetlinkListener() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) isUp() (bool, error) {
|
|
||||||
inter, err := net.InterfaceByName(tun.name)
|
|
||||||
return inter.Flags&net.FlagUp != 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func getIFIndex(name string) (int32, error) {
|
func getIFIndex(name string) (int32, error) {
|
||||||
fd, err := unix.Socket(
|
fd, err := unix.Socket(
|
||||||
unix.AF_INET,
|
unix.AF_INET,
|
||||||
@@ -198,6 +207,11 @@ func getIFIndex(name string) (int32, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) setMTU(n int) error {
|
func (tun *NativeTun) setMTU(n int) error {
|
||||||
|
name, err := tun.Name()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// open datagram socket
|
// open datagram socket
|
||||||
fd, err := unix.Socket(
|
fd, err := unix.Socket(
|
||||||
unix.AF_INET,
|
unix.AF_INET,
|
||||||
@@ -212,9 +226,8 @@ func (tun *NativeTun) setMTU(n int) error {
|
|||||||
defer unix.Close(fd)
|
defer unix.Close(fd)
|
||||||
|
|
||||||
// do ioctl call
|
// do ioctl call
|
||||||
|
|
||||||
var ifr [ifReqSize]byte
|
var ifr [ifReqSize]byte
|
||||||
copy(ifr[:], tun.name)
|
copy(ifr[:], name)
|
||||||
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
||||||
_, _, errno := unix.Syscall(
|
_, _, errno := unix.Syscall(
|
||||||
unix.SYS_IOCTL,
|
unix.SYS_IOCTL,
|
||||||
@@ -231,6 +244,11 @@ func (tun *NativeTun) setMTU(n int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) MTU() (int, error) {
|
func (tun *NativeTun) MTU() (int, error) {
|
||||||
|
name, err := tun.Name()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
// open datagram socket
|
// open datagram socket
|
||||||
fd, err := unix.Socket(
|
fd, err := unix.Socket(
|
||||||
unix.AF_INET,
|
unix.AF_INET,
|
||||||
@@ -247,7 +265,7 @@ func (tun *NativeTun) MTU() (int, error) {
|
|||||||
// do ioctl call
|
// do ioctl call
|
||||||
|
|
||||||
var ifr [ifReqSize]byte
|
var ifr [ifReqSize]byte
|
||||||
copy(ifr[:], tun.name)
|
copy(ifr[:], name)
|
||||||
_, _, errno := unix.Syscall(
|
_, _, errno := unix.Syscall(
|
||||||
unix.SYS_IOCTL,
|
unix.SYS_IOCTL,
|
||||||
uintptr(fd),
|
uintptr(fd),
|
||||||
@@ -262,6 +280,15 @@ func (tun *NativeTun) MTU() (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Name() (string, error) {
|
func (tun *NativeTun) Name() (string, error) {
|
||||||
|
tun.nameOnce.Do(tun.initNameCache)
|
||||||
|
return tun.nameCache, tun.nameErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) initNameCache() {
|
||||||
|
tun.nameCache, tun.nameErr = tun.nameSlow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *NativeTun) nameSlow() (string, error) {
|
||||||
sysconn, err := tun.tunFile.SyscallConn()
|
sysconn, err := tun.tunFile.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -282,13 +309,11 @@ func (tun *NativeTun) Name() (string, error) {
|
|||||||
if errno != 0 {
|
if errno != 0 {
|
||||||
return "", errors.New("failed to get name of TUN device: " + errno.Error())
|
return "", errors.New("failed to get name of TUN device: " + errno.Error())
|
||||||
}
|
}
|
||||||
nullStr := ifr[:]
|
name := ifr[:]
|
||||||
i := bytes.IndexByte(nullStr, 0)
|
if i := bytes.IndexByte(name, 0); i != -1 {
|
||||||
if i != -1 {
|
name = name[:i]
|
||||||
nullStr = nullStr[:i]
|
|
||||||
}
|
}
|
||||||
tun.name = string(nullStr)
|
return string(name), nil
|
||||||
return tun.name, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||||
@@ -367,6 +392,9 @@ func (tun *NativeTun) Close() error {
|
|||||||
func CreateTUN(name string, mtu int) (Device, error) {
|
func CreateTUN(name string, mtu int) (Device, error) {
|
||||||
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
|
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -408,16 +436,15 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
|||||||
statusListenersShutdown: make(chan struct{}),
|
statusListenersShutdown: make(chan struct{}),
|
||||||
nopi: false,
|
nopi: false,
|
||||||
}
|
}
|
||||||
var err error
|
|
||||||
|
|
||||||
_, err = tun.Name()
|
name, err := tun.Name()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// start event listener
|
// start event listener
|
||||||
|
|
||||||
tun.index, err = getIFIndex(tun.name)
|
tun.index, err = getIFIndex(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tun
|
package tun
|
||||||
@@ -42,34 +42,11 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
|||||||
|
|
||||||
defer close(tun.events)
|
defer close(tun.events)
|
||||||
|
|
||||||
data := make([]byte, os.Getpagesize())
|
check := func() bool {
|
||||||
for {
|
iface, err := net.InterfaceByIndex(tunIfindex)
|
||||||
retry:
|
|
||||||
n, err := unix.Read(tun.routeSocket, data)
|
|
||||||
if err != nil {
|
|
||||||
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
|
|
||||||
goto retry
|
|
||||||
}
|
|
||||||
tun.errors <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if n < 8 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if data[3 /* type */] != unix.RTM_IFINFO {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
|
|
||||||
if ifindex != tunIfindex {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
iface, err := net.InterfaceByIndex(ifindex)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tun.errors <- err
|
tun.errors <- err
|
||||||
return
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Up / Down event
|
// Up / Down event
|
||||||
@@ -87,6 +64,38 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
|||||||
tun.events <- EventMTUUpdate
|
tun.events <- EventMTUUpdate
|
||||||
}
|
}
|
||||||
statusMTU = iface.MTU
|
statusMTU = iface.MTU
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if check() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make([]byte, os.Getpagesize())
|
||||||
|
for {
|
||||||
|
n, err := unix.Read(tun.routeSocket, data)
|
||||||
|
if err != nil {
|
||||||
|
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tun.errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if n < 8 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if data[3 /* type */] != unix.RTM_IFINFO {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
|
||||||
|
if ifindex != tunIfindex {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if check() {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,7 +149,6 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||||
|
|
||||||
tun := &NativeTun{
|
tun := &NativeTun{
|
||||||
tunFile: file,
|
tunFile: file,
|
||||||
events: make(chan Event, 10),
|
events: make(chan Event, 10),
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: MIT
|
||||||
*
|
*
|
||||||
* Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2018-2020 WireGuard LLC. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package tun
|
package tun
|
||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
_ "unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
@@ -19,87 +19,79 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
packetAlignment uint32 = 4 // Number of bytes packets are aligned to in rings
|
rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
|
||||||
packetSizeMax = 0xffff // Maximum packet size
|
spinloopRateThreshold = 800000000 / 8 // 800mbps
|
||||||
packetCapacity = 0x800000 // Ring capacity, 8MiB
|
spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
|
||||||
packetTrailingSize = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment
|
|
||||||
ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type packetHeader struct {
|
type rateJuggler struct {
|
||||||
size uint32
|
current uint64
|
||||||
}
|
nextByteCount uint64
|
||||||
|
nextStartTime int64
|
||||||
type packet struct {
|
changing int32
|
||||||
packetHeader
|
|
||||||
data [packetSizeMax]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ring struct {
|
|
||||||
head uint32
|
|
||||||
tail uint32
|
|
||||||
alertable int32
|
|
||||||
data [packetCapacity + packetTrailingSize]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ringDescriptor struct {
|
|
||||||
send, receive struct {
|
|
||||||
size uint32
|
|
||||||
ring *ring
|
|
||||||
tailMoved windows.Handle
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type NativeTun struct {
|
type NativeTun struct {
|
||||||
wt *wintun.Wintun
|
wt *wintun.Adapter
|
||||||
handle windows.Handle
|
handle windows.Handle
|
||||||
close bool
|
close bool
|
||||||
rings ringDescriptor
|
|
||||||
events chan Event
|
events chan Event
|
||||||
errors chan error
|
errors chan error
|
||||||
forcedMTU int
|
forcedMTU int
|
||||||
|
rate rateJuggler
|
||||||
|
session wintun.Session
|
||||||
|
readWait windows.Handle
|
||||||
}
|
}
|
||||||
|
|
||||||
func packetAlign(size uint32) uint32 {
|
var WintunPool *wintun.Pool
|
||||||
return (size + (packetAlignment - 1)) &^ (packetAlignment - 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
func init() {
|
||||||
// CreateTUN creates a Wintun adapter with the given name. Should a Wintun
|
|
||||||
// adapter with the same name exist, it is reused.
|
|
||||||
//
|
|
||||||
func CreateTUN(ifname string) (Device, error) {
|
|
||||||
return CreateTUNWithRequestedGUID(ifname, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// CreateTUNWithRequestedGUID creates a Wintun adapter with the given name and
|
|
||||||
// a requested GUID. Should a Wintun adapter with the same name exist, it is reused.
|
|
||||||
//
|
|
||||||
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Device, error) {
|
|
||||||
var err error
|
var err error
|
||||||
var wt *wintun.Wintun
|
WintunPool, err = wintun.MakePool("WireGuard")
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("Failed to make pool: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:linkname procyield runtime.procyield
|
||||||
|
func procyield(cycles uint32)
|
||||||
|
|
||||||
|
//go:linkname nanotime runtime.nanotime
|
||||||
|
func nanotime() int64
|
||||||
|
|
||||||
|
//
|
||||||
|
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
|
||||||
|
// interface with the same name exist, it is reused.
|
||||||
|
//
|
||||||
|
func CreateTUN(ifname string, mtu int) (Device, error) {
|
||||||
|
return CreateTUNWithRequestedGUID(ifname, nil, mtu)
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
|
||||||
|
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
|
||||||
|
//
|
||||||
|
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
|
||||||
|
var err error
|
||||||
|
var wt *wintun.Adapter
|
||||||
|
|
||||||
// Does an interface with this name already exist?
|
// Does an interface with this name already exist?
|
||||||
wt, err = wintun.GetInterface(ifname)
|
wt, err = WintunPool.OpenAdapter(ifname)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// If so, we delete it, in case it has weird residual configuration.
|
// If so, we delete it, in case it has weird residual configuration.
|
||||||
_, err = wt.DeleteInterface()
|
_, err = wt.Delete(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Unable to delete already existing Wintun interface: %v", err)
|
return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
|
||||||
}
|
}
|
||||||
} else if err == windows.ERROR_ALREADY_EXISTS {
|
|
||||||
return nil, fmt.Errorf("Foreign network interface with the same name exists")
|
|
||||||
}
|
}
|
||||||
wt, _, err = wintun.CreateInterface("WireGuard Tunnel Adapter", requestedGUID)
|
wt, _, err = WintunPool.CreateAdapter(ifname, requestedGUID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Unable to create Wintun interface: %v", err)
|
return nil, fmt.Errorf("Error creating interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = wt.SetInterfaceName(ifname)
|
forcedMTU := 1420
|
||||||
if err != nil {
|
if mtu > 0 {
|
||||||
wt.DeleteInterface()
|
forcedMTU = mtu
|
||||||
return nil, fmt.Errorf("Unable to set name of Wintun interface: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tun := &NativeTun{
|
tun := &NativeTun{
|
||||||
@@ -107,42 +99,21 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
|
|||||||
handle: windows.InvalidHandle,
|
handle: windows.InvalidHandle,
|
||||||
events: make(chan Event, 10),
|
events: make(chan Event, 10),
|
||||||
errors: make(chan error, 1),
|
errors: make(chan error, 1),
|
||||||
forcedMTU: 1500,
|
forcedMTU: forcedMTU,
|
||||||
}
|
}
|
||||||
|
|
||||||
tun.rings.send.size = uint32(unsafe.Sizeof(ring{}))
|
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
|
||||||
tun.rings.send.ring = &ring{}
|
|
||||||
tun.rings.send.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tun.Close()
|
_, err = tun.wt.Delete(false)
|
||||||
return nil, fmt.Errorf("Error creating event: %v", err)
|
close(tun.events)
|
||||||
}
|
return nil, fmt.Errorf("Error starting session: %w", err)
|
||||||
|
|
||||||
tun.rings.receive.size = uint32(unsafe.Sizeof(ring{}))
|
|
||||||
tun.rings.receive.ring = &ring{}
|
|
||||||
tun.rings.receive.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
|
|
||||||
if err != nil {
|
|
||||||
tun.Close()
|
|
||||||
return nil, fmt.Errorf("Error creating event: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tun.handle, err = tun.wt.AdapterHandle()
|
|
||||||
if err != nil {
|
|
||||||
tun.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var bytesReturned uint32
|
|
||||||
err = windows.DeviceIoControl(tun.handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(&tun.rings)), uint32(unsafe.Sizeof(tun.rings)), nil, 0, &bytesReturned, nil)
|
|
||||||
if err != nil {
|
|
||||||
tun.Close()
|
|
||||||
return nil, fmt.Errorf("Error registering rings: %v", err)
|
|
||||||
}
|
}
|
||||||
|
tun.readWait = tun.session.ReadWaitEvent()
|
||||||
return tun, nil
|
return tun, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Name() (string, error) {
|
func (tun *NativeTun) Name() (string, error) {
|
||||||
return tun.wt.InterfaceName()
|
return tun.wt.Name()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) File() *os.File {
|
func (tun *NativeTun) File() *os.File {
|
||||||
@@ -155,21 +126,10 @@ func (tun *NativeTun) Events() chan Event {
|
|||||||
|
|
||||||
func (tun *NativeTun) Close() error {
|
func (tun *NativeTun) Close() error {
|
||||||
tun.close = true
|
tun.close = true
|
||||||
if tun.rings.send.tailMoved != 0 {
|
tun.session.End()
|
||||||
windows.SetEvent(tun.rings.send.tailMoved) // wake the reader if it's sleeping
|
|
||||||
}
|
|
||||||
if tun.handle != windows.InvalidHandle {
|
|
||||||
windows.CloseHandle(tun.handle)
|
|
||||||
}
|
|
||||||
if tun.rings.send.tailMoved != 0 {
|
|
||||||
windows.CloseHandle(tun.rings.send.tailMoved)
|
|
||||||
}
|
|
||||||
if tun.rings.send.tailMoved != 0 {
|
|
||||||
windows.CloseHandle(tun.rings.receive.tailMoved)
|
|
||||||
}
|
|
||||||
var err error
|
var err error
|
||||||
if tun.wt != nil {
|
if tun.wt != nil {
|
||||||
_, err = tun.wt.DeleteInterface()
|
_, err = tun.wt.Delete(false)
|
||||||
}
|
}
|
||||||
close(tun.events)
|
close(tun.events)
|
||||||
return err
|
return err
|
||||||
@@ -184,9 +144,6 @@ func (tun *NativeTun) ForceMTU(mtu int) {
|
|||||||
tun.forcedMTU = mtu
|
tun.forcedMTU = mtu
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:linkname procyield runtime.procyield
|
|
||||||
func procyield(cycles uint32)
|
|
||||||
|
|
||||||
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
|
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
|
||||||
|
|
||||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||||
@@ -196,54 +153,34 @@ retry:
|
|||||||
return 0, err
|
return 0, err
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
if tun.close {
|
start := nanotime()
|
||||||
return 0, os.ErrClosed
|
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
|
||||||
}
|
|
||||||
|
|
||||||
buffHead := atomic.LoadUint32(&tun.rings.send.ring.head)
|
|
||||||
if buffHead >= packetCapacity {
|
|
||||||
return 0, os.ErrClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
var buffTail uint32
|
|
||||||
for {
|
for {
|
||||||
buffTail = atomic.LoadUint32(&tun.rings.send.ring.tail)
|
|
||||||
if buffHead != buffTail {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if tun.close {
|
if tun.close {
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
if time.Since(start) >= time.Millisecond/80 /* ~1gbit/s */ {
|
packet, err := tun.session.ReceivePacket()
|
||||||
windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE)
|
switch err {
|
||||||
goto retry
|
case nil:
|
||||||
|
packetSize := len(packet)
|
||||||
|
copy(buff[offset:], packet)
|
||||||
|
tun.session.ReleaseReceivePacket(packet)
|
||||||
|
tun.rate.update(uint64(packetSize))
|
||||||
|
return packetSize, nil
|
||||||
|
case windows.ERROR_NO_MORE_ITEMS:
|
||||||
|
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
||||||
|
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
procyield(1)
|
||||||
|
continue
|
||||||
|
case windows.ERROR_HANDLE_EOF:
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
case windows.ERROR_INVALID_DATA:
|
||||||
|
return 0, errors.New("Send ring corrupt")
|
||||||
}
|
}
|
||||||
procyield(1)
|
return 0, fmt.Errorf("Read failed: %w", err)
|
||||||
}
|
}
|
||||||
if buffTail >= packetCapacity {
|
|
||||||
return 0, os.ErrClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
buffContent := tun.rings.send.ring.wrap(buffTail - buffHead)
|
|
||||||
if buffContent < uint32(unsafe.Sizeof(packetHeader{})) {
|
|
||||||
return 0, errors.New("incomplete packet header in send ring")
|
|
||||||
}
|
|
||||||
|
|
||||||
packet := (*packet)(unsafe.Pointer(&tun.rings.send.ring.data[buffHead]))
|
|
||||||
if packet.size > packetSizeMax {
|
|
||||||
return 0, errors.New("packet too big in send ring")
|
|
||||||
}
|
|
||||||
|
|
||||||
alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packet.size)
|
|
||||||
if alignedPacketSize > buffContent {
|
|
||||||
return 0, errors.New("incomplete packet in send ring")
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(buff[offset:], packet.data[:packet.size])
|
|
||||||
buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize)
|
|
||||||
atomic.StoreUint32(&tun.rings.send.ring.head, buffHead)
|
|
||||||
return int(packet.size), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Flush() error {
|
func (tun *NativeTun) Flush() error {
|
||||||
@@ -255,40 +192,45 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
packetSize := uint32(len(buff) - offset)
|
packetSize := len(buff) - offset
|
||||||
alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize)
|
tun.rate.update(uint64(packetSize))
|
||||||
|
|
||||||
buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head)
|
packet, err := tun.session.AllocateSendPacket(packetSize)
|
||||||
if buffHead >= packetCapacity {
|
if err == nil {
|
||||||
return 0, os.ErrClosed
|
copy(packet, buff[offset:])
|
||||||
|
tun.session.SendPacket(packet)
|
||||||
|
return packetSize, nil
|
||||||
}
|
}
|
||||||
|
switch err {
|
||||||
buffTail := atomic.LoadUint32(&tun.rings.receive.ring.tail)
|
case windows.ERROR_HANDLE_EOF:
|
||||||
if buffTail >= packetCapacity {
|
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
}
|
case windows.ERROR_BUFFER_OVERFLOW:
|
||||||
|
|
||||||
buffSpace := tun.rings.receive.ring.wrap(buffHead - buffTail - packetAlignment)
|
|
||||||
if alignedPacketSize > buffSpace {
|
|
||||||
return 0, nil // Dropping when ring is full.
|
return 0, nil // Dropping when ring is full.
|
||||||
}
|
}
|
||||||
|
return 0, fmt.Errorf("Write failed: %w", err)
|
||||||
packet := (*packet)(unsafe.Pointer(&tun.rings.receive.ring.data[buffTail]))
|
|
||||||
packet.size = packetSize
|
|
||||||
copy(packet.data[:packetSize], buff[offset:])
|
|
||||||
atomic.StoreUint32(&tun.rings.receive.ring.tail, tun.rings.receive.ring.wrap(buffTail+alignedPacketSize))
|
|
||||||
if atomic.LoadInt32(&tun.rings.receive.ring.alertable) != 0 {
|
|
||||||
windows.SetEvent(tun.rings.receive.tailMoved)
|
|
||||||
}
|
|
||||||
return int(packetSize), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LUID returns Windows adapter instance ID.
|
// LUID returns Windows interface instance ID.
|
||||||
func (tun *NativeTun) LUID() uint64 {
|
func (tun *NativeTun) LUID() uint64 {
|
||||||
return tun.wt.LUID()
|
return tun.wt.LUID()
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrap returns value modulo ring capacity
|
// RunningVersion returns the running version of the Wintun driver.
|
||||||
func (rb *ring) wrap(value uint32) uint32 {
|
func (tun *NativeTun) RunningVersion() (version uint32, err error) {
|
||||||
return value & (packetCapacity - 1)
|
return wintun.RunningVersion()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rate *rateJuggler) update(packetLen uint64) {
|
||||||
|
now := nanotime()
|
||||||
|
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
|
||||||
|
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
|
||||||
|
if period >= rateMeasurementGranularity {
|
||||||
|
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
atomic.StoreInt64(&rate.nextStartTime, now)
|
||||||
|
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
|
||||||
|
atomic.StoreUint64(&rate.nextByteCount, 0)
|
||||||
|
atomic.StoreInt32(&rate.changing, 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
150
tun/tuntest/tuntest.go
Normal file
150
tun/tuntest/tuntest.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package tuntest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Ping(dst, src net.IP) []byte {
|
||||||
|
localPort := uint16(1337)
|
||||||
|
seq := uint16(0)
|
||||||
|
|
||||||
|
payload := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint16(payload[0:], localPort)
|
||||||
|
binary.BigEndian.PutUint16(payload[2:], seq)
|
||||||
|
|
||||||
|
return genICMPv4(payload, dst, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
|
||||||
|
func checksum(buf []byte, initial uint16) uint16 {
|
||||||
|
v := uint32(initial)
|
||||||
|
for i := 0; i < len(buf)-1; i += 2 {
|
||||||
|
v += uint32(binary.BigEndian.Uint16(buf[i:]))
|
||||||
|
}
|
||||||
|
if len(buf)%2 == 1 {
|
||||||
|
v += uint32(buf[len(buf)-1]) << 8
|
||||||
|
}
|
||||||
|
for v > 0xffff {
|
||||||
|
v = (v >> 16) + (v & 0xffff)
|
||||||
|
}
|
||||||
|
return ^uint16(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
||||||
|
const (
|
||||||
|
icmpv4ProtocolNumber = 1
|
||||||
|
icmpv4Echo = 8
|
||||||
|
icmpv4ChecksumOffset = 2
|
||||||
|
icmpv4Size = 8
|
||||||
|
ipv4Size = 20
|
||||||
|
ipv4TotalLenOffset = 2
|
||||||
|
ipv4ChecksumOffset = 10
|
||||||
|
ttl = 65
|
||||||
|
)
|
||||||
|
|
||||||
|
hdr := make([]byte, ipv4Size+icmpv4Size)
|
||||||
|
|
||||||
|
ip := hdr[0:ipv4Size]
|
||||||
|
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc792
|
||||||
|
icmpv4[0] = icmpv4Echo // type
|
||||||
|
icmpv4[1] = 0 // code
|
||||||
|
chksum := ^checksum(icmpv4, checksum(payload, 0))
|
||||||
|
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc760 section 3.1
|
||||||
|
length := uint16(len(hdr) + len(payload))
|
||||||
|
ip[0] = (4 << 4) | (ipv4Size / 4)
|
||||||
|
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
|
||||||
|
ip[8] = ttl
|
||||||
|
ip[9] = icmpv4ProtocolNumber
|
||||||
|
copy(ip[12:], src.To4())
|
||||||
|
copy(ip[16:], dst.To4())
|
||||||
|
chksum = ^checksum(ip[:], 0)
|
||||||
|
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
|
||||||
|
|
||||||
|
var v []byte
|
||||||
|
v = append(v, hdr...)
|
||||||
|
v = append(v, payload...)
|
||||||
|
return []byte(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(crawshaw): find a reusable home for this. package devicetest?
|
||||||
|
type ChannelTUN struct {
|
||||||
|
Inbound chan []byte // incoming packets, closed on TUN close
|
||||||
|
Outbound chan []byte // outbound packets, blocks forever on TUN close
|
||||||
|
|
||||||
|
closed chan struct{}
|
||||||
|
events chan tun.Event
|
||||||
|
tun chTun
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChannelTUN() *ChannelTUN {
|
||||||
|
c := &ChannelTUN{
|
||||||
|
Inbound: make(chan []byte),
|
||||||
|
Outbound: make(chan []byte),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
events: make(chan tun.Event, 1),
|
||||||
|
}
|
||||||
|
c.tun.c = c
|
||||||
|
c.events <- tun.EventUp
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelTUN) TUN() tun.Device {
|
||||||
|
return &c.tun
|
||||||
|
}
|
||||||
|
|
||||||
|
type chTun struct {
|
||||||
|
c *ChannelTUN
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *chTun) File() *os.File { return nil }
|
||||||
|
|
||||||
|
func (t *chTun) Read(data []byte, offset int) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-t.c.closed:
|
||||||
|
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||||
|
case msg := <-t.c.Outbound:
|
||||||
|
return copy(data[offset:], msg), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is called by the wireguard device to deliver a packet for routing.
|
||||||
|
func (t *chTun) Write(data []byte, offset int) (int, error) {
|
||||||
|
if offset == -1 {
|
||||||
|
close(t.c.closed)
|
||||||
|
close(t.c.events)
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
msg := make([]byte, len(data)-offset)
|
||||||
|
copy(msg, data[offset:])
|
||||||
|
select {
|
||||||
|
case <-t.c.closed:
|
||||||
|
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||||
|
case t.c.Inbound <- msg:
|
||||||
|
return len(data) - offset, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const DefaultMTU = 1420
|
||||||
|
|
||||||
|
func (t *chTun) Flush() error { return nil }
|
||||||
|
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
|
||||||
|
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
|
||||||
|
func (t *chTun) Events() chan tun.Event { return t.c.events }
|
||||||
|
func (t *chTun) Close() error {
|
||||||
|
t.Write(nil, -1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
50
tun/wintun/dll_fromfile_windows.go
Normal file
50
tun/wintun/dll_fromfile_windows.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
// +build !load_wintun_from_rsrc
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package wintun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
type lazyDLL struct {
|
||||||
|
Name string
|
||||||
|
mu sync.Mutex
|
||||||
|
module windows.Handle
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *lazyDLL) Load() error {
|
||||||
|
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
if d.module != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
|
||||||
|
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
|
||||||
|
)
|
||||||
|
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to load library: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lazyProc) nameToAddr() (uintptr, error) {
|
||||||
|
return windows.GetProcAddress(p.dll.module, p.Name)
|
||||||
|
}
|
||||||
58
tun/wintun/dll_fromrsrc_windows.go
Normal file
58
tun/wintun/dll_fromrsrc_windows.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
// +build load_wintun_from_rsrc
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package wintun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun/wintun/memmod"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/wintun/resource"
|
||||||
|
)
|
||||||
|
|
||||||
|
type lazyDLL struct {
|
||||||
|
Name string
|
||||||
|
mu sync.Mutex
|
||||||
|
module *memmod.Module
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *lazyDLL) Load() error {
|
||||||
|
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
if d.module != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const ourModule windows.Handle = 0
|
||||||
|
resInfo, err := resource.FindByName(ourModule, d.Name, resource.RT_RCDATA)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
|
||||||
|
}
|
||||||
|
data, err := resource.Load(ourModule, resInfo)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to load resource: %w", err)
|
||||||
|
}
|
||||||
|
module, err := memmod.LoadLibrary(data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Unable to load library: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lazyProc) nameToAddr() (uintptr, error) {
|
||||||
|
return p.dll.module.ProcAddressByName(p.Name)
|
||||||
|
}
|
||||||
59
tun/wintun/dll_windows.go
Normal file
59
tun/wintun/dll_windows.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package wintun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newLazyDLL(name string) *lazyDLL {
|
||||||
|
return &lazyDLL{Name: name}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *lazyDLL) NewProc(name string) *lazyProc {
|
||||||
|
return &lazyProc{dll: d, Name: name}
|
||||||
|
}
|
||||||
|
|
||||||
|
type lazyProc struct {
|
||||||
|
Name string
|
||||||
|
mu sync.Mutex
|
||||||
|
dll *lazyDLL
|
||||||
|
addr uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lazyProc) Find() error {
|
||||||
|
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
if p.addr != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := p.dll.Load()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
|
||||||
|
}
|
||||||
|
addr, err := p.nameToAddr()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error getting %v address: %w", p.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *lazyProc) Addr() uintptr {
|
||||||
|
err := p.Find()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return p.addr
|
||||||
|
}
|
||||||
620
tun/wintun/memmod/memmod_windows.go
Normal file
620
tun/wintun/memmod/memmod_windows.go
Normal file
@@ -0,0 +1,620 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
type addressList struct {
|
||||||
|
next *addressList
|
||||||
|
address uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (head *addressList) free() {
|
||||||
|
for node := head; node != nil; node = node.next {
|
||||||
|
windows.VirtualFree(node.address, 0, windows.MEM_RELEASE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Module struct {
|
||||||
|
headers *IMAGE_NT_HEADERS
|
||||||
|
codeBase uintptr
|
||||||
|
modules []windows.Handle
|
||||||
|
initialized bool
|
||||||
|
isDLL bool
|
||||||
|
isRelocated bool
|
||||||
|
nameExports map[string]uint16
|
||||||
|
entry uintptr
|
||||||
|
blockedMemory *addressList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) headerDirectory(idx int) *IMAGE_DATA_DIRECTORY {
|
||||||
|
return &module.headers.OptionalHeader.DataDirectory[idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) copySections(address uintptr, size uintptr, old_headers *IMAGE_NT_HEADERS) error {
|
||||||
|
sections := module.headers.Sections()
|
||||||
|
for i := range sections {
|
||||||
|
if sections[i].SizeOfRawData == 0 {
|
||||||
|
// Section doesn't contain data in the dll itself, but may define uninitialized data.
|
||||||
|
sectionSize := old_headers.OptionalHeader.SectionAlignment
|
||||||
|
if sectionSize == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
|
||||||
|
uintptr(sectionSize),
|
||||||
|
windows.MEM_COMMIT,
|
||||||
|
windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error allocating section: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
|
||||||
|
dest = module.codeBase + uintptr(sections[i].VirtualAddress)
|
||||||
|
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
|
||||||
|
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
|
||||||
|
var dst []byte
|
||||||
|
unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
|
||||||
|
for j := range dst {
|
||||||
|
dst[j] = 0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if size < uintptr(sections[i].PointerToRawData+sections[i].SizeOfRawData) {
|
||||||
|
return errors.New("Incomplete section")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit memory block and copy data from dll.
|
||||||
|
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
|
||||||
|
uintptr(sections[i].SizeOfRawData),
|
||||||
|
windows.MEM_COMMIT,
|
||||||
|
windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error allocating memory block: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
|
||||||
|
memcpy(
|
||||||
|
module.codeBase+uintptr(sections[i].VirtualAddress),
|
||||||
|
address+uintptr(sections[i].PointerToRawData),
|
||||||
|
uintptr(sections[i].SizeOfRawData))
|
||||||
|
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
|
||||||
|
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) realSectionSize(section *IMAGE_SECTION_HEADER) uintptr {
|
||||||
|
size := section.SizeOfRawData
|
||||||
|
if size != 0 {
|
||||||
|
return uintptr(size)
|
||||||
|
}
|
||||||
|
if (section.Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) != 0 {
|
||||||
|
return uintptr(module.headers.OptionalHeader.SizeOfInitializedData)
|
||||||
|
}
|
||||||
|
if (section.Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) != 0 {
|
||||||
|
return uintptr(module.headers.OptionalHeader.SizeOfUninitializedData)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
type sectionFinalizeData struct {
|
||||||
|
address uintptr
|
||||||
|
alignedAddress uintptr
|
||||||
|
size uintptr
|
||||||
|
characteristics uint32
|
||||||
|
last bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) finalizeSection(sectionData *sectionFinalizeData) error {
|
||||||
|
if sectionData.size == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) != 0 {
|
||||||
|
// Section is not needed any more and can safely be freed.
|
||||||
|
if sectionData.address == sectionData.alignedAddress &&
|
||||||
|
(sectionData.last ||
|
||||||
|
(sectionData.size%uintptr(module.headers.OptionalHeader.SectionAlignment)) == 0) {
|
||||||
|
// Only allowed to decommit whole pages.
|
||||||
|
windows.VirtualFree(sectionData.address, sectionData.size, windows.MEM_DECOMMIT)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// determine protection flags based on characteristics
|
||||||
|
var ProtectionFlags = [8]uint32{
|
||||||
|
windows.PAGE_NOACCESS, // not writeable, not readable, not executable
|
||||||
|
windows.PAGE_EXECUTE, // not writeable, not readable, executable
|
||||||
|
windows.PAGE_READONLY, // not writeable, readable, not executable
|
||||||
|
windows.PAGE_EXECUTE_READ, // not writeable, readable, executable
|
||||||
|
windows.PAGE_WRITECOPY, // writeable, not readable, not executable
|
||||||
|
windows.PAGE_EXECUTE_WRITECOPY, // writeable, not readable, executable
|
||||||
|
windows.PAGE_READWRITE, // writeable, readable, not executable
|
||||||
|
windows.PAGE_EXECUTE_READWRITE, // writeable, readable, executable
|
||||||
|
}
|
||||||
|
protect := ProtectionFlags[sectionData.characteristics>>29]
|
||||||
|
if (sectionData.characteristics & IMAGE_SCN_MEM_NOT_CACHED) != 0 {
|
||||||
|
protect |= windows.PAGE_NOCACHE
|
||||||
|
}
|
||||||
|
|
||||||
|
// Change memory access flags.
|
||||||
|
var oldProtect uint32
|
||||||
|
err := windows.VirtualProtect(sectionData.address, sectionData.size, protect, &oldProtect)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error protecting memory page: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) finalizeSections() error {
|
||||||
|
sections := module.headers.Sections()
|
||||||
|
imageOffset := module.headers.OptionalHeader.imageOffset()
|
||||||
|
sectionData := sectionFinalizeData{}
|
||||||
|
sectionData.address = uintptr(sections[0].PhysicalAddress()) | imageOffset
|
||||||
|
sectionData.alignedAddress = alignDown(sectionData.address, uintptr(module.headers.OptionalHeader.SectionAlignment))
|
||||||
|
sectionData.size = module.realSectionSize(§ions[0])
|
||||||
|
sectionData.characteristics = sections[0].Characteristics
|
||||||
|
|
||||||
|
// Loop through all sections and change access flags.
|
||||||
|
for i := uint16(1); i < module.headers.FileHeader.NumberOfSections; i++ {
|
||||||
|
sectionAddress := uintptr(sections[i].PhysicalAddress()) | imageOffset
|
||||||
|
alignedAddress := alignDown(sectionAddress, uintptr(module.headers.OptionalHeader.SectionAlignment))
|
||||||
|
sectionSize := module.realSectionSize(§ions[i])
|
||||||
|
// Combine access flags of all sections that share a page.
|
||||||
|
// TODO: We currently share flags of a trailing large section with the page of a first small section. This should be optimized.
|
||||||
|
if sectionData.alignedAddress == alignedAddress || sectionData.address+sectionData.size > alignedAddress {
|
||||||
|
// Section shares page with previous.
|
||||||
|
if (sections[i].Characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 {
|
||||||
|
sectionData.characteristics = (sectionData.characteristics | sections[i].Characteristics) &^ IMAGE_SCN_MEM_DISCARDABLE
|
||||||
|
} else {
|
||||||
|
sectionData.characteristics |= sections[i].Characteristics
|
||||||
|
}
|
||||||
|
sectionData.size = sectionAddress + sectionSize - sectionData.address
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := module.finalizeSection(§ionData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error finalizing section: %w", err)
|
||||||
|
}
|
||||||
|
sectionData.address = sectionAddress
|
||||||
|
sectionData.alignedAddress = alignedAddress
|
||||||
|
sectionData.size = sectionSize
|
||||||
|
sectionData.characteristics = sections[i].Characteristics
|
||||||
|
}
|
||||||
|
sectionData.last = true
|
||||||
|
err := module.finalizeSection(§ionData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error finalizing section: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) executeTLS() {
|
||||||
|
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_TLS)
|
||||||
|
if directory.VirtualAddress == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tls := (*IMAGE_TLS_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||||
|
callback := tls.AddressOfCallbacks
|
||||||
|
if callback != 0 {
|
||||||
|
for {
|
||||||
|
f := *(*uintptr)(a2p(callback))
|
||||||
|
if f == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
syscall.Syscall(f, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), uintptr(0))
|
||||||
|
callback += unsafe.Sizeof(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err error) {
|
||||||
|
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_BASERELOC)
|
||||||
|
if directory.Size == 0 {
|
||||||
|
return delta == 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
relocationHdr := (*IMAGE_BASE_RELOCATION)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||||
|
for relocationHdr.VirtualAddress > 0 {
|
||||||
|
dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
|
||||||
|
|
||||||
|
var relInfos []uint16
|
||||||
|
unsafeSlice(
|
||||||
|
unsafe.Pointer(&relInfos),
|
||||||
|
a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
|
||||||
|
int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
|
||||||
|
for _, relInfo := range relInfos {
|
||||||
|
// The upper 4 bits define the type of relocation.
|
||||||
|
relType := relInfo >> 12
|
||||||
|
// The lower 12 bits define the offset.
|
||||||
|
relOffset := uintptr(relInfo & 0xfff)
|
||||||
|
|
||||||
|
switch relType {
|
||||||
|
case IMAGE_REL_BASED_ABSOLUTE:
|
||||||
|
// Skip relocation.
|
||||||
|
|
||||||
|
case IMAGE_REL_BASED_LOW:
|
||||||
|
*(*uint16)(a2p(dest + relOffset)) += uint16(delta & 0xffff)
|
||||||
|
break
|
||||||
|
|
||||||
|
case IMAGE_REL_BASED_HIGH:
|
||||||
|
*(*uint16)(a2p(dest + relOffset)) += uint16(uint32(delta) >> 16)
|
||||||
|
break
|
||||||
|
|
||||||
|
case IMAGE_REL_BASED_HIGHLOW:
|
||||||
|
*(*uint32)(a2p(dest + relOffset)) += uint32(delta)
|
||||||
|
|
||||||
|
case IMAGE_REL_BASED_DIR64:
|
||||||
|
*(*uint64)(a2p(dest + relOffset)) += uint64(delta)
|
||||||
|
|
||||||
|
case IMAGE_REL_BASED_THUMB_MOV32:
|
||||||
|
inst := *(*uint32)(a2p(dest + relOffset))
|
||||||
|
imm16 := ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
|
||||||
|
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
|
||||||
|
if (inst & 0x8000fbf0) != 0x0000f240 {
|
||||||
|
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVW", inst)
|
||||||
|
}
|
||||||
|
imm16 += uint32(delta) & 0xffff
|
||||||
|
hiDelta := (uint32(delta&0xffff0000) >> 16) + ((imm16 & 0xffff0000) >> 16)
|
||||||
|
*(*uint32)(a2p(dest + relOffset)) = (inst & 0x8f00fbf0) + ((imm16 >> 1) & 0x0400) +
|
||||||
|
((imm16 >> 12) & 0x000f) +
|
||||||
|
((imm16 << 20) & 0x70000000) +
|
||||||
|
((imm16 << 16) & 0xff0000)
|
||||||
|
if hiDelta != 0 {
|
||||||
|
inst = *(*uint32)(a2p(dest + relOffset + 4))
|
||||||
|
imm16 = ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
|
||||||
|
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
|
||||||
|
if (inst & 0x8000fbf0) != 0x0000f2c0 {
|
||||||
|
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVT", inst)
|
||||||
|
}
|
||||||
|
imm16 += hiDelta
|
||||||
|
if imm16 > 0xffff {
|
||||||
|
return false, fmt.Errorf("Resulting immediate value won't fit: %08x", imm16)
|
||||||
|
}
|
||||||
|
*(*uint32)(a2p(dest + relOffset + 4)) = (inst & 0x8f00fbf0) +
|
||||||
|
((imm16 >> 1) & 0x0400) +
|
||||||
|
((imm16 >> 12) & 0x000f) +
|
||||||
|
((imm16 << 20) & 0x70000000) +
|
||||||
|
((imm16 << 16) & 0xff0000)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("Unsupported relocation: %w", relType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance to next relocation block.
|
||||||
|
relocationHdr = (*IMAGE_BASE_RELOCATION)(a2p(uintptr(unsafe.Pointer(relocationHdr)) + uintptr(relocationHdr.SizeOfBlock)))
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) buildImportTable() error {
|
||||||
|
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_IMPORT)
|
||||||
|
if directory.Size == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
module.modules = make([]windows.Handle, 0, 16)
|
||||||
|
importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||||
|
for !isBadReadPtr(uintptr(unsafe.Pointer(importDesc)), unsafe.Sizeof(*importDesc)) && importDesc.Name != 0 {
|
||||||
|
handle, err := loadLibraryA((*byte)(a2p(module.codeBase + uintptr(importDesc.Name))))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error loading module: %w", err)
|
||||||
|
}
|
||||||
|
var thunkRef, funcRef *uintptr
|
||||||
|
if importDesc.OriginalFirstThunk() != 0 {
|
||||||
|
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.OriginalFirstThunk())))
|
||||||
|
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
|
||||||
|
} else {
|
||||||
|
// No hint table.
|
||||||
|
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
|
||||||
|
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
|
||||||
|
}
|
||||||
|
for *thunkRef != 0 {
|
||||||
|
if IMAGE_SNAP_BY_ORDINAL(*thunkRef) {
|
||||||
|
*funcRef, err = getProcAddress(handle, (*byte)(a2p(IMAGE_ORDINAL(*thunkRef))))
|
||||||
|
} else {
|
||||||
|
thunkData := (*IMAGE_IMPORT_BY_NAME)(a2p(module.codeBase + *thunkRef))
|
||||||
|
*funcRef, err = getProcAddress(handle, &thunkData.Name[0])
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
windows.FreeLibrary(handle)
|
||||||
|
return fmt.Errorf("Error getting function address: %w", err)
|
||||||
|
}
|
||||||
|
thunkRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(thunkRef)) + unsafe.Sizeof(*thunkRef)))
|
||||||
|
funcRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(funcRef)) + unsafe.Sizeof(*funcRef)))
|
||||||
|
}
|
||||||
|
module.modules = append(module.modules, handle)
|
||||||
|
importDesc = (*IMAGE_IMPORT_DESCRIPTOR)(a2p(uintptr(unsafe.Pointer(importDesc)) + unsafe.Sizeof(*importDesc)))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) buildNameExports() error {
|
||||||
|
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
|
||||||
|
if directory.Size == 0 {
|
||||||
|
return errors.New("No export table found")
|
||||||
|
}
|
||||||
|
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||||
|
if exports.NumberOfNames == 0 || exports.NumberOfFunctions == 0 {
|
||||||
|
return errors.New("No functions exported")
|
||||||
|
}
|
||||||
|
if exports.NumberOfNames == 0 {
|
||||||
|
return errors.New("No functions exported by name")
|
||||||
|
}
|
||||||
|
var nameRefs []uint32
|
||||||
|
unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
|
||||||
|
var ordinals []uint16
|
||||||
|
unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
|
||||||
|
module.nameExports = make(map[string]uint16)
|
||||||
|
for i := range nameRefs {
|
||||||
|
nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
|
||||||
|
module.nameExports[nameArray] = ordinals[i]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadLibrary loads module image to memory.
|
||||||
|
func LoadLibrary(data []byte) (module *Module, err error) {
|
||||||
|
addr := uintptr(unsafe.Pointer(&data[0]))
|
||||||
|
size := uintptr(len(data))
|
||||||
|
if size < unsafe.Sizeof(IMAGE_DOS_HEADER{}) {
|
||||||
|
return nil, errors.New("Incomplete IMAGE_DOS_HEADER")
|
||||||
|
}
|
||||||
|
dosHeader := (*IMAGE_DOS_HEADER)(a2p(addr))
|
||||||
|
if dosHeader.E_magic != IMAGE_DOS_SIGNATURE {
|
||||||
|
return nil, fmt.Errorf("Not an MS-DOS binary (provided: %x, expected: %x)", dosHeader.E_magic, IMAGE_DOS_SIGNATURE)
|
||||||
|
}
|
||||||
|
if (size < uintptr(dosHeader.E_lfanew)+unsafe.Sizeof(IMAGE_NT_HEADERS{})) {
|
||||||
|
return nil, errors.New("Incomplete IMAGE_NT_HEADERS")
|
||||||
|
}
|
||||||
|
oldHeader := (*IMAGE_NT_HEADERS)(a2p(addr + uintptr(dosHeader.E_lfanew)))
|
||||||
|
if oldHeader.Signature != IMAGE_NT_SIGNATURE {
|
||||||
|
return nil, fmt.Errorf("Not an NT binary (provided: %x, expected: %x)", oldHeader.Signature, IMAGE_NT_SIGNATURE)
|
||||||
|
}
|
||||||
|
if oldHeader.FileHeader.Machine != imageFileProcess {
|
||||||
|
return nil, fmt.Errorf("Foreign platform (provided: %x, expected: %x)", oldHeader.FileHeader.Machine, imageFileProcess)
|
||||||
|
}
|
||||||
|
if (oldHeader.OptionalHeader.SectionAlignment & 1) != 0 {
|
||||||
|
return nil, errors.New("Unaligned section")
|
||||||
|
}
|
||||||
|
lastSectionEnd := uintptr(0)
|
||||||
|
sections := oldHeader.Sections()
|
||||||
|
optionalSectionSize := oldHeader.OptionalHeader.SectionAlignment
|
||||||
|
for i := range sections {
|
||||||
|
var endOfSection uintptr
|
||||||
|
if sections[i].SizeOfRawData == 0 {
|
||||||
|
// Section without data in the DLL
|
||||||
|
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(optionalSectionSize)
|
||||||
|
} else {
|
||||||
|
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(sections[i].SizeOfRawData)
|
||||||
|
}
|
||||||
|
if endOfSection > lastSectionEnd {
|
||||||
|
lastSectionEnd = endOfSection
|
||||||
|
}
|
||||||
|
}
|
||||||
|
alignedImageSize := alignUp(uintptr(oldHeader.OptionalHeader.SizeOfImage), uintptr(oldHeader.OptionalHeader.SectionAlignment))
|
||||||
|
if alignedImageSize != alignUp(lastSectionEnd, uintptr(oldHeader.OptionalHeader.SectionAlignment)) {
|
||||||
|
return nil, errors.New("Section is not page-aligned")
|
||||||
|
}
|
||||||
|
|
||||||
|
module = &Module{isDLL: (oldHeader.FileHeader.Characteristics & IMAGE_FILE_DLL) != 0}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
module.Free()
|
||||||
|
module = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Reserve memory for image of library.
|
||||||
|
// TODO: Is it correct to commit the complete memory region at once? Calling DllEntry raises an exception if we don't.
|
||||||
|
module.codeBase, err = windows.VirtualAlloc(oldHeader.OptionalHeader.ImageBase,
|
||||||
|
alignedImageSize,
|
||||||
|
windows.MEM_RESERVE|windows.MEM_COMMIT,
|
||||||
|
windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
// Try to allocate memory at arbitrary position.
|
||||||
|
module.codeBase, err = windows.VirtualAlloc(0,
|
||||||
|
alignedImageSize,
|
||||||
|
windows.MEM_RESERVE|windows.MEM_COMMIT,
|
||||||
|
windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Error allocating code: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = module.check4GBBoundaries(alignedImageSize)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Error reallocating code: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if size < uintptr(oldHeader.OptionalHeader.SizeOfHeaders) {
|
||||||
|
err = errors.New("Incomplete headers")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Commit memory for headers.
|
||||||
|
headers, err := windows.VirtualAlloc(module.codeBase,
|
||||||
|
uintptr(oldHeader.OptionalHeader.SizeOfHeaders),
|
||||||
|
windows.MEM_COMMIT,
|
||||||
|
windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Error allocating headers: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Copy PE header to code.
|
||||||
|
memcpy(headers, addr, uintptr(oldHeader.OptionalHeader.SizeOfHeaders))
|
||||||
|
module.headers = (*IMAGE_NT_HEADERS)(a2p(headers + uintptr(dosHeader.E_lfanew)))
|
||||||
|
|
||||||
|
// Update position.
|
||||||
|
module.headers.OptionalHeader.ImageBase = module.codeBase
|
||||||
|
|
||||||
|
// Copy sections from DLL file block to new memory location.
|
||||||
|
err = module.copySections(addr, size, oldHeader)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Error copying sections: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust base address of imported data.
|
||||||
|
locationDelta := module.headers.OptionalHeader.ImageBase - oldHeader.OptionalHeader.ImageBase
|
||||||
|
if locationDelta != 0 {
|
||||||
|
module.isRelocated, err = module.performBaseRelocation(locationDelta)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Error relocating module: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
module.isRelocated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load required dlls and adjust function table of imports.
|
||||||
|
err = module.buildImportTable()
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Error building import table: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark memory pages depending on section headers and release sections that are marked as "discardable".
|
||||||
|
err = module.finalizeSections()
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Error finalizing sections: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLS callbacks are executed BEFORE the main loading.
|
||||||
|
module.executeTLS()
|
||||||
|
|
||||||
|
// Get entry point of loaded module.
|
||||||
|
if module.headers.OptionalHeader.AddressOfEntryPoint != 0 {
|
||||||
|
module.entry = module.codeBase + uintptr(module.headers.OptionalHeader.AddressOfEntryPoint)
|
||||||
|
if module.isDLL {
|
||||||
|
// Notify library about attaching to process.
|
||||||
|
r0, _, _ := syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), 0)
|
||||||
|
successful := r0 != 0
|
||||||
|
if !successful {
|
||||||
|
err = windows.ERROR_DLL_INIT_FAILED
|
||||||
|
return
|
||||||
|
}
|
||||||
|
module.initialized = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.buildNameExports()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free releases module resources and unloads it.
|
||||||
|
func (module *Module) Free() {
|
||||||
|
if module.initialized {
|
||||||
|
// Notify library about detaching from process.
|
||||||
|
syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_DETACH), 0)
|
||||||
|
module.initialized = false
|
||||||
|
}
|
||||||
|
if module.modules != nil {
|
||||||
|
// Free previously opened libraries.
|
||||||
|
for _, handle := range module.modules {
|
||||||
|
windows.FreeLibrary(handle)
|
||||||
|
}
|
||||||
|
module.modules = nil
|
||||||
|
}
|
||||||
|
if module.codeBase != 0 {
|
||||||
|
windows.VirtualFree(module.codeBase, 0, windows.MEM_RELEASE)
|
||||||
|
module.codeBase = 0
|
||||||
|
}
|
||||||
|
if module.blockedMemory != nil {
|
||||||
|
module.blockedMemory.free()
|
||||||
|
module.blockedMemory = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcAddressByName returns function address by exported name.
|
||||||
|
func (module *Module) ProcAddressByName(name string) (uintptr, error) {
|
||||||
|
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
|
||||||
|
if directory.Size == 0 {
|
||||||
|
return 0, errors.New("No export table found")
|
||||||
|
}
|
||||||
|
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||||
|
if module.nameExports == nil {
|
||||||
|
return 0, errors.New("No functions exported by name")
|
||||||
|
}
|
||||||
|
if idx, ok := module.nameExports[name]; ok {
|
||||||
|
if uint32(idx) > exports.NumberOfFunctions {
|
||||||
|
return 0, errors.New("Ordinal number too high")
|
||||||
|
}
|
||||||
|
// AddressOfFunctions contains the RVAs to the "real" functions.
|
||||||
|
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("Function not found by name")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcAddressByOrdinal returns function address by exported ordinal.
|
||||||
|
func (module *Module) ProcAddressByOrdinal(ordinal uint16) (uintptr, error) {
|
||||||
|
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
|
||||||
|
if directory.Size == 0 {
|
||||||
|
return 0, errors.New("No export table found")
|
||||||
|
}
|
||||||
|
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||||
|
if uint32(ordinal) < exports.Base {
|
||||||
|
return 0, errors.New("Ordinal number too low")
|
||||||
|
}
|
||||||
|
idx := ordinal - uint16(exports.Base)
|
||||||
|
if uint32(idx) > exports.NumberOfFunctions {
|
||||||
|
return 0, errors.New("Ordinal number too high")
|
||||||
|
}
|
||||||
|
// AddressOfFunctions contains the RVAs to the "real" functions.
|
||||||
|
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func alignDown(value, alignment uintptr) uintptr {
|
||||||
|
return value & ^(alignment - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func alignUp(value, alignment uintptr) uintptr {
|
||||||
|
return (value + alignment - 1) & ^(alignment - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func a2p(addr uintptr) unsafe.Pointer {
|
||||||
|
return unsafe.Pointer(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func memcpy(dst, src, size uintptr) {
|
||||||
|
var d, s []byte
|
||||||
|
unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
|
||||||
|
unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
|
||||||
|
copy(d, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unsafeSlice updates the slice slicePtr to be a slice
|
||||||
|
// referencing the provided data with its length & capacity set to
|
||||||
|
// lenCap.
|
||||||
|
//
|
||||||
|
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
|
||||||
|
// update callers to use unsafe.Slice instead of this.
|
||||||
|
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
|
||||||
|
type sliceHeader struct {
|
||||||
|
Data unsafe.Pointer
|
||||||
|
Len int
|
||||||
|
Cap int
|
||||||
|
}
|
||||||
|
h := (*sliceHeader)(slicePtr)
|
||||||
|
h.Data = data
|
||||||
|
h.Len = lenCap
|
||||||
|
h.Cap = lenCap
|
||||||
|
}
|
||||||
16
tun/wintun/memmod/memmod_windows_32.go
Normal file
16
tun/wintun/memmod/memmod_windows_32.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// +build 386 arm
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
|
||||||
|
return
|
||||||
|
}
|
||||||
8
tun/wintun/memmod/memmod_windows_386.go
Normal file
8
tun/wintun/memmod/memmod_windows_386.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
const imageFileProcess = IMAGE_FILE_MACHINE_I386
|
||||||
36
tun/wintun/memmod/memmod_windows_64.go
Normal file
36
tun/wintun/memmod/memmod_windows_64.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
// +build amd64 arm64
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
|
||||||
|
return uintptr(opthdr.ImageBase & 0xffffffff00000000)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
|
||||||
|
for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) {
|
||||||
|
node := &addressList{
|
||||||
|
next: module.blockedMemory,
|
||||||
|
address: module.codeBase,
|
||||||
|
}
|
||||||
|
module.blockedMemory = node
|
||||||
|
module.codeBase, err = windows.VirtualAlloc(0,
|
||||||
|
alignedImageSize,
|
||||||
|
windows.MEM_RESERVE|windows.MEM_COMMIT,
|
||||||
|
windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error allocating memory block: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
8
tun/wintun/memmod/memmod_windows_amd64.go
Normal file
8
tun/wintun/memmod/memmod_windows_amd64.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
const imageFileProcess = IMAGE_FILE_MACHINE_AMD64
|
||||||
8
tun/wintun/memmod/memmod_windows_arm.go
Normal file
8
tun/wintun/memmod/memmod_windows_arm.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT
|
||||||
8
tun/wintun/memmod/memmod_windows_arm64.go
Normal file
8
tun/wintun/memmod/memmod_windows_arm64.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
const imageFileProcess = IMAGE_FILE_MACHINE_ARM64
|
||||||
8
tun/wintun/memmod/mksyscall.go
Normal file
8
tun/wintun/memmod/mksyscall.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
|
||||||
343
tun/wintun/memmod/syscall_windows.go
Normal file
343
tun/wintun/memmod/syscall_windows.go
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
import "unsafe"
|
||||||
|
|
||||||
|
const (
|
||||||
|
IMAGE_DOS_SIGNATURE = 0x5A4D // MZ
|
||||||
|
IMAGE_OS2_SIGNATURE = 0x454E // NE
|
||||||
|
IMAGE_OS2_SIGNATURE_LE = 0x454C // LE
|
||||||
|
IMAGE_VXD_SIGNATURE = 0x454C // LE
|
||||||
|
IMAGE_NT_SIGNATURE = 0x00004550 // PE00
|
||||||
|
)
|
||||||
|
|
||||||
|
// DOS .EXE header
|
||||||
|
type IMAGE_DOS_HEADER struct {
|
||||||
|
E_magic uint16 // Magic number
|
||||||
|
E_cblp uint16 // Bytes on last page of file
|
||||||
|
E_cp uint16 // Pages in file
|
||||||
|
E_crlc uint16 // Relocations
|
||||||
|
E_cparhdr uint16 // Size of header in paragraphs
|
||||||
|
E_minalloc uint16 // Minimum extra paragraphs needed
|
||||||
|
E_maxalloc uint16 // Maximum extra paragraphs needed
|
||||||
|
E_ss uint16 // Initial (relative) SS value
|
||||||
|
E_sp uint16 // Initial SP value
|
||||||
|
E_csum uint16 // Checksum
|
||||||
|
E_ip uint16 // Initial IP value
|
||||||
|
E_cs uint16 // Initial (relative) CS value
|
||||||
|
E_lfarlc uint16 // File address of relocation table
|
||||||
|
E_ovno uint16 // Overlay number
|
||||||
|
E_res [4]uint16 // Reserved words
|
||||||
|
E_oemid uint16 // OEM identifier (for e_oeminfo)
|
||||||
|
E_oeminfo uint16 // OEM information; e_oemid specific
|
||||||
|
E_res2 [10]uint16 // Reserved words
|
||||||
|
E_lfanew int32 // File address of new exe header
|
||||||
|
}
|
||||||
|
|
||||||
|
// File header format
|
||||||
|
type IMAGE_FILE_HEADER struct {
|
||||||
|
Machine uint16
|
||||||
|
NumberOfSections uint16
|
||||||
|
TimeDateStamp uint32
|
||||||
|
PointerToSymbolTable uint32
|
||||||
|
NumberOfSymbols uint32
|
||||||
|
SizeOfOptionalHeader uint16
|
||||||
|
Characteristics uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
IMAGE_SIZEOF_FILE_HEADER = 20
|
||||||
|
|
||||||
|
IMAGE_FILE_RELOCS_STRIPPED = 0x0001 // Relocation info stripped from file.
|
||||||
|
IMAGE_FILE_EXECUTABLE_IMAGE = 0x0002 // File is executable (i.e. no unresolved external references).
|
||||||
|
IMAGE_FILE_LINE_NUMS_STRIPPED = 0x0004 // Line nunbers stripped from file.
|
||||||
|
IMAGE_FILE_LOCAL_SYMS_STRIPPED = 0x0008 // Local symbols stripped from file.
|
||||||
|
IMAGE_FILE_AGGRESIVE_WS_TRIM = 0x0010 // Aggressively trim working set
|
||||||
|
IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020 // App can handle >2gb addresses
|
||||||
|
IMAGE_FILE_BYTES_REVERSED_LO = 0x0080 // Bytes of machine word are reversed.
|
||||||
|
IMAGE_FILE_32BIT_MACHINE = 0x0100 // 32 bit word machine.
|
||||||
|
IMAGE_FILE_DEBUG_STRIPPED = 0x0200 // Debugging info stripped from file in .DBG file
|
||||||
|
IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400 // If Image is on removable media, copy and run from the swap file.
|
||||||
|
IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800 // If Image is on Net, copy and run from the swap file.
|
||||||
|
IMAGE_FILE_SYSTEM = 0x1000 // System File.
|
||||||
|
IMAGE_FILE_DLL = 0x2000 // File is a DLL.
|
||||||
|
IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 // File should only be run on a UP machine
|
||||||
|
IMAGE_FILE_BYTES_REVERSED_HI = 0x8000 // Bytes of machine word are reversed.
|
||||||
|
|
||||||
|
IMAGE_FILE_MACHINE_UNKNOWN = 0
|
||||||
|
IMAGE_FILE_MACHINE_TARGET_HOST = 0x0001 // Useful for indicating we want to interact with the host and not a WoW guest.
|
||||||
|
IMAGE_FILE_MACHINE_I386 = 0x014c // Intel 386.
|
||||||
|
IMAGE_FILE_MACHINE_R3000 = 0x0162 // MIPS little-endian, 0x160 big-endian
|
||||||
|
IMAGE_FILE_MACHINE_R4000 = 0x0166 // MIPS little-endian
|
||||||
|
IMAGE_FILE_MACHINE_R10000 = 0x0168 // MIPS little-endian
|
||||||
|
IMAGE_FILE_MACHINE_WCEMIPSV2 = 0x0169 // MIPS little-endian WCE v2
|
||||||
|
IMAGE_FILE_MACHINE_ALPHA = 0x0184 // Alpha_AXP
|
||||||
|
IMAGE_FILE_MACHINE_SH3 = 0x01a2 // SH3 little-endian
|
||||||
|
IMAGE_FILE_MACHINE_SH3DSP = 0x01a3
|
||||||
|
IMAGE_FILE_MACHINE_SH3E = 0x01a4 // SH3E little-endian
|
||||||
|
IMAGE_FILE_MACHINE_SH4 = 0x01a6 // SH4 little-endian
|
||||||
|
IMAGE_FILE_MACHINE_SH5 = 0x01a8 // SH5
|
||||||
|
IMAGE_FILE_MACHINE_ARM = 0x01c0 // ARM Little-Endian
|
||||||
|
IMAGE_FILE_MACHINE_THUMB = 0x01c2 // ARM Thumb/Thumb-2 Little-Endian
|
||||||
|
IMAGE_FILE_MACHINE_ARMNT = 0x01c4 // ARM Thumb-2 Little-Endian
|
||||||
|
IMAGE_FILE_MACHINE_AM33 = 0x01d3
|
||||||
|
IMAGE_FILE_MACHINE_POWERPC = 0x01F0 // IBM PowerPC Little-Endian
|
||||||
|
IMAGE_FILE_MACHINE_POWERPCFP = 0x01f1
|
||||||
|
IMAGE_FILE_MACHINE_IA64 = 0x0200 // Intel 64
|
||||||
|
IMAGE_FILE_MACHINE_MIPS16 = 0x0266 // MIPS
|
||||||
|
IMAGE_FILE_MACHINE_ALPHA64 = 0x0284 // ALPHA64
|
||||||
|
IMAGE_FILE_MACHINE_MIPSFPU = 0x0366 // MIPS
|
||||||
|
IMAGE_FILE_MACHINE_MIPSFPU16 = 0x0466 // MIPS
|
||||||
|
IMAGE_FILE_MACHINE_AXP64 = IMAGE_FILE_MACHINE_ALPHA64
|
||||||
|
IMAGE_FILE_MACHINE_TRICORE = 0x0520 // Infineon
|
||||||
|
IMAGE_FILE_MACHINE_CEF = 0x0CEF
|
||||||
|
IMAGE_FILE_MACHINE_EBC = 0x0EBC // EFI Byte Code
|
||||||
|
IMAGE_FILE_MACHINE_AMD64 = 0x8664 // AMD64 (K8)
|
||||||
|
IMAGE_FILE_MACHINE_M32R = 0x9041 // M32R little-endian
|
||||||
|
IMAGE_FILE_MACHINE_ARM64 = 0xAA64 // ARM64 Little-Endian
|
||||||
|
IMAGE_FILE_MACHINE_CEE = 0xC0EE
|
||||||
|
)
|
||||||
|
|
||||||
|
// Directory format
|
||||||
|
type IMAGE_DATA_DIRECTORY struct {
|
||||||
|
VirtualAddress uint32
|
||||||
|
Size uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
const IMAGE_NUMBEROF_DIRECTORY_ENTRIES = 16
|
||||||
|
|
||||||
|
type IMAGE_NT_HEADERS struct {
|
||||||
|
Signature uint32
|
||||||
|
FileHeader IMAGE_FILE_HEADER
|
||||||
|
OptionalHeader IMAGE_OPTIONAL_HEADER
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ntheader *IMAGE_NT_HEADERS) Sections() []IMAGE_SECTION_HEADER {
|
||||||
|
return (*[0xffff]IMAGE_SECTION_HEADER)(unsafe.Pointer(
|
||||||
|
(uintptr)(unsafe.Pointer(ntheader)) +
|
||||||
|
unsafe.Offsetof(ntheader.OptionalHeader) +
|
||||||
|
uintptr(ntheader.FileHeader.SizeOfOptionalHeader)))[:ntheader.FileHeader.NumberOfSections]
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
IMAGE_DIRECTORY_ENTRY_EXPORT = 0 // Export Directory
|
||||||
|
IMAGE_DIRECTORY_ENTRY_IMPORT = 1 // Import Directory
|
||||||
|
IMAGE_DIRECTORY_ENTRY_RESOURCE = 2 // Resource Directory
|
||||||
|
IMAGE_DIRECTORY_ENTRY_EXCEPTION = 3 // Exception Directory
|
||||||
|
IMAGE_DIRECTORY_ENTRY_SECURITY = 4 // Security Directory
|
||||||
|
IMAGE_DIRECTORY_ENTRY_BASERELOC = 5 // Base Relocation Table
|
||||||
|
IMAGE_DIRECTORY_ENTRY_DEBUG = 6 // Debug Directory
|
||||||
|
IMAGE_DIRECTORY_ENTRY_COPYRIGHT = 7 // (X86 usage)
|
||||||
|
IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = 7 // Architecture Specific Data
|
||||||
|
IMAGE_DIRECTORY_ENTRY_GLOBALPTR = 8 // RVA of GP
|
||||||
|
IMAGE_DIRECTORY_ENTRY_TLS = 9 // TLS Directory
|
||||||
|
IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = 10 // Load Configuration Directory
|
||||||
|
IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = 11 // Bound Import Directory in headers
|
||||||
|
IMAGE_DIRECTORY_ENTRY_IAT = 12 // Import Address Table
|
||||||
|
IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = 13 // Delay Load Import Descriptors
|
||||||
|
IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = 14 // COM Runtime descriptor
|
||||||
|
)
|
||||||
|
|
||||||
|
const IMAGE_SIZEOF_SHORT_NAME = 8
|
||||||
|
|
||||||
|
// Section header format
|
||||||
|
type IMAGE_SECTION_HEADER struct {
|
||||||
|
Name [IMAGE_SIZEOF_SHORT_NAME]byte
|
||||||
|
physicalAddressOrVirtualSize uint32
|
||||||
|
VirtualAddress uint32
|
||||||
|
SizeOfRawData uint32
|
||||||
|
PointerToRawData uint32
|
||||||
|
PointerToRelocations uint32
|
||||||
|
PointerToLinenumbers uint32
|
||||||
|
NumberOfRelocations uint16
|
||||||
|
NumberOfLinenumbers uint16
|
||||||
|
Characteristics uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ishdr *IMAGE_SECTION_HEADER) PhysicalAddress() uint32 {
|
||||||
|
return ishdr.physicalAddressOrVirtualSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ishdr *IMAGE_SECTION_HEADER) SetPhysicalAddress(addr uint32) {
|
||||||
|
ishdr.physicalAddressOrVirtualSize = addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ishdr *IMAGE_SECTION_HEADER) VirtualSize() uint32 {
|
||||||
|
return ishdr.physicalAddressOrVirtualSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ishdr *IMAGE_SECTION_HEADER) SetVirtualSize(addr uint32) {
|
||||||
|
ishdr.physicalAddressOrVirtualSize = addr
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Section characteristics.
|
||||||
|
IMAGE_SCN_TYPE_REG = 0x00000000 // Reserved.
|
||||||
|
IMAGE_SCN_TYPE_DSECT = 0x00000001 // Reserved.
|
||||||
|
IMAGE_SCN_TYPE_NOLOAD = 0x00000002 // Reserved.
|
||||||
|
IMAGE_SCN_TYPE_GROUP = 0x00000004 // Reserved.
|
||||||
|
IMAGE_SCN_TYPE_NO_PAD = 0x00000008 // Reserved.
|
||||||
|
IMAGE_SCN_TYPE_COPY = 0x00000010 // Reserved.
|
||||||
|
|
||||||
|
IMAGE_SCN_CNT_CODE = 0x00000020 // Section contains code.
|
||||||
|
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040 // Section contains initialized data.
|
||||||
|
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080 // Section contains uninitialized data.
|
||||||
|
|
||||||
|
IMAGE_SCN_LNK_OTHER = 0x00000100 // Reserved.
|
||||||
|
IMAGE_SCN_LNK_INFO = 0x00000200 // Section contains comments or some other type of information.
|
||||||
|
IMAGE_SCN_TYPE_OVER = 0x00000400 // Reserved.
|
||||||
|
IMAGE_SCN_LNK_REMOVE = 0x00000800 // Section contents will not become part of image.
|
||||||
|
IMAGE_SCN_LNK_COMDAT = 0x00001000 // Section contents comdat.
|
||||||
|
IMAGE_SCN_MEM_PROTECTED = 0x00004000 // Obsolete.
|
||||||
|
IMAGE_SCN_NO_DEFER_SPEC_EXC = 0x00004000 // Reset speculative exceptions handling bits in the TLB entries for this section.
|
||||||
|
IMAGE_SCN_GPREL = 0x00008000 // Section content can be accessed relative to GP
|
||||||
|
IMAGE_SCN_MEM_FARDATA = 0x00008000
|
||||||
|
IMAGE_SCN_MEM_SYSHEAP = 0x00010000 // Obsolete.
|
||||||
|
IMAGE_SCN_MEM_PURGEABLE = 0x00020000
|
||||||
|
IMAGE_SCN_MEM_16BIT = 0x00020000
|
||||||
|
IMAGE_SCN_MEM_LOCKED = 0x00040000
|
||||||
|
IMAGE_SCN_MEM_PRELOAD = 0x00080000
|
||||||
|
|
||||||
|
IMAGE_SCN_ALIGN_1BYTES = 0x00100000 //
|
||||||
|
IMAGE_SCN_ALIGN_2BYTES = 0x00200000 //
|
||||||
|
IMAGE_SCN_ALIGN_4BYTES = 0x00300000 //
|
||||||
|
IMAGE_SCN_ALIGN_8BYTES = 0x00400000 //
|
||||||
|
IMAGE_SCN_ALIGN_16BYTES = 0x00500000 // Default alignment if no others are specified.
|
||||||
|
IMAGE_SCN_ALIGN_32BYTES = 0x00600000 //
|
||||||
|
IMAGE_SCN_ALIGN_64BYTES = 0x00700000 //
|
||||||
|
IMAGE_SCN_ALIGN_128BYTES = 0x00800000 //
|
||||||
|
IMAGE_SCN_ALIGN_256BYTES = 0x00900000 //
|
||||||
|
IMAGE_SCN_ALIGN_512BYTES = 0x00A00000 //
|
||||||
|
IMAGE_SCN_ALIGN_1024BYTES = 0x00B00000 //
|
||||||
|
IMAGE_SCN_ALIGN_2048BYTES = 0x00C00000 //
|
||||||
|
IMAGE_SCN_ALIGN_4096BYTES = 0x00D00000 //
|
||||||
|
IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000 //
|
||||||
|
IMAGE_SCN_ALIGN_MASK = 0x00F00000
|
||||||
|
|
||||||
|
IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000 // Section contains extended relocations.
|
||||||
|
IMAGE_SCN_MEM_DISCARDABLE = 0x02000000 // Section can be discarded.
|
||||||
|
IMAGE_SCN_MEM_NOT_CACHED = 0x04000000 // Section is not cachable.
|
||||||
|
IMAGE_SCN_MEM_NOT_PAGED = 0x08000000 // Section is not pageable.
|
||||||
|
IMAGE_SCN_MEM_SHARED = 0x10000000 // Section is shareable.
|
||||||
|
IMAGE_SCN_MEM_EXECUTE = 0x20000000 // Section is executable.
|
||||||
|
IMAGE_SCN_MEM_READ = 0x40000000 // Section is readable.
|
||||||
|
IMAGE_SCN_MEM_WRITE = 0x80000000 // Section is writeable.
|
||||||
|
|
||||||
|
// TLS Characteristic Flags
|
||||||
|
IMAGE_SCN_SCALE_INDEX = 0x00000001 // Tls index is scaled.
|
||||||
|
)
|
||||||
|
|
||||||
|
// Based relocation format
|
||||||
|
type IMAGE_BASE_RELOCATION struct {
|
||||||
|
VirtualAddress uint32
|
||||||
|
SizeOfBlock uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
IMAGE_REL_BASED_ABSOLUTE = 0
|
||||||
|
IMAGE_REL_BASED_HIGH = 1
|
||||||
|
IMAGE_REL_BASED_LOW = 2
|
||||||
|
IMAGE_REL_BASED_HIGHLOW = 3
|
||||||
|
IMAGE_REL_BASED_HIGHADJ = 4
|
||||||
|
IMAGE_REL_BASED_MACHINE_SPECIFIC_5 = 5
|
||||||
|
IMAGE_REL_BASED_RESERVED = 6
|
||||||
|
IMAGE_REL_BASED_MACHINE_SPECIFIC_7 = 7
|
||||||
|
IMAGE_REL_BASED_MACHINE_SPECIFIC_8 = 8
|
||||||
|
IMAGE_REL_BASED_MACHINE_SPECIFIC_9 = 9
|
||||||
|
IMAGE_REL_BASED_DIR64 = 10
|
||||||
|
|
||||||
|
IMAGE_REL_BASED_IA64_IMM64 = 9
|
||||||
|
|
||||||
|
IMAGE_REL_BASED_MIPS_JMPADDR = 5
|
||||||
|
IMAGE_REL_BASED_MIPS_JMPADDR16 = 9
|
||||||
|
|
||||||
|
IMAGE_REL_BASED_ARM_MOV32 = 5
|
||||||
|
IMAGE_REL_BASED_THUMB_MOV32 = 7
|
||||||
|
)
|
||||||
|
|
||||||
|
// Export Format
|
||||||
|
type IMAGE_EXPORT_DIRECTORY struct {
|
||||||
|
Characteristics uint32
|
||||||
|
TimeDateStamp uint32
|
||||||
|
MajorVersion uint16
|
||||||
|
MinorVersion uint16
|
||||||
|
Name uint32
|
||||||
|
Base uint32
|
||||||
|
NumberOfFunctions uint32
|
||||||
|
NumberOfNames uint32
|
||||||
|
AddressOfFunctions uint32 // RVA from base of image
|
||||||
|
AddressOfNames uint32 // RVA from base of image
|
||||||
|
AddressOfNameOrdinals uint32 // RVA from base of image
|
||||||
|
}
|
||||||
|
|
||||||
|
type IMAGE_IMPORT_BY_NAME struct {
|
||||||
|
Hint uint16
|
||||||
|
Name [1]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func IMAGE_ORDINAL(ordinal uintptr) uintptr {
|
||||||
|
return ordinal & 0xffff
|
||||||
|
}
|
||||||
|
|
||||||
|
func IMAGE_SNAP_BY_ORDINAL(ordinal uintptr) bool {
|
||||||
|
return (ordinal & IMAGE_ORDINAL_FLAG) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Thread Local Storage
|
||||||
|
type IMAGE_TLS_DIRECTORY struct {
|
||||||
|
StartAddressOfRawData uintptr
|
||||||
|
EndAddressOfRawData uintptr
|
||||||
|
AddressOfIndex uintptr // PDWORD
|
||||||
|
AddressOfCallbacks uintptr // PIMAGE_TLS_CALLBACK *;
|
||||||
|
SizeOfZeroFill uint32
|
||||||
|
Characteristics uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type IMAGE_IMPORT_DESCRIPTOR struct {
|
||||||
|
characteristicsOrOriginalFirstThunk uint32 // 0 for terminating null import descriptor
|
||||||
|
// RVA to original unbound IAT (PIMAGE_THUNK_DATA)
|
||||||
|
TimeDateStamp uint32 // 0 if not bound,
|
||||||
|
// -1 if bound, and real date\time stamp
|
||||||
|
// in IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT (new BIND)
|
||||||
|
// O.W. date/time stamp of DLL bound to (Old BIND)
|
||||||
|
ForwarderChain uint32 // -1 if no forwarders
|
||||||
|
Name uint32
|
||||||
|
FirstThunk uint32 // RVA to IAT (if bound this IAT has actual addresses)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) Characteristics() uint32 {
|
||||||
|
return imgimpdesc.characteristicsOrOriginalFirstThunk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) OriginalFirstThunk() uint32 {
|
||||||
|
return imgimpdesc.characteristicsOrOriginalFirstThunk
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
DLL_PROCESS_ATTACH = 1
|
||||||
|
DLL_THREAD_ATTACH = 2
|
||||||
|
DLL_THREAD_DETACH = 3
|
||||||
|
DLL_PROCESS_DETACH = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
//sys loadLibraryA(libFileName *byte) (module windows.Handle, err error) = kernel32.LoadLibraryA
|
||||||
|
//sys getProcAddress(module windows.Handle, procName *byte) (addr uintptr, err error) = kernel32.GetProcAddress
|
||||||
|
//sys isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) = kernel32.IsBadReadPtr
|
||||||
|
|
||||||
|
type SYSTEM_INFO struct {
|
||||||
|
ProcessorArchitecture uint16
|
||||||
|
Reserved uint16
|
||||||
|
PageSize uint32
|
||||||
|
MinimumApplicationAddress uintptr
|
||||||
|
MaximumApplicationAddress uintptr
|
||||||
|
ActiveProcessorMask uintptr
|
||||||
|
NumberOfProcessors uint32
|
||||||
|
ProcessorType uint32
|
||||||
|
AllocationGranularity uint32
|
||||||
|
ProcessorLevel uint16
|
||||||
|
ProcessorRevision uint16
|
||||||
|
}
|
||||||
45
tun/wintun/memmod/syscall_windows_32.go
Normal file
45
tun/wintun/memmod/syscall_windows_32.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
// +build 386 arm
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
// Optional header format
|
||||||
|
type IMAGE_OPTIONAL_HEADER struct {
|
||||||
|
Magic uint16
|
||||||
|
MajorLinkerVersion uint8
|
||||||
|
MinorLinkerVersion uint8
|
||||||
|
SizeOfCode uint32
|
||||||
|
SizeOfInitializedData uint32
|
||||||
|
SizeOfUninitializedData uint32
|
||||||
|
AddressOfEntryPoint uint32
|
||||||
|
BaseOfCode uint32
|
||||||
|
BaseOfData uint32
|
||||||
|
ImageBase uintptr
|
||||||
|
SectionAlignment uint32
|
||||||
|
FileAlignment uint32
|
||||||
|
MajorOperatingSystemVersion uint16
|
||||||
|
MinorOperatingSystemVersion uint16
|
||||||
|
MajorImageVersion uint16
|
||||||
|
MinorImageVersion uint16
|
||||||
|
MajorSubsystemVersion uint16
|
||||||
|
MinorSubsystemVersion uint16
|
||||||
|
Win32VersionValue uint32
|
||||||
|
SizeOfImage uint32
|
||||||
|
SizeOfHeaders uint32
|
||||||
|
CheckSum uint32
|
||||||
|
Subsystem uint16
|
||||||
|
DllCharacteristics uint16
|
||||||
|
SizeOfStackReserve uintptr
|
||||||
|
SizeOfStackCommit uintptr
|
||||||
|
SizeOfHeapReserve uintptr
|
||||||
|
SizeOfHeapCommit uintptr
|
||||||
|
LoaderFlags uint32
|
||||||
|
NumberOfRvaAndSizes uint32
|
||||||
|
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
|
||||||
|
}
|
||||||
|
|
||||||
|
const IMAGE_ORDINAL_FLAG uintptr = 0x80000000
|
||||||
44
tun/wintun/memmod/syscall_windows_64.go
Normal file
44
tun/wintun/memmod/syscall_windows_64.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
// +build amd64 arm64
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
// Optional header format
|
||||||
|
type IMAGE_OPTIONAL_HEADER struct {
|
||||||
|
Magic uint16
|
||||||
|
MajorLinkerVersion uint8
|
||||||
|
MinorLinkerVersion uint8
|
||||||
|
SizeOfCode uint32
|
||||||
|
SizeOfInitializedData uint32
|
||||||
|
SizeOfUninitializedData uint32
|
||||||
|
AddressOfEntryPoint uint32
|
||||||
|
BaseOfCode uint32
|
||||||
|
ImageBase uintptr
|
||||||
|
SectionAlignment uint32
|
||||||
|
FileAlignment uint32
|
||||||
|
MajorOperatingSystemVersion uint16
|
||||||
|
MinorOperatingSystemVersion uint16
|
||||||
|
MajorImageVersion uint16
|
||||||
|
MinorImageVersion uint16
|
||||||
|
MajorSubsystemVersion uint16
|
||||||
|
MinorSubsystemVersion uint16
|
||||||
|
Win32VersionValue uint32
|
||||||
|
SizeOfImage uint32
|
||||||
|
SizeOfHeaders uint32
|
||||||
|
CheckSum uint32
|
||||||
|
Subsystem uint16
|
||||||
|
DllCharacteristics uint16
|
||||||
|
SizeOfStackReserve uintptr
|
||||||
|
SizeOfStackCommit uintptr
|
||||||
|
SizeOfHeapReserve uintptr
|
||||||
|
SizeOfHeapCommit uintptr
|
||||||
|
LoaderFlags uint32
|
||||||
|
NumberOfRvaAndSizes uint32
|
||||||
|
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
|
||||||
|
}
|
||||||
|
|
||||||
|
const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000
|
||||||
70
tun/wintun/memmod/zsyscall_windows.go
Normal file
70
tun/wintun/memmod/zsyscall_windows.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
// Code generated by 'go generate'; DO NOT EDIT.
|
||||||
|
|
||||||
|
package memmod
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ unsafe.Pointer
|
||||||
|
|
||||||
|
// Do the interface allocations only once for common
|
||||||
|
// Errno values.
|
||||||
|
const (
|
||||||
|
errnoERROR_IO_PENDING = 997
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||||
|
errERROR_EINVAL error = syscall.EINVAL
|
||||||
|
)
|
||||||
|
|
||||||
|
// errnoErr returns common boxed Errno values, to prevent
|
||||||
|
// allocations at runtime.
|
||||||
|
func errnoErr(e syscall.Errno) error {
|
||||||
|
switch e {
|
||||||
|
case 0:
|
||||||
|
return errERROR_EINVAL
|
||||||
|
case errnoERROR_IO_PENDING:
|
||||||
|
return errERROR_IO_PENDING
|
||||||
|
}
|
||||||
|
// TODO: add more here, after collecting data on the common
|
||||||
|
// error values see on Windows. (perhaps when running
|
||||||
|
// all.bat?)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
|
|
||||||
|
procGetProcAddress = modkernel32.NewProc("GetProcAddress")
|
||||||
|
procIsBadReadPtr = modkernel32.NewProc("IsBadReadPtr")
|
||||||
|
procLoadLibraryA = modkernel32.NewProc("LoadLibraryA")
|
||||||
|
)
|
||||||
|
|
||||||
|
func getProcAddress(module windows.Handle, procName *byte) (addr uintptr, err error) {
|
||||||
|
r0, _, e1 := syscall.Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(procName)), 0)
|
||||||
|
addr = uintptr(r0)
|
||||||
|
if addr == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) {
|
||||||
|
r0, _, _ := syscall.Syscall(procIsBadReadPtr.Addr(), 2, uintptr(addr), uintptr(ucb), 0)
|
||||||
|
ret = r0 != 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadLibraryA(libFileName *byte) (module windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.Syscall(procLoadLibraryA.Addr(), 1, uintptr(unsafe.Pointer(libFileName)), 0, 0)
|
||||||
|
module = windows.Handle(r0)
|
||||||
|
if module == 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package netshell
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
modnetshell = windows.NewLazySystemDLL("netshell.dll")
|
|
||||||
procHrRenameConnection = modnetshell.NewProc("HrRenameConnection")
|
|
||||||
)
|
|
||||||
|
|
||||||
func HrRenameConnection(guid *windows.GUID, newName *uint16) (err error) {
|
|
||||||
err = procHrRenameConnection.Find()
|
|
||||||
if err != nil {
|
|
||||||
// Missing from servercore, so we can't presume it's always there.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ret, _, _ := syscall.Syscall(procHrRenameConnection.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0)
|
|
||||||
if ret != 0 {
|
|
||||||
err = syscall.Errno(ret)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package registry
|
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zregistry_windows.go registry_windows.go
|
|
||||||
@@ -1,272 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package registry
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
"golang.org/x/sys/windows/registry"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// REG_NOTIFY_CHANGE_NAME notifies the caller if a subkey is added or deleted.
|
|
||||||
REG_NOTIFY_CHANGE_NAME uint32 = 0x00000001
|
|
||||||
|
|
||||||
// REG_NOTIFY_CHANGE_ATTRIBUTES notifies the caller of changes to the attributes of the key, such as the security descriptor information.
|
|
||||||
REG_NOTIFY_CHANGE_ATTRIBUTES uint32 = 0x00000002
|
|
||||||
|
|
||||||
// REG_NOTIFY_CHANGE_LAST_SET notifies the caller of changes to a value of the key. This can include adding or deleting a value, or changing an existing value.
|
|
||||||
REG_NOTIFY_CHANGE_LAST_SET uint32 = 0x00000004
|
|
||||||
|
|
||||||
// REG_NOTIFY_CHANGE_SECURITY notifies the caller of changes to the security descriptor of the key.
|
|
||||||
REG_NOTIFY_CHANGE_SECURITY uint32 = 0x00000008
|
|
||||||
|
|
||||||
// REG_NOTIFY_THREAD_AGNOSTIC indicates that the lifetime of the registration must not be tied to the lifetime of the thread issuing the RegNotifyChangeKeyValue call. Note: This flag value is only supported in Windows 8 and later.
|
|
||||||
REG_NOTIFY_THREAD_AGNOSTIC uint32 = 0x10000000
|
|
||||||
)
|
|
||||||
|
|
||||||
//sys regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) = advapi32.RegNotifyChangeKeyValue
|
|
||||||
|
|
||||||
func OpenKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) {
|
|
||||||
runtime.LockOSThread()
|
|
||||||
defer runtime.UnlockOSThread()
|
|
||||||
|
|
||||||
deadline := time.Now().Add(timeout)
|
|
||||||
pathSpl := strings.Split(path, "\\")
|
|
||||||
for i := 0; ; i++ {
|
|
||||||
keyName := pathSpl[i]
|
|
||||||
isLast := i+1 == len(pathSpl)
|
|
||||||
|
|
||||||
event, err := windows.CreateEvent(nil, 0, 0, nil)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("Error creating event: %v", err)
|
|
||||||
}
|
|
||||||
defer windows.CloseHandle(event)
|
|
||||||
|
|
||||||
var key registry.Key
|
|
||||||
for {
|
|
||||||
err = regNotifyChangeKeyValue(windows.Handle(k), false, REG_NOTIFY_CHANGE_NAME, windows.Handle(event), true)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("Setting up change notification on registry key failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var accessFlags uint32
|
|
||||||
if isLast {
|
|
||||||
accessFlags = access
|
|
||||||
} else {
|
|
||||||
accessFlags = registry.NOTIFY
|
|
||||||
}
|
|
||||||
key, err = registry.OpenKey(k, keyName, accessFlags)
|
|
||||||
if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
|
|
||||||
timeout := time.Until(deadline) / time.Millisecond
|
|
||||||
if timeout < 0 {
|
|
||||||
timeout = 0
|
|
||||||
}
|
|
||||||
s, err := windows.WaitForSingleObject(event, uint32(timeout))
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("Unable to wait on registry key: %v", err)
|
|
||||||
}
|
|
||||||
if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
|
|
||||||
return 0, errors.New("Timeout waiting for registry key")
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return 0, fmt.Errorf("Error opening registry key %v: %v", path, err)
|
|
||||||
} else {
|
|
||||||
if isLast {
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
defer key.Close()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
k = key
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func WaitForKey(k registry.Key, path string, timeout time.Duration) error {
|
|
||||||
key, err := OpenKeyWait(k, path, registry.NOTIFY, timeout)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
key.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// getValue is more or less the same as windows/registry's getValue.
|
|
||||||
//
|
|
||||||
func getValue(k registry.Key, name string, buf []byte) (value []byte, valueType uint32, err error) {
|
|
||||||
var name16 *uint16
|
|
||||||
name16, err = windows.UTF16PtrFromString(name)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n := uint32(len(buf))
|
|
||||||
for {
|
|
||||||
err = windows.RegQueryValueEx(windows.Handle(k), name16, nil, &valueType, (*byte)(unsafe.Pointer(&buf[0])), &n)
|
|
||||||
if err == nil {
|
|
||||||
value = buf[:n]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err != windows.ERROR_MORE_DATA {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if n <= uint32(len(buf)) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buf = make([]byte, n)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// getValueRetry function reads any value from registry. It waits for
|
|
||||||
// the registry value to become available or returns error on timeout.
|
|
||||||
//
|
|
||||||
// Key must be opened with at least QUERY_VALUE|NOTIFY access.
|
|
||||||
//
|
|
||||||
func getValueRetry(key registry.Key, name string, buf []byte, timeout time.Duration) ([]byte, uint32, error) {
|
|
||||||
runtime.LockOSThread()
|
|
||||||
defer runtime.UnlockOSThread()
|
|
||||||
|
|
||||||
event, err := windows.CreateEvent(nil, 0, 0, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Error creating event: %v", err)
|
|
||||||
}
|
|
||||||
defer windows.CloseHandle(event)
|
|
||||||
|
|
||||||
deadline := time.Now().Add(timeout)
|
|
||||||
for {
|
|
||||||
err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf, valueType, err := getValue(key, name, buf)
|
|
||||||
if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
|
|
||||||
timeout := time.Until(deadline) / time.Millisecond
|
|
||||||
if timeout < 0 {
|
|
||||||
timeout = 0
|
|
||||||
}
|
|
||||||
s, err := windows.WaitForSingleObject(event, uint32(timeout))
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Unable to wait on registry value: %v", err)
|
|
||||||
}
|
|
||||||
if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
|
|
||||||
return nil, 0, errors.New("Timeout waiting for registry value")
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("Error reading registry value %v: %v", name, err)
|
|
||||||
} else {
|
|
||||||
return buf, valueType, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func toString(buf []byte, valueType uint32, err error) (string, error) {
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var value string
|
|
||||||
switch valueType {
|
|
||||||
case registry.SZ, registry.EXPAND_SZ, registry.MULTI_SZ:
|
|
||||||
if len(buf) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
value = windows.UTF16ToString((*[(1 << 30) - 1]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2])
|
|
||||||
|
|
||||||
default:
|
|
||||||
return "", registry.ErrUnexpectedType
|
|
||||||
}
|
|
||||||
|
|
||||||
if valueType != registry.EXPAND_SZ {
|
|
||||||
// Value does not require expansion.
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
valueExp, err := registry.ExpandString(value)
|
|
||||||
if err != nil {
|
|
||||||
// Expanding failed: return original sting value.
|
|
||||||
return value, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return expanded value.
|
|
||||||
return valueExp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toInteger(buf []byte, valueType uint32, err error) (uint64, error) {
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch valueType {
|
|
||||||
case registry.DWORD:
|
|
||||||
if len(buf) != 4 {
|
|
||||||
return 0, errors.New("DWORD value is not 4 bytes long")
|
|
||||||
}
|
|
||||||
var val uint32
|
|
||||||
copy((*[4]byte)(unsafe.Pointer(&val))[:], buf)
|
|
||||||
return uint64(val), nil
|
|
||||||
|
|
||||||
case registry.QWORD:
|
|
||||||
if len(buf) != 8 {
|
|
||||||
return 0, errors.New("QWORD value is not 8 bytes long")
|
|
||||||
}
|
|
||||||
var val uint64
|
|
||||||
copy((*[8]byte)(unsafe.Pointer(&val))[:], buf)
|
|
||||||
return val, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return 0, registry.ErrUnexpectedType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// GetStringValueWait function reads a string value from registry. It waits
|
|
||||||
// for the registry value to become available or returns error on timeout.
|
|
||||||
//
|
|
||||||
// Key must be opened with at least QUERY_VALUE|NOTIFY access.
|
|
||||||
//
|
|
||||||
// If the value type is REG_EXPAND_SZ the environment variables are expanded.
|
|
||||||
// Should expanding fail, original string value and nil error are returned.
|
|
||||||
//
|
|
||||||
// If the value type is REG_MULTI_SZ only the first string is returned.
|
|
||||||
//
|
|
||||||
func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
|
|
||||||
return toString(getValueRetry(key, name, make([]byte, 256), timeout))
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// GetStringValue function reads a string value from registry.
|
|
||||||
//
|
|
||||||
// Key must be opened with at least QUERY_VALUE access.
|
|
||||||
//
|
|
||||||
// If the value type is REG_EXPAND_SZ the environment variables are expanded.
|
|
||||||
// Should expanding fail, original string value and nil error are returned.
|
|
||||||
//
|
|
||||||
// If the value type is REG_MULTI_SZ only the first string is returned.
|
|
||||||
//
|
|
||||||
func GetStringValue(key registry.Key, name string) (string, error) {
|
|
||||||
return toString(getValue(key, name, make([]byte, 256)))
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// GetIntegerValueWait function reads a DWORD32 or QWORD value from registry.
|
|
||||||
// It waits for the registry value to become available or returns error on
|
|
||||||
// timeout.
|
|
||||||
//
|
|
||||||
// Key must be opened with at least QUERY_VALUE|NOTIFY access.
|
|
||||||
//
|
|
||||||
func GetIntegerValueWait(key registry.Key, name string, timeout time.Duration) (uint64, error) {
|
|
||||||
return toInteger(getValueRetry(key, name, make([]byte, 8), timeout))
|
|
||||||
}
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package registry
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows/registry"
|
|
||||||
)
|
|
||||||
|
|
||||||
const keyRoot = registry.CURRENT_USER
|
|
||||||
const pathRoot = "Software\\WireGuardRegistryTest"
|
|
||||||
const path = pathRoot + "\\foobar"
|
|
||||||
const pathFake = pathRoot + "\\raboof"
|
|
||||||
|
|
||||||
func Test_WaitForKey(t *testing.T) {
|
|
||||||
registry.DeleteKey(keyRoot, path)
|
|
||||||
registry.DeleteKey(keyRoot, pathRoot)
|
|
||||||
go func() {
|
|
||||||
time.Sleep(time.Second * 1)
|
|
||||||
key, _, err := registry.CreateKey(keyRoot, pathFake, registry.QUERY_VALUE)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error creating registry key: %v", err)
|
|
||||||
}
|
|
||||||
key.Close()
|
|
||||||
registry.DeleteKey(keyRoot, pathFake)
|
|
||||||
|
|
||||||
key, _, err = registry.CreateKey(keyRoot, path, registry.QUERY_VALUE)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error creating registry key: %v", err)
|
|
||||||
}
|
|
||||||
key.Close()
|
|
||||||
}()
|
|
||||||
err := WaitForKey(keyRoot, path, time.Second*2)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error waiting for registry key: %v", err)
|
|
||||||
}
|
|
||||||
registry.DeleteKey(keyRoot, path)
|
|
||||||
registry.DeleteKey(keyRoot, pathRoot)
|
|
||||||
|
|
||||||
err = WaitForKey(keyRoot, path, time.Second*1)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Registry key notification expected to timeout but it succeeded.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_GetValueWait(t *testing.T) {
|
|
||||||
registry.DeleteKey(keyRoot, path)
|
|
||||||
registry.DeleteKey(keyRoot, pathRoot)
|
|
||||||
go func() {
|
|
||||||
time.Sleep(time.Second * 1)
|
|
||||||
key, _, err := registry.CreateKey(keyRoot, path, registry.SET_VALUE)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error creating registry key: %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second * 1)
|
|
||||||
key.SetStringValue("name1", "eulav")
|
|
||||||
key.SetExpandStringValue("name2", "value")
|
|
||||||
time.Sleep(time.Second * 1)
|
|
||||||
key.SetDWordValue("name3", ^uint32(123))
|
|
||||||
key.SetDWordValue("name4", 123)
|
|
||||||
key.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
key, err := OpenKeyWait(keyRoot, path, registry.QUERY_VALUE|registry.NOTIFY, time.Second*2)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error waiting for registry key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
valueStr, err := GetStringValueWait(key, "name2", time.Second*2)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error waiting for registry value: %v", err)
|
|
||||||
}
|
|
||||||
if valueStr != "value" {
|
|
||||||
t.Errorf("Wrong value read: %v", valueStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = GetStringValueWait(key, "nonexisting", time.Second*1)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Registry value notification expected to timeout but it succeeded.")
|
|
||||||
}
|
|
||||||
|
|
||||||
valueInt, err := GetIntegerValueWait(key, "name4", time.Second*2)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error waiting for registry value: %v", err)
|
|
||||||
}
|
|
||||||
if valueInt != 123 {
|
|
||||||
t.Errorf("Wrong value read: %v", valueInt)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = GetIntegerValueWait(key, "nonexisting", time.Second*1)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Registry value notification expected to timeout but it succeeded.")
|
|
||||||
}
|
|
||||||
|
|
||||||
key.Close()
|
|
||||||
registry.DeleteKey(keyRoot, path)
|
|
||||||
registry.DeleteKey(keyRoot, pathRoot)
|
|
||||||
}
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
// Code generated by 'go generate'; DO NOT EDIT.
|
|
||||||
|
|
||||||
package registry
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ unsafe.Pointer
|
|
||||||
|
|
||||||
// Do the interface allocations only once for common
|
|
||||||
// Errno values.
|
|
||||||
const (
|
|
||||||
errnoERROR_IO_PENDING = 997
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
|
||||||
)
|
|
||||||
|
|
||||||
// errnoErr returns common boxed Errno values, to prevent
|
|
||||||
// allocations at runtime.
|
|
||||||
func errnoErr(e syscall.Errno) error {
|
|
||||||
switch e {
|
|
||||||
case 0:
|
|
||||||
return nil
|
|
||||||
case errnoERROR_IO_PENDING:
|
|
||||||
return errERROR_IO_PENDING
|
|
||||||
}
|
|
||||||
// TODO: add more here, after collecting data on the common
|
|
||||||
// error values see on Windows. (perhaps when running
|
|
||||||
// all.bat?)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
|
||||||
|
|
||||||
procRegNotifyChangeKeyValue = modadvapi32.NewProc("RegNotifyChangeKeyValue")
|
|
||||||
)
|
|
||||||
|
|
||||||
func regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) {
|
|
||||||
var _p0 uint32
|
|
||||||
if watchSubtree {
|
|
||||||
_p0 = 1
|
|
||||||
} else {
|
|
||||||
_p0 = 0
|
|
||||||
}
|
|
||||||
var _p1 uint32
|
|
||||||
if asynchronous {
|
|
||||||
_p1 = 1
|
|
||||||
} else {
|
|
||||||
_p1 = 0
|
|
||||||
}
|
|
||||||
r0, _, _ := syscall.Syscall6(procRegNotifyChangeKeyValue.Addr(), 5, uintptr(key), uintptr(_p0), uintptr(notifyFilter), uintptr(event), uintptr(_p1), 0)
|
|
||||||
if r0 != 0 {
|
|
||||||
regerrno = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user